def __init__(self, hidden_size=768, num_experts=8, top_k=2):
self.gate = Linear(hidden_size, num_experts)
self.experts = ModuleList([Expert(hidden_size) for _ in
gates = softmax(self.gate(x), dim=-1)
probs, indices = topk(gates, self.top_k, dim=-1)
probs = probs / probs.sum(dim=-1, keepdim=True)
for i in range(x.shape[0]):
for j in range(self.top_k):
expert = self.experts[indices[i, j]]
out[i] += probs[i, j] * expert(x[i:i+1])
def __init__(self, hidden_size):
self.w1 = Linear(hidden_size, hidden_size * 4)
self.w2 = Linear(hidden_size * 4, hidden_size)
return self.w2(self.act(self.w1(x)))
def __init__(self, hidden_size, num_heads, num_experts=8):
self.attn = MultiHeadAttention(hidden_size, num_heads)
self.moe = MoELayer(hidden_size, num_experts)
self.norm1 = LayerNorm(hidden_size)
self.norm2 = LayerNorm(hidden_size)
x = self.norm1(x + self.attn(x))
x = self.norm2(x + self.moe(x))