sudachi/analysis/
lattice.rs

1/*
2 *  Copyright (c) 2021-2024 Works Applications Co., Ltd.
3 *
4 *  Licensed under the Apache License, Version 2.0 (the "License");
5 *  you may not use this file except in compliance with the License.
6 *  You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 *   Unless required by applicable law or agreed to in writing, software
11 *  distributed under the License is distributed on an "AS IS" BASIS,
12 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 *  See the License for the specific language governing permissions and
14 *  limitations under the License.
15 */
16
17use crate::analysis::inner::{Node, NodeIdx};
18use crate::analysis::node::{LatticeNode, PathCost, RightId};
19use crate::dic::connect::ConnectionMatrix;
20use crate::dic::grammar::Grammar;
21use crate::dic::lexicon_set::LexiconSet;
22use crate::dic::subset::InfoSubset;
23use crate::dic::word_id::WordId;
24use crate::error::SudachiResult;
25use crate::input_text::InputBuffer;
26use crate::prelude::SudachiError;
27use std::fmt::{Display, Formatter};
28use std::io::Write;
29
30/// Lattice Node for Viterbi Search.
31/// Extremely small for better cache locality.
32/// Current implementation has 25% efficiency loss because of padding :(
33/// Maybe we should use array-of-structs layout instead, but I want to try to measure the
34/// efficiency of that without the effects of the current rewrite.
35struct VNode {
36    total_cost: i32,
37    right_id: u16,
38}
39
40impl RightId for VNode {
41    #[inline]
42    fn right_id(&self) -> u16 {
43        self.right_id
44    }
45}
46
47impl PathCost for VNode {
48    #[inline]
49    fn total_cost(&self) -> i32 {
50        self.total_cost
51    }
52}
53
54impl VNode {
55    #[inline]
56    fn new(right_id: u16, total_cost: i32) -> VNode {
57        VNode {
58            right_id,
59            total_cost,
60        }
61    }
62}
63
64/// Lattice which is constructed for performing the Viterbi search.
65/// Contain several parallel arrays.
66/// First level of parallel arrays is indexed by end word boundary.
67/// Word boundaries are always aligned to codepoint boundaries, not to byte boundaries.
68///
69/// During the successive analysis, we do not drop inner vectors, so
70/// the size of vectors never shrink.
71/// You must use the size parameter to check the current size and never
72/// access vectors after the end.
73#[derive(Default)]
74pub struct Lattice {
75    ends: Vec<Vec<VNode>>,
76    ends_full: Vec<Vec<Node>>,
77    indices: Vec<Vec<NodeIdx>>,
78    eos: Option<(NodeIdx, i32)>,
79    size: usize,
80}
81
82impl Lattice {
83    fn reset_vec<T>(data: &mut Vec<Vec<T>>, target: usize) {
84        for v in data.iter_mut() {
85            v.clear();
86        }
87        let cur_len = data.len();
88        if cur_len <= target {
89            data.reserve(target - cur_len);
90            for _ in cur_len..target {
91                data.push(Vec::with_capacity(16))
92            }
93        }
94    }
95
96    /// Prepare lattice for the next analysis of a sentence with the
97    /// specified length (in codepoints)
98    pub fn reset(&mut self, length: usize) {
99        Self::reset_vec(&mut self.ends, length + 1);
100        Self::reset_vec(&mut self.ends_full, length + 1);
101        Self::reset_vec(&mut self.indices, length + 1);
102        self.eos = None;
103        self.size = length + 1;
104        self.connect_bos();
105    }
106
107    fn connect_bos(&mut self) {
108        self.ends[0].push(VNode::new(0, 0));
109    }
110
111    /// Find EOS node -- finish the lattice construction
112    pub fn connect_eos(&mut self, conn: &ConnectionMatrix) -> SudachiResult<()> {
113        let len = self.size;
114        let eos_start = (len - 1) as u16;
115        let eos_end = (len - 1) as u16;
116        let node = Node::new(eos_start, eos_end, 0, 0, 0, WordId::EOS);
117        let (idx, cost) = self.connect_node(&node, conn);
118        if cost == i32::MAX {
119            Err(SudachiError::EosBosDisconnect)
120        } else {
121            self.eos = Some((idx, cost));
122            Ok(())
123        }
124    }
125
126    /// Insert a single node in the lattice, founding the path to the previous node
127    /// Assumption: lattice for all previous boundaries is already constructed
128    pub fn insert(&mut self, node: Node, conn: &ConnectionMatrix) -> i32 {
129        let (idx, cost) = self.connect_node(&node, conn);
130        let end_idx = node.end();
131        self.ends[end_idx].push(VNode::new(node.right_id(), cost));
132        self.indices[end_idx].push(idx);
133        self.ends_full[end_idx].push(node);
134        cost
135    }
136
137    /// Find the path with the minimal cost through the lattice to the attached node
138    /// Assumption: lattice for all previous boundaries is already constructed
139    #[inline]
140    pub fn connect_node(&self, r_node: &Node, conn: &ConnectionMatrix) -> (NodeIdx, i32) {
141        let begin = r_node.begin();
142
143        let node_cost = r_node.cost() as i32;
144        let mut min_cost = i32::MAX;
145        let mut prev_idx = NodeIdx::empty();
146
147        for (i, l_node) in self.ends[begin].iter().enumerate() {
148            if !l_node.is_connected_to_bos() {
149                continue;
150            }
151
152            let connect_cost = conn.cost(l_node.right_id(), r_node.left_id()) as i32;
153            let new_cost = l_node.total_cost() + connect_cost + node_cost;
154            if new_cost < min_cost {
155                min_cost = new_cost;
156                prev_idx = NodeIdx::new(begin as u16, i as u16);
157            }
158        }
159
160        (prev_idx, min_cost)
161    }
162
163    /// Checks if there exist at least one at the word end boundary
164    pub fn has_previous_node(&self, i: usize) -> bool {
165        self.ends.get(i).map(|d| !d.is_empty()).unwrap_or(false)
166    }
167
168    /// Lookup a node for the index
169    pub fn node(&self, id: NodeIdx) -> (&Node, i32) {
170        let node = &self.ends_full[id.end() as usize][id.index() as usize];
171        let cost = self.ends[id.end() as usize][id.index() as usize].total_cost;
172        (node, cost)
173    }
174
175    /// Fill the path with the minimum cost (indices only).
176    /// **Attention**: the path will be reversed (end to beginning) and will need to be traversed
177    /// in the reverse order.
178    pub fn fill_top_path(&self, result: &mut Vec<NodeIdx>) {
179        if self.eos.is_none() {
180            return;
181        }
182        // start with EOS
183        let (mut idx, _) = self.eos.unwrap();
184        result.push(idx);
185        loop {
186            let prev_idx = self.indices[idx.end() as usize][idx.index() as usize];
187            if prev_idx.end() != 0 {
188                // add if not BOS
189                result.push(prev_idx);
190                idx = prev_idx;
191            } else {
192                // finish if BOS
193                break;
194            }
195        }
196    }
197}
198
199impl Lattice {
200    pub fn dump<W: Write>(
201        &self,
202        input: &InputBuffer,
203        grammar: &Grammar,
204        lexicon: &LexiconSet,
205        out: &mut W,
206    ) -> SudachiResult<()> {
207        enum PosData<'a> {
208            Bos,
209            Borrow(&'a [String]),
210        }
211
212        impl Display for PosData<'_> {
213            fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
214                match self {
215                    PosData::Bos => write!(f, "BOS/EOS"),
216                    PosData::Borrow(data) => {
217                        for (i, s) in data.iter().enumerate() {
218                            write!(f, "{}", s)?;
219                            if i + 1 != data.len() {
220                                write!(f, ", ")?;
221                            }
222                        }
223                        Ok(())
224                    }
225                }
226            }
227        }
228
229        let mut dump_idx = 0;
230
231        for boundary in (0..self.indices.len()).rev() {
232            for r_node in &self.ends_full[boundary] {
233                let (surface, pos) = if r_node.is_special_node() {
234                    ("(null)", PosData::Bos)
235                } else if r_node.is_oov() {
236                    let pos_id = r_node.word_id().word() as usize;
237                    (
238                        input.curr_slice_c(r_node.begin()..r_node.end()),
239                        PosData::Borrow(&grammar.pos_list[pos_id]),
240                    )
241                } else {
242                    let winfo =
243                        lexicon.get_word_info_subset(r_node.word_id(), InfoSubset::POS_ID)?;
244                    (
245                        input.orig_slice_c(r_node.begin()..r_node.end()),
246                        PosData::Borrow(&grammar.pos_list[winfo.pos_id() as usize]),
247                    )
248                };
249
250                write!(
251                    out,
252                    "{}: {} {} {}{} {} {} {} {}:",
253                    dump_idx,
254                    r_node.begin(),
255                    r_node.end(),
256                    surface,
257                    r_node.word_id(),
258                    pos,
259                    r_node.left_id(),
260                    r_node.right_id(),
261                    r_node.cost()
262                )?;
263
264                let conn = grammar.conn_matrix();
265
266                for l_node in &self.ends[r_node.begin()] {
267                    let connect_cost = conn.cost(l_node.right_id(), r_node.left_id());
268                    write!(out, " {}", connect_cost)?;
269                }
270
271                writeln!(out)?;
272
273                dump_idx += 1;
274            }
275        }
276        Ok(())
277    }
278}