1 #include <algorithm>
2 #include <functional>
3 #include <queue>
4 #include <stdexcept>
5
6 #include "range.h"
7 #include "trie.h"
8
9 namespace marisa {
10
build(const char * const * keys,std::size_t num_keys,const std::size_t * key_lengths,const double * key_weights,UInt32 * key_ids,int flags)11 void Trie::build(const char * const *keys, std::size_t num_keys,
12 const std::size_t *key_lengths, const double *key_weights,
13 UInt32 *key_ids, int flags) {
14 MARISA_THROW_IF((keys == NULL) && (num_keys != 0), MARISA_PARAM_ERROR);
15 Vector<Key<String> > temp_keys;
16 temp_keys.resize(num_keys);
17 for (std::size_t i = 0; i < temp_keys.size(); ++i) {
18 MARISA_THROW_IF(keys[i] == NULL, MARISA_PARAM_ERROR);
19 std::size_t length = 0;
20 if (key_lengths == NULL) {
21 while (keys[i][length] != '\0') {
22 ++length;
23 }
24 } else {
25 length = key_lengths[i];
26 }
27 MARISA_THROW_IF(length > MARISA_MAX_LENGTH, MARISA_SIZE_ERROR);
28 temp_keys[i].set_str(String(keys[i], length));
29 temp_keys[i].set_weight((key_weights != NULL) ? key_weights[i] : 1.0);
30 }
31 build_trie(temp_keys, key_ids, flags);
32 }
33
build(const std::vector<std::string> & keys,std::vector<UInt32> * key_ids,int flags)34 void Trie::build(const std::vector<std::string> &keys,
35 std::vector<UInt32> *key_ids, int flags) {
36 Vector<Key<String> > temp_keys;
37 temp_keys.resize(keys.size());
38 for (std::size_t i = 0; i < temp_keys.size(); ++i) {
39 MARISA_THROW_IF(keys[i].length() > MARISA_MAX_LENGTH, MARISA_SIZE_ERROR);
40 temp_keys[i].set_str(String(keys[i].c_str(), keys[i].length()));
41 temp_keys[i].set_weight(1.0);
42 }
43 build_trie(temp_keys, key_ids, flags);
44 }
45
build(const std::vector<std::pair<std::string,double>> & keys,std::vector<UInt32> * key_ids,int flags)46 void Trie::build(const std::vector<std::pair<std::string, double> > &keys,
47 std::vector<UInt32> *key_ids, int flags) {
48 Vector<Key<String> > temp_keys;
49 temp_keys.resize(keys.size());
50 for (std::size_t i = 0; i < temp_keys.size(); ++i) {
51 MARISA_THROW_IF(keys[i].first.length() > MARISA_MAX_LENGTH,
52 MARISA_SIZE_ERROR);
53 temp_keys[i].set_str(String(
54 keys[i].first.c_str(), keys[i].first.length()));
55 temp_keys[i].set_weight(keys[i].second);
56 }
57 build_trie(temp_keys, key_ids, flags);
58 }
59
build_trie(Vector<Key<String>> & keys,std::vector<UInt32> * key_ids,int flags)60 void Trie::build_trie(Vector<Key<String> > &keys,
61 std::vector<UInt32> *key_ids, int flags) {
62 if (key_ids == NULL) {
63 build_trie(keys, static_cast<UInt32 *>(NULL), flags);
64 return;
65 }
66 std::vector<UInt32> temp_key_ids(keys.size());
67 build_trie(keys, temp_key_ids.empty() ? NULL : &temp_key_ids[0], flags);
68 key_ids->swap(temp_key_ids);
69 }
70
build_trie(Vector<Key<String>> & keys,UInt32 * key_ids,int flags)71 void Trie::build_trie(Vector<Key<String> > &keys,
72 UInt32 *key_ids, int flags) {
73 Trie temp;
74 Vector<UInt32> terminals;
75 Progress progress(flags);
76 MARISA_THROW_IF(!progress.is_valid(), MARISA_PARAM_ERROR);
77 temp.build_trie(keys, &terminals, progress);
78
79 typedef std::pair<UInt32, UInt32> TerminalIdPair;
80 Vector<TerminalIdPair> pairs;
81 pairs.resize(terminals.size());
82 for (UInt32 i = 0; i < pairs.size(); ++i) {
83 pairs[i].first = terminals[i];
84 pairs[i].second = i;
85 }
86 terminals.clear();
87 std::sort(pairs.begin(), pairs.end());
88
89 UInt32 node = 0;
90 for (UInt32 i = 0; i < pairs.size(); ++i) {
91 while (node < pairs[i].first) {
92 temp.terminal_flags_.push_back(false);
93 ++node;
94 }
95 if (node == pairs[i].first) {
96 temp.terminal_flags_.push_back(true);
97 ++node;
98 }
99 }
100 while (node < temp.labels_.size()) {
101 temp.terminal_flags_.push_back(false);
102 ++node;
103 }
104 terminal_flags_.push_back(false);
105 temp.terminal_flags_.build();
106 temp.terminal_flags_.clear_select0s();
107 progress.test_total_size(temp.terminal_flags_.total_size());
108
109 if (key_ids != NULL) {
110 for (UInt32 i = 0; i < pairs.size(); ++i) {
111 key_ids[pairs[i].second] = temp.node_to_key_id(pairs[i].first);
112 }
113 }
114 MARISA_THROW_IF(progress.total_size() != temp.total_size(),
115 MARISA_UNEXPECTED_ERROR);
116 temp.swap(this);
117 }
118
119 template <typename T>
build_trie(Vector<Key<T>> & keys,Vector<UInt32> * terminals,Progress & progress)120 void Trie::build_trie(Vector<Key<T> > &keys,
121 Vector<UInt32> *terminals, Progress &progress) {
122 build_cur(keys, terminals, progress);
123 progress.test_total_size(louds_.total_size());
124 progress.test_total_size(sizeof(num_first_branches_));
125 progress.test_total_size(sizeof(num_keys_));
126 if (link_flags_.empty()) {
127 labels_.shrink();
128 progress.test_total_size(labels_.total_size());
129 progress.test_total_size(link_flags_.total_size());
130 progress.test_total_size(links_.total_size());
131 progress.test_total_size(tail_.total_size());
132 return;
133 }
134
135 Vector<UInt32> next_terminals;
136 build_next(keys, &next_terminals, progress);
137
138 if (has_trie()) {
139 progress.test_total_size(trie_->terminal_flags_.total_size());
140 } else if (tail_.mode() == MARISA_BINARY_TAIL) {
141 labels_.push_back('\0');
142 link_flags_.push_back(true);
143 }
144 link_flags_.build();
145
146 for (UInt32 i = 0; i < next_terminals.size(); ++i) {
147 labels_[link_flags_.select1(i)] = (UInt8)(next_terminals[i] % 256);
148 next_terminals[i] /= 256;
149 }
150 link_flags_.clear_select0s();
151 if (has_trie() || (tail_.mode() == MARISA_TEXT_TAIL)) {
152 link_flags_.clear_select1s();
153 }
154
155 links_.build(next_terminals);
156 labels_.shrink();
157 progress.test_total_size(labels_.total_size());
158 progress.test_total_size(link_flags_.total_size());
159 progress.test_total_size(links_.total_size());
160 progress.test_total_size(tail_.total_size());
161 }
162
163 template <typename T>
build_cur(Vector<Key<T>> & keys,Vector<UInt32> * terminals,Progress & progress)164 void Trie::build_cur(Vector<Key<T> > &keys,
165 Vector<UInt32> *terminals, Progress &progress) {
166 num_keys_ = sort_keys(keys);
167 louds_.push_back(true);
168 louds_.push_back(false);
169 labels_.push_back('\0');
170 link_flags_.push_back(false);
171
172 Vector<Key<T> > rest_keys;
173 std::queue<Range> queue;
174 Vector<WRange> wranges;
175 queue.push(Range(0, (UInt32)keys.size(), 0));
176 while (!queue.empty()) {
177 const UInt32 node = (UInt32)(link_flags_.size() - queue.size());
178 Range range = queue.front();
179 queue.pop();
180
181 while ((range.begin() < range.end()) &&
182 (keys[range.begin()].str().length() == range.pos())) {
183 keys[range.begin()].set_terminal(node);
184 range.set_begin(range.begin() + 1);
185 }
186 if (range.begin() == range.end()) {
187 louds_.push_back(false);
188 continue;
189 }
190
191 wranges.clear();
192 double weight = keys[range.begin()].weight();
193 for (UInt32 i = range.begin() + 1; i < range.end(); ++i) {
194 if (keys[i - 1].str()[range.pos()] != keys[i].str()[range.pos()]) {
195 wranges.push_back(WRange(range.begin(), i, range.pos(), weight));
196 range.set_begin(i);
197 weight = 0.0;
198 }
199 weight += keys[i].weight();
200 }
201 wranges.push_back(WRange(range, weight));
202 if (progress.order() == MARISA_WEIGHT_ORDER) {
203 std::stable_sort(wranges.begin(), wranges.end(), std::greater<WRange>());
204 }
205 if (node == 0) {
206 num_first_branches_ = wranges.size();
207 }
208 for (UInt32 i = 0; i < wranges.size(); ++i) {
209 const WRange &wrange = wranges[i];
210 UInt32 pos = wrange.pos() + 1;
211 if ((progress.tail() != MARISA_WITHOUT_TAIL) || !progress.is_last()) {
212 while (pos < keys[wrange.begin()].str().length()) {
213 UInt32 j;
214 for (j = wrange.begin() + 1; j < wrange.end(); ++j) {
215 if (keys[j - 1].str()[pos] != keys[j].str()[pos]) {
216 break;
217 }
218 }
219 if (j < wrange.end()) {
220 break;
221 }
222 ++pos;
223 }
224 }
225 if ((progress.trie() != MARISA_PATRICIA_TRIE) &&
226 (pos != keys[wrange.end() - 1].str().length())) {
227 pos = wrange.pos() + 1;
228 }
229 louds_.push_back(true);
230 if (pos == wrange.pos() + 1) {
231 labels_.push_back(keys[wrange.begin()].str()[wrange.pos()]);
232 link_flags_.push_back(false);
233 } else {
234 labels_.push_back('\0');
235 link_flags_.push_back(true);
236 Key<T> rest_key;
237 rest_key.set_str(keys[wrange.begin()].str().substr(
238 wrange.pos(), pos - wrange.pos()));
239 rest_key.set_weight(wrange.weight());
240 rest_keys.push_back(rest_key);
241 }
242 wranges[i].set_pos(pos);
243 queue.push(wranges[i].range());
244 }
245 louds_.push_back(false);
246 }
247 louds_.push_back(false);
248 louds_.build();
249 if (progress.trie_id() != 0) {
250 louds_.clear_select0s();
251 }
252 if (rest_keys.empty()) {
253 link_flags_.clear();
254 }
255
256 build_terminals(keys, terminals);
257 keys.swap(&rest_keys);
258 }
259
build_next(Vector<Key<String>> & keys,Vector<UInt32> * terminals,Progress & progress)260 void Trie::build_next(Vector<Key<String> > &keys,
261 Vector<UInt32> *terminals, Progress &progress) {
262 if (progress.is_last()) {
263 Vector<String> strs;
264 strs.resize(keys.size());
265 for (UInt32 i = 0; i < strs.size(); ++i) {
266 strs[i] = keys[i].str();
267 }
268 tail_.build(strs, terminals, progress.tail());
269 return;
270 }
271 Vector<Key<RString> > rkeys;
272 rkeys.resize(keys.size());
273 for (UInt32 i = 0; i < rkeys.size(); ++i) {
274 rkeys[i].set_str(RString(keys[i].str()));
275 rkeys[i].set_weight(keys[i].weight());
276 }
277 keys.clear();
278 trie_.reset(new (std::nothrow) Trie);
279 MARISA_THROW_IF(!has_trie(), MARISA_MEMORY_ERROR);
280 trie_->build_trie(rkeys, terminals, ++progress);
281 }
282
build_next(Vector<Key<RString>> & rkeys,Vector<UInt32> * terminals,Progress & progress)283 void Trie::build_next(Vector<Key<RString> > &rkeys,
284 Vector<UInt32> *terminals, Progress &progress) {
285 if (progress.is_last()) {
286 Vector<String> strs;
287 strs.resize(rkeys.size());
288 for (UInt32 i = 0; i < strs.size(); ++i) {
289 strs[i] = String(rkeys[i].str().ptr(), rkeys[i].str().length());
290 }
291 tail_.build(strs, terminals, progress.tail());
292 return;
293 }
294 trie_.reset(new (std::nothrow) Trie);
295 MARISA_THROW_IF(!has_trie(), MARISA_MEMORY_ERROR);
296 trie_->build_trie(rkeys, terminals, ++progress);
297 }
298
299 template <typename T>
sort_keys(Vector<Key<T>> & keys) const300 UInt32 Trie::sort_keys(Vector<Key<T> > &keys) const {
301 if (keys.empty()) {
302 return 0;
303 }
304 for (UInt32 i = 0; i < keys.size(); ++i) {
305 keys[i].set_id(i);
306 }
307 std::sort(keys.begin(), keys.end());
308 UInt32 count = 1;
309 for (UInt32 i = 1; i < keys.size(); ++i) {
310 if (keys[i - 1].str() != keys[i].str()) {
311 ++count;
312 }
313 }
314 return count;
315 }
316
317 template <typename T>
build_terminals(const Vector<Key<T>> & keys,Vector<UInt32> * terminals) const318 void Trie::build_terminals(const Vector<Key<T> > &keys,
319 Vector<UInt32> *terminals) const {
320 Vector<UInt32> temp_terminals;
321 temp_terminals.resize(keys.size());
322 for (UInt32 i = 0; i < keys.size(); ++i) {
323 temp_terminals[keys[i].id()] = keys[i].terminal();
324 }
325 temp_terminals.swap(terminals);
326 }
327
328 } // namespace marisa
329