Class-conditioned
指的是类别-Conditioned. 或者说 class-label-conditioned.
网络输入改成啥样了? 其实就是 concat.
- Unet 输入通道直接改成了
in_channels=1 + class_emb_size
1UNet2DModel(2 in_channels=1 + class_emb_size,
- forward 时广播 + torch.cat 一下.
1def forward(self, x, t, class_labels):2 bs, ch, w, h = x.shape3
4 # & self.class_emb = nn.Embedding(num_classes, class_emb_size)5 class_cond = self.class_emb(class_labels) # *6 # 广播7 class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)8
9 net_input = torch.cat((x, class_cond), dim=1)10
11 # model 返回 ModelOutput.12 # sample: 就是预测的噪声张量.13 # additional_residuals: 存储额外残差信息. 一般没用.14 return self.model(net_input, t).sample