-
关于 VAE 理解的教程: https://spaces.ac.cn/archives/5253
-
假定 $p(Z|X)$ 为一正态分布.(Z 为隐变量,X 为目标分布)
- 注意不是假设 $p(Z)$为正态分布,不同 $X$ 显然必须有不同的隐藏分布,否则解码器无法区分它们,训练时 $X_i$ 和 $X^“hat”_i$ 就无法对应上.
-
训练编码器使得样本对应的(编码器输出的)$mu$ 和 $log sigma^2$ 既要接近正态分布,又要对不同样本产生一些区别使得解码器能够将其还原到对应图像.
- 接近正态又要有些微区分,这是一个权衡问题.
- 为了防止正态分布采样以后,不同样本直接混在一起,其实不同类图像还是独占某一隐变量空间的区域的(这个我主成分分析绘制过).
-
Hint:
- 正态分布的参数为 runtime 参数(中间层输出结果),不是 traintime 权重.
大致代码(见 speit-ml-tp 仓库) :
1from torch.nn.functional import binary_cross_entropy2from torchvision import datasets, transforms3from torch.utils.data import DataLoader4import matplotlib.pyplot as plt5
6class Sampling(nn.Module):7 def forward(self, mu, logvar):8 std = torch.exp(0.5 * logvar)9 eps = torch.randn_like(std)10 return mu + eps * std11
12class Encoder(nn.Module):13 def __init__(self, latent_dim):14 super(Encoder, self).__init__()15 self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1)46 collapsed lines
16 self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)17 self.fc = nn.Linear(64 * 7 * 7, 16)18 self.fc_mu = nn.Linear(16, latent_dim)19 self.fc_logvar = nn.Linear(16, latent_dim)20
21 def forward(self, x):22 x = torch.relu(self.conv1(x))23 x = torch.relu(self.conv2(x))24 x = x.view(x.size(0), -1)25 x = torch.relu(self.fc(x))26 mu = self.fc_mu(x)27 logvar = self.fc_logvar(x)28 return mu, logvar29
30class Decoder(nn.Module):31 def __init__(self, latent_dim):32 super(Decoder, self).__init__()33 self.fc = nn.Linear(latent_dim, 64 * 7 * 7)34 self.deconv1 = nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2, padding=1, output_padding=1)35 self.deconv2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)36 self.deconv3 = nn.ConvTranspose2d(32, 1, kernel_size=3, padding=1)37
38 def forward(self, z):39 x = torch.relu(self.fc(z))40 x = x.view(-1, 64, 7, 7)41 x = torch.relu(self.deconv1(x))42 x = torch.relu(self.deconv2(x))43 x = torch.sigmoid(self.deconv3(x))44 return x45
46class VAE(nn.Module):47 def __init__(self, latent_dim):48 super(VAE, self).__init__()49 self.encoder = Encoder(latent_dim)50 self.decoder = Decoder(latent_dim)51 self.sampling = Sampling()52
53 def forward(self, x):54 mu, logvar = self.encoder(x)55 z = self.sampling(mu, logvar)56 return self.decoder(z), mu, logvar57
58def loss_function(recon_x, x, mu, logvar): # reconstruction59 BCE = binary_cross_entropy(recon_x, x, reduction='sum')60 KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())61 return BCE + KLD