← Blog

拡散モデルの再構成誤差とスコア関数の一致について

Tweedie’s formulaを導出する。

最近はDiffusion Models(以下、「拡散モデル」)に興味があります。 EncodingとDecodingを経て得られるデータと元のデータの誤差である再構成誤差を最小化するのが自己符号化器のモチベーションであり、拡散モデルもこの一種と言えます。 実際、EncodingをNoising、DecodingをDenoisingと見なせば拡散モデルが自己符号化器の一種であることがわかります。

さて、今回の目標は次の式を示すことです。

r(x~)x=σ2x~logqσ(x~)r^*(\tilde{x}) - x = \sigma^2 \nabla_{\tilde{x}} \log q_\sigma(\tilde{x})

ただし、rr^* は最適なデノイザ(復号器)、qσ(x~)q_\sigma(\tilde{x}) はノイズを加えたデータの周辺分布、σ2\sigma^2 はノイズの分散です。

拡散モデルの再構成誤差とスコア関数の差がなぜ o(σ2)o(\sigma^{2}) なのか

仮定

ノイズ付加モデル qσ(x~x)q_\sigma(\tilde{x} \mid x) は次の正規分布に従うよう仮定します。

qσ(x~x)=N(x~;x,σ2I)=1(2πσ2)d/2exp ⁣(x~x22σ2)q_\sigma(\tilde{x} \mid x) = \mathcal{N}(\tilde{x};\, x,\, \sigma^2 I) = \frac{1}{(2\pi\sigma^2)^{d/2}} \exp\!\left( -\frac{\|\tilde{x} - x\|^2}{2\sigma^2} \right)

周辺分布による損失関数の表現

データ分布を p(x)p(x) とすると、同時分布は qσ(x~,x)=qσ(x~x)p(x)q_\sigma(\tilde{x}, x) = q_\sigma(\tilde{x} \mid x)\, p(x) であり、 ノイズ付きデータの周辺分布は

qσ(x~)=qσ(x~x)p(x)dxq_\sigma(\tilde{x}) = \int q_\sigma(\tilde{x} \mid x)\, p(x)\, dx

と書けます。復号器 rr の再構成誤差(損失関数)は、この同時分布に関する期待値として

L(r)=Eqσ(x~,x) ⁣[r(x~)x2]= ⁣ ⁣r(x~)x2qσ(x~x)p(x)dxdx~L(r) = \mathbb{E}_{q_\sigma(\tilde{x},\, x)}\!\Big[\|r(\tilde{x}) - x\|^2\Big] = \int\!\!\int \|r(\tilde{x}) - x\|^2\, q_\sigma(\tilde{x} \mid x)\, p(x)\, dx\, d\tilde{x}

と表せます。

最適な復号器を変分法で求める

L(r)L(r) を最小化する rr^* を求めます。各 x~\tilde{x} に対して被積分関数の汎関数微分をとり、00 とおくと

δLδr(x~)=2(r(x~)x)qσ(x~x)p(x)dx=0\frac{\delta L}{\delta r(\tilde{x})} = 2\int \big(r(\tilde{x}) - x\big)\, q_\sigma(\tilde{x} \mid x)\, p(x)\, dx = 0

これを r(x~)r^*(\tilde{x}) について解けば

r(x~)=xqσ(x~x)p(x)dxqσ(x~x)p(x)dx=xqσ(x~x)p(x)dxqσ(x~)r^*(\tilde{x}) = \frac{\int x\, q_\sigma(\tilde{x} \mid x)\, p(x)\, dx}{\int q_\sigma(\tilde{x} \mid x)\, p(x)\, dx} = \frac{\int x\, q_\sigma(\tilde{x} \mid x)\, p(x)\, dx}{q_\sigma(\tilde{x})}

ベイズの定理より事後分布は

qσ(xx~)=qσ(x~x)p(x)qσ(x~)q_\sigma(x \mid \tilde{x}) = \frac{q_\sigma(\tilde{x} \mid x)\, p(x)}{q_\sigma(\tilde{x})}

と分解できるので、最適な復号器は事後期待値になります。

r(x~)=Eqσ(xx~)[x]r^*(\tilde{x}) = \mathbb{E}_{q_\sigma(x \mid \tilde{x})}[x]

ガウス分布の微分による導出

私はここが一番天下りだと感じました。ガウス分布 qσ(x~x)q_\sigma(\tilde{x} \mid x)x~\tilde{x} で微分すると

x~qσ(x~x)=x~xσ2qσ(x~x)\nabla_{\tilde{x}}\, q_\sigma(\tilde{x} \mid x) = -\frac{\tilde{x} - x}{\sigma^2}\, q_\sigma(\tilde{x} \mid x)

この関係を使って r(x~)x~r^*(\tilde{x}) - \tilde{x} を計算します。先ほどのベイズの定理による分解の式を代入すると

r(x~)x~=(xx~)qσ(x~x)p(x)dxqσ(x~)r^*(\tilde{x}) - \tilde{x} = \frac{\int (x - \tilde{x})\, q_\sigma(\tilde{x} \mid x)\, p(x)\, dx}{q_\sigma(\tilde{x})}

この分子に上の微分の関係式を代入します。

=σ2x~qσ(x~x)p(x)dxqσ(x~)= \frac{\sigma^2 \int \nabla_{\tilde{x}}\, q_\sigma(\tilde{x} \mid x)\, p(x)\, dx}{q_\sigma(\tilde{x})}

p(x)p(x)x~\tilde{x} に依存しないので、微分と積分の順序を交換できます。

=σ2x~qσ(x~x)p(x)dxqσ(x~)=σ2x~qσ(x~)qσ(x~)= \frac{\sigma^2\, \nabla_{\tilde{x}} \int q_\sigma(\tilde{x} \mid x)\, p(x)\, dx}{q_\sigma(\tilde{x})} = \frac{\sigma^2\, \nabla_{\tilde{x}}\, q_\sigma(\tilde{x})}{q_\sigma(\tilde{x})}

logf=ff\nabla \log f = \frac{\nabla f}{f} であることを用いれば、最終的に

r(x~)x~=σ2x~logqσ(x~)\boxed{\, r^*(\tilde{x}) - \tilde{x} = \sigma^2\, \nabla_{\tilde{x}} \log q_\sigma(\tilde{x}) \,}

が得られます。これは Tweedie’s formula として知られる等式です。

σ20\sigma^2 \to 0 の極限と logp(x)\log p(x) への帰着

σ20\sigma^2 \to 0 のとき、ノイズ付加モデルはデルタ分布に収束します。

qσ(x~x)δ(x~x)q_\sigma(\tilde{x} \mid x) \to \delta(\tilde{x} - x)

したがって周辺分布は元のデータ分布に一致し

qσ(x~)p(x~)q_\sigma(\tilde{x}) \to p(\tilde{x})

よってスコア関数も

x~logqσ(x~)    xlogp(x)(σ20)\nabla_{\tilde{x}} \log q_\sigma(\tilde{x}) \;\to\; \nabla_x \log p(x) \quad (\sigma^2 \to 0)

と収束します。すなわち、ノイズを十分小さくとれば、最適デノイザから推定されるスコア関数は真のデータ分布のスコア関数 xlogp(x)\nabla_x \log p(x) に一致するということです。

結論

拡散モデルをノイズ付き自己符号化器として定式化し、最適なデノイザの復元誤差がスコア関数と直結することを示しました。

r(x~)x~σ2=x~logqσ(x~)  σ20  xlogp(x)\frac{r^*(\tilde{x}) - \tilde{x}}{\sigma^2} = \nabla_{\tilde{x}} \log q_\sigma(\tilde{x}) \;\xrightarrow{\sigma^2 \to 0}\; \nabla_x \log p(x)

このようにデノイジングを学習することがスコア関数を学習することと等価であるということから、再構成誤差を直接最小化する代わりにスコアを推定するアプローチが取られるようになったんですね。