How to?

act

Jan 1, 2025
2501
4 Minutes
693 Words

一句话:state 和 image 通过 transformer encoder 作为 kv memory 提供给 decoder hidden(初始值为 0,带可学习位置编码),decoder hidden 通过 cross-attn memory 生成最终 action. act 还引入了 vae 来自监督重建 GT-action,原文声称是为了建模人类数据的 mutli-modality 防止输出平均.

这里的 vae-encoder 是一个 BERT 风格的 transformer encoder. 输入为 [[cls], [joint state], [action seq]],经过四层 self-attn 最终只取 [cls] 并 Linear 得到 z. 而 policy transformer encoder 则没有 [cls].

General

1
Transformer
2
推理时不使用 VAE.
3
(acts as VAE decoder
4
during training)
5
┌───────────────────────┐
6
│ Outputs │
7
│ ▲ │
8
│ ┌─K───►┌───────┐ │
9
┌──────┐ │ │ │Transf.│ │
10
│ │ │ ├─V───►│decoder│ │
11
┌────┴────┐ │ │ │ │ │ │
12
│ │ │ │ ┌───┴───┬─►│ │ │
13
│ VAE │ │ │ │ │ └───────┘ │
14
│ encoder │ │ │ │Transf.│ │
15
│ │ │ │ │encoder│ │
8 collapsed lines
16
└───▲─────┘ latent │ │ │ │
17
│ │ │ └▲──▲─▲─┘ │
18
│ │ │ │ │ │ │
19
inputs └─────────┼──┘ │ image emb.(use resnet backbone)
20
│ │ qpos emb.(No action emb.)
21
action&qpos emb. └───────────────────────┘
22
23
where latent: (b, 512)

其他

  • is_pad 是什么:考虑到有些采样动作长度不足 chunk_size,其对应位置 is_pad 为 true 且 input 由复制产生且不会被注意(已验证)
  • decoder_pos_embed 是什么:形状是 [50, 8, 512], 是固定长度 (num_queries) 的学习参数,代表 num_queries个“查询槽位”,每个槽位询问一个动作。

Train

flowchart TD
    qpos(["qpos
(B=8, state_dim=14)"]) --> qemb[["qpos Embedding / proj
14 -> D=512"]] action(["GT Actions
(B=8, chunk_size=50, action_dim=14)"]) --> aemb[["action Embedding
14 -> D=512"]] cls[["cls_embed
(B=8, 1, D=512)"]] --> encin(["VAE tokens
(seq_len=52, B=8, D=512)"]) qemb --> encin aemb --> encin encin --> vae[["VAE Transformer Encoder
ACTEncoder"]] vae --> dist[["latent proj MLP
mu, logvar"]] dist --> z(["sample latent z
(B=8, latent_dim=512)"]) dist --> kl["KL Loss"] img(["Images
(B=8, n_cam=3, C=3, H=480, W=640)"]) --> infer[["See Infer flowchart
policy network as black box"]] qpos --> infer z --> infer infer --> ahat(["Pred Actions
(B=8, 50, 14)"]) infer --> phat(["Pred is_pad
(B=8, 50, 1)"]) action --> recon["Masked L1 Loss"] ispad(["is_pad mask
(B=8, 50)"]) --> recon ahat --> recon phat --> padloss["Pad / auxiliary loss"] ispad --> padloss recon --> total(["Total Loss
L1 + kl_weight * KL"]) kl --> total padloss --> total

Infer

flowchart TD
    qpos(["qpos
(B=8, state_dim=14)"]) --> stateproj[["input_proj_robot_state
14 -> 512"]] img(["Images
(B=8, n_cam=3, C=3, H=480, W=640)"]) --> norm["Image Normalize
ImageNet mean/std"] norm --> backbone[["ResNet Backbone
feature D=512"]] backbone --> feat(["Image Features
(B=8, D=512, H=15, W=60)
900 tokens"]) zero(["zero latent
(B=8, latent_dim=512)"]) --> zproj[["latent_input_proj
512 -> 512"]] zproj --> addtok(["Extra Tokens
(2, B=8, D=512)
latent + qpos"]) stateproj --> addtok feat --> src(["Encoder src
(seq_len=902, B=8, D=512)"]) addtok --> src src --> enc[["Transformer Encoder
多层 selfattn + FFN"]] enc --> mem(["memory
(902, B=8, D=512)"]) query[["decoder_pos_embed
num_queries=50, D=512"]] --> tgt(["Decoder tgt zeros
(50, B=8, D=512)"]) tgt --> dec[["Transformer Decoder
cross-attn to memory"]] query --> dec mem --> dec dec --> hs(["Action Tokens
(B=8, 50, D=512)"]) hs --> ahead[["action_head / proj
512 -> 14"]] hs --> phead[["is_pad_head / proj
512 -> 1"]] ahead --> ahat(["Pred Action Chunk
(B=8, horizon=50, action_dim=14)"]) phead --> phat(["Pred is_pad
(B=8, 50, 1)"]) ahat --> exec["Execute / temporal aggregation
next action(s)"]

其中 decoder QKV

flowchart TD
    tgt(["Decoder tgt zeros
(全 0 tensor)
(T=50, B=8, D=512)"]) query[["decoder_pos_embed
表明 query 位置.
是 learnable embedding.
(T=50, 1, D=512)"]] --> selfqk["+"] tgt --> selfqk selfqk --> selfq(["Self-attn Q/K
(50, B=8, 512)"]) tgt --> selfv(["Self-attn V
(50, B=8, 512)"]) selfq --> selfattn[["self_attn
WTF??这不是全 0 吗"]] selfv --> selfattn selfattn --> x1(["Decoder hidden
(50, B=8, 512)"]) x1 --> crossqadd["+"] query --> crossqadd crossqadd --> crossq(["Cross-attn Q
(50, B=8, 512)"]) mem(["memory
(encoder_out)
(S=902, B=8, D=512)"]) --> crosskadd["+"] pos(["encoder_pos_embed
2D Sine
(S=902, B=8, D=512)"]) --> crosskadd crosskadd --> crossk(["Cross-attn K
(902, B=8, 512)"]) mem --> crossv(["Cross-attn V
(902, B=8, 512)"]) crossq --> crossattn[["cross-attn
memory不变
每层 self-attn + cross memory + FFN"]] crossk --> crossattn crossv --> crossattn crossattn --> ff[["MLP
512 -> feedforward -> 512"]] ff --> out(["decoder_out
(50, B=8, D=512)"])
Article title:act
Article author:Julyfun
Release time:Jan 1, 2025
Copyright 2026
Sitemap