sudachi/plugin/path_rewrite/join_katakana_oov/
mod.rs1use serde::Deserialize;
18use serde_json::Value;
19
20use crate::analysis::lattice::Lattice;
21use crate::analysis::node::{concat_oov_nodes, LatticeNode, ResultNode};
22use crate::config::Config;
23use crate::dic::category_type::CategoryType;
24use crate::dic::grammar::Grammar;
25use crate::input_text::InputBuffer;
26use crate::input_text::InputTextIndex;
27use crate::plugin::path_rewrite::PathRewritePlugin;
28use crate::prelude::*;
29
30#[cfg(test)]
31mod tests;
32
33#[derive(Default)]
35pub struct JoinKatakanaOovPlugin {
36 oov_pos_id: u16,
38 min_length: usize,
40}
41
42#[allow(non_snake_case)]
44#[derive(Deserialize)]
45struct PluginSettings {
46 oovPOS: Vec<String>,
47 minLength: usize,
48}
49
50impl JoinKatakanaOovPlugin {
51 fn is_katakana_node<T: InputTextIndex>(&self, text: &T, node: &ResultNode) -> bool {
52 text.cat_of_range(node.begin()..node.end())
53 .contains(CategoryType::KATAKANA)
54 }
55
56 fn can_oov_bow_node<T: InputTextIndex>(&self, text: &T, node: &ResultNode) -> bool {
62 !text
63 .cat_at_char(node.begin())
64 .contains(CategoryType::NOOOVBOW)
65 }
66
67 fn is_shorter(&self, node: &ResultNode) -> bool {
68 node.num_codepts() < self.min_length
69 }
70
71 fn rewrite_gen<T: InputTextIndex>(
72 &self,
73 text: &T,
74 mut path: Vec<ResultNode>,
75 _lattice: &Lattice,
76 ) -> SudachiResult<Vec<ResultNode>> {
77 let mut i = 0;
78 loop {
79 if i >= path.len() {
80 break;
81 }
82
83 let node = &path[i];
84 if !(node.is_oov() || self.is_shorter(node)) || !self.is_katakana_node(text, node) {
85 i += 1;
86 continue;
87 }
88 let mut begin = i as i32 - 1;
89 loop {
90 if begin < 0 {
91 break;
92 }
93 if !self.is_katakana_node(text, &path[begin as usize]) {
94 begin += 1;
95 break;
96 }
97 begin -= 1;
98 }
99 let mut begin = if begin < 0 { 0 } else { begin as usize };
100 let mut end = i + 1;
101 loop {
102 if end >= path.len() {
103 break;
104 }
105 if !self.is_katakana_node(text, &path[end]) {
106 break;
107 }
108 end += 1;
109 }
110 while begin != end && !self.can_oov_bow_node(text, &path[begin]) {
111 begin += 1;
112 }
113
114 if (end - begin) > 1 {
115 path = concat_oov_nodes(path, begin, end, self.oov_pos_id)?;
116 i = begin + 1;
118 }
119 i += 1;
120 }
121
122 Ok(path)
123 }
124}
125
126impl PathRewritePlugin for JoinKatakanaOovPlugin {
127 fn set_up(
128 &mut self,
129 settings: &Value,
130 _config: &Config,
131 grammar: &Grammar,
132 ) -> SudachiResult<()> {
133 let settings: PluginSettings = serde_json::from_value(settings.clone())?;
134
135 let oov_pos_string: Vec<&str> = settings.oovPOS.iter().map(|s| s.as_str()).collect();
136 let oov_pos_id = grammar.get_part_of_speech_id(&oov_pos_string).ok_or(
137 SudachiError::InvalidPartOfSpeech(format!("{:?}", oov_pos_string)),
138 )?;
139 let min_length = settings.minLength;
140
141 self.oov_pos_id = oov_pos_id;
142 self.min_length = min_length;
143
144 Ok(())
145 }
146
147 fn rewrite(
148 &self,
149 text: &InputBuffer,
150 path: Vec<ResultNode>,
151 lattice: &Lattice,
152 ) -> SudachiResult<Vec<ResultNode>> {
153 self.rewrite_gen(text, path, lattice)
154 }
155}