aboutsummaryrefslogtreecommitdiffstats
path: root/native
diff options
context:
space:
mode:
Diffstat (limited to 'native')
-rw-r--r--native/jni/src/suggest/core/dictionary/dictionary.cpp6
-rw-r--r--native/jni/src/suggest/core/dictionary/property/unigram_property.h5
-rw-r--r--native/jni/src/suggest/core/dictionary/property/word_property.h7
-rw-r--r--native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h5
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp25
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h5
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp9
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h3
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp30
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h8
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp17
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h3
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h30
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp3
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h3
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp14
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h3
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp59
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h5
-rw-r--r--native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp8
20 files changed, 129 insertions, 119 deletions
diff --git a/native/jni/src/suggest/core/dictionary/dictionary.cpp b/native/jni/src/suggest/core/dictionary/dictionary.cpp
index e4084b0f5..a3bb408c3 100644
--- a/native/jni/src/suggest/core/dictionary/dictionary.cpp
+++ b/native/jni/src/suggest/core/dictionary/dictionary.cpp
@@ -77,10 +77,8 @@ void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbabi
return;
}
int targetWordCodePoints[MAX_WORD_LENGTH];
- int unigramProbability = 0;
- const int codePointCount = mDictStructurePolicy->
- getCodePointsAndProbabilityAndReturnCodePointCount(targetWordId, MAX_WORD_LENGTH,
- targetWordCodePoints, &unigramProbability);
+ const int codePointCount = mDictStructurePolicy->getCodePointsAndReturnCodePointCount(
+ targetWordId, MAX_WORD_LENGTH, targetWordCodePoints);
if (codePointCount <= 0) {
return;
}
diff --git a/native/jni/src/suggest/core/dictionary/property/unigram_property.h b/native/jni/src/suggest/core/dictionary/property/unigram_property.h
index 902eb000f..65c8333bb 100644
--- a/native/jni/src/suggest/core/dictionary/property/unigram_property.h
+++ b/native/jni/src/suggest/core/dictionary/property/unigram_property.h
@@ -71,6 +71,11 @@ class UnigramProperty {
return mIsBlacklisted;
}
+ bool isPossiblyOffensive() const {
+ // TODO: Have dedicated flag.
+ return mProbability == 0;
+ }
+
bool hasShortcuts() const {
return !mShortcuts.empty();
}
diff --git a/native/jni/src/suggest/core/dictionary/property/word_property.h b/native/jni/src/suggest/core/dictionary/property/word_property.h
index aa3e0b68a..f78380e15 100644
--- a/native/jni/src/suggest/core/dictionary/property/word_property.h
+++ b/native/jni/src/suggest/core/dictionary/property/word_property.h
@@ -23,6 +23,7 @@
#include "jni.h"
#include "suggest/core/dictionary/property/bigram_property.h"
#include "suggest/core/dictionary/property/unigram_property.h"
+#include "utils/int_array_view.h"
namespace latinime {
@@ -33,10 +34,10 @@ class WordProperty {
WordProperty()
: mCodePoints(), mUnigramProperty(), mBigrams() {}
- WordProperty(const std::vector<int> *const codePoints,
- const UnigramProperty *const unigramProperty,
+ WordProperty(const CodePointArrayView codePoints, const UnigramProperty *const unigramProperty,
const std::vector<BigramProperty> *const bigrams)
- : mCodePoints(*codePoints), mUnigramProperty(*unigramProperty), mBigrams(*bigrams) {}
+ : mCodePoints(codePoints.begin(), codePoints.end()), mUnigramProperty(*unigramProperty),
+ mBigrams(*bigrams) {}
void outputProperties(JNIEnv *const env, jintArray outCodePoints, jbooleanArray outFlags,
jintArray outProbabilityInfo, jobject outBigramTargets, jobject outBigramProbabilities,
diff --git a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h
index a498b6f65..1546b2610 100644
--- a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h
+++ b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h
@@ -51,9 +51,8 @@ class DictionaryStructureWithBufferPolicy {
virtual void createAndGetAllChildDicNodes(const DicNode *const dicNode,
DicNodeVector *const childDicNodes) const = 0;
- virtual int getCodePointsAndProbabilityAndReturnCodePointCount(
- const int wordId, const int maxCodePointCount, int *const outCodePoints,
- int *const outUnigramProbability) const = 0;
+ virtual int getCodePointsAndReturnCodePointCount(const int wordId, const int maxCodePointCount,
+ int *const outCodePoints) const = 0;
virtual int getWordId(const CodePointArrayView wordCodePoints,
const bool forceLowerCaseSearch) const = 0;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
index ee1403739..3187aa9ac 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
@@ -87,14 +87,13 @@ void Ver4PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const d
}
}
-int Ver4PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
- const int wordId, const int maxCodePointCount, int *const outCodePoints,
- int *const outUnigramProbability) const {
+int Ver4PatriciaTriePolicy::getCodePointsAndReturnCodePointCount(const int wordId,
+ const int maxCodePointCount, int *const outCodePoints) const {
DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader);
const int ptNodePos = getTerminalPtNodePosFromWordId(wordId);
readingHelper.initWithPtNodePos(ptNodePos);
- const int codePointCount = readingHelper.getCodePointsAndProbabilityAndReturnCodePointCount(
- maxCodePointCount, outCodePoints, outUnigramProbability);
+ const int codePointCount = readingHelper.getCodePointsAndReturnCodePointCount(
+ maxCodePointCount, outCodePoints);
if (readingHelper.isError()) {
mIsCorrupted = true;
AKLOGE("Dictionary reading error in getCodePointsAndProbabilityAndReturnCodePointCount().");
@@ -495,8 +494,6 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
return WordProperty();
}
const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
- std::vector<int> codePointVector(ptNodeParams.getCodePoints(),
- ptNodeParams.getCodePoints() + ptNodeParams.getCodePointCount());
const ProbabilityEntry probabilityEntry =
mBuffers->getProbabilityDictContent()->getProbabilityEntry(
ptNodeParams.getTerminalId());
@@ -521,11 +518,9 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
if (word1TerminalPtNodePos == NOT_A_DICT_POS) {
continue;
}
- // Word (unigram) probability
- int word1Probability = NOT_A_PROBABILITY;
- const int codePointCount = getCodePointsAndProbabilityAndReturnCodePointCount(
+ const int codePointCount = getCodePointsAndReturnCodePointCount(
getWordIdFromTerminalPtNodePos(word1TerminalPtNodePos), MAX_WORD_LENGTH,
- bigramWord1CodePoints, &word1Probability);
+ bigramWord1CodePoints);
const std::vector<int> word1(bigramWord1CodePoints,
bigramWord1CodePoints + codePointCount);
const HistoricalInfo *const historicalInfo = bigramEntry.getHistoricalInfo();
@@ -559,7 +554,7 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
ptNodeParams.isNotAWord(), ptNodeParams.isBlacklisted(), ptNodeParams.getProbability(),
historicalInfo->getTimeStamp(), historicalInfo->getLevel(),
historicalInfo->getCount(), &shortcuts);
- return WordProperty(&codePointVector, &unigramProperty, &bigrams);
+ return WordProperty(wordCodePoints, &unigramProperty, &bigrams);
}
int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outCodePoints,
@@ -580,10 +575,8 @@ int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const
return 0;
}
const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token];
- int unigramProbability = NOT_A_PROBABILITY;
- *outCodePointCount = getCodePointsAndProbabilityAndReturnCodePointCount(
- getWordIdFromTerminalPtNodePos(terminalPtNodePos), MAX_WORD_LENGTH, outCodePoints,
- &unigramProbability);
+ *outCodePointCount = getCodePointsAndReturnCodePointCount(
+ getWordIdFromTerminalPtNodePos(terminalPtNodePos), MAX_WORD_LENGTH, outCodePoints);
const int nextToken = token + 1;
if (nextToken >= terminalPtNodePositionsVectorSize) {
// All words have been iterated.
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h
index 576d2abb5..8420c94d0 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h
@@ -85,9 +85,8 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
void createAndGetAllChildDicNodes(const DicNode *const dicNode,
DicNodeVector *const childDicNodes) const;
- int getCodePointsAndProbabilityAndReturnCodePointCount(
- const int wordId, const int maxCodePointCount, int *const outCodePoints,
- int *const outUnigramProbability) const;
+ int getCodePointsAndReturnCodePointCount(const int wordId, const int maxCodePointCount,
+ int *const outCodePoints) const;
int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp
index 40782a44f..5e4a4b166 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp
@@ -175,8 +175,8 @@ bool DynamicPtReadingHelper::traverseAllPtNodesInPtNodeArrayLevelPreorderDepthFi
return !isError();
}
-int DynamicPtReadingHelper::getCodePointsAndProbabilityAndReturnCodePointCount(
- const int maxCodePointCount, int *const outCodePoints, int *const outUnigramProbability) {
+int DynamicPtReadingHelper::getCodePointsAndReturnCodePointCount(const int maxCodePointCount,
+ int *const outCodePoints) {
// This method traverses parent nodes from the terminal by following parent pointers; thus,
// node code points are stored in the buffer in the reverse order.
int reverseCodePoints[maxCodePointCount];
@@ -184,11 +184,8 @@ int DynamicPtReadingHelper::getCodePointsAndProbabilityAndReturnCodePointCount(
// First, read the terminal node and get its probability.
if (!isValidTerminalNode(terminalPtNodeParams)) {
// Node at the ptNodePos is not a valid terminal node.
- *outUnigramProbability = NOT_A_PROBABILITY;
return 0;
}
- // Store terminal node probability.
- *outUnigramProbability = terminalPtNodeParams.getProbability();
// Then, following parent node link to the dictionary root and fetch node code points.
int totalCodePointCount = 0;
while (!isEnd()) {
@@ -196,7 +193,6 @@ int DynamicPtReadingHelper::getCodePointsAndProbabilityAndReturnCodePointCount(
totalCodePointCount = getTotalCodePointCount(ptNodeParams);
if (!ptNodeParams.isValid() || totalCodePointCount > maxCodePointCount) {
// The ptNodePos is not a valid terminal node position in the dictionary.
- *outUnigramProbability = NOT_A_PROBABILITY;
return 0;
}
// Store node code points to buffer in the reverse order.
@@ -207,7 +203,6 @@ int DynamicPtReadingHelper::getCodePointsAndProbabilityAndReturnCodePointCount(
}
if (isError()) {
// The node position or the dictionary is invalid.
- *outUnigramProbability = NOT_A_PROBABILITY;
return 0;
}
// Reverse the stored code points to output them.
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h
index 9a7abc97f..21c287fdc 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h
@@ -211,8 +211,7 @@ class DynamicPtReadingHelper {
bool traverseAllPtNodesInPtNodeArrayLevelPreorderDepthFirstManner(
TraversingEventListener *const listener);
- int getCodePointsAndProbabilityAndReturnCodePointCount(const int maxCodePointCount,
- int *const outCodePoints, int *const outUnigramProbability);
+ int getCodePointsAndReturnCodePointCount(const int maxCodePointCount, int *const outCodePoints);
int getTerminalPtNodePositionOfWord(const int *const inWord, const size_t length,
const bool forceLowerCaseSearch);
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp
index 6e7dba9ff..20e0e7476 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp
@@ -58,6 +58,11 @@ void PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const dicNo
}
}
+int PatriciaTriePolicy::getCodePointsAndReturnCodePointCount(const int wordId,
+ const int maxCodePointCount, int *const outCodePoints) const {
+ return getCodePointsAndProbabilityAndReturnCodePointCount(wordId, maxCodePointCount,
+ outCodePoints, nullptr /* outUnigramProbability */);
+}
// This retrieves code points and the probability of the word by its id.
// Due to the fact that words are ordered in the dictionary in a strict breadth-first order,
// it is possible to check for this with advantageous complexity. For each PtNode array, we search
@@ -82,6 +87,9 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
int pos = getRootPosition();
int wordPos = 0;
const int *const codePointTable = mHeaderPolicy.getCodePointTable();
+ if (outUnigramProbability) {
+ *outUnigramProbability = NOT_A_PROBABILITY;
+ }
// One iteration of the outer loop iterates through PtNode arrays. As stated above, we will
// only traverse PtNodes that are actually a part of the terminal we are searching, so each
// time we enter this loop we are one depth level further than last time.
@@ -97,7 +105,6 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
pos, mBuffer.size());
mIsCorrupted = true;
ASSERT(false);
- *outUnigramProbability = NOT_A_PROBABILITY;
return 0;
}
for (int ptNodeCount = PatriciaTrieReadingUtils::getPtNodeArraySizeAndAdvancePosition(
@@ -107,7 +114,6 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
AKLOGE("PtNode position is invalid. pos: %d, dict size: %zd", pos, mBuffer.size());
mIsCorrupted = true;
ASSERT(false);
- *outUnigramProbability = NOT_A_PROBABILITY;
return 0;
}
const PatriciaTrieReadingUtils::NodeFlags flags =
@@ -130,9 +136,11 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
mBuffer.data(), codePointTable, &pos);
}
}
- *outUnigramProbability =
- PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mBuffer.data(),
- &pos);
+ if (outUnigramProbability) {
+ *outUnigramProbability =
+ PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(
+ mBuffer.data(), &pos);
+ }
return ++wordPos;
}
// We need to skip past this PtNode, so skip any remaining code points after the
@@ -234,7 +242,6 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
pos);
mIsCorrupted = true;
ASSERT(false);
- *outUnigramProbability = NOT_A_PROBABILITY;
return 0;
}
}
@@ -257,7 +264,6 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
AKLOGE("Cannot skip bigrams. BufSize: %zd, pos: %d.", mBuffer.size(), pos);
mIsCorrupted = true;
ASSERT(false);
- *outUnigramProbability = NOT_A_PROBABILITY;
return 0;
}
}
@@ -429,8 +435,6 @@ const WordProperty PatriciaTriePolicy::getWordProperty(
const int ptNodePos = getTerminalPtNodePosFromWordId(wordId);
const PtNodeParams ptNodeParams =
mPtNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
- std::vector<int> codePointVector(ptNodeParams.getCodePoints(),
- ptNodeParams.getCodePoints() + ptNodeParams.getCodePointCount());
// Fetch bigram information.
std::vector<BigramProperty> bigrams;
const int bigramListPos = getBigramsPositionOfPtNode(ptNodePos);
@@ -475,7 +479,7 @@ const WordProperty PatriciaTriePolicy::getWordProperty(
const UnigramProperty unigramProperty(ptNodeParams.representsBeginningOfSentence(),
ptNodeParams.isNotAWord(), ptNodeParams.isBlacklisted(), ptNodeParams.getProbability(),
NOT_A_TIMESTAMP /* timestamp */, 0 /* level */, 0 /* count */, &shortcuts);
- return WordProperty(&codePointVector, &unigramProperty, &bigrams);
+ return WordProperty(wordCodePoints, &unigramProperty, &bigrams);
}
int PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outCodePoints,
@@ -497,10 +501,8 @@ int PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outC
return 0;
}
const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token];
- int unigramProbability = NOT_A_PROBABILITY;
- *outCodePointCount = getCodePointsAndProbabilityAndReturnCodePointCount(
- getWordIdFromTerminalPtNodePos(terminalPtNodePos), MAX_WORD_LENGTH, outCodePoints,
- &unigramProbability);
+ *outCodePointCount = getCodePointsAndReturnCodePointCount(
+ getWordIdFromTerminalPtNodePos(terminalPtNodePos), MAX_WORD_LENGTH, outCodePoints);
const int nextToken = token + 1;
if (nextToken >= terminalPtNodePositionsVectorSize) {
// All words have been iterated.
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h
index 3cdf6cd16..0d679c5dc 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h
@@ -58,9 +58,8 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
void createAndGetAllChildDicNodes(const DicNode *const dicNode,
DicNodeVector *const childDicNodes) const;
- int getCodePointsAndProbabilityAndReturnCodePointCount(
- const int wordId, const int maxCodePointCount, int *const outCodePoints,
- int *const outUnigramProbability) const;
+ int getCodePointsAndReturnCodePointCount(const int wordId, const int maxCodePointCount,
+ int *const outCodePoints) const;
int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
@@ -155,6 +154,9 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
std::vector<int> mTerminalPtNodePositionsForIteratingWords;
mutable bool mIsCorrupted;
+ int getCodePointsAndProbabilityAndReturnCodePointCount(const int wordId,
+ const int maxCodePointCount, int *const outCodePoints,
+ int *const outUnigramProbability) const;
int getShortcutPositionOfPtNode(const int ptNodePos) const;
int getBigramsPositionOfPtNode(const int ptNodePos) const;
int createAndGetLeavingChildNode(const DicNode *const dicNode, const int ptNodePos,
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
index 35f0f768f..89094c83a 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
@@ -38,7 +38,7 @@ bool LanguageModelDictContent::runGC(
0 /* nextLevelBitmapEntryIndex */, outNgramCount);
}
-int LanguageModelDictContent::getWordProbability(const WordIdArrayView prevWordIds,
+const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArrayView prevWordIds,
const int wordId, const HeaderPolicy *const headerPolicy) const {
int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex();
@@ -60,17 +60,24 @@ int LanguageModelDictContent::getWordProbability(const WordIdArrayView prevWordI
}
const ProbabilityEntry probabilityEntry =
ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo);
+ int probability = NOT_A_PROBABILITY;
if (mHasHistoricalInfo) {
- const int probability = ForgettingCurveUtils::decodeProbability(
+ const int rawProbability = ForgettingCurveUtils::decodeProbability(
probabilityEntry.getHistoricalInfo(), headerPolicy)
+ ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */);
- return std::min(probability, MAX_PROBABILITY);
+ probability = std::min(rawProbability, MAX_PROBABILITY);
} else {
- return probabilityEntry.getProbability();
+ probability = probabilityEntry.getProbability();
}
+ // TODO: Some flags in unigramProbabilityEntry should be overwritten by flags in
+ // probabilityEntry.
+ const ProbabilityEntry unigramProbabilityEntry = getProbabilityEntry(wordId);
+ return WordAttributes(probability, unigramProbabilityEntry.isNotAWord(),
+ unigramProbabilityEntry.isBlacklisted(),
+ unigramProbabilityEntry.isPossiblyOffensive());
}
// Cannot find the word.
- return NOT_A_PROBABILITY;
+ return WordAttributes();
}
ProbabilityEntry LanguageModelDictContent::getNgramProbabilityEntry(
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
index a793af4be..b7e4af977 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
@@ -21,6 +21,7 @@
#include <vector>
#include "defines.h"
+#include "suggest/core/dictionary/word_attributes.h"
#include "suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h"
#include "suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h"
#include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h"
@@ -128,7 +129,7 @@ class LanguageModelDictContent {
const LanguageModelDictContent *const originalContent,
int *const outNgramCount);
- int getWordProbability(const WordIdArrayView prevWordIds, const int wordId,
+ const WordAttributes getWordAttributes(const WordIdArrayView prevWordIds, const int wordId,
const HeaderPolicy *const headerPolicy) const;
ProbabilityEntry getProbabilityEntry(const int wordId) const {
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h
index f1bf12cb2..e1e10ca17 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h
@@ -49,7 +49,9 @@ class ProbabilityEntry {
// Create from unigram property.
ProbabilityEntry(const UnigramProperty *const unigramProperty)
- : mFlags(createFlags(unigramProperty->representsBeginningOfSentence())),
+ : mFlags(createFlags(unigramProperty->representsBeginningOfSentence(),
+ unigramProperty->isNotAWord(), unigramProperty->isBlacklisted(),
+ unigramProperty->isPossiblyOffensive())),
mProbability(unigramProperty->getProbability()),
mHistoricalInfo(unigramProperty->getTimestamp(), unigramProperty->getLevel(),
unigramProperty->getCount()) {}
@@ -85,6 +87,18 @@ class ProbabilityEntry {
return (mFlags & Ver4DictConstants::FLAG_REPRESENTS_BEGINNING_OF_SENTENCE) != 0;
}
+ bool isNotAWord() const {
+ return (mFlags & Ver4DictConstants::FLAG_NOT_A_WORD) != 0;
+ }
+
+ bool isBlacklisted() const {
+ return (mFlags & Ver4DictConstants::FLAG_BLACKLISTED) != 0;
+ }
+
+ bool isPossiblyOffensive() const {
+ return (mFlags & Ver4DictConstants::FLAG_POSSIBLY_OFFENSIVE) != 0;
+ }
+
uint64_t encode(const bool hasHistoricalInfo) const {
uint64_t encodedEntry = static_cast<uint64_t>(mFlags);
if (hasHistoricalInfo) {
@@ -142,10 +156,20 @@ class ProbabilityEntry {
(encodedEntry >> (pos * CHAR_BIT)) & ((1ull << (size * CHAR_BIT)) - 1));
}
- static uint8_t createFlags(const bool representsBeginningOfSentence) {
+ static uint8_t createFlags(const bool representsBeginningOfSentence,
+ const bool isNotAWord, const bool isBlacklisted, const bool isPossiblyOffensive) {
uint8_t flags = 0;
if (representsBeginningOfSentence) {
- flags ^= Ver4DictConstants::FLAG_REPRESENTS_BEGINNING_OF_SENTENCE;
+ flags |= Ver4DictConstants::FLAG_REPRESENTS_BEGINNING_OF_SENTENCE;
+ }
+ if (isNotAWord) {
+ flags |= Ver4DictConstants::FLAG_NOT_A_WORD;
+ }
+ if (isBlacklisted) {
+ flags |= Ver4DictConstants::FLAG_BLACKLISTED;
+ }
+ if (isPossiblyOffensive) {
+ flags |= Ver4DictConstants::FLAG_POSSIBLY_OFFENSIVE;
}
return flags;
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp
index 39822b94a..8e6cb974b 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp
@@ -54,6 +54,9 @@ const int Ver4DictConstants::WORD_COUNT_FIELD_SIZE = 1;
const uint8_t Ver4DictConstants::FLAG_REPRESENTS_BEGINNING_OF_SENTENCE = 0x1;
const uint8_t Ver4DictConstants::FLAG_NOT_A_VALID_ENTRY = 0x2;
+const uint8_t Ver4DictConstants::FLAG_NOT_A_WORD = 0x4;
+const uint8_t Ver4DictConstants::FLAG_BLACKLISTED = 0x8;
+const uint8_t Ver4DictConstants::FLAG_POSSIBLY_OFFENSIVE = 0x10;
const int Ver4DictConstants::SHORTCUT_ADDRESS_TABLE_BLOCK_SIZE = 64;
const int Ver4DictConstants::SHORTCUT_ADDRESS_TABLE_DATA_SIZE = 4;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h
index dfcdd4d6f..600b5ffe4 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h
@@ -52,6 +52,9 @@ class Ver4DictConstants {
// Flags in probability entry.
static const uint8_t FLAG_REPRESENTS_BEGINNING_OF_SENTENCE;
static const uint8_t FLAG_NOT_A_VALID_ENTRY;
+ static const uint8_t FLAG_NOT_A_WORD;
+ static const uint8_t FLAG_BLACKLISTED;
+ static const uint8_t FLAG_POSSIBLY_OFFENSIVE;
static const int SHORTCUT_ADDRESS_TABLE_BLOCK_SIZE;
static const int SHORTCUT_ADDRESS_TABLE_DATA_SIZE;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp
index 75ec16912..a1a33d27a 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp
@@ -191,7 +191,6 @@ bool Ver4PatriciaTrieNodeWriter::writePtNodeAndAdvancePosition(
ptNodeWritingPos);
}
-
bool Ver4PatriciaTrieNodeWriter::writeNewTerminalPtNodeAndAdvancePosition(
const PtNodeParams *const ptNodeParams, const UnigramProperty *const unigramProperty,
int *const ptNodeWritingPos) {
@@ -341,8 +340,8 @@ bool Ver4PatriciaTrieNodeWriter::writePtNodeAndGetTerminalIdAndAdvancePosition(
ptNodeParams->getChildrenPos(), ptNodeWritingPos)) {
return false;
}
- return updatePtNodeFlags(nodePos, ptNodeParams->isBlacklisted(), ptNodeParams->isNotAWord(),
- isTerminal, ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */);
+ return updatePtNodeFlags(nodePos, isTerminal,
+ ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */);
}
// TODO: Move probability handling code to LanguageModelDictContent.
@@ -361,14 +360,13 @@ const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom(
}
}
-bool Ver4PatriciaTrieNodeWriter::updatePtNodeFlags(const int ptNodePos,
- const bool isBlacklisted, const bool isNotAWord, const bool isTerminal,
+bool Ver4PatriciaTrieNodeWriter::updatePtNodeFlags(const int ptNodePos, const bool isTerminal,
const bool hasMultipleChars) {
// Create node flags and write them.
PatriciaTrieReadingUtils::NodeFlags nodeFlags =
- PatriciaTrieReadingUtils::createAndGetFlags(isBlacklisted, isNotAWord, isTerminal,
- false /* hasShortcutTargets */, false /* hasBigrams */, hasMultipleChars,
- CHILDREN_POSITION_FIELD_SIZE);
+ PatriciaTrieReadingUtils::createAndGetFlags(false /* isNotAWord */,
+ false /* isBlacklisted */, isTerminal, false /* hasShortcutTargets */,
+ false /* hasBigrams */, hasMultipleChars, CHILDREN_POSITION_FIELD_SIZE);
if (!DynamicPtWritingUtils::writeFlags(mTrieBuffer, nodeFlags, ptNodePos)) {
AKLOGE("Cannot write PtNode flags. flags: %x, pos: %d", nodeFlags, ptNodePos);
return false;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h
index 08b7d3825..17915273b 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h
@@ -103,8 +103,7 @@ class Ver4PatriciaTrieNodeWriter : public PtNodeWriter {
const ProbabilityEntry *const originalProbabilityEntry,
const ProbabilityEntry *const probabilityEntry) const;
- bool updatePtNodeFlags(const int ptNodePos, const bool isBlacklisted, const bool isNotAWord,
- const bool isTerminal, const bool hasMultipleChars);
+ bool updatePtNodeFlags(const int ptNodePos, const bool isTerminal, const bool hasMultipleChars);
static const int CHILDREN_POSITION_FIELD_SIZE;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
index 8d4135679..0349ba4a0 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
@@ -63,14 +63,10 @@ void Ver4PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const d
// valid terminal DicNode.
isTerminal = ptNodeParams.getProbability() != NOT_A_PROBABILITY;
}
- readingHelper.readNextSiblingNode(ptNodeParams);
- if (ptNodeParams.representsNonWordInfo()) {
- // Skip PtNodes that represent non-word information.
- continue;
- }
const int wordId = isTerminal ? ptNodeParams.getTerminalId() : NOT_A_WORD_ID;
childDicNodes->pushLeavingChild(dicNode, ptNodeParams.getChildrenPos(),
wordId, ptNodeParams.getCodePointArrayView());
+ readingHelper.readNextSiblingNode(ptNodeParams);
}
if (readingHelper.isError()) {
mIsCorrupted = true;
@@ -78,15 +74,14 @@ void Ver4PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const d
}
}
-int Ver4PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount(
- const int wordId, const int maxCodePointCount, int *const outCodePoints,
- int *const outUnigramProbability) const {
+int Ver4PatriciaTriePolicy::getCodePointsAndReturnCodePointCount(const int wordId,
+ const int maxCodePointCount, int *const outCodePoints) const {
DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader);
const int ptNodePos =
mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
readingHelper.initWithPtNodePos(ptNodePos);
- const int codePointCount = readingHelper.getCodePointsAndProbabilityAndReturnCodePointCount(
- maxCodePointCount, outCodePoints, outUnigramProbability);
+ const int codePointCount = readingHelper.getCodePointsAndReturnCodePointCount(
+ maxCodePointCount, outCodePoints);
if (readingHelper.isError()) {
mIsCorrupted = true;
AKLOGE("Dictionary reading error in getCodePointsAndProbabilityAndReturnCodePointCount().");
@@ -117,13 +112,8 @@ const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext(
if (wordId == NOT_A_WORD_ID) {
return WordAttributes();
}
- const int ptNodePos =
- mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
- const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
- const int probability = mBuffers->getLanguageModelDictContent()->getWordProbability(
- prevWordIds, wordId, mHeaderPolicy);
- return WordAttributes(probability, ptNodeParams.isBlacklisted(), ptNodeParams.isNotAWord(),
- probability == 0);
+ return mBuffers->getLanguageModelDictContent()->getWordAttributes(prevWordIds, wordId,
+ mHeaderPolicy);
}
int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds,
@@ -131,15 +121,10 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordI
if (wordId == NOT_A_WORD_ID || prevWordIds.contains(NOT_A_WORD_ID)) {
return NOT_A_PROBABILITY;
}
- const int ptNodePos =
- mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
- const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
- if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) {
- return NOT_A_PROBABILITY;
- }
const ProbabilityEntry probabilityEntry =
mBuffers->getLanguageModelDictContent()->getNgramProbabilityEntry(prevWordIds, wordId);
- if (!probabilityEntry.isValid()) {
+ if (!probabilityEntry.isValid() || probabilityEntry.isBlacklisted()
+ || probabilityEntry.isNotAWord()) {
return NOT_A_PROBABILITY;
}
if (mHeaderPolicy->hasHistoricalInfoOfWords()) {
@@ -166,6 +151,9 @@ void Ver4PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordI
for (const auto entry : languageModelDictContent->getProbabilityEntries(
prevWordIds.limit(i))) {
const ProbabilityEntry &probabilityEntry = entry.getProbabilityEntry();
+ if (!probabilityEntry.isValid()) {
+ continue;
+ }
const int probability = probabilityEntry.hasHistoricalInfo() ?
ForgettingCurveUtils::decodeProbability(
probabilityEntry.getHistoricalInfo(), mHeaderPolicy)
@@ -463,8 +451,6 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
const int ptNodePos =
mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
- std::vector<int> codePointVector(ptNodeParams.getCodePoints(),
- ptNodeParams.getCodePoints() + ptNodeParams.getCodePointCount());
const ProbabilityEntry probabilityEntry =
mBuffers->getLanguageModelDictContent()->getProbabilityEntry(
ptNodeParams.getTerminalId());
@@ -476,10 +462,8 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
int bigramWord1CodePoints[MAX_WORD_LENGTH];
for (const auto entry : mBuffers->getLanguageModelDictContent()->getProbabilityEntries(
prevWordIds)) {
- // Word (unigram) probability
- int word1Probability = NOT_A_PROBABILITY;
- const int codePointCount = getCodePointsAndProbabilityAndReturnCodePointCount(
- entry.getWordId(), MAX_WORD_LENGTH, bigramWord1CodePoints, &word1Probability);
+ const int codePointCount = getCodePointsAndReturnCodePointCount(entry.getWordId(),
+ MAX_WORD_LENGTH, bigramWord1CodePoints);
const std::vector<int> word1(bigramWord1CodePoints,
bigramWord1CodePoints + codePointCount);
const ProbabilityEntry probabilityEntry = entry.getProbabilityEntry();
@@ -508,11 +492,11 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
shortcuts.emplace_back(&target, shortcutProbability);
}
}
- const UnigramProperty unigramProperty(ptNodeParams.representsBeginningOfSentence(),
- ptNodeParams.isNotAWord(), ptNodeParams.isBlacklisted(), ptNodeParams.getProbability(),
- historicalInfo->getTimeStamp(), historicalInfo->getLevel(),
- historicalInfo->getCount(), &shortcuts);
- return WordProperty(&codePointVector, &unigramProperty, &bigrams);
+ const UnigramProperty unigramProperty(probabilityEntry.representsBeginningOfSentence(),
+ probabilityEntry.isNotAWord(), probabilityEntry.isBlacklisted(),
+ probabilityEntry.getProbability(), historicalInfo->getTimeStamp(),
+ historicalInfo->getLevel(), historicalInfo->getCount(), &shortcuts);
+ return WordProperty(wordCodePoints, &unigramProperty, &bigrams);
}
int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outCodePoints,
@@ -535,9 +519,8 @@ int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const
const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token];
const PtNodeParams ptNodeParams =
mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(terminalPtNodePos);
- int unigramProbability = NOT_A_PROBABILITY;
- *outCodePointCount = getCodePointsAndProbabilityAndReturnCodePointCount(
- ptNodeParams.getTerminalId(), MAX_WORD_LENGTH, outCodePoints, &unigramProbability);
+ *outCodePointCount = getCodePointsAndReturnCodePointCount(ptNodeParams.getTerminalId(),
+ MAX_WORD_LENGTH, outCodePoints);
const int nextToken = token + 1;
if (nextToken >= terminalPtNodePositionsVectorSize) {
// All words have been iterated.
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h
index a117a3614..758f8da80 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h
@@ -62,9 +62,8 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
void createAndGetAllChildDicNodes(const DicNode *const dicNode,
DicNodeVector *const childDicNodes) const;
- int getCodePointsAndProbabilityAndReturnCodePointCount(
- const int wordId, const int maxCodePointCount, int *const outCodePoints,
- int *const outUnigramProbability) const;
+ int getCodePointsAndReturnCodePointCount(const int wordId, const int maxCodePointCount,
+ int *const outCodePoints) const;
int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
diff --git a/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp b/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp
index 06f82df52..daa32c348 100644
--- a/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp
+++ b/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp
@@ -107,15 +107,15 @@ TEST(LanguageModelDictContentTest, TestGetWordProbability) {
languageModelDictContent.setProbabilityEntry(prevWordIds[0], &probabilityEntry);
languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1), wordId,
&bigramProbabilityEntry);
- EXPECT_EQ(bigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId,
- nullptr /* headerPolicy */));
+ EXPECT_EQ(bigramProbability, languageModelDictContent.getWordAttributes(prevWordIds, wordId,
+ nullptr /* headerPolicy */).getProbability());
const ProbabilityEntry trigramProbabilityEntry(flag, trigramProbability);
languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1),
prevWordIds[1], &probabilityEntry);
languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(2), wordId,
&trigramProbabilityEntry);
- EXPECT_EQ(trigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId,
- nullptr /* headerPolicy */));
+ EXPECT_EQ(trigramProbability, languageModelDictContent.getWordAttributes(prevWordIds, wordId,
+ nullptr /* headerPolicy */).getProbability());
}
} // namespace