sudachi/analysis/
lattice.rs1use crate::analysis::inner::{Node, NodeIdx};
18use crate::analysis::node::{LatticeNode, PathCost, RightId};
19use crate::dic::connect::ConnectionMatrix;
20use crate::dic::grammar::Grammar;
21use crate::dic::lexicon_set::LexiconSet;
22use crate::dic::subset::InfoSubset;
23use crate::dic::word_id::WordId;
24use crate::error::SudachiResult;
25use crate::input_text::InputBuffer;
26use crate::prelude::SudachiError;
27use std::fmt::{Display, Formatter};
28use std::io::Write;
29
30struct VNode {
36 total_cost: i32,
37 right_id: u16,
38}
39
40impl RightId for VNode {
41 #[inline]
42 fn right_id(&self) -> u16 {
43 self.right_id
44 }
45}
46
47impl PathCost for VNode {
48 #[inline]
49 fn total_cost(&self) -> i32 {
50 self.total_cost
51 }
52}
53
54impl VNode {
55 #[inline]
56 fn new(right_id: u16, total_cost: i32) -> VNode {
57 VNode {
58 right_id,
59 total_cost,
60 }
61 }
62}
63
64#[derive(Default)]
74pub struct Lattice {
75 ends: Vec<Vec<VNode>>,
76 ends_full: Vec<Vec<Node>>,
77 indices: Vec<Vec<NodeIdx>>,
78 eos: Option<(NodeIdx, i32)>,
79 size: usize,
80}
81
82impl Lattice {
83 fn reset_vec<T>(data: &mut Vec<Vec<T>>, target: usize) {
84 for v in data.iter_mut() {
85 v.clear();
86 }
87 let cur_len = data.len();
88 if cur_len <= target {
89 data.reserve(target - cur_len);
90 for _ in cur_len..target {
91 data.push(Vec::with_capacity(16))
92 }
93 }
94 }
95
96 pub fn reset(&mut self, length: usize) {
99 Self::reset_vec(&mut self.ends, length + 1);
100 Self::reset_vec(&mut self.ends_full, length + 1);
101 Self::reset_vec(&mut self.indices, length + 1);
102 self.eos = None;
103 self.size = length + 1;
104 self.connect_bos();
105 }
106
107 fn connect_bos(&mut self) {
108 self.ends[0].push(VNode::new(0, 0));
109 }
110
111 pub fn connect_eos(&mut self, conn: &ConnectionMatrix) -> SudachiResult<()> {
113 let len = self.size;
114 let eos_start = (len - 1) as u16;
115 let eos_end = (len - 1) as u16;
116 let node = Node::new(eos_start, eos_end, 0, 0, 0, WordId::EOS);
117 let (idx, cost) = self.connect_node(&node, conn);
118 if cost == i32::MAX {
119 Err(SudachiError::EosBosDisconnect)
120 } else {
121 self.eos = Some((idx, cost));
122 Ok(())
123 }
124 }
125
126 pub fn insert(&mut self, node: Node, conn: &ConnectionMatrix) -> i32 {
129 let (idx, cost) = self.connect_node(&node, conn);
130 let end_idx = node.end();
131 self.ends[end_idx].push(VNode::new(node.right_id(), cost));
132 self.indices[end_idx].push(idx);
133 self.ends_full[end_idx].push(node);
134 cost
135 }
136
137 #[inline]
140 pub fn connect_node(&self, r_node: &Node, conn: &ConnectionMatrix) -> (NodeIdx, i32) {
141 let begin = r_node.begin();
142
143 let node_cost = r_node.cost() as i32;
144 let mut min_cost = i32::MAX;
145 let mut prev_idx = NodeIdx::empty();
146
147 for (i, l_node) in self.ends[begin].iter().enumerate() {
148 if !l_node.is_connected_to_bos() {
149 continue;
150 }
151
152 let connect_cost = conn.cost(l_node.right_id(), r_node.left_id()) as i32;
153 let new_cost = l_node.total_cost() + connect_cost + node_cost;
154 if new_cost < min_cost {
155 min_cost = new_cost;
156 prev_idx = NodeIdx::new(begin as u16, i as u16);
157 }
158 }
159
160 (prev_idx, min_cost)
161 }
162
163 pub fn has_previous_node(&self, i: usize) -> bool {
165 self.ends.get(i).map(|d| !d.is_empty()).unwrap_or(false)
166 }
167
168 pub fn node(&self, id: NodeIdx) -> (&Node, i32) {
170 let node = &self.ends_full[id.end() as usize][id.index() as usize];
171 let cost = self.ends[id.end() as usize][id.index() as usize].total_cost;
172 (node, cost)
173 }
174
175 pub fn fill_top_path(&self, result: &mut Vec<NodeIdx>) {
179 if self.eos.is_none() {
180 return;
181 }
182 let (mut idx, _) = self.eos.unwrap();
184 result.push(idx);
185 loop {
186 let prev_idx = self.indices[idx.end() as usize][idx.index() as usize];
187 if prev_idx.end() != 0 {
188 result.push(prev_idx);
190 idx = prev_idx;
191 } else {
192 break;
194 }
195 }
196 }
197}
198
199impl Lattice {
200 pub fn dump<W: Write>(
201 &self,
202 input: &InputBuffer,
203 grammar: &Grammar,
204 lexicon: &LexiconSet,
205 out: &mut W,
206 ) -> SudachiResult<()> {
207 enum PosData<'a> {
208 Bos,
209 Borrow(&'a [String]),
210 }
211
212 impl Display for PosData<'_> {
213 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
214 match self {
215 PosData::Bos => write!(f, "BOS/EOS"),
216 PosData::Borrow(data) => {
217 for (i, s) in data.iter().enumerate() {
218 write!(f, "{}", s)?;
219 if i + 1 != data.len() {
220 write!(f, ", ")?;
221 }
222 }
223 Ok(())
224 }
225 }
226 }
227 }
228
229 let mut dump_idx = 0;
230
231 for boundary in (0..self.indices.len()).rev() {
232 for r_node in &self.ends_full[boundary] {
233 let (surface, pos) = if r_node.is_special_node() {
234 ("(null)", PosData::Bos)
235 } else if r_node.is_oov() {
236 let pos_id = r_node.word_id().word() as usize;
237 (
238 input.curr_slice_c(r_node.begin()..r_node.end()),
239 PosData::Borrow(&grammar.pos_list[pos_id]),
240 )
241 } else {
242 let winfo =
243 lexicon.get_word_info_subset(r_node.word_id(), InfoSubset::POS_ID)?;
244 (
245 input.orig_slice_c(r_node.begin()..r_node.end()),
246 PosData::Borrow(&grammar.pos_list[winfo.pos_id() as usize]),
247 )
248 };
249
250 write!(
251 out,
252 "{}: {} {} {}{} {} {} {} {}:",
253 dump_idx,
254 r_node.begin(),
255 r_node.end(),
256 surface,
257 r_node.word_id(),
258 pos,
259 r_node.left_id(),
260 r_node.right_id(),
261 r_node.cost()
262 )?;
263
264 let conn = grammar.conn_matrix();
265
266 for l_node in &self.ends[r_node.begin()] {
267 let connect_cost = conn.cost(l_node.right_id(), r_node.left_id());
268 write!(out, " {}", connect_cost)?;
269 }
270
271 writeln!(out)?;
272
273 dump_idx += 1;
274 }
275 }
276 Ok(())
277 }
278}