sudachi/input_text/buffer/
mod.rs

1/*
2 *  Copyright (c) 2021-2024 Works Applications Co., Ltd.
3 *
4 *  Licensed under the Apache License, Version 2.0 (the "License");
5 *  you may not use this file except in compliance with the License.
6 *  You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 *   Unless required by applicable law or agreed to in writing, software
11 *  distributed under the License is distributed on an "AS IS" BASIS,
12 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 *  See the License for the specific language governing permissions and
14 *  limitations under the License.
15 */
16
17mod edit;
18#[cfg(test)]
19mod test_basic;
20#[cfg(test)]
21mod test_ported;
22
23pub use self::edit::InputEditor;
24use crate::dic::category_type::CategoryType;
25use crate::dic::grammar::Grammar;
26use std::ops::Range;
27
28use crate::error::{SudachiError, SudachiResult};
29use crate::input_text::InputTextIndex;
30
31/// limit on the maximum length of the input types, in bytes, 3/4 of u16::MAX
32const MAX_LENGTH: usize = u16::MAX as usize / 4 * 3;
33
34/// if the limit of the rewritten sentence is more than this number, then all bets are off
35const REALLY_MAX_LENGTH: usize = u16::MAX as usize;
36
37#[derive(Eq, PartialEq, Debug, Clone, Default)]
38enum BufferState {
39    #[default]
40    Clean,
41    RW,
42    RO,
43}
44
45/// InputBuffer - prepares the input data for the analysis
46///
47/// By saying char we actually mean Unicode codepoint here.
48/// In the context of this struct these terms are synonyms.
49#[derive(Default, Clone)]
50pub struct InputBuffer {
51    /// Original input data, output is done on this
52    original: String,
53    /// Normalized input data, analysis is done on this. Byte-based indexing.
54    modified: String,
55    /// Buffer for normalization, reusing allocations
56    modified_2: String,
57    /// Byte mapping from normalized data to originals.
58    /// Only values lying on codepoint boundaries are correct. Byte-based indexing.
59    m2o: Vec<usize>,
60    /// Buffer for normalization.
61    /// After building it is used as byte-to-char mapping for original data.
62    m2o_2: Vec<usize>,
63    /// Characters of the modified string. Char-based indexing.
64    mod_chars: Vec<char>,
65    /// Char-to-byte mapping for the modified string. Char-based indexing.
66    mod_c2b: Vec<usize>,
67    /// Byte-to-char mapping for the modified string. Byte-based indexing.
68    mod_b2c: Vec<usize>,
69    /// Markers whether the byte can start new word or not
70    mod_bow: Vec<bool>,
71    /// Character categories. Char-based indexing.
72    mod_cat: Vec<CategoryType>,
73    /// Number of codepoints with the same category. Char-based indexing.
74    mod_cat_continuity: Vec<usize>,
75    /// This very temporarily keeps the replacement data.
76    /// 'static lifetime is a lie and it is **incorrect** to use
77    /// it outside `with_replacer` function or its callees.
78    replaces: Vec<edit::ReplaceOp<'static>>,
79    /// Current state of the buffer
80    state: BufferState,
81}
82
83impl InputBuffer {
84    /// Creates new InputBuffer
85    pub fn new() -> InputBuffer {
86        InputBuffer::default()
87    }
88
89    /// Resets the input buffer, so it could be used to process new input.
90    /// New input should be written to the returned mutable reference.
91    pub fn reset(&mut self) -> &mut String {
92        // extended buffers can be ignored during cleaning,
93        // they will be cleaned before usage automatically
94        self.original.clear();
95        self.modified.clear();
96        self.m2o.clear();
97        self.mod_chars.clear();
98        self.mod_c2b.clear();
99        self.mod_b2c.clear();
100        self.mod_bow.clear();
101        self.mod_cat.clear();
102        self.mod_cat_continuity.clear();
103        self.state = BufferState::Clean;
104        &mut self.original
105    }
106
107    /// Creates input from the passed string. Should be used mostly for tests.
108    ///
109    /// Panics if the input string is too long.
110    pub fn from<T: AsRef<str>>(data: T) -> InputBuffer {
111        let mut buf = Self::new();
112        buf.reset().push_str(data.as_ref());
113        buf.start_build().expect("");
114        buf
115    }
116
117    /// Moves InputBuffer into RW state, making it possible to perform edits on it
118    pub fn start_build(&mut self) -> SudachiResult<()> {
119        if self.original.len() > MAX_LENGTH {
120            return Err(SudachiError::InputTooLong(self.original.len(), MAX_LENGTH));
121        }
122        debug_assert_eq!(self.state, BufferState::Clean);
123        self.state = BufferState::RW;
124        self.modified.push_str(&self.original);
125        self.m2o.extend(0..self.modified.len() + 1);
126        Ok(())
127    }
128
129    /// Finalizes InputBuffer state, making it RO
130    pub fn build(&mut self, grammar: &Grammar) -> SudachiResult<()> {
131        debug_assert_eq!(self.state, BufferState::RW);
132        self.state = BufferState::RO;
133        self.mod_chars.clear();
134        let cats = &grammar.character_category;
135        let mut last_offset = 0;
136        let mut last_chidx = 0;
137
138        // Special cases for BOW logic
139        let non_starting = CategoryType::ALPHA | CategoryType::GREEK | CategoryType::CYRILLIC;
140        let mut prev_cat = CategoryType::empty();
141        self.mod_bow.resize(self.modified.len(), false);
142        let mut next_bow = true;
143
144        for (chidx, (bidx, ch)) in self.modified.char_indices().enumerate() {
145            self.mod_chars.push(ch);
146            let cat = cats.get_category_types(ch);
147            self.mod_cat.push(cat);
148            self.mod_c2b.push(bidx);
149            self.mod_b2c
150                .extend(std::iter::repeat(last_chidx).take(bidx - last_offset));
151            last_offset = bidx;
152            last_chidx = chidx;
153
154            let can_bow = if !next_bow {
155                // this char was forbidden by the previous one
156                next_bow = true;
157                false
158            } else if cat.intersects(CategoryType::NOOOVBOW2) {
159                // this rule is stronger than the next one and must come before
160                // this and next are forbidden
161                next_bow = false;
162                false
163            } else if cat.intersects(CategoryType::NOOOVBOW) {
164                // this char is forbidden
165                false
166            } else if cat.intersects(non_starting) {
167                // the previous char is compatible
168                !cat.intersects(prev_cat)
169            } else {
170                true
171            };
172
173            self.mod_bow[bidx] = can_bow;
174            prev_cat = cat;
175        }
176        // trailing indices for the last codepoint
177        self.mod_b2c
178            .extend(std::iter::repeat(last_chidx).take(self.modified.len() - last_offset));
179        // sentinel values for range translations
180        self.mod_c2b.push(self.mod_b2c.len());
181        self.mod_b2c.push(last_chidx + 1);
182
183        self.fill_cat_continuity();
184        self.fill_orig_b2c();
185
186        Ok(())
187    }
188
189    fn fill_cat_continuity(&mut self) {
190        if self.mod_chars.is_empty() {
191            return;
192        }
193        // single pass algorithm
194        // by default continuity is 1 codepoint
195        // go from the back and set it prev + 1 when chars are compatible
196        self.mod_cat_continuity.resize(self.mod_chars.len(), 1);
197        let mut cat = *self.mod_cat.last().unwrap_or(&CategoryType::all());
198        for i in (0..self.mod_cat.len() - 1).rev() {
199            let cur = self.mod_cat[i];
200            let common = cur & cat;
201            if !common.is_empty() {
202                self.mod_cat_continuity[i] = self.mod_cat_continuity[i + 1] + 1;
203                cat = common;
204            } else {
205                cat = cur;
206            }
207        }
208    }
209
210    fn fill_orig_b2c(&mut self) {
211        self.m2o_2.clear();
212        self.m2o_2.resize(self.original.len() + 1, usize::MAX);
213        let mut max = 0;
214        for (ch_idx, (b_idx, _)) in self.original.char_indices().enumerate() {
215            self.m2o_2[b_idx] = ch_idx;
216            max = ch_idx
217        }
218        self.m2o_2[self.original.len()] = max + 1;
219    }
220
221    fn commit(&mut self) -> SudachiResult<()> {
222        if self.replaces.is_empty() {
223            return Ok(());
224        }
225
226        self.mod_chars.clear();
227        self.modified_2.clear();
228        self.m2o_2.clear();
229
230        let sz = edit::resolve_edits(
231            &self.modified,
232            &self.m2o,
233            &mut self.modified_2,
234            &mut self.m2o_2,
235            &mut self.replaces,
236        );
237        if sz > REALLY_MAX_LENGTH {
238            // super improbable, but still
239            return Err(SudachiError::InputTooLong(sz, REALLY_MAX_LENGTH));
240        }
241        std::mem::swap(&mut self.modified, &mut self.modified_2);
242        std::mem::swap(&mut self.m2o, &mut self.m2o_2);
243        Ok(())
244    }
245
246    fn rollback(&mut self) {
247        self.replaces.clear()
248    }
249
250    fn make_editor<'a>(&mut self) -> InputEditor<'a> {
251        // SAFETY: while it is possible to write into borrowed replaces
252        // the buffer object itself will be accessible as RO
253        let replaces: &'a mut Vec<edit::ReplaceOp<'a>> =
254            unsafe { std::mem::transmute(&mut self.replaces) };
255        return InputEditor::new(replaces);
256    }
257
258    /// Execute a function which can modify the contents of the current buffer
259    ///
260    /// Edit can borrow &str from the context with the borrow checker working correctly     
261    pub fn with_editor<'a, F>(&mut self, func: F) -> SudachiResult<()>
262    where
263        F: FnOnce(&InputBuffer, InputEditor<'a>) -> SudachiResult<InputEditor<'a>>,
264        F: 'a,
265    {
266        debug_assert_eq!(self.state, BufferState::RW);
267        // InputBufferReplacer should have 'a lifetime parameter for API safety
268        // It is impossible to create it outside of this function
269        // And the API forces user to return it by value
270        let editor: InputEditor<'a> = self.make_editor();
271        match func(self, editor) {
272            Ok(_) => self.commit(),
273            Err(e) => {
274                self.rollback();
275                Err(e)
276            }
277        }
278    }
279
280    /// Recompute chars from modified string (useful if the processing will use chars)
281    pub fn refresh_chars(&mut self) {
282        debug_assert_eq!(self.state, BufferState::RW);
283        if self.mod_chars.is_empty() {
284            self.mod_chars.extend(self.modified.chars());
285        }
286    }
287}
288
289// RO Accessors
290impl InputBuffer {
291    /// Borrow original data
292    pub fn original(&self) -> &str {
293        debug_assert_ne!(self.state, BufferState::Clean);
294        &self.original
295    }
296
297    /// Borrow modified data
298    pub fn current(&self) -> &str {
299        debug_assert_ne!(self.state, BufferState::Clean);
300        &self.modified
301    }
302
303    /// Borrow array of current characters
304    pub fn current_chars(&self) -> &[char] {
305        debug_assert_ne!(self.state, BufferState::Clean);
306        debug_assert_eq!(self.modified.is_empty(), self.mod_chars.is_empty());
307        &self.mod_chars
308    }
309
310    /// Returns byte offsets of current chars
311    pub fn curr_byte_offsets(&self) -> &[usize] {
312        debug_assert_eq!(self.state, BufferState::RO);
313        let len = self.mod_c2b.len();
314        &self.mod_c2b[0..len - 1]
315    }
316
317    /// Get index of the current byte in original sentence
318    /// Bytes not on character boundaries are not supported
319    pub fn get_original_index(&self, index: usize) -> usize {
320        debug_assert!(self.modified.is_char_boundary(index));
321        self.m2o[index]
322    }
323
324    /// Mod Char Idx -> Orig Byte Idx
325    pub fn to_orig_byte_idx(&self, index: usize) -> usize {
326        debug_assert_ne!(self.state, BufferState::Clean);
327        let byte_idx = self.mod_c2b[index];
328        self.m2o[byte_idx]
329    }
330
331    /// Mod Char Idx -> Orig Char Idx
332    pub fn to_orig_char_idx(&self, index: usize) -> usize {
333        let b_idx = self.to_orig_byte_idx(index);
334        let res = self.m2o_2[b_idx];
335        debug_assert_ne!(res, usize::MAX);
336        res
337    }
338
339    /// Mod Char Idx -> Mod Byte Idx
340    pub fn to_curr_byte_idx(&self, index: usize) -> usize {
341        debug_assert_eq!(self.state, BufferState::RO);
342        self.mod_c2b[index]
343    }
344
345    /// Input: Mod Char Idx
346    pub fn curr_slice_c(&self, data: Range<usize>) -> &str {
347        debug_assert_eq!(self.state, BufferState::RO);
348        let start = self.mod_c2b[data.start];
349        let end = self.mod_c2b[data.end];
350        &self.modified[start..end]
351    }
352
353    /// Input: Mod Char Idx
354    pub fn orig_slice_c(&self, data: Range<usize>) -> &str {
355        debug_assert_eq!(self.state, BufferState::RO);
356        let start = self.to_orig_byte_idx(data.start);
357        let end = self.to_orig_byte_idx(data.end);
358        &self.original[start..end]
359    }
360
361    pub fn ch_idx(&self, idx: usize) -> usize {
362        debug_assert_eq!(self.state, BufferState::RO);
363        self.mod_b2c[idx]
364    }
365
366    /// Swaps original data with the passed location
367    pub fn swap_original(&mut self, target: &mut String) {
368        std::mem::swap(&mut self.original, target);
369        self.state = BufferState::Clean;
370    }
371
372    /// Return original data as owned, consuming itself    
373    pub fn into_original(self) -> String {
374        self.original
375    }
376
377    /// Whether the byte can start a new word.
378    /// Supports bytes not on character boundaries.
379    #[inline]
380    pub fn can_bow(&self, offset: usize) -> bool {
381        debug_assert_eq!(self.state, BufferState::RO);
382        self.mod_bow[offset]
383    }
384
385    /// Returns char length to the next can_bow point
386    ///
387    /// Used by SimpleOOV plugin
388    pub fn get_word_candidate_length(&self, char_idx: usize) -> usize {
389        debug_assert_eq!(self.state, BufferState::RO);
390        let char_len = self.mod_chars.len();
391
392        for i in (char_idx + 1)..char_len {
393            let byte_idx = self.mod_c2b[i];
394            if self.can_bow(byte_idx) {
395                return i - char_idx;
396            }
397        }
398        char_len - char_idx
399    }
400}
401
402impl InputTextIndex for InputBuffer {
403    #[inline]
404    fn cat_of_range(&self, range: Range<usize>) -> CategoryType {
405        debug_assert_eq!(self.state, BufferState::RO);
406        if range.is_empty() {
407            return CategoryType::empty();
408        }
409
410        self.mod_cat[range]
411            .iter()
412            .fold(CategoryType::all(), |a, b| a & *b)
413    }
414
415    #[inline]
416    fn cat_at_char(&self, offset: usize) -> CategoryType {
417        debug_assert_eq!(self.state, BufferState::RO);
418        self.mod_cat[offset]
419    }
420
421    #[inline]
422    fn cat_continuous_len(&self, offset: usize) -> usize {
423        debug_assert_eq!(self.state, BufferState::RO);
424        self.mod_cat_continuity[offset]
425    }
426
427    fn char_distance(&self, cpt: usize, offset: usize) -> usize {
428        debug_assert_eq!(self.state, BufferState::RO);
429        let end = (cpt + offset).min(self.mod_chars.len());
430        end - cpt
431    }
432
433    #[inline]
434    fn orig_slice(&self, range: Range<usize>) -> &str {
435        debug_assert_ne!(self.state, BufferState::Clean);
436        debug_assert!(
437            self.modified.is_char_boundary(range.start),
438            "start is off char boundary"
439        );
440        debug_assert!(
441            self.modified.is_char_boundary(range.end),
442            "end is off char boundary"
443        );
444        &self.original[self.to_orig(range)]
445    }
446
447    #[inline]
448    fn curr_slice(&self, range: Range<usize>) -> &str {
449        debug_assert_ne!(self.state, BufferState::Clean);
450        &self.modified[range]
451    }
452
453    #[inline]
454    fn to_orig(&self, range: Range<usize>) -> Range<usize> {
455        debug_assert_ne!(self.state, BufferState::Clean);
456        self.m2o[range.start]..self.m2o[range.end]
457    }
458}