diff options
Diffstat (limited to 'native')
-rw-r--r-- | native/src/correction.cpp | 244 | ||||
-rw-r--r-- | native/src/correction.h | 17 | ||||
-rw-r--r-- | native/src/correction_state.h | 34 | ||||
-rw-r--r-- | native/src/unigram_dictionary.cpp | 25 |
4 files changed, 213 insertions, 107 deletions
diff --git a/native/src/correction.cpp b/native/src/correction.cpp index 99412b211..fb160149d 100644 --- a/native/src/correction.cpp +++ b/native/src/correction.cpp @@ -56,17 +56,22 @@ void Correction::initCorrectionState( const int rootPos, const int childCount, const bool traverseAll) { latinime::initCorrectionState(mCorrectionStates, rootPos, childCount, traverseAll); // TODO: remove + mCorrectionStates[0].mTransposedPos = mTransposedPos; + mCorrectionStates[0].mExcessivePos = mExcessivePos; mCorrectionStates[0].mSkipPos = mSkipPos; } void Correction::setCorrectionParams(const int skipPos, const int excessivePos, const int transposedPos, const int spaceProximityPos, const int missingSpacePos) { // TODO: remove + mTransposedPos = transposedPos; + mExcessivePos = excessivePos; mSkipPos = skipPos; // TODO: remove + mCorrectionStates[0].mTransposedPos = transposedPos; + mCorrectionStates[0].mExcessivePos = excessivePos; mCorrectionStates[0].mSkipPos = skipPos; - mExcessivePos = excessivePos; - mTransposedPos = transposedPos; + mSpaceProximityPos = spaceProximityPos; mMissingSpacePos = missingSpacePos; } @@ -107,12 +112,23 @@ bool Correction::initProcessState(const int outputIndex) { --(mCorrectionStates[outputIndex].mChildCount); mInputIndex = mCorrectionStates[outputIndex].mInputIndex; mNeedsToTraverseAllNodes = mCorrectionStates[outputIndex].mNeedsToTraverseAllNodes; + mProximityCount = mCorrectionStates[outputIndex].mProximityCount; + mTransposedCount = mCorrectionStates[outputIndex].mTransposedCount; + mExcessiveCount = mCorrectionStates[outputIndex].mExcessiveCount; mSkippedCount = mCorrectionStates[outputIndex].mSkippedCount; + mLastCharExceeded = mCorrectionStates[outputIndex].mLastCharExceeded; + + mTransposedPos = mCorrectionStates[outputIndex].mTransposedPos; + mExcessivePos = mCorrectionStates[outputIndex].mExcessivePos; mSkipPos = mCorrectionStates[outputIndex].mSkipPos; - mSkipping = false; - mProximityMatching = false; + mMatching = false; + mProximityMatching = false; + mTransposing = false; + mExceeding = false; + mSkipping = false; + return true; } @@ -150,12 +166,23 @@ void Correction::incrementOutputIndex() { mCorrectionStates[mOutputIndex].mSiblingPos = mCorrectionStates[mOutputIndex - 1].mSiblingPos; mCorrectionStates[mOutputIndex].mInputIndex = mInputIndex; mCorrectionStates[mOutputIndex].mNeedsToTraverseAllNodes = mNeedsToTraverseAllNodes; + mCorrectionStates[mOutputIndex].mProximityCount = mProximityCount; + mCorrectionStates[mOutputIndex].mTransposedCount = mTransposedCount; + mCorrectionStates[mOutputIndex].mExcessiveCount = mExcessiveCount; mCorrectionStates[mOutputIndex].mSkippedCount = mSkippedCount; - mCorrectionStates[mOutputIndex].mSkipping = mSkipping; + mCorrectionStates[mOutputIndex].mSkipPos = mSkipPos; + mCorrectionStates[mOutputIndex].mTransposedPos = mTransposedPos; + mCorrectionStates[mOutputIndex].mExcessivePos = mExcessivePos; + + mCorrectionStates[mOutputIndex].mLastCharExceeded = mLastCharExceeded; + mCorrectionStates[mOutputIndex].mMatching = mMatching; mCorrectionStates[mOutputIndex].mProximityMatching = mProximityMatching; + mCorrectionStates[mOutputIndex].mTransposing = mTransposing; + mCorrectionStates[mOutputIndex].mExceeding = mExceeding; + mCorrectionStates[mOutputIndex].mSkipping = mSkipping; } void Correction::startToTraverseAllNodes() { @@ -183,102 +210,138 @@ Correction::CorrectionType Correction::processSkipChar( Correction::CorrectionType Correction::processCharAndCalcState( const int32_t c, const bool isTerminal) { - CorrectionType currentStateType = NOT_ON_TERMINAL; - // This has to be done for each virtual char (this forwards the "inputIndex" which - // is the index in the user-inputted chars, as read by proximity chars. - if (mExcessivePos == mOutputIndex && mInputIndex < mInputLength - 1) { - incrementInputIndex(); + + if (mNeedsToTraverseAllNodes || isQuote(c)) { + if (mLastCharExceeded > 0 && mInputIndex == mInputLength - 1 + && mProximityInfo->getMatchedProximityId(mInputIndex, c, false) + == ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR) { + mLastCharExceeded = false; + --mExcessiveCount; + } + return processSkipChar(c, isTerminal); + } + + if (mExcessivePos >= 0) { + if (mExcessiveCount == 0 && mExcessivePos < mOutputIndex) { + mExcessivePos = mOutputIndex; + } + if (mExcessivePos < mInputLength - 1) { + mExceeding = mExcessivePos == mInputIndex; + } } - bool skip = false; if (mSkipPos >= 0) { if (mSkippedCount == 0 && mSkipPos < mOutputIndex) { if (DEBUG_DICT) { assert(mSkipPos == mOutputIndex - 1); } - ++mSkipPos; + mSkipPos = mOutputIndex; } - skip = mSkipPos == mOutputIndex; - mSkipping = true; + mSkipping = mSkipPos == mOutputIndex; } - if (mNeedsToTraverseAllNodes || isQuote(c)) { - return processSkipChar(c, isTerminal); - } else { - int inputIndexForProximity = mInputIndex; + if (mTransposedPos >= 0) { + if (mTransposedCount == 0 && mTransposedPos < mOutputIndex) { + mTransposedPos = mOutputIndex; + } + if (mTransposedPos < mInputLength - 1) { + mTransposing = mInputIndex == mTransposedPos; + } + } - if (mTransposedPos >= 0) { - if (mInputIndex == mTransposedPos) { - ++inputIndexForProximity; - } - if (mInputIndex == (mTransposedPos + 1)) { - --inputIndexForProximity; - } + bool secondTransposing = false; + if (mTransposedCount % 2 == 1) { + if (mProximityInfo->getMatchedProximityId(mInputIndex - 1, c, false) + == ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR) { + ++mTransposedCount; + secondTransposing = true; + } else if (mCorrectionStates[mOutputIndex].mExceeding) { + --mTransposedCount; + ++mExcessiveCount; + incrementInputIndex(); + } else { + --mTransposedCount; + return UNRELATED; } + } - // TODO: sum counters - const bool checkProximityChars = - !(mSkippedCount > 0 || mExcessivePos >= 0 || mTransposedPos >= 0); - int matchedProximityCharId = mProximityInfo->getMatchedProximityId( - inputIndexForProximity, c, checkProximityChars); - - if (ProximityInfo::UNRELATED_CHAR == matchedProximityCharId) { - if (skip && mProximityCount == 0) { - // Skip this letter and continue deeper - ++mSkippedCount; - return processSkipChar(c, isTerminal); - } else if (checkProximityChars - && inputIndexForProximity > 0 - && mCorrectionStates[mOutputIndex].mProximityMatching - && mCorrectionStates[mOutputIndex].mSkipping - && mProximityInfo->getMatchedProximityId( - inputIndexForProximity - 1, c, false) - == ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR) { - // Note: This logic tries saving cases like contrst --> contrast -- "a" is one of - // proximity chars of "s", but it should rather be handled as a skipped char. - ++mSkippedCount; - --mProximityCount; - return processSkipChar(c, isTerminal); + // TODO: sum counters + const bool checkProximityChars = + !(mSkippedCount > 0 || mExcessivePos >= 0 || mTransposedPos >= 0); + const int matchedProximityCharId = secondTransposing + ? ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR + : mProximityInfo->getMatchedProximityId(mInputIndex, c, checkProximityChars); + + if (ProximityInfo::UNRELATED_CHAR == matchedProximityCharId) { + if (mInputIndex - 1 < mInputLength && (mExceeding || mTransposing) + && mProximityInfo->getMatchedProximityId(mInputIndex + 1, c, false) + == ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR) { + if (mTransposing) { + ++mTransposedCount; } else { - return UNRELATED; + ++mExcessiveCount; + incrementInputIndex(); } - } else if (ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR == matchedProximityCharId) { - // If inputIndex is greater than mInputLength, that means there is no - // proximity chars. So, we don't need to check proximity. - mMatching = true; - } else if (ProximityInfo::NEAR_PROXIMITY_CHAR == matchedProximityCharId) { - mProximityMatching = true; - incrementProximityCount(); + } else if (mSkipping && mProximityCount == 0) { + // Skip this letter and continue deeper + ++mSkippedCount; + return processSkipChar(c, isTerminal); + } else if (checkProximityChars + && mInputIndex > 0 + && mCorrectionStates[mOutputIndex].mProximityMatching + && mCorrectionStates[mOutputIndex].mSkipping + && mProximityInfo->getMatchedProximityId(mInputIndex - 1, c, false) + == ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR) { + // Note: This logic tries saving cases like contrst --> contrast -- "a" is one of + // proximity chars of "s", but it should rather be handled as a skipped char. + ++mSkippedCount; + --mProximityCount; + return processSkipChar(c, isTerminal); + } else { + return UNRELATED; } + } else if (secondTransposing + || ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR == matchedProximityCharId) { + // If inputIndex is greater than mInputLength, that means there is no + // proximity chars. So, we don't need to check proximity. + mMatching = true; + } else if (ProximityInfo::NEAR_PROXIMITY_CHAR == matchedProximityCharId) { + mProximityMatching = true; + incrementProximityCount(); + } - mWord[mOutputIndex] = c; + mWord[mOutputIndex] = c; - const bool isSameAsUserTypedLength = mInputLength - == getInputIndex() + 1 - || (mExcessivePos == mInputLength - 1 - && getInputIndex() == mInputLength - 2); - if (isSameAsUserTypedLength && isTerminal) { - mTerminalInputIndex = mInputIndex; - mTerminalOutputIndex = mOutputIndex; - currentStateType = ON_TERMINAL; - } - // Start traversing all nodes after the index exceeds the user typed length - if (isSameAsUserTypedLength) { - startToTraverseAllNodes(); - } + mLastCharExceeded = mExcessiveCount == 0 && mSkippedCount == 0 + && mProximityCount == 0 && mTransposedCount == 0 + // TODO: remove this line once excessive correction is conmibned to others. + && mExcessivePos >= 0 && (mInputIndex == mInputLength - 2); + const bool isSameAsUserTypedLength = (mInputLength == mInputIndex + 1) || mLastCharExceeded; + if (mLastCharExceeded) { + ++mExcessiveCount; + } - // Finally, we are ready to go to the next character, the next "virtual node". - // We should advance the input index. - // We do this in this branch of the 'if traverseAllNodes' because we are still matching - // characters to input; the other branch is not matching them but searching for - // completions, this is why it does not have to do it. - incrementInputIndex(); + // Start traversing all nodes after the index exceeds the user typed length + if (isSameAsUserTypedLength) { + startToTraverseAllNodes(); } + // Finally, we are ready to go to the next character, the next "virtual node". + // We should advance the input index. + // We do this in this branch of the 'if traverseAllNodes' because we are still matching + // characters to input; the other branch is not matching them but searching for + // completions, this is why it does not have to do it. + incrementInputIndex(); // Also, the next char is one "virtual node" depth more than this char. incrementOutputIndex(); - return currentStateType; + if (isSameAsUserTypedLength && isTerminal) { + mTerminalInputIndex = mInputIndex - 1; + mTerminalOutputIndex = mOutputIndex - 1; + return ON_TERMINAL; + } else { + return NOT_ON_TERMINAL; + } } Correction::~Correction() { @@ -395,20 +458,33 @@ int Correction::RankingAlgorithm::calculateFinalFreq(const int inputIndex, const const int typedLetterMultiplier = correction->TYPED_LETTER_MULTIPLIER; const int fullWordMultiplier = correction->FULL_WORD_MULTIPLIER; const ProximityInfo *proximityInfo = correction->mProximityInfo; - const int skipCount = correction->mSkippedCount; + const int skippedCount = correction->mSkippedCount; + const int transposedCount = correction->mTransposedCount; + const int excessiveCount = correction->mExcessiveCount; const int proximityMatchedCount = correction->mProximityCount; - if (skipCount >= inputLength || inputLength == 0) { + const bool lastCharExceeded = correction->mLastCharExceeded; + if (skippedCount >= inputLength || inputLength == 0) { + return -1; + } + + // TODO: remove + if (transposedPos >= 0 && transposedCount == 0) { return -1; } - const bool sameLength = (excessivePos == inputLength - 1) ? (inputLength == inputIndex + 2) - : (inputLength == inputIndex + 1); + // TODO: remove + if (excessivePos >= 0 && excessiveCount == 0) { + return -1; + } + + const bool sameLength = lastCharExceeded ? (inputLength == inputIndex + 2) + : (inputLength == inputIndex + 1); // TODO: use mExcessiveCount int matchCount = inputLength - correction->mProximityCount - (excessivePos >= 0 ? 1 : 0); const unsigned short* word = correction->mWord; - const bool skipped = skipCount > 0; + const bool skipped = skippedCount > 0; const int quoteDiffCount = max(0, getQuoteCount(word, outputIndex + 1) - getQuoteCount(proximityInfo->getPrimaryInputWord(), inputLength)); @@ -417,6 +493,8 @@ int Correction::RankingAlgorithm::calculateFinalFreq(const int inputIndex, const int matchWeight; int ed = 0; int adJustedProximityMatchedCount = proximityMatchedCount; + + // TODO: Optimize this. if (excessivePos < 0 && transposedPos < 0 && (proximityMatchedCount > 0 || skipped)) { const unsigned short* primaryInputWord = proximityInfo->getPrimaryInputWord(); ed = editDistance(editDistanceTable, primaryInputWord, @@ -475,7 +553,7 @@ int Correction::RankingAlgorithm::calculateFinalFreq(const int inputIndex, const multiplyRate(WORDS_WITH_PROXIMITY_CHARACTER_DEMOTION_RATE, &finalFreq); } - const int errorCount = proximityMatchedCount + skipCount; + const int errorCount = proximityMatchedCount + skippedCount; multiplyRate( 100 - CORRECTION_COUNT_RATE_DEMOTION_RATE_BASE * errorCount / inputLength, &finalFreq); diff --git a/native/src/correction.h b/native/src/correction.h index 871a04251..3cd600cf0 100644 --- a/native/src/correction.h +++ b/native/src/correction.h @@ -113,8 +113,6 @@ private: int mMaxEditDistance; int mMaxDepth; int mInputLength; - int mExcessivePos; - int mTransposedPos; int mSpaceProximityPos; int mMissingSpacePos; int mTerminalInputIndex; @@ -126,15 +124,26 @@ private: CorrectionState mCorrectionStates[MAX_WORD_LENGTH_INTERNAL]; // The following member variables are being used as cache values of the correction state. + bool mNeedsToTraverseAllNodes; int mOutputIndex; int mInputIndex; + int mProximityCount; + int mExcessiveCount; + int mTransposedCount; int mSkippedCount; + + int mTransposedPos; + int mExcessivePos; int mSkipPos; - bool mNeedsToTraverseAllNodes; + + bool mLastCharExceeded; + bool mMatching; - bool mSkipping; bool mProximityMatching; + bool mExceeding; + bool mTransposing; + bool mSkipping; class RankingAlgorithm { public: diff --git a/native/src/correction_state.h b/native/src/correction_state.h index 267deda9b..93f8a8aab 100644 --- a/native/src/correction_state.h +++ b/native/src/correction_state.h @@ -28,12 +28,25 @@ struct CorrectionState { int mSiblingPos; uint16_t mChildCount; uint8_t mInputIndex; + uint8_t mProximityCount; + uint8_t mTransposedCount; + uint8_t mExcessiveCount; uint8_t mSkippedCount; + + int8_t mTransposedPos; + int8_t mExcessivePos; int8_t mSkipPos; // should be signed + + // TODO: int? + bool mLastCharExceeded; + bool mMatching; + bool mTransposing; + bool mExceeding; bool mSkipping; bool mProximityMatching; + bool mNeedsToTraverseAllNodes; }; @@ -43,14 +56,27 @@ inline static void initCorrectionState(CorrectionState *state, const int rootPos state->mParentIndex = -1; state->mChildCount = childCount; state->mInputIndex = 0; - state->mProximityCount = 0; state->mSiblingPos = rootPos; + state->mNeedsToTraverseAllNodes = traverseAll; + + state->mTransposedPos = -1; + state->mExcessivePos = -1; + state->mSkipPos = -1; + + + state->mProximityCount = 0; + state->mTransposedCount = 0; + state->mExcessiveCount = 0; state->mSkippedCount = 0; + + state->mLastCharExceeded = false; + state->mMatching = false; - state->mSkipping = false; state->mProximityMatching = false; - state->mNeedsToTraverseAllNodes = traverseAll; - state->mSkipPos = -1; + state->mTransposing = false; + state->mExceeding = false; + state->mSkipping = false; + } } // namespace latinime diff --git a/native/src/unigram_dictionary.cpp b/native/src/unigram_dictionary.cpp index 6bc350505..805e1cbb7 100644 --- a/native/src/unigram_dictionary.cpp +++ b/native/src/unigram_dictionary.cpp @@ -194,34 +194,27 @@ void UnigramDictionary::getWordSuggestions(ProximityInfo *proximityInfo, PROF_START(2); // Suggestion with missing character - LOGI("--- Suggest missing characters"); + if (DEBUG_DICT) { + LOGI("--- Suggest missing characters"); + } getSuggestionCandidates(0, -1, -1); PROF_END(2); PROF_START(3); // Suggestion with excessive character - if (SUGGEST_WORDS_WITH_EXCESSIVE_CHARACTER - && mInputLength >= MIN_USER_TYPED_LENGTH_FOR_EXCESSIVE_CHARACTER_SUGGESTION) { - for (int i = 0; i < codesSize; ++i) { - if (DEBUG_DICT) { - LOGI("--- Suggest excessive characters %d", i); - } - getSuggestionCandidates(-1, i, -1); - } + if (DEBUG_DICT) { + LOGI("--- Suggest excessive characters"); } + getSuggestionCandidates(-1, 0, -1); PROF_END(3); PROF_START(4); // Suggestion with transposed characters // Only suggest words that length is mInputLength - if (SUGGEST_WORDS_WITH_TRANSPOSED_CHARACTERS) { - for (int i = 0; i < codesSize; ++i) { - if (DEBUG_DICT) { - LOGI("--- Suggest transposed characters %d", i); - } - getSuggestionCandidates(-1, -1, i); - } + if (DEBUG_DICT) { + LOGI("--- Suggest transposed characters"); } + getSuggestionCandidates(-1, -1, 0); PROF_END(4); PROF_START(5); |