03/29 11:51

# 扩散模型讲解

## 1. 背景知识: 生成模型

### 2. 扩散模型

#### 2.1.3 目标函数

\begin{aligned} \mathcal L & = - \log p(\boldsymbol x) \\ & = - \log \int \frac{p_\theta(\boldsymbol x_{0:T})q(\boldsymbol x_{1:T} | \boldsymbol x_0)}{q(\boldsymbol x_{1:T} | \boldsymbol x_0)} d \boldsymbol x_{1:T} \\ & \leq - \mathbb E_{q(\boldsymbol x_{1:T} | \boldsymbol x_0)} \left[ \frac{p_\theta(\boldsymbol x_{0:T})}{q(\boldsymbol x_{1:T} | \boldsymbol x_0)}\right] \\ & = - \underbrace{\mathbb{E}_{q\left(\boldsymbol{x}_1 \mid \boldsymbol{x}_0\right)}\left[\log p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_0 \mid \boldsymbol{x}_1\right)\right]}_{\text {重构项}} + \underbrace{D_{\mathrm{KL}}\left(q\left(\boldsymbol{x}_T \mid \boldsymbol{x}_0\right) \| p\left(\boldsymbol{x}_T\right)\right)}_{\text {先验匹配项}} + \sum_{t=2}^T \underbrace{\mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)}\left[D_{\mathrm{KL}}\left(q\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0\right) \| p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t\right)\right)\right]}_{\text {去噪匹配项}} \end{aligned} \tag6

\begin{aligned} q(\boldsymbol x_{t-1} | \boldsymbol x_t, \boldsymbol x_0) & = \frac{q(\boldsymbol x_{t} | \boldsymbol x_{t-1}, \boldsymbol x_0) q(\boldsymbol x_{t-1} | \boldsymbol x_0)}{q(\boldsymbol x_{t} | \boldsymbol x_0)} \ & \propto \mathcal N \left( \boldsymbol x_{t-1}; \frac{\sqrt{\alpha_t} (1 - \bar{\alpha}{t-1}) \boldsymbol x_t + \sqrt{\bar{\alpha}{t-1}}(1 - \alpha_t) \boldsymbol x_0}{ 1- \bar{\alpha}t}, \frac{(1 - \alpha_t)(1 - \bar{\alpha}{t-1})}{1 - \bar{\alpha}t} \mathbf I \right) \ & = \mathcal N(\boldsymbol x{t-1}; \mu_q(\boldsymbol x_t, \boldsymbol x_0), \Sigma_q(t)) \end{aligned} \tag8

$p{\boldsymbol{\theta}}\left(\boldsymbol{x}{t-1} \mid \boldsymbol{x}t\right) = \mathcal N(\boldsymbol x{t-1}; \mu\theta(\boldsymbol x_t, t), \Sigma_q(t)) \tag9$

\begin{aligned} & D_\text{KL}(\mathcal N(\boldsymbol x; \boldsymbol \mu_x, \boldsymbol \Sigma_x), \mathcal N(\boldsymbol y; \boldsymbol \mu_y, \boldsymbol \Sigma_y) \ = & \frac{1}{2}\left[ \log \frac{|\boldsymbol \Sigma_x|}{|\boldsymbol \Sigma_y|} - d + \text{tr}(\boldsymbol \Sigma_y^{-1} \boldsymbol \Sigma_x) + (\boldsymbol \mu_y - \boldsymbol \mu_x)^\intercal \boldsymbol \sigma_y^{-1}(\boldsymbol \mu_y - \boldsymbol \mu_x)\right]) \end{aligned} \tag{10}

\begin{aligned} \mathop{\arg\min}\theta D\text{KL}(q(\boldsymbol x_{t-1} | \boldsymbol x_t, \boldsymbol x_0) || p_\theta(\boldsymbol x_{t-1} | \boldsymbol x_t)) = \mathop{\arg\min}\theta \frac{1}{2\sigma_q^2(t)} \left[|\boldsymbol \mu\theta (\boldsymbol x_t, \boldsymbol x_0) - \boldsymbol \mu_q(\boldsymbol x_t, t) |^2_2\right] \ \end{aligned} \tag{11}

### 2.2 算法实现

#### 2.2.1模型结构

DDPM在预测施加的噪声时，它的输入是施加噪声之后的图像，预测内容是和输入图像相同尺寸的噪声，所以它可以看做一个Img2Img的任务。DDPM选择了U-Net[9]作为噪声预测的模型结构。U-Net是一个U形的网络结构，它由编码器，解码器以及编码器和解码器之间的跨层连接（残差连接）组成。其中编码器将图像降采样成一个特征，解码器将这个特征上采样为目标噪声，跨层连接用于拼接编码器和解码器之间的特征。

1. 首先在噪声图像$\boldsymbol x_0$上应用卷积层，并为噪声水平$t$计算时间嵌入；
2. 接下来是降采样阶段。采用的模型结构依次是两个卷积（WRNS或是ConvNeXT）+GN+Attention+降采样层；
3. 在网络的最中间，依次是卷积层+Attention+卷积层；
4. 接下来是上采样阶段。它首先会使用Short-cut拼接来自降采样中同样尺寸的卷积，再之后是两个卷积+GN+Attention+上采样层。
5. 最后是使用WRNS或是ConvNeXT作为输出层的卷积。

U-Net类的forword函数如下面代码片段所示，完整的实现代码参照[3]。

def forward(self, x, time):
x = self.init_conv(x)
t = self.time_mlp(time) if exists(self.time_mlp) else None
h = []
# downsample
for block1, block2, attn, downsample in self.downs:
x = block1(x, t)
x = block2(x, t)
x = attn(x)
h.append(x)
x = downsample(x)
# bottleneck
x = self.mid_block1(x, t)
x = self.mid_attn(x)
x = self.mid_block2(x, t)
# upsample
for block1, block2, attn, upsample in self.ups:
x = torch.cat((x, h.pop()), dim=1)
x = block1(x, t)
x = block2(x, t)
x = attn(x)
x = upsample(x)
return self.final_conv(x)



#### 2.2.2 前向加噪

def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):
# 1. 根据时刻t计算随机噪声分布，并对图像x_start进行加噪
x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
# 2. 根据噪声图像以及时刻t，预测添加的噪声
predicted_noise = denoise_model(x_noisy, t)
# 3. 对比添加的噪声和预测的噪声的相似性
loss = F.mse_loss(noise, predicted_noise)
return loss



#### 2.2.3 样本生成

@torch.no_grad()
def p_sample(model, x, t, t_index):
betas_t = extract(betas, t, x.shape)
sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x.shape)
sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
# 使用式(13)计算模型的均值
model_mean = sqrt_recip_alphas_t * (x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t)
if t_index == 0:
return model_mean
else:
# 获取保存的方差
posterior_variance_t = extract(posterior_variance, t, x.shape)
noise = torch.randn_like(x)
# 算法2的第4行
return model_mean + torch.sqrt(posterior_variance_t) * noise

# 算法2的流程，但是我们保存了所有中间样本
def p_sample_loop(model, shape):
device = next(model.parameters()).device
b = shape[0]
# start from pure noise (for each example in the batch)
img = torch.randn(shape, device=device)
imgs = []
for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)
imgs.append(img.cpu().numpy())
return imgs



## 3. 总结

1. 采样速度慢：DDPM的去噪是从时刻$T$到时刻$1$的一个完整的马尔可夫链的计算，尤其是DDPM还需要一个比较大的$T$才能保证比较好的效果，这就导致了DDPM的采样过程注定是非常慢的；
2. 生成效果差：DDPM的效果并不能说是非常好，尤其是对于高分辨率图像的生成。这一方面是因为它的计算速度限制了它扩展到更大的模型；另一方面它的设计还有一些问题，例如逐像素的计算损失并使用相同权值而忽略图像中的主体并不是非常好的策略。
3. 内容不可控：我们可以看出，DDPM生成的内容完全还是取决于它的训练集。它并没有引入一些先验条件，因此并不能通过控制图像中的细节来生成我们制定的内容。

## Reference

[1] Sohl-Dickstein, Jascha, et al. "Deep unsupervised learning using nonequilibrium thermodynamics." International Conference on Machine Learning. PMLR, 2015.

[2] Ho, Jonathan, Ajay Jain, and Pieter Abbeel. "Denoising diffusion probabilistic models." Advances in Neural Information Processing Systems 33 (2020): 6840-6851.

[6] Nichol, Alexander Quinn, and Prafulla Dhariwal. "Improved denoising diffusion probabilistic models." International Conference on Machine Learning. PMLR, 2021.

[7] Kingma, Diederik P., and Max Welling. "Auto-encoding variational bayes." arXiv preprint arXiv:1312.6114 (2013).

[8] Hinton, Geoffrey E., and Ruslan R. Salakhutdinov. "Reducing the dimensionality of data with neural networks." science 313.5786 (2006): 504-507.

[9] Ronneberger O, Fischer P, Brox T. U-net: Convolutional networks for biomedical image segmentation[C]//International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015: 234-241.

[10] Long, Jonathan, Evan Shelhamer, and Trevor Darrell. "Fully convolutional networks for semantic segmentation." Proceedings of the IEEE conference on computer vision and pattern recognition. 2015.

[11] Luo, Calvin. "Understanding diffusion models: A unified perspective." arXiv preprint arXiv:2208.11970 (2022).

[12] Zagoruyko, Sergey, and Nikos Komodakis. "Wide residual networks." arXiv preprint arXiv:1605.07146 (2016).

[14] Liu, Zhuang, et al. "A convnet for the 2020s." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2022.

[15] Vaswani, Ashish, et al. "Attention is all you need." Advances in neural information processing systems 30 (2017).

[16] Wu, Yuxin, and Kaiming He. "Group normalization." Proceedings of the European conference on computer vision (ECCV). 2018.

0
1 收藏

0 评论
1 收藏
0