diff options
Diffstat (limited to 'native/jni/src')
67 files changed, 662 insertions, 557 deletions
diff --git a/native/jni/src/defines.h b/native/jni/src/defines.h index 57e18884d..e55c9eb8a 100644 --- a/native/jni/src/defines.h +++ b/native/jni/src/defines.h @@ -301,7 +301,7 @@ static inline void prof_out(void) { #define NOT_A_DICT_POS (S_INT_MIN) #define NOT_A_WORD_ID (S_INT_MIN) #define NOT_A_TIMESTAMP (-1) -#define NOT_A_LANGUAGE_WEIGHT (-1.0f) +#define NOT_A_WEIGHT_OF_LANG_MODEL_VS_SPATIAL_MODEL (-1.0f) // A special value to mean the first word confidence makes no sense in this case, // e.g. this is not a multi-word suggestion. @@ -338,7 +338,7 @@ static inline void prof_out(void) { #define MAX_POINTER_COUNT_G 2 // (MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1)-gram is supported. -#define MAX_PREV_WORD_COUNT_FOR_N_GRAM 1 +#define MAX_PREV_WORD_COUNT_FOR_N_GRAM 2 #define DISALLOW_DEFAULT_CONSTRUCTOR(TypeName) \ TypeName() = delete diff --git a/native/jni/src/suggest/core/dicnode/dic_node.h b/native/jni/src/suggest/core/dicnode/dic_node.h index 3970963e8..5214077dc 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node.h +++ b/native/jni/src/suggest/core/dicnode/dic_node.h @@ -105,7 +105,7 @@ class DicNode { } // Init for root with prevWordIds which is used for n-gram - void initAsRoot(const int rootPtNodeArrayPos, const int *const prevWordIds) { + void initAsRoot(const int rootPtNodeArrayPos, const WordIdArrayView prevWordIds) { mIsCachedForNextSuggestion = false; mDicNodeProperties.init(rootPtNodeArrayPos, prevWordIds); mDicNodeState.init(); @@ -115,12 +115,11 @@ class DicNode { // Init for root with previous word void initAsRootWithPreviousWord(const DicNode *const dicNode, const int rootPtNodeArrayPos) { mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion; - int newPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> newPrevWordIds; newPrevWordIds[0] = dicNode->mDicNodeProperties.getWordId(); - for (size_t i = 1; i < NELEMS(newPrevWordIds); ++i) { - newPrevWordIds[i] = dicNode->getPrevWordIds()[i - 1]; - } - mDicNodeProperties.init(rootPtNodeArrayPos, newPrevWordIds); + dicNode->getPrevWordIds().limit(newPrevWordIds.size() - 1) + .copyToArray(&newPrevWordIds, 1 /* offset */); + mDicNodeProperties.init(rootPtNodeArrayPos, WordIdArrayView::fromArray(newPrevWordIds)); mDicNodeState.initAsRootWithPreviousWord(&dicNode->mDicNodeState, dicNode->mDicNodeProperties.getDepth()); PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); @@ -203,8 +202,7 @@ class DicNode { return mDicNodeProperties.getWordId(); } - // TODO: Use view class to return word id array. - const int *getPrevWordIds() const { + const WordIdArrayView getPrevWordIds() const { return mDicNodeProperties.getPrevWordIds(); } @@ -297,8 +295,9 @@ class DicNode { } // Used to prune nodes - float getCompoundDistance(const float languageWeight) const { - return mDicNodeState.mDicNodeStateScoring.getCompoundDistance(languageWeight); + float getCompoundDistance(const float weightOfLangModelVsSpatialModel) const { + return mDicNodeState.mDicNodeStateScoring.getCompoundDistance( + weightOfLangModelVsSpatialModel); } AK_FORCE_INLINE const int *getOutputWordBuf() const { diff --git a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp index fe5fe8448..7d2898b7a 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp +++ b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp @@ -28,7 +28,7 @@ namespace latinime { /* static */ void DicNodeUtils::initAsRoot( const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy, - const int *const prevWordIds, DicNode *const newRootDicNode) { + const WordIdArrayView prevWordIds, DicNode *const newRootDicNode) { newRootDicNode->initAsRoot(dictionaryStructurePolicy->getRootPosition(), prevWordIds); } diff --git a/native/jni/src/suggest/core/dicnode/dic_node_utils.h b/native/jni/src/suggest/core/dicnode/dic_node_utils.h index 961a1c29d..b891a842a 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_utils.h +++ b/native/jni/src/suggest/core/dicnode/dic_node_utils.h @@ -18,6 +18,7 @@ #define LATINIME_DIC_NODE_UTILS_H #include "defines.h" +#include "utils/int_array_view.h" namespace latinime { @@ -30,7 +31,7 @@ class DicNodeUtils { public: static void initAsRoot( const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy, - const int *const prevWordIds, DicNode *const newRootDicNode); + const WordIdArrayView prevWordIds, DicNode *const newRootDicNode); static void initAsRootWithPreviousWord( const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy, const DicNode *const prevWordLastDicNode, DicNode *const newRootDicNode); diff --git a/native/jni/src/suggest/core/dicnode/internal/dic_node_properties.h b/native/jni/src/suggest/core/dicnode/internal/dic_node_properties.h index 6a1b84273..1b796b5d4 100644 --- a/native/jni/src/suggest/core/dicnode/internal/dic_node_properties.h +++ b/native/jni/src/suggest/core/dicnode/internal/dic_node_properties.h @@ -18,8 +18,10 @@ #define LATINIME_DIC_NODE_PROPERTIES_H #include <cstdint> +#include <cstdlib> #include "defines.h" +#include "utils/int_array_view.h" namespace latinime { @@ -30,29 +32,31 @@ class DicNodeProperties { public: AK_FORCE_INLINE DicNodeProperties() : mChildrenPtNodeArrayPos(NOT_A_DICT_POS), mDicNodeCodePoint(NOT_A_CODE_POINT), - mWordId(NOT_A_WORD_ID), mDepth(0), mLeavingDepth(0) {} + mWordId(NOT_A_WORD_ID), mDepth(0), mLeavingDepth(0), mPrevWordCount(0) {} ~DicNodeProperties() {} // Should be called only once per DicNode is initialized. void init(const int childrenPos, const int nodeCodePoint, const int wordId, - const uint16_t depth, const uint16_t leavingDepth, const int *const prevWordIds) { + const uint16_t depth, const uint16_t leavingDepth, const WordIdArrayView prevWordIds) { mChildrenPtNodeArrayPos = childrenPos; mDicNodeCodePoint = nodeCodePoint; mWordId = wordId; mDepth = depth; mLeavingDepth = leavingDepth; - memmove(mPrevWordIds, prevWordIds, sizeof(mPrevWordIds)); + prevWordIds.copyToArray(&mPrevWordIds, 0 /* offset */); + mPrevWordCount = prevWordIds.size(); } // Init for root with prevWordsPtNodePos which is used for n-gram - void init(const int rootPtNodeArrayPos, const int *const prevWordIds) { + void init(const int rootPtNodeArrayPos, const WordIdArrayView prevWordIds) { mChildrenPtNodeArrayPos = rootPtNodeArrayPos; mDicNodeCodePoint = NOT_A_CODE_POINT; mWordId = NOT_A_WORD_ID; mDepth = 0; mLeavingDepth = 0; - memmove(mPrevWordIds, prevWordIds, sizeof(mPrevWordIds)); + prevWordIds.copyToArray(&mPrevWordIds, 0 /* offset */); + mPrevWordCount = prevWordIds.size(); } void initByCopy(const DicNodeProperties *const dicNodeProp) { @@ -61,7 +65,9 @@ class DicNodeProperties { mWordId = dicNodeProp->mWordId; mDepth = dicNodeProp->mDepth; mLeavingDepth = dicNodeProp->mLeavingDepth; - memmove(mPrevWordIds, dicNodeProp->mPrevWordIds, sizeof(mPrevWordIds)); + const WordIdArrayView prevWordIdArrayView = dicNodeProp->getPrevWordIds(); + prevWordIdArrayView.copyToArray(&mPrevWordIds, 0 /* offset */); + mPrevWordCount = prevWordIdArrayView.size(); } // Init as passing child @@ -71,7 +77,9 @@ class DicNodeProperties { mWordId = dicNodeProp->mWordId; mDepth = dicNodeProp->mDepth + 1; // Increment the depth of a passing child mLeavingDepth = dicNodeProp->mLeavingDepth; - memmove(mPrevWordIds, dicNodeProp->mPrevWordIds, sizeof(mPrevWordIds)); + const WordIdArrayView prevWordIdArrayView = dicNodeProp->getPrevWordIds(); + prevWordIdArrayView.copyToArray(&mPrevWordIds, 0 /* offset */); + mPrevWordCount = prevWordIdArrayView.size(); } int getChildrenPtNodeArrayPos() const { @@ -99,8 +107,8 @@ class DicNodeProperties { return (mChildrenPtNodeArrayPos != NOT_A_DICT_POS) || mDepth != mLeavingDepth; } - const int *getPrevWordIds() const { - return mPrevWordIds; + const WordIdArrayView getPrevWordIds() const { + return WordIdArrayView::fromArray(mPrevWordIds).limit(mPrevWordCount); } int getWordId() const { @@ -116,7 +124,8 @@ class DicNodeProperties { int mWordId; uint16_t mDepth; uint16_t mLeavingDepth; - int mPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> mPrevWordIds; + size_t mPrevWordCount; }; } // namespace latinime #endif // LATINIME_DIC_NODE_PROPERTIES_H diff --git a/native/jni/src/suggest/core/dicnode/internal/dic_node_state_scoring.h b/native/jni/src/suggest/core/dicnode/internal/dic_node_state_scoring.h index c19d48eb9..3a54c2599 100644 --- a/native/jni/src/suggest/core/dicnode/internal/dic_node_state_scoring.h +++ b/native/jni/src/suggest/core/dicnode/internal/dic_node_state_scoring.h @@ -103,8 +103,10 @@ class DicNodeStateScoring { return getCompoundDistance(1.0f); } - float getCompoundDistance(const float languageWeight) const { - return mSpatialDistance + mLanguageDistance * languageWeight; + float getCompoundDistance( + const float weightOfLangModelVsSpatialModel) const { + return mSpatialDistance + + mLanguageDistance * weightOfLangModelVsSpatialModel; } float getNormalizedCompoundDistance() const { diff --git a/native/jni/src/suggest/core/dictionary/dictionary.cpp b/native/jni/src/suggest/core/dictionary/dictionary.cpp index 1de405104..e4084b0f5 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.cpp +++ b/native/jni/src/suggest/core/dictionary/dictionary.cpp @@ -47,14 +47,14 @@ Dictionary::Dictionary(JNIEnv *env, DictionaryStructureWithBufferPolicy::Structu void Dictionary::getSuggestions(ProximityInfo *proximityInfo, DicTraverseSession *traverseSession, int *xcoordinates, int *ycoordinates, int *times, int *pointerIds, int *inputCodePoints, int inputSize, const PrevWordsInfo *const prevWordsInfo, - const SuggestOptions *const suggestOptions, const float languageWeight, + const SuggestOptions *const suggestOptions, const float weightOfLangModelVsSpatialModel, SuggestionResults *const outSuggestionResults) const { TimeKeeper::setCurrentTime(); traverseSession->init(this, prevWordsInfo, suggestOptions); const auto &suggest = suggestOptions->isGesture() ? mGestureSuggest : mTypingSuggest; suggest->getSuggestions(proximityInfo, traverseSession, xcoordinates, ycoordinates, times, pointerIds, inputCodePoints, inputSize, - languageWeight, outSuggestionResults); + weightOfLangModelVsSpatialModel, outSuggestionResults); if (DEBUG_DICT) { outSuggestionResults->dumpSuggestions(); } @@ -85,7 +85,7 @@ void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbabi return; } const WordAttributes wordAttributes = mDictStructurePolicy->getWordAttributesInContext( - mPrevWordIds.data(), targetWordId, nullptr /* multiBigramMap */); + mPrevWordIds, targetWordId, nullptr /* multiBigramMap */); mSuggestionResults->addPrediction(targetWordCodePoints, codePointCount, wordAttributes.getProbability()); } @@ -93,42 +93,42 @@ void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbabi void Dictionary::getPredictions(const PrevWordsInfo *const prevWordsInfo, SuggestionResults *const outSuggestionResults) const { TimeKeeper::setCurrentTime(); - int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; - prevWordsInfo->getPrevWordIds(mDictionaryStructureWithBufferPolicy.get(), prevWordIds, + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; + const WordIdArrayView prevWordIds = prevWordsInfo->getPrevWordIds( + mDictionaryStructureWithBufferPolicy.get(), &prevWordIdArray, true /* tryLowerCaseSearch */); - NgramListenerForPrediction listener(prevWordsInfo, - WordIdArrayView::fromFixedSizeArray(prevWordIds), outSuggestionResults, + NgramListenerForPrediction listener(prevWordsInfo, prevWordIds, outSuggestionResults, mDictionaryStructureWithBufferPolicy.get()); mDictionaryStructureWithBufferPolicy->iterateNgramEntries(prevWordIds, &listener); } -int Dictionary::getProbability(const int *word, int length) const { - return getNgramProbability(nullptr /* prevWordsInfo */, word, length); +int Dictionary::getProbability(const CodePointArrayView codePoints) const { + return getNgramProbability(nullptr /* prevWordsInfo */, codePoints); } -int Dictionary::getMaxProbabilityOfExactMatches(const int *word, int length) const { +int Dictionary::getMaxProbabilityOfExactMatches(const CodePointArrayView codePoints) const { TimeKeeper::setCurrentTime(); return DictionaryUtils::getMaxProbabilityOfExactMatches( - mDictionaryStructureWithBufferPolicy.get(), word, length); + mDictionaryStructureWithBufferPolicy.get(), codePoints); } -int Dictionary::getNgramProbability(const PrevWordsInfo *const prevWordsInfo, const int *word, - int length) const { +int Dictionary::getNgramProbability(const PrevWordsInfo *const prevWordsInfo, + const CodePointArrayView codePoints) const { TimeKeeper::setCurrentTime(); - int wordId = mDictionaryStructureWithBufferPolicy->getWordId( - CodePointArrayView(word, length), false /* forceLowerCaseSearch */); + const int wordId = mDictionaryStructureWithBufferPolicy->getWordId(codePoints, + false /* forceLowerCaseSearch */); if (wordId == NOT_A_WORD_ID) return NOT_A_PROBABILITY; if (!prevWordsInfo) { - return getDictionaryStructurePolicy()->getProbabilityOfWord( - nullptr /* prevWordsPtNodePos */, wordId); + return getDictionaryStructurePolicy()->getProbabilityOfWord(WordIdArrayView(), wordId); } - int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; - prevWordsInfo->getPrevWordIds(mDictionaryStructureWithBufferPolicy.get(), prevWordIds, + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; + const WordIdArrayView prevWordIds = prevWordsInfo->getPrevWordIds + (mDictionaryStructureWithBufferPolicy.get(), &prevWordIdArray, true /* tryLowerCaseSearch */); return getDictionaryStructurePolicy()->getProbabilityOfWord(prevWordIds, wordId); } -bool Dictionary::addUnigramEntry(const int *const word, const int length, +bool Dictionary::addUnigramEntry(const CodePointArrayView codePoints, const UnigramProperty *const unigramProperty) { if (unigramProperty->representsBeginningOfSentence() && !mDictionaryStructureWithBufferPolicy->getHeaderStructurePolicy() @@ -137,14 +137,12 @@ bool Dictionary::addUnigramEntry(const int *const word, const int length, return false; } TimeKeeper::setCurrentTime(); - return mDictionaryStructureWithBufferPolicy->addUnigramEntry(CodePointArrayView(word, length), - unigramProperty); + return mDictionaryStructureWithBufferPolicy->addUnigramEntry(codePoints, unigramProperty); } -bool Dictionary::removeUnigramEntry(const int *const codePoints, const int codePointCount) { +bool Dictionary::removeUnigramEntry(const CodePointArrayView codePoints) { TimeKeeper::setCurrentTime(); - return mDictionaryStructureWithBufferPolicy->removeUnigramEntry( - CodePointArrayView(codePoints, codePointCount)); + return mDictionaryStructureWithBufferPolicy->removeUnigramEntry(codePoints); } bool Dictionary::addNgramEntry(const PrevWordsInfo *const prevWordsInfo, @@ -154,10 +152,9 @@ bool Dictionary::addNgramEntry(const PrevWordsInfo *const prevWordsInfo, } bool Dictionary::removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, - const int *const word, const int length) { + const CodePointArrayView codePoints) { TimeKeeper::setCurrentTime(); - return mDictionaryStructureWithBufferPolicy->removeNgramEntry(prevWordsInfo, - CodePointArrayView(word, length)); + return mDictionaryStructureWithBufferPolicy->removeNgramEntry(prevWordsInfo, codePoints); } bool Dictionary::flush(const char *const filePath) { @@ -182,11 +179,9 @@ void Dictionary::getProperty(const char *const query, const int queryLength, cha maxResultLength); } -const WordProperty Dictionary::getWordProperty(const int *const codePoints, - const int codePointCount) { +const WordProperty Dictionary::getWordProperty(const CodePointArrayView codePoints) { TimeKeeper::setCurrentTime(); - return mDictionaryStructureWithBufferPolicy->getWordProperty( - CodePointArrayView(codePoints, codePointCount)); + return mDictionaryStructureWithBufferPolicy->getWordProperty(codePoints); } int Dictionary::getNextWordAndNextToken(const int token, int *const outCodePoints, diff --git a/native/jni/src/suggest/core/dictionary/dictionary.h b/native/jni/src/suggest/core/dictionary/dictionary.h index 0b54f30e9..324e3504a 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.h +++ b/native/jni/src/suggest/core/dictionary/dictionary.h @@ -66,29 +66,29 @@ class Dictionary { void getSuggestions(ProximityInfo *proximityInfo, DicTraverseSession *traverseSession, int *xcoordinates, int *ycoordinates, int *times, int *pointerIds, int *inputCodePoints, int inputSize, const PrevWordsInfo *const prevWordsInfo, - const SuggestOptions *const suggestOptions, const float languageWeight, + const SuggestOptions *const suggestOptions, const float weightOfLangModelVsSpatialModel, SuggestionResults *const outSuggestionResults) const; void getPredictions(const PrevWordsInfo *const prevWordsInfo, SuggestionResults *const outSuggestionResults) const; - int getProbability(const int *word, int length) const; + int getProbability(const CodePointArrayView codePoints) const; - int getMaxProbabilityOfExactMatches(const int *word, int length) const; + int getMaxProbabilityOfExactMatches(const CodePointArrayView codePoints) const; int getNgramProbability(const PrevWordsInfo *const prevWordsInfo, - const int *word, int length) const; + const CodePointArrayView codePoints) const; - bool addUnigramEntry(const int *const codePoints, const int codePointCount, + bool addUnigramEntry(const CodePointArrayView codePoints, const UnigramProperty *const unigramProperty); - bool removeUnigramEntry(const int *const codePoints, const int codePointCount); + bool removeUnigramEntry(const CodePointArrayView codePoints); bool addNgramEntry(const PrevWordsInfo *const prevWordsInfo, const BigramProperty *const bigramProperty); - bool removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, const int *const word, - const int length); + bool removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, + const CodePointArrayView codePoints); bool flush(const char *const filePath); @@ -99,7 +99,7 @@ class Dictionary { void getProperty(const char *const query, const int queryLength, char *const outResult, const int maxResultLength); - const WordProperty getWordProperty(const int *const codePoints, const int codePointCount); + const WordProperty getWordProperty(const CodePointArrayView codePoints); // Method to iterate all words in the dictionary. // The returned token has to be used to get the next word. If token is 0, this method newly diff --git a/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp b/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp index f71d4c5f0..b85f3622a 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp +++ b/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp @@ -23,32 +23,33 @@ #include "suggest/core/dictionary/digraph_utils.h" #include "suggest/core/session/prev_words_info.h" #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" +#include "utils/int_array_view.h" namespace latinime { /* static */ int DictionaryUtils::getMaxProbabilityOfExactMatches( const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy, - const int *const codePoints, const int codePointCount) { + const CodePointArrayView codePoints) { std::vector<DicNode> current; std::vector<DicNode> next; // No prev words information. PrevWordsInfo emptyPrevWordsInfo; - int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; - emptyPrevWordsInfo.getPrevWordIds(dictionaryStructurePolicy, prevWordIds, - false /* tryLowerCaseSearch */); + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; + const WordIdArrayView prevWordIds = emptyPrevWordsInfo.getPrevWordIds( + dictionaryStructurePolicy, &prevWordIdArray, false /* tryLowerCaseSearch */); current.emplace_back(); DicNodeUtils::initAsRoot(dictionaryStructurePolicy, prevWordIds, ¤t.front()); - for (int i = 0; i < codePointCount; ++i) { + for (const int codePoint : codePoints) { // The base-lower input is used to ignore case errors and accent errors. - const int codePoint = CharUtils::toBaseLowerCase(codePoints[i]); + const int baseLowerCodePoint = CharUtils::toBaseLowerCase(codePoint); for (const DicNode &dicNode : current) { - if (dicNode.isInDigraph() && dicNode.getNodeCodePoint() == codePoint) { + if (dicNode.isInDigraph() && dicNode.getNodeCodePoint() == baseLowerCodePoint) { next.emplace_back(dicNode); next.back().advanceDigraphIndex(); continue; } - processChildDicNodes(dictionaryStructurePolicy, codePoint, &dicNode, &next); + processChildDicNodes(dictionaryStructurePolicy, baseLowerCodePoint, &dicNode, &next); } current.clear(); current.swap(next); diff --git a/native/jni/src/suggest/core/dictionary/dictionary_utils.h b/native/jni/src/suggest/core/dictionary/dictionary_utils.h index 358ebf674..4dd21c9be 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary_utils.h +++ b/native/jni/src/suggest/core/dictionary/dictionary_utils.h @@ -20,6 +20,7 @@ #include <vector> #include "defines.h" +#include "utils/int_array_view.h" namespace latinime { @@ -30,7 +31,7 @@ class DictionaryUtils { public: static int getMaxProbabilityOfExactMatches( const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy, - const int *const codePoints, const int codePointCount); + const CodePointArrayView codePoints); private: DISALLOW_IMPLICIT_CONSTRUCTORS(DictionaryUtils); diff --git a/native/jni/src/suggest/core/dictionary/multi_bigram_map.cpp b/native/jni/src/suggest/core/dictionary/multi_bigram_map.cpp index 979d61edb..761f51ec8 100644 --- a/native/jni/src/suggest/core/dictionary/multi_bigram_map.cpp +++ b/native/jni/src/suggest/core/dictionary/multi_bigram_map.cpp @@ -35,9 +35,9 @@ const int MultiBigramMap::BigramMap::DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP = // Also caches the bigrams if there is space remaining and they have not been cached already. int MultiBigramMap::getBigramProbability( const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int *const prevWordIds, const int nextWordId, + const WordIdArrayView prevWordIds, const int nextWordId, const int unigramProbability) { - if (!prevWordIds || prevWordIds[0] == NOT_A_WORD_ID) { + if (prevWordIds.empty() || prevWordIds[0] == NOT_A_WORD_ID) { return structurePolicy->getProbability(unigramProbability, NOT_A_PROBABILITY); } const auto mapPosition = mBigramMaps.find(prevWordIds[0]); @@ -56,7 +56,7 @@ int MultiBigramMap::getBigramProbability( void MultiBigramMap::BigramMap::init( const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int *const prevWordIds) { + const WordIdArrayView prevWordIds) { structurePolicy->iterateNgramEntries(prevWordIds, this /* listener */); } @@ -83,16 +83,13 @@ void MultiBigramMap::BigramMap::onVisitEntry(const int ngramProbability, const i void MultiBigramMap::addBigramsForWord( const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int *const prevWordIds) { - if (prevWordIds) { - mBigramMaps[prevWordIds[0]].init(structurePolicy, prevWordIds); - } + const WordIdArrayView prevWordIds) { + mBigramMaps[prevWordIds[0]].init(structurePolicy, prevWordIds); } int MultiBigramMap::readBigramProbabilityFromBinaryDictionary( const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int *const prevWordIds, const int nextWordId, - const int unigramProbability) { + const WordIdArrayView prevWordIds, const int nextWordId, const int unigramProbability) { const int bigramProbability = structurePolicy->getProbabilityOfWord(prevWordIds, nextWordId); if (bigramProbability != NOT_A_PROBABILITY) { return bigramProbability; diff --git a/native/jni/src/suggest/core/dictionary/multi_bigram_map.h b/native/jni/src/suggest/core/dictionary/multi_bigram_map.h index a8c4ded57..d2eb5cc32 100644 --- a/native/jni/src/suggest/core/dictionary/multi_bigram_map.h +++ b/native/jni/src/suggest/core/dictionary/multi_bigram_map.h @@ -25,6 +25,7 @@ #include "suggest/core/dictionary/bloom_filter.h" #include "suggest/core/dictionary/ngram_listener.h" #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" +#include "utils/int_array_view.h" namespace latinime { @@ -39,7 +40,7 @@ class MultiBigramMap { // Look up the bigram probability for the given word pair from the cached bigram maps. // Also caches the bigrams if there is space remaining and they have not been cached already. int getBigramProbability(const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int *const prevWordIds, const int nextWordId, const int unigramProbability); + const WordIdArrayView prevWordIds, const int nextWordId, const int unigramProbability); void clear() { mBigramMaps.clear(); @@ -57,7 +58,7 @@ class MultiBigramMap { virtual ~BigramMap() {} void init(const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int *const prevWordIds); + const WordIdArrayView prevWordIds); int getBigramProbability( const DictionaryStructureWithBufferPolicy *const structurePolicy, const int nextWordId, const int unigramProbability) const; @@ -70,11 +71,11 @@ class MultiBigramMap { }; void addBigramsForWord(const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int *const prevWordIds); + const WordIdArrayView prevWordIds); int readBigramProbabilityFromBinaryDictionary( const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int *const prevWordIds, const int nextWordId, const int unigramProbability); + const WordIdArrayView prevWordIds, const int nextWordId, const int unigramProbability); static const size_t MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP; std::unordered_map<int, BigramMap> mBigramMaps; diff --git a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h index 7414f696c..a498b6f65 100644 --- a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h +++ b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h @@ -58,15 +58,15 @@ class DictionaryStructureWithBufferPolicy { virtual int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const = 0; - virtual const WordAttributes getWordAttributesInContext(const int *const prevWordIds, + virtual const WordAttributes getWordAttributesInContext(const WordIdArrayView prevWordIds, const int wordId, MultiBigramMap *const multiBigramMap) const = 0; // TODO: Remove virtual int getProbability(const int unigramProbability, const int bigramProbability) const = 0; - virtual int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const = 0; + virtual int getProbabilityOfWord(const WordIdArrayView prevWordIds, const int wordId) const = 0; - virtual void iterateNgramEntries(const int *const prevWordIds, + virtual void iterateNgramEntries(const WordIdArrayView prevWordIds, NgramListener *const listener) const = 0; virtual BinaryDictionaryShortcutIterator getShortcutIterator(const int wordId) const = 0; diff --git a/native/jni/src/suggest/core/policy/scoring.h b/native/jni/src/suggest/core/policy/scoring.h index 9e75cace4..ce3684a1c 100644 --- a/native/jni/src/suggest/core/policy/scoring.h +++ b/native/jni/src/suggest/core/policy/scoring.h @@ -32,9 +32,11 @@ class Scoring { const ErrorTypeUtils::ErrorType containedErrorTypes, const bool forceCommit, const bool boostExactMatches) const = 0; virtual void getMostProbableString(const DicTraverseSession *const traverseSession, - const float languageWeight, SuggestionResults *const outSuggestionResults) const = 0; - virtual float getAdjustedLanguageWeight(DicTraverseSession *const traverseSession, - DicNode *const terminals, const int size) const = 0; + const float weightOfLangModelVsSpatialModel, + SuggestionResults *const outSuggestionResults) const = 0; + virtual float getAdjustedWeightOfLangModelVsSpatialModel( + DicTraverseSession *const traverseSession, DicNode *const terminals, + const int size) const = 0; virtual float getDoubleLetterDemotionDistanceCost( const DicNode *const terminalDicNode) const = 0; virtual bool autoCorrectsToMultiWordSuggestionIfTop() const = 0; diff --git a/native/jni/src/suggest/core/result/suggestion_results.cpp b/native/jni/src/suggest/core/result/suggestion_results.cpp index 4c10bd08a..3756d1092 100644 --- a/native/jni/src/suggest/core/result/suggestion_results.cpp +++ b/native/jni/src/suggest/core/result/suggestion_results.cpp @@ -23,7 +23,7 @@ namespace latinime { void SuggestionResults::outputSuggestions(JNIEnv *env, jintArray outSuggestionCount, jintArray outputCodePointsArray, jintArray outScoresArray, jintArray outSpaceIndicesArray, jintArray outTypesArray, jintArray outAutoCommitFirstWordConfidenceArray, - jfloatArray outLanguageWeight) { + jfloatArray outWeightOfLangModelVsSpatialModel) { int outputIndex = 0; while (!mSuggestedWords.empty()) { const SuggestedWord &suggestedWord = mSuggestedWords.top(); @@ -44,7 +44,8 @@ void SuggestionResults::outputSuggestions(JNIEnv *env, jintArray outSuggestionCo mSuggestedWords.pop(); } JniDataUtils::putIntToArray(env, outSuggestionCount, 0 /* index */, outputIndex); - JniDataUtils::putFloatToArray(env, outLanguageWeight, 0 /* index */, mLanguageWeight); + JniDataUtils::putFloatToArray(env, outWeightOfLangModelVsSpatialModel, 0 /* index */, + mWeightOfLangModelVsSpatialModel); } void SuggestionResults::addPrediction(const int *const codePoints, const int codePointCount, @@ -89,7 +90,7 @@ void SuggestionResults::getSortedScores(int *const outScores) const { } void SuggestionResults::dumpSuggestions() const { - AKLOGE("language weight: %f", mLanguageWeight); + AKLOGE("weight of language model vs spatial model: %f", mWeightOfLangModelVsSpatialModel); std::vector<SuggestedWord> suggestedWords; auto copyOfSuggestedWords = mSuggestedWords; while (!copyOfSuggestedWords.empty()) { diff --git a/native/jni/src/suggest/core/result/suggestion_results.h b/native/jni/src/suggest/core/result/suggestion_results.h index 8e845e2d3..738c78a9f 100644 --- a/native/jni/src/suggest/core/result/suggestion_results.h +++ b/native/jni/src/suggest/core/result/suggestion_results.h @@ -29,13 +29,15 @@ namespace latinime { class SuggestionResults { public: explicit SuggestionResults(const int maxSuggestionCount) - : mMaxSuggestionCount(maxSuggestionCount), mLanguageWeight(NOT_A_LANGUAGE_WEIGHT), + : mMaxSuggestionCount(maxSuggestionCount), + mWeightOfLangModelVsSpatialModel(NOT_A_WEIGHT_OF_LANG_MODEL_VS_SPATIAL_MODEL), mSuggestedWords() {} // Returns suggestion count. void outputSuggestions(JNIEnv *env, jintArray outSuggestionCount, jintArray outCodePointsArray, jintArray outScoresArray, jintArray outSpaceIndicesArray, jintArray outTypesArray, - jintArray outAutoCommitFirstWordConfidenceArray, jfloatArray outLanguageWeight); + jintArray outAutoCommitFirstWordConfidenceArray, + jfloatArray outWeightOfLangModelVsSpatialModel); void addPrediction(const int *const codePoints, const int codePointCount, const int score); void addSuggestion(const int *const codePoints, const int codePointCount, const int score, const int type, const int indexToPartialCommit, @@ -43,8 +45,8 @@ class SuggestionResults { void getSortedScores(int *const outScores) const; void dumpSuggestions() const; - void setLanguageWeight(const float languageWeight) { - mLanguageWeight = languageWeight; + void setWeightOfLangModelVsSpatialModel(const float weightOfLangModelVsSpatialModel) { + mWeightOfLangModelVsSpatialModel = weightOfLangModelVsSpatialModel; } int getSuggestionCount() const { @@ -55,7 +57,7 @@ class SuggestionResults { DISALLOW_IMPLICIT_CONSTRUCTORS(SuggestionResults); const int mMaxSuggestionCount; - float mLanguageWeight; + float mWeightOfLangModelVsSpatialModel; std::priority_queue< SuggestedWord, std::vector<SuggestedWord>, SuggestedWord::Comparator> mSuggestedWords; }; diff --git a/native/jni/src/suggest/core/result/suggestions_output_utils.cpp b/native/jni/src/suggest/core/result/suggestions_output_utils.cpp index 6e0193772..3283f6deb 100644 --- a/native/jni/src/suggest/core/result/suggestions_output_utils.cpp +++ b/native/jni/src/suggest/core/result/suggestions_output_utils.cpp @@ -34,7 +34,8 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; /* static */ void SuggestionsOutputUtils::outputSuggestions( const Scoring *const scoringPolicy, DicTraverseSession *traverseSession, - const float languageWeight, SuggestionResults *const outSuggestionResults) { + const float weightOfLangModelVsSpatialModel, + SuggestionResults *const outSuggestionResults) { #if DEBUG_EVALUATE_MOST_PROBABLE_STRING const int terminalSize = 0; #else @@ -44,12 +45,15 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; for (int index = terminalSize - 1; index >= 0; --index) { traverseSession->getDicTraverseCache()->popTerminal(&terminals[index]); } - // Compute a language weight when an invalid language weight is passed. - // NOT_A_LANGUAGE_WEIGHT (-1) is assumed as an invalid language weight. - const float languageWeightToOutputSuggestions = (languageWeight < 0.0f) ? - scoringPolicy->getAdjustedLanguageWeight( - traverseSession, terminals.data(), terminalSize) : languageWeight; - outSuggestionResults->setLanguageWeight(languageWeightToOutputSuggestions); + // Compute a weight of language model when an invalid weight is passed. + // NOT_A_WEIGHT_OF_LANG_MODEL_VS_SPATIAL_MODEL (-1) is taken as an invalid value. + const float weightOfLangModelVsSpatialModelToOutputSuggestions = + (weightOfLangModelVsSpatialModel < 0.0f) + ? scoringPolicy->getAdjustedWeightOfLangModelVsSpatialModel(traverseSession, + terminals.data(), terminalSize) + : weightOfLangModelVsSpatialModel; + outSuggestionResults->setWeightOfLangModelVsSpatialModel( + weightOfLangModelVsSpatialModelToOutputSuggestions); // Force autocorrection for obvious long multi-word suggestions when the top suggestion is // a long multiple words suggestion. // TODO: Implement a smarter auto-commit method for handling multi-word suggestions. @@ -65,16 +69,16 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; // Output suggestion results here for (auto &terminalDicNode : terminals) { outputSuggestionsOfDicNode(scoringPolicy, traverseSession, &terminalDicNode, - languageWeightToOutputSuggestions, boostExactMatches, forceCommitMultiWords, - outputSecondWordFirstLetterInputIndex, outSuggestionResults); + weightOfLangModelVsSpatialModelToOutputSuggestions, boostExactMatches, + forceCommitMultiWords, outputSecondWordFirstLetterInputIndex, outSuggestionResults); } - scoringPolicy->getMostProbableString(traverseSession, languageWeightToOutputSuggestions, - outSuggestionResults); + scoringPolicy->getMostProbableString(traverseSession, + weightOfLangModelVsSpatialModelToOutputSuggestions, outSuggestionResults); } /* static */ void SuggestionsOutputUtils::outputSuggestionsOfDicNode( const Scoring *const scoringPolicy, DicTraverseSession *traverseSession, - const DicNode *const terminalDicNode, const float languageWeight, + const DicNode *const terminalDicNode, const float weightOfLangModelVsSpatialModel, const bool boostExactMatches, const bool forceCommitMultiWords, const bool outputSecondWordFirstLetterInputIndex, SuggestionResults *const outSuggestionResults) { @@ -83,8 +87,9 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; } const float doubleLetterCost = scoringPolicy->getDoubleLetterDemotionDistanceCost(terminalDicNode); - const float compoundDistance = terminalDicNode->getCompoundDistance(languageWeight) - + doubleLetterCost; + const float compoundDistance = + terminalDicNode->getCompoundDistance(weightOfLangModelVsSpatialModel) + + doubleLetterCost; const WordAttributes wordAttributes = traverseSession->getDictionaryStructurePolicy() ->getWordAttributesInContext(terminalDicNode->getPrevWordIds(), terminalDicNode->getWordId(), nullptr /* multiBigramMap */); diff --git a/native/jni/src/suggest/core/result/suggestions_output_utils.h b/native/jni/src/suggest/core/result/suggestions_output_utils.h index b099b4776..bf8497828 100644 --- a/native/jni/src/suggest/core/result/suggestions_output_utils.h +++ b/native/jni/src/suggest/core/result/suggestions_output_utils.h @@ -33,7 +33,7 @@ class SuggestionsOutputUtils { * Outputs the final list of suggestions (i.e., terminal nodes). */ static void outputSuggestions(const Scoring *const scoringPolicy, - DicTraverseSession *traverseSession, const float languageWeight, + DicTraverseSession *traverseSession, const float weightOfLangModelVsSpatialModel, SuggestionResults *const outSuggestionResults); private: @@ -44,7 +44,7 @@ class SuggestionsOutputUtils { static void outputSuggestionsOfDicNode(const Scoring *const scoringPolicy, DicTraverseSession *traverseSession, const DicNode *const terminalDicNode, - const float languageWeight, const bool boostExactMatches, + const float weightOfLangModelVsSpatialModel, const bool boostExactMatches, const bool forceCommitMultiWords, const bool outputSecondWordFirstLetterInputIndex, SuggestionResults *const outSuggestionResults); static void outputShortcuts(BinaryDictionaryShortcutIterator *const shortcutIt, diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.cpp b/native/jni/src/suggest/core/session/dic_traverse_session.cpp index d4d4d1eed..4d7505a55 100644 --- a/native/jni/src/suggest/core/session/dic_traverse_session.cpp +++ b/native/jni/src/suggest/core/session/dic_traverse_session.cpp @@ -35,8 +35,8 @@ void DicTraverseSession::init(const Dictionary *const dictionary, mMultiWordCostMultiplier = getDictionaryStructurePolicy()->getHeaderStructurePolicy() ->getMultiWordCostMultiplier(); mSuggestOptions = suggestOptions; - prevWordsInfo->getPrevWordIds(getDictionaryStructurePolicy(), mPrevWordsIds, - true /* tryLowerCaseSearch */); + mPrevWordIdCount = prevWordsInfo->getPrevWordIds(getDictionaryStructurePolicy(), + &mPrevWordIdArray, true /* tryLowerCaseSearch */).size(); } void DicTraverseSession::setupForGetSuggestions(const ProximityInfo *pInfo, diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.h b/native/jni/src/suggest/core/session/dic_traverse_session.h index 0e676d897..9f841aa3c 100644 --- a/native/jni/src/suggest/core/session/dic_traverse_session.h +++ b/native/jni/src/suggest/core/session/dic_traverse_session.h @@ -24,6 +24,7 @@ #include "suggest/core/dicnode/dic_nodes_cache.h" #include "suggest/core/dictionary/multi_bigram_map.h" #include "suggest/core/layout/proximity_info_state.h" +#include "utils/int_array_view.h" namespace latinime { @@ -50,14 +51,11 @@ class DicTraverseSession { } AK_FORCE_INLINE DicTraverseSession(JNIEnv *env, jstring localeStr, bool usesLargeCache) - : mProximityInfo(nullptr), mDictionary(nullptr), mSuggestOptions(nullptr), - mDicNodesCache(usesLargeCache), mMultiBigramMap(), mInputSize(0), mMaxPointerCount(1), - mMultiWordCostMultiplier(1.0f) { + : mPrevWordIdCount(0), mProximityInfo(nullptr), mDictionary(nullptr), + mSuggestOptions(nullptr), mDicNodesCache(usesLargeCache), mMultiBigramMap(), + mInputSize(0), mMaxPointerCount(1), mMultiWordCostMultiplier(1.0f) { // NOTE: mProximityInfoStates is an array of instances. // No need to initialize it explicitly here. - for (size_t i = 0; i < NELEMS(mPrevWordsIds); ++i) { - mPrevWordsIds[i] = NOT_A_DICT_POS; - } } // Non virtual inline destructor -- never inherit this class @@ -79,7 +77,9 @@ class DicTraverseSession { //-------------------- const ProximityInfo *getProximityInfo() const { return mProximityInfo; } const SuggestOptions *getSuggestOptions() const { return mSuggestOptions; } - const int *getPrevWordIds() const { return mPrevWordsIds; } + const WordIdArrayView getPrevWordIds() const { + return WordIdArrayView::fromArray(mPrevWordIdArray).limit(mPrevWordIdCount); + } DicNodesCache *getDicTraverseCache() { return &mDicNodesCache; } MultiBigramMap *getMultiBigramMap() { return &mMultiBigramMap; } const ProximityInfoState *getProximityInfoState(int id) const { @@ -166,7 +166,8 @@ class DicTraverseSession { const int *const inputYs, const int *const times, const int *const pointerIds, const int inputSize, const float maxSpatialDistance, const int maxPointerCount); - int mPrevWordsIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> mPrevWordIdArray; + size_t mPrevWordIdCount; const ProximityInfo *mProximityInfo; const Dictionary *mDictionary; const SuggestOptions *mSuggestOptions; diff --git a/native/jni/src/suggest/core/session/prev_words_info.h b/native/jni/src/suggest/core/session/prev_words_info.h index fc9a35968..02e82a8e0 100644 --- a/native/jni/src/suggest/core/session/prev_words_info.h +++ b/native/jni/src/suggest/core/session/prev_words_info.h @@ -17,6 +17,8 @@ #ifndef LATINIME_PREV_WORDS_INFO_H #define LATINIME_PREV_WORDS_INFO_H +#include <array> + #include "defines.h" #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" #include "utils/char_utils.h" @@ -27,12 +29,13 @@ namespace latinime { class PrevWordsInfo { public: // No prev word information. - PrevWordsInfo() { + PrevWordsInfo() : mPrevWordCount(0) { clear(); } - PrevWordsInfo(PrevWordsInfo &&prevWordsInfo) { - for (size_t i = 0; i < NELEMS(mPrevWordCodePoints); ++i) { + PrevWordsInfo(PrevWordsInfo &&prevWordsInfo) + : mPrevWordCount(prevWordsInfo.mPrevWordCount) { + for (size_t i = 0; i < mPrevWordCount; ++i) { mPrevWordCodePointCount[i] = prevWordsInfo.mPrevWordCodePointCount[i]; memmove(mPrevWordCodePoints[i], prevWordsInfo.mPrevWordCodePoints[i], sizeof(mPrevWordCodePoints[i][0]) * mPrevWordCodePointCount[i]); @@ -43,9 +46,10 @@ class PrevWordsInfo { // Construct from previous words. PrevWordsInfo(const int prevWordCodePoints[][MAX_WORD_LENGTH], const int *const prevWordCodePointCount, const bool *const isBeginningOfSentence, - const size_t prevWordCount) { + const size_t prevWordCount) + : mPrevWordCount(std::min(NELEMS(mPrevWordCodePoints), prevWordCount)) { clear(); - for (size_t i = 0; i < std::min(NELEMS(mPrevWordCodePoints), prevWordCount); ++i) { + for (size_t i = 0; i < mPrevWordCount; ++i) { if (prevWordCodePointCount[i] < 0 || prevWordCodePointCount[i] > MAX_WORD_LENGTH) { continue; } @@ -58,7 +62,7 @@ class PrevWordsInfo { // Construct from a previous word. PrevWordsInfo(const int *const prevWordCodePoints, const int prevWordCodePointCount, - const bool isBeginningOfSentence) { + const bool isBeginningOfSentence) : mPrevWordCount(1) { clear(); if (prevWordCodePointCount > MAX_WORD_LENGTH || !prevWordCodePoints) { return; @@ -79,26 +83,29 @@ class PrevWordsInfo { return false; } - void getPrevWordIds(const DictionaryStructureWithBufferPolicy *const dictStructurePolicy, - int *const outPrevWordIds, const bool tryLowerCaseSearch) const { - for (size_t i = 0; i < NELEMS(mPrevWordCodePoints); ++i) { - outPrevWordIds[i] = getWordId(dictStructurePolicy, + template<size_t N> + const WordIdArrayView getPrevWordIds( + const DictionaryStructureWithBufferPolicy *const dictStructurePolicy, + std::array<int, N> *const prevWordIdBuffer, const bool tryLowerCaseSearch) const { + for (size_t i = 0; i < std::min(mPrevWordCount, N); ++i) { + prevWordIdBuffer->at(i) = getWordId(dictStructurePolicy, mPrevWordCodePoints[i], mPrevWordCodePointCount[i], mIsBeginningOfSentence[i], tryLowerCaseSearch); } + return WordIdArrayView::fromArray(*prevWordIdBuffer).limit(mPrevWordCount); } // n is 1-indexed. - const CodePointArrayView getNthPrevWordCodePoints(const int n) const { - if (n <= 0 || n > MAX_PREV_WORD_COUNT_FOR_N_GRAM) { + const CodePointArrayView getNthPrevWordCodePoints(const size_t n) const { + if (n <= 0 || n > mPrevWordCount) { return CodePointArrayView(); } return CodePointArrayView(mPrevWordCodePoints[n - 1], mPrevWordCodePointCount[n - 1]); } // n is 1-indexed. - bool isNthPrevWordBeginningOfSentence(const int n) const { - if (n <= 0 || n > MAX_PREV_WORD_COUNT_FOR_N_GRAM) { + bool isNthPrevWordBeginningOfSentence(const size_t n) const { + if (n <= 0 || n > mPrevWordCount) { return false; } return mIsBeginningOfSentence[n - 1]; @@ -142,6 +149,7 @@ class PrevWordsInfo { } } + const size_t mPrevWordCount; int mPrevWordCodePoints[MAX_PREV_WORD_COUNT_FOR_N_GRAM][MAX_WORD_LENGTH]; int mPrevWordCodePointCount[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; bool mIsBeginningOfSentence[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp index 947d41f4b..457414f2b 100644 --- a/native/jni/src/suggest/core/suggest.cpp +++ b/native/jni/src/suggest/core/suggest.cpp @@ -45,7 +45,7 @@ const int Suggest::MIN_CONTINUOUS_SUGGESTION_INPUT_SIZE = 2; */ void Suggest::getSuggestions(ProximityInfo *pInfo, void *traverseSession, int *inputXs, int *inputYs, int *times, int *pointerIds, int *inputCodePoints, - int inputSize, const float languageWeight, + int inputSize, const float weightOfLangModelVsSpatialModel, SuggestionResults *const outSuggestionResults) const { PROF_OPEN; PROF_START(0); @@ -68,7 +68,7 @@ void Suggest::getSuggestions(ProximityInfo *pInfo, void *traverseSession, PROF_END(1); PROF_START(2); SuggestionsOutputUtils::outputSuggestions( - SCORING, tSession, languageWeight, outSuggestionResults); + SCORING, tSession, weightOfLangModelVsSpatialModel, outSuggestionResults); PROF_END(2); PROF_CLOSE; } diff --git a/native/jni/src/suggest/core/suggest.h b/native/jni/src/suggest/core/suggest.h index 788e0314b..65d5918cf 100644 --- a/native/jni/src/suggest/core/suggest.h +++ b/native/jni/src/suggest/core/suggest.h @@ -49,7 +49,8 @@ class Suggest : public SuggestInterface { AK_FORCE_INLINE virtual ~Suggest() {} void getSuggestions(ProximityInfo *pInfo, void *traverseSession, int *inputXs, int *inputYs, int *times, int *pointerIds, int *inputCodePoints, int inputSize, - const float languageWeight, SuggestionResults *const outSuggestionResults) const; + const float weightOfLangModelVsSpatialModel, + SuggestionResults *const outSuggestionResults) const; private: DISALLOW_IMPLICIT_CONSTRUCTORS(Suggest); diff --git a/native/jni/src/suggest/core/suggest_interface.h b/native/jni/src/suggest/core/suggest_interface.h index a6e5aefae..a05aa9c80 100644 --- a/native/jni/src/suggest/core/suggest_interface.h +++ b/native/jni/src/suggest/core/suggest_interface.h @@ -28,7 +28,8 @@ class SuggestInterface { public: virtual void getSuggestions(ProximityInfo *pInfo, void *traverseSession, int *inputXs, int *inputYs, int *times, int *pointerIds, int *inputCodePoints, int inputSize, - const float languageWeight, SuggestionResults *const suggestionResults) const = 0; + const float weightOfLangModelVsSpatialModel, + SuggestionResults *const suggestionResults) const = 0; SuggestInterface() {} virtual ~SuggestInterface() {} private: diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h index 87cf0cd3b..daf40d4f9 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h @@ -65,7 +65,8 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { mMaxUnigramCount(HeaderReadWriteUtils::readIntAttributeValue( &mAttributeMap, MAX_UNIGRAM_COUNT_KEY, DEFAULT_MAX_UNIGRAM_COUNT)), mMaxBigramCount(HeaderReadWriteUtils::readIntAttributeValue( - &mAttributeMap, MAX_BIGRAM_COUNT_KEY, DEFAULT_MAX_BIGRAM_COUNT)) {} + &mAttributeMap, MAX_BIGRAM_COUNT_KEY, DEFAULT_MAX_BIGRAM_COUNT)), + mCodePointTable(HeaderReadWriteUtils::readCodePointTable(&mAttributeMap)) {} // Constructs header information using an attribute map. HeaderPolicy(const FormatUtils::FORMAT_VERSION dictFormatVersion, @@ -97,7 +98,8 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { mMaxUnigramCount(HeaderReadWriteUtils::readIntAttributeValue( &mAttributeMap, MAX_UNIGRAM_COUNT_KEY, DEFAULT_MAX_UNIGRAM_COUNT)), mMaxBigramCount(HeaderReadWriteUtils::readIntAttributeValue( - &mAttributeMap, MAX_BIGRAM_COUNT_KEY, DEFAULT_MAX_BIGRAM_COUNT)) {} + &mAttributeMap, MAX_BIGRAM_COUNT_KEY, DEFAULT_MAX_BIGRAM_COUNT)), + mCodePointTable(HeaderReadWriteUtils::readCodePointTable(&mAttributeMap)) {} // Copy header information HeaderPolicy(const HeaderPolicy *const headerPolicy) @@ -118,7 +120,8 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { mForgettingCurveDurationToLevelDown( headerPolicy->mForgettingCurveDurationToLevelDown), mMaxUnigramCount(headerPolicy->mMaxUnigramCount), - mMaxBigramCount(headerPolicy->mMaxBigramCount) {} + mMaxBigramCount(headerPolicy->mMaxBigramCount), + mCodePointTable(headerPolicy->mCodePointTable) {} // Temporary dummy header. HeaderPolicy() @@ -128,7 +131,8 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { mDate(0), mLastDecayedTime(0), mUnigramCount(0), mBigramCount(0), mExtendedRegionSize(0), mHasHistoricalInfoOfWords(false), mForgettingCurveOccurrencesToLevelUp(0), mForgettingCurveProbabilityValuesTableId(0), - mForgettingCurveDurationToLevelDown(0), mMaxUnigramCount(0), mMaxBigramCount(0) {} + mForgettingCurveDurationToLevelDown(0), mMaxUnigramCount(0), mMaxBigramCount(0), + mCodePointTable(nullptr) {} ~HeaderPolicy() {} @@ -139,6 +143,8 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { switch (mDictFormatVersion) { case FormatUtils::VERSION_2: return FormatUtils::VERSION_2; + case FormatUtils::VERSION_201: + return FormatUtils::VERSION_201; case FormatUtils::VERSION_4_ONLY_FOR_TESTING: return FormatUtils::VERSION_4_ONLY_FOR_TESTING; case FormatUtils::VERSION_4: @@ -250,6 +256,10 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { return mDictFormatVersion >= FormatUtils::VERSION_4; } + const int *getCodePointTable() const { + return mCodePointTable; + } + private: DISALLOW_COPY_AND_ASSIGN(HeaderPolicy); @@ -295,6 +305,7 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { const int mForgettingCurveDurationToLevelDown; const int mMaxUnigramCount; const int mMaxBigramCount; + const int *const mCodePointTable; const std::vector<int> readLocale() const; float readMultipleWordCostMultiplier() const; diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp index d2c3d2fe0..41a8b13b8 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp @@ -18,6 +18,7 @@ #include <cctype> #include <cstdio> +#include <memory> #include <vector> #include "defines.h" @@ -34,12 +35,13 @@ namespace latinime { const int HeaderReadWriteUtils::LARGEST_INT_DIGIT_COUNT = 11; const int HeaderReadWriteUtils::MAX_ATTRIBUTE_KEY_LENGTH = 256; -const int HeaderReadWriteUtils::MAX_ATTRIBUTE_VALUE_LENGTH = 256; +const int HeaderReadWriteUtils::MAX_ATTRIBUTE_VALUE_LENGTH = 2048; const int HeaderReadWriteUtils::HEADER_MAGIC_NUMBER_SIZE = 4; const int HeaderReadWriteUtils::HEADER_DICTIONARY_VERSION_SIZE = 2; const int HeaderReadWriteUtils::HEADER_FLAG_SIZE = 2; const int HeaderReadWriteUtils::HEADER_SIZE_FIELD_SIZE = 4; +const char *const HeaderReadWriteUtils::CODE_POINT_TABLE_KEY = "codePointTable"; const HeaderReadWriteUtils::DictionaryFlags HeaderReadWriteUtils::NO_FLAGS = 0; @@ -73,20 +75,32 @@ typedef DictionaryHeaderStructurePolicy::AttributeMap AttributeMap; return; } int keyBuffer[MAX_ATTRIBUTE_KEY_LENGTH]; - int valueBuffer[MAX_ATTRIBUTE_VALUE_LENGTH]; + std::unique_ptr<int[]> valueBuffer(new int[MAX_ATTRIBUTE_VALUE_LENGTH]); while (pos < headerSize) { + // The values in the header don't use the code point table for their encoding. const int keyLength = ByteArrayUtils::readStringAndAdvancePosition(dictBuf, - MAX_ATTRIBUTE_KEY_LENGTH, keyBuffer, &pos); + MAX_ATTRIBUTE_KEY_LENGTH, nullptr /* codePointTable */, keyBuffer, &pos); std::vector<int> key; key.insert(key.end(), keyBuffer, keyBuffer + keyLength); const int valueLength = ByteArrayUtils::readStringAndAdvancePosition(dictBuf, - MAX_ATTRIBUTE_VALUE_LENGTH, valueBuffer, &pos); + MAX_ATTRIBUTE_VALUE_LENGTH, nullptr /* codePointTable */, valueBuffer.get(), &pos); std::vector<int> value; - value.insert(value.end(), valueBuffer, valueBuffer + valueLength); + value.insert(value.end(), valueBuffer.get(), valueBuffer.get() + valueLength); headerAttributes->insert(AttributeMap::value_type(key, value)); } } +/* static */ const int *HeaderReadWriteUtils::readCodePointTable( + AttributeMap *const headerAttributes) { + AttributeMap::key_type keyVector; + insertCharactersIntoVector(CODE_POINT_TABLE_KEY, &keyVector); + AttributeMap::const_iterator it = headerAttributes->find(keyVector); + if (it == headerAttributes->end()) { + return nullptr; + } + return it->second.data(); +} + /* static */ bool HeaderReadWriteUtils::writeDictionaryVersion( BufferWithExtendableBuffer *const buffer, const FormatUtils::FORMAT_VERSION version, int *const writingPos) { @@ -96,7 +110,8 @@ typedef DictionaryHeaderStructurePolicy::AttributeMap AttributeMap; } switch (version) { case FormatUtils::VERSION_2: - // Version 2 dictionary writing is not supported. + case FormatUtils::VERSION_201: + // Version 2 or 201 dictionary writing is not supported. return false; case FormatUtils::VERSION_4_ONLY_FOR_TESTING: case FormatUtils::VERSION_4: diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.h b/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.h index 1ab2eec69..5dd91b26c 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.h @@ -46,6 +46,9 @@ class HeaderReadWriteUtils { static void fetchAllHeaderAttributes(const uint8_t *const dictBuf, DictionaryHeaderStructurePolicy::AttributeMap *const headerAttributes); + static const int *readCodePointTable( + DictionaryHeaderStructurePolicy::AttributeMap *const headerAttributes); + static bool writeDictionaryVersion(BufferWithExtendableBuffer *const buffer, const FormatUtils::FORMAT_VERSION version, int *const writingPos); @@ -101,6 +104,8 @@ class HeaderReadWriteUtils { static const int HEADER_FLAG_SIZE; static const int HEADER_SIZE_FIELD_SIZE; + static const char *const CODE_POINT_TABLE_KEY; + // Value for the "flags" field. It's unused at the moment. static const DictionaryFlags NO_FLAGS; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_reader.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_reader.cpp index 82399f190..5c639b19c 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_reader.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_reader.cpp @@ -23,6 +23,7 @@ #include "suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_reader.h" +#include "suggest/policyimpl/dictionary/header/header_policy.h" #include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_utils.h" #include "suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h" #include "suggest/policyimpl/dictionary/structure/backward/v402/content/probability_dict_content.h" @@ -59,8 +60,8 @@ const PtNodeParams Ver4PatriciaTrieNodeReader::fetchPtNodeInfoFromBufferAndProce const int parentPos = DynamicPtReadingUtils::getParentPtNodePos(parentPosOffset, headPos); int codePoints[MAX_WORD_LENGTH]; - const int codePonitCount = PatriciaTrieReadingUtils::getCharsAndAdvancePosition( - dictBuf, flags, MAX_WORD_LENGTH, codePoints, &pos); + const int codePointCount = PatriciaTrieReadingUtils::getCharsAndAdvancePosition( + dictBuf, flags, MAX_WORD_LENGTH, mHeaderPolicy->getCodePointTable(), codePoints, &pos); int terminalIdFieldPos = NOT_A_DICT_POS; int terminalId = Ver4DictConstants::NOT_A_TERMINAL_ID; int probability = NOT_A_PROBABILITY; @@ -98,7 +99,7 @@ const PtNodeParams Ver4PatriciaTrieNodeReader::fetchPtNodeInfoFromBufferAndProce // The destination position is stored at the same place as the parent position. return fetchPtNodeInfoFromBufferAndProcessMovedPtNode(parentPos, newSiblingNodePos); } else { - return PtNodeParams(headPos, flags, parentPos, codePonitCount, codePoints, + return PtNodeParams(headPos, flags, parentPos, codePointCount, codePoints, terminalIdFieldPos, terminalId, probability, childrenPosFieldPos, childrenPos, newSiblingNodePos); } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp index 9b8a50b07..ee1403739 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp @@ -116,7 +116,7 @@ int Ver4PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints, } const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext( - const int *const prevWordIds, const int wordId, + const WordIdArrayView prevWordIds, const int wordId, MultiBigramMap *const multiBigramMap) const { if (wordId == NOT_A_WORD_ID) { return WordAttributes(); @@ -128,7 +128,7 @@ const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext( prevWordIds, wordId, ptNodeParams.getProbability()); return getWordAttributes(probability, ptNodeParams); } - if (prevWordIds) { + if (!prevWordIds.empty()) { const int probability = getProbabilityOfWord(prevWordIds, wordId); if (probability != NOT_A_PROBABILITY) { return getWordAttributes(probability, ptNodeParams); @@ -160,7 +160,7 @@ int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability, } } -int Ver4PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds, +int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds, const int wordId) const { if (wordId == NOT_A_WORD_ID) { return NOT_A_PROBABILITY; @@ -170,7 +170,7 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds, if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) { return NOT_A_PROBABILITY; } - if (prevWordIds) { + if (!prevWordIds.empty()) { const int bigramsPosition = getBigramsPositionOfPtNode( getTerminalPtNodePosFromWordId(prevWordIds[0])); BinaryDictionaryBigramsIterator bigramsIt(&mBigramPolicy, bigramsPosition); @@ -186,9 +186,9 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds, return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY); } -void Ver4PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordIds, +void Ver4PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordIds, NgramListener *const listener) const { - if (!prevWordIds) { + if (prevWordIds.empty()) { return; } const int bigramsPosition = getBigramsPositionOfPtNode( @@ -268,8 +268,8 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const CodePointArrayView wordCodePo return false; } const CodePointArrayView codePointArrayView(codePointsToAdd, codePointCountToAdd); - if (mUpdatingHelper.addUnigramWord(&readingHelper, codePointArrayView.data(), - codePointArrayView.size(), unigramProperty, &addedNewUnigram)) { + if (mUpdatingHelper.addUnigramWord(&readingHelper, codePointArrayView, unigramProperty, + &addedNewUnigram)) { if (addedNewUnigram && !unigramProperty->representsBeginningOfSentence()) { mUnigramCount++; } @@ -283,8 +283,8 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const CodePointArrayView wordCodePo } for (const auto &shortcut : unigramProperty->getShortcuts()) { if (!mUpdatingHelper.addShortcutTarget(wordPos, - shortcut.getTargetCodePoints()->data(), - shortcut.getTargetCodePoints()->size(), shortcut.getProbability())) { + CodePointArrayView(*shortcut.getTargetCodePoints()), + shortcut.getProbability())) { AKLOGE("Cannot add new shortcut target. PtNodePos: %d, length: %zd, " "probability: %d", wordPos, shortcut.getTargetCodePoints()->size(), shortcut.getProbability()); @@ -332,8 +332,12 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI "length: %zd", bigramProperty->getTargetCodePoints()->size()); return false; } - int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; - prevWordsInfo->getPrevWordIds(this, prevWordIds, false /* tryLowerCaseSearch */); + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; + const WordIdArrayView prevWordIds = prevWordsInfo->getPrevWordIds(this, &prevWordIdArray, + false /* tryLowerCaseSearch */); + if (prevWordIds.empty()) { + return false; + } if (prevWordIds[0] == NOT_A_WORD_ID) { if (prevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */)) { const std::vector<UnigramProperty::ShortcutProperty> shortcuts; @@ -347,7 +351,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI return false; } // Refresh word ids. - prevWordsInfo->getPrevWordIds(this, prevWordIds, false /* tryLowerCaseSearch */); + prevWordsInfo->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */); } else { return false; } @@ -390,9 +394,10 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor AKLOGE("word is too long to remove n-gram entry form the dictionary. length: %zd", wordCodePoints.size()); } - int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; - prevWordsInfo->getPrevWordIds(this, prevWordIds, false /* tryLowerCaseSerch */); - if (prevWordIds[0] == NOT_A_WORD_ID) { + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; + const WordIdArrayView prevWordIds = prevWordsInfo->getPrevWordIds(this, &prevWordIdArray, + false /* tryLowerCaseSerch */); + if (prevWordIds.firstOrDefault(NOT_A_WORD_ID) == NOT_A_WORD_ID) { return false; } const int wordPos = getTerminalPtNodePosFromWordId(getWordId(wordCodePoints, diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h index 871b556e1..576d2abb5 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h @@ -91,14 +91,15 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const; - const WordAttributes getWordAttributesInContext(const int *const prevWordIds, const int wordId, - MultiBigramMap *const multiBigramMap) const; + const WordAttributes getWordAttributesInContext(const WordIdArrayView prevWordIds, + const int wordId, MultiBigramMap *const multiBigramMap) const; int getProbability(const int unigramProbability, const int bigramProbability) const; - int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const; + int getProbabilityOfWord(const WordIdArrayView prevWordIds, const int wordId) const; - void iterateNgramEntries(const int *const prevWordIds, NgramListener *const listener) const; + void iterateNgramEntries(const WordIdArrayView prevWordIds, + NgramListener *const listener) const; BinaryDictionaryShortcutIterator getShortcutIterator(const int wordId) const; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp index 9fa93efc9..372c9e36f 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp @@ -114,7 +114,8 @@ template<class DictConstants, class DictBuffers, class DictBuffersPtr, class Str mmappedBuffer->getReadOnlyByteArrayView()); switch (formatVersion) { case FormatUtils::VERSION_2: - AKLOGE("Given path is a directory but the format is version 2. path: %s", path); + case FormatUtils::VERSION_201: + AKLOGE("Given path is a directory but the format is version 2 or 201. path: %s", path); break; case FormatUtils::VERSION_4: { return newPolicyForV4Dict<backward::v402::Ver4DictConstants, @@ -175,6 +176,7 @@ template<class DictConstants, class DictBuffers, class DictBuffersPtr, class Str } switch (FormatUtils::detectFormatVersion(mmappedBuffer->getReadOnlyByteArrayView())) { case FormatUtils::VERSION_2: + case FormatUtils::VERSION_201: return DictionaryStructureWithBufferPolicy::StructurePolicyPtr( new PatriciaTriePolicy(std::move(mmappedBuffer))); case FormatUtils::VERSION_4_ONLY_FOR_TESTING: diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/bigram/bigram_list_read_write_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/bigram/bigram_list_read_write_utils.cpp index f7fd5c071..1b2f857ab 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/bigram/bigram_list_read_write_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/bigram/bigram_list_read_write_utils.cpp @@ -39,32 +39,31 @@ const BigramListReadWriteUtils::BigramFlags BigramListReadWriteUtils::MASK_ATTRIBUTE_PROBABILITY = 0x0F; /* static */ bool BigramListReadWriteUtils::getBigramEntryPropertiesAndAdvancePosition( - const uint8_t *const bigramsBuf, const int bufSize, BigramFlags *const outBigramFlags, + const ReadOnlyByteArrayView buffer, BigramFlags *const outBigramFlags, int *const outTargetPtNodePos, int *const bigramEntryPos) { - if (bufSize <= *bigramEntryPos) { - AKLOGE("Read invalid pos in getBigramEntryPropertiesAndAdvancePosition(). bufSize: %d, " - "bigramEntryPos: %d.", bufSize, *bigramEntryPos); + if (static_cast<int>(buffer.size()) <= *bigramEntryPos) { + AKLOGE("Read invalid pos in getBigramEntryPropertiesAndAdvancePosition(). bufSize: %zd, " + "bigramEntryPos: %d.", buffer.size(), *bigramEntryPos); return false; } - const BigramFlags bigramFlags = ByteArrayUtils::readUint8AndAdvancePosition(bigramsBuf, + const BigramFlags bigramFlags = ByteArrayUtils::readUint8AndAdvancePosition(buffer.data(), bigramEntryPos); if (outBigramFlags) { *outBigramFlags = bigramFlags; } - const int targetPos = getBigramAddressAndAdvancePosition(bigramsBuf, bigramFlags, - bigramEntryPos); + const int targetPos = getBigramAddressAndAdvancePosition(buffer, bigramFlags, bigramEntryPos); if (outTargetPtNodePos) { *outTargetPtNodePos = targetPos; } return true; } -/* static */ bool BigramListReadWriteUtils::skipExistingBigrams(const uint8_t *const bigramsBuf, - const int bufSize, int *const bigramListPos) { +/* static */ bool BigramListReadWriteUtils::skipExistingBigrams(const ReadOnlyByteArrayView buffer, + int *const bigramListPos) { BigramFlags flags; do { - if (!getBigramEntryPropertiesAndAdvancePosition(bigramsBuf, bufSize, &flags, - 0 /* outTargetPtNodePos */, bigramListPos)) { + if (!getBigramEntryPropertiesAndAdvancePosition(buffer, &flags, 0 /* outTargetPtNodePos */, + bigramListPos)) { return false; } } while(hasNext(flags)); @@ -72,18 +71,18 @@ const BigramListReadWriteUtils::BigramFlags } /* static */ int BigramListReadWriteUtils::getBigramAddressAndAdvancePosition( - const uint8_t *const bigramsBuf, const BigramFlags flags, int *const pos) { + const ReadOnlyByteArrayView buffer, const BigramFlags flags, int *const pos) { int offset = 0; const int origin = *pos; switch (MASK_ATTRIBUTE_ADDRESS_TYPE & flags) { case FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE: - offset = ByteArrayUtils::readUint8AndAdvancePosition(bigramsBuf, pos); + offset = ByteArrayUtils::readUint8AndAdvancePosition(buffer.data(), pos); break; case FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES: - offset = ByteArrayUtils::readUint16AndAdvancePosition(bigramsBuf, pos); + offset = ByteArrayUtils::readUint16AndAdvancePosition(buffer.data(), pos); break; case FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES: - offset = ByteArrayUtils::readUint24AndAdvancePosition(bigramsBuf, pos); + offset = ByteArrayUtils::readUint24AndAdvancePosition(buffer.data(), pos); break; } if (isOffsetNegative(flags)) { diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/bigram/bigram_list_read_write_utils.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/bigram/bigram_list_read_write_utils.h index 10f93fb7a..a0f7d5e83 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/bigram/bigram_list_read_write_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/bigram/bigram_list_read_write_utils.h @@ -21,6 +21,7 @@ #include <cstdlib> #include "defines.h" +#include "utils/byte_array_view.h" namespace latinime { @@ -30,8 +31,8 @@ class BigramListReadWriteUtils { public: typedef uint8_t BigramFlags; - static bool getBigramEntryPropertiesAndAdvancePosition(const uint8_t *const bigramsBuf, - const int bufSize, BigramFlags *const outBigramFlags, int *const outTargetPtNodePos, + static bool getBigramEntryPropertiesAndAdvancePosition(const ReadOnlyByteArrayView buffer, + BigramFlags *const outBigramFlags, int *const outTargetPtNodePos, int *const bigramEntryPos); static AK_FORCE_INLINE int getProbabilityFromFlags(const BigramFlags flags) { @@ -43,8 +44,7 @@ public: } // Bigrams reading methods - static bool skipExistingBigrams(const uint8_t *const bigramsBuf, const int bufSize, - int *const bigramListPos); + static bool skipExistingBigrams(const ReadOnlyByteArrayView buffer, int *const bigramListPos); private: DISALLOW_IMPLICIT_CONSTRUCTORS(BigramListReadWriteUtils); @@ -61,7 +61,7 @@ private: return (flags & FLAG_ATTRIBUTE_OFFSET_NEGATIVE) != 0; } - static int getBigramAddressAndAdvancePosition(const uint8_t *const bigramsBuf, + static int getBigramAddressAndAdvancePosition(const ReadOnlyByteArrayView buffer, const BigramFlags flags, int *const pos); }; } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp index 086d98b4a..40782a44f 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp @@ -218,9 +218,9 @@ int DynamicPtReadingHelper::getCodePointsAndProbabilityAndReturnCodePointCount( } int DynamicPtReadingHelper::getTerminalPtNodePositionOfWord(const int *const inWord, - const int length, const bool forceLowerCaseSearch) { + const size_t length, const bool forceLowerCaseSearch) { int searchCodePoints[length]; - for (int i = 0; i < length; ++i) { + for (size_t i = 0; i < length; ++i) { searchCodePoints[i] = forceLowerCaseSearch ? CharUtils::toLowerCase(inWord[i]) : inWord[i]; } while (!isEnd()) { diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h index b7262581a..9a7abc97f 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h @@ -138,12 +138,12 @@ class DynamicPtReadingHelper { } // Return code point count exclude the last read node's code points. - AK_FORCE_INLINE int getPrevTotalCodePointCount() const { + AK_FORCE_INLINE size_t getPrevTotalCodePointCount() const { return mReadingState.mTotalCodePointCountSinceInitialization; } // Return code point count include the last read node's code points. - AK_FORCE_INLINE int getTotalCodePointCount(const PtNodeParams &ptNodeParams) const { + AK_FORCE_INLINE size_t getTotalCodePointCount(const PtNodeParams &ptNodeParams) const { return mReadingState.mTotalCodePointCountSinceInitialization + ptNodeParams.getCodePointCount(); } @@ -214,7 +214,7 @@ class DynamicPtReadingHelper { int getCodePointsAndProbabilityAndReturnCodePointCount(const int maxCodePointCount, int *const outCodePoints, int *const outUnigramProbability); - int getTerminalPtNodePositionOfWord(const int *const inWord, const int length, + int getTerminalPtNodePositionOfWord(const int *const inWord, const size_t length, const bool forceLowerCaseSearch); private: @@ -234,7 +234,7 @@ class DynamicPtReadingHelper { int mPos; // Remaining node count in the current array. int mRemainingPtNodeCountInThisArray; - int mTotalCodePointCountSinceInitialization; + size_t mTotalCodePointCountSinceInitialization; // Counter of PtNodes used to avoid infinite loops caused by broken or malicious links. int mTotalPtNodeIndexInThisArrayChain; // Counter of PtNode arrays used to avoid infinite loops caused by cyclic links of empty diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.cpp index 3c62e2e56..3b58d7d6d 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.cpp @@ -28,17 +28,16 @@ namespace latinime { const int DynamicPtUpdatingHelper::CHILDREN_POSITION_FIELD_SIZE = 3; -bool DynamicPtUpdatingHelper::addUnigramWord( - DynamicPtReadingHelper *const readingHelper, - const int *const wordCodePoints, const int codePointCount, - const UnigramProperty *const unigramProperty, bool *const outAddedNewUnigram) { +bool DynamicPtUpdatingHelper::addUnigramWord(DynamicPtReadingHelper *const readingHelper, + const CodePointArrayView wordCodePoints, const UnigramProperty *const unigramProperty, + bool *const outAddedNewUnigram) { int parentPos = NOT_A_DICT_POS; while (!readingHelper->isEnd()) { const PtNodeParams ptNodeParams(readingHelper->getPtNodeParams()); if (!ptNodeParams.isValid()) { break; } - const int matchedCodePointCount = readingHelper->getPrevTotalCodePointCount(); + const size_t matchedCodePointCount = readingHelper->getPrevTotalCodePointCount(); if (!readingHelper->isMatchedCodePoint(ptNodeParams, 0 /* index */, wordCodePoints[matchedCodePointCount])) { // The first code point is different from target code point. Skip this node and read @@ -47,26 +46,25 @@ bool DynamicPtUpdatingHelper::addUnigramWord( continue; } // Check following merged node code points. - const int nodeCodePointCount = ptNodeParams.getCodePointCount(); - for (int j = 1; j < nodeCodePointCount; ++j) { - const int nextIndex = matchedCodePointCount + j; - if (nextIndex >= codePointCount || !readingHelper->isMatchedCodePoint(ptNodeParams, j, - wordCodePoints[matchedCodePointCount + j])) { + const size_t nodeCodePointCount = ptNodeParams.getCodePointArrayView().size(); + for (size_t j = 1; j < nodeCodePointCount; ++j) { + const size_t nextIndex = matchedCodePointCount + j; + if (nextIndex >= wordCodePoints.size() + || !readingHelper->isMatchedCodePoint(ptNodeParams, j, + wordCodePoints[matchedCodePointCount + j])) { *outAddedNewUnigram = true; return reallocatePtNodeAndAddNewPtNodes(&ptNodeParams, j, unigramProperty, - wordCodePoints + matchedCodePointCount, - codePointCount - matchedCodePointCount); + wordCodePoints.skip(matchedCodePointCount)); } } // All characters are matched. - if (codePointCount == readingHelper->getTotalCodePointCount(ptNodeParams)) { + if (wordCodePoints.size() == readingHelper->getTotalCodePointCount(ptNodeParams)) { return setPtNodeProbability(&ptNodeParams, unigramProperty, outAddedNewUnigram); } if (!ptNodeParams.hasChildren()) { *outAddedNewUnigram = true; return createChildrenPtNodeArrayAndAChildPtNode(&ptNodeParams, unigramProperty, - wordCodePoints + readingHelper->getTotalCodePointCount(ptNodeParams), - codePointCount - readingHelper->getTotalCodePointCount(ptNodeParams)); + wordCodePoints.skip(readingHelper->getTotalCodePointCount(ptNodeParams))); } // Advance to the children nodes. parentPos = ptNodeParams.getHeadPos(); @@ -79,9 +77,8 @@ bool DynamicPtUpdatingHelper::addUnigramWord( int pos = readingHelper->getPosOfLastForwardLinkField(); *outAddedNewUnigram = true; return createAndInsertNodeIntoPtNodeArray(parentPos, - wordCodePoints + readingHelper->getPrevTotalCodePointCount(), - codePointCount - readingHelper->getPrevTotalCodePointCount(), - unigramProperty, &pos); + wordCodePoints.skip(readingHelper->getPrevTotalCodePointCount()), unigramProperty, + &pos); } bool DynamicPtUpdatingHelper::addNgramEntry(const PtNodePosArrayView prevWordsPtNodePos, @@ -120,23 +117,21 @@ bool DynamicPtUpdatingHelper::removeNgramEntry(const PtNodePosArrayView prevWord } bool DynamicPtUpdatingHelper::addShortcutTarget(const int wordPos, - const int *const targetCodePoints, const int targetCodePointCount, - const int shortcutProbability) { + const CodePointArrayView targetCodePoints, const int shortcutProbability) { const PtNodeParams ptNodeParams(mPtNodeReader->fetchPtNodeParamsInBufferFromPtNodePos(wordPos)); - return mPtNodeWriter->addShortcutTarget(&ptNodeParams, targetCodePoints, targetCodePointCount, - shortcutProbability); + return mPtNodeWriter->addShortcutTarget(&ptNodeParams, targetCodePoints.data(), + targetCodePoints.size(), shortcutProbability); } bool DynamicPtUpdatingHelper::createAndInsertNodeIntoPtNodeArray(const int parentPos, - const int *const nodeCodePoints, const int nodeCodePointCount, - const UnigramProperty *const unigramProperty, int *const forwardLinkFieldPos) { + const CodePointArrayView ptNodeCodePoints, const UnigramProperty *const unigramProperty, + int *const forwardLinkFieldPos) { const int newPtNodeArrayPos = mBuffer->getTailPosition(); if (!DynamicPtWritingUtils::writeForwardLinkPositionAndAdvancePosition(mBuffer, newPtNodeArrayPos, forwardLinkFieldPos)) { return false; } - return createNewPtNodeArrayWithAChildPtNode(parentPos, nodeCodePoints, nodeCodePointCount, - unigramProperty); + return createNewPtNodeArrayWithAChildPtNode(parentPos, ptNodeCodePoints, unigramProperty); } bool DynamicPtUpdatingHelper::setPtNodeProbability(const PtNodeParams *const originalPtNodeParams, @@ -153,8 +148,7 @@ bool DynamicPtUpdatingHelper::setPtNodeProbability(const PtNodeParams *const ori const PtNodeParams ptNodeParamsToWrite(getUpdatedPtNodeParams(originalPtNodeParams, unigramProperty->isNotAWord(), unigramProperty->isBlacklisted(), true /* isTerminal */, originalPtNodeParams->getParentPos(), - originalPtNodeParams->getCodePointCount(), originalPtNodeParams->getCodePoints(), - unigramProperty->getProbability())); + originalPtNodeParams->getCodePointArrayView(), unigramProperty->getProbability())); if (!mPtNodeWriter->writeNewTerminalPtNodeAndAdvancePosition(&ptNodeParamsToWrite, unigramProperty, &writingPos)) { return false; @@ -168,17 +162,17 @@ bool DynamicPtUpdatingHelper::setPtNodeProbability(const PtNodeParams *const ori bool DynamicPtUpdatingHelper::createChildrenPtNodeArrayAndAChildPtNode( const PtNodeParams *const parentPtNodeParams, const UnigramProperty *const unigramProperty, - const int *const codePoints, const int codePointCount) { + const CodePointArrayView codePoints) { const int newPtNodeArrayPos = mBuffer->getTailPosition(); if (!mPtNodeWriter->updateChildrenPosition(parentPtNodeParams, newPtNodeArrayPos)) { return false; } return createNewPtNodeArrayWithAChildPtNode(parentPtNodeParams->getHeadPos(), codePoints, - codePointCount, unigramProperty); + unigramProperty); } bool DynamicPtUpdatingHelper::createNewPtNodeArrayWithAChildPtNode( - const int parentPtNodePos, const int *const nodeCodePoints, const int nodeCodePointCount, + const int parentPtNodePos, const CodePointArrayView ptNodeCodePoints, const UnigramProperty *const unigramProperty) { int writingPos = mBuffer->getTailPosition(); if (!DynamicPtWritingUtils::writePtNodeArraySizeAndAdvancePosition(mBuffer, @@ -187,8 +181,7 @@ bool DynamicPtUpdatingHelper::createNewPtNodeArrayWithAChildPtNode( } const PtNodeParams ptNodeParamsToWrite(getPtNodeParamsForNewPtNode( unigramProperty->isNotAWord(), unigramProperty->isBlacklisted(), true /* isTerminal */, - parentPtNodePos, nodeCodePointCount, nodeCodePoints, - unigramProperty->getProbability())); + parentPtNodePos, ptNodeCodePoints, unigramProperty->getProbability())); if (!mPtNodeWriter->writeNewTerminalPtNodeAndAdvancePosition(&ptNodeParamsToWrite, unigramProperty, &writingPos)) { return false; @@ -202,9 +195,9 @@ bool DynamicPtUpdatingHelper::createNewPtNodeArrayWithAChildPtNode( // Returns whether the dictionary updating was succeeded or not. bool DynamicPtUpdatingHelper::reallocatePtNodeAndAddNewPtNodes( - const PtNodeParams *const reallocatingPtNodeParams, const int overlappingCodePointCount, - const UnigramProperty *const unigramProperty, const int *const newNodeCodePoints, - const int newNodeCodePointCount) { + const PtNodeParams *const reallocatingPtNodeParams, const size_t overlappingCodePointCount, + const UnigramProperty *const unigramProperty, + const CodePointArrayView newPtNodeCodePoints) { // When addsExtraChild is true, split the reallocating PtNode and add new child. // Reallocating PtNode: abcde, newNode: abcxy. // abc (1st, not terminal) __ de (2nd) @@ -212,16 +205,18 @@ bool DynamicPtUpdatingHelper::reallocatePtNodeAndAddNewPtNodes( // Otherwise, this method makes 1st part terminal and write information in unigramProperty. // Reallocating PtNode: abcde, newNode: abc. // abc (1st, terminal) __ de (2nd) - const bool addsExtraChild = newNodeCodePointCount > overlappingCodePointCount; + const bool addsExtraChild = newPtNodeCodePoints.size() > overlappingCodePointCount; const int firstPartOfReallocatedPtNodePos = mBuffer->getTailPosition(); int writingPos = firstPartOfReallocatedPtNodePos; // Write the 1st part of the reallocating node. The children position will be updated later // with actual children position. + const CodePointArrayView firstPtNodeCodePoints = + reallocatingPtNodeParams->getCodePointArrayView().limit(overlappingCodePointCount); if (addsExtraChild) { const PtNodeParams ptNodeParamsToWrite(getPtNodeParamsForNewPtNode( false /* isNotAWord */, false /* isBlacklisted */, false /* isTerminal */, - reallocatingPtNodeParams->getParentPos(), overlappingCodePointCount, - reallocatingPtNodeParams->getCodePoints(), NOT_A_PROBABILITY)); + reallocatingPtNodeParams->getParentPos(), firstPtNodeCodePoints, + NOT_A_PROBABILITY)); if (!mPtNodeWriter->writePtNodeAndAdvancePosition(&ptNodeParamsToWrite, &writingPos)) { return false; } @@ -229,8 +224,7 @@ bool DynamicPtUpdatingHelper::reallocatePtNodeAndAddNewPtNodes( const PtNodeParams ptNodeParamsToWrite(getPtNodeParamsForNewPtNode( unigramProperty->isNotAWord(), unigramProperty->isBlacklisted(), true /* isTerminal */, reallocatingPtNodeParams->getParentPos(), - overlappingCodePointCount, reallocatingPtNodeParams->getCodePoints(), - unigramProperty->getProbability())); + firstPtNodeCodePoints, unigramProperty->getProbability())); if (!mPtNodeWriter->writeNewTerminalPtNodeAndAdvancePosition(&ptNodeParamsToWrite, unigramProperty, &writingPos)) { return false; @@ -248,8 +242,7 @@ bool DynamicPtUpdatingHelper::reallocatePtNodeAndAddNewPtNodes( const PtNodeParams childPartPtNodeParams(getUpdatedPtNodeParams(reallocatingPtNodeParams, reallocatingPtNodeParams->isNotAWord(), reallocatingPtNodeParams->isBlacklisted(), reallocatingPtNodeParams->isTerminal(), firstPartOfReallocatedPtNodePos, - reallocatingPtNodeParams->getCodePointCount() - overlappingCodePointCount, - reallocatingPtNodeParams->getCodePoints() + overlappingCodePointCount, + reallocatingPtNodeParams->getCodePointArrayView().skip(overlappingCodePointCount), reallocatingPtNodeParams->getProbability())); if (!mPtNodeWriter->writePtNodeAndAdvancePosition(&childPartPtNodeParams, &writingPos)) { return false; @@ -258,8 +251,8 @@ bool DynamicPtUpdatingHelper::reallocatePtNodeAndAddNewPtNodes( const PtNodeParams extraChildPtNodeParams(getPtNodeParamsForNewPtNode( unigramProperty->isNotAWord(), unigramProperty->isBlacklisted(), true /* isTerminal */, firstPartOfReallocatedPtNodePos, - newNodeCodePointCount - overlappingCodePointCount, - newNodeCodePoints + overlappingCodePointCount, unigramProperty->getProbability())); + newPtNodeCodePoints.skip(overlappingCodePointCount), + unigramProperty->getProbability())); if (!mPtNodeWriter->writeNewTerminalPtNodeAndAdvancePosition(&extraChildPtNodeParams, unigramProperty, &writingPos)) { return false; @@ -282,26 +275,24 @@ bool DynamicPtUpdatingHelper::reallocatePtNodeAndAddNewPtNodes( } const PtNodeParams DynamicPtUpdatingHelper::getUpdatedPtNodeParams( - const PtNodeParams *const originalPtNodeParams, - const bool isNotAWord, const bool isBlacklisted, const bool isTerminal, const int parentPos, - const int codePointCount, const int *const codePoints, const int probability) const { + const PtNodeParams *const originalPtNodeParams, const bool isNotAWord, + const bool isBlacklisted, const bool isTerminal, const int parentPos, + const CodePointArrayView codePoints, const int probability) const { const PatriciaTrieReadingUtils::NodeFlags flags = PatriciaTrieReadingUtils::createAndGetFlags( isBlacklisted, isNotAWord, isTerminal, false /* hasShortcutTargets */, - false /* hasBigrams */, codePointCount > 1 /* hasMultipleChars */, + false /* hasBigrams */, codePoints.size() > 1u /* hasMultipleChars */, CHILDREN_POSITION_FIELD_SIZE); - return PtNodeParams(originalPtNodeParams, flags, parentPos, codePointCount, codePoints, - probability); + return PtNodeParams(originalPtNodeParams, flags, parentPos, codePoints, probability); } -const PtNodeParams DynamicPtUpdatingHelper::getPtNodeParamsForNewPtNode( - const bool isNotAWord, const bool isBlacklisted, const bool isTerminal, - const int parentPos, const int codePointCount, const int *const codePoints, - const int probability) const { +const PtNodeParams DynamicPtUpdatingHelper::getPtNodeParamsForNewPtNode(const bool isNotAWord, + const bool isBlacklisted, const bool isTerminal, const int parentPos, + const CodePointArrayView codePoints, const int probability) const { const PatriciaTrieReadingUtils::NodeFlags flags = PatriciaTrieReadingUtils::createAndGetFlags( isBlacklisted, isNotAWord, isTerminal, false /* hasShortcutTargets */, - false /* hasBigrams */, codePointCount > 1 /* hasMultipleChars */, + false /* hasBigrams */, codePoints.size() > 1u /* hasMultipleChars */, CHILDREN_POSITION_FIELD_SIZE); - return PtNodeParams(flags, parentPos, codePointCount, codePoints, probability); + return PtNodeParams(flags, parentPos, codePoints, probability); } } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.h index 97c05c1ea..710047e8c 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.h @@ -40,19 +40,21 @@ class DynamicPtUpdatingHelper { // Add a word to the dictionary. If the word already exists, update the probability. bool addUnigramWord(DynamicPtReadingHelper *const readingHelper, - const int *const wordCodePoints, const int codePointCount, - const UnigramProperty *const unigramProperty, bool *const outAddedNewUnigram); + const CodePointArrayView wordCodePoints, const UnigramProperty *const unigramProperty, + bool *const outAddedNewUnigram); + // TODO: Remove after stopping supporting v402. // Add an n-gram entry. bool addNgramEntry(const PtNodePosArrayView prevWordsPtNodePos, const int wordPos, const BigramProperty *const bigramProperty, bool *const outAddedNewEntry); + // TODO: Remove after stopping supporting v402. // Remove an n-gram entry. bool removeNgramEntry(const PtNodePosArrayView prevWordsPtNodePos, const int wordPos); // Add a shortcut target. - bool addShortcutTarget(const int wordPos, const int *const targetCodePoints, - const int targetCodePointCount, const int shortcutProbability); + bool addShortcutTarget(const int wordPos, const CodePointArrayView targetCodePoints, + const int shortcutProbability); private: DISALLOW_IMPLICIT_CONSTRUCTORS(DynamicPtUpdatingHelper); @@ -63,33 +65,32 @@ class DynamicPtUpdatingHelper { const PtNodeReader *const mPtNodeReader; PtNodeWriter *const mPtNodeWriter; - bool createAndInsertNodeIntoPtNodeArray(const int parentPos, const int *const nodeCodePoints, - const int nodeCodePointCount, const UnigramProperty *const unigramProperty, + bool createAndInsertNodeIntoPtNodeArray(const int parentPos, + const CodePointArrayView ptNodeCodePoints, const UnigramProperty *const unigramProperty, int *const forwardLinkFieldPos); bool setPtNodeProbability(const PtNodeParams *const originalPtNodeParams, const UnigramProperty *const unigramProperty, bool *const outAddedNewUnigram); bool createChildrenPtNodeArrayAndAChildPtNode(const PtNodeParams *const parentPtNodeParams, - const UnigramProperty *const unigramProperty, const int *const codePoints, - const int codePointCount); + const UnigramProperty *const unigramProperty, + const CodePointArrayView remainingCodePoints); - bool createNewPtNodeArrayWithAChildPtNode(const int parentPos, const int *const nodeCodePoints, - const int nodeCodePointCount, const UnigramProperty *const unigramProperty); + bool createNewPtNodeArrayWithAChildPtNode(const int parentPos, + const CodePointArrayView ptNodeCodePoints, + const UnigramProperty *const unigramProperty); - bool reallocatePtNodeAndAddNewPtNodes( - const PtNodeParams *const reallocatingPtNodeParams, const int overlappingCodePointCount, - const UnigramProperty *const unigramProperty, const int *const newNodeCodePoints, - const int newNodeCodePointCount); + bool reallocatePtNodeAndAddNewPtNodes(const PtNodeParams *const reallocatingPtNodeParams, + const size_t overlappingCodePointCount, const UnigramProperty *const unigramProperty, + const CodePointArrayView newPtNodeCodePoints); const PtNodeParams getUpdatedPtNodeParams(const PtNodeParams *const originalPtNodeParams, const bool isNotAWord, const bool isBlacklisted, const bool isTerminal, - const int parentPos, const int codePointCount, - const int *const codePoints, const int probability) const; + const int parentPos, const CodePointArrayView codePoints, const int probability) const; const PtNodeParams getPtNodeParamsForNewPtNode(const bool isNotAWord, const bool isBlacklisted, - const bool isTerminal, const int parentPos, - const int codePointCount, const int *const codePoints, const int probability) const; + const bool isTerminal, const int parentPos, const CodePointArrayView codePoints, + const int probability) const; }; } // namespace latinime #endif /* LATINIME_DYNAMIC_PATRICIA_TRIE_UPDATING_HELPER_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.cpp index e64a13cc4..6a498b2f4 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.cpp @@ -61,19 +61,20 @@ const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_IS_BLACKLISTED = 0x01; } /* static */ int PtReadingUtils::getCodePointAndAdvancePosition(const uint8_t *const buffer, - int *const pos) { - return ByteArrayUtils::readCodePointAndAdvancePosition(buffer, pos); + const int *const codePointTable, int *const pos) { + return ByteArrayUtils::readCodePointAndAdvancePosition(buffer, codePointTable, pos); } // Returns the number of read characters. /* static */ int PtReadingUtils::getCharsAndAdvancePosition(const uint8_t *const buffer, - const NodeFlags flags, const int maxLength, int *const outBuffer, int *const pos) { + const NodeFlags flags, const int maxLength, const int *const codePointTable, + int *const outBuffer, int *const pos) { int length = 0; if (hasMultipleChars(flags)) { - length = ByteArrayUtils::readStringAndAdvancePosition(buffer, maxLength, outBuffer, - pos); + length = ByteArrayUtils::readStringAndAdvancePosition(buffer, maxLength, codePointTable, + outBuffer, pos); } else { - const int codePoint = getCodePointAndAdvancePosition(buffer, pos); + const int codePoint = getCodePointAndAdvancePosition(buffer, codePointTable, pos); if (codePoint == NOT_A_CODE_POINT) { // CAVEAT: codePoint == NOT_A_CODE_POINT means the code point is // CHARACTER_ARRAY_TERMINATOR. The code point must not be CHARACTER_ARRAY_TERMINATOR @@ -92,12 +93,12 @@ const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_IS_BLACKLISTED = 0x01; // Returns the number of skipped characters. /* static */ int PtReadingUtils::skipCharacters(const uint8_t *const buffer, const NodeFlags flags, - const int maxLength, int *const pos) { + const int maxLength, const int *const codePointTable, int *const pos) { if (hasMultipleChars(flags)) { return ByteArrayUtils::advancePositionToBehindString(buffer, maxLength, pos); } else { if (maxLength > 0) { - getCodePointAndAdvancePosition(buffer, pos); + getCodePointAndAdvancePosition(buffer, codePointTable, pos); return 1; } else { return 0; @@ -134,7 +135,7 @@ const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_IS_BLACKLISTED = 0x01; /* static */ void PtReadingUtils::readPtNodeInfo(const uint8_t *const dictBuf, const int ptNodePos, const DictionaryShortcutsStructurePolicy *const shortcutPolicy, - const DictionaryBigramsStructurePolicy *const bigramPolicy, + const DictionaryBigramsStructurePolicy *const bigramPolicy, const int *const codePointTable, NodeFlags *const outFlags, int *const outCodePointCount, int *const outCodePoint, int *const outProbability, int *const outChildrenPos, int *const outShortcutPos, int *const outBigramPos, int *const outSiblingPos) { @@ -142,7 +143,7 @@ const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_IS_BLACKLISTED = 0x01; const NodeFlags flags = getFlagsAndAdvancePosition(dictBuf, &readingPos); *outFlags = flags; *outCodePointCount = getCharsAndAdvancePosition( - dictBuf, flags, MAX_WORD_LENGTH, outCodePoint, &readingPos); + dictBuf, flags, MAX_WORD_LENGTH, codePointTable, outCodePoint, &readingPos); *outProbability = isTerminal(flags) ? readProbabilityAndAdvancePosition(dictBuf, &readingPos) : NOT_A_PROBABILITY; *outChildrenPos = hasChildrenInFlags(flags) ? diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h index c3f09c3b1..a69ec4435 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h @@ -34,15 +34,17 @@ class PatriciaTrieReadingUtils { static NodeFlags getFlagsAndAdvancePosition(const uint8_t *const buffer, int *const pos); - static int getCodePointAndAdvancePosition(const uint8_t *const buffer, int *const pos); + static int getCodePointAndAdvancePosition(const uint8_t *const buffer, + const int *const codePointTable, int *const pos); // Returns the number of read characters. static int getCharsAndAdvancePosition(const uint8_t *const buffer, const NodeFlags flags, - const int maxLength, int *const outBuffer, int *const pos); + const int maxLength, const int *const codePointTable, int *const outBuffer, + int *const pos); // Returns the number of skipped characters. static int skipCharacters(const uint8_t *const buffer, const NodeFlags flags, - const int maxLength, int *const pos); + const int maxLength, const int *const codePointTable, int *const pos); static int readProbabilityAndAdvancePosition(const uint8_t *const buffer, int *const pos); @@ -106,9 +108,10 @@ class PatriciaTrieReadingUtils { static void readPtNodeInfo(const uint8_t *const dictBuf, const int ptNodePos, const DictionaryShortcutsStructurePolicy *const shortcutPolicy, const DictionaryBigramsStructurePolicy *const bigramPolicy, - NodeFlags *const outFlags, int *const outCodePointCount, int *const outCodePoint, - int *const outProbability, int *const outChildrenPos, int *const outShortcutPos, - int *const outBigramPos, int *const outSiblingPos); + const int *const codePointTable, NodeFlags *const outFlags, + int *const outCodePointCount, int *const outCodePoint, int *const outProbability, + int *const outChildrenPos, int *const outShortcutPos, int *const outBigramPos, + int *const outSiblingPos); private: DISALLOW_IMPLICIT_CONSTRUCTORS(PatriciaTrieReadingUtils); diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h index c12fed324..3ff1829bd 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h @@ -89,9 +89,9 @@ class PtNodeParams { // Construct new params by updating existing PtNode params. PtNodeParams(const PtNodeParams *const ptNodeParams, const PatriciaTrieReadingUtils::NodeFlags flags, const int parentPos, - const int codePointCount, const int *const codePoints, const int probability) + const CodePointArrayView codePoints, const int probability) : mHeadPos(ptNodeParams->getHeadPos()), mFlags(flags), mHasMovedFlag(true), - mParentPos(parentPos), mCodePointCount(codePointCount), mCodePoints(), + mParentPos(parentPos), mCodePointCount(codePoints.size()), mCodePoints(), mTerminalIdFieldPos(ptNodeParams->getTerminalIdFieldPos()), mTerminalId(ptNodeParams->getTerminalId()), mProbabilityFieldPos(ptNodeParams->getProbabilityFieldPos()), @@ -102,20 +102,20 @@ class PtNodeParams { mShortcutPos(ptNodeParams->getShortcutPos()), mBigramPos(ptNodeParams->getBigramsPos()), mSiblingPos(ptNodeParams->getSiblingNodePos()) { - memcpy(mCodePoints, codePoints, sizeof(int) * mCodePointCount); + memcpy(mCodePoints, codePoints.data(), sizeof(int) * mCodePointCount); } PtNodeParams(const PatriciaTrieReadingUtils::NodeFlags flags, const int parentPos, - const int codePointCount, const int *const codePoints, const int probability) + const CodePointArrayView codePoints, const int probability) : mHeadPos(NOT_A_DICT_POS), mFlags(flags), mHasMovedFlag(true), mParentPos(parentPos), - mCodePointCount(codePointCount), mCodePoints(), + mCodePointCount(codePoints.size()), mCodePoints(), mTerminalIdFieldPos(NOT_A_DICT_POS), mTerminalId(Ver4DictConstants::NOT_A_TERMINAL_ID), mProbabilityFieldPos(NOT_A_DICT_POS), mProbability(probability), mChildrenPosFieldPos(NOT_A_DICT_POS), mChildrenPos(NOT_A_DICT_POS), mBigramLinkedNodePos(NOT_A_DICT_POS), mShortcutPos(NOT_A_DICT_POS), mBigramPos(NOT_A_DICT_POS), mSiblingPos(NOT_A_DICT_POS) { - memcpy(mCodePoints, codePoints, sizeof(int) * mCodePointCount); + memcpy(mCodePoints, codePoints.data(), sizeof(int) * mCodePointCount); } AK_FORCE_INLINE bool isValid() const { diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/shortcut/shortcut_list_reading_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/shortcut/shortcut_list_reading_utils.cpp index 91c76941c..40b872055 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/shortcut/shortcut_list_reading_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/shortcut/shortcut_list_reading_utils.cpp @@ -31,21 +31,23 @@ const int ShortcutListReadingUtils::SHORTCUT_LIST_SIZE_FIELD_SIZE = 2; const int ShortcutListReadingUtils::WHITELIST_SHORTCUT_PROBABILITY = 15; /* static */ ShortcutListReadingUtils::ShortcutFlags - ShortcutListReadingUtils::getFlagsAndForwardPointer(const uint8_t *const dictRoot, + ShortcutListReadingUtils::getFlagsAndForwardPointer(const ReadOnlyByteArrayView buffer, int *const pos) { - return ByteArrayUtils::readUint8AndAdvancePosition(dictRoot, pos); + return ByteArrayUtils::readUint8AndAdvancePosition(buffer.data(), pos); } /* static */ int ShortcutListReadingUtils::getShortcutListSizeAndForwardPointer( - const uint8_t *const dictRoot, int *const pos) { + const ReadOnlyByteArrayView buffer, int *const pos) { // readUint16andAdvancePosition() returns an offset *including* the uint16 field itself. - return ByteArrayUtils::readUint16AndAdvancePosition(dictRoot, pos) + return ByteArrayUtils::readUint16AndAdvancePosition(buffer.data(), pos) - SHORTCUT_LIST_SIZE_FIELD_SIZE; } -/* static */ int ShortcutListReadingUtils::readShortcutTarget( - const uint8_t *const dictRoot, const int maxLength, int *const outWord, int *const pos) { - return ByteArrayUtils::readStringAndAdvancePosition(dictRoot, maxLength, outWord, pos); +/* static */ int ShortcutListReadingUtils::readShortcutTarget(const ReadOnlyByteArrayView buffer, + const int maxLength, int *const outWord, int *const pos) { + // TODO: Use codePointTable for shortcuts. + return ByteArrayUtils::readStringAndAdvancePosition(buffer.data(), maxLength, + nullptr /* codePointTable */, outWord, pos); } } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/shortcut/shortcut_list_reading_utils.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/shortcut/shortcut_list_reading_utils.h index d065bf7fd..71cb8cc2c 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/shortcut/shortcut_list_reading_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/shortcut/shortcut_list_reading_utils.h @@ -20,6 +20,7 @@ #include <cstdint> #include "defines.h" +#include "utils/byte_array_view.h" namespace latinime { @@ -27,7 +28,8 @@ class ShortcutListReadingUtils { public: typedef uint8_t ShortcutFlags; - static ShortcutFlags getFlagsAndForwardPointer(const uint8_t *const dictRoot, int *const pos); + static ShortcutFlags getFlagsAndForwardPointer(const ReadOnlyByteArrayView buffer, + int *const pos); static AK_FORCE_INLINE int getProbabilityFromFlags(const ShortcutFlags flags) { return flags & MASK_ATTRIBUTE_PROBABILITY; @@ -39,14 +41,15 @@ class ShortcutListReadingUtils { // This method returns the size of the shortcut list region excluding the shortcut list size // field at the beginning. - static int getShortcutListSizeAndForwardPointer(const uint8_t *const dictRoot, int *const pos); + static int getShortcutListSizeAndForwardPointer(const ReadOnlyByteArrayView buffer, + int *const pos); static AK_FORCE_INLINE int getShortcutListSizeFieldSize() { return SHORTCUT_LIST_SIZE_FIELD_SIZE; } - static AK_FORCE_INLINE void skipShortcuts(const uint8_t *const dictRoot, int *const pos) { - const int shortcutListSize = getShortcutListSizeAndForwardPointer(dictRoot, pos); + static AK_FORCE_INLINE void skipShortcuts(const ReadOnlyByteArrayView buffer, int *const pos) { + const int shortcutListSize = getShortcutListSizeAndForwardPointer(buffer, pos); *pos += shortcutListSize; } @@ -54,7 +57,7 @@ class ShortcutListReadingUtils { return getProbabilityFromFlags(flags) == WHITELIST_SHORTCUT_PROBABILITY; } - static int readShortcutTarget(const uint8_t *const dictRoot, const int maxLength, + static int readShortcutTarget(const ReadOnlyByteArrayView buffer, const int maxLength, int *const outWord, int *const pos); private: diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/bigram/bigram_list_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/bigram/bigram_list_policy.h index 73e291ec2..e2608435c 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/bigram/bigram_list_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/bigram/bigram_list_policy.h @@ -22,22 +22,22 @@ #include "defines.h" #include "suggest/core/policy/dictionary_bigrams_structure_policy.h" #include "suggest/policyimpl/dictionary/structure/pt_common/bigram/bigram_list_read_write_utils.h" +#include "utils/byte_array_view.h" namespace latinime { class BigramListPolicy : public DictionaryBigramsStructurePolicy { public: - BigramListPolicy(const uint8_t *const bigramsBuf, const int bufSize) - : mBigramsBuf(bigramsBuf), mBufSize(bufSize) {} + BigramListPolicy(const ReadOnlyByteArrayView buffer) : mBuffer(buffer) {} ~BigramListPolicy() {} void getNextBigram(int *const outBigramPos, int *const outProbability, bool *const outHasNext, int *const pos) const { BigramListReadWriteUtils::BigramFlags flags; - if (!BigramListReadWriteUtils::getBigramEntryPropertiesAndAdvancePosition(mBigramsBuf, - mBufSize, &flags, outBigramPos, pos)) { - AKLOGE("Cannot read bigram entry. mBufSize: %d, pos: %d. ", mBufSize, *pos); + if (!BigramListReadWriteUtils::getBigramEntryPropertiesAndAdvancePosition(mBuffer, &flags, + outBigramPos, pos)) { + AKLOGE("Cannot read bigram entry. bufSize: %zd, pos: %d. ", mBuffer.size(), *pos); *outProbability = NOT_A_PROBABILITY; *outHasNext = false; return; @@ -47,14 +47,13 @@ class BigramListPolicy : public DictionaryBigramsStructurePolicy { } bool skipAllBigrams(int *const pos) const { - return BigramListReadWriteUtils::skipExistingBigrams(mBigramsBuf, mBufSize, pos); + return BigramListReadWriteUtils::skipExistingBigrams(mBuffer, pos); } private: DISALLOW_IMPLICIT_CONSTRUCTORS(BigramListPolicy); - const uint8_t *const mBigramsBuf; - const int mBufSize; + const ReadOnlyByteArrayView mBuffer; }; } // namespace latinime #endif // LATINIME_BIGRAM_LIST_POLICY_H diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp index e76bae97c..6e7dba9ff 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp @@ -37,19 +37,19 @@ void PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const dicNo return; } int nextPos = dicNode->getChildrenPtNodeArrayPos(); - if (nextPos < 0 || nextPos >= mDictBufferSize) { - AKLOGE("Children PtNode array position is invalid. pos: %d, dict size: %d", - nextPos, mDictBufferSize); + if (!isValidPos(nextPos)) { + AKLOGE("Children PtNode array position is invalid. pos: %d, dict size: %zd", + nextPos, mBuffer.size()); mIsCorrupted = true; ASSERT(false); return; } const int childCount = PatriciaTrieReadingUtils::getPtNodeArraySizeAndAdvancePosition( - mDictRoot, &nextPos); + mBuffer.data(), &nextPos); for (int i = 0; i < childCount; i++) { - if (nextPos < 0 || nextPos >= mDictBufferSize) { - AKLOGE("Child PtNode position is invalid. pos: %d, dict size: %d, childCount: %d / %d", - nextPos, mDictBufferSize, i, childCount); + if (!isValidPos(nextPos)) { + AKLOGE("Child PtNode position is invalid. pos: %d, dict size: %zd, childCount: %d / %d", + nextPos, mBuffer.size(), i, childCount); mIsCorrupted = true; ASSERT(false); return; @@ -81,6 +81,7 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( const int ptNodePos = getTerminalPtNodePosFromWordId(wordId); int pos = getRootPosition(); int wordPos = 0; + const int *const codePointTable = mHeaderPolicy.getCodePointTable(); // One iteration of the outer loop iterates through PtNode arrays. As stated above, we will // only traverse PtNodes that are actually a part of the terminal we are searching, so each // time we enter this loop we are one depth level further than last time. @@ -91,56 +92,57 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( int lastCandidatePtNodePos = 0; // Let's loop through PtNodes in this PtNode array searching for either the terminal // or one of its ascendants. - if (pos < 0 || pos >= mDictBufferSize) { - AKLOGE("PtNode array position is invalid. pos: %d, dict size: %d", - pos, mDictBufferSize); + if (!isValidPos(pos)) { + AKLOGE("PtNode array position is invalid. pos: %d, dict size: %zd", + pos, mBuffer.size()); mIsCorrupted = true; ASSERT(false); *outUnigramProbability = NOT_A_PROBABILITY; return 0; } for (int ptNodeCount = PatriciaTrieReadingUtils::getPtNodeArraySizeAndAdvancePosition( - mDictRoot, &pos); ptNodeCount > 0; --ptNodeCount) { + mBuffer.data(), &pos); ptNodeCount > 0; --ptNodeCount) { const int startPos = pos; - if (pos < 0 || pos >= mDictBufferSize) { - AKLOGE("PtNode position is invalid. pos: %d, dict size: %d", pos, mDictBufferSize); + if (!isValidPos(pos)) { + AKLOGE("PtNode position is invalid. pos: %d, dict size: %zd", pos, mBuffer.size()); mIsCorrupted = true; ASSERT(false); *outUnigramProbability = NOT_A_PROBABILITY; return 0; } const PatriciaTrieReadingUtils::NodeFlags flags = - PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(mDictRoot, &pos); + PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(mBuffer.data(), &pos); const int character = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition( - mDictRoot, &pos); + mBuffer.data(), codePointTable, &pos); if (ptNodePos == startPos) { // We found the position. Copy the rest of the code points in the buffer and return // the length. outCodePoints[wordPos] = character; if (PatriciaTrieReadingUtils::hasMultipleChars(flags)) { int nextChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition( - mDictRoot, &pos); + mBuffer.data(), codePointTable, &pos); // We count code points in order to avoid infinite loops if the file is broken // or if there is some other bug int charCount = maxCodePointCount; while (NOT_A_CODE_POINT != nextChar && --charCount > 0) { outCodePoints[++wordPos] = nextChar; nextChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition( - mDictRoot, &pos); + mBuffer.data(), codePointTable, &pos); } } *outUnigramProbability = - PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mDictRoot, + PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mBuffer.data(), &pos); return ++wordPos; } // We need to skip past this PtNode, so skip any remaining code points after the // first and possibly the probability. if (PatriciaTrieReadingUtils::hasMultipleChars(flags)) { - PatriciaTrieReadingUtils::skipCharacters(mDictRoot, flags, MAX_WORD_LENGTH, &pos); + PatriciaTrieReadingUtils::skipCharacters(mBuffer.data(), flags, MAX_WORD_LENGTH, + codePointTable, &pos); } if (PatriciaTrieReadingUtils::isTerminal(flags)) { - PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mDictRoot, &pos); + PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mBuffer.data(), &pos); } // The fact that this PtNode has children is very important. Since we already know // that this PtNode does not match, if it has no children we know it is irrelevant @@ -155,7 +157,8 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( int currentPos = pos; // Here comes the tricky part. First, read the children position. const int childrenPos = PatriciaTrieReadingUtils - ::readChildrenPositionAndAdvancePosition(mDictRoot, flags, ¤tPos); + ::readChildrenPositionAndAdvancePosition(mBuffer.data(), flags, + ¤tPos); if (childrenPos > ptNodePos) { // If the children pos is greater than the position, it means the previous // PtNode, which position is stored in lastCandidatePtNodePos, was the right @@ -185,30 +188,30 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( if (0 != lastCandidatePtNodePos) { const PatriciaTrieReadingUtils::NodeFlags lastFlags = PatriciaTrieReadingUtils::getFlagsAndAdvancePosition( - mDictRoot, &lastCandidatePtNodePos); + mBuffer.data(), &lastCandidatePtNodePos); const int lastChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition( - mDictRoot, &lastCandidatePtNodePos); + mBuffer.data(), codePointTable, &lastCandidatePtNodePos); // We copy all the characters in this PtNode to the buffer outCodePoints[wordPos] = lastChar; if (PatriciaTrieReadingUtils::hasMultipleChars(lastFlags)) { int nextChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition( - mDictRoot, &lastCandidatePtNodePos); + mBuffer.data(), codePointTable, &lastCandidatePtNodePos); int charCount = maxCodePointCount; while (-1 != nextChar && --charCount > 0) { outCodePoints[++wordPos] = nextChar; nextChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition( - mDictRoot, &lastCandidatePtNodePos); + mBuffer.data(), codePointTable, &lastCandidatePtNodePos); } } ++wordPos; // Now we only need to branch to the children address. Skip the probability if // it's there, read pos, and break to resume the search at pos. if (PatriciaTrieReadingUtils::isTerminal(lastFlags)) { - PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mDictRoot, + PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mBuffer.data(), &lastCandidatePtNodePos); } pos = PatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition( - mDictRoot, lastFlags, &lastCandidatePtNodePos); + mBuffer.data(), lastFlags, &lastCandidatePtNodePos); break; } else { // Here is a little tricky part: we come here if we found out that all children @@ -220,14 +223,14 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( // ready to start the next one. if (PatriciaTrieReadingUtils::hasChildrenInFlags(flags)) { PatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition( - mDictRoot, flags, &pos); + mBuffer.data(), flags, &pos); } if (PatriciaTrieReadingUtils::hasShortcutTargets(flags)) { mShortcutListPolicy.skipAllShortcuts(&pos); } if (PatriciaTrieReadingUtils::hasBigrams(flags)) { if (!mBigramListPolicy.skipAllBigrams(&pos)) { - AKLOGE("Cannot skip bigrams. BufSize: %d, pos: %d.", mDictBufferSize, + AKLOGE("Cannot skip bigrams. BufSize: %zd, pos: %d.", mBuffer.size(), pos); mIsCorrupted = true; ASSERT(false); @@ -244,14 +247,14 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( // our pos is after the end of this PtNode, at the start of the next one. if (PatriciaTrieReadingUtils::hasChildrenInFlags(flags)) { PatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition( - mDictRoot, flags, &pos); + mBuffer.data(), flags, &pos); } if (PatriciaTrieReadingUtils::hasShortcutTargets(flags)) { mShortcutListPolicy.skipAllShortcuts(&pos); } if (PatriciaTrieReadingUtils::hasBigrams(flags)) { if (!mBigramListPolicy.skipAllBigrams(&pos)) { - AKLOGE("Cannot skip bigrams. BufSize: %d, pos: %d.", mDictBufferSize, pos); + AKLOGE("Cannot skip bigrams. BufSize: %zd, pos: %d.", mBuffer.size(), pos); mIsCorrupted = true; ASSERT(false); *outUnigramProbability = NOT_A_PROBABILITY; @@ -282,8 +285,9 @@ int PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints, return getWordIdFromTerminalPtNodePos(ptNodePos); } -const WordAttributes PatriciaTriePolicy::getWordAttributesInContext(const int *const prevWordIds, - const int wordId, MultiBigramMap *const multiBigramMap) const { +const WordAttributes PatriciaTriePolicy::getWordAttributesInContext( + const WordIdArrayView prevWordIds, const int wordId, + MultiBigramMap *const multiBigramMap) const { if (wordId == NOT_A_WORD_ID) { return WordAttributes(); } @@ -295,7 +299,7 @@ const WordAttributes PatriciaTriePolicy::getWordAttributesInContext(const int *c prevWordIds, wordId, ptNodeParams.getProbability()); return getWordAttributes(probability, ptNodeParams); } - if (prevWordIds) { + if (!prevWordIds.empty()) { const int bigramProbability = getProbabilityOfWord(prevWordIds, wordId); if (bigramProbability != NOT_A_PROBABILITY) { return getWordAttributes(bigramProbability, ptNodeParams); @@ -327,7 +331,8 @@ int PatriciaTriePolicy::getProbability(const int unigramProbability, } } -int PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds, const int wordId) const { +int PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds, + const int wordId) const { if (wordId == NOT_A_WORD_ID) { return NOT_A_PROBABILITY; } @@ -340,7 +345,7 @@ int PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds, const // for shortcuts). return NOT_A_PROBABILITY; } - if (prevWordIds) { + if (!prevWordIds.empty()) { const int bigramsPosition = getBigramsPositionOfPtNode( getTerminalPtNodePosFromWordId(prevWordIds[0])); BinaryDictionaryBigramsIterator bigramsIt(&mBigramListPolicy, bigramsPosition); @@ -356,9 +361,9 @@ int PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds, const return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY); } -void PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordIds, +void PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordIds, NgramListener *const listener) const { - if (!prevWordIds) { + if (prevWordIds.empty()) { return; } const int bigramsPosition = getBigramsPositionOfPtNode( @@ -371,8 +376,7 @@ void PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordIds, } } -BinaryDictionaryShortcutIterator PatriciaTriePolicy::getShortcutIterator( - const int wordId) const { +BinaryDictionaryShortcutIterator PatriciaTriePolicy::getShortcutIterator(const int wordId) const { const int shortcutPos = getShortcutPositionOfPtNode(getTerminalPtNodePosFromWordId(wordId)); return BinaryDictionaryShortcutIterator(&mShortcutListPolicy, shortcutPos); } @@ -401,9 +405,11 @@ int PatriciaTriePolicy::createAndGetLeavingChildNode(const DicNode *const dicNod int shortcutPos = NOT_A_DICT_POS; int bigramPos = NOT_A_DICT_POS; int siblingPos = NOT_A_DICT_POS; - PatriciaTrieReadingUtils::readPtNodeInfo(mDictRoot, ptNodePos, &mShortcutListPolicy, - &mBigramListPolicy, &flags, &mergedNodeCodePointCount, mergedNodeCodePoints, - &probability, &childrenPos, &shortcutPos, &bigramPos, &siblingPos); + const int *const codePointTable = mHeaderPolicy.getCodePointTable(); + PatriciaTrieReadingUtils::readPtNodeInfo(mBuffer.data(), ptNodePos, &mShortcutListPolicy, + &mBigramListPolicy, codePointTable, &flags, &mergedNodeCodePointCount, + mergedNodeCodePoints, &probability, &childrenPos, &shortcutPos, &bigramPos, + &siblingPos); // Skip PtNodes don't start with Unicode code point because they represent non-word information. if (CharUtils::isInUnicodeSpace(mergedNodeCodePoints[0])) { const int wordId = PatriciaTrieReadingUtils::isTerminal(flags) ? ptNodePos : NOT_A_WORD_ID; @@ -451,14 +457,14 @@ const WordProperty PatriciaTriePolicy::getWordProperty( int shortcutPos = getShortcutPositionOfPtNode(ptNodePos); if (shortcutPos != NOT_A_DICT_POS) { int shortcutTargetCodePoints[MAX_WORD_LENGTH]; - ShortcutListReadingUtils::getShortcutListSizeAndForwardPointer(mDictRoot, &shortcutPos); + ShortcutListReadingUtils::getShortcutListSizeAndForwardPointer(mBuffer, &shortcutPos); bool hasNext = true; while (hasNext) { const ShortcutListReadingUtils::ShortcutFlags shortcutFlags = - ShortcutListReadingUtils::getFlagsAndForwardPointer(mDictRoot, &shortcutPos); + ShortcutListReadingUtils::getFlagsAndForwardPointer(mBuffer, &shortcutPos); hasNext = ShortcutListReadingUtils::hasNext(shortcutFlags); const int shortcutTargetLength = ShortcutListReadingUtils::readShortcutTarget( - mDictRoot, MAX_WORD_LENGTH, shortcutTargetCodePoints, &shortcutPos); + mBuffer, MAX_WORD_LENGTH, shortcutTargetCodePoints, &shortcutPos); const std::vector<int> shortcutTarget(shortcutTargetCodePoints, shortcutTargetCodePoints + shortcutTargetLength); const int shortcutProbability = @@ -511,4 +517,9 @@ int PatriciaTriePolicy::getWordIdFromTerminalPtNodePos(const int ptNodePos) cons int PatriciaTriePolicy::getTerminalPtNodePosFromWordId(const int wordId) const { return wordId == NOT_A_WORD_ID ? NOT_A_DICT_POS : wordId; } + +bool PatriciaTriePolicy::isValidPos(const int pos) const { + return pos >= 0 && pos < static_cast<int>(mBuffer.size()); +} + } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h index 8c1665d7d..3cdf6cd16 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h @@ -43,15 +43,13 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { PatriciaTriePolicy(MmappedBuffer::MmappedBufferPtr mmappedBuffer) : mMmappedBuffer(std::move(mmappedBuffer)), mHeaderPolicy(mMmappedBuffer->getReadOnlyByteArrayView().data(), - FormatUtils::VERSION_2), - mDictRoot(mMmappedBuffer->getReadOnlyByteArrayView().data() - + mHeaderPolicy.getSize()), - mDictBufferSize(mMmappedBuffer->getReadOnlyByteArrayView().size() - - mHeaderPolicy.getSize()), - mBigramListPolicy(mDictRoot, mDictBufferSize), mShortcutListPolicy(mDictRoot), - mPtNodeReader(mDictRoot, mDictBufferSize, &mBigramListPolicy, &mShortcutListPolicy), - mPtNodeArrayReader(mDictRoot, mDictBufferSize), - mTerminalPtNodePositionsForIteratingWords(), mIsCorrupted(false) {} + FormatUtils::detectFormatVersion(mMmappedBuffer->getReadOnlyByteArrayView())), + mBuffer(mMmappedBuffer->getReadOnlyByteArrayView().skip(mHeaderPolicy.getSize())), + mBigramListPolicy(mBuffer), mShortcutListPolicy(mBuffer), + mPtNodeReader(mBuffer, &mBigramListPolicy, &mShortcutListPolicy, + mHeaderPolicy.getCodePointTable()), + mPtNodeArrayReader(mBuffer), mTerminalPtNodePositionsForIteratingWords(), + mIsCorrupted(false) {} AK_FORCE_INLINE int getRootPosition() const { return 0; @@ -66,14 +64,15 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const; - const WordAttributes getWordAttributesInContext(const int *const prevWordIds, const int wordId, - MultiBigramMap *const multiBigramMap) const; + const WordAttributes getWordAttributesInContext(const WordIdArrayView prevWordIds, + const int wordId, MultiBigramMap *const multiBigramMap) const; int getProbability(const int unigramProbability, const int bigramProbability) const; - int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const; + int getProbabilityOfWord(const WordIdArrayView prevWordIds, const int wordId) const; - void iterateNgramEntries(const int *const prevWordIds, NgramListener *const listener) const; + void iterateNgramEntries(const WordIdArrayView prevWordIds, + NgramListener *const listener) const; BinaryDictionaryShortcutIterator getShortcutIterator(const int wordId) const; @@ -148,8 +147,7 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { const MmappedBuffer::MmappedBufferPtr mMmappedBuffer; const HeaderPolicy mHeaderPolicy; - const uint8_t *const mDictRoot; - const int mDictBufferSize; + const ReadOnlyByteArrayView mBuffer; const BigramListPolicy mBigramListPolicy; const ShortcutListPolicy mShortcutListPolicy; const Ver2ParticiaTrieNodeReader mPtNodeReader; @@ -165,6 +163,7 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { int getTerminalPtNodePosFromWordId(const int wordId) const; const WordAttributes getWordAttributes(const int probability, const PtNodeParams &ptNodeParams) const; + bool isValidPos(const int pos) const; }; } // namespace latinime #endif // LATINIME_PATRICIA_TRIE_POLICY_H diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/shortcut/shortcut_list_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/shortcut/shortcut_list_policy.h index 8e16ccc05..5319dd26c 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/shortcut/shortcut_list_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/shortcut/shortcut_list_policy.h @@ -22,13 +22,13 @@ #include "defines.h" #include "suggest/core/policy/dictionary_shortcuts_structure_policy.h" #include "suggest/policyimpl/dictionary/structure/pt_common/shortcut/shortcut_list_reading_utils.h" +#include "utils/byte_array_view.h" namespace latinime { class ShortcutListPolicy : public DictionaryShortcutsStructurePolicy { public: - explicit ShortcutListPolicy(const uint8_t *const shortcutBuf) - : mShortcutsBuf(shortcutBuf) {} + explicit ShortcutListPolicy(const ReadOnlyByteArrayView buffer) : mBuffer(buffer) {} ~ShortcutListPolicy() {} @@ -37,7 +37,7 @@ class ShortcutListPolicy : public DictionaryShortcutsStructurePolicy { return NOT_A_DICT_POS; } int listPos = pos; - ShortcutListReadingUtils::getShortcutListSizeAndForwardPointer(mShortcutsBuf, &listPos); + ShortcutListReadingUtils::getShortcutListSizeAndForwardPointer(mBuffer, &listPos); return listPos; } @@ -45,7 +45,7 @@ class ShortcutListPolicy : public DictionaryShortcutsStructurePolicy { int *const outCodePointCount, bool *const outIsWhitelist, bool *const outHasNext, int *const pos) const { const ShortcutListReadingUtils::ShortcutFlags flags = - ShortcutListReadingUtils::getFlagsAndForwardPointer(mShortcutsBuf, pos); + ShortcutListReadingUtils::getFlagsAndForwardPointer(mBuffer, pos); if (outHasNext) { *outHasNext = ShortcutListReadingUtils::hasNext(flags); } @@ -54,20 +54,20 @@ class ShortcutListPolicy : public DictionaryShortcutsStructurePolicy { } if (outCodePoint) { *outCodePointCount = ShortcutListReadingUtils::readShortcutTarget( - mShortcutsBuf, maxCodePointCount, outCodePoint, pos); + mBuffer, maxCodePointCount, outCodePoint, pos); } } void skipAllShortcuts(int *const pos) const { const int shortcutListSize = ShortcutListReadingUtils - ::getShortcutListSizeAndForwardPointer(mShortcutsBuf, pos); + ::getShortcutListSizeAndForwardPointer(mBuffer, pos); *pos += shortcutListSize; } private: DISALLOW_IMPLICIT_CONSTRUCTORS(ShortcutListPolicy); - const uint8_t *const mShortcutsBuf; + const ReadOnlyByteArrayView mBuffer; }; } // namespace latinime #endif // LATINIME_SHORTCUT_LIST_POLICY_H diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.cpp index c1e938710..dc0ed96d0 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.cpp @@ -22,10 +22,10 @@ namespace latinime { const PtNodeParams Ver2ParticiaTrieNodeReader::fetchPtNodeParamsInBufferFromPtNodePos( const int ptNodePos) const { - if (ptNodePos < 0 || ptNodePos >= mDictSize) { + if (ptNodePos < 0 || ptNodePos >= static_cast<int>(mBuffer.size())) { // Reading invalid position because of bug or broken dictionary. - AKLOGE("Fetching PtNode info from invalid dictionary position: %d, dictionary size: %d", - ptNodePos, mDictSize); + AKLOGE("Fetching PtNode info from invalid dictionary position: %d, dictionary size: %zd", + ptNodePos, mBuffer.size()); ASSERT(false); return PtNodeParams(); } @@ -37,9 +37,9 @@ const PtNodeParams Ver2ParticiaTrieNodeReader::fetchPtNodeParamsInBufferFromPtNo int shortcutPos = NOT_A_DICT_POS; int bigramPos = NOT_A_DICT_POS; int siblingPos = NOT_A_DICT_POS; - PatriciaTrieReadingUtils::readPtNodeInfo(mDictBuffer, ptNodePos, mShortuctPolicy, - mBigramPolicy, &flags, &mergedNodeCodePointCount, mergedNodeCodePoints, &probability, - &childrenPos, &shortcutPos, &bigramPos, &siblingPos); + PatriciaTrieReadingUtils::readPtNodeInfo(mBuffer.data(), ptNodePos, mShortuctPolicy, + mBigramPolicy, mCodePointTable, &flags, &mergedNodeCodePointCount, mergedNodeCodePoints, + &probability, &childrenPos, &shortcutPos, &bigramPos, &siblingPos); if (mergedNodeCodePointCount <= 0) { AKLOGE("Empty PtNode is not allowed. Code point count: %d", mergedNodeCodePointCount); ASSERT(false); diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.h index f0725b66d..24ec5bcca 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.h @@ -22,6 +22,7 @@ #include "defines.h" #include "suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h" #include "suggest/policyimpl/dictionary/structure/pt_common/pt_node_reader.h" +#include "utils/byte_array_view.h" namespace latinime { @@ -30,21 +31,22 @@ class DictionaryShortcutsStructurePolicy; class Ver2ParticiaTrieNodeReader : public PtNodeReader { public: - Ver2ParticiaTrieNodeReader(const uint8_t *const dictBuffer, const int dictSize, + Ver2ParticiaTrieNodeReader(const ReadOnlyByteArrayView buffer, const DictionaryBigramsStructurePolicy *const bigramPolicy, - const DictionaryShortcutsStructurePolicy *const shortcutPolicy) - : mDictBuffer(dictBuffer), mDictSize(dictSize), mBigramPolicy(bigramPolicy), - mShortuctPolicy(shortcutPolicy) {} + const DictionaryShortcutsStructurePolicy *const shortcutPolicy, + const int *const codePointTable) + : mBuffer(buffer), mBigramPolicy(bigramPolicy), mShortuctPolicy(shortcutPolicy), + mCodePointTable(codePointTable) {} virtual const PtNodeParams fetchPtNodeParamsInBufferFromPtNodePos(const int ptNodePos) const; private: DISALLOW_IMPLICIT_CONSTRUCTORS(Ver2ParticiaTrieNodeReader); - const uint8_t *const mDictBuffer; - const int mDictSize; + const ReadOnlyByteArrayView mBuffer; const DictionaryBigramsStructurePolicy *const mBigramPolicy; const DictionaryShortcutsStructurePolicy *const mShortuctPolicy; + const int *const mCodePointTable; }; } // namespace latinime #endif /* LATINIME_VER2_PATRICIA_TRIE_NODE_READER_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_pt_node_array_reader.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_pt_node_array_reader.cpp index b46617d96..72ad1eb66 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_pt_node_array_reader.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_pt_node_array_reader.cpp @@ -22,16 +22,16 @@ namespace latinime { bool Ver2PtNodeArrayReader::readPtNodeArrayInfoAndReturnIfValid(const int ptNodeArrayPos, int *const outPtNodeCount, int *const outFirstPtNodePos) const { - if (ptNodeArrayPos < 0 || ptNodeArrayPos >= mDictSize) { + if (ptNodeArrayPos < 0 || ptNodeArrayPos >= static_cast<int>(mBuffer.size())) { // Reading invalid position because of a bug or a broken dictionary. - AKLOGE("Reading PtNode array info from invalid dictionary position: %d, dict size: %d", - ptNodeArrayPos, mDictSize); + AKLOGE("Reading PtNode array info from invalid dictionary position: %d, dict size: %zd", + ptNodeArrayPos, mBuffer.size()); ASSERT(false); return false; } int readingPos = ptNodeArrayPos; const int ptNodeCountInArray = PatriciaTrieReadingUtils::getPtNodeArraySizeAndAdvancePosition( - mDictBuffer, &readingPos); + mBuffer.data(), &readingPos); *outPtNodeCount = ptNodeCountInArray; *outFirstPtNodePos = readingPos; return true; @@ -39,10 +39,10 @@ bool Ver2PtNodeArrayReader::readPtNodeArrayInfoAndReturnIfValid(const int ptNode bool Ver2PtNodeArrayReader::readForwardLinkAndReturnIfValid(const int forwordLinkPos, int *const outNextPtNodeArrayPos) const { - if (forwordLinkPos < 0 || forwordLinkPos >= mDictSize) { + if (forwordLinkPos < 0 || forwordLinkPos >= static_cast<int>(mBuffer.size())) { // Reading invalid position because of bug or broken dictionary. - AKLOGE("Reading forward link from invalid dictionary position: %d, dict size: %d", - forwordLinkPos, mDictSize); + AKLOGE("Reading forward link from invalid dictionary position: %d, dict size: %zd", + forwordLinkPos, mBuffer.size()); ASSERT(false); return false; } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_pt_node_array_reader.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_pt_node_array_reader.h index 548272148..548f36bf3 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_pt_node_array_reader.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_pt_node_array_reader.h @@ -21,13 +21,13 @@ #include "defines.h" #include "suggest/policyimpl/dictionary/structure/pt_common/pt_node_array_reader.h" +#include "utils/byte_array_view.h" namespace latinime { class Ver2PtNodeArrayReader : public PtNodeArrayReader { public: - Ver2PtNodeArrayReader(const uint8_t *const dictBuffer, const int dictSize) - : mDictBuffer(dictBuffer), mDictSize(dictSize) {}; + Ver2PtNodeArrayReader(const ReadOnlyByteArrayView buffer) : mBuffer(buffer) {}; virtual bool readPtNodeArrayInfoAndReturnIfValid(const int ptNodeArrayPos, int *const outPtNodeCount, int *const outFirstPtNodePos) const; @@ -37,8 +37,7 @@ class Ver2PtNodeArrayReader : public PtNodeArrayReader { private: DISALLOW_COPY_AND_ASSIGN(Ver2PtNodeArrayReader); - const uint8_t *const mDictBuffer; - const int mDictSize; + const ReadOnlyByteArrayView mBuffer; }; } // namespace latinime #endif /* LATINIME_VER2_PT_NODE_ARRAY_READER_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp index f54bb151a..35f0f768f 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp @@ -39,7 +39,7 @@ bool LanguageModelDictContent::runGC( } int LanguageModelDictContent::getWordProbability(const WordIdArrayView prevWordIds, - const int wordId) const { + const int wordId, const HeaderPolicy *const headerPolicy) const { int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex(); int maxLevel = 0; @@ -58,14 +58,15 @@ int LanguageModelDictContent::getWordProbability(const WordIdArrayView prevWordI if (!result.mIsValid) { continue; } - const int probability = - ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo).getProbability(); + const ProbabilityEntry probabilityEntry = + ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo); if (mHasHistoricalInfo) { - return std::min( - probability + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */), - MAX_PROBABILITY); + const int probability = ForgettingCurveUtils::decodeProbability( + probabilityEntry.getHistoricalInfo(), headerPolicy) + + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */); + return std::min(probability, MAX_PROBABILITY); } else { - return probability; + return probabilityEntry.getProbability(); } } // Cannot find the word. @@ -166,7 +167,15 @@ int LanguageModelDictContent::createAndGetBitmapEntryIndex(const WordIdArrayView if (lastBitmapEntryIndex == TrieMap::INVALID_INDEX) { return TrieMap::INVALID_INDEX; } - return mTrieMap.getNextLevelBitmapEntryIndex(prevWordIds[prevWordIds.size() - 1], + const int oldestPrevWordId = prevWordIds.lastOrDefault(NOT_A_WORD_ID); + const TrieMap::Result result = mTrieMap.get(oldestPrevWordId, lastBitmapEntryIndex); + if (!result.mIsValid) { + if (!mTrieMap.put(oldestPrevWordId, + ProbabilityEntry().encode(mHasHistoricalInfo), lastBitmapEntryIndex)) { + return TrieMap::INVALID_INDEX; + } + } + return mTrieMap.getNextLevelBitmapEntryIndex(prevWordIds.lastOrDefault(NOT_A_WORD_ID), lastBitmapEntryIndex); } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h index 4e0b47036..a793af4be 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h @@ -128,7 +128,8 @@ class LanguageModelDictContent { const LanguageModelDictContent *const originalContent, int *const outNgramCount); - int getWordProbability(const WordIdArrayView prevWordIds, const int wordId) const; + int getWordProbability(const WordIdArrayView prevWordIds, const int wordId, + const HeaderPolicy *const headerPolicy) const; ProbabilityEntry getProbabilityEntry(const int wordId) const { return getNgramProbabilityEntry(WordIdArrayView(), wordId); diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h index 3dfaba755..f1bf12cb2 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h @@ -36,7 +36,8 @@ class ProbabilityEntry { // Dummy entry ProbabilityEntry() - : mFlags(0), mProbability(NOT_A_PROBABILITY), mHistoricalInfo() {} + : mFlags(Ver4DictConstants::FLAG_NOT_A_VALID_ENTRY), mProbability(NOT_A_PROBABILITY), + mHistoricalInfo() {} // Entry without historical information ProbabilityEntry(const int flags, const int probability) @@ -61,7 +62,7 @@ class ProbabilityEntry { bigramProperty->getCount()) {} bool isValid() const { - return (mProbability != NOT_A_PROBABILITY) || hasHistoricalInfo(); + return (mFlags & Ver4DictConstants::FLAG_NOT_A_VALID_ENTRY) == 0; } bool hasHistoricalInfo() const { diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp index 9acf2d44f..39822b94a 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp @@ -53,6 +53,7 @@ const int Ver4DictConstants::WORD_LEVEL_FIELD_SIZE = 1; const int Ver4DictConstants::WORD_COUNT_FIELD_SIZE = 1; const uint8_t Ver4DictConstants::FLAG_REPRESENTS_BEGINNING_OF_SENTENCE = 0x1; +const uint8_t Ver4DictConstants::FLAG_NOT_A_VALID_ENTRY = 0x2; const int Ver4DictConstants::SHORTCUT_ADDRESS_TABLE_BLOCK_SIZE = 64; const int Ver4DictConstants::SHORTCUT_ADDRESS_TABLE_DATA_SIZE = 4; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h index 97035311e..dfcdd4d6f 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h @@ -51,6 +51,7 @@ class Ver4DictConstants { static const int WORD_COUNT_FIELD_SIZE; // Flags in probability entry. static const uint8_t FLAG_REPRESENTS_BEGINNING_OF_SENTENCE; + static const uint8_t FLAG_NOT_A_VALID_ENTRY; static const int SHORTCUT_ADDRESS_TABLE_BLOCK_SIZE; static const int SHORTCUT_ADDRESS_TABLE_DATA_SIZE; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.cpp index 731092efd..d795239fc 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.cpp @@ -16,6 +16,7 @@ #include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.h" +#include "suggest/policyimpl/dictionary/header/header_policy.h" #include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_utils.h" #include "suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h" #include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h" @@ -51,7 +52,7 @@ const PtNodeParams Ver4PatriciaTrieNodeReader::fetchPtNodeInfoFromBufferAndProce DynamicPtReadingUtils::getParentPtNodePos(parentPosOffset, headPos); int codePoints[MAX_WORD_LENGTH]; const int codePonitCount = PatriciaTrieReadingUtils::getCharsAndAdvancePosition( - dictBuf, flags, MAX_WORD_LENGTH, codePoints, &pos); + dictBuf, flags, MAX_WORD_LENGTH, mHeaderPolicy->getCodePointTable(), codePoints, &pos); int terminalIdFieldPos = NOT_A_DICT_POS; int terminalId = Ver4DictConstants::NOT_A_TERMINAL_ID; int probability = NOT_A_PROBABILITY; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp index 9ca712470..75ec16912 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp @@ -211,19 +211,17 @@ bool Ver4PatriciaTrieNodeWriter::writeNewTerminalPtNodeAndAdvancePosition( bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds, const int wordId, const BigramProperty *const bigramProperty, bool *const outAddedNewBigram) { - // TODO: Support n-gram. LanguageModelDictContent *const languageModelDictContent = mBuffers->getMutableLanguageModelDictContent(); const ProbabilityEntry probabilityEntry = - languageModelDictContent->getNgramProbabilityEntry( - prevWordIds.limit(1 /* maxSize */), wordId); + languageModelDictContent->getNgramProbabilityEntry(prevWordIds, wordId); const ProbabilityEntry probabilityEntryOfBigramProperty(bigramProperty); const ProbabilityEntry updatedProbabilityEntry = createUpdatedEntryFrom( &probabilityEntry, &probabilityEntryOfBigramProperty); if (!languageModelDictContent->setNgramProbabilityEntry( - prevWordIds.limit(1 /* maxSize */), wordId, &updatedProbabilityEntry)) { - AKLOGE("Cannot add new ngram entry. prevWordId: %d, wordId: %d", - prevWordIds[0], wordId); + prevWordIds, wordId, &updatedProbabilityEntry)) { + AKLOGE("Cannot add new ngram entry. prevWordId[0]: %d, prevWordId.size(): %zd, wordId: %d", + prevWordIds[0], prevWordIds.size(), wordId); return false; } if (!probabilityEntry.isValid() && outAddedNewBigram) { @@ -234,11 +232,9 @@ bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds bool Ver4PatriciaTrieNodeWriter::removeNgramEntry(const WordIdArrayView prevWordIds, const int wordId) { - // TODO: Support n-gram. LanguageModelDictContent *const languageModelDictContent = mBuffers->getMutableLanguageModelDictContent(); - return languageModelDictContent->removeNgramProbabilityEntry(prevWordIds.limit(1 /* maxSize */), - wordId); + return languageModelDictContent->removeNgramProbabilityEntry(prevWordIds, wordId); } // TODO: Remove when we stop supporting v402 format. diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp index 0472a453a..8d4135679 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp @@ -16,6 +16,7 @@ #include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h" +#include <array> #include <vector> #include "suggest/core/dicnode/dic_node.h" @@ -111,7 +112,7 @@ int Ver4PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints, } const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext( - const int *const prevWordIds, const int wordId, + const WordIdArrayView prevWordIds, const int wordId, MultiBigramMap *const multiBigramMap) const { if (wordId == NOT_A_WORD_ID) { return WordAttributes(); @@ -119,31 +120,15 @@ const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext( const int ptNodePos = mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId); const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); - // TODO: Support n-gram. - return WordAttributes(mBuffers->getLanguageModelDictContent()->getWordProbability( - WordIdArrayView::singleElementView(prevWordIds), wordId), ptNodeParams.isBlacklisted(), - ptNodeParams.isNotAWord(), ptNodeParams.getProbability() == 0); + const int probability = mBuffers->getLanguageModelDictContent()->getWordProbability( + prevWordIds, wordId, mHeaderPolicy); + return WordAttributes(probability, ptNodeParams.isBlacklisted(), ptNodeParams.isNotAWord(), + probability == 0); } -int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability, - const int bigramProbability) const { - if (mHeaderPolicy->isDecayingDict()) { - // Both probabilities are encoded. Decode them and get probability. - return ForgettingCurveUtils::getProbability(unigramProbability, bigramProbability); - } else { - if (unigramProbability == NOT_A_PROBABILITY) { - return NOT_A_PROBABILITY; - } else if (bigramProbability == NOT_A_PROBABILITY) { - return ProbabilityUtils::backoff(unigramProbability); - } else { - return bigramProbability; - } - } -} - -int Ver4PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds, +int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds, const int wordId) const { - if (wordId == NOT_A_WORD_ID) { + if (wordId == NOT_A_WORD_ID || prevWordIds.contains(NOT_A_WORD_ID)) { return NOT_A_PROBABILITY; } const int ptNodePos = @@ -152,22 +137,17 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds, if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) { return NOT_A_PROBABILITY; } - if (prevWordIds) { - // TODO: Support n-gram. - const ProbabilityEntry probabilityEntry = - mBuffers->getLanguageModelDictContent()->getNgramProbabilityEntry( - IntArrayView::singleElementView(prevWordIds), wordId); - if (!probabilityEntry.isValid()) { - return NOT_A_PROBABILITY; - } - if (mHeaderPolicy->hasHistoricalInfoOfWords()) { - return ForgettingCurveUtils::decodeProbability(probabilityEntry.getHistoricalInfo(), - mHeaderPolicy); - } else { - return probabilityEntry.getProbability(); - } + const ProbabilityEntry probabilityEntry = + mBuffers->getLanguageModelDictContent()->getNgramProbabilityEntry(prevWordIds, wordId); + if (!probabilityEntry.isValid()) { + return NOT_A_PROBABILITY; + } + if (mHeaderPolicy->hasHistoricalInfoOfWords()) { + return ForgettingCurveUtils::decodeProbability(probabilityEntry.getHistoricalInfo(), + mHeaderPolicy); + } else { + return probabilityEntry.getProbability(); } - return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY); } BinaryDictionaryShortcutIterator Ver4PatriciaTriePolicy::getShortcutIterator( @@ -176,21 +156,23 @@ BinaryDictionaryShortcutIterator Ver4PatriciaTriePolicy::getShortcutIterator( return BinaryDictionaryShortcutIterator(&mShortcutPolicy, shortcutPos); } -void Ver4PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordIds, +void Ver4PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordIds, NgramListener *const listener) const { - if (!prevWordIds) { + if (prevWordIds.empty()) { return; } - // TODO: Support n-gram. const auto languageModelDictContent = mBuffers->getLanguageModelDictContent(); - for (const auto entry : languageModelDictContent->getProbabilityEntries( - WordIdArrayView::singleElementView(prevWordIds))) { - const ProbabilityEntry &probabilityEntry = entry.getProbabilityEntry(); - const int probability = probabilityEntry.hasHistoricalInfo() ? - ForgettingCurveUtils::decodeProbability( - probabilityEntry.getHistoricalInfo(), mHeaderPolicy) : - probabilityEntry.getProbability(); - listener->onVisitEntry(probability, entry.getWordId()); + for (size_t i = 1; i <= prevWordIds.size(); ++i) { + for (const auto entry : languageModelDictContent->getProbabilityEntries( + prevWordIds.limit(i))) { + const ProbabilityEntry &probabilityEntry = entry.getProbabilityEntry(); + const int probability = probabilityEntry.hasHistoricalInfo() ? + ForgettingCurveUtils::decodeProbability( + probabilityEntry.getHistoricalInfo(), mHeaderPolicy) + + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */) : + probabilityEntry.getProbability(); + listener->onVisitEntry(probability, entry.getWordId()); + } } } @@ -245,8 +227,8 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const CodePointArrayView wordCodePo return false; } const CodePointArrayView codePointArrayView(codePointsToAdd, codePointCountToAdd); - if (mUpdatingHelper.addUnigramWord(&readingHelper, codePointArrayView.data(), - codePointArrayView.size(), unigramProperty, &addedNewUnigram)) { + if (mUpdatingHelper.addUnigramWord(&readingHelper, codePointArrayView, unigramProperty, + &addedNewUnigram)) { if (addedNewUnigram && !unigramProperty->representsBeginningOfSentence()) { mUnigramCount++; } @@ -261,8 +243,8 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const CodePointArrayView wordCodePo mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId); for (const auto &shortcut : unigramProperty->getShortcuts()) { if (!mUpdatingHelper.addShortcutTarget(wordPos, - shortcut.getTargetCodePoints()->data(), - shortcut.getTargetCodePoints()->size(), shortcut.getProbability())) { + CodePointArrayView(*shortcut.getTargetCodePoints()), + shortcut.getProbability())) { AKLOGE("Cannot add new shortcut target. PtNodePos: %d, length: %zd, " "probability: %d", wordPos, shortcut.getTargetCodePoints()->size(), shortcut.getProbability()); @@ -321,26 +303,31 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI "length: %zd", bigramProperty->getTargetCodePoints()->size()); return false; } - int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; - prevWordsInfo->getPrevWordIds(this, prevWordIds, false /* tryLowerCaseSearch */); - // TODO: Support N-gram. - if (prevWordIds[0] == NOT_A_WORD_ID) { - if (prevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */)) { - const std::vector<UnigramProperty::ShortcutProperty> shortcuts; - const UnigramProperty beginningOfSentenceUnigramProperty( - true /* representsBeginningOfSentence */, true /* isNotAWord */, - false /* isBlacklisted */, MAX_PROBABILITY /* probability */, - NOT_A_TIMESTAMP /* timestamp */, 0 /* level */, 0 /* count */, &shortcuts); - if (!addUnigramEntry(prevWordsInfo->getNthPrevWordCodePoints(1 /* n */), - &beginningOfSentenceUnigramProperty)) { - AKLOGE("Cannot add unigram entry for the beginning-of-sentence."); - return false; - } - // Refresh word ids. - prevWordsInfo->getPrevWordIds(this, prevWordIds, false /* tryLowerCaseSearch */); - } else { + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; + const WordIdArrayView prevWordIds = prevWordsInfo->getPrevWordIds(this, &prevWordIdArray, + false /* tryLowerCaseSearch */); + if (prevWordIds.empty()) { + return false; + } + for (size_t i = 0; i < prevWordIds.size(); ++i) { + if (prevWordIds[i] != NOT_A_WORD_ID) { + continue; + } + if (!prevWordsInfo->isNthPrevWordBeginningOfSentence(i + 1 /* n */)) { return false; } + const std::vector<UnigramProperty::ShortcutProperty> shortcuts; + const UnigramProperty beginningOfSentenceUnigramProperty( + true /* representsBeginningOfSentence */, true /* isNotAWord */, + false /* isBlacklisted */, MAX_PROBABILITY /* probability */, + NOT_A_TIMESTAMP /* timestamp */, 0 /* level */, 0 /* count */, &shortcuts); + if (!addUnigramEntry(prevWordsInfo->getNthPrevWordCodePoints(1 /* n */), + &beginningOfSentenceUnigramProperty)) { + AKLOGE("Cannot add unigram entry for the beginning-of-sentence."); + return false; + } + // Refresh word ids. + prevWordsInfo->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */); } const int wordId = getWordId(CodePointArrayView(*bigramProperty->getTargetCodePoints()), false /* forceLowerCaseSearch */); @@ -348,15 +335,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI return false; } bool addedNewEntry = false; - int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; - for (size_t i = 0; i < NELEMS(prevWordIds); ++i) { - prevWordsPtNodePos[i] = mBuffers->getTerminalPositionLookupTable() - ->getTerminalPtNodePosition(prevWordIds[i]); - } - const int wordPtNodePos = mBuffers->getTerminalPositionLookupTable() - ->getTerminalPtNodePosition(wordId); - if (mUpdatingHelper.addNgramEntry(WordIdArrayView::fromFixedSizeArray(prevWordsPtNodePos), - wordPtNodePos, bigramProperty, &addedNewEntry)) { + if (mNodeWriter.addNgramEntry(prevWordIds, wordId, bigramProperty, &addedNewEntry)) { if (addedNewEntry) { mBigramCount++; } @@ -385,25 +364,17 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor AKLOGE("word is too long to remove n-gram entry form the dictionary. length: %zd", wordCodePoints.size()); } - int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; - prevWordsInfo->getPrevWordIds(this, prevWordIds, false /* tryLowerCaseSerch */); - // TODO: Support N-gram. - if (prevWordIds[0] == NOT_A_WORD_ID) { + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; + const WordIdArrayView prevWordIds = prevWordsInfo->getPrevWordIds(this, &prevWordIdArray, + false /* tryLowerCaseSerch */); + if (prevWordIds.empty() || prevWordIds.contains(NOT_A_WORD_ID)) { return false; } const int wordId = getWordId(wordCodePoints, false /* forceLowerCaseSearch */); if (wordId == NOT_A_WORD_ID) { return false; } - int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; - for (size_t i = 0; i < NELEMS(prevWordIds); ++i) { - prevWordsPtNodePos[i] = mBuffers->getTerminalPositionLookupTable() - ->getTerminalPtNodePosition(prevWordIds[i]); - } - const int wordPtNodePos = mBuffers->getTerminalPositionLookupTable() - ->getTerminalPtNodePosition(wordId); - if (mUpdatingHelper.removeNgramEntry(WordIdArrayView::fromFixedSizeArray(prevWordsPtNodePos), - wordPtNodePos)) { + if (mNodeWriter.removeNgramEntry(prevWordIds, wordId)) { mBigramCount--; return true; } else { diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h index 980c16e4a..a117a3614 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h @@ -68,14 +68,19 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const; - const WordAttributes getWordAttributesInContext(const int *const prevWordIds, const int wordId, - MultiBigramMap *const multiBigramMap) const; + const WordAttributes getWordAttributesInContext(const WordIdArrayView prevWordIds, + const int wordId, MultiBigramMap *const multiBigramMap) const; - int getProbability(const int unigramProbability, const int bigramProbability) const; + // TODO: Remove + int getProbability(const int unigramProbability, const int bigramProbability) const { + // Not used. + return NOT_A_PROBABILITY; + } - int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const; + int getProbabilityOfWord(const WordIdArrayView prevWordIds, const int wordId) const; - void iterateNgramEntries(const int *const prevWordIds, NgramListener *const listener) const; + void iterateNgramEntries(const WordIdArrayView prevWordIds, + NgramListener *const listener) const; BinaryDictionaryShortcutIterator getShortcutIterator(const int wordId) const; diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.cpp index ecbe7922c..da2c30cd6 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.cpp @@ -42,8 +42,10 @@ void BufferWithExtendableBuffer::readCodePointsAndAdvancePosition(const int maxC if (readingPosIsInAdditionalBuffer) { *pos -= mOriginalBuffer.size(); } + // Code point table is not used for dynamic format. *outCodePointCount = ByteArrayUtils::readStringAndAdvancePosition( - getBuffer(readingPosIsInAdditionalBuffer), maxCodePointCount, outCodePoints, pos); + getBuffer(readingPosIsInAdditionalBuffer), maxCodePointCount, + nullptr /* codePointTable */, outCodePoints, pos); if (readingPosIsInAdditionalBuffer) { *pos += mOriginalBuffer.size(); } diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/byte_array_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/byte_array_utils.h index 4b3c98988..abb979050 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/byte_array_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/byte_array_utils.h @@ -147,11 +147,18 @@ class ByteArrayUtils { */ static AK_FORCE_INLINE int readCodePoint(const uint8_t *const buffer, const int pos) { int p = pos; - return readCodePointAndAdvancePosition(buffer, &p); + return readCodePointAndAdvancePosition(buffer, nullptr /* codePointTable */, &p); } static AK_FORCE_INLINE int readCodePointAndAdvancePosition( - const uint8_t *const buffer, int *const pos) { + const uint8_t *const buffer, const int *const codePointTable, int *const pos) { + /* + * codePointTable is an array to convert the most frequent characters in this dictionary to + * 1 byte code points. It is only made of the original code points of the most frequent + * characters used in this dictionary. 0x20 - 0xFF is used for the 1 byte characters. + * The original code points are restored by picking the code points at the indices of the + * codePointTable. The indices are calculated by subtracting 0x20 from the firstByte. + */ const uint8_t firstByte = readUint8(buffer, *pos); if (firstByte < MINIMUM_ONE_BYTE_CHARACTER_VALUE) { if (firstByte == CHARACTER_ARRAY_TERMINATOR) { @@ -162,6 +169,9 @@ class ByteArrayUtils { } } else { *pos += 1; + if (codePointTable) { + return codePointTable[firstByte - MINIMUM_ONE_BYTE_CHARACTER_VALUE]; + } return firstByte; } } @@ -173,12 +183,13 @@ class ByteArrayUtils { */ // Returns the length of the string. static int readStringAndAdvancePosition(const uint8_t *const buffer, - const int maxLength, int *const outBuffer, int *const pos) { + const int maxLength, const int *const codePointTable, int *const outBuffer, + int *const pos) { int length = 0; - int codePoint = readCodePointAndAdvancePosition(buffer, pos); + int codePoint = readCodePointAndAdvancePosition(buffer, codePointTable, pos); while (NOT_A_CODE_POINT != codePoint && length < maxLength) { outBuffer[length++] = codePoint; - codePoint = readCodePointAndAdvancePosition(buffer, pos); + codePoint = readCodePointAndAdvancePosition(buffer, codePointTable, pos); } return length; } @@ -187,9 +198,9 @@ class ByteArrayUtils { static int advancePositionToBehindString( const uint8_t *const buffer, const int maxLength, int *const pos) { int length = 0; - int codePoint = readCodePointAndAdvancePosition(buffer, pos); + int codePoint = readCodePointAndAdvancePosition(buffer, nullptr /* codePointTable */, pos); while (NOT_A_CODE_POINT != codePoint && length < maxLength) { - codePoint = readCodePointAndAdvancePosition(buffer, pos); + codePoint = readCodePointAndAdvancePosition(buffer, nullptr /* codePointTable */, pos); length++; } return length; diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp index e6e7167c2..0cffe569d 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp @@ -29,6 +29,8 @@ const size_t FormatUtils::DICTIONARY_MINIMUM_SIZE = 12; switch (formatVersion) { case VERSION_2: return VERSION_2; + case VERSION_201: + return VERSION_201; case VERSION_4_ONLY_FOR_TESTING: return VERSION_4_ONLY_FOR_TESTING; case VERSION_4: diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h index 51ad9877c..96310086b 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h @@ -32,6 +32,7 @@ class FormatUtils { enum FORMAT_VERSION { // These MUST have the same values as the relevant constants in FormatSpec.java. VERSION_2 = 2, + VERSION_201 = 201, VERSION_4_ONLY_FOR_TESTING = 399, VERSION_4 = 402, VERSION_4_DEV = 403, diff --git a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h index 52c4251f0..0240bcf54 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h @@ -33,10 +33,12 @@ class TypingScoring : public Scoring { static const TypingScoring *getInstance() { return &sInstance; } AK_FORCE_INLINE void getMostProbableString(const DicTraverseSession *const traverseSession, - const float languageWeight, SuggestionResults *const outSuggestionResults) const {} + const float weightOfLangModelVsSpatialModel, + SuggestionResults *const outSuggestionResults) const {} - AK_FORCE_INLINE float getAdjustedLanguageWeight(DicTraverseSession *const traverseSession, - DicNode *const terminals, const int size) const { + AK_FORCE_INLINE float getAdjustedWeightOfLangModelVsSpatialModel( + DicTraverseSession *const traverseSession, DicNode *const terminals, + const int size) const { return 1.0f; } diff --git a/native/jni/src/utils/byte_array_view.h b/native/jni/src/utils/byte_array_view.h index 10d7ae278..2b778af6f 100644 --- a/native/jni/src/utils/byte_array_view.h +++ b/native/jni/src/utils/byte_array_view.h @@ -42,6 +42,13 @@ class ReadOnlyByteArrayView { return mPtr; } + AK_FORCE_INLINE const ReadOnlyByteArrayView skip(const size_t n) const { + if (mSize <= n) { + return ReadOnlyByteArrayView(); + } + return ReadOnlyByteArrayView(mPtr + n, mSize - n); + } + private: DISALLOW_ASSIGNMENT_OPERATOR(ReadOnlyByteArrayView); diff --git a/native/jni/src/utils/int_array_view.h b/native/jni/src/utils/int_array_view.h index c39add9fe..f3a8589ca 100644 --- a/native/jni/src/utils/int_array_view.h +++ b/native/jni/src/utils/int_array_view.h @@ -17,6 +17,7 @@ #ifndef LATINIME_INT_ARRAY_VIEW_H #define LATINIME_INT_ARRAY_VIEW_H +#include <algorithm> #include <array> #include <cstdint> #include <cstring> @@ -57,9 +58,9 @@ class IntArrayView { explicit IntArrayView(const std::vector<int> &vector) : mPtr(vector.data()), mSize(vector.size()) {} - template <int N> - AK_FORCE_INLINE static IntArrayView fromFixedSizeArray(const int (&array)[N]) { - return IntArrayView(array, N); + template <size_t N> + AK_FORCE_INLINE static IntArrayView fromArray(const std::array<int, N> &array) { + return IntArrayView(array.data(), array.size()); } // Returns a view that points one int object. @@ -92,12 +93,16 @@ class IntArrayView { return mPtr + mSize; } + AK_FORCE_INLINE bool contains(const int value) const { + return std::find(begin(), end(), value) != end(); + } + // Returns the view whose size is smaller than or equal to the given count. - const IntArrayView limit(const size_t maxSize) const { + AK_FORCE_INLINE const IntArrayView limit(const size_t maxSize) const { return IntArrayView(mPtr, std::min(maxSize, mSize)); } - const IntArrayView skip(const size_t n) const { + AK_FORCE_INLINE const IntArrayView skip(const size_t n) const { if (mSize <= n) { return IntArrayView(); } @@ -110,6 +115,20 @@ class IntArrayView { memmove(buffer->data() + offset, mPtr, sizeof(int) * mSize); } + AK_FORCE_INLINE int firstOrDefault(const int defaultValue) const { + if (empty()) { + return defaultValue; + } + return mPtr[0]; + } + + AK_FORCE_INLINE int lastOrDefault(const int defaultValue) const { + if (empty()) { + return defaultValue; + } + return mPtr[mSize - 1]; + } + private: DISALLOW_ASSIGNMENT_OPERATOR(IntArrayView); @@ -120,6 +139,8 @@ class IntArrayView { using WordIdArrayView = IntArrayView; using PtNodePosArrayView = IntArrayView; using CodePointArrayView = IntArrayView; +template <size_t size> +using WordIdArray = std::array<int, size>; } // namespace latinime #endif // LATINIME_MEMORY_VIEW_H diff --git a/native/jni/src/utils/jni_data_utils.h b/native/jni/src/utils/jni_data_utils.h index cb82d3c3b..235a03bba 100644 --- a/native/jni/src/utils/jni_data_utils.h +++ b/native/jni/src/utils/jni_data_utils.h @@ -97,17 +97,13 @@ class JniDataUtils { } static PrevWordsInfo constructPrevWordsInfo(JNIEnv *env, jobjectArray prevWordCodePointArrays, - jbooleanArray isBeginningOfSentenceArray) { + jbooleanArray isBeginningOfSentenceArray, const size_t prevWordCount) { int prevWordCodePoints[MAX_PREV_WORD_COUNT_FOR_N_GRAM][MAX_WORD_LENGTH]; int prevWordCodePointCount[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; bool isBeginningOfSentence[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; - jsize prevWordsCount = env->GetArrayLength(prevWordCodePointArrays); - for (size_t i = 0; i < NELEMS(prevWordCodePoints); ++i) { + for (size_t i = 0; i < prevWordCount; ++i) { prevWordCodePointCount[i] = 0; isBeginningOfSentence[i] = false; - if (prevWordsCount <= static_cast<int>(i)) { - continue; - } jintArray prevWord = (jintArray)env->GetObjectArrayElement(prevWordCodePointArrays, i); if (!prevWord) { continue; @@ -124,7 +120,7 @@ class JniDataUtils { isBeginningOfSentence[i] = isBeginningOfSentenceBoolean == JNI_TRUE; } return PrevWordsInfo(prevWordCodePoints, prevWordCodePointCount, isBeginningOfSentence, - MAX_PREV_WORD_COUNT_FOR_N_GRAM); + prevWordCount); } static void putBooleanToArray(JNIEnv *env, jbooleanArray array, const int index, |