aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp4
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp34
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h2
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp24
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h5
-rw-r--r--native/jni/src/utils/int_array_view.h6
-rw-r--r--native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp58
-rw-r--r--native/jni/tests/utils/int_array_view_test.cpp2
8 files changed, 95 insertions, 40 deletions
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
index 88982e540..df3daa816 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
@@ -354,7 +354,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI
}
bool addedNewBigram = false;
const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]);
- if (mUpdatingHelper.addNgramEntry(PtNodePosArrayView::fromObject(&prevWordPtNodePos),
+ if (mUpdatingHelper.addNgramEntry(PtNodePosArrayView::singleElementView(&prevWordPtNodePos),
wordPos, bigramProperty, &addedNewBigram)) {
if (addedNewBigram) {
mBigramCount++;
@@ -396,7 +396,7 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor
}
const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]);
if (mUpdatingHelper.removeNgramEntry(
- PtNodePosArrayView::fromObject(&prevWordPtNodePos), wordPos)) {
+ PtNodePosArrayView::singleElementView(&prevWordPtNodePos), wordPos)) {
mBigramCount--;
return true;
} else {
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
index d5749e9eb..f54bb151a 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
@@ -38,6 +38,40 @@ bool LanguageModelDictContent::runGC(
0 /* nextLevelBitmapEntryIndex */, outNgramCount);
}
+int LanguageModelDictContent::getWordProbability(const WordIdArrayView prevWordIds,
+ const int wordId) const {
+ int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
+ bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex();
+ int maxLevel = 0;
+ for (size_t i = 0; i < prevWordIds.size(); ++i) {
+ const int nextBitmapEntryIndex =
+ mTrieMap.get(prevWordIds[i], bitmapEntryIndices[i]).mNextLevelBitmapEntryIndex;
+ if (nextBitmapEntryIndex == TrieMap::INVALID_INDEX) {
+ break;
+ }
+ maxLevel = i + 1;
+ bitmapEntryIndices[i + 1] = nextBitmapEntryIndex;
+ }
+
+ for (int i = maxLevel; i >= 0; --i) {
+ const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndices[i]);
+ if (!result.mIsValid) {
+ continue;
+ }
+ const int probability =
+ ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo).getProbability();
+ if (mHasHistoricalInfo) {
+ return std::min(
+ probability + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */),
+ MAX_PROBABILITY);
+ } else {
+ return probability;
+ }
+ }
+ // Cannot find the word.
+ return NOT_A_PROBABILITY;
+}
+
ProbabilityEntry LanguageModelDictContent::getNgramProbabilityEntry(
const WordIdArrayView prevWordIds, const int wordId) const {
const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds);
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
index aa612e35a..4e0b47036 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
@@ -128,6 +128,8 @@ class LanguageModelDictContent {
const LanguageModelDictContent *const originalContent,
int *const outNgramCount);
+ int getWordProbability(const WordIdArrayView prevWordIds, const int wordId) const;
+
ProbabilityEntry getProbabilityEntry(const int wordId) const {
return getNgramProbabilityEntry(WordIdArrayView(), wordId);
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
index 6de3e5a81..308c35585 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
@@ -115,24 +115,12 @@ int Ver4PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints,
int Ver4PatriciaTriePolicy::getProbabilityOfWordInContext(const int *const prevWordIds,
const int wordId, MultiBigramMap *const multiBigramMap) const {
- // TODO: Quit using MultiBigramMap.
if (wordId == NOT_A_WORD_ID) {
return NOT_A_PROBABILITY;
}
- const int ptNodePos =
- mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
- const PtNodeParams ptNodeParams(mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos));
- if (multiBigramMap) {
- return multiBigramMap->getBigramProbability(this /* structurePolicy */, prevWordIds,
- wordId, ptNodeParams.getProbability());
- }
- if (prevWordIds) {
- const int probability = getProbabilityOfWord(prevWordIds, wordId);
- if (probability != NOT_A_PROBABILITY) {
- return probability;
- }
- }
- return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
+ // TODO: Support n-gram.
+ return mBuffers->getLanguageModelDictContent()->getWordProbability(
+ WordIdArrayView::singleElementView(prevWordIds), wordId);
}
int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability,
@@ -166,7 +154,7 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds,
// TODO: Support n-gram.
const ProbabilityEntry probabilityEntry =
mBuffers->getLanguageModelDictContent()->getNgramProbabilityEntry(
- IntArrayView::fromObject(prevWordIds), wordId);
+ IntArrayView::singleElementView(prevWordIds), wordId);
if (!probabilityEntry.isValid()) {
return NOT_A_PROBABILITY;
}
@@ -194,7 +182,7 @@ void Ver4PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordIds,
// TODO: Support n-gram.
const auto languageModelDictContent = mBuffers->getLanguageModelDictContent();
for (const auto entry : languageModelDictContent->getProbabilityEntries(
- WordIdArrayView::fromObject(prevWordIds))) {
+ WordIdArrayView::singleElementView(prevWordIds))) {
const ProbabilityEntry &probabilityEntry = entry.getProbabilityEntry();
const int probability = probabilityEntry.hasHistoricalInfo() ?
ForgettingCurveUtils::decodeProbability(
@@ -511,7 +499,7 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
// Fetch bigram information.
// TODO: Support n-gram.
std::vector<BigramProperty> bigrams;
- const WordIdArrayView prevWordIds = WordIdArrayView::fromObject(&wordId);
+ const WordIdArrayView prevWordIds = WordIdArrayView::singleElementView(&wordId);
int bigramWord1CodePoints[MAX_WORD_LENGTH];
for (const auto entry : mBuffers->getLanguageModelDictContent()->getProbabilityEntries(
prevWordIds)) {
diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h
index 9910777b8..313eb6b64 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h
@@ -48,6 +48,11 @@ class ForgettingCurveUtils {
static bool needsToDecay(const bool mindsBlockByDecay, const int unigramCount,
const int bigramCount, const HeaderPolicy *const headerPolicy);
+ // TODO: Improve probability computation method and remove this.
+ static int getProbabilityBiasForNgram(const int n) {
+ return (n - 1) * MULTIPLIER_TWO_IN_PROBABILITY_SCALE;
+ }
+
AK_FORCE_INLINE static int getUnigramCountHardLimit(const int maxUnigramCount) {
return static_cast<int>(static_cast<float>(maxUnigramCount)
* UNIGRAM_COUNT_HARD_LIMIT_WEIGHT);
diff --git a/native/jni/src/utils/int_array_view.h b/native/jni/src/utils/int_array_view.h
index c9c3b21d4..08256bdef 100644
--- a/native/jni/src/utils/int_array_view.h
+++ b/native/jni/src/utils/int_array_view.h
@@ -61,9 +61,9 @@ class IntArrayView {
return IntArrayView(array, N);
}
- // Returns a view that points one int object. Does not take ownership of the given object.
- AK_FORCE_INLINE static IntArrayView fromObject(const int *const object) {
- return IntArrayView(object, 1);
+ // Returns a view that points one int object.
+ AK_FORCE_INLINE static IntArrayView singleElementView(const int *const ptr) {
+ return IntArrayView(ptr, 1);
}
AK_FORCE_INLINE int operator[](const size_t index) const {
diff --git a/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp b/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp
index ca8d56f27..e6f0353e3 100644
--- a/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp
+++ b/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp
@@ -26,28 +26,28 @@ namespace latinime {
namespace {
TEST(LanguageModelDictContentTest, TestUnigramProbability) {
- LanguageModelDictContent LanguageModelDictContent(false /* useHistoricalInfo */);
+ LanguageModelDictContent languageModelDictContent(false /* useHistoricalInfo */);
const int flag = 0xFF;
const int probability = 10;
const int wordId = 100;
const ProbabilityEntry probabilityEntry(flag, probability);
- LanguageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry);
+ languageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry);
const ProbabilityEntry entry =
- LanguageModelDictContent.getProbabilityEntry(wordId);
+ languageModelDictContent.getProbabilityEntry(wordId);
EXPECT_EQ(flag, entry.getFlags());
EXPECT_EQ(probability, entry.getProbability());
// Remove
- EXPECT_TRUE(LanguageModelDictContent.removeProbabilityEntry(wordId));
- EXPECT_FALSE(LanguageModelDictContent.getProbabilityEntry(wordId).isValid());
- EXPECT_FALSE(LanguageModelDictContent.removeProbabilityEntry(wordId));
- EXPECT_TRUE(LanguageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry));
- EXPECT_TRUE(LanguageModelDictContent.getProbabilityEntry(wordId).isValid());
+ EXPECT_TRUE(languageModelDictContent.removeProbabilityEntry(wordId));
+ EXPECT_FALSE(languageModelDictContent.getProbabilityEntry(wordId).isValid());
+ EXPECT_FALSE(languageModelDictContent.removeProbabilityEntry(wordId));
+ EXPECT_TRUE(languageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry));
+ EXPECT_TRUE(languageModelDictContent.getProbabilityEntry(wordId).isValid());
}
TEST(LanguageModelDictContentTest, TestUnigramProbabilityWithHistoricalInfo) {
- LanguageModelDictContent LanguageModelDictContent(true /* useHistoricalInfo */);
+ LanguageModelDictContent languageModelDictContent(true /* useHistoricalInfo */);
const int flag = 0xF0;
const int timestamp = 0x3FFFFFFF;
@@ -56,19 +56,19 @@ TEST(LanguageModelDictContentTest, TestUnigramProbabilityWithHistoricalInfo) {
const int wordId = 100;
const HistoricalInfo historicalInfo(timestamp, level, count);
const ProbabilityEntry probabilityEntry(flag, &historicalInfo);
- LanguageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry);
- const ProbabilityEntry entry = LanguageModelDictContent.getProbabilityEntry(wordId);
+ languageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry);
+ const ProbabilityEntry entry = languageModelDictContent.getProbabilityEntry(wordId);
EXPECT_EQ(flag, entry.getFlags());
EXPECT_EQ(timestamp, entry.getHistoricalInfo()->getTimeStamp());
EXPECT_EQ(level, entry.getHistoricalInfo()->getLevel());
EXPECT_EQ(count, entry.getHistoricalInfo()->getCount());
// Remove
- EXPECT_TRUE(LanguageModelDictContent.removeProbabilityEntry(wordId));
- EXPECT_FALSE(LanguageModelDictContent.getProbabilityEntry(wordId).isValid());
- EXPECT_FALSE(LanguageModelDictContent.removeProbabilityEntry(wordId));
- EXPECT_TRUE(LanguageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry));
- EXPECT_TRUE(LanguageModelDictContent.removeProbabilityEntry(wordId));
+ EXPECT_TRUE(languageModelDictContent.removeProbabilityEntry(wordId));
+ EXPECT_FALSE(languageModelDictContent.getProbabilityEntry(wordId).isValid());
+ EXPECT_FALSE(languageModelDictContent.removeProbabilityEntry(wordId));
+ EXPECT_TRUE(languageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry));
+ EXPECT_TRUE(languageModelDictContent.removeProbabilityEntry(wordId));
}
TEST(LanguageModelDictContentTest, TestIterateProbabilityEntry) {
@@ -89,5 +89,31 @@ TEST(LanguageModelDictContentTest, TestIterateProbabilityEntry) {
EXPECT_TRUE(wordIdSet.empty());
}
+TEST(LanguageModelDictContentTest, TestGetWordProbability) {
+ LanguageModelDictContent languageModelDictContent(false /* useHistoricalInfo */);
+
+ const int flag = 0xFF;
+ const int probability = 10;
+ const int bigramProbability = 20;
+ const int trigramProbability = 30;
+ const int wordId = 100;
+ const int prevWordIdArray[] = { 1, 2 };
+ const WordIdArrayView prevWordIds = WordIdArrayView::fromFixedSizeArray(prevWordIdArray);
+
+ const ProbabilityEntry probabilityEntry(flag, probability);
+ languageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry);
+ const ProbabilityEntry bigramProbabilityEntry(flag, bigramProbability);
+ languageModelDictContent.setProbabilityEntry(prevWordIds[0], &probabilityEntry);
+ languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1), wordId,
+ &bigramProbabilityEntry);
+ EXPECT_EQ(bigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId));
+ const ProbabilityEntry trigramProbabilityEntry(flag, trigramProbability);
+ languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1),
+ prevWordIds[1], &probabilityEntry);
+ languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(2), wordId,
+ &trigramProbabilityEntry);
+ EXPECT_EQ(trigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId));
+}
+
} // namespace
} // namespace latinime
diff --git a/native/jni/tests/utils/int_array_view_test.cpp b/native/jni/tests/utils/int_array_view_test.cpp
index 161df2f43..93bad5822 100644
--- a/native/jni/tests/utils/int_array_view_test.cpp
+++ b/native/jni/tests/utils/int_array_view_test.cpp
@@ -52,7 +52,7 @@ TEST(IntArrayViewTest, TestConstructFromArray) {
TEST(IntArrayViewTest, TestConstructFromObject) {
const int object = 10;
- const auto intArrayView = IntArrayView::fromObject(&object);
+ const auto intArrayView = IntArrayView::singleElementView(&object);
EXPECT_EQ(1u, intArrayView.size());
EXPECT_EQ(object, intArrayView[0]);
}