how to

MAMBA

Nov 21, 2024
notesjulyfun技术学习models
2 Minutes
283 Words

kimi:

1
class MambaBlock:
2
def __init__(self, d_model, d_state=128, d_conv=4, expand=2):
3
self.d_inner = expand * d_model
4
self.x_proj = Linear(d_model, self.d_inner * 2)
5
self.conv1d = Conv1d(self.d_inner, self.d_inner, d_conv,
6
groups=self.d_inner)
7
self.dt_proj = Linear(self.d_inner, self.d_inner)
8
self.A = nn.Parameter(torch.arange(1, d_state + 1).repeat(self.d_inner,
9
1))
10
self.D = nn.Parameter(torch.ones(self.d_inner))
11
self.out_proj = Linear(self.d_inner, d_model)
12
13
def forward(self, x):
14
y = self.x_proj(x)
15
x, z = y.chunk(2, dim=-1)
26 collapsed lines
16
x = self.conv1d(x.transpose(1, 2)).transpose(1, 2)
17
x = silu(x)
18
dt = softplus(self.dt_proj(x))
19
y = selective_scan(x, dt, self.A, self.D)
20
return self.out_proj(silu(y) * silu(z))
21
22
def selective_scan(x, dt, A, D):
23
return torch.cumsum(dt * x.unsqueeze(-1) * A, dim=1).sum(-1) + D * x
24
25
florian✨ 解释
26
Mamba = SSM + 门控 + 硬件友好扫描
27
28
1 x_proj
29
把输入劈两半:一路做SSM,一路做门控(像Transformer的GLU)
30
2 conv1d
31
因果卷积,给局部上下文,权重组共享(groups=dim)省参数
32
3 selective_scan
33
核心:Δ(t) 动态生成,把固定A变成随输入变化的“可变状态矩阵”。
34
cumsum 就是硬件友好的并行前缀和,等价于线性递推 y_t = A_t y_{t-1} + B_t
35
x_t,但一次scan做完。
36
4 门控乘
37
silu(y) * silu(z) 控制信息流,类似GRU的更新门。
38
5 无attention
39
全程O(n)内存,序列长度线性扩展,训练推理都省显存。
40
41
一句话:用“动态SSM+卷积+门控”代替attention,线性复杂度,长上下文不吃显存。
Article title:MAMBA
Article author:Julyfun
Release time:Nov 21, 2024
Copyright 2025
Sitemap