生成式扩散模型采样方法DDIM及其推导

记录学习DDIM采样方法的过程

为什么有DDIM?

上篇文章大概了解了什么是DDPM模型以及相关的公式推导,从中可以了解到DDPM本身是一个马尔科夫链的过程,有无法避免的迭代过程,导致推理速度过慢。为了解决这个问题,DDIM出现了,DDIM通过数学推理,打破了马尔科夫链的过程,最为巧妙的是其无需重新训练DDPM(无需改变前向加噪),只对采样器进行修改即可,修改后的采样器能够大幅增加采样速度。

从DDPM的推导过程,了解到最终的训练过程是训练了噪声预测器,损失函数为:

\[\begin{aligned} \left\|\boldsymbol{\epsilon}_t-\boldsymbol{\epsilon}_\theta\left(\sqrt{\bar{\alpha}_t} \mathbf{x}_0+\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon}_t, t\right)\right\|^2 \end{aligned}\]

可以看到这只和公式

\begin{equation} q(\mathbf{x_t}|\mathbf{x_0}) = \mathcal{N}(\mathbf{x_t}; \sqrt{ \bar{\alpha_t} } \mathbf{x}_0, (1-\bar{\alpha_t}) \boldsymbol{I}) \end{equation}

有关

而采样过程为:

\[\begin{aligned} \boldsymbol{\mu}_\theta\left(\mathbf{x}_t, t\right) & =\frac{1}{\sqrt{\alpha_t}}\left(\mathrm{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_\theta\left(\mathrm{x}_t, t\right)\right) \\ \text { Thus } \mathbf{x}_{t-1} & =\mathcal{N}\left(\mathbf{x}_{t-1} ; \frac{1}{\sqrt{\alpha_t}}\left(\mathbf{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \boldsymbol{\epsilon}_\theta\left(\mathbf{x}_t, t\right)\right), \boldsymbol{\Sigma}_\theta\left(\mathbf{x}_t, t\right)\right)\end{aligned}\]

这个过程只和\(q(\mathbf{x}_{t-1} \mid \mathbf{x}_t)\)有关。

既然训练和采样都和\(q(\mathbf{x}_{t} \mid \mathbf{x}_{t-1})\)没有关系,就可以大胆的把\(q(\mathbf{x}_{t} \mid \mathbf{x}_{t-1})\)从推导过程中去掉,这样所谓的马尔科夫链的过程就没有了,DDIM正是从这样的想法开始的。

什么是DDIM?

由上篇推导得知,最终\(q(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0)\)表示为

\[q(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0) = \mathcal{N}(\mathbf{x}_{t-1}; \tilde{\boldsymbol{\mu}}(\mathbf{x}_t, \mathbf{x}_0), \tilde{\beta_t} \mathbf{I})\]

因此,我们不妨假设

\[q(\mathbf{x}_{t-1} \mid \mathbf{x}_t, \mathbf{x}_0) \sim \mathcal{N}(k\mathbf{x}_0+m\mathbf{x}_t, \sigma^2)\]

那么可以得到:

\[\begin{align} \mathbf{x}_{t-1} = k\mathbf{x}_0+m\mathbf{x}_t + \sigma \varepsilon, \varepsilon \sim \mathcal{N}(\mathbf{0}, \boldsymbol{1}) \end{align}\]

根据公式(1)有: \(\mathbf{x}_{t} = \sqrt{\bar{\alpha}_t} \mathbf{x}_0+\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon}_t\)

带入公式(2)得:

\[\begin{aligned} \mathbf{x}_{t-1} & = k\mathbf{x}_0+m(\sqrt{\bar{\alpha}_t} \mathbf{x}_0+\sqrt{1-\bar{\alpha}_t} \boldsymbol{\epsilon}_t) + \sigma \varepsilon, \varepsilon \sim \mathcal{N}(\mathbf{0}, \boldsymbol{1}) \\ & = (k+m\sqrt{\bar{\alpha}_t} )\mathbf{x}_0 + m\sqrt{1-\bar{\alpha}_t}\boldsymbol{\epsilon}_t + \sigma \varepsilon \\ & = (k+m\sqrt{\bar{\alpha}_t} )\mathbf{x}_0 + \boldsymbol{\epsilon}^{\prime}, \boldsymbol{\epsilon}^{\prime} \sim \mathcal{N}(\mathbf{0}, m^2 (1 - \bar{\alpha}_t) + \sigma^2) \end{aligned}\]

又因为

\[\mathbf{x}_{t-1} = \sqrt{\bar{\alpha}_{t-1}} \mathbf{x}_0+\sqrt{1-\bar{\alpha}_{t-1}} \boldsymbol{\epsilon}_t\]

满足系数相等,得到:

\[\begin{aligned} k+m\sqrt{\bar{\alpha}_t} = \sqrt{\bar{\alpha}_{t-1}} \\ m^2 (1 - \bar{\alpha}_t) + \sigma^2 = 1-\bar{\alpha}_{t-1} \end{aligned}\]

解方程得到

\[\begin{aligned} m &= \sqrt{\frac{1-\bar{\alpha}_{t-1} - \sigma^2}{1 - \bar{\alpha}_t}} \\ k &= \sqrt{\bar{\alpha}_{t-1}}- \sqrt{\bar{\alpha}_t \frac{1-\bar{\alpha}_{t-1} - \sigma^2}{1 - \bar{\alpha}_t}} \\ \mathbf{x}_{t-1} &= \mathbf{x}_{0}(\sqrt{\bar{\alpha}_{t-1}}- \sqrt{\bar{\alpha}_t \frac{1-\bar{\alpha}_{t-1} - \sigma^2}{1 - \bar{\alpha}_t}}) + \mathbf{x}_{t}(\sqrt{\frac{1-\bar{\alpha}_{t-1} - \sigma^2}{1 - \bar{\alpha}_t}}) + \sigma \varepsilon \end{aligned}\]

\[\mathbf{x}_{t} = \sqrt{\bar{\alpha}_{t}} \mathbf{x}_0+\sqrt{1-\bar{\alpha}_{t}} \boldsymbol{\epsilon}_t\]

代入上式得到:

\[\begin{aligned} \mathbf{x}_{t-1} &= \mathbf{x}_{0}(\sqrt{\bar{\alpha}_{t-1}}- \sqrt{\bar{\alpha}_t \frac{1-\bar{\alpha}_{t-1} - \sigma^2}{1 - \bar{\alpha}_t}}) + \mathbf{x}_{t}(\sqrt{\frac{1-\bar{\alpha}_{t-1} - \sigma^2}{1 - \bar{\alpha}_t}}) + \sigma \varepsilon \\ &= \frac{\mathbf{x}_{t}-\sqrt{1-\bar{\alpha}_{t}} \boldsymbol{\epsilon}_t}{\sqrt{\bar{\alpha}_{t}}}(\sqrt{\bar{\alpha}_{t-1}}- \sqrt{\bar{\alpha}_t \frac{1-\bar{\alpha}_{t-1} - \sigma^2}{1 - \bar{\alpha}_t}}) + \mathbf{x}_{t}(\sqrt{\frac{1-\bar{\alpha}_{t-1} - \sigma^2}{1 - \bar{\alpha}_t}}) + \sigma \varepsilon \\ &= \frac{\mathbf{x}_{t}-\sqrt{1-\bar{\alpha}_{t}} \boldsymbol{\epsilon}_t}{\sqrt{\alpha_{t}}} - (\mathbf{x}_{t}-\sqrt{1-\bar{\alpha}_{t}} \boldsymbol{\epsilon}_t)(\sqrt{ \frac{1-\bar{\alpha}_{t-1} - \sigma^2}{1 - \bar{\alpha}_t}})+ \mathbf{x}_{t}(\sqrt{\frac{1-\bar{\alpha}_{t-1} - \sigma^2}{1 - \bar{\alpha}_t}}) + \sigma \varepsilon \\ &= \frac{\mathbf{x}_{t}-\sqrt{1-\bar{\alpha}_{t}} \boldsymbol{\epsilon}_t}{\sqrt{\alpha_{t}}} + \sqrt{1-\bar{\alpha}_{t}} \boldsymbol{\epsilon}_t (\sqrt{ \frac{1-\bar{\alpha}_{t-1} - \sigma^2}{1 - \bar{\alpha}_t}}) \\ &= \frac{\mathbf{x}_{t}-\sqrt{1-\bar{\alpha}_{t}} \boldsymbol{\epsilon}_t}{\sqrt{\alpha_{t}}} + \boldsymbol{\epsilon}_t (\sqrt{ 1-\bar{\alpha}_{t-1} - \sigma^2}) \\ &= \frac{1}{\sqrt{\alpha_{t}}}\left(\mathbf{x}_{t} - \sqrt{1-\bar{\alpha}_{t}}\boldsymbol{\epsilon}_t + \sqrt{\alpha_{t}}\boldsymbol{\epsilon}_t \sqrt{ 1-\bar{\alpha}_{t-1} - \sigma^2}\right)\\ &= \frac{1}{\sqrt{\alpha_{t}}}(\mathbf{x}_{t} - \left(\sqrt{1-\bar{\alpha}_{t}} - \sqrt{\alpha_{t}} \sqrt{ 1-\bar{\alpha}_{t-1} - \sigma^2}\right)\boldsymbol{\epsilon}_t)\\ &= \frac{1}{\sqrt{\alpha_{t}}}(\mathbf{x}_{t} - \left(\sqrt{1-\bar{\alpha}_{t}} - \sqrt{\alpha_{t}} \sqrt{ 1-\bar{\alpha}_{t-1} - \sigma^2} \right)\boldsymbol{\epsilon}_t)\\ \end{aligned}\]

需要注意的是这里的推导过程没有依赖\(q(\mathbf{x}_{t} \mid \mathbf{x}_{t-1})\),因此\(t-1\)可以是任意间隔的步数,更一般的有:

\[\begin{aligned} \mathbf{x}_{prev}&= \frac{1}{\sqrt{\alpha_{t}}}(\mathbf{x}_{t} - \left(\sqrt{1-\bar{\alpha}_{t}} - \sqrt{\alpha_{t}} \sqrt{ 1-\bar{\alpha}_{prev} - \sigma^2}\right)\boldsymbol{\epsilon}_t) \end{aligned}\]

也就是可以得到

\[\begin{align} q(\mathbf{x}_{t-1} \mid \mathbf{x}_t) = \mathcal{N}(\mathbf{x}_{t-1}; \frac{1}{\sqrt{\alpha_{t}}}(\mathbf{x}_{t} - \left(\sqrt{1-\bar{\alpha}_{t}} - \sqrt{\alpha_{t}} \sqrt{ 1-\bar{\alpha}_{t-1} - \sigma^2}\right)\boldsymbol{\epsilon}_t) ,\sigma^2) \end{align}\]

这里的推导过程也可以参考苏神的待定系数法得到

由上篇的推导得知,DDPM的马尔可夫链的推导结果如下:

\[\begin{align} q(\mathbf{x}_{t-1} \mid \mathbf{x}_t) = \mathcal{N}(\mathbf{x}_{t-1}; \frac{1}{\sqrt{\alpha_t}}\left(\mathbf{x}_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_t\right) ,\frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t} \cdot \beta_t) \end{align}\]

观察式子(3)和(4)有些相似,若令:

\[\sigma^2 = \frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t} \cdot \beta_t\]

代入(3)式就能够得到(4)式。可以知道DDIM就是DDPM的一个推广,若令\(\sigma = 0\),扩散模型的反向过程变成了一个没有噪声的确定性过程。给定随机噪声\(\mathbf{x}_t\),我们只能得到唯一的采样结果\(\mathbf{x}_0\)。这种结果确定的概率模型被称为隐式概率模型(implicit probabilistic model)。所以,论文作者把方差为0的这种扩散模型称为DDIM(Denoising Diffusion Implicit Model)。

为了方便地选取方差值,作者将方差改写为:

\[\begin{align} \tilde{\beta}_t(\eta) = \eta \cdot \frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t} \cdot \beta_t \end{align}\]

其中\(\eta \in [0,1]\),通过选择不同的\(\eta\),我们实际上是在DDPM和DDIM之间插值。\(\eta\)控制了插值的比例。\(\eta=0\) 模型是DDIM,\(\eta=1\) 模型是DDPM

怎么采样DDIM?

对于DDIM的实现及如何采样,可以直接参考