diff options
Diffstat (limited to 'native/jni/src')
62 files changed, 1713 insertions, 643 deletions
diff --git a/native/jni/src/suggest/core/dicnode/dic_node.h b/native/jni/src/suggest/core/dicnode/dic_node.h index 92f39ea25..d1b2c87be 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node.h +++ b/native/jni/src/suggest/core/dicnode/dic_node.h @@ -117,7 +117,7 @@ class DicNode { int newPrevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; newPrevWordsPtNodePos[0] = dicNode->mDicNodeProperties.getPtNodePos(); for (size_t i = 1; i < NELEMS(newPrevWordsPtNodePos); ++i) { - newPrevWordsPtNodePos[i] = dicNode->getNthPrevWordTerminalPtNodePos(i); + newPrevWordsPtNodePos[i] = dicNode->getPrevWordsTerminalPtNodePos()[i - 1]; } mDicNodeProperties.init(rootPtNodeArrayPos, newPrevWordsPtNodePos); mDicNodeState.initAsRootWithPreviousWord(&dicNode->mDicNodeState, @@ -208,12 +208,9 @@ class DicNode { return mDicNodeProperties.getPtNodePos(); } - // Used to get n-gram probability in DicNodeUtils. n is 1-indexed. - int getNthPrevWordTerminalPtNodePos(const int n) const { - if (n <= 0 || n > MAX_PREV_WORD_COUNT_FOR_N_GRAM) { - return NOT_A_DICT_POS; - } - return mDicNodeProperties.getPrevWordsTerminalPtNodePos()[n - 1]; + // TODO: Use view class to return PtNodePos array. + const int *getPrevWordsTerminalPtNodePos() const { + return mDicNodeProperties.getPrevWordsTerminalPtNodePos(); } // Used in DicNodeUtils diff --git a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp index 4445f4aaf..69ea67418 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp +++ b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp @@ -85,17 +85,10 @@ namespace latinime { const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy, const DicNode *const dicNode, MultiBigramMap *const multiBigramMap) { const int unigramProbability = dicNode->getProbability(); - const int ptNodePos = dicNode->getPtNodePos(); - const int prevWordTerminalPtNodePos = dicNode->getNthPrevWordTerminalPtNodePos(1 /* n */); - if (NOT_A_DICT_POS == ptNodePos || NOT_A_DICT_POS == prevWordTerminalPtNodePos) { - // Note: Normally wordPos comes from the dictionary and should never equal - // NOT_A_VALID_WORD_POS. - return dictionaryStructurePolicy->getProbability(unigramProbability, - NOT_A_PROBABILITY); - } if (multiBigramMap) { + const int *const prevWordsPtNodePos = dicNode->getPrevWordsTerminalPtNodePos(); return multiBigramMap->getBigramProbability(dictionaryStructurePolicy, - prevWordTerminalPtNodePos, ptNodePos, unigramProbability); + prevWordsPtNodePos, dicNode->getPtNodePos(), unigramProbability); } return dictionaryStructurePolicy->getProbability(unigramProbability, NOT_A_PROBABILITY); diff --git a/native/jni/src/suggest/core/dictionary/dictionary.cpp b/native/jni/src/suggest/core/dictionary/dictionary.cpp index fb25f757c..d62573970 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.cpp +++ b/native/jni/src/suggest/core/dictionary/dictionary.cpp @@ -59,42 +59,48 @@ void Dictionary::getSuggestions(ProximityInfo *proximityInfo, DicTraverseSession } } +Dictionary::NgramListenerForPrediction::NgramListenerForPrediction( + const PrevWordsInfo *const prevWordsInfo, SuggestionResults *const suggestionResults, + const DictionaryStructureWithBufferPolicy *const dictStructurePolicy) + : mPrevWordsInfo(prevWordsInfo), mSuggestionResults(suggestionResults), + mDictStructurePolicy(dictStructurePolicy) {} + +void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbability, + const int targetPtNodePos) { + if (targetPtNodePos == NOT_A_DICT_POS) { + return; + } + if (mPrevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */) + && ngramProbability == NOT_A_PROBABILITY) { + return; + } + int targetWordCodePoints[MAX_WORD_LENGTH]; + int unigramProbability = 0; + const int codePointCount = mDictStructurePolicy-> + getCodePointsAndProbabilityAndReturnCodePointCount(targetPtNodePos, + MAX_WORD_LENGTH, targetWordCodePoints, &unigramProbability); + if (codePointCount <= 0) { + return; + } + const int probability = mDictStructurePolicy->getProbability( + unigramProbability, ngramProbability); + mSuggestionResults->addPrediction(targetWordCodePoints, codePointCount, probability); +} + void Dictionary::getPredictions(const PrevWordsInfo *const prevWordsInfo, SuggestionResults *const outSuggestionResults) const { TimeKeeper::setCurrentTime(); - int unigramProbability = 0; - int bigramCodePoints[MAX_WORD_LENGTH]; - BinaryDictionaryBigramsIterator bigramsIt = prevWordsInfo->getBigramsIteratorForPrediction( + NgramListenerForPrediction listener(prevWordsInfo, outSuggestionResults, mDictionaryStructureWithBufferPolicy.get()); - while (bigramsIt.hasNext()) { - bigramsIt.next(); - if (bigramsIt.getBigramPos() == NOT_A_DICT_POS) { - continue; - } - if (prevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */) - && bigramsIt.getProbability() == NOT_A_PROBABILITY) { - continue; - } - const int codePointCount = mDictionaryStructureWithBufferPolicy-> - getCodePointsAndProbabilityAndReturnCodePointCount(bigramsIt.getBigramPos(), - MAX_WORD_LENGTH, bigramCodePoints, &unigramProbability); - if (codePointCount <= 0) { - continue; - } - const int probability = mDictionaryStructureWithBufferPolicy->getProbability( - unigramProbability, bigramsIt.getProbability()); - outSuggestionResults->addPrediction(bigramCodePoints, codePointCount, probability); - } + int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; + prevWordsInfo->getPrevWordsTerminalPtNodePos( + mDictionaryStructureWithBufferPolicy.get(), prevWordsPtNodePos, + true /* tryLowerCaseSearch */); + mDictionaryStructureWithBufferPolicy->iterateNgramEntries(prevWordsPtNodePos, &listener); } int Dictionary::getProbability(const int *word, int length) const { - TimeKeeper::setCurrentTime(); - int pos = getDictionaryStructurePolicy()->getTerminalPtNodePositionOfWord(word, length, - false /* forceLowerCaseSearch */); - if (NOT_A_DICT_POS == pos) { - return NOT_A_PROBABILITY; - } - return getDictionaryStructurePolicy()->getUnigramProbabilityOfPtNode(pos); + return getNgramProbability(nullptr /* prevWordsInfo */, word, length); } int Dictionary::getMaxProbabilityOfExactMatches(const int *word, int length) const { @@ -109,18 +115,15 @@ int Dictionary::getNgramProbability(const PrevWordsInfo *const prevWordsInfo, co int nextWordPos = mDictionaryStructureWithBufferPolicy->getTerminalPtNodePositionOfWord(word, length, false /* forceLowerCaseSearch */); if (NOT_A_DICT_POS == nextWordPos) return NOT_A_PROBABILITY; - BinaryDictionaryBigramsIterator bigramsIt = prevWordsInfo->getBigramsIteratorForPrediction( - mDictionaryStructureWithBufferPolicy.get()); - while (bigramsIt.hasNext()) { - bigramsIt.next(); - if (bigramsIt.getBigramPos() == nextWordPos - && bigramsIt.getProbability() != NOT_A_PROBABILITY) { - return mDictionaryStructureWithBufferPolicy->getProbability( - mDictionaryStructureWithBufferPolicy->getUnigramProbabilityOfPtNode( - nextWordPos), bigramsIt.getProbability()); - } + if (!prevWordsInfo) { + return getDictionaryStructurePolicy()->getProbabilityOfPtNode( + nullptr /* prevWordsPtNodePos */, nextWordPos); } - return NOT_A_PROBABILITY; + int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; + prevWordsInfo->getPrevWordsTerminalPtNodePos( + mDictionaryStructureWithBufferPolicy.get(), prevWordsPtNodePos, + true /* tryLowerCaseSearch */); + return getDictionaryStructurePolicy()->getProbabilityOfPtNode(prevWordsPtNodePos, nextWordPos); } bool Dictionary::addUnigramEntry(const int *const word, const int length, diff --git a/native/jni/src/suggest/core/dictionary/dictionary.h b/native/jni/src/suggest/core/dictionary/dictionary.h index 3b41088fe..732d3b199 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.h +++ b/native/jni/src/suggest/core/dictionary/dictionary.h @@ -21,6 +21,7 @@ #include "defines.h" #include "jni.h" +#include "suggest/core/dictionary/ngram_listener.h" #include "suggest/core/dictionary/property/word_property.h" #include "suggest/core/policy/dictionary_header_structure_policy.h" #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" @@ -114,6 +115,21 @@ class Dictionary { typedef std::unique_ptr<SuggestInterface> SuggestInterfacePtr; + class NgramListenerForPrediction : public NgramListener { + public: + NgramListenerForPrediction(const PrevWordsInfo *const prevWordsInfo, + SuggestionResults *const suggestionResults, + const DictionaryStructureWithBufferPolicy *const dictStructurePolicy); + virtual void onVisitEntry(const int ngramProbability, const int targetPtNodePos); + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(NgramListenerForPrediction); + + const PrevWordsInfo *const mPrevWordsInfo; + SuggestionResults *const mSuggestionResults; + const DictionaryStructureWithBufferPolicy *const mDictStructurePolicy; + }; + static const int HEADER_ATTRIBUTE_BUFFER_SIZE; const DictionaryStructureWithBufferPolicy::StructurePolicyPtr diff --git a/native/jni/src/suggest/core/dictionary/multi_bigram_map.cpp b/native/jni/src/suggest/core/dictionary/multi_bigram_map.cpp index 012e4dc9c..91f33a8dd 100644 --- a/native/jni/src/suggest/core/dictionary/multi_bigram_map.cpp +++ b/native/jni/src/suggest/core/dictionary/multi_bigram_map.cpp @@ -35,34 +35,30 @@ const int MultiBigramMap::BigramMap::DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP = // Also caches the bigrams if there is space remaining and they have not been cached already. int MultiBigramMap::getBigramProbability( const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int wordPosition, const int nextWordPosition, const int unigramProbability) { + const int *const prevWordsPtNodePos, const int nextWordPosition, + const int unigramProbability) { + if (!prevWordsPtNodePos || prevWordsPtNodePos[0] == NOT_A_DICT_POS) { + return structurePolicy->getProbability(unigramProbability, NOT_A_PROBABILITY); + } std::unordered_map<int, BigramMap>::const_iterator mapPosition = - mBigramMaps.find(wordPosition); + mBigramMaps.find(prevWordsPtNodePos[0]); if (mapPosition != mBigramMaps.end()) { return mapPosition->second.getBigramProbability(structurePolicy, nextWordPosition, unigramProbability); } if (mBigramMaps.size() < MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP) { - addBigramsForWordPosition(structurePolicy, wordPosition); - return mBigramMaps[wordPosition].getBigramProbability(structurePolicy, + addBigramsForWordPosition(structurePolicy, prevWordsPtNodePos); + return mBigramMaps[prevWordsPtNodePos[0]].getBigramProbability(structurePolicy, nextWordPosition, unigramProbability); } - return readBigramProbabilityFromBinaryDictionary(structurePolicy, wordPosition, + return readBigramProbabilityFromBinaryDictionary(structurePolicy, prevWordsPtNodePos, nextWordPosition, unigramProbability); } void MultiBigramMap::BigramMap::init( - const DictionaryStructureWithBufferPolicy *const structurePolicy, const int nodePos) { - BinaryDictionaryBigramsIterator bigramsIt = - structurePolicy->getBigramsIteratorOfPtNode(nodePos); - while (bigramsIt.hasNext()) { - bigramsIt.next(); - if (bigramsIt.getBigramPos() == NOT_A_DICT_POS) { - continue; - } - mBigramMap[bigramsIt.getBigramPos()] = bigramsIt.getProbability(); - mBloomFilter.setInFilter(bigramsIt.getBigramPos()); - } + const DictionaryStructureWithBufferPolicy *const structurePolicy, + const int *const prevWordsPtNodePos) { + structurePolicy->iterateNgramEntries(prevWordsPtNodePos, this /* listener */); } int MultiBigramMap::BigramMap::getBigramProbability( @@ -79,25 +75,33 @@ int MultiBigramMap::BigramMap::getBigramProbability( return structurePolicy->getProbability(unigramProbability, bigramProbability); } +void MultiBigramMap::BigramMap::onVisitEntry(const int ngramProbability, + const int targetPtNodePos) { + if (targetPtNodePos == NOT_A_DICT_POS) { + return; + } + mBigramMap[targetPtNodePos] = ngramProbability; + mBloomFilter.setInFilter(targetPtNodePos); +} + void MultiBigramMap::addBigramsForWordPosition( - const DictionaryStructureWithBufferPolicy *const structurePolicy, const int position) { - mBigramMaps[position].init(structurePolicy, position); + const DictionaryStructureWithBufferPolicy *const structurePolicy, + const int *const prevWordsPtNodePos) { + if (prevWordsPtNodePos) { + mBigramMaps[prevWordsPtNodePos[0]].init(structurePolicy, prevWordsPtNodePos); + } } int MultiBigramMap::readBigramProbabilityFromBinaryDictionary( - const DictionaryStructureWithBufferPolicy *const structurePolicy, const int nodePos, - const int nextWordPosition, const int unigramProbability) { - int bigramProbability = NOT_A_PROBABILITY; - BinaryDictionaryBigramsIterator bigramsIt = - structurePolicy->getBigramsIteratorOfPtNode(nodePos); - while (bigramsIt.hasNext()) { - bigramsIt.next(); - if (bigramsIt.getBigramPos() == nextWordPosition) { - bigramProbability = bigramsIt.getProbability(); - break; - } + const DictionaryStructureWithBufferPolicy *const structurePolicy, + const int *const prevWordsPtNodePos, const int nextWordPosition, + const int unigramProbability) { + const int bigramProbability = structurePolicy->getProbabilityOfPtNode(prevWordsPtNodePos, + nextWordPosition); + if (bigramProbability != NOT_A_PROBABILITY) { + return bigramProbability; } - return structurePolicy->getProbability(unigramProbability, bigramProbability); + return structurePolicy->getProbability(unigramProbability, NOT_A_PROBABILITY); } } // namespace latinime diff --git a/native/jni/src/suggest/core/dictionary/multi_bigram_map.h b/native/jni/src/suggest/core/dictionary/multi_bigram_map.h index 195b5e22f..ad36dde83 100644 --- a/native/jni/src/suggest/core/dictionary/multi_bigram_map.h +++ b/native/jni/src/suggest/core/dictionary/multi_bigram_map.h @@ -23,6 +23,7 @@ #include "defines.h" #include "suggest/core/dictionary/binary_dictionary_bigrams_iterator.h" #include "suggest/core/dictionary/bloom_filter.h" +#include "suggest/core/dictionary/ngram_listener.h" #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" namespace latinime { @@ -38,7 +39,8 @@ class MultiBigramMap { // Look up the bigram probability for the given word pair from the cached bigram maps. // Also caches the bigrams if there is space remaining and they have not been cached already. int getBigramProbability(const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int wordPosition, const int nextWordPosition, const int unigramProbability); + const int *const prevWordsPtNodePos, const int nextWordPosition, + const int unigramProbability); void clear() { mBigramMaps.clear(); @@ -47,32 +49,35 @@ class MultiBigramMap { private: DISALLOW_COPY_AND_ASSIGN(MultiBigramMap); - class BigramMap { + class BigramMap : public NgramListener { public: BigramMap() : mBigramMap(DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP), mBloomFilter() {} - ~BigramMap() {} + // Copy constructor needed for std::unordered_map. + BigramMap(const BigramMap &bigramMap) + : mBigramMap(bigramMap.mBigramMap), mBloomFilter(bigramMap.mBloomFilter) {} + virtual ~BigramMap() {} void init(const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int nodePos); - + const int *const prevWordsPtNodePos); int getBigramProbability( const DictionaryStructureWithBufferPolicy *const structurePolicy, const int nextWordPosition, const int unigramProbability) const; + virtual void onVisitEntry(const int ngramProbability, const int targetPtNodePos); private: - // NOTE: The BigramMap class doesn't use DISALLOW_COPY_AND_ASSIGN() because its default - // copy constructor is needed for use in hash_map. static const int DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP; std::unordered_map<int, int> mBigramMap; BloomFilter mBloomFilter; }; void addBigramsForWordPosition( - const DictionaryStructureWithBufferPolicy *const structurePolicy, const int position); + const DictionaryStructureWithBufferPolicy *const structurePolicy, + const int *const prevWordsPtNodePos); int readBigramProbabilityFromBinaryDictionary( - const DictionaryStructureWithBufferPolicy *const structurePolicy, const int nodePos, - const int nextWordPosition, const int unigramProbability); + const DictionaryStructureWithBufferPolicy *const structurePolicy, + const int *const prevWordsPtNodePos, const int nextWordPosition, + const int unigramProbability); static const size_t MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP; std::unordered_map<int, BigramMap> mBigramMaps; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dict_content.h b/native/jni/src/suggest/core/dictionary/ngram_listener.h index c264aeac4..88b88bafb 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dict_content.h +++ b/native/jni/src/suggest/core/dictionary/ngram_listener.h @@ -1,5 +1,5 @@ /* - * Copyright (C) 2013, The Android Open Source Project + * Copyright (C) 2014, The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,22 +14,27 @@ * limitations under the License. */ -#ifndef LATINIME_DICT_CONTENT_H -#define LATINIME_DICT_CONTENT_H +#ifndef LATINIME_NGRAM_LISTENER_H +#define LATINIME_NGRAM_LISTENER_H #include "defines.h" namespace latinime { -class DictContent { +/** + * Interface to iterate ngram entries. + */ +class NgramListener { public: - virtual ~DictContent() {} + virtual void onVisitEntry(const int ngramProbability, const int targetPtNodePos) = 0; + virtual ~NgramListener() {}; protected: - DictContent() {} + NgramListener() {} private: - DISALLOW_COPY_AND_ASSIGN(DictContent); + DISALLOW_COPY_AND_ASSIGN(NgramListener); + }; } // namespace latinime -#endif /* LATINIME_DICT_CONTENT_H */ +#endif /* LATINIME_NGRAM_LISTENER_H */ diff --git a/native/jni/src/suggest/core/layout/proximity_info_state.h b/native/jni/src/suggest/core/layout/proximity_info_state.h index 6b1a319aa..e6180fe17 100644 --- a/native/jni/src/suggest/core/layout/proximity_info_state.h +++ b/native/jni/src/suggest/core/layout/proximity_info_state.h @@ -215,13 +215,13 @@ class ProximityInfoState { std::vector<float> mSpeedRates; std::vector<float> mDirections; // probabilities of skipping or mapping to a key for each point. - std::vector<std::unordered_map<int, float> > mCharProbabilities; + std::vector<std::unordered_map<int, float>> mCharProbabilities; // The vector for the key code set which holds nearby keys of some trailing sampled input points // for each sampled input point. These nearby keys contain the next characters which can be in // the dictionary. Specifically, currently we are looking for keys nearby trailing sampled // inputs including the current input point. std::vector<ProximityInfoStateUtils::NearKeycodesSet> mSampledSearchKeySets; - std::vector<std::vector<int> > mSampledSearchKeyVectors; + std::vector<std::vector<int>> mSampledSearchKeyVectors; bool mTouchPositionCorrectionEnabled; int mInputProximities[MAX_PROXIMITY_CHARS_SIZE * MAX_WORD_LENGTH]; int mSampledInputSize; diff --git a/native/jni/src/suggest/core/layout/proximity_info_state_utils.cpp b/native/jni/src/suggest/core/layout/proximity_info_state_utils.cpp index ea3b02216..0aeb36aad 100644 --- a/native/jni/src/suggest/core/layout/proximity_info_state_utils.cpp +++ b/native/jni/src/suggest/core/layout/proximity_info_state_utils.cpp @@ -621,7 +621,7 @@ namespace latinime { const std::vector<int> *const sampledLengthCache, const std::vector<float> *const sampledNormalizedSquaredLengthCache, const ProximityInfo *const proximityInfo, - std::vector<std::unordered_map<int, float> > *charProbabilities) { + std::vector<std::unordered_map<int, float>> *charProbabilities) { charProbabilities->resize(sampledInputSize); // Calculates probabilities of using a point as a correlated point with the character // for each point. @@ -822,9 +822,9 @@ namespace latinime { /* static */ void ProximityInfoStateUtils::updateSampledSearchKeySets( const ProximityInfo *const proximityInfo, const int sampledInputSize, const int lastSavedInputSize, const std::vector<int> *const sampledLengthCache, - const std::vector<std::unordered_map<int, float> > *const charProbabilities, + const std::vector<std::unordered_map<int, float>> *const charProbabilities, std::vector<NearKeycodesSet> *sampledSearchKeySets, - std::vector<std::vector<int> > *sampledSearchKeyVectors) { + std::vector<std::vector<int>> *sampledSearchKeyVectors) { sampledSearchKeySets->resize(sampledInputSize); sampledSearchKeyVectors->resize(sampledInputSize); const int readForwordLength = static_cast<int>( @@ -868,7 +868,7 @@ namespace latinime { /* static */ bool ProximityInfoStateUtils::suppressCharProbabilities(const int mostCommonKeyWidth, const int sampledInputSize, const std::vector<int> *const lengthCache, const int index0, const int index1, - std::vector<std::unordered_map<int, float> > *charProbabilities) { + std::vector<std::unordered_map<int, float>> *charProbabilities) { ASSERT(0 <= index0 && index0 < sampledInputSize); ASSERT(0 <= index1 && index1 < sampledInputSize); const float keyWidthFloat = static_cast<float>(mostCommonKeyWidth); @@ -933,7 +933,7 @@ namespace latinime { // returns probability of generating the word. /* static */ float ProximityInfoStateUtils::getMostProbableString( const ProximityInfo *const proximityInfo, const int sampledInputSize, - const std::vector<std::unordered_map<int, float> > *const charProbabilities, + const std::vector<std::unordered_map<int, float>> *const charProbabilities, int *const codePointBuf) { ASSERT(sampledInputSize >= 0); memset(codePointBuf, 0, sizeof(codePointBuf[0]) * MAX_WORD_LENGTH); diff --git a/native/jni/src/suggest/core/layout/proximity_info_state_utils.h b/native/jni/src/suggest/core/layout/proximity_info_state_utils.h index 211a79737..4043334e6 100644 --- a/native/jni/src/suggest/core/layout/proximity_info_state_utils.h +++ b/native/jni/src/suggest/core/layout/proximity_info_state_utils.h @@ -72,13 +72,13 @@ class ProximityInfoStateUtils { const std::vector<int> *const sampledLengthCache, const std::vector<float> *const sampledNormalizedSquaredLengthCache, const ProximityInfo *const proximityInfo, - std::vector<std::unordered_map<int, float> > *charProbabilities); + std::vector<std::unordered_map<int, float>> *charProbabilities); static void updateSampledSearchKeySets(const ProximityInfo *const proximityInfo, const int sampledInputSize, const int lastSavedInputSize, const std::vector<int> *const sampledLengthCache, - const std::vector<std::unordered_map<int, float> > *const charProbabilities, + const std::vector<std::unordered_map<int, float>> *const charProbabilities, std::vector<NearKeycodesSet> *sampledSearchKeySets, - std::vector<std::vector<int> > *sampledSearchKeyVectors); + std::vector<std::vector<int>> *sampledSearchKeyVectors); static float getPointToKeyByIdLength(const float maxPointToKeyLength, const std::vector<float> *const sampledNormalizedSquaredLengthCache, const int keyCount, const int inputIndex, const int keyId); @@ -105,7 +105,7 @@ class ProximityInfoStateUtils { // TODO: Move to most_probable_string_utils.h static float getMostProbableString(const ProximityInfo *const proximityInfo, const int sampledInputSize, - const std::vector<std::unordered_map<int, float> > *const charProbabilities, + const std::vector<std::unordered_map<int, float>> *const charProbabilities, int *const codePointBuf); private: @@ -147,7 +147,7 @@ class ProximityInfoStateUtils { const int index2); static bool suppressCharProbabilities(const int mostCommonKeyWidth, const int sampledInputSize, const std::vector<int> *const lengthCache, const int index0, - const int index1, std::vector<std::unordered_map<int, float> > *charProbabilities); + const int index1, std::vector<std::unordered_map<int, float>> *charProbabilities); static float calculateSquaredDistanceFromSweetSpotCenter( const ProximityInfo *const proximityInfo, const std::vector<int> *const sampledInputXs, const std::vector<int> *const sampledInputYs, const int keyIndex, diff --git a/native/jni/src/suggest/core/policy/dictionary_bigrams_structure_policy.h b/native/jni/src/suggest/core/policy/dictionary_bigrams_structure_policy.h index 661ef1b1a..aa0d068aa 100644 --- a/native/jni/src/suggest/core/policy/dictionary_bigrams_structure_policy.h +++ b/native/jni/src/suggest/core/policy/dictionary_bigrams_structure_policy.h @@ -30,7 +30,7 @@ class DictionaryBigramsStructurePolicy { virtual void getNextBigram(int *const outBigramPos, int *const outProbability, bool *const outHasNext, int *const pos) const = 0; - virtual void skipAllBigrams(int *const pos) const = 0; + virtual bool skipAllBigrams(int *const pos) const = 0; protected: DictionaryBigramsStructurePolicy() {} diff --git a/native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h b/native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h index a61227626..6da390e55 100644 --- a/native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h +++ b/native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h @@ -30,7 +30,7 @@ namespace latinime { */ class DictionaryHeaderStructurePolicy { public: - typedef std::map<std::vector<int>, std::vector<int> > AttributeMap; + typedef std::map<std::vector<int>, std::vector<int>> AttributeMap; virtual ~DictionaryHeaderStructurePolicy() {} diff --git a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h index a48d64473..e91f07682 100644 --- a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h +++ b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h @@ -20,16 +20,15 @@ #include <memory> #include "defines.h" -#include "suggest/core/dictionary/binary_dictionary_bigrams_iterator.h" #include "suggest/core/dictionary/property/word_property.h" namespace latinime { class DicNode; class DicNodeVector; -class DictionaryBigramsStructurePolicy; class DictionaryHeaderStructurePolicy; class DictionaryShortcutsStructurePolicy; +class NgramListener; class PrevWordsInfo; class UnigramProperty; @@ -58,11 +57,13 @@ class DictionaryStructureWithBufferPolicy { virtual int getProbability(const int unigramProbability, const int bigramProbability) const = 0; - virtual int getUnigramProbabilityOfPtNode(const int nodePos) const = 0; + virtual int getProbabilityOfPtNode(const int *const prevWordsPtNodePos, + const int nodePos) const = 0; - virtual int getShortcutPositionOfPtNode(const int nodePos) const = 0; + virtual void iterateNgramEntries(const int *const prevWordsPtNodePos, + NgramListener *const listener) const = 0; - virtual BinaryDictionaryBigramsIterator getBigramsIteratorOfPtNode(const int nodePos) const = 0; + virtual int getShortcutPositionOfPtNode(const int nodePos) const = 0; virtual const DictionaryHeaderStructurePolicy *getHeaderStructurePolicy() const = 0; diff --git a/native/jni/src/suggest/core/policy/scoring.h b/native/jni/src/suggest/core/policy/scoring.h index 292194bf2..9e75cace4 100644 --- a/native/jni/src/suggest/core/policy/scoring.h +++ b/native/jni/src/suggest/core/policy/scoring.h @@ -37,7 +37,6 @@ class Scoring { DicNode *const terminals, const int size) const = 0; virtual float getDoubleLetterDemotionDistanceCost( const DicNode *const terminalDicNode) const = 0; - virtual bool doesAutoCorrectValidWord() const = 0; virtual bool autoCorrectsToMultiWordSuggestionIfTop() const = 0; virtual bool sameAsTyped(const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const = 0; diff --git a/native/jni/src/suggest/core/result/suggestions_output_utils.cpp b/native/jni/src/suggest/core/result/suggestions_output_utils.cpp index 7b0e7e1b4..0b99b75ec 100644 --- a/native/jni/src/suggest/core/result/suggestions_output_utils.cpp +++ b/native/jni/src/suggest/core/result/suggestions_output_utils.cpp @@ -117,8 +117,7 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; const int finalScore = scoringPolicy->calculateFinalScore( compoundDistance, traverseSession->getInputSize(), terminalDicNode->getContainedErrorTypes(), - (forceCommitMultiWords && terminalDicNode->hasMultipleWords()) - || (isValidWord && scoringPolicy->doesAutoCorrectValidWord()), + (forceCommitMultiWords && terminalDicNode->hasMultipleWords()), boostExactMatches); // Don't output invalid or blocked offensive words. However, we still need to submit their @@ -145,12 +144,7 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; traverseSession->getDictionaryStructurePolicy() ->getShortcutPositionOfPtNode(terminalDicNode->getPtNodePos())); const bool sameAsTyped = scoringPolicy->sameAsTyped(traverseSession, terminalDicNode); - const int shortcutBaseScore = scoringPolicy->doesAutoCorrectValidWord() ? - scoringPolicy->calculateFinalScore(compoundDistance, - traverseSession->getInputSize(), - terminalDicNode->getContainedErrorTypes(), - true /* forceCommit */, boostExactMatches) : finalScore; - outputShortcuts(&shortcutIt, shortcutBaseScore, sameAsTyped, outSuggestionResults); + outputShortcuts(&shortcutIt, finalScore, sameAsTyped, outSuggestionResults); } } diff --git a/native/jni/src/suggest/core/session/prev_words_info.h b/native/jni/src/suggest/core/session/prev_words_info.h index 76276f528..e44e876e9 100644 --- a/native/jni/src/suggest/core/session/prev_words_info.h +++ b/native/jni/src/suggest/core/session/prev_words_info.h @@ -90,13 +90,6 @@ class PrevWordsInfo { } } - BinaryDictionaryBigramsIterator getBigramsIteratorForPrediction( - const DictionaryStructureWithBufferPolicy *const dictStructurePolicy) const { - return getBigramsIteratorForWordWithTryingLowerCaseSearch( - dictStructurePolicy, mPrevWordCodePoints[0], mPrevWordCodePointCount[0], - mIsBeginningOfSentence[0]); - } - // n is 1-indexed. const int *getNthPrevWordCodePoints(const int n) const { if (n <= 0 || n > MAX_PREV_WORD_COUNT_FOR_N_GRAM) { @@ -154,46 +147,6 @@ class PrevWordsInfo { codePoints, codePointCount, true /* forceLowerCaseSearch */); } - static BinaryDictionaryBigramsIterator getBigramsIteratorForWordWithTryingLowerCaseSearch( - const DictionaryStructureWithBufferPolicy *const dictStructurePolicy, - const int *const wordCodePoints, const int wordCodePointCount, - const bool isBeginningOfSentence) { - if (!dictStructurePolicy || !wordCodePoints || wordCodePointCount > MAX_WORD_LENGTH) { - return BinaryDictionaryBigramsIterator(); - } - int codePoints[MAX_WORD_LENGTH]; - int codePointCount = wordCodePointCount; - memmove(codePoints, wordCodePoints, sizeof(int) * codePointCount); - if (isBeginningOfSentence) { - codePointCount = CharUtils::attachBeginningOfSentenceMarker(codePoints, - codePointCount, MAX_WORD_LENGTH); - if (codePointCount <= 0) { - return BinaryDictionaryBigramsIterator(); - } - } - BinaryDictionaryBigramsIterator bigramsIt = getBigramsIteratorForWord(dictStructurePolicy, - codePoints, codePointCount, false /* forceLowerCaseSearch */); - // getBigramsIteratorForWord returns an empty iterator if this word isn't in the dictionary - // or has no bigrams. - if (bigramsIt.hasNext()) { - return bigramsIt; - } - // If no bigrams for this exact word, search again in lower case. - return getBigramsIteratorForWord(dictStructurePolicy, codePoints, - codePointCount, true /* forceLowerCaseSearch */); - } - - static BinaryDictionaryBigramsIterator getBigramsIteratorForWord( - const DictionaryStructureWithBufferPolicy *const dictStructurePolicy, - const int *wordCodePoints, const int wordCodePointCount, - const bool forceLowerCaseSearch) { - if (!wordCodePoints || wordCodePointCount <= 0) return BinaryDictionaryBigramsIterator(); - const int terminalPtNodePos = dictStructurePolicy->getTerminalPtNodePositionOfWord( - wordCodePoints, wordCodePointCount, forceLowerCaseSearch); - if (NOT_A_DICT_POS == terminalPtNodePos) return BinaryDictionaryBigramsIterator(); - return dictStructurePolicy->getBigramsIteratorOfPtNode(terminalPtNodePos); - } - void clear() { for (size_t i = 0; i < NELEMS(mPrevWordCodePoints); ++i) { mPrevWordCodePointCount[i] = 0; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/bigram/ver4_bigram_list_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/bigram/ver4_bigram_list_policy.h index 61623468e..50a4c9743 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/bigram/ver4_bigram_list_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/bigram/ver4_bigram_list_policy.h @@ -58,8 +58,9 @@ class Ver4BigramListPolicy : public DictionaryBigramsStructurePolicy { void getNextBigram(int *const outBigramPos, int *const outProbability, bool *const outHasNext, int *const bigramEntryPos) const; - void skipAllBigrams(int *const pos) const { + bool skipAllBigrams(int *const pos) const { // Do nothing because we don't need to skip bigram lists in ver4 dictionaries. + return true; } bool addNewEntry(const int terminalId, const int newTargetTerminalId, diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/single_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/single_dict_content.h index 6433650b0..49f446814 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/single_dict_content.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/single_dict_content.h @@ -30,6 +30,7 @@ #include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" #include "suggest/policyimpl/dictionary/utils/dict_file_writing_utils.h" #include "suggest/policyimpl/dictionary/utils/mmapped_buffer.h" +#include "utils/byte_array_view.h" namespace latinime { namespace backward { @@ -40,8 +41,9 @@ class SingleDictContent : public DictContent { SingleDictContent(const char *const dictPath, const char *const contentFileName, const bool isUpdatable) : mMmappedBuffer(MmappedBuffer::openBuffer(dictPath, contentFileName, isUpdatable)), - mExpandableContentBuffer(mMmappedBuffer ? mMmappedBuffer->getBuffer() : nullptr, - mMmappedBuffer ? mMmappedBuffer->getBufferSize() : 0, + mExpandableContentBuffer( + mMmappedBuffer ? mMmappedBuffer->getReadWriteByteArrayView() : + ReadWriteByteArrayView(), BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE), mIsValid(mMmappedBuffer) {} diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/sparse_table_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/sparse_table_dict_content.h index c7233edd3..3c626df11 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/sparse_table_dict_content.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/sparse_table_dict_content.h @@ -31,6 +31,7 @@ #include "suggest/policyimpl/dictionary/utils/dict_file_writing_utils.h" #include "suggest/policyimpl/dictionary/utils/mmapped_buffer.h" #include "suggest/policyimpl/dictionary/utils/sparse_table.h" +#include "utils/byte_array_view.h" namespace latinime { namespace backward { @@ -50,15 +51,16 @@ class SparseTableDictContent : public DictContent { mContentBuffer( MmappedBuffer::openBuffer(dictPath, contentFileName, isUpdatable)), mExpandableLookupTableBuffer( - mLookupTableBuffer ? mLookupTableBuffer->getBuffer() : nullptr, - mLookupTableBuffer ? mLookupTableBuffer->getBufferSize() : 0, + mLookupTableBuffer ? mLookupTableBuffer->getReadWriteByteArrayView() : + ReadWriteByteArrayView(), BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE), mExpandableAddressTableBuffer( - mAddressTableBuffer ? mAddressTableBuffer->getBuffer() : nullptr, - mAddressTableBuffer ? mAddressTableBuffer->getBufferSize() : 0, + mAddressTableBuffer ? mAddressTableBuffer->getReadWriteByteArrayView() : + ReadWriteByteArrayView(), BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE), - mExpandableContentBuffer(mContentBuffer ? mContentBuffer->getBuffer() : nullptr, - mContentBuffer ? mContentBuffer->getBufferSize() : 0, + mExpandableContentBuffer( + mContentBuffer ? mContentBuffer->getReadWriteByteArrayView() : + ReadWriteByteArrayView(), BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE), mAddressLookupTable(&mExpandableLookupTableBuffer, &mExpandableAddressTableBuffer, sparseTableBlockSize, sparseTableDataSize), diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_dict_buffers.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_dict_buffers.cpp index 93f192976..3dfbd1c94 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_dict_buffers.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_dict_buffers.cpp @@ -30,6 +30,7 @@ #include "suggest/policyimpl/dictionary/utils/dict_file_writing_utils.h" #include "suggest/policyimpl/dictionary/utils/file_utils.h" +#include "utils/byte_array_view.h" namespace latinime { namespace backward { @@ -130,12 +131,12 @@ Ver4DictBuffers::Ver4DictBuffers(const char *const dictPath, : mHeaderBuffer(std::move(headerBuffer)), mDictBuffer(MmappedBuffer::openBuffer(dictPath, Ver4DictConstants::TRIE_FILE_EXTENSION, isUpdatable)), - mHeaderPolicy(mHeaderBuffer->getBuffer(), formatVersion), - mExpandableHeaderBuffer(mHeaderBuffer ? mHeaderBuffer->getBuffer() : nullptr, - mHeaderPolicy.getSize(), + mHeaderPolicy(mHeaderBuffer->getReadOnlyByteArrayView().data(), formatVersion), + mExpandableHeaderBuffer(mHeaderBuffer->getReadWriteByteArrayView(), BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE), - mExpandableTrieBuffer(mDictBuffer ? mDictBuffer->getBuffer() : nullptr, - mDictBuffer ? mDictBuffer->getBufferSize() : 0, + mExpandableTrieBuffer( + mDictBuffer ? mDictBuffer->getReadWriteByteArrayView() : + ReadWriteByteArrayView(), BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE), mTerminalPositionLookupTable(dictPath, isUpdatable), mProbabilityDictContent(dictPath, mHeaderPolicy.hasHistoricalInfoOfWords(), isUpdatable), diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp index 4220a9561..278f2b199 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp @@ -231,30 +231,31 @@ bool Ver4PatriciaTrieNodeWriter::writeNewTerminalPtNodeAndAdvancePosition( &probabilityEntryToWrite); } -bool Ver4PatriciaTrieNodeWriter::addNewBigramEntry( - const PtNodeParams *const sourcePtNodeParams, const PtNodeParams *const targetPtNodeParam, - const BigramProperty *const bigramProperty, bool *const outAddedNewBigram) { - if (!mBigramPolicy->addNewEntry(sourcePtNodeParams->getTerminalId(), - targetPtNodeParam->getTerminalId(), bigramProperty, outAddedNewBigram)) { +bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds, const int wordId, + const BigramProperty *const bigramProperty, bool *const outAddedNewEntry) { + if (!mBigramPolicy->addNewEntry(prevWordIds[0], wordId, bigramProperty, outAddedNewEntry)) { AKLOGE("Cannot add new bigram entry. terminalId: %d, targetTerminalId: %d", sourcePtNodeParams->getTerminalId(), targetPtNodeParam->getTerminalId()); return false; } - if (!sourcePtNodeParams->hasBigrams()) { + const int ptNodePos = + mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(prevWordIds[0]); + const PtNodeParams sourcePtNodeParams = + mPtNodeReader->fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); + if (!sourcePtNodeParams.hasBigrams()) { // Update has bigrams flag. - return updatePtNodeFlags(sourcePtNodeParams->getHeadPos(), - sourcePtNodeParams->isBlacklisted(), sourcePtNodeParams->isNotAWord(), - sourcePtNodeParams->isTerminal(), sourcePtNodeParams->hasShortcutTargets(), + return updatePtNodeFlags(sourcePtNodeParams.getHeadPos(), + sourcePtNodeParams.isBlacklisted(), sourcePtNodeParams.isNotAWord(), + sourcePtNodeParams.isTerminal(), sourcePtNodeParams.hasShortcutTargets(), true /* hasBigrams */, - sourcePtNodeParams->getCodePointCount() > 1 /* hasMultipleChars */); + sourcePtNodeParams.getCodePointCount() > 1 /* hasMultipleChars */); } return true; } -bool Ver4PatriciaTrieNodeWriter::removeBigramEntry( - const PtNodeParams *const sourcePtNodeParams, const PtNodeParams *const targetPtNodeParam) { - return mBigramPolicy->removeEntry(sourcePtNodeParams->getTerminalId(), - targetPtNodeParam->getTerminalId()); +bool Ver4PatriciaTrieNodeWriter::removeNgramEntry(const WordIdArrayView prevWordIds, + const int wordId) { + return mBigramPolicy->removeEntry(prevWordIds[0], wordId); } bool Ver4PatriciaTrieNodeWriter::updateAllBigramEntriesAndDeleteUselessEntries( diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.h index 08226ea26..d49d9a666 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.h @@ -29,6 +29,7 @@ #include "suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h" #include "suggest/policyimpl/dictionary/structure/pt_common/pt_node_writer.h" #include "suggest/policyimpl/dictionary/structure/backward/v402/content/probability_entry.h" +#include "utils/int_array_view.h" namespace latinime { namespace backward { @@ -61,8 +62,8 @@ class Ver4PatriciaTrieNodeWriter : public PtNodeWriter { const PtNodeArrayReader *const ptNodeArrayReader, Ver4BigramListPolicy *const bigramPolicy, Ver4ShortcutListPolicy *const shortcutPolicy) : mTrieBuffer(trieBuffer), mBuffers(buffers), mHeaderPolicy(headerPolicy), - mReadingHelper(ptNodeReader, ptNodeArrayReader), mBigramPolicy(bigramPolicy), - mShortcutPolicy(shortcutPolicy) {} + mPtNodeReader(ptNodeReader), mReadingHelper(ptNodeReader, ptNodeArrayReader), + mBigramPolicy(bigramPolicy), mShortcutPolicy(shortcutPolicy) {} virtual ~Ver4PatriciaTrieNodeWriter() {} @@ -92,12 +93,10 @@ class Ver4PatriciaTrieNodeWriter : public PtNodeWriter { virtual bool writeNewTerminalPtNodeAndAdvancePosition(const PtNodeParams *const ptNodeParams, const UnigramProperty *const unigramProperty, int *const ptNodeWritingPos); - virtual bool addNewBigramEntry(const PtNodeParams *const sourcePtNodeParams, - const PtNodeParams *const targetPtNodeParam, const BigramProperty *const bigramProperty, - bool *const outAddedNewBigram); + virtual bool addNgramEntry(const WordIdArrayView prevWordIds, const int wordId, + const BigramProperty *const bigramProperty, bool *const outAddedNewEntry); - virtual bool removeBigramEntry(const PtNodeParams *const sourcePtNodeParams, - const PtNodeParams *const targetPtNodeParam); + virtual bool removeNgramEntry(const WordIdArrayView prevWordIds, const int wordId); virtual bool updateAllBigramEntriesAndDeleteUselessEntries( const PtNodeParams *const sourcePtNodeParams, int *const outBigramEntryCount); @@ -135,6 +134,7 @@ class Ver4PatriciaTrieNodeWriter : public PtNodeWriter { BufferWithExtendableBuffer *const mTrieBuffer; Ver4DictBuffers *const mBuffers; const HeaderPolicy *const mHeaderPolicy; + const PtNodeReader *const mPtNodeReader; DynamicPtReadingHelper mReadingHelper; Ver4BigramListPolicy *const mBigramPolicy; Ver4ShortcutListPolicy *const mShortcutPolicy; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp index f478d9b91..1296b8acd 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp @@ -28,6 +28,7 @@ #include "suggest/core/dicnode/dic_node.h" #include "suggest/core/dicnode/dic_node_vector.h" +#include "suggest/core/dictionary/ngram_listener.h" #include "suggest/core/dictionary/property/bigram_property.h" #include "suggest/core/dictionary/property/unigram_property.h" #include "suggest/core/dictionary/property/word_property.h" @@ -131,7 +132,8 @@ int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability, } } -int Ver4PatriciaTriePolicy::getUnigramProbabilityOfPtNode(const int ptNodePos) const { +int Ver4PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtNodePos, + const int ptNodePos) const { if (ptNodePos == NOT_A_DICT_POS) { return NOT_A_PROBABILITY; } @@ -139,9 +141,34 @@ int Ver4PatriciaTriePolicy::getUnigramProbabilityOfPtNode(const int ptNodePos) c if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) { return NOT_A_PROBABILITY; } + if (prevWordsPtNodePos) { + const int bigramsPosition = getBigramsPositionOfPtNode(prevWordsPtNodePos[0]); + BinaryDictionaryBigramsIterator bigramsIt(&mBigramPolicy, bigramsPosition); + while (bigramsIt.hasNext()) { + bigramsIt.next(); + if (bigramsIt.getBigramPos() == ptNodePos + && bigramsIt.getProbability() != NOT_A_PROBABILITY) { + return getProbability(ptNodeParams.getProbability(), bigramsIt.getProbability()); + } + } + return NOT_A_PROBABILITY; + } return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY); } +void Ver4PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordsPtNodePos, + NgramListener *const listener) const { + if (!prevWordsPtNodePos) { + return; + } + const int bigramsPosition = getBigramsPositionOfPtNode(prevWordsPtNodePos[0]); + BinaryDictionaryBigramsIterator bigramsIt(&mBigramPolicy, bigramsPosition); + while (bigramsIt.hasNext()) { + bigramsIt.next(); + listener->onVisitEntry(bigramsIt.getProbability(), bigramsIt.getBigramPos()); + } +} + int Ver4PatriciaTriePolicy::getShortcutPositionOfPtNode(const int ptNodePos) const { if (ptNodePos == NOT_A_DICT_POS) { return NOT_A_DICT_POS; @@ -154,12 +181,6 @@ int Ver4PatriciaTriePolicy::getShortcutPositionOfPtNode(const int ptNodePos) con ptNodeParams.getTerminalId()); } -BinaryDictionaryBigramsIterator Ver4PatriciaTriePolicy::getBigramsIteratorOfPtNode( - const int ptNodePos) const { - const int bigramsPosition = getBigramsPositionOfPtNode(ptNodePos); - return BinaryDictionaryBigramsIterator(&mBigramPolicy, bigramsPosition); -} - int Ver4PatriciaTriePolicy::getBigramsPositionOfPtNode(const int ptNodePos) const { if (ptNodePos == NOT_A_DICT_POS) { return NOT_A_DICT_POS; @@ -288,8 +309,8 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI return false; } bool addedNewBigram = false; - if (mUpdatingHelper.addBigramWords(prevWordsPtNodePos[0], word1Pos, bigramProperty, - &addedNewBigram)) { + if (mUpdatingHelper.addNgramEntry(PtNodePosArrayView::fromObject(prevWordsPtNodePos), + word1Pos, bigramProperty, &addedNewBigram)) { if (addedNewBigram) { mBigramCount++; } @@ -329,7 +350,8 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor if (wordPos == NOT_A_DICT_POS) { return false; } - if (mUpdatingHelper.removeBigramWords(prevWordsPtNodePos[0], wordPos)) { + if (mUpdatingHelper.removeNgramEntry( + PtNodePosArrayView::fromObject(prevWordsPtNodePos), wordPos)) { mBigramCount--; return true; } else { diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h index 6d97c7cc8..9e989b268 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h @@ -90,11 +90,12 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { int getProbability(const int unigramProbability, const int bigramProbability) const; - int getUnigramProbabilityOfPtNode(const int ptNodePos) const; + int getProbabilityOfPtNode(const int *const prevWordsPtNodePos, const int ptNodePos) const; - int getShortcutPositionOfPtNode(const int ptNodePos) const; + void iterateNgramEntries(const int *const prevWordsPtNodePos, + NgramListener *const listener) const; - BinaryDictionaryBigramsIterator getBigramsIteratorOfPtNode(const int ptNodePos) const; + int getShortcutPositionOfPtNode(const int ptNodePos) const; const DictionaryHeaderStructurePolicy *getHeaderStructurePolicy() const { return mHeaderPolicy; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp index e4b5fa267..e4ea3da16 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp @@ -31,6 +31,7 @@ #include "suggest/policyimpl/dictionary/utils/file_utils.h" #include "suggest/policyimpl/dictionary/utils/format_utils.h" #include "suggest/policyimpl/dictionary/utils/mmapped_buffer.h" +#include "utils/byte_array_view.h" namespace latinime { @@ -110,7 +111,8 @@ template<class DictConstants, class DictBuffers, class DictBuffersPtr, class Str return nullptr; } const FormatUtils::FORMAT_VERSION formatVersion = FormatUtils::detectFormatVersion( - mmappedBuffer->getBuffer(), mmappedBuffer->getBufferSize()); + mmappedBuffer->getReadOnlyByteArrayView().data(), + mmappedBuffer->getReadOnlyByteArrayView().size()); switch (formatVersion) { case FormatUtils::VERSION_2: AKLOGE("Given path is a directory but the format is version 2. path: %s", path); @@ -172,8 +174,8 @@ template<class DictConstants, class DictBuffers, class DictBuffersPtr, class Str if (!mmappedBuffer) { return nullptr; } - switch (FormatUtils::detectFormatVersion(mmappedBuffer->getBuffer(), - mmappedBuffer->getBufferSize())) { + switch (FormatUtils::detectFormatVersion(mmappedBuffer->getReadOnlyByteArrayView().data(), + mmappedBuffer->getReadOnlyByteArrayView().size())) { case FormatUtils::VERSION_2: return DictionaryStructureWithBufferPolicy::StructurePolicyPtr( new PatriciaTriePolicy(std::move(mmappedBuffer))); diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/bigram/bigram_list_read_write_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/bigram/bigram_list_read_write_utils.cpp index 08b4e0b5e..f7fd5c071 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/bigram/bigram_list_read_write_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/bigram/bigram_list_read_write_utils.cpp @@ -38,9 +38,14 @@ const BigramListReadWriteUtils::BigramFlags BigramListReadWriteUtils::FLAG_ATTRI const BigramListReadWriteUtils::BigramFlags BigramListReadWriteUtils::MASK_ATTRIBUTE_PROBABILITY = 0x0F; -/* static */ void BigramListReadWriteUtils::getBigramEntryPropertiesAndAdvancePosition( - const uint8_t *const bigramsBuf, BigramFlags *const outBigramFlags, +/* static */ bool BigramListReadWriteUtils::getBigramEntryPropertiesAndAdvancePosition( + const uint8_t *const bigramsBuf, const int bufSize, BigramFlags *const outBigramFlags, int *const outTargetPtNodePos, int *const bigramEntryPos) { + if (bufSize <= *bigramEntryPos) { + AKLOGE("Read invalid pos in getBigramEntryPropertiesAndAdvancePosition(). bufSize: %d, " + "bigramEntryPos: %d.", bufSize, *bigramEntryPos); + return false; + } const BigramFlags bigramFlags = ByteArrayUtils::readUint8AndAdvancePosition(bigramsBuf, bigramEntryPos); if (outBigramFlags) { @@ -51,15 +56,19 @@ const BigramListReadWriteUtils::BigramFlags if (outTargetPtNodePos) { *outTargetPtNodePos = targetPos; } + return true; } -/* static */ void BigramListReadWriteUtils::skipExistingBigrams(const uint8_t *const bigramsBuf, - int *const bigramListPos) { +/* static */ bool BigramListReadWriteUtils::skipExistingBigrams(const uint8_t *const bigramsBuf, + const int bufSize, int *const bigramListPos) { BigramFlags flags; do { - getBigramEntryPropertiesAndAdvancePosition(bigramsBuf, &flags, 0 /* outTargetPtNodePos */, - bigramListPos); + if (!getBigramEntryPropertiesAndAdvancePosition(bigramsBuf, bufSize, &flags, + 0 /* outTargetPtNodePos */, bigramListPos)) { + return false; + } } while(hasNext(flags)); + return true; } /* static */ int BigramListReadWriteUtils::getBigramAddressAndAdvancePosition( diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/bigram/bigram_list_read_write_utils.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/bigram/bigram_list_read_write_utils.h index 15f924a6a..10f93fb7a 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/bigram/bigram_list_read_write_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/bigram/bigram_list_read_write_utils.h @@ -30,8 +30,8 @@ class BigramListReadWriteUtils { public: typedef uint8_t BigramFlags; - static void getBigramEntryPropertiesAndAdvancePosition(const uint8_t *const bigramsBuf, - BigramFlags *const outBigramFlags, int *const outTargetPtNodePos, + static bool getBigramEntryPropertiesAndAdvancePosition(const uint8_t *const bigramsBuf, + const int bufSize, BigramFlags *const outBigramFlags, int *const outTargetPtNodePos, int *const bigramEntryPos); static AK_FORCE_INLINE int getProbabilityFromFlags(const BigramFlags flags) { @@ -43,7 +43,8 @@ public: } // Bigrams reading methods - static void skipExistingBigrams(const uint8_t *const bigramsBuf, int *const bigramListPos); + static bool skipExistingBigrams(const uint8_t *const bigramsBuf, const int bufSize, + int *const bigramListPos); private: DISALLOW_IMPLICIT_CONSTRUCTORS(BigramListReadWriteUtils); diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h index 2e05bf397..b7262581a 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h @@ -26,7 +26,6 @@ namespace latinime { -class DictionaryBigramsStructurePolicy; class DictionaryShortcutsStructurePolicy; class PtNodeArrayReader; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.cpp index f31c914d2..3c62e2e56 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.cpp @@ -84,23 +84,39 @@ bool DynamicPtUpdatingHelper::addUnigramWord( unigramProperty, &pos); } -bool DynamicPtUpdatingHelper::addBigramWords(const int word0Pos, const int word1Pos, - const BigramProperty *const bigramProperty, bool *const outAddedNewBigram) { - const PtNodeParams sourcePtNodeParams( - mPtNodeReader->fetchPtNodeParamsInBufferFromPtNodePos(word0Pos)); - const PtNodeParams targetPtNodeParams( - mPtNodeReader->fetchPtNodeParamsInBufferFromPtNodePos(word1Pos)); - return mPtNodeWriter->addNewBigramEntry(&sourcePtNodeParams, &targetPtNodeParams, - bigramProperty, outAddedNewBigram); +bool DynamicPtUpdatingHelper::addNgramEntry(const PtNodePosArrayView prevWordsPtNodePos, + const int wordPos, const BigramProperty *const bigramProperty, + bool *const outAddedNewEntry) { + if (prevWordsPtNodePos.empty()) { + return false; + } + ASSERT(prevWordsPtNodePos.size() <= MAX_PREV_WORD_COUNT_FOR_N_GRAM); + int prevWordTerminalIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; + for (size_t i = 0; i < prevWordsPtNodePos.size(); ++i) { + prevWordTerminalIds[i] = mPtNodeReader->fetchPtNodeParamsInBufferFromPtNodePos( + prevWordsPtNodePos[i]).getTerminalId(); + } + const WordIdArrayView prevWordIds(prevWordTerminalIds, prevWordsPtNodePos.size()); + const int wordId = + mPtNodeReader->fetchPtNodeParamsInBufferFromPtNodePos(wordPos).getTerminalId(); + return mPtNodeWriter->addNgramEntry(prevWordIds, wordId, bigramProperty, outAddedNewEntry); } -// Remove a bigram relation from word0Pos to word1Pos. -bool DynamicPtUpdatingHelper::removeBigramWords(const int word0Pos, const int word1Pos) { - const PtNodeParams sourcePtNodeParams( - mPtNodeReader->fetchPtNodeParamsInBufferFromPtNodePos(word0Pos)); - const PtNodeParams targetPtNodeParams( - mPtNodeReader->fetchPtNodeParamsInBufferFromPtNodePos(word1Pos)); - return mPtNodeWriter->removeBigramEntry(&sourcePtNodeParams, &targetPtNodeParams); +bool DynamicPtUpdatingHelper::removeNgramEntry(const PtNodePosArrayView prevWordsPtNodePos, + const int wordPos) { + if (prevWordsPtNodePos.empty()) { + return false; + } + ASSERT(prevWordsPtNodePos.size() <= MAX_PREV_WORD_COUNT_FOR_N_GRAM); + int prevWordTerminalIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; + for (size_t i = 0; i < prevWordsPtNodePos.size(); ++i) { + prevWordTerminalIds[i] = mPtNodeReader->fetchPtNodeParamsInBufferFromPtNodePos( + prevWordsPtNodePos[i]).getTerminalId(); + } + const WordIdArrayView prevWordIds(prevWordTerminalIds, prevWordsPtNodePos.size()); + const int wordId = + mPtNodeReader->fetchPtNodeParamsInBufferFromPtNodePos(wordPos).getTerminalId(); + return mPtNodeWriter->removeNgramEntry(prevWordIds, wordId); } bool DynamicPtUpdatingHelper::addShortcutTarget(const int wordPos, diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.h index f10d15a9b..97c05c1ea 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.h @@ -19,6 +19,7 @@ #include "defines.h" #include "suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h" +#include "utils/int_array_view.h" namespace latinime { @@ -42,12 +43,12 @@ class DynamicPtUpdatingHelper { const int *const wordCodePoints, const int codePointCount, const UnigramProperty *const unigramProperty, bool *const outAddedNewUnigram); - // Add a bigram relation from word0Pos to word1Pos. - bool addBigramWords(const int word0Pos, const int word1Pos, - const BigramProperty *const bigramProperty, bool *const outAddedNewBigram); + // Add an n-gram entry. + bool addNgramEntry(const PtNodePosArrayView prevWordsPtNodePos, const int wordPos, + const BigramProperty *const bigramProperty, bool *const outAddedNewEntry); - // Remove a bigram relation from word0Pos to word1Pos. - bool removeBigramWords(const int word0Pos, const int word1Pos); + // Remove an n-gram entry. + bool removeNgramEntry(const PtNodePosArrayView prevWordsPtNodePos, const int wordPos); // Add a shortcut target. bool addShortcutTarget(const int wordPos, const int *const targetCodePoints, diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_writer.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_writer.h index a8029f73f..955d779ac 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_writer.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_writer.h @@ -21,6 +21,7 @@ #include "defines.h" #include "suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h" +#include "utils/int_array_view.h" namespace latinime { @@ -70,12 +71,10 @@ class PtNodeWriter { virtual bool writeNewTerminalPtNodeAndAdvancePosition(const PtNodeParams *const ptNodeParams, const UnigramProperty *const unigramProperty, int *const ptNodeWritingPos) = 0; - virtual bool addNewBigramEntry(const PtNodeParams *const sourcePtNodeParams, - const PtNodeParams *const targetPtNodeParam, const BigramProperty *const bigramProperty, - bool *const outAddedNewBigram) = 0; + virtual bool addNgramEntry(const WordIdArrayView prevWordIds, const int wordId, + const BigramProperty *const bigramProperty, bool *const outAddedNewEntry) = 0; - virtual bool removeBigramEntry(const PtNodeParams *const sourcePtNodeParams, - const PtNodeParams *const targetPtNodeParam) = 0; + virtual bool removeNgramEntry(const WordIdArrayView prevWordIds, const int wordId) = 0; virtual bool updateAllBigramEntriesAndDeleteUselessEntries( const PtNodeParams *const sourcePtNodeParams, int *const outBigramEntryCount) = 0; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/bigram/bigram_list_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/bigram/bigram_list_policy.h index 00bb502dc..73e291ec2 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/bigram/bigram_list_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/bigram/bigram_list_policy.h @@ -27,27 +27,34 @@ namespace latinime { class BigramListPolicy : public DictionaryBigramsStructurePolicy { public: - explicit BigramListPolicy(const uint8_t *const bigramsBuf) : mBigramsBuf(bigramsBuf) {} + BigramListPolicy(const uint8_t *const bigramsBuf, const int bufSize) + : mBigramsBuf(bigramsBuf), mBufSize(bufSize) {} ~BigramListPolicy() {} void getNextBigram(int *const outBigramPos, int *const outProbability, bool *const outHasNext, int *const pos) const { BigramListReadWriteUtils::BigramFlags flags; - BigramListReadWriteUtils::getBigramEntryPropertiesAndAdvancePosition(mBigramsBuf, &flags, - outBigramPos, pos); + if (!BigramListReadWriteUtils::getBigramEntryPropertiesAndAdvancePosition(mBigramsBuf, + mBufSize, &flags, outBigramPos, pos)) { + AKLOGE("Cannot read bigram entry. mBufSize: %d, pos: %d. ", mBufSize, *pos); + *outProbability = NOT_A_PROBABILITY; + *outHasNext = false; + return; + } *outProbability = BigramListReadWriteUtils::getProbabilityFromFlags(flags); *outHasNext = BigramListReadWriteUtils::hasNext(flags); } - void skipAllBigrams(int *const pos) const { - BigramListReadWriteUtils::skipExistingBigrams(mBigramsBuf, pos); + bool skipAllBigrams(int *const pos) const { + return BigramListReadWriteUtils::skipExistingBigrams(mBigramsBuf, mBufSize, pos); } private: DISALLOW_IMPLICIT_CONSTRUCTORS(BigramListPolicy); const uint8_t *const mBigramsBuf; + const int mBufSize; }; } // namespace latinime #endif // LATINIME_BIGRAM_LIST_POLICY_H diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp index 91d76040f..ea32eb2a9 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp @@ -21,6 +21,8 @@ #include "suggest/core/dicnode/dic_node.h" #include "suggest/core/dicnode/dic_node_vector.h" #include "suggest/core/dictionary/binary_dictionary_bigrams_iterator.h" +#include "suggest/core/dictionary/ngram_listener.h" +#include "suggest/core/session/prev_words_info.h" #include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h" #include "suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h" #include "suggest/policyimpl/dictionary/utils/probability_utils.h" @@ -223,7 +225,14 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( mShortcutListPolicy.skipAllShortcuts(&pos); } if (PatriciaTrieReadingUtils::hasBigrams(flags)) { - mBigramListPolicy.skipAllBigrams(&pos); + if (!mBigramListPolicy.skipAllBigrams(&pos)) { + AKLOGE("Cannot skip bigrams. BufSize: %d, pos: %d.", mDictBufferSize, + pos); + mIsCorrupted = true; + ASSERT(false); + *outUnigramProbability = NOT_A_PROBABILITY; + return 0; + } } } } else { @@ -240,7 +249,13 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( mShortcutListPolicy.skipAllShortcuts(&pos); } if (PatriciaTrieReadingUtils::hasBigrams(flags)) { - mBigramListPolicy.skipAllBigrams(&pos); + if (!mBigramListPolicy.skipAllBigrams(&pos)) { + AKLOGE("Cannot skip bigrams. BufSize: %d, pos: %d.", mDictBufferSize, pos); + mIsCorrupted = true; + ASSERT(false); + *outUnigramProbability = NOT_A_PROBABILITY; + return 0; + } } } @@ -282,7 +297,8 @@ int PatriciaTriePolicy::getProbability(const int unigramProbability, } } -int PatriciaTriePolicy::getUnigramProbabilityOfPtNode(const int ptNodePos) const { +int PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtNodePos, + const int ptNodePos) const { if (ptNodePos == NOT_A_DICT_POS) { return NOT_A_PROBABILITY; } @@ -294,9 +310,34 @@ int PatriciaTriePolicy::getUnigramProbabilityOfPtNode(const int ptNodePos) const // for shortcuts). return NOT_A_PROBABILITY; } + if (prevWordsPtNodePos) { + const int bigramsPosition = getBigramsPositionOfPtNode(prevWordsPtNodePos[0]); + BinaryDictionaryBigramsIterator bigramsIt(&mBigramListPolicy, bigramsPosition); + while (bigramsIt.hasNext()) { + bigramsIt.next(); + if (bigramsIt.getBigramPos() == ptNodePos + && bigramsIt.getProbability() != NOT_A_PROBABILITY) { + return getProbability(ptNodeParams.getProbability(), bigramsIt.getProbability()); + } + } + return NOT_A_PROBABILITY; + } return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY); } +void PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordsPtNodePos, + NgramListener *const listener) const { + if (!prevWordsPtNodePos) { + return; + } + const int bigramsPosition = getBigramsPositionOfPtNode(prevWordsPtNodePos[0]); + BinaryDictionaryBigramsIterator bigramsIt(&mBigramListPolicy, bigramsPosition); + while (bigramsIt.hasNext()) { + bigramsIt.next(); + listener->onVisitEntry(bigramsIt.getProbability(), bigramsIt.getBigramPos()); + } +} + int PatriciaTriePolicy::getShortcutPositionOfPtNode(const int ptNodePos) const { if (ptNodePos == NOT_A_DICT_POS) { return NOT_A_DICT_POS; @@ -304,12 +345,6 @@ int PatriciaTriePolicy::getShortcutPositionOfPtNode(const int ptNodePos) const { return mPtNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos).getShortcutPos(); } -BinaryDictionaryBigramsIterator PatriciaTriePolicy::getBigramsIteratorOfPtNode( - const int ptNodePos) const { - const int bigramsPosition = getBigramsPositionOfPtNode(ptNodePos); - return BinaryDictionaryBigramsIterator(&mBigramListPolicy, bigramsPosition); -} - int PatriciaTriePolicy::getBigramsPositionOfPtNode(const int ptNodePos) const { if (ptNodePos == NOT_A_DICT_POS) { return NOT_A_DICT_POS; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h index 7c0b9d3c5..70351d147 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h @@ -29,6 +29,7 @@ #include "suggest/policyimpl/dictionary/structure/v2/ver2_pt_node_array_reader.h" #include "suggest/policyimpl/dictionary/utils/format_utils.h" #include "suggest/policyimpl/dictionary/utils/mmapped_buffer.h" +#include "utils/byte_array_view.h" namespace latinime { @@ -39,10 +40,13 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { public: PatriciaTriePolicy(MmappedBuffer::MmappedBufferPtr mmappedBuffer) : mMmappedBuffer(std::move(mmappedBuffer)), - mHeaderPolicy(mMmappedBuffer->getBuffer(), FormatUtils::VERSION_2), - mDictRoot(mMmappedBuffer->getBuffer() + mHeaderPolicy.getSize()), - mDictBufferSize(mMmappedBuffer->getBufferSize() - mHeaderPolicy.getSize()), - mBigramListPolicy(mDictRoot), mShortcutListPolicy(mDictRoot), + mHeaderPolicy(mMmappedBuffer->getReadOnlyByteArrayView().data(), + FormatUtils::VERSION_2), + mDictRoot(mMmappedBuffer->getReadOnlyByteArrayView().data() + + mHeaderPolicy.getSize()), + mDictBufferSize(mMmappedBuffer->getReadOnlyByteArrayView().size() + - mHeaderPolicy.getSize()), + mBigramListPolicy(mDictRoot, mDictBufferSize), mShortcutListPolicy(mDictRoot), mPtNodeReader(mDictRoot, mDictBufferSize, &mBigramListPolicy, &mShortcutListPolicy), mPtNodeArrayReader(mDictRoot, mDictBufferSize), mTerminalPtNodePositionsForIteratingWords(), mIsCorrupted(false) {} @@ -63,11 +67,12 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { int getProbability(const int unigramProbability, const int bigramProbability) const; - int getUnigramProbabilityOfPtNode(const int ptNodePos) const; + int getProbabilityOfPtNode(const int *const prevWordsPtNodePos, const int ptNodePos) const; - int getShortcutPositionOfPtNode(const int ptNodePos) const; + void iterateNgramEntries(const int *const prevWordsPtNodePos, + NgramListener *const listener) const; - BinaryDictionaryBigramsIterator getBigramsIteratorOfPtNode(const int ptNodePos) const; + int getShortcutPositionOfPtNode(const int ptNodePos) const; const DictionaryHeaderStructurePolicy *getHeaderStructurePolicy() const { return &mHeaderPolicy; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/bigram/ver4_bigram_list_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/bigram/ver4_bigram_list_policy.h index 55ba613a5..4b3bb3725 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/bigram/ver4_bigram_list_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/bigram/ver4_bigram_list_policy.h @@ -40,8 +40,9 @@ class Ver4BigramListPolicy : public DictionaryBigramsStructurePolicy { void getNextBigram(int *const outBigramPos, int *const outProbability, bool *const outHasNext, int *const bigramEntryPos) const; - void skipAllBigrams(int *const pos) const { + bool skipAllBigrams(int *const pos) const { // Do nothing because we don't need to skip bigram lists in ver4 dictionaries. + return true; } bool addNewEntry(const int terminalId, const int newTargetTerminalId, diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp new file mode 100644 index 000000000..5dc91ba10 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp @@ -0,0 +1,95 @@ +/* + * Copyright (C) 2014, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h" + +namespace latinime { + +bool LanguageModelDictContent::save(FILE *const file) const { + return mTrieMap.save(file); +} + +bool LanguageModelDictContent::runGC( + const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, + const LanguageModelDictContent *const originalContent, + int *const outNgramCount) { + return runGCInner(terminalIdMap, originalContent->mTrieMap.getEntriesInRootLevel(), + 0 /* nextLevelBitmapEntryIndex */, outNgramCount); +} + +ProbabilityEntry LanguageModelDictContent::getNgramProbabilityEntry( + const WordIdArrayView prevWordIds, const int wordId) const { + const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds); + if (bitmapEntryIndex == TrieMap::INVALID_INDEX) { + return ProbabilityEntry(); + } + const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndex); + if (!result.mIsValid) { + // Not found. + return ProbabilityEntry(); + } + return ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo); +} + +bool LanguageModelDictContent::setNgramProbabilityEntry(const WordIdArrayView prevWordIds, + const int terminalId, const ProbabilityEntry *const probabilityEntry) { + const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds); + if (bitmapEntryIndex == TrieMap::INVALID_INDEX) { + return false; + } + return mTrieMap.put(terminalId, probabilityEntry->encode(mHasHistoricalInfo), bitmapEntryIndex); +} + +bool LanguageModelDictContent::runGCInner( + const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, + const TrieMap::TrieMapRange trieMapRange, + const int nextLevelBitmapEntryIndex, int *const outNgramCount) { + for (auto &entry : trieMapRange) { + const auto it = terminalIdMap->find(entry.key()); + if (it == terminalIdMap->end() || it->second == Ver4DictConstants::NOT_A_TERMINAL_ID) { + // The word has been removed. + continue; + } + if (!mTrieMap.put(it->second, entry.value(), nextLevelBitmapEntryIndex)) { + return false; + } + if (outNgramCount) { + *outNgramCount += 1; + } + if (entry.hasNextLevelMap()) { + if (!runGCInner(terminalIdMap, entry.getEntriesInNextLevel(), + mTrieMap.getNextLevelBitmapEntryIndex(it->second, nextLevelBitmapEntryIndex), + outNgramCount)) { + return false; + } + } + } + return true; +} + +int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWordIds) const { + int bitmapEntryIndex = mTrieMap.getRootBitmapEntryIndex(); + for (const int wordId : prevWordIds) { + const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndex); + if (!result.mIsValid) { + return TrieMap::INVALID_INDEX; + } + bitmapEntryIndex = result.mNextLevelBitmapEntryIndex; + } + return bitmapEntryIndex; +} + +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h new file mode 100644 index 000000000..18f2e0170 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h @@ -0,0 +1,83 @@ +/* + * Copyright (C) 2014, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H +#define LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H + +#include <cstdio> + +#include "defines.h" +#include "suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h" +#include "suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h" +#include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h" +#include "suggest/policyimpl/dictionary/utils/trie_map.h" +#include "utils/byte_array_view.h" +#include "utils/int_array_view.h" + +namespace latinime { + +/** + * Class representing language model. + * + * This class provides methods to get and store unigram/n-gram probability information and flags. + */ +class LanguageModelDictContent { + public: + LanguageModelDictContent(const ReadWriteByteArrayView trieMapBuffer, + const bool hasHistoricalInfo) + : mTrieMap(trieMapBuffer), mHasHistoricalInfo(hasHistoricalInfo) {} + + explicit LanguageModelDictContent(const bool hasHistoricalInfo) + : mTrieMap(), mHasHistoricalInfo(hasHistoricalInfo) {} + + bool isNearSizeLimit() const { + return mTrieMap.isNearSizeLimit(); + } + + bool save(FILE *const file) const; + + bool runGC(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, + const LanguageModelDictContent *const originalContent, + int *const outNgramCount); + + ProbabilityEntry getProbabilityEntry(const int wordId) const { + return getNgramProbabilityEntry(WordIdArrayView(), wordId); + } + + bool setProbabilityEntry(const int wordId, const ProbabilityEntry *const probabilityEntry) { + return setNgramProbabilityEntry(WordIdArrayView(), wordId, probabilityEntry); + } + + ProbabilityEntry getNgramProbabilityEntry(const WordIdArrayView prevWordIds, + const int wordId) const; + + bool setNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId, + const ProbabilityEntry *const probabilityEntry); + + private: + DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent); + + TrieMap mTrieMap; + const bool mHasHistoricalInfo; + + bool runGCInner(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, + const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex, + int *const outNgramCount); + + int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const; +}; +} // namespace latinime +#endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_dict_content.cpp deleted file mode 100644 index 2425b3b2f..000000000 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_dict_content.cpp +++ /dev/null @@ -1,159 +0,0 @@ -/* - * Copyright (C) 2013 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "suggest/policyimpl/dictionary/structure/v4/content/probability_dict_content.h" - -#include "suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h" -#include "suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h" -#include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h" -#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" - -namespace latinime { - -const ProbabilityEntry ProbabilityDictContent::getProbabilityEntry(const int terminalId) const { - if (terminalId < 0 || terminalId >= mSize) { - // This method can be called with invalid terminal id during GC. - return ProbabilityEntry(0 /* flags */, NOT_A_PROBABILITY); - } - const BufferWithExtendableBuffer *const buffer = getBuffer(); - int entryPos = getEntryPos(terminalId); - const int flags = buffer->readUintAndAdvancePosition( - Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE, &entryPos); - const int probability = buffer->readUintAndAdvancePosition( - Ver4DictConstants::PROBABILITY_SIZE, &entryPos); - if (mHasHistoricalInfo) { - const int timestamp = buffer->readUintAndAdvancePosition( - Ver4DictConstants::TIME_STAMP_FIELD_SIZE, &entryPos); - const int level = buffer->readUintAndAdvancePosition( - Ver4DictConstants::WORD_LEVEL_FIELD_SIZE, &entryPos); - const int count = buffer->readUintAndAdvancePosition( - Ver4DictConstants::WORD_COUNT_FIELD_SIZE, &entryPos); - const HistoricalInfo historicalInfo(timestamp, level, count); - return ProbabilityEntry(flags, probability, &historicalInfo); - } else { - return ProbabilityEntry(flags, probability); - } -} - -bool ProbabilityDictContent::setProbabilityEntry(const int terminalId, - const ProbabilityEntry *const probabilityEntry) { - if (terminalId < 0) { - return false; - } - const int entryPos = getEntryPos(terminalId); - if (terminalId >= mSize) { - ProbabilityEntry dummyEntry; - // Write new entry. - int writingPos = getBuffer()->getTailPosition(); - while (writingPos <= entryPos) { - // Fulfilling with dummy entries until writingPos. - if (!writeEntry(&dummyEntry, writingPos)) { - AKLOGE("Cannot write dummy entry. pos: %d, mSize: %d", writingPos, mSize); - return false; - } - writingPos += getEntrySize(); - mSize++; - } - } - return writeEntry(probabilityEntry, entryPos); -} - -bool ProbabilityDictContent::flushToFile(FILE *const file) const { - if (getEntryPos(mSize) < getBuffer()->getTailPosition()) { - ProbabilityDictContent probabilityDictContentToWrite(mHasHistoricalInfo); - for (int i = 0; i < mSize; ++i) { - const ProbabilityEntry probabilityEntry = getProbabilityEntry(i); - if (!probabilityDictContentToWrite.setProbabilityEntry(i, &probabilityEntry)) { - AKLOGE("Cannot set probability entry in flushToFile. terminalId: %d", i); - return false; - } - } - return probabilityDictContentToWrite.flush(file); - } else { - return flush(file); - } -} - -bool ProbabilityDictContent::runGC( - const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, - const ProbabilityDictContent *const originalProbabilityDictContent) { - mSize = 0; - for (TerminalPositionLookupTable::TerminalIdMap::const_iterator it = terminalIdMap->begin(); - it != terminalIdMap->end(); ++it) { - const ProbabilityEntry probabilityEntry = - originalProbabilityDictContent->getProbabilityEntry(it->first); - if (!setProbabilityEntry(it->second, &probabilityEntry)) { - AKLOGE("Cannot set probability entry in runGC. terminalId: %d", it->second); - return false; - } - mSize++; - } - return true; -} - -int ProbabilityDictContent::getEntrySize() const { - if (mHasHistoricalInfo) { - return Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE - + Ver4DictConstants::PROBABILITY_SIZE - + Ver4DictConstants::TIME_STAMP_FIELD_SIZE - + Ver4DictConstants::WORD_LEVEL_FIELD_SIZE - + Ver4DictConstants::WORD_COUNT_FIELD_SIZE; - } else { - return Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE - + Ver4DictConstants::PROBABILITY_SIZE; - } -} - -int ProbabilityDictContent::getEntryPos(const int terminalId) const { - return terminalId * getEntrySize(); -} - -bool ProbabilityDictContent::writeEntry(const ProbabilityEntry *const probabilityEntry, - const int entryPos) { - BufferWithExtendableBuffer *const bufferToWrite = getWritableBuffer(); - int writingPos = entryPos; - if (!bufferToWrite->writeUintAndAdvancePosition(probabilityEntry->getFlags(), - Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE, &writingPos)) { - AKLOGE("Cannot write flags in probability dict content. pos: %d", writingPos); - return false; - } - if (!bufferToWrite->writeUintAndAdvancePosition(probabilityEntry->getProbability(), - Ver4DictConstants::PROBABILITY_SIZE, &writingPos)) { - AKLOGE("Cannot write probability in probability dict content. pos: %d", writingPos); - return false; - } - if (mHasHistoricalInfo) { - const HistoricalInfo *const historicalInfo = probabilityEntry->getHistoricalInfo(); - if (!bufferToWrite->writeUintAndAdvancePosition(historicalInfo->getTimeStamp(), - Ver4DictConstants::TIME_STAMP_FIELD_SIZE, &writingPos)) { - AKLOGE("Cannot write timestamp in probability dict content. pos: %d", writingPos); - return false; - } - if (!bufferToWrite->writeUintAndAdvancePosition(historicalInfo->getLevel(), - Ver4DictConstants::WORD_LEVEL_FIELD_SIZE, &writingPos)) { - AKLOGE("Cannot write level in probability dict content. pos: %d", writingPos); - return false; - } - if (!bufferToWrite->writeUintAndAdvancePosition(historicalInfo->getCount(), - Ver4DictConstants::WORD_COUNT_FIELD_SIZE, &writingPos)) { - AKLOGE("Cannot write count in probability dict content. pos: %d", writingPos); - return false; - } - } - return true; -} - -} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_dict_content.h deleted file mode 100644 index 80e992c1c..000000000 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_dict_content.h +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright (C) 2013, The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LATINIME_PROBABILITY_DICT_CONTENT_H -#define LATINIME_PROBABILITY_DICT_CONTENT_H - -#include <cstdint> -#include <cstdio> - -#include "defines.h" -#include "suggest/policyimpl/dictionary/structure/v4/content/single_dict_content.h" -#include "suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h" -#include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h" -#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" - -namespace latinime { - -class ProbabilityEntry; - -class ProbabilityDictContent : public SingleDictContent { - public: - ProbabilityDictContent(uint8_t *const buffer, const int bufferSize, - const bool hasHistoricalInfo) - : SingleDictContent(buffer, bufferSize), - mHasHistoricalInfo(hasHistoricalInfo), - mSize(getBuffer()->getTailPosition() / getEntrySize()) {} - - ProbabilityDictContent(const bool hasHistoricalInfo) - : mHasHistoricalInfo(hasHistoricalInfo), mSize(0) {} - - const ProbabilityEntry getProbabilityEntry(const int terminalId) const; - - bool setProbabilityEntry(const int terminalId, const ProbabilityEntry *const probabilityEntry); - - bool flushToFile(FILE *const file) const; - - bool runGC(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, - const ProbabilityDictContent *const originalProbabilityDictContent); - - private: - DISALLOW_COPY_AND_ASSIGN(ProbabilityDictContent); - - int getEntrySize() const; - - int getEntryPos(const int terminalId) const; - - bool writeEntry(const ProbabilityEntry *const probabilityEntry, const int entryPos); - - bool mHasHistoricalInfo; - int mSize; -}; -} // namespace latinime -#endif /* LATINIME_PROBABILITY_DICT_CONTENT_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h index 36ba82be1..feff6b57f 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h @@ -17,6 +17,9 @@ #ifndef LATINIME_PROBABILITY_ENTRY_H #define LATINIME_PROBABILITY_ENTRY_H +#include <climits> +#include <cstdint> + #include "defines.h" #include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h" #include "suggest/policyimpl/dictionary/utils/historical_info.h" @@ -67,6 +70,50 @@ class ProbabilityEntry { return &mHistoricalInfo; } + uint64_t encode(const bool hasHistoricalInfo) const { + uint64_t encodedEntry = static_cast<uint64_t>(mFlags); + if (hasHistoricalInfo) { + encodedEntry = (encodedEntry << (Ver4DictConstants::TIME_STAMP_FIELD_SIZE * CHAR_BIT)) + ^ static_cast<uint64_t>(mHistoricalInfo.getTimeStamp()); + encodedEntry = (encodedEntry << (Ver4DictConstants::WORD_LEVEL_FIELD_SIZE * CHAR_BIT)) + ^ static_cast<uint64_t>(mHistoricalInfo.getLevel()); + encodedEntry = (encodedEntry << (Ver4DictConstants::WORD_COUNT_FIELD_SIZE * CHAR_BIT)) + ^ static_cast<uint64_t>(mHistoricalInfo.getCount()); + } else { + encodedEntry = (encodedEntry << (Ver4DictConstants::PROBABILITY_SIZE * CHAR_BIT)) + ^ static_cast<uint64_t>(mProbability); + } + return encodedEntry; + } + + static ProbabilityEntry decode(const uint64_t encodedEntry, const bool hasHistoricalInfo) { + if (hasHistoricalInfo) { + const int flags = readFromEncodedEntry(encodedEntry, + Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE, + Ver4DictConstants::TIME_STAMP_FIELD_SIZE + + Ver4DictConstants::WORD_LEVEL_FIELD_SIZE + + Ver4DictConstants::WORD_COUNT_FIELD_SIZE); + const int timestamp = readFromEncodedEntry(encodedEntry, + Ver4DictConstants::TIME_STAMP_FIELD_SIZE, + Ver4DictConstants::WORD_LEVEL_FIELD_SIZE + + Ver4DictConstants::WORD_COUNT_FIELD_SIZE); + const int level = readFromEncodedEntry(encodedEntry, + Ver4DictConstants::WORD_LEVEL_FIELD_SIZE, + Ver4DictConstants::WORD_COUNT_FIELD_SIZE); + const int count = readFromEncodedEntry(encodedEntry, + Ver4DictConstants::WORD_COUNT_FIELD_SIZE, 0 /* pos */); + const HistoricalInfo historicalInfo(timestamp, level, count); + return ProbabilityEntry(flags, NOT_A_PROBABILITY, &historicalInfo); + } else { + const int flags = readFromEncodedEntry(encodedEntry, + Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE, + Ver4DictConstants::PROBABILITY_SIZE); + const int probability = readFromEncodedEntry(encodedEntry, + Ver4DictConstants::PROBABILITY_SIZE, 0 /* pos */); + return ProbabilityEntry(flags, probability); + } + } + private: // Copy constructor is public to use this class as a type of return value. DISALLOW_ASSIGNMENT_OPERATOR(ProbabilityEntry); @@ -74,6 +121,11 @@ class ProbabilityEntry { const int mFlags; const int mProbability; const HistoricalInfo mHistoricalInfo; + + static int readFromEncodedEntry(const uint64_t encodedEntry, const int size, const int pos) { + return static_cast<int>( + (encodedEntry >> (pos * CHAR_BIT)) & ((1ull << (size * CHAR_BIT)) - 1)); + } }; } // namespace latinime #endif /* LATINIME_PROBABILITY_ENTRY_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/single_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/single_dict_content.h index 69a11425f..921774181 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/single_dict_content.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/single_dict_content.h @@ -21,17 +21,17 @@ #include <cstdio> #include "defines.h" -#include "suggest/policyimpl/dictionary/structure/v4/content/dict_content.h" #include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h" #include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" #include "suggest/policyimpl/dictionary/utils/dict_file_writing_utils.h" +#include "utils/byte_array_view.h" namespace latinime { -class SingleDictContent : public DictContent { +class SingleDictContent { public: SingleDictContent(uint8_t *const buffer, const int bufferSize) - : mExpandableContentBuffer(buffer, bufferSize, + : mExpandableContentBuffer(ReadWriteByteArrayView(buffer, bufferSize), BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE) {} SingleDictContent() diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/sparse_table_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/sparse_table_dict_content.h index cdf870bd2..c98dd11fd 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/sparse_table_dict_content.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/sparse_table_dict_content.h @@ -21,26 +21,29 @@ #include <cstdio> #include "defines.h" -#include "suggest/policyimpl/dictionary/structure/v4/content/dict_content.h" #include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h" #include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" #include "suggest/policyimpl/dictionary/utils/sparse_table.h" +#include "utils/byte_array_view.h" namespace latinime { // TODO: Support multiple contents. -class SparseTableDictContent : public DictContent { +class SparseTableDictContent { public: AK_FORCE_INLINE SparseTableDictContent(uint8_t *const *buffers, const int *bufferSizes, const int sparseTableBlockSize, const int sparseTableDataSize) - : mExpandableLookupTableBuffer(buffers[LOOKUP_TABLE_BUFFER_INDEX], - bufferSizes[LOOKUP_TABLE_BUFFER_INDEX], + : mExpandableLookupTableBuffer( + ReadWriteByteArrayView(buffers[LOOKUP_TABLE_BUFFER_INDEX], + bufferSizes[LOOKUP_TABLE_BUFFER_INDEX]), BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE), - mExpandableAddressTableBuffer(buffers[ADDRESS_TABLE_BUFFER_INDEX], - bufferSizes[ADDRESS_TABLE_BUFFER_INDEX], + mExpandableAddressTableBuffer( + ReadWriteByteArrayView(buffers[ADDRESS_TABLE_BUFFER_INDEX], + bufferSizes[ADDRESS_TABLE_BUFFER_INDEX]), BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE), - mExpandableContentBuffer(buffers[CONTENT_BUFFER_INDEX], - bufferSizes[CONTENT_BUFFER_INDEX], + mExpandableContentBuffer( + ReadWriteByteArrayView(buffers[CONTENT_BUFFER_INDEX], + bufferSizes[CONTENT_BUFFER_INDEX]), BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE), mAddressLookupTable(&mExpandableLookupTableBuffer, &mExpandableAddressTableBuffer, sparseTableBlockSize, sparseTableDataSize) {} diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.cpp index 36ab9963a..3c8008dc4 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.cpp @@ -26,6 +26,7 @@ #include "suggest/policyimpl/dictionary/utils/byte_array_utils.h" #include "suggest/policyimpl/dictionary/utils/dict_file_writing_utils.h" #include "suggest/policyimpl/dictionary/utils/file_utils.h" +#include "utils/byte_array_view.h" namespace latinime { @@ -46,14 +47,16 @@ namespace latinime { } std::vector<uint8_t *> buffers; std::vector<int> bufferSizes; - uint8_t *const buffer = bodyBuffer->getBuffer(); + const ReadWriteByteArrayView buffer = bodyBuffer->getReadWriteByteArrayView(); int position = 0; - while (position < bodyBuffer->getBufferSize()) { - const int bufferSize = ByteArrayUtils::readUint32AndAdvancePosition(buffer, &position); - buffers.push_back(buffer + position); - bufferSizes.push_back(bufferSize); + while (position < static_cast<int>(buffer.size())) { + const int bufferSize = ByteArrayUtils::readUint32AndAdvancePosition( + buffer.data(), &position); + const ReadWriteByteArrayView subBuffer = buffer.subView(position, bufferSize); + buffers.push_back(subBuffer.data()); + bufferSizes.push_back(subBuffer.size()); position += bufferSize; - if (bufferSize < 0 || position < 0 || position > bodyBuffer->getBufferSize()) { + if (bufferSize < 0 || position < 0 || position > static_cast<int>(buffer.size())) { AKLOGE("The dict body file is corrupted."); return Ver4DictBuffersPtr(nullptr); } @@ -118,7 +121,7 @@ bool Ver4DictBuffers::flushHeaderAndDictBuffers(const char *const dictDirPath, } FILE *const file = fdopen(fd, "wb"); if (!file) { - AKLOGE("fdopen failed for the file %s. errno: %d", filePath, errno); + AKLOGE("fdopen failed for the file %s. errno: %d", bodyFilePath, errno); ASSERT(false); return false; } @@ -146,27 +149,27 @@ bool Ver4DictBuffers::flushHeaderAndDictBuffers(const char *const dictDirPath, bool Ver4DictBuffers::flushDictBuffers(FILE *const file) const { // Write trie. if (!DictFileWritingUtils::writeBufferToFileTail(file, &mExpandableTrieBuffer)) { - AKLOGE("Trie cannot be written. %s", tmpDirPath); + AKLOGE("Trie cannot be written."); return false; } // Write terminal position lookup table. if (!mTerminalPositionLookupTable.flushToFile(file)) { - AKLOGE("Terminal position lookup table cannot be written. %s", tmpDirPath); + AKLOGE("Terminal position lookup table cannot be written."); return false; } - // Write probability dict content. - if (!mProbabilityDictContent.flushToFile(file)) { - AKLOGE("Probability dict content cannot be written. %s", tmpDirPath); + // Write language model content. + if (!mLanguageModelDictContent.save(file)) { + AKLOGE("Language model dict content cannot be written."); return false; } // Write bigram dict content. if (!mBigramDictContent.flushToFile(file)) { - AKLOGE("Bigram dict content cannot be written. %s", tmpDirPath); + AKLOGE("Bigram dict content cannot be written."); return false; } // Write shortcut dict content. if (!mShortcutDictContent.flushToFile(file)) { - AKLOGE("Shortcut dict content cannot be written. %s", tmpDirPath); + AKLOGE("Shortcut dict content cannot be written."); return false; } return true; @@ -177,20 +180,21 @@ Ver4DictBuffers::Ver4DictBuffers(MmappedBuffer::MmappedBufferPtr &&headerBuffer, const FormatUtils::FORMAT_VERSION formatVersion, const std::vector<uint8_t *> &contentBuffers, const std::vector<int> &contentBufferSizes) : mHeaderBuffer(std::move(headerBuffer)), mDictBuffer(std::move(bodyBuffer)), - mHeaderPolicy(mHeaderBuffer->getBuffer(), formatVersion), - mExpandableHeaderBuffer(mHeaderBuffer ? mHeaderBuffer->getBuffer() : nullptr, - mHeaderPolicy.getSize(), + mHeaderPolicy(mHeaderBuffer->getReadOnlyByteArrayView().data(), formatVersion), + mExpandableHeaderBuffer(mHeaderBuffer->getReadWriteByteArrayView(), BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE), - mExpandableTrieBuffer(contentBuffers[Ver4DictConstants::TRIE_BUFFER_INDEX], - contentBufferSizes[Ver4DictConstants::TRIE_BUFFER_INDEX], + mExpandableTrieBuffer( + ReadWriteByteArrayView(contentBuffers[Ver4DictConstants::TRIE_BUFFER_INDEX], + contentBufferSizes[Ver4DictConstants::TRIE_BUFFER_INDEX]), BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE), mTerminalPositionLookupTable( contentBuffers[Ver4DictConstants::TERMINAL_ADDRESS_LOOKUP_TABLE_BUFFER_INDEX], contentBufferSizes[ Ver4DictConstants::TERMINAL_ADDRESS_LOOKUP_TABLE_BUFFER_INDEX]), - mProbabilityDictContent( - contentBuffers[Ver4DictConstants::PROBABILITY_BUFFER_INDEX], - contentBufferSizes[Ver4DictConstants::PROBABILITY_BUFFER_INDEX], + mLanguageModelDictContent( + ReadWriteByteArrayView( + contentBuffers[Ver4DictConstants::LANGUAGE_MODEL_BUFFER_INDEX], + contentBufferSizes[Ver4DictConstants::LANGUAGE_MODEL_BUFFER_INDEX]), mHeaderPolicy.hasHistoricalInfoOfWords()), mBigramDictContent(&contentBuffers[Ver4DictConstants::BIGRAM_BUFFERS_INDEX], &contentBufferSizes[Ver4DictConstants::BIGRAM_BUFFERS_INDEX], @@ -203,7 +207,7 @@ Ver4DictBuffers::Ver4DictBuffers(const HeaderPolicy *const headerPolicy, const i : mHeaderBuffer(nullptr), mDictBuffer(nullptr), mHeaderPolicy(headerPolicy), mExpandableHeaderBuffer(Ver4DictConstants::MAX_DICTIONARY_SIZE), mExpandableTrieBuffer(maxTrieSize), mTerminalPositionLookupTable(), - mProbabilityDictContent(headerPolicy->hasHistoricalInfoOfWords()), + mLanguageModelDictContent(headerPolicy->hasHistoricalInfoOfWords()), mBigramDictContent(headerPolicy->hasHistoricalInfoOfWords()), mShortcutDictContent(), mIsUpdatable(true) {} diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.h index 433411cb8..68027dcb8 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.h @@ -23,7 +23,7 @@ #include "defines.h" #include "suggest/policyimpl/dictionary/header/header_policy.h" #include "suggest/policyimpl/dictionary/structure/v4/content/bigram_dict_content.h" -#include "suggest/policyimpl/dictionary/structure/v4/content/probability_dict_content.h" +#include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h" #include "suggest/policyimpl/dictionary/structure/v4/content/shortcut_dict_content.h" #include "suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h" #include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h" @@ -52,7 +52,7 @@ class Ver4DictBuffers { AK_FORCE_INLINE bool isNearSizeLimit() const { return mExpandableTrieBuffer.isNearSizeLimit() || mTerminalPositionLookupTable.isNearSizeLimit() - || mProbabilityDictContent.isNearSizeLimit() + || mLanguageModelDictContent.isNearSizeLimit() || mBigramDictContent.isNearSizeLimit() || mShortcutDictContent.isNearSizeLimit(); } @@ -81,12 +81,12 @@ class Ver4DictBuffers { return &mTerminalPositionLookupTable; } - AK_FORCE_INLINE ProbabilityDictContent *getMutableProbabilityDictContent() { - return &mProbabilityDictContent; + AK_FORCE_INLINE LanguageModelDictContent *getMutableLanguageModelDictContent() { + return &mLanguageModelDictContent; } - AK_FORCE_INLINE const ProbabilityDictContent *getProbabilityDictContent() const { - return &mProbabilityDictContent; + AK_FORCE_INLINE const LanguageModelDictContent *getLanguageModelDictContent() const { + return &mLanguageModelDictContent; } AK_FORCE_INLINE BigramDictContent *getMutableBigramDictContent() { @@ -135,7 +135,7 @@ class Ver4DictBuffers { BufferWithExtendableBuffer mExpandableHeaderBuffer; BufferWithExtendableBuffer mExpandableTrieBuffer; TerminalPositionLookupTable mTerminalPositionLookupTable; - ProbabilityDictContent mProbabilityDictContent; + LanguageModelDictContent mLanguageModelDictContent; BigramDictContent mBigramDictContent; ShortcutDictContent mShortcutDictContent; const int mIsUpdatable; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp index d45dfe377..93d4e562d 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp @@ -27,18 +27,20 @@ const int Ver4DictConstants::MAX_DICTIONARY_SIZE = 8 * 1024 * 1024; // limited to 1MB to prevent from inefficient traversing. const int Ver4DictConstants::MAX_DICT_EXTENDED_REGION_SIZE = 1 * 1024 * 1024; -// NUM_OF_BUFFERS_FOR_SINGLE_DICT_CONTENT for Trie, TerminalAddressLookupTable and Probability. +// NUM_OF_BUFFERS_FOR_SINGLE_DICT_CONTENT for Trie and TerminalAddressLookupTable. +// NUM_OF_BUFFERS_FOR_LANGUAGE_MODEL_DICT_CONTENT for language model. // NUM_OF_BUFFERS_FOR_SPARSE_TABLE_DICT_CONTENT for bigram and shortcut. const size_t Ver4DictConstants::NUM_OF_CONTENT_BUFFERS_IN_BODY_FILE = - NUM_OF_BUFFERS_FOR_SINGLE_DICT_CONTENT * 3 + NUM_OF_BUFFERS_FOR_SINGLE_DICT_CONTENT * 2 + + NUM_OF_BUFFERS_FOR_LANGUAGE_MODEL_DICT_CONTENT + NUM_OF_BUFFERS_FOR_SPARSE_TABLE_DICT_CONTENT * 2; const int Ver4DictConstants::TRIE_BUFFER_INDEX = 0; const int Ver4DictConstants::TERMINAL_ADDRESS_LOOKUP_TABLE_BUFFER_INDEX = TRIE_BUFFER_INDEX + NUM_OF_BUFFERS_FOR_SINGLE_DICT_CONTENT; -const int Ver4DictConstants::PROBABILITY_BUFFER_INDEX = +const int Ver4DictConstants::LANGUAGE_MODEL_BUFFER_INDEX = TERMINAL_ADDRESS_LOOKUP_TABLE_BUFFER_INDEX + NUM_OF_BUFFERS_FOR_SINGLE_DICT_CONTENT; const int Ver4DictConstants::BIGRAM_BUFFERS_INDEX = - PROBABILITY_BUFFER_INDEX + NUM_OF_BUFFERS_FOR_SINGLE_DICT_CONTENT; + LANGUAGE_MODEL_BUFFER_INDEX + NUM_OF_BUFFERS_FOR_LANGUAGE_MODEL_DICT_CONTENT; const int Ver4DictConstants::SHORTCUT_BUFFERS_INDEX = BIGRAM_BUFFERS_INDEX + NUM_OF_BUFFERS_FOR_SPARSE_TABLE_DICT_CONTENT; @@ -73,5 +75,6 @@ const int Ver4DictConstants::SHORTCUT_HAS_NEXT_MASK = 0x80; const size_t Ver4DictConstants::NUM_OF_BUFFERS_FOR_SINGLE_DICT_CONTENT = 1; const size_t Ver4DictConstants::NUM_OF_BUFFERS_FOR_SPARSE_TABLE_DICT_CONTENT = 3; +const size_t Ver4DictConstants::NUM_OF_BUFFERS_FOR_LANGUAGE_MODEL_DICT_CONTENT = 1; } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h index e8f6739ba..6950ca70f 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h @@ -35,7 +35,7 @@ class Ver4DictConstants { static const size_t NUM_OF_CONTENT_BUFFERS_IN_BODY_FILE; static const int TRIE_BUFFER_INDEX; static const int TERMINAL_ADDRESS_LOOKUP_TABLE_BUFFER_INDEX; - static const int PROBABILITY_BUFFER_INDEX; + static const int LANGUAGE_MODEL_BUFFER_INDEX; static const int BIGRAM_BUFFERS_INDEX; static const int SHORTCUT_BUFFERS_INDEX; @@ -71,6 +71,7 @@ class Ver4DictConstants { static const size_t NUM_OF_BUFFERS_FOR_SINGLE_DICT_CONTENT; static const size_t NUM_OF_BUFFERS_FOR_SPARSE_TABLE_DICT_CONTENT; + static const size_t NUM_OF_BUFFERS_FOR_LANGUAGE_MODEL_DICT_CONTENT; }; } // namespace latinime #endif /* LATINIME_VER4_DICT_CONSTANTS_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.cpp index 0a435e91c..731092efd 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.cpp @@ -18,7 +18,7 @@ #include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_utils.h" #include "suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h" -#include "suggest/policyimpl/dictionary/structure/v4/content/probability_dict_content.h" +#include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h" #include "suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h" #include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_reading_utils.h" #include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" @@ -61,8 +61,9 @@ const PtNodeParams Ver4PatriciaTrieNodeReader::fetchPtNodeInfoFromBufferAndProce terminalIdFieldPos += mBuffer->getOriginalBufferSize(); } terminalId = Ver4PatriciaTrieReadingUtils::getTerminalIdAndAdvancePosition(dictBuf, &pos); + // TODO: Quit reading probability here. const ProbabilityEntry probabilityEntry = - mProbabilityDictContent->getProbabilityEntry(terminalId); + mLanguageModelDictContent->getProbabilityEntry(terminalId); if (probabilityEntry.hasHistoricalInfo()) { probability = ForgettingCurveUtils::decodeProbability( probabilityEntry.getHistoricalInfo(), mHeaderPolicy); diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.h index 22ed4a6c0..a91ad5728 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.h @@ -25,18 +25,18 @@ namespace latinime { class BufferWithExtendableBuffer; class HeaderPolicy; -class ProbabilityDictContent; +class LanguageModelDictContent; /* * This class is used for helping to read nodes of ver4 patricia trie. This class handles moved - * node and reads node attributes including probability form probabilityBuffer. + * node and reads node attributes including probability form language model. */ class Ver4PatriciaTrieNodeReader : public PtNodeReader { public: Ver4PatriciaTrieNodeReader(const BufferWithExtendableBuffer *const buffer, - const ProbabilityDictContent *const probabilityDictContent, + const LanguageModelDictContent *const languageModelDictContent, const HeaderPolicy *const headerPolicy) - : mBuffer(buffer), mProbabilityDictContent(probabilityDictContent), + : mBuffer(buffer), mLanguageModelDictContent(languageModelDictContent), mHeaderPolicy(headerPolicy) {} ~Ver4PatriciaTrieNodeReader() {} @@ -50,7 +50,7 @@ class Ver4PatriciaTrieNodeReader : public PtNodeReader { DISALLOW_COPY_AND_ASSIGN(Ver4PatriciaTrieNodeReader); const BufferWithExtendableBuffer *const mBuffer; - const ProbabilityDictContent *const mProbabilityDictContent; + const LanguageModelDictContent *const mLanguageModelDictContent; const HeaderPolicy *const mHeaderPolicy; const PtNodeParams fetchPtNodeInfoFromBufferAndProcessMovedPtNode(const int ptNodePos, diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp index 3d8da9173..857222f5d 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp @@ -143,11 +143,11 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeUnigramProperty( return false; } const ProbabilityEntry originalProbabilityEntry = - mBuffers->getProbabilityDictContent()->getProbabilityEntry( + mBuffers->getLanguageModelDictContent()->getProbabilityEntry( toBeUpdatedPtNodeParams->getTerminalId()); const ProbabilityEntry probabilityEntry = createUpdatedEntryFrom(&originalProbabilityEntry, unigramProperty); - return mBuffers->getMutableProbabilityDictContent()->setProbabilityEntry( + return mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry( toBeUpdatedPtNodeParams->getTerminalId(), &probabilityEntry); } @@ -158,14 +158,14 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeProbabilityAndGetNeedsToKeepPtNodeA return false; } const ProbabilityEntry originalProbabilityEntry = - mBuffers->getProbabilityDictContent()->getProbabilityEntry( + mBuffers->getLanguageModelDictContent()->getProbabilityEntry( toBeUpdatedPtNodeParams->getTerminalId()); if (originalProbabilityEntry.hasHistoricalInfo()) { const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave( originalProbabilityEntry.getHistoricalInfo(), mHeaderPolicy); const ProbabilityEntry probabilityEntry = originalProbabilityEntry.createEntryWithUpdatedHistoricalInfo(&historicalInfo); - if (!mBuffers->getMutableProbabilityDictContent()->setProbabilityEntry( + if (!mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry( toBeUpdatedPtNodeParams->getTerminalId(), &probabilityEntry)) { AKLOGE("Cannot write updated probability entry. terminalId: %d", toBeUpdatedPtNodeParams->getTerminalId()); @@ -218,26 +218,23 @@ bool Ver4PatriciaTrieNodeWriter::writeNewTerminalPtNodeAndAdvancePosition( ProbabilityEntry newProbabilityEntry; const ProbabilityEntry probabilityEntryToWrite = createUpdatedEntryFrom( &newProbabilityEntry, unigramProperty); - return mBuffers->getMutableProbabilityDictContent()->setProbabilityEntry(terminalId, - &probabilityEntryToWrite); + return mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry( + terminalId, &probabilityEntryToWrite); } -bool Ver4PatriciaTrieNodeWriter::addNewBigramEntry( - const PtNodeParams *const sourcePtNodeParams, const PtNodeParams *const targetPtNodeParam, +bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds, const int wordId, const BigramProperty *const bigramProperty, bool *const outAddedNewBigram) { - if (!mBigramPolicy->addNewEntry(sourcePtNodeParams->getTerminalId(), - targetPtNodeParam->getTerminalId(), bigramProperty, outAddedNewBigram)) { + if (!mBigramPolicy->addNewEntry(prevWordIds[0], wordId, bigramProperty, outAddedNewBigram)) { AKLOGE("Cannot add new bigram entry. terminalId: %d, targetTerminalId: %d", - sourcePtNodeParams->getTerminalId(), targetPtNodeParam->getTerminalId()); + prevWordIds[0], wordId); return false; } return true; } -bool Ver4PatriciaTrieNodeWriter::removeBigramEntry( - const PtNodeParams *const sourcePtNodeParams, const PtNodeParams *const targetPtNodeParam) { - return mBigramPolicy->removeEntry(sourcePtNodeParams->getTerminalId(), - targetPtNodeParam->getTerminalId()); +bool Ver4PatriciaTrieNodeWriter::removeNgramEntry(const WordIdArrayView prevWordIds, + const int wordId) { + return mBigramPolicy->removeEntry(prevWordIds[0], wordId); } bool Ver4PatriciaTrieNodeWriter::updateAllBigramEntriesAndDeleteUselessEntries( diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h index 162dc9b1d..6703dba04 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h @@ -75,12 +75,10 @@ class Ver4PatriciaTrieNodeWriter : public PtNodeWriter { virtual bool writeNewTerminalPtNodeAndAdvancePosition(const PtNodeParams *const ptNodeParams, const UnigramProperty *const unigramProperty, int *const ptNodeWritingPos); - virtual bool addNewBigramEntry(const PtNodeParams *const sourcePtNodeParams, - const PtNodeParams *const targetPtNodeParam, const BigramProperty *const bigramProperty, - bool *const outAddedNewBigram); + virtual bool addNgramEntry(const WordIdArrayView prevWordIds, const int wordId, + const BigramProperty *const bigramProperty, bool *const outAddedNewEntry); - virtual bool removeBigramEntry(const PtNodeParams *const sourcePtNodeParams, - const PtNodeParams *const targetPtNodeParam); + virtual bool removeNgramEntry(const WordIdArrayView prevWordIds, const int wordId); virtual bool updateAllBigramEntriesAndDeleteUselessEntries( const PtNodeParams *const sourcePtNodeParams, int *const outBigramEntryCount); diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp index 0b5764aba..723808399 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp @@ -20,6 +20,7 @@ #include "suggest/core/dicnode/dic_node.h" #include "suggest/core/dicnode/dic_node_vector.h" +#include "suggest/core/dictionary/ngram_listener.h" #include "suggest/core/dictionary/property/bigram_property.h" #include "suggest/core/dictionary/property/unigram_property.h" #include "suggest/core/dictionary/property/word_property.h" @@ -121,7 +122,8 @@ int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability, } } -int Ver4PatriciaTriePolicy::getUnigramProbabilityOfPtNode(const int ptNodePos) const { +int Ver4PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtNodePos, + const int ptNodePos) const { if (ptNodePos == NOT_A_DICT_POS) { return NOT_A_PROBABILITY; } @@ -129,9 +131,34 @@ int Ver4PatriciaTriePolicy::getUnigramProbabilityOfPtNode(const int ptNodePos) c if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) { return NOT_A_PROBABILITY; } + if (prevWordsPtNodePos) { + const int bigramsPosition = getBigramsPositionOfPtNode(prevWordsPtNodePos[0]); + BinaryDictionaryBigramsIterator bigramsIt(&mBigramPolicy, bigramsPosition); + while (bigramsIt.hasNext()) { + bigramsIt.next(); + if (bigramsIt.getBigramPos() == ptNodePos + && bigramsIt.getProbability() != NOT_A_PROBABILITY) { + return getProbability(ptNodeParams.getProbability(), bigramsIt.getProbability()); + } + } + return NOT_A_PROBABILITY; + } return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY); } +void Ver4PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordsPtNodePos, + NgramListener *const listener) const { + if (!prevWordsPtNodePos) { + return; + } + const int bigramsPosition = getBigramsPositionOfPtNode(prevWordsPtNodePos[0]); + BinaryDictionaryBigramsIterator bigramsIt(&mBigramPolicy, bigramsPosition); + while (bigramsIt.hasNext()) { + bigramsIt.next(); + listener->onVisitEntry(bigramsIt.getProbability(), bigramsIt.getBigramPos()); + } +} + int Ver4PatriciaTriePolicy::getShortcutPositionOfPtNode(const int ptNodePos) const { if (ptNodePos == NOT_A_DICT_POS) { return NOT_A_DICT_POS; @@ -144,12 +171,6 @@ int Ver4PatriciaTriePolicy::getShortcutPositionOfPtNode(const int ptNodePos) con ptNodeParams.getTerminalId()); } -BinaryDictionaryBigramsIterator Ver4PatriciaTriePolicy::getBigramsIteratorOfPtNode( - const int ptNodePos) const { - const int bigramsPosition = getBigramsPositionOfPtNode(ptNodePos); - return BinaryDictionaryBigramsIterator(&mBigramPolicy, bigramsPosition); -} - int Ver4PatriciaTriePolicy::getBigramsPositionOfPtNode(const int ptNodePos) const { if (ptNodePos == NOT_A_DICT_POS) { return NOT_A_DICT_POS; @@ -271,6 +292,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; prevWordsInfo->getPrevWordsTerminalPtNodePos(this, prevWordsPtNodePos, false /* tryLowerCaseSearch */); + const auto prevWordsPtNodePosView = PtNodePosArrayView::fromFixedSizeArray(prevWordsPtNodePos); // TODO: Support N-gram. if (prevWordsPtNodePos[0] == NOT_A_DICT_POS) { if (prevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */)) { @@ -298,10 +320,10 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI if (word1Pos == NOT_A_DICT_POS) { return false; } - bool addedNewBigram = false; - if (mUpdatingHelper.addBigramWords(prevWordsPtNodePos[0], word1Pos, bigramProperty, - &addedNewBigram)) { - if (addedNewBigram) { + bool addedNewEntry = false; + if (mUpdatingHelper.addNgramEntry(prevWordsPtNodePosView, word1Pos, bigramProperty, + &addedNewEntry)) { + if (addedNewEntry) { mBigramCount++; } return true; @@ -331,6 +353,7 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; prevWordsInfo->getPrevWordsTerminalPtNodePos(this, prevWordsPtNodePos, false /* tryLowerCaseSerch */); + const auto prevWordsPtNodePosView = PtNodePosArrayView::fromFixedSizeArray(prevWordsPtNodePos); // TODO: Support N-gram. if (prevWordsPtNodePos[0] == NOT_A_DICT_POS) { return false; @@ -340,7 +363,7 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor if (wordPos == NOT_A_DICT_POS) { return false; } - if (mUpdatingHelper.removeBigramWords(prevWordsPtNodePos[0], wordPos)) { + if (mUpdatingHelper.removeNgramEntry(prevWordsPtNodePosView, wordPos)) { mBigramCount--; return true; } else { @@ -431,7 +454,7 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(const int *const code std::vector<int> codePointVector(ptNodeParams.getCodePoints(), ptNodeParams.getCodePoints() + ptNodeParams.getCodePointCount()); const ProbabilityEntry probabilityEntry = - mBuffers->getProbabilityDictContent()->getProbabilityEntry( + mBuffers->getLanguageModelDictContent()->getProbabilityEntry( ptNodeParams.getTerminalId()); const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo(); // Fetch bigram information. diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h index 85929b785..faad4290d 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h @@ -46,7 +46,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { mBuffers->getTerminalPositionLookupTable(), mHeaderPolicy), mShortcutPolicy(mBuffers->getMutableShortcutDictContent(), mBuffers->getTerminalPositionLookupTable()), - mNodeReader(mDictBuffer, mBuffers->getProbabilityDictContent(), mHeaderPolicy), + mNodeReader(mDictBuffer, mBuffers->getLanguageModelDictContent(), mHeaderPolicy), mPtNodeArrayReader(mDictBuffer), mNodeWriter(mDictBuffer, mBuffers.get(), mHeaderPolicy, &mNodeReader, &mPtNodeArrayReader, &mBigramPolicy, &mShortcutPolicy), @@ -72,11 +72,12 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { int getProbability(const int unigramProbability, const int bigramProbability) const; - int getUnigramProbabilityOfPtNode(const int ptNodePos) const; + int getProbabilityOfPtNode(const int *const prevWordsPtNodePos, const int ptNodePos) const; - int getShortcutPositionOfPtNode(const int ptNodePos) const; + void iterateNgramEntries(const int *const prevWordsPtNodePos, + NgramListener *const listener) const; - BinaryDictionaryBigramsIterator getBigramsIteratorOfPtNode(const int ptNodePos) const; + int getShortcutPositionOfPtNode(const int ptNodePos) const; const DictionaryHeaderStructurePolicy *getHeaderStructurePolicy() const { return mHeaderPolicy; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp index 0e658f8e3..4220312e0 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp @@ -75,7 +75,7 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, const HeaderPolicy *const headerPolicy, Ver4DictBuffers *const buffersToWrite, int *const outUnigramCount, int *const outBigramCount) { Ver4PatriciaTrieNodeReader ptNodeReader(mBuffers->getTrieBuffer(), - mBuffers->getProbabilityDictContent(), headerPolicy); + mBuffers->getLanguageModelDictContent(), headerPolicy); Ver4PtNodeArrayReader ptNodeArrayReader(mBuffers->getTrieBuffer()); Ver4BigramListPolicy bigramPolicy(mBuffers->getMutableBigramDictContent(), mBuffers->getTerminalPositionLookupTable(), headerPolicy); @@ -138,7 +138,7 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, // Create policy instances for the GCed dictionary. Ver4PatriciaTrieNodeReader newPtNodeReader(buffersToWrite->getTrieBuffer(), - buffersToWrite->getProbabilityDictContent(), headerPolicy); + buffersToWrite->getLanguageModelDictContent(), headerPolicy); Ver4PtNodeArrayReader newPtNodeArrayreader(buffersToWrite->getTrieBuffer()); Ver4BigramListPolicy newBigramPolicy(buffersToWrite->getMutableBigramDictContent(), buffersToWrite->getTerminalPositionLookupTable(), headerPolicy); @@ -154,8 +154,8 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, return false; } // Run GC for probability dict content. - if (!buffersToWrite->getMutableProbabilityDictContent()->runGC(&terminalIdMap, - mBuffers->getProbabilityDictContent())) { + if (!buffersToWrite->getMutableLanguageModelDictContent()->runGC(&terminalIdMap, + mBuffers->getLanguageModelDictContent(), nullptr /* outNgramCount */)) { return false; } // Run GC for bigram dict content. @@ -201,7 +201,7 @@ bool Ver4PatriciaTrieWritingHelper::truncateUnigrams( continue; } const ProbabilityEntry probabilityEntry = - mBuffers->getProbabilityDictContent()->getProbabilityEntry(i); + mBuffers->getLanguageModelDictContent()->getProbabilityEntry(i); const int probability = probabilityEntry.hasHistoricalInfo() ? ForgettingCurveUtils::decodeProbability( probabilityEntry.getHistoricalInfo(), mBuffers->getHeaderPolicy()) : diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.cpp index 825b72c6a..833063c17 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.cpp @@ -25,7 +25,7 @@ const size_t BufferWithExtendableBuffer::EXTEND_ADDITIONAL_BUFFER_SIZE_STEP = 12 uint32_t BufferWithExtendableBuffer::readUint(const int size, const int pos) const { const bool readingPosIsInAdditionalBuffer = isInAdditionalBuffer(pos); - const int posInBuffer = readingPosIsInAdditionalBuffer ? pos - mOriginalBufferSize : pos; + const int posInBuffer = readingPosIsInAdditionalBuffer ? pos - mOriginalBuffer.size() : pos; return ByteArrayUtils::readUint(getBuffer(readingPosIsInAdditionalBuffer), size, posInBuffer); } @@ -40,12 +40,12 @@ void BufferWithExtendableBuffer::readCodePointsAndAdvancePosition(const int maxC int *const outCodePoints, int *outCodePointCount, int *const pos) const { const bool readingPosIsInAdditionalBuffer = isInAdditionalBuffer(*pos); if (readingPosIsInAdditionalBuffer) { - *pos -= mOriginalBufferSize; + *pos -= mOriginalBuffer.size(); } *outCodePointCount = ByteArrayUtils::readStringAndAdvancePosition( getBuffer(readingPosIsInAdditionalBuffer), maxCodePointCount, outCodePoints, pos); if (readingPosIsInAdditionalBuffer) { - *pos += mOriginalBufferSize; + *pos += mOriginalBuffer.size(); } } @@ -69,13 +69,14 @@ bool BufferWithExtendableBuffer::writeUintAndAdvancePosition(const uint32_t data return false; } const bool usesAdditionalBuffer = isInAdditionalBuffer(*pos); - uint8_t *const buffer = usesAdditionalBuffer ? &mAdditionalBuffer[0] : mOriginalBuffer; + uint8_t *const buffer = + usesAdditionalBuffer ? mAdditionalBuffer.data() : mOriginalBuffer.data(); if (usesAdditionalBuffer) { - *pos -= mOriginalBufferSize; + *pos -= mOriginalBuffer.size(); } ByteArrayUtils::writeUintAndAdvancePosition(buffer, data, size, pos); if (usesAdditionalBuffer) { - *pos += mOriginalBufferSize; + *pos += mOriginalBuffer.size(); } return true; } @@ -88,14 +89,15 @@ bool BufferWithExtendableBuffer::writeCodePointsAndAdvancePosition(const int *co return false; } const bool usesAdditionalBuffer = isInAdditionalBuffer(*pos); - uint8_t *const buffer = usesAdditionalBuffer ? &mAdditionalBuffer[0] : mOriginalBuffer; + uint8_t *const buffer = + usesAdditionalBuffer ? mAdditionalBuffer.data() : mOriginalBuffer.data(); if (usesAdditionalBuffer) { - *pos -= mOriginalBufferSize; + *pos -= mOriginalBuffer.size(); } ByteArrayUtils::writeCodePointsAndAdvancePosition(buffer, codePoints, codePointCount, writesTerminator, pos); if (usesAdditionalBuffer) { - *pos += mOriginalBufferSize; + *pos += mOriginalBuffer.size(); } return true; } @@ -119,7 +121,7 @@ bool BufferWithExtendableBuffer::checkAndPrepareWriting(const int pos, const int const size_t totalRequiredSize = static_cast<size_t>(pos + size); if (!isInAdditionalBuffer(pos)) { // Here don't need to care about the additional buffer. - if (static_cast<size_t>(mOriginalBufferSize) < totalRequiredSize) { + if (mOriginalBuffer.size() < totalRequiredSize) { // Violate the boundary. return false; } @@ -137,7 +139,7 @@ bool BufferWithExtendableBuffer::checkAndPrepareWriting(const int pos, const int return false; } const size_t extendSize = totalRequiredSize - - std::min(mAdditionalBuffer.size() + mOriginalBufferSize, totalRequiredSize); + std::min(mAdditionalBuffer.size() + mOriginalBuffer.size(), totalRequiredSize); if (extendSize > 0 && !extendBuffer(extendSize)) { // Failed to extend the buffer. return false; diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h b/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h index 5e1362eee..fad83aa25 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h @@ -23,6 +23,7 @@ #include "defines.h" #include "suggest/policyimpl/dictionary/utils/byte_array_utils.h" +#include "utils/byte_array_view.h" namespace latinime { @@ -34,20 +35,18 @@ class BufferWithExtendableBuffer { public: static const size_t DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE; - BufferWithExtendableBuffer(uint8_t *const originalBuffer, const int originalBufferSize, + BufferWithExtendableBuffer(const ReadWriteByteArrayView originalBuffer, const int maxAdditionalBufferSize) - : mOriginalBuffer(originalBuffer), mOriginalBufferSize(originalBufferSize), - mAdditionalBuffer(0), mUsedAdditionalBufferSize(0), + : mOriginalBuffer(originalBuffer), mAdditionalBuffer(), mUsedAdditionalBufferSize(0), mMaxAdditionalBufferSize(maxAdditionalBufferSize) {} // Without original buffer. BufferWithExtendableBuffer(const int maxAdditionalBufferSize) - : mOriginalBuffer(0), mOriginalBufferSize(0), - mAdditionalBuffer(0), mUsedAdditionalBufferSize(0), + : mOriginalBuffer(), mAdditionalBuffer(), mUsedAdditionalBufferSize(0), mMaxAdditionalBufferSize(maxAdditionalBufferSize) {} AK_FORCE_INLINE int getTailPosition() const { - return mOriginalBufferSize + mUsedAdditionalBufferSize; + return mOriginalBuffer.size() + mUsedAdditionalBufferSize; } AK_FORCE_INLINE int getUsedAdditionalBufferSize() const { @@ -58,16 +57,16 @@ class BufferWithExtendableBuffer { * For reading. */ AK_FORCE_INLINE bool isInAdditionalBuffer(const int position) const { - return position >= mOriginalBufferSize; + return position >= static_cast<int>(mOriginalBuffer.size()); } // TODO: Resolve the issue that the address can be changed when the vector is resized. // CAVEAT!: Be careful about array out of bound access with buffers AK_FORCE_INLINE const uint8_t *getBuffer(const bool usesAdditionalBuffer) const { if (usesAdditionalBuffer) { - return &mAdditionalBuffer[0]; + return mAdditionalBuffer.data(); } else { - return mOriginalBuffer; + return mOriginalBuffer.data(); } } @@ -79,7 +78,7 @@ class BufferWithExtendableBuffer { int *const outCodePoints, int *outCodePointCount, int *const pos) const; AK_FORCE_INLINE int getOriginalBufferSize() const { - return mOriginalBufferSize; + return mOriginalBuffer.size(); } AK_FORCE_INLINE bool isNearSizeLimit() const { @@ -110,8 +109,7 @@ class BufferWithExtendableBuffer { static const int NEAR_BUFFER_LIMIT_THRESHOLD_PERCENTILE; static const size_t EXTEND_ADDITIONAL_BUFFER_SIZE_STEP; - uint8_t *const mOriginalBuffer; - const int mOriginalBufferSize; + const ReadWriteByteArrayView mOriginalBuffer; std::vector<uint8_t> mAdditionalBuffer; int mUsedAdditionalBufferSize; const size_t mMaxAdditionalBufferSize; diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h index 3ff80aeec..9910777b8 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h @@ -84,7 +84,7 @@ class ForgettingCurveUtils { static const int STRONG_BASE_PROBABILITY; static const int AGGRESSIVE_BASE_PROBABILITY; - std::vector<std::vector<std::vector<int> > > mTables; + std::vector<std::vector<std::vector<int>>> mTables; static int getBaseProbabilityForLevel(const int tableId, const int level); }; diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/mmapped_buffer.h b/native/jni/src/suggest/policyimpl/dictionary/utils/mmapped_buffer.h index 8460087ab..e25310373 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/mmapped_buffer.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/mmapped_buffer.h @@ -21,6 +21,7 @@ #include <memory> #include "defines.h" +#include "utils/byte_array_view.h" namespace latinime { @@ -39,12 +40,12 @@ class MmappedBuffer { ~MmappedBuffer(); - AK_FORCE_INLINE uint8_t *getBuffer() const { - return mBuffer; + ReadWriteByteArrayView getReadWriteByteArrayView() const { + return mByteArrayView; } - AK_FORCE_INLINE int getBufferSize() const { - return mBufferSize; + ReadOnlyByteArrayView getReadOnlyByteArrayView() const { + return mByteArrayView.getReadOnlyView(); } AK_FORCE_INLINE bool isUpdatable() const { @@ -55,18 +56,17 @@ class MmappedBuffer { AK_FORCE_INLINE MmappedBuffer(uint8_t *const buffer, const int bufferSize, void *const mmappedBuffer, const int alignedSize, const int mmapFd, const bool isUpdatable) - : mBuffer(buffer), mBufferSize(bufferSize), mMmappedBuffer(mmappedBuffer), + : mByteArrayView(buffer, bufferSize), mMmappedBuffer(mmappedBuffer), mAlignedSize(alignedSize), mMmapFd(mmapFd), mIsUpdatable(isUpdatable) {} // Empty file. We have to handle an empty file as a valid part of a dictionary. AK_FORCE_INLINE MmappedBuffer(const bool isUpdatable) - : mBuffer(nullptr), mBufferSize(0), mMmappedBuffer(nullptr), mAlignedSize(0), + : mByteArrayView(), mMmappedBuffer(nullptr), mAlignedSize(0), mMmapFd(0), mIsUpdatable(isUpdatable) {} DISALLOW_IMPLICIT_CONSTRUCTORS(MmappedBuffer); - uint8_t *const mBuffer; - const int mBufferSize; + const ReadWriteByteArrayView mByteArrayView; void *const mMmappedBuffer; const int mAlignedSize; const int mMmapFd; diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp new file mode 100644 index 000000000..407b8efd0 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp @@ -0,0 +1,387 @@ +/* + * Copyright (C) 2014, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/utils/trie_map.h" + +#include "suggest/policyimpl/dictionary/utils/dict_file_writing_utils.h" + +namespace latinime { + +const int TrieMap::INVALID_INDEX = -1; +const int TrieMap::FIELD0_SIZE = 4; +const int TrieMap::FIELD1_SIZE = 3; +const int TrieMap::ENTRY_SIZE = FIELD0_SIZE + FIELD1_SIZE; +const uint32_t TrieMap::VALUE_FLAG = 0x400000; +const uint32_t TrieMap::VALUE_MASK = 0x3FFFFF; +const uint32_t TrieMap::TERMINAL_LINK_FLAG = 0x800000; +const uint32_t TrieMap::TERMINAL_LINK_MASK = 0x7FFFFF; +const int TrieMap::NUM_OF_BITS_USED_FOR_ONE_LEVEL = 5; +const uint32_t TrieMap::LABEL_MASK = 0x1F; +const int TrieMap::MAX_NUM_OF_ENTRIES_IN_ONE_LEVEL = 1 << NUM_OF_BITS_USED_FOR_ONE_LEVEL; +const int TrieMap::ROOT_BITMAP_ENTRY_INDEX = 0; +const int TrieMap::ROOT_BITMAP_ENTRY_POS = MAX_NUM_OF_ENTRIES_IN_ONE_LEVEL * FIELD0_SIZE; +const TrieMap::Entry TrieMap::EMPTY_BITMAP_ENTRY = TrieMap::Entry(0, 0); +const uint64_t TrieMap::MAX_VALUE = + (static_cast<uint64_t>(1) << ((FIELD0_SIZE + FIELD1_SIZE) * CHAR_BIT)) - 1; +const int TrieMap::MAX_BUFFER_SIZE = TERMINAL_LINK_MASK * ENTRY_SIZE; + +TrieMap::TrieMap() : mBuffer(MAX_BUFFER_SIZE) { + mBuffer.extend(ROOT_BITMAP_ENTRY_POS); + writeEntry(EMPTY_BITMAP_ENTRY, ROOT_BITMAP_ENTRY_INDEX); +} + +TrieMap::TrieMap(const ReadWriteByteArrayView buffer) + : mBuffer(buffer, BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE) {} + +void TrieMap::dump(const int from, const int to) const { + AKLOGI("BufSize: %d", mBuffer.getTailPosition()); + for (int i = from; i < to; ++i) { + AKLOGI("Entry[%d]: %x, %x", i, readField0(i), readField1(i)); + } + int unusedRegionSize = 0; + for (int i = 1; i <= MAX_NUM_OF_ENTRIES_IN_ONE_LEVEL; ++i) { + int index = readEmptyTableLink(i); + while (index != ROOT_BITMAP_ENTRY_INDEX) { + index = readField0(index); + unusedRegionSize += i; + } + } + AKLOGI("Unused Size: %d", unusedRegionSize); +} + +int TrieMap::getNextLevelBitmapEntryIndex(const int key, const int bitmapEntryIndex) { + const Entry bitmapEntry = readEntry(bitmapEntryIndex); + const uint32_t unsignedKey = static_cast<uint32_t>(key); + const int terminalEntryIndex = getTerminalEntryIndex( + unsignedKey, getBitShuffledKey(unsignedKey), bitmapEntry, 0 /* level */); + if (terminalEntryIndex == INVALID_INDEX) { + // Not found. + return INVALID_INDEX; + } + const Entry terminalEntry = readEntry(terminalEntryIndex); + if (terminalEntry.hasTerminalLink()) { + return terminalEntry.getValueEntryIndex() + 1; + } + // Create a value entry and a bitmap entry. + const int valueEntryIndex = allocateTable(2 /* entryCount */); + if (!writeEntry(Entry(0, terminalEntry.getValue()), valueEntryIndex)) { + return INVALID_INDEX; + } + if (!writeEntry(EMPTY_BITMAP_ENTRY, valueEntryIndex + 1)) { + return INVALID_INDEX; + } + if (!writeField1(valueEntryIndex | TERMINAL_LINK_FLAG, valueEntryIndex)) { + return INVALID_INDEX; + } + return valueEntryIndex + 1; +} + +const TrieMap::Result TrieMap::get(const int key, const int bitmapEntryIndex) const { + const uint32_t unsignedKey = static_cast<uint32_t>(key); + return getInternal(unsignedKey, getBitShuffledKey(unsignedKey), bitmapEntryIndex, + 0 /* level */); +} + +bool TrieMap::put(const int key, const uint64_t value, const int bitmapEntryIndex) { + if (value > MAX_VALUE) { + return false; + } + const uint32_t unsignedKey = static_cast<uint32_t>(key); + return putInternal(unsignedKey, value, getBitShuffledKey(unsignedKey), bitmapEntryIndex, + readEntry(bitmapEntryIndex), 0 /* level */); +} + +bool TrieMap::save(FILE *const file) const { + return DictFileWritingUtils::writeBufferToFileTail(file, &mBuffer); +} + +/** + * Iterate next entry in a certain level. + * + * @param iterationState the iteration state that will be read and updated in this method. + * @param outKey the output key + * @return Result instance. mIsValid is false when all entries are iterated. + */ +const TrieMap::Result TrieMap::iterateNext(std::vector<TableIterationState> *const iterationState, + int *const outKey) const { + while (!iterationState->empty()) { + TableIterationState &state = iterationState->back(); + if (state.mTableSize <= state.mCurrentIndex) { + // Move to parent. + iterationState->pop_back(); + } else { + const int entryIndex = state.mTableIndex + state.mCurrentIndex; + state.mCurrentIndex += 1; + const Entry entry = readEntry(entryIndex); + if (entry.isBitmapEntry()) { + // Move to child. + iterationState->emplace_back(popCount(entry.getBitmap()), entry.getTableIndex()); + } else { + if (outKey) { + *outKey = entry.getKey(); + } + if (!entry.hasTerminalLink()) { + return Result(entry.getValue(), true, INVALID_INDEX); + } + const int valueEntryIndex = entry.getValueEntryIndex(); + const Entry valueEntry = readEntry(valueEntryIndex); + return Result(valueEntry.getValueOfValueEntry(), true, valueEntryIndex + 1); + } + } + } + // Visited all entries. + return Result(0, false, INVALID_INDEX); +} + +/** + * Shuffle bits of the key in the fixed order. + * + * This method is used as a hash function. This returns different values for different inputs. + */ +uint32_t TrieMap::getBitShuffledKey(const uint32_t key) const { + uint32_t shuffledKey = 0; + for (int i = 0; i < 4; ++i) { + const uint32_t keyPiece = (key >> (i * 8)) & 0xFF; + shuffledKey ^= ((keyPiece ^ (keyPiece << 7) ^ (keyPiece << 14) ^ (keyPiece << 21)) + & 0x11111111) << i; + } + return shuffledKey; +} + +bool TrieMap::writeValue(const uint64_t value, const int terminalEntryIndex) { + if (value <= VALUE_MASK) { + // Write value into the terminal entry. + return writeField1(value | VALUE_FLAG, terminalEntryIndex); + } + // Create value entry and write value. + const int valueEntryIndex = allocateTable(2 /* entryCount */); + if (!writeEntry(Entry(value >> (FIELD1_SIZE * CHAR_BIT), value), valueEntryIndex)) { + return false; + } + if (!writeEntry(EMPTY_BITMAP_ENTRY, valueEntryIndex + 1)) { + return false; + } + return writeField1(valueEntryIndex | TERMINAL_LINK_FLAG, terminalEntryIndex); +} + +bool TrieMap::updateValue(const Entry &terminalEntry, const uint64_t value, + const int terminalEntryIndex) { + if (!terminalEntry.hasTerminalLink()) { + return writeValue(value, terminalEntryIndex); + } + const int valueEntryIndex = terminalEntry.getValueEntryIndex(); + return writeEntry(Entry(value >> (FIELD1_SIZE * CHAR_BIT), value), valueEntryIndex); +} + +bool TrieMap::freeTable(const int tableIndex, const int entryCount) { + if (!writeField0(readEmptyTableLink(entryCount), tableIndex)) { + return false; + } + return writeEmptyTableLink(tableIndex, entryCount); +} + +/** + * Allocate table with entryCount-entries. Reuse freed table if possible. + */ +int TrieMap::allocateTable(const int entryCount) { + if (entryCount > 0 && entryCount <= MAX_NUM_OF_ENTRIES_IN_ONE_LEVEL) { + const int tableIndex = readEmptyTableLink(entryCount); + if (tableIndex > 0) { + if (!writeEmptyTableLink(readField0(tableIndex), entryCount)) { + return INVALID_INDEX; + } + // Reuse the table. + return tableIndex; + } + } + // Allocate memory space at tail position of the buffer. + const int mapIndex = getTailEntryIndex(); + if (!mBuffer.extend(entryCount * ENTRY_SIZE)) { + return INVALID_INDEX; + } + return mapIndex; +} + +int TrieMap::getTerminalEntryIndex(const uint32_t key, const uint32_t hashedKey, + const Entry &bitmapEntry, const int level) const { + const int label = getLabel(hashedKey, level); + if (!exists(bitmapEntry.getBitmap(), label)) { + return INVALID_INDEX; + } + const int entryIndex = bitmapEntry.getTableIndex() + popCount(bitmapEntry.getBitmap(), label); + const Entry entry = readEntry(entryIndex); + if (entry.isBitmapEntry()) { + // Move to the next level. + return getTerminalEntryIndex(key, hashedKey, entry, level + 1); + } + if (entry.getKey() == key) { + // Terminal entry is found. + return entryIndex; + } + return INVALID_INDEX; +} + +/** + * Get Result corresponding to the key. + * + * @param key the key. + * @param hashedKey the hashed key. + * @param bitmapEntryIndex the index of bitmap entry + * @param level current level + * @return Result instance corresponding to the key. mIsValid indicates whether the key is in the + * map. + */ +const TrieMap::Result TrieMap::getInternal(const uint32_t key, const uint32_t hashedKey, + const int bitmapEntryIndex, const int level) const { + const int terminalEntryIndex = getTerminalEntryIndex(key, hashedKey, + readEntry(bitmapEntryIndex), level); + if (terminalEntryIndex == INVALID_INDEX) { + // Not found. + return Result(0, false, INVALID_INDEX); + } + const Entry terminalEntry = readEntry(terminalEntryIndex); + if (!terminalEntry.hasTerminalLink()) { + return Result(terminalEntry.getValue(), true, INVALID_INDEX); + } + const int valueEntryIndex = terminalEntry.getValueEntryIndex(); + const Entry valueEntry = readEntry(valueEntryIndex); + return Result(valueEntry.getValueOfValueEntry(), true, valueEntryIndex + 1); +} + +/** + * Put key to value mapping to the map. + * + * @param key the key. + * @param value the value + * @param hashedKey the hashed key. + * @param bitmapEntryIndex the index of bitmap entry + * @param bitmapEntry the bitmap entry + * @param level current level + * @return whether the key-value has been correctly inserted to the map or not. + */ +bool TrieMap::putInternal(const uint32_t key, const uint64_t value, const uint32_t hashedKey, + const int bitmapEntryIndex, const Entry &bitmapEntry, const int level) { + const int label = getLabel(hashedKey, level); + const uint32_t bitmap = bitmapEntry.getBitmap(); + const int mapIndex = bitmapEntry.getTableIndex(); + if (!exists(bitmap, label)) { + // Current map doesn't contain the label. + return addNewEntryByExpandingTable(key, value, mapIndex, bitmap, bitmapEntryIndex, label); + } + const int entryIndex = mapIndex + popCount(bitmap, label); + const Entry entry = readEntry(entryIndex); + if (entry.isBitmapEntry()) { + // Bitmap entry is found. Go to the next level. + return putInternal(key, value, hashedKey, entryIndex, entry, level + 1); + } + if (entry.getKey() == key) { + // Terminal entry for the key is found. Update the value. + return updateValue(entry, value, entryIndex); + } + // Conflict with the existing key. + return addNewEntryByResolvingConflict(key, value, hashedKey, entry, entryIndex, level); +} + +/** + * Resolve a conflict in the current level and add new entry. + * + * @param key the key + * @param value the value + * @param hashedKey the hashed key + * @param conflictedEntry the existing conflicted entry + * @param conflictedEntryIndex the index of existing conflicted entry + * @param level current level + * @return whether the key-value has been correctly inserted to the map or not. + */ +bool TrieMap::addNewEntryByResolvingConflict(const uint32_t key, const uint64_t value, + const uint32_t hashedKey, const Entry &conflictedEntry, const int conflictedEntryIndex, + const int level) { + const int conflictedKeyNextLabel = + getLabel(getBitShuffledKey(conflictedEntry.getKey()), level + 1); + const int nextLabel = getLabel(hashedKey, level + 1); + if (conflictedKeyNextLabel == nextLabel) { + // Conflicted again in the next level. + const int newTableIndex = allocateTable(1 /* entryCount */); + if (newTableIndex == INVALID_INDEX) { + return false; + } + if (!writeEntry(conflictedEntry, newTableIndex)) { + return false; + } + const Entry newBitmapEntry(setExist(0 /* bitmap */, nextLabel), newTableIndex); + if (!writeEntry(newBitmapEntry, conflictedEntryIndex)) { + return false; + } + return putInternal(key, value, hashedKey, conflictedEntryIndex, newBitmapEntry, level + 1); + } + // The conflict has been resolved. Create a table that contains 2 entries. + const int newTableIndex = allocateTable(2 /* entryCount */); + if (newTableIndex == INVALID_INDEX) { + return false; + } + if (nextLabel < conflictedKeyNextLabel) { + if (!writeTerminalEntry(key, value, newTableIndex)) { + return false; + } + if (!writeEntry(conflictedEntry, newTableIndex + 1)) { + return false; + } + } else { // nextLabel > conflictedKeyNextLabel + if (!writeEntry(conflictedEntry, newTableIndex)) { + return false; + } + if (!writeTerminalEntry(key, value, newTableIndex + 1)) { + return false; + } + } + const uint32_t updatedBitmap = + setExist(setExist(0 /* bitmap */, nextLabel), conflictedKeyNextLabel); + return writeEntry(Entry(updatedBitmap, newTableIndex), conflictedEntryIndex); +} + +/** + * Add new entry to the existing table. + */ +bool TrieMap::addNewEntryByExpandingTable(const uint32_t key, const uint64_t value, + const int tableIndex, const uint32_t bitmap, const int bitmapEntryIndex, const int label) { + // Current map doesn't contain the label. + const int entryCount = popCount(bitmap); + const int newTableIndex = allocateTable(entryCount + 1); + if (newTableIndex == INVALID_INDEX) { + return false; + } + const int newEntryIndexInTable = popCount(bitmap, label); + // Copy from existing table to the new table. + for (int i = 0; i < entryCount; ++i) { + if (!copyEntry(tableIndex + i, newTableIndex + i + (i >= newEntryIndexInTable ? 1 : 0))) { + return false; + } + } + // Write new terminal entry. + if (!writeTerminalEntry(key, value, newTableIndex + newEntryIndexInTable)) { + return false; + } + // Update bitmap. + if (!writeEntry(Entry(setExist(bitmap, label), newTableIndex), bitmapEntryIndex)) { + return false; + } + if (entryCount > 0) { + return freeTable(tableIndex, entryCount); + } + return true; +} + +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h new file mode 100644 index 000000000..3e5c4010c --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h @@ -0,0 +1,384 @@ +/* + * Copyright (C) 2014, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_TRIE_MAP_H +#define LATINIME_TRIE_MAP_H + +#include <climits> +#include <cstdint> +#include <cstdio> +#include <vector> + +#include "defines.h" +#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" +#include "utils/byte_array_view.h" + +namespace latinime { + +/** + * Trie map derived from Phil Bagwell's Hash Array Mapped Trie. + * key is int and value is uint64_t. + * This supports multiple level map. Terminal entries can have a bitmap for the next level map. + * This doesn't support root map resizing. + */ +class TrieMap { + public: + struct Result { + const uint64_t mValue; + const bool mIsValid; + const int mNextLevelBitmapEntryIndex; + + Result(const uint64_t value, const bool isValid, const int nextLevelBitmapEntryIndex) + : mValue(value), mIsValid(isValid), + mNextLevelBitmapEntryIndex(nextLevelBitmapEntryIndex) {} + }; + + /** + * Struct to record iteration state in a table. + */ + struct TableIterationState { + int mTableSize; + int mTableIndex; + int mCurrentIndex; + + TableIterationState(const int tableSize, const int tableIndex) + : mTableSize(tableSize), mTableIndex(tableIndex), mCurrentIndex(0) {} + }; + + class TrieMapRange; + class TrieMapIterator { + public: + class IterationResult { + public: + IterationResult(const TrieMap *const trieMap, const int key, const uint64_t value, + const int nextLeveBitmapEntryIndex) + : mTrieMap(trieMap), mKey(key), mValue(value), + mNextLevelBitmapEntryIndex(nextLeveBitmapEntryIndex) {} + + const TrieMapRange getEntriesInNextLevel() const { + return TrieMapRange(mTrieMap, mNextLevelBitmapEntryIndex); + } + + bool hasNextLevelMap() const { + return mNextLevelBitmapEntryIndex != INVALID_INDEX; + } + + AK_FORCE_INLINE int key() const { + return mKey; + } + + AK_FORCE_INLINE uint64_t value() const { + return mValue; + } + + private: + const TrieMap *const mTrieMap; + const int mKey; + const uint64_t mValue; + const int mNextLevelBitmapEntryIndex; + }; + + TrieMapIterator(const TrieMap *const trieMap, const int bitmapEntryIndex) + : mTrieMap(trieMap), mStateStack(), mBaseBitmapEntryIndex(bitmapEntryIndex), + mKey(0), mValue(0), mIsValid(false), mNextLevelBitmapEntryIndex(INVALID_INDEX) { + if (!trieMap) { + return; + } + const Entry bitmapEntry = mTrieMap->readEntry(mBaseBitmapEntryIndex); + mStateStack.emplace_back( + mTrieMap->popCount(bitmapEntry.getBitmap()), bitmapEntry.getTableIndex()); + this->operator++(); + } + + const IterationResult operator*() const { + return IterationResult(mTrieMap, mKey, mValue, mNextLevelBitmapEntryIndex); + } + + bool operator!=(const TrieMapIterator &other) const { + // Caveat: This works only for for loops. + return mIsValid || other.mIsValid; + } + + const TrieMapIterator &operator++() { + const Result result = mTrieMap->iterateNext(&mStateStack, &mKey); + mValue = result.mValue; + mIsValid = result.mIsValid; + mNextLevelBitmapEntryIndex = result.mNextLevelBitmapEntryIndex; + return *this; + } + + private: + DISALLOW_DEFAULT_CONSTRUCTOR(TrieMapIterator); + DISALLOW_ASSIGNMENT_OPERATOR(TrieMapIterator); + + const TrieMap *const mTrieMap; + std::vector<TrieMap::TableIterationState> mStateStack; + const int mBaseBitmapEntryIndex; + int mKey; + uint64_t mValue; + bool mIsValid; + int mNextLevelBitmapEntryIndex; + }; + + /** + * Class to support iterating entries in TrieMap by range base for loops. + */ + class TrieMapRange { + public: + TrieMapRange(const TrieMap *const trieMap, const int bitmapEntryIndex) + : mTrieMap(trieMap), mBaseBitmapEntryIndex(bitmapEntryIndex) {}; + + TrieMapIterator begin() const { + return TrieMapIterator(mTrieMap, mBaseBitmapEntryIndex); + } + + const TrieMapIterator end() const { + return TrieMapIterator(nullptr, INVALID_INDEX); + } + + private: + DISALLOW_DEFAULT_CONSTRUCTOR(TrieMapRange); + DISALLOW_ASSIGNMENT_OPERATOR(TrieMapRange); + + const TrieMap *const mTrieMap; + const int mBaseBitmapEntryIndex; + }; + + static const int INVALID_INDEX; + static const uint64_t MAX_VALUE; + + TrieMap(); + // Construct TrieMap using existing data in the memory region written by save(). + TrieMap(const ReadWriteByteArrayView buffer); + void dump(const int from = 0, const int to = 0) const; + + bool isNearSizeLimit() const { + return mBuffer.isNearSizeLimit(); + } + + int getRootBitmapEntryIndex() const { + return ROOT_BITMAP_ENTRY_INDEX; + } + + // Returns bitmapEntryIndex. Create the next level map if it doesn't exist. + int getNextLevelBitmapEntryIndex(const int key) { + return getNextLevelBitmapEntryIndex(key, ROOT_BITMAP_ENTRY_INDEX); + } + + int getNextLevelBitmapEntryIndex(const int key, const int bitmapEntryIndex); + + const Result getRoot(const int key) const { + return get(key, ROOT_BITMAP_ENTRY_INDEX); + } + + const Result get(const int key, const int bitmapEntryIndex) const; + + bool putRoot(const int key, const uint64_t value) { + return put(key, value, ROOT_BITMAP_ENTRY_INDEX); + } + + bool put(const int key, const uint64_t value, const int bitmapEntryIndex); + + const TrieMapRange getEntriesInRootLevel() const { + return getEntriesInSpecifiedLevel(ROOT_BITMAP_ENTRY_INDEX); + } + + const TrieMapRange getEntriesInSpecifiedLevel(const int bitmapEntryIndex) const { + return TrieMapRange(this, bitmapEntryIndex); + } + + bool save(FILE *const file) const; + + private: + DISALLOW_COPY_AND_ASSIGN(TrieMap); + + /** + * Struct represents an entry. + * + * Entry is one of these entry types. All entries are fixed size and have 2 fields FIELD_0 and + * FIELD_1. + * 1. bitmap entry. bitmap entry contains bitmap and the link to hash table. + * FIELD_0(bitmap) FIELD_1(LINK_TO_HASH_TABLE) + * 2. terminal entry. terminal entry contains hashed key and value or terminal link. terminal + * entry have terminal link when the value is not fit to FIELD_1 or there is a next level map + * for the key. + * FIELD_0(hashed key) (FIELD_1(VALUE_FLAG VALUE) | FIELD_1(TERMINAL_LINK_FLAG TERMINAL_LINK)) + * 3. value entry. value entry represents a value. Upper order bytes are stored in FIELD_0 and + * lower order bytes are stored in FIELD_1. + * FIELD_0(value (upper order bytes)) FIELD_1(value (lower order bytes)) + */ + struct Entry { + const uint32_t mData0; + const uint32_t mData1; + + Entry(const uint32_t data0, const uint32_t data1) : mData0(data0), mData1(data1) {} + + AK_FORCE_INLINE bool isBitmapEntry() const { + return (mData1 & VALUE_FLAG) == 0 && (mData1 & TERMINAL_LINK_FLAG) == 0; + } + + AK_FORCE_INLINE bool hasTerminalLink() const { + return (mData1 & TERMINAL_LINK_FLAG) != 0; + } + + // For terminal entry. + AK_FORCE_INLINE uint32_t getKey() const { + return mData0; + } + + // For terminal entry. + AK_FORCE_INLINE uint32_t getValue() const { + return mData1 & VALUE_MASK; + } + + // For terminal entry. + AK_FORCE_INLINE uint32_t getValueEntryIndex() const { + return mData1 & TERMINAL_LINK_MASK; + } + + // For bitmap entry. + AK_FORCE_INLINE uint32_t getBitmap() const { + return mData0; + } + + // For bitmap entry. + AK_FORCE_INLINE int getTableIndex() const { + return static_cast<int>(mData1); + } + + // For value entry. + AK_FORCE_INLINE uint64_t getValueOfValueEntry() const { + return ((static_cast<uint64_t>(mData0) << (FIELD1_SIZE * CHAR_BIT)) ^ mData1); + } + }; + + BufferWithExtendableBuffer mBuffer; + + static const int FIELD0_SIZE; + static const int FIELD1_SIZE; + static const int ENTRY_SIZE; + static const uint32_t VALUE_FLAG; + static const uint32_t VALUE_MASK; + static const uint32_t TERMINAL_LINK_FLAG; + static const uint32_t TERMINAL_LINK_MASK; + static const int NUM_OF_BITS_USED_FOR_ONE_LEVEL; + static const uint32_t LABEL_MASK; + static const int MAX_NUM_OF_ENTRIES_IN_ONE_LEVEL; + static const int ROOT_BITMAP_ENTRY_INDEX; + static const int ROOT_BITMAP_ENTRY_POS; + static const Entry EMPTY_BITMAP_ENTRY; + static const int MAX_BUFFER_SIZE; + + uint32_t getBitShuffledKey(const uint32_t key) const; + bool writeValue(const uint64_t value, const int terminalEntryIndex); + bool updateValue(const Entry &terminalEntry, const uint64_t value, + const int terminalEntryIndex); + bool freeTable(const int tableIndex, const int entryCount); + int allocateTable(const int entryCount); + int getTerminalEntryIndex(const uint32_t key, const uint32_t hashedKey, + const Entry &bitmapEntry, const int level) const; + const Result getInternal(const uint32_t key, const uint32_t hashedKey, + const int bitmapEntryIndex, const int level) const; + bool putInternal(const uint32_t key, const uint64_t value, const uint32_t hashedKey, + const int bitmapEntryIndex, const Entry &bitmapEntry, const int level); + bool addNewEntryByResolvingConflict(const uint32_t key, const uint64_t value, + const uint32_t hashedKey, const Entry &conflictedEntry, const int conflictedEntryIndex, + const int level); + bool addNewEntryByExpandingTable(const uint32_t key, const uint64_t value, + const int tableIndex, const uint32_t bitmap, const int bitmapEntryIndex, + const int label); + const Result iterateNext(std::vector<TableIterationState> *const iterationState, + int *const outKey) const; + + AK_FORCE_INLINE const Entry readEntry(const int entryIndex) const { + return Entry(readField0(entryIndex), readField1(entryIndex)); + } + + // Returns whether an entry for the index is existing by testing if the index-th bit in the + // bitmap is set or not. + AK_FORCE_INLINE bool exists(const uint32_t bitmap, const int index) const { + return (bitmap & (1 << index)) != 0; + } + + // Set index-th bit in the bitmap. + AK_FORCE_INLINE uint32_t setExist(const uint32_t bitmap, const int index) const { + return bitmap | (1 << index); + } + + // Count set bits before index in the bitmap. + AK_FORCE_INLINE int popCount(const uint32_t bitmap, const int index) const { + return popCount(bitmap & ((1 << index) - 1)); + } + + // Count set bits in the bitmap. + AK_FORCE_INLINE int popCount(const uint32_t bitmap) const { + return __builtin_popcount(bitmap); + // int v = bitmap - ((bitmap >> 1) & 0x55555555); + // v = (v & 0x33333333) + ((v >> 2) & 0x33333333); + // return (((v + (v >> 4)) & 0x0F0F0F0F) * 0x01010101) >> 24; + } + + AK_FORCE_INLINE int getLabel(const uint32_t hashedKey, const int level) const { + return (hashedKey >> (level * NUM_OF_BITS_USED_FOR_ONE_LEVEL)) & LABEL_MASK; + } + + AK_FORCE_INLINE uint32_t readField0(const int entryIndex) const { + return mBuffer.readUint(FIELD0_SIZE, ROOT_BITMAP_ENTRY_POS + entryIndex * ENTRY_SIZE); + } + + AK_FORCE_INLINE uint32_t readField1(const int entryIndex) const { + return mBuffer.readUint(FIELD1_SIZE, + ROOT_BITMAP_ENTRY_POS + entryIndex * ENTRY_SIZE + FIELD0_SIZE); + } + + AK_FORCE_INLINE int readEmptyTableLink(const int entryCount) const { + return mBuffer.readUint(FIELD1_SIZE, (entryCount - 1) * FIELD1_SIZE); + } + + AK_FORCE_INLINE bool writeEmptyTableLink(const int tableIndex, const int entryCount) { + return mBuffer.writeUint(tableIndex, FIELD1_SIZE, (entryCount - 1) * FIELD1_SIZE); + } + + AK_FORCE_INLINE bool writeField0(const uint32_t data, const int entryIndex) { + return mBuffer.writeUint(data, FIELD0_SIZE, + ROOT_BITMAP_ENTRY_POS + entryIndex * ENTRY_SIZE); + } + + AK_FORCE_INLINE bool writeField1(const uint32_t data, const int entryIndex) { + return mBuffer.writeUint(data, FIELD1_SIZE, + ROOT_BITMAP_ENTRY_POS + entryIndex * ENTRY_SIZE + FIELD0_SIZE); + } + + AK_FORCE_INLINE bool writeEntry(const Entry &entry, const int entryIndex) { + return writeField0(entry.mData0, entryIndex) && writeField1(entry.mData1, entryIndex); + } + + AK_FORCE_INLINE bool writeTerminalEntry(const uint32_t key, const uint64_t value, + const int entryIndex) { + return writeField0(key, entryIndex) && writeValue(value, entryIndex); + } + + AK_FORCE_INLINE bool copyEntry(const int originalEntryIndex, const int newEntryIndex) { + return writeEntry(readEntry(originalEntryIndex), newEntryIndex); + } + + AK_FORCE_INLINE int getTailEntryIndex() const { + return (mBuffer.getTailPosition() - ROOT_BITMAP_ENTRY_POS) / ENTRY_SIZE; + } +}; + +} // namespace latinime +#endif /* LATINIME_TRIE_MAP_H */ diff --git a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h index 66ea62406..04cb6603a 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h @@ -69,10 +69,6 @@ class TypingScoring : public Scoring { return 0.0f; } - AK_FORCE_INLINE bool doesAutoCorrectValidWord() const { - return false; - } - AK_FORCE_INLINE bool autoCorrectsToMultiWordSuggestionIfTop() const { return true; } diff --git a/native/jni/src/utils/byte_array_view.h b/native/jni/src/utils/byte_array_view.h new file mode 100644 index 000000000..2c97c6d58 --- /dev/null +++ b/native/jni/src/utils/byte_array_view.h @@ -0,0 +1,87 @@ +/* + * Copyright (C) 2014 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_BYTE_ARRAY_VIEW_H +#define LATINIME_BYTE_ARRAY_VIEW_H + +#include <cstdint> +#include <cstdlib> + +#include "defines.h" + +namespace latinime { + +/** + * Helper class used to keep track of read accesses for a given memory region. + */ +class ReadOnlyByteArrayView { + public: + ReadOnlyByteArrayView() : mPtr(nullptr), mSize(0) {} + + ReadOnlyByteArrayView(const uint8_t *const ptr, const size_t size) + : mPtr(ptr), mSize(size) {} + + AK_FORCE_INLINE size_t size() const { + return mSize; + } + + AK_FORCE_INLINE const uint8_t *data() const { + return mPtr; + } + + private: + DISALLOW_ASSIGNMENT_OPERATOR(ReadOnlyByteArrayView); + + const uint8_t *const mPtr; + const size_t mSize; +}; + +/** + * Helper class used to keep track of read-write accesses for a given memory region. + */ +class ReadWriteByteArrayView { + public: + ReadWriteByteArrayView() : mPtr(nullptr), mSize(0) {} + + ReadWriteByteArrayView(uint8_t *const ptr, const size_t size) + : mPtr(ptr), mSize(size) {} + + AK_FORCE_INLINE size_t size() const { + return mSize; + } + + AK_FORCE_INLINE uint8_t *data() const { + return mPtr; + } + + AK_FORCE_INLINE ReadOnlyByteArrayView getReadOnlyView() const { + return ReadOnlyByteArrayView(mPtr, mSize); + } + + ReadWriteByteArrayView subView(const size_t start, const size_t n) const { + ASSERT(start + n <= mSize); + return ReadWriteByteArrayView(mPtr + start, n); + } + + private: + DISALLOW_ASSIGNMENT_OPERATOR(ReadWriteByteArrayView); + + uint8_t *const mPtr; + const size_t mSize; +}; + +} // namespace latinime +#endif // LATINIME_BYTE_ARRAY_VIEW_H diff --git a/native/jni/src/utils/int_array_view.h b/native/jni/src/utils/int_array_view.h new file mode 100644 index 000000000..c1ddc9812 --- /dev/null +++ b/native/jni/src/utils/int_array_view.h @@ -0,0 +1,105 @@ +/* + * Copyright (C) 2014 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_INT_ARRAY_VIEW_H +#define LATINIME_INT_ARRAY_VIEW_H + +#include <cstdint> +#include <cstdlib> +#include <vector> + +#include "defines.h" + +namespace latinime { + +/** + * Helper class used to provide a read-only view of a given range of integer array. This class + * does not take ownership of the underlying integer array but is designed to be a lightweight + * object that obeys value semantics. + * + * Example: + * <code> + * bool constinsX(IntArrayView view) { + * for (size_t i = 0; i < view.size(); ++i) { + * if (view[i] == 'X') { + * return true; + * } + * } + * return false; + * } + * + * const int codePointArray[] = { 'A', 'B', 'X', 'Z' }; + * auto view = IntArrayView(codePointArray, NELEMS(codePointArray)); + * const bool hasX = constinsX(view); + * </code> + */ +class IntArrayView { + public: + IntArrayView() : mPtr(nullptr), mSize(0) {} + + IntArrayView(const int *const ptr, const size_t size) + : mPtr(ptr), mSize(size) {} + + explicit IntArrayView(const std::vector<int> &vector) + : mPtr(vector.data()), mSize(vector.size()) {} + + template <int N> + AK_FORCE_INLINE static IntArrayView fromFixedSizeArray(const int (&array)[N]) { + return IntArrayView(array, N); + } + + // Returns a view that points one int object. Does not take ownership of the given object. + AK_FORCE_INLINE static IntArrayView fromObject(const int *const object) { + return IntArrayView(object, 1); + } + + AK_FORCE_INLINE int operator[](const size_t index) const { + ASSERT(index < mSize); + return mPtr[index]; + } + + AK_FORCE_INLINE bool empty() const { + return size() == 0; + } + + AK_FORCE_INLINE size_t size() const { + return mSize; + } + + AK_FORCE_INLINE const int *data() const { + return mPtr; + } + + AK_FORCE_INLINE const int *begin() const { + return mPtr; + } + + AK_FORCE_INLINE const int *end() const { + return mPtr + mSize; + } + + private: + DISALLOW_ASSIGNMENT_OPERATOR(IntArrayView); + + const int *const mPtr; + const size_t mSize; +}; + +using WordIdArrayView = IntArrayView; +using PtNodePosArrayView = IntArrayView; + +} // namespace latinime +#endif // LATINIME_MEMORY_VIEW_H |