sudachi/plugin/oov/regex_oov/
mod.rs

1/*
2 *  Copyright (c) 2022-2024 Works Applications Co., Ltd.
3 *
4 *  Licensed under the Apache License, Version 2.0 (the "License");
5 *  you may not use this file except in compliance with the License.
6 *  You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 *   Unless required by applicable law or agreed to in writing, software
11 *  distributed under the License is distributed on an "AS IS" BASIS,
12 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 *  See the License for the specific language governing permissions and
14 *  limitations under the License.
15 */
16
17use crate::analysis::created::{CreatedWords, HasWord};
18use crate::analysis::node::LatticeNode;
19use crate::analysis::Node;
20use crate::config::{Config, ConfigError};
21use crate::dic::grammar::Grammar;
22use crate::dic::word_id::WordId;
23use crate::error::{SudachiError, SudachiResult};
24use crate::input_text::{InputBuffer, InputTextIndex};
25use crate::plugin::oov::OovProviderPlugin;
26use crate::util::check_params::CheckParams;
27use crate::util::user_pos::{UserPosMode, UserPosSupport};
28use regex::{Regex, RegexBuilder};
29use serde::Deserialize;
30use serde_json::Value;
31
32#[cfg(test)]
33mod test;
34
35#[derive(Default)]
36pub(crate) struct RegexOovProvider {
37    regex: Option<Regex>,
38    left_id: u16,
39    right_id: u16,
40    cost: i16,
41    pos: u16,
42    max_length: usize,
43    debug: bool,
44    boundaries: BoundaryMode,
45}
46
47#[derive(Deserialize, Eq, PartialEq, Debug, Copy, Clone, Default)]
48#[serde(rename_all = "lowercase")]
49pub enum BoundaryMode {
50    #[default]
51    Strict,
52    Relaxed,
53}
54
55fn default_max_length() -> usize {
56    32
57}
58
59#[derive(Deserialize)]
60#[allow(non_snake_case)]
61struct RegexProviderConfig {
62    #[serde(alias = "oovPOS")]
63    pos: Vec<String>,
64    leftId: i64,
65    rightId: i64,
66    cost: i64,
67    regex: String,
68    #[serde(default = "default_max_length")]
69    maxLength: usize,
70    #[serde(default)]
71    debug: bool,
72    #[serde(default)]
73    userPOS: UserPosMode,
74    #[serde(default)]
75    boundaries: BoundaryMode,
76}
77
78impl OovProviderPlugin for RegexOovProvider {
79    fn set_up(
80        &mut self,
81        settings: &Value,
82        _config: &Config,
83        mut grammar: &mut Grammar,
84    ) -> SudachiResult<()> {
85        let mut parsed: RegexProviderConfig = serde_json::from_value(settings.clone())?;
86
87        if !parsed.regex.starts_with('^') {
88            parsed.regex.insert(0, '^');
89        }
90
91        self.left_id = grammar.check_left_id(parsed.leftId)?;
92        self.right_id = grammar.check_right_id(parsed.rightId)?;
93        self.cost = grammar.check_cost(parsed.cost)?;
94        self.max_length = parsed.maxLength;
95        self.debug = parsed.debug;
96        self.pos = grammar.handle_user_pos(&parsed.pos, parsed.userPOS)?;
97        self.boundaries = parsed.boundaries;
98
99        match RegexBuilder::new(&parsed.regex).build() {
100            Ok(re) => self.regex = Some(re),
101            Err(e) => {
102                return Err(SudachiError::ConfigError(ConfigError::InvalidFormat(
103                    format!("regex {:?} is invalid: {:?}", &parsed.regex, e),
104                )))
105            }
106        };
107
108        Ok(())
109    }
110
111    fn provide_oov(
112        &self,
113        input_text: &InputBuffer,
114        offset: usize,
115        other_words: CreatedWords,
116        result: &mut Vec<Node>,
117    ) -> SudachiResult<usize> {
118        if self.boundaries == BoundaryMode::Strict && offset > 0 {
119            // check that we have discontinuity in character categories
120            let this_cat = input_text.cat_continuous_len(offset);
121            let prev_cat = input_text.cat_continuous_len(offset - 1);
122            if this_cat + 1 == prev_cat {
123                // no discontinuity
124                return Ok(0);
125            }
126        }
127
128        let regex = self
129            .regex
130            .as_ref()
131            .ok_or_else(|| SudachiError::InvalidDictionaryGrammar)?;
132
133        let end = input_text
134            .current_chars()
135            .len()
136            .min(offset + self.max_length);
137        let text_data = input_text.curr_slice_c(offset..end);
138        match regex.find(text_data) {
139            None => Ok(0),
140            Some(m) => {
141                if m.start() != 0 {
142                    return if self.debug {
143                        Err(SudachiError::InvalidDataFormat(m.start(), format!("in input {:?} regex {:?} matched non-starting text in non-starting position: {}", text_data, regex, m.as_str())))
144                    } else {
145                        Ok(0)
146                    };
147                }
148
149                let byte_offset = input_text.to_curr_byte_idx(offset);
150                let match_start = offset;
151                let match_end = input_text.ch_idx(byte_offset + m.end());
152
153                let match_length = match_end - match_start;
154
155                match other_words.has_word(match_length as i64) {
156                    HasWord::Yes => return Ok(0),
157                    HasWord::No => {} // do nothing
158                    HasWord::Maybe => {
159                        // need to check actual lengths for long words
160                        for node in result.iter() {
161                            if node.end() == match_end {
162                                return Ok(0);
163                            }
164                        }
165                    }
166                }
167
168                let node = Node::new(
169                    match_start as _,
170                    match_end as _,
171                    self.left_id,
172                    self.right_id,
173                    self.cost,
174                    WordId::oov(self.pos as _),
175                );
176                result.push(node);
177                Ok(1)
178            }
179        }
180    }
181}