• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <algorithm>
2 #include <cstring>
3 #include <sstream>
4 
5 #include <marisa/grimoire/trie/config.h>
6 #include <marisa/grimoire/trie/header.h>
7 #include <marisa/grimoire/trie/key.h>
8 #include <marisa/grimoire/trie/range.h>
9 #include <marisa/grimoire/trie/tail.h>
10 #include <marisa/grimoire/trie/state.h>
11 
12 #include "marisa-assert.h"
13 
14 namespace {
15 
TestConfig()16 void TestConfig() {
17   TEST_START();
18 
19   marisa::grimoire::trie::Config config;
20 
21   ASSERT(config.num_tries() == MARISA_DEFAULT_NUM_TRIES);
22   ASSERT(config.tail_mode() == MARISA_DEFAULT_TAIL);
23   ASSERT(config.node_order() == MARISA_DEFAULT_ORDER);
24   ASSERT(config.cache_level() == MARISA_DEFAULT_CACHE);
25 
26   config.parse(10 | MARISA_BINARY_TAIL | MARISA_LABEL_ORDER |
27       MARISA_TINY_CACHE);
28 
29   ASSERT(config.num_tries() == 10);
30   ASSERT(config.tail_mode() == MARISA_BINARY_TAIL);
31   ASSERT(config.node_order() == MARISA_LABEL_ORDER);
32   ASSERT(config.cache_level() == MARISA_TINY_CACHE);
33 
34   config.parse(0);
35 
36   ASSERT(config.num_tries() == MARISA_DEFAULT_NUM_TRIES);
37   ASSERT(config.tail_mode() == MARISA_DEFAULT_TAIL);
38   ASSERT(config.node_order() == MARISA_DEFAULT_ORDER);
39   ASSERT(config.cache_level() == MARISA_DEFAULT_CACHE);
40 
41   TEST_END();
42 }
43 
TestHeader()44 void TestHeader() {
45   TEST_START();
46 
47   marisa::grimoire::trie::Header header;
48 
49   {
50     marisa::grimoire::Writer writer;
51     writer.open("trie-test.dat");
52     header.write(writer);
53   }
54 
55   {
56     marisa::grimoire::Mapper mapper;
57     mapper.open("trie-test.dat");
58     header.map(mapper);
59   }
60 
61   {
62     marisa::grimoire::Reader reader;
63     reader.open("trie-test.dat");
64     header.read(reader);
65   }
66 
67   TEST_END();
68 }
69 
TestKey()70 void TestKey() {
71   TEST_START();
72 
73   marisa::grimoire::trie::Key key;
74 
75   ASSERT(key.ptr() == NULL);
76   ASSERT(key.length() == 0);
77   ASSERT(key.id() == 0);
78   ASSERT(key.terminal() == 0);
79 
80   const char *str = "xyz";
81 
82   key.set_str(str, 3);
83   key.set_weight(10.0F);
84   key.set_id(20);
85 
86 
87   ASSERT(key.ptr() == str);
88   ASSERT(key.length() == 3);
89   ASSERT(key[0] == 'x');
90   ASSERT(key[1] == 'y');
91   ASSERT(key[2] == 'z');
92   ASSERT(key.weight() == 10.0F);
93   ASSERT(key.id() == 20);
94 
95   key.set_terminal(30);
96   ASSERT(key.terminal() == 30);
97 
98   key.substr(1, 2);
99 
100   ASSERT(key.ptr() == str + 1);
101   ASSERT(key.length() == 2);
102   ASSERT(key[0] == 'y');
103   ASSERT(key[1] == 'z');
104 
105   marisa::grimoire::trie::Key key2;
106   key2.set_str("abc", 3);
107 
108   ASSERT(key == key);
109   ASSERT(key != key2);
110   ASSERT(key > key2);
111   ASSERT(key2 < key);
112 
113   marisa::grimoire::trie::ReverseKey r_key;
114 
115   ASSERT(r_key.ptr() == NULL);
116   ASSERT(r_key.length() == 0);
117   ASSERT(r_key.id() == 0);
118   ASSERT(r_key.terminal() == 0);
119 
120   r_key.set_str(str, 3);
121   r_key.set_weight(100.0F);
122   r_key.set_id(200);
123 
124   ASSERT(r_key.ptr() == str);
125   ASSERT(r_key.length() == 3);
126   ASSERT(r_key[0] == 'z');
127   ASSERT(r_key[1] == 'y');
128   ASSERT(r_key[2] == 'x');
129   ASSERT(r_key.weight() == 100.0F);
130   ASSERT(r_key.id() == 200);
131 
132   r_key.set_terminal(300);
133   ASSERT(r_key.terminal() == 300);
134 
135   r_key.substr(1, 2);
136 
137   ASSERT(r_key.ptr() == str);
138   ASSERT(r_key.length() == 2);
139   ASSERT(r_key[0] == 'y');
140   ASSERT(r_key[1] == 'x');
141 
142   marisa::grimoire::trie::ReverseKey r_key2;
143   r_key2.set_str("abc", 3);
144 
145   ASSERT(r_key == r_key);
146   ASSERT(r_key != r_key2);
147   ASSERT(r_key > r_key2);
148   ASSERT(r_key2 < r_key);
149 
150   TEST_END();
151 }
152 
TestRange()153 void TestRange() {
154   TEST_START();
155 
156   marisa::grimoire::trie::Range range;
157 
158   ASSERT(range.begin() == 0);
159   ASSERT(range.end() == 0);
160   ASSERT(range.key_pos() == 0);
161 
162   range.set_begin(1);
163   range.set_end(2);
164   range.set_key_pos(3);
165 
166   ASSERT(range.begin() == 1);
167   ASSERT(range.end() == 2);
168   ASSERT(range.key_pos() == 3);
169 
170   range = marisa::grimoire::trie::make_range(10, 20, 30);
171 
172   ASSERT(range.begin() == 10);
173   ASSERT(range.end() == 20);
174   ASSERT(range.key_pos() == 30);
175 
176   marisa::grimoire::trie::WeightedRange w_range;
177 
178   ASSERT(w_range.begin() == 0);
179   ASSERT(w_range.end() == 0);
180   ASSERT(w_range.key_pos() == 0);
181   ASSERT(w_range.weight() == 0.0F);
182 
183   w_range.set_begin(10);
184   w_range.set_end(20);
185   w_range.set_key_pos(30);
186   w_range.set_weight(40.0F);
187 
188   ASSERT(w_range.begin() == 10);
189   ASSERT(w_range.end() == 20);
190   ASSERT(w_range.key_pos() == 30);
191   ASSERT(w_range.weight() == 40.0F);
192 
193   marisa::grimoire::trie::WeightedRange w_range2 =
194       marisa::grimoire::trie::make_weighted_range(100, 200, 300, 400.0F);
195 
196   ASSERT(w_range2.begin() == 100);
197   ASSERT(w_range2.end() == 200);
198   ASSERT(w_range2.key_pos() == 300);
199   ASSERT(w_range2.weight() == 400.0F);
200 
201   ASSERT(w_range < w_range2);
202   ASSERT(w_range2 > w_range);
203 
204   TEST_END();
205 }
206 
TestEntry()207 void TestEntry() {
208   TEST_START();
209 
210   marisa::grimoire::trie::Entry entry;
211 
212   ASSERT(entry.length() == 0);
213   ASSERT(entry.id() == 0);
214 
215   const char *str = "XYZ";
216 
217   entry.set_str(str, 3);
218   entry.set_id(123);
219 
220   ASSERT(entry.ptr() == str);
221   ASSERT(entry.length() == 3);
222   ASSERT(entry[0] == 'Z');
223   ASSERT(entry[1] == 'Y');
224   ASSERT(entry[2] == 'X');
225   ASSERT(entry.id() == 123);
226 
227   TEST_END();
228 }
229 
TestTextTail()230 void TestTextTail() {
231   TEST_START();
232 
233   marisa::grimoire::trie::Tail tail;
234   marisa::grimoire::Vector<marisa::grimoire::trie::Entry> entries;
235   marisa::grimoire::Vector<marisa::UInt32> offsets;
236   tail.build(entries, &offsets, MARISA_TEXT_TAIL);
237 
238   ASSERT(tail.mode() == MARISA_TEXT_TAIL);
239   ASSERT(tail.size() == 0);
240   ASSERT(tail.empty());
241   ASSERT(tail.total_size() == tail.size());
242   ASSERT(tail.io_size() == (sizeof(marisa::UInt64) * 6));
243 
244   ASSERT(offsets.empty());
245 
246   marisa::grimoire::trie::Entry entry;
247   entry.set_str("X", 1);
248   entries.push_back(entry);
249 
250   tail.build(entries, &offsets, MARISA_TEXT_TAIL);
251 
252   ASSERT(tail.mode() == MARISA_TEXT_TAIL);
253   ASSERT(tail.size() == 2);
254   ASSERT(!tail.empty());
255   ASSERT(tail.total_size() == tail.size());
256   ASSERT(tail.io_size() == (sizeof(marisa::UInt64) * 7));
257 
258   ASSERT(offsets.size() == entries.size());
259   ASSERT(offsets[0] == 0);
260   ASSERT(tail[offsets[0]] == 'X');
261   ASSERT(tail[offsets[0] + 1] == '\0');
262 
263   entries.clear();
264   entry.set_str("abc", 3);
265   entries.push_back(entry);
266   entry.set_str("bc", 2);
267   entries.push_back(entry);
268   entry.set_str("abc", 3);
269   entries.push_back(entry);
270   entry.set_str("c", 1);
271   entries.push_back(entry);
272   entry.set_str("ABC", 3);
273   entries.push_back(entry);
274   entry.set_str("AB", 2);
275   entries.push_back(entry);
276 
277   tail.build(entries, &offsets, MARISA_TEXT_TAIL);
278   std::sort(entries.begin(), entries.end(),
279       marisa::grimoire::trie::Entry::IDComparer());
280 
281   ASSERT(tail.size() == 11);
282   ASSERT(offsets.size() == entries.size());
283   for (std::size_t i = 0; i < entries.size(); ++i) {
284     const char * const ptr = &tail[offsets[i]];
285     ASSERT(std::strlen(ptr) == entries[i].length());
286     ASSERT(std::strcmp(ptr, entries[i].ptr()) == 0);
287   }
288 
289   {
290     marisa::grimoire::Writer writer;
291     writer.open("trie-test.dat");
292     tail.write(writer);
293   }
294 
295   tail.clear();
296 
297   ASSERT(tail.size() == 0);
298   ASSERT(tail.total_size() == tail.size());
299 
300   {
301     marisa::grimoire::Mapper mapper;
302     mapper.open("trie-test.dat");
303     tail.map(mapper);
304 
305     ASSERT(tail.mode() == MARISA_TEXT_TAIL);
306     ASSERT(tail.size() == 11);
307     for (std::size_t i = 0; i < entries.size(); ++i) {
308       const char * const ptr = &tail[offsets[i]];
309     ASSERT(std::strlen(ptr) == entries[i].length());
310     ASSERT(std::strcmp(ptr, entries[i].ptr()) == 0);
311     }
312     tail.clear();
313   }
314 
315   {
316     marisa::grimoire::Reader reader;
317     reader.open("trie-test.dat");
318     tail.read(reader);
319   }
320 
321   ASSERT(tail.size() == 11);
322   ASSERT(offsets.size() == entries.size());
323   for (std::size_t i = 0; i < entries.size(); ++i) {
324     const char * const ptr = &tail[offsets[i]];
325     ASSERT(std::strlen(ptr) == entries[i].length());
326     ASSERT(std::strcmp(ptr, entries[i].ptr()) == 0);
327   }
328 
329   {
330     std::stringstream stream;
331     marisa::grimoire::Writer writer;
332     writer.open(stream);
333     tail.write(writer);
334     tail.clear();
335     marisa::grimoire::Reader reader;
336     reader.open(stream);
337     tail.read(reader);
338   }
339 
340   ASSERT(tail.size() == 11);
341   ASSERT(offsets.size() == entries.size());
342   for (std::size_t i = 0; i < entries.size(); ++i) {
343     const char * const ptr = &tail[offsets[i]];
344     ASSERT(std::strlen(ptr) == entries[i].length());
345     ASSERT(std::strcmp(ptr, entries[i].ptr()) == 0);
346   }
347 
348   TEST_END();
349 }
350 
TestBinaryTail()351 void TestBinaryTail() {
352   TEST_START();
353 
354   marisa::grimoire::trie::Tail tail;
355   marisa::grimoire::Vector<marisa::grimoire::trie::Entry> entries;
356   marisa::grimoire::Vector<marisa::UInt32> offsets;
357   tail.build(entries, &offsets, MARISA_BINARY_TAIL);
358 
359   ASSERT(tail.mode() == MARISA_TEXT_TAIL);
360   ASSERT(tail.size() == 0);
361   ASSERT(tail.empty());
362   ASSERT(tail.total_size() == tail.size());
363   ASSERT(tail.io_size() == (sizeof(marisa::UInt64) * 6));
364 
365   ASSERT(offsets.empty());
366 
367   marisa::grimoire::trie::Entry entry;
368   entry.set_str("X", 1);
369   entries.push_back(entry);
370 
371   tail.build(entries, &offsets, MARISA_BINARY_TAIL);
372 
373   ASSERT(tail.mode() == MARISA_BINARY_TAIL);
374   ASSERT(tail.size() == 1);
375   ASSERT(!tail.empty());
376   ASSERT(tail.total_size() == (tail.size() + sizeof(marisa::UInt64)));
377   ASSERT(tail.io_size() == (sizeof(marisa::UInt64) * 8));
378 
379   ASSERT(offsets.size() == entries.size());
380   ASSERT(offsets[0] == 0);
381 
382   const char binary_entry[] = { 'N', 'P', '\0', 'T', 'r', 'i', 'e' };
383   entries[0].set_str(binary_entry, sizeof(binary_entry));
384 
385   tail.build(entries, &offsets, MARISA_TEXT_TAIL);
386 
387   ASSERT(tail.mode() == MARISA_BINARY_TAIL);
388   ASSERT(tail.size() == entries[0].length());
389 
390   ASSERT(offsets.size() == entries.size());
391   ASSERT(offsets[0] == 0);
392 
393   entries.clear();
394   entry.set_str("abc", 3);
395   entries.push_back(entry);
396   entry.set_str("bc", 2);
397   entries.push_back(entry);
398   entry.set_str("abc", 3);
399   entries.push_back(entry);
400   entry.set_str("c", 1);
401   entries.push_back(entry);
402   entry.set_str("ABC", 3);
403   entries.push_back(entry);
404   entry.set_str("AB", 2);
405   entries.push_back(entry);
406 
407   tail.build(entries, &offsets, MARISA_BINARY_TAIL);
408   std::sort(entries.begin(), entries.end(),
409       marisa::grimoire::trie::Entry::IDComparer());
410 
411   ASSERT(tail.mode() == MARISA_BINARY_TAIL);
412   ASSERT(tail.size() == 8);
413   ASSERT(offsets.size() == entries.size());
414   for (std::size_t i = 0; i < entries.size(); ++i) {
415     const char * const ptr = &tail[offsets[i]];
416     ASSERT(std::memcmp(ptr, entries[i].ptr(), entries[i].length()) == 0);
417   }
418 
419   TEST_END();
420 }
421 
TestHistory()422 void TestHistory() {
423   TEST_START();
424 
425   marisa::grimoire::trie::History history;
426 
427   ASSERT(history.node_id() == 0);
428   ASSERT(history.louds_pos() == 0);
429   ASSERT(history.key_pos() == 0);
430   ASSERT(history.link_id() == MARISA_INVALID_LINK_ID);
431   ASSERT(history.key_id() == MARISA_INVALID_KEY_ID);
432 
433   history.set_node_id(100);
434   history.set_louds_pos(200);
435   history.set_key_pos(300);
436   history.set_link_id(400);
437   history.set_key_id(500);
438 
439   ASSERT(history.node_id() == 100);
440   ASSERT(history.louds_pos() == 200);
441   ASSERT(history.key_pos() == 300);
442   ASSERT(history.link_id() == 400);
443   ASSERT(history.key_id() == 500);
444 
445   TEST_END();
446 }
447 
TestState()448 void TestState() {
449   TEST_START();
450 
451   marisa::grimoire::trie::State state;
452 
453   ASSERT(state.key_buf().empty());
454   ASSERT(state.history().empty());
455   ASSERT(state.node_id() == 0);
456   ASSERT(state.query_pos() == 0);
457   ASSERT(state.history_pos() == 0);
458   ASSERT(state.status_code() == marisa::grimoire::trie::MARISA_READY_TO_ALL);
459 
460   state.set_node_id(10);
461   state.set_query_pos(100);
462   state.set_history_pos(1000);
463   state.set_status_code(
464       marisa::grimoire::trie::MARISA_END_OF_PREDICTIVE_SEARCH);
465 
466   ASSERT(state.node_id() == 10);
467   ASSERT(state.query_pos() == 100);
468   ASSERT(state.history_pos() == 1000);
469   ASSERT(state.status_code() ==
470       marisa::grimoire::trie::MARISA_END_OF_PREDICTIVE_SEARCH);
471 
472   state.lookup_init();
473   ASSERT(state.status_code() == marisa::grimoire::trie::MARISA_READY_TO_ALL);
474 
475   state.reverse_lookup_init();
476   ASSERT(state.status_code() == marisa::grimoire::trie::MARISA_READY_TO_ALL);
477 
478   state.common_prefix_search_init();
479   ASSERT(state.status_code() ==
480       marisa::grimoire::trie::MARISA_READY_TO_COMMON_PREFIX_SEARCH);
481 
482   state.predictive_search_init();
483   ASSERT(state.status_code() ==
484       marisa::grimoire::trie::MARISA_READY_TO_PREDICTIVE_SEARCH);
485 
486   TEST_END();
487 }
488 
489 }  // namespace
490 
main()491 int main() try {
492   TestConfig();
493   TestHeader();
494   TestKey();
495   TestRange();
496   TestEntry();
497   TestTextTail();
498   TestBinaryTail();
499   TestHistory();
500   TestState();
501 
502   return 0;
503 } catch (const marisa::Exception &ex) {
504   std::cerr << ex.what() << std::endl;
505   throw;
506 }
507