diff options
Diffstat (limited to 'native/src')
-rw-r--r-- | native/src/correction.cpp | 103 | ||||
-rw-r--r-- | native/src/correction.h | 60 | ||||
-rw-r--r-- | native/src/correction_state.h | 55 | ||||
-rw-r--r-- | native/src/defines.h | 1 | ||||
-rw-r--r-- | native/src/proximity_info.h | 1 | ||||
-rw-r--r-- | native/src/unigram_dictionary.cpp | 34 | ||||
-rw-r--r-- | native/src/unigram_dictionary.h | 11 |
7 files changed, 187 insertions, 78 deletions
diff --git a/native/src/correction.cpp b/native/src/correction.cpp index 6d682c0c9..654d4715f 100644 --- a/native/src/correction.cpp +++ b/native/src/correction.cpp @@ -49,7 +49,11 @@ void Correction::initCorrection(const ProximityInfo *pi, const int inputLength, mInputLength = inputLength; mMaxDepth = maxDepth; mMaxEditDistance = mInputLength < 5 ? 2 : mInputLength / 2; - mSkippedOutputIndex = -1; +} + +void Correction::initCorrectionState( + const int rootPos, const int childCount, const bool traverseAll) { + latinime::initCorrectionState(mCorrectionStates, rootPos, childCount, traverseAll); } void Correction::setCorrectionParams(const int skipPos, const int excessivePos, @@ -83,6 +87,12 @@ int Correction::getFinalFreq(const int freq, unsigned short **word, int *wordLen if (mProximityInfo->sameAsTyped(mWord, outputIndex + 1) || outputIndex < MIN_SUGGEST_DEPTH) { return -1; } + + // TODO: Remove this + if (mSkipPos >= 0 && mSkippedCount <= 0) { + return -1; + } + *word = mWord; const bool sameLength = (mExcessivePos == mInputLength - 1) ? (mInputLength == inputIndex + 2) : (mInputLength == inputIndex + 1); @@ -90,22 +100,28 @@ int Correction::getFinalFreq(const int freq, unsigned short **word, int *wordLen inputIndex, outputIndex, mMatchedCharCount, freq, sameLength, this); } -void Correction::initProcessState(const int matchCount, const int inputIndex, - const int outputIndex, const bool traverseAllNodes, const int diffs) { - mMatchedCharCount = matchCount; - mInputIndex = inputIndex; +bool Correction::initProcessState(const int outputIndex) { + if (mCorrectionStates[outputIndex].mChildCount <= 0) { + return false; + } mOutputIndex = outputIndex; - mTraverseAllNodes = traverseAllNodes; - mDiffs = diffs; + --(mCorrectionStates[outputIndex].mChildCount); + mMatchedCharCount = mCorrectionStates[outputIndex].mMatchedCount; + mInputIndex = mCorrectionStates[outputIndex].mInputIndex; + mNeedsToTraverseAllNodes = mCorrectionStates[outputIndex].mNeedsToTraverseAllNodes; + mDiffs = mCorrectionStates[outputIndex].mDiffs; + mSkippedCount = mCorrectionStates[outputIndex].mSkippedCount; + mSkipping = false; + mMatching = false; + return true; } -void Correction::getProcessState(int *matchedCount, int *inputIndex, int *outputIndex, - bool *traverseAllNodes, int *diffs) { - *matchedCount = mMatchedCharCount; - *inputIndex = mInputIndex; - *outputIndex = mOutputIndex; - *traverseAllNodes = mTraverseAllNodes; - *diffs = mDiffs; +int Correction::goDownTree( + const int parentIndex, const int childCount, const int firstChildPos) { + mCorrectionStates[mOutputIndex].mParentIndex = parentIndex; + mCorrectionStates[mOutputIndex].mChildCount = childCount; + mCorrectionStates[mOutputIndex].mSiblingPos = firstChildPos; + return mOutputIndex; } void Correction::charMatched() { @@ -123,8 +139,8 @@ int Correction::getInputIndex() { } // TODO: remove -bool Correction::needsToTraverseAll() { - return mTraverseAllNodes; +bool Correction::needsToTraverseAllNodes() { + return mNeedsToTraverseAllNodes; } void Correction::incrementInputIndex() { @@ -133,10 +149,20 @@ void Correction::incrementInputIndex() { void Correction::incrementOutputIndex() { ++mOutputIndex; + mCorrectionStates[mOutputIndex].mParentIndex = mCorrectionStates[mOutputIndex - 1].mParentIndex; + mCorrectionStates[mOutputIndex].mChildCount = mCorrectionStates[mOutputIndex - 1].mChildCount; + mCorrectionStates[mOutputIndex].mSiblingPos = mCorrectionStates[mOutputIndex - 1].mSiblingPos; + mCorrectionStates[mOutputIndex].mMatchedCount = mMatchedCharCount; + mCorrectionStates[mOutputIndex].mInputIndex = mInputIndex; + mCorrectionStates[mOutputIndex].mNeedsToTraverseAllNodes = mNeedsToTraverseAllNodes; + mCorrectionStates[mOutputIndex].mDiffs = mDiffs; + mCorrectionStates[mOutputIndex].mSkippedCount = mSkippedCount; + mCorrectionStates[mOutputIndex].mSkipping = mSkipping; + mCorrectionStates[mOutputIndex].mMatching = mMatching; } -void Correction::startTraverseAll() { - mTraverseAllNodes = true; +void Correction::startToTraverseAllNodes() { + mNeedsToTraverseAllNodes = true; } bool Correction::needsToPrune() const { @@ -147,7 +173,7 @@ bool Correction::needsToPrune() const { Correction::CorrectionType Correction::processSkipChar( const int32_t c, const bool isTerminal) { mWord[mOutputIndex] = c; - if (needsToTraverseAll() && isTerminal) { + if (needsToTraverseAllNodes() && isTerminal) { mTerminalInputIndex = mInputIndex; mTerminalOutputIndex = mOutputIndex; incrementOutputIndex(); @@ -170,9 +196,10 @@ Correction::CorrectionType Correction::processCharAndCalcState( bool skip = false; if (mSkipPos >= 0) { skip = mSkipPos == mOutputIndex; + mSkipping = true; } - if (mTraverseAllNodes || isQuote(c)) { + if (mNeedsToTraverseAllNodes || isQuote(c)) { return processSkipChar(c, isTerminal); } else { int inputIndexForProximity = mInputIndex; @@ -195,25 +222,23 @@ Correction::CorrectionType Correction::processCharAndCalcState( if (unrelated) { if (skip) { // Skip this letter and continue deeper - mSkippedOutputIndex = mOutputIndex; + ++mSkippedCount; return processSkipChar(c, isTerminal); } else { return UNRELATED; } } - // No need to skip. Finish traversing and increment skipPos. - // TODO: Remove this? + // TODO: remove after allowing combination errors if (skip) { - mWord[mOutputIndex] = c; - incrementOutputIndex(); - return TRAVERSE_ALL_NOT_ON_TERMINAL; + return UNRELATED; } mWord[mOutputIndex] = c; // If inputIndex is greater than mInputLength, that means there is no // proximity chars. So, we don't need to check proximity. if (ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR == matchedProximityCharId) { + mMatching = true; charMatched(); } @@ -232,7 +257,7 @@ Correction::CorrectionType Correction::processCharAndCalcState( } // Start traversing all nodes after the index exceeds the user typed length if (isSameAsUserTypedLength) { - startTraverseAll(); + startToTraverseAllNodes(); } // Finally, we are ready to go to the next character, the next "virtual node". @@ -302,6 +327,7 @@ inline static void multiplyRate(const int rate, int *freq) { // RankingAlgorithm // ////////////////////// +/* static */ int Correction::RankingAlgorithm::calculateFinalFreq( const int inputIndex, const int outputIndex, const int matchCount, const int freq, const bool sameLength, @@ -314,6 +340,8 @@ int Correction::RankingAlgorithm::calculateFinalFreq( const int fullWordMultiplier = correction->FULL_WORD_MULTIPLIER; const ProximityInfo *proximityInfo = correction->mProximityInfo; const int matchWeight = powerIntCapped(typedLetterMultiplier, matchCount); + const unsigned short* word = correction->mWord; + const int skippedCount = correction->mSkippedCount; // TODO: Demote by edit distance int finalFreq = freq * matchWeight; @@ -367,9 +395,30 @@ int Correction::RankingAlgorithm::calculateFinalFreq( LOGI("calc: %d, %d", outputIndex, sameLength); } if (sameLength) multiplyIntCapped(fullWordMultiplier, &finalFreq); + + // TODO: check excessive count and transposed count + /* + If the last character of the user input word is the same as the next character + of the output word, and also all of characters of the user input are matched + to the output word, we'll promote that word a bit because + that word can be considered the combination of skipped and matched characters. + This means that the 'sm' pattern wins over the 'ma' pattern. + e.g.) + shel -> shell [mmmma] or [mmmsm] + hel -> hello [mmmaa] or [mmsma] + m ... matching + s ... skipping + a ... traversing all + */ + if (matchCount == inputLength && matchCount >= 2 && skippedCount == 0 + && word[matchCount] == word[matchCount - 1]) { + multiplyRate(WORDS_WITH_MATCH_SKIP_PROMOTION_RATE, &finalFreq); + } + return finalFreq; } +/* static */ int Correction::RankingAlgorithm::calcFreqForSplitTwoWords( const int firstFreq, const int secondFreq, const Correction* correction) { const int spaceProximityPos = correction->mSpaceProximityPos; diff --git a/native/src/correction.h b/native/src/correction.h index ae6c7a421..62fe38696 100644 --- a/native/src/correction.h +++ b/native/src/correction.h @@ -18,6 +18,7 @@ #define LATINIME_CORRECTION_H #include <stdint.h> +#include "correction_state.h" #include "defines.h" @@ -39,16 +40,18 @@ public: Correction(const int typedLetterMultiplier, const int fullWordMultiplier); void initCorrection( const ProximityInfo *pi, const int inputLength, const int maxWordLength); + void initCorrectionState(const int rootPos, const int childCount, const bool traverseAll); + + // TODO: remove void setCorrectionParams(const int skipPos, const int excessivePos, const int transposedPos, const int spaceProximityPos, const int missingSpacePos); void checkState(); - void initProcessState(const int matchCount, const int inputIndex, const int outputIndex, - const bool traverseAllNodes, const int diffs); + bool initProcessState(const int index); + void getProcessState(int *matchedCount, int *inputIndex, int *outputIndex, bool *traverseAllNodes, int *diffs); int getOutputIndex(); int getInputIndex(); - bool needsToTraverseAll(); virtual ~Correction(); int getSpaceProximityPos() const { @@ -80,44 +83,63 @@ public: int getDiffs() const { return mDiffs; } + + ///////////////////////// + // Tree helper methods + int goDownTree(const int parentIndex, const int childCount, const int firstChildPos); + + inline int getTreeSiblingPos(const int index) const { + return mCorrectionStates[index].mSiblingPos; + } + + inline void setTreeSiblingPos(const int index, const int pos) { + mCorrectionStates[index].mSiblingPos = pos; + } + + inline int getTreeParentIndex(const int index) const { + return mCorrectionStates[index].mParentIndex; + } private: - void charMatched(); - void incrementInputIndex(); - void incrementOutputIndex(); - void startTraverseAll(); + inline void charMatched(); + inline void incrementInputIndex(); + inline void incrementOutputIndex(); + inline bool needsToTraverseAllNodes(); + inline void startToTraverseAllNodes(); + inline bool isQuote(const unsigned short c); + inline CorrectionType processSkipChar(const int32_t c, const bool isTerminal); // TODO: remove - - void incrementDiffs() { + inline void incrementDiffs() { ++mDiffs; } const int TYPED_LETTER_MULTIPLIER; const int FULL_WORD_MULTIPLIER; - const ProximityInfo *mProximityInfo; int mMaxEditDistance; int mMaxDepth; int mInputLength; int mSkipPos; - int mSkippedOutputIndex; int mExcessivePos; int mTransposedPos; int mSpaceProximityPos; int mMissingSpacePos; - - int mMatchedCharCount; - int mInputIndex; - int mOutputIndex; int mTerminalInputIndex; int mTerminalOutputIndex; - int mDiffs; - bool mTraverseAllNodes; unsigned short mWord[MAX_WORD_LENGTH_INTERNAL]; - inline bool isQuote(const unsigned short c); - inline CorrectionType processSkipChar(const int32_t c, const bool isTerminal); + CorrectionState mCorrectionStates[MAX_WORD_LENGTH_INTERNAL]; + + // The following member variables are being used as cache values of the correction state. + int mOutputIndex; + int mInputIndex; + int mDiffs; + int mMatchedCharCount; + int mSkippedCount; + bool mNeedsToTraverseAllNodes; + bool mMatching; + bool mSkipping; class RankingAlgorithm { public: diff --git a/native/src/correction_state.h b/native/src/correction_state.h new file mode 100644 index 000000000..731222696 --- /dev/null +++ b/native/src/correction_state.h @@ -0,0 +1,55 @@ +/* + * Copyright (C) 2011 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_CORRECTION_STATE_H +#define LATINIME_CORRECTION_STATE_H + +#include <stdint.h> + +#include "defines.h" + +namespace latinime { + +struct CorrectionState { + int mParentIndex; + int mSiblingPos; + uint16_t mChildCount; + uint8_t mInputIndex; + uint8_t mDiffs; + uint8_t mMatchedCount; + uint8_t mSkippedCount; + bool mMatching; + bool mSkipping; + bool mNeedsToTraverseAllNodes; + +}; + +inline static void initCorrectionState(CorrectionState *state, const int rootPos, + const uint16_t childCount, const bool traverseAll) { + state->mParentIndex = -1; + state->mChildCount = childCount; + state->mInputIndex = 0; + state->mDiffs = 0; + state->mSiblingPos = rootPos; + state->mMatchedCount = 0; + state->mSkippedCount = 0; + state->mMatching = false; + state->mSkipping = false; + state->mNeedsToTraverseAllNodes = traverseAll; +} + +} // namespace latinime +#endif // LATINIME_CORRECTION_STATE_H diff --git a/native/src/defines.h b/native/src/defines.h index 5a5d3ee0c..c1838d341 100644 --- a/native/src/defines.h +++ b/native/src/defines.h @@ -160,6 +160,7 @@ static void prof_out(void) { #define WORDS_WITH_TRANSPOSED_CHARACTERS_DEMOTION_RATE 60 #define FULL_MATCHED_WORDS_PROMOTION_RATE 120 #define WORDS_WITH_PROXIMITY_CHARACTER_DEMOTION_RATE 90 +#define WORDS_WITH_MATCH_SKIP_PROMOTION_RATE 105 // This should be greater than or equal to MAX_WORD_LENGTH defined in BinaryDictionary.java // This is only used for the size of array. Not to be used in c functions. diff --git a/native/src/proximity_info.h b/native/src/proximity_info.h index 5034c3b89..d9ed46f5b 100644 --- a/native/src/proximity_info.h +++ b/native/src/proximity_info.h @@ -46,6 +46,7 @@ public: ProximityType getMatchedProximityId( const int index, const unsigned short c, const bool checkProximityChars) const; bool sameAsTyped(const unsigned short *word, int length) const; + private: int getStartIndexFromCoordinates(const int x, const int y) const; const int MAX_PROXIMITY_CHARS_SIZE; diff --git a/native/src/unigram_dictionary.cpp b/native/src/unigram_dictionary.cpp index df1a2e273..bbfaea454 100644 --- a/native/src/unigram_dictionary.cpp +++ b/native/src/unigram_dictionary.cpp @@ -352,44 +352,28 @@ void UnigramDictionary::getSuggestionCandidates(const int skipPos, int rootPosition = ROOT_POS; // Get the number of children of root, then increment the position int childCount = Dictionary::getCount(DICT_ROOT, &rootPosition); - int depth = 0; + int outputIndex = 0; - mStackChildCount[0] = childCount; - mStackTraverseAll[0] = (mInputLength <= 0); - mStackInputIndex[0] = 0; - mStackDiffs[0] = 0; - mStackSiblingPos[0] = rootPosition; - mStackOutputIndex[0] = 0; - mStackMatchedCount[0] = 0; + mCorrection->initCorrectionState(rootPosition, childCount, (mInputLength <= 0)); // Depth first search - while (depth >= 0) { - if (mStackChildCount[depth] > 0) { - --mStackChildCount[depth]; - int siblingPos = mStackSiblingPos[depth]; + while (outputIndex >= 0) { + if (mCorrection->initProcessState(outputIndex)) { + int siblingPos = mCorrection->getTreeSiblingPos(outputIndex); int firstChildPos; - mCorrection->initProcessState( - mStackMatchedCount[depth], mStackInputIndex[depth], mStackOutputIndex[depth], - mStackTraverseAll[depth], mStackDiffs[depth]); - // needsToTraverseChildrenNodes should be false const bool needsToTraverseChildrenNodes = processCurrentNode(siblingPos, mCorrection, &childCount, &firstChildPos, &siblingPos); // Update next sibling pos - mStackSiblingPos[depth] = siblingPos; + mCorrection->setTreeSiblingPos(outputIndex, siblingPos); + if (needsToTraverseChildrenNodes) { // Goes to child node - ++depth; - mStackChildCount[depth] = childCount; - mStackSiblingPos[depth] = firstChildPos; - - mCorrection->getProcessState(&mStackMatchedCount[depth], - &mStackInputIndex[depth], &mStackOutputIndex[depth], - &mStackTraverseAll[depth], &mStackDiffs[depth]); + outputIndex = mCorrection->goDownTree(outputIndex, childCount, firstChildPos); } } else { // Goes to parent sibling node - --depth; + outputIndex = mCorrection->getTreeParentIndex(outputIndex); } } } diff --git a/native/src/unigram_dictionary.h b/native/src/unigram_dictionary.h index 8bcd7cea5..cfe63ff79 100644 --- a/native/src/unigram_dictionary.h +++ b/native/src/unigram_dictionary.h @@ -19,6 +19,7 @@ #include <stdint.h> #include "correction.h" +#include "correction_state.h" #include "defines.h" #include "proximity_info.h" @@ -134,13 +135,9 @@ private: // MAX_WORD_LENGTH_INTERNAL must be bigger than MAX_WORD_LENGTH unsigned short mWord[MAX_WORD_LENGTH_INTERNAL]; - int mStackMatchedCount[MAX_WORD_LENGTH_INTERNAL]; - int mStackChildCount[MAX_WORD_LENGTH_INTERNAL]; - bool mStackTraverseAll[MAX_WORD_LENGTH_INTERNAL]; - int mStackInputIndex[MAX_WORD_LENGTH_INTERNAL]; - int mStackDiffs[MAX_WORD_LENGTH_INTERNAL]; - int mStackSiblingPos[MAX_WORD_LENGTH_INTERNAL]; - int mStackOutputIndex[MAX_WORD_LENGTH_INTERNAL]; + int mStackChildCount[MAX_WORD_LENGTH_INTERNAL];// TODO: remove + int mStackInputIndex[MAX_WORD_LENGTH_INTERNAL];// TODO: remove + int mStackSiblingPos[MAX_WORD_LENGTH_INTERNAL];// TODO: remove }; } // namespace latinime |