“使用 Image Encoder 以及 Text Encoder 并使用 FiLM 进行 Fusion 后用 Transformer 处理的模型” 用高度凝练的几行 pytorch 伪代码告诉我这里的写法。尤其是 film 和如何后续采用 transformer 处理
FiLM
来自 RT1
使用 txt token => MLP 的输出缩放 img token
1# img: [B, C, H, W], txt_ids: [B, L]2v = ImageEncoder(img) # [B, Nv, D] (视觉 token)3t = TextEncoder(txt_ids).mean(1) # [B, Dt] (文本全局语义)4
5gamma, beta = Linear(t).chunk(2, dim=-1) # 各 [B, D]6v_film = v * (1 + gamma[:, None, :]) + beta[:, None, :] # FiLM: feature-wise affine7
8x = torch.cat([CLS.expand(B,1,D), v_film], dim=1) # [B, 1+Nv, D]9h = TransformerEncoder(x) # [B, 1+Nv, D]10y = Head(h[:, 0]) # 用 CLS 做下游预测更现代的(把现成 LLM Transformer 当作 fusion 主干)
1v = VisionEncoder(img) # [B, Nv, Dv]2v_tok = VisionProjector(v) # [B, Nv, D] 对齐到 LLM 维度3
4t_tok = LLM.embed_tokens(text_ids) # [B, Nt, D]5x = torch.cat([BOV, v_tok, EOV, t_tok], dim=1) # 视觉+语言统一序列(Fusion in LLM)6
7h = LLM.transformer(x, causal_mask=True) # 统一 Transformer 融合8a_logits = ActionHead(h[:, -Na:, :]) # 取动作位置 hidden state9a = sample_or_argmax(a_logits) # 自回归/并行输出离散动作 token10
11如果动作是连续量,也常见:12a = MLP(h[:, -1, :]) # 直接回归 7DoF/控制量预测 attn map
来自 RT1
就是你能想到的最简单的自主加权. (每个 token 自己算自己对 m 个结果的权重,权重和为 1)
1# x: [B, N, D] (N个输入token)2attn_logits = MLP(x) # [B, N, M] -> 每个token对M个map的打3分4attn = torch.softmax(attn_logits, dim=1) # 在token维归一化: 每张map覆盖N个token5
6# 用每张attention map对token做加权汇聚7# x_out[b,m,d] = sum_n attn[b,n,m] * x[b,n,d]8x_out = torch.einsum('bnm,bnd->bmd', attn, x) # [B, M, D]RT-2 离散动作 token
[todo]
1# 连续动作 -> 离散token(训练前/数据管线)2# a_cont: [B, T, A] (A维连续控制量,如x,y,z,roll,pitch,yaw,gripper)3bins = torch.linspace(low, high, K) # 每个维度K个bin4a_idx = bucketize_per_dim(a_cont, bins) # [B, T, A] in [0..K-1]5act_tok = a_idx + dim_offset[None, None, :] # 映射到统一动作词表ID6
7# 把动作token当“文本token”一样做自回归监督8inp = torch.cat([vision_tok, text_tok, act_tok[:, :-1, :].reshape(B, -1)], dim=1)9target = act_tok.reshape(B, -1) # 预测下一动作token10h = VLM_Transformer(inp)11logits = ActionLMHead(h[:, -target.size(1):, :]) # [B, T*A, V_act]12loss = F.cross_entropy(logits.reshape(-1, V_act), target.reshape(-1))13
14# 推理时 token -> 连续动作(解码)15pred_idx = logits.argmax(-1).reshape(B, T, A) - dim_offset5 collapsed lines
16a_pred = dequantize(pred_idx, bins) # bin中心/边界还原为连17续值18
19核心就是:把每个动作维度量化成离散 bin,转成 token 序列,用语言建模同款 CE loss20训练,再反量化回控制量。