1 #include <algorithm>
2 #include <functional>
3 #include <queue>
4
5 #include "marisa/grimoire/algorithm.h"
6 #include "marisa/grimoire/trie/header.h"
7 #include "marisa/grimoire/trie/range.h"
8 #include "marisa/grimoire/trie/state.h"
9 #include "marisa/grimoire/trie/louds-trie.h"
10
11 namespace marisa {
12 namespace grimoire {
13 namespace trie {
14
LoudsTrie()15 LoudsTrie::LoudsTrie()
16 : louds_(), terminal_flags_(), link_flags_(), bases_(), extras_(),
17 tail_(), next_trie_(), cache_(), cache_mask_(0), num_l1_nodes_(0),
18 config_(), mapper_() {}
19
~LoudsTrie()20 LoudsTrie::~LoudsTrie() {}
21
build(Keyset & keyset,int flags)22 void LoudsTrie::build(Keyset &keyset, int flags) {
23 Config config;
24 config.parse(flags);
25
26 LoudsTrie temp;
27 temp.build_(keyset, config);
28 swap(temp);
29 }
30
map(Mapper & mapper)31 void LoudsTrie::map(Mapper &mapper) {
32 Header().map(mapper);
33
34 LoudsTrie temp;
35 temp.map_(mapper);
36 temp.mapper_.swap(mapper);
37 swap(temp);
38 }
39
read(Reader & reader)40 void LoudsTrie::read(Reader &reader) {
41 Header().read(reader);
42
43 LoudsTrie temp;
44 temp.read_(reader);
45 swap(temp);
46 }
47
write(Writer & writer) const48 void LoudsTrie::write(Writer &writer) const {
49 Header().write(writer);
50
51 write_(writer);
52 }
53
lookup(Agent & agent) const54 bool LoudsTrie::lookup(Agent &agent) const {
55 MARISA_DEBUG_IF(!agent.has_state(), MARISA_STATE_ERROR);
56
57 State &state = agent.state();
58 state.lookup_init();
59 while (state.query_pos() < agent.query().length()) {
60 if (!find_child(agent)) {
61 return false;
62 }
63 }
64 if (!terminal_flags_[state.node_id()]) {
65 return false;
66 }
67 agent.set_key(agent.query().ptr(), agent.query().length());
68 agent.set_key(terminal_flags_.rank1(state.node_id()));
69 return true;
70 }
71
reverse_lookup(Agent & agent) const72 void LoudsTrie::reverse_lookup(Agent &agent) const {
73 MARISA_DEBUG_IF(!agent.has_state(), MARISA_STATE_ERROR);
74 MARISA_THROW_IF(agent.query().id() >= size(), MARISA_BOUND_ERROR);
75
76 State &state = agent.state();
77 state.reverse_lookup_init();
78
79 state.set_node_id(terminal_flags_.select1(agent.query().id()));
80 if (state.node_id() == 0) {
81 agent.set_key(state.key_buf().begin(), state.key_buf().size());
82 agent.set_key(agent.query().id());
83 return;
84 }
85 for ( ; ; ) {
86 if (link_flags_[state.node_id()]) {
87 const std::size_t prev_key_pos = state.key_buf().size();
88 restore(agent, get_link(state.node_id()));
89 std::reverse(state.key_buf().begin() + prev_key_pos,
90 state.key_buf().end());
91 } else {
92 state.key_buf().push_back((char)bases_[state.node_id()]);
93 }
94
95 if (state.node_id() <= num_l1_nodes_) {
96 std::reverse(state.key_buf().begin(), state.key_buf().end());
97 agent.set_key(state.key_buf().begin(), state.key_buf().size());
98 agent.set_key(agent.query().id());
99 return;
100 }
101 state.set_node_id(louds_.select1(state.node_id()) - state.node_id() - 1);
102 }
103 }
104
common_prefix_search(Agent & agent) const105 bool LoudsTrie::common_prefix_search(Agent &agent) const {
106 MARISA_DEBUG_IF(!agent.has_state(), MARISA_STATE_ERROR);
107
108 State &state = agent.state();
109 if (state.status_code() == MARISA_END_OF_COMMON_PREFIX_SEARCH) {
110 return false;
111 }
112
113 if (state.status_code() != MARISA_READY_TO_COMMON_PREFIX_SEARCH) {
114 state.common_prefix_search_init();
115 if (terminal_flags_[state.node_id()]) {
116 agent.set_key(agent.query().ptr(), state.query_pos());
117 agent.set_key(terminal_flags_.rank1(state.node_id()));
118 return true;
119 }
120 }
121
122 while (state.query_pos() < agent.query().length()) {
123 if (!find_child(agent)) {
124 state.set_status_code(MARISA_END_OF_COMMON_PREFIX_SEARCH);
125 return false;
126 } else if (terminal_flags_[state.node_id()]) {
127 agent.set_key(agent.query().ptr(), state.query_pos());
128 agent.set_key(terminal_flags_.rank1(state.node_id()));
129 return true;
130 }
131 }
132 state.set_status_code(MARISA_END_OF_COMMON_PREFIX_SEARCH);
133 return false;
134 }
135
predictive_search(Agent & agent) const136 bool LoudsTrie::predictive_search(Agent &agent) const {
137 MARISA_DEBUG_IF(!agent.has_state(), MARISA_STATE_ERROR);
138
139 State &state = agent.state();
140 if (state.status_code() == MARISA_END_OF_PREDICTIVE_SEARCH) {
141 return false;
142 }
143
144 if (state.status_code() != MARISA_READY_TO_PREDICTIVE_SEARCH) {
145 state.predictive_search_init();
146 while (state.query_pos() < agent.query().length()) {
147 if (!predictive_find_child(agent)) {
148 state.set_status_code(MARISA_END_OF_PREDICTIVE_SEARCH);
149 return false;
150 }
151 }
152
153 History history;
154 history.set_node_id(state.node_id());
155 history.set_key_pos(state.key_buf().size());
156 state.history().push_back(history);
157 state.set_history_pos(1);
158
159 if (terminal_flags_[state.node_id()]) {
160 agent.set_key(state.key_buf().begin(), state.key_buf().size());
161 agent.set_key(terminal_flags_.rank1(state.node_id()));
162 return true;
163 }
164 }
165
166 for ( ; ; ) {
167 if (state.history_pos() == state.history().size()) {
168 const History ¤t = state.history().back();
169 History next;
170 next.set_louds_pos(louds_.select0(current.node_id()) + 1);
171 next.set_node_id(next.louds_pos() - current.node_id() - 1);
172 state.history().push_back(next);
173 }
174
175 History &next = state.history()[state.history_pos()];
176 const bool link_flag = louds_[next.louds_pos()];
177 next.set_louds_pos(next.louds_pos() + 1);
178 if (link_flag) {
179 state.set_history_pos(state.history_pos() + 1);
180 if (link_flags_[next.node_id()]) {
181 next.set_link_id(update_link_id(next.link_id(), next.node_id()));
182 restore(agent, get_link(next.node_id(), next.link_id()));
183 } else {
184 state.key_buf().push_back((char)bases_[next.node_id()]);
185 }
186 next.set_key_pos(state.key_buf().size());
187
188 if (terminal_flags_[next.node_id()]) {
189 if (next.key_id() == MARISA_INVALID_KEY_ID) {
190 next.set_key_id(terminal_flags_.rank1(next.node_id()));
191 } else {
192 next.set_key_id(next.key_id() + 1);
193 }
194 agent.set_key(state.key_buf().begin(), state.key_buf().size());
195 agent.set_key(next.key_id());
196 return true;
197 }
198 } else if (state.history_pos() != 1) {
199 History ¤t = state.history()[state.history_pos() - 1];
200 current.set_node_id(current.node_id() + 1);
201 const History &prev =
202 state.history()[state.history_pos() - 2];
203 state.key_buf().resize(prev.key_pos());
204 state.set_history_pos(state.history_pos() - 1);
205 } else {
206 state.set_status_code(MARISA_END_OF_PREDICTIVE_SEARCH);
207 return false;
208 }
209 }
210 }
211
total_size() const212 std::size_t LoudsTrie::total_size() const {
213 return louds_.total_size() + terminal_flags_.total_size()
214 + link_flags_.total_size() + bases_.total_size()
215 + extras_.total_size() + tail_.total_size()
216 + ((next_trie_.get() != NULL) ? next_trie_->total_size() : 0)
217 + cache_.total_size();
218 }
219
io_size() const220 std::size_t LoudsTrie::io_size() const {
221 return Header().io_size() + louds_.io_size()
222 + terminal_flags_.io_size() + link_flags_.io_size()
223 + bases_.io_size() + extras_.io_size() + tail_.io_size()
224 + ((next_trie_.get() != NULL) ?
225 (next_trie_->io_size() - Header().io_size()) : 0)
226 + cache_.io_size() + (sizeof(UInt32) * 2);
227 }
228
clear()229 void LoudsTrie::clear() {
230 LoudsTrie().swap(*this);
231 }
232
swap(LoudsTrie & rhs)233 void LoudsTrie::swap(LoudsTrie &rhs) {
234 louds_.swap(rhs.louds_);
235 terminal_flags_.swap(rhs.terminal_flags_);
236 link_flags_.swap(rhs.link_flags_);
237 bases_.swap(rhs.bases_);
238 extras_.swap(rhs.extras_);
239 tail_.swap(rhs.tail_);
240 next_trie_.swap(rhs.next_trie_);
241 cache_.swap(rhs.cache_);
242 marisa::swap(cache_mask_, rhs.cache_mask_);
243 marisa::swap(num_l1_nodes_, rhs.num_l1_nodes_);
244 config_.swap(rhs.config_);
245 mapper_.swap(rhs.mapper_);
246 }
247
build_(Keyset & keyset,const Config & config)248 void LoudsTrie::build_(Keyset &keyset, const Config &config) {
249 Vector<Key> keys;
250 keys.resize(keyset.size());
251 for (std::size_t i = 0; i < keyset.size(); ++i) {
252 keys[i].set_str(keyset[i].ptr(), keyset[i].length());
253 keys[i].set_weight(keyset[i].weight());
254 }
255
256 Vector<UInt32> terminals;
257 build_trie(keys, &terminals, config, 1);
258
259 typedef std::pair<UInt32, UInt32> TerminalIdPair;
260
261 Vector<TerminalIdPair> pairs;
262 pairs.resize(terminals.size());
263 for (std::size_t i = 0; i < pairs.size(); ++i) {
264 pairs[i].first = terminals[i];
265 pairs[i].second = (UInt32)i;
266 }
267 terminals.clear();
268 std::sort(pairs.begin(), pairs.end());
269
270 std::size_t node_id = 0;
271 for (std::size_t i = 0; i < pairs.size(); ++i) {
272 while (node_id < pairs[i].first) {
273 terminal_flags_.push_back(false);
274 ++node_id;
275 }
276 if (node_id == pairs[i].first) {
277 terminal_flags_.push_back(true);
278 ++node_id;
279 }
280 }
281 while (node_id < bases_.size()) {
282 terminal_flags_.push_back(false);
283 ++node_id;
284 }
285 terminal_flags_.push_back(false);
286 terminal_flags_.build(false, true);
287
288 for (std::size_t i = 0; i < keyset.size(); ++i) {
289 keyset[pairs[i].second].set_id(terminal_flags_.rank1(pairs[i].first));
290 }
291 }
292
293 template <typename T>
build_trie(Vector<T> & keys,Vector<UInt32> * terminals,const Config & config,std::size_t trie_id)294 void LoudsTrie::build_trie(Vector<T> &keys,
295 Vector<UInt32> *terminals, const Config &config, std::size_t trie_id) {
296 build_current_trie(keys, terminals, config, trie_id);
297
298 Vector<UInt32> next_terminals;
299 if (!keys.empty()) {
300 build_next_trie(keys, &next_terminals, config, trie_id);
301 }
302
303 if (next_trie_.get() != NULL) {
304 config_.parse(static_cast<int>((next_trie_->num_tries() + 1)) |
305 next_trie_->tail_mode() | next_trie_->node_order());
306 } else {
307 config_.parse(1 | tail_.mode() | config.node_order() |
308 config.cache_level());
309 }
310
311 link_flags_.build(false, false);
312 std::size_t node_id = 0;
313 for (std::size_t i = 0; i < next_terminals.size(); ++i) {
314 while (!link_flags_[node_id]) {
315 ++node_id;
316 }
317 bases_[node_id] = (UInt8)(next_terminals[i] % 256);
318 next_terminals[i] /= 256;
319 ++node_id;
320 }
321 extras_.build(next_terminals);
322 fill_cache();
323 }
324
325 template <typename T>
build_current_trie(Vector<T> & keys,Vector<UInt32> * terminals,const Config & config,std::size_t trie_id)326 void LoudsTrie::build_current_trie(Vector<T> &keys,
327 Vector<UInt32> *terminals, const Config &config,
328 std::size_t trie_id) try {
329 for (std::size_t i = 0; i < keys.size(); ++i) {
330 keys[i].set_id(i);
331 }
332 const std::size_t num_keys = Algorithm().sort(keys.begin(), keys.end());
333 reserve_cache(config, trie_id, num_keys);
334
335 louds_.push_back(true);
336 louds_.push_back(false);
337 bases_.push_back('\0');
338 link_flags_.push_back(false);
339
340 Vector<T> next_keys;
341 std::queue<Range> queue;
342 Vector<WeightedRange> w_ranges;
343
344 queue.push(make_range(0, keys.size(), 0));
345 while (!queue.empty()) {
346 const std::size_t node_id = link_flags_.size() - queue.size();
347
348 Range range = queue.front();
349 queue.pop();
350
351 while ((range.begin() < range.end()) &&
352 (keys[range.begin()].length() == range.key_pos())) {
353 keys[range.begin()].set_terminal(node_id);
354 range.set_begin(range.begin() + 1);
355 }
356
357 if (range.begin() == range.end()) {
358 louds_.push_back(false);
359 continue;
360 }
361
362 w_ranges.clear();
363 double weight = keys[range.begin()].weight();
364 for (std::size_t i = range.begin() + 1; i < range.end(); ++i) {
365 if (keys[i - 1][range.key_pos()] != keys[i][range.key_pos()]) {
366 w_ranges.push_back(make_weighted_range(
367 range.begin(), i, range.key_pos(), (float)weight));
368 range.set_begin(i);
369 weight = 0.0;
370 }
371 weight += keys[i].weight();
372 }
373 w_ranges.push_back(make_weighted_range(
374 range.begin(), range.end(), range.key_pos(), (float)weight));
375 if (config.node_order() == MARISA_WEIGHT_ORDER) {
376 std::stable_sort(w_ranges.begin(), w_ranges.end(),
377 std::greater<WeightedRange>());
378 }
379
380 if (node_id == 0) {
381 num_l1_nodes_ = w_ranges.size();
382 }
383
384 for (std::size_t i = 0; i < w_ranges.size(); ++i) {
385 WeightedRange &w_range = w_ranges[i];
386 std::size_t key_pos = w_range.key_pos() + 1;
387 while (key_pos < keys[w_range.begin()].length()) {
388 std::size_t j;
389 for (j = w_range.begin() + 1; j < w_range.end(); ++j) {
390 if (keys[j - 1][key_pos] != keys[j][key_pos]) {
391 break;
392 }
393 }
394 if (j < w_range.end()) {
395 break;
396 }
397 ++key_pos;
398 }
399 cache<T>(node_id, bases_.size(), w_range.weight(),
400 keys[w_range.begin()][w_range.key_pos()]);
401
402 if (key_pos == w_range.key_pos() + 1) {
403 bases_.push_back(static_cast<unsigned char>(
404 keys[w_range.begin()][w_range.key_pos()]));
405 link_flags_.push_back(false);
406 } else {
407 bases_.push_back('\0');
408 link_flags_.push_back(true);
409 T next_key;
410 next_key.set_str(keys[w_range.begin()].ptr(),
411 keys[w_range.begin()].length());
412 next_key.substr(w_range.key_pos(), key_pos - w_range.key_pos());
413 next_key.set_weight(w_range.weight());
414 next_keys.push_back(next_key);
415 }
416 w_range.set_key_pos(key_pos);
417 queue.push(w_range.range());
418 louds_.push_back(true);
419 }
420 louds_.push_back(false);
421 }
422
423 louds_.push_back(false);
424 louds_.build(trie_id == 1, true);
425 bases_.shrink();
426
427 build_terminals(keys, terminals);
428 keys.swap(next_keys);
429 } catch (const std::bad_alloc &) {
430 MARISA_THROW(MARISA_MEMORY_ERROR, "std::bad_alloc");
431 }
432
433 template <>
build_next_trie(Vector<Key> & keys,Vector<UInt32> * terminals,const Config & config,std::size_t trie_id)434 void LoudsTrie::build_next_trie(Vector<Key> &keys,
435 Vector<UInt32> *terminals, const Config &config, std::size_t trie_id) {
436 if (trie_id == config.num_tries()) {
437 Vector<Entry> entries;
438 entries.resize(keys.size());
439 for (std::size_t i = 0; i < keys.size(); ++i) {
440 entries[i].set_str(keys[i].ptr(), keys[i].length());
441 }
442 tail_.build(entries, terminals, config.tail_mode());
443 return;
444 }
445 Vector<ReverseKey> reverse_keys;
446 reverse_keys.resize(keys.size());
447 for (std::size_t i = 0; i < keys.size(); ++i) {
448 reverse_keys[i].set_str(keys[i].ptr(), keys[i].length());
449 reverse_keys[i].set_weight(keys[i].weight());
450 }
451 keys.clear();
452 next_trie_.reset(new (std::nothrow) LoudsTrie);
453 MARISA_THROW_IF(next_trie_.get() == NULL, MARISA_MEMORY_ERROR);
454 next_trie_->build_trie(reverse_keys, terminals, config, trie_id + 1);
455 }
456
457 template <>
build_next_trie(Vector<ReverseKey> & keys,Vector<UInt32> * terminals,const Config & config,std::size_t trie_id)458 void LoudsTrie::build_next_trie(Vector<ReverseKey> &keys,
459 Vector<UInt32> *terminals, const Config &config, std::size_t trie_id) {
460 if (trie_id == config.num_tries()) {
461 Vector<Entry> entries;
462 entries.resize(keys.size());
463 for (std::size_t i = 0; i < keys.size(); ++i) {
464 entries[i].set_str(keys[i].ptr(), keys[i].length());
465 }
466 tail_.build(entries, terminals, config.tail_mode());
467 return;
468 }
469 next_trie_.reset(new (std::nothrow) LoudsTrie);
470 MARISA_THROW_IF(next_trie_.get() == NULL, MARISA_MEMORY_ERROR);
471 next_trie_->build_trie(keys, terminals, config, trie_id + 1);
472 }
473
474 template <typename T>
build_terminals(const Vector<T> & keys,Vector<UInt32> * terminals) const475 void LoudsTrie::build_terminals(const Vector<T> &keys,
476 Vector<UInt32> *terminals) const {
477 Vector<UInt32> temp;
478 temp.resize(keys.size());
479 for (std::size_t i = 0; i < keys.size(); ++i) {
480 temp[keys[i].id()] = (UInt32)keys[i].terminal();
481 }
482 terminals->swap(temp);
483 }
484
485 template <>
cache(std::size_t parent,std::size_t child,float weight,char label)486 void LoudsTrie::cache<Key>(std::size_t parent, std::size_t child,
487 float weight, char label) {
488 MARISA_DEBUG_IF(parent >= child, MARISA_RANGE_ERROR);
489
490 const std::size_t cache_id = get_cache_id(parent, label);
491 if (weight > cache_[cache_id].weight()) {
492 cache_[cache_id].set_parent(parent);
493 cache_[cache_id].set_child(child);
494 cache_[cache_id].set_weight(weight);
495 }
496 }
497
reserve_cache(const Config & config,std::size_t trie_id,std::size_t num_keys)498 void LoudsTrie::reserve_cache(const Config &config, std::size_t trie_id,
499 std::size_t num_keys) {
500 std::size_t cache_size = (trie_id == 1) ? 256 : 1;
501 while (cache_size < (num_keys / config.cache_level())) {
502 cache_size *= 2;
503 }
504 cache_.resize(cache_size);
505 cache_mask_ = cache_size - 1;
506 }
507
508 template <>
cache(std::size_t parent,std::size_t child,float weight,char)509 void LoudsTrie::cache<ReverseKey>(std::size_t parent, std::size_t child,
510 float weight, char) {
511 MARISA_DEBUG_IF(parent >= child, MARISA_RANGE_ERROR);
512
513 const std::size_t cache_id = get_cache_id(child);
514 if (weight > cache_[cache_id].weight()) {
515 cache_[cache_id].set_parent(parent);
516 cache_[cache_id].set_child(child);
517 cache_[cache_id].set_weight(weight);
518 }
519 }
520
fill_cache()521 void LoudsTrie::fill_cache() {
522 for (std::size_t i = 0; i < cache_.size(); ++i) {
523 const std::size_t node_id = cache_[i].child();
524 if (node_id != 0) {
525 cache_[i].set_base(bases_[node_id]);
526 cache_[i].set_extra(!link_flags_[node_id] ?
527 MARISA_INVALID_EXTRA : extras_[link_flags_.rank1(node_id)]);
528 } else {
529 cache_[i].set_parent(MARISA_UINT32_MAX);
530 cache_[i].set_child(MARISA_UINT32_MAX);
531 }
532 }
533 }
534
map_(Mapper & mapper)535 void LoudsTrie::map_(Mapper &mapper) {
536 louds_.map(mapper);
537 terminal_flags_.map(mapper);
538 link_flags_.map(mapper);
539 bases_.map(mapper);
540 extras_.map(mapper);
541 tail_.map(mapper);
542 if ((link_flags_.num_1s() != 0) && tail_.empty()) {
543 next_trie_.reset(new (std::nothrow) LoudsTrie);
544 MARISA_THROW_IF(next_trie_.get() == NULL, MARISA_MEMORY_ERROR);
545 next_trie_->map_(mapper);
546 }
547 cache_.map(mapper);
548 cache_mask_ = cache_.size() - 1;
549 {
550 UInt32 temp_num_l1_nodes;
551 mapper.map(&temp_num_l1_nodes);
552 num_l1_nodes_ = temp_num_l1_nodes;
553 }
554 {
555 UInt32 temp_config_flags;
556 mapper.map(&temp_config_flags);
557 config_.parse((int)temp_config_flags);
558 }
559 }
560
read_(Reader & reader)561 void LoudsTrie::read_(Reader &reader) {
562 louds_.read(reader);
563 terminal_flags_.read(reader);
564 link_flags_.read(reader);
565 bases_.read(reader);
566 extras_.read(reader);
567 tail_.read(reader);
568 if ((link_flags_.num_1s() != 0) && tail_.empty()) {
569 next_trie_.reset(new (std::nothrow) LoudsTrie);
570 MARISA_THROW_IF(next_trie_.get() == NULL, MARISA_MEMORY_ERROR);
571 next_trie_->read_(reader);
572 }
573 cache_.read(reader);
574 cache_mask_ = cache_.size() - 1;
575 {
576 UInt32 temp_num_l1_nodes;
577 reader.read(&temp_num_l1_nodes);
578 num_l1_nodes_ = temp_num_l1_nodes;
579 }
580 {
581 UInt32 temp_config_flags;
582 reader.read(&temp_config_flags);
583 config_.parse((int)temp_config_flags);
584 }
585 }
586
write_(Writer & writer) const587 void LoudsTrie::write_(Writer &writer) const {
588 louds_.write(writer);
589 terminal_flags_.write(writer);
590 link_flags_.write(writer);
591 bases_.write(writer);
592 extras_.write(writer);
593 tail_.write(writer);
594 if (next_trie_.get() != NULL) {
595 next_trie_->write_(writer);
596 }
597 cache_.write(writer);
598 writer.write((UInt32)num_l1_nodes_);
599 writer.write((UInt32)config_.flags());
600 }
601
find_child(Agent & agent) const602 bool LoudsTrie::find_child(Agent &agent) const {
603 MARISA_DEBUG_IF(agent.state().query_pos() >= agent.query().length(),
604 MARISA_BOUND_ERROR);
605
606 State &state = agent.state();
607 const std::size_t cache_id = get_cache_id(state.node_id(),
608 agent.query()[state.query_pos()]);
609 if (state.node_id() == cache_[cache_id].parent()) {
610 if (cache_[cache_id].extra() != MARISA_INVALID_EXTRA) {
611 if (!match(agent, cache_[cache_id].link())) {
612 return false;
613 }
614 } else {
615 state.set_query_pos(state.query_pos() + 1);
616 }
617 state.set_node_id(cache_[cache_id].child());
618 return true;
619 }
620
621 std::size_t louds_pos = louds_.select0(state.node_id()) + 1;
622 if (!louds_[louds_pos]) {
623 return false;
624 }
625 state.set_node_id(louds_pos - state.node_id() - 1);
626 std::size_t link_id = MARISA_INVALID_LINK_ID;
627 do {
628 if (link_flags_[state.node_id()]) {
629 link_id = update_link_id(link_id, state.node_id());
630 const std::size_t prev_query_pos = state.query_pos();
631 if (match(agent, get_link(state.node_id(), link_id))) {
632 return true;
633 } else if (state.query_pos() != prev_query_pos) {
634 return false;
635 }
636 } else if (bases_[state.node_id()] ==
637 (UInt8)agent.query()[state.query_pos()]) {
638 state.set_query_pos(state.query_pos() + 1);
639 return true;
640 }
641 state.set_node_id(state.node_id() + 1);
642 ++louds_pos;
643 } while (louds_[louds_pos]);
644 return false;
645 }
646
predictive_find_child(Agent & agent) const647 bool LoudsTrie::predictive_find_child(Agent &agent) const {
648 MARISA_DEBUG_IF(agent.state().query_pos() >= agent.query().length(),
649 MARISA_BOUND_ERROR);
650
651 State &state = agent.state();
652 const std::size_t cache_id = get_cache_id(state.node_id(),
653 agent.query()[state.query_pos()]);
654 if (state.node_id() == cache_[cache_id].parent()) {
655 if (cache_[cache_id].extra() != MARISA_INVALID_EXTRA) {
656 if (!prefix_match(agent, cache_[cache_id].link())) {
657 return false;
658 }
659 } else {
660 state.key_buf().push_back(cache_[cache_id].label());
661 state.set_query_pos(state.query_pos() + 1);
662 }
663 state.set_node_id(cache_[cache_id].child());
664 return true;
665 }
666
667 std::size_t louds_pos = louds_.select0(state.node_id()) + 1;
668 if (!louds_[louds_pos]) {
669 return false;
670 }
671 state.set_node_id(louds_pos - state.node_id() - 1);
672 std::size_t link_id = MARISA_INVALID_LINK_ID;
673 do {
674 if (link_flags_[state.node_id()]) {
675 link_id = update_link_id(link_id, state.node_id());
676 const std::size_t prev_query_pos = state.query_pos();
677 if (prefix_match(agent, get_link(state.node_id(), link_id))) {
678 return true;
679 } else if (state.query_pos() != prev_query_pos) {
680 return false;
681 }
682 } else if (bases_[state.node_id()] ==
683 (UInt8)agent.query()[state.query_pos()]) {
684 state.key_buf().push_back((char)bases_[state.node_id()]);
685 state.set_query_pos(state.query_pos() + 1);
686 return true;
687 }
688 state.set_node_id(state.node_id() + 1);
689 ++louds_pos;
690 } while (louds_[louds_pos]);
691 return false;
692 }
693
restore(Agent & agent,std::size_t link) const694 void LoudsTrie::restore(Agent &agent, std::size_t link) const {
695 if (next_trie_.get() != NULL) {
696 next_trie_->restore_(agent, link);
697 } else {
698 tail_.restore(agent, link);
699 }
700 }
701
match(Agent & agent,std::size_t link) const702 bool LoudsTrie::match(Agent &agent, std::size_t link) const {
703 if (next_trie_.get() != NULL) {
704 return next_trie_->match_(agent, link);
705 } else {
706 return tail_.match(agent, link);
707 }
708 }
709
prefix_match(Agent & agent,std::size_t link) const710 bool LoudsTrie::prefix_match(Agent &agent, std::size_t link) const {
711 if (next_trie_.get() != NULL) {
712 return next_trie_->prefix_match_(agent, link);
713 } else {
714 return tail_.prefix_match(agent, link);
715 }
716 }
717
restore_(Agent & agent,std::size_t node_id) const718 void LoudsTrie::restore_(Agent &agent, std::size_t node_id) const {
719 MARISA_DEBUG_IF(node_id == 0, MARISA_RANGE_ERROR);
720
721 State &state = agent.state();
722 for ( ; ; ) {
723 const std::size_t cache_id = get_cache_id(node_id);
724 if (node_id == cache_[cache_id].child()) {
725 if (cache_[cache_id].extra() != MARISA_INVALID_EXTRA) {
726 restore(agent, cache_[cache_id].link());
727 } else {
728 state.key_buf().push_back(cache_[cache_id].label());
729 }
730
731 node_id = cache_[cache_id].parent();
732 if (node_id == 0) {
733 return;
734 }
735 continue;
736 }
737
738 if (link_flags_[node_id]) {
739 restore(agent, get_link(node_id));
740 } else {
741 state.key_buf().push_back((char)bases_[node_id]);
742 }
743
744 if (node_id <= num_l1_nodes_) {
745 return;
746 }
747 node_id = louds_.select1(node_id) - node_id - 1;
748 }
749 }
750
match_(Agent & agent,std::size_t node_id) const751 bool LoudsTrie::match_(Agent &agent, std::size_t node_id) const {
752 MARISA_DEBUG_IF(agent.state().query_pos() >= agent.query().length(),
753 MARISA_BOUND_ERROR);
754 MARISA_DEBUG_IF(node_id == 0, MARISA_RANGE_ERROR);
755
756 State &state = agent.state();
757 for ( ; ; ) {
758 const std::size_t cache_id = get_cache_id(node_id);
759 if (node_id == cache_[cache_id].child()) {
760 if (cache_[cache_id].extra() != MARISA_INVALID_EXTRA) {
761 if (!match(agent, cache_[cache_id].link())) {
762 return false;
763 }
764 } else if (cache_[cache_id].label() ==
765 agent.query()[state.query_pos()]) {
766 state.set_query_pos(state.query_pos() + 1);
767 } else {
768 return false;
769 }
770
771 node_id = cache_[cache_id].parent();
772 if (node_id == 0) {
773 return true;
774 } else if (state.query_pos() >= agent.query().length()) {
775 return false;
776 }
777 continue;
778 }
779
780 if (link_flags_[node_id]) {
781 if (next_trie_.get() != NULL) {
782 if (!match(agent, get_link(node_id))) {
783 return false;
784 }
785 } else if (!tail_.match(agent, get_link(node_id))) {
786 return false;
787 }
788 } else if (bases_[node_id] == (UInt8)agent.query()[state.query_pos()]) {
789 state.set_query_pos(state.query_pos() + 1);
790 } else {
791 return false;
792 }
793
794 if (node_id <= num_l1_nodes_) {
795 return true;
796 } else if (state.query_pos() >= agent.query().length()) {
797 return false;
798 }
799 node_id = louds_.select1(node_id) - node_id - 1;
800 }
801 }
802
prefix_match_(Agent & agent,std::size_t node_id) const803 bool LoudsTrie::prefix_match_(Agent &agent, std::size_t node_id) const {
804 MARISA_DEBUG_IF(agent.state().query_pos() >= agent.query().length(),
805 MARISA_BOUND_ERROR);
806 MARISA_DEBUG_IF(node_id == 0, MARISA_RANGE_ERROR);
807
808 State &state = agent.state();
809 for ( ; ; ) {
810 const std::size_t cache_id = get_cache_id(node_id);
811 if (node_id == cache_[cache_id].child()) {
812 if (cache_[cache_id].extra() != MARISA_INVALID_EXTRA) {
813 if (!prefix_match(agent, cache_[cache_id].link())) {
814 return false;
815 }
816 } else if (cache_[cache_id].label() ==
817 agent.query()[state.query_pos()]) {
818 state.key_buf().push_back(cache_[cache_id].label());
819 state.set_query_pos(state.query_pos() + 1);
820 } else {
821 return false;
822 }
823
824 node_id = cache_[cache_id].parent();
825 if (node_id == 0) {
826 return true;
827 }
828 } else {
829 if (link_flags_[node_id]) {
830 if (!prefix_match(agent, get_link(node_id))) {
831 return false;
832 }
833 } else if (bases_[node_id] == (UInt8)agent.query()[state.query_pos()]) {
834 state.key_buf().push_back((char)bases_[node_id]);
835 state.set_query_pos(state.query_pos() + 1);
836 } else {
837 return false;
838 }
839
840 if (node_id <= num_l1_nodes_) {
841 return true;
842 }
843 node_id = louds_.select1(node_id) - node_id - 1;
844 }
845
846 if (state.query_pos() >= agent.query().length()) {
847 restore_(agent, node_id);
848 return true;
849 }
850 }
851 }
852
get_cache_id(std::size_t node_id,char label) const853 std::size_t LoudsTrie::get_cache_id(std::size_t node_id, char label) const {
854 return (node_id ^ (node_id << 5) ^ (UInt8)label) & cache_mask_;
855 }
856
get_cache_id(std::size_t node_id) const857 std::size_t LoudsTrie::get_cache_id(std::size_t node_id) const {
858 return node_id & cache_mask_;
859 }
860
get_link(std::size_t node_id) const861 std::size_t LoudsTrie::get_link(std::size_t node_id) const {
862 return bases_[node_id] | (extras_[link_flags_.rank1(node_id)] * 256);
863 }
864
get_link(std::size_t node_id,std::size_t link_id) const865 std::size_t LoudsTrie::get_link(std::size_t node_id,
866 std::size_t link_id) const {
867 return bases_[node_id] | (extras_[link_id] * 256);
868 }
869
update_link_id(std::size_t link_id,std::size_t node_id) const870 std::size_t LoudsTrie::update_link_id(std::size_t link_id,
871 std::size_t node_id) const {
872 return (link_id == MARISA_INVALID_LINK_ID) ?
873 link_flags_.rank1(node_id) : (link_id + 1);
874 }
875
876 } // namespace trie
877 } // namespace grimoire
878 } // namespace marisa
879