sudachi/dic/lexicon/
trie.rs

1/*
2 * Copyright (c) 2021-2024 Works Applications Co., Ltd.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use crate::util::cow_array::CowArray;
18use std::iter::FusedIterator;
19
20#[derive(Debug, Eq, PartialEq, Clone)]
21pub struct TrieEntry {
22    /// Value of Trie, this is not the pointer to WordId, but the offset in WordId table
23    pub value: u32,
24    /// Offset of word end
25    pub end: usize,
26}
27
28impl TrieEntry {
29    #[inline]
30    pub fn new(value: u32, offset: usize) -> TrieEntry {
31        TrieEntry { value, end: offset }
32    }
33}
34
35pub struct Trie<'a> {
36    array: CowArray<'a, u32>,
37}
38
39pub struct TrieEntryIter<'a> {
40    trie: &'a [u32],
41    node_pos: usize,
42    data: &'a [u8],
43    offset: usize,
44}
45
46impl<'a> TrieEntryIter<'a> {
47    #[inline(always)]
48    fn get(&self, index: usize) -> u32 {
49        debug_assert!(index < self.trie.len());
50        // UB if out of bounds
51        // Should we panic in release builds here instead?
52        // Safe version is not optimized away
53        *unsafe { self.trie.get_unchecked(index) }
54    }
55}
56
57impl<'a> Iterator for TrieEntryIter<'a> {
58    type Item = TrieEntry;
59
60    #[inline]
61    fn next(&mut self) -> Option<Self::Item> {
62        let mut node_pos = self.node_pos;
63        let mut unit;
64
65        for i in self.offset..self.data.len() {
66            // Unwrap is safe: access is always in bounds
67            // It is optimized away: https://rust.godbolt.org/z/va9K3az4n
68            let k = self.data.get(i).unwrap();
69            node_pos ^= *k as usize;
70            unit = self.get(node_pos) as usize;
71            if Trie::label(unit) != *k as usize {
72                return None;
73            }
74
75            node_pos ^= Trie::offset(unit);
76            if Trie::has_leaf(unit) {
77                let r = TrieEntry::new(Trie::value(self.get(node_pos)), i + 1);
78                self.offset = r.end;
79                self.node_pos = node_pos;
80                return Some(r);
81            }
82        }
83        None
84    }
85}
86
87impl FusedIterator for TrieEntryIter<'_> {}
88
89impl<'a> Trie<'a> {
90    pub fn new(data: &'a [u8], size: usize) -> Trie<'a> {
91        Trie {
92            array: CowArray::from_bytes(data, 0, size),
93        }
94    }
95
96    pub fn new_owned(data: Vec<u32>) -> Trie<'a> {
97        Trie {
98            array: CowArray::from_owned(data),
99        }
100    }
101
102    pub fn total_size(&self) -> usize {
103        4 * self.array.len()
104    }
105
106    #[inline]
107    pub fn common_prefix_iterator<'b>(&'a self, input: &'b [u8], offset: usize) -> TrieEntryIter<'b>
108    where
109        'a: 'b,
110    {
111        let unit: usize = self.get(0) as usize;
112
113        TrieEntryIter {
114            node_pos: Trie::offset(unit),
115            data: input,
116            trie: &self.array,
117            offset,
118        }
119    }
120
121    #[inline(always)]
122    fn get(&self, index: usize) -> u32 {
123        debug_assert!(index < self.array.len());
124        // UB if out of bounds
125        // Should we panic in release builds here instead?
126        // Safe version is not optimized away
127        *unsafe { self.array.get_unchecked(index) }
128    }
129
130    #[inline(always)]
131    fn has_leaf(unit: usize) -> bool {
132        ((unit >> 8) & 1) == 1
133    }
134
135    #[inline(always)]
136    fn value(unit: u32) -> u32 {
137        unit & ((1 << 31) - 1)
138    }
139
140    #[inline(always)]
141    fn label(unit: usize) -> usize {
142        unit & ((1 << 31) | 0xFF)
143    }
144
145    #[inline(always)]
146    fn offset(unit: usize) -> usize {
147        (unit >> 10) << ((unit & (1 << 9)) >> 6)
148    }
149}