how to

VAE

Dec 12, 2024
notesjulyfun大四上机器学习
3 Minutes
461 Words
  • 关于 VAE 理解的教程: https://spaces.ac.cn/archives/5253

  • 假定 $p(Z|X)$ 为一正态分布.(Z 为隐变量,X 为目标分布)

    • 注意不是假设 $p(Z)$为正态分布,不同 $X$ 显然必须有不同的隐藏分布,否则解码器无法区分它们,训练时 $X_i$ 和 $X^“hat”_i$ 就无法对应上.
  • 训练编码器使得样本对应的(编码器输出的)$mu$ 和 $log sigma^2$ 既要接近正态分布,又要对不同样本产生一些区别使得解码器能够将其还原到对应图像.

    • 接近正态又要有些微区分,这是一个权衡问题.
    • 为了防止正态分布采样以后,不同样本直接混在一起,其实不同类图像还是独占某一隐变量空间的区域的(这个我主成分分析绘制过).
  • Hint:

    • 正态分布的参数为 runtime 参数(中间层输出结果),不是 traintime 权重.

大致代码(见 speit-ml-tp 仓库) :

1
from torch.nn.functional import binary_cross_entropy
2
from torchvision import datasets, transforms
3
from torch.utils.data import DataLoader
4
import matplotlib.pyplot as plt
5
6
class Sampling(nn.Module):
7
def forward(self, mu, logvar):
8
std = torch.exp(0.5 * logvar)
9
eps = torch.randn_like(std)
10
return mu + eps * std
11
12
class Encoder(nn.Module):
13
def __init__(self, latent_dim):
14
super(Encoder, self).__init__()
15
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1)
46 collapsed lines
16
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
17
self.fc = nn.Linear(64 * 7 * 7, 16)
18
self.fc_mu = nn.Linear(16, latent_dim)
19
self.fc_logvar = nn.Linear(16, latent_dim)
20
21
def forward(self, x):
22
x = torch.relu(self.conv1(x))
23
x = torch.relu(self.conv2(x))
24
x = x.view(x.size(0), -1)
25
x = torch.relu(self.fc(x))
26
mu = self.fc_mu(x)
27
logvar = self.fc_logvar(x)
28
return mu, logvar
29
30
class Decoder(nn.Module):
31
def __init__(self, latent_dim):
32
super(Decoder, self).__init__()
33
self.fc = nn.Linear(latent_dim, 64 * 7 * 7)
34
self.deconv1 = nn.ConvTranspose2d(64, 64, kernel_size=3, stride=2, padding=1, output_padding=1)
35
self.deconv2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
36
self.deconv3 = nn.ConvTranspose2d(32, 1, kernel_size=3, padding=1)
37
38
def forward(self, z):
39
x = torch.relu(self.fc(z))
40
x = x.view(-1, 64, 7, 7)
41
x = torch.relu(self.deconv1(x))
42
x = torch.relu(self.deconv2(x))
43
x = torch.sigmoid(self.deconv3(x))
44
return x
45
46
class VAE(nn.Module):
47
def __init__(self, latent_dim):
48
super(VAE, self).__init__()
49
self.encoder = Encoder(latent_dim)
50
self.decoder = Decoder(latent_dim)
51
self.sampling = Sampling()
52
53
def forward(self, x):
54
mu, logvar = self.encoder(x)
55
z = self.sampling(mu, logvar)
56
return self.decoder(z), mu, logvar
57
58
def loss_function(recon_x, x, mu, logvar): # reconstruction
59
BCE = binary_cross_entropy(recon_x, x, reduction='sum')
60
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
61
return BCE + KLD
Article title:VAE
Article author:Julyfun
Release time:Dec 12, 2024
Copyright 2025
Sitemap