cp_library_rs/data_structure/
bitset.rs

1//! ビット列を管理する
2
3use std::{
4    fmt::Debug,
5    ops::{Deref, DerefMut, Index},
6};
7
8/// ビット列を高速に処理する
9#[derive(Clone)]
10pub struct BitSet<const SIZE: usize> {
11    bits: Vec<u64>,
12}
13
14impl<const SIZE: usize> Default for BitSet<SIZE> {
15    fn default() -> Self {
16        Self::new()
17    }
18}
19
20impl<const SIZE: usize> BitSet<SIZE> {
21    /// ⌈size / 64⌉個のu64
22    const ARRAY_SIZE: usize = SIZE.div_ceil(64);
23
24    /// あまりのビット
25    const REM_BIT: usize = SIZE % 64;
26
27    /// 一時的な値
28    const TMP_BOOL: [bool; 2] = [false, true];
29
30    /// Bitsetを初期化する
31    /// - `size`: ビットの数
32    pub fn new() -> Self {
33        Self {
34            bits: vec![0; Self::ARRAY_SIZE],
35        }
36    }
37
38    /// `idx`bit目を1に設定
39    pub fn set(&mut self, index: usize) {
40        let arr_idx = index / 64;
41        let bit_idx = index % 64;
42        self.bits[arr_idx] |= 1 << bit_idx;
43    }
44
45    /// `idx`bit目を0に設定
46    pub fn unset(&mut self, index: usize) {
47        let arr_idx = index / 64;
48        let bit_idx = index % 64;
49        self.bits[arr_idx] &= !(1 << bit_idx);
50    }
51
52    /// `idx`bit目を反転
53    pub fn flip(&mut self, index: usize) {
54        if self[index] {
55            self.unset(index);
56        } else {
57            self.set(index);
58        }
59    }
60
61    /// すべてのbitが0になっているかを判定する
62    pub fn any(&self) -> bool {
63        self.bits.iter().all(|&b64| b64 == 0)
64    }
65
66    /// すべてのbitが1になっているかを判定する
67    pub fn all(&self) -> bool {
68        // あまりだけ個別に判定
69        let filter = !0_u64 >> (64 - Self::REM_BIT);
70        self.bits[Self::ARRAY_SIZE - 1] ^ filter == 0
71            && self
72                .bits
73                .iter()
74                .take(Self::ARRAY_SIZE - 1)
75                .all(|&b64| b64 == !0)
76    }
77
78    /// あるbitを更新する
79    fn update(&mut self, index: usize, new_val: bool) {
80        if new_val {
81            self.set(index);
82        } else {
83            self.unset(index);
84        }
85    }
86
87    /// 1であるビットの数を求める
88    pub fn count_ones(&self) -> usize {
89        self.bits
90            .iter()
91            .map(|b64| b64.count_ones() as usize)
92            .sum::<usize>()
93    }
94
95    /// あるbitの可変参照を取得する
96    /// - `index`: 取得するbitのインデックス
97    pub fn get_mut(&mut self, index: usize) -> Option<BitMut<'_, SIZE>> {
98        if index < SIZE {
99            let default = self[index];
100            Some(BitMut {
101                bitset: self,
102                idx: index,
103                new_val: default,
104            })
105        } else {
106            None
107        }
108    }
109}
110
111impl<const SIZE: usize> Index<usize> for BitSet<SIZE> {
112    type Output = bool;
113    fn index(&self, index: usize) -> &Self::Output {
114        let arr_idx = index / 64;
115        let bit_idx = index % 64;
116        if (self.bits[arr_idx] >> bit_idx) & 1 == 0 {
117            &Self::TMP_BOOL[0]
118        } else {
119            &Self::TMP_BOOL[1]
120        }
121    }
122}
123
124/// bitsetの更新を行う
125pub struct BitMut<'a, const SIZE: usize> {
126    bitset: &'a mut BitSet<SIZE>,
127    idx: usize,
128    new_val: bool,
129}
130
131impl<const SIZE: usize> Deref for BitMut<'_, SIZE> {
132    type Target = bool;
133    fn deref(&self) -> &Self::Target {
134        &self.new_val
135    }
136}
137
138impl<const SIZE: usize> DerefMut for BitMut<'_, SIZE> {
139    fn deref_mut(&mut self) -> &mut Self::Target {
140        &mut self.new_val
141    }
142}
143
144impl<const SIZE: usize> Drop for BitMut<'_, SIZE> {
145    fn drop(&mut self) {
146        self.bitset.update(self.idx, self.new_val);
147    }
148}
149
150impl<const SIZE: usize> Debug for BitSet<SIZE> {
151    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152        let mut bit_str = format!("{:b}", self.bits[Self::ARRAY_SIZE - 1]);
153        // ゼロ埋め
154        bit_str = "0".repeat(Self::REM_BIT - bit_str.len()) + &bit_str;
155        bit_str = self.bits[..Self::ARRAY_SIZE - 1]
156            .iter()
157            .rev()
158            .map(|b64| format!(",{:0>64b}", b64))
159            .fold(bit_str, |acc, b64| acc + &b64);
160        write!(f, "BitSet {{ {:?} }}", bit_str)
161    }
162}