sudachi/plugin/path_rewrite/join_katakana_oov/
mod.rs

1/*
2 * Copyright (c) 2021 Works Applications Co., Ltd.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use serde::Deserialize;
18use serde_json::Value;
19
20use crate::analysis::lattice::Lattice;
21use crate::analysis::node::{concat_oov_nodes, LatticeNode, ResultNode};
22use crate::config::Config;
23use crate::dic::category_type::CategoryType;
24use crate::dic::grammar::Grammar;
25use crate::input_text::InputBuffer;
26use crate::input_text::InputTextIndex;
27use crate::plugin::path_rewrite::PathRewritePlugin;
28use crate::prelude::*;
29
30#[cfg(test)]
31mod tests;
32
33/// Concatenates katakana oov nodes into one
34#[derive(Default)]
35pub struct JoinKatakanaOovPlugin {
36    /// The pos_id used for concatenated node
37    oov_pos_id: u16,
38    /// The minimum node char_length to concatenate even if it is not oov
39    min_length: usize,
40}
41
42/// Struct corresponds with raw config json file.
43#[allow(non_snake_case)]
44#[derive(Deserialize)]
45struct PluginSettings {
46    oovPOS: Vec<String>,
47    minLength: usize,
48}
49
50impl JoinKatakanaOovPlugin {
51    fn is_katakana_node<T: InputTextIndex>(&self, text: &T, node: &ResultNode) -> bool {
52        text.cat_of_range(node.begin()..node.end())
53            .contains(CategoryType::KATAKANA)
54    }
55
56    // fn is_one_char(&self, text: &Utf8InputText, node: &Node) -> bool {
57    //     let b = node.begin;
58    //     b + text.get_code_points_offset_length(b, 1) == node.end
59    // }
60
61    fn can_oov_bow_node<T: InputTextIndex>(&self, text: &T, node: &ResultNode) -> bool {
62        !text
63            .cat_at_char(node.begin())
64            .contains(CategoryType::NOOOVBOW)
65    }
66
67    fn is_shorter(&self, node: &ResultNode) -> bool {
68        node.num_codepts() < self.min_length
69    }
70
71    fn rewrite_gen<T: InputTextIndex>(
72        &self,
73        text: &T,
74        mut path: Vec<ResultNode>,
75        _lattice: &Lattice,
76    ) -> SudachiResult<Vec<ResultNode>> {
77        let mut i = 0;
78        loop {
79            if i >= path.len() {
80                break;
81            }
82
83            let node = &path[i];
84            if !(node.is_oov() || self.is_shorter(node)) || !self.is_katakana_node(text, node) {
85                i += 1;
86                continue;
87            }
88            let mut begin = i as i32 - 1;
89            loop {
90                if begin < 0 {
91                    break;
92                }
93                if !self.is_katakana_node(text, &path[begin as usize]) {
94                    begin += 1;
95                    break;
96                }
97                begin -= 1;
98            }
99            let mut begin = if begin < 0 { 0 } else { begin as usize };
100            let mut end = i + 1;
101            loop {
102                if end >= path.len() {
103                    break;
104                }
105                if !self.is_katakana_node(text, &path[end]) {
106                    break;
107                }
108                end += 1;
109            }
110            while begin != end && !self.can_oov_bow_node(text, &path[begin]) {
111                begin += 1;
112            }
113
114            if (end - begin) > 1 {
115                path = concat_oov_nodes(path, begin, end, self.oov_pos_id)?;
116                // skip next node, as we already know it is not a joinable katakana
117                i = begin + 1;
118            }
119            i += 1;
120        }
121
122        Ok(path)
123    }
124}
125
126impl PathRewritePlugin for JoinKatakanaOovPlugin {
127    fn set_up(
128        &mut self,
129        settings: &Value,
130        _config: &Config,
131        grammar: &Grammar,
132    ) -> SudachiResult<()> {
133        let settings: PluginSettings = serde_json::from_value(settings.clone())?;
134
135        let oov_pos_string: Vec<&str> = settings.oovPOS.iter().map(|s| s.as_str()).collect();
136        let oov_pos_id = grammar.get_part_of_speech_id(&oov_pos_string).ok_or(
137            SudachiError::InvalidPartOfSpeech(format!("{:?}", oov_pos_string)),
138        )?;
139        let min_length = settings.minLength;
140
141        self.oov_pos_id = oov_pos_id;
142        self.min_length = min_length;
143
144        Ok(())
145    }
146
147    fn rewrite(
148        &self,
149        text: &InputBuffer,
150        path: Vec<ResultNode>,
151        lattice: &Lattice,
152    ) -> SudachiResult<Vec<ResultNode>> {
153        self.rewrite_gen(text, path, lattice)
154    }
155}