diff options
Diffstat (limited to 'native/jni/src')
73 files changed, 1481 insertions, 1807 deletions
diff --git a/native/jni/src/defines.h b/native/jni/src/defines.h index 24d04e51f..57e18884d 100644 --- a/native/jni/src/defines.h +++ b/native/jni/src/defines.h @@ -299,6 +299,7 @@ static inline void prof_out(void) { #define NOT_AN_INDEX (-1) #define NOT_A_PROBABILITY (-1) #define NOT_A_DICT_POS (S_INT_MIN) +#define NOT_A_WORD_ID (S_INT_MIN) #define NOT_A_TIMESTAMP (-1) #define NOT_A_LANGUAGE_WEIGHT (-1.0f) diff --git a/native/jni/src/suggest/core/dicnode/dic_node.h b/native/jni/src/suggest/core/dicnode/dic_node.h index d1b2c87be..ec61783cb 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node.h +++ b/native/jni/src/suggest/core/dicnode/dic_node.h @@ -26,6 +26,7 @@ #include "suggest/core/dictionary/error_type_utils.h" #include "suggest/core/layout/proximity_info_state.h" #include "utils/char_utils.h" +#include "utils/int_array_view.h" #if DEBUG_DICT #define LOGI_SHOW_ADD_COST_PROP \ @@ -103,10 +104,10 @@ class DicNode { PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); } - // Init for root with prevWordsPtNodePos which is used for n-gram - void initAsRoot(const int rootPtNodeArrayPos, const int *const prevWordsPtNodePos) { + // Init for root with prevWordIds which is used for n-gram + void initAsRoot(const int rootPtNodeArrayPos, const WordIdArrayView prevWordIds) { mIsCachedForNextSuggestion = false; - mDicNodeProperties.init(rootPtNodeArrayPos, prevWordsPtNodePos); + mDicNodeProperties.init(rootPtNodeArrayPos, prevWordIds); mDicNodeState.init(); PROF_NODE_RESET(mProfiler); } @@ -114,12 +115,11 @@ class DicNode { // Init for root with previous word void initAsRootWithPreviousWord(const DicNode *const dicNode, const int rootPtNodeArrayPos) { mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion; - int newPrevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; - newPrevWordsPtNodePos[0] = dicNode->mDicNodeProperties.getPtNodePos(); - for (size_t i = 1; i < NELEMS(newPrevWordsPtNodePos); ++i) { - newPrevWordsPtNodePos[i] = dicNode->getPrevWordsTerminalPtNodePos()[i - 1]; - } - mDicNodeProperties.init(rootPtNodeArrayPos, newPrevWordsPtNodePos); + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> newPrevWordIds; + newPrevWordIds[0] = dicNode->mDicNodeProperties.getWordId(); + dicNode->getPrevWordIds().limit(newPrevWordIds.size() - 1) + .copyToArray(&newPrevWordIds, 1 /* offset */); + mDicNodeProperties.init(rootPtNodeArrayPos, WordIdArrayView::fromArray(newPrevWordIds)); mDicNodeState.initAsRootWithPreviousWord(&dicNode->mDicNodeState, dicNode->mDicNodeProperties.getDepth()); PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); @@ -135,19 +135,16 @@ class DicNode { PROF_NODE_COPY(&parentDicNode->mProfiler, mProfiler); } - void initAsChild(const DicNode *const dicNode, const int ptNodePos, - const int childrenPtNodeArrayPos, const int probability, const bool isTerminal, - const bool hasChildren, const bool isBlacklistedOrNotAWord, - const uint16_t mergedNodeCodePointCount, const int *const mergedNodeCodePoints) { + void initAsChild(const DicNode *const dicNode, const int childrenPtNodeArrayPos, + const int wordId, const CodePointArrayView mergedCodePoints) { uint16_t newDepth = static_cast<uint16_t>(dicNode->getNodeCodePointCount() + 1); mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion; const uint16_t newLeavingDepth = static_cast<uint16_t>( - dicNode->mDicNodeProperties.getLeavingDepth() + mergedNodeCodePointCount); - mDicNodeProperties.init(ptNodePos, childrenPtNodeArrayPos, mergedNodeCodePoints[0], - probability, isTerminal, hasChildren, isBlacklistedOrNotAWord, newDepth, - newLeavingDepth, dicNode->mDicNodeProperties.getPrevWordsTerminalPtNodePos()); - mDicNodeState.init(&dicNode->mDicNodeState, mergedNodeCodePointCount, - mergedNodeCodePoints); + dicNode->mDicNodeProperties.getLeavingDepth() + mergedCodePoints.size()); + mDicNodeProperties.init(childrenPtNodeArrayPos, mergedCodePoints[0], + wordId, newDepth, newLeavingDepth, dicNode->mDicNodeProperties.getPrevWordIds()); + mDicNodeState.init(&dicNode->mDicNodeState, mergedCodePoints.size(), + mergedCodePoints.data()); PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); } @@ -179,9 +176,6 @@ class DicNode { // Check if the current word and the previous word can be considered as a valid multiple word // suggestion. bool isValidMultipleWordSuggestion() const { - if (isBlacklistedOrNotAWord()) { - return false; - } // Treat suggestion as invalid if the current and the previous word are single character // words. const int prevWordLen = mDicNodeState.mDicNodeStateOutput.getPrevWordsLength() @@ -204,13 +198,12 @@ class DicNode { } // Used to get n-gram probability in DicNodeUtils. - int getPtNodePos() const { - return mDicNodeProperties.getPtNodePos(); + int getWordId() const { + return mDicNodeProperties.getWordId(); } - // TODO: Use view class to return PtNodePos array. - const int *getPrevWordsTerminalPtNodePos() const { - return mDicNodeProperties.getPrevWordsTerminalPtNodePos(); + const WordIdArrayView getPrevWordIds() const { + return mDicNodeProperties.getPrevWordIds(); } // Used in DicNodeUtils @@ -218,10 +211,6 @@ class DicNode { return mDicNodeProperties.getChildrenPtNodeArrayPos(); } - int getProbability() const { - return mDicNodeProperties.getProbability(); - } - AK_FORCE_INLINE bool isTerminalDicNode() const { const bool isTerminalPtNode = mDicNodeProperties.isTerminal(); const int currentDicNodeDepth = getNodeCodePointCount(); @@ -404,10 +393,6 @@ class DicNode { return mDicNodeState.mDicNodeStateScoring.getContainedErrorTypes(); } - bool isBlacklistedOrNotAWord() const { - return mDicNodeProperties.isBlacklistedOrNotAWord(); - } - inline uint16_t getNodeCodePointCount() const { return mDicNodeProperties.getDepth(); } diff --git a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp index 69ea67418..7d2898b7a 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp +++ b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp @@ -18,7 +18,6 @@ #include "suggest/core/dicnode/dic_node.h" #include "suggest/core/dicnode/dic_node_vector.h" -#include "suggest/core/dictionary/multi_bigram_map.h" #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" namespace latinime { @@ -29,8 +28,8 @@ namespace latinime { /* static */ void DicNodeUtils::initAsRoot( const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy, - const int *const prevWordsPtNodePos, DicNode *const newRootDicNode) { - newRootDicNode->initAsRoot(dictionaryStructurePolicy->getRootPosition(), prevWordsPtNodePos); + const WordIdArrayView prevWordIds, DicNode *const newRootDicNode) { + newRootDicNode->initAsRoot(dictionaryStructurePolicy->getRootPosition(), prevWordIds); } /*static */ void DicNodeUtils::initAsRootWithPreviousWord( @@ -73,25 +72,16 @@ namespace latinime { if (dicNode->hasMultipleWords() && !dicNode->isValidMultipleWordSuggestion()) { return static_cast<float>(MAX_VALUE_FOR_WEIGHTING); } - const int probability = getBigramNodeProbability(dictionaryStructurePolicy, dicNode, - multiBigramMap); + const WordAttributes wordAttributes = dictionaryStructurePolicy->getWordAttributesInContext( + dicNode->getPrevWordIds(), dicNode->getWordId(), multiBigramMap); + if (dicNode->hasMultipleWords() + && (wordAttributes.isBlacklisted() || wordAttributes.isNotAWord())) { + return static_cast<float>(MAX_VALUE_FOR_WEIGHTING); + } // TODO: This equation to calculate the improbability looks unreasonable. Investigate this. - const float cost = static_cast<float>(MAX_PROBABILITY - probability) + const float cost = static_cast<float>(MAX_PROBABILITY - wordAttributes.getProbability()) / static_cast<float>(MAX_PROBABILITY); return cost; } -/* static */ int DicNodeUtils::getBigramNodeProbability( - const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy, - const DicNode *const dicNode, MultiBigramMap *const multiBigramMap) { - const int unigramProbability = dicNode->getProbability(); - if (multiBigramMap) { - const int *const prevWordsPtNodePos = dicNode->getPrevWordsTerminalPtNodePos(); - return multiBigramMap->getBigramProbability(dictionaryStructurePolicy, - prevWordsPtNodePos, dicNode->getPtNodePos(), unigramProbability); - } - return dictionaryStructurePolicy->getProbability(unigramProbability, - NOT_A_PROBABILITY); -} - } // namespace latinime diff --git a/native/jni/src/suggest/core/dicnode/dic_node_utils.h b/native/jni/src/suggest/core/dicnode/dic_node_utils.h index 00e80c604..b891a842a 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_utils.h +++ b/native/jni/src/suggest/core/dicnode/dic_node_utils.h @@ -18,6 +18,7 @@ #define LATINIME_DIC_NODE_UTILS_H #include "defines.h" +#include "utils/int_array_view.h" namespace latinime { @@ -30,7 +31,7 @@ class DicNodeUtils { public: static void initAsRoot( const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy, - const int *const prevWordPtNodePos, DicNode *const newRootDicNode); + const WordIdArrayView prevWordIds, DicNode *const newRootDicNode); static void initAsRootWithPreviousWord( const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy, const DicNode *const prevWordLastDicNode, DicNode *const newRootDicNode); @@ -46,10 +47,6 @@ class DicNodeUtils { DISALLOW_IMPLICIT_CONSTRUCTORS(DicNodeUtils); // Max number of bigrams to look up static const int MAX_BIGRAMS_CONSIDERED_PER_CONTEXT = 500; - - static int getBigramNodeProbability( - const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy, - const DicNode *const dicNode, MultiBigramMap *const multiBigramMap); }; } // namespace latinime #endif // LATINIME_DIC_NODE_UTILS_H diff --git a/native/jni/src/suggest/core/dicnode/dic_node_vector.h b/native/jni/src/suggest/core/dicnode/dic_node_vector.h index 54cde1988..e6b758954 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_vector.h +++ b/native/jni/src/suggest/core/dicnode/dic_node_vector.h @@ -21,6 +21,7 @@ #include "defines.h" #include "suggest/core/dicnode/dic_node.h" +#include "utils/int_array_view.h" namespace latinime { @@ -58,15 +59,11 @@ class DicNodeVector { mDicNodes.back().initAsPassingChild(dicNode); } - void pushLeavingChild(const DicNode *const dicNode, const int ptNodePos, - const int childrenPtNodeArrayPos, const int probability, const bool isTerminal, - const bool hasChildren, const bool isBlacklistedOrNotAWord, - const uint16_t mergedNodeCodePointCount, const int *const mergedNodeCodePoints) { + void pushLeavingChild(const DicNode *const dicNode, const int childrenPtNodeArrayPos, + const int wordId, const CodePointArrayView mergedCodePoints) { ASSERT(!mLock); mDicNodes.emplace_back(); - mDicNodes.back().initAsChild(dicNode, ptNodePos, childrenPtNodeArrayPos, probability, - isTerminal, hasChildren, isBlacklistedOrNotAWord, mergedNodeCodePointCount, - mergedNodeCodePoints); + mDicNodes.back().initAsChild(dicNode, childrenPtNodeArrayPos, wordId, mergedCodePoints); } DicNode *operator[](const int id) { diff --git a/native/jni/src/suggest/core/dicnode/internal/dic_node_properties.h b/native/jni/src/suggest/core/dicnode/internal/dic_node_properties.h index 8202176f7..1b796b5d4 100644 --- a/native/jni/src/suggest/core/dicnode/internal/dic_node_properties.h +++ b/native/jni/src/suggest/core/dicnode/internal/dic_node_properties.h @@ -18,8 +18,10 @@ #define LATINIME_DIC_NODE_PROPERTIES_H #include <cstdint> +#include <cstdlib> #include "defines.h" +#include "utils/int_array_view.h" namespace latinime { @@ -29,84 +31,61 @@ namespace latinime { class DicNodeProperties { public: AK_FORCE_INLINE DicNodeProperties() - : mPtNodePos(NOT_A_DICT_POS), mChildrenPtNodeArrayPos(NOT_A_DICT_POS), - mProbability(NOT_A_PROBABILITY), mDicNodeCodePoint(NOT_A_CODE_POINT), - mIsTerminal(false), mHasChildrenPtNodes(false), - mIsBlacklistedOrNotAWord(false), mDepth(0), mLeavingDepth(0) {} + : mChildrenPtNodeArrayPos(NOT_A_DICT_POS), mDicNodeCodePoint(NOT_A_CODE_POINT), + mWordId(NOT_A_WORD_ID), mDepth(0), mLeavingDepth(0), mPrevWordCount(0) {} ~DicNodeProperties() {} // Should be called only once per DicNode is initialized. - void init(const int pos, const int childrenPos, const int nodeCodePoint, const int probability, - const bool isTerminal, const bool hasChildren, const bool isBlacklistedOrNotAWord, - const uint16_t depth, const uint16_t leavingDepth, const int *const prevWordsNodePos) { - mPtNodePos = pos; + void init(const int childrenPos, const int nodeCodePoint, const int wordId, + const uint16_t depth, const uint16_t leavingDepth, const WordIdArrayView prevWordIds) { mChildrenPtNodeArrayPos = childrenPos; mDicNodeCodePoint = nodeCodePoint; - mProbability = probability; - mIsTerminal = isTerminal; - mHasChildrenPtNodes = hasChildren; - mIsBlacklistedOrNotAWord = isBlacklistedOrNotAWord; + mWordId = wordId; mDepth = depth; mLeavingDepth = leavingDepth; - memmove(mPrevWordsTerminalPtNodePos, prevWordsNodePos, sizeof(mPrevWordsTerminalPtNodePos)); + prevWordIds.copyToArray(&mPrevWordIds, 0 /* offset */); + mPrevWordCount = prevWordIds.size(); } // Init for root with prevWordsPtNodePos which is used for n-gram - void init(const int rootPtNodeArrayPos, const int *const prevWordsNodePos) { - mPtNodePos = NOT_A_DICT_POS; + void init(const int rootPtNodeArrayPos, const WordIdArrayView prevWordIds) { mChildrenPtNodeArrayPos = rootPtNodeArrayPos; mDicNodeCodePoint = NOT_A_CODE_POINT; - mProbability = NOT_A_PROBABILITY; - mIsTerminal = false; - mHasChildrenPtNodes = true; - mIsBlacklistedOrNotAWord = false; + mWordId = NOT_A_WORD_ID; mDepth = 0; mLeavingDepth = 0; - memmove(mPrevWordsTerminalPtNodePos, prevWordsNodePos, sizeof(mPrevWordsTerminalPtNodePos)); + prevWordIds.copyToArray(&mPrevWordIds, 0 /* offset */); + mPrevWordCount = prevWordIds.size(); } void initByCopy(const DicNodeProperties *const dicNodeProp) { - mPtNodePos = dicNodeProp->mPtNodePos; mChildrenPtNodeArrayPos = dicNodeProp->mChildrenPtNodeArrayPos; mDicNodeCodePoint = dicNodeProp->mDicNodeCodePoint; - mProbability = dicNodeProp->mProbability; - mIsTerminal = dicNodeProp->mIsTerminal; - mHasChildrenPtNodes = dicNodeProp->mHasChildrenPtNodes; - mIsBlacklistedOrNotAWord = dicNodeProp->mIsBlacklistedOrNotAWord; + mWordId = dicNodeProp->mWordId; mDepth = dicNodeProp->mDepth; mLeavingDepth = dicNodeProp->mLeavingDepth; - memmove(mPrevWordsTerminalPtNodePos, dicNodeProp->mPrevWordsTerminalPtNodePos, - sizeof(mPrevWordsTerminalPtNodePos)); + const WordIdArrayView prevWordIdArrayView = dicNodeProp->getPrevWordIds(); + prevWordIdArrayView.copyToArray(&mPrevWordIds, 0 /* offset */); + mPrevWordCount = prevWordIdArrayView.size(); } // Init as passing child void init(const DicNodeProperties *const dicNodeProp, const int codePoint) { - mPtNodePos = dicNodeProp->mPtNodePos; mChildrenPtNodeArrayPos = dicNodeProp->mChildrenPtNodeArrayPos; mDicNodeCodePoint = codePoint; // Overwrite the node char of a passing child - mProbability = dicNodeProp->mProbability; - mIsTerminal = dicNodeProp->mIsTerminal; - mHasChildrenPtNodes = dicNodeProp->mHasChildrenPtNodes; - mIsBlacklistedOrNotAWord = dicNodeProp->mIsBlacklistedOrNotAWord; + mWordId = dicNodeProp->mWordId; mDepth = dicNodeProp->mDepth + 1; // Increment the depth of a passing child mLeavingDepth = dicNodeProp->mLeavingDepth; - memmove(mPrevWordsTerminalPtNodePos, dicNodeProp->mPrevWordsTerminalPtNodePos, - sizeof(mPrevWordsTerminalPtNodePos)); - } - - int getPtNodePos() const { - return mPtNodePos; + const WordIdArrayView prevWordIdArrayView = dicNodeProp->getPrevWordIds(); + prevWordIdArrayView.copyToArray(&mPrevWordIds, 0 /* offset */); + mPrevWordCount = prevWordIdArrayView.size(); } int getChildrenPtNodeArrayPos() const { return mChildrenPtNodeArrayPos; } - int getProbability() const { - return mProbability; - } - int getDicNodeCodePoint() const { return mDicNodeCodePoint; } @@ -121,35 +100,32 @@ class DicNodeProperties { } bool isTerminal() const { - return mIsTerminal; + return mWordId != NOT_A_WORD_ID; } bool hasChildren() const { - return mHasChildrenPtNodes || mDepth != mLeavingDepth; + return (mChildrenPtNodeArrayPos != NOT_A_DICT_POS) || mDepth != mLeavingDepth; } - bool isBlacklistedOrNotAWord() const { - return mIsBlacklistedOrNotAWord; + const WordIdArrayView getPrevWordIds() const { + return WordIdArrayView::fromArray(mPrevWordIds).limit(mPrevWordCount); } - const int *getPrevWordsTerminalPtNodePos() const { - return mPrevWordsTerminalPtNodePos; + int getWordId() const { + return mWordId; } private: // Caution!!! // Use a default copy constructor and an assign operator because shallow copies are ok // for this class - int mPtNodePos; int mChildrenPtNodeArrayPos; - int mProbability; int mDicNodeCodePoint; - bool mIsTerminal; - bool mHasChildrenPtNodes; - bool mIsBlacklistedOrNotAWord; + int mWordId; uint16_t mDepth; uint16_t mLeavingDepth; - int mPrevWordsTerminalPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> mPrevWordIds; + size_t mPrevWordCount; }; } // namespace latinime #endif // LATINIME_DIC_NODE_PROPERTIES_H diff --git a/native/jni/src/suggest/core/dictionary/binary_dictionary_shortcut_iterator.h b/native/jni/src/suggest/core/dictionary/binary_dictionary_shortcut_iterator.h index 558e0a5c3..ee1606b6a 100644 --- a/native/jni/src/suggest/core/dictionary/binary_dictionary_shortcut_iterator.h +++ b/native/jni/src/suggest/core/dictionary/binary_dictionary_shortcut_iterator.h @@ -31,6 +31,11 @@ class BinaryDictionaryShortcutIterator { mPos(shortcutStructurePolicy->getStartPos(shortcutPos)), mHasNextShortcutTarget(shortcutPos != NOT_A_DICT_POS) {} + BinaryDictionaryShortcutIterator(const BinaryDictionaryShortcutIterator &&shortcutIterator) + : mShortcutStructurePolicy(shortcutIterator.mShortcutStructurePolicy), + mPos(shortcutIterator.mPos), + mHasNextShortcutTarget(shortcutIterator.mHasNextShortcutTarget) {} + AK_FORCE_INLINE bool hasNextShortcutTarget() const { return mHasNextShortcutTarget; } @@ -45,7 +50,8 @@ class BinaryDictionaryShortcutIterator { } private: - DISALLOW_IMPLICIT_CONSTRUCTORS(BinaryDictionaryShortcutIterator); + DISALLOW_DEFAULT_CONSTRUCTOR(BinaryDictionaryShortcutIterator); + DISALLOW_ASSIGNMENT_OPERATOR(BinaryDictionaryShortcutIterator); const DictionaryShortcutsStructurePolicy *const mShortcutStructurePolicy; int mPos; diff --git a/native/jni/src/suggest/core/dictionary/dictionary.cpp b/native/jni/src/suggest/core/dictionary/dictionary.cpp index d62573970..b843791d6 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.cpp +++ b/native/jni/src/suggest/core/dictionary/dictionary.cpp @@ -28,6 +28,7 @@ #include "suggest/core/suggest_options.h" #include "suggest/policyimpl/gesture/gesture_suggest_policy_factory.h" #include "suggest/policyimpl/typing/typing_suggest_policy_factory.h" +#include "utils/int_array_view.h" #include "utils/log_utils.h" #include "utils/time_keeper.h" @@ -60,14 +61,15 @@ void Dictionary::getSuggestions(ProximityInfo *proximityInfo, DicTraverseSession } Dictionary::NgramListenerForPrediction::NgramListenerForPrediction( - const PrevWordsInfo *const prevWordsInfo, SuggestionResults *const suggestionResults, + const PrevWordsInfo *const prevWordsInfo, const WordIdArrayView prevWordIds, + SuggestionResults *const suggestionResults, const DictionaryStructureWithBufferPolicy *const dictStructurePolicy) - : mPrevWordsInfo(prevWordsInfo), mSuggestionResults(suggestionResults), - mDictStructurePolicy(dictStructurePolicy) {} + : mPrevWordsInfo(prevWordsInfo), mPrevWordIds(prevWordIds), + mSuggestionResults(suggestionResults), mDictStructurePolicy(dictStructurePolicy) {} void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbability, - const int targetPtNodePos) { - if (targetPtNodePos == NOT_A_DICT_POS) { + const int targetWordId) { + if (targetWordId == NOT_A_WORD_ID) { return; } if (mPrevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */) @@ -77,26 +79,27 @@ void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbabi int targetWordCodePoints[MAX_WORD_LENGTH]; int unigramProbability = 0; const int codePointCount = mDictStructurePolicy-> - getCodePointsAndProbabilityAndReturnCodePointCount(targetPtNodePos, - MAX_WORD_LENGTH, targetWordCodePoints, &unigramProbability); + getCodePointsAndProbabilityAndReturnCodePointCount(targetWordId, MAX_WORD_LENGTH, + targetWordCodePoints, &unigramProbability); if (codePointCount <= 0) { return; } - const int probability = mDictStructurePolicy->getProbability( - unigramProbability, ngramProbability); - mSuggestionResults->addPrediction(targetWordCodePoints, codePointCount, probability); + const WordAttributes wordAttributes = mDictStructurePolicy->getWordAttributesInContext( + mPrevWordIds, targetWordId, nullptr /* multiBigramMap */); + mSuggestionResults->addPrediction(targetWordCodePoints, codePointCount, + wordAttributes.getProbability()); } void Dictionary::getPredictions(const PrevWordsInfo *const prevWordsInfo, SuggestionResults *const outSuggestionResults) const { TimeKeeper::setCurrentTime(); - NgramListenerForPrediction listener(prevWordsInfo, outSuggestionResults, - mDictionaryStructureWithBufferPolicy.get()); - int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; - prevWordsInfo->getPrevWordsTerminalPtNodePos( - mDictionaryStructureWithBufferPolicy.get(), prevWordsPtNodePos, + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; + const WordIdArrayView prevWordIds = prevWordsInfo->getPrevWordIds( + mDictionaryStructureWithBufferPolicy.get(), &prevWordIdArray, true /* tryLowerCaseSearch */); - mDictionaryStructureWithBufferPolicy->iterateNgramEntries(prevWordsPtNodePos, &listener); + NgramListenerForPrediction listener(prevWordsInfo, prevWordIds, outSuggestionResults, + mDictionaryStructureWithBufferPolicy.get()); + mDictionaryStructureWithBufferPolicy->iterateNgramEntries(prevWordIds, &listener); } int Dictionary::getProbability(const int *word, int length) const { @@ -112,18 +115,17 @@ int Dictionary::getMaxProbabilityOfExactMatches(const int *word, int length) con int Dictionary::getNgramProbability(const PrevWordsInfo *const prevWordsInfo, const int *word, int length) const { TimeKeeper::setCurrentTime(); - int nextWordPos = mDictionaryStructureWithBufferPolicy->getTerminalPtNodePositionOfWord(word, - length, false /* forceLowerCaseSearch */); - if (NOT_A_DICT_POS == nextWordPos) return NOT_A_PROBABILITY; + int wordId = mDictionaryStructureWithBufferPolicy->getWordId( + CodePointArrayView(word, length), false /* forceLowerCaseSearch */); + if (wordId == NOT_A_WORD_ID) return NOT_A_PROBABILITY; if (!prevWordsInfo) { - return getDictionaryStructurePolicy()->getProbabilityOfPtNode( - nullptr /* prevWordsPtNodePos */, nextWordPos); + return getDictionaryStructurePolicy()->getProbabilityOfWord(WordIdArrayView(), wordId); } - int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; - prevWordsInfo->getPrevWordsTerminalPtNodePos( - mDictionaryStructureWithBufferPolicy.get(), prevWordsPtNodePos, + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; + const WordIdArrayView prevWordIds = prevWordsInfo->getPrevWordIds + (mDictionaryStructureWithBufferPolicy.get(), &prevWordIdArray, true /* tryLowerCaseSearch */); - return getDictionaryStructurePolicy()->getProbabilityOfPtNode(prevWordsPtNodePos, nextWordPos); + return getDictionaryStructurePolicy()->getProbabilityOfWord(prevWordIds, wordId); } bool Dictionary::addUnigramEntry(const int *const word, const int length, @@ -135,12 +137,14 @@ bool Dictionary::addUnigramEntry(const int *const word, const int length, return false; } TimeKeeper::setCurrentTime(); - return mDictionaryStructureWithBufferPolicy->addUnigramEntry(word, length, unigramProperty); + return mDictionaryStructureWithBufferPolicy->addUnigramEntry(CodePointArrayView(word, length), + unigramProperty); } bool Dictionary::removeUnigramEntry(const int *const codePoints, const int codePointCount) { TimeKeeper::setCurrentTime(); - return mDictionaryStructureWithBufferPolicy->removeUnigramEntry(codePoints, codePointCount); + return mDictionaryStructureWithBufferPolicy->removeUnigramEntry( + CodePointArrayView(codePoints, codePointCount)); } bool Dictionary::addNgramEntry(const PrevWordsInfo *const prevWordsInfo, @@ -152,7 +156,8 @@ bool Dictionary::addNgramEntry(const PrevWordsInfo *const prevWordsInfo, bool Dictionary::removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, const int *const word, const int length) { TimeKeeper::setCurrentTime(); - return mDictionaryStructureWithBufferPolicy->removeNgramEntry(prevWordsInfo, word, length); + return mDictionaryStructureWithBufferPolicy->removeNgramEntry(prevWordsInfo, + CodePointArrayView(word, length)); } bool Dictionary::flush(const char *const filePath) { @@ -181,7 +186,7 @@ const WordProperty Dictionary::getWordProperty(const int *const codePoints, const int codePointCount) { TimeKeeper::setCurrentTime(); return mDictionaryStructureWithBufferPolicy->getWordProperty( - codePoints, codePointCount); + CodePointArrayView(codePoints, codePointCount)); } int Dictionary::getNextWordAndNextToken(const int token, int *const outCodePoints, diff --git a/native/jni/src/suggest/core/dictionary/dictionary.h b/native/jni/src/suggest/core/dictionary/dictionary.h index 732d3b199..0b54f30e9 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.h +++ b/native/jni/src/suggest/core/dictionary/dictionary.h @@ -26,6 +26,7 @@ #include "suggest/core/policy/dictionary_header_structure_policy.h" #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" #include "suggest/core/suggest_interface.h" +#include "utils/int_array_view.h" namespace latinime { @@ -118,14 +119,15 @@ class Dictionary { class NgramListenerForPrediction : public NgramListener { public: NgramListenerForPrediction(const PrevWordsInfo *const prevWordsInfo, - SuggestionResults *const suggestionResults, + const WordIdArrayView prevWordIds, SuggestionResults *const suggestionResults, const DictionaryStructureWithBufferPolicy *const dictStructurePolicy); - virtual void onVisitEntry(const int ngramProbability, const int targetPtNodePos); + virtual void onVisitEntry(const int ngramProbability, const int targetWordId); private: DISALLOW_IMPLICIT_CONSTRUCTORS(NgramListenerForPrediction); const PrevWordsInfo *const mPrevWordsInfo; + const WordIdArrayView mPrevWordIds; SuggestionResults *const mSuggestionResults; const DictionaryStructureWithBufferPolicy *const mDictStructurePolicy; }; diff --git a/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp b/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp index b94966cbe..d09266e29 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp +++ b/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp @@ -23,6 +23,7 @@ #include "suggest/core/dictionary/digraph_utils.h" #include "suggest/core/session/prev_words_info.h" #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" +#include "utils/int_array_view.h" namespace latinime { @@ -34,11 +35,11 @@ namespace latinime { // No prev words information. PrevWordsInfo emptyPrevWordsInfo; - int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; - emptyPrevWordsInfo.getPrevWordsTerminalPtNodePos(dictionaryStructurePolicy, - prevWordsPtNodePos, false /* tryLowerCaseSearch */); + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; + const WordIdArrayView prevWordIds = emptyPrevWordsInfo.getPrevWordIds( + dictionaryStructurePolicy, &prevWordIdArray, false /* tryLowerCaseSearch */); current.emplace_back(); - DicNodeUtils::initAsRoot(dictionaryStructurePolicy, prevWordsPtNodePos, ¤t.front()); + DicNodeUtils::initAsRoot(dictionaryStructurePolicy, prevWordIds, ¤t.front()); for (int i = 0; i < codePointCount; ++i) { // The base-lower input is used to ignore case errors and accent errors. const int codePoint = CharUtils::toBaseLowerCase(codePoints[i]); @@ -59,8 +60,11 @@ namespace latinime { if (!dicNode.isTerminalDicNode()) { continue; } + const WordAttributes wordAttributes = + dictionaryStructurePolicy->getWordAttributesInContext(dicNode.getPrevWordIds(), + dicNode.getWordId(), nullptr /* multiBigramMap */); // dicNode can contain case errors, accent errors, intentional omissions or digraphs. - maxProbability = std::max(maxProbability, dicNode.getProbability()); + maxProbability = std::max(maxProbability, wordAttributes.getProbability()); } return maxProbability; } diff --git a/native/jni/src/suggest/core/dictionary/error_type_utils.cpp b/native/jni/src/suggest/core/dictionary/error_type_utils.cpp index b6bf7a98c..1e2494e92 100644 --- a/native/jni/src/suggest/core/dictionary/error_type_utils.cpp +++ b/native/jni/src/suggest/core/dictionary/error_type_utils.cpp @@ -19,17 +19,18 @@ namespace latinime { const ErrorTypeUtils::ErrorType ErrorTypeUtils::NOT_AN_ERROR = 0x0; -const ErrorTypeUtils::ErrorType ErrorTypeUtils::MATCH_WITH_CASE_ERROR = 0x1; -const ErrorTypeUtils::ErrorType ErrorTypeUtils::MATCH_WITH_ACCENT_ERROR = 0x2; -const ErrorTypeUtils::ErrorType ErrorTypeUtils::MATCH_WITH_DIGRAPH = 0x4; -const ErrorTypeUtils::ErrorType ErrorTypeUtils::INTENTIONAL_OMISSION = 0x8; -const ErrorTypeUtils::ErrorType ErrorTypeUtils::EDIT_CORRECTION = 0x10; -const ErrorTypeUtils::ErrorType ErrorTypeUtils::PROXIMITY_CORRECTION = 0x20; -const ErrorTypeUtils::ErrorType ErrorTypeUtils::COMPLETION = 0x40; -const ErrorTypeUtils::ErrorType ErrorTypeUtils::NEW_WORD = 0x80; +const ErrorTypeUtils::ErrorType ErrorTypeUtils::MATCH_WITH_WRONG_CASE = 0x1; +const ErrorTypeUtils::ErrorType ErrorTypeUtils::MATCH_WITH_MISSING_ACCENT = 0x2; +const ErrorTypeUtils::ErrorType ErrorTypeUtils::MATCH_WITH_WRONG_ACCENT = 0x4; +const ErrorTypeUtils::ErrorType ErrorTypeUtils::MATCH_WITH_DIGRAPH = 0x8; +const ErrorTypeUtils::ErrorType ErrorTypeUtils::INTENTIONAL_OMISSION = 0x10; +const ErrorTypeUtils::ErrorType ErrorTypeUtils::EDIT_CORRECTION = 0x20; +const ErrorTypeUtils::ErrorType ErrorTypeUtils::PROXIMITY_CORRECTION = 0x40; +const ErrorTypeUtils::ErrorType ErrorTypeUtils::COMPLETION = 0x80; +const ErrorTypeUtils::ErrorType ErrorTypeUtils::NEW_WORD = 0x100; const ErrorTypeUtils::ErrorType ErrorTypeUtils::ERRORS_TREATED_AS_AN_EXACT_MATCH = - NOT_AN_ERROR | MATCH_WITH_CASE_ERROR | MATCH_WITH_ACCENT_ERROR | MATCH_WITH_DIGRAPH; + NOT_AN_ERROR | MATCH_WITH_WRONG_CASE | MATCH_WITH_MISSING_ACCENT | MATCH_WITH_DIGRAPH; const ErrorTypeUtils::ErrorType ErrorTypeUtils::ERRORS_TREATED_AS_AN_EXACT_MATCH_WITH_INTENTIONAL_OMISSION = diff --git a/native/jni/src/suggest/core/dictionary/error_type_utils.h b/native/jni/src/suggest/core/dictionary/error_type_utils.h index e3e76b238..fd1d5fcff 100644 --- a/native/jni/src/suggest/core/dictionary/error_type_utils.h +++ b/native/jni/src/suggest/core/dictionary/error_type_utils.h @@ -30,8 +30,9 @@ class ErrorTypeUtils { typedef uint32_t ErrorType; static const ErrorType NOT_AN_ERROR; - static const ErrorType MATCH_WITH_CASE_ERROR; - static const ErrorType MATCH_WITH_ACCENT_ERROR; + static const ErrorType MATCH_WITH_WRONG_CASE; + static const ErrorType MATCH_WITH_MISSING_ACCENT; + static const ErrorType MATCH_WITH_WRONG_ACCENT; static const ErrorType MATCH_WITH_DIGRAPH; // Treat error as an intentional omission when the CorrectionType is omission and the node can // be intentional omission. diff --git a/native/jni/src/suggest/core/dictionary/multi_bigram_map.cpp b/native/jni/src/suggest/core/dictionary/multi_bigram_map.cpp index 91f33a8dd..761f51ec8 100644 --- a/native/jni/src/suggest/core/dictionary/multi_bigram_map.cpp +++ b/native/jni/src/suggest/core/dictionary/multi_bigram_map.cpp @@ -35,39 +35,37 @@ const int MultiBigramMap::BigramMap::DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP = // Also caches the bigrams if there is space remaining and they have not been cached already. int MultiBigramMap::getBigramProbability( const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int *const prevWordsPtNodePos, const int nextWordPosition, + const WordIdArrayView prevWordIds, const int nextWordId, const int unigramProbability) { - if (!prevWordsPtNodePos || prevWordsPtNodePos[0] == NOT_A_DICT_POS) { + if (prevWordIds.empty() || prevWordIds[0] == NOT_A_WORD_ID) { return structurePolicy->getProbability(unigramProbability, NOT_A_PROBABILITY); } - std::unordered_map<int, BigramMap>::const_iterator mapPosition = - mBigramMaps.find(prevWordsPtNodePos[0]); + const auto mapPosition = mBigramMaps.find(prevWordIds[0]); if (mapPosition != mBigramMaps.end()) { - return mapPosition->second.getBigramProbability(structurePolicy, nextWordPosition, + return mapPosition->second.getBigramProbability(structurePolicy, nextWordId, unigramProbability); } if (mBigramMaps.size() < MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP) { - addBigramsForWordPosition(structurePolicy, prevWordsPtNodePos); - return mBigramMaps[prevWordsPtNodePos[0]].getBigramProbability(structurePolicy, - nextWordPosition, unigramProbability); + addBigramsForWord(structurePolicy, prevWordIds); + return mBigramMaps[prevWordIds[0]].getBigramProbability(structurePolicy, + nextWordId, unigramProbability); } - return readBigramProbabilityFromBinaryDictionary(structurePolicy, prevWordsPtNodePos, - nextWordPosition, unigramProbability); + return readBigramProbabilityFromBinaryDictionary(structurePolicy, prevWordIds, + nextWordId, unigramProbability); } void MultiBigramMap::BigramMap::init( const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int *const prevWordsPtNodePos) { - structurePolicy->iterateNgramEntries(prevWordsPtNodePos, this /* listener */); + const WordIdArrayView prevWordIds) { + structurePolicy->iterateNgramEntries(prevWordIds, this /* listener */); } int MultiBigramMap::BigramMap::getBigramProbability( const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int nextWordPosition, const int unigramProbability) const { + const int nextWordId, const int unigramProbability) const { int bigramProbability = NOT_A_PROBABILITY; - if (mBloomFilter.isInFilter(nextWordPosition)) { - const std::unordered_map<int, int>::const_iterator bigramProbabilityIt = - mBigramMap.find(nextWordPosition); + if (mBloomFilter.isInFilter(nextWordId)) { + const auto bigramProbabilityIt = mBigramMap.find(nextWordId); if (bigramProbabilityIt != mBigramMap.end()) { bigramProbability = bigramProbabilityIt->second; } @@ -75,29 +73,24 @@ int MultiBigramMap::BigramMap::getBigramProbability( return structurePolicy->getProbability(unigramProbability, bigramProbability); } -void MultiBigramMap::BigramMap::onVisitEntry(const int ngramProbability, - const int targetPtNodePos) { - if (targetPtNodePos == NOT_A_DICT_POS) { +void MultiBigramMap::BigramMap::onVisitEntry(const int ngramProbability, const int targetWordId) { + if (targetWordId == NOT_A_WORD_ID) { return; } - mBigramMap[targetPtNodePos] = ngramProbability; - mBloomFilter.setInFilter(targetPtNodePos); + mBigramMap[targetWordId] = ngramProbability; + mBloomFilter.setInFilter(targetWordId); } -void MultiBigramMap::addBigramsForWordPosition( +void MultiBigramMap::addBigramsForWord( const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int *const prevWordsPtNodePos) { - if (prevWordsPtNodePos) { - mBigramMaps[prevWordsPtNodePos[0]].init(structurePolicy, prevWordsPtNodePos); - } + const WordIdArrayView prevWordIds) { + mBigramMaps[prevWordIds[0]].init(structurePolicy, prevWordIds); } int MultiBigramMap::readBigramProbabilityFromBinaryDictionary( const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int *const prevWordsPtNodePos, const int nextWordPosition, - const int unigramProbability) { - const int bigramProbability = structurePolicy->getProbabilityOfPtNode(prevWordsPtNodePos, - nextWordPosition); + const WordIdArrayView prevWordIds, const int nextWordId, const int unigramProbability) { + const int bigramProbability = structurePolicy->getProbabilityOfWord(prevWordIds, nextWordId); if (bigramProbability != NOT_A_PROBABILITY) { return bigramProbability; } diff --git a/native/jni/src/suggest/core/dictionary/multi_bigram_map.h b/native/jni/src/suggest/core/dictionary/multi_bigram_map.h index ad36dde83..d2eb5cc32 100644 --- a/native/jni/src/suggest/core/dictionary/multi_bigram_map.h +++ b/native/jni/src/suggest/core/dictionary/multi_bigram_map.h @@ -25,6 +25,7 @@ #include "suggest/core/dictionary/bloom_filter.h" #include "suggest/core/dictionary/ngram_listener.h" #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" +#include "utils/int_array_view.h" namespace latinime { @@ -39,8 +40,7 @@ class MultiBigramMap { // Look up the bigram probability for the given word pair from the cached bigram maps. // Also caches the bigrams if there is space remaining and they have not been cached already. int getBigramProbability(const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int *const prevWordsPtNodePos, const int nextWordPosition, - const int unigramProbability); + const WordIdArrayView prevWordIds, const int nextWordId, const int unigramProbability); void clear() { mBigramMaps.clear(); @@ -58,11 +58,11 @@ class MultiBigramMap { virtual ~BigramMap() {} void init(const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int *const prevWordsPtNodePos); + const WordIdArrayView prevWordIds); int getBigramProbability( const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int nextWordPosition, const int unigramProbability) const; - virtual void onVisitEntry(const int ngramProbability, const int targetPtNodePos); + const int nextWordId, const int unigramProbability) const; + virtual void onVisitEntry(const int ngramProbability, const int targetWordId); private: static const int DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP; @@ -70,14 +70,12 @@ class MultiBigramMap { BloomFilter mBloomFilter; }; - void addBigramsForWordPosition( - const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int *const prevWordsPtNodePos); + void addBigramsForWord(const DictionaryStructureWithBufferPolicy *const structurePolicy, + const WordIdArrayView prevWordIds); int readBigramProbabilityFromBinaryDictionary( const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int *const prevWordsPtNodePos, const int nextWordPosition, - const int unigramProbability); + const WordIdArrayView prevWordIds, const int nextWordId, const int unigramProbability); static const size_t MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP; std::unordered_map<int, BigramMap> mBigramMaps; diff --git a/native/jni/src/suggest/core/dictionary/ngram_listener.h b/native/jni/src/suggest/core/dictionary/ngram_listener.h index 88b88bafb..e9b3c1aaf 100644 --- a/native/jni/src/suggest/core/dictionary/ngram_listener.h +++ b/native/jni/src/suggest/core/dictionary/ngram_listener.h @@ -26,7 +26,7 @@ namespace latinime { */ class NgramListener { public: - virtual void onVisitEntry(const int ngramProbability, const int targetPtNodePos) = 0; + virtual void onVisitEntry(const int ngramProbability, const int targetWordId) = 0; virtual ~NgramListener() {}; protected: diff --git a/native/jni/src/suggest/core/dictionary/property/word_property.cpp b/native/jni/src/suggest/core/dictionary/property/word_property.cpp index 5bdd5606b..66daf3e3f 100644 --- a/native/jni/src/suggest/core/dictionary/property/word_property.cpp +++ b/native/jni/src/suggest/core/dictionary/property/word_property.cpp @@ -65,8 +65,6 @@ void WordProperty::outputProperties(JNIEnv *const env, jintArray outCodePoints, for (const auto &shortcut : mUnigramProperty.getShortcuts()) { const std::vector<int> *const targetCodePoints = shortcut.getTargetCodePoints(); jintArray shortcutTargetCodePointArray = env->NewIntArray(targetCodePoints->size()); - env->SetIntArrayRegion(shortcutTargetCodePointArray, 0 /* start */, - targetCodePoints->size(), targetCodePoints->data()); JniDataUtils::outputCodePoints(env, shortcutTargetCodePointArray, 0 /* start */, targetCodePoints->size(), targetCodePoints->data(), targetCodePoints->size(), false /* needsNullTermination */); diff --git a/native/jni/src/suggest/core/dictionary/word_attributes.h b/native/jni/src/suggest/core/dictionary/word_attributes.h new file mode 100644 index 000000000..6e9da3570 --- /dev/null +++ b/native/jni/src/suggest/core/dictionary/word_attributes.h @@ -0,0 +1,60 @@ +/* + * Copyright (C) 2014, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_WORD_ATTRIBUTES_H +#define LATINIME_WORD_ATTRIBUTES_H + +#include "defines.h" + +class WordAttributes { + public: + // Invalid word attributes. + WordAttributes() + : mProbability(NOT_A_PROBABILITY), mIsBlacklisted(false), mIsNotAWord(false), + mIsPossiblyOffensive(false) {} + + WordAttributes(const int probability, const bool isBlacklisted, const bool isNotAWord, + const bool isPossiblyOffensive) + : mProbability(probability), mIsBlacklisted(isBlacklisted), mIsNotAWord(isNotAWord), + mIsPossiblyOffensive(isPossiblyOffensive) {} + + int getProbability() const { + return mProbability; + } + + bool isBlacklisted() const { + return mIsBlacklisted; + } + + bool isNotAWord() const { + return mIsNotAWord; + } + + bool isPossiblyOffensive() const { + return mIsPossiblyOffensive; + } + + private: + DISALLOW_ASSIGNMENT_OPERATOR(WordAttributes); + + int mProbability; + bool mIsBlacklisted; + bool mIsNotAWord; + bool mIsPossiblyOffensive; +}; + + // namespace +#endif /* LATINIME_WORD_ATTRIBUTES_H */ diff --git a/native/jni/src/suggest/core/layout/geometry_utils.h b/native/jni/src/suggest/core/layout/geometry_utils.h index b667df68f..000fcd4a1 100644 --- a/native/jni/src/suggest/core/layout/geometry_utils.h +++ b/native/jni/src/suggest/core/layout/geometry_utils.h @@ -38,13 +38,15 @@ class GeometryUtils { } static AK_FORCE_INLINE float getAngleDiff(const float a1, const float a2) { - const float deltaA = fabsf(a1 - a2); - const float diff = ROUND_FLOAT_10000(deltaA); - if (diff > M_PI_F) { - const float normalizedDiff = 2.0f * M_PI_F - diff; - return ROUND_FLOAT_10000(normalizedDiff); + static const float M_2PI_F = M_PI * 2.0f; + float delta = fabsf(a1 - a2); + if (delta > M_2PI_F) { + delta -= (M_2PI_F * static_cast<int>(delta / M_2PI_F)); } - return diff; + if (delta > M_PI_F) { + delta = M_2PI_F - delta; + } + return ROUND_FLOAT_10000(delta); } static AK_FORCE_INLINE int getDistanceInt(const int x1, const int y1, const int x2, diff --git a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h index e91f07682..a498b6f65 100644 --- a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h +++ b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h @@ -20,14 +20,17 @@ #include <memory> #include "defines.h" +#include "suggest/core/dictionary/binary_dictionary_shortcut_iterator.h" #include "suggest/core/dictionary/property/word_property.h" +#include "suggest/core/dictionary/word_attributes.h" +#include "utils/int_array_view.h" namespace latinime { class DicNode; class DicNodeVector; class DictionaryHeaderStructurePolicy; -class DictionaryShortcutsStructurePolicy; +class MultiBigramMap; class NgramListener; class PrevWordsInfo; class UnigramProperty; @@ -36,6 +39,7 @@ class UnigramProperty; * This class abstracts the structure of dictionaries. * Implement this policy to support additional dictionaries. */ +// TODO: Use word id instead of terminal PtNode position. class DictionaryStructureWithBufferPolicy { public: typedef std::unique_ptr<DictionaryStructureWithBufferPolicy> StructurePolicyPtr; @@ -48,33 +52,33 @@ class DictionaryStructureWithBufferPolicy { DicNodeVector *const childDicNodes) const = 0; virtual int getCodePointsAndProbabilityAndReturnCodePointCount( - const int nodePos, const int maxCodePointCount, int *const outCodePoints, + const int wordId, const int maxCodePointCount, int *const outCodePoints, int *const outUnigramProbability) const = 0; - virtual int getTerminalPtNodePositionOfWord(const int *const inWord, - const int length, const bool forceLowerCaseSearch) const = 0; + virtual int getWordId(const CodePointArrayView wordCodePoints, + const bool forceLowerCaseSearch) const = 0; - virtual int getProbability(const int unigramProbability, - const int bigramProbability) const = 0; + virtual const WordAttributes getWordAttributesInContext(const WordIdArrayView prevWordIds, + const int wordId, MultiBigramMap *const multiBigramMap) const = 0; - virtual int getProbabilityOfPtNode(const int *const prevWordsPtNodePos, - const int nodePos) const = 0; + // TODO: Remove + virtual int getProbability(const int unigramProbability, const int bigramProbability) const = 0; - virtual void iterateNgramEntries(const int *const prevWordsPtNodePos, + virtual int getProbabilityOfWord(const WordIdArrayView prevWordIds, const int wordId) const = 0; + + virtual void iterateNgramEntries(const WordIdArrayView prevWordIds, NgramListener *const listener) const = 0; - virtual int getShortcutPositionOfPtNode(const int nodePos) const = 0; + virtual BinaryDictionaryShortcutIterator getShortcutIterator(const int wordId) const = 0; virtual const DictionaryHeaderStructurePolicy *getHeaderStructurePolicy() const = 0; - virtual const DictionaryShortcutsStructurePolicy *getShortcutsStructurePolicy() const = 0; - // Returns whether the update was success or not. - virtual bool addUnigramEntry(const int *const word, const int length, + virtual bool addUnigramEntry(const CodePointArrayView wordCodePoints, const UnigramProperty *const unigramProperty) = 0; // Returns whether the update was success or not. - virtual bool removeUnigramEntry(const int *const word, const int length) = 0; + virtual bool removeUnigramEntry(const CodePointArrayView wordCodePoints) = 0; // Returns whether the update was success or not. virtual bool addNgramEntry(const PrevWordsInfo *const prevWordsInfo, @@ -82,7 +86,7 @@ class DictionaryStructureWithBufferPolicy { // Returns whether the update was success or not. virtual bool removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, - const int *const word, const int length) = 0; + const CodePointArrayView wordCodePoints) = 0; // Returns whether the flush was success or not. virtual bool flush(const char *const filePath) = 0; @@ -98,8 +102,7 @@ class DictionaryStructureWithBufferPolicy { const int maxResultLength) = 0; // Used for testing. - virtual const WordProperty getWordProperty(const int *const codePonts, - const int codePointCount) const = 0; + virtual const WordProperty getWordProperty(const CodePointArrayView wordCodePoints) const = 0; // Method to iterate all words in the dictionary. // The returned token has to be used to get the next word. If token is 0, this method newly diff --git a/native/jni/src/suggest/core/policy/traversal.h b/native/jni/src/suggest/core/policy/traversal.h index 8ddaa0514..6dfa7e314 100644 --- a/native/jni/src/suggest/core/policy/traversal.h +++ b/native/jni/src/suggest/core/policy/traversal.h @@ -48,7 +48,8 @@ class Traversal { virtual int getTerminalCacheSize() const = 0; virtual bool isPossibleOmissionChildNode(const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0; - virtual bool isGoodToTraverseNextWord(const DicNode *const dicNode) const = 0; + virtual bool isGoodToTraverseNextWord(const DicNode *const dicNode, + const int probability) const = 0; protected: Traversal() {} diff --git a/native/jni/src/suggest/core/result/suggestions_output_utils.cpp b/native/jni/src/suggest/core/result/suggestions_output_utils.cpp index 0b99b75ec..6e0193772 100644 --- a/native/jni/src/suggest/core/result/suggestions_output_utils.cpp +++ b/native/jni/src/suggest/core/result/suggestions_output_utils.cpp @@ -85,9 +85,9 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; scoringPolicy->getDoubleLetterDemotionDistanceCost(terminalDicNode); const float compoundDistance = terminalDicNode->getCompoundDistance(languageWeight) + doubleLetterCost; - const bool isPossiblyOffensiveWord = - traverseSession->getDictionaryStructurePolicy()->getProbability( - terminalDicNode->getProbability(), NOT_A_PROBABILITY) <= 0; + const WordAttributes wordAttributes = traverseSession->getDictionaryStructurePolicy() + ->getWordAttributesInContext(terminalDicNode->getPrevWordIds(), + terminalDicNode->getWordId(), nullptr /* multiBigramMap */); const bool isExactMatch = ErrorTypeUtils::isExactMatch(terminalDicNode->getContainedErrorTypes()); const bool isExactMatchWithIntentionalOmission = @@ -97,19 +97,19 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; // Heuristic: We exclude probability=0 first-char-uppercase words from exact match. // (e.g. "AMD" and "and") const bool isSafeExactMatch = isExactMatch - && !(isPossiblyOffensiveWord && isFirstCharUppercase); + && !(wordAttributes.isPossiblyOffensive() && isFirstCharUppercase); const int outputTypeFlags = - (isPossiblyOffensiveWord ? Dictionary::KIND_FLAG_POSSIBLY_OFFENSIVE : 0) + (wordAttributes.isPossiblyOffensive() ? Dictionary::KIND_FLAG_POSSIBLY_OFFENSIVE : 0) | ((isSafeExactMatch && boostExactMatches) ? Dictionary::KIND_FLAG_EXACT_MATCH : 0) | (isExactMatchWithIntentionalOmission ? Dictionary::KIND_FLAG_EXACT_MATCH_WITH_INTENTIONAL_OMISSION : 0); // Entries that are blacklisted or do not represent a word should not be output. - const bool isValidWord = !terminalDicNode->isBlacklistedOrNotAWord(); + const bool isValidWord = !(wordAttributes.isBlacklisted() || wordAttributes.isNotAWord()); // When we have to block offensive words, non-exact matched offensive words should not be // output. const bool blockOffensiveWords = traverseSession->getSuggestOptions()->blockOffensiveWords(); - const bool isBlockedOffensiveWord = blockOffensiveWords && isPossiblyOffensiveWord + const bool isBlockedOffensiveWord = blockOffensiveWords && wordAttributes.isPossiblyOffensive() && !isSafeExactMatch; // Increase output score of top typing suggestion to ensure autocorrection. @@ -139,10 +139,9 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; // Shortcut is not supported for multiple words suggestions. // TODO: Check shortcuts during traversal for multiple words suggestions. if (!terminalDicNode->hasMultipleWords()) { - BinaryDictionaryShortcutIterator shortcutIt( - traverseSession->getDictionaryStructurePolicy()->getShortcutsStructurePolicy(), - traverseSession->getDictionaryStructurePolicy() - ->getShortcutPositionOfPtNode(terminalDicNode->getPtNodePos())); + BinaryDictionaryShortcutIterator shortcutIt = + traverseSession->getDictionaryStructurePolicy()->getShortcutIterator( + terminalDicNode->getWordId()); const bool sameAsTyped = scoringPolicy->sameAsTyped(traverseSession, terminalDicNode); outputShortcuts(&shortcutIt, finalScore, sameAsTyped, outSuggestionResults); } diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.cpp b/native/jni/src/suggest/core/session/dic_traverse_session.cpp index f1e411f38..4d7505a55 100644 --- a/native/jni/src/suggest/core/session/dic_traverse_session.cpp +++ b/native/jni/src/suggest/core/session/dic_traverse_session.cpp @@ -35,8 +35,8 @@ void DicTraverseSession::init(const Dictionary *const dictionary, mMultiWordCostMultiplier = getDictionaryStructurePolicy()->getHeaderStructurePolicy() ->getMultiWordCostMultiplier(); mSuggestOptions = suggestOptions; - prevWordsInfo->getPrevWordsTerminalPtNodePos( - getDictionaryStructurePolicy(), mPrevWordsPtNodePos, true /* tryLowerCaseSearch */); + mPrevWordIdCount = prevWordsInfo->getPrevWordIds(getDictionaryStructurePolicy(), + &mPrevWordIdArray, true /* tryLowerCaseSearch */).size(); } void DicTraverseSession::setupForGetSuggestions(const ProximityInfo *pInfo, diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.h b/native/jni/src/suggest/core/session/dic_traverse_session.h index 5a51a112d..9f841aa3c 100644 --- a/native/jni/src/suggest/core/session/dic_traverse_session.h +++ b/native/jni/src/suggest/core/session/dic_traverse_session.h @@ -24,6 +24,7 @@ #include "suggest/core/dicnode/dic_nodes_cache.h" #include "suggest/core/dictionary/multi_bigram_map.h" #include "suggest/core/layout/proximity_info_state.h" +#include "utils/int_array_view.h" namespace latinime { @@ -50,14 +51,11 @@ class DicTraverseSession { } AK_FORCE_INLINE DicTraverseSession(JNIEnv *env, jstring localeStr, bool usesLargeCache) - : mProximityInfo(nullptr), mDictionary(nullptr), mSuggestOptions(nullptr), - mDicNodesCache(usesLargeCache), mMultiBigramMap(), mInputSize(0), mMaxPointerCount(1), - mMultiWordCostMultiplier(1.0f) { + : mPrevWordIdCount(0), mProximityInfo(nullptr), mDictionary(nullptr), + mSuggestOptions(nullptr), mDicNodesCache(usesLargeCache), mMultiBigramMap(), + mInputSize(0), mMaxPointerCount(1), mMultiWordCostMultiplier(1.0f) { // NOTE: mProximityInfoStates is an array of instances. // No need to initialize it explicitly here. - for (size_t i = 0; i < NELEMS(mPrevWordsPtNodePos); ++i) { - mPrevWordsPtNodePos[i] = NOT_A_DICT_POS; - } } // Non virtual inline destructor -- never inherit this class @@ -79,7 +77,9 @@ class DicTraverseSession { //-------------------- const ProximityInfo *getProximityInfo() const { return mProximityInfo; } const SuggestOptions *getSuggestOptions() const { return mSuggestOptions; } - const int *getPrevWordsPtNodePos() const { return mPrevWordsPtNodePos; } + const WordIdArrayView getPrevWordIds() const { + return WordIdArrayView::fromArray(mPrevWordIdArray).limit(mPrevWordIdCount); + } DicNodesCache *getDicTraverseCache() { return &mDicNodesCache; } MultiBigramMap *getMultiBigramMap() { return &mMultiBigramMap; } const ProximityInfoState *getProximityInfoState(int id) const { @@ -166,7 +166,8 @@ class DicTraverseSession { const int *const inputYs, const int *const times, const int *const pointerIds, const int inputSize, const float maxSpatialDistance, const int maxPointerCount); - int mPrevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> mPrevWordIdArray; + size_t mPrevWordIdCount; const ProximityInfo *mProximityInfo; const Dictionary *mDictionary; const SuggestOptions *mSuggestOptions; diff --git a/native/jni/src/suggest/core/session/prev_words_info.h b/native/jni/src/suggest/core/session/prev_words_info.h index e44e876e9..02e82a8e0 100644 --- a/native/jni/src/suggest/core/session/prev_words_info.h +++ b/native/jni/src/suggest/core/session/prev_words_info.h @@ -17,23 +17,25 @@ #ifndef LATINIME_PREV_WORDS_INFO_H #define LATINIME_PREV_WORDS_INFO_H +#include <array> + #include "defines.h" -#include "suggest/core/dictionary/binary_dictionary_bigrams_iterator.h" #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" #include "utils/char_utils.h" +#include "utils/int_array_view.h" namespace latinime { -// TODO: Support n-gram. class PrevWordsInfo { public: // No prev word information. - PrevWordsInfo() { + PrevWordsInfo() : mPrevWordCount(0) { clear(); } - PrevWordsInfo(PrevWordsInfo &&prevWordsInfo) { - for (size_t i = 0; i < NELEMS(mPrevWordCodePoints); ++i) { + PrevWordsInfo(PrevWordsInfo &&prevWordsInfo) + : mPrevWordCount(prevWordsInfo.mPrevWordCount) { + for (size_t i = 0; i < mPrevWordCount; ++i) { mPrevWordCodePointCount[i] = prevWordsInfo.mPrevWordCodePointCount[i]; memmove(mPrevWordCodePoints[i], prevWordsInfo.mPrevWordCodePoints[i], sizeof(mPrevWordCodePoints[i][0]) * mPrevWordCodePointCount[i]); @@ -44,9 +46,10 @@ class PrevWordsInfo { // Construct from previous words. PrevWordsInfo(const int prevWordCodePoints[][MAX_WORD_LENGTH], const int *const prevWordCodePointCount, const bool *const isBeginningOfSentence, - const size_t prevWordCount) { + const size_t prevWordCount) + : mPrevWordCount(std::min(NELEMS(mPrevWordCodePoints), prevWordCount)) { clear(); - for (size_t i = 0; i < std::min(NELEMS(mPrevWordCodePoints), prevWordCount); ++i) { + for (size_t i = 0; i < mPrevWordCount; ++i) { if (prevWordCodePointCount[i] < 0 || prevWordCodePointCount[i] > MAX_WORD_LENGTH) { continue; } @@ -59,7 +62,7 @@ class PrevWordsInfo { // Construct from a previous word. PrevWordsInfo(const int *const prevWordCodePoints, const int prevWordCodePointCount, - const bool isBeginningOfSentence) { + const bool isBeginningOfSentence) : mPrevWordCount(1) { clear(); if (prevWordCodePointCount > MAX_WORD_LENGTH || !prevWordCodePoints) { return; @@ -80,35 +83,29 @@ class PrevWordsInfo { return false; } - void getPrevWordsTerminalPtNodePos( + template<size_t N> + const WordIdArrayView getPrevWordIds( const DictionaryStructureWithBufferPolicy *const dictStructurePolicy, - int *const outPrevWordsTerminalPtNodePos, const bool tryLowerCaseSearch) const { - for (size_t i = 0; i < NELEMS(mPrevWordCodePoints); ++i) { - outPrevWordsTerminalPtNodePos[i] = getTerminalPtNodePosOfWord(dictStructurePolicy, + std::array<int, N> *const prevWordIdBuffer, const bool tryLowerCaseSearch) const { + for (size_t i = 0; i < std::min(mPrevWordCount, N); ++i) { + prevWordIdBuffer->at(i) = getWordId(dictStructurePolicy, mPrevWordCodePoints[i], mPrevWordCodePointCount[i], mIsBeginningOfSentence[i], tryLowerCaseSearch); } + return WordIdArrayView::fromArray(*prevWordIdBuffer).limit(mPrevWordCount); } // n is 1-indexed. - const int *getNthPrevWordCodePoints(const int n) const { - if (n <= 0 || n > MAX_PREV_WORD_COUNT_FOR_N_GRAM) { - return nullptr; - } - return mPrevWordCodePoints[n - 1]; - } - - // n is 1-indexed. - int getNthPrevWordCodePointCount(const int n) const { - if (n <= 0 || n > MAX_PREV_WORD_COUNT_FOR_N_GRAM) { - return 0; + const CodePointArrayView getNthPrevWordCodePoints(const size_t n) const { + if (n <= 0 || n > mPrevWordCount) { + return CodePointArrayView(); } - return mPrevWordCodePointCount[n - 1]; + return CodePointArrayView(mPrevWordCodePoints[n - 1], mPrevWordCodePointCount[n - 1]); } // n is 1-indexed. - bool isNthPrevWordBeginningOfSentence(const int n) const { - if (n <= 0 || n > MAX_PREV_WORD_COUNT_FOR_N_GRAM) { + bool isNthPrevWordBeginningOfSentence(const size_t n) const { + if (n <= 0 || n > mPrevWordCount) { return false; } return mIsBeginningOfSentence[n - 1]; @@ -117,12 +114,11 @@ class PrevWordsInfo { private: DISALLOW_COPY_AND_ASSIGN(PrevWordsInfo); - static int getTerminalPtNodePosOfWord( - const DictionaryStructureWithBufferPolicy *const dictStructurePolicy, + static int getWordId(const DictionaryStructureWithBufferPolicy *const dictStructurePolicy, const int *const wordCodePoints, const int wordCodePointCount, const bool isBeginningOfSentence, const bool tryLowerCaseSearch) { if (!dictStructurePolicy || !wordCodePoints || wordCodePointCount > MAX_WORD_LENGTH) { - return NOT_A_DICT_POS; + return NOT_A_WORD_ID; } int codePoints[MAX_WORD_LENGTH]; int codePointCount = wordCodePointCount; @@ -131,20 +127,19 @@ class PrevWordsInfo { codePointCount = CharUtils::attachBeginningOfSentenceMarker(codePoints, codePointCount, MAX_WORD_LENGTH); if (codePointCount <= 0) { - return NOT_A_DICT_POS; + return NOT_A_WORD_ID; } } - const int wordPtNodePos = dictStructurePolicy->getTerminalPtNodePositionOfWord( - codePoints, codePointCount, false /* forceLowerCaseSearch */); - if (wordPtNodePos != NOT_A_DICT_POS || !tryLowerCaseSearch) { - // Return the position when when the word was found or doesn't try lower case - // search. - return wordPtNodePos; + const CodePointArrayView codePointArrayView(codePoints, codePointCount); + const int wordId = dictStructurePolicy->getWordId( + codePointArrayView, false /* forceLowerCaseSearch */); + if (wordId != NOT_A_WORD_ID || !tryLowerCaseSearch) { + // Return the id when when the word was found or doesn't try lower case search. + return wordId; } // Check bigrams for lower-cased previous word if original was not found. Useful for // auto-capitalized words like "The [current_word]". - return dictStructurePolicy->getTerminalPtNodePositionOfWord( - codePoints, codePointCount, true /* forceLowerCaseSearch */); + return dictStructurePolicy->getWordId(codePointArrayView, true /* forceLowerCaseSearch */); } void clear() { @@ -154,6 +149,7 @@ class PrevWordsInfo { } } + const size_t mPrevWordCount; int mPrevWordCodePoints[MAX_PREV_WORD_COUNT_FOR_N_GRAM][MAX_WORD_LENGTH]; int mPrevWordCodePointCount[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; bool mIsBeginningOfSentence[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp index 0cd305f5a..947d41f4b 100644 --- a/native/jni/src/suggest/core/suggest.cpp +++ b/native/jni/src/suggest/core/suggest.cpp @@ -21,6 +21,7 @@ #include "suggest/core/dicnode/dic_node_vector.h" #include "suggest/core/dictionary/dictionary.h" #include "suggest/core/dictionary/digraph_utils.h" +#include "suggest/core/dictionary/word_attributes.h" #include "suggest/core/layout/proximity_info.h" #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" #include "suggest/core/policy/traversal.h" @@ -92,7 +93,7 @@ void Suggest::initializeSearch(DicTraverseSession *traverseSession) const { // Create a new dic node here DicNode rootNode; DicNodeUtils::initAsRoot(traverseSession->getDictionaryStructurePolicy(), - traverseSession->getPrevWordsPtNodePos(), &rootNode); + traverseSession->getPrevWordIds(), &rootNode); traverseSession->getDicTraverseCache()->copyPushActive(&rootNode); } } @@ -412,7 +413,11 @@ void Suggest::weightChildNode(DicTraverseSession *traverseSession, DicNode *dicN */ void Suggest::createNextWordDicNode(DicTraverseSession *traverseSession, DicNode *dicNode, const bool spaceSubstitution) const { - if (!TRAVERSAL->isGoodToTraverseNextWord(dicNode)) { + const WordAttributes wordAttributes = + traverseSession->getDictionaryStructurePolicy()->getWordAttributesInContext( + dicNode->getPrevWordIds(), dicNode->getWordId(), + traverseSession->getMultiBigramMap()); + if (!TRAVERSAL->isGoodToTraverseNextWord(dicNode, wordAttributes.getProbability())) { return; } diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp index a8f8f284b..d2c3d2fe0 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp @@ -142,7 +142,8 @@ typedef DictionaryHeaderStructurePolicy::AttributeMap AttributeMap; } /* static */ void HeaderReadWriteUtils::setCodePointVectorAttribute( - AttributeMap *const headerAttributes, const char *const key, const std::vector<int> value) { + AttributeMap *const headerAttributes, const char *const key, + const std::vector<int> &value) { AttributeMap::key_type keyVector; insertCharactersIntoVector(key, &keyVector); (*headerAttributes)[keyVector] = value; diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.h b/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.h index 9b90488fc..1ab2eec69 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.h @@ -64,7 +64,7 @@ class HeaderReadWriteUtils { */ static void setCodePointVectorAttribute( DictionaryHeaderStructurePolicy::AttributeMap *const headerAttributes, - const char *const key, const std::vector<int> value); + const char *const key, const std::vector<int> &value); static void setBoolAttribute( DictionaryHeaderStructurePolicy::AttributeMap *const headerAttributes, diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp index 278f2b199..97a8bcc98 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp @@ -234,8 +234,8 @@ bool Ver4PatriciaTrieNodeWriter::writeNewTerminalPtNodeAndAdvancePosition( bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds, const int wordId, const BigramProperty *const bigramProperty, bool *const outAddedNewEntry) { if (!mBigramPolicy->addNewEntry(prevWordIds[0], wordId, bigramProperty, outAddedNewEntry)) { - AKLOGE("Cannot add new bigram entry. terminalId: %d, targetTerminalId: %d", - sourcePtNodeParams->getTerminalId(), targetPtNodeParam->getTerminalId()); + AKLOGE("Cannot add new bigram entry. prevWordId: %d, wordId: %d", + prevWordIds[0], wordId); return false; } const int ptNodePos = @@ -425,6 +425,18 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeFlags(const int ptNodePos, return true; } +bool Ver4PatriciaTrieNodeWriter::suppressUnigramEntry(const PtNodeParams *const ptNodeParams) { + if (!mHeaderPolicy->hasHistoricalInfoOfWords()) { + // Require historical info to suppress unigram entry. + return false; + } + const HistoricalInfo suppressedHistorycalInfo(0 /* timestamp */, 0 /* level */, 0 /* count */); + const ProbabilityEntry probabilityEntryToWrite = + ProbabilityEntry().createEntryWithUpdatedHistoricalInfo(&suppressedHistorycalInfo); + return mBuffers->getMutableProbabilityDictContent()->setProbabilityEntry( + ptNodeParams->getTerminalId(), &probabilityEntryToWrite); +} + } // namespace v402 } // namespace backward } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.h index d49d9a666..9d8a55bff 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.h @@ -111,6 +111,11 @@ class Ver4PatriciaTrieNodeWriter : public PtNodeWriter { bool updatePtNodeHasBigramsAndShortcutTargetsFlags(const PtNodeParams *const ptNodeParams); + // Suppress unigram not to use the word for generating suggestions. So, this method can be used + // only for dictionaries with historical info. Also, suppressed entries are included in unigram + // count. They will be removed from the dictionary during GC. + bool suppressUnigramEntry(const PtNodeParams *const ptNodeParams); + private: DISALLOW_COPY_AND_ASSIGN(Ver4PatriciaTrieNodeWriter); diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp index 1296b8acd..41b9a11b1 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp @@ -28,6 +28,7 @@ #include "suggest/core/dicnode/dic_node.h" #include "suggest/core/dicnode/dic_node_vector.h" +#include "suggest/core/dictionary/multi_bigram_map.h" #include "suggest/core/dictionary/ngram_listener.h" #include "suggest/core/dictionary/property/bigram_property.h" #include "suggest/core/dictionary/property/unigram_property.h" @@ -76,12 +77,9 @@ void Ver4PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const d // Skip PtNodes that represent non-word information. continue; } - childDicNodes->pushLeavingChild(dicNode, ptNodeParams.getHeadPos(), - ptNodeParams.getChildrenPos(), ptNodeParams.getProbability(), isTerminal, - ptNodeParams.hasChildren(), - ptNodeParams.isBlacklisted() - || ptNodeParams.isNotAWord() /* isBlacklistedOrNotAWord */, - ptNodeParams.getCodePointCount(), ptNodeParams.getCodePoints()); + const int wordId = isTerminal ? ptNodeParams.getHeadPos() : NOT_A_WORD_ID; + childDicNodes->pushLeavingChild(dicNode, ptNodeParams.getChildrenPos(), + wordId, ptNodeParams.getCodePointArrayView()); } if (readingHelper.isError()) { mIsCorrupted = true; @@ -90,9 +88,10 @@ void Ver4PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const d } int Ver4PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( - const int ptNodePos, const int maxCodePointCount, int *const outCodePoints, + const int wordId, const int maxCodePointCount, int *const outCodePoints, int *const outUnigramProbability) const { DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader); + const int ptNodePos = getTerminalPtNodePosFromWordId(wordId); readingHelper.initWithPtNodePos(ptNodePos); const int codePointCount = readingHelper.getCodePointsAndProbabilityAndReturnCodePointCount( maxCodePointCount, outCodePoints, outUnigramProbability); @@ -103,17 +102,46 @@ int Ver4PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( return codePointCount; } -int Ver4PatriciaTriePolicy::getTerminalPtNodePositionOfWord(const int *const inWord, - const int length, const bool forceLowerCaseSearch) const { +int Ver4PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints, + const bool forceLowerCaseSearch) const { DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader); readingHelper.initWithPtNodeArrayPos(getRootPosition()); - const int ptNodePos = - readingHelper.getTerminalPtNodePositionOfWord(inWord, length, forceLowerCaseSearch); + const int ptNodePos = readingHelper.getTerminalPtNodePositionOfWord(wordCodePoints.data(), + wordCodePoints.size(), forceLowerCaseSearch); if (readingHelper.isError()) { mIsCorrupted = true; - AKLOGE("Dictionary reading error in createAndGetAllChildDicNodes()."); + AKLOGE("Dictionary reading error in getWordId()."); + } + return getWordIdFromTerminalPtNodePos(ptNodePos); +} + +const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext( + const WordIdArrayView prevWordIds, const int wordId, + MultiBigramMap *const multiBigramMap) const { + if (wordId == NOT_A_WORD_ID) { + return WordAttributes(); + } + const int ptNodePos = getTerminalPtNodePosFromWordId(wordId); + const PtNodeParams ptNodeParams(mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos)); + if (multiBigramMap) { + const int probability = multiBigramMap->getBigramProbability(this /* structurePolicy */, + prevWordIds, wordId, ptNodeParams.getProbability()); + return getWordAttributes(probability, ptNodeParams); + } + if (!prevWordIds.empty()) { + const int probability = getProbabilityOfWord(prevWordIds, wordId); + if (probability != NOT_A_PROBABILITY) { + return getWordAttributes(probability, ptNodeParams); + } } - return ptNodePos; + return getWordAttributes(getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY), + ptNodeParams); +} + +const WordAttributes Ver4PatriciaTriePolicy::getWordAttributes(const int probability, + const PtNodeParams &ptNodeParams) const { + return WordAttributes(probability, ptNodeParams.isBlacklisted(), ptNodeParams.isNotAWord(), + ptNodeParams.getProbability() == 0); } int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability, @@ -132,17 +160,19 @@ int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability, } } -int Ver4PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtNodePos, - const int ptNodePos) const { - if (ptNodePos == NOT_A_DICT_POS) { +int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds, + const int wordId) const { + if (wordId == NOT_A_WORD_ID) { return NOT_A_PROBABILITY; } + const int ptNodePos = getTerminalPtNodePosFromWordId(wordId); const PtNodeParams ptNodeParams(mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos)); if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) { return NOT_A_PROBABILITY; } - if (prevWordsPtNodePos) { - const int bigramsPosition = getBigramsPositionOfPtNode(prevWordsPtNodePos[0]); + if (!prevWordIds.empty()) { + const int bigramsPosition = getBigramsPositionOfPtNode( + getTerminalPtNodePosFromWordId(prevWordIds[0])); BinaryDictionaryBigramsIterator bigramsIt(&mBigramPolicy, bigramsPosition); while (bigramsIt.hasNext()) { bigramsIt.next(); @@ -156,19 +186,27 @@ int Ver4PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtN return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY); } -void Ver4PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordsPtNodePos, +void Ver4PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordIds, NgramListener *const listener) const { - if (!prevWordsPtNodePos) { + if (prevWordIds.empty()) { return; } - const int bigramsPosition = getBigramsPositionOfPtNode(prevWordsPtNodePos[0]); + const int bigramsPosition = getBigramsPositionOfPtNode( + getTerminalPtNodePosFromWordId(prevWordIds[0])); BinaryDictionaryBigramsIterator bigramsIt(&mBigramPolicy, bigramsPosition); while (bigramsIt.hasNext()) { bigramsIt.next(); - listener->onVisitEntry(bigramsIt.getProbability(), bigramsIt.getBigramPos()); + listener->onVisitEntry(bigramsIt.getProbability(), + getWordIdFromTerminalPtNodePos(bigramsIt.getBigramPos())); } } +BinaryDictionaryShortcutIterator Ver4PatriciaTriePolicy::getShortcutIterator( + const int wordId) const { + const int shortcutPos = getShortcutPositionOfPtNode(getTerminalPtNodePosFromWordId(wordId)); + return BinaryDictionaryShortcutIterator(&mShortcutPolicy, shortcutPos); +} + int Ver4PatriciaTriePolicy::getShortcutPositionOfPtNode(const int ptNodePos) const { if (ptNodePos == NOT_A_DICT_POS) { return NOT_A_DICT_POS; @@ -193,7 +231,7 @@ int Ver4PatriciaTriePolicy::getBigramsPositionOfPtNode(const int ptNodePos) cons ptNodeParams.getTerminalId()); } -bool Ver4PatriciaTriePolicy::addUnigramEntry(const int *const word, const int length, +bool Ver4PatriciaTriePolicy::addUnigramEntry(const CodePointArrayView wordCodePoints, const UnigramProperty *const unigramProperty) { if (!mBuffers->isUpdatable()) { AKLOGI("Warning: addUnigramEntry() is called for non-updatable dictionary."); @@ -204,13 +242,14 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const int *const word, const int le mDictBuffer->getTailPosition()); return false; } - if (length > MAX_WORD_LENGTH) { - AKLOGE("The word is too long to insert to the dictionary, length: %d", length); + if (wordCodePoints.size() > MAX_WORD_LENGTH) { + AKLOGE("The word is too long to insert to the dictionary, length: %zd", + wordCodePoints.size()); return false; } for (const auto &shortcut : unigramProperty->getShortcuts()) { if (shortcut.getTargetCodePoints()->size() > MAX_WORD_LENGTH) { - AKLOGE("One of shortcut targets is too long to insert to the dictionary, length: %d", + AKLOGE("One of shortcut targets is too long to insert to the dictionary, length: %zd", shortcut.getTargetCodePoints()->size()); return false; } @@ -219,8 +258,8 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const int *const word, const int le readingHelper.initWithPtNodeArrayPos(getRootPosition()); bool addedNewUnigram = false; int codePointsToAdd[MAX_WORD_LENGTH]; - int codePointCountToAdd = length; - memmove(codePointsToAdd, word, sizeof(int) * length); + int codePointCountToAdd = wordCodePoints.size(); + memmove(codePointsToAdd, wordCodePoints.data(), sizeof(int) * codePointCountToAdd); if (unigramProperty->representsBeginningOfSentence()) { codePointCountToAdd = CharUtils::attachBeginningOfSentenceMarker(codePointsToAdd, codePointCountToAdd, MAX_WORD_LENGTH); @@ -228,15 +267,16 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const int *const word, const int le if (codePointCountToAdd <= 0) { return false; } - if (mUpdatingHelper.addUnigramWord(&readingHelper, codePointsToAdd, codePointCountToAdd, - unigramProperty, &addedNewUnigram)) { + const CodePointArrayView codePointArrayView(codePointsToAdd, codePointCountToAdd); + if (mUpdatingHelper.addUnigramWord(&readingHelper, codePointArrayView.data(), + codePointArrayView.size(), unigramProperty, &addedNewUnigram)) { if (addedNewUnigram && !unigramProperty->representsBeginningOfSentence()) { mUnigramCount++; } if (unigramProperty->getShortcuts().size() > 0) { // Add shortcut target. - const int wordPos = getTerminalPtNodePositionOfWord(word, length, - false /* forceLowerCaseSearch */); + const int wordPos = getTerminalPtNodePosFromWordId( + getWordId(codePointArrayView, false /* forceLowerCaseSearch */)); if (wordPos == NOT_A_DICT_POS) { AKLOGE("Cannot find terminal PtNode position to add shortcut target."); return false; @@ -245,7 +285,7 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const int *const word, const int le if (!mUpdatingHelper.addShortcutTarget(wordPos, shortcut.getTargetCodePoints()->data(), shortcut.getTargetCodePoints()->size(), shortcut.getProbability())) { - AKLOGE("Cannot add new shortcut target. PtNodePos: %d, length: %d, " + AKLOGE("Cannot add new shortcut target. PtNodePos: %d, length: %zd, " "probability: %d", wordPos, shortcut.getTargetCodePoints()->size(), shortcut.getProbability()); return false; @@ -258,6 +298,20 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const int *const word, const int le } } +bool Ver4PatriciaTriePolicy::removeUnigramEntry(const CodePointArrayView wordCodePoints) { + if (!mBuffers->isUpdatable()) { + AKLOGI("Warning: removeUnigramEntry() is called for non-updatable dictionary."); + return false; + } + const int ptNodePos = getTerminalPtNodePosFromWordId( + getWordId(wordCodePoints, false /* forceLowerCaseSearch */)); + if (ptNodePos == NOT_A_DICT_POS) { + return false; + } + const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); + return mNodeWriter.suppressUnigramEntry(&ptNodeParams); +} + bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsInfo, const BigramProperty *const bigramProperty) { if (!mBuffers->isUpdatable()) { @@ -275,14 +329,16 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI } if (bigramProperty->getTargetCodePoints()->size() > MAX_WORD_LENGTH) { AKLOGE("The word is too long to insert the ngram to the dictionary. " - "length: %d", bigramProperty->getTargetCodePoints()->size()); + "length: %zd", bigramProperty->getTargetCodePoints()->size()); return false; } - int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; - prevWordsInfo->getPrevWordsTerminalPtNodePos(this, prevWordsPtNodePos, + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; + const WordIdArrayView prevWordIds = prevWordsInfo->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */); - // TODO: Support N-gram. - if (prevWordsPtNodePos[0] == NOT_A_DICT_POS) { + if (prevWordIds.empty()) { + return false; + } + if (prevWordIds[0] == NOT_A_WORD_ID) { if (prevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */)) { const std::vector<UnigramProperty::ShortcutProperty> shortcuts; const UnigramProperty beginningOfSentenceUnigramProperty( @@ -290,27 +346,26 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI false /* isBlacklisted */, MAX_PROBABILITY /* probability */, NOT_A_TIMESTAMP /* timestamp */, 0 /* level */, 0 /* count */, &shortcuts); if (!addUnigramEntry(prevWordsInfo->getNthPrevWordCodePoints(1 /* n */), - prevWordsInfo->getNthPrevWordCodePointCount(1 /* n */), &beginningOfSentenceUnigramProperty)) { AKLOGE("Cannot add unigram entry for the beginning-of-sentence."); return false; } - // Refresh Terminal PtNode positions. - prevWordsInfo->getPrevWordsTerminalPtNodePos(this, prevWordsPtNodePos, - false /* tryLowerCaseSearch */); + // Refresh word ids. + prevWordsInfo->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */); } else { return false; } } - const int word1Pos = getTerminalPtNodePositionOfWord( - bigramProperty->getTargetCodePoints()->data(), - bigramProperty->getTargetCodePoints()->size(), false /* forceLowerCaseSearch */); - if (word1Pos == NOT_A_DICT_POS) { + const int wordPos = getTerminalPtNodePosFromWordId(getWordId( + CodePointArrayView(*bigramProperty->getTargetCodePoints()), + false /* forceLowerCaseSearch */)); + if (wordPos == NOT_A_DICT_POS) { return false; } bool addedNewBigram = false; - if (mUpdatingHelper.addNgramEntry(PtNodePosArrayView::fromObject(prevWordsPtNodePos), - word1Pos, bigramProperty, &addedNewBigram)) { + const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]); + if (mUpdatingHelper.addNgramEntry(PtNodePosArrayView::singleElementView(&prevWordPtNodePos), + wordPos, bigramProperty, &addedNewBigram)) { if (addedNewBigram) { mBigramCount++; } @@ -321,7 +376,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI } bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, - const int *const word, const int length) { + const CodePointArrayView wordCodePoints) { if (!mBuffers->isUpdatable()) { AKLOGI("Warning: removeNgramEntry() is called for non-updatable dictionary."); return false; @@ -335,23 +390,24 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor AKLOGE("prev words info is not valid for removing n-gram entry form the dictionary."); return false; } - if (length > MAX_WORD_LENGTH) { - AKLOGE("word is too long to remove n-gram entry form the dictionary. length: %d", length); + if (wordCodePoints.size() > MAX_WORD_LENGTH) { + AKLOGE("word is too long to remove n-gram entry form the dictionary. length: %zd", + wordCodePoints.size()); } - int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; - prevWordsInfo->getPrevWordsTerminalPtNodePos(this, prevWordsPtNodePos, + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; + const WordIdArrayView prevWordIds = prevWordsInfo->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSerch */); - // TODO: Support N-gram. - if (prevWordsPtNodePos[0] == NOT_A_DICT_POS) { + if (prevWordIds.empty() || prevWordIds[0] == NOT_A_WORD_ID) { return false; } - const int wordPos = getTerminalPtNodePositionOfWord(word, length, - false /* forceLowerCaseSearch */); + const int wordPos = getTerminalPtNodePosFromWordId(getWordId(wordCodePoints, + false /* forceLowerCaseSearch */)); if (wordPos == NOT_A_DICT_POS) { return false; } + const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]); if (mUpdatingHelper.removeNgramEntry( - PtNodePosArrayView::fromObject(prevWordsPtNodePos), wordPos)) { + PtNodePosArrayView::singleElementView(&prevWordPtNodePos), wordPos)) { mBigramCount--; return true; } else { @@ -430,10 +486,10 @@ void Ver4PatriciaTriePolicy::getProperty(const char *const query, const int quer } } -const WordProperty Ver4PatriciaTriePolicy::getWordProperty(const int *const codePoints, - const int codePointCount) const { - const int ptNodePos = getTerminalPtNodePositionOfWord(codePoints, codePointCount, - false /* forceLowerCaseSearch */); +const WordProperty Ver4PatriciaTriePolicy::getWordProperty( + const CodePointArrayView wordCodePoints) const { + const int ptNodePos = getTerminalPtNodePosFromWordId( + getWordId(wordCodePoints, false /* forceLowerCaseSearch */)); if (ptNodePos == NOT_A_DICT_POS) { AKLOGE("getWordProperty is called for invalid word."); return WordProperty(); @@ -468,8 +524,8 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(const int *const code // Word (unigram) probability int word1Probability = NOT_A_PROBABILITY; const int codePointCount = getCodePointsAndProbabilityAndReturnCodePointCount( - word1TerminalPtNodePos, MAX_WORD_LENGTH, bigramWord1CodePoints, - &word1Probability); + getWordIdFromTerminalPtNodePos(word1TerminalPtNodePos), MAX_WORD_LENGTH, + bigramWord1CodePoints, &word1Probability); const std::vector<int> word1(bigramWord1CodePoints, bigramWord1CodePoints + codePointCount); const HistoricalInfo *const historicalInfo = bigramEntry.getHistoricalInfo(); @@ -526,7 +582,8 @@ int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token]; int unigramProbability = NOT_A_PROBABILITY; *outCodePointCount = getCodePointsAndProbabilityAndReturnCodePointCount( - terminalPtNodePos, MAX_WORD_LENGTH, outCodePoints, &unigramProbability); + getWordIdFromTerminalPtNodePos(terminalPtNodePos), MAX_WORD_LENGTH, outCodePoints, + &unigramProbability); const int nextToken = token + 1; if (nextToken >= terminalPtNodePositionsVectorSize) { // All words have been iterated. @@ -536,6 +593,14 @@ int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const return nextToken; } +int Ver4PatriciaTriePolicy::getWordIdFromTerminalPtNodePos(const int ptNodePos) const { + return ptNodePos == NOT_A_DICT_POS ? NOT_A_WORD_ID : ptNodePos; +} + +int Ver4PatriciaTriePolicy::getTerminalPtNodePosFromWordId(const int wordId) const { + return wordId == NOT_A_WORD_ID ? NOT_A_DICT_POS : wordId; +} + } // namespace v402 } // namespace backward } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h index 9e989b268..576d2abb5 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h @@ -28,6 +28,8 @@ #include <vector> #include "defines.h" +#include "suggest/core/dictionary/binary_dictionary_bigrams_iterator.h" +#include "suggest/core/dictionary/binary_dictionary_shortcut_iterator.h" #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" #include "suggest/policyimpl/dictionary/header/header_policy.h" #include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.h" @@ -39,6 +41,7 @@ #include "suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.h" #include "suggest/policyimpl/dictionary/structure/backward/v402/ver4_pt_node_array_reader.h" #include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" +#include "utils/int_array_view.h" namespace latinime { namespace backward { @@ -55,6 +58,7 @@ class DicNodeVector; namespace backward { namespace v402 { +// Word id = Position of a PtNode that represents the word. class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { public: Ver4PatriciaTriePolicy(Ver4DictBuffers::Ver4DictBuffersPtr buffers) @@ -74,7 +78,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { mBigramCount(mHeaderPolicy->getBigramCount()), mTerminalPtNodePositionsForIteratingWords(), mIsCorrupted(false) {}; - AK_FORCE_INLINE int getRootPosition() const { + virtual int getRootPosition() const { return 0; } @@ -82,42 +86,37 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { DicNodeVector *const childDicNodes) const; int getCodePointsAndProbabilityAndReturnCodePointCount( - const int terminalPtNodePos, const int maxCodePointCount, int *const outCodePoints, + const int wordId, const int maxCodePointCount, int *const outCodePoints, int *const outUnigramProbability) const; - int getTerminalPtNodePositionOfWord(const int *const inWord, - const int length, const bool forceLowerCaseSearch) const; + int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const; + + const WordAttributes getWordAttributesInContext(const WordIdArrayView prevWordIds, + const int wordId, MultiBigramMap *const multiBigramMap) const; int getProbability(const int unigramProbability, const int bigramProbability) const; - int getProbabilityOfPtNode(const int *const prevWordsPtNodePos, const int ptNodePos) const; + int getProbabilityOfWord(const WordIdArrayView prevWordIds, const int wordId) const; - void iterateNgramEntries(const int *const prevWordsPtNodePos, + void iterateNgramEntries(const WordIdArrayView prevWordIds, NgramListener *const listener) const; - int getShortcutPositionOfPtNode(const int ptNodePos) const; + BinaryDictionaryShortcutIterator getShortcutIterator(const int wordId) const; const DictionaryHeaderStructurePolicy *getHeaderStructurePolicy() const { return mHeaderPolicy; } - const DictionaryShortcutsStructurePolicy *getShortcutsStructurePolicy() const { - return &mShortcutPolicy; - } - - bool addUnigramEntry(const int *const word, const int length, + bool addUnigramEntry(const CodePointArrayView wordCodePoints, const UnigramProperty *const unigramProperty); - bool removeUnigramEntry(const int *const word, const int length) { - // Removing unigram entry is not supported. - return false; - } + bool removeUnigramEntry(const CodePointArrayView wordCodePoints); bool addNgramEntry(const PrevWordsInfo *const prevWordsInfo, const BigramProperty *const bigramProperty); - bool removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, const int *const word1, - const int length1); + bool removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, + const CodePointArrayView wordCodePoints); bool flush(const char *const filePath); @@ -128,8 +127,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { void getProperty(const char *const query, const int queryLength, char *const outResult, const int maxResultLength); - const WordProperty getWordProperty(const int *const codePoints, - const int codePointCount) const; + const WordProperty getWordProperty(const CodePointArrayView wordCodePoints) const; int getNextWordAndNextToken(const int token, int *const outCodePoints, int *const outCodePointCount); @@ -166,6 +164,11 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { mutable bool mIsCorrupted; int getBigramsPositionOfPtNode(const int ptNodePos) const; + int getShortcutPositionOfPtNode(const int ptNodePos) const; + int getWordIdFromTerminalPtNodePos(const int ptNodePos) const; + int getTerminalPtNodePosFromWordId(const int wordId) const; + const WordAttributes getWordAttributes(const int probability, + const PtNodeParams &ptNodeParams) const; }; } // namespace v402 } // namespace backward diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp index e4ea3da16..9fa93efc9 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp @@ -111,8 +111,7 @@ template<class DictConstants, class DictBuffers, class DictBuffersPtr, class Str return nullptr; } const FormatUtils::FORMAT_VERSION formatVersion = FormatUtils::detectFormatVersion( - mmappedBuffer->getReadOnlyByteArrayView().data(), - mmappedBuffer->getReadOnlyByteArrayView().size()); + mmappedBuffer->getReadOnlyByteArrayView()); switch (formatVersion) { case FormatUtils::VERSION_2: AKLOGE("Given path is a directory but the format is version 2. path: %s", path); @@ -174,8 +173,7 @@ template<class DictConstants, class DictBuffers, class DictBuffersPtr, class Str if (!mmappedBuffer) { return nullptr; } - switch (FormatUtils::detectFormatVersion(mmappedBuffer->getReadOnlyByteArrayView().data(), - mmappedBuffer->getReadOnlyByteArrayView().size())) { + switch (FormatUtils::detectFormatVersion(mmappedBuffer->getReadOnlyByteArrayView())) { case FormatUtils::VERSION_2: return DictionaryStructureWithBufferPolicy::StructurePolicyPtr( new PatriciaTriePolicy(std::move(mmappedBuffer))); diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h index b2e60a837..c12fed324 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h @@ -24,6 +24,7 @@ #include "suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h" #include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h" #include "utils/char_utils.h" +#include "utils/int_array_view.h" namespace latinime { @@ -174,11 +175,17 @@ class PtNodeParams { return mParentPos; } + AK_FORCE_INLINE const CodePointArrayView getCodePointArrayView() const { + return CodePointArrayView(mCodePoints, mCodePointCount); + } + + // TODO: Remove // Number of code points AK_FORCE_INLINE uint8_t getCodePointCount() const { return mCodePointCount; } + // TODO: Remove AK_FORCE_INLINE const int *getCodePoints() const { return mCodePoints; } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp index ea32eb2a9..aae61afca 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp @@ -21,6 +21,7 @@ #include "suggest/core/dicnode/dic_node.h" #include "suggest/core/dicnode/dic_node_vector.h" #include "suggest/core/dictionary/binary_dictionary_bigrams_iterator.h" +#include "suggest/core/dictionary/multi_bigram_map.h" #include "suggest/core/dictionary/ngram_listener.h" #include "suggest/core/session/prev_words_info.h" #include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h" @@ -57,7 +58,7 @@ void PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const dicNo } } -// This retrieves code points and the probability of the word by its terminal position. +// This retrieves code points and the probability of the word by its id. // Due to the fact that words are ordered in the dictionary in a strict breadth-first order, // it is possible to check for this with advantageous complexity. For each PtNode array, we search // for PtNodes with children and compare the children position with the position we look for. @@ -68,16 +69,16 @@ void PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const dicNo // with a z, it's the last PtNode of the root array, so all children addresses will be smaller // than the position we look for, and we have to descend the z PtNode). /* Parameters : - * ptNodePos: the byte position of the terminal PtNode of the word we are searching for (this is - * what is stored as the "bigram position" in each bigram) + * wordId: Id of the word we are searching for. * outCodePoints: an array to write the found word, with MAX_WORD_LENGTH size. * outUnigramProbability: a pointer to an int to write the probability into. * Return value : the code point count, of 0 if the word was not found. */ // TODO: Split this function to be more readable int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( - const int ptNodePos, const int maxCodePointCount, int *const outCodePoints, + const int wordId, const int maxCodePointCount, int *const outCodePoints, int *const outUnigramProbability) const { + const int ptNodePos = getTerminalPtNodePosFromWordId(wordId); int pos = getRootPosition(); int wordPos = 0; // One iteration of the outer loop iterates through PtNode arrays. As stated above, we will @@ -267,18 +268,48 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( } // This function gets the position of the terminal PtNode of the exact matching word in the -// dictionary. If no match is found, it returns NOT_A_DICT_POS. -int PatriciaTriePolicy::getTerminalPtNodePositionOfWord(const int *const inWord, - const int length, const bool forceLowerCaseSearch) const { +// dictionary. If no match is found, it returns NOT_A_WORD_ID. +int PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints, + const bool forceLowerCaseSearch) const { DynamicPtReadingHelper readingHelper(&mPtNodeReader, &mPtNodeArrayReader); readingHelper.initWithPtNodeArrayPos(getRootPosition()); - const int ptNodePos = - readingHelper.getTerminalPtNodePositionOfWord(inWord, length, forceLowerCaseSearch); + const int ptNodePos = readingHelper.getTerminalPtNodePositionOfWord(wordCodePoints.data(), + wordCodePoints.size(), forceLowerCaseSearch); if (readingHelper.isError()) { mIsCorrupted = true; - AKLOGE("Dictionary reading error in createAndGetAllChildDicNodes()."); + AKLOGE("Dictionary reading error in getWordId()."); } - return ptNodePos; + return getWordIdFromTerminalPtNodePos(ptNodePos); +} + +const WordAttributes PatriciaTriePolicy::getWordAttributesInContext( + const WordIdArrayView prevWordIds, const int wordId, + MultiBigramMap *const multiBigramMap) const { + if (wordId == NOT_A_WORD_ID) { + return WordAttributes(); + } + const int ptNodePos = getTerminalPtNodePosFromWordId(wordId); + const PtNodeParams ptNodeParams = + mPtNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); + if (multiBigramMap) { + const int probability = multiBigramMap->getBigramProbability(this /* structurePolicy */, + prevWordIds, wordId, ptNodeParams.getProbability()); + return getWordAttributes(probability, ptNodeParams); + } + if (!prevWordIds.empty()) { + const int bigramProbability = getProbabilityOfWord(prevWordIds, wordId); + if (bigramProbability != NOT_A_PROBABILITY) { + return getWordAttributes(bigramProbability, ptNodeParams); + } + } + return getWordAttributes(getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY), + ptNodeParams); +} + +const WordAttributes PatriciaTriePolicy::getWordAttributes(const int probability, + const PtNodeParams &ptNodeParams) const { + return WordAttributes(probability, ptNodeParams.isBlacklisted(), ptNodeParams.isNotAWord(), + ptNodeParams.getProbability() == 0); } int PatriciaTriePolicy::getProbability(const int unigramProbability, @@ -297,11 +328,12 @@ int PatriciaTriePolicy::getProbability(const int unigramProbability, } } -int PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtNodePos, - const int ptNodePos) const { - if (ptNodePos == NOT_A_DICT_POS) { +int PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds, + const int wordId) const { + if (wordId == NOT_A_WORD_ID) { return NOT_A_PROBABILITY; } + const int ptNodePos = getTerminalPtNodePosFromWordId(wordId); const PtNodeParams ptNodeParams = mPtNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); if (ptNodeParams.isNotAWord() || ptNodeParams.isBlacklisted()) { @@ -310,8 +342,9 @@ int PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtNodeP // for shortcuts). return NOT_A_PROBABILITY; } - if (prevWordsPtNodePos) { - const int bigramsPosition = getBigramsPositionOfPtNode(prevWordsPtNodePos[0]); + if (!prevWordIds.empty()) { + const int bigramsPosition = getBigramsPositionOfPtNode( + getTerminalPtNodePosFromWordId(prevWordIds[0])); BinaryDictionaryBigramsIterator bigramsIt(&mBigramListPolicy, bigramsPosition); while (bigramsIt.hasNext()) { bigramsIt.next(); @@ -325,19 +358,26 @@ int PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtNodeP return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY); } -void PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordsPtNodePos, +void PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordIds, NgramListener *const listener) const { - if (!prevWordsPtNodePos) { + if (prevWordIds.empty()) { return; } - const int bigramsPosition = getBigramsPositionOfPtNode(prevWordsPtNodePos[0]); + const int bigramsPosition = getBigramsPositionOfPtNode( + getTerminalPtNodePosFromWordId(prevWordIds[0])); BinaryDictionaryBigramsIterator bigramsIt(&mBigramListPolicy, bigramsPosition); while (bigramsIt.hasNext()) { bigramsIt.next(); - listener->onVisitEntry(bigramsIt.getProbability(), bigramsIt.getBigramPos()); + listener->onVisitEntry(bigramsIt.getProbability(), + getWordIdFromTerminalPtNodePos(bigramsIt.getBigramPos())); } } +BinaryDictionaryShortcutIterator PatriciaTriePolicy::getShortcutIterator(const int wordId) const { + const int shortcutPos = getShortcutPositionOfPtNode(getTerminalPtNodePosFromWordId(wordId)); + return BinaryDictionaryShortcutIterator(&mShortcutListPolicy, shortcutPos); +} + int PatriciaTriePolicy::getShortcutPositionOfPtNode(const int ptNodePos) const { if (ptNodePos == NOT_A_DICT_POS) { return NOT_A_DICT_POS; @@ -362,29 +402,26 @@ int PatriciaTriePolicy::createAndGetLeavingChildNode(const DicNode *const dicNod int shortcutPos = NOT_A_DICT_POS; int bigramPos = NOT_A_DICT_POS; int siblingPos = NOT_A_DICT_POS; - PatriciaTrieReadingUtils::readPtNodeInfo(mDictRoot, ptNodePos, getShortcutsStructurePolicy(), + PatriciaTrieReadingUtils::readPtNodeInfo(mDictRoot, ptNodePos, &mShortcutListPolicy, &mBigramListPolicy, &flags, &mergedNodeCodePointCount, mergedNodeCodePoints, &probability, &childrenPos, &shortcutPos, &bigramPos, &siblingPos); // Skip PtNodes don't start with Unicode code point because they represent non-word information. if (CharUtils::isInUnicodeSpace(mergedNodeCodePoints[0])) { - childDicNodes->pushLeavingChild(dicNode, ptNodePos, childrenPos, probability, - PatriciaTrieReadingUtils::isTerminal(flags), - PatriciaTrieReadingUtils::hasChildrenInFlags(flags), - PatriciaTrieReadingUtils::isBlacklisted(flags) - || PatriciaTrieReadingUtils::isNotAWord(flags), - mergedNodeCodePointCount, mergedNodeCodePoints); + const int wordId = PatriciaTrieReadingUtils::isTerminal(flags) ? ptNodePos : NOT_A_WORD_ID; + childDicNodes->pushLeavingChild(dicNode, childrenPos, wordId, + CodePointArrayView(mergedNodeCodePoints, mergedNodeCodePointCount)); } return siblingPos; } -const WordProperty PatriciaTriePolicy::getWordProperty(const int *const codePoints, - const int codePointCount) const { - const int ptNodePos = getTerminalPtNodePositionOfWord(codePoints, codePointCount, - false /* forceLowerCaseSearch */); - if (ptNodePos == NOT_A_DICT_POS) { +const WordProperty PatriciaTriePolicy::getWordProperty( + const CodePointArrayView wordCodePoints) const { + const int wordId = getWordId(wordCodePoints, false /* forceLowerCaseSearch */); + if (wordId == NOT_A_WORD_ID) { AKLOGE("getWordProperty was called for invalid word."); return WordProperty(); } + const int ptNodePos = getTerminalPtNodePosFromWordId(wordId); const PtNodeParams ptNodeParams = mPtNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); std::vector<int> codePointVector(ptNodeParams.getCodePoints(), @@ -401,8 +438,8 @@ const WordProperty PatriciaTriePolicy::getWordProperty(const int *const codePoin if (bigramsIt.getBigramPos() != NOT_A_DICT_POS) { int word1Probability = NOT_A_PROBABILITY; const int word1CodePointCount = getCodePointsAndProbabilityAndReturnCodePointCount( - bigramsIt.getBigramPos(), MAX_WORD_LENGTH, bigramWord1CodePoints, - &word1Probability); + getWordIdFromTerminalPtNodePos(bigramsIt.getBigramPos()), MAX_WORD_LENGTH, + bigramWord1CodePoints, &word1Probability); const std::vector<int> word1(bigramWord1CodePoints, bigramWord1CodePoints + word1CodePointCount); const int probability = getProbability(word1Probability, bigramsIt.getProbability()); @@ -456,8 +493,9 @@ int PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outC } const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token]; int unigramProbability = NOT_A_PROBABILITY; - *outCodePointCount = getCodePointsAndProbabilityAndReturnCodePointCount(terminalPtNodePos, - MAX_WORD_LENGTH, outCodePoints, &unigramProbability); + *outCodePointCount = getCodePointsAndProbabilityAndReturnCodePointCount( + getWordIdFromTerminalPtNodePos(terminalPtNodePos), MAX_WORD_LENGTH, outCodePoints, + &unigramProbability); const int nextToken = token + 1; if (nextToken >= terminalPtNodePositionsVectorSize) { // All words have been iterated. @@ -467,4 +505,11 @@ int PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outC return nextToken; } +int PatriciaTriePolicy::getWordIdFromTerminalPtNodePos(const int ptNodePos) const { + return ptNodePos == NOT_A_DICT_POS ? NOT_A_WORD_ID : ptNodePos; +} + +int PatriciaTriePolicy::getTerminalPtNodePosFromWordId(const int wordId) const { + return wordId == NOT_A_WORD_ID ? NOT_A_DICT_POS : wordId; +} } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h index 70351d147..fc65de58c 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h @@ -30,12 +30,14 @@ #include "suggest/policyimpl/dictionary/utils/format_utils.h" #include "suggest/policyimpl/dictionary/utils/mmapped_buffer.h" #include "utils/byte_array_view.h" +#include "utils/int_array_view.h" namespace latinime { class DicNode; class DicNodeVector; +// Word id = Position of a PtNode that represents the word. class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { public: PatriciaTriePolicy(MmappedBuffer::MmappedBufferPtr mmappedBuffer) @@ -59,37 +61,35 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { DicNodeVector *const childDicNodes) const; int getCodePointsAndProbabilityAndReturnCodePointCount( - const int terminalNodePos, const int maxCodePointCount, int *const outCodePoints, + const int wordId, const int maxCodePointCount, int *const outCodePoints, int *const outUnigramProbability) const; - int getTerminalPtNodePositionOfWord(const int *const inWord, - const int length, const bool forceLowerCaseSearch) const; + int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const; + + const WordAttributes getWordAttributesInContext(const WordIdArrayView prevWordIds, + const int wordId, MultiBigramMap *const multiBigramMap) const; int getProbability(const int unigramProbability, const int bigramProbability) const; - int getProbabilityOfPtNode(const int *const prevWordsPtNodePos, const int ptNodePos) const; + int getProbabilityOfWord(const WordIdArrayView prevWordIds, const int wordId) const; - void iterateNgramEntries(const int *const prevWordsPtNodePos, + void iterateNgramEntries(const WordIdArrayView prevWordIds, NgramListener *const listener) const; - int getShortcutPositionOfPtNode(const int ptNodePos) const; + BinaryDictionaryShortcutIterator getShortcutIterator(const int wordId) const; const DictionaryHeaderStructurePolicy *getHeaderStructurePolicy() const { return &mHeaderPolicy; } - const DictionaryShortcutsStructurePolicy *getShortcutsStructurePolicy() const { - return &mShortcutListPolicy; - } - - bool addUnigramEntry(const int *const word, const int length, + bool addUnigramEntry(const CodePointArrayView wordCodePoints, const UnigramProperty *const unigramProperty) { // This method should not be called for non-updatable dictionary. AKLOGI("Warning: addUnigramEntry() is called for non-updatable dictionary."); return false; } - bool removeUnigramEntry(const int *const word, const int length) { + bool removeUnigramEntry(const CodePointArrayView wordCodePoints) { // This method should not be called for non-updatable dictionary. AKLOGI("Warning: removeUnigramEntry() is called for non-updatable dictionary."); return false; @@ -102,8 +102,8 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { return false; } - bool removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, const int *const word, - const int length) { + bool removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, + const CodePointArrayView wordCodePoints) { // This method should not be called for non-updatable dictionary. AKLOGI("Warning: removeNgramEntry() is called for non-updatable dictionary."); return false; @@ -135,8 +135,7 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { } } - const WordProperty getWordProperty(const int *const codePoints, - const int codePointCount) const; + const WordProperty getWordProperty(const CodePointArrayView wordCodePoints) const; int getNextWordAndNextToken(const int token, int *const outCodePoints, int *const outCodePointCount); @@ -159,9 +158,14 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { std::vector<int> mTerminalPtNodePositionsForIteratingWords; mutable bool mIsCorrupted; + int getShortcutPositionOfPtNode(const int ptNodePos) const; int getBigramsPositionOfPtNode(const int ptNodePos) const; int createAndGetLeavingChildNode(const DicNode *const dicNode, const int ptNodePos, DicNodeVector *const childDicNodes) const; + int getWordIdFromTerminalPtNodePos(const int ptNodePos) const; + int getTerminalPtNodePosFromWordId(const int wordId) const; + const WordAttributes getWordAttributes(const int probability, + const PtNodeParams &ptNodeParams) const; }; } // namespace latinime #endif // LATINIME_PATRICIA_TRIE_POLICY_H diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/bigram/ver4_bigram_list_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/bigram/ver4_bigram_list_policy.cpp deleted file mode 100644 index 08dc107ab..000000000 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/bigram/ver4_bigram_list_policy.cpp +++ /dev/null @@ -1,282 +0,0 @@ -/* - * Copyright (C) 2013 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "suggest/policyimpl/dictionary/structure/v4/bigram/ver4_bigram_list_policy.h" - -#include "suggest/core/dictionary/property/bigram_property.h" -#include "suggest/policyimpl/dictionary/header/header_policy.h" -#include "suggest/policyimpl/dictionary/structure/pt_common/bigram/bigram_list_read_write_utils.h" -#include "suggest/policyimpl/dictionary/structure/v4/content/bigram_dict_content.h" -#include "suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h" -#include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h" -#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h" - -namespace latinime { - -void Ver4BigramListPolicy::getNextBigram(int *const outBigramPos, int *const outProbability, - bool *const outHasNext, int *const bigramEntryPos) const { - const BigramEntry bigramEntry = - mBigramDictContent->getBigramEntryAndAdvancePosition(bigramEntryPos); - if (outBigramPos) { - // Lookup target PtNode position. - *outBigramPos = mTerminalPositionLookupTable->getTerminalPtNodePosition( - bigramEntry.getTargetTerminalId()); - } - if (outProbability) { - if (bigramEntry.hasHistoricalInfo()) { - *outProbability = - ForgettingCurveUtils::decodeProbability(bigramEntry.getHistoricalInfo(), - mHeaderPolicy); - } else { - *outProbability = bigramEntry.getProbability(); - } - } - if (outHasNext) { - *outHasNext = bigramEntry.hasNext(); - } -} - -bool Ver4BigramListPolicy::addNewEntry(const int terminalId, const int newTargetTerminalId, - const BigramProperty *const bigramProperty, bool *const outAddedNewEntry) { - // 1. The word has no bigrams yet. - // 2. The word has bigrams, and there is the target in the list. - // 3. The word has bigrams, and there is an invalid entry that can be reclaimed. - // 4. The word has bigrams. We have to append new bigram entry to the list. - // 5. Same as 4, but the list is the last entry of the content file. - if (outAddedNewEntry) { - *outAddedNewEntry = false; - } - const int bigramListPos = mBigramDictContent->getBigramListHeadPos(terminalId); - if (bigramListPos == NOT_A_DICT_POS) { - // Case 1. PtNode that doesn't have a bigram list. - // Create new bigram list. - if (!mBigramDictContent->createNewBigramList(terminalId)) { - return false; - } - const BigramEntry newBigramEntry(false /* hasNext */, NOT_A_PROBABILITY, - newTargetTerminalId); - const BigramEntry bigramEntryToWrite = createUpdatedBigramEntryFrom(&newBigramEntry, - bigramProperty); - // Write an entry. - int writingPos = mBigramDictContent->getBigramListHeadPos(terminalId); - if (!mBigramDictContent->writeBigramEntryAndAdvancePosition(&bigramEntryToWrite, - &writingPos)) { - AKLOGE("Cannot write bigram entry. pos: %d.", writingPos); - return false; - } - if (!mBigramDictContent->writeTerminator(writingPos)) { - AKLOGE("Cannot write bigram list terminator. pos: %d.", writingPos); - return false; - } - if (outAddedNewEntry) { - *outAddedNewEntry = true; - } - return true; - } - - int tailEntryPos = NOT_A_DICT_POS; - const int entryPosToUpdate = getEntryPosToUpdate(newTargetTerminalId, bigramListPos, - &tailEntryPos); - if (entryPosToUpdate == NOT_A_DICT_POS) { - // Case 4, 5. Add new entry to the bigram list. - const int contentTailPos = mBigramDictContent->getContentTailPos(); - // If the tail entry is at the tail of content buffer, the new entry can be written without - // link (Case 5). - const bool canAppendEntry = - contentTailPos == tailEntryPos + mBigramDictContent->getBigramEntrySize(); - const int newEntryPos = canAppendEntry ? tailEntryPos : contentTailPos; - int writingPos = newEntryPos; - // Write new entry at the tail position of the bigram content. - const BigramEntry newBigramEntry(false /* hasNext */, NOT_A_PROBABILITY, - newTargetTerminalId); - const BigramEntry bigramEntryToWrite = createUpdatedBigramEntryFrom( - &newBigramEntry, bigramProperty); - if (!mBigramDictContent->writeBigramEntryAndAdvancePosition(&bigramEntryToWrite, - &writingPos)) { - AKLOGE("Cannot write bigram entry. pos: %d.", writingPos); - return false; - } - if (!mBigramDictContent->writeTerminator(writingPos)) { - AKLOGE("Cannot write bigram list terminator. pos: %d.", writingPos); - return false; - } - if (!canAppendEntry) { - // Update link of the current tail entry. - if (!mBigramDictContent->writeLink(newEntryPos, tailEntryPos)) { - AKLOGE("Cannot update bigram entry link. pos: %d, linked entry pos: %d.", - tailEntryPos, newEntryPos); - return false; - } - } - if (outAddedNewEntry) { - *outAddedNewEntry = true; - } - return true; - } - - // Case 2. Overwrite the existing entry. Case 3. Reclaim and reuse the existing invalid entry. - const BigramEntry originalBigramEntry = mBigramDictContent->getBigramEntry(entryPosToUpdate); - if (!originalBigramEntry.isValid()) { - // Case 3. Reuse the existing invalid entry. outAddedNewEntry is false when an existing - // entry is updated. - if (outAddedNewEntry) { - *outAddedNewEntry = true; - } - } - const BigramEntry updatedBigramEntry = - originalBigramEntry.updateTargetTerminalIdAndGetEntry(newTargetTerminalId); - const BigramEntry bigramEntryToWrite = createUpdatedBigramEntryFrom( - &updatedBigramEntry, bigramProperty); - return mBigramDictContent->writeBigramEntry(&bigramEntryToWrite, entryPosToUpdate); -} - -bool Ver4BigramListPolicy::removeEntry(const int terminalId, const int targetTerminalId) { - const int bigramListPos = mBigramDictContent->getBigramListHeadPos(terminalId); - if (bigramListPos == NOT_A_DICT_POS) { - // Bigram list doesn't exist. - return false; - } - const int entryPosToUpdate = getEntryPosToUpdate(targetTerminalId, bigramListPos, - nullptr /* outTailEntryPos */); - if (entryPosToUpdate == NOT_A_DICT_POS) { - // Bigram entry doesn't exist. - return false; - } - const BigramEntry bigramEntry = mBigramDictContent->getBigramEntry(entryPosToUpdate); - if (targetTerminalId != bigramEntry.getTargetTerminalId()) { - // Bigram entry doesn't exist. - return false; - } - // Remove bigram entry by marking it as invalid entry and overwriting the original entry. - const BigramEntry updatedBigramEntry = bigramEntry.getInvalidatedEntry(); - return mBigramDictContent->writeBigramEntry(&updatedBigramEntry, entryPosToUpdate); -} - -bool Ver4BigramListPolicy::updateAllBigramEntriesAndDeleteUselessEntries(const int terminalId, - int *const outBigramCount) { - const int bigramListPos = mBigramDictContent->getBigramListHeadPos(terminalId); - if (bigramListPos == NOT_A_DICT_POS) { - // Bigram list doesn't exist. - return true; - } - bool hasNext = true; - int readingPos = bigramListPos; - while (hasNext) { - const BigramEntry bigramEntry = - mBigramDictContent->getBigramEntryAndAdvancePosition(&readingPos); - const int entryPos = readingPos - mBigramDictContent->getBigramEntrySize(); - hasNext = bigramEntry.hasNext(); - if (!bigramEntry.isValid()) { - continue; - } - const int targetPtNodePos = mTerminalPositionLookupTable->getTerminalPtNodePosition( - bigramEntry.getTargetTerminalId()); - if (targetPtNodePos == NOT_A_DICT_POS) { - // Invalidate bigram entry. - const BigramEntry updatedBigramEntry = bigramEntry.getInvalidatedEntry(); - if (!mBigramDictContent->writeBigramEntry(&updatedBigramEntry, entryPos)) { - return false; - } - } else if (bigramEntry.hasHistoricalInfo()) { - const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave( - bigramEntry.getHistoricalInfo(), mHeaderPolicy); - if (ForgettingCurveUtils::needsToKeep(&historicalInfo, mHeaderPolicy)) { - const BigramEntry updatedBigramEntry = - bigramEntry.updateHistoricalInfoAndGetEntry(&historicalInfo); - if (!mBigramDictContent->writeBigramEntry(&updatedBigramEntry, entryPos)) { - return false; - } - *outBigramCount += 1; - } else { - // Remove entry. - const BigramEntry updatedBigramEntry = bigramEntry.getInvalidatedEntry(); - if (!mBigramDictContent->writeBigramEntry(&updatedBigramEntry, entryPos)) { - return false; - } - } - } else { - *outBigramCount += 1; - } - } - return true; -} - -int Ver4BigramListPolicy::getBigramEntryConut(const int terminalId) { - const int bigramListPos = mBigramDictContent->getBigramListHeadPos(terminalId); - if (bigramListPos == NOT_A_DICT_POS) { - // Bigram list doesn't exist. - return 0; - } - int bigramCount = 0; - bool hasNext = true; - int readingPos = bigramListPos; - while (hasNext) { - const BigramEntry bigramEntry = - mBigramDictContent->getBigramEntryAndAdvancePosition(&readingPos); - hasNext = bigramEntry.hasNext(); - if (bigramEntry.isValid()) { - bigramCount++; - } - } - return bigramCount; -} - -int Ver4BigramListPolicy::getEntryPosToUpdate(const int targetTerminalIdToFind, - const int bigramListPos, int *const outTailEntryPos) const { - if (outTailEntryPos) { - *outTailEntryPos = NOT_A_DICT_POS; - } - int invalidEntryPos = NOT_A_DICT_POS; - int readingPos = bigramListPos; - while (true) { - const BigramEntry bigramEntry = - mBigramDictContent->getBigramEntryAndAdvancePosition(&readingPos); - const int entryPos = readingPos - mBigramDictContent->getBigramEntrySize(); - if (!bigramEntry.hasNext()) { - if (outTailEntryPos) { - *outTailEntryPos = entryPos; - } - break; - } - if (bigramEntry.getTargetTerminalId() == targetTerminalIdToFind) { - // Entry with same target is found. - return entryPos; - } else if (!bigramEntry.isValid()) { - // Invalid entry that can be reused is found. - invalidEntryPos = entryPos; - } - } - return invalidEntryPos; -} - -const BigramEntry Ver4BigramListPolicy::createUpdatedBigramEntryFrom( - const BigramEntry *const originalBigramEntry, - const BigramProperty *const bigramProperty) const { - // TODO: Consolidate historical info and probability. - if (mHeaderPolicy->hasHistoricalInfoOfWords()) { - const HistoricalInfo historicalInfoForUpdate(bigramProperty->getTimestamp(), - bigramProperty->getLevel(), bigramProperty->getCount()); - const HistoricalInfo updatedHistoricalInfo = - ForgettingCurveUtils::createUpdatedHistoricalInfo( - originalBigramEntry->getHistoricalInfo(), bigramProperty->getProbability(), - &historicalInfoForUpdate, mHeaderPolicy); - return originalBigramEntry->updateHistoricalInfoAndGetEntry(&updatedHistoricalInfo); - } else { - return originalBigramEntry->updateProbabilityAndGetEntry(bigramProperty->getProbability()); - } -} - -} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/bigram/ver4_bigram_list_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/bigram/ver4_bigram_list_policy.h deleted file mode 100644 index 4b3bb3725..000000000 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/bigram/ver4_bigram_list_policy.h +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Copyright (C) 2013 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LATINIME_VER4_BIGRAM_LIST_POLICY_H -#define LATINIME_VER4_BIGRAM_LIST_POLICY_H - -#include "defines.h" -#include "suggest/core/policy/dictionary_bigrams_structure_policy.h" -#include "suggest/policyimpl/dictionary/structure/v4/content/bigram_entry.h" - -namespace latinime { - -class BigramDictContent; -class BigramProperty; -class HeaderPolicy; -class TerminalPositionLookupTable; - -class Ver4BigramListPolicy : public DictionaryBigramsStructurePolicy { - public: - Ver4BigramListPolicy(BigramDictContent *const bigramDictContent, - const TerminalPositionLookupTable *const terminalPositionLookupTable, - const HeaderPolicy *const headerPolicy) - : mBigramDictContent(bigramDictContent), - mTerminalPositionLookupTable(terminalPositionLookupTable), - mHeaderPolicy(headerPolicy) {} - - void getNextBigram(int *const outBigramPos, int *const outProbability, - bool *const outHasNext, int *const bigramEntryPos) const; - - bool skipAllBigrams(int *const pos) const { - // Do nothing because we don't need to skip bigram lists in ver4 dictionaries. - return true; - } - - bool addNewEntry(const int terminalId, const int newTargetTerminalId, - const BigramProperty *const bigramProperty, bool *const outAddedNewEntry); - - bool removeEntry(const int terminalId, const int targetTerminalId); - - bool updateAllBigramEntriesAndDeleteUselessEntries(const int terminalId, - int *const outBigramCount); - - int getBigramEntryConut(const int terminalId); - - private: - DISALLOW_IMPLICIT_CONSTRUCTORS(Ver4BigramListPolicy); - - int getEntryPosToUpdate(const int targetTerminalIdToFind, const int bigramListPos, - int *const outTailEntryPos) const; - - const BigramEntry createUpdatedBigramEntryFrom(const BigramEntry *const originalBigramEntry, - const BigramProperty *const bigramProperty) const; - - BigramDictContent *const mBigramDictContent; - const TerminalPositionLookupTable *const mTerminalPositionLookupTable; - const HeaderPolicy *const mHeaderPolicy; -}; -} // namespace latinime -#endif /* LATINIME_VER4_BIGRAM_LIST_POLICY_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/bigram_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/bigram_dict_content.cpp deleted file mode 100644 index d7e1952b5..000000000 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/bigram_dict_content.cpp +++ /dev/null @@ -1,219 +0,0 @@ -/* - * Copyright (C) 2013 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "suggest/policyimpl/dictionary/structure/v4/content/bigram_dict_content.h" - -#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" - -namespace latinime { - -const int BigramDictContent::INVALID_LINKED_ENTRY_POS = Ver4DictConstants::NOT_A_TERMINAL_ID; - -const BigramEntry BigramDictContent::getBigramEntryAndAdvancePosition( - int *const bigramEntryPos) const { - const BufferWithExtendableBuffer *const bigramListBuffer = getContentBuffer(); - const int bigramEntryTailPos = (*bigramEntryPos) + getBigramEntrySize(); - if (*bigramEntryPos < 0 || bigramEntryTailPos > bigramListBuffer->getTailPosition()) { - AKLOGE("Invalid bigram entry position. bigramEntryPos: %d, bigramEntryTailPos: %d, " - "bufSize: %d", *bigramEntryPos, bigramEntryTailPos, - bigramListBuffer->getTailPosition()); - ASSERT(false); - return BigramEntry(false /* hasNext */, NOT_A_PROBABILITY, - Ver4DictConstants::NOT_A_TERMINAL_ID); - } - const int bigramFlags = bigramListBuffer->readUintAndAdvancePosition( - Ver4DictConstants::BIGRAM_FLAGS_FIELD_SIZE, bigramEntryPos); - const bool isLink = (bigramFlags & Ver4DictConstants::BIGRAM_IS_LINK_MASK) != 0; - int probability = NOT_A_PROBABILITY; - int timestamp = NOT_A_TIMESTAMP; - int level = 0; - int count = 0; - if (mHasHistoricalInfo) { - timestamp = bigramListBuffer->readUintAndAdvancePosition( - Ver4DictConstants::TIME_STAMP_FIELD_SIZE, bigramEntryPos); - level = bigramListBuffer->readUintAndAdvancePosition( - Ver4DictConstants::WORD_LEVEL_FIELD_SIZE, bigramEntryPos); - count = bigramListBuffer->readUintAndAdvancePosition( - Ver4DictConstants::WORD_COUNT_FIELD_SIZE, bigramEntryPos); - } else { - probability = bigramListBuffer->readUintAndAdvancePosition( - Ver4DictConstants::PROBABILITY_SIZE, bigramEntryPos); - } - const int encodedTargetTerminalId = bigramListBuffer->readUintAndAdvancePosition( - Ver4DictConstants::BIGRAM_TARGET_TERMINAL_ID_FIELD_SIZE, bigramEntryPos); - const int targetTerminalId = - (encodedTargetTerminalId == Ver4DictConstants::INVALID_BIGRAM_TARGET_TERMINAL_ID) ? - Ver4DictConstants::NOT_A_TERMINAL_ID : encodedTargetTerminalId; - if (isLink) { - const int linkedEntryPos = targetTerminalId; - if (linkedEntryPos == INVALID_LINKED_ENTRY_POS) { - // Bigram list terminator is found. - return BigramEntry(false /* hasNext */, NOT_A_PROBABILITY, - Ver4DictConstants::NOT_A_TERMINAL_ID); - } - *bigramEntryPos = linkedEntryPos; - return getBigramEntryAndAdvancePosition(bigramEntryPos); - } - // hasNext is always true because we should continue to read the next entry until the terminator - // is found. - if (mHasHistoricalInfo) { - const HistoricalInfo historicalInfo(timestamp, level, count); - return BigramEntry(true /* hasNext */, probability, &historicalInfo, targetTerminalId); - } else { - return BigramEntry(true /* hasNext */, probability, targetTerminalId); - } -} - -bool BigramDictContent::writeBigramEntryAndAdvancePosition( - const BigramEntry *const bigramEntryToWrite, int *const entryWritingPos) { - return writeBigramEntryAttributesAndAdvancePosition(false /* isLink */, - bigramEntryToWrite->getProbability(), bigramEntryToWrite->getTargetTerminalId(), - bigramEntryToWrite->getHistoricalInfo()->getTimeStamp(), - bigramEntryToWrite->getHistoricalInfo()->getLevel(), - bigramEntryToWrite->getHistoricalInfo()->getCount(), - entryWritingPos); -} - -bool BigramDictContent::writeBigramEntryAttributesAndAdvancePosition( - const bool isLink, const int probability, const int targetTerminalId, - const int timestamp, const int level, const int count, int *const entryWritingPos) { - BufferWithExtendableBuffer *const bigramListBuffer = getWritableContentBuffer(); - const int bigramFlags = isLink ? Ver4DictConstants::BIGRAM_IS_LINK_MASK : 0; - if (!bigramListBuffer->writeUintAndAdvancePosition(bigramFlags, - Ver4DictConstants::BIGRAM_FLAGS_FIELD_SIZE, entryWritingPos)) { - AKLOGE("Cannot write bigram flags. pos: %d, flags: %x", *entryWritingPos, bigramFlags); - return false; - } - if (mHasHistoricalInfo) { - if (!bigramListBuffer->writeUintAndAdvancePosition(timestamp, - Ver4DictConstants::TIME_STAMP_FIELD_SIZE, entryWritingPos)) { - AKLOGE("Cannot write bigram timestamps. pos: %d, timestamp: %d", *entryWritingPos, - timestamp); - return false; - } - if (!bigramListBuffer->writeUintAndAdvancePosition(level, - Ver4DictConstants::WORD_LEVEL_FIELD_SIZE, entryWritingPos)) { - AKLOGE("Cannot write bigram level. pos: %d, level: %d", *entryWritingPos, - level); - return false; - } - if (!bigramListBuffer->writeUintAndAdvancePosition(count, - Ver4DictConstants::WORD_COUNT_FIELD_SIZE, entryWritingPos)) { - AKLOGE("Cannot write bigram count. pos: %d, count: %d", *entryWritingPos, - count); - return false; - } - } else { - if (!bigramListBuffer->writeUintAndAdvancePosition(probability, - Ver4DictConstants::PROBABILITY_SIZE, entryWritingPos)) { - AKLOGE("Cannot write bigram probability. pos: %d, probability: %d", *entryWritingPos, - probability); - return false; - } - } - const int targetTerminalIdToWrite = (targetTerminalId == Ver4DictConstants::NOT_A_TERMINAL_ID) ? - Ver4DictConstants::INVALID_BIGRAM_TARGET_TERMINAL_ID : targetTerminalId; - if (!bigramListBuffer->writeUintAndAdvancePosition(targetTerminalIdToWrite, - Ver4DictConstants::BIGRAM_TARGET_TERMINAL_ID_FIELD_SIZE, entryWritingPos)) { - AKLOGE("Cannot write bigram target terminal id. pos: %d, target terminal id: %d", - *entryWritingPos, targetTerminalId); - return false; - } - return true; -} - -bool BigramDictContent::writeLink(const int linkedEntryPos, const int writingPos) { - const int targetTerminalId = linkedEntryPos; - int pos = writingPos; - return writeBigramEntryAttributesAndAdvancePosition(true /* isLink */, - NOT_A_PROBABILITY /* probability */, targetTerminalId, NOT_A_TIMESTAMP, 0 /* level */, - 0 /* count */, &pos); -} - -bool BigramDictContent::runGC(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, - const BigramDictContent *const originalBigramDictContent, - int *const outBigramEntryCount) { - for (TerminalPositionLookupTable::TerminalIdMap::const_iterator it = terminalIdMap->begin(); - it != terminalIdMap->end(); ++it) { - const int originalBigramListPos = - originalBigramDictContent->getBigramListHeadPos(it->first); - if (originalBigramListPos == NOT_A_DICT_POS) { - // This terminal does not have a bigram list. - continue; - } - const int bigramListPos = getContentBuffer()->getTailPosition(); - int bigramEntryCount = 0; - // Copy bigram list with GC from original content. - if (!runGCBigramList(originalBigramListPos, originalBigramDictContent, bigramListPos, - terminalIdMap, &bigramEntryCount)) { - AKLOGE("Cannot complete GC for the bigram list. original pos: %d, pos: %d", - originalBigramListPos, bigramListPos); - return false; - } - if (bigramEntryCount == 0) { - // All bigram entries are useless. This terminal does not have a bigram list. - continue; - } - *outBigramEntryCount += bigramEntryCount; - // Set bigram list position to the lookup table. - if (!getUpdatableAddressLookupTable()->set(it->second, bigramListPos)) { - AKLOGE("Cannot set bigram list position. terminal id: %d, pos: %d", - it->second, bigramListPos); - return false; - } - } - return true; -} - -// Returns whether GC for the bigram list was succeeded or not. -bool BigramDictContent::runGCBigramList(const int bigramListPos, - const BigramDictContent *const sourceBigramDictContent, const int toPos, - const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, - int *const outEntryCount) { - bool hasNext = true; - int readingPos = bigramListPos; - int writingPos = toPos; - while (hasNext) { - const BigramEntry originalBigramEntry = - sourceBigramDictContent->getBigramEntryAndAdvancePosition(&readingPos); - hasNext = originalBigramEntry.hasNext(); - if (!originalBigramEntry.isValid()) { - continue; - } - TerminalPositionLookupTable::TerminalIdMap::const_iterator it = - terminalIdMap->find(originalBigramEntry.getTargetTerminalId()); - if (it == terminalIdMap->end()) { - // Target word has been removed. - continue; - } - const BigramEntry updatedBigramEntry = - originalBigramEntry.updateTargetTerminalIdAndGetEntry(it->second); - if (!writeBigramEntryAndAdvancePosition(&updatedBigramEntry, &writingPos)) { - AKLOGE("Cannot write bigram entry to run GC. pos: %d", writingPos); - return false; - } - *outEntryCount += 1; - } - if (*outEntryCount > 0) { - if (!writeTerminator(writingPos)) { - AKLOGE("Cannot write terminator to run GC. pos: %d", writingPos); - return false; - } - } - return true; -} - -} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/bigram_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/bigram_dict_content.h deleted file mode 100644 index 361dd2c74..000000000 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/bigram_dict_content.h +++ /dev/null @@ -1,128 +0,0 @@ -/* - * Copyright (C) 2013, The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LATINIME_BIGRAM_DICT_CONTENT_H -#define LATINIME_BIGRAM_DICT_CONTENT_H - -#include <cstdint> -#include <cstdio> - -#include "defines.h" -#include "suggest/policyimpl/dictionary/structure/v4/content/bigram_entry.h" -#include "suggest/policyimpl/dictionary/structure/v4/content/sparse_table_dict_content.h" -#include "suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h" -#include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h" - -namespace latinime { - -class BigramDictContent : public SparseTableDictContent { - public: - BigramDictContent(uint8_t *const *buffers, const int *bufferSizes, const bool hasHistoricalInfo) - : SparseTableDictContent(buffers, bufferSizes, - Ver4DictConstants::BIGRAM_ADDRESS_TABLE_BLOCK_SIZE, - Ver4DictConstants::BIGRAM_ADDRESS_TABLE_DATA_SIZE), - mHasHistoricalInfo(hasHistoricalInfo) {} - - BigramDictContent(const bool hasHistoricalInfo) - : SparseTableDictContent(Ver4DictConstants::BIGRAM_ADDRESS_TABLE_BLOCK_SIZE, - Ver4DictConstants::BIGRAM_ADDRESS_TABLE_DATA_SIZE), - mHasHistoricalInfo(hasHistoricalInfo) {} - - int getContentTailPos() const { - return getContentBuffer()->getTailPosition(); - } - - const BigramEntry getBigramEntry(const int bigramEntryPos) const { - int readingPos = bigramEntryPos; - return getBigramEntryAndAdvancePosition(&readingPos); - } - - const BigramEntry getBigramEntryAndAdvancePosition(int *const bigramEntryPos) const; - - // Returns head position of bigram list for a PtNode specified by terminalId. - int getBigramListHeadPos(const int terminalId) const { - const SparseTable *const addressLookupTable = getAddressLookupTable(); - if (!addressLookupTable->contains(terminalId)) { - return NOT_A_DICT_POS; - } - return addressLookupTable->get(terminalId); - } - - bool writeBigramEntryAtTail(const BigramEntry *const bigramEntryToWrite) { - int writingPos = getContentBuffer()->getTailPosition(); - return writeBigramEntryAndAdvancePosition(bigramEntryToWrite, &writingPos); - } - - bool writeBigramEntry(const BigramEntry *const bigramEntryToWrite, const int entryWritingPos) { - int writingPos = entryWritingPos; - return writeBigramEntryAndAdvancePosition(bigramEntryToWrite, &writingPos); - } - - bool writeBigramEntryAndAdvancePosition(const BigramEntry *const bigramEntryToWrite, - int *const entryWritingPos); - - bool writeTerminator(const int writingPos) { - // Terminator is a link to the invalid position. - return writeLink(INVALID_LINKED_ENTRY_POS, writingPos); - } - - bool writeLink(const int linkedPos, const int writingPos); - - bool createNewBigramList(const int terminalId) { - const int bigramListPos = getContentBuffer()->getTailPosition(); - return getUpdatableAddressLookupTable()->set(terminalId, bigramListPos); - } - - bool flushToFile(FILE *const file) const { - return flush(file); - } - - bool runGC(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, - const BigramDictContent *const originalBigramDictContent, - int *const outBigramEntryCount); - - int getBigramEntrySize() const { - if (mHasHistoricalInfo) { - return Ver4DictConstants::BIGRAM_FLAGS_FIELD_SIZE - + Ver4DictConstants::TIME_STAMP_FIELD_SIZE - + Ver4DictConstants::WORD_LEVEL_FIELD_SIZE - + Ver4DictConstants::WORD_COUNT_FIELD_SIZE - + Ver4DictConstants::BIGRAM_TARGET_TERMINAL_ID_FIELD_SIZE; - } else { - return Ver4DictConstants::BIGRAM_FLAGS_FIELD_SIZE - + Ver4DictConstants::PROBABILITY_SIZE - + Ver4DictConstants::BIGRAM_TARGET_TERMINAL_ID_FIELD_SIZE; - } - } - - private: - DISALLOW_COPY_AND_ASSIGN(BigramDictContent); - - static const int INVALID_LINKED_ENTRY_POS; - - bool writeBigramEntryAttributesAndAdvancePosition( - const bool isLink, const int probability, const int targetTerminalId, - const int timestamp, const int level, const int count, int *const entryWritingPos); - - bool runGCBigramList(const int bigramListPos, - const BigramDictContent *const sourceBigramDictContent, const int toPos, - const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, - int *const outEntryCount); - - bool mHasHistoricalInfo; -}; -} // namespace latinime -#endif /* LATINIME_BIGRAM_DICT_CONTENT_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/bigram_entry.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/bigram_entry.h deleted file mode 100644 index 2b0cbd93b..000000000 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/bigram_entry.h +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Copyright (C) 2013, The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LATINIME_BIGRAM_ENTRY_H -#define LATINIME_BIGRAM_ENTRY_H - -#include "defines.h" -#include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h" -#include "suggest/policyimpl/dictionary/utils/historical_info.h" - -namespace latinime { - -class BigramEntry { - public: - BigramEntry(const BigramEntry& bigramEntry) - : mHasNext(bigramEntry.mHasNext), mProbability(bigramEntry.mProbability), - mHistoricalInfo(), mTargetTerminalId(bigramEntry.mTargetTerminalId) {} - - // Entry with historical information. - BigramEntry(const bool hasNext, const int probability, const int targetTerminalId) - : mHasNext(hasNext), mProbability(probability), mHistoricalInfo(), - mTargetTerminalId(targetTerminalId) {} - - // Entry with historical information. - BigramEntry(const bool hasNext, const int probability, - const HistoricalInfo *const historicalInfo, const int targetTerminalId) - : mHasNext(hasNext), mProbability(probability), mHistoricalInfo(*historicalInfo), - mTargetTerminalId(targetTerminalId) {} - - const BigramEntry getInvalidatedEntry() const { - return updateTargetTerminalIdAndGetEntry(Ver4DictConstants::NOT_A_TERMINAL_ID); - } - - const BigramEntry updateHasNextAndGetEntry(const bool hasNext) const { - return BigramEntry(hasNext, mProbability, &mHistoricalInfo, mTargetTerminalId); - } - - const BigramEntry updateTargetTerminalIdAndGetEntry(const int newTargetTerminalId) const { - return BigramEntry(mHasNext, mProbability, &mHistoricalInfo, newTargetTerminalId); - } - - const BigramEntry updateProbabilityAndGetEntry(const int probability) const { - return BigramEntry(mHasNext, probability, &mHistoricalInfo, mTargetTerminalId); - } - - const BigramEntry updateHistoricalInfoAndGetEntry( - const HistoricalInfo *const historicalInfo) const { - return BigramEntry(mHasNext, mProbability, historicalInfo, mTargetTerminalId); - } - - bool isValid() const { - return mTargetTerminalId != Ver4DictConstants::NOT_A_TERMINAL_ID; - } - - bool hasNext() const { - return mHasNext; - } - - int getProbability() const { - return mProbability; - } - - bool hasHistoricalInfo() const { - return mHistoricalInfo.isValid(); - } - - const HistoricalInfo *getHistoricalInfo() const { - return &mHistoricalInfo; - } - - int getTargetTerminalId() const { - return mTargetTerminalId; - } - - private: - // Copy constructor is public to use this class as a type of return value. - DISALLOW_DEFAULT_CONSTRUCTOR(BigramEntry); - DISALLOW_ASSIGNMENT_OPERATOR(BigramEntry); - - const bool mHasNext; - const int mProbability; - const HistoricalInfo mHistoricalInfo; - const int mTargetTerminalId; -}; -} // namespace latinime -#endif /* LATINIME_BIGRAM_ENTRY_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp index 5dc91ba10..0675de6fa 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp @@ -16,8 +16,16 @@ #include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h" +#include <algorithm> +#include <cstring> + +#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h" + namespace latinime { +const int LanguageModelDictContent::UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE = 0; +const int LanguageModelDictContent::BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE = 1; + bool LanguageModelDictContent::save(FILE *const file) const { return mTrieMap.save(file); } @@ -30,6 +38,41 @@ bool LanguageModelDictContent::runGC( 0 /* nextLevelBitmapEntryIndex */, outNgramCount); } +int LanguageModelDictContent::getWordProbability(const WordIdArrayView prevWordIds, + const int wordId, const HeaderPolicy *const headerPolicy) const { + int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; + bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex(); + int maxLevel = 0; + for (size_t i = 0; i < prevWordIds.size(); ++i) { + const int nextBitmapEntryIndex = + mTrieMap.get(prevWordIds[i], bitmapEntryIndices[i]).mNextLevelBitmapEntryIndex; + if (nextBitmapEntryIndex == TrieMap::INVALID_INDEX) { + break; + } + maxLevel = i + 1; + bitmapEntryIndices[i + 1] = nextBitmapEntryIndex; + } + + for (int i = maxLevel; i >= 0; --i) { + const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndices[i]); + if (!result.mIsValid) { + continue; + } + const ProbabilityEntry probabilityEntry = + ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo); + if (mHasHistoricalInfo) { + const int probability = ForgettingCurveUtils::decodeProbability( + probabilityEntry.getHistoricalInfo(), headerPolicy) + + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */); + return std::min(probability, MAX_PROBABILITY); + } else { + return probabilityEntry.getProbability(); + } + } + // Cannot find the word. + return NOT_A_PROBABILITY; +} + ProbabilityEntry LanguageModelDictContent::getNgramProbabilityEntry( const WordIdArrayView prevWordIds, const int wordId) const { const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds); @@ -45,12 +88,47 @@ ProbabilityEntry LanguageModelDictContent::getNgramProbabilityEntry( } bool LanguageModelDictContent::setNgramProbabilityEntry(const WordIdArrayView prevWordIds, - const int terminalId, const ProbabilityEntry *const probabilityEntry) { + const int wordId, const ProbabilityEntry *const probabilityEntry) { + if (wordId == Ver4DictConstants::NOT_A_TERMINAL_ID) { + return false; + } + const int bitmapEntryIndex = createAndGetBitmapEntryIndex(prevWordIds); + if (bitmapEntryIndex == TrieMap::INVALID_INDEX) { + return false; + } + return mTrieMap.put(wordId, probabilityEntry->encode(mHasHistoricalInfo), bitmapEntryIndex); +} + +bool LanguageModelDictContent::removeNgramProbabilityEntry(const WordIdArrayView prevWordIds, + const int wordId) { const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds); if (bitmapEntryIndex == TrieMap::INVALID_INDEX) { + // Cannot find bitmap entry for the probability entry. The entry doesn't exist. return false; } - return mTrieMap.put(terminalId, probabilityEntry->encode(mHasHistoricalInfo), bitmapEntryIndex); + return mTrieMap.remove(wordId, bitmapEntryIndex); +} + +LanguageModelDictContent::EntryRange LanguageModelDictContent::getProbabilityEntries( + const WordIdArrayView prevWordIds) const { + const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds); + return EntryRange(mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex), mHasHistoricalInfo); +} + +bool LanguageModelDictContent::truncateEntries(const int *const entryCounts, + const int *const maxEntryCounts, const HeaderPolicy *const headerPolicy, + int *const outEntryCounts) { + for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) { + if (entryCounts[i] <= maxEntryCounts[i]) { + outEntryCounts[i] = entryCounts[i]; + continue; + } + if (!turncateEntriesInSpecifiedLevel(headerPolicy, maxEntryCounts[i], i, + &outEntryCounts[i])) { + return false; + } + } + return true; } bool LanguageModelDictContent::runGCInner( @@ -80,6 +158,19 @@ bool LanguageModelDictContent::runGCInner( return true; } +int LanguageModelDictContent::createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds) { + if (prevWordIds.empty()) { + return mTrieMap.getRootBitmapEntryIndex(); + } + const int lastBitmapEntryIndex = + getBitmapEntryIndex(prevWordIds.limit(prevWordIds.size() - 1)); + if (lastBitmapEntryIndex == TrieMap::INVALID_INDEX) { + return TrieMap::INVALID_INDEX; + } + return mTrieMap.getNextLevelBitmapEntryIndex(prevWordIds[prevWordIds.size() - 1], + lastBitmapEntryIndex); +} + int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWordIds) const { int bitmapEntryIndex = mTrieMap.getRootBitmapEntryIndex(); for (const int wordId : prevWordIds) { @@ -92,4 +183,132 @@ int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWord return bitmapEntryIndex; } +bool LanguageModelDictContent::updateAllProbabilityEntriesInner(const int bitmapEntryIndex, + const int level, const HeaderPolicy *const headerPolicy, int *const outEntryCounts) { + for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) { + if (level > MAX_PREV_WORD_COUNT_FOR_N_GRAM) { + AKLOGE("Invalid level. level: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.", + level, MAX_PREV_WORD_COUNT_FOR_N_GRAM); + return false; + } + const ProbabilityEntry probabilityEntry = + ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo); + if (mHasHistoricalInfo && !probabilityEntry.representsBeginningOfSentence()) { + const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave( + probabilityEntry.getHistoricalInfo(), headerPolicy); + if (ForgettingCurveUtils::needsToKeep(&historicalInfo, headerPolicy)) { + // Update the entry. + const ProbabilityEntry updatedEntry(probabilityEntry.getFlags(), &historicalInfo); + if (!mTrieMap.put(entry.key(), updatedEntry.encode(mHasHistoricalInfo), + bitmapEntryIndex)) { + return false; + } + } else { + // Remove the entry. + if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) { + return false; + } + continue; + } + } + if (!probabilityEntry.representsBeginningOfSentence()) { + outEntryCounts[level] += 1; + } + if (!entry.hasNextLevelMap()) { + continue; + } + if (!updateAllProbabilityEntriesInner(entry.getNextLevelBitmapEntryIndex(), level + 1, + headerPolicy, outEntryCounts)) { + return false; + } + } + return true; +} + +bool LanguageModelDictContent::turncateEntriesInSpecifiedLevel( + const HeaderPolicy *const headerPolicy, const int maxEntryCount, const int targetLevel, + int *const outEntryCount) { + std::vector<int> prevWordIds; + std::vector<EntryInfoToTurncate> entryInfoVector; + if (!getEntryInfo(headerPolicy, targetLevel, mTrieMap.getRootBitmapEntryIndex(), + &prevWordIds, &entryInfoVector)) { + return false; + } + if (static_cast<int>(entryInfoVector.size()) <= maxEntryCount) { + *outEntryCount = static_cast<int>(entryInfoVector.size()); + return true; + } + *outEntryCount = maxEntryCount; + const int entryCountToRemove = static_cast<int>(entryInfoVector.size()) - maxEntryCount; + std::partial_sort(entryInfoVector.begin(), entryInfoVector.begin() + entryCountToRemove, + entryInfoVector.end(), + EntryInfoToTurncate::Comparator()); + for (int i = 0; i < entryCountToRemove; ++i) { + const EntryInfoToTurncate &entryInfo = entryInfoVector[i]; + if (!removeNgramProbabilityEntry( + WordIdArrayView(entryInfo.mPrevWordIds, entryInfo.mEntryLevel), entryInfo.mKey)) { + return false; + } + } + return true; +} + +bool LanguageModelDictContent::getEntryInfo(const HeaderPolicy *const headerPolicy, + const int targetLevel, const int bitmapEntryIndex, std::vector<int> *const prevWordIds, + std::vector<EntryInfoToTurncate> *const outEntryInfo) const { + const int currentLevel = prevWordIds->size(); + for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) { + if (currentLevel < targetLevel) { + if (!entry.hasNextLevelMap()) { + continue; + } + prevWordIds->push_back(entry.key()); + if (!getEntryInfo(headerPolicy, targetLevel, entry.getNextLevelBitmapEntryIndex(), + prevWordIds, outEntryInfo)) { + return false; + } + prevWordIds->pop_back(); + continue; + } + const ProbabilityEntry probabilityEntry = + ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo); + const int probability = (mHasHistoricalInfo) ? + ForgettingCurveUtils::decodeProbability(probabilityEntry.getHistoricalInfo(), + headerPolicy) : probabilityEntry.getProbability(); + outEntryInfo->emplace_back(probability, + probabilityEntry.getHistoricalInfo()->getTimeStamp(), + entry.key(), targetLevel, prevWordIds->data()); + } + return true; +} + +bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()( + const EntryInfoToTurncate &left, const EntryInfoToTurncate &right) const { + if (left.mProbability != right.mProbability) { + return left.mProbability < right.mProbability; + } + if (left.mTimestamp != right.mTimestamp) { + return left.mTimestamp > right.mTimestamp; + } + if (left.mKey != right.mKey) { + return left.mKey < right.mKey; + } + if (left.mEntryLevel != right.mEntryLevel) { + return left.mEntryLevel > right.mEntryLevel; + } + for (int i = 0; i < left.mEntryLevel; ++i) { + if (left.mPrevWordIds[i] != right.mPrevWordIds[i]) { + return left.mPrevWordIds[i] < right.mPrevWordIds[i]; + } + } + // left and rigth represent the same entry. + return false; +} + +LanguageModelDictContent::EntryInfoToTurncate::EntryInfoToTurncate(const int probability, + const int timestamp, const int key, const int entryLevel, const int *const prevWordIds) + : mProbability(probability), mTimestamp(timestamp), mKey(key), mEntryLevel(entryLevel) { + memmove(mPrevWordIds, prevWordIds, mEntryLevel * sizeof(mPrevWordIds[0])); +} + } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h index 18f2e0170..a793af4be 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h @@ -18,6 +18,7 @@ #define LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H #include <cstdio> +#include <vector> #include "defines.h" #include "suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h" @@ -29,6 +30,8 @@ namespace latinime { +class HeaderPolicy; + /** * Class representing language model. * @@ -36,6 +39,78 @@ namespace latinime { */ class LanguageModelDictContent { public: + static const int UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE; + static const int BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE; + + // Pair of word id and probability entry used for iteration. + class WordIdAndProbabilityEntry { + public: + WordIdAndProbabilityEntry(const int wordId, const ProbabilityEntry &probabilityEntry) + : mWordId(wordId), mProbabilityEntry(probabilityEntry) {} + + int getWordId() const { return mWordId; } + const ProbabilityEntry getProbabilityEntry() const { return mProbabilityEntry; } + + private: + DISALLOW_DEFAULT_CONSTRUCTOR(WordIdAndProbabilityEntry); + DISALLOW_ASSIGNMENT_OPERATOR(WordIdAndProbabilityEntry); + + const int mWordId; + const ProbabilityEntry mProbabilityEntry; + }; + + // Iterator. + class EntryIterator { + public: + EntryIterator(const TrieMap::TrieMapIterator &trieMapIterator, + const bool hasHistoricalInfo) + : mTrieMapIterator(trieMapIterator), mHasHistoricalInfo(hasHistoricalInfo) {} + + const WordIdAndProbabilityEntry operator*() const { + const TrieMap::TrieMapIterator::IterationResult &result = *mTrieMapIterator; + return WordIdAndProbabilityEntry( + result.key(), ProbabilityEntry::decode(result.value(), mHasHistoricalInfo)); + } + + bool operator!=(const EntryIterator &other) const { + return mTrieMapIterator != other.mTrieMapIterator; + } + + const EntryIterator &operator++() { + ++mTrieMapIterator; + return *this; + } + + private: + DISALLOW_DEFAULT_CONSTRUCTOR(EntryIterator); + DISALLOW_ASSIGNMENT_OPERATOR(EntryIterator); + + TrieMap::TrieMapIterator mTrieMapIterator; + const bool mHasHistoricalInfo; + }; + + // Class represents range to use range base for loops. + class EntryRange { + public: + EntryRange(const TrieMap::TrieMapRange trieMapRange, const bool hasHistoricalInfo) + : mTrieMapRange(trieMapRange), mHasHistoricalInfo(hasHistoricalInfo) {} + + EntryIterator begin() const { + return EntryIterator(mTrieMapRange.begin(), mHasHistoricalInfo); + } + + EntryIterator end() const { + return EntryIterator(mTrieMapRange.end(), mHasHistoricalInfo); + } + + private: + DISALLOW_DEFAULT_CONSTRUCTOR(EntryRange); + DISALLOW_ASSIGNMENT_OPERATOR(EntryRange); + + const TrieMap::TrieMapRange mTrieMapRange; + const bool mHasHistoricalInfo; + }; + LanguageModelDictContent(const ReadWriteByteArrayView trieMapBuffer, const bool hasHistoricalInfo) : mTrieMap(trieMapBuffer), mHasHistoricalInfo(hasHistoricalInfo) {} @@ -53,6 +128,9 @@ class LanguageModelDictContent { const LanguageModelDictContent *const originalContent, int *const outNgramCount); + int getWordProbability(const WordIdArrayView prevWordIds, const int wordId, + const HeaderPolicy *const headerPolicy) const; + ProbabilityEntry getProbabilityEntry(const int wordId) const { return getNgramProbabilityEntry(WordIdArrayView(), wordId); } @@ -61,23 +139,74 @@ class LanguageModelDictContent { return setNgramProbabilityEntry(WordIdArrayView(), wordId, probabilityEntry); } + bool removeProbabilityEntry(const int wordId) { + return removeNgramProbabilityEntry(WordIdArrayView(), wordId); + } + ProbabilityEntry getNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId) const; bool setNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId, const ProbabilityEntry *const probabilityEntry); + bool removeNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId); + + EntryRange getProbabilityEntries(const WordIdArrayView prevWordIds) const; + + bool updateAllProbabilityEntries(const HeaderPolicy *const headerPolicy, + int *const outEntryCounts) { + for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) { + outEntryCounts[i] = 0; + } + return updateAllProbabilityEntriesInner(mTrieMap.getRootBitmapEntryIndex(), 0 /* level */, + headerPolicy, outEntryCounts); + } + + // entryCounts should be created by updateAllProbabilityEntries. + bool truncateEntries(const int *const entryCounts, const int *const maxEntryCounts, + const HeaderPolicy *const headerPolicy, int *const outEntryCounts); + private: DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent); + class EntryInfoToTurncate { + public: + class Comparator { + public: + bool operator()(const EntryInfoToTurncate &left, + const EntryInfoToTurncate &right) const; + private: + DISALLOW_ASSIGNMENT_OPERATOR(Comparator); + }; + + EntryInfoToTurncate(const int probability, const int timestamp, const int key, + const int entryLevel, const int *const prevWordIds); + + int mProbability; + int mTimestamp; + int mKey; + int mEntryLevel; + int mPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; + + private: + DISALLOW_DEFAULT_CONSTRUCTOR(EntryInfoToTurncate); + }; + TrieMap mTrieMap; const bool mHasHistoricalInfo; bool runGCInner(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex, int *const outNgramCount); - + int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds); int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const; + bool updateAllProbabilityEntriesInner(const int bitmapEntryIndex, const int level, + const HeaderPolicy *const headerPolicy, int *const outEntryCounts); + bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy, + const int maxEntryCount, const int targetLevel, int *const outEntryCount); + bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel, + const int bitmapEntryIndex, std::vector<int> *const prevWordIds, + std::vector<EntryInfoToTurncate> *const outEntryInfo) const; }; } // namespace latinime #endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h index feff6b57f..3dfaba755 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h @@ -21,6 +21,8 @@ #include <cstdint> #include "defines.h" +#include "suggest/core/dictionary/property/bigram_property.h" +#include "suggest/core/dictionary/property/unigram_property.h" #include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h" #include "suggest/policyimpl/dictionary/utils/historical_info.h" @@ -41,24 +43,32 @@ class ProbabilityEntry { : mFlags(flags), mProbability(probability), mHistoricalInfo() {} // Entry with historical information. - ProbabilityEntry(const int flags, const int probability, - const HistoricalInfo *const historicalInfo) - : mFlags(flags), mProbability(probability), mHistoricalInfo(*historicalInfo) {} - - const ProbabilityEntry createEntryWithUpdatedProbability(const int probability) const { - return ProbabilityEntry(mFlags, probability, &mHistoricalInfo); - } - - const ProbabilityEntry createEntryWithUpdatedHistoricalInfo( - const HistoricalInfo *const historicalInfo) const { - return ProbabilityEntry(mFlags, mProbability, historicalInfo); + ProbabilityEntry(const int flags, const HistoricalInfo *const historicalInfo) + : mFlags(flags), mProbability(NOT_A_PROBABILITY), mHistoricalInfo(*historicalInfo) {} + + // Create from unigram property. + ProbabilityEntry(const UnigramProperty *const unigramProperty) + : mFlags(createFlags(unigramProperty->representsBeginningOfSentence())), + mProbability(unigramProperty->getProbability()), + mHistoricalInfo(unigramProperty->getTimestamp(), unigramProperty->getLevel(), + unigramProperty->getCount()) {} + + // Create from bigram property. + // TODO: Set flags. + ProbabilityEntry(const BigramProperty *const bigramProperty) + : mFlags(0), mProbability(bigramProperty->getProbability()), + mHistoricalInfo(bigramProperty->getTimestamp(), bigramProperty->getLevel(), + bigramProperty->getCount()) {} + + bool isValid() const { + return (mProbability != NOT_A_PROBABILITY) || hasHistoricalInfo(); } bool hasHistoricalInfo() const { return mHistoricalInfo.isValid(); } - int getFlags() const { + uint8_t getFlags() const { return mFlags; } @@ -70,6 +80,10 @@ class ProbabilityEntry { return &mHistoricalInfo; } + bool representsBeginningOfSentence() const { + return (mFlags & Ver4DictConstants::FLAG_REPRESENTS_BEGINNING_OF_SENTENCE) != 0; + } + uint64_t encode(const bool hasHistoricalInfo) const { uint64_t encodedEntry = static_cast<uint64_t>(mFlags); if (hasHistoricalInfo) { @@ -89,7 +103,7 @@ class ProbabilityEntry { static ProbabilityEntry decode(const uint64_t encodedEntry, const bool hasHistoricalInfo) { if (hasHistoricalInfo) { const int flags = readFromEncodedEntry(encodedEntry, - Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE, + Ver4DictConstants::FLAGS_IN_LANGUAGE_MODEL_SIZE, Ver4DictConstants::TIME_STAMP_FIELD_SIZE + Ver4DictConstants::WORD_LEVEL_FIELD_SIZE + Ver4DictConstants::WORD_COUNT_FIELD_SIZE); @@ -103,10 +117,10 @@ class ProbabilityEntry { const int count = readFromEncodedEntry(encodedEntry, Ver4DictConstants::WORD_COUNT_FIELD_SIZE, 0 /* pos */); const HistoricalInfo historicalInfo(timestamp, level, count); - return ProbabilityEntry(flags, NOT_A_PROBABILITY, &historicalInfo); + return ProbabilityEntry(flags, &historicalInfo); } else { const int flags = readFromEncodedEntry(encodedEntry, - Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE, + Ver4DictConstants::FLAGS_IN_LANGUAGE_MODEL_SIZE, Ver4DictConstants::PROBABILITY_SIZE); const int probability = readFromEncodedEntry(encodedEntry, Ver4DictConstants::PROBABILITY_SIZE, 0 /* pos */); @@ -118,7 +132,7 @@ class ProbabilityEntry { // Copy constructor is public to use this class as a type of return value. DISALLOW_ASSIGNMENT_OPERATOR(ProbabilityEntry); - const int mFlags; + const uint8_t mFlags; const int mProbability; const HistoricalInfo mHistoricalInfo; @@ -126,6 +140,14 @@ class ProbabilityEntry { return static_cast<int>( (encodedEntry >> (pos * CHAR_BIT)) & ((1ull << (size * CHAR_BIT)) - 1)); } + + static uint8_t createFlags(const bool representsBeginningOfSentence) { + uint8_t flags = 0; + if (representsBeginningOfSentence) { + flags ^= Ver4DictConstants::FLAG_REPRESENTS_BEGINNING_OF_SENTENCE; + } + return flags; + } }; } // namespace latinime #endif /* LATINIME_PROBABILITY_ENTRY_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/shortcut_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/shortcut_dict_content.h index 7b12aff16..85c9ce8d8 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/shortcut_dict_content.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/shortcut_dict_content.h @@ -17,7 +17,6 @@ #ifndef LATINIME_SHORTCUT_DICT_CONTENT_H #define LATINIME_SHORTCUT_DICT_CONTENT_H -#include <cstdint> #include <cstdio> #include "defines.h" @@ -27,11 +26,12 @@ namespace latinime { +class ReadWriteByteArrayView; + class ShortcutDictContent : public SparseTableDictContent { public: - ShortcutDictContent(uint8_t *const *buffers, const int *bufferSizes) - : SparseTableDictContent(buffers, bufferSizes, - Ver4DictConstants::SHORTCUT_ADDRESS_TABLE_BLOCK_SIZE, + ShortcutDictContent(const ReadWriteByteArrayView *const buffers) + : SparseTableDictContent(buffers, Ver4DictConstants::SHORTCUT_ADDRESS_TABLE_BLOCK_SIZE, Ver4DictConstants::SHORTCUT_ADDRESS_TABLE_DATA_SIZE) {} ShortcutDictContent() diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/single_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/single_dict_content.h index 921774181..309c434cf 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/single_dict_content.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/single_dict_content.h @@ -17,7 +17,6 @@ #ifndef LATINIME_SINGLE_DICT_CONTENT_H #define LATINIME_SINGLE_DICT_CONTENT_H -#include <cstdint> #include <cstdio> #include "defines.h" @@ -30,9 +29,9 @@ namespace latinime { class SingleDictContent { public: - SingleDictContent(uint8_t *const buffer, const int bufferSize) - : mExpandableContentBuffer(ReadWriteByteArrayView(buffer, bufferSize), - BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE) {} + SingleDictContent(const ReadWriteByteArrayView buffer) + : mExpandableContentBuffer(buffer, + BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE) {} SingleDictContent() : mExpandableContentBuffer(Ver4DictConstants::MAX_DICTIONARY_SIZE) {} diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/sparse_table_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/sparse_table_dict_content.h index c98dd11fd..0ce2da7bf 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/sparse_table_dict_content.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/sparse_table_dict_content.h @@ -17,7 +17,6 @@ #ifndef LATINIME_SPARSE_TABLE_DICT_CONTENT_H #define LATINIME_SPARSE_TABLE_DICT_CONTENT_H -#include <cstdint> #include <cstdio> #include "defines.h" @@ -31,19 +30,13 @@ namespace latinime { // TODO: Support multiple contents. class SparseTableDictContent { public: - AK_FORCE_INLINE SparseTableDictContent(uint8_t *const *buffers, const int *bufferSizes, + AK_FORCE_INLINE SparseTableDictContent(const ReadWriteByteArrayView *const buffers, const int sparseTableBlockSize, const int sparseTableDataSize) - : mExpandableLookupTableBuffer( - ReadWriteByteArrayView(buffers[LOOKUP_TABLE_BUFFER_INDEX], - bufferSizes[LOOKUP_TABLE_BUFFER_INDEX]), + : mExpandableLookupTableBuffer(buffers[LOOKUP_TABLE_BUFFER_INDEX], BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE), - mExpandableAddressTableBuffer( - ReadWriteByteArrayView(buffers[ADDRESS_TABLE_BUFFER_INDEX], - bufferSizes[ADDRESS_TABLE_BUFFER_INDEX]), + mExpandableAddressTableBuffer(buffers[ADDRESS_TABLE_BUFFER_INDEX], BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE), - mExpandableContentBuffer( - ReadWriteByteArrayView(buffers[CONTENT_BUFFER_INDEX], - bufferSizes[CONTENT_BUFFER_INDEX]), + mExpandableContentBuffer(buffers[CONTENT_BUFFER_INDEX], BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE), mAddressLookupTable(&mExpandableLookupTableBuffer, &mExpandableAddressTableBuffer, sparseTableBlockSize, sparseTableDataSize) {} diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.cpp index cf238ee5f..2bdf07752 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.cpp @@ -34,7 +34,7 @@ int TerminalPositionLookupTable::getTerminalPtNodePosition(const int terminalId) bool TerminalPositionLookupTable::setTerminalPtNodePosition( const int terminalId, const int terminalPtNodePos) { if (terminalId < 0) { - return NOT_A_DICT_POS; + return false; } while (terminalId >= mSize) { // Write new entry. diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h index b2262bf1e..febcbe5b4 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h @@ -17,13 +17,13 @@ #ifndef LATINIME_TERMINAL_POSITION_LOOKUP_TABLE_H #define LATINIME_TERMINAL_POSITION_LOOKUP_TABLE_H -#include <cstdint> #include <cstdio> #include <unordered_map> #include "defines.h" #include "suggest/policyimpl/dictionary/structure/v4/content/single_dict_content.h" #include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h" +#include "utils/byte_array_view.h" namespace latinime { @@ -31,8 +31,8 @@ class TerminalPositionLookupTable : public SingleDictContent { public: typedef std::unordered_map<int, int> TerminalIdMap; - TerminalPositionLookupTable(uint8_t *const buffer, const int bufferSize) - : SingleDictContent(buffer, bufferSize), + TerminalPositionLookupTable(const ReadWriteByteArrayView buffer) + : SingleDictContent(buffer), mSize(getBuffer()->getTailPosition() / Ver4DictConstants::TERMINAL_ADDRESS_TABLE_ADDRESS_SIZE) {} diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.cpp index 3c8008dc4..45f88e9b2 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.cpp @@ -45,16 +45,13 @@ namespace latinime { if (!bodyBuffer) { return Ver4DictBuffersPtr(nullptr); } - std::vector<uint8_t *> buffers; - std::vector<int> bufferSizes; + std::vector<ReadWriteByteArrayView> buffers; const ReadWriteByteArrayView buffer = bodyBuffer->getReadWriteByteArrayView(); int position = 0; while (position < static_cast<int>(buffer.size())) { const int bufferSize = ByteArrayUtils::readUint32AndAdvancePosition( buffer.data(), &position); - const ReadWriteByteArrayView subBuffer = buffer.subView(position, bufferSize); - buffers.push_back(subBuffer.data()); - bufferSizes.push_back(subBuffer.size()); + buffers.push_back(buffer.subView(position, bufferSize)); position += bufferSize; if (bufferSize < 0 || position < 0 || position > static_cast<int>(buffer.size())) { AKLOGE("The dict body file is corrupted."); @@ -66,7 +63,7 @@ namespace latinime { return Ver4DictBuffersPtr(nullptr); } return Ver4DictBuffersPtr(new Ver4DictBuffers(std::move(headerBuffer), std::move(bodyBuffer), - formatVersion, buffers, bufferSizes)); + formatVersion, buffers)); } bool Ver4DictBuffers::flushHeaderAndDictBuffers(const char *const dictDirPath, @@ -162,11 +159,6 @@ bool Ver4DictBuffers::flushDictBuffers(FILE *const file) const { AKLOGE("Language model dict content cannot be written."); return false; } - // Write bigram dict content. - if (!mBigramDictContent.flushToFile(file)) { - AKLOGE("Bigram dict content cannot be written."); - return false; - } // Write shortcut dict content. if (!mShortcutDictContent.flushToFile(file)) { AKLOGE("Shortcut dict content cannot be written."); @@ -178,29 +170,18 @@ bool Ver4DictBuffers::flushDictBuffers(FILE *const file) const { Ver4DictBuffers::Ver4DictBuffers(MmappedBuffer::MmappedBufferPtr &&headerBuffer, MmappedBuffer::MmappedBufferPtr &&bodyBuffer, const FormatUtils::FORMAT_VERSION formatVersion, - const std::vector<uint8_t *> &contentBuffers, const std::vector<int> &contentBufferSizes) + const std::vector<ReadWriteByteArrayView> &contentBuffers) : mHeaderBuffer(std::move(headerBuffer)), mDictBuffer(std::move(bodyBuffer)), mHeaderPolicy(mHeaderBuffer->getReadOnlyByteArrayView().data(), formatVersion), mExpandableHeaderBuffer(mHeaderBuffer->getReadWriteByteArrayView(), BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE), - mExpandableTrieBuffer( - ReadWriteByteArrayView(contentBuffers[Ver4DictConstants::TRIE_BUFFER_INDEX], - contentBufferSizes[Ver4DictConstants::TRIE_BUFFER_INDEX]), + mExpandableTrieBuffer(contentBuffers[Ver4DictConstants::TRIE_BUFFER_INDEX], BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE), mTerminalPositionLookupTable( - contentBuffers[Ver4DictConstants::TERMINAL_ADDRESS_LOOKUP_TABLE_BUFFER_INDEX], - contentBufferSizes[ - Ver4DictConstants::TERMINAL_ADDRESS_LOOKUP_TABLE_BUFFER_INDEX]), - mLanguageModelDictContent( - ReadWriteByteArrayView( - contentBuffers[Ver4DictConstants::LANGUAGE_MODEL_BUFFER_INDEX], - contentBufferSizes[Ver4DictConstants::LANGUAGE_MODEL_BUFFER_INDEX]), - mHeaderPolicy.hasHistoricalInfoOfWords()), - mBigramDictContent(&contentBuffers[Ver4DictConstants::BIGRAM_BUFFERS_INDEX], - &contentBufferSizes[Ver4DictConstants::BIGRAM_BUFFERS_INDEX], + contentBuffers[Ver4DictConstants::TERMINAL_ADDRESS_LOOKUP_TABLE_BUFFER_INDEX]), + mLanguageModelDictContent(contentBuffers[Ver4DictConstants::LANGUAGE_MODEL_BUFFER_INDEX], mHeaderPolicy.hasHistoricalInfoOfWords()), - mShortcutDictContent(&contentBuffers[Ver4DictConstants::SHORTCUT_BUFFERS_INDEX], - &contentBufferSizes[Ver4DictConstants::SHORTCUT_BUFFERS_INDEX]), + mShortcutDictContent(&contentBuffers[Ver4DictConstants::SHORTCUT_BUFFERS_INDEX]), mIsUpdatable(mDictBuffer->isUpdatable()) {} Ver4DictBuffers::Ver4DictBuffers(const HeaderPolicy *const headerPolicy, const int maxTrieSize) @@ -208,7 +189,6 @@ Ver4DictBuffers::Ver4DictBuffers(const HeaderPolicy *const headerPolicy, const i mExpandableHeaderBuffer(Ver4DictConstants::MAX_DICTIONARY_SIZE), mExpandableTrieBuffer(maxTrieSize), mTerminalPositionLookupTable(), mLanguageModelDictContent(headerPolicy->hasHistoricalInfoOfWords()), - mBigramDictContent(headerPolicy->hasHistoricalInfoOfWords()), mShortcutDictContent(), - mIsUpdatable(true) {} + mShortcutDictContent(), mIsUpdatable(true) {} } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.h index 68027dcb8..5407525af 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.h @@ -22,7 +22,6 @@ #include "defines.h" #include "suggest/policyimpl/dictionary/header/header_policy.h" -#include "suggest/policyimpl/dictionary/structure/v4/content/bigram_dict_content.h" #include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h" #include "suggest/policyimpl/dictionary/structure/v4/content/shortcut_dict_content.h" #include "suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h" @@ -53,7 +52,6 @@ class Ver4DictBuffers { return mExpandableTrieBuffer.isNearSizeLimit() || mTerminalPositionLookupTable.isNearSizeLimit() || mLanguageModelDictContent.isNearSizeLimit() - || mBigramDictContent.isNearSizeLimit() || mShortcutDictContent.isNearSizeLimit(); } @@ -89,14 +87,6 @@ class Ver4DictBuffers { return &mLanguageModelDictContent; } - AK_FORCE_INLINE BigramDictContent *getMutableBigramDictContent() { - return &mBigramDictContent; - } - - AK_FORCE_INLINE const BigramDictContent *getBigramDictContent() const { - return &mBigramDictContent; - } - AK_FORCE_INLINE ShortcutDictContent *getMutableShortcutDictContent() { return &mShortcutDictContent; } @@ -122,8 +112,7 @@ class Ver4DictBuffers { Ver4DictBuffers(MmappedBuffer::MmappedBufferPtr &&headerBuffer, MmappedBuffer::MmappedBufferPtr &&bodyBuffer, const FormatUtils::FORMAT_VERSION formatVersion, - const std::vector<uint8_t *> &contentBuffers, - const std::vector<int> &contentBufferSizes); + const std::vector<ReadWriteByteArrayView> &contentBuffers); Ver4DictBuffers(const HeaderPolicy *const headerPolicy, const int maxTrieSize); @@ -136,7 +125,6 @@ class Ver4DictBuffers { BufferWithExtendableBuffer mExpandableTrieBuffer; TerminalPositionLookupTable mTerminalPositionLookupTable; LanguageModelDictContent mLanguageModelDictContent; - BigramDictContent mBigramDictContent; ShortcutDictContent mShortcutDictContent; const int mIsUpdatable; }; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp index 93d4e562d..9acf2d44f 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp @@ -29,24 +29,22 @@ const int Ver4DictConstants::MAX_DICT_EXTENDED_REGION_SIZE = 1 * 1024 * 1024; // NUM_OF_BUFFERS_FOR_SINGLE_DICT_CONTENT for Trie and TerminalAddressLookupTable. // NUM_OF_BUFFERS_FOR_LANGUAGE_MODEL_DICT_CONTENT for language model. -// NUM_OF_BUFFERS_FOR_SPARSE_TABLE_DICT_CONTENT for bigram and shortcut. +// NUM_OF_BUFFERS_FOR_SPARSE_TABLE_DICT_CONTENT for shortcut. const size_t Ver4DictConstants::NUM_OF_CONTENT_BUFFERS_IN_BODY_FILE = NUM_OF_BUFFERS_FOR_SINGLE_DICT_CONTENT * 2 + NUM_OF_BUFFERS_FOR_LANGUAGE_MODEL_DICT_CONTENT - + NUM_OF_BUFFERS_FOR_SPARSE_TABLE_DICT_CONTENT * 2; + + NUM_OF_BUFFERS_FOR_SPARSE_TABLE_DICT_CONTENT; const int Ver4DictConstants::TRIE_BUFFER_INDEX = 0; const int Ver4DictConstants::TERMINAL_ADDRESS_LOOKUP_TABLE_BUFFER_INDEX = TRIE_BUFFER_INDEX + NUM_OF_BUFFERS_FOR_SINGLE_DICT_CONTENT; const int Ver4DictConstants::LANGUAGE_MODEL_BUFFER_INDEX = TERMINAL_ADDRESS_LOOKUP_TABLE_BUFFER_INDEX + NUM_OF_BUFFERS_FOR_SINGLE_DICT_CONTENT; -const int Ver4DictConstants::BIGRAM_BUFFERS_INDEX = - LANGUAGE_MODEL_BUFFER_INDEX + NUM_OF_BUFFERS_FOR_LANGUAGE_MODEL_DICT_CONTENT; const int Ver4DictConstants::SHORTCUT_BUFFERS_INDEX = - BIGRAM_BUFFERS_INDEX + NUM_OF_BUFFERS_FOR_SPARSE_TABLE_DICT_CONTENT; + LANGUAGE_MODEL_BUFFER_INDEX + NUM_OF_BUFFERS_FOR_LANGUAGE_MODEL_DICT_CONTENT; const int Ver4DictConstants::NOT_A_TERMINAL_ID = -1; const int Ver4DictConstants::PROBABILITY_SIZE = 1; -const int Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE = 1; +const int Ver4DictConstants::FLAGS_IN_LANGUAGE_MODEL_SIZE = 1; const int Ver4DictConstants::TERMINAL_ADDRESS_TABLE_ADDRESS_SIZE = 3; const int Ver4DictConstants::NOT_A_TERMINAL_ADDRESS = 0; const int Ver4DictConstants::TERMINAL_ID_FIELD_SIZE = 4; @@ -54,21 +52,11 @@ const int Ver4DictConstants::TIME_STAMP_FIELD_SIZE = 4; const int Ver4DictConstants::WORD_LEVEL_FIELD_SIZE = 1; const int Ver4DictConstants::WORD_COUNT_FIELD_SIZE = 1; -const int Ver4DictConstants::BIGRAM_ADDRESS_TABLE_BLOCK_SIZE = 16; -const int Ver4DictConstants::BIGRAM_ADDRESS_TABLE_DATA_SIZE = 4; +const uint8_t Ver4DictConstants::FLAG_REPRESENTS_BEGINNING_OF_SENTENCE = 0x1; + const int Ver4DictConstants::SHORTCUT_ADDRESS_TABLE_BLOCK_SIZE = 64; const int Ver4DictConstants::SHORTCUT_ADDRESS_TABLE_DATA_SIZE = 4; -const int Ver4DictConstants::BIGRAM_TARGET_TERMINAL_ID_FIELD_SIZE = 3; -// Unsigned int max value of BIGRAM_TARGET_TERMINAL_ID_FIELD_SIZE-byte is used for representing -// invalid terminal ID in bigram lists. -const int Ver4DictConstants::INVALID_BIGRAM_TARGET_TERMINAL_ID = - (1 << (BIGRAM_TARGET_TERMINAL_ID_FIELD_SIZE * 8)) - 1; -const int Ver4DictConstants::BIGRAM_FLAGS_FIELD_SIZE = 1; -const int Ver4DictConstants::BIGRAM_PROBABILITY_MASK = 0x0F; -const int Ver4DictConstants::BIGRAM_IS_LINK_MASK = 0x80; -const int Ver4DictConstants::BIGRAM_LARGE_PROBABILITY_FIELD_SIZE = 1; - const int Ver4DictConstants::SHORTCUT_FLAGS_FIELD_SIZE = 1; const int Ver4DictConstants::SHORTCUT_PROBABILITY_MASK = 0x0F; const int Ver4DictConstants::SHORTCUT_HAS_NEXT_MASK = 0x80; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h index 6950ca70f..97035311e 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h @@ -20,6 +20,7 @@ #include "defines.h" #include <cstddef> +#include <cstdint> namespace latinime { @@ -41,27 +42,19 @@ class Ver4DictConstants { static const int NOT_A_TERMINAL_ID; static const int PROBABILITY_SIZE; - static const int FLAGS_IN_PROBABILITY_FILE_SIZE; + static const int FLAGS_IN_LANGUAGE_MODEL_SIZE; static const int TERMINAL_ADDRESS_TABLE_ADDRESS_SIZE; static const int NOT_A_TERMINAL_ADDRESS; static const int TERMINAL_ID_FIELD_SIZE; static const int TIME_STAMP_FIELD_SIZE; static const int WORD_LEVEL_FIELD_SIZE; static const int WORD_COUNT_FIELD_SIZE; + // Flags in probability entry. + static const uint8_t FLAG_REPRESENTS_BEGINNING_OF_SENTENCE; - static const int BIGRAM_ADDRESS_TABLE_BLOCK_SIZE; - static const int BIGRAM_ADDRESS_TABLE_DATA_SIZE; static const int SHORTCUT_ADDRESS_TABLE_BLOCK_SIZE; static const int SHORTCUT_ADDRESS_TABLE_DATA_SIZE; - static const int BIGRAM_FLAGS_FIELD_SIZE; - static const int BIGRAM_TARGET_TERMINAL_ID_FIELD_SIZE; - static const int INVALID_BIGRAM_TARGET_TERMINAL_ID; - static const int BIGRAM_IS_LINK_MASK; - static const int BIGRAM_PROBABILITY_MASK; - // Used when bigram list has time stamp. - static const int BIGRAM_LARGE_PROBABILITY_FIELD_SIZE; - static const int SHORTCUT_FLAGS_FIELD_SIZE; static const int SHORTCUT_PROBABILITY_MASK; static const int SHORTCUT_HAS_NEXT_MASK; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp index 857222f5d..9ca712470 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp @@ -21,7 +21,6 @@ #include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_utils.h" #include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_writing_utils.h" #include "suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h" -#include "suggest/policyimpl/dictionary/structure/v4/bigram/ver4_bigram_list_policy.h" #include "suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h" #include "suggest/policyimpl/dictionary/structure/v4/shortcut/ver4_shortcut_list_policy.h" #include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.h" @@ -145,10 +144,11 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeUnigramProperty( const ProbabilityEntry originalProbabilityEntry = mBuffers->getLanguageModelDictContent()->getProbabilityEntry( toBeUpdatedPtNodeParams->getTerminalId()); - const ProbabilityEntry probabilityEntry = createUpdatedEntryFrom(&originalProbabilityEntry, - unigramProperty); + const ProbabilityEntry probabilityEntryOfUnigramProperty = ProbabilityEntry(unigramProperty); + const ProbabilityEntry updatedProbabilityEntry = + createUpdatedEntryFrom(&originalProbabilityEntry, &probabilityEntryOfUnigramProperty); return mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry( - toBeUpdatedPtNodeParams->getTerminalId(), &probabilityEntry); + toBeUpdatedPtNodeParams->getTerminalId(), &updatedProbabilityEntry); } bool Ver4PatriciaTrieNodeWriter::updatePtNodeProbabilityAndGetNeedsToKeepPtNodeAfterGC( @@ -160,29 +160,15 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeProbabilityAndGetNeedsToKeepPtNodeA const ProbabilityEntry originalProbabilityEntry = mBuffers->getLanguageModelDictContent()->getProbabilityEntry( toBeUpdatedPtNodeParams->getTerminalId()); - if (originalProbabilityEntry.hasHistoricalInfo()) { - const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave( - originalProbabilityEntry.getHistoricalInfo(), mHeaderPolicy); - const ProbabilityEntry probabilityEntry = - originalProbabilityEntry.createEntryWithUpdatedHistoricalInfo(&historicalInfo); - if (!mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry( - toBeUpdatedPtNodeParams->getTerminalId(), &probabilityEntry)) { - AKLOGE("Cannot write updated probability entry. terminalId: %d", - toBeUpdatedPtNodeParams->getTerminalId()); - return false; - } - const bool isValid = ForgettingCurveUtils::needsToKeep(&historicalInfo, mHeaderPolicy); - if (!isValid) { - if (!markPtNodeAsWillBecomeNonTerminal(toBeUpdatedPtNodeParams)) { - AKLOGE("Cannot mark PtNode as willBecomeNonTerminal."); - return false; - } - } - *outNeedsToKeepPtNode = isValid; - } else { - // No need to update probability. + if (originalProbabilityEntry.isValid()) { *outNeedsToKeepPtNode = true; + return true; } + if (!markPtNodeAsWillBecomeNonTerminal(toBeUpdatedPtNodeParams)) { + AKLOGE("Cannot mark PtNode as willBecomeNonTerminal."); + return false; + } + *outNeedsToKeepPtNode = false; return true; } @@ -216,31 +202,50 @@ bool Ver4PatriciaTrieNodeWriter::writeNewTerminalPtNodeAndAdvancePosition( } // Write probability. ProbabilityEntry newProbabilityEntry; + const ProbabilityEntry probabilityEntryOfUnigramProperty = ProbabilityEntry(unigramProperty); const ProbabilityEntry probabilityEntryToWrite = createUpdatedEntryFrom( - &newProbabilityEntry, unigramProperty); + &newProbabilityEntry, &probabilityEntryOfUnigramProperty); return mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry( terminalId, &probabilityEntryToWrite); } bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds, const int wordId, const BigramProperty *const bigramProperty, bool *const outAddedNewBigram) { - if (!mBigramPolicy->addNewEntry(prevWordIds[0], wordId, bigramProperty, outAddedNewBigram)) { - AKLOGE("Cannot add new bigram entry. terminalId: %d, targetTerminalId: %d", + // TODO: Support n-gram. + LanguageModelDictContent *const languageModelDictContent = + mBuffers->getMutableLanguageModelDictContent(); + const ProbabilityEntry probabilityEntry = + languageModelDictContent->getNgramProbabilityEntry( + prevWordIds.limit(1 /* maxSize */), wordId); + const ProbabilityEntry probabilityEntryOfBigramProperty(bigramProperty); + const ProbabilityEntry updatedProbabilityEntry = createUpdatedEntryFrom( + &probabilityEntry, &probabilityEntryOfBigramProperty); + if (!languageModelDictContent->setNgramProbabilityEntry( + prevWordIds.limit(1 /* maxSize */), wordId, &updatedProbabilityEntry)) { + AKLOGE("Cannot add new ngram entry. prevWordId: %d, wordId: %d", prevWordIds[0], wordId); return false; } + if (!probabilityEntry.isValid() && outAddedNewBigram) { + *outAddedNewBigram = true; + } return true; } bool Ver4PatriciaTrieNodeWriter::removeNgramEntry(const WordIdArrayView prevWordIds, const int wordId) { - return mBigramPolicy->removeEntry(prevWordIds[0], wordId); + // TODO: Support n-gram. + LanguageModelDictContent *const languageModelDictContent = + mBuffers->getMutableLanguageModelDictContent(); + return languageModelDictContent->removeNgramProbabilityEntry(prevWordIds.limit(1 /* maxSize */), + wordId); } +// TODO: Remove when we stop supporting v402 format. bool Ver4PatriciaTrieNodeWriter::updateAllBigramEntriesAndDeleteUselessEntries( const PtNodeParams *const sourcePtNodeParams, int *const outBigramEntryCount) { - return mBigramPolicy->updateAllBigramEntriesAndDeleteUselessEntries( - sourcePtNodeParams->getTerminalId(), outBigramEntryCount); + // Do nothing. + return true; } bool Ver4PatriciaTrieNodeWriter::updateAllPositionFields( @@ -275,12 +280,6 @@ bool Ver4PatriciaTrieNodeWriter::updateAllPositionFields( if (!updateChildrenPosition(toBeUpdatedPtNodeParams, childrenPos)) { return false; } - - // Counts bigram entries. - if (outBigramEntryCount) { - *outBigramEntryCount = mBigramPolicy->getBigramEntryConut( - toBeUpdatedPtNodeParams->getTerminalId()); - } return true; } @@ -350,22 +349,19 @@ bool Ver4PatriciaTrieNodeWriter::writePtNodeAndGetTerminalIdAndAdvancePosition( isTerminal, ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */); } +// TODO: Move probability handling code to LanguageModelDictContent. const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom( const ProbabilityEntry *const originalProbabilityEntry, - const UnigramProperty *const unigramProperty) const { - // TODO: Consolidate historical info and probability. + const ProbabilityEntry *const probabilityEntry) const { if (mHeaderPolicy->hasHistoricalInfoOfWords()) { - const HistoricalInfo historicalInfoForUpdate(unigramProperty->getTimestamp(), - unigramProperty->getLevel(), unigramProperty->getCount()); const HistoricalInfo updatedHistoricalInfo = ForgettingCurveUtils::createUpdatedHistoricalInfo( originalProbabilityEntry->getHistoricalInfo(), - unigramProperty->getProbability(), &historicalInfoForUpdate, mHeaderPolicy); - return originalProbabilityEntry->createEntryWithUpdatedHistoricalInfo( - &updatedHistoricalInfo); + probabilityEntry->getProbability(), probabilityEntry->getHistoricalInfo(), + mHeaderPolicy); + return ProbabilityEntry(probabilityEntry->getFlags(), &updatedHistoricalInfo); } else { - return originalProbabilityEntry->createEntryWithUpdatedProbability( - unigramProperty->getProbability()); + return *probabilityEntry; } } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h index 6703dba04..08b7d3825 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h @@ -27,7 +27,6 @@ namespace latinime { class BufferWithExtendableBuffer; class HeaderPolicy; -class Ver4BigramListPolicy; class Ver4DictBuffers; class Ver4PatriciaTrieNodeReader; class Ver4PtNodeArrayReader; @@ -42,10 +41,9 @@ class Ver4PatriciaTrieNodeWriter : public PtNodeWriter { Ver4DictBuffers *const buffers, const HeaderPolicy *const headerPolicy, const PtNodeReader *const ptNodeReader, const PtNodeArrayReader *const ptNodeArrayReader, - Ver4BigramListPolicy *const bigramPolicy, Ver4ShortcutListPolicy *const shortcutPolicy) + Ver4ShortcutListPolicy *const shortcutPolicy) : mTrieBuffer(trieBuffer), mBuffers(buffers), mHeaderPolicy(headerPolicy), - mReadingHelper(ptNodeReader, ptNodeArrayReader), mBigramPolicy(bigramPolicy), - mShortcutPolicy(shortcutPolicy) {} + mReadingHelper(ptNodeReader, ptNodeArrayReader), mShortcutPolicy(shortcutPolicy) {} virtual ~Ver4PatriciaTrieNodeWriter() {} @@ -98,12 +96,12 @@ class Ver4PatriciaTrieNodeWriter : public PtNodeWriter { const PtNodeParams *const ptNodeParams, int *const outTerminalId, int *const ptNodeWritingPos); - // Create updated probability entry using given unigram property. In addition to the + // Create updated probability entry using given probability property. In addition to the // probability, this method updates historical information if needed. - // TODO: Update flags belonging to the unigram property. + // TODO: Update flags. const ProbabilityEntry createUpdatedEntryFrom( const ProbabilityEntry *const originalProbabilityEntry, - const UnigramProperty *const unigramProperty) const; + const ProbabilityEntry *const probabilityEntry) const; bool updatePtNodeFlags(const int ptNodePos, const bool isBlacklisted, const bool isNotAWord, const bool isTerminal, const bool hasMultipleChars); @@ -114,7 +112,6 @@ class Ver4PatriciaTrieNodeWriter : public PtNodeWriter { Ver4DictBuffers *const mBuffers; const HeaderPolicy *const mHeaderPolicy; DynamicPtReadingHelper mReadingHelper; - Ver4BigramListPolicy *const mBigramPolicy; Ver4ShortcutListPolicy *const mShortcutPolicy; }; } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp index 723808399..d537711b0 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp @@ -16,10 +16,12 @@ #include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h" +#include <array> #include <vector> #include "suggest/core/dicnode/dic_node.h" #include "suggest/core/dicnode/dic_node_vector.h" +#include "suggest/core/dictionary/multi_bigram_map.h" #include "suggest/core/dictionary/ngram_listener.h" #include "suggest/core/dictionary/property/bigram_property.h" #include "suggest/core/dictionary/property/unigram_property.h" @@ -66,12 +68,9 @@ void Ver4PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const d // Skip PtNodes that represent non-word information. continue; } - childDicNodes->pushLeavingChild(dicNode, ptNodeParams.getHeadPos(), - ptNodeParams.getChildrenPos(), ptNodeParams.getProbability(), isTerminal, - ptNodeParams.hasChildren(), - ptNodeParams.isBlacklisted() - || ptNodeParams.isNotAWord() /* isBlacklistedOrNotAWord */, - ptNodeParams.getCodePointCount(), ptNodeParams.getCodePoints()); + const int wordId = isTerminal ? ptNodeParams.getTerminalId() : NOT_A_WORD_ID; + childDicNodes->pushLeavingChild(dicNode, ptNodeParams.getChildrenPos(), + wordId, ptNodeParams.getCodePointArrayView()); } if (readingHelper.isError()) { mIsCorrupted = true; @@ -80,9 +79,11 @@ void Ver4PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const d } int Ver4PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( - const int ptNodePos, const int maxCodePointCount, int *const outCodePoints, + const int wordId, const int maxCodePointCount, int *const outCodePoints, int *const outUnigramProbability) const { DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader); + const int ptNodePos = + mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId); readingHelper.initWithPtNodePos(ptNodePos); const int codePointCount = readingHelper.getCodePointsAndProbabilityAndReturnCodePointCount( maxCodePointCount, outCodePoints, outUnigramProbability); @@ -93,76 +94,95 @@ int Ver4PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( return codePointCount; } -int Ver4PatriciaTriePolicy::getTerminalPtNodePositionOfWord(const int *const inWord, - const int length, const bool forceLowerCaseSearch) const { +int Ver4PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints, + const bool forceLowerCaseSearch) const { DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader); readingHelper.initWithPtNodeArrayPos(getRootPosition()); - const int ptNodePos = - readingHelper.getTerminalPtNodePositionOfWord(inWord, length, forceLowerCaseSearch); + const int ptNodePos = readingHelper.getTerminalPtNodePositionOfWord(wordCodePoints.data(), + wordCodePoints.size(), forceLowerCaseSearch); if (readingHelper.isError()) { mIsCorrupted = true; AKLOGE("Dictionary reading error in createAndGetAllChildDicNodes()."); } - return ptNodePos; + if (ptNodePos == NOT_A_DICT_POS) { + return NOT_A_WORD_ID; + } + const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); + return ptNodeParams.getTerminalId(); } -int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability, - const int bigramProbability) const { - if (mHeaderPolicy->isDecayingDict()) { - // Both probabilities are encoded. Decode them and get probability. - return ForgettingCurveUtils::getProbability(unigramProbability, bigramProbability); - } else { - if (unigramProbability == NOT_A_PROBABILITY) { - return NOT_A_PROBABILITY; - } else if (bigramProbability == NOT_A_PROBABILITY) { - return ProbabilityUtils::backoff(unigramProbability); - } else { - return bigramProbability; - } +const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext( + const WordIdArrayView prevWordIds, const int wordId, + MultiBigramMap *const multiBigramMap) const { + if (wordId == NOT_A_WORD_ID) { + return WordAttributes(); } + const int ptNodePos = + mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId); + const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); + // TODO: Support n-gram. + const int probability = mBuffers->getLanguageModelDictContent()->getWordProbability( + prevWordIds.limit(1 /* maxSize */), wordId, mHeaderPolicy); + return WordAttributes(probability, ptNodeParams.isBlacklisted(), ptNodeParams.isNotAWord(), + probability == 0); } -int Ver4PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtNodePos, - const int ptNodePos) const { - if (ptNodePos == NOT_A_DICT_POS) { +int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds, + const int wordId) const { + if (wordId == NOT_A_WORD_ID) { return NOT_A_PROBABILITY; } - const PtNodeParams ptNodeParams(mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos)); + const int ptNodePos = + mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId); + const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) { return NOT_A_PROBABILITY; } - if (prevWordsPtNodePos) { - const int bigramsPosition = getBigramsPositionOfPtNode(prevWordsPtNodePos[0]); - BinaryDictionaryBigramsIterator bigramsIt(&mBigramPolicy, bigramsPosition); - while (bigramsIt.hasNext()) { - bigramsIt.next(); - if (bigramsIt.getBigramPos() == ptNodePos - && bigramsIt.getProbability() != NOT_A_PROBABILITY) { - return getProbability(ptNodeParams.getProbability(), bigramsIt.getProbability()); - } - } + // TODO: Support n-gram. + const ProbabilityEntry probabilityEntry = + mBuffers->getLanguageModelDictContent()->getNgramProbabilityEntry( + prevWordIds.limit(1 /* maxSize */), wordId); + if (!probabilityEntry.isValid()) { return NOT_A_PROBABILITY; } - return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY); + if (mHeaderPolicy->hasHistoricalInfoOfWords()) { + return ForgettingCurveUtils::decodeProbability(probabilityEntry.getHistoricalInfo(), + mHeaderPolicy); + } else { + return probabilityEntry.getProbability(); + } } -void Ver4PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordsPtNodePos, +BinaryDictionaryShortcutIterator Ver4PatriciaTriePolicy::getShortcutIterator( + const int wordId) const { + const int shortcutPos = getShortcutPositionOfWord(wordId); + return BinaryDictionaryShortcutIterator(&mShortcutPolicy, shortcutPos); +} + +void Ver4PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordIds, NgramListener *const listener) const { - if (!prevWordsPtNodePos) { + if (prevWordIds.empty()) { return; } - const int bigramsPosition = getBigramsPositionOfPtNode(prevWordsPtNodePos[0]); - BinaryDictionaryBigramsIterator bigramsIt(&mBigramPolicy, bigramsPosition); - while (bigramsIt.hasNext()) { - bigramsIt.next(); - listener->onVisitEntry(bigramsIt.getProbability(), bigramsIt.getBigramPos()); + // TODO: Support n-gram. + const auto languageModelDictContent = mBuffers->getLanguageModelDictContent(); + for (const auto entry : languageModelDictContent->getProbabilityEntries( + prevWordIds.limit(1 /* maxSize */))) { + const ProbabilityEntry &probabilityEntry = entry.getProbabilityEntry(); + const int probability = probabilityEntry.hasHistoricalInfo() ? + ForgettingCurveUtils::decodeProbability( + probabilityEntry.getHistoricalInfo(), mHeaderPolicy) : + probabilityEntry.getProbability(); + listener->onVisitEntry(probability, entry.getWordId()); } } -int Ver4PatriciaTriePolicy::getShortcutPositionOfPtNode(const int ptNodePos) const { - if (ptNodePos == NOT_A_DICT_POS) { +int Ver4PatriciaTriePolicy::getShortcutPositionOfWord(const int wordId) const { + if (wordId == NOT_A_WORD_ID) { return NOT_A_DICT_POS; } + const int ptNodePos = + mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId); const PtNodeParams ptNodeParams(mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos)); if (ptNodeParams.isDeleted()) { return NOT_A_DICT_POS; @@ -171,19 +191,7 @@ int Ver4PatriciaTriePolicy::getShortcutPositionOfPtNode(const int ptNodePos) con ptNodeParams.getTerminalId()); } -int Ver4PatriciaTriePolicy::getBigramsPositionOfPtNode(const int ptNodePos) const { - if (ptNodePos == NOT_A_DICT_POS) { - return NOT_A_DICT_POS; - } - const PtNodeParams ptNodeParams(mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos)); - if (ptNodeParams.isDeleted()) { - return NOT_A_DICT_POS; - } - return mBuffers->getBigramDictContent()->getBigramListHeadPos( - ptNodeParams.getTerminalId()); -} - -bool Ver4PatriciaTriePolicy::addUnigramEntry(const int *const word, const int length, +bool Ver4PatriciaTriePolicy::addUnigramEntry(const CodePointArrayView wordCodePoints, const UnigramProperty *const unigramProperty) { if (!mBuffers->isUpdatable()) { AKLOGI("Warning: addUnigramEntry() is called for non-updatable dictionary."); @@ -194,13 +202,14 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const int *const word, const int le mDictBuffer->getTailPosition()); return false; } - if (length > MAX_WORD_LENGTH) { - AKLOGE("The word is too long to insert to the dictionary, length: %d", length); + if (wordCodePoints.size() > MAX_WORD_LENGTH) { + AKLOGE("The word is too long to insert to the dictionary, length: %zd", + wordCodePoints.size()); return false; } for (const auto &shortcut : unigramProperty->getShortcuts()) { if (shortcut.getTargetCodePoints()->size() > MAX_WORD_LENGTH) { - AKLOGE("One of shortcut targets is too long to insert to the dictionary, length: %d", + AKLOGE("One of shortcut targets is too long to insert to the dictionary, length: %zd", shortcut.getTargetCodePoints()->size()); return false; } @@ -209,8 +218,8 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const int *const word, const int le readingHelper.initWithPtNodeArrayPos(getRootPosition()); bool addedNewUnigram = false; int codePointsToAdd[MAX_WORD_LENGTH]; - int codePointCountToAdd = length; - memmove(codePointsToAdd, word, sizeof(int) * length); + int codePointCountToAdd = wordCodePoints.size(); + memmove(codePointsToAdd, wordCodePoints.data(), sizeof(int) * codePointCountToAdd); if (unigramProperty->representsBeginningOfSentence()) { codePointCountToAdd = CharUtils::attachBeginningOfSentenceMarker(codePointsToAdd, codePointCountToAdd, MAX_WORD_LENGTH); @@ -218,24 +227,26 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const int *const word, const int le if (codePointCountToAdd <= 0) { return false; } - if (mUpdatingHelper.addUnigramWord(&readingHelper, codePointsToAdd, codePointCountToAdd, - unigramProperty, &addedNewUnigram)) { + const CodePointArrayView codePointArrayView(codePointsToAdd, codePointCountToAdd); + if (mUpdatingHelper.addUnigramWord(&readingHelper, codePointArrayView.data(), + codePointArrayView.size(), unigramProperty, &addedNewUnigram)) { if (addedNewUnigram && !unigramProperty->representsBeginningOfSentence()) { mUnigramCount++; } if (unigramProperty->getShortcuts().size() > 0) { // Add shortcut target. - const int wordPos = getTerminalPtNodePositionOfWord(word, length, - false /* forceLowerCaseSearch */); - if (wordPos == NOT_A_DICT_POS) { - AKLOGE("Cannot find terminal PtNode position to add shortcut target."); + const int wordId = getWordId(codePointArrayView, false /* forceLowerCaseSearch */); + if (wordId == NOT_A_WORD_ID) { + AKLOGE("Cannot find word id to add shortcut target."); return false; } + const int wordPos = + mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId); for (const auto &shortcut : unigramProperty->getShortcuts()) { if (!mUpdatingHelper.addShortcutTarget(wordPos, shortcut.getTargetCodePoints()->data(), shortcut.getTargetCodePoints()->size(), shortcut.getProbability())) { - AKLOGE("Cannot add new shortcut target. PtNodePos: %d, length: %d, " + AKLOGE("Cannot add new shortcut target. PtNodePos: %d, length: %zd, " "probability: %d", wordPos, shortcut.getTargetCodePoints()->size(), shortcut.getProbability()); return false; @@ -248,21 +259,25 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const int *const word, const int le } } -bool Ver4PatriciaTriePolicy::removeUnigramEntry(const int *const word, const int length) { +bool Ver4PatriciaTriePolicy::removeUnigramEntry(const CodePointArrayView wordCodePoints) { if (!mBuffers->isUpdatable()) { AKLOGI("Warning: removeUnigramEntry() is called for non-updatable dictionary."); return false; } - const int ptNodePos = getTerminalPtNodePositionOfWord(word, length, - false /* forceLowerCaseSearch */); - if (ptNodePos == NOT_A_DICT_POS) { + const int wordId = getWordId(wordCodePoints, false /* forceLowerCaseSearch */); + if (wordId == NOT_A_WORD_ID) { return false; } + const int ptNodePos = + mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId); const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); if (!mNodeWriter.markPtNodeAsDeleted(&ptNodeParams)) { AKLOGE("Cannot remove unigram. ptNodePos: %d", ptNodePos); return false; } + if (!mBuffers->getMutableLanguageModelDictContent()->removeProbabilityEntry(wordId)) { + return false; + } if (!ptNodeParams.representsNonWordInfo()) { mUnigramCount--; } @@ -286,43 +301,51 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI } if (bigramProperty->getTargetCodePoints()->size() > MAX_WORD_LENGTH) { AKLOGE("The word is too long to insert the ngram to the dictionary. " - "length: %d", bigramProperty->getTargetCodePoints()->size()); + "length: %zd", bigramProperty->getTargetCodePoints()->size()); return false; } - int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; - prevWordsInfo->getPrevWordsTerminalPtNodePos(this, prevWordsPtNodePos, + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; + const WordIdArrayView prevWordIds = prevWordsInfo->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */); - const auto prevWordsPtNodePosView = PtNodePosArrayView::fromFixedSizeArray(prevWordsPtNodePos); - // TODO: Support N-gram. - if (prevWordsPtNodePos[0] == NOT_A_DICT_POS) { - if (prevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */)) { - const std::vector<UnigramProperty::ShortcutProperty> shortcuts; - const UnigramProperty beginningOfSentenceUnigramProperty( - true /* representsBeginningOfSentence */, true /* isNotAWord */, - false /* isBlacklisted */, MAX_PROBABILITY /* probability */, - NOT_A_TIMESTAMP /* timestamp */, 0 /* level */, 0 /* count */, &shortcuts); - if (!addUnigramEntry(prevWordsInfo->getNthPrevWordCodePoints(1 /* n */), - prevWordsInfo->getNthPrevWordCodePointCount(1 /* n */), - &beginningOfSentenceUnigramProperty)) { - AKLOGE("Cannot add unigram entry for the beginning-of-sentence."); - return false; - } - // Refresh Terminal PtNode positions. - prevWordsInfo->getPrevWordsTerminalPtNodePos(this, prevWordsPtNodePos, - false /* tryLowerCaseSearch */); - } else { + if (prevWordIds.empty()) { + return false; + } + for (size_t i = 0; i < prevWordIds.size(); ++i) { + if (prevWordIds[i] != NOT_A_WORD_ID) { + continue; + } + if (!prevWordsInfo->isNthPrevWordBeginningOfSentence(i + 1 /* n */)) { return false; } + const std::vector<UnigramProperty::ShortcutProperty> shortcuts; + const UnigramProperty beginningOfSentenceUnigramProperty( + true /* representsBeginningOfSentence */, true /* isNotAWord */, + false /* isBlacklisted */, MAX_PROBABILITY /* probability */, + NOT_A_TIMESTAMP /* timestamp */, 0 /* level */, 0 /* count */, &shortcuts); + if (!addUnigramEntry(prevWordsInfo->getNthPrevWordCodePoints(1 /* n */), + &beginningOfSentenceUnigramProperty)) { + AKLOGE("Cannot add unigram entry for the beginning-of-sentence."); + return false; + } + // Refresh word ids. + prevWordsInfo->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */); } - const int word1Pos = getTerminalPtNodePositionOfWord( - bigramProperty->getTargetCodePoints()->data(), - bigramProperty->getTargetCodePoints()->size(), false /* forceLowerCaseSearch */); - if (word1Pos == NOT_A_DICT_POS) { + const int wordId = getWordId(CodePointArrayView(*bigramProperty->getTargetCodePoints()), + false /* forceLowerCaseSearch */); + if (wordId == NOT_A_WORD_ID) { return false; } + // TODO: Support N-gram. bool addedNewEntry = false; - if (mUpdatingHelper.addNgramEntry(prevWordsPtNodePosView, word1Pos, bigramProperty, - &addedNewEntry)) { + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordsPtNodePos; + for (size_t i = 0; i < prevWordsPtNodePos.size(); ++i) { + prevWordsPtNodePos[i] = mBuffers->getTerminalPositionLookupTable() + ->getTerminalPtNodePosition(prevWordIds[i]); + } + const int wordPtNodePos = mBuffers->getTerminalPositionLookupTable() + ->getTerminalPtNodePosition(wordId); + if (mUpdatingHelper.addNgramEntry(WordIdArrayView::fromArray(prevWordsPtNodePos), + wordPtNodePos, bigramProperty, &addedNewEntry)) { if (addedNewEntry) { mBigramCount++; } @@ -333,7 +356,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI } bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, - const int *const word, const int length) { + const CodePointArrayView wordCodePoints) { if (!mBuffers->isUpdatable()) { AKLOGI("Warning: removeNgramEntry() is called for non-updatable dictionary."); return false; @@ -347,23 +370,29 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor AKLOGE("prev words info is not valid for removing n-gram entry form the dictionary."); return false; } - if (length > MAX_WORD_LENGTH) { - AKLOGE("word is too long to remove n-gram entry form the dictionary. length: %d", length); + if (wordCodePoints.size() > MAX_WORD_LENGTH) { + AKLOGE("word is too long to remove n-gram entry form the dictionary. length: %zd", + wordCodePoints.size()); } - int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; - prevWordsInfo->getPrevWordsTerminalPtNodePos(this, prevWordsPtNodePos, + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; + const WordIdArrayView prevWordIds = prevWordsInfo->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSerch */); - const auto prevWordsPtNodePosView = PtNodePosArrayView::fromFixedSizeArray(prevWordsPtNodePos); - // TODO: Support N-gram. - if (prevWordsPtNodePos[0] == NOT_A_DICT_POS) { + if (prevWordIds.empty() || prevWordIds.contains(NOT_A_WORD_ID)) { return false; } - const int wordPos = getTerminalPtNodePositionOfWord(word, length, - false /* forceLowerCaseSearch */); - if (wordPos == NOT_A_DICT_POS) { + const int wordId = getWordId(wordCodePoints, false /* forceLowerCaseSearch */); + if (wordId == NOT_A_WORD_ID) { return false; } - if (mUpdatingHelper.removeNgramEntry(prevWordsPtNodePosView, wordPos)) { + std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordsPtNodePos; + for (size_t i = 0; i < prevWordsPtNodePos.size(); ++i) { + prevWordsPtNodePos[i] = mBuffers->getTerminalPositionLookupTable() + ->getTerminalPtNodePosition(prevWordIds[i]); + } + const int wordPtNodePos = mBuffers->getTerminalPositionLookupTable() + ->getTerminalPtNodePosition(wordId); + if (mUpdatingHelper.removeNgramEntry(WordIdArrayView::fromArray(prevWordsPtNodePos), + wordPtNodePos)) { mBigramCount--; return true; } else { @@ -442,14 +471,15 @@ void Ver4PatriciaTriePolicy::getProperty(const char *const query, const int quer } } -const WordProperty Ver4PatriciaTriePolicy::getWordProperty(const int *const codePoints, - const int codePointCount) const { - const int ptNodePos = getTerminalPtNodePositionOfWord(codePoints, codePointCount, - false /* forceLowerCaseSearch */); - if (ptNodePos == NOT_A_DICT_POS) { +const WordProperty Ver4PatriciaTriePolicy::getWordProperty( + const CodePointArrayView wordCodePoints) const { + const int wordId = getWordId(wordCodePoints, false /* forceLowerCaseSearch */); + if (wordId == NOT_A_WORD_ID) { AKLOGE("getWordProperty is called for invalid word."); return WordProperty(); } + const int ptNodePos = + mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId); const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); std::vector<int> codePointVector(ptNodeParams.getCodePoints(), ptNodeParams.getCodePoints() + ptNodeParams.getCodePointCount()); @@ -458,45 +488,30 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(const int *const code ptNodeParams.getTerminalId()); const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo(); // Fetch bigram information. + // TODO: Support n-gram. std::vector<BigramProperty> bigrams; - const int bigramListPos = getBigramsPositionOfPtNode(ptNodePos); - if (bigramListPos != NOT_A_DICT_POS) { - int bigramWord1CodePoints[MAX_WORD_LENGTH]; - const BigramDictContent *const bigramDictContent = mBuffers->getBigramDictContent(); - const TerminalPositionLookupTable *const terminalPositionLookupTable = - mBuffers->getTerminalPositionLookupTable(); - bool hasNext = true; - int readingPos = bigramListPos; - while (hasNext) { - const BigramEntry bigramEntry = - bigramDictContent->getBigramEntryAndAdvancePosition(&readingPos); - hasNext = bigramEntry.hasNext(); - const int word1TerminalId = bigramEntry.getTargetTerminalId(); - const int word1TerminalPtNodePos = - terminalPositionLookupTable->getTerminalPtNodePosition(word1TerminalId); - if (word1TerminalPtNodePos == NOT_A_DICT_POS) { - continue; - } - // Word (unigram) probability - int word1Probability = NOT_A_PROBABILITY; - const int codePointCount = getCodePointsAndProbabilityAndReturnCodePointCount( - word1TerminalPtNodePos, MAX_WORD_LENGTH, bigramWord1CodePoints, - &word1Probability); - const std::vector<int> word1(bigramWord1CodePoints, - bigramWord1CodePoints + codePointCount); - const HistoricalInfo *const historicalInfo = bigramEntry.getHistoricalInfo(); - const int probability = bigramEntry.hasHistoricalInfo() ? - ForgettingCurveUtils::decodeProbability( - bigramEntry.getHistoricalInfo(), mHeaderPolicy) : - bigramEntry.getProbability(); - bigrams.emplace_back(&word1, probability, - historicalInfo->getTimeStamp(), historicalInfo->getLevel(), - historicalInfo->getCount()); - } + const WordIdArrayView prevWordIds = WordIdArrayView::singleElementView(&wordId); + int bigramWord1CodePoints[MAX_WORD_LENGTH]; + for (const auto entry : mBuffers->getLanguageModelDictContent()->getProbabilityEntries( + prevWordIds)) { + // Word (unigram) probability + int word1Probability = NOT_A_PROBABILITY; + const int codePointCount = getCodePointsAndProbabilityAndReturnCodePointCount( + entry.getWordId(), MAX_WORD_LENGTH, bigramWord1CodePoints, &word1Probability); + const std::vector<int> word1(bigramWord1CodePoints, + bigramWord1CodePoints + codePointCount); + const ProbabilityEntry probabilityEntry = entry.getProbabilityEntry(); + const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo(); + const int probability = probabilityEntry.hasHistoricalInfo() ? + ForgettingCurveUtils::decodeProbability(historicalInfo, mHeaderPolicy) : + probabilityEntry.getProbability(); + bigrams.emplace_back(&word1, probability, + historicalInfo->getTimeStamp(), historicalInfo->getLevel(), + historicalInfo->getCount()); } // Fetch shortcut information. std::vector<UnigramProperty::ShortcutProperty> shortcuts; - int shortcutPos = getShortcutPositionOfPtNode(ptNodePos); + int shortcutPos = getShortcutPositionOfWord(wordId); if (shortcutPos != NOT_A_DICT_POS) { int shortcutTarget[MAX_WORD_LENGTH]; const ShortcutDictContent *const shortcutDictContent = @@ -536,9 +551,11 @@ int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const return 0; } const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token]; + const PtNodeParams ptNodeParams = + mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(terminalPtNodePos); int unigramProbability = NOT_A_PROBABILITY; *outCodePointCount = getCodePointsAndProbabilityAndReturnCodePointCount( - terminalPtNodePos, MAX_WORD_LENGTH, outCodePoints, &unigramProbability); + ptNodeParams.getTerminalId(), MAX_WORD_LENGTH, outCodePoints, &unigramProbability); const int nextToken = token + 1; if (nextToken >= terminalPtNodePositionsVectorSize) { // All words have been iterated. diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h index faad4290d..a117a3614 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h @@ -23,7 +23,6 @@ #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" #include "suggest/policyimpl/dictionary/header/header_policy.h" #include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.h" -#include "suggest/policyimpl/dictionary/structure/v4/bigram/ver4_bigram_list_policy.h" #include "suggest/policyimpl/dictionary/structure/v4/shortcut/ver4_shortcut_list_policy.h" #include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.h" #include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.h" @@ -31,25 +30,25 @@ #include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h" #include "suggest/policyimpl/dictionary/structure/v4/ver4_pt_node_array_reader.h" #include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" +#include "utils/int_array_view.h" namespace latinime { class DicNode; class DicNodeVector; +// Word id = Artificial id that is stored in the PtNode looked up by the word. class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { public: Ver4PatriciaTriePolicy(Ver4DictBuffers::Ver4DictBuffersPtr buffers) : mBuffers(std::move(buffers)), mHeaderPolicy(mBuffers->getHeaderPolicy()), mDictBuffer(mBuffers->getWritableTrieBuffer()), - mBigramPolicy(mBuffers->getMutableBigramDictContent(), - mBuffers->getTerminalPositionLookupTable(), mHeaderPolicy), mShortcutPolicy(mBuffers->getMutableShortcutDictContent(), mBuffers->getTerminalPositionLookupTable()), mNodeReader(mDictBuffer, mBuffers->getLanguageModelDictContent(), mHeaderPolicy), mPtNodeArrayReader(mDictBuffer), mNodeWriter(mDictBuffer, mBuffers.get(), mHeaderPolicy, &mNodeReader, - &mPtNodeArrayReader, &mBigramPolicy, &mShortcutPolicy), + &mPtNodeArrayReader, &mShortcutPolicy), mUpdatingHelper(mDictBuffer, &mNodeReader, &mNodeWriter), mWritingHelper(mBuffers.get()), mUnigramCount(mHeaderPolicy->getUnigramCount()), @@ -64,39 +63,41 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { DicNodeVector *const childDicNodes) const; int getCodePointsAndProbabilityAndReturnCodePointCount( - const int terminalPtNodePos, const int maxCodePointCount, int *const outCodePoints, + const int wordId, const int maxCodePointCount, int *const outCodePoints, int *const outUnigramProbability) const; - int getTerminalPtNodePositionOfWord(const int *const inWord, - const int length, const bool forceLowerCaseSearch) const; + int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const; - int getProbability(const int unigramProbability, const int bigramProbability) const; + const WordAttributes getWordAttributesInContext(const WordIdArrayView prevWordIds, + const int wordId, MultiBigramMap *const multiBigramMap) const; - int getProbabilityOfPtNode(const int *const prevWordsPtNodePos, const int ptNodePos) const; + // TODO: Remove + int getProbability(const int unigramProbability, const int bigramProbability) const { + // Not used. + return NOT_A_PROBABILITY; + } + + int getProbabilityOfWord(const WordIdArrayView prevWordIds, const int wordId) const; - void iterateNgramEntries(const int *const prevWordsPtNodePos, + void iterateNgramEntries(const WordIdArrayView prevWordIds, NgramListener *const listener) const; - int getShortcutPositionOfPtNode(const int ptNodePos) const; + BinaryDictionaryShortcutIterator getShortcutIterator(const int wordId) const; const DictionaryHeaderStructurePolicy *getHeaderStructurePolicy() const { return mHeaderPolicy; } - const DictionaryShortcutsStructurePolicy *getShortcutsStructurePolicy() const { - return &mShortcutPolicy; - } - - bool addUnigramEntry(const int *const word, const int length, + bool addUnigramEntry(const CodePointArrayView wordCodePoints, const UnigramProperty *const unigramProperty); - bool removeUnigramEntry(const int *const word, const int length); + bool removeUnigramEntry(const CodePointArrayView wordCodePoints); bool addNgramEntry(const PrevWordsInfo *const prevWordsInfo, const BigramProperty *const bigramProperty); - bool removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, const int *const word1, - const int length1); + bool removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, + const CodePointArrayView wordCodePoints); bool flush(const char *const filePath); @@ -107,8 +108,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { void getProperty(const char *const query, const int queryLength, char *const outResult, const int maxResultLength); - const WordProperty getWordProperty(const int *const codePoints, - const int codePointCount) const; + const WordProperty getWordProperty(const CodePointArrayView wordCodePoints) const; int getNextWordAndNextToken(const int token, int *const outCodePoints, int *const outCodePointCount); @@ -132,7 +132,6 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { const Ver4DictBuffers::Ver4DictBuffersPtr mBuffers; const HeaderPolicy *const mHeaderPolicy; BufferWithExtendableBuffer *const mDictBuffer; - Ver4BigramListPolicy mBigramPolicy; Ver4ShortcutListPolicy mShortcutPolicy; Ver4PatriciaTrieNodeReader mNodeReader; Ver4PtNodeArrayReader mPtNodeArrayReader; @@ -144,7 +143,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { std::vector<int> mTerminalPtNodePositionsForIteratingWords; mutable bool mIsCorrupted; - int getBigramsPositionOfPtNode(const int ptNodePos) const; + int getShortcutPositionOfWord(const int wordId) const; }; } // namespace latinime #endif // LATINIME_VER4_PATRICIA_TRIE_POLICY_H diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp index 4220312e0..63e43a544 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp @@ -20,7 +20,6 @@ #include <queue> #include "suggest/policyimpl/dictionary/header/header_policy.h" -#include "suggest/policyimpl/dictionary/structure/v4/bigram/ver4_bigram_list_policy.h" #include "suggest/policyimpl/dictionary/structure/v4/shortcut/ver4_shortcut_list_policy.h" #include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.h" #include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h" @@ -77,13 +76,33 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, Ver4PatriciaTrieNodeReader ptNodeReader(mBuffers->getTrieBuffer(), mBuffers->getLanguageModelDictContent(), headerPolicy); Ver4PtNodeArrayReader ptNodeArrayReader(mBuffers->getTrieBuffer()); - Ver4BigramListPolicy bigramPolicy(mBuffers->getMutableBigramDictContent(), - mBuffers->getTerminalPositionLookupTable(), headerPolicy); Ver4ShortcutListPolicy shortcutPolicy(mBuffers->getMutableShortcutDictContent(), mBuffers->getTerminalPositionLookupTable()); Ver4PatriciaTrieNodeWriter ptNodeWriter(mBuffers->getWritableTrieBuffer(), - mBuffers, headerPolicy, &ptNodeReader, &ptNodeArrayReader, &bigramPolicy, - &shortcutPolicy); + mBuffers, headerPolicy, &ptNodeReader, &ptNodeArrayReader, &shortcutPolicy); + + int entryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; + if (!mBuffers->getMutableLanguageModelDictContent()->updateAllProbabilityEntries(headerPolicy, + entryCountTable)) { + AKLOGE("Failed to update probabilities in language model dict content."); + return false; + } + if (headerPolicy->isDecayingDict()) { + int maxEntryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; + maxEntryCountTable[LanguageModelDictContent::UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE] = + headerPolicy->getMaxUnigramCount(); + maxEntryCountTable[LanguageModelDictContent::BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE] = + headerPolicy->getMaxBigramCount(); + for (size_t i = 2; i < NELEMS(maxEntryCountTable); ++i) { + // TODO: Have max n-gram count. + maxEntryCountTable[i] = headerPolicy->getMaxBigramCount(); + } + if (!mBuffers->getMutableLanguageModelDictContent()->truncateEntries(entryCountTable, + maxEntryCountTable, headerPolicy, entryCountTable)) { + AKLOGE("Failed to truncate entries in language model dict content."); + return false; + } + } DynamicPtReadingHelper readingHelper(&ptNodeReader, &ptNodeArrayReader); readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos); @@ -95,16 +114,6 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, &traversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted)) { return false; } - const int unigramCount = traversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted - .getValidUnigramCount(); - const int maxUnigramCount = headerPolicy->getMaxUnigramCount(); - if (headerPolicy->isDecayingDict() && unigramCount > maxUnigramCount) { - if (!truncateUnigrams(&ptNodeReader, &ptNodeWriter, maxUnigramCount)) { - AKLOGE("Cannot remove unigrams. current: %d, max: %d", unigramCount, - maxUnigramCount); - return false; - } - } readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos); DynamicPtGcEventListeners::TraversePolicyToUpdateBigramProbability @@ -113,21 +122,12 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, &traversePolicyToUpdateBigramProbability)) { return false; } - const int bigramCount = traversePolicyToUpdateBigramProbability.getValidBigramEntryCount(); - const int maxBigramCount = headerPolicy->getMaxBigramCount(); - if (headerPolicy->isDecayingDict() && bigramCount > maxBigramCount) { - if (!truncateBigrams(maxBigramCount)) { - AKLOGE("Cannot remove bigrams. current: %d, max: %d", bigramCount, maxBigramCount); - return false; - } - } // Mapping from positions in mBuffer to positions in bufferToWrite. PtNodeWriter::DictPositionRelocationMap dictPositionRelocationMap; readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos); Ver4PatriciaTrieNodeWriter ptNodeWriterForNewBuffers(buffersToWrite->getWritableTrieBuffer(), - buffersToWrite, headerPolicy, &ptNodeReader, &ptNodeArrayReader, &bigramPolicy, - &shortcutPolicy); + buffersToWrite, headerPolicy, &ptNodeReader, &ptNodeArrayReader, &shortcutPolicy); DynamicPtGcEventListeners::TraversePolicyToPlaceAndWriteValidPtNodesToBuffer traversePolicyToPlaceAndWriteValidPtNodesToBuffer(&ptNodeWriterForNewBuffers, buffersToWrite->getWritableTrieBuffer(), &dictPositionRelocationMap); @@ -140,12 +140,10 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, Ver4PatriciaTrieNodeReader newPtNodeReader(buffersToWrite->getTrieBuffer(), buffersToWrite->getLanguageModelDictContent(), headerPolicy); Ver4PtNodeArrayReader newPtNodeArrayreader(buffersToWrite->getTrieBuffer()); - Ver4BigramListPolicy newBigramPolicy(buffersToWrite->getMutableBigramDictContent(), - buffersToWrite->getTerminalPositionLookupTable(), headerPolicy); Ver4ShortcutListPolicy newShortcutPolicy(buffersToWrite->getMutableShortcutDictContent(), buffersToWrite->getTerminalPositionLookupTable()); Ver4PatriciaTrieNodeWriter newPtNodeWriter(buffersToWrite->getWritableTrieBuffer(), - buffersToWrite, headerPolicy, &newPtNodeReader, &newPtNodeArrayreader, &newBigramPolicy, + buffersToWrite, headerPolicy, &newPtNodeReader, &newPtNodeArrayreader, &newShortcutPolicy); // Re-assign terminal IDs for valid terminal PtNodes. TerminalPositionLookupTable::TerminalIdMap terminalIdMap; @@ -158,11 +156,6 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, mBuffers->getLanguageModelDictContent(), nullptr /* outNgramCount */)) { return false; } - // Run GC for bigram dict content. - if(!buffersToWrite->getMutableBigramDictContent()->runGC(&terminalIdMap, - mBuffers->getBigramDictContent(), outBigramCount)) { - return false; - } // Run GC for shortcut dict content. if(!buffersToWrite->getMutableShortcutDictContent()->runGC(&terminalIdMap, mBuffers->getShortcutDictContent())) { @@ -183,92 +176,10 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, &traversePolicyToUpdateAllPtNodeFlagsAndTerminalIds)) { return false; } - *outUnigramCount = traversePolicyToUpdateAllPositionFields.getUnigramCount(); - return true; -} - -bool Ver4PatriciaTrieWritingHelper::truncateUnigrams( - const Ver4PatriciaTrieNodeReader *const ptNodeReader, - Ver4PatriciaTrieNodeWriter *const ptNodeWriter, const int maxUnigramCount) { - const TerminalPositionLookupTable *const terminalPosLookupTable = - mBuffers->getTerminalPositionLookupTable(); - const int nextTerminalId = terminalPosLookupTable->getNextTerminalId(); - std::priority_queue<DictProbability, std::vector<DictProbability>, DictProbabilityComparator> - priorityQueue; - for (int i = 0; i < nextTerminalId; ++i) { - const int terminalPos = terminalPosLookupTable->getTerminalPtNodePosition(i); - if (terminalPos == NOT_A_DICT_POS) { - continue; - } - const ProbabilityEntry probabilityEntry = - mBuffers->getLanguageModelDictContent()->getProbabilityEntry(i); - const int probability = probabilityEntry.hasHistoricalInfo() ? - ForgettingCurveUtils::decodeProbability( - probabilityEntry.getHistoricalInfo(), mBuffers->getHeaderPolicy()) : - probabilityEntry.getProbability(); - priorityQueue.push(DictProbability(terminalPos, probability, - probabilityEntry.getHistoricalInfo()->getTimeStamp())); - } - - // Delete unigrams. - while (static_cast<int>(priorityQueue.size()) > maxUnigramCount) { - const int ptNodePos = priorityQueue.top().getDictPos(); - priorityQueue.pop(); - const PtNodeParams ptNodeParams = - ptNodeReader->fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); - if (ptNodeParams.representsNonWordInfo()) { - continue; - } - if (!ptNodeWriter->markPtNodeAsWillBecomeNonTerminal(&ptNodeParams)) { - AKLOGE("Cannot mark PtNode as willBecomeNonterminal. PtNode pos: %d", ptNodePos); - return false; - } - } - return true; -} - -bool Ver4PatriciaTrieWritingHelper::truncateBigrams(const int maxBigramCount) { - const TerminalPositionLookupTable *const terminalPosLookupTable = - mBuffers->getTerminalPositionLookupTable(); - const int nextTerminalId = terminalPosLookupTable->getNextTerminalId(); - std::priority_queue<DictProbability, std::vector<DictProbability>, DictProbabilityComparator> - priorityQueue; - BigramDictContent *const bigramDictContent = mBuffers->getMutableBigramDictContent(); - for (int i = 0; i < nextTerminalId; ++i) { - const int bigramListPos = bigramDictContent->getBigramListHeadPos(i); - if (bigramListPos == NOT_A_DICT_POS) { - continue; - } - bool hasNext = true; - int readingPos = bigramListPos; - while (hasNext) { - const BigramEntry bigramEntry = - bigramDictContent->getBigramEntryAndAdvancePosition(&readingPos); - const int entryPos = readingPos - bigramDictContent->getBigramEntrySize(); - hasNext = bigramEntry.hasNext(); - if (!bigramEntry.isValid()) { - continue; - } - const int probability = bigramEntry.hasHistoricalInfo() ? - ForgettingCurveUtils::decodeProbability( - bigramEntry.getHistoricalInfo(), mBuffers->getHeaderPolicy()) : - bigramEntry.getProbability(); - priorityQueue.push(DictProbability(entryPos, probability, - bigramEntry.getHistoricalInfo()->getTimeStamp())); - } - } - - // Delete bigrams. - while (static_cast<int>(priorityQueue.size()) > maxBigramCount) { - const int entryPos = priorityQueue.top().getDictPos(); - const BigramEntry bigramEntry = bigramDictContent->getBigramEntry(entryPos); - const BigramEntry invalidatedBigramEntry = bigramEntry.getInvalidatedEntry(); - if (!bigramDictContent->writeBigramEntry(&invalidatedBigramEntry, entryPos)) { - AKLOGE("Cannot write bigram entry to remove. pos: %d", entryPos); - return false; - } - priorityQueue.pop(); - } + *outUnigramCount = + entryCountTable[LanguageModelDictContent::UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE]; + *outBigramCount = + entryCountTable[LanguageModelDictContent::BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE]; return true; } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h index bb464ad28..b6278c4cb 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h @@ -66,49 +66,6 @@ class Ver4PatriciaTrieWritingHelper { const TerminalPositionLookupTable::TerminalIdMap *const mTerminalIdMap; }; - // For truncateUnigrams() and truncateBigrams(). - class DictProbability { - public: - DictProbability(const int dictPos, const int probability, const int timestamp) - : mDictPos(dictPos), mProbability(probability), mTimestamp(timestamp) {} - - int getDictPos() const { - return mDictPos; - } - - int getProbability() const { - return mProbability; - } - - int getTimestamp() const { - return mTimestamp; - } - - private: - DISALLOW_DEFAULT_CONSTRUCTOR(DictProbability); - - int mDictPos; - int mProbability; - int mTimestamp; - }; - - // For truncateUnigrams() and truncateBigrams(). - class DictProbabilityComparator { - public: - bool operator()(const DictProbability &left, const DictProbability &right) { - if (left.getProbability() != right.getProbability()) { - return left.getProbability() > right.getProbability(); - } - if (left.getTimestamp() != right.getTimestamp()) { - return left.getTimestamp() < right.getTimestamp(); - } - return left.getDictPos() > right.getDictPos(); - } - - private: - DISALLOW_ASSIGNMENT_OPERATOR(DictProbabilityComparator); - }; - bool runGC(const int rootPtNodeArrayPos, const HeaderPolicy *const headerPolicy, Ver4DictBuffers *const buffersToWrite, int *const outUnigramCount, int *const outBigramCount); diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.cpp index 833063c17..ecbe7922c 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.cpp @@ -31,7 +31,7 @@ uint32_t BufferWithExtendableBuffer::readUint(const int size, const int pos) con uint32_t BufferWithExtendableBuffer::readUintAndAdvancePosition(const int size, int *const pos) const { - const int value = readUint(size, *pos); + const uint32_t value = readUint(size, *pos); *pos += size; return value; } diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/byte_array_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/byte_array_utils.h index c0a9fcb1d..4b3c98988 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/byte_array_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/byte_array_utils.h @@ -114,7 +114,7 @@ class ByteArrayUtils { return buffer[(*pos)++]; } - static AK_FORCE_INLINE int readUint(const uint8_t *const buffer, + static AK_FORCE_INLINE uint32_t readUint(const uint8_t *const buffer, const int size, const int pos) { // size must be in 1 to 4. ASSERT(size >= 1 && size <= 4); diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h index 9910777b8..313eb6b64 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h @@ -48,6 +48,11 @@ class ForgettingCurveUtils { static bool needsToDecay(const bool mindsBlockByDecay, const int unigramCount, const int bigramCount, const HeaderPolicy *const headerPolicy); + // TODO: Improve probability computation method and remove this. + static int getProbabilityBiasForNgram(const int n) { + return (n - 1) * MULTIPLIER_TWO_IN_PROBABILITY_SCALE; + } + AK_FORCE_INLINE static int getUnigramCountHardLimit(const int maxUnigramCount) { return static_cast<int>(static_cast<float>(maxUnigramCount) * UNIGRAM_COUNT_HARD_LIMIT_WEIGHT); diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp index 1916ea560..e6e7167c2 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp @@ -23,7 +23,7 @@ namespace latinime { const uint32_t FormatUtils::MAGIC_NUMBER = 0x9BC13AFE; // Magic number (4 bytes), version (2 bytes), flags (2 bytes), header size (4 bytes) = 12 -const int FormatUtils::DICTIONARY_MINIMUM_SIZE = 12; +const size_t FormatUtils::DICTIONARY_MINIMUM_SIZE = 12; /* static */ FormatUtils::FORMAT_VERSION FormatUtils::getFormatVersion(const int formatVersion) { switch (formatVersion) { @@ -40,14 +40,14 @@ const int FormatUtils::DICTIONARY_MINIMUM_SIZE = 12; } } /* static */ FormatUtils::FORMAT_VERSION FormatUtils::detectFormatVersion( - const uint8_t *const dict, const int dictSize) { + const ReadOnlyByteArrayView dictBuffer) { // The magic number is stored big-endian. // If the dictionary is less than 4 bytes, we can't even read the magic number, so we don't // understand this format. - if (dictSize < DICTIONARY_MINIMUM_SIZE) { + if (dictBuffer.size() < DICTIONARY_MINIMUM_SIZE) { return UNKNOWN_VERSION; } - const uint32_t magicNumber = ByteArrayUtils::readUint32(dict, 0); + const uint32_t magicNumber = ByteArrayUtils::readUint32(dictBuffer.data(), 0); switch (magicNumber) { case MAGIC_NUMBER: // The layout of the header is as follows: @@ -58,7 +58,7 @@ const int FormatUtils::DICTIONARY_MINIMUM_SIZE = 12; // Conceptually this converts the hardcoded value of the bytes in the file into // the symbolic value we use in the code. But we want the constants to be the // same so we use them for both here. - return getFormatVersion(ByteArrayUtils::readUint16(dict, 4)); + return getFormatVersion(ByteArrayUtils::readUint16(dictBuffer.data(), 4)); default: return UNKNOWN_VERSION; } diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h index 55ad5799f..51ad9877c 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h @@ -20,6 +20,7 @@ #include <cstdint> #include "defines.h" +#include "utils/byte_array_view.h" namespace latinime { @@ -42,12 +43,12 @@ class FormatUtils { static const uint32_t MAGIC_NUMBER; static FORMAT_VERSION getFormatVersion(const int formatVersion); - static FORMAT_VERSION detectFormatVersion(const uint8_t *const dict, const int dictSize); + static FORMAT_VERSION detectFormatVersion(const ReadOnlyByteArrayView dictBuffer); private: DISALLOW_IMPLICIT_CONSTRUCTORS(FormatUtils); - static const int DICTIONARY_MINIMUM_SIZE; + static const size_t DICTIONARY_MINIMUM_SIZE; }; } // namespace latinime #endif /* LATINIME_FORMAT_UTILS_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/sparse_table.h b/native/jni/src/suggest/policyimpl/dictionary/utils/sparse_table.h index fca8120f1..e1a96c6f7 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/sparse_table.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/sparse_table.h @@ -24,7 +24,6 @@ namespace latinime { -// Note that there is a corresponding implementation in SparseTable.java. // TODO: Support multiple content buffers. class SparseTable { public: diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp index 407b8efd0..39f417ebb 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp @@ -26,6 +26,7 @@ const int TrieMap::FIELD1_SIZE = 3; const int TrieMap::ENTRY_SIZE = FIELD0_SIZE + FIELD1_SIZE; const uint32_t TrieMap::VALUE_FLAG = 0x400000; const uint32_t TrieMap::VALUE_MASK = 0x3FFFFF; +const uint32_t TrieMap::INVALID_VALUE_IN_KEY_VALUE_ENTRY = VALUE_MASK; const uint32_t TrieMap::TERMINAL_LINK_FLAG = 0x800000; const uint32_t TrieMap::TERMINAL_LINK_MASK = 0x7FFFFF; const int TrieMap::NUM_OF_BITS_USED_FOR_ONE_LEVEL = 5; @@ -34,6 +35,7 @@ const int TrieMap::MAX_NUM_OF_ENTRIES_IN_ONE_LEVEL = 1 << NUM_OF_BITS_USED_FOR_O const int TrieMap::ROOT_BITMAP_ENTRY_INDEX = 0; const int TrieMap::ROOT_BITMAP_ENTRY_POS = MAX_NUM_OF_ENTRIES_IN_ONE_LEVEL * FIELD0_SIZE; const TrieMap::Entry TrieMap::EMPTY_BITMAP_ENTRY = TrieMap::Entry(0, 0); +const int TrieMap::TERMINAL_LINKED_ENTRY_COUNT = 2; // Value entry and bitmap entry. const uint64_t TrieMap::MAX_VALUE = (static_cast<uint64_t>(1) << ((FIELD0_SIZE + FIELD1_SIZE) * CHAR_BIT)) - 1; const int TrieMap::MAX_BUFFER_SIZE = TERMINAL_LINK_MASK * ENTRY_SIZE; @@ -76,14 +78,14 @@ int TrieMap::getNextLevelBitmapEntryIndex(const int key, const int bitmapEntryIn return terminalEntry.getValueEntryIndex() + 1; } // Create a value entry and a bitmap entry. - const int valueEntryIndex = allocateTable(2 /* entryCount */); + const int valueEntryIndex = allocateTable(TERMINAL_LINKED_ENTRY_COUNT); if (!writeEntry(Entry(0, terminalEntry.getValue()), valueEntryIndex)) { return INVALID_INDEX; } if (!writeEntry(EMPTY_BITMAP_ENTRY, valueEntryIndex + 1)) { return INVALID_INDEX; } - if (!writeField1(valueEntryIndex | TERMINAL_LINK_FLAG, valueEntryIndex)) { + if (!writeField1(valueEntryIndex | TERMINAL_LINK_FLAG, terminalEntryIndex)) { return INVALID_INDEX; } return valueEntryIndex + 1; @@ -108,6 +110,31 @@ bool TrieMap::save(FILE *const file) const { return DictFileWritingUtils::writeBufferToFileTail(file, &mBuffer); } +bool TrieMap::remove(const int key, const int bitmapEntryIndex) { + const Entry bitmapEntry = readEntry(bitmapEntryIndex); + const uint32_t unsignedKey = static_cast<uint32_t>(key); + const int terminalEntryIndex = getTerminalEntryIndex( + unsignedKey, getBitShuffledKey(unsignedKey), bitmapEntry, 0 /* level */); + if (terminalEntryIndex == INVALID_INDEX) { + // Not found. + return false; + } + const Entry terminalEntry = readEntry(terminalEntryIndex); + if (!writeField1(VALUE_FLAG ^ INVALID_VALUE_IN_KEY_VALUE_ENTRY , terminalEntryIndex)) { + return false; + } + if (terminalEntry.hasTerminalLink()) { + const Entry nextLevelBitmapEntry = readEntry(terminalEntry.getValueEntryIndex() + 1); + if (!freeTable(terminalEntry.getValueEntryIndex(), TERMINAL_LINKED_ENTRY_COUNT)) { + return false; + } + if (!removeInner(nextLevelBitmapEntry)){ + return false; + } + } + return true; +} + /** * Iterate next entry in a certain level. * @@ -129,7 +156,7 @@ const TrieMap::Result TrieMap::iterateNext(std::vector<TableIterationState> *con if (entry.isBitmapEntry()) { // Move to child. iterationState->emplace_back(popCount(entry.getBitmap()), entry.getTableIndex()); - } else { + } else if (entry.isValidTerminalEntry()) { if (outKey) { *outKey = entry.getKey(); } @@ -162,12 +189,12 @@ uint32_t TrieMap::getBitShuffledKey(const uint32_t key) const { } bool TrieMap::writeValue(const uint64_t value, const int terminalEntryIndex) { - if (value <= VALUE_MASK) { + if (value < VALUE_MASK) { // Write value into the terminal entry. return writeField1(value | VALUE_FLAG, terminalEntryIndex); } // Create value entry and write value. - const int valueEntryIndex = allocateTable(2 /* entryCount */); + const int valueEntryIndex = allocateTable(TERMINAL_LINKED_ENTRY_COUNT); if (!writeEntry(Entry(value >> (FIELD1_SIZE * CHAR_BIT), value), valueEntryIndex)) { return false; } @@ -227,6 +254,9 @@ int TrieMap::getTerminalEntryIndex(const uint32_t key, const uint32_t hashedKey, // Move to the next level. return getTerminalEntryIndex(key, hashedKey, entry, level + 1); } + if (!entry.isValidTerminalEntry()) { + return INVALID_INDEX; + } if (entry.getKey() == key) { // Terminal entry is found. return entryIndex; @@ -287,6 +317,10 @@ bool TrieMap::putInternal(const uint32_t key, const uint64_t value, const uint32 // Bitmap entry is found. Go to the next level. return putInternal(key, value, hashedKey, entryIndex, entry, level + 1); } + if (!entry.isValidTerminalEntry()) { + // Overwrite invalid terminal entry. + return writeTerminalEntry(key, value, entryIndex); + } if (entry.getKey() == key) { // Terminal entry for the key is found. Update the value. return updateValue(entry, value, entryIndex); @@ -384,4 +418,37 @@ bool TrieMap::addNewEntryByExpandingTable(const uint32_t key, const uint64_t val return true; } +bool TrieMap::removeInner(const Entry &bitmapEntry) { + const int tableSize = popCount(bitmapEntry.getBitmap()); + if (tableSize <= 0) { + // The table is empty. No need to remove any entries. + return true; + } + for (int i = 0; i < tableSize; ++i) { + const int entryIndex = bitmapEntry.getTableIndex() + i; + const Entry entry = readEntry(entryIndex); + if (entry.isBitmapEntry()) { + // Delete next bitmap entry recursively. + if (!removeInner(entry)) { + return false; + } + } else { + // Invalidate terminal entry just in case. + if (!writeField1(VALUE_FLAG ^ INVALID_VALUE_IN_KEY_VALUE_ENTRY , entryIndex)) { + return false; + } + if (entry.hasTerminalLink()) { + const Entry nextLevelBitmapEntry = readEntry(entry.getValueEntryIndex() + 1); + if (!freeTable(entry.getValueEntryIndex(), TERMINAL_LINKED_ENTRY_COUNT)) { + return false; + } + if (!removeInner(nextLevelBitmapEntry)) { + return false; + } + } + } + } + return true; +} + } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h index 3e5c4010c..00765888b 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h @@ -84,6 +84,10 @@ class TrieMap { return mValue; } + AK_FORCE_INLINE int getNextLevelBitmapEntryIndex() const { + return mNextLevelBitmapEntryIndex; + } + private: const TrieMap *const mTrieMap; const int mKey; @@ -94,7 +98,7 @@ class TrieMap { TrieMapIterator(const TrieMap *const trieMap, const int bitmapEntryIndex) : mTrieMap(trieMap), mStateStack(), mBaseBitmapEntryIndex(bitmapEntryIndex), mKey(0), mValue(0), mIsValid(false), mNextLevelBitmapEntryIndex(INVALID_INDEX) { - if (!trieMap) { + if (!trieMap || mBaseBitmapEntryIndex == INVALID_INDEX) { return; } const Entry bitmapEntry = mTrieMap->readEntry(mBaseBitmapEntryIndex); @@ -202,6 +206,8 @@ class TrieMap { bool save(FILE *const file) const; + bool remove(const int key, const int bitmapEntryIndex); + private: DISALLOW_COPY_AND_ASSIGN(TrieMap); @@ -245,6 +251,11 @@ class TrieMap { } // For terminal entry. + AK_FORCE_INLINE bool isValidTerminalEntry() const { + return hasTerminalLink() || ((mData1 & VALUE_MASK) != INVALID_VALUE_IN_KEY_VALUE_ENTRY); + } + + // For terminal entry. AK_FORCE_INLINE uint32_t getValueEntryIndex() const { return mData1 & TERMINAL_LINK_MASK; } @@ -272,6 +283,7 @@ class TrieMap { static const int ENTRY_SIZE; static const uint32_t VALUE_FLAG; static const uint32_t VALUE_MASK; + static const uint32_t INVALID_VALUE_IN_KEY_VALUE_ENTRY; static const uint32_t TERMINAL_LINK_FLAG; static const uint32_t TERMINAL_LINK_MASK; static const int NUM_OF_BITS_USED_FOR_ONE_LEVEL; @@ -280,6 +292,7 @@ class TrieMap { static const int ROOT_BITMAP_ENTRY_INDEX; static const int ROOT_BITMAP_ENTRY_POS; static const Entry EMPTY_BITMAP_ENTRY; + static const int TERMINAL_LINKED_ENTRY_COUNT; static const int MAX_BUFFER_SIZE; uint32_t getBitShuffledKey(const uint32_t key) const; @@ -378,6 +391,8 @@ class TrieMap { AK_FORCE_INLINE int getTailEntryIndex() const { return (mBuffer.getTailPosition() - ROOT_BITMAP_ENTRY_POS) / ENTRY_SIZE; } + + bool removeInner(const Entry &bitmapEntry); }; } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h index 04cb6603a..52c4251f0 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h @@ -51,10 +51,10 @@ class TypingScoring : public Scoring { } if (boostExactMatches && ErrorTypeUtils::isExactMatch(containedErrorTypes)) { score += ScoringParams::EXACT_MATCH_PROMOTION; - if ((ErrorTypeUtils::MATCH_WITH_CASE_ERROR & containedErrorTypes) != 0) { + if ((ErrorTypeUtils::MATCH_WITH_WRONG_CASE & containedErrorTypes) != 0) { score -= ScoringParams::CASE_ERROR_PENALTY_FOR_EXACT_MATCH; } - if ((ErrorTypeUtils::MATCH_WITH_ACCENT_ERROR & containedErrorTypes) != 0) { + if ((ErrorTypeUtils::MATCH_WITH_MISSING_ACCENT & containedErrorTypes) != 0) { score -= ScoringParams::ACCENT_ERROR_PENALTY_FOR_EXACT_MATCH; } if ((ErrorTypeUtils::MATCH_WITH_DIGRAPH & containedErrorTypes) != 0) { diff --git a/native/jni/src/suggest/policyimpl/typing/typing_traversal.h b/native/jni/src/suggest/policyimpl/typing/typing_traversal.h index cb3dfac70..b64ee8be4 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_traversal.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_traversal.h @@ -161,8 +161,8 @@ class TypingTraversal : public Traversal { return true; } - AK_FORCE_INLINE bool isGoodToTraverseNextWord(const DicNode *const dicNode) const { - const int probability = dicNode->getProbability(); + AK_FORCE_INLINE bool isGoodToTraverseNextWord(const DicNode *const dicNode, + const int probability) const { if (probability < ScoringParams::THRESHOLD_NEXT_WORD_PROBABILITY) { return false; } diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp b/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp index 54f65c786..1d590c353 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp +++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp @@ -36,25 +36,34 @@ ErrorTypeUtils::ErrorType TypingWeighting::getErrorType(const CorrectionType cor // Compare the node code point with original primary code point on the keyboard. const ProximityInfoState *const pInfoState = traverseSession->getProximityInfoState(0); - const int primaryOriginalCodePoint = pInfoState->getPrimaryOriginalCodePointAt( + const int primaryCodePoint = pInfoState->getPrimaryCodePointAt( dicNode->getInputIndex(0)); const int nodeCodePoint = dicNode->getNodeCodePoint(); - if (primaryOriginalCodePoint == nodeCodePoint) { + // TODO: Check whether the input code point is on the keyboard. + if (primaryCodePoint == nodeCodePoint) { // Node code point is same as original code point on the keyboard. return ErrorTypeUtils::NOT_AN_ERROR; - } else if (CharUtils::toLowerCase(primaryOriginalCodePoint) == + } else if (CharUtils::toLowerCase(primaryCodePoint) == CharUtils::toLowerCase(nodeCodePoint)) { // Only cases of the code points are different. - return ErrorTypeUtils::MATCH_WITH_CASE_ERROR; - } else if (CharUtils::toBaseCodePoint(primaryOriginalCodePoint) == - CharUtils::toBaseCodePoint(nodeCodePoint)) { + return ErrorTypeUtils::MATCH_WITH_WRONG_CASE; + } else if (primaryCodePoint == CharUtils::toBaseCodePoint(nodeCodePoint)) { // Node code point is a variant of original code point. - return ErrorTypeUtils::MATCH_WITH_ACCENT_ERROR; - } else { + return ErrorTypeUtils::MATCH_WITH_MISSING_ACCENT; + } else if (CharUtils::toBaseCodePoint(primaryCodePoint) + == CharUtils::toBaseCodePoint(nodeCodePoint)) { + // Base code points are the same but the code point is intentionally input. + return ErrorTypeUtils::MATCH_WITH_WRONG_ACCENT; + } else if (CharUtils::toLowerCase(primaryCodePoint) + == CharUtils::toBaseLowerCase(nodeCodePoint)) { // Node code point is a variant of original code point and the cases are also // different. - return ErrorTypeUtils::MATCH_WITH_ACCENT_ERROR - | ErrorTypeUtils::MATCH_WITH_CASE_ERROR; + return ErrorTypeUtils::MATCH_WITH_MISSING_ACCENT + | ErrorTypeUtils::MATCH_WITH_WRONG_CASE; + } else { + // Base code points are the same and the cases are different. + return ErrorTypeUtils::MATCH_WITH_WRONG_ACCENT + | ErrorTypeUtils::MATCH_WITH_WRONG_CASE; } } break; diff --git a/native/jni/src/utils/byte_array_view.h b/native/jni/src/utils/byte_array_view.h index 2c97c6d58..10d7ae278 100644 --- a/native/jni/src/utils/byte_array_view.h +++ b/native/jni/src/utils/byte_array_view.h @@ -77,10 +77,12 @@ class ReadWriteByteArrayView { } private: - DISALLOW_ASSIGNMENT_OPERATOR(ReadWriteByteArrayView); + // Default copy constructor and assignment operator are used for using this class with STL + // containers. - uint8_t *const mPtr; - const size_t mSize; + // These members cannot be const to have the assignment operator. + uint8_t *mPtr; + size_t mSize; }; } // namespace latinime diff --git a/native/jni/src/utils/char_utils.cpp b/native/jni/src/utils/char_utils.cpp index b17e0847d..3bb9055b2 100644 --- a/native/jni/src/utils/char_utils.cpp +++ b/native/jni/src/utils/char_utils.cpp @@ -1057,11 +1057,11 @@ static int compare_pair_capital(const void *a, const void *b) { - static_cast<int>((static_cast<const struct LatinCapitalSmallPair *>(b))->capital); } -/* static */ unsigned short CharUtils::latin_tolower(const unsigned short c) { +/* static */ int CharUtils::latin_tolower(const int c) { struct LatinCapitalSmallPair *p = static_cast<struct LatinCapitalSmallPair *>(bsearch(&c, SORTED_CHAR_MAP, NELEMS(SORTED_CHAR_MAP), sizeof(SORTED_CHAR_MAP[0]), compare_pair_capital)); - return p ? p->small : c; + return p ? static_cast<int>(p->small) : c; } /* diff --git a/native/jni/src/utils/char_utils.h b/native/jni/src/utils/char_utils.h index 63786502b..5e9cdd9b2 100644 --- a/native/jni/src/utils/char_utils.h +++ b/native/jni/src/utils/char_utils.h @@ -27,20 +27,14 @@ namespace latinime { class CharUtils { public: + static const std::vector<int> EMPTY_STRING; + static AK_FORCE_INLINE bool isAsciiUpper(int c) { // Note: isupper(...) reports false positives for some Cyrillic characters, causing them to // be incorrectly lower-cased using toAsciiLower(...) rather than latin_tolower(...). return (c >= 'A' && c <= 'Z'); } - static AK_FORCE_INLINE int toAsciiLower(int c) { - return c - 'A' + 'a'; - } - - static AK_FORCE_INLINE bool isAscii(int c) { - return isascii(c) != 0; - } - static AK_FORCE_INLINE int toLowerCase(const int c) { if (isAsciiUpper(c)) { return toAsciiLower(c); @@ -48,7 +42,7 @@ class CharUtils { if (isAscii(c)) { return c; } - return static_cast<int>(latin_tolower(static_cast<unsigned short>(c))); + return latin_tolower(c); } static AK_FORCE_INLINE int toBaseLowerCase(const int c) { @@ -59,7 +53,6 @@ class CharUtils { // TODO: Do not hardcode here return codePoint == KEYCODE_SINGLE_QUOTE || codePoint == KEYCODE_HYPHEN_MINUS; } - static AK_FORCE_INLINE int getCodePointCount(const int arraySize, const int *const codePoints) { int size = 0; for (; size < arraySize; ++size) { @@ -91,9 +84,6 @@ class CharUtils { return codePoint >= MIN_UNICODE_CODE_POINT && codePoint <= MAX_UNICODE_CODE_POINT; } - static unsigned short latin_tolower(const unsigned short c); - static const std::vector<int> EMPTY_STRING; - // Returns updated code point count. Returns 0 when the code points cannot be marked as a // Beginning-of-Sentence. static AK_FORCE_INLINE int attachBeginningOfSentenceMarker(int *const codePoints, @@ -125,6 +115,16 @@ class CharUtils { */ static const int BASE_CHARS_SIZE = 0x0500; static const unsigned short BASE_CHARS[BASE_CHARS_SIZE]; + + static AK_FORCE_INLINE bool isAscii(int c) { + return isascii(c) != 0; + } + + static AK_FORCE_INLINE int toAsciiLower(int c) { + return c - 'A' + 'a'; + } + + static int latin_tolower(const int c); }; } // namespace latinime #endif // LATINIME_CHAR_UTILS_H diff --git a/native/jni/src/utils/int_array_view.h b/native/jni/src/utils/int_array_view.h index c1ddc9812..cc5f328ba 100644 --- a/native/jni/src/utils/int_array_view.h +++ b/native/jni/src/utils/int_array_view.h @@ -17,8 +17,10 @@ #ifndef LATINIME_INT_ARRAY_VIEW_H #define LATINIME_INT_ARRAY_VIEW_H +#include <algorithm> +#include <array> #include <cstdint> -#include <cstdlib> +#include <cstring> #include <vector> #include "defines.h" @@ -56,14 +58,14 @@ class IntArrayView { explicit IntArrayView(const std::vector<int> &vector) : mPtr(vector.data()), mSize(vector.size()) {} - template <int N> - AK_FORCE_INLINE static IntArrayView fromFixedSizeArray(const int (&array)[N]) { - return IntArrayView(array, N); + template <size_t N> + AK_FORCE_INLINE static IntArrayView fromArray(const std::array<int, N> &array) { + return IntArrayView(array.data(), array.size()); } - // Returns a view that points one int object. Does not take ownership of the given object. - AK_FORCE_INLINE static IntArrayView fromObject(const int *const object) { - return IntArrayView(object, 1); + // Returns a view that points one int object. + AK_FORCE_INLINE static IntArrayView singleElementView(const int *const ptr) { + return IntArrayView(ptr, 1); } AK_FORCE_INLINE int operator[](const size_t index) const { @@ -91,6 +93,28 @@ class IntArrayView { return mPtr + mSize; } + AK_FORCE_INLINE bool contains(const int value) const { + return std::find(begin(), end(), value) != end(); + } + + // Returns the view whose size is smaller than or equal to the given count. + AK_FORCE_INLINE const IntArrayView limit(const size_t maxSize) const { + return IntArrayView(mPtr, std::min(maxSize, mSize)); + } + + AK_FORCE_INLINE const IntArrayView skip(const size_t n) const { + if (mSize <= n) { + return IntArrayView(); + } + return IntArrayView(mPtr + n, mSize - n); + } + + template <size_t N> + void copyToArray(std::array<int, N> *const buffer, const size_t offset) const { + ASSERT(mSize + offset <= N); + memmove(buffer->data() + offset, mPtr, sizeof(int) * mSize); + } + private: DISALLOW_ASSIGNMENT_OPERATOR(IntArrayView); @@ -100,6 +124,9 @@ class IntArrayView { using WordIdArrayView = IntArrayView; using PtNodePosArrayView = IntArrayView; +using CodePointArrayView = IntArrayView; +template <size_t size> +using WordIdArray = std::array<int, size>; } // namespace latinime #endif // LATINIME_MEMORY_VIEW_H |