1 #include <algorithm>
2 #include <cstring>
3 #include <sstream>
4
5 #include <marisa/grimoire/trie/config.h>
6 #include <marisa/grimoire/trie/header.h>
7 #include <marisa/grimoire/trie/key.h>
8 #include <marisa/grimoire/trie/range.h>
9 #include <marisa/grimoire/trie/tail.h>
10 #include <marisa/grimoire/trie/state.h>
11
12 #include "marisa-assert.h"
13
14 namespace {
15
TestConfig()16 void TestConfig() {
17 TEST_START();
18
19 marisa::grimoire::trie::Config config;
20
21 ASSERT(config.num_tries() == MARISA_DEFAULT_NUM_TRIES);
22 ASSERT(config.tail_mode() == MARISA_DEFAULT_TAIL);
23 ASSERT(config.node_order() == MARISA_DEFAULT_ORDER);
24 ASSERT(config.cache_level() == MARISA_DEFAULT_CACHE);
25
26 config.parse(10 | MARISA_BINARY_TAIL | MARISA_LABEL_ORDER |
27 MARISA_TINY_CACHE);
28
29 ASSERT(config.num_tries() == 10);
30 ASSERT(config.tail_mode() == MARISA_BINARY_TAIL);
31 ASSERT(config.node_order() == MARISA_LABEL_ORDER);
32 ASSERT(config.cache_level() == MARISA_TINY_CACHE);
33
34 config.parse(0);
35
36 ASSERT(config.num_tries() == MARISA_DEFAULT_NUM_TRIES);
37 ASSERT(config.tail_mode() == MARISA_DEFAULT_TAIL);
38 ASSERT(config.node_order() == MARISA_DEFAULT_ORDER);
39 ASSERT(config.cache_level() == MARISA_DEFAULT_CACHE);
40
41 TEST_END();
42 }
43
TestHeader()44 void TestHeader() {
45 TEST_START();
46
47 marisa::grimoire::trie::Header header;
48
49 {
50 marisa::grimoire::Writer writer;
51 writer.open("trie-test.dat");
52 header.write(writer);
53 }
54
55 {
56 marisa::grimoire::Mapper mapper;
57 mapper.open("trie-test.dat");
58 header.map(mapper);
59 }
60
61 {
62 marisa::grimoire::Reader reader;
63 reader.open("trie-test.dat");
64 header.read(reader);
65 }
66
67 TEST_END();
68 }
69
TestKey()70 void TestKey() {
71 TEST_START();
72
73 marisa::grimoire::trie::Key key;
74
75 ASSERT(key.ptr() == NULL);
76 ASSERT(key.length() == 0);
77 ASSERT(key.id() == 0);
78 ASSERT(key.terminal() == 0);
79
80 const char *str = "xyz";
81
82 key.set_str(str, 3);
83 key.set_weight(10.0F);
84 key.set_id(20);
85
86
87 ASSERT(key.ptr() == str);
88 ASSERT(key.length() == 3);
89 ASSERT(key[0] == 'x');
90 ASSERT(key[1] == 'y');
91 ASSERT(key[2] == 'z');
92 ASSERT(key.weight() == 10.0F);
93 ASSERT(key.id() == 20);
94
95 key.set_terminal(30);
96 ASSERT(key.terminal() == 30);
97
98 key.substr(1, 2);
99
100 ASSERT(key.ptr() == str + 1);
101 ASSERT(key.length() == 2);
102 ASSERT(key[0] == 'y');
103 ASSERT(key[1] == 'z');
104
105 marisa::grimoire::trie::Key key2;
106 key2.set_str("abc", 3);
107
108 ASSERT(key == key);
109 ASSERT(key != key2);
110 ASSERT(key > key2);
111 ASSERT(key2 < key);
112
113 marisa::grimoire::trie::ReverseKey r_key;
114
115 ASSERT(r_key.ptr() == NULL);
116 ASSERT(r_key.length() == 0);
117 ASSERT(r_key.id() == 0);
118 ASSERT(r_key.terminal() == 0);
119
120 r_key.set_str(str, 3);
121 r_key.set_weight(100.0F);
122 r_key.set_id(200);
123
124 ASSERT(r_key.ptr() == str);
125 ASSERT(r_key.length() == 3);
126 ASSERT(r_key[0] == 'z');
127 ASSERT(r_key[1] == 'y');
128 ASSERT(r_key[2] == 'x');
129 ASSERT(r_key.weight() == 100.0F);
130 ASSERT(r_key.id() == 200);
131
132 r_key.set_terminal(300);
133 ASSERT(r_key.terminal() == 300);
134
135 r_key.substr(1, 2);
136
137 ASSERT(r_key.ptr() == str);
138 ASSERT(r_key.length() == 2);
139 ASSERT(r_key[0] == 'y');
140 ASSERT(r_key[1] == 'x');
141
142 marisa::grimoire::trie::ReverseKey r_key2;
143 r_key2.set_str("abc", 3);
144
145 ASSERT(r_key == r_key);
146 ASSERT(r_key != r_key2);
147 ASSERT(r_key > r_key2);
148 ASSERT(r_key2 < r_key);
149
150 TEST_END();
151 }
152
TestRange()153 void TestRange() {
154 TEST_START();
155
156 marisa::grimoire::trie::Range range;
157
158 ASSERT(range.begin() == 0);
159 ASSERT(range.end() == 0);
160 ASSERT(range.key_pos() == 0);
161
162 range.set_begin(1);
163 range.set_end(2);
164 range.set_key_pos(3);
165
166 ASSERT(range.begin() == 1);
167 ASSERT(range.end() == 2);
168 ASSERT(range.key_pos() == 3);
169
170 range = marisa::grimoire::trie::make_range(10, 20, 30);
171
172 ASSERT(range.begin() == 10);
173 ASSERT(range.end() == 20);
174 ASSERT(range.key_pos() == 30);
175
176 marisa::grimoire::trie::WeightedRange w_range;
177
178 ASSERT(w_range.begin() == 0);
179 ASSERT(w_range.end() == 0);
180 ASSERT(w_range.key_pos() == 0);
181 ASSERT(w_range.weight() == 0.0F);
182
183 w_range.set_begin(10);
184 w_range.set_end(20);
185 w_range.set_key_pos(30);
186 w_range.set_weight(40.0F);
187
188 ASSERT(w_range.begin() == 10);
189 ASSERT(w_range.end() == 20);
190 ASSERT(w_range.key_pos() == 30);
191 ASSERT(w_range.weight() == 40.0F);
192
193 marisa::grimoire::trie::WeightedRange w_range2 =
194 marisa::grimoire::trie::make_weighted_range(100, 200, 300, 400.0F);
195
196 ASSERT(w_range2.begin() == 100);
197 ASSERT(w_range2.end() == 200);
198 ASSERT(w_range2.key_pos() == 300);
199 ASSERT(w_range2.weight() == 400.0F);
200
201 ASSERT(w_range < w_range2);
202 ASSERT(w_range2 > w_range);
203
204 TEST_END();
205 }
206
TestEntry()207 void TestEntry() {
208 TEST_START();
209
210 marisa::grimoire::trie::Entry entry;
211
212 ASSERT(entry.length() == 0);
213 ASSERT(entry.id() == 0);
214
215 const char *str = "XYZ";
216
217 entry.set_str(str, 3);
218 entry.set_id(123);
219
220 ASSERT(entry.ptr() == str);
221 ASSERT(entry.length() == 3);
222 ASSERT(entry[0] == 'Z');
223 ASSERT(entry[1] == 'Y');
224 ASSERT(entry[2] == 'X');
225 ASSERT(entry.id() == 123);
226
227 TEST_END();
228 }
229
TestTextTail()230 void TestTextTail() {
231 TEST_START();
232
233 marisa::grimoire::trie::Tail tail;
234 marisa::grimoire::Vector<marisa::grimoire::trie::Entry> entries;
235 marisa::grimoire::Vector<marisa::UInt32> offsets;
236 tail.build(entries, &offsets, MARISA_TEXT_TAIL);
237
238 ASSERT(tail.mode() == MARISA_TEXT_TAIL);
239 ASSERT(tail.size() == 0);
240 ASSERT(tail.empty());
241 ASSERT(tail.total_size() == tail.size());
242 ASSERT(tail.io_size() == (sizeof(marisa::UInt64) * 6));
243
244 ASSERT(offsets.empty());
245
246 marisa::grimoire::trie::Entry entry;
247 entry.set_str("X", 1);
248 entries.push_back(entry);
249
250 tail.build(entries, &offsets, MARISA_TEXT_TAIL);
251
252 ASSERT(tail.mode() == MARISA_TEXT_TAIL);
253 ASSERT(tail.size() == 2);
254 ASSERT(!tail.empty());
255 ASSERT(tail.total_size() == tail.size());
256 ASSERT(tail.io_size() == (sizeof(marisa::UInt64) * 7));
257
258 ASSERT(offsets.size() == entries.size());
259 ASSERT(offsets[0] == 0);
260 ASSERT(tail[offsets[0]] == 'X');
261 ASSERT(tail[offsets[0] + 1] == '\0');
262
263 entries.clear();
264 entry.set_str("abc", 3);
265 entries.push_back(entry);
266 entry.set_str("bc", 2);
267 entries.push_back(entry);
268 entry.set_str("abc", 3);
269 entries.push_back(entry);
270 entry.set_str("c", 1);
271 entries.push_back(entry);
272 entry.set_str("ABC", 3);
273 entries.push_back(entry);
274 entry.set_str("AB", 2);
275 entries.push_back(entry);
276
277 tail.build(entries, &offsets, MARISA_TEXT_TAIL);
278 std::sort(entries.begin(), entries.end(),
279 marisa::grimoire::trie::Entry::IDComparer());
280
281 ASSERT(tail.size() == 11);
282 ASSERT(offsets.size() == entries.size());
283 for (std::size_t i = 0; i < entries.size(); ++i) {
284 const char * const ptr = &tail[offsets[i]];
285 ASSERT(std::strlen(ptr) == entries[i].length());
286 ASSERT(std::strcmp(ptr, entries[i].ptr()) == 0);
287 }
288
289 {
290 marisa::grimoire::Writer writer;
291 writer.open("trie-test.dat");
292 tail.write(writer);
293 }
294
295 tail.clear();
296
297 ASSERT(tail.size() == 0);
298 ASSERT(tail.total_size() == tail.size());
299
300 {
301 marisa::grimoire::Mapper mapper;
302 mapper.open("trie-test.dat");
303 tail.map(mapper);
304
305 ASSERT(tail.mode() == MARISA_TEXT_TAIL);
306 ASSERT(tail.size() == 11);
307 for (std::size_t i = 0; i < entries.size(); ++i) {
308 const char * const ptr = &tail[offsets[i]];
309 ASSERT(std::strlen(ptr) == entries[i].length());
310 ASSERT(std::strcmp(ptr, entries[i].ptr()) == 0);
311 }
312 tail.clear();
313 }
314
315 {
316 marisa::grimoire::Reader reader;
317 reader.open("trie-test.dat");
318 tail.read(reader);
319 }
320
321 ASSERT(tail.size() == 11);
322 ASSERT(offsets.size() == entries.size());
323 for (std::size_t i = 0; i < entries.size(); ++i) {
324 const char * const ptr = &tail[offsets[i]];
325 ASSERT(std::strlen(ptr) == entries[i].length());
326 ASSERT(std::strcmp(ptr, entries[i].ptr()) == 0);
327 }
328
329 {
330 std::stringstream stream;
331 marisa::grimoire::Writer writer;
332 writer.open(stream);
333 tail.write(writer);
334 tail.clear();
335 marisa::grimoire::Reader reader;
336 reader.open(stream);
337 tail.read(reader);
338 }
339
340 ASSERT(tail.size() == 11);
341 ASSERT(offsets.size() == entries.size());
342 for (std::size_t i = 0; i < entries.size(); ++i) {
343 const char * const ptr = &tail[offsets[i]];
344 ASSERT(std::strlen(ptr) == entries[i].length());
345 ASSERT(std::strcmp(ptr, entries[i].ptr()) == 0);
346 }
347
348 TEST_END();
349 }
350
TestBinaryTail()351 void TestBinaryTail() {
352 TEST_START();
353
354 marisa::grimoire::trie::Tail tail;
355 marisa::grimoire::Vector<marisa::grimoire::trie::Entry> entries;
356 marisa::grimoire::Vector<marisa::UInt32> offsets;
357 tail.build(entries, &offsets, MARISA_BINARY_TAIL);
358
359 ASSERT(tail.mode() == MARISA_TEXT_TAIL);
360 ASSERT(tail.size() == 0);
361 ASSERT(tail.empty());
362 ASSERT(tail.total_size() == tail.size());
363 ASSERT(tail.io_size() == (sizeof(marisa::UInt64) * 6));
364
365 ASSERT(offsets.empty());
366
367 marisa::grimoire::trie::Entry entry;
368 entry.set_str("X", 1);
369 entries.push_back(entry);
370
371 tail.build(entries, &offsets, MARISA_BINARY_TAIL);
372
373 ASSERT(tail.mode() == MARISA_BINARY_TAIL);
374 ASSERT(tail.size() == 1);
375 ASSERT(!tail.empty());
376 ASSERT(tail.total_size() == (tail.size() + sizeof(marisa::UInt64)));
377 ASSERT(tail.io_size() == (sizeof(marisa::UInt64) * 8));
378
379 ASSERT(offsets.size() == entries.size());
380 ASSERT(offsets[0] == 0);
381
382 const char binary_entry[] = { 'N', 'P', '\0', 'T', 'r', 'i', 'e' };
383 entries[0].set_str(binary_entry, sizeof(binary_entry));
384
385 tail.build(entries, &offsets, MARISA_TEXT_TAIL);
386
387 ASSERT(tail.mode() == MARISA_BINARY_TAIL);
388 ASSERT(tail.size() == entries[0].length());
389
390 ASSERT(offsets.size() == entries.size());
391 ASSERT(offsets[0] == 0);
392
393 entries.clear();
394 entry.set_str("abc", 3);
395 entries.push_back(entry);
396 entry.set_str("bc", 2);
397 entries.push_back(entry);
398 entry.set_str("abc", 3);
399 entries.push_back(entry);
400 entry.set_str("c", 1);
401 entries.push_back(entry);
402 entry.set_str("ABC", 3);
403 entries.push_back(entry);
404 entry.set_str("AB", 2);
405 entries.push_back(entry);
406
407 tail.build(entries, &offsets, MARISA_BINARY_TAIL);
408 std::sort(entries.begin(), entries.end(),
409 marisa::grimoire::trie::Entry::IDComparer());
410
411 ASSERT(tail.mode() == MARISA_BINARY_TAIL);
412 ASSERT(tail.size() == 8);
413 ASSERT(offsets.size() == entries.size());
414 for (std::size_t i = 0; i < entries.size(); ++i) {
415 const char * const ptr = &tail[offsets[i]];
416 ASSERT(std::memcmp(ptr, entries[i].ptr(), entries[i].length()) == 0);
417 }
418
419 TEST_END();
420 }
421
TestHistory()422 void TestHistory() {
423 TEST_START();
424
425 marisa::grimoire::trie::History history;
426
427 ASSERT(history.node_id() == 0);
428 ASSERT(history.louds_pos() == 0);
429 ASSERT(history.key_pos() == 0);
430 ASSERT(history.link_id() == MARISA_INVALID_LINK_ID);
431 ASSERT(history.key_id() == MARISA_INVALID_KEY_ID);
432
433 history.set_node_id(100);
434 history.set_louds_pos(200);
435 history.set_key_pos(300);
436 history.set_link_id(400);
437 history.set_key_id(500);
438
439 ASSERT(history.node_id() == 100);
440 ASSERT(history.louds_pos() == 200);
441 ASSERT(history.key_pos() == 300);
442 ASSERT(history.link_id() == 400);
443 ASSERT(history.key_id() == 500);
444
445 TEST_END();
446 }
447
TestState()448 void TestState() {
449 TEST_START();
450
451 marisa::grimoire::trie::State state;
452
453 ASSERT(state.key_buf().empty());
454 ASSERT(state.history().empty());
455 ASSERT(state.node_id() == 0);
456 ASSERT(state.query_pos() == 0);
457 ASSERT(state.history_pos() == 0);
458 ASSERT(state.status_code() == marisa::grimoire::trie::MARISA_READY_TO_ALL);
459
460 state.set_node_id(10);
461 state.set_query_pos(100);
462 state.set_history_pos(1000);
463 state.set_status_code(
464 marisa::grimoire::trie::MARISA_END_OF_PREDICTIVE_SEARCH);
465
466 ASSERT(state.node_id() == 10);
467 ASSERT(state.query_pos() == 100);
468 ASSERT(state.history_pos() == 1000);
469 ASSERT(state.status_code() ==
470 marisa::grimoire::trie::MARISA_END_OF_PREDICTIVE_SEARCH);
471
472 state.lookup_init();
473 ASSERT(state.status_code() == marisa::grimoire::trie::MARISA_READY_TO_ALL);
474
475 state.reverse_lookup_init();
476 ASSERT(state.status_code() == marisa::grimoire::trie::MARISA_READY_TO_ALL);
477
478 state.common_prefix_search_init();
479 ASSERT(state.status_code() ==
480 marisa::grimoire::trie::MARISA_READY_TO_COMMON_PREFIX_SEARCH);
481
482 state.predictive_search_init();
483 ASSERT(state.status_code() ==
484 marisa::grimoire::trie::MARISA_READY_TO_PREDICTIVE_SEARCH);
485
486 TEST_END();
487 }
488
489 } // namespace
490
main()491 int main() try {
492 TestConfig();
493 TestHeader();
494 TestKey();
495 TestRange();
496 TestEntry();
497 TestTextTail();
498 TestBinaryTail();
499 TestHistory();
500 TestState();
501
502 return 0;
503 } catch (const marisa::Exception &ex) {
504 std::cerr << ex.what() << std::endl;
505 throw;
506 }
507