aboutsummaryrefslogtreecommitdiffstats
path: root/native
diff options
context:
space:
mode:
Diffstat (limited to 'native')
-rw-r--r--native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp15
-rw-r--r--native/jni/src/suggest/core/dictionary/property/unigram_property.h49
-rw-r--r--native/jni/src/suggest/core/dictionary/property/word_property.cpp2
-rw-r--r--native/jni/src/suggest/core/dictionary/word_attributes.h8
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp20
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h34
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp11
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp22
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h8
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.cpp11
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.h4
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.cpp23
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.h8
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.cpp4
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h10
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h13
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp4
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp52
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h26
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp2
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp38
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h10
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp45
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h8
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp4
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/utils/entry_counters.h133
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp17
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h19
28 files changed, 389 insertions, 211 deletions
diff --git a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp
index f8dadb488..b01acead7 100644
--- a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp
+++ b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp
@@ -358,7 +358,7 @@ static void latinime_BinaryDictionary_getWordProperty(JNIEnv *env, jclass clazz,
static bool latinime_BinaryDictionary_addUnigramEntry(JNIEnv *env, jclass clazz, jlong dict,
jintArray word, jint probability, jintArray shortcutTarget, jint shortcutProbability,
- jboolean isBeginningOfSentence, jboolean isNotAWord, jboolean isBlacklisted,
+ jboolean isBeginningOfSentence, jboolean isNotAWord, jboolean isPossiblyOffensive,
jint timestamp) {
Dictionary *dictionary = reinterpret_cast<Dictionary *>(dict);
if (!dictionary) {
@@ -377,8 +377,8 @@ static bool latinime_BinaryDictionary_addUnigramEntry(JNIEnv *env, jclass clazz,
}
// Use 1 for count to indicate the word has inputted.
const UnigramProperty unigramProperty(isBeginningOfSentence, isNotAWord,
- isBlacklisted, probability, HistoricalInfo(timestamp, 0 /* level */, 1 /* count */),
- std::move(shortcuts));
+ isPossiblyOffensive, probability, HistoricalInfo(timestamp, 0 /* level */,
+ 1 /* count */), std::move(shortcuts));
return dictionary->addUnigramEntry(CodePointArrayView(codePoints, codePointCount),
&unigramProperty);
}
@@ -480,8 +480,8 @@ static int latinime_BinaryDictionary_addMultipleDictionaryEntries(JNIEnv *env, j
env->GetFieldID(languageModelParamClass, "mShortcutProbability", "I");
jfieldID isNotAWordFieldId =
env->GetFieldID(languageModelParamClass, "mIsNotAWord", "Z");
- jfieldID isBlacklistedFieldId =
- env->GetFieldID(languageModelParamClass, "mIsBlacklisted", "Z");
+ jfieldID isPossiblyOffensiveFieldId =
+ env->GetFieldID(languageModelParamClass, "mIsPossiblyOffensive", "Z");
env->DeleteLocalRef(languageModelParamClass);
for (int i = startIndex; i < languageModelParamCount; ++i) {
@@ -504,7 +504,8 @@ static int latinime_BinaryDictionary_addMultipleDictionaryEntries(JNIEnv *env, j
jint unigramProbability = env->GetIntField(languageModelParam, unigramProbabilityFieldId);
jint timestamp = env->GetIntField(languageModelParam, timestampFieldId);
jboolean isNotAWord = env->GetBooleanField(languageModelParam, isNotAWordFieldId);
- jboolean isBlacklisted = env->GetBooleanField(languageModelParam, isBlacklistedFieldId);
+ jboolean isPossiblyOffensive = env->GetBooleanField(languageModelParam,
+ isPossiblyOffensiveFieldId);
jintArray shortcutTarget = static_cast<jintArray>(
env->GetObjectField(languageModelParam, shortcutTargetFieldId));
std::vector<UnigramProperty::ShortcutProperty> shortcuts;
@@ -519,7 +520,7 @@ static int latinime_BinaryDictionary_addMultipleDictionaryEntries(JNIEnv *env, j
}
// Use 1 for count to indicate the word has inputted.
const UnigramProperty unigramProperty(false /* isBeginningOfSentence */, isNotAWord,
- isBlacklisted, unigramProbability,
+ isPossiblyOffensive, unigramProbability,
HistoricalInfo(timestamp, 0 /* level */, 1 /* count */), std::move(shortcuts));
dictionary->addUnigramEntry(CodePointArrayView(word1CodePoints, word1Length),
&unigramProperty);
diff --git a/native/jni/src/suggest/core/dictionary/property/unigram_property.h b/native/jni/src/suggest/core/dictionary/property/unigram_property.h
index 5ed2e2602..f194f979a 100644
--- a/native/jni/src/suggest/core/dictionary/property/unigram_property.h
+++ b/native/jni/src/suggest/core/dictionary/property/unigram_property.h
@@ -49,21 +49,44 @@ class UnigramProperty {
};
UnigramProperty()
- : mRepresentsBeginningOfSentence(false), mIsNotAWord(false), mIsBlacklisted(false),
- mProbability(NOT_A_PROBABILITY), mHistoricalInfo(), mShortcuts() {}
+ : mRepresentsBeginningOfSentence(false), mIsNotAWord(false),
+ mIsBlacklisted(false), mIsPossiblyOffensive(false), mProbability(NOT_A_PROBABILITY),
+ mHistoricalInfo(), mShortcuts() {}
+ // In contexts which do not support the Blacklisted flag (v2, v4<403)
UnigramProperty(const bool representsBeginningOfSentence, const bool isNotAWord,
- const bool isBlacklisted, const int probability, const HistoricalInfo historicalInfo,
- const std::vector<ShortcutProperty> &&shortcuts)
+ const bool isPossiblyOffensive, const int probability,
+ const HistoricalInfo historicalInfo, const std::vector<ShortcutProperty> &&shortcuts)
: mRepresentsBeginningOfSentence(representsBeginningOfSentence),
- mIsNotAWord(isNotAWord), mIsBlacklisted(isBlacklisted), mProbability(probability),
+ mIsNotAWord(isNotAWord), mIsBlacklisted(false),
+ mIsPossiblyOffensive(isPossiblyOffensive), mProbability(probability),
mHistoricalInfo(historicalInfo), mShortcuts(std::move(shortcuts)) {}
- // Without shortcuts.
+ // Without shortcuts, in contexts which do not support the Blacklisted flag (v2, v4<403)
UnigramProperty(const bool representsBeginningOfSentence, const bool isNotAWord,
- const bool isBlacklisted, const int probability, const HistoricalInfo historicalInfo)
+ const bool isPossiblyOffensive, const int probability,
+ const HistoricalInfo historicalInfo)
: mRepresentsBeginningOfSentence(representsBeginningOfSentence),
- mIsNotAWord(isNotAWord), mIsBlacklisted(isBlacklisted), mProbability(probability),
+ mIsNotAWord(isNotAWord), mIsBlacklisted(false),
+ mIsPossiblyOffensive(isPossiblyOffensive), mProbability(probability),
+ mHistoricalInfo(historicalInfo), mShortcuts() {}
+
+ // In contexts which DO support the Blacklisted flag (v403)
+ UnigramProperty(const bool representsBeginningOfSentence, const bool isNotAWord,
+ const bool isBlacklisted, const bool isPossiblyOffensive, const int probability,
+ const HistoricalInfo historicalInfo, const std::vector<ShortcutProperty> &&shortcuts)
+ : mRepresentsBeginningOfSentence(representsBeginningOfSentence),
+ mIsNotAWord(isNotAWord), mIsBlacklisted(isBlacklisted),
+ mIsPossiblyOffensive(isPossiblyOffensive), mProbability(probability),
+ mHistoricalInfo(historicalInfo), mShortcuts(std::move(shortcuts)) {}
+
+ // Without shortcuts, in contexts which DO support the Blacklisted flag (v403)
+ UnigramProperty(const bool representsBeginningOfSentence, const bool isNotAWord,
+ const bool isBlacklisted, const bool isPossiblyOffensive, const int probability,
+ const HistoricalInfo historicalInfo)
+ : mRepresentsBeginningOfSentence(representsBeginningOfSentence),
+ mIsNotAWord(isNotAWord), mIsBlacklisted(isBlacklisted),
+ mIsPossiblyOffensive(isPossiblyOffensive), mProbability(probability),
mHistoricalInfo(historicalInfo), mShortcuts() {}
bool representsBeginningOfSentence() const {
@@ -74,13 +97,12 @@ class UnigramProperty {
return mIsNotAWord;
}
- bool isBlacklisted() const {
- return mIsBlacklisted;
+ bool isPossiblyOffensive() const {
+ return mIsPossiblyOffensive;
}
- bool isPossiblyOffensive() const {
- // TODO: Have dedicated flag.
- return mProbability == 0;
+ bool isBlacklisted() const {
+ return mIsBlacklisted;
}
bool hasShortcuts() const {
@@ -106,6 +128,7 @@ class UnigramProperty {
const bool mRepresentsBeginningOfSentence;
const bool mIsNotAWord;
const bool mIsBlacklisted;
+ const bool mIsPossiblyOffensive;
const int mProbability;
const HistoricalInfo mHistoricalInfo;
const std::vector<ShortcutProperty> mShortcuts;
diff --git a/native/jni/src/suggest/core/dictionary/property/word_property.cpp b/native/jni/src/suggest/core/dictionary/property/word_property.cpp
index caac8fe79..a707f1ba2 100644
--- a/native/jni/src/suggest/core/dictionary/property/word_property.cpp
+++ b/native/jni/src/suggest/core/dictionary/property/word_property.cpp
@@ -28,7 +28,7 @@ void WordProperty::outputProperties(JNIEnv *const env, jintArray outCodePoints,
JniDataUtils::outputCodePoints(env, outCodePoints, 0 /* start */,
MAX_WORD_LENGTH /* maxLength */, mCodePoints.data(), mCodePoints.size(),
false /* needsNullTermination */);
- jboolean flags[] = {mUnigramProperty.isNotAWord(), mUnigramProperty.isBlacklisted(),
+ jboolean flags[] = {mUnigramProperty.isNotAWord(), mUnigramProperty.isPossiblyOffensive(),
!mNgrams.empty(), mUnigramProperty.hasShortcuts(),
mUnigramProperty.representsBeginningOfSentence()};
env->SetBooleanArrayRegion(outFlags, 0 /* start */, NELEMS(flags), flags);
diff --git a/native/jni/src/suggest/core/dictionary/word_attributes.h b/native/jni/src/suggest/core/dictionary/word_attributes.h
index 6e9da3570..5351e7d7d 100644
--- a/native/jni/src/suggest/core/dictionary/word_attributes.h
+++ b/native/jni/src/suggest/core/dictionary/word_attributes.h
@@ -43,6 +43,14 @@ class WordAttributes {
return mIsNotAWord;
}
+ // Whether or not a word is possibly offensive.
+ // * Static dictionaries <v202, as well as dynamic dictionaries <v403, will set this based on
+ // whether or not the probability of the word is zero.
+ // * Static dictionaries >=v203 will set this based on the IS_POSSIBLY_OFFENSIVE PtNode flag.
+ // * Dynamic dictionaries >=v403 will set this based on the IS_POSSIBLY_OFFENSIVE language model
+ // flag (the PtNode flag IS_BLACKLISTED is ignored and kept as zero)
+ //
+ // See the ::getWordAttributes function for each of these dictionary policies for more details.
bool isPossiblyOffensive() const {
return mIsPossiblyOffensive;
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp
index 8fb256c54..300e96c4e 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp
@@ -30,6 +30,7 @@ const char *const HeaderPolicy::DATE_KEY = "date";
const char *const HeaderPolicy::LAST_DECAYED_TIME_KEY = "LAST_DECAYED_TIME";
const char *const HeaderPolicy::UNIGRAM_COUNT_KEY = "UNIGRAM_COUNT";
const char *const HeaderPolicy::BIGRAM_COUNT_KEY = "BIGRAM_COUNT";
+const char *const HeaderPolicy::TRIGRAM_COUNT_KEY = "TRIGRAM_COUNT";
const char *const HeaderPolicy::EXTENDED_REGION_SIZE_KEY = "EXTENDED_REGION_SIZE";
// Historical info is information that is needed to support decaying such as timestamp, level and
// count.
@@ -94,12 +95,11 @@ bool HeaderPolicy::readRequiresGermanUmlautProcessing() const {
}
bool HeaderPolicy::fillInAndWriteHeaderToBuffer(const bool updatesLastDecayedTime,
- const int unigramCount, const int bigramCount,
- const int extendedRegionSize, BufferWithExtendableBuffer *const outBuffer) const {
+ const EntryCounts &entryCounts, const int extendedRegionSize,
+ BufferWithExtendableBuffer *const outBuffer) const {
int writingPos = 0;
DictionaryHeaderStructurePolicy::AttributeMap attributeMapToWrite(mAttributeMap);
- fillInHeader(updatesLastDecayedTime, unigramCount, bigramCount,
- extendedRegionSize, &attributeMapToWrite);
+ fillInHeader(updatesLastDecayedTime, entryCounts, extendedRegionSize, &attributeMapToWrite);
if (!HeaderReadWriteUtils::writeDictionaryVersion(outBuffer, mDictFormatVersion,
&writingPos)) {
return false;
@@ -126,11 +126,15 @@ bool HeaderPolicy::fillInAndWriteHeaderToBuffer(const bool updatesLastDecayedTim
return true;
}
-void HeaderPolicy::fillInHeader(const bool updatesLastDecayedTime, const int unigramCount,
- const int bigramCount, const int extendedRegionSize,
+void HeaderPolicy::fillInHeader(const bool updatesLastDecayedTime,
+ const EntryCounts &entryCounts, const int extendedRegionSize,
DictionaryHeaderStructurePolicy::AttributeMap *outAttributeMap) const {
- HeaderReadWriteUtils::setIntAttribute(outAttributeMap, UNIGRAM_COUNT_KEY, unigramCount);
- HeaderReadWriteUtils::setIntAttribute(outAttributeMap, BIGRAM_COUNT_KEY, bigramCount);
+ HeaderReadWriteUtils::setIntAttribute(outAttributeMap, UNIGRAM_COUNT_KEY,
+ entryCounts.getUnigramCount());
+ HeaderReadWriteUtils::setIntAttribute(outAttributeMap, BIGRAM_COUNT_KEY,
+ entryCounts.getBigramCount());
+ HeaderReadWriteUtils::setIntAttribute(outAttributeMap, TRIGRAM_COUNT_KEY,
+ entryCounts.getTrigramCount());
HeaderReadWriteUtils::setIntAttribute(outAttributeMap, EXTENDED_REGION_SIZE_KEY,
extendedRegionSize);
// Set the current time as the generation time.
diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h
index 836bbe5a1..44c2f443f 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h
@@ -22,6 +22,7 @@
#include "defines.h"
#include "suggest/core/policy/dictionary_header_structure_policy.h"
#include "suggest/policyimpl/dictionary/header/header_read_write_utils.h"
+#include "suggest/policyimpl/dictionary/utils/entry_counters.h"
#include "suggest/policyimpl/dictionary/utils/format_utils.h"
#include "utils/char_utils.h"
#include "utils/time_keeper.h"
@@ -49,6 +50,8 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
UNIGRAM_COUNT_KEY, 0 /* defaultValue */)),
mBigramCount(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
BIGRAM_COUNT_KEY, 0 /* defaultValue */)),
+ mTrigramCount(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
+ TRIGRAM_COUNT_KEY, 0 /* defaultValue */)),
mExtendedRegionSize(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
EXTENDED_REGION_SIZE_KEY, 0 /* defaultValue */)),
mHasHistoricalInfoOfWords(HeaderReadWriteUtils::readBoolAttributeValue(
@@ -60,6 +63,8 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
&mAttributeMap, MAX_UNIGRAM_COUNT_KEY, DEFAULT_MAX_UNIGRAM_COUNT)),
mMaxBigramCount(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, MAX_BIGRAM_COUNT_KEY, DEFAULT_MAX_BIGRAM_COUNT)),
+ mMaxTrigramCount(HeaderReadWriteUtils::readIntAttributeValue(
+ &mAttributeMap, MAX_TRIGRAM_COUNT_KEY, DEFAULT_MAX_TRIGRAM_COUNT)),
mCodePointTable(HeaderReadWriteUtils::readCodePointTable(&mAttributeMap)) {}
// Constructs header information using an attribute map.
@@ -77,7 +82,7 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
DATE_KEY, TimeKeeper::peekCurrentTime() /* defaultValue */)),
mLastDecayedTime(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
DATE_KEY, TimeKeeper::peekCurrentTime() /* defaultValue */)),
- mUnigramCount(0), mBigramCount(0), mExtendedRegionSize(0),
+ mUnigramCount(0), mBigramCount(0), mTrigramCount(0), mExtendedRegionSize(0),
mHasHistoricalInfoOfWords(HeaderReadWriteUtils::readBoolAttributeValue(
&mAttributeMap, HAS_HISTORICAL_INFO_KEY, false /* defaultValue */)),
mForgettingCurveProbabilityValuesTableId(HeaderReadWriteUtils::readIntAttributeValue(
@@ -87,6 +92,8 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
&mAttributeMap, MAX_UNIGRAM_COUNT_KEY, DEFAULT_MAX_UNIGRAM_COUNT)),
mMaxBigramCount(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, MAX_BIGRAM_COUNT_KEY, DEFAULT_MAX_BIGRAM_COUNT)),
+ mMaxTrigramCount(HeaderReadWriteUtils::readIntAttributeValue(
+ &mAttributeMap, MAX_TRIGRAM_COUNT_KEY, DEFAULT_MAX_TRIGRAM_COUNT)),
mCodePointTable(HeaderReadWriteUtils::readCodePointTable(&mAttributeMap)) {}
// Copy header information
@@ -99,12 +106,14 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
mIsDecayingDict(headerPolicy->mIsDecayingDict),
mDate(headerPolicy->mDate), mLastDecayedTime(headerPolicy->mLastDecayedTime),
mUnigramCount(headerPolicy->mUnigramCount), mBigramCount(headerPolicy->mBigramCount),
+ mTrigramCount(headerPolicy->mTrigramCount),
mExtendedRegionSize(headerPolicy->mExtendedRegionSize),
mHasHistoricalInfoOfWords(headerPolicy->mHasHistoricalInfoOfWords),
mForgettingCurveProbabilityValuesTableId(
headerPolicy->mForgettingCurveProbabilityValuesTableId),
mMaxUnigramCount(headerPolicy->mMaxUnigramCount),
mMaxBigramCount(headerPolicy->mMaxBigramCount),
+ mMaxTrigramCount(headerPolicy->mMaxTrigramCount),
mCodePointTable(headerPolicy->mCodePointTable) {}
// Temporary dummy header.
@@ -112,10 +121,10 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
: mDictFormatVersion(FormatUtils::UNKNOWN_VERSION), mDictionaryFlags(0), mSize(0),
mAttributeMap(), mLocale(CharUtils::EMPTY_STRING), mMultiWordCostMultiplier(0.0f),
mRequiresGermanUmlautProcessing(false), mIsDecayingDict(false),
- mDate(0), mLastDecayedTime(0), mUnigramCount(0), mBigramCount(0),
+ mDate(0), mLastDecayedTime(0), mUnigramCount(0), mBigramCount(0), mTrigramCount(0),
mExtendedRegionSize(0), mHasHistoricalInfoOfWords(false),
mForgettingCurveProbabilityValuesTableId(0), mMaxUnigramCount(0), mMaxBigramCount(0),
- mCodePointTable(nullptr) {}
+ mMaxTrigramCount(0), mCodePointTable(nullptr) {}
~HeaderPolicy() {}
@@ -183,6 +192,10 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
return mBigramCount;
}
+ AK_FORCE_INLINE int getTrigramCount() const {
+ return mTrigramCount;
+ }
+
AK_FORCE_INLINE int getExtendedRegionSize() const {
return mExtendedRegionSize;
}
@@ -212,15 +225,19 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
return mMaxBigramCount;
}
+ AK_FORCE_INLINE int getMaxTrigramCount() const {
+ return mMaxTrigramCount;
+ }
+
void readHeaderValueOrQuestionMark(const char *const key,
int *outValue, int outValueSize) const;
bool fillInAndWriteHeaderToBuffer(const bool updatesLastDecayedTime,
- const int unigramCount, const int bigramCount,
- const int extendedRegionSize, BufferWithExtendableBuffer *const outBuffer) const;
+ const EntryCounts &entryCounts, const int extendedRegionSize,
+ BufferWithExtendableBuffer *const outBuffer) const;
- void fillInHeader(const bool updatesLastDecayedTime,
- const int unigramCount, const int bigramCount, const int extendedRegionSize,
+ void fillInHeader(const bool updatesLastDecayedTime, const EntryCounts &entryCounts,
+ const int extendedRegionSize,
DictionaryHeaderStructurePolicy::AttributeMap *outAttributeMap) const;
AK_FORCE_INLINE const std::vector<int> *getLocale() const {
@@ -245,6 +262,7 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
static const char *const LAST_DECAYED_TIME_KEY;
static const char *const UNIGRAM_COUNT_KEY;
static const char *const BIGRAM_COUNT_KEY;
+ static const char *const TRIGRAM_COUNT_KEY;
static const char *const EXTENDED_REGION_SIZE_KEY;
static const char *const HAS_HISTORICAL_INFO_KEY;
static const char *const LOCALE_KEY;
@@ -273,11 +291,13 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
const int mLastDecayedTime;
const int mUnigramCount;
const int mBigramCount;
+ const int mTrigramCount;
const int mExtendedRegionSize;
const bool mHasHistoricalInfoOfWords;
const int mForgettingCurveProbabilityValuesTableId;
const int mMaxUnigramCount;
const int mMaxBigramCount;
+ const int mMaxTrigramCount;
const int *const mCodePointTable;
const std::vector<int> readLocale() const;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp
index 6243f14cc..d558b949a 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp
@@ -245,7 +245,7 @@ bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds
if (!sourcePtNodeParams.hasBigrams()) {
// Update has bigrams flag.
return updatePtNodeFlags(sourcePtNodeParams.getHeadPos(),
- sourcePtNodeParams.isBlacklisted(), sourcePtNodeParams.isNotAWord(),
+ sourcePtNodeParams.isPossiblyOffensive(), sourcePtNodeParams.isNotAWord(),
sourcePtNodeParams.isTerminal(), sourcePtNodeParams.hasShortcutTargets(),
true /* hasBigrams */,
sourcePtNodeParams.getCodePointCount() > 1 /* hasMultipleChars */);
@@ -316,7 +316,7 @@ bool Ver4PatriciaTrieNodeWriter::addShortcutTarget(const PtNodeParams *const ptN
if (!ptNodeParams->hasShortcutTargets()) {
// Update has shortcut targets flag.
return updatePtNodeFlags(ptNodeParams->getHeadPos(),
- ptNodeParams->isBlacklisted(), ptNodeParams->isNotAWord(),
+ ptNodeParams->isPossiblyOffensive(), ptNodeParams->isNotAWord(),
ptNodeParams->isTerminal(), true /* hasShortcutTargets */,
ptNodeParams->hasBigrams(),
ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */);
@@ -330,7 +330,7 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeHasBigramsAndShortcutTargetsFlags(
ptNodeParams->getTerminalId()) != NOT_A_DICT_POS;
const bool hasShortcutTargets = mBuffers->getShortcutDictContent()->getShortcutListHeadPos(
ptNodeParams->getTerminalId()) != NOT_A_DICT_POS;
- return updatePtNodeFlags(ptNodeParams->getHeadPos(), ptNodeParams->isBlacklisted(),
+ return updatePtNodeFlags(ptNodeParams->getHeadPos(), ptNodeParams->isPossiblyOffensive(),
ptNodeParams->isNotAWord(), ptNodeParams->isTerminal(), hasShortcutTargets,
hasBigrams, ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */);
}
@@ -386,8 +386,9 @@ bool Ver4PatriciaTrieNodeWriter::writePtNodeAndGetTerminalIdAndAdvancePosition(
ptNodeParams->getChildrenPos(), ptNodeWritingPos)) {
return false;
}
- return updatePtNodeFlags(nodePos, ptNodeParams->isBlacklisted(), ptNodeParams->isNotAWord(),
- isTerminal, ptNodeParams->hasShortcutTargets(), ptNodeParams->hasBigrams(),
+ return updatePtNodeFlags(nodePos, ptNodeParams->isPossiblyOffensive(),
+ ptNodeParams->isNotAWord(), isTerminal, ptNodeParams->hasShortcutTargets(),
+ ptNodeParams->hasBigrams(),
ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */);
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
index 0eae934ae..29f9ba37f 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
@@ -303,7 +303,7 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const CodePointArrayView wordCodePo
if (mUpdatingHelper.addUnigramWord(&readingHelper, codePointArrayView, unigramProperty,
&addedNewUnigram)) {
if (addedNewUnigram && !unigramProperty->representsBeginningOfSentence()) {
- mUnigramCount++;
+ mEntryCounters.incrementUnigramCount();
}
if (unigramProperty->getShortcuts().size() > 0) {
// Add shortcut target.
@@ -397,7 +397,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const NgramContext *const ngramContex
if (mUpdatingHelper.addNgramEntry(PtNodePosArrayView::singleElementView(&prevWordPtNodePos),
wordPos, ngramProperty, &addedNewBigram)) {
if (addedNewBigram) {
- mBigramCount++;
+ mEntryCounters.incrementBigramCount();
}
return true;
} else {
@@ -438,7 +438,7 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const NgramContext *const ngramCon
const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]);
if (mUpdatingHelper.removeNgramEntry(
PtNodePosArrayView::singleElementView(&prevWordPtNodePos), wordPos)) {
- mBigramCount--;
+ mEntryCounters.decrementBigramCount();
return true;
} else {
return false;
@@ -477,7 +477,7 @@ bool Ver4PatriciaTriePolicy::flush(const char *const filePath) {
AKLOGI("Warning: flush() is called for non-updatable dictionary. filePath: %s", filePath);
return false;
}
- if (!mWritingHelper.writeToDictFile(filePath, mUnigramCount, mBigramCount)) {
+ if (!mWritingHelper.writeToDictFile(filePath, mEntryCounters.getEntryCounts())) {
AKLOGE("Cannot flush the dictionary to file.");
mIsCorrupted = true;
return false;
@@ -515,7 +515,7 @@ bool Ver4PatriciaTriePolicy::needsToRunGC(const bool mindsBlockByGC) const {
// Needs to reduce dictionary size.
return true;
} else if (mHeaderPolicy->isDecayingDict()) {
- return ForgettingCurveUtils::needsToDecay(mindsBlockByGC, mUnigramCount, mBigramCount,
+ return ForgettingCurveUtils::needsToDecay(mindsBlockByGC, mEntryCounters.getEntryCounts(),
mHeaderPolicy);
}
return false;
@@ -525,19 +525,19 @@ void Ver4PatriciaTriePolicy::getProperty(const char *const query, const int quer
char *const outResult, const int maxResultLength) {
const int compareLength = queryLength + 1 /* terminator */;
if (strncmp(query, UNIGRAM_COUNT_QUERY, compareLength) == 0) {
- snprintf(outResult, maxResultLength, "%d", mUnigramCount);
+ snprintf(outResult, maxResultLength, "%d", mEntryCounters.getUnigramCount());
} else if (strncmp(query, BIGRAM_COUNT_QUERY, compareLength) == 0) {
- snprintf(outResult, maxResultLength, "%d", mBigramCount);
+ snprintf(outResult, maxResultLength, "%d", mEntryCounters.getBigramCount());
} else if (strncmp(query, MAX_UNIGRAM_COUNT_QUERY, compareLength) == 0) {
snprintf(outResult, maxResultLength, "%d",
mHeaderPolicy->isDecayingDict() ?
- ForgettingCurveUtils::getUnigramCountHardLimit(
+ ForgettingCurveUtils::getEntryCountHardLimit(
mHeaderPolicy->getMaxUnigramCount()) :
static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE));
} else if (strncmp(query, MAX_BIGRAM_COUNT_QUERY, compareLength) == 0) {
snprintf(outResult, maxResultLength, "%d",
mHeaderPolicy->isDecayingDict() ?
- ForgettingCurveUtils::getBigramCountHardLimit(
+ ForgettingCurveUtils::getEntryCountHardLimit(
mHeaderPolicy->getMaxBigramCount()) :
static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE));
}
@@ -608,8 +608,8 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
}
}
const UnigramProperty unigramProperty(ptNodeParams.representsBeginningOfSentence(),
- ptNodeParams.isNotAWord(), ptNodeParams.isBlacklisted(), ptNodeParams.getProbability(),
- *historicalInfo, std::move(shortcuts));
+ ptNodeParams.isNotAWord(), ptNodeParams.isPossiblyOffensive(),
+ ptNodeParams.getProbability(), *historicalInfo, std::move(shortcuts));
return WordProperty(wordCodePoints.toVector(), &unigramProperty, &ngrams);
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h
index 1ad5e7e36..2cda0d3fa 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h
@@ -41,6 +41,7 @@
#include "suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.h"
#include "suggest/policyimpl/dictionary/structure/backward/v402/ver4_pt_node_array_reader.h"
#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h"
+#include "suggest/policyimpl/dictionary/utils/entry_counters.h"
#include "utils/int_array_view.h"
namespace latinime {
@@ -75,8 +76,8 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
&mPtNodeArrayReader, &mBigramPolicy, &mShortcutPolicy),
mUpdatingHelper(mDictBuffer, &mNodeReader, &mNodeWriter),
mWritingHelper(mBuffers.get()),
- mUnigramCount(mHeaderPolicy->getUnigramCount()),
- mBigramCount(mHeaderPolicy->getBigramCount()),
+ mEntryCounters(mHeaderPolicy->getUnigramCount(), mHeaderPolicy->getBigramCount(),
+ mHeaderPolicy->getTrigramCount()),
mTerminalPtNodePositionsForIteratingWords(), mIsCorrupted(false) {};
virtual int getRootPosition() const {
@@ -163,8 +164,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
Ver4PatriciaTrieNodeWriter mNodeWriter;
DynamicPtUpdatingHelper mUpdatingHelper;
Ver4PatriciaTrieWritingHelper mWritingHelper;
- int mUnigramCount;
- int mBigramCount;
+ MutableEntryCounters mEntryCounters;
std::vector<int> mTerminalPtNodePositionsForIteratingWords;
mutable bool mIsCorrupted;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.cpp
index 2887dc6b1..5f440819a 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.cpp
@@ -43,18 +43,18 @@ namespace backward {
namespace v402 {
bool Ver4PatriciaTrieWritingHelper::writeToDictFile(const char *const dictDirPath,
- const int unigramCount, const int bigramCount) const {
+ const EntryCounts &entryCounts) const {
const HeaderPolicy *const headerPolicy = mBuffers->getHeaderPolicy();
BufferWithExtendableBuffer headerBuffer(
BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE);
const int extendedRegionSize = headerPolicy->getExtendedRegionSize()
+ mBuffers->getTrieBuffer()->getUsedAdditionalBufferSize();
if (!headerPolicy->fillInAndWriteHeaderToBuffer(false /* updatesLastDecayedTime */,
- unigramCount, bigramCount, extendedRegionSize, &headerBuffer)) {
+ entryCounts, extendedRegionSize, &headerBuffer)) {
AKLOGE("Cannot write header structure to buffer. "
"updatesLastDecayedTime: %d, unigramCount: %d, bigramCount: %d, "
- "extendedRegionSize: %d", false, unigramCount, bigramCount,
- extendedRegionSize);
+ "extendedRegionSize: %d", false, entryCounters.getUnigramCount(),
+ entryCounters.getBigramCount(), extendedRegionSize);
return false;
}
return mBuffers->flushHeaderAndDictBuffers(dictDirPath, &headerBuffer);
@@ -74,7 +74,8 @@ bool Ver4PatriciaTrieWritingHelper::writeToDictFileWithGC(const int rootPtNodeAr
BufferWithExtendableBuffer headerBuffer(
BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE);
if (!headerPolicy->fillInAndWriteHeaderToBuffer(true /* updatesLastDecayedTime */,
- unigramCount, bigramCount, 0 /* extendedRegionSize */, &headerBuffer)) {
+ EntryCounts(unigramCount, bigramCount, 0 /* trigramCount */),
+ 0 /* extendedRegionSize */, &headerBuffer)) {
return false;
}
return dictBuffers->flushHeaderAndDictBuffers(dictDirPath, &headerBuffer);
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.h
index 9034ee656..1aad33e38 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.h
@@ -27,6 +27,7 @@
#include "defines.h"
#include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_gc_event_listeners.h"
#include "suggest/policyimpl/dictionary/structure/backward/v402/content/terminal_position_lookup_table.h"
+#include "suggest/policyimpl/dictionary/utils/entry_counters.h"
namespace latinime {
namespace backward {
@@ -46,8 +47,7 @@ class Ver4PatriciaTrieWritingHelper {
Ver4PatriciaTrieWritingHelper(Ver4DictBuffers *const buffers)
: mBuffers(buffers) {}
- bool writeToDictFile(const char *const dictDirPath, const int unigramCount,
- const int bigramCount) const;
+ bool writeToDictFile(const char *const dictDirPath, const EntryCounts &entryCounts) const;
// This method cannot be const because the original dictionary buffer will be updated to detect
// useless PtNodes during GC.
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.cpp
index 92fd6f214..e524e86e5 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.cpp
@@ -146,7 +146,7 @@ bool DynamicPtUpdatingHelper::setPtNodeProbability(const PtNodeParams *const ori
const int movedPos = mBuffer->getTailPosition();
int writingPos = movedPos;
const PtNodeParams ptNodeParamsToWrite(getUpdatedPtNodeParams(originalPtNodeParams,
- unigramProperty->isNotAWord(), unigramProperty->isBlacklisted(),
+ unigramProperty->isNotAWord(), unigramProperty->isPossiblyOffensive(),
true /* isTerminal */, originalPtNodeParams->getParentPos(),
originalPtNodeParams->getCodePointArrayView(), unigramProperty->getProbability()));
if (!mPtNodeWriter->writeNewTerminalPtNodeAndAdvancePosition(&ptNodeParamsToWrite,
@@ -180,8 +180,9 @@ bool DynamicPtUpdatingHelper::createNewPtNodeArrayWithAChildPtNode(
return false;
}
const PtNodeParams ptNodeParamsToWrite(getPtNodeParamsForNewPtNode(
- unigramProperty->isNotAWord(), unigramProperty->isBlacklisted(), true /* isTerminal */,
- parentPtNodePos, ptNodeCodePoints, unigramProperty->getProbability()));
+ unigramProperty->isNotAWord(), unigramProperty->isPossiblyOffensive(),
+ true /* isTerminal */, parentPtNodePos, ptNodeCodePoints,
+ unigramProperty->getProbability()));
if (!mPtNodeWriter->writeNewTerminalPtNodeAndAdvancePosition(&ptNodeParamsToWrite,
unigramProperty, &writingPos)) {
return false;
@@ -214,7 +215,7 @@ bool DynamicPtUpdatingHelper::reallocatePtNodeAndAddNewPtNodes(
reallocatingPtNodeParams->getCodePointArrayView().limit(overlappingCodePointCount);
if (addsExtraChild) {
const PtNodeParams ptNodeParamsToWrite(getPtNodeParamsForNewPtNode(
- false /* isNotAWord */, false /* isBlacklisted */, false /* isTerminal */,
+ false /* isNotAWord */, false /* isPossiblyOffensive */, false /* isTerminal */,
reallocatingPtNodeParams->getParentPos(), firstPtNodeCodePoints,
NOT_A_PROBABILITY));
if (!mPtNodeWriter->writePtNodeAndAdvancePosition(&ptNodeParamsToWrite, &writingPos)) {
@@ -222,7 +223,7 @@ bool DynamicPtUpdatingHelper::reallocatePtNodeAndAddNewPtNodes(
}
} else {
const PtNodeParams ptNodeParamsToWrite(getPtNodeParamsForNewPtNode(
- unigramProperty->isNotAWord(), unigramProperty->isBlacklisted(),
+ unigramProperty->isNotAWord(), unigramProperty->isPossiblyOffensive(),
true /* isTerminal */, reallocatingPtNodeParams->getParentPos(),
firstPtNodeCodePoints, unigramProperty->getProbability()));
if (!mPtNodeWriter->writeNewTerminalPtNodeAndAdvancePosition(&ptNodeParamsToWrite,
@@ -240,7 +241,7 @@ bool DynamicPtUpdatingHelper::reallocatePtNodeAndAddNewPtNodes(
// Write the 2nd part of the reallocating node.
const int secondPartOfReallocatedPtNodePos = writingPos;
const PtNodeParams childPartPtNodeParams(getUpdatedPtNodeParams(reallocatingPtNodeParams,
- reallocatingPtNodeParams->isNotAWord(), reallocatingPtNodeParams->isBlacklisted(),
+ reallocatingPtNodeParams->isNotAWord(), reallocatingPtNodeParams->isPossiblyOffensive(),
reallocatingPtNodeParams->isTerminal(), firstPartOfReallocatedPtNodePos,
reallocatingPtNodeParams->getCodePointArrayView().skip(overlappingCodePointCount),
reallocatingPtNodeParams->getProbability()));
@@ -249,7 +250,7 @@ bool DynamicPtUpdatingHelper::reallocatePtNodeAndAddNewPtNodes(
}
if (addsExtraChild) {
const PtNodeParams extraChildPtNodeParams(getPtNodeParamsForNewPtNode(
- unigramProperty->isNotAWord(), unigramProperty->isBlacklisted(),
+ unigramProperty->isNotAWord(), unigramProperty->isPossiblyOffensive(),
true /* isTerminal */, firstPartOfReallocatedPtNodePos,
newPtNodeCodePoints.skip(overlappingCodePointCount),
unigramProperty->getProbability()));
@@ -276,20 +277,20 @@ bool DynamicPtUpdatingHelper::reallocatePtNodeAndAddNewPtNodes(
const PtNodeParams DynamicPtUpdatingHelper::getUpdatedPtNodeParams(
const PtNodeParams *const originalPtNodeParams, const bool isNotAWord,
- const bool isBlacklisted, const bool isTerminal, const int parentPos,
+ const bool isPossiblyOffensive, const bool isTerminal, const int parentPos,
const CodePointArrayView codePoints, const int probability) const {
const PatriciaTrieReadingUtils::NodeFlags flags = PatriciaTrieReadingUtils::createAndGetFlags(
- isBlacklisted, isNotAWord, isTerminal, false /* hasShortcutTargets */,
+ isPossiblyOffensive, isNotAWord, isTerminal, false /* hasShortcutTargets */,
false /* hasBigrams */, codePoints.size() > 1u /* hasMultipleChars */,
CHILDREN_POSITION_FIELD_SIZE);
return PtNodeParams(originalPtNodeParams, flags, parentPos, codePoints, probability);
}
const PtNodeParams DynamicPtUpdatingHelper::getPtNodeParamsForNewPtNode(const bool isNotAWord,
- const bool isBlacklisted, const bool isTerminal, const int parentPos,
+ const bool isPossiblyOffensive, const bool isTerminal, const int parentPos,
const CodePointArrayView codePoints, const int probability) const {
const PatriciaTrieReadingUtils::NodeFlags flags = PatriciaTrieReadingUtils::createAndGetFlags(
- isBlacklisted, isNotAWord, isTerminal, false /* hasShortcutTargets */,
+ isPossiblyOffensive, isNotAWord, isTerminal, false /* hasShortcutTargets */,
false /* hasBigrams */, codePoints.size() > 1u /* hasMultipleChars */,
CHILDREN_POSITION_FIELD_SIZE);
return PtNodeParams(flags, parentPos, codePoints, probability);
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.h
index 2bbe2f4dc..db5f6ab17 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.h
@@ -85,12 +85,12 @@ class DynamicPtUpdatingHelper {
const CodePointArrayView newPtNodeCodePoints);
const PtNodeParams getUpdatedPtNodeParams(const PtNodeParams *const originalPtNodeParams,
- const bool isNotAWord, const bool isBlacklisted, const bool isTerminal,
+ const bool isNotAWord, const bool isPossiblyOffensive, const bool isTerminal,
const int parentPos, const CodePointArrayView codePoints, const int probability) const;
- const PtNodeParams getPtNodeParamsForNewPtNode(const bool isNotAWord, const bool isBlacklisted,
- const bool isTerminal, const int parentPos, const CodePointArrayView codePoints,
- const int probability) const;
+ const PtNodeParams getPtNodeParamsForNewPtNode(const bool isNotAWord,
+ const bool isPossiblyOffensive, const bool isTerminal, const int parentPos,
+ const CodePointArrayView codePoints, const int probability) const;
};
} // namespace latinime
#endif /* LATINIME_DYNAMIC_PATRICIA_TRIE_UPDATING_HELPER_H */
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.cpp
index 6a498b2f4..b8d78bf10 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.cpp
@@ -41,8 +41,8 @@ const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_HAS_SHORTCUT_TARGETS = 0x08
const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_HAS_BIGRAMS = 0x04;
// Flag for non-words (typically, shortcut only entries)
const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_IS_NOT_A_WORD = 0x02;
-// Flag for blacklist
-const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_IS_BLACKLISTED = 0x01;
+// Flag for possibly offensive words
+const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_IS_POSSIBLY_OFFENSIVE = 0x01;
/* static */ int PtReadingUtils::getPtNodeArraySizeAndAdvancePosition(
const uint8_t *const buffer, int *const pos) {
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h
index a69ec4435..6a2bf5d3c 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h
@@ -54,8 +54,8 @@ class PatriciaTrieReadingUtils {
/**
* Node Flags
*/
- static AK_FORCE_INLINE bool isBlacklisted(const NodeFlags flags) {
- return (flags & FLAG_IS_BLACKLISTED) != 0;
+ static AK_FORCE_INLINE bool isPossiblyOffensive(const NodeFlags flags) {
+ return (flags & FLAG_IS_POSSIBLY_OFFENSIVE) != 0;
}
static AK_FORCE_INLINE bool isNotAWord(const NodeFlags flags) {
@@ -82,12 +82,12 @@ class PatriciaTrieReadingUtils {
return FLAG_CHILDREN_POSITION_TYPE_NOPOSITION != (MASK_CHILDREN_POSITION_TYPE & flags);
}
- static AK_FORCE_INLINE NodeFlags createAndGetFlags(const bool isBlacklisted,
+ static AK_FORCE_INLINE NodeFlags createAndGetFlags(const bool isPossiblyOffensive,
const bool isNotAWord, const bool isTerminal, const bool hasShortcutTargets,
const bool hasBigrams, const bool hasMultipleChars,
const int childrenPositionFieldSize) {
NodeFlags nodeFlags = 0;
- nodeFlags = isBlacklisted ? (nodeFlags | FLAG_IS_BLACKLISTED) : nodeFlags;
+ nodeFlags = isPossiblyOffensive ? (nodeFlags | FLAG_IS_POSSIBLY_OFFENSIVE) : nodeFlags;
nodeFlags = isNotAWord ? (nodeFlags | FLAG_IS_NOT_A_WORD) : nodeFlags;
nodeFlags = isTerminal ? (nodeFlags | FLAG_IS_TERMINAL) : nodeFlags;
nodeFlags = hasShortcutTargets ? (nodeFlags | FLAG_HAS_SHORTCUT_TARGETS) : nodeFlags;
@@ -127,7 +127,7 @@ class PatriciaTrieReadingUtils {
static const NodeFlags FLAG_HAS_SHORTCUT_TARGETS;
static const NodeFlags FLAG_HAS_BIGRAMS;
static const NodeFlags FLAG_IS_NOT_A_WORD;
- static const NodeFlags FLAG_IS_BLACKLISTED;
+ static const NodeFlags FLAG_IS_POSSIBLY_OFFENSIVE;
};
} // namespace latinime
#endif /* LATINIME_PATRICIA_TRIE_NODE_READING_UTILS_H */
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h
index 3ff1829bd..585e87a24 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h
@@ -145,7 +145,18 @@ class PtNodeParams {
}
AK_FORCE_INLINE bool isBlacklisted() const {
- return PatriciaTrieReadingUtils::isBlacklisted(mFlags);
+ // Note: this method will be removed in the next change.
+ // It is used in getProbabilityOfWord and getWordAttributes for both v402 and v403.
+ // * getProbabilityOfWord will be changed to no longer return NOT_A_PROBABILITY
+ // when isBlacklisted (i.e. to only check if isNotAWord or isDeleted)
+ // * getWordAttributes will be changed to always return blacklisted=false and
+ // isPossiblyOffensive according to the function below (instead of the current
+ // behaviour of checking if the probability is zero)
+ return PatriciaTrieReadingUtils::isPossiblyOffensive(mFlags);
+ }
+
+ AK_FORCE_INLINE bool isPossiblyOffensive() const {
+ return PatriciaTrieReadingUtils::isPossiblyOffensive(mFlags);
}
AK_FORCE_INLINE bool isNotAWord() const {
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp
index b7f1199c5..ca44da9fb 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp
@@ -476,8 +476,8 @@ const WordProperty PatriciaTriePolicy::getWordProperty(
}
}
const UnigramProperty unigramProperty(ptNodeParams.representsBeginningOfSentence(),
- ptNodeParams.isNotAWord(), ptNodeParams.isBlacklisted(), ptNodeParams.getProbability(),
- HistoricalInfo(), std::move(shortcuts));
+ ptNodeParams.isNotAWord(), ptNodeParams.isPossiblyOffensive(),
+ ptNodeParams.getProbability(), HistoricalInfo(), std::move(shortcuts));
return WordProperty(wordCodePoints.toVector(), &unigramProperty, &ngrams);
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
index c4297f5d6..a88996524 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
@@ -23,8 +23,6 @@
namespace latinime {
-const int LanguageModelDictContent::UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE = 0;
-const int LanguageModelDictContent::BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE = 1;
const int LanguageModelDictContent::DUMMY_PROBABILITY_FOR_VALID_WORDS = 1;
bool LanguageModelDictContent::save(FILE *const file) const {
@@ -33,10 +31,9 @@ bool LanguageModelDictContent::save(FILE *const file) const {
bool LanguageModelDictContent::runGC(
const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
- const LanguageModelDictContent *const originalContent,
- int *const outNgramCount) {
+ const LanguageModelDictContent *const originalContent) {
return runGCInner(terminalIdMap, originalContent->mTrieMap.getEntriesInRootLevel(),
- 0 /* nextLevelBitmapEntryIndex */, outNgramCount);
+ 0 /* nextLevelBitmapEntryIndex */);
}
const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArrayView prevWordIds,
@@ -143,28 +140,30 @@ LanguageModelDictContent::EntryRange LanguageModelDictContent::getProbabilityEnt
return EntryRange(mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex), mHasHistoricalInfo);
}
-bool LanguageModelDictContent::truncateEntries(const int *const entryCounts,
- const int *const maxEntryCounts, const HeaderPolicy *const headerPolicy,
- int *const outEntryCounts) {
- for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
- if (entryCounts[i] <= maxEntryCounts[i]) {
- outEntryCounts[i] = entryCounts[i];
+bool LanguageModelDictContent::truncateEntries(const EntryCounts &currentEntryCounts,
+ const EntryCounts &maxEntryCounts, const HeaderPolicy *const headerPolicy,
+ MutableEntryCounters *const outEntryCounters) {
+ for (int prevWordCount = 0; prevWordCount <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++prevWordCount) {
+ const int totalWordCount = prevWordCount + 1;
+ if (currentEntryCounts.getNgramCount(totalWordCount)
+ <= maxEntryCounts.getNgramCount(totalWordCount)) {
+ outEntryCounters->setNgramCount(totalWordCount,
+ currentEntryCounts.getNgramCount(totalWordCount));
continue;
}
- if (!turncateEntriesInSpecifiedLevel(headerPolicy, maxEntryCounts[i], i,
- &outEntryCounts[i])) {
+ int entryCount = 0;
+ if (!turncateEntriesInSpecifiedLevel(headerPolicy,
+ maxEntryCounts.getNgramCount(totalWordCount), prevWordCount, &entryCount)) {
return false;
}
+ outEntryCounters->setNgramCount(totalWordCount, entryCount);
}
return true;
}
bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView prevWordIds,
const int wordId, const bool isValid, const HistoricalInfo historicalInfo,
- const HeaderPolicy *const headerPolicy, int *const outAddedNewNgramEntryCount) {
- if (outAddedNewNgramEntryCount) {
- *outAddedNewNgramEntryCount = 0;
- }
+ const HeaderPolicy *const headerPolicy, MutableEntryCounters *const entryCountersToUpdate) {
if (!mHasHistoricalInfo) {
AKLOGE("updateAllEntriesOnInputWord is called for dictionary without historical info.");
return false;
@@ -188,8 +187,8 @@ bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView
if (!setNgramProbabilityEntry(limitedPrevWordIds, wordId, &updatedNgramProbabilityEntry)) {
return false;
}
- if (!originalNgramProbabilityEntry.isValid() && outAddedNewNgramEntryCount) {
- *outAddedNewNgramEntryCount += 1;
+ if (!originalNgramProbabilityEntry.isValid()) {
+ entryCountersToUpdate->incrementNgramCount(i + 2);
}
}
return true;
@@ -211,8 +210,7 @@ const ProbabilityEntry LanguageModelDictContent::createUpdatedEntryFrom(
bool LanguageModelDictContent::runGCInner(
const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
- const TrieMap::TrieMapRange trieMapRange,
- const int nextLevelBitmapEntryIndex, int *const outNgramCount) {
+ const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex) {
for (auto &entry : trieMapRange) {
const auto it = terminalIdMap->find(entry.key());
if (it == terminalIdMap->end() || it->second == Ver4DictConstants::NOT_A_TERMINAL_ID) {
@@ -222,13 +220,9 @@ bool LanguageModelDictContent::runGCInner(
if (!mTrieMap.put(it->second, entry.value(), nextLevelBitmapEntryIndex)) {
return false;
}
- if (outNgramCount) {
- *outNgramCount += 1;
- }
if (entry.hasNextLevelMap()) {
if (!runGCInner(terminalIdMap, entry.getEntriesInNextLevel(),
- mTrieMap.getNextLevelBitmapEntryIndex(it->second, nextLevelBitmapEntryIndex),
- outNgramCount)) {
+ mTrieMap.getNextLevelBitmapEntryIndex(it->second, nextLevelBitmapEntryIndex))) {
return false;
}
}
@@ -271,7 +265,7 @@ int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWord
bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex,
const int prevWordCount, const HeaderPolicy *const headerPolicy,
- int *const outEntryCounts) {
+ MutableEntryCounters *const outEntryCounters) {
for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
if (prevWordCount > MAX_PREV_WORD_COUNT_FOR_N_GRAM) {
AKLOGE("Invalid prevWordCount. prevWordCount: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.",
@@ -308,13 +302,13 @@ bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int b
}
}
if (!probabilityEntry.representsBeginningOfSentence()) {
- outEntryCounts[prevWordCount] += 1;
+ outEntryCounters->incrementNgramCount(prevWordCount + 1);
}
if (!entry.hasNextLevelMap()) {
continue;
}
if (!updateAllProbabilityEntriesForGCInner(entry.getNextLevelBitmapEntryIndex(),
- prevWordCount + 1, headerPolicy, outEntryCounts)) {
+ prevWordCount + 1, headerPolicy, outEntryCounters)) {
return false;
}
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
index 51ef090e1..41a429a54 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
@@ -25,6 +25,7 @@
#include "suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h"
#include "suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h"
#include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h"
+#include "suggest/policyimpl/dictionary/utils/entry_counters.h"
#include "suggest/policyimpl/dictionary/utils/trie_map.h"
#include "utils/byte_array_view.h"
#include "utils/int_array_view.h"
@@ -40,9 +41,6 @@ class HeaderPolicy;
*/
class LanguageModelDictContent {
public:
- static const int UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE;
- static const int BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE;
-
// Pair of word id and probability entry used for iteration.
class WordIdAndProbabilityEntry {
public:
@@ -126,8 +124,7 @@ class LanguageModelDictContent {
bool save(FILE *const file) const;
bool runGC(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
- const LanguageModelDictContent *const originalContent,
- int *const outNgramCount);
+ const LanguageModelDictContent *const originalContent);
const WordAttributes getWordAttributes(const WordIdArrayView prevWordIds, const int wordId,
const HeaderPolicy *const headerPolicy) const;
@@ -155,21 +152,19 @@ class LanguageModelDictContent {
EntryRange getProbabilityEntries(const WordIdArrayView prevWordIds) const;
bool updateAllProbabilityEntriesForGC(const HeaderPolicy *const headerPolicy,
- int *const outEntryCounts) {
- for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
- outEntryCounts[i] = 0;
- }
+ MutableEntryCounters *const outEntryCounters) {
return updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(),
- 0 /* prevWordCount */, headerPolicy, outEntryCounts);
+ 0 /* prevWordCount */, headerPolicy, outEntryCounters);
}
// entryCounts should be created by updateAllProbabilityEntries.
- bool truncateEntries(const int *const entryCounts, const int *const maxEntryCounts,
- const HeaderPolicy *const headerPolicy, int *const outEntryCounts);
+ bool truncateEntries(const EntryCounts &currentEntryCounts, const EntryCounts &maxEntryCounts,
+ const HeaderPolicy *const headerPolicy, MutableEntryCounters *const outEntryCounters);
bool updateAllEntriesOnInputWord(const WordIdArrayView prevWordIds, const int wordId,
const bool isValid, const HistoricalInfo historicalInfo,
- const HeaderPolicy *const headerPolicy, int *const outAddedNewNgramEntryCount);
+ const HeaderPolicy *const headerPolicy,
+ MutableEntryCounters *const entryCountersToUpdate);
private:
DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent);
@@ -204,12 +199,11 @@ class LanguageModelDictContent {
const bool mHasHistoricalInfo;
bool runGCInner(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
- const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex,
- int *const outNgramCount);
+ const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex);
int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds);
int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
bool updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, const int prevWordCount,
- const HeaderPolicy *const headerPolicy, int *const outEntryCounts);
+ const HeaderPolicy *const headerPolicy, MutableEntryCounters *const outEntryCounters);
bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy,
const int maxEntryCount, const int targetLevel, int *const outEntryCount);
bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel,
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp
index 794c63ffd..3488f7d2a 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp
@@ -342,7 +342,7 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeFlags(const int ptNodePos, const bo
// Create node flags and write them.
PatriciaTrieReadingUtils::NodeFlags nodeFlags =
PatriciaTrieReadingUtils::createAndGetFlags(false /* isNotAWord */,
- false /* isBlacklisted */, isTerminal, false /* hasShortcutTargets */,
+ false /* isPossiblyOffensive */, isTerminal, false /* hasShortcutTargets */,
false /* hasBigrams */, hasMultipleChars, CHILDREN_POSITION_FIELD_SIZE);
if (!DynamicPtWritingUtils::writeFlags(mTrieBuffer, nodeFlags, ptNodePos)) {
AKLOGE("Cannot write PtNode flags. flags: %x, pos: %d", nodeFlags, ptNodePos);
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
index ea8c0dc22..d90066cec 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
@@ -211,7 +211,7 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const CodePointArrayView wordCodePo
if (mUpdatingHelper.addUnigramWord(&readingHelper, codePointArrayView, unigramProperty,
&addedNewUnigram)) {
if (addedNewUnigram && !unigramProperty->representsBeginningOfSentence()) {
- mUnigramCount++;
+ mEntryCounters.incrementUnigramCount();
}
if (unigramProperty->getShortcuts().size() > 0) {
// Add shortcut target.
@@ -259,7 +259,7 @@ bool Ver4PatriciaTriePolicy::removeUnigramEntry(const CodePointArrayView wordCod
return false;
}
if (!ptNodeParams.representsNonWordInfo()) {
- mUnigramCount--;
+ mEntryCounters.decrementUnigramCount();
}
return true;
}
@@ -299,7 +299,8 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const NgramContext *const ngramContex
}
const UnigramProperty beginningOfSentenceUnigramProperty(
true /* representsBeginningOfSentence */, true /* isNotAWord */,
- false /* isBlacklisted */, MAX_PROBABILITY /* probability */, HistoricalInfo());
+ false /* isBlacklisted */, false /* isPossiblyOffensive */,
+ MAX_PROBABILITY /* probability */, HistoricalInfo());
if (!addUnigramEntry(ngramContext->getNthPrevWordCodePoints(1 /* n */),
&beginningOfSentenceUnigramProperty)) {
AKLOGE("Cannot add unigram entry for the beginning-of-sentence.");
@@ -316,7 +317,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const NgramContext *const ngramContex
bool addedNewEntry = false;
if (mNodeWriter.addNgramEntry(prevWordIds, wordId, ngramProperty, &addedNewEntry)) {
if (addedNewEntry) {
- mBigramCount++;
+ mEntryCounters.incrementNgramCount(prevWordIds.size() + 1);
}
return true;
} else {
@@ -354,7 +355,7 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const NgramContext *const ngramCon
return false;
}
if (mNodeWriter.removeNgramEntry(prevWordIds, wordId)) {
- mBigramCount--;
+ mEntryCounters.decrementNgramCount(prevWordIds.size());
return true;
} else {
return false;
@@ -375,8 +376,9 @@ bool Ver4PatriciaTriePolicy::updateEntriesForWordWithNgramContext(
if (wordId == NOT_A_WORD_ID) {
// The word is not in the dictionary.
const UnigramProperty unigramProperty(false /* representsBeginningOfSentence */,
- false /* isNotAWord */, false /* isBlacklisted */, NOT_A_PROBABILITY,
- HistoricalInfo(historicalInfo.getTimestamp(), 0 /* level */, 0 /* count */));
+ false /* isNotAWord */, false /* isBlacklisted */, false /* isPossiblyOffensive */,
+ NOT_A_PROBABILITY, HistoricalInfo(historicalInfo.getTimestamp(), 0 /* level */,
+ 0 /* count */));
if (!addUnigramEntry(wordCodePoints, &unigramProperty)) {
AKLOGE("Cannot add unigarm entry in updateEntriesForWordWithNgramContext().");
return false;
@@ -391,7 +393,7 @@ bool Ver4PatriciaTriePolicy::updateEntriesForWordWithNgramContext(
&& ngramContext->isNthPrevWordBeginningOfSentence(1 /* n */)) {
const UnigramProperty beginningOfSentenceUnigramProperty(
true /* representsBeginningOfSentence */,
- true /* isNotAWord */, false /* isBlacklisted */, NOT_A_PROBABILITY,
+ true /* isNotAWord */, false /* isPossiblyOffensive */, NOT_A_PROBABILITY,
HistoricalInfo(historicalInfo.getTimestamp(), 0 /* level */, 0 /* count */));
if (!addUnigramEntry(ngramContext->getNthPrevWordCodePoints(1 /* n */),
&beginningOfSentenceUnigramProperty)) {
@@ -401,12 +403,10 @@ bool Ver4PatriciaTriePolicy::updateEntriesForWordWithNgramContext(
// Refresh word ids.
ngramContext->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */);
}
- int addedNewNgramEntryCount = 0;
if (!mBuffers->getMutableLanguageModelDictContent()->updateAllEntriesOnInputWord(prevWordIds,
- wordId, updateAsAValidWord, historicalInfo, mHeaderPolicy, &addedNewNgramEntryCount)) {
+ wordId, updateAsAValidWord, historicalInfo, mHeaderPolicy, &mEntryCounters)) {
return false;
}
- mBigramCount += addedNewNgramEntryCount;
return true;
}
@@ -415,7 +415,7 @@ bool Ver4PatriciaTriePolicy::flush(const char *const filePath) {
AKLOGI("Warning: flush() is called for non-updatable dictionary. filePath: %s", filePath);
return false;
}
- if (!mWritingHelper.writeToDictFile(filePath, mUnigramCount, mBigramCount)) {
+ if (!mWritingHelper.writeToDictFile(filePath, mEntryCounters.getEntryCounts())) {
AKLOGE("Cannot flush the dictionary to file.");
mIsCorrupted = true;
return false;
@@ -453,8 +453,7 @@ bool Ver4PatriciaTriePolicy::needsToRunGC(const bool mindsBlockByGC) const {
// Needs to reduce dictionary size.
return true;
} else if (mHeaderPolicy->isDecayingDict()) {
- return ForgettingCurveUtils::needsToDecay(mindsBlockByGC, mUnigramCount, mBigramCount,
- mHeaderPolicy);
+ return ForgettingCurveUtils::needsToDecay(mindsBlockByGC, mEntryCounters.getEntryCounts(), mHeaderPolicy);
}
return false;
}
@@ -463,19 +462,19 @@ void Ver4PatriciaTriePolicy::getProperty(const char *const query, const int quer
char *const outResult, const int maxResultLength) {
const int compareLength = queryLength + 1 /* terminator */;
if (strncmp(query, UNIGRAM_COUNT_QUERY, compareLength) == 0) {
- snprintf(outResult, maxResultLength, "%d", mUnigramCount);
+ snprintf(outResult, maxResultLength, "%d", mEntryCounters.getUnigramCount());
} else if (strncmp(query, BIGRAM_COUNT_QUERY, compareLength) == 0) {
- snprintf(outResult, maxResultLength, "%d", mBigramCount);
+ snprintf(outResult, maxResultLength, "%d", mEntryCounters.getBigramCount());
} else if (strncmp(query, MAX_UNIGRAM_COUNT_QUERY, compareLength) == 0) {
snprintf(outResult, maxResultLength, "%d",
mHeaderPolicy->isDecayingDict() ?
- ForgettingCurveUtils::getUnigramCountHardLimit(
+ ForgettingCurveUtils::getEntryCountHardLimit(
mHeaderPolicy->getMaxUnigramCount()) :
static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE));
} else if (strncmp(query, MAX_BIGRAM_COUNT_QUERY, compareLength) == 0) {
snprintf(outResult, maxResultLength, "%d",
mHeaderPolicy->isDecayingDict() ?
- ForgettingCurveUtils::getBigramCountHardLimit(
+ ForgettingCurveUtils::getEntryCountHardLimit(
mHeaderPolicy->getMaxBigramCount()) :
static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE));
}
@@ -532,7 +531,8 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
}
const UnigramProperty unigramProperty(probabilityEntry.representsBeginningOfSentence(),
probabilityEntry.isNotAWord(), probabilityEntry.isBlacklisted(),
- probabilityEntry.getProbability(), *historicalInfo, std::move(shortcuts));
+ probabilityEntry.isPossiblyOffensive(), probabilityEntry.getProbability(),
+ *historicalInfo, std::move(shortcuts));
return WordProperty(wordCodePoints.toVector(), &unigramProperty, &ngrams);
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h
index c0532815c..e3611cb32 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h
@@ -30,6 +30,7 @@
#include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h"
#include "suggest/policyimpl/dictionary/structure/v4/ver4_pt_node_array_reader.h"
#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h"
+#include "suggest/policyimpl/dictionary/utils/entry_counters.h"
#include "utils/int_array_view.h"
namespace latinime {
@@ -37,7 +38,6 @@ namespace latinime {
class DicNode;
class DicNodeVector;
-// TODO: Support counting ngram entries.
// Word id = Artificial id that is stored in the PtNode looked up by the word.
class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
public:
@@ -51,8 +51,8 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
&mShortcutPolicy),
mUpdatingHelper(mDictBuffer, &mNodeReader, &mNodeWriter),
mWritingHelper(mBuffers.get()),
- mUnigramCount(mHeaderPolicy->getUnigramCount()),
- mBigramCount(mHeaderPolicy->getBigramCount()),
+ mEntryCounters(mHeaderPolicy->getUnigramCount(), mHeaderPolicy->getBigramCount(),
+ mHeaderPolicy->getTrigramCount()),
mTerminalPtNodePositionsForIteratingWords(), mIsCorrupted(false) {};
AK_FORCE_INLINE int getRootPosition() const {
@@ -141,9 +141,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
Ver4PatriciaTrieNodeWriter mNodeWriter;
DynamicPtUpdatingHelper mUpdatingHelper;
Ver4PatriciaTrieWritingHelper mWritingHelper;
- int mUnigramCount;
- // TODO: Support counting ngram entries.
- int mBigramCount;
+ MutableEntryCounters mEntryCounters;
std::vector<int> mTerminalPtNodePositionsForIteratingWords;
mutable bool mIsCorrupted;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
index f0d59c150..c7ffcc860 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
@@ -33,17 +33,18 @@
namespace latinime {
bool Ver4PatriciaTrieWritingHelper::writeToDictFile(const char *const dictDirPath,
- const int unigramCount, const int bigramCount) const {
+ const EntryCounts &entryCounts) const {
const HeaderPolicy *const headerPolicy = mBuffers->getHeaderPolicy();
BufferWithExtendableBuffer headerBuffer(
BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE);
const int extendedRegionSize = headerPolicy->getExtendedRegionSize()
+ mBuffers->getTrieBuffer()->getUsedAdditionalBufferSize();
if (!headerPolicy->fillInAndWriteHeaderToBuffer(false /* updatesLastDecayedTime */,
- unigramCount, bigramCount, extendedRegionSize, &headerBuffer)) {
+ entryCounts, extendedRegionSize, &headerBuffer)) {
AKLOGE("Cannot write header structure to buffer. "
- "updatesLastDecayedTime: %d, unigramCount: %d, bigramCount: %d, "
- "extendedRegionSize: %d", false, unigramCount, bigramCount,
+ "updatesLastDecayedTime: %d, unigramCount: %d, bigramCount: %d, trigramCount: %d,"
+ "extendedRegionSize: %d", false, entryCounters.getUnigramCount(),
+ entryCounters.getBigramCount(), entryCounters.getTrigramCount(),
extendedRegionSize);
return false;
}
@@ -56,15 +57,14 @@ bool Ver4PatriciaTrieWritingHelper::writeToDictFileWithGC(const int rootPtNodeAr
Ver4DictBuffers::Ver4DictBuffersPtr dictBuffers(
Ver4DictBuffers::createVer4DictBuffers(headerPolicy,
Ver4DictConstants::MAX_DICTIONARY_SIZE));
- int unigramCount = 0;
- int bigramCount = 0;
- if (!runGC(rootPtNodeArrayPos, headerPolicy, dictBuffers.get(), &unigramCount, &bigramCount)) {
+ MutableEntryCounters entryCounters;
+ if (!runGC(rootPtNodeArrayPos, headerPolicy, dictBuffers.get(), &entryCounters)) {
return false;
}
BufferWithExtendableBuffer headerBuffer(
BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE);
if (!headerPolicy->fillInAndWriteHeaderToBuffer(true /* updatesLastDecayedTime */,
- unigramCount, bigramCount, 0 /* extendedRegionSize */, &headerBuffer)) {
+ entryCounters.getEntryCounts(), 0 /* extendedRegionSize */, &headerBuffer)) {
return false;
}
return dictBuffers->flushHeaderAndDictBuffers(dictDirPath, &headerBuffer);
@@ -72,7 +72,7 @@ bool Ver4PatriciaTrieWritingHelper::writeToDictFileWithGC(const int rootPtNodeAr
bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
const HeaderPolicy *const headerPolicy, Ver4DictBuffers *const buffersToWrite,
- int *const outUnigramCount, int *const outBigramCount) {
+ MutableEntryCounters *const outEntryCounters) {
Ver4PatriciaTrieNodeReader ptNodeReader(mBuffers->getTrieBuffer());
Ver4PtNodeArrayReader ptNodeArrayReader(mBuffers->getTrieBuffer());
Ver4ShortcutListPolicy shortcutPolicy(mBuffers->getMutableShortcutDictContent(),
@@ -80,24 +80,17 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
Ver4PatriciaTrieNodeWriter ptNodeWriter(mBuffers->getWritableTrieBuffer(),
mBuffers, &ptNodeReader, &ptNodeArrayReader, &shortcutPolicy);
- int entryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
if (!mBuffers->getMutableLanguageModelDictContent()->updateAllProbabilityEntriesForGC(
- headerPolicy, entryCountTable)) {
+ headerPolicy, outEntryCounters)) {
AKLOGE("Failed to update probabilities in language model dict content.");
return false;
}
if (headerPolicy->isDecayingDict()) {
- int maxEntryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
- maxEntryCountTable[LanguageModelDictContent::UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE] =
- headerPolicy->getMaxUnigramCount();
- maxEntryCountTable[LanguageModelDictContent::BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE] =
- headerPolicy->getMaxBigramCount();
- for (size_t i = 2; i < NELEMS(maxEntryCountTable); ++i) {
- // TODO: Have max n-gram count.
- maxEntryCountTable[i] = headerPolicy->getMaxBigramCount();
- }
- if (!mBuffers->getMutableLanguageModelDictContent()->truncateEntries(entryCountTable,
- maxEntryCountTable, headerPolicy, entryCountTable)) {
+ const EntryCounts maxEntryCounts(headerPolicy->getMaxUnigramCount(),
+ headerPolicy->getMaxBigramCount(), headerPolicy->getMaxTrigramCount());
+ if (!mBuffers->getMutableLanguageModelDictContent()->truncateEntries(
+ outEntryCounters->getEntryCounts(), maxEntryCounts, headerPolicy,
+ outEntryCounters)) {
AKLOGE("Failed to truncate entries in language model dict content.");
return false;
}
@@ -141,9 +134,9 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
&terminalIdMap)) {
return false;
}
- // Run GC for probability dict content.
+ // Run GC for language model dict content.
if (!buffersToWrite->getMutableLanguageModelDictContent()->runGC(&terminalIdMap,
- mBuffers->getLanguageModelDictContent(), nullptr /* outNgramCount */)) {
+ mBuffers->getLanguageModelDictContent())) {
return false;
}
// Run GC for shortcut dict content.
@@ -166,10 +159,6 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
&traversePolicyToUpdateAllPtNodeFlagsAndTerminalIds)) {
return false;
}
- *outUnigramCount =
- entryCountTable[LanguageModelDictContent::UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE];
- *outBigramCount =
- entryCountTable[LanguageModelDictContent::BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE];
return true;
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h
index 3569d0576..c56cea5cf 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h
@@ -20,6 +20,7 @@
#include "defines.h"
#include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_gc_event_listeners.h"
#include "suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h"
+#include "suggest/policyimpl/dictionary/utils/entry_counters.h"
namespace latinime {
@@ -33,9 +34,7 @@ class Ver4PatriciaTrieWritingHelper {
Ver4PatriciaTrieWritingHelper(Ver4DictBuffers *const buffers)
: mBuffers(buffers) {}
- // TODO: Support counting ngram entries.
- bool writeToDictFile(const char *const dictDirPath, const int unigramCount,
- const int bigramCount) const;
+ bool writeToDictFile(const char *const dictDirPath, const EntryCounts &entryCounts) const;
// This method cannot be const because the original dictionary buffer will be updated to detect
// useless PtNodes during GC.
@@ -68,8 +67,7 @@ class Ver4PatriciaTrieWritingHelper {
};
bool runGC(const int rootPtNodeArrayPos, const HeaderPolicy *const headerPolicy,
- Ver4DictBuffers *const buffersToWrite, int *const outUnigramCount,
- int *const outBigramCount);
+ Ver4DictBuffers *const buffersToWrite, MutableEntryCounters *const outEntryCounters);
Ver4DictBuffers *const mBuffers;
};
diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp
index b7e2a7278..9d8e86675 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp
@@ -27,6 +27,7 @@
#include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_writing_utils.h"
#include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.h"
#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h"
+#include "suggest/policyimpl/dictionary/utils/entry_counters.h"
#include "suggest/policyimpl/dictionary/utils/file_utils.h"
#include "suggest/policyimpl/dictionary/utils/format_utils.h"
#include "utils/time_keeper.h"
@@ -69,8 +70,7 @@ template<class DictConstants, class DictBuffers, class DictBuffersPtr>
DictBuffersPtr dictBuffers = DictBuffers::createVer4DictBuffers(&headerPolicy,
DictConstants::MAX_DICT_EXTENDED_REGION_SIZE);
headerPolicy.fillInAndWriteHeaderToBuffer(true /* updatesLastDecayedTime */,
- 0 /* unigramCount */, 0 /* bigramCount */,
- 0 /* extendedRegionSize */, dictBuffers->getWritableHeaderBuffer());
+ EntryCounts(), 0 /* extendedRegionSize */, dictBuffers->getWritableHeaderBuffer());
if (!DynamicPtWritingUtils::writeEmptyDictionary(
dictBuffers->getWritableTrieBuffer(), 0 /* rootPos */)) {
AKLOGE("Empty ver4 dictionary structure cannot be created on memory.");
diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/entry_counters.h b/native/jni/src/suggest/policyimpl/dictionary/utils/entry_counters.h
new file mode 100644
index 000000000..73dc42a18
--- /dev/null
+++ b/native/jni/src/suggest/policyimpl/dictionary/utils/entry_counters.h
@@ -0,0 +1,133 @@
+/*
+ * Copyright (C) 2014, The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LATINIME_ENTRY_COUNTERS_H
+#define LATINIME_ENTRY_COUNTERS_H
+
+#include <array>
+
+#include "defines.h"
+
+namespace latinime {
+
+// Copyable but immutable
+class EntryCounts final {
+ public:
+ EntryCounts() : mEntryCounts({{0, 0, 0}}) {}
+
+ EntryCounts(const int unigramCount, const int bigramCount, const int trigramCount)
+ : mEntryCounts({{unigramCount, bigramCount, trigramCount}}) {}
+
+ explicit EntryCounts(const std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1> &counters)
+ : mEntryCounts(counters) {}
+
+ int getUnigramCount() const {
+ return mEntryCounts[0];
+ }
+
+ int getBigramCount() const {
+ return mEntryCounts[1];
+ }
+
+ int getTrigramCount() const {
+ return mEntryCounts[2];
+ }
+
+ int getNgramCount(const size_t n) const {
+ if (n < 1 || n > mEntryCounts.size()) {
+ return 0;
+ }
+ return mEntryCounts[n - 1];
+ }
+
+ private:
+ DISALLOW_ASSIGNMENT_OPERATOR(EntryCounts);
+
+ const std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1> mEntryCounts;
+};
+
+class MutableEntryCounters final {
+ public:
+ MutableEntryCounters() {
+ mEntryCounters.fill(0);
+ }
+
+ MutableEntryCounters(const int unigramCount, const int bigramCount, const int trigramCount)
+ : mEntryCounters({{unigramCount, bigramCount, trigramCount}}) {}
+
+ const EntryCounts getEntryCounts() const {
+ return EntryCounts(mEntryCounters);
+ }
+
+ int getUnigramCount() const {
+ return mEntryCounters[0];
+ }
+
+ int getBigramCount() const {
+ return mEntryCounters[1];
+ }
+
+ int getTrigramCount() const {
+ return mEntryCounters[2];
+ }
+
+ void incrementUnigramCount() {
+ ++mEntryCounters[0];
+ }
+
+ void decrementUnigramCount() {
+ ASSERT(mEntryCounters[0] != 0);
+ --mEntryCounters[0];
+ }
+
+ void incrementBigramCount() {
+ ++mEntryCounters[1];
+ }
+
+ void decrementBigramCount() {
+ ASSERT(mEntryCounters[1] != 0);
+ --mEntryCounters[1];
+ }
+
+ void incrementNgramCount(const size_t n) {
+ if (n < 1 || n > mEntryCounters.size()) {
+ return;
+ }
+ ++mEntryCounters[n - 1];
+ }
+
+ void decrementNgramCount(const size_t n) {
+ if (n < 1 || n > mEntryCounters.size()) {
+ return;
+ }
+ ASSERT(mEntryCounters[n - 1] != 0);
+ --mEntryCounters[n - 1];
+ }
+
+ void setNgramCount(const size_t n, const int count) {
+ if (n < 1 || n > mEntryCounters.size()) {
+ return;
+ }
+ mEntryCounters[n - 1] = count;
+ }
+
+ private:
+ DISALLOW_COPY_AND_ASSIGN(MutableEntryCounters);
+
+ std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1> mEntryCounters;
+};
+} // namespace latinime
+#endif /* LATINIME_ENTRY_COUNTERS_H */
diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp
index e5ef2abf8..9055f7bfc 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp
@@ -38,8 +38,7 @@ const int ForgettingCurveUtils::OCCURRENCES_TO_RAISE_THE_LEVEL = 1;
// 15 days
const int ForgettingCurveUtils::DURATION_TO_LOWER_THE_LEVEL_IN_SECONDS = 15 * 24 * 60 * 60;
-const float ForgettingCurveUtils::UNIGRAM_COUNT_HARD_LIMIT_WEIGHT = 1.2;
-const float ForgettingCurveUtils::BIGRAM_COUNT_HARD_LIMIT_WEIGHT = 1.2;
+const float ForgettingCurveUtils::ENTRY_COUNT_HARD_LIMIT_WEIGHT = 1.2;
const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityTable;
@@ -126,14 +125,22 @@ const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityT
}
/* static */ bool ForgettingCurveUtils::needsToDecay(const bool mindsBlockByDecay,
- const int unigramCount, const int bigramCount, const HeaderPolicy *const headerPolicy) {
- if (unigramCount >= getUnigramCountHardLimit(headerPolicy->getMaxUnigramCount())) {
+ const EntryCounts &entryCounts, const HeaderPolicy *const headerPolicy) {
+ if (entryCounts.getUnigramCount()
+ >= getEntryCountHardLimit(headerPolicy->getMaxUnigramCount())) {
// Unigram count exceeds the limit.
return true;
- } else if (bigramCount >= getBigramCountHardLimit(headerPolicy->getMaxBigramCount())) {
+ }
+ if (entryCounts.getBigramCount()
+ >= getEntryCountHardLimit(headerPolicy->getMaxBigramCount())) {
// Bigram count exceeds the limit.
return true;
}
+ if (entryCounts.getTrigramCount()
+ >= getEntryCountHardLimit(headerPolicy->getMaxTrigramCount())) {
+ // Trigram count exceeds the limit.
+ return true;
+ }
if (mindsBlockByDecay) {
return false;
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h
index ccbc4a98d..06dcae8a1 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h
@@ -21,6 +21,7 @@
#include "defines.h"
#include "suggest/core/dictionary/property/historical_info.h"
+#include "suggest/policyimpl/dictionary/utils/entry_counters.h"
namespace latinime {
@@ -42,22 +43,17 @@ class ForgettingCurveUtils {
static bool needsToKeep(const HistoricalInfo *const historicalInfo,
const HeaderPolicy *const headerPolicy);
- static bool needsToDecay(const bool mindsBlockByDecay, const int unigramCount,
- const int bigramCount, const HeaderPolicy *const headerPolicy);
+ static bool needsToDecay(const bool mindsBlockByDecay, const EntryCounts &entryCounters,
+ const HeaderPolicy *const headerPolicy);
// TODO: Improve probability computation method and remove this.
static int getProbabilityBiasForNgram(const int n) {
return (n - 1) * MULTIPLIER_TWO_IN_PROBABILITY_SCALE;
}
- AK_FORCE_INLINE static int getUnigramCountHardLimit(const int maxUnigramCount) {
- return static_cast<int>(static_cast<float>(maxUnigramCount)
- * UNIGRAM_COUNT_HARD_LIMIT_WEIGHT);
- }
-
- AK_FORCE_INLINE static int getBigramCountHardLimit(const int maxBigramCount) {
- return static_cast<int>(static_cast<float>(maxBigramCount)
- * BIGRAM_COUNT_HARD_LIMIT_WEIGHT);
+ AK_FORCE_INLINE static int getEntryCountHardLimit(const int maxEntryCount) {
+ return static_cast<int>(static_cast<float>(maxEntryCount)
+ * ENTRY_COUNT_HARD_LIMIT_WEIGHT);
}
private:
@@ -101,8 +97,7 @@ class ForgettingCurveUtils {
static const int OCCURRENCES_TO_RAISE_THE_LEVEL;
static const int DURATION_TO_LOWER_THE_LEVEL_IN_SECONDS;
- static const float UNIGRAM_COUNT_HARD_LIMIT_WEIGHT;
- static const float BIGRAM_COUNT_HARD_LIMIT_WEIGHT;
+ static const float ENTRY_COUNT_HARD_LIMIT_WEIGHT;
static const ProbabilityTable sProbabilityTable;