• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <algorithm>
2 #include <functional>
3 #include <queue>
4 
5 #include "marisa/grimoire/algorithm.h"
6 #include "marisa/grimoire/trie/header.h"
7 #include "marisa/grimoire/trie/range.h"
8 #include "marisa/grimoire/trie/state.h"
9 #include "marisa/grimoire/trie/louds-trie.h"
10 
11 namespace marisa {
12 namespace grimoire {
13 namespace trie {
14 
LoudsTrie()15 LoudsTrie::LoudsTrie()
16     : louds_(), terminal_flags_(), link_flags_(), bases_(), extras_(),
17       tail_(), next_trie_(), cache_(), cache_mask_(0), num_l1_nodes_(0),
18       config_(), mapper_() {}
19 
~LoudsTrie()20 LoudsTrie::~LoudsTrie() {}
21 
build(Keyset & keyset,int flags)22 void LoudsTrie::build(Keyset &keyset, int flags) {
23   Config config;
24   config.parse(flags);
25 
26   LoudsTrie temp;
27   temp.build_(keyset, config);
28   swap(temp);
29 }
30 
map(Mapper & mapper)31 void LoudsTrie::map(Mapper &mapper) {
32   Header().map(mapper);
33 
34   LoudsTrie temp;
35   temp.map_(mapper);
36   temp.mapper_.swap(mapper);
37   swap(temp);
38 }
39 
read(Reader & reader)40 void LoudsTrie::read(Reader &reader) {
41   Header().read(reader);
42 
43   LoudsTrie temp;
44   temp.read_(reader);
45   swap(temp);
46 }
47 
write(Writer & writer) const48 void LoudsTrie::write(Writer &writer) const {
49   Header().write(writer);
50 
51   write_(writer);
52 }
53 
lookup(Agent & agent) const54 bool LoudsTrie::lookup(Agent &agent) const {
55   MARISA_DEBUG_IF(!agent.has_state(), MARISA_STATE_ERROR);
56 
57   State &state = agent.state();
58   state.lookup_init();
59   while (state.query_pos() < agent.query().length()) {
60     if (!find_child(agent)) {
61       return false;
62     }
63   }
64   if (!terminal_flags_[state.node_id()]) {
65     return false;
66   }
67   agent.set_key(agent.query().ptr(), agent.query().length());
68   agent.set_key(terminal_flags_.rank1(state.node_id()));
69   return true;
70 }
71 
reverse_lookup(Agent & agent) const72 void LoudsTrie::reverse_lookup(Agent &agent) const {
73   MARISA_DEBUG_IF(!agent.has_state(), MARISA_STATE_ERROR);
74   MARISA_THROW_IF(agent.query().id() >= size(), MARISA_BOUND_ERROR);
75 
76   State &state = agent.state();
77   state.reverse_lookup_init();
78 
79   state.set_node_id(terminal_flags_.select1(agent.query().id()));
80   if (state.node_id() == 0) {
81     agent.set_key(state.key_buf().begin(), state.key_buf().size());
82     agent.set_key(agent.query().id());
83     return;
84   }
85   for ( ; ; ) {
86     if (link_flags_[state.node_id()]) {
87       const std::size_t prev_key_pos = state.key_buf().size();
88       restore(agent, get_link(state.node_id()));
89       std::reverse(state.key_buf().begin() + prev_key_pos,
90           state.key_buf().end());
91     } else {
92       state.key_buf().push_back((char)bases_[state.node_id()]);
93     }
94 
95     if (state.node_id() <= num_l1_nodes_) {
96       std::reverse(state.key_buf().begin(), state.key_buf().end());
97       agent.set_key(state.key_buf().begin(), state.key_buf().size());
98       agent.set_key(agent.query().id());
99       return;
100     }
101     state.set_node_id(louds_.select1(state.node_id()) - state.node_id() - 1);
102   }
103 }
104 
common_prefix_search(Agent & agent) const105 bool LoudsTrie::common_prefix_search(Agent &agent) const {
106   MARISA_DEBUG_IF(!agent.has_state(), MARISA_STATE_ERROR);
107 
108   State &state = agent.state();
109   if (state.status_code() == MARISA_END_OF_COMMON_PREFIX_SEARCH) {
110     return false;
111   }
112 
113   if (state.status_code() != MARISA_READY_TO_COMMON_PREFIX_SEARCH) {
114     state.common_prefix_search_init();
115     if (terminal_flags_[state.node_id()]) {
116       agent.set_key(agent.query().ptr(), state.query_pos());
117       agent.set_key(terminal_flags_.rank1(state.node_id()));
118       return true;
119     }
120   }
121 
122   while (state.query_pos() < agent.query().length()) {
123     if (!find_child(agent)) {
124       state.set_status_code(MARISA_END_OF_COMMON_PREFIX_SEARCH);
125       return false;
126     } else if (terminal_flags_[state.node_id()]) {
127       agent.set_key(agent.query().ptr(), state.query_pos());
128       agent.set_key(terminal_flags_.rank1(state.node_id()));
129       return true;
130     }
131   }
132   state.set_status_code(MARISA_END_OF_COMMON_PREFIX_SEARCH);
133   return false;
134 }
135 
predictive_search(Agent & agent) const136 bool LoudsTrie::predictive_search(Agent &agent) const {
137   MARISA_DEBUG_IF(!agent.has_state(), MARISA_STATE_ERROR);
138 
139   State &state = agent.state();
140   if (state.status_code() == MARISA_END_OF_PREDICTIVE_SEARCH) {
141     return false;
142   }
143 
144   if (state.status_code() != MARISA_READY_TO_PREDICTIVE_SEARCH) {
145     state.predictive_search_init();
146     while (state.query_pos() < agent.query().length()) {
147       if (!predictive_find_child(agent)) {
148         state.set_status_code(MARISA_END_OF_PREDICTIVE_SEARCH);
149         return false;
150       }
151     }
152 
153     History history;
154     history.set_node_id(state.node_id());
155     history.set_key_pos(state.key_buf().size());
156     state.history().push_back(history);
157     state.set_history_pos(1);
158 
159     if (terminal_flags_[state.node_id()]) {
160       agent.set_key(state.key_buf().begin(), state.key_buf().size());
161       agent.set_key(terminal_flags_.rank1(state.node_id()));
162       return true;
163     }
164   }
165 
166   for ( ; ; ) {
167     if (state.history_pos() == state.history().size()) {
168       const History &current = state.history().back();
169       History next;
170       next.set_louds_pos(louds_.select0(current.node_id()) + 1);
171       next.set_node_id(next.louds_pos() - current.node_id() - 1);
172       state.history().push_back(next);
173     }
174 
175     History &next = state.history()[state.history_pos()];
176     const bool link_flag = louds_[next.louds_pos()];
177     next.set_louds_pos(next.louds_pos() + 1);
178     if (link_flag) {
179       state.set_history_pos(state.history_pos() + 1);
180       if (link_flags_[next.node_id()]) {
181         next.set_link_id(update_link_id(next.link_id(), next.node_id()));
182         restore(agent, get_link(next.node_id(), next.link_id()));
183       } else {
184         state.key_buf().push_back((char)bases_[next.node_id()]);
185       }
186       next.set_key_pos(state.key_buf().size());
187 
188       if (terminal_flags_[next.node_id()]) {
189         if (next.key_id() == MARISA_INVALID_KEY_ID) {
190           next.set_key_id(terminal_flags_.rank1(next.node_id()));
191         } else {
192           next.set_key_id(next.key_id() + 1);
193         }
194         agent.set_key(state.key_buf().begin(), state.key_buf().size());
195         agent.set_key(next.key_id());
196         return true;
197       }
198     } else if (state.history_pos() != 1) {
199       History &current = state.history()[state.history_pos() - 1];
200       current.set_node_id(current.node_id() + 1);
201       const History &prev =
202           state.history()[state.history_pos() - 2];
203       state.key_buf().resize(prev.key_pos());
204       state.set_history_pos(state.history_pos() - 1);
205     } else {
206       state.set_status_code(MARISA_END_OF_PREDICTIVE_SEARCH);
207       return false;
208     }
209   }
210 }
211 
total_size() const212 std::size_t LoudsTrie::total_size() const {
213   return louds_.total_size() + terminal_flags_.total_size()
214       + link_flags_.total_size() + bases_.total_size()
215       + extras_.total_size() + tail_.total_size()
216       + ((next_trie_.get() != NULL) ? next_trie_->total_size() : 0)
217       + cache_.total_size();
218 }
219 
io_size() const220 std::size_t LoudsTrie::io_size() const {
221   return Header().io_size() + louds_.io_size()
222       + terminal_flags_.io_size() + link_flags_.io_size()
223       + bases_.io_size() + extras_.io_size() + tail_.io_size()
224       + ((next_trie_.get() != NULL) ?
225           (next_trie_->io_size() - Header().io_size()) : 0)
226       + cache_.io_size() + (sizeof(UInt32) * 2);
227 }
228 
clear()229 void LoudsTrie::clear() {
230   LoudsTrie().swap(*this);
231 }
232 
swap(LoudsTrie & rhs)233 void LoudsTrie::swap(LoudsTrie &rhs) {
234   louds_.swap(rhs.louds_);
235   terminal_flags_.swap(rhs.terminal_flags_);
236   link_flags_.swap(rhs.link_flags_);
237   bases_.swap(rhs.bases_);
238   extras_.swap(rhs.extras_);
239   tail_.swap(rhs.tail_);
240   next_trie_.swap(rhs.next_trie_);
241   cache_.swap(rhs.cache_);
242   marisa::swap(cache_mask_, rhs.cache_mask_);
243   marisa::swap(num_l1_nodes_, rhs.num_l1_nodes_);
244   config_.swap(rhs.config_);
245   mapper_.swap(rhs.mapper_);
246 }
247 
build_(Keyset & keyset,const Config & config)248 void LoudsTrie::build_(Keyset &keyset, const Config &config) {
249   Vector<Key> keys;
250   keys.resize(keyset.size());
251   for (std::size_t i = 0; i < keyset.size(); ++i) {
252     keys[i].set_str(keyset[i].ptr(), keyset[i].length());
253     keys[i].set_weight(keyset[i].weight());
254   }
255 
256   Vector<UInt32> terminals;
257   build_trie(keys, &terminals, config, 1);
258 
259   typedef std::pair<UInt32, UInt32> TerminalIdPair;
260 
261   Vector<TerminalIdPair> pairs;
262   pairs.resize(terminals.size());
263   for (std::size_t i = 0; i < pairs.size(); ++i) {
264     pairs[i].first = terminals[i];
265     pairs[i].second = (UInt32)i;
266   }
267   terminals.clear();
268   std::sort(pairs.begin(), pairs.end());
269 
270   std::size_t node_id = 0;
271   for (std::size_t i = 0; i < pairs.size(); ++i) {
272     while (node_id < pairs[i].first) {
273       terminal_flags_.push_back(false);
274       ++node_id;
275     }
276     if (node_id == pairs[i].first) {
277       terminal_flags_.push_back(true);
278       ++node_id;
279     }
280   }
281   while (node_id < bases_.size()) {
282     terminal_flags_.push_back(false);
283     ++node_id;
284   }
285   terminal_flags_.push_back(false);
286   terminal_flags_.build(false, true);
287 
288   for (std::size_t i = 0; i < keyset.size(); ++i) {
289     keyset[pairs[i].second].set_id(terminal_flags_.rank1(pairs[i].first));
290   }
291 }
292 
293 template <typename T>
build_trie(Vector<T> & keys,Vector<UInt32> * terminals,const Config & config,std::size_t trie_id)294 void LoudsTrie::build_trie(Vector<T> &keys,
295     Vector<UInt32> *terminals, const Config &config, std::size_t trie_id) {
296   build_current_trie(keys, terminals, config, trie_id);
297 
298   Vector<UInt32> next_terminals;
299   if (!keys.empty()) {
300     build_next_trie(keys, &next_terminals, config, trie_id);
301   }
302 
303   if (next_trie_.get() != NULL) {
304     config_.parse(static_cast<int>((next_trie_->num_tries() + 1)) |
305         next_trie_->tail_mode() | next_trie_->node_order());
306   } else {
307     config_.parse(1 | tail_.mode() | config.node_order() |
308         config.cache_level());
309   }
310 
311   link_flags_.build(false, false);
312   std::size_t node_id = 0;
313   for (std::size_t i = 0; i < next_terminals.size(); ++i) {
314     while (!link_flags_[node_id]) {
315       ++node_id;
316     }
317     bases_[node_id] = (UInt8)(next_terminals[i] % 256);
318     next_terminals[i] /= 256;
319     ++node_id;
320   }
321   extras_.build(next_terminals);
322   fill_cache();
323 }
324 
325 template <typename T>
build_current_trie(Vector<T> & keys,Vector<UInt32> * terminals,const Config & config,std::size_t trie_id)326 void LoudsTrie::build_current_trie(Vector<T> &keys,
327     Vector<UInt32> *terminals, const Config &config,
328     std::size_t trie_id) try {
329   for (std::size_t i = 0; i < keys.size(); ++i) {
330     keys[i].set_id(i);
331   }
332   const std::size_t num_keys = Algorithm().sort(keys.begin(), keys.end());
333   reserve_cache(config, trie_id, num_keys);
334 
335   louds_.push_back(true);
336   louds_.push_back(false);
337   bases_.push_back('\0');
338   link_flags_.push_back(false);
339 
340   Vector<T> next_keys;
341   std::queue<Range> queue;
342   Vector<WeightedRange> w_ranges;
343 
344   queue.push(make_range(0, keys.size(), 0));
345   while (!queue.empty()) {
346     const std::size_t node_id = link_flags_.size() - queue.size();
347 
348     Range range = queue.front();
349     queue.pop();
350 
351     while ((range.begin() < range.end()) &&
352         (keys[range.begin()].length() == range.key_pos())) {
353       keys[range.begin()].set_terminal(node_id);
354       range.set_begin(range.begin() + 1);
355     }
356 
357     if (range.begin() == range.end()) {
358       louds_.push_back(false);
359       continue;
360     }
361 
362     w_ranges.clear();
363     double weight = keys[range.begin()].weight();
364     for (std::size_t i = range.begin() + 1; i < range.end(); ++i) {
365       if (keys[i - 1][range.key_pos()] != keys[i][range.key_pos()]) {
366         w_ranges.push_back(make_weighted_range(
367             range.begin(), i, range.key_pos(), (float)weight));
368         range.set_begin(i);
369         weight = 0.0;
370       }
371       weight += keys[i].weight();
372     }
373     w_ranges.push_back(make_weighted_range(
374         range.begin(), range.end(), range.key_pos(), (float)weight));
375     if (config.node_order() == MARISA_WEIGHT_ORDER) {
376       std::stable_sort(w_ranges.begin(), w_ranges.end(),
377           std::greater<WeightedRange>());
378     }
379 
380     if (node_id == 0) {
381       num_l1_nodes_ = w_ranges.size();
382     }
383 
384     for (std::size_t i = 0; i < w_ranges.size(); ++i) {
385       WeightedRange &w_range = w_ranges[i];
386       std::size_t key_pos = w_range.key_pos() + 1;
387       while (key_pos < keys[w_range.begin()].length()) {
388         std::size_t j;
389         for (j = w_range.begin() + 1; j < w_range.end(); ++j) {
390           if (keys[j - 1][key_pos] != keys[j][key_pos]) {
391             break;
392           }
393         }
394         if (j < w_range.end()) {
395           break;
396         }
397         ++key_pos;
398       }
399       cache<T>(node_id, bases_.size(), w_range.weight(),
400           keys[w_range.begin()][w_range.key_pos()]);
401 
402       if (key_pos == w_range.key_pos() + 1) {
403         bases_.push_back(static_cast<unsigned char>(
404             keys[w_range.begin()][w_range.key_pos()]));
405         link_flags_.push_back(false);
406       } else {
407         bases_.push_back('\0');
408         link_flags_.push_back(true);
409         T next_key;
410         next_key.set_str(keys[w_range.begin()].ptr(),
411             keys[w_range.begin()].length());
412         next_key.substr(w_range.key_pos(), key_pos - w_range.key_pos());
413         next_key.set_weight(w_range.weight());
414         next_keys.push_back(next_key);
415       }
416       w_range.set_key_pos(key_pos);
417       queue.push(w_range.range());
418       louds_.push_back(true);
419     }
420     louds_.push_back(false);
421   }
422 
423   louds_.push_back(false);
424   louds_.build(trie_id == 1, true);
425   bases_.shrink();
426 
427   build_terminals(keys, terminals);
428   keys.swap(next_keys);
429 } catch (const std::bad_alloc &) {
430   MARISA_THROW(MARISA_MEMORY_ERROR, "std::bad_alloc");
431 }
432 
433 template <>
build_next_trie(Vector<Key> & keys,Vector<UInt32> * terminals,const Config & config,std::size_t trie_id)434 void LoudsTrie::build_next_trie(Vector<Key> &keys,
435     Vector<UInt32> *terminals, const Config &config, std::size_t trie_id) {
436   if (trie_id == config.num_tries()) {
437     Vector<Entry> entries;
438     entries.resize(keys.size());
439     for (std::size_t i = 0; i < keys.size(); ++i) {
440       entries[i].set_str(keys[i].ptr(), keys[i].length());
441     }
442     tail_.build(entries, terminals, config.tail_mode());
443     return;
444   }
445   Vector<ReverseKey> reverse_keys;
446   reverse_keys.resize(keys.size());
447   for (std::size_t i = 0; i < keys.size(); ++i) {
448     reverse_keys[i].set_str(keys[i].ptr(), keys[i].length());
449     reverse_keys[i].set_weight(keys[i].weight());
450   }
451   keys.clear();
452   next_trie_.reset(new (std::nothrow) LoudsTrie);
453   MARISA_THROW_IF(next_trie_.get() == NULL, MARISA_MEMORY_ERROR);
454   next_trie_->build_trie(reverse_keys, terminals, config, trie_id + 1);
455 }
456 
457 template <>
build_next_trie(Vector<ReverseKey> & keys,Vector<UInt32> * terminals,const Config & config,std::size_t trie_id)458 void LoudsTrie::build_next_trie(Vector<ReverseKey> &keys,
459     Vector<UInt32> *terminals, const Config &config, std::size_t trie_id) {
460   if (trie_id == config.num_tries()) {
461     Vector<Entry> entries;
462     entries.resize(keys.size());
463     for (std::size_t i = 0; i < keys.size(); ++i) {
464       entries[i].set_str(keys[i].ptr(), keys[i].length());
465     }
466     tail_.build(entries, terminals, config.tail_mode());
467     return;
468   }
469   next_trie_.reset(new (std::nothrow) LoudsTrie);
470   MARISA_THROW_IF(next_trie_.get() == NULL, MARISA_MEMORY_ERROR);
471   next_trie_->build_trie(keys, terminals, config, trie_id + 1);
472 }
473 
474 template <typename T>
build_terminals(const Vector<T> & keys,Vector<UInt32> * terminals) const475 void LoudsTrie::build_terminals(const Vector<T> &keys,
476     Vector<UInt32> *terminals) const {
477   Vector<UInt32> temp;
478   temp.resize(keys.size());
479   for (std::size_t i = 0; i < keys.size(); ++i) {
480     temp[keys[i].id()] = (UInt32)keys[i].terminal();
481   }
482   terminals->swap(temp);
483 }
484 
485 template <>
cache(std::size_t parent,std::size_t child,float weight,char label)486 void LoudsTrie::cache<Key>(std::size_t parent, std::size_t child,
487     float weight, char label) {
488   MARISA_DEBUG_IF(parent >= child, MARISA_RANGE_ERROR);
489 
490   const std::size_t cache_id = get_cache_id(parent, label);
491   if (weight > cache_[cache_id].weight()) {
492     cache_[cache_id].set_parent(parent);
493     cache_[cache_id].set_child(child);
494     cache_[cache_id].set_weight(weight);
495   }
496 }
497 
reserve_cache(const Config & config,std::size_t trie_id,std::size_t num_keys)498 void LoudsTrie::reserve_cache(const Config &config, std::size_t trie_id,
499     std::size_t num_keys) {
500   std::size_t cache_size = (trie_id == 1) ? 256 : 1;
501   while (cache_size < (num_keys / config.cache_level())) {
502     cache_size *= 2;
503   }
504   cache_.resize(cache_size);
505   cache_mask_ = cache_size - 1;
506 }
507 
508 template <>
cache(std::size_t parent,std::size_t child,float weight,char)509 void LoudsTrie::cache<ReverseKey>(std::size_t parent, std::size_t child,
510     float weight, char) {
511   MARISA_DEBUG_IF(parent >= child, MARISA_RANGE_ERROR);
512 
513   const std::size_t cache_id = get_cache_id(child);
514   if (weight > cache_[cache_id].weight()) {
515     cache_[cache_id].set_parent(parent);
516     cache_[cache_id].set_child(child);
517     cache_[cache_id].set_weight(weight);
518   }
519 }
520 
fill_cache()521 void LoudsTrie::fill_cache() {
522   for (std::size_t i = 0; i < cache_.size(); ++i) {
523     const std::size_t node_id = cache_[i].child();
524     if (node_id != 0) {
525       cache_[i].set_base(bases_[node_id]);
526       cache_[i].set_extra(!link_flags_[node_id] ?
527           MARISA_INVALID_EXTRA : extras_[link_flags_.rank1(node_id)]);
528     } else {
529       cache_[i].set_parent(MARISA_UINT32_MAX);
530       cache_[i].set_child(MARISA_UINT32_MAX);
531     }
532   }
533 }
534 
map_(Mapper & mapper)535 void LoudsTrie::map_(Mapper &mapper) {
536   louds_.map(mapper);
537   terminal_flags_.map(mapper);
538   link_flags_.map(mapper);
539   bases_.map(mapper);
540   extras_.map(mapper);
541   tail_.map(mapper);
542   if ((link_flags_.num_1s() != 0) && tail_.empty()) {
543     next_trie_.reset(new (std::nothrow) LoudsTrie);
544     MARISA_THROW_IF(next_trie_.get() == NULL, MARISA_MEMORY_ERROR);
545     next_trie_->map_(mapper);
546   }
547   cache_.map(mapper);
548   cache_mask_ = cache_.size() - 1;
549   {
550     UInt32 temp_num_l1_nodes;
551     mapper.map(&temp_num_l1_nodes);
552     num_l1_nodes_ = temp_num_l1_nodes;
553   }
554   {
555     UInt32 temp_config_flags;
556     mapper.map(&temp_config_flags);
557     config_.parse((int)temp_config_flags);
558   }
559 }
560 
read_(Reader & reader)561 void LoudsTrie::read_(Reader &reader) {
562   louds_.read(reader);
563   terminal_flags_.read(reader);
564   link_flags_.read(reader);
565   bases_.read(reader);
566   extras_.read(reader);
567   tail_.read(reader);
568   if ((link_flags_.num_1s() != 0) && tail_.empty()) {
569     next_trie_.reset(new (std::nothrow) LoudsTrie);
570     MARISA_THROW_IF(next_trie_.get() == NULL, MARISA_MEMORY_ERROR);
571     next_trie_->read_(reader);
572   }
573   cache_.read(reader);
574   cache_mask_ = cache_.size() - 1;
575   {
576     UInt32 temp_num_l1_nodes;
577     reader.read(&temp_num_l1_nodes);
578     num_l1_nodes_ = temp_num_l1_nodes;
579   }
580   {
581     UInt32 temp_config_flags;
582     reader.read(&temp_config_flags);
583     config_.parse((int)temp_config_flags);
584   }
585 }
586 
write_(Writer & writer) const587 void LoudsTrie::write_(Writer &writer) const {
588   louds_.write(writer);
589   terminal_flags_.write(writer);
590   link_flags_.write(writer);
591   bases_.write(writer);
592   extras_.write(writer);
593   tail_.write(writer);
594   if (next_trie_.get() != NULL) {
595     next_trie_->write_(writer);
596   }
597   cache_.write(writer);
598   writer.write((UInt32)num_l1_nodes_);
599   writer.write((UInt32)config_.flags());
600 }
601 
find_child(Agent & agent) const602 bool LoudsTrie::find_child(Agent &agent) const {
603   MARISA_DEBUG_IF(agent.state().query_pos() >= agent.query().length(),
604       MARISA_BOUND_ERROR);
605 
606   State &state = agent.state();
607   const std::size_t cache_id = get_cache_id(state.node_id(),
608       agent.query()[state.query_pos()]);
609   if (state.node_id() == cache_[cache_id].parent()) {
610     if (cache_[cache_id].extra() != MARISA_INVALID_EXTRA) {
611       if (!match(agent, cache_[cache_id].link())) {
612         return false;
613       }
614     } else {
615       state.set_query_pos(state.query_pos() + 1);
616     }
617     state.set_node_id(cache_[cache_id].child());
618     return true;
619   }
620 
621   std::size_t louds_pos = louds_.select0(state.node_id()) + 1;
622   if (!louds_[louds_pos]) {
623     return false;
624   }
625   state.set_node_id(louds_pos - state.node_id() - 1);
626   std::size_t link_id = MARISA_INVALID_LINK_ID;
627   do {
628     if (link_flags_[state.node_id()]) {
629       link_id = update_link_id(link_id, state.node_id());
630       const std::size_t prev_query_pos = state.query_pos();
631       if (match(agent, get_link(state.node_id(), link_id))) {
632         return true;
633       } else if (state.query_pos() != prev_query_pos) {
634         return false;
635       }
636     } else if (bases_[state.node_id()] ==
637         (UInt8)agent.query()[state.query_pos()]) {
638       state.set_query_pos(state.query_pos() + 1);
639       return true;
640     }
641     state.set_node_id(state.node_id() + 1);
642     ++louds_pos;
643   } while (louds_[louds_pos]);
644   return false;
645 }
646 
predictive_find_child(Agent & agent) const647 bool LoudsTrie::predictive_find_child(Agent &agent) const {
648   MARISA_DEBUG_IF(agent.state().query_pos() >= agent.query().length(),
649       MARISA_BOUND_ERROR);
650 
651   State &state = agent.state();
652   const std::size_t cache_id = get_cache_id(state.node_id(),
653       agent.query()[state.query_pos()]);
654   if (state.node_id() == cache_[cache_id].parent()) {
655     if (cache_[cache_id].extra() != MARISA_INVALID_EXTRA) {
656       if (!prefix_match(agent, cache_[cache_id].link())) {
657         return false;
658       }
659     } else {
660       state.key_buf().push_back(cache_[cache_id].label());
661       state.set_query_pos(state.query_pos() + 1);
662     }
663     state.set_node_id(cache_[cache_id].child());
664     return true;
665   }
666 
667   std::size_t louds_pos = louds_.select0(state.node_id()) + 1;
668   if (!louds_[louds_pos]) {
669     return false;
670   }
671   state.set_node_id(louds_pos - state.node_id() - 1);
672   std::size_t link_id = MARISA_INVALID_LINK_ID;
673   do {
674     if (link_flags_[state.node_id()]) {
675       link_id = update_link_id(link_id, state.node_id());
676       const std::size_t prev_query_pos = state.query_pos();
677       if (prefix_match(agent, get_link(state.node_id(), link_id))) {
678         return true;
679       } else if (state.query_pos() != prev_query_pos) {
680         return false;
681       }
682     } else if (bases_[state.node_id()] ==
683         (UInt8)agent.query()[state.query_pos()]) {
684       state.key_buf().push_back((char)bases_[state.node_id()]);
685       state.set_query_pos(state.query_pos() + 1);
686       return true;
687     }
688     state.set_node_id(state.node_id() + 1);
689     ++louds_pos;
690   } while (louds_[louds_pos]);
691   return false;
692 }
693 
restore(Agent & agent,std::size_t link) const694 void LoudsTrie::restore(Agent &agent, std::size_t link) const {
695   if (next_trie_.get() != NULL) {
696     next_trie_->restore_(agent,  link);
697   } else {
698     tail_.restore(agent, link);
699   }
700 }
701 
match(Agent & agent,std::size_t link) const702 bool LoudsTrie::match(Agent &agent, std::size_t link) const {
703   if (next_trie_.get() != NULL) {
704     return next_trie_->match_(agent, link);
705   } else {
706     return tail_.match(agent, link);
707   }
708 }
709 
prefix_match(Agent & agent,std::size_t link) const710 bool LoudsTrie::prefix_match(Agent &agent, std::size_t link) const {
711   if (next_trie_.get() != NULL) {
712     return next_trie_->prefix_match_(agent, link);
713   } else {
714     return tail_.prefix_match(agent, link);
715   }
716 }
717 
restore_(Agent & agent,std::size_t node_id) const718 void LoudsTrie::restore_(Agent &agent, std::size_t node_id) const {
719   MARISA_DEBUG_IF(node_id == 0, MARISA_RANGE_ERROR);
720 
721   State &state = agent.state();
722   for ( ; ; ) {
723     const std::size_t cache_id = get_cache_id(node_id);
724     if (node_id == cache_[cache_id].child()) {
725       if (cache_[cache_id].extra() != MARISA_INVALID_EXTRA) {
726         restore(agent,  cache_[cache_id].link());
727       } else {
728         state.key_buf().push_back(cache_[cache_id].label());
729       }
730 
731       node_id = cache_[cache_id].parent();
732       if (node_id == 0) {
733         return;
734       }
735       continue;
736     }
737 
738     if (link_flags_[node_id]) {
739       restore(agent, get_link(node_id));
740     } else {
741       state.key_buf().push_back((char)bases_[node_id]);
742     }
743 
744     if (node_id <= num_l1_nodes_) {
745       return;
746     }
747     node_id = louds_.select1(node_id) - node_id - 1;
748   }
749 }
750 
match_(Agent & agent,std::size_t node_id) const751 bool LoudsTrie::match_(Agent &agent, std::size_t node_id) const {
752   MARISA_DEBUG_IF(agent.state().query_pos() >= agent.query().length(),
753       MARISA_BOUND_ERROR);
754   MARISA_DEBUG_IF(node_id == 0, MARISA_RANGE_ERROR);
755 
756   State &state = agent.state();
757   for ( ; ; ) {
758     const std::size_t cache_id = get_cache_id(node_id);
759     if (node_id == cache_[cache_id].child()) {
760       if (cache_[cache_id].extra() != MARISA_INVALID_EXTRA) {
761         if (!match(agent, cache_[cache_id].link())) {
762           return false;
763         }
764       } else if (cache_[cache_id].label() ==
765           agent.query()[state.query_pos()]) {
766         state.set_query_pos(state.query_pos() + 1);
767       } else {
768         return false;
769       }
770 
771       node_id = cache_[cache_id].parent();
772       if (node_id == 0) {
773         return true;
774       } else if (state.query_pos() >= agent.query().length()) {
775         return false;
776       }
777       continue;
778     }
779 
780     if (link_flags_[node_id]) {
781       if (next_trie_.get() != NULL) {
782         if (!match(agent, get_link(node_id))) {
783           return false;
784         }
785       } else if (!tail_.match(agent, get_link(node_id))) {
786         return false;
787       }
788     } else if (bases_[node_id] == (UInt8)agent.query()[state.query_pos()]) {
789       state.set_query_pos(state.query_pos() + 1);
790     } else {
791       return false;
792     }
793 
794     if (node_id <= num_l1_nodes_) {
795       return true;
796     } else if (state.query_pos() >= agent.query().length()) {
797       return false;
798     }
799     node_id = louds_.select1(node_id) - node_id - 1;
800   }
801 }
802 
prefix_match_(Agent & agent,std::size_t node_id) const803 bool LoudsTrie::prefix_match_(Agent &agent, std::size_t node_id) const {
804   MARISA_DEBUG_IF(agent.state().query_pos() >= agent.query().length(),
805       MARISA_BOUND_ERROR);
806   MARISA_DEBUG_IF(node_id == 0, MARISA_RANGE_ERROR);
807 
808   State &state = agent.state();
809   for ( ; ; ) {
810     const std::size_t cache_id = get_cache_id(node_id);
811     if (node_id == cache_[cache_id].child()) {
812       if (cache_[cache_id].extra() != MARISA_INVALID_EXTRA) {
813         if (!prefix_match(agent, cache_[cache_id].link())) {
814           return false;
815         }
816       } else if (cache_[cache_id].label() ==
817           agent.query()[state.query_pos()]) {
818         state.key_buf().push_back(cache_[cache_id].label());
819         state.set_query_pos(state.query_pos() + 1);
820       } else {
821         return false;
822       }
823 
824       node_id = cache_[cache_id].parent();
825       if (node_id == 0) {
826         return true;
827       }
828     } else {
829       if (link_flags_[node_id]) {
830         if (!prefix_match(agent, get_link(node_id))) {
831           return false;
832         }
833       } else if (bases_[node_id] == (UInt8)agent.query()[state.query_pos()]) {
834         state.key_buf().push_back((char)bases_[node_id]);
835         state.set_query_pos(state.query_pos() + 1);
836       } else {
837         return false;
838       }
839 
840       if (node_id <= num_l1_nodes_) {
841         return true;
842       }
843       node_id = louds_.select1(node_id) - node_id - 1;
844     }
845 
846     if (state.query_pos() >= agent.query().length()) {
847       restore_(agent, node_id);
848       return true;
849     }
850   }
851 }
852 
get_cache_id(std::size_t node_id,char label) const853 std::size_t LoudsTrie::get_cache_id(std::size_t node_id, char label) const {
854   return (node_id ^ (node_id << 5) ^ (UInt8)label) & cache_mask_;
855 }
856 
get_cache_id(std::size_t node_id) const857 std::size_t LoudsTrie::get_cache_id(std::size_t node_id) const {
858   return node_id & cache_mask_;
859 }
860 
get_link(std::size_t node_id) const861 std::size_t LoudsTrie::get_link(std::size_t node_id) const {
862   return  bases_[node_id] | (extras_[link_flags_.rank1(node_id)] * 256);
863 }
864 
get_link(std::size_t node_id,std::size_t link_id) const865 std::size_t LoudsTrie::get_link(std::size_t node_id,
866     std::size_t link_id) const {
867   return  bases_[node_id] | (extras_[link_id] * 256);
868 }
869 
update_link_id(std::size_t link_id,std::size_t node_id) const870 std::size_t LoudsTrie::update_link_id(std::size_t link_id,
871     std::size_t node_id) const {
872   return (link_id == MARISA_INVALID_LINK_ID) ?
873       link_flags_.rank1(node_id) : (link_id + 1);
874 }
875 
876 }  // namespace trie
877 }  // namespace grimoire
878 }  // namespace marisa
879