diff options
Diffstat (limited to 'native')
9 files changed, 71 insertions, 12 deletions
diff --git a/native/jni/src/suggest/core/dictionary/suggestions_output_utils.cpp b/native/jni/src/suggest/core/dictionary/suggestions_output_utils.cpp index b8106377c..e37811b88 100644 --- a/native/jni/src/suggest/core/dictionary/suggestions_output_utils.cpp +++ b/native/jni/src/suggest/core/dictionary/suggestions_output_utils.cpp @@ -78,7 +78,8 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; outputAutoCommitFirstWordConfidence[0] = computeFirstWordConfidence(&terminals[0]); } - + const bool boostExactMatches = traverseSession->getDictionaryStructurePolicy()-> + getHeaderStructurePolicy()->shouldBoostExactMatches(); // Output suggestion results here for (int terminalIndex = 0; terminalIndex < terminalSize && outputWordIndex < MAX_RESULTS; ++terminalIndex) { @@ -102,7 +103,7 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; && !(isPossiblyOffensiveWord && isFirstCharUppercase); const int outputTypeFlags = (isPossiblyOffensiveWord ? Dictionary::KIND_FLAG_POSSIBLY_OFFENSIVE : 0) - | (isSafeExactMatch ? Dictionary::KIND_FLAG_EXACT_MATCH : 0); + | ((isSafeExactMatch && boostExactMatches) ? Dictionary::KIND_FLAG_EXACT_MATCH : 0); // Entries that are blacklisted or do not represent a word should not be output. const bool isValidWord = !terminalDicNode->isBlacklistedOrNotAWord(); @@ -113,7 +114,8 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; compoundDistance, traverseSession->getInputSize(), terminalDicNode->getContainedErrorTypes(), (forceCommitMultiWords && terminalDicNode->hasMultipleWords()) - || (isValidWord && scoringPolicy->doesAutoCorrectValidWord())); + || (isValidWord && scoringPolicy->doesAutoCorrectValidWord()), + boostExactMatches); if (maxScore < finalScore && isValidWord) { maxScore = finalScore; } @@ -147,7 +149,7 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; scoringPolicy->calculateFinalScore(compoundDistance, traverseSession->getInputSize(), terminalDicNode->getContainedErrorTypes(), - true /* forceCommit */) : finalScore; + true /* forceCommit */, boostExactMatches) : finalScore; const int updatedOutputWordIndex = outputShortcuts(&shortcutIt, outputWordIndex, shortcutBaseScore, outputCodePoints, frequencies, outputTypes, sameAsTyped); diff --git a/native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h b/native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h index b76b13971..417620e00 100644 --- a/native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h +++ b/native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h @@ -40,6 +40,8 @@ class DictionaryHeaderStructurePolicy { virtual void readHeaderValueOrQuestionMark(const char *const key, int *outValue, int outValueSize) const = 0; + virtual bool shouldBoostExactMatches() const = 0; + protected: DictionaryHeaderStructurePolicy() {} diff --git a/native/jni/src/suggest/core/policy/scoring.h b/native/jni/src/suggest/core/policy/scoring.h index 783383450..e581a97c3 100644 --- a/native/jni/src/suggest/core/policy/scoring.h +++ b/native/jni/src/suggest/core/policy/scoring.h @@ -28,7 +28,8 @@ class DicTraverseSession; class Scoring { public: virtual int calculateFinalScore(const float compoundDistance, const int inputSize, - const ErrorTypeUtils::ErrorType containedErrorTypes, const bool forceCommit) const = 0; + const ErrorTypeUtils::ErrorType containedErrorTypes, const bool forceCommit, + const bool boostExactMatches) const = 0; virtual bool getMostProbableString(const DicTraverseSession *const traverseSession, const int terminalSize, const float languageWeight, int *const outputCodePoints, int *const type, int *const freq) const = 0; diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h index a44f9f0fc..1320c6560 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h @@ -146,6 +146,11 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { return mHasHistoricalInfoOfWords; } + AK_FORCE_INLINE bool shouldBoostExactMatches() const { + // TODO: Investigate better ways to handle exact matches for personalized dictionaries. + return !isDecayingDict(); + } + void readHeaderValueOrQuestionMark(const char *const key, int *outValue, int outValueSize) const; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp index b918e0765..824d442e4 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp @@ -28,6 +28,14 @@ const int DynamicPtReadingHelper::MAX_CHILD_COUNT_TO_AVOID_INFINITE_LOOP = 10000 const int DynamicPtReadingHelper::MAX_PT_NODE_ARRAY_COUNT_TO_AVOID_INFINITE_LOOP = 100000; const size_t DynamicPtReadingHelper::MAX_READING_STATE_STACK_SIZE = MAX_WORD_LENGTH; +bool DynamicPtReadingHelper::TraversePolicyToGetAllTerminalPtNodePositions::onVisitingPtNode( + const PtNodeParams *const ptNodeParams) { + if (ptNodeParams->isTerminal() && !ptNodeParams->isDeleted()) { + mTerminalPositions->push_back(ptNodeParams->getHeadPos()); + } + return true; +} + // Visits all PtNodes in post-order depth first manner. // For example, visits c -> b -> y -> x -> a for the following dictionary: // a _ b _ c diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h index a69490943..bcc5c7857 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h @@ -59,6 +59,21 @@ class DynamicPtReadingHelper { DISALLOW_COPY_AND_ASSIGN(TraversingEventListener); }; + class TraversePolicyToGetAllTerminalPtNodePositions : public TraversingEventListener { + public: + TraversePolicyToGetAllTerminalPtNodePositions(std::vector<int> *const terminalPositions) + : mTerminalPositions(terminalPositions) {} + bool onAscend() { return true; } + bool onDescend(const int ptNodeArrayPos) { return true; } + bool onReadingPtNodeArrayTail() { return true; } + bool onVisitingPtNode(const PtNodeParams *const ptNodeParams); + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(TraversePolicyToGetAllTerminalPtNodePositions); + + std::vector<int> *const mTerminalPositions; + }; + DynamicPtReadingHelper(const BufferWithExtendableBuffer *const buffer, const PtNodeReader *const ptNodeReader) : mIsError(false), mReadingState(), mBuffer(buffer), diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp index 1c420e070..75d85988c 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp @@ -392,10 +392,32 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(const int *const code historicalInfo->getCount(), &bigrams, &shortcuts); } -int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, - int *const outCodePoints) { - // TODO: Implement. - return 0; +int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outCodePoints) { + if (token == 0) { + mTerminalPtNodePositionsForIteratingWords.clear(); + DynamicPtReadingHelper::TraversePolicyToGetAllTerminalPtNodePositions traversePolicy( + &mTerminalPtNodePositionsForIteratingWords); + DynamicPtReadingHelper readingHelper(mDictBuffer, &mNodeReader); + readingHelper.initWithPtNodeArrayPos(getRootPosition()); + readingHelper.traverseAllPtNodesInPostorderDepthFirstManner(&traversePolicy); + } + const int terminalPtNodePositionsVectorSize = + static_cast<int>(mTerminalPtNodePositionsForIteratingWords.size()); + if (token < 0 || token >= terminalPtNodePositionsVectorSize) { + AKLOGE("Given token %d is invalid.", token); + return 0; + } + const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token]; + int unigramProbability = NOT_A_PROBABILITY; + getCodePointsAndProbabilityAndReturnCodePointCount(terminalPtNodePos, MAX_WORD_LENGTH, + outCodePoints, &unigramProbability); + const int nextToken = token + 1; + if (nextToken >= terminalPtNodePositionsVectorSize) { + // All words have been iterated. + mTerminalPtNodePositionsForIteratingWords.clear(); + return 0; + } + return nextToken; } } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h index 1bcd4ceea..9ba5be0c3 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h @@ -17,6 +17,8 @@ #ifndef LATINIME_VER4_PATRICIA_TRIE_POLICY_H #define LATINIME_VER4_PATRICIA_TRIE_POLICY_H +#include <vector> + #include "defines.h" #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" #include "suggest/policyimpl/dictionary/bigram/ver4_bigram_list_policy.h" @@ -50,7 +52,8 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { mUpdatingHelper(mDictBuffer, &mNodeReader, &mNodeWriter), mWritingHelper(mBuffers.get()), mUnigramCount(mHeaderPolicy->getUnigramCount()), - mBigramCount(mHeaderPolicy->getBigramCount()) {}; + mBigramCount(mHeaderPolicy->getBigramCount()), + mTerminalPtNodePositionsForIteratingWords() {}; AK_FORCE_INLINE int getRootPosition() const { return 0; @@ -134,6 +137,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { Ver4PatriciaTrieWritingHelper mWritingHelper; int mUnigramCount; int mBigramCount; + std::vector<int> mTerminalPtNodePositionsForIteratingWords; }; } // namespace latinime #endif // LATINIME_VER4_PATRICIA_TRIE_POLICY_H diff --git a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h index c777e7238..8b405e8de 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h @@ -50,14 +50,14 @@ class TypingScoring : public Scoring { AK_FORCE_INLINE int calculateFinalScore(const float compoundDistance, const int inputSize, const ErrorTypeUtils::ErrorType containedErrorTypes, - const bool forceCommit) const { + const bool forceCommit, const bool boostExactMatches) const { const float maxDistance = ScoringParams::DISTANCE_WEIGHT_LANGUAGE + static_cast<float>(inputSize) * ScoringParams::TYPING_MAX_OUTPUT_SCORE_PER_INPUT; float score = ScoringParams::TYPING_BASE_OUTPUT_SCORE - compoundDistance / maxDistance; if (forceCommit) { score += ScoringParams::AUTOCORRECT_OUTPUT_THRESHOLD; } - if (ErrorTypeUtils::isExactMatch(containedErrorTypes)) { + if (boostExactMatches && ErrorTypeUtils::isExactMatch(containedErrorTypes)) { score += ScoringParams::EXACT_MATCH_PROMOTION; if ((ErrorTypeUtils::MATCH_WITH_CASE_ERROR & containedErrorTypes) != 0) { score -= ScoringParams::CASE_ERROR_PENALTY_FOR_EXACT_MATCH; |