条件变分自编码器 CVAE

本文最后更新于:2023年5月24日 晚上

之前学习了 VAE 相关内容,本文记录 VAE 的条件版 CVAE(Conditional VAE)。

CVAE原理

  • 在条件变分自编码器(CVAE)中,模型的输出就不是 $\mathbf{x}_j$了,而是对应于输入$\mathbf{x}_i$的任务相关数据$\mathbf{y}_i$

    也就是说输入的是条件,输出是在条件约束下的数据样本;

    比如手写数字生成任务中:输入 $x$ 可以是想输出的数字,比如 $6$,输出 $y$ 则是数字 6 的手写图片。

  • 因此,我们采样的时候,不再是从 $P(Z)$ 中直接采样,而是从 $P(Z|X)$ 中进行采样,因此假设变成了, $ P(Z \mid X)=N(Z \mid \mu(X), I) $

  • 套路和VAE是一样的,这次的最大似然估计变成了 $\log p_{\theta}(\mathbf{Y}\mid\mathbf{X})$ ,即:

$$ \begin{aligned} \log p_{\theta}(\mathbf{Y} \mid \mathbf{X}) &=1 \cdot \log p_{\theta}(\mathbf{Y} \mid \mathbf{X}) \\ &=\left(\int_{\mathbf{z}} q_{\phi}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y}) \mathrm{d} \mathbf{z}\right) \log p_{\theta}(\mathbf{Y} \mid \mathbf{X}) \\ &=\int_{\mathbf{z}} q_{\phi}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y}) \log p_{\theta}(\mathbf{Y} \mid \mathbf{X}) \mathrm{d} \mathbf{z} \\ &=\int_{\mathbf{z}} q_{\phi}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y}) \log \frac{p_{\theta}(\mathbf{z}, \mathbf{X}, \mathbf{Y})}{p_{\theta}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y}) p_{\theta}(\mathbf{X})} \mathrm{d} \mathbf{z} \\ &=\int_{\mathbf{z}} q_{\phi}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y}) \log \frac{q_{\phi}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y})}{p_{\theta}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y})} \frac{p_{\theta}(\mathbf{z}, \mathbf{X}, \mathbf{Y})}{q_{\phi}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y}) p_{\theta}(\mathbf{X})} \mathrm{d} \mathbf{z} \\ &=\int_{\mathbf{z}} q_{\phi}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y}) \log \frac{q_{\phi}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y})}{p_{\theta}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y})} \mathrm{d} \mathbf{z}+\int_{\mathbf{z}} q_{\phi}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y}) \log \frac{p_{\theta}(\mathbf{z}, \mathbf{X}, \mathbf{Y})}{q_{\phi}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y}) p_{\theta}(\mathbf{X})} \mathrm{d} \mathbf{z} \\ &=D_{K L}\left(q_{\phi}, p_{\theta}\right)+\ell\left(p_{\theta}, q_{\phi}\right) \end{aligned} \tag{16} $$
  • 则ELBO(Empirical Lower Bound)为 $\ell(p_{\theta}, q_{\phi})$,进一步:
$$ \begin{aligned} \ell\left(p_{\theta}, q_{\phi}\right) &=\int_{\mathbf{z}} q_{\phi}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y}) \log \frac{p_{\theta}(\mathbf{z}, \mathbf{X}, \mathbf{Y})}{q_{\phi}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y}) p_{\theta}(\mathbf{X})} \mathrm{d} \mathbf{z} \\ &=\int_{\mathbf{z}} q_{\phi}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y}) \log \frac{p_{\theta}(\mathbf{Y} \mid \mathbf{X}, \mathbf{Z}) p_{\theta}(\mathbf{Z} \mid \mathbf{X}) p_{\theta}(\mathbf{X})}{q_{\phi}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y}) p_{\theta}(\mathbf{X})} \mathrm{d} \mathbf{z} \\ &=\int_{\mathbf{z}} q_{\phi}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y}) \log \frac{p_{\theta}(\mathbf{Z} \mid \mathbf{X})}{q_{\phi}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y})} \mathrm{d} \mathbf{z}+\int_{\mathbf{z}} q_{\phi}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y}) \log p_{\theta}(\mathbf{Y} \mid \mathbf{X}, \mathbf{Z}) \mathrm{d} \mathbf{z} \\ &=-D_{K L}\left(q_{\phi}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y}) \mid p_{\theta}(\mathbf{Z} \mid \mathbf{X})\right)+\mathbb{E}_{q_{\phi}}\left[\log p_{\theta}(\mathbf{Y} \mid \mathbf{X}, \mathbf{Z})\right] \end{aligned} \tag{17} $$

其中 $(x,y)$ 为一般有监督学习中的数据对。可以看出CVAE相当于一个有监督版本的VAE,它重构/生成的是 $ y \mid x $ (VAE重构/生成的是 $ x $ )。举个例子,若令 $ x $ 表示手写数字的类别标签, $ y $ 表示手写数字图像,就可以通过采样 $ z $ 生成指定的数字 $ x $ 对应的图像 $ y $ 。值得一提的是,VAE 中的关于 $ z $ 的先验项是 $ p_{\theta}(z) $ ,而 CVAE 中的先验项 $ p_{\theta}(z \mid x) $ 与 $ x $ 有关,在网络实现上就会有一个从 $ x $ 到 $ z $ 的 “先验网络”。

网络结构

网络结构包含三个部分:

  1. 先验网络 $p_{\theta}(\mathbf{z}\mid\mathbf{X})$,如下图(b)所示
  2. Recognition 网络 $q_{\phi}(\mathbf{z}\mid\mathbf{X},\mathbf{Y})$, 如下图©所示
  3. Decoder网络 $p_{\theta}(\mathbf{Y}\mid\mathbf{X},\mathbf{Z})$,如下图(b)所示

 概率图模型

通过条件改变隐变量的均值,从而控制了隐变量采样的位置,控制最后的输出结果。

先看图 (b),代表了整个从 $ x $ 推断到 $ y $ 的过程,如果理解的话其实这是一个生成的过程 (生成 $y \mid x$) :先从输入 $x$ 经过一个先验网络到 $z$ (重参数采样),再由 $x$ 和 $z$ 生成 $y$ 。然而,这篇文章后面的实验都采用了图 (d) 的架构。也就是 $ x $ 先通过一个 baseline CNN得到一个 $ \hat{y} $ ,再 由 $ x $ 和 $ \hat{y} $ 共同得到 $ z $ 的先验。个人认为这个操作就是为了得到效果保证而启发式地设计的,理论上不太漂亮。

对比 VAE

  • VAE的变分下界为:
$$ \mathcal{L}(\phi, \theta ; x)=-K L\left(q_{\phi}(z \mid x) \| p_{\theta}(z)\right)+\mathbb{E}_{q_{\phi}(z \mid x)}\left[\log p_{\theta}(z \mid x)\right] \leq \log p_{\theta}(x) $$
  • CVAE的变分下界为:

    $$ \mathcal{L}(\phi, \theta ; x, y)=-K L\left(q_{\phi}(z \mid x, y) \| p_{\theta}(z \mid x)\right)+\mathbb{E}_{q_{\phi}(z \mid x, y)}\left[\log p_{\theta}(y \mid x, z)\right] \leq \log p_{\theta}(y \mid x) $$

示例代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class CVAE(nn.Module):
"""Implementation of CVAE(Conditional Variational Auto-Encoder)"""
def __init__(self, feature_size, class_size, latent_size):
super(CVAE, self).__init__()
self.fc1 = nn.Linear(feature_size + class_size, 200)
self.fc2_mu = nn.Linear(200, latent_size)
self.fc2_log_std = nn.Linear(200, latent_size)
self.fc3 = nn.Linear(latent_size + class_size, 200)
self.fc4 = nn.Linear(200, feature_size)
def encode(self, x, y):
h1 = F.relu(self.fc1(torch.cat([x, y], dim=1))) # concat features and labels
mu = self.fc2_mu(h1)
log_std = self.fc2_log_std(h1)
return mu, log_std
def decode(self, z, y):
h3 = F.relu(self.fc3(torch.cat([z, y], dim=1))) # concat latents and labels
recon = torch.sigmoid(self.fc4(h3)) # use sigmoid because the input image's pixel is between 0-1
return recon
def reparametrize(self, mu, log_std):
std = torch.exp(log_std)
eps = torch.randn_like(std) # simple from standard normal distribution
z = mu + eps * std
return z
def forward(self, x, y):
mu, log_std = self.encode(x, y)
z = self.reparametrize(mu, log_std)
recon = self.decode(z, y)
return recon, mu, log_std
def loss_function(self, recon, x, mu, log_std) -> torch.Tensor:
recon_loss = F.mse_loss(recon, x, reduction="sum") # use "mean" may have a bad effect on gradients
kl_loss = -0.5 * (1 + 2*log_std - mu.pow(2) - torch.exp(2*log_std))
kl_loss = torch.sum(kl_loss)
loss = recon_loss + kl_loss
return loss

原始论文

参考资料



文章链接:
https://www.zywvvd.com/notes/study/deep-learning/generation/vae/cvae/cvae/


条件变分自编码器 CVAE
https://www.zywvvd.com/notes/study/deep-learning/generation/vae/cvae/cvae/
作者
Yiwei Zhang
发布于
2023年5月18日
许可协议