how to

unit2-02_class_conditioned_diffusion_model_example

Jun 10, 2025
notesjulyfun技术学习diffusion-models-class
1 Minutes
117 Words

Class-conditioned

指的是类别-Conditioned. 或者说 class-label-conditioned.

网络输入改成啥样了? 其实就是 concat.

  • Unet 输入通道直接改成了 in_channels=1 + class_emb_size
1
UNet2DModel(
2
in_channels=1 + class_emb_size,
  • forward 时广播 + torch.cat 一下.
1
def forward(self, x, t, class_labels):
2
bs, ch, w, h = x.shape
3
4
# & self.class_emb = nn.Embedding(num_classes, class_emb_size)
5
class_cond = self.class_emb(class_labels) # *
6
# 广播
7
class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)
8
9
net_input = torch.cat((x, class_cond), dim=1)
10
11
# model 返回 ModelOutput.
12
# sample: 就是预测的噪声张量.
13
# additional_residuals: 存储额外残差信息. 一般没用.
14
return self.model(net_input, t).sample
Article title:unit2-02_class_conditioned_diffusion_model_example
Article author:Julyfun
Release time:Jun 10, 2025
Copyright 2025
Sitemap