多次元正規分布

多次元正規分布#

:dep ndarray = { version = "0.15.6" }
:dep ndarray-linalg = { version = "0.16.0", features = ["openblas"] }
use ndarray::{Array, ArrayView, array, s};
use ndarray_linalg::{Determinant, Inverse};
:dep myml = { path = "../myml" }
use myml::normal::multivariate_normal;
// 多変量正規分布
let x = array![[0.0, 0.0]];
let mu = array![0.0, 0.0];
let cov = array![
    [1.0, 0.0],
    [0.0, 4.0]
];

multivariate_normal(x.view(), mu.view(), cov.view())
Some([0.07957747154594767], shape=[1], strides=[1], layout=CFcf (0xf), const ndim=1)

2次元正規分布の可視化#

// 依存関係のインストール

:dep image = "0.23"
:dep evcxr_image = "1.1"

// プロット用ライブラリ
:dep plotters = { version = "^0.3.5", default_features = false, features = ["evcxr", "all_series", "all_elements", "bitmap_backend", "full_palette", "colormaps"] }

// 乱数
:dep rand = { version = "0.8.5" }
// インポート
use evcxr_image::ImageDisplay;
use image::{GenericImage, imageops::FilterType};

use plotters::prelude::*;
use rand::prelude::*;
use myml::utility::linspace;
evcxr_figure((600, 600), |root| {
    root.fill(&WHITE)?;

    let mut chart = ChartBuilder::on(&root)
        .caption("multivariate normal", ("sans-serif", 20).into_font())
        // .x_label_area_size(50)
        // .y_label_area_size(50)
        .build_cartesian_3d(-8.0..8.0, 0.0..0.1, -8.0..8.0)?;

    chart.with_projection(|mut p| {
        p.pitch = f64::to_radians(20.0);
        p.yaw = f64::to_radians(20.0);
        p.scale = 0.8;
        p.into_matrix()
    });

    chart.configure_axes()
        .draw()?;

    chart.draw_series(
        SurfaceSeries::xoz(
            linspace(-8.0, 8.0, 100).into_iter(),
            linspace(-8.0, 8.0, 100).into_iter(),
            |x, z| {
                let x = array![[x, z]];
                multivariate_normal(x.view(), mu.view(), cov.view()).unwrap()[0]
            }
        )
        .style_func(&|&v| (VulcanoHSL::get_color(v * 20.0)).into())
    )?;

    Ok(())
})
multivariate normal -8.0 -6.0 -4.0 -2.0 0.0 2.0 4.0 6.0 8.0 0.0 0.02 0.04 0.06 0.08 0.1 -8.0 -6.0 -4.0 -2.0 0.0 2.0 4.0 6.0 8.0
evcxr_figure((600, 600), |root| {
    root.fill(&WHITE)?;

    let mut chart = ChartBuilder::on(&root)
        .caption("multivariate normal", ("sans-serif", 20).into_font())
        // .x_label_area_size(50)
        // .y_label_area_size(50)
        .build_cartesian_3d(-4.0..4.0, 0.0..0.1, -4.0..4.0)?;

    chart.with_projection(|mut p| {
        p.pitch = f64::to_radians(90.0);
        p.yaw = f64::to_radians(0.0);
        p.into_matrix()
    });

    chart.configure_axes()
        .draw()?;

    chart.draw_series(
        SurfaceSeries::xoz(
            linspace(-4.0, 4.0, 100).into_iter(),
            linspace(-4.0, 4.0, 100).into_iter(),
            |x, z| {
                let x = array![[x, z]];
                multivariate_normal(x.view(), mu.view(), cov.view()).unwrap()[0]
            }
        )
        .style_func(&|&v| (VulcanoHSL::get_color(v * 20.0)).into())
    )?;

    Ok(())
})
multivariate normal -4.0 -3.0 -2.0 -1.0 0.0 1.0 2.0 3.0 4.0 0.0 0.02 0.04 0.06 0.08 0.1 -4.0 -3.0 -2.0 -1.0 0.0 1.0 2.0 3.0 4.0