混合ガウスモデル#
多峰性関数のモデル化方法を学ぶ.
// 依存関係のインストール
:dep image = "0.23"
:dep evcxr_image = "1.1"
// プロット用ライブラリ
:dep plotters = { version = "^0.3.5", default_features = false, features = ["evcxr", "all_series", "all_elements", "bitmap_backend", "full_palette", "point_series", "colormaps"] }
// 乱数
:dep rand = "0.8.5"
:dep rand_distr = "0.4.3"
// 配列
:dep ndarray = { version = "0.15.6" }
// インポート
use evcxr_image::ImageDisplay;
use image::{GenericImage, imageops::FilterType};
use plotters::prelude::*;
use rand::prelude::*;
use ndarray::{Array, ArrayView, Axis, array, s};
use std::{fs::File, io::{self, prelude::*}, path::Path};
// ファイル
let old_faithful_path = Path::new("./data/old_faithful.txt");
// 読み込み
let file = File::open(&old_faithful_path)?;
let reader = io::BufReader::new(file);
let old_faithful_data = reader
.lines()
.filter_map(|line| line.ok())
.filter_map(|val| {
let mut itr = val.split_whitespace();
let eruptions = itr.next()?.parse::<f64>().ok()?;
let waiting = itr.next()?.parse::<f64>().ok()?;
Some((eruptions, waiting))
})
.collect::<Vec<_>>();
// 読み取れていることを確認
&old_faithful_data[..5]
[(3.6, 79.0), (1.8, 54.0), (3.333, 74.0), (2.283, 62.0), (4.533, 85.0)]
// データの可視化
evcxr_figure((600, 450), |root| {
root.fill(&WHITE)?;
let mut chart = ChartBuilder::on(&root)
.caption("Eruptions and Waiting", ("sans", 20).into_font())
.x_label_area_size(40)
.y_label_area_size(40)
.build_cartesian_2d(1.0..6.0, 30.0..110.0)?;
chart.configure_mesh()
.draw()?;
chart.draw_series(
old_faithful_data.iter().map(|&(x, y)| Circle::new((x, y), 3, BLUE.filled()))
)?;
Ok(())
})
混合ガウスモデルのデータ生成#
混合ガウスモデル (GMM) とは?#
複数の正規分布を組合せたモデル
GMMからのサンプリング方法#
正規分布2個の場合
以下を繰り返す
2つの正規分布から,1つを選ぶ
選んだ正規分布からデータを生成する
:dep myml = { path = "../myml/" }
use rand::distributions::WeightedIndex;
use myml::normal::multivariate_normal_sample;
/// 混合ガウスモデルからサンプリングを行う関数
fn sample_n(n: usize) -> Vec<(f64, f64)> {
// ===== 学習済みパラメータ =====
let mus = [
array![2.0, 54.50],
array![4.3, 80.0]
];
let covs = [
array![
[0.07, 0.44],
[0.44, 33.7],
],
array![
[0.17, 0.94],
[0.94, 36.0],
]
];
let phis = [0.35, 0.65];
// ==============================
// 乱数生成器の初期化
let mut rng = thread_rng();
// 重み付きインデックス
let dist = WeightedIndex::new(&phis).unwrap();
dist.sample_iter(&mut rng)
.take(n)
.map(|i| {
let res = multivariate_normal_sample(
1,
mus[i].view().clone(),
covs[i].view().clone()
).unwrap()[0].clone();
(res[0], res[1])
})
.collect()
}
let sample_500 = sample_n(500);
&sample_500[..5]
[(3.881909744757437, 77.94282717452054), (3.9074054929941586, 78.21915888118534), (5.155770743536556, 78.7678070149949), (4.419981411180869, 77.25807163642133), (4.604574522923448, 73.92620420502918)]
// データの可視化
evcxr_figure((600, 450), |root| {
root.fill(&WHITE)?;
let mut chart = ChartBuilder::on(&root)
.caption("Eruptions and Waiting", ("sans", 20).into_font())
.x_label_area_size(40)
.y_label_area_size(40)
.build_cartesian_2d(1.0..6.0, 30.0..110.0)?;
chart.configure_mesh()
.draw()?;
// 元データ
chart.draw_series(
old_faithful_data.iter().map(|&(x, y)| Circle::new((x, y), 3, BLUE.filled()))
)?
.label("raw data")
.legend(|(x, y)| Circle::new((x, y), 3, BLUE.filled()));
// 生成したデータ
chart.draw_series(
sample_500.iter().map(|&(x, y)| Circle::new((x, y), 3, RED.filled()))
)?
.label("generated data")
.legend(|(x, y)| Circle::new((x, y), 3, RED.filled()));
chart.configure_series_labels()
.position(SeriesLabelPosition::LowerRight)
.margin(10)
.legend_area_size(10)
.border_style(BLACK)
.background_style(BLACK.mix(0.1))
.label_font(("Calibri", 20))
.draw()?;
Ok(())
})