use crate::dic::grammar::Grammar;
use crate::error::{SudachiError, SudachiResult};
use itertools::Itertools;
use serde::Deserialize;
use std::fmt::Display;
#[derive(Eq, PartialEq, Deserialize, Clone, Copy, Debug, Default)]
#[serde(rename_all = "lowercase")]
pub enum UserPosMode {
#[default]
Forbid,
Allow,
}
pub trait UserPosSupport {
fn handle_user_pos<S: AsRef<str> + ToString + Display>(
&mut self,
pos: &[S],
mode: UserPosMode,
) -> SudachiResult<u16>;
}
impl<'a> UserPosSupport for &'a mut Grammar<'_> {
fn handle_user_pos<S: AsRef<str> + ToString + Display>(
&mut self,
pos: &[S],
mode: UserPosMode,
) -> SudachiResult<u16> {
if let Some(id) = self.get_part_of_speech_id(pos) {
return Ok(id);
}
match mode {
UserPosMode::Allow => self.register_pos(pos),
UserPosMode::Forbid => Err(SudachiError::InvalidPartOfSpeech(format!(
"POS {} was not in the dictionary, user-defined POS are forbidden",
pos.iter().join(",")
))),
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn allow() {
let mode: UserPosMode = serde_json::from_str("\"allow\"").expect("fails");
assert_eq!(UserPosMode::Allow, mode)
}
#[test]
fn forbid() {
let mode: UserPosMode = serde_json::from_str("\"forbid\"").expect("fails");
assert_eq!(UserPosMode::Forbid, mode)
}
#[test]
fn other_value() {
let mode: Result<UserPosMode, _> = serde_json::from_str("\"test\"");
assert!(mode.is_err())
}
}