diff options
Diffstat (limited to 'native')
19 files changed, 393 insertions, 100 deletions
diff --git a/native/jni/src/suggest/core/dictionary/dictionary.cpp b/native/jni/src/suggest/core/dictionary/dictionary.cpp index bf917d69c..d62573970 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.cpp +++ b/native/jni/src/suggest/core/dictionary/dictionary.cpp @@ -59,32 +59,44 @@ void Dictionary::getSuggestions(ProximityInfo *proximityInfo, DicTraverseSession } } +Dictionary::NgramListenerForPrediction::NgramListenerForPrediction( + const PrevWordsInfo *const prevWordsInfo, SuggestionResults *const suggestionResults, + const DictionaryStructureWithBufferPolicy *const dictStructurePolicy) + : mPrevWordsInfo(prevWordsInfo), mSuggestionResults(suggestionResults), + mDictStructurePolicy(dictStructurePolicy) {} + +void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbability, + const int targetPtNodePos) { + if (targetPtNodePos == NOT_A_DICT_POS) { + return; + } + if (mPrevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */) + && ngramProbability == NOT_A_PROBABILITY) { + return; + } + int targetWordCodePoints[MAX_WORD_LENGTH]; + int unigramProbability = 0; + const int codePointCount = mDictStructurePolicy-> + getCodePointsAndProbabilityAndReturnCodePointCount(targetPtNodePos, + MAX_WORD_LENGTH, targetWordCodePoints, &unigramProbability); + if (codePointCount <= 0) { + return; + } + const int probability = mDictStructurePolicy->getProbability( + unigramProbability, ngramProbability); + mSuggestionResults->addPrediction(targetWordCodePoints, codePointCount, probability); +} + void Dictionary::getPredictions(const PrevWordsInfo *const prevWordsInfo, SuggestionResults *const outSuggestionResults) const { TimeKeeper::setCurrentTime(); - int unigramProbability = 0; - int bigramCodePoints[MAX_WORD_LENGTH]; - BinaryDictionaryBigramsIterator bigramsIt = prevWordsInfo->getBigramsIteratorForPrediction( + NgramListenerForPrediction listener(prevWordsInfo, outSuggestionResults, mDictionaryStructureWithBufferPolicy.get()); - while (bigramsIt.hasNext()) { - bigramsIt.next(); - if (bigramsIt.getBigramPos() == NOT_A_DICT_POS) { - continue; - } - if (prevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */) - && bigramsIt.getProbability() == NOT_A_PROBABILITY) { - continue; - } - const int codePointCount = mDictionaryStructureWithBufferPolicy-> - getCodePointsAndProbabilityAndReturnCodePointCount(bigramsIt.getBigramPos(), - MAX_WORD_LENGTH, bigramCodePoints, &unigramProbability); - if (codePointCount <= 0) { - continue; - } - const int probability = mDictionaryStructureWithBufferPolicy->getProbability( - unigramProbability, bigramsIt.getProbability()); - outSuggestionResults->addPrediction(bigramCodePoints, codePointCount, probability); - } + int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; + prevWordsInfo->getPrevWordsTerminalPtNodePos( + mDictionaryStructureWithBufferPolicy.get(), prevWordsPtNodePos, + true /* tryLowerCaseSearch */); + mDictionaryStructureWithBufferPolicy->iterateNgramEntries(prevWordsPtNodePos, &listener); } int Dictionary::getProbability(const int *word, int length) const { @@ -103,7 +115,15 @@ int Dictionary::getNgramProbability(const PrevWordsInfo *const prevWordsInfo, co int nextWordPos = mDictionaryStructureWithBufferPolicy->getTerminalPtNodePositionOfWord(word, length, false /* forceLowerCaseSearch */); if (NOT_A_DICT_POS == nextWordPos) return NOT_A_PROBABILITY; - return getDictionaryStructurePolicy()->getProbabilityOfPtNode(prevWordsInfo, nextWordPos); + if (!prevWordsInfo) { + return getDictionaryStructurePolicy()->getProbabilityOfPtNode( + nullptr /* prevWordsPtNodePos */, nextWordPos); + } + int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; + prevWordsInfo->getPrevWordsTerminalPtNodePos( + mDictionaryStructureWithBufferPolicy.get(), prevWordsPtNodePos, + true /* tryLowerCaseSearch */); + return getDictionaryStructurePolicy()->getProbabilityOfPtNode(prevWordsPtNodePos, nextWordPos); } bool Dictionary::addUnigramEntry(const int *const word, const int length, diff --git a/native/jni/src/suggest/core/dictionary/dictionary.h b/native/jni/src/suggest/core/dictionary/dictionary.h index 3b41088fe..732d3b199 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.h +++ b/native/jni/src/suggest/core/dictionary/dictionary.h @@ -21,6 +21,7 @@ #include "defines.h" #include "jni.h" +#include "suggest/core/dictionary/ngram_listener.h" #include "suggest/core/dictionary/property/word_property.h" #include "suggest/core/policy/dictionary_header_structure_policy.h" #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" @@ -114,6 +115,21 @@ class Dictionary { typedef std::unique_ptr<SuggestInterface> SuggestInterfacePtr; + class NgramListenerForPrediction : public NgramListener { + public: + NgramListenerForPrediction(const PrevWordsInfo *const prevWordsInfo, + SuggestionResults *const suggestionResults, + const DictionaryStructureWithBufferPolicy *const dictStructurePolicy); + virtual void onVisitEntry(const int ngramProbability, const int targetPtNodePos); + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(NgramListenerForPrediction); + + const PrevWordsInfo *const mPrevWordsInfo; + SuggestionResults *const mSuggestionResults; + const DictionaryStructureWithBufferPolicy *const mDictStructurePolicy; + }; + static const int HEADER_ATTRIBUTE_BUFFER_SIZE; const DictionaryStructureWithBufferPolicy::StructurePolicyPtr diff --git a/native/jni/src/suggest/core/dictionary/ngram_listener.h b/native/jni/src/suggest/core/dictionary/ngram_listener.h new file mode 100644 index 000000000..88b88bafb --- /dev/null +++ b/native/jni/src/suggest/core/dictionary/ngram_listener.h @@ -0,0 +1,40 @@ +/* + * Copyright (C) 2014, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_NGRAM_LISTENER_H +#define LATINIME_NGRAM_LISTENER_H + +#include "defines.h" + +namespace latinime { + +/** + * Interface to iterate ngram entries. + */ +class NgramListener { + public: + virtual void onVisitEntry(const int ngramProbability, const int targetPtNodePos) = 0; + virtual ~NgramListener() {}; + + protected: + NgramListener() {} + + private: + DISALLOW_COPY_AND_ASSIGN(NgramListener); + +}; +} // namespace latinime +#endif /* LATINIME_NGRAM_LISTENER_H */ diff --git a/native/jni/src/suggest/core/layout/proximity_info_state.h b/native/jni/src/suggest/core/layout/proximity_info_state.h index 6b1a319aa..e6180fe17 100644 --- a/native/jni/src/suggest/core/layout/proximity_info_state.h +++ b/native/jni/src/suggest/core/layout/proximity_info_state.h @@ -215,13 +215,13 @@ class ProximityInfoState { std::vector<float> mSpeedRates; std::vector<float> mDirections; // probabilities of skipping or mapping to a key for each point. - std::vector<std::unordered_map<int, float> > mCharProbabilities; + std::vector<std::unordered_map<int, float>> mCharProbabilities; // The vector for the key code set which holds nearby keys of some trailing sampled input points // for each sampled input point. These nearby keys contain the next characters which can be in // the dictionary. Specifically, currently we are looking for keys nearby trailing sampled // inputs including the current input point. std::vector<ProximityInfoStateUtils::NearKeycodesSet> mSampledSearchKeySets; - std::vector<std::vector<int> > mSampledSearchKeyVectors; + std::vector<std::vector<int>> mSampledSearchKeyVectors; bool mTouchPositionCorrectionEnabled; int mInputProximities[MAX_PROXIMITY_CHARS_SIZE * MAX_WORD_LENGTH]; int mSampledInputSize; diff --git a/native/jni/src/suggest/core/layout/proximity_info_state_utils.cpp b/native/jni/src/suggest/core/layout/proximity_info_state_utils.cpp index ea3b02216..0aeb36aad 100644 --- a/native/jni/src/suggest/core/layout/proximity_info_state_utils.cpp +++ b/native/jni/src/suggest/core/layout/proximity_info_state_utils.cpp @@ -621,7 +621,7 @@ namespace latinime { const std::vector<int> *const sampledLengthCache, const std::vector<float> *const sampledNormalizedSquaredLengthCache, const ProximityInfo *const proximityInfo, - std::vector<std::unordered_map<int, float> > *charProbabilities) { + std::vector<std::unordered_map<int, float>> *charProbabilities) { charProbabilities->resize(sampledInputSize); // Calculates probabilities of using a point as a correlated point with the character // for each point. @@ -822,9 +822,9 @@ namespace latinime { /* static */ void ProximityInfoStateUtils::updateSampledSearchKeySets( const ProximityInfo *const proximityInfo, const int sampledInputSize, const int lastSavedInputSize, const std::vector<int> *const sampledLengthCache, - const std::vector<std::unordered_map<int, float> > *const charProbabilities, + const std::vector<std::unordered_map<int, float>> *const charProbabilities, std::vector<NearKeycodesSet> *sampledSearchKeySets, - std::vector<std::vector<int> > *sampledSearchKeyVectors) { + std::vector<std::vector<int>> *sampledSearchKeyVectors) { sampledSearchKeySets->resize(sampledInputSize); sampledSearchKeyVectors->resize(sampledInputSize); const int readForwordLength = static_cast<int>( @@ -868,7 +868,7 @@ namespace latinime { /* static */ bool ProximityInfoStateUtils::suppressCharProbabilities(const int mostCommonKeyWidth, const int sampledInputSize, const std::vector<int> *const lengthCache, const int index0, const int index1, - std::vector<std::unordered_map<int, float> > *charProbabilities) { + std::vector<std::unordered_map<int, float>> *charProbabilities) { ASSERT(0 <= index0 && index0 < sampledInputSize); ASSERT(0 <= index1 && index1 < sampledInputSize); const float keyWidthFloat = static_cast<float>(mostCommonKeyWidth); @@ -933,7 +933,7 @@ namespace latinime { // returns probability of generating the word. /* static */ float ProximityInfoStateUtils::getMostProbableString( const ProximityInfo *const proximityInfo, const int sampledInputSize, - const std::vector<std::unordered_map<int, float> > *const charProbabilities, + const std::vector<std::unordered_map<int, float>> *const charProbabilities, int *const codePointBuf) { ASSERT(sampledInputSize >= 0); memset(codePointBuf, 0, sizeof(codePointBuf[0]) * MAX_WORD_LENGTH); diff --git a/native/jni/src/suggest/core/layout/proximity_info_state_utils.h b/native/jni/src/suggest/core/layout/proximity_info_state_utils.h index 211a79737..4043334e6 100644 --- a/native/jni/src/suggest/core/layout/proximity_info_state_utils.h +++ b/native/jni/src/suggest/core/layout/proximity_info_state_utils.h @@ -72,13 +72,13 @@ class ProximityInfoStateUtils { const std::vector<int> *const sampledLengthCache, const std::vector<float> *const sampledNormalizedSquaredLengthCache, const ProximityInfo *const proximityInfo, - std::vector<std::unordered_map<int, float> > *charProbabilities); + std::vector<std::unordered_map<int, float>> *charProbabilities); static void updateSampledSearchKeySets(const ProximityInfo *const proximityInfo, const int sampledInputSize, const int lastSavedInputSize, const std::vector<int> *const sampledLengthCache, - const std::vector<std::unordered_map<int, float> > *const charProbabilities, + const std::vector<std::unordered_map<int, float>> *const charProbabilities, std::vector<NearKeycodesSet> *sampledSearchKeySets, - std::vector<std::vector<int> > *sampledSearchKeyVectors); + std::vector<std::vector<int>> *sampledSearchKeyVectors); static float getPointToKeyByIdLength(const float maxPointToKeyLength, const std::vector<float> *const sampledNormalizedSquaredLengthCache, const int keyCount, const int inputIndex, const int keyId); @@ -105,7 +105,7 @@ class ProximityInfoStateUtils { // TODO: Move to most_probable_string_utils.h static float getMostProbableString(const ProximityInfo *const proximityInfo, const int sampledInputSize, - const std::vector<std::unordered_map<int, float> > *const charProbabilities, + const std::vector<std::unordered_map<int, float>> *const charProbabilities, int *const codePointBuf); private: @@ -147,7 +147,7 @@ class ProximityInfoStateUtils { const int index2); static bool suppressCharProbabilities(const int mostCommonKeyWidth, const int sampledInputSize, const std::vector<int> *const lengthCache, const int index0, - const int index1, std::vector<std::unordered_map<int, float> > *charProbabilities); + const int index1, std::vector<std::unordered_map<int, float>> *charProbabilities); static float calculateSquaredDistanceFromSweetSpotCenter( const ProximityInfo *const proximityInfo, const std::vector<int> *const sampledInputXs, const std::vector<int> *const sampledInputYs, const int keyIndex, diff --git a/native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h b/native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h index a61227626..6da390e55 100644 --- a/native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h +++ b/native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h @@ -30,7 +30,7 @@ namespace latinime { */ class DictionaryHeaderStructurePolicy { public: - typedef std::map<std::vector<int>, std::vector<int> > AttributeMap; + typedef std::map<std::vector<int>, std::vector<int>> AttributeMap; virtual ~DictionaryHeaderStructurePolicy() {} diff --git a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h index 7ad20e782..7e3bf3ff6 100644 --- a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h +++ b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h @@ -30,6 +30,7 @@ class DicNodeVector; class DictionaryBigramsStructurePolicy; class DictionaryHeaderStructurePolicy; class DictionaryShortcutsStructurePolicy; +class NgramListener; class PrevWordsInfo; class UnigramProperty; @@ -58,9 +59,12 @@ class DictionaryStructureWithBufferPolicy { virtual int getProbability(const int unigramProbability, const int bigramProbability) const = 0; - virtual int getProbabilityOfPtNode(const PrevWordsInfo *const prevWordsInfo, + virtual int getProbabilityOfPtNode(const int *const prevWordsPtNodePos, const int nodePos) const = 0; + virtual void iterateNgramEntries(const int *const prevWordsPtNodePos, + NgramListener *const listener) const = 0; + virtual int getShortcutPositionOfPtNode(const int nodePos) const = 0; virtual BinaryDictionaryBigramsIterator getBigramsIteratorOfPtNode(const int nodePos) const = 0; diff --git a/native/jni/src/suggest/core/session/prev_words_info.h b/native/jni/src/suggest/core/session/prev_words_info.h index 76276f528..e44e876e9 100644 --- a/native/jni/src/suggest/core/session/prev_words_info.h +++ b/native/jni/src/suggest/core/session/prev_words_info.h @@ -90,13 +90,6 @@ class PrevWordsInfo { } } - BinaryDictionaryBigramsIterator getBigramsIteratorForPrediction( - const DictionaryStructureWithBufferPolicy *const dictStructurePolicy) const { - return getBigramsIteratorForWordWithTryingLowerCaseSearch( - dictStructurePolicy, mPrevWordCodePoints[0], mPrevWordCodePointCount[0], - mIsBeginningOfSentence[0]); - } - // n is 1-indexed. const int *getNthPrevWordCodePoints(const int n) const { if (n <= 0 || n > MAX_PREV_WORD_COUNT_FOR_N_GRAM) { @@ -154,46 +147,6 @@ class PrevWordsInfo { codePoints, codePointCount, true /* forceLowerCaseSearch */); } - static BinaryDictionaryBigramsIterator getBigramsIteratorForWordWithTryingLowerCaseSearch( - const DictionaryStructureWithBufferPolicy *const dictStructurePolicy, - const int *const wordCodePoints, const int wordCodePointCount, - const bool isBeginningOfSentence) { - if (!dictStructurePolicy || !wordCodePoints || wordCodePointCount > MAX_WORD_LENGTH) { - return BinaryDictionaryBigramsIterator(); - } - int codePoints[MAX_WORD_LENGTH]; - int codePointCount = wordCodePointCount; - memmove(codePoints, wordCodePoints, sizeof(int) * codePointCount); - if (isBeginningOfSentence) { - codePointCount = CharUtils::attachBeginningOfSentenceMarker(codePoints, - codePointCount, MAX_WORD_LENGTH); - if (codePointCount <= 0) { - return BinaryDictionaryBigramsIterator(); - } - } - BinaryDictionaryBigramsIterator bigramsIt = getBigramsIteratorForWord(dictStructurePolicy, - codePoints, codePointCount, false /* forceLowerCaseSearch */); - // getBigramsIteratorForWord returns an empty iterator if this word isn't in the dictionary - // or has no bigrams. - if (bigramsIt.hasNext()) { - return bigramsIt; - } - // If no bigrams for this exact word, search again in lower case. - return getBigramsIteratorForWord(dictStructurePolicy, codePoints, - codePointCount, true /* forceLowerCaseSearch */); - } - - static BinaryDictionaryBigramsIterator getBigramsIteratorForWord( - const DictionaryStructureWithBufferPolicy *const dictStructurePolicy, - const int *wordCodePoints, const int wordCodePointCount, - const bool forceLowerCaseSearch) { - if (!wordCodePoints || wordCodePointCount <= 0) return BinaryDictionaryBigramsIterator(); - const int terminalPtNodePos = dictStructurePolicy->getTerminalPtNodePositionOfWord( - wordCodePoints, wordCodePointCount, forceLowerCaseSearch); - if (NOT_A_DICT_POS == terminalPtNodePos) return BinaryDictionaryBigramsIterator(); - return dictStructurePolicy->getBigramsIteratorOfPtNode(terminalPtNodePos); - } - void clear() { for (size_t i = 0; i < NELEMS(mPrevWordCodePoints); ++i) { mPrevWordCodePointCount[i] = 0; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp index 327741065..994c42505 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp @@ -28,6 +28,7 @@ #include "suggest/core/dicnode/dic_node.h" #include "suggest/core/dicnode/dic_node_vector.h" +#include "suggest/core/dictionary/ngram_listener.h" #include "suggest/core/dictionary/property/bigram_property.h" #include "suggest/core/dictionary/property/unigram_property.h" #include "suggest/core/dictionary/property/word_property.h" @@ -131,7 +132,7 @@ int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability, } } -int Ver4PatriciaTriePolicy::getProbabilityOfPtNode(const PrevWordsInfo *const prevWordsInfo, +int Ver4PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtNodePos, const int ptNodePos) const { if (ptNodePos == NOT_A_DICT_POS) { return NOT_A_PROBABILITY; @@ -140,9 +141,9 @@ int Ver4PatriciaTriePolicy::getProbabilityOfPtNode(const PrevWordsInfo *const pr if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) { return NOT_A_PROBABILITY; } - if (prevWordsInfo) { + if (prevWordsPtNodePos) { BinaryDictionaryBigramsIterator bigramsIt = - prevWordsInfo->getBigramsIteratorForPrediction(this /* dictStructurePolicy */); + getBigramsIteratorOfPtNode(prevWordsPtNodePos[0]); while (bigramsIt.hasNext()) { bigramsIt.next(); if (bigramsIt.getBigramPos() == ptNodePos @@ -155,6 +156,18 @@ int Ver4PatriciaTriePolicy::getProbabilityOfPtNode(const PrevWordsInfo *const pr return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY); } +void Ver4PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordsPtNodePos, + NgramListener *const listener) const { + if (!prevWordsPtNodePos) { + return; + } + BinaryDictionaryBigramsIterator bigramsIt = getBigramsIteratorOfPtNode(prevWordsPtNodePos[0]); + while (bigramsIt.hasNext()) { + bigramsIt.next(); + listener->onVisitEntry(bigramsIt.getProbability(), bigramsIt.getBigramPos()); + } +} + int Ver4PatriciaTriePolicy::getShortcutPositionOfPtNode(const int ptNodePos) const { if (ptNodePos == NOT_A_DICT_POS) { return NOT_A_DICT_POS; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h index c80a73af7..ff69de7c0 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h @@ -90,8 +90,10 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { int getProbability(const int unigramProbability, const int bigramProbability) const; - int getProbabilityOfPtNode(const PrevWordsInfo *const prevWordsInfo, - const int ptNodePos) const; + int getProbabilityOfPtNode(const int *const prevWordsPtNodePos, const int ptNodePos) const; + + void iterateNgramEntries(const int *const prevWordsPtNodePos, + NgramListener *const listener) const; int getShortcutPositionOfPtNode(const int ptNodePos) const; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp index b909e8268..53415aeb6 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp @@ -21,6 +21,7 @@ #include "suggest/core/dicnode/dic_node.h" #include "suggest/core/dicnode/dic_node_vector.h" #include "suggest/core/dictionary/binary_dictionary_bigrams_iterator.h" +#include "suggest/core/dictionary/ngram_listener.h" #include "suggest/core/session/prev_words_info.h" #include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h" #include "suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h" @@ -296,7 +297,7 @@ int PatriciaTriePolicy::getProbability(const int unigramProbability, } } -int PatriciaTriePolicy::getProbabilityOfPtNode(const PrevWordsInfo *const prevWordsInfo, +int PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtNodePos, const int ptNodePos) const { if (ptNodePos == NOT_A_DICT_POS) { return NOT_A_PROBABILITY; @@ -309,9 +310,9 @@ int PatriciaTriePolicy::getProbabilityOfPtNode(const PrevWordsInfo *const prevWo // for shortcuts). return NOT_A_PROBABILITY; } - if (prevWordsInfo) { + if (prevWordsPtNodePos) { BinaryDictionaryBigramsIterator bigramsIt = - prevWordsInfo->getBigramsIteratorForPrediction(this /* dictStructurePolicy */); + getBigramsIteratorOfPtNode(prevWordsPtNodePos[0]); while (bigramsIt.hasNext()) { bigramsIt.next(); if (bigramsIt.getBigramPos() == ptNodePos @@ -324,6 +325,18 @@ int PatriciaTriePolicy::getProbabilityOfPtNode(const PrevWordsInfo *const prevWo return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY); } +void PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordsPtNodePos, + NgramListener *const listener) const { + if (!prevWordsPtNodePos) { + return; + } + BinaryDictionaryBigramsIterator bigramsIt = getBigramsIteratorOfPtNode(prevWordsPtNodePos[0]); + while (bigramsIt.hasNext()) { + bigramsIt.next(); + listener->onVisitEntry(bigramsIt.getProbability(), bigramsIt.getBigramPos()); + } +} + int PatriciaTriePolicy::getShortcutPositionOfPtNode(const int ptNodePos) const { if (ptNodePos == NOT_A_DICT_POS) { return NOT_A_DICT_POS; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h index 1dd5705be..07cb72b23 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h @@ -63,7 +63,10 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { int getProbability(const int unigramProbability, const int bigramProbability) const; - int getProbabilityOfPtNode(const PrevWordsInfo *const prevWordsInfo, const int ptNodePos) const; + int getProbabilityOfPtNode(const int *const prevWordsPtNodePos, const int ptNodePos) const; + + void iterateNgramEntries(const int *const prevWordsPtNodePos, + NgramListener *const listener) const; int getShortcutPositionOfPtNode(const int ptNodePos) const; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp index cada3d1f7..22f7e1182 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp @@ -20,6 +20,7 @@ #include "suggest/core/dicnode/dic_node.h" #include "suggest/core/dicnode/dic_node_vector.h" +#include "suggest/core/dictionary/ngram_listener.h" #include "suggest/core/dictionary/property/bigram_property.h" #include "suggest/core/dictionary/property/unigram_property.h" #include "suggest/core/dictionary/property/word_property.h" @@ -121,7 +122,7 @@ int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability, } } -int Ver4PatriciaTriePolicy::getProbabilityOfPtNode(const PrevWordsInfo *const prevWordsInfo, +int Ver4PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtNodePos, const int ptNodePos) const { if (ptNodePos == NOT_A_DICT_POS) { return NOT_A_PROBABILITY; @@ -130,9 +131,9 @@ int Ver4PatriciaTriePolicy::getProbabilityOfPtNode(const PrevWordsInfo *const pr if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) { return NOT_A_PROBABILITY; } - if (prevWordsInfo) { + if (prevWordsPtNodePos) { BinaryDictionaryBigramsIterator bigramsIt = - prevWordsInfo->getBigramsIteratorForPrediction(this /* dictStructurePolicy */); + getBigramsIteratorOfPtNode(prevWordsPtNodePos[0]); while (bigramsIt.hasNext()) { bigramsIt.next(); if (bigramsIt.getBigramPos() == ptNodePos @@ -145,6 +146,18 @@ int Ver4PatriciaTriePolicy::getProbabilityOfPtNode(const PrevWordsInfo *const pr return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY); } +void Ver4PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordsPtNodePos, + NgramListener *const listener) const { + if (!prevWordsPtNodePos) { + return; + } + BinaryDictionaryBigramsIterator bigramsIt = getBigramsIteratorOfPtNode(prevWordsPtNodePos[0]); + while (bigramsIt.hasNext()) { + bigramsIt.next(); + listener->onVisitEntry(bigramsIt.getProbability(), bigramsIt.getBigramPos()); + } +} + int Ver4PatriciaTriePolicy::getShortcutPositionOfPtNode(const int ptNodePos) const { if (ptNodePos == NOT_A_DICT_POS) { return NOT_A_DICT_POS; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h index b0f16cd01..c5b6a80c0 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h @@ -72,7 +72,10 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { int getProbability(const int unigramProbability, const int bigramProbability) const; - int getProbabilityOfPtNode(const PrevWordsInfo *const prevWordsInfo, const int ptNodePos) const; + int getProbabilityOfPtNode(const int *const prevWordsPtNodePos, const int ptNodePos) const; + + void iterateNgramEntries(const int *const prevWordsPtNodePos, + NgramListener *const listener) const; int getShortcutPositionOfPtNode(const int ptNodePos) const; diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h index 3ff80aeec..9910777b8 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h @@ -84,7 +84,7 @@ class ForgettingCurveUtils { static const int STRONG_BASE_PROBABILITY; static const int AGGRESSIVE_BASE_PROBABILITY; - std::vector<std::vector<std::vector<int> > > mTables; + std::vector<std::vector<std::vector<int>>> mTables; static int getBaseProbabilityForLevel(const int tableId, const int level); }; diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp index a7d86f9ae..c70047638 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp @@ -98,6 +98,43 @@ bool TrieMap::put(const int key, const uint64_t value, const int bitmapEntryInde return putInternal(unsignedKey, value, getBitShuffledKey(unsignedKey), bitmapEntryIndex, readEntry(bitmapEntryIndex), 0 /* level */); } +/** + * Iterate next entry in a certain level. + * + * @param iterationState the iteration state that will be read and updated in this method. + * @param outKey the output key + * @return Result instance. mIsValid is false when all entries are iterated. + */ +const TrieMap::Result TrieMap::iterateNext(std::vector<TableIterationState> *const iterationState, + int *const outKey) const { + while (!iterationState->empty()) { + TableIterationState &state = iterationState->back(); + if (state.mTableSize <= state.mCurrentIndex) { + // Move to parent. + iterationState->pop_back(); + } else { + const int entryIndex = state.mTableIndex + state.mCurrentIndex; + state.mCurrentIndex += 1; + const Entry entry = readEntry(entryIndex); + if (entry.isBitmapEntry()) { + // Move to child. + iterationState->emplace_back(popCount(entry.getBitmap()), entry.getTableIndex()); + } else { + if (outKey) { + *outKey = entry.getKey(); + } + if (!entry.hasTerminalLink()) { + return Result(entry.getValue(), true, INVALID_INDEX); + } + const int valueEntryIndex = entry.getValueEntryIndex(); + const Entry valueEntry = readEntry(valueEntryIndex); + return Result(valueEntry.getValueOfValueEntry(), true, valueEntryIndex + 1); + } + } + } + // Visited all entries. + return Result(0, false, INVALID_INDEX); +} /** * Shuffle bits of the key in the fixed order. diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h index 2a9051f98..b5bcc3bc8 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h @@ -44,6 +44,117 @@ class TrieMap { mNextLevelBitmapEntryIndex(nextLevelBitmapEntryIndex) {} }; + /** + * Struct to record iteration state in a table. + */ + struct TableIterationState { + int mTableSize; + int mTableIndex; + int mCurrentIndex; + + TableIterationState(const int tableSize, const int tableIndex) + : mTableSize(tableSize), mTableIndex(tableIndex), mCurrentIndex(0) {} + }; + + class TrieMapRange; + class TrieMapIterator { + public: + class IterationResult { + public: + IterationResult(const TrieMap *const trieMap, const int key, const uint64_t value, + const int nextLeveBitmapEntryIndex) + : mTrieMap(trieMap), mKey(key), mValue(value), + mNextLevelBitmapEntryIndex(nextLeveBitmapEntryIndex) {} + + const TrieMapRange getEntriesInNextLevel() const { + return TrieMapRange(mTrieMap, mNextLevelBitmapEntryIndex); + } + + bool hasNextLevelMap() const { + return mNextLevelBitmapEntryIndex != INVALID_INDEX; + } + + AK_FORCE_INLINE int key() const { + return mKey; + } + + AK_FORCE_INLINE uint64_t value() const { + return mValue; + } + + private: + const TrieMap *const mTrieMap; + const int mKey; + const uint64_t mValue; + const int mNextLevelBitmapEntryIndex; + }; + + TrieMapIterator(const TrieMap *const trieMap, const int bitmapEntryIndex) + : mTrieMap(trieMap), mStateStack(), mBaseBitmapEntryIndex(bitmapEntryIndex), + mKey(0), mValue(0), mIsValid(false), mNextLevelBitmapEntryIndex(INVALID_INDEX) { + if (!trieMap) { + return; + } + const Entry bitmapEntry = mTrieMap->readEntry(mBaseBitmapEntryIndex); + mStateStack.emplace_back( + mTrieMap->popCount(bitmapEntry.getBitmap()), bitmapEntry.getTableIndex()); + this->operator++(); + } + + const IterationResult operator*() const { + return IterationResult(mTrieMap, mKey, mValue, mNextLevelBitmapEntryIndex); + } + + bool operator!=(const TrieMapIterator &other) const { + // Caveat: This works only for for loops. + return mIsValid || other.mIsValid; + } + + const TrieMapIterator &operator++() { + const Result result = mTrieMap->iterateNext(&mStateStack, &mKey); + mValue = result.mValue; + mIsValid = result.mIsValid; + mNextLevelBitmapEntryIndex = result.mNextLevelBitmapEntryIndex; + return *this; + } + + private: + DISALLOW_DEFAULT_CONSTRUCTOR(TrieMapIterator); + DISALLOW_ASSIGNMENT_OPERATOR(TrieMapIterator); + + const TrieMap *const mTrieMap; + std::vector<TrieMap::TableIterationState> mStateStack; + const int mBaseBitmapEntryIndex; + int mKey; + uint64_t mValue; + bool mIsValid; + int mNextLevelBitmapEntryIndex; + }; + + /** + * Class to support iterating entries in TrieMap by range base for loops. + */ + class TrieMapRange { + public: + TrieMapRange(const TrieMap *const trieMap, const int bitmapEntryIndex) + : mTrieMap(trieMap), mBaseBitmapEntryIndex(bitmapEntryIndex) {}; + + TrieMapIterator begin() const { + return TrieMapIterator(mTrieMap, mBaseBitmapEntryIndex); + } + + const TrieMapIterator end() const { + return TrieMapIterator(nullptr, INVALID_INDEX); + } + + private: + DISALLOW_DEFAULT_CONSTRUCTOR(TrieMapRange); + DISALLOW_ASSIGNMENT_OPERATOR(TrieMapRange); + + const TrieMap *const mTrieMap; + const int mBaseBitmapEntryIndex; + }; + static const int INVALID_INDEX; static const uint64_t MAX_VALUE; @@ -73,6 +184,14 @@ class TrieMap { bool put(const int key, const uint64_t value, const int bitmapEntryIndex); + const TrieMapRange getEntriesInRootLevel() const { + return getEntriesInSpecifiedLevel(ROOT_BITMAP_ENTRY_INDEX); + } + + const TrieMapRange getEntriesInSpecifiedLevel(const int bitmapEntryIndex) const { + return TrieMapRange(this, bitmapEntryIndex); + } + private: DISALLOW_COPY_AND_ASSIGN(TrieMap); @@ -171,6 +290,8 @@ class TrieMap { bool addNewEntryByExpandingTable(const uint32_t key, const uint64_t value, const int tableIndex, const uint32_t bitmap, const int bitmapEntryIndex, const int label); + const Result iterateNext(std::vector<TableIterationState> *const iterationState, + int *const outKey) const; AK_FORCE_INLINE const Entry readEntry(const int entryIndex) const { return Entry(readField0(entryIndex), readField1(entryIndex)); diff --git a/native/jni/tests/suggest/policyimpl/dictionary/utils/trie_map_test.cpp b/native/jni/tests/suggest/policyimpl/dictionary/utils/trie_map_test.cpp index 5dd782277..df778b6cf 100644 --- a/native/jni/tests/suggest/policyimpl/dictionary/utils/trie_map_test.cpp +++ b/native/jni/tests/suggest/policyimpl/dictionary/utils/trie_map_test.cpp @@ -54,7 +54,7 @@ TEST(TrieMapTest, TestSetAndGetLarge) { EXPECT_TRUE(trieMap.putRoot(i, i)); } for (int i = 0; i < ELEMENT_COUNT; ++i) { - EXPECT_EQ(trieMap.getRoot(i).mValue, static_cast<uint64_t>(i)); + EXPECT_EQ(static_cast<uint64_t>(i), trieMap.getRoot(i).mValue); } } @@ -78,7 +78,7 @@ TEST(TrieMapTest, TestRandSetAndGetLarge) { testKeyValuePairs[key] = value; } for (const auto &v : testKeyValuePairs) { - EXPECT_EQ(trieMap.getRoot(v.first).mValue, v.second); + EXPECT_EQ(v.second, trieMap.getRoot(v.first).mValue); } } @@ -163,6 +163,61 @@ TEST(TrieMapTest, TestMultiLevel) { } } } + + // Iteration + for (const auto &firstLevelEntry : trieMap.getEntriesInRootLevel()) { + EXPECT_EQ(trieMap.getRoot(firstLevelEntry.key()).mValue, firstLevelEntry.value()); + EXPECT_EQ(firstLevelEntries[firstLevelEntry.key()], firstLevelEntry.value()); + firstLevelEntries.erase(firstLevelEntry.key()); + for (const auto &secondLevelEntry : firstLevelEntry.getEntriesInNextLevel()) { + EXPECT_EQ(twoLevelMap[firstLevelEntry.key()][secondLevelEntry.key()], + secondLevelEntry.value()); + twoLevelMap[firstLevelEntry.key()].erase(secondLevelEntry.key()); + for (const auto &thirdLevelEntry : secondLevelEntry.getEntriesInNextLevel()) { + EXPECT_EQ(threeLevelMap[firstLevelEntry.key()][secondLevelEntry.key()] + [thirdLevelEntry.key()], thirdLevelEntry.value()); + threeLevelMap[firstLevelEntry.key()][secondLevelEntry.key()].erase( + thirdLevelEntry.key()); + } + } + } + + // Ensure all entries have been traversed. + EXPECT_TRUE(firstLevelEntries.empty()); + for (const auto &secondLevelEntry : twoLevelMap) { + EXPECT_TRUE(secondLevelEntry.second.empty()); + } + for (const auto &secondLevelEntry : threeLevelMap) { + for (const auto &thirdLevelEntry : secondLevelEntry.second) { + EXPECT_TRUE(thirdLevelEntry.second.empty()); + } + } +} + +TEST(TrieMapTest, TestIteration) { + static const int ELEMENT_COUNT = 200000; + TrieMap trieMap; + std::unordered_map<int, uint64_t> testKeyValuePairs; + + // Use the uniform integer distribution [S_INT_MIN, S_INT_MAX]. + std::uniform_int_distribution<int> keyDistribution(S_INT_MIN, S_INT_MAX); + auto keyRandomNumberGenerator = std::bind(keyDistribution, std::mt19937()); + + // Use the uniform distribution [0, TrieMap::MAX_VALUE]. + std::uniform_int_distribution<uint64_t> valueDistribution(0, TrieMap::MAX_VALUE); + auto valueRandomNumberGenerator = std::bind(valueDistribution, std::mt19937()); + for (int i = 0; i < ELEMENT_COUNT; ++i) { + const int key = keyRandomNumberGenerator(); + const uint64_t value = valueRandomNumberGenerator(); + EXPECT_TRUE(trieMap.putRoot(key, value)); + testKeyValuePairs[key] = value; + } + for (const auto &entry : trieMap.getEntriesInRootLevel()) { + EXPECT_EQ(trieMap.getRoot(entry.key()).mValue, entry.value()); + EXPECT_EQ(testKeyValuePairs[entry.key()], entry.value()); + testKeyValuePairs.erase(entry.key()); + } + EXPECT_TRUE(testKeyValuePairs.empty()); } } // namespace |