混合ガウスモデルの実装

混合ガウスモデルの実装#

GMMのアルゴリズム#

GMMのアルゴリズムは以下の通り.

K 個の正規分布を用意する.

以下を繰り返す

  1. K個の正規分布のなかから,確率分布 \(\phi\) にしたがって1つを選ぶ

  2. 選んだ正規分布からサンプリングする

GMMの式#

観測データ \(\boldsymbol{x}\) の確率分布 \(p(\boldsymbol{x})\) は,複数の正規分布の重み付き和として表される.

\[ p(\boldsymbol{x}) = \sum_{k = 1}^K \phi_k \mathcal{N}(\boldsymbol{x}; \boldsymbol{\mu}_k, \boldsymbol{\Sigma}_k) \]

ここで \(\phi_k\)\(k\) 番目の正規分布が選ばれる確率である.

実装#

// 依存関係のインストール
:dep image = "0.23"
:dep evcxr_image = "1.1"

// プロット用ライブラリ
:dep plotters = { version = "^0.3.5", default_features = false, features = ["evcxr", "all_series", "all_elements", "bitmap_backend", "full_palette", "point_series", "colormaps"] }

// 乱数
:dep rand = "0.8.5"
:dep rand_distr = "0.4.3"

// 配列
:dep ndarray = { version = "0.15.6" }
// インポート
use evcxr_image::ImageDisplay;
use image::{GenericImage, imageops::FilterType};

use plotters::prelude::*;
use rand::prelude::*;

use ndarray::{Array, ArrayView, Axis, array, s};
:dep myml = { path = "../myml" }
use myml::{gmm::gmm, utility::linspace};

可視化

let mus = [
    array![2.0, 54.50],
    array![4.3, 80.0]
];

let covs = [
    array![
        [0.07, 0.44],
        [0.44, 33.7],
    ],
    array![
        [0.17, 0.94],
        [0.94, 36.0],
    ]
];

let phis = [0.35, 0.65];
evcxr_figure((600, 600), |root| {
    root.fill(&WHITE)?;

    let mut chart = ChartBuilder::on(&root)
        .caption("GMM", ("sans-serif", 20).into_font())
        // .x_label_area_size(50)
        // .y_label_area_size(50)
        .build_cartesian_3d(1.0..6.0, 0.0..0.08, 40.0..100.0)?;

    chart.with_projection(|mut p| {
        p.pitch = f64::to_radians(20.0);
        p.yaw = f64::to_radians(20.0);
        p.scale = 0.8;
        p.into_matrix()
    });

    chart.configure_axes()
        .draw()?;

    chart.draw_series(
        SurfaceSeries::xoz(
            linspace(1.0, 6.0, 100).into_iter(),
            linspace(40.0, 100.0, 100).into_iter(),
            |x, z| {
                let x = array![[x, z]];
                gmm(x.view(), &mus, &covs, &phis).unwrap()[0]
            }
        )
        .style_func(&|&v| (VulcanoHSL::get_color(v * 30.0)).into())
    )?;

    Ok(())
})
GMM 1.0 2.0 3.0 4.0 5.0 6.0 0.0 0.01 0.02 0.03 0.04 0.05 0.06 0.07 0.08 40.0 50.0 60.0 70.0 80.0 90.0 100.0