diff options
Diffstat (limited to 'native/jni/src')
36 files changed, 380 insertions, 257 deletions
diff --git a/native/jni/src/suggest/core/dictionary/dictionary.cpp b/native/jni/src/suggest/core/dictionary/dictionary.cpp index 7a69d3ceb..697e99ffb 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.cpp +++ b/native/jni/src/suggest/core/dictionary/dictionary.cpp @@ -23,7 +23,7 @@ #include "suggest/core/policy/dictionary_header_structure_policy.h" #include "suggest/core/result/suggestion_results.h" #include "suggest/core/session/dic_traverse_session.h" -#include "suggest/core/session/prev_words_info.h" +#include "suggest/core/session/ngram_context.h" #include "suggest/core/suggest.h" #include "suggest/core/suggest_options.h" #include "suggest/policyimpl/gesture/gesture_suggest_policy_factory.h" @@ -46,11 +46,11 @@ Dictionary::Dictionary(JNIEnv *env, DictionaryStructureWithBufferPolicy::Structu void Dictionary::getSuggestions(ProximityInfo *proximityInfo, DicTraverseSession *traverseSession, int *xcoordinates, int *ycoordinates, int *times, int *pointerIds, int *inputCodePoints, - int inputSize, const PrevWordsInfo *const prevWordsInfo, + int inputSize, const NgramContext *const ngramContext, const SuggestOptions *const suggestOptions, const float weightOfLangModelVsSpatialModel, SuggestionResults *const outSuggestionResults) const { TimeKeeper::setCurrentTime(); - traverseSession->init(this, prevWordsInfo, suggestOptions); + traverseSession->init(this, ngramContext, suggestOptions); const auto &suggest = suggestOptions->isGesture() ? mGestureSuggest : mTypingSuggest; suggest->getSuggestions(proximityInfo, traverseSession, xcoordinates, ycoordinates, times, pointerIds, inputCodePoints, inputSize, @@ -58,10 +58,10 @@ void Dictionary::getSuggestions(ProximityInfo *proximityInfo, DicTraverseSession } Dictionary::NgramListenerForPrediction::NgramListenerForPrediction( - const PrevWordsInfo *const prevWordsInfo, const WordIdArrayView prevWordIds, + const NgramContext *const ngramContext, const WordIdArrayView prevWordIds, SuggestionResults *const suggestionResults, const DictionaryStructureWithBufferPolicy *const dictStructurePolicy) - : mPrevWordsInfo(prevWordsInfo), mPrevWordIds(prevWordIds), + : mNgramContext(ngramContext), mPrevWordIds(prevWordIds), mSuggestionResults(suggestionResults), mDictStructurePolicy(dictStructurePolicy) {} void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbability, @@ -69,7 +69,7 @@ void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbabi if (targetWordId == NOT_A_WORD_ID) { return; } - if (mPrevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */) + if (mNgramContext->isNthPrevWordBeginningOfSentence(1 /* n */) && ngramProbability == NOT_A_PROBABILITY) { return; } @@ -85,20 +85,20 @@ void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbabi wordAttributes.getProbability()); } -void Dictionary::getPredictions(const PrevWordsInfo *const prevWordsInfo, +void Dictionary::getPredictions(const NgramContext *const ngramContext, SuggestionResults *const outSuggestionResults) const { TimeKeeper::setCurrentTime(); WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; - const WordIdArrayView prevWordIds = prevWordsInfo->getPrevWordIds( + const WordIdArrayView prevWordIds = ngramContext->getPrevWordIds( mDictionaryStructureWithBufferPolicy.get(), &prevWordIdArray, true /* tryLowerCaseSearch */); - NgramListenerForPrediction listener(prevWordsInfo, prevWordIds, outSuggestionResults, + NgramListenerForPrediction listener(ngramContext, prevWordIds, outSuggestionResults, mDictionaryStructureWithBufferPolicy.get()); mDictionaryStructureWithBufferPolicy->iterateNgramEntries(prevWordIds, &listener); } int Dictionary::getProbability(const CodePointArrayView codePoints) const { - return getNgramProbability(nullptr /* prevWordsInfo */, codePoints); + return getNgramProbability(nullptr /* ngramContext */, codePoints); } int Dictionary::getMaxProbabilityOfExactMatches(const CodePointArrayView codePoints) const { @@ -107,18 +107,18 @@ int Dictionary::getMaxProbabilityOfExactMatches(const CodePointArrayView codePoi mDictionaryStructureWithBufferPolicy.get(), codePoints); } -int Dictionary::getNgramProbability(const PrevWordsInfo *const prevWordsInfo, +int Dictionary::getNgramProbability(const NgramContext *const ngramContext, const CodePointArrayView codePoints) const { TimeKeeper::setCurrentTime(); const int wordId = mDictionaryStructureWithBufferPolicy->getWordId(codePoints, false /* forceLowerCaseSearch */); if (wordId == NOT_A_WORD_ID) return NOT_A_PROBABILITY; - if (!prevWordsInfo) { + if (!ngramContext) { return getDictionaryStructurePolicy()->getProbabilityOfWord(WordIdArrayView(), wordId); } WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; - const WordIdArrayView prevWordIds = prevWordsInfo->getPrevWordIds - (mDictionaryStructureWithBufferPolicy.get(), &prevWordIdArray, + const WordIdArrayView prevWordIds = ngramContext->getPrevWordIds( + mDictionaryStructureWithBufferPolicy.get(), &prevWordIdArray, true /* tryLowerCaseSearch */); return getDictionaryStructurePolicy()->getProbabilityOfWord(prevWordIds, wordId); } @@ -140,24 +140,24 @@ bool Dictionary::removeUnigramEntry(const CodePointArrayView codePoints) { return mDictionaryStructureWithBufferPolicy->removeUnigramEntry(codePoints); } -bool Dictionary::addNgramEntry(const PrevWordsInfo *const prevWordsInfo, +bool Dictionary::addNgramEntry(const NgramContext *const ngramContext, const NgramProperty *const ngramProperty) { TimeKeeper::setCurrentTime(); - return mDictionaryStructureWithBufferPolicy->addNgramEntry(prevWordsInfo, ngramProperty); + return mDictionaryStructureWithBufferPolicy->addNgramEntry(ngramContext, ngramProperty); } -bool Dictionary::removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, +bool Dictionary::removeNgramEntry(const NgramContext *const ngramContext, const CodePointArrayView codePoints) { TimeKeeper::setCurrentTime(); - return mDictionaryStructureWithBufferPolicy->removeNgramEntry(prevWordsInfo, codePoints); + return mDictionaryStructureWithBufferPolicy->removeNgramEntry(ngramContext, codePoints); } -bool Dictionary::updateCounter(const PrevWordsInfo *const prevWordsInfo, +bool Dictionary::updateEntriesForWordWithNgramContext(const NgramContext *const ngramContext, const CodePointArrayView codePoints, const bool isValidWord, const HistoricalInfo historicalInfo) { TimeKeeper::setCurrentTime(); - return mDictionaryStructureWithBufferPolicy->updateCounter(prevWordsInfo, codePoints, - isValidWord, historicalInfo); + return mDictionaryStructureWithBufferPolicy->updateEntriesForWordWithNgramContext(ngramContext, + codePoints, isValidWord, historicalInfo); } bool Dictionary::flush(const char *const filePath) { diff --git a/native/jni/src/suggest/core/dictionary/dictionary.h b/native/jni/src/suggest/core/dictionary/dictionary.h index a58dbfbd7..843aec473 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.h +++ b/native/jni/src/suggest/core/dictionary/dictionary.h @@ -33,7 +33,7 @@ namespace latinime { class DictionaryStructureWithBufferPolicy; class DicTraverseSession; -class PrevWordsInfo; +class NgramContext; class ProximityInfo; class SuggestionResults; class SuggestOptions; @@ -66,18 +66,18 @@ class Dictionary { void getSuggestions(ProximityInfo *proximityInfo, DicTraverseSession *traverseSession, int *xcoordinates, int *ycoordinates, int *times, int *pointerIds, int *inputCodePoints, - int inputSize, const PrevWordsInfo *const prevWordsInfo, + int inputSize, const NgramContext *const ngramContext, const SuggestOptions *const suggestOptions, const float weightOfLangModelVsSpatialModel, SuggestionResults *const outSuggestionResults) const; - void getPredictions(const PrevWordsInfo *const prevWordsInfo, + void getPredictions(const NgramContext *const ngramContext, SuggestionResults *const outSuggestionResults) const; int getProbability(const CodePointArrayView codePoints) const; int getMaxProbabilityOfExactMatches(const CodePointArrayView codePoints) const; - int getNgramProbability(const PrevWordsInfo *const prevWordsInfo, + int getNgramProbability(const NgramContext *const ngramContext, const CodePointArrayView codePoints) const; bool addUnigramEntry(const CodePointArrayView codePoints, @@ -85,13 +85,13 @@ class Dictionary { bool removeUnigramEntry(const CodePointArrayView codePoints); - bool addNgramEntry(const PrevWordsInfo *const prevWordsInfo, + bool addNgramEntry(const NgramContext *const ngramContext, const NgramProperty *const ngramProperty); - bool removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, + bool removeNgramEntry(const NgramContext *const ngramContext, const CodePointArrayView codePoints); - bool updateCounter(const PrevWordsInfo *const prevWordsInfo, + bool updateEntriesForWordWithNgramContext(const NgramContext *const ngramContext, const CodePointArrayView codePoints, const bool isValidWord, const HistoricalInfo historicalInfo); @@ -123,7 +123,7 @@ class Dictionary { class NgramListenerForPrediction : public NgramListener { public: - NgramListenerForPrediction(const PrevWordsInfo *const prevWordsInfo, + NgramListenerForPrediction(const NgramContext *const ngramContext, const WordIdArrayView prevWordIds, SuggestionResults *const suggestionResults, const DictionaryStructureWithBufferPolicy *const dictStructurePolicy); virtual void onVisitEntry(const int ngramProbability, const int targetWordId); @@ -131,7 +131,7 @@ class Dictionary { private: DISALLOW_IMPLICIT_CONSTRUCTORS(NgramListenerForPrediction); - const PrevWordsInfo *const mPrevWordsInfo; + const NgramContext *const mNgramContext; const WordIdArrayView mPrevWordIds; SuggestionResults *const mSuggestionResults; const DictionaryStructureWithBufferPolicy *const mDictStructurePolicy; diff --git a/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp b/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp index b85f3622a..9573c37bc 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp +++ b/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp @@ -21,7 +21,7 @@ #include "suggest/core/dicnode/dic_node_vector.h" #include "suggest/core/dictionary/dictionary.h" #include "suggest/core/dictionary/digraph_utils.h" -#include "suggest/core/session/prev_words_info.h" +#include "suggest/core/session/ngram_context.h" #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" #include "utils/int_array_view.h" @@ -33,10 +33,10 @@ namespace latinime { std::vector<DicNode> current; std::vector<DicNode> next; - // No prev words information. - PrevWordsInfo emptyPrevWordsInfo; + // No ngram context. + NgramContext emptyNgramContext; WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; - const WordIdArrayView prevWordIds = emptyPrevWordsInfo.getPrevWordIds( + const WordIdArrayView prevWordIds = emptyNgramContext.getPrevWordIds( dictionaryStructurePolicy, &prevWordIdArray, false /* tryLowerCaseSearch */); current.emplace_back(); DicNodeUtils::initAsRoot(dictionaryStructurePolicy, prevWordIds, ¤t.front()); diff --git a/native/jni/src/suggest/core/dictionary/property/historical_info.h b/native/jni/src/suggest/core/dictionary/property/historical_info.h index 5ed9ebfca..f9bd6fd8c 100644 --- a/native/jni/src/suggest/core/dictionary/property/historical_info.h +++ b/native/jni/src/suggest/core/dictionary/property/historical_info.h @@ -47,12 +47,12 @@ class HistoricalInfo { } private: - // Default copy constructor and assign operator are used for using in std::vector. + // Default copy constructor is used for using in std::vector. + DISALLOW_ASSIGNMENT_OPERATOR(HistoricalInfo); - // TODO: Make members const. - int mTimestamp; - int mLevel; - int mCount; + const int mTimestamp; + const int mLevel; + const int mCount; }; } // namespace latinime #endif /* LATINIME_HISTORICAL_INFO_H */ diff --git a/native/jni/src/suggest/core/dictionary/property/ngram_property.h b/native/jni/src/suggest/core/dictionary/property/ngram_property.h index dce460099..8709799f9 100644 --- a/native/jni/src/suggest/core/dictionary/property/ngram_property.h +++ b/native/jni/src/suggest/core/dictionary/property/ngram_property.h @@ -44,13 +44,13 @@ class NgramProperty { } private: - // Default copy constructor and assign operator are used for using in std::vector. + // Default copy constructor is used for using in std::vector. DISALLOW_DEFAULT_CONSTRUCTOR(NgramProperty); + DISALLOW_ASSIGNMENT_OPERATOR(NgramProperty); - // TODO: Make members const. - std::vector<int> mTargetCodePoints; - int mProbability; - HistoricalInfo mHistoricalInfo; + const std::vector<int> mTargetCodePoints; + const int mProbability; + const HistoricalInfo mHistoricalInfo; }; } // namespace latinime #endif // LATINIME_NGRAM_PROPERTY_H diff --git a/native/jni/src/suggest/core/dictionary/property/unigram_property.h b/native/jni/src/suggest/core/dictionary/property/unigram_property.h index d1f0ab4ca..5ed2e2602 100644 --- a/native/jni/src/suggest/core/dictionary/property/unigram_property.h +++ b/native/jni/src/suggest/core/dictionary/property/unigram_property.h @@ -41,12 +41,11 @@ class UnigramProperty { } private: - // Default copy constructor and assign operator are used for using in std::vector. + // Default copy constructor is used for using in std::vector. DISALLOW_DEFAULT_CONSTRUCTOR(ShortcutProperty); - // TODO: Make members const. - std::vector<int> mTargetCodePoints; - int mProbability; + const std::vector<int> mTargetCodePoints; + const int mProbability; }; UnigramProperty() @@ -104,13 +103,12 @@ class UnigramProperty { // Default copy constructor is used for using as a return value. DISALLOW_ASSIGNMENT_OPERATOR(UnigramProperty); - // TODO: Make members const. - bool mRepresentsBeginningOfSentence; - bool mIsNotAWord; - bool mIsBlacklisted; - int mProbability; - HistoricalInfo mHistoricalInfo; - std::vector<ShortcutProperty> mShortcuts; + const bool mRepresentsBeginningOfSentence; + const bool mIsNotAWord; + const bool mIsBlacklisted; + const int mProbability; + const HistoricalInfo mHistoricalInfo; + const std::vector<ShortcutProperty> mShortcuts; }; } // namespace latinime #endif // LATINIME_UNIGRAM_PROPERTY_H diff --git a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h index 6624b7921..ceda5c03f 100644 --- a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h +++ b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h @@ -33,7 +33,7 @@ class DicNodeVector; class DictionaryHeaderStructurePolicy; class MultiBigramMap; class NgramListener; -class PrevWordsInfo; +class NgramContext; class UnigramProperty; /* @@ -81,15 +81,15 @@ class DictionaryStructureWithBufferPolicy { virtual bool removeUnigramEntry(const CodePointArrayView wordCodePoints) = 0; // Returns whether the update was success or not. - virtual bool addNgramEntry(const PrevWordsInfo *const prevWordsInfo, + virtual bool addNgramEntry(const NgramContext *const ngramContext, const NgramProperty *const ngramProperty) = 0; // Returns whether the update was success or not. - virtual bool removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, + virtual bool removeNgramEntry(const NgramContext *const ngramContext, const CodePointArrayView wordCodePoints) = 0; // Returns whether the update was success or not. - virtual bool updateCounter(const PrevWordsInfo *const prevWordsInfo, + virtual bool updateEntriesForWordWithNgramContext(const NgramContext *const ngramContext, const CodePointArrayView wordCodePoints, const bool isValidWord, const HistoricalInfo historicalInfo) = 0; diff --git a/native/jni/src/suggest/core/policy/traversal.h b/native/jni/src/suggest/core/policy/traversal.h index 6dfa7e314..5b6616d9a 100644 --- a/native/jni/src/suggest/core/policy/traversal.h +++ b/native/jni/src/suggest/core/policy/traversal.h @@ -44,7 +44,7 @@ class Traversal { virtual bool needsToTraverseAllUserInput() const = 0; virtual float getMaxSpatialDistance() const = 0; virtual int getDefaultExpandDicNodeSize() const = 0; - virtual int getMaxCacheSize(const int inputSize) const = 0; + virtual int getMaxCacheSize(const int inputSize, const float weightForLocale) const = 0; virtual int getTerminalCacheSize() const = 0; virtual bool isPossibleOmissionChildNode(const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0; diff --git a/native/jni/src/suggest/core/policy/weighting.cpp b/native/jni/src/suggest/core/policy/weighting.cpp index c202b81fe..a06e7d070 100644 --- a/native/jni/src/suggest/core/policy/weighting.cpp +++ b/native/jni/src/suggest/core/policy/weighting.cpp @@ -110,10 +110,14 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n return weighting->getOmissionCost(parentDicNode, dicNode); case CT_ADDITIONAL_PROXIMITY: // only used for typing - return weighting->getAdditionalProximityCost(); + // TODO: Quit calling getMatchedCost(). + return weighting->getAdditionalProximityCost() + + weighting->getMatchedCost(traverseSession, dicNode, inputStateG); case CT_SUBSTITUTION: // only used for typing - return weighting->getSubstitutionCost(); + // TODO: Quit calling getMatchedCost(). + return weighting->getSubstitutionCost() + + weighting->getMatchedCost(traverseSession, dicNode, inputStateG); case CT_NEW_WORD_SPACE_OMISSION: return weighting->getNewWordSpatialCost(traverseSession, dicNode, inputStateG); case CT_MATCH: @@ -176,9 +180,9 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n case CT_OMISSION: return 0; case CT_ADDITIONAL_PROXIMITY: - return 0; /* 0 because CT_MATCH will be called */ + return 1; case CT_SUBSTITUTION: - return 0; /* 0 because CT_MATCH will be called */ + return 1; case CT_NEW_WORD_SPACE_OMISSION: return 0; case CT_MATCH: diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.cpp b/native/jni/src/suggest/core/session/dic_traverse_session.cpp index b4d01d0f0..52dc2f86c 100644 --- a/native/jni/src/suggest/core/session/dic_traverse_session.cpp +++ b/native/jni/src/suggest/core/session/dic_traverse_session.cpp @@ -20,7 +20,7 @@ #include "suggest/core/dictionary/dictionary.h" #include "suggest/core/policy/dictionary_header_structure_policy.h" #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" -#include "suggest/core/session/prev_words_info.h" +#include "suggest/core/session/ngram_context.h" namespace latinime { @@ -30,12 +30,12 @@ const int DicTraverseSession::DICTIONARY_SIZE_THRESHOLD_TO_USE_LARGE_CACHE_FOR_S 256 * 1024; void DicTraverseSession::init(const Dictionary *const dictionary, - const PrevWordsInfo *const prevWordsInfo, const SuggestOptions *const suggestOptions) { + const NgramContext *const ngramContext, const SuggestOptions *const suggestOptions) { mDictionary = dictionary; mMultiWordCostMultiplier = getDictionaryStructurePolicy()->getHeaderStructurePolicy() ->getMultiWordCostMultiplier(); mSuggestOptions = suggestOptions; - mPrevWordIdCount = prevWordsInfo->getPrevWordIds(getDictionaryStructurePolicy(), + mPrevWordIdCount = ngramContext->getPrevWordIds(getDictionaryStructurePolicy(), &mPrevWordIdArray, true /* tryLowerCaseSearch */).size(); } diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.h b/native/jni/src/suggest/core/session/dic_traverse_session.h index 9f841aa3c..bc53167f0 100644 --- a/native/jni/src/suggest/core/session/dic_traverse_session.h +++ b/native/jni/src/suggest/core/session/dic_traverse_session.h @@ -30,7 +30,7 @@ namespace latinime { class Dictionary; class DictionaryStructureWithBufferPolicy; -class PrevWordsInfo; +class NgramContext; class ProximityInfo; class SuggestOptions; @@ -61,7 +61,7 @@ class DicTraverseSession { // Non virtual inline destructor -- never inherit this class AK_FORCE_INLINE ~DicTraverseSession() {} - void init(const Dictionary *dictionary, const PrevWordsInfo *const prevWordsInfo, + void init(const Dictionary *dictionary, const NgramContext *const ngramContext, const SuggestOptions *const suggestOptions); // TODO: Remove and merge into init void setupForGetSuggestions(const ProximityInfo *pInfo, const int *inputCodePoints, diff --git a/native/jni/src/suggest/core/session/prev_words_info.h b/native/jni/src/suggest/core/session/ngram_context.h index 553d5ad07..64c71410f 100644 --- a/native/jni/src/suggest/core/session/prev_words_info.h +++ b/native/jni/src/suggest/core/session/ngram_context.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef LATINIME_PREV_WORDS_INFO_H -#define LATINIME_PREV_WORDS_INFO_H +#ifndef LATINIME_NGRAM_CONTEXT_H +#define LATINIME_NGRAM_CONTEXT_H #include <array> @@ -26,25 +26,26 @@ namespace latinime { -class PrevWordsInfo { +// Rename to NgramContext. +class NgramContext { public: // No prev word information. - PrevWordsInfo() : mPrevWordCount(0) { + NgramContext() : mPrevWordCount(0) { clear(); } - PrevWordsInfo(const PrevWordsInfo &prevWordsInfo) - : mPrevWordCount(prevWordsInfo.mPrevWordCount) { + NgramContext(const NgramContext &ngramContext) + : mPrevWordCount(ngramContext.mPrevWordCount) { for (size_t i = 0; i < mPrevWordCount; ++i) { - mPrevWordCodePointCount[i] = prevWordsInfo.mPrevWordCodePointCount[i]; - memmove(mPrevWordCodePoints[i], prevWordsInfo.mPrevWordCodePoints[i], + mPrevWordCodePointCount[i] = ngramContext.mPrevWordCodePointCount[i]; + memmove(mPrevWordCodePoints[i], ngramContext.mPrevWordCodePoints[i], sizeof(mPrevWordCodePoints[i][0]) * mPrevWordCodePointCount[i]); - mIsBeginningOfSentence[i] = prevWordsInfo.mIsBeginningOfSentence[i]; + mIsBeginningOfSentence[i] = ngramContext.mIsBeginningOfSentence[i]; } } // Construct from previous words. - PrevWordsInfo(const int prevWordCodePoints[][MAX_WORD_LENGTH], + NgramContext(const int prevWordCodePoints[][MAX_WORD_LENGTH], const int *const prevWordCodePointCount, const bool *const isBeginningOfSentence, const size_t prevWordCount) : mPrevWordCount(std::min(NELEMS(mPrevWordCodePoints), prevWordCount)) { @@ -61,7 +62,7 @@ class PrevWordsInfo { } // Construct from a previous word. - PrevWordsInfo(const int *const prevWordCodePoints, const int prevWordCodePointCount, + NgramContext(const int *const prevWordCodePoints, const int prevWordCodePointCount, const bool isBeginningOfSentence) : mPrevWordCount(1) { clear(); if (prevWordCodePointCount > MAX_WORD_LENGTH || !prevWordCodePoints) { @@ -78,8 +79,8 @@ class PrevWordsInfo { } // TODO: Remove. - const PrevWordsInfo getTrimmedPrevWordsInfo(const size_t maxPrevWordCount) const { - return PrevWordsInfo(mPrevWordCodePoints, mPrevWordCodePointCount, mIsBeginningOfSentence, + const NgramContext getTrimmedNgramContext(const size_t maxPrevWordCount) const { + return NgramContext(mPrevWordCodePoints, mPrevWordCodePointCount, mIsBeginningOfSentence, std::min(mPrevWordCount, maxPrevWordCount)); } @@ -122,7 +123,7 @@ class PrevWordsInfo { } private: - DISALLOW_ASSIGNMENT_OPERATOR(PrevWordsInfo); + DISALLOW_ASSIGNMENT_OPERATOR(NgramContext); static int getWordId(const DictionaryStructureWithBufferPolicy *const dictStructurePolicy, const int *const wordCodePoints, const int wordCodePointCount, @@ -165,4 +166,4 @@ class PrevWordsInfo { bool mIsBeginningOfSentence[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; }; } // namespace latinime -#endif // LATINIME_PREV_WORDS_INFO_H +#endif // LATINIME_NGRAM_CONTEXT_H diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp index 457414f2b..c71526293 100644 --- a/native/jni/src/suggest/core/suggest.cpp +++ b/native/jni/src/suggest/core/suggest.cpp @@ -28,6 +28,7 @@ #include "suggest/core/policy/weighting.h" #include "suggest/core/result/suggestions_output_utils.h" #include "suggest/core/session/dic_traverse_session.h" +#include "suggest/core/suggest_options.h" namespace latinime { @@ -88,7 +89,8 @@ void Suggest::initializeSearch(DicTraverseSession *traverseSession) const { traverseSession->getDicTraverseCache()->continueSearch(); } else { // Restart recognition at the root. - traverseSession->resetCache(TRAVERSAL->getMaxCacheSize(traverseSession->getInputSize()), + traverseSession->resetCache(TRAVERSAL->getMaxCacheSize(traverseSession->getInputSize(), + traverseSession->getSuggestOptions()->weightForLocale()), TRAVERSAL->getTerminalCacheSize()); // Create a new dic node here DicNode rootNode; @@ -282,7 +284,6 @@ void Suggest::processDicNodeAsAdditionalProximityChar(DicTraverseSession *traver // not treat the node as a terminal. There is no need to pass the bigram map in these cases. Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_ADDITIONAL_PROXIMITY, traverseSession, dicNode, childDicNode, 0 /* multiBigramMap */); - weightChildNode(traverseSession, childDicNode); processExpandedDicNode(traverseSession, childDicNode); } @@ -290,7 +291,6 @@ void Suggest::processDicNodeAsSubstitution(DicTraverseSession *traverseSession, DicNode *dicNode, DicNode *childDicNode) const { Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_SUBSTITUTION, traverseSession, dicNode, childDicNode, 0 /* multiBigramMap */); - weightChildNode(traverseSession, childDicNode); processExpandedDicNode(traverseSession, childDicNode); } @@ -401,7 +401,7 @@ void Suggest::weightChildNode(DicTraverseSession *traverseSession, DicNode *dicN if (dicNode->isCompletion(inputSize)) { Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_COMPLETION, traverseSession, 0 /* parentDicNode */, dicNode, 0 /* multiBigramMap */); - } else { // completion + } else { Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_MATCH, traverseSession, 0 /* parentDicNode */, dicNode, 0 /* multiBigramMap */); } diff --git a/native/jni/src/suggest/core/suggest_options.h b/native/jni/src/suggest/core/suggest_options.h index d456680dd..4d331292b 100644 --- a/native/jni/src/suggest/core/suggest_options.h +++ b/native/jni/src/suggest/core/suggest_options.h @@ -42,6 +42,12 @@ class SuggestOptions{ return getBoolOption(SPACE_AWARE_GESTURE_ENABLED); } + AK_FORCE_INLINE float weightForLocale() const { + // The weight is in thousands and we want the real value, so we divide by 1000. + // NativeSuggestOptions#setWeightForLocale does the opposite processing in Java. + return static_cast<float>(getIntOption(WEIGHT_FOR_LOCALE_IN_THOUSANDS)) / 1000.0f; + } + AK_FORCE_INLINE bool getAdditionalFeaturesBoolOption(const int key) const { return getBoolOption(key + ADDITIONAL_FEATURES_OPTIONS); } @@ -55,9 +61,10 @@ class SuggestOptions{ static const int USE_FULL_EDIT_DISTANCE = 1; static const int BLOCK_OFFENSIVE_WORDS = 2; static const int SPACE_AWARE_GESTURE_ENABLED = 3; + static const int WEIGHT_FOR_LOCALE_IN_THOUSANDS = 4; // Additional features options are stored after the other options and used as setting values of // experimental features. - static const int ADDITIONAL_FEATURES_OPTIONS = 4; + static const int ADDITIONAL_FEATURES_OPTIONS = 5; const int *const mOptions; const int mLength; diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp index 4c4dfc578..8fb256c54 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp @@ -38,15 +38,17 @@ const char *const HeaderPolicy::LOCALE_KEY = "locale"; // match Java declaration const char *const HeaderPolicy::FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY = "FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID"; -const char *const HeaderPolicy::MAX_UNIGRAM_COUNT_KEY = "MAX_UNIGRAM_COUNT"; -const char *const HeaderPolicy::MAX_BIGRAM_COUNT_KEY = "MAX_BIGRAM_COUNT"; +const char *const HeaderPolicy::MAX_UNIGRAM_COUNT_KEY = "MAX_UNIGRAM_ENTRY_COUNT"; +const char *const HeaderPolicy::MAX_BIGRAM_COUNT_KEY = "MAX_BIGRAM_ENTRY_COUNT"; +const char *const HeaderPolicy::MAX_TRIGRAM_COUNT_KEY = "MAX_TRIGRAM_ENTRY_COUNT"; const int HeaderPolicy::DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE = 100; const float HeaderPolicy::MULTIPLE_WORD_COST_MULTIPLIER_SCALE = 100.0f; const int HeaderPolicy::DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID = 3; const int HeaderPolicy::DEFAULT_MAX_UNIGRAM_COUNT = 10000; -const int HeaderPolicy::DEFAULT_MAX_BIGRAM_COUNT = 10000; +const int HeaderPolicy::DEFAULT_MAX_BIGRAM_COUNT = 30000; +const int HeaderPolicy::DEFAULT_MAX_TRIGRAM_COUNT = 30000; // Used for logging. Question mark is used to indicate that the key is not found. void HeaderPolicy::readHeaderValueOrQuestionMark(const char *const key, int *outValue, diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h index bc8eaded3..836bbe5a1 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h @@ -253,11 +253,13 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { static const char *const FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS_KEY; static const char *const MAX_UNIGRAM_COUNT_KEY; static const char *const MAX_BIGRAM_COUNT_KEY; + static const char *const MAX_TRIGRAM_COUNT_KEY; static const int DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE; static const float MULTIPLE_WORD_COST_MULTIPLIER_SCALE; static const int DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID; static const int DEFAULT_MAX_UNIGRAM_COUNT; static const int DEFAULT_MAX_BIGRAM_COUNT; + static const int DEFAULT_MAX_TRIGRAM_COUNT; const FormatUtils::FORMAT_VERSION mDictFormatVersion; const HeaderReadWriteUtils::DictionaryFlags mDictionaryFlags; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp index 8d169743c..6243f14cc 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp @@ -310,7 +310,7 @@ bool Ver4PatriciaTrieNodeWriter::addShortcutTarget(const PtNodeParams *const ptN const int shortcutProbability) { if (!mShortcutPolicy->addNewShortcut(ptNodeParams->getTerminalId(), targetCodePoints, targetCodePointCount, shortcutProbability)) { - AKLOGE("Cannot add new shortuct entry. terminalId: %d", ptNodeParams->getTerminalId()); + AKLOGE("Cannot add new shortcut entry. terminalId: %d", ptNodeParams->getTerminalId()); return false; } if (!ptNodeParams->hasShortcutTargets()) { diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp index 36eafa1e9..0eae934ae 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp @@ -33,7 +33,7 @@ #include "suggest/core/dictionary/property/ngram_property.h" #include "suggest/core/dictionary/property/unigram_property.h" #include "suggest/core/dictionary/property/word_property.h" -#include "suggest/core/session/prev_words_info.h" +#include "suggest/core/session/ngram_context.h" #include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h" #include "suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_reader.h" #include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h" @@ -186,7 +186,9 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordI if (bigramsIt.getBigramPos() == ptNodePos && bigramsIt.getProbability() != NOT_A_PROBABILITY) { const int bigramConditionalProbability = getBigramConditionalProbability( - prevWordPtNodeParams.getProbability(), bigramsIt.getProbability()); + prevWordPtNodeParams.getProbability(), + prevWordPtNodeParams.representsBeginningOfSentence(), + bigramsIt.getProbability()); return getProbability(ptNodeParams.getProbability(), bigramConditionalProbability); } } @@ -209,15 +211,19 @@ void Ver4PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordI while (bigramsIt.hasNext()) { bigramsIt.next(); const int bigramConditionalProbability = getBigramConditionalProbability( - prevWordPtNodeParams.getProbability(), bigramsIt.getProbability()); + prevWordPtNodeParams.getProbability(), + prevWordPtNodeParams.representsBeginningOfSentence(), bigramsIt.getProbability()); listener->onVisitEntry(bigramConditionalProbability, getWordIdFromTerminalPtNodePos(bigramsIt.getBigramPos())); } } int Ver4PatriciaTriePolicy::getBigramConditionalProbability(const int prevWordUnigramProbability, - const int bigramProbability) const { + const bool isInBeginningOfSentenceContext, const int bigramProbability) const { if (mHeaderPolicy->hasHistoricalInfoOfWords()) { + if (isInBeginningOfSentenceContext) { + return bigramProbability; + } // Calculate conditional probability. return std::min(MAX_PROBABILITY - prevWordUnigramProbability + bigramProbability, MAX_PROBABILITY); @@ -338,7 +344,7 @@ bool Ver4PatriciaTriePolicy::removeUnigramEntry(const CodePointArrayView wordCod return mNodeWriter.suppressUnigramEntry(&ptNodeParams); } -bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsInfo, +bool Ver4PatriciaTriePolicy::addNgramEntry(const NgramContext *const ngramContext, const NgramProperty *const ngramProperty) { if (!mBuffers->isUpdatable()) { AKLOGI("Warning: addNgramEntry() is called for non-updatable dictionary."); @@ -349,8 +355,8 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI mDictBuffer->getTailPosition()); return false; } - if (!prevWordsInfo->isValid()) { - AKLOGE("prev words info is not valid for adding n-gram entry to the dictionary."); + if (!ngramContext->isValid()) { + AKLOGE("Ngram context is not valid for adding n-gram entry to the dictionary."); return false; } if (ngramProperty->getTargetCodePoints()->size() > MAX_WORD_LENGTH) { @@ -359,23 +365,23 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI return false; } WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; - const WordIdArrayView prevWordIds = prevWordsInfo->getPrevWordIds(this, &prevWordIdArray, + const WordIdArrayView prevWordIds = ngramContext->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */); if (prevWordIds.empty()) { return false; } if (prevWordIds[0] == NOT_A_WORD_ID) { - if (prevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */)) { + if (ngramContext->isNthPrevWordBeginningOfSentence(1 /* n */)) { const UnigramProperty beginningOfSentenceUnigramProperty( true /* representsBeginningOfSentence */, true /* isNotAWord */, false /* isBlacklisted */, MAX_PROBABILITY /* probability */, HistoricalInfo()); - if (!addUnigramEntry(prevWordsInfo->getNthPrevWordCodePoints(1 /* n */), + if (!addUnigramEntry(ngramContext->getNthPrevWordCodePoints(1 /* n */), &beginningOfSentenceUnigramProperty)) { AKLOGE("Cannot add unigram entry for the beginning-of-sentence."); return false; } // Refresh word ids. - prevWordsInfo->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */); + ngramContext->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */); } else { return false; } @@ -399,7 +405,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI } } -bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, +bool Ver4PatriciaTriePolicy::removeNgramEntry(const NgramContext *const ngramContext, const CodePointArrayView wordCodePoints) { if (!mBuffers->isUpdatable()) { AKLOGI("Warning: removeNgramEntry() is called for non-updatable dictionary."); @@ -410,8 +416,8 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor mDictBuffer->getTailPosition()); return false; } - if (!prevWordsInfo->isValid()) { - AKLOGE("prev words info is not valid for removing n-gram entry form the dictionary."); + if (!ngramContext->isValid()) { + AKLOGE("Ngram context is not valid for removing n-gram entry form the dictionary."); return false; } if (wordCodePoints.size() > MAX_WORD_LENGTH) { @@ -419,7 +425,7 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor wordCodePoints.size()); } WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; - const WordIdArrayView prevWordIds = prevWordsInfo->getPrevWordIds(this, &prevWordIdArray, + const WordIdArrayView prevWordIds = ngramContext->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSerch */); if (prevWordIds.firstOrDefault(NOT_A_WORD_ID) == NOT_A_WORD_ID) { return false; @@ -440,26 +446,27 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor } -bool Ver4PatriciaTriePolicy::updateCounter(const PrevWordsInfo *const prevWordsInfo, - const CodePointArrayView wordCodePoints, const bool isValidWord, - const HistoricalInfo historicalInfo) { +bool Ver4PatriciaTriePolicy::updateEntriesForWordWithNgramContext( + const NgramContext *const ngramContext, const CodePointArrayView wordCodePoints, + const bool isValidWord, const HistoricalInfo historicalInfo) { if (!mBuffers->isUpdatable()) { - AKLOGI("Warning: updateCounter() is called for non-updatable dictionary."); + AKLOGI("Warning: updateEntriesForWordWithNgramContext() is called for non-updatable " + "dictionary."); return false; } const int probability = isValidWord ? DUMMY_PROBABILITY_FOR_VALID_WORDS : NOT_A_PROBABILITY; const UnigramProperty unigramProperty(false /* representsBeginningOfSentence */, false /* isNotAWord */, false /*isBlacklisted*/, probability, historicalInfo); if (!addUnigramEntry(wordCodePoints, &unigramProperty)) { - AKLOGE("Cannot update unigarm entry in updateCounter()."); + AKLOGE("Cannot update unigarm entry in updateEntriesForWordWithNgramContext()."); return false; } - const int probabilityForNgram = prevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */) + const int probabilityForNgram = ngramContext->isNthPrevWordBeginningOfSentence(1 /* n */) ? NOT_A_PROBABILITY : probability; const NgramProperty ngramProperty(wordCodePoints.toVector(), probabilityForNgram, historicalInfo); - if (!addNgramEntry(prevWordsInfo, &ngramProperty)) { - AKLOGE("Cannot update unigarm entry in updateCounter()."); + if (!addNgramEntry(ngramContext, &ngramProperty)) { + AKLOGE("Cannot update unigarm entry in updateEntriesForWordWithNgramContext()."); return false; } return true; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h index b82563e61..1ad5e7e36 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h @@ -112,13 +112,13 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { bool removeUnigramEntry(const CodePointArrayView wordCodePoints); - bool addNgramEntry(const PrevWordsInfo *const prevWordsInfo, + bool addNgramEntry(const NgramContext *const ngramContext, const NgramProperty *const ngramProperty); - bool removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, + bool removeNgramEntry(const NgramContext *const ngramContext, const CodePointArrayView wordCodePoints); - bool updateCounter(const PrevWordsInfo *const prevWordsInfo, + bool updateEntriesForWordWithNgramContext(const NgramContext *const ngramContext, const CodePointArrayView wordCodePoints, const bool isValidWord, const HistoricalInfo historicalInfo); @@ -175,7 +175,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { const WordAttributes getWordAttributes(const int probability, const PtNodeParams &ptNodeParams) const; int getBigramConditionalProbability(const int prevWordUnigramProbability, - const int bigramProbability) const; + const bool isInBeginningOfSentenceContext, const int bigramProbability) const; }; } // namespace v402 } // namespace backward diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp index d3d684bfa..b7f1199c5 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp @@ -23,7 +23,7 @@ #include "suggest/core/dictionary/binary_dictionary_bigrams_iterator.h" #include "suggest/core/dictionary/multi_bigram_map.h" #include "suggest/core/dictionary/ngram_listener.h" -#include "suggest/core/session/prev_words_info.h" +#include "suggest/core/session/ngram_context.h" #include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h" #include "suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h" #include "suggest/policyimpl/dictionary/utils/probability_utils.h" diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h index 32a95bb6c..b17681388 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h @@ -93,25 +93,26 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { return false; } - bool addNgramEntry(const PrevWordsInfo *const prevWordsInfo, + bool addNgramEntry(const NgramContext *const ngramContext, const NgramProperty *const ngramProperty) { // This method should not be called for non-updatable dictionary. AKLOGI("Warning: addNgramEntry() is called for non-updatable dictionary."); return false; } - bool removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, + bool removeNgramEntry(const NgramContext *const ngramContext, const CodePointArrayView wordCodePoints) { // This method should not be called for non-updatable dictionary. AKLOGI("Warning: removeNgramEntry() is called for non-updatable dictionary."); return false; } - bool updateCounter(const PrevWordsInfo *const prevWordsInfo, + bool updateEntriesForWordWithNgramContext(const NgramContext *const ngramContext, const CodePointArrayView wordCodePoints, const bool isValidWord, const HistoricalInfo historicalInfo) { // This method should not be called for non-updatable dictionary. - AKLOGI("Warning: updateCounter() is called for non-updatable dictionary."); + AKLOGI("Warning: updateEntriesForWordWithNgramContext() is called for non-updatable " + "dictionary."); return false; } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.cpp index dc0ed96d0..90d4687dd 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.cpp @@ -37,7 +37,7 @@ const PtNodeParams Ver2ParticiaTrieNodeReader::fetchPtNodeParamsInBufferFromPtNo int shortcutPos = NOT_A_DICT_POS; int bigramPos = NOT_A_DICT_POS; int siblingPos = NOT_A_DICT_POS; - PatriciaTrieReadingUtils::readPtNodeInfo(mBuffer.data(), ptNodePos, mShortuctPolicy, + PatriciaTrieReadingUtils::readPtNodeInfo(mBuffer.data(), ptNodePos, mShortcutPolicy, mBigramPolicy, mCodePointTable, &flags, &mergedNodeCodePointCount, mergedNodeCodePoints, &probability, &childrenPos, &shortcutPos, &bigramPos, &siblingPos); if (mergedNodeCodePointCount <= 0) { diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.h index 24ec5bcca..838d37314 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.h @@ -35,7 +35,7 @@ class Ver2ParticiaTrieNodeReader : public PtNodeReader { const DictionaryBigramsStructurePolicy *const bigramPolicy, const DictionaryShortcutsStructurePolicy *const shortcutPolicy, const int *const codePointTable) - : mBuffer(buffer), mBigramPolicy(bigramPolicy), mShortuctPolicy(shortcutPolicy), + : mBuffer(buffer), mBigramPolicy(bigramPolicy), mShortcutPolicy(shortcutPolicy), mCodePointTable(codePointTable) {} virtual const PtNodeParams fetchPtNodeParamsInBufferFromPtNodePos(const int ptNodePos) const; @@ -45,7 +45,7 @@ class Ver2ParticiaTrieNodeReader : public PtNodeReader { const ReadOnlyByteArrayView mBuffer; const DictionaryBigramsStructurePolicy *const mBigramPolicy; - const DictionaryShortcutsStructurePolicy *const mShortuctPolicy; + const DictionaryShortcutsStructurePolicy *const mShortcutPolicy; const int *const mCodePointTable; }; } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp index 956dabb4f..c4297f5d6 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp @@ -25,6 +25,7 @@ namespace latinime { const int LanguageModelDictContent::UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE = 0; const int LanguageModelDictContent::BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE = 1; +const int LanguageModelDictContent::DUMMY_PROBABILITY_FOR_VALID_WORDS = 1; bool LanguageModelDictContent::save(FILE *const file) const { return mTrieMap.save(file); @@ -42,18 +43,18 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr const int wordId, const HeaderPolicy *const headerPolicy) const { int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex(); - int maxLevel = 0; + int maxPrevWordCount = 0; for (size_t i = 0; i < prevWordIds.size(); ++i) { const int nextBitmapEntryIndex = mTrieMap.get(prevWordIds[i], bitmapEntryIndices[i]).mNextLevelBitmapEntryIndex; if (nextBitmapEntryIndex == TrieMap::INVALID_INDEX) { break; } - maxLevel = i + 1; + maxPrevWordCount = i + 1; bitmapEntryIndices[i + 1] = nextBitmapEntryIndex; } - for (int i = maxLevel; i >= 0; --i) { + for (int i = maxPrevWordCount; i >= 0; --i) { const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndices[i]); if (!result.mIsValid) { continue; @@ -68,9 +69,24 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr // The entry should not be treated as a valid entry. continue; } - probability = std::min(rawProbability - + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */), - MAX_PROBABILITY); + if (i == 0) { + // unigram + probability = rawProbability; + } else { + const ProbabilityEntry prevWordProbabilityEntry = getNgramProbabilityEntry( + prevWordIds.skip(1 /* n */).limit(i - 1), prevWordIds[0]); + if (!prevWordProbabilityEntry.isValid()) { + continue; + } + if (prevWordProbabilityEntry.representsBeginningOfSentence()) { + probability = rawProbability; + } else { + const int prevWordRawProbability = ForgettingCurveUtils::decodeProbability( + prevWordProbabilityEntry.getHistoricalInfo(), headerPolicy); + probability = std::min(MAX_PROBABILITY - prevWordRawProbability + + rawProbability, MAX_PROBABILITY); + } + } } else { probability = probabilityEntry.getProbability(); } @@ -143,6 +159,56 @@ bool LanguageModelDictContent::truncateEntries(const int *const entryCounts, return true; } +bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView prevWordIds, + const int wordId, const bool isValid, const HistoricalInfo historicalInfo, + const HeaderPolicy *const headerPolicy, int *const outAddedNewNgramEntryCount) { + if (outAddedNewNgramEntryCount) { + *outAddedNewNgramEntryCount = 0; + } + if (!mHasHistoricalInfo) { + AKLOGE("updateAllEntriesOnInputWord is called for dictionary without historical info."); + return false; + } + const ProbabilityEntry originalUnigramProbabilityEntry = getProbabilityEntry(wordId); + const ProbabilityEntry updatedUnigramProbabilityEntry = createUpdatedEntryFrom( + originalUnigramProbabilityEntry, isValid, historicalInfo, headerPolicy); + if (!setProbabilityEntry(wordId, &updatedUnigramProbabilityEntry)) { + return false; + } + for (size_t i = 0; i < prevWordIds.size(); ++i) { + if (prevWordIds[i] == NOT_A_WORD_ID) { + break; + } + // TODO: Optimize this code. + const WordIdArrayView limitedPrevWordIds = prevWordIds.limit(i + 1); + const ProbabilityEntry originalNgramProbabilityEntry = getNgramProbabilityEntry( + limitedPrevWordIds, wordId); + const ProbabilityEntry updatedNgramProbabilityEntry = createUpdatedEntryFrom( + originalNgramProbabilityEntry, isValid, historicalInfo, headerPolicy); + if (!setNgramProbabilityEntry(limitedPrevWordIds, wordId, &updatedNgramProbabilityEntry)) { + return false; + } + if (!originalNgramProbabilityEntry.isValid() && outAddedNewNgramEntryCount) { + *outAddedNewNgramEntryCount += 1; + } + } + return true; +} + +const ProbabilityEntry LanguageModelDictContent::createUpdatedEntryFrom( + const ProbabilityEntry &originalProbabilityEntry, const bool isValid, + const HistoricalInfo historicalInfo, const HeaderPolicy *const headerPolicy) const { + const HistoricalInfo updatedHistoricalInfo = ForgettingCurveUtils::createUpdatedHistoricalInfo( + originalProbabilityEntry.getHistoricalInfo(), isValid ? + DUMMY_PROBABILITY_FOR_VALID_WORDS : NOT_A_PROBABILITY, + &historicalInfo, headerPolicy); + if (originalProbabilityEntry.isValid()) { + return ProbabilityEntry(originalProbabilityEntry.getFlags(), &updatedHistoricalInfo); + } else { + return ProbabilityEntry(0 /* flags */, &updatedHistoricalInfo); + } +} + bool LanguageModelDictContent::runGCInner( const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, const TrieMap::TrieMapRange trieMapRange, @@ -203,17 +269,27 @@ int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWord return bitmapEntryIndex; } -bool LanguageModelDictContent::updateAllProbabilityEntriesInner(const int bitmapEntryIndex, - const int level, const HeaderPolicy *const headerPolicy, int *const outEntryCounts) { +bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, + const int prevWordCount, const HeaderPolicy *const headerPolicy, + int *const outEntryCounts) { for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) { - if (level > MAX_PREV_WORD_COUNT_FOR_N_GRAM) { - AKLOGE("Invalid level. level: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.", - level, MAX_PREV_WORD_COUNT_FOR_N_GRAM); + if (prevWordCount > MAX_PREV_WORD_COUNT_FOR_N_GRAM) { + AKLOGE("Invalid prevWordCount. prevWordCount: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.", + prevWordCount, MAX_PREV_WORD_COUNT_FOR_N_GRAM); return false; } const ProbabilityEntry probabilityEntry = ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo); - if (mHasHistoricalInfo && !probabilityEntry.representsBeginningOfSentence()) { + if (prevWordCount > 0 && probabilityEntry.isValid() + && !mTrieMap.getRoot(entry.key()).mIsValid) { + // The entry is related to a word that has been removed. Remove the entry. + if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) { + return false; + } + continue; + } + if (mHasHistoricalInfo && !probabilityEntry.representsBeginningOfSentence() + && probabilityEntry.isValid()) { const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave( probabilityEntry.getHistoricalInfo(), headerPolicy); if (ForgettingCurveUtils::needsToKeep(&historicalInfo, headerPolicy)) { @@ -232,13 +308,13 @@ bool LanguageModelDictContent::updateAllProbabilityEntriesInner(const int bitmap } } if (!probabilityEntry.representsBeginningOfSentence()) { - outEntryCounts[level] += 1; + outEntryCounts[prevWordCount] += 1; } if (!entry.hasNextLevelMap()) { continue; } - if (!updateAllProbabilityEntriesInner(entry.getNextLevelBitmapEntryIndex(), level + 1, - headerPolicy, outEntryCounts)) { + if (!updateAllProbabilityEntriesForGCInner(entry.getNextLevelBitmapEntryIndex(), + prevWordCount + 1, headerPolicy, outEntryCounts)) { return false; } } @@ -266,7 +342,7 @@ bool LanguageModelDictContent::turncateEntriesInSpecifiedLevel( for (int i = 0; i < entryCountToRemove; ++i) { const EntryInfoToTurncate &entryInfo = entryInfoVector[i]; if (!removeNgramProbabilityEntry( - WordIdArrayView(entryInfo.mPrevWordIds, entryInfo.mEntryLevel), entryInfo.mKey)) { + WordIdArrayView(entryInfo.mPrevWordIds, entryInfo.mPrevWordCount), entryInfo.mKey)) { return false; } } @@ -276,9 +352,9 @@ bool LanguageModelDictContent::turncateEntriesInSpecifiedLevel( bool LanguageModelDictContent::getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel, const int bitmapEntryIndex, std::vector<int> *const prevWordIds, std::vector<EntryInfoToTurncate> *const outEntryInfo) const { - const int currentLevel = prevWordIds->size(); + const int prevWordCount = prevWordIds->size(); for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) { - if (currentLevel < targetLevel) { + if (prevWordCount < targetLevel) { if (!entry.hasNextLevelMap()) { continue; } @@ -313,10 +389,10 @@ bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()( if (left.mKey != right.mKey) { return left.mKey < right.mKey; } - if (left.mEntryLevel != right.mEntryLevel) { - return left.mEntryLevel > right.mEntryLevel; + if (left.mPrevWordCount != right.mPrevWordCount) { + return left.mPrevWordCount > right.mPrevWordCount; } - for (int i = 0; i < left.mEntryLevel; ++i) { + for (int i = 0; i < left.mPrevWordCount; ++i) { if (left.mPrevWordIds[i] != right.mPrevWordIds[i]) { return left.mPrevWordIds[i] < right.mPrevWordIds[i]; } @@ -326,9 +402,10 @@ bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()( } LanguageModelDictContent::EntryInfoToTurncate::EntryInfoToTurncate(const int probability, - const int timestamp, const int key, const int entryLevel, const int *const prevWordIds) - : mProbability(probability), mTimestamp(timestamp), mKey(key), mEntryLevel(entryLevel) { - memmove(mPrevWordIds, prevWordIds, mEntryLevel * sizeof(mPrevWordIds[0])); + const int timestamp, const int key, const int prevWordCount, const int *const prevWordIds) + : mProbability(probability), mTimestamp(timestamp), mKey(key), + mPrevWordCount(prevWordCount) { + memmove(mPrevWordIds, prevWordIds, mPrevWordCount * sizeof(mPrevWordIds[0])); } } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h index b7e4af977..51ef090e1 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h @@ -154,19 +154,23 @@ class LanguageModelDictContent { EntryRange getProbabilityEntries(const WordIdArrayView prevWordIds) const; - bool updateAllProbabilityEntries(const HeaderPolicy *const headerPolicy, + bool updateAllProbabilityEntriesForGC(const HeaderPolicy *const headerPolicy, int *const outEntryCounts) { for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) { outEntryCounts[i] = 0; } - return updateAllProbabilityEntriesInner(mTrieMap.getRootBitmapEntryIndex(), 0 /* level */, - headerPolicy, outEntryCounts); + return updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(), + 0 /* prevWordCount */, headerPolicy, outEntryCounts); } // entryCounts should be created by updateAllProbabilityEntries. bool truncateEntries(const int *const entryCounts, const int *const maxEntryCounts, const HeaderPolicy *const headerPolicy, int *const outEntryCounts); + bool updateAllEntriesOnInputWord(const WordIdArrayView prevWordIds, const int wordId, + const bool isValid, const HistoricalInfo historicalInfo, + const HeaderPolicy *const headerPolicy, int *const outAddedNewNgramEntryCount); + private: DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent); @@ -181,18 +185,21 @@ class LanguageModelDictContent { }; EntryInfoToTurncate(const int probability, const int timestamp, const int key, - const int entryLevel, const int *const prevWordIds); + const int prevWordCount, const int *const prevWordIds); int mProbability; int mTimestamp; int mKey; - int mEntryLevel; + int mPrevWordCount; int mPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; private: DISALLOW_DEFAULT_CONSTRUCTOR(EntryInfoToTurncate); }; + // TODO: Remove + static const int DUMMY_PROBABILITY_FOR_VALID_WORDS; + TrieMap mTrieMap; const bool mHasHistoricalInfo; @@ -201,13 +208,16 @@ class LanguageModelDictContent { int *const outNgramCount); int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds); int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const; - bool updateAllProbabilityEntriesInner(const int bitmapEntryIndex, const int level, + bool updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, const int prevWordCount, const HeaderPolicy *const headerPolicy, int *const outEntryCounts); bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy, const int maxEntryCount, const int targetLevel, int *const outEntryCount); bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel, const int bitmapEntryIndex, std::vector<int> *const prevWordIds, std::vector<EntryInfoToTurncate> *const outEntryInfo) const; + const ProbabilityEntry createUpdatedEntryFrom(const ProbabilityEntry &originalProbabilityEntry, + const bool isValid, const HistoricalInfo historicalInfo, + const HeaderPolicy *const headerPolicy) const; }; } // namespace latinime #endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h index fa1415633..f4d340f86 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h @@ -98,17 +98,17 @@ class ProbabilityEntry { } uint64_t encode(const bool hasHistoricalInfo) const { - uint64_t encodedEntry = static_cast<uint64_t>(mFlags); + uint64_t encodedEntry = static_cast<uint8_t>(mFlags); if (hasHistoricalInfo) { encodedEntry = (encodedEntry << (Ver4DictConstants::TIME_STAMP_FIELD_SIZE * CHAR_BIT)) - ^ static_cast<uint64_t>(mHistoricalInfo.getTimestamp()); + | static_cast<uint32_t>(mHistoricalInfo.getTimestamp()); encodedEntry = (encodedEntry << (Ver4DictConstants::WORD_LEVEL_FIELD_SIZE * CHAR_BIT)) - ^ static_cast<uint64_t>(mHistoricalInfo.getLevel()); + | static_cast<uint8_t>(mHistoricalInfo.getLevel()); encodedEntry = (encodedEntry << (Ver4DictConstants::WORD_COUNT_FIELD_SIZE * CHAR_BIT)) - ^ static_cast<uint64_t>(mHistoricalInfo.getCount()); + | static_cast<uint8_t>(mHistoricalInfo.getCount()); } else { encodedEntry = (encodedEntry << (Ver4DictConstants::PROBABILITY_SIZE * CHAR_BIT)) - ^ static_cast<uint64_t>(mProbability); + | static_cast<uint8_t>(mProbability); } return encodedEntry; } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp index f13512d5a..794c63ffd 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp @@ -142,14 +142,9 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeUnigramProperty( if (!toBeUpdatedPtNodeParams->isTerminal()) { return false; } - const ProbabilityEntry originalProbabilityEntry = - mBuffers->getLanguageModelDictContent()->getProbabilityEntry( - toBeUpdatedPtNodeParams->getTerminalId()); const ProbabilityEntry probabilityEntryOfUnigramProperty = ProbabilityEntry(unigramProperty); - const ProbabilityEntry updatedProbabilityEntry = - createUpdatedEntryFrom(&originalProbabilityEntry, &probabilityEntryOfUnigramProperty); return mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry( - toBeUpdatedPtNodeParams->getTerminalId(), &updatedProbabilityEntry); + toBeUpdatedPtNodeParams->getTerminalId(), &probabilityEntryOfUnigramProperty); } bool Ver4PatriciaTrieNodeWriter::updatePtNodeProbabilityAndGetNeedsToKeepPtNodeAfterGC( @@ -203,10 +198,8 @@ bool Ver4PatriciaTrieNodeWriter::writeNewTerminalPtNodeAndAdvancePosition( // Write probability. ProbabilityEntry newProbabilityEntry; const ProbabilityEntry probabilityEntryOfUnigramProperty = ProbabilityEntry(unigramProperty); - const ProbabilityEntry probabilityEntryToWrite = createUpdatedEntryFrom( - &newProbabilityEntry, &probabilityEntryOfUnigramProperty); return mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry( - terminalId, &probabilityEntryToWrite); + terminalId, &probabilityEntryOfUnigramProperty); } // TODO: Support counting ngram entries. @@ -217,10 +210,8 @@ bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds const ProbabilityEntry probabilityEntry = languageModelDictContent->getNgramProbabilityEntry(prevWordIds, wordId); const ProbabilityEntry probabilityEntryOfNgramProperty(ngramProperty); - const ProbabilityEntry updatedProbabilityEntry = createUpdatedEntryFrom( - &probabilityEntry, &probabilityEntryOfNgramProperty); if (!languageModelDictContent->setNgramProbabilityEntry( - prevWordIds, wordId, &updatedProbabilityEntry)) { + prevWordIds, wordId, &probabilityEntryOfNgramProperty)) { AKLOGE("Cannot add new ngram entry. prevWordId[0]: %d, prevWordId.size(): %zd, wordId: %d", prevWordIds[0], prevWordIds.size(), wordId); return false; @@ -285,7 +276,7 @@ bool Ver4PatriciaTrieNodeWriter::addShortcutTarget(const PtNodeParams *const ptN const int shortcutProbability) { if (!mShortcutPolicy->addNewShortcut(ptNodeParams->getTerminalId(), targetCodePoints, targetCodePointCount, shortcutProbability)) { - AKLOGE("Cannot add new shortuct entry. terminalId: %d", ptNodeParams->getTerminalId()); + AKLOGE("Cannot add new shortcut entry. terminalId: %d", ptNodeParams->getTerminalId()); return false; } return true; @@ -346,22 +337,6 @@ bool Ver4PatriciaTrieNodeWriter::writePtNodeAndGetTerminalIdAndAdvancePosition( ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */); } -// TODO: Move probability handling code to LanguageModelDictContent. -const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom( - const ProbabilityEntry *const originalProbabilityEntry, - const ProbabilityEntry *const probabilityEntry) const { - if (mHeaderPolicy->hasHistoricalInfoOfWords()) { - const HistoricalInfo updatedHistoricalInfo = - ForgettingCurveUtils::createUpdatedHistoricalInfo( - originalProbabilityEntry->getHistoricalInfo(), - probabilityEntry->getProbability(), probabilityEntry->getHistoricalInfo(), - mHeaderPolicy); - return ProbabilityEntry(probabilityEntry->getFlags(), &updatedHistoricalInfo); - } else { - return *probabilityEntry; - } -} - bool Ver4PatriciaTrieNodeWriter::updatePtNodeFlags(const int ptNodePos, const bool isTerminal, const bool hasMultipleChars) { // Create node flags and write them. diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h index ea4f09904..4ecf88729 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h @@ -38,11 +38,10 @@ class Ver4ShortcutListPolicy; class Ver4PatriciaTrieNodeWriter : public PtNodeWriter { public: Ver4PatriciaTrieNodeWriter(BufferWithExtendableBuffer *const trieBuffer, - Ver4DictBuffers *const buffers, const HeaderPolicy *const headerPolicy, - const PtNodeReader *const ptNodeReader, + Ver4DictBuffers *const buffers, const PtNodeReader *const ptNodeReader, const PtNodeArrayReader *const ptNodeArrayReader, Ver4ShortcutListPolicy *const shortcutPolicy) - : mTrieBuffer(trieBuffer), mBuffers(buffers), mHeaderPolicy(headerPolicy), + : mTrieBuffer(trieBuffer), mBuffers(buffers), mReadingHelper(ptNodeReader, ptNodeArrayReader), mShortcutPolicy(shortcutPolicy) {} virtual ~Ver4PatriciaTrieNodeWriter() {} @@ -96,20 +95,12 @@ class Ver4PatriciaTrieNodeWriter : public PtNodeWriter { const PtNodeParams *const ptNodeParams, int *const outTerminalId, int *const ptNodeWritingPos); - // Create updated probability entry using given probability property. In addition to the - // probability, this method updates historical information if needed. - // TODO: Update flags. - const ProbabilityEntry createUpdatedEntryFrom( - const ProbabilityEntry *const originalProbabilityEntry, - const ProbabilityEntry *const probabilityEntry) const; - bool updatePtNodeFlags(const int ptNodePos, const bool isTerminal, const bool hasMultipleChars); static const int CHILDREN_POSITION_FIELD_SIZE; BufferWithExtendableBuffer *const mTrieBuffer; Ver4DictBuffers *const mBuffers; - const HeaderPolicy *const mHeaderPolicy; DynamicPtReadingHelper mReadingHelper; Ver4ShortcutListPolicy *const mShortcutPolicy; }; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp index 036197c41..ea8c0dc22 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp @@ -26,7 +26,7 @@ #include "suggest/core/dictionary/property/ngram_property.h" #include "suggest/core/dictionary/property/unigram_property.h" #include "suggest/core/dictionary/property/word_property.h" -#include "suggest/core/session/prev_words_info.h" +#include "suggest/core/session/ngram_context.h" #include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h" #include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.h" #include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h" @@ -43,7 +43,6 @@ const char *const Ver4PatriciaTriePolicy::MAX_BIGRAM_COUNT_QUERY = "MAX_BIGRAM_C const int Ver4PatriciaTriePolicy::MARGIN_TO_REFUSE_DYNAMIC_OPERATIONS = 1024; const int Ver4PatriciaTriePolicy::MIN_DICT_SIZE_TO_REFUSE_DYNAMIC_OPERATIONS = Ver4DictConstants::MAX_DICTIONARY_SIZE - MARGIN_TO_REFUSE_DYNAMIC_OPERATIONS; -const int Ver4PatriciaTriePolicy::DUMMY_PROBABILITY_FOR_VALID_WORDS = 1; void Ver4PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const dicNode, DicNodeVector *const childDicNodes) const { @@ -151,8 +150,7 @@ void Ver4PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordI } const int probability = probabilityEntry.hasHistoricalInfo() ? ForgettingCurveUtils::decodeProbability( - probabilityEntry.getHistoricalInfo(), mHeaderPolicy) - + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */) : + probabilityEntry.getHistoricalInfo(), mHeaderPolicy) : probabilityEntry.getProbability(); listener->onVisitEntry(probability, entry.getWordId()); } @@ -266,7 +264,7 @@ bool Ver4PatriciaTriePolicy::removeUnigramEntry(const CodePointArrayView wordCod return true; } -bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsInfo, +bool Ver4PatriciaTriePolicy::addNgramEntry(const NgramContext *const ngramContext, const NgramProperty *const ngramProperty) { if (!mBuffers->isUpdatable()) { AKLOGI("Warning: addNgramEntry() is called for non-updatable dictionary."); @@ -277,8 +275,8 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI mDictBuffer->getTailPosition()); return false; } - if (!prevWordsInfo->isValid()) { - AKLOGE("prev words info is not valid for adding n-gram entry to the dictionary."); + if (!ngramContext->isValid()) { + AKLOGE("Ngram context is not valid for adding n-gram entry to the dictionary."); return false; } if (ngramProperty->getTargetCodePoints()->size() > MAX_WORD_LENGTH) { @@ -287,7 +285,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI return false; } WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; - const WordIdArrayView prevWordIds = prevWordsInfo->getPrevWordIds(this, &prevWordIdArray, + const WordIdArrayView prevWordIds = ngramContext->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */); if (prevWordIds.empty()) { return false; @@ -296,19 +294,19 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI if (prevWordIds[i] != NOT_A_WORD_ID) { continue; } - if (!prevWordsInfo->isNthPrevWordBeginningOfSentence(i + 1 /* n */)) { + if (!ngramContext->isNthPrevWordBeginningOfSentence(i + 1 /* n */)) { return false; } const UnigramProperty beginningOfSentenceUnigramProperty( true /* representsBeginningOfSentence */, true /* isNotAWord */, false /* isBlacklisted */, MAX_PROBABILITY /* probability */, HistoricalInfo()); - if (!addUnigramEntry(prevWordsInfo->getNthPrevWordCodePoints(1 /* n */), + if (!addUnigramEntry(ngramContext->getNthPrevWordCodePoints(1 /* n */), &beginningOfSentenceUnigramProperty)) { AKLOGE("Cannot add unigram entry for the beginning-of-sentence."); return false; } // Refresh word ids. - prevWordsInfo->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */); + ngramContext->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */); } const int wordId = getWordId(CodePointArrayView(*ngramProperty->getTargetCodePoints()), false /* forceLowerCaseSearch */); @@ -326,7 +324,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI } } -bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, +bool Ver4PatriciaTriePolicy::removeNgramEntry(const NgramContext *const ngramContext, const CodePointArrayView wordCodePoints) { if (!mBuffers->isUpdatable()) { AKLOGI("Warning: removeNgramEntry() is called for non-updatable dictionary."); @@ -337,8 +335,8 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor mDictBuffer->getTailPosition()); return false; } - if (!prevWordsInfo->isValid()) { - AKLOGE("prev words info is not valid for removing n-gram entry form the dictionary."); + if (!ngramContext->isValid()) { + AKLOGE("Ngram context is not valid for removing n-gram entry form the dictionary."); return false; } if (wordCodePoints.size() > MAX_WORD_LENGTH) { @@ -346,7 +344,7 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor wordCodePoints.size()); } WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; - const WordIdArrayView prevWordIds = prevWordsInfo->getPrevWordIds(this, &prevWordIdArray, + const WordIdArrayView prevWordIds = ngramContext->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSerch */); if (prevWordIds.empty() || prevWordIds.contains(NOT_A_WORD_ID)) { return false; @@ -363,32 +361,52 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor } } -bool Ver4PatriciaTriePolicy::updateCounter(const PrevWordsInfo *const prevWordsInfo, - const CodePointArrayView wordCodePoints, const bool isValidWord, - const HistoricalInfo historicalInfo) { +bool Ver4PatriciaTriePolicy::updateEntriesForWordWithNgramContext( + const NgramContext *const ngramContext, const CodePointArrayView wordCodePoints, + const bool isValidWord, const HistoricalInfo historicalInfo) { if (!mBuffers->isUpdatable()) { - AKLOGI("Warning: updateCounter() is called for non-updatable dictionary."); + AKLOGI("Warning: updateEntriesForWordWithNgramContext() is called for non-updatable " + "dictionary."); return false; } - // TODO: Have count up method in language model dict content. - const int probability = isValidWord ? DUMMY_PROBABILITY_FOR_VALID_WORDS : NOT_A_PROBABILITY; - const UnigramProperty unigramProperty(false /* representsBeginningOfSentence */, - false /* isNotAWord */, false /*isBlacklisted*/, probability, historicalInfo); - if (!addUnigramEntry(wordCodePoints, &unigramProperty)) { - AKLOGE("Cannot update unigarm entry in updateCounter()."); - return false; + const bool updateAsAValidWord = ngramContext->isNthPrevWordBeginningOfSentence(1 /* n */) ? + false : isValidWord; + int wordId = getWordId(wordCodePoints, false /* tryLowerCaseSearch */); + if (wordId == NOT_A_WORD_ID) { + // The word is not in the dictionary. + const UnigramProperty unigramProperty(false /* representsBeginningOfSentence */, + false /* isNotAWord */, false /* isBlacklisted */, NOT_A_PROBABILITY, + HistoricalInfo(historicalInfo.getTimestamp(), 0 /* level */, 0 /* count */)); + if (!addUnigramEntry(wordCodePoints, &unigramProperty)) { + AKLOGE("Cannot add unigarm entry in updateEntriesForWordWithNgramContext()."); + return false; + } + wordId = getWordId(wordCodePoints, false /* tryLowerCaseSearch */); } - const int probabilityForNgram = prevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */) - ? NOT_A_PROBABILITY : probability; - const NgramProperty ngramProperty(wordCodePoints.toVector(), probabilityForNgram, - historicalInfo); - for (size_t i = 1; i <= prevWordsInfo->getPrevWordCount(); ++i) { - const PrevWordsInfo trimmedPrevWordsInfo(prevWordsInfo->getTrimmedPrevWordsInfo(i)); - if (!addNgramEntry(&trimmedPrevWordsInfo, &ngramProperty)) { - AKLOGE("Cannot update ngram entry in updateCounter()."); + + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; + const WordIdArrayView prevWordIds = ngramContext->getPrevWordIds(this, &prevWordIdArray, + false /* tryLowerCaseSearch */); + if (prevWordIds.firstOrDefault(NOT_A_WORD_ID) == NOT_A_WORD_ID + && ngramContext->isNthPrevWordBeginningOfSentence(1 /* n */)) { + const UnigramProperty beginningOfSentenceUnigramProperty( + true /* representsBeginningOfSentence */, + true /* isNotAWord */, false /* isBlacklisted */, NOT_A_PROBABILITY, + HistoricalInfo(historicalInfo.getTimestamp(), 0 /* level */, 0 /* count */)); + if (!addUnigramEntry(ngramContext->getNthPrevWordCodePoints(1 /* n */), + &beginningOfSentenceUnigramProperty)) { + AKLOGE("Cannot add BoS entry in updateEntriesForWordWithNgramContext()."); return false; } + // Refresh word ids. + ngramContext->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */); + } + int addedNewNgramEntryCount = 0; + if (!mBuffers->getMutableLanguageModelDictContent()->updateAllEntriesOnInputWord(prevWordIds, + wordId, updateAsAValidWord, historicalInfo, mHeaderPolicy, &addedNewNgramEntryCount)) { + return false; } + mBigramCount += addedNewNgramEntryCount; return true; } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h index 662bb8d4b..c0532815c 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h @@ -47,8 +47,8 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { mShortcutPolicy(mBuffers->getMutableShortcutDictContent(), mBuffers->getTerminalPositionLookupTable()), mNodeReader(mDictBuffer), mPtNodeArrayReader(mDictBuffer), - mNodeWriter(mDictBuffer, mBuffers.get(), mHeaderPolicy, &mNodeReader, - &mPtNodeArrayReader, &mShortcutPolicy), + mNodeWriter(mDictBuffer, mBuffers.get(), &mNodeReader, &mPtNodeArrayReader, + &mShortcutPolicy), mUpdatingHelper(mDictBuffer, &mNodeReader, &mNodeWriter), mWritingHelper(mBuffers.get()), mUnigramCount(mHeaderPolicy->getUnigramCount()), @@ -92,13 +92,13 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { bool removeUnigramEntry(const CodePointArrayView wordCodePoints); - bool addNgramEntry(const PrevWordsInfo *const prevWordsInfo, + bool addNgramEntry(const NgramContext *const ngramContext, const NgramProperty *const ngramProperty); - bool removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, + bool removeNgramEntry(const NgramContext *const ngramContext, const CodePointArrayView wordCodePoints); - bool updateCounter(const PrevWordsInfo *const prevWordsInfo, + bool updateEntriesForWordWithNgramContext(const NgramContext *const ngramContext, const CodePointArrayView wordCodePoints, const bool isValidWord, const HistoricalInfo historicalInfo); @@ -131,8 +131,6 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { // prevent the dictionary from overflowing. static const int MARGIN_TO_REFUSE_DYNAMIC_OPERATIONS; static const int MIN_DICT_SIZE_TO_REFUSE_DYNAMIC_OPERATIONS; - // TODO: Remove - static const int DUMMY_PROBABILITY_FOR_VALID_WORDS; const Ver4DictBuffers::Ver4DictBuffersPtr mBuffers; const HeaderPolicy *const mHeaderPolicy; @@ -144,6 +142,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { DynamicPtUpdatingHelper mUpdatingHelper; Ver4PatriciaTrieWritingHelper mWritingHelper; int mUnigramCount; + // TODO: Support counting ngram entries. int mBigramCount; std::vector<int> mTerminalPtNodePositionsForIteratingWords; mutable bool mIsCorrupted; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp index e1ff973de..f0d59c150 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp @@ -78,11 +78,11 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, Ver4ShortcutListPolicy shortcutPolicy(mBuffers->getMutableShortcutDictContent(), mBuffers->getTerminalPositionLookupTable()); Ver4PatriciaTrieNodeWriter ptNodeWriter(mBuffers->getWritableTrieBuffer(), - mBuffers, headerPolicy, &ptNodeReader, &ptNodeArrayReader, &shortcutPolicy); + mBuffers, &ptNodeReader, &ptNodeArrayReader, &shortcutPolicy); int entryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; - if (!mBuffers->getMutableLanguageModelDictContent()->updateAllProbabilityEntries(headerPolicy, - entryCountTable)) { + if (!mBuffers->getMutableLanguageModelDictContent()->updateAllProbabilityEntriesForGC( + headerPolicy, entryCountTable)) { AKLOGE("Failed to update probabilities in language model dict content."); return false; } @@ -118,7 +118,7 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, PtNodeWriter::DictPositionRelocationMap dictPositionRelocationMap; readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos); Ver4PatriciaTrieNodeWriter ptNodeWriterForNewBuffers(buffersToWrite->getWritableTrieBuffer(), - buffersToWrite, headerPolicy, &ptNodeReader, &ptNodeArrayReader, &shortcutPolicy); + buffersToWrite, &ptNodeReader, &ptNodeArrayReader, &shortcutPolicy); DynamicPtGcEventListeners::TraversePolicyToPlaceAndWriteValidPtNodesToBuffer traversePolicyToPlaceAndWriteValidPtNodesToBuffer(&ptNodeWriterForNewBuffers, buffersToWrite->getWritableTrieBuffer(), &dictPositionRelocationMap); @@ -133,7 +133,7 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, Ver4ShortcutListPolicy newShortcutPolicy(buffersToWrite->getMutableShortcutDictContent(), buffersToWrite->getTerminalPositionLookupTable()); Ver4PatriciaTrieNodeWriter newPtNodeWriter(buffersToWrite->getWritableTrieBuffer(), - buffersToWrite, headerPolicy, &newPtNodeReader, &newPtNodeArrayreader, + buffersToWrite, &newPtNodeReader, &newPtNodeArrayreader, &newShortcutPolicy); // Re-assign terminal IDs for valid terminal PtNodes. TerminalPositionLookupTable::TerminalIdMap terminalIdMap; diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp index 3fc566e7a..6a2db687d 100644 --- a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp +++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp @@ -31,6 +31,7 @@ const float ScoringParams::DIGRAPH_PENALTY_FOR_EXACT_MATCH = 0.03f; // TODO: Unlimit max cache dic node size const int ScoringParams::MAX_CACHE_DIC_NODE_SIZE = 170; const int ScoringParams::MAX_CACHE_DIC_NODE_SIZE_FOR_SINGLE_POINT = 310; +const int ScoringParams::MAX_CACHE_DIC_NODE_SIZE_FOR_LOW_PROBABILITY_LOCALE = 50; const int ScoringParams::THRESHOLD_SHORT_WORD_LENGTH = 4; const float ScoringParams::DISTANCE_WEIGHT_LENGTH = 0.1524f; @@ -48,7 +49,7 @@ const float ScoringParams::INSERTION_COST_PROXIMITY_CHAR = 0.674f; const float ScoringParams::INSERTION_COST_FIRST_CHAR = 0.639f; const float ScoringParams::TRANSPOSITION_COST = 0.5608f; const float ScoringParams::SPACE_SUBSTITUTION_COST = 0.334f; -const float ScoringParams::ADDITIONAL_PROXIMITY_COST = 0.4576f; +const float ScoringParams::ADDITIONAL_PROXIMITY_COST = 0.37972f; const float ScoringParams::SUBSTITUTION_COST = 0.3806f; const float ScoringParams::COST_NEW_WORD = 0.0314f; const float ScoringParams::COST_SECOND_OR_LATER_WORD_FIRST_CHAR_UPPERCASE = 0.3224f; @@ -61,4 +62,7 @@ const float ScoringParams::HAS_MULTI_WORD_TERMINAL_COST = 0.4182f; const float ScoringParams::TYPING_BASE_OUTPUT_SCORE = 1.0f; const float ScoringParams::TYPING_MAX_OUTPUT_SCORE_PER_INPUT = 0.1f; const float ScoringParams::NORMALIZED_SPATIAL_DISTANCE_THRESHOLD_FOR_EDIT = 0.095f; +const float ScoringParams::LOCALE_WEIGHT_THRESHOLD_FOR_SPACE_SUBSTITUTION = 0.99f; +const float ScoringParams::LOCALE_WEIGHT_THRESHOLD_FOR_SPACE_OMISSION = 0.99f; +const float ScoringParams::LOCALE_WEIGHT_THRESHOLD_FOR_SMALL_CACHE_SIZE = 0.99f; } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.h b/native/jni/src/suggest/policyimpl/typing/scoring_params.h index b12de6d87..731424f3d 100644 --- a/native/jni/src/suggest/policyimpl/typing/scoring_params.h +++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.h @@ -30,6 +30,7 @@ class ScoringParams { static const float AUTOCORRECT_OUTPUT_THRESHOLD; static const int MAX_CACHE_DIC_NODE_SIZE; static const int MAX_CACHE_DIC_NODE_SIZE_FOR_SINGLE_POINT; + static const int MAX_CACHE_DIC_NODE_SIZE_FOR_LOW_PROBABILITY_LOCALE; static const int THRESHOLD_SHORT_WORD_LENGTH; static const float EXACT_MATCH_PROMOTION; @@ -68,6 +69,9 @@ class ScoringParams { static const float TYPING_BASE_OUTPUT_SCORE; static const float TYPING_MAX_OUTPUT_SCORE_PER_INPUT; static const float NORMALIZED_SPATIAL_DISTANCE_THRESHOLD_FOR_EDIT; + static const float LOCALE_WEIGHT_THRESHOLD_FOR_SPACE_SUBSTITUTION; + static const float LOCALE_WEIGHT_THRESHOLD_FOR_SPACE_OMISSION; + static const float LOCALE_WEIGHT_THRESHOLD_FOR_SMALL_CACHE_SIZE; private: DISALLOW_IMPLICIT_CONSTRUCTORS(ScoringParams); diff --git a/native/jni/src/suggest/policyimpl/typing/typing_traversal.h b/native/jni/src/suggest/policyimpl/typing/typing_traversal.h index b64ee8be4..b9b6314ae 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_traversal.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_traversal.h @@ -26,6 +26,7 @@ #include "suggest/core/layout/proximity_info_utils.h" #include "suggest/core/policy/traversal.h" #include "suggest/core/session/dic_traverse_session.h" +#include "suggest/core/suggest_options.h" #include "suggest/policyimpl/typing/scoring_params.h" #include "utils/char_utils.h" @@ -77,6 +78,13 @@ class TypingTraversal : public Traversal { if (!CORRECT_NEW_WORD_SPACE_SUBSTITUTION) { return false; } + if (traverseSession->getSuggestOptions()->weightForLocale() + < ScoringParams::LOCALE_WEIGHT_THRESHOLD_FOR_SPACE_SUBSTITUTION) { + // Space substitution is heavy, so we skip doing it if the weight for this language + // is low because we anticipate the suggestions out of this dictionary are not for + // the language the user intends to type in. + return false; + } if (!canDoLookAheadCorrection(traverseSession, dicNode)) { return false; } @@ -91,6 +99,13 @@ class TypingTraversal : public Traversal { if (!CORRECT_NEW_WORD_SPACE_OMISSION) { return false; } + if (traverseSession->getSuggestOptions()->weightForLocale() + < ScoringParams::LOCALE_WEIGHT_THRESHOLD_FOR_SPACE_OMISSION) { + // Space omission is heavy, so we skip doing it if the weight for this language + // is low because we anticipate the suggestions out of this dictionary are not for + // the language the user intends to type in. + return false; + } const int inputSize = traverseSession->getInputSize(); // TODO: Don't refer to isCompletion? if (dicNode->isCompletion(inputSize)) { @@ -141,9 +156,14 @@ class TypingTraversal : public Traversal { return DicNodeVector::DEFAULT_NODES_SIZE_FOR_OPTIMIZATION; } - AK_FORCE_INLINE int getMaxCacheSize(const int inputSize) const { - return (inputSize <= 1) ? ScoringParams::MAX_CACHE_DIC_NODE_SIZE_FOR_SINGLE_POINT - : ScoringParams::MAX_CACHE_DIC_NODE_SIZE; + AK_FORCE_INLINE int getMaxCacheSize(const int inputSize, const float weightForLocale) const { + if (inputSize <= 1) { + return ScoringParams::MAX_CACHE_DIC_NODE_SIZE_FOR_SINGLE_POINT; + } + if (weightForLocale < ScoringParams::LOCALE_WEIGHT_THRESHOLD_FOR_SMALL_CACHE_SIZE) { + return ScoringParams::MAX_CACHE_DIC_NODE_SIZE_FOR_LOW_PROBABILITY_LOCALE; + } + return ScoringParams::MAX_CACHE_DIC_NODE_SIZE; } AK_FORCE_INLINE int getTerminalCacheSize() const { diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp b/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp index 1d590c353..db7a39efb 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp +++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp @@ -68,7 +68,8 @@ ErrorTypeUtils::ErrorType TypingWeighting::getErrorType(const CorrectionType cor } break; case CT_ADDITIONAL_PROXIMITY: - return ErrorTypeUtils::PROXIMITY_CORRECTION; + // TODO: Change to EDIT_CORRECTION. + return ErrorTypeUtils::PROXIMITY_CORRECTION; case CT_OMISSION: if (parentDicNode->canBeIntentionalOmission()) { return ErrorTypeUtils::INTENTIONAL_OMISSION; @@ -77,6 +78,8 @@ ErrorTypeUtils::ErrorType TypingWeighting::getErrorType(const CorrectionType cor } break; case CT_SUBSTITUTION: + // TODO: Quit settng PROXIMITY_CORRECTION. + return ErrorTypeUtils::EDIT_CORRECTION | ErrorTypeUtils::PROXIMITY_CORRECTION; case CT_INSERTION: case CT_TERMINAL_INSERTION: case CT_TRANSPOSITION: diff --git a/native/jni/src/utils/jni_data_utils.h b/native/jni/src/utils/jni_data_utils.h index 235a03bba..25cc41742 100644 --- a/native/jni/src/utils/jni_data_utils.h +++ b/native/jni/src/utils/jni_data_utils.h @@ -21,7 +21,7 @@ #include "defines.h" #include "jni.h" -#include "suggest/core/session/prev_words_info.h" +#include "suggest/core/session/ngram_context.h" #include "suggest/core/policy/dictionary_header_structure_policy.h" #include "suggest/policyimpl/dictionary/header/header_read_write_utils.h" #include "utils/char_utils.h" @@ -96,7 +96,7 @@ class JniDataUtils { } } - static PrevWordsInfo constructPrevWordsInfo(JNIEnv *env, jobjectArray prevWordCodePointArrays, + static NgramContext constructNgramContext(JNIEnv *env, jobjectArray prevWordCodePointArrays, jbooleanArray isBeginningOfSentenceArray, const size_t prevWordCount) { int prevWordCodePoints[MAX_PREV_WORD_COUNT_FOR_N_GRAM][MAX_WORD_LENGTH]; int prevWordCodePointCount[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; @@ -119,7 +119,7 @@ class JniDataUtils { &isBeginningOfSentenceBoolean); isBeginningOfSentence[i] = isBeginningOfSentenceBoolean == JNI_TRUE; } - return PrevWordsInfo(prevWordCodePoints, prevWordCodePointCount, isBeginningOfSentence, + return NgramContext(prevWordCodePoints, prevWordCodePointCount, isBeginningOfSentence, prevWordCount); } |