diff options
Diffstat (limited to 'native/jni/src')
3 files changed, 100 insertions, 4 deletions
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp index b60499e9f..10f90523a 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp @@ -121,8 +121,14 @@ bool Ver4PatriciaTrieNodeWriter::markPtNodeAsWillBecomeNonTerminal( const PatriciaTrieReadingUtils::NodeFlags updatedFlags = DynamicPatriciaTrieReadingUtils::updateAndGetFlags(originalFlags, false /* isMoved */, false /* isDeleted */, true /* willBecomeNonTerminal */); - int writingPos = toBeUpdatedPtNodeParams->getHeadPos(); + if (!mBuffers->getMutableTerminalPositionLookupTable()->setTerminalPtNodePosition( + toBeUpdatedPtNodeParams->getTerminalId(), NOT_A_DICT_POS /* ptNodePos */)) { + AKLOGE("Cannot update terminal position lookup table. terminal id: %d", + toBeUpdatedPtNodeParams->getTerminalId()); + return false; + } // Update flags. + int writingPos = toBeUpdatedPtNodeParams->getHeadPos(); return DynamicPatriciaTrieWritingUtils::writeFlagsAndAdvancePosition(mTrieBuffer, updatedFlags, &writingPos); } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp index 21d009ecb..77fb41dc5 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp @@ -17,6 +17,7 @@ #include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h" #include <cstring> +#include <queue> #include "suggest/policyimpl/dictionary/bigram/ver4_bigram_list_policy.h" #include "suggest/policyimpl/dictionary/header/header_policy.h" @@ -97,10 +98,16 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, &traversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted)) { return false; } + const int unigramCount = traversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted + .getValidUnigramCount(); if (headerPolicy->isDecayingDict() - && traversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted - .getValidUnigramCount() > ForgettingCurveUtils::MAX_UNIGRAM_COUNT_AFTER_GC) { - // TODO: Remove more unigrams. + && unigramCount > ForgettingCurveUtils::MAX_UNIGRAM_COUNT_AFTER_GC) { + if (!turncateUnigrams(&ptNodeReader, &ptNodeWriter, + ForgettingCurveUtils::MAX_UNIGRAM_COUNT_AFTER_GC)) { + AKLOGE("Cannot remove unigrams. current: %d, max: %d", unigramCount, + ForgettingCurveUtils::MAX_UNIGRAM_COUNT_AFTER_GC); + return false; + } } readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos); @@ -179,6 +186,42 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, return true; } +bool Ver4PatriciaTrieWritingHelper::turncateUnigrams( + const Ver4PatriciaTrieNodeReader *const ptNodeReader, + Ver4PatriciaTrieNodeWriter *const ptNodeWriter, const int maxUnigramCount) { + const TerminalPositionLookupTable *const terminalPosLookupTable = + mBuffers->getTerminalPositionLookupTable(); + const int nextTerminalId = terminalPosLookupTable->getNextTerminalId(); + std::priority_queue<DictProbability, std::vector<DictProbability>, DictProbabilityComparator> + priorityQueue; + for (int i = 0; i < nextTerminalId; ++i) { + const int terminalPos = terminalPosLookupTable->getTerminalPtNodePosition(i); + if (terminalPos == NOT_A_DICT_POS) { + continue; + } + const ProbabilityEntry probabilityEntry = + mBuffers->getProbabilityDictContent()->getProbabilityEntry(i); + const int probability = probabilityEntry.hasHistoricalInfo() ? + ForgettingCurveUtils::decodeProbability(probabilityEntry.getHistoricalInfo()) : + probabilityEntry.getProbability(); + priorityQueue.push(DictProbability(terminalPos, probability, + probabilityEntry.getHistoricalInfo()->getTimeStamp())); + } + + // Delete unigrams. + while (static_cast<int>(priorityQueue.size()) > maxUnigramCount) { + const int ptNodePos = priorityQueue.top().getDictPos(); + const PtNodeParams ptNodeParams = + ptNodeReader->fetchNodeInfoInBufferFromPtNodePos(ptNodePos); + if (!ptNodeWriter->markPtNodeAsWillBecomeNonTerminal(&ptNodeParams)) { + AKLOGE("Cannot mark PtNode as willBecomeNonterminal. PtNode pos: %d", ptNodePos); + return false; + } + priorityQueue.pop(); + } + return true; +} + bool Ver4PatriciaTrieWritingHelper::TraversePolicyToUpdateAllPtNodeFlagsAndTerminalIds ::onVisitingPtNode(const PtNodeParams *const ptNodeParams) { if (!ptNodeParams->isTerminal()) { diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h index 82877fdcc..26eb678b0 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h @@ -25,6 +25,7 @@ namespace latinime { class HeaderPolicy; class Ver4DictBuffers; +class Ver4PatriciaTrieNodeReader; class Ver4PatriciaTrieNodeWriter; class Ver4PatriciaTrieWritingHelper { @@ -64,10 +65,56 @@ class Ver4PatriciaTrieWritingHelper { const TerminalPositionLookupTable::TerminalIdMap *const mTerminalIdMap; }; + // For truncateUnigrams(). + class DictProbability { + public: + DictProbability(const int dictPos, const int probability, const int timestamp) + : mDictPos(dictPos), mProbability(probability), mTimestamp(timestamp) {} + + int getDictPos() const { + return mDictPos; + } + + int getProbability() const { + return mProbability; + } + + int getTimestamp() const { + return mTimestamp; + } + + private: + DISALLOW_DEFAULT_CONSTRUCTOR(DictProbability); + + int mDictPos; + int mProbability; + int mTimestamp; + }; + + // For truncateUnigrams(). + class DictProbabilityComparator { + public: + bool operator()(const DictProbability &left, const DictProbability &right) { + if (left.getProbability() != right.getProbability()) { + return left.getProbability() > right.getProbability(); + } + if (left.getTimestamp() != right.getTimestamp()) { + return left.getTimestamp() < right.getTimestamp(); + } + return left.getDictPos() > right.getDictPos(); + } + + private: + DISALLOW_ASSIGNMENT_OPERATOR(DictProbabilityComparator); + }; + bool runGC(const int rootPtNodeArrayPos, const HeaderPolicy *const headerPolicy, Ver4DictBuffers *const buffersToWrite, int *const outUnigramCount, int *const outBigramCount); + bool turncateUnigrams(const Ver4PatriciaTrieNodeReader *const ptNodeReader, + Ver4PatriciaTrieNodeWriter *const ptNodeWriter, const int maxUnigramCount); + Ver4DictBuffers *const mBuffers; }; } // namespace latinime |