How to?

act

Jan 1, 2025
2501
4 Minutes
678 Words
1
[class ACTPolicy(nn.Module).__call__(self, qpos, image, actions=None, is_pad=None)]
2
- qpos: 8, 14
3
- image: 8, 3, 3, 480, 640 # the first 3 means camera num
4
- actions: 8, 125, 14 # wtf is 125?

General

1
Transformer
2
推理时不使用 VAE.
3
(acts as VAE decoder
4
during training)
5
┌───────────────────────┐
6
│ Outputs │
7
│ ▲ │
8
│ ┌─K───►┌───────┐ │
9
┌──────┐ │ │ │Transf.│ │
10
│ │ │ ├─V───►│decoder│ │
11
┌────┴────┐ │ │ │ │ │ │
12
│ │ │ │ ┌───┴───┬─►│ │ │
13
│ VAE │ │ │ │ │ └───────┘ │
14
│ encoder │ │ │ │Transf.│ │
15
│ │ │ │ │encoder│ │
8 collapsed lines
16
└───▲─────┘ latent │ │ │ │
17
│ │ │ └▲──▲─▲─┘ │
18
│ │ │ │ │ │ │
19
inputs └─────────┼──┘ │ image emb.(use resnet backbone)
20
│ │ qpos emb.(No action emb.)
21
action&qpos emb. └───────────────────────┘
22
23
where latent: (b, 512)

其他

  • is_pad 是什么:考虑到有些采样动作长度不足 chunk_size,其对应位置 is_pad 为 true 且 input 由复制产生且不会被注意(已验证)
  • decoder_pos_embed 是什么:形状是 [50, 8, 512], 是固定长度 (num_queries) 的学习参数,代表 num_queries个“查询槽位”,每个槽位询问一个动作。

Train

flowchart TD
    qpos(["qpos
(B=8, state_dim=14)"]) --> qemb[["qpos Embedding / proj
14 -> D=512"]] action(["GT Actions
(B=8, chunk_size=50, action_dim=14)"]) --> aemb[["action Embedding
14 -> D=512"]] cls[["cls_embed
(B=8, 1, D=512)"]] --> encin(["VAE tokens
(seq_len=52, B=8, D=512)"]) qemb --> encin aemb --> encin encin --> vae[["VAE Transformer Encoder
ACTEncoder"]] vae --> dist[["latent proj MLP
mu, logvar"]] dist --> z(["sample latent z
(B=8, latent_dim=512)"]) dist --> kl["KL Loss"] img(["Images
(B=8, n_cam=3, C=3, H=480, W=640)"]) --> norm["Image Normalize
ImageNet mean/std"] norm --> backbone[["ResNet Backbone
feature D=512"]] backbone --> feat(["Image Features
(B=8, D=512, H=15, W=60)
900 tokens"]) feat --> pos["2D sine pos
(900, B=8, D=512)"] z --> zproj[["latent_input_proj
512 -> 512"]] qpos --> stateproj[["input_proj_robot_state
14 -> 512"]] zproj --> addtok(["Extra Tokens
(2, B=8, D=512)
latent + qpos"]) stateproj --> addtok feat --> src(["Encoder src
(seq_len=902, B=8, D=512)"]) addtok --> src pos --> src src --> venc[["Transformer Encoder
vision + state + latent"]] query[["decoder_pos_embed
num_queries=50, D=512"]] --> tgt(["Decoder tgt zeros
(50, B=8, D=512)"]) tgt --> dec[["Transformer Decoder
cross-attn to memory"]] query --> dec venc --> mem(["memory
(902, B=8, D=512)"]) mem --> dec dec --> hs(["Action Tokens
(B=8, 50, D=512)"]) hs --> ahead[["action_head / proj
512 -> 14"]] hs --> phead[["is_pad_head / proj
512 -> 1"]] ahead --> ahat(["Pred Actions
(B=8, 50, 14)"]) phead --> phat(["Pred is_pad
(B=8, 50, 1)"]) action --> recon["Masked L1 Loss"] ispad(["is_pad mask
(B=8, 50)"]) --> recon ahat --> recon phat --> padloss["Pad / auxiliary loss"] ispad --> padloss recon --> total(["Total Loss
L1 + kl_weight * KL"]) kl --> total padloss --> total

Infer

flowchart TD
    qpos(["qpos
(B=8, state_dim=14)"]) --> stateproj[["input_proj_robot_state
14 -> 512"]] img(["Images
(B=8, n_cam=3, C=3, H=480, W=640)"]) --> norm["Image Normalize
ImageNet mean/std"] norm --> backbone[["ResNet Backbone
feature D=512"]] backbone --> feat(["Image Features
(B=8, D=512, H=15, W=60)
900 tokens"]) feat --> pos["2D sine pos
(900, B=8, D=512)"] zero(["zero latent
(B=8, latent_dim=512)"]) --> zproj[["latent_input_proj
512 -> 512"]] zproj --> addtok(["Extra Tokens
(2, B=8, D=512)
latent + qpos"]) stateproj --> addtok feat --> src(["Encoder src
(seq_len=902, B=8, D=512)"]) addtok --> src pos --> src src --> enc[["Transformer Encoder
vision + state"]] enc --> mem(["memory
(902, B=8, D=512)"]) query[["decoder_pos_embed
num_queries=50, D=512"]] --> tgt(["Decoder tgt zeros
(50, B=8, D=512)"]) tgt --> dec[["Transformer Decoder
cross-attn to memory"]] query --> dec mem --> dec dec --> hs(["Action Tokens
(B=8, 50, D=512)"]) hs --> ahead[["action_head / proj
512 -> 14"]] hs --> phead[["is_pad_head / proj
512 -> 1"]] ahead --> ahat(["Pred Action Chunk
(B=8, horizon=50, action_dim=14)"]) phead --> phat(["Pred is_pad
(B=8, 50, 1)"]) ahat --> exec["Execute / temporal aggregation
next action(s)"]

其中 decoder QKV

flowchart TD
    tgt(["Decoder tgt zeros
(T=50, B=8, D=512)"]) query[["query_embed
decoder_pos_embed
(T=50, 1, D=512)"]] --> selfqk["+"] tgt --> selfqk selfqk --> selfq(["Self-attn Q/K
(50, B=8, 512)"]) tgt --> selfv(["Self-attn V
(50, B=8, 512)"]) selfq --> selfattn[["self_attn"]] selfv --> selfattn selfattn --> x1(["Decoder hidden
(50, B=8, 512)"]) x1 --> crossqadd["+"] query --> crossqadd crossqadd --> crossq(["Cross-attn Q
(50, B=8, 512)"]) mem(["memory
(encoder_out)
(S=902, B=8, D=512)"]) --> crosskadd["+"] pos(["encoder_pos_embed
(S=902, B=8, D=512)"]) --> crosskadd crosskadd --> crossk(["Cross-attn K
(902, B=8, 512)"]) mem --> crossv(["Cross-attn V
(902, B=8, 512)"]) crossq --> crossattn[["multihead_attn
cross-attn"]] crossk --> crossattn crossv --> crossattn crossattn --> ff[["MLP
512 -> feedforward -> 512"]] ff --> out(["decoder_out
(50, B=8, D=512)"])
Article title:act
Article author:Julyfun
Release time:Jan 1, 2025
Copyright 2026
Sitemap