unit2-02_class_conditioned_diffusion_model_example

Class-conditioned

指的是类别-Conditioned. 或者说 class-label-conditioned.

网络输入改成啥样了? 其实就是 concat.

  • Unet 输入通道直接改成了 in_channels=1 + class_emb_size

  • forward 其实就是广播 + torch.cat 一下.

def forward(self, x, t, class_labels):
    bs, ch, w, h = x.shape

    # [pre-defined] self.class_emb = nn.Embedding(num_classes, class_emb_size)
    class_cond = self.class_emb(class_labels)
    class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)

    net_input = torch.cat((x, class_cond), dim=1)

    # model 返回 ModelOutput.
    # sample: 就是预测的噪声张量.
    # additional_residuals: 存储额外残差信息. 一般没用.
    return self.model(net_input, t).sample