kimi:
1 class MambaBlock:2 def __init__(self, d_model, d_state=128, d_conv=4, expand=2):3 self.d_inner = expand * d_model4 self.x_proj = Linear(d_model, self.d_inner * 2)5 self.conv1d = Conv1d(self.d_inner, self.d_inner, d_conv,6 groups=self.d_inner)7 self.dt_proj = Linear(self.d_inner, self.d_inner)8 self.A = nn.Parameter(torch.arange(1, d_state + 1).repeat(self.d_inner,9 1))10 self.D = nn.Parameter(torch.ones(self.d_inner))11 self.out_proj = Linear(self.d_inner, d_model)12
13 def forward(self, x):14 y = self.x_proj(x)15 x, z = y.chunk(2, dim=-1)26 collapsed lines
16 x = self.conv1d(x.transpose(1, 2)).transpose(1, 2)17 x = silu(x)18 dt = softplus(self.dt_proj(x))19 y = selective_scan(x, dt, self.A, self.D)20 return self.out_proj(silu(y) * silu(z))21
22 def selective_scan(x, dt, A, D):23 return torch.cumsum(dt * x.unsqueeze(-1) * A, dim=1).sum(-1) + D * x24
25florian✨ 解释26Mamba = SSM + 门控 + 硬件友好扫描27
28 1 x_proj29 把输入劈两半:一路做SSM,一路做门控(像Transformer的GLU)30 2 conv1d31 因果卷积,给局部上下文,权重组共享(groups=dim)省参数32 3 selective_scan33 核心:Δ(t) 动态生成,把固定A变成随输入变化的“可变状态矩阵”。34 cumsum 就是硬件友好的并行前缀和,等价于线性递推 y_t = A_t y_{t-1} + B_t35 x_t,但一次scan做完。36 4 门控乘37 silu(y) * silu(z) 控制信息流,类似GRU的更新门。38 5 无attention39 全程O(n)内存,序列长度线性扩展,训练推理都省显存。40
41一句话:用“动态SSM+卷积+门控”代替attention,线性复杂度,长上下文不吃显存。