sudachi/analysis/
stateful_tokenizer.rs

1/*
2 *  Copyright (c) 2021-2024 Works Applications Co., Ltd.
3 *
4 *  Licensed under the Apache License, Version 2.0 (the "License");
5 *  you may not use this file except in compliance with the License.
6 *  You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 *   Unless required by applicable law or agreed to in writing, software
11 *  distributed under the License is distributed on an "AS IS" BASIS,
12 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 *  See the License for the specific language governing permissions and
14 *  limitations under the License.
15 */
16
17use crate::analysis::created::CreatedWords;
18use crate::analysis::inner::{Node, NodeIdx};
19use crate::analysis::lattice::Lattice;
20use crate::analysis::node::{LatticeNode, ResultNode};
21use crate::analysis::stateless_tokenizer::{dump_path, split_path, DictionaryAccess};
22use crate::analysis::Mode;
23use crate::dic::category_type::CategoryType;
24use crate::dic::connect::ConnectionMatrix;
25use crate::dic::lexicon::word_infos::WordInfoData;
26use crate::dic::lexicon_set::LexiconSet;
27use crate::dic::subset::InfoSubset;
28use crate::error::{SudachiError, SudachiResult};
29use crate::input_text::InputBuffer;
30use crate::input_text::InputTextIndex;
31use crate::plugin::oov::OovProviderPlugin;
32use crate::prelude::MorphemeList;
33
34pub struct StatefulTokenizer<D> {
35    dictionary: D,
36    input: InputBuffer,
37    debug: bool,
38    mode: Mode,
39    oov: Vec<Node>,
40    lattice: Lattice,
41    top_path_ids: Vec<NodeIdx>,
42    top_path: Option<Vec<ResultNode>>,
43    subset: InfoSubset,
44}
45
46impl<D: DictionaryAccess + Clone> StatefulTokenizer<D> {
47    /// Get a clone of current dictionary
48    pub fn dict_clone(&self) -> D {
49        self.dictionary.clone()
50    }
51}
52
53impl<D: DictionaryAccess> StatefulTokenizer<D> {
54    /// Create a new non-debug stateful tokenizer
55    pub fn new(dic: D, mode: Mode) -> Self {
56        Self::create(dic, false, mode)
57    }
58
59    /// Create a new debug stateful tokenizer with the following options
60    pub fn create(dic: D, debug: bool, mode: Mode) -> Self {
61        Self {
62            dictionary: dic,
63            input: InputBuffer::default(),
64            debug,
65            mode,
66            oov: Vec::with_capacity(10),
67            lattice: Lattice::default(),
68            top_path_ids: Vec::new(),
69            top_path: Some(Vec::new()),
70            subset: InfoSubset::all(),
71        }
72    }
73
74    /// Set debug flag and returns the current one
75    pub fn set_debug(&mut self, debug: bool) -> bool {
76        std::mem::replace(&mut self.debug, debug)
77    }
78
79    /// Set the analysis mode and returns the current one
80    pub fn set_mode(&mut self, mode: Mode) -> Mode {
81        self.subset |= match mode {
82            Mode::A => InfoSubset::SPLIT_A,
83            Mode::B => InfoSubset::SPLIT_B,
84            _ => InfoSubset::empty(),
85        };
86        std::mem::replace(&mut self.mode, mode)
87    }
88
89    /// Return current analysis mode
90    pub fn mode(&self) -> Mode {
91        self.mode
92    }
93
94    /// Analyzer will read only following [`WordInfo`] field subset
95    pub fn set_subset(&mut self, subset: InfoSubset) -> InfoSubset {
96        let mode_subset = match self.mode {
97            Mode::A => InfoSubset::SPLIT_A,
98            Mode::B => InfoSubset::SPLIT_B,
99            _ => InfoSubset::empty(),
100        };
101        let new_subset = (subset | mode_subset).normalize();
102        std::mem::replace(&mut self.subset, new_subset | mode_subset)
103    }
104
105    /// Prepare StatefulTokenizer for the next data.
106    /// Data must be written in the returned reference.
107    pub fn reset(&mut self) -> &mut String {
108        if let Some(p) = self.top_path.as_mut() {
109            p.clear()
110        }
111        self.oov.clear();
112        self.input.reset()
113    }
114
115    /// Borrow current dictionary
116    pub fn dict(&self) -> &D {
117        &self.dictionary
118    }
119
120    /// Perform the actual tokenization so the analysis result will be available
121    /// for consumption
122    pub fn do_tokenize(&mut self) -> SudachiResult<()> {
123        self.input.start_build()?;
124        self.rewrite_input()?;
125        self.input.build(self.dictionary.grammar())?;
126
127        if self.input.current().is_empty() {
128            return Ok(());
129        }
130
131        let debug = self.debug;
132
133        if debug {
134            println!("=== Input dump:\n{}", self.input.current());
135        }
136
137        self.build_lattice()?;
138
139        if debug {
140            println!("=== Lattice dump:");
141            let dict = &self.dictionary;
142            let mut writer = std::io::stdout();
143            self.lattice
144                .dump(&self.input, dict.grammar(), dict.lexicon(), &mut writer)?;
145        };
146
147        let mut path = self.resolve_best_path()?;
148
149        if debug {
150            println!("=== Before Rewriting:");
151            dump_path(&path);
152        };
153
154        for plugin in self.dictionary.path_rewrite_plugins() {
155            path = plugin.rewrite(&self.input, path, &self.lattice)?;
156        }
157
158        path = split_path(&self.dictionary, path, self.mode, self.subset, &self.input)?;
159
160        if debug {
161            println!("=== After Rewriting:");
162            dump_path(&path);
163            println!("===");
164        };
165
166        self.top_path = Some(path);
167
168        Ok(())
169    }
170
171    /// Resolve the path (as ResultNodes) with the smallest cost
172    fn resolve_best_path(&mut self) -> SudachiResult<Vec<ResultNode>> {
173        let lex = self.dictionary.lexicon();
174        let mut path = self.top_path.take().unwrap_or_default();
175        self.lattice.fill_top_path(&mut self.top_path_ids);
176        self.top_path_ids.reverse();
177        for pid in self.top_path_ids.drain(..) {
178            let (inner, cost) = self.lattice.node(pid);
179            let wi = if inner.word_id().is_oov() {
180                let curr_slice = self.input.curr_slice_c(inner.char_range()).to_owned();
181                WordInfoData {
182                    pos_id: inner.word_id().word() as u16,
183                    surface: curr_slice,
184                    ..Default::default()
185                }
186                .into()
187            } else {
188                lex.get_word_info_subset(inner.word_id(), self.subset)?
189            };
190
191            let byte_begin = self.input.to_curr_byte_idx(inner.begin());
192            let byte_end = self.input.to_curr_byte_idx(inner.end());
193
194            path.push(ResultNode::new(
195                inner.clone(),
196                cost,
197                byte_begin as u16,
198                byte_end as u16,
199                wi,
200            ));
201        }
202        Ok(path)
203    }
204
205    /// Swap result data with the current analyzer
206    pub fn swap_result(
207        &mut self,
208        input: &mut InputBuffer,
209        result: &mut Vec<ResultNode>,
210        subset: &mut InfoSubset,
211    ) {
212        std::mem::swap(&mut self.input, input);
213        std::mem::swap(self.top_path.as_mut().unwrap(), result);
214        *subset = self.subset;
215    }
216
217    fn rewrite_input(&mut self) -> SudachiResult<()> {
218        for p in self.dictionary.input_text_plugins() {
219            p.rewrite(&mut self.input)?;
220        }
221        Ok(())
222    }
223
224    fn build_lattice(&mut self) -> SudachiResult<()> {
225        let mut builder = LatticeBuilder {
226            node_buffer: &mut self.oov,
227            lattice: &mut self.lattice,
228            matrix: self.dictionary.grammar().conn_matrix(),
229            oov_providers: self.dictionary.oov_provider_plugins(),
230            lexicon: self.dictionary.lexicon(),
231            input: &self.input,
232        };
233        builder.build_lattice()
234    }
235
236    /// Consume the Tokenizer and produce MorphemeList
237    pub fn into_morpheme_list(self) -> SudachiResult<MorphemeList<D>> {
238        match self.top_path {
239            None => Err(SudachiError::EosBosDisconnect),
240            Some(path) => Ok(MorphemeList::from_components(
241                self.dictionary,
242                self.input,
243                path,
244                self.subset,
245            )),
246        }
247    }
248}
249
250// This structure is purely for Rust.
251// Otherwise splitting code into functions fails to compile with double borrow errors
252struct LatticeBuilder<'a> {
253    node_buffer: &'a mut Vec<Node>,
254    lattice: &'a mut Lattice,
255    matrix: &'a ConnectionMatrix<'a>,
256    input: &'a InputBuffer,
257    lexicon: &'a LexiconSet<'a>,
258    oov_providers: &'a [Box<dyn OovProviderPlugin + Sync + Send>],
259}
260
261impl<'a> LatticeBuilder<'a> {
262    #[inline]
263    fn build_lattice(&mut self) -> SudachiResult<()> {
264        self.lattice.reset(self.input.current_chars().len());
265        let input_bytes = self.input.current().as_bytes();
266
267        for (ch_off, &byte_off) in self.input.curr_byte_offsets().iter().enumerate() {
268            if !self.lattice.has_previous_node(ch_off) {
269                continue;
270            }
271
272            self.node_buffer.clear();
273            let mut created = CreatedWords::default();
274            for e in self.lexicon.lookup(input_bytes, byte_off) {
275                // do we really need input.can_bow condition?
276                if (e.end < input_bytes.len()) && !self.input.can_bow(e.end) {
277                    continue;
278                }
279                let (left_id, right_id, cost) = self.lexicon.get_word_param(e.word_id);
280                let end_c = self.input.ch_idx(e.end);
281                let node = Node::new(
282                    ch_off as u16,
283                    end_c as u16,
284                    left_id as u16,
285                    right_id as u16,
286                    cost,
287                    e.word_id,
288                );
289                created = created.add_word((end_c - ch_off) as i64);
290                self.node_buffer.push(node.clone());
291                self.lattice.insert(node, self.matrix);
292            }
293
294            // OOV
295            if !self
296                .input
297                .cat_at_char(ch_off)
298                .intersects(CategoryType::NOOOVBOW | CategoryType::NOOOVBOW2)
299            {
300                for provider in self.oov_providers {
301                    created = self.provide_oovs(ch_off, created, provider.as_ref())?;
302                }
303            }
304
305            if created.is_empty() {
306                let provider = self.oov_providers.last().unwrap();
307                created = self.provide_oovs(ch_off, created, provider.as_ref())?;
308            }
309
310            if created.is_empty() {
311                return Err(SudachiError::EosBosDisconnect);
312            }
313        }
314        self.lattice.connect_eos(self.matrix)?;
315
316        Ok(())
317    }
318
319    #[inline]
320    fn provide_oovs<P>(
321        &mut self,
322        char_offset: usize,
323        mut other: CreatedWords,
324        plugin: &P,
325    ) -> SudachiResult<CreatedWords>
326    where
327        P: OovProviderPlugin + 'a + ?Sized,
328    {
329        let start_size = self.node_buffer.len();
330        let num_provided = plugin.provide_oov(self.input, char_offset, other, self.node_buffer)?;
331        for idx in start_size..(start_size + num_provided) {
332            let node = self.node_buffer[idx].clone();
333            other = other.add_word(node.char_range().len() as i64);
334            self.lattice.insert(node, self.matrix);
335        }
336        Ok(other)
337    }
338}