aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp6
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h71
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h2
-rw-r--r--native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp20
4 files changed, 98 insertions, 1 deletions
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
index ea2d24e67..eb2b1ec3e 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
@@ -71,6 +71,12 @@ bool LanguageModelDictContent::removeNgramProbabilityEntry(const WordIdArrayView
return mTrieMap.remove(wordId, bitmapEntryIndex);
}
+LanguageModelDictContent::EntryRange LanguageModelDictContent::getProbabilityEntries(
+ const WordIdArrayView prevWordIds) const {
+ const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds);
+ return EntryRange(mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex), mHasHistoricalInfo);
+}
+
bool LanguageModelDictContent::truncateEntries(const int *const entryCounts,
const int *const maxEntryCounts, const HeaderPolicy *const headerPolicy) {
for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
index 43b2aab66..961637679 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
@@ -39,6 +39,75 @@ class HeaderPolicy;
*/
class LanguageModelDictContent {
public:
+ // Pair of word id and probability entry used for iteration.
+ class WordIdAndProbabilityEntry {
+ public:
+ WordIdAndProbabilityEntry(const int wordId, const ProbabilityEntry &probabilityEntry)
+ : mWordId(wordId), mProbabilityEntry(probabilityEntry) {}
+
+ int getWordId() const { return mWordId; }
+ const ProbabilityEntry getProbabilityEntry() const { return mProbabilityEntry; }
+
+ private:
+ DISALLOW_DEFAULT_CONSTRUCTOR(WordIdAndProbabilityEntry);
+ DISALLOW_ASSIGNMENT_OPERATOR(WordIdAndProbabilityEntry);
+
+ const int mWordId;
+ const ProbabilityEntry mProbabilityEntry;
+ };
+
+ // Iterator.
+ class EntryIterator {
+ public:
+ EntryIterator(const TrieMap::TrieMapIterator &trieMapIterator,
+ const bool hasHistoricalInfo)
+ : mTrieMapIterator(trieMapIterator), mHasHistoricalInfo(hasHistoricalInfo) {}
+
+ const WordIdAndProbabilityEntry operator*() const {
+ const TrieMap::TrieMapIterator::IterationResult &result = *mTrieMapIterator;
+ return WordIdAndProbabilityEntry(
+ result.key(), ProbabilityEntry::decode(result.value(), mHasHistoricalInfo));
+ }
+
+ bool operator!=(const EntryIterator &other) const {
+ return mTrieMapIterator != other.mTrieMapIterator;
+ }
+
+ const EntryIterator &operator++() {
+ ++mTrieMapIterator;
+ return *this;
+ }
+
+ private:
+ DISALLOW_DEFAULT_CONSTRUCTOR(EntryIterator);
+ DISALLOW_ASSIGNMENT_OPERATOR(EntryIterator);
+
+ TrieMap::TrieMapIterator mTrieMapIterator;
+ const bool mHasHistoricalInfo;
+ };
+
+ // Class represents range to use range base for loops.
+ class EntryRange {
+ public:
+ EntryRange(const TrieMap::TrieMapRange trieMapRange, const bool hasHistoricalInfo)
+ : mTrieMapRange(trieMapRange), mHasHistoricalInfo(hasHistoricalInfo) {}
+
+ EntryIterator begin() const {
+ return EntryIterator(mTrieMapRange.begin(), mHasHistoricalInfo);
+ }
+
+ EntryIterator end() const {
+ return EntryIterator(mTrieMapRange.end(), mHasHistoricalInfo);
+ }
+
+ private:
+ DISALLOW_DEFAULT_CONSTRUCTOR(EntryRange);
+ DISALLOW_ASSIGNMENT_OPERATOR(EntryRange);
+
+ const TrieMap::TrieMapRange mTrieMapRange;
+ const bool mHasHistoricalInfo;
+ };
+
LanguageModelDictContent(const ReadWriteByteArrayView trieMapBuffer,
const bool hasHistoricalInfo)
: mTrieMap(trieMapBuffer), mHasHistoricalInfo(hasHistoricalInfo) {}
@@ -76,6 +145,8 @@ class LanguageModelDictContent {
bool removeNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId);
+ EntryRange getProbabilityEntries(const WordIdArrayView prevWordIds) const;
+
bool updateAllProbabilityEntries(const HeaderPolicy *const headerPolicy,
int *const outEntryCounts) {
for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h
index c2aeac211..00765888b 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h
@@ -98,7 +98,7 @@ class TrieMap {
TrieMapIterator(const TrieMap *const trieMap, const int bitmapEntryIndex)
: mTrieMap(trieMap), mStateStack(), mBaseBitmapEntryIndex(bitmapEntryIndex),
mKey(0), mValue(0), mIsValid(false), mNextLevelBitmapEntryIndex(INVALID_INDEX) {
- if (!trieMap) {
+ if (!trieMap || mBaseBitmapEntryIndex == INVALID_INDEX) {
return;
}
const Entry bitmapEntry = mTrieMap->readEntry(mBaseBitmapEntryIndex);
diff --git a/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp b/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp
index 3cacba1c3..ca8d56f27 100644
--- a/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp
+++ b/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp
@@ -18,6 +18,8 @@
#include <gtest/gtest.h>
+#include <unordered_set>
+
#include "utils/int_array_view.h"
namespace latinime {
@@ -69,5 +71,23 @@ TEST(LanguageModelDictContentTest, TestUnigramProbabilityWithHistoricalInfo) {
EXPECT_TRUE(LanguageModelDictContent.removeProbabilityEntry(wordId));
}
+TEST(LanguageModelDictContentTest, TestIterateProbabilityEntry) {
+ LanguageModelDictContent languageModelDictContent(false /* useHistoricalInfo */);
+
+ const ProbabilityEntry originalEntry(0xFC, 100);
+
+ const int wordIds[] = { 1, 2, 3, 4, 5 };
+ for (const int wordId : wordIds) {
+ languageModelDictContent.setProbabilityEntry(wordId, &originalEntry);
+ }
+ std::unordered_set<int> wordIdSet(std::begin(wordIds), std::end(wordIds));
+ for (const auto entry : languageModelDictContent.getProbabilityEntries(WordIdArrayView())) {
+ EXPECT_EQ(originalEntry.getFlags(), entry.getProbabilityEntry().getFlags());
+ EXPECT_EQ(originalEntry.getProbability(), entry.getProbabilityEntry().getProbability());
+ wordIdSet.erase(entry.getWordId());
+ }
+ EXPECT_TRUE(wordIdSet.empty());
+}
+
} // namespace
} // namespace latinime