1 /*
2 * Copyright (C) 2011 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef LATINIME_BINARY_FORMAT_H
18 #define LATINIME_BINARY_FORMAT_H
19
20 #include <cstdlib>
21 #include <map>
22 #include <stdint.h>
23
24 #include "bloom_filter.h"
25 #include "char_utils.h"
26 #include "hash_map_compat.h"
27
28 namespace latinime {
29
30 class BinaryFormat {
31 public:
32 // Mask and flags for children address type selection.
33 static const int MASK_GROUP_ADDRESS_TYPE = 0xC0;
34
35 // Flag for single/multiple char group
36 static const int FLAG_HAS_MULTIPLE_CHARS = 0x20;
37
38 // Flag for terminal groups
39 static const int FLAG_IS_TERMINAL = 0x10;
40
41 // Flag for shortcut targets presence
42 static const int FLAG_HAS_SHORTCUT_TARGETS = 0x08;
43 // Flag for bigram presence
44 static const int FLAG_HAS_BIGRAMS = 0x04;
45 // Flag for non-words (typically, shortcut only entries)
46 static const int FLAG_IS_NOT_A_WORD = 0x02;
47 // Flag for blacklist
48 static const int FLAG_IS_BLACKLISTED = 0x01;
49
50 // Attribute (bigram/shortcut) related flags:
51 // Flag for presence of more attributes
52 static const int FLAG_ATTRIBUTE_HAS_NEXT = 0x80;
53 // Flag for sign of offset. If this flag is set, the offset value must be negated.
54 static const int FLAG_ATTRIBUTE_OFFSET_NEGATIVE = 0x40;
55
56 // Mask for attribute probability, stored on 4 bits inside the flags byte.
57 static const int MASK_ATTRIBUTE_PROBABILITY = 0x0F;
58 // The numeric value of the shortcut probability that means 'whitelist'.
59 static const int WHITELIST_SHORTCUT_PROBABILITY = 15;
60
61 // Mask and flags for attribute address type selection.
62 static const int MASK_ATTRIBUTE_ADDRESS_TYPE = 0x30;
63
64 static const int UNKNOWN_FORMAT = -1;
65 static const int SHORTCUT_LIST_SIZE_SIZE = 2;
66
67 static int detectFormat(const uint8_t *const dict, const int dictSize);
68 static int getHeaderSize(const uint8_t *const dict, const int dictSize);
69 static int getFlags(const uint8_t *const dict, const int dictSize);
70 static bool hasBlacklistedOrNotAWordFlag(const int flags);
71 static void readHeaderValue(const uint8_t *const dict, const int dictSize,
72 const char *const key, int *outValue, const int outValueSize);
73 static int readHeaderValueInt(const uint8_t *const dict, const int dictSize,
74 const char *const key);
75 static int getGroupCountAndForwardPointer(const uint8_t *const dict, int *pos);
76 static uint8_t getFlagsAndForwardPointer(const uint8_t *const dict, int *pos);
77 static int getCodePointAndForwardPointer(const uint8_t *const dict, int *pos);
78 static int readProbabilityWithoutMovingPointer(const uint8_t *const dict, const int pos);
79 static int skipOtherCharacters(const uint8_t *const dict, const int pos);
80 static int skipChildrenPosition(const uint8_t flags, const int pos);
81 static int skipProbability(const uint8_t flags, const int pos);
82 static int skipShortcuts(const uint8_t *const dict, const uint8_t flags, const int pos);
83 static int skipChildrenPosAndAttributes(const uint8_t *const dict, const uint8_t flags,
84 const int pos);
85 static int readChildrenPosition(const uint8_t *const dict, const uint8_t flags, const int pos);
86 static bool hasChildrenInFlags(const uint8_t flags);
87 static int getAttributeAddressAndForwardPointer(const uint8_t *const dict, const uint8_t flags,
88 int *pos);
89 static int getAttributeProbabilityFromFlags(const int flags);
90 static int getTerminalPosition(const uint8_t *const root, const int *const inWord,
91 const int length, const bool forceLowerCaseSearch);
92 static int getWordAtAddress(const uint8_t *const root, const int address, const int maxDepth,
93 int *outWord, int *outUnigramProbability);
94 static int computeProbabilityForBigram(
95 const int unigramProbability, const int bigramProbability);
96 static int getProbability(const int position, const std::map<int, int> *bigramMap,
97 const uint8_t *bigramFilter, const int unigramProbability);
98 static int getBigramProbabilityFromHashMap(const int position,
99 const hash_map_compat<int, int> *bigramMap, const int unigramProbability);
100 static float getMultiWordCostMultiplier(const uint8_t *const dict, const int dictSize);
101 static void fillBigramProbabilityToHashMap(const uint8_t *const root, int position,
102 hash_map_compat<int, int> *bigramMap);
103 static int getBigramProbability(const uint8_t *const root, int position,
104 const int nextPosition, const int unigramProbability);
105
106 // Flags for special processing
107 // Those *must* match the flags in makedict (BinaryDictInputOutput#*_PROCESSING_FLAG) or
108 // something very bad (like, the apocalypse) will happen. Please update both at the same time.
109 enum {
110 REQUIRES_GERMAN_UMLAUT_PROCESSING = 0x1,
111 REQUIRES_FRENCH_LIGATURES_PROCESSING = 0x4
112 };
113
114 private:
115 DISALLOW_IMPLICIT_CONSTRUCTORS(BinaryFormat);
116 static int getBigramListPositionForWordPosition(const uint8_t *const root, int position);
117
118 static const int FLAG_GROUP_ADDRESS_TYPE_NOADDRESS = 0x00;
119 static const int FLAG_GROUP_ADDRESS_TYPE_ONEBYTE = 0x40;
120 static const int FLAG_GROUP_ADDRESS_TYPE_TWOBYTES = 0x80;
121 static const int FLAG_GROUP_ADDRESS_TYPE_THREEBYTES = 0xC0;
122 static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE = 0x10;
123 static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES = 0x20;
124 static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES = 0x30;
125
126 // Any file smaller than this is not a dictionary.
127 static const int DICTIONARY_MINIMUM_SIZE = 4;
128 // Originally, format version 1 had a 16-bit magic number, then the version number `01'
129 // then options that must be 0. Hence the first 32-bits of the format are always as follow
130 // and it's okay to consider them a magic number as a whole.
131 static const int FORMAT_VERSION_1_MAGIC_NUMBER = 0x78B10100;
132 static const int FORMAT_VERSION_1_HEADER_SIZE = 5;
133 // The versions of Latin IME that only handle format version 1 only test for the magic
134 // number, so we had to change it so that version 2 files would be rejected by older
135 // implementations. On this occasion, we made the magic number 32 bits long.
136 static const int FORMAT_VERSION_2_MAGIC_NUMBER = -1681835266; // 0x9BC13AFE
137 // Magic number (4 bytes), version (2 bytes), options (2 bytes), header size (4 bytes) = 12
138 static const int FORMAT_VERSION_2_MINIMUM_SIZE = 12;
139
140 static const int CHARACTER_ARRAY_TERMINATOR_SIZE = 1;
141 static const int MINIMAL_ONE_BYTE_CHARACTER_VALUE = 0x20;
142 static const int CHARACTER_ARRAY_TERMINATOR = 0x1F;
143 static const int MULTIPLE_BYTE_CHARACTER_ADDITIONAL_SIZE = 2;
144 static const int NO_FLAGS = 0;
145 static int skipAllAttributes(const uint8_t *const dict, const uint8_t flags, const int pos);
146 static int skipBigrams(const uint8_t *const dict, const uint8_t flags, const int pos);
147 };
148
detectFormat(const uint8_t * const dict,const int dictSize)149 AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict, const int dictSize) {
150 // The magic number is stored big-endian.
151 // If the dictionary is less than 4 bytes, we can't even read the magic number, so we don't
152 // understand this format.
153 if (dictSize < DICTIONARY_MINIMUM_SIZE) return UNKNOWN_FORMAT;
154 const int magicNumber = (dict[0] << 24) + (dict[1] << 16) + (dict[2] << 8) + dict[3];
155 switch (magicNumber) {
156 case FORMAT_VERSION_1_MAGIC_NUMBER:
157 // Format 1 header is exactly 5 bytes long and looks like:
158 // Magic number (2 bytes) 0x78 0xB1
159 // Version number (1 byte) 0x01
160 // Options (2 bytes) must be 0x00 0x00
161 return 1;
162 case FORMAT_VERSION_2_MAGIC_NUMBER:
163 // Version 2 dictionaries are at least 12 bytes long (see below details for the header).
164 // If this dictionary has the version 2 magic number but is less than 12 bytes long, then
165 // it's an unknown format and we need to avoid confidently reading the next bytes.
166 if (dictSize < FORMAT_VERSION_2_MINIMUM_SIZE) return UNKNOWN_FORMAT;
167 // Format 2 header is as follows:
168 // Magic number (4 bytes) 0x9B 0xC1 0x3A 0xFE
169 // Version number (2 bytes) 0x00 0x02
170 // Options (2 bytes)
171 // Header size (4 bytes) : integer, big endian
172 return (dict[4] << 8) + dict[5];
173 default:
174 return UNKNOWN_FORMAT;
175 }
176 }
177
getFlags(const uint8_t * const dict,const int dictSize)178 inline int BinaryFormat::getFlags(const uint8_t *const dict, const int dictSize) {
179 switch (detectFormat(dict, dictSize)) {
180 case 1:
181 return NO_FLAGS; // TODO: NO_FLAGS is unused anywhere else?
182 default:
183 return (dict[6] << 8) + dict[7];
184 }
185 }
186
hasBlacklistedOrNotAWordFlag(const int flags)187 inline bool BinaryFormat::hasBlacklistedOrNotAWordFlag(const int flags) {
188 return (flags & (FLAG_IS_BLACKLISTED | FLAG_IS_NOT_A_WORD)) != 0;
189 }
190
getHeaderSize(const uint8_t * const dict,const int dictSize)191 inline int BinaryFormat::getHeaderSize(const uint8_t *const dict, const int dictSize) {
192 switch (detectFormat(dict, dictSize)) {
193 case 1:
194 return FORMAT_VERSION_1_HEADER_SIZE;
195 case 2:
196 // See the format of the header in the comment in detectFormat() above
197 return (dict[8] << 24) + (dict[9] << 16) + (dict[10] << 8) + dict[11];
198 default:
199 return S_INT_MAX;
200 }
201 }
202
readHeaderValue(const uint8_t * const dict,const int dictSize,const char * const key,int * outValue,const int outValueSize)203 inline void BinaryFormat::readHeaderValue(const uint8_t *const dict, const int dictSize,
204 const char *const key, int *outValue, const int outValueSize) {
205 int outValueIndex = 0;
206 // Only format 2 and above have header attributes as {key,value} string pairs. For prior
207 // formats, we just return an empty string, as if the key wasn't found.
208 if (2 <= detectFormat(dict, dictSize)) {
209 const int headerOptionsOffset = 4 /* magic number */
210 + 2 /* dictionary version */ + 2 /* flags */;
211 const int headerSize =
212 (dict[headerOptionsOffset] << 24) + (dict[headerOptionsOffset + 1] << 16)
213 + (dict[headerOptionsOffset + 2] << 8) + dict[headerOptionsOffset + 3];
214 const int headerEnd = headerOptionsOffset + 4 + headerSize;
215 int index = headerOptionsOffset + 4;
216 while (index < headerEnd) {
217 int keyIndex = 0;
218 int codePoint = getCodePointAndForwardPointer(dict, &index);
219 while (codePoint != NOT_A_CODE_POINT) {
220 if (codePoint != key[keyIndex++]) {
221 break;
222 }
223 codePoint = getCodePointAndForwardPointer(dict, &index);
224 }
225 if (codePoint == NOT_A_CODE_POINT && key[keyIndex] == 0) {
226 // We found the key! Copy and return the value.
227 codePoint = getCodePointAndForwardPointer(dict, &index);
228 while (codePoint != NOT_A_CODE_POINT && outValueIndex < outValueSize) {
229 outValue[outValueIndex++] = codePoint;
230 codePoint = getCodePointAndForwardPointer(dict, &index);
231 }
232 // Finished copying. Break to go to the termination code.
233 break;
234 }
235 // We didn't find the key, skip the remainder of it and its value
236 while (codePoint != NOT_A_CODE_POINT) {
237 codePoint = getCodePointAndForwardPointer(dict, &index);
238 }
239 codePoint = getCodePointAndForwardPointer(dict, &index);
240 while (codePoint != NOT_A_CODE_POINT) {
241 codePoint = getCodePointAndForwardPointer(dict, &index);
242 }
243 }
244 // We couldn't find it - fall through and return an empty value.
245 }
246 // Put a terminator 0 if possible at all (always unless outValueSize is <= 0)
247 if (outValueIndex >= outValueSize) outValueIndex = outValueSize - 1;
248 if (outValueIndex >= 0) outValue[outValueIndex] = 0;
249 }
250
readHeaderValueInt(const uint8_t * const dict,const int dictSize,const char * const key)251 inline int BinaryFormat::readHeaderValueInt(const uint8_t *const dict, const int dictSize,
252 const char *const key) {
253 const int bufferSize = LARGEST_INT_DIGIT_COUNT;
254 int intBuffer[bufferSize];
255 char charBuffer[bufferSize];
256 BinaryFormat::readHeaderValue(dict, dictSize, key, intBuffer, bufferSize);
257 for (int i = 0; i < bufferSize; ++i) {
258 charBuffer[i] = intBuffer[i];
259 }
260 // If not a number, return S_INT_MIN
261 if (!isdigit(charBuffer[0])) return S_INT_MIN;
262 return atoi(charBuffer);
263 }
264
getGroupCountAndForwardPointer(const uint8_t * const dict,int * pos)265 AK_FORCE_INLINE int BinaryFormat::getGroupCountAndForwardPointer(const uint8_t *const dict,
266 int *pos) {
267 const int msb = dict[(*pos)++];
268 if (msb < 0x80) return msb;
269 return ((msb & 0x7F) << 8) | dict[(*pos)++];
270 }
271
getMultiWordCostMultiplier(const uint8_t * const dict,const int dictSize)272 inline float BinaryFormat::getMultiWordCostMultiplier(const uint8_t *const dict,
273 const int dictSize) {
274 const int headerValue = readHeaderValueInt(dict, dictSize,
275 "MULTIPLE_WORDS_DEMOTION_RATE");
276 if (headerValue == S_INT_MIN) {
277 return 1.0f;
278 }
279 if (headerValue <= 0) {
280 return static_cast<float>(MAX_VALUE_FOR_WEIGHTING);
281 }
282 return 100.0f / static_cast<float>(headerValue);
283 }
284
getFlagsAndForwardPointer(const uint8_t * const dict,int * pos)285 inline uint8_t BinaryFormat::getFlagsAndForwardPointer(const uint8_t *const dict, int *pos) {
286 return dict[(*pos)++];
287 }
288
getCodePointAndForwardPointer(const uint8_t * const dict,int * pos)289 AK_FORCE_INLINE int BinaryFormat::getCodePointAndForwardPointer(const uint8_t *const dict,
290 int *pos) {
291 const int origin = *pos;
292 const int codePoint = dict[origin];
293 if (codePoint < MINIMAL_ONE_BYTE_CHARACTER_VALUE) {
294 if (codePoint == CHARACTER_ARRAY_TERMINATOR) {
295 *pos = origin + 1;
296 return NOT_A_CODE_POINT;
297 } else {
298 *pos = origin + 3;
299 const int char_1 = codePoint << 16;
300 const int char_2 = char_1 + (dict[origin + 1] << 8);
301 return char_2 + dict[origin + 2];
302 }
303 } else {
304 *pos = origin + 1;
305 return codePoint;
306 }
307 }
308
readProbabilityWithoutMovingPointer(const uint8_t * const dict,const int pos)309 inline int BinaryFormat::readProbabilityWithoutMovingPointer(const uint8_t *const dict,
310 const int pos) {
311 return dict[pos];
312 }
313
skipOtherCharacters(const uint8_t * const dict,const int pos)314 AK_FORCE_INLINE int BinaryFormat::skipOtherCharacters(const uint8_t *const dict, const int pos) {
315 int currentPos = pos;
316 int character = dict[currentPos++];
317 while (CHARACTER_ARRAY_TERMINATOR != character) {
318 if (character < MINIMAL_ONE_BYTE_CHARACTER_VALUE) {
319 currentPos += MULTIPLE_BYTE_CHARACTER_ADDITIONAL_SIZE;
320 }
321 character = dict[currentPos++];
322 }
323 return currentPos;
324 }
325
attributeAddressSize(const uint8_t flags)326 static inline int attributeAddressSize(const uint8_t flags) {
327 static const int ATTRIBUTE_ADDRESS_SHIFT = 4;
328 return (flags & BinaryFormat::MASK_ATTRIBUTE_ADDRESS_TYPE) >> ATTRIBUTE_ADDRESS_SHIFT;
329 /* Note: this is a value-dependant optimization of what may probably be
330 more readably written this way:
331 switch (flags * BinaryFormat::MASK_ATTRIBUTE_ADDRESS_TYPE) {
332 case FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE: return 1;
333 case FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES: return 2;
334 case FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTE: return 3;
335 default: return 0;
336 }
337 */
338 }
339
skipExistingBigrams(const uint8_t * const dict,const int pos)340 static AK_FORCE_INLINE int skipExistingBigrams(const uint8_t *const dict, const int pos) {
341 int currentPos = pos;
342 uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(dict, ¤tPos);
343 while (flags & BinaryFormat::FLAG_ATTRIBUTE_HAS_NEXT) {
344 currentPos += attributeAddressSize(flags);
345 flags = BinaryFormat::getFlagsAndForwardPointer(dict, ¤tPos);
346 }
347 currentPos += attributeAddressSize(flags);
348 return currentPos;
349 }
350
childrenAddressSize(const uint8_t flags)351 static inline int childrenAddressSize(const uint8_t flags) {
352 static const int CHILDREN_ADDRESS_SHIFT = 6;
353 return (BinaryFormat::MASK_GROUP_ADDRESS_TYPE & flags) >> CHILDREN_ADDRESS_SHIFT;
354 /* See the note in attributeAddressSize. The same applies here */
355 }
356
shortcutByteSize(const uint8_t * const dict,const int pos)357 static AK_FORCE_INLINE int shortcutByteSize(const uint8_t *const dict, const int pos) {
358 return (static_cast<int>(dict[pos] << 8)) + (dict[pos + 1]);
359 }
360
skipChildrenPosition(const uint8_t flags,const int pos)361 inline int BinaryFormat::skipChildrenPosition(const uint8_t flags, const int pos) {
362 return pos + childrenAddressSize(flags);
363 }
364
skipProbability(const uint8_t flags,const int pos)365 inline int BinaryFormat::skipProbability(const uint8_t flags, const int pos) {
366 return FLAG_IS_TERMINAL & flags ? pos + 1 : pos;
367 }
368
skipShortcuts(const uint8_t * const dict,const uint8_t flags,const int pos)369 AK_FORCE_INLINE int BinaryFormat::skipShortcuts(const uint8_t *const dict, const uint8_t flags,
370 const int pos) {
371 if (FLAG_HAS_SHORTCUT_TARGETS & flags) {
372 return pos + shortcutByteSize(dict, pos);
373 } else {
374 return pos;
375 }
376 }
377
skipBigrams(const uint8_t * const dict,const uint8_t flags,const int pos)378 AK_FORCE_INLINE int BinaryFormat::skipBigrams(const uint8_t *const dict, const uint8_t flags,
379 const int pos) {
380 if (FLAG_HAS_BIGRAMS & flags) {
381 return skipExistingBigrams(dict, pos);
382 } else {
383 return pos;
384 }
385 }
386
skipAllAttributes(const uint8_t * const dict,const uint8_t flags,const int pos)387 AK_FORCE_INLINE int BinaryFormat::skipAllAttributes(const uint8_t *const dict, const uint8_t flags,
388 const int pos) {
389 // This function skips all attributes: shortcuts and bigrams.
390 int newPos = pos;
391 newPos = skipShortcuts(dict, flags, newPos);
392 newPos = skipBigrams(dict, flags, newPos);
393 return newPos;
394 }
395
skipChildrenPosAndAttributes(const uint8_t * const dict,const uint8_t flags,const int pos)396 AK_FORCE_INLINE int BinaryFormat::skipChildrenPosAndAttributes(const uint8_t *const dict,
397 const uint8_t flags, const int pos) {
398 int currentPos = pos;
399 currentPos = skipChildrenPosition(flags, currentPos);
400 currentPos = skipAllAttributes(dict, flags, currentPos);
401 return currentPos;
402 }
403
readChildrenPosition(const uint8_t * const dict,const uint8_t flags,const int pos)404 AK_FORCE_INLINE int BinaryFormat::readChildrenPosition(const uint8_t *const dict,
405 const uint8_t flags, const int pos) {
406 int offset = 0;
407 switch (MASK_GROUP_ADDRESS_TYPE & flags) {
408 case FLAG_GROUP_ADDRESS_TYPE_ONEBYTE:
409 offset = dict[pos];
410 break;
411 case FLAG_GROUP_ADDRESS_TYPE_TWOBYTES:
412 offset = dict[pos] << 8;
413 offset += dict[pos + 1];
414 break;
415 case FLAG_GROUP_ADDRESS_TYPE_THREEBYTES:
416 offset = dict[pos] << 16;
417 offset += dict[pos + 1] << 8;
418 offset += dict[pos + 2];
419 break;
420 default:
421 // If we come here, it means we asked for the children of a word with
422 // no children.
423 return -1;
424 }
425 return pos + offset;
426 }
427
hasChildrenInFlags(const uint8_t flags)428 inline bool BinaryFormat::hasChildrenInFlags(const uint8_t flags) {
429 return (FLAG_GROUP_ADDRESS_TYPE_NOADDRESS != (MASK_GROUP_ADDRESS_TYPE & flags));
430 }
431
getAttributeAddressAndForwardPointer(const uint8_t * const dict,const uint8_t flags,int * pos)432 AK_FORCE_INLINE int BinaryFormat::getAttributeAddressAndForwardPointer(const uint8_t *const dict,
433 const uint8_t flags, int *pos) {
434 int offset = 0;
435 const int origin = *pos;
436 switch (MASK_ATTRIBUTE_ADDRESS_TYPE & flags) {
437 case FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE:
438 offset = dict[origin];
439 *pos = origin + 1;
440 break;
441 case FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES:
442 offset = dict[origin] << 8;
443 offset += dict[origin + 1];
444 *pos = origin + 2;
445 break;
446 case FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES:
447 offset = dict[origin] << 16;
448 offset += dict[origin + 1] << 8;
449 offset += dict[origin + 2];
450 *pos = origin + 3;
451 break;
452 }
453 if (FLAG_ATTRIBUTE_OFFSET_NEGATIVE & flags) {
454 return origin - offset;
455 } else {
456 return origin + offset;
457 }
458 }
459
getAttributeProbabilityFromFlags(const int flags)460 inline int BinaryFormat::getAttributeProbabilityFromFlags(const int flags) {
461 return flags & MASK_ATTRIBUTE_PROBABILITY;
462 }
463
464 // This function gets the byte position of the last chargroup of the exact matching word in the
465 // dictionary. If no match is found, it returns NOT_VALID_WORD.
getTerminalPosition(const uint8_t * const root,const int * const inWord,const int length,const bool forceLowerCaseSearch)466 AK_FORCE_INLINE int BinaryFormat::getTerminalPosition(const uint8_t *const root,
467 const int *const inWord, const int length, const bool forceLowerCaseSearch) {
468 int pos = 0;
469 int wordPos = 0;
470
471 while (true) {
472 // If we already traversed the tree further than the word is long, there means
473 // there was no match (or we would have found it).
474 if (wordPos >= length) return NOT_VALID_WORD;
475 int charGroupCount = BinaryFormat::getGroupCountAndForwardPointer(root, &pos);
476 const int wChar = forceLowerCaseSearch ? toLowerCase(inWord[wordPos]) : inWord[wordPos];
477 while (true) {
478 // If there are no more character groups in this node, it means we could not
479 // find a matching character for this depth, therefore there is no match.
480 if (0 >= charGroupCount) return NOT_VALID_WORD;
481 const int charGroupPos = pos;
482 const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(root, &pos);
483 int character = BinaryFormat::getCodePointAndForwardPointer(root, &pos);
484 if (character == wChar) {
485 // This is the correct node. Only one character group may start with the same
486 // char within a node, so either we found our match in this node, or there is
487 // no match and we can return NOT_VALID_WORD. So we will check all the characters
488 // in this character group indeed does match.
489 if (FLAG_HAS_MULTIPLE_CHARS & flags) {
490 character = BinaryFormat::getCodePointAndForwardPointer(root, &pos);
491 while (NOT_A_CODE_POINT != character) {
492 ++wordPos;
493 // If we shoot the length of the word we search for, or if we find a single
494 // character that does not match, as explained above, it means the word is
495 // not in the dictionary (by virtue of this chargroup being the only one to
496 // match the word on the first character, but not matching the whole word).
497 if (wordPos >= length) return NOT_VALID_WORD;
498 if (inWord[wordPos] != character) return NOT_VALID_WORD;
499 character = BinaryFormat::getCodePointAndForwardPointer(root, &pos);
500 }
501 }
502 // If we come here we know that so far, we do match. Either we are on a terminal
503 // and we match the length, in which case we found it, or we traverse children.
504 // If we don't match the length AND don't have children, then a word in the
505 // dictionary fully matches a prefix of the searched word but not the full word.
506 ++wordPos;
507 if (FLAG_IS_TERMINAL & flags) {
508 if (wordPos == length) {
509 return charGroupPos;
510 }
511 pos = BinaryFormat::skipProbability(FLAG_IS_TERMINAL, pos);
512 }
513 if (FLAG_GROUP_ADDRESS_TYPE_NOADDRESS == (MASK_GROUP_ADDRESS_TYPE & flags)) {
514 return NOT_VALID_WORD;
515 }
516 // We have children and we are still shorter than the word we are searching for, so
517 // we need to traverse children. Put the pointer on the children position, and
518 // break
519 pos = BinaryFormat::readChildrenPosition(root, flags, pos);
520 break;
521 } else {
522 // This chargroup does not match, so skip the remaining part and go to the next.
523 if (FLAG_HAS_MULTIPLE_CHARS & flags) {
524 pos = BinaryFormat::skipOtherCharacters(root, pos);
525 }
526 pos = BinaryFormat::skipProbability(flags, pos);
527 pos = BinaryFormat::skipChildrenPosAndAttributes(root, flags, pos);
528 }
529 --charGroupCount;
530 }
531 }
532 }
533
534 // This function searches for a terminal in the dictionary by its address.
535 // Due to the fact that words are ordered in the dictionary in a strict breadth-first order,
536 // it is possible to check for this with advantageous complexity. For each node, we search
537 // for groups with children and compare the children address with the address we look for.
538 // When we shoot the address we look for, it means the word we look for is in the children
539 // of the previous group. The only tricky part is the fact that if we arrive at the end of a
540 // node with the last group's children address still less than what we are searching for, we
541 // must descend the last group's children (for example, if the word we are searching for starts
542 // with a z, it's the last group of the root node, so all children addresses will be smaller
543 // than the address we look for, and we have to descend the z node).
544 /* Parameters :
545 * root: the dictionary buffer
546 * address: the byte position of the last chargroup of the word we are searching for (this is
547 * what is stored as the "bigram address" in each bigram)
548 * outword: an array to write the found word, with MAX_WORD_LENGTH size.
549 * outUnigramProbability: a pointer to an int to write the probability into.
550 * Return value : the length of the word, of 0 if the word was not found.
551 */
getWordAtAddress(const uint8_t * const root,const int address,const int maxDepth,int * outWord,int * outUnigramProbability)552 AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, const int address,
553 const int maxDepth, int *outWord, int *outUnigramProbability) {
554 int pos = 0;
555 int wordPos = 0;
556
557 // One iteration of the outer loop iterates through nodes. As stated above, we will only
558 // traverse nodes that are actually a part of the terminal we are searching, so each time
559 // we enter this loop we are one depth level further than last time.
560 // The only reason we count nodes is because we want to reduce the probability of infinite
561 // looping in case there is a bug. Since we know there is an upper bound to the depth we are
562 // supposed to traverse, it does not hurt to count iterations.
563 for (int loopCount = maxDepth; loopCount > 0; --loopCount) {
564 int lastCandidateGroupPos = 0;
565 // Let's loop through char groups in this node searching for either the terminal
566 // or one of its ascendants.
567 for (int charGroupCount = getGroupCountAndForwardPointer(root, &pos); charGroupCount > 0;
568 --charGroupCount) {
569 const int startPos = pos;
570 const uint8_t flags = getFlagsAndForwardPointer(root, &pos);
571 const int character = getCodePointAndForwardPointer(root, &pos);
572 if (address == startPos) {
573 // We found the address. Copy the rest of the word in the buffer and return
574 // the length.
575 outWord[wordPos] = character;
576 if (FLAG_HAS_MULTIPLE_CHARS & flags) {
577 int nextChar = getCodePointAndForwardPointer(root, &pos);
578 // We count chars in order to avoid infinite loops if the file is broken or
579 // if there is some other bug
580 int charCount = maxDepth;
581 while (NOT_A_CODE_POINT != nextChar && --charCount > 0) {
582 outWord[++wordPos] = nextChar;
583 nextChar = getCodePointAndForwardPointer(root, &pos);
584 }
585 }
586 *outUnigramProbability = readProbabilityWithoutMovingPointer(root, pos);
587 return ++wordPos;
588 }
589 // We need to skip past this char group, so skip any remaining chars after the
590 // first and possibly the probability.
591 if (FLAG_HAS_MULTIPLE_CHARS & flags) {
592 pos = skipOtherCharacters(root, pos);
593 }
594 pos = skipProbability(flags, pos);
595
596 // The fact that this group has children is very important. Since we already know
597 // that this group does not match, if it has no children we know it is irrelevant
598 // to what we are searching for.
599 const bool hasChildren = (FLAG_GROUP_ADDRESS_TYPE_NOADDRESS !=
600 (MASK_GROUP_ADDRESS_TYPE & flags));
601 // We will write in `found' whether we have passed the children address we are
602 // searching for. For example if we search for "beer", the children of b are less
603 // than the address we are searching for and the children of c are greater. When we
604 // come here for c, we realize this is too big, and that we should descend b.
605 bool found;
606 if (hasChildren) {
607 // Here comes the tricky part. First, read the children position.
608 const int childrenPos = readChildrenPosition(root, flags, pos);
609 if (childrenPos > address) {
610 // If the children pos is greater than address, it means the previous chargroup,
611 // which address is stored in lastCandidateGroupPos, was the right one.
612 found = true;
613 } else if (1 >= charGroupCount) {
614 // However if we are on the LAST group of this node, and we have NOT shot the
615 // address we should descend THIS node. So we trick the lastCandidateGroupPos
616 // so that we will descend this node, not the previous one.
617 lastCandidateGroupPos = startPos;
618 found = true;
619 } else {
620 // Else, we should continue looking.
621 found = false;
622 }
623 } else {
624 // Even if we don't have children here, we could still be on the last group of this
625 // node. If this is the case, we should descend the last group that had children,
626 // and their address is already in lastCandidateGroup.
627 found = (1 >= charGroupCount);
628 }
629
630 if (found) {
631 // Okay, we found the group we should descend. Its address is in
632 // the lastCandidateGroupPos variable, so we just re-read it.
633 if (0 != lastCandidateGroupPos) {
634 const uint8_t lastFlags =
635 getFlagsAndForwardPointer(root, &lastCandidateGroupPos);
636 const int lastChar =
637 getCodePointAndForwardPointer(root, &lastCandidateGroupPos);
638 // We copy all the characters in this group to the buffer
639 outWord[wordPos] = lastChar;
640 if (FLAG_HAS_MULTIPLE_CHARS & lastFlags) {
641 int nextChar = getCodePointAndForwardPointer(root, &lastCandidateGroupPos);
642 int charCount = maxDepth;
643 while (-1 != nextChar && --charCount > 0) {
644 outWord[++wordPos] = nextChar;
645 nextChar = getCodePointAndForwardPointer(root, &lastCandidateGroupPos);
646 }
647 }
648 ++wordPos;
649 // Now we only need to branch to the children address. Skip the probability if
650 // it's there, read pos, and break to resume the search at pos.
651 lastCandidateGroupPos = skipProbability(lastFlags, lastCandidateGroupPos);
652 pos = readChildrenPosition(root, lastFlags, lastCandidateGroupPos);
653 break;
654 } else {
655 // Here is a little tricky part: we come here if we found out that all children
656 // addresses in this group are bigger than the address we are searching for.
657 // Should we conclude the word is not in the dictionary? No! It could still be
658 // one of the remaining chargroups in this node, so we have to keep looking in
659 // this node until we find it (or we realize it's not there either, in which
660 // case it's actually not in the dictionary). Pass the end of this group, ready
661 // to start the next one.
662 pos = skipChildrenPosAndAttributes(root, flags, pos);
663 }
664 } else {
665 // If we did not find it, we should record the last children address for the next
666 // iteration.
667 if (hasChildren) lastCandidateGroupPos = startPos;
668 // Now skip the end of this group (children pos and the attributes if any) so that
669 // our pos is after the end of this char group, at the start of the next one.
670 pos = skipChildrenPosAndAttributes(root, flags, pos);
671 }
672
673 }
674 }
675 // If we have looked through all the chargroups and found no match, the address is
676 // not the address of a terminal in this dictionary.
677 return 0;
678 }
679
backoff(const int unigramProbability)680 static inline int backoff(const int unigramProbability) {
681 return unigramProbability;
682 // For some reason, applying the backoff weight gives bad results in tests. To apply the
683 // backoff weight, we divide the probability by 2, which in our storing format means
684 // decreasing the score by 8.
685 // TODO: figure out what's wrong with this.
686 // return unigramProbability > 8 ? unigramProbability - 8 : (0 == unigramProbability ? 0 : 8);
687 }
688
computeProbabilityForBigram(const int unigramProbability,const int bigramProbability)689 inline int BinaryFormat::computeProbabilityForBigram(
690 const int unigramProbability, const int bigramProbability) {
691 // We divide the range [unigramProbability..255] in 16.5 steps - in other words, we want the
692 // unigram probability to be the median value of the 17th step from the top. A value of
693 // 0 for the bigram probability represents the middle of the 16th step from the top,
694 // while a value of 15 represents the middle of the top step.
695 // See makedict.BinaryDictInputOutput for details.
696 const float stepSize = static_cast<float>(MAX_PROBABILITY - unigramProbability)
697 / (1.5f + MAX_BIGRAM_ENCODED_PROBABILITY);
698 return unigramProbability
699 + static_cast<int>(static_cast<float>(bigramProbability + 1) * stepSize);
700 }
701
702 // This returns a probability in log space.
getProbability(const int position,const std::map<int,int> * bigramMap,const uint8_t * bigramFilter,const int unigramProbability)703 inline int BinaryFormat::getProbability(const int position, const std::map<int, int> *bigramMap,
704 const uint8_t *bigramFilter, const int unigramProbability) {
705 if (!bigramMap || !bigramFilter) return backoff(unigramProbability);
706 if (!isInFilter(bigramFilter, position)) return backoff(unigramProbability);
707 const std::map<int, int>::const_iterator bigramProbabilityIt = bigramMap->find(position);
708 if (bigramProbabilityIt != bigramMap->end()) {
709 const int bigramProbability = bigramProbabilityIt->second;
710 return computeProbabilityForBigram(unigramProbability, bigramProbability);
711 }
712 return backoff(unigramProbability);
713 }
714
715 // This returns a probability in log space.
getBigramProbabilityFromHashMap(const int position,const hash_map_compat<int,int> * bigramMap,const int unigramProbability)716 inline int BinaryFormat::getBigramProbabilityFromHashMap(const int position,
717 const hash_map_compat<int, int> *bigramMap, const int unigramProbability) {
718 if (!bigramMap) return backoff(unigramProbability);
719 const hash_map_compat<int, int>::const_iterator bigramProbabilityIt = bigramMap->find(position);
720 if (bigramProbabilityIt != bigramMap->end()) {
721 const int bigramProbability = bigramProbabilityIt->second;
722 return computeProbabilityForBigram(unigramProbability, bigramProbability);
723 }
724 return backoff(unigramProbability);
725 }
726
fillBigramProbabilityToHashMap(const uint8_t * const root,int position,hash_map_compat<int,int> * bigramMap)727 AK_FORCE_INLINE void BinaryFormat::fillBigramProbabilityToHashMap(
728 const uint8_t *const root, int position, hash_map_compat<int, int> *bigramMap) {
729 position = getBigramListPositionForWordPosition(root, position);
730 if (0 == position) return;
731
732 uint8_t bigramFlags;
733 do {
734 bigramFlags = getFlagsAndForwardPointer(root, &position);
735 const int probability = MASK_ATTRIBUTE_PROBABILITY & bigramFlags;
736 const int bigramPos = getAttributeAddressAndForwardPointer(root, bigramFlags,
737 &position);
738 (*bigramMap)[bigramPos] = probability;
739 } while (FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags);
740 }
741
getBigramProbability(const uint8_t * const root,int position,const int nextPosition,const int unigramProbability)742 AK_FORCE_INLINE int BinaryFormat::getBigramProbability(const uint8_t *const root, int position,
743 const int nextPosition, const int unigramProbability) {
744 position = getBigramListPositionForWordPosition(root, position);
745 if (0 == position) return backoff(unigramProbability);
746
747 uint8_t bigramFlags;
748 do {
749 bigramFlags = getFlagsAndForwardPointer(root, &position);
750 const int bigramPos = getAttributeAddressAndForwardPointer(
751 root, bigramFlags, &position);
752 if (bigramPos == nextPosition) {
753 const int bigramProbability = MASK_ATTRIBUTE_PROBABILITY & bigramFlags;
754 return computeProbabilityForBigram(unigramProbability, bigramProbability);
755 }
756 } while (FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags);
757 return backoff(unigramProbability);
758 }
759
760 // Returns a pointer to the start of the bigram list.
getBigramListPositionForWordPosition(const uint8_t * const root,int position)761 AK_FORCE_INLINE int BinaryFormat::getBigramListPositionForWordPosition(
762 const uint8_t *const root, int position) {
763 if (NOT_VALID_WORD == position) return 0;
764 const uint8_t flags = getFlagsAndForwardPointer(root, &position);
765 if (!(flags & FLAG_HAS_BIGRAMS)) return 0;
766 if (flags & FLAG_HAS_MULTIPLE_CHARS) {
767 position = skipOtherCharacters(root, position);
768 } else {
769 getCodePointAndForwardPointer(root, &position);
770 }
771 position = skipProbability(flags, position);
772 position = skipChildrenPosition(flags, position);
773 position = skipShortcuts(root, flags, position);
774 return position;
775 }
776
777 } // namespace latinime
778 #endif // LATINIME_BINARY_FORMAT_H
779