• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #ifndef MARISA_ALPHA_TRIE_INLINE_H_
2 #define MARISA_ALPHA_TRIE_INLINE_H_
3 
4 #include <stdexcept>
5 
6 #include "cell.h"
7 
8 namespace marisa_alpha {
9 
10 inline std::string Trie::operator[](UInt32 key_id) const {
11   std::string key;
12   restore(key_id, &key);
13   return key;
14 }
15 
16 inline UInt32 Trie::operator[](const char *str) const {
17   return lookup(str);
18 }
19 
20 inline UInt32 Trie::operator[](const std::string &str) const {
21   return lookup(str);
22 }
23 
lookup(const std::string & str)24 inline UInt32 Trie::lookup(const std::string &str) const {
25   return lookup(str.c_str(), str.length());
26 }
27 
find(const std::string & str,UInt32 * key_ids,std::size_t * key_lengths,std::size_t max_num_results)28 inline std::size_t Trie::find(const std::string &str,
29     UInt32 *key_ids, std::size_t *key_lengths,
30     std::size_t max_num_results) const {
31   return find(str.c_str(), str.length(),
32       key_ids, key_lengths, max_num_results);
33 }
34 
find(const std::string & str,std::vector<UInt32> * key_ids,std::vector<std::size_t> * key_lengths,std::size_t max_num_results)35 inline std::size_t Trie::find(const std::string &str,
36     std::vector<UInt32> *key_ids, std::vector<std::size_t> *key_lengths,
37     std::size_t max_num_results) const {
38   return find(str.c_str(), str.length(),
39       key_ids, key_lengths, max_num_results);
40 }
41 
find_first(const std::string & str,std::size_t * key_length)42 inline UInt32 Trie::find_first(const std::string &str,
43     std::size_t *key_length) const {
44   return find_first(str.c_str(), str.length(), key_length);
45 }
46 
find_last(const std::string & str,std::size_t * key_length)47 inline UInt32 Trie::find_last(const std::string &str,
48     std::size_t *key_length) const {
49   return find_last(str.c_str(), str.length(), key_length);
50 }
51 
52 template <typename T>
find_callback(const char * str,T callback)53 inline std::size_t Trie::find_callback(const char *str,
54     T callback) const {
55   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
56   MARISA_ALPHA_THROW_IF(str == NULL, MARISA_ALPHA_PARAM_ERROR);
57   return find_callback_<CQuery>(CQuery(str), callback);
58 }
59 
60 template <typename T>
find_callback(const char * ptr,std::size_t length,T callback)61 inline std::size_t Trie::find_callback(const char *ptr, std::size_t length,
62     T callback) const {
63   MARISA_ALPHA_THROW_IF(empty(), MARISA_ALPHA_STATE_ERROR);
64   MARISA_ALPHA_THROW_IF((ptr == NULL) && (length != 0),
65       MARISA_ALPHA_PARAM_ERROR);
66   return find_callback_<const Query &>(Query(ptr, length), callback);
67 }
68 
69 template <typename T>
find_callback(const std::string & str,T callback)70 inline std::size_t Trie::find_callback(const std::string &str,
71     T callback) const {
72   return find_callback(str.c_str(), str.length(), callback);
73 }
74 
predict(const std::string & str,UInt32 * key_ids,std::string * keys,std::size_t max_num_results)75 inline std::size_t Trie::predict(const std::string &str,
76     UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
77   return predict(str.c_str(), str.length(), key_ids, keys, max_num_results);
78 }
79 
predict(const std::string & str,std::vector<UInt32> * key_ids,std::vector<std::string> * keys,std::size_t max_num_results)80 inline std::size_t Trie::predict(const std::string &str,
81     std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
82     std::size_t max_num_results) const {
83   return predict(str.c_str(), str.length(), key_ids, keys, max_num_results);
84 }
85 
predict_breadth_first(const std::string & str,UInt32 * key_ids,std::string * keys,std::size_t max_num_results)86 inline std::size_t Trie::predict_breadth_first(const std::string &str,
87     UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
88   return predict_breadth_first(str.c_str(), str.length(),
89       key_ids, keys, max_num_results);
90 }
91 
predict_breadth_first(const std::string & str,std::vector<UInt32> * key_ids,std::vector<std::string> * keys,std::size_t max_num_results)92 inline std::size_t Trie::predict_breadth_first(const std::string &str,
93     std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
94     std::size_t max_num_results) const {
95   return predict_breadth_first(str.c_str(), str.length(),
96       key_ids, keys, max_num_results);
97 }
98 
predict_depth_first(const std::string & str,UInt32 * key_ids,std::string * keys,std::size_t max_num_results)99 inline std::size_t Trie::predict_depth_first(const std::string &str,
100     UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
101   return predict_depth_first(str.c_str(), str.length(),
102       key_ids, keys, max_num_results);
103 }
104 
predict_depth_first(const std::string & str,std::vector<UInt32> * key_ids,std::vector<std::string> * keys,std::size_t max_num_results)105 inline std::size_t Trie::predict_depth_first(const std::string &str,
106     std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
107     std::size_t max_num_results) const {
108   return predict_depth_first(str.c_str(), str.length(),
109       key_ids, keys, max_num_results);
110 }
111 
112 template <typename T>
predict_callback(const char * str,T callback)113 inline std::size_t Trie::predict_callback(
114     const char *str, T callback) const {
115   return predict_callback_<CQuery>(CQuery(str), callback);
116 }
117 
118 template <typename T>
predict_callback(const char * ptr,std::size_t length,T callback)119 inline std::size_t Trie::predict_callback(
120     const char *ptr, std::size_t length,
121     T callback) const {
122   return predict_callback_<const Query &>(Query(ptr, length), callback);
123 }
124 
125 template <typename T>
predict_callback(const std::string & str,T callback)126 inline std::size_t Trie::predict_callback(
127     const std::string &str, T callback) const {
128   return predict_callback(str.c_str(), str.length(), callback);
129 }
130 
empty()131 inline bool Trie::empty() const {
132   return louds_.empty();
133 }
134 
num_keys()135 inline std::size_t Trie::num_keys() const {
136   return num_keys_;
137 }
138 
notfound()139 inline UInt32 Trie::notfound() {
140   return MARISA_ALPHA_NOT_FOUND;
141 }
142 
mismatch()143 inline std::size_t Trie::mismatch() {
144   return MARISA_ALPHA_MISMATCH;
145 }
146 
147 template <typename T>
find_child(UInt32 & node,T query,std::size_t & pos)148 inline bool Trie::find_child(UInt32 &node, T query,
149     std::size_t &pos) const {
150   UInt32 louds_pos = get_child(node);
151   if (!louds_[louds_pos]) {
152     return false;
153   }
154   node = louds_pos_to_node(louds_pos, node);
155   UInt32 link_id = MARISA_ALPHA_UINT32_MAX;
156   do {
157     if (has_link(node)) {
158       if (link_id == MARISA_ALPHA_UINT32_MAX) {
159         link_id = get_link_id(node);
160       } else {
161         ++link_id;
162       }
163       std::size_t next_pos = has_trie() ?
164           trie_->trie_match<T>(get_link(node, link_id), query, pos) :
165           tail_match<T>(node, link_id, query, pos);
166       if (next_pos == mismatch()) {
167         return false;
168       } else if (next_pos != pos) {
169         pos = next_pos;
170         return true;
171       }
172     } else if (labels_[node] == query[pos]) {
173       ++pos;
174       return true;
175     }
176     ++node;
177     ++louds_pos;
178   } while (louds_[louds_pos]);
179   return false;
180 }
181 
182 template <typename T, typename U>
find_callback_(T query,U callback)183 std::size_t Trie::find_callback_(T query, U callback) const try {
184   std::size_t count = 0;
185   UInt32 node = 0;
186   std::size_t pos = 0;
187   do {
188     if (terminal_flags_[node]) {
189       ++count;
190       if (!callback(node_to_key_id(node), pos)) {
191         return count;
192       }
193     }
194   } while (!query.ends_at(pos) && find_child<T>(node, query, pos));
195   return count;
196 } catch (const std::bad_alloc &) {
197   MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR);
catch(const std::length_error &)198 } catch (const std::length_error &) {
199   MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR);
200 }
201 
202 template <typename T>
predict_child(UInt32 & node,T query,std::size_t & pos,std::string * key)203 inline bool Trie::predict_child(UInt32 &node, T query, std::size_t &pos,
204     std::string *key) const {
205   UInt32 louds_pos = get_child(node);
206   if (!louds_[louds_pos]) {
207     return false;
208   }
209   node = louds_pos_to_node(louds_pos, node);
210   UInt32 link_id = MARISA_ALPHA_UINT32_MAX;
211   do {
212     if (has_link(node)) {
213       if (link_id == MARISA_ALPHA_UINT32_MAX) {
214         link_id = get_link_id(node);
215       } else {
216         ++link_id;
217       }
218       std::size_t next_pos = has_trie() ?
219           trie_->trie_prefix_match<T>(
220               get_link(node, link_id), query, pos, key) :
221           tail_prefix_match<T>(node, link_id, query, pos, key);
222       if (next_pos == mismatch()) {
223         return false;
224       } else if (next_pos != pos) {
225         pos = next_pos;
226         return true;
227       }
228     } else if (labels_[node] == query[pos]) {
229       ++pos;
230       return true;
231     }
232     ++node;
233     ++louds_pos;
234   } while (louds_[louds_pos]);
235   return false;
236 }
237 
238 template <typename T, typename U>
predict_callback_(T query,U callback)239 std::size_t Trie::predict_callback_(T query, U callback) const try {
240   std::string key;
241   UInt32 node = 0;
242   std::size_t pos = 0;
243   while (!query.ends_at(pos)) {
244     if (!predict_child<T>(node, query, pos, &key)) {
245       return 0;
246     }
247   }
248   query.insert(&key);
249   std::size_t count = 0;
250   if (terminal_flags_[node]) {
251     ++count;
252     if (!callback(node_to_key_id(node), key)) {
253       return count;
254     }
255   }
256   Cell cell;
257   cell.set_louds_pos(get_child(node));
258   if (!louds_[cell.louds_pos()]) {
259     return count;
260   }
261   cell.set_node(louds_pos_to_node(cell.louds_pos(), node));
262   cell.set_key_id(node_to_key_id(cell.node()));
263   cell.set_length(key.length());
264   Vector<Cell> stack;
265   stack.push_back(cell);
266   std::size_t stack_pos = 1;
267   while (stack_pos != 0) {
268     Cell &cur = stack[stack_pos - 1];
269     if (!louds_[cur.louds_pos()]) {
270       cur.set_louds_pos(cur.louds_pos() + 1);
271       --stack_pos;
272       continue;
273     }
274     cur.set_louds_pos(cur.louds_pos() + 1);
275     key.resize(cur.length());
276     if (has_link(cur.node())) {
277       if (has_trie()) {
278         trie_->trie_restore(get_link(cur.node()), &key);
279       } else {
280         tail_restore(cur.node(), &key);
281       }
282     } else {
283       key += labels_[cur.node()];
284     }
285     if (terminal_flags_[cur.node()]) {
286       ++count;
287       if (!callback(cur.key_id(), key)) {
288         return count;
289       }
290       cur.set_key_id(cur.key_id() + 1);
291     }
292     if (stack_pos == stack.size()) {
293       cell.set_louds_pos(get_child(cur.node()));
294       cell.set_node(louds_pos_to_node(cell.louds_pos(), cur.node()));
295       cell.set_key_id(node_to_key_id(cell.node()));
296       stack.push_back(cell);
297     }
298     stack[stack_pos].set_length(key.length());
299     stack[stack_pos - 1].set_node(stack[stack_pos - 1].node() + 1);
300     ++stack_pos;
301   }
302   return count;
303 } catch (const std::bad_alloc &) {
304   MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR);
catch(const std::length_error &)305 } catch (const std::length_error &) {
306   MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR);
307 }
308 
key_id_to_node(UInt32 key_id)309 inline UInt32 Trie::key_id_to_node(UInt32 key_id) const {
310   return terminal_flags_.select1(key_id);
311 }
312 
node_to_key_id(UInt32 node)313 inline UInt32 Trie::node_to_key_id(UInt32 node) const {
314   return terminal_flags_.rank1(node);
315 }
316 
louds_pos_to_node(UInt32 louds_pos,UInt32 parent_node)317 inline UInt32 Trie::louds_pos_to_node(UInt32 louds_pos,
318     UInt32 parent_node) const {
319   return louds_pos - parent_node - 1;
320 }
321 
get_child(UInt32 node)322 inline UInt32 Trie::get_child(UInt32 node) const {
323   return louds_.select0(node) + 1;
324 }
325 
get_parent(UInt32 node)326 inline UInt32 Trie::get_parent(UInt32 node) const {
327   return (node > num_first_branches_) ? (louds_.select1(node) - node - 1) : 0;
328 }
329 
has_link(UInt32 node)330 inline bool Trie::has_link(UInt32 node) const {
331   return (link_flags_.empty()) ? false : link_flags_[node];
332 }
333 
get_link_id(UInt32 node)334 inline UInt32 Trie::get_link_id(UInt32 node) const {
335   return link_flags_.rank1(node);
336 }
337 
get_link(UInt32 node)338 inline UInt32 Trie::get_link(UInt32 node) const {
339   return get_link(node, get_link_id(node));
340 }
341 
get_link(UInt32 node,UInt32 link_id)342 inline UInt32 Trie::get_link(UInt32 node, UInt32 link_id) const {
343   return (links_[link_id] * 256) + labels_[node];
344 }
345 
has_link()346 inline bool Trie::has_link() const {
347   return !link_flags_.empty();
348 }
349 
has_trie()350 inline bool Trie::has_trie() const {
351   return trie_.get() != NULL;
352 }
353 
has_tail()354 inline bool Trie::has_tail() const {
355   return !tail_.empty();
356 }
357 
358 }  // namespace marisa_alpha
359 
360 #endif  // MARISA_ALPHA_TRIE_INLINE_H_
361