diff options
Diffstat (limited to 'native/jni/src/correction.h')
-rw-r--r-- | native/jni/src/correction.h | 159 |
1 files changed, 131 insertions, 28 deletions
diff --git a/native/jni/src/correction.h b/native/jni/src/correction.h index f016d5453..912cd838e 100644 --- a/native/jni/src/correction.h +++ b/native/jni/src/correction.h @@ -56,7 +56,8 @@ class Correction { // No need to initialize it explicitly here. } - virtual ~Correction() {} + // Non virtual inline destructor -- never inherit this class + ~Correction() {} void resetCorrection(); void initCorrection( const ProximityInfo *pi, const int inputSize, const int maxWordLength); @@ -78,14 +79,13 @@ class Correction { return ++mTotalTraverseCount; } - int getFreqForSplitMultipleWords( - const int *freqArray, const int *wordLengthArray, const int wordCount, - const bool isSpaceProximity, const unsigned short *word); - int getFinalProbability(const int probability, unsigned short **word, int *wordLength); - int getFinalProbabilityForSubQueue(const int probability, unsigned short **word, - int *wordLength, const int inputSize); + int getFreqForSplitMultipleWords(const int *freqArray, const int *wordLengthArray, + const int wordCount, const bool isSpaceProximity, const int *word); + int getFinalProbability(const int probability, int **word, int *wordLength); + int getFinalProbabilityForSubQueue(const int probability, int **word, int *wordLength, + const int inputSize); - CorrectionType processCharAndCalcState(const int32_t c, const bool isTerminal); + CorrectionType processCharAndCalcState(const int c, const bool isTerminal); ///////////////////////// // Tree helper methods @@ -110,29 +110,28 @@ class Correction { const int inputSize); static int calcFreqForSplitMultipleWords(const int *freqArray, const int *wordLengthArray, const int wordCount, const Correction *correction, const bool isSpaceProximity, - const unsigned short *word); - static float calcNormalizedScore(const unsigned short *before, const int beforeLength, - const unsigned short *after, const int afterLength, const int score); - static int editDistance(const unsigned short *before, - const int beforeLength, const unsigned short *after, const int afterLength); + const int *word); + static float calcNormalizedScore(const int *before, const int beforeLength, + const int *after, const int afterLength, const int score); + static int editDistance(const int *before, const int beforeLength, const int *after, + const int afterLength); private: - static const int CODE_SPACE = ' '; static const int MAX_INITIAL_SCORE = 255; }; // proximity info state - void initInputParams(const ProximityInfo *proximityInfo, const int32_t *inputCodes, + void initInputParams(const ProximityInfo *proximityInfo, const int *inputCodes, const int inputSize, const int *xCoordinates, const int *yCoordinates) { mProximityInfoState.initInputParams(0, MAX_POINT_TO_KEY_LENGTH, proximityInfo, inputCodes, inputSize, xCoordinates, yCoordinates, 0, 0, false); } - const unsigned short *getPrimaryInputWord() const { + const int *getPrimaryInputWord() const { return mProximityInfoState.getPrimaryInputWord(); } - unsigned short getPrimaryCharAt(const int index) const { - return mProximityInfoState.getPrimaryCharAt(index); + int getPrimaryCodePointAt(const int index) const { + return mProximityInfoState.getPrimaryCodePointAt(index); } private: @@ -147,7 +146,7 @@ class Correction { } static const int TWO_31ST_DIV_2 = S_INT_MAX / 2; - inline static void multiplyIntCapped(const int multiplier, int *base) { + AK_FORCE_INLINE static void multiplyIntCapped(const int multiplier, int *base) { const int temp = *base; if (temp != S_INT_MAX) { // Branch if multiplier == 2 for the optimization @@ -170,7 +169,7 @@ class Correction { } } - inline static int powerIntCapped(const int base, const int n) { + AK_FORCE_INLINE static int powerIntCapped(const int base, const int n) { if (n <= 0) return 1; if (base == 2) { return n < 31 ? 1 << n : S_INT_MAX; @@ -181,7 +180,7 @@ class Correction { } } - inline static void multiplyRate(const int rate, int *freq) { + AK_FORCE_INLINE static void multiplyRate(const int rate, int *freq) { if (*freq != S_INT_MAX) { if (*freq > 1000000) { *freq /= 100; @@ -215,13 +214,13 @@ class Correction { inline void incrementInputIndex(); inline void incrementOutputIndex(); inline void startToTraverseAllNodes(); - inline bool isSingleQuote(const unsigned short c); - inline CorrectionType processSkipChar( - const int32_t c, const bool isTerminal, const bool inputIndexIncremented); + inline bool isSingleQuote(const int c); + inline CorrectionType processSkipChar(const int c, const bool isTerminal, + const bool inputIndexIncremented); inline CorrectionType processUnrelatedCorrectionType(); - inline void addCharToCurrentWord(const int32_t c); - inline int getFinalProbabilityInternal(const int probability, unsigned short **word, - int *wordLength, const int inputSize); + inline void addCharToCurrentWord(const int c); + inline int getFinalProbabilityInternal(const int probability, int **word, int *wordLength, + const int inputSize); static const int TYPED_LETTER_MULTIPLIER = 2; static const int FULL_WORD_MULTIPLIER = 2; @@ -241,7 +240,7 @@ class Correction { uint8_t mTotalTraverseCount; // The following arrays are state buffer. - unsigned short mWord[MAX_WORD_LENGTH_INTERNAL]; + int mWord[MAX_WORD_LENGTH_INTERNAL]; int mDistances[MAX_WORD_LENGTH_INTERNAL]; // Edit distance calculation requires a buffer with (N+1)^2 length for the input length N. @@ -275,5 +274,109 @@ class Correction { bool mSkipping; ProximityInfoState mProximityInfoState; }; + +inline void Correction::incrementInputIndex() { + ++mInputIndex; +} + +AK_FORCE_INLINE void Correction::incrementOutputIndex() { + ++mOutputIndex; + mCorrectionStates[mOutputIndex].mParentIndex = mCorrectionStates[mOutputIndex - 1].mParentIndex; + mCorrectionStates[mOutputIndex].mChildCount = mCorrectionStates[mOutputIndex - 1].mChildCount; + mCorrectionStates[mOutputIndex].mSiblingPos = mCorrectionStates[mOutputIndex - 1].mSiblingPos; + mCorrectionStates[mOutputIndex].mInputIndex = mInputIndex; + mCorrectionStates[mOutputIndex].mNeedsToTraverseAllNodes = mNeedsToTraverseAllNodes; + + mCorrectionStates[mOutputIndex].mEquivalentCharCount = mEquivalentCharCount; + mCorrectionStates[mOutputIndex].mProximityCount = mProximityCount; + mCorrectionStates[mOutputIndex].mTransposedCount = mTransposedCount; + mCorrectionStates[mOutputIndex].mExcessiveCount = mExcessiveCount; + mCorrectionStates[mOutputIndex].mSkippedCount = mSkippedCount; + + mCorrectionStates[mOutputIndex].mSkipPos = mSkipPos; + mCorrectionStates[mOutputIndex].mTransposedPos = mTransposedPos; + mCorrectionStates[mOutputIndex].mExcessivePos = mExcessivePos; + + mCorrectionStates[mOutputIndex].mLastCharExceeded = mLastCharExceeded; + + mCorrectionStates[mOutputIndex].mMatching = mMatching; + mCorrectionStates[mOutputIndex].mProximityMatching = mProximityMatching; + mCorrectionStates[mOutputIndex].mAdditionalProximityMatching = mAdditionalProximityMatching; + mCorrectionStates[mOutputIndex].mTransposing = mTransposing; + mCorrectionStates[mOutputIndex].mExceeding = mExceeding; + mCorrectionStates[mOutputIndex].mSkipping = mSkipping; +} + +inline void Correction::startToTraverseAllNodes() { + mNeedsToTraverseAllNodes = true; +} + +inline bool Correction::isSingleQuote(const int c) { + const int userTypedChar = mProximityInfoState.getPrimaryCodePointAt(mInputIndex); + return (c == KEYCODE_SINGLE_QUOTE && userTypedChar != KEYCODE_SINGLE_QUOTE); +} + +AK_FORCE_INLINE Correction::CorrectionType Correction::processSkipChar(const int c, + const bool isTerminal, const bool inputIndexIncremented) { + addCharToCurrentWord(c); + mTerminalInputIndex = mInputIndex - (inputIndexIncremented ? 1 : 0); + mTerminalOutputIndex = mOutputIndex; + if (mNeedsToTraverseAllNodes && isTerminal) { + incrementOutputIndex(); + return TRAVERSE_ALL_ON_TERMINAL; + } else { + incrementOutputIndex(); + return TRAVERSE_ALL_NOT_ON_TERMINAL; + } +} + +inline Correction::CorrectionType Correction::processUnrelatedCorrectionType() { + // Needs to set mTerminalInputIndex and mTerminalOutputIndex before returning any CorrectionType + mTerminalInputIndex = mInputIndex; + mTerminalOutputIndex = mOutputIndex; + return UNRELATED; +} + +AK_FORCE_INLINE static void calcEditDistanceOneStep(int *editDistanceTable, const int *input, + const int inputSize, const int *output, const int outputLength) { + // TODO: Make sure that editDistance[0 ~ MAX_WORD_LENGTH_INTERNAL] is not touched. + // Let dp[i][j] be editDistanceTable[i * (inputSize + 1) + j]. + // Assuming that dp[0][0] ... dp[outputLength - 1][inputSize] are already calculated, + // and calculate dp[ouputLength][0] ... dp[outputLength][inputSize]. + int *const current = editDistanceTable + outputLength * (inputSize + 1); + const int *const prev = editDistanceTable + (outputLength - 1) * (inputSize + 1); + const int *const prevprev = + outputLength >= 2 ? editDistanceTable + (outputLength - 2) * (inputSize + 1) : 0; + current[0] = outputLength; + const int co = toBaseLowerCase(output[outputLength - 1]); + const int prevCO = outputLength >= 2 ? toBaseLowerCase(output[outputLength - 2]) : 0; + for (int i = 1; i <= inputSize; ++i) { + const int ci = toBaseLowerCase(input[i - 1]); + const uint16_t cost = (ci == co) ? 0 : 1; + current[i] = min(current[i - 1] + 1, min(prev[i] + 1, prev[i - 1] + cost)); + if (i >= 2 && prevprev && ci == prevCO && co == toBaseLowerCase(input[i - 2])) { + current[i] = min(current[i], prevprev[i - 2] + 1); + } + } +} + +AK_FORCE_INLINE void Correction::addCharToCurrentWord(const int c) { + mWord[mOutputIndex] = c; + const int *primaryInputWord = mProximityInfoState.getPrimaryInputWord(); + calcEditDistanceOneStep(mEditDistanceTable, primaryInputWord, mInputSize, mWord, + mOutputIndex + 1); +} + +inline int Correction::getFinalProbabilityInternal(const int probability, int **word, + int *wordLength, const int inputSize) { + const int outputIndex = mTerminalOutputIndex; + const int inputIndex = mTerminalInputIndex; + *wordLength = outputIndex + 1; + *word = mWord; + int finalProbability= Correction::RankingAlgorithm::calculateFinalProbability( + inputIndex, outputIndex, probability, mEditDistanceTable, this, inputSize); + return finalProbability; +} + } // namespace latinime #endif // LATINIME_CORRECTION_H |