一句话:state 和 image 通过 transformer encoder 作为 kv memory 提供给 decoder hidden(初始值为 0,带可学习位置编码),decoder hidden 通过 cross-attn memory 生成最终 action. act 还引入了 vae 来自监督重建 GT-action,原文声称是为了建模人类数据的 mutli-modality 防止输出平均.
这里的 vae-encoder 是一个 BERT 风格的 transformer encoder. 输入为 [[cls], [joint state], [action seq]],经过四层 self-attn 最终只取 [cls] 并 Linear 得到 z. 而 policy transformer encoder 则没有 [cls].
General
1 Transformer2 推理时不使用 VAE.3 (acts as VAE decoder4 during training)5 ┌───────────────────────┐6 │ Outputs │7 │ ▲ │8 │ ┌─K───►┌───────┐ │9 ┌──────┐ │ │ │Transf.│ │10 │ │ │ ├─V───►│decoder│ │11┌────┴────┐ │ │ │ │ │ │12│ │ │ │ ┌───┴───┬─►│ │ │13│ VAE │ │ │ │ │ └───────┘ │14│ encoder │ │ │ │Transf.│ │15│ │ │ │ │encoder│ │8 collapsed lines
16└───▲─────┘ latent │ │ │ │17 │ │ │ └▲──▲─▲─┘ │18 │ │ │ │ │ │ │19 inputs └─────────┼──┘ │ image emb.(use resnet backbone)20 │ │ qpos emb.(No action emb.)21action&qpos emb. └───────────────────────┘22
23where latent: (b, 512)其他
- is_pad 是什么:考虑到有些采样动作长度不足 chunk_size,其对应位置 is_pad 为 true 且 input 由复制产生且不会被注意(已验证)
- decoder_pos_embed 是什么:形状是 [50, 8, 512], 是固定长度 (
num_queries) 的学习参数,代表num_queries个“查询槽位”,每个槽位询问一个动作。
Train
flowchart TD
qpos(["qpos
(B=8, state_dim=14)"]) --> qemb[["qpos Embedding / proj
14 -> D=512"]]
action(["GT Actions
(B=8, chunk_size=50, action_dim=14)"]) --> aemb[["action Embedding
14 -> D=512"]]
cls[["cls_embed
(B=8, 1, D=512)"]] --> encin(["VAE tokens
(seq_len=52, B=8, D=512)"])
qemb --> encin
aemb --> encin
encin --> vae[["VAE Transformer Encoder
ACTEncoder"]]
vae --> dist[["latent proj MLP
mu, logvar"]]
dist --> z(["sample latent z
(B=8, latent_dim=512)"])
dist --> kl["KL Loss"]
img(["Images
(B=8, n_cam=3, C=3, H=480, W=640)"]) --> infer[["See Infer flowchart
policy network as black box"]]
qpos --> infer
z --> infer
infer --> ahat(["Pred Actions
(B=8, 50, 14)"])
infer --> phat(["Pred is_pad
(B=8, 50, 1)"])
action --> recon["Masked L1 Loss"]
ispad(["is_pad mask
(B=8, 50)"]) --> recon
ahat --> recon
phat --> padloss["Pad / auxiliary loss"]
ispad --> padloss
recon --> total(["Total Loss
L1 + kl_weight * KL"])
kl --> total
padloss --> total
Infer
flowchart TD
qpos(["qpos
(B=8, state_dim=14)"]) --> stateproj[["input_proj_robot_state
14 -> 512"]]
img(["Images
(B=8, n_cam=3, C=3, H=480, W=640)"]) --> norm["Image Normalize
ImageNet mean/std"]
norm --> backbone[["ResNet Backbone
feature D=512"]]
backbone --> feat(["Image Features
(B=8, D=512, H=15, W=60)
900 tokens"])
zero(["zero latent
(B=8, latent_dim=512)"]) --> zproj[["latent_input_proj
512 -> 512"]]
zproj --> addtok(["Extra Tokens
(2, B=8, D=512)
latent + qpos"])
stateproj --> addtok
feat --> src(["Encoder src
(seq_len=902, B=8, D=512)"])
addtok --> src
src --> enc[["Transformer Encoder
多层 selfattn + FFN"]]
enc --> mem(["memory
(902, B=8, D=512)"])
query[["decoder_pos_embed
num_queries=50, D=512"]] --> tgt(["Decoder tgt zeros
(50, B=8, D=512)"])
tgt --> dec[["Transformer Decoder
cross-attn to memory"]]
query --> dec
mem --> dec
dec --> hs(["Action Tokens
(B=8, 50, D=512)"])
hs --> ahead[["action_head / proj
512 -> 14"]]
hs --> phead[["is_pad_head / proj
512 -> 1"]]
ahead --> ahat(["Pred Action Chunk
(B=8, horizon=50, action_dim=14)"])
phead --> phat(["Pred is_pad
(B=8, 50, 1)"])
ahat --> exec["Execute / temporal aggregation
next action(s)"]
其中 decoder QKV
flowchart TD
tgt(["Decoder tgt zeros
(全 0 tensor)
(T=50, B=8, D=512)"])
query[["decoder_pos_embed
表明 query 位置.
是 learnable embedding.
(T=50, 1, D=512)"]] --> selfqk["+"]
tgt --> selfqk
selfqk --> selfq(["Self-attn Q/K
(50, B=8, 512)"])
tgt --> selfv(["Self-attn V
(50, B=8, 512)"])
selfq --> selfattn[["self_attn
WTF??这不是全 0 吗"]]
selfv --> selfattn
selfattn --> x1(["Decoder hidden
(50, B=8, 512)"])
x1 --> crossqadd["+"]
query --> crossqadd
crossqadd --> crossq(["Cross-attn Q
(50, B=8, 512)"])
mem(["memory
(encoder_out)
(S=902, B=8, D=512)"]) --> crosskadd["+"]
pos(["encoder_pos_embed
2D Sine
(S=902, B=8, D=512)"]) --> crosskadd
crosskadd --> crossk(["Cross-attn K
(902, B=8, 512)"])
mem --> crossv(["Cross-attn V
(902, B=8, 512)"])
crossq --> crossattn[["cross-attn
memory不变
每层 self-attn + cross memory + FFN"]]
crossk --> crossattn
crossv --> crossattn
crossattn --> ff[["MLP
512 -> feedforward -> 512"]]
ff --> out(["decoder_out
(50, B=8, D=512)"])