[class ACTPolicy(nn.Module).__call__(self, qpos, image, actions=None, is_pad=None)]
- image: 8, 3, 3, 480, 640 # the first 3 means camera num
- actions: 8, 125, 14 # wtf is 125?
General
┌───────────────────────┐
│ VAE │ │ │ │ │ └───────┘ │
│ encoder │ │ │ │Transf.│ │
└───▲─────┘ latent │ │ │ │
inputs └─────────┼──┘ │ image emb.(use resnet backbone)
│ │ qpos emb.(No action emb.)
action&qpos emb. └───────────────────────┘
其他
- is_pad 是什么:考虑到有些采样动作长度不足 chunk_size,其对应位置 is_pad 为 true 且 input 由复制产生且不会被注意(已验证)
- decoder_pos_embed 是什么:形状是 [50, 8, 512], 是固定长度 (
num_queries) 的学习参数,代表 num_queries个“查询槽位”,每个槽位询问一个动作。
Train
flowchart TD
qpos(["qpos
(B=8, state_dim=14)"]) --> qemb[["qpos Embedding / proj
14 -> D=512"]]
action(["GT Actions
(B=8, chunk_size=50, action_dim=14)"]) --> aemb[["action Embedding
14 -> D=512"]]
cls[["cls_embed
(B=8, 1, D=512)"]] --> encin(["VAE tokens
(seq_len=52, B=8, D=512)"])
qemb --> encin
aemb --> encin
encin --> vae[["VAE Transformer Encoder
ACTEncoder"]]
vae --> dist[["latent proj MLP
mu, logvar"]]
dist --> z(["sample latent z
(B=8, latent_dim=512)"])
dist --> kl["KL Loss"]
img(["Images
(B=8, n_cam=3, C=3, H=480, W=640)"]) --> norm["Image Normalize
ImageNet mean/std"]
norm --> backbone[["ResNet Backbone
feature D=512"]]
backbone --> feat(["Image Features
(B=8, D=512, H=15, W=60)
900 tokens"])
feat --> pos["2D sine pos
(900, B=8, D=512)"]
z --> zproj[["latent_input_proj
512 -> 512"]]
qpos --> stateproj[["input_proj_robot_state
14 -> 512"]]
zproj --> addtok(["Extra Tokens
(2, B=8, D=512)
latent + qpos"])
stateproj --> addtok
feat --> src(["Encoder src
(seq_len=902, B=8, D=512)"])
addtok --> src
pos --> src
src --> venc[["Transformer Encoder
vision + state + latent"]]
query[["decoder_pos_embed
num_queries=50, D=512"]] --> tgt(["Decoder tgt zeros
(50, B=8, D=512)"])
tgt --> dec[["Transformer Decoder
cross-attn to memory"]]
query --> dec
venc --> mem(["memory
(902, B=8, D=512)"])
mem --> dec
dec --> hs(["Action Tokens
(B=8, 50, D=512)"])
hs --> ahead[["action_head / proj
512 -> 14"]]
hs --> phead[["is_pad_head / proj
512 -> 1"]]
ahead --> ahat(["Pred Actions
(B=8, 50, 14)"])
phead --> phat(["Pred is_pad
(B=8, 50, 1)"])
action --> recon["Masked L1 Loss"]
ispad(["is_pad mask
(B=8, 50)"]) --> recon
ahat --> recon
phat --> padloss["Pad / auxiliary loss"]
ispad --> padloss
recon --> total(["Total Loss
L1 + kl_weight * KL"])
kl --> total
padloss --> total
Infer
flowchart TD
qpos(["qpos
(B=8, state_dim=14)"]) --> stateproj[["input_proj_robot_state
14 -> 512"]]
img(["Images
(B=8, n_cam=3, C=3, H=480, W=640)"]) --> norm["Image Normalize
ImageNet mean/std"]
norm --> backbone[["ResNet Backbone
feature D=512"]]
backbone --> feat(["Image Features
(B=8, D=512, H=15, W=60)
900 tokens"])
feat --> pos["2D sine pos
(900, B=8, D=512)"]
zero(["zero latent
(B=8, latent_dim=512)"]) --> zproj[["latent_input_proj
512 -> 512"]]
zproj --> addtok(["Extra Tokens
(2, B=8, D=512)
latent + qpos"])
stateproj --> addtok
feat --> src(["Encoder src
(seq_len=902, B=8, D=512)"])
addtok --> src
pos --> src
src --> enc[["Transformer Encoder
vision + state"]]
enc --> mem(["memory
(902, B=8, D=512)"])
query[["decoder_pos_embed
num_queries=50, D=512"]] --> tgt(["Decoder tgt zeros
(50, B=8, D=512)"])
tgt --> dec[["Transformer Decoder
cross-attn to memory"]]
query --> dec
mem --> dec
dec --> hs(["Action Tokens
(B=8, 50, D=512)"])
hs --> ahead[["action_head / proj
512 -> 14"]]
hs --> phead[["is_pad_head / proj
512 -> 1"]]
ahead --> ahat(["Pred Action Chunk
(B=8, horizon=50, action_dim=14)"])
phead --> phat(["Pred is_pad
(B=8, 50, 1)"])
ahat --> exec["Execute / temporal aggregation
next action(s)"]
其中 decoder QKV
flowchart TD
tgt(["Decoder tgt zeros
(T=50, B=8, D=512)"])
query[["query_embed
decoder_pos_embed
(T=50, 1, D=512)"]] --> selfqk["+"]
tgt --> selfqk
selfqk --> selfq(["Self-attn Q/K
(50, B=8, 512)"])
tgt --> selfv(["Self-attn V
(50, B=8, 512)"])
selfq --> selfattn[["self_attn"]]
selfv --> selfattn
selfattn --> x1(["Decoder hidden
(50, B=8, 512)"])
x1 --> crossqadd["+"]
query --> crossqadd
crossqadd --> crossq(["Cross-attn Q
(50, B=8, 512)"])
mem(["memory
(encoder_out)
(S=902, B=8, D=512)"]) --> crosskadd["+"]
pos(["encoder_pos_embed
(S=902, B=8, D=512)"]) --> crosskadd
crosskadd --> crossk(["Cross-attn K
(902, B=8, 512)"])
mem --> crossv(["Cross-attn V
(902, B=8, 512)"])
crossq --> crossattn[["multihead_attn
cross-attn"]]
crossk --> crossattn
crossv --> crossattn
crossattn --> ff[["MLP
512 -> feedforward -> 512"]]
ff --> out(["decoder_out
(50, B=8, D=512)"])