diff options
Diffstat (limited to 'native/src/correction.cpp')
-rw-r--r-- | native/src/correction.cpp | 274 |
1 files changed, 216 insertions, 58 deletions
diff --git a/native/src/correction.cpp b/native/src/correction.cpp index 6d682c0c9..a4090a966 100644 --- a/native/src/correction.cpp +++ b/native/src/correction.cpp @@ -21,6 +21,7 @@ #define LOG_TAG "LatinIME: correction.cpp" #include "correction.h" +#include "dictionary.h" #include "proximity_info.h" namespace latinime { @@ -49,12 +50,21 @@ void Correction::initCorrection(const ProximityInfo *pi, const int inputLength, mInputLength = inputLength; mMaxDepth = maxDepth; mMaxEditDistance = mInputLength < 5 ? 2 : mInputLength / 2; - mSkippedOutputIndex = -1; +} + +void Correction::initCorrectionState( + const int rootPos, const int childCount, const bool traverseAll) { + latinime::initCorrectionState(mCorrectionStates, rootPos, childCount, traverseAll); + // TODO: remove + mCorrectionStates[0].mSkipPos = mSkipPos; } void Correction::setCorrectionParams(const int skipPos, const int excessivePos, const int transposedPos, const int spaceProximityPos, const int missingSpacePos) { + // TODO: remove mSkipPos = skipPos; + // TODO: remove + mCorrectionStates[0].mSkipPos = skipPos; mExcessivePos = excessivePos; mTransposedPos = transposedPos; mSpaceProximityPos = spaceProximityPos; @@ -83,33 +93,37 @@ int Correction::getFinalFreq(const int freq, unsigned short **word, int *wordLen if (mProximityInfo->sameAsTyped(mWord, outputIndex + 1) || outputIndex < MIN_SUGGEST_DEPTH) { return -1; } + *word = mWord; const bool sameLength = (mExcessivePos == mInputLength - 1) ? (mInputLength == inputIndex + 2) : (mInputLength == inputIndex + 1); return Correction::RankingAlgorithm::calculateFinalFreq( - inputIndex, outputIndex, mMatchedCharCount, freq, sameLength, this); + inputIndex, outputIndex, freq, sameLength, mEditDistanceTable, this); } -void Correction::initProcessState(const int matchCount, const int inputIndex, - const int outputIndex, const bool traverseAllNodes, const int diffs) { - mMatchedCharCount = matchCount; - mInputIndex = inputIndex; +bool Correction::initProcessState(const int outputIndex) { + if (mCorrectionStates[outputIndex].mChildCount <= 0) { + return false; + } mOutputIndex = outputIndex; - mTraverseAllNodes = traverseAllNodes; - mDiffs = diffs; -} - -void Correction::getProcessState(int *matchedCount, int *inputIndex, int *outputIndex, - bool *traverseAllNodes, int *diffs) { - *matchedCount = mMatchedCharCount; - *inputIndex = mInputIndex; - *outputIndex = mOutputIndex; - *traverseAllNodes = mTraverseAllNodes; - *diffs = mDiffs; + --(mCorrectionStates[outputIndex].mChildCount); + mInputIndex = mCorrectionStates[outputIndex].mInputIndex; + mNeedsToTraverseAllNodes = mCorrectionStates[outputIndex].mNeedsToTraverseAllNodes; + mProximityCount = mCorrectionStates[outputIndex].mProximityCount; + mSkippedCount = mCorrectionStates[outputIndex].mSkippedCount; + mSkipPos = mCorrectionStates[outputIndex].mSkipPos; + mSkipping = false; + mProximityMatching = false; + mMatching = false; + return true; } -void Correction::charMatched() { - ++mMatchedCharCount; +int Correction::goDownTree( + const int parentIndex, const int childCount, const int firstChildPos) { + mCorrectionStates[mOutputIndex].mParentIndex = parentIndex; + mCorrectionStates[mOutputIndex].mChildCount = childCount; + mCorrectionStates[mOutputIndex].mSiblingPos = firstChildPos; + return mOutputIndex; } // TODO: remove @@ -123,8 +137,8 @@ int Correction::getInputIndex() { } // TODO: remove -bool Correction::needsToTraverseAll() { - return mTraverseAllNodes; +bool Correction::needsToTraverseAllNodes() { + return mNeedsToTraverseAllNodes; } void Correction::incrementInputIndex() { @@ -133,21 +147,32 @@ void Correction::incrementInputIndex() { void Correction::incrementOutputIndex() { ++mOutputIndex; + mCorrectionStates[mOutputIndex].mParentIndex = mCorrectionStates[mOutputIndex - 1].mParentIndex; + mCorrectionStates[mOutputIndex].mChildCount = mCorrectionStates[mOutputIndex - 1].mChildCount; + mCorrectionStates[mOutputIndex].mSiblingPos = mCorrectionStates[mOutputIndex - 1].mSiblingPos; + mCorrectionStates[mOutputIndex].mInputIndex = mInputIndex; + mCorrectionStates[mOutputIndex].mNeedsToTraverseAllNodes = mNeedsToTraverseAllNodes; + mCorrectionStates[mOutputIndex].mProximityCount = mProximityCount; + mCorrectionStates[mOutputIndex].mSkippedCount = mSkippedCount; + mCorrectionStates[mOutputIndex].mSkipping = mSkipping; + mCorrectionStates[mOutputIndex].mSkipPos = mSkipPos; + mCorrectionStates[mOutputIndex].mMatching = mMatching; + mCorrectionStates[mOutputIndex].mProximityMatching = mProximityMatching; } -void Correction::startTraverseAll() { - mTraverseAllNodes = true; +void Correction::startToTraverseAllNodes() { + mNeedsToTraverseAllNodes = true; } bool Correction::needsToPrune() const { return (mOutputIndex - 1 >= (mTransposedPos >= 0 ? mInputLength - 1 : mMaxDepth) - || mDiffs > mMaxEditDistance); + || mProximityCount > mMaxEditDistance); } Correction::CorrectionType Correction::processSkipChar( const int32_t c, const bool isTerminal) { mWord[mOutputIndex] = c; - if (needsToTraverseAll() && isTerminal) { + if (needsToTraverseAllNodes() && isTerminal) { mTerminalInputIndex = mInputIndex; mTerminalOutputIndex = mOutputIndex; incrementOutputIndex(); @@ -169,10 +194,31 @@ Correction::CorrectionType Correction::processCharAndCalcState( bool skip = false; if (mSkipPos >= 0) { + if (mSkippedCount == 0 && mSkipPos < mOutputIndex) { + if (DEBUG_DICT) { + assert(mSkipPos == mOutputIndex - 1); + } + ++mSkipPos; + } skip = mSkipPos == mOutputIndex; + mSkipping = true; } - if (mTraverseAllNodes || isQuote(c)) { + if (mNeedsToTraverseAllNodes || isQuote(c)) { + const bool checkProximityChars = + !(mSkippedCount > 0 || mExcessivePos >= 0 || mTransposedPos >= 0); + // Note: This logic tries saving cases like contrst --> contrast -- "a" is one of + // proximity chars of "s", but it should rather be handled as a skipped char. + if (checkProximityChars + && mInputIndex > 0 + && mCorrectionStates[mOutputIndex].mProximityMatching + && mCorrectionStates[mOutputIndex].mSkipping + && mProximityInfo->getMatchedProximityId( + mInputIndex - 1, c, false) + == ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR) { + ++mSkippedCount; + --mProximityCount; + } return processSkipChar(c, isTerminal); } else { int inputIndexForProximity = mInputIndex; @@ -186,40 +232,40 @@ Correction::CorrectionType Correction::processCharAndCalcState( } } + // TODO: sum counters const bool checkProximityChars = - !(mSkipPos >= 0 || mExcessivePos >= 0 || mTransposedPos >= 0); + !(mSkippedCount > 0 || mExcessivePos >= 0 || mTransposedPos >= 0); int matchedProximityCharId = mProximityInfo->getMatchedProximityId( inputIndexForProximity, c, checkProximityChars); - const bool unrelated = ProximityInfo::UNRELATED_CHAR == matchedProximityCharId; - if (unrelated) { - if (skip) { + if (ProximityInfo::UNRELATED_CHAR == matchedProximityCharId) { + if (skip && mProximityCount == 0) { // Skip this letter and continue deeper - mSkippedOutputIndex = mOutputIndex; + ++mSkippedCount; + return processSkipChar(c, isTerminal); + } else if (checkProximityChars + && inputIndexForProximity > 0 + && mCorrectionStates[mOutputIndex].mProximityMatching + && mCorrectionStates[mOutputIndex].mSkipping + && mProximityInfo->getMatchedProximityId( + inputIndexForProximity - 1, c, false) + == ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR) { + ++mSkippedCount; + --mProximityCount; return processSkipChar(c, isTerminal); } else { return UNRELATED; } - } - - // No need to skip. Finish traversing and increment skipPos. - // TODO: Remove this? - if (skip) { - mWord[mOutputIndex] = c; - incrementOutputIndex(); - return TRAVERSE_ALL_NOT_ON_TERMINAL; + } else if (ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR == matchedProximityCharId) { + // If inputIndex is greater than mInputLength, that means there is no + // proximity chars. So, we don't need to check proximity. + mMatching = true; + } else if (ProximityInfo::NEAR_PROXIMITY_CHAR == matchedProximityCharId) { + mProximityMatching = true; + incrementProximityCount(); } mWord[mOutputIndex] = c; - // If inputIndex is greater than mInputLength, that means there is no - // proximity chars. So, we don't need to check proximity. - if (ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR == matchedProximityCharId) { - charMatched(); - } - - if (ProximityInfo::NEAR_PROXIMITY_CHAR == matchedProximityCharId) { - incrementDiffs(); - } const bool isSameAsUserTypedLength = mInputLength == getInputIndex() + 1 @@ -232,7 +278,7 @@ Correction::CorrectionType Correction::processCharAndCalcState( } // Start traversing all nodes after the index exceeds the user typed length if (isSameAsUserTypedLength) { - startTraverseAll(); + startToTraverseAllNodes(); } // Finally, we are ready to go to the next character, the next "virtual node". @@ -298,26 +344,117 @@ inline static void multiplyRate(const int rate, int *freq) { } } +/* static */ +inline static int editDistance( + int* editDistanceTable, const unsigned short* input, + const int inputLength, const unsigned short* output, const int outputLength) { + // dp[li][lo] dp[a][b] = dp[ a * lo + b] + int* dp = editDistanceTable; + const int li = inputLength + 1; + const int lo = outputLength + 1; + for (int i = 0; i < li; ++i) { + dp[lo * i] = i; + } + for (int i = 0; i < lo; ++i) { + dp[i] = i; + } + + for (int i = 0; i < li - 1; ++i) { + for (int j = 0; j < lo - 1; ++j) { + const uint32_t ci = Dictionary::toBaseLowerCase(input[i]); + const uint32_t co = Dictionary::toBaseLowerCase(output[j]); + const uint16_t cost = (ci == co) ? 0 : 1; + dp[(i + 1) * lo + (j + 1)] = min(dp[i * lo + (j + 1)] + 1, + min(dp[(i + 1) * lo + j] + 1, dp[i * lo + j] + cost)); + if (li > 0 && lo > 0 + && ci == Dictionary::toBaseLowerCase(output[j - 1]) + && co == Dictionary::toBaseLowerCase(input[i - 1])) { + dp[(i + 1) * lo + (j + 1)] = min( + dp[(i + 1) * lo + (j + 1)], dp[(i - 1) * lo + (j - 1)] + cost); + } + } + } + + if (DEBUG_EDIT_DISTANCE) { + LOGI("IN = %d, OUT = %d", inputLength, outputLength); + for (int i = 0; i < li; ++i) { + for (int j = 0; j < lo; ++j) { + LOGI("EDIT[%d][%d], %d", i, j, dp[i * lo + j]); + } + } + } + return dp[li * lo - 1]; +} + ////////////////////// // RankingAlgorithm // ////////////////////// -int Correction::RankingAlgorithm::calculateFinalFreq( - const int inputIndex, const int outputIndex, - const int matchCount, const int freq, const bool sameLength, +/* static */ +int Correction::RankingAlgorithm::calculateFinalFreq(const int inputIndex, const int outputIndex, + const int freq, const bool sameLength, int* editDistanceTable, const Correction* correction) { - const int skipPos = correction->getSkipPos(); const int excessivePos = correction->getExcessivePos(); const int transposedPos = correction->getTransposedPos(); const int inputLength = correction->mInputLength; const int typedLetterMultiplier = correction->TYPED_LETTER_MULTIPLIER; const int fullWordMultiplier = correction->FULL_WORD_MULTIPLIER; const ProximityInfo *proximityInfo = correction->mProximityInfo; - const int matchWeight = powerIntCapped(typedLetterMultiplier, matchCount); + const int skipCount = correction->mSkippedCount; + const int proximityMatchedCount = correction->mProximityCount; + + // TODO: use mExcessiveCount + int matchCount = inputLength - correction->mProximityCount - (excessivePos >= 0 ? 1 : 0); + + const unsigned short* word = correction->mWord; + const bool skipped = skipCount > 0; + + // ----- TODO: use edit distance here as follows? ---------------------- / + //if (!skipped && excessivePos < 0 && transposedPos < 0) { + // const int ed = editDistance(dp, proximityInfo->getInputWord(), + // inputLength, word, outputIndex + 1); + // matchCount = outputIndex + 1 - ed; + // if (ed == 1 && !sameLength) ++matchCount; + //} + // const int ed = editDistance(dp, proximityInfo->getInputWord(), + // inputLength, word, outputIndex + 1); + // if (ed == 1 && !sameLength) ++matchCount; ------------------------ / + int matchWeight = powerIntCapped(typedLetterMultiplier, matchCount); // TODO: Demote by edit distance int finalFreq = freq * matchWeight; - if (skipPos >= 0) { + // +1 +11/-12 + /*if (inputLength == outputIndex && !skipped && excessivePos < 0 && transposedPos < 0) { + const int ed = editDistance(dp, proximityInfo->getInputWord(), + inputLength, word, outputIndex + 1); + if (ed == 1) { + multiplyRate(160, &finalFreq); + } + }*/ + if (inputLength == outputIndex && excessivePos < 0 && transposedPos < 0 + && (proximityMatchedCount > 0 || skipped)) { + const int ed = editDistance(editDistanceTable, proximityInfo->getPrimaryInputWord(), + inputLength, word, outputIndex + 1); + if (ed == 1) { + multiplyRate(160, &finalFreq); + } + } + + // TODO: Promote properly? + //if (skipCount == 1 && excessivePos < 0 && transposedPos < 0 && inputLength == outputIndex + // && !sameLength) { + // multiplyRate(150, &finalFreq); + //} + //if (skipCount == 0 && excessivePos < 0 && transposedPos < 0 && inputLength == outputIndex + // && !sameLength) { + // multiplyRate(150, &finalFreq); + //} + //if (skipCount == 0 && excessivePos < 0 && transposedPos < 0 + // && inputLength == outputIndex + 1) { + // multiplyRate(150, &finalFreq); + //} + + if (skipped) { if (inputLength >= 2) { const int demotionRate = WORDS_WITH_MISSING_CHARACTER_DEMOTION_RATE * (10 * inputLength - WORDS_WITH_MISSING_CHARACTER_DEMOTION_START_POS_10X) @@ -351,10 +488,10 @@ int Correction::RankingAlgorithm::calculateFinalFreq( } multiplyRate(FULL_MATCHED_WORDS_PROMOTION_RATE, &finalFreq); } - if (sameLength && transposedPos < 0 && skipPos < 0 && excessivePos < 0) { + if (sameLength && transposedPos < 0 && !skipped && excessivePos < 0) { finalFreq = capped255MultForFullMatchAccentsOrCapitalizationDifference(finalFreq); } - } else if (sameLength && transposedPos < 0 && skipPos < 0 && excessivePos < 0 + } else if (sameLength && transposedPos < 0 && !skipped && excessivePos < 0 && outputIndex > 0) { // A word with proximity corrections if (DEBUG_DICT) { @@ -363,13 +500,34 @@ int Correction::RankingAlgorithm::calculateFinalFreq( multiplyIntCapped(typedLetterMultiplier, &finalFreq); multiplyRate(WORDS_WITH_PROXIMITY_CHARACTER_DEMOTION_RATE, &finalFreq); } - if (DEBUG_DICT) { + if (DEBUG_DICT_FULL) { LOGI("calc: %d, %d", outputIndex, sameLength); } if (sameLength) multiplyIntCapped(fullWordMultiplier, &finalFreq); + + // TODO: check excessive count and transposed count + /* + If the last character of the user input word is the same as the next character + of the output word, and also all of characters of the user input are matched + to the output word, we'll promote that word a bit because + that word can be considered the combination of skipped and matched characters. + This means that the 'sm' pattern wins over the 'ma' pattern. + e.g.) + shel -> shell [mmmma] or [mmmsm] + hel -> hello [mmmaa] or [mmsma] + m ... matching + s ... skipping + a ... traversing all + */ + if (matchCount == inputLength && matchCount >= 2 && !skipped + && word[matchCount] == word[matchCount - 1]) { + multiplyRate(WORDS_WITH_MATCH_SKIP_PROMOTION_RATE, &finalFreq); + } + return finalFreq; } +/* static */ int Correction::RankingAlgorithm::calcFreqForSplitTwoWords( const int firstFreq, const int secondFreq, const Correction* correction) { const int spaceProximityPos = correction->mSpaceProximityPos; |