• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2011 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LATINIME_CORRECTION_H
18 #define LATINIME_CORRECTION_H
19 
20 #include <cstring> // for memset()
21 
22 #include "correction_state.h"
23 #include "defines.h"
24 #include "proximity_info_state.h"
25 
26 namespace latinime {
27 
28 class ProximityInfo;
29 
30 class Correction {
31  public:
32     typedef enum {
33         TRAVERSE_ALL_ON_TERMINAL,
34         TRAVERSE_ALL_NOT_ON_TERMINAL,
35         UNRELATED,
36         ON_TERMINAL,
37         NOT_ON_TERMINAL
38     } CorrectionType;
39 
Correction()40     Correction()
41             : mProximityInfo(0), mUseFullEditDistance(false), mDoAutoCompletion(false),
42               mMaxEditDistance(0), mMaxDepth(0), mInputSize(0), mSpaceProximityPos(0),
43               mMissingSpacePos(0), mTerminalInputIndex(0), mTerminalOutputIndex(0), mMaxErrors(0),
44               mTotalTraverseCount(0), mNeedsToTraverseAllNodes(false), mOutputIndex(0),
45               mInputIndex(0), mEquivalentCharCount(0), mProximityCount(0), mExcessiveCount(0),
46               mTransposedCount(0), mSkippedCount(0), mTransposedPos(0), mExcessivePos(0),
47               mSkipPos(0), mLastCharExceeded(false), mMatching(false), mProximityMatching(false),
48               mAdditionalProximityMatching(false), mExceeding(false), mTransposing(false),
49               mSkipping(false), mProximityInfoState() {
50         memset(mWord, 0, sizeof(mWord));
51         memset(mDistances, 0, sizeof(mDistances));
52         memset(mEditDistanceTable, 0, sizeof(mEditDistanceTable));
53         // NOTE: mCorrectionStates is an array of instances.
54         // No need to initialize it explicitly here.
55     }
56 
57     // Non virtual inline destructor -- never inherit this class
~Correction()58     ~Correction() {}
59     void resetCorrection();
60     void initCorrection(const ProximityInfo *pi, const int inputSize, const int maxDepth);
61     void initCorrectionState(const int rootPos, const int childCount, const bool traverseAll);
62 
63     // TODO: remove
64     void setCorrectionParams(const int skipPos, const int excessivePos, const int transposedPos,
65             const int spaceProximityPos, const int missingSpacePos, const bool useFullEditDistance,
66             const bool doAutoCompletion, const int maxErrors);
67     void checkState() const;
68     bool sameAsTyped() const;
69     bool initProcessState(const int index);
70 
71     int getInputIndex() const;
72 
73     bool needsToPrune() const;
74 
pushAndGetTotalTraverseCount()75     int pushAndGetTotalTraverseCount() {
76         return ++mTotalTraverseCount;
77     }
78 
79     int getFreqForSplitMultipleWords(const int *freqArray, const int *wordLengthArray,
80             const int wordCount, const bool isSpaceProximity, const int *word) const;
81     int getFinalProbability(const int probability, int **word, int *wordLength);
82     int getFinalProbabilityForSubQueue(const int probability, int **word, int *wordLength,
83             const int inputSize);
84 
85     CorrectionType processCharAndCalcState(const int c, const bool isTerminal);
86 
87     /////////////////////////
88     // Tree helper methods
89     int goDownTree(const int parentIndex, const int childCount, const int firstChildPos);
90 
getTreeSiblingPos(const int index)91     inline int getTreeSiblingPos(const int index) const {
92         return mCorrectionStates[index].mSiblingPos;
93     }
94 
setTreeSiblingPos(const int index,const int pos)95     inline void setTreeSiblingPos(const int index, const int pos) {
96         mCorrectionStates[index].mSiblingPos = pos;
97     }
98 
getTreeParentIndex(const int index)99     inline int getTreeParentIndex(const int index) const {
100         return mCorrectionStates[index].mParentIndex;
101     }
102 
103     class RankingAlgorithm {
104      public:
105         static int calculateFinalProbability(const int inputIndex, const int depth,
106                 const int probability, int *editDistanceTable, const Correction *correction,
107                 const int inputSize);
108         static int calcFreqForSplitMultipleWords(const int *freqArray, const int *wordLengthArray,
109                 const int wordCount, const Correction *correction, const bool isSpaceProximity,
110                 const int *word);
111         static float calcNormalizedScore(const int *before, const int beforeLength,
112                 const int *after, const int afterLength, const int score);
113         static int editDistance(const int *before, const int beforeLength, const int *after,
114                 const int afterLength);
115      private:
116         static const int MAX_INITIAL_SCORE = 255;
117     };
118 
119     // proximity info state
initInputParams(const ProximityInfo * proximityInfo,const int * inputCodes,const int inputSize,const int * xCoordinates,const int * yCoordinates)120     void initInputParams(const ProximityInfo *proximityInfo, const int *inputCodes,
121             const int inputSize, const int *xCoordinates, const int *yCoordinates) {
122         mProximityInfoState.initInputParams(0, static_cast<float>(MAX_VALUE_FOR_WEIGHTING),
123                 proximityInfo, inputCodes, inputSize, xCoordinates, yCoordinates, 0, 0, false);
124     }
125 
getPrimaryInputWord()126     const int *getPrimaryInputWord() const {
127         return mProximityInfoState.getPrimaryInputWord();
128     }
129 
getPrimaryCodePointAt(const int index)130     int getPrimaryCodePointAt(const int index) const {
131         return mProximityInfoState.getPrimaryCodePointAt(index);
132     }
133 
134  private:
135     DISALLOW_COPY_AND_ASSIGN(Correction);
136 
137     /////////////////////////
138     // static inline utils //
139     /////////////////////////
140     static const int TWO_31ST_DIV_255 = S_INT_MAX / 255;
capped255MultForFullMatchAccentsOrCapitalizationDifference(const int num)141     static inline int capped255MultForFullMatchAccentsOrCapitalizationDifference(const int num) {
142         return (num < TWO_31ST_DIV_255 ? 255 * num : S_INT_MAX);
143     }
144 
145     static const int TWO_31ST_DIV_2 = S_INT_MAX / 2;
multiplyIntCapped(const int multiplier,int * base)146     AK_FORCE_INLINE static void multiplyIntCapped(const int multiplier, int *base) {
147         const int temp = *base;
148         if (temp != S_INT_MAX) {
149             // Branch if multiplier == 2 for the optimization
150             if (multiplier < 0) {
151                 if (DEBUG_DICT) {
152                     ASSERT(false);
153                 }
154                 AKLOGI("--- Invalid multiplier: %d", multiplier);
155             } else if (multiplier == 0) {
156                 *base = 0;
157             } else if (multiplier == 2) {
158                 *base = TWO_31ST_DIV_2 >= temp ? temp << 1 : S_INT_MAX;
159             } else {
160                 // TODO: This overflow check gives a wrong answer when, for example,
161                 //       temp = 2^16 + 1 and multiplier = 2^17 + 1.
162                 //       Fix this behavior.
163                 const int tempRetval = temp * multiplier;
164                 *base = tempRetval >= temp ? tempRetval : S_INT_MAX;
165             }
166         }
167     }
168 
powerIntCapped(const int base,const int n)169     AK_FORCE_INLINE static int powerIntCapped(const int base, const int n) {
170         if (n <= 0) return 1;
171         if (base == 2) {
172             return n < 31 ? 1 << n : S_INT_MAX;
173         }
174         int ret = base;
175         for (int i = 1; i < n; ++i) multiplyIntCapped(base, &ret);
176         return ret;
177     }
178 
multiplyRate(const int rate,int * freq)179     AK_FORCE_INLINE static void multiplyRate(const int rate, int *freq) {
180         if (*freq != S_INT_MAX) {
181             if (*freq > 1000000) {
182                 *freq /= 100;
183                 multiplyIntCapped(rate, freq);
184             } else {
185                 multiplyIntCapped(rate, freq);
186                 *freq /= 100;
187             }
188         }
189     }
190 
getSpaceProximityPos()191     inline int getSpaceProximityPos() const {
192         return mSpaceProximityPos;
193     }
getMissingSpacePos()194     inline int getMissingSpacePos() const {
195         return mMissingSpacePos;
196     }
197 
getSkipPos()198     inline int getSkipPos() const {
199         return mSkipPos;
200     }
201 
getExcessivePos()202     inline int getExcessivePos() const {
203         return mExcessivePos;
204     }
205 
getTransposedPos()206     inline int getTransposedPos() const {
207         return mTransposedPos;
208     }
209 
210     inline void incrementInputIndex();
211     inline void incrementOutputIndex();
212     inline void startToTraverseAllNodes();
213     inline bool isSingleQuote(const int c);
214     inline CorrectionType processSkipChar(const int c, const bool isTerminal,
215             const bool inputIndexIncremented);
216     inline CorrectionType processUnrelatedCorrectionType();
217     inline void addCharToCurrentWord(const int c);
218     inline int getFinalProbabilityInternal(const int probability, int **word, int *wordLength,
219             const int inputSize);
220 
221     static const int TYPED_LETTER_MULTIPLIER = 2;
222     static const int FULL_WORD_MULTIPLIER = 2;
223     const ProximityInfo *mProximityInfo;
224 
225     bool mUseFullEditDistance;
226     bool mDoAutoCompletion;
227     int mMaxEditDistance;
228     int mMaxDepth;
229     int mInputSize;
230     int mSpaceProximityPos;
231     int mMissingSpacePos;
232     int mTerminalInputIndex;
233     int mTerminalOutputIndex;
234     int mMaxErrors;
235 
236     int mTotalTraverseCount;
237 
238     // The following arrays are state buffer.
239     int mWord[MAX_WORD_LENGTH];
240     int mDistances[MAX_WORD_LENGTH];
241 
242     // Edit distance calculation requires a buffer with (N+1)^2 length for the input length N.
243     // Caveat: Do not create multiple tables per thread as this table eats up RAM a lot.
244     int mEditDistanceTable[(MAX_WORD_LENGTH + 1) * (MAX_WORD_LENGTH + 1)];
245 
246     CorrectionState mCorrectionStates[MAX_WORD_LENGTH];
247 
248     // The following member variables are being used as cache values of the correction state.
249     bool mNeedsToTraverseAllNodes;
250     int mOutputIndex;
251     int mInputIndex;
252 
253     int mEquivalentCharCount;
254     int mProximityCount;
255     int mExcessiveCount;
256     int mTransposedCount;
257     int mSkippedCount;
258 
259     int mTransposedPos;
260     int mExcessivePos;
261     int mSkipPos;
262 
263     bool mLastCharExceeded;
264 
265     bool mMatching;
266     bool mProximityMatching;
267     bool mAdditionalProximityMatching;
268     bool mExceeding;
269     bool mTransposing;
270     bool mSkipping;
271     ProximityInfoState mProximityInfoState;
272 };
273 
incrementInputIndex()274 inline void Correction::incrementInputIndex() {
275     ++mInputIndex;
276 }
277 
incrementOutputIndex()278 AK_FORCE_INLINE void Correction::incrementOutputIndex() {
279     ++mOutputIndex;
280     mCorrectionStates[mOutputIndex].mParentIndex = mCorrectionStates[mOutputIndex - 1].mParentIndex;
281     mCorrectionStates[mOutputIndex].mChildCount = mCorrectionStates[mOutputIndex - 1].mChildCount;
282     mCorrectionStates[mOutputIndex].mSiblingPos = mCorrectionStates[mOutputIndex - 1].mSiblingPos;
283     mCorrectionStates[mOutputIndex].mInputIndex = mInputIndex;
284     mCorrectionStates[mOutputIndex].mNeedsToTraverseAllNodes = mNeedsToTraverseAllNodes;
285 
286     mCorrectionStates[mOutputIndex].mEquivalentCharCount = mEquivalentCharCount;
287     mCorrectionStates[mOutputIndex].mProximityCount = mProximityCount;
288     mCorrectionStates[mOutputIndex].mTransposedCount = mTransposedCount;
289     mCorrectionStates[mOutputIndex].mExcessiveCount = mExcessiveCount;
290     mCorrectionStates[mOutputIndex].mSkippedCount = mSkippedCount;
291 
292     mCorrectionStates[mOutputIndex].mSkipPos = mSkipPos;
293     mCorrectionStates[mOutputIndex].mTransposedPos = mTransposedPos;
294     mCorrectionStates[mOutputIndex].mExcessivePos = mExcessivePos;
295 
296     mCorrectionStates[mOutputIndex].mLastCharExceeded = mLastCharExceeded;
297 
298     mCorrectionStates[mOutputIndex].mMatching = mMatching;
299     mCorrectionStates[mOutputIndex].mProximityMatching = mProximityMatching;
300     mCorrectionStates[mOutputIndex].mAdditionalProximityMatching = mAdditionalProximityMatching;
301     mCorrectionStates[mOutputIndex].mTransposing = mTransposing;
302     mCorrectionStates[mOutputIndex].mExceeding = mExceeding;
303     mCorrectionStates[mOutputIndex].mSkipping = mSkipping;
304 }
305 
startToTraverseAllNodes()306 inline void Correction::startToTraverseAllNodes() {
307     mNeedsToTraverseAllNodes = true;
308 }
309 
isSingleQuote(const int c)310 AK_FORCE_INLINE bool Correction::isSingleQuote(const int c) {
311     const int userTypedChar = mProximityInfoState.getPrimaryCodePointAt(mInputIndex);
312     return (c == KEYCODE_SINGLE_QUOTE && userTypedChar != KEYCODE_SINGLE_QUOTE);
313 }
314 
processSkipChar(const int c,const bool isTerminal,const bool inputIndexIncremented)315 AK_FORCE_INLINE Correction::CorrectionType Correction::processSkipChar(const int c,
316         const bool isTerminal, const bool inputIndexIncremented) {
317     addCharToCurrentWord(c);
318     mTerminalInputIndex = mInputIndex - (inputIndexIncremented ? 1 : 0);
319     mTerminalOutputIndex = mOutputIndex;
320     incrementOutputIndex();
321     if (mNeedsToTraverseAllNodes && isTerminal) {
322         return TRAVERSE_ALL_ON_TERMINAL;
323     }
324     return TRAVERSE_ALL_NOT_ON_TERMINAL;
325 }
326 
processUnrelatedCorrectionType()327 inline Correction::CorrectionType Correction::processUnrelatedCorrectionType() {
328     // Needs to set mTerminalInputIndex and mTerminalOutputIndex before returning any CorrectionType
329     mTerminalInputIndex = mInputIndex;
330     mTerminalOutputIndex = mOutputIndex;
331     return UNRELATED;
332 }
333 
calcEditDistanceOneStep(int * editDistanceTable,const int * input,const int inputSize,const int * output,const int outputLength)334 AK_FORCE_INLINE static void calcEditDistanceOneStep(int *editDistanceTable, const int *input,
335         const int inputSize, const int *output, const int outputLength) {
336     // TODO: Make sure that editDistance[0 ~ MAX_WORD_LENGTH] is not touched.
337     // Let dp[i][j] be editDistanceTable[i * (inputSize + 1) + j].
338     // Assuming that dp[0][0] ... dp[outputLength - 1][inputSize] are already calculated,
339     // and calculate dp[ouputLength][0] ... dp[outputLength][inputSize].
340     int *const current = editDistanceTable + outputLength * (inputSize + 1);
341     const int *const prev = editDistanceTable + (outputLength - 1) * (inputSize + 1);
342     const int *const prevprev =
343             outputLength >= 2 ? editDistanceTable + (outputLength - 2) * (inputSize + 1) : 0;
344     current[0] = outputLength;
345     const int co = toBaseLowerCase(output[outputLength - 1]);
346     const int prevCO = outputLength >= 2 ? toBaseLowerCase(output[outputLength - 2]) : 0;
347     for (int i = 1; i <= inputSize; ++i) {
348         const int ci = toBaseLowerCase(input[i - 1]);
349         const int cost = (ci == co) ? 0 : 1;
350         current[i] = min(current[i - 1] + 1, min(prev[i] + 1, prev[i - 1] + cost));
351         if (i >= 2 && prevprev && ci == prevCO && co == toBaseLowerCase(input[i - 2])) {
352             current[i] = min(current[i], prevprev[i - 2] + 1);
353         }
354     }
355 }
356 
addCharToCurrentWord(const int c)357 AK_FORCE_INLINE void Correction::addCharToCurrentWord(const int c) {
358     mWord[mOutputIndex] = c;
359     const int *primaryInputWord = mProximityInfoState.getPrimaryInputWord();
360     calcEditDistanceOneStep(mEditDistanceTable, primaryInputWord, mInputSize, mWord,
361             mOutputIndex + 1);
362 }
363 
getFinalProbabilityInternal(const int probability,int ** word,int * wordLength,const int inputSize)364 inline int Correction::getFinalProbabilityInternal(const int probability, int **word,
365         int *wordLength, const int inputSize) {
366     const int outputIndex = mTerminalOutputIndex;
367     const int inputIndex = mTerminalInputIndex;
368     *wordLength = outputIndex + 1;
369     *word = mWord;
370     int finalProbability= Correction::RankingAlgorithm::calculateFinalProbability(
371             inputIndex, outputIndex, probability, mEditDistanceTable, this, inputSize);
372     return finalProbability;
373 }
374 
375 } // namespace latinime
376 #endif // LATINIME_CORRECTION_H
377