aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp8
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h5
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp1
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h1
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp29
-rw-r--r--native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp2
6 files changed, 28 insertions, 18 deletions
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
index 0675de6fa..85d6d434d 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
@@ -167,6 +167,14 @@ int LanguageModelDictContent::createAndGetBitmapEntryIndex(const WordIdArrayView
if (lastBitmapEntryIndex == TrieMap::INVALID_INDEX) {
return TrieMap::INVALID_INDEX;
}
+ const int oldestPrevWordId = prevWordIds[prevWordIds.size() - 1];
+ const TrieMap::Result result = mTrieMap.get(oldestPrevWordId, lastBitmapEntryIndex);
+ if (!result.mIsValid) {
+ if (!mTrieMap.put(oldestPrevWordId,
+ ProbabilityEntry().encode(mHasHistoricalInfo), lastBitmapEntryIndex)) {
+ return TrieMap::INVALID_INDEX;
+ }
+ }
return mTrieMap.getNextLevelBitmapEntryIndex(prevWordIds[prevWordIds.size() - 1],
lastBitmapEntryIndex);
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h
index 3dfaba755..f1bf12cb2 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h
@@ -36,7 +36,8 @@ class ProbabilityEntry {
// Dummy entry
ProbabilityEntry()
- : mFlags(0), mProbability(NOT_A_PROBABILITY), mHistoricalInfo() {}
+ : mFlags(Ver4DictConstants::FLAG_NOT_A_VALID_ENTRY), mProbability(NOT_A_PROBABILITY),
+ mHistoricalInfo() {}
// Entry without historical information
ProbabilityEntry(const int flags, const int probability)
@@ -61,7 +62,7 @@ class ProbabilityEntry {
bigramProperty->getCount()) {}
bool isValid() const {
- return (mProbability != NOT_A_PROBABILITY) || hasHistoricalInfo();
+ return (mFlags & Ver4DictConstants::FLAG_NOT_A_VALID_ENTRY) == 0;
}
bool hasHistoricalInfo() const {
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp
index 9acf2d44f..39822b94a 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp
@@ -53,6 +53,7 @@ const int Ver4DictConstants::WORD_LEVEL_FIELD_SIZE = 1;
const int Ver4DictConstants::WORD_COUNT_FIELD_SIZE = 1;
const uint8_t Ver4DictConstants::FLAG_REPRESENTS_BEGINNING_OF_SENTENCE = 0x1;
+const uint8_t Ver4DictConstants::FLAG_NOT_A_VALID_ENTRY = 0x2;
const int Ver4DictConstants::SHORTCUT_ADDRESS_TABLE_BLOCK_SIZE = 64;
const int Ver4DictConstants::SHORTCUT_ADDRESS_TABLE_DATA_SIZE = 4;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h
index 97035311e..dfcdd4d6f 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h
@@ -51,6 +51,7 @@ class Ver4DictConstants {
static const int WORD_COUNT_FIELD_SIZE;
// Flags in probability entry.
static const uint8_t FLAG_REPRESENTS_BEGINNING_OF_SENTENCE;
+ static const uint8_t FLAG_NOT_A_VALID_ENTRY;
static const int SHORTCUT_ADDRESS_TABLE_BLOCK_SIZE;
static const int SHORTCUT_ADDRESS_TABLE_DATA_SIZE;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
index b808c904d..2af218ab6 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
@@ -120,16 +120,15 @@ const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext(
const int ptNodePos =
mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
- // TODO: Support n-gram.
const int probability = mBuffers->getLanguageModelDictContent()->getWordProbability(
- prevWordIds.limit(1 /* maxSize */), wordId, mHeaderPolicy);
+ prevWordIds, wordId, mHeaderPolicy);
return WordAttributes(probability, ptNodeParams.isBlacklisted(), ptNodeParams.isNotAWord(),
probability == 0);
}
int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds,
const int wordId) const {
- if (wordId == NOT_A_WORD_ID) {
+ if (wordId == NOT_A_WORD_ID || prevWordIds.contains(NOT_A_WORD_ID)) {
return NOT_A_PROBABILITY;
}
const int ptNodePos =
@@ -138,10 +137,8 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordI
if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) {
return NOT_A_PROBABILITY;
}
- // TODO: Support n-gram.
const ProbabilityEntry probabilityEntry =
- mBuffers->getLanguageModelDictContent()->getNgramProbabilityEntry(
- prevWordIds.limit(1 /* maxSize */), wordId);
+ mBuffers->getLanguageModelDictContent()->getNgramProbabilityEntry(prevWordIds, wordId);
if (!probabilityEntry.isValid()) {
return NOT_A_PROBABILITY;
}
@@ -164,16 +161,18 @@ void Ver4PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordI
if (prevWordIds.empty()) {
return;
}
- // TODO: Support n-gram.
const auto languageModelDictContent = mBuffers->getLanguageModelDictContent();
- for (const auto entry : languageModelDictContent->getProbabilityEntries(
- prevWordIds.limit(1 /* maxSize */))) {
- const ProbabilityEntry &probabilityEntry = entry.getProbabilityEntry();
- const int probability = probabilityEntry.hasHistoricalInfo() ?
- ForgettingCurveUtils::decodeProbability(
- probabilityEntry.getHistoricalInfo(), mHeaderPolicy) :
- probabilityEntry.getProbability();
- listener->onVisitEntry(probability, entry.getWordId());
+ for (size_t i = 1; i <= prevWordIds.size(); ++i) {
+ for (const auto entry : languageModelDictContent->getProbabilityEntries(
+ prevWordIds.limit(i))) {
+ const ProbabilityEntry &probabilityEntry = entry.getProbabilityEntry();
+ const int probability = probabilityEntry.hasHistoricalInfo() ?
+ ForgettingCurveUtils::decodeProbability(
+ probabilityEntry.getHistoricalInfo(), mHeaderPolicy)
+ + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */) :
+ probabilityEntry.getProbability();
+ listener->onVisitEntry(probability, entry.getWordId());
+ }
}
}
diff --git a/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp b/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp
index c5849d054..06f82df52 100644
--- a/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp
+++ b/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp
@@ -29,7 +29,7 @@ namespace {
TEST(LanguageModelDictContentTest, TestUnigramProbability) {
LanguageModelDictContent languageModelDictContent(false /* useHistoricalInfo */);
- const int flag = 0xFF;
+ const int flag = 0xF0;
const int probability = 10;
const int wordId = 100;
const ProbabilityEntry probabilityEntry(flag, probability);