• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #ifndef MARISA_ALPHA_TRIE_H_
2 #define MARISA_ALPHA_TRIE_H_
3 
4 #include "base.h"
5 
6 #ifdef __cplusplus
7 
8 #include <memory>
9 #include <vector>
10 
11 #include "progress.h"
12 #include "key.h"
13 #include "query.h"
14 #include "container.h"
15 #include "intvector.h"
16 #include "bitvector.h"
17 #include "tail.h"
18 
19 namespace marisa_alpha {
20 
21 class Trie {
22  public:
23   Trie();
24 
25   void build(const char * const *keys, std::size_t num_keys,
26       const std::size_t *key_lengths = NULL,
27       const double *key_weights = NULL,
28       UInt32 *key_ids = NULL, int flags = 0);
29 
30   void build(const std::vector<std::string> &keys,
31       std::vector<UInt32> *key_ids = NULL, int flags = 0);
32   void build(const std::vector<std::pair<std::string, double> > &keys,
33       std::vector<UInt32> *key_ids = NULL, int flags = 0);
34 
35   void mmap(Mapper *mapper, const char *filename,
36       long offset = 0, int whence = SEEK_SET);
37   void map(const void *ptr, std::size_t size);
38   void map(Mapper &mapper);
39 
40   void load(const char *filename,
41       long offset = 0, int whence = SEEK_SET);
42   void fread(std::FILE *file);
43   void read(int fd);
44   void read(std::istream &stream);
45   void read(Reader &reader);
46 
47   void save(const char *filename, bool trunc_flag = true,
48       long offset = 0, int whence = SEEK_SET) const;
49   void fwrite(std::FILE *file) const;
50   void write(int fd) const;
51   void write(std::ostream &stream) const;
52   void write(Writer &writer) const;
53 
54   std::string operator[](UInt32 key_id) const;
55 
56   UInt32 operator[](const char *str) const;
57   UInt32 operator[](const std::string &str) const;
58 
59   std::string restore(UInt32 key_id) const;
60   void restore(UInt32 key_id, std::string *key) const;
61   std::size_t restore(UInt32 key_id, char *key_buf,
62       std::size_t key_buf_size) const;
63 
64   UInt32 lookup(const char *str) const;
65   UInt32 lookup(const char *ptr, std::size_t length) const;
66   UInt32 lookup(const std::string &str) const;
67 
68   std::size_t find(const char *str,
69       UInt32 *key_ids, std::size_t *key_lengths,
70       std::size_t max_num_results) const;
71   std::size_t find(const char *ptr, std::size_t length,
72       UInt32 *key_ids, std::size_t *key_lengths,
73       std::size_t max_num_results) const;
74   std::size_t find(const std::string &str,
75       UInt32 *key_ids, std::size_t *key_lengths,
76       std::size_t max_num_results) const;
77 
78   std::size_t find(const char *str,
79       std::vector<UInt32> *key_ids = NULL,
80       std::vector<std::size_t> *key_lengths = NULL,
81       std::size_t max_num_results = MARISA_ALPHA_MAX_NUM_KEYS) const;
82   std::size_t find(const char *ptr, std::size_t length,
83       std::vector<UInt32> *key_ids = NULL,
84       std::vector<std::size_t> *key_lengths = NULL,
85       std::size_t max_num_results = MARISA_ALPHA_MAX_NUM_KEYS) const;
86   std::size_t find(const std::string &str,
87       std::vector<UInt32> *key_ids = NULL,
88       std::vector<std::size_t> *key_lengths = NULL,
89       std::size_t max_num_results = MARISA_ALPHA_MAX_NUM_KEYS) const;
90 
91   UInt32 find_first(const char *str,
92       std::size_t *key_length = NULL) const;
93   UInt32 find_first(const char *ptr, std::size_t length,
94       std::size_t *key_length = NULL) const;
95   UInt32 find_first(const std::string &str,
96       std::size_t *key_length = NULL) const;
97 
98   UInt32 find_last(const char *str,
99       std::size_t *key_length = NULL) const;
100   UInt32 find_last(const char *ptr, std::size_t length,
101       std::size_t *key_length = NULL) const;
102   UInt32 find_last(const std::string &str,
103       std::size_t *key_length = NULL) const;
104 
105   // bool callback(UInt32 key_id, std::size_t key_length);
106   template <typename T>
107   std::size_t find_callback(const char *str, T callback) const;
108   template <typename T>
109   std::size_t find_callback(const char *ptr, std::size_t length,
110       T callback) const;
111   template <typename T>
112   std::size_t find_callback(const std::string &str, T callback) const;
113 
114   std::size_t predict(const char *str,
115       UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const;
116   std::size_t predict(const char *ptr, std::size_t length,
117       UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const;
118   std::size_t predict(const std::string &str,
119       UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const;
120 
121   std::size_t predict(const char *str,
122       std::vector<UInt32> *key_ids = NULL,
123       std::vector<std::string> *keys = NULL,
124       std::size_t max_num_results = MARISA_ALPHA_MAX_NUM_KEYS) const;
125   std::size_t predict(const char *ptr, std::size_t length,
126       std::vector<UInt32> *key_ids = NULL,
127       std::vector<std::string> *keys = NULL,
128       std::size_t max_num_results = MARISA_ALPHA_MAX_NUM_KEYS) const;
129   std::size_t predict(const std::string &str,
130       std::vector<UInt32> *key_ids = NULL,
131       std::vector<std::string> *keys = NULL,
132       std::size_t max_num_results = MARISA_ALPHA_MAX_NUM_KEYS) const;
133 
134   std::size_t predict_breadth_first(const char *str,
135       UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const;
136   std::size_t predict_breadth_first(const char *ptr, std::size_t length,
137       UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const;
138   std::size_t predict_breadth_first(const std::string &str,
139       UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const;
140 
141   std::size_t predict_breadth_first(const char *str,
142       std::vector<UInt32> *key_ids = NULL,
143       std::vector<std::string> *keys = NULL,
144       std::size_t max_num_results = MARISA_ALPHA_MAX_NUM_KEYS) const;
145   std::size_t predict_breadth_first(const char *ptr, std::size_t length,
146       std::vector<UInt32> *key_ids = NULL,
147       std::vector<std::string> *keys = NULL,
148       std::size_t max_num_results = MARISA_ALPHA_MAX_NUM_KEYS) const;
149   std::size_t predict_breadth_first(const std::string &str,
150       std::vector<UInt32> *key_ids = NULL,
151       std::vector<std::string> *keys = NULL,
152       std::size_t max_num_results = MARISA_ALPHA_MAX_NUM_KEYS) const;
153 
154   std::size_t predict_depth_first(const char *str,
155       UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const;
156   std::size_t predict_depth_first(const char *ptr, std::size_t length,
157       UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const;
158   std::size_t predict_depth_first(const std::string &str,
159       UInt32 *key_ids, std::string *keys, std::size_t max_num_results) const;
160 
161   std::size_t predict_depth_first(const char *str,
162       std::vector<UInt32> *key_ids = NULL,
163       std::vector<std::string> *keys = NULL,
164       std::size_t max_num_results = MARISA_ALPHA_MAX_NUM_KEYS) const;
165   std::size_t predict_depth_first(const char *ptr, std::size_t length,
166       std::vector<UInt32> *key_ids = NULL,
167       std::vector<std::string> *keys = NULL,
168       std::size_t max_num_results = MARISA_ALPHA_MAX_NUM_KEYS) const;
169   std::size_t predict_depth_first(const std::string &str,
170       std::vector<UInt32> *key_ids = NULL,
171       std::vector<std::string> *keys = NULL,
172       std::size_t max_num_results = MARISA_ALPHA_MAX_NUM_KEYS) const;
173 
174   // bool callback(UInt32 key_id, const std::string &key);
175   template <typename T>
176   std::size_t predict_callback(const char *str, T callback) const;
177   template <typename T>
178   std::size_t predict_callback(const char *ptr, std::size_t length,
179       T callback) const;
180   template <typename T>
181   std::size_t predict_callback(const std::string &str, T callback) const;
182 
183   bool empty() const;
184   std::size_t num_tries() const;
185   std::size_t num_keys() const;
186   std::size_t num_nodes() const;
187   std::size_t total_size() const;
188 
189   void clear();
190   void swap(Trie *rhs);
191 
192   static UInt32 notfound();
193   static std::size_t mismatch();
194 
195  private:
196   BitVector louds_;
197   Vector<UInt8> labels_;
198   BitVector terminal_flags_;
199   BitVector link_flags_;
200   IntVector links_;
201   std::auto_ptr<Trie> trie_;
202   Tail tail_;
203   UInt32 num_first_branches_;
204   UInt32 num_keys_;
205 
206   void build_trie(Vector<Key<String> > &keys,
207       std::vector<UInt32> *key_ids, int flags);
208   void build_trie(Vector<Key<String> > &keys,
209       UInt32 *key_ids, int flags);
210 
211   template <typename T>
212   void build_trie(Vector<Key<T> > &keys,
213       Vector<UInt32> *terminals, Progress &progress);
214 
215   template <typename T>
216   void build_cur(Vector<Key<T> > &keys,
217       Vector<UInt32> *terminals, Progress &progress);
218 
219   void build_next(Vector<Key<String> > &keys,
220       Vector<UInt32> *terminals, Progress &progress);
221   void build_next(Vector<Key<RString> > &rkeys,
222       Vector<UInt32> *terminals, Progress &progress);
223 
224   template <typename T>
225   UInt32 sort_keys(Vector<Key<T> > &keys) const;
226 
227   template <typename T>
228   void build_terminals(const Vector<Key<T> > &keys,
229       Vector<UInt32> *terminals) const;
230 
231   void restore_(UInt32 key_id, std::string *key) const;
232   void trie_restore(UInt32 node, std::string *key) const;
233   void tail_restore(UInt32 node, std::string *key) const;
234 
235   std::size_t restore_(UInt32 key_id, char *key_buf,
236       std::size_t key_buf_size) const;
237   void trie_restore(UInt32 node, char *key_buf,
238       std::size_t key_buf_size, std::size_t &key_pos) const;
239   void tail_restore(UInt32 node, char *key_buf,
240       std::size_t key_buf_size, std::size_t &key_pos) const;
241 
242   template <typename T>
243   UInt32 lookup_(T query) const;
244   template <typename T>
245   bool find_child(UInt32 &node, T query, std::size_t &pos) const;
246   template <typename T>
247   std::size_t trie_match(UInt32 node, T query, std::size_t pos) const;
248   template <typename T>
249   std::size_t tail_match(UInt32 node, UInt32 link_id,
250       T query, std::size_t pos) const;
251 
252   template <typename T, typename U, typename V>
253   std::size_t find_(T query, U key_ids, V key_lengths,
254       std::size_t max_num_results) const;
255   template <typename T>
256   UInt32 find_first_(T query, std::size_t *key_length) const;
257   template <typename T>
258   UInt32 find_last_(T query, std::size_t *key_length) const;
259   template <typename T, typename U>
260   std::size_t find_callback_(T query, U callback) const;
261 
262   template <typename T, typename U, typename V>
263   std::size_t predict_breadth_first_(T query, U key_ids, V keys,
264       std::size_t max_num_results) const;
265   template <typename T, typename U, typename V>
266   std::size_t predict_depth_first_(T query, U key_ids, V keys,
267       std::size_t max_num_results) const;
268   template <typename T, typename U>
269   std::size_t predict_callback_(T query, U callback) const;
270 
271   template <typename T>
272   bool predict_child(UInt32 &node, T query, std::size_t &pos,
273       std::string *key) const;
274   template <typename T>
275   std::size_t trie_prefix_match(UInt32 node, T query,
276       std::size_t pos, std::string *key) const;
277   template <typename T>
278   std::size_t tail_prefix_match(UInt32 node, UInt32 link_id,
279       T query, std::size_t pos, std::string *key) const;
280 
281   UInt32 key_id_to_node(UInt32 key_id) const;
282   UInt32 node_to_key_id(UInt32 node) const;
283   UInt32 louds_pos_to_node(UInt32 louds_pos, UInt32 parent_node) const;
284 
285   UInt32 get_child(UInt32 node) const;
286   UInt32 get_parent(UInt32 node) const;
287 
288   bool has_link(UInt32 node) const;
289   UInt32 get_link_id(UInt32 node) const;
290   UInt32 get_link(UInt32 node) const;
291   UInt32 get_link(UInt32 node, UInt32 link_id) const;
292 
293   bool has_link() const;
294   bool has_trie() const;
295   bool has_tail() const;
296 
297   // Disallows copy and assignment.
298   Trie(const Trie &);
299   Trie &operator=(const Trie &);
300 };
301 
302 }  // namespace marisa_alpha
303 
304 #include "trie-inline.h"
305 
306 #else  // __cplusplus
307 
308 #include <stdio.h>
309 
310 #endif  // __cplusplus
311 
312 #ifdef __cplusplus
313 extern "C" {
314 #endif  // __cplusplus
315 
316 typedef struct marisa_alpha_trie_ marisa_alpha_trie;
317 
318 marisa_alpha_status marisa_alpha_init(marisa_alpha_trie **h);
319 marisa_alpha_status marisa_alpha_end(marisa_alpha_trie *h);
320 
321 marisa_alpha_status marisa_alpha_build(marisa_alpha_trie *h,
322     const char * const *keys, size_t num_keys, const size_t *key_lengths,
323     const double *key_weights, marisa_alpha_uint32 *key_ids, int flags);
324 
325 marisa_alpha_status marisa_alpha_mmap(marisa_alpha_trie *h,
326     const char *filename, long offset, int whence);
327 marisa_alpha_status marisa_alpha_map(marisa_alpha_trie *h, const void *ptr,
328     size_t size);
329 
330 marisa_alpha_status marisa_alpha_load(marisa_alpha_trie *h,
331     const char *filename, long offset, int whence);
332 marisa_alpha_status marisa_alpha_fread(marisa_alpha_trie *h, FILE *file);
333 marisa_alpha_status marisa_alpha_read(marisa_alpha_trie *h, int fd);
334 
335 marisa_alpha_status marisa_alpha_save(const marisa_alpha_trie *h,
336     const char *filename, int trunc_flag, long offset, int whence);
337 marisa_alpha_status marisa_alpha_fwrite(const marisa_alpha_trie *h,
338     FILE *file);
339 marisa_alpha_status marisa_alpha_write(const marisa_alpha_trie *h, int fd);
340 
341 marisa_alpha_status marisa_alpha_restore(const marisa_alpha_trie *h,
342     marisa_alpha_uint32 key_id, char *key_buf, size_t key_buf_size,
343     size_t *key_length);
344 
345 marisa_alpha_status marisa_alpha_lookup(const marisa_alpha_trie *h,
346     const char *ptr, size_t length, marisa_alpha_uint32 *key_id);
347 
348 marisa_alpha_status marisa_alpha_find(const marisa_alpha_trie *h,
349     const char *ptr, size_t length,
350     marisa_alpha_uint32 *key_ids, size_t *key_lengths,
351     size_t max_num_results, size_t *num_results);
352 marisa_alpha_status marisa_alpha_find_first(const marisa_alpha_trie *h,
353     const char *ptr, size_t length,
354     marisa_alpha_uint32 *key_id, size_t *key_length);
355 marisa_alpha_status marisa_alpha_find_last(const marisa_alpha_trie *h,
356     const char *ptr, size_t length,
357     marisa_alpha_uint32 *key_id, size_t *key_length);
358 marisa_alpha_status marisa_alpha_find_callback(const marisa_alpha_trie *h,
359     const char *ptr, size_t length,
360     int (*callback)(void *, marisa_alpha_uint32, size_t),
361     void *first_arg_to_callback);
362 
363 marisa_alpha_status marisa_alpha_predict(const marisa_alpha_trie *h,
364     const char *ptr, size_t length, marisa_alpha_uint32 *key_ids,
365     size_t max_num_results, size_t *num_results);
366 marisa_alpha_status marisa_alpha_predict_breadth_first(
367     const marisa_alpha_trie *h, const char *ptr, size_t length,
368     marisa_alpha_uint32 *key_ids, size_t max_num_results, size_t *num_results);
369 marisa_alpha_status marisa_alpha_predict_depth_first(
370     const marisa_alpha_trie *h, const char *ptr, size_t length,
371     marisa_alpha_uint32 *key_ids, size_t max_num_results, size_t *num_results);
372 marisa_alpha_status marisa_alpha_predict_callback(const marisa_alpha_trie *h,
373     const char *ptr, size_t length,
374     int (*callback)(void *, marisa_alpha_uint32, const char *, size_t),
375     void *first_arg_to_callback);
376 
377 size_t marisa_alpha_get_num_tries(const marisa_alpha_trie *h);
378 size_t marisa_alpha_get_num_keys(const marisa_alpha_trie *h);
379 size_t marisa_alpha_get_num_nodes(const marisa_alpha_trie *h);
380 size_t marisa_alpha_get_total_size(const marisa_alpha_trie *h);
381 
382 marisa_alpha_status marisa_alpha_clear(marisa_alpha_trie *h);
383 
384 #ifdef __cplusplus
385 }  // extern "C"
386 #endif  // __cplusplus
387 
388 #endif  // MARISA_ALPHA_TRIE_H_
389