sudachi/dic/
grammar.rs

1/*
2 * Copyright (c) 2021-2024 Works Applications Co., Ltd.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17use crate::dic::character_category::CharacterCategory;
18use crate::dic::connect::ConnectionMatrix;
19use crate::dic::read::u16str::utf16_string_parser;
20use crate::dic::POS_DEPTH;
21use crate::error::SudachiNomResult;
22use crate::prelude::*;
23use itertools::Itertools;
24use nom::{
25    bytes::complete::take,
26    number::complete::{le_i16, le_u16},
27};
28use std::ops::Index;
29
30type PosList = Vec<Vec<String>>;
31
32/// Dictionary grammar
33///
34/// Contains part_of_speech list and connection cost map.
35/// It also holds character category.
36pub struct Grammar<'a> {
37    _bytes: &'a [u8],
38    pub pos_list: PosList,
39    pub storage_size: usize,
40
41    /// The mapping to overload cost table
42    connection: ConnectionMatrix<'a>,
43
44    /// The mapping from character to character_category_type
45    pub character_category: CharacterCategory,
46}
47
48impl<'a> Grammar<'a> {
49    pub const INHIBITED_CONNECTION: i16 = i16::MAX;
50
51    pub const BOS_PARAMETER: (i16, i16, i16) = (0, 0, 0); // left_id, right_id, cost
52    pub const EOS_PARAMETER: (i16, i16, i16) = (0, 0, 0); // left_id, right_id, cost
53
54    /// Creates a Grammar from dictionary bytes
55    ///
56    /// buf: reference to the dictionary bytes
57    /// offset: offset to the grammar section in the buf
58    pub fn parse(buf: &[u8], offset: usize) -> SudachiResult<Grammar> {
59        let (rest, (pos_list, left_id_size, right_id_size)) = grammar_parser(buf, offset)
60            .map_err(|e| SudachiError::InvalidDictionaryGrammar.with_context(e.to_string()))?;
61
62        let connect_table_offset = buf.len() - rest.len();
63        let storage_size =
64            (connect_table_offset - offset) + 2 * left_id_size as usize * right_id_size as usize;
65
66        let conn = ConnectionMatrix::from_offset_size(
67            buf,
68            connect_table_offset,
69            left_id_size as usize,
70            right_id_size as usize,
71        )?;
72
73        Ok(Grammar {
74            _bytes: buf,
75            pos_list,
76            connection: conn,
77            storage_size,
78            character_category: CharacterCategory::default(),
79        })
80    }
81
82    /// Returns connection cost of nodes
83    ///
84    /// left_id: right_id of left node
85    /// right_id: left_if of right node
86    #[inline(always)]
87    pub fn connect_cost(&self, left_id: i16, right_id: i16) -> i16 {
88        self.connection.cost(left_id as u16, right_id as u16)
89    }
90
91    #[inline]
92    pub fn conn_matrix(&self) -> &ConnectionMatrix {
93        &self.connection
94    }
95
96    /// Sets character category
97    ///
98    /// This is the only way to set character category.
99    /// Character category will be a empty map by default.
100    pub fn set_character_category(&mut self, character_category: CharacterCategory) {
101        self.character_category = character_category;
102    }
103
104    /// Sets connect cost for a specific pair of ids
105    ///
106    /// left_id: right_id of left node
107    /// right_id: left_if of right node
108    pub fn set_connect_cost(&mut self, left_id: i16, right_id: i16, cost: i16) {
109        // for edit connection cost plugin
110        self.connection
111            .update(left_id as u16, right_id as u16, cost);
112    }
113
114    /// Returns a pos_id of given pos in the grammar
115    pub fn get_part_of_speech_id<S>(&self, pos1: &[S]) -> Option<u16>
116    where
117        S: AsRef<str>,
118    {
119        if pos1.len() != POS_DEPTH {
120            return None;
121        }
122        for (i, pos2) in self.pos_list.iter().enumerate() {
123            if pos1.iter().zip(pos2).all(|(a, b)| a.as_ref() == b) {
124                return Some(i as u16);
125            }
126        }
127        None
128    }
129
130    pub fn register_pos<S>(&mut self, pos: &[S]) -> SudachiResult<u16>
131    where
132        S: AsRef<str> + ToString,
133    {
134        if pos.len() != POS_DEPTH {
135            let pos_string = pos.iter().map(|x| x.as_ref()).join(",");
136            return Err(SudachiError::InvalidPartOfSpeech(pos_string));
137        }
138        match self.get_part_of_speech_id(pos) {
139            Some(id) => Ok(id),
140            None => {
141                let new_id = self.pos_list.len();
142                if new_id > u16::MAX as usize {
143                    return Err(SudachiError::InvalidPartOfSpeech(
144                        "Too much POS tags registered".to_owned(),
145                    ));
146                }
147                let components = pos.iter().map(|x| x.to_string()).collect();
148                self.pos_list.push(components);
149                Ok(new_id as u16)
150            }
151        }
152    }
153
154    /// Gets POS components for POS ID.
155    /// Panics if out of bounds.
156    pub fn pos_components(&self, pos_id: u16) -> &[String] {
157        self.pos_list.index(pos_id as usize)
158    }
159
160    /// Merge a another grammar into this grammar
161    ///
162    /// Only pos_list is merged
163    pub fn merge(&mut self, other: Grammar) {
164        self.pos_list.extend(other.pos_list);
165    }
166}
167
168fn pos_list_parser(input: &[u8]) -> SudachiNomResult<&[u8], PosList> {
169    let (rest, pos_size) = le_u16(input)?;
170    nom::multi::count(
171        nom::multi::count(utf16_string_parser, POS_DEPTH),
172        pos_size as usize,
173    )(rest)
174}
175
176fn grammar_parser(input: &[u8], offset: usize) -> SudachiNomResult<&[u8], (PosList, i16, i16)> {
177    nom::sequence::preceded(
178        take(offset),
179        nom::sequence::tuple((pos_list_parser, le_i16, le_i16)),
180    )(input)
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186
187    #[test]
188    fn storage_size() {
189        let bytes = setup_bytes();
190        let grammar = Grammar::parse(&bytes, 0).expect("failed to create grammar");
191        assert_eq!(bytes.len(), grammar.storage_size);
192    }
193
194    #[test]
195    fn partofspeech_string() {
196        let bytes = setup_bytes();
197        let grammar = Grammar::parse(&bytes, 0).expect("failed to create grammar");
198        assert_eq!(6, grammar.pos_list[0].len());
199        assert_eq!("BOS/EOS", grammar.pos_list[0][0]);
200        assert_eq!("*", grammar.pos_list[0][5]);
201
202        assert_eq!("一般", grammar.pos_list[1][1]);
203        assert_eq!("*", grammar.pos_list[1][5]);
204
205        assert_eq!("五段-サ行", grammar.pos_list[2][4]);
206        assert_eq!("終止形-一般", grammar.pos_list[2][5]);
207    }
208
209    #[test]
210    fn get_connect_cost() {
211        let bytes = setup_bytes();
212        let grammar = Grammar::parse(&bytes, 0).expect("failed to create grammar");
213        assert_eq!(0, grammar.connect_cost(0, 0));
214        assert_eq!(-100, grammar.connect_cost(2, 1));
215        assert_eq!(200, grammar.connect_cost(1, 2));
216    }
217
218    #[test]
219    fn set_connect_cost() {
220        let bytes = setup_bytes();
221        let mut grammar = Grammar::parse(&bytes, 0).expect("failed to create grammar");
222        grammar.set_connect_cost(0, 0, 300);
223        assert_eq!(300, grammar.connect_cost(0, 0));
224    }
225
226    #[test]
227    fn register_pos() {
228        let bytes = setup_bytes();
229        let mut grammar = Grammar::parse(&bytes, 0).expect("failed to create grammar");
230
231        let id1 = grammar
232            .register_pos(["a", "b", "c", "d", "e", "f"].as_slice())
233            .expect("failed");
234        let id2 = grammar
235            .register_pos(["a", "b", "c", "d", "e", "f"].as_slice())
236            .expect("failed");
237        assert_eq!(id1, id2);
238    }
239
240    #[test]
241    fn bos_parameter() {
242        assert_eq!(0, Grammar::BOS_PARAMETER.0);
243        assert_eq!(0, Grammar::BOS_PARAMETER.1);
244        assert_eq!(0, Grammar::BOS_PARAMETER.2);
245    }
246
247    #[test]
248    fn eos_parameter() {
249        assert_eq!(0, Grammar::EOS_PARAMETER.0);
250        assert_eq!(0, Grammar::EOS_PARAMETER.1);
251        assert_eq!(0, Grammar::EOS_PARAMETER.2);
252    }
253
254    fn setup_bytes() -> Vec<u8> {
255        let mut storage: Vec<u8> = Vec::new();
256        build_partofspeech(&mut storage);
257        build_connect_table(&mut storage);
258        storage
259    }
260    fn string_to_bytes(s: &str) -> Vec<u8> {
261        s.encode_utf16().flat_map(|c| c.to_le_bytes()).collect()
262    }
263    fn build_partofspeech(storage: &mut Vec<u8>) {
264        // number of part of speech
265        storage.extend(&3_i16.to_le_bytes());
266
267        storage.extend(
268            b"\x07B\x00O\x00S\x00/\x00E\x00O\x00S\x00\x01*\x00\x01*\x00\x01*\x00\x01*\x00\x01*\x00",
269        );
270
271        storage.extend(b"\x02");
272        storage.extend(string_to_bytes("名刺"));
273        storage.extend(b"\x02");
274        storage.extend(string_to_bytes("一般"));
275        storage.extend(b"\x01*\x00\x01*\x00\x01*\x00\x01*\x00");
276
277        storage.extend(b"\x02");
278        storage.extend(string_to_bytes("動詞"));
279        storage.extend(b"\x02");
280        storage.extend(string_to_bytes("一般"));
281        storage.extend(b"\x01*\x00\x01*\x00\x05");
282        storage.extend(string_to_bytes("五段-サ行"));
283        storage.extend(b"\x06");
284        storage.extend(string_to_bytes("終止形-一般"));
285    }
286    fn build_connect_table(storage: &mut Vec<u8>) {
287        storage.extend(&3_i16.to_le_bytes());
288        storage.extend(&3_i16.to_le_bytes());
289
290        storage.extend(&0_i16.to_le_bytes());
291        storage.extend(&(-300_i16).to_le_bytes());
292        storage.extend(&300_i16.to_le_bytes());
293
294        storage.extend(&300_i16.to_le_bytes());
295        storage.extend(&(-500_i16).to_le_bytes());
296        storage.extend(&(-100_i16).to_le_bytes());
297
298        storage.extend(&(-3000_i16).to_le_bytes());
299        storage.extend(&200_i16.to_le_bytes());
300        storage.extend(&2000_i16.to_le_bytes());
301    }
302}