本文最后更新于:2024年5月7日 下午
之前学习了 VAE 相关内容,本文记录 VAE 的条件版 CVAE(Conditional VAE)。
CVAE原理
-
在条件变分自编码器(CVAE)中,模型的输出就不是 $\mathbf{x}_j$了,而是对应于输入$\mathbf{x}_i$的任务相关数据$\mathbf{y}_i$
也就是说输入的是条件,输出是在条件约束下的数据样本;
比如手写数字生成任务中:输入 $x$ 可以是想输出的数字,比如 $6$,输出 $y$ 则是数字 6 的手写图片。
-
因此,我们采样的时候,不再是从 $P(Z)$ 中直接采样,而是从 $P(Z|X)$ 中进行采样,因此假设变成了, $ P(Z \mid X)=N(Z \mid \mu(X), I) $
-
套路和VAE是一样的,这次的最大似然估计变成了 $\log p_{\theta}(\mathbf{Y}\mid\mathbf{X})$ ,即:
$$
\begin{aligned} \log p_{\theta}(\mathbf{Y} \mid \mathbf{X}) &=1 \cdot \log p_{\theta}(\mathbf{Y} \mid \mathbf{X}) \\ &=\left(\int_{\mathbf{z}} q_{\phi}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y}) \mathrm{d} \mathbf{z}\right) \log p_{\theta}(\mathbf{Y} \mid \mathbf{X}) \\ &=\int_{\mathbf{z}} q_{\phi}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y}) \log p_{\theta}(\mathbf{Y} \mid \mathbf{X}) \mathrm{d} \mathbf{z} \\ &=\int_{\mathbf{z}} q_{\phi}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y}) \log \frac{p_{\theta}(\mathbf{z}, \mathbf{X}, \mathbf{Y})}{p_{\theta}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y}) p_{\theta}(\mathbf{X})} \mathrm{d} \mathbf{z} \\ &=\int_{\mathbf{z}} q_{\phi}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y}) \log \frac{q_{\phi}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y})}{p_{\theta}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y})} \frac{p_{\theta}(\mathbf{z}, \mathbf{X}, \mathbf{Y})}{q_{\phi}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y}) p_{\theta}(\mathbf{X})} \mathrm{d} \mathbf{z} \\ &=\int_{\mathbf{z}} q_{\phi}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y}) \log \frac{q_{\phi}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y})}{p_{\theta}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y})} \mathrm{d} \mathbf{z}+\int_{\mathbf{z}} q_{\phi}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y}) \log \frac{p_{\theta}(\mathbf{z}, \mathbf{X}, \mathbf{Y})}{q_{\phi}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y}) p_{\theta}(\mathbf{X})} \mathrm{d} \mathbf{z} \\ &=D_{K L}\left(q_{\phi}, p_{\theta}\right)+\ell\left(p_{\theta}, q_{\phi}\right) \end{aligned} \tag{16}
$$
- 则ELBO(Empirical Lower Bound)为 $\ell(p_{\theta}, q_{\phi})$,进一步:
$$
\begin{aligned} \ell\left(p_{\theta}, q_{\phi}\right) &=\int_{\mathbf{z}} q_{\phi}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y}) \log \frac{p_{\theta}(\mathbf{z}, \mathbf{X}, \mathbf{Y})}{q_{\phi}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y}) p_{\theta}(\mathbf{X})} \mathrm{d} \mathbf{z} \\ &=\int_{\mathbf{z}} q_{\phi}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y}) \log \frac{p_{\theta}(\mathbf{Y} \mid \mathbf{X}, \mathbf{Z}) p_{\theta}(\mathbf{Z} \mid \mathbf{X}) p_{\theta}(\mathbf{X})}{q_{\phi}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y}) p_{\theta}(\mathbf{X})} \mathrm{d} \mathbf{z} \\ &=\int_{\mathbf{z}} q_{\phi}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y}) \log \frac{p_{\theta}(\mathbf{Z} \mid \mathbf{X})}{q_{\phi}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y})} \mathrm{d} \mathbf{z}+\int_{\mathbf{z}} q_{\phi}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y}) \log p_{\theta}(\mathbf{Y} \mid \mathbf{X}, \mathbf{Z}) \mathrm{d} \mathbf{z} \\ &=-D_{K L}\left(q_{\phi}(\mathbf{z} \mid \mathbf{X}, \mathbf{Y}) \mid p_{\theta}(\mathbf{Z} \mid \mathbf{X})\right)+\mathbb{E}_{q_{\phi}}\left[\log p_{\theta}(\mathbf{Y} \mid \mathbf{X}, \mathbf{Z})\right] \end{aligned} \tag{17}
$$
其中 $(x,y)$ 为一般有监督学习中的数据对。可以看出CVAE相当于一个有监督版本的VAE,它重构/生成的是 $ y \mid x $ (VAE重构/生成的是 $ x $ )。举个例子,若令 $ x $ 表示手写数字的类别标签, $ y $ 表示手写数字图像,就可以通过采样 $ z $ 生成指定的数字 $ x $ 对应的图像 $ y $ 。值得一提的是,VAE 中的关于 $ z $ 的先验项是 $ p_{\theta}(z) $ ,而 CVAE 中的先验项 $ p_{\theta}(z \mid x) $ 与 $ x $ 有关,在网络实现上就会有一个从 $ x $ 到 $ z $ 的 “先验网络”。
网络结构
网络结构包含三个部分:
- 先验网络 $p_{\theta}(\mathbf{z}\mid\mathbf{X})$,如下图(b)所示
- Recognition 网络 $q_{\phi}(\mathbf{z}\mid\mathbf{X},\mathbf{Y})$, 如下图©所示
- Decoder网络 $p_{\theta}(\mathbf{Y}\mid\mathbf{X},\mathbf{Z})$,如下图(b)所示
通过条件改变隐变量的均值,从而控制了隐变量采样的位置,控制最后的输出结果。
先看图 (b),代表了整个从 $ x $ 推断到 $ y $ 的过程,如果理解的话其实这是一个生成的过程 (生成 $y \mid x$) :先从输入 $x$ 经过一个先验网络到 $z$ (重参数采样),再由 $x$ 和 $z$ 生成 $y$ 。然而,这篇文章后面的实验都采用了图 (d) 的架构。也就是 $ x $ 先通过一个 baseline CNN得到一个 $ \hat{y} $ ,再 由 $ x $ 和 $ \hat{y} $ 共同得到 $ z $ 的先验。个人认为这个操作就是为了得到效果保证而启发式地设计的,理论上不太漂亮。
对比 VAE
$$
\mathcal{L}(\phi, \theta ; x)=-K L\left(q_{\phi}(z \mid x) \| p_{\theta}(z)\right)+\mathbb{E}_{q_{\phi}(z \mid x)}\left[\log p_{\theta}(z \mid x)\right] \leq \log p_{\theta}(x)
$$
-
CVAE的变分下界为:
$$
\mathcal{L}(\phi, \theta ; x, y)=-K L\left(q_{\phi}(z \mid x, y) \| p_{\theta}(z \mid x)\right)+\mathbb{E}_{q_{\phi}(z \mid x, y)}\left[\log p_{\theta}(y \mid x, z)\right] \leq \log p_{\theta}(y \mid x)
$$
示例代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
| class CVAE(nn.Module): """Implementation of CVAE(Conditional Variational Auto-Encoder)""" def __init__(self, feature_size, class_size, latent_size): super(CVAE, self).__init__() self.fc1 = nn.Linear(feature_size + class_size, 200) self.fc2_mu = nn.Linear(200, latent_size) self.fc2_log_std = nn.Linear(200, latent_size) self.fc3 = nn.Linear(latent_size + class_size, 200) self.fc4 = nn.Linear(200, feature_size) def encode(self, x, y): h1 = F.relu(self.fc1(torch.cat([x, y], dim=1))) mu = self.fc2_mu(h1) log_std = self.fc2_log_std(h1) return mu, log_std def decode(self, z, y): h3 = F.relu(self.fc3(torch.cat([z, y], dim=1))) recon = torch.sigmoid(self.fc4(h3)) return recon def reparametrize(self, mu, log_std): std = torch.exp(log_std) eps = torch.randn_like(std) z = mu + eps * std return z def forward(self, x, y): mu, log_std = self.encode(x, y) z = self.reparametrize(mu, log_std) recon = self.decode(z, y) return recon, mu, log_std def loss_function(self, recon, x, mu, log_std) -> torch.Tensor: recon_loss = F.mse_loss(recon, x, reduction="sum") kl_loss = -0.5 * (1 + 2*log_std - mu.pow(2) - torch.exp(2*log_std)) kl_loss = torch.sum(kl_loss) loss = recon_loss + kl_loss return loss
|
原始论文
参考资料
文章链接:
https://www.zywvvd.com/notes/study/deep-learning/generation/vae/cvae/cvae/