use crate::dic::character_category::CharacterCategory;
use crate::dic::connect::ConnectionMatrix;
use crate::dic::read::u16str::utf16_string_parser;
use crate::dic::POS_DEPTH;
use crate::error::SudachiNomResult;
use crate::prelude::*;
use itertools::Itertools;
use nom::{
bytes::complete::take,
number::complete::{le_i16, le_u16},
};
use std::ops::Index;
type PosList = Vec<Vec<String>>;
pub struct Grammar<'a> {
_bytes: &'a [u8],
pub pos_list: PosList,
pub storage_size: usize,
connection: ConnectionMatrix<'a>,
pub character_category: CharacterCategory,
}
impl<'a> Grammar<'a> {
pub const INHIBITED_CONNECTION: i16 = i16::MAX;
pub const BOS_PARAMETER: (i16, i16, i16) = (0, 0, 0); pub const EOS_PARAMETER: (i16, i16, i16) = (0, 0, 0); pub fn parse(buf: &[u8], offset: usize) -> SudachiResult<Grammar> {
let (rest, (pos_list, left_id_size, right_id_size)) = grammar_parser(buf, offset)
.map_err(|e| SudachiError::InvalidDictionaryGrammar.with_context(e.to_string()))?;
let connect_table_offset = buf.len() - rest.len();
let storage_size =
(connect_table_offset - offset) + 2 * left_id_size as usize * right_id_size as usize;
let conn = ConnectionMatrix::from_offset_size(
buf,
connect_table_offset,
left_id_size as usize,
right_id_size as usize,
)?;
Ok(Grammar {
_bytes: buf,
pos_list,
connection: conn,
storage_size,
character_category: CharacterCategory::default(),
})
}
#[inline(always)]
pub fn connect_cost(&self, left_id: i16, right_id: i16) -> i16 {
self.connection.cost(left_id as u16, right_id as u16)
}
#[inline]
pub fn conn_matrix(&self) -> &ConnectionMatrix {
&self.connection
}
pub fn set_character_category(&mut self, character_category: CharacterCategory) {
self.character_category = character_category;
}
pub fn set_connect_cost(&mut self, left_id: i16, right_id: i16, cost: i16) {
self.connection
.update(left_id as u16, right_id as u16, cost);
}
pub fn get_part_of_speech_id<S>(&self, pos1: &[S]) -> Option<u16>
where
S: AsRef<str>,
{
if pos1.len() != POS_DEPTH {
return None;
}
for (i, pos2) in self.pos_list.iter().enumerate() {
if pos1.iter().zip(pos2).all(|(a, b)| a.as_ref() == b) {
return Some(i as u16);
}
}
None
}
pub fn register_pos<S>(&mut self, pos: &[S]) -> SudachiResult<u16>
where
S: AsRef<str> + ToString,
{
if pos.len() != POS_DEPTH {
let pos_string = pos.iter().map(|x| x.as_ref()).join(",");
return Err(SudachiError::InvalidPartOfSpeech(pos_string));
}
match self.get_part_of_speech_id(pos) {
Some(id) => Ok(id),
None => {
let new_id = self.pos_list.len();
if new_id > u16::MAX as usize {
return Err(SudachiError::InvalidPartOfSpeech(
"Too much POS tags registered".to_owned(),
));
}
let components = pos.iter().map(|x| x.to_string()).collect();
self.pos_list.push(components);
Ok(new_id as u16)
}
}
}
pub fn pos_components(&self, pos_id: u16) -> &[String] {
self.pos_list.index(pos_id as usize)
}
pub fn merge(&mut self, other: Grammar) {
self.pos_list.extend(other.pos_list);
}
}
fn pos_list_parser(input: &[u8]) -> SudachiNomResult<&[u8], PosList> {
let (rest, pos_size) = le_u16(input)?;
nom::multi::count(
nom::multi::count(utf16_string_parser, POS_DEPTH),
pos_size as usize,
)(rest)
}
fn grammar_parser(input: &[u8], offset: usize) -> SudachiNomResult<&[u8], (PosList, i16, i16)> {
nom::sequence::preceded(
take(offset),
nom::sequence::tuple((pos_list_parser, le_i16, le_i16)),
)(input)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn storage_size() {
let bytes = setup_bytes();
let grammar = Grammar::parse(&bytes, 0).expect("failed to create grammar");
assert_eq!(bytes.len(), grammar.storage_size);
}
#[test]
fn partofspeech_string() {
let bytes = setup_bytes();
let grammar = Grammar::parse(&bytes, 0).expect("failed to create grammar");
assert_eq!(6, grammar.pos_list[0].len());
assert_eq!("BOS/EOS", grammar.pos_list[0][0]);
assert_eq!("*", grammar.pos_list[0][5]);
assert_eq!("一般", grammar.pos_list[1][1]);
assert_eq!("*", grammar.pos_list[1][5]);
assert_eq!("五段-サ行", grammar.pos_list[2][4]);
assert_eq!("終止形-一般", grammar.pos_list[2][5]);
}
#[test]
fn get_connect_cost() {
let bytes = setup_bytes();
let grammar = Grammar::parse(&bytes, 0).expect("failed to create grammar");
assert_eq!(0, grammar.connect_cost(0, 0));
assert_eq!(-100, grammar.connect_cost(2, 1));
assert_eq!(200, grammar.connect_cost(1, 2));
}
#[test]
fn set_connect_cost() {
let bytes = setup_bytes();
let mut grammar = Grammar::parse(&bytes, 0).expect("failed to create grammar");
grammar.set_connect_cost(0, 0, 300);
assert_eq!(300, grammar.connect_cost(0, 0));
}
#[test]
fn register_pos() {
let bytes = setup_bytes();
let mut grammar = Grammar::parse(&bytes, 0).expect("failed to create grammar");
let id1 = grammar
.register_pos(["a", "b", "c", "d", "e", "f"].as_slice())
.expect("failed");
let id2 = grammar
.register_pos(["a", "b", "c", "d", "e", "f"].as_slice())
.expect("failed");
assert_eq!(id1, id2);
}
#[test]
fn bos_parameter() {
assert_eq!(0, Grammar::BOS_PARAMETER.0);
assert_eq!(0, Grammar::BOS_PARAMETER.1);
assert_eq!(0, Grammar::BOS_PARAMETER.2);
}
#[test]
fn eos_parameter() {
assert_eq!(0, Grammar::EOS_PARAMETER.0);
assert_eq!(0, Grammar::EOS_PARAMETER.1);
assert_eq!(0, Grammar::EOS_PARAMETER.2);
}
fn setup_bytes() -> Vec<u8> {
let mut storage: Vec<u8> = Vec::new();
build_partofspeech(&mut storage);
build_connect_table(&mut storage);
storage
}
fn string_to_bytes(s: &str) -> Vec<u8> {
s.encode_utf16().flat_map(|c| c.to_le_bytes()).collect()
}
fn build_partofspeech(storage: &mut Vec<u8>) {
storage.extend(&3_i16.to_le_bytes());
storage.extend(
b"\x07B\x00O\x00S\x00/\x00E\x00O\x00S\x00\x01*\x00\x01*\x00\x01*\x00\x01*\x00\x01*\x00",
);
storage.extend(b"\x02");
storage.extend(string_to_bytes("名刺"));
storage.extend(b"\x02");
storage.extend(string_to_bytes("一般"));
storage.extend(b"\x01*\x00\x01*\x00\x01*\x00\x01*\x00");
storage.extend(b"\x02");
storage.extend(string_to_bytes("動詞"));
storage.extend(b"\x02");
storage.extend(string_to_bytes("一般"));
storage.extend(b"\x01*\x00\x01*\x00\x05");
storage.extend(string_to_bytes("五段-サ行"));
storage.extend(b"\x06");
storage.extend(string_to_bytes("終止形-一般"));
}
fn build_connect_table(storage: &mut Vec<u8>) {
storage.extend(&3_i16.to_le_bytes());
storage.extend(&3_i16.to_le_bytes());
storage.extend(&0_i16.to_le_bytes());
storage.extend(&(-300_i16).to_le_bytes());
storage.extend(&300_i16.to_le_bytes());
storage.extend(&300_i16.to_le_bytes());
storage.extend(&(-500_i16).to_le_bytes());
storage.extend(&(-100_i16).to_le_bytes());
storage.extend(&(-3000_i16).to_le_bytes());
storage.extend(&200_i16.to_le_bytes());
storage.extend(&2000_i16.to_le_bytes());
}
}