sudachi/plugin/path_rewrite/join_numeric/
mod.rs

1/*
2 * Copyright (c) 2021-2024 Works Applications Co., Ltd.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use serde::Deserialize;
18use serde_json::Value;
19
20use self::numeric_parser::NumericParser;
21use crate::analysis::lattice::Lattice;
22use crate::analysis::node::{concat_nodes, LatticeNode, ResultNode};
23use crate::config::Config;
24use crate::dic::category_type::CategoryType;
25use crate::dic::grammar::Grammar;
26use crate::input_text::InputBuffer;
27use crate::input_text::InputTextIndex;
28use crate::plugin::path_rewrite::PathRewritePlugin;
29use crate::prelude::*;
30
31mod numeric_parser;
32#[cfg(test)]
33mod test;
34
35/// Concatenates numeric nodes as one
36#[derive(Default)]
37pub struct JoinNumericPlugin {
38    /// The pos_id to concatenate
39    numeric_pos_id: u16,
40    /// Whether if to normalize the normalized_form
41    enable_normalize: bool,
42}
43
44/// Struct corresponds with raw config json file.
45#[allow(non_snake_case)]
46#[derive(Deserialize)]
47struct PluginSettings {
48    enableNormalize: Option<bool>,
49}
50
51impl JoinNumericPlugin {
52    fn concat(
53        &self,
54        mut path: Vec<ResultNode>,
55        begin: usize,
56        end: usize,
57        parser: &mut NumericParser,
58    ) -> SudachiResult<Vec<ResultNode>> {
59        let word_info = path[begin].word_info();
60
61        if word_info.pos_id() != self.numeric_pos_id {
62            return Ok(path);
63        }
64
65        if self.enable_normalize {
66            let normalized_form = parser.get_normalized();
67            if end - begin > 1 || normalized_form != word_info.normalized_form() {
68                path = concat_nodes(path, begin, end, Some(normalized_form))?;
69            }
70            return Ok(path);
71        }
72
73        if end - begin > 1 {
74            path = concat_nodes(path, begin, end, None)?;
75        }
76        Ok(path)
77    }
78
79    fn rewrite_gen<T: InputTextIndex>(
80        &self,
81        text: &T,
82        mut path: Vec<ResultNode>,
83    ) -> SudachiResult<Vec<ResultNode>> {
84        let mut begin_idx = -1;
85        let mut comma_as_digit = true;
86        let mut period_as_digit = true;
87        let mut parser = NumericParser::new();
88        let mut i = -1;
89        while i < path.len() as i32 - 1 {
90            i += 1;
91            let node = &path[i as usize];
92            let ctypes = text.cat_of_range(node.char_range());
93            let s = node.word_info().normalized_form();
94            if ctypes.intersects(CategoryType::NUMERIC | CategoryType::KANJINUMERIC)
95                || (comma_as_digit && s == ",")
96                || (period_as_digit && s == ".")
97            {
98                if begin_idx < 0 {
99                    parser.clear();
100                    begin_idx = i;
101                }
102                for c in s.chars() {
103                    if !parser.append(&c) {
104                        if begin_idx >= 0 {
105                            if parser.error_state == numeric_parser::Error::Comma {
106                                comma_as_digit = false;
107                                i = begin_idx - 1;
108                            } else if parser.error_state == numeric_parser::Error::Point {
109                                period_as_digit = false;
110                                i = begin_idx - 1;
111                            }
112                            begin_idx = -1;
113                        }
114                        break;
115                    }
116                }
117                continue;
118            }
119
120            let c = if s.len() == 1 {
121                // must be 1 byte utf-8: ASCII
122                s.as_bytes()[0] as char
123            } else {
124                char::MAX
125            };
126
127            // can't use s below this line
128
129            if begin_idx >= 0 {
130                if parser.done() {
131                    path = self.concat(path, begin_idx as usize, i as usize, &mut parser)?;
132                    i = begin_idx + 1;
133                } else {
134                    let ss = path[i as usize - 1].word_info().normalized_form();
135                    if (parser.error_state == numeric_parser::Error::Comma && ss == ",")
136                        || (parser.error_state == numeric_parser::Error::Point && ss == ".")
137                    {
138                        path =
139                            self.concat(path, begin_idx as usize, i as usize - 1, &mut parser)?;
140                        i = begin_idx + 2;
141                    }
142                }
143            }
144            begin_idx = -1;
145            if !comma_as_digit && c != ',' {
146                comma_as_digit = true;
147            }
148            if !period_as_digit && c != '.' {
149                period_as_digit = true;
150            }
151        }
152
153        // process last part
154        if begin_idx >= 0 {
155            let len = path.len();
156            if parser.done() {
157                path = self.concat(path, begin_idx as usize, len, &mut parser)?;
158            } else {
159                let ss = path[len - 1].word_info().normalized_form();
160                if (parser.error_state == numeric_parser::Error::Comma && ss == ",")
161                    || (parser.error_state == numeric_parser::Error::Point && ss == ".")
162                {
163                    path = self.concat(path, begin_idx as usize, len - 1, &mut parser)?;
164                }
165            }
166        }
167
168        Ok(path)
169    }
170}
171
172impl PathRewritePlugin for JoinNumericPlugin {
173    fn set_up(
174        &mut self,
175        settings: &Value,
176        _config: &Config,
177        grammar: &Grammar,
178    ) -> SudachiResult<()> {
179        let settings: PluginSettings = serde_json::from_value(settings.clone())?;
180
181        // this pos is fixed
182        let numeric_pos_string = vec!["名詞", "数詞", "*", "*", "*", "*"];
183        let numeric_pos_id = grammar.get_part_of_speech_id(&numeric_pos_string).ok_or(
184            SudachiError::InvalidPartOfSpeech(format!("{:?}", numeric_pos_string)),
185        )?;
186        let enable_normalize = settings.enableNormalize;
187
188        self.numeric_pos_id = numeric_pos_id;
189        self.enable_normalize = enable_normalize.unwrap_or(true);
190
191        Ok(())
192    }
193
194    fn rewrite(
195        &self,
196        text: &InputBuffer,
197        path: Vec<ResultNode>,
198        _lattice: &Lattice,
199    ) -> SudachiResult<Vec<ResultNode>> {
200        self.rewrite_gen(text, path)
201    }
202}