diff options
Diffstat (limited to 'native/jni/src')
68 files changed, 1377 insertions, 564 deletions
diff --git a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp index 7d2898b7a..ea438922f 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp +++ b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp @@ -74,8 +74,9 @@ namespace latinime { } const WordAttributes wordAttributes = dictionaryStructurePolicy->getWordAttributesInContext( dicNode->getPrevWordIds(), dicNode->getWordId(), multiBigramMap); - if (dicNode->hasMultipleWords() - && (wordAttributes.isBlacklisted() || wordAttributes.isNotAWord())) { + if (wordAttributes.getProbability() == NOT_A_PROBABILITY + || (dicNode->hasMultipleWords() + && (wordAttributes.isBlacklisted() || wordAttributes.isNotAWord()))) { return static_cast<float>(MAX_VALUE_FOR_WEIGHTING); } // TODO: This equation to calculate the improbability looks unreasonable. Investigate this. diff --git a/native/jni/src/suggest/core/dictionary/dictionary.cpp b/native/jni/src/suggest/core/dictionary/dictionary.cpp index 697e99ffb..6a5df9d95 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.cpp +++ b/native/jni/src/suggest/core/dictionary/dictionary.cpp @@ -81,6 +81,9 @@ void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbabi } const WordAttributes wordAttributes = mDictStructurePolicy->getWordAttributesInContext( mPrevWordIds, targetWordId, nullptr /* multiBigramMap */); + if (wordAttributes.getProbability() == NOT_A_PROBABILITY) { + return; + } mSuggestionResults->addPrediction(targetWordCodePoints, codePointCount, wordAttributes.getProbability()); } @@ -140,10 +143,9 @@ bool Dictionary::removeUnigramEntry(const CodePointArrayView codePoints) { return mDictionaryStructureWithBufferPolicy->removeUnigramEntry(codePoints); } -bool Dictionary::addNgramEntry(const NgramContext *const ngramContext, - const NgramProperty *const ngramProperty) { +bool Dictionary::addNgramEntry(const NgramProperty *const ngramProperty) { TimeKeeper::setCurrentTime(); - return mDictionaryStructureWithBufferPolicy->addNgramEntry(ngramContext, ngramProperty); + return mDictionaryStructureWithBufferPolicy->addNgramEntry(ngramProperty); } bool Dictionary::removeNgramEntry(const NgramContext *const ngramContext, diff --git a/native/jni/src/suggest/core/dictionary/dictionary.h b/native/jni/src/suggest/core/dictionary/dictionary.h index 843aec473..a5e986d15 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.h +++ b/native/jni/src/suggest/core/dictionary/dictionary.h @@ -85,8 +85,7 @@ class Dictionary { bool removeUnigramEntry(const CodePointArrayView codePoints); - bool addNgramEntry(const NgramContext *const ngramContext, - const NgramProperty *const ngramProperty); + bool addNgramEntry(const NgramProperty *const ngramProperty); bool removeNgramEntry(const NgramContext *const ngramContext, const CodePointArrayView codePoints); diff --git a/native/jni/src/suggest/core/dictionary/error_type_utils.cpp b/native/jni/src/suggest/core/dictionary/error_type_utils.cpp index 1e2494e92..8f07ce275 100644 --- a/native/jni/src/suggest/core/dictionary/error_type_utils.cpp +++ b/native/jni/src/suggest/core/dictionary/error_type_utils.cpp @@ -31,6 +31,7 @@ const ErrorTypeUtils::ErrorType ErrorTypeUtils::NEW_WORD = 0x100; const ErrorTypeUtils::ErrorType ErrorTypeUtils::ERRORS_TREATED_AS_AN_EXACT_MATCH = NOT_AN_ERROR | MATCH_WITH_WRONG_CASE | MATCH_WITH_MISSING_ACCENT | MATCH_WITH_DIGRAPH; +const ErrorTypeUtils::ErrorType ErrorTypeUtils::ERRORS_TREATED_AS_A_PERFECT_MATCH = NOT_AN_ERROR; const ErrorTypeUtils::ErrorType ErrorTypeUtils::ERRORS_TREATED_AS_AN_EXACT_MATCH_WITH_INTENTIONAL_OMISSION = diff --git a/native/jni/src/suggest/core/dictionary/error_type_utils.h b/native/jni/src/suggest/core/dictionary/error_type_utils.h index fd1d5fcff..e92c509fa 100644 --- a/native/jni/src/suggest/core/dictionary/error_type_utils.h +++ b/native/jni/src/suggest/core/dictionary/error_type_utils.h @@ -52,6 +52,10 @@ class ErrorTypeUtils { return (containedErrorTypes & ~ERRORS_TREATED_AS_AN_EXACT_MATCH) == 0; } + static bool isPerfectMatch(const ErrorType containedErrorTypes) { + return (containedErrorTypes & ~ERRORS_TREATED_AS_A_PERFECT_MATCH) == 0; + } + static bool isExactMatchWithIntentionalOmission(const ErrorType containedErrorTypes) { return (containedErrorTypes & ~ERRORS_TREATED_AS_AN_EXACT_MATCH_WITH_INTENTIONAL_OMISSION) == 0; @@ -73,6 +77,7 @@ class ErrorTypeUtils { DISALLOW_IMPLICIT_CONSTRUCTORS(ErrorTypeUtils); static const ErrorType ERRORS_TREATED_AS_AN_EXACT_MATCH; + static const ErrorType ERRORS_TREATED_AS_A_PERFECT_MATCH; static const ErrorType ERRORS_TREATED_AS_AN_EXACT_MATCH_WITH_INTENTIONAL_OMISSION; }; } // namespace latinime diff --git a/native/jni/src/suggest/core/dictionary/ngram_listener.h b/native/jni/src/suggest/core/dictionary/ngram_listener.h index e9b3c1aaf..2eb5e9fd1 100644 --- a/native/jni/src/suggest/core/dictionary/ngram_listener.h +++ b/native/jni/src/suggest/core/dictionary/ngram_listener.h @@ -26,6 +26,8 @@ namespace latinime { */ class NgramListener { public: + // ngramProbability is always 0 for v403 decaying dictionary. + // TODO: Remove ngramProbability. virtual void onVisitEntry(const int ngramProbability, const int targetWordId) = 0; virtual ~NgramListener() {}; diff --git a/native/jni/src/suggest/core/dictionary/property/historical_info.h b/native/jni/src/suggest/core/dictionary/property/historical_info.h index f9bd6fd8c..e5ce1ea25 100644 --- a/native/jni/src/suggest/core/dictionary/property/historical_info.h +++ b/native/jni/src/suggest/core/dictionary/property/historical_info.h @@ -38,6 +38,7 @@ class HistoricalInfo { return mTimestamp; } + // TODO: Remove int getLevel() const { return mLevel; } diff --git a/native/jni/src/suggest/core/dictionary/property/ngram_property.h b/native/jni/src/suggest/core/dictionary/property/ngram_property.h index 8709799f9..e67b4da31 100644 --- a/native/jni/src/suggest/core/dictionary/property/ngram_property.h +++ b/native/jni/src/suggest/core/dictionary/property/ngram_property.h @@ -21,15 +21,20 @@ #include "defines.h" #include "suggest/core/dictionary/property/historical_info.h" +#include "suggest/core/session/ngram_context.h" namespace latinime { class NgramProperty { public: - NgramProperty(const std::vector<int> &&targetCodePoints, const int probability, - const HistoricalInfo historicalInfo) - : mTargetCodePoints(std::move(targetCodePoints)), mProbability(probability), - mHistoricalInfo(historicalInfo) {} + NgramProperty(const NgramContext &ngramContext, const std::vector<int> &&targetCodePoints, + const int probability, const HistoricalInfo historicalInfo) + : mNgramContext(ngramContext), mTargetCodePoints(std::move(targetCodePoints)), + mProbability(probability), mHistoricalInfo(historicalInfo) {} + + const NgramContext *getNgramContext() const { + return &mNgramContext; + } const std::vector<int> *getTargetCodePoints() const { return &mTargetCodePoints; @@ -48,6 +53,7 @@ class NgramProperty { DISALLOW_DEFAULT_CONSTRUCTOR(NgramProperty); DISALLOW_ASSIGNMENT_OPERATOR(NgramProperty); + const NgramContext mNgramContext; const std::vector<int> mTargetCodePoints; const int mProbability; const HistoricalInfo mHistoricalInfo; diff --git a/native/jni/src/suggest/core/dictionary/property/unigram_property.h b/native/jni/src/suggest/core/dictionary/property/unigram_property.h index 5ed2e2602..f194f979a 100644 --- a/native/jni/src/suggest/core/dictionary/property/unigram_property.h +++ b/native/jni/src/suggest/core/dictionary/property/unigram_property.h @@ -49,21 +49,44 @@ class UnigramProperty { }; UnigramProperty() - : mRepresentsBeginningOfSentence(false), mIsNotAWord(false), mIsBlacklisted(false), - mProbability(NOT_A_PROBABILITY), mHistoricalInfo(), mShortcuts() {} + : mRepresentsBeginningOfSentence(false), mIsNotAWord(false), + mIsBlacklisted(false), mIsPossiblyOffensive(false), mProbability(NOT_A_PROBABILITY), + mHistoricalInfo(), mShortcuts() {} + // In contexts which do not support the Blacklisted flag (v2, v4<403) UnigramProperty(const bool representsBeginningOfSentence, const bool isNotAWord, - const bool isBlacklisted, const int probability, const HistoricalInfo historicalInfo, - const std::vector<ShortcutProperty> &&shortcuts) + const bool isPossiblyOffensive, const int probability, + const HistoricalInfo historicalInfo, const std::vector<ShortcutProperty> &&shortcuts) : mRepresentsBeginningOfSentence(representsBeginningOfSentence), - mIsNotAWord(isNotAWord), mIsBlacklisted(isBlacklisted), mProbability(probability), + mIsNotAWord(isNotAWord), mIsBlacklisted(false), + mIsPossiblyOffensive(isPossiblyOffensive), mProbability(probability), mHistoricalInfo(historicalInfo), mShortcuts(std::move(shortcuts)) {} - // Without shortcuts. + // Without shortcuts, in contexts which do not support the Blacklisted flag (v2, v4<403) UnigramProperty(const bool representsBeginningOfSentence, const bool isNotAWord, - const bool isBlacklisted, const int probability, const HistoricalInfo historicalInfo) + const bool isPossiblyOffensive, const int probability, + const HistoricalInfo historicalInfo) : mRepresentsBeginningOfSentence(representsBeginningOfSentence), - mIsNotAWord(isNotAWord), mIsBlacklisted(isBlacklisted), mProbability(probability), + mIsNotAWord(isNotAWord), mIsBlacklisted(false), + mIsPossiblyOffensive(isPossiblyOffensive), mProbability(probability), + mHistoricalInfo(historicalInfo), mShortcuts() {} + + // In contexts which DO support the Blacklisted flag (v403) + UnigramProperty(const bool representsBeginningOfSentence, const bool isNotAWord, + const bool isBlacklisted, const bool isPossiblyOffensive, const int probability, + const HistoricalInfo historicalInfo, const std::vector<ShortcutProperty> &&shortcuts) + : mRepresentsBeginningOfSentence(representsBeginningOfSentence), + mIsNotAWord(isNotAWord), mIsBlacklisted(isBlacklisted), + mIsPossiblyOffensive(isPossiblyOffensive), mProbability(probability), + mHistoricalInfo(historicalInfo), mShortcuts(std::move(shortcuts)) {} + + // Without shortcuts, in contexts which DO support the Blacklisted flag (v403) + UnigramProperty(const bool representsBeginningOfSentence, const bool isNotAWord, + const bool isBlacklisted, const bool isPossiblyOffensive, const int probability, + const HistoricalInfo historicalInfo) + : mRepresentsBeginningOfSentence(representsBeginningOfSentence), + mIsNotAWord(isNotAWord), mIsBlacklisted(isBlacklisted), + mIsPossiblyOffensive(isPossiblyOffensive), mProbability(probability), mHistoricalInfo(historicalInfo), mShortcuts() {} bool representsBeginningOfSentence() const { @@ -74,13 +97,12 @@ class UnigramProperty { return mIsNotAWord; } - bool isBlacklisted() const { - return mIsBlacklisted; + bool isPossiblyOffensive() const { + return mIsPossiblyOffensive; } - bool isPossiblyOffensive() const { - // TODO: Have dedicated flag. - return mProbability == 0; + bool isBlacklisted() const { + return mIsBlacklisted; } bool hasShortcuts() const { @@ -106,6 +128,7 @@ class UnigramProperty { const bool mRepresentsBeginningOfSentence; const bool mIsNotAWord; const bool mIsBlacklisted; + const bool mIsPossiblyOffensive; const int mProbability; const HistoricalInfo mHistoricalInfo; const std::vector<ShortcutProperty> mShortcuts; diff --git a/native/jni/src/suggest/core/dictionary/property/word_property.cpp b/native/jni/src/suggest/core/dictionary/property/word_property.cpp index caac8fe79..019f0880f 100644 --- a/native/jni/src/suggest/core/dictionary/property/word_property.cpp +++ b/native/jni/src/suggest/core/dictionary/property/word_property.cpp @@ -22,13 +22,14 @@ namespace latinime { void WordProperty::outputProperties(JNIEnv *const env, jintArray outCodePoints, - jbooleanArray outFlags, jintArray outProbabilityInfo, jobject outBigramTargets, - jobject outBigramProbabilities, jobject outShortcutTargets, + jbooleanArray outFlags, jintArray outProbabilityInfo, + jobject outNgramPrevWordsArray, jobject outNgramPrevWordIsBeginningOfSentenceArray, + jobject outNgramTargets, jobject outNgramProbabilities, jobject outShortcutTargets, jobject outShortcutProbabilities) const { JniDataUtils::outputCodePoints(env, outCodePoints, 0 /* start */, MAX_WORD_LENGTH /* maxLength */, mCodePoints.data(), mCodePoints.size(), false /* needsNullTermination */); - jboolean flags[] = {mUnigramProperty.isNotAWord(), mUnigramProperty.isBlacklisted(), + jboolean flags[] = {mUnigramProperty.isNotAWord(), mUnigramProperty.isPossiblyOffensive(), !mNgrams.empty(), mUnigramProperty.hasShortcuts(), mUnigramProperty.representsBeginningOfSentence()}; env->SetBooleanArrayRegion(outFlags, 0 /* start */, NELEMS(flags), flags); @@ -43,16 +44,39 @@ void WordProperty::outputProperties(JNIEnv *const env, jintArray outCodePoints, jclass arrayListClass = env->FindClass("java/util/ArrayList"); jmethodID addMethodId = env->GetMethodID(arrayListClass, "add", "(Ljava/lang/Object;)Z"); - // Output bigrams. - // TODO: Support n-gram + // Output ngrams. + jclass intArrayClass = env->FindClass("[I"); for (const auto &ngramProperty : mNgrams) { - const std::vector<int> *const word1CodePoints = ngramProperty.getTargetCodePoints(); - jintArray bigramWord1CodePointArray = env->NewIntArray(word1CodePoints->size()); - JniDataUtils::outputCodePoints(env, bigramWord1CodePointArray, 0 /* start */, - word1CodePoints->size(), word1CodePoints->data(), word1CodePoints->size(), - false /* needsNullTermination */); - env->CallBooleanMethod(outBigramTargets, addMethodId, bigramWord1CodePointArray); - env->DeleteLocalRef(bigramWord1CodePointArray); + const NgramContext *const ngramContext = ngramProperty.getNgramContext(); + jobjectArray prevWordWordCodePointsArray = env->NewObjectArray( + ngramContext->getPrevWordCount(), intArrayClass, nullptr); + jbooleanArray prevWordIsBeginningOfSentenceArray = + env->NewBooleanArray(ngramContext->getPrevWordCount()); + for (size_t i = 0; i < ngramContext->getPrevWordCount(); ++i) { + const CodePointArrayView codePoints = ngramContext->getNthPrevWordCodePoints(i + 1); + jintArray prevWordCodePoints = env->NewIntArray(codePoints.size()); + JniDataUtils::outputCodePoints(env, prevWordCodePoints, 0 /* start */, + codePoints.size(), codePoints.data(), codePoints.size(), + false /* needsNullTermination */); + env->SetObjectArrayElement(prevWordWordCodePointsArray, i, prevWordCodePoints); + env->DeleteLocalRef(prevWordCodePoints); + JniDataUtils::putBooleanToArray(env, prevWordIsBeginningOfSentenceArray, i, + ngramContext->isNthPrevWordBeginningOfSentence(i + 1)); + } + env->CallBooleanMethod(outNgramPrevWordsArray, addMethodId, prevWordWordCodePointsArray); + env->CallBooleanMethod(outNgramPrevWordIsBeginningOfSentenceArray, addMethodId, + prevWordIsBeginningOfSentenceArray); + env->DeleteLocalRef(prevWordWordCodePointsArray); + env->DeleteLocalRef(prevWordIsBeginningOfSentenceArray); + + const std::vector<int> *const targetWordCodePoints = ngramProperty.getTargetCodePoints(); + jintArray targetWordCodePointArray = env->NewIntArray(targetWordCodePoints->size()); + JniDataUtils::outputCodePoints(env, targetWordCodePointArray, 0 /* start */, + targetWordCodePoints->size(), targetWordCodePoints->data(), + targetWordCodePoints->size(), false /* needsNullTermination */); + env->CallBooleanMethod(outNgramTargets, addMethodId, targetWordCodePointArray); + env->DeleteLocalRef(targetWordCodePointArray); + const HistoricalInfo &ngramHistoricalInfo = ngramProperty.getHistoricalInfo(); int bigramProbabilityInfo[] = {ngramProperty.getProbability(), ngramHistoricalInfo.getTimestamp(), ngramHistoricalInfo.getLevel(), @@ -60,7 +84,7 @@ void WordProperty::outputProperties(JNIEnv *const env, jintArray outCodePoints, jintArray bigramProbabilityInfoArray = env->NewIntArray(NELEMS(bigramProbabilityInfo)); env->SetIntArrayRegion(bigramProbabilityInfoArray, 0 /* start */, NELEMS(bigramProbabilityInfo), bigramProbabilityInfo); - env->CallBooleanMethod(outBigramProbabilities, addMethodId, bigramProbabilityInfoArray); + env->CallBooleanMethod(outNgramProbabilities, addMethodId, bigramProbabilityInfoArray); env->DeleteLocalRef(bigramProbabilityInfoArray); } diff --git a/native/jni/src/suggest/core/dictionary/property/word_property.h b/native/jni/src/suggest/core/dictionary/property/word_property.h index 0c23e8225..b5314faaa 100644 --- a/native/jni/src/suggest/core/dictionary/property/word_property.h +++ b/native/jni/src/suggest/core/dictionary/property/word_property.h @@ -34,13 +34,15 @@ class WordProperty { : mCodePoints(), mUnigramProperty(), mNgrams() {} WordProperty(const std::vector<int> &&codePoints, const UnigramProperty *const unigramProperty, - const std::vector<NgramProperty> *const bigrams) + const std::vector<NgramProperty> *const ngrams) : mCodePoints(std::move(codePoints)), mUnigramProperty(*unigramProperty), - mNgrams(*bigrams) {} + mNgrams(*ngrams) {} void outputProperties(JNIEnv *const env, jintArray outCodePoints, jbooleanArray outFlags, - jintArray outProbabilityInfo, jobject outBigramTargets, jobject outBigramProbabilities, - jobject outShortcutTargets, jobject outShortcutProbabilities) const; + jintArray outProbabilityInfo, jobject outNgramPrevWordsArray, + jobject outNgramPrevWordIsBeginningOfSentenceArray, jobject outNgramTargets, + jobject outNgramProbabilities, jobject outShortcutTargets, + jobject outShortcutProbabilities) const; const UnigramProperty *getUnigramProperty() const { return &mUnigramProperty; diff --git a/native/jni/src/suggest/core/dictionary/word_attributes.h b/native/jni/src/suggest/core/dictionary/word_attributes.h index 6e9da3570..5351e7d7d 100644 --- a/native/jni/src/suggest/core/dictionary/word_attributes.h +++ b/native/jni/src/suggest/core/dictionary/word_attributes.h @@ -43,6 +43,14 @@ class WordAttributes { return mIsNotAWord; } + // Whether or not a word is possibly offensive. + // * Static dictionaries <v202, as well as dynamic dictionaries <v403, will set this based on + // whether or not the probability of the word is zero. + // * Static dictionaries >=v203 will set this based on the IS_POSSIBLY_OFFENSIVE PtNode flag. + // * Dynamic dictionaries >=v403 will set this based on the IS_POSSIBLY_OFFENSIVE language model + // flag (the PtNode flag IS_BLACKLISTED is ignored and kept as zero) + // + // See the ::getWordAttributes function for each of these dictionary policies for more details. bool isPossiblyOffensive() const { return mIsPossiblyOffensive; } diff --git a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h index ceda5c03f..33a0fbc19 100644 --- a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h +++ b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h @@ -40,7 +40,6 @@ class UnigramProperty; * This class abstracts the structure of dictionaries. * Implement this policy to support additional dictionaries. */ -// TODO: Use word id instead of terminal PtNode position. class DictionaryStructureWithBufferPolicy { public: typedef std::unique_ptr<DictionaryStructureWithBufferPolicy> StructurePolicyPtr; @@ -81,8 +80,7 @@ class DictionaryStructureWithBufferPolicy { virtual bool removeUnigramEntry(const CodePointArrayView wordCodePoints) = 0; // Returns whether the update was success or not. - virtual bool addNgramEntry(const NgramContext *const ngramContext, - const NgramProperty *const ngramProperty) = 0; + virtual bool addNgramEntry(const NgramProperty *const ngramProperty) = 0; // Returns whether the update was success or not. virtual bool removeNgramEntry(const NgramContext *const ngramContext, @@ -106,7 +104,6 @@ class DictionaryStructureWithBufferPolicy { virtual void getProperty(const char *const query, const int queryLength, char *const outResult, const int maxResultLength) = 0; - // Used for testing. virtual const WordProperty getWordProperty(const CodePointArrayView wordCodePoints) const = 0; // Method to iterate all words in the dictionary. diff --git a/native/jni/src/suggest/core/policy/scoring.h b/native/jni/src/suggest/core/policy/scoring.h index ce3684a1c..b9dda83ad 100644 --- a/native/jni/src/suggest/core/policy/scoring.h +++ b/native/jni/src/suggest/core/policy/scoring.h @@ -30,7 +30,7 @@ class Scoring { public: virtual int calculateFinalScore(const float compoundDistance, const int inputSize, const ErrorTypeUtils::ErrorType containedErrorTypes, const bool forceCommit, - const bool boostExactMatches) const = 0; + const bool boostExactMatches, const bool hasProbabilityZero) const = 0; virtual void getMostProbableString(const DicTraverseSession *const traverseSession, const float weightOfLangModelVsSpatialModel, SuggestionResults *const outSuggestionResults) const = 0; diff --git a/native/jni/src/suggest/core/policy/weighting.cpp b/native/jni/src/suggest/core/policy/weighting.cpp index a06e7d070..450203d98 100644 --- a/native/jni/src/suggest/core/policy/weighting.cpp +++ b/native/jni/src/suggest/core/policy/weighting.cpp @@ -119,7 +119,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n return weighting->getSubstitutionCost() + weighting->getMatchedCost(traverseSession, dicNode, inputStateG); case CT_NEW_WORD_SPACE_OMISSION: - return weighting->getNewWordSpatialCost(traverseSession, dicNode, inputStateG); + return weighting->getSpaceOmissionCost(traverseSession, dicNode, inputStateG); case CT_MATCH: return weighting->getMatchedCost(traverseSession, dicNode, inputStateG); case CT_COMPLETION: diff --git a/native/jni/src/suggest/core/policy/weighting.h b/native/jni/src/suggest/core/policy/weighting.h index bd6b3cf41..863c4eabe 100644 --- a/native/jni/src/suggest/core/policy/weighting.h +++ b/native/jni/src/suggest/core/policy/weighting.h @@ -57,7 +57,7 @@ class Weighting { const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0; - virtual float getNewWordSpatialCost(const DicTraverseSession *const traverseSession, + virtual float getSpaceOmissionCost(const DicTraverseSession *const traverseSession, const DicNode *const dicNode, DicNode_InputStateG *const inputStateG) const = 0; virtual float getNewWordBigramLanguageCost( diff --git a/native/jni/src/suggest/core/result/suggestions_output_utils.cpp b/native/jni/src/suggest/core/result/suggestions_output_utils.cpp index 3283f6deb..74db95953 100644 --- a/native/jni/src/suggest/core/result/suggestions_output_utils.cpp +++ b/native/jni/src/suggest/core/result/suggestions_output_utils.cpp @@ -76,6 +76,52 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; weightOfLangModelVsSpatialModelToOutputSuggestions, outSuggestionResults); } +/* static */ bool SuggestionsOutputUtils::shouldBlockWord( + const SuggestOptions *const suggestOptions, const DicNode *const terminalDicNode, + const WordAttributes wordAttributes, const bool isLastWord) { + const bool currentWordExactMatch = + ErrorTypeUtils::isExactMatch(terminalDicNode->getContainedErrorTypes()); + // When we have to block offensive words, non-exact matched offensive words should not be + // output. + const bool shouldBlockOffensiveWords = suggestOptions->blockOffensiveWords(); + + const bool isBlockedOffensiveWord = shouldBlockOffensiveWords && + wordAttributes.isPossiblyOffensive(); + + // This function is called in two situations: + // + // 1) At the end of a search, in which case terminalDicNode will point to the last DicNode + // of the search, and isLastWord will be true. + // "fuck" + // | + // \ terminalDicNode (isLastWord=true, currentWordExactMatch=true) + // In this case, if the current word is an exact match, we will always let the word + // through, even if the user is blocking offensive words (it's exactly what they typed!) + // + // 2) In the middle of the search, when we hit a terminal node, to decide whether or not + // to start a new search at root, to try to match the rest of the input. In this case, + // terminalDicNode will point to the terminal node we just hit, and isLastWord will be + // false. + // "fuckvthis" + // | + // \ terminalDicNode (isLastWord=false, currentWordExactMatch=true) + // + // In this case, we should NOT allow the match through (correcting "fuckthis" to "fuck this" + // when offensive words are blocked would be a bad idea). + // + // In the case of a multi-word correction where the offensive word is typed last (eg. + // for the input "allfuck"), this function will be called with isLastWord==true, but + // currentWordExactMatch==false. So we are OK in this case as well. + // "allfuck" + // | + // \ terminalDicNode (isLastWord=true, currentWordExactMatch=false) + if (isLastWord && currentWordExactMatch) { + return false; + } else { + return isBlockedOffensiveWord; + } +} + /* static */ void SuggestionsOutputUtils::outputSuggestionsOfDicNode( const Scoring *const scoringPolicy, DicTraverseSession *traverseSession, const DicNode *const terminalDicNode, const float weightOfLangModelVsSpatialModel, @@ -98,24 +144,16 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; const bool isExactMatchWithIntentionalOmission = ErrorTypeUtils::isExactMatchWithIntentionalOmission( terminalDicNode->getContainedErrorTypes()); - const bool isFirstCharUppercase = terminalDicNode->isFirstCharUppercase(); - // Heuristic: We exclude probability=0 first-char-uppercase words from exact match. - // (e.g. "AMD" and "and") - const bool isSafeExactMatch = isExactMatch - && !(wordAttributes.isPossiblyOffensive() && isFirstCharUppercase); const int outputTypeFlags = (wordAttributes.isPossiblyOffensive() ? Dictionary::KIND_FLAG_POSSIBLY_OFFENSIVE : 0) - | ((isSafeExactMatch && boostExactMatches) ? Dictionary::KIND_FLAG_EXACT_MATCH : 0) + | ((isExactMatch && boostExactMatches) ? Dictionary::KIND_FLAG_EXACT_MATCH : 0) | (isExactMatchWithIntentionalOmission ? Dictionary::KIND_FLAG_EXACT_MATCH_WITH_INTENTIONAL_OMISSION : 0); - // Entries that are blacklisted or do not represent a word should not be output. const bool isValidWord = !(wordAttributes.isBlacklisted() || wordAttributes.isNotAWord()); - // When we have to block offensive words, non-exact matched offensive words should not be - // output. - const bool blockOffensiveWords = traverseSession->getSuggestOptions()->blockOffensiveWords(); - const bool isBlockedOffensiveWord = blockOffensiveWords && wordAttributes.isPossiblyOffensive() - && !isSafeExactMatch; + + const bool shouldBlockThisWord = shouldBlockWord(traverseSession->getSuggestOptions(), + terminalDicNode, wordAttributes, true /* isLastWord */); // Increase output score of top typing suggestion to ensure autocorrection. // TODO: Better integration with java side autocorrection logic. @@ -123,11 +161,11 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; compoundDistance, traverseSession->getInputSize(), terminalDicNode->getContainedErrorTypes(), (forceCommitMultiWords && terminalDicNode->hasMultipleWords()), - boostExactMatches); + boostExactMatches, wordAttributes.getProbability() == 0); // Don't output invalid or blocked offensive words. However, we still need to submit their // shortcuts if any. - if (isValidWord && !isBlockedOffensiveWord) { + if (isValidWord && !shouldBlockThisWord) { int codePoints[MAX_WORD_LENGTH]; terminalDicNode->outputResult(codePoints); const int indexToPartialCommit = outputSecondWordFirstLetterInputIndex ? diff --git a/native/jni/src/suggest/core/result/suggestions_output_utils.h b/native/jni/src/suggest/core/result/suggestions_output_utils.h index bf8497828..eca1f78b2 100644 --- a/native/jni/src/suggest/core/result/suggestions_output_utils.h +++ b/native/jni/src/suggest/core/result/suggestions_output_utils.h @@ -18,6 +18,7 @@ #define LATINIME_SUGGESTIONS_OUTPUT_UTILS #include "defines.h" +#include "suggest/core/dictionary/word_attributes.h" namespace latinime { @@ -25,11 +26,19 @@ class BinaryDictionaryShortcutIterator; class DicNode; class DicTraverseSession; class Scoring; +class SuggestOptions; class SuggestionResults; class SuggestionsOutputUtils { public: /** + * Returns true if we should block the incoming word, in the context of the user's + * preferences to include or not include possibly offensive words + */ + static bool shouldBlockWord(const SuggestOptions *const suggestOptions, + const DicNode *const terminalDicNode, const WordAttributes wordAttributes, + const bool isLastWord); + /** * Outputs the final list of suggestions (i.e., terminal nodes). */ static void outputSuggestions(const Scoring *const scoringPolicy, diff --git a/native/jni/src/suggest/core/session/ngram_context.cpp b/native/jni/src/suggest/core/session/ngram_context.cpp new file mode 100644 index 000000000..17ef9ae60 --- /dev/null +++ b/native/jni/src/suggest/core/session/ngram_context.cpp @@ -0,0 +1,123 @@ +/* + * Copyright (C) 2014 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/core/session/ngram_context.h" + +#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" +#include "utils/char_utils.h" + +namespace latinime { + +NgramContext::NgramContext() : mPrevWordCount(0) {} + +NgramContext::NgramContext(const NgramContext &ngramContext) + : mPrevWordCount(ngramContext.mPrevWordCount) { + for (size_t i = 0; i < mPrevWordCount; ++i) { + mPrevWordCodePointCount[i] = ngramContext.mPrevWordCodePointCount[i]; + memmove(mPrevWordCodePoints[i], ngramContext.mPrevWordCodePoints[i], + sizeof(mPrevWordCodePoints[i][0]) * mPrevWordCodePointCount[i]); + mIsBeginningOfSentence[i] = ngramContext.mIsBeginningOfSentence[i]; + } +} + +NgramContext::NgramContext(const int prevWordCodePoints[][MAX_WORD_LENGTH], + const int *const prevWordCodePointCount, const bool *const isBeginningOfSentence, + const size_t prevWordCount) + : mPrevWordCount(std::min(NELEMS(mPrevWordCodePoints), prevWordCount)) { + clear(); + for (size_t i = 0; i < mPrevWordCount; ++i) { + if (prevWordCodePointCount[i] < 0 || prevWordCodePointCount[i] > MAX_WORD_LENGTH) { + continue; + } + memmove(mPrevWordCodePoints[i], prevWordCodePoints[i], + sizeof(mPrevWordCodePoints[i][0]) * prevWordCodePointCount[i]); + mPrevWordCodePointCount[i] = prevWordCodePointCount[i]; + mIsBeginningOfSentence[i] = isBeginningOfSentence[i]; + } +} + +NgramContext::NgramContext(const int *const prevWordCodePoints, const int prevWordCodePointCount, + const bool isBeginningOfSentence) : mPrevWordCount(1) { + clear(); + if (prevWordCodePointCount > MAX_WORD_LENGTH || !prevWordCodePoints) { + return; + } + memmove(mPrevWordCodePoints[0], prevWordCodePoints, + sizeof(mPrevWordCodePoints[0][0]) * prevWordCodePointCount); + mPrevWordCodePointCount[0] = prevWordCodePointCount; + mIsBeginningOfSentence[0] = isBeginningOfSentence; +} + +bool NgramContext::isValid() const { + if (mPrevWordCodePointCount[0] > 0) { + return true; + } + if (mIsBeginningOfSentence[0]) { + return true; + } + return false; +} + +const CodePointArrayView NgramContext::getNthPrevWordCodePoints(const size_t n) const { + if (n <= 0 || n > mPrevWordCount) { + return CodePointArrayView(); + } + return CodePointArrayView(mPrevWordCodePoints[n - 1], mPrevWordCodePointCount[n - 1]); +} + +bool NgramContext::isNthPrevWordBeginningOfSentence(const size_t n) const { + if (n <= 0 || n > mPrevWordCount) { + return false; + } + return mIsBeginningOfSentence[n - 1]; +} + +/* static */ int NgramContext::getWordId( + const DictionaryStructureWithBufferPolicy *const dictStructurePolicy, + const int *const wordCodePoints, const int wordCodePointCount, + const bool isBeginningOfSentence, const bool tryLowerCaseSearch) { + if (!dictStructurePolicy || !wordCodePoints || wordCodePointCount > MAX_WORD_LENGTH) { + return NOT_A_WORD_ID; + } + int codePoints[MAX_WORD_LENGTH]; + int codePointCount = wordCodePointCount; + memmove(codePoints, wordCodePoints, sizeof(int) * codePointCount); + if (isBeginningOfSentence) { + codePointCount = CharUtils::attachBeginningOfSentenceMarker(codePoints, codePointCount, + MAX_WORD_LENGTH); + if (codePointCount <= 0) { + return NOT_A_WORD_ID; + } + } + const CodePointArrayView codePointArrayView(codePoints, codePointCount); + const int wordId = dictStructurePolicy->getWordId(codePointArrayView, + false /* forceLowerCaseSearch */); + if (wordId != NOT_A_WORD_ID || !tryLowerCaseSearch) { + // Return the id when when the word was found or doesn't try lower case search. + return wordId; + } + // Check bigrams for lower-cased previous word if original was not found. Useful for + // auto-capitalized words like "The [current_word]". + return dictStructurePolicy->getWordId(codePointArrayView, true /* forceLowerCaseSearch */); +} + +void NgramContext::clear() { + for (size_t i = 0; i < NELEMS(mPrevWordCodePoints); ++i) { + mPrevWordCodePointCount[i] = 0; + mIsBeginningOfSentence[i] = false; + } +} +} // namespace latinime diff --git a/native/jni/src/suggest/core/session/ngram_context.h b/native/jni/src/suggest/core/session/ngram_context.h index 64c71410f..9b36199c9 100644 --- a/native/jni/src/suggest/core/session/ngram_context.h +++ b/native/jni/src/suggest/core/session/ngram_context.h @@ -20,145 +20,54 @@ #include <array> #include "defines.h" -#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" -#include "utils/char_utils.h" #include "utils/int_array_view.h" namespace latinime { -// Rename to NgramContext. +class DictionaryStructureWithBufferPolicy; + class NgramContext { public: // No prev word information. - NgramContext() : mPrevWordCount(0) { - clear(); - } - - NgramContext(const NgramContext &ngramContext) - : mPrevWordCount(ngramContext.mPrevWordCount) { - for (size_t i = 0; i < mPrevWordCount; ++i) { - mPrevWordCodePointCount[i] = ngramContext.mPrevWordCodePointCount[i]; - memmove(mPrevWordCodePoints[i], ngramContext.mPrevWordCodePoints[i], - sizeof(mPrevWordCodePoints[i][0]) * mPrevWordCodePointCount[i]); - mIsBeginningOfSentence[i] = ngramContext.mIsBeginningOfSentence[i]; - } - } - + NgramContext(); + // Copy constructor to use this class with std::vector and use this class as a return value. + NgramContext(const NgramContext &ngramContext); // Construct from previous words. NgramContext(const int prevWordCodePoints[][MAX_WORD_LENGTH], const int *const prevWordCodePointCount, const bool *const isBeginningOfSentence, - const size_t prevWordCount) - : mPrevWordCount(std::min(NELEMS(mPrevWordCodePoints), prevWordCount)) { - clear(); - for (size_t i = 0; i < mPrevWordCount; ++i) { - if (prevWordCodePointCount[i] < 0 || prevWordCodePointCount[i] > MAX_WORD_LENGTH) { - continue; - } - memmove(mPrevWordCodePoints[i], prevWordCodePoints[i], - sizeof(mPrevWordCodePoints[i][0]) * prevWordCodePointCount[i]); - mPrevWordCodePointCount[i] = prevWordCodePointCount[i]; - mIsBeginningOfSentence[i] = isBeginningOfSentence[i]; - } - } - + const size_t prevWordCount); // Construct from a previous word. NgramContext(const int *const prevWordCodePoints, const int prevWordCodePointCount, - const bool isBeginningOfSentence) : mPrevWordCount(1) { - clear(); - if (prevWordCodePointCount > MAX_WORD_LENGTH || !prevWordCodePoints) { - return; - } - memmove(mPrevWordCodePoints[0], prevWordCodePoints, - sizeof(mPrevWordCodePoints[0][0]) * prevWordCodePointCount); - mPrevWordCodePointCount[0] = prevWordCodePointCount; - mIsBeginningOfSentence[0] = isBeginningOfSentence; - } + const bool isBeginningOfSentence); size_t getPrevWordCount() const { return mPrevWordCount; } - - // TODO: Remove. - const NgramContext getTrimmedNgramContext(const size_t maxPrevWordCount) const { - return NgramContext(mPrevWordCodePoints, mPrevWordCodePointCount, mIsBeginningOfSentence, - std::min(mPrevWordCount, maxPrevWordCount)); - } - - bool isValid() const { - if (mPrevWordCodePointCount[0] > 0) { - return true; - } - if (mIsBeginningOfSentence[0]) { - return true; - } - return false; - } + bool isValid() const; template<size_t N> const WordIdArrayView getPrevWordIds( const DictionaryStructureWithBufferPolicy *const dictStructurePolicy, - std::array<int, N> *const prevWordIdBuffer, const bool tryLowerCaseSearch) const { + WordIdArray<N> *const prevWordIdBuffer, const bool tryLowerCaseSearch) const { for (size_t i = 0; i < std::min(mPrevWordCount, N); ++i) { - prevWordIdBuffer->at(i) = getWordId(dictStructurePolicy, - mPrevWordCodePoints[i], mPrevWordCodePointCount[i], - mIsBeginningOfSentence[i], tryLowerCaseSearch); + prevWordIdBuffer->at(i) = getWordId(dictStructurePolicy, mPrevWordCodePoints[i], + mPrevWordCodePointCount[i], mIsBeginningOfSentence[i], tryLowerCaseSearch); } return WordIdArrayView::fromArray(*prevWordIdBuffer).limit(mPrevWordCount); } // n is 1-indexed. - const CodePointArrayView getNthPrevWordCodePoints(const size_t n) const { - if (n <= 0 || n > mPrevWordCount) { - return CodePointArrayView(); - } - return CodePointArrayView(mPrevWordCodePoints[n - 1], mPrevWordCodePointCount[n - 1]); - } - + const CodePointArrayView getNthPrevWordCodePoints(const size_t n) const; // n is 1-indexed. - bool isNthPrevWordBeginningOfSentence(const size_t n) const { - if (n <= 0 || n > mPrevWordCount) { - return false; - } - return mIsBeginningOfSentence[n - 1]; - } + bool isNthPrevWordBeginningOfSentence(const size_t n) const; private: DISALLOW_ASSIGNMENT_OPERATOR(NgramContext); static int getWordId(const DictionaryStructureWithBufferPolicy *const dictStructurePolicy, const int *const wordCodePoints, const int wordCodePointCount, - const bool isBeginningOfSentence, const bool tryLowerCaseSearch) { - if (!dictStructurePolicy || !wordCodePoints || wordCodePointCount > MAX_WORD_LENGTH) { - return NOT_A_WORD_ID; - } - int codePoints[MAX_WORD_LENGTH]; - int codePointCount = wordCodePointCount; - memmove(codePoints, wordCodePoints, sizeof(int) * codePointCount); - if (isBeginningOfSentence) { - codePointCount = CharUtils::attachBeginningOfSentenceMarker(codePoints, - codePointCount, MAX_WORD_LENGTH); - if (codePointCount <= 0) { - return NOT_A_WORD_ID; - } - } - const CodePointArrayView codePointArrayView(codePoints, codePointCount); - const int wordId = dictStructurePolicy->getWordId( - codePointArrayView, false /* forceLowerCaseSearch */); - if (wordId != NOT_A_WORD_ID || !tryLowerCaseSearch) { - // Return the id when when the word was found or doesn't try lower case search. - return wordId; - } - // Check bigrams for lower-cased previous word if original was not found. Useful for - // auto-capitalized words like "The [current_word]". - return dictStructurePolicy->getWordId(codePointArrayView, true /* forceLowerCaseSearch */); - } - - void clear() { - for (size_t i = 0; i < NELEMS(mPrevWordCodePoints); ++i) { - mPrevWordCodePointCount[i] = 0; - mIsBeginningOfSentence[i] = false; - } - } + const bool isBeginningOfSentence, const bool tryLowerCaseSearch); + void clear(); const size_t mPrevWordCount; int mPrevWordCodePoints[MAX_PREV_WORD_COUNT_FOR_N_GRAM][MAX_WORD_LENGTH]; diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp index c71526293..c372d668b 100644 --- a/native/jni/src/suggest/core/suggest.cpp +++ b/native/jni/src/suggest/core/suggest.cpp @@ -160,8 +160,7 @@ void Suggest::expandCurrentDicNodes(DicTraverseSession *traverseSession) const { // TODO: Remove. Do not prune node here. const bool allowsErrorCorrections = TRAVERSAL->allowsErrorCorrections(&dicNode); // Process for handling space substitution (e.g., hevis => he is) - if (allowsErrorCorrections - && TRAVERSAL->isSpaceSubstitutionTerminal(traverseSession, &dicNode)) { + if (TRAVERSAL->isSpaceSubstitutionTerminal(traverseSession, &dicNode)) { createNextWordDicNode(traverseSession, &dicNode, true /* spaceSubstitution */); } @@ -417,6 +416,11 @@ void Suggest::createNextWordDicNode(DicTraverseSession *traverseSession, DicNode traverseSession->getDictionaryStructurePolicy()->getWordAttributesInContext( dicNode->getPrevWordIds(), dicNode->getWordId(), traverseSession->getMultiBigramMap()); + if (SuggestionsOutputUtils::shouldBlockWord(traverseSession->getSuggestOptions(), + dicNode, wordAttributes, false /* isLastWord */)) { + return; + } + if (!TRAVERSAL->isGoodToTraverseNextWord(dicNode, wordAttributes.getProbability())) { return; } diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp index 4c4dfc578..300e96c4e 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp @@ -30,6 +30,7 @@ const char *const HeaderPolicy::DATE_KEY = "date"; const char *const HeaderPolicy::LAST_DECAYED_TIME_KEY = "LAST_DECAYED_TIME"; const char *const HeaderPolicy::UNIGRAM_COUNT_KEY = "UNIGRAM_COUNT"; const char *const HeaderPolicy::BIGRAM_COUNT_KEY = "BIGRAM_COUNT"; +const char *const HeaderPolicy::TRIGRAM_COUNT_KEY = "TRIGRAM_COUNT"; const char *const HeaderPolicy::EXTENDED_REGION_SIZE_KEY = "EXTENDED_REGION_SIZE"; // Historical info is information that is needed to support decaying such as timestamp, level and // count. @@ -38,15 +39,17 @@ const char *const HeaderPolicy::LOCALE_KEY = "locale"; // match Java declaration const char *const HeaderPolicy::FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY = "FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID"; -const char *const HeaderPolicy::MAX_UNIGRAM_COUNT_KEY = "MAX_UNIGRAM_COUNT"; -const char *const HeaderPolicy::MAX_BIGRAM_COUNT_KEY = "MAX_BIGRAM_COUNT"; +const char *const HeaderPolicy::MAX_UNIGRAM_COUNT_KEY = "MAX_UNIGRAM_ENTRY_COUNT"; +const char *const HeaderPolicy::MAX_BIGRAM_COUNT_KEY = "MAX_BIGRAM_ENTRY_COUNT"; +const char *const HeaderPolicy::MAX_TRIGRAM_COUNT_KEY = "MAX_TRIGRAM_ENTRY_COUNT"; const int HeaderPolicy::DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE = 100; const float HeaderPolicy::MULTIPLE_WORD_COST_MULTIPLIER_SCALE = 100.0f; const int HeaderPolicy::DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID = 3; const int HeaderPolicy::DEFAULT_MAX_UNIGRAM_COUNT = 10000; -const int HeaderPolicy::DEFAULT_MAX_BIGRAM_COUNT = 10000; +const int HeaderPolicy::DEFAULT_MAX_BIGRAM_COUNT = 30000; +const int HeaderPolicy::DEFAULT_MAX_TRIGRAM_COUNT = 30000; // Used for logging. Question mark is used to indicate that the key is not found. void HeaderPolicy::readHeaderValueOrQuestionMark(const char *const key, int *outValue, @@ -92,12 +95,11 @@ bool HeaderPolicy::readRequiresGermanUmlautProcessing() const { } bool HeaderPolicy::fillInAndWriteHeaderToBuffer(const bool updatesLastDecayedTime, - const int unigramCount, const int bigramCount, - const int extendedRegionSize, BufferWithExtendableBuffer *const outBuffer) const { + const EntryCounts &entryCounts, const int extendedRegionSize, + BufferWithExtendableBuffer *const outBuffer) const { int writingPos = 0; DictionaryHeaderStructurePolicy::AttributeMap attributeMapToWrite(mAttributeMap); - fillInHeader(updatesLastDecayedTime, unigramCount, bigramCount, - extendedRegionSize, &attributeMapToWrite); + fillInHeader(updatesLastDecayedTime, entryCounts, extendedRegionSize, &attributeMapToWrite); if (!HeaderReadWriteUtils::writeDictionaryVersion(outBuffer, mDictFormatVersion, &writingPos)) { return false; @@ -124,11 +126,15 @@ bool HeaderPolicy::fillInAndWriteHeaderToBuffer(const bool updatesLastDecayedTim return true; } -void HeaderPolicy::fillInHeader(const bool updatesLastDecayedTime, const int unigramCount, - const int bigramCount, const int extendedRegionSize, +void HeaderPolicy::fillInHeader(const bool updatesLastDecayedTime, + const EntryCounts &entryCounts, const int extendedRegionSize, DictionaryHeaderStructurePolicy::AttributeMap *outAttributeMap) const { - HeaderReadWriteUtils::setIntAttribute(outAttributeMap, UNIGRAM_COUNT_KEY, unigramCount); - HeaderReadWriteUtils::setIntAttribute(outAttributeMap, BIGRAM_COUNT_KEY, bigramCount); + HeaderReadWriteUtils::setIntAttribute(outAttributeMap, UNIGRAM_COUNT_KEY, + entryCounts.getUnigramCount()); + HeaderReadWriteUtils::setIntAttribute(outAttributeMap, BIGRAM_COUNT_KEY, + entryCounts.getBigramCount()); + HeaderReadWriteUtils::setIntAttribute(outAttributeMap, TRIGRAM_COUNT_KEY, + entryCounts.getTrigramCount()); HeaderReadWriteUtils::setIntAttribute(outAttributeMap, EXTENDED_REGION_SIZE_KEY, extendedRegionSize); // Set the current time as the generation time. diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h index bc8eaded3..7a5acd7d5 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h @@ -22,6 +22,7 @@ #include "defines.h" #include "suggest/core/policy/dictionary_header_structure_policy.h" #include "suggest/policyimpl/dictionary/header/header_read_write_utils.h" +#include "suggest/policyimpl/dictionary/utils/entry_counters.h" #include "suggest/policyimpl/dictionary/utils/format_utils.h" #include "utils/char_utils.h" #include "utils/time_keeper.h" @@ -49,6 +50,8 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { UNIGRAM_COUNT_KEY, 0 /* defaultValue */)), mBigramCount(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap, BIGRAM_COUNT_KEY, 0 /* defaultValue */)), + mTrigramCount(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap, + TRIGRAM_COUNT_KEY, 0 /* defaultValue */)), mExtendedRegionSize(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap, EXTENDED_REGION_SIZE_KEY, 0 /* defaultValue */)), mHasHistoricalInfoOfWords(HeaderReadWriteUtils::readBoolAttributeValue( @@ -60,6 +63,8 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { &mAttributeMap, MAX_UNIGRAM_COUNT_KEY, DEFAULT_MAX_UNIGRAM_COUNT)), mMaxBigramCount(HeaderReadWriteUtils::readIntAttributeValue( &mAttributeMap, MAX_BIGRAM_COUNT_KEY, DEFAULT_MAX_BIGRAM_COUNT)), + mMaxTrigramCount(HeaderReadWriteUtils::readIntAttributeValue( + &mAttributeMap, MAX_TRIGRAM_COUNT_KEY, DEFAULT_MAX_TRIGRAM_COUNT)), mCodePointTable(HeaderReadWriteUtils::readCodePointTable(&mAttributeMap)) {} // Constructs header information using an attribute map. @@ -77,7 +82,7 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { DATE_KEY, TimeKeeper::peekCurrentTime() /* defaultValue */)), mLastDecayedTime(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap, DATE_KEY, TimeKeeper::peekCurrentTime() /* defaultValue */)), - mUnigramCount(0), mBigramCount(0), mExtendedRegionSize(0), + mUnigramCount(0), mBigramCount(0), mTrigramCount(0), mExtendedRegionSize(0), mHasHistoricalInfoOfWords(HeaderReadWriteUtils::readBoolAttributeValue( &mAttributeMap, HAS_HISTORICAL_INFO_KEY, false /* defaultValue */)), mForgettingCurveProbabilityValuesTableId(HeaderReadWriteUtils::readIntAttributeValue( @@ -87,6 +92,8 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { &mAttributeMap, MAX_UNIGRAM_COUNT_KEY, DEFAULT_MAX_UNIGRAM_COUNT)), mMaxBigramCount(HeaderReadWriteUtils::readIntAttributeValue( &mAttributeMap, MAX_BIGRAM_COUNT_KEY, DEFAULT_MAX_BIGRAM_COUNT)), + mMaxTrigramCount(HeaderReadWriteUtils::readIntAttributeValue( + &mAttributeMap, MAX_TRIGRAM_COUNT_KEY, DEFAULT_MAX_TRIGRAM_COUNT)), mCodePointTable(HeaderReadWriteUtils::readCodePointTable(&mAttributeMap)) {} // Copy header information @@ -99,12 +106,14 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { mIsDecayingDict(headerPolicy->mIsDecayingDict), mDate(headerPolicy->mDate), mLastDecayedTime(headerPolicy->mLastDecayedTime), mUnigramCount(headerPolicy->mUnigramCount), mBigramCount(headerPolicy->mBigramCount), + mTrigramCount(headerPolicy->mTrigramCount), mExtendedRegionSize(headerPolicy->mExtendedRegionSize), mHasHistoricalInfoOfWords(headerPolicy->mHasHistoricalInfoOfWords), mForgettingCurveProbabilityValuesTableId( headerPolicy->mForgettingCurveProbabilityValuesTableId), mMaxUnigramCount(headerPolicy->mMaxUnigramCount), mMaxBigramCount(headerPolicy->mMaxBigramCount), + mMaxTrigramCount(headerPolicy->mMaxTrigramCount), mCodePointTable(headerPolicy->mCodePointTable) {} // Temporary dummy header. @@ -112,10 +121,10 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { : mDictFormatVersion(FormatUtils::UNKNOWN_VERSION), mDictionaryFlags(0), mSize(0), mAttributeMap(), mLocale(CharUtils::EMPTY_STRING), mMultiWordCostMultiplier(0.0f), mRequiresGermanUmlautProcessing(false), mIsDecayingDict(false), - mDate(0), mLastDecayedTime(0), mUnigramCount(0), mBigramCount(0), + mDate(0), mLastDecayedTime(0), mUnigramCount(0), mBigramCount(0), mTrigramCount(0), mExtendedRegionSize(0), mHasHistoricalInfoOfWords(false), mForgettingCurveProbabilityValuesTableId(0), mMaxUnigramCount(0), mMaxBigramCount(0), - mCodePointTable(nullptr) {} + mMaxTrigramCount(0), mCodePointTable(nullptr) {} ~HeaderPolicy() {} @@ -125,15 +134,17 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { // same so we use them for both here. switch (mDictFormatVersion) { case FormatUtils::VERSION_2: - return FormatUtils::VERSION_2; case FormatUtils::VERSION_201: - return FormatUtils::VERSION_201; + AKLOGE("Dictionary versions 2 and 201 are incompatible with this version"); + return FormatUtils::UNKNOWN_VERSION; + case FormatUtils::VERSION_202: + return FormatUtils::VERSION_202; case FormatUtils::VERSION_4_ONLY_FOR_TESTING: return FormatUtils::VERSION_4_ONLY_FOR_TESTING; - case FormatUtils::VERSION_4: - return FormatUtils::VERSION_4; - case FormatUtils::VERSION_4_DEV: - return FormatUtils::VERSION_4_DEV; + case FormatUtils::VERSION_402: + return FormatUtils::VERSION_402; + case FormatUtils::VERSION_403: + return FormatUtils::VERSION_403; default: return FormatUtils::UNKNOWN_VERSION; } @@ -183,6 +194,10 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { return mBigramCount; } + AK_FORCE_INLINE int getTrigramCount() const { + return mTrigramCount; + } + AK_FORCE_INLINE int getExtendedRegionSize() const { return mExtendedRegionSize; } @@ -212,15 +227,19 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { return mMaxBigramCount; } + AK_FORCE_INLINE int getMaxTrigramCount() const { + return mMaxTrigramCount; + } + void readHeaderValueOrQuestionMark(const char *const key, int *outValue, int outValueSize) const; bool fillInAndWriteHeaderToBuffer(const bool updatesLastDecayedTime, - const int unigramCount, const int bigramCount, - const int extendedRegionSize, BufferWithExtendableBuffer *const outBuffer) const; + const EntryCounts &entryCounts, const int extendedRegionSize, + BufferWithExtendableBuffer *const outBuffer) const; - void fillInHeader(const bool updatesLastDecayedTime, - const int unigramCount, const int bigramCount, const int extendedRegionSize, + void fillInHeader(const bool updatesLastDecayedTime, const EntryCounts &entryCounts, + const int extendedRegionSize, DictionaryHeaderStructurePolicy::AttributeMap *outAttributeMap) const; AK_FORCE_INLINE const std::vector<int> *getLocale() const { @@ -228,7 +247,7 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { } bool supportsBeginningOfSentence() const { - return mDictFormatVersion >= FormatUtils::VERSION_4; + return mDictFormatVersion >= FormatUtils::VERSION_402; } const int *getCodePointTable() const { @@ -245,6 +264,7 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { static const char *const LAST_DECAYED_TIME_KEY; static const char *const UNIGRAM_COUNT_KEY; static const char *const BIGRAM_COUNT_KEY; + static const char *const TRIGRAM_COUNT_KEY; static const char *const EXTENDED_REGION_SIZE_KEY; static const char *const HAS_HISTORICAL_INFO_KEY; static const char *const LOCALE_KEY; @@ -253,11 +273,13 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { static const char *const FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS_KEY; static const char *const MAX_UNIGRAM_COUNT_KEY; static const char *const MAX_BIGRAM_COUNT_KEY; + static const char *const MAX_TRIGRAM_COUNT_KEY; static const int DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE; static const float MULTIPLE_WORD_COST_MULTIPLIER_SCALE; static const int DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID; static const int DEFAULT_MAX_UNIGRAM_COUNT; static const int DEFAULT_MAX_BIGRAM_COUNT; + static const int DEFAULT_MAX_TRIGRAM_COUNT; const FormatUtils::FORMAT_VERSION mDictFormatVersion; const HeaderReadWriteUtils::DictionaryFlags mDictionaryFlags; @@ -271,11 +293,13 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { const int mLastDecayedTime; const int mUnigramCount; const int mBigramCount; + const int mTrigramCount; const int mExtendedRegionSize; const bool mHasHistoricalInfoOfWords; const int mForgettingCurveProbabilityValuesTableId; const int mMaxUnigramCount; const int mMaxBigramCount; + const int mMaxTrigramCount; const int *const mCodePointTable; const std::vector<int> readLocale() const; diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp index 41a8b13b8..19ed0d468 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp @@ -111,11 +111,12 @@ typedef DictionaryHeaderStructurePolicy::AttributeMap AttributeMap; switch (version) { case FormatUtils::VERSION_2: case FormatUtils::VERSION_201: - // Version 2 or 201 dictionary writing is not supported. + case FormatUtils::VERSION_202: + // None of the static dictionaries (v2x) support writing return false; case FormatUtils::VERSION_4_ONLY_FOR_TESTING: - case FormatUtils::VERSION_4: - case FormatUtils::VERSION_4_DEV: + case FormatUtils::VERSION_402: + case FormatUtils::VERSION_403: return buffer->writeUintAndAdvancePosition(version /* data */, HEADER_DICTIONARY_VERSION_SIZE, writingPos); default: diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/bigram_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/bigram_dict_content.cpp index 9e1adff70..15ac88319 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/bigram_dict_content.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/bigram_dict_content.cpp @@ -65,6 +65,8 @@ const BigramEntry BigramDictContent::getBigramEntryAndAdvancePosition( (encodedTargetTerminalId == Ver4DictConstants::INVALID_BIGRAM_TARGET_TERMINAL_ID) ? Ver4DictConstants::NOT_A_TERMINAL_ID : encodedTargetTerminalId; if (mHasHistoricalInfo) { + // Hack for better migration. + count += level; const HistoricalInfo historicalInfo(timestamp, level, count); return BigramEntry(hasNext, probability, &historicalInfo, targetTerminalId); } else { diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/probability_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/probability_dict_content.cpp index ef6166ffd..61ef4aa42 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/probability_dict_content.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/probability_dict_content.cpp @@ -50,7 +50,8 @@ const ProbabilityEntry ProbabilityDictContent::getProbabilityEntry(const int ter Ver4DictConstants::WORD_LEVEL_FIELD_SIZE, &entryPos); const int count = buffer->readUintAndAdvancePosition( Ver4DictConstants::WORD_COUNT_FIELD_SIZE, &entryPos); - const HistoricalInfo historicalInfo(timestamp, level, count); + // Hack for better migration. + const HistoricalInfo historicalInfo(timestamp, level, count + level); return ProbabilityEntry(flags, probability, &historicalInfo); } else { return ProbabilityEntry(flags, probability); diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp index 6243f14cc..d558b949a 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp @@ -245,7 +245,7 @@ bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds if (!sourcePtNodeParams.hasBigrams()) { // Update has bigrams flag. return updatePtNodeFlags(sourcePtNodeParams.getHeadPos(), - sourcePtNodeParams.isBlacklisted(), sourcePtNodeParams.isNotAWord(), + sourcePtNodeParams.isPossiblyOffensive(), sourcePtNodeParams.isNotAWord(), sourcePtNodeParams.isTerminal(), sourcePtNodeParams.hasShortcutTargets(), true /* hasBigrams */, sourcePtNodeParams.getCodePointCount() > 1 /* hasMultipleChars */); @@ -316,7 +316,7 @@ bool Ver4PatriciaTrieNodeWriter::addShortcutTarget(const PtNodeParams *const ptN if (!ptNodeParams->hasShortcutTargets()) { // Update has shortcut targets flag. return updatePtNodeFlags(ptNodeParams->getHeadPos(), - ptNodeParams->isBlacklisted(), ptNodeParams->isNotAWord(), + ptNodeParams->isPossiblyOffensive(), ptNodeParams->isNotAWord(), ptNodeParams->isTerminal(), true /* hasShortcutTargets */, ptNodeParams->hasBigrams(), ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */); @@ -330,7 +330,7 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeHasBigramsAndShortcutTargetsFlags( ptNodeParams->getTerminalId()) != NOT_A_DICT_POS; const bool hasShortcutTargets = mBuffers->getShortcutDictContent()->getShortcutListHeadPos( ptNodeParams->getTerminalId()) != NOT_A_DICT_POS; - return updatePtNodeFlags(ptNodeParams->getHeadPos(), ptNodeParams->isBlacklisted(), + return updatePtNodeFlags(ptNodeParams->getHeadPos(), ptNodeParams->isPossiblyOffensive(), ptNodeParams->isNotAWord(), ptNodeParams->isTerminal(), hasShortcutTargets, hasBigrams, ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */); } @@ -386,8 +386,9 @@ bool Ver4PatriciaTrieNodeWriter::writePtNodeAndGetTerminalIdAndAdvancePosition( ptNodeParams->getChildrenPos(), ptNodeWritingPos)) { return false; } - return updatePtNodeFlags(nodePos, ptNodeParams->isBlacklisted(), ptNodeParams->isNotAWord(), - isTerminal, ptNodeParams->hasShortcutTargets(), ptNodeParams->hasBigrams(), + return updatePtNodeFlags(nodePos, ptNodeParams->isPossiblyOffensive(), + ptNodeParams->isNotAWord(), isTerminal, ptNodeParams->hasShortcutTargets(), + ptNodeParams->hasBigrams(), ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */); } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp index 0eae934ae..9455222dd 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp @@ -140,7 +140,7 @@ const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext( const WordAttributes Ver4PatriciaTriePolicy::getWordAttributes(const int probability, const PtNodeParams &ptNodeParams) const { - return WordAttributes(probability, ptNodeParams.isBlacklisted(), ptNodeParams.isNotAWord(), + return WordAttributes(probability, false /* isBlacklisted */, ptNodeParams.isNotAWord(), ptNodeParams.getProbability() == 0); } @@ -164,7 +164,7 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordI } const int ptNodePos = getTerminalPtNodePosFromWordId(wordId); const PtNodeParams ptNodeParams(mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos)); - if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) { + if (ptNodeParams.isDeleted() || ptNodeParams.isNotAWord()) { return NOT_A_PROBABILITY; } if (prevWordIds.empty()) { @@ -303,7 +303,7 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const CodePointArrayView wordCodePo if (mUpdatingHelper.addUnigramWord(&readingHelper, codePointArrayView, unigramProperty, &addedNewUnigram)) { if (addedNewUnigram && !unigramProperty->representsBeginningOfSentence()) { - mUnigramCount++; + mEntryCounters.incrementUnigramCount(); } if (unigramProperty->getShortcuts().size() > 0) { // Add shortcut target. @@ -344,8 +344,7 @@ bool Ver4PatriciaTriePolicy::removeUnigramEntry(const CodePointArrayView wordCod return mNodeWriter.suppressUnigramEntry(&ptNodeParams); } -bool Ver4PatriciaTriePolicy::addNgramEntry(const NgramContext *const ngramContext, - const NgramProperty *const ngramProperty) { +bool Ver4PatriciaTriePolicy::addNgramEntry(const NgramProperty *const ngramProperty) { if (!mBuffers->isUpdatable()) { AKLOGI("Warning: addNgramEntry() is called for non-updatable dictionary."); return false; @@ -355,6 +354,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const NgramContext *const ngramContex mDictBuffer->getTailPosition()); return false; } + const NgramContext *const ngramContext = ngramProperty->getNgramContext(); if (!ngramContext->isValid()) { AKLOGE("Ngram context is not valid for adding n-gram entry to the dictionary."); return false; @@ -397,7 +397,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const NgramContext *const ngramContex if (mUpdatingHelper.addNgramEntry(PtNodePosArrayView::singleElementView(&prevWordPtNodePos), wordPos, ngramProperty, &addedNewBigram)) { if (addedNewBigram) { - mBigramCount++; + mEntryCounters.incrementBigramCount(); } return true; } else { @@ -438,7 +438,7 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const NgramContext *const ngramCon const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]); if (mUpdatingHelper.removeNgramEntry( PtNodePosArrayView::singleElementView(&prevWordPtNodePos), wordPos)) { - mBigramCount--; + mEntryCounters.decrementBigramCount(); return true; } else { return false; @@ -463,9 +463,9 @@ bool Ver4PatriciaTriePolicy::updateEntriesForWordWithNgramContext( } const int probabilityForNgram = ngramContext->isNthPrevWordBeginningOfSentence(1 /* n */) ? NOT_A_PROBABILITY : probability; - const NgramProperty ngramProperty(wordCodePoints.toVector(), probabilityForNgram, + const NgramProperty ngramProperty(*ngramContext, wordCodePoints.toVector(), probabilityForNgram, historicalInfo); - if (!addNgramEntry(ngramContext, &ngramProperty)) { + if (!addNgramEntry(&ngramProperty)) { AKLOGE("Cannot update unigarm entry in updateEntriesForWordWithNgramContext()."); return false; } @@ -477,7 +477,7 @@ bool Ver4PatriciaTriePolicy::flush(const char *const filePath) { AKLOGI("Warning: flush() is called for non-updatable dictionary. filePath: %s", filePath); return false; } - if (!mWritingHelper.writeToDictFile(filePath, mUnigramCount, mBigramCount)) { + if (!mWritingHelper.writeToDictFile(filePath, mEntryCounters.getEntryCounts())) { AKLOGE("Cannot flush the dictionary to file."); mIsCorrupted = true; return false; @@ -515,7 +515,7 @@ bool Ver4PatriciaTriePolicy::needsToRunGC(const bool mindsBlockByGC) const { // Needs to reduce dictionary size. return true; } else if (mHeaderPolicy->isDecayingDict()) { - return ForgettingCurveUtils::needsToDecay(mindsBlockByGC, mUnigramCount, mBigramCount, + return ForgettingCurveUtils::needsToDecay(mindsBlockByGC, mEntryCounters.getEntryCounts(), mHeaderPolicy); } return false; @@ -525,19 +525,19 @@ void Ver4PatriciaTriePolicy::getProperty(const char *const query, const int quer char *const outResult, const int maxResultLength) { const int compareLength = queryLength + 1 /* terminator */; if (strncmp(query, UNIGRAM_COUNT_QUERY, compareLength) == 0) { - snprintf(outResult, maxResultLength, "%d", mUnigramCount); + snprintf(outResult, maxResultLength, "%d", mEntryCounters.getUnigramCount()); } else if (strncmp(query, BIGRAM_COUNT_QUERY, compareLength) == 0) { - snprintf(outResult, maxResultLength, "%d", mBigramCount); + snprintf(outResult, maxResultLength, "%d", mEntryCounters.getBigramCount()); } else if (strncmp(query, MAX_UNIGRAM_COUNT_QUERY, compareLength) == 0) { snprintf(outResult, maxResultLength, "%d", mHeaderPolicy->isDecayingDict() ? - ForgettingCurveUtils::getUnigramCountHardLimit( + ForgettingCurveUtils::getEntryCountHardLimit( mHeaderPolicy->getMaxUnigramCount()) : static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE)); } else if (strncmp(query, MAX_BIGRAM_COUNT_QUERY, compareLength) == 0) { snprintf(outResult, maxResultLength, "%d", mHeaderPolicy->isDecayingDict() ? - ForgettingCurveUtils::getBigramCountHardLimit( + ForgettingCurveUtils::getEntryCountHardLimit( mHeaderPolicy->getMaxBigramCount()) : static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE)); } @@ -580,11 +580,15 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty( getWordIdFromTerminalPtNodePos(word1TerminalPtNodePos), MAX_WORD_LENGTH, bigramWord1CodePoints); const HistoricalInfo *const historicalInfo = bigramEntry.getHistoricalInfo(); - const int probability = bigramEntry.hasHistoricalInfo() ? - ForgettingCurveUtils::decodeProbability( - bigramEntry.getHistoricalInfo(), mHeaderPolicy) : - bigramEntry.getProbability(); + const int rawBigramProbability = bigramEntry.hasHistoricalInfo() + ? ForgettingCurveUtils::decodeProbability( + bigramEntry.getHistoricalInfo(), mHeaderPolicy) + : bigramEntry.getProbability(); + const int probability = getBigramConditionalProbability(ptNodeParams.getProbability(), + ptNodeParams.representsBeginningOfSentence(), rawBigramProbability); ngrams.emplace_back( + NgramContext(wordCodePoints.data(), wordCodePoints.size(), + ptNodeParams.representsBeginningOfSentence()), CodePointArrayView(bigramWord1CodePoints, codePointCount).toVector(), probability, *historicalInfo); } @@ -608,8 +612,8 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty( } } const UnigramProperty unigramProperty(ptNodeParams.representsBeginningOfSentence(), - ptNodeParams.isNotAWord(), ptNodeParams.isBlacklisted(), ptNodeParams.getProbability(), - *historicalInfo, std::move(shortcuts)); + ptNodeParams.isNotAWord(), ptNodeParams.isPossiblyOffensive(), + ptNodeParams.getProbability(), *historicalInfo, std::move(shortcuts)); return WordProperty(wordCodePoints.toVector(), &unigramProperty, &ngrams); } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h index 1ad5e7e36..0480876ed 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h @@ -41,6 +41,7 @@ #include "suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.h" #include "suggest/policyimpl/dictionary/structure/backward/v402/ver4_pt_node_array_reader.h" #include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" +#include "suggest/policyimpl/dictionary/utils/entry_counters.h" #include "utils/int_array_view.h" namespace latinime { @@ -75,8 +76,8 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { &mPtNodeArrayReader, &mBigramPolicy, &mShortcutPolicy), mUpdatingHelper(mDictBuffer, &mNodeReader, &mNodeWriter), mWritingHelper(mBuffers.get()), - mUnigramCount(mHeaderPolicy->getUnigramCount()), - mBigramCount(mHeaderPolicy->getBigramCount()), + mEntryCounters(mHeaderPolicy->getUnigramCount(), mHeaderPolicy->getBigramCount(), + mHeaderPolicy->getTrigramCount()), mTerminalPtNodePositionsForIteratingWords(), mIsCorrupted(false) {}; virtual int getRootPosition() const { @@ -112,8 +113,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { bool removeUnigramEntry(const CodePointArrayView wordCodePoints); - bool addNgramEntry(const NgramContext *const ngramContext, - const NgramProperty *const ngramProperty); + bool addNgramEntry(const NgramProperty *const ngramProperty); bool removeNgramEntry(const NgramContext *const ngramContext, const CodePointArrayView wordCodePoints); @@ -163,8 +163,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { Ver4PatriciaTrieNodeWriter mNodeWriter; DynamicPtUpdatingHelper mUpdatingHelper; Ver4PatriciaTrieWritingHelper mWritingHelper; - int mUnigramCount; - int mBigramCount; + MutableEntryCounters mEntryCounters; std::vector<int> mTerminalPtNodePositionsForIteratingWords; mutable bool mIsCorrupted; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.cpp index 2887dc6b1..a033d396b 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.cpp @@ -43,18 +43,18 @@ namespace backward { namespace v402 { bool Ver4PatriciaTrieWritingHelper::writeToDictFile(const char *const dictDirPath, - const int unigramCount, const int bigramCount) const { + const EntryCounts &entryCounts) const { const HeaderPolicy *const headerPolicy = mBuffers->getHeaderPolicy(); BufferWithExtendableBuffer headerBuffer( BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE); const int extendedRegionSize = headerPolicy->getExtendedRegionSize() + mBuffers->getTrieBuffer()->getUsedAdditionalBufferSize(); if (!headerPolicy->fillInAndWriteHeaderToBuffer(false /* updatesLastDecayedTime */, - unigramCount, bigramCount, extendedRegionSize, &headerBuffer)) { + entryCounts, extendedRegionSize, &headerBuffer)) { AKLOGE("Cannot write header structure to buffer. " "updatesLastDecayedTime: %d, unigramCount: %d, bigramCount: %d, " - "extendedRegionSize: %d", false, unigramCount, bigramCount, - extendedRegionSize); + "extendedRegionSize: %d", false, entryCounts.getUnigramCount(), + entryCounts.getBigramCount(), extendedRegionSize); return false; } return mBuffers->flushHeaderAndDictBuffers(dictDirPath, &headerBuffer); @@ -74,7 +74,8 @@ bool Ver4PatriciaTrieWritingHelper::writeToDictFileWithGC(const int rootPtNodeAr BufferWithExtendableBuffer headerBuffer( BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE); if (!headerPolicy->fillInAndWriteHeaderToBuffer(true /* updatesLastDecayedTime */, - unigramCount, bigramCount, 0 /* extendedRegionSize */, &headerBuffer)) { + EntryCounts(unigramCount, bigramCount, 0 /* trigramCount */), + 0 /* extendedRegionSize */, &headerBuffer)) { return false; } return dictBuffers->flushHeaderAndDictBuffers(dictDirPath, &headerBuffer); diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.h index 9034ee656..1aad33e38 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.h @@ -27,6 +27,7 @@ #include "defines.h" #include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_gc_event_listeners.h" #include "suggest/policyimpl/dictionary/structure/backward/v402/content/terminal_position_lookup_table.h" +#include "suggest/policyimpl/dictionary/utils/entry_counters.h" namespace latinime { namespace backward { @@ -46,8 +47,7 @@ class Ver4PatriciaTrieWritingHelper { Ver4PatriciaTrieWritingHelper(Ver4DictBuffers *const buffers) : mBuffers(buffers) {} - bool writeToDictFile(const char *const dictDirPath, const int unigramCount, - const int bigramCount) const; + bool writeToDictFile(const char *const dictDirPath, const EntryCounts &entryCounts) const; // This method cannot be const because the original dictionary buffer will be updated to detect // useless PtNodes during GC. diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp index 372c9e36f..9a9a21b6b 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp @@ -58,7 +58,7 @@ namespace latinime { const DictionaryHeaderStructurePolicy::AttributeMap *const attributeMap) { FormatUtils::FORMAT_VERSION dictFormatVersion = FormatUtils::getFormatVersion(formatVersion); switch (dictFormatVersion) { - case FormatUtils::VERSION_4: { + case FormatUtils::VERSION_402: { return newPolicyForOnMemoryV4Dict<backward::v402::Ver4DictConstants, backward::v402::Ver4DictBuffers, backward::v402::Ver4DictBuffers::Ver4DictBuffersPtr, @@ -66,7 +66,7 @@ namespace latinime { dictFormatVersion, locale, attributeMap); } case FormatUtils::VERSION_4_ONLY_FOR_TESTING: - case FormatUtils::VERSION_4_DEV: { + case FormatUtils::VERSION_403: { return newPolicyForOnMemoryV4Dict<Ver4DictConstants, Ver4DictBuffers, Ver4DictBuffers::Ver4DictBuffersPtr, Ver4PatriciaTriePolicy>( dictFormatVersion, locale, attributeMap); @@ -115,9 +115,10 @@ template<class DictConstants, class DictBuffers, class DictBuffersPtr, class Str switch (formatVersion) { case FormatUtils::VERSION_2: case FormatUtils::VERSION_201: - AKLOGE("Given path is a directory but the format is version 2 or 201. path: %s", path); + case FormatUtils::VERSION_202: + AKLOGE("Given path is a directory but the format is version 2xx. path: %s", path); break; - case FormatUtils::VERSION_4: { + case FormatUtils::VERSION_402: { return newPolicyForV4Dict<backward::v402::Ver4DictConstants, backward::v402::Ver4DictBuffers, backward::v402::Ver4DictBuffers::Ver4DictBuffersPtr, @@ -125,7 +126,7 @@ template<class DictConstants, class DictBuffers, class DictBuffersPtr, class Str headerFilePath, formatVersion, std::move(mmappedBuffer)); } case FormatUtils::VERSION_4_ONLY_FOR_TESTING: - case FormatUtils::VERSION_4_DEV: { + case FormatUtils::VERSION_403: { return newPolicyForV4Dict<Ver4DictConstants, Ver4DictBuffers, Ver4DictBuffers::Ver4DictBuffersPtr, Ver4PatriciaTriePolicy>( headerFilePath, formatVersion, std::move(mmappedBuffer)); @@ -177,11 +178,14 @@ template<class DictConstants, class DictBuffers, class DictBuffersPtr, class Str switch (FormatUtils::detectFormatVersion(mmappedBuffer->getReadOnlyByteArrayView())) { case FormatUtils::VERSION_2: case FormatUtils::VERSION_201: + AKLOGE("Dictionary versions 2 and 201 are incompatible with this version"); + break; + case FormatUtils::VERSION_202: return DictionaryStructureWithBufferPolicy::StructurePolicyPtr( new PatriciaTriePolicy(std::move(mmappedBuffer))); case FormatUtils::VERSION_4_ONLY_FOR_TESTING: - case FormatUtils::VERSION_4: - case FormatUtils::VERSION_4_DEV: + case FormatUtils::VERSION_402: + case FormatUtils::VERSION_403: AKLOGE("Given path is a file but the format is version 4. path: %s", path); break; default: diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.cpp index 92fd6f214..e524e86e5 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.cpp @@ -146,7 +146,7 @@ bool DynamicPtUpdatingHelper::setPtNodeProbability(const PtNodeParams *const ori const int movedPos = mBuffer->getTailPosition(); int writingPos = movedPos; const PtNodeParams ptNodeParamsToWrite(getUpdatedPtNodeParams(originalPtNodeParams, - unigramProperty->isNotAWord(), unigramProperty->isBlacklisted(), + unigramProperty->isNotAWord(), unigramProperty->isPossiblyOffensive(), true /* isTerminal */, originalPtNodeParams->getParentPos(), originalPtNodeParams->getCodePointArrayView(), unigramProperty->getProbability())); if (!mPtNodeWriter->writeNewTerminalPtNodeAndAdvancePosition(&ptNodeParamsToWrite, @@ -180,8 +180,9 @@ bool DynamicPtUpdatingHelper::createNewPtNodeArrayWithAChildPtNode( return false; } const PtNodeParams ptNodeParamsToWrite(getPtNodeParamsForNewPtNode( - unigramProperty->isNotAWord(), unigramProperty->isBlacklisted(), true /* isTerminal */, - parentPtNodePos, ptNodeCodePoints, unigramProperty->getProbability())); + unigramProperty->isNotAWord(), unigramProperty->isPossiblyOffensive(), + true /* isTerminal */, parentPtNodePos, ptNodeCodePoints, + unigramProperty->getProbability())); if (!mPtNodeWriter->writeNewTerminalPtNodeAndAdvancePosition(&ptNodeParamsToWrite, unigramProperty, &writingPos)) { return false; @@ -214,7 +215,7 @@ bool DynamicPtUpdatingHelper::reallocatePtNodeAndAddNewPtNodes( reallocatingPtNodeParams->getCodePointArrayView().limit(overlappingCodePointCount); if (addsExtraChild) { const PtNodeParams ptNodeParamsToWrite(getPtNodeParamsForNewPtNode( - false /* isNotAWord */, false /* isBlacklisted */, false /* isTerminal */, + false /* isNotAWord */, false /* isPossiblyOffensive */, false /* isTerminal */, reallocatingPtNodeParams->getParentPos(), firstPtNodeCodePoints, NOT_A_PROBABILITY)); if (!mPtNodeWriter->writePtNodeAndAdvancePosition(&ptNodeParamsToWrite, &writingPos)) { @@ -222,7 +223,7 @@ bool DynamicPtUpdatingHelper::reallocatePtNodeAndAddNewPtNodes( } } else { const PtNodeParams ptNodeParamsToWrite(getPtNodeParamsForNewPtNode( - unigramProperty->isNotAWord(), unigramProperty->isBlacklisted(), + unigramProperty->isNotAWord(), unigramProperty->isPossiblyOffensive(), true /* isTerminal */, reallocatingPtNodeParams->getParentPos(), firstPtNodeCodePoints, unigramProperty->getProbability())); if (!mPtNodeWriter->writeNewTerminalPtNodeAndAdvancePosition(&ptNodeParamsToWrite, @@ -240,7 +241,7 @@ bool DynamicPtUpdatingHelper::reallocatePtNodeAndAddNewPtNodes( // Write the 2nd part of the reallocating node. const int secondPartOfReallocatedPtNodePos = writingPos; const PtNodeParams childPartPtNodeParams(getUpdatedPtNodeParams(reallocatingPtNodeParams, - reallocatingPtNodeParams->isNotAWord(), reallocatingPtNodeParams->isBlacklisted(), + reallocatingPtNodeParams->isNotAWord(), reallocatingPtNodeParams->isPossiblyOffensive(), reallocatingPtNodeParams->isTerminal(), firstPartOfReallocatedPtNodePos, reallocatingPtNodeParams->getCodePointArrayView().skip(overlappingCodePointCount), reallocatingPtNodeParams->getProbability())); @@ -249,7 +250,7 @@ bool DynamicPtUpdatingHelper::reallocatePtNodeAndAddNewPtNodes( } if (addsExtraChild) { const PtNodeParams extraChildPtNodeParams(getPtNodeParamsForNewPtNode( - unigramProperty->isNotAWord(), unigramProperty->isBlacklisted(), + unigramProperty->isNotAWord(), unigramProperty->isPossiblyOffensive(), true /* isTerminal */, firstPartOfReallocatedPtNodePos, newPtNodeCodePoints.skip(overlappingCodePointCount), unigramProperty->getProbability())); @@ -276,20 +277,20 @@ bool DynamicPtUpdatingHelper::reallocatePtNodeAndAddNewPtNodes( const PtNodeParams DynamicPtUpdatingHelper::getUpdatedPtNodeParams( const PtNodeParams *const originalPtNodeParams, const bool isNotAWord, - const bool isBlacklisted, const bool isTerminal, const int parentPos, + const bool isPossiblyOffensive, const bool isTerminal, const int parentPos, const CodePointArrayView codePoints, const int probability) const { const PatriciaTrieReadingUtils::NodeFlags flags = PatriciaTrieReadingUtils::createAndGetFlags( - isBlacklisted, isNotAWord, isTerminal, false /* hasShortcutTargets */, + isPossiblyOffensive, isNotAWord, isTerminal, false /* hasShortcutTargets */, false /* hasBigrams */, codePoints.size() > 1u /* hasMultipleChars */, CHILDREN_POSITION_FIELD_SIZE); return PtNodeParams(originalPtNodeParams, flags, parentPos, codePoints, probability); } const PtNodeParams DynamicPtUpdatingHelper::getPtNodeParamsForNewPtNode(const bool isNotAWord, - const bool isBlacklisted, const bool isTerminal, const int parentPos, + const bool isPossiblyOffensive, const bool isTerminal, const int parentPos, const CodePointArrayView codePoints, const int probability) const { const PatriciaTrieReadingUtils::NodeFlags flags = PatriciaTrieReadingUtils::createAndGetFlags( - isBlacklisted, isNotAWord, isTerminal, false /* hasShortcutTargets */, + isPossiblyOffensive, isNotAWord, isTerminal, false /* hasShortcutTargets */, false /* hasBigrams */, codePoints.size() > 1u /* hasMultipleChars */, CHILDREN_POSITION_FIELD_SIZE); return PtNodeParams(flags, parentPos, codePoints, probability); diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.h index 2bbe2f4dc..db5f6ab17 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.h @@ -85,12 +85,12 @@ class DynamicPtUpdatingHelper { const CodePointArrayView newPtNodeCodePoints); const PtNodeParams getUpdatedPtNodeParams(const PtNodeParams *const originalPtNodeParams, - const bool isNotAWord, const bool isBlacklisted, const bool isTerminal, + const bool isNotAWord, const bool isPossiblyOffensive, const bool isTerminal, const int parentPos, const CodePointArrayView codePoints, const int probability) const; - const PtNodeParams getPtNodeParamsForNewPtNode(const bool isNotAWord, const bool isBlacklisted, - const bool isTerminal, const int parentPos, const CodePointArrayView codePoints, - const int probability) const; + const PtNodeParams getPtNodeParamsForNewPtNode(const bool isNotAWord, + const bool isPossiblyOffensive, const bool isTerminal, const int parentPos, + const CodePointArrayView codePoints, const int probability) const; }; } // namespace latinime #endif /* LATINIME_DYNAMIC_PATRICIA_TRIE_UPDATING_HELPER_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.cpp index 6a498b2f4..b8d78bf10 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.cpp @@ -41,8 +41,8 @@ const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_HAS_SHORTCUT_TARGETS = 0x08 const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_HAS_BIGRAMS = 0x04; // Flag for non-words (typically, shortcut only entries) const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_IS_NOT_A_WORD = 0x02; -// Flag for blacklist -const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_IS_BLACKLISTED = 0x01; +// Flag for possibly offensive words +const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_IS_POSSIBLY_OFFENSIVE = 0x01; /* static */ int PtReadingUtils::getPtNodeArraySizeAndAdvancePosition( const uint8_t *const buffer, int *const pos) { diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h index a69ec4435..6a2bf5d3c 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h @@ -54,8 +54,8 @@ class PatriciaTrieReadingUtils { /** * Node Flags */ - static AK_FORCE_INLINE bool isBlacklisted(const NodeFlags flags) { - return (flags & FLAG_IS_BLACKLISTED) != 0; + static AK_FORCE_INLINE bool isPossiblyOffensive(const NodeFlags flags) { + return (flags & FLAG_IS_POSSIBLY_OFFENSIVE) != 0; } static AK_FORCE_INLINE bool isNotAWord(const NodeFlags flags) { @@ -82,12 +82,12 @@ class PatriciaTrieReadingUtils { return FLAG_CHILDREN_POSITION_TYPE_NOPOSITION != (MASK_CHILDREN_POSITION_TYPE & flags); } - static AK_FORCE_INLINE NodeFlags createAndGetFlags(const bool isBlacklisted, + static AK_FORCE_INLINE NodeFlags createAndGetFlags(const bool isPossiblyOffensive, const bool isNotAWord, const bool isTerminal, const bool hasShortcutTargets, const bool hasBigrams, const bool hasMultipleChars, const int childrenPositionFieldSize) { NodeFlags nodeFlags = 0; - nodeFlags = isBlacklisted ? (nodeFlags | FLAG_IS_BLACKLISTED) : nodeFlags; + nodeFlags = isPossiblyOffensive ? (nodeFlags | FLAG_IS_POSSIBLY_OFFENSIVE) : nodeFlags; nodeFlags = isNotAWord ? (nodeFlags | FLAG_IS_NOT_A_WORD) : nodeFlags; nodeFlags = isTerminal ? (nodeFlags | FLAG_IS_TERMINAL) : nodeFlags; nodeFlags = hasShortcutTargets ? (nodeFlags | FLAG_HAS_SHORTCUT_TARGETS) : nodeFlags; @@ -127,7 +127,7 @@ class PatriciaTrieReadingUtils { static const NodeFlags FLAG_HAS_SHORTCUT_TARGETS; static const NodeFlags FLAG_HAS_BIGRAMS; static const NodeFlags FLAG_IS_NOT_A_WORD; - static const NodeFlags FLAG_IS_BLACKLISTED; + static const NodeFlags FLAG_IS_POSSIBLY_OFFENSIVE; }; } // namespace latinime #endif /* LATINIME_PATRICIA_TRIE_NODE_READING_UTILS_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h index 3ff1829bd..e52706e07 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h @@ -144,8 +144,8 @@ class PtNodeParams { return PatriciaTrieReadingUtils::isTerminal(mFlags); } - AK_FORCE_INLINE bool isBlacklisted() const { - return PatriciaTrieReadingUtils::isBlacklisted(mFlags); + AK_FORCE_INLINE bool isPossiblyOffensive() const { + return PatriciaTrieReadingUtils::isPossiblyOffensive(mFlags); } AK_FORCE_INLINE bool isNotAWord() const { diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp index b7f1199c5..59873612a 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp @@ -14,7 +14,6 @@ * limitations under the License. */ - #include "suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h" #include "defines.h" @@ -317,8 +316,8 @@ const WordAttributes PatriciaTriePolicy::getWordAttributesInContext( const WordAttributes PatriciaTriePolicy::getWordAttributes(const int probability, const PtNodeParams &ptNodeParams) const { - return WordAttributes(probability, ptNodeParams.isBlacklisted(), ptNodeParams.isNotAWord(), - ptNodeParams.getProbability() == 0); + return WordAttributes(probability, false /* isBlacklisted */, ptNodeParams.isNotAWord(), + ptNodeParams.isPossiblyOffensive()); } int PatriciaTriePolicy::getProbability(const int unigramProbability, @@ -345,10 +344,9 @@ int PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds, const int ptNodePos = getTerminalPtNodePosFromWordId(wordId); const PtNodeParams ptNodeParams = mPtNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); - if (ptNodeParams.isNotAWord() || ptNodeParams.isBlacklisted()) { - // If this is not a word, or if it's a blacklisted entry, it should behave as - // having no probability outside of the suggestion process (where it should be used - // for shortcuts). + if (ptNodeParams.isNotAWord()) { + // If this is not a word, it should behave as having no probability outside of the + // suggestion process (where it should be used for shortcuts). return NOT_A_PROBABILITY; } if (!prevWordIds.empty()) { @@ -451,6 +449,8 @@ const WordProperty PatriciaTriePolicy::getWordProperty( bigramWord1CodePoints, &word1Probability); const int probability = getProbability(word1Probability, bigramsIt.getProbability()); ngrams.emplace_back( + NgramContext(wordCodePoints.data(), wordCodePoints.size(), + ptNodeParams.representsBeginningOfSentence()), CodePointArrayView(bigramWord1CodePoints, word1CodePointCount).toVector(), probability, HistoricalInfo()); } @@ -476,8 +476,8 @@ const WordProperty PatriciaTriePolicy::getWordProperty( } } const UnigramProperty unigramProperty(ptNodeParams.representsBeginningOfSentence(), - ptNodeParams.isNotAWord(), ptNodeParams.isBlacklisted(), ptNodeParams.getProbability(), - HistoricalInfo(), std::move(shortcuts)); + ptNodeParams.isNotAWord(), ptNodeParams.isPossiblyOffensive(), + ptNodeParams.getProbability(), HistoricalInfo(), std::move(shortcuts)); return WordProperty(wordCodePoints.toVector(), &unigramProperty, &ngrams); } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h index b17681388..8933962ab 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h @@ -93,8 +93,7 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { return false; } - bool addNgramEntry(const NgramContext *const ngramContext, - const NgramProperty *const ngramProperty) { + bool addNgramEntry(const NgramProperty *const ngramProperty) { // This method should not be called for non-updatable dictionary. AKLOGI("Warning: addNgramEntry() is called for non-updatable dictionary."); return false; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.cpp new file mode 100644 index 000000000..b0fbb3e72 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.cpp @@ -0,0 +1,37 @@ +/* + * Copyright (C) 2014, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h" + +namespace latinime { + +// These counts are used to provide stable probabilities even if the user's input count is small. +const int DynamicLanguageModelProbabilityUtils::ASSUMED_MIN_COUNT_FOR_UNIGRAMS = 8192; +const int DynamicLanguageModelProbabilityUtils::ASSUMED_MIN_COUNT_FOR_BIGRAMS = 2; +const int DynamicLanguageModelProbabilityUtils::ASSUMED_MIN_COUNT_FOR_TRIGRAMS = 2; + +// These are encoded backoff weights. +// Note that we give positive value for trigrams that means the weight is more than 1. +// TODO: Apply backoff for main dictionaries and quit giving a positive backoff weight. +const int DynamicLanguageModelProbabilityUtils::ENCODED_BACKOFF_WEIGHT_FOR_UNIGRAMS = -32; +const int DynamicLanguageModelProbabilityUtils::ENCODED_BACKOFF_WEIGHT_FOR_BIGRAMS = 0; +const int DynamicLanguageModelProbabilityUtils::ENCODED_BACKOFF_WEIGHT_FOR_TRIGRAMS = 8; + +// This value is used to remove too old entries from the dictionary. +const int DynamicLanguageModelProbabilityUtils::DURATION_TO_DISCARD_ENTRY_IN_SECONDS = + 300 * 24 * 60 * 60; // 300 days + +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h new file mode 100644 index 000000000..88bc58fe8 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h @@ -0,0 +1,114 @@ +/* + * Copyright (C) 2014, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DYNAMIC_LANGUAGE_MODEL_PROBABILITY_UTILS_H +#define LATINIME_DYNAMIC_LANGUAGE_MODEL_PROBABILITY_UTILS_H + +#include <algorithm> + +#include "defines.h" +#include "suggest/core/dictionary/property/historical_info.h" +#include "utils/time_keeper.h" + +namespace latinime { + +class DynamicLanguageModelProbabilityUtils { + public: + static float computeRawProbabilityFromCounts(const int count, const int contextCount, + const int matchedWordCountInContext) { + int minCount = 0; + switch (matchedWordCountInContext) { + case 1: + minCount = ASSUMED_MIN_COUNT_FOR_UNIGRAMS; + break; + case 2: + minCount = ASSUMED_MIN_COUNT_FOR_BIGRAMS; + break; + case 3: + minCount = ASSUMED_MIN_COUNT_FOR_TRIGRAMS; + break; + default: + AKLOGE("computeRawProbabilityFromCounts is called with invalid " + "matchedWordCountInContext (%d).", matchedWordCountInContext); + ASSERT(false); + return 0.0f; + } + return static_cast<float>(count) / static_cast<float>(std::max(contextCount, minCount)); + } + + static float backoff(const int ngramProbability, const int matchedWordCountInContext) { + int probability = NOT_A_PROBABILITY; + + switch (matchedWordCountInContext) { + case 1: + probability = ngramProbability + ENCODED_BACKOFF_WEIGHT_FOR_UNIGRAMS; + break; + case 2: + probability = ngramProbability + ENCODED_BACKOFF_WEIGHT_FOR_BIGRAMS; + break; + case 3: + probability = ngramProbability + ENCODED_BACKOFF_WEIGHT_FOR_TRIGRAMS; + break; + default: + AKLOGE("backoff is called with invalid matchedWordCountInContext (%d).", + matchedWordCountInContext); + ASSERT(false); + return NOT_A_PROBABILITY; + } + return std::min(std::max(probability, NOT_A_PROBABILITY), MAX_PROBABILITY); + } + + static int getDecayedProbability(const int probability, const HistoricalInfo historicalInfo) { + const int elapsedTime = TimeKeeper::peekCurrentTime() - historicalInfo.getTimestamp(); + if (elapsedTime < 0) { + AKLOGE("The elapsed time is negatime value. Timestamp overflow?"); + return NOT_A_PROBABILITY; + } + // TODO: Improve this logic. + // We don't modify probability depending on the elapsed time. + return probability; + } + + static int shouldRemoveEntryDuringGC(const HistoricalInfo historicalInfo) { + // TODO: Improve this logic. + const int elapsedTime = TimeKeeper::peekCurrentTime() - historicalInfo.getTimestamp(); + return elapsedTime > DURATION_TO_DISCARD_ENTRY_IN_SECONDS; + } + + static int getPriorityToPreventFromEviction(const HistoricalInfo historicalInfo) { + // TODO: Improve this logic. + // More recently input entries get higher priority. + return historicalInfo.getTimestamp(); + } + +private: + DISALLOW_IMPLICIT_CONSTRUCTORS(DynamicLanguageModelProbabilityUtils); + + static_assert(MAX_PREV_WORD_COUNT_FOR_N_GRAM <= 2, "Max supported Ngram is Trigram."); + + static const int ASSUMED_MIN_COUNT_FOR_UNIGRAMS; + static const int ASSUMED_MIN_COUNT_FOR_BIGRAMS; + static const int ASSUMED_MIN_COUNT_FOR_TRIGRAMS; + + static const int ENCODED_BACKOFF_WEIGHT_FOR_UNIGRAMS; + static const int ENCODED_BACKOFF_WEIGHT_FOR_BIGRAMS; + static const int ENCODED_BACKOFF_WEIGHT_FOR_TRIGRAMS; + + static const int DURATION_TO_DISCARD_ENTRY_IN_SECONDS; +}; + +} // namespace latinime +#endif /* LATINIME_DYNAMIC_LANGUAGE_MODEL_PROBABILITY_UTILS_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp index c4297f5d6..31b1ea696 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp @@ -19,28 +19,28 @@ #include <algorithm> #include <cstring> -#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h" +#include "suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h" +#include "suggest/policyimpl/dictionary/utils/probability_utils.h" namespace latinime { -const int LanguageModelDictContent::UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE = 0; -const int LanguageModelDictContent::BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE = 1; -const int LanguageModelDictContent::DUMMY_PROBABILITY_FOR_VALID_WORDS = 1; +const int LanguageModelDictContent::TRIE_MAP_BUFFER_INDEX = 0; +const int LanguageModelDictContent::GLOBAL_COUNTERS_BUFFER_INDEX = 1; bool LanguageModelDictContent::save(FILE *const file) const { - return mTrieMap.save(file); + return mTrieMap.save(file) && mGlobalCounters.save(file); } bool LanguageModelDictContent::runGC( const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, - const LanguageModelDictContent *const originalContent, - int *const outNgramCount) { + const LanguageModelDictContent *const originalContent) { return runGCInner(terminalIdMap, originalContent->mTrieMap.getEntriesInRootLevel(), - 0 /* nextLevelBitmapEntryIndex */, outNgramCount); + 0 /* nextLevelBitmapEntryIndex */); } const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArrayView prevWordIds, - const int wordId, const HeaderPolicy *const headerPolicy) const { + const int wordId, const bool mustMatchAllPrevWords, + const HeaderPolicy *const headerPolicy) const { int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex(); int maxPrevWordCount = 0; @@ -54,7 +54,15 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr bitmapEntryIndices[i + 1] = nextBitmapEntryIndex; } + const ProbabilityEntry unigramProbabilityEntry = getProbabilityEntry(wordId); + if (mHasHistoricalInfo && unigramProbabilityEntry.getHistoricalInfo()->getCount() == 0) { + // The word should be treated as a invalid word. + return WordAttributes(); + } for (int i = maxPrevWordCount; i >= 0; --i) { + if (mustMatchAllPrevWords && prevWordIds.size() > static_cast<size_t>(i)) { + break; + } const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndices[i]); if (!result.mIsValid) { continue; @@ -63,38 +71,41 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo); int probability = NOT_A_PROBABILITY; if (mHasHistoricalInfo) { - const int rawProbability = ForgettingCurveUtils::decodeProbability( - probabilityEntry.getHistoricalInfo(), headerPolicy); - if (rawProbability == NOT_A_PROBABILITY) { - // The entry should not be treated as a valid entry. - continue; - } + const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo(); + int contextCount = 0; if (i == 0) { // unigram - probability = rawProbability; + contextCount = mGlobalCounters.getTotalCount(); } else { const ProbabilityEntry prevWordProbabilityEntry = getNgramProbabilityEntry( prevWordIds.skip(1 /* n */).limit(i - 1), prevWordIds[0]); if (!prevWordProbabilityEntry.isValid()) { continue; } - if (prevWordProbabilityEntry.representsBeginningOfSentence()) { - probability = rawProbability; - } else { - const int prevWordRawProbability = ForgettingCurveUtils::decodeProbability( - prevWordProbabilityEntry.getHistoricalInfo(), headerPolicy); - probability = std::min(MAX_PROBABILITY - prevWordRawProbability - + rawProbability, MAX_PROBABILITY); + if (prevWordProbabilityEntry.representsBeginningOfSentence() + && historicalInfo->getCount() == 1) { + // BoS ngram requires multiple contextCount. + continue; } + contextCount = prevWordProbabilityEntry.getHistoricalInfo()->getCount(); } + const float rawProbability = + DynamicLanguageModelProbabilityUtils::computeRawProbabilityFromCounts( + historicalInfo->getCount(), contextCount, i + 1); + const int encodedRawProbability = + ProbabilityUtils::encodeRawProbability(rawProbability); + const int decayedProbability = + DynamicLanguageModelProbabilityUtils::getDecayedProbability( + encodedRawProbability, *historicalInfo); + probability = DynamicLanguageModelProbabilityUtils::backoff( + decayedProbability, i + 1 /* n */); } else { probability = probabilityEntry.getProbability(); } // TODO: Some flags in unigramProbabilityEntry should be overwritten by flags in // probabilityEntry. - const ProbabilityEntry unigramProbabilityEntry = getProbabilityEntry(wordId); - return WordAttributes(probability, unigramProbabilityEntry.isNotAWord(), - unigramProbabilityEntry.isBlacklisted(), + return WordAttributes(probability, unigramProbabilityEntry.isBlacklisted(), + unigramProbabilityEntry.isNotAWord(), unigramProbabilityEntry.isPossiblyOffensive()); } // Cannot find the word. @@ -143,28 +154,69 @@ LanguageModelDictContent::EntryRange LanguageModelDictContent::getProbabilityEnt return EntryRange(mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex), mHasHistoricalInfo); } -bool LanguageModelDictContent::truncateEntries(const int *const entryCounts, - const int *const maxEntryCounts, const HeaderPolicy *const headerPolicy, - int *const outEntryCounts) { - for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) { - if (entryCounts[i] <= maxEntryCounts[i]) { - outEntryCounts[i] = entryCounts[i]; +std::vector<LanguageModelDictContent::DumppedFullEntryInfo> + LanguageModelDictContent::exportAllNgramEntriesRelatedToWord( + const HeaderPolicy *const headerPolicy, const int wordId) const { + const TrieMap::Result result = mTrieMap.getRoot(wordId); + if (!result.mIsValid || result.mNextLevelBitmapEntryIndex == TrieMap::INVALID_INDEX) { + // The word doesn't have any related ngram entries. + return std::vector<DumppedFullEntryInfo>(); + } + std::vector<int> prevWordIds = { wordId }; + std::vector<DumppedFullEntryInfo> entries; + exportAllNgramEntriesRelatedToWordInner(headerPolicy, result.mNextLevelBitmapEntryIndex, + &prevWordIds, &entries); + return entries; +} + +void LanguageModelDictContent::exportAllNgramEntriesRelatedToWordInner( + const HeaderPolicy *const headerPolicy, const int bitmapEntryIndex, + std::vector<int> *const prevWordIds, + std::vector<DumppedFullEntryInfo> *const outBummpedFullEntryInfo) const { + for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) { + const int wordId = entry.key(); + const ProbabilityEntry probabilityEntry = + ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo); + if (probabilityEntry.isValid()) { + const WordAttributes wordAttributes = getWordAttributes( + WordIdArrayView(*prevWordIds), wordId, true /* mustMatchAllPrevWords */, + headerPolicy); + outBummpedFullEntryInfo->emplace_back(*prevWordIds, wordId, + wordAttributes, probabilityEntry); + } + if (entry.hasNextLevelMap()) { + prevWordIds->push_back(wordId); + exportAllNgramEntriesRelatedToWordInner(headerPolicy, + entry.getNextLevelBitmapEntryIndex(), prevWordIds, outBummpedFullEntryInfo); + prevWordIds->pop_back(); + } + } +} + +bool LanguageModelDictContent::truncateEntries(const EntryCounts ¤tEntryCounts, + const EntryCounts &maxEntryCounts, const HeaderPolicy *const headerPolicy, + MutableEntryCounters *const outEntryCounters) { + for (int prevWordCount = 0; prevWordCount <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++prevWordCount) { + const int totalWordCount = prevWordCount + 1; + if (currentEntryCounts.getNgramCount(totalWordCount) + <= maxEntryCounts.getNgramCount(totalWordCount)) { + outEntryCounters->setNgramCount(totalWordCount, + currentEntryCounts.getNgramCount(totalWordCount)); continue; } - if (!turncateEntriesInSpecifiedLevel(headerPolicy, maxEntryCounts[i], i, - &outEntryCounts[i])) { + int entryCount = 0; + if (!turncateEntriesInSpecifiedLevel(headerPolicy, + maxEntryCounts.getNgramCount(totalWordCount), prevWordCount, &entryCount)) { return false; } + outEntryCounters->setNgramCount(totalWordCount, entryCount); } return true; } bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView prevWordIds, const int wordId, const bool isValid, const HistoricalInfo historicalInfo, - const HeaderPolicy *const headerPolicy, int *const outAddedNewNgramEntryCount) { - if (outAddedNewNgramEntryCount) { - *outAddedNewNgramEntryCount = 0; - } + const HeaderPolicy *const headerPolicy, MutableEntryCounters *const entryCountersToUpdate) { if (!mHasHistoricalInfo) { AKLOGE("updateAllEntriesOnInputWord is called for dictionary without historical info."); return false; @@ -175,6 +227,9 @@ bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView if (!setProbabilityEntry(wordId, &updatedUnigramProbabilityEntry)) { return false; } + mGlobalCounters.incrementTotalCount(); + mGlobalCounters.updateMaxValueOfCounters( + updatedUnigramProbabilityEntry.getHistoricalInfo()->getCount()); for (size_t i = 0; i < prevWordIds.size(); ++i) { if (prevWordIds[i] == NOT_A_WORD_ID) { break; @@ -188,8 +243,10 @@ bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView if (!setNgramProbabilityEntry(limitedPrevWordIds, wordId, &updatedNgramProbabilityEntry)) { return false; } - if (!originalNgramProbabilityEntry.isValid() && outAddedNewNgramEntryCount) { - *outAddedNewNgramEntryCount += 1; + mGlobalCounters.updateMaxValueOfCounters( + updatedNgramProbabilityEntry.getHistoricalInfo()->getCount()); + if (!originalNgramProbabilityEntry.isValid()) { + entryCountersToUpdate->incrementNgramCount(i + 2); } } return true; @@ -198,10 +255,9 @@ bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView const ProbabilityEntry LanguageModelDictContent::createUpdatedEntryFrom( const ProbabilityEntry &originalProbabilityEntry, const bool isValid, const HistoricalInfo historicalInfo, const HeaderPolicy *const headerPolicy) const { - const HistoricalInfo updatedHistoricalInfo = ForgettingCurveUtils::createUpdatedHistoricalInfo( - originalProbabilityEntry.getHistoricalInfo(), isValid ? - DUMMY_PROBABILITY_FOR_VALID_WORDS : NOT_A_PROBABILITY, - &historicalInfo, headerPolicy); + const HistoricalInfo updatedHistoricalInfo = HistoricalInfo(historicalInfo.getTimestamp(), + 0 /* level */, originalProbabilityEntry.getHistoricalInfo()->getCount() + + historicalInfo.getCount()); if (originalProbabilityEntry.isValid()) { return ProbabilityEntry(originalProbabilityEntry.getFlags(), &updatedHistoricalInfo); } else { @@ -211,8 +267,7 @@ const ProbabilityEntry LanguageModelDictContent::createUpdatedEntryFrom( bool LanguageModelDictContent::runGCInner( const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, - const TrieMap::TrieMapRange trieMapRange, - const int nextLevelBitmapEntryIndex, int *const outNgramCount) { + const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex) { for (auto &entry : trieMapRange) { const auto it = terminalIdMap->find(entry.key()); if (it == terminalIdMap->end() || it->second == Ver4DictConstants::NOT_A_TERMINAL_ID) { @@ -222,13 +277,9 @@ bool LanguageModelDictContent::runGCInner( if (!mTrieMap.put(it->second, entry.value(), nextLevelBitmapEntryIndex)) { return false; } - if (outNgramCount) { - *outNgramCount += 1; - } if (entry.hasNextLevelMap()) { if (!runGCInner(terminalIdMap, entry.getEntriesInNextLevel(), - mTrieMap.getNextLevelBitmapEntryIndex(it->second, nextLevelBitmapEntryIndex), - outNgramCount)) { + mTrieMap.getNextLevelBitmapEntryIndex(it->second, nextLevelBitmapEntryIndex))) { return false; } } @@ -237,24 +288,25 @@ bool LanguageModelDictContent::runGCInner( } int LanguageModelDictContent::createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds) { - if (prevWordIds.empty()) { - return mTrieMap.getRootBitmapEntryIndex(); - } - const int lastBitmapEntryIndex = - getBitmapEntryIndex(prevWordIds.limit(prevWordIds.size() - 1)); - if (lastBitmapEntryIndex == TrieMap::INVALID_INDEX) { - return TrieMap::INVALID_INDEX; - } - const int oldestPrevWordId = prevWordIds.lastOrDefault(NOT_A_WORD_ID); - const TrieMap::Result result = mTrieMap.get(oldestPrevWordId, lastBitmapEntryIndex); - if (!result.mIsValid) { - if (!mTrieMap.put(oldestPrevWordId, - ProbabilityEntry().encode(mHasHistoricalInfo), lastBitmapEntryIndex)) { - return TrieMap::INVALID_INDEX; + int lastBitmapEntryIndex = mTrieMap.getRootBitmapEntryIndex(); + for (const int wordId : prevWordIds) { + const TrieMap::Result result = mTrieMap.get(wordId, lastBitmapEntryIndex); + if (result.mIsValid && result.mNextLevelBitmapEntryIndex != TrieMap::INVALID_INDEX) { + lastBitmapEntryIndex = result.mNextLevelBitmapEntryIndex; + continue; } + if (!result.mIsValid) { + if (!mTrieMap.put(wordId, ProbabilityEntry().encode(mHasHistoricalInfo), + lastBitmapEntryIndex)) { + AKLOGE("Failed to update trie map. wordId: %d, lastBitmapEntryIndex %d", wordId, + lastBitmapEntryIndex); + return TrieMap::INVALID_INDEX; + } + } + lastBitmapEntryIndex = mTrieMap.getNextLevelBitmapEntryIndex(wordId, + lastBitmapEntryIndex); } - return mTrieMap.getNextLevelBitmapEntryIndex(prevWordIds.lastOrDefault(NOT_A_WORD_ID), - lastBitmapEntryIndex); + return lastBitmapEntryIndex; } int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWordIds) const { @@ -271,7 +323,7 @@ int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWord bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, const int prevWordCount, const HeaderPolicy *const headerPolicy, - int *const outEntryCounts) { + const bool needsToHalveCounters, MutableEntryCounters *const outEntryCounters) { for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) { if (prevWordCount > MAX_PREV_WORD_COUNT_FOR_N_GRAM) { AKLOGE("Invalid prevWordCount. prevWordCount: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.", @@ -288,33 +340,41 @@ bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int b } continue; } - if (mHasHistoricalInfo && !probabilityEntry.representsBeginningOfSentence() - && probabilityEntry.isValid()) { - const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave( - probabilityEntry.getHistoricalInfo(), headerPolicy); - if (ForgettingCurveUtils::needsToKeep(&historicalInfo, headerPolicy)) { - // Update the entry. - const ProbabilityEntry updatedEntry(probabilityEntry.getFlags(), &historicalInfo); - if (!mTrieMap.put(entry.key(), updatedEntry.encode(mHasHistoricalInfo), - bitmapEntryIndex)) { - return false; - } - } else { + if (mHasHistoricalInfo && probabilityEntry.isValid()) { + const HistoricalInfo *originalHistoricalInfo = probabilityEntry.getHistoricalInfo(); + if (DynamicLanguageModelProbabilityUtils::shouldRemoveEntryDuringGC( + *originalHistoricalInfo)) { // Remove the entry. if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) { return false; } continue; } + if (needsToHalveCounters) { + const int updatedCount = originalHistoricalInfo->getCount() / 2; + if (updatedCount == 0) { + // Remove the entry. + if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) { + return false; + } + continue; + } + const HistoricalInfo historicalInfoToSave(originalHistoricalInfo->getTimestamp(), + originalHistoricalInfo->getLevel(), updatedCount); + const ProbabilityEntry updatedEntry(probabilityEntry.getFlags(), + &historicalInfoToSave); + if (!mTrieMap.put(entry.key(), updatedEntry.encode(mHasHistoricalInfo), + bitmapEntryIndex)) { + return false; + } + } } - if (!probabilityEntry.representsBeginningOfSentence()) { - outEntryCounts[prevWordCount] += 1; - } + outEntryCounters->incrementNgramCount(prevWordCount + 1); if (!entry.hasNextLevelMap()) { continue; } if (!updateAllProbabilityEntriesForGCInner(entry.getNextLevelBitmapEntryIndex(), - prevWordCount + 1, headerPolicy, outEntryCounts)) { + prevWordCount + 1, headerPolicy, needsToHalveCounters, outEntryCounters)) { return false; } } @@ -368,11 +428,11 @@ bool LanguageModelDictContent::getEntryInfo(const HeaderPolicy *const headerPoli } const ProbabilityEntry probabilityEntry = ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo); - const int probability = (mHasHistoricalInfo) ? - ForgettingCurveUtils::decodeProbability(probabilityEntry.getHistoricalInfo(), - headerPolicy) : probabilityEntry.getProbability(); - outEntryInfo->emplace_back(probability, - probabilityEntry.getHistoricalInfo()->getTimestamp(), + const int priority = mHasHistoricalInfo + ? DynamicLanguageModelProbabilityUtils::getPriorityToPreventFromEviction( + *probabilityEntry.getHistoricalInfo()) + : probabilityEntry.getProbability(); + outEntryInfo->emplace_back(priority, probabilityEntry.getHistoricalInfo()->getCount(), entry.key(), targetLevel, prevWordIds->data()); } return true; @@ -380,11 +440,11 @@ bool LanguageModelDictContent::getEntryInfo(const HeaderPolicy *const headerPoli bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()( const EntryInfoToTurncate &left, const EntryInfoToTurncate &right) const { - if (left.mProbability != right.mProbability) { - return left.mProbability < right.mProbability; + if (left.mPriority != right.mPriority) { + return left.mPriority < right.mPriority; } - if (left.mTimestamp != right.mTimestamp) { - return left.mTimestamp > right.mTimestamp; + if (left.mCount != right.mCount) { + return left.mCount < right.mCount; } if (left.mKey != right.mKey) { return left.mKey < right.mKey; @@ -401,10 +461,9 @@ bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()( return false; } -LanguageModelDictContent::EntryInfoToTurncate::EntryInfoToTurncate(const int probability, - const int timestamp, const int key, const int prevWordCount, const int *const prevWordIds) - : mProbability(probability), mTimestamp(timestamp), mKey(key), - mPrevWordCount(prevWordCount) { +LanguageModelDictContent::EntryInfoToTurncate::EntryInfoToTurncate(const int priority, + const int count, const int key, const int prevWordCount, const int *const prevWordIds) + : mPriority(priority), mCount(count), mKey(key), mPrevWordCount(prevWordCount) { memmove(mPrevWordIds, prevWordIds, mPrevWordCount * sizeof(mPrevWordIds[0])); } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h index 51ef090e1..9678c35f9 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h @@ -22,9 +22,11 @@ #include "defines.h" #include "suggest/core/dictionary/word_attributes.h" +#include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_global_counters.h" #include "suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h" #include "suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h" #include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h" +#include "suggest/policyimpl/dictionary/utils/entry_counters.h" #include "suggest/policyimpl/dictionary/utils/trie_map.h" #include "utils/byte_array_view.h" #include "utils/int_array_view.h" @@ -40,9 +42,6 @@ class HeaderPolicy; */ class LanguageModelDictContent { public: - static const int UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE; - static const int BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE; - // Pair of word id and probability entry used for iteration. class WordIdAndProbabilityEntry { public: @@ -112,31 +111,54 @@ class LanguageModelDictContent { const bool mHasHistoricalInfo; }; - LanguageModelDictContent(const ReadWriteByteArrayView trieMapBuffer, + class DumppedFullEntryInfo { + public: + DumppedFullEntryInfo(std::vector<int> &prevWordIds, const int targetWordId, + const WordAttributes &wordAttributes, const ProbabilityEntry &probabilityEntry) + : mPrevWordIds(prevWordIds), mTargetWordId(targetWordId), + mWordAttributes(wordAttributes), mProbabilityEntry(probabilityEntry) {} + + const WordIdArrayView getPrevWordIds() const { return WordIdArrayView(mPrevWordIds); } + int getTargetWordId() const { return mTargetWordId; } + const WordAttributes &getWordAttributes() const { return mWordAttributes; } + const ProbabilityEntry &getProbabilityEntry() const { return mProbabilityEntry; } + + private: + DISALLOW_ASSIGNMENT_OPERATOR(DumppedFullEntryInfo); + + const std::vector<int> mPrevWordIds; + const int mTargetWordId; + const WordAttributes mWordAttributes; + const ProbabilityEntry mProbabilityEntry; + }; + + LanguageModelDictContent(const ReadWriteByteArrayView *const buffers, const bool hasHistoricalInfo) - : mTrieMap(trieMapBuffer), mHasHistoricalInfo(hasHistoricalInfo) {} + : mTrieMap(buffers[TRIE_MAP_BUFFER_INDEX]), + mGlobalCounters(buffers[GLOBAL_COUNTERS_BUFFER_INDEX]), + mHasHistoricalInfo(hasHistoricalInfo) {} explicit LanguageModelDictContent(const bool hasHistoricalInfo) - : mTrieMap(), mHasHistoricalInfo(hasHistoricalInfo) {} + : mTrieMap(), mGlobalCounters(), mHasHistoricalInfo(hasHistoricalInfo) {} bool isNearSizeLimit() const { - return mTrieMap.isNearSizeLimit(); + return mTrieMap.isNearSizeLimit() || mGlobalCounters.needsToHalveCounters(); } bool save(FILE *const file) const; bool runGC(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, - const LanguageModelDictContent *const originalContent, - int *const outNgramCount); + const LanguageModelDictContent *const originalContent); const WordAttributes getWordAttributes(const WordIdArrayView prevWordIds, const int wordId, - const HeaderPolicy *const headerPolicy) const; + const bool mustMatchAllPrevWords, const HeaderPolicy *const headerPolicy) const; ProbabilityEntry getProbabilityEntry(const int wordId) const { return getNgramProbabilityEntry(WordIdArrayView(), wordId); } bool setProbabilityEntry(const int wordId, const ProbabilityEntry *const probabilityEntry) { + mGlobalCounters.addToTotalCount(probabilityEntry->getHistoricalInfo()->getCount()); return setNgramProbabilityEntry(WordIdArrayView(), wordId, probabilityEntry); } @@ -154,22 +176,30 @@ class LanguageModelDictContent { EntryRange getProbabilityEntries(const WordIdArrayView prevWordIds) const; + std::vector<DumppedFullEntryInfo> exportAllNgramEntriesRelatedToWord( + const HeaderPolicy *const headerPolicy, const int wordId) const; + bool updateAllProbabilityEntriesForGC(const HeaderPolicy *const headerPolicy, - int *const outEntryCounts) { - for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) { - outEntryCounts[i] = 0; + MutableEntryCounters *const outEntryCounters) { + if (!updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(), + 0 /* prevWordCount */, headerPolicy, mGlobalCounters.needsToHalveCounters(), + outEntryCounters)) { + return false; + } + if (mGlobalCounters.needsToHalveCounters()) { + mGlobalCounters.halveCounters(); } - return updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(), - 0 /* prevWordCount */, headerPolicy, outEntryCounts); + return true; } // entryCounts should be created by updateAllProbabilityEntries. - bool truncateEntries(const int *const entryCounts, const int *const maxEntryCounts, - const HeaderPolicy *const headerPolicy, int *const outEntryCounts); + bool truncateEntries(const EntryCounts ¤tEntryCounts, const EntryCounts &maxEntryCounts, + const HeaderPolicy *const headerPolicy, MutableEntryCounters *const outEntryCounters); bool updateAllEntriesOnInputWord(const WordIdArrayView prevWordIds, const int wordId, const bool isValid, const HistoricalInfo historicalInfo, - const HeaderPolicy *const headerPolicy, int *const outAddedNewNgramEntryCount); + const HeaderPolicy *const headerPolicy, + MutableEntryCounters *const entryCountersToUpdate); private: DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent); @@ -184,11 +214,12 @@ class LanguageModelDictContent { DISALLOW_ASSIGNMENT_OPERATOR(Comparator); }; - EntryInfoToTurncate(const int probability, const int timestamp, const int key, + EntryInfoToTurncate(const int priority, const int count, const int key, const int prevWordCount, const int *const prevWordIds); - int mProbability; - int mTimestamp; + int mPriority; + // TODO: Remove. + int mCount; int mKey; int mPrevWordCount; int mPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; @@ -197,19 +228,20 @@ class LanguageModelDictContent { DISALLOW_DEFAULT_CONSTRUCTOR(EntryInfoToTurncate); }; - // TODO: Remove - static const int DUMMY_PROBABILITY_FOR_VALID_WORDS; + static const int TRIE_MAP_BUFFER_INDEX; + static const int GLOBAL_COUNTERS_BUFFER_INDEX; TrieMap mTrieMap; + LanguageModelDictContentGlobalCounters mGlobalCounters; const bool mHasHistoricalInfo; bool runGCInner(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, - const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex, - int *const outNgramCount); + const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex); int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds); int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const; bool updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, const int prevWordCount, - const HeaderPolicy *const headerPolicy, int *const outEntryCounts); + const HeaderPolicy *const headerPolicy, const bool needsToHalveCounters, + MutableEntryCounters *const outEntryCounters); bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy, const int maxEntryCount, const int targetLevel, int *const outEntryCount); bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel, @@ -218,6 +250,9 @@ class LanguageModelDictContent { const ProbabilityEntry createUpdatedEntryFrom(const ProbabilityEntry &originalProbabilityEntry, const bool isValid, const HistoricalInfo historicalInfo, const HeaderPolicy *const headerPolicy) const; + void exportAllNgramEntriesRelatedToWordInner(const HeaderPolicy *const headerPolicy, + const int bitmapEntryIndex, std::vector<int> *const prevWordIds, + std::vector<DumppedFullEntryInfo> *const outBummpedFullEntryInfo) const; }; } // namespace latinime #endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_global_counters.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_global_counters.cpp new file mode 100644 index 000000000..d6d91887e --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_global_counters.cpp @@ -0,0 +1,32 @@ +/* + * Copyright (C) 2014, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_global_counters.h" + +#include <climits> + +#include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h" + +namespace latinime { + +const int LanguageModelDictContentGlobalCounters::COUNTER_VALUE_NEAR_LIMIT_THRESHOLD = + (1 << (Ver4DictConstants::WORD_COUNT_FIELD_SIZE * CHAR_BIT)) - 64; +const int LanguageModelDictContentGlobalCounters::TOTAL_COUNT_VALUE_NEAR_LIMIT_THRESHOLD = 1 << 30; +const int LanguageModelDictContentGlobalCounters::COUNTER_SIZE_IN_BYTES = 4; +const int LanguageModelDictContentGlobalCounters::TOTAL_COUNT_INDEX = 0; +const int LanguageModelDictContentGlobalCounters::MAX_VALUE_OF_COUNTERS_INDEX = 1; + +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_global_counters.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_global_counters.h new file mode 100644 index 000000000..283c2691a --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_global_counters.h @@ -0,0 +1,101 @@ +/* + * Copyright (C) 2014, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_LANGUAGE_MODEL_DICT_CONTENT_GLOBAL_COUNTERS_H +#define LATINIME_LANGUAGE_MODEL_DICT_CONTENT_GLOBAL_COUNTERS_H + +#include <cstdio> + +#include "defines.h" +#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" +#include "suggest/policyimpl/dictionary/utils/dict_file_writing_utils.h" +#include "utils/byte_array_view.h" + +namespace latinime { + +class LanguageModelDictContentGlobalCounters { + public: + explicit LanguageModelDictContentGlobalCounters(const ReadWriteByteArrayView buffer) + : mBuffer(buffer, 0 /* maxAdditionalBufferSize */), + mTotalCount(readValue(mBuffer, TOTAL_COUNT_INDEX)), + mMaxValueOfCounters(readValue(mBuffer, MAX_VALUE_OF_COUNTERS_INDEX)) {} + + LanguageModelDictContentGlobalCounters() + : mBuffer(0 /* maxAdditionalBufferSize */), mTotalCount(0), mMaxValueOfCounters(0) {} + + bool needsToHalveCounters() const { + return mMaxValueOfCounters >= COUNTER_VALUE_NEAR_LIMIT_THRESHOLD + || mTotalCount >= TOTAL_COUNT_VALUE_NEAR_LIMIT_THRESHOLD; + } + + int getTotalCount() const { + return mTotalCount; + } + + bool save(FILE *const file) const { + BufferWithExtendableBuffer bufferToWrite( + BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE); + if (!bufferToWrite.writeUint(mTotalCount, COUNTER_SIZE_IN_BYTES, + TOTAL_COUNT_INDEX * COUNTER_SIZE_IN_BYTES)) { + return false; + } + if (!bufferToWrite.writeUint(mMaxValueOfCounters, COUNTER_SIZE_IN_BYTES, + MAX_VALUE_OF_COUNTERS_INDEX * COUNTER_SIZE_IN_BYTES)) { + return false; + } + return DictFileWritingUtils::writeBufferToFileTail(file, &bufferToWrite); + } + + void incrementTotalCount() { + mTotalCount += 1; + } + + void addToTotalCount(const int count) { + mTotalCount += count; + } + + void updateMaxValueOfCounters(const int count) { + mMaxValueOfCounters = std::max(count, mMaxValueOfCounters); + } + + void halveCounters() { + mMaxValueOfCounters /= 2; + mTotalCount /= 2; + } + +private: + DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContentGlobalCounters); + + const static int COUNTER_VALUE_NEAR_LIMIT_THRESHOLD; + const static int TOTAL_COUNT_VALUE_NEAR_LIMIT_THRESHOLD; + const static int COUNTER_SIZE_IN_BYTES; + const static int TOTAL_COUNT_INDEX; + const static int MAX_VALUE_OF_COUNTERS_INDEX; + + BufferWithExtendableBuffer mBuffer; + int mTotalCount; + int mMaxValueOfCounters; + + static int readValue(const BufferWithExtendableBuffer &buffer, const int index) { + const int pos = COUNTER_SIZE_IN_BYTES * index; + if (pos + COUNTER_SIZE_IN_BYTES > buffer.getTailPosition()) { + return 0; + } + return buffer.readUint(COUNTER_SIZE_IN_BYTES, pos); + } +}; +} // namespace latinime +#endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_GLOBAL_COUNTERS_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h index f4d340f86..9c4ab18e4 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h @@ -105,7 +105,7 @@ class ProbabilityEntry { encodedEntry = (encodedEntry << (Ver4DictConstants::WORD_LEVEL_FIELD_SIZE * CHAR_BIT)) | static_cast<uint8_t>(mHistoricalInfo.getLevel()); encodedEntry = (encodedEntry << (Ver4DictConstants::WORD_COUNT_FIELD_SIZE * CHAR_BIT)) - | static_cast<uint8_t>(mHistoricalInfo.getCount()); + | static_cast<uint16_t>(mHistoricalInfo.getCount()); } else { encodedEntry = (encodedEntry << (Ver4DictConstants::PROBABILITY_SIZE * CHAR_BIT)) | static_cast<uint8_t>(mProbability); diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.cpp index 45f88e9b2..4d088dcab 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.cpp @@ -179,7 +179,7 @@ Ver4DictBuffers::Ver4DictBuffers(MmappedBuffer::MmappedBufferPtr &&headerBuffer, BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE), mTerminalPositionLookupTable( contentBuffers[Ver4DictConstants::TERMINAL_ADDRESS_LOOKUP_TABLE_BUFFER_INDEX]), - mLanguageModelDictContent(contentBuffers[Ver4DictConstants::LANGUAGE_MODEL_BUFFER_INDEX], + mLanguageModelDictContent(&contentBuffers[Ver4DictConstants::LANGUAGE_MODEL_BUFFER_INDEX], mHeaderPolicy.hasHistoricalInfoOfWords()), mShortcutDictContent(&contentBuffers[Ver4DictConstants::SHORTCUT_BUFFERS_INDEX]), mIsUpdatable(mDictBuffer->isUpdatable()) {} diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp index 8e6cb974b..bd89b8da7 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp @@ -49,8 +49,8 @@ const int Ver4DictConstants::TERMINAL_ADDRESS_TABLE_ADDRESS_SIZE = 3; const int Ver4DictConstants::NOT_A_TERMINAL_ADDRESS = 0; const int Ver4DictConstants::TERMINAL_ID_FIELD_SIZE = 4; const int Ver4DictConstants::TIME_STAMP_FIELD_SIZE = 4; -const int Ver4DictConstants::WORD_LEVEL_FIELD_SIZE = 1; -const int Ver4DictConstants::WORD_COUNT_FIELD_SIZE = 1; +const int Ver4DictConstants::WORD_LEVEL_FIELD_SIZE = 0; +const int Ver4DictConstants::WORD_COUNT_FIELD_SIZE = 2; const uint8_t Ver4DictConstants::FLAG_REPRESENTS_BEGINNING_OF_SENTENCE = 0x1; const uint8_t Ver4DictConstants::FLAG_NOT_A_VALID_ENTRY = 0x2; @@ -67,6 +67,6 @@ const int Ver4DictConstants::SHORTCUT_HAS_NEXT_MASK = 0x80; const size_t Ver4DictConstants::NUM_OF_BUFFERS_FOR_SINGLE_DICT_CONTENT = 1; const size_t Ver4DictConstants::NUM_OF_BUFFERS_FOR_SPARSE_TABLE_DICT_CONTENT = 3; -const size_t Ver4DictConstants::NUM_OF_BUFFERS_FOR_LANGUAGE_MODEL_DICT_CONTENT = 1; +const size_t Ver4DictConstants::NUM_OF_BUFFERS_FOR_LANGUAGE_MODEL_DICT_CONTENT = 2; } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h index 600b5ffe4..13d7a5714 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h @@ -47,6 +47,7 @@ class Ver4DictConstants { static const int NOT_A_TERMINAL_ADDRESS; static const int TERMINAL_ID_FIELD_SIZE; static const int TIME_STAMP_FIELD_SIZE; + // TODO: Remove static const int WORD_LEVEL_FIELD_SIZE; static const int WORD_COUNT_FIELD_SIZE; // Flags in probability entry. diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp index 794c63ffd..3488f7d2a 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp @@ -342,7 +342,7 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeFlags(const int ptNodePos, const bo // Create node flags and write them. PatriciaTrieReadingUtils::NodeFlags nodeFlags = PatriciaTrieReadingUtils::createAndGetFlags(false /* isNotAWord */, - false /* isBlacklisted */, isTerminal, false /* hasShortcutTargets */, + false /* isPossiblyOffensive */, isTerminal, false /* hasShortcutTargets */, false /* hasBigrams */, hasMultipleChars, CHILDREN_POSITION_FIELD_SIZE); if (!DynamicPtWritingUtils::writeFlags(mTrieBuffer, nodeFlags, ptNodePos)) { AKLOGE("Cannot write PtNode flags. flags: %x, pos: %d", nodeFlags, ptNodePos); diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp index ea8c0dc22..1992d4a5a 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp @@ -97,6 +97,9 @@ int Ver4PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints, return NOT_A_WORD_ID; } const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); + if (ptNodeParams.isDeleted()) { + return NOT_A_WORD_ID; + } return ptNodeParams.getTerminalId(); } @@ -107,7 +110,7 @@ const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext( return WordAttributes(); } return mBuffers->getLanguageModelDictContent()->getWordAttributes(prevWordIds, wordId, - mHeaderPolicy); + false /* mustMatchAllPrevWords */, mHeaderPolicy); } int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds, @@ -115,18 +118,13 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordI if (wordId == NOT_A_WORD_ID || prevWordIds.contains(NOT_A_WORD_ID)) { return NOT_A_PROBABILITY; } - const ProbabilityEntry probabilityEntry = - mBuffers->getLanguageModelDictContent()->getNgramProbabilityEntry(prevWordIds, wordId); - if (!probabilityEntry.isValid() || probabilityEntry.isBlacklisted() - || probabilityEntry.isNotAWord()) { + const WordAttributes wordAttributes = + mBuffers->getLanguageModelDictContent()->getWordAttributes(prevWordIds, wordId, + true /* mustMatchAllPrevWords */, mHeaderPolicy); + if (wordAttributes.isBlacklisted() || wordAttributes.isNotAWord()) { return NOT_A_PROBABILITY; } - if (mHeaderPolicy->hasHistoricalInfoOfWords()) { - return ForgettingCurveUtils::decodeProbability(probabilityEntry.getHistoricalInfo(), - mHeaderPolicy); - } else { - return probabilityEntry.getProbability(); - } + return wordAttributes.getProbability(); } BinaryDictionaryShortcutIterator Ver4PatriciaTriePolicy::getShortcutIterator( @@ -148,10 +146,16 @@ void Ver4PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordI if (!probabilityEntry.isValid()) { continue; } - const int probability = probabilityEntry.hasHistoricalInfo() ? - ForgettingCurveUtils::decodeProbability( - probabilityEntry.getHistoricalInfo(), mHeaderPolicy) : - probabilityEntry.getProbability(); + int probability = NOT_A_PROBABILITY; + if (probabilityEntry.hasHistoricalInfo()) { + // TODO: Quit checking count here. + // If count <= 1, the word can be an invaild word. The actual probability should + // be checked using getWordAttributesInContext() in onVisitEntry(). + probability = probabilityEntry.getHistoricalInfo()->getCount() <= 1 ? + NOT_A_PROBABILITY : 0; + } else { + probability = probabilityEntry.getProbability(); + } listener->onVisitEntry(probability, entry.getWordId()); } } @@ -211,7 +215,7 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const CodePointArrayView wordCodePo if (mUpdatingHelper.addUnigramWord(&readingHelper, codePointArrayView, unigramProperty, &addedNewUnigram)) { if (addedNewUnigram && !unigramProperty->representsBeginningOfSentence()) { - mUnigramCount++; + mEntryCounters.incrementUnigramCount(); } if (unigramProperty->getShortcuts().size() > 0) { // Add shortcut target. @@ -259,13 +263,12 @@ bool Ver4PatriciaTriePolicy::removeUnigramEntry(const CodePointArrayView wordCod return false; } if (!ptNodeParams.representsNonWordInfo()) { - mUnigramCount--; + mEntryCounters.decrementUnigramCount(); } return true; } -bool Ver4PatriciaTriePolicy::addNgramEntry(const NgramContext *const ngramContext, - const NgramProperty *const ngramProperty) { +bool Ver4PatriciaTriePolicy::addNgramEntry(const NgramProperty *const ngramProperty) { if (!mBuffers->isUpdatable()) { AKLOGI("Warning: addNgramEntry() is called for non-updatable dictionary."); return false; @@ -275,6 +278,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const NgramContext *const ngramContex mDictBuffer->getTailPosition()); return false; } + const NgramContext *const ngramContext = ngramProperty->getNgramContext(); if (!ngramContext->isValid()) { AKLOGE("Ngram context is not valid for adding n-gram entry to the dictionary."); return false; @@ -299,7 +303,8 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const NgramContext *const ngramContex } const UnigramProperty beginningOfSentenceUnigramProperty( true /* representsBeginningOfSentence */, true /* isNotAWord */, - false /* isBlacklisted */, MAX_PROBABILITY /* probability */, HistoricalInfo()); + false /* isBlacklisted */, false /* isPossiblyOffensive */, + MAX_PROBABILITY /* probability */, HistoricalInfo()); if (!addUnigramEntry(ngramContext->getNthPrevWordCodePoints(1 /* n */), &beginningOfSentenceUnigramProperty)) { AKLOGE("Cannot add unigram entry for the beginning-of-sentence."); @@ -316,7 +321,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const NgramContext *const ngramContex bool addedNewEntry = false; if (mNodeWriter.addNgramEntry(prevWordIds, wordId, ngramProperty, &addedNewEntry)) { if (addedNewEntry) { - mBigramCount++; + mEntryCounters.incrementNgramCount(prevWordIds.size() + 1); } return true; } else { @@ -354,7 +359,7 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const NgramContext *const ngramCon return false; } if (mNodeWriter.removeNgramEntry(prevWordIds, wordId)) { - mBigramCount--; + mEntryCounters.decrementNgramCount(prevWordIds.size()); return true; } else { return false; @@ -375,38 +380,47 @@ bool Ver4PatriciaTriePolicy::updateEntriesForWordWithNgramContext( if (wordId == NOT_A_WORD_ID) { // The word is not in the dictionary. const UnigramProperty unigramProperty(false /* representsBeginningOfSentence */, - false /* isNotAWord */, false /* isBlacklisted */, NOT_A_PROBABILITY, - HistoricalInfo(historicalInfo.getTimestamp(), 0 /* level */, 0 /* count */)); + false /* isNotAWord */, false /* isBlacklisted */, false /* isPossiblyOffensive */, + NOT_A_PROBABILITY, HistoricalInfo(historicalInfo.getTimestamp(), 0 /* level */, + 0 /* count */)); if (!addUnigramEntry(wordCodePoints, &unigramProperty)) { AKLOGE("Cannot add unigarm entry in updateEntriesForWordWithNgramContext()."); return false; } + if (!isValidWord) { + return true; + } wordId = getWordId(wordCodePoints, false /* tryLowerCaseSearch */); } WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; const WordIdArrayView prevWordIds = ngramContext->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */); - if (prevWordIds.firstOrDefault(NOT_A_WORD_ID) == NOT_A_WORD_ID - && ngramContext->isNthPrevWordBeginningOfSentence(1 /* n */)) { - const UnigramProperty beginningOfSentenceUnigramProperty( - true /* representsBeginningOfSentence */, - true /* isNotAWord */, false /* isBlacklisted */, NOT_A_PROBABILITY, - HistoricalInfo(historicalInfo.getTimestamp(), 0 /* level */, 0 /* count */)); - if (!addUnigramEntry(ngramContext->getNthPrevWordCodePoints(1 /* n */), - &beginningOfSentenceUnigramProperty)) { - AKLOGE("Cannot add BoS entry in updateEntriesForWordWithNgramContext()."); + if (ngramContext->isNthPrevWordBeginningOfSentence(1 /* n */)) { + if (prevWordIds.firstOrDefault(NOT_A_WORD_ID) == NOT_A_WORD_ID) { + const UnigramProperty beginningOfSentenceUnigramProperty( + true /* representsBeginningOfSentence */, + true /* isNotAWord */, false /* isPossiblyOffensive */, NOT_A_PROBABILITY, + HistoricalInfo(historicalInfo.getTimestamp(), 0 /* level */, 0 /* count */)); + if (!addUnigramEntry(ngramContext->getNthPrevWordCodePoints(1 /* n */), + &beginningOfSentenceUnigramProperty)) { + AKLOGE("Cannot add BoS entry in updateEntriesForWordWithNgramContext()."); + return false; + } + // Refresh word ids. + ngramContext->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */); + } + // Update entries for beginning of sentence. + if (!mBuffers->getMutableLanguageModelDictContent()->updateAllEntriesOnInputWord( + prevWordIds.skip(1 /* n */), prevWordIds[0], true /* isVaild */, historicalInfo, + mHeaderPolicy, &mEntryCounters)) { return false; } - // Refresh word ids. - ngramContext->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */); } - int addedNewNgramEntryCount = 0; if (!mBuffers->getMutableLanguageModelDictContent()->updateAllEntriesOnInputWord(prevWordIds, - wordId, updateAsAValidWord, historicalInfo, mHeaderPolicy, &addedNewNgramEntryCount)) { + wordId, updateAsAValidWord, historicalInfo, mHeaderPolicy, &mEntryCounters)) { return false; } - mBigramCount += addedNewNgramEntryCount; return true; } @@ -415,7 +429,7 @@ bool Ver4PatriciaTriePolicy::flush(const char *const filePath) { AKLOGI("Warning: flush() is called for non-updatable dictionary. filePath: %s", filePath); return false; } - if (!mWritingHelper.writeToDictFile(filePath, mUnigramCount, mBigramCount)) { + if (!mWritingHelper.writeToDictFile(filePath, mEntryCounters.getEntryCounts())) { AKLOGE("Cannot flush the dictionary to file."); mIsCorrupted = true; return false; @@ -453,7 +467,7 @@ bool Ver4PatriciaTriePolicy::needsToRunGC(const bool mindsBlockByGC) const { // Needs to reduce dictionary size. return true; } else if (mHeaderPolicy->isDecayingDict()) { - return ForgettingCurveUtils::needsToDecay(mindsBlockByGC, mUnigramCount, mBigramCount, + return ForgettingCurveUtils::needsToDecay(mindsBlockByGC, mEntryCounters.getEntryCounts(), mHeaderPolicy); } return false; @@ -463,19 +477,19 @@ void Ver4PatriciaTriePolicy::getProperty(const char *const query, const int quer char *const outResult, const int maxResultLength) { const int compareLength = queryLength + 1 /* terminator */; if (strncmp(query, UNIGRAM_COUNT_QUERY, compareLength) == 0) { - snprintf(outResult, maxResultLength, "%d", mUnigramCount); + snprintf(outResult, maxResultLength, "%d", mEntryCounters.getUnigramCount()); } else if (strncmp(query, BIGRAM_COUNT_QUERY, compareLength) == 0) { - snprintf(outResult, maxResultLength, "%d", mBigramCount); + snprintf(outResult, maxResultLength, "%d", mEntryCounters.getBigramCount()); } else if (strncmp(query, MAX_UNIGRAM_COUNT_QUERY, compareLength) == 0) { snprintf(outResult, maxResultLength, "%d", mHeaderPolicy->isDecayingDict() ? - ForgettingCurveUtils::getUnigramCountHardLimit( + ForgettingCurveUtils::getEntryCountHardLimit( mHeaderPolicy->getMaxUnigramCount()) : static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE)); } else if (strncmp(query, MAX_BIGRAM_COUNT_QUERY, compareLength) == 0) { snprintf(outResult, maxResultLength, "%d", mHeaderPolicy->isDecayingDict() ? - ForgettingCurveUtils::getBigramCountHardLimit( + ForgettingCurveUtils::getEntryCountHardLimit( mHeaderPolicy->getMaxBigramCount()) : static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE)); } @@ -488,29 +502,37 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty( AKLOGE("getWordProperty is called for invalid word."); return WordProperty(); } - const int ptNodePos = - mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId); - const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); - const ProbabilityEntry probabilityEntry = - mBuffers->getLanguageModelDictContent()->getProbabilityEntry( - ptNodeParams.getTerminalId()); - const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo(); - // Fetch bigram information. - // TODO: Support n-gram. + const LanguageModelDictContent *const languageModelDictContent = + mBuffers->getLanguageModelDictContent(); + // Fetch ngram information. std::vector<NgramProperty> ngrams; - const WordIdArrayView prevWordIds = WordIdArrayView::singleElementView(&wordId); - int bigramWord1CodePoints[MAX_WORD_LENGTH]; - for (const auto entry : mBuffers->getLanguageModelDictContent()->getProbabilityEntries( - prevWordIds)) { - const int codePointCount = getCodePointsAndReturnCodePointCount(entry.getWordId(), - MAX_WORD_LENGTH, bigramWord1CodePoints); - const ProbabilityEntry probabilityEntry = entry.getProbabilityEntry(); - const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo(); - const int probability = probabilityEntry.hasHistoricalInfo() ? - ForgettingCurveUtils::decodeProbability(historicalInfo, mHeaderPolicy) : - probabilityEntry.getProbability(); - ngrams.emplace_back(CodePointArrayView(bigramWord1CodePoints, codePointCount).toVector(), - probability, *historicalInfo); + int ngramTargetCodePoints[MAX_WORD_LENGTH]; + int ngramPrevWordsCodePoints[MAX_PREV_WORD_COUNT_FOR_N_GRAM][MAX_WORD_LENGTH]; + int ngramPrevWordsCodePointCount[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; + bool ngramPrevWordIsBeginningOfSentense[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; + for (const auto entry : languageModelDictContent->exportAllNgramEntriesRelatedToWord( + mHeaderPolicy, wordId)) { + const int codePointCount = getCodePointsAndReturnCodePointCount(entry.getTargetWordId(), + MAX_WORD_LENGTH, ngramTargetCodePoints); + const WordIdArrayView prevWordIds = entry.getPrevWordIds(); + for (size_t i = 0; i < prevWordIds.size(); ++i) { + ngramPrevWordsCodePointCount[i] = getCodePointsAndReturnCodePointCount(prevWordIds[i], + MAX_WORD_LENGTH, ngramPrevWordsCodePoints[i]); + ngramPrevWordIsBeginningOfSentense[i] = languageModelDictContent->getProbabilityEntry( + prevWordIds[i]).representsBeginningOfSentence(); + if (ngramPrevWordIsBeginningOfSentense[i]) { + ngramPrevWordsCodePointCount[i] = CharUtils::removeBeginningOfSentenceMarker( + ngramPrevWordsCodePoints[i], ngramPrevWordsCodePointCount[i]); + } + } + const NgramContext ngramContext(ngramPrevWordsCodePoints, ngramPrevWordsCodePointCount, + ngramPrevWordIsBeginningOfSentense, prevWordIds.size()); + const ProbabilityEntry ngramProbabilityEntry = entry.getProbabilityEntry(); + const HistoricalInfo *const historicalInfo = ngramProbabilityEntry.getHistoricalInfo(); + // TODO: Output flags in WordAttributes. + ngrams.emplace_back(ngramContext, + CodePointArrayView(ngramTargetCodePoints, codePointCount).toVector(), + entry.getWordAttributes().getProbability(), *historicalInfo); } // Fetch shortcut information. std::vector<UnigramProperty::ShortcutProperty> shortcuts; @@ -530,9 +552,14 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty( shortcutProbability); } } + const WordAttributes wordAttributes = languageModelDictContent->getWordAttributes( + WordIdArrayView(), wordId, true /* mustMatchAllPrevWords */, mHeaderPolicy); + const ProbabilityEntry probabilityEntry = languageModelDictContent->getProbabilityEntry(wordId); + const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo(); const UnigramProperty unigramProperty(probabilityEntry.representsBeginningOfSentence(), - probabilityEntry.isNotAWord(), probabilityEntry.isBlacklisted(), - probabilityEntry.getProbability(), *historicalInfo, std::move(shortcuts)); + wordAttributes.isNotAWord(), wordAttributes.isBlacklisted(), + wordAttributes.isPossiblyOffensive(), wordAttributes.getProbability(), + *historicalInfo, std::move(shortcuts)); return WordProperty(wordCodePoints.toVector(), &unigramProperty, &ngrams); } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h index c0532815c..13700b390 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h @@ -30,6 +30,7 @@ #include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h" #include "suggest/policyimpl/dictionary/structure/v4/ver4_pt_node_array_reader.h" #include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" +#include "suggest/policyimpl/dictionary/utils/entry_counters.h" #include "utils/int_array_view.h" namespace latinime { @@ -37,7 +38,6 @@ namespace latinime { class DicNode; class DicNodeVector; -// TODO: Support counting ngram entries. // Word id = Artificial id that is stored in the PtNode looked up by the word. class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { public: @@ -51,8 +51,8 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { &mShortcutPolicy), mUpdatingHelper(mDictBuffer, &mNodeReader, &mNodeWriter), mWritingHelper(mBuffers.get()), - mUnigramCount(mHeaderPolicy->getUnigramCount()), - mBigramCount(mHeaderPolicy->getBigramCount()), + mEntryCounters(mHeaderPolicy->getUnigramCount(), mHeaderPolicy->getBigramCount(), + mHeaderPolicy->getTrigramCount()), mTerminalPtNodePositionsForIteratingWords(), mIsCorrupted(false) {}; AK_FORCE_INLINE int getRootPosition() const { @@ -92,8 +92,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { bool removeUnigramEntry(const CodePointArrayView wordCodePoints); - bool addNgramEntry(const NgramContext *const ngramContext, - const NgramProperty *const ngramProperty); + bool addNgramEntry(const NgramProperty *const ngramProperty); bool removeNgramEntry(const NgramContext *const ngramContext, const CodePointArrayView wordCodePoints); @@ -141,9 +140,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { Ver4PatriciaTrieNodeWriter mNodeWriter; DynamicPtUpdatingHelper mUpdatingHelper; Ver4PatriciaTrieWritingHelper mWritingHelper; - int mUnigramCount; - // TODO: Support counting ngram entries. - int mBigramCount; + MutableEntryCounters mEntryCounters; std::vector<int> mTerminalPtNodePositionsForIteratingWords; mutable bool mIsCorrupted; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp index f0d59c150..7f0604ce8 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp @@ -33,17 +33,18 @@ namespace latinime { bool Ver4PatriciaTrieWritingHelper::writeToDictFile(const char *const dictDirPath, - const int unigramCount, const int bigramCount) const { + const EntryCounts &entryCounts) const { const HeaderPolicy *const headerPolicy = mBuffers->getHeaderPolicy(); BufferWithExtendableBuffer headerBuffer( BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE); const int extendedRegionSize = headerPolicy->getExtendedRegionSize() + mBuffers->getTrieBuffer()->getUsedAdditionalBufferSize(); if (!headerPolicy->fillInAndWriteHeaderToBuffer(false /* updatesLastDecayedTime */, - unigramCount, bigramCount, extendedRegionSize, &headerBuffer)) { + entryCounts, extendedRegionSize, &headerBuffer)) { AKLOGE("Cannot write header structure to buffer. " - "updatesLastDecayedTime: %d, unigramCount: %d, bigramCount: %d, " - "extendedRegionSize: %d", false, unigramCount, bigramCount, + "updatesLastDecayedTime: %d, unigramCount: %d, bigramCount: %d, trigramCount: %d," + "extendedRegionSize: %d", false, entryCounts.getUnigramCount(), + entryCounts.getBigramCount(), entryCounts.getTrigramCount(), extendedRegionSize); return false; } @@ -56,15 +57,14 @@ bool Ver4PatriciaTrieWritingHelper::writeToDictFileWithGC(const int rootPtNodeAr Ver4DictBuffers::Ver4DictBuffersPtr dictBuffers( Ver4DictBuffers::createVer4DictBuffers(headerPolicy, Ver4DictConstants::MAX_DICTIONARY_SIZE)); - int unigramCount = 0; - int bigramCount = 0; - if (!runGC(rootPtNodeArrayPos, headerPolicy, dictBuffers.get(), &unigramCount, &bigramCount)) { + MutableEntryCounters entryCounters; + if (!runGC(rootPtNodeArrayPos, headerPolicy, dictBuffers.get(), &entryCounters)) { return false; } BufferWithExtendableBuffer headerBuffer( BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE); if (!headerPolicy->fillInAndWriteHeaderToBuffer(true /* updatesLastDecayedTime */, - unigramCount, bigramCount, 0 /* extendedRegionSize */, &headerBuffer)) { + entryCounters.getEntryCounts(), 0 /* extendedRegionSize */, &headerBuffer)) { return false; } return dictBuffers->flushHeaderAndDictBuffers(dictDirPath, &headerBuffer); @@ -72,7 +72,7 @@ bool Ver4PatriciaTrieWritingHelper::writeToDictFileWithGC(const int rootPtNodeAr bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, const HeaderPolicy *const headerPolicy, Ver4DictBuffers *const buffersToWrite, - int *const outUnigramCount, int *const outBigramCount) { + MutableEntryCounters *const outEntryCounters) { Ver4PatriciaTrieNodeReader ptNodeReader(mBuffers->getTrieBuffer()); Ver4PtNodeArrayReader ptNodeArrayReader(mBuffers->getTrieBuffer()); Ver4ShortcutListPolicy shortcutPolicy(mBuffers->getMutableShortcutDictContent(), @@ -80,24 +80,17 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, Ver4PatriciaTrieNodeWriter ptNodeWriter(mBuffers->getWritableTrieBuffer(), mBuffers, &ptNodeReader, &ptNodeArrayReader, &shortcutPolicy); - int entryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; if (!mBuffers->getMutableLanguageModelDictContent()->updateAllProbabilityEntriesForGC( - headerPolicy, entryCountTable)) { + headerPolicy, outEntryCounters)) { AKLOGE("Failed to update probabilities in language model dict content."); return false; } if (headerPolicy->isDecayingDict()) { - int maxEntryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; - maxEntryCountTable[LanguageModelDictContent::UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE] = - headerPolicy->getMaxUnigramCount(); - maxEntryCountTable[LanguageModelDictContent::BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE] = - headerPolicy->getMaxBigramCount(); - for (size_t i = 2; i < NELEMS(maxEntryCountTable); ++i) { - // TODO: Have max n-gram count. - maxEntryCountTable[i] = headerPolicy->getMaxBigramCount(); - } - if (!mBuffers->getMutableLanguageModelDictContent()->truncateEntries(entryCountTable, - maxEntryCountTable, headerPolicy, entryCountTable)) { + const EntryCounts maxEntryCounts(headerPolicy->getMaxUnigramCount(), + headerPolicy->getMaxBigramCount(), headerPolicy->getMaxTrigramCount()); + if (!mBuffers->getMutableLanguageModelDictContent()->truncateEntries( + outEntryCounters->getEntryCounts(), maxEntryCounts, headerPolicy, + outEntryCounters)) { AKLOGE("Failed to truncate entries in language model dict content."); return false; } @@ -141,9 +134,9 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, &terminalIdMap)) { return false; } - // Run GC for probability dict content. + // Run GC for language model dict content. if (!buffersToWrite->getMutableLanguageModelDictContent()->runGC(&terminalIdMap, - mBuffers->getLanguageModelDictContent(), nullptr /* outNgramCount */)) { + mBuffers->getLanguageModelDictContent())) { return false; } // Run GC for shortcut dict content. @@ -166,10 +159,6 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, &traversePolicyToUpdateAllPtNodeFlagsAndTerminalIds)) { return false; } - *outUnigramCount = - entryCountTable[LanguageModelDictContent::UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE]; - *outBigramCount = - entryCountTable[LanguageModelDictContent::BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE]; return true; } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h index 3569d0576..c56cea5cf 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h @@ -20,6 +20,7 @@ #include "defines.h" #include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_gc_event_listeners.h" #include "suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h" +#include "suggest/policyimpl/dictionary/utils/entry_counters.h" namespace latinime { @@ -33,9 +34,7 @@ class Ver4PatriciaTrieWritingHelper { Ver4PatriciaTrieWritingHelper(Ver4DictBuffers *const buffers) : mBuffers(buffers) {} - // TODO: Support counting ngram entries. - bool writeToDictFile(const char *const dictDirPath, const int unigramCount, - const int bigramCount) const; + bool writeToDictFile(const char *const dictDirPath, const EntryCounts &entryCounts) const; // This method cannot be const because the original dictionary buffer will be updated to detect // useless PtNodes during GC. @@ -68,8 +67,7 @@ class Ver4PatriciaTrieWritingHelper { }; bool runGC(const int rootPtNodeArrayPos, const HeaderPolicy *const headerPolicy, - Ver4DictBuffers *const buffersToWrite, int *const outUnigramCount, - int *const outBigramCount); + Ver4DictBuffers *const buffersToWrite, MutableEntryCounters *const outEntryCounters); Ver4DictBuffers *const mBuffers; }; diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp index b7e2a7278..edcb43678 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp @@ -27,6 +27,7 @@ #include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_writing_utils.h" #include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.h" #include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" +#include "suggest/policyimpl/dictionary/utils/entry_counters.h" #include "suggest/policyimpl/dictionary/utils/file_utils.h" #include "suggest/policyimpl/dictionary/utils/format_utils.h" #include "utils/time_keeper.h" @@ -43,13 +44,13 @@ const int DictFileWritingUtils::SIZE_OF_BUFFER_SIZE_FIELD = 4; TimeKeeper::setCurrentTime(); const FormatUtils::FORMAT_VERSION formatVersion = FormatUtils::getFormatVersion(dictVersion); switch (formatVersion) { - case FormatUtils::VERSION_4: + case FormatUtils::VERSION_402: return createEmptyV4DictFile<backward::v402::Ver4DictConstants, backward::v402::Ver4DictBuffers, backward::v402::Ver4DictBuffers::Ver4DictBuffersPtr>( filePath, localeAsCodePointVector, attributeMap, formatVersion); case FormatUtils::VERSION_4_ONLY_FOR_TESTING: - case FormatUtils::VERSION_4_DEV: + case FormatUtils::VERSION_403: return createEmptyV4DictFile<Ver4DictConstants, Ver4DictBuffers, Ver4DictBuffers::Ver4DictBuffersPtr>( filePath, localeAsCodePointVector, attributeMap, formatVersion); @@ -69,8 +70,7 @@ template<class DictConstants, class DictBuffers, class DictBuffersPtr> DictBuffersPtr dictBuffers = DictBuffers::createVer4DictBuffers(&headerPolicy, DictConstants::MAX_DICT_EXTENDED_REGION_SIZE); headerPolicy.fillInAndWriteHeaderToBuffer(true /* updatesLastDecayedTime */, - 0 /* unigramCount */, 0 /* bigramCount */, - 0 /* extendedRegionSize */, dictBuffers->getWritableHeaderBuffer()); + EntryCounts(), 0 /* extendedRegionSize */, dictBuffers->getWritableHeaderBuffer()); if (!DynamicPtWritingUtils::writeEmptyDictionary( dictBuffers->getWritableTrieBuffer(), 0 /* rootPos */)) { AKLOGE("Empty ver4 dictionary structure cannot be created on memory."); diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/entry_counters.h b/native/jni/src/suggest/policyimpl/dictionary/utils/entry_counters.h new file mode 100644 index 000000000..73dc42a18 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/entry_counters.h @@ -0,0 +1,133 @@ +/* + * Copyright (C) 2014, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_ENTRY_COUNTERS_H +#define LATINIME_ENTRY_COUNTERS_H + +#include <array> + +#include "defines.h" + +namespace latinime { + +// Copyable but immutable +class EntryCounts final { + public: + EntryCounts() : mEntryCounts({{0, 0, 0}}) {} + + EntryCounts(const int unigramCount, const int bigramCount, const int trigramCount) + : mEntryCounts({{unigramCount, bigramCount, trigramCount}}) {} + + explicit EntryCounts(const std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1> &counters) + : mEntryCounts(counters) {} + + int getUnigramCount() const { + return mEntryCounts[0]; + } + + int getBigramCount() const { + return mEntryCounts[1]; + } + + int getTrigramCount() const { + return mEntryCounts[2]; + } + + int getNgramCount(const size_t n) const { + if (n < 1 || n > mEntryCounts.size()) { + return 0; + } + return mEntryCounts[n - 1]; + } + + private: + DISALLOW_ASSIGNMENT_OPERATOR(EntryCounts); + + const std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1> mEntryCounts; +}; + +class MutableEntryCounters final { + public: + MutableEntryCounters() { + mEntryCounters.fill(0); + } + + MutableEntryCounters(const int unigramCount, const int bigramCount, const int trigramCount) + : mEntryCounters({{unigramCount, bigramCount, trigramCount}}) {} + + const EntryCounts getEntryCounts() const { + return EntryCounts(mEntryCounters); + } + + int getUnigramCount() const { + return mEntryCounters[0]; + } + + int getBigramCount() const { + return mEntryCounters[1]; + } + + int getTrigramCount() const { + return mEntryCounters[2]; + } + + void incrementUnigramCount() { + ++mEntryCounters[0]; + } + + void decrementUnigramCount() { + ASSERT(mEntryCounters[0] != 0); + --mEntryCounters[0]; + } + + void incrementBigramCount() { + ++mEntryCounters[1]; + } + + void decrementBigramCount() { + ASSERT(mEntryCounters[1] != 0); + --mEntryCounters[1]; + } + + void incrementNgramCount(const size_t n) { + if (n < 1 || n > mEntryCounters.size()) { + return; + } + ++mEntryCounters[n - 1]; + } + + void decrementNgramCount(const size_t n) { + if (n < 1 || n > mEntryCounters.size()) { + return; + } + ASSERT(mEntryCounters[n - 1] != 0); + --mEntryCounters[n - 1]; + } + + void setNgramCount(const size_t n, const int count) { + if (n < 1 || n > mEntryCounters.size()) { + return; + } + mEntryCounters[n - 1] = count; + } + + private: + DISALLOW_COPY_AND_ASSIGN(MutableEntryCounters); + + std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1> mEntryCounters; +}; +} // namespace latinime +#endif /* LATINIME_ENTRY_COUNTERS_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp index e5ef2abf8..9055f7bfc 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp @@ -38,8 +38,7 @@ const int ForgettingCurveUtils::OCCURRENCES_TO_RAISE_THE_LEVEL = 1; // 15 days const int ForgettingCurveUtils::DURATION_TO_LOWER_THE_LEVEL_IN_SECONDS = 15 * 24 * 60 * 60; -const float ForgettingCurveUtils::UNIGRAM_COUNT_HARD_LIMIT_WEIGHT = 1.2; -const float ForgettingCurveUtils::BIGRAM_COUNT_HARD_LIMIT_WEIGHT = 1.2; +const float ForgettingCurveUtils::ENTRY_COUNT_HARD_LIMIT_WEIGHT = 1.2; const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityTable; @@ -126,14 +125,22 @@ const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityT } /* static */ bool ForgettingCurveUtils::needsToDecay(const bool mindsBlockByDecay, - const int unigramCount, const int bigramCount, const HeaderPolicy *const headerPolicy) { - if (unigramCount >= getUnigramCountHardLimit(headerPolicy->getMaxUnigramCount())) { + const EntryCounts &entryCounts, const HeaderPolicy *const headerPolicy) { + if (entryCounts.getUnigramCount() + >= getEntryCountHardLimit(headerPolicy->getMaxUnigramCount())) { // Unigram count exceeds the limit. return true; - } else if (bigramCount >= getBigramCountHardLimit(headerPolicy->getMaxBigramCount())) { + } + if (entryCounts.getBigramCount() + >= getEntryCountHardLimit(headerPolicy->getMaxBigramCount())) { // Bigram count exceeds the limit. return true; } + if (entryCounts.getTrigramCount() + >= getEntryCountHardLimit(headerPolicy->getMaxTrigramCount())) { + // Trigram count exceeds the limit. + return true; + } if (mindsBlockByDecay) { return false; } diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h index ccbc4a98d..06dcae8a1 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h @@ -21,6 +21,7 @@ #include "defines.h" #include "suggest/core/dictionary/property/historical_info.h" +#include "suggest/policyimpl/dictionary/utils/entry_counters.h" namespace latinime { @@ -42,22 +43,17 @@ class ForgettingCurveUtils { static bool needsToKeep(const HistoricalInfo *const historicalInfo, const HeaderPolicy *const headerPolicy); - static bool needsToDecay(const bool mindsBlockByDecay, const int unigramCount, - const int bigramCount, const HeaderPolicy *const headerPolicy); + static bool needsToDecay(const bool mindsBlockByDecay, const EntryCounts &entryCounters, + const HeaderPolicy *const headerPolicy); // TODO: Improve probability computation method and remove this. static int getProbabilityBiasForNgram(const int n) { return (n - 1) * MULTIPLIER_TWO_IN_PROBABILITY_SCALE; } - AK_FORCE_INLINE static int getUnigramCountHardLimit(const int maxUnigramCount) { - return static_cast<int>(static_cast<float>(maxUnigramCount) - * UNIGRAM_COUNT_HARD_LIMIT_WEIGHT); - } - - AK_FORCE_INLINE static int getBigramCountHardLimit(const int maxBigramCount) { - return static_cast<int>(static_cast<float>(maxBigramCount) - * BIGRAM_COUNT_HARD_LIMIT_WEIGHT); + AK_FORCE_INLINE static int getEntryCountHardLimit(const int maxEntryCount) { + return static_cast<int>(static_cast<float>(maxEntryCount) + * ENTRY_COUNT_HARD_LIMIT_WEIGHT); } private: @@ -101,8 +97,7 @@ class ForgettingCurveUtils { static const int OCCURRENCES_TO_RAISE_THE_LEVEL; static const int DURATION_TO_LOWER_THE_LEVEL_IN_SECONDS; - static const float UNIGRAM_COUNT_HARD_LIMIT_WEIGHT; - static const float BIGRAM_COUNT_HARD_LIMIT_WEIGHT; + static const float ENTRY_COUNT_HARD_LIMIT_WEIGHT; static const ProbabilityTable sProbabilityTable; diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp index 0cffe569d..e225c235e 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp @@ -28,15 +28,17 @@ const size_t FormatUtils::DICTIONARY_MINIMUM_SIZE = 12; /* static */ FormatUtils::FORMAT_VERSION FormatUtils::getFormatVersion(const int formatVersion) { switch (formatVersion) { case VERSION_2: - return VERSION_2; case VERSION_201: - return VERSION_201; + AKLOGE("Dictionary versions 2 and 201 are incompatible with this version"); + return UNKNOWN_VERSION; + case VERSION_202: + return VERSION_202; case VERSION_4_ONLY_FOR_TESTING: return VERSION_4_ONLY_FOR_TESTING; - case VERSION_4: - return VERSION_4; - case VERSION_4_DEV: - return VERSION_4_DEV; + case VERSION_402: + return VERSION_402; + case VERSION_403: + return VERSION_403; default: return UNKNOWN_VERSION; } diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h index 96310086b..1616efcce 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h @@ -31,11 +31,15 @@ class FormatUtils { public: enum FORMAT_VERSION { // These MUST have the same values as the relevant constants in FormatSpec.java. + // TODO: Remove VERSION_2 and VERSION_201 when we: + // * Confirm that old versions of LatinIME download old-format dictionaries + // * We no longer need the corresponding constants on the Java side for dicttool VERSION_2 = 2, VERSION_201 = 201, + VERSION_202 = 202, VERSION_4_ONLY_FOR_TESTING = 399, - VERSION_4 = 402, - VERSION_4_DEV = 403, + VERSION_402 = 402, + VERSION_403 = 403, UNKNOWN_VERSION = -1 }; diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/probability_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/probability_utils.cpp new file mode 100644 index 000000000..e8fa06942 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/probability_utils.cpp @@ -0,0 +1,23 @@ +/* + * Copyright (C) 2014, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/utils/probability_utils.h" + +namespace latinime { + +const float ProbabilityUtils::PROBABILITY_ENCODING_SCALER = 8.58923700372f; + +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/probability_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/probability_utils.h index 3b339e61a..2050af1e9 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/probability_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/probability_utils.h @@ -17,6 +17,9 @@ #ifndef LATINIME_PROBABILITY_UTILS_H #define LATINIME_PROBABILITY_UTILS_H +#include <algorithm> +#include <cmath> + #include "defines.h" namespace latinime { @@ -47,8 +50,20 @@ class ProbabilityUtils { + static_cast<int>(static_cast<float>(bigramProbability + 1) * stepSize); } + // Encode probability using the same way as we are doing for main dictionaries. + static AK_FORCE_INLINE int encodeRawProbability(const float rawProbability) { + const float probability = static_cast<float>(MAX_PROBABILITY) + + log2f(rawProbability) * PROBABILITY_ENCODING_SCALER; + if (probability < 0.0f) { + return 0; + } + return std::min(static_cast<int>(probability + 0.5f), MAX_PROBABILITY); + } + private: DISALLOW_IMPLICIT_CONSTRUCTORS(ProbabilityUtils); + + static const float PROBABILITY_ENCODING_SCALER; }; } #endif /* LATINIME_PROBABILITY_UTILS_H */ diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp index 6a2db687d..856808a74 100644 --- a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp +++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp @@ -24,6 +24,7 @@ const int ScoringParams::THRESHOLD_NEXT_WORD_PROBABILITY_FOR_CAPPED = 120; const float ScoringParams::AUTOCORRECT_OUTPUT_THRESHOLD = 1.0f; const float ScoringParams::EXACT_MATCH_PROMOTION = 1.1f; +const float ScoringParams::PERFECT_MATCH_PROMOTION = 1.1f; const float ScoringParams::CASE_ERROR_PENALTY_FOR_EXACT_MATCH = 0.01f; const float ScoringParams::ACCENT_ERROR_PENALTY_FOR_EXACT_MATCH = 0.02f; const float ScoringParams::DIGRAPH_PENALTY_FOR_EXACT_MATCH = 0.03f; @@ -48,17 +49,17 @@ const float ScoringParams::INSERTION_COST_SAME_CHAR = 0.5508f; const float ScoringParams::INSERTION_COST_PROXIMITY_CHAR = 0.674f; const float ScoringParams::INSERTION_COST_FIRST_CHAR = 0.639f; const float ScoringParams::TRANSPOSITION_COST = 0.5608f; -const float ScoringParams::SPACE_SUBSTITUTION_COST = 0.334f; +const float ScoringParams::SPACE_SUBSTITUTION_COST = 0.33f; +const float ScoringParams::SPACE_OMISSION_COST = 0.1f; const float ScoringParams::ADDITIONAL_PROXIMITY_COST = 0.37972f; const float ScoringParams::SUBSTITUTION_COST = 0.3806f; -const float ScoringParams::COST_NEW_WORD = 0.0314f; const float ScoringParams::COST_SECOND_OR_LATER_WORD_FIRST_CHAR_UPPERCASE = 0.3224f; const float ScoringParams::DISTANCE_WEIGHT_LANGUAGE = 1.1214f; const float ScoringParams::COST_FIRST_COMPLETION = 0.4836f; const float ScoringParams::COST_COMPLETION = 0.00624f; const float ScoringParams::HAS_PROXIMITY_TERMINAL_COST = 0.0683f; const float ScoringParams::HAS_EDIT_CORRECTION_TERMINAL_COST = 0.0362f; -const float ScoringParams::HAS_MULTI_WORD_TERMINAL_COST = 0.4182f; +const float ScoringParams::HAS_MULTI_WORD_TERMINAL_COST = 0.3482f; const float ScoringParams::TYPING_BASE_OUTPUT_SCORE = 1.0f; const float ScoringParams::TYPING_MAX_OUTPUT_SCORE_PER_INPUT = 0.1f; const float ScoringParams::NORMALIZED_SPATIAL_DISTANCE_THRESHOLD_FOR_EDIT = 0.095f; diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.h b/native/jni/src/suggest/policyimpl/typing/scoring_params.h index 731424f3d..6f327a370 100644 --- a/native/jni/src/suggest/policyimpl/typing/scoring_params.h +++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.h @@ -34,6 +34,7 @@ class ScoringParams { static const int THRESHOLD_SHORT_WORD_LENGTH; static const float EXACT_MATCH_PROMOTION; + static const float PERFECT_MATCH_PROMOTION; static const float CASE_ERROR_PENALTY_FOR_EXACT_MATCH; static const float ACCENT_ERROR_PENALTY_FOR_EXACT_MATCH; static const float DIGRAPH_PENALTY_FOR_EXACT_MATCH; @@ -56,9 +57,9 @@ class ScoringParams { static const float INSERTION_COST_FIRST_CHAR; static const float TRANSPOSITION_COST; static const float SPACE_SUBSTITUTION_COST; + static const float SPACE_OMISSION_COST; static const float ADDITIONAL_PROXIMITY_COST; static const float SUBSTITUTION_COST; - static const float COST_NEW_WORD; static const float COST_SECOND_OR_LATER_WORD_FIRST_CHAR_UPPERCASE; static const float DISTANCE_WEIGHT_LANGUAGE; static const float COST_FIRST_COMPLETION; diff --git a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h index 0240bcf54..6acd767ea 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h @@ -44,23 +44,50 @@ class TypingScoring : public Scoring { AK_FORCE_INLINE int calculateFinalScore(const float compoundDistance, const int inputSize, const ErrorTypeUtils::ErrorType containedErrorTypes, const bool forceCommit, - const bool boostExactMatches) const { + const bool boostExactMatches, const bool hasProbabilityZero) const { const float maxDistance = ScoringParams::DISTANCE_WEIGHT_LANGUAGE + static_cast<float>(inputSize) * ScoringParams::TYPING_MAX_OUTPUT_SCORE_PER_INPUT; float score = ScoringParams::TYPING_BASE_OUTPUT_SCORE - compoundDistance / maxDistance; if (forceCommit) { score += ScoringParams::AUTOCORRECT_OUTPUT_THRESHOLD; } - if (boostExactMatches && ErrorTypeUtils::isExactMatch(containedErrorTypes)) { - score += ScoringParams::EXACT_MATCH_PROMOTION; - if ((ErrorTypeUtils::MATCH_WITH_WRONG_CASE & containedErrorTypes) != 0) { - score -= ScoringParams::CASE_ERROR_PENALTY_FOR_EXACT_MATCH; + if (hasProbabilityZero) { + // Previously, when both legitimate 0-frequency words (such as distracters) and + // offensive words were encoded in the same way, distracters would never show up + // when the user blocked offensive words (the default setting, as well as the + // setting for regression tests). + // + // When b/11031090 was fixed and a separate encoding was used for offensive words, + // 0-frequency words would no longer be blocked when they were an "exact match" + // (where case mismatches and accent mismatches would be considered an "exact + // match"). The exact match boosting functionality meant that, for example, when + // the user typed "mt" they would be suggested the word "Mt", although they most + // probably meant to type "my". + // + // For this reason, we introduced this change, which does the following: + // * Defines the "perfect match" as a really exact match, with no room for case or + // accent mismatches + // * When the target word has probability zero (as "Mt" does, because it is a + // distracter), ONLY boost its score if it is a perfect match. + // + // By doing this, when the user types "mt", the word "Mt" will NOT be boosted, and + // they will get "my". However, if the user makes an explicit effort to type "Mt", + // we do boost the word "Mt" so that the user's input is not autocorrected to "My". + if (boostExactMatches && ErrorTypeUtils::isPerfectMatch(containedErrorTypes)) { + score += ScoringParams::PERFECT_MATCH_PROMOTION; } - if ((ErrorTypeUtils::MATCH_WITH_MISSING_ACCENT & containedErrorTypes) != 0) { - score -= ScoringParams::ACCENT_ERROR_PENALTY_FOR_EXACT_MATCH; - } - if ((ErrorTypeUtils::MATCH_WITH_DIGRAPH & containedErrorTypes) != 0) { - score -= ScoringParams::DIGRAPH_PENALTY_FOR_EXACT_MATCH; + } else { + if (boostExactMatches && ErrorTypeUtils::isExactMatch(containedErrorTypes)) { + score += ScoringParams::EXACT_MATCH_PROMOTION; + if ((ErrorTypeUtils::MATCH_WITH_WRONG_CASE & containedErrorTypes) != 0) { + score -= ScoringParams::CASE_ERROR_PENALTY_FOR_EXACT_MATCH; + } + if ((ErrorTypeUtils::MATCH_WITH_MISSING_ACCENT & containedErrorTypes) != 0) { + score -= ScoringParams::ACCENT_ERROR_PENALTY_FOR_EXACT_MATCH; + } + if ((ErrorTypeUtils::MATCH_WITH_DIGRAPH & containedErrorTypes) != 0) { + score -= ScoringParams::DIGRAPH_PENALTY_FOR_EXACT_MATCH; + } } } return static_cast<int>(score * SUGGEST_INTERFACE_OUTPUT_SCALE); diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h index 84077174d..1338ac81a 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h @@ -150,9 +150,10 @@ class TypingWeighting : public Weighting { return cost + weightedDistance; } - float getNewWordSpatialCost(const DicTraverseSession *const traverseSession, + float getSpaceOmissionCost(const DicTraverseSession *const traverseSession, const DicNode *const dicNode, DicNode_InputStateG *inputStateG) const { - return ScoringParams::COST_NEW_WORD * traverseSession->getMultiWordCostMultiplier(); + const float cost = ScoringParams::SPACE_OMISSION_COST; + return cost * traverseSession->getMultiWordCostMultiplier(); } float getNewWordBigramLanguageCost(const DicTraverseSession *const traverseSession, @@ -202,7 +203,10 @@ class TypingWeighting : public Weighting { AK_FORCE_INLINE float getSpaceSubstitutionCost(const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const { - const float cost = ScoringParams::SPACE_SUBSTITUTION_COST + ScoringParams::COST_NEW_WORD; + const int inputIndex = dicNode->getInputIndex(0); + const float distanceToSpaceKey = traverseSession->getProximityInfoState(0) + ->getPointToKeyLength(inputIndex, KEYCODE_SPACE); + const float cost = ScoringParams::SPACE_SUBSTITUTION_COST * distanceToSpaceKey; return cost * traverseSession->getMultiWordCostMultiplier(); } diff --git a/native/jni/src/utils/char_utils.h b/native/jni/src/utils/char_utils.h index 5e9cdd9b2..7871c26ef 100644 --- a/native/jni/src/utils/char_utils.h +++ b/native/jni/src/utils/char_utils.h @@ -101,6 +101,17 @@ class CharUtils { return codePointCount + 1; } + // Returns updated code point count. + static AK_FORCE_INLINE int removeBeginningOfSentenceMarker(int *const codePoints, + const int codePointCount) { + if (codePointCount <= 0 || codePoints[0] != CODE_POINT_BEGINNING_OF_SENTENCE) { + return codePointCount; + } + const int newCodePointCount = codePointCount - 1; + memmove(codePoints, codePoints + 1, sizeof(int) * newCodePointCount); + return newCodePointCount; + } + private: DISALLOW_IMPLICIT_CONSTRUCTORS(CharUtils); diff --git a/native/jni/src/utils/jni_data_utils.h b/native/jni/src/utils/jni_data_utils.h index 25cc41742..a259e1cd0 100644 --- a/native/jni/src/utils/jni_data_utils.h +++ b/native/jni/src/utils/jni_data_utils.h @@ -50,6 +50,7 @@ class JniDataUtils { const jsize keyUtf8Length = env->GetStringUTFLength(keyString); char keyChars[keyUtf8Length + 1]; env->GetStringUTFRegion(keyString, 0, env->GetStringLength(keyString), keyChars); + env->DeleteLocalRef(keyString); keyChars[keyUtf8Length] = '\0'; DictionaryHeaderStructurePolicy::AttributeMap::key_type key; HeaderReadWriteUtils::insertCharactersIntoVector(keyChars, &key); @@ -59,6 +60,7 @@ class JniDataUtils { const jsize valueUtf8Length = env->GetStringUTFLength(valueString); char valueChars[valueUtf8Length + 1]; env->GetStringUTFRegion(valueString, 0, env->GetStringLength(valueString), valueChars); + env->DeleteLocalRef(valueString); valueChars[valueUtf8Length] = '\0'; DictionaryHeaderStructurePolicy::AttributeMap::mapped_type value; HeaderReadWriteUtils::insertCharactersIntoVector(valueChars, &value); @@ -113,6 +115,7 @@ class JniDataUtils { continue; } env->GetIntArrayRegion(prevWord, 0, prevWordLength, prevWordCodePoints[i]); + env->DeleteLocalRef(prevWord); prevWordCodePointCount[i] = prevWordLength; jboolean isBeginningOfSentenceBoolean = JNI_FALSE; env->GetBooleanArrayRegion(isBeginningOfSentenceArray, i, 1 /* len */, |