KL ダイバージェンス#

モチベーション#

  • 2 つの確率分布 \(p(x),q(x)\) の間に,距離のようなものを定義したい

  • 以下の条件を満たしてほしい

条件

  1. \(p(x), q(x)\) が異なるほど大きい値をとる

  2. \(p(x) = q(x)\) のときのみ \(0\) ,それ以外は正の値をとる

これを満たす関数が KL ダイバージェンスである.

KL ダイバージェンスとは#

KL ダイバージェンスとは,2 つの確率密度を測る尺度である.

2 つの確率分布 \(p(x)\)\(q(x)\) の間の KL ダイバージェンスは以下のように定義される.

\[ D_\mathrm{KL}(p || q) = \int p(x) \log \frac{p(x)}{q(x)} dx \]

KL ダイバージェンスの特徴#

  • KL ダイバージェンスは上の条件 1,2 を満たす

  • \(D_\mathrm{KL}(p||q)\)\(D_\mathrm{KL}(q||p)\) は同じ値をとるとは限らない
    距離の公理を満たさないので,厳密には距離ではない

実際に検証してみる#

// 依存関係のインストール
:dep image = "0.23"
:dep evcxr_image = "1.1"

// プロット用ライブラリ
:dep plotters = { version = "^0.3.5", default_features = false, features = ["evcxr", "all_series", "all_elements", "bitmap_backend", "full_palette", "point_series", "colormaps"] }

// 乱数
:dep rand = "0.8.5"
:dep rand_distr = "0.4.3"

// 配列
:dep ndarray = { version = "0.15.6" }
// インポート
use evcxr_image::ImageDisplay;
use image::{GenericImage, imageops::FilterType};

use plotters::prelude::*;
use rand::prelude::*;

use ndarray::{Array, ArrayView, Axis, array, s};
:dep myml = { path = "../myml/" }
use myml::utility::linspace;

比較する確率分布#

\[\begin{split} \begin{align*} f(x) &= \frac{1}{\sqrt{2\pi}} \exp \left( -\frac{x^2}{2} \right) \quad (= \mathcal{N}(0, 1))\\[10pt] g(x) &= \left\{ \begin{array}{ll} 1/4 x + 1/2 & \mathrm{if} ~ (-2 \le x \lt 0)\\ -1/4 x + 1/2 & \mathrm{if} ~ (0 \le x \lt 2)\\ 0 & \mathrm{otherwise} \end{array} \right.\\[10pt] h(x) &= \left\{ \begin{array}{ll} 1 / 6 & \mathrm{if} ~ (-3 \le x \lt 3)\\ 0 & \mathrm{otherwise} \end{array} \right. \end{align*} \end{split}\]

関数の定義

use std::f64::consts::PI;

fn f(x: f64) -> f64 {
    1.0 / (2.0 * PI).sqrt() * f64::exp(-x.powi(2) / 2.0)
}

fn g(x: f64) -> f64 {
    (0.5 - 0.25 * x.abs()).max(0.0)
}

fn h(x: f64) -> f64 {
    if -3.0 <= x && x < 3.0 {
        1.0 / 6.0
    } else {
        0.0
    }
}

プロット

const NUM: usize = 200;

// x
let x = linspace(-5.0, 5.0, NUM);

// f(x)
let fx = x.iter().copied().map(f).collect::<Vec<_>>();

// g(x)
let gx = x.iter().copied().map(g).collect::<Vec<_>>();

// h(x)
let hx = x.iter().copied().map(h).collect::<Vec<_>>();
evcxr_figure((600, 480), |root| {
    root.fill(&WHITE)?;

    let mut chart = ChartBuilder::on(&root)
        .caption("functions", ("Sans", 20).into_font())
        .x_label_area_size(40)
        .y_label_area_size(40)
        .build_cartesian_2d(-4.0..4.0, 0.0..0.6)?;

    chart.configure_mesh()
        .draw()?;

    chart.draw_series(
        LineSeries::new(
            x.iter().copied().zip(fx.iter().copied()),
            &BLUE
        )
    )?
    .label("y = f(x)")
    .legend(|(x, y)| Circle::new((x, y), 3, BLUE.filled()));

    chart.draw_series(
        LineSeries::new(
            x.iter().copied().zip(gx.iter().copied()),
            &RED
        )
    )?
    .label("y = g(x)")
    .legend(|(x, y)| Circle::new((x, y), 3, RED.filled()));

    chart.draw_series(
        LineSeries::new(
            x.iter().copied().zip(hx.iter().copied()),
            &GREEN
        )
    )?
    .label("y = h(x)")
    .legend(|(x, y)| Circle::new((x, y), 3, GREEN.filled()));

    // 凡例の設定
    chart.configure_series_labels()
        .position(SeriesLabelPosition::UpperRight)
        .margin(10)
        .legend_area_size(10)
        .border_style(BLACK)
        .background_style(BLACK.mix(0.1))
        .label_font(("roman", 20))
        .draw()?;

    Ok(())
})
functions 0.0 0.1 0.2 0.3 0.4 0.5 0.6 -4.0 -3.0 -2.0 -1.0 0.0 1.0 2.0 3.0 4.0 y = f(x) y = g(x) y = h(x)

KLダイバージェンスの計算#

離散値に直して積分する

const EPS: f64 = 1e-10;

fn kl_divergence(p: &Vec<f64>, q: &Vec<f64>) -> f64 {
    p
        .iter()
        .zip(q.iter())
        .map(|(&pi, &qi)| (
            pi / NUM as f64 + EPS,
            qi / NUM as f64 + EPS
        ))
        .map(|(pi, qi)| pi * (pi / qi).ln())
        .sum()
}
let funcs = [fx.clone(), gx.clone(), hx.clone()];
let labels = ["f", "g", "h"];
for i in 0..3 {
    for j in 0..3 {
        let kld = kl_divergence(&funcs[i], &funcs[j]);

        println!("D_KL( {} || {} ) = {}", labels[i], labels[j], kld);
    }
};
D_KL( f || f ) = 0
D_KL( f || g ) = 0.0592454658316637
D_KL( f || h ) = 0.041165007754176974
D_KL( g || f ) = 0.00586053773246323
D_KL( g || g ) = 0
D_KL( g || h ) = 0.05953179630295036
D_KL( h || f ) = 0.06421860035958628
D_KL( h || g ) = 0.5258095061012997
D_KL( h || h ) = 0

KLダイバージェンスの性質が確認できる.

  • 自分自身に対してゼロを返す

    • \(D_\mathrm{KL}(f || f) = 0\)

    • \(D_\mathrm{KL}(g || g) = 0\)

    • \(D_\mathrm{KL}(h || h) = 0\)

  • 近い分布ほど小さい値をとる

    • \(D_\mathrm{KL}(g || f) = 0.00586053773246323\)

  • 遠い分布ほど大きい値をとる

    • \(D_\mathrm{KL}(h || g) = 0.5258095061012997\)