aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp37
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h121
-rw-r--r--native/jni/tests/suggest/policyimpl/dictionary/utils/trie_map_test.cpp59
3 files changed, 215 insertions, 2 deletions
diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp
index a7d86f9ae..c70047638 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp
@@ -98,6 +98,43 @@ bool TrieMap::put(const int key, const uint64_t value, const int bitmapEntryInde
return putInternal(unsignedKey, value, getBitShuffledKey(unsignedKey), bitmapEntryIndex,
readEntry(bitmapEntryIndex), 0 /* level */);
}
+/**
+ * Iterate next entry in a certain level.
+ *
+ * @param iterationState the iteration state that will be read and updated in this method.
+ * @param outKey the output key
+ * @return Result instance. mIsValid is false when all entries are iterated.
+ */
+const TrieMap::Result TrieMap::iterateNext(std::vector<TableIterationState> *const iterationState,
+ int *const outKey) const {
+ while (!iterationState->empty()) {
+ TableIterationState &state = iterationState->back();
+ if (state.mTableSize <= state.mCurrentIndex) {
+ // Move to parent.
+ iterationState->pop_back();
+ } else {
+ const int entryIndex = state.mTableIndex + state.mCurrentIndex;
+ state.mCurrentIndex += 1;
+ const Entry entry = readEntry(entryIndex);
+ if (entry.isBitmapEntry()) {
+ // Move to child.
+ iterationState->emplace_back(popCount(entry.getBitmap()), entry.getTableIndex());
+ } else {
+ if (outKey) {
+ *outKey = entry.getKey();
+ }
+ if (!entry.hasTerminalLink()) {
+ return Result(entry.getValue(), true, INVALID_INDEX);
+ }
+ const int valueEntryIndex = entry.getValueEntryIndex();
+ const Entry valueEntry = readEntry(valueEntryIndex);
+ return Result(valueEntry.getValueOfValueEntry(), true, valueEntryIndex + 1);
+ }
+ }
+ }
+ // Visited all entries.
+ return Result(0, false, INVALID_INDEX);
+}
/**
* Shuffle bits of the key in the fixed order.
diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h
index 2a9051f98..b5bcc3bc8 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h
@@ -44,6 +44,117 @@ class TrieMap {
mNextLevelBitmapEntryIndex(nextLevelBitmapEntryIndex) {}
};
+ /**
+ * Struct to record iteration state in a table.
+ */
+ struct TableIterationState {
+ int mTableSize;
+ int mTableIndex;
+ int mCurrentIndex;
+
+ TableIterationState(const int tableSize, const int tableIndex)
+ : mTableSize(tableSize), mTableIndex(tableIndex), mCurrentIndex(0) {}
+ };
+
+ class TrieMapRange;
+ class TrieMapIterator {
+ public:
+ class IterationResult {
+ public:
+ IterationResult(const TrieMap *const trieMap, const int key, const uint64_t value,
+ const int nextLeveBitmapEntryIndex)
+ : mTrieMap(trieMap), mKey(key), mValue(value),
+ mNextLevelBitmapEntryIndex(nextLeveBitmapEntryIndex) {}
+
+ const TrieMapRange getEntriesInNextLevel() const {
+ return TrieMapRange(mTrieMap, mNextLevelBitmapEntryIndex);
+ }
+
+ bool hasNextLevelMap() const {
+ return mNextLevelBitmapEntryIndex != INVALID_INDEX;
+ }
+
+ AK_FORCE_INLINE int key() const {
+ return mKey;
+ }
+
+ AK_FORCE_INLINE uint64_t value() const {
+ return mValue;
+ }
+
+ private:
+ const TrieMap *const mTrieMap;
+ const int mKey;
+ const uint64_t mValue;
+ const int mNextLevelBitmapEntryIndex;
+ };
+
+ TrieMapIterator(const TrieMap *const trieMap, const int bitmapEntryIndex)
+ : mTrieMap(trieMap), mStateStack(), mBaseBitmapEntryIndex(bitmapEntryIndex),
+ mKey(0), mValue(0), mIsValid(false), mNextLevelBitmapEntryIndex(INVALID_INDEX) {
+ if (!trieMap) {
+ return;
+ }
+ const Entry bitmapEntry = mTrieMap->readEntry(mBaseBitmapEntryIndex);
+ mStateStack.emplace_back(
+ mTrieMap->popCount(bitmapEntry.getBitmap()), bitmapEntry.getTableIndex());
+ this->operator++();
+ }
+
+ const IterationResult operator*() const {
+ return IterationResult(mTrieMap, mKey, mValue, mNextLevelBitmapEntryIndex);
+ }
+
+ bool operator!=(const TrieMapIterator &other) const {
+ // Caveat: This works only for for loops.
+ return mIsValid || other.mIsValid;
+ }
+
+ const TrieMapIterator &operator++() {
+ const Result result = mTrieMap->iterateNext(&mStateStack, &mKey);
+ mValue = result.mValue;
+ mIsValid = result.mIsValid;
+ mNextLevelBitmapEntryIndex = result.mNextLevelBitmapEntryIndex;
+ return *this;
+ }
+
+ private:
+ DISALLOW_DEFAULT_CONSTRUCTOR(TrieMapIterator);
+ DISALLOW_ASSIGNMENT_OPERATOR(TrieMapIterator);
+
+ const TrieMap *const mTrieMap;
+ std::vector<TrieMap::TableIterationState> mStateStack;
+ const int mBaseBitmapEntryIndex;
+ int mKey;
+ uint64_t mValue;
+ bool mIsValid;
+ int mNextLevelBitmapEntryIndex;
+ };
+
+ /**
+ * Class to support iterating entries in TrieMap by range base for loops.
+ */
+ class TrieMapRange {
+ public:
+ TrieMapRange(const TrieMap *const trieMap, const int bitmapEntryIndex)
+ : mTrieMap(trieMap), mBaseBitmapEntryIndex(bitmapEntryIndex) {};
+
+ TrieMapIterator begin() const {
+ return TrieMapIterator(mTrieMap, mBaseBitmapEntryIndex);
+ }
+
+ const TrieMapIterator end() const {
+ return TrieMapIterator(nullptr, INVALID_INDEX);
+ }
+
+ private:
+ DISALLOW_DEFAULT_CONSTRUCTOR(TrieMapRange);
+ DISALLOW_ASSIGNMENT_OPERATOR(TrieMapRange);
+
+ const TrieMap *const mTrieMap;
+ const int mBaseBitmapEntryIndex;
+ };
+
static const int INVALID_INDEX;
static const uint64_t MAX_VALUE;
@@ -73,6 +184,14 @@ class TrieMap {
bool put(const int key, const uint64_t value, const int bitmapEntryIndex);
+ const TrieMapRange getEntriesInRootLevel() const {
+ return getEntriesInSpecifiedLevel(ROOT_BITMAP_ENTRY_INDEX);
+ }
+
+ const TrieMapRange getEntriesInSpecifiedLevel(const int bitmapEntryIndex) const {
+ return TrieMapRange(this, bitmapEntryIndex);
+ }
+
private:
DISALLOW_COPY_AND_ASSIGN(TrieMap);
@@ -171,6 +290,8 @@ class TrieMap {
bool addNewEntryByExpandingTable(const uint32_t key, const uint64_t value,
const int tableIndex, const uint32_t bitmap, const int bitmapEntryIndex,
const int label);
+ const Result iterateNext(std::vector<TableIterationState> *const iterationState,
+ int *const outKey) const;
AK_FORCE_INLINE const Entry readEntry(const int entryIndex) const {
return Entry(readField0(entryIndex), readField1(entryIndex));
diff --git a/native/jni/tests/suggest/policyimpl/dictionary/utils/trie_map_test.cpp b/native/jni/tests/suggest/policyimpl/dictionary/utils/trie_map_test.cpp
index 5dd782277..df778b6cf 100644
--- a/native/jni/tests/suggest/policyimpl/dictionary/utils/trie_map_test.cpp
+++ b/native/jni/tests/suggest/policyimpl/dictionary/utils/trie_map_test.cpp
@@ -54,7 +54,7 @@ TEST(TrieMapTest, TestSetAndGetLarge) {
EXPECT_TRUE(trieMap.putRoot(i, i));
}
for (int i = 0; i < ELEMENT_COUNT; ++i) {
- EXPECT_EQ(trieMap.getRoot(i).mValue, static_cast<uint64_t>(i));
+ EXPECT_EQ(static_cast<uint64_t>(i), trieMap.getRoot(i).mValue);
}
}
@@ -78,7 +78,7 @@ TEST(TrieMapTest, TestRandSetAndGetLarge) {
testKeyValuePairs[key] = value;
}
for (const auto &v : testKeyValuePairs) {
- EXPECT_EQ(trieMap.getRoot(v.first).mValue, v.second);
+ EXPECT_EQ(v.second, trieMap.getRoot(v.first).mValue);
}
}
@@ -163,6 +163,61 @@ TEST(TrieMapTest, TestMultiLevel) {
}
}
}
+
+ // Iteration
+ for (const auto &firstLevelEntry : trieMap.getEntriesInRootLevel()) {
+ EXPECT_EQ(trieMap.getRoot(firstLevelEntry.key()).mValue, firstLevelEntry.value());
+ EXPECT_EQ(firstLevelEntries[firstLevelEntry.key()], firstLevelEntry.value());
+ firstLevelEntries.erase(firstLevelEntry.key());
+ for (const auto &secondLevelEntry : firstLevelEntry.getEntriesInNextLevel()) {
+ EXPECT_EQ(twoLevelMap[firstLevelEntry.key()][secondLevelEntry.key()],
+ secondLevelEntry.value());
+ twoLevelMap[firstLevelEntry.key()].erase(secondLevelEntry.key());
+ for (const auto &thirdLevelEntry : secondLevelEntry.getEntriesInNextLevel()) {
+ EXPECT_EQ(threeLevelMap[firstLevelEntry.key()][secondLevelEntry.key()]
+ [thirdLevelEntry.key()], thirdLevelEntry.value());
+ threeLevelMap[firstLevelEntry.key()][secondLevelEntry.key()].erase(
+ thirdLevelEntry.key());
+ }
+ }
+ }
+
+ // Ensure all entries have been traversed.
+ EXPECT_TRUE(firstLevelEntries.empty());
+ for (const auto &secondLevelEntry : twoLevelMap) {
+ EXPECT_TRUE(secondLevelEntry.second.empty());
+ }
+ for (const auto &secondLevelEntry : threeLevelMap) {
+ for (const auto &thirdLevelEntry : secondLevelEntry.second) {
+ EXPECT_TRUE(thirdLevelEntry.second.empty());
+ }
+ }
+}
+
+TEST(TrieMapTest, TestIteration) {
+ static const int ELEMENT_COUNT = 200000;
+ TrieMap trieMap;
+ std::unordered_map<int, uint64_t> testKeyValuePairs;
+
+ // Use the uniform integer distribution [S_INT_MIN, S_INT_MAX].
+ std::uniform_int_distribution<int> keyDistribution(S_INT_MIN, S_INT_MAX);
+ auto keyRandomNumberGenerator = std::bind(keyDistribution, std::mt19937());
+
+ // Use the uniform distribution [0, TrieMap::MAX_VALUE].
+ std::uniform_int_distribution<uint64_t> valueDistribution(0, TrieMap::MAX_VALUE);
+ auto valueRandomNumberGenerator = std::bind(valueDistribution, std::mt19937());
+ for (int i = 0; i < ELEMENT_COUNT; ++i) {
+ const int key = keyRandomNumberGenerator();
+ const uint64_t value = valueRandomNumberGenerator();
+ EXPECT_TRUE(trieMap.putRoot(key, value));
+ testKeyValuePairs[key] = value;
+ }
+ for (const auto &entry : trieMap.getEntriesInRootLevel()) {
+ EXPECT_EQ(trieMap.getRoot(entry.key()).mValue, entry.value());
+ EXPECT_EQ(testKeyValuePairs[entry.key()], entry.value());
+ testKeyValuePairs.erase(entry.key());
+ }
+ EXPECT_TRUE(testKeyValuePairs.empty());
}
} // namespace