cp_library_rs/data_structure/
trie.rs

1//! トライ木
2
3use std::fmt::Debug;
4
5// 定数
6const ORIGIN: char = 'a'; // 基準となる文字
7const ORIGIN_ID: usize = ORIGIN as usize; // 基準となる文字のID
8const KINDS: usize = 26; // 文字の種類数
9type NodePointer<T> = Option<Box<TrieNode<T>>>;
10
11/// 何番目の文字かを判定する
12fn ord(c: char) -> usize {
13    let num = c as usize;
14    num - ORIGIN_ID
15}
16
17/// i番目の文字を返す
18fn chr(i: usize) -> char {
19    (ORIGIN_ID + i) as u8 as char
20}
21
22/// # TrieNode
23/// - トライ木のノード
24#[derive(Debug, Clone)]
25struct TrieNode<T> {
26    data: Option<T>,
27    children: Vec<NodePointer<T>>,
28}
29
30impl<T> TrieNode<T>
31where
32    T: Clone,
33{
34    pub fn new(data: Option<T>) -> Self {
35        Self {
36            data,
37            children: vec![NodePointer::None; KINDS],
38        }
39    }
40}
41
42/// # Trie
43/// - トライ木の実装
44#[derive(Debug)]
45pub struct Trie<T> {
46    size: usize,
47    root: NodePointer<T>,
48}
49
50impl<T> Trie<T>
51where
52    T: Clone + Debug,
53{
54    pub fn len(&self) -> usize {
55        self.size
56    }
57
58    pub fn is_empty(&self) -> bool {
59        self.size == 0
60    }
61
62    pub fn insert(&mut self, key: &str, data: T) -> Option<T> {
63        let res = self.get_or_insert_mut(key).replace(data);
64        if res.is_none() {
65            self.size += 1;
66        }
67        res
68    }
69
70    pub fn get(&self, key: &str) -> Option<&T> {
71        let mut node = &self.root;
72        for c in key.chars().map(ord) {
73            node = &node.as_ref()?.children[c];
74        }
75        node.as_deref()?.data.as_ref()
76    }
77
78    pub fn get_mut(&mut self, key: &str) -> Option<&mut T> {
79        let mut node = &mut self.root;
80        for c in key.chars().map(ord) {
81            node = node.as_mut()?.children.get_mut(c).unwrap();
82        }
83        node.as_deref_mut()?.data.as_mut()
84    }
85
86    pub fn get_or_insert_mut(&mut self, key: &str) -> &mut Option<T> {
87        let mut node = &mut self.root;
88        for c in key.chars().map(ord).chain(KINDS..=KINDS) {
89            // データの挿入
90            if c == KINDS {
91                if node.as_ref().is_none() {
92                    *node = Some(Box::new(TrieNode::new(None)));
93                }
94                break;
95            }
96            if node.as_ref().is_none() {
97                *node = Some(Box::new(TrieNode::new(None)));
98            }
99            node = node.as_mut().unwrap().children.get_mut(c).unwrap();
100        }
101        &mut node.as_deref_mut().unwrap().data
102    }
103
104    pub fn traverse(&self) -> Vec<(String, &T)> {
105        let mut res = vec![];
106        let mut cur = String::new();
107        traverse_inner(&self.root, &mut cur, &mut res);
108        res
109    }
110}
111
112impl<T: Clone> Default for Trie<T> {
113    fn default() -> Self {
114        Trie {
115            size: 0,
116            root: Some(Box::new(TrieNode {
117                data: None,
118                children: vec![NodePointer::None; KINDS],
119            })),
120        }
121    }
122}
123
124/// trieを順に探索する
125fn traverse_inner<'a, T>(
126    node: &'a NodePointer<T>,
127    cur: &mut String,
128    list: &mut Vec<(String, &'a T)>,
129) {
130    if let Some(value) = node.as_ref().unwrap().data.as_ref() {
131        let key = cur.clone();
132        list.push((key, value));
133    }
134    if let Some(node) = node.as_deref() {
135        for (i, child) in node.children.iter().enumerate() {
136            if child.as_ref().is_some() {
137                cur.push(chr(i));
138                traverse_inner(child, cur, list);
139                cur.pop();
140            }
141        }
142    }
143}