aboutsummaryrefslogtreecommitdiffstats
path: root/native
diff options
context:
space:
mode:
Diffstat (limited to 'native')
-rw-r--r--native/jni/src/suggest/core/dictionary/dictionary.cpp66
-rw-r--r--native/jni/src/suggest/core/dictionary/dictionary.h16
-rw-r--r--native/jni/src/suggest/core/dictionary/ngram_listener.h40
-rw-r--r--native/jni/src/suggest/core/layout/proximity_info_state.h4
-rw-r--r--native/jni/src/suggest/core/layout/proximity_info_state_utils.cpp10
-rw-r--r--native/jni/src/suggest/core/layout/proximity_info_state_utils.h10
-rw-r--r--native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h2
-rw-r--r--native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h6
-rw-r--r--native/jni/src/suggest/core/session/prev_words_info.h47
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp19
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h6
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp19
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h5
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp19
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h5
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h2
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp37
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h121
-rw-r--r--native/jni/tests/suggest/policyimpl/dictionary/utils/trie_map_test.cpp59
19 files changed, 393 insertions, 100 deletions
diff --git a/native/jni/src/suggest/core/dictionary/dictionary.cpp b/native/jni/src/suggest/core/dictionary/dictionary.cpp
index bf917d69c..d62573970 100644
--- a/native/jni/src/suggest/core/dictionary/dictionary.cpp
+++ b/native/jni/src/suggest/core/dictionary/dictionary.cpp
@@ -59,32 +59,44 @@ void Dictionary::getSuggestions(ProximityInfo *proximityInfo, DicTraverseSession
}
}
+Dictionary::NgramListenerForPrediction::NgramListenerForPrediction(
+ const PrevWordsInfo *const prevWordsInfo, SuggestionResults *const suggestionResults,
+ const DictionaryStructureWithBufferPolicy *const dictStructurePolicy)
+ : mPrevWordsInfo(prevWordsInfo), mSuggestionResults(suggestionResults),
+ mDictStructurePolicy(dictStructurePolicy) {}
+
+void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbability,
+ const int targetPtNodePos) {
+ if (targetPtNodePos == NOT_A_DICT_POS) {
+ return;
+ }
+ if (mPrevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */)
+ && ngramProbability == NOT_A_PROBABILITY) {
+ return;
+ }
+ int targetWordCodePoints[MAX_WORD_LENGTH];
+ int unigramProbability = 0;
+ const int codePointCount = mDictStructurePolicy->
+ getCodePointsAndProbabilityAndReturnCodePointCount(targetPtNodePos,
+ MAX_WORD_LENGTH, targetWordCodePoints, &unigramProbability);
+ if (codePointCount <= 0) {
+ return;
+ }
+ const int probability = mDictStructurePolicy->getProbability(
+ unigramProbability, ngramProbability);
+ mSuggestionResults->addPrediction(targetWordCodePoints, codePointCount, probability);
+}
+
void Dictionary::getPredictions(const PrevWordsInfo *const prevWordsInfo,
SuggestionResults *const outSuggestionResults) const {
TimeKeeper::setCurrentTime();
- int unigramProbability = 0;
- int bigramCodePoints[MAX_WORD_LENGTH];
- BinaryDictionaryBigramsIterator bigramsIt = prevWordsInfo->getBigramsIteratorForPrediction(
+ NgramListenerForPrediction listener(prevWordsInfo, outSuggestionResults,
mDictionaryStructureWithBufferPolicy.get());
- while (bigramsIt.hasNext()) {
- bigramsIt.next();
- if (bigramsIt.getBigramPos() == NOT_A_DICT_POS) {
- continue;
- }
- if (prevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */)
- && bigramsIt.getProbability() == NOT_A_PROBABILITY) {
- continue;
- }
- const int codePointCount = mDictionaryStructureWithBufferPolicy->
- getCodePointsAndProbabilityAndReturnCodePointCount(bigramsIt.getBigramPos(),
- MAX_WORD_LENGTH, bigramCodePoints, &unigramProbability);
- if (codePointCount <= 0) {
- continue;
- }
- const int probability = mDictionaryStructureWithBufferPolicy->getProbability(
- unigramProbability, bigramsIt.getProbability());
- outSuggestionResults->addPrediction(bigramCodePoints, codePointCount, probability);
- }
+ int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
+ prevWordsInfo->getPrevWordsTerminalPtNodePos(
+ mDictionaryStructureWithBufferPolicy.get(), prevWordsPtNodePos,
+ true /* tryLowerCaseSearch */);
+ mDictionaryStructureWithBufferPolicy->iterateNgramEntries(prevWordsPtNodePos, &listener);
}
int Dictionary::getProbability(const int *word, int length) const {
@@ -103,7 +115,15 @@ int Dictionary::getNgramProbability(const PrevWordsInfo *const prevWordsInfo, co
int nextWordPos = mDictionaryStructureWithBufferPolicy->getTerminalPtNodePositionOfWord(word,
length, false /* forceLowerCaseSearch */);
if (NOT_A_DICT_POS == nextWordPos) return NOT_A_PROBABILITY;
- return getDictionaryStructurePolicy()->getProbabilityOfPtNode(prevWordsInfo, nextWordPos);
+ if (!prevWordsInfo) {
+ return getDictionaryStructurePolicy()->getProbabilityOfPtNode(
+ nullptr /* prevWordsPtNodePos */, nextWordPos);
+ }
+ int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
+ prevWordsInfo->getPrevWordsTerminalPtNodePos(
+ mDictionaryStructureWithBufferPolicy.get(), prevWordsPtNodePos,
+ true /* tryLowerCaseSearch */);
+ return getDictionaryStructurePolicy()->getProbabilityOfPtNode(prevWordsPtNodePos, nextWordPos);
}
bool Dictionary::addUnigramEntry(const int *const word, const int length,
diff --git a/native/jni/src/suggest/core/dictionary/dictionary.h b/native/jni/src/suggest/core/dictionary/dictionary.h
index 3b41088fe..732d3b199 100644
--- a/native/jni/src/suggest/core/dictionary/dictionary.h
+++ b/native/jni/src/suggest/core/dictionary/dictionary.h
@@ -21,6 +21,7 @@
#include "defines.h"
#include "jni.h"
+#include "suggest/core/dictionary/ngram_listener.h"
#include "suggest/core/dictionary/property/word_property.h"
#include "suggest/core/policy/dictionary_header_structure_policy.h"
#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h"
@@ -114,6 +115,21 @@ class Dictionary {
typedef std::unique_ptr<SuggestInterface> SuggestInterfacePtr;
+ class NgramListenerForPrediction : public NgramListener {
+ public:
+ NgramListenerForPrediction(const PrevWordsInfo *const prevWordsInfo,
+ SuggestionResults *const suggestionResults,
+ const DictionaryStructureWithBufferPolicy *const dictStructurePolicy);
+ virtual void onVisitEntry(const int ngramProbability, const int targetPtNodePos);
+
+ private:
+ DISALLOW_IMPLICIT_CONSTRUCTORS(NgramListenerForPrediction);
+
+ const PrevWordsInfo *const mPrevWordsInfo;
+ SuggestionResults *const mSuggestionResults;
+ const DictionaryStructureWithBufferPolicy *const mDictStructurePolicy;
+ };
+
static const int HEADER_ATTRIBUTE_BUFFER_SIZE;
const DictionaryStructureWithBufferPolicy::StructurePolicyPtr
diff --git a/native/jni/src/suggest/core/dictionary/ngram_listener.h b/native/jni/src/suggest/core/dictionary/ngram_listener.h
new file mode 100644
index 000000000..88b88bafb
--- /dev/null
+++ b/native/jni/src/suggest/core/dictionary/ngram_listener.h
@@ -0,0 +1,40 @@
+/*
+ * Copyright (C) 2014, The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LATINIME_NGRAM_LISTENER_H
+#define LATINIME_NGRAM_LISTENER_H
+
+#include "defines.h"
+
+namespace latinime {
+
+/**
+ * Interface to iterate ngram entries.
+ */
+class NgramListener {
+ public:
+ virtual void onVisitEntry(const int ngramProbability, const int targetPtNodePos) = 0;
+ virtual ~NgramListener() {};
+
+ protected:
+ NgramListener() {}
+
+ private:
+ DISALLOW_COPY_AND_ASSIGN(NgramListener);
+
+};
+} // namespace latinime
+#endif /* LATINIME_NGRAM_LISTENER_H */
diff --git a/native/jni/src/suggest/core/layout/proximity_info_state.h b/native/jni/src/suggest/core/layout/proximity_info_state.h
index 6b1a319aa..e6180fe17 100644
--- a/native/jni/src/suggest/core/layout/proximity_info_state.h
+++ b/native/jni/src/suggest/core/layout/proximity_info_state.h
@@ -215,13 +215,13 @@ class ProximityInfoState {
std::vector<float> mSpeedRates;
std::vector<float> mDirections;
// probabilities of skipping or mapping to a key for each point.
- std::vector<std::unordered_map<int, float> > mCharProbabilities;
+ std::vector<std::unordered_map<int, float>> mCharProbabilities;
// The vector for the key code set which holds nearby keys of some trailing sampled input points
// for each sampled input point. These nearby keys contain the next characters which can be in
// the dictionary. Specifically, currently we are looking for keys nearby trailing sampled
// inputs including the current input point.
std::vector<ProximityInfoStateUtils::NearKeycodesSet> mSampledSearchKeySets;
- std::vector<std::vector<int> > mSampledSearchKeyVectors;
+ std::vector<std::vector<int>> mSampledSearchKeyVectors;
bool mTouchPositionCorrectionEnabled;
int mInputProximities[MAX_PROXIMITY_CHARS_SIZE * MAX_WORD_LENGTH];
int mSampledInputSize;
diff --git a/native/jni/src/suggest/core/layout/proximity_info_state_utils.cpp b/native/jni/src/suggest/core/layout/proximity_info_state_utils.cpp
index ea3b02216..0aeb36aad 100644
--- a/native/jni/src/suggest/core/layout/proximity_info_state_utils.cpp
+++ b/native/jni/src/suggest/core/layout/proximity_info_state_utils.cpp
@@ -621,7 +621,7 @@ namespace latinime {
const std::vector<int> *const sampledLengthCache,
const std::vector<float> *const sampledNormalizedSquaredLengthCache,
const ProximityInfo *const proximityInfo,
- std::vector<std::unordered_map<int, float> > *charProbabilities) {
+ std::vector<std::unordered_map<int, float>> *charProbabilities) {
charProbabilities->resize(sampledInputSize);
// Calculates probabilities of using a point as a correlated point with the character
// for each point.
@@ -822,9 +822,9 @@ namespace latinime {
/* static */ void ProximityInfoStateUtils::updateSampledSearchKeySets(
const ProximityInfo *const proximityInfo, const int sampledInputSize,
const int lastSavedInputSize, const std::vector<int> *const sampledLengthCache,
- const std::vector<std::unordered_map<int, float> > *const charProbabilities,
+ const std::vector<std::unordered_map<int, float>> *const charProbabilities,
std::vector<NearKeycodesSet> *sampledSearchKeySets,
- std::vector<std::vector<int> > *sampledSearchKeyVectors) {
+ std::vector<std::vector<int>> *sampledSearchKeyVectors) {
sampledSearchKeySets->resize(sampledInputSize);
sampledSearchKeyVectors->resize(sampledInputSize);
const int readForwordLength = static_cast<int>(
@@ -868,7 +868,7 @@ namespace latinime {
/* static */ bool ProximityInfoStateUtils::suppressCharProbabilities(const int mostCommonKeyWidth,
const int sampledInputSize, const std::vector<int> *const lengthCache,
const int index0, const int index1,
- std::vector<std::unordered_map<int, float> > *charProbabilities) {
+ std::vector<std::unordered_map<int, float>> *charProbabilities) {
ASSERT(0 <= index0 && index0 < sampledInputSize);
ASSERT(0 <= index1 && index1 < sampledInputSize);
const float keyWidthFloat = static_cast<float>(mostCommonKeyWidth);
@@ -933,7 +933,7 @@ namespace latinime {
// returns probability of generating the word.
/* static */ float ProximityInfoStateUtils::getMostProbableString(
const ProximityInfo *const proximityInfo, const int sampledInputSize,
- const std::vector<std::unordered_map<int, float> > *const charProbabilities,
+ const std::vector<std::unordered_map<int, float>> *const charProbabilities,
int *const codePointBuf) {
ASSERT(sampledInputSize >= 0);
memset(codePointBuf, 0, sizeof(codePointBuf[0]) * MAX_WORD_LENGTH);
diff --git a/native/jni/src/suggest/core/layout/proximity_info_state_utils.h b/native/jni/src/suggest/core/layout/proximity_info_state_utils.h
index 211a79737..4043334e6 100644
--- a/native/jni/src/suggest/core/layout/proximity_info_state_utils.h
+++ b/native/jni/src/suggest/core/layout/proximity_info_state_utils.h
@@ -72,13 +72,13 @@ class ProximityInfoStateUtils {
const std::vector<int> *const sampledLengthCache,
const std::vector<float> *const sampledNormalizedSquaredLengthCache,
const ProximityInfo *const proximityInfo,
- std::vector<std::unordered_map<int, float> > *charProbabilities);
+ std::vector<std::unordered_map<int, float>> *charProbabilities);
static void updateSampledSearchKeySets(const ProximityInfo *const proximityInfo,
const int sampledInputSize, const int lastSavedInputSize,
const std::vector<int> *const sampledLengthCache,
- const std::vector<std::unordered_map<int, float> > *const charProbabilities,
+ const std::vector<std::unordered_map<int, float>> *const charProbabilities,
std::vector<NearKeycodesSet> *sampledSearchKeySets,
- std::vector<std::vector<int> > *sampledSearchKeyVectors);
+ std::vector<std::vector<int>> *sampledSearchKeyVectors);
static float getPointToKeyByIdLength(const float maxPointToKeyLength,
const std::vector<float> *const sampledNormalizedSquaredLengthCache, const int keyCount,
const int inputIndex, const int keyId);
@@ -105,7 +105,7 @@ class ProximityInfoStateUtils {
// TODO: Move to most_probable_string_utils.h
static float getMostProbableString(const ProximityInfo *const proximityInfo,
const int sampledInputSize,
- const std::vector<std::unordered_map<int, float> > *const charProbabilities,
+ const std::vector<std::unordered_map<int, float>> *const charProbabilities,
int *const codePointBuf);
private:
@@ -147,7 +147,7 @@ class ProximityInfoStateUtils {
const int index2);
static bool suppressCharProbabilities(const int mostCommonKeyWidth,
const int sampledInputSize, const std::vector<int> *const lengthCache, const int index0,
- const int index1, std::vector<std::unordered_map<int, float> > *charProbabilities);
+ const int index1, std::vector<std::unordered_map<int, float>> *charProbabilities);
static float calculateSquaredDistanceFromSweetSpotCenter(
const ProximityInfo *const proximityInfo, const std::vector<int> *const sampledInputXs,
const std::vector<int> *const sampledInputYs, const int keyIndex,
diff --git a/native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h b/native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h
index a61227626..6da390e55 100644
--- a/native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h
+++ b/native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h
@@ -30,7 +30,7 @@ namespace latinime {
*/
class DictionaryHeaderStructurePolicy {
public:
- typedef std::map<std::vector<int>, std::vector<int> > AttributeMap;
+ typedef std::map<std::vector<int>, std::vector<int>> AttributeMap;
virtual ~DictionaryHeaderStructurePolicy() {}
diff --git a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h
index 7ad20e782..7e3bf3ff6 100644
--- a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h
+++ b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h
@@ -30,6 +30,7 @@ class DicNodeVector;
class DictionaryBigramsStructurePolicy;
class DictionaryHeaderStructurePolicy;
class DictionaryShortcutsStructurePolicy;
+class NgramListener;
class PrevWordsInfo;
class UnigramProperty;
@@ -58,9 +59,12 @@ class DictionaryStructureWithBufferPolicy {
virtual int getProbability(const int unigramProbability,
const int bigramProbability) const = 0;
- virtual int getProbabilityOfPtNode(const PrevWordsInfo *const prevWordsInfo,
+ virtual int getProbabilityOfPtNode(const int *const prevWordsPtNodePos,
const int nodePos) const = 0;
+ virtual void iterateNgramEntries(const int *const prevWordsPtNodePos,
+ NgramListener *const listener) const = 0;
+
virtual int getShortcutPositionOfPtNode(const int nodePos) const = 0;
virtual BinaryDictionaryBigramsIterator getBigramsIteratorOfPtNode(const int nodePos) const = 0;
diff --git a/native/jni/src/suggest/core/session/prev_words_info.h b/native/jni/src/suggest/core/session/prev_words_info.h
index 76276f528..e44e876e9 100644
--- a/native/jni/src/suggest/core/session/prev_words_info.h
+++ b/native/jni/src/suggest/core/session/prev_words_info.h
@@ -90,13 +90,6 @@ class PrevWordsInfo {
}
}
- BinaryDictionaryBigramsIterator getBigramsIteratorForPrediction(
- const DictionaryStructureWithBufferPolicy *const dictStructurePolicy) const {
- return getBigramsIteratorForWordWithTryingLowerCaseSearch(
- dictStructurePolicy, mPrevWordCodePoints[0], mPrevWordCodePointCount[0],
- mIsBeginningOfSentence[0]);
- }
-
// n is 1-indexed.
const int *getNthPrevWordCodePoints(const int n) const {
if (n <= 0 || n > MAX_PREV_WORD_COUNT_FOR_N_GRAM) {
@@ -154,46 +147,6 @@ class PrevWordsInfo {
codePoints, codePointCount, true /* forceLowerCaseSearch */);
}
- static BinaryDictionaryBigramsIterator getBigramsIteratorForWordWithTryingLowerCaseSearch(
- const DictionaryStructureWithBufferPolicy *const dictStructurePolicy,
- const int *const wordCodePoints, const int wordCodePointCount,
- const bool isBeginningOfSentence) {
- if (!dictStructurePolicy || !wordCodePoints || wordCodePointCount > MAX_WORD_LENGTH) {
- return BinaryDictionaryBigramsIterator();
- }
- int codePoints[MAX_WORD_LENGTH];
- int codePointCount = wordCodePointCount;
- memmove(codePoints, wordCodePoints, sizeof(int) * codePointCount);
- if (isBeginningOfSentence) {
- codePointCount = CharUtils::attachBeginningOfSentenceMarker(codePoints,
- codePointCount, MAX_WORD_LENGTH);
- if (codePointCount <= 0) {
- return BinaryDictionaryBigramsIterator();
- }
- }
- BinaryDictionaryBigramsIterator bigramsIt = getBigramsIteratorForWord(dictStructurePolicy,
- codePoints, codePointCount, false /* forceLowerCaseSearch */);
- // getBigramsIteratorForWord returns an empty iterator if this word isn't in the dictionary
- // or has no bigrams.
- if (bigramsIt.hasNext()) {
- return bigramsIt;
- }
- // If no bigrams for this exact word, search again in lower case.
- return getBigramsIteratorForWord(dictStructurePolicy, codePoints,
- codePointCount, true /* forceLowerCaseSearch */);
- }
-
- static BinaryDictionaryBigramsIterator getBigramsIteratorForWord(
- const DictionaryStructureWithBufferPolicy *const dictStructurePolicy,
- const int *wordCodePoints, const int wordCodePointCount,
- const bool forceLowerCaseSearch) {
- if (!wordCodePoints || wordCodePointCount <= 0) return BinaryDictionaryBigramsIterator();
- const int terminalPtNodePos = dictStructurePolicy->getTerminalPtNodePositionOfWord(
- wordCodePoints, wordCodePointCount, forceLowerCaseSearch);
- if (NOT_A_DICT_POS == terminalPtNodePos) return BinaryDictionaryBigramsIterator();
- return dictStructurePolicy->getBigramsIteratorOfPtNode(terminalPtNodePos);
- }
-
void clear() {
for (size_t i = 0; i < NELEMS(mPrevWordCodePoints); ++i) {
mPrevWordCodePointCount[i] = 0;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
index 327741065..994c42505 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
@@ -28,6 +28,7 @@
#include "suggest/core/dicnode/dic_node.h"
#include "suggest/core/dicnode/dic_node_vector.h"
+#include "suggest/core/dictionary/ngram_listener.h"
#include "suggest/core/dictionary/property/bigram_property.h"
#include "suggest/core/dictionary/property/unigram_property.h"
#include "suggest/core/dictionary/property/word_property.h"
@@ -131,7 +132,7 @@ int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability,
}
}
-int Ver4PatriciaTriePolicy::getProbabilityOfPtNode(const PrevWordsInfo *const prevWordsInfo,
+int Ver4PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtNodePos,
const int ptNodePos) const {
if (ptNodePos == NOT_A_DICT_POS) {
return NOT_A_PROBABILITY;
@@ -140,9 +141,9 @@ int Ver4PatriciaTriePolicy::getProbabilityOfPtNode(const PrevWordsInfo *const pr
if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) {
return NOT_A_PROBABILITY;
}
- if (prevWordsInfo) {
+ if (prevWordsPtNodePos) {
BinaryDictionaryBigramsIterator bigramsIt =
- prevWordsInfo->getBigramsIteratorForPrediction(this /* dictStructurePolicy */);
+ getBigramsIteratorOfPtNode(prevWordsPtNodePos[0]);
while (bigramsIt.hasNext()) {
bigramsIt.next();
if (bigramsIt.getBigramPos() == ptNodePos
@@ -155,6 +156,18 @@ int Ver4PatriciaTriePolicy::getProbabilityOfPtNode(const PrevWordsInfo *const pr
return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
}
+void Ver4PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordsPtNodePos,
+ NgramListener *const listener) const {
+ if (!prevWordsPtNodePos) {
+ return;
+ }
+ BinaryDictionaryBigramsIterator bigramsIt = getBigramsIteratorOfPtNode(prevWordsPtNodePos[0]);
+ while (bigramsIt.hasNext()) {
+ bigramsIt.next();
+ listener->onVisitEntry(bigramsIt.getProbability(), bigramsIt.getBigramPos());
+ }
+}
+
int Ver4PatriciaTriePolicy::getShortcutPositionOfPtNode(const int ptNodePos) const {
if (ptNodePos == NOT_A_DICT_POS) {
return NOT_A_DICT_POS;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h
index c80a73af7..ff69de7c0 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h
@@ -90,8 +90,10 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
int getProbability(const int unigramProbability, const int bigramProbability) const;
- int getProbabilityOfPtNode(const PrevWordsInfo *const prevWordsInfo,
- const int ptNodePos) const;
+ int getProbabilityOfPtNode(const int *const prevWordsPtNodePos, const int ptNodePos) const;
+
+ void iterateNgramEntries(const int *const prevWordsPtNodePos,
+ NgramListener *const listener) const;
int getShortcutPositionOfPtNode(const int ptNodePos) const;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp
index b909e8268..53415aeb6 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp
@@ -21,6 +21,7 @@
#include "suggest/core/dicnode/dic_node.h"
#include "suggest/core/dicnode/dic_node_vector.h"
#include "suggest/core/dictionary/binary_dictionary_bigrams_iterator.h"
+#include "suggest/core/dictionary/ngram_listener.h"
#include "suggest/core/session/prev_words_info.h"
#include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h"
#include "suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h"
@@ -296,7 +297,7 @@ int PatriciaTriePolicy::getProbability(const int unigramProbability,
}
}
-int PatriciaTriePolicy::getProbabilityOfPtNode(const PrevWordsInfo *const prevWordsInfo,
+int PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtNodePos,
const int ptNodePos) const {
if (ptNodePos == NOT_A_DICT_POS) {
return NOT_A_PROBABILITY;
@@ -309,9 +310,9 @@ int PatriciaTriePolicy::getProbabilityOfPtNode(const PrevWordsInfo *const prevWo
// for shortcuts).
return NOT_A_PROBABILITY;
}
- if (prevWordsInfo) {
+ if (prevWordsPtNodePos) {
BinaryDictionaryBigramsIterator bigramsIt =
- prevWordsInfo->getBigramsIteratorForPrediction(this /* dictStructurePolicy */);
+ getBigramsIteratorOfPtNode(prevWordsPtNodePos[0]);
while (bigramsIt.hasNext()) {
bigramsIt.next();
if (bigramsIt.getBigramPos() == ptNodePos
@@ -324,6 +325,18 @@ int PatriciaTriePolicy::getProbabilityOfPtNode(const PrevWordsInfo *const prevWo
return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
}
+void PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordsPtNodePos,
+ NgramListener *const listener) const {
+ if (!prevWordsPtNodePos) {
+ return;
+ }
+ BinaryDictionaryBigramsIterator bigramsIt = getBigramsIteratorOfPtNode(prevWordsPtNodePos[0]);
+ while (bigramsIt.hasNext()) {
+ bigramsIt.next();
+ listener->onVisitEntry(bigramsIt.getProbability(), bigramsIt.getBigramPos());
+ }
+}
+
int PatriciaTriePolicy::getShortcutPositionOfPtNode(const int ptNodePos) const {
if (ptNodePos == NOT_A_DICT_POS) {
return NOT_A_DICT_POS;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h
index 1dd5705be..07cb72b23 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h
@@ -63,7 +63,10 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
int getProbability(const int unigramProbability, const int bigramProbability) const;
- int getProbabilityOfPtNode(const PrevWordsInfo *const prevWordsInfo, const int ptNodePos) const;
+ int getProbabilityOfPtNode(const int *const prevWordsPtNodePos, const int ptNodePos) const;
+
+ void iterateNgramEntries(const int *const prevWordsPtNodePos,
+ NgramListener *const listener) const;
int getShortcutPositionOfPtNode(const int ptNodePos) const;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
index cada3d1f7..22f7e1182 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
@@ -20,6 +20,7 @@
#include "suggest/core/dicnode/dic_node.h"
#include "suggest/core/dicnode/dic_node_vector.h"
+#include "suggest/core/dictionary/ngram_listener.h"
#include "suggest/core/dictionary/property/bigram_property.h"
#include "suggest/core/dictionary/property/unigram_property.h"
#include "suggest/core/dictionary/property/word_property.h"
@@ -121,7 +122,7 @@ int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability,
}
}
-int Ver4PatriciaTriePolicy::getProbabilityOfPtNode(const PrevWordsInfo *const prevWordsInfo,
+int Ver4PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtNodePos,
const int ptNodePos) const {
if (ptNodePos == NOT_A_DICT_POS) {
return NOT_A_PROBABILITY;
@@ -130,9 +131,9 @@ int Ver4PatriciaTriePolicy::getProbabilityOfPtNode(const PrevWordsInfo *const pr
if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) {
return NOT_A_PROBABILITY;
}
- if (prevWordsInfo) {
+ if (prevWordsPtNodePos) {
BinaryDictionaryBigramsIterator bigramsIt =
- prevWordsInfo->getBigramsIteratorForPrediction(this /* dictStructurePolicy */);
+ getBigramsIteratorOfPtNode(prevWordsPtNodePos[0]);
while (bigramsIt.hasNext()) {
bigramsIt.next();
if (bigramsIt.getBigramPos() == ptNodePos
@@ -145,6 +146,18 @@ int Ver4PatriciaTriePolicy::getProbabilityOfPtNode(const PrevWordsInfo *const pr
return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
}
+void Ver4PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordsPtNodePos,
+ NgramListener *const listener) const {
+ if (!prevWordsPtNodePos) {
+ return;
+ }
+ BinaryDictionaryBigramsIterator bigramsIt = getBigramsIteratorOfPtNode(prevWordsPtNodePos[0]);
+ while (bigramsIt.hasNext()) {
+ bigramsIt.next();
+ listener->onVisitEntry(bigramsIt.getProbability(), bigramsIt.getBigramPos());
+ }
+}
+
int Ver4PatriciaTriePolicy::getShortcutPositionOfPtNode(const int ptNodePos) const {
if (ptNodePos == NOT_A_DICT_POS) {
return NOT_A_DICT_POS;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h
index b0f16cd01..c5b6a80c0 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h
@@ -72,7 +72,10 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
int getProbability(const int unigramProbability, const int bigramProbability) const;
- int getProbabilityOfPtNode(const PrevWordsInfo *const prevWordsInfo, const int ptNodePos) const;
+ int getProbabilityOfPtNode(const int *const prevWordsPtNodePos, const int ptNodePos) const;
+
+ void iterateNgramEntries(const int *const prevWordsPtNodePos,
+ NgramListener *const listener) const;
int getShortcutPositionOfPtNode(const int ptNodePos) const;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h
index 3ff80aeec..9910777b8 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h
@@ -84,7 +84,7 @@ class ForgettingCurveUtils {
static const int STRONG_BASE_PROBABILITY;
static const int AGGRESSIVE_BASE_PROBABILITY;
- std::vector<std::vector<std::vector<int> > > mTables;
+ std::vector<std::vector<std::vector<int>>> mTables;
static int getBaseProbabilityForLevel(const int tableId, const int level);
};
diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp
index a7d86f9ae..c70047638 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp
@@ -98,6 +98,43 @@ bool TrieMap::put(const int key, const uint64_t value, const int bitmapEntryInde
return putInternal(unsignedKey, value, getBitShuffledKey(unsignedKey), bitmapEntryIndex,
readEntry(bitmapEntryIndex), 0 /* level */);
}
+/**
+ * Iterate next entry in a certain level.
+ *
+ * @param iterationState the iteration state that will be read and updated in this method.
+ * @param outKey the output key
+ * @return Result instance. mIsValid is false when all entries are iterated.
+ */
+const TrieMap::Result TrieMap::iterateNext(std::vector<TableIterationState> *const iterationState,
+ int *const outKey) const {
+ while (!iterationState->empty()) {
+ TableIterationState &state = iterationState->back();
+ if (state.mTableSize <= state.mCurrentIndex) {
+ // Move to parent.
+ iterationState->pop_back();
+ } else {
+ const int entryIndex = state.mTableIndex + state.mCurrentIndex;
+ state.mCurrentIndex += 1;
+ const Entry entry = readEntry(entryIndex);
+ if (entry.isBitmapEntry()) {
+ // Move to child.
+ iterationState->emplace_back(popCount(entry.getBitmap()), entry.getTableIndex());
+ } else {
+ if (outKey) {
+ *outKey = entry.getKey();
+ }
+ if (!entry.hasTerminalLink()) {
+ return Result(entry.getValue(), true, INVALID_INDEX);
+ }
+ const int valueEntryIndex = entry.getValueEntryIndex();
+ const Entry valueEntry = readEntry(valueEntryIndex);
+ return Result(valueEntry.getValueOfValueEntry(), true, valueEntryIndex + 1);
+ }
+ }
+ }
+ // Visited all entries.
+ return Result(0, false, INVALID_INDEX);
+}
/**
* Shuffle bits of the key in the fixed order.
diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h
index 2a9051f98..b5bcc3bc8 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h
@@ -44,6 +44,117 @@ class TrieMap {
mNextLevelBitmapEntryIndex(nextLevelBitmapEntryIndex) {}
};
+ /**
+ * Struct to record iteration state in a table.
+ */
+ struct TableIterationState {
+ int mTableSize;
+ int mTableIndex;
+ int mCurrentIndex;
+
+ TableIterationState(const int tableSize, const int tableIndex)
+ : mTableSize(tableSize), mTableIndex(tableIndex), mCurrentIndex(0) {}
+ };
+
+ class TrieMapRange;
+ class TrieMapIterator {
+ public:
+ class IterationResult {
+ public:
+ IterationResult(const TrieMap *const trieMap, const int key, const uint64_t value,
+ const int nextLeveBitmapEntryIndex)
+ : mTrieMap(trieMap), mKey(key), mValue(value),
+ mNextLevelBitmapEntryIndex(nextLeveBitmapEntryIndex) {}
+
+ const TrieMapRange getEntriesInNextLevel() const {
+ return TrieMapRange(mTrieMap, mNextLevelBitmapEntryIndex);
+ }
+
+ bool hasNextLevelMap() const {
+ return mNextLevelBitmapEntryIndex != INVALID_INDEX;
+ }
+
+ AK_FORCE_INLINE int key() const {
+ return mKey;
+ }
+
+ AK_FORCE_INLINE uint64_t value() const {
+ return mValue;
+ }
+
+ private:
+ const TrieMap *const mTrieMap;
+ const int mKey;
+ const uint64_t mValue;
+ const int mNextLevelBitmapEntryIndex;
+ };
+
+ TrieMapIterator(const TrieMap *const trieMap, const int bitmapEntryIndex)
+ : mTrieMap(trieMap), mStateStack(), mBaseBitmapEntryIndex(bitmapEntryIndex),
+ mKey(0), mValue(0), mIsValid(false), mNextLevelBitmapEntryIndex(INVALID_INDEX) {
+ if (!trieMap) {
+ return;
+ }
+ const Entry bitmapEntry = mTrieMap->readEntry(mBaseBitmapEntryIndex);
+ mStateStack.emplace_back(
+ mTrieMap->popCount(bitmapEntry.getBitmap()), bitmapEntry.getTableIndex());
+ this->operator++();
+ }
+
+ const IterationResult operator*() const {
+ return IterationResult(mTrieMap, mKey, mValue, mNextLevelBitmapEntryIndex);
+ }
+
+ bool operator!=(const TrieMapIterator &other) const {
+ // Caveat: This works only for for loops.
+ return mIsValid || other.mIsValid;
+ }
+
+ const TrieMapIterator &operator++() {
+ const Result result = mTrieMap->iterateNext(&mStateStack, &mKey);
+ mValue = result.mValue;
+ mIsValid = result.mIsValid;
+ mNextLevelBitmapEntryIndex = result.mNextLevelBitmapEntryIndex;
+ return *this;
+ }
+
+ private:
+ DISALLOW_DEFAULT_CONSTRUCTOR(TrieMapIterator);
+ DISALLOW_ASSIGNMENT_OPERATOR(TrieMapIterator);
+
+ const TrieMap *const mTrieMap;
+ std::vector<TrieMap::TableIterationState> mStateStack;
+ const int mBaseBitmapEntryIndex;
+ int mKey;
+ uint64_t mValue;
+ bool mIsValid;
+ int mNextLevelBitmapEntryIndex;
+ };
+
+ /**
+ * Class to support iterating entries in TrieMap by range base for loops.
+ */
+ class TrieMapRange {
+ public:
+ TrieMapRange(const TrieMap *const trieMap, const int bitmapEntryIndex)
+ : mTrieMap(trieMap), mBaseBitmapEntryIndex(bitmapEntryIndex) {};
+
+ TrieMapIterator begin() const {
+ return TrieMapIterator(mTrieMap, mBaseBitmapEntryIndex);
+ }
+
+ const TrieMapIterator end() const {
+ return TrieMapIterator(nullptr, INVALID_INDEX);
+ }
+
+ private:
+ DISALLOW_DEFAULT_CONSTRUCTOR(TrieMapRange);
+ DISALLOW_ASSIGNMENT_OPERATOR(TrieMapRange);
+
+ const TrieMap *const mTrieMap;
+ const int mBaseBitmapEntryIndex;
+ };
+
static const int INVALID_INDEX;
static const uint64_t MAX_VALUE;
@@ -73,6 +184,14 @@ class TrieMap {
bool put(const int key, const uint64_t value, const int bitmapEntryIndex);
+ const TrieMapRange getEntriesInRootLevel() const {
+ return getEntriesInSpecifiedLevel(ROOT_BITMAP_ENTRY_INDEX);
+ }
+
+ const TrieMapRange getEntriesInSpecifiedLevel(const int bitmapEntryIndex) const {
+ return TrieMapRange(this, bitmapEntryIndex);
+ }
+
private:
DISALLOW_COPY_AND_ASSIGN(TrieMap);
@@ -171,6 +290,8 @@ class TrieMap {
bool addNewEntryByExpandingTable(const uint32_t key, const uint64_t value,
const int tableIndex, const uint32_t bitmap, const int bitmapEntryIndex,
const int label);
+ const Result iterateNext(std::vector<TableIterationState> *const iterationState,
+ int *const outKey) const;
AK_FORCE_INLINE const Entry readEntry(const int entryIndex) const {
return Entry(readField0(entryIndex), readField1(entryIndex));
diff --git a/native/jni/tests/suggest/policyimpl/dictionary/utils/trie_map_test.cpp b/native/jni/tests/suggest/policyimpl/dictionary/utils/trie_map_test.cpp
index 5dd782277..df778b6cf 100644
--- a/native/jni/tests/suggest/policyimpl/dictionary/utils/trie_map_test.cpp
+++ b/native/jni/tests/suggest/policyimpl/dictionary/utils/trie_map_test.cpp
@@ -54,7 +54,7 @@ TEST(TrieMapTest, TestSetAndGetLarge) {
EXPECT_TRUE(trieMap.putRoot(i, i));
}
for (int i = 0; i < ELEMENT_COUNT; ++i) {
- EXPECT_EQ(trieMap.getRoot(i).mValue, static_cast<uint64_t>(i));
+ EXPECT_EQ(static_cast<uint64_t>(i), trieMap.getRoot(i).mValue);
}
}
@@ -78,7 +78,7 @@ TEST(TrieMapTest, TestRandSetAndGetLarge) {
testKeyValuePairs[key] = value;
}
for (const auto &v : testKeyValuePairs) {
- EXPECT_EQ(trieMap.getRoot(v.first).mValue, v.second);
+ EXPECT_EQ(v.second, trieMap.getRoot(v.first).mValue);
}
}
@@ -163,6 +163,61 @@ TEST(TrieMapTest, TestMultiLevel) {
}
}
}
+
+ // Iteration
+ for (const auto &firstLevelEntry : trieMap.getEntriesInRootLevel()) {
+ EXPECT_EQ(trieMap.getRoot(firstLevelEntry.key()).mValue, firstLevelEntry.value());
+ EXPECT_EQ(firstLevelEntries[firstLevelEntry.key()], firstLevelEntry.value());
+ firstLevelEntries.erase(firstLevelEntry.key());
+ for (const auto &secondLevelEntry : firstLevelEntry.getEntriesInNextLevel()) {
+ EXPECT_EQ(twoLevelMap[firstLevelEntry.key()][secondLevelEntry.key()],
+ secondLevelEntry.value());
+ twoLevelMap[firstLevelEntry.key()].erase(secondLevelEntry.key());
+ for (const auto &thirdLevelEntry : secondLevelEntry.getEntriesInNextLevel()) {
+ EXPECT_EQ(threeLevelMap[firstLevelEntry.key()][secondLevelEntry.key()]
+ [thirdLevelEntry.key()], thirdLevelEntry.value());
+ threeLevelMap[firstLevelEntry.key()][secondLevelEntry.key()].erase(
+ thirdLevelEntry.key());
+ }
+ }
+ }
+
+ // Ensure all entries have been traversed.
+ EXPECT_TRUE(firstLevelEntries.empty());
+ for (const auto &secondLevelEntry : twoLevelMap) {
+ EXPECT_TRUE(secondLevelEntry.second.empty());
+ }
+ for (const auto &secondLevelEntry : threeLevelMap) {
+ for (const auto &thirdLevelEntry : secondLevelEntry.second) {
+ EXPECT_TRUE(thirdLevelEntry.second.empty());
+ }
+ }
+}
+
+TEST(TrieMapTest, TestIteration) {
+ static const int ELEMENT_COUNT = 200000;
+ TrieMap trieMap;
+ std::unordered_map<int, uint64_t> testKeyValuePairs;
+
+ // Use the uniform integer distribution [S_INT_MIN, S_INT_MAX].
+ std::uniform_int_distribution<int> keyDistribution(S_INT_MIN, S_INT_MAX);
+ auto keyRandomNumberGenerator = std::bind(keyDistribution, std::mt19937());
+
+ // Use the uniform distribution [0, TrieMap::MAX_VALUE].
+ std::uniform_int_distribution<uint64_t> valueDistribution(0, TrieMap::MAX_VALUE);
+ auto valueRandomNumberGenerator = std::bind(valueDistribution, std::mt19937());
+ for (int i = 0; i < ELEMENT_COUNT; ++i) {
+ const int key = keyRandomNumberGenerator();
+ const uint64_t value = valueRandomNumberGenerator();
+ EXPECT_TRUE(trieMap.putRoot(key, value));
+ testKeyValuePairs[key] = value;
+ }
+ for (const auto &entry : trieMap.getEntriesInRootLevel()) {
+ EXPECT_EQ(trieMap.getRoot(entry.key()).mValue, entry.value());
+ EXPECT_EQ(testKeyValuePairs[entry.key()], entry.value());
+ testKeyValuePairs.erase(entry.key());
+ }
+ EXPECT_TRUE(testKeyValuePairs.empty());
}
} // namespace