1 /*
2 * Copyright (C) 2011 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef LATINIME_CORRECTION_H
18 #define LATINIME_CORRECTION_H
19
20 #include <cstring> // for memset()
21
22 #include "correction_state.h"
23 #include "defines.h"
24 #include "proximity_info_state.h"
25
26 namespace latinime {
27
28 class ProximityInfo;
29
30 class Correction {
31 public:
32 typedef enum {
33 TRAVERSE_ALL_ON_TERMINAL,
34 TRAVERSE_ALL_NOT_ON_TERMINAL,
35 UNRELATED,
36 ON_TERMINAL,
37 NOT_ON_TERMINAL
38 } CorrectionType;
39
Correction()40 Correction()
41 : mProximityInfo(0), mUseFullEditDistance(false), mDoAutoCompletion(false),
42 mMaxEditDistance(0), mMaxDepth(0), mInputSize(0), mSpaceProximityPos(0),
43 mMissingSpacePos(0), mTerminalInputIndex(0), mTerminalOutputIndex(0), mMaxErrors(0),
44 mTotalTraverseCount(0), mNeedsToTraverseAllNodes(false), mOutputIndex(0),
45 mInputIndex(0), mEquivalentCharCount(0), mProximityCount(0), mExcessiveCount(0),
46 mTransposedCount(0), mSkippedCount(0), mTransposedPos(0), mExcessivePos(0),
47 mSkipPos(0), mLastCharExceeded(false), mMatching(false), mProximityMatching(false),
48 mAdditionalProximityMatching(false), mExceeding(false), mTransposing(false),
49 mSkipping(false), mProximityInfoState() {
50 memset(mWord, 0, sizeof(mWord));
51 memset(mDistances, 0, sizeof(mDistances));
52 memset(mEditDistanceTable, 0, sizeof(mEditDistanceTable));
53 // NOTE: mCorrectionStates is an array of instances.
54 // No need to initialize it explicitly here.
55 }
56
57 // Non virtual inline destructor -- never inherit this class
~Correction()58 ~Correction() {}
59 void resetCorrection();
60 void initCorrection(const ProximityInfo *pi, const int inputSize, const int maxDepth);
61 void initCorrectionState(const int rootPos, const int childCount, const bool traverseAll);
62
63 // TODO: remove
64 void setCorrectionParams(const int skipPos, const int excessivePos, const int transposedPos,
65 const int spaceProximityPos, const int missingSpacePos, const bool useFullEditDistance,
66 const bool doAutoCompletion, const int maxErrors);
67 void checkState() const;
68 bool sameAsTyped() const;
69 bool initProcessState(const int index);
70
71 int getInputIndex() const;
72
73 bool needsToPrune() const;
74
pushAndGetTotalTraverseCount()75 int pushAndGetTotalTraverseCount() {
76 return ++mTotalTraverseCount;
77 }
78
79 int getFreqForSplitMultipleWords(const int *freqArray, const int *wordLengthArray,
80 const int wordCount, const bool isSpaceProximity, const int *word) const;
81 int getFinalProbability(const int probability, int **word, int *wordLength);
82 int getFinalProbabilityForSubQueue(const int probability, int **word, int *wordLength,
83 const int inputSize);
84
85 CorrectionType processCharAndCalcState(const int c, const bool isTerminal);
86
87 /////////////////////////
88 // Tree helper methods
89 int goDownTree(const int parentIndex, const int childCount, const int firstChildPos);
90
getTreeSiblingPos(const int index)91 inline int getTreeSiblingPos(const int index) const {
92 return mCorrectionStates[index].mSiblingPos;
93 }
94
setTreeSiblingPos(const int index,const int pos)95 inline void setTreeSiblingPos(const int index, const int pos) {
96 mCorrectionStates[index].mSiblingPos = pos;
97 }
98
getTreeParentIndex(const int index)99 inline int getTreeParentIndex(const int index) const {
100 return mCorrectionStates[index].mParentIndex;
101 }
102
103 class RankingAlgorithm {
104 public:
105 static int calculateFinalProbability(const int inputIndex, const int depth,
106 const int probability, int *editDistanceTable, const Correction *correction,
107 const int inputSize);
108 static int calcFreqForSplitMultipleWords(const int *freqArray, const int *wordLengthArray,
109 const int wordCount, const Correction *correction, const bool isSpaceProximity,
110 const int *word);
111 static float calcNormalizedScore(const int *before, const int beforeLength,
112 const int *after, const int afterLength, const int score);
113 static int editDistance(const int *before, const int beforeLength, const int *after,
114 const int afterLength);
115 private:
116 static const int MAX_INITIAL_SCORE = 255;
117 };
118
119 // proximity info state
initInputParams(const ProximityInfo * proximityInfo,const int * inputCodes,const int inputSize,const int * xCoordinates,const int * yCoordinates)120 void initInputParams(const ProximityInfo *proximityInfo, const int *inputCodes,
121 const int inputSize, const int *xCoordinates, const int *yCoordinates) {
122 mProximityInfoState.initInputParams(0, static_cast<float>(MAX_VALUE_FOR_WEIGHTING),
123 proximityInfo, inputCodes, inputSize, xCoordinates, yCoordinates, 0, 0, false);
124 }
125
getPrimaryInputWord()126 const int *getPrimaryInputWord() const {
127 return mProximityInfoState.getPrimaryInputWord();
128 }
129
getPrimaryCodePointAt(const int index)130 int getPrimaryCodePointAt(const int index) const {
131 return mProximityInfoState.getPrimaryCodePointAt(index);
132 }
133
134 private:
135 DISALLOW_COPY_AND_ASSIGN(Correction);
136
137 /////////////////////////
138 // static inline utils //
139 /////////////////////////
140 static const int TWO_31ST_DIV_255 = S_INT_MAX / 255;
capped255MultForFullMatchAccentsOrCapitalizationDifference(const int num)141 static inline int capped255MultForFullMatchAccentsOrCapitalizationDifference(const int num) {
142 return (num < TWO_31ST_DIV_255 ? 255 * num : S_INT_MAX);
143 }
144
145 static const int TWO_31ST_DIV_2 = S_INT_MAX / 2;
multiplyIntCapped(const int multiplier,int * base)146 AK_FORCE_INLINE static void multiplyIntCapped(const int multiplier, int *base) {
147 const int temp = *base;
148 if (temp != S_INT_MAX) {
149 // Branch if multiplier == 2 for the optimization
150 if (multiplier < 0) {
151 if (DEBUG_DICT) {
152 ASSERT(false);
153 }
154 AKLOGI("--- Invalid multiplier: %d", multiplier);
155 } else if (multiplier == 0) {
156 *base = 0;
157 } else if (multiplier == 2) {
158 *base = TWO_31ST_DIV_2 >= temp ? temp << 1 : S_INT_MAX;
159 } else {
160 // TODO: This overflow check gives a wrong answer when, for example,
161 // temp = 2^16 + 1 and multiplier = 2^17 + 1.
162 // Fix this behavior.
163 const int tempRetval = temp * multiplier;
164 *base = tempRetval >= temp ? tempRetval : S_INT_MAX;
165 }
166 }
167 }
168
powerIntCapped(const int base,const int n)169 AK_FORCE_INLINE static int powerIntCapped(const int base, const int n) {
170 if (n <= 0) return 1;
171 if (base == 2) {
172 return n < 31 ? 1 << n : S_INT_MAX;
173 }
174 int ret = base;
175 for (int i = 1; i < n; ++i) multiplyIntCapped(base, &ret);
176 return ret;
177 }
178
multiplyRate(const int rate,int * freq)179 AK_FORCE_INLINE static void multiplyRate(const int rate, int *freq) {
180 if (*freq != S_INT_MAX) {
181 if (*freq > 1000000) {
182 *freq /= 100;
183 multiplyIntCapped(rate, freq);
184 } else {
185 multiplyIntCapped(rate, freq);
186 *freq /= 100;
187 }
188 }
189 }
190
getSpaceProximityPos()191 inline int getSpaceProximityPos() const {
192 return mSpaceProximityPos;
193 }
getMissingSpacePos()194 inline int getMissingSpacePos() const {
195 return mMissingSpacePos;
196 }
197
getSkipPos()198 inline int getSkipPos() const {
199 return mSkipPos;
200 }
201
getExcessivePos()202 inline int getExcessivePos() const {
203 return mExcessivePos;
204 }
205
getTransposedPos()206 inline int getTransposedPos() const {
207 return mTransposedPos;
208 }
209
210 inline void incrementInputIndex();
211 inline void incrementOutputIndex();
212 inline void startToTraverseAllNodes();
213 inline bool isSingleQuote(const int c);
214 inline CorrectionType processSkipChar(const int c, const bool isTerminal,
215 const bool inputIndexIncremented);
216 inline CorrectionType processUnrelatedCorrectionType();
217 inline void addCharToCurrentWord(const int c);
218 inline int getFinalProbabilityInternal(const int probability, int **word, int *wordLength,
219 const int inputSize);
220
221 static const int TYPED_LETTER_MULTIPLIER = 2;
222 static const int FULL_WORD_MULTIPLIER = 2;
223 const ProximityInfo *mProximityInfo;
224
225 bool mUseFullEditDistance;
226 bool mDoAutoCompletion;
227 int mMaxEditDistance;
228 int mMaxDepth;
229 int mInputSize;
230 int mSpaceProximityPos;
231 int mMissingSpacePos;
232 int mTerminalInputIndex;
233 int mTerminalOutputIndex;
234 int mMaxErrors;
235
236 int mTotalTraverseCount;
237
238 // The following arrays are state buffer.
239 int mWord[MAX_WORD_LENGTH];
240 int mDistances[MAX_WORD_LENGTH];
241
242 // Edit distance calculation requires a buffer with (N+1)^2 length for the input length N.
243 // Caveat: Do not create multiple tables per thread as this table eats up RAM a lot.
244 int mEditDistanceTable[(MAX_WORD_LENGTH + 1) * (MAX_WORD_LENGTH + 1)];
245
246 CorrectionState mCorrectionStates[MAX_WORD_LENGTH];
247
248 // The following member variables are being used as cache values of the correction state.
249 bool mNeedsToTraverseAllNodes;
250 int mOutputIndex;
251 int mInputIndex;
252
253 int mEquivalentCharCount;
254 int mProximityCount;
255 int mExcessiveCount;
256 int mTransposedCount;
257 int mSkippedCount;
258
259 int mTransposedPos;
260 int mExcessivePos;
261 int mSkipPos;
262
263 bool mLastCharExceeded;
264
265 bool mMatching;
266 bool mProximityMatching;
267 bool mAdditionalProximityMatching;
268 bool mExceeding;
269 bool mTransposing;
270 bool mSkipping;
271 ProximityInfoState mProximityInfoState;
272 };
273
incrementInputIndex()274 inline void Correction::incrementInputIndex() {
275 ++mInputIndex;
276 }
277
incrementOutputIndex()278 AK_FORCE_INLINE void Correction::incrementOutputIndex() {
279 ++mOutputIndex;
280 mCorrectionStates[mOutputIndex].mParentIndex = mCorrectionStates[mOutputIndex - 1].mParentIndex;
281 mCorrectionStates[mOutputIndex].mChildCount = mCorrectionStates[mOutputIndex - 1].mChildCount;
282 mCorrectionStates[mOutputIndex].mSiblingPos = mCorrectionStates[mOutputIndex - 1].mSiblingPos;
283 mCorrectionStates[mOutputIndex].mInputIndex = mInputIndex;
284 mCorrectionStates[mOutputIndex].mNeedsToTraverseAllNodes = mNeedsToTraverseAllNodes;
285
286 mCorrectionStates[mOutputIndex].mEquivalentCharCount = mEquivalentCharCount;
287 mCorrectionStates[mOutputIndex].mProximityCount = mProximityCount;
288 mCorrectionStates[mOutputIndex].mTransposedCount = mTransposedCount;
289 mCorrectionStates[mOutputIndex].mExcessiveCount = mExcessiveCount;
290 mCorrectionStates[mOutputIndex].mSkippedCount = mSkippedCount;
291
292 mCorrectionStates[mOutputIndex].mSkipPos = mSkipPos;
293 mCorrectionStates[mOutputIndex].mTransposedPos = mTransposedPos;
294 mCorrectionStates[mOutputIndex].mExcessivePos = mExcessivePos;
295
296 mCorrectionStates[mOutputIndex].mLastCharExceeded = mLastCharExceeded;
297
298 mCorrectionStates[mOutputIndex].mMatching = mMatching;
299 mCorrectionStates[mOutputIndex].mProximityMatching = mProximityMatching;
300 mCorrectionStates[mOutputIndex].mAdditionalProximityMatching = mAdditionalProximityMatching;
301 mCorrectionStates[mOutputIndex].mTransposing = mTransposing;
302 mCorrectionStates[mOutputIndex].mExceeding = mExceeding;
303 mCorrectionStates[mOutputIndex].mSkipping = mSkipping;
304 }
305
startToTraverseAllNodes()306 inline void Correction::startToTraverseAllNodes() {
307 mNeedsToTraverseAllNodes = true;
308 }
309
isSingleQuote(const int c)310 AK_FORCE_INLINE bool Correction::isSingleQuote(const int c) {
311 const int userTypedChar = mProximityInfoState.getPrimaryCodePointAt(mInputIndex);
312 return (c == KEYCODE_SINGLE_QUOTE && userTypedChar != KEYCODE_SINGLE_QUOTE);
313 }
314
processSkipChar(const int c,const bool isTerminal,const bool inputIndexIncremented)315 AK_FORCE_INLINE Correction::CorrectionType Correction::processSkipChar(const int c,
316 const bool isTerminal, const bool inputIndexIncremented) {
317 addCharToCurrentWord(c);
318 mTerminalInputIndex = mInputIndex - (inputIndexIncremented ? 1 : 0);
319 mTerminalOutputIndex = mOutputIndex;
320 incrementOutputIndex();
321 if (mNeedsToTraverseAllNodes && isTerminal) {
322 return TRAVERSE_ALL_ON_TERMINAL;
323 }
324 return TRAVERSE_ALL_NOT_ON_TERMINAL;
325 }
326
processUnrelatedCorrectionType()327 inline Correction::CorrectionType Correction::processUnrelatedCorrectionType() {
328 // Needs to set mTerminalInputIndex and mTerminalOutputIndex before returning any CorrectionType
329 mTerminalInputIndex = mInputIndex;
330 mTerminalOutputIndex = mOutputIndex;
331 return UNRELATED;
332 }
333
calcEditDistanceOneStep(int * editDistanceTable,const int * input,const int inputSize,const int * output,const int outputLength)334 AK_FORCE_INLINE static void calcEditDistanceOneStep(int *editDistanceTable, const int *input,
335 const int inputSize, const int *output, const int outputLength) {
336 // TODO: Make sure that editDistance[0 ~ MAX_WORD_LENGTH] is not touched.
337 // Let dp[i][j] be editDistanceTable[i * (inputSize + 1) + j].
338 // Assuming that dp[0][0] ... dp[outputLength - 1][inputSize] are already calculated,
339 // and calculate dp[ouputLength][0] ... dp[outputLength][inputSize].
340 int *const current = editDistanceTable + outputLength * (inputSize + 1);
341 const int *const prev = editDistanceTable + (outputLength - 1) * (inputSize + 1);
342 const int *const prevprev =
343 outputLength >= 2 ? editDistanceTable + (outputLength - 2) * (inputSize + 1) : 0;
344 current[0] = outputLength;
345 const int co = toBaseLowerCase(output[outputLength - 1]);
346 const int prevCO = outputLength >= 2 ? toBaseLowerCase(output[outputLength - 2]) : 0;
347 for (int i = 1; i <= inputSize; ++i) {
348 const int ci = toBaseLowerCase(input[i - 1]);
349 const int cost = (ci == co) ? 0 : 1;
350 current[i] = min(current[i - 1] + 1, min(prev[i] + 1, prev[i - 1] + cost));
351 if (i >= 2 && prevprev && ci == prevCO && co == toBaseLowerCase(input[i - 2])) {
352 current[i] = min(current[i], prevprev[i - 2] + 1);
353 }
354 }
355 }
356
addCharToCurrentWord(const int c)357 AK_FORCE_INLINE void Correction::addCharToCurrentWord(const int c) {
358 mWord[mOutputIndex] = c;
359 const int *primaryInputWord = mProximityInfoState.getPrimaryInputWord();
360 calcEditDistanceOneStep(mEditDistanceTable, primaryInputWord, mInputSize, mWord,
361 mOutputIndex + 1);
362 }
363
getFinalProbabilityInternal(const int probability,int ** word,int * wordLength,const int inputSize)364 inline int Correction::getFinalProbabilityInternal(const int probability, int **word,
365 int *wordLength, const int inputSize) {
366 const int outputIndex = mTerminalOutputIndex;
367 const int inputIndex = mTerminalInputIndex;
368 *wordLength = outputIndex + 1;
369 *word = mWord;
370 int finalProbability= Correction::RankingAlgorithm::calculateFinalProbability(
371 inputIndex, outputIndex, probability, mEditDistanceTable, this, inputSize);
372 return finalProbability;
373 }
374
375 } // namespace latinime
376 #endif // LATINIME_CORRECTION_H
377