sudachi/plugin/oov/mecab_oov/
mod.rs

1/*
2 * Copyright (c) 2021-2024 Works Applications Co., Ltd.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use crate::analysis::created::CreatedWords;
18use crate::util::user_pos::{UserPosMode, UserPosSupport};
19use serde::Deserialize;
20use serde_json::Value;
21use std::collections::HashMap;
22use std::fs;
23use std::io::{BufRead, BufReader};
24use std::path::PathBuf;
25
26use crate::analysis::Node;
27use crate::config::Config;
28use crate::dic::category_type::CategoryType;
29use crate::dic::character_category::Error as CharacterCategoryError;
30use crate::dic::grammar::Grammar;
31use crate::dic::word_id::WordId;
32use crate::hash::RoMu;
33use crate::input_text::InputBuffer;
34use crate::input_text::InputTextIndex;
35use crate::plugin::oov::OovProviderPlugin;
36use crate::prelude::*;
37
38#[cfg(test)]
39mod test;
40
41const DEFAULT_CHAR_DEF_FILE: &str = "char.def";
42const DEFAULT_CHAR_DEF_BYTES: &[u8] = include_bytes!("../../../../../resources/char.def");
43const DEFAULT_UNK_DEF_FILE: &str = "unk.def";
44const DEFAULT_UNK_DEF_BYTES: &[u8] = include_bytes!("../../../../../resources/unk.def");
45
46/// provides MeCab oov nodes
47#[derive(Default)]
48pub struct MeCabOovPlugin {
49    categories: HashMap<CategoryType, CategoryInfo, RoMu>,
50    oov_list: HashMap<CategoryType, Vec<Oov>, RoMu>,
51}
52
53/// Struct corresponds with raw config json file.
54#[allow(non_snake_case)]
55#[derive(Deserialize)]
56struct PluginSettings {
57    charDef: Option<PathBuf>,
58    unkDef: Option<PathBuf>,
59    #[serde(default)]
60    userPOS: UserPosMode,
61}
62
63impl MeCabOovPlugin {
64    /// Loads character category definition
65    ///
66    /// See resources/char.def for the syntax
67    fn read_character_property<T: BufRead>(
68        reader: T,
69    ) -> SudachiResult<HashMap<CategoryType, CategoryInfo, RoMu>> {
70        let mut categories = HashMap::with_hasher(RoMu::new());
71        for (i, line) in reader.lines().enumerate() {
72            let line = line?;
73            let line = line.trim();
74            if line.is_empty()
75                || line.starts_with('#')
76                || line.chars().take(2).collect::<Vec<_>>() == vec!['0', 'x']
77            {
78                continue;
79            }
80
81            let cols: Vec<_> = line.split_whitespace().collect();
82            if cols.len() < 4 {
83                return Err(SudachiError::InvalidCharacterCategory(
84                    CharacterCategoryError::InvalidFormat(i),
85                ));
86            }
87            let category_type: CategoryType = match cols[0].parse() {
88                Ok(t) => t,
89                Err(_) => {
90                    return Err(SudachiError::InvalidCharacterCategory(
91                        CharacterCategoryError::InvalidCategoryType(i, cols[0].to_string()),
92                    ))
93                }
94            };
95            if categories.contains_key(&category_type) {
96                return Err(SudachiError::InvalidCharacterCategory(
97                    CharacterCategoryError::MultipleTypeDefinition(i, cols[0].to_string()),
98                ));
99            }
100
101            categories.insert(
102                category_type,
103                CategoryInfo {
104                    category_type,
105                    is_invoke: cols[1] == "1",
106                    is_group: cols[2] == "1",
107                    length: cols[3].parse()?,
108                },
109            );
110        }
111
112        Ok(categories)
113    }
114
115    /// Load OOV definition
116    ///
117    /// Each line contains: CategoryType, left_id, right_id, cost, and pos
118    fn read_oov<T: BufRead>(
119        reader: T,
120        categories: &HashMap<CategoryType, CategoryInfo, RoMu>,
121        mut grammar: &mut Grammar,
122        user_pos: UserPosMode,
123    ) -> SudachiResult<HashMap<CategoryType, Vec<Oov>, RoMu>> {
124        let mut oov_list: HashMap<CategoryType, Vec<Oov>, RoMu> = HashMap::with_hasher(RoMu::new());
125        for (i, line) in reader.lines().enumerate() {
126            let line = line?;
127            let line = line.trim();
128            if line.is_empty() || line.starts_with('#') {
129                continue;
130            }
131
132            let cols: Vec<_> = line.split(',').collect();
133            if cols.len() < 10 {
134                return Err(SudachiError::InvalidDataFormat(
135                    i,
136                    format!("Invalid number of columns ({})", line),
137                ));
138            }
139            let category_type: CategoryType = cols[0].parse()?;
140            if !categories.contains_key(&category_type) {
141                return Err(SudachiError::InvalidDataFormat(
142                    i,
143                    format!("{} is undefined in char definition", cols[0]),
144                ));
145            }
146
147            let oov = Oov {
148                left_id: cols[1].parse()?,
149                right_id: cols[2].parse()?,
150                cost: cols[3].parse()?,
151                pos_id: grammar.handle_user_pos(&cols[4..10], user_pos)?,
152            };
153
154            if oov.left_id as usize > grammar.conn_matrix().num_left() {
155                return Err(SudachiError::InvalidDataFormat(
156                    0,
157                    format!(
158                        "max grammar left_id is {}, was {}",
159                        grammar.conn_matrix().num_left(),
160                        oov.left_id
161                    ),
162                ));
163            }
164
165            if oov.right_id as usize > grammar.conn_matrix().num_right() {
166                return Err(SudachiError::InvalidDataFormat(
167                    0,
168                    format!(
169                        "max grammar right_id is {}, was {}",
170                        grammar.conn_matrix().num_right(),
171                        oov.right_id
172                    ),
173                ));
174            }
175
176            match oov_list.get_mut(&category_type) {
177                None => {
178                    oov_list.insert(category_type, vec![oov]);
179                }
180                Some(l) => {
181                    l.push(oov);
182                }
183            };
184        }
185
186        Ok(oov_list)
187    }
188
189    /// Creates a new oov node
190    fn get_oov_node(&self, oov: &Oov, start: usize, end: usize) -> Node {
191        Node::new(
192            start as u16,
193            end as u16,
194            oov.left_id as u16,
195            oov.right_id as u16,
196            oov.cost,
197            WordId::oov(oov.pos_id as u32),
198        )
199    }
200
201    fn provide_oov_gen<T: InputTextIndex>(
202        &self,
203        input: &T,
204        offset: usize,
205        other_words: CreatedWords,
206        nodes: &mut Vec<Node>,
207    ) -> SudachiResult<usize> {
208        let char_len = input.cat_continuous_len(offset);
209        if char_len == 0 {
210            return Ok(0);
211        }
212        let mut num_created = 0;
213
214        for ctype in input.cat_at_char(offset).iter() {
215            let cinfo = match self.categories.get(&ctype) {
216                Some(ci) => ci,
217                None => continue,
218            };
219
220            if !cinfo.is_invoke && other_words.not_empty() {
221                continue;
222            }
223
224            let mut llength = char_len;
225            let oovs = match self.oov_list.get(&cinfo.category_type) {
226                Some(v) => v,
227                None => continue,
228            };
229
230            if cinfo.is_group {
231                for oov in oovs {
232                    nodes.push(self.get_oov_node(oov, offset, offset + char_len));
233                    num_created += 1;
234                }
235                llength -= 1;
236            }
237            for i in 1..=cinfo.length {
238                let sublength = input.char_distance(offset, i as usize);
239                if sublength > llength {
240                    break;
241                }
242                for oov in oovs {
243                    nodes.push(self.get_oov_node(oov, offset, offset + sublength));
244                    num_created += 1;
245                }
246            }
247        }
248        Ok(num_created)
249    }
250}
251
252impl OovProviderPlugin for MeCabOovPlugin {
253    fn set_up(
254        &mut self,
255        settings: &Value,
256        config: &Config,
257        grammar: &mut Grammar,
258    ) -> SudachiResult<()> {
259        let settings: PluginSettings = serde_json::from_value(settings.clone())?;
260
261        let char_def_path = config.complete_path(
262            settings
263                .charDef
264                .unwrap_or_else(|| PathBuf::from(DEFAULT_CHAR_DEF_FILE)),
265        );
266
267        let categories = if char_def_path.is_ok() {
268            let reader = BufReader::new(fs::File::open(char_def_path?)?);
269            MeCabOovPlugin::read_character_property(reader)?
270        } else {
271            let reader = BufReader::new(DEFAULT_CHAR_DEF_BYTES);
272            MeCabOovPlugin::read_character_property(reader)?
273        };
274
275        let unk_def_path = config.complete_path(
276            settings
277                .unkDef
278                .unwrap_or_else(|| PathBuf::from(DEFAULT_UNK_DEF_FILE)),
279        );
280
281        let oov_list = if unk_def_path.is_ok() {
282            let reader = BufReader::new(fs::File::open(unk_def_path?)?);
283            MeCabOovPlugin::read_oov(reader, &categories, grammar, settings.userPOS)?
284        } else {
285            let reader = BufReader::new(DEFAULT_UNK_DEF_BYTES);
286            MeCabOovPlugin::read_oov(reader, &categories, grammar, settings.userPOS)?
287        };
288
289        self.categories = categories;
290        self.oov_list = oov_list;
291
292        Ok(())
293    }
294
295    fn provide_oov(
296        &self,
297        input_text: &InputBuffer,
298        offset: usize,
299        other_words: CreatedWords,
300        result: &mut Vec<Node>,
301    ) -> SudachiResult<usize> {
302        self.provide_oov_gen(input_text, offset, other_words, result)
303    }
304}
305
306/// The character category definition
307#[derive(Debug)]
308struct CategoryInfo {
309    category_type: CategoryType,
310    is_invoke: bool,
311    is_group: bool,
312    length: u32,
313}
314
315/// The OOV definition
316#[derive(Debug, Default, Clone)]
317struct Oov {
318    left_id: i16,
319    right_id: i16,
320    cost: i16,
321    pos_id: u16,
322}