cp_library_rs/data_structure/
acc2d_cyclic.rs

1//! トーラス上での区間和取得ができる2次元累積和
2
3use std::{
4    convert::{TryFrom, TryInto},
5    fmt::Debug,
6    ops::{
7        Bound::{Excluded, Included, Unbounded},
8        Mul, RangeBounds,
9    },
10};
11
12use num_traits::Num;
13
14/// 2次元累積和
15pub struct Acc2D<T: Num + Copy> {
16    pub H: usize,
17    pub W: usize,
18    pub S: Vec<Vec<T>>,
19}
20
21impl<T> Acc2D<T>
22where
23    T: Num + Copy + TryFrom<usize> + Mul,
24    <T as TryFrom<usize>>::Error: Debug,
25{
26    #[inline]
27    fn parse_range<R: RangeBounds<usize>>(&self, range: &R, max: usize) -> Option<(usize, usize)> {
28        let start = match range.start_bound() {
29            Unbounded => 0,
30            Excluded(&v) => v + 1,
31            Included(&v) => v,
32        };
33        let end = match range.end_bound() {
34            Unbounded => max,
35            Excluded(&v) => v,
36            Included(&v) => v + 1,
37        };
38        if start <= end && end <= max {
39            Some((start, end))
40        } else {
41            None
42        }
43    }
44
45    /// 2次元配列から累積和を初期化する
46    #[allow(clippy::ptr_arg)]
47    pub fn new(array: &Vec<Vec<T>>) -> Self {
48        let (H, W) = (array.len(), array[0].len());
49        let mut S = vec![vec![T::zero(); W + 1]; H + 1];
50        for i in 0..H {
51            for j in 0..W {
52                S[i + 1][j + 1] = array[i][j] + S[i][j + 1] + S[i + 1][j] - S[i][j];
53            }
54        }
55        Self { H, W, S }
56    }
57
58    /// 累積和の値を求める
59    pub fn sum<R, C>(&self, row: R, col: C) -> T
60    where
61        R: RangeBounds<usize> + Debug,
62        C: RangeBounds<usize> + Debug,
63    {
64        let Some((rs, re)) = self.parse_range(&row, self.H) else {
65            panic!("The given range is wrong (row): {:?}", row);
66        };
67        let Some((cs, ce)) = self.parse_range(&col, self.W) else {
68            panic!("The given range is wrong (col): {:?}", col);
69        };
70        self.S[re][ce] + self.S[rs][cs] - self.S[re][cs] - self.S[rs][ce]
71    }
72
73    /// トーラスとみなしたときの和を求める
74    /// ## Args
75    /// - `(top,left)`:左上の座標
76    /// - `(height,width)`:取得する区間
77    pub fn sum_cyclic(&self, mut top: usize, mut left: usize, height: usize, width: usize) -> T {
78        top %= self.H;
79        left %= self.W;
80
81        // 繰り返し回数
82        let hrep: T = (height / self.H).try_into().unwrap();
83        let wrep: T = (width / self.W).try_into().unwrap();
84
85        // 右下の座標
86        let bottom = (top + height) % self.H;
87        let right = (left + width) % self.W;
88
89        // 内部領域
90        let S_inner = self.sum(.., ..) * hrep * wrep;
91
92        // 左右の領域
93        let S_lr = if left <= right {
94            self.sum(.., left..right) * hrep
95        } else {
96            (self.sum(.., left..) + self.sum(.., ..right)) * hrep
97        };
98
99        // 上下の領域
100        let S_tb = if top <= bottom {
101            self.sum(top..bottom, ..) * wrep
102        } else {
103            (self.sum(top.., ..) + self.sum(..bottom, ..)) * wrep
104        };
105
106        // 端の領域
107        let S_edge = match (top <= bottom, left <= right) {
108            (true, true) => self.sum(top..bottom, left..right),
109            (true, false) => self.sum(top..bottom, left..) + self.sum(top..bottom, ..right),
110            (false, true) => self.sum(top.., left..right) + self.sum(..bottom, left..right),
111            (false, false) => {
112                self.sum(top.., left..)
113                    + self.sum(top.., ..right)
114                    + self.sum(..bottom, left..)
115                    + self.sum(..bottom, ..right)
116            }
117        };
118
119        S_inner + S_lr + S_tb + S_edge
120    }
121}