aboutsummaryrefslogtreecommitdiffstats
path: root/native/jni/src
diff options
context:
space:
mode:
Diffstat (limited to 'native/jni/src')
-rw-r--r--native/jni/src/suggest/core/dicnode/dic_node_utils.cpp5
-rw-r--r--native/jni/src/suggest/core/dictionary/dictionary.cpp8
-rw-r--r--native/jni/src/suggest/core/dictionary/dictionary.h3
-rw-r--r--native/jni/src/suggest/core/dictionary/error_type_utils.cpp1
-rw-r--r--native/jni/src/suggest/core/dictionary/error_type_utils.h5
-rw-r--r--native/jni/src/suggest/core/dictionary/ngram_listener.h2
-rw-r--r--native/jni/src/suggest/core/dictionary/property/historical_info.h1
-rw-r--r--native/jni/src/suggest/core/dictionary/property/ngram_property.h14
-rw-r--r--native/jni/src/suggest/core/dictionary/property/unigram_property.h49
-rw-r--r--native/jni/src/suggest/core/dictionary/property/word_property.cpp50
-rw-r--r--native/jni/src/suggest/core/dictionary/property/word_property.h10
-rw-r--r--native/jni/src/suggest/core/dictionary/word_attributes.h8
-rw-r--r--native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h5
-rw-r--r--native/jni/src/suggest/core/policy/scoring.h2
-rw-r--r--native/jni/src/suggest/core/policy/weighting.cpp2
-rw-r--r--native/jni/src/suggest/core/policy/weighting.h2
-rw-r--r--native/jni/src/suggest/core/result/suggestions_output_utils.cpp66
-rw-r--r--native/jni/src/suggest/core/result/suggestions_output_utils.h9
-rw-r--r--native/jni/src/suggest/core/session/ngram_context.cpp123
-rw-r--r--native/jni/src/suggest/core/session/ngram_context.h121
-rw-r--r--native/jni/src/suggest/core/suggest.cpp8
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp28
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h52
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp7
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/bigram_dict_content.cpp2
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/probability_dict_content.cpp3
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp11
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp46
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h11
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.cpp11
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.h4
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp18
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.cpp23
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.h8
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.cpp4
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h10
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h4
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp18
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h3
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.cpp37
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h114
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp253
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h87
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_global_counters.cpp32
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_global_counters.h101
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h2
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.cpp2
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp6
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h1
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp2
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp163
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h13
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp45
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h8
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp8
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/utils/entry_counters.h133
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp17
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h19
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp14
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h8
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/utils/probability_utils.cpp23
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/utils/probability_utils.h15
-rw-r--r--native/jni/src/suggest/policyimpl/typing/scoring_params.cpp7
-rw-r--r--native/jni/src/suggest/policyimpl/typing/scoring_params.h3
-rw-r--r--native/jni/src/suggest/policyimpl/typing/typing_scoring.h47
-rw-r--r--native/jni/src/suggest/policyimpl/typing/typing_weighting.h10
-rw-r--r--native/jni/src/utils/char_utils.h11
-rw-r--r--native/jni/src/utils/jni_data_utils.h3
68 files changed, 1377 insertions, 564 deletions
diff --git a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp
index 7d2898b7a..ea438922f 100644
--- a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp
+++ b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp
@@ -74,8 +74,9 @@ namespace latinime {
}
const WordAttributes wordAttributes = dictionaryStructurePolicy->getWordAttributesInContext(
dicNode->getPrevWordIds(), dicNode->getWordId(), multiBigramMap);
- if (dicNode->hasMultipleWords()
- && (wordAttributes.isBlacklisted() || wordAttributes.isNotAWord())) {
+ if (wordAttributes.getProbability() == NOT_A_PROBABILITY
+ || (dicNode->hasMultipleWords()
+ && (wordAttributes.isBlacklisted() || wordAttributes.isNotAWord()))) {
return static_cast<float>(MAX_VALUE_FOR_WEIGHTING);
}
// TODO: This equation to calculate the improbability looks unreasonable. Investigate this.
diff --git a/native/jni/src/suggest/core/dictionary/dictionary.cpp b/native/jni/src/suggest/core/dictionary/dictionary.cpp
index 697e99ffb..6a5df9d95 100644
--- a/native/jni/src/suggest/core/dictionary/dictionary.cpp
+++ b/native/jni/src/suggest/core/dictionary/dictionary.cpp
@@ -81,6 +81,9 @@ void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbabi
}
const WordAttributes wordAttributes = mDictStructurePolicy->getWordAttributesInContext(
mPrevWordIds, targetWordId, nullptr /* multiBigramMap */);
+ if (wordAttributes.getProbability() == NOT_A_PROBABILITY) {
+ return;
+ }
mSuggestionResults->addPrediction(targetWordCodePoints, codePointCount,
wordAttributes.getProbability());
}
@@ -140,10 +143,9 @@ bool Dictionary::removeUnigramEntry(const CodePointArrayView codePoints) {
return mDictionaryStructureWithBufferPolicy->removeUnigramEntry(codePoints);
}
-bool Dictionary::addNgramEntry(const NgramContext *const ngramContext,
- const NgramProperty *const ngramProperty) {
+bool Dictionary::addNgramEntry(const NgramProperty *const ngramProperty) {
TimeKeeper::setCurrentTime();
- return mDictionaryStructureWithBufferPolicy->addNgramEntry(ngramContext, ngramProperty);
+ return mDictionaryStructureWithBufferPolicy->addNgramEntry(ngramProperty);
}
bool Dictionary::removeNgramEntry(const NgramContext *const ngramContext,
diff --git a/native/jni/src/suggest/core/dictionary/dictionary.h b/native/jni/src/suggest/core/dictionary/dictionary.h
index 843aec473..a5e986d15 100644
--- a/native/jni/src/suggest/core/dictionary/dictionary.h
+++ b/native/jni/src/suggest/core/dictionary/dictionary.h
@@ -85,8 +85,7 @@ class Dictionary {
bool removeUnigramEntry(const CodePointArrayView codePoints);
- bool addNgramEntry(const NgramContext *const ngramContext,
- const NgramProperty *const ngramProperty);
+ bool addNgramEntry(const NgramProperty *const ngramProperty);
bool removeNgramEntry(const NgramContext *const ngramContext,
const CodePointArrayView codePoints);
diff --git a/native/jni/src/suggest/core/dictionary/error_type_utils.cpp b/native/jni/src/suggest/core/dictionary/error_type_utils.cpp
index 1e2494e92..8f07ce275 100644
--- a/native/jni/src/suggest/core/dictionary/error_type_utils.cpp
+++ b/native/jni/src/suggest/core/dictionary/error_type_utils.cpp
@@ -31,6 +31,7 @@ const ErrorTypeUtils::ErrorType ErrorTypeUtils::NEW_WORD = 0x100;
const ErrorTypeUtils::ErrorType ErrorTypeUtils::ERRORS_TREATED_AS_AN_EXACT_MATCH =
NOT_AN_ERROR | MATCH_WITH_WRONG_CASE | MATCH_WITH_MISSING_ACCENT | MATCH_WITH_DIGRAPH;
+const ErrorTypeUtils::ErrorType ErrorTypeUtils::ERRORS_TREATED_AS_A_PERFECT_MATCH = NOT_AN_ERROR;
const ErrorTypeUtils::ErrorType
ErrorTypeUtils::ERRORS_TREATED_AS_AN_EXACT_MATCH_WITH_INTENTIONAL_OMISSION =
diff --git a/native/jni/src/suggest/core/dictionary/error_type_utils.h b/native/jni/src/suggest/core/dictionary/error_type_utils.h
index fd1d5fcff..e92c509fa 100644
--- a/native/jni/src/suggest/core/dictionary/error_type_utils.h
+++ b/native/jni/src/suggest/core/dictionary/error_type_utils.h
@@ -52,6 +52,10 @@ class ErrorTypeUtils {
return (containedErrorTypes & ~ERRORS_TREATED_AS_AN_EXACT_MATCH) == 0;
}
+ static bool isPerfectMatch(const ErrorType containedErrorTypes) {
+ return (containedErrorTypes & ~ERRORS_TREATED_AS_A_PERFECT_MATCH) == 0;
+ }
+
static bool isExactMatchWithIntentionalOmission(const ErrorType containedErrorTypes) {
return (containedErrorTypes
& ~ERRORS_TREATED_AS_AN_EXACT_MATCH_WITH_INTENTIONAL_OMISSION) == 0;
@@ -73,6 +77,7 @@ class ErrorTypeUtils {
DISALLOW_IMPLICIT_CONSTRUCTORS(ErrorTypeUtils);
static const ErrorType ERRORS_TREATED_AS_AN_EXACT_MATCH;
+ static const ErrorType ERRORS_TREATED_AS_A_PERFECT_MATCH;
static const ErrorType ERRORS_TREATED_AS_AN_EXACT_MATCH_WITH_INTENTIONAL_OMISSION;
};
} // namespace latinime
diff --git a/native/jni/src/suggest/core/dictionary/ngram_listener.h b/native/jni/src/suggest/core/dictionary/ngram_listener.h
index e9b3c1aaf..2eb5e9fd1 100644
--- a/native/jni/src/suggest/core/dictionary/ngram_listener.h
+++ b/native/jni/src/suggest/core/dictionary/ngram_listener.h
@@ -26,6 +26,8 @@ namespace latinime {
*/
class NgramListener {
public:
+ // ngramProbability is always 0 for v403 decaying dictionary.
+ // TODO: Remove ngramProbability.
virtual void onVisitEntry(const int ngramProbability, const int targetWordId) = 0;
virtual ~NgramListener() {};
diff --git a/native/jni/src/suggest/core/dictionary/property/historical_info.h b/native/jni/src/suggest/core/dictionary/property/historical_info.h
index f9bd6fd8c..e5ce1ea25 100644
--- a/native/jni/src/suggest/core/dictionary/property/historical_info.h
+++ b/native/jni/src/suggest/core/dictionary/property/historical_info.h
@@ -38,6 +38,7 @@ class HistoricalInfo {
return mTimestamp;
}
+ // TODO: Remove
int getLevel() const {
return mLevel;
}
diff --git a/native/jni/src/suggest/core/dictionary/property/ngram_property.h b/native/jni/src/suggest/core/dictionary/property/ngram_property.h
index 8709799f9..e67b4da31 100644
--- a/native/jni/src/suggest/core/dictionary/property/ngram_property.h
+++ b/native/jni/src/suggest/core/dictionary/property/ngram_property.h
@@ -21,15 +21,20 @@
#include "defines.h"
#include "suggest/core/dictionary/property/historical_info.h"
+#include "suggest/core/session/ngram_context.h"
namespace latinime {
class NgramProperty {
public:
- NgramProperty(const std::vector<int> &&targetCodePoints, const int probability,
- const HistoricalInfo historicalInfo)
- : mTargetCodePoints(std::move(targetCodePoints)), mProbability(probability),
- mHistoricalInfo(historicalInfo) {}
+ NgramProperty(const NgramContext &ngramContext, const std::vector<int> &&targetCodePoints,
+ const int probability, const HistoricalInfo historicalInfo)
+ : mNgramContext(ngramContext), mTargetCodePoints(std::move(targetCodePoints)),
+ mProbability(probability), mHistoricalInfo(historicalInfo) {}
+
+ const NgramContext *getNgramContext() const {
+ return &mNgramContext;
+ }
const std::vector<int> *getTargetCodePoints() const {
return &mTargetCodePoints;
@@ -48,6 +53,7 @@ class NgramProperty {
DISALLOW_DEFAULT_CONSTRUCTOR(NgramProperty);
DISALLOW_ASSIGNMENT_OPERATOR(NgramProperty);
+ const NgramContext mNgramContext;
const std::vector<int> mTargetCodePoints;
const int mProbability;
const HistoricalInfo mHistoricalInfo;
diff --git a/native/jni/src/suggest/core/dictionary/property/unigram_property.h b/native/jni/src/suggest/core/dictionary/property/unigram_property.h
index 5ed2e2602..f194f979a 100644
--- a/native/jni/src/suggest/core/dictionary/property/unigram_property.h
+++ b/native/jni/src/suggest/core/dictionary/property/unigram_property.h
@@ -49,21 +49,44 @@ class UnigramProperty {
};
UnigramProperty()
- : mRepresentsBeginningOfSentence(false), mIsNotAWord(false), mIsBlacklisted(false),
- mProbability(NOT_A_PROBABILITY), mHistoricalInfo(), mShortcuts() {}
+ : mRepresentsBeginningOfSentence(false), mIsNotAWord(false),
+ mIsBlacklisted(false), mIsPossiblyOffensive(false), mProbability(NOT_A_PROBABILITY),
+ mHistoricalInfo(), mShortcuts() {}
+ // In contexts which do not support the Blacklisted flag (v2, v4<403)
UnigramProperty(const bool representsBeginningOfSentence, const bool isNotAWord,
- const bool isBlacklisted, const int probability, const HistoricalInfo historicalInfo,
- const std::vector<ShortcutProperty> &&shortcuts)
+ const bool isPossiblyOffensive, const int probability,
+ const HistoricalInfo historicalInfo, const std::vector<ShortcutProperty> &&shortcuts)
: mRepresentsBeginningOfSentence(representsBeginningOfSentence),
- mIsNotAWord(isNotAWord), mIsBlacklisted(isBlacklisted), mProbability(probability),
+ mIsNotAWord(isNotAWord), mIsBlacklisted(false),
+ mIsPossiblyOffensive(isPossiblyOffensive), mProbability(probability),
mHistoricalInfo(historicalInfo), mShortcuts(std::move(shortcuts)) {}
- // Without shortcuts.
+ // Without shortcuts, in contexts which do not support the Blacklisted flag (v2, v4<403)
UnigramProperty(const bool representsBeginningOfSentence, const bool isNotAWord,
- const bool isBlacklisted, const int probability, const HistoricalInfo historicalInfo)
+ const bool isPossiblyOffensive, const int probability,
+ const HistoricalInfo historicalInfo)
: mRepresentsBeginningOfSentence(representsBeginningOfSentence),
- mIsNotAWord(isNotAWord), mIsBlacklisted(isBlacklisted), mProbability(probability),
+ mIsNotAWord(isNotAWord), mIsBlacklisted(false),
+ mIsPossiblyOffensive(isPossiblyOffensive), mProbability(probability),
+ mHistoricalInfo(historicalInfo), mShortcuts() {}
+
+ // In contexts which DO support the Blacklisted flag (v403)
+ UnigramProperty(const bool representsBeginningOfSentence, const bool isNotAWord,
+ const bool isBlacklisted, const bool isPossiblyOffensive, const int probability,
+ const HistoricalInfo historicalInfo, const std::vector<ShortcutProperty> &&shortcuts)
+ : mRepresentsBeginningOfSentence(representsBeginningOfSentence),
+ mIsNotAWord(isNotAWord), mIsBlacklisted(isBlacklisted),
+ mIsPossiblyOffensive(isPossiblyOffensive), mProbability(probability),
+ mHistoricalInfo(historicalInfo), mShortcuts(std::move(shortcuts)) {}
+
+ // Without shortcuts, in contexts which DO support the Blacklisted flag (v403)
+ UnigramProperty(const bool representsBeginningOfSentence, const bool isNotAWord,
+ const bool isBlacklisted, const bool isPossiblyOffensive, const int probability,
+ const HistoricalInfo historicalInfo)
+ : mRepresentsBeginningOfSentence(representsBeginningOfSentence),
+ mIsNotAWord(isNotAWord), mIsBlacklisted(isBlacklisted),
+ mIsPossiblyOffensive(isPossiblyOffensive), mProbability(probability),
mHistoricalInfo(historicalInfo), mShortcuts() {}
bool representsBeginningOfSentence() const {
@@ -74,13 +97,12 @@ class UnigramProperty {
return mIsNotAWord;
}
- bool isBlacklisted() const {
- return mIsBlacklisted;
+ bool isPossiblyOffensive() const {
+ return mIsPossiblyOffensive;
}
- bool isPossiblyOffensive() const {
- // TODO: Have dedicated flag.
- return mProbability == 0;
+ bool isBlacklisted() const {
+ return mIsBlacklisted;
}
bool hasShortcuts() const {
@@ -106,6 +128,7 @@ class UnigramProperty {
const bool mRepresentsBeginningOfSentence;
const bool mIsNotAWord;
const bool mIsBlacklisted;
+ const bool mIsPossiblyOffensive;
const int mProbability;
const HistoricalInfo mHistoricalInfo;
const std::vector<ShortcutProperty> mShortcuts;
diff --git a/native/jni/src/suggest/core/dictionary/property/word_property.cpp b/native/jni/src/suggest/core/dictionary/property/word_property.cpp
index caac8fe79..019f0880f 100644
--- a/native/jni/src/suggest/core/dictionary/property/word_property.cpp
+++ b/native/jni/src/suggest/core/dictionary/property/word_property.cpp
@@ -22,13 +22,14 @@
namespace latinime {
void WordProperty::outputProperties(JNIEnv *const env, jintArray outCodePoints,
- jbooleanArray outFlags, jintArray outProbabilityInfo, jobject outBigramTargets,
- jobject outBigramProbabilities, jobject outShortcutTargets,
+ jbooleanArray outFlags, jintArray outProbabilityInfo,
+ jobject outNgramPrevWordsArray, jobject outNgramPrevWordIsBeginningOfSentenceArray,
+ jobject outNgramTargets, jobject outNgramProbabilities, jobject outShortcutTargets,
jobject outShortcutProbabilities) const {
JniDataUtils::outputCodePoints(env, outCodePoints, 0 /* start */,
MAX_WORD_LENGTH /* maxLength */, mCodePoints.data(), mCodePoints.size(),
false /* needsNullTermination */);
- jboolean flags[] = {mUnigramProperty.isNotAWord(), mUnigramProperty.isBlacklisted(),
+ jboolean flags[] = {mUnigramProperty.isNotAWord(), mUnigramProperty.isPossiblyOffensive(),
!mNgrams.empty(), mUnigramProperty.hasShortcuts(),
mUnigramProperty.representsBeginningOfSentence()};
env->SetBooleanArrayRegion(outFlags, 0 /* start */, NELEMS(flags), flags);
@@ -43,16 +44,39 @@ void WordProperty::outputProperties(JNIEnv *const env, jintArray outCodePoints,
jclass arrayListClass = env->FindClass("java/util/ArrayList");
jmethodID addMethodId = env->GetMethodID(arrayListClass, "add", "(Ljava/lang/Object;)Z");
- // Output bigrams.
- // TODO: Support n-gram
+ // Output ngrams.
+ jclass intArrayClass = env->FindClass("[I");
for (const auto &ngramProperty : mNgrams) {
- const std::vector<int> *const word1CodePoints = ngramProperty.getTargetCodePoints();
- jintArray bigramWord1CodePointArray = env->NewIntArray(word1CodePoints->size());
- JniDataUtils::outputCodePoints(env, bigramWord1CodePointArray, 0 /* start */,
- word1CodePoints->size(), word1CodePoints->data(), word1CodePoints->size(),
- false /* needsNullTermination */);
- env->CallBooleanMethod(outBigramTargets, addMethodId, bigramWord1CodePointArray);
- env->DeleteLocalRef(bigramWord1CodePointArray);
+ const NgramContext *const ngramContext = ngramProperty.getNgramContext();
+ jobjectArray prevWordWordCodePointsArray = env->NewObjectArray(
+ ngramContext->getPrevWordCount(), intArrayClass, nullptr);
+ jbooleanArray prevWordIsBeginningOfSentenceArray =
+ env->NewBooleanArray(ngramContext->getPrevWordCount());
+ for (size_t i = 0; i < ngramContext->getPrevWordCount(); ++i) {
+ const CodePointArrayView codePoints = ngramContext->getNthPrevWordCodePoints(i + 1);
+ jintArray prevWordCodePoints = env->NewIntArray(codePoints.size());
+ JniDataUtils::outputCodePoints(env, prevWordCodePoints, 0 /* start */,
+ codePoints.size(), codePoints.data(), codePoints.size(),
+ false /* needsNullTermination */);
+ env->SetObjectArrayElement(prevWordWordCodePointsArray, i, prevWordCodePoints);
+ env->DeleteLocalRef(prevWordCodePoints);
+ JniDataUtils::putBooleanToArray(env, prevWordIsBeginningOfSentenceArray, i,
+ ngramContext->isNthPrevWordBeginningOfSentence(i + 1));
+ }
+ env->CallBooleanMethod(outNgramPrevWordsArray, addMethodId, prevWordWordCodePointsArray);
+ env->CallBooleanMethod(outNgramPrevWordIsBeginningOfSentenceArray, addMethodId,
+ prevWordIsBeginningOfSentenceArray);
+ env->DeleteLocalRef(prevWordWordCodePointsArray);
+ env->DeleteLocalRef(prevWordIsBeginningOfSentenceArray);
+
+ const std::vector<int> *const targetWordCodePoints = ngramProperty.getTargetCodePoints();
+ jintArray targetWordCodePointArray = env->NewIntArray(targetWordCodePoints->size());
+ JniDataUtils::outputCodePoints(env, targetWordCodePointArray, 0 /* start */,
+ targetWordCodePoints->size(), targetWordCodePoints->data(),
+ targetWordCodePoints->size(), false /* needsNullTermination */);
+ env->CallBooleanMethod(outNgramTargets, addMethodId, targetWordCodePointArray);
+ env->DeleteLocalRef(targetWordCodePointArray);
+
const HistoricalInfo &ngramHistoricalInfo = ngramProperty.getHistoricalInfo();
int bigramProbabilityInfo[] = {ngramProperty.getProbability(),
ngramHistoricalInfo.getTimestamp(), ngramHistoricalInfo.getLevel(),
@@ -60,7 +84,7 @@ void WordProperty::outputProperties(JNIEnv *const env, jintArray outCodePoints,
jintArray bigramProbabilityInfoArray = env->NewIntArray(NELEMS(bigramProbabilityInfo));
env->SetIntArrayRegion(bigramProbabilityInfoArray, 0 /* start */,
NELEMS(bigramProbabilityInfo), bigramProbabilityInfo);
- env->CallBooleanMethod(outBigramProbabilities, addMethodId, bigramProbabilityInfoArray);
+ env->CallBooleanMethod(outNgramProbabilities, addMethodId, bigramProbabilityInfoArray);
env->DeleteLocalRef(bigramProbabilityInfoArray);
}
diff --git a/native/jni/src/suggest/core/dictionary/property/word_property.h b/native/jni/src/suggest/core/dictionary/property/word_property.h
index 0c23e8225..b5314faaa 100644
--- a/native/jni/src/suggest/core/dictionary/property/word_property.h
+++ b/native/jni/src/suggest/core/dictionary/property/word_property.h
@@ -34,13 +34,15 @@ class WordProperty {
: mCodePoints(), mUnigramProperty(), mNgrams() {}
WordProperty(const std::vector<int> &&codePoints, const UnigramProperty *const unigramProperty,
- const std::vector<NgramProperty> *const bigrams)
+ const std::vector<NgramProperty> *const ngrams)
: mCodePoints(std::move(codePoints)), mUnigramProperty(*unigramProperty),
- mNgrams(*bigrams) {}
+ mNgrams(*ngrams) {}
void outputProperties(JNIEnv *const env, jintArray outCodePoints, jbooleanArray outFlags,
- jintArray outProbabilityInfo, jobject outBigramTargets, jobject outBigramProbabilities,
- jobject outShortcutTargets, jobject outShortcutProbabilities) const;
+ jintArray outProbabilityInfo, jobject outNgramPrevWordsArray,
+ jobject outNgramPrevWordIsBeginningOfSentenceArray, jobject outNgramTargets,
+ jobject outNgramProbabilities, jobject outShortcutTargets,
+ jobject outShortcutProbabilities) const;
const UnigramProperty *getUnigramProperty() const {
return &mUnigramProperty;
diff --git a/native/jni/src/suggest/core/dictionary/word_attributes.h b/native/jni/src/suggest/core/dictionary/word_attributes.h
index 6e9da3570..5351e7d7d 100644
--- a/native/jni/src/suggest/core/dictionary/word_attributes.h
+++ b/native/jni/src/suggest/core/dictionary/word_attributes.h
@@ -43,6 +43,14 @@ class WordAttributes {
return mIsNotAWord;
}
+ // Whether or not a word is possibly offensive.
+ // * Static dictionaries <v202, as well as dynamic dictionaries <v403, will set this based on
+ // whether or not the probability of the word is zero.
+ // * Static dictionaries >=v203 will set this based on the IS_POSSIBLY_OFFENSIVE PtNode flag.
+ // * Dynamic dictionaries >=v403 will set this based on the IS_POSSIBLY_OFFENSIVE language model
+ // flag (the PtNode flag IS_BLACKLISTED is ignored and kept as zero)
+ //
+ // See the ::getWordAttributes function for each of these dictionary policies for more details.
bool isPossiblyOffensive() const {
return mIsPossiblyOffensive;
}
diff --git a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h
index ceda5c03f..33a0fbc19 100644
--- a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h
+++ b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h
@@ -40,7 +40,6 @@ class UnigramProperty;
* This class abstracts the structure of dictionaries.
* Implement this policy to support additional dictionaries.
*/
-// TODO: Use word id instead of terminal PtNode position.
class DictionaryStructureWithBufferPolicy {
public:
typedef std::unique_ptr<DictionaryStructureWithBufferPolicy> StructurePolicyPtr;
@@ -81,8 +80,7 @@ class DictionaryStructureWithBufferPolicy {
virtual bool removeUnigramEntry(const CodePointArrayView wordCodePoints) = 0;
// Returns whether the update was success or not.
- virtual bool addNgramEntry(const NgramContext *const ngramContext,
- const NgramProperty *const ngramProperty) = 0;
+ virtual bool addNgramEntry(const NgramProperty *const ngramProperty) = 0;
// Returns whether the update was success or not.
virtual bool removeNgramEntry(const NgramContext *const ngramContext,
@@ -106,7 +104,6 @@ class DictionaryStructureWithBufferPolicy {
virtual void getProperty(const char *const query, const int queryLength, char *const outResult,
const int maxResultLength) = 0;
- // Used for testing.
virtual const WordProperty getWordProperty(const CodePointArrayView wordCodePoints) const = 0;
// Method to iterate all words in the dictionary.
diff --git a/native/jni/src/suggest/core/policy/scoring.h b/native/jni/src/suggest/core/policy/scoring.h
index ce3684a1c..b9dda83ad 100644
--- a/native/jni/src/suggest/core/policy/scoring.h
+++ b/native/jni/src/suggest/core/policy/scoring.h
@@ -30,7 +30,7 @@ class Scoring {
public:
virtual int calculateFinalScore(const float compoundDistance, const int inputSize,
const ErrorTypeUtils::ErrorType containedErrorTypes, const bool forceCommit,
- const bool boostExactMatches) const = 0;
+ const bool boostExactMatches, const bool hasProbabilityZero) const = 0;
virtual void getMostProbableString(const DicTraverseSession *const traverseSession,
const float weightOfLangModelVsSpatialModel,
SuggestionResults *const outSuggestionResults) const = 0;
diff --git a/native/jni/src/suggest/core/policy/weighting.cpp b/native/jni/src/suggest/core/policy/weighting.cpp
index a06e7d070..450203d98 100644
--- a/native/jni/src/suggest/core/policy/weighting.cpp
+++ b/native/jni/src/suggest/core/policy/weighting.cpp
@@ -119,7 +119,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
return weighting->getSubstitutionCost()
+ weighting->getMatchedCost(traverseSession, dicNode, inputStateG);
case CT_NEW_WORD_SPACE_OMISSION:
- return weighting->getNewWordSpatialCost(traverseSession, dicNode, inputStateG);
+ return weighting->getSpaceOmissionCost(traverseSession, dicNode, inputStateG);
case CT_MATCH:
return weighting->getMatchedCost(traverseSession, dicNode, inputStateG);
case CT_COMPLETION:
diff --git a/native/jni/src/suggest/core/policy/weighting.h b/native/jni/src/suggest/core/policy/weighting.h
index bd6b3cf41..863c4eabe 100644
--- a/native/jni/src/suggest/core/policy/weighting.h
+++ b/native/jni/src/suggest/core/policy/weighting.h
@@ -57,7 +57,7 @@ class Weighting {
const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0;
- virtual float getNewWordSpatialCost(const DicTraverseSession *const traverseSession,
+ virtual float getSpaceOmissionCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode, DicNode_InputStateG *const inputStateG) const = 0;
virtual float getNewWordBigramLanguageCost(
diff --git a/native/jni/src/suggest/core/result/suggestions_output_utils.cpp b/native/jni/src/suggest/core/result/suggestions_output_utils.cpp
index 3283f6deb..74db95953 100644
--- a/native/jni/src/suggest/core/result/suggestions_output_utils.cpp
+++ b/native/jni/src/suggest/core/result/suggestions_output_utils.cpp
@@ -76,6 +76,52 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16;
weightOfLangModelVsSpatialModelToOutputSuggestions, outSuggestionResults);
}
+/* static */ bool SuggestionsOutputUtils::shouldBlockWord(
+ const SuggestOptions *const suggestOptions, const DicNode *const terminalDicNode,
+ const WordAttributes wordAttributes, const bool isLastWord) {
+ const bool currentWordExactMatch =
+ ErrorTypeUtils::isExactMatch(terminalDicNode->getContainedErrorTypes());
+ // When we have to block offensive words, non-exact matched offensive words should not be
+ // output.
+ const bool shouldBlockOffensiveWords = suggestOptions->blockOffensiveWords();
+
+ const bool isBlockedOffensiveWord = shouldBlockOffensiveWords &&
+ wordAttributes.isPossiblyOffensive();
+
+ // This function is called in two situations:
+ //
+ // 1) At the end of a search, in which case terminalDicNode will point to the last DicNode
+ // of the search, and isLastWord will be true.
+ // "fuck"
+ // |
+ // \ terminalDicNode (isLastWord=true, currentWordExactMatch=true)
+ // In this case, if the current word is an exact match, we will always let the word
+ // through, even if the user is blocking offensive words (it's exactly what they typed!)
+ //
+ // 2) In the middle of the search, when we hit a terminal node, to decide whether or not
+ // to start a new search at root, to try to match the rest of the input. In this case,
+ // terminalDicNode will point to the terminal node we just hit, and isLastWord will be
+ // false.
+ // "fuckvthis"
+ // |
+ // \ terminalDicNode (isLastWord=false, currentWordExactMatch=true)
+ //
+ // In this case, we should NOT allow the match through (correcting "fuckthis" to "fuck this"
+ // when offensive words are blocked would be a bad idea).
+ //
+ // In the case of a multi-word correction where the offensive word is typed last (eg.
+ // for the input "allfuck"), this function will be called with isLastWord==true, but
+ // currentWordExactMatch==false. So we are OK in this case as well.
+ // "allfuck"
+ // |
+ // \ terminalDicNode (isLastWord=true, currentWordExactMatch=false)
+ if (isLastWord && currentWordExactMatch) {
+ return false;
+ } else {
+ return isBlockedOffensiveWord;
+ }
+}
+
/* static */ void SuggestionsOutputUtils::outputSuggestionsOfDicNode(
const Scoring *const scoringPolicy, DicTraverseSession *traverseSession,
const DicNode *const terminalDicNode, const float weightOfLangModelVsSpatialModel,
@@ -98,24 +144,16 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16;
const bool isExactMatchWithIntentionalOmission =
ErrorTypeUtils::isExactMatchWithIntentionalOmission(
terminalDicNode->getContainedErrorTypes());
- const bool isFirstCharUppercase = terminalDicNode->isFirstCharUppercase();
- // Heuristic: We exclude probability=0 first-char-uppercase words from exact match.
- // (e.g. "AMD" and "and")
- const bool isSafeExactMatch = isExactMatch
- && !(wordAttributes.isPossiblyOffensive() && isFirstCharUppercase);
const int outputTypeFlags =
(wordAttributes.isPossiblyOffensive() ? Dictionary::KIND_FLAG_POSSIBLY_OFFENSIVE : 0)
- | ((isSafeExactMatch && boostExactMatches) ? Dictionary::KIND_FLAG_EXACT_MATCH : 0)
+ | ((isExactMatch && boostExactMatches) ? Dictionary::KIND_FLAG_EXACT_MATCH : 0)
| (isExactMatchWithIntentionalOmission ?
Dictionary::KIND_FLAG_EXACT_MATCH_WITH_INTENTIONAL_OMISSION : 0);
-
// Entries that are blacklisted or do not represent a word should not be output.
const bool isValidWord = !(wordAttributes.isBlacklisted() || wordAttributes.isNotAWord());
- // When we have to block offensive words, non-exact matched offensive words should not be
- // output.
- const bool blockOffensiveWords = traverseSession->getSuggestOptions()->blockOffensiveWords();
- const bool isBlockedOffensiveWord = blockOffensiveWords && wordAttributes.isPossiblyOffensive()
- && !isSafeExactMatch;
+
+ const bool shouldBlockThisWord = shouldBlockWord(traverseSession->getSuggestOptions(),
+ terminalDicNode, wordAttributes, true /* isLastWord */);
// Increase output score of top typing suggestion to ensure autocorrection.
// TODO: Better integration with java side autocorrection logic.
@@ -123,11 +161,11 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16;
compoundDistance, traverseSession->getInputSize(),
terminalDicNode->getContainedErrorTypes(),
(forceCommitMultiWords && terminalDicNode->hasMultipleWords()),
- boostExactMatches);
+ boostExactMatches, wordAttributes.getProbability() == 0);
// Don't output invalid or blocked offensive words. However, we still need to submit their
// shortcuts if any.
- if (isValidWord && !isBlockedOffensiveWord) {
+ if (isValidWord && !shouldBlockThisWord) {
int codePoints[MAX_WORD_LENGTH];
terminalDicNode->outputResult(codePoints);
const int indexToPartialCommit = outputSecondWordFirstLetterInputIndex ?
diff --git a/native/jni/src/suggest/core/result/suggestions_output_utils.h b/native/jni/src/suggest/core/result/suggestions_output_utils.h
index bf8497828..eca1f78b2 100644
--- a/native/jni/src/suggest/core/result/suggestions_output_utils.h
+++ b/native/jni/src/suggest/core/result/suggestions_output_utils.h
@@ -18,6 +18,7 @@
#define LATINIME_SUGGESTIONS_OUTPUT_UTILS
#include "defines.h"
+#include "suggest/core/dictionary/word_attributes.h"
namespace latinime {
@@ -25,11 +26,19 @@ class BinaryDictionaryShortcutIterator;
class DicNode;
class DicTraverseSession;
class Scoring;
+class SuggestOptions;
class SuggestionResults;
class SuggestionsOutputUtils {
public:
/**
+ * Returns true if we should block the incoming word, in the context of the user's
+ * preferences to include or not include possibly offensive words
+ */
+ static bool shouldBlockWord(const SuggestOptions *const suggestOptions,
+ const DicNode *const terminalDicNode, const WordAttributes wordAttributes,
+ const bool isLastWord);
+ /**
* Outputs the final list of suggestions (i.e., terminal nodes).
*/
static void outputSuggestions(const Scoring *const scoringPolicy,
diff --git a/native/jni/src/suggest/core/session/ngram_context.cpp b/native/jni/src/suggest/core/session/ngram_context.cpp
new file mode 100644
index 000000000..17ef9ae60
--- /dev/null
+++ b/native/jni/src/suggest/core/session/ngram_context.cpp
@@ -0,0 +1,123 @@
+/*
+ * Copyright (C) 2014 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "suggest/core/session/ngram_context.h"
+
+#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h"
+#include "utils/char_utils.h"
+
+namespace latinime {
+
+NgramContext::NgramContext() : mPrevWordCount(0) {}
+
+NgramContext::NgramContext(const NgramContext &ngramContext)
+ : mPrevWordCount(ngramContext.mPrevWordCount) {
+ for (size_t i = 0; i < mPrevWordCount; ++i) {
+ mPrevWordCodePointCount[i] = ngramContext.mPrevWordCodePointCount[i];
+ memmove(mPrevWordCodePoints[i], ngramContext.mPrevWordCodePoints[i],
+ sizeof(mPrevWordCodePoints[i][0]) * mPrevWordCodePointCount[i]);
+ mIsBeginningOfSentence[i] = ngramContext.mIsBeginningOfSentence[i];
+ }
+}
+
+NgramContext::NgramContext(const int prevWordCodePoints[][MAX_WORD_LENGTH],
+ const int *const prevWordCodePointCount, const bool *const isBeginningOfSentence,
+ const size_t prevWordCount)
+ : mPrevWordCount(std::min(NELEMS(mPrevWordCodePoints), prevWordCount)) {
+ clear();
+ for (size_t i = 0; i < mPrevWordCount; ++i) {
+ if (prevWordCodePointCount[i] < 0 || prevWordCodePointCount[i] > MAX_WORD_LENGTH) {
+ continue;
+ }
+ memmove(mPrevWordCodePoints[i], prevWordCodePoints[i],
+ sizeof(mPrevWordCodePoints[i][0]) * prevWordCodePointCount[i]);
+ mPrevWordCodePointCount[i] = prevWordCodePointCount[i];
+ mIsBeginningOfSentence[i] = isBeginningOfSentence[i];
+ }
+}
+
+NgramContext::NgramContext(const int *const prevWordCodePoints, const int prevWordCodePointCount,
+ const bool isBeginningOfSentence) : mPrevWordCount(1) {
+ clear();
+ if (prevWordCodePointCount > MAX_WORD_LENGTH || !prevWordCodePoints) {
+ return;
+ }
+ memmove(mPrevWordCodePoints[0], prevWordCodePoints,
+ sizeof(mPrevWordCodePoints[0][0]) * prevWordCodePointCount);
+ mPrevWordCodePointCount[0] = prevWordCodePointCount;
+ mIsBeginningOfSentence[0] = isBeginningOfSentence;
+}
+
+bool NgramContext::isValid() const {
+ if (mPrevWordCodePointCount[0] > 0) {
+ return true;
+ }
+ if (mIsBeginningOfSentence[0]) {
+ return true;
+ }
+ return false;
+}
+
+const CodePointArrayView NgramContext::getNthPrevWordCodePoints(const size_t n) const {
+ if (n <= 0 || n > mPrevWordCount) {
+ return CodePointArrayView();
+ }
+ return CodePointArrayView(mPrevWordCodePoints[n - 1], mPrevWordCodePointCount[n - 1]);
+}
+
+bool NgramContext::isNthPrevWordBeginningOfSentence(const size_t n) const {
+ if (n <= 0 || n > mPrevWordCount) {
+ return false;
+ }
+ return mIsBeginningOfSentence[n - 1];
+}
+
+/* static */ int NgramContext::getWordId(
+ const DictionaryStructureWithBufferPolicy *const dictStructurePolicy,
+ const int *const wordCodePoints, const int wordCodePointCount,
+ const bool isBeginningOfSentence, const bool tryLowerCaseSearch) {
+ if (!dictStructurePolicy || !wordCodePoints || wordCodePointCount > MAX_WORD_LENGTH) {
+ return NOT_A_WORD_ID;
+ }
+ int codePoints[MAX_WORD_LENGTH];
+ int codePointCount = wordCodePointCount;
+ memmove(codePoints, wordCodePoints, sizeof(int) * codePointCount);
+ if (isBeginningOfSentence) {
+ codePointCount = CharUtils::attachBeginningOfSentenceMarker(codePoints, codePointCount,
+ MAX_WORD_LENGTH);
+ if (codePointCount <= 0) {
+ return NOT_A_WORD_ID;
+ }
+ }
+ const CodePointArrayView codePointArrayView(codePoints, codePointCount);
+ const int wordId = dictStructurePolicy->getWordId(codePointArrayView,
+ false /* forceLowerCaseSearch */);
+ if (wordId != NOT_A_WORD_ID || !tryLowerCaseSearch) {
+ // Return the id when when the word was found or doesn't try lower case search.
+ return wordId;
+ }
+ // Check bigrams for lower-cased previous word if original was not found. Useful for
+ // auto-capitalized words like "The [current_word]".
+ return dictStructurePolicy->getWordId(codePointArrayView, true /* forceLowerCaseSearch */);
+}
+
+void NgramContext::clear() {
+ for (size_t i = 0; i < NELEMS(mPrevWordCodePoints); ++i) {
+ mPrevWordCodePointCount[i] = 0;
+ mIsBeginningOfSentence[i] = false;
+ }
+}
+} // namespace latinime
diff --git a/native/jni/src/suggest/core/session/ngram_context.h b/native/jni/src/suggest/core/session/ngram_context.h
index 64c71410f..9b36199c9 100644
--- a/native/jni/src/suggest/core/session/ngram_context.h
+++ b/native/jni/src/suggest/core/session/ngram_context.h
@@ -20,145 +20,54 @@
#include <array>
#include "defines.h"
-#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h"
-#include "utils/char_utils.h"
#include "utils/int_array_view.h"
namespace latinime {
-// Rename to NgramContext.
+class DictionaryStructureWithBufferPolicy;
+
class NgramContext {
public:
// No prev word information.
- NgramContext() : mPrevWordCount(0) {
- clear();
- }
-
- NgramContext(const NgramContext &ngramContext)
- : mPrevWordCount(ngramContext.mPrevWordCount) {
- for (size_t i = 0; i < mPrevWordCount; ++i) {
- mPrevWordCodePointCount[i] = ngramContext.mPrevWordCodePointCount[i];
- memmove(mPrevWordCodePoints[i], ngramContext.mPrevWordCodePoints[i],
- sizeof(mPrevWordCodePoints[i][0]) * mPrevWordCodePointCount[i]);
- mIsBeginningOfSentence[i] = ngramContext.mIsBeginningOfSentence[i];
- }
- }
-
+ NgramContext();
+ // Copy constructor to use this class with std::vector and use this class as a return value.
+ NgramContext(const NgramContext &ngramContext);
// Construct from previous words.
NgramContext(const int prevWordCodePoints[][MAX_WORD_LENGTH],
const int *const prevWordCodePointCount, const bool *const isBeginningOfSentence,
- const size_t prevWordCount)
- : mPrevWordCount(std::min(NELEMS(mPrevWordCodePoints), prevWordCount)) {
- clear();
- for (size_t i = 0; i < mPrevWordCount; ++i) {
- if (prevWordCodePointCount[i] < 0 || prevWordCodePointCount[i] > MAX_WORD_LENGTH) {
- continue;
- }
- memmove(mPrevWordCodePoints[i], prevWordCodePoints[i],
- sizeof(mPrevWordCodePoints[i][0]) * prevWordCodePointCount[i]);
- mPrevWordCodePointCount[i] = prevWordCodePointCount[i];
- mIsBeginningOfSentence[i] = isBeginningOfSentence[i];
- }
- }
-
+ const size_t prevWordCount);
// Construct from a previous word.
NgramContext(const int *const prevWordCodePoints, const int prevWordCodePointCount,
- const bool isBeginningOfSentence) : mPrevWordCount(1) {
- clear();
- if (prevWordCodePointCount > MAX_WORD_LENGTH || !prevWordCodePoints) {
- return;
- }
- memmove(mPrevWordCodePoints[0], prevWordCodePoints,
- sizeof(mPrevWordCodePoints[0][0]) * prevWordCodePointCount);
- mPrevWordCodePointCount[0] = prevWordCodePointCount;
- mIsBeginningOfSentence[0] = isBeginningOfSentence;
- }
+ const bool isBeginningOfSentence);
size_t getPrevWordCount() const {
return mPrevWordCount;
}
-
- // TODO: Remove.
- const NgramContext getTrimmedNgramContext(const size_t maxPrevWordCount) const {
- return NgramContext(mPrevWordCodePoints, mPrevWordCodePointCount, mIsBeginningOfSentence,
- std::min(mPrevWordCount, maxPrevWordCount));
- }
-
- bool isValid() const {
- if (mPrevWordCodePointCount[0] > 0) {
- return true;
- }
- if (mIsBeginningOfSentence[0]) {
- return true;
- }
- return false;
- }
+ bool isValid() const;
template<size_t N>
const WordIdArrayView getPrevWordIds(
const DictionaryStructureWithBufferPolicy *const dictStructurePolicy,
- std::array<int, N> *const prevWordIdBuffer, const bool tryLowerCaseSearch) const {
+ WordIdArray<N> *const prevWordIdBuffer, const bool tryLowerCaseSearch) const {
for (size_t i = 0; i < std::min(mPrevWordCount, N); ++i) {
- prevWordIdBuffer->at(i) = getWordId(dictStructurePolicy,
- mPrevWordCodePoints[i], mPrevWordCodePointCount[i],
- mIsBeginningOfSentence[i], tryLowerCaseSearch);
+ prevWordIdBuffer->at(i) = getWordId(dictStructurePolicy, mPrevWordCodePoints[i],
+ mPrevWordCodePointCount[i], mIsBeginningOfSentence[i], tryLowerCaseSearch);
}
return WordIdArrayView::fromArray(*prevWordIdBuffer).limit(mPrevWordCount);
}
// n is 1-indexed.
- const CodePointArrayView getNthPrevWordCodePoints(const size_t n) const {
- if (n <= 0 || n > mPrevWordCount) {
- return CodePointArrayView();
- }
- return CodePointArrayView(mPrevWordCodePoints[n - 1], mPrevWordCodePointCount[n - 1]);
- }
-
+ const CodePointArrayView getNthPrevWordCodePoints(const size_t n) const;
// n is 1-indexed.
- bool isNthPrevWordBeginningOfSentence(const size_t n) const {
- if (n <= 0 || n > mPrevWordCount) {
- return false;
- }
- return mIsBeginningOfSentence[n - 1];
- }
+ bool isNthPrevWordBeginningOfSentence(const size_t n) const;
private:
DISALLOW_ASSIGNMENT_OPERATOR(NgramContext);
static int getWordId(const DictionaryStructureWithBufferPolicy *const dictStructurePolicy,
const int *const wordCodePoints, const int wordCodePointCount,
- const bool isBeginningOfSentence, const bool tryLowerCaseSearch) {
- if (!dictStructurePolicy || !wordCodePoints || wordCodePointCount > MAX_WORD_LENGTH) {
- return NOT_A_WORD_ID;
- }
- int codePoints[MAX_WORD_LENGTH];
- int codePointCount = wordCodePointCount;
- memmove(codePoints, wordCodePoints, sizeof(int) * codePointCount);
- if (isBeginningOfSentence) {
- codePointCount = CharUtils::attachBeginningOfSentenceMarker(codePoints,
- codePointCount, MAX_WORD_LENGTH);
- if (codePointCount <= 0) {
- return NOT_A_WORD_ID;
- }
- }
- const CodePointArrayView codePointArrayView(codePoints, codePointCount);
- const int wordId = dictStructurePolicy->getWordId(
- codePointArrayView, false /* forceLowerCaseSearch */);
- if (wordId != NOT_A_WORD_ID || !tryLowerCaseSearch) {
- // Return the id when when the word was found or doesn't try lower case search.
- return wordId;
- }
- // Check bigrams for lower-cased previous word if original was not found. Useful for
- // auto-capitalized words like "The [current_word]".
- return dictStructurePolicy->getWordId(codePointArrayView, true /* forceLowerCaseSearch */);
- }
-
- void clear() {
- for (size_t i = 0; i < NELEMS(mPrevWordCodePoints); ++i) {
- mPrevWordCodePointCount[i] = 0;
- mIsBeginningOfSentence[i] = false;
- }
- }
+ const bool isBeginningOfSentence, const bool tryLowerCaseSearch);
+ void clear();
const size_t mPrevWordCount;
int mPrevWordCodePoints[MAX_PREV_WORD_COUNT_FOR_N_GRAM][MAX_WORD_LENGTH];
diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp
index c71526293..c372d668b 100644
--- a/native/jni/src/suggest/core/suggest.cpp
+++ b/native/jni/src/suggest/core/suggest.cpp
@@ -160,8 +160,7 @@ void Suggest::expandCurrentDicNodes(DicTraverseSession *traverseSession) const {
// TODO: Remove. Do not prune node here.
const bool allowsErrorCorrections = TRAVERSAL->allowsErrorCorrections(&dicNode);
// Process for handling space substitution (e.g., hevis => he is)
- if (allowsErrorCorrections
- && TRAVERSAL->isSpaceSubstitutionTerminal(traverseSession, &dicNode)) {
+ if (TRAVERSAL->isSpaceSubstitutionTerminal(traverseSession, &dicNode)) {
createNextWordDicNode(traverseSession, &dicNode, true /* spaceSubstitution */);
}
@@ -417,6 +416,11 @@ void Suggest::createNextWordDicNode(DicTraverseSession *traverseSession, DicNode
traverseSession->getDictionaryStructurePolicy()->getWordAttributesInContext(
dicNode->getPrevWordIds(), dicNode->getWordId(),
traverseSession->getMultiBigramMap());
+ if (SuggestionsOutputUtils::shouldBlockWord(traverseSession->getSuggestOptions(),
+ dicNode, wordAttributes, false /* isLastWord */)) {
+ return;
+ }
+
if (!TRAVERSAL->isGoodToTraverseNextWord(dicNode, wordAttributes.getProbability())) {
return;
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp
index 4c4dfc578..300e96c4e 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp
@@ -30,6 +30,7 @@ const char *const HeaderPolicy::DATE_KEY = "date";
const char *const HeaderPolicy::LAST_DECAYED_TIME_KEY = "LAST_DECAYED_TIME";
const char *const HeaderPolicy::UNIGRAM_COUNT_KEY = "UNIGRAM_COUNT";
const char *const HeaderPolicy::BIGRAM_COUNT_KEY = "BIGRAM_COUNT";
+const char *const HeaderPolicy::TRIGRAM_COUNT_KEY = "TRIGRAM_COUNT";
const char *const HeaderPolicy::EXTENDED_REGION_SIZE_KEY = "EXTENDED_REGION_SIZE";
// Historical info is information that is needed to support decaying such as timestamp, level and
// count.
@@ -38,15 +39,17 @@ const char *const HeaderPolicy::LOCALE_KEY = "locale"; // match Java declaration
const char *const HeaderPolicy::FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY =
"FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID";
-const char *const HeaderPolicy::MAX_UNIGRAM_COUNT_KEY = "MAX_UNIGRAM_COUNT";
-const char *const HeaderPolicy::MAX_BIGRAM_COUNT_KEY = "MAX_BIGRAM_COUNT";
+const char *const HeaderPolicy::MAX_UNIGRAM_COUNT_KEY = "MAX_UNIGRAM_ENTRY_COUNT";
+const char *const HeaderPolicy::MAX_BIGRAM_COUNT_KEY = "MAX_BIGRAM_ENTRY_COUNT";
+const char *const HeaderPolicy::MAX_TRIGRAM_COUNT_KEY = "MAX_TRIGRAM_ENTRY_COUNT";
const int HeaderPolicy::DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE = 100;
const float HeaderPolicy::MULTIPLE_WORD_COST_MULTIPLIER_SCALE = 100.0f;
const int HeaderPolicy::DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID = 3;
const int HeaderPolicy::DEFAULT_MAX_UNIGRAM_COUNT = 10000;
-const int HeaderPolicy::DEFAULT_MAX_BIGRAM_COUNT = 10000;
+const int HeaderPolicy::DEFAULT_MAX_BIGRAM_COUNT = 30000;
+const int HeaderPolicy::DEFAULT_MAX_TRIGRAM_COUNT = 30000;
// Used for logging. Question mark is used to indicate that the key is not found.
void HeaderPolicy::readHeaderValueOrQuestionMark(const char *const key, int *outValue,
@@ -92,12 +95,11 @@ bool HeaderPolicy::readRequiresGermanUmlautProcessing() const {
}
bool HeaderPolicy::fillInAndWriteHeaderToBuffer(const bool updatesLastDecayedTime,
- const int unigramCount, const int bigramCount,
- const int extendedRegionSize, BufferWithExtendableBuffer *const outBuffer) const {
+ const EntryCounts &entryCounts, const int extendedRegionSize,
+ BufferWithExtendableBuffer *const outBuffer) const {
int writingPos = 0;
DictionaryHeaderStructurePolicy::AttributeMap attributeMapToWrite(mAttributeMap);
- fillInHeader(updatesLastDecayedTime, unigramCount, bigramCount,
- extendedRegionSize, &attributeMapToWrite);
+ fillInHeader(updatesLastDecayedTime, entryCounts, extendedRegionSize, &attributeMapToWrite);
if (!HeaderReadWriteUtils::writeDictionaryVersion(outBuffer, mDictFormatVersion,
&writingPos)) {
return false;
@@ -124,11 +126,15 @@ bool HeaderPolicy::fillInAndWriteHeaderToBuffer(const bool updatesLastDecayedTim
return true;
}
-void HeaderPolicy::fillInHeader(const bool updatesLastDecayedTime, const int unigramCount,
- const int bigramCount, const int extendedRegionSize,
+void HeaderPolicy::fillInHeader(const bool updatesLastDecayedTime,
+ const EntryCounts &entryCounts, const int extendedRegionSize,
DictionaryHeaderStructurePolicy::AttributeMap *outAttributeMap) const {
- HeaderReadWriteUtils::setIntAttribute(outAttributeMap, UNIGRAM_COUNT_KEY, unigramCount);
- HeaderReadWriteUtils::setIntAttribute(outAttributeMap, BIGRAM_COUNT_KEY, bigramCount);
+ HeaderReadWriteUtils::setIntAttribute(outAttributeMap, UNIGRAM_COUNT_KEY,
+ entryCounts.getUnigramCount());
+ HeaderReadWriteUtils::setIntAttribute(outAttributeMap, BIGRAM_COUNT_KEY,
+ entryCounts.getBigramCount());
+ HeaderReadWriteUtils::setIntAttribute(outAttributeMap, TRIGRAM_COUNT_KEY,
+ entryCounts.getTrigramCount());
HeaderReadWriteUtils::setIntAttribute(outAttributeMap, EXTENDED_REGION_SIZE_KEY,
extendedRegionSize);
// Set the current time as the generation time.
diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h
index bc8eaded3..7a5acd7d5 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h
@@ -22,6 +22,7 @@
#include "defines.h"
#include "suggest/core/policy/dictionary_header_structure_policy.h"
#include "suggest/policyimpl/dictionary/header/header_read_write_utils.h"
+#include "suggest/policyimpl/dictionary/utils/entry_counters.h"
#include "suggest/policyimpl/dictionary/utils/format_utils.h"
#include "utils/char_utils.h"
#include "utils/time_keeper.h"
@@ -49,6 +50,8 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
UNIGRAM_COUNT_KEY, 0 /* defaultValue */)),
mBigramCount(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
BIGRAM_COUNT_KEY, 0 /* defaultValue */)),
+ mTrigramCount(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
+ TRIGRAM_COUNT_KEY, 0 /* defaultValue */)),
mExtendedRegionSize(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
EXTENDED_REGION_SIZE_KEY, 0 /* defaultValue */)),
mHasHistoricalInfoOfWords(HeaderReadWriteUtils::readBoolAttributeValue(
@@ -60,6 +63,8 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
&mAttributeMap, MAX_UNIGRAM_COUNT_KEY, DEFAULT_MAX_UNIGRAM_COUNT)),
mMaxBigramCount(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, MAX_BIGRAM_COUNT_KEY, DEFAULT_MAX_BIGRAM_COUNT)),
+ mMaxTrigramCount(HeaderReadWriteUtils::readIntAttributeValue(
+ &mAttributeMap, MAX_TRIGRAM_COUNT_KEY, DEFAULT_MAX_TRIGRAM_COUNT)),
mCodePointTable(HeaderReadWriteUtils::readCodePointTable(&mAttributeMap)) {}
// Constructs header information using an attribute map.
@@ -77,7 +82,7 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
DATE_KEY, TimeKeeper::peekCurrentTime() /* defaultValue */)),
mLastDecayedTime(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
DATE_KEY, TimeKeeper::peekCurrentTime() /* defaultValue */)),
- mUnigramCount(0), mBigramCount(0), mExtendedRegionSize(0),
+ mUnigramCount(0), mBigramCount(0), mTrigramCount(0), mExtendedRegionSize(0),
mHasHistoricalInfoOfWords(HeaderReadWriteUtils::readBoolAttributeValue(
&mAttributeMap, HAS_HISTORICAL_INFO_KEY, false /* defaultValue */)),
mForgettingCurveProbabilityValuesTableId(HeaderReadWriteUtils::readIntAttributeValue(
@@ -87,6 +92,8 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
&mAttributeMap, MAX_UNIGRAM_COUNT_KEY, DEFAULT_MAX_UNIGRAM_COUNT)),
mMaxBigramCount(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, MAX_BIGRAM_COUNT_KEY, DEFAULT_MAX_BIGRAM_COUNT)),
+ mMaxTrigramCount(HeaderReadWriteUtils::readIntAttributeValue(
+ &mAttributeMap, MAX_TRIGRAM_COUNT_KEY, DEFAULT_MAX_TRIGRAM_COUNT)),
mCodePointTable(HeaderReadWriteUtils::readCodePointTable(&mAttributeMap)) {}
// Copy header information
@@ -99,12 +106,14 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
mIsDecayingDict(headerPolicy->mIsDecayingDict),
mDate(headerPolicy->mDate), mLastDecayedTime(headerPolicy->mLastDecayedTime),
mUnigramCount(headerPolicy->mUnigramCount), mBigramCount(headerPolicy->mBigramCount),
+ mTrigramCount(headerPolicy->mTrigramCount),
mExtendedRegionSize(headerPolicy->mExtendedRegionSize),
mHasHistoricalInfoOfWords(headerPolicy->mHasHistoricalInfoOfWords),
mForgettingCurveProbabilityValuesTableId(
headerPolicy->mForgettingCurveProbabilityValuesTableId),
mMaxUnigramCount(headerPolicy->mMaxUnigramCount),
mMaxBigramCount(headerPolicy->mMaxBigramCount),
+ mMaxTrigramCount(headerPolicy->mMaxTrigramCount),
mCodePointTable(headerPolicy->mCodePointTable) {}
// Temporary dummy header.
@@ -112,10 +121,10 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
: mDictFormatVersion(FormatUtils::UNKNOWN_VERSION), mDictionaryFlags(0), mSize(0),
mAttributeMap(), mLocale(CharUtils::EMPTY_STRING), mMultiWordCostMultiplier(0.0f),
mRequiresGermanUmlautProcessing(false), mIsDecayingDict(false),
- mDate(0), mLastDecayedTime(0), mUnigramCount(0), mBigramCount(0),
+ mDate(0), mLastDecayedTime(0), mUnigramCount(0), mBigramCount(0), mTrigramCount(0),
mExtendedRegionSize(0), mHasHistoricalInfoOfWords(false),
mForgettingCurveProbabilityValuesTableId(0), mMaxUnigramCount(0), mMaxBigramCount(0),
- mCodePointTable(nullptr) {}
+ mMaxTrigramCount(0), mCodePointTable(nullptr) {}
~HeaderPolicy() {}
@@ -125,15 +134,17 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
// same so we use them for both here.
switch (mDictFormatVersion) {
case FormatUtils::VERSION_2:
- return FormatUtils::VERSION_2;
case FormatUtils::VERSION_201:
- return FormatUtils::VERSION_201;
+ AKLOGE("Dictionary versions 2 and 201 are incompatible with this version");
+ return FormatUtils::UNKNOWN_VERSION;
+ case FormatUtils::VERSION_202:
+ return FormatUtils::VERSION_202;
case FormatUtils::VERSION_4_ONLY_FOR_TESTING:
return FormatUtils::VERSION_4_ONLY_FOR_TESTING;
- case FormatUtils::VERSION_4:
- return FormatUtils::VERSION_4;
- case FormatUtils::VERSION_4_DEV:
- return FormatUtils::VERSION_4_DEV;
+ case FormatUtils::VERSION_402:
+ return FormatUtils::VERSION_402;
+ case FormatUtils::VERSION_403:
+ return FormatUtils::VERSION_403;
default:
return FormatUtils::UNKNOWN_VERSION;
}
@@ -183,6 +194,10 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
return mBigramCount;
}
+ AK_FORCE_INLINE int getTrigramCount() const {
+ return mTrigramCount;
+ }
+
AK_FORCE_INLINE int getExtendedRegionSize() const {
return mExtendedRegionSize;
}
@@ -212,15 +227,19 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
return mMaxBigramCount;
}
+ AK_FORCE_INLINE int getMaxTrigramCount() const {
+ return mMaxTrigramCount;
+ }
+
void readHeaderValueOrQuestionMark(const char *const key,
int *outValue, int outValueSize) const;
bool fillInAndWriteHeaderToBuffer(const bool updatesLastDecayedTime,
- const int unigramCount, const int bigramCount,
- const int extendedRegionSize, BufferWithExtendableBuffer *const outBuffer) const;
+ const EntryCounts &entryCounts, const int extendedRegionSize,
+ BufferWithExtendableBuffer *const outBuffer) const;
- void fillInHeader(const bool updatesLastDecayedTime,
- const int unigramCount, const int bigramCount, const int extendedRegionSize,
+ void fillInHeader(const bool updatesLastDecayedTime, const EntryCounts &entryCounts,
+ const int extendedRegionSize,
DictionaryHeaderStructurePolicy::AttributeMap *outAttributeMap) const;
AK_FORCE_INLINE const std::vector<int> *getLocale() const {
@@ -228,7 +247,7 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
}
bool supportsBeginningOfSentence() const {
- return mDictFormatVersion >= FormatUtils::VERSION_4;
+ return mDictFormatVersion >= FormatUtils::VERSION_402;
}
const int *getCodePointTable() const {
@@ -245,6 +264,7 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
static const char *const LAST_DECAYED_TIME_KEY;
static const char *const UNIGRAM_COUNT_KEY;
static const char *const BIGRAM_COUNT_KEY;
+ static const char *const TRIGRAM_COUNT_KEY;
static const char *const EXTENDED_REGION_SIZE_KEY;
static const char *const HAS_HISTORICAL_INFO_KEY;
static const char *const LOCALE_KEY;
@@ -253,11 +273,13 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
static const char *const FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS_KEY;
static const char *const MAX_UNIGRAM_COUNT_KEY;
static const char *const MAX_BIGRAM_COUNT_KEY;
+ static const char *const MAX_TRIGRAM_COUNT_KEY;
static const int DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE;
static const float MULTIPLE_WORD_COST_MULTIPLIER_SCALE;
static const int DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID;
static const int DEFAULT_MAX_UNIGRAM_COUNT;
static const int DEFAULT_MAX_BIGRAM_COUNT;
+ static const int DEFAULT_MAX_TRIGRAM_COUNT;
const FormatUtils::FORMAT_VERSION mDictFormatVersion;
const HeaderReadWriteUtils::DictionaryFlags mDictionaryFlags;
@@ -271,11 +293,13 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
const int mLastDecayedTime;
const int mUnigramCount;
const int mBigramCount;
+ const int mTrigramCount;
const int mExtendedRegionSize;
const bool mHasHistoricalInfoOfWords;
const int mForgettingCurveProbabilityValuesTableId;
const int mMaxUnigramCount;
const int mMaxBigramCount;
+ const int mMaxTrigramCount;
const int *const mCodePointTable;
const std::vector<int> readLocale() const;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp
index 41a8b13b8..19ed0d468 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp
@@ -111,11 +111,12 @@ typedef DictionaryHeaderStructurePolicy::AttributeMap AttributeMap;
switch (version) {
case FormatUtils::VERSION_2:
case FormatUtils::VERSION_201:
- // Version 2 or 201 dictionary writing is not supported.
+ case FormatUtils::VERSION_202:
+ // None of the static dictionaries (v2x) support writing
return false;
case FormatUtils::VERSION_4_ONLY_FOR_TESTING:
- case FormatUtils::VERSION_4:
- case FormatUtils::VERSION_4_DEV:
+ case FormatUtils::VERSION_402:
+ case FormatUtils::VERSION_403:
return buffer->writeUintAndAdvancePosition(version /* data */,
HEADER_DICTIONARY_VERSION_SIZE, writingPos);
default:
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/bigram_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/bigram_dict_content.cpp
index 9e1adff70..15ac88319 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/bigram_dict_content.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/bigram_dict_content.cpp
@@ -65,6 +65,8 @@ const BigramEntry BigramDictContent::getBigramEntryAndAdvancePosition(
(encodedTargetTerminalId == Ver4DictConstants::INVALID_BIGRAM_TARGET_TERMINAL_ID) ?
Ver4DictConstants::NOT_A_TERMINAL_ID : encodedTargetTerminalId;
if (mHasHistoricalInfo) {
+ // Hack for better migration.
+ count += level;
const HistoricalInfo historicalInfo(timestamp, level, count);
return BigramEntry(hasNext, probability, &historicalInfo, targetTerminalId);
} else {
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/probability_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/probability_dict_content.cpp
index ef6166ffd..61ef4aa42 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/probability_dict_content.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/probability_dict_content.cpp
@@ -50,7 +50,8 @@ const ProbabilityEntry ProbabilityDictContent::getProbabilityEntry(const int ter
Ver4DictConstants::WORD_LEVEL_FIELD_SIZE, &entryPos);
const int count = buffer->readUintAndAdvancePosition(
Ver4DictConstants::WORD_COUNT_FIELD_SIZE, &entryPos);
- const HistoricalInfo historicalInfo(timestamp, level, count);
+ // Hack for better migration.
+ const HistoricalInfo historicalInfo(timestamp, level, count + level);
return ProbabilityEntry(flags, probability, &historicalInfo);
} else {
return ProbabilityEntry(flags, probability);
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp
index 6243f14cc..d558b949a 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp
@@ -245,7 +245,7 @@ bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds
if (!sourcePtNodeParams.hasBigrams()) {
// Update has bigrams flag.
return updatePtNodeFlags(sourcePtNodeParams.getHeadPos(),
- sourcePtNodeParams.isBlacklisted(), sourcePtNodeParams.isNotAWord(),
+ sourcePtNodeParams.isPossiblyOffensive(), sourcePtNodeParams.isNotAWord(),
sourcePtNodeParams.isTerminal(), sourcePtNodeParams.hasShortcutTargets(),
true /* hasBigrams */,
sourcePtNodeParams.getCodePointCount() > 1 /* hasMultipleChars */);
@@ -316,7 +316,7 @@ bool Ver4PatriciaTrieNodeWriter::addShortcutTarget(const PtNodeParams *const ptN
if (!ptNodeParams->hasShortcutTargets()) {
// Update has shortcut targets flag.
return updatePtNodeFlags(ptNodeParams->getHeadPos(),
- ptNodeParams->isBlacklisted(), ptNodeParams->isNotAWord(),
+ ptNodeParams->isPossiblyOffensive(), ptNodeParams->isNotAWord(),
ptNodeParams->isTerminal(), true /* hasShortcutTargets */,
ptNodeParams->hasBigrams(),
ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */);
@@ -330,7 +330,7 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeHasBigramsAndShortcutTargetsFlags(
ptNodeParams->getTerminalId()) != NOT_A_DICT_POS;
const bool hasShortcutTargets = mBuffers->getShortcutDictContent()->getShortcutListHeadPos(
ptNodeParams->getTerminalId()) != NOT_A_DICT_POS;
- return updatePtNodeFlags(ptNodeParams->getHeadPos(), ptNodeParams->isBlacklisted(),
+ return updatePtNodeFlags(ptNodeParams->getHeadPos(), ptNodeParams->isPossiblyOffensive(),
ptNodeParams->isNotAWord(), ptNodeParams->isTerminal(), hasShortcutTargets,
hasBigrams, ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */);
}
@@ -386,8 +386,9 @@ bool Ver4PatriciaTrieNodeWriter::writePtNodeAndGetTerminalIdAndAdvancePosition(
ptNodeParams->getChildrenPos(), ptNodeWritingPos)) {
return false;
}
- return updatePtNodeFlags(nodePos, ptNodeParams->isBlacklisted(), ptNodeParams->isNotAWord(),
- isTerminal, ptNodeParams->hasShortcutTargets(), ptNodeParams->hasBigrams(),
+ return updatePtNodeFlags(nodePos, ptNodeParams->isPossiblyOffensive(),
+ ptNodeParams->isNotAWord(), isTerminal, ptNodeParams->hasShortcutTargets(),
+ ptNodeParams->hasBigrams(),
ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */);
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
index 0eae934ae..9455222dd 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
@@ -140,7 +140,7 @@ const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext(
const WordAttributes Ver4PatriciaTriePolicy::getWordAttributes(const int probability,
const PtNodeParams &ptNodeParams) const {
- return WordAttributes(probability, ptNodeParams.isBlacklisted(), ptNodeParams.isNotAWord(),
+ return WordAttributes(probability, false /* isBlacklisted */, ptNodeParams.isNotAWord(),
ptNodeParams.getProbability() == 0);
}
@@ -164,7 +164,7 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordI
}
const int ptNodePos = getTerminalPtNodePosFromWordId(wordId);
const PtNodeParams ptNodeParams(mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos));
- if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) {
+ if (ptNodeParams.isDeleted() || ptNodeParams.isNotAWord()) {
return NOT_A_PROBABILITY;
}
if (prevWordIds.empty()) {
@@ -303,7 +303,7 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const CodePointArrayView wordCodePo
if (mUpdatingHelper.addUnigramWord(&readingHelper, codePointArrayView, unigramProperty,
&addedNewUnigram)) {
if (addedNewUnigram && !unigramProperty->representsBeginningOfSentence()) {
- mUnigramCount++;
+ mEntryCounters.incrementUnigramCount();
}
if (unigramProperty->getShortcuts().size() > 0) {
// Add shortcut target.
@@ -344,8 +344,7 @@ bool Ver4PatriciaTriePolicy::removeUnigramEntry(const CodePointArrayView wordCod
return mNodeWriter.suppressUnigramEntry(&ptNodeParams);
}
-bool Ver4PatriciaTriePolicy::addNgramEntry(const NgramContext *const ngramContext,
- const NgramProperty *const ngramProperty) {
+bool Ver4PatriciaTriePolicy::addNgramEntry(const NgramProperty *const ngramProperty) {
if (!mBuffers->isUpdatable()) {
AKLOGI("Warning: addNgramEntry() is called for non-updatable dictionary.");
return false;
@@ -355,6 +354,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const NgramContext *const ngramContex
mDictBuffer->getTailPosition());
return false;
}
+ const NgramContext *const ngramContext = ngramProperty->getNgramContext();
if (!ngramContext->isValid()) {
AKLOGE("Ngram context is not valid for adding n-gram entry to the dictionary.");
return false;
@@ -397,7 +397,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const NgramContext *const ngramContex
if (mUpdatingHelper.addNgramEntry(PtNodePosArrayView::singleElementView(&prevWordPtNodePos),
wordPos, ngramProperty, &addedNewBigram)) {
if (addedNewBigram) {
- mBigramCount++;
+ mEntryCounters.incrementBigramCount();
}
return true;
} else {
@@ -438,7 +438,7 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const NgramContext *const ngramCon
const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]);
if (mUpdatingHelper.removeNgramEntry(
PtNodePosArrayView::singleElementView(&prevWordPtNodePos), wordPos)) {
- mBigramCount--;
+ mEntryCounters.decrementBigramCount();
return true;
} else {
return false;
@@ -463,9 +463,9 @@ bool Ver4PatriciaTriePolicy::updateEntriesForWordWithNgramContext(
}
const int probabilityForNgram = ngramContext->isNthPrevWordBeginningOfSentence(1 /* n */)
? NOT_A_PROBABILITY : probability;
- const NgramProperty ngramProperty(wordCodePoints.toVector(), probabilityForNgram,
+ const NgramProperty ngramProperty(*ngramContext, wordCodePoints.toVector(), probabilityForNgram,
historicalInfo);
- if (!addNgramEntry(ngramContext, &ngramProperty)) {
+ if (!addNgramEntry(&ngramProperty)) {
AKLOGE("Cannot update unigarm entry in updateEntriesForWordWithNgramContext().");
return false;
}
@@ -477,7 +477,7 @@ bool Ver4PatriciaTriePolicy::flush(const char *const filePath) {
AKLOGI("Warning: flush() is called for non-updatable dictionary. filePath: %s", filePath);
return false;
}
- if (!mWritingHelper.writeToDictFile(filePath, mUnigramCount, mBigramCount)) {
+ if (!mWritingHelper.writeToDictFile(filePath, mEntryCounters.getEntryCounts())) {
AKLOGE("Cannot flush the dictionary to file.");
mIsCorrupted = true;
return false;
@@ -515,7 +515,7 @@ bool Ver4PatriciaTriePolicy::needsToRunGC(const bool mindsBlockByGC) const {
// Needs to reduce dictionary size.
return true;
} else if (mHeaderPolicy->isDecayingDict()) {
- return ForgettingCurveUtils::needsToDecay(mindsBlockByGC, mUnigramCount, mBigramCount,
+ return ForgettingCurveUtils::needsToDecay(mindsBlockByGC, mEntryCounters.getEntryCounts(),
mHeaderPolicy);
}
return false;
@@ -525,19 +525,19 @@ void Ver4PatriciaTriePolicy::getProperty(const char *const query, const int quer
char *const outResult, const int maxResultLength) {
const int compareLength = queryLength + 1 /* terminator */;
if (strncmp(query, UNIGRAM_COUNT_QUERY, compareLength) == 0) {
- snprintf(outResult, maxResultLength, "%d", mUnigramCount);
+ snprintf(outResult, maxResultLength, "%d", mEntryCounters.getUnigramCount());
} else if (strncmp(query, BIGRAM_COUNT_QUERY, compareLength) == 0) {
- snprintf(outResult, maxResultLength, "%d", mBigramCount);
+ snprintf(outResult, maxResultLength, "%d", mEntryCounters.getBigramCount());
} else if (strncmp(query, MAX_UNIGRAM_COUNT_QUERY, compareLength) == 0) {
snprintf(outResult, maxResultLength, "%d",
mHeaderPolicy->isDecayingDict() ?
- ForgettingCurveUtils::getUnigramCountHardLimit(
+ ForgettingCurveUtils::getEntryCountHardLimit(
mHeaderPolicy->getMaxUnigramCount()) :
static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE));
} else if (strncmp(query, MAX_BIGRAM_COUNT_QUERY, compareLength) == 0) {
snprintf(outResult, maxResultLength, "%d",
mHeaderPolicy->isDecayingDict() ?
- ForgettingCurveUtils::getBigramCountHardLimit(
+ ForgettingCurveUtils::getEntryCountHardLimit(
mHeaderPolicy->getMaxBigramCount()) :
static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE));
}
@@ -580,11 +580,15 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
getWordIdFromTerminalPtNodePos(word1TerminalPtNodePos), MAX_WORD_LENGTH,
bigramWord1CodePoints);
const HistoricalInfo *const historicalInfo = bigramEntry.getHistoricalInfo();
- const int probability = bigramEntry.hasHistoricalInfo() ?
- ForgettingCurveUtils::decodeProbability(
- bigramEntry.getHistoricalInfo(), mHeaderPolicy) :
- bigramEntry.getProbability();
+ const int rawBigramProbability = bigramEntry.hasHistoricalInfo()
+ ? ForgettingCurveUtils::decodeProbability(
+ bigramEntry.getHistoricalInfo(), mHeaderPolicy)
+ : bigramEntry.getProbability();
+ const int probability = getBigramConditionalProbability(ptNodeParams.getProbability(),
+ ptNodeParams.representsBeginningOfSentence(), rawBigramProbability);
ngrams.emplace_back(
+ NgramContext(wordCodePoints.data(), wordCodePoints.size(),
+ ptNodeParams.representsBeginningOfSentence()),
CodePointArrayView(bigramWord1CodePoints, codePointCount).toVector(),
probability, *historicalInfo);
}
@@ -608,8 +612,8 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
}
}
const UnigramProperty unigramProperty(ptNodeParams.representsBeginningOfSentence(),
- ptNodeParams.isNotAWord(), ptNodeParams.isBlacklisted(), ptNodeParams.getProbability(),
- *historicalInfo, std::move(shortcuts));
+ ptNodeParams.isNotAWord(), ptNodeParams.isPossiblyOffensive(),
+ ptNodeParams.getProbability(), *historicalInfo, std::move(shortcuts));
return WordProperty(wordCodePoints.toVector(), &unigramProperty, &ngrams);
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h
index 1ad5e7e36..0480876ed 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h
@@ -41,6 +41,7 @@
#include "suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.h"
#include "suggest/policyimpl/dictionary/structure/backward/v402/ver4_pt_node_array_reader.h"
#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h"
+#include "suggest/policyimpl/dictionary/utils/entry_counters.h"
#include "utils/int_array_view.h"
namespace latinime {
@@ -75,8 +76,8 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
&mPtNodeArrayReader, &mBigramPolicy, &mShortcutPolicy),
mUpdatingHelper(mDictBuffer, &mNodeReader, &mNodeWriter),
mWritingHelper(mBuffers.get()),
- mUnigramCount(mHeaderPolicy->getUnigramCount()),
- mBigramCount(mHeaderPolicy->getBigramCount()),
+ mEntryCounters(mHeaderPolicy->getUnigramCount(), mHeaderPolicy->getBigramCount(),
+ mHeaderPolicy->getTrigramCount()),
mTerminalPtNodePositionsForIteratingWords(), mIsCorrupted(false) {};
virtual int getRootPosition() const {
@@ -112,8 +113,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
bool removeUnigramEntry(const CodePointArrayView wordCodePoints);
- bool addNgramEntry(const NgramContext *const ngramContext,
- const NgramProperty *const ngramProperty);
+ bool addNgramEntry(const NgramProperty *const ngramProperty);
bool removeNgramEntry(const NgramContext *const ngramContext,
const CodePointArrayView wordCodePoints);
@@ -163,8 +163,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
Ver4PatriciaTrieNodeWriter mNodeWriter;
DynamicPtUpdatingHelper mUpdatingHelper;
Ver4PatriciaTrieWritingHelper mWritingHelper;
- int mUnigramCount;
- int mBigramCount;
+ MutableEntryCounters mEntryCounters;
std::vector<int> mTerminalPtNodePositionsForIteratingWords;
mutable bool mIsCorrupted;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.cpp
index 2887dc6b1..a033d396b 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.cpp
@@ -43,18 +43,18 @@ namespace backward {
namespace v402 {
bool Ver4PatriciaTrieWritingHelper::writeToDictFile(const char *const dictDirPath,
- const int unigramCount, const int bigramCount) const {
+ const EntryCounts &entryCounts) const {
const HeaderPolicy *const headerPolicy = mBuffers->getHeaderPolicy();
BufferWithExtendableBuffer headerBuffer(
BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE);
const int extendedRegionSize = headerPolicy->getExtendedRegionSize()
+ mBuffers->getTrieBuffer()->getUsedAdditionalBufferSize();
if (!headerPolicy->fillInAndWriteHeaderToBuffer(false /* updatesLastDecayedTime */,
- unigramCount, bigramCount, extendedRegionSize, &headerBuffer)) {
+ entryCounts, extendedRegionSize, &headerBuffer)) {
AKLOGE("Cannot write header structure to buffer. "
"updatesLastDecayedTime: %d, unigramCount: %d, bigramCount: %d, "
- "extendedRegionSize: %d", false, unigramCount, bigramCount,
- extendedRegionSize);
+ "extendedRegionSize: %d", false, entryCounts.getUnigramCount(),
+ entryCounts.getBigramCount(), extendedRegionSize);
return false;
}
return mBuffers->flushHeaderAndDictBuffers(dictDirPath, &headerBuffer);
@@ -74,7 +74,8 @@ bool Ver4PatriciaTrieWritingHelper::writeToDictFileWithGC(const int rootPtNodeAr
BufferWithExtendableBuffer headerBuffer(
BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE);
if (!headerPolicy->fillInAndWriteHeaderToBuffer(true /* updatesLastDecayedTime */,
- unigramCount, bigramCount, 0 /* extendedRegionSize */, &headerBuffer)) {
+ EntryCounts(unigramCount, bigramCount, 0 /* trigramCount */),
+ 0 /* extendedRegionSize */, &headerBuffer)) {
return false;
}
return dictBuffers->flushHeaderAndDictBuffers(dictDirPath, &headerBuffer);
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.h
index 9034ee656..1aad33e38 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.h
@@ -27,6 +27,7 @@
#include "defines.h"
#include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_gc_event_listeners.h"
#include "suggest/policyimpl/dictionary/structure/backward/v402/content/terminal_position_lookup_table.h"
+#include "suggest/policyimpl/dictionary/utils/entry_counters.h"
namespace latinime {
namespace backward {
@@ -46,8 +47,7 @@ class Ver4PatriciaTrieWritingHelper {
Ver4PatriciaTrieWritingHelper(Ver4DictBuffers *const buffers)
: mBuffers(buffers) {}
- bool writeToDictFile(const char *const dictDirPath, const int unigramCount,
- const int bigramCount) const;
+ bool writeToDictFile(const char *const dictDirPath, const EntryCounts &entryCounts) const;
// This method cannot be const because the original dictionary buffer will be updated to detect
// useless PtNodes during GC.
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp
index 372c9e36f..9a9a21b6b 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp
@@ -58,7 +58,7 @@ namespace latinime {
const DictionaryHeaderStructurePolicy::AttributeMap *const attributeMap) {
FormatUtils::FORMAT_VERSION dictFormatVersion = FormatUtils::getFormatVersion(formatVersion);
switch (dictFormatVersion) {
- case FormatUtils::VERSION_4: {
+ case FormatUtils::VERSION_402: {
return newPolicyForOnMemoryV4Dict<backward::v402::Ver4DictConstants,
backward::v402::Ver4DictBuffers,
backward::v402::Ver4DictBuffers::Ver4DictBuffersPtr,
@@ -66,7 +66,7 @@ namespace latinime {
dictFormatVersion, locale, attributeMap);
}
case FormatUtils::VERSION_4_ONLY_FOR_TESTING:
- case FormatUtils::VERSION_4_DEV: {
+ case FormatUtils::VERSION_403: {
return newPolicyForOnMemoryV4Dict<Ver4DictConstants, Ver4DictBuffers,
Ver4DictBuffers::Ver4DictBuffersPtr, Ver4PatriciaTriePolicy>(
dictFormatVersion, locale, attributeMap);
@@ -115,9 +115,10 @@ template<class DictConstants, class DictBuffers, class DictBuffersPtr, class Str
switch (formatVersion) {
case FormatUtils::VERSION_2:
case FormatUtils::VERSION_201:
- AKLOGE("Given path is a directory but the format is version 2 or 201. path: %s", path);
+ case FormatUtils::VERSION_202:
+ AKLOGE("Given path is a directory but the format is version 2xx. path: %s", path);
break;
- case FormatUtils::VERSION_4: {
+ case FormatUtils::VERSION_402: {
return newPolicyForV4Dict<backward::v402::Ver4DictConstants,
backward::v402::Ver4DictBuffers,
backward::v402::Ver4DictBuffers::Ver4DictBuffersPtr,
@@ -125,7 +126,7 @@ template<class DictConstants, class DictBuffers, class DictBuffersPtr, class Str
headerFilePath, formatVersion, std::move(mmappedBuffer));
}
case FormatUtils::VERSION_4_ONLY_FOR_TESTING:
- case FormatUtils::VERSION_4_DEV: {
+ case FormatUtils::VERSION_403: {
return newPolicyForV4Dict<Ver4DictConstants, Ver4DictBuffers,
Ver4DictBuffers::Ver4DictBuffersPtr, Ver4PatriciaTriePolicy>(
headerFilePath, formatVersion, std::move(mmappedBuffer));
@@ -177,11 +178,14 @@ template<class DictConstants, class DictBuffers, class DictBuffersPtr, class Str
switch (FormatUtils::detectFormatVersion(mmappedBuffer->getReadOnlyByteArrayView())) {
case FormatUtils::VERSION_2:
case FormatUtils::VERSION_201:
+ AKLOGE("Dictionary versions 2 and 201 are incompatible with this version");
+ break;
+ case FormatUtils::VERSION_202:
return DictionaryStructureWithBufferPolicy::StructurePolicyPtr(
new PatriciaTriePolicy(std::move(mmappedBuffer)));
case FormatUtils::VERSION_4_ONLY_FOR_TESTING:
- case FormatUtils::VERSION_4:
- case FormatUtils::VERSION_4_DEV:
+ case FormatUtils::VERSION_402:
+ case FormatUtils::VERSION_403:
AKLOGE("Given path is a file but the format is version 4. path: %s", path);
break;
default:
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.cpp
index 92fd6f214..e524e86e5 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.cpp
@@ -146,7 +146,7 @@ bool DynamicPtUpdatingHelper::setPtNodeProbability(const PtNodeParams *const ori
const int movedPos = mBuffer->getTailPosition();
int writingPos = movedPos;
const PtNodeParams ptNodeParamsToWrite(getUpdatedPtNodeParams(originalPtNodeParams,
- unigramProperty->isNotAWord(), unigramProperty->isBlacklisted(),
+ unigramProperty->isNotAWord(), unigramProperty->isPossiblyOffensive(),
true /* isTerminal */, originalPtNodeParams->getParentPos(),
originalPtNodeParams->getCodePointArrayView(), unigramProperty->getProbability()));
if (!mPtNodeWriter->writeNewTerminalPtNodeAndAdvancePosition(&ptNodeParamsToWrite,
@@ -180,8 +180,9 @@ bool DynamicPtUpdatingHelper::createNewPtNodeArrayWithAChildPtNode(
return false;
}
const PtNodeParams ptNodeParamsToWrite(getPtNodeParamsForNewPtNode(
- unigramProperty->isNotAWord(), unigramProperty->isBlacklisted(), true /* isTerminal */,
- parentPtNodePos, ptNodeCodePoints, unigramProperty->getProbability()));
+ unigramProperty->isNotAWord(), unigramProperty->isPossiblyOffensive(),
+ true /* isTerminal */, parentPtNodePos, ptNodeCodePoints,
+ unigramProperty->getProbability()));
if (!mPtNodeWriter->writeNewTerminalPtNodeAndAdvancePosition(&ptNodeParamsToWrite,
unigramProperty, &writingPos)) {
return false;
@@ -214,7 +215,7 @@ bool DynamicPtUpdatingHelper::reallocatePtNodeAndAddNewPtNodes(
reallocatingPtNodeParams->getCodePointArrayView().limit(overlappingCodePointCount);
if (addsExtraChild) {
const PtNodeParams ptNodeParamsToWrite(getPtNodeParamsForNewPtNode(
- false /* isNotAWord */, false /* isBlacklisted */, false /* isTerminal */,
+ false /* isNotAWord */, false /* isPossiblyOffensive */, false /* isTerminal */,
reallocatingPtNodeParams->getParentPos(), firstPtNodeCodePoints,
NOT_A_PROBABILITY));
if (!mPtNodeWriter->writePtNodeAndAdvancePosition(&ptNodeParamsToWrite, &writingPos)) {
@@ -222,7 +223,7 @@ bool DynamicPtUpdatingHelper::reallocatePtNodeAndAddNewPtNodes(
}
} else {
const PtNodeParams ptNodeParamsToWrite(getPtNodeParamsForNewPtNode(
- unigramProperty->isNotAWord(), unigramProperty->isBlacklisted(),
+ unigramProperty->isNotAWord(), unigramProperty->isPossiblyOffensive(),
true /* isTerminal */, reallocatingPtNodeParams->getParentPos(),
firstPtNodeCodePoints, unigramProperty->getProbability()));
if (!mPtNodeWriter->writeNewTerminalPtNodeAndAdvancePosition(&ptNodeParamsToWrite,
@@ -240,7 +241,7 @@ bool DynamicPtUpdatingHelper::reallocatePtNodeAndAddNewPtNodes(
// Write the 2nd part of the reallocating node.
const int secondPartOfReallocatedPtNodePos = writingPos;
const PtNodeParams childPartPtNodeParams(getUpdatedPtNodeParams(reallocatingPtNodeParams,
- reallocatingPtNodeParams->isNotAWord(), reallocatingPtNodeParams->isBlacklisted(),
+ reallocatingPtNodeParams->isNotAWord(), reallocatingPtNodeParams->isPossiblyOffensive(),
reallocatingPtNodeParams->isTerminal(), firstPartOfReallocatedPtNodePos,
reallocatingPtNodeParams->getCodePointArrayView().skip(overlappingCodePointCount),
reallocatingPtNodeParams->getProbability()));
@@ -249,7 +250,7 @@ bool DynamicPtUpdatingHelper::reallocatePtNodeAndAddNewPtNodes(
}
if (addsExtraChild) {
const PtNodeParams extraChildPtNodeParams(getPtNodeParamsForNewPtNode(
- unigramProperty->isNotAWord(), unigramProperty->isBlacklisted(),
+ unigramProperty->isNotAWord(), unigramProperty->isPossiblyOffensive(),
true /* isTerminal */, firstPartOfReallocatedPtNodePos,
newPtNodeCodePoints.skip(overlappingCodePointCount),
unigramProperty->getProbability()));
@@ -276,20 +277,20 @@ bool DynamicPtUpdatingHelper::reallocatePtNodeAndAddNewPtNodes(
const PtNodeParams DynamicPtUpdatingHelper::getUpdatedPtNodeParams(
const PtNodeParams *const originalPtNodeParams, const bool isNotAWord,
- const bool isBlacklisted, const bool isTerminal, const int parentPos,
+ const bool isPossiblyOffensive, const bool isTerminal, const int parentPos,
const CodePointArrayView codePoints, const int probability) const {
const PatriciaTrieReadingUtils::NodeFlags flags = PatriciaTrieReadingUtils::createAndGetFlags(
- isBlacklisted, isNotAWord, isTerminal, false /* hasShortcutTargets */,
+ isPossiblyOffensive, isNotAWord, isTerminal, false /* hasShortcutTargets */,
false /* hasBigrams */, codePoints.size() > 1u /* hasMultipleChars */,
CHILDREN_POSITION_FIELD_SIZE);
return PtNodeParams(originalPtNodeParams, flags, parentPos, codePoints, probability);
}
const PtNodeParams DynamicPtUpdatingHelper::getPtNodeParamsForNewPtNode(const bool isNotAWord,
- const bool isBlacklisted, const bool isTerminal, const int parentPos,
+ const bool isPossiblyOffensive, const bool isTerminal, const int parentPos,
const CodePointArrayView codePoints, const int probability) const {
const PatriciaTrieReadingUtils::NodeFlags flags = PatriciaTrieReadingUtils::createAndGetFlags(
- isBlacklisted, isNotAWord, isTerminal, false /* hasShortcutTargets */,
+ isPossiblyOffensive, isNotAWord, isTerminal, false /* hasShortcutTargets */,
false /* hasBigrams */, codePoints.size() > 1u /* hasMultipleChars */,
CHILDREN_POSITION_FIELD_SIZE);
return PtNodeParams(flags, parentPos, codePoints, probability);
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.h
index 2bbe2f4dc..db5f6ab17 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.h
@@ -85,12 +85,12 @@ class DynamicPtUpdatingHelper {
const CodePointArrayView newPtNodeCodePoints);
const PtNodeParams getUpdatedPtNodeParams(const PtNodeParams *const originalPtNodeParams,
- const bool isNotAWord, const bool isBlacklisted, const bool isTerminal,
+ const bool isNotAWord, const bool isPossiblyOffensive, const bool isTerminal,
const int parentPos, const CodePointArrayView codePoints, const int probability) const;
- const PtNodeParams getPtNodeParamsForNewPtNode(const bool isNotAWord, const bool isBlacklisted,
- const bool isTerminal, const int parentPos, const CodePointArrayView codePoints,
- const int probability) const;
+ const PtNodeParams getPtNodeParamsForNewPtNode(const bool isNotAWord,
+ const bool isPossiblyOffensive, const bool isTerminal, const int parentPos,
+ const CodePointArrayView codePoints, const int probability) const;
};
} // namespace latinime
#endif /* LATINIME_DYNAMIC_PATRICIA_TRIE_UPDATING_HELPER_H */
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.cpp
index 6a498b2f4..b8d78bf10 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.cpp
@@ -41,8 +41,8 @@ const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_HAS_SHORTCUT_TARGETS = 0x08
const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_HAS_BIGRAMS = 0x04;
// Flag for non-words (typically, shortcut only entries)
const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_IS_NOT_A_WORD = 0x02;
-// Flag for blacklist
-const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_IS_BLACKLISTED = 0x01;
+// Flag for possibly offensive words
+const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_IS_POSSIBLY_OFFENSIVE = 0x01;
/* static */ int PtReadingUtils::getPtNodeArraySizeAndAdvancePosition(
const uint8_t *const buffer, int *const pos) {
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h
index a69ec4435..6a2bf5d3c 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h
@@ -54,8 +54,8 @@ class PatriciaTrieReadingUtils {
/**
* Node Flags
*/
- static AK_FORCE_INLINE bool isBlacklisted(const NodeFlags flags) {
- return (flags & FLAG_IS_BLACKLISTED) != 0;
+ static AK_FORCE_INLINE bool isPossiblyOffensive(const NodeFlags flags) {
+ return (flags & FLAG_IS_POSSIBLY_OFFENSIVE) != 0;
}
static AK_FORCE_INLINE bool isNotAWord(const NodeFlags flags) {
@@ -82,12 +82,12 @@ class PatriciaTrieReadingUtils {
return FLAG_CHILDREN_POSITION_TYPE_NOPOSITION != (MASK_CHILDREN_POSITION_TYPE & flags);
}
- static AK_FORCE_INLINE NodeFlags createAndGetFlags(const bool isBlacklisted,
+ static AK_FORCE_INLINE NodeFlags createAndGetFlags(const bool isPossiblyOffensive,
const bool isNotAWord, const bool isTerminal, const bool hasShortcutTargets,
const bool hasBigrams, const bool hasMultipleChars,
const int childrenPositionFieldSize) {
NodeFlags nodeFlags = 0;
- nodeFlags = isBlacklisted ? (nodeFlags | FLAG_IS_BLACKLISTED) : nodeFlags;
+ nodeFlags = isPossiblyOffensive ? (nodeFlags | FLAG_IS_POSSIBLY_OFFENSIVE) : nodeFlags;
nodeFlags = isNotAWord ? (nodeFlags | FLAG_IS_NOT_A_WORD) : nodeFlags;
nodeFlags = isTerminal ? (nodeFlags | FLAG_IS_TERMINAL) : nodeFlags;
nodeFlags = hasShortcutTargets ? (nodeFlags | FLAG_HAS_SHORTCUT_TARGETS) : nodeFlags;
@@ -127,7 +127,7 @@ class PatriciaTrieReadingUtils {
static const NodeFlags FLAG_HAS_SHORTCUT_TARGETS;
static const NodeFlags FLAG_HAS_BIGRAMS;
static const NodeFlags FLAG_IS_NOT_A_WORD;
- static const NodeFlags FLAG_IS_BLACKLISTED;
+ static const NodeFlags FLAG_IS_POSSIBLY_OFFENSIVE;
};
} // namespace latinime
#endif /* LATINIME_PATRICIA_TRIE_NODE_READING_UTILS_H */
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h
index 3ff1829bd..e52706e07 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h
@@ -144,8 +144,8 @@ class PtNodeParams {
return PatriciaTrieReadingUtils::isTerminal(mFlags);
}
- AK_FORCE_INLINE bool isBlacklisted() const {
- return PatriciaTrieReadingUtils::isBlacklisted(mFlags);
+ AK_FORCE_INLINE bool isPossiblyOffensive() const {
+ return PatriciaTrieReadingUtils::isPossiblyOffensive(mFlags);
}
AK_FORCE_INLINE bool isNotAWord() const {
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp
index b7f1199c5..59873612a 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp
@@ -14,7 +14,6 @@
* limitations under the License.
*/
-
#include "suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h"
#include "defines.h"
@@ -317,8 +316,8 @@ const WordAttributes PatriciaTriePolicy::getWordAttributesInContext(
const WordAttributes PatriciaTriePolicy::getWordAttributes(const int probability,
const PtNodeParams &ptNodeParams) const {
- return WordAttributes(probability, ptNodeParams.isBlacklisted(), ptNodeParams.isNotAWord(),
- ptNodeParams.getProbability() == 0);
+ return WordAttributes(probability, false /* isBlacklisted */, ptNodeParams.isNotAWord(),
+ ptNodeParams.isPossiblyOffensive());
}
int PatriciaTriePolicy::getProbability(const int unigramProbability,
@@ -345,10 +344,9 @@ int PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds,
const int ptNodePos = getTerminalPtNodePosFromWordId(wordId);
const PtNodeParams ptNodeParams =
mPtNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
- if (ptNodeParams.isNotAWord() || ptNodeParams.isBlacklisted()) {
- // If this is not a word, or if it's a blacklisted entry, it should behave as
- // having no probability outside of the suggestion process (where it should be used
- // for shortcuts).
+ if (ptNodeParams.isNotAWord()) {
+ // If this is not a word, it should behave as having no probability outside of the
+ // suggestion process (where it should be used for shortcuts).
return NOT_A_PROBABILITY;
}
if (!prevWordIds.empty()) {
@@ -451,6 +449,8 @@ const WordProperty PatriciaTriePolicy::getWordProperty(
bigramWord1CodePoints, &word1Probability);
const int probability = getProbability(word1Probability, bigramsIt.getProbability());
ngrams.emplace_back(
+ NgramContext(wordCodePoints.data(), wordCodePoints.size(),
+ ptNodeParams.representsBeginningOfSentence()),
CodePointArrayView(bigramWord1CodePoints, word1CodePointCount).toVector(),
probability, HistoricalInfo());
}
@@ -476,8 +476,8 @@ const WordProperty PatriciaTriePolicy::getWordProperty(
}
}
const UnigramProperty unigramProperty(ptNodeParams.representsBeginningOfSentence(),
- ptNodeParams.isNotAWord(), ptNodeParams.isBlacklisted(), ptNodeParams.getProbability(),
- HistoricalInfo(), std::move(shortcuts));
+ ptNodeParams.isNotAWord(), ptNodeParams.isPossiblyOffensive(),
+ ptNodeParams.getProbability(), HistoricalInfo(), std::move(shortcuts));
return WordProperty(wordCodePoints.toVector(), &unigramProperty, &ngrams);
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h
index b17681388..8933962ab 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h
@@ -93,8 +93,7 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
return false;
}
- bool addNgramEntry(const NgramContext *const ngramContext,
- const NgramProperty *const ngramProperty) {
+ bool addNgramEntry(const NgramProperty *const ngramProperty) {
// This method should not be called for non-updatable dictionary.
AKLOGI("Warning: addNgramEntry() is called for non-updatable dictionary.");
return false;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.cpp
new file mode 100644
index 000000000..b0fbb3e72
--- /dev/null
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.cpp
@@ -0,0 +1,37 @@
+/*
+ * Copyright (C) 2014, The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h"
+
+namespace latinime {
+
+// These counts are used to provide stable probabilities even if the user's input count is small.
+const int DynamicLanguageModelProbabilityUtils::ASSUMED_MIN_COUNT_FOR_UNIGRAMS = 8192;
+const int DynamicLanguageModelProbabilityUtils::ASSUMED_MIN_COUNT_FOR_BIGRAMS = 2;
+const int DynamicLanguageModelProbabilityUtils::ASSUMED_MIN_COUNT_FOR_TRIGRAMS = 2;
+
+// These are encoded backoff weights.
+// Note that we give positive value for trigrams that means the weight is more than 1.
+// TODO: Apply backoff for main dictionaries and quit giving a positive backoff weight.
+const int DynamicLanguageModelProbabilityUtils::ENCODED_BACKOFF_WEIGHT_FOR_UNIGRAMS = -32;
+const int DynamicLanguageModelProbabilityUtils::ENCODED_BACKOFF_WEIGHT_FOR_BIGRAMS = 0;
+const int DynamicLanguageModelProbabilityUtils::ENCODED_BACKOFF_WEIGHT_FOR_TRIGRAMS = 8;
+
+// This value is used to remove too old entries from the dictionary.
+const int DynamicLanguageModelProbabilityUtils::DURATION_TO_DISCARD_ENTRY_IN_SECONDS =
+ 300 * 24 * 60 * 60; // 300 days
+
+} // namespace latinime
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h
new file mode 100644
index 000000000..88bc58fe8
--- /dev/null
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h
@@ -0,0 +1,114 @@
+/*
+ * Copyright (C) 2014, The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LATINIME_DYNAMIC_LANGUAGE_MODEL_PROBABILITY_UTILS_H
+#define LATINIME_DYNAMIC_LANGUAGE_MODEL_PROBABILITY_UTILS_H
+
+#include <algorithm>
+
+#include "defines.h"
+#include "suggest/core/dictionary/property/historical_info.h"
+#include "utils/time_keeper.h"
+
+namespace latinime {
+
+class DynamicLanguageModelProbabilityUtils {
+ public:
+ static float computeRawProbabilityFromCounts(const int count, const int contextCount,
+ const int matchedWordCountInContext) {
+ int minCount = 0;
+ switch (matchedWordCountInContext) {
+ case 1:
+ minCount = ASSUMED_MIN_COUNT_FOR_UNIGRAMS;
+ break;
+ case 2:
+ minCount = ASSUMED_MIN_COUNT_FOR_BIGRAMS;
+ break;
+ case 3:
+ minCount = ASSUMED_MIN_COUNT_FOR_TRIGRAMS;
+ break;
+ default:
+ AKLOGE("computeRawProbabilityFromCounts is called with invalid "
+ "matchedWordCountInContext (%d).", matchedWordCountInContext);
+ ASSERT(false);
+ return 0.0f;
+ }
+ return static_cast<float>(count) / static_cast<float>(std::max(contextCount, minCount));
+ }
+
+ static float backoff(const int ngramProbability, const int matchedWordCountInContext) {
+ int probability = NOT_A_PROBABILITY;
+
+ switch (matchedWordCountInContext) {
+ case 1:
+ probability = ngramProbability + ENCODED_BACKOFF_WEIGHT_FOR_UNIGRAMS;
+ break;
+ case 2:
+ probability = ngramProbability + ENCODED_BACKOFF_WEIGHT_FOR_BIGRAMS;
+ break;
+ case 3:
+ probability = ngramProbability + ENCODED_BACKOFF_WEIGHT_FOR_TRIGRAMS;
+ break;
+ default:
+ AKLOGE("backoff is called with invalid matchedWordCountInContext (%d).",
+ matchedWordCountInContext);
+ ASSERT(false);
+ return NOT_A_PROBABILITY;
+ }
+ return std::min(std::max(probability, NOT_A_PROBABILITY), MAX_PROBABILITY);
+ }
+
+ static int getDecayedProbability(const int probability, const HistoricalInfo historicalInfo) {
+ const int elapsedTime = TimeKeeper::peekCurrentTime() - historicalInfo.getTimestamp();
+ if (elapsedTime < 0) {
+ AKLOGE("The elapsed time is negatime value. Timestamp overflow?");
+ return NOT_A_PROBABILITY;
+ }
+ // TODO: Improve this logic.
+ // We don't modify probability depending on the elapsed time.
+ return probability;
+ }
+
+ static int shouldRemoveEntryDuringGC(const HistoricalInfo historicalInfo) {
+ // TODO: Improve this logic.
+ const int elapsedTime = TimeKeeper::peekCurrentTime() - historicalInfo.getTimestamp();
+ return elapsedTime > DURATION_TO_DISCARD_ENTRY_IN_SECONDS;
+ }
+
+ static int getPriorityToPreventFromEviction(const HistoricalInfo historicalInfo) {
+ // TODO: Improve this logic.
+ // More recently input entries get higher priority.
+ return historicalInfo.getTimestamp();
+ }
+
+private:
+ DISALLOW_IMPLICIT_CONSTRUCTORS(DynamicLanguageModelProbabilityUtils);
+
+ static_assert(MAX_PREV_WORD_COUNT_FOR_N_GRAM <= 2, "Max supported Ngram is Trigram.");
+
+ static const int ASSUMED_MIN_COUNT_FOR_UNIGRAMS;
+ static const int ASSUMED_MIN_COUNT_FOR_BIGRAMS;
+ static const int ASSUMED_MIN_COUNT_FOR_TRIGRAMS;
+
+ static const int ENCODED_BACKOFF_WEIGHT_FOR_UNIGRAMS;
+ static const int ENCODED_BACKOFF_WEIGHT_FOR_BIGRAMS;
+ static const int ENCODED_BACKOFF_WEIGHT_FOR_TRIGRAMS;
+
+ static const int DURATION_TO_DISCARD_ENTRY_IN_SECONDS;
+};
+
+} // namespace latinime
+#endif /* LATINIME_DYNAMIC_LANGUAGE_MODEL_PROBABILITY_UTILS_H */
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
index c4297f5d6..31b1ea696 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
@@ -19,28 +19,28 @@
#include <algorithm>
#include <cstring>
-#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h"
+#include "suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h"
+#include "suggest/policyimpl/dictionary/utils/probability_utils.h"
namespace latinime {
-const int LanguageModelDictContent::UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE = 0;
-const int LanguageModelDictContent::BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE = 1;
-const int LanguageModelDictContent::DUMMY_PROBABILITY_FOR_VALID_WORDS = 1;
+const int LanguageModelDictContent::TRIE_MAP_BUFFER_INDEX = 0;
+const int LanguageModelDictContent::GLOBAL_COUNTERS_BUFFER_INDEX = 1;
bool LanguageModelDictContent::save(FILE *const file) const {
- return mTrieMap.save(file);
+ return mTrieMap.save(file) && mGlobalCounters.save(file);
}
bool LanguageModelDictContent::runGC(
const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
- const LanguageModelDictContent *const originalContent,
- int *const outNgramCount) {
+ const LanguageModelDictContent *const originalContent) {
return runGCInner(terminalIdMap, originalContent->mTrieMap.getEntriesInRootLevel(),
- 0 /* nextLevelBitmapEntryIndex */, outNgramCount);
+ 0 /* nextLevelBitmapEntryIndex */);
}
const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArrayView prevWordIds,
- const int wordId, const HeaderPolicy *const headerPolicy) const {
+ const int wordId, const bool mustMatchAllPrevWords,
+ const HeaderPolicy *const headerPolicy) const {
int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex();
int maxPrevWordCount = 0;
@@ -54,7 +54,15 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr
bitmapEntryIndices[i + 1] = nextBitmapEntryIndex;
}
+ const ProbabilityEntry unigramProbabilityEntry = getProbabilityEntry(wordId);
+ if (mHasHistoricalInfo && unigramProbabilityEntry.getHistoricalInfo()->getCount() == 0) {
+ // The word should be treated as a invalid word.
+ return WordAttributes();
+ }
for (int i = maxPrevWordCount; i >= 0; --i) {
+ if (mustMatchAllPrevWords && prevWordIds.size() > static_cast<size_t>(i)) {
+ break;
+ }
const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndices[i]);
if (!result.mIsValid) {
continue;
@@ -63,38 +71,41 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr
ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo);
int probability = NOT_A_PROBABILITY;
if (mHasHistoricalInfo) {
- const int rawProbability = ForgettingCurveUtils::decodeProbability(
- probabilityEntry.getHistoricalInfo(), headerPolicy);
- if (rawProbability == NOT_A_PROBABILITY) {
- // The entry should not be treated as a valid entry.
- continue;
- }
+ const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo();
+ int contextCount = 0;
if (i == 0) {
// unigram
- probability = rawProbability;
+ contextCount = mGlobalCounters.getTotalCount();
} else {
const ProbabilityEntry prevWordProbabilityEntry = getNgramProbabilityEntry(
prevWordIds.skip(1 /* n */).limit(i - 1), prevWordIds[0]);
if (!prevWordProbabilityEntry.isValid()) {
continue;
}
- if (prevWordProbabilityEntry.representsBeginningOfSentence()) {
- probability = rawProbability;
- } else {
- const int prevWordRawProbability = ForgettingCurveUtils::decodeProbability(
- prevWordProbabilityEntry.getHistoricalInfo(), headerPolicy);
- probability = std::min(MAX_PROBABILITY - prevWordRawProbability
- + rawProbability, MAX_PROBABILITY);
+ if (prevWordProbabilityEntry.representsBeginningOfSentence()
+ && historicalInfo->getCount() == 1) {
+ // BoS ngram requires multiple contextCount.
+ continue;
}
+ contextCount = prevWordProbabilityEntry.getHistoricalInfo()->getCount();
}
+ const float rawProbability =
+ DynamicLanguageModelProbabilityUtils::computeRawProbabilityFromCounts(
+ historicalInfo->getCount(), contextCount, i + 1);
+ const int encodedRawProbability =
+ ProbabilityUtils::encodeRawProbability(rawProbability);
+ const int decayedProbability =
+ DynamicLanguageModelProbabilityUtils::getDecayedProbability(
+ encodedRawProbability, *historicalInfo);
+ probability = DynamicLanguageModelProbabilityUtils::backoff(
+ decayedProbability, i + 1 /* n */);
} else {
probability = probabilityEntry.getProbability();
}
// TODO: Some flags in unigramProbabilityEntry should be overwritten by flags in
// probabilityEntry.
- const ProbabilityEntry unigramProbabilityEntry = getProbabilityEntry(wordId);
- return WordAttributes(probability, unigramProbabilityEntry.isNotAWord(),
- unigramProbabilityEntry.isBlacklisted(),
+ return WordAttributes(probability, unigramProbabilityEntry.isBlacklisted(),
+ unigramProbabilityEntry.isNotAWord(),
unigramProbabilityEntry.isPossiblyOffensive());
}
// Cannot find the word.
@@ -143,28 +154,69 @@ LanguageModelDictContent::EntryRange LanguageModelDictContent::getProbabilityEnt
return EntryRange(mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex), mHasHistoricalInfo);
}
-bool LanguageModelDictContent::truncateEntries(const int *const entryCounts,
- const int *const maxEntryCounts, const HeaderPolicy *const headerPolicy,
- int *const outEntryCounts) {
- for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
- if (entryCounts[i] <= maxEntryCounts[i]) {
- outEntryCounts[i] = entryCounts[i];
+std::vector<LanguageModelDictContent::DumppedFullEntryInfo>
+ LanguageModelDictContent::exportAllNgramEntriesRelatedToWord(
+ const HeaderPolicy *const headerPolicy, const int wordId) const {
+ const TrieMap::Result result = mTrieMap.getRoot(wordId);
+ if (!result.mIsValid || result.mNextLevelBitmapEntryIndex == TrieMap::INVALID_INDEX) {
+ // The word doesn't have any related ngram entries.
+ return std::vector<DumppedFullEntryInfo>();
+ }
+ std::vector<int> prevWordIds = { wordId };
+ std::vector<DumppedFullEntryInfo> entries;
+ exportAllNgramEntriesRelatedToWordInner(headerPolicy, result.mNextLevelBitmapEntryIndex,
+ &prevWordIds, &entries);
+ return entries;
+}
+
+void LanguageModelDictContent::exportAllNgramEntriesRelatedToWordInner(
+ const HeaderPolicy *const headerPolicy, const int bitmapEntryIndex,
+ std::vector<int> *const prevWordIds,
+ std::vector<DumppedFullEntryInfo> *const outBummpedFullEntryInfo) const {
+ for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
+ const int wordId = entry.key();
+ const ProbabilityEntry probabilityEntry =
+ ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
+ if (probabilityEntry.isValid()) {
+ const WordAttributes wordAttributes = getWordAttributes(
+ WordIdArrayView(*prevWordIds), wordId, true /* mustMatchAllPrevWords */,
+ headerPolicy);
+ outBummpedFullEntryInfo->emplace_back(*prevWordIds, wordId,
+ wordAttributes, probabilityEntry);
+ }
+ if (entry.hasNextLevelMap()) {
+ prevWordIds->push_back(wordId);
+ exportAllNgramEntriesRelatedToWordInner(headerPolicy,
+ entry.getNextLevelBitmapEntryIndex(), prevWordIds, outBummpedFullEntryInfo);
+ prevWordIds->pop_back();
+ }
+ }
+}
+
+bool LanguageModelDictContent::truncateEntries(const EntryCounts &currentEntryCounts,
+ const EntryCounts &maxEntryCounts, const HeaderPolicy *const headerPolicy,
+ MutableEntryCounters *const outEntryCounters) {
+ for (int prevWordCount = 0; prevWordCount <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++prevWordCount) {
+ const int totalWordCount = prevWordCount + 1;
+ if (currentEntryCounts.getNgramCount(totalWordCount)
+ <= maxEntryCounts.getNgramCount(totalWordCount)) {
+ outEntryCounters->setNgramCount(totalWordCount,
+ currentEntryCounts.getNgramCount(totalWordCount));
continue;
}
- if (!turncateEntriesInSpecifiedLevel(headerPolicy, maxEntryCounts[i], i,
- &outEntryCounts[i])) {
+ int entryCount = 0;
+ if (!turncateEntriesInSpecifiedLevel(headerPolicy,
+ maxEntryCounts.getNgramCount(totalWordCount), prevWordCount, &entryCount)) {
return false;
}
+ outEntryCounters->setNgramCount(totalWordCount, entryCount);
}
return true;
}
bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView prevWordIds,
const int wordId, const bool isValid, const HistoricalInfo historicalInfo,
- const HeaderPolicy *const headerPolicy, int *const outAddedNewNgramEntryCount) {
- if (outAddedNewNgramEntryCount) {
- *outAddedNewNgramEntryCount = 0;
- }
+ const HeaderPolicy *const headerPolicy, MutableEntryCounters *const entryCountersToUpdate) {
if (!mHasHistoricalInfo) {
AKLOGE("updateAllEntriesOnInputWord is called for dictionary without historical info.");
return false;
@@ -175,6 +227,9 @@ bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView
if (!setProbabilityEntry(wordId, &updatedUnigramProbabilityEntry)) {
return false;
}
+ mGlobalCounters.incrementTotalCount();
+ mGlobalCounters.updateMaxValueOfCounters(
+ updatedUnigramProbabilityEntry.getHistoricalInfo()->getCount());
for (size_t i = 0; i < prevWordIds.size(); ++i) {
if (prevWordIds[i] == NOT_A_WORD_ID) {
break;
@@ -188,8 +243,10 @@ bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView
if (!setNgramProbabilityEntry(limitedPrevWordIds, wordId, &updatedNgramProbabilityEntry)) {
return false;
}
- if (!originalNgramProbabilityEntry.isValid() && outAddedNewNgramEntryCount) {
- *outAddedNewNgramEntryCount += 1;
+ mGlobalCounters.updateMaxValueOfCounters(
+ updatedNgramProbabilityEntry.getHistoricalInfo()->getCount());
+ if (!originalNgramProbabilityEntry.isValid()) {
+ entryCountersToUpdate->incrementNgramCount(i + 2);
}
}
return true;
@@ -198,10 +255,9 @@ bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView
const ProbabilityEntry LanguageModelDictContent::createUpdatedEntryFrom(
const ProbabilityEntry &originalProbabilityEntry, const bool isValid,
const HistoricalInfo historicalInfo, const HeaderPolicy *const headerPolicy) const {
- const HistoricalInfo updatedHistoricalInfo = ForgettingCurveUtils::createUpdatedHistoricalInfo(
- originalProbabilityEntry.getHistoricalInfo(), isValid ?
- DUMMY_PROBABILITY_FOR_VALID_WORDS : NOT_A_PROBABILITY,
- &historicalInfo, headerPolicy);
+ const HistoricalInfo updatedHistoricalInfo = HistoricalInfo(historicalInfo.getTimestamp(),
+ 0 /* level */, originalProbabilityEntry.getHistoricalInfo()->getCount()
+ + historicalInfo.getCount());
if (originalProbabilityEntry.isValid()) {
return ProbabilityEntry(originalProbabilityEntry.getFlags(), &updatedHistoricalInfo);
} else {
@@ -211,8 +267,7 @@ const ProbabilityEntry LanguageModelDictContent::createUpdatedEntryFrom(
bool LanguageModelDictContent::runGCInner(
const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
- const TrieMap::TrieMapRange trieMapRange,
- const int nextLevelBitmapEntryIndex, int *const outNgramCount) {
+ const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex) {
for (auto &entry : trieMapRange) {
const auto it = terminalIdMap->find(entry.key());
if (it == terminalIdMap->end() || it->second == Ver4DictConstants::NOT_A_TERMINAL_ID) {
@@ -222,13 +277,9 @@ bool LanguageModelDictContent::runGCInner(
if (!mTrieMap.put(it->second, entry.value(), nextLevelBitmapEntryIndex)) {
return false;
}
- if (outNgramCount) {
- *outNgramCount += 1;
- }
if (entry.hasNextLevelMap()) {
if (!runGCInner(terminalIdMap, entry.getEntriesInNextLevel(),
- mTrieMap.getNextLevelBitmapEntryIndex(it->second, nextLevelBitmapEntryIndex),
- outNgramCount)) {
+ mTrieMap.getNextLevelBitmapEntryIndex(it->second, nextLevelBitmapEntryIndex))) {
return false;
}
}
@@ -237,24 +288,25 @@ bool LanguageModelDictContent::runGCInner(
}
int LanguageModelDictContent::createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds) {
- if (prevWordIds.empty()) {
- return mTrieMap.getRootBitmapEntryIndex();
- }
- const int lastBitmapEntryIndex =
- getBitmapEntryIndex(prevWordIds.limit(prevWordIds.size() - 1));
- if (lastBitmapEntryIndex == TrieMap::INVALID_INDEX) {
- return TrieMap::INVALID_INDEX;
- }
- const int oldestPrevWordId = prevWordIds.lastOrDefault(NOT_A_WORD_ID);
- const TrieMap::Result result = mTrieMap.get(oldestPrevWordId, lastBitmapEntryIndex);
- if (!result.mIsValid) {
- if (!mTrieMap.put(oldestPrevWordId,
- ProbabilityEntry().encode(mHasHistoricalInfo), lastBitmapEntryIndex)) {
- return TrieMap::INVALID_INDEX;
+ int lastBitmapEntryIndex = mTrieMap.getRootBitmapEntryIndex();
+ for (const int wordId : prevWordIds) {
+ const TrieMap::Result result = mTrieMap.get(wordId, lastBitmapEntryIndex);
+ if (result.mIsValid && result.mNextLevelBitmapEntryIndex != TrieMap::INVALID_INDEX) {
+ lastBitmapEntryIndex = result.mNextLevelBitmapEntryIndex;
+ continue;
}
+ if (!result.mIsValid) {
+ if (!mTrieMap.put(wordId, ProbabilityEntry().encode(mHasHistoricalInfo),
+ lastBitmapEntryIndex)) {
+ AKLOGE("Failed to update trie map. wordId: %d, lastBitmapEntryIndex %d", wordId,
+ lastBitmapEntryIndex);
+ return TrieMap::INVALID_INDEX;
+ }
+ }
+ lastBitmapEntryIndex = mTrieMap.getNextLevelBitmapEntryIndex(wordId,
+ lastBitmapEntryIndex);
}
- return mTrieMap.getNextLevelBitmapEntryIndex(prevWordIds.lastOrDefault(NOT_A_WORD_ID),
- lastBitmapEntryIndex);
+ return lastBitmapEntryIndex;
}
int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWordIds) const {
@@ -271,7 +323,7 @@ int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWord
bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex,
const int prevWordCount, const HeaderPolicy *const headerPolicy,
- int *const outEntryCounts) {
+ const bool needsToHalveCounters, MutableEntryCounters *const outEntryCounters) {
for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
if (prevWordCount > MAX_PREV_WORD_COUNT_FOR_N_GRAM) {
AKLOGE("Invalid prevWordCount. prevWordCount: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.",
@@ -288,33 +340,41 @@ bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int b
}
continue;
}
- if (mHasHistoricalInfo && !probabilityEntry.representsBeginningOfSentence()
- && probabilityEntry.isValid()) {
- const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave(
- probabilityEntry.getHistoricalInfo(), headerPolicy);
- if (ForgettingCurveUtils::needsToKeep(&historicalInfo, headerPolicy)) {
- // Update the entry.
- const ProbabilityEntry updatedEntry(probabilityEntry.getFlags(), &historicalInfo);
- if (!mTrieMap.put(entry.key(), updatedEntry.encode(mHasHistoricalInfo),
- bitmapEntryIndex)) {
- return false;
- }
- } else {
+ if (mHasHistoricalInfo && probabilityEntry.isValid()) {
+ const HistoricalInfo *originalHistoricalInfo = probabilityEntry.getHistoricalInfo();
+ if (DynamicLanguageModelProbabilityUtils::shouldRemoveEntryDuringGC(
+ *originalHistoricalInfo)) {
// Remove the entry.
if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) {
return false;
}
continue;
}
+ if (needsToHalveCounters) {
+ const int updatedCount = originalHistoricalInfo->getCount() / 2;
+ if (updatedCount == 0) {
+ // Remove the entry.
+ if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) {
+ return false;
+ }
+ continue;
+ }
+ const HistoricalInfo historicalInfoToSave(originalHistoricalInfo->getTimestamp(),
+ originalHistoricalInfo->getLevel(), updatedCount);
+ const ProbabilityEntry updatedEntry(probabilityEntry.getFlags(),
+ &historicalInfoToSave);
+ if (!mTrieMap.put(entry.key(), updatedEntry.encode(mHasHistoricalInfo),
+ bitmapEntryIndex)) {
+ return false;
+ }
+ }
}
- if (!probabilityEntry.representsBeginningOfSentence()) {
- outEntryCounts[prevWordCount] += 1;
- }
+ outEntryCounters->incrementNgramCount(prevWordCount + 1);
if (!entry.hasNextLevelMap()) {
continue;
}
if (!updateAllProbabilityEntriesForGCInner(entry.getNextLevelBitmapEntryIndex(),
- prevWordCount + 1, headerPolicy, outEntryCounts)) {
+ prevWordCount + 1, headerPolicy, needsToHalveCounters, outEntryCounters)) {
return false;
}
}
@@ -368,11 +428,11 @@ bool LanguageModelDictContent::getEntryInfo(const HeaderPolicy *const headerPoli
}
const ProbabilityEntry probabilityEntry =
ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
- const int probability = (mHasHistoricalInfo) ?
- ForgettingCurveUtils::decodeProbability(probabilityEntry.getHistoricalInfo(),
- headerPolicy) : probabilityEntry.getProbability();
- outEntryInfo->emplace_back(probability,
- probabilityEntry.getHistoricalInfo()->getTimestamp(),
+ const int priority = mHasHistoricalInfo
+ ? DynamicLanguageModelProbabilityUtils::getPriorityToPreventFromEviction(
+ *probabilityEntry.getHistoricalInfo())
+ : probabilityEntry.getProbability();
+ outEntryInfo->emplace_back(priority, probabilityEntry.getHistoricalInfo()->getCount(),
entry.key(), targetLevel, prevWordIds->data());
}
return true;
@@ -380,11 +440,11 @@ bool LanguageModelDictContent::getEntryInfo(const HeaderPolicy *const headerPoli
bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()(
const EntryInfoToTurncate &left, const EntryInfoToTurncate &right) const {
- if (left.mProbability != right.mProbability) {
- return left.mProbability < right.mProbability;
+ if (left.mPriority != right.mPriority) {
+ return left.mPriority < right.mPriority;
}
- if (left.mTimestamp != right.mTimestamp) {
- return left.mTimestamp > right.mTimestamp;
+ if (left.mCount != right.mCount) {
+ return left.mCount < right.mCount;
}
if (left.mKey != right.mKey) {
return left.mKey < right.mKey;
@@ -401,10 +461,9 @@ bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()(
return false;
}
-LanguageModelDictContent::EntryInfoToTurncate::EntryInfoToTurncate(const int probability,
- const int timestamp, const int key, const int prevWordCount, const int *const prevWordIds)
- : mProbability(probability), mTimestamp(timestamp), mKey(key),
- mPrevWordCount(prevWordCount) {
+LanguageModelDictContent::EntryInfoToTurncate::EntryInfoToTurncate(const int priority,
+ const int count, const int key, const int prevWordCount, const int *const prevWordIds)
+ : mPriority(priority), mCount(count), mKey(key), mPrevWordCount(prevWordCount) {
memmove(mPrevWordIds, prevWordIds, mPrevWordCount * sizeof(mPrevWordIds[0]));
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
index 51ef090e1..9678c35f9 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
@@ -22,9 +22,11 @@
#include "defines.h"
#include "suggest/core/dictionary/word_attributes.h"
+#include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_global_counters.h"
#include "suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h"
#include "suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h"
#include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h"
+#include "suggest/policyimpl/dictionary/utils/entry_counters.h"
#include "suggest/policyimpl/dictionary/utils/trie_map.h"
#include "utils/byte_array_view.h"
#include "utils/int_array_view.h"
@@ -40,9 +42,6 @@ class HeaderPolicy;
*/
class LanguageModelDictContent {
public:
- static const int UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE;
- static const int BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE;
-
// Pair of word id and probability entry used for iteration.
class WordIdAndProbabilityEntry {
public:
@@ -112,31 +111,54 @@ class LanguageModelDictContent {
const bool mHasHistoricalInfo;
};
- LanguageModelDictContent(const ReadWriteByteArrayView trieMapBuffer,
+ class DumppedFullEntryInfo {
+ public:
+ DumppedFullEntryInfo(std::vector<int> &prevWordIds, const int targetWordId,
+ const WordAttributes &wordAttributes, const ProbabilityEntry &probabilityEntry)
+ : mPrevWordIds(prevWordIds), mTargetWordId(targetWordId),
+ mWordAttributes(wordAttributes), mProbabilityEntry(probabilityEntry) {}
+
+ const WordIdArrayView getPrevWordIds() const { return WordIdArrayView(mPrevWordIds); }
+ int getTargetWordId() const { return mTargetWordId; }
+ const WordAttributes &getWordAttributes() const { return mWordAttributes; }
+ const ProbabilityEntry &getProbabilityEntry() const { return mProbabilityEntry; }
+
+ private:
+ DISALLOW_ASSIGNMENT_OPERATOR(DumppedFullEntryInfo);
+
+ const std::vector<int> mPrevWordIds;
+ const int mTargetWordId;
+ const WordAttributes mWordAttributes;
+ const ProbabilityEntry mProbabilityEntry;
+ };
+
+ LanguageModelDictContent(const ReadWriteByteArrayView *const buffers,
const bool hasHistoricalInfo)
- : mTrieMap(trieMapBuffer), mHasHistoricalInfo(hasHistoricalInfo) {}
+ : mTrieMap(buffers[TRIE_MAP_BUFFER_INDEX]),
+ mGlobalCounters(buffers[GLOBAL_COUNTERS_BUFFER_INDEX]),
+ mHasHistoricalInfo(hasHistoricalInfo) {}
explicit LanguageModelDictContent(const bool hasHistoricalInfo)
- : mTrieMap(), mHasHistoricalInfo(hasHistoricalInfo) {}
+ : mTrieMap(), mGlobalCounters(), mHasHistoricalInfo(hasHistoricalInfo) {}
bool isNearSizeLimit() const {
- return mTrieMap.isNearSizeLimit();
+ return mTrieMap.isNearSizeLimit() || mGlobalCounters.needsToHalveCounters();
}
bool save(FILE *const file) const;
bool runGC(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
- const LanguageModelDictContent *const originalContent,
- int *const outNgramCount);
+ const LanguageModelDictContent *const originalContent);
const WordAttributes getWordAttributes(const WordIdArrayView prevWordIds, const int wordId,
- const HeaderPolicy *const headerPolicy) const;
+ const bool mustMatchAllPrevWords, const HeaderPolicy *const headerPolicy) const;
ProbabilityEntry getProbabilityEntry(const int wordId) const {
return getNgramProbabilityEntry(WordIdArrayView(), wordId);
}
bool setProbabilityEntry(const int wordId, const ProbabilityEntry *const probabilityEntry) {
+ mGlobalCounters.addToTotalCount(probabilityEntry->getHistoricalInfo()->getCount());
return setNgramProbabilityEntry(WordIdArrayView(), wordId, probabilityEntry);
}
@@ -154,22 +176,30 @@ class LanguageModelDictContent {
EntryRange getProbabilityEntries(const WordIdArrayView prevWordIds) const;
+ std::vector<DumppedFullEntryInfo> exportAllNgramEntriesRelatedToWord(
+ const HeaderPolicy *const headerPolicy, const int wordId) const;
+
bool updateAllProbabilityEntriesForGC(const HeaderPolicy *const headerPolicy,
- int *const outEntryCounts) {
- for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
- outEntryCounts[i] = 0;
+ MutableEntryCounters *const outEntryCounters) {
+ if (!updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(),
+ 0 /* prevWordCount */, headerPolicy, mGlobalCounters.needsToHalveCounters(),
+ outEntryCounters)) {
+ return false;
+ }
+ if (mGlobalCounters.needsToHalveCounters()) {
+ mGlobalCounters.halveCounters();
}
- return updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(),
- 0 /* prevWordCount */, headerPolicy, outEntryCounts);
+ return true;
}
// entryCounts should be created by updateAllProbabilityEntries.
- bool truncateEntries(const int *const entryCounts, const int *const maxEntryCounts,
- const HeaderPolicy *const headerPolicy, int *const outEntryCounts);
+ bool truncateEntries(const EntryCounts &currentEntryCounts, const EntryCounts &maxEntryCounts,
+ const HeaderPolicy *const headerPolicy, MutableEntryCounters *const outEntryCounters);
bool updateAllEntriesOnInputWord(const WordIdArrayView prevWordIds, const int wordId,
const bool isValid, const HistoricalInfo historicalInfo,
- const HeaderPolicy *const headerPolicy, int *const outAddedNewNgramEntryCount);
+ const HeaderPolicy *const headerPolicy,
+ MutableEntryCounters *const entryCountersToUpdate);
private:
DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent);
@@ -184,11 +214,12 @@ class LanguageModelDictContent {
DISALLOW_ASSIGNMENT_OPERATOR(Comparator);
};
- EntryInfoToTurncate(const int probability, const int timestamp, const int key,
+ EntryInfoToTurncate(const int priority, const int count, const int key,
const int prevWordCount, const int *const prevWordIds);
- int mProbability;
- int mTimestamp;
+ int mPriority;
+ // TODO: Remove.
+ int mCount;
int mKey;
int mPrevWordCount;
int mPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
@@ -197,19 +228,20 @@ class LanguageModelDictContent {
DISALLOW_DEFAULT_CONSTRUCTOR(EntryInfoToTurncate);
};
- // TODO: Remove
- static const int DUMMY_PROBABILITY_FOR_VALID_WORDS;
+ static const int TRIE_MAP_BUFFER_INDEX;
+ static const int GLOBAL_COUNTERS_BUFFER_INDEX;
TrieMap mTrieMap;
+ LanguageModelDictContentGlobalCounters mGlobalCounters;
const bool mHasHistoricalInfo;
bool runGCInner(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
- const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex,
- int *const outNgramCount);
+ const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex);
int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds);
int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
bool updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, const int prevWordCount,
- const HeaderPolicy *const headerPolicy, int *const outEntryCounts);
+ const HeaderPolicy *const headerPolicy, const bool needsToHalveCounters,
+ MutableEntryCounters *const outEntryCounters);
bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy,
const int maxEntryCount, const int targetLevel, int *const outEntryCount);
bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel,
@@ -218,6 +250,9 @@ class LanguageModelDictContent {
const ProbabilityEntry createUpdatedEntryFrom(const ProbabilityEntry &originalProbabilityEntry,
const bool isValid, const HistoricalInfo historicalInfo,
const HeaderPolicy *const headerPolicy) const;
+ void exportAllNgramEntriesRelatedToWordInner(const HeaderPolicy *const headerPolicy,
+ const int bitmapEntryIndex, std::vector<int> *const prevWordIds,
+ std::vector<DumppedFullEntryInfo> *const outBummpedFullEntryInfo) const;
};
} // namespace latinime
#endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_global_counters.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_global_counters.cpp
new file mode 100644
index 000000000..d6d91887e
--- /dev/null
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_global_counters.cpp
@@ -0,0 +1,32 @@
+/*
+ * Copyright (C) 2014, The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_global_counters.h"
+
+#include <climits>
+
+#include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h"
+
+namespace latinime {
+
+const int LanguageModelDictContentGlobalCounters::COUNTER_VALUE_NEAR_LIMIT_THRESHOLD =
+ (1 << (Ver4DictConstants::WORD_COUNT_FIELD_SIZE * CHAR_BIT)) - 64;
+const int LanguageModelDictContentGlobalCounters::TOTAL_COUNT_VALUE_NEAR_LIMIT_THRESHOLD = 1 << 30;
+const int LanguageModelDictContentGlobalCounters::COUNTER_SIZE_IN_BYTES = 4;
+const int LanguageModelDictContentGlobalCounters::TOTAL_COUNT_INDEX = 0;
+const int LanguageModelDictContentGlobalCounters::MAX_VALUE_OF_COUNTERS_INDEX = 1;
+
+} // namespace latinime
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_global_counters.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_global_counters.h
new file mode 100644
index 000000000..283c2691a
--- /dev/null
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_global_counters.h
@@ -0,0 +1,101 @@
+/*
+ * Copyright (C) 2014, The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LATINIME_LANGUAGE_MODEL_DICT_CONTENT_GLOBAL_COUNTERS_H
+#define LATINIME_LANGUAGE_MODEL_DICT_CONTENT_GLOBAL_COUNTERS_H
+
+#include <cstdio>
+
+#include "defines.h"
+#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h"
+#include "suggest/policyimpl/dictionary/utils/dict_file_writing_utils.h"
+#include "utils/byte_array_view.h"
+
+namespace latinime {
+
+class LanguageModelDictContentGlobalCounters {
+ public:
+ explicit LanguageModelDictContentGlobalCounters(const ReadWriteByteArrayView buffer)
+ : mBuffer(buffer, 0 /* maxAdditionalBufferSize */),
+ mTotalCount(readValue(mBuffer, TOTAL_COUNT_INDEX)),
+ mMaxValueOfCounters(readValue(mBuffer, MAX_VALUE_OF_COUNTERS_INDEX)) {}
+
+ LanguageModelDictContentGlobalCounters()
+ : mBuffer(0 /* maxAdditionalBufferSize */), mTotalCount(0), mMaxValueOfCounters(0) {}
+
+ bool needsToHalveCounters() const {
+ return mMaxValueOfCounters >= COUNTER_VALUE_NEAR_LIMIT_THRESHOLD
+ || mTotalCount >= TOTAL_COUNT_VALUE_NEAR_LIMIT_THRESHOLD;
+ }
+
+ int getTotalCount() const {
+ return mTotalCount;
+ }
+
+ bool save(FILE *const file) const {
+ BufferWithExtendableBuffer bufferToWrite(
+ BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE);
+ if (!bufferToWrite.writeUint(mTotalCount, COUNTER_SIZE_IN_BYTES,
+ TOTAL_COUNT_INDEX * COUNTER_SIZE_IN_BYTES)) {
+ return false;
+ }
+ if (!bufferToWrite.writeUint(mMaxValueOfCounters, COUNTER_SIZE_IN_BYTES,
+ MAX_VALUE_OF_COUNTERS_INDEX * COUNTER_SIZE_IN_BYTES)) {
+ return false;
+ }
+ return DictFileWritingUtils::writeBufferToFileTail(file, &bufferToWrite);
+ }
+
+ void incrementTotalCount() {
+ mTotalCount += 1;
+ }
+
+ void addToTotalCount(const int count) {
+ mTotalCount += count;
+ }
+
+ void updateMaxValueOfCounters(const int count) {
+ mMaxValueOfCounters = std::max(count, mMaxValueOfCounters);
+ }
+
+ void halveCounters() {
+ mMaxValueOfCounters /= 2;
+ mTotalCount /= 2;
+ }
+
+private:
+ DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContentGlobalCounters);
+
+ const static int COUNTER_VALUE_NEAR_LIMIT_THRESHOLD;
+ const static int TOTAL_COUNT_VALUE_NEAR_LIMIT_THRESHOLD;
+ const static int COUNTER_SIZE_IN_BYTES;
+ const static int TOTAL_COUNT_INDEX;
+ const static int MAX_VALUE_OF_COUNTERS_INDEX;
+
+ BufferWithExtendableBuffer mBuffer;
+ int mTotalCount;
+ int mMaxValueOfCounters;
+
+ static int readValue(const BufferWithExtendableBuffer &buffer, const int index) {
+ const int pos = COUNTER_SIZE_IN_BYTES * index;
+ if (pos + COUNTER_SIZE_IN_BYTES > buffer.getTailPosition()) {
+ return 0;
+ }
+ return buffer.readUint(COUNTER_SIZE_IN_BYTES, pos);
+ }
+};
+} // namespace latinime
+#endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_GLOBAL_COUNTERS_H */
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h
index f4d340f86..9c4ab18e4 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h
@@ -105,7 +105,7 @@ class ProbabilityEntry {
encodedEntry = (encodedEntry << (Ver4DictConstants::WORD_LEVEL_FIELD_SIZE * CHAR_BIT))
| static_cast<uint8_t>(mHistoricalInfo.getLevel());
encodedEntry = (encodedEntry << (Ver4DictConstants::WORD_COUNT_FIELD_SIZE * CHAR_BIT))
- | static_cast<uint8_t>(mHistoricalInfo.getCount());
+ | static_cast<uint16_t>(mHistoricalInfo.getCount());
} else {
encodedEntry = (encodedEntry << (Ver4DictConstants::PROBABILITY_SIZE * CHAR_BIT))
| static_cast<uint8_t>(mProbability);
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.cpp
index 45f88e9b2..4d088dcab 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.cpp
@@ -179,7 +179,7 @@ Ver4DictBuffers::Ver4DictBuffers(MmappedBuffer::MmappedBufferPtr &&headerBuffer,
BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE),
mTerminalPositionLookupTable(
contentBuffers[Ver4DictConstants::TERMINAL_ADDRESS_LOOKUP_TABLE_BUFFER_INDEX]),
- mLanguageModelDictContent(contentBuffers[Ver4DictConstants::LANGUAGE_MODEL_BUFFER_INDEX],
+ mLanguageModelDictContent(&contentBuffers[Ver4DictConstants::LANGUAGE_MODEL_BUFFER_INDEX],
mHeaderPolicy.hasHistoricalInfoOfWords()),
mShortcutDictContent(&contentBuffers[Ver4DictConstants::SHORTCUT_BUFFERS_INDEX]),
mIsUpdatable(mDictBuffer->isUpdatable()) {}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp
index 8e6cb974b..bd89b8da7 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp
@@ -49,8 +49,8 @@ const int Ver4DictConstants::TERMINAL_ADDRESS_TABLE_ADDRESS_SIZE = 3;
const int Ver4DictConstants::NOT_A_TERMINAL_ADDRESS = 0;
const int Ver4DictConstants::TERMINAL_ID_FIELD_SIZE = 4;
const int Ver4DictConstants::TIME_STAMP_FIELD_SIZE = 4;
-const int Ver4DictConstants::WORD_LEVEL_FIELD_SIZE = 1;
-const int Ver4DictConstants::WORD_COUNT_FIELD_SIZE = 1;
+const int Ver4DictConstants::WORD_LEVEL_FIELD_SIZE = 0;
+const int Ver4DictConstants::WORD_COUNT_FIELD_SIZE = 2;
const uint8_t Ver4DictConstants::FLAG_REPRESENTS_BEGINNING_OF_SENTENCE = 0x1;
const uint8_t Ver4DictConstants::FLAG_NOT_A_VALID_ENTRY = 0x2;
@@ -67,6 +67,6 @@ const int Ver4DictConstants::SHORTCUT_HAS_NEXT_MASK = 0x80;
const size_t Ver4DictConstants::NUM_OF_BUFFERS_FOR_SINGLE_DICT_CONTENT = 1;
const size_t Ver4DictConstants::NUM_OF_BUFFERS_FOR_SPARSE_TABLE_DICT_CONTENT = 3;
-const size_t Ver4DictConstants::NUM_OF_BUFFERS_FOR_LANGUAGE_MODEL_DICT_CONTENT = 1;
+const size_t Ver4DictConstants::NUM_OF_BUFFERS_FOR_LANGUAGE_MODEL_DICT_CONTENT = 2;
} // namespace latinime
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h
index 600b5ffe4..13d7a5714 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h
@@ -47,6 +47,7 @@ class Ver4DictConstants {
static const int NOT_A_TERMINAL_ADDRESS;
static const int TERMINAL_ID_FIELD_SIZE;
static const int TIME_STAMP_FIELD_SIZE;
+ // TODO: Remove
static const int WORD_LEVEL_FIELD_SIZE;
static const int WORD_COUNT_FIELD_SIZE;
// Flags in probability entry.
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp
index 794c63ffd..3488f7d2a 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp
@@ -342,7 +342,7 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeFlags(const int ptNodePos, const bo
// Create node flags and write them.
PatriciaTrieReadingUtils::NodeFlags nodeFlags =
PatriciaTrieReadingUtils::createAndGetFlags(false /* isNotAWord */,
- false /* isBlacklisted */, isTerminal, false /* hasShortcutTargets */,
+ false /* isPossiblyOffensive */, isTerminal, false /* hasShortcutTargets */,
false /* hasBigrams */, hasMultipleChars, CHILDREN_POSITION_FIELD_SIZE);
if (!DynamicPtWritingUtils::writeFlags(mTrieBuffer, nodeFlags, ptNodePos)) {
AKLOGE("Cannot write PtNode flags. flags: %x, pos: %d", nodeFlags, ptNodePos);
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
index ea8c0dc22..1992d4a5a 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
@@ -97,6 +97,9 @@ int Ver4PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints,
return NOT_A_WORD_ID;
}
const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
+ if (ptNodeParams.isDeleted()) {
+ return NOT_A_WORD_ID;
+ }
return ptNodeParams.getTerminalId();
}
@@ -107,7 +110,7 @@ const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext(
return WordAttributes();
}
return mBuffers->getLanguageModelDictContent()->getWordAttributes(prevWordIds, wordId,
- mHeaderPolicy);
+ false /* mustMatchAllPrevWords */, mHeaderPolicy);
}
int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds,
@@ -115,18 +118,13 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordI
if (wordId == NOT_A_WORD_ID || prevWordIds.contains(NOT_A_WORD_ID)) {
return NOT_A_PROBABILITY;
}
- const ProbabilityEntry probabilityEntry =
- mBuffers->getLanguageModelDictContent()->getNgramProbabilityEntry(prevWordIds, wordId);
- if (!probabilityEntry.isValid() || probabilityEntry.isBlacklisted()
- || probabilityEntry.isNotAWord()) {
+ const WordAttributes wordAttributes =
+ mBuffers->getLanguageModelDictContent()->getWordAttributes(prevWordIds, wordId,
+ true /* mustMatchAllPrevWords */, mHeaderPolicy);
+ if (wordAttributes.isBlacklisted() || wordAttributes.isNotAWord()) {
return NOT_A_PROBABILITY;
}
- if (mHeaderPolicy->hasHistoricalInfoOfWords()) {
- return ForgettingCurveUtils::decodeProbability(probabilityEntry.getHistoricalInfo(),
- mHeaderPolicy);
- } else {
- return probabilityEntry.getProbability();
- }
+ return wordAttributes.getProbability();
}
BinaryDictionaryShortcutIterator Ver4PatriciaTriePolicy::getShortcutIterator(
@@ -148,10 +146,16 @@ void Ver4PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordI
if (!probabilityEntry.isValid()) {
continue;
}
- const int probability = probabilityEntry.hasHistoricalInfo() ?
- ForgettingCurveUtils::decodeProbability(
- probabilityEntry.getHistoricalInfo(), mHeaderPolicy) :
- probabilityEntry.getProbability();
+ int probability = NOT_A_PROBABILITY;
+ if (probabilityEntry.hasHistoricalInfo()) {
+ // TODO: Quit checking count here.
+ // If count <= 1, the word can be an invaild word. The actual probability should
+ // be checked using getWordAttributesInContext() in onVisitEntry().
+ probability = probabilityEntry.getHistoricalInfo()->getCount() <= 1 ?
+ NOT_A_PROBABILITY : 0;
+ } else {
+ probability = probabilityEntry.getProbability();
+ }
listener->onVisitEntry(probability, entry.getWordId());
}
}
@@ -211,7 +215,7 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const CodePointArrayView wordCodePo
if (mUpdatingHelper.addUnigramWord(&readingHelper, codePointArrayView, unigramProperty,
&addedNewUnigram)) {
if (addedNewUnigram && !unigramProperty->representsBeginningOfSentence()) {
- mUnigramCount++;
+ mEntryCounters.incrementUnigramCount();
}
if (unigramProperty->getShortcuts().size() > 0) {
// Add shortcut target.
@@ -259,13 +263,12 @@ bool Ver4PatriciaTriePolicy::removeUnigramEntry(const CodePointArrayView wordCod
return false;
}
if (!ptNodeParams.representsNonWordInfo()) {
- mUnigramCount--;
+ mEntryCounters.decrementUnigramCount();
}
return true;
}
-bool Ver4PatriciaTriePolicy::addNgramEntry(const NgramContext *const ngramContext,
- const NgramProperty *const ngramProperty) {
+bool Ver4PatriciaTriePolicy::addNgramEntry(const NgramProperty *const ngramProperty) {
if (!mBuffers->isUpdatable()) {
AKLOGI("Warning: addNgramEntry() is called for non-updatable dictionary.");
return false;
@@ -275,6 +278,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const NgramContext *const ngramContex
mDictBuffer->getTailPosition());
return false;
}
+ const NgramContext *const ngramContext = ngramProperty->getNgramContext();
if (!ngramContext->isValid()) {
AKLOGE("Ngram context is not valid for adding n-gram entry to the dictionary.");
return false;
@@ -299,7 +303,8 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const NgramContext *const ngramContex
}
const UnigramProperty beginningOfSentenceUnigramProperty(
true /* representsBeginningOfSentence */, true /* isNotAWord */,
- false /* isBlacklisted */, MAX_PROBABILITY /* probability */, HistoricalInfo());
+ false /* isBlacklisted */, false /* isPossiblyOffensive */,
+ MAX_PROBABILITY /* probability */, HistoricalInfo());
if (!addUnigramEntry(ngramContext->getNthPrevWordCodePoints(1 /* n */),
&beginningOfSentenceUnigramProperty)) {
AKLOGE("Cannot add unigram entry for the beginning-of-sentence.");
@@ -316,7 +321,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const NgramContext *const ngramContex
bool addedNewEntry = false;
if (mNodeWriter.addNgramEntry(prevWordIds, wordId, ngramProperty, &addedNewEntry)) {
if (addedNewEntry) {
- mBigramCount++;
+ mEntryCounters.incrementNgramCount(prevWordIds.size() + 1);
}
return true;
} else {
@@ -354,7 +359,7 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const NgramContext *const ngramCon
return false;
}
if (mNodeWriter.removeNgramEntry(prevWordIds, wordId)) {
- mBigramCount--;
+ mEntryCounters.decrementNgramCount(prevWordIds.size());
return true;
} else {
return false;
@@ -375,38 +380,47 @@ bool Ver4PatriciaTriePolicy::updateEntriesForWordWithNgramContext(
if (wordId == NOT_A_WORD_ID) {
// The word is not in the dictionary.
const UnigramProperty unigramProperty(false /* representsBeginningOfSentence */,
- false /* isNotAWord */, false /* isBlacklisted */, NOT_A_PROBABILITY,
- HistoricalInfo(historicalInfo.getTimestamp(), 0 /* level */, 0 /* count */));
+ false /* isNotAWord */, false /* isBlacklisted */, false /* isPossiblyOffensive */,
+ NOT_A_PROBABILITY, HistoricalInfo(historicalInfo.getTimestamp(), 0 /* level */,
+ 0 /* count */));
if (!addUnigramEntry(wordCodePoints, &unigramProperty)) {
AKLOGE("Cannot add unigarm entry in updateEntriesForWordWithNgramContext().");
return false;
}
+ if (!isValidWord) {
+ return true;
+ }
wordId = getWordId(wordCodePoints, false /* tryLowerCaseSearch */);
}
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray;
const WordIdArrayView prevWordIds = ngramContext->getPrevWordIds(this, &prevWordIdArray,
false /* tryLowerCaseSearch */);
- if (prevWordIds.firstOrDefault(NOT_A_WORD_ID) == NOT_A_WORD_ID
- && ngramContext->isNthPrevWordBeginningOfSentence(1 /* n */)) {
- const UnigramProperty beginningOfSentenceUnigramProperty(
- true /* representsBeginningOfSentence */,
- true /* isNotAWord */, false /* isBlacklisted */, NOT_A_PROBABILITY,
- HistoricalInfo(historicalInfo.getTimestamp(), 0 /* level */, 0 /* count */));
- if (!addUnigramEntry(ngramContext->getNthPrevWordCodePoints(1 /* n */),
- &beginningOfSentenceUnigramProperty)) {
- AKLOGE("Cannot add BoS entry in updateEntriesForWordWithNgramContext().");
+ if (ngramContext->isNthPrevWordBeginningOfSentence(1 /* n */)) {
+ if (prevWordIds.firstOrDefault(NOT_A_WORD_ID) == NOT_A_WORD_ID) {
+ const UnigramProperty beginningOfSentenceUnigramProperty(
+ true /* representsBeginningOfSentence */,
+ true /* isNotAWord */, false /* isPossiblyOffensive */, NOT_A_PROBABILITY,
+ HistoricalInfo(historicalInfo.getTimestamp(), 0 /* level */, 0 /* count */));
+ if (!addUnigramEntry(ngramContext->getNthPrevWordCodePoints(1 /* n */),
+ &beginningOfSentenceUnigramProperty)) {
+ AKLOGE("Cannot add BoS entry in updateEntriesForWordWithNgramContext().");
+ return false;
+ }
+ // Refresh word ids.
+ ngramContext->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */);
+ }
+ // Update entries for beginning of sentence.
+ if (!mBuffers->getMutableLanguageModelDictContent()->updateAllEntriesOnInputWord(
+ prevWordIds.skip(1 /* n */), prevWordIds[0], true /* isVaild */, historicalInfo,
+ mHeaderPolicy, &mEntryCounters)) {
return false;
}
- // Refresh word ids.
- ngramContext->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */);
}
- int addedNewNgramEntryCount = 0;
if (!mBuffers->getMutableLanguageModelDictContent()->updateAllEntriesOnInputWord(prevWordIds,
- wordId, updateAsAValidWord, historicalInfo, mHeaderPolicy, &addedNewNgramEntryCount)) {
+ wordId, updateAsAValidWord, historicalInfo, mHeaderPolicy, &mEntryCounters)) {
return false;
}
- mBigramCount += addedNewNgramEntryCount;
return true;
}
@@ -415,7 +429,7 @@ bool Ver4PatriciaTriePolicy::flush(const char *const filePath) {
AKLOGI("Warning: flush() is called for non-updatable dictionary. filePath: %s", filePath);
return false;
}
- if (!mWritingHelper.writeToDictFile(filePath, mUnigramCount, mBigramCount)) {
+ if (!mWritingHelper.writeToDictFile(filePath, mEntryCounters.getEntryCounts())) {
AKLOGE("Cannot flush the dictionary to file.");
mIsCorrupted = true;
return false;
@@ -453,7 +467,7 @@ bool Ver4PatriciaTriePolicy::needsToRunGC(const bool mindsBlockByGC) const {
// Needs to reduce dictionary size.
return true;
} else if (mHeaderPolicy->isDecayingDict()) {
- return ForgettingCurveUtils::needsToDecay(mindsBlockByGC, mUnigramCount, mBigramCount,
+ return ForgettingCurveUtils::needsToDecay(mindsBlockByGC, mEntryCounters.getEntryCounts(),
mHeaderPolicy);
}
return false;
@@ -463,19 +477,19 @@ void Ver4PatriciaTriePolicy::getProperty(const char *const query, const int quer
char *const outResult, const int maxResultLength) {
const int compareLength = queryLength + 1 /* terminator */;
if (strncmp(query, UNIGRAM_COUNT_QUERY, compareLength) == 0) {
- snprintf(outResult, maxResultLength, "%d", mUnigramCount);
+ snprintf(outResult, maxResultLength, "%d", mEntryCounters.getUnigramCount());
} else if (strncmp(query, BIGRAM_COUNT_QUERY, compareLength) == 0) {
- snprintf(outResult, maxResultLength, "%d", mBigramCount);
+ snprintf(outResult, maxResultLength, "%d", mEntryCounters.getBigramCount());
} else if (strncmp(query, MAX_UNIGRAM_COUNT_QUERY, compareLength) == 0) {
snprintf(outResult, maxResultLength, "%d",
mHeaderPolicy->isDecayingDict() ?
- ForgettingCurveUtils::getUnigramCountHardLimit(
+ ForgettingCurveUtils::getEntryCountHardLimit(
mHeaderPolicy->getMaxUnigramCount()) :
static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE));
} else if (strncmp(query, MAX_BIGRAM_COUNT_QUERY, compareLength) == 0) {
snprintf(outResult, maxResultLength, "%d",
mHeaderPolicy->isDecayingDict() ?
- ForgettingCurveUtils::getBigramCountHardLimit(
+ ForgettingCurveUtils::getEntryCountHardLimit(
mHeaderPolicy->getMaxBigramCount()) :
static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE));
}
@@ -488,29 +502,37 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
AKLOGE("getWordProperty is called for invalid word.");
return WordProperty();
}
- const int ptNodePos =
- mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
- const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
- const ProbabilityEntry probabilityEntry =
- mBuffers->getLanguageModelDictContent()->getProbabilityEntry(
- ptNodeParams.getTerminalId());
- const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo();
- // Fetch bigram information.
- // TODO: Support n-gram.
+ const LanguageModelDictContent *const languageModelDictContent =
+ mBuffers->getLanguageModelDictContent();
+ // Fetch ngram information.
std::vector<NgramProperty> ngrams;
- const WordIdArrayView prevWordIds = WordIdArrayView::singleElementView(&wordId);
- int bigramWord1CodePoints[MAX_WORD_LENGTH];
- for (const auto entry : mBuffers->getLanguageModelDictContent()->getProbabilityEntries(
- prevWordIds)) {
- const int codePointCount = getCodePointsAndReturnCodePointCount(entry.getWordId(),
- MAX_WORD_LENGTH, bigramWord1CodePoints);
- const ProbabilityEntry probabilityEntry = entry.getProbabilityEntry();
- const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo();
- const int probability = probabilityEntry.hasHistoricalInfo() ?
- ForgettingCurveUtils::decodeProbability(historicalInfo, mHeaderPolicy) :
- probabilityEntry.getProbability();
- ngrams.emplace_back(CodePointArrayView(bigramWord1CodePoints, codePointCount).toVector(),
- probability, *historicalInfo);
+ int ngramTargetCodePoints[MAX_WORD_LENGTH];
+ int ngramPrevWordsCodePoints[MAX_PREV_WORD_COUNT_FOR_N_GRAM][MAX_WORD_LENGTH];
+ int ngramPrevWordsCodePointCount[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
+ bool ngramPrevWordIsBeginningOfSentense[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
+ for (const auto entry : languageModelDictContent->exportAllNgramEntriesRelatedToWord(
+ mHeaderPolicy, wordId)) {
+ const int codePointCount = getCodePointsAndReturnCodePointCount(entry.getTargetWordId(),
+ MAX_WORD_LENGTH, ngramTargetCodePoints);
+ const WordIdArrayView prevWordIds = entry.getPrevWordIds();
+ for (size_t i = 0; i < prevWordIds.size(); ++i) {
+ ngramPrevWordsCodePointCount[i] = getCodePointsAndReturnCodePointCount(prevWordIds[i],
+ MAX_WORD_LENGTH, ngramPrevWordsCodePoints[i]);
+ ngramPrevWordIsBeginningOfSentense[i] = languageModelDictContent->getProbabilityEntry(
+ prevWordIds[i]).representsBeginningOfSentence();
+ if (ngramPrevWordIsBeginningOfSentense[i]) {
+ ngramPrevWordsCodePointCount[i] = CharUtils::removeBeginningOfSentenceMarker(
+ ngramPrevWordsCodePoints[i], ngramPrevWordsCodePointCount[i]);
+ }
+ }
+ const NgramContext ngramContext(ngramPrevWordsCodePoints, ngramPrevWordsCodePointCount,
+ ngramPrevWordIsBeginningOfSentense, prevWordIds.size());
+ const ProbabilityEntry ngramProbabilityEntry = entry.getProbabilityEntry();
+ const HistoricalInfo *const historicalInfo = ngramProbabilityEntry.getHistoricalInfo();
+ // TODO: Output flags in WordAttributes.
+ ngrams.emplace_back(ngramContext,
+ CodePointArrayView(ngramTargetCodePoints, codePointCount).toVector(),
+ entry.getWordAttributes().getProbability(), *historicalInfo);
}
// Fetch shortcut information.
std::vector<UnigramProperty::ShortcutProperty> shortcuts;
@@ -530,9 +552,14 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
shortcutProbability);
}
}
+ const WordAttributes wordAttributes = languageModelDictContent->getWordAttributes(
+ WordIdArrayView(), wordId, true /* mustMatchAllPrevWords */, mHeaderPolicy);
+ const ProbabilityEntry probabilityEntry = languageModelDictContent->getProbabilityEntry(wordId);
+ const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo();
const UnigramProperty unigramProperty(probabilityEntry.representsBeginningOfSentence(),
- probabilityEntry.isNotAWord(), probabilityEntry.isBlacklisted(),
- probabilityEntry.getProbability(), *historicalInfo, std::move(shortcuts));
+ wordAttributes.isNotAWord(), wordAttributes.isBlacklisted(),
+ wordAttributes.isPossiblyOffensive(), wordAttributes.getProbability(),
+ *historicalInfo, std::move(shortcuts));
return WordProperty(wordCodePoints.toVector(), &unigramProperty, &ngrams);
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h
index c0532815c..13700b390 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h
@@ -30,6 +30,7 @@
#include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h"
#include "suggest/policyimpl/dictionary/structure/v4/ver4_pt_node_array_reader.h"
#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h"
+#include "suggest/policyimpl/dictionary/utils/entry_counters.h"
#include "utils/int_array_view.h"
namespace latinime {
@@ -37,7 +38,6 @@ namespace latinime {
class DicNode;
class DicNodeVector;
-// TODO: Support counting ngram entries.
// Word id = Artificial id that is stored in the PtNode looked up by the word.
class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
public:
@@ -51,8 +51,8 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
&mShortcutPolicy),
mUpdatingHelper(mDictBuffer, &mNodeReader, &mNodeWriter),
mWritingHelper(mBuffers.get()),
- mUnigramCount(mHeaderPolicy->getUnigramCount()),
- mBigramCount(mHeaderPolicy->getBigramCount()),
+ mEntryCounters(mHeaderPolicy->getUnigramCount(), mHeaderPolicy->getBigramCount(),
+ mHeaderPolicy->getTrigramCount()),
mTerminalPtNodePositionsForIteratingWords(), mIsCorrupted(false) {};
AK_FORCE_INLINE int getRootPosition() const {
@@ -92,8 +92,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
bool removeUnigramEntry(const CodePointArrayView wordCodePoints);
- bool addNgramEntry(const NgramContext *const ngramContext,
- const NgramProperty *const ngramProperty);
+ bool addNgramEntry(const NgramProperty *const ngramProperty);
bool removeNgramEntry(const NgramContext *const ngramContext,
const CodePointArrayView wordCodePoints);
@@ -141,9 +140,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
Ver4PatriciaTrieNodeWriter mNodeWriter;
DynamicPtUpdatingHelper mUpdatingHelper;
Ver4PatriciaTrieWritingHelper mWritingHelper;
- int mUnigramCount;
- // TODO: Support counting ngram entries.
- int mBigramCount;
+ MutableEntryCounters mEntryCounters;
std::vector<int> mTerminalPtNodePositionsForIteratingWords;
mutable bool mIsCorrupted;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
index f0d59c150..7f0604ce8 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
@@ -33,17 +33,18 @@
namespace latinime {
bool Ver4PatriciaTrieWritingHelper::writeToDictFile(const char *const dictDirPath,
- const int unigramCount, const int bigramCount) const {
+ const EntryCounts &entryCounts) const {
const HeaderPolicy *const headerPolicy = mBuffers->getHeaderPolicy();
BufferWithExtendableBuffer headerBuffer(
BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE);
const int extendedRegionSize = headerPolicy->getExtendedRegionSize()
+ mBuffers->getTrieBuffer()->getUsedAdditionalBufferSize();
if (!headerPolicy->fillInAndWriteHeaderToBuffer(false /* updatesLastDecayedTime */,
- unigramCount, bigramCount, extendedRegionSize, &headerBuffer)) {
+ entryCounts, extendedRegionSize, &headerBuffer)) {
AKLOGE("Cannot write header structure to buffer. "
- "updatesLastDecayedTime: %d, unigramCount: %d, bigramCount: %d, "
- "extendedRegionSize: %d", false, unigramCount, bigramCount,
+ "updatesLastDecayedTime: %d, unigramCount: %d, bigramCount: %d, trigramCount: %d,"
+ "extendedRegionSize: %d", false, entryCounts.getUnigramCount(),
+ entryCounts.getBigramCount(), entryCounts.getTrigramCount(),
extendedRegionSize);
return false;
}
@@ -56,15 +57,14 @@ bool Ver4PatriciaTrieWritingHelper::writeToDictFileWithGC(const int rootPtNodeAr
Ver4DictBuffers::Ver4DictBuffersPtr dictBuffers(
Ver4DictBuffers::createVer4DictBuffers(headerPolicy,
Ver4DictConstants::MAX_DICTIONARY_SIZE));
- int unigramCount = 0;
- int bigramCount = 0;
- if (!runGC(rootPtNodeArrayPos, headerPolicy, dictBuffers.get(), &unigramCount, &bigramCount)) {
+ MutableEntryCounters entryCounters;
+ if (!runGC(rootPtNodeArrayPos, headerPolicy, dictBuffers.get(), &entryCounters)) {
return false;
}
BufferWithExtendableBuffer headerBuffer(
BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE);
if (!headerPolicy->fillInAndWriteHeaderToBuffer(true /* updatesLastDecayedTime */,
- unigramCount, bigramCount, 0 /* extendedRegionSize */, &headerBuffer)) {
+ entryCounters.getEntryCounts(), 0 /* extendedRegionSize */, &headerBuffer)) {
return false;
}
return dictBuffers->flushHeaderAndDictBuffers(dictDirPath, &headerBuffer);
@@ -72,7 +72,7 @@ bool Ver4PatriciaTrieWritingHelper::writeToDictFileWithGC(const int rootPtNodeAr
bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
const HeaderPolicy *const headerPolicy, Ver4DictBuffers *const buffersToWrite,
- int *const outUnigramCount, int *const outBigramCount) {
+ MutableEntryCounters *const outEntryCounters) {
Ver4PatriciaTrieNodeReader ptNodeReader(mBuffers->getTrieBuffer());
Ver4PtNodeArrayReader ptNodeArrayReader(mBuffers->getTrieBuffer());
Ver4ShortcutListPolicy shortcutPolicy(mBuffers->getMutableShortcutDictContent(),
@@ -80,24 +80,17 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
Ver4PatriciaTrieNodeWriter ptNodeWriter(mBuffers->getWritableTrieBuffer(),
mBuffers, &ptNodeReader, &ptNodeArrayReader, &shortcutPolicy);
- int entryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
if (!mBuffers->getMutableLanguageModelDictContent()->updateAllProbabilityEntriesForGC(
- headerPolicy, entryCountTable)) {
+ headerPolicy, outEntryCounters)) {
AKLOGE("Failed to update probabilities in language model dict content.");
return false;
}
if (headerPolicy->isDecayingDict()) {
- int maxEntryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
- maxEntryCountTable[LanguageModelDictContent::UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE] =
- headerPolicy->getMaxUnigramCount();
- maxEntryCountTable[LanguageModelDictContent::BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE] =
- headerPolicy->getMaxBigramCount();
- for (size_t i = 2; i < NELEMS(maxEntryCountTable); ++i) {
- // TODO: Have max n-gram count.
- maxEntryCountTable[i] = headerPolicy->getMaxBigramCount();
- }
- if (!mBuffers->getMutableLanguageModelDictContent()->truncateEntries(entryCountTable,
- maxEntryCountTable, headerPolicy, entryCountTable)) {
+ const EntryCounts maxEntryCounts(headerPolicy->getMaxUnigramCount(),
+ headerPolicy->getMaxBigramCount(), headerPolicy->getMaxTrigramCount());
+ if (!mBuffers->getMutableLanguageModelDictContent()->truncateEntries(
+ outEntryCounters->getEntryCounts(), maxEntryCounts, headerPolicy,
+ outEntryCounters)) {
AKLOGE("Failed to truncate entries in language model dict content.");
return false;
}
@@ -141,9 +134,9 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
&terminalIdMap)) {
return false;
}
- // Run GC for probability dict content.
+ // Run GC for language model dict content.
if (!buffersToWrite->getMutableLanguageModelDictContent()->runGC(&terminalIdMap,
- mBuffers->getLanguageModelDictContent(), nullptr /* outNgramCount */)) {
+ mBuffers->getLanguageModelDictContent())) {
return false;
}
// Run GC for shortcut dict content.
@@ -166,10 +159,6 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
&traversePolicyToUpdateAllPtNodeFlagsAndTerminalIds)) {
return false;
}
- *outUnigramCount =
- entryCountTable[LanguageModelDictContent::UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE];
- *outBigramCount =
- entryCountTable[LanguageModelDictContent::BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE];
return true;
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h
index 3569d0576..c56cea5cf 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h
@@ -20,6 +20,7 @@
#include "defines.h"
#include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_gc_event_listeners.h"
#include "suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h"
+#include "suggest/policyimpl/dictionary/utils/entry_counters.h"
namespace latinime {
@@ -33,9 +34,7 @@ class Ver4PatriciaTrieWritingHelper {
Ver4PatriciaTrieWritingHelper(Ver4DictBuffers *const buffers)
: mBuffers(buffers) {}
- // TODO: Support counting ngram entries.
- bool writeToDictFile(const char *const dictDirPath, const int unigramCount,
- const int bigramCount) const;
+ bool writeToDictFile(const char *const dictDirPath, const EntryCounts &entryCounts) const;
// This method cannot be const because the original dictionary buffer will be updated to detect
// useless PtNodes during GC.
@@ -68,8 +67,7 @@ class Ver4PatriciaTrieWritingHelper {
};
bool runGC(const int rootPtNodeArrayPos, const HeaderPolicy *const headerPolicy,
- Ver4DictBuffers *const buffersToWrite, int *const outUnigramCount,
- int *const outBigramCount);
+ Ver4DictBuffers *const buffersToWrite, MutableEntryCounters *const outEntryCounters);
Ver4DictBuffers *const mBuffers;
};
diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp
index b7e2a7278..edcb43678 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp
@@ -27,6 +27,7 @@
#include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_writing_utils.h"
#include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.h"
#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h"
+#include "suggest/policyimpl/dictionary/utils/entry_counters.h"
#include "suggest/policyimpl/dictionary/utils/file_utils.h"
#include "suggest/policyimpl/dictionary/utils/format_utils.h"
#include "utils/time_keeper.h"
@@ -43,13 +44,13 @@ const int DictFileWritingUtils::SIZE_OF_BUFFER_SIZE_FIELD = 4;
TimeKeeper::setCurrentTime();
const FormatUtils::FORMAT_VERSION formatVersion = FormatUtils::getFormatVersion(dictVersion);
switch (formatVersion) {
- case FormatUtils::VERSION_4:
+ case FormatUtils::VERSION_402:
return createEmptyV4DictFile<backward::v402::Ver4DictConstants,
backward::v402::Ver4DictBuffers,
backward::v402::Ver4DictBuffers::Ver4DictBuffersPtr>(
filePath, localeAsCodePointVector, attributeMap, formatVersion);
case FormatUtils::VERSION_4_ONLY_FOR_TESTING:
- case FormatUtils::VERSION_4_DEV:
+ case FormatUtils::VERSION_403:
return createEmptyV4DictFile<Ver4DictConstants, Ver4DictBuffers,
Ver4DictBuffers::Ver4DictBuffersPtr>(
filePath, localeAsCodePointVector, attributeMap, formatVersion);
@@ -69,8 +70,7 @@ template<class DictConstants, class DictBuffers, class DictBuffersPtr>
DictBuffersPtr dictBuffers = DictBuffers::createVer4DictBuffers(&headerPolicy,
DictConstants::MAX_DICT_EXTENDED_REGION_SIZE);
headerPolicy.fillInAndWriteHeaderToBuffer(true /* updatesLastDecayedTime */,
- 0 /* unigramCount */, 0 /* bigramCount */,
- 0 /* extendedRegionSize */, dictBuffers->getWritableHeaderBuffer());
+ EntryCounts(), 0 /* extendedRegionSize */, dictBuffers->getWritableHeaderBuffer());
if (!DynamicPtWritingUtils::writeEmptyDictionary(
dictBuffers->getWritableTrieBuffer(), 0 /* rootPos */)) {
AKLOGE("Empty ver4 dictionary structure cannot be created on memory.");
diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/entry_counters.h b/native/jni/src/suggest/policyimpl/dictionary/utils/entry_counters.h
new file mode 100644
index 000000000..73dc42a18
--- /dev/null
+++ b/native/jni/src/suggest/policyimpl/dictionary/utils/entry_counters.h
@@ -0,0 +1,133 @@
+/*
+ * Copyright (C) 2014, The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LATINIME_ENTRY_COUNTERS_H
+#define LATINIME_ENTRY_COUNTERS_H
+
+#include <array>
+
+#include "defines.h"
+
+namespace latinime {
+
+// Copyable but immutable
+class EntryCounts final {
+ public:
+ EntryCounts() : mEntryCounts({{0, 0, 0}}) {}
+
+ EntryCounts(const int unigramCount, const int bigramCount, const int trigramCount)
+ : mEntryCounts({{unigramCount, bigramCount, trigramCount}}) {}
+
+ explicit EntryCounts(const std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1> &counters)
+ : mEntryCounts(counters) {}
+
+ int getUnigramCount() const {
+ return mEntryCounts[0];
+ }
+
+ int getBigramCount() const {
+ return mEntryCounts[1];
+ }
+
+ int getTrigramCount() const {
+ return mEntryCounts[2];
+ }
+
+ int getNgramCount(const size_t n) const {
+ if (n < 1 || n > mEntryCounts.size()) {
+ return 0;
+ }
+ return mEntryCounts[n - 1];
+ }
+
+ private:
+ DISALLOW_ASSIGNMENT_OPERATOR(EntryCounts);
+
+ const std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1> mEntryCounts;
+};
+
+class MutableEntryCounters final {
+ public:
+ MutableEntryCounters() {
+ mEntryCounters.fill(0);
+ }
+
+ MutableEntryCounters(const int unigramCount, const int bigramCount, const int trigramCount)
+ : mEntryCounters({{unigramCount, bigramCount, trigramCount}}) {}
+
+ const EntryCounts getEntryCounts() const {
+ return EntryCounts(mEntryCounters);
+ }
+
+ int getUnigramCount() const {
+ return mEntryCounters[0];
+ }
+
+ int getBigramCount() const {
+ return mEntryCounters[1];
+ }
+
+ int getTrigramCount() const {
+ return mEntryCounters[2];
+ }
+
+ void incrementUnigramCount() {
+ ++mEntryCounters[0];
+ }
+
+ void decrementUnigramCount() {
+ ASSERT(mEntryCounters[0] != 0);
+ --mEntryCounters[0];
+ }
+
+ void incrementBigramCount() {
+ ++mEntryCounters[1];
+ }
+
+ void decrementBigramCount() {
+ ASSERT(mEntryCounters[1] != 0);
+ --mEntryCounters[1];
+ }
+
+ void incrementNgramCount(const size_t n) {
+ if (n < 1 || n > mEntryCounters.size()) {
+ return;
+ }
+ ++mEntryCounters[n - 1];
+ }
+
+ void decrementNgramCount(const size_t n) {
+ if (n < 1 || n > mEntryCounters.size()) {
+ return;
+ }
+ ASSERT(mEntryCounters[n - 1] != 0);
+ --mEntryCounters[n - 1];
+ }
+
+ void setNgramCount(const size_t n, const int count) {
+ if (n < 1 || n > mEntryCounters.size()) {
+ return;
+ }
+ mEntryCounters[n - 1] = count;
+ }
+
+ private:
+ DISALLOW_COPY_AND_ASSIGN(MutableEntryCounters);
+
+ std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1> mEntryCounters;
+};
+} // namespace latinime
+#endif /* LATINIME_ENTRY_COUNTERS_H */
diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp
index e5ef2abf8..9055f7bfc 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp
@@ -38,8 +38,7 @@ const int ForgettingCurveUtils::OCCURRENCES_TO_RAISE_THE_LEVEL = 1;
// 15 days
const int ForgettingCurveUtils::DURATION_TO_LOWER_THE_LEVEL_IN_SECONDS = 15 * 24 * 60 * 60;
-const float ForgettingCurveUtils::UNIGRAM_COUNT_HARD_LIMIT_WEIGHT = 1.2;
-const float ForgettingCurveUtils::BIGRAM_COUNT_HARD_LIMIT_WEIGHT = 1.2;
+const float ForgettingCurveUtils::ENTRY_COUNT_HARD_LIMIT_WEIGHT = 1.2;
const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityTable;
@@ -126,14 +125,22 @@ const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityT
}
/* static */ bool ForgettingCurveUtils::needsToDecay(const bool mindsBlockByDecay,
- const int unigramCount, const int bigramCount, const HeaderPolicy *const headerPolicy) {
- if (unigramCount >= getUnigramCountHardLimit(headerPolicy->getMaxUnigramCount())) {
+ const EntryCounts &entryCounts, const HeaderPolicy *const headerPolicy) {
+ if (entryCounts.getUnigramCount()
+ >= getEntryCountHardLimit(headerPolicy->getMaxUnigramCount())) {
// Unigram count exceeds the limit.
return true;
- } else if (bigramCount >= getBigramCountHardLimit(headerPolicy->getMaxBigramCount())) {
+ }
+ if (entryCounts.getBigramCount()
+ >= getEntryCountHardLimit(headerPolicy->getMaxBigramCount())) {
// Bigram count exceeds the limit.
return true;
}
+ if (entryCounts.getTrigramCount()
+ >= getEntryCountHardLimit(headerPolicy->getMaxTrigramCount())) {
+ // Trigram count exceeds the limit.
+ return true;
+ }
if (mindsBlockByDecay) {
return false;
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h
index ccbc4a98d..06dcae8a1 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h
@@ -21,6 +21,7 @@
#include "defines.h"
#include "suggest/core/dictionary/property/historical_info.h"
+#include "suggest/policyimpl/dictionary/utils/entry_counters.h"
namespace latinime {
@@ -42,22 +43,17 @@ class ForgettingCurveUtils {
static bool needsToKeep(const HistoricalInfo *const historicalInfo,
const HeaderPolicy *const headerPolicy);
- static bool needsToDecay(const bool mindsBlockByDecay, const int unigramCount,
- const int bigramCount, const HeaderPolicy *const headerPolicy);
+ static bool needsToDecay(const bool mindsBlockByDecay, const EntryCounts &entryCounters,
+ const HeaderPolicy *const headerPolicy);
// TODO: Improve probability computation method and remove this.
static int getProbabilityBiasForNgram(const int n) {
return (n - 1) * MULTIPLIER_TWO_IN_PROBABILITY_SCALE;
}
- AK_FORCE_INLINE static int getUnigramCountHardLimit(const int maxUnigramCount) {
- return static_cast<int>(static_cast<float>(maxUnigramCount)
- * UNIGRAM_COUNT_HARD_LIMIT_WEIGHT);
- }
-
- AK_FORCE_INLINE static int getBigramCountHardLimit(const int maxBigramCount) {
- return static_cast<int>(static_cast<float>(maxBigramCount)
- * BIGRAM_COUNT_HARD_LIMIT_WEIGHT);
+ AK_FORCE_INLINE static int getEntryCountHardLimit(const int maxEntryCount) {
+ return static_cast<int>(static_cast<float>(maxEntryCount)
+ * ENTRY_COUNT_HARD_LIMIT_WEIGHT);
}
private:
@@ -101,8 +97,7 @@ class ForgettingCurveUtils {
static const int OCCURRENCES_TO_RAISE_THE_LEVEL;
static const int DURATION_TO_LOWER_THE_LEVEL_IN_SECONDS;
- static const float UNIGRAM_COUNT_HARD_LIMIT_WEIGHT;
- static const float BIGRAM_COUNT_HARD_LIMIT_WEIGHT;
+ static const float ENTRY_COUNT_HARD_LIMIT_WEIGHT;
static const ProbabilityTable sProbabilityTable;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp
index 0cffe569d..e225c235e 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp
@@ -28,15 +28,17 @@ const size_t FormatUtils::DICTIONARY_MINIMUM_SIZE = 12;
/* static */ FormatUtils::FORMAT_VERSION FormatUtils::getFormatVersion(const int formatVersion) {
switch (formatVersion) {
case VERSION_2:
- return VERSION_2;
case VERSION_201:
- return VERSION_201;
+ AKLOGE("Dictionary versions 2 and 201 are incompatible with this version");
+ return UNKNOWN_VERSION;
+ case VERSION_202:
+ return VERSION_202;
case VERSION_4_ONLY_FOR_TESTING:
return VERSION_4_ONLY_FOR_TESTING;
- case VERSION_4:
- return VERSION_4;
- case VERSION_4_DEV:
- return VERSION_4_DEV;
+ case VERSION_402:
+ return VERSION_402;
+ case VERSION_403:
+ return VERSION_403;
default:
return UNKNOWN_VERSION;
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h
index 96310086b..1616efcce 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h
@@ -31,11 +31,15 @@ class FormatUtils {
public:
enum FORMAT_VERSION {
// These MUST have the same values as the relevant constants in FormatSpec.java.
+ // TODO: Remove VERSION_2 and VERSION_201 when we:
+ // * Confirm that old versions of LatinIME download old-format dictionaries
+ // * We no longer need the corresponding constants on the Java side for dicttool
VERSION_2 = 2,
VERSION_201 = 201,
+ VERSION_202 = 202,
VERSION_4_ONLY_FOR_TESTING = 399,
- VERSION_4 = 402,
- VERSION_4_DEV = 403,
+ VERSION_402 = 402,
+ VERSION_403 = 403,
UNKNOWN_VERSION = -1
};
diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/probability_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/probability_utils.cpp
new file mode 100644
index 000000000..e8fa06942
--- /dev/null
+++ b/native/jni/src/suggest/policyimpl/dictionary/utils/probability_utils.cpp
@@ -0,0 +1,23 @@
+/*
+ * Copyright (C) 2014, The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "suggest/policyimpl/dictionary/utils/probability_utils.h"
+
+namespace latinime {
+
+const float ProbabilityUtils::PROBABILITY_ENCODING_SCALER = 8.58923700372f;
+
+} // namespace latinime
diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/probability_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/probability_utils.h
index 3b339e61a..2050af1e9 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/utils/probability_utils.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/utils/probability_utils.h
@@ -17,6 +17,9 @@
#ifndef LATINIME_PROBABILITY_UTILS_H
#define LATINIME_PROBABILITY_UTILS_H
+#include <algorithm>
+#include <cmath>
+
#include "defines.h"
namespace latinime {
@@ -47,8 +50,20 @@ class ProbabilityUtils {
+ static_cast<int>(static_cast<float>(bigramProbability + 1) * stepSize);
}
+ // Encode probability using the same way as we are doing for main dictionaries.
+ static AK_FORCE_INLINE int encodeRawProbability(const float rawProbability) {
+ const float probability = static_cast<float>(MAX_PROBABILITY)
+ + log2f(rawProbability) * PROBABILITY_ENCODING_SCALER;
+ if (probability < 0.0f) {
+ return 0;
+ }
+ return std::min(static_cast<int>(probability + 0.5f), MAX_PROBABILITY);
+ }
+
private:
DISALLOW_IMPLICIT_CONSTRUCTORS(ProbabilityUtils);
+
+ static const float PROBABILITY_ENCODING_SCALER;
};
}
#endif /* LATINIME_PROBABILITY_UTILS_H */
diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp
index 6a2db687d..856808a74 100644
--- a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp
+++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp
@@ -24,6 +24,7 @@ const int ScoringParams::THRESHOLD_NEXT_WORD_PROBABILITY_FOR_CAPPED = 120;
const float ScoringParams::AUTOCORRECT_OUTPUT_THRESHOLD = 1.0f;
const float ScoringParams::EXACT_MATCH_PROMOTION = 1.1f;
+const float ScoringParams::PERFECT_MATCH_PROMOTION = 1.1f;
const float ScoringParams::CASE_ERROR_PENALTY_FOR_EXACT_MATCH = 0.01f;
const float ScoringParams::ACCENT_ERROR_PENALTY_FOR_EXACT_MATCH = 0.02f;
const float ScoringParams::DIGRAPH_PENALTY_FOR_EXACT_MATCH = 0.03f;
@@ -48,17 +49,17 @@ const float ScoringParams::INSERTION_COST_SAME_CHAR = 0.5508f;
const float ScoringParams::INSERTION_COST_PROXIMITY_CHAR = 0.674f;
const float ScoringParams::INSERTION_COST_FIRST_CHAR = 0.639f;
const float ScoringParams::TRANSPOSITION_COST = 0.5608f;
-const float ScoringParams::SPACE_SUBSTITUTION_COST = 0.334f;
+const float ScoringParams::SPACE_SUBSTITUTION_COST = 0.33f;
+const float ScoringParams::SPACE_OMISSION_COST = 0.1f;
const float ScoringParams::ADDITIONAL_PROXIMITY_COST = 0.37972f;
const float ScoringParams::SUBSTITUTION_COST = 0.3806f;
-const float ScoringParams::COST_NEW_WORD = 0.0314f;
const float ScoringParams::COST_SECOND_OR_LATER_WORD_FIRST_CHAR_UPPERCASE = 0.3224f;
const float ScoringParams::DISTANCE_WEIGHT_LANGUAGE = 1.1214f;
const float ScoringParams::COST_FIRST_COMPLETION = 0.4836f;
const float ScoringParams::COST_COMPLETION = 0.00624f;
const float ScoringParams::HAS_PROXIMITY_TERMINAL_COST = 0.0683f;
const float ScoringParams::HAS_EDIT_CORRECTION_TERMINAL_COST = 0.0362f;
-const float ScoringParams::HAS_MULTI_WORD_TERMINAL_COST = 0.4182f;
+const float ScoringParams::HAS_MULTI_WORD_TERMINAL_COST = 0.3482f;
const float ScoringParams::TYPING_BASE_OUTPUT_SCORE = 1.0f;
const float ScoringParams::TYPING_MAX_OUTPUT_SCORE_PER_INPUT = 0.1f;
const float ScoringParams::NORMALIZED_SPATIAL_DISTANCE_THRESHOLD_FOR_EDIT = 0.095f;
diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.h b/native/jni/src/suggest/policyimpl/typing/scoring_params.h
index 731424f3d..6f327a370 100644
--- a/native/jni/src/suggest/policyimpl/typing/scoring_params.h
+++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.h
@@ -34,6 +34,7 @@ class ScoringParams {
static const int THRESHOLD_SHORT_WORD_LENGTH;
static const float EXACT_MATCH_PROMOTION;
+ static const float PERFECT_MATCH_PROMOTION;
static const float CASE_ERROR_PENALTY_FOR_EXACT_MATCH;
static const float ACCENT_ERROR_PENALTY_FOR_EXACT_MATCH;
static const float DIGRAPH_PENALTY_FOR_EXACT_MATCH;
@@ -56,9 +57,9 @@ class ScoringParams {
static const float INSERTION_COST_FIRST_CHAR;
static const float TRANSPOSITION_COST;
static const float SPACE_SUBSTITUTION_COST;
+ static const float SPACE_OMISSION_COST;
static const float ADDITIONAL_PROXIMITY_COST;
static const float SUBSTITUTION_COST;
- static const float COST_NEW_WORD;
static const float COST_SECOND_OR_LATER_WORD_FIRST_CHAR_UPPERCASE;
static const float DISTANCE_WEIGHT_LANGUAGE;
static const float COST_FIRST_COMPLETION;
diff --git a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h
index 0240bcf54..6acd767ea 100644
--- a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h
+++ b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h
@@ -44,23 +44,50 @@ class TypingScoring : public Scoring {
AK_FORCE_INLINE int calculateFinalScore(const float compoundDistance, const int inputSize,
const ErrorTypeUtils::ErrorType containedErrorTypes, const bool forceCommit,
- const bool boostExactMatches) const {
+ const bool boostExactMatches, const bool hasProbabilityZero) const {
const float maxDistance = ScoringParams::DISTANCE_WEIGHT_LANGUAGE
+ static_cast<float>(inputSize) * ScoringParams::TYPING_MAX_OUTPUT_SCORE_PER_INPUT;
float score = ScoringParams::TYPING_BASE_OUTPUT_SCORE - compoundDistance / maxDistance;
if (forceCommit) {
score += ScoringParams::AUTOCORRECT_OUTPUT_THRESHOLD;
}
- if (boostExactMatches && ErrorTypeUtils::isExactMatch(containedErrorTypes)) {
- score += ScoringParams::EXACT_MATCH_PROMOTION;
- if ((ErrorTypeUtils::MATCH_WITH_WRONG_CASE & containedErrorTypes) != 0) {
- score -= ScoringParams::CASE_ERROR_PENALTY_FOR_EXACT_MATCH;
+ if (hasProbabilityZero) {
+ // Previously, when both legitimate 0-frequency words (such as distracters) and
+ // offensive words were encoded in the same way, distracters would never show up
+ // when the user blocked offensive words (the default setting, as well as the
+ // setting for regression tests).
+ //
+ // When b/11031090 was fixed and a separate encoding was used for offensive words,
+ // 0-frequency words would no longer be blocked when they were an "exact match"
+ // (where case mismatches and accent mismatches would be considered an "exact
+ // match"). The exact match boosting functionality meant that, for example, when
+ // the user typed "mt" they would be suggested the word "Mt", although they most
+ // probably meant to type "my".
+ //
+ // For this reason, we introduced this change, which does the following:
+ // * Defines the "perfect match" as a really exact match, with no room for case or
+ // accent mismatches
+ // * When the target word has probability zero (as "Mt" does, because it is a
+ // distracter), ONLY boost its score if it is a perfect match.
+ //
+ // By doing this, when the user types "mt", the word "Mt" will NOT be boosted, and
+ // they will get "my". However, if the user makes an explicit effort to type "Mt",
+ // we do boost the word "Mt" so that the user's input is not autocorrected to "My".
+ if (boostExactMatches && ErrorTypeUtils::isPerfectMatch(containedErrorTypes)) {
+ score += ScoringParams::PERFECT_MATCH_PROMOTION;
}
- if ((ErrorTypeUtils::MATCH_WITH_MISSING_ACCENT & containedErrorTypes) != 0) {
- score -= ScoringParams::ACCENT_ERROR_PENALTY_FOR_EXACT_MATCH;
- }
- if ((ErrorTypeUtils::MATCH_WITH_DIGRAPH & containedErrorTypes) != 0) {
- score -= ScoringParams::DIGRAPH_PENALTY_FOR_EXACT_MATCH;
+ } else {
+ if (boostExactMatches && ErrorTypeUtils::isExactMatch(containedErrorTypes)) {
+ score += ScoringParams::EXACT_MATCH_PROMOTION;
+ if ((ErrorTypeUtils::MATCH_WITH_WRONG_CASE & containedErrorTypes) != 0) {
+ score -= ScoringParams::CASE_ERROR_PENALTY_FOR_EXACT_MATCH;
+ }
+ if ((ErrorTypeUtils::MATCH_WITH_MISSING_ACCENT & containedErrorTypes) != 0) {
+ score -= ScoringParams::ACCENT_ERROR_PENALTY_FOR_EXACT_MATCH;
+ }
+ if ((ErrorTypeUtils::MATCH_WITH_DIGRAPH & containedErrorTypes) != 0) {
+ score -= ScoringParams::DIGRAPH_PENALTY_FOR_EXACT_MATCH;
+ }
}
}
return static_cast<int>(score * SUGGEST_INTERFACE_OUTPUT_SCALE);
diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h
index 84077174d..1338ac81a 100644
--- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h
+++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h
@@ -150,9 +150,10 @@ class TypingWeighting : public Weighting {
return cost + weightedDistance;
}
- float getNewWordSpatialCost(const DicTraverseSession *const traverseSession,
+ float getSpaceOmissionCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode, DicNode_InputStateG *inputStateG) const {
- return ScoringParams::COST_NEW_WORD * traverseSession->getMultiWordCostMultiplier();
+ const float cost = ScoringParams::SPACE_OMISSION_COST;
+ return cost * traverseSession->getMultiWordCostMultiplier();
}
float getNewWordBigramLanguageCost(const DicTraverseSession *const traverseSession,
@@ -202,7 +203,10 @@ class TypingWeighting : public Weighting {
AK_FORCE_INLINE float getSpaceSubstitutionCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const {
- const float cost = ScoringParams::SPACE_SUBSTITUTION_COST + ScoringParams::COST_NEW_WORD;
+ const int inputIndex = dicNode->getInputIndex(0);
+ const float distanceToSpaceKey = traverseSession->getProximityInfoState(0)
+ ->getPointToKeyLength(inputIndex, KEYCODE_SPACE);
+ const float cost = ScoringParams::SPACE_SUBSTITUTION_COST * distanceToSpaceKey;
return cost * traverseSession->getMultiWordCostMultiplier();
}
diff --git a/native/jni/src/utils/char_utils.h b/native/jni/src/utils/char_utils.h
index 5e9cdd9b2..7871c26ef 100644
--- a/native/jni/src/utils/char_utils.h
+++ b/native/jni/src/utils/char_utils.h
@@ -101,6 +101,17 @@ class CharUtils {
return codePointCount + 1;
}
+ // Returns updated code point count.
+ static AK_FORCE_INLINE int removeBeginningOfSentenceMarker(int *const codePoints,
+ const int codePointCount) {
+ if (codePointCount <= 0 || codePoints[0] != CODE_POINT_BEGINNING_OF_SENTENCE) {
+ return codePointCount;
+ }
+ const int newCodePointCount = codePointCount - 1;
+ memmove(codePoints, codePoints + 1, sizeof(int) * newCodePointCount);
+ return newCodePointCount;
+ }
+
private:
DISALLOW_IMPLICIT_CONSTRUCTORS(CharUtils);
diff --git a/native/jni/src/utils/jni_data_utils.h b/native/jni/src/utils/jni_data_utils.h
index 25cc41742..a259e1cd0 100644
--- a/native/jni/src/utils/jni_data_utils.h
+++ b/native/jni/src/utils/jni_data_utils.h
@@ -50,6 +50,7 @@ class JniDataUtils {
const jsize keyUtf8Length = env->GetStringUTFLength(keyString);
char keyChars[keyUtf8Length + 1];
env->GetStringUTFRegion(keyString, 0, env->GetStringLength(keyString), keyChars);
+ env->DeleteLocalRef(keyString);
keyChars[keyUtf8Length] = '\0';
DictionaryHeaderStructurePolicy::AttributeMap::key_type key;
HeaderReadWriteUtils::insertCharactersIntoVector(keyChars, &key);
@@ -59,6 +60,7 @@ class JniDataUtils {
const jsize valueUtf8Length = env->GetStringUTFLength(valueString);
char valueChars[valueUtf8Length + 1];
env->GetStringUTFRegion(valueString, 0, env->GetStringLength(valueString), valueChars);
+ env->DeleteLocalRef(valueString);
valueChars[valueUtf8Length] = '\0';
DictionaryHeaderStructurePolicy::AttributeMap::mapped_type value;
HeaderReadWriteUtils::insertCharactersIntoVector(valueChars, &value);
@@ -113,6 +115,7 @@ class JniDataUtils {
continue;
}
env->GetIntArrayRegion(prevWord, 0, prevWordLength, prevWordCodePoints[i]);
+ env->DeleteLocalRef(prevWord);
prevWordCodePointCount[i] = prevWordLength;
jboolean isBeginningOfSentenceBoolean = JNI_FALSE;
env->GetBooleanArrayRegion(isBeginningOfSentenceArray, i, 1 /* len */,