• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #ifndef MARISA_TRIE_INLINE_H_
2 #define MARISA_TRIE_INLINE_H_
3 
4 #include <stdexcept>
5 
6 #include "cell.h"
7 
8 namespace marisa {
9 
10 inline std::string Trie::operator[](UInt32 key_id) const {
11   std::string key;
12   restore(key_id, &key);
13   return key;
14 }
15 
16 inline UInt32 Trie::operator[](const char *str) const {
17   return lookup(str);
18 }
19 
20 inline UInt32 Trie::operator[](const std::string &str) const {
21   return lookup(str);
22 }
23 
lookup(const std::string & str)24 inline UInt32 Trie::lookup(const std::string &str) const {
25   return lookup(str.c_str(), str.length());
26 }
27 
find(const std::string & str,UInt32 * key_ids,std::size_t * key_lengths,std::size_t max_num_results)28 inline std::size_t Trie::find(const std::string &str,
29     UInt32 *key_ids, std::size_t *key_lengths,
30     std::size_t max_num_results) const {
31   return find(str.c_str(), str.length(),
32       key_ids, key_lengths, max_num_results);
33 }
34 
find(const std::string & str,std::vector<UInt32> * key_ids,std::vector<std::size_t> * key_lengths,std::size_t max_num_results)35 inline std::size_t Trie::find(const std::string &str,
36     std::vector<UInt32> *key_ids, std::vector<std::size_t> *key_lengths,
37     std::size_t max_num_results) const {
38   return find(str.c_str(), str.length(),
39       key_ids, key_lengths, max_num_results);
40 }
41 
find_first(const std::string & str,std::size_t * key_length)42 inline UInt32 Trie::find_first(const std::string &str,
43     std::size_t *key_length) const {
44   return find_first(str.c_str(), str.length(), key_length);
45 }
46 
find_last(const std::string & str,std::size_t * key_length)47 inline UInt32 Trie::find_last(const std::string &str,
48     std::size_t *key_length) const {
49   return find_last(str.c_str(), str.length(), key_length);
50 }
51 
52 template <typename T>
find_callback(const char * str,T callback)53 inline std::size_t Trie::find_callback(const char *str,
54     T callback) const {
55   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
56   MARISA_THROW_IF(str == NULL, MARISA_PARAM_ERROR);
57   return find_callback_<CQuery>(CQuery(str), callback);
58 }
59 
60 template <typename T>
find_callback(const char * ptr,std::size_t length,T callback)61 inline std::size_t Trie::find_callback(const char *ptr, std::size_t length,
62     T callback) const {
63   MARISA_THROW_IF(empty(), MARISA_STATE_ERROR);
64   MARISA_THROW_IF((ptr == NULL) && (length != 0), MARISA_PARAM_ERROR);
65   return find_callback_<const Query &>(Query(ptr, length), callback);
66 }
67 
68 template <typename T>
find_callback(const std::string & str,T callback)69 inline std::size_t Trie::find_callback(const std::string &str,
70     T callback) const {
71   return find_callback(str.c_str(), str.length(), callback);
72 }
73 
predict(const std::string & str,UInt32 * key_ids,std::string * keys,std::size_t max_num_results)74 inline std::size_t Trie::predict(const std::string &str,
75     UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
76   return predict(str.c_str(), str.length(), key_ids, keys, max_num_results);
77 }
78 
predict(const std::string & str,std::vector<UInt32> * key_ids,std::vector<std::string> * keys,std::size_t max_num_results)79 inline std::size_t Trie::predict(const std::string &str,
80     std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
81     std::size_t max_num_results) const {
82   return predict(str.c_str(), str.length(), key_ids, keys, max_num_results);
83 }
84 
predict_breadth_first(const std::string & str,UInt32 * key_ids,std::string * keys,std::size_t max_num_results)85 inline std::size_t Trie::predict_breadth_first(const std::string &str,
86     UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
87   return predict_breadth_first(str.c_str(), str.length(),
88       key_ids, keys, max_num_results);
89 }
90 
predict_breadth_first(const std::string & str,std::vector<UInt32> * key_ids,std::vector<std::string> * keys,std::size_t max_num_results)91 inline std::size_t Trie::predict_breadth_first(const std::string &str,
92     std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
93     std::size_t max_num_results) const {
94   return predict_breadth_first(str.c_str(), str.length(),
95       key_ids, keys, max_num_results);
96 }
97 
predict_depth_first(const std::string & str,UInt32 * key_ids,std::string * keys,std::size_t max_num_results)98 inline std::size_t Trie::predict_depth_first(const std::string &str,
99     UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const {
100   return predict_depth_first(str.c_str(), str.length(),
101       key_ids, keys, max_num_results);
102 }
103 
predict_depth_first(const std::string & str,std::vector<UInt32> * key_ids,std::vector<std::string> * keys,std::size_t max_num_results)104 inline std::size_t Trie::predict_depth_first(const std::string &str,
105     std::vector<UInt32> *key_ids, std::vector<std::string> *keys,
106     std::size_t max_num_results) const {
107   return predict_depth_first(str.c_str(), str.length(),
108       key_ids, keys, max_num_results);
109 }
110 
111 template <typename T>
predict_callback(const char * str,T callback)112 inline std::size_t Trie::predict_callback(
113     const char *str, T callback) const {
114   return predict_callback_<CQuery>(CQuery(str), callback);
115 }
116 
117 template <typename T>
predict_callback(const char * ptr,std::size_t length,T callback)118 inline std::size_t Trie::predict_callback(
119     const char *ptr, std::size_t length,
120     T callback) const {
121   return predict_callback_<const Query &>(Query(ptr, length), callback);
122 }
123 
124 template <typename T>
predict_callback(const std::string & str,T callback)125 inline std::size_t Trie::predict_callback(
126     const std::string &str, T callback) const {
127   return predict_callback(str.c_str(), str.length(), callback);
128 }
129 
empty()130 inline bool Trie::empty() const {
131   return louds_.empty();
132 }
133 
num_keys()134 inline std::size_t Trie::num_keys() const {
135   return num_keys_;
136 }
137 
notfound()138 inline UInt32 Trie::notfound() {
139   return MARISA_NOT_FOUND;
140 }
141 
mismatch()142 inline std::size_t Trie::mismatch() {
143   return MARISA_MISMATCH;
144 }
145 
146 template <typename T>
find_child(UInt32 & node,T query,std::size_t & pos)147 inline bool Trie::find_child(UInt32 &node, T query,
148     std::size_t &pos) const {
149   UInt32 louds_pos = get_child(node);
150   if (!louds_[louds_pos]) {
151     return false;
152   }
153   node = louds_pos_to_node(louds_pos, node);
154   UInt32 link_id = MARISA_UINT32_MAX;
155   do {
156     if (has_link(node)) {
157       if (link_id == MARISA_UINT32_MAX) {
158         link_id = get_link_id(node);
159       } else {
160         ++link_id;
161       }
162       std::size_t next_pos = has_trie() ?
163           trie_->trie_match<T>(get_link(node, link_id), query, pos) :
164           tail_match<T>(node, link_id, query, pos);
165       if (next_pos == mismatch()) {
166         return false;
167       } else if (next_pos != pos) {
168         pos = next_pos;
169         return true;
170       }
171     } else if (labels_[node] == query[pos]) {
172       ++pos;
173       return true;
174     }
175     ++node;
176     ++louds_pos;
177   } while (louds_[louds_pos]);
178   return false;
179 }
180 
181 template <typename T, typename U>
find_callback_(T query,U callback)182 std::size_t Trie::find_callback_(T query, U callback) const {
183   std::size_t count = 0;
184   UInt32 node = 0;
185   std::size_t pos = 0;
186   do {
187     if (terminal_flags_[node]) {
188       ++count;
189       if (!callback(node_to_key_id(node), pos)) {
190         return count;
191       }
192     }
193   } while (!query.ends_at(pos) && find_child<T>(node, query, pos));
194   return count;
195 }
196 
197 template <typename T>
predict_child(UInt32 & node,T query,std::size_t & pos,std::string * key)198 inline bool Trie::predict_child(UInt32 &node, T query, std::size_t &pos,
199     std::string *key) const {
200   UInt32 louds_pos = get_child(node);
201   if (!louds_[louds_pos]) {
202     return false;
203   }
204   node = louds_pos_to_node(louds_pos, node);
205   UInt32 link_id = MARISA_UINT32_MAX;
206   do {
207     if (has_link(node)) {
208       if (link_id == MARISA_UINT32_MAX) {
209         link_id = get_link_id(node);
210       } else {
211         ++link_id;
212       }
213       std::size_t next_pos = has_trie() ?
214           trie_->trie_prefix_match<T>(
215               get_link(node, link_id), query, pos, key) :
216           tail_prefix_match<T>(node, link_id, query, pos, key);
217       if (next_pos == mismatch()) {
218         return false;
219       } else if (next_pos != pos) {
220         pos = next_pos;
221         return true;
222       }
223     } else if (labels_[node] == query[pos]) {
224       ++pos;
225       return true;
226     }
227     ++node;
228     ++louds_pos;
229   } while (louds_[louds_pos]);
230   return false;
231 }
232 
233 template <typename T, typename U>
predict_callback_(T query,U callback)234 std::size_t Trie::predict_callback_(T query, U callback) const {
235   std::string key;
236   UInt32 node = 0;
237   std::size_t pos = 0;
238   while (!query.ends_at(pos)) {
239     if (!predict_child<T>(node, query, pos, &key)) {
240       return 0;
241     }
242   }
243   query.insert(&key);
244   std::size_t count = 0;
245   if (terminal_flags_[node]) {
246     ++count;
247     if (!callback(node_to_key_id(node), key)) {
248       return count;
249     }
250   }
251   Cell cell;
252   cell.set_louds_pos(get_child(node));
253   if (!louds_[cell.louds_pos()]) {
254     return count;
255   }
256   cell.set_node(louds_pos_to_node(cell.louds_pos(), node));
257   cell.set_key_id(node_to_key_id(cell.node()));
258   cell.set_length(key.length());
259   Vector<Cell> stack;
260   stack.push_back(cell);
261   std::size_t stack_pos = 1;
262   while (stack_pos != 0) {
263     Cell &cur = stack[stack_pos - 1];
264     if (!louds_[cur.louds_pos()]) {
265       cur.set_louds_pos(cur.louds_pos() + 1);
266       --stack_pos;
267       continue;
268     }
269     cur.set_louds_pos(cur.louds_pos() + 1);
270     key.resize(cur.length());
271     if (has_link(cur.node())) {
272       if (has_trie()) {
273         trie_->trie_restore(get_link(cur.node()), &key);
274       } else {
275         tail_restore(cur.node(), &key);
276       }
277     } else {
278       key += labels_[cur.node()];
279     }
280     if (terminal_flags_[cur.node()]) {
281       ++count;
282       if (!callback(cur.key_id(), key)) {
283         return count;
284       }
285       cur.set_key_id(cur.key_id() + 1);
286     }
287     if (stack_pos == stack.size()) {
288       cell.set_louds_pos(get_child(cur.node()));
289       cell.set_node(louds_pos_to_node(cell.louds_pos(), cur.node()));
290       cell.set_key_id(node_to_key_id(cell.node()));
291       stack.push_back(cell);
292     }
293     stack[stack_pos].set_length(key.length());
294     stack[stack_pos - 1].set_node(stack[stack_pos - 1].node() + 1);
295     ++stack_pos;
296   }
297   return count;
298 }
299 
key_id_to_node(UInt32 key_id)300 inline UInt32 Trie::key_id_to_node(UInt32 key_id) const {
301   return terminal_flags_.select1(key_id);
302 }
303 
node_to_key_id(UInt32 node)304 inline UInt32 Trie::node_to_key_id(UInt32 node) const {
305   return terminal_flags_.rank1(node);
306 }
307 
louds_pos_to_node(UInt32 louds_pos,UInt32 parent_node)308 inline UInt32 Trie::louds_pos_to_node(UInt32 louds_pos,
309     UInt32 parent_node) const {
310   return louds_pos - parent_node - 1;
311 }
312 
get_child(UInt32 node)313 inline UInt32 Trie::get_child(UInt32 node) const {
314   return louds_.select0(node) + 1;
315 }
316 
get_parent(UInt32 node)317 inline UInt32 Trie::get_parent(UInt32 node) const {
318   return (node > num_first_branches_) ? (louds_.select1(node) - node - 1) : 0;
319 }
320 
has_link(UInt32 node)321 inline bool Trie::has_link(UInt32 node) const {
322   return (link_flags_.empty()) ? false : link_flags_[node];
323 }
324 
get_link_id(UInt32 node)325 inline UInt32 Trie::get_link_id(UInt32 node) const {
326   return link_flags_.rank1(node);
327 }
328 
get_link(UInt32 node)329 inline UInt32 Trie::get_link(UInt32 node) const {
330   return get_link(node, get_link_id(node));
331 }
332 
get_link(UInt32 node,UInt32 link_id)333 inline UInt32 Trie::get_link(UInt32 node, UInt32 link_id) const {
334   return (links_[link_id] * 256) + labels_[node];
335 }
336 
has_link()337 inline bool Trie::has_link() const {
338   return !link_flags_.empty();
339 }
340 
has_trie()341 inline bool Trie::has_trie() const {
342   return trie_.get() != NULL;
343 }
344 
has_tail()345 inline bool Trie::has_tail() const {
346   return !tail_.empty();
347 }
348 
349 }  // namespace marisa
350 
351 #endif  // MARISA_TRIE_INLINE_H_
352