1use crate::dic::grammar::Grammar;
18use crate::error::{SudachiError, SudachiResult};
19use itertools::Itertools;
20use serde::Deserialize;
21use std::fmt::Display;
22
23#[derive(Eq, PartialEq, Deserialize, Clone, Copy, Debug, Default)]
24#[serde(rename_all = "lowercase")]
25pub enum UserPosMode {
26 #[default]
27 Forbid,
28 Allow,
29}
30
31pub trait UserPosSupport {
32 fn handle_user_pos<S: AsRef<str> + ToString + Display>(
33 &mut self,
34 pos: &[S],
35 mode: UserPosMode,
36 ) -> SudachiResult<u16>;
37}
38
39impl<'a> UserPosSupport for &'a mut Grammar<'_> {
40 fn handle_user_pos<S: AsRef<str> + ToString + Display>(
41 &mut self,
42 pos: &[S],
43 mode: UserPosMode,
44 ) -> SudachiResult<u16> {
45 if let Some(id) = self.get_part_of_speech_id(pos) {
46 return Ok(id);
47 }
48
49 match mode {
50 UserPosMode::Allow => self.register_pos(pos),
51 UserPosMode::Forbid => Err(SudachiError::InvalidPartOfSpeech(format!(
52 "POS {} was not in the dictionary, user-defined POS are forbidden",
53 pos.iter().join(",")
54 ))),
55 }
56 }
57}
58
59#[cfg(test)]
60mod test {
61 use super::*;
62
63 #[test]
64 fn allow() {
65 let mode: UserPosMode = serde_json::from_str("\"allow\"").expect("fails");
66 assert_eq!(UserPosMode::Allow, mode)
67 }
68
69 #[test]
70 fn forbid() {
71 let mode: UserPosMode = serde_json::from_str("\"forbid\"").expect("fails");
72 assert_eq!(UserPosMode::Forbid, mode)
73 }
74
75 #[test]
76 fn other_value() {
77 let mode: Result<UserPosMode, _> = serde_json::from_str("\"test\"");
78 assert!(mode.is_err())
79 }
80}