• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <sstream>
2 
3 #include <marisa.h>
4 
5 #include "assert.h"
6 
7 namespace {
8 
9 class FindCallback {
10  public:
FindCallback(std::vector<marisa::UInt32> * key_ids,std::vector<std::size_t> * key_lengths)11   FindCallback(std::vector<marisa::UInt32> *key_ids,
12       std::vector<std::size_t> *key_lengths)
13       : key_ids_(key_ids), key_lengths_(key_lengths) {}
FindCallback(const FindCallback & callback)14   FindCallback(const FindCallback &callback)
15       : key_ids_(callback.key_ids_), key_lengths_(callback.key_lengths_) {}
16 
operator ()(marisa::UInt32 key_id,std::size_t key_length) const17   bool operator()(marisa::UInt32 key_id, std::size_t key_length) const {
18     key_ids_->push_back(key_id);
19     key_lengths_->push_back(key_length);
20     return true;
21   }
22 
23  private:
24   std::vector<marisa::UInt32> *key_ids_;
25   std::vector<std::size_t> *key_lengths_;
26 
27   // Disallows assignment.
28   FindCallback &operator=(const FindCallback &);
29 };
30 
31 class PredictCallback {
32  public:
PredictCallback(std::vector<marisa::UInt32> * key_ids,std::vector<std::string> * keys)33   PredictCallback(std::vector<marisa::UInt32> *key_ids,
34       std::vector<std::string> *keys)
35       : key_ids_(key_ids), keys_(keys) {}
PredictCallback(const PredictCallback & callback)36   PredictCallback(const PredictCallback &callback)
37       : key_ids_(callback.key_ids_), keys_(callback.keys_) {}
38 
operator ()(marisa::UInt32 key_id,const std::string & key) const39   bool operator()(marisa::UInt32 key_id, const std::string &key) const {
40     key_ids_->push_back(key_id);
41     keys_->push_back(key);
42     return true;
43   }
44 
45  private:
46   std::vector<marisa::UInt32> *key_ids_;
47   std::vector<std::string> *keys_;
48 
49   // Disallows assignment.
50   PredictCallback &operator=(const PredictCallback &);
51 };
52 
TestTrie()53 void TestTrie() {
54   TEST_START();
55 
56   marisa::Trie trie;
57 
58   ASSERT(trie.num_tries() == 0);
59   ASSERT(trie.num_keys() == 0);
60   ASSERT(trie.num_nodes() == 0);
61   ASSERT(trie.total_size() == (sizeof(marisa::UInt32) * 23));
62 
63   std::vector<std::string> keys;
64   trie.build(keys);
65   ASSERT(trie.num_tries() == 1);
66   ASSERT(trie.num_keys() == 0);
67   ASSERT(trie.num_nodes() == 1);
68 
69   keys.push_back("apple");
70   keys.push_back("and");
71   keys.push_back("Bad");
72   keys.push_back("apple");
73   keys.push_back("app");
74 
75   std::vector<marisa::UInt32> key_ids;
76   trie.build(keys, &key_ids, 1 | MARISA_WITHOUT_TAIL | MARISA_LABEL_ORDER);
77 
78   ASSERT(trie.num_tries() == 1);
79   ASSERT(trie.num_keys() == 4);
80   ASSERT(trie.num_nodes() == 11);
81 
82   ASSERT(key_ids.size() == 5);
83   ASSERT(key_ids[0] == 3);
84   ASSERT(key_ids[1] == 1);
85   ASSERT(key_ids[2] == 0);
86   ASSERT(key_ids[3] == 3);
87   ASSERT(key_ids[4] == 2);
88 
89   char key_buf[256];
90   std::size_t key_length;
91   for (std::size_t i = 0; i < keys.size(); ++i) {
92     key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf));
93 
94     ASSERT(trie[keys[i]] == key_ids[i]);
95     ASSERT(trie[key_ids[i]] == keys[i]);
96     ASSERT(key_length == keys[i].length());
97     ASSERT(keys[i] == key_buf);
98   }
99 
100   trie.clear();
101 
102   ASSERT(trie.num_tries() == 0);
103   ASSERT(trie.num_keys() == 0);
104   ASSERT(trie.num_nodes() == 0);
105   ASSERT(trie.total_size() == (sizeof(marisa::UInt32) * 23));
106 
107   trie.build(keys, &key_ids, 1 | MARISA_WITHOUT_TAIL | MARISA_WEIGHT_ORDER);
108 
109   ASSERT(trie.num_tries() == 1);
110   ASSERT(trie.num_keys() == 4);
111   ASSERT(trie.num_nodes() == 11);
112 
113   ASSERT(key_ids.size() == 5);
114   ASSERT(key_ids[0] == 3);
115   ASSERT(key_ids[1] == 1);
116   ASSERT(key_ids[2] == 2);
117   ASSERT(key_ids[3] == 3);
118   ASSERT(key_ids[4] == 0);
119 
120   for (std::size_t i = 0; i < keys.size(); ++i) {
121     ASSERT(trie[keys[i]] == key_ids[i]);
122     ASSERT(trie[key_ids[i]] == keys[i]);
123   }
124 
125   ASSERT(trie["appl"] == trie.notfound());
126   ASSERT(trie["applex"] == trie.notfound());
127   ASSERT(trie.find_first("ap") == trie.notfound());
128   ASSERT(trie.find_first("applex") == trie["app"]);
129   ASSERT(trie.find_last("ap") == trie.notfound());
130   ASSERT(trie.find_last("applex") == trie["apple"]);
131 
132   std::vector<marisa::UInt32> ids;
133   ASSERT(trie.find("ap", &ids) == 0);
134   ASSERT(trie.find("applex", &ids) == 2);
135   ASSERT(ids.size() == 2);
136   ASSERT(ids[0] == trie["app"]);
137   ASSERT(ids[1] == trie["apple"]);
138 
139   std::vector<std::size_t> lengths;
140   ASSERT(trie.find("Baddie", &ids, &lengths) == 1);
141   ASSERT(ids.size() == 3);
142   ASSERT(ids[2] == trie["Bad"]);
143   ASSERT(lengths.size() == 1);
144   ASSERT(lengths[0] == 3);
145 
146   ASSERT(trie.find_callback("anderson", FindCallback(&ids, &lengths)) == 1);
147   ASSERT(ids.size() == 4);
148   ASSERT(ids[3] == trie["and"]);
149   ASSERT(lengths.size() == 2);
150   ASSERT(lengths[1] == 3);
151 
152   ASSERT(trie.predict("") == 4);
153   ASSERT(trie.predict("a") == 3);
154   ASSERT(trie.predict("ap") == 2);
155   ASSERT(trie.predict("app") == 2);
156   ASSERT(trie.predict("appl") == 1);
157   ASSERT(trie.predict("apple") == 1);
158   ASSERT(trie.predict("appleX") == 0);
159   ASSERT(trie.predict("X") == 0);
160 
161   ids.clear();
162   ASSERT(trie.predict("a", &ids) == 3);
163   ASSERT(ids.size() == 3);
164   ASSERT(ids[0] == trie["app"]);
165   ASSERT(ids[1] == trie["and"]);
166   ASSERT(ids[2] == trie["apple"]);
167 
168   std::vector<std::string> strs;
169   ASSERT(trie.predict("a", &ids, &strs) == 3);
170   ASSERT(ids.size() == 6);
171   ASSERT(ids[3] == trie["app"]);
172   ASSERT(ids[4] == trie["apple"]);
173   ASSERT(ids[5] == trie["and"]);
174   ASSERT(strs[0] == "app");
175   ASSERT(strs[1] == "apple");
176   ASSERT(strs[2] == "and");
177 
178   TEST_END();
179 }
180 
TestPrefixTrie()181 void TestPrefixTrie() {
182   TEST_START();
183 
184   std::vector<std::string> keys;
185   keys.push_back("after");
186   keys.push_back("bar");
187   keys.push_back("car");
188   keys.push_back("caster");
189 
190   marisa::Trie trie;
191   std::vector<marisa::UInt32> key_ids;
192   trie.build(keys, &key_ids, 1 | MARISA_PREFIX_TRIE
193       | MARISA_TEXT_TAIL | MARISA_LABEL_ORDER);
194 
195   ASSERT(trie.num_tries() == 1);
196   ASSERT(trie.num_keys() == 4);
197   ASSERT(trie.num_nodes() == 7);
198 
199   char key_buf[256];
200   std::size_t key_length;
201   for (std::size_t i = 0; i < keys.size(); ++i) {
202     key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf));
203 
204     ASSERT(trie[keys[i]] == key_ids[i]);
205     ASSERT(trie[key_ids[i]] == keys[i]);
206     ASSERT(key_length == keys[i].length());
207     ASSERT(keys[i] == key_buf);
208   }
209 
210   key_length = trie.restore(key_ids[0], NULL, 0);
211 
212   ASSERT(key_length == keys[0].length());
213   EXCEPT(trie.restore(key_ids[0], NULL, 5), MARISA_PARAM_ERROR);
214 
215   key_length = trie.restore(key_ids[0], key_buf, 5);
216 
217   ASSERT(key_length == keys[0].length());
218 
219   key_length = trie.restore(key_ids[0], key_buf, 6);
220 
221   ASSERT(key_length == keys[0].length());
222 
223   trie.build(keys, &key_ids, 2 | MARISA_PREFIX_TRIE
224       | MARISA_WITHOUT_TAIL | MARISA_WEIGHT_ORDER);
225 
226   ASSERT(trie.num_tries() == 2);
227   ASSERT(trie.num_keys() == 4);
228   ASSERT(trie.num_nodes() == 16);
229 
230   for (std::size_t i = 0; i < keys.size(); ++i) {
231     key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf));
232 
233     ASSERT(trie[keys[i]] == key_ids[i]);
234     ASSERT(trie[key_ids[i]] == keys[i]);
235     ASSERT(key_length == keys[i].length());
236     ASSERT(keys[i] == key_buf);
237   }
238 
239   key_length = trie.restore(key_ids[0], NULL, 0);
240 
241   ASSERT(key_length == keys[0].length());
242   EXCEPT(trie.restore(key_ids[0], NULL, 5), MARISA_PARAM_ERROR);
243 
244   key_length = trie.restore(key_ids[0], key_buf, 5);
245 
246   ASSERT(key_length == keys[0].length());
247 
248   key_length = trie.restore(key_ids[0], key_buf, 6);
249 
250   ASSERT(key_length == keys[0].length());
251 
252   trie.build(keys, &key_ids, 2 | MARISA_PREFIX_TRIE
253       | MARISA_TEXT_TAIL | MARISA_LABEL_ORDER);
254 
255   ASSERT(trie.num_tries() == 2);
256   ASSERT(trie.num_keys() == 4);
257   ASSERT(trie.num_nodes() == 14);
258 
259   for (std::size_t i = 0; i < keys.size(); ++i) {
260     key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf));
261 
262     ASSERT(trie[keys[i]] == key_ids[i]);
263     ASSERT(trie[key_ids[i]] == keys[i]);
264     ASSERT(key_length == keys[i].length());
265     ASSERT(keys[i] == key_buf);
266   }
267 
268   trie.save("trie-test.dat");
269   trie.clear();
270   marisa::Mapper mapper;
271   trie.mmap(&mapper, "trie-test.dat");
272 
273   ASSERT(mapper.is_open());
274   ASSERT(trie.num_tries() == 2);
275   ASSERT(trie.num_keys() == 4);
276   ASSERT(trie.num_nodes() == 14);
277 
278   for (std::size_t i = 0; i < keys.size(); ++i) {
279     key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf));
280 
281     ASSERT(trie[keys[i]] == key_ids[i]);
282     ASSERT(trie[key_ids[i]] == keys[i]);
283     ASSERT(key_length == keys[i].length());
284     ASSERT(keys[i] == key_buf);
285   }
286 
287   std::stringstream stream;
288   trie.write(stream);
289   trie.clear();
290   trie.read(stream);
291 
292   ASSERT(trie.num_tries() == 2);
293   ASSERT(trie.num_keys() == 4);
294   ASSERT(trie.num_nodes() == 14);
295 
296   for (std::size_t i = 0; i < keys.size(); ++i) {
297     key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf));
298 
299     ASSERT(trie[keys[i]] == key_ids[i]);
300     ASSERT(trie[key_ids[i]] == keys[i]);
301     ASSERT(key_length == keys[i].length());
302     ASSERT(keys[i] == key_buf);
303   }
304 
305   trie.build(keys, &key_ids, 3 | MARISA_PREFIX_TRIE
306       | MARISA_WITHOUT_TAIL | MARISA_WEIGHT_ORDER);
307 
308   ASSERT(trie.num_tries() == 3);
309   ASSERT(trie.num_keys() == 4);
310   ASSERT(trie.num_nodes() == 19);
311 
312   for (std::size_t i = 0; i < keys.size(); ++i) {
313     key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf));
314 
315     ASSERT(trie[keys[i]] == key_ids[i]);
316     ASSERT(trie[key_ids[i]] == keys[i]);
317     ASSERT(key_length == keys[i].length());
318     ASSERT(keys[i] == key_buf);
319   }
320 
321   ASSERT(trie["ca"] == trie.notfound());
322   ASSERT(trie["card"] == trie.notfound());
323 
324   std::size_t length = 0;
325   ASSERT(trie.find_first("ca") == trie.notfound());
326   ASSERT(trie.find_first("car") == trie["car"]);
327   ASSERT(trie.find_first("card", &length) == trie["car"]);
328   ASSERT(length == 3);
329 
330   ASSERT(trie.find_last("afte") == trie.notfound());
331   ASSERT(trie.find_last("after") == trie["after"]);
332   ASSERT(trie.find_last("afternoon", &length) == trie["after"]);
333   ASSERT(length == 5);
334 
335   {
336     std::vector<marisa::UInt32> ids;
337     std::vector<std::size_t> lengths;
338     ASSERT(trie.find("card", &ids, &lengths) == 1);
339     ASSERT(ids.size() == 1);
340     ASSERT(ids[0] == trie["car"]);
341     ASSERT(lengths.size() == 1);
342     ASSERT(lengths[0] == 3);
343 
344     ASSERT(trie.predict("ca", &ids) == 2);
345     ASSERT(ids.size() == 3);
346     ASSERT(ids[1] == trie["car"]);
347     ASSERT(ids[2] == trie["caster"]);
348 
349     ASSERT(trie.predict("ca", &ids, NULL, 1) == 1);
350     ASSERT(ids.size() == 4);
351     ASSERT(ids[3] == trie["car"]);
352 
353     std::vector<std::string> strs;
354     ASSERT(trie.predict("ca", &ids, &strs, 1) == 1);
355     ASSERT(ids.size() == 5);
356     ASSERT(ids[4] == trie["car"]);
357     ASSERT(strs.size() == 1);
358     ASSERT(strs[0] == "car");
359 
360     ASSERT(trie.predict_callback("", PredictCallback(&ids, &strs)) == 4);
361     ASSERT(ids.size() == 9);
362     ASSERT(ids[5] == trie["car"]);
363     ASSERT(ids[6] == trie["caster"]);
364     ASSERT(ids[7] == trie["after"]);
365     ASSERT(ids[8] == trie["bar"]);
366     ASSERT(strs.size() == 5);
367     ASSERT(strs[1] == "car");
368     ASSERT(strs[2] == "caster");
369     ASSERT(strs[3] == "after");
370     ASSERT(strs[4] == "bar");
371   }
372 
373   {
374     marisa::UInt32 ids[10];
375     std::size_t lengths[10];
376     ASSERT(trie.find("card", ids, lengths, 10) == 1);
377     ASSERT(ids[0] == trie["car"]);
378     ASSERT(lengths[0] == 3);
379 
380     ASSERT(trie.predict("ca", ids, NULL, 10) == 2);
381     ASSERT(ids[0] == trie["car"]);
382     ASSERT(ids[1] == trie["caster"]);
383 
384     ASSERT(trie.predict("ca", ids, NULL, 1) == 1);
385     ASSERT(ids[0] == trie["car"]);
386 
387     std::string strs[10];
388     ASSERT(trie.predict("ca", ids, strs, 1) == 1);
389     ASSERT(ids[0] == trie["car"]);
390     ASSERT(strs[0] == "car");
391 
392     ASSERT(trie.predict("", ids, strs, 10) == 4);
393     ASSERT(ids[0] == trie["car"]);
394     ASSERT(ids[1] == trie["caster"]);
395     ASSERT(ids[2] == trie["after"]);
396     ASSERT(ids[3] == trie["bar"]);
397     ASSERT(strs[0] == "car");
398     ASSERT(strs[1] == "caster");
399     ASSERT(strs[2] == "after");
400     ASSERT(strs[3] == "bar");
401   }
402 
403   TEST_END();
404 }
405 
TestPatriciaTrie()406 void TestPatriciaTrie() {
407   TEST_START();
408 
409   std::vector<std::string> keys;
410   keys.push_back("bach");
411   keys.push_back("bet");
412   keys.push_back("chat");
413   keys.push_back("check");
414   keys.push_back("check");
415 
416   marisa::Trie trie;
417   std::vector<marisa::UInt32> key_ids;
418   trie.build(keys, &key_ids, 1);
419 
420   ASSERT(trie.num_tries() == 1);
421   ASSERT(trie.num_keys() == 4);
422   ASSERT(trie.num_nodes() == 7);
423 
424   ASSERT(key_ids.size() == 5);
425   ASSERT(key_ids[0] == 2);
426   ASSERT(key_ids[1] == 3);
427   ASSERT(key_ids[2] == 1);
428   ASSERT(key_ids[3] == 0);
429   ASSERT(key_ids[4] == 0);
430 
431   char key_buf[256];
432   std::size_t key_length;
433   for (std::size_t i = 0; i < keys.size(); ++i) {
434     key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf));
435 
436     ASSERT(trie[keys[i]] == key_ids[i]);
437     ASSERT(trie[key_ids[i]] == keys[i]);
438     ASSERT(key_length == keys[i].length());
439     ASSERT(keys[i] == key_buf);
440   }
441 
442   trie.build(keys, &key_ids, 2 | MARISA_WITHOUT_TAIL);
443 
444   ASSERT(trie.num_tries() == 2);
445   ASSERT(trie.num_keys() == 4);
446   ASSERT(trie.num_nodes() == 17);
447 
448   for (std::size_t i = 0; i < keys.size(); ++i) {
449     key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf));
450 
451     ASSERT(trie[keys[i]] == key_ids[i]);
452     ASSERT(trie[key_ids[i]] == keys[i]);
453     ASSERT(key_length == keys[i].length());
454     ASSERT(keys[i] == key_buf);
455   }
456 
457   trie.build(keys, &key_ids, 2);
458 
459   ASSERT(trie.num_tries() == 2);
460   ASSERT(trie.num_keys() == 4);
461   ASSERT(trie.num_nodes() == 14);
462 
463   for (std::size_t i = 0; i < keys.size(); ++i) {
464     key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf));
465 
466     ASSERT(trie[keys[i]] == key_ids[i]);
467     ASSERT(trie[key_ids[i]] == keys[i]);
468     ASSERT(key_length == keys[i].length());
469     ASSERT(keys[i] == key_buf);
470   }
471 
472   trie.build(keys, &key_ids, 3 | MARISA_WITHOUT_TAIL);
473 
474   ASSERT(trie.num_tries() == 3);
475   ASSERT(trie.num_keys() == 4);
476   ASSERT(trie.num_nodes() == 20);
477 
478   for (std::size_t i = 0; i < keys.size(); ++i) {
479     key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf));
480 
481     ASSERT(trie[keys[i]] == key_ids[i]);
482     ASSERT(trie[key_ids[i]] == keys[i]);
483     ASSERT(key_length == keys[i].length());
484     ASSERT(keys[i] == key_buf);
485   }
486 
487   std::stringstream stream;
488   trie.write(stream);
489   trie.clear();
490   trie.read(stream);
491 
492   ASSERT(trie.num_tries() == 3);
493   ASSERT(trie.num_keys() == 4);
494   ASSERT(trie.num_nodes() == 20);
495 
496   for (std::size_t i = 0; i < keys.size(); ++i) {
497     key_length = trie.restore(key_ids[i], key_buf, sizeof(key_buf));
498 
499     ASSERT(trie[keys[i]] == key_ids[i]);
500     ASSERT(trie[key_ids[i]] == keys[i]);
501     ASSERT(key_length == keys[i].length());
502     ASSERT(keys[i] == key_buf);
503   }
504 
505   TEST_END();
506 }
507 
TestEmptyString()508 void TestEmptyString() {
509   TEST_START();
510 
511   std::vector<std::string> keys;
512   keys.push_back("");
513 
514   marisa::Trie trie;
515   std::vector<marisa::UInt32> key_ids;
516   trie.build(keys, &key_ids);
517 
518   ASSERT(trie.num_tries() == 1);
519   ASSERT(trie.num_keys() == 1);
520   ASSERT(trie.num_nodes() == 1);
521 
522   ASSERT(key_ids.size() == 1);
523   ASSERT(key_ids[0] == 0);
524 
525   ASSERT(trie[""] == 0);
526   ASSERT(trie[(marisa::UInt32)0] == "");
527 
528   ASSERT(trie["x"] == trie.notfound());
529   ASSERT(trie.find_first("") == 0);
530   ASSERT(trie.find_first("x") == 0);
531   ASSERT(trie.find_last("") == 0);
532   ASSERT(trie.find_last("x") == 0);
533 
534   std::vector<marisa::UInt32> ids;
535   ASSERT(trie.find("xyz", &ids) == 1);
536   ASSERT(ids.size() == 1);
537   ASSERT(ids[0] == trie[""]);
538 
539   std::vector<std::size_t> lengths;
540   ASSERT(trie.find("xyz", &ids, &lengths) == 1);
541   ASSERT(ids.size() == 2);
542   ASSERT(ids[0] == trie[""]);
543   ASSERT(ids[1] == trie[""]);
544   ASSERT(lengths.size() == 1);
545   ASSERT(lengths[0] == 0);
546 
547   ASSERT(trie.find_callback("xyz", FindCallback(&ids, &lengths)) == 1);
548   ASSERT(ids.size() == 3);
549   ASSERT(ids[2] == trie[""]);
550   ASSERT(lengths.size() == 2);
551   ASSERT(lengths[1] == 0);
552 
553   ASSERT(trie.predict("xyz", &ids) == 0);
554 
555   ASSERT(trie.predict("", &ids) == 1);
556   ASSERT(ids.size() == 4);
557   ASSERT(ids[3] == trie[""]);
558 
559   std::vector<std::string> strs;
560   ASSERT(trie.predict("", &ids, &strs) == 1);
561   ASSERT(ids.size() == 5);
562   ASSERT(ids[4] == trie[""]);
563   ASSERT(strs[0] == "");
564 
565   TEST_END();
566 }
567 
TestBinaryKey()568 void TestBinaryKey() {
569   TEST_START();
570 
571   std::string binary_key = "NP";
572   binary_key += '\0';
573   binary_key += "Trie";
574 
575   std::vector<std::string> keys;
576   keys.push_back(binary_key);
577 
578   marisa::Trie trie;
579   std::vector<marisa::UInt32> key_ids;
580   trie.build(keys, &key_ids, 1 | MARISA_WITHOUT_TAIL);
581 
582   ASSERT(trie.num_tries() == 1);
583   ASSERT(trie.num_keys() == 1);
584   ASSERT(trie.num_nodes() == 8);
585   ASSERT(key_ids.size() == 1);
586 
587   char key_buf[256];
588   std::size_t key_length;
589   key_length = trie.restore(0, key_buf, sizeof(key_buf));
590 
591   ASSERT(trie[keys[0]] == key_ids[0]);
592   ASSERT(trie[key_ids[0]] == keys[0]);
593   ASSERT(std::string(key_buf, key_length) == keys[0]);
594 
595   trie.build(keys, &key_ids, 1 | MARISA_PREFIX_TRIE | MARISA_BINARY_TAIL);
596 
597   ASSERT(trie.num_tries() == 1);
598   ASSERT(trie.num_keys() == 1);
599   ASSERT(trie.num_nodes() == 2);
600   ASSERT(key_ids.size() == 1);
601 
602   key_length = trie.restore(0, key_buf, sizeof(key_buf));
603 
604   ASSERT(trie[keys[0]] == key_ids[0]);
605   ASSERT(trie[key_ids[0]] == keys[0]);
606   ASSERT(std::string(key_buf, key_length) == keys[0]);
607 
608   trie.build(keys, &key_ids, 1 | MARISA_PREFIX_TRIE | MARISA_TEXT_TAIL);
609 
610   ASSERT(trie.num_tries() == 1);
611   ASSERT(trie.num_keys() == 1);
612   ASSERT(trie.num_nodes() == 2);
613   ASSERT(key_ids.size() == 1);
614 
615   key_length = trie.restore(0, key_buf, sizeof(key_buf));
616 
617   ASSERT(trie[keys[0]] == key_ids[0]);
618   ASSERT(trie[key_ids[0]] == keys[0]);
619   ASSERT(std::string(key_buf, key_length) == keys[0]);
620 
621   std::vector<marisa::UInt32> ids;
622   ASSERT(trie.predict_breadth_first("", &ids) == 1);
623   ASSERT(ids.size() == 1);
624   ASSERT(ids[0] == key_ids[0]);
625 
626   std::vector<std::string> strs;
627   ASSERT(trie.predict_depth_first("NP", &ids, &strs) == 1);
628   ASSERT(ids.size() == 2);
629   ASSERT(ids[1] == key_ids[0]);
630   ASSERT(strs[0] == keys[0]);
631 
632   TEST_END();
633 }
634 
635 }  // namespace
636 
main()637 int main() {
638   TestTrie();
639   TestPrefixTrie();
640   TestPatriciaTrie();
641   TestEmptyString();
642   TestBinaryKey();
643 
644   return 0;
645 }
646