diff options
Diffstat (limited to 'native/jni/src')
35 files changed, 403 insertions, 369 deletions
diff --git a/native/jni/src/suggest/core/dicnode/internal/dic_node_properties.h b/native/jni/src/suggest/core/dicnode/internal/dic_node_properties.h index cecfc7aa9..1b796b5d4 100644 --- a/native/jni/src/suggest/core/dicnode/internal/dic_node_properties.h +++ b/native/jni/src/suggest/core/dicnode/internal/dic_node_properties.h @@ -32,7 +32,7 @@ class DicNodeProperties { public: AK_FORCE_INLINE DicNodeProperties() : mChildrenPtNodeArrayPos(NOT_A_DICT_POS), mDicNodeCodePoint(NOT_A_CODE_POINT), - mWordId(NOT_A_WORD_ID), mDepth(0), mLeavingDepth(0) {} + mWordId(NOT_A_WORD_ID), mDepth(0), mLeavingDepth(0), mPrevWordCount(0) {} ~DicNodeProperties() {} @@ -45,6 +45,7 @@ class DicNodeProperties { mDepth = depth; mLeavingDepth = leavingDepth; prevWordIds.copyToArray(&mPrevWordIds, 0 /* offset */); + mPrevWordCount = prevWordIds.size(); } // Init for root with prevWordsPtNodePos which is used for n-gram @@ -55,6 +56,7 @@ class DicNodeProperties { mDepth = 0; mLeavingDepth = 0; prevWordIds.copyToArray(&mPrevWordIds, 0 /* offset */); + mPrevWordCount = prevWordIds.size(); } void initByCopy(const DicNodeProperties *const dicNodeProp) { @@ -63,8 +65,9 @@ class DicNodeProperties { mWordId = dicNodeProp->mWordId; mDepth = dicNodeProp->mDepth; mLeavingDepth = dicNodeProp->mLeavingDepth; - WordIdArrayView::fromArray(dicNodeProp->mPrevWordIds) - .copyToArray(&mPrevWordIds, 0 /* offset */); + const WordIdArrayView prevWordIdArrayView = dicNodeProp->getPrevWordIds(); + prevWordIdArrayView.copyToArray(&mPrevWordIds, 0 /* offset */); + mPrevWordCount = prevWordIdArrayView.size(); } // Init as passing child @@ -74,8 +77,9 @@ class DicNodeProperties { mWordId = dicNodeProp->mWordId; mDepth = dicNodeProp->mDepth + 1; // Increment the depth of a passing child mLeavingDepth = dicNodeProp->mLeavingDepth; - WordIdArrayView::fromArray(dicNodeProp->mPrevWordIds) - .copyToArray(&mPrevWordIds, 0 /* offset */); + const WordIdArrayView prevWordIdArrayView = dicNodeProp->getPrevWordIds(); + prevWordIdArrayView.copyToArray(&mPrevWordIds, 0 /* offset */); + mPrevWordCount = prevWordIdArrayView.size(); } int getChildrenPtNodeArrayPos() const { @@ -104,7 +108,7 @@ class DicNodeProperties { } const WordIdArrayView getPrevWordIds() const { - return WordIdArrayView::fromArray(mPrevWordIds); + return WordIdArrayView::fromArray(mPrevWordIds).limit(mPrevWordCount); } int getWordId() const { @@ -121,6 +125,7 @@ class DicNodeProperties { uint16_t mDepth; uint16_t mLeavingDepth; WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> mPrevWordIds; + size_t mPrevWordCount; }; } // namespace latinime #endif // LATINIME_DIC_NODE_PROPERTIES_H diff --git a/native/jni/src/suggest/core/dictionary/dictionary.cpp b/native/jni/src/suggest/core/dictionary/dictionary.cpp index ec261cfbf..f9f36ce44 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.cpp +++ b/native/jni/src/suggest/core/dictionary/dictionary.cpp @@ -93,42 +93,42 @@ void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbabi void Dictionary::getPredictions(const PrevWordsInfo *const prevWordsInfo, SuggestionResults *const outSuggestionResults) const { TimeKeeper::setCurrentTime(); - WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIds; - prevWordsInfo->getPrevWordIds(mDictionaryStructureWithBufferPolicy.get(), prevWordIds.data(), + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; + const WordIdArrayView prevWordIds = prevWordsInfo->getPrevWordIds( + mDictionaryStructureWithBufferPolicy.get(), &prevWordIdArray, true /* tryLowerCaseSearch */); - const WordIdArrayView prevWordIdArrayView = WordIdArrayView::fromArray(prevWordIds); - NgramListenerForPrediction listener(prevWordsInfo, prevWordIdArrayView, outSuggestionResults, + NgramListenerForPrediction listener(prevWordsInfo, prevWordIds, outSuggestionResults, mDictionaryStructureWithBufferPolicy.get()); - mDictionaryStructureWithBufferPolicy->iterateNgramEntries(prevWordIdArrayView, &listener); + mDictionaryStructureWithBufferPolicy->iterateNgramEntries(prevWordIds, &listener); } -int Dictionary::getProbability(const int *word, int length) const { - return getNgramProbability(nullptr /* prevWordsInfo */, word, length); +int Dictionary::getProbability(const CodePointArrayView codePoints) const { + return getNgramProbability(nullptr /* prevWordsInfo */, codePoints); } -int Dictionary::getMaxProbabilityOfExactMatches(const int *word, int length) const { +int Dictionary::getMaxProbabilityOfExactMatches(const CodePointArrayView codePoints) const { TimeKeeper::setCurrentTime(); return DictionaryUtils::getMaxProbabilityOfExactMatches( - mDictionaryStructureWithBufferPolicy.get(), word, length); + mDictionaryStructureWithBufferPolicy.get(), codePoints); } -int Dictionary::getNgramProbability(const PrevWordsInfo *const prevWordsInfo, const int *word, - int length) const { +int Dictionary::getNgramProbability(const PrevWordsInfo *const prevWordsInfo, + const CodePointArrayView codePoints) const { TimeKeeper::setCurrentTime(); - int wordId = mDictionaryStructureWithBufferPolicy->getWordId( - CodePointArrayView(word, length), false /* forceLowerCaseSearch */); + const int wordId = mDictionaryStructureWithBufferPolicy->getWordId(codePoints, + false /* forceLowerCaseSearch */); if (wordId == NOT_A_WORD_ID) return NOT_A_PROBABILITY; if (!prevWordsInfo) { return getDictionaryStructurePolicy()->getProbabilityOfWord(WordIdArrayView(), wordId); } - WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIds; - prevWordsInfo->getPrevWordIds(mDictionaryStructureWithBufferPolicy.get(), prevWordIds.data(), + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; + const WordIdArrayView prevWordIds = prevWordsInfo->getPrevWordIds + (mDictionaryStructureWithBufferPolicy.get(), &prevWordIdArray, true /* tryLowerCaseSearch */); - return getDictionaryStructurePolicy()->getProbabilityOfWord( - IntArrayView::fromArray(prevWordIds), wordId); + return getDictionaryStructurePolicy()->getProbabilityOfWord(prevWordIds, wordId); } -bool Dictionary::addUnigramEntry(const int *const word, const int length, +bool Dictionary::addUnigramEntry(const CodePointArrayView codePoints, const UnigramProperty *const unigramProperty) { if (unigramProperty->representsBeginningOfSentence() && !mDictionaryStructureWithBufferPolicy->getHeaderStructurePolicy() @@ -137,14 +137,12 @@ bool Dictionary::addUnigramEntry(const int *const word, const int length, return false; } TimeKeeper::setCurrentTime(); - return mDictionaryStructureWithBufferPolicy->addUnigramEntry(CodePointArrayView(word, length), - unigramProperty); + return mDictionaryStructureWithBufferPolicy->addUnigramEntry(codePoints, unigramProperty); } -bool Dictionary::removeUnigramEntry(const int *const codePoints, const int codePointCount) { +bool Dictionary::removeUnigramEntry(const CodePointArrayView codePoints) { TimeKeeper::setCurrentTime(); - return mDictionaryStructureWithBufferPolicy->removeUnigramEntry( - CodePointArrayView(codePoints, codePointCount)); + return mDictionaryStructureWithBufferPolicy->removeUnigramEntry(codePoints); } bool Dictionary::addNgramEntry(const PrevWordsInfo *const prevWordsInfo, @@ -154,10 +152,9 @@ bool Dictionary::addNgramEntry(const PrevWordsInfo *const prevWordsInfo, } bool Dictionary::removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, - const int *const word, const int length) { + const CodePointArrayView codePoints) { TimeKeeper::setCurrentTime(); - return mDictionaryStructureWithBufferPolicy->removeNgramEntry(prevWordsInfo, - CodePointArrayView(word, length)); + return mDictionaryStructureWithBufferPolicy->removeNgramEntry(prevWordsInfo, codePoints); } bool Dictionary::flush(const char *const filePath) { @@ -182,11 +179,9 @@ void Dictionary::getProperty(const char *const query, const int queryLength, cha maxResultLength); } -const WordProperty Dictionary::getWordProperty(const int *const codePoints, - const int codePointCount) { +const WordProperty Dictionary::getWordProperty(const CodePointArrayView codePoints) { TimeKeeper::setCurrentTime(); - return mDictionaryStructureWithBufferPolicy->getWordProperty( - CodePointArrayView(codePoints, codePointCount)); + return mDictionaryStructureWithBufferPolicy->getWordProperty(codePoints); } int Dictionary::getNextWordAndNextToken(const int token, int *const outCodePoints, diff --git a/native/jni/src/suggest/core/dictionary/dictionary.h b/native/jni/src/suggest/core/dictionary/dictionary.h index 0b54f30e9..f6482ab78 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.h +++ b/native/jni/src/suggest/core/dictionary/dictionary.h @@ -72,23 +72,23 @@ class Dictionary { void getPredictions(const PrevWordsInfo *const prevWordsInfo, SuggestionResults *const outSuggestionResults) const; - int getProbability(const int *word, int length) const; + int getProbability(const CodePointArrayView codePoints) const; - int getMaxProbabilityOfExactMatches(const int *word, int length) const; + int getMaxProbabilityOfExactMatches(const CodePointArrayView codePoints) const; int getNgramProbability(const PrevWordsInfo *const prevWordsInfo, - const int *word, int length) const; + const CodePointArrayView codePoints) const; - bool addUnigramEntry(const int *const codePoints, const int codePointCount, + bool addUnigramEntry(const CodePointArrayView codePoints, const UnigramProperty *const unigramProperty); - bool removeUnigramEntry(const int *const codePoints, const int codePointCount); + bool removeUnigramEntry(const CodePointArrayView codePoints); bool addNgramEntry(const PrevWordsInfo *const prevWordsInfo, const BigramProperty *const bigramProperty); - bool removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, const int *const word, - const int length); + bool removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, + const CodePointArrayView codePoints); bool flush(const char *const filePath); @@ -99,7 +99,7 @@ class Dictionary { void getProperty(const char *const query, const int queryLength, char *const outResult, const int maxResultLength); - const WordProperty getWordProperty(const int *const codePoints, const int codePointCount); + const WordProperty getWordProperty(const CodePointArrayView codePoints); // Method to iterate all words in the dictionary. // The returned token has to be used to get the next word. If token is 0, this method newly diff --git a/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp b/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp index 617bf5a90..b85f3622a 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp +++ b/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp @@ -29,28 +29,27 @@ namespace latinime { /* static */ int DictionaryUtils::getMaxProbabilityOfExactMatches( const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy, - const int *const codePoints, const int codePointCount) { + const CodePointArrayView codePoints) { std::vector<DicNode> current; std::vector<DicNode> next; // No prev words information. PrevWordsInfo emptyPrevWordsInfo; - WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIds; - emptyPrevWordsInfo.getPrevWordIds(dictionaryStructurePolicy, prevWordIds.data(), - false /* tryLowerCaseSearch */); + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; + const WordIdArrayView prevWordIds = emptyPrevWordsInfo.getPrevWordIds( + dictionaryStructurePolicy, &prevWordIdArray, false /* tryLowerCaseSearch */); current.emplace_back(); - DicNodeUtils::initAsRoot(dictionaryStructurePolicy, - IntArrayView::fromArray(prevWordIds), ¤t.front()); - for (int i = 0; i < codePointCount; ++i) { + DicNodeUtils::initAsRoot(dictionaryStructurePolicy, prevWordIds, ¤t.front()); + for (const int codePoint : codePoints) { // The base-lower input is used to ignore case errors and accent errors. - const int codePoint = CharUtils::toBaseLowerCase(codePoints[i]); + const int baseLowerCodePoint = CharUtils::toBaseLowerCase(codePoint); for (const DicNode &dicNode : current) { - if (dicNode.isInDigraph() && dicNode.getNodeCodePoint() == codePoint) { + if (dicNode.isInDigraph() && dicNode.getNodeCodePoint() == baseLowerCodePoint) { next.emplace_back(dicNode); next.back().advanceDigraphIndex(); continue; } - processChildDicNodes(dictionaryStructurePolicy, codePoint, &dicNode, &next); + processChildDicNodes(dictionaryStructurePolicy, baseLowerCodePoint, &dicNode, &next); } current.clear(); current.swap(next); diff --git a/native/jni/src/suggest/core/dictionary/dictionary_utils.h b/native/jni/src/suggest/core/dictionary/dictionary_utils.h index 358ebf674..4dd21c9be 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary_utils.h +++ b/native/jni/src/suggest/core/dictionary/dictionary_utils.h @@ -20,6 +20,7 @@ #include <vector> #include "defines.h" +#include "utils/int_array_view.h" namespace latinime { @@ -30,7 +31,7 @@ class DictionaryUtils { public: static int getMaxProbabilityOfExactMatches( const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy, - const int *const codePoints, const int codePointCount); + const CodePointArrayView codePoints); private: DISALLOW_IMPLICIT_CONSTRUCTORS(DictionaryUtils); diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.cpp b/native/jni/src/suggest/core/session/dic_traverse_session.cpp index 4f58c0c54..4d7505a55 100644 --- a/native/jni/src/suggest/core/session/dic_traverse_session.cpp +++ b/native/jni/src/suggest/core/session/dic_traverse_session.cpp @@ -35,8 +35,8 @@ void DicTraverseSession::init(const Dictionary *const dictionary, mMultiWordCostMultiplier = getDictionaryStructurePolicy()->getHeaderStructurePolicy() ->getMultiWordCostMultiplier(); mSuggestOptions = suggestOptions; - prevWordsInfo->getPrevWordIds(getDictionaryStructurePolicy(), mPrevWordsIds.data(), - true /* tryLowerCaseSearch */); + mPrevWordIdCount = prevWordsInfo->getPrevWordIds(getDictionaryStructurePolicy(), + &mPrevWordIdArray, true /* tryLowerCaseSearch */).size(); } void DicTraverseSession::setupForGetSuggestions(const ProximityInfo *pInfo, diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.h b/native/jni/src/suggest/core/session/dic_traverse_session.h index af199071e..9f841aa3c 100644 --- a/native/jni/src/suggest/core/session/dic_traverse_session.h +++ b/native/jni/src/suggest/core/session/dic_traverse_session.h @@ -51,12 +51,11 @@ class DicTraverseSession { } AK_FORCE_INLINE DicTraverseSession(JNIEnv *env, jstring localeStr, bool usesLargeCache) - : mProximityInfo(nullptr), mDictionary(nullptr), mSuggestOptions(nullptr), - mDicNodesCache(usesLargeCache), mMultiBigramMap(), mInputSize(0), mMaxPointerCount(1), - mMultiWordCostMultiplier(1.0f) { + : mPrevWordIdCount(0), mProximityInfo(nullptr), mDictionary(nullptr), + mSuggestOptions(nullptr), mDicNodesCache(usesLargeCache), mMultiBigramMap(), + mInputSize(0), mMaxPointerCount(1), mMultiWordCostMultiplier(1.0f) { // NOTE: mProximityInfoStates is an array of instances. // No need to initialize it explicitly here. - mPrevWordsIds.fill(NOT_A_DICT_POS); } // Non virtual inline destructor -- never inherit this class @@ -78,7 +77,9 @@ class DicTraverseSession { //-------------------- const ProximityInfo *getProximityInfo() const { return mProximityInfo; } const SuggestOptions *getSuggestOptions() const { return mSuggestOptions; } - const WordIdArrayView getPrevWordIds() const { return IntArrayView::fromArray(mPrevWordsIds); } + const WordIdArrayView getPrevWordIds() const { + return WordIdArrayView::fromArray(mPrevWordIdArray).limit(mPrevWordIdCount); + } DicNodesCache *getDicTraverseCache() { return &mDicNodesCache; } MultiBigramMap *getMultiBigramMap() { return &mMultiBigramMap; } const ProximityInfoState *getProximityInfoState(int id) const { @@ -165,7 +166,8 @@ class DicTraverseSession { const int *const inputYs, const int *const times, const int *const pointerIds, const int inputSize, const float maxSpatialDistance, const int maxPointerCount); - WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> mPrevWordsIds; + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> mPrevWordIdArray; + size_t mPrevWordIdCount; const ProximityInfo *mProximityInfo; const Dictionary *mDictionary; const SuggestOptions *mSuggestOptions; diff --git a/native/jni/src/suggest/core/session/prev_words_info.h b/native/jni/src/suggest/core/session/prev_words_info.h index fc9a35968..02e82a8e0 100644 --- a/native/jni/src/suggest/core/session/prev_words_info.h +++ b/native/jni/src/suggest/core/session/prev_words_info.h @@ -17,6 +17,8 @@ #ifndef LATINIME_PREV_WORDS_INFO_H #define LATINIME_PREV_WORDS_INFO_H +#include <array> + #include "defines.h" #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" #include "utils/char_utils.h" @@ -27,12 +29,13 @@ namespace latinime { class PrevWordsInfo { public: // No prev word information. - PrevWordsInfo() { + PrevWordsInfo() : mPrevWordCount(0) { clear(); } - PrevWordsInfo(PrevWordsInfo &&prevWordsInfo) { - for (size_t i = 0; i < NELEMS(mPrevWordCodePoints); ++i) { + PrevWordsInfo(PrevWordsInfo &&prevWordsInfo) + : mPrevWordCount(prevWordsInfo.mPrevWordCount) { + for (size_t i = 0; i < mPrevWordCount; ++i) { mPrevWordCodePointCount[i] = prevWordsInfo.mPrevWordCodePointCount[i]; memmove(mPrevWordCodePoints[i], prevWordsInfo.mPrevWordCodePoints[i], sizeof(mPrevWordCodePoints[i][0]) * mPrevWordCodePointCount[i]); @@ -43,9 +46,10 @@ class PrevWordsInfo { // Construct from previous words. PrevWordsInfo(const int prevWordCodePoints[][MAX_WORD_LENGTH], const int *const prevWordCodePointCount, const bool *const isBeginningOfSentence, - const size_t prevWordCount) { + const size_t prevWordCount) + : mPrevWordCount(std::min(NELEMS(mPrevWordCodePoints), prevWordCount)) { clear(); - for (size_t i = 0; i < std::min(NELEMS(mPrevWordCodePoints), prevWordCount); ++i) { + for (size_t i = 0; i < mPrevWordCount; ++i) { if (prevWordCodePointCount[i] < 0 || prevWordCodePointCount[i] > MAX_WORD_LENGTH) { continue; } @@ -58,7 +62,7 @@ class PrevWordsInfo { // Construct from a previous word. PrevWordsInfo(const int *const prevWordCodePoints, const int prevWordCodePointCount, - const bool isBeginningOfSentence) { + const bool isBeginningOfSentence) : mPrevWordCount(1) { clear(); if (prevWordCodePointCount > MAX_WORD_LENGTH || !prevWordCodePoints) { return; @@ -79,26 +83,29 @@ class PrevWordsInfo { return false; } - void getPrevWordIds(const DictionaryStructureWithBufferPolicy *const dictStructurePolicy, - int *const outPrevWordIds, const bool tryLowerCaseSearch) const { - for (size_t i = 0; i < NELEMS(mPrevWordCodePoints); ++i) { - outPrevWordIds[i] = getWordId(dictStructurePolicy, + template<size_t N> + const WordIdArrayView getPrevWordIds( + const DictionaryStructureWithBufferPolicy *const dictStructurePolicy, + std::array<int, N> *const prevWordIdBuffer, const bool tryLowerCaseSearch) const { + for (size_t i = 0; i < std::min(mPrevWordCount, N); ++i) { + prevWordIdBuffer->at(i) = getWordId(dictStructurePolicy, mPrevWordCodePoints[i], mPrevWordCodePointCount[i], mIsBeginningOfSentence[i], tryLowerCaseSearch); } + return WordIdArrayView::fromArray(*prevWordIdBuffer).limit(mPrevWordCount); } // n is 1-indexed. - const CodePointArrayView getNthPrevWordCodePoints(const int n) const { - if (n <= 0 || n > MAX_PREV_WORD_COUNT_FOR_N_GRAM) { + const CodePointArrayView getNthPrevWordCodePoints(const size_t n) const { + if (n <= 0 || n > mPrevWordCount) { return CodePointArrayView(); } return CodePointArrayView(mPrevWordCodePoints[n - 1], mPrevWordCodePointCount[n - 1]); } // n is 1-indexed. - bool isNthPrevWordBeginningOfSentence(const int n) const { - if (n <= 0 || n > MAX_PREV_WORD_COUNT_FOR_N_GRAM) { + bool isNthPrevWordBeginningOfSentence(const size_t n) const { + if (n <= 0 || n > mPrevWordCount) { return false; } return mIsBeginningOfSentence[n - 1]; @@ -142,6 +149,7 @@ class PrevWordsInfo { } } + const size_t mPrevWordCount; int mPrevWordCodePoints[MAX_PREV_WORD_COUNT_FOR_N_GRAM][MAX_WORD_LENGTH]; int mPrevWordCodePointCount[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; bool mIsBeginningOfSentence[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp index dfc3d2d9b..ee1403739 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp @@ -268,8 +268,8 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const CodePointArrayView wordCodePo return false; } const CodePointArrayView codePointArrayView(codePointsToAdd, codePointCountToAdd); - if (mUpdatingHelper.addUnigramWord(&readingHelper, codePointArrayView.data(), - codePointArrayView.size(), unigramProperty, &addedNewUnigram)) { + if (mUpdatingHelper.addUnigramWord(&readingHelper, codePointArrayView, unigramProperty, + &addedNewUnigram)) { if (addedNewUnigram && !unigramProperty->representsBeginningOfSentence()) { mUnigramCount++; } @@ -283,8 +283,8 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const CodePointArrayView wordCodePo } for (const auto &shortcut : unigramProperty->getShortcuts()) { if (!mUpdatingHelper.addShortcutTarget(wordPos, - shortcut.getTargetCodePoints()->data(), - shortcut.getTargetCodePoints()->size(), shortcut.getProbability())) { + CodePointArrayView(*shortcut.getTargetCodePoints()), + shortcut.getProbability())) { AKLOGE("Cannot add new shortcut target. PtNodePos: %d, length: %zd, " "probability: %d", wordPos, shortcut.getTargetCodePoints()->size(), shortcut.getProbability()); @@ -332,8 +332,12 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI "length: %zd", bigramProperty->getTargetCodePoints()->size()); return false; } - int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; - prevWordsInfo->getPrevWordIds(this, prevWordIds, false /* tryLowerCaseSearch */); + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; + const WordIdArrayView prevWordIds = prevWordsInfo->getPrevWordIds(this, &prevWordIdArray, + false /* tryLowerCaseSearch */); + if (prevWordIds.empty()) { + return false; + } if (prevWordIds[0] == NOT_A_WORD_ID) { if (prevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */)) { const std::vector<UnigramProperty::ShortcutProperty> shortcuts; @@ -347,7 +351,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI return false; } // Refresh word ids. - prevWordsInfo->getPrevWordIds(this, prevWordIds, false /* tryLowerCaseSearch */); + prevWordsInfo->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */); } else { return false; } @@ -390,9 +394,10 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor AKLOGE("word is too long to remove n-gram entry form the dictionary. length: %zd", wordCodePoints.size()); } - int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; - prevWordsInfo->getPrevWordIds(this, prevWordIds, false /* tryLowerCaseSerch */); - if (prevWordIds[0] == NOT_A_WORD_ID) { + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; + const WordIdArrayView prevWordIds = prevWordsInfo->getPrevWordIds(this, &prevWordIdArray, + false /* tryLowerCaseSerch */); + if (prevWordIds.firstOrDefault(NOT_A_WORD_ID) == NOT_A_WORD_ID) { return false; } const int wordPos = getTerminalPtNodePosFromWordId(getWordId(wordCodePoints, diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/bigram/bigram_list_read_write_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/bigram/bigram_list_read_write_utils.cpp index f7fd5c071..1b2f857ab 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/bigram/bigram_list_read_write_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/bigram/bigram_list_read_write_utils.cpp @@ -39,32 +39,31 @@ const BigramListReadWriteUtils::BigramFlags BigramListReadWriteUtils::MASK_ATTRIBUTE_PROBABILITY = 0x0F; /* static */ bool BigramListReadWriteUtils::getBigramEntryPropertiesAndAdvancePosition( - const uint8_t *const bigramsBuf, const int bufSize, BigramFlags *const outBigramFlags, + const ReadOnlyByteArrayView buffer, BigramFlags *const outBigramFlags, int *const outTargetPtNodePos, int *const bigramEntryPos) { - if (bufSize <= *bigramEntryPos) { - AKLOGE("Read invalid pos in getBigramEntryPropertiesAndAdvancePosition(). bufSize: %d, " - "bigramEntryPos: %d.", bufSize, *bigramEntryPos); + if (static_cast<int>(buffer.size()) <= *bigramEntryPos) { + AKLOGE("Read invalid pos in getBigramEntryPropertiesAndAdvancePosition(). bufSize: %zd, " + "bigramEntryPos: %d.", buffer.size(), *bigramEntryPos); return false; } - const BigramFlags bigramFlags = ByteArrayUtils::readUint8AndAdvancePosition(bigramsBuf, + const BigramFlags bigramFlags = ByteArrayUtils::readUint8AndAdvancePosition(buffer.data(), bigramEntryPos); if (outBigramFlags) { *outBigramFlags = bigramFlags; } - const int targetPos = getBigramAddressAndAdvancePosition(bigramsBuf, bigramFlags, - bigramEntryPos); + const int targetPos = getBigramAddressAndAdvancePosition(buffer, bigramFlags, bigramEntryPos); if (outTargetPtNodePos) { *outTargetPtNodePos = targetPos; } return true; } -/* static */ bool BigramListReadWriteUtils::skipExistingBigrams(const uint8_t *const bigramsBuf, - const int bufSize, int *const bigramListPos) { +/* static */ bool BigramListReadWriteUtils::skipExistingBigrams(const ReadOnlyByteArrayView buffer, + int *const bigramListPos) { BigramFlags flags; do { - if (!getBigramEntryPropertiesAndAdvancePosition(bigramsBuf, bufSize, &flags, - 0 /* outTargetPtNodePos */, bigramListPos)) { + if (!getBigramEntryPropertiesAndAdvancePosition(buffer, &flags, 0 /* outTargetPtNodePos */, + bigramListPos)) { return false; } } while(hasNext(flags)); @@ -72,18 +71,18 @@ const BigramListReadWriteUtils::BigramFlags } /* static */ int BigramListReadWriteUtils::getBigramAddressAndAdvancePosition( - const uint8_t *const bigramsBuf, const BigramFlags flags, int *const pos) { + const ReadOnlyByteArrayView buffer, const BigramFlags flags, int *const pos) { int offset = 0; const int origin = *pos; switch (MASK_ATTRIBUTE_ADDRESS_TYPE & flags) { case FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE: - offset = ByteArrayUtils::readUint8AndAdvancePosition(bigramsBuf, pos); + offset = ByteArrayUtils::readUint8AndAdvancePosition(buffer.data(), pos); break; case FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES: - offset = ByteArrayUtils::readUint16AndAdvancePosition(bigramsBuf, pos); + offset = ByteArrayUtils::readUint16AndAdvancePosition(buffer.data(), pos); break; case FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES: - offset = ByteArrayUtils::readUint24AndAdvancePosition(bigramsBuf, pos); + offset = ByteArrayUtils::readUint24AndAdvancePosition(buffer.data(), pos); break; } if (isOffsetNegative(flags)) { diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/bigram/bigram_list_read_write_utils.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/bigram/bigram_list_read_write_utils.h index 10f93fb7a..a0f7d5e83 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/bigram/bigram_list_read_write_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/bigram/bigram_list_read_write_utils.h @@ -21,6 +21,7 @@ #include <cstdlib> #include "defines.h" +#include "utils/byte_array_view.h" namespace latinime { @@ -30,8 +31,8 @@ class BigramListReadWriteUtils { public: typedef uint8_t BigramFlags; - static bool getBigramEntryPropertiesAndAdvancePosition(const uint8_t *const bigramsBuf, - const int bufSize, BigramFlags *const outBigramFlags, int *const outTargetPtNodePos, + static bool getBigramEntryPropertiesAndAdvancePosition(const ReadOnlyByteArrayView buffer, + BigramFlags *const outBigramFlags, int *const outTargetPtNodePos, int *const bigramEntryPos); static AK_FORCE_INLINE int getProbabilityFromFlags(const BigramFlags flags) { @@ -43,8 +44,7 @@ public: } // Bigrams reading methods - static bool skipExistingBigrams(const uint8_t *const bigramsBuf, const int bufSize, - int *const bigramListPos); + static bool skipExistingBigrams(const ReadOnlyByteArrayView buffer, int *const bigramListPos); private: DISALLOW_IMPLICIT_CONSTRUCTORS(BigramListReadWriteUtils); @@ -61,7 +61,7 @@ private: return (flags & FLAG_ATTRIBUTE_OFFSET_NEGATIVE) != 0; } - static int getBigramAddressAndAdvancePosition(const uint8_t *const bigramsBuf, + static int getBigramAddressAndAdvancePosition(const ReadOnlyByteArrayView buffer, const BigramFlags flags, int *const pos); }; } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp index 086d98b4a..40782a44f 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp @@ -218,9 +218,9 @@ int DynamicPtReadingHelper::getCodePointsAndProbabilityAndReturnCodePointCount( } int DynamicPtReadingHelper::getTerminalPtNodePositionOfWord(const int *const inWord, - const int length, const bool forceLowerCaseSearch) { + const size_t length, const bool forceLowerCaseSearch) { int searchCodePoints[length]; - for (int i = 0; i < length; ++i) { + for (size_t i = 0; i < length; ++i) { searchCodePoints[i] = forceLowerCaseSearch ? CharUtils::toLowerCase(inWord[i]) : inWord[i]; } while (!isEnd()) { diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h index b7262581a..9a7abc97f 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h @@ -138,12 +138,12 @@ class DynamicPtReadingHelper { } // Return code point count exclude the last read node's code points. - AK_FORCE_INLINE int getPrevTotalCodePointCount() const { + AK_FORCE_INLINE size_t getPrevTotalCodePointCount() const { return mReadingState.mTotalCodePointCountSinceInitialization; } // Return code point count include the last read node's code points. - AK_FORCE_INLINE int getTotalCodePointCount(const PtNodeParams &ptNodeParams) const { + AK_FORCE_INLINE size_t getTotalCodePointCount(const PtNodeParams &ptNodeParams) const { return mReadingState.mTotalCodePointCountSinceInitialization + ptNodeParams.getCodePointCount(); } @@ -214,7 +214,7 @@ class DynamicPtReadingHelper { int getCodePointsAndProbabilityAndReturnCodePointCount(const int maxCodePointCount, int *const outCodePoints, int *const outUnigramProbability); - int getTerminalPtNodePositionOfWord(const int *const inWord, const int length, + int getTerminalPtNodePositionOfWord(const int *const inWord, const size_t length, const bool forceLowerCaseSearch); private: @@ -234,7 +234,7 @@ class DynamicPtReadingHelper { int mPos; // Remaining node count in the current array. int mRemainingPtNodeCountInThisArray; - int mTotalCodePointCountSinceInitialization; + size_t mTotalCodePointCountSinceInitialization; // Counter of PtNodes used to avoid infinite loops caused by broken or malicious links. int mTotalPtNodeIndexInThisArrayChain; // Counter of PtNode arrays used to avoid infinite loops caused by cyclic links of empty diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.cpp index 3c62e2e56..3b58d7d6d 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.cpp @@ -28,17 +28,16 @@ namespace latinime { const int DynamicPtUpdatingHelper::CHILDREN_POSITION_FIELD_SIZE = 3; -bool DynamicPtUpdatingHelper::addUnigramWord( - DynamicPtReadingHelper *const readingHelper, - const int *const wordCodePoints, const int codePointCount, - const UnigramProperty *const unigramProperty, bool *const outAddedNewUnigram) { +bool DynamicPtUpdatingHelper::addUnigramWord(DynamicPtReadingHelper *const readingHelper, + const CodePointArrayView wordCodePoints, const UnigramProperty *const unigramProperty, + bool *const outAddedNewUnigram) { int parentPos = NOT_A_DICT_POS; while (!readingHelper->isEnd()) { const PtNodeParams ptNodeParams(readingHelper->getPtNodeParams()); if (!ptNodeParams.isValid()) { break; } - const int matchedCodePointCount = readingHelper->getPrevTotalCodePointCount(); + const size_t matchedCodePointCount = readingHelper->getPrevTotalCodePointCount(); if (!readingHelper->isMatchedCodePoint(ptNodeParams, 0 /* index */, wordCodePoints[matchedCodePointCount])) { // The first code point is different from target code point. Skip this node and read @@ -47,26 +46,25 @@ bool DynamicPtUpdatingHelper::addUnigramWord( continue; } // Check following merged node code points. - const int nodeCodePointCount = ptNodeParams.getCodePointCount(); - for (int j = 1; j < nodeCodePointCount; ++j) { - const int nextIndex = matchedCodePointCount + j; - if (nextIndex >= codePointCount || !readingHelper->isMatchedCodePoint(ptNodeParams, j, - wordCodePoints[matchedCodePointCount + j])) { + const size_t nodeCodePointCount = ptNodeParams.getCodePointArrayView().size(); + for (size_t j = 1; j < nodeCodePointCount; ++j) { + const size_t nextIndex = matchedCodePointCount + j; + if (nextIndex >= wordCodePoints.size() + || !readingHelper->isMatchedCodePoint(ptNodeParams, j, + wordCodePoints[matchedCodePointCount + j])) { *outAddedNewUnigram = true; return reallocatePtNodeAndAddNewPtNodes(&ptNodeParams, j, unigramProperty, - wordCodePoints + matchedCodePointCount, - codePointCount - matchedCodePointCount); + wordCodePoints.skip(matchedCodePointCount)); } } // All characters are matched. - if (codePointCount == readingHelper->getTotalCodePointCount(ptNodeParams)) { + if (wordCodePoints.size() == readingHelper->getTotalCodePointCount(ptNodeParams)) { return setPtNodeProbability(&ptNodeParams, unigramProperty, outAddedNewUnigram); } if (!ptNodeParams.hasChildren()) { *outAddedNewUnigram = true; return createChildrenPtNodeArrayAndAChildPtNode(&ptNodeParams, unigramProperty, - wordCodePoints + readingHelper->getTotalCodePointCount(ptNodeParams), - codePointCount - readingHelper->getTotalCodePointCount(ptNodeParams)); + wordCodePoints.skip(readingHelper->getTotalCodePointCount(ptNodeParams))); } // Advance to the children nodes. parentPos = ptNodeParams.getHeadPos(); @@ -79,9 +77,8 @@ bool DynamicPtUpdatingHelper::addUnigramWord( int pos = readingHelper->getPosOfLastForwardLinkField(); *outAddedNewUnigram = true; return createAndInsertNodeIntoPtNodeArray(parentPos, - wordCodePoints + readingHelper->getPrevTotalCodePointCount(), - codePointCount - readingHelper->getPrevTotalCodePointCount(), - unigramProperty, &pos); + wordCodePoints.skip(readingHelper->getPrevTotalCodePointCount()), unigramProperty, + &pos); } bool DynamicPtUpdatingHelper::addNgramEntry(const PtNodePosArrayView prevWordsPtNodePos, @@ -120,23 +117,21 @@ bool DynamicPtUpdatingHelper::removeNgramEntry(const PtNodePosArrayView prevWord } bool DynamicPtUpdatingHelper::addShortcutTarget(const int wordPos, - const int *const targetCodePoints, const int targetCodePointCount, - const int shortcutProbability) { + const CodePointArrayView targetCodePoints, const int shortcutProbability) { const PtNodeParams ptNodeParams(mPtNodeReader->fetchPtNodeParamsInBufferFromPtNodePos(wordPos)); - return mPtNodeWriter->addShortcutTarget(&ptNodeParams, targetCodePoints, targetCodePointCount, - shortcutProbability); + return mPtNodeWriter->addShortcutTarget(&ptNodeParams, targetCodePoints.data(), + targetCodePoints.size(), shortcutProbability); } bool DynamicPtUpdatingHelper::createAndInsertNodeIntoPtNodeArray(const int parentPos, - const int *const nodeCodePoints, const int nodeCodePointCount, - const UnigramProperty *const unigramProperty, int *const forwardLinkFieldPos) { + const CodePointArrayView ptNodeCodePoints, const UnigramProperty *const unigramProperty, + int *const forwardLinkFieldPos) { const int newPtNodeArrayPos = mBuffer->getTailPosition(); if (!DynamicPtWritingUtils::writeForwardLinkPositionAndAdvancePosition(mBuffer, newPtNodeArrayPos, forwardLinkFieldPos)) { return false; } - return createNewPtNodeArrayWithAChildPtNode(parentPos, nodeCodePoints, nodeCodePointCount, - unigramProperty); + return createNewPtNodeArrayWithAChildPtNode(parentPos, ptNodeCodePoints, unigramProperty); } bool DynamicPtUpdatingHelper::setPtNodeProbability(const PtNodeParams *const originalPtNodeParams, @@ -153,8 +148,7 @@ bool DynamicPtUpdatingHelper::setPtNodeProbability(const PtNodeParams *const ori const PtNodeParams ptNodeParamsToWrite(getUpdatedPtNodeParams(originalPtNodeParams, unigramProperty->isNotAWord(), unigramProperty->isBlacklisted(), true /* isTerminal */, originalPtNodeParams->getParentPos(), - originalPtNodeParams->getCodePointCount(), originalPtNodeParams->getCodePoints(), - unigramProperty->getProbability())); + originalPtNodeParams->getCodePointArrayView(), unigramProperty->getProbability())); if (!mPtNodeWriter->writeNewTerminalPtNodeAndAdvancePosition(&ptNodeParamsToWrite, unigramProperty, &writingPos)) { return false; @@ -168,17 +162,17 @@ bool DynamicPtUpdatingHelper::setPtNodeProbability(const PtNodeParams *const ori bool DynamicPtUpdatingHelper::createChildrenPtNodeArrayAndAChildPtNode( const PtNodeParams *const parentPtNodeParams, const UnigramProperty *const unigramProperty, - const int *const codePoints, const int codePointCount) { + const CodePointArrayView codePoints) { const int newPtNodeArrayPos = mBuffer->getTailPosition(); if (!mPtNodeWriter->updateChildrenPosition(parentPtNodeParams, newPtNodeArrayPos)) { return false; } return createNewPtNodeArrayWithAChildPtNode(parentPtNodeParams->getHeadPos(), codePoints, - codePointCount, unigramProperty); + unigramProperty); } bool DynamicPtUpdatingHelper::createNewPtNodeArrayWithAChildPtNode( - const int parentPtNodePos, const int *const nodeCodePoints, const int nodeCodePointCount, + const int parentPtNodePos, const CodePointArrayView ptNodeCodePoints, const UnigramProperty *const unigramProperty) { int writingPos = mBuffer->getTailPosition(); if (!DynamicPtWritingUtils::writePtNodeArraySizeAndAdvancePosition(mBuffer, @@ -187,8 +181,7 @@ bool DynamicPtUpdatingHelper::createNewPtNodeArrayWithAChildPtNode( } const PtNodeParams ptNodeParamsToWrite(getPtNodeParamsForNewPtNode( unigramProperty->isNotAWord(), unigramProperty->isBlacklisted(), true /* isTerminal */, - parentPtNodePos, nodeCodePointCount, nodeCodePoints, - unigramProperty->getProbability())); + parentPtNodePos, ptNodeCodePoints, unigramProperty->getProbability())); if (!mPtNodeWriter->writeNewTerminalPtNodeAndAdvancePosition(&ptNodeParamsToWrite, unigramProperty, &writingPos)) { return false; @@ -202,9 +195,9 @@ bool DynamicPtUpdatingHelper::createNewPtNodeArrayWithAChildPtNode( // Returns whether the dictionary updating was succeeded or not. bool DynamicPtUpdatingHelper::reallocatePtNodeAndAddNewPtNodes( - const PtNodeParams *const reallocatingPtNodeParams, const int overlappingCodePointCount, - const UnigramProperty *const unigramProperty, const int *const newNodeCodePoints, - const int newNodeCodePointCount) { + const PtNodeParams *const reallocatingPtNodeParams, const size_t overlappingCodePointCount, + const UnigramProperty *const unigramProperty, + const CodePointArrayView newPtNodeCodePoints) { // When addsExtraChild is true, split the reallocating PtNode and add new child. // Reallocating PtNode: abcde, newNode: abcxy. // abc (1st, not terminal) __ de (2nd) @@ -212,16 +205,18 @@ bool DynamicPtUpdatingHelper::reallocatePtNodeAndAddNewPtNodes( // Otherwise, this method makes 1st part terminal and write information in unigramProperty. // Reallocating PtNode: abcde, newNode: abc. // abc (1st, terminal) __ de (2nd) - const bool addsExtraChild = newNodeCodePointCount > overlappingCodePointCount; + const bool addsExtraChild = newPtNodeCodePoints.size() > overlappingCodePointCount; const int firstPartOfReallocatedPtNodePos = mBuffer->getTailPosition(); int writingPos = firstPartOfReallocatedPtNodePos; // Write the 1st part of the reallocating node. The children position will be updated later // with actual children position. + const CodePointArrayView firstPtNodeCodePoints = + reallocatingPtNodeParams->getCodePointArrayView().limit(overlappingCodePointCount); if (addsExtraChild) { const PtNodeParams ptNodeParamsToWrite(getPtNodeParamsForNewPtNode( false /* isNotAWord */, false /* isBlacklisted */, false /* isTerminal */, - reallocatingPtNodeParams->getParentPos(), overlappingCodePointCount, - reallocatingPtNodeParams->getCodePoints(), NOT_A_PROBABILITY)); + reallocatingPtNodeParams->getParentPos(), firstPtNodeCodePoints, + NOT_A_PROBABILITY)); if (!mPtNodeWriter->writePtNodeAndAdvancePosition(&ptNodeParamsToWrite, &writingPos)) { return false; } @@ -229,8 +224,7 @@ bool DynamicPtUpdatingHelper::reallocatePtNodeAndAddNewPtNodes( const PtNodeParams ptNodeParamsToWrite(getPtNodeParamsForNewPtNode( unigramProperty->isNotAWord(), unigramProperty->isBlacklisted(), true /* isTerminal */, reallocatingPtNodeParams->getParentPos(), - overlappingCodePointCount, reallocatingPtNodeParams->getCodePoints(), - unigramProperty->getProbability())); + firstPtNodeCodePoints, unigramProperty->getProbability())); if (!mPtNodeWriter->writeNewTerminalPtNodeAndAdvancePosition(&ptNodeParamsToWrite, unigramProperty, &writingPos)) { return false; @@ -248,8 +242,7 @@ bool DynamicPtUpdatingHelper::reallocatePtNodeAndAddNewPtNodes( const PtNodeParams childPartPtNodeParams(getUpdatedPtNodeParams(reallocatingPtNodeParams, reallocatingPtNodeParams->isNotAWord(), reallocatingPtNodeParams->isBlacklisted(), reallocatingPtNodeParams->isTerminal(), firstPartOfReallocatedPtNodePos, - reallocatingPtNodeParams->getCodePointCount() - overlappingCodePointCount, - reallocatingPtNodeParams->getCodePoints() + overlappingCodePointCount, + reallocatingPtNodeParams->getCodePointArrayView().skip(overlappingCodePointCount), reallocatingPtNodeParams->getProbability())); if (!mPtNodeWriter->writePtNodeAndAdvancePosition(&childPartPtNodeParams, &writingPos)) { return false; @@ -258,8 +251,8 @@ bool DynamicPtUpdatingHelper::reallocatePtNodeAndAddNewPtNodes( const PtNodeParams extraChildPtNodeParams(getPtNodeParamsForNewPtNode( unigramProperty->isNotAWord(), unigramProperty->isBlacklisted(), true /* isTerminal */, firstPartOfReallocatedPtNodePos, - newNodeCodePointCount - overlappingCodePointCount, - newNodeCodePoints + overlappingCodePointCount, unigramProperty->getProbability())); + newPtNodeCodePoints.skip(overlappingCodePointCount), + unigramProperty->getProbability())); if (!mPtNodeWriter->writeNewTerminalPtNodeAndAdvancePosition(&extraChildPtNodeParams, unigramProperty, &writingPos)) { return false; @@ -282,26 +275,24 @@ bool DynamicPtUpdatingHelper::reallocatePtNodeAndAddNewPtNodes( } const PtNodeParams DynamicPtUpdatingHelper::getUpdatedPtNodeParams( - const PtNodeParams *const originalPtNodeParams, - const bool isNotAWord, const bool isBlacklisted, const bool isTerminal, const int parentPos, - const int codePointCount, const int *const codePoints, const int probability) const { + const PtNodeParams *const originalPtNodeParams, const bool isNotAWord, + const bool isBlacklisted, const bool isTerminal, const int parentPos, + const CodePointArrayView codePoints, const int probability) const { const PatriciaTrieReadingUtils::NodeFlags flags = PatriciaTrieReadingUtils::createAndGetFlags( isBlacklisted, isNotAWord, isTerminal, false /* hasShortcutTargets */, - false /* hasBigrams */, codePointCount > 1 /* hasMultipleChars */, + false /* hasBigrams */, codePoints.size() > 1u /* hasMultipleChars */, CHILDREN_POSITION_FIELD_SIZE); - return PtNodeParams(originalPtNodeParams, flags, parentPos, codePointCount, codePoints, - probability); + return PtNodeParams(originalPtNodeParams, flags, parentPos, codePoints, probability); } -const PtNodeParams DynamicPtUpdatingHelper::getPtNodeParamsForNewPtNode( - const bool isNotAWord, const bool isBlacklisted, const bool isTerminal, - const int parentPos, const int codePointCount, const int *const codePoints, - const int probability) const { +const PtNodeParams DynamicPtUpdatingHelper::getPtNodeParamsForNewPtNode(const bool isNotAWord, + const bool isBlacklisted, const bool isTerminal, const int parentPos, + const CodePointArrayView codePoints, const int probability) const { const PatriciaTrieReadingUtils::NodeFlags flags = PatriciaTrieReadingUtils::createAndGetFlags( isBlacklisted, isNotAWord, isTerminal, false /* hasShortcutTargets */, - false /* hasBigrams */, codePointCount > 1 /* hasMultipleChars */, + false /* hasBigrams */, codePoints.size() > 1u /* hasMultipleChars */, CHILDREN_POSITION_FIELD_SIZE); - return PtNodeParams(flags, parentPos, codePointCount, codePoints, probability); + return PtNodeParams(flags, parentPos, codePoints, probability); } } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.h index 97c05c1ea..710047e8c 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.h @@ -40,19 +40,21 @@ class DynamicPtUpdatingHelper { // Add a word to the dictionary. If the word already exists, update the probability. bool addUnigramWord(DynamicPtReadingHelper *const readingHelper, - const int *const wordCodePoints, const int codePointCount, - const UnigramProperty *const unigramProperty, bool *const outAddedNewUnigram); + const CodePointArrayView wordCodePoints, const UnigramProperty *const unigramProperty, + bool *const outAddedNewUnigram); + // TODO: Remove after stopping supporting v402. // Add an n-gram entry. bool addNgramEntry(const PtNodePosArrayView prevWordsPtNodePos, const int wordPos, const BigramProperty *const bigramProperty, bool *const outAddedNewEntry); + // TODO: Remove after stopping supporting v402. // Remove an n-gram entry. bool removeNgramEntry(const PtNodePosArrayView prevWordsPtNodePos, const int wordPos); // Add a shortcut target. - bool addShortcutTarget(const int wordPos, const int *const targetCodePoints, - const int targetCodePointCount, const int shortcutProbability); + bool addShortcutTarget(const int wordPos, const CodePointArrayView targetCodePoints, + const int shortcutProbability); private: DISALLOW_IMPLICIT_CONSTRUCTORS(DynamicPtUpdatingHelper); @@ -63,33 +65,32 @@ class DynamicPtUpdatingHelper { const PtNodeReader *const mPtNodeReader; PtNodeWriter *const mPtNodeWriter; - bool createAndInsertNodeIntoPtNodeArray(const int parentPos, const int *const nodeCodePoints, - const int nodeCodePointCount, const UnigramProperty *const unigramProperty, + bool createAndInsertNodeIntoPtNodeArray(const int parentPos, + const CodePointArrayView ptNodeCodePoints, const UnigramProperty *const unigramProperty, int *const forwardLinkFieldPos); bool setPtNodeProbability(const PtNodeParams *const originalPtNodeParams, const UnigramProperty *const unigramProperty, bool *const outAddedNewUnigram); bool createChildrenPtNodeArrayAndAChildPtNode(const PtNodeParams *const parentPtNodeParams, - const UnigramProperty *const unigramProperty, const int *const codePoints, - const int codePointCount); + const UnigramProperty *const unigramProperty, + const CodePointArrayView remainingCodePoints); - bool createNewPtNodeArrayWithAChildPtNode(const int parentPos, const int *const nodeCodePoints, - const int nodeCodePointCount, const UnigramProperty *const unigramProperty); + bool createNewPtNodeArrayWithAChildPtNode(const int parentPos, + const CodePointArrayView ptNodeCodePoints, + const UnigramProperty *const unigramProperty); - bool reallocatePtNodeAndAddNewPtNodes( - const PtNodeParams *const reallocatingPtNodeParams, const int overlappingCodePointCount, - const UnigramProperty *const unigramProperty, const int *const newNodeCodePoints, - const int newNodeCodePointCount); + bool reallocatePtNodeAndAddNewPtNodes(const PtNodeParams *const reallocatingPtNodeParams, + const size_t overlappingCodePointCount, const UnigramProperty *const unigramProperty, + const CodePointArrayView newPtNodeCodePoints); const PtNodeParams getUpdatedPtNodeParams(const PtNodeParams *const originalPtNodeParams, const bool isNotAWord, const bool isBlacklisted, const bool isTerminal, - const int parentPos, const int codePointCount, - const int *const codePoints, const int probability) const; + const int parentPos, const CodePointArrayView codePoints, const int probability) const; const PtNodeParams getPtNodeParamsForNewPtNode(const bool isNotAWord, const bool isBlacklisted, - const bool isTerminal, const int parentPos, - const int codePointCount, const int *const codePoints, const int probability) const; + const bool isTerminal, const int parentPos, const CodePointArrayView codePoints, + const int probability) const; }; } // namespace latinime #endif /* LATINIME_DYNAMIC_PATRICIA_TRIE_UPDATING_HELPER_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h index c12fed324..3ff1829bd 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h @@ -89,9 +89,9 @@ class PtNodeParams { // Construct new params by updating existing PtNode params. PtNodeParams(const PtNodeParams *const ptNodeParams, const PatriciaTrieReadingUtils::NodeFlags flags, const int parentPos, - const int codePointCount, const int *const codePoints, const int probability) + const CodePointArrayView codePoints, const int probability) : mHeadPos(ptNodeParams->getHeadPos()), mFlags(flags), mHasMovedFlag(true), - mParentPos(parentPos), mCodePointCount(codePointCount), mCodePoints(), + mParentPos(parentPos), mCodePointCount(codePoints.size()), mCodePoints(), mTerminalIdFieldPos(ptNodeParams->getTerminalIdFieldPos()), mTerminalId(ptNodeParams->getTerminalId()), mProbabilityFieldPos(ptNodeParams->getProbabilityFieldPos()), @@ -102,20 +102,20 @@ class PtNodeParams { mShortcutPos(ptNodeParams->getShortcutPos()), mBigramPos(ptNodeParams->getBigramsPos()), mSiblingPos(ptNodeParams->getSiblingNodePos()) { - memcpy(mCodePoints, codePoints, sizeof(int) * mCodePointCount); + memcpy(mCodePoints, codePoints.data(), sizeof(int) * mCodePointCount); } PtNodeParams(const PatriciaTrieReadingUtils::NodeFlags flags, const int parentPos, - const int codePointCount, const int *const codePoints, const int probability) + const CodePointArrayView codePoints, const int probability) : mHeadPos(NOT_A_DICT_POS), mFlags(flags), mHasMovedFlag(true), mParentPos(parentPos), - mCodePointCount(codePointCount), mCodePoints(), + mCodePointCount(codePoints.size()), mCodePoints(), mTerminalIdFieldPos(NOT_A_DICT_POS), mTerminalId(Ver4DictConstants::NOT_A_TERMINAL_ID), mProbabilityFieldPos(NOT_A_DICT_POS), mProbability(probability), mChildrenPosFieldPos(NOT_A_DICT_POS), mChildrenPos(NOT_A_DICT_POS), mBigramLinkedNodePos(NOT_A_DICT_POS), mShortcutPos(NOT_A_DICT_POS), mBigramPos(NOT_A_DICT_POS), mSiblingPos(NOT_A_DICT_POS) { - memcpy(mCodePoints, codePoints, sizeof(int) * mCodePointCount); + memcpy(mCodePoints, codePoints.data(), sizeof(int) * mCodePointCount); } AK_FORCE_INLINE bool isValid() const { diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/shortcut/shortcut_list_reading_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/shortcut/shortcut_list_reading_utils.cpp index 91c76941c..7cb7dff9a 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/shortcut/shortcut_list_reading_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/shortcut/shortcut_list_reading_utils.cpp @@ -31,21 +31,21 @@ const int ShortcutListReadingUtils::SHORTCUT_LIST_SIZE_FIELD_SIZE = 2; const int ShortcutListReadingUtils::WHITELIST_SHORTCUT_PROBABILITY = 15; /* static */ ShortcutListReadingUtils::ShortcutFlags - ShortcutListReadingUtils::getFlagsAndForwardPointer(const uint8_t *const dictRoot, + ShortcutListReadingUtils::getFlagsAndForwardPointer(const ReadOnlyByteArrayView buffer, int *const pos) { - return ByteArrayUtils::readUint8AndAdvancePosition(dictRoot, pos); + return ByteArrayUtils::readUint8AndAdvancePosition(buffer.data(), pos); } /* static */ int ShortcutListReadingUtils::getShortcutListSizeAndForwardPointer( - const uint8_t *const dictRoot, int *const pos) { + const ReadOnlyByteArrayView buffer, int *const pos) { // readUint16andAdvancePosition() returns an offset *including* the uint16 field itself. - return ByteArrayUtils::readUint16AndAdvancePosition(dictRoot, pos) + return ByteArrayUtils::readUint16AndAdvancePosition(buffer.data(), pos) - SHORTCUT_LIST_SIZE_FIELD_SIZE; } -/* static */ int ShortcutListReadingUtils::readShortcutTarget( - const uint8_t *const dictRoot, const int maxLength, int *const outWord, int *const pos) { - return ByteArrayUtils::readStringAndAdvancePosition(dictRoot, maxLength, outWord, pos); +/* static */ int ShortcutListReadingUtils::readShortcutTarget(const ReadOnlyByteArrayView buffer, + const int maxLength, int *const outWord, int *const pos) { + return ByteArrayUtils::readStringAndAdvancePosition(buffer.data(), maxLength, outWord, pos); } } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/shortcut/shortcut_list_reading_utils.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/shortcut/shortcut_list_reading_utils.h index d065bf7fd..71cb8cc2c 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/shortcut/shortcut_list_reading_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/shortcut/shortcut_list_reading_utils.h @@ -20,6 +20,7 @@ #include <cstdint> #include "defines.h" +#include "utils/byte_array_view.h" namespace latinime { @@ -27,7 +28,8 @@ class ShortcutListReadingUtils { public: typedef uint8_t ShortcutFlags; - static ShortcutFlags getFlagsAndForwardPointer(const uint8_t *const dictRoot, int *const pos); + static ShortcutFlags getFlagsAndForwardPointer(const ReadOnlyByteArrayView buffer, + int *const pos); static AK_FORCE_INLINE int getProbabilityFromFlags(const ShortcutFlags flags) { return flags & MASK_ATTRIBUTE_PROBABILITY; @@ -39,14 +41,15 @@ class ShortcutListReadingUtils { // This method returns the size of the shortcut list region excluding the shortcut list size // field at the beginning. - static int getShortcutListSizeAndForwardPointer(const uint8_t *const dictRoot, int *const pos); + static int getShortcutListSizeAndForwardPointer(const ReadOnlyByteArrayView buffer, + int *const pos); static AK_FORCE_INLINE int getShortcutListSizeFieldSize() { return SHORTCUT_LIST_SIZE_FIELD_SIZE; } - static AK_FORCE_INLINE void skipShortcuts(const uint8_t *const dictRoot, int *const pos) { - const int shortcutListSize = getShortcutListSizeAndForwardPointer(dictRoot, pos); + static AK_FORCE_INLINE void skipShortcuts(const ReadOnlyByteArrayView buffer, int *const pos) { + const int shortcutListSize = getShortcutListSizeAndForwardPointer(buffer, pos); *pos += shortcutListSize; } @@ -54,7 +57,7 @@ class ShortcutListReadingUtils { return getProbabilityFromFlags(flags) == WHITELIST_SHORTCUT_PROBABILITY; } - static int readShortcutTarget(const uint8_t *const dictRoot, const int maxLength, + static int readShortcutTarget(const ReadOnlyByteArrayView buffer, const int maxLength, int *const outWord, int *const pos); private: diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/bigram/bigram_list_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/bigram/bigram_list_policy.h index 73e291ec2..e2608435c 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/bigram/bigram_list_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/bigram/bigram_list_policy.h @@ -22,22 +22,22 @@ #include "defines.h" #include "suggest/core/policy/dictionary_bigrams_structure_policy.h" #include "suggest/policyimpl/dictionary/structure/pt_common/bigram/bigram_list_read_write_utils.h" +#include "utils/byte_array_view.h" namespace latinime { class BigramListPolicy : public DictionaryBigramsStructurePolicy { public: - BigramListPolicy(const uint8_t *const bigramsBuf, const int bufSize) - : mBigramsBuf(bigramsBuf), mBufSize(bufSize) {} + BigramListPolicy(const ReadOnlyByteArrayView buffer) : mBuffer(buffer) {} ~BigramListPolicy() {} void getNextBigram(int *const outBigramPos, int *const outProbability, bool *const outHasNext, int *const pos) const { BigramListReadWriteUtils::BigramFlags flags; - if (!BigramListReadWriteUtils::getBigramEntryPropertiesAndAdvancePosition(mBigramsBuf, - mBufSize, &flags, outBigramPos, pos)) { - AKLOGE("Cannot read bigram entry. mBufSize: %d, pos: %d. ", mBufSize, *pos); + if (!BigramListReadWriteUtils::getBigramEntryPropertiesAndAdvancePosition(mBuffer, &flags, + outBigramPos, pos)) { + AKLOGE("Cannot read bigram entry. bufSize: %zd, pos: %d. ", mBuffer.size(), *pos); *outProbability = NOT_A_PROBABILITY; *outHasNext = false; return; @@ -47,14 +47,13 @@ class BigramListPolicy : public DictionaryBigramsStructurePolicy { } bool skipAllBigrams(int *const pos) const { - return BigramListReadWriteUtils::skipExistingBigrams(mBigramsBuf, mBufSize, pos); + return BigramListReadWriteUtils::skipExistingBigrams(mBuffer, pos); } private: DISALLOW_IMPLICIT_CONSTRUCTORS(BigramListPolicy); - const uint8_t *const mBigramsBuf; - const int mBufSize; + const ReadOnlyByteArrayView mBuffer; }; } // namespace latinime #endif // LATINIME_BIGRAM_LIST_POLICY_H diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp index aae61afca..64b767dac 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp @@ -37,19 +37,19 @@ void PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const dicNo return; } int nextPos = dicNode->getChildrenPtNodeArrayPos(); - if (nextPos < 0 || nextPos >= mDictBufferSize) { - AKLOGE("Children PtNode array position is invalid. pos: %d, dict size: %d", - nextPos, mDictBufferSize); + if (!isValidPos(nextPos)) { + AKLOGE("Children PtNode array position is invalid. pos: %d, dict size: %zd", + nextPos, mBuffer.size()); mIsCorrupted = true; ASSERT(false); return; } const int childCount = PatriciaTrieReadingUtils::getPtNodeArraySizeAndAdvancePosition( - mDictRoot, &nextPos); + mBuffer.data(), &nextPos); for (int i = 0; i < childCount; i++) { - if (nextPos < 0 || nextPos >= mDictBufferSize) { - AKLOGE("Child PtNode position is invalid. pos: %d, dict size: %d, childCount: %d / %d", - nextPos, mDictBufferSize, i, childCount); + if (!isValidPos(nextPos)) { + AKLOGE("Child PtNode position is invalid. pos: %d, dict size: %zd, childCount: %d / %d", + nextPos, mBuffer.size(), i, childCount); mIsCorrupted = true; ASSERT(false); return; @@ -91,56 +91,57 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( int lastCandidatePtNodePos = 0; // Let's loop through PtNodes in this PtNode array searching for either the terminal // or one of its ascendants. - if (pos < 0 || pos >= mDictBufferSize) { - AKLOGE("PtNode array position is invalid. pos: %d, dict size: %d", - pos, mDictBufferSize); + if (!isValidPos(pos)) { + AKLOGE("PtNode array position is invalid. pos: %d, dict size: %zd", + pos, mBuffer.size()); mIsCorrupted = true; ASSERT(false); *outUnigramProbability = NOT_A_PROBABILITY; return 0; } for (int ptNodeCount = PatriciaTrieReadingUtils::getPtNodeArraySizeAndAdvancePosition( - mDictRoot, &pos); ptNodeCount > 0; --ptNodeCount) { + mBuffer.data(), &pos); ptNodeCount > 0; --ptNodeCount) { const int startPos = pos; - if (pos < 0 || pos >= mDictBufferSize) { - AKLOGE("PtNode position is invalid. pos: %d, dict size: %d", pos, mDictBufferSize); + if (!isValidPos(pos)) { + AKLOGE("PtNode position is invalid. pos: %d, dict size: %zd", pos, mBuffer.size()); mIsCorrupted = true; ASSERT(false); *outUnigramProbability = NOT_A_PROBABILITY; return 0; } const PatriciaTrieReadingUtils::NodeFlags flags = - PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(mDictRoot, &pos); + PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(mBuffer.data(), &pos); const int character = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition( - mDictRoot, &pos); + mBuffer.data(), &pos); if (ptNodePos == startPos) { // We found the position. Copy the rest of the code points in the buffer and return // the length. outCodePoints[wordPos] = character; if (PatriciaTrieReadingUtils::hasMultipleChars(flags)) { int nextChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition( - mDictRoot, &pos); + mBuffer.data(), &pos); // We count code points in order to avoid infinite loops if the file is broken // or if there is some other bug int charCount = maxCodePointCount; while (NOT_A_CODE_POINT != nextChar && --charCount > 0) { outCodePoints[++wordPos] = nextChar; nextChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition( - mDictRoot, &pos); + mBuffer.data(), &pos); } } *outUnigramProbability = - PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mDictRoot, + PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mBuffer.data(), &pos); return ++wordPos; } // We need to skip past this PtNode, so skip any remaining code points after the // first and possibly the probability. if (PatriciaTrieReadingUtils::hasMultipleChars(flags)) { - PatriciaTrieReadingUtils::skipCharacters(mDictRoot, flags, MAX_WORD_LENGTH, &pos); + PatriciaTrieReadingUtils::skipCharacters(mBuffer.data(), flags, MAX_WORD_LENGTH, + &pos); } if (PatriciaTrieReadingUtils::isTerminal(flags)) { - PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mDictRoot, &pos); + PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mBuffer.data(), &pos); } // The fact that this PtNode has children is very important. Since we already know // that this PtNode does not match, if it has no children we know it is irrelevant @@ -155,7 +156,8 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( int currentPos = pos; // Here comes the tricky part. First, read the children position. const int childrenPos = PatriciaTrieReadingUtils - ::readChildrenPositionAndAdvancePosition(mDictRoot, flags, ¤tPos); + ::readChildrenPositionAndAdvancePosition(mBuffer.data(), flags, + ¤tPos); if (childrenPos > ptNodePos) { // If the children pos is greater than the position, it means the previous // PtNode, which position is stored in lastCandidatePtNodePos, was the right @@ -185,30 +187,30 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( if (0 != lastCandidatePtNodePos) { const PatriciaTrieReadingUtils::NodeFlags lastFlags = PatriciaTrieReadingUtils::getFlagsAndAdvancePosition( - mDictRoot, &lastCandidatePtNodePos); + mBuffer.data(), &lastCandidatePtNodePos); const int lastChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition( - mDictRoot, &lastCandidatePtNodePos); + mBuffer.data(), &lastCandidatePtNodePos); // We copy all the characters in this PtNode to the buffer outCodePoints[wordPos] = lastChar; if (PatriciaTrieReadingUtils::hasMultipleChars(lastFlags)) { int nextChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition( - mDictRoot, &lastCandidatePtNodePos); + mBuffer.data(), &lastCandidatePtNodePos); int charCount = maxCodePointCount; while (-1 != nextChar && --charCount > 0) { outCodePoints[++wordPos] = nextChar; nextChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition( - mDictRoot, &lastCandidatePtNodePos); + mBuffer.data(), &lastCandidatePtNodePos); } } ++wordPos; // Now we only need to branch to the children address. Skip the probability if // it's there, read pos, and break to resume the search at pos. if (PatriciaTrieReadingUtils::isTerminal(lastFlags)) { - PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mDictRoot, + PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mBuffer.data(), &lastCandidatePtNodePos); } pos = PatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition( - mDictRoot, lastFlags, &lastCandidatePtNodePos); + mBuffer.data(), lastFlags, &lastCandidatePtNodePos); break; } else { // Here is a little tricky part: we come here if we found out that all children @@ -220,14 +222,14 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( // ready to start the next one. if (PatriciaTrieReadingUtils::hasChildrenInFlags(flags)) { PatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition( - mDictRoot, flags, &pos); + mBuffer.data(), flags, &pos); } if (PatriciaTrieReadingUtils::hasShortcutTargets(flags)) { mShortcutListPolicy.skipAllShortcuts(&pos); } if (PatriciaTrieReadingUtils::hasBigrams(flags)) { if (!mBigramListPolicy.skipAllBigrams(&pos)) { - AKLOGE("Cannot skip bigrams. BufSize: %d, pos: %d.", mDictBufferSize, + AKLOGE("Cannot skip bigrams. BufSize: %zd, pos: %d.", mBuffer.size(), pos); mIsCorrupted = true; ASSERT(false); @@ -244,14 +246,14 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( // our pos is after the end of this PtNode, at the start of the next one. if (PatriciaTrieReadingUtils::hasChildrenInFlags(flags)) { PatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition( - mDictRoot, flags, &pos); + mBuffer.data(), flags, &pos); } if (PatriciaTrieReadingUtils::hasShortcutTargets(flags)) { mShortcutListPolicy.skipAllShortcuts(&pos); } if (PatriciaTrieReadingUtils::hasBigrams(flags)) { if (!mBigramListPolicy.skipAllBigrams(&pos)) { - AKLOGE("Cannot skip bigrams. BufSize: %d, pos: %d.", mDictBufferSize, pos); + AKLOGE("Cannot skip bigrams. BufSize: %zd, pos: %d.", mBuffer.size(), pos); mIsCorrupted = true; ASSERT(false); *outUnigramProbability = NOT_A_PROBABILITY; @@ -402,7 +404,7 @@ int PatriciaTriePolicy::createAndGetLeavingChildNode(const DicNode *const dicNod int shortcutPos = NOT_A_DICT_POS; int bigramPos = NOT_A_DICT_POS; int siblingPos = NOT_A_DICT_POS; - PatriciaTrieReadingUtils::readPtNodeInfo(mDictRoot, ptNodePos, &mShortcutListPolicy, + PatriciaTrieReadingUtils::readPtNodeInfo(mBuffer.data(), ptNodePos, &mShortcutListPolicy, &mBigramListPolicy, &flags, &mergedNodeCodePointCount, mergedNodeCodePoints, &probability, &childrenPos, &shortcutPos, &bigramPos, &siblingPos); // Skip PtNodes don't start with Unicode code point because they represent non-word information. @@ -452,14 +454,14 @@ const WordProperty PatriciaTriePolicy::getWordProperty( int shortcutPos = getShortcutPositionOfPtNode(ptNodePos); if (shortcutPos != NOT_A_DICT_POS) { int shortcutTargetCodePoints[MAX_WORD_LENGTH]; - ShortcutListReadingUtils::getShortcutListSizeAndForwardPointer(mDictRoot, &shortcutPos); + ShortcutListReadingUtils::getShortcutListSizeAndForwardPointer(mBuffer, &shortcutPos); bool hasNext = true; while (hasNext) { const ShortcutListReadingUtils::ShortcutFlags shortcutFlags = - ShortcutListReadingUtils::getFlagsAndForwardPointer(mDictRoot, &shortcutPos); + ShortcutListReadingUtils::getFlagsAndForwardPointer(mBuffer, &shortcutPos); hasNext = ShortcutListReadingUtils::hasNext(shortcutFlags); const int shortcutTargetLength = ShortcutListReadingUtils::readShortcutTarget( - mDictRoot, MAX_WORD_LENGTH, shortcutTargetCodePoints, &shortcutPos); + mBuffer, MAX_WORD_LENGTH, shortcutTargetCodePoints, &shortcutPos); const std::vector<int> shortcutTarget(shortcutTargetCodePoints, shortcutTargetCodePoints + shortcutTargetLength); const int shortcutProbability = @@ -512,4 +514,9 @@ int PatriciaTriePolicy::getWordIdFromTerminalPtNodePos(const int ptNodePos) cons int PatriciaTriePolicy::getTerminalPtNodePosFromWordId(const int wordId) const { return wordId == NOT_A_WORD_ID ? NOT_A_DICT_POS : wordId; } + +bool PatriciaTriePolicy::isValidPos(const int pos) const { + return pos >= 0 && pos < static_cast<int>(mBuffer.size()); +} + } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h index fc65de58c..70e8d847e 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h @@ -44,14 +44,11 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { : mMmappedBuffer(std::move(mmappedBuffer)), mHeaderPolicy(mMmappedBuffer->getReadOnlyByteArrayView().data(), FormatUtils::VERSION_2), - mDictRoot(mMmappedBuffer->getReadOnlyByteArrayView().data() - + mHeaderPolicy.getSize()), - mDictBufferSize(mMmappedBuffer->getReadOnlyByteArrayView().size() - - mHeaderPolicy.getSize()), - mBigramListPolicy(mDictRoot, mDictBufferSize), mShortcutListPolicy(mDictRoot), - mPtNodeReader(mDictRoot, mDictBufferSize, &mBigramListPolicy, &mShortcutListPolicy), - mPtNodeArrayReader(mDictRoot, mDictBufferSize), - mTerminalPtNodePositionsForIteratingWords(), mIsCorrupted(false) {} + mBuffer(mMmappedBuffer->getReadOnlyByteArrayView().skip(mHeaderPolicy.getSize())), + mBigramListPolicy(mBuffer), mShortcutListPolicy(mBuffer), + mPtNodeReader(mBuffer, &mBigramListPolicy, &mShortcutListPolicy), + mPtNodeArrayReader(mBuffer), mTerminalPtNodePositionsForIteratingWords(), + mIsCorrupted(false) {} AK_FORCE_INLINE int getRootPosition() const { return 0; @@ -149,8 +146,7 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { const MmappedBuffer::MmappedBufferPtr mMmappedBuffer; const HeaderPolicy mHeaderPolicy; - const uint8_t *const mDictRoot; - const int mDictBufferSize; + const ReadOnlyByteArrayView mBuffer; const BigramListPolicy mBigramListPolicy; const ShortcutListPolicy mShortcutListPolicy; const Ver2ParticiaTrieNodeReader mPtNodeReader; @@ -166,6 +162,7 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { int getTerminalPtNodePosFromWordId(const int wordId) const; const WordAttributes getWordAttributes(const int probability, const PtNodeParams &ptNodeParams) const; + bool isValidPos(const int pos) const; }; } // namespace latinime #endif // LATINIME_PATRICIA_TRIE_POLICY_H diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/shortcut/shortcut_list_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/shortcut/shortcut_list_policy.h index 8e16ccc05..5319dd26c 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/shortcut/shortcut_list_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/shortcut/shortcut_list_policy.h @@ -22,13 +22,13 @@ #include "defines.h" #include "suggest/core/policy/dictionary_shortcuts_structure_policy.h" #include "suggest/policyimpl/dictionary/structure/pt_common/shortcut/shortcut_list_reading_utils.h" +#include "utils/byte_array_view.h" namespace latinime { class ShortcutListPolicy : public DictionaryShortcutsStructurePolicy { public: - explicit ShortcutListPolicy(const uint8_t *const shortcutBuf) - : mShortcutsBuf(shortcutBuf) {} + explicit ShortcutListPolicy(const ReadOnlyByteArrayView buffer) : mBuffer(buffer) {} ~ShortcutListPolicy() {} @@ -37,7 +37,7 @@ class ShortcutListPolicy : public DictionaryShortcutsStructurePolicy { return NOT_A_DICT_POS; } int listPos = pos; - ShortcutListReadingUtils::getShortcutListSizeAndForwardPointer(mShortcutsBuf, &listPos); + ShortcutListReadingUtils::getShortcutListSizeAndForwardPointer(mBuffer, &listPos); return listPos; } @@ -45,7 +45,7 @@ class ShortcutListPolicy : public DictionaryShortcutsStructurePolicy { int *const outCodePointCount, bool *const outIsWhitelist, bool *const outHasNext, int *const pos) const { const ShortcutListReadingUtils::ShortcutFlags flags = - ShortcutListReadingUtils::getFlagsAndForwardPointer(mShortcutsBuf, pos); + ShortcutListReadingUtils::getFlagsAndForwardPointer(mBuffer, pos); if (outHasNext) { *outHasNext = ShortcutListReadingUtils::hasNext(flags); } @@ -54,20 +54,20 @@ class ShortcutListPolicy : public DictionaryShortcutsStructurePolicy { } if (outCodePoint) { *outCodePointCount = ShortcutListReadingUtils::readShortcutTarget( - mShortcutsBuf, maxCodePointCount, outCodePoint, pos); + mBuffer, maxCodePointCount, outCodePoint, pos); } } void skipAllShortcuts(int *const pos) const { const int shortcutListSize = ShortcutListReadingUtils - ::getShortcutListSizeAndForwardPointer(mShortcutsBuf, pos); + ::getShortcutListSizeAndForwardPointer(mBuffer, pos); *pos += shortcutListSize; } private: DISALLOW_IMPLICIT_CONSTRUCTORS(ShortcutListPolicy); - const uint8_t *const mShortcutsBuf; + const ReadOnlyByteArrayView mBuffer; }; } // namespace latinime #endif // LATINIME_SHORTCUT_LIST_POLICY_H diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.cpp index c1e938710..74cdf7929 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.cpp @@ -22,10 +22,10 @@ namespace latinime { const PtNodeParams Ver2ParticiaTrieNodeReader::fetchPtNodeParamsInBufferFromPtNodePos( const int ptNodePos) const { - if (ptNodePos < 0 || ptNodePos >= mDictSize) { + if (ptNodePos < 0 || ptNodePos >= static_cast<int>(mBuffer.size())) { // Reading invalid position because of bug or broken dictionary. - AKLOGE("Fetching PtNode info from invalid dictionary position: %d, dictionary size: %d", - ptNodePos, mDictSize); + AKLOGE("Fetching PtNode info from invalid dictionary position: %d, dictionary size: %zd", + ptNodePos, mBuffer.size()); ASSERT(false); return PtNodeParams(); } @@ -37,7 +37,7 @@ const PtNodeParams Ver2ParticiaTrieNodeReader::fetchPtNodeParamsInBufferFromPtNo int shortcutPos = NOT_A_DICT_POS; int bigramPos = NOT_A_DICT_POS; int siblingPos = NOT_A_DICT_POS; - PatriciaTrieReadingUtils::readPtNodeInfo(mDictBuffer, ptNodePos, mShortuctPolicy, + PatriciaTrieReadingUtils::readPtNodeInfo(mBuffer.data(), ptNodePos, mShortuctPolicy, mBigramPolicy, &flags, &mergedNodeCodePointCount, mergedNodeCodePoints, &probability, &childrenPos, &shortcutPos, &bigramPos, &siblingPos); if (mergedNodeCodePointCount <= 0) { diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.h index f0725b66d..0f6769dc8 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.h @@ -22,6 +22,7 @@ #include "defines.h" #include "suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h" #include "suggest/policyimpl/dictionary/structure/pt_common/pt_node_reader.h" +#include "utils/byte_array_view.h" namespace latinime { @@ -30,19 +31,17 @@ class DictionaryShortcutsStructurePolicy; class Ver2ParticiaTrieNodeReader : public PtNodeReader { public: - Ver2ParticiaTrieNodeReader(const uint8_t *const dictBuffer, const int dictSize, + Ver2ParticiaTrieNodeReader(const ReadOnlyByteArrayView buffer, const DictionaryBigramsStructurePolicy *const bigramPolicy, const DictionaryShortcutsStructurePolicy *const shortcutPolicy) - : mDictBuffer(dictBuffer), mDictSize(dictSize), mBigramPolicy(bigramPolicy), - mShortuctPolicy(shortcutPolicy) {} + : mBuffer(buffer), mBigramPolicy(bigramPolicy), mShortuctPolicy(shortcutPolicy) {} virtual const PtNodeParams fetchPtNodeParamsInBufferFromPtNodePos(const int ptNodePos) const; private: DISALLOW_IMPLICIT_CONSTRUCTORS(Ver2ParticiaTrieNodeReader); - const uint8_t *const mDictBuffer; - const int mDictSize; + const ReadOnlyByteArrayView mBuffer; const DictionaryBigramsStructurePolicy *const mBigramPolicy; const DictionaryShortcutsStructurePolicy *const mShortuctPolicy; }; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_pt_node_array_reader.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_pt_node_array_reader.cpp index b46617d96..72ad1eb66 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_pt_node_array_reader.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_pt_node_array_reader.cpp @@ -22,16 +22,16 @@ namespace latinime { bool Ver2PtNodeArrayReader::readPtNodeArrayInfoAndReturnIfValid(const int ptNodeArrayPos, int *const outPtNodeCount, int *const outFirstPtNodePos) const { - if (ptNodeArrayPos < 0 || ptNodeArrayPos >= mDictSize) { + if (ptNodeArrayPos < 0 || ptNodeArrayPos >= static_cast<int>(mBuffer.size())) { // Reading invalid position because of a bug or a broken dictionary. - AKLOGE("Reading PtNode array info from invalid dictionary position: %d, dict size: %d", - ptNodeArrayPos, mDictSize); + AKLOGE("Reading PtNode array info from invalid dictionary position: %d, dict size: %zd", + ptNodeArrayPos, mBuffer.size()); ASSERT(false); return false; } int readingPos = ptNodeArrayPos; const int ptNodeCountInArray = PatriciaTrieReadingUtils::getPtNodeArraySizeAndAdvancePosition( - mDictBuffer, &readingPos); + mBuffer.data(), &readingPos); *outPtNodeCount = ptNodeCountInArray; *outFirstPtNodePos = readingPos; return true; @@ -39,10 +39,10 @@ bool Ver2PtNodeArrayReader::readPtNodeArrayInfoAndReturnIfValid(const int ptNode bool Ver2PtNodeArrayReader::readForwardLinkAndReturnIfValid(const int forwordLinkPos, int *const outNextPtNodeArrayPos) const { - if (forwordLinkPos < 0 || forwordLinkPos >= mDictSize) { + if (forwordLinkPos < 0 || forwordLinkPos >= static_cast<int>(mBuffer.size())) { // Reading invalid position because of bug or broken dictionary. - AKLOGE("Reading forward link from invalid dictionary position: %d, dict size: %d", - forwordLinkPos, mDictSize); + AKLOGE("Reading forward link from invalid dictionary position: %d, dict size: %zd", + forwordLinkPos, mBuffer.size()); ASSERT(false); return false; } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_pt_node_array_reader.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_pt_node_array_reader.h index 548272148..548f36bf3 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_pt_node_array_reader.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_pt_node_array_reader.h @@ -21,13 +21,13 @@ #include "defines.h" #include "suggest/policyimpl/dictionary/structure/pt_common/pt_node_array_reader.h" +#include "utils/byte_array_view.h" namespace latinime { class Ver2PtNodeArrayReader : public PtNodeArrayReader { public: - Ver2PtNodeArrayReader(const uint8_t *const dictBuffer, const int dictSize) - : mDictBuffer(dictBuffer), mDictSize(dictSize) {}; + Ver2PtNodeArrayReader(const ReadOnlyByteArrayView buffer) : mBuffer(buffer) {}; virtual bool readPtNodeArrayInfoAndReturnIfValid(const int ptNodeArrayPos, int *const outPtNodeCount, int *const outFirstPtNodePos) const; @@ -37,8 +37,7 @@ class Ver2PtNodeArrayReader : public PtNodeArrayReader { private: DISALLOW_COPY_AND_ASSIGN(Ver2PtNodeArrayReader); - const uint8_t *const mDictBuffer; - const int mDictSize; + const ReadOnlyByteArrayView mBuffer; }; } // namespace latinime #endif /* LATINIME_VER2_PT_NODE_ARRAY_READER_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp index f54bb151a..35f0f768f 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp @@ -39,7 +39,7 @@ bool LanguageModelDictContent::runGC( } int LanguageModelDictContent::getWordProbability(const WordIdArrayView prevWordIds, - const int wordId) const { + const int wordId, const HeaderPolicy *const headerPolicy) const { int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex(); int maxLevel = 0; @@ -58,14 +58,15 @@ int LanguageModelDictContent::getWordProbability(const WordIdArrayView prevWordI if (!result.mIsValid) { continue; } - const int probability = - ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo).getProbability(); + const ProbabilityEntry probabilityEntry = + ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo); if (mHasHistoricalInfo) { - return std::min( - probability + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */), - MAX_PROBABILITY); + const int probability = ForgettingCurveUtils::decodeProbability( + probabilityEntry.getHistoricalInfo(), headerPolicy) + + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */); + return std::min(probability, MAX_PROBABILITY); } else { - return probability; + return probabilityEntry.getProbability(); } } // Cannot find the word. @@ -166,7 +167,15 @@ int LanguageModelDictContent::createAndGetBitmapEntryIndex(const WordIdArrayView if (lastBitmapEntryIndex == TrieMap::INVALID_INDEX) { return TrieMap::INVALID_INDEX; } - return mTrieMap.getNextLevelBitmapEntryIndex(prevWordIds[prevWordIds.size() - 1], + const int oldestPrevWordId = prevWordIds.lastOrDefault(NOT_A_WORD_ID); + const TrieMap::Result result = mTrieMap.get(oldestPrevWordId, lastBitmapEntryIndex); + if (!result.mIsValid) { + if (!mTrieMap.put(oldestPrevWordId, + ProbabilityEntry().encode(mHasHistoricalInfo), lastBitmapEntryIndex)) { + return TrieMap::INVALID_INDEX; + } + } + return mTrieMap.getNextLevelBitmapEntryIndex(prevWordIds.lastOrDefault(NOT_A_WORD_ID), lastBitmapEntryIndex); } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h index 4e0b47036..a793af4be 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h @@ -128,7 +128,8 @@ class LanguageModelDictContent { const LanguageModelDictContent *const originalContent, int *const outNgramCount); - int getWordProbability(const WordIdArrayView prevWordIds, const int wordId) const; + int getWordProbability(const WordIdArrayView prevWordIds, const int wordId, + const HeaderPolicy *const headerPolicy) const; ProbabilityEntry getProbabilityEntry(const int wordId) const { return getNgramProbabilityEntry(WordIdArrayView(), wordId); diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h index 3dfaba755..f1bf12cb2 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h @@ -36,7 +36,8 @@ class ProbabilityEntry { // Dummy entry ProbabilityEntry() - : mFlags(0), mProbability(NOT_A_PROBABILITY), mHistoricalInfo() {} + : mFlags(Ver4DictConstants::FLAG_NOT_A_VALID_ENTRY), mProbability(NOT_A_PROBABILITY), + mHistoricalInfo() {} // Entry without historical information ProbabilityEntry(const int flags, const int probability) @@ -61,7 +62,7 @@ class ProbabilityEntry { bigramProperty->getCount()) {} bool isValid() const { - return (mProbability != NOT_A_PROBABILITY) || hasHistoricalInfo(); + return (mFlags & Ver4DictConstants::FLAG_NOT_A_VALID_ENTRY) == 0; } bool hasHistoricalInfo() const { diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp index 9acf2d44f..39822b94a 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp @@ -53,6 +53,7 @@ const int Ver4DictConstants::WORD_LEVEL_FIELD_SIZE = 1; const int Ver4DictConstants::WORD_COUNT_FIELD_SIZE = 1; const uint8_t Ver4DictConstants::FLAG_REPRESENTS_BEGINNING_OF_SENTENCE = 0x1; +const uint8_t Ver4DictConstants::FLAG_NOT_A_VALID_ENTRY = 0x2; const int Ver4DictConstants::SHORTCUT_ADDRESS_TABLE_BLOCK_SIZE = 64; const int Ver4DictConstants::SHORTCUT_ADDRESS_TABLE_DATA_SIZE = 4; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h index 97035311e..dfcdd4d6f 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h @@ -51,6 +51,7 @@ class Ver4DictConstants { static const int WORD_COUNT_FIELD_SIZE; // Flags in probability entry. static const uint8_t FLAG_REPRESENTS_BEGINNING_OF_SENTENCE; + static const uint8_t FLAG_NOT_A_VALID_ENTRY; static const int SHORTCUT_ADDRESS_TABLE_BLOCK_SIZE; static const int SHORTCUT_ADDRESS_TABLE_DATA_SIZE; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp index 9ca712470..75ec16912 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp @@ -211,19 +211,17 @@ bool Ver4PatriciaTrieNodeWriter::writeNewTerminalPtNodeAndAdvancePosition( bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds, const int wordId, const BigramProperty *const bigramProperty, bool *const outAddedNewBigram) { - // TODO: Support n-gram. LanguageModelDictContent *const languageModelDictContent = mBuffers->getMutableLanguageModelDictContent(); const ProbabilityEntry probabilityEntry = - languageModelDictContent->getNgramProbabilityEntry( - prevWordIds.limit(1 /* maxSize */), wordId); + languageModelDictContent->getNgramProbabilityEntry(prevWordIds, wordId); const ProbabilityEntry probabilityEntryOfBigramProperty(bigramProperty); const ProbabilityEntry updatedProbabilityEntry = createUpdatedEntryFrom( &probabilityEntry, &probabilityEntryOfBigramProperty); if (!languageModelDictContent->setNgramProbabilityEntry( - prevWordIds.limit(1 /* maxSize */), wordId, &updatedProbabilityEntry)) { - AKLOGE("Cannot add new ngram entry. prevWordId: %d, wordId: %d", - prevWordIds[0], wordId); + prevWordIds, wordId, &updatedProbabilityEntry)) { + AKLOGE("Cannot add new ngram entry. prevWordId[0]: %d, prevWordId.size(): %zd, wordId: %d", + prevWordIds[0], prevWordIds.size(), wordId); return false; } if (!probabilityEntry.isValid() && outAddedNewBigram) { @@ -234,11 +232,9 @@ bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds bool Ver4PatriciaTrieNodeWriter::removeNgramEntry(const WordIdArrayView prevWordIds, const int wordId) { - // TODO: Support n-gram. LanguageModelDictContent *const languageModelDictContent = mBuffers->getMutableLanguageModelDictContent(); - return languageModelDictContent->removeNgramProbabilityEntry(prevWordIds.limit(1 /* maxSize */), - wordId); + return languageModelDictContent->removeNgramProbabilityEntry(prevWordIds, wordId); } // TODO: Remove when we stop supporting v402 format. diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp index 5b1907f49..8d4135679 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp @@ -120,15 +120,15 @@ const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext( const int ptNodePos = mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId); const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); - // TODO: Support n-gram. - return WordAttributes(mBuffers->getLanguageModelDictContent()->getWordProbability( - prevWordIds.limit(1 /* maxSize */), wordId), ptNodeParams.isBlacklisted(), - ptNodeParams.isNotAWord(), ptNodeParams.getProbability() == 0); + const int probability = mBuffers->getLanguageModelDictContent()->getWordProbability( + prevWordIds, wordId, mHeaderPolicy); + return WordAttributes(probability, ptNodeParams.isBlacklisted(), ptNodeParams.isNotAWord(), + probability == 0); } int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds, const int wordId) const { - if (wordId == NOT_A_WORD_ID) { + if (wordId == NOT_A_WORD_ID || prevWordIds.contains(NOT_A_WORD_ID)) { return NOT_A_PROBABILITY; } const int ptNodePos = @@ -137,10 +137,8 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordI if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) { return NOT_A_PROBABILITY; } - // TODO: Support n-gram. const ProbabilityEntry probabilityEntry = - mBuffers->getLanguageModelDictContent()->getNgramProbabilityEntry( - prevWordIds.limit(1 /* maxSize */), wordId); + mBuffers->getLanguageModelDictContent()->getNgramProbabilityEntry(prevWordIds, wordId); if (!probabilityEntry.isValid()) { return NOT_A_PROBABILITY; } @@ -163,16 +161,18 @@ void Ver4PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordI if (prevWordIds.empty()) { return; } - // TODO: Support n-gram. const auto languageModelDictContent = mBuffers->getLanguageModelDictContent(); - for (const auto entry : languageModelDictContent->getProbabilityEntries( - prevWordIds.limit(1 /* maxSize */))) { - const ProbabilityEntry &probabilityEntry = entry.getProbabilityEntry(); - const int probability = probabilityEntry.hasHistoricalInfo() ? - ForgettingCurveUtils::decodeProbability( - probabilityEntry.getHistoricalInfo(), mHeaderPolicy) : - probabilityEntry.getProbability(); - listener->onVisitEntry(probability, entry.getWordId()); + for (size_t i = 1; i <= prevWordIds.size(); ++i) { + for (const auto entry : languageModelDictContent->getProbabilityEntries( + prevWordIds.limit(i))) { + const ProbabilityEntry &probabilityEntry = entry.getProbabilityEntry(); + const int probability = probabilityEntry.hasHistoricalInfo() ? + ForgettingCurveUtils::decodeProbability( + probabilityEntry.getHistoricalInfo(), mHeaderPolicy) + + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */) : + probabilityEntry.getProbability(); + listener->onVisitEntry(probability, entry.getWordId()); + } } } @@ -227,8 +227,8 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const CodePointArrayView wordCodePo return false; } const CodePointArrayView codePointArrayView(codePointsToAdd, codePointCountToAdd); - if (mUpdatingHelper.addUnigramWord(&readingHelper, codePointArrayView.data(), - codePointArrayView.size(), unigramProperty, &addedNewUnigram)) { + if (mUpdatingHelper.addUnigramWord(&readingHelper, codePointArrayView, unigramProperty, + &addedNewUnigram)) { if (addedNewUnigram && !unigramProperty->representsBeginningOfSentence()) { mUnigramCount++; } @@ -243,8 +243,8 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const CodePointArrayView wordCodePo mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId); for (const auto &shortcut : unigramProperty->getShortcuts()) { if (!mUpdatingHelper.addShortcutTarget(wordPos, - shortcut.getTargetCodePoints()->data(), - shortcut.getTargetCodePoints()->size(), shortcut.getProbability())) { + CodePointArrayView(*shortcut.getTargetCodePoints()), + shortcut.getProbability())) { AKLOGE("Cannot add new shortcut target. PtNodePos: %d, length: %zd, " "probability: %d", wordPos, shortcut.getTargetCodePoints()->size(), shortcut.getProbability()); @@ -303,26 +303,31 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI "length: %zd", bigramProperty->getTargetCodePoints()->size()); return false; } - WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIds; - prevWordsInfo->getPrevWordIds(this, prevWordIds.data(), false /* tryLowerCaseSearch */); - // TODO: Support N-gram. - if (prevWordIds[0] == NOT_A_WORD_ID) { - if (prevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */)) { - const std::vector<UnigramProperty::ShortcutProperty> shortcuts; - const UnigramProperty beginningOfSentenceUnigramProperty( - true /* representsBeginningOfSentence */, true /* isNotAWord */, - false /* isBlacklisted */, MAX_PROBABILITY /* probability */, - NOT_A_TIMESTAMP /* timestamp */, 0 /* level */, 0 /* count */, &shortcuts); - if (!addUnigramEntry(prevWordsInfo->getNthPrevWordCodePoints(1 /* n */), - &beginningOfSentenceUnigramProperty)) { - AKLOGE("Cannot add unigram entry for the beginning-of-sentence."); - return false; - } - // Refresh word ids. - prevWordsInfo->getPrevWordIds(this, prevWordIds.data(), false /* tryLowerCaseSearch */); - } else { + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; + const WordIdArrayView prevWordIds = prevWordsInfo->getPrevWordIds(this, &prevWordIdArray, + false /* tryLowerCaseSearch */); + if (prevWordIds.empty()) { + return false; + } + for (size_t i = 0; i < prevWordIds.size(); ++i) { + if (prevWordIds[i] != NOT_A_WORD_ID) { + continue; + } + if (!prevWordsInfo->isNthPrevWordBeginningOfSentence(i + 1 /* n */)) { + return false; + } + const std::vector<UnigramProperty::ShortcutProperty> shortcuts; + const UnigramProperty beginningOfSentenceUnigramProperty( + true /* representsBeginningOfSentence */, true /* isNotAWord */, + false /* isBlacklisted */, MAX_PROBABILITY /* probability */, + NOT_A_TIMESTAMP /* timestamp */, 0 /* level */, 0 /* count */, &shortcuts); + if (!addUnigramEntry(prevWordsInfo->getNthPrevWordCodePoints(1 /* n */), + &beginningOfSentenceUnigramProperty)) { + AKLOGE("Cannot add unigram entry for the beginning-of-sentence."); return false; } + // Refresh word ids. + prevWordsInfo->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */); } const int wordId = getWordId(CodePointArrayView(*bigramProperty->getTargetCodePoints()), false /* forceLowerCaseSearch */); @@ -330,15 +335,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI return false; } bool addedNewEntry = false; - WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordsPtNodePos; - for (size_t i = 0; i < prevWordsPtNodePos.size(); ++i) { - prevWordsPtNodePos[i] = mBuffers->getTerminalPositionLookupTable() - ->getTerminalPtNodePosition(prevWordIds[i]); - } - const int wordPtNodePos = mBuffers->getTerminalPositionLookupTable() - ->getTerminalPtNodePosition(wordId); - if (mUpdatingHelper.addNgramEntry(WordIdArrayView::fromArray(prevWordsPtNodePos), - wordPtNodePos, bigramProperty, &addedNewEntry)) { + if (mNodeWriter.addNgramEntry(prevWordIds, wordId, bigramProperty, &addedNewEntry)) { if (addedNewEntry) { mBigramCount++; } @@ -367,25 +364,17 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor AKLOGE("word is too long to remove n-gram entry form the dictionary. length: %zd", wordCodePoints.size()); } - WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIds; - prevWordsInfo->getPrevWordIds(this, prevWordIds.data(), false /* tryLowerCaseSerch */); - // TODO: Support N-gram. - if (prevWordIds[0] == NOT_A_WORD_ID) { + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; + const WordIdArrayView prevWordIds = prevWordsInfo->getPrevWordIds(this, &prevWordIdArray, + false /* tryLowerCaseSerch */); + if (prevWordIds.empty() || prevWordIds.contains(NOT_A_WORD_ID)) { return false; } const int wordId = getWordId(wordCodePoints, false /* forceLowerCaseSearch */); if (wordId == NOT_A_WORD_ID) { return false; } - std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordsPtNodePos; - for (size_t i = 0; i < prevWordsPtNodePos.size(); ++i) { - prevWordsPtNodePos[i] = mBuffers->getTerminalPositionLookupTable() - ->getTerminalPtNodePosition(prevWordIds[i]); - } - const int wordPtNodePos = mBuffers->getTerminalPositionLookupTable() - ->getTerminalPtNodePosition(wordId); - if (mUpdatingHelper.removeNgramEntry(WordIdArrayView::fromArray(prevWordsPtNodePos), - wordPtNodePos)) { + if (mNodeWriter.removeNgramEntry(prevWordIds, wordId)) { mBigramCount--; return true; } else { diff --git a/native/jni/src/utils/byte_array_view.h b/native/jni/src/utils/byte_array_view.h index 10d7ae278..2b778af6f 100644 --- a/native/jni/src/utils/byte_array_view.h +++ b/native/jni/src/utils/byte_array_view.h @@ -42,6 +42,13 @@ class ReadOnlyByteArrayView { return mPtr; } + AK_FORCE_INLINE const ReadOnlyByteArrayView skip(const size_t n) const { + if (mSize <= n) { + return ReadOnlyByteArrayView(); + } + return ReadOnlyByteArrayView(mPtr + n, mSize - n); + } + private: DISALLOW_ASSIGNMENT_OPERATOR(ReadOnlyByteArrayView); diff --git a/native/jni/src/utils/int_array_view.h b/native/jni/src/utils/int_array_view.h index caa13d976..f3a8589ca 100644 --- a/native/jni/src/utils/int_array_view.h +++ b/native/jni/src/utils/int_array_view.h @@ -17,6 +17,7 @@ #ifndef LATINIME_INT_ARRAY_VIEW_H #define LATINIME_INT_ARRAY_VIEW_H +#include <algorithm> #include <array> #include <cstdint> #include <cstring> @@ -92,12 +93,16 @@ class IntArrayView { return mPtr + mSize; } + AK_FORCE_INLINE bool contains(const int value) const { + return std::find(begin(), end(), value) != end(); + } + // Returns the view whose size is smaller than or equal to the given count. - const IntArrayView limit(const size_t maxSize) const { + AK_FORCE_INLINE const IntArrayView limit(const size_t maxSize) const { return IntArrayView(mPtr, std::min(maxSize, mSize)); } - const IntArrayView skip(const size_t n) const { + AK_FORCE_INLINE const IntArrayView skip(const size_t n) const { if (mSize <= n) { return IntArrayView(); } @@ -110,6 +115,20 @@ class IntArrayView { memmove(buffer->data() + offset, mPtr, sizeof(int) * mSize); } + AK_FORCE_INLINE int firstOrDefault(const int defaultValue) const { + if (empty()) { + return defaultValue; + } + return mPtr[0]; + } + + AK_FORCE_INLINE int lastOrDefault(const int defaultValue) const { + if (empty()) { + return defaultValue; + } + return mPtr[mSize - 1]; + } + private: DISALLOW_ASSIGNMENT_OPERATOR(IntArrayView); |