1use crate::analysis::created::CreatedWords;
18use crate::analysis::inner::{Node, NodeIdx};
19use crate::analysis::lattice::Lattice;
20use crate::analysis::node::{LatticeNode, ResultNode};
21use crate::analysis::stateless_tokenizer::{dump_path, split_path, DictionaryAccess};
22use crate::analysis::Mode;
23use crate::dic::category_type::CategoryType;
24use crate::dic::connect::ConnectionMatrix;
25use crate::dic::lexicon::word_infos::WordInfoData;
26use crate::dic::lexicon_set::LexiconSet;
27use crate::dic::subset::InfoSubset;
28use crate::error::{SudachiError, SudachiResult};
29use crate::input_text::InputBuffer;
30use crate::input_text::InputTextIndex;
31use crate::plugin::oov::OovProviderPlugin;
32use crate::prelude::MorphemeList;
33
34pub struct StatefulTokenizer<D> {
35 dictionary: D,
36 input: InputBuffer,
37 debug: bool,
38 mode: Mode,
39 oov: Vec<Node>,
40 lattice: Lattice,
41 top_path_ids: Vec<NodeIdx>,
42 top_path: Option<Vec<ResultNode>>,
43 subset: InfoSubset,
44}
45
46impl<D: DictionaryAccess + Clone> StatefulTokenizer<D> {
47 pub fn dict_clone(&self) -> D {
49 self.dictionary.clone()
50 }
51}
52
53impl<D: DictionaryAccess> StatefulTokenizer<D> {
54 pub fn new(dic: D, mode: Mode) -> Self {
56 Self::create(dic, false, mode)
57 }
58
59 pub fn create(dic: D, debug: bool, mode: Mode) -> Self {
61 Self {
62 dictionary: dic,
63 input: InputBuffer::default(),
64 debug,
65 mode,
66 oov: Vec::with_capacity(10),
67 lattice: Lattice::default(),
68 top_path_ids: Vec::new(),
69 top_path: Some(Vec::new()),
70 subset: InfoSubset::all(),
71 }
72 }
73
74 pub fn set_debug(&mut self, debug: bool) -> bool {
76 std::mem::replace(&mut self.debug, debug)
77 }
78
79 pub fn set_mode(&mut self, mode: Mode) -> Mode {
81 self.subset |= match mode {
82 Mode::A => InfoSubset::SPLIT_A,
83 Mode::B => InfoSubset::SPLIT_B,
84 _ => InfoSubset::empty(),
85 };
86 std::mem::replace(&mut self.mode, mode)
87 }
88
89 pub fn mode(&self) -> Mode {
91 self.mode
92 }
93
94 pub fn set_subset(&mut self, subset: InfoSubset) -> InfoSubset {
96 let mode_subset = match self.mode {
97 Mode::A => InfoSubset::SPLIT_A,
98 Mode::B => InfoSubset::SPLIT_B,
99 _ => InfoSubset::empty(),
100 };
101 let new_subset = (subset | mode_subset).normalize();
102 std::mem::replace(&mut self.subset, new_subset | mode_subset)
103 }
104
105 pub fn reset(&mut self) -> &mut String {
108 if let Some(p) = self.top_path.as_mut() {
109 p.clear()
110 }
111 self.oov.clear();
112 self.input.reset()
113 }
114
115 pub fn dict(&self) -> &D {
117 &self.dictionary
118 }
119
120 pub fn do_tokenize(&mut self) -> SudachiResult<()> {
123 self.input.start_build()?;
124 self.rewrite_input()?;
125 self.input.build(self.dictionary.grammar())?;
126
127 if self.input.current().is_empty() {
128 return Ok(());
129 }
130
131 let debug = self.debug;
132
133 if debug {
134 println!("=== Input dump:\n{}", self.input.current());
135 }
136
137 self.build_lattice()?;
138
139 if debug {
140 println!("=== Lattice dump:");
141 let dict = &self.dictionary;
142 let mut writer = std::io::stdout();
143 self.lattice
144 .dump(&self.input, dict.grammar(), dict.lexicon(), &mut writer)?;
145 };
146
147 let mut path = self.resolve_best_path()?;
148
149 if debug {
150 println!("=== Before Rewriting:");
151 dump_path(&path);
152 };
153
154 for plugin in self.dictionary.path_rewrite_plugins() {
155 path = plugin.rewrite(&self.input, path, &self.lattice)?;
156 }
157
158 path = split_path(&self.dictionary, path, self.mode, self.subset, &self.input)?;
159
160 if debug {
161 println!("=== After Rewriting:");
162 dump_path(&path);
163 println!("===");
164 };
165
166 self.top_path = Some(path);
167
168 Ok(())
169 }
170
171 fn resolve_best_path(&mut self) -> SudachiResult<Vec<ResultNode>> {
173 let lex = self.dictionary.lexicon();
174 let mut path = self.top_path.take().unwrap_or_default();
175 self.lattice.fill_top_path(&mut self.top_path_ids);
176 self.top_path_ids.reverse();
177 for pid in self.top_path_ids.drain(..) {
178 let (inner, cost) = self.lattice.node(pid);
179 let wi = if inner.word_id().is_oov() {
180 let curr_slice = self.input.curr_slice_c(inner.char_range()).to_owned();
181 WordInfoData {
182 pos_id: inner.word_id().word() as u16,
183 surface: curr_slice,
184 ..Default::default()
185 }
186 .into()
187 } else {
188 lex.get_word_info_subset(inner.word_id(), self.subset)?
189 };
190
191 let byte_begin = self.input.to_curr_byte_idx(inner.begin());
192 let byte_end = self.input.to_curr_byte_idx(inner.end());
193
194 path.push(ResultNode::new(
195 inner.clone(),
196 cost,
197 byte_begin as u16,
198 byte_end as u16,
199 wi,
200 ));
201 }
202 Ok(path)
203 }
204
205 pub fn swap_result(
207 &mut self,
208 input: &mut InputBuffer,
209 result: &mut Vec<ResultNode>,
210 subset: &mut InfoSubset,
211 ) {
212 std::mem::swap(&mut self.input, input);
213 std::mem::swap(self.top_path.as_mut().unwrap(), result);
214 *subset = self.subset;
215 }
216
217 fn rewrite_input(&mut self) -> SudachiResult<()> {
218 for p in self.dictionary.input_text_plugins() {
219 p.rewrite(&mut self.input)?;
220 }
221 Ok(())
222 }
223
224 fn build_lattice(&mut self) -> SudachiResult<()> {
225 let mut builder = LatticeBuilder {
226 node_buffer: &mut self.oov,
227 lattice: &mut self.lattice,
228 matrix: self.dictionary.grammar().conn_matrix(),
229 oov_providers: self.dictionary.oov_provider_plugins(),
230 lexicon: self.dictionary.lexicon(),
231 input: &self.input,
232 };
233 builder.build_lattice()
234 }
235
236 pub fn into_morpheme_list(self) -> SudachiResult<MorphemeList<D>> {
238 match self.top_path {
239 None => Err(SudachiError::EosBosDisconnect),
240 Some(path) => Ok(MorphemeList::from_components(
241 self.dictionary,
242 self.input,
243 path,
244 self.subset,
245 )),
246 }
247 }
248}
249
250struct LatticeBuilder<'a> {
253 node_buffer: &'a mut Vec<Node>,
254 lattice: &'a mut Lattice,
255 matrix: &'a ConnectionMatrix<'a>,
256 input: &'a InputBuffer,
257 lexicon: &'a LexiconSet<'a>,
258 oov_providers: &'a [Box<dyn OovProviderPlugin + Sync + Send>],
259}
260
261impl<'a> LatticeBuilder<'a> {
262 #[inline]
263 fn build_lattice(&mut self) -> SudachiResult<()> {
264 self.lattice.reset(self.input.current_chars().len());
265 let input_bytes = self.input.current().as_bytes();
266
267 for (ch_off, &byte_off) in self.input.curr_byte_offsets().iter().enumerate() {
268 if !self.lattice.has_previous_node(ch_off) {
269 continue;
270 }
271
272 self.node_buffer.clear();
273 let mut created = CreatedWords::default();
274 for e in self.lexicon.lookup(input_bytes, byte_off) {
275 if (e.end < input_bytes.len()) && !self.input.can_bow(e.end) {
277 continue;
278 }
279 let (left_id, right_id, cost) = self.lexicon.get_word_param(e.word_id);
280 let end_c = self.input.ch_idx(e.end);
281 let node = Node::new(
282 ch_off as u16,
283 end_c as u16,
284 left_id as u16,
285 right_id as u16,
286 cost,
287 e.word_id,
288 );
289 created = created.add_word((end_c - ch_off) as i64);
290 self.node_buffer.push(node.clone());
291 self.lattice.insert(node, self.matrix);
292 }
293
294 if !self
296 .input
297 .cat_at_char(ch_off)
298 .intersects(CategoryType::NOOOVBOW | CategoryType::NOOOVBOW2)
299 {
300 for provider in self.oov_providers {
301 created = self.provide_oovs(ch_off, created, provider.as_ref())?;
302 }
303 }
304
305 if created.is_empty() {
306 let provider = self.oov_providers.last().unwrap();
307 created = self.provide_oovs(ch_off, created, provider.as_ref())?;
308 }
309
310 if created.is_empty() {
311 return Err(SudachiError::EosBosDisconnect);
312 }
313 }
314 self.lattice.connect_eos(self.matrix)?;
315
316 Ok(())
317 }
318
319 #[inline]
320 fn provide_oovs<P>(
321 &mut self,
322 char_offset: usize,
323 mut other: CreatedWords,
324 plugin: &P,
325 ) -> SudachiResult<CreatedWords>
326 where
327 P: OovProviderPlugin + 'a + ?Sized,
328 {
329 let start_size = self.node_buffer.len();
330 let num_provided = plugin.provide_oov(self.input, char_offset, other, self.node_buffer)?;
331 for idx in start_size..(start_size + num_provided) {
332 let node = self.node_buffer[idx].clone();
333 other = other.add_word(node.char_range().len() as i64);
334 self.lattice.insert(node, self.matrix);
335 }
336 Ok(other)
337 }
338}