• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2009 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <assert.h>
18 #include <stdio.h>
19 #include <string.h>
20 #include "../include/dicttrie.h"
21 #include "../include/dictbuilder.h"
22 #include "../include/lpicache.h"
23 #include "../include/mystdlib.h"
24 #include "../include/ngram.h"
25 
26 namespace ime_pinyin {
27 
DictTrie()28 DictTrie::DictTrie() {
29   spl_trie_ = SpellingTrie::get_cpinstance();
30 
31   root_ = NULL;
32   splid_le0_index_ = NULL;
33   lma_node_num_le0_ = 0;
34   nodes_ge1_ = NULL;
35   lma_node_num_ge1_ = 0;
36   lma_idx_buf_ = NULL;
37   lma_idx_buf_len_ = 0;
38   total_lma_num_ = 0;
39   top_lmas_num_ = 0;
40   dict_list_ = NULL;
41 
42   parsing_marks_ = NULL;
43   mile_stones_ = NULL;
44   reset_milestones(0, kFirstValidMileStoneHandle);
45 }
46 
~DictTrie()47 DictTrie::~DictTrie() {
48   free_resource(true);
49 }
50 
free_resource(bool free_dict_list)51 void DictTrie::free_resource(bool free_dict_list) {
52   if (NULL != root_)
53     free(root_);
54   root_ = NULL;
55 
56   if (NULL != splid_le0_index_)
57     free(splid_le0_index_);
58   splid_le0_index_ = NULL;
59 
60   if (NULL != nodes_ge1_)
61     free(nodes_ge1_);
62   nodes_ge1_ = NULL;
63 
64   if (NULL != nodes_ge1_)
65     free(nodes_ge1_);
66   nodes_ge1_ = NULL;
67 
68   if (free_dict_list) {
69     if (NULL != dict_list_) {
70       delete dict_list_;
71     }
72     dict_list_ = NULL;
73   }
74 
75   if (parsing_marks_)
76     delete [] parsing_marks_;
77   parsing_marks_ = NULL;
78 
79   if (mile_stones_)
80     delete [] mile_stones_;
81   mile_stones_ = NULL;
82 
83   reset_milestones(0, kFirstValidMileStoneHandle);
84 }
85 
get_son_offset(const LmaNodeGE1 * node)86 inline size_t DictTrie::get_son_offset(const LmaNodeGE1 *node) {
87   return ((size_t)node->son_1st_off_l + ((size_t)node->son_1st_off_h << 16));
88 }
89 
get_homo_idx_buf_offset(const LmaNodeGE1 * node)90 inline size_t DictTrie::get_homo_idx_buf_offset(const LmaNodeGE1 *node) {
91   return ((size_t)node->homo_idx_buf_off_l +
92           ((size_t)node->homo_idx_buf_off_h << 16));
93 }
94 
get_lemma_id(size_t id_offset)95 inline LemmaIdType DictTrie::get_lemma_id(size_t id_offset) {
96   LemmaIdType id = 0;
97   for (uint16 pos = kLemmaIdSize - 1; pos > 0; pos--)
98     id = (id << 8) + lma_idx_buf_[id_offset * kLemmaIdSize + pos];
99   id = (id << 8) + lma_idx_buf_[id_offset * kLemmaIdSize];
100   return id;
101 }
102 
103 #ifdef ___BUILD_MODEL___
build_dict(const char * fn_raw,const char * fn_validhzs)104 bool DictTrie::build_dict(const char* fn_raw, const char* fn_validhzs) {
105   DictBuilder* dict_builder = new DictBuilder();
106 
107   free_resource(true);
108 
109   return dict_builder->build_dict(fn_raw, fn_validhzs, this);
110 }
111 
save_dict(FILE * fp)112 bool DictTrie::save_dict(FILE *fp) {
113   if (NULL == fp)
114     return false;
115 
116   if (fwrite(&lma_node_num_le0_, sizeof(size_t), 1, fp) != 1)
117     return false;
118 
119   if (fwrite(&lma_node_num_ge1_, sizeof(size_t), 1, fp) != 1)
120     return false;
121 
122   if (fwrite(&lma_idx_buf_len_, sizeof(size_t), 1, fp) != 1)
123     return false;
124 
125   if (fwrite(&top_lmas_num_, sizeof(size_t), 1, fp) != 1)
126     return false;
127 
128   if (fwrite(root_, sizeof(LmaNodeLE0), lma_node_num_le0_, fp)
129       != lma_node_num_le0_)
130     return false;
131 
132   if (fwrite(nodes_ge1_, sizeof(LmaNodeGE1), lma_node_num_ge1_, fp)
133       != lma_node_num_ge1_)
134     return false;
135 
136   if (fwrite(lma_idx_buf_, sizeof(unsigned char), lma_idx_buf_len_, fp) !=
137       lma_idx_buf_len_)
138     return false;
139 
140   return true;
141 }
142 
save_dict(const char * filename)143 bool DictTrie::save_dict(const char *filename) {
144   if (NULL == filename)
145     return false;
146 
147   if (NULL == root_ || NULL == dict_list_)
148     return false;
149 
150   SpellingTrie &spl_trie = SpellingTrie::get_instance();
151   NGram &ngram = NGram::get_instance();
152 
153   FILE *fp = fopen(filename, "wb");
154   if (NULL == fp)
155     return false;
156 
157   if (!spl_trie.save_spl_trie(fp) || !dict_list_->save_list(fp) ||
158       !save_dict(fp) || !ngram.save_ngram(fp)) {
159     fclose(fp);
160     return false;
161   }
162 
163   fclose(fp);
164   return true;
165 }
166 #endif  // ___BUILD_MODEL___
167 
load_dict(FILE * fp)168 bool DictTrie::load_dict(FILE *fp) {
169   if (NULL == fp)
170     return false;
171 
172   if (fread(&lma_node_num_le0_, sizeof(size_t), 1, fp) != 1)
173     return false;
174 
175   if (fread(&lma_node_num_ge1_, sizeof(size_t), 1, fp) != 1)
176     return false;
177 
178   if (fread(&lma_idx_buf_len_, sizeof(size_t), 1, fp) != 1)
179     return false;
180 
181   if (fread(&top_lmas_num_, sizeof(size_t), 1, fp) != 1 ||
182       top_lmas_num_ >= lma_idx_buf_len_)
183     return false;
184 
185   free_resource(false);
186 
187   root_ = static_cast<LmaNodeLE0*>
188           (malloc(lma_node_num_le0_ * sizeof(LmaNodeLE0)));
189   nodes_ge1_ = static_cast<LmaNodeGE1*>
190                (malloc(lma_node_num_ge1_ * sizeof(LmaNodeGE1)));
191   lma_idx_buf_ = (unsigned char*)malloc(lma_idx_buf_len_);
192   total_lma_num_ = lma_idx_buf_len_ / kLemmaIdSize;
193 
194   size_t buf_size = SpellingTrie::get_instance().get_spelling_num() + 1;
195   assert(lma_node_num_le0_ <= buf_size);
196   splid_le0_index_ = static_cast<uint16*>(malloc(buf_size * sizeof(uint16)));
197 
198   // Init the space for parsing.
199   parsing_marks_ = new ParsingMark[kMaxParsingMark];
200   mile_stones_ = new MileStone[kMaxMileStone];
201   reset_milestones(0, kFirstValidMileStoneHandle);
202 
203   if (NULL == root_ || NULL == nodes_ge1_ || NULL == lma_idx_buf_ ||
204       NULL == splid_le0_index_ || NULL == parsing_marks_ ||
205       NULL == mile_stones_) {
206     free_resource(false);
207     return false;
208   }
209 
210   if (fread(root_, sizeof(LmaNodeLE0), lma_node_num_le0_, fp)
211       != lma_node_num_le0_)
212     return false;
213 
214   if (fread(nodes_ge1_, sizeof(LmaNodeGE1), lma_node_num_ge1_, fp)
215       != lma_node_num_ge1_)
216     return false;
217 
218   if (fread(lma_idx_buf_, sizeof(unsigned char), lma_idx_buf_len_, fp) !=
219       lma_idx_buf_len_)
220     return false;
221 
222   // The quick index for the first level sons
223   uint16 last_splid = kFullSplIdStart;
224   size_t last_pos = 0;
225   for (size_t i = 1; i < lma_node_num_le0_; i++) {
226     for (uint16 splid = last_splid; splid < root_[i].spl_idx; splid++)
227       splid_le0_index_[splid - kFullSplIdStart] = last_pos;
228 
229     splid_le0_index_[root_[i].spl_idx - kFullSplIdStart] =
230         static_cast<uint16>(i);
231     last_splid = root_[i].spl_idx;
232     last_pos = i;
233   }
234 
235   for (uint16 splid = last_splid + 1;
236        splid < buf_size + kFullSplIdStart; splid++) {
237     assert(static_cast<size_t>(splid - kFullSplIdStart) < buf_size);
238     splid_le0_index_[splid - kFullSplIdStart] = last_pos + 1;
239   }
240 
241   return true;
242 }
243 
load_dict(const char * filename,LemmaIdType start_id,LemmaIdType end_id)244 bool DictTrie::load_dict(const char *filename, LemmaIdType start_id,
245                          LemmaIdType end_id) {
246   if (NULL == filename || end_id <= start_id)
247     return false;
248 
249   FILE *fp = fopen(filename, "rb");
250   if (NULL == fp)
251     return false;
252 
253   free_resource(true);
254 
255   dict_list_ = new DictList();
256   if (NULL == dict_list_) {
257     fclose(fp);
258     return false;
259   }
260 
261   SpellingTrie &spl_trie = SpellingTrie::get_instance();
262   NGram &ngram = NGram::get_instance();
263 
264   if (!spl_trie.load_spl_trie(fp) || !dict_list_->load_list(fp) ||
265       !load_dict(fp) || !ngram.load_ngram(fp) ||
266       total_lma_num_ > end_id - start_id + 1) {
267     free_resource(true);
268     fclose(fp);
269     return false;
270   }
271 
272   fclose(fp);
273   return true;
274 }
275 
load_dict_fd(int sys_fd,long start_offset,long length,LemmaIdType start_id,LemmaIdType end_id)276 bool DictTrie::load_dict_fd(int sys_fd, long start_offset,
277                             long length, LemmaIdType start_id,
278                             LemmaIdType end_id) {
279   if (start_offset < 0 || length <= 0 || end_id <= start_id)
280     return false;
281 
282   FILE *fp = fdopen(sys_fd, "rb");
283   if (NULL == fp)
284     return false;
285 
286   if (-1 == fseek(fp, start_offset, SEEK_SET)) {
287     fclose(fp);
288     return false;
289   }
290 
291   free_resource(true);
292 
293   dict_list_ = new DictList();
294   if (NULL == dict_list_) {
295     fclose(fp);
296     return false;
297   }
298 
299   SpellingTrie &spl_trie = SpellingTrie::get_instance();
300   NGram &ngram = NGram::get_instance();
301 
302   if (!spl_trie.load_spl_trie(fp) || !dict_list_->load_list(fp) ||
303       !load_dict(fp) || !ngram.load_ngram(fp) ||
304       ftell(fp) < start_offset + length ||
305       total_lma_num_ > end_id - start_id + 1) {
306     free_resource(true);
307     fclose(fp);
308     return false;
309   }
310 
311   fclose(fp);
312   return true;
313 }
314 
fill_lpi_buffer(LmaPsbItem lpi_items[],size_t lpi_max,LmaNodeLE0 * node)315 size_t DictTrie::fill_lpi_buffer(LmaPsbItem lpi_items[], size_t lpi_max,
316                                  LmaNodeLE0 *node) {
317   size_t lpi_num = 0;
318   NGram& ngram = NGram::get_instance();
319   for (size_t homo = 0; homo < (size_t)node->num_of_homo; homo++) {
320     lpi_items[lpi_num].id = get_lemma_id(node->homo_idx_buf_off +
321                                          homo);
322     lpi_items[lpi_num].lma_len = 1;
323     lpi_items[lpi_num].psb =
324         static_cast<LmaScoreType>(ngram.get_uni_psb(lpi_items[lpi_num].id));
325     lpi_num++;
326     if (lpi_num >= lpi_max)
327       break;
328   }
329 
330   return lpi_num;
331 }
332 
fill_lpi_buffer(LmaPsbItem lpi_items[],size_t lpi_max,size_t homo_buf_off,LmaNodeGE1 * node,uint16 lma_len)333 size_t DictTrie::fill_lpi_buffer(LmaPsbItem lpi_items[], size_t lpi_max,
334                                  size_t homo_buf_off, LmaNodeGE1 *node,
335                                  uint16 lma_len) {
336   size_t lpi_num = 0;
337   NGram& ngram = NGram::get_instance();
338   for (size_t homo = 0; homo < (size_t)node->num_of_homo; homo++) {
339     lpi_items[lpi_num].id = get_lemma_id(homo_buf_off + homo);
340     lpi_items[lpi_num].lma_len = lma_len;
341     lpi_items[lpi_num].psb =
342         static_cast<LmaScoreType>(ngram.get_uni_psb(lpi_items[lpi_num].id));
343     lpi_num++;
344     if (lpi_num >= lpi_max)
345       break;
346   }
347 
348   return lpi_num;
349 }
350 
reset_milestones(uint16 from_step,MileStoneHandle from_handle)351 void DictTrie::reset_milestones(uint16 from_step, MileStoneHandle from_handle) {
352   if (0 == from_step) {
353     parsing_marks_pos_ = 0;
354     mile_stones_pos_ = kFirstValidMileStoneHandle;
355   } else {
356     if (from_handle > 0 && from_handle < mile_stones_pos_) {
357       mile_stones_pos_ = from_handle;
358 
359       MileStone *mile_stone = mile_stones_ + from_handle;
360       parsing_marks_pos_ = mile_stone->mark_start;
361     }
362   }
363 }
364 
extend_dict(MileStoneHandle from_handle,const DictExtPara * dep,LmaPsbItem * lpi_items,size_t lpi_max,size_t * lpi_num)365 MileStoneHandle DictTrie::extend_dict(MileStoneHandle from_handle,
366                                       const DictExtPara *dep,
367                                       LmaPsbItem *lpi_items, size_t lpi_max,
368                                       size_t *lpi_num) {
369   if (NULL == dep)
370     return 0;
371 
372   // from LmaNodeLE0 (root) to LmaNodeLE0
373   if (0 == from_handle) {
374     assert(0 == dep->splids_extended);
375     return extend_dict0(from_handle, dep, lpi_items, lpi_max, lpi_num);
376   }
377 
378   // from LmaNodeLE0 to LmaNodeGE1
379   if (1 == dep->splids_extended)
380     return extend_dict1(from_handle, dep, lpi_items, lpi_max, lpi_num);
381 
382   // From LmaNodeGE1 to LmaNodeGE1
383   return extend_dict2(from_handle, dep, lpi_items, lpi_max, lpi_num);
384 }
385 
extend_dict0(MileStoneHandle from_handle,const DictExtPara * dep,LmaPsbItem * lpi_items,size_t lpi_max,size_t * lpi_num)386 MileStoneHandle DictTrie::extend_dict0(MileStoneHandle from_handle,
387                                        const DictExtPara *dep,
388                                        LmaPsbItem *lpi_items,
389                                        size_t lpi_max, size_t *lpi_num) {
390   assert(NULL != dep && 0 == from_handle);
391   *lpi_num = 0;
392   MileStoneHandle ret_handle = 0;
393 
394   uint16 splid = dep->splids[dep->splids_extended];
395   uint16 id_start = dep->id_start;
396   uint16 id_num = dep->id_num;
397 
398   LpiCache& lpi_cache = LpiCache::get_instance();
399   bool cached = lpi_cache.is_cached(splid);
400 
401   // 2. Begin exgtending
402   // 2.1 Get the LmaPsbItem list
403   LmaNodeLE0 *node = root_;
404   size_t son_start = splid_le0_index_[id_start - kFullSplIdStart];
405   size_t son_end = splid_le0_index_[id_start + id_num - kFullSplIdStart];
406   for (size_t son_pos = son_start; son_pos < son_end; son_pos++) {
407     assert(1 == node->son_1st_off);
408     LmaNodeLE0 *son = root_ + son_pos;
409     assert(son->spl_idx >= id_start && son->spl_idx < id_start + id_num);
410 
411     if (!cached && *lpi_num < lpi_max) {
412       bool need_lpi = true;
413       if (spl_trie_->is_half_id_yunmu(splid) && son_pos != son_start)
414         need_lpi = false;
415 
416       if (need_lpi)
417         *lpi_num += fill_lpi_buffer(lpi_items + (*lpi_num),
418                                     lpi_max - *lpi_num, son);
419     }
420 
421     // If necessary, fill in a new mile stone.
422     if (son->spl_idx == id_start) {
423       if (mile_stones_pos_ < kMaxMileStone &&
424           parsing_marks_pos_ < kMaxParsingMark) {
425         parsing_marks_[parsing_marks_pos_].node_offset = son_pos;
426         parsing_marks_[parsing_marks_pos_].node_num = id_num;
427         mile_stones_[mile_stones_pos_].mark_start = parsing_marks_pos_;
428         mile_stones_[mile_stones_pos_].mark_num = 1;
429         ret_handle = mile_stones_pos_;
430         parsing_marks_pos_++;
431         mile_stones_pos_++;
432       }
433     }
434 
435     if (son->spl_idx >= id_start + id_num -1)
436       break;
437   }
438 
439   //  printf("----- parsing marks: %d, mile stone: %d \n", parsing_marks_pos_,
440   //      mile_stones_pos_);
441   return ret_handle;
442 }
443 
extend_dict1(MileStoneHandle from_handle,const DictExtPara * dep,LmaPsbItem * lpi_items,size_t lpi_max,size_t * lpi_num)444 MileStoneHandle DictTrie::extend_dict1(MileStoneHandle from_handle,
445                                        const DictExtPara *dep,
446                                        LmaPsbItem *lpi_items,
447                                        size_t lpi_max, size_t *lpi_num) {
448   assert(NULL != dep && from_handle > 0 && from_handle < mile_stones_pos_);
449 
450   MileStoneHandle ret_handle = 0;
451 
452   // 1. If this is a half Id, get its corresponding full starting Id and
453   // number of full Id.
454   size_t ret_val = 0;
455 
456   uint16 id_start = dep->id_start;
457   uint16 id_num = dep->id_num;
458 
459   // 2. Begin extending.
460   MileStone *mile_stone = mile_stones_ + from_handle;
461 
462   for (uint16 h_pos = 0; h_pos < mile_stone->mark_num; h_pos++) {
463     ParsingMark p_mark = parsing_marks_[mile_stone->mark_start + h_pos];
464     uint16 ext_num = p_mark.node_num;
465     for (uint16 ext_pos = 0; ext_pos < ext_num; ext_pos++) {
466       LmaNodeLE0 *node = root_ + p_mark.node_offset + ext_pos;
467       size_t found_start = 0;
468       size_t found_num = 0;
469       for (size_t son_pos = 0; son_pos < (size_t)node->num_of_son; son_pos++) {
470         assert(node->son_1st_off <= lma_node_num_ge1_);
471         LmaNodeGE1 *son = nodes_ge1_ + node->son_1st_off + son_pos;
472         if (son->spl_idx >= id_start
473             && son->spl_idx < id_start + id_num) {
474           if (*lpi_num < lpi_max) {
475             size_t homo_buf_off = get_homo_idx_buf_offset(son);
476             *lpi_num += fill_lpi_buffer(lpi_items + (*lpi_num),
477                                         lpi_max - *lpi_num, homo_buf_off, son,
478                                         2);
479           }
480 
481           // If necessary, fill in the new DTMI
482           if (0 == found_num) {
483             found_start = son_pos;
484           }
485           found_num++;
486         }
487         if (son->spl_idx >= id_start + id_num - 1 || son_pos ==
488             (size_t)node->num_of_son - 1) {
489           if (found_num > 0) {
490             if (mile_stones_pos_ < kMaxMileStone &&
491                 parsing_marks_pos_ < kMaxParsingMark) {
492               parsing_marks_[parsing_marks_pos_].node_offset =
493                 node->son_1st_off + found_start;
494               parsing_marks_[parsing_marks_pos_].node_num = found_num;
495               if (0 == ret_val)
496                 mile_stones_[mile_stones_pos_].mark_start =
497                   parsing_marks_pos_;
498               parsing_marks_pos_++;
499             }
500 
501             ret_val++;
502           }
503           break;
504         }  // for son_pos
505       }  // for ext_pos
506     }  // for h_pos
507   }
508 
509   if (ret_val > 0) {
510     mile_stones_[mile_stones_pos_].mark_num = ret_val;
511     ret_handle = mile_stones_pos_;
512     mile_stones_pos_++;
513     ret_val = 1;
514   }
515 
516   //  printf("----- parsing marks: %d, mile stone: %d \n", parsing_marks_pos_,
517   //         mile_stones_pos_);
518   return ret_handle;
519 }
520 
extend_dict2(MileStoneHandle from_handle,const DictExtPara * dep,LmaPsbItem * lpi_items,size_t lpi_max,size_t * lpi_num)521 MileStoneHandle DictTrie::extend_dict2(MileStoneHandle from_handle,
522                                        const DictExtPara *dep,
523                                        LmaPsbItem *lpi_items,
524                                        size_t lpi_max, size_t *lpi_num) {
525   assert(NULL != dep && from_handle > 0 && from_handle < mile_stones_pos_);
526 
527   MileStoneHandle ret_handle = 0;
528 
529   // 1. If this is a half Id, get its corresponding full starting Id and
530   // number of full Id.
531   size_t ret_val = 0;
532 
533   uint16 id_start = dep->id_start;
534   uint16 id_num = dep->id_num;
535 
536   // 2. Begin extending.
537   MileStone *mile_stone = mile_stones_ + from_handle;
538 
539   for (uint16 h_pos = 0; h_pos < mile_stone->mark_num; h_pos++) {
540     ParsingMark p_mark = parsing_marks_[mile_stone->mark_start + h_pos];
541     uint16 ext_num = p_mark.node_num;
542     for (uint16 ext_pos = 0; ext_pos < ext_num; ext_pos++) {
543       LmaNodeGE1 *node = nodes_ge1_ + p_mark.node_offset + ext_pos;
544       size_t found_start = 0;
545       size_t found_num = 0;
546 
547       for (size_t son_pos = 0; son_pos < (size_t)node->num_of_son; son_pos++) {
548         assert(node->son_1st_off_l > 0 || node->son_1st_off_h > 0);
549         LmaNodeGE1 *son = nodes_ge1_ + get_son_offset(node) + son_pos;
550         if (son->spl_idx >= id_start
551             && son->spl_idx < id_start + id_num) {
552           if (*lpi_num < lpi_max) {
553             size_t homo_buf_off = get_homo_idx_buf_offset(son);
554             *lpi_num += fill_lpi_buffer(lpi_items + (*lpi_num),
555                                         lpi_max - *lpi_num, homo_buf_off, son,
556                                         dep->splids_extended + 1);
557           }
558 
559           // If necessary, fill in the new DTMI
560           if (0 == found_num) {
561             found_start = son_pos;
562           }
563           found_num++;
564         }
565         if (son->spl_idx >= id_start + id_num - 1 || son_pos ==
566             (size_t)node->num_of_son - 1) {
567           if (found_num > 0) {
568             if (mile_stones_pos_ < kMaxMileStone &&
569                 parsing_marks_pos_ < kMaxParsingMark) {
570               parsing_marks_[parsing_marks_pos_].node_offset =
571                 get_son_offset(node) + found_start;
572               parsing_marks_[parsing_marks_pos_].node_num = found_num;
573               if (0 == ret_val)
574                 mile_stones_[mile_stones_pos_].mark_start =
575                   parsing_marks_pos_;
576               parsing_marks_pos_++;
577             }
578 
579             ret_val++;
580           }
581           break;
582         }
583       }  // for son_pos
584     }  // for ext_pos
585   }  // for h_pos
586 
587   if (ret_val > 0) {
588     mile_stones_[mile_stones_pos_].mark_num = ret_val;
589     ret_handle = mile_stones_pos_;
590     mile_stones_pos_++;
591   }
592 
593   // printf("----- parsing marks: %d, mile stone: %d \n", parsing_marks_pos_,
594   //        mile_stones_pos_);
595   return ret_handle;
596 }
597 
try_extend(const uint16 * splids,uint16 splid_num,LemmaIdType id_lemma)598 bool DictTrie::try_extend(const uint16 *splids, uint16 splid_num,
599                           LemmaIdType id_lemma) {
600   if (0 == splid_num || NULL == splids)
601     return false;
602 
603   void *node = root_ + splid_le0_index_[splids[0] - kFullSplIdStart];
604 
605   for (uint16 pos = 1; pos < splid_num; pos++) {
606     if (1 == pos) {
607       LmaNodeLE0 *node_le0 = reinterpret_cast<LmaNodeLE0*>(node);
608       LmaNodeGE1 *node_son;
609       uint16 son_pos;
610       for (son_pos = 0; son_pos < static_cast<uint16>(node_le0->num_of_son);
611            son_pos++) {
612         assert(node_le0->son_1st_off <= lma_node_num_ge1_);
613         node_son = nodes_ge1_ + node_le0->son_1st_off
614             + son_pos;
615         if (node_son->spl_idx == splids[pos])
616           break;
617       }
618       if (son_pos < node_le0->num_of_son)
619         node = reinterpret_cast<void*>(node_son);
620       else
621         return false;
622     } else {
623       LmaNodeGE1 *node_ge1 = reinterpret_cast<LmaNodeGE1*>(node);
624       LmaNodeGE1 *node_son;
625       uint16 son_pos;
626       for (son_pos = 0; son_pos < static_cast<uint16>(node_ge1->num_of_son);
627            son_pos++) {
628         assert(node_ge1->son_1st_off_l > 0 || node_ge1->son_1st_off_h > 0);
629         node_son = nodes_ge1_ + get_son_offset(node_ge1) + son_pos;
630         if (node_son->spl_idx == splids[pos])
631           break;
632       }
633       if (son_pos < node_ge1->num_of_son)
634         node = reinterpret_cast<void*>(node_son);
635       else
636         return false;
637     }
638   }
639 
640   if (1 == splid_num) {
641     LmaNodeLE0* node_le0 = reinterpret_cast<LmaNodeLE0*>(node);
642     size_t num_of_homo = (size_t)node_le0->num_of_homo;
643     for (size_t homo_pos = 0; homo_pos < num_of_homo; homo_pos++) {
644       LemmaIdType id_this = get_lemma_id(node_le0->homo_idx_buf_off + homo_pos);
645       char16 str[2];
646       get_lemma_str(id_this, str, 2);
647       if (id_this == id_lemma)
648         return true;
649     }
650   } else {
651     LmaNodeGE1* node_ge1 = reinterpret_cast<LmaNodeGE1*>(node);
652     size_t num_of_homo = (size_t)node_ge1->num_of_homo;
653     for (size_t homo_pos = 0; homo_pos < num_of_homo; homo_pos++) {
654       size_t node_homo_off = get_homo_idx_buf_offset(node_ge1);
655       if (get_lemma_id(node_homo_off + homo_pos) == id_lemma)
656         return true;
657     }
658   }
659 
660   return false;
661 }
662 
get_lpis(const uint16 * splid_str,uint16 splid_str_len,LmaPsbItem * lma_buf,size_t max_lma_buf)663 size_t DictTrie::get_lpis(const uint16* splid_str, uint16 splid_str_len,
664                           LmaPsbItem* lma_buf, size_t max_lma_buf) {
665   if (splid_str_len > kMaxLemmaSize)
666     return 0;
667 
668 #define MAX_EXTENDBUF_LEN 200
669 
670   size_t* node_buf1[MAX_EXTENDBUF_LEN];  // use size_t for data alignment
671   size_t* node_buf2[MAX_EXTENDBUF_LEN];
672   LmaNodeLE0** node_fr_le0 =
673     reinterpret_cast<LmaNodeLE0**>(node_buf1);      // Nodes from.
674   LmaNodeLE0** node_to_le0 =
675     reinterpret_cast<LmaNodeLE0**>(node_buf2);      // Nodes to.
676   LmaNodeGE1** node_fr_ge1 = NULL;
677   LmaNodeGE1** node_to_ge1 = NULL;
678   size_t node_fr_num = 1;
679   size_t node_to_num = 0;
680   node_fr_le0[0] = root_;
681   if (NULL == node_fr_le0[0])
682     return 0;
683 
684   size_t spl_pos = 0;
685 
686   while (spl_pos < splid_str_len) {
687     uint16 id_num = 1;
688     uint16 id_start = splid_str[spl_pos];
689     // If it is a half id
690     if (spl_trie_->is_half_id(splid_str[spl_pos])) {
691       id_num = spl_trie_->half_to_full(splid_str[spl_pos], &id_start);
692       assert(id_num > 0);
693     }
694 
695     // Extend the nodes
696     if (0 == spl_pos) {  // From LmaNodeLE0 (root) to LmaNodeLE0 nodes
697       for (size_t node_fr_pos = 0; node_fr_pos < node_fr_num; node_fr_pos++) {
698         LmaNodeLE0 *node = node_fr_le0[node_fr_pos];
699         assert(node == root_ && 1 == node_fr_num);
700         size_t son_start = splid_le0_index_[id_start - kFullSplIdStart];
701         size_t son_end =
702             splid_le0_index_[id_start + id_num - kFullSplIdStart];
703         for (size_t son_pos = son_start; son_pos < son_end; son_pos++) {
704           assert(1 == node->son_1st_off);
705           LmaNodeLE0 *node_son = root_ + son_pos;
706           assert(node_son->spl_idx >= id_start
707                  && node_son->spl_idx < id_start + id_num);
708           if (node_to_num < MAX_EXTENDBUF_LEN) {
709             node_to_le0[node_to_num] = node_son;
710             node_to_num++;
711           }
712           // id_start + id_num - 1 is the last one, which has just been
713           // recorded.
714           if (node_son->spl_idx >= id_start + id_num - 1)
715             break;
716         }
717       }
718 
719       spl_pos++;
720       if (spl_pos >= splid_str_len || node_to_num == 0)
721         break;
722       // Prepare the nodes for next extending
723       // next time, from LmaNodeLE0 to LmaNodeGE1
724       LmaNodeLE0** node_tmp = node_fr_le0;
725       node_fr_le0 = node_to_le0;
726       node_to_le0 = NULL;
727       node_to_ge1 = reinterpret_cast<LmaNodeGE1**>(node_tmp);
728     } else if (1 == spl_pos) {  // From LmaNodeLE0 to LmaNodeGE1 nodes
729       for (size_t node_fr_pos = 0; node_fr_pos < node_fr_num; node_fr_pos++) {
730         LmaNodeLE0 *node = node_fr_le0[node_fr_pos];
731         for (size_t son_pos = 0; son_pos < (size_t)node->num_of_son;
732              son_pos++) {
733           assert(node->son_1st_off <= lma_node_num_ge1_);
734           LmaNodeGE1 *node_son = nodes_ge1_ + node->son_1st_off
735                                   + son_pos;
736           if (node_son->spl_idx >= id_start
737               && node_son->spl_idx < id_start + id_num) {
738             if (node_to_num < MAX_EXTENDBUF_LEN) {
739               node_to_ge1[node_to_num] = node_son;
740               node_to_num++;
741             }
742           }
743           // id_start + id_num - 1 is the last one, which has just been
744           // recorded.
745           if (node_son->spl_idx >= id_start + id_num - 1)
746             break;
747         }
748       }
749 
750       spl_pos++;
751       if (spl_pos >= splid_str_len || node_to_num == 0)
752         break;
753       // Prepare the nodes for next extending
754       // next time, from LmaNodeGE1 to LmaNodeGE1
755       node_fr_ge1 = node_to_ge1;
756       node_to_ge1 = reinterpret_cast<LmaNodeGE1**>(node_fr_le0);
757       node_fr_le0 = NULL;
758       node_to_le0 = NULL;
759     } else {  // From LmaNodeGE1 to LmaNodeGE1 nodes
760       for (size_t node_fr_pos = 0; node_fr_pos < node_fr_num; node_fr_pos++) {
761         LmaNodeGE1 *node = node_fr_ge1[node_fr_pos];
762         for (size_t son_pos = 0; son_pos < (size_t)node->num_of_son;
763              son_pos++) {
764           assert(node->son_1st_off_l > 0 || node->son_1st_off_h > 0);
765           LmaNodeGE1 *node_son = nodes_ge1_
766                                   + get_son_offset(node) + son_pos;
767           if (node_son->spl_idx >= id_start
768               && node_son->spl_idx < id_start + id_num) {
769             if (node_to_num < MAX_EXTENDBUF_LEN) {
770               node_to_ge1[node_to_num] = node_son;
771               node_to_num++;
772             }
773           }
774           // id_start + id_num - 1 is the last one, which has just been
775           // recorded.
776           if (node_son->spl_idx >= id_start + id_num - 1)
777             break;
778         }
779       }
780 
781       spl_pos++;
782       if (spl_pos >= splid_str_len || node_to_num == 0)
783         break;
784       // Prepare the nodes for next extending
785       // next time, from LmaNodeGE1 to LmaNodeGE1
786       LmaNodeGE1 **node_tmp = node_fr_ge1;
787       node_fr_ge1 = node_to_ge1;
788       node_to_ge1 = node_tmp;
789     }
790 
791     // The number of node for next extending
792     node_fr_num = node_to_num;
793     node_to_num = 0;
794   }  // while
795 
796   if (0 == node_to_num)
797     return 0;
798 
799   NGram &ngram = NGram::get_instance();
800   size_t lma_num = 0;
801 
802   // If the length is 1, and the splid is a one-char Yunmu like 'a', 'o', 'e',
803   // only those candidates for the full matched one-char id will be returned.
804   if (1 == splid_str_len && spl_trie_->is_half_id_yunmu(splid_str[0]))
805     node_to_num = node_to_num > 0 ? 1 : 0;
806 
807   for (size_t node_pos = 0; node_pos < node_to_num; node_pos++) {
808     size_t num_of_homo = 0;
809     if (spl_pos <= 1) {  // Get from LmaNodeLE0 nodes
810       LmaNodeLE0* node_le0 = node_to_le0[node_pos];
811       num_of_homo = (size_t)node_le0->num_of_homo;
812       for (size_t homo_pos = 0; homo_pos < num_of_homo; homo_pos++) {
813         size_t ch_pos = lma_num + homo_pos;
814         lma_buf[ch_pos].id =
815             get_lemma_id(node_le0->homo_idx_buf_off + homo_pos);
816         lma_buf[ch_pos].lma_len = 1;
817         lma_buf[ch_pos].psb =
818             static_cast<LmaScoreType>(ngram.get_uni_psb(lma_buf[ch_pos].id));
819 
820         if (lma_num + homo_pos >= max_lma_buf - 1)
821           break;
822       }
823     } else {  // Get from LmaNodeGE1 nodes
824       LmaNodeGE1* node_ge1 = node_to_ge1[node_pos];
825       num_of_homo = (size_t)node_ge1->num_of_homo;
826       for (size_t homo_pos = 0; homo_pos < num_of_homo; homo_pos++) {
827         size_t ch_pos = lma_num + homo_pos;
828         size_t node_homo_off = get_homo_idx_buf_offset(node_ge1);
829         lma_buf[ch_pos].id = get_lemma_id(node_homo_off + homo_pos);
830         lma_buf[ch_pos].lma_len = splid_str_len;
831         lma_buf[ch_pos].psb =
832             static_cast<LmaScoreType>(ngram.get_uni_psb(lma_buf[ch_pos].id));
833 
834         if (lma_num + homo_pos >= max_lma_buf - 1)
835           break;
836       }
837     }
838 
839     lma_num += num_of_homo;
840     if (lma_num >= max_lma_buf) {
841       lma_num = max_lma_buf;
842       break;
843     }
844   }
845   return lma_num;
846 }
847 
get_lemma_str(LemmaIdType id_lemma,char16 * str_buf,uint16 str_max)848 uint16 DictTrie::get_lemma_str(LemmaIdType id_lemma, char16 *str_buf,
849                                uint16 str_max) {
850   return dict_list_->get_lemma_str(id_lemma, str_buf, str_max);
851 }
852 
get_lemma_splids(LemmaIdType id_lemma,uint16 * splids,uint16 splids_max,bool arg_valid)853 uint16 DictTrie::get_lemma_splids(LemmaIdType id_lemma, uint16 *splids,
854                                   uint16 splids_max, bool arg_valid) {
855   char16 lma_str[kMaxLemmaSize + 1];
856   uint16 lma_len = get_lemma_str(id_lemma, lma_str, kMaxLemmaSize + 1);
857   assert((!arg_valid && splids_max >= lma_len) || lma_len == splids_max);
858 
859   uint16 spl_mtrx[kMaxLemmaSize * 5];
860   uint16 spl_start[kMaxLemmaSize + 1];
861   spl_start[0] = 0;
862   uint16 try_num = 1;
863 
864   for (uint16 pos = 0; pos < lma_len; pos++) {
865     uint16 cand_splids_this = 0;
866     if (arg_valid && spl_trie_->is_full_id(splids[pos])) {
867       spl_mtrx[spl_start[pos]] = splids[pos];
868       cand_splids_this = 1;
869     } else {
870       cand_splids_this = dict_list_->get_splids_for_hanzi(lma_str[pos],
871           arg_valid ? splids[pos] : 0, spl_mtrx + spl_start[pos],
872           kMaxLemmaSize * 5 - spl_start[pos]);
873       assert(cand_splids_this > 0);
874     }
875     spl_start[pos + 1] = spl_start[pos] + cand_splids_this;
876     try_num *= cand_splids_this;
877   }
878 
879   for (uint16 try_pos = 0; try_pos < try_num; try_pos++) {
880     uint16 mod = 1;
881     for (uint16 pos = 0; pos < lma_len; pos++) {
882       uint16 radix = spl_start[pos + 1] - spl_start[pos];
883       splids[pos] = spl_mtrx[ spl_start[pos] + try_pos / mod % radix];
884       mod *= radix;
885     }
886 
887     if (try_extend(splids, lma_len, id_lemma))
888       return lma_len;
889   }
890 
891   return 0;
892 }
893 
set_total_lemma_count_of_others(size_t count)894 void DictTrie::set_total_lemma_count_of_others(size_t count) {
895   NGram& ngram = NGram::get_instance();
896   ngram.set_total_freq_none_sys(count);
897 }
898 
convert_to_hanzis(char16 * str,uint16 str_len)899 void DictTrie::convert_to_hanzis(char16 *str, uint16 str_len) {
900   return dict_list_->convert_to_hanzis(str, str_len);
901 }
902 
convert_to_scis_ids(char16 * str,uint16 str_len)903 void DictTrie::convert_to_scis_ids(char16 *str, uint16 str_len) {
904   return dict_list_->convert_to_scis_ids(str, str_len);
905 }
906 
get_lemma_id(const char16 lemma_str[],uint16 lemma_len)907 LemmaIdType DictTrie::get_lemma_id(const char16 lemma_str[], uint16 lemma_len) {
908   if (NULL == lemma_str || lemma_len > kMaxLemmaSize)
909     return 0;
910 
911   return dict_list_->get_lemma_id(lemma_str, lemma_len);
912 }
913 
predict_top_lmas(size_t his_len,NPredictItem * npre_items,size_t npre_max,size_t b4_used)914 size_t DictTrie::predict_top_lmas(size_t his_len, NPredictItem *npre_items,
915                                   size_t npre_max, size_t b4_used) {
916   NGram &ngram = NGram::get_instance();
917 
918   size_t item_num = 0;
919   size_t top_lmas_id_offset = lma_idx_buf_len_ / kLemmaIdSize - top_lmas_num_;
920   size_t top_lmas_pos = 0;
921   while (item_num < npre_max && top_lmas_pos < top_lmas_num_) {
922     memset(npre_items + item_num, 0, sizeof(NPredictItem));
923     LemmaIdType top_lma_id = get_lemma_id(top_lmas_id_offset + top_lmas_pos);
924     top_lmas_pos += 1;
925     if (dict_list_->get_lemma_str(top_lma_id,
926                                   npre_items[item_num].pre_hzs,
927                                   kMaxLemmaSize - 1) == 0) {
928       continue;
929     }
930     npre_items[item_num].psb = ngram.get_uni_psb(top_lma_id);
931     npre_items[item_num].his_len = his_len;
932     item_num++;
933   }
934   return item_num;
935 }
936 
predict(const char16 * last_hzs,uint16 hzs_len,NPredictItem * npre_items,size_t npre_max,size_t b4_used)937 size_t DictTrie::predict(const char16 *last_hzs, uint16 hzs_len,
938                          NPredictItem *npre_items, size_t npre_max,
939                          size_t b4_used) {
940   return dict_list_->predict(last_hzs, hzs_len, npre_items, npre_max, b4_used);
941 }
942 }  // namespace ime_pinyin
943