diff options
Diffstat (limited to 'native')
37 files changed, 720 insertions, 159 deletions
diff --git a/native/jni/NativeFileList.mk b/native/jni/NativeFileList.mk index 7299ed3c0..55bb68344 100644 --- a/native/jni/NativeFileList.mk +++ b/native/jni/NativeFileList.mk @@ -71,7 +71,9 @@ LATIN_IME_CORE_SRC_FILES := \ ver4_patricia_trie_writing_helper.cpp \ ver4_pt_node_array_reader.cpp) \ $(addprefix suggest/policyimpl/dictionary/structure/v4/content/, \ + dynamic_language_model_probability_utils.cpp \ language_model_dict_content.cpp \ + language_model_dict_content_global_counters.cpp \ shortcut_dict_content.cpp \ sparse_table_dict_content.cpp \ terminal_position_lookup_table.cpp) \ @@ -83,6 +85,7 @@ LATIN_IME_CORE_SRC_FILES := \ forgetting_curve_utils.cpp \ format_utils.cpp \ mmapped_buffer.cpp \ + probability_utils.cpp \ sparse_table.cpp \ trie_map.cpp ) \ suggest/policyimpl/gesture/gesture_suggest_policy_factory.cpp \ @@ -128,11 +131,13 @@ LATIN_IME_CORE_TEST_FILES := \ suggest/core/layout/normal_distribution_2d_test.cpp \ suggest/policyimpl/dictionary/header/header_read_write_utils_test.cpp \ suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp \ + suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_global_counters_test.cpp \ suggest/policyimpl/dictionary/structure/v4/content/probability_entry_test.cpp \ suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table_test.cpp \ suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer_test.cpp \ suggest/policyimpl/dictionary/utils/byte_array_utils_test.cpp \ suggest/policyimpl/dictionary/utils/format_utils_test.cpp \ + suggest/policyimpl/dictionary/utils/probability_utils_test.cpp \ suggest/policyimpl/dictionary/utils/sparse_table_test.cpp \ suggest/policyimpl/dictionary/utils/trie_map_test.cpp \ suggest/policyimpl/utils/damerau_levenshtein_edit_distance_policy_test.cpp \ diff --git a/native/jni/src/suggest/core/dictionary/dictionary.cpp b/native/jni/src/suggest/core/dictionary/dictionary.cpp index bfe17cc4c..6a5df9d95 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.cpp +++ b/native/jni/src/suggest/core/dictionary/dictionary.cpp @@ -81,6 +81,9 @@ void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbabi } const WordAttributes wordAttributes = mDictStructurePolicy->getWordAttributesInContext( mPrevWordIds, targetWordId, nullptr /* multiBigramMap */); + if (wordAttributes.getProbability() == NOT_A_PROBABILITY) { + return; + } mSuggestionResults->addPrediction(targetWordCodePoints, codePointCount, wordAttributes.getProbability()); } diff --git a/native/jni/src/suggest/core/dictionary/error_type_utils.cpp b/native/jni/src/suggest/core/dictionary/error_type_utils.cpp index 1e2494e92..8f07ce275 100644 --- a/native/jni/src/suggest/core/dictionary/error_type_utils.cpp +++ b/native/jni/src/suggest/core/dictionary/error_type_utils.cpp @@ -31,6 +31,7 @@ const ErrorTypeUtils::ErrorType ErrorTypeUtils::NEW_WORD = 0x100; const ErrorTypeUtils::ErrorType ErrorTypeUtils::ERRORS_TREATED_AS_AN_EXACT_MATCH = NOT_AN_ERROR | MATCH_WITH_WRONG_CASE | MATCH_WITH_MISSING_ACCENT | MATCH_WITH_DIGRAPH; +const ErrorTypeUtils::ErrorType ErrorTypeUtils::ERRORS_TREATED_AS_A_PERFECT_MATCH = NOT_AN_ERROR; const ErrorTypeUtils::ErrorType ErrorTypeUtils::ERRORS_TREATED_AS_AN_EXACT_MATCH_WITH_INTENTIONAL_OMISSION = diff --git a/native/jni/src/suggest/core/dictionary/error_type_utils.h b/native/jni/src/suggest/core/dictionary/error_type_utils.h index fd1d5fcff..e92c509fa 100644 --- a/native/jni/src/suggest/core/dictionary/error_type_utils.h +++ b/native/jni/src/suggest/core/dictionary/error_type_utils.h @@ -52,6 +52,10 @@ class ErrorTypeUtils { return (containedErrorTypes & ~ERRORS_TREATED_AS_AN_EXACT_MATCH) == 0; } + static bool isPerfectMatch(const ErrorType containedErrorTypes) { + return (containedErrorTypes & ~ERRORS_TREATED_AS_A_PERFECT_MATCH) == 0; + } + static bool isExactMatchWithIntentionalOmission(const ErrorType containedErrorTypes) { return (containedErrorTypes & ~ERRORS_TREATED_AS_AN_EXACT_MATCH_WITH_INTENTIONAL_OMISSION) == 0; @@ -73,6 +77,7 @@ class ErrorTypeUtils { DISALLOW_IMPLICIT_CONSTRUCTORS(ErrorTypeUtils); static const ErrorType ERRORS_TREATED_AS_AN_EXACT_MATCH; + static const ErrorType ERRORS_TREATED_AS_A_PERFECT_MATCH; static const ErrorType ERRORS_TREATED_AS_AN_EXACT_MATCH_WITH_INTENTIONAL_OMISSION; }; } // namespace latinime diff --git a/native/jni/src/suggest/core/dictionary/ngram_listener.h b/native/jni/src/suggest/core/dictionary/ngram_listener.h index e9b3c1aaf..2eb5e9fd1 100644 --- a/native/jni/src/suggest/core/dictionary/ngram_listener.h +++ b/native/jni/src/suggest/core/dictionary/ngram_listener.h @@ -26,6 +26,8 @@ namespace latinime { */ class NgramListener { public: + // ngramProbability is always 0 for v403 decaying dictionary. + // TODO: Remove ngramProbability. virtual void onVisitEntry(const int ngramProbability, const int targetWordId) = 0; virtual ~NgramListener() {}; diff --git a/native/jni/src/suggest/core/policy/scoring.h b/native/jni/src/suggest/core/policy/scoring.h index ce3684a1c..b9dda83ad 100644 --- a/native/jni/src/suggest/core/policy/scoring.h +++ b/native/jni/src/suggest/core/policy/scoring.h @@ -30,7 +30,7 @@ class Scoring { public: virtual int calculateFinalScore(const float compoundDistance, const int inputSize, const ErrorTypeUtils::ErrorType containedErrorTypes, const bool forceCommit, - const bool boostExactMatches) const = 0; + const bool boostExactMatches, const bool hasProbabilityZero) const = 0; virtual void getMostProbableString(const DicTraverseSession *const traverseSession, const float weightOfLangModelVsSpatialModel, SuggestionResults *const outSuggestionResults) const = 0; diff --git a/native/jni/src/suggest/core/policy/weighting.cpp b/native/jni/src/suggest/core/policy/weighting.cpp index a06e7d070..450203d98 100644 --- a/native/jni/src/suggest/core/policy/weighting.cpp +++ b/native/jni/src/suggest/core/policy/weighting.cpp @@ -119,7 +119,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n return weighting->getSubstitutionCost() + weighting->getMatchedCost(traverseSession, dicNode, inputStateG); case CT_NEW_WORD_SPACE_OMISSION: - return weighting->getNewWordSpatialCost(traverseSession, dicNode, inputStateG); + return weighting->getSpaceOmissionCost(traverseSession, dicNode, inputStateG); case CT_MATCH: return weighting->getMatchedCost(traverseSession, dicNode, inputStateG); case CT_COMPLETION: diff --git a/native/jni/src/suggest/core/policy/weighting.h b/native/jni/src/suggest/core/policy/weighting.h index bd6b3cf41..863c4eabe 100644 --- a/native/jni/src/suggest/core/policy/weighting.h +++ b/native/jni/src/suggest/core/policy/weighting.h @@ -57,7 +57,7 @@ class Weighting { const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0; - virtual float getNewWordSpatialCost(const DicTraverseSession *const traverseSession, + virtual float getSpaceOmissionCost(const DicTraverseSession *const traverseSession, const DicNode *const dicNode, DicNode_InputStateG *const inputStateG) const = 0; virtual float getNewWordBigramLanguageCost( diff --git a/native/jni/src/suggest/core/result/suggestions_output_utils.cpp b/native/jni/src/suggest/core/result/suggestions_output_utils.cpp index 3283f6deb..74db95953 100644 --- a/native/jni/src/suggest/core/result/suggestions_output_utils.cpp +++ b/native/jni/src/suggest/core/result/suggestions_output_utils.cpp @@ -76,6 +76,52 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; weightOfLangModelVsSpatialModelToOutputSuggestions, outSuggestionResults); } +/* static */ bool SuggestionsOutputUtils::shouldBlockWord( + const SuggestOptions *const suggestOptions, const DicNode *const terminalDicNode, + const WordAttributes wordAttributes, const bool isLastWord) { + const bool currentWordExactMatch = + ErrorTypeUtils::isExactMatch(terminalDicNode->getContainedErrorTypes()); + // When we have to block offensive words, non-exact matched offensive words should not be + // output. + const bool shouldBlockOffensiveWords = suggestOptions->blockOffensiveWords(); + + const bool isBlockedOffensiveWord = shouldBlockOffensiveWords && + wordAttributes.isPossiblyOffensive(); + + // This function is called in two situations: + // + // 1) At the end of a search, in which case terminalDicNode will point to the last DicNode + // of the search, and isLastWord will be true. + // "fuck" + // | + // \ terminalDicNode (isLastWord=true, currentWordExactMatch=true) + // In this case, if the current word is an exact match, we will always let the word + // through, even if the user is blocking offensive words (it's exactly what they typed!) + // + // 2) In the middle of the search, when we hit a terminal node, to decide whether or not + // to start a new search at root, to try to match the rest of the input. In this case, + // terminalDicNode will point to the terminal node we just hit, and isLastWord will be + // false. + // "fuckvthis" + // | + // \ terminalDicNode (isLastWord=false, currentWordExactMatch=true) + // + // In this case, we should NOT allow the match through (correcting "fuckthis" to "fuck this" + // when offensive words are blocked would be a bad idea). + // + // In the case of a multi-word correction where the offensive word is typed last (eg. + // for the input "allfuck"), this function will be called with isLastWord==true, but + // currentWordExactMatch==false. So we are OK in this case as well. + // "allfuck" + // | + // \ terminalDicNode (isLastWord=true, currentWordExactMatch=false) + if (isLastWord && currentWordExactMatch) { + return false; + } else { + return isBlockedOffensiveWord; + } +} + /* static */ void SuggestionsOutputUtils::outputSuggestionsOfDicNode( const Scoring *const scoringPolicy, DicTraverseSession *traverseSession, const DicNode *const terminalDicNode, const float weightOfLangModelVsSpatialModel, @@ -98,24 +144,16 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; const bool isExactMatchWithIntentionalOmission = ErrorTypeUtils::isExactMatchWithIntentionalOmission( terminalDicNode->getContainedErrorTypes()); - const bool isFirstCharUppercase = terminalDicNode->isFirstCharUppercase(); - // Heuristic: We exclude probability=0 first-char-uppercase words from exact match. - // (e.g. "AMD" and "and") - const bool isSafeExactMatch = isExactMatch - && !(wordAttributes.isPossiblyOffensive() && isFirstCharUppercase); const int outputTypeFlags = (wordAttributes.isPossiblyOffensive() ? Dictionary::KIND_FLAG_POSSIBLY_OFFENSIVE : 0) - | ((isSafeExactMatch && boostExactMatches) ? Dictionary::KIND_FLAG_EXACT_MATCH : 0) + | ((isExactMatch && boostExactMatches) ? Dictionary::KIND_FLAG_EXACT_MATCH : 0) | (isExactMatchWithIntentionalOmission ? Dictionary::KIND_FLAG_EXACT_MATCH_WITH_INTENTIONAL_OMISSION : 0); - // Entries that are blacklisted or do not represent a word should not be output. const bool isValidWord = !(wordAttributes.isBlacklisted() || wordAttributes.isNotAWord()); - // When we have to block offensive words, non-exact matched offensive words should not be - // output. - const bool blockOffensiveWords = traverseSession->getSuggestOptions()->blockOffensiveWords(); - const bool isBlockedOffensiveWord = blockOffensiveWords && wordAttributes.isPossiblyOffensive() - && !isSafeExactMatch; + + const bool shouldBlockThisWord = shouldBlockWord(traverseSession->getSuggestOptions(), + terminalDicNode, wordAttributes, true /* isLastWord */); // Increase output score of top typing suggestion to ensure autocorrection. // TODO: Better integration with java side autocorrection logic. @@ -123,11 +161,11 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; compoundDistance, traverseSession->getInputSize(), terminalDicNode->getContainedErrorTypes(), (forceCommitMultiWords && terminalDicNode->hasMultipleWords()), - boostExactMatches); + boostExactMatches, wordAttributes.getProbability() == 0); // Don't output invalid or blocked offensive words. However, we still need to submit their // shortcuts if any. - if (isValidWord && !isBlockedOffensiveWord) { + if (isValidWord && !shouldBlockThisWord) { int codePoints[MAX_WORD_LENGTH]; terminalDicNode->outputResult(codePoints); const int indexToPartialCommit = outputSecondWordFirstLetterInputIndex ? diff --git a/native/jni/src/suggest/core/result/suggestions_output_utils.h b/native/jni/src/suggest/core/result/suggestions_output_utils.h index bf8497828..eca1f78b2 100644 --- a/native/jni/src/suggest/core/result/suggestions_output_utils.h +++ b/native/jni/src/suggest/core/result/suggestions_output_utils.h @@ -18,6 +18,7 @@ #define LATINIME_SUGGESTIONS_OUTPUT_UTILS #include "defines.h" +#include "suggest/core/dictionary/word_attributes.h" namespace latinime { @@ -25,11 +26,19 @@ class BinaryDictionaryShortcutIterator; class DicNode; class DicTraverseSession; class Scoring; +class SuggestOptions; class SuggestionResults; class SuggestionsOutputUtils { public: /** + * Returns true if we should block the incoming word, in the context of the user's + * preferences to include or not include possibly offensive words + */ + static bool shouldBlockWord(const SuggestOptions *const suggestOptions, + const DicNode *const terminalDicNode, const WordAttributes wordAttributes, + const bool isLastWord); + /** * Outputs the final list of suggestions (i.e., terminal nodes). */ static void outputSuggestions(const Scoring *const scoringPolicy, diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp index c71526293..c372d668b 100644 --- a/native/jni/src/suggest/core/suggest.cpp +++ b/native/jni/src/suggest/core/suggest.cpp @@ -160,8 +160,7 @@ void Suggest::expandCurrentDicNodes(DicTraverseSession *traverseSession) const { // TODO: Remove. Do not prune node here. const bool allowsErrorCorrections = TRAVERSAL->allowsErrorCorrections(&dicNode); // Process for handling space substitution (e.g., hevis => he is) - if (allowsErrorCorrections - && TRAVERSAL->isSpaceSubstitutionTerminal(traverseSession, &dicNode)) { + if (TRAVERSAL->isSpaceSubstitutionTerminal(traverseSession, &dicNode)) { createNextWordDicNode(traverseSession, &dicNode, true /* spaceSubstitution */); } @@ -417,6 +416,11 @@ void Suggest::createNextWordDicNode(DicTraverseSession *traverseSession, DicNode traverseSession->getDictionaryStructurePolicy()->getWordAttributesInContext( dicNode->getPrevWordIds(), dicNode->getWordId(), traverseSession->getMultiBigramMap()); + if (SuggestionsOutputUtils::shouldBlockWord(traverseSession->getSuggestOptions(), + dicNode, wordAttributes, false /* isLastWord */)) { + return; + } + if (!TRAVERSAL->isGoodToTraverseNextWord(dicNode, wordAttributes.getProbability())) { return; } diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h index 44c2f443f..abc7f9906 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h @@ -134,9 +134,11 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { // same so we use them for both here. switch (mDictFormatVersion) { case FormatUtils::VERSION_2: - return FormatUtils::VERSION_2; case FormatUtils::VERSION_201: - return FormatUtils::VERSION_201; + AKLOGE("Dictionary versions 2 and 201 are incompatible with this version"); + return FormatUtils::UNKNOWN_VERSION; + case FormatUtils::VERSION_202: + return FormatUtils::VERSION_202; case FormatUtils::VERSION_4_ONLY_FOR_TESTING: return FormatUtils::VERSION_4_ONLY_FOR_TESTING; case FormatUtils::VERSION_4: diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp index 41a8b13b8..d69a53fce 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp @@ -111,7 +111,8 @@ typedef DictionaryHeaderStructurePolicy::AttributeMap AttributeMap; switch (version) { case FormatUtils::VERSION_2: case FormatUtils::VERSION_201: - // Version 2 or 201 dictionary writing is not supported. + case FormatUtils::VERSION_202: + // None of the static dictionaries (v2x) support writing return false; case FormatUtils::VERSION_4_ONLY_FOR_TESTING: case FormatUtils::VERSION_4: diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp index 08e39ce43..9455222dd 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp @@ -140,7 +140,7 @@ const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext( const WordAttributes Ver4PatriciaTriePolicy::getWordAttributes(const int probability, const PtNodeParams &ptNodeParams) const { - return WordAttributes(probability, ptNodeParams.isBlacklisted(), ptNodeParams.isNotAWord(), + return WordAttributes(probability, false /* isBlacklisted */, ptNodeParams.isNotAWord(), ptNodeParams.getProbability() == 0); } @@ -164,7 +164,7 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordI } const int ptNodePos = getTerminalPtNodePosFromWordId(wordId); const PtNodeParams ptNodeParams(mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos)); - if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) { + if (ptNodeParams.isDeleted() || ptNodeParams.isNotAWord()) { return NOT_A_PROBABILITY; } if (prevWordIds.empty()) { diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp index 372c9e36f..a19a384f4 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp @@ -115,7 +115,8 @@ template<class DictConstants, class DictBuffers, class DictBuffersPtr, class Str switch (formatVersion) { case FormatUtils::VERSION_2: case FormatUtils::VERSION_201: - AKLOGE("Given path is a directory but the format is version 2 or 201. path: %s", path); + case FormatUtils::VERSION_202: + AKLOGE("Given path is a directory but the format is version 2xx. path: %s", path); break; case FormatUtils::VERSION_4: { return newPolicyForV4Dict<backward::v402::Ver4DictConstants, @@ -177,6 +178,9 @@ template<class DictConstants, class DictBuffers, class DictBuffersPtr, class Str switch (FormatUtils::detectFormatVersion(mmappedBuffer->getReadOnlyByteArrayView())) { case FormatUtils::VERSION_2: case FormatUtils::VERSION_201: + AKLOGE("Dictionary versions 2 and 201 are incompatible with this version"); + break; + case FormatUtils::VERSION_202: return DictionaryStructureWithBufferPolicy::StructurePolicyPtr( new PatriciaTriePolicy(std::move(mmappedBuffer))); case FormatUtils::VERSION_4_ONLY_FOR_TESTING: diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h index 585e87a24..e52706e07 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h @@ -144,17 +144,6 @@ class PtNodeParams { return PatriciaTrieReadingUtils::isTerminal(mFlags); } - AK_FORCE_INLINE bool isBlacklisted() const { - // Note: this method will be removed in the next change. - // It is used in getProbabilityOfWord and getWordAttributes for both v402 and v403. - // * getProbabilityOfWord will be changed to no longer return NOT_A_PROBABILITY - // when isBlacklisted (i.e. to only check if isNotAWord or isDeleted) - // * getWordAttributes will be changed to always return blacklisted=false and - // isPossiblyOffensive according to the function below (instead of the current - // behaviour of checking if the probability is zero) - return PatriciaTrieReadingUtils::isPossiblyOffensive(mFlags); - } - AK_FORCE_INLINE bool isPossiblyOffensive() const { return PatriciaTrieReadingUtils::isPossiblyOffensive(mFlags); } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp index 66fd18a52..59873612a 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp @@ -14,7 +14,6 @@ * limitations under the License. */ - #include "suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h" #include "defines.h" @@ -317,8 +316,8 @@ const WordAttributes PatriciaTriePolicy::getWordAttributesInContext( const WordAttributes PatriciaTriePolicy::getWordAttributes(const int probability, const PtNodeParams &ptNodeParams) const { - return WordAttributes(probability, ptNodeParams.isBlacklisted(), ptNodeParams.isNotAWord(), - ptNodeParams.getProbability() == 0); + return WordAttributes(probability, false /* isBlacklisted */, ptNodeParams.isNotAWord(), + ptNodeParams.isPossiblyOffensive()); } int PatriciaTriePolicy::getProbability(const int unigramProbability, @@ -345,10 +344,9 @@ int PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds, const int ptNodePos = getTerminalPtNodePosFromWordId(wordId); const PtNodeParams ptNodeParams = mPtNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); - if (ptNodeParams.isNotAWord() || ptNodeParams.isBlacklisted()) { - // If this is not a word, or if it's a blacklisted entry, it should behave as - // having no probability outside of the suggestion process (where it should be used - // for shortcuts). + if (ptNodeParams.isNotAWord()) { + // If this is not a word, it should behave as having no probability outside of the + // suggestion process (where it should be used for shortcuts). return NOT_A_PROBABILITY; } if (!prevWordIds.empty()) { diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.cpp new file mode 100644 index 000000000..b0fbb3e72 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.cpp @@ -0,0 +1,37 @@ +/* + * Copyright (C) 2014, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h" + +namespace latinime { + +// These counts are used to provide stable probabilities even if the user's input count is small. +const int DynamicLanguageModelProbabilityUtils::ASSUMED_MIN_COUNT_FOR_UNIGRAMS = 8192; +const int DynamicLanguageModelProbabilityUtils::ASSUMED_MIN_COUNT_FOR_BIGRAMS = 2; +const int DynamicLanguageModelProbabilityUtils::ASSUMED_MIN_COUNT_FOR_TRIGRAMS = 2; + +// These are encoded backoff weights. +// Note that we give positive value for trigrams that means the weight is more than 1. +// TODO: Apply backoff for main dictionaries and quit giving a positive backoff weight. +const int DynamicLanguageModelProbabilityUtils::ENCODED_BACKOFF_WEIGHT_FOR_UNIGRAMS = -32; +const int DynamicLanguageModelProbabilityUtils::ENCODED_BACKOFF_WEIGHT_FOR_BIGRAMS = 0; +const int DynamicLanguageModelProbabilityUtils::ENCODED_BACKOFF_WEIGHT_FOR_TRIGRAMS = 8; + +// This value is used to remove too old entries from the dictionary. +const int DynamicLanguageModelProbabilityUtils::DURATION_TO_DISCARD_ENTRY_IN_SECONDS = + 300 * 24 * 60 * 60; // 300 days + +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h new file mode 100644 index 000000000..88bc58fe8 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h @@ -0,0 +1,114 @@ +/* + * Copyright (C) 2014, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DYNAMIC_LANGUAGE_MODEL_PROBABILITY_UTILS_H +#define LATINIME_DYNAMIC_LANGUAGE_MODEL_PROBABILITY_UTILS_H + +#include <algorithm> + +#include "defines.h" +#include "suggest/core/dictionary/property/historical_info.h" +#include "utils/time_keeper.h" + +namespace latinime { + +class DynamicLanguageModelProbabilityUtils { + public: + static float computeRawProbabilityFromCounts(const int count, const int contextCount, + const int matchedWordCountInContext) { + int minCount = 0; + switch (matchedWordCountInContext) { + case 1: + minCount = ASSUMED_MIN_COUNT_FOR_UNIGRAMS; + break; + case 2: + minCount = ASSUMED_MIN_COUNT_FOR_BIGRAMS; + break; + case 3: + minCount = ASSUMED_MIN_COUNT_FOR_TRIGRAMS; + break; + default: + AKLOGE("computeRawProbabilityFromCounts is called with invalid " + "matchedWordCountInContext (%d).", matchedWordCountInContext); + ASSERT(false); + return 0.0f; + } + return static_cast<float>(count) / static_cast<float>(std::max(contextCount, minCount)); + } + + static float backoff(const int ngramProbability, const int matchedWordCountInContext) { + int probability = NOT_A_PROBABILITY; + + switch (matchedWordCountInContext) { + case 1: + probability = ngramProbability + ENCODED_BACKOFF_WEIGHT_FOR_UNIGRAMS; + break; + case 2: + probability = ngramProbability + ENCODED_BACKOFF_WEIGHT_FOR_BIGRAMS; + break; + case 3: + probability = ngramProbability + ENCODED_BACKOFF_WEIGHT_FOR_TRIGRAMS; + break; + default: + AKLOGE("backoff is called with invalid matchedWordCountInContext (%d).", + matchedWordCountInContext); + ASSERT(false); + return NOT_A_PROBABILITY; + } + return std::min(std::max(probability, NOT_A_PROBABILITY), MAX_PROBABILITY); + } + + static int getDecayedProbability(const int probability, const HistoricalInfo historicalInfo) { + const int elapsedTime = TimeKeeper::peekCurrentTime() - historicalInfo.getTimestamp(); + if (elapsedTime < 0) { + AKLOGE("The elapsed time is negatime value. Timestamp overflow?"); + return NOT_A_PROBABILITY; + } + // TODO: Improve this logic. + // We don't modify probability depending on the elapsed time. + return probability; + } + + static int shouldRemoveEntryDuringGC(const HistoricalInfo historicalInfo) { + // TODO: Improve this logic. + const int elapsedTime = TimeKeeper::peekCurrentTime() - historicalInfo.getTimestamp(); + return elapsedTime > DURATION_TO_DISCARD_ENTRY_IN_SECONDS; + } + + static int getPriorityToPreventFromEviction(const HistoricalInfo historicalInfo) { + // TODO: Improve this logic. + // More recently input entries get higher priority. + return historicalInfo.getTimestamp(); + } + +private: + DISALLOW_IMPLICIT_CONSTRUCTORS(DynamicLanguageModelProbabilityUtils); + + static_assert(MAX_PREV_WORD_COUNT_FOR_N_GRAM <= 2, "Max supported Ngram is Trigram."); + + static const int ASSUMED_MIN_COUNT_FOR_UNIGRAMS; + static const int ASSUMED_MIN_COUNT_FOR_BIGRAMS; + static const int ASSUMED_MIN_COUNT_FOR_TRIGRAMS; + + static const int ENCODED_BACKOFF_WEIGHT_FOR_UNIGRAMS; + static const int ENCODED_BACKOFF_WEIGHT_FOR_BIGRAMS; + static const int ENCODED_BACKOFF_WEIGHT_FOR_TRIGRAMS; + + static const int DURATION_TO_DISCARD_ENTRY_IN_SECONDS; +}; + +} // namespace latinime +#endif /* LATINIME_DYNAMIC_LANGUAGE_MODEL_PROBABILITY_UTILS_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp index 509bd683b..31b1ea696 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp @@ -19,14 +19,16 @@ #include <algorithm> #include <cstring> -#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h" +#include "suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h" +#include "suggest/policyimpl/dictionary/utils/probability_utils.h" namespace latinime { -const int LanguageModelDictContent::DUMMY_PROBABILITY_FOR_VALID_WORDS = 1; +const int LanguageModelDictContent::TRIE_MAP_BUFFER_INDEX = 0; +const int LanguageModelDictContent::GLOBAL_COUNTERS_BUFFER_INDEX = 1; bool LanguageModelDictContent::save(FILE *const file) const { - return mTrieMap.save(file); + return mTrieMap.save(file) && mGlobalCounters.save(file); } bool LanguageModelDictContent::runGC( @@ -37,7 +39,8 @@ bool LanguageModelDictContent::runGC( } const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArrayView prevWordIds, - const int wordId, const HeaderPolicy *const headerPolicy) const { + const int wordId, const bool mustMatchAllPrevWords, + const HeaderPolicy *const headerPolicy) const { int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex(); int maxPrevWordCount = 0; @@ -51,7 +54,15 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr bitmapEntryIndices[i + 1] = nextBitmapEntryIndex; } + const ProbabilityEntry unigramProbabilityEntry = getProbabilityEntry(wordId); + if (mHasHistoricalInfo && unigramProbabilityEntry.getHistoricalInfo()->getCount() == 0) { + // The word should be treated as a invalid word. + return WordAttributes(); + } for (int i = maxPrevWordCount; i >= 0; --i) { + if (mustMatchAllPrevWords && prevWordIds.size() > static_cast<size_t>(i)) { + break; + } const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndices[i]); if (!result.mIsValid) { continue; @@ -60,36 +71,39 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo); int probability = NOT_A_PROBABILITY; if (mHasHistoricalInfo) { - const int rawProbability = ForgettingCurveUtils::decodeProbability( - probabilityEntry.getHistoricalInfo(), headerPolicy); - if (rawProbability == NOT_A_PROBABILITY) { - // The entry should not be treated as a valid entry. - continue; - } + const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo(); + int contextCount = 0; if (i == 0) { // unigram - probability = rawProbability; + contextCount = mGlobalCounters.getTotalCount(); } else { const ProbabilityEntry prevWordProbabilityEntry = getNgramProbabilityEntry( prevWordIds.skip(1 /* n */).limit(i - 1), prevWordIds[0]); if (!prevWordProbabilityEntry.isValid()) { continue; } - if (prevWordProbabilityEntry.representsBeginningOfSentence()) { - probability = rawProbability; - } else { - const int prevWordRawProbability = ForgettingCurveUtils::decodeProbability( - prevWordProbabilityEntry.getHistoricalInfo(), headerPolicy); - probability = std::min(MAX_PROBABILITY - prevWordRawProbability - + rawProbability, MAX_PROBABILITY); + if (prevWordProbabilityEntry.representsBeginningOfSentence() + && historicalInfo->getCount() == 1) { + // BoS ngram requires multiple contextCount. + continue; } + contextCount = prevWordProbabilityEntry.getHistoricalInfo()->getCount(); } + const float rawProbability = + DynamicLanguageModelProbabilityUtils::computeRawProbabilityFromCounts( + historicalInfo->getCount(), contextCount, i + 1); + const int encodedRawProbability = + ProbabilityUtils::encodeRawProbability(rawProbability); + const int decayedProbability = + DynamicLanguageModelProbabilityUtils::getDecayedProbability( + encodedRawProbability, *historicalInfo); + probability = DynamicLanguageModelProbabilityUtils::backoff( + decayedProbability, i + 1 /* n */); } else { probability = probabilityEntry.getProbability(); } // TODO: Some flags in unigramProbabilityEntry should be overwritten by flags in // probabilityEntry. - const ProbabilityEntry unigramProbabilityEntry = getProbabilityEntry(wordId); return WordAttributes(probability, unigramProbabilityEntry.isBlacklisted(), unigramProbabilityEntry.isNotAWord(), unigramProbabilityEntry.isPossiblyOffensive()); @@ -165,7 +179,8 @@ void LanguageModelDictContent::exportAllNgramEntriesRelatedToWordInner( ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo); if (probabilityEntry.isValid()) { const WordAttributes wordAttributes = getWordAttributes( - WordIdArrayView(*prevWordIds), wordId, headerPolicy); + WordIdArrayView(*prevWordIds), wordId, true /* mustMatchAllPrevWords */, + headerPolicy); outBummpedFullEntryInfo->emplace_back(*prevWordIds, wordId, wordAttributes, probabilityEntry); } @@ -212,6 +227,9 @@ bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView if (!setProbabilityEntry(wordId, &updatedUnigramProbabilityEntry)) { return false; } + mGlobalCounters.incrementTotalCount(); + mGlobalCounters.updateMaxValueOfCounters( + updatedUnigramProbabilityEntry.getHistoricalInfo()->getCount()); for (size_t i = 0; i < prevWordIds.size(); ++i) { if (prevWordIds[i] == NOT_A_WORD_ID) { break; @@ -225,6 +243,8 @@ bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView if (!setNgramProbabilityEntry(limitedPrevWordIds, wordId, &updatedNgramProbabilityEntry)) { return false; } + mGlobalCounters.updateMaxValueOfCounters( + updatedNgramProbabilityEntry.getHistoricalInfo()->getCount()); if (!originalNgramProbabilityEntry.isValid()) { entryCountersToUpdate->incrementNgramCount(i + 2); } @@ -235,10 +255,9 @@ bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView const ProbabilityEntry LanguageModelDictContent::createUpdatedEntryFrom( const ProbabilityEntry &originalProbabilityEntry, const bool isValid, const HistoricalInfo historicalInfo, const HeaderPolicy *const headerPolicy) const { - const HistoricalInfo updatedHistoricalInfo = ForgettingCurveUtils::createUpdatedHistoricalInfo( - originalProbabilityEntry.getHistoricalInfo(), isValid ? - DUMMY_PROBABILITY_FOR_VALID_WORDS : NOT_A_PROBABILITY, - &historicalInfo, headerPolicy); + const HistoricalInfo updatedHistoricalInfo = HistoricalInfo(historicalInfo.getTimestamp(), + 0 /* level */, originalProbabilityEntry.getHistoricalInfo()->getCount() + + historicalInfo.getCount()); if (originalProbabilityEntry.isValid()) { return ProbabilityEntry(originalProbabilityEntry.getFlags(), &updatedHistoricalInfo); } else { @@ -304,7 +323,7 @@ int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWord bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, const int prevWordCount, const HeaderPolicy *const headerPolicy, - MutableEntryCounters *const outEntryCounters) { + const bool needsToHalveCounters, MutableEntryCounters *const outEntryCounters) { for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) { if (prevWordCount > MAX_PREV_WORD_COUNT_FOR_N_GRAM) { AKLOGE("Invalid prevWordCount. prevWordCount: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.", @@ -321,33 +340,41 @@ bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int b } continue; } - if (mHasHistoricalInfo && !probabilityEntry.representsBeginningOfSentence() - && probabilityEntry.isValid()) { - const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave( - probabilityEntry.getHistoricalInfo(), headerPolicy); - if (ForgettingCurveUtils::needsToKeep(&historicalInfo, headerPolicy)) { - // Update the entry. - const ProbabilityEntry updatedEntry(probabilityEntry.getFlags(), &historicalInfo); - if (!mTrieMap.put(entry.key(), updatedEntry.encode(mHasHistoricalInfo), - bitmapEntryIndex)) { - return false; - } - } else { + if (mHasHistoricalInfo && probabilityEntry.isValid()) { + const HistoricalInfo *originalHistoricalInfo = probabilityEntry.getHistoricalInfo(); + if (DynamicLanguageModelProbabilityUtils::shouldRemoveEntryDuringGC( + *originalHistoricalInfo)) { // Remove the entry. if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) { return false; } continue; } + if (needsToHalveCounters) { + const int updatedCount = originalHistoricalInfo->getCount() / 2; + if (updatedCount == 0) { + // Remove the entry. + if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) { + return false; + } + continue; + } + const HistoricalInfo historicalInfoToSave(originalHistoricalInfo->getTimestamp(), + originalHistoricalInfo->getLevel(), updatedCount); + const ProbabilityEntry updatedEntry(probabilityEntry.getFlags(), + &historicalInfoToSave); + if (!mTrieMap.put(entry.key(), updatedEntry.encode(mHasHistoricalInfo), + bitmapEntryIndex)) { + return false; + } + } } - if (!probabilityEntry.representsBeginningOfSentence()) { - outEntryCounters->incrementNgramCount(prevWordCount + 1); - } + outEntryCounters->incrementNgramCount(prevWordCount + 1); if (!entry.hasNextLevelMap()) { continue; } if (!updateAllProbabilityEntriesForGCInner(entry.getNextLevelBitmapEntryIndex(), - prevWordCount + 1, headerPolicy, outEntryCounters)) { + prevWordCount + 1, headerPolicy, needsToHalveCounters, outEntryCounters)) { return false; } } @@ -401,11 +428,11 @@ bool LanguageModelDictContent::getEntryInfo(const HeaderPolicy *const headerPoli } const ProbabilityEntry probabilityEntry = ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo); - const int probability = (mHasHistoricalInfo) ? - ForgettingCurveUtils::decodeProbability(probabilityEntry.getHistoricalInfo(), - headerPolicy) : probabilityEntry.getProbability(); - outEntryInfo->emplace_back(probability, - probabilityEntry.getHistoricalInfo()->getTimestamp(), + const int priority = mHasHistoricalInfo + ? DynamicLanguageModelProbabilityUtils::getPriorityToPreventFromEviction( + *probabilityEntry.getHistoricalInfo()) + : probabilityEntry.getProbability(); + outEntryInfo->emplace_back(priority, probabilityEntry.getHistoricalInfo()->getCount(), entry.key(), targetLevel, prevWordIds->data()); } return true; @@ -413,11 +440,11 @@ bool LanguageModelDictContent::getEntryInfo(const HeaderPolicy *const headerPoli bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()( const EntryInfoToTurncate &left, const EntryInfoToTurncate &right) const { - if (left.mProbability != right.mProbability) { - return left.mProbability < right.mProbability; + if (left.mPriority != right.mPriority) { + return left.mPriority < right.mPriority; } - if (left.mTimestamp != right.mTimestamp) { - return left.mTimestamp > right.mTimestamp; + if (left.mCount != right.mCount) { + return left.mCount < right.mCount; } if (left.mKey != right.mKey) { return left.mKey < right.mKey; @@ -434,10 +461,9 @@ bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()( return false; } -LanguageModelDictContent::EntryInfoToTurncate::EntryInfoToTurncate(const int probability, - const int timestamp, const int key, const int prevWordCount, const int *const prevWordIds) - : mProbability(probability), mTimestamp(timestamp), mKey(key), - mPrevWordCount(prevWordCount) { +LanguageModelDictContent::EntryInfoToTurncate::EntryInfoToTurncate(const int priority, + const int count, const int key, const int prevWordCount, const int *const prevWordIds) + : mPriority(priority), mCount(count), mKey(key), mPrevWordCount(prevWordCount) { memmove(mPrevWordIds, prevWordIds, mPrevWordCount * sizeof(mPrevWordIds[0])); } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h index 1cccf92d2..9678c35f9 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h @@ -22,6 +22,7 @@ #include "defines.h" #include "suggest/core/dictionary/word_attributes.h" +#include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_global_counters.h" #include "suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h" #include "suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h" #include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h" @@ -131,15 +132,17 @@ class LanguageModelDictContent { const ProbabilityEntry mProbabilityEntry; }; - LanguageModelDictContent(const ReadWriteByteArrayView trieMapBuffer, + LanguageModelDictContent(const ReadWriteByteArrayView *const buffers, const bool hasHistoricalInfo) - : mTrieMap(trieMapBuffer), mHasHistoricalInfo(hasHistoricalInfo) {} + : mTrieMap(buffers[TRIE_MAP_BUFFER_INDEX]), + mGlobalCounters(buffers[GLOBAL_COUNTERS_BUFFER_INDEX]), + mHasHistoricalInfo(hasHistoricalInfo) {} explicit LanguageModelDictContent(const bool hasHistoricalInfo) - : mTrieMap(), mHasHistoricalInfo(hasHistoricalInfo) {} + : mTrieMap(), mGlobalCounters(), mHasHistoricalInfo(hasHistoricalInfo) {} bool isNearSizeLimit() const { - return mTrieMap.isNearSizeLimit(); + return mTrieMap.isNearSizeLimit() || mGlobalCounters.needsToHalveCounters(); } bool save(FILE *const file) const; @@ -148,13 +151,14 @@ class LanguageModelDictContent { const LanguageModelDictContent *const originalContent); const WordAttributes getWordAttributes(const WordIdArrayView prevWordIds, const int wordId, - const HeaderPolicy *const headerPolicy) const; + const bool mustMatchAllPrevWords, const HeaderPolicy *const headerPolicy) const; ProbabilityEntry getProbabilityEntry(const int wordId) const { return getNgramProbabilityEntry(WordIdArrayView(), wordId); } bool setProbabilityEntry(const int wordId, const ProbabilityEntry *const probabilityEntry) { + mGlobalCounters.addToTotalCount(probabilityEntry->getHistoricalInfo()->getCount()); return setNgramProbabilityEntry(WordIdArrayView(), wordId, probabilityEntry); } @@ -177,8 +181,15 @@ class LanguageModelDictContent { bool updateAllProbabilityEntriesForGC(const HeaderPolicy *const headerPolicy, MutableEntryCounters *const outEntryCounters) { - return updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(), - 0 /* prevWordCount */, headerPolicy, outEntryCounters); + if (!updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(), + 0 /* prevWordCount */, headerPolicy, mGlobalCounters.needsToHalveCounters(), + outEntryCounters)) { + return false; + } + if (mGlobalCounters.needsToHalveCounters()) { + mGlobalCounters.halveCounters(); + } + return true; } // entryCounts should be created by updateAllProbabilityEntries. @@ -203,11 +214,12 @@ class LanguageModelDictContent { DISALLOW_ASSIGNMENT_OPERATOR(Comparator); }; - EntryInfoToTurncate(const int probability, const int timestamp, const int key, + EntryInfoToTurncate(const int priority, const int count, const int key, const int prevWordCount, const int *const prevWordIds); - int mProbability; - int mTimestamp; + int mPriority; + // TODO: Remove. + int mCount; int mKey; int mPrevWordCount; int mPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; @@ -216,10 +228,11 @@ class LanguageModelDictContent { DISALLOW_DEFAULT_CONSTRUCTOR(EntryInfoToTurncate); }; - // TODO: Remove - static const int DUMMY_PROBABILITY_FOR_VALID_WORDS; + static const int TRIE_MAP_BUFFER_INDEX; + static const int GLOBAL_COUNTERS_BUFFER_INDEX; TrieMap mTrieMap; + LanguageModelDictContentGlobalCounters mGlobalCounters; const bool mHasHistoricalInfo; bool runGCInner(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, @@ -227,7 +240,8 @@ class LanguageModelDictContent { int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds); int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const; bool updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, const int prevWordCount, - const HeaderPolicy *const headerPolicy, MutableEntryCounters *const outEntryCounters); + const HeaderPolicy *const headerPolicy, const bool needsToHalveCounters, + MutableEntryCounters *const outEntryCounters); bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy, const int maxEntryCount, const int targetLevel, int *const outEntryCount); bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel, diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_global_counters.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_global_counters.cpp new file mode 100644 index 000000000..d6d91887e --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_global_counters.cpp @@ -0,0 +1,32 @@ +/* + * Copyright (C) 2014, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_global_counters.h" + +#include <climits> + +#include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h" + +namespace latinime { + +const int LanguageModelDictContentGlobalCounters::COUNTER_VALUE_NEAR_LIMIT_THRESHOLD = + (1 << (Ver4DictConstants::WORD_COUNT_FIELD_SIZE * CHAR_BIT)) - 64; +const int LanguageModelDictContentGlobalCounters::TOTAL_COUNT_VALUE_NEAR_LIMIT_THRESHOLD = 1 << 30; +const int LanguageModelDictContentGlobalCounters::COUNTER_SIZE_IN_BYTES = 4; +const int LanguageModelDictContentGlobalCounters::TOTAL_COUNT_INDEX = 0; +const int LanguageModelDictContentGlobalCounters::MAX_VALUE_OF_COUNTERS_INDEX = 1; + +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_global_counters.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_global_counters.h new file mode 100644 index 000000000..283c2691a --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_global_counters.h @@ -0,0 +1,101 @@ +/* + * Copyright (C) 2014, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_LANGUAGE_MODEL_DICT_CONTENT_GLOBAL_COUNTERS_H +#define LATINIME_LANGUAGE_MODEL_DICT_CONTENT_GLOBAL_COUNTERS_H + +#include <cstdio> + +#include "defines.h" +#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" +#include "suggest/policyimpl/dictionary/utils/dict_file_writing_utils.h" +#include "utils/byte_array_view.h" + +namespace latinime { + +class LanguageModelDictContentGlobalCounters { + public: + explicit LanguageModelDictContentGlobalCounters(const ReadWriteByteArrayView buffer) + : mBuffer(buffer, 0 /* maxAdditionalBufferSize */), + mTotalCount(readValue(mBuffer, TOTAL_COUNT_INDEX)), + mMaxValueOfCounters(readValue(mBuffer, MAX_VALUE_OF_COUNTERS_INDEX)) {} + + LanguageModelDictContentGlobalCounters() + : mBuffer(0 /* maxAdditionalBufferSize */), mTotalCount(0), mMaxValueOfCounters(0) {} + + bool needsToHalveCounters() const { + return mMaxValueOfCounters >= COUNTER_VALUE_NEAR_LIMIT_THRESHOLD + || mTotalCount >= TOTAL_COUNT_VALUE_NEAR_LIMIT_THRESHOLD; + } + + int getTotalCount() const { + return mTotalCount; + } + + bool save(FILE *const file) const { + BufferWithExtendableBuffer bufferToWrite( + BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE); + if (!bufferToWrite.writeUint(mTotalCount, COUNTER_SIZE_IN_BYTES, + TOTAL_COUNT_INDEX * COUNTER_SIZE_IN_BYTES)) { + return false; + } + if (!bufferToWrite.writeUint(mMaxValueOfCounters, COUNTER_SIZE_IN_BYTES, + MAX_VALUE_OF_COUNTERS_INDEX * COUNTER_SIZE_IN_BYTES)) { + return false; + } + return DictFileWritingUtils::writeBufferToFileTail(file, &bufferToWrite); + } + + void incrementTotalCount() { + mTotalCount += 1; + } + + void addToTotalCount(const int count) { + mTotalCount += count; + } + + void updateMaxValueOfCounters(const int count) { + mMaxValueOfCounters = std::max(count, mMaxValueOfCounters); + } + + void halveCounters() { + mMaxValueOfCounters /= 2; + mTotalCount /= 2; + } + +private: + DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContentGlobalCounters); + + const static int COUNTER_VALUE_NEAR_LIMIT_THRESHOLD; + const static int TOTAL_COUNT_VALUE_NEAR_LIMIT_THRESHOLD; + const static int COUNTER_SIZE_IN_BYTES; + const static int TOTAL_COUNT_INDEX; + const static int MAX_VALUE_OF_COUNTERS_INDEX; + + BufferWithExtendableBuffer mBuffer; + int mTotalCount; + int mMaxValueOfCounters; + + static int readValue(const BufferWithExtendableBuffer &buffer, const int index) { + const int pos = COUNTER_SIZE_IN_BYTES * index; + if (pos + COUNTER_SIZE_IN_BYTES > buffer.getTailPosition()) { + return 0; + } + return buffer.readUint(COUNTER_SIZE_IN_BYTES, pos); + } +}; +} // namespace latinime +#endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_GLOBAL_COUNTERS_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.cpp index 45f88e9b2..4d088dcab 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.cpp @@ -179,7 +179,7 @@ Ver4DictBuffers::Ver4DictBuffers(MmappedBuffer::MmappedBufferPtr &&headerBuffer, BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE), mTerminalPositionLookupTable( contentBuffers[Ver4DictConstants::TERMINAL_ADDRESS_LOOKUP_TABLE_BUFFER_INDEX]), - mLanguageModelDictContent(contentBuffers[Ver4DictConstants::LANGUAGE_MODEL_BUFFER_INDEX], + mLanguageModelDictContent(&contentBuffers[Ver4DictConstants::LANGUAGE_MODEL_BUFFER_INDEX], mHeaderPolicy.hasHistoricalInfoOfWords()), mShortcutDictContent(&contentBuffers[Ver4DictConstants::SHORTCUT_BUFFERS_INDEX]), mIsUpdatable(mDictBuffer->isUpdatable()) {} diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp index 8e6cb974b..eb6080a24 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp @@ -67,6 +67,6 @@ const int Ver4DictConstants::SHORTCUT_HAS_NEXT_MASK = 0x80; const size_t Ver4DictConstants::NUM_OF_BUFFERS_FOR_SINGLE_DICT_CONTENT = 1; const size_t Ver4DictConstants::NUM_OF_BUFFERS_FOR_SPARSE_TABLE_DICT_CONTENT = 3; -const size_t Ver4DictConstants::NUM_OF_BUFFERS_FOR_LANGUAGE_MODEL_DICT_CONTENT = 1; +const size_t Ver4DictConstants::NUM_OF_BUFFERS_FOR_LANGUAGE_MODEL_DICT_CONTENT = 2; } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp index 249d822b2..96d789f58 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp @@ -97,6 +97,9 @@ int Ver4PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints, return NOT_A_WORD_ID; } const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); + if (ptNodeParams.isDeleted()) { + return NOT_A_WORD_ID; + } return ptNodeParams.getTerminalId(); } @@ -107,7 +110,7 @@ const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext( return WordAttributes(); } return mBuffers->getLanguageModelDictContent()->getWordAttributes(prevWordIds, wordId, - mHeaderPolicy); + false /* mustMatchAllPrevWords */, mHeaderPolicy); } int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds, @@ -115,18 +118,13 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordI if (wordId == NOT_A_WORD_ID || prevWordIds.contains(NOT_A_WORD_ID)) { return NOT_A_PROBABILITY; } - const ProbabilityEntry probabilityEntry = - mBuffers->getLanguageModelDictContent()->getNgramProbabilityEntry(prevWordIds, wordId); - if (!probabilityEntry.isValid() || probabilityEntry.isBlacklisted() - || probabilityEntry.isNotAWord()) { + const WordAttributes wordAttributes = + mBuffers->getLanguageModelDictContent()->getWordAttributes(prevWordIds, wordId, + true /* mustMatchAllPrevWords */, mHeaderPolicy); + if (wordAttributes.isBlacklisted() || wordAttributes.isNotAWord()) { return NOT_A_PROBABILITY; } - if (mHeaderPolicy->hasHistoricalInfoOfWords()) { - return ForgettingCurveUtils::decodeProbability(probabilityEntry.getHistoricalInfo(), - mHeaderPolicy); - } else { - return probabilityEntry.getProbability(); - } + return wordAttributes.getProbability(); } BinaryDictionaryShortcutIterator Ver4PatriciaTriePolicy::getShortcutIterator( @@ -149,9 +147,7 @@ void Ver4PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordI continue; } const int probability = probabilityEntry.hasHistoricalInfo() ? - ForgettingCurveUtils::decodeProbability( - probabilityEntry.getHistoricalInfo(), mHeaderPolicy) : - probabilityEntry.getProbability(); + 0 : probabilityEntry.getProbability(); listener->onVisitEntry(probability, entry.getWordId()); } } @@ -383,25 +379,35 @@ bool Ver4PatriciaTriePolicy::updateEntriesForWordWithNgramContext( AKLOGE("Cannot add unigarm entry in updateEntriesForWordWithNgramContext()."); return false; } + if (!isValidWord) { + return true; + } wordId = getWordId(wordCodePoints, false /* tryLowerCaseSearch */); } WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; const WordIdArrayView prevWordIds = ngramContext->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */); - if (prevWordIds.firstOrDefault(NOT_A_WORD_ID) == NOT_A_WORD_ID - && ngramContext->isNthPrevWordBeginningOfSentence(1 /* n */)) { - const UnigramProperty beginningOfSentenceUnigramProperty( - true /* representsBeginningOfSentence */, - true /* isNotAWord */, false /* isPossiblyOffensive */, NOT_A_PROBABILITY, - HistoricalInfo(historicalInfo.getTimestamp(), 0 /* level */, 0 /* count */)); - if (!addUnigramEntry(ngramContext->getNthPrevWordCodePoints(1 /* n */), - &beginningOfSentenceUnigramProperty)) { - AKLOGE("Cannot add BoS entry in updateEntriesForWordWithNgramContext()."); + if (ngramContext->isNthPrevWordBeginningOfSentence(1 /* n */)) { + if (prevWordIds.firstOrDefault(NOT_A_WORD_ID) == NOT_A_WORD_ID) { + const UnigramProperty beginningOfSentenceUnigramProperty( + true /* representsBeginningOfSentence */, + true /* isNotAWord */, false /* isPossiblyOffensive */, NOT_A_PROBABILITY, + HistoricalInfo(historicalInfo.getTimestamp(), 0 /* level */, 0 /* count */)); + if (!addUnigramEntry(ngramContext->getNthPrevWordCodePoints(1 /* n */), + &beginningOfSentenceUnigramProperty)) { + AKLOGE("Cannot add BoS entry in updateEntriesForWordWithNgramContext()."); + return false; + } + // Refresh word ids. + ngramContext->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */); + } + // Update entries for beginning of sentence. + if (!mBuffers->getMutableLanguageModelDictContent()->updateAllEntriesOnInputWord( + prevWordIds.skip(1 /* n */), prevWordIds[0], true /* isVaild */, historicalInfo, + mHeaderPolicy, &mEntryCounters)) { return false; } - // Refresh word ids. - ngramContext->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */); } if (!mBuffers->getMutableLanguageModelDictContent()->updateAllEntriesOnInputWord(prevWordIds, wordId, updateAsAValidWord, historicalInfo, mHeaderPolicy, &mEntryCounters)) { @@ -539,7 +545,7 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty( } } const WordAttributes wordAttributes = languageModelDictContent->getWordAttributes( - WordIdArrayView(), wordId, mHeaderPolicy); + WordIdArrayView(), wordId, true /* mustMatchAllPrevWords */, mHeaderPolicy); const ProbabilityEntry probabilityEntry = languageModelDictContent->getProbabilityEntry(wordId); const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo(); const UnigramProperty unigramProperty(probabilityEntry.representsBeginningOfSentence(), diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp index 0cffe569d..8b47147e1 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp @@ -28,9 +28,11 @@ const size_t FormatUtils::DICTIONARY_MINIMUM_SIZE = 12; /* static */ FormatUtils::FORMAT_VERSION FormatUtils::getFormatVersion(const int formatVersion) { switch (formatVersion) { case VERSION_2: - return VERSION_2; case VERSION_201: - return VERSION_201; + AKLOGE("Dictionary versions 2 and 201 are incompatible with this version"); + return UNKNOWN_VERSION; + case VERSION_202: + return VERSION_202; case VERSION_4_ONLY_FOR_TESTING: return VERSION_4_ONLY_FOR_TESTING; case VERSION_4: diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h index 96310086b..05bd7eb8a 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h @@ -31,8 +31,12 @@ class FormatUtils { public: enum FORMAT_VERSION { // These MUST have the same values as the relevant constants in FormatSpec.java. + // TODO: Remove VERSION_2 and VERSION_201 when we: + // * Confirm that old versions of LatinIME download old-format dictionaries + // * We no longer need the corresponding constants on the Java side for dicttool VERSION_2 = 2, VERSION_201 = 201, + VERSION_202 = 202, VERSION_4_ONLY_FOR_TESTING = 399, VERSION_4 = 402, VERSION_4_DEV = 403, diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/probability_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/probability_utils.cpp new file mode 100644 index 000000000..e8fa06942 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/probability_utils.cpp @@ -0,0 +1,23 @@ +/* + * Copyright (C) 2014, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/utils/probability_utils.h" + +namespace latinime { + +const float ProbabilityUtils::PROBABILITY_ENCODING_SCALER = 8.58923700372f; + +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/probability_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/probability_utils.h index 3b339e61a..2050af1e9 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/probability_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/probability_utils.h @@ -17,6 +17,9 @@ #ifndef LATINIME_PROBABILITY_UTILS_H #define LATINIME_PROBABILITY_UTILS_H +#include <algorithm> +#include <cmath> + #include "defines.h" namespace latinime { @@ -47,8 +50,20 @@ class ProbabilityUtils { + static_cast<int>(static_cast<float>(bigramProbability + 1) * stepSize); } + // Encode probability using the same way as we are doing for main dictionaries. + static AK_FORCE_INLINE int encodeRawProbability(const float rawProbability) { + const float probability = static_cast<float>(MAX_PROBABILITY) + + log2f(rawProbability) * PROBABILITY_ENCODING_SCALER; + if (probability < 0.0f) { + return 0; + } + return std::min(static_cast<int>(probability + 0.5f), MAX_PROBABILITY); + } + private: DISALLOW_IMPLICIT_CONSTRUCTORS(ProbabilityUtils); + + static const float PROBABILITY_ENCODING_SCALER; }; } #endif /* LATINIME_PROBABILITY_UTILS_H */ diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp index 6a2db687d..856808a74 100644 --- a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp +++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp @@ -24,6 +24,7 @@ const int ScoringParams::THRESHOLD_NEXT_WORD_PROBABILITY_FOR_CAPPED = 120; const float ScoringParams::AUTOCORRECT_OUTPUT_THRESHOLD = 1.0f; const float ScoringParams::EXACT_MATCH_PROMOTION = 1.1f; +const float ScoringParams::PERFECT_MATCH_PROMOTION = 1.1f; const float ScoringParams::CASE_ERROR_PENALTY_FOR_EXACT_MATCH = 0.01f; const float ScoringParams::ACCENT_ERROR_PENALTY_FOR_EXACT_MATCH = 0.02f; const float ScoringParams::DIGRAPH_PENALTY_FOR_EXACT_MATCH = 0.03f; @@ -48,17 +49,17 @@ const float ScoringParams::INSERTION_COST_SAME_CHAR = 0.5508f; const float ScoringParams::INSERTION_COST_PROXIMITY_CHAR = 0.674f; const float ScoringParams::INSERTION_COST_FIRST_CHAR = 0.639f; const float ScoringParams::TRANSPOSITION_COST = 0.5608f; -const float ScoringParams::SPACE_SUBSTITUTION_COST = 0.334f; +const float ScoringParams::SPACE_SUBSTITUTION_COST = 0.33f; +const float ScoringParams::SPACE_OMISSION_COST = 0.1f; const float ScoringParams::ADDITIONAL_PROXIMITY_COST = 0.37972f; const float ScoringParams::SUBSTITUTION_COST = 0.3806f; -const float ScoringParams::COST_NEW_WORD = 0.0314f; const float ScoringParams::COST_SECOND_OR_LATER_WORD_FIRST_CHAR_UPPERCASE = 0.3224f; const float ScoringParams::DISTANCE_WEIGHT_LANGUAGE = 1.1214f; const float ScoringParams::COST_FIRST_COMPLETION = 0.4836f; const float ScoringParams::COST_COMPLETION = 0.00624f; const float ScoringParams::HAS_PROXIMITY_TERMINAL_COST = 0.0683f; const float ScoringParams::HAS_EDIT_CORRECTION_TERMINAL_COST = 0.0362f; -const float ScoringParams::HAS_MULTI_WORD_TERMINAL_COST = 0.4182f; +const float ScoringParams::HAS_MULTI_WORD_TERMINAL_COST = 0.3482f; const float ScoringParams::TYPING_BASE_OUTPUT_SCORE = 1.0f; const float ScoringParams::TYPING_MAX_OUTPUT_SCORE_PER_INPUT = 0.1f; const float ScoringParams::NORMALIZED_SPATIAL_DISTANCE_THRESHOLD_FOR_EDIT = 0.095f; diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.h b/native/jni/src/suggest/policyimpl/typing/scoring_params.h index 731424f3d..6f327a370 100644 --- a/native/jni/src/suggest/policyimpl/typing/scoring_params.h +++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.h @@ -34,6 +34,7 @@ class ScoringParams { static const int THRESHOLD_SHORT_WORD_LENGTH; static const float EXACT_MATCH_PROMOTION; + static const float PERFECT_MATCH_PROMOTION; static const float CASE_ERROR_PENALTY_FOR_EXACT_MATCH; static const float ACCENT_ERROR_PENALTY_FOR_EXACT_MATCH; static const float DIGRAPH_PENALTY_FOR_EXACT_MATCH; @@ -56,9 +57,9 @@ class ScoringParams { static const float INSERTION_COST_FIRST_CHAR; static const float TRANSPOSITION_COST; static const float SPACE_SUBSTITUTION_COST; + static const float SPACE_OMISSION_COST; static const float ADDITIONAL_PROXIMITY_COST; static const float SUBSTITUTION_COST; - static const float COST_NEW_WORD; static const float COST_SECOND_OR_LATER_WORD_FIRST_CHAR_UPPERCASE; static const float DISTANCE_WEIGHT_LANGUAGE; static const float COST_FIRST_COMPLETION; diff --git a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h index 0240bcf54..6acd767ea 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h @@ -44,23 +44,50 @@ class TypingScoring : public Scoring { AK_FORCE_INLINE int calculateFinalScore(const float compoundDistance, const int inputSize, const ErrorTypeUtils::ErrorType containedErrorTypes, const bool forceCommit, - const bool boostExactMatches) const { + const bool boostExactMatches, const bool hasProbabilityZero) const { const float maxDistance = ScoringParams::DISTANCE_WEIGHT_LANGUAGE + static_cast<float>(inputSize) * ScoringParams::TYPING_MAX_OUTPUT_SCORE_PER_INPUT; float score = ScoringParams::TYPING_BASE_OUTPUT_SCORE - compoundDistance / maxDistance; if (forceCommit) { score += ScoringParams::AUTOCORRECT_OUTPUT_THRESHOLD; } - if (boostExactMatches && ErrorTypeUtils::isExactMatch(containedErrorTypes)) { - score += ScoringParams::EXACT_MATCH_PROMOTION; - if ((ErrorTypeUtils::MATCH_WITH_WRONG_CASE & containedErrorTypes) != 0) { - score -= ScoringParams::CASE_ERROR_PENALTY_FOR_EXACT_MATCH; + if (hasProbabilityZero) { + // Previously, when both legitimate 0-frequency words (such as distracters) and + // offensive words were encoded in the same way, distracters would never show up + // when the user blocked offensive words (the default setting, as well as the + // setting for regression tests). + // + // When b/11031090 was fixed and a separate encoding was used for offensive words, + // 0-frequency words would no longer be blocked when they were an "exact match" + // (where case mismatches and accent mismatches would be considered an "exact + // match"). The exact match boosting functionality meant that, for example, when + // the user typed "mt" they would be suggested the word "Mt", although they most + // probably meant to type "my". + // + // For this reason, we introduced this change, which does the following: + // * Defines the "perfect match" as a really exact match, with no room for case or + // accent mismatches + // * When the target word has probability zero (as "Mt" does, because it is a + // distracter), ONLY boost its score if it is a perfect match. + // + // By doing this, when the user types "mt", the word "Mt" will NOT be boosted, and + // they will get "my". However, if the user makes an explicit effort to type "Mt", + // we do boost the word "Mt" so that the user's input is not autocorrected to "My". + if (boostExactMatches && ErrorTypeUtils::isPerfectMatch(containedErrorTypes)) { + score += ScoringParams::PERFECT_MATCH_PROMOTION; } - if ((ErrorTypeUtils::MATCH_WITH_MISSING_ACCENT & containedErrorTypes) != 0) { - score -= ScoringParams::ACCENT_ERROR_PENALTY_FOR_EXACT_MATCH; - } - if ((ErrorTypeUtils::MATCH_WITH_DIGRAPH & containedErrorTypes) != 0) { - score -= ScoringParams::DIGRAPH_PENALTY_FOR_EXACT_MATCH; + } else { + if (boostExactMatches && ErrorTypeUtils::isExactMatch(containedErrorTypes)) { + score += ScoringParams::EXACT_MATCH_PROMOTION; + if ((ErrorTypeUtils::MATCH_WITH_WRONG_CASE & containedErrorTypes) != 0) { + score -= ScoringParams::CASE_ERROR_PENALTY_FOR_EXACT_MATCH; + } + if ((ErrorTypeUtils::MATCH_WITH_MISSING_ACCENT & containedErrorTypes) != 0) { + score -= ScoringParams::ACCENT_ERROR_PENALTY_FOR_EXACT_MATCH; + } + if ((ErrorTypeUtils::MATCH_WITH_DIGRAPH & containedErrorTypes) != 0) { + score -= ScoringParams::DIGRAPH_PENALTY_FOR_EXACT_MATCH; + } } } return static_cast<int>(score * SUGGEST_INTERFACE_OUTPUT_SCALE); diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h index 84077174d..1338ac81a 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h @@ -150,9 +150,10 @@ class TypingWeighting : public Weighting { return cost + weightedDistance; } - float getNewWordSpatialCost(const DicTraverseSession *const traverseSession, + float getSpaceOmissionCost(const DicTraverseSession *const traverseSession, const DicNode *const dicNode, DicNode_InputStateG *inputStateG) const { - return ScoringParams::COST_NEW_WORD * traverseSession->getMultiWordCostMultiplier(); + const float cost = ScoringParams::SPACE_OMISSION_COST; + return cost * traverseSession->getMultiWordCostMultiplier(); } float getNewWordBigramLanguageCost(const DicTraverseSession *const traverseSession, @@ -202,7 +203,10 @@ class TypingWeighting : public Weighting { AK_FORCE_INLINE float getSpaceSubstitutionCost(const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const { - const float cost = ScoringParams::SPACE_SUBSTITUTION_COST + ScoringParams::COST_NEW_WORD; + const int inputIndex = dicNode->getInputIndex(0); + const float distanceToSpaceKey = traverseSession->getProximityInfoState(0) + ->getPointToKeyLength(inputIndex, KEYCODE_SPACE); + const float cost = ScoringParams::SPACE_SUBSTITUTION_COST * distanceToSpaceKey; return cost * traverseSession->getMultiWordCostMultiplier(); } diff --git a/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_global_counters_test.cpp b/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_global_counters_test.cpp new file mode 100644 index 000000000..44b5a8aaa --- /dev/null +++ b/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_global_counters_test.cpp @@ -0,0 +1,60 @@ +/* + * Copyright (C) 2014 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_global_counters.h" + +#include <gtest/gtest.h> + +#include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h" + +namespace latinime { +namespace { + +TEST(LanguageModelDictContentGlobalCountersTest, TestUpdateMaxValueOfCounters) { + LanguageModelDictContentGlobalCounters globalCounters; + + EXPECT_FALSE(globalCounters.needsToHalveCounters()); + globalCounters.updateMaxValueOfCounters(10); + EXPECT_FALSE(globalCounters.needsToHalveCounters()); + const int count = (1 << (Ver4DictConstants::WORD_COUNT_FIELD_SIZE * CHAR_BIT)) - 1; + globalCounters.updateMaxValueOfCounters(count); + EXPECT_TRUE(globalCounters.needsToHalveCounters()); + globalCounters.halveCounters(); + EXPECT_FALSE(globalCounters.needsToHalveCounters()); +} + +TEST(LanguageModelDictContentGlobalCountersTest, TestIncrementTotalCount) { + LanguageModelDictContentGlobalCounters globalCounters; + + EXPECT_EQ(0, globalCounters.getTotalCount()); + globalCounters.incrementTotalCount(); + EXPECT_EQ(1, globalCounters.getTotalCount()); + for (int i = 1; i < 50; ++i) { + globalCounters.incrementTotalCount(); + } + EXPECT_EQ(50, globalCounters.getTotalCount()); + globalCounters.halveCounters(); + EXPECT_EQ(25, globalCounters.getTotalCount()); + globalCounters.halveCounters(); + EXPECT_EQ(12, globalCounters.getTotalCount()); + for (int i = 0; i < 4; ++i) { + globalCounters.halveCounters(); + } + EXPECT_EQ(0, globalCounters.getTotalCount()); +} + +} // namespace +} // namespace latinime diff --git a/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp b/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp index 4469dc715..86040f12c 100644 --- a/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp +++ b/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp @@ -108,14 +108,14 @@ TEST(LanguageModelDictContentTest, TestGetWordProbability) { languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1), wordId, &bigramProbabilityEntry); EXPECT_EQ(bigramProbability, languageModelDictContent.getWordAttributes(prevWordIds, wordId, - nullptr /* headerPolicy */).getProbability()); + false /* mustMatchAllPrevWords */, nullptr /* headerPolicy */).getProbability()); const ProbabilityEntry trigramProbabilityEntry(flag, trigramProbability); languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1), prevWordIds[1], &probabilityEntry); languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(2), wordId, &trigramProbabilityEntry); EXPECT_EQ(trigramProbability, languageModelDictContent.getWordAttributes(prevWordIds, wordId, - nullptr /* headerPolicy */).getProbability()); + false /* mustMatchAllPrevWords */, nullptr /* headerPolicy */).getProbability()); } } // namespace diff --git a/native/jni/tests/suggest/policyimpl/dictionary/utils/probability_utils_test.cpp b/native/jni/tests/suggest/policyimpl/dictionary/utils/probability_utils_test.cpp new file mode 100644 index 000000000..be1f278c6 --- /dev/null +++ b/native/jni/tests/suggest/policyimpl/dictionary/utils/probability_utils_test.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2014 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/utils/probability_utils.h" + +#include <gtest/gtest.h> + +#include "defines.h" + +namespace latinime { +namespace { + +TEST(ProbabilityUtilsTest, TestEncodeRawProbability) { + EXPECT_EQ(MAX_PROBABILITY, ProbabilityUtils::encodeRawProbability(1.0f)); + EXPECT_EQ(MAX_PROBABILITY - 9, ProbabilityUtils::encodeRawProbability(0.5f)); + EXPECT_EQ(0, ProbabilityUtils::encodeRawProbability(0.0f)); +} + +} // namespace +} // namespace latinime |