1 #include <algorithm>
2 #include <stdexcept>
3
4 #include "trie.h"
5
6 namespace marisa_alpha {
7 namespace {
8
9 template <typename T, typename U>
10 class PredictCallback {
11 public:
PredictCallback(T key_ids,U keys,std::size_t max_num_results)12 PredictCallback(T key_ids, U keys, std::size_t max_num_results)
13 : key_ids_(key_ids), keys_(keys),
14 max_num_results_(max_num_results), num_results_(0) {}
PredictCallback(const PredictCallback & callback)15 PredictCallback(const PredictCallback &callback)
16 : key_ids_(callback.key_ids_), keys_(callback.keys_),
17 max_num_results_(callback.max_num_results_),
18 num_results_(callback.num_results_) {}
19
operator ()(marisa_alpha::UInt32 key_id,const std::string & key)20 bool operator()(marisa_alpha::UInt32 key_id, const std::string &key) {
21 if (key_ids_.is_valid()) {
22 key_ids_.insert(num_results_, key_id);
23 }
24 if (keys_.is_valid()) {
25 keys_.insert(num_results_, key);
26 }
27 return ++num_results_ < max_num_results_;
28 }
29
30 private:
31 T key_ids_;
32 U keys_;
33 const std::size_t max_num_results_;
34 std::size_t num_results_;
35
36 // Disallows assignment.
37 PredictCallback &operator=(const PredictCallback &);
38 };
39
40 } // namespace
41
restore(UInt32 key_id) const42 std::string Trie::restore(UInt32 key_id) const {
43 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
44 MARISA_ALPHA_THROW_IF(key_id >= num_keys_, MARISA_ALPHA_PARAM_ERROR);
45 std::string key;
46 restore_(key_id, &key);
47 return key;
48 }
49
restore(UInt32 key_id,std::string * key) const50 void Trie::restore(UInt32 key_id, std::string *key) const {
51 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
52 MARISA_ALPHA_THROW_IF(key_id >= num_keys_, MARISA_ALPHA_PARAM_ERROR);
53 MARISA_ALPHA_THROW_IF(key == NULL, MARISA_ALPHA_PARAM_ERROR);
54 restore_(key_id, key);
55 }
56
restore(UInt32 key_id,char * key_buf,std::size_t key_buf_size) const57 std::size_t Trie::restore(UInt32 key_id, char *key_buf,
58 std::size_t key_buf_size) const {
59 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
60 MARISA_ALPHA_THROW_IF(key_id >= num_keys_, MARISA_ALPHA_PARAM_ERROR);
61 MARISA_ALPHA_THROW_IF((key_buf == NULL) && (key_buf_size != 0),
62 MARISA_ALPHA_PARAM_ERROR);
63 return restore_(key_id, key_buf, key_buf_size);
64 }
65
lookup(const char * str) const66 UInt32 Trie::lookup(const char *str) const {
67 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
68 MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
69 return lookup_<CQuery>(CQuery(str));
70 }
71
lookup(const char * ptr,std::size_t length) const72 UInt32 Trie::lookup(const char *ptr, std::size_t length) const {
73 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
74 MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
75 MARISA_ALPHA_PARAM_ERROR);
76 return lookup_<const Query &>(Query(ptr, length));
77 }
78
find(const char * str,UInt32 * key_ids,std::size_t * key_lengths,std::size_t max_num_results) const79 std::size_t Trie::find(const char *str,
80 UInt32 *key_ids, std::size_t *key_lengths,
81 std::size_t max_num_results) const {
82 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
83 MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
84 return find_<CQuery>(CQuery(str),
85 MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
86 }
87
find(const char * ptr,std::size_t length,UInt32 * key_ids,std::size_t * key_lengths,std::size_t max_num_results) const88 std::size_t Trie::find(const char *ptr, std::size_t length,
89 UInt32 *key_ids, std::size_t *key_lengths,
90 std::size_t max_num_results) const {
91 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
92 MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
93 MARISA_ALPHA_PARAM_ERROR);
94 return find_<const Query &>(Query(ptr, length),
95 MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
96 }
97
find(const char * str,std::vector<UInt32> * key_ids,std::vector<std::size_t> * key_lengths,std::size_t max_num_results) const98 std::size_t Trie::find(const char *str,
99 std::vector<UInt32> *key_ids, std::vector<std::size_t> *key_lengths,
100 std::size_t max_num_results) const {
101 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
102 MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
103 return find_<CQuery>(CQuery(str),
104 MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
105 }
106
find(const char * ptr,std::size_t length,std::vector<UInt32> * key_ids,std::vector<std::size_t> * key_lengths,std::size_t max_num_results) const107 std::size_t Trie::find(const char *ptr, std::size_t length,
108 std::vector<UInt32> *key_ids, std::vector<std::size_t> *key_lengths,
109 std::size_t max_num_results) const {
110 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
111 MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
112 MARISA_ALPHA_PARAM_ERROR);
113 return find_<const Query &>(Query(ptr, length),
114 MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
115 }
116
find_first(const char * str,std::size_t * key_length) const117 UInt32 Trie::find_first(const char *str,
118 std::size_t *key_length) const {
119 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
120 MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
121 return find_first_<CQuery>(CQuery(str), key_length);
122 }
123
find_first(const char * ptr,std::size_t length,std::size_t * key_length) const124 UInt32 Trie::find_first(const char *ptr, std::size_t length,
125 std::size_t *key_length) const {
126 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
127 MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
128 MARISA_ALPHA_PARAM_ERROR);
129 return find_first_<const Query &>(Query(ptr, length), key_length);
130 }
131
find_last(const char * str,std::size_t * key_length) const132 UInt32 Trie::find_last(const char *str,
133 std::size_t *key_length) const {
134 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
135 MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
136 return find_last_<CQuery>(CQuery(str), key_length);
137 }
138
find_last(const char * ptr,std::size_t length,std::size_t * key_length) const139 UInt32 Trie::find_last(const char *ptr, std::size_t length,
140 std::size_t *key_length) const {
141 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
142 MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
143 MARISA_ALPHA_PARAM_ERROR);
144 return find_last_<const Query &>(Query(ptr, length), key_length);
145 }
146
predict(const char * str,UInt32 * key_ids,std::string * keys,std::size_t max_num_results) const147 std::size_t Trie::predict(const char *str,
148 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
149 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
150 MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
151 return (keys == NULL) ?
152 predict_breadth_first(str, key_ids, keys, max_num_results) :
153 predict_depth_first(str, key_ids, keys, max_num_results);
154 }
155
predict(const char * ptr,std::size_t length,UInt32 * key_ids,std::string * keys,std::size_t max_num_results) const156 std::size_t Trie::predict(const char *ptr, std::size_t length,
157 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
158 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
159 MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
160 MARISA_ALPHA_PARAM_ERROR);
161 return (keys == NULL) ?
162 predict_breadth_first(ptr, length, key_ids, keys, max_num_results) :
163 predict_depth_first(ptr, length, key_ids, keys, max_num_results);
164 }
165
predict(const char * str,std::vector<UInt32> * key_ids,std::vector<std::string> * keys,std::size_t max_num_results) const166 std::size_t Trie::predict(const char *str,
167 std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
168 std::size_t max_num_results) const {
169 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
170 MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
171 return (keys == NULL) ?
172 predict_breadth_first(str, key_ids, keys, max_num_results) :
173 predict_depth_first(str, key_ids, keys, max_num_results);
174 }
175
predict(const char * ptr,std::size_t length,std::vector<UInt32> * key_ids,std::vector<std::string> * keys,std::size_t max_num_results) const176 std::size_t Trie::predict(const char *ptr, std::size_t length,
177 std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
178 std::size_t max_num_results) const {
179 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
180 MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
181 MARISA_ALPHA_PARAM_ERROR);
182 return (keys == NULL) ?
183 predict_breadth_first(ptr, length, key_ids, keys, max_num_results) :
184 predict_depth_first(ptr, length, key_ids, keys, max_num_results);
185 }
186
predict_breadth_first(const char * str,UInt32 * key_ids,std::string * keys,std::size_t max_num_results) const187 std::size_t Trie::predict_breadth_first(const char *str,
188 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
189 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
190 MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
191 return predict_breadth_first_<CQuery>(CQuery(str),
192 MakeContainer(key_ids), MakeContainer(keys), max_num_results);
193 }
194
predict_breadth_first(const char * ptr,std::size_t length,UInt32 * key_ids,std::string * keys,std::size_t max_num_results) const195 std::size_t Trie::predict_breadth_first(const char *ptr, std::size_t length,
196 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
197 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
198 MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
199 MARISA_ALPHA_PARAM_ERROR);
200 return predict_breadth_first_<const Query &>(Query(ptr, length),
201 MakeContainer(key_ids), MakeContainer(keys), max_num_results);
202 }
203
predict_breadth_first(const char * str,std::vector<UInt32> * key_ids,std::vector<std::string> * keys,std::size_t max_num_results) const204 std::size_t Trie::predict_breadth_first(const char *str,
205 std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
206 std::size_t max_num_results) const {
207 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
208 MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
209 return predict_breadth_first_<CQuery>(CQuery(str),
210 MakeContainer(key_ids), MakeContainer(keys), max_num_results);
211 }
212
predict_breadth_first(const char * ptr,std::size_t length,std::vector<UInt32> * key_ids,std::vector<std::string> * keys,std::size_t max_num_results) const213 std::size_t Trie::predict_breadth_first(const char *ptr, std::size_t length,
214 std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
215 std::size_t max_num_results) const {
216 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
217 MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
218 MARISA_ALPHA_PARAM_ERROR);
219 return predict_breadth_first_<const Query &>(Query(ptr, length),
220 MakeContainer(key_ids), MakeContainer(keys), max_num_results);
221 }
222
predict_depth_first(const char * str,UInt32 * key_ids,std::string * keys,std::size_t max_num_results) const223 std::size_t Trie::predict_depth_first(const char *str,
224 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
225 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
226 MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
227 return predict_depth_first_<CQuery>(CQuery(str),
228 MakeContainer(key_ids), MakeContainer(keys), max_num_results);
229 }
230
predict_depth_first(const char * ptr,std::size_t length,UInt32 * key_ids,std::string * keys,std::size_t max_num_results) const231 std::size_t Trie::predict_depth_first(const char *ptr, std::size_t length,
232 UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
233 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
234 MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
235 MARISA_ALPHA_PARAM_ERROR);
236 return predict_depth_first_<const Query &>(Query(ptr, length),
237 MakeContainer(key_ids), MakeContainer(keys), max_num_results);
238 }
239
predict_depth_first(const char * str,std::vector<UInt32> * key_ids,std::vector<std::string> * keys,std::size_t max_num_results) const240 std::size_t Trie::predict_depth_first(
241 const char *str, std::vector<UInt32> *key_ids,
242 std::vector<std::string> *keys, std::size_t max_num_results) const {
243 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
244 MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
245 return predict_depth_first_<CQuery>(CQuery(str),
246 MakeContainer(key_ids), MakeContainer(keys), max_num_results);
247 }
248
predict_depth_first(const char * ptr,std::size_t length,std::vector<UInt32> * key_ids,std::vector<std::string> * keys,std::size_t max_num_results) const249 std::size_t Trie::predict_depth_first(
250 const char *ptr, std::size_t length, std::vector<UInt32> *key_ids,
251 std::vector<std::string> *keys, std::size_t max_num_results) const {
252 MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
253 MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
254 MARISA_ALPHA_PARAM_ERROR);
255 return predict_depth_first_<const Query &>(Query(ptr, length),
256 MakeContainer(key_ids), MakeContainer(keys), max_num_results);
257 }
258
restore_(UInt32 key_id,std::string * key) const259 void Trie::restore_(UInt32 key_id, std::string *key) const {
260 const std::size_t start_pos = key->length();
261 try {
262 UInt32 node = key_id_to_node(key_id);
263 while (node != 0) {
264 if (has_link(node)) {
265 const std::size_t prev_pos = key->length();
266 if (has_trie()) {
267 trie_->trie_restore(get_link(node), key);
268 } else {
269 tail_restore(node, key);
270 }
271 std::reverse(key->begin() + prev_pos, key->end());
272 } else {
273 *key += labels_[node];
274 }
275 node = get_parent(node);
276 }
277 std::reverse(key->begin() + start_pos, key->end());
278 } catch (const std::bad_alloc &) {
279 key->resize(start_pos);
280 MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR);
281 } catch (const std::length_error &) {
282 key->resize(start_pos);
283 MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR);
284 }
285 }
286
trie_restore(UInt32 node,std::string * key) const287 void Trie::trie_restore(UInt32 node, std::string *key) const {
288 do {
289 if (has_link(node)) {
290 if (has_trie()) {
291 trie_->trie_restore(get_link(node), key);
292 } else {
293 tail_restore(node, key);
294 }
295 } else {
296 *key += labels_[node];
297 }
298 node = get_parent(node);
299 } while (node != 0);
300 }
301
tail_restore(UInt32 node,std::string * key) const302 void Trie::tail_restore(UInt32 node, std::string *key) const {
303 const UInt32 link_id = link_flags_.rank1(node);
304 const UInt32 offset = (links_[link_id] * 256) + labels_[node];
305 if (tail_.mode() == MARISA_ALPHA_BINARY_TAIL) {
306 const UInt32 length = (links_[link_id + 1] * 256)
307 + labels_[link_flags_.select1(link_id + 1)] - offset;
308 key->append(reinterpret_cast<const char *>(tail_[offset]), length);
309 } else {
310 key->append(reinterpret_cast<const char *>(tail_[offset]));
311 }
312 }
313
restore_(UInt32 key_id,char * key_buf,std::size_t key_buf_size) const314 std::size_t Trie::restore_(UInt32 key_id, char *key_buf,
315 std::size_t key_buf_size) const {
316 std::size_t pos = 0;
317 UInt32 node = key_id_to_node(key_id);
318 while (node != 0) {
319 if (has_link(node)) {
320 const std::size_t prev_pos = pos;
321 if (has_trie()) {
322 trie_->trie_restore(get_link(node), key_buf, key_buf_size, pos);
323 } else {
324 tail_restore(node, key_buf, key_buf_size, pos);
325 }
326 if (pos < key_buf_size) {
327 std::reverse(key_buf + prev_pos, key_buf + pos);
328 }
329 } else {
330 if (pos < key_buf_size) {
331 key_buf[pos] = labels_[node];
332 }
333 ++pos;
334 }
335 node = get_parent(node);
336 }
337 if (pos < key_buf_size) {
338 key_buf[pos] = '\0';
339 std::reverse(key_buf, key_buf + pos);
340 }
341 return pos;
342 }
343
trie_restore(UInt32 node,char * key_buf,std::size_t key_buf_size,std::size_t & pos) const344 void Trie::trie_restore(UInt32 node, char *key_buf,
345 std::size_t key_buf_size, std::size_t &pos) const {
346 do {
347 if (has_link(node)) {
348 if (has_trie()) {
349 trie_->trie_restore(get_link(node), key_buf, key_buf_size, pos);
350 } else {
351 tail_restore(node, key_buf, key_buf_size, pos);
352 }
353 } else {
354 if (pos < key_buf_size) {
355 key_buf[pos] = labels_[node];
356 }
357 ++pos;
358 }
359 node = get_parent(node);
360 } while (node != 0);
361 }
362
tail_restore(UInt32 node,char * key_buf,std::size_t key_buf_size,std::size_t & pos) const363 void Trie::tail_restore(UInt32 node, char *key_buf,
364 std::size_t key_buf_size, std::size_t &pos) const {
365 const UInt32 link_id = link_flags_.rank1(node);
366 const UInt32 offset = (links_[link_id] * 256) + labels_[node];
367 if (tail_.mode() == MARISA_ALPHA_BINARY_TAIL) {
368 const UInt8 *ptr = tail_[offset];
369 const UInt32 length = (links_[link_id + 1] * 256)
370 + labels_[link_flags_.select1(link_id + 1)] - offset;
371 for (UInt32 i = 0; i < length; ++i) {
372 if (pos < key_buf_size) {
373 key_buf[pos] = ptr[i];
374 }
375 ++pos;
376 }
377 } else {
378 for (const UInt8 *str = tail_[offset]; *str != '\0'; ++str) {
379 if (pos < key_buf_size) {
380 key_buf[pos] = *str;
381 }
382 ++pos;
383 }
384 }
385 }
386
387 template <typename T>
lookup_(T query) const388 UInt32 Trie::lookup_(T query) const {
389 UInt32 node = 0;
390 std::size_t pos = 0;
391 while (!query.ends_at(pos)) {
392 if (!find_child<T>(node, query, pos)) {
393 return notfound();
394 }
395 }
396 return terminal_flags_[node] ? node_to_key_id(node) : notfound();
397 }
398
399 template <typename T>
trie_match(UInt32 node,T query,std::size_t pos) const400 std::size_t Trie::trie_match(UInt32 node, T query,
401 std::size_t pos) const {
402 if (has_link(node)) {
403 std::size_t next_pos;
404 if (has_trie()) {
405 next_pos = trie_->trie_match<T>(get_link(node), query, pos);
406 } else {
407 next_pos = tail_match<T>(node, get_link_id(node), query, pos);
408 }
409 if ((next_pos == mismatch()) || (next_pos == pos)) {
410 return next_pos;
411 }
412 pos = next_pos;
413 } else if (labels_[node] != query[pos]) {
414 return pos;
415 } else {
416 ++pos;
417 }
418 node = get_parent(node);
419 while (node != 0) {
420 if (query.ends_at(pos)) {
421 return mismatch();
422 }
423 if (has_link(node)) {
424 std::size_t next_pos;
425 if (has_trie()) {
426 next_pos = trie_->trie_match<T>(get_link(node), query, pos);
427 } else {
428 next_pos = tail_match<T>(node, get_link_id(node), query, pos);
429 }
430 if ((next_pos == mismatch()) || (next_pos == pos)) {
431 return mismatch();
432 }
433 pos = next_pos;
434 } else if (labels_[node] != query[pos]) {
435 return mismatch();
436 } else {
437 ++pos;
438 }
439 node = get_parent(node);
440 }
441 return pos;
442 }
443
444 template std::size_t Trie::trie_match<CQuery>(UInt32 node,
445 CQuery query, std::size_t pos) const;
446 template std::size_t Trie::trie_match<const Query &>(UInt32 node,
447 const Query &query, std::size_t pos) const;
448
449 template <typename T>
tail_match(UInt32 node,UInt32 link_id,T query,std::size_t pos) const450 std::size_t Trie::tail_match(UInt32 node, UInt32 link_id,
451 T query, std::size_t pos) const {
452 const UInt32 offset = (links_[link_id] * 256) + labels_[node];
453 const UInt8 *ptr = tail_[offset];
454 if (*ptr != query[pos]) {
455 return pos;
456 } else if (tail_.mode() == MARISA_ALPHA_BINARY_TAIL) {
457 const UInt32 length = (links_[link_id + 1] * 256)
458 + labels_[link_flags_.select1(link_id + 1)] - offset;
459 for (UInt32 i = 1; i < length; ++i) {
460 if (query.ends_at(pos + i) || (ptr[i] != query[pos + i])) {
461 return mismatch();
462 }
463 }
464 return pos + length;
465 } else {
466 for (++ptr, ++pos; *ptr != '\0'; ++ptr, ++pos) {
467 if (query.ends_at(pos) || (*ptr != query[pos])) {
468 return mismatch();
469 }
470 }
471 return pos;
472 }
473 }
474
475 template std::size_t Trie::tail_match<CQuery>(UInt32 node,
476 UInt32 link_id, CQuery query, std::size_t pos) const;
477 template std::size_t Trie::tail_match<const Query &>(UInt32 node,
478 UInt32 link_id, const Query &query, std::size_t pos) const;
479
480 template <typename T, typename U, typename V>
find_(T query,U key_ids,V key_lengths,std::size_t max_num_results) const481 std::size_t Trie::find_(T query, U key_ids, V key_lengths,
482 std::size_t max_num_results) const try {
483 if (max_num_results == 0) {
484 return 0;
485 }
486 std::size_t count = 0;
487 UInt32 node = 0;
488 std::size_t pos = 0;
489 do {
490 if (terminal_flags_[node]) {
491 if (key_ids.is_valid()) {
492 key_ids.insert(count, node_to_key_id(node));
493 }
494 if (key_lengths.is_valid()) {
495 key_lengths.insert(count, pos);
496 }
497 if (++count >= max_num_results) {
498 return count;
499 }
500 }
501 } while (!query.ends_at(pos) && find_child<T>(node, query, pos));
502 return count;
503 } catch (const std::bad_alloc &) {
504 MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR);
505 } catch (const std::length_error &) {
506 MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR);
507 }
508
509 template <typename T>
find_first_(T query,std::size_t * key_length) const510 UInt32 Trie::find_first_(T query, std::size_t *key_length) const {
511 UInt32 node = 0;
512 std::size_t pos = 0;
513 do {
514 if (terminal_flags_[node]) {
515 if (key_length != NULL) {
516 *key_length = pos;
517 }
518 return node_to_key_id(node);
519 }
520 } while (!query.ends_at(pos) && find_child<T>(node, query, pos));
521 return notfound();
522 }
523
524 template <typename T>
find_last_(T query,std::size_t * key_length) const525 UInt32 Trie::find_last_(T query, std::size_t *key_length) const {
526 UInt32 node = 0;
527 UInt32 node_found = notfound();
528 std::size_t pos = 0;
529 std::size_t pos_found = mismatch();
530 do {
531 if (terminal_flags_[node]) {
532 node_found = node;
533 pos_found = pos;
534 }
535 } while (!query.ends_at(pos) && find_child<T>(node, query, pos));
536 if (node_found != notfound()) {
537 if (key_length != NULL) {
538 *key_length = pos_found;
539 }
540 return node_to_key_id(node_found);
541 }
542 return notfound();
543 }
544
545 template <typename T, typename U, typename V>
predict_breadth_first_(T query,U key_ids,V keys,std::size_t max_num_results) const546 std::size_t Trie::predict_breadth_first_(T query, U key_ids, V keys,
547 std::size_t max_num_results) const try {
548 if (max_num_results == 0) {
549 return 0;
550 }
551 UInt32 node = 0;
552 std::size_t pos = 0;
553 while (!query.ends_at(pos)) {
554 if (!predict_child<T>(node, query, pos, NULL)) {
555 return 0;
556 }
557 }
558 std::string key;
559 std::size_t count = 0;
560 if (terminal_flags_[node]) {
561 const UInt32 key_id = node_to_key_id(node);
562 if (key_ids.is_valid()) {
563 key_ids.insert(count, key_id);
564 }
565 if (keys.is_valid()) {
566 restore(key_id, &key);
567 keys.insert(count, key);
568 }
569 if (++count >= max_num_results) {
570 return count;
571 }
572 }
573 const UInt32 louds_pos = get_child(node);
574 if (!louds_[louds_pos]) {
575 return count;
576 }
577 UInt32 node_begin = louds_pos_to_node(louds_pos, node);
578 UInt32 node_end = louds_pos_to_node(get_child(node + 1), node + 1);
579 while (node_begin < node_end) {
580 const UInt32 key_id_begin = node_to_key_id(node_begin);
581 const UInt32 key_id_end = node_to_key_id(node_end);
582 if (key_ids.is_valid()) {
583 UInt32 temp_count = count;
584 for (UInt32 key_id = key_id_begin; key_id < key_id_end; ++key_id) {
585 key_ids.insert(temp_count, key_id);
586 if (++temp_count >= max_num_results) {
587 break;
588 }
589 }
590 }
591 if (keys.is_valid()) {
592 UInt32 temp_count = count;
593 for (UInt32 key_id = key_id_begin; key_id < key_id_end; ++key_id) {
594 key.clear();
595 restore(key_id, &key);
596 keys.insert(temp_count, key);
597 if (++temp_count >= max_num_results) {
598 break;
599 }
600 }
601 }
602 count += key_id_end - key_id_begin;
603 if (count >= max_num_results) {
604 return max_num_results;
605 }
606 node_begin = louds_pos_to_node(get_child(node_begin), node_begin);
607 node_end = louds_pos_to_node(get_child(node_end), node_end);
608 }
609 return count;
610 } catch (const std::bad_alloc &) {
611 MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR);
612 } catch (const std::length_error &) {
613 MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR);
614 }
615
616 template <typename T, typename U, typename V>
predict_depth_first_(T query,U key_ids,V keys,std::size_t max_num_results) const617 std::size_t Trie::predict_depth_first_(T query, U key_ids, V keys,
618 std::size_t max_num_results) const try {
619 if (max_num_results == 0) {
620 return 0;
621 } else if (keys.is_valid()) {
622 PredictCallback<U, V> callback(key_ids, keys, max_num_results);
623 return predict_callback_(query, callback);
624 }
625
626 UInt32 node = 0;
627 std::size_t pos = 0;
628 while (!query.ends_at(pos)) {
629 if (!predict_child<T>(node, query, pos, NULL)) {
630 return 0;
631 }
632 }
633 std::size_t count = 0;
634 if (terminal_flags_[node]) {
635 if (key_ids.is_valid()) {
636 key_ids.insert(count, node_to_key_id(node));
637 }
638 if (++count >= max_num_results) {
639 return count;
640 }
641 }
642 Cell cell;
643 cell.set_louds_pos(get_child(node));
644 if (!louds_[cell.louds_pos()]) {
645 return count;
646 }
647 cell.set_node(louds_pos_to_node(cell.louds_pos(), node));
648 cell.set_key_id(node_to_key_id(cell.node()));
649 Vector<Cell> stack;
650 stack.push_back(cell);
651 std::size_t stack_pos = 1;
652 while (stack_pos != 0) {
653 Cell &cur = stack[stack_pos - 1];
654 if (!louds_[cur.louds_pos()]) {
655 cur.set_louds_pos(cur.louds_pos() + 1);
656 --stack_pos;
657 continue;
658 }
659 cur.set_louds_pos(cur.louds_pos() + 1);
660 if (terminal_flags_[cur.node()]) {
661 if (key_ids.is_valid()) {
662 key_ids.insert(count, cur.key_id());
663 }
664 if (++count >= max_num_results) {
665 return count;
666 }
667 cur.set_key_id(cur.key_id() + 1);
668 }
669 if (stack_pos == stack.size()) {
670 cell.set_louds_pos(get_child(cur.node()));
671 cell.set_node(louds_pos_to_node(cell.louds_pos(), cur.node()));
672 cell.set_key_id(node_to_key_id(cell.node()));
673 stack.push_back(cell);
674 }
675 stack[stack_pos - 1].set_node(stack[stack_pos - 1].node() + 1);
676 ++stack_pos;
677 }
678 return count;
679 } catch (const std::bad_alloc &) {
680 MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR);
681 } catch (const std::length_error &) {
682 MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR);
683 }
684
685 template <typename T>
trie_prefix_match(UInt32 node,T query,std::size_t pos,std::string * key) const686 std::size_t Trie::trie_prefix_match(UInt32 node, T query,
687 std::size_t pos, std::string *key) const {
688 if (has_link(node)) {
689 std::size_t next_pos;
690 if (has_trie()) {
691 next_pos = trie_->trie_prefix_match<T>(get_link(node), query, pos, key);
692 } else {
693 next_pos = tail_prefix_match<T>(
694 node, get_link_id(node), query, pos, key);
695 }
696 if ((next_pos == mismatch()) || (next_pos == pos)) {
697 return next_pos;
698 }
699 pos = next_pos;
700 } else if (labels_[node] != query[pos]) {
701 return pos;
702 } else {
703 ++pos;
704 }
705 node = get_parent(node);
706 while (node != 0) {
707 if (query.ends_at(pos)) {
708 if (key != NULL) {
709 trie_restore(node, key);
710 }
711 return pos;
712 }
713 if (has_link(node)) {
714 std::size_t next_pos;
715 if (has_trie()) {
716 next_pos = trie_->trie_prefix_match<T>(
717 get_link(node), query, pos, key);
718 } else {
719 next_pos = tail_prefix_match<T>(
720 node, get_link_id(node), query, pos, key);
721 }
722 if ((next_pos == mismatch()) || (next_pos == pos)) {
723 return next_pos;
724 }
725 pos = next_pos;
726 } else if (labels_[node] != query[pos]) {
727 return mismatch();
728 } else {
729 ++pos;
730 }
731 node = get_parent(node);
732 }
733 return pos;
734 }
735
736 template std::size_t Trie::trie_prefix_match<CQuery>(UInt32 node,
737 CQuery query, std::size_t pos, std::string *key) const;
738 template std::size_t Trie::trie_prefix_match<const Query &>(UInt32 node,
739 const Query &query, std::size_t pos, std::string *key) const;
740
741 template <typename T>
tail_prefix_match(UInt32 node,UInt32 link_id,T query,std::size_t pos,std::string * key) const742 std::size_t Trie::tail_prefix_match(UInt32 node, UInt32 link_id,
743 T query, std::size_t pos, std::string *key) const {
744 const UInt32 offset = (links_[link_id] * 256) + labels_[node];
745 const UInt8 *ptr = tail_[offset];
746 if (*ptr != query[pos]) {
747 return pos;
748 } else if (tail_.mode() == MARISA_ALPHA_BINARY_TAIL) {
749 const UInt32 length = (links_[link_id + 1] * 256)
750 + labels_[link_flags_.select1(link_id + 1)] - offset;
751 for (UInt32 i = 1; i < length; ++i) {
752 if (query.ends_at(pos + i)) {
753 if (key != NULL) {
754 key->append(reinterpret_cast<const char *>(ptr + i), length - i);
755 }
756 return pos + i;
757 } else if (ptr[i] != query[pos + i]) {
758 return mismatch();
759 }
760 }
761 return pos + length;
762 } else {
763 for (++ptr, ++pos; *ptr != '\0'; ++ptr, ++pos) {
764 if (query.ends_at(pos)) {
765 if (key != NULL) {
766 key->append(reinterpret_cast<const char *>(ptr));
767 }
768 return pos;
769 } else if (*ptr != query[pos]) {
770 return mismatch();
771 }
772 }
773 return pos;
774 }
775 }
776
777 template std::size_t Trie::tail_prefix_match<CQuery>(
778 UInt32 node, UInt32 link_id,
779 CQuery query, std::size_t pos, std::string *key) const;
780 template std::size_t Trie::tail_prefix_match<const Query &>(
781 UInt32 node, UInt32 link_id,
782 const Query &query, std::size_t pos, std::string *key) const;
783
784 } // namespace marisa_alpha
785