KL ダイバージェンス#
モチベーション#
2 つの確率分布 \(p(x),q(x)\) の間に,距離のようなものを定義したい
以下の条件を満たしてほしい
条件
\(p(x), q(x)\) が異なるほど大きい値をとる
\(p(x) = q(x)\) のときのみ \(0\) ,それ以外は正の値をとる
これを満たす関数が KL ダイバージェンスである.
KL ダイバージェンスとは#
KL ダイバージェンスとは,2 つの確率密度を測る尺度である.
2 つの確率分布 \(p(x)\) と \(q(x)\) の間の KL ダイバージェンスは以下のように定義される.
\[
D_\mathrm{KL}(p || q) = \int p(x) \log \frac{p(x)}{q(x)} dx
\]
KL ダイバージェンスの特徴#
KL ダイバージェンスは上の条件 1,2 を満たす
\(D_\mathrm{KL}(p||q)\) と \(D_\mathrm{KL}(q||p)\) は同じ値をとるとは限らない
→ 距離の公理を満たさないので,厳密には距離ではない
実際に検証してみる#
// 依存関係のインストール
:dep image = "0.23"
:dep evcxr_image = "1.1"
// プロット用ライブラリ
:dep plotters = { version = "^0.3.5", default_features = false, features = ["evcxr", "all_series", "all_elements", "bitmap_backend", "full_palette", "point_series", "colormaps"] }
// 乱数
:dep rand = "0.8.5"
:dep rand_distr = "0.4.3"
// 配列
:dep ndarray = { version = "0.15.6" }
// インポート
use evcxr_image::ImageDisplay;
use image::{GenericImage, imageops::FilterType};
use plotters::prelude::*;
use rand::prelude::*;
use ndarray::{Array, ArrayView, Axis, array, s};
:dep myml = { path = "../myml/" }
use myml::utility::linspace;
比較する確率分布#
\[\begin{split}
\begin{align*}
f(x) &= \frac{1}{\sqrt{2\pi}} \exp \left( -\frac{x^2}{2} \right) \quad (= \mathcal{N}(0, 1))\\[10pt]
g(x) &=
\left\{
\begin{array}{ll}
1/4 x + 1/2 & \mathrm{if} ~ (-2 \le x \lt 0)\\
-1/4 x + 1/2 & \mathrm{if} ~ (0 \le x \lt 2)\\
0 & \mathrm{otherwise}
\end{array}
\right.\\[10pt]
h(x) &=
\left\{
\begin{array}{ll}
1 / 6 & \mathrm{if} ~ (-3 \le x \lt 3)\\
0 & \mathrm{otherwise}
\end{array}
\right.
\end{align*}
\end{split}\]
関数の定義
use std::f64::consts::PI;
fn f(x: f64) -> f64 {
1.0 / (2.0 * PI).sqrt() * f64::exp(-x.powi(2) / 2.0)
}
fn g(x: f64) -> f64 {
(0.5 - 0.25 * x.abs()).max(0.0)
}
fn h(x: f64) -> f64 {
if -3.0 <= x && x < 3.0 {
1.0 / 6.0
} else {
0.0
}
}
プロット
const NUM: usize = 200;
// x
let x = linspace(-5.0, 5.0, NUM);
// f(x)
let fx = x.iter().copied().map(f).collect::<Vec<_>>();
// g(x)
let gx = x.iter().copied().map(g).collect::<Vec<_>>();
// h(x)
let hx = x.iter().copied().map(h).collect::<Vec<_>>();
evcxr_figure((600, 480), |root| {
root.fill(&WHITE)?;
let mut chart = ChartBuilder::on(&root)
.caption("functions", ("Sans", 20).into_font())
.x_label_area_size(40)
.y_label_area_size(40)
.build_cartesian_2d(-4.0..4.0, 0.0..0.6)?;
chart.configure_mesh()
.draw()?;
chart.draw_series(
LineSeries::new(
x.iter().copied().zip(fx.iter().copied()),
&BLUE
)
)?
.label("y = f(x)")
.legend(|(x, y)| Circle::new((x, y), 3, BLUE.filled()));
chart.draw_series(
LineSeries::new(
x.iter().copied().zip(gx.iter().copied()),
&RED
)
)?
.label("y = g(x)")
.legend(|(x, y)| Circle::new((x, y), 3, RED.filled()));
chart.draw_series(
LineSeries::new(
x.iter().copied().zip(hx.iter().copied()),
&GREEN
)
)?
.label("y = h(x)")
.legend(|(x, y)| Circle::new((x, y), 3, GREEN.filled()));
// 凡例の設定
chart.configure_series_labels()
.position(SeriesLabelPosition::UpperRight)
.margin(10)
.legend_area_size(10)
.border_style(BLACK)
.background_style(BLACK.mix(0.1))
.label_font(("roman", 20))
.draw()?;
Ok(())
})
KLダイバージェンスの計算#
離散値に直して積分する
const EPS: f64 = 1e-10;
fn kl_divergence(p: &Vec<f64>, q: &Vec<f64>) -> f64 {
p
.iter()
.zip(q.iter())
.map(|(&pi, &qi)| (
pi / NUM as f64 + EPS,
qi / NUM as f64 + EPS
))
.map(|(pi, qi)| pi * (pi / qi).ln())
.sum()
}
let funcs = [fx.clone(), gx.clone(), hx.clone()];
let labels = ["f", "g", "h"];
for i in 0..3 {
for j in 0..3 {
let kld = kl_divergence(&funcs[i], &funcs[j]);
println!("D_KL( {} || {} ) = {}", labels[i], labels[j], kld);
}
};
D_KL( f || f ) = 0
D_KL( f || g ) = 0.0592454658316637
D_KL( f || h ) = 0.041165007754176974
D_KL( g || f ) = 0.00586053773246323
D_KL( g || g ) = 0
D_KL( g || h ) = 0.05953179630295036
D_KL( h || f ) = 0.06421860035958628
D_KL( h || g ) = 0.5258095061012997
D_KL( h || h ) = 0
KLダイバージェンスの性質が確認できる.
自分自身に対してゼロを返す
\(D_\mathrm{KL}(f || f) = 0\)
\(D_\mathrm{KL}(g || g) = 0\)
\(D_\mathrm{KL}(h || h) = 0\)
近い分布ほど小さい値をとる
\(D_\mathrm{KL}(g || f) = 0.00586053773246323\)
遠い分布ほど大きい値をとる
\(D_\mathrm{KL}(h || g) = 0.5258095061012997\)