aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp15
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h3
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp7
-rw-r--r--native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp6
4 files changed, 18 insertions, 13 deletions
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
index f54bb151a..0675de6fa 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
@@ -39,7 +39,7 @@ bool LanguageModelDictContent::runGC(
}
int LanguageModelDictContent::getWordProbability(const WordIdArrayView prevWordIds,
- const int wordId) const {
+ const int wordId, const HeaderPolicy *const headerPolicy) const {
int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex();
int maxLevel = 0;
@@ -58,14 +58,15 @@ int LanguageModelDictContent::getWordProbability(const WordIdArrayView prevWordI
if (!result.mIsValid) {
continue;
}
- const int probability =
- ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo).getProbability();
+ const ProbabilityEntry probabilityEntry =
+ ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo);
if (mHasHistoricalInfo) {
- return std::min(
- probability + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */),
- MAX_PROBABILITY);
+ const int probability = ForgettingCurveUtils::decodeProbability(
+ probabilityEntry.getHistoricalInfo(), headerPolicy)
+ + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */);
+ return std::min(probability, MAX_PROBABILITY);
} else {
- return probability;
+ return probabilityEntry.getProbability();
}
}
// Cannot find the word.
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
index 4e0b47036..a793af4be 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
@@ -128,7 +128,8 @@ class LanguageModelDictContent {
const LanguageModelDictContent *const originalContent,
int *const outNgramCount);
- int getWordProbability(const WordIdArrayView prevWordIds, const int wordId) const;
+ int getWordProbability(const WordIdArrayView prevWordIds, const int wordId,
+ const HeaderPolicy *const headerPolicy) const;
ProbabilityEntry getProbabilityEntry(const int wordId) const {
return getNgramProbabilityEntry(WordIdArrayView(), wordId);
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
index e624bf338..1336a6229 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
@@ -121,9 +121,10 @@ const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext(
mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
// TODO: Support n-gram.
- return WordAttributes(mBuffers->getLanguageModelDictContent()->getWordProbability(
- prevWordIds.limit(1 /* maxSize */), wordId), ptNodeParams.isBlacklisted(),
- ptNodeParams.isNotAWord(), ptNodeParams.getProbability() == 0);
+ const int probability = mBuffers->getLanguageModelDictContent()->getWordProbability(
+ prevWordIds.limit(1 /* maxSize */), wordId, mHeaderPolicy);
+ return WordAttributes(probability, ptNodeParams.isBlacklisted(), ptNodeParams.isNotAWord(),
+ probability == 0);
}
int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds,
diff --git a/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp b/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp
index 7608b45c2..c5849d054 100644
--- a/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp
+++ b/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp
@@ -107,13 +107,15 @@ TEST(LanguageModelDictContentTest, TestGetWordProbability) {
languageModelDictContent.setProbabilityEntry(prevWordIds[0], &probabilityEntry);
languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1), wordId,
&bigramProbabilityEntry);
- EXPECT_EQ(bigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId));
+ EXPECT_EQ(bigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId,
+ nullptr /* headerPolicy */));
const ProbabilityEntry trigramProbabilityEntry(flag, trigramProbability);
languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1),
prevWordIds[1], &probabilityEntry);
languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(2), wordId,
&trigramProbabilityEntry);
- EXPECT_EQ(trigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId));
+ EXPECT_EQ(trigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId,
+ nullptr /* headerPolicy */));
}
} // namespace