混合ガウスモデル#

多峰性関数のモデル化方法を学ぶ.

// 依存関係のインストール
:dep image = "0.23"
:dep evcxr_image = "1.1"

// プロット用ライブラリ
:dep plotters = { version = "^0.3.5", default_features = false, features = ["evcxr", "all_series", "all_elements", "bitmap_backend", "full_palette", "point_series", "colormaps"] }

// 乱数
:dep rand = "0.8.5"
:dep rand_distr = "0.4.3"

// 配列
:dep ndarray = { version = "0.15.6" }
// インポート
use evcxr_image::ImageDisplay;
use image::{GenericImage, imageops::FilterType};

use plotters::prelude::*;
use rand::prelude::*;

use ndarray::{Array, ArrayView, Axis, array, s};
use std::{fs::File, io::{self, prelude::*}, path::Path};
// ファイル
let old_faithful_path = Path::new("./data/old_faithful.txt");

// 読み込み
let file = File::open(&old_faithful_path)?;
let reader = io::BufReader::new(file);
let old_faithful_data = reader
    .lines()
    .filter_map(|line| line.ok())
    .filter_map(|val| {
        let mut itr = val.split_whitespace();
        let eruptions = itr.next()?.parse::<f64>().ok()?;
        let waiting = itr.next()?.parse::<f64>().ok()?;
        Some((eruptions, waiting))
    })
    .collect::<Vec<_>>();

// 読み取れていることを確認
&old_faithful_data[..5]
[(3.6, 79.0), (1.8, 54.0), (3.333, 74.0), (2.283, 62.0), (4.533, 85.0)]
// データの可視化
evcxr_figure((600, 450), |root| {
    root.fill(&WHITE)?;

    let mut chart = ChartBuilder::on(&root)
        .caption("Eruptions and Waiting", ("sans", 20).into_font())
        .x_label_area_size(40)
        .y_label_area_size(40)
        .build_cartesian_2d(1.0..6.0, 30.0..110.0)?;

    chart.configure_mesh()
        .draw()?;
        
    chart.draw_series(
        old_faithful_data.iter().map(|&(x, y)| Circle::new((x, y), 3, BLUE.filled()))
    )?;

    Ok(())
})
Eruptions and Waiting 30.0 40.0 50.0 60.0 70.0 80.0 90.0 100.0 110.0 1.0 1.5 2.0 2.5 3.0 3.5 4.0 4.5 5.0 5.5 6.0

混合ガウスモデルのデータ生成#

混合ガウスモデル (GMM) とは?#

  • 複数の正規分布を組合せたモデル

GMMからのサンプリング方法#

正規分布2個の場合

  1. 以下を繰り返す

    1. 2つの正規分布から,1つを選ぶ

    2. 選んだ正規分布からデータを生成する

:dep myml = { path = "../myml/" }
use rand::distributions::WeightedIndex;

use myml::normal::multivariate_normal_sample;
/// 混合ガウスモデルからサンプリングを行う関数
fn sample_n(n: usize) -> Vec<(f64, f64)> {
    // ===== 学習済みパラメータ =====
    let mus = [
        array![2.0, 54.50],
        array![4.3, 80.0]
    ];

    let covs = [
        array![
            [0.07, 0.44],
            [0.44, 33.7],
        ],
        array![
            [0.17, 0.94],
            [0.94, 36.0],
        ]
    ];

    let phis = [0.35, 0.65];
    // ==============================

    // 乱数生成器の初期化
    let mut rng = thread_rng();
    // 重み付きインデックス
    let dist = WeightedIndex::new(&phis).unwrap();

    dist.sample_iter(&mut rng)
        .take(n)
        .map(|i| {
            let res = multivariate_normal_sample(
                1,
                mus[i].view().clone(),
                covs[i].view().clone()
            ).unwrap()[0].clone();
            (res[0], res[1])
        })
        .collect()
}
let sample_500 = sample_n(500);

&sample_500[..5]
[(3.881909744757437, 77.94282717452054), (3.9074054929941586, 78.21915888118534), (5.155770743536556, 78.7678070149949), (4.419981411180869, 77.25807163642133), (4.604574522923448, 73.92620420502918)]
// データの可視化
evcxr_figure((600, 450), |root| {
    root.fill(&WHITE)?;

    let mut chart = ChartBuilder::on(&root)
        .caption("Eruptions and Waiting", ("sans", 20).into_font())
        .x_label_area_size(40)
        .y_label_area_size(40)
        .build_cartesian_2d(1.0..6.0, 30.0..110.0)?;

    chart.configure_mesh()
        .draw()?;
        
    // 元データ
    chart.draw_series(
        old_faithful_data.iter().map(|&(x, y)| Circle::new((x, y), 3, BLUE.filled()))
    )?
    .label("raw data")
    .legend(|(x, y)| Circle::new((x, y), 3, BLUE.filled()));

    // 生成したデータ
    chart.draw_series(
        sample_500.iter().map(|&(x, y)| Circle::new((x, y), 3, RED.filled()))
    )?
    .label("generated data")
    .legend(|(x, y)| Circle::new((x, y), 3, RED.filled()));

    chart.configure_series_labels()
        .position(SeriesLabelPosition::LowerRight)
        .margin(10)
        .legend_area_size(10)
        .border_style(BLACK)
        .background_style(BLACK.mix(0.1))
        .label_font(("Calibri", 20))
        .draw()?;

    Ok(())
})
Eruptions and Waiting 30.0 40.0 50.0 60.0 70.0 80.0 90.0 100.0 110.0 1.0 1.5 2.0 2.5 3.0 3.5 4.0 4.5 5.0 5.5 6.0 raw data generated data