aboutsummaryrefslogtreecommitdiffstats
path: root/native
diff options
context:
space:
mode:
Diffstat (limited to 'native')
-rw-r--r--native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp37
-rw-r--r--native/jni/src/suggest/core/dictionary/dictionary.cpp6
-rw-r--r--native/jni/src/suggest/core/dictionary/property/bigram_property.h6
-rw-r--r--native/jni/src/suggest/core/dictionary/property/unigram_property.h10
-rw-r--r--native/jni/src/suggest/core/dictionary/property/word_property.h6
-rw-r--r--native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h5
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp37
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h5
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp9
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h3
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp43
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h10
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp26
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h3
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h30
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp3
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h3
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.cpp19
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.h11
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp14
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h3
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp78
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h8
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp6
-rw-r--r--native/jni/src/utils/int_array_view.h4
-rw-r--r--native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp8
-rw-r--r--native/jni/tests/utils/int_array_view_test.cpp7
27 files changed, 201 insertions, 199 deletions
diff --git a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp
index 688ce44be..e420f8056 100644
--- a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp
+++ b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp
@@ -364,10 +364,12 @@ static bool latinime_BinaryDictionary_addUnigramEntry(JNIEnv *env, jclass clazz,
int codePoints[codePointCount];
env->GetIntArrayRegion(word, 0, codePointCount, codePoints);
std::vector<UnigramProperty::ShortcutProperty> shortcuts;
- std::vector<int> shortcutTargetCodePoints;
- JniDataUtils::jintarrayToVector(env, shortcutTarget, &shortcutTargetCodePoints);
- if (!shortcutTargetCodePoints.empty()) {
- shortcuts.emplace_back(&shortcutTargetCodePoints, shortcutProbability);
+ {
+ std::vector<int> shortcutTargetCodePoints;
+ JniDataUtils::jintarrayToVector(env, shortcutTarget, &shortcutTargetCodePoints);
+ if (!shortcutTargetCodePoints.empty()) {
+ shortcuts.emplace_back(std::move(shortcutTargetCodePoints), shortcutProbability);
+ }
}
// Use 1 for count to indicate the word has inputted.
const UnigramProperty unigramProperty(isBeginningOfSentence, isNotAWord,
@@ -401,11 +403,9 @@ static bool latinime_BinaryDictionary_addNgramEntry(JNIEnv *env, jclass clazz, j
jsize wordLength = env->GetArrayLength(word);
int wordCodePoints[wordLength];
env->GetIntArrayRegion(word, 0, wordLength, wordCodePoints);
- const std::vector<int> bigramTargetCodePoints(
- wordCodePoints, wordCodePoints + wordLength);
// Use 1 for count to indicate the bigram has inputted.
- const BigramProperty bigramProperty(&bigramTargetCodePoints, probability,
- timestamp, 0 /* level */, 1 /* count */);
+ const BigramProperty bigramProperty(CodePointArrayView(wordCodePoints, wordLength).toVector(),
+ probability, timestamp, 0 /* level */, 1 /* count */);
return dictionary->addNgramEntry(&prevWordsInfo, &bigramProperty);
}
@@ -483,12 +483,14 @@ static int latinime_BinaryDictionary_addMultipleDictionaryEntries(JNIEnv *env, j
jintArray shortcutTarget = static_cast<jintArray>(
env->GetObjectField(languageModelParam, shortcutTargetFieldId));
std::vector<UnigramProperty::ShortcutProperty> shortcuts;
- std::vector<int> shortcutTargetCodePoints;
- JniDataUtils::jintarrayToVector(env, shortcutTarget, &shortcutTargetCodePoints);
- if (!shortcutTargetCodePoints.empty()) {
- jint shortcutProbability =
- env->GetIntField(languageModelParam, shortcutProbabilityFieldId);
- shortcuts.emplace_back(&shortcutTargetCodePoints, shortcutProbability);
+ {
+ std::vector<int> shortcutTargetCodePoints;
+ JniDataUtils::jintarrayToVector(env, shortcutTarget, &shortcutTargetCodePoints);
+ if (!shortcutTargetCodePoints.empty()) {
+ jint shortcutProbability =
+ env->GetIntField(languageModelParam, shortcutProbabilityFieldId);
+ shortcuts.emplace_back(std::move(shortcutTargetCodePoints), shortcutProbability);
+ }
}
// Use 1 for count to indicate the word has inputted.
const UnigramProperty unigramProperty(false /* isBeginningOfSentence */, isNotAWord,
@@ -498,11 +500,10 @@ static int latinime_BinaryDictionary_addMultipleDictionaryEntries(JNIEnv *env, j
&unigramProperty);
if (word0) {
jint bigramProbability = env->GetIntField(languageModelParam, bigramProbabilityFieldId);
- const std::vector<int> bigramTargetCodePoints(
- word1CodePoints, word1CodePoints + word1Length);
// Use 1 for count to indicate the bigram has inputted.
- const BigramProperty bigramProperty(&bigramTargetCodePoints, bigramProbability,
- timestamp, 0 /* level */, 1 /* count */);
+ const BigramProperty bigramProperty(
+ CodePointArrayView(word1CodePoints, word1Length).toVector(),
+ bigramProbability, timestamp, 0 /* level */, 1 /* count */);
const PrevWordsInfo prevWordsInfo(word0CodePoints, word0Length,
false /* isBeginningOfSentence */);
dictionary->addNgramEntry(&prevWordsInfo, &bigramProperty);
diff --git a/native/jni/src/suggest/core/dictionary/dictionary.cpp b/native/jni/src/suggest/core/dictionary/dictionary.cpp
index e4084b0f5..a3bb408c3 100644
--- a/native/jni/src/suggest/core/dictionary/dictionary.cpp
+++ b/native/jni/src/suggest/core/dictionary/dictionary.cpp
@@ -77,10 +77,8 @@ void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbabi
return;
}
int targetWordCodePoints[MAX_WORD_LENGTH];
- int unigramProbability = 0;
- const int codePointCount = mDictStructurePolicy->
- getCodePointsAndProbabilityAndReturnCodePointCount(targetWordId, MAX_WORD_LENGTH,
- targetWordCodePoints, &unigramProbability);
+ const int codePointCount = mDictStructurePolicy->getCodePointsAndReturnCodePointCount(
+ targetWordId, MAX_WORD_LENGTH, targetWordCodePoints);
if (codePointCount <= 0) {
return;
}
diff --git a/native/jni/src/suggest/core/dictionary/property/bigram_property.h b/native/jni/src/suggest/core/dictionary/property/bigram_property.h
index 343af143c..9e0baa032 100644
--- a/native/jni/src/suggest/core/dictionary/property/bigram_property.h
+++ b/native/jni/src/suggest/core/dictionary/property/bigram_property.h
@@ -26,9 +26,9 @@ namespace latinime {
// TODO: Change to NgramProperty.
class BigramProperty {
public:
- BigramProperty(const std::vector<int> *const targetCodePoints,
- const int probability, const int timestamp, const int level, const int count)
- : mTargetCodePoints(*targetCodePoints), mProbability(probability),
+ BigramProperty(const std::vector<int> &&targetCodePoints, const int probability,
+ const int timestamp, const int level, const int count)
+ : mTargetCodePoints(std::move(targetCodePoints)), mProbability(probability),
mTimestamp(timestamp), mLevel(level), mCount(count) {}
const std::vector<int> *getTargetCodePoints() const {
diff --git a/native/jni/src/suggest/core/dictionary/property/unigram_property.h b/native/jni/src/suggest/core/dictionary/property/unigram_property.h
index 902eb000f..b7e7d6686 100644
--- a/native/jni/src/suggest/core/dictionary/property/unigram_property.h
+++ b/native/jni/src/suggest/core/dictionary/property/unigram_property.h
@@ -27,8 +27,9 @@ class UnigramProperty {
public:
class ShortcutProperty {
public:
- ShortcutProperty(const std::vector<int> *const targetCodePoints, const int probability)
- : mTargetCodePoints(*targetCodePoints), mProbability(probability) {}
+ ShortcutProperty(const std::vector<int> &&targetCodePoints, const int probability)
+ : mTargetCodePoints(std::move(targetCodePoints)),
+ mProbability(probability) {}
const std::vector<int> *getTargetCodePoints() const {
return &mTargetCodePoints;
@@ -71,6 +72,11 @@ class UnigramProperty {
return mIsBlacklisted;
}
+ bool isPossiblyOffensive() const {
+ // TODO: Have dedicated flag.
+ return mProbability == 0;
+ }
+
bool hasShortcuts() const {
return !mShortcuts.empty();
}
diff --git a/native/jni/src/suggest/core/dictionary/property/word_property.h b/native/jni/src/suggest/core/dictionary/property/word_property.h
index aa3e0b68a..4e6febb3f 100644
--- a/native/jni/src/suggest/core/dictionary/property/word_property.h
+++ b/native/jni/src/suggest/core/dictionary/property/word_property.h
@@ -33,10 +33,10 @@ class WordProperty {
WordProperty()
: mCodePoints(), mUnigramProperty(), mBigrams() {}
- WordProperty(const std::vector<int> *const codePoints,
- const UnigramProperty *const unigramProperty,
+ WordProperty(const std::vector<int> &&codePoints, const UnigramProperty *const unigramProperty,
const std::vector<BigramProperty> *const bigrams)
- : mCodePoints(*codePoints), mUnigramProperty(*unigramProperty), mBigrams(*bigrams) {}
+ : mCodePoints(std::move(codePoints)), mUnigramProperty(*unigramProperty),
+ mBigrams(*bigrams) {}
void outputProperties(JNIEnv *const env, jintArray outCodePoints, jbooleanArray outFlags,
jintArray outProbabilityInfo, jobject outBigramTargets, jobject outBigramProbabilities,
diff --git a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h
index a498b6f65..1546b2610 100644
--- a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h
+++ b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h
@@ -51,9 +51,8 @@ class DictionaryStructureWithBufferPolicy {
virtual void createAndGetAllChildDicNodes(const DicNode *const dicNode,
DicNodeVector *const childDicNodes) const = 0;
- virtual int getCodePointsAndProbabilityAndReturnCodePointCount(
- const int wordId, const int maxCodePointCount, int *const outCodePoints,
- int *const outUnigramProbability) const = 0;
+ virtual int getCodePointsAndReturnCodePointCount(const int wordId, const int maxCodePointCount,
+ int *const outCodePoints) const = 0;
virtual int getWordId(const CodePointArrayView wordCodePoints,
const bool forceLowerCaseSearch) const = 0;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
index ee1403739..f752f89f1 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
@@ -87,14 +87,13 @@ void Ver4PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const d
}
}
-int Ver4PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
- const int wordId, const int maxCodePointCount, int *const outCodePoints,
- int *const outUnigramProbability) const {
+int Ver4PatriciaTriePolicy::getCodePointsAndReturnCodePointCount(const int wordId,
+ const int maxCodePointCount, int *const outCodePoints) const {
DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader);
const int ptNodePos = getTerminalPtNodePosFromWordId(wordId);
readingHelper.initWithPtNodePos(ptNodePos);
- const int codePointCount = readingHelper.getCodePointsAndProbabilityAndReturnCodePointCount(
- maxCodePointCount, outCodePoints, outUnigramProbability);
+ const int codePointCount = readingHelper.getCodePointsAndReturnCodePointCount(
+ maxCodePointCount, outCodePoints);
if (readingHelper.isError()) {
mIsCorrupted = true;
AKLOGE("Dictionary reading error in getCodePointsAndProbabilityAndReturnCodePointCount().");
@@ -495,8 +494,6 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
return WordProperty();
}
const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
- std::vector<int> codePointVector(ptNodeParams.getCodePoints(),
- ptNodeParams.getCodePoints() + ptNodeParams.getCodePointCount());
const ProbabilityEntry probabilityEntry =
mBuffers->getProbabilityDictContent()->getProbabilityEntry(
ptNodeParams.getTerminalId());
@@ -521,20 +518,17 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
if (word1TerminalPtNodePos == NOT_A_DICT_POS) {
continue;
}
- // Word (unigram) probability
- int word1Probability = NOT_A_PROBABILITY;
- const int codePointCount = getCodePointsAndProbabilityAndReturnCodePointCount(
+ const int codePointCount = getCodePointsAndReturnCodePointCount(
getWordIdFromTerminalPtNodePos(word1TerminalPtNodePos), MAX_WORD_LENGTH,
- bigramWord1CodePoints, &word1Probability);
- const std::vector<int> word1(bigramWord1CodePoints,
- bigramWord1CodePoints + codePointCount);
+ bigramWord1CodePoints);
const HistoricalInfo *const historicalInfo = bigramEntry.getHistoricalInfo();
const int probability = bigramEntry.hasHistoricalInfo() ?
ForgettingCurveUtils::decodeProbability(
bigramEntry.getHistoricalInfo(), mHeaderPolicy) :
bigramEntry.getProbability();
- bigrams.emplace_back(&word1, probability,
- historicalInfo->getTimeStamp(), historicalInfo->getLevel(),
+ bigrams.emplace_back(
+ CodePointArrayView(bigramWord1CodePoints, codePointCount).toVector(),
+ probability, historicalInfo->getTimeStamp(), historicalInfo->getLevel(),
historicalInfo->getCount());
}
}
@@ -551,15 +545,16 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
int shortcutProbability = NOT_A_PROBABILITY;
shortcutDictContent->getShortcutEntryAndAdvancePosition(MAX_WORD_LENGTH, shortcutTarget,
&shortcutTargetLength, &shortcutProbability, &hasNext, &shortcutPos);
- const std::vector<int> target(shortcutTarget, shortcutTarget + shortcutTargetLength);
- shortcuts.emplace_back(&target, shortcutProbability);
+ shortcuts.emplace_back(
+ CodePointArrayView(shortcutTarget, shortcutTargetLength).toVector(),
+ shortcutProbability);
}
}
const UnigramProperty unigramProperty(ptNodeParams.representsBeginningOfSentence(),
ptNodeParams.isNotAWord(), ptNodeParams.isBlacklisted(), ptNodeParams.getProbability(),
historicalInfo->getTimeStamp(), historicalInfo->getLevel(),
historicalInfo->getCount(), &shortcuts);
- return WordProperty(&codePointVector, &unigramProperty, &bigrams);
+ return WordProperty(wordCodePoints.toVector(), &unigramProperty, &bigrams);
}
int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outCodePoints,
@@ -580,10 +575,8 @@ int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const
return 0;
}
const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token];
- int unigramProbability = NOT_A_PROBABILITY;
- *outCodePointCount = getCodePointsAndProbabilityAndReturnCodePointCount(
- getWordIdFromTerminalPtNodePos(terminalPtNodePos), MAX_WORD_LENGTH, outCodePoints,
- &unigramProbability);
+ *outCodePointCount = getCodePointsAndReturnCodePointCount(
+ getWordIdFromTerminalPtNodePos(terminalPtNodePos), MAX_WORD_LENGTH, outCodePoints);
const int nextToken = token + 1;
if (nextToken >= terminalPtNodePositionsVectorSize) {
// All words have been iterated.
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h
index 576d2abb5..8420c94d0 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h
@@ -85,9 +85,8 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
void createAndGetAllChildDicNodes(const DicNode *const dicNode,
DicNodeVector *const childDicNodes) const;
- int getCodePointsAndProbabilityAndReturnCodePointCount(
- const int wordId, const int maxCodePointCount, int *const outCodePoints,
- int *const outUnigramProbability) const;
+ int getCodePointsAndReturnCodePointCount(const int wordId, const int maxCodePointCount,
+ int *const outCodePoints) const;
int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp
index 40782a44f..5e4a4b166 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp
@@ -175,8 +175,8 @@ bool DynamicPtReadingHelper::traverseAllPtNodesInPtNodeArrayLevelPreorderDepthFi
return !isError();
}
-int DynamicPtReadingHelper::getCodePointsAndProbabilityAndReturnCodePointCount(
- const int maxCodePointCount, int *const outCodePoints, int *const outUnigramProbability) {
+int DynamicPtReadingHelper::getCodePointsAndReturnCodePointCount(const int maxCodePointCount,
+ int *const outCodePoints) {
// This method traverses parent nodes from the terminal by following parent pointers; thus,
// node code points are stored in the buffer in the reverse order.
int reverseCodePoints[maxCodePointCount];
@@ -184,11 +184,8 @@ int DynamicPtReadingHelper::getCodePointsAndProbabilityAndReturnCodePointCount(
// First, read the terminal node and get its probability.
if (!isValidTerminalNode(terminalPtNodeParams)) {
// Node at the ptNodePos is not a valid terminal node.
- *outUnigramProbability = NOT_A_PROBABILITY;
return 0;
}
- // Store terminal node probability.
- *outUnigramProbability = terminalPtNodeParams.getProbability();
// Then, following parent node link to the dictionary root and fetch node code points.
int totalCodePointCount = 0;
while (!isEnd()) {
@@ -196,7 +193,6 @@ int DynamicPtReadingHelper::getCodePointsAndProbabilityAndReturnCodePointCount(
totalCodePointCount = getTotalCodePointCount(ptNodeParams);
if (!ptNodeParams.isValid() || totalCodePointCount > maxCodePointCount) {
// The ptNodePos is not a valid terminal node position in the dictionary.
- *outUnigramProbability = NOT_A_PROBABILITY;
return 0;
}
// Store node code points to buffer in the reverse order.
@@ -207,7 +203,6 @@ int DynamicPtReadingHelper::getCodePointsAndProbabilityAndReturnCodePointCount(
}
if (isError()) {
// The node position or the dictionary is invalid.
- *outUnigramProbability = NOT_A_PROBABILITY;
return 0;
}
// Reverse the stored code points to output them.
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h
index 9a7abc97f..21c287fdc 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h
@@ -211,8 +211,7 @@ class DynamicPtReadingHelper {
bool traverseAllPtNodesInPtNodeArrayLevelPreorderDepthFirstManner(
TraversingEventListener *const listener);
- int getCodePointsAndProbabilityAndReturnCodePointCount(const int maxCodePointCount,
- int *const outCodePoints, int *const outUnigramProbability);
+ int getCodePointsAndReturnCodePointCount(const int maxCodePointCount, int *const outCodePoints);
int getTerminalPtNodePositionOfWord(const int *const inWord, const size_t length,
const bool forceLowerCaseSearch);
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp
index 6e7dba9ff..13cf9a5a8 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp
@@ -58,6 +58,11 @@ void PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const dicNo
}
}
+int PatriciaTriePolicy::getCodePointsAndReturnCodePointCount(const int wordId,
+ const int maxCodePointCount, int *const outCodePoints) const {
+ return getCodePointsAndProbabilityAndReturnCodePointCount(wordId, maxCodePointCount,
+ outCodePoints, nullptr /* outUnigramProbability */);
+}
// This retrieves code points and the probability of the word by its id.
// Due to the fact that words are ordered in the dictionary in a strict breadth-first order,
// it is possible to check for this with advantageous complexity. For each PtNode array, we search
@@ -82,6 +87,9 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
int pos = getRootPosition();
int wordPos = 0;
const int *const codePointTable = mHeaderPolicy.getCodePointTable();
+ if (outUnigramProbability) {
+ *outUnigramProbability = NOT_A_PROBABILITY;
+ }
// One iteration of the outer loop iterates through PtNode arrays. As stated above, we will
// only traverse PtNodes that are actually a part of the terminal we are searching, so each
// time we enter this loop we are one depth level further than last time.
@@ -97,7 +105,6 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
pos, mBuffer.size());
mIsCorrupted = true;
ASSERT(false);
- *outUnigramProbability = NOT_A_PROBABILITY;
return 0;
}
for (int ptNodeCount = PatriciaTrieReadingUtils::getPtNodeArraySizeAndAdvancePosition(
@@ -107,7 +114,6 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
AKLOGE("PtNode position is invalid. pos: %d, dict size: %zd", pos, mBuffer.size());
mIsCorrupted = true;
ASSERT(false);
- *outUnigramProbability = NOT_A_PROBABILITY;
return 0;
}
const PatriciaTrieReadingUtils::NodeFlags flags =
@@ -130,9 +136,11 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
mBuffer.data(), codePointTable, &pos);
}
}
- *outUnigramProbability =
- PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mBuffer.data(),
- &pos);
+ if (outUnigramProbability) {
+ *outUnigramProbability =
+ PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(
+ mBuffer.data(), &pos);
+ }
return ++wordPos;
}
// We need to skip past this PtNode, so skip any remaining code points after the
@@ -234,7 +242,6 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
pos);
mIsCorrupted = true;
ASSERT(false);
- *outUnigramProbability = NOT_A_PROBABILITY;
return 0;
}
}
@@ -257,7 +264,6 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
AKLOGE("Cannot skip bigrams. BufSize: %zd, pos: %d.", mBuffer.size(), pos);
mIsCorrupted = true;
ASSERT(false);
- *outUnigramProbability = NOT_A_PROBABILITY;
return 0;
}
}
@@ -429,8 +435,6 @@ const WordProperty PatriciaTriePolicy::getWordProperty(
const int ptNodePos = getTerminalPtNodePosFromWordId(wordId);
const PtNodeParams ptNodeParams =
mPtNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
- std::vector<int> codePointVector(ptNodeParams.getCodePoints(),
- ptNodeParams.getCodePoints() + ptNodeParams.getCodePointCount());
// Fetch bigram information.
std::vector<BigramProperty> bigrams;
const int bigramListPos = getBigramsPositionOfPtNode(ptNodePos);
@@ -445,11 +449,10 @@ const WordProperty PatriciaTriePolicy::getWordProperty(
const int word1CodePointCount = getCodePointsAndProbabilityAndReturnCodePointCount(
getWordIdFromTerminalPtNodePos(bigramsIt.getBigramPos()), MAX_WORD_LENGTH,
bigramWord1CodePoints, &word1Probability);
- const std::vector<int> word1(bigramWord1CodePoints,
- bigramWord1CodePoints + word1CodePointCount);
const int probability = getProbability(word1Probability, bigramsIt.getProbability());
- bigrams.emplace_back(&word1, probability,
- NOT_A_TIMESTAMP /* timestamp */, 0 /* level */, 0 /* count */);
+ bigrams.emplace_back(
+ CodePointArrayView(bigramWord1CodePoints, word1CodePointCount).toVector(),
+ probability, NOT_A_TIMESTAMP /* timestamp */, 0 /* level */, 0 /* count */);
}
}
// Fetch shortcut information.
@@ -465,17 +468,17 @@ const WordProperty PatriciaTriePolicy::getWordProperty(
hasNext = ShortcutListReadingUtils::hasNext(shortcutFlags);
const int shortcutTargetLength = ShortcutListReadingUtils::readShortcutTarget(
mBuffer, MAX_WORD_LENGTH, shortcutTargetCodePoints, &shortcutPos);
- const std::vector<int> shortcutTarget(shortcutTargetCodePoints,
- shortcutTargetCodePoints + shortcutTargetLength);
const int shortcutProbability =
ShortcutListReadingUtils::getProbabilityFromFlags(shortcutFlags);
- shortcuts.emplace_back(&shortcutTarget, shortcutProbability);
+ shortcuts.emplace_back(
+ CodePointArrayView(shortcutTargetCodePoints, shortcutTargetLength).toVector(),
+ shortcutProbability);
}
}
const UnigramProperty unigramProperty(ptNodeParams.representsBeginningOfSentence(),
ptNodeParams.isNotAWord(), ptNodeParams.isBlacklisted(), ptNodeParams.getProbability(),
NOT_A_TIMESTAMP /* timestamp */, 0 /* level */, 0 /* count */, &shortcuts);
- return WordProperty(&codePointVector, &unigramProperty, &bigrams);
+ return WordProperty(wordCodePoints.toVector(), &unigramProperty, &bigrams);
}
int PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outCodePoints,
@@ -497,10 +500,8 @@ int PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outC
return 0;
}
const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token];
- int unigramProbability = NOT_A_PROBABILITY;
- *outCodePointCount = getCodePointsAndProbabilityAndReturnCodePointCount(
- getWordIdFromTerminalPtNodePos(terminalPtNodePos), MAX_WORD_LENGTH, outCodePoints,
- &unigramProbability);
+ *outCodePointCount = getCodePointsAndReturnCodePointCount(
+ getWordIdFromTerminalPtNodePos(terminalPtNodePos), MAX_WORD_LENGTH, outCodePoints);
const int nextToken = token + 1;
if (nextToken >= terminalPtNodePositionsVectorSize) {
// All words have been iterated.
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h
index 5f179513f..0d679c5dc 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h
@@ -43,7 +43,7 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
PatriciaTriePolicy(MmappedBuffer::MmappedBufferPtr mmappedBuffer)
: mMmappedBuffer(std::move(mmappedBuffer)),
mHeaderPolicy(mMmappedBuffer->getReadOnlyByteArrayView().data(),
- FormatUtils::detectFormatVersion(mmappedBuffer->getReadOnlyByteArrayView())),
+ FormatUtils::detectFormatVersion(mMmappedBuffer->getReadOnlyByteArrayView())),
mBuffer(mMmappedBuffer->getReadOnlyByteArrayView().skip(mHeaderPolicy.getSize())),
mBigramListPolicy(mBuffer), mShortcutListPolicy(mBuffer),
mPtNodeReader(mBuffer, &mBigramListPolicy, &mShortcutListPolicy,
@@ -58,9 +58,8 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
void createAndGetAllChildDicNodes(const DicNode *const dicNode,
DicNodeVector *const childDicNodes) const;
- int getCodePointsAndProbabilityAndReturnCodePointCount(
- const int wordId, const int maxCodePointCount, int *const outCodePoints,
- int *const outUnigramProbability) const;
+ int getCodePointsAndReturnCodePointCount(const int wordId, const int maxCodePointCount,
+ int *const outCodePoints) const;
int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
@@ -155,6 +154,9 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
std::vector<int> mTerminalPtNodePositionsForIteratingWords;
mutable bool mIsCorrupted;
+ int getCodePointsAndProbabilityAndReturnCodePointCount(const int wordId,
+ const int maxCodePointCount, int *const outCodePoints,
+ int *const outUnigramProbability) const;
int getShortcutPositionOfPtNode(const int ptNodePos) const;
int getBigramsPositionOfPtNode(const int ptNodePos) const;
int createAndGetLeavingChildNode(const DicNode *const dicNode, const int ptNodePos,
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
index 35f0f768f..139230228 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
@@ -38,7 +38,7 @@ bool LanguageModelDictContent::runGC(
0 /* nextLevelBitmapEntryIndex */, outNgramCount);
}
-int LanguageModelDictContent::getWordProbability(const WordIdArrayView prevWordIds,
+const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArrayView prevWordIds,
const int wordId, const HeaderPolicy *const headerPolicy) const {
int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex();
@@ -60,17 +60,29 @@ int LanguageModelDictContent::getWordProbability(const WordIdArrayView prevWordI
}
const ProbabilityEntry probabilityEntry =
ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo);
+ int probability = NOT_A_PROBABILITY;
if (mHasHistoricalInfo) {
- const int probability = ForgettingCurveUtils::decodeProbability(
- probabilityEntry.getHistoricalInfo(), headerPolicy)
- + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */);
- return std::min(probability, MAX_PROBABILITY);
+ const int rawProbability = ForgettingCurveUtils::decodeProbability(
+ probabilityEntry.getHistoricalInfo(), headerPolicy);
+ if (rawProbability == NOT_A_PROBABILITY) {
+ // The entry should not be treated as a valid entry.
+ continue;
+ }
+ probability = std::min(rawProbability
+ + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */),
+ MAX_PROBABILITY);
} else {
- return probabilityEntry.getProbability();
+ probability = probabilityEntry.getProbability();
}
+ // TODO: Some flags in unigramProbabilityEntry should be overwritten by flags in
+ // probabilityEntry.
+ const ProbabilityEntry unigramProbabilityEntry = getProbabilityEntry(wordId);
+ return WordAttributes(probability, unigramProbabilityEntry.isNotAWord(),
+ unigramProbabilityEntry.isBlacklisted(),
+ unigramProbabilityEntry.isPossiblyOffensive());
}
// Cannot find the word.
- return NOT_A_PROBABILITY;
+ return WordAttributes();
}
ProbabilityEntry LanguageModelDictContent::getNgramProbabilityEntry(
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
index a793af4be..b7e4af977 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
@@ -21,6 +21,7 @@
#include <vector>
#include "defines.h"
+#include "suggest/core/dictionary/word_attributes.h"
#include "suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h"
#include "suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h"
#include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h"
@@ -128,7 +129,7 @@ class LanguageModelDictContent {
const LanguageModelDictContent *const originalContent,
int *const outNgramCount);
- int getWordProbability(const WordIdArrayView prevWordIds, const int wordId,
+ const WordAttributes getWordAttributes(const WordIdArrayView prevWordIds, const int wordId,
const HeaderPolicy *const headerPolicy) const;
ProbabilityEntry getProbabilityEntry(const int wordId) const {
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h
index f1bf12cb2..e1e10ca17 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h
@@ -49,7 +49,9 @@ class ProbabilityEntry {
// Create from unigram property.
ProbabilityEntry(const UnigramProperty *const unigramProperty)
- : mFlags(createFlags(unigramProperty->representsBeginningOfSentence())),
+ : mFlags(createFlags(unigramProperty->representsBeginningOfSentence(),
+ unigramProperty->isNotAWord(), unigramProperty->isBlacklisted(),
+ unigramProperty->isPossiblyOffensive())),
mProbability(unigramProperty->getProbability()),
mHistoricalInfo(unigramProperty->getTimestamp(), unigramProperty->getLevel(),
unigramProperty->getCount()) {}
@@ -85,6 +87,18 @@ class ProbabilityEntry {
return (mFlags & Ver4DictConstants::FLAG_REPRESENTS_BEGINNING_OF_SENTENCE) != 0;
}
+ bool isNotAWord() const {
+ return (mFlags & Ver4DictConstants::FLAG_NOT_A_WORD) != 0;
+ }
+
+ bool isBlacklisted() const {
+ return (mFlags & Ver4DictConstants::FLAG_BLACKLISTED) != 0;
+ }
+
+ bool isPossiblyOffensive() const {
+ return (mFlags & Ver4DictConstants::FLAG_POSSIBLY_OFFENSIVE) != 0;
+ }
+
uint64_t encode(const bool hasHistoricalInfo) const {
uint64_t encodedEntry = static_cast<uint64_t>(mFlags);
if (hasHistoricalInfo) {
@@ -142,10 +156,20 @@ class ProbabilityEntry {
(encodedEntry >> (pos * CHAR_BIT)) & ((1ull << (size * CHAR_BIT)) - 1));
}
- static uint8_t createFlags(const bool representsBeginningOfSentence) {
+ static uint8_t createFlags(const bool representsBeginningOfSentence,
+ const bool isNotAWord, const bool isBlacklisted, const bool isPossiblyOffensive) {
uint8_t flags = 0;
if (representsBeginningOfSentence) {
- flags ^= Ver4DictConstants::FLAG_REPRESENTS_BEGINNING_OF_SENTENCE;
+ flags |= Ver4DictConstants::FLAG_REPRESENTS_BEGINNING_OF_SENTENCE;
+ }
+ if (isNotAWord) {
+ flags |= Ver4DictConstants::FLAG_NOT_A_WORD;
+ }
+ if (isBlacklisted) {
+ flags |= Ver4DictConstants::FLAG_BLACKLISTED;
+ }
+ if (isPossiblyOffensive) {
+ flags |= Ver4DictConstants::FLAG_POSSIBLY_OFFENSIVE;
}
return flags;
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp
index 39822b94a..8e6cb974b 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp
@@ -54,6 +54,9 @@ const int Ver4DictConstants::WORD_COUNT_FIELD_SIZE = 1;
const uint8_t Ver4DictConstants::FLAG_REPRESENTS_BEGINNING_OF_SENTENCE = 0x1;
const uint8_t Ver4DictConstants::FLAG_NOT_A_VALID_ENTRY = 0x2;
+const uint8_t Ver4DictConstants::FLAG_NOT_A_WORD = 0x4;
+const uint8_t Ver4DictConstants::FLAG_BLACKLISTED = 0x8;
+const uint8_t Ver4DictConstants::FLAG_POSSIBLY_OFFENSIVE = 0x10;
const int Ver4DictConstants::SHORTCUT_ADDRESS_TABLE_BLOCK_SIZE = 64;
const int Ver4DictConstants::SHORTCUT_ADDRESS_TABLE_DATA_SIZE = 4;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h
index dfcdd4d6f..600b5ffe4 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h
@@ -52,6 +52,9 @@ class Ver4DictConstants {
// Flags in probability entry.
static const uint8_t FLAG_REPRESENTS_BEGINNING_OF_SENTENCE;
static const uint8_t FLAG_NOT_A_VALID_ENTRY;
+ static const uint8_t FLAG_NOT_A_WORD;
+ static const uint8_t FLAG_BLACKLISTED;
+ static const uint8_t FLAG_POSSIBLY_OFFENSIVE;
static const int SHORTCUT_ADDRESS_TABLE_BLOCK_SIZE;
static const int SHORTCUT_ADDRESS_TABLE_DATA_SIZE;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.cpp
index d795239fc..4110d6036 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.cpp
@@ -51,26 +51,17 @@ const PtNodeParams Ver4PatriciaTrieNodeReader::fetchPtNodeInfoFromBufferAndProce
const int parentPos =
DynamicPtReadingUtils::getParentPtNodePos(parentPosOffset, headPos);
int codePoints[MAX_WORD_LENGTH];
- const int codePonitCount = PatriciaTrieReadingUtils::getCharsAndAdvancePosition(
- dictBuf, flags, MAX_WORD_LENGTH, mHeaderPolicy->getCodePointTable(), codePoints, &pos);
+ // Code point table is not used for ver4 dictionaries.
+ const int codePointCount = PatriciaTrieReadingUtils::getCharsAndAdvancePosition(
+ dictBuf, flags, MAX_WORD_LENGTH, nullptr /* codePointTable */, codePoints, &pos);
int terminalIdFieldPos = NOT_A_DICT_POS;
int terminalId = Ver4DictConstants::NOT_A_TERMINAL_ID;
- int probability = NOT_A_PROBABILITY;
if (PatriciaTrieReadingUtils::isTerminal(flags)) {
terminalIdFieldPos = pos;
if (usesAdditionalBuffer) {
terminalIdFieldPos += mBuffer->getOriginalBufferSize();
}
terminalId = Ver4PatriciaTrieReadingUtils::getTerminalIdAndAdvancePosition(dictBuf, &pos);
- // TODO: Quit reading probability here.
- const ProbabilityEntry probabilityEntry =
- mLanguageModelDictContent->getProbabilityEntry(terminalId);
- if (probabilityEntry.hasHistoricalInfo()) {
- probability = ForgettingCurveUtils::decodeProbability(
- probabilityEntry.getHistoricalInfo(), mHeaderPolicy);
- } else {
- probability = probabilityEntry.getProbability();
- }
}
int childrenPosFieldPos = pos;
if (usesAdditionalBuffer) {
@@ -91,8 +82,8 @@ const PtNodeParams Ver4PatriciaTrieNodeReader::fetchPtNodeInfoFromBufferAndProce
// The destination position is stored at the same place as the parent position.
return fetchPtNodeInfoFromBufferAndProcessMovedPtNode(parentPos, newSiblingNodePos);
} else {
- return PtNodeParams(headPos, flags, parentPos, codePonitCount, codePoints,
- terminalIdFieldPos, terminalId, probability, childrenPosFieldPos, childrenPos,
+ return PtNodeParams(headPos, flags, parentPos, codePointCount, codePoints,
+ terminalIdFieldPos, terminalId, NOT_A_PROBABILITY, childrenPosFieldPos, childrenPos,
newSiblingNodePos);
}
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.h
index a91ad5728..f4df544e2 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.h
@@ -29,15 +29,12 @@ class LanguageModelDictContent;
/*
* This class is used for helping to read nodes of ver4 patricia trie. This class handles moved
- * node and reads node attributes including probability form language model.
+ * node and reads node attributes.
*/
class Ver4PatriciaTrieNodeReader : public PtNodeReader {
public:
- Ver4PatriciaTrieNodeReader(const BufferWithExtendableBuffer *const buffer,
- const LanguageModelDictContent *const languageModelDictContent,
- const HeaderPolicy *const headerPolicy)
- : mBuffer(buffer), mLanguageModelDictContent(languageModelDictContent),
- mHeaderPolicy(headerPolicy) {}
+ explicit Ver4PatriciaTrieNodeReader(const BufferWithExtendableBuffer *const buffer)
+ : mBuffer(buffer) {}
~Ver4PatriciaTrieNodeReader() {}
@@ -50,8 +47,6 @@ class Ver4PatriciaTrieNodeReader : public PtNodeReader {
DISALLOW_COPY_AND_ASSIGN(Ver4PatriciaTrieNodeReader);
const BufferWithExtendableBuffer *const mBuffer;
- const LanguageModelDictContent *const mLanguageModelDictContent;
- const HeaderPolicy *const mHeaderPolicy;
const PtNodeParams fetchPtNodeInfoFromBufferAndProcessMovedPtNode(const int ptNodePos,
const int siblingNodePos) const;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp
index 75ec16912..a1a33d27a 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp
@@ -191,7 +191,6 @@ bool Ver4PatriciaTrieNodeWriter::writePtNodeAndAdvancePosition(
ptNodeWritingPos);
}
-
bool Ver4PatriciaTrieNodeWriter::writeNewTerminalPtNodeAndAdvancePosition(
const PtNodeParams *const ptNodeParams, const UnigramProperty *const unigramProperty,
int *const ptNodeWritingPos) {
@@ -341,8 +340,8 @@ bool Ver4PatriciaTrieNodeWriter::writePtNodeAndGetTerminalIdAndAdvancePosition(
ptNodeParams->getChildrenPos(), ptNodeWritingPos)) {
return false;
}
- return updatePtNodeFlags(nodePos, ptNodeParams->isBlacklisted(), ptNodeParams->isNotAWord(),
- isTerminal, ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */);
+ return updatePtNodeFlags(nodePos, isTerminal,
+ ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */);
}
// TODO: Move probability handling code to LanguageModelDictContent.
@@ -361,14 +360,13 @@ const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom(
}
}
-bool Ver4PatriciaTrieNodeWriter::updatePtNodeFlags(const int ptNodePos,
- const bool isBlacklisted, const bool isNotAWord, const bool isTerminal,
+bool Ver4PatriciaTrieNodeWriter::updatePtNodeFlags(const int ptNodePos, const bool isTerminal,
const bool hasMultipleChars) {
// Create node flags and write them.
PatriciaTrieReadingUtils::NodeFlags nodeFlags =
- PatriciaTrieReadingUtils::createAndGetFlags(isBlacklisted, isNotAWord, isTerminal,
- false /* hasShortcutTargets */, false /* hasBigrams */, hasMultipleChars,
- CHILDREN_POSITION_FIELD_SIZE);
+ PatriciaTrieReadingUtils::createAndGetFlags(false /* isNotAWord */,
+ false /* isBlacklisted */, isTerminal, false /* hasShortcutTargets */,
+ false /* hasBigrams */, hasMultipleChars, CHILDREN_POSITION_FIELD_SIZE);
if (!DynamicPtWritingUtils::writeFlags(mTrieBuffer, nodeFlags, ptNodePos)) {
AKLOGE("Cannot write PtNode flags. flags: %x, pos: %d", nodeFlags, ptNodePos);
return false;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h
index 08b7d3825..17915273b 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h
@@ -103,8 +103,7 @@ class Ver4PatriciaTrieNodeWriter : public PtNodeWriter {
const ProbabilityEntry *const originalProbabilityEntry,
const ProbabilityEntry *const probabilityEntry) const;
- bool updatePtNodeFlags(const int ptNodePos, const bool isBlacklisted, const bool isNotAWord,
- const bool isTerminal, const bool hasMultipleChars);
+ bool updatePtNodeFlags(const int ptNodePos, const bool isTerminal, const bool hasMultipleChars);
static const int CHILDREN_POSITION_FIELD_SIZE;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
index 8d4135679..0f0696410 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
@@ -56,21 +56,11 @@ void Ver4PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const d
if (!ptNodeParams.isValid()) {
break;
}
- bool isTerminal = ptNodeParams.isTerminal() && !ptNodeParams.isDeleted();
- if (isTerminal && mHeaderPolicy->isDecayingDict()) {
- // A DecayingDict may have a terminal PtNode that has a terminal DicNode whose
- // probability is NOT_A_PROBABILITY. In such case, we don't want to treat it as a
- // valid terminal DicNode.
- isTerminal = ptNodeParams.getProbability() != NOT_A_PROBABILITY;
- }
- readingHelper.readNextSiblingNode(ptNodeParams);
- if (ptNodeParams.representsNonWordInfo()) {
- // Skip PtNodes that represent non-word information.
- continue;
- }
+ const bool isTerminal = ptNodeParams.isTerminal() && !ptNodeParams.isDeleted();
const int wordId = isTerminal ? ptNodeParams.getTerminalId() : NOT_A_WORD_ID;
childDicNodes->pushLeavingChild(dicNode, ptNodeParams.getChildrenPos(),
wordId, ptNodeParams.getCodePointArrayView());
+ readingHelper.readNextSiblingNode(ptNodeParams);
}
if (readingHelper.isError()) {
mIsCorrupted = true;
@@ -78,15 +68,14 @@ void Ver4PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const d
}
}
-int Ver4PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
- const int wordId, const int maxCodePointCount, int *const outCodePoints,
- int *const outUnigramProbability) const {
+int Ver4PatriciaTriePolicy::getCodePointsAndReturnCodePointCount(const int wordId,
+ const int maxCodePointCount, int *const outCodePoints) const {
DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader);
const int ptNodePos =
mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
readingHelper.initWithPtNodePos(ptNodePos);
- const int codePointCount = readingHelper.getCodePointsAndProbabilityAndReturnCodePointCount(
- maxCodePointCount, outCodePoints, outUnigramProbability);
+ const int codePointCount = readingHelper.getCodePointsAndReturnCodePointCount(
+ maxCodePointCount, outCodePoints);
if (readingHelper.isError()) {
mIsCorrupted = true;
AKLOGE("Dictionary reading error in getCodePointsAndProbabilityAndReturnCodePointCount().");
@@ -117,13 +106,8 @@ const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext(
if (wordId == NOT_A_WORD_ID) {
return WordAttributes();
}
- const int ptNodePos =
- mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
- const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
- const int probability = mBuffers->getLanguageModelDictContent()->getWordProbability(
- prevWordIds, wordId, mHeaderPolicy);
- return WordAttributes(probability, ptNodeParams.isBlacklisted(), ptNodeParams.isNotAWord(),
- probability == 0);
+ return mBuffers->getLanguageModelDictContent()->getWordAttributes(prevWordIds, wordId,
+ mHeaderPolicy);
}
int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds,
@@ -131,15 +115,10 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordI
if (wordId == NOT_A_WORD_ID || prevWordIds.contains(NOT_A_WORD_ID)) {
return NOT_A_PROBABILITY;
}
- const int ptNodePos =
- mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
- const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
- if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) {
- return NOT_A_PROBABILITY;
- }
const ProbabilityEntry probabilityEntry =
mBuffers->getLanguageModelDictContent()->getNgramProbabilityEntry(prevWordIds, wordId);
- if (!probabilityEntry.isValid()) {
+ if (!probabilityEntry.isValid() || probabilityEntry.isBlacklisted()
+ || probabilityEntry.isNotAWord()) {
return NOT_A_PROBABILITY;
}
if (mHeaderPolicy->hasHistoricalInfoOfWords()) {
@@ -166,6 +145,9 @@ void Ver4PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordI
for (const auto entry : languageModelDictContent->getProbabilityEntries(
prevWordIds.limit(i))) {
const ProbabilityEntry &probabilityEntry = entry.getProbabilityEntry();
+ if (!probabilityEntry.isValid()) {
+ continue;
+ }
const int probability = probabilityEntry.hasHistoricalInfo() ?
ForgettingCurveUtils::decodeProbability(
probabilityEntry.getHistoricalInfo(), mHeaderPolicy)
@@ -463,8 +445,6 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
const int ptNodePos =
mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
- std::vector<int> codePointVector(ptNodeParams.getCodePoints(),
- ptNodeParams.getCodePoints() + ptNodeParams.getCodePointCount());
const ProbabilityEntry probabilityEntry =
mBuffers->getLanguageModelDictContent()->getProbabilityEntry(
ptNodeParams.getTerminalId());
@@ -476,19 +456,15 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
int bigramWord1CodePoints[MAX_WORD_LENGTH];
for (const auto entry : mBuffers->getLanguageModelDictContent()->getProbabilityEntries(
prevWordIds)) {
- // Word (unigram) probability
- int word1Probability = NOT_A_PROBABILITY;
- const int codePointCount = getCodePointsAndProbabilityAndReturnCodePointCount(
- entry.getWordId(), MAX_WORD_LENGTH, bigramWord1CodePoints, &word1Probability);
- const std::vector<int> word1(bigramWord1CodePoints,
- bigramWord1CodePoints + codePointCount);
+ const int codePointCount = getCodePointsAndReturnCodePointCount(entry.getWordId(),
+ MAX_WORD_LENGTH, bigramWord1CodePoints);
const ProbabilityEntry probabilityEntry = entry.getProbabilityEntry();
const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo();
const int probability = probabilityEntry.hasHistoricalInfo() ?
ForgettingCurveUtils::decodeProbability(historicalInfo, mHeaderPolicy) :
probabilityEntry.getProbability();
- bigrams.emplace_back(&word1, probability,
- historicalInfo->getTimeStamp(), historicalInfo->getLevel(),
+ bigrams.emplace_back(CodePointArrayView(bigramWord1CodePoints, codePointCount).toVector(),
+ probability, historicalInfo->getTimeStamp(), historicalInfo->getLevel(),
historicalInfo->getCount());
}
// Fetch shortcut information.
@@ -504,15 +480,16 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
int shortcutProbability = NOT_A_PROBABILITY;
shortcutDictContent->getShortcutEntryAndAdvancePosition(MAX_WORD_LENGTH, shortcutTarget,
&shortcutTargetLength, &shortcutProbability, &hasNext, &shortcutPos);
- const std::vector<int> target(shortcutTarget, shortcutTarget + shortcutTargetLength);
- shortcuts.emplace_back(&target, shortcutProbability);
+ shortcuts.emplace_back(
+ CodePointArrayView(shortcutTarget, shortcutTargetLength).toVector(),
+ shortcutProbability);
}
}
- const UnigramProperty unigramProperty(ptNodeParams.representsBeginningOfSentence(),
- ptNodeParams.isNotAWord(), ptNodeParams.isBlacklisted(), ptNodeParams.getProbability(),
- historicalInfo->getTimeStamp(), historicalInfo->getLevel(),
- historicalInfo->getCount(), &shortcuts);
- return WordProperty(&codePointVector, &unigramProperty, &bigrams);
+ const UnigramProperty unigramProperty(probabilityEntry.representsBeginningOfSentence(),
+ probabilityEntry.isNotAWord(), probabilityEntry.isBlacklisted(),
+ probabilityEntry.getProbability(), historicalInfo->getTimeStamp(),
+ historicalInfo->getLevel(), historicalInfo->getCount(), &shortcuts);
+ return WordProperty(wordCodePoints.toVector(), &unigramProperty, &bigrams);
}
int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outCodePoints,
@@ -535,9 +512,8 @@ int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const
const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token];
const PtNodeParams ptNodeParams =
mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(terminalPtNodePos);
- int unigramProbability = NOT_A_PROBABILITY;
- *outCodePointCount = getCodePointsAndProbabilityAndReturnCodePointCount(
- ptNodeParams.getTerminalId(), MAX_WORD_LENGTH, outCodePoints, &unigramProbability);
+ *outCodePointCount = getCodePointsAndReturnCodePointCount(ptNodeParams.getTerminalId(),
+ MAX_WORD_LENGTH, outCodePoints);
const int nextToken = token + 1;
if (nextToken >= terminalPtNodePositionsVectorSize) {
// All words have been iterated.
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h
index a117a3614..c9bde2cf5 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h
@@ -45,8 +45,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
mDictBuffer(mBuffers->getWritableTrieBuffer()),
mShortcutPolicy(mBuffers->getMutableShortcutDictContent(),
mBuffers->getTerminalPositionLookupTable()),
- mNodeReader(mDictBuffer, mBuffers->getLanguageModelDictContent(), mHeaderPolicy),
- mPtNodeArrayReader(mDictBuffer),
+ mNodeReader(mDictBuffer), mPtNodeArrayReader(mDictBuffer),
mNodeWriter(mDictBuffer, mBuffers.get(), mHeaderPolicy, &mNodeReader,
&mPtNodeArrayReader, &mShortcutPolicy),
mUpdatingHelper(mDictBuffer, &mNodeReader, &mNodeWriter),
@@ -62,9 +61,8 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
void createAndGetAllChildDicNodes(const DicNode *const dicNode,
DicNodeVector *const childDicNodes) const;
- int getCodePointsAndProbabilityAndReturnCodePointCount(
- const int wordId, const int maxCodePointCount, int *const outCodePoints,
- int *const outUnigramProbability) const;
+ int getCodePointsAndReturnCodePointCount(const int wordId, const int maxCodePointCount,
+ int *const outCodePoints) const;
int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
index 63e43a544..442abadee 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
@@ -73,8 +73,7 @@ bool Ver4PatriciaTrieWritingHelper::writeToDictFileWithGC(const int rootPtNodeAr
bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
const HeaderPolicy *const headerPolicy, Ver4DictBuffers *const buffersToWrite,
int *const outUnigramCount, int *const outBigramCount) {
- Ver4PatriciaTrieNodeReader ptNodeReader(mBuffers->getTrieBuffer(),
- mBuffers->getLanguageModelDictContent(), headerPolicy);
+ Ver4PatriciaTrieNodeReader ptNodeReader(mBuffers->getTrieBuffer());
Ver4PtNodeArrayReader ptNodeArrayReader(mBuffers->getTrieBuffer());
Ver4ShortcutListPolicy shortcutPolicy(mBuffers->getMutableShortcutDictContent(),
mBuffers->getTerminalPositionLookupTable());
@@ -137,8 +136,7 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
}
// Create policy instances for the GCed dictionary.
- Ver4PatriciaTrieNodeReader newPtNodeReader(buffersToWrite->getTrieBuffer(),
- buffersToWrite->getLanguageModelDictContent(), headerPolicy);
+ Ver4PatriciaTrieNodeReader newPtNodeReader(buffersToWrite->getTrieBuffer());
Ver4PtNodeArrayReader newPtNodeArrayreader(buffersToWrite->getTrieBuffer());
Ver4ShortcutListPolicy newShortcutPolicy(buffersToWrite->getMutableShortcutDictContent(),
buffersToWrite->getTerminalPositionLookupTable());
diff --git a/native/jni/src/utils/int_array_view.h b/native/jni/src/utils/int_array_view.h
index f3a8589ca..408373176 100644
--- a/native/jni/src/utils/int_array_view.h
+++ b/native/jni/src/utils/int_array_view.h
@@ -129,6 +129,10 @@ class IntArrayView {
return mPtr[mSize - 1];
}
+ AK_FORCE_INLINE std::vector<int> toVector() const {
+ return std::vector<int>(begin(), end());
+ }
+
private:
DISALLOW_ASSIGNMENT_OPERATOR(IntArrayView);
diff --git a/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp b/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp
index 06f82df52..daa32c348 100644
--- a/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp
+++ b/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp
@@ -107,15 +107,15 @@ TEST(LanguageModelDictContentTest, TestGetWordProbability) {
languageModelDictContent.setProbabilityEntry(prevWordIds[0], &probabilityEntry);
languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1), wordId,
&bigramProbabilityEntry);
- EXPECT_EQ(bigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId,
- nullptr /* headerPolicy */));
+ EXPECT_EQ(bigramProbability, languageModelDictContent.getWordAttributes(prevWordIds, wordId,
+ nullptr /* headerPolicy */).getProbability());
const ProbabilityEntry trigramProbabilityEntry(flag, trigramProbability);
languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1),
prevWordIds[1], &probabilityEntry);
languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(2), wordId,
&trigramProbabilityEntry);
- EXPECT_EQ(trigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId,
- nullptr /* headerPolicy */));
+ EXPECT_EQ(trigramProbability, languageModelDictContent.getWordAttributes(prevWordIds, wordId,
+ nullptr /* headerPolicy */).getProbability());
}
} // namespace
diff --git a/native/jni/tests/utils/int_array_view_test.cpp b/native/jni/tests/utils/int_array_view_test.cpp
index 487bd04b1..4757a416b 100644
--- a/native/jni/tests/utils/int_array_view_test.cpp
+++ b/native/jni/tests/utils/int_array_view_test.cpp
@@ -144,5 +144,12 @@ TEST(IntArrayViewTest, TestLastOrDefault) {
EXPECT_EQ(10, intArrayView.skip(6).lastOrDefault(10));
}
+TEST(IntArrayViewTest, TestToVector) {
+ const std::vector<int> intVector = {3, 2, 1, 0, -1, -2};
+ IntArrayView intArrayView(intVector);
+ EXPECT_EQ(intVector, intArrayView.toVector());
+ EXPECT_EQ(std::vector<int>(), CodePointArrayView().toVector());
+}
+
} // namespace
} // namespace latinime