diff options
Diffstat (limited to 'native')
14 files changed, 155 insertions, 111 deletions
diff --git a/native/jni/Android.mk b/native/jni/Android.mk index fb60139d3..d5df6b62e 100644 --- a/native/jni/Android.mk +++ b/native/jni/Android.mk @@ -53,10 +53,10 @@ LATIN_IME_CORE_SRC_FILES := \ dic_nodes_cache.cpp) \ $(addprefix suggest/core/dictionary/, \ bigram_dictionary.cpp \ - binary_dictionary_bigrams_reading_utils.cpp \ binary_dictionary_format_utils.cpp \ binary_dictionary_header.cpp \ binary_dictionary_header_reading_utils.cpp \ + binary_dictionary_terminal_attributes_reading_utils.cpp \ bloom_filter.cpp \ byte_array_utils.cpp \ dictionary.cpp \ diff --git a/native/jni/src/suggest/core/dicnode/dic_node.h b/native/jni/src/suggest/core/dicnode/dic_node.h index 25299948d..c700b01ca 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node.h +++ b/native/jni/src/suggest/core/dicnode/dic_node.h @@ -28,13 +28,13 @@ #if DEBUG_DICT #define LOGI_SHOW_ADD_COST_PROP \ do { char charBuf[50]; \ - INTS_TO_CHARS(getOutputWordBuf(), getDepth(), charBuf); \ + INTS_TO_CHARS(getOutputWordBuf(), getNodeCodePointCount(), charBuf); \ AKLOGI("%20s, \"%c\", size = %03d, total = %03d, index(0) = %02d, dist = %.4f, %s,,", \ __FUNCTION__, getNodeCodePoint(), inputSize, getTotalInputIndex(), \ getInputIndex(0), getNormalizedCompoundDistance(), charBuf); } while (0) #define DUMP_WORD_AND_SCORE(header) \ do { char charBuf[50]; char prevWordCharBuf[50]; \ - INTS_TO_CHARS(getOutputWordBuf(), getDepth(), charBuf); \ + INTS_TO_CHARS(getOutputWordBuf(), getNodeCodePointCount(), charBuf); \ INTS_TO_CHARS(mDicNodeState.mDicNodeStatePrevWord.mPrevWord, \ mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(), prevWordCharBuf); \ AKLOGI("#%8s, %5f, %5f, %5f, %5f, %s, %s, %d,,", header, \ @@ -51,6 +51,11 @@ namespace latinime { // This struct is purely a bucket to return values. No instances of this struct should be kept. struct DicNode_InputStateG { + DicNode_InputStateG() + : mNeedsToUpdateInputStateG(false), mPointerId(0), mInputIndex(0), + mPrevCodePoint(0), mTerminalDiffCost(0.0f), mRawLength(0.0f), + mDoubleLetterLevel(NOT_A_DOUBLE_LETTER) {} + bool mNeedsToUpdateInputStateG; int mPointerId; int16_t mInputIndex; @@ -157,7 +162,7 @@ class DicNode { const bool isTerminal, const bool hasMultipleChars, const bool hasChildren, const uint16_t additionalSubwordLength, const int *additionalSubword) { mIsUsed = true; - uint16_t newDepth = static_cast<uint16_t>(dicNode->getDepth() + 1); + uint16_t newDepth = static_cast<uint16_t>(dicNode->getNodeCodePointCount() + 1); mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion; const uint16_t newLeavingDepth = static_cast<uint16_t>( dicNode->mDicNodeProperties.getLeavingDepth() + additionalSubwordLength); @@ -180,7 +185,7 @@ class DicNode { } bool isRoot() const { - return getDepth() == 0; + return getNodeCodePointCount() == 0; } bool hasChildren() const { @@ -188,12 +193,12 @@ class DicNode { } bool isLeavingNode() const { - ASSERT(getDepth() <= getLeavingDepth()); - return getDepth() == getLeavingDepth(); + ASSERT(getNodeCodePointCount() <= getLeavingDepth()); + return getNodeCodePointCount() == getLeavingDepth(); } AK_FORCE_INLINE bool isFirstLetter() const { - return getDepth() == 1; + return getNodeCodePointCount() == 1; } bool isCached() const { @@ -206,7 +211,7 @@ class DicNode { // Used to expand the node in DicNodeUtils int getNodeTypedCodePoint() const { - return mDicNodeState.mDicNodeStateOutput.getCodePointAt(getDepth()); + return mDicNodeState.mDicNodeStateOutput.getCodePointAt(getNodeCodePointCount()); } bool isImpossibleBigramWord() const { @@ -215,7 +220,7 @@ class DicNode { } const int prevWordLen = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength() - mDicNodeState.mDicNodeStatePrevWord.getPrevWordStart() - 1; - const int currentWordLen = getDepth(); + const int currentWordLen = getNodeCodePointCount(); return (prevWordLen == 1 && currentWordLen == 1); } @@ -263,13 +268,13 @@ class DicNode { AK_FORCE_INLINE bool isTerminalWordNode() const { const bool isTerminalNodes = mDicNodeProperties.isTerminal(); - const int currentNodeDepth = getDepth(); + const int currentNodeDepth = getNodeCodePointCount(); const int terminalNodeDepth = mDicNodeProperties.getLeavingDepth(); return isTerminalNodes && currentNodeDepth > 0 && currentNodeDepth == terminalNodeDepth; } bool shouldBeFilterdBySafetyNetForBigram() const { - const uint16_t currentDepth = getDepth(); + const uint16_t currentDepth = getNodeCodePointCount(); const int prevWordLen = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength() - mDicNodeState.mDicNodeStatePrevWord.getPrevWordStart() - 1; return !(currentDepth > 0 && (currentDepth != 1 || prevWordLen != 1)); @@ -281,7 +286,7 @@ class DicNode { bool isTotalInputSizeExceedingLimit() const { const int prevWordsLen = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(); - const int currentWordDepth = getDepth(); + const int currentWordDepth = getNodeCodePointCount(); // TODO: 3 can be 2? Needs to be investigated. // TODO: Have a const variable for 3 (or 2) return prevWordsLen + currentWordDepth > MAX_WORD_LENGTH - 3; @@ -316,7 +321,7 @@ class DicNode { void outputResult(int *dest) const { const uint16_t prevWordLength = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(); - const uint16_t currentDepth = getDepth(); + const uint16_t currentDepth = getNodeCodePointCount(); DicNodeUtils::appendTwoWords(mDicNodeState.mDicNodeStatePrevWord.mPrevWord, prevWordLength, getOutputWordBuf(), currentDepth, dest); DUMP_WORD_AND_SCORE("OUTPUT"); @@ -475,13 +480,13 @@ class DicNode { return mDicNodeProperties.getAttributesPos(); } - inline uint16_t getDepth() const { + inline uint16_t getNodeCodePointCount() const { return mDicNodeProperties.getDepth(); } - // "Length" includes spaces. - inline uint16_t getTotalLength() const { - return getDepth() + mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(); + // Returns code point count including spaces + inline uint16_t getTotalNodeCodePointCount() const { + return getNodeCodePointCount() + mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(); } AK_FORCE_INLINE void dump(const char *tag) const { @@ -516,8 +521,8 @@ class DicNode { } else if (diff < -MIN_DIFF) { return false; } - const int depth = getDepth(); - const int depthDiff = right->getDepth() - depth; + const int depth = getNodeCodePointCount(); + const int depthDiff = right->getNodeCodePointCount() - depth; if (depthDiff != 0) { return depthDiff > 0; } diff --git a/native/jni/src/suggest/core/dictionary/binary_dictionary_bigrams_iterator.h b/native/jni/src/suggest/core/dictionary/binary_dictionary_bigrams_iterator.h index 0856840b2..f2b48e960 100644 --- a/native/jni/src/suggest/core/dictionary/binary_dictionary_bigrams_iterator.h +++ b/native/jni/src/suggest/core/dictionary/binary_dictionary_bigrams_iterator.h @@ -18,8 +18,8 @@ #define LATINIME_BINARY_DICTIONARY_BIGRAMS_ITERATOR_H #include "defines.h" -#include "suggest/core/dictionary/binary_dictionary_bigrams_reading_utils.h" #include "suggest/core/dictionary/binary_dictionary_info.h" +#include "suggest/core/dictionary/binary_dictionary_terminal_attributes_reading_utils.h" namespace latinime { @@ -35,15 +35,17 @@ class BinaryDictionaryBigramsIterator { } AK_FORCE_INLINE void next() { - mBigramFlags = BinaryDictionaryBigramsReadingUtils::getFlagsAndForwardPointer( + mBigramFlags = BinaryDictionaryTerminalAttributesReadingUtils::getFlagsAndForwardPointer( mBinaryDictionaryInfo, &mPos); - mBigramPos = BinaryDictionaryBigramsReadingUtils::getBigramAddressAndForwardPointer( - mBinaryDictionaryInfo, mBigramFlags, &mPos); - mHasNext = BinaryDictionaryBigramsReadingUtils::hasNext(mBigramFlags); + mBigramPos = + BinaryDictionaryTerminalAttributesReadingUtils::getBigramAddressAndForwardPointer( + mBinaryDictionaryInfo, mBigramFlags, &mPos); + mHasNext = BinaryDictionaryTerminalAttributesReadingUtils::hasNext(mBigramFlags); } AK_FORCE_INLINE int getProbability() const { - return BinaryDictionaryBigramsReadingUtils::getBigramProbability(mBigramFlags); + return BinaryDictionaryTerminalAttributesReadingUtils::getProbabilityFromFlags( + mBigramFlags); } AK_FORCE_INLINE int getBigramPos() const { @@ -59,7 +61,7 @@ class BinaryDictionaryBigramsIterator { const BinaryDictionaryInfo *const mBinaryDictionaryInfo; int mPos; - BinaryDictionaryBigramsReadingUtils::BigramFlags mBigramFlags; + BinaryDictionaryTerminalAttributesReadingUtils::BigramFlags mBigramFlags; int mBigramPos; bool mHasNext; }; diff --git a/native/jni/src/suggest/core/dictionary/binary_dictionary_bigrams_reading_utils.cpp b/native/jni/src/suggest/core/dictionary/binary_dictionary_terminal_attributes_reading_utils.cpp index 78a54b141..0a7509c8b 100644 --- a/native/jni/src/suggest/core/dictionary/binary_dictionary_bigrams_reading_utils.cpp +++ b/native/jni/src/suggest/core/dictionary/binary_dictionary_terminal_attributes_reading_utils.cpp @@ -14,33 +14,28 @@ * limitations under the License. */ -#include "suggest/core/dictionary/binary_dictionary_bigrams_reading_utils.h" +#include "suggest/core/dictionary/binary_dictionary_terminal_attributes_reading_utils.h" #include "suggest/core/dictionary/binary_dictionary_info.h" #include "suggest/core/dictionary/byte_array_utils.h" namespace latinime { -const BinaryDictionaryBigramsReadingUtils::BigramFlags - BinaryDictionaryBigramsReadingUtils::MASK_ATTRIBUTE_ADDRESS_TYPE = 0x30; -const BinaryDictionaryBigramsReadingUtils::BigramFlags - BinaryDictionaryBigramsReadingUtils::FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE = 0x10; -const BinaryDictionaryBigramsReadingUtils::BigramFlags - BinaryDictionaryBigramsReadingUtils::FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES = 0x20; -const BinaryDictionaryBigramsReadingUtils::BigramFlags - BinaryDictionaryBigramsReadingUtils::FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES = 0x30; -const BinaryDictionaryBigramsReadingUtils::BigramFlags - BinaryDictionaryBigramsReadingUtils::FLAG_ATTRIBUTE_OFFSET_NEGATIVE = 0x40; +typedef BinaryDictionaryTerminalAttributesReadingUtils TaUtils; + +const TaUtils::TerminalAttributeFlags TaUtils::MASK_ATTRIBUTE_ADDRESS_TYPE = 0x30; +const TaUtils::TerminalAttributeFlags TaUtils::FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE = 0x10; +const TaUtils::TerminalAttributeFlags TaUtils::FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES = 0x20; +const TaUtils::TerminalAttributeFlags TaUtils::FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES = 0x30; +const TaUtils::TerminalAttributeFlags TaUtils::FLAG_ATTRIBUTE_OFFSET_NEGATIVE = 0x40; // Flag for presence of more attributes -const BinaryDictionaryBigramsReadingUtils::BigramFlags - BinaryDictionaryBigramsReadingUtils::FLAG_ATTRIBUTE_HAS_NEXT = 0x80; +const TaUtils::TerminalAttributeFlags TaUtils::FLAG_ATTRIBUTE_HAS_NEXT = 0x80; // Mask for attribute probability, stored on 4 bits inside the flags byte. -const BinaryDictionaryBigramsReadingUtils::BigramFlags - BinaryDictionaryBigramsReadingUtils::MASK_ATTRIBUTE_PROBABILITY = 0x0F; -const int BinaryDictionaryBigramsReadingUtils::ATTRIBUTE_ADDRESS_SHIFT = 4; +const TaUtils::TerminalAttributeFlags TaUtils::MASK_ATTRIBUTE_PROBABILITY = 0x0F; +const int TaUtils::ATTRIBUTE_ADDRESS_SHIFT = 4; -/* static */ int BinaryDictionaryBigramsReadingUtils::getBigramAddressAndForwardPointer( - const BinaryDictionaryInfo *const binaryDictionaryInfo, const BigramFlags flags, +/* static */ int TaUtils::getBigramAddressAndForwardPointer( + const BinaryDictionaryInfo *const binaryDictionaryInfo, const TerminalAttributeFlags flags, int *const pos) { int offset = 0; const int origin = *pos; diff --git a/native/jni/src/suggest/core/dictionary/binary_dictionary_bigrams_reading_utils.h b/native/jni/src/suggest/core/dictionary/binary_dictionary_terminal_attributes_reading_utils.h index e71f2a17a..f38fd5aaa 100644 --- a/native/jni/src/suggest/core/dictionary/binary_dictionary_bigrams_reading_utils.h +++ b/native/jni/src/suggest/core/dictionary/binary_dictionary_terminal_attributes_reading_utils.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef LATINIME_BINARY_DICTIONARY_BIGRAM_READING_UTILS_H -#define LATINIME_BINARY_DICTIONARY_BIGRAM_READING_UTILS_H +#ifndef LATINIME_BINARY_DICTIONARY_TERMINAL_ATTRIBUTES_READING_UTILS_H +#define LATINIME_BINARY_DICTIONARY_TERMINAL_ATTRIBUTES_READING_UTILS_H #include <stdint.h> @@ -25,55 +25,57 @@ namespace latinime { -class BinaryDictionaryBigramsReadingUtils { +class BinaryDictionaryTerminalAttributesReadingUtils { public: - typedef uint8_t BigramFlags; + typedef uint8_t TerminalAttributeFlags; + typedef TerminalAttributeFlags BigramFlags; - static AK_FORCE_INLINE void skipExistingBigrams( - const BinaryDictionaryInfo *const binaryDictionaryInfo, int *const pos) { - BigramFlags flags = getFlagsAndForwardPointer(binaryDictionaryInfo, pos); - while (hasNext(flags)) { - *pos += attributeAddressSize(flags); - flags = getFlagsAndForwardPointer(binaryDictionaryInfo, pos); - } - *pos += attributeAddressSize(flags); - } - - static AK_FORCE_INLINE BigramFlags getFlagsAndForwardPointer( + static AK_FORCE_INLINE TerminalAttributeFlags getFlagsAndForwardPointer( const BinaryDictionaryInfo *const binaryDictionaryInfo, int *const pos) { return ByteArrayUtils::readUint8andAdvancePosition( binaryDictionaryInfo->getDictRoot(), pos); } - static AK_FORCE_INLINE int getBigramProbability(const BigramFlags flags) { + static AK_FORCE_INLINE int getProbabilityFromFlags(const TerminalAttributeFlags flags) { return flags & MASK_ATTRIBUTE_PROBABILITY; } - static AK_FORCE_INLINE bool isOffsetNegative(const BigramFlags flags) { - return (flags & FLAG_ATTRIBUTE_OFFSET_NEGATIVE) != 0; + static AK_FORCE_INLINE bool hasNext(const TerminalAttributeFlags flags) { + return (flags & FLAG_ATTRIBUTE_HAS_NEXT) != 0; } - static AK_FORCE_INLINE bool hasNext(const BigramFlags flags) { - return (flags & FLAG_ATTRIBUTE_HAS_NEXT) != 0; + // Bigrams reading methods + static AK_FORCE_INLINE void skipExistingBigrams( + const BinaryDictionaryInfo *const binaryDictionaryInfo, int *const pos) { + BigramFlags flags = getFlagsAndForwardPointer(binaryDictionaryInfo, pos); + while (hasNext(flags)) { + *pos += attributeAddressSize(flags); + flags = getFlagsAndForwardPointer(binaryDictionaryInfo, pos); + } + *pos += attributeAddressSize(flags); } static int getBigramAddressAndForwardPointer( - const BinaryDictionaryInfo *const binaryDictionaryInfo, - const BigramFlags flags, int *const pos); + const BinaryDictionaryInfo *const binaryDictionaryInfo, const BigramFlags flags, + int *const pos); private: - DISALLOW_IMPLICIT_CONSTRUCTORS(BinaryDictionaryBigramsReadingUtils); + DISALLOW_IMPLICIT_CONSTRUCTORS(BinaryDictionaryTerminalAttributesReadingUtils); - static const BigramFlags MASK_ATTRIBUTE_ADDRESS_TYPE; - static const BigramFlags FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE; - static const BigramFlags FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES; - static const BigramFlags FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES; - static const BigramFlags FLAG_ATTRIBUTE_OFFSET_NEGATIVE; - static const BigramFlags FLAG_ATTRIBUTE_HAS_NEXT; - static const BigramFlags MASK_ATTRIBUTE_PROBABILITY; + static const TerminalAttributeFlags MASK_ATTRIBUTE_ADDRESS_TYPE; + static const TerminalAttributeFlags FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE; + static const TerminalAttributeFlags FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES; + static const TerminalAttributeFlags FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES; + static const TerminalAttributeFlags FLAG_ATTRIBUTE_OFFSET_NEGATIVE; + static const TerminalAttributeFlags FLAG_ATTRIBUTE_HAS_NEXT; + static const TerminalAttributeFlags MASK_ATTRIBUTE_PROBABILITY; static const int ATTRIBUTE_ADDRESS_SHIFT; - static AK_FORCE_INLINE int attributeAddressSize(const BigramFlags flags) { + static AK_FORCE_INLINE bool isOffsetNegative(const TerminalAttributeFlags flags) { + return (flags & FLAG_ATTRIBUTE_OFFSET_NEGATIVE) != 0; + } + + static AK_FORCE_INLINE int attributeAddressSize(const TerminalAttributeFlags flags) { return (flags & MASK_ATTRIBUTE_ADDRESS_TYPE) >> ATTRIBUTE_ADDRESS_SHIFT; /* Note: this is a value-dependant optimization of what may probably be more readably written this way: @@ -87,4 +89,4 @@ class BinaryDictionaryBigramsReadingUtils { } }; } -#endif /* LATINIME_BINARY_DICTIONARY_BIGRAM_READING_UTILS_H */ +#endif /* LATINIME_BINARY_DICTIONARY_TERMINAL_ATTRIBUTES_READING_UTILS_H */ diff --git a/native/jni/src/suggest/core/dictionary/byte_array_utils.h b/native/jni/src/suggest/core/dictionary/byte_array_utils.h index d3321f624..daa822ffa 100644 --- a/native/jni/src/suggest/core/dictionary/byte_array_utils.h +++ b/native/jni/src/suggest/core/dictionary/byte_array_utils.h @@ -57,6 +57,17 @@ class ByteArrayUtils { return value; } + static AK_FORCE_INLINE int readSint24andAdvancePosition( + const uint8_t *const buffer, int *const pos) { + const uint8_t value = readUint8(buffer, *pos); + if (value < 0x80) { + return readUint24andAdvancePosition(buffer, pos); + } else { + (*pos)++; + return -(((value & 0x7F) << 16) ^ readUint16andAdvancePosition(buffer, pos)); + } + } + static AK_FORCE_INLINE uint32_t readUint24andAdvancePosition( const uint8_t *const buffer, int *const pos) { const uint32_t value = readUint24(buffer, *pos); diff --git a/native/jni/src/suggest/core/layout/proximity_info.cpp b/native/jni/src/suggest/core/layout/proximity_info.cpp index 05826a5a1..e64476d82 100644 --- a/native/jni/src/suggest/core/layout/proximity_info.cpp +++ b/native/jni/src/suggest/core/layout/proximity_info.cpp @@ -215,22 +215,30 @@ int ProximityInfo::getKeyCenterXOfKeyIdG( return centerX; } -// referencePointY is currently not used because we don't specially handle keys higher than the -// most common key height. When the referencePointY is NOT_A_COORDINATE, this method should -// calculate the return value without using the line segment. +// When the referencePointY is NOT_A_COORDINATE, this method calculates the return value without +// using the line segment. int ProximityInfo::getKeyCenterYOfKeyIdG( const int keyId, const int referencePointY, const bool isGeometric) const { // TODO: Remove "isGeometric" and have separate "proximity_info"s for gesture and typing. if (keyId < 0) { return 0; } + int centerY; if (!hasTouchPositionCorrectionData()) { - return mCenterYsG[keyId]; + centerY = mCenterYsG[keyId]; } else if (isGeometric) { - return static_cast<int>(mSweetSpotCenterYsG[keyId]); + centerY = static_cast<int>(mSweetSpotCenterYsG[keyId]); } else { - return static_cast<int>(mSweetSpotCenterYs[keyId]); + centerY = static_cast<int>(mSweetSpotCenterYs[keyId]); } + if (referencePointY != NOT_A_COORDINATE && + centerY + mKeyHeights[keyId] > KEYBOARD_HEIGHT && centerY < referencePointY) { + // When the distance between center point and bottom edge of the keyboard is shorter than + // the key height, we assume the key is located at the bottom row of the keyboard. + // The center point is extended to the bottom edge for such keys. + return referencePointY; + } + return centerY; } int ProximityInfo::getKeyKeyDistanceG(const int keyId0, const int keyId1) const { diff --git a/native/jni/src/suggest/core/layout/proximity_info_state.h b/native/jni/src/suggest/core/layout/proximity_info_state.h index cc6410af1..dbcd54488 100644 --- a/native/jni/src/suggest/core/layout/proximity_info_state.h +++ b/native/jni/src/suggest/core/layout/proximity_info_state.h @@ -90,20 +90,7 @@ class ProximityInfoState { return false; } - // TODO: Promote insertion letter correction if that letter is a proximity of the previous - // letter like follows: - // // Demotion for a word with excessive character - // if (excessiveCount > 0) { - // multiplyRate(WORDS_WITH_EXCESSIVE_CHARACTER_DEMOTION_RATE, &finalFreq); - // if (!lastCharExceeded - // && !proximityInfoState->existsAdjacentProximityChars(excessivePos)) { - // // If an excessive character is not adjacent to the left char or the right char, - // // we will demote this word. - // multiplyRate(WORDS_WITH_EXCESSIVE_CHARACTER_OUT_OF_PROXIMITY_DEMOTION_RATE, - // &finalFreq); - // } - // } - inline bool existsAdjacentProximityChars(const int index) const { + AK_FORCE_INLINE bool existsAdjacentProximityChars(const int index) const { if (index < 0 || index >= mSampledInputSize) return false; const int currentCodePoint = getPrimaryCodePointAt(index); const int leftIndex = index - 1; diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp index a8f16c8cb..173a612be 100644 --- a/native/jni/src/suggest/core/suggest.cpp +++ b/native/jni/src/suggest/core/suggest.cpp @@ -36,6 +36,7 @@ namespace latinime { const int Suggest::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; const int Suggest::MIN_CONTINUOUS_SUGGESTION_INPUT_SIZE = 2; const float Suggest::AUTOCORRECT_CLASSIFICATION_THRESHOLD = 0.33f; +const int Suggest::FINAL_SCORE_PENALTY_FOR_NOT_BEST_EXACT_MATCHED_WORD = 1; /** * Returns a set of suggestions for the given input touch points. The commitPoint argument indicates @@ -148,6 +149,8 @@ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequen &doubleLetterTerminalIndex, &doubleLetterLevel); int maxScore = S_INT_MIN; + int bestExactMatchedNodeTerminalIndex = -1; + int bestExactMatchedNodeOutputWordIndex = -1; // Output suggestion results here for (int terminalIndex = 0; terminalIndex < terminalSize && outputWordIndex < MAX_RESULTS; ++terminalIndex) { @@ -186,7 +189,6 @@ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequen const int finalScore = SCORING->calculateFinalScore( compoundDistance, traverseSession->getInputSize(), isForceCommitMultiWords || (isValidWord && SCORING->doesAutoCorrectValidWord())); - maxScore = max(maxScore, finalScore); if (TRAVERSAL->allowPartialCommit()) { @@ -200,6 +202,25 @@ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequen if (isValidWord) { outputTypes[outputWordIndex] = Dictionary::KIND_CORRECTION | outputTypeFlags; frequencies[outputWordIndex] = finalScore; + if (isSafeExactMatch) { + // Demote exact matches that are not the highest probable node among all exact + // matches. + const bool isBestTerminal = bestExactMatchedNodeTerminalIndex < 0 + || terminals[bestExactMatchedNodeTerminalIndex].getProbability() + < terminalDicNode->getProbability(); + const int outputWordIndexToBeDemoted = isBestTerminal ? + bestExactMatchedNodeOutputWordIndex : outputWordIndex; + if (outputWordIndexToBeDemoted >= 0) { + frequencies[outputWordIndexToBeDemoted] -= + FINAL_SCORE_PENALTY_FOR_NOT_BEST_EXACT_MATCHED_WORD; + } + if (isBestTerminal) { + // Updates the best exact matched node index. + bestExactMatchedNodeTerminalIndex = terminalIndex; + // Updates the best exact matched output word index. + bestExactMatchedNodeOutputWordIndex = outputWordIndex; + } + } // Populate the outputChars array with the suggested word. const int startIndex = outputWordIndex * MAX_WORD_LENGTH; terminalDicNode->outputResult(&outputCodePoints[startIndex]); diff --git a/native/jni/src/suggest/core/suggest.h b/native/jni/src/suggest/core/suggest.h index 875cbe4e0..752bde9ac 100644 --- a/native/jni/src/suggest/core/suggest.h +++ b/native/jni/src/suggest/core/suggest.h @@ -82,6 +82,8 @@ class Suggest : public SuggestInterface { // Threshold for autocorrection classifier static const float AUTOCORRECT_CLASSIFICATION_THRESHOLD; + // Final score penalty to exact match words that are not the most probable exact match. + static const int FINAL_SCORE_PENALTY_FOR_NOT_BEST_EXACT_MATCHED_WORD; const Traversal *const TRAVERSAL; const Scoring *const SCORING; diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp index f87989286..2659e4a23 100644 --- a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp +++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp @@ -33,6 +33,7 @@ const float ScoringParams::OMISSION_COST_SAME_CHAR = 0.491f; const float ScoringParams::OMISSION_COST_FIRST_CHAR = 0.582f; const float ScoringParams::INSERTION_COST = 0.730f; const float ScoringParams::INSERTION_COST_SAME_CHAR = 0.586f; +const float ScoringParams::INSERTION_COST_PROXIMITY_CHAR = 0.70f; const float ScoringParams::INSERTION_COST_FIRST_CHAR = 0.623f; const float ScoringParams::TRANSPOSITION_COST = 0.516f; const float ScoringParams::SPACE_SUBSTITUTION_COST = 0.319f; diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.h b/native/jni/src/suggest/policyimpl/typing/scoring_params.h index 53ac999c1..c39c41779 100644 --- a/native/jni/src/suggest/policyimpl/typing/scoring_params.h +++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.h @@ -42,6 +42,7 @@ class ScoringParams { static const float OMISSION_COST_FIRST_CHAR; static const float INSERTION_COST; static const float INSERTION_COST_SAME_CHAR; + static const float INSERTION_COST_PROXIMITY_CHAR; static const float INSERTION_COST_FIRST_CHAR; static const float TRANSPOSITION_COST; static const float SPACE_SUBSTITUTION_COST; diff --git a/native/jni/src/suggest/policyimpl/typing/typing_traversal.h b/native/jni/src/suggest/policyimpl/typing/typing_traversal.h index e21b318e6..5ae396e64 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_traversal.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_traversal.h @@ -147,7 +147,7 @@ class TypingTraversal : public Traversal { AK_FORCE_INLINE bool sameAsTyped( const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const { return traverseSession->getProximityInfoState(0)->sameAsTyped( - dicNode->getOutputWordBuf(), dicNode->getDepth()); + dicNode->getOutputWordBuf(), dicNode->getNodeCodePointCount()); } AK_FORCE_INLINE int getMaxCacheSize() const { @@ -171,7 +171,7 @@ class TypingTraversal : public Traversal { return false; } const int c = dicNode->getOutputWordBuf()[0]; - const bool shortCappedWord = dicNode->getDepth() + const bool shortCappedWord = dicNode->getNodeCodePointCount() < ScoringParams::THRESHOLD_SHORT_WORD_LENGTH && CharUtils::isAsciiUpper(c); return !shortCappedWord || probability >= ScoringParams::THRESHOLD_NEXT_WORD_PROBABILITY_FOR_CAPPED; diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h index a1c99182a..e098f353e 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h @@ -55,7 +55,7 @@ class TypingWeighting : public Weighting { const bool isZeroCostOmission = parentDicNode->isZeroCostOmission(); const bool sameCodePoint = dicNode->isSameNodeCodePoint(parentDicNode); // If the traversal omitted the first letter then the dicNode should now be on the second. - const bool isFirstLetterOmission = dicNode->getDepth() == 2; + const bool isFirstLetterOmission = dicNode->getNodeCodePointCount() == 2; float cost = 0.0f; if (isZeroCostOmission) { cost = 0.0f; @@ -83,7 +83,7 @@ class TypingWeighting : public Weighting { const bool isProximity = isProximityDicNode(traverseSession, dicNode); float cost = isProximity ? (isFirstChar ? ScoringParams::FIRST_PROXIMITY_COST : ScoringParams::PROXIMITY_COST) : 0.0f; - if (dicNode->getDepth() == 2) { + if (dicNode->getNodeCodePointCount() == 2) { // At the second character of the current word, we check if the first char is uppercase // and the word is a second or later word of a multiple word suggestion. We demote it // if so. @@ -122,19 +122,25 @@ class TypingWeighting : public Weighting { float getInsertionCost(const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, const DicNode *const dicNode) const { - const int16_t parentPointIndex = parentDicNode->getInputIndex(0); - const int prevCodePoint = - traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt(parentPointIndex); - + const int16_t insertedPointIndex = parentDicNode->getInputIndex(0); + const int prevCodePoint = traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt( + insertedPointIndex); const int currentCodePoint = dicNode->getNodeCodePoint(); const bool sameCodePoint = prevCodePoint == currentCodePoint; + const bool existsAdjacentProximityChars = traverseSession->getProximityInfoState(0) + ->existsAdjacentProximityChars(insertedPointIndex); const float dist = traverseSession->getProximityInfoState(0)->getPointToKeyLength( - parentPointIndex + 1, currentCodePoint); + insertedPointIndex + 1, dicNode->getNodeCodePoint()); const float weightedDistance = dist * ScoringParams::DISTANCE_WEIGHT_LENGTH; - const bool singleChar = dicNode->getDepth() == 1; - const float cost = (singleChar ? ScoringParams::INSERTION_COST_FIRST_CHAR : 0.0f) - + (sameCodePoint ? ScoringParams::INSERTION_COST_SAME_CHAR - : ScoringParams::INSERTION_COST); + const bool singleChar = dicNode->getNodeCodePointCount() == 1; + float cost = (singleChar ? ScoringParams::INSERTION_COST_FIRST_CHAR : 0.0f); + if (sameCodePoint) { + cost += ScoringParams::INSERTION_COST_SAME_CHAR; + } else if (existsAdjacentProximityChars) { + cost += ScoringParams::INSERTION_COST_PROXIMITY_CHAR; + } else { + cost += ScoringParams::INSERTION_COST; + } return cost + weightedDistance; } @@ -163,6 +169,9 @@ class TypingWeighting : public Weighting { float getTerminalLanguageCost(const DicTraverseSession *const traverseSession, const DicNode *const dicNode, const float dicNodeLanguageImprobability) const { + // We promote exact matches here to prevent them from being pruned. The final score of + // exact match nodes might be demoted later in Suggest::outputSuggestions if there are + // multiple exact matches. const float languageImprobability = (dicNode->isExactMatch()) ? 0.0f : dicNodeLanguageImprobability; return languageImprobability * ScoringParams::DISTANCE_WEIGHT_LANGUAGE; |