1 #include <cstdlib>
2 #include <cstring>
3 #include <ctime>
4 #include <sstream>
5
6 #include <marisa.h>
7
8 #include "marisa-assert.h"
9
10 namespace {
11
TestEmptyTrie()12 void TestEmptyTrie() {
13 TEST_START();
14
15 marisa::Trie trie;
16
17 EXCEPT(trie.save("marisa-test.dat"), MARISA_STATE_ERROR);
18 #ifdef _MSC_VER
19 EXCEPT(trie.write(::_fileno(stdout)), MARISA_STATE_ERROR);
20 #else // _MSC_VER
21 EXCEPT(trie.write(::fileno(stdout)), MARISA_STATE_ERROR);
22 #endif // _MSC_VER
23 EXCEPT(std::cout << trie, MARISA_STATE_ERROR);
24 EXCEPT(marisa::fwrite(stdout, trie), MARISA_STATE_ERROR);
25
26 marisa::Agent agent;
27
28 EXCEPT(trie.lookup(agent), MARISA_STATE_ERROR);
29 EXCEPT(trie.reverse_lookup(agent), MARISA_STATE_ERROR);
30 EXCEPT(trie.common_prefix_search(agent), MARISA_STATE_ERROR);
31 EXCEPT(trie.predictive_search(agent), MARISA_STATE_ERROR);
32
33 EXCEPT(trie.num_tries(), MARISA_STATE_ERROR);
34 EXCEPT(trie.num_keys(), MARISA_STATE_ERROR);
35 EXCEPT(trie.num_nodes(), MARISA_STATE_ERROR);
36
37 EXCEPT(trie.tail_mode(), MARISA_STATE_ERROR);
38 EXCEPT(trie.node_order(), MARISA_STATE_ERROR);
39
40 EXCEPT(trie.empty(), MARISA_STATE_ERROR);
41 EXCEPT(trie.size(), MARISA_STATE_ERROR);
42 EXCEPT(trie.total_size(), MARISA_STATE_ERROR);
43 EXCEPT(trie.io_size(), MARISA_STATE_ERROR);
44
45 marisa::Keyset keyset;
46 trie.build(keyset);
47
48 ASSERT(!trie.lookup(agent));
49 EXCEPT(trie.reverse_lookup(agent), MARISA_BOUND_ERROR);
50 ASSERT(!trie.common_prefix_search(agent));
51 ASSERT(!trie.predictive_search(agent));
52
53 ASSERT(trie.num_tries() == 1);
54 ASSERT(trie.num_keys() == 0);
55 ASSERT(trie.num_nodes() == 1);
56
57 ASSERT(trie.tail_mode() == MARISA_DEFAULT_TAIL);
58 ASSERT(trie.node_order() == MARISA_DEFAULT_ORDER);
59
60 ASSERT(trie.empty());
61 ASSERT(trie.size() == 0);
62 ASSERT(trie.total_size() != 0);
63 ASSERT(trie.io_size() != 0);
64
65 keyset.push_back("");
66 trie.build(keyset);
67
68 ASSERT(trie.lookup(agent));
69 trie.reverse_lookup(agent);
70 ASSERT(trie.common_prefix_search(agent));
71 ASSERT(!trie.common_prefix_search(agent));
72 ASSERT(trie.predictive_search(agent));
73 ASSERT(!trie.predictive_search(agent));
74
75 ASSERT(trie.num_keys() == 1);
76 ASSERT(trie.num_nodes() == 1);
77
78 ASSERT(!trie.empty());
79 ASSERT(trie.size() == 1);
80 ASSERT(trie.total_size() != 0);
81 ASSERT(trie.io_size() != 0);
82
83 TEST_END();
84 }
85
TestTinyTrie()86 void TestTinyTrie() {
87 TEST_START();
88
89 marisa::Keyset keyset;
90 keyset.push_back("bach");
91 keyset.push_back("bet");
92 keyset.push_back("chat");
93 keyset.push_back("check");
94 keyset.push_back("check");
95
96 marisa::Trie trie;
97 trie.build(keyset, 1);
98
99 ASSERT(trie.num_tries() == 1);
100 ASSERT(trie.num_keys() == 4);
101 ASSERT(trie.num_nodes() == 7);
102
103 ASSERT(trie.tail_mode() == MARISA_DEFAULT_TAIL);
104 ASSERT(trie.node_order() == MARISA_DEFAULT_ORDER);
105
106 ASSERT(keyset[0].id() == 2);
107 ASSERT(keyset[1].id() == 3);
108 ASSERT(keyset[2].id() == 1);
109 ASSERT(keyset[3].id() == 0);
110 ASSERT(keyset[4].id() == 0);
111
112 marisa::Agent agent;
113 for (std::size_t i = 0; i < keyset.size(); ++i) {
114 agent.set_query(keyset[i].ptr(), keyset[i].length());
115 ASSERT(trie.lookup(agent));
116 ASSERT(agent.key().id() == keyset[i].id());
117
118 agent.set_query(keyset[i].id());
119 trie.reverse_lookup(agent);
120 ASSERT(agent.key().length() == keyset[i].length());
121 ASSERT(std::memcmp(agent.key().ptr(), keyset[i].ptr(),
122 agent.key().length()) == 0);
123 }
124
125 agent.set_query("be");
126 ASSERT(!trie.common_prefix_search(agent));
127 agent.set_query("beX");
128 ASSERT(!trie.common_prefix_search(agent));
129 agent.set_query("bet");
130 ASSERT(trie.common_prefix_search(agent));
131 ASSERT(!trie.common_prefix_search(agent));
132 agent.set_query("betX");
133 ASSERT(trie.common_prefix_search(agent));
134 ASSERT(!trie.common_prefix_search(agent));
135
136 agent.set_query("chatX");
137 ASSERT(!trie.predictive_search(agent));
138 agent.set_query("chat");
139 ASSERT(trie.predictive_search(agent));
140 ASSERT(agent.key().length() == 4);
141 ASSERT(!trie.predictive_search(agent));
142
143 agent.set_query("cha");
144 ASSERT(trie.predictive_search(agent));
145 ASSERT(agent.key().length() == 4);
146 ASSERT(!trie.predictive_search(agent));
147
148 agent.set_query("c");
149 ASSERT(trie.predictive_search(agent));
150 ASSERT(agent.key().length() == 5);
151 ASSERT(std::memcmp(agent.key().ptr(), "check", 5) == 0);
152 ASSERT(trie.predictive_search(agent));
153 ASSERT(agent.key().length() == 4);
154 ASSERT(std::memcmp(agent.key().ptr(), "chat", 4) == 0);
155 ASSERT(!trie.predictive_search(agent));
156
157 agent.set_query("ch");
158 ASSERT(trie.predictive_search(agent));
159 ASSERT(agent.key().length() == 5);
160 ASSERT(std::memcmp(agent.key().ptr(), "check", 5) == 0);
161 ASSERT(trie.predictive_search(agent));
162 ASSERT(agent.key().length() == 4);
163 ASSERT(std::memcmp(agent.key().ptr(), "chat", 4) == 0);
164 ASSERT(!trie.predictive_search(agent));
165
166 trie.build(keyset, 1 | MARISA_LABEL_ORDER);
167
168 ASSERT(trie.num_tries() == 1);
169 ASSERT(trie.num_keys() == 4);
170 ASSERT(trie.num_nodes() == 7);
171
172 ASSERT(trie.tail_mode() == MARISA_DEFAULT_TAIL);
173 ASSERT(trie.node_order() == MARISA_LABEL_ORDER);
174
175 ASSERT(keyset[0].id() == 0);
176 ASSERT(keyset[1].id() == 1);
177 ASSERT(keyset[2].id() == 2);
178 ASSERT(keyset[3].id() == 3);
179 ASSERT(keyset[4].id() == 3);
180
181 for (std::size_t i = 0; i < keyset.size(); ++i) {
182 agent.set_query(keyset[i].ptr(), keyset[i].length());
183 ASSERT(trie.lookup(agent));
184 ASSERT(agent.key().id() == keyset[i].id());
185
186 agent.set_query(keyset[i].id());
187 trie.reverse_lookup(agent);
188 ASSERT(agent.key().length() == keyset[i].length());
189 ASSERT(std::memcmp(agent.key().ptr(), keyset[i].ptr(),
190 agent.key().length()) == 0);
191 }
192
193 agent.set_query("");
194 for (std::size_t i = 0; i < trie.size(); ++i) {
195 ASSERT(trie.predictive_search(agent));
196 ASSERT(agent.key().id() == i);
197 }
198 ASSERT(!trie.predictive_search(agent));
199
200 TEST_END();
201 }
202
MakeKeyset(std::size_t num_keys,marisa::TailMode tail_mode,marisa::Keyset * keyset)203 void MakeKeyset(std::size_t num_keys, marisa::TailMode tail_mode,
204 marisa::Keyset *keyset) {
205 char key_buf[16];
206 for (std::size_t i = 0; i < num_keys; ++i) {
207 const std::size_t length =
208 static_cast<std::size_t>(std::rand()) % sizeof(key_buf);
209 for (std::size_t j = 0; j < length; ++j) {
210 key_buf[j] = (char)(std::rand() % 10);
211 if (tail_mode == MARISA_TEXT_TAIL) {
212 key_buf[j] = static_cast<char>(key_buf[j] + '0');
213 }
214 }
215 keyset->push_back(key_buf, length);
216 }
217 }
218
TestLookup(const marisa::Trie & trie,const marisa::Keyset & keyset)219 void TestLookup(const marisa::Trie &trie, const marisa::Keyset &keyset) {
220 marisa::Agent agent;
221 for (std::size_t i = 0; i < keyset.size(); ++i) {
222 agent.set_query(keyset[i].ptr(), keyset[i].length());
223 ASSERT(trie.lookup(agent));
224 ASSERT(agent.key().id() == keyset[i].id());
225
226 agent.set_query(keyset[i].id());
227 trie.reverse_lookup(agent);
228 ASSERT(agent.key().length() == keyset[i].length());
229 ASSERT(std::memcmp(agent.key().ptr(), keyset[i].ptr(),
230 agent.key().length()) == 0);
231 }
232 }
233
TestCommonPrefixSearch(const marisa::Trie & trie,const marisa::Keyset & keyset)234 void TestCommonPrefixSearch(const marisa::Trie &trie,
235 const marisa::Keyset &keyset) {
236 marisa::Agent agent;
237 for (std::size_t i = 0; i < keyset.size(); ++i) {
238 agent.set_query(keyset[i].ptr(), keyset[i].length());
239 ASSERT(trie.common_prefix_search(agent));
240 ASSERT(agent.key().id() <= keyset[i].id());
241 while (trie.common_prefix_search(agent)) {
242 ASSERT(agent.key().id() <= keyset[i].id());
243 }
244 ASSERT(agent.key().id() == keyset[i].id());
245 }
246 }
247
TestPredictiveSearch(const marisa::Trie & trie,const marisa::Keyset & keyset)248 void TestPredictiveSearch(const marisa::Trie &trie,
249 const marisa::Keyset &keyset) {
250 marisa::Agent agent;
251 for (std::size_t i = 0; i < keyset.size(); ++i) {
252 agent.set_query(keyset[i].ptr(), keyset[i].length());
253 ASSERT(trie.predictive_search(agent));
254 ASSERT(agent.key().id() == keyset[i].id());
255 while (trie.predictive_search(agent)) {
256 ASSERT(agent.key().id() > keyset[i].id());
257 }
258 }
259 }
260
TestTrie(int num_tries,marisa::TailMode tail_mode,marisa::NodeOrder node_order,marisa::Keyset & keyset)261 void TestTrie(int num_tries, marisa::TailMode tail_mode,
262 marisa::NodeOrder node_order, marisa::Keyset &keyset) {
263 for (std::size_t i = 0; i < keyset.size(); ++i) {
264 keyset[i].set_weight(1.0F);
265 }
266
267 marisa::Trie trie;
268 trie.build(keyset, num_tries | tail_mode | node_order);
269
270 ASSERT(trie.num_tries() == (std::size_t)num_tries);
271 ASSERT(trie.num_keys() <= keyset.size());
272
273 ASSERT(trie.tail_mode() == tail_mode);
274 ASSERT(trie.node_order() == node_order);
275
276 TestLookup(trie, keyset);
277 TestCommonPrefixSearch(trie, keyset);
278 TestPredictiveSearch(trie, keyset);
279
280 trie.save("marisa-test.dat");
281
282 trie.clear();
283 trie.load("marisa-test.dat");
284
285 ASSERT(trie.num_tries() == (std::size_t)num_tries);
286 ASSERT(trie.num_keys() <= keyset.size());
287
288 ASSERT(trie.tail_mode() == tail_mode);
289 ASSERT(trie.node_order() == node_order);
290
291 TestLookup(trie, keyset);
292
293 {
294 std::FILE *file;
295 #ifdef _MSC_VER
296 ASSERT(::fopen_s(&file, "marisa-test.dat", "wb") == 0);
297 #else // _MSC_VER
298 file = std::fopen("marisa-test.dat", "wb");
299 ASSERT(file != NULL);
300 #endif // _MSC_VER
301 marisa::fwrite(file, trie);
302 std::fclose(file);
303 trie.clear();
304 #ifdef _MSC_VER
305 ASSERT(::fopen_s(&file, "marisa-test.dat", "rb") == 0);
306 #else // _MSC_VER
307 file = std::fopen("marisa-test.dat", "rb");
308 ASSERT(file != NULL);
309 #endif // _MSC_VER
310 marisa::fread(file, &trie);
311 std::fclose(file);
312 }
313
314 ASSERT(trie.num_tries() == (std::size_t)num_tries);
315 ASSERT(trie.num_keys() <= keyset.size());
316
317 ASSERT(trie.tail_mode() == tail_mode);
318 ASSERT(trie.node_order() == node_order);
319
320 TestLookup(trie, keyset);
321
322 trie.clear();
323 trie.mmap("marisa-test.dat");
324
325 ASSERT(trie.num_tries() == (std::size_t)num_tries);
326 ASSERT(trie.num_keys() <= keyset.size());
327
328 ASSERT(trie.tail_mode() == tail_mode);
329 ASSERT(trie.node_order() == node_order);
330
331 TestLookup(trie, keyset);
332
333 {
334 std::stringstream stream;
335 stream << trie;
336 trie.clear();
337 stream >> trie;
338 }
339
340 ASSERT(trie.num_tries() == (std::size_t)num_tries);
341 ASSERT(trie.num_keys() <= keyset.size());
342
343 ASSERT(trie.tail_mode() == tail_mode);
344 ASSERT(trie.node_order() == node_order);
345
346 TestLookup(trie, keyset);
347 }
348
TestTrie(marisa::TailMode tail_mode,marisa::NodeOrder node_order,marisa::Keyset & keyset)349 void TestTrie(marisa::TailMode tail_mode, marisa::NodeOrder node_order,
350 marisa::Keyset &keyset) {
351 TEST_START();
352 std::cout << ((tail_mode == MARISA_TEXT_TAIL) ? "TEXT" : "BINARY") << ", ";
353 std::cout << ((node_order == MARISA_WEIGHT_ORDER) ?
354 "WEIGHT" : "LABEL") << ": ";
355
356 for (int i = 1; i < 5; ++i) {
357 TestTrie(i, tail_mode, node_order, keyset);
358 }
359
360 TEST_END();
361 }
362
TestTrie(marisa::TailMode tail_mode)363 void TestTrie(marisa::TailMode tail_mode) {
364 marisa::Keyset keyset;
365 MakeKeyset(1000, tail_mode, &keyset);
366
367 TestTrie(tail_mode, MARISA_WEIGHT_ORDER, keyset);
368 TestTrie(tail_mode, MARISA_LABEL_ORDER, keyset);
369 }
370
TestTrie()371 void TestTrie() {
372 TestTrie(MARISA_TEXT_TAIL);
373 TestTrie(MARISA_BINARY_TAIL);
374 }
375
376 } // namespace
377
main()378 int main() try {
379 std::srand((unsigned int)std::time(NULL));
380
381 TestEmptyTrie();
382 TestTinyTrie();
383 TestTrie();
384
385 return 0;
386 } catch (const marisa::Exception &ex) {
387 std::cerr << ex.what() << std::endl;
388 throw;
389 }
390