• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <algorithm>
2 #include <functional>
3 #include <queue>
4 #include <stdexcept>
5 
6 #include "range.h"
7 #include "trie.h"
8 
9 namespace marisa_alpha {
10 
build(const char * const * keys,std::size_t num_keys,const std::size_t * key_lengths,const double * key_weights,UInt32 * key_ids,int flags)11 void Trie::build(const char * const *keys, std::size_t num_keys,
12     const std::size_t *key_lengths, const double *key_weights,
13     UInt32 *key_ids, int flags) {
14   MARISA_ALPHA_THROW_IF((keys == NULL) && (num_keys != 0),
15       MARISA_ALPHA_PARAM_ERROR);
16   Vector<Key<String> > temp_keys;
17   temp_keys.resize(num_keys);
18   for (std::size_t i = 0; i < temp_keys.size(); ++i) {
19     MARISA_ALPHA_THROW_IF(keys[i] == NULL, MARISA_ALPHA_PARAM_ERROR);
20     std::size_t length = 0;
21     if (key_lengths == NULL) {
22       while (keys[i][length] != '\0') {
23         ++length;
24       }
25     } else {
26       length = key_lengths[i];
27     }
28     MARISA_ALPHA_THROW_IF(length > MARISA_ALPHA_MAX_LENGTH,
29         MARISA_ALPHA_SIZE_ERROR);
30     temp_keys[i].set_str(String(keys[i], length));
31     temp_keys[i].set_weight((key_weights != NULL) ? key_weights[i] : 1.0);
32   }
33   build_trie(temp_keys, key_ids, flags);
34 }
35 
build(const std::vector<std::string> & keys,std::vector<UInt32> * key_ids,int flags)36 void Trie::build(const std::vector<std::string> &keys,
37     std::vector<UInt32> *key_ids, int flags) {
38   Vector<Key<String> > temp_keys;
39   temp_keys.resize(keys.size());
40   for (std::size_t i = 0; i < temp_keys.size(); ++i) {
41     MARISA_ALPHA_THROW_IF(keys[i].length() > MARISA_ALPHA_MAX_LENGTH,
42         MARISA_ALPHA_SIZE_ERROR);
43     temp_keys[i].set_str(String(keys[i].c_str(), keys[i].length()));
44     temp_keys[i].set_weight(1.0);
45   }
46   build_trie(temp_keys, key_ids, flags);
47 }
48 
build(const std::vector<std::pair<std::string,double>> & keys,std::vector<UInt32> * key_ids,int flags)49 void Trie::build(const std::vector<std::pair<std::string, double> > &keys,
50     std::vector<UInt32> *key_ids, int flags) {
51   Vector<Key<String> > temp_keys;
52   temp_keys.resize(keys.size());
53   for (std::size_t i = 0; i < temp_keys.size(); ++i) {
54     MARISA_ALPHA_THROW_IF(keys[i].first.length() > MARISA_ALPHA_MAX_LENGTH,
55         MARISA_ALPHA_SIZE_ERROR);
56     temp_keys[i].set_str(String(
57         keys[i].first.c_str(), keys[i].first.length()));
58     temp_keys[i].set_weight(keys[i].second);
59   }
60   build_trie(temp_keys, key_ids, flags);
61 }
62 
build_trie(Vector<Key<String>> & keys,std::vector<UInt32> * key_ids,int flags)63 void Trie::build_trie(Vector<Key<String> > &keys,
64     std::vector<UInt32> *key_ids, int flags) {
65   if (key_ids == NULL) {
66     build_trie(keys, static_cast<UInt32 *>(NULL), flags);
67     return;
68   }
69   try {
70     std::vector<UInt32> temp_key_ids(keys.size());
71     build_trie(keys, temp_key_ids.empty() ? NULL : &temp_key_ids[0], flags);
72     key_ids->swap(temp_key_ids);
73   } catch (const std::bad_alloc &) {
74     MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR);
75   } catch (const std::length_error &) {
76     MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR);
77   }
78 }
79 
build_trie(Vector<Key<String>> & keys,UInt32 * key_ids,int flags)80 void Trie::build_trie(Vector<Key<String> > &keys,
81     UInt32 *key_ids, int flags) {
82   Trie temp;
83   Vector<UInt32> terminals;
84   Progress progress(flags);
85   MARISA_ALPHA_THROW_IF(!progress.is_valid(), MARISA_ALPHA_PARAM_ERROR);
86   temp.build_trie(keys, &terminals, progress);
87 
88   typedef std::pair<UInt32, UInt32> TerminalIdPair;
89   Vector<TerminalIdPair> pairs;
90   pairs.resize(terminals.size());
91   for (UInt32 i = 0; i < pairs.size(); ++i) {
92     pairs[i].first = terminals[i];
93     pairs[i].second = i;
94   }
95   terminals.clear();
96   std::sort(pairs.begin(), pairs.end());
97 
98   UInt32 node = 0;
99   for (UInt32 i = 0; i < pairs.size(); ++i) {
100     while (node < pairs[i].first) {
101       temp.terminal_flags_.push_back(false);
102       ++node;
103     }
104     if (node == pairs[i].first) {
105       temp.terminal_flags_.push_back(true);
106       ++node;
107     }
108   }
109   while (node < temp.labels_.size()) {
110     temp.terminal_flags_.push_back(false);
111     ++node;
112   }
113   terminal_flags_.push_back(false);
114   temp.terminal_flags_.build();
115   temp.terminal_flags_.clear_select0s();
116   progress.test_total_size(temp.terminal_flags_.total_size());
117 
118   if (key_ids != NULL) {
119     for (UInt32 i = 0; i < pairs.size(); ++i) {
120       key_ids[pairs[i].second] = temp.node_to_key_id(pairs[i].first);
121     }
122   }
123   MARISA_ALPHA_THROW_IF(progress.total_size() != temp.total_size(),
124       MARISA_ALPHA_UNEXPECTED_ERROR);
125   temp.swap(this);
126 }
127 
128 template <typename T>
build_trie(Vector<Key<T>> & keys,Vector<UInt32> * terminals,Progress & progress)129 void Trie::build_trie(Vector<Key<T> > &keys,
130     Vector<UInt32> *terminals, Progress &progress) {
131   build_cur(keys, terminals, progress);
132   progress.test_total_size(louds_.total_size());
133   progress.test_total_size(sizeof(num_first_branches_));
134   progress.test_total_size(sizeof(num_keys_));
135   if (link_flags_.empty()) {
136     labels_.shrink();
137     progress.test_total_size(labels_.total_size());
138     progress.test_total_size(link_flags_.total_size());
139     progress.test_total_size(links_.total_size());
140     progress.test_total_size(tail_.total_size());
141     return;
142   }
143 
144   Vector<UInt32> next_terminals;
145   build_next(keys, &next_terminals, progress);
146 
147   if (has_trie()) {
148     progress.test_total_size(trie_->terminal_flags_.total_size());
149   } else if (tail_.mode() == MARISA_ALPHA_BINARY_TAIL) {
150     labels_.push_back('\0');
151     link_flags_.push_back(true);
152   }
153   link_flags_.build();
154 
155   for (UInt32 i = 0; i < next_terminals.size(); ++i) {
156     labels_[link_flags_.select1(i)] = (UInt8)(next_terminals[i] % 256);
157     next_terminals[i] /= 256;
158   }
159   link_flags_.clear_select0s();
160   if (has_trie() || (tail_.mode() == MARISA_ALPHA_TEXT_TAIL)) {
161     link_flags_.clear_select1s();
162   }
163 
164   links_.build(next_terminals);
165   labels_.shrink();
166   progress.test_total_size(labels_.total_size());
167   progress.test_total_size(link_flags_.total_size());
168   progress.test_total_size(links_.total_size());
169   progress.test_total_size(tail_.total_size());
170 }
171 
172 template <typename T>
build_cur(Vector<Key<T>> & keys,Vector<UInt32> * terminals,Progress & progress)173 void Trie::build_cur(Vector<Key<T> > &keys,
174     Vector<UInt32> *terminals, Progress &progress) try {
175   num_keys_ = sort_keys(keys);
176   louds_.push_back(true);
177   louds_.push_back(false);
178   labels_.push_back('\0');
179   link_flags_.push_back(false);
180 
181   Vector<Key<T> > rest_keys;
182   std::queue<Range> queue;
183   Vector<WRange> wranges;
184   queue.push(Range(0, (UInt32)keys.size(), 0));
185   while (!queue.empty()) {
186     const UInt32 node = (UInt32)(link_flags_.size() - queue.size());
187     Range range = queue.front();
188     queue.pop();
189 
190     while ((range.begin() < range.end()) &&
191         (keys[range.begin()].str().length() == range.pos())) {
192       keys[range.begin()].set_terminal(node);
193       range.set_begin(range.begin() + 1);
194     }
195     if (range.begin() == range.end()) {
196       louds_.push_back(false);
197       continue;
198     }
199 
200     wranges.clear();
201     double weight = keys[range.begin()].weight();
202     for (UInt32 i = range.begin() + 1; i < range.end(); ++i) {
203       if (keys[i - 1].str()[range.pos()] != keys[i].str()[range.pos()]) {
204         wranges.push_back(WRange(range.begin(), i, range.pos(), weight));
205         range.set_begin(i);
206         weight = 0.0;
207       }
208       weight += keys[i].weight();
209     }
210     wranges.push_back(WRange(range, weight));
211     if (progress.order() == MARISA_ALPHA_WEIGHT_ORDER) {
212       std::stable_sort(wranges.begin(), wranges.end(), std::greater<WRange>());
213     }
214     if (node == 0) {
215       num_first_branches_ = wranges.size();
216     }
217     for (UInt32 i = 0; i < wranges.size(); ++i) {
218       const WRange &wrange = wranges[i];
219       UInt32 pos = wrange.pos() + 1;
220       if ((progress.tail() != MARISA_ALPHA_WITHOUT_TAIL) ||
221           !progress.is_last()) {
222         while (pos < keys[wrange.begin()].str().length()) {
223           UInt32 j;
224           for (j = wrange.begin() + 1; j < wrange.end(); ++j) {
225             if (keys[j - 1].str()[pos] != keys[j].str()[pos]) {
226               break;
227             }
228           }
229           if (j < wrange.end()) {
230             break;
231           }
232           ++pos;
233         }
234       }
235       if ((progress.trie() != MARISA_ALPHA_PATRICIA_TRIE) &&
236           (pos != keys[wrange.end() - 1].str().length())) {
237         pos = wrange.pos() + 1;
238       }
239       louds_.push_back(true);
240       if (pos == wrange.pos() + 1) {
241         labels_.push_back(keys[wrange.begin()].str()[wrange.pos()]);
242         link_flags_.push_back(false);
243       } else {
244         labels_.push_back('\0');
245         link_flags_.push_back(true);
246         Key<T> rest_key;
247         rest_key.set_str(keys[wrange.begin()].str().substr(
248             wrange.pos(), pos - wrange.pos()));
249         rest_key.set_weight(wrange.weight());
250         rest_keys.push_back(rest_key);
251       }
252       wranges[i].set_pos(pos);
253       queue.push(wranges[i].range());
254     }
255     louds_.push_back(false);
256   }
257   louds_.push_back(false);
258   louds_.build();
259   if (progress.trie_id() != 0) {
260     louds_.clear_select0s();
261   }
262   if (rest_keys.empty()) {
263     link_flags_.clear();
264   }
265 
266   build_terminals(keys, terminals);
267   keys.swap(&rest_keys);
268 } catch (const std::bad_alloc &) {
269   MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR);
270 } catch (const std::length_error &) {
271   MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR);
272 }
273 
build_next(Vector<Key<String>> & keys,Vector<UInt32> * terminals,Progress & progress)274 void Trie::build_next(Vector<Key<String> > &keys,
275     Vector<UInt32> *terminals, Progress &progress) {
276   if (progress.is_last()) {
277     Vector<String> strs;
278     strs.resize(keys.size());
279     for (UInt32 i = 0; i < strs.size(); ++i) {
280       strs[i] = keys[i].str();
281     }
282     tail_.build(strs, terminals, progress.tail());
283     return;
284   }
285   Vector<Key<RString> > rkeys;
286   rkeys.resize(keys.size());
287   for (UInt32 i = 0; i < rkeys.size(); ++i) {
288     rkeys[i].set_str(RString(keys[i].str()));
289     rkeys[i].set_weight(keys[i].weight());
290   }
291   keys.clear();
292   trie_.reset(new (std::nothrow) Trie);
293   MARISA_ALPHA_THROW_IF(!has_trie(), MARISA_ALPHA_MEMORY_ERROR);
294   trie_->build_trie(rkeys, terminals, ++progress);
295 }
296 
build_next(Vector<Key<RString>> & rkeys,Vector<UInt32> * terminals,Progress & progress)297 void Trie::build_next(Vector<Key<RString> > &rkeys,
298     Vector<UInt32> *terminals, Progress &progress) {
299   if (progress.is_last()) {
300     Vector<String> strs;
301     strs.resize(rkeys.size());
302     for (UInt32 i = 0; i < strs.size(); ++i) {
303       strs[i] = String(rkeys[i].str().ptr(), rkeys[i].str().length());
304     }
305     tail_.build(strs, terminals, progress.tail());
306     return;
307   }
308   trie_.reset(new (std::nothrow) Trie);
309   MARISA_ALPHA_THROW_IF(!has_trie(), MARISA_ALPHA_MEMORY_ERROR);
310   trie_->build_trie(rkeys, terminals, ++progress);
311 }
312 
313 template <typename T>
sort_keys(Vector<Key<T>> & keys) const314 UInt32 Trie::sort_keys(Vector<Key<T> > &keys) const {
315   if (keys.empty()) {
316     return 0;
317   }
318   for (UInt32 i = 0; i < keys.size(); ++i) {
319     keys[i].set_id(i);
320   }
321   std::sort(keys.begin(), keys.end());
322   UInt32 count = 1;
323   for (UInt32 i = 1; i < keys.size(); ++i) {
324     if (keys[i - 1].str() != keys[i].str()) {
325       ++count;
326     }
327   }
328   return count;
329 }
330 
331 template <typename T>
build_terminals(const Vector<Key<T>> & keys,Vector<UInt32> * terminals) const332 void Trie::build_terminals(const Vector<Key<T> > &keys,
333     Vector<UInt32> *terminals) const {
334   Vector<UInt32> temp_terminals;
335   temp_terminals.resize(keys.size());
336   for (UInt32 i = 0; i < keys.size(); ++i) {
337     temp_terminals[keys[i].id()] = keys[i].terminal();
338   }
339   temp_terminals.swap(terminals);
340 }
341 
342 }  // namespace marisa_alpha
343