最近はDiffusion Models(以下、「拡散モデル」)に興味があります。
EncodingとDecodingを経て得られるデータと元のデータの誤差である再構成誤差を最小化するのが自己符号化器のモチベーションであり、拡散モデルもこの一種と言えます。
実際、EncodingをNoising、DecodingをDenoisingと見なせば拡散モデルが自己符号化器の一種であることがわかります。
さて、今回の目標は次の式を示すことです。
r∗(x~)−x=σ2∇x~logqσ(x~)
ただし、r∗ は最適なデノイザ(復号器)、qσ(x~) はノイズを加えたデータの周辺分布、σ2 はノイズの分散です。
拡散モデルの再構成誤差とスコア関数の差がなぜ o(σ2) なのか
仮定
ノイズ付加モデル qσ(x~∣x) は次の正規分布に従うよう仮定します。
qσ(x~∣x)=N(x~;x,σ2I)=(2πσ2)d/21exp(−2σ2∥x~−x∥2)
周辺分布による損失関数の表現
データ分布を p(x) とすると、同時分布は qσ(x~,x)=qσ(x~∣x)p(x) であり、
ノイズ付きデータの周辺分布は
qσ(x~)=∫qσ(x~∣x)p(x)dx
と書けます。復号器 r の再構成誤差(損失関数)は、この同時分布に関する期待値として
L(r)=Eqσ(x~,x)[∥r(x~)−x∥2]=∫∫∥r(x~)−x∥2qσ(x~∣x)p(x)dxdx~
と表せます。
最適な復号器を変分法で求める
L(r) を最小化する r∗ を求めます。各 x~ に対して被積分関数の汎関数微分をとり、0 とおくと
δr(x~)δL=2∫(r(x~)−x)qσ(x~∣x)p(x)dx=0
これを r∗(x~) について解けば
r∗(x~)=∫qσ(x~∣x)p(x)dx∫xqσ(x~∣x)p(x)dx=qσ(x~)∫xqσ(x~∣x)p(x)dx
ベイズの定理より事後分布は
qσ(x∣x~)=qσ(x~)qσ(x~∣x)p(x)
と分解できるので、最適な復号器は事後期待値になります。
r∗(x~)=Eqσ(x∣x~)[x]
ガウス分布の微分による導出
私はここが一番天下りだと感じました。ガウス分布 qσ(x~∣x) を x~ で微分すると
∇x~qσ(x~∣x)=−σ2x~−xqσ(x~∣x)
この関係を使って r∗(x~)−x~ を計算します。先ほどのベイズの定理による分解の式を代入すると
r∗(x~)−x~=qσ(x~)∫(x−x~)qσ(x~∣x)p(x)dx
この分子に上の微分の関係式を代入します。
=qσ(x~)σ2∫∇x~qσ(x~∣x)p(x)dx
p(x) は x~ に依存しないので、微分と積分の順序を交換できます。
=qσ(x~)σ2∇x~∫qσ(x~∣x)p(x)dx=qσ(x~)σ2∇x~qσ(x~)
∇logf=f∇f であることを用いれば、最終的に
r∗(x~)−x~=σ2∇x~logqσ(x~)
が得られます。これは Tweedie’s formula として知られる等式です。
σ2→0 の極限と logp(x) への帰着
σ2→0 のとき、ノイズ付加モデルはデルタ分布に収束します。
qσ(x~∣x)→δ(x~−x)
したがって周辺分布は元のデータ分布に一致し
qσ(x~)→p(x~)
よってスコア関数も
∇x~logqσ(x~)→∇xlogp(x)(σ2→0)
と収束します。すなわち、ノイズを十分小さくとれば、最適デノイザから推定されるスコア関数は真のデータ分布のスコア関数 ∇xlogp(x) に一致するということです。
結論
拡散モデルをノイズ付き自己符号化器として定式化し、最適なデノイザの復元誤差がスコア関数と直結することを示しました。
σ2r∗(x~)−x~=∇x~logqσ(x~)σ2→0∇xlogp(x)
このようにデノイジングを学習することがスコア関数を学習することと等価であるということから、再構成誤差を直接最小化する代わりにスコアを推定するアプローチが取られるようになったんですね。