1 #include "trie.h"
2
3 extern "C" {
4
5 namespace {
6
7 class FindCallback {
8 public:
9 typedef int (*Func)(void *, marisa_uint32, size_t);
10
FindCallback(Func func,void * first_arg)11 FindCallback(Func func, void *first_arg)
12 : func_(func), first_arg_(first_arg) {}
FindCallback(const FindCallback & callback)13 FindCallback(const FindCallback &callback)
14 : func_(callback.func_), first_arg_(callback.first_arg_) {}
15
operator ()(marisa::UInt32 key_id,std::size_t key_length) const16 bool operator()(marisa::UInt32 key_id, std::size_t key_length) const {
17 return func_(first_arg_, key_id, key_length) != 0;
18 }
19
20 private:
21 Func func_;
22 void *first_arg_;
23
24 // Disallows assignment.
25 FindCallback &operator=(const FindCallback &);
26 };
27
28 class PredictCallback {
29 public:
30 typedef int (*Func)(void *, marisa_uint32, const char *, size_t);
31
PredictCallback(Func func,void * first_arg)32 PredictCallback(Func func, void *first_arg)
33 : func_(func), first_arg_(first_arg) {}
PredictCallback(const PredictCallback & callback)34 PredictCallback(const PredictCallback &callback)
35 : func_(callback.func_), first_arg_(callback.first_arg_) {}
36
operator ()(marisa::UInt32 key_id,const std::string & key) const37 bool operator()(marisa::UInt32 key_id, const std::string &key) const {
38 return func_(first_arg_, key_id, key.c_str(), key.length()) != 0;
39 }
40
41 private:
42 Func func_;
43 void *first_arg_;
44
45 // Disallows assignment.
46 PredictCallback &operator=(const PredictCallback &);
47 };
48
49 } // namespace
50
51 struct marisa_trie_ {
52 public:
marisa_trie_marisa_trie_53 marisa_trie_() : trie(), mapper() {}
54
55 marisa::Trie trie;
56 marisa::Mapper mapper;
57
58 private:
59 // Disallows copy and assignment.
60 marisa_trie_(const marisa_trie_ &);
61 marisa_trie_ &operator=(const marisa_trie_ &);
62 };
63
marisa_init(marisa_trie ** h)64 marisa_status marisa_init(marisa_trie **h) {
65 if ((h == NULL) || (*h != NULL)) {
66 return MARISA_HANDLE_ERROR;
67 }
68 *h = new (std::nothrow) marisa_trie_();
69 return (*h != NULL) ? MARISA_OK : MARISA_MEMORY_ERROR;
70 }
71
marisa_end(marisa_trie * h)72 marisa_status marisa_end(marisa_trie *h) {
73 if (h == NULL) {
74 return MARISA_HANDLE_ERROR;
75 }
76 delete h;
77 return MARISA_OK;
78 }
79
marisa_build(marisa_trie * h,const char * const * keys,size_t num_keys,const size_t * key_lengths,const double * key_weights,marisa_uint32 * key_ids,int flags)80 marisa_status marisa_build(marisa_trie *h, const char * const *keys,
81 size_t num_keys, const size_t *key_lengths, const double *key_weights,
82 marisa_uint32 *key_ids, int flags) {
83 if (h == NULL) {
84 return MARISA_HANDLE_ERROR;
85 }
86 h->trie.build(keys, num_keys, key_lengths, key_weights, key_ids, flags);
87 h->mapper.clear();
88 return MARISA_OK;
89 }
90
marisa_mmap(marisa_trie * h,const char * filename,long offset,int whence)91 marisa_status marisa_mmap(marisa_trie *h, const char *filename,
92 long offset, int whence) {
93 if (h == NULL) {
94 return MARISA_HANDLE_ERROR;
95 }
96 h->trie.mmap(&h->mapper, filename, offset, whence);
97 return MARISA_OK;
98 }
99
marisa_map(marisa_trie * h,const void * ptr,size_t size)100 marisa_status marisa_map(marisa_trie *h, const void *ptr, size_t size) {
101 if (h == NULL) {
102 return MARISA_HANDLE_ERROR;
103 }
104 h->trie.map(ptr, size);
105 h->mapper.clear();
106 return MARISA_OK;
107 }
108
marisa_load(marisa_trie * h,const char * filename,long offset,int whence)109 marisa_status marisa_load(marisa_trie *h, const char *filename,
110 long offset, int whence) {
111 if (h == NULL) {
112 return MARISA_HANDLE_ERROR;
113 }
114 h->trie.load(filename, offset, whence);
115 h->mapper.clear();
116 return MARISA_OK;
117 }
118
marisa_fread(marisa_trie * h,FILE * file)119 marisa_status marisa_fread(marisa_trie *h, FILE *file) {
120 if (h == NULL) {
121 return MARISA_HANDLE_ERROR;
122 }
123 h->trie.fread(file);
124 h->mapper.clear();
125 return MARISA_OK;
126 }
127
marisa_read(marisa_trie * h,int fd)128 marisa_status marisa_read(marisa_trie *h, int fd) {
129 if (h == NULL) {
130 return MARISA_HANDLE_ERROR;
131 }
132 h->trie.read(fd);
133 h->mapper.clear();
134 return MARISA_OK;
135 }
136
marisa_save(const marisa_trie * h,const char * filename,int trunc_flag,long offset,int whence)137 marisa_status marisa_save(const marisa_trie *h, const char *filename,
138 int trunc_flag, long offset, int whence) {
139 if (h == NULL) {
140 return MARISA_HANDLE_ERROR;
141 }
142 h->trie.save(filename, trunc_flag != 0, offset, whence);
143 return MARISA_OK;
144 }
145
marisa_fwrite(const marisa_trie * h,FILE * file)146 marisa_status marisa_fwrite(const marisa_trie *h, FILE *file) {
147 if (h == NULL) {
148 return MARISA_HANDLE_ERROR;
149 }
150 h->trie.fwrite(file);
151 return MARISA_OK;
152 }
153
marisa_write(const marisa_trie * h,int fd)154 marisa_status marisa_write(const marisa_trie *h, int fd) {
155 if (h == NULL) {
156 return MARISA_HANDLE_ERROR;
157 }
158 h->trie.write(fd);
159 return MARISA_OK;
160 }
161
marisa_restore(const marisa_trie * h,marisa_uint32 key_id,char * key_buf,size_t key_buf_size,size_t * key_length)162 marisa_status marisa_restore(const marisa_trie *h, marisa_uint32 key_id,
163 char *key_buf, size_t key_buf_size, size_t *key_length) {
164 if (h == NULL) {
165 return MARISA_HANDLE_ERROR;
166 } else if (key_length == NULL) {
167 return MARISA_PARAM_ERROR;
168 }
169 *key_length = h->trie.restore(key_id, key_buf, key_buf_size);
170 return MARISA_OK;
171 }
172
marisa_lookup(const marisa_trie * h,const char * ptr,size_t length,marisa_uint32 * key_id)173 marisa_status marisa_lookup(const marisa_trie *h,
174 const char *ptr, size_t length, marisa_uint32 *key_id) {
175 if (h == NULL) {
176 return MARISA_HANDLE_ERROR;
177 } else if (key_id == NULL) {
178 return MARISA_PARAM_ERROR;
179 }
180 if (length == MARISA_ZERO_TERMINATED) {
181 *key_id = h->trie.lookup(ptr);
182 } else {
183 *key_id = h->trie.lookup(ptr, length);
184 }
185 return MARISA_OK;
186 }
187
marisa_find(const marisa_trie * h,const char * ptr,size_t length,marisa_uint32 * key_ids,size_t * key_lengths,size_t max_num_results,size_t * num_results)188 marisa_status marisa_find(const marisa_trie *h,
189 const char *ptr, size_t length,
190 marisa_uint32 *key_ids, size_t *key_lengths,
191 size_t max_num_results, size_t *num_results) {
192 if (h == NULL) {
193 return MARISA_HANDLE_ERROR;
194 } else if (num_results == NULL) {
195 return MARISA_PARAM_ERROR;
196 }
197 if (length == MARISA_ZERO_TERMINATED) {
198 *num_results = h->trie.find(ptr, key_ids, key_lengths, max_num_results);
199 } else {
200 *num_results = h->trie.find(ptr, length,
201 key_ids, key_lengths, max_num_results);
202 }
203 return MARISA_OK;
204 }
205
marisa_find_first(const marisa_trie * h,const char * ptr,size_t length,marisa_uint32 * key_id,size_t * key_length)206 marisa_status marisa_find_first(const marisa_trie *h,
207 const char *ptr, size_t length,
208 marisa_uint32 *key_id, size_t *key_length) {
209 if (h == NULL) {
210 return MARISA_HANDLE_ERROR;
211 } else if (key_id == NULL) {
212 return MARISA_PARAM_ERROR;
213 }
214 if (length == MARISA_ZERO_TERMINATED) {
215 *key_id = h->trie.find_first(ptr, key_length);
216 } else {
217 *key_id = h->trie.find_first(ptr, length, key_length);
218 }
219 return MARISA_OK;
220 }
221
marisa_find_last(const marisa_trie * h,const char * ptr,size_t length,marisa_uint32 * key_id,size_t * key_length)222 marisa_status marisa_find_last(const marisa_trie *h,
223 const char *ptr, size_t length,
224 marisa_uint32 *key_id, size_t *key_length) {
225 if (h == NULL) {
226 return MARISA_HANDLE_ERROR;
227 } else if (key_id == NULL) {
228 return MARISA_PARAM_ERROR;
229 }
230 if (length == MARISA_ZERO_TERMINATED) {
231 *key_id = h->trie.find_last(ptr, key_length);
232 } else {
233 *key_id = h->trie.find_last(ptr, length, key_length);
234 }
235 return MARISA_OK;
236 }
237
marisa_find_callback(const marisa_trie * h,const char * ptr,size_t length,int (* callback)(void *,marisa_uint32,size_t),void * first_arg_to_callback)238 marisa_status marisa_find_callback(const marisa_trie *h,
239 const char *ptr, size_t length,
240 int (*callback)(void *, marisa_uint32, size_t),
241 void *first_arg_to_callback) {
242 if (h == NULL) {
243 return MARISA_HANDLE_ERROR;
244 } else if (callback == NULL) {
245 return MARISA_PARAM_ERROR;
246 }
247 if (length == MARISA_ZERO_TERMINATED) {
248 h->trie.find_callback(ptr,
249 ::FindCallback(callback, first_arg_to_callback));
250 } else {
251 h->trie.find_callback(ptr, length,
252 ::FindCallback(callback, first_arg_to_callback));
253 }
254 return MARISA_OK;
255 }
256
marisa_predict(const marisa_trie * h,const char * ptr,size_t length,marisa_uint32 * key_ids,size_t max_num_results,size_t * num_results)257 marisa_status marisa_predict(const marisa_trie *h,
258 const char *ptr, size_t length, marisa_uint32 *key_ids,
259 size_t max_num_results, size_t *num_results) {
260 return marisa_predict_breadth_first(h, ptr, length,
261 key_ids, max_num_results, num_results);
262 }
263
marisa_predict_breadth_first(const marisa_trie * h,const char * ptr,size_t length,marisa_uint32 * key_ids,size_t max_num_results,size_t * num_results)264 marisa_status marisa_predict_breadth_first(const marisa_trie *h,
265 const char *ptr, size_t length, marisa_uint32 *key_ids,
266 size_t max_num_results, size_t *num_results) {
267 if (h == NULL) {
268 return MARISA_HANDLE_ERROR;
269 } else if (num_results == NULL) {
270 return MARISA_PARAM_ERROR;
271 }
272 if (length == MARISA_ZERO_TERMINATED) {
273 *num_results = h->trie.predict_breadth_first(
274 ptr, key_ids, NULL, max_num_results);
275 } else {
276 *num_results = h->trie.predict_breadth_first(
277 ptr, length, key_ids, NULL, max_num_results);
278 }
279 return MARISA_OK;
280 }
281
marisa_predict_depth_first(const marisa_trie * h,const char * ptr,size_t length,marisa_uint32 * key_ids,size_t max_num_results,size_t * num_results)282 marisa_status marisa_predict_depth_first(const marisa_trie *h,
283 const char *ptr, size_t length, marisa_uint32 *key_ids,
284 size_t max_num_results, size_t *num_results) {
285 if (h == NULL) {
286 return MARISA_HANDLE_ERROR;
287 } else if (num_results == NULL) {
288 return MARISA_PARAM_ERROR;
289 }
290 if (length == MARISA_ZERO_TERMINATED) {
291 *num_results = h->trie.predict_depth_first(
292 ptr, key_ids, NULL, max_num_results);
293 } else {
294 *num_results = h->trie.predict_depth_first(
295 ptr, length, key_ids, NULL, max_num_results);
296 }
297 return MARISA_OK;
298 }
299
marisa_predict_callback(const marisa_trie * h,const char * ptr,size_t length,int (* callback)(void *,marisa_uint32,const char *,size_t),void * first_arg_to_callback)300 marisa_status marisa_predict_callback(const marisa_trie *h,
301 const char *ptr, size_t length,
302 int (*callback)(void *, marisa_uint32, const char *, size_t),
303 void *first_arg_to_callback) {
304 if (h == NULL) {
305 return MARISA_HANDLE_ERROR;
306 } else if (callback == NULL) {
307 return MARISA_PARAM_ERROR;
308 }
309 if (length == MARISA_ZERO_TERMINATED) {
310 h->trie.predict_callback(ptr,
311 ::PredictCallback(callback, first_arg_to_callback));
312 } else {
313 h->trie.predict_callback(ptr, length,
314 ::PredictCallback(callback, first_arg_to_callback));
315 }
316 return MARISA_OK;
317 }
318
marisa_get_num_tries(const marisa_trie * h)319 size_t marisa_get_num_tries(const marisa_trie *h) {
320 return (h != NULL) ? h->trie.num_tries() : 0;
321 }
322
marisa_get_num_keys(const marisa_trie * h)323 size_t marisa_get_num_keys(const marisa_trie *h) {
324 return (h != NULL) ? h->trie.num_keys() : 0;
325 }
326
marisa_get_num_nodes(const marisa_trie * h)327 size_t marisa_get_num_nodes(const marisa_trie *h) {
328 return (h != NULL) ? h->trie.num_nodes() : 0;
329 }
330
marisa_get_total_size(const marisa_trie * h)331 size_t marisa_get_total_size(const marisa_trie *h) {
332 return (h != NULL) ? h->trie.total_size() : 0;
333 }
334
marisa_clear(marisa_trie * h)335 marisa_status marisa_clear(marisa_trie *h) {
336 if (h == NULL) {
337 return MARISA_HANDLE_ERROR;
338 }
339 h->trie.clear();
340 h->mapper.clear();
341 return MARISA_OK;
342 }
343
344 } // extern "C"
345