diff options
Diffstat (limited to 'native/jni/src')
31 files changed, 931 insertions, 218 deletions
diff --git a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp index bb54e608e..e81591992 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp +++ b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp @@ -21,7 +21,6 @@ #include "suggest/core/dicnode/dic_node.h" #include "suggest/core/dicnode/dic_node_vector.h" #include "suggest/core/dictionary/multi_bigram_map.h" -#include "suggest/core/dictionary/probability_utils.h" #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" #include "utils/char_utils.h" @@ -93,13 +92,15 @@ namespace latinime { if (NOT_A_VALID_WORD_POS == wordPos || NOT_A_VALID_WORD_POS == prevWordPos) { // Note: Normally wordPos comes from the dictionary and should never equal // NOT_A_VALID_WORD_POS. - return ProbabilityUtils::backoff(unigramProbability); + return dictionaryStructurePolicy->getProbability(unigramProbability, + NOT_A_PROBABILITY); } if (multiBigramMap) { return multiBigramMap->getBigramProbability(dictionaryStructurePolicy, prevWordPos, wordPos, unigramProbability); } - return ProbabilityUtils::backoff(unigramProbability); + return dictionaryStructurePolicy->getProbability(unigramProbability, + NOT_A_PROBABILITY); } //////////////// diff --git a/native/jni/src/suggest/core/dicnode/internal/dic_node_state_prevword.h b/native/jni/src/suggest/core/dicnode/internal/dic_node_state_prevword.h index f437c95f6..9bc96877e 100644 --- a/native/jni/src/suggest/core/dicnode/internal/dic_node_state_prevword.h +++ b/native/jni/src/suggest/core/dicnode/internal/dic_node_state_prevword.h @@ -116,10 +116,6 @@ class DicNodeStatePrevWord { return mPrevWordStart; } - int16_t getPrevWordProbability() const { - return mPrevWordProbability; - } - int getPrevWordNodePos() const { return mPrevWordNodePos; } diff --git a/native/jni/src/suggest/core/dictionary/bigram_dictionary.cpp b/native/jni/src/suggest/core/dictionary/bigram_dictionary.cpp index e74a1dbc8..cf1cd8815 100644 --- a/native/jni/src/suggest/core/dictionary/bigram_dictionary.cpp +++ b/native/jni/src/suggest/core/dictionary/bigram_dictionary.cpp @@ -23,7 +23,6 @@ #include "defines.h" #include "suggest/core/dictionary/binary_dictionary_bigrams_iterator.h" #include "suggest/core/dictionary/dictionary.h" -#include "suggest/core/dictionary/probability_utils.h" #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" #include "utils/char_utils.h" @@ -131,7 +130,7 @@ int BigramDictionary::getPredictions(const int *prevWord, const int prevWordLeng // resulting probability is 8 - although in the practice it's never bigger than 3 or 4 // in very bad cases. This means that sometimes, we'll see some bigrams interverted // here, but it can't get too bad. - const int probability = ProbabilityUtils::computeProbabilityForBigram( + const int probability = mDictionaryStructurePolicy->getProbability( unigramProbability, bigramsIt.getProbability()); addWordBigram(bigramBuffer, codePointCount, probability, outBigramProbability, outBigramCodePoints, outputTypes); diff --git a/native/jni/src/suggest/core/dictionary/dictionary.cpp b/native/jni/src/suggest/core/dictionary/dictionary.cpp index 8418a608a..02ece639c 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.cpp +++ b/native/jni/src/suggest/core/dictionary/dictionary.cpp @@ -90,7 +90,7 @@ int Dictionary::getProbability(const int *word, int length) const { if (NOT_A_VALID_WORD_POS == pos) { return NOT_A_PROBABILITY; } - return getDictionaryStructurePolicy()->getUnigramProbability(pos); + return getDictionaryStructurePolicy()->getUnigramProbabilityOfPtNode(pos); } bool Dictionary::isValidBigram(const int *word0, int length0, const int *word1, int length1) const { diff --git a/native/jni/src/suggest/core/dictionary/multi_bigram_map.h b/native/jni/src/suggest/core/dictionary/multi_bigram_map.h index fb4a80083..9efe5f6f9 100644 --- a/native/jni/src/suggest/core/dictionary/multi_bigram_map.h +++ b/native/jni/src/suggest/core/dictionary/multi_bigram_map.h @@ -22,7 +22,6 @@ #include "defines.h" #include "suggest/core/dictionary/binary_dictionary_bigrams_iterator.h" #include "suggest/core/dictionary/bloom_filter.h" -#include "suggest/core/dictionary/probability_utils.h" #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" #include "utils/hash_map_compat.h" @@ -43,11 +42,12 @@ class MultiBigramMap { hash_map_compat<int, BigramMap>::const_iterator mapPosition = mBigramMaps.find(wordPosition); if (mapPosition != mBigramMaps.end()) { - return mapPosition->second.getBigramProbability(nextWordPosition, unigramProbability); + return mapPosition->second.getBigramProbability(structurePolicy, nextWordPosition, + unigramProbability); } if (mBigramMaps.size() < MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP) { addBigramsForWordPosition(structurePolicy, wordPosition); - return mBigramMaps[wordPosition].getBigramProbability( + return mBigramMaps[wordPosition].getBigramProbability(structurePolicy, nextWordPosition, unigramProbability); } return readBigramProbabilityFromBinaryDictionary(structurePolicy, wordPosition, @@ -82,17 +82,17 @@ class MultiBigramMap { } AK_FORCE_INLINE int getBigramProbability( + const DictionaryStructureWithBufferPolicy *const structurePolicy, const int nextWordPosition, const int unigramProbability) const { + int bigramProbability = NOT_A_PROBABILITY; if (mBloomFilter.isInFilter(nextWordPosition)) { const hash_map_compat<int, int>::const_iterator bigramProbabilityIt = mBigramMap.find(nextWordPosition); if (bigramProbabilityIt != mBigramMap.end()) { - const int bigramProbability = bigramProbabilityIt->second; - return ProbabilityUtils::computeProbabilityForBigram( - unigramProbability, bigramProbability); + bigramProbability = bigramProbabilityIt->second; } } - return ProbabilityUtils::backoff(unigramProbability); + return structurePolicy->getProbability(unigramProbability, bigramProbability); } private: @@ -111,17 +111,18 @@ class MultiBigramMap { AK_FORCE_INLINE int readBigramProbabilityFromBinaryDictionary( const DictionaryStructureWithBufferPolicy *const structurePolicy, const int nodePos, const int nextWordPosition, const int unigramProbability) { + int bigramProbability = NOT_A_PROBABILITY; const int bigramsListPos = structurePolicy->getBigramsPositionOfNode(nodePos); BinaryDictionaryBigramsIterator bigramsIt(structurePolicy->getBigramsStructurePolicy(), bigramsListPos); while (bigramsIt.hasNext()) { bigramsIt.next(); if (bigramsIt.getBigramPos() == nextWordPosition) { - return ProbabilityUtils::computeProbabilityForBigram( - unigramProbability, bigramsIt.getProbability()); + bigramProbability = bigramsIt.getProbability(); + break; } } - return ProbabilityUtils::backoff(unigramProbability); + return structurePolicy->getProbability(unigramProbability, bigramProbability); } static const size_t MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP; diff --git a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h index 532411509..c8cbbcfdf 100644 --- a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h +++ b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h @@ -47,7 +47,10 @@ class DictionaryStructureWithBufferPolicy { virtual int getTerminalNodePositionOfWord(const int *const inWord, const int length, const bool forceLowerCaseSearch) const = 0; - virtual int getUnigramProbability(const int nodePos) const = 0; + virtual int getProbability(const int unigramProbability, + const int bigramProbability) const = 0; + + virtual int getUnigramProbabilityOfPtNode(const int nodePos) const = 0; virtual int getShortcutPositionOfNode(const int nodePos) const = 0; diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp index 7d8dd21c5..e788e914a 100644 --- a/native/jni/src/suggest/core/suggest.cpp +++ b/native/jni/src/suggest/core/suggest.cpp @@ -171,7 +171,9 @@ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequen terminalIndex, doubleLetterTerminalIndex, doubleLetterLevel); const float compoundDistance = terminalDicNode->getCompoundDistance(languageWeight) + doubleLetterCost; - const bool isPossiblyOffensiveWord = terminalDicNode->getProbability() <= 0; + const bool isPossiblyOffensiveWord = + traverseSession->getDictionaryStructurePolicy()->getProbability( + terminalDicNode->getProbability(), NOT_A_PROBABILITY) <= 0; const bool isExactMatch = terminalDicNode->isExactMatch(); const bool isFirstCharUppercase = terminalDicNode->isFirstCharUppercase(); // Heuristic: We exclude freq=0 first-char-uppercase words from exact match. diff --git a/native/jni/src/suggest/policyimpl/dictionary/bigram/dynamic_bigram_list_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/bigram/dynamic_bigram_list_policy.cpp index e31a91069..936dc9c5d 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/bigram/dynamic_bigram_list_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/bigram/dynamic_bigram_list_policy.cpp @@ -20,12 +20,13 @@ namespace latinime { bool DynamicBigramListPolicy::copyAllBigrams(int *const fromPos, int *const toPos) { const bool usesAdditionalBuffer = mBuffer->isInAdditionalBuffer(*fromPos); - const uint8_t *const buffer = mBuffer->getBuffer(usesAdditionalBuffer); if (usesAdditionalBuffer) { *fromPos -= mBuffer->getOriginalBufferSize(); } BigramListReadWriteUtils::BigramFlags flags; do { + // The buffer address can be changed after calling buffer writing methods. + const uint8_t *const buffer = mBuffer->getBuffer(usesAdditionalBuffer); flags = BigramListReadWriteUtils::getFlagsAndForwardPointer(buffer, fromPos); int bigramPos = BigramListReadWriteUtils::getBigramAddressAndForwardPointer( buffer, flags, fromPos); @@ -63,7 +64,6 @@ bool DynamicBigramListPolicy::copyAllBigrams(int *const fromPos, int *const toPo bool DynamicBigramListPolicy::addBigramEntry(const int bigramPos, const int probability, int *const pos) { const bool usesAdditionalBuffer = mBuffer->isInAdditionalBuffer(*pos); - const uint8_t *const buffer = mBuffer->getBuffer(usesAdditionalBuffer); if (usesAdditionalBuffer) { *pos -= mBuffer->getOriginalBufferSize(); } @@ -73,6 +73,8 @@ bool DynamicBigramListPolicy::addBigramEntry(const int bigramPos, const int prob if (usesAdditionalBuffer) { entryPos += mBuffer->getOriginalBufferSize(); } + // The buffer address can be changed after calling buffer writing methods. + const uint8_t *const buffer = mBuffer->getBuffer(usesAdditionalBuffer); flags = BigramListReadWriteUtils::getFlagsAndForwardPointer(buffer, pos); BigramListReadWriteUtils::getBigramAddressAndForwardPointer(buffer, flags, pos); if (BigramListReadWriteUtils::hasNext(flags)) { @@ -118,13 +120,14 @@ bool DynamicBigramListPolicy::addBigramEntry(const int bigramPos, const int prob bool DynamicBigramListPolicy::removeBigram(const int bigramListPos, const int targetBigramPos) { const bool usesAdditionalBuffer = mBuffer->isInAdditionalBuffer(bigramListPos); - const uint8_t *const buffer = mBuffer->getBuffer(usesAdditionalBuffer); int pos = bigramListPos; if (usesAdditionalBuffer) { pos -= mBuffer->getOriginalBufferSize(); } BigramListReadWriteUtils::BigramFlags flags; do { + // The buffer address can be changed after calling buffer writing methods. + const uint8_t *const buffer = mBuffer->getBuffer(usesAdditionalBuffer); flags = BigramListReadWriteUtils::getFlagsAndForwardPointer(buffer, &pos); int bigramOffsetFieldPos = pos; if (usesAdditionalBuffer) { @@ -139,8 +142,7 @@ bool DynamicBigramListPolicy::removeBigram(const int bigramListPos, const int ta continue; } // Target entry is found. Write 0 into the bigram pos field to mark the bigram invalid. - const int bigramOffsetFieldSize = - BigramListReadWriteUtils::attributeAddressSize(flags); + const int bigramOffsetFieldSize = BigramListReadWriteUtils::attributeAddressSize(flags); if (!mBuffer->writeUintAndAdvancePosition(0 /* data */, bigramOffsetFieldSize, &bigramOffsetFieldPos)) { return false; diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_node_reader.cpp b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_node_reader.cpp index 6bb90fc2d..5674cb48e 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_node_reader.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_node_reader.cpp @@ -34,7 +34,7 @@ void DynamicPatriciaTrieNodeReader::fetchNodeInfoFromBufferAndProcessMovedNode(c mFlags = PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(dictBuf, &pos); const int parentPos = DynamicPatriciaTrieReadingUtils::getParentPosAndAdvancePosition(dictBuf, &pos); - mParentPos = (parentPos != 0) ? mNodePos + parentPos : NOT_A_DICT_POS; + mParentPos = (parentPos != 0) ? nodePos + parentPos : NOT_A_DICT_POS; if (outCodePoints != 0) { mCodePointCount = PatriciaTrieReadingUtils::getCharsAndAdvancePosition( dictBuf, mFlags, maxCodePointCount, outCodePoints, &pos); @@ -43,10 +43,19 @@ void DynamicPatriciaTrieNodeReader::fetchNodeInfoFromBufferAndProcessMovedNode(c dictBuf, mFlags, MAX_WORD_LENGTH, &pos); } if (isTerminal()) { + mProbabilityFieldPos = pos; + if (usesAdditionalBuffer) { + mProbabilityFieldPos += mBuffer->getOriginalBufferSize(); + } mProbability = PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(dictBuf, &pos); } else { + mProbabilityFieldPos = NOT_A_DICT_POS; mProbability = NOT_A_PROBABILITY; } + mChildrenPosFieldPos = pos; + if (usesAdditionalBuffer) { + mChildrenPosFieldPos += mBuffer->getOriginalBufferSize(); + } mChildrenPos = DynamicPatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition( dictBuf, mFlags, &pos); if (usesAdditionalBuffer && mChildrenPos != NOT_A_DICT_POS) { diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_node_reader.h b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_node_reader.h index acc68b321..2ee7c2495 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_node_reader.h +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_node_reader.h @@ -40,9 +40,11 @@ class DynamicPatriciaTrieNodeReader { const DictionaryShortcutsStructurePolicy *const shortcutsPolicy) : mBuffer(buffer), mBigramsPolicy(bigramsPolicy), mShortcutsPolicy(shortcutsPolicy), mNodePos(NOT_A_VALID_WORD_POS), mFlags(0), - mParentPos(NOT_A_DICT_POS), mCodePointCount(0), mProbability(NOT_A_PROBABILITY), - mChildrenPos(NOT_A_DICT_POS), mShortcutPos(NOT_A_DICT_POS), - mBigramPos(NOT_A_DICT_POS), mSiblingPos(NOT_A_VALID_WORD_POS) {} + mParentPos(NOT_A_DICT_POS), mCodePointCount(0), + mProbabilityFieldPos(NOT_A_DICT_POS), mProbability(NOT_A_PROBABILITY), + mChildrenPosFieldPos(NOT_A_DICT_POS), mChildrenPos(NOT_A_DICT_POS), + mShortcutPos(NOT_A_DICT_POS), mBigramPos(NOT_A_DICT_POS), + mSiblingPos(NOT_A_VALID_WORD_POS) {} ~DynamicPatriciaTrieNodeReader() {} @@ -95,11 +97,19 @@ class DynamicPatriciaTrieNodeReader { } // Probability + AK_FORCE_INLINE int getProbabilityFieldPos() const { + return mProbabilityFieldPos; + } + AK_FORCE_INLINE int getProbability() const { return mProbability; } - // Children node group position + // Children PtNode array position + AK_FORCE_INLINE int getChildrenPosFieldPos() const { + return mChildrenPosFieldPos; + } + AK_FORCE_INLINE int getChildrenPos() const { return mChildrenPos; } @@ -129,7 +139,9 @@ class DynamicPatriciaTrieNodeReader { DynamicPatriciaTrieReadingUtils::NodeFlags mFlags; int mParentPos; uint8_t mCodePointCount; + int mProbabilityFieldPos; int mProbability; + int mChildrenPosFieldPos; int mChildrenPos; int mShortcutPos; int mBigramPos; diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.cpp index 3b9878b82..945677b50 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.cpp @@ -24,6 +24,7 @@ #include "suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.h" #include "suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.h" #include "suggest/policyimpl/dictionary/patricia_trie_reading_utils.h" +#include "suggest/policyimpl/dictionary/utils/probability_utils.h" namespace latinime { @@ -134,7 +135,20 @@ int DynamicPatriciaTriePolicy::getTerminalNodePositionOfWord(const int *const in return NOT_A_VALID_WORD_POS; } -int DynamicPatriciaTriePolicy::getUnigramProbability(const int nodePos) const { +int DynamicPatriciaTriePolicy::getProbability(const int unigramProbability, + const int bigramProbability) const { + // TODO: check mHeaderPolicy.usesForgettingCurve(); + if (unigramProbability == NOT_A_PROBABILITY) { + return NOT_A_PROBABILITY; + } else if (bigramProbability == NOT_A_PROBABILITY) { + return ProbabilityUtils::backoff(unigramProbability); + } else { + return ProbabilityUtils::computeProbabilityForBigram(unigramProbability, + bigramProbability); + } +} + +int DynamicPatriciaTriePolicy::getUnigramProbabilityOfPtNode(const int nodePos) const { if (nodePos == NOT_A_VALID_WORD_POS) { return NOT_A_PROBABILITY; } @@ -144,7 +158,7 @@ int DynamicPatriciaTriePolicy::getUnigramProbability(const int nodePos) const { if (nodeReader.isDeleted() || nodeReader.isBlacklisted() || nodeReader.isNotAWord()) { return NOT_A_PROBABILITY; } - return nodeReader.getProbability(); + return getProbability(nodeReader.getProbability(), NOT_A_PROBABILITY); } int DynamicPatriciaTriePolicy::getShortcutPositionOfNode(const int nodePos) const { diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.h index 5873d3d65..cdab0e16a 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.h @@ -57,7 +57,9 @@ class DynamicPatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { int getTerminalNodePositionOfWord(const int *const inWord, const int length, const bool forceLowerCaseSearch) const; - int getUnigramProbability(const int nodePos) const; + int getProbability(const int unigramProbability, const int bigramProbability) const; + + int getUnigramProbabilityOfPtNode(const int nodePos) const; int getShortcutPositionOfNode(const int nodePos) const; diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_helper.cpp index 2042fcbd2..a0b5be6a4 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_helper.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_helper.cpp @@ -70,9 +70,10 @@ void DynamicPatriciaTrieReadingHelper::followForwardLink() { if (usesAdditionalBuffer) { mPos += mBuffer->getOriginalBufferSize(); } + mPosOfLastForwardLinkField = mPos; if (DynamicPatriciaTrieReadingUtils::isValidForwardLinkPosition(forwardLinkPosition)) { // Follow the forward link. - mPos = forwardLinkPosition; + mPos += forwardLinkPosition; nextNodeArray(); } else { // All node arrays have been read. diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_helper.h b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_helper.h index b108ed5fb..db1c392bb 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_helper.h +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_helper.h @@ -38,8 +38,8 @@ class DynamicPatriciaTrieReadingHelper { const DictionaryBigramsStructurePolicy *const bigramsPolicy, const DictionaryShortcutsStructurePolicy *const shortcutsPolicy) : mIsError(false), mPos(NOT_A_DICT_POS), mNodeCount(0), mPrevTotalCodePointCount(0), - mTotalNodeCount(0), mNodeArrayCount(0), mBuffer(buffer), - mNodeReader(mBuffer, bigramsPolicy, shortcutsPolicy) {} + mTotalNodeCount(0), mNodeArrayCount(0), mPosOfLastForwardLinkField(NOT_A_DICT_POS), + mBuffer(buffer), mNodeReader(mBuffer, bigramsPolicy, shortcutsPolicy) {} ~DynamicPatriciaTrieReadingHelper() {} @@ -62,6 +62,7 @@ class DynamicPatriciaTrieReadingHelper { mPrevTotalCodePointCount = 0; mTotalNodeCount = 0; mNodeArrayCount = 0; + mPosOfLastForwardLinkField = NOT_A_DICT_POS; nextNodeArray(); if (!isEnd()) { fetchNodeInfo(); @@ -81,6 +82,7 @@ class DynamicPatriciaTrieReadingHelper { mPrevTotalCodePointCount = 0; mTotalNodeCount = 1; mNodeArrayCount = 1; + mPosOfLastForwardLinkField = NOT_A_DICT_POS; fetchNodeInfo(); } } @@ -140,6 +142,7 @@ class DynamicPatriciaTrieReadingHelper { mTotalNodeCount = 0; mNodeArrayCount = 0; mPos = mNodeReader.getChildrenPos(); + mPosOfLastForwardLinkField = NOT_A_DICT_POS; // Read children node array. nextNodeArray(); if (!isEnd()) { @@ -158,12 +161,17 @@ class DynamicPatriciaTrieReadingHelper { mNodeArrayCount = 1; mNodeCount = 1; mPos = mNodeReader.getParentPos(); + mPosOfLastForwardLinkField = NOT_A_DICT_POS; fetchNodeInfo(); } else { mPos = NOT_A_DICT_POS; } } + AK_FORCE_INLINE int getPosOfLastForwardLinkField() const { + return mPosOfLastForwardLinkField; + } + private: DISALLOW_COPY_AND_ASSIGN(DynamicPatriciaTrieReadingHelper); @@ -177,6 +185,7 @@ class DynamicPatriciaTrieReadingHelper { int mPrevTotalCodePointCount; int mTotalNodeCount; int mNodeArrayCount; + int mPosOfLastForwardLinkField; const BufferWithExtendableBuffer *const mBuffer; DynamicPatriciaTrieNodeReader mNodeReader; int mMergedNodeCodePoints[MAX_WORD_LENGTH]; diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.h b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.h index a6cb46d39..62d73bb02 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.h @@ -56,6 +56,15 @@ class DynamicPatriciaTrieReadingUtils { return FLAG_IS_DELETED == (MASK_MOVED & flags); } + static AK_FORCE_INLINE NodeFlags updateAndGetFlags(const NodeFlags originalFlags, + const bool isMoved, const bool isDeleted) { + NodeFlags flags = originalFlags; + flags = isMoved ? ((flags & (!MASK_MOVED)) | FLAG_IS_MOVED) : flags; + flags = isDeleted ? ((flags & (!MASK_MOVED)) | FLAG_IS_DELETED) : flags; + flags = (!isMoved && !isDeleted) ? ((flags & (!MASK_MOVED)) | FLAG_IS_NOT_MOVED) : flags; + return flags; + } + private: DISALLOW_IMPLICIT_CONSTRUCTORS(DynamicPatriciaTrieReadingUtils); diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.cpp index 128d69d88..99a983f21 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.cpp @@ -19,6 +19,9 @@ #include "suggest/policyimpl/dictionary/bigram/dynamic_bigram_list_policy.h" #include "suggest/policyimpl/dictionary/dynamic_patricia_trie_node_reader.h" #include "suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_helper.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_utils.h" +#include "suggest/policyimpl/dictionary/patricia_trie_reading_utils.h" #include "suggest/policyimpl/dictionary/shortcut/dynamic_shortcut_list_policy.h" namespace latinime { @@ -26,6 +29,7 @@ namespace latinime { bool DynamicPatriciaTrieWritingHelper::addUnigramWord( DynamicPatriciaTrieReadingHelper *const readingHelper, const int *const wordCodePoints, const int codePointCount, const int probability) { + int parentPos = NOT_A_VALID_WORD_POS; while (!readingHelper->isEnd()) { const int matchedCodePointCount = readingHelper->getPrevTotalCodePointCount(); if (!readingHelper->isMatchedCodePoint(0 /* index */, @@ -40,38 +44,37 @@ bool DynamicPatriciaTrieWritingHelper::addUnigramWord( const int nodeCodePointCount = nodeReader->getCodePointCount(); for (int j = 1; j < nodeCodePointCount; ++j) { const int nextIndex = matchedCodePointCount + j; - if (nextIndex >= codePointCount) { - // TODO: split current node after j - 1, create child and make this terminal. - return false; - } - if (!readingHelper->isMatchedCodePoint(j, + if (nextIndex >= codePointCount || !readingHelper->isMatchedCodePoint(j, wordCodePoints[matchedCodePointCount + j])) { - // TODO: split current node after j - 1 and create two children. - return false; + return reallocatePtNodeAndAddNewPtNodes(nodeReader, + readingHelper->getMergedNodeCodePoints(), j, probability, + wordCodePoints + matchedCodePointCount, + codePointCount - matchedCodePointCount); } } // All characters are matched. if (codePointCount == readingHelper->getTotalCodePointCount()) { - if (nodeReader->isTerminal()) { - // TODO: Update probability. - } else { - // TODO: Make it terminal and update probability. - } - return false; + return setPtNodeProbability(nodeReader, probability, + readingHelper->getMergedNodeCodePoints()); } if (!nodeReader->hasChildren()) { - // TODO: Create children node array and add new node as a child. - return false; + return createChildrenPtNodeArrayAndAChildPtNode(nodeReader, probability, + wordCodePoints + readingHelper->getTotalCodePointCount(), + codePointCount - readingHelper->getTotalCodePointCount()); } // Advance to the children nodes. + parentPos = nodeReader->getNodePos(); readingHelper->readChildNode(); } if (readingHelper->isError()) { // The dictionary is invalid. return false; } - // TODO: add at the last position of the node array. - return false; + int pos = readingHelper->getPosOfLastForwardLinkField(); + return createAndInsertNodeIntoPtNodeArray(parentPos, + wordCodePoints + readingHelper->getPrevTotalCodePointCount(), + codePointCount - readingHelper->getPrevTotalCodePointCount(), + probability, &pos); } bool DynamicPatriciaTrieWritingHelper::addBigramWords(const int word0Pos, const int word1Pos, @@ -96,4 +99,243 @@ bool DynamicPatriciaTrieWritingHelper::removeBigramWords(const int word0Pos, con return false; } +bool DynamicPatriciaTrieWritingHelper::markNodeAsMovedAndSetPosition( + const DynamicPatriciaTrieNodeReader *const originalNode, const int movedPos) { + int pos = originalNode->getNodePos(); + const bool usesAdditionalBuffer = mBuffer->isInAdditionalBuffer(pos); + const uint8_t *const dictBuf = mBuffer->getBuffer(usesAdditionalBuffer); + if (usesAdditionalBuffer) { + pos -= mBuffer->getOriginalBufferSize(); + } + // Read original flags + const PatriciaTrieReadingUtils::NodeFlags originalFlags = + PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(dictBuf, &pos); + const PatriciaTrieReadingUtils::NodeFlags updatedFlags = + DynamicPatriciaTrieReadingUtils::updateAndGetFlags(originalFlags, true /* isMoved */, + false /* isDeleted */); + int writingPos = originalNode->getNodePos(); + // Update flags. + if (!DynamicPatriciaTrieWritingUtils::writeFlagsAndAdvancePosition(mBuffer, updatedFlags, + &writingPos)) { + return false; + } + // Update moved position, which is stored in the parent offset field. + const int movedPosOffset = movedPos - originalNode->getNodePos(); + if (!DynamicPatriciaTrieWritingUtils::writeParentOffsetAndAdvancePosition( + mBuffer, movedPosOffset, &writingPos)) { + return false; + } + return true; +} + +// Write new PtNode at writingPos. +bool DynamicPatriciaTrieWritingHelper::writePtNodeWithFullInfoToBuffer(const bool isBlacklisted, + const bool isNotAWord, const int parentPos, const int *const codePoints, + const int codePointCount, const int probability, const int childrenPos, + const int originalBigramListPos, const int originalShortcutListPos, + int *const writingPos) { + const int nodePos = *writingPos; + // Create node flags and write them. + const PatriciaTrieReadingUtils::NodeFlags nodeFlags = + PatriciaTrieReadingUtils::createAndGetFlags(isBlacklisted, isNotAWord, + probability != NOT_A_PROBABILITY, originalShortcutListPos != NOT_A_DICT_POS, + originalBigramListPos != NOT_A_DICT_POS, codePointCount > 1, + 3 /* childrenPositionFieldSize */); + if (!DynamicPatriciaTrieWritingUtils::writeFlagsAndAdvancePosition(mBuffer, nodeFlags, + writingPos)) { + return false; + } + // Calculate a parent offset and write the offset. + const int parentOffset = (parentPos != NOT_A_DICT_POS) ? parentPos - nodePos : NOT_A_DICT_POS; + if (!DynamicPatriciaTrieWritingUtils::writeParentOffsetAndAdvancePosition(mBuffer, + parentOffset, writingPos)) { + return false; + } + // Write code points + if (!DynamicPatriciaTrieWritingUtils::writeCodePointsAndAdvancePosition(mBuffer, codePoints, + codePointCount, writingPos)) { + return false;; + } + // Write probability when the probability is a valid probability, which means this node is + // terminal. + if (probability != NOT_A_PROBABILITY) { + if (!DynamicPatriciaTrieWritingUtils::writeProbabilityAndAdvancePosition(mBuffer, + probability, writingPos)) { + return false; + } + } + // Write children position + if (!DynamicPatriciaTrieWritingUtils::writeChildrenPositionAndAdvancePosition(mBuffer, + childrenPos, writingPos)) { + return false; + } + // Copy shortcut list when the originalShortcutListPos is valid dictionary position. + if (originalShortcutListPos != NOT_A_DICT_POS) { + int fromPos = originalShortcutListPos; + if (!mShortcutPolicy->copyAllShortcutsAndReturnIfSucceededOrNot(&fromPos, writingPos)) { + return false; + } + } + // Copy bigram list when the originalBigramListPos is valid dictionary position. + if (originalBigramListPos != NOT_A_DICT_POS) { + int fromPos = originalBigramListPos; + if (!mBigramPolicy->copyAllBigrams(&fromPos, writingPos)) { + return false; + } + } + return true; +} + +bool DynamicPatriciaTrieWritingHelper::writePtNodeToBuffer(const int parentPos, + const int *const codePoints, const int codePointCount, const int probability, + int *const writingPos) { + return writePtNodeWithFullInfoToBuffer(false /* isBlacklisted */, false /* isNotAWord */, + parentPos, codePoints, codePointCount, probability, + NOT_A_DICT_POS /* childrenPos */, NOT_A_DICT_POS /* originalBigramsPos */, + NOT_A_DICT_POS /* originalShortcutPos */, writingPos); +} + +bool DynamicPatriciaTrieWritingHelper::writePtNodeToBufferByCopyingPtNodeInfo( + const DynamicPatriciaTrieNodeReader *const originalNode, const int parentPos, + const int *const codePoints, const int codePointCount, const int probability, + int *const writingPos) { + return writePtNodeWithFullInfoToBuffer(originalNode->isBlacklisted(), + originalNode->isNotAWord(), parentPos, codePoints, codePointCount, probability, + originalNode->getChildrenPos(), originalNode->getBigramsPos(), + originalNode->getShortcutPos(), writingPos); +} + +bool DynamicPatriciaTrieWritingHelper::createAndInsertNodeIntoPtNodeArray(const int parentPos, + const int *const nodeCodePoints, const int nodeCodePointCount, const int probability, + int *const forwardLinkFieldPos) { + const int newPtNodeArrayPos = mBuffer->getTailPosition(); + if (!DynamicPatriciaTrieWritingUtils::writeForwardLinkPositionAndAdvancePosition(mBuffer, + newPtNodeArrayPos, forwardLinkFieldPos)) { + return false; + } + return createNewPtNodeArrayWithAChildPtNode(parentPos, nodeCodePoints, nodeCodePointCount, + probability); +} + +bool DynamicPatriciaTrieWritingHelper::setPtNodeProbability( + const DynamicPatriciaTrieNodeReader *const originalPtNode, const int probability, + const int *const codePoints) { + if (originalPtNode->isTerminal()) { + // Overwrites the probability. + int probabilityFieldPos = originalPtNode->getProbabilityFieldPos(); + if (!DynamicPatriciaTrieWritingUtils::writeProbabilityAndAdvancePosition(mBuffer, + probability, &probabilityFieldPos)) { + return false; + } + } else { + // Make the node terminal and write the probability. + int movedPos = mBuffer->getTailPosition(); + if (!markNodeAsMovedAndSetPosition(originalPtNode, movedPos)) { + return false; + } + if (!writePtNodeToBufferByCopyingPtNodeInfo(originalPtNode, originalPtNode->getParentPos(), + codePoints, originalPtNode->getCodePointCount(), probability, &movedPos)) { + return false; + } + } + return true; +} + +bool DynamicPatriciaTrieWritingHelper::createChildrenPtNodeArrayAndAChildPtNode( + const DynamicPatriciaTrieNodeReader *const parentNode, const int probability, + const int *const codePoints, const int codePointCount) { + const int newPtNodeArrayPos = mBuffer->getTailPosition(); + int childrenPosFieldPos = parentNode->getChildrenPosFieldPos(); + if (!DynamicPatriciaTrieWritingUtils::writeChildrenPositionAndAdvancePosition(mBuffer, + newPtNodeArrayPos, &childrenPosFieldPos)) { + return false; + } + return createNewPtNodeArrayWithAChildPtNode(parentNode->getNodePos(), codePoints, + codePointCount, probability); +} + +bool DynamicPatriciaTrieWritingHelper::createNewPtNodeArrayWithAChildPtNode( + const int parentPtNodePos, const int *const nodeCodePoints, const int nodeCodePointCount, + const int probability) { + int writingPos = mBuffer->getTailPosition(); + if (!DynamicPatriciaTrieWritingUtils::writePtNodeArraySizeAndAdvancePosition(mBuffer, + 1 /* arraySize */, &writingPos)) { + return false; + } + if (!writePtNodeToBuffer(parentPtNodePos, nodeCodePoints, nodeCodePointCount, probability, + &writingPos)) { + return false; + } + if (!DynamicPatriciaTrieWritingUtils::writeForwardLinkPositionAndAdvancePosition(mBuffer, + NOT_A_DICT_POS /* forwardLinkPos */, &writingPos)) { + return false; + } + return true; +} + +// Returns whether the dictionary updating was succeeded or not. +bool DynamicPatriciaTrieWritingHelper::reallocatePtNodeAndAddNewPtNodes( + const DynamicPatriciaTrieNodeReader *const reallocatingPtNode, + const int *const reallocatingPtNodeCodePoints, const int overlappingCodePointCount, + const int probabilityOfNewPtNode, const int *const newNodeCodePoints, + const int newNodeCodePointCount) { + // When addsExtraChild is true, split the reallocating PtNode and add new child. + // Reallocating PtNode: abcde, newNode: abcxy. + // abc (1st, not terminal) __ de (2nd) + // \_ xy (extra child, terminal) + // Otherwise, this method makes 1st part terminal and write probabilityOfNewPtNode. + // Reallocating PtNode: abcde, newNode: abc. + // abc (1st, terminal) __ de (2nd) + const bool addsExtraChild = newNodeCodePointCount > overlappingCodePointCount; + const int firstPtNodePos = mBuffer->getTailPosition(); + if (!markNodeAsMovedAndSetPosition(reallocatingPtNode, firstPtNodePos)) { + return false; + } + int writingPos = firstPtNodePos; + // Write the 1st part of the reallocating node. The children position will be updated later + // with actual children position. + const int newProbability = addsExtraChild ? NOT_A_PROBABILITY : probabilityOfNewPtNode; + if (!writePtNodeToBuffer(reallocatingPtNode->getParentPos(), reallocatingPtNodeCodePoints, + overlappingCodePointCount, newProbability, &writingPos)) { + return false; + } + const int actualChildrenPos = writingPos; + // Create new children PtNode array. + const size_t newPtNodeCount = addsExtraChild ? 2 : 1; + if (!DynamicPatriciaTrieWritingUtils::writePtNodeArraySizeAndAdvancePosition(mBuffer, + newPtNodeCount, &writingPos)) { + return false; + } + // Write the 2nd part of the reallocating node. + if (!writePtNodeToBufferByCopyingPtNodeInfo(reallocatingPtNode, + reallocatingPtNode->getNodePos(), + reallocatingPtNodeCodePoints + overlappingCodePointCount, + reallocatingPtNode->getCodePointCount() - overlappingCodePointCount, + reallocatingPtNode->getProbability(), &writingPos)) { + return false; + } + if (addsExtraChild) { + if (!writePtNodeToBuffer(reallocatingPtNode->getNodePos(), + newNodeCodePoints + overlappingCodePointCount, + newNodeCodePointCount - overlappingCodePointCount, probabilityOfNewPtNode, + &writingPos)) { + return false; + } + } + if (!DynamicPatriciaTrieWritingUtils::writeForwardLinkPositionAndAdvancePosition(mBuffer, + NOT_A_DICT_POS /* forwardLinkPos */, &writingPos)) { + return false; + } + // Load node info. Information of the 1st part will be fetched. + DynamicPatriciaTrieNodeReader nodeReader(mBuffer, mBigramPolicy, mShortcutPolicy); + nodeReader.fetchNodeInfoFromBuffer(firstPtNodePos); + // Update children position. + int childrenPosFieldPos = nodeReader.getChildrenPosFieldPos(); + if (!DynamicPatriciaTrieWritingUtils::writeChildrenPositionAndAdvancePosition(mBuffer, + actualChildrenPos, &childrenPosFieldPos)) { + return false; + } + return true; +} + } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.h b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.h index f8165fc3d..ada634a54 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.h +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.h @@ -23,6 +23,7 @@ namespace latinime { class BufferWithExtendableBuffer; class DynamicBigramListPolicy; +class DynamicPatriciaTrieNodeReader; class DynamicPatriciaTrieReadingHelper; class DynamicShortcutListPolicy; @@ -51,6 +52,41 @@ class DynamicPatriciaTrieWritingHelper { BufferWithExtendableBuffer *const mBuffer; DynamicBigramListPolicy *const mBigramPolicy; DynamicShortcutListPolicy *const mShortcutPolicy; + + bool markNodeAsMovedAndSetPosition(const DynamicPatriciaTrieNodeReader *const nodeToUpdate, + const int movedPos); + + bool writePtNodeWithFullInfoToBuffer(const bool isBlacklisted, const bool isNotAWord, + const int parentPos, const int *const codePoints, const int codePointCount, + const int probability, const int childrenPos, const int originalBigramListPos, + const int originalShortcutListPos, int *const writingPos); + + bool writePtNodeToBuffer(const int parentPos, const int *const codePoints, + const int codePointCount, const int probability, int *const writingPos); + + bool writePtNodeToBufferByCopyingPtNodeInfo( + const DynamicPatriciaTrieNodeReader *const originalNode, const int parentPos, + const int *const codePoints, const int codePointCount, const int probability, + int *const writingPos); + + bool createAndInsertNodeIntoPtNodeArray(const int parentPos, const int *const nodeCodePoints, + const int nodeCodePointCount, const int probability, int *const forwardLinkFieldPos); + + bool setPtNodeProbability(const DynamicPatriciaTrieNodeReader *const originalNode, + const int probability, const int *const codePoints); + + bool createChildrenPtNodeArrayAndAChildPtNode( + const DynamicPatriciaTrieNodeReader *const parentNode, const int probability, + const int *const codePoints, const int codePointCount); + + bool createNewPtNodeArrayWithAChildPtNode(const int parentPos, const int *const nodeCodePoints, + const int nodeCodePointCount, const int probability); + + bool reallocatePtNodeAndAddNewPtNodes( + const DynamicPatriciaTrieNodeReader *const reallocatingPtNode, + const int *const reallocatingPtNodeCodePoints, const int overlappingCodePointCount, + const int probabilityOfNewPtNode, const int *const newNodeCodePoints, + const int newNodeCodePointCount); }; } // namespace latinime #endif /* LATINIME_DYNAMIC_PATRICIA_TRIE_WRITING_HELPER_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_utils.cpp new file mode 100644 index 000000000..b261e594d --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_utils.cpp @@ -0,0 +1,130 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_utils.h" + +#include <cstddef> +#include <cstdlib> +#include <stdint.h> + +#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" + +namespace latinime { + +const size_t DynamicPatriciaTrieWritingUtils::MAX_PTNODE_ARRAY_SIZE_TO_USE_SMALL_SIZE_FIELD = 0x7F; +const size_t DynamicPatriciaTrieWritingUtils::MAX_PTNODE_ARRAY_SIZE = 0x7FFF; +const int DynamicPatriciaTrieWritingUtils::SMALL_PTNODE_ARRAY_SIZE_FIELD_SIZE = 1; +const int DynamicPatriciaTrieWritingUtils::LARGE_PTNODE_ARRAY_SIZE_FIELD_SIZE = 2; +const int DynamicPatriciaTrieWritingUtils::LARGE_PTNODE_ARRAY_SIZE_FIELD_SIZE_FLAG = 0x8000; +const int DynamicPatriciaTrieWritingUtils::DICT_OFFSET_FIELD_SIZE = 3; +const int DynamicPatriciaTrieWritingUtils::MAX_DICT_OFFSET_VALUE = 0x7FFFFF; +const int DynamicPatriciaTrieWritingUtils::MIN_DICT_OFFSET_VALUE = -0x7FFFFF; +const int DynamicPatriciaTrieWritingUtils::DICT_OFFSET_NEGATIVE_FLAG = 0x800000; +const int DynamicPatriciaTrieWritingUtils::PROBABILITY_FIELD_SIZE = 1; +const int DynamicPatriciaTrieWritingUtils::NODE_FLAG_FIELD_SIZE = 1; + +/* static */ bool DynamicPatriciaTrieWritingUtils::writeForwardLinkPositionAndAdvancePosition( + BufferWithExtendableBuffer *const buffer, const int forwardLinkPos, + int *const forwardLinkFieldPos) { + const int offset = (forwardLinkPos != NOT_A_DICT_POS) ? + forwardLinkPos - (*forwardLinkFieldPos) : 0; + return writeDictOffset(buffer, offset, forwardLinkFieldPos); +} + +/* static */ bool DynamicPatriciaTrieWritingUtils::writePtNodeArraySizeAndAdvancePosition( + BufferWithExtendableBuffer *const buffer, const size_t arraySize, + int *const arraySizeFieldPos) { + if (arraySize <= MAX_PTNODE_ARRAY_SIZE_TO_USE_SMALL_SIZE_FIELD) { + return buffer->writeUintAndAdvancePosition(arraySize, SMALL_PTNODE_ARRAY_SIZE_FIELD_SIZE, + arraySizeFieldPos); + } else if (arraySize <= MAX_PTNODE_ARRAY_SIZE) { + uint32_t data = arraySize | LARGE_PTNODE_ARRAY_SIZE_FIELD_SIZE_FLAG; + return buffer->writeUintAndAdvancePosition(data, LARGE_PTNODE_ARRAY_SIZE_FIELD_SIZE, + arraySizeFieldPos); + } else { + AKLOGI("PtNode array size cannot be written because arraySize is too large: %zd", + arraySize); + ASSERT(false); + return false; + } +} + +/* static */ bool DynamicPatriciaTrieWritingUtils::writeFlagsAndAdvancePosition( + BufferWithExtendableBuffer *const buffer, + const DynamicPatriciaTrieReadingUtils::NodeFlags nodeFlags, int *const nodeFlagsFieldPos) { + return buffer->writeUintAndAdvancePosition(nodeFlags, NODE_FLAG_FIELD_SIZE, nodeFlagsFieldPos); +} + +// Note that parentOffset is offset from node's head position. +/* static */ bool DynamicPatriciaTrieWritingUtils::writeParentOffsetAndAdvancePosition( + BufferWithExtendableBuffer *const buffer, const int parentOffset, + int *const parentPosFieldPos) { + int offset = (parentOffset != NOT_A_DICT_POS) ? parentOffset : 0; + return writeDictOffset(buffer, offset, parentPosFieldPos); +} + +/* static */ bool DynamicPatriciaTrieWritingUtils::writeCodePointsAndAdvancePosition( + BufferWithExtendableBuffer *const buffer, const int *const codePoints, + const int codePointCount, int *const codePointFieldPos) { + if (codePointCount <= 0) { + AKLOGI("code points cannot be written because codePointCount is invalid: %d", + codePointCount); + ASSERT(false); + return false; + } + const bool hasMultipleCodePoints = codePointCount > 1; + return buffer->writeCodePointsAndAdvancePosition(codePoints, codePointCount, + hasMultipleCodePoints, codePointFieldPos); +} + +/* static */ bool DynamicPatriciaTrieWritingUtils::writeProbabilityAndAdvancePosition( + BufferWithExtendableBuffer *const buffer, const int probability, + int *const probabilityFieldPos) { + if (probability < 0 || probability > MAX_PROBABILITY) { + AKLOGI("probability cannot be written because the probability is invalid: %d", + probability); + ASSERT(false); + return false; + } + return buffer->writeUintAndAdvancePosition(probability, PROBABILITY_FIELD_SIZE, + probabilityFieldPos); +} + +/* static */ bool DynamicPatriciaTrieWritingUtils::writeChildrenPositionAndAdvancePosition( + BufferWithExtendableBuffer *const buffer, const int childrenPosition, + int *const childrenPositionFieldPos) { + int offset = (childrenPosition != NOT_A_DICT_POS) ? + childrenPosition - (*childrenPositionFieldPos) : 0; + return writeDictOffset(buffer, offset, childrenPositionFieldPos); +} + +/* static */ bool DynamicPatriciaTrieWritingUtils::writeDictOffset( + BufferWithExtendableBuffer *const buffer, const int offset, int *const offsetFieldPos) { + if (offset > MAX_DICT_OFFSET_VALUE || offset < MIN_DICT_OFFSET_VALUE) { + AKLOGI("offset cannot be written because the offset is too large or too small: %d", + offset); + ASSERT(false); + return false; + } + uint32_t data = 0; + if (offset >= 0) { + data = offset; + } else { + data = abs(offset) | DICT_OFFSET_NEGATIVE_FLAG; + } + return buffer->writeUintAndAdvancePosition(data, DICT_OFFSET_FIELD_SIZE, offsetFieldPos); +} +} diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_utils.h b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_utils.h new file mode 100644 index 000000000..183ede444 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_utils.h @@ -0,0 +1,73 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DYNAMIC_PATRICIA_TRIE_WRITING_UTILS_H +#define LATINIME_DYNAMIC_PATRICIA_TRIE_WRITING_UTILS_H + +#include <cstddef> + +#include "defines.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.h" + +namespace latinime { + +class BufferWithExtendableBuffer; + +class DynamicPatriciaTrieWritingUtils { + public: + static bool writeForwardLinkPositionAndAdvancePosition( + BufferWithExtendableBuffer *const buffer, const int forwardLinkPos, + int *const forwardLinkFieldPos); + + static bool writePtNodeArraySizeAndAdvancePosition(BufferWithExtendableBuffer *const buffer, + const size_t arraySize, int *const arraySizeFieldPos); + + static bool writeFlagsAndAdvancePosition(BufferWithExtendableBuffer *const buffer, + const DynamicPatriciaTrieReadingUtils::NodeFlags nodeFlags, + int *const nodeFlagsFieldPos); + + static bool writeParentOffsetAndAdvancePosition(BufferWithExtendableBuffer *const buffer, + const int parentPosition, int *const parentPosFieldPos); + + static bool writeCodePointsAndAdvancePosition(BufferWithExtendableBuffer *const buffer, + const int *const codePoints, const int codePointCount, int *const codePointFieldPos); + + static bool writeProbabilityAndAdvancePosition(BufferWithExtendableBuffer *const buffer, + const int probability, int *const probabilityFieldPos); + + static bool writeChildrenPositionAndAdvancePosition(BufferWithExtendableBuffer *const buffer, + const int childrenPosition, int *const childrenPositionFieldPos); + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(DynamicPatriciaTrieWritingUtils); + + static const size_t MAX_PTNODE_ARRAY_SIZE_TO_USE_SMALL_SIZE_FIELD; + static const size_t MAX_PTNODE_ARRAY_SIZE; + static const int SMALL_PTNODE_ARRAY_SIZE_FIELD_SIZE; + static const int LARGE_PTNODE_ARRAY_SIZE_FIELD_SIZE; + static const int LARGE_PTNODE_ARRAY_SIZE_FIELD_SIZE_FLAG; + static const int DICT_OFFSET_FIELD_SIZE; + static const int MAX_DICT_OFFSET_VALUE; + static const int MIN_DICT_OFFSET_VALUE; + static const int DICT_OFFSET_NEGATIVE_FLAG; + static const int NODE_FLAG_FIELD_SIZE; + static const int PROBABILITY_FIELD_SIZE; + + static bool writeDictOffset(BufferWithExtendableBuffer *const buffer, const int offset, + int *const offsetFieldPos); +}; +} // namespace latinime +#endif /* LATINIME_DYNAMIC_PATRICIA_TRIE_WRITING_UTILS_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp index eb828b58c..196da5c97 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp @@ -16,24 +16,115 @@ #include "suggest/policyimpl/dictionary/header/header_policy.h" +#include <cstddef> + namespace latinime { -const char *const HeaderPolicy::MULTIPLE_WORDS_DEMOTION_RATE_KEY = - "MULTIPLE_WORDS_DEMOTION_RATE"; -const float HeaderPolicy::DEFAULT_MULTI_WORD_COST_MULTIPLIER = 1.0f; -const float HeaderPolicy::MULTI_WORD_COST_MULTIPLIER_SCALE = 100.0f; +const char *const HeaderPolicy::MULTIPLE_WORDS_DEMOTION_RATE_KEY = "MULTIPLE_WORDS_DEMOTION_RATE"; +const char *const HeaderPolicy::USES_FORGETTING_CURVE_KEY = "USES_FORGETTING_CURVE"; +const char *const HeaderPolicy::LAST_UPDATED_TIME_KEY = "date"; +const float HeaderPolicy::DEFAULT_MULTIPLE_WORD_COST_MULTIPLIER = 1.0f; +const float HeaderPolicy::MULTIPLE_WORD_COST_MULTIPLIER_SCALE = 100.0f; + +// Used for logging. Question mark is used to indicate that the key is not found. +void HeaderPolicy::readHeaderValueOrQuestionMark(const char *const key, int *outValue, + int outValueSize) const { + if (outValueSize <= 0) return; + if (outValueSize == 1) { + outValue[0] = '\0'; + return; + } + std::vector<int> keyCodePointVector; + insertCharactersIntoVector(key, &keyCodePointVector); + HeaderReadingUtils::AttributeMap::const_iterator it = mAttributeMap.find(keyCodePointVector); + if (it == mAttributeMap.end()) { + // The key was not found. + outValue[0] = '?'; + outValue[1] = '\0'; + return; + } + const int terminalIndex = min(static_cast<int>(it->second.size()), outValueSize - 1); + for (int i = 0; i < terminalIndex; ++i) { + outValue[i] = it->second[i]; + } + outValue[terminalIndex] = '\0'; +} + +float HeaderPolicy::readMultipleWordCostMultiplier() const { + int attributeValue = 0; + if (getAttributeValueAsInt(MULTIPLE_WORDS_DEMOTION_RATE_KEY, &attributeValue)) { + if (attributeValue <= 0) { + return static_cast<float>(MAX_VALUE_FOR_WEIGHTING); + } + return MULTIPLE_WORD_COST_MULTIPLIER_SCALE / static_cast<float>(attributeValue); + } else { + return DEFAULT_MULTIPLE_WORD_COST_MULTIPLIER; + } +} + +bool HeaderPolicy::readUsesForgettingCurveFlag() const { + int attributeValue = 0; + if (getAttributeValueAsInt(USES_FORGETTING_CURVE_KEY, &attributeValue)) { + return attributeValue != 0; + } else { + return false; + } +} + +// Returns S_INT_MIN when the key is not found or the value is invalid. +int HeaderPolicy::readLastUpdatedTime() const { + int attributeValue = 0; + if (getAttributeValueAsInt(LAST_UPDATED_TIME_KEY, &attributeValue)) { + return attributeValue; + } else { + return S_INT_MIN; + } +} -float HeaderPolicy::readMultiWordCostMultiplier() const { - const int headerValue = HeaderReadingUtils::readHeaderValueInt( - mDictBuf, MULTIPLE_WORDS_DEMOTION_RATE_KEY); - if (headerValue == S_INT_MIN) { - // not found - return DEFAULT_MULTI_WORD_COST_MULTIPLIER; +// Returns whether the key is found or not and stores the found value into outValue. +bool HeaderPolicy::getAttributeValueAsInt(const char *const key, int *const outValue) const { + std::vector<int> keyVector; + insertCharactersIntoVector(key, &keyVector); + HeaderReadingUtils::AttributeMap::const_iterator it = mAttributeMap.find(keyVector); + if (it == mAttributeMap.end()) { + // The key was not found. + return false; } - if (headerValue <= 0) { - return static_cast<float>(MAX_VALUE_FOR_WEIGHTING); + *outValue = parseIntAttributeValue(&(it->second)); + return true; +} + +/* static */ HeaderReadingUtils::AttributeMap HeaderPolicy::createAttributeMapAndReadAllAttributes( + const uint8_t *const dictBuf) { + HeaderReadingUtils::AttributeMap attributeMap; + HeaderReadingUtils::fetchAllHeaderAttributes(dictBuf, &attributeMap); + return attributeMap; +} + +/* static */ int HeaderPolicy::parseIntAttributeValue( + const std::vector<int> *const attributeValue) { + int value = 0; + bool isNegative = false; + for (size_t i = 0; i < attributeValue->size(); ++i) { + if (i == 0 && attributeValue->at(i) == '-') { + isNegative = true; + } else { + if (!isdigit(attributeValue->at(i))) { + // If not a number, return S_INT_MIN + return S_INT_MIN; + } + value *= 10; + value += attributeValue->at(i) - '0'; + } + } + return isNegative ? -value : value; +} + +/* static */ void HeaderPolicy::insertCharactersIntoVector(const char *const characters, + std::vector<int> *const vector) { + for (int i = 0; characters[i]; ++i) { + vector->push_back(characters[i]); } - return MULTI_WORD_COST_MULTIPLIER_SCALE / static_cast<float>(headerValue); } } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h index e3e6fc077..930b475c7 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h @@ -17,6 +17,7 @@ #ifndef LATINIME_HEADER_POLICY_H #define LATINIME_HEADER_POLICY_H +#include <cctype> #include <stdint.h> #include "defines.h" @@ -30,7 +31,10 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { explicit HeaderPolicy(const uint8_t *const dictBuf) : mDictBuf(dictBuf), mDictionaryFlags(HeaderReadingUtils::getFlags(dictBuf)), mSize(HeaderReadingUtils::getHeaderSize(dictBuf)), - mMultiWordCostMultiplier(readMultiWordCostMultiplier()) {} + mAttributeMap(createAttributeMapAndReadAllAttributes(mDictBuf)), + mMultiWordCostMultiplier(readMultipleWordCostMultiplier()), + mUsesForgettingCurve(readUsesForgettingCurveFlag()), + mLastUpdatedTime(readLastUpdatedTime()) {} ~HeaderPolicy() {} @@ -55,34 +59,49 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { return mMultiWordCostMultiplier; } - AK_FORCE_INLINE void readHeaderValueOrQuestionMark(const char *const key, - int *outValue, int outValueSize) const { - if (outValueSize <= 0) return; - if (outValueSize == 1) { - outValue[0] = '\0'; - return; - } - if (!HeaderReadingUtils::readHeaderValue(mDictBuf, - key, outValue, outValueSize)) { - outValue[0] = '?'; - outValue[1] = '\0'; - } + AK_FORCE_INLINE bool usesForgettingCurve() const { + return mUsesForgettingCurve; } + AK_FORCE_INLINE int getLastUpdatedTime() const { + return mLastUpdatedTime; + } + + void readHeaderValueOrQuestionMark(const char *const key, + int *outValue, int outValueSize) const; + private: DISALLOW_IMPLICIT_CONSTRUCTORS(HeaderPolicy); static const char *const MULTIPLE_WORDS_DEMOTION_RATE_KEY; - static const float DEFAULT_MULTI_WORD_COST_MULTIPLIER; - static const float MULTI_WORD_COST_MULTIPLIER_SCALE; + static const char *const USES_FORGETTING_CURVE_KEY; + static const char *const LAST_UPDATED_TIME_KEY; + static const float DEFAULT_MULTIPLE_WORD_COST_MULTIPLIER; + static const float MULTIPLE_WORD_COST_MULTIPLIER_SCALE; const uint8_t *const mDictBuf; const HeaderReadingUtils::DictionaryFlags mDictionaryFlags; const int mSize; + HeaderReadingUtils::AttributeMap mAttributeMap; const float mMultiWordCostMultiplier; + const bool mUsesForgettingCurve; + const int mLastUpdatedTime; - float readMultiWordCostMultiplier() const; -}; + float readMultipleWordCostMultiplier() const; + + bool readUsesForgettingCurveFlag() const; + int readLastUpdatedTime() const; + + bool getAttributeValueAsInt(const char *const key, int *const outValue) const; + + static HeaderReadingUtils::AttributeMap createAttributeMapAndReadAllAttributes( + const uint8_t *const dictBuf); + + static int parseIntAttributeValue(const std::vector<int> *const attributeValue); + + static void insertCharactersIntoVector( + const char *const characters, std::vector<int> *const vector); +}; } // namespace latinime #endif /* LATINIME_HEADER_POLICY_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_reading_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/header/header_reading_utils.cpp index f323876c4..186c043c1 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/header/header_reading_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_reading_utils.cpp @@ -16,23 +16,22 @@ #include "suggest/policyimpl/dictionary/header/header_reading_utils.h" -#include <cctype> -#include <cstdlib> +#include <vector> #include "defines.h" #include "suggest/policyimpl/dictionary/utils/byte_array_utils.h" namespace latinime { -const int HeaderReadingUtils::MAX_OPTION_KEY_LENGTH = 256; +const int HeaderReadingUtils::MAX_ATTRIBUTE_KEY_LENGTH = 256; +const int HeaderReadingUtils::MAX_ATTRIBUTE_VALUE_LENGTH = 256; const int HeaderReadingUtils::HEADER_MAGIC_NUMBER_SIZE = 4; const int HeaderReadingUtils::HEADER_DICTIONARY_VERSION_SIZE = 2; const int HeaderReadingUtils::HEADER_FLAG_SIZE = 2; const int HeaderReadingUtils::HEADER_SIZE_FIELD_SIZE = 4; -const HeaderReadingUtils::DictionaryFlags - HeaderReadingUtils::NO_FLAGS = 0; +const HeaderReadingUtils::DictionaryFlags HeaderReadingUtils::NO_FLAGS = 0; // Flags for special processing // Those *must* match the flags in makedict (FormatSpec#*_PROCESSING_FLAG) or // something very bad (like, the apocalypse) will happen. Please update both at the same time. @@ -56,53 +55,27 @@ const HeaderReadingUtils::DictionaryFlags HEADER_MAGIC_NUMBER_SIZE + HEADER_DICTIONARY_VERSION_SIZE); } -// Returns if the key is found or not and reads the found value into outValue. -/* static */ bool HeaderReadingUtils::readHeaderValue(const uint8_t *const dictBuf, - const char *const key, int *outValue, const int outValueSize) { - if (outValueSize <= 0) { - return false; - } +/* static */ void HeaderReadingUtils::fetchAllHeaderAttributes(const uint8_t *const dictBuf, + AttributeMap *const headerAttributes) { const int headerSize = getHeaderSize(dictBuf); int pos = getHeaderOptionsPosition(); if (pos == NOT_A_DICT_POS) { // The header doesn't have header options. - return false; + return; } + int keyBuffer[MAX_ATTRIBUTE_KEY_LENGTH]; + int valueBuffer[MAX_ATTRIBUTE_VALUE_LENGTH]; while (pos < headerSize) { - if(ByteArrayUtils::compareStringInBufferWithCharArray( - dictBuf, key, headerSize - pos, &pos) == 0) { - // The key was found. - const int length = ByteArrayUtils::readStringAndAdvancePosition(dictBuf, outValueSize, - outValue, &pos); - // Add a 0 terminator to the string. - outValue[length < outValueSize ? length : outValueSize - 1] = '\0'; - return true; - } - ByteArrayUtils::advancePositionToBehindString(dictBuf, headerSize - pos, &pos); - } - // The key was not found. - return false; -} - -/* static */ int HeaderReadingUtils::readHeaderValueInt( - const uint8_t *const dictBuf, const char *const key) { - const int bufferSize = LARGEST_INT_DIGIT_COUNT; - int intBuffer[bufferSize]; - char charBuffer[bufferSize]; - if (!readHeaderValue(dictBuf, key, intBuffer, bufferSize)) { - return S_INT_MIN; - } - for (int i = 0; i < bufferSize; ++i) { - charBuffer[i] = intBuffer[i]; - if (charBuffer[i] == '0') { - break; - } - if (!isdigit(charBuffer[i])) { - // If not a number, return S_INT_MIN - return S_INT_MIN; - } + const int keyLength = ByteArrayUtils::readStringAndAdvancePosition(dictBuf, + MAX_ATTRIBUTE_KEY_LENGTH, keyBuffer, &pos); + std::vector<int> key; + key.insert(key.end(), keyBuffer, keyBuffer + keyLength); + const int valueLength = ByteArrayUtils::readStringAndAdvancePosition(dictBuf, + MAX_ATTRIBUTE_VALUE_LENGTH, valueBuffer, &pos); + std::vector<int> value; + value.insert(value.end(), valueBuffer, valueBuffer + valueLength); + headerAttributes->insert(AttributeMap::value_type(key, value)); } - return atoi(charBuffer); } } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_reading_utils.h b/native/jni/src/suggest/policyimpl/dictionary/header/header_reading_utils.h index c94919640..5716198fb 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/header/header_reading_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_reading_utils.h @@ -17,7 +17,9 @@ #ifndef LATINIME_HEADER_READING_UTILS_H #define LATINIME_HEADER_READING_UTILS_H +#include <map> #include <stdint.h> +#include <vector> #include "defines.h" @@ -26,8 +28,7 @@ namespace latinime { class HeaderReadingUtils { public: typedef uint16_t DictionaryFlags; - - static const int MAX_OPTION_KEY_LENGTH; + typedef std::map<std::vector<int>, std::vector<int> > AttributeMap; static int getHeaderSize(const uint8_t *const dictBuf); @@ -50,14 +51,15 @@ class HeaderReadingUtils { + HEADER_SIZE_FIELD_SIZE; } - static bool readHeaderValue(const uint8_t *const dictBuf, - const char *const key, int *outValue, const int outValueSize); - - static int readHeaderValueInt(const uint8_t *const dictBuf, const char *const key); + static void fetchAllHeaderAttributes(const uint8_t *const dictBuf, + AttributeMap *const headerAttributes); private: DISALLOW_IMPLICIT_CONSTRUCTORS(HeaderReadingUtils); + static const int MAX_ATTRIBUTE_KEY_LENGTH; + static const int MAX_ATTRIBUTE_VALUE_LENGTH; + static const int HEADER_MAGIC_NUMBER_SIZE; static const int HEADER_DICTIONARY_VERSION_SIZE; static const int HEADER_FLAG_SIZE; diff --git a/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.cpp index adcf2dbdf..d5a83a938 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.cpp @@ -21,6 +21,7 @@ #include "suggest/core/dicnode/dic_node.h" #include "suggest/core/dicnode/dic_node_vector.h" #include "suggest/policyimpl/dictionary/patricia_trie_reading_utils.h" +#include "suggest/policyimpl/dictionary/utils/probability_utils.h" namespace latinime { @@ -306,7 +307,19 @@ int PatriciaTriePolicy::getTerminalNodePositionOfWord(const int *const inWord, } } -int PatriciaTriePolicy::getUnigramProbability(const int nodePos) const { +int PatriciaTriePolicy::getProbability(const int unigramProbability, + const int bigramProbability) const { + if (unigramProbability == NOT_A_PROBABILITY) { + return NOT_A_PROBABILITY; + } else if (bigramProbability == NOT_A_PROBABILITY) { + return ProbabilityUtils::backoff(unigramProbability); + } else { + return ProbabilityUtils::computeProbabilityForBigram(unigramProbability, + bigramProbability); + } +} + +int PatriciaTriePolicy::getUnigramProbabilityOfPtNode(const int nodePos) const { if (nodePos == NOT_A_VALID_WORD_POS) { return NOT_A_PROBABILITY; } @@ -324,7 +337,8 @@ int PatriciaTriePolicy::getUnigramProbability(const int nodePos) const { return NOT_A_PROBABILITY; } PatriciaTrieReadingUtils::skipCharacters(mDictRoot, flags, MAX_WORD_LENGTH, &pos); - return PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mDictRoot, &pos); + return getProbability(PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition( + mDictRoot, &pos), NOT_A_PROBABILITY); } int PatriciaTriePolicy::getShortcutPositionOfNode(const int nodePos) const { diff --git a/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.h index d0567fd85..75d976205 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.h @@ -56,7 +56,9 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { int getTerminalNodePositionOfWord(const int *const inWord, const int length, const bool forceLowerCaseSearch) const; - int getUnigramProbability(const int nodePos) const; + int getProbability(const int unigramProbability, const int bigramProbability) const; + + int getUnigramProbabilityOfPtNode(const int nodePos) const; int getShortcutPositionOfNode(const int nodePos) const; diff --git a/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_reading_utils.h b/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_reading_utils.h index f76c38751..2b0646db2 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_reading_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_reading_utils.h @@ -119,6 +119,29 @@ class PatriciaTrieReadingUtils { return FLAG_CHILDREN_POSITION_TYPE_NOPOSITION != (MASK_CHILDREN_POSITION_TYPE & flags); } + static AK_FORCE_INLINE NodeFlags createAndGetFlags(const bool isBlacklisted, + const bool isNotAWord, const bool isTerminal, const bool hasShortcutTargets, + const bool hasBigrams, const bool hasMultipleChars, + const int childrenPositionFieldSize) { + NodeFlags nodeFlags = 0; + nodeFlags = isBlacklisted ? (nodeFlags | FLAG_IS_BLACKLISTED) : nodeFlags; + nodeFlags = isNotAWord ? (nodeFlags | FLAG_IS_NOT_A_WORD) : nodeFlags; + nodeFlags = isTerminal ? (nodeFlags | FLAG_IS_TERMINAL) : nodeFlags; + nodeFlags = hasShortcutTargets ? (nodeFlags | FLAG_HAS_SHORTCUT_TARGETS) : nodeFlags; + nodeFlags = hasBigrams ? (nodeFlags | FLAG_HAS_BIGRAMS) : nodeFlags; + nodeFlags = hasMultipleChars ? (nodeFlags | FLAG_HAS_MULTIPLE_CHARS) : nodeFlags; + if (childrenPositionFieldSize == 1) { + nodeFlags |= FLAG_CHILDREN_POSITION_TYPE_ONEBYTE; + } else if (childrenPositionFieldSize == 2) { + nodeFlags |= FLAG_CHILDREN_POSITION_TYPE_TWOBYTES; + } else if (childrenPositionFieldSize == 3) { + nodeFlags |= FLAG_CHILDREN_POSITION_TYPE_THREEBYTES; + } else { + nodeFlags |= FLAG_CHILDREN_POSITION_TYPE_NOPOSITION; + } + return nodeFlags; + } + private: DISALLOW_IMPLICIT_CONSTRUCTORS(PatriciaTrieReadingUtils); diff --git a/native/jni/src/suggest/policyimpl/dictionary/shortcut/dynamic_shortcut_list_policy.h b/native/jni/src/suggest/policyimpl/dictionary/shortcut/dynamic_shortcut_list_policy.h index 5e9c52950..1803c09cb 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/shortcut/dynamic_shortcut_list_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/shortcut/dynamic_shortcut_list_policy.h @@ -83,8 +83,8 @@ class DynamicShortcutListPolicy : public DictionaryShortcutsStructurePolicy { } // Copy shortcuts from the shortcut list that starts at fromPos to toPos and advance these - // positions after the shortcut lists. - void copyAllShortcuts(int *const fromPos, int *const toPos) { + // positions after the shortcut lists. This returns whether the copy was succeeded or not. + bool copyAllShortcutsAndReturnIfSucceededOrNot(int *const fromPos, int *const toPos) { const bool usesAdditionalBuffer = mBuffer->isInAdditionalBuffer(*fromPos); const uint8_t *const buffer = mBuffer->getBuffer(usesAdditionalBuffer); if (usesAdditionalBuffer) { @@ -93,16 +93,23 @@ class DynamicShortcutListPolicy : public DictionaryShortcutsStructurePolicy { const int shortcutListSize = ShortcutListReadingUtils ::getShortcutListSizeAndForwardPointer(buffer, fromPos); // Copy shortcut list size. - mBuffer->writeUintAndAdvancePosition( + if (!mBuffer->writeUintAndAdvancePosition( shortcutListSize + ShortcutListReadingUtils::getShortcutListSizeFieldSize(), - ShortcutListReadingUtils::getShortcutListSizeFieldSize(), toPos); + ShortcutListReadingUtils::getShortcutListSizeFieldSize(), toPos)) { + return false; + } + // Copy shortcut list. for (int i = 0; i < shortcutListSize; ++i) { - const uint8_t data = ByteArrayUtils::readUint8AndAdvancePosition(buffer, fromPos); - mBuffer->writeUintAndAdvancePosition(data, 1 /* size */, toPos); + const uint8_t data = ByteArrayUtils::readUint8AndAdvancePosition( + mBuffer->getBuffer(usesAdditionalBuffer), fromPos); + if (!mBuffer->writeUintAndAdvancePosition(data, 1 /* size */, toPos)) { + return false; + } } if (usesAdditionalBuffer) { *fromPos += mBuffer->getOriginalBufferSize(); } + return true; } private: diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.cpp index 8582c4b81..dfdaebd18 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.cpp @@ -22,4 +22,71 @@ const size_t BufferWithExtendableBuffer::INITIAL_ADDITIONAL_BUFFER_SIZE = 16 * 1 const size_t BufferWithExtendableBuffer::MAX_ADDITIONAL_BUFFER_SIZE = 1024 * 1024; const size_t BufferWithExtendableBuffer::EXTEND_ADDITIONAL_BUFFER_SIZE_STEP = 16 * 1024; +bool BufferWithExtendableBuffer::writeUintAndAdvancePosition(const uint32_t data, const int size, + int *const pos) { + if (!(size >= 1 && size <= 4)) { + AKLOGI("writeUintAndAdvancePosition() is called with invalid size: %d", size); + ASSERT(false); + return false; + } + if (!checkAndPrepareWriting(*pos, size)) { + return false; + } + const bool usesAdditionalBuffer = isInAdditionalBuffer(*pos); + uint8_t *const buffer = usesAdditionalBuffer ? &mAdditionalBuffer[0] : mOriginalBuffer; + if (usesAdditionalBuffer) { + *pos -= mOriginalBufferSize; + } + ByteArrayUtils::writeUintAndAdvancePosition(buffer, data, size, pos); + if (usesAdditionalBuffer) { + *pos += mOriginalBufferSize; + } + return true; +} + +bool BufferWithExtendableBuffer::writeCodePointsAndAdvancePosition(const int *const codePoints, + const int codePointCount, const bool writesTerminator ,int *const pos) { + const size_t size = ByteArrayUtils::calculateRequiredByteCountToStoreCodePoints( + codePoints, codePointCount, writesTerminator); + if (!checkAndPrepareWriting(*pos, size)) { + return false; + } + const bool usesAdditionalBuffer = isInAdditionalBuffer(*pos); + uint8_t *const buffer = usesAdditionalBuffer ? &mAdditionalBuffer[0] : mOriginalBuffer; + if (usesAdditionalBuffer) { + *pos -= mOriginalBufferSize; + } + ByteArrayUtils::writeCodePointsAndAdvancePosition(buffer, codePoints, codePointCount, + writesTerminator, pos); + if (usesAdditionalBuffer) { + *pos += mOriginalBufferSize; + } + return true; +} + +bool BufferWithExtendableBuffer::checkAndPrepareWriting(const int pos, const int size) { + if (isInAdditionalBuffer(pos)) { + const int tailPosition = getTailPosition(); + if (pos == tailPosition) { + // Append data to the tail. + if (pos + size > static_cast<int>(mAdditionalBuffer.size()) + mOriginalBufferSize) { + // Need to extend buffer. + if (!extendBuffer()) { + return false; + } + } + mUsedAdditionalBufferSize += size; + } else if (pos + size >= tailPosition) { + // The access will beyond the tail of used region. + return false; + } + } else { + if (pos < 0 || mOriginalBufferSize < pos + size) { + // Invalid position or violate the boundary. + return false; + } + } + return true; +} + } diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h b/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h index ec871ec85..c6a484131 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h @@ -47,6 +47,7 @@ class BufferWithExtendableBuffer { return position >= mOriginalBufferSize; } + // TODO: Resolve the issue that the address can be changed when the vector is resized. // CAVEAT!: Be careful about array out of bound access with buffers AK_FORCE_INLINE const uint8_t *getBuffer(const bool usesAdditionalBuffer) const { if (usesAdditionalBuffer) { @@ -66,27 +67,10 @@ class BufferWithExtendableBuffer { * Writing is allowed for original buffer, already written region of additional buffer and the * tail of additional buffer. */ - AK_FORCE_INLINE bool writeUintAndAdvancePosition(const uint32_t data, const int size, - int *const pos) { - if (!(size >= 1 && size <= 4)) { - AKLOGI("writeUintAndAdvancePosition() is called with invalid size: %d", size); - ASSERT(false); - return false; - } - if (!checkAndPrepareWriting(*pos, size)) { - return false; - } - const bool usesAdditionalBuffer = isInAdditionalBuffer(*pos); - uint8_t *const buffer = usesAdditionalBuffer ? &mAdditionalBuffer[0] : mOriginalBuffer; - if (usesAdditionalBuffer) { - *pos -= mOriginalBufferSize; - } - ByteArrayUtils::writeUintAndAdvancePosition(buffer, data, size, pos); - if (usesAdditionalBuffer) { - *pos += mOriginalBufferSize; - } - return true; - } + bool writeUintAndAdvancePosition(const uint32_t data, const int size, int *const pos); + + bool writeCodePointsAndAdvancePosition(const int *const codePoints, const int codePointCount, + const bool writesTerminator, int *const pos); private: DISALLOW_COPY_AND_ASSIGN(BufferWithExtendableBuffer); @@ -112,29 +96,7 @@ class BufferWithExtendableBuffer { // Returns if it is possible to write size-bytes from pos. When pos is at the tail position of // the additional buffer, try extending the buffer. - AK_FORCE_INLINE bool checkAndPrepareWriting(const int pos, const int size) { - if (isInAdditionalBuffer(pos)) { - if (pos == mUsedAdditionalBufferSize) { - // Append data to the tail. - if (pos + size > static_cast<int>(mAdditionalBuffer.size())) { - // Need to extend buffer. - if (!extendBuffer()) { - return false; - } - } - mUsedAdditionalBufferSize += size; - } else if (pos + size >= mUsedAdditionalBufferSize) { - // The access will beyond the tail of used region. - return false; - } - } else { - if (pos < 0 || mOriginalBufferSize < pos + size) { - // Invalid position or violate the boundary. - return false; - } - } - return true; - } + AK_FORCE_INLINE bool checkAndPrepareWriting(const int pos, const int size); }; } #endif /* LATINIME_BUFFER_WITH_EXTENDABLE_BUFFER_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/byte_array_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/byte_array_utils.h index 1d14929c7..f727ecf8e 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/byte_array_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/byte_array_utils.h @@ -115,7 +115,7 @@ class ByteArrayUtils { } /** - * Code Point + * Code Point Reading * * 1 byte = bbbbbbbb match * case 000xxxxx: xxxxx << 16 + next byte << 8 + next byte @@ -149,7 +149,7 @@ class ByteArrayUtils { } /** - * String (array of code points) + * String (array of code points) Reading * * Reads code points until the terminator is found. */ @@ -176,37 +176,49 @@ class ByteArrayUtils { return length; } - // Returns an integer less than, equal to, or greater than zero when string starting from pos - // in buffer is less than, match, or is greater than charArray. - static AK_FORCE_INLINE int compareStringInBufferWithCharArray(const uint8_t *const buffer, - const char *const charArray, const int maxLength, int *const pos) { - int index = 0; - int codePoint = readCodePointAndAdvancePosition(buffer, pos); - const uint8_t *const uint8CharArrayForComparison = - reinterpret_cast<const uint8_t *>(charArray); - while (NOT_A_CODE_POINT != codePoint - && '\0' != uint8CharArrayForComparison[index] && index < maxLength) { - if (codePoint != uint8CharArrayForComparison[index]) { - // Different character is found. - // Skip the rest of the string in the buffer. - advancePositionToBehindString(buffer, maxLength - index, pos); - return codePoint - uint8CharArrayForComparison[index]; + /** + * String (array of code points) Writing + */ + static void writeCodePointsAndAdvancePosition(uint8_t *const buffer, + const int *const codePoints, const int codePointCount, const bool writesTerminator, + int *const pos) { + for (int i = 0; i < codePointCount; ++i) { + const int codePoint = codePoints[i]; + if (codePoint == NOT_A_CODE_POINT || codePoint == CHARACTER_ARRAY_TERMINATOR) { + break; + } else if (codePoint < MINIMAL_ONE_BYTE_CHARACTER_VALUE) { + // three bytes character. + writeUint24AndAdvancePosition(buffer, codePoint, pos); + } else { + // one byte character. + writeUint8AndAdvancePosition(buffer, codePoint, pos); } - // Advance - codePoint = readCodePointAndAdvancePosition(buffer, pos); - ++index; } - if (NOT_A_CODE_POINT != codePoint && index < maxLength) { - // Skip the rest of the string in the buffer. - advancePositionToBehindString(buffer, maxLength - index, pos); + if (writesTerminator) { + writeUint8AndAdvancePosition(buffer, CHARACTER_ARRAY_TERMINATOR, pos); } - if (NOT_A_CODE_POINT == codePoint && '\0' == uint8CharArrayForComparison[index]) { - // When both of the last characters are terminals, we consider the string in the buffer - // matches the given char array - return 0; - } else { - return codePoint - uint8CharArrayForComparison[index]; + } + + static int calculateRequiredByteCountToStoreCodePoints(const int *const codePoints, + const int codePointCount, const bool writesTerminator) { + int byteCount = 0; + for (int i = 0; i < codePointCount; ++i) { + const int codePoint = codePoints[i]; + if (codePoint == NOT_A_CODE_POINT || codePoint == CHARACTER_ARRAY_TERMINATOR) { + break; + } else if (codePoint < MINIMAL_ONE_BYTE_CHARACTER_VALUE) { + // three bytes character. + byteCount += 3; + } else { + // one byte character. + byteCount += 1; + } + } + if (writesTerminator) { + // The terminator is one byte. + byteCount += 1; } + return byteCount; } private: diff --git a/native/jni/src/suggest/core/dictionary/probability_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/probability_utils.h index 21fe355b8..21fe355b8 100644 --- a/native/jni/src/suggest/core/dictionary/probability_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/probability_utils.h |