sudachi/analysis/
node.rs

1/*
2 *  Copyright (c) 2021-2024 Works Applications Co., Ltd.
3 *
4 *  Licensed under the Apache License, Version 2.0 (the "License");
5 *  you may not use this file except in compliance with the License.
6 *  You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 *   Unless required by applicable law or agreed to in writing, software
11 *  distributed under the License is distributed on an "AS IS" BASIS,
12 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 *  See the License for the specific language governing permissions and
14 *  limitations under the License.
15 */
16
17use std::fmt;
18use std::iter::FusedIterator;
19use std::ops::Range;
20
21use crate::analysis::inner::Node;
22use crate::dic::lexicon::word_infos::{WordInfo, WordInfoData};
23use crate::dic::lexicon_set::LexiconSet;
24use crate::dic::subset::InfoSubset;
25use crate::dic::word_id::WordId;
26use crate::input_text::InputBuffer;
27use crate::prelude::*;
28
29/// Accessor trait for right connection id
30pub trait RightId {
31    fn right_id(&self) -> u16;
32}
33
34/// Accessor trait for the full path cost
35pub trait PathCost {
36    fn total_cost(&self) -> i32;
37
38    #[inline]
39    fn is_connected_to_bos(&self) -> bool {
40        self.total_cost() != i32::MAX
41    }
42}
43
44pub trait LatticeNode: RightId {
45    fn begin(&self) -> usize;
46    fn end(&self) -> usize;
47    fn cost(&self) -> i16;
48    fn word_id(&self) -> WordId;
49    fn left_id(&self) -> u16;
50
51    /// Is true when the word does not come from the dictionary.
52    /// BOS and EOS are also treated as OOV.
53    #[inline]
54    fn is_oov(&self) -> bool {
55        self.word_id().is_oov()
56    }
57
58    /// If a node is a special system node like BOS or EOS.
59    /// Java name isSystem (which is similar to a regular node coming from the system dictionary)
60    #[inline]
61    fn is_special_node(&self) -> bool {
62        self.word_id().is_special()
63    }
64
65    /// Returns number of codepoints in the current node
66    #[inline]
67    fn num_codepts(&self) -> usize {
68        self.end() - self.begin()
69    }
70
71    /// Utility method for extracting [begin, end) codepoint range.
72    #[inline]
73    fn char_range(&self) -> Range<usize> {
74        self.begin()..self.end()
75    }
76}
77
78#[derive(Clone)]
79/// Full lattice node, as the result of analysis.
80/// All indices (including inner) are in the modified sentence space
81/// Indices are converted to original sentence space when user request them.
82pub struct ResultNode {
83    inner: Node,
84    total_cost: i32,
85    begin_bytes: u16,
86    end_bytes: u16,
87    word_info: WordInfo,
88}
89
90impl ResultNode {
91    pub fn new(
92        inner: Node,
93        total_cost: i32,
94        begin_bytes: u16,
95        end_bytes: u16,
96        word_info: WordInfo,
97    ) -> ResultNode {
98        ResultNode {
99            inner,
100            total_cost,
101            begin_bytes,
102            end_bytes,
103            word_info,
104        }
105    }
106}
107
108impl RightId for ResultNode {
109    fn right_id(&self) -> u16 {
110        self.inner.right_id()
111    }
112}
113
114impl PathCost for ResultNode {
115    fn total_cost(&self) -> i32 {
116        self.total_cost
117    }
118}
119
120impl LatticeNode for ResultNode {
121    fn begin(&self) -> usize {
122        self.inner.begin()
123    }
124
125    fn end(&self) -> usize {
126        self.inner.end()
127    }
128
129    fn cost(&self) -> i16 {
130        self.inner.cost()
131    }
132
133    fn word_id(&self) -> WordId {
134        self.inner.word_id()
135    }
136
137    fn left_id(&self) -> u16 {
138        self.inner.left_id()
139    }
140}
141
142impl ResultNode {
143    pub fn word_info(&self) -> &WordInfo {
144        &self.word_info
145    }
146
147    /// Returns begin offset in bytes of node surface in a sentence
148    pub fn begin_bytes(&self) -> usize {
149        self.begin_bytes as usize
150    }
151
152    /// Returns end offset in bytes of node surface in a sentence
153    pub fn end_bytes(&self) -> usize {
154        self.end_bytes as usize
155    }
156
157    /// Returns range in bytes (for easy string slicing)
158    pub fn bytes_range(&self) -> Range<usize> {
159        self.begin_bytes()..self.end_bytes()
160    }
161
162    pub fn set_bytes_range(&mut self, begin: u16, end: u16) {
163        self.begin_bytes = begin;
164        self.end_bytes = end;
165    }
166
167    pub fn set_char_range(&mut self, begin: u16, end: u16) {
168        self.inner.set_range(begin, end)
169    }
170
171    /// Returns number of splits in a specified mode
172    pub fn num_splits(&self, mode: Mode) -> usize {
173        match mode {
174            Mode::A => self.word_info.a_unit_split().len(),
175            Mode::B => self.word_info.b_unit_split().len(),
176            Mode::C => 0,
177        }
178    }
179
180    /// Split the node with a specified mode using the dictionary data
181    pub fn split<'a>(
182        &'a self,
183        mode: Mode,
184        lexicon: &'a LexiconSet<'a>,
185        subset: InfoSubset,
186        text: &'a InputBuffer,
187    ) -> NodeSplitIterator<'a> {
188        let splits: &[WordId] = match mode {
189            Mode::A => self.word_info.a_unit_split(),
190            Mode::B => self.word_info.b_unit_split(),
191            Mode::C => panic!("splitting Node with Mode::C is not supported"),
192        };
193
194        NodeSplitIterator {
195            splits,
196            index: 0,
197            lexicon,
198            subset,
199            text,
200            byte_offset: self.begin_bytes,
201            byte_end: self.end_bytes,
202            char_offset: self.begin() as u16,
203            char_end: self.end() as u16,
204        }
205    }
206}
207
208impl fmt::Display for ResultNode {
209    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
210        write!(
211            f,
212            "{} {} {}{} {} {} {} {}",
213            self.begin(),
214            self.end(),
215            self.word_info.surface(),
216            self.word_id(),
217            self.word_info().pos_id(),
218            self.left_id(),
219            self.right_id(),
220            self.cost()
221        )
222    }
223}
224
225pub struct NodeSplitIterator<'a> {
226    splits: &'a [WordId],
227    lexicon: &'a LexiconSet<'a>,
228    index: usize,
229    subset: InfoSubset,
230    text: &'a InputBuffer,
231    char_offset: u16,
232    byte_offset: u16,
233    char_end: u16,
234    byte_end: u16,
235}
236
237impl Iterator for NodeSplitIterator<'_> {
238    type Item = ResultNode;
239
240    #[inline]
241    fn next(&mut self) -> Option<Self::Item> {
242        let idx = self.index;
243        if idx >= self.splits.len() {
244            return None;
245        }
246
247        let char_start = self.char_offset;
248        let byte_start = self.byte_offset;
249
250        let word_id = self.splits[idx];
251        // data comes from dictionary, panicking here is OK
252        let word_info = self
253            .lexicon
254            .get_word_info_subset(word_id, self.subset)
255            .unwrap();
256
257        let (char_end, byte_end) = if idx + 1 == self.splits.len() {
258            (self.char_end, self.byte_end)
259        } else {
260            let byte_end = byte_start as usize + word_info.head_word_length();
261            let char_end = self.text.ch_idx(byte_end);
262            (char_end as u16, byte_end as u16)
263        };
264
265        self.char_offset = char_end;
266        self.byte_offset = byte_end;
267
268        let inner = Node::new(char_start, char_end, u16::MAX, u16::MAX, i16::MAX, word_id);
269
270        let node = ResultNode::new(inner, i32::MAX, byte_start, byte_end, word_info);
271
272        self.index += 1;
273        Some(node)
274    }
275
276    #[inline]
277    fn size_hint(&self) -> (usize, Option<usize>) {
278        (self.splits.len(), Some(self.splits.len()))
279    }
280}
281
282impl FusedIterator for NodeSplitIterator<'_> {}
283
284/// Concatenate the nodes in the range and replace normalized_form if given.
285pub fn concat_nodes(
286    mut path: Vec<ResultNode>,
287    begin: usize,
288    end: usize,
289    normalized_form: Option<String>,
290) -> SudachiResult<Vec<ResultNode>> {
291    if begin >= end {
292        return Err(SudachiError::InvalidRange(begin, end));
293    }
294
295    let end_bytes = path[end - 1].end_bytes();
296    let beg_bytes = path[begin].begin_bytes();
297
298    let mut surface = String::with_capacity(end_bytes - beg_bytes);
299    let mut reading_form = String::with_capacity(end_bytes - beg_bytes);
300    let mut dictionary_form = String::with_capacity(end_bytes - beg_bytes);
301    let mut head_word_length: u16 = 0;
302
303    for node in path[begin..end].iter() {
304        let data = node.word_info().borrow_data();
305        surface.push_str(&data.surface);
306        reading_form.push_str(&data.reading_form);
307        dictionary_form.push_str(&data.dictionary_form);
308        head_word_length += data.head_word_length;
309    }
310
311    let normalized_form = normalized_form.unwrap_or_else(|| {
312        let mut norm = String::with_capacity(end_bytes - beg_bytes);
313        for node in path[begin..end].iter() {
314            norm.push_str(&node.word_info().borrow_data().normalized_form);
315        }
316        norm
317    });
318
319    let pos_id = path[begin].word_info().pos_id();
320
321    let new_wi = WordInfoData {
322        surface,
323        head_word_length,
324        pos_id,
325        normalized_form,
326        reading_form,
327        dictionary_form,
328        dictionary_form_word_id: -1,
329        ..Default::default()
330    };
331
332    let inner = Node::new(
333        path[begin].begin() as u16,
334        path[end - 1].end() as u16,
335        u16::MAX,
336        u16::MAX,
337        i16::MAX,
338        WordId::INVALID,
339    );
340
341    let node = ResultNode::new(
342        inner,
343        path[end - 1].total_cost,
344        path[begin].begin_bytes,
345        path[end - 1].end_bytes,
346        new_wi.into(),
347    );
348
349    path[begin] = node;
350    path.drain(begin + 1..end);
351    Ok(path)
352}
353
354/// Concatenate the nodes in the range and set pos_id.
355pub fn concat_oov_nodes(
356    mut path: Vec<ResultNode>,
357    begin: usize,
358    end: usize,
359    pos_id: u16,
360) -> SudachiResult<Vec<ResultNode>> {
361    if begin >= end {
362        return Err(SudachiError::InvalidRange(begin, end));
363    }
364
365    let capa = path[end - 1].end_bytes() - path[begin].begin_bytes();
366
367    let mut surface = String::with_capacity(capa);
368    let mut head_word_length: u16 = 0;
369    let mut wid = WordId::from_raw(0);
370
371    for node in path[begin..end].iter() {
372        let data = node.word_info().borrow_data();
373        surface.push_str(&data.surface);
374        head_word_length += data.head_word_length;
375        wid = wid.max(node.word_id());
376    }
377
378    if !wid.is_oov() {
379        wid = WordId::new(wid.dic(), WordId::MAX_WORD);
380    }
381
382    let new_wi = WordInfoData {
383        normalized_form: surface.clone(),
384        dictionary_form: surface.clone(),
385        surface,
386        head_word_length,
387        pos_id,
388        dictionary_form_word_id: -1,
389        ..Default::default()
390    };
391
392    let inner = Node::new(
393        path[begin].begin() as u16,
394        path[end - 1].end() as u16,
395        u16::MAX,
396        u16::MAX,
397        i16::MAX,
398        wid,
399    );
400
401    let node = ResultNode::new(
402        inner,
403        path[end - 1].total_cost,
404        path[begin].begin_bytes,
405        path[end - 1].end_bytes,
406        new_wi.into(),
407    );
408
409    path[begin] = node;
410    path.drain(begin + 1..end);
411    Ok(path)
412}