混合ガウスモデルの実装#
GMMのアルゴリズム#
GMMのアルゴリズムは以下の通り.
K 個の正規分布を用意する.
以下を繰り返す
K個の正規分布のなかから,確率分布 \(\phi\) にしたがって1つを選ぶ
選んだ正規分布からサンプリングする
GMMの式#
観測データ \(\boldsymbol{x}\) の確率分布 \(p(\boldsymbol{x})\) は,複数の正規分布の重み付き和として表される.
\[
p(\boldsymbol{x}) = \sum_{k = 1}^K \phi_k \mathcal{N}(\boldsymbol{x}; \boldsymbol{\mu}_k, \boldsymbol{\Sigma}_k)
\]
ここで \(\phi_k\) は \(k\) 番目の正規分布が選ばれる確率である.
実装#
// 依存関係のインストール
:dep image = "0.23"
:dep evcxr_image = "1.1"
// プロット用ライブラリ
:dep plotters = { version = "^0.3.5", default_features = false, features = ["evcxr", "all_series", "all_elements", "bitmap_backend", "full_palette", "point_series", "colormaps"] }
// 乱数
:dep rand = "0.8.5"
:dep rand_distr = "0.4.3"
// 配列
:dep ndarray = { version = "0.15.6" }
// インポート
use evcxr_image::ImageDisplay;
use image::{GenericImage, imageops::FilterType};
use plotters::prelude::*;
use rand::prelude::*;
use ndarray::{Array, ArrayView, Axis, array, s};
:dep myml = { path = "../myml" }
use myml::{gmm::gmm, utility::linspace};
可視化
let mus = [
array![2.0, 54.50],
array![4.3, 80.0]
];
let covs = [
array![
[0.07, 0.44],
[0.44, 33.7],
],
array![
[0.17, 0.94],
[0.94, 36.0],
]
];
let phis = [0.35, 0.65];
evcxr_figure((600, 600), |root| {
root.fill(&WHITE)?;
let mut chart = ChartBuilder::on(&root)
.caption("GMM", ("sans-serif", 20).into_font())
// .x_label_area_size(50)
// .y_label_area_size(50)
.build_cartesian_3d(1.0..6.0, 0.0..0.08, 40.0..100.0)?;
chart.with_projection(|mut p| {
p.pitch = f64::to_radians(20.0);
p.yaw = f64::to_radians(20.0);
p.scale = 0.8;
p.into_matrix()
});
chart.configure_axes()
.draw()?;
chart.draw_series(
SurfaceSeries::xoz(
linspace(1.0, 6.0, 100).into_iter(),
linspace(40.0, 100.0, 100).into_iter(),
|x, z| {
let x = array![[x, z]];
gmm(x.view(), &mus, &covs, &phis).unwrap()[0]
}
)
.style_func(&|&v| (VulcanoHSL::get_color(v * 30.0)).into())
)?;
Ok(())
})