Skip to main content

ddpm

Ref: https://zhuanlan.zhihu.com/p/12591930520

扩散过程

正向过程施加小范围的噪声

xt+1:=xt+nt,nN(0,σ2)x_{t+1}:=x_t+n_t, n\sim N(0, \sigma^2)

DDMP 采样器的目标

DDMP(去噪扩散模型)的核心思想是逆向操作

  • 正向过程:从真实数据 x0x_0​ 出发,一步步加高斯噪声,得到 x1,x2,,xtx_1​, x_2​,…,x_t​。
  • 逆向过程:从加了很多噪声的 xtx_t​ 出发,一步步 “去噪”,最终还原出真实数据 x0x_0​。

理想的 DDMP 采样器,就是在每一步 tt,根据当前的含噪数据 xt=zx_t=z​,输出一个能生成前一步数据 xt1x_{t−1}​ 的条件分布采样器:

p(xt1xt=z)p(x_{t−1}​∣x_t​=z)

info

这个条件分布是逆向过程的核心。

如果直接去学习这个条件分布 p(xt1xt)p(x_{t−1}​∣x_t​),我们需要为每一个 xtx_t​ 都训练一个生成模型,这在计算上是巨大的负担。

当噪声足够小时,逆向条件分布近似高斯 p(xt1xt)N(xt1;μ,σ2)p(x_{t−1​}∣x_t​)\approx N(x_{t-1};\mu,\sigma^2),这是整个DDPM的核心。

这意味着,我们不需要学习整个复杂的分布,只需要学习这个高斯分布的均值 μ 就够了,因为高斯分布的形状完全由均值和方差决定,而这里的方差 σ2σ^2 是我们在正向过程中已知的超参数。

因此问题变成了:给在时间 tt 和条件 xtx_t,学习pxt1xtp_{x_{t-1}|x_t} 的均值即可。学习均值比较简单,可以用回归方法。

info

对任意分布 p(xt1)p(x_{t−1}​),如果我们通过加极小方差高斯噪声得到 xtx_t​,那么逆向的条件分布 p(xt1xt)p(x_{t−1}​∣x_t​)近似服从高斯分布,并且其方差与正向过程的噪声方差几乎相同。

p(xt1)p(x_{t-1})服从高斯分布时,逆向条件分布才会严格服从高斯分布。

训练过程

一步一步正向加噪,公式为

xt=αˉtx0+1αˉtϵ(1)x_t = \sqrt{\bar\alpha_t}x_0+\sqrt{1-\bar\alpha_t} \epsilon \tag1

其中

αˉt=i=1tαi\bar{\alpha}_t = \prod_{i=1}^t \alpha_i

噪声方差的确定

考虑一个加噪序列,时间均匀间隔为 Δt=1T\Delta t=\frac{1}{T},每一步都加上一个均值为0,方差为σ2\sigma^2的噪声,那个任意一步的总噪声和(仍是高斯分布)的方差就是各分布方差和Nσ2N \sigma^2

info

但我们的目标是噪声方差与时间步数NN无关,因此将噪声添加一个归一化项σΔt\sigma\sqrt \Delta t,这样任意时间步总噪声就变为tΔtσ2Δt=σ2t\frac{t}{\Delta t}\sigma^2\Delta t=\sigma^2 t,只与扩散时间(而不是扩散步数)有关。

这就是连续时间的随机微分方程(SDE),表面扩散速度和时间的平方根成正比,这也是布朗运动的典型特征。

DDPM 的前向加噪过程,本质上就是在模拟一个受限的布朗运动。

DDPM的成立

前文说到只有当噪声足够小时,DDPM的核心依据才成立,根据上文噪声方差的确定,即Δt\Delta t足够小,因为DDPM也是使用了1000个时间步。

证明过程看论文。

info

直接把时间步 Δt\Delta t 缩小,虽然每一步引入的噪声会变小,但需要的总步数 T=1ΔtT=\frac{1}{\Delta t}​ 会随之增加。如果每一步的误差没有以足够快的速度衰减,最终的累计误差可能会变得很大,因此原证明的严谨性需要补充。

可以使用用 KL 散度来量化 “单步噪声添加” 带来的分布误差,近似高斯分布与真实后验分布之间的 KL 散度误差是 O(σ4)O(\sigma^4) 每步噪声的方差 σ2=σq2Δtσ^2=σ_q^2​\Delta t,因此 σ4=σq4(Δt)2σ^4=σ_q^4​(\Delta t)^2

为什么会存在这个 O(Δt2)\mathcal{O}(\Delta t^2) 的误差?

当我们用高斯分布近似 q(xt1xt)q(x_{t-1}|x_t) 时,我们实际上只匹配了分布的前二阶矩(均值和方差)。

  • 真实后验: 包含更高阶的累积量(Cumulants),如偏度(Skewness)和峰度(Kurtosis)。
  • 量级评估: 这些高阶项的大小与得分函数(Score Function)的导数(即 Hessian 矩阵)有关。如果数据的分布非常复杂(曲率很大),那么高斯近似的误差就会显著增加。 因此有其他工作如improved ddpm,同时学习方差预测。

总误差为 TO(σ4)=1ΔtO((Δt)2)=O(Δt)T⋅O(\sigma^4)=\frac 1 {\Delta t}​⋅O((\Delta t)^2)=O(\Delta t)Δt0\Delta t→0 时,总误差也趋于 0,这就保证了整个离散化过程的收敛性,也证明了 DDPM 的正确性。

直接预测x0

由式 (1)(1) 可知,x0x_0和噪声ϵ\epsilon是线性关系,因此可以直接预测。

并且由于每一步的噪声是同分布的(不独立, 正向过程是i.i.di.i.d),与其估计每一步的噪声,不如直接等效地估计所有先验噪声的平均,并且方差更小。

E[xtΔtxtxt]=ΔttE[(x0xt)xt]E[x_{t-\Delta t}-x_t|x_t]=\frac{\Delta t}{t}E[(x_0-x_t)|x_t]

采样过程

均值计算:利用预测噪声ϵθ(xt,t)\epsilon_\theta(x_t,t)推导反向分布均值,公式如下:

μθ(xt,t)=1αt(xt1αt1αˉtϵθ(xt,t))\mu_\theta(x_t,t)=\frac{1}{\sqrt{\alpha_t}}\left(x_t-\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_\theta(x_t,t)\right)

该式核心是从当前带噪样本xtx_t中减去网络预测的对应噪声分量,再通过αt\alpha_t相关系数校正,得到去噪后的均值估计。

单步采样更新:在均值基础上加入高斯噪声实现随机采样,公式为:

xt1=μθ(xt,t)+σtz,zN(0,I)x_{t-1}=\mu_\theta(x_t,t)+\sigma_t\cdot z,z\sim\mathcal{N}(0,\mathbf{I})

其中σt=1αˉt11αˉtβt\sigma_t=\sqrt{\frac{1-\bar{\alpha}_{t-1}}{1-\bar{\alpha}_t}\cdot\beta_t},且当t=1t=1时,z=0z=0(避免最终样本引入额外噪声)。