• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <algorithm>
2 #include <stdexcept>
3 
4 #include "trie.h"
5 
6 namespace marisa_alpha {
7 namespace {
8 
9 template <typename T, typename U>
10 class PredictCallback {
11  public:
PredictCallback(T key_ids,U keys,std::size_t max_num_results)12   PredictCallback(T key_ids, U keys, std::size_t max_num_results)
13       : key_ids_(key_ids), keys_(keys),
14         max_num_results_(max_num_results), num_results_(0) {}
PredictCallback(const PredictCallback & callback)15   PredictCallback(const PredictCallback &callback)
16       : key_ids_(callback.key_ids_), keys_(callback.keys_),
17         max_num_results_(callback.max_num_results_),
18         num_results_(callback.num_results_) {}
19 
operator ()(marisa_alpha::UInt32 key_id,const std::string & key)20   bool operator()(marisa_alpha::UInt32 key_id, const std::string &key) {
21     if (key_ids_.is_valid()) {
22       key_ids_.insert(num_results_, key_id);
23     }
24     if (keys_.is_valid()) {
25       keys_.insert(num_results_, key);
26     }
27     return ++num_results_ < max_num_results_;
28   }
29 
30  private:
31   T key_ids_;
32   U keys_;
33   const std::size_t max_num_results_;
34   std::size_t num_results_;
35 
36   // Disallows assignment.
37   PredictCallback &operator=(const PredictCallback &);
38 };
39 
40 }  // namespace
41 
restore(UInt32 key_id) const42 std::string Trie::restore(UInt32 key_id) const {
43   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
44   MARISA_ALPHA_THROW_IF(key_id >= num_keys_, MARISA_ALPHA_PARAM_ERROR);
45   std::string key;
46   restore_(key_id, &key);
47   return key;
48 }
49 
restore(UInt32 key_id,std::string * key) const50 void Trie::restore(UInt32 key_id, std::string *key) const {
51   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
52   MARISA_ALPHA_THROW_IF(key_id >= num_keys_, MARISA_ALPHA_PARAM_ERROR);
53   MARISA_ALPHA_THROW_IF(key == NULL, MARISA_ALPHA_PARAM_ERROR);
54   restore_(key_id, key);
55 }
56 
restore(UInt32 key_id,char * key_buf,std::size_t key_buf_size) const57 std::size_t Trie::restore(UInt32 key_id, char *key_buf,
58     std::size_t key_buf_size) const {
59   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
60   MARISA_ALPHA_THROW_IF(key_id >= num_keys_, MARISA_ALPHA_PARAM_ERROR);
61   MARISA_ALPHA_THROW_IF((key_buf == NULL) && (key_buf_size != 0),
62       MARISA_ALPHA_PARAM_ERROR);
63   return restore_(key_id, key_buf, key_buf_size);
64 }
65 
lookup(const char * str) const66 UInt32 Trie::lookup(const char *str) const {
67   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
68   MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
69   return lookup_<CQuery>(CQuery(str));
70 }
71 
lookup(const char * ptr,std::size_t length) const72 UInt32 Trie::lookup(const char *ptr, std::size_t length) const {
73   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
74   MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
75       MARISA_ALPHA_PARAM_ERROR);
76   return lookup_<const Query &>(Query(ptr, length));
77 }
78 
find(const char * str,UInt32 * key_ids,std::size_t * key_lengths,std::size_t max_num_results) const79 std::size_t Trie::find(const char *str,
80     UInt32 *key_ids, std::size_t *key_lengths,
81     std::size_t max_num_results) const {
82   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
83   MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
84   return find_<CQuery>(CQuery(str),
85       MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
86 }
87 
find(const char * ptr,std::size_t length,UInt32 * key_ids,std::size_t * key_lengths,std::size_t max_num_results) const88 std::size_t Trie::find(const char *ptr, std::size_t length,
89     UInt32 *key_ids, std::size_t *key_lengths,
90     std::size_t max_num_results) const {
91   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
92   MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
93       MARISA_ALPHA_PARAM_ERROR);
94   return find_<const Query &>(Query(ptr, length),
95       MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
96 }
97 
find(const char * str,std::vector<UInt32> * key_ids,std::vector<std::size_t> * key_lengths,std::size_t max_num_results) const98 std::size_t Trie::find(const char *str,
99     std::vector<UInt32> *key_ids, std::vector<std::size_t> *key_lengths,
100     std::size_t max_num_results) const {
101   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
102   MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
103   return find_<CQuery>(CQuery(str),
104       MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
105 }
106 
find(const char * ptr,std::size_t length,std::vector<UInt32> * key_ids,std::vector<std::size_t> * key_lengths,std::size_t max_num_results) const107 std::size_t Trie::find(const char *ptr, std::size_t length,
108     std::vector<UInt32> *key_ids, std::vector<std::size_t> *key_lengths,
109     std::size_t max_num_results) const {
110   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
111   MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
112       MARISA_ALPHA_PARAM_ERROR);
113   return find_<const Query &>(Query(ptr, length),
114       MakeContainer(key_ids), MakeContainer(key_lengths), max_num_results);
115 }
116 
find_first(const char * str,std::size_t * key_length) const117 UInt32 Trie::find_first(const char *str,
118     std::size_t *key_length) const {
119   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
120   MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
121   return find_first_<CQuery>(CQuery(str), key_length);
122 }
123 
find_first(const char * ptr,std::size_t length,std::size_t * key_length) const124 UInt32 Trie::find_first(const char *ptr, std::size_t length,
125     std::size_t *key_length) const {
126   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
127   MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
128       MARISA_ALPHA_PARAM_ERROR);
129   return find_first_<const Query &>(Query(ptr, length), key_length);
130 }
131 
find_last(const char * str,std::size_t * key_length) const132 UInt32 Trie::find_last(const char *str,
133     std::size_t *key_length) const {
134   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
135   MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
136   return find_last_<CQuery>(CQuery(str), key_length);
137 }
138 
find_last(const char * ptr,std::size_t length,std::size_t * key_length) const139 UInt32 Trie::find_last(const char *ptr, std::size_t length,
140     std::size_t *key_length) const {
141   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
142   MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
143       MARISA_ALPHA_PARAM_ERROR);
144   return find_last_<const Query &>(Query(ptr, length), key_length);
145 }
146 
predict(const char * str,UInt32 * key_ids,std::string * keys,std::size_t max_num_results) const147 std::size_t Trie::predict(const char *str,
148     UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
149   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
150   MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
151   return (keys == NULL) ?
152       predict_breadth_first(str, key_ids, keys, max_num_results) :
153       predict_depth_first(str, key_ids, keys, max_num_results);
154 }
155 
predict(const char * ptr,std::size_t length,UInt32 * key_ids,std::string * keys,std::size_t max_num_results) const156 std::size_t Trie::predict(const char *ptr, std::size_t length,
157     UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
158   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
159   MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
160       MARISA_ALPHA_PARAM_ERROR);
161   return (keys == NULL) ?
162       predict_breadth_first(ptr, length, key_ids, keys, max_num_results) :
163       predict_depth_first(ptr, length, key_ids, keys, max_num_results);
164 }
165 
predict(const char * str,std::vector<UInt32> * key_ids,std::vector<std::string> * keys,std::size_t max_num_results) const166 std::size_t Trie::predict(const char *str,
167     std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
168     std::size_t max_num_results) const {
169   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
170   MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
171   return (keys == NULL) ?
172       predict_breadth_first(str, key_ids, keys, max_num_results) :
173       predict_depth_first(str, key_ids, keys, max_num_results);
174 }
175 
predict(const char * ptr,std::size_t length,std::vector<UInt32> * key_ids,std::vector<std::string> * keys,std::size_t max_num_results) const176 std::size_t Trie::predict(const char *ptr, std::size_t length,
177     std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
178     std::size_t max_num_results) const {
179   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
180   MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
181       MARISA_ALPHA_PARAM_ERROR);
182   return (keys == NULL) ?
183       predict_breadth_first(ptr, length, key_ids, keys, max_num_results) :
184       predict_depth_first(ptr, length, key_ids, keys, max_num_results);
185 }
186 
predict_breadth_first(const char * str,UInt32 * key_ids,std::string * keys,std::size_t max_num_results) const187 std::size_t Trie::predict_breadth_first(const char *str,
188     UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
189   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
190   MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
191   return predict_breadth_first_<CQuery>(CQuery(str),
192       MakeContainer(key_ids), MakeContainer(keys), max_num_results);
193 }
194 
predict_breadth_first(const char * ptr,std::size_t length,UInt32 * key_ids,std::string * keys,std::size_t max_num_results) const195 std::size_t Trie::predict_breadth_first(const char *ptr, std::size_t length,
196     UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
197   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
198   MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
199       MARISA_ALPHA_PARAM_ERROR);
200   return predict_breadth_first_<const Query &>(Query(ptr, length),
201       MakeContainer(key_ids), MakeContainer(keys), max_num_results);
202 }
203 
predict_breadth_first(const char * str,std::vector<UInt32> * key_ids,std::vector<std::string> * keys,std::size_t max_num_results) const204 std::size_t Trie::predict_breadth_first(const char *str,
205     std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
206     std::size_t max_num_results) const {
207   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
208   MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
209   return predict_breadth_first_<CQuery>(CQuery(str),
210       MakeContainer(key_ids), MakeContainer(keys), max_num_results);
211 }
212 
predict_breadth_first(const char * ptr,std::size_t length,std::vector<UInt32> * key_ids,std::vector<std::string> * keys,std::size_t max_num_results) const213 std::size_t Trie::predict_breadth_first(const char *ptr, std::size_t length,
214     std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
215     std::size_t max_num_results) const {
216   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
217   MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
218       MARISA_ALPHA_PARAM_ERROR);
219   return predict_breadth_first_<const Query &>(Query(ptr, length),
220       MakeContainer(key_ids), MakeContainer(keys), max_num_results);
221 }
222 
predict_depth_first(const char * str,UInt32 * key_ids,std::string * keys,std::size_t max_num_results) const223 std::size_t Trie::predict_depth_first(const char *str,
224     UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
225   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
226   MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
227   return predict_depth_first_<CQuery>(CQuery(str),
228       MakeContainer(key_ids), MakeContainer(keys), max_num_results);
229 }
230 
predict_depth_first(const char * ptr,std::size_t length,UInt32 * key_ids,std::string * keys,std::size_t max_num_results) const231 std::size_t Trie::predict_depth_first(const char *ptr, std::size_t length,
232     UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
233   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
234   MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
235       MARISA_ALPHA_PARAM_ERROR);
236   return predict_depth_first_<const Query &>(Query(ptr, length),
237       MakeContainer(key_ids), MakeContainer(keys), max_num_results);
238 }
239 
predict_depth_first(const char * str,std::vector<UInt32> * key_ids,std::vector<std::string> * keys,std::size_t max_num_results) const240 std::size_t Trie::predict_depth_first(
241     const char *str, std::vector<UInt32> *key_ids,
242     std::vector<std::string> *keys, std::size_t max_num_results) const {
243   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
244   MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
245   return predict_depth_first_<CQuery>(CQuery(str),
246       MakeContainer(key_ids), MakeContainer(keys), max_num_results);
247 }
248 
predict_depth_first(const char * ptr,std::size_t length,std::vector<UInt32> * key_ids,std::vector<std::string> * keys,std::size_t max_num_results) const249 std::size_t Trie::predict_depth_first(
250     const char *ptr, std::size_t length, std::vector<UInt32> *key_ids,
251     std::vector<std::string> *keys, std::size_t max_num_results) const {
252   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
253   MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
254       MARISA_ALPHA_PARAM_ERROR);
255   return predict_depth_first_<const Query &>(Query(ptr, length),
256       MakeContainer(key_ids), MakeContainer(keys), max_num_results);
257 }
258 
restore_(UInt32 key_id,std::string * key) const259 void Trie::restore_(UInt32 key_id, std::string *key) const {
260   const std::size_t start_pos = key->length();
261   try {
262     UInt32 node = key_id_to_node(key_id);
263     while (node != 0) {
264       if (has_link(node)) {
265         const std::size_t prev_pos = key->length();
266         if (has_trie()) {
267           trie_->trie_restore(get_link(node), key);
268         } else {
269           tail_restore(node, key);
270         }
271         std::reverse(key->begin() + prev_pos, key->end());
272       } else {
273         *key += labels_[node];
274       }
275       node = get_parent(node);
276     }
277     std::reverse(key->begin() + start_pos, key->end());
278   } catch (const std::bad_alloc &) {
279     key->resize(start_pos);
280     MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR);
281   } catch (const std::length_error &) {
282     key->resize(start_pos);
283     MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR);
284   }
285 }
286 
trie_restore(UInt32 node,std::string * key) const287 void Trie::trie_restore(UInt32 node, std::string *key) const {
288   do {
289     if (has_link(node)) {
290       if (has_trie()) {
291         trie_->trie_restore(get_link(node), key);
292       } else {
293         tail_restore(node, key);
294       }
295     } else {
296       *key += labels_[node];
297     }
298     node = get_parent(node);
299   } while (node != 0);
300 }
301 
tail_restore(UInt32 node,std::string * key) const302 void Trie::tail_restore(UInt32 node, std::string *key) const {
303   const UInt32 link_id = link_flags_.rank1(node);
304   const UInt32 offset = (links_[link_id] * 256) + labels_[node];
305   if (tail_.mode() == MARISA_ALPHA_BINARY_TAIL) {
306     const UInt32 length = (links_[link_id + 1] * 256)
307         + labels_[link_flags_.select1(link_id + 1)] - offset;
308     key->append(reinterpret_cast<const char *>(tail_[offset]), length);
309   } else {
310     key->append(reinterpret_cast<const char *>(tail_[offset]));
311   }
312 }
313 
restore_(UInt32 key_id,char * key_buf,std::size_t key_buf_size) const314 std::size_t Trie::restore_(UInt32 key_id, char *key_buf,
315     std::size_t key_buf_size) const {
316   std::size_t pos = 0;
317   UInt32 node = key_id_to_node(key_id);
318   while (node != 0) {
319     if (has_link(node)) {
320       const std::size_t prev_pos = pos;
321       if (has_trie()) {
322         trie_->trie_restore(get_link(node), key_buf, key_buf_size, pos);
323       } else {
324         tail_restore(node, key_buf, key_buf_size, pos);
325       }
326       if (pos < key_buf_size) {
327         std::reverse(key_buf + prev_pos, key_buf + pos);
328       }
329     } else {
330       if (pos < key_buf_size) {
331         key_buf[pos] = labels_[node];
332       }
333       ++pos;
334     }
335     node = get_parent(node);
336   }
337   if (pos < key_buf_size) {
338     key_buf[pos] = '\0';
339     std::reverse(key_buf, key_buf + pos);
340   }
341   return pos;
342 }
343 
trie_restore(UInt32 node,char * key_buf,std::size_t key_buf_size,std::size_t & pos) const344 void Trie::trie_restore(UInt32 node, char *key_buf,
345     std::size_t key_buf_size, std::size_t &pos) const {
346   do {
347     if (has_link(node)) {
348       if (has_trie()) {
349         trie_->trie_restore(get_link(node), key_buf, key_buf_size, pos);
350       } else {
351         tail_restore(node, key_buf, key_buf_size, pos);
352       }
353     } else {
354       if (pos < key_buf_size) {
355         key_buf[pos] = labels_[node];
356       }
357       ++pos;
358     }
359     node = get_parent(node);
360   } while (node != 0);
361 }
362 
tail_restore(UInt32 node,char * key_buf,std::size_t key_buf_size,std::size_t & pos) const363 void Trie::tail_restore(UInt32 node, char *key_buf,
364     std::size_t key_buf_size, std::size_t &pos) const {
365   const UInt32 link_id = link_flags_.rank1(node);
366   const UInt32 offset = (links_[link_id] * 256) + labels_[node];
367   if (tail_.mode() == MARISA_ALPHA_BINARY_TAIL) {
368     const UInt8 *ptr = tail_[offset];
369     const UInt32 length = (links_[link_id + 1] * 256)
370         + labels_[link_flags_.select1(link_id + 1)] - offset;
371     for (UInt32 i = 0; i < length; ++i) {
372       if (pos < key_buf_size) {
373         key_buf[pos] = ptr[i];
374       }
375       ++pos;
376     }
377   } else {
378     for (const UInt8 *str = tail_[offset]; *str != '\0'; ++str) {
379       if (pos < key_buf_size) {
380         key_buf[pos] = *str;
381       }
382       ++pos;
383     }
384   }
385 }
386 
387 template <typename T>
lookup_(T query) const388 UInt32 Trie::lookup_(T query) const {
389   UInt32 node = 0;
390   std::size_t pos = 0;
391   while (!query.ends_at(pos)) {
392     if (!find_child<T>(node, query, pos)) {
393       return notfound();
394     }
395   }
396   return terminal_flags_[node] ? node_to_key_id(node) : notfound();
397 }
398 
399 template <typename T>
trie_match(UInt32 node,T query,std::size_t pos) const400 std::size_t Trie::trie_match(UInt32 node, T query,
401     std::size_t pos) const {
402   if (has_link(node)) {
403     std::size_t next_pos;
404     if (has_trie()) {
405       next_pos = trie_->trie_match<T>(get_link(node), query, pos);
406     } else {
407       next_pos = tail_match<T>(node, get_link_id(node), query, pos);
408     }
409     if ((next_pos == mismatch()) || (next_pos == pos)) {
410       return next_pos;
411     }
412     pos = next_pos;
413   } else if (labels_[node] != query[pos]) {
414     return pos;
415   } else {
416     ++pos;
417   }
418   node = get_parent(node);
419   while (node != 0) {
420     if (query.ends_at(pos)) {
421       return mismatch();
422     }
423     if (has_link(node)) {
424       std::size_t next_pos;
425       if (has_trie()) {
426         next_pos = trie_->trie_match<T>(get_link(node), query, pos);
427       } else {
428         next_pos = tail_match<T>(node, get_link_id(node), query, pos);
429       }
430       if ((next_pos == mismatch()) || (next_pos == pos)) {
431         return mismatch();
432       }
433       pos = next_pos;
434     } else if (labels_[node] != query[pos]) {
435       return mismatch();
436     } else {
437       ++pos;
438     }
439     node = get_parent(node);
440   }
441   return pos;
442 }
443 
444 template std::size_t Trie::trie_match<CQuery>(UInt32 node,
445     CQuery query, std::size_t pos) const;
446 template std::size_t Trie::trie_match<const Query &>(UInt32 node,
447     const Query &query, std::size_t pos) const;
448 
449 template <typename T>
tail_match(UInt32 node,UInt32 link_id,T query,std::size_t pos) const450 std::size_t Trie::tail_match(UInt32 node, UInt32 link_id,
451     T query, std::size_t pos) const {
452   const UInt32 offset = (links_[link_id] * 256) + labels_[node];
453   const UInt8 *ptr = tail_[offset];
454   if (*ptr != query[pos]) {
455     return pos;
456   } else if (tail_.mode() == MARISA_ALPHA_BINARY_TAIL) {
457     const UInt32 length = (links_[link_id + 1] * 256)
458         + labels_[link_flags_.select1(link_id + 1)] - offset;
459     for (UInt32 i = 1; i < length; ++i) {
460       if (query.ends_at(pos + i) || (ptr[i] != query[pos + i])) {
461         return mismatch();
462       }
463     }
464     return pos + length;
465   } else {
466     for (++ptr, ++pos; *ptr != '\0'; ++ptr, ++pos) {
467       if (query.ends_at(pos) || (*ptr != query[pos])) {
468         return mismatch();
469       }
470     }
471     return pos;
472   }
473 }
474 
475 template std::size_t Trie::tail_match<CQuery>(UInt32 node,
476     UInt32 link_id, CQuery query, std::size_t pos) const;
477 template std::size_t Trie::tail_match<const Query &>(UInt32 node,
478     UInt32 link_id, const Query &query, std::size_t pos) const;
479 
480 template <typename T, typename U, typename V>
find_(T query,U key_ids,V key_lengths,std::size_t max_num_results) const481 std::size_t Trie::find_(T query, U key_ids, V key_lengths,
482     std::size_t max_num_results) const try {
483   if (max_num_results == 0) {
484     return 0;
485   }
486   std::size_t count = 0;
487   UInt32 node = 0;
488   std::size_t pos = 0;
489   do {
490     if (terminal_flags_[node]) {
491       if (key_ids.is_valid()) {
492         key_ids.insert(count, node_to_key_id(node));
493       }
494       if (key_lengths.is_valid()) {
495         key_lengths.insert(count, pos);
496       }
497       if (++count >= max_num_results) {
498         return count;
499       }
500     }
501   } while (!query.ends_at(pos) && find_child<T>(node, query, pos));
502   return count;
503 } catch (const std::bad_alloc &) {
504   MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR);
505 } catch (const std::length_error &) {
506   MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR);
507 }
508 
509 template <typename T>
find_first_(T query,std::size_t * key_length) const510 UInt32 Trie::find_first_(T query, std::size_t *key_length) const {
511   UInt32 node = 0;
512   std::size_t pos = 0;
513   do {
514     if (terminal_flags_[node]) {
515       if (key_length != NULL) {
516         *key_length = pos;
517       }
518       return node_to_key_id(node);
519     }
520   } while (!query.ends_at(pos) && find_child<T>(node, query, pos));
521   return notfound();
522 }
523 
524 template <typename T>
find_last_(T query,std::size_t * key_length) const525 UInt32 Trie::find_last_(T query, std::size_t *key_length) const {
526   UInt32 node = 0;
527   UInt32 node_found = notfound();
528   std::size_t pos = 0;
529   std::size_t pos_found = mismatch();
530   do {
531     if (terminal_flags_[node]) {
532       node_found = node;
533       pos_found = pos;
534     }
535   } while (!query.ends_at(pos) && find_child<T>(node, query, pos));
536   if (node_found != notfound()) {
537     if (key_length != NULL) {
538       *key_length = pos_found;
539     }
540     return node_to_key_id(node_found);
541   }
542   return notfound();
543 }
544 
545 template <typename T, typename U, typename V>
predict_breadth_first_(T query,U key_ids,V keys,std::size_t max_num_results) const546 std::size_t Trie::predict_breadth_first_(T query, U key_ids, V keys,
547     std::size_t max_num_results) const try {
548   if (max_num_results == 0) {
549     return 0;
550   }
551   UInt32 node = 0;
552   std::size_t pos = 0;
553   while (!query.ends_at(pos)) {
554     if (!predict_child<T>(node, query, pos, NULL)) {
555       return 0;
556     }
557   }
558   std::string key;
559   std::size_t count = 0;
560   if (terminal_flags_[node]) {
561     const UInt32 key_id = node_to_key_id(node);
562     if (key_ids.is_valid()) {
563       key_ids.insert(count, key_id);
564     }
565     if (keys.is_valid()) {
566       restore(key_id, &key);
567       keys.insert(count, key);
568     }
569     if (++count >= max_num_results) {
570       return count;
571     }
572   }
573   const UInt32 louds_pos = get_child(node);
574   if (!louds_[louds_pos]) {
575     return count;
576   }
577   UInt32 node_begin = louds_pos_to_node(louds_pos, node);
578   UInt32 node_end = louds_pos_to_node(get_child(node + 1), node + 1);
579   while (node_begin < node_end) {
580     const UInt32 key_id_begin = node_to_key_id(node_begin);
581     const UInt32 key_id_end = node_to_key_id(node_end);
582     if (key_ids.is_valid()) {
583       UInt32 temp_count = count;
584       for (UInt32 key_id = key_id_begin; key_id < key_id_end; ++key_id) {
585         key_ids.insert(temp_count, key_id);
586         if (++temp_count >= max_num_results) {
587           break;
588         }
589       }
590     }
591     if (keys.is_valid()) {
592       UInt32 temp_count = count;
593       for (UInt32 key_id = key_id_begin; key_id < key_id_end; ++key_id) {
594         key.clear();
595         restore(key_id, &key);
596         keys.insert(temp_count, key);
597         if (++temp_count >= max_num_results) {
598           break;
599         }
600       }
601     }
602     count += key_id_end - key_id_begin;
603     if (count >= max_num_results) {
604       return max_num_results;
605     }
606     node_begin = louds_pos_to_node(get_child(node_begin), node_begin);
607     node_end = louds_pos_to_node(get_child(node_end), node_end);
608   }
609   return count;
610 } catch (const std::bad_alloc &) {
611   MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR);
612 } catch (const std::length_error &) {
613   MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR);
614 }
615 
616 template <typename T, typename U, typename V>
predict_depth_first_(T query,U key_ids,V keys,std::size_t max_num_results) const617 std::size_t Trie::predict_depth_first_(T query, U key_ids, V keys,
618     std::size_t max_num_results) const try {
619   if (max_num_results == 0) {
620     return 0;
621   } else if (keys.is_valid()) {
622     PredictCallback<U, V> callback(key_ids, keys, max_num_results);
623     return predict_callback_(query, callback);
624   }
625 
626   UInt32 node = 0;
627   std::size_t pos = 0;
628   while (!query.ends_at(pos)) {
629     if (!predict_child<T>(node, query, pos, NULL)) {
630       return 0;
631     }
632   }
633   std::size_t count = 0;
634   if (terminal_flags_[node]) {
635     if (key_ids.is_valid()) {
636       key_ids.insert(count, node_to_key_id(node));
637     }
638     if (++count >= max_num_results) {
639       return count;
640     }
641   }
642   Cell cell;
643   cell.set_louds_pos(get_child(node));
644   if (!louds_[cell.louds_pos()]) {
645     return count;
646   }
647   cell.set_node(louds_pos_to_node(cell.louds_pos(), node));
648   cell.set_key_id(node_to_key_id(cell.node()));
649   Vector<Cell> stack;
650   stack.push_back(cell);
651   std::size_t stack_pos = 1;
652   while (stack_pos != 0) {
653     Cell &cur = stack[stack_pos - 1];
654     if (!louds_[cur.louds_pos()]) {
655       cur.set_louds_pos(cur.louds_pos() + 1);
656       --stack_pos;
657       continue;
658     }
659     cur.set_louds_pos(cur.louds_pos() + 1);
660     if (terminal_flags_[cur.node()]) {
661       if (key_ids.is_valid()) {
662         key_ids.insert(count, cur.key_id());
663       }
664       if (++count >= max_num_results) {
665         return count;
666       }
667       cur.set_key_id(cur.key_id() + 1);
668     }
669     if (stack_pos == stack.size()) {
670       cell.set_louds_pos(get_child(cur.node()));
671       cell.set_node(louds_pos_to_node(cell.louds_pos(), cur.node()));
672       cell.set_key_id(node_to_key_id(cell.node()));
673       stack.push_back(cell);
674     }
675     stack[stack_pos - 1].set_node(stack[stack_pos - 1].node() + 1);
676     ++stack_pos;
677   }
678   return count;
679 } catch (const std::bad_alloc &) {
680   MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR);
681 } catch (const std::length_error &) {
682   MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR);
683 }
684 
685 template <typename T>
trie_prefix_match(UInt32 node,T query,std::size_t pos,std::string * key) const686 std::size_t Trie::trie_prefix_match(UInt32 node, T query,
687     std::size_t pos, std::string *key) const {
688   if (has_link(node)) {
689     std::size_t next_pos;
690     if (has_trie()) {
691       next_pos = trie_->trie_prefix_match<T>(get_link(node), query, pos, key);
692     } else {
693       next_pos = tail_prefix_match<T>(
694           node, get_link_id(node), query, pos, key);
695     }
696     if ((next_pos == mismatch()) || (next_pos == pos)) {
697       return next_pos;
698     }
699     pos = next_pos;
700   } else if (labels_[node] != query[pos]) {
701     return pos;
702   } else {
703     ++pos;
704   }
705   node = get_parent(node);
706   while (node != 0) {
707     if (query.ends_at(pos)) {
708       if (key != NULL) {
709         trie_restore(node, key);
710       }
711       return pos;
712     }
713     if (has_link(node)) {
714       std::size_t next_pos;
715       if (has_trie()) {
716         next_pos = trie_->trie_prefix_match<T>(
717             get_link(node), query, pos, key);
718       } else {
719         next_pos = tail_prefix_match<T>(
720             node, get_link_id(node), query, pos, key);
721       }
722       if ((next_pos == mismatch()) || (next_pos == pos)) {
723         return next_pos;
724       }
725       pos = next_pos;
726     } else if (labels_[node] != query[pos]) {
727       return mismatch();
728     } else {
729       ++pos;
730     }
731     node = get_parent(node);
732   }
733   return pos;
734 }
735 
736 template std::size_t Trie::trie_prefix_match<CQuery>(UInt32 node,
737     CQuery query, std::size_t pos, std::string *key) const;
738 template std::size_t Trie::trie_prefix_match<const Query &>(UInt32 node,
739     const Query &query, std::size_t pos, std::string *key) const;
740 
741 template <typename T>
tail_prefix_match(UInt32 node,UInt32 link_id,T query,std::size_t pos,std::string * key) const742 std::size_t Trie::tail_prefix_match(UInt32 node, UInt32 link_id,
743     T query, std::size_t pos, std::string *key) const {
744   const UInt32 offset = (links_[link_id] * 256) + labels_[node];
745   const UInt8 *ptr = tail_[offset];
746   if (*ptr != query[pos]) {
747     return pos;
748   } else if (tail_.mode() == MARISA_ALPHA_BINARY_TAIL) {
749     const UInt32 length = (links_[link_id + 1] * 256)
750         + labels_[link_flags_.select1(link_id + 1)] - offset;
751     for (UInt32 i = 1; i < length; ++i) {
752       if (query.ends_at(pos + i)) {
753         if (key != NULL) {
754           key->append(reinterpret_cast<const char *>(ptr + i), length - i);
755         }
756         return pos + i;
757       } else if (ptr[i] != query[pos + i]) {
758         return mismatch();
759       }
760     }
761     return pos + length;
762   } else {
763     for (++ptr, ++pos; *ptr != '\0'; ++ptr, ++pos) {
764       if (query.ends_at(pos)) {
765         if (key != NULL) {
766           key->append(reinterpret_cast<const char *>(ptr));
767         }
768         return pos;
769       } else if (*ptr != query[pos]) {
770         return mismatch();
771       }
772     }
773     return pos;
774   }
775 }
776 
777 template std::size_t Trie::tail_prefix_match<CQuery>(
778     UInt32 node, UInt32 link_id,
779     CQuery query, std::size_t pos, std::string *key) const;
780 template std::size_t Trie::tail_prefix_match<const Query &>(
781     UInt32 node, UInt32 link_id,
782     const Query &query, std::size_t pos, std::string *key) const;
783 
784 }  // namespace marisa_alpha
785