1 #ifndef MARISA_ALPHA_TRIE_INLINE_H_
2 #define MARISA_ALPHA_TRIE_INLINE_H_
3
4 #include <stdexcept>
5
6 #include "cell.h"
7
8 namespace marisa_alpha {
9
10 inline std::string Trie::operator[](UInt32 key_id) const {
11 std::string key;
12 restore(key_id, &key);
13 return key;
14 }
15
16 inline UInt32 Trie::operator[](const char *str) const {
17 return lookup(str);
18 }
19
20 inline UInt32 Trie::operator[](const std::string &str) const {
21 return lookup(str);
22 }
23
lookup(const std::string & str)24 inline UInt32 Trie::lookup(const std::string &str) const {
25 return lookup(str.c_str(), str.length());
26 }
27
find(const std::string & str,UInt32 * key_ids,std::size_t * key_lengths,std::size_t max_num_results)28 inline std::size_t Trie::find(const std::string &str,
29 UInt32 *key_ids, std::size_t *key_lengths,
30 std::size_t max_num_results) const {
31 return find(str.c_str(), str.length(),
32 key_ids, key_lengths, max_num_results);
33 }
34
find(const std::string & str,std::vector<UInt32> * key_ids,std::vector<std::size_t> * key_lengths,std::size_t max_num_results)35 inline std::size_t Trie::find(const std::string &str,
36 std::vector<UInt32> *key_ids, std::vector<std::size_t> *key_lengths,
37 std::size_t max_num_results) const {
38 return find(str.c_str(), str.length(),
39 key_ids, key_lengths, max_num_results);
40 }
41
find_first(const std::string & str,std::size_t * key_length)42 inline UInt32 Trie::find_first(const std::string &str,
43 std::size_t *key_length) const {
44 return find_first(str.c_str(), str.length(), key_length);
45 }
46
find_last(const std::string & str,std::size_t * key_length)47 inline UInt32 Trie::find_last(const std::string &str,
48 std::size_t *key_length) const {
49 return find_last(str.c_str(), str.length(), key_length);
50 }
51
52 template <typename T>
find_callback(const char * str,T callback)53 inline std::size_t Trie::find_callback(const char *str,
54 T callback) const {
55 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
56 MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
57 return find_callback_<CQuery>(CQuery(str), callback);
58 }
59
60 template <typename T>
find_callback(const char * ptr,std::size_t length,T callback)61 inline std::size_t Trie::find_callback(const char *ptr, std::size_t length,
62 T callback) const {
63 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
64 MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
65 MARISA_ALPHA_PARAM_ERROR);
66 return find_callback_<const Query &>(Query(ptr, length), callback);
67 }
68
69 template <typename T>
find_callback(const std::string & str,T callback)70 inline std::size_t Trie::find_callback(const std::string &str,
71 T callback) const {
72 return find_callback(str.c_str(), str.length(), callback);
73 }
74
predict(const std::string & str,UInt32 * key_ids,std::string * keys,std::size_t max_num_results)75 inline std::size_t Trie::predict(const std::string &str,
76 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
77 return predict(str.c_str(), str.length(), key_ids, keys, max_num_results);
78 }
79
predict(const std::string & str,std::vector<UInt32> * key_ids,std::vector<std::string> * keys,std::size_t max_num_results)80 inline std::size_t Trie::predict(const std::string &str,
81 std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
82 std::size_t max_num_results) const {
83 return predict(str.c_str(), str.length(), key_ids, keys, max_num_results);
84 }
85
predict_breadth_first(const std::string & str,UInt32 * key_ids,std::string * keys,std::size_t max_num_results)86 inline std::size_t Trie::predict_breadth_first(const std::string &str,
87 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
88 return predict_breadth_first(str.c_str(), str.length(),
89 key_ids, keys, max_num_results);
90 }
91
predict_breadth_first(const std::string & str,std::vector<UInt32> * key_ids,std::vector<std::string> * keys,std::size_t max_num_results)92 inline std::size_t Trie::predict_breadth_first(const std::string &str,
93 std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
94 std::size_t max_num_results) const {
95 return predict_breadth_first(str.c_str(), str.length(),
96 key_ids, keys, max_num_results);
97 }
98
predict_depth_first(const std::string & str,UInt32 * key_ids,std::string * keys,std::size_t max_num_results)99 inline std::size_t Trie::predict_depth_first(const std::string &str,
100 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
101 return predict_depth_first(str.c_str(), str.length(),
102 key_ids, keys, max_num_results);
103 }
104
predict_depth_first(const std::string & str,std::vector<UInt32> * key_ids,std::vector<std::string> * keys,std::size_t max_num_results)105 inline std::size_t Trie::predict_depth_first(const std::string &str,
106 std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
107 std::size_t max_num_results) const {
108 return predict_depth_first(str.c_str(), str.length(),
109 key_ids, keys, max_num_results);
110 }
111
112 template <typename T>
predict_callback(const char * str,T callback)113 inline std::size_t Trie::predict_callback(
114 const char *str, T callback) const {
115 return predict_callback_<CQuery>(CQuery(str), callback);
116 }
117
118 template <typename T>
predict_callback(const char * ptr,std::size_t length,T callback)119 inline std::size_t Trie::predict_callback(
120 const char *ptr, std::size_t length,
121 T callback) const {
122 return predict_callback_<const Query &>(Query(ptr, length), callback);
123 }
124
125 template <typename T>
predict_callback(const std::string & str,T callback)126 inline std::size_t Trie::predict_callback(
127 const std::string &str, T callback) const {
128 return predict_callback(str.c_str(), str.length(), callback);
129 }
130
empty()131 inline bool Trie::empty() const {
132 return louds_.empty();
133 }
134
num_keys()135 inline std::size_t Trie::num_keys() const {
136 return num_keys_;
137 }
138
notfound()139 inline UInt32 Trie::notfound() {
140 return MARISA_ALPHA_NOT_FOUND;
141 }
142
mismatch()143 inline std::size_t Trie::mismatch() {
144 return MARISA_ALPHA_MISMATCH;
145 }
146
147 template <typename T>
find_child(UInt32 & node,T query,std::size_t & pos)148 inline bool Trie::find_child(UInt32 &node, T query,
149 std::size_t &pos) const {
150 UInt32 louds_pos = get_child(node);
151 if (!louds_[louds_pos]) {
152 return false;
153 }
154 node = louds_pos_to_node(louds_pos, node);
155 UInt32 link_id = MARISA_ALPHA_UINT32_MAX;
156 do {
157 if (has_link(node)) {
158 if (link_id == MARISA_ALPHA_UINT32_MAX) {
159 link_id = get_link_id(node);
160 } else {
161 ++link_id;
162 }
163 std::size_t next_pos = has_trie() ?
164 trie_->trie_match<T>(get_link(node, link_id), query, pos) :
165 tail_match<T>(node, link_id, query, pos);
166 if (next_pos == mismatch()) {
167 return false;
168 } else if (next_pos != pos) {
169 pos = next_pos;
170 return true;
171 }
172 } else if (labels_[node] == query[pos]) {
173 ++pos;
174 return true;
175 }
176 ++node;
177 ++louds_pos;
178 } while (louds_[louds_pos]);
179 return false;
180 }
181
182 template <typename T, typename U>
find_callback_(T query,U callback)183 std::size_t Trie::find_callback_(T query, U callback) const try {
184 std::size_t count = 0;
185 UInt32 node = 0;
186 std::size_t pos = 0;
187 do {
188 if (terminal_flags_[node]) {
189 ++count;
190 if (!callback(node_to_key_id(node), pos)) {
191 return count;
192 }
193 }
194 } while (!query.ends_at(pos) && find_child<T>(node, query, pos));
195 return count;
196 } catch (const std::bad_alloc &) {
197 MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR);
catch(const std::length_error &)198 } catch (const std::length_error &) {
199 MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR);
200 }
201
202 template <typename T>
predict_child(UInt32 & node,T query,std::size_t & pos,std::string * key)203 inline bool Trie::predict_child(UInt32 &node, T query, std::size_t &pos,
204 std::string *key) const {
205 UInt32 louds_pos = get_child(node);
206 if (!louds_[louds_pos]) {
207 return false;
208 }
209 node = louds_pos_to_node(louds_pos, node);
210 UInt32 link_id = MARISA_ALPHA_UINT32_MAX;
211 do {
212 if (has_link(node)) {
213 if (link_id == MARISA_ALPHA_UINT32_MAX) {
214 link_id = get_link_id(node);
215 } else {
216 ++link_id;
217 }
218 std::size_t next_pos = has_trie() ?
219 trie_->trie_prefix_match<T>(
220 get_link(node, link_id), query, pos, key) :
221 tail_prefix_match<T>(node, link_id, query, pos, key);
222 if (next_pos == mismatch()) {
223 return false;
224 } else if (next_pos != pos) {
225 pos = next_pos;
226 return true;
227 }
228 } else if (labels_[node] == query[pos]) {
229 ++pos;
230 return true;
231 }
232 ++node;
233 ++louds_pos;
234 } while (louds_[louds_pos]);
235 return false;
236 }
237
238 template <typename T, typename U>
predict_callback_(T query,U callback)239 std::size_t Trie::predict_callback_(T query, U callback) const try {
240 std::string key;
241 UInt32 node = 0;
242 std::size_t pos = 0;
243 while (!query.ends_at(pos)) {
244 if (!predict_child<T>(node, query, pos, &key)) {
245 return 0;
246 }
247 }
248 query.insert(&key);
249 std::size_t count = 0;
250 if (terminal_flags_[node]) {
251 ++count;
252 if (!callback(node_to_key_id(node), key)) {
253 return count;
254 }
255 }
256 Cell cell;
257 cell.set_louds_pos(get_child(node));
258 if (!louds_[cell.louds_pos()]) {
259 return count;
260 }
261 cell.set_node(louds_pos_to_node(cell.louds_pos(), node));
262 cell.set_key_id(node_to_key_id(cell.node()));
263 cell.set_length(key.length());
264 Vector<Cell> stack;
265 stack.push_back(cell);
266 std::size_t stack_pos = 1;
267 while (stack_pos != 0) {
268 Cell &cur = stack[stack_pos - 1];
269 if (!louds_[cur.louds_pos()]) {
270 cur.set_louds_pos(cur.louds_pos() + 1);
271 --stack_pos;
272 continue;
273 }
274 cur.set_louds_pos(cur.louds_pos() + 1);
275 key.resize(cur.length());
276 if (has_link(cur.node())) {
277 if (has_trie()) {
278 trie_->trie_restore(get_link(cur.node()), &key);
279 } else {
280 tail_restore(cur.node(), &key);
281 }
282 } else {
283 key += labels_[cur.node()];
284 }
285 if (terminal_flags_[cur.node()]) {
286 ++count;
287 if (!callback(cur.key_id(), key)) {
288 return count;
289 }
290 cur.set_key_id(cur.key_id() + 1);
291 }
292 if (stack_pos == stack.size()) {
293 cell.set_louds_pos(get_child(cur.node()));
294 cell.set_node(louds_pos_to_node(cell.louds_pos(), cur.node()));
295 cell.set_key_id(node_to_key_id(cell.node()));
296 stack.push_back(cell);
297 }
298 stack[stack_pos].set_length(key.length());
299 stack[stack_pos - 1].set_node(stack[stack_pos - 1].node() + 1);
300 ++stack_pos;
301 }
302 return count;
303 } catch (const std::bad_alloc &) {
304 MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR);
catch(const std::length_error &)305 } catch (const std::length_error &) {
306 MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR);
307 }
308
key_id_to_node(UInt32 key_id)309 inline UInt32 Trie::key_id_to_node(UInt32 key_id) const {
310 return terminal_flags_.select1(key_id);
311 }
312
node_to_key_id(UInt32 node)313 inline UInt32 Trie::node_to_key_id(UInt32 node) const {
314 return terminal_flags_.rank1(node);
315 }
316
louds_pos_to_node(UInt32 louds_pos,UInt32 parent_node)317 inline UInt32 Trie::louds_pos_to_node(UInt32 louds_pos,
318 UInt32 parent_node) const {
319 return louds_pos - parent_node - 1;
320 }
321
get_child(UInt32 node)322 inline UInt32 Trie::get_child(UInt32 node) const {
323 return louds_.select0(node) + 1;
324 }
325
get_parent(UInt32 node)326 inline UInt32 Trie::get_parent(UInt32 node) const {
327 return (node > num_first_branches_) ? (louds_.select1(node) - node - 1) : 0;
328 }
329
has_link(UInt32 node)330 inline bool Trie::has_link(UInt32 node) const {
331 return (link_flags_.empty()) ? false : link_flags_[node];
332 }
333
get_link_id(UInt32 node)334 inline UInt32 Trie::get_link_id(UInt32 node) const {
335 return link_flags_.rank1(node);
336 }
337
get_link(UInt32 node)338 inline UInt32 Trie::get_link(UInt32 node) const {
339 return get_link(node, get_link_id(node));
340 }
341
get_link(UInt32 node,UInt32 link_id)342 inline UInt32 Trie::get_link(UInt32 node, UInt32 link_id) const {
343 return (links_[link_id] * 256) + labels_[node];
344 }
345
has_link()346 inline bool Trie::has_link() const {
347 return !link_flags_.empty();
348 }
349
has_trie()350 inline bool Trie::has_trie() const {
351 return trie_.get() != NULL;
352 }
353
has_tail()354 inline bool Trie::has_tail() const {
355 return !tail_.empty();
356 }
357
358 } // namespace marisa_alpha
359
360 #endif // MARISA_ALPHA_TRIE_INLINE_H_
361