1use crate::dic::character_category::CharacterCategory;
18use crate::dic::connect::ConnectionMatrix;
19use crate::dic::read::u16str::utf16_string_parser;
20use crate::dic::POS_DEPTH;
21use crate::error::SudachiNomResult;
22use crate::prelude::*;
23use itertools::Itertools;
24use nom::{
25 bytes::complete::take,
26 number::complete::{le_i16, le_u16},
27};
28use std::ops::Index;
29
30type PosList = Vec<Vec<String>>;
31
32pub struct Grammar<'a> {
37 _bytes: &'a [u8],
38 pub pos_list: PosList,
39 pub storage_size: usize,
40
41 connection: ConnectionMatrix<'a>,
43
44 pub character_category: CharacterCategory,
46}
47
48impl<'a> Grammar<'a> {
49 pub const INHIBITED_CONNECTION: i16 = i16::MAX;
50
51 pub const BOS_PARAMETER: (i16, i16, i16) = (0, 0, 0); pub const EOS_PARAMETER: (i16, i16, i16) = (0, 0, 0); pub fn parse(buf: &[u8], offset: usize) -> SudachiResult<Grammar> {
59 let (rest, (pos_list, left_id_size, right_id_size)) = grammar_parser(buf, offset)
60 .map_err(|e| SudachiError::InvalidDictionaryGrammar.with_context(e.to_string()))?;
61
62 let connect_table_offset = buf.len() - rest.len();
63 let storage_size =
64 (connect_table_offset - offset) + 2 * left_id_size as usize * right_id_size as usize;
65
66 let conn = ConnectionMatrix::from_offset_size(
67 buf,
68 connect_table_offset,
69 left_id_size as usize,
70 right_id_size as usize,
71 )?;
72
73 Ok(Grammar {
74 _bytes: buf,
75 pos_list,
76 connection: conn,
77 storage_size,
78 character_category: CharacterCategory::default(),
79 })
80 }
81
82 #[inline(always)]
87 pub fn connect_cost(&self, left_id: i16, right_id: i16) -> i16 {
88 self.connection.cost(left_id as u16, right_id as u16)
89 }
90
91 #[inline]
92 pub fn conn_matrix(&self) -> &ConnectionMatrix {
93 &self.connection
94 }
95
96 pub fn set_character_category(&mut self, character_category: CharacterCategory) {
101 self.character_category = character_category;
102 }
103
104 pub fn set_connect_cost(&mut self, left_id: i16, right_id: i16, cost: i16) {
109 self.connection
111 .update(left_id as u16, right_id as u16, cost);
112 }
113
114 pub fn get_part_of_speech_id<S>(&self, pos1: &[S]) -> Option<u16>
116 where
117 S: AsRef<str>,
118 {
119 if pos1.len() != POS_DEPTH {
120 return None;
121 }
122 for (i, pos2) in self.pos_list.iter().enumerate() {
123 if pos1.iter().zip(pos2).all(|(a, b)| a.as_ref() == b) {
124 return Some(i as u16);
125 }
126 }
127 None
128 }
129
130 pub fn register_pos<S>(&mut self, pos: &[S]) -> SudachiResult<u16>
131 where
132 S: AsRef<str> + ToString,
133 {
134 if pos.len() != POS_DEPTH {
135 let pos_string = pos.iter().map(|x| x.as_ref()).join(",");
136 return Err(SudachiError::InvalidPartOfSpeech(pos_string));
137 }
138 match self.get_part_of_speech_id(pos) {
139 Some(id) => Ok(id),
140 None => {
141 let new_id = self.pos_list.len();
142 if new_id > u16::MAX as usize {
143 return Err(SudachiError::InvalidPartOfSpeech(
144 "Too much POS tags registered".to_owned(),
145 ));
146 }
147 let components = pos.iter().map(|x| x.to_string()).collect();
148 self.pos_list.push(components);
149 Ok(new_id as u16)
150 }
151 }
152 }
153
154 pub fn pos_components(&self, pos_id: u16) -> &[String] {
157 self.pos_list.index(pos_id as usize)
158 }
159
160 pub fn merge(&mut self, other: Grammar) {
164 self.pos_list.extend(other.pos_list);
165 }
166}
167
168fn pos_list_parser(input: &[u8]) -> SudachiNomResult<&[u8], PosList> {
169 let (rest, pos_size) = le_u16(input)?;
170 nom::multi::count(
171 nom::multi::count(utf16_string_parser, POS_DEPTH),
172 pos_size as usize,
173 )(rest)
174}
175
176fn grammar_parser(input: &[u8], offset: usize) -> SudachiNomResult<&[u8], (PosList, i16, i16)> {
177 nom::sequence::preceded(
178 take(offset),
179 nom::sequence::tuple((pos_list_parser, le_i16, le_i16)),
180 )(input)
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186
187 #[test]
188 fn storage_size() {
189 let bytes = setup_bytes();
190 let grammar = Grammar::parse(&bytes, 0).expect("failed to create grammar");
191 assert_eq!(bytes.len(), grammar.storage_size);
192 }
193
194 #[test]
195 fn partofspeech_string() {
196 let bytes = setup_bytes();
197 let grammar = Grammar::parse(&bytes, 0).expect("failed to create grammar");
198 assert_eq!(6, grammar.pos_list[0].len());
199 assert_eq!("BOS/EOS", grammar.pos_list[0][0]);
200 assert_eq!("*", grammar.pos_list[0][5]);
201
202 assert_eq!("一般", grammar.pos_list[1][1]);
203 assert_eq!("*", grammar.pos_list[1][5]);
204
205 assert_eq!("五段-サ行", grammar.pos_list[2][4]);
206 assert_eq!("終止形-一般", grammar.pos_list[2][5]);
207 }
208
209 #[test]
210 fn get_connect_cost() {
211 let bytes = setup_bytes();
212 let grammar = Grammar::parse(&bytes, 0).expect("failed to create grammar");
213 assert_eq!(0, grammar.connect_cost(0, 0));
214 assert_eq!(-100, grammar.connect_cost(2, 1));
215 assert_eq!(200, grammar.connect_cost(1, 2));
216 }
217
218 #[test]
219 fn set_connect_cost() {
220 let bytes = setup_bytes();
221 let mut grammar = Grammar::parse(&bytes, 0).expect("failed to create grammar");
222 grammar.set_connect_cost(0, 0, 300);
223 assert_eq!(300, grammar.connect_cost(0, 0));
224 }
225
226 #[test]
227 fn register_pos() {
228 let bytes = setup_bytes();
229 let mut grammar = Grammar::parse(&bytes, 0).expect("failed to create grammar");
230
231 let id1 = grammar
232 .register_pos(["a", "b", "c", "d", "e", "f"].as_slice())
233 .expect("failed");
234 let id2 = grammar
235 .register_pos(["a", "b", "c", "d", "e", "f"].as_slice())
236 .expect("failed");
237 assert_eq!(id1, id2);
238 }
239
240 #[test]
241 fn bos_parameter() {
242 assert_eq!(0, Grammar::BOS_PARAMETER.0);
243 assert_eq!(0, Grammar::BOS_PARAMETER.1);
244 assert_eq!(0, Grammar::BOS_PARAMETER.2);
245 }
246
247 #[test]
248 fn eos_parameter() {
249 assert_eq!(0, Grammar::EOS_PARAMETER.0);
250 assert_eq!(0, Grammar::EOS_PARAMETER.1);
251 assert_eq!(0, Grammar::EOS_PARAMETER.2);
252 }
253
254 fn setup_bytes() -> Vec<u8> {
255 let mut storage: Vec<u8> = Vec::new();
256 build_partofspeech(&mut storage);
257 build_connect_table(&mut storage);
258 storage
259 }
260 fn string_to_bytes(s: &str) -> Vec<u8> {
261 s.encode_utf16().flat_map(|c| c.to_le_bytes()).collect()
262 }
263 fn build_partofspeech(storage: &mut Vec<u8>) {
264 storage.extend(&3_i16.to_le_bytes());
266
267 storage.extend(
268 b"\x07B\x00O\x00S\x00/\x00E\x00O\x00S\x00\x01*\x00\x01*\x00\x01*\x00\x01*\x00\x01*\x00",
269 );
270
271 storage.extend(b"\x02");
272 storage.extend(string_to_bytes("名刺"));
273 storage.extend(b"\x02");
274 storage.extend(string_to_bytes("一般"));
275 storage.extend(b"\x01*\x00\x01*\x00\x01*\x00\x01*\x00");
276
277 storage.extend(b"\x02");
278 storage.extend(string_to_bytes("動詞"));
279 storage.extend(b"\x02");
280 storage.extend(string_to_bytes("一般"));
281 storage.extend(b"\x01*\x00\x01*\x00\x05");
282 storage.extend(string_to_bytes("五段-サ行"));
283 storage.extend(b"\x06");
284 storage.extend(string_to_bytes("終止形-一般"));
285 }
286 fn build_connect_table(storage: &mut Vec<u8>) {
287 storage.extend(&3_i16.to_le_bytes());
288 storage.extend(&3_i16.to_le_bytes());
289
290 storage.extend(&0_i16.to_le_bytes());
291 storage.extend(&(-300_i16).to_le_bytes());
292 storage.extend(&300_i16.to_le_bytes());
293
294 storage.extend(&300_i16.to_le_bytes());
295 storage.extend(&(-500_i16).to_le_bytes());
296 storage.extend(&(-100_i16).to_le_bytes());
297
298 storage.extend(&(-3000_i16).to_le_bytes());
299 storage.extend(&200_i16.to_le_bytes());
300 storage.extend(&2000_i16.to_le_bytes());
301 }
302}