diff options
Diffstat (limited to 'native/jni/src')
33 files changed, 575 insertions, 173 deletions
diff --git a/native/jni/src/suggest/core/dictionary/error_type_utils.cpp b/native/jni/src/suggest/core/dictionary/error_type_utils.cpp index b6bf7a98c..1e2494e92 100644 --- a/native/jni/src/suggest/core/dictionary/error_type_utils.cpp +++ b/native/jni/src/suggest/core/dictionary/error_type_utils.cpp @@ -19,17 +19,18 @@ namespace latinime { const ErrorTypeUtils::ErrorType ErrorTypeUtils::NOT_AN_ERROR = 0x0; -const ErrorTypeUtils::ErrorType ErrorTypeUtils::MATCH_WITH_CASE_ERROR = 0x1; -const ErrorTypeUtils::ErrorType ErrorTypeUtils::MATCH_WITH_ACCENT_ERROR = 0x2; -const ErrorTypeUtils::ErrorType ErrorTypeUtils::MATCH_WITH_DIGRAPH = 0x4; -const ErrorTypeUtils::ErrorType ErrorTypeUtils::INTENTIONAL_OMISSION = 0x8; -const ErrorTypeUtils::ErrorType ErrorTypeUtils::EDIT_CORRECTION = 0x10; -const ErrorTypeUtils::ErrorType ErrorTypeUtils::PROXIMITY_CORRECTION = 0x20; -const ErrorTypeUtils::ErrorType ErrorTypeUtils::COMPLETION = 0x40; -const ErrorTypeUtils::ErrorType ErrorTypeUtils::NEW_WORD = 0x80; +const ErrorTypeUtils::ErrorType ErrorTypeUtils::MATCH_WITH_WRONG_CASE = 0x1; +const ErrorTypeUtils::ErrorType ErrorTypeUtils::MATCH_WITH_MISSING_ACCENT = 0x2; +const ErrorTypeUtils::ErrorType ErrorTypeUtils::MATCH_WITH_WRONG_ACCENT = 0x4; +const ErrorTypeUtils::ErrorType ErrorTypeUtils::MATCH_WITH_DIGRAPH = 0x8; +const ErrorTypeUtils::ErrorType ErrorTypeUtils::INTENTIONAL_OMISSION = 0x10; +const ErrorTypeUtils::ErrorType ErrorTypeUtils::EDIT_CORRECTION = 0x20; +const ErrorTypeUtils::ErrorType ErrorTypeUtils::PROXIMITY_CORRECTION = 0x40; +const ErrorTypeUtils::ErrorType ErrorTypeUtils::COMPLETION = 0x80; +const ErrorTypeUtils::ErrorType ErrorTypeUtils::NEW_WORD = 0x100; const ErrorTypeUtils::ErrorType ErrorTypeUtils::ERRORS_TREATED_AS_AN_EXACT_MATCH = - NOT_AN_ERROR | MATCH_WITH_CASE_ERROR | MATCH_WITH_ACCENT_ERROR | MATCH_WITH_DIGRAPH; + NOT_AN_ERROR | MATCH_WITH_WRONG_CASE | MATCH_WITH_MISSING_ACCENT | MATCH_WITH_DIGRAPH; const ErrorTypeUtils::ErrorType ErrorTypeUtils::ERRORS_TREATED_AS_AN_EXACT_MATCH_WITH_INTENTIONAL_OMISSION = diff --git a/native/jni/src/suggest/core/dictionary/error_type_utils.h b/native/jni/src/suggest/core/dictionary/error_type_utils.h index e3e76b238..fd1d5fcff 100644 --- a/native/jni/src/suggest/core/dictionary/error_type_utils.h +++ b/native/jni/src/suggest/core/dictionary/error_type_utils.h @@ -30,8 +30,9 @@ class ErrorTypeUtils { typedef uint32_t ErrorType; static const ErrorType NOT_AN_ERROR; - static const ErrorType MATCH_WITH_CASE_ERROR; - static const ErrorType MATCH_WITH_ACCENT_ERROR; + static const ErrorType MATCH_WITH_WRONG_CASE; + static const ErrorType MATCH_WITH_MISSING_ACCENT; + static const ErrorType MATCH_WITH_WRONG_ACCENT; static const ErrorType MATCH_WITH_DIGRAPH; // Treat error as an intentional omission when the CorrectionType is omission and the node can // be intentional omission. diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp index 278f2b199..97a8bcc98 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp @@ -234,8 +234,8 @@ bool Ver4PatriciaTrieNodeWriter::writeNewTerminalPtNodeAndAdvancePosition( bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds, const int wordId, const BigramProperty *const bigramProperty, bool *const outAddedNewEntry) { if (!mBigramPolicy->addNewEntry(prevWordIds[0], wordId, bigramProperty, outAddedNewEntry)) { - AKLOGE("Cannot add new bigram entry. terminalId: %d, targetTerminalId: %d", - sourcePtNodeParams->getTerminalId(), targetPtNodeParam->getTerminalId()); + AKLOGE("Cannot add new bigram entry. prevWordId: %d, wordId: %d", + prevWordIds[0], wordId); return false; } const int ptNodePos = @@ -425,6 +425,18 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeFlags(const int ptNodePos, return true; } +bool Ver4PatriciaTrieNodeWriter::suppressUnigramEntry(const PtNodeParams *const ptNodeParams) { + if (!mHeaderPolicy->hasHistoricalInfoOfWords()) { + // Require historical info to suppress unigram entry. + return false; + } + const HistoricalInfo suppressedHistorycalInfo(0 /* timestamp */, 0 /* level */, 0 /* count */); + const ProbabilityEntry probabilityEntryToWrite = + ProbabilityEntry().createEntryWithUpdatedHistoricalInfo(&suppressedHistorycalInfo); + return mBuffers->getMutableProbabilityDictContent()->setProbabilityEntry( + ptNodeParams->getTerminalId(), &probabilityEntryToWrite); +} + } // namespace v402 } // namespace backward } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.h index d49d9a666..9d8a55bff 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.h @@ -111,6 +111,11 @@ class Ver4PatriciaTrieNodeWriter : public PtNodeWriter { bool updatePtNodeHasBigramsAndShortcutTargetsFlags(const PtNodeParams *const ptNodeParams); + // Suppress unigram not to use the word for generating suggestions. So, this method can be used + // only for dictionaries with historical info. Also, suppressed entries are included in unigram + // count. They will be removed from the dictionary during GC. + bool suppressUnigramEntry(const PtNodeParams *const ptNodeParams); + private: DISALLOW_COPY_AND_ASSIGN(Ver4PatriciaTrieNodeWriter); diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp index 1296b8acd..9c6452e40 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp @@ -210,7 +210,7 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const int *const word, const int le } for (const auto &shortcut : unigramProperty->getShortcuts()) { if (shortcut.getTargetCodePoints()->size() > MAX_WORD_LENGTH) { - AKLOGE("One of shortcut targets is too long to insert to the dictionary, length: %d", + AKLOGE("One of shortcut targets is too long to insert to the dictionary, length: %zd", shortcut.getTargetCodePoints()->size()); return false; } @@ -245,7 +245,7 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const int *const word, const int le if (!mUpdatingHelper.addShortcutTarget(wordPos, shortcut.getTargetCodePoints()->data(), shortcut.getTargetCodePoints()->size(), shortcut.getProbability())) { - AKLOGE("Cannot add new shortcut target. PtNodePos: %d, length: %d, " + AKLOGE("Cannot add new shortcut target. PtNodePos: %d, length: %zd, " "probability: %d", wordPos, shortcut.getTargetCodePoints()->size(), shortcut.getProbability()); return false; @@ -258,6 +258,20 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const int *const word, const int le } } +bool Ver4PatriciaTriePolicy::removeUnigramEntry(const int *const word, const int length) { + if (!mBuffers->isUpdatable()) { + AKLOGI("Warning: removeUnigramEntry() is called for non-updatable dictionary."); + return false; + } + const int ptNodePos = getTerminalPtNodePositionOfWord(word, length, + false /* forceLowerCaseSearch */); + if (ptNodePos == NOT_A_DICT_POS) { + return false; + } + const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); + return mNodeWriter.suppressUnigramEntry(&ptNodeParams); +} + bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsInfo, const BigramProperty *const bigramProperty) { if (!mBuffers->isUpdatable()) { @@ -275,7 +289,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI } if (bigramProperty->getTargetCodePoints()->size() > MAX_WORD_LENGTH) { AKLOGE("The word is too long to insert the ngram to the dictionary. " - "length: %d", bigramProperty->getTargetCodePoints()->size()); + "length: %zd", bigramProperty->getTargetCodePoints()->size()); return false; } int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h index 9e989b268..d77499636 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h @@ -108,10 +108,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { bool addUnigramEntry(const int *const word, const int length, const UnigramProperty *const unigramProperty); - bool removeUnigramEntry(const int *const word, const int length) { - // Removing unigram entry is not supported. - return false; - } + bool removeUnigramEntry(const int *const word, const int length); bool addNgramEntry(const PrevWordsInfo *const prevWordsInfo, const BigramProperty *const bigramProperty); diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp index e4ea3da16..9fa93efc9 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp @@ -111,8 +111,7 @@ template<class DictConstants, class DictBuffers, class DictBuffersPtr, class Str return nullptr; } const FormatUtils::FORMAT_VERSION formatVersion = FormatUtils::detectFormatVersion( - mmappedBuffer->getReadOnlyByteArrayView().data(), - mmappedBuffer->getReadOnlyByteArrayView().size()); + mmappedBuffer->getReadOnlyByteArrayView()); switch (formatVersion) { case FormatUtils::VERSION_2: AKLOGE("Given path is a directory but the format is version 2. path: %s", path); @@ -174,8 +173,7 @@ template<class DictConstants, class DictBuffers, class DictBuffersPtr, class Str if (!mmappedBuffer) { return nullptr; } - switch (FormatUtils::detectFormatVersion(mmappedBuffer->getReadOnlyByteArrayView().data(), - mmappedBuffer->getReadOnlyByteArrayView().size())) { + switch (FormatUtils::detectFormatVersion(mmappedBuffer->getReadOnlyByteArrayView())) { case FormatUtils::VERSION_2: return DictionaryStructureWithBufferPolicy::StructurePolicyPtr( new PatriciaTriePolicy(std::move(mmappedBuffer))); diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/bigram_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/bigram_dict_content.h index 361dd2c74..20bae5943 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/bigram_dict_content.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/bigram_dict_content.h @@ -17,7 +17,6 @@ #ifndef LATINIME_BIGRAM_DICT_CONTENT_H #define LATINIME_BIGRAM_DICT_CONTENT_H -#include <cstdint> #include <cstdio> #include "defines.h" @@ -28,11 +27,12 @@ namespace latinime { +class ReadWriteByteArrayView; + class BigramDictContent : public SparseTableDictContent { public: - BigramDictContent(uint8_t *const *buffers, const int *bufferSizes, const bool hasHistoricalInfo) - : SparseTableDictContent(buffers, bufferSizes, - Ver4DictConstants::BIGRAM_ADDRESS_TABLE_BLOCK_SIZE, + BigramDictContent(const ReadWriteByteArrayView *const buffers, const bool hasHistoricalInfo) + : SparseTableDictContent(buffers, Ver4DictConstants::BIGRAM_ADDRESS_TABLE_BLOCK_SIZE, Ver4DictConstants::BIGRAM_ADDRESS_TABLE_DATA_SIZE), mHasHistoricalInfo(hasHistoricalInfo) {} diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp index 5dc91ba10..ea2d24e67 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp @@ -16,6 +16,11 @@ #include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h" +#include <algorithm> +#include <cstring> + +#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h" + namespace latinime { bool LanguageModelDictContent::save(FILE *const file) const { @@ -45,12 +50,38 @@ ProbabilityEntry LanguageModelDictContent::getNgramProbabilityEntry( } bool LanguageModelDictContent::setNgramProbabilityEntry(const WordIdArrayView prevWordIds, - const int terminalId, const ProbabilityEntry *const probabilityEntry) { + const int wordId, const ProbabilityEntry *const probabilityEntry) { + if (wordId == Ver4DictConstants::NOT_A_TERMINAL_ID) { + return false; + } + const int bitmapEntryIndex = createAndGetBitmapEntryIndex(prevWordIds); + if (bitmapEntryIndex == TrieMap::INVALID_INDEX) { + return false; + } + return mTrieMap.put(wordId, probabilityEntry->encode(mHasHistoricalInfo), bitmapEntryIndex); +} + +bool LanguageModelDictContent::removeNgramProbabilityEntry(const WordIdArrayView prevWordIds, + const int wordId) { const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds); if (bitmapEntryIndex == TrieMap::INVALID_INDEX) { + // Cannot find bitmap entry for the probability entry. The entry doesn't exist. return false; } - return mTrieMap.put(terminalId, probabilityEntry->encode(mHasHistoricalInfo), bitmapEntryIndex); + return mTrieMap.remove(wordId, bitmapEntryIndex); +} + +bool LanguageModelDictContent::truncateEntries(const int *const entryCounts, + const int *const maxEntryCounts, const HeaderPolicy *const headerPolicy) { + for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) { + if (entryCounts[i] <= maxEntryCounts[i]) { + continue; + } + if (!turncateEntriesInSpecifiedLevel(headerPolicy, maxEntryCounts[i], i)) { + return false; + } + } + return true; } bool LanguageModelDictContent::runGCInner( @@ -80,6 +111,19 @@ bool LanguageModelDictContent::runGCInner( return true; } +int LanguageModelDictContent::createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds) { + if (prevWordIds.empty()) { + return mTrieMap.getRootBitmapEntryIndex(); + } + const int lastBitmapEntryIndex = + getBitmapEntryIndex(prevWordIds.limit(prevWordIds.size() - 1)); + if (lastBitmapEntryIndex == TrieMap::INVALID_INDEX) { + return TrieMap::INVALID_INDEX; + } + return mTrieMap.getNextLevelBitmapEntryIndex(prevWordIds[prevWordIds.size() - 1], + lastBitmapEntryIndex); +} + int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWordIds) const { int bitmapEntryIndex = mTrieMap.getRootBitmapEntryIndex(); for (const int wordId : prevWordIds) { @@ -92,4 +136,129 @@ int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWord return bitmapEntryIndex; } +bool LanguageModelDictContent::updateAllProbabilityEntriesInner(const int bitmapEntryIndex, + const int level, const HeaderPolicy *const headerPolicy, int *const outEntryCounts) { + for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) { + if (level > MAX_PREV_WORD_COUNT_FOR_N_GRAM) { + AKLOGE("Invalid level. level: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.", + level, MAX_PREV_WORD_COUNT_FOR_N_GRAM); + return false; + } + const ProbabilityEntry probabilityEntry = + ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo); + if (mHasHistoricalInfo && !probabilityEntry.representsBeginningOfSentence()) { + const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave( + probabilityEntry.getHistoricalInfo(), headerPolicy); + if (ForgettingCurveUtils::needsToKeep(&historicalInfo, headerPolicy)) { + // Update the entry. + const ProbabilityEntry updatedEntry(probabilityEntry.getFlags(), &historicalInfo); + if (!mTrieMap.put(entry.key(), updatedEntry.encode(mHasHistoricalInfo), + bitmapEntryIndex)) { + return false; + } + } else { + // Remove the entry. + if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) { + return false; + } + continue; + } + } + if (!probabilityEntry.representsBeginningOfSentence()) { + outEntryCounts[level] += 1; + } + if (!entry.hasNextLevelMap()) { + continue; + } + if (!updateAllProbabilityEntriesInner(entry.getNextLevelBitmapEntryIndex(), level + 1, + headerPolicy, outEntryCounts)) { + return false; + } + } + return true; +} + +bool LanguageModelDictContent::turncateEntriesInSpecifiedLevel( + const HeaderPolicy *const headerPolicy, const int maxEntryCount, const int targetLevel) { + std::vector<int> prevWordIds; + std::vector<EntryInfoToTurncate> entryInfoVector; + if (!getEntryInfo(headerPolicy, targetLevel, mTrieMap.getRootBitmapEntryIndex(), + &prevWordIds, &entryInfoVector)) { + return false; + } + if (static_cast<int>(entryInfoVector.size()) <= maxEntryCount) { + return true; + } + const int entryCountToRemove = static_cast<int>(entryInfoVector.size()) - maxEntryCount; + std::partial_sort(entryInfoVector.begin(), entryInfoVector.begin() + entryCountToRemove, + entryInfoVector.end(), + EntryInfoToTurncate::Comparator()); + for (int i = 0; i < entryCountToRemove; ++i) { + const EntryInfoToTurncate &entryInfo = entryInfoVector[i]; + if (!removeNgramProbabilityEntry( + WordIdArrayView(entryInfo.mPrevWordIds, entryInfo.mEntryLevel), entryInfo.mKey)) { + return false; + } + } + return true; +} + +bool LanguageModelDictContent::getEntryInfo(const HeaderPolicy *const headerPolicy, + const int targetLevel, const int bitmapEntryIndex, std::vector<int> *const prevWordIds, + std::vector<EntryInfoToTurncate> *const outEntryInfo) const { + const int currentLevel = prevWordIds->size(); + for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) { + if (currentLevel < targetLevel) { + if (!entry.hasNextLevelMap()) { + continue; + } + prevWordIds->push_back(entry.key()); + if (!getEntryInfo(headerPolicy, targetLevel, entry.getNextLevelBitmapEntryIndex(), + prevWordIds, outEntryInfo)) { + return false; + } + prevWordIds->pop_back(); + continue; + } + const ProbabilityEntry probabilityEntry = + ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo); + const int probability = (mHasHistoricalInfo) ? + ForgettingCurveUtils::decodeProbability(probabilityEntry.getHistoricalInfo(), + headerPolicy) : probabilityEntry.getProbability(); + outEntryInfo->emplace_back(probability, + probabilityEntry.getHistoricalInfo()->getTimeStamp(), + entry.key(), targetLevel, prevWordIds->data()); + } + return true; +} + +bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()( + const EntryInfoToTurncate &left, const EntryInfoToTurncate &right) const { + if (left.mProbability != right.mProbability) { + return left.mProbability < right.mProbability; + } + if (left.mTimestamp != right.mTimestamp) { + return left.mTimestamp > right.mTimestamp; + } + if (left.mKey != right.mKey) { + return left.mKey < right.mKey; + } + if (left.mEntryLevel != right.mEntryLevel) { + return left.mEntryLevel > right.mEntryLevel; + } + for (int i = 0; i < left.mEntryLevel; ++i) { + if (left.mPrevWordIds[i] != right.mPrevWordIds[i]) { + return left.mPrevWordIds[i] < right.mPrevWordIds[i]; + } + } + // left and rigth represent the same entry. + return false; +} + +LanguageModelDictContent::EntryInfoToTurncate::EntryInfoToTurncate(const int probability, + const int timestamp, const int key, const int entryLevel, const int *const prevWordIds) + : mProbability(probability), mTimestamp(timestamp), mKey(key), mEntryLevel(entryLevel) { + memmove(mPrevWordIds, prevWordIds, mEntryLevel * sizeof(mPrevWordIds[0])); +} + } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h index 18f2e0170..43b2aab66 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h @@ -18,6 +18,7 @@ #define LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H #include <cstdio> +#include <vector> #include "defines.h" #include "suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h" @@ -29,6 +30,8 @@ namespace latinime { +class HeaderPolicy; + /** * Class representing language model. * @@ -61,23 +64,72 @@ class LanguageModelDictContent { return setNgramProbabilityEntry(WordIdArrayView(), wordId, probabilityEntry); } + bool removeProbabilityEntry(const int wordId) { + return removeNgramProbabilityEntry(WordIdArrayView(), wordId); + } + ProbabilityEntry getNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId) const; bool setNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId, const ProbabilityEntry *const probabilityEntry); + bool removeNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId); + + bool updateAllProbabilityEntries(const HeaderPolicy *const headerPolicy, + int *const outEntryCounts) { + for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) { + outEntryCounts[i] = 0; + } + return updateAllProbabilityEntriesInner(mTrieMap.getRootBitmapEntryIndex(), 0 /* level */, + headerPolicy, outEntryCounts); + } + + // entryCounts should be created by updateAllProbabilityEntries. + bool truncateEntries(const int *const entryCounts, const int *const maxEntryCounts, + const HeaderPolicy *const headerPolicy); + private: DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent); + class EntryInfoToTurncate { + public: + class Comparator { + public: + bool operator()(const EntryInfoToTurncate &left, + const EntryInfoToTurncate &right) const; + private: + DISALLOW_ASSIGNMENT_OPERATOR(Comparator); + }; + + EntryInfoToTurncate(const int probability, const int timestamp, const int key, + const int entryLevel, const int *const prevWordIds); + + int mProbability; + int mTimestamp; + int mKey; + int mEntryLevel; + int mPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; + + private: + DISALLOW_DEFAULT_CONSTRUCTOR(EntryInfoToTurncate); + }; + TrieMap mTrieMap; const bool mHasHistoricalInfo; bool runGCInner(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex, int *const outNgramCount); - + int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds); int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const; + bool updateAllProbabilityEntriesInner(const int bitmapEntryIndex, const int level, + const HeaderPolicy *const headerPolicy, int *const outEntryCounts); + bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy, + const int maxEntryCount, const int targetLevel); + bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel, + const int bitmapEntryIndex, std::vector<int> *const prevWordIds, + std::vector<EntryInfoToTurncate> *const outEntryInfo) const; }; } // namespace latinime #endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h index feff6b57f..3dfaba755 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h @@ -21,6 +21,8 @@ #include <cstdint> #include "defines.h" +#include "suggest/core/dictionary/property/bigram_property.h" +#include "suggest/core/dictionary/property/unigram_property.h" #include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h" #include "suggest/policyimpl/dictionary/utils/historical_info.h" @@ -41,24 +43,32 @@ class ProbabilityEntry { : mFlags(flags), mProbability(probability), mHistoricalInfo() {} // Entry with historical information. - ProbabilityEntry(const int flags, const int probability, - const HistoricalInfo *const historicalInfo) - : mFlags(flags), mProbability(probability), mHistoricalInfo(*historicalInfo) {} - - const ProbabilityEntry createEntryWithUpdatedProbability(const int probability) const { - return ProbabilityEntry(mFlags, probability, &mHistoricalInfo); - } - - const ProbabilityEntry createEntryWithUpdatedHistoricalInfo( - const HistoricalInfo *const historicalInfo) const { - return ProbabilityEntry(mFlags, mProbability, historicalInfo); + ProbabilityEntry(const int flags, const HistoricalInfo *const historicalInfo) + : mFlags(flags), mProbability(NOT_A_PROBABILITY), mHistoricalInfo(*historicalInfo) {} + + // Create from unigram property. + ProbabilityEntry(const UnigramProperty *const unigramProperty) + : mFlags(createFlags(unigramProperty->representsBeginningOfSentence())), + mProbability(unigramProperty->getProbability()), + mHistoricalInfo(unigramProperty->getTimestamp(), unigramProperty->getLevel(), + unigramProperty->getCount()) {} + + // Create from bigram property. + // TODO: Set flags. + ProbabilityEntry(const BigramProperty *const bigramProperty) + : mFlags(0), mProbability(bigramProperty->getProbability()), + mHistoricalInfo(bigramProperty->getTimestamp(), bigramProperty->getLevel(), + bigramProperty->getCount()) {} + + bool isValid() const { + return (mProbability != NOT_A_PROBABILITY) || hasHistoricalInfo(); } bool hasHistoricalInfo() const { return mHistoricalInfo.isValid(); } - int getFlags() const { + uint8_t getFlags() const { return mFlags; } @@ -70,6 +80,10 @@ class ProbabilityEntry { return &mHistoricalInfo; } + bool representsBeginningOfSentence() const { + return (mFlags & Ver4DictConstants::FLAG_REPRESENTS_BEGINNING_OF_SENTENCE) != 0; + } + uint64_t encode(const bool hasHistoricalInfo) const { uint64_t encodedEntry = static_cast<uint64_t>(mFlags); if (hasHistoricalInfo) { @@ -89,7 +103,7 @@ class ProbabilityEntry { static ProbabilityEntry decode(const uint64_t encodedEntry, const bool hasHistoricalInfo) { if (hasHistoricalInfo) { const int flags = readFromEncodedEntry(encodedEntry, - Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE, + Ver4DictConstants::FLAGS_IN_LANGUAGE_MODEL_SIZE, Ver4DictConstants::TIME_STAMP_FIELD_SIZE + Ver4DictConstants::WORD_LEVEL_FIELD_SIZE + Ver4DictConstants::WORD_COUNT_FIELD_SIZE); @@ -103,10 +117,10 @@ class ProbabilityEntry { const int count = readFromEncodedEntry(encodedEntry, Ver4DictConstants::WORD_COUNT_FIELD_SIZE, 0 /* pos */); const HistoricalInfo historicalInfo(timestamp, level, count); - return ProbabilityEntry(flags, NOT_A_PROBABILITY, &historicalInfo); + return ProbabilityEntry(flags, &historicalInfo); } else { const int flags = readFromEncodedEntry(encodedEntry, - Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE, + Ver4DictConstants::FLAGS_IN_LANGUAGE_MODEL_SIZE, Ver4DictConstants::PROBABILITY_SIZE); const int probability = readFromEncodedEntry(encodedEntry, Ver4DictConstants::PROBABILITY_SIZE, 0 /* pos */); @@ -118,7 +132,7 @@ class ProbabilityEntry { // Copy constructor is public to use this class as a type of return value. DISALLOW_ASSIGNMENT_OPERATOR(ProbabilityEntry); - const int mFlags; + const uint8_t mFlags; const int mProbability; const HistoricalInfo mHistoricalInfo; @@ -126,6 +140,14 @@ class ProbabilityEntry { return static_cast<int>( (encodedEntry >> (pos * CHAR_BIT)) & ((1ull << (size * CHAR_BIT)) - 1)); } + + static uint8_t createFlags(const bool representsBeginningOfSentence) { + uint8_t flags = 0; + if (representsBeginningOfSentence) { + flags ^= Ver4DictConstants::FLAG_REPRESENTS_BEGINNING_OF_SENTENCE; + } + return flags; + } }; } // namespace latinime #endif /* LATINIME_PROBABILITY_ENTRY_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/shortcut_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/shortcut_dict_content.h index 7b12aff16..85c9ce8d8 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/shortcut_dict_content.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/shortcut_dict_content.h @@ -17,7 +17,6 @@ #ifndef LATINIME_SHORTCUT_DICT_CONTENT_H #define LATINIME_SHORTCUT_DICT_CONTENT_H -#include <cstdint> #include <cstdio> #include "defines.h" @@ -27,11 +26,12 @@ namespace latinime { +class ReadWriteByteArrayView; + class ShortcutDictContent : public SparseTableDictContent { public: - ShortcutDictContent(uint8_t *const *buffers, const int *bufferSizes) - : SparseTableDictContent(buffers, bufferSizes, - Ver4DictConstants::SHORTCUT_ADDRESS_TABLE_BLOCK_SIZE, + ShortcutDictContent(const ReadWriteByteArrayView *const buffers) + : SparseTableDictContent(buffers, Ver4DictConstants::SHORTCUT_ADDRESS_TABLE_BLOCK_SIZE, Ver4DictConstants::SHORTCUT_ADDRESS_TABLE_DATA_SIZE) {} ShortcutDictContent() diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/single_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/single_dict_content.h index 921774181..309c434cf 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/single_dict_content.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/single_dict_content.h @@ -17,7 +17,6 @@ #ifndef LATINIME_SINGLE_DICT_CONTENT_H #define LATINIME_SINGLE_DICT_CONTENT_H -#include <cstdint> #include <cstdio> #include "defines.h" @@ -30,9 +29,9 @@ namespace latinime { class SingleDictContent { public: - SingleDictContent(uint8_t *const buffer, const int bufferSize) - : mExpandableContentBuffer(ReadWriteByteArrayView(buffer, bufferSize), - BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE) {} + SingleDictContent(const ReadWriteByteArrayView buffer) + : mExpandableContentBuffer(buffer, + BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE) {} SingleDictContent() : mExpandableContentBuffer(Ver4DictConstants::MAX_DICTIONARY_SIZE) {} diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/sparse_table_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/sparse_table_dict_content.h index c98dd11fd..0ce2da7bf 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/sparse_table_dict_content.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/sparse_table_dict_content.h @@ -17,7 +17,6 @@ #ifndef LATINIME_SPARSE_TABLE_DICT_CONTENT_H #define LATINIME_SPARSE_TABLE_DICT_CONTENT_H -#include <cstdint> #include <cstdio> #include "defines.h" @@ -31,19 +30,13 @@ namespace latinime { // TODO: Support multiple contents. class SparseTableDictContent { public: - AK_FORCE_INLINE SparseTableDictContent(uint8_t *const *buffers, const int *bufferSizes, + AK_FORCE_INLINE SparseTableDictContent(const ReadWriteByteArrayView *const buffers, const int sparseTableBlockSize, const int sparseTableDataSize) - : mExpandableLookupTableBuffer( - ReadWriteByteArrayView(buffers[LOOKUP_TABLE_BUFFER_INDEX], - bufferSizes[LOOKUP_TABLE_BUFFER_INDEX]), + : mExpandableLookupTableBuffer(buffers[LOOKUP_TABLE_BUFFER_INDEX], BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE), - mExpandableAddressTableBuffer( - ReadWriteByteArrayView(buffers[ADDRESS_TABLE_BUFFER_INDEX], - bufferSizes[ADDRESS_TABLE_BUFFER_INDEX]), + mExpandableAddressTableBuffer(buffers[ADDRESS_TABLE_BUFFER_INDEX], BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE), - mExpandableContentBuffer( - ReadWriteByteArrayView(buffers[CONTENT_BUFFER_INDEX], - bufferSizes[CONTENT_BUFFER_INDEX]), + mExpandableContentBuffer(buffers[CONTENT_BUFFER_INDEX], BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE), mAddressLookupTable(&mExpandableLookupTableBuffer, &mExpandableAddressTableBuffer, sparseTableBlockSize, sparseTableDataSize) {} diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h index b2262bf1e..febcbe5b4 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h @@ -17,13 +17,13 @@ #ifndef LATINIME_TERMINAL_POSITION_LOOKUP_TABLE_H #define LATINIME_TERMINAL_POSITION_LOOKUP_TABLE_H -#include <cstdint> #include <cstdio> #include <unordered_map> #include "defines.h" #include "suggest/policyimpl/dictionary/structure/v4/content/single_dict_content.h" #include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h" +#include "utils/byte_array_view.h" namespace latinime { @@ -31,8 +31,8 @@ class TerminalPositionLookupTable : public SingleDictContent { public: typedef std::unordered_map<int, int> TerminalIdMap; - TerminalPositionLookupTable(uint8_t *const buffer, const int bufferSize) - : SingleDictContent(buffer, bufferSize), + TerminalPositionLookupTable(const ReadWriteByteArrayView buffer) + : SingleDictContent(buffer), mSize(getBuffer()->getTailPosition() / Ver4DictConstants::TERMINAL_ADDRESS_TABLE_ADDRESS_SIZE) {} diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.cpp index 3c8008dc4..1f40e3dd2 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.cpp @@ -45,16 +45,13 @@ namespace latinime { if (!bodyBuffer) { return Ver4DictBuffersPtr(nullptr); } - std::vector<uint8_t *> buffers; - std::vector<int> bufferSizes; + std::vector<ReadWriteByteArrayView> buffers; const ReadWriteByteArrayView buffer = bodyBuffer->getReadWriteByteArrayView(); int position = 0; while (position < static_cast<int>(buffer.size())) { const int bufferSize = ByteArrayUtils::readUint32AndAdvancePosition( buffer.data(), &position); - const ReadWriteByteArrayView subBuffer = buffer.subView(position, bufferSize); - buffers.push_back(subBuffer.data()); - bufferSizes.push_back(subBuffer.size()); + buffers.push_back(buffer.subView(position, bufferSize)); position += bufferSize; if (bufferSize < 0 || position < 0 || position > static_cast<int>(buffer.size())) { AKLOGE("The dict body file is corrupted."); @@ -66,7 +63,7 @@ namespace latinime { return Ver4DictBuffersPtr(nullptr); } return Ver4DictBuffersPtr(new Ver4DictBuffers(std::move(headerBuffer), std::move(bodyBuffer), - formatVersion, buffers, bufferSizes)); + formatVersion, buffers)); } bool Ver4DictBuffers::flushHeaderAndDictBuffers(const char *const dictDirPath, @@ -178,29 +175,20 @@ bool Ver4DictBuffers::flushDictBuffers(FILE *const file) const { Ver4DictBuffers::Ver4DictBuffers(MmappedBuffer::MmappedBufferPtr &&headerBuffer, MmappedBuffer::MmappedBufferPtr &&bodyBuffer, const FormatUtils::FORMAT_VERSION formatVersion, - const std::vector<uint8_t *> &contentBuffers, const std::vector<int> &contentBufferSizes) + const std::vector<ReadWriteByteArrayView> &contentBuffers) : mHeaderBuffer(std::move(headerBuffer)), mDictBuffer(std::move(bodyBuffer)), mHeaderPolicy(mHeaderBuffer->getReadOnlyByteArrayView().data(), formatVersion), mExpandableHeaderBuffer(mHeaderBuffer->getReadWriteByteArrayView(), BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE), - mExpandableTrieBuffer( - ReadWriteByteArrayView(contentBuffers[Ver4DictConstants::TRIE_BUFFER_INDEX], - contentBufferSizes[Ver4DictConstants::TRIE_BUFFER_INDEX]), + mExpandableTrieBuffer(contentBuffers[Ver4DictConstants::TRIE_BUFFER_INDEX], BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE), mTerminalPositionLookupTable( - contentBuffers[Ver4DictConstants::TERMINAL_ADDRESS_LOOKUP_TABLE_BUFFER_INDEX], - contentBufferSizes[ - Ver4DictConstants::TERMINAL_ADDRESS_LOOKUP_TABLE_BUFFER_INDEX]), - mLanguageModelDictContent( - ReadWriteByteArrayView( - contentBuffers[Ver4DictConstants::LANGUAGE_MODEL_BUFFER_INDEX], - contentBufferSizes[Ver4DictConstants::LANGUAGE_MODEL_BUFFER_INDEX]), + contentBuffers[Ver4DictConstants::TERMINAL_ADDRESS_LOOKUP_TABLE_BUFFER_INDEX]), + mLanguageModelDictContent(contentBuffers[Ver4DictConstants::LANGUAGE_MODEL_BUFFER_INDEX], mHeaderPolicy.hasHistoricalInfoOfWords()), mBigramDictContent(&contentBuffers[Ver4DictConstants::BIGRAM_BUFFERS_INDEX], - &contentBufferSizes[Ver4DictConstants::BIGRAM_BUFFERS_INDEX], mHeaderPolicy.hasHistoricalInfoOfWords()), - mShortcutDictContent(&contentBuffers[Ver4DictConstants::SHORTCUT_BUFFERS_INDEX], - &contentBufferSizes[Ver4DictConstants::SHORTCUT_BUFFERS_INDEX]), + mShortcutDictContent(&contentBuffers[Ver4DictConstants::SHORTCUT_BUFFERS_INDEX]), mIsUpdatable(mDictBuffer->isUpdatable()) {} Ver4DictBuffers::Ver4DictBuffers(const HeaderPolicy *const headerPolicy, const int maxTrieSize) diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.h index 68027dcb8..70a7983f1 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.h @@ -122,8 +122,7 @@ class Ver4DictBuffers { Ver4DictBuffers(MmappedBuffer::MmappedBufferPtr &&headerBuffer, MmappedBuffer::MmappedBufferPtr &&bodyBuffer, const FormatUtils::FORMAT_VERSION formatVersion, - const std::vector<uint8_t *> &contentBuffers, - const std::vector<int> &contentBufferSizes); + const std::vector<ReadWriteByteArrayView> &contentBuffers); Ver4DictBuffers(const HeaderPolicy *const headerPolicy, const int maxTrieSize); diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp index 93d4e562d..b085a6661 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp @@ -46,7 +46,7 @@ const int Ver4DictConstants::SHORTCUT_BUFFERS_INDEX = const int Ver4DictConstants::NOT_A_TERMINAL_ID = -1; const int Ver4DictConstants::PROBABILITY_SIZE = 1; -const int Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE = 1; +const int Ver4DictConstants::FLAGS_IN_LANGUAGE_MODEL_SIZE = 1; const int Ver4DictConstants::TERMINAL_ADDRESS_TABLE_ADDRESS_SIZE = 3; const int Ver4DictConstants::NOT_A_TERMINAL_ADDRESS = 0; const int Ver4DictConstants::TERMINAL_ID_FIELD_SIZE = 4; @@ -54,6 +54,8 @@ const int Ver4DictConstants::TIME_STAMP_FIELD_SIZE = 4; const int Ver4DictConstants::WORD_LEVEL_FIELD_SIZE = 1; const int Ver4DictConstants::WORD_COUNT_FIELD_SIZE = 1; +const uint8_t Ver4DictConstants::FLAG_REPRESENTS_BEGINNING_OF_SENTENCE = 0x1; + const int Ver4DictConstants::BIGRAM_ADDRESS_TABLE_BLOCK_SIZE = 16; const int Ver4DictConstants::BIGRAM_ADDRESS_TABLE_DATA_SIZE = 4; const int Ver4DictConstants::SHORTCUT_ADDRESS_TABLE_BLOCK_SIZE = 64; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h index 6950ca70f..230b3052d 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h @@ -20,6 +20,7 @@ #include "defines.h" #include <cstddef> +#include <cstdint> namespace latinime { @@ -41,13 +42,15 @@ class Ver4DictConstants { static const int NOT_A_TERMINAL_ID; static const int PROBABILITY_SIZE; - static const int FLAGS_IN_PROBABILITY_FILE_SIZE; + static const int FLAGS_IN_LANGUAGE_MODEL_SIZE; static const int TERMINAL_ADDRESS_TABLE_ADDRESS_SIZE; static const int NOT_A_TERMINAL_ADDRESS; static const int TERMINAL_ID_FIELD_SIZE; static const int TIME_STAMP_FIELD_SIZE; static const int WORD_LEVEL_FIELD_SIZE; static const int WORD_COUNT_FIELD_SIZE; + // Flags in probability entry. + static const uint8_t FLAG_REPRESENTS_BEGINNING_OF_SENTENCE; static const int BIGRAM_ADDRESS_TABLE_BLOCK_SIZE; static const int BIGRAM_ADDRESS_TABLE_DATA_SIZE; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp index 857222f5d..b7c31bf75 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp @@ -145,10 +145,11 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeUnigramProperty( const ProbabilityEntry originalProbabilityEntry = mBuffers->getLanguageModelDictContent()->getProbabilityEntry( toBeUpdatedPtNodeParams->getTerminalId()); - const ProbabilityEntry probabilityEntry = createUpdatedEntryFrom(&originalProbabilityEntry, - unigramProperty); + const ProbabilityEntry probabilityEntryOfUnigramProperty = ProbabilityEntry(unigramProperty); + const ProbabilityEntry updatedProbabilityEntry = + createUpdatedEntryFrom(&originalProbabilityEntry, &probabilityEntryOfUnigramProperty); return mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry( - toBeUpdatedPtNodeParams->getTerminalId(), &probabilityEntry); + toBeUpdatedPtNodeParams->getTerminalId(), &updatedProbabilityEntry); } bool Ver4PatriciaTrieNodeWriter::updatePtNodeProbabilityAndGetNeedsToKeepPtNodeAfterGC( @@ -160,29 +161,15 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeProbabilityAndGetNeedsToKeepPtNodeA const ProbabilityEntry originalProbabilityEntry = mBuffers->getLanguageModelDictContent()->getProbabilityEntry( toBeUpdatedPtNodeParams->getTerminalId()); - if (originalProbabilityEntry.hasHistoricalInfo()) { - const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave( - originalProbabilityEntry.getHistoricalInfo(), mHeaderPolicy); - const ProbabilityEntry probabilityEntry = - originalProbabilityEntry.createEntryWithUpdatedHistoricalInfo(&historicalInfo); - if (!mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry( - toBeUpdatedPtNodeParams->getTerminalId(), &probabilityEntry)) { - AKLOGE("Cannot write updated probability entry. terminalId: %d", - toBeUpdatedPtNodeParams->getTerminalId()); - return false; - } - const bool isValid = ForgettingCurveUtils::needsToKeep(&historicalInfo, mHeaderPolicy); - if (!isValid) { - if (!markPtNodeAsWillBecomeNonTerminal(toBeUpdatedPtNodeParams)) { - AKLOGE("Cannot mark PtNode as willBecomeNonTerminal."); - return false; - } - } - *outNeedsToKeepPtNode = isValid; - } else { - // No need to update probability. + if (originalProbabilityEntry.isValid()) { *outNeedsToKeepPtNode = true; + return true; } + if (!markPtNodeAsWillBecomeNonTerminal(toBeUpdatedPtNodeParams)) { + AKLOGE("Cannot mark PtNode as willBecomeNonTerminal."); + return false; + } + *outNeedsToKeepPtNode = false; return true; } @@ -216,16 +203,36 @@ bool Ver4PatriciaTrieNodeWriter::writeNewTerminalPtNodeAndAdvancePosition( } // Write probability. ProbabilityEntry newProbabilityEntry; + const ProbabilityEntry probabilityEntryOfUnigramProperty = ProbabilityEntry(unigramProperty); const ProbabilityEntry probabilityEntryToWrite = createUpdatedEntryFrom( - &newProbabilityEntry, unigramProperty); + &newProbabilityEntry, &probabilityEntryOfUnigramProperty); return mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry( terminalId, &probabilityEntryToWrite); } bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds, const int wordId, const BigramProperty *const bigramProperty, bool *const outAddedNewBigram) { + // TODO: Support n-gram. + LanguageModelDictContent *const languageModelDictContent = + mBuffers->getMutableLanguageModelDictContent(); + const ProbabilityEntry probabilityEntry = + languageModelDictContent->getNgramProbabilityEntry( + prevWordIds.limit(1 /* maxSize */), wordId); + const ProbabilityEntry probabilityEntryOfBigramProperty(bigramProperty); + const ProbabilityEntry updatedProbabilityEntry = createUpdatedEntryFrom( + &probabilityEntry, &probabilityEntryOfBigramProperty); + if (!languageModelDictContent->setNgramProbabilityEntry( + prevWordIds.limit(1 /* maxSize */), wordId, &updatedProbabilityEntry)) { + AKLOGE("Cannot add new ngram entry. prevWordId: %d, wordId: %d", + prevWordIds[0], wordId); + return false; + } + if (!probabilityEntry.isValid() && outAddedNewBigram) { + *outAddedNewBigram = true; + } + // TODO: Remove. if (!mBigramPolicy->addNewEntry(prevWordIds[0], wordId, bigramProperty, outAddedNewBigram)) { - AKLOGE("Cannot add new bigram entry. terminalId: %d, targetTerminalId: %d", + AKLOGE("Cannot add new bigram entry. prevWordId: %d, wordId: %d", prevWordIds[0], wordId); return false; } @@ -234,6 +241,15 @@ bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds bool Ver4PatriciaTrieNodeWriter::removeNgramEntry(const WordIdArrayView prevWordIds, const int wordId) { + // TODO: Support n-gram. + LanguageModelDictContent *const languageModelDictContent = + mBuffers->getMutableLanguageModelDictContent(); + if (!languageModelDictContent->removeNgramProbabilityEntry(prevWordIds.limit(1 /* maxSize */), + wordId)) { + // TODO: Uncomment. + // return false; + } + // TODO: Remove. return mBigramPolicy->removeEntry(prevWordIds[0], wordId); } @@ -350,22 +366,19 @@ bool Ver4PatriciaTrieNodeWriter::writePtNodeAndGetTerminalIdAndAdvancePosition( isTerminal, ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */); } +// TODO: Move probability handling code to LanguageModelDictContent. const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom( const ProbabilityEntry *const originalProbabilityEntry, - const UnigramProperty *const unigramProperty) const { - // TODO: Consolidate historical info and probability. + const ProbabilityEntry *const probabilityEntry) const { if (mHeaderPolicy->hasHistoricalInfoOfWords()) { - const HistoricalInfo historicalInfoForUpdate(unigramProperty->getTimestamp(), - unigramProperty->getLevel(), unigramProperty->getCount()); const HistoricalInfo updatedHistoricalInfo = ForgettingCurveUtils::createUpdatedHistoricalInfo( originalProbabilityEntry->getHistoricalInfo(), - unigramProperty->getProbability(), &historicalInfoForUpdate, mHeaderPolicy); - return originalProbabilityEntry->createEntryWithUpdatedHistoricalInfo( - &updatedHistoricalInfo); + probabilityEntry->getProbability(), probabilityEntry->getHistoricalInfo(), + mHeaderPolicy); + return ProbabilityEntry(probabilityEntry->getFlags(), &updatedHistoricalInfo); } else { - return originalProbabilityEntry->createEntryWithUpdatedProbability( - unigramProperty->getProbability()); + return *probabilityEntry; } } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h index 6703dba04..5d73b6ea3 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h @@ -98,12 +98,12 @@ class Ver4PatriciaTrieNodeWriter : public PtNodeWriter { const PtNodeParams *const ptNodeParams, int *const outTerminalId, int *const ptNodeWritingPos); - // Create updated probability entry using given unigram property. In addition to the + // Create updated probability entry using given probability property. In addition to the // probability, this method updates historical information if needed. - // TODO: Update flags belonging to the unigram property. + // TODO: Update flags. const ProbabilityEntry createUpdatedEntryFrom( const ProbabilityEntry *const originalProbabilityEntry, - const UnigramProperty *const unigramProperty) const; + const ProbabilityEntry *const probabilityEntry) const; bool updatePtNodeFlags(const int ptNodePos, const bool isBlacklisted, const bool isNotAWord, const bool isTerminal, const bool hasMultipleChars); diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp index 723808399..2ea248e86 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp @@ -127,21 +127,28 @@ int Ver4PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtN if (ptNodePos == NOT_A_DICT_POS) { return NOT_A_PROBABILITY; } - const PtNodeParams ptNodeParams(mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos)); + const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) { return NOT_A_PROBABILITY; } if (prevWordsPtNodePos) { - const int bigramsPosition = getBigramsPositionOfPtNode(prevWordsPtNodePos[0]); - BinaryDictionaryBigramsIterator bigramsIt(&mBigramPolicy, bigramsPosition); - while (bigramsIt.hasNext()) { - bigramsIt.next(); - if (bigramsIt.getBigramPos() == ptNodePos - && bigramsIt.getProbability() != NOT_A_PROBABILITY) { - return getProbability(ptNodeParams.getProbability(), bigramsIt.getProbability()); - } + // TODO: Support n-gram. + const PtNodeParams prevWordPtNodeParams = + mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(prevWordsPtNodePos[0]); + const int prevWordTerminalId = prevWordPtNodeParams.getTerminalId(); + const ProbabilityEntry probabilityEntry = + mBuffers->getLanguageModelDictContent()->getNgramProbabilityEntry( + IntArrayView::fromObject(&prevWordTerminalId), + ptNodeParams.getTerminalId()); + if (!probabilityEntry.isValid()) { + return NOT_A_PROBABILITY; + } + if (mHeaderPolicy->hasHistoricalInfoOfWords()) { + return ForgettingCurveUtils::decodeProbability(probabilityEntry.getHistoricalInfo(), + mHeaderPolicy); + } else { + return probabilityEntry.getProbability(); } - return NOT_A_PROBABILITY; } return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY); } @@ -200,7 +207,7 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const int *const word, const int le } for (const auto &shortcut : unigramProperty->getShortcuts()) { if (shortcut.getTargetCodePoints()->size() > MAX_WORD_LENGTH) { - AKLOGE("One of shortcut targets is too long to insert to the dictionary, length: %d", + AKLOGE("One of shortcut targets is too long to insert to the dictionary, length: %zd", shortcut.getTargetCodePoints()->size()); return false; } @@ -235,7 +242,7 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const int *const word, const int le if (!mUpdatingHelper.addShortcutTarget(wordPos, shortcut.getTargetCodePoints()->data(), shortcut.getTargetCodePoints()->size(), shortcut.getProbability())) { - AKLOGE("Cannot add new shortcut target. PtNodePos: %d, length: %d, " + AKLOGE("Cannot add new shortcut target. PtNodePos: %d, length: %zd, " "probability: %d", wordPos, shortcut.getTargetCodePoints()->size(), shortcut.getProbability()); return false; @@ -263,6 +270,11 @@ bool Ver4PatriciaTriePolicy::removeUnigramEntry(const int *const word, const int AKLOGE("Cannot remove unigram. ptNodePos: %d", ptNodePos); return false; } + if (!mBuffers->getMutableLanguageModelDictContent()->removeProbabilityEntry( + ptNodeParams.getTerminalId())) { + // TODO: Uncomment. + // return false; + } if (!ptNodeParams.representsNonWordInfo()) { mUnigramCount--; } @@ -286,7 +298,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI } if (bigramProperty->getTargetCodePoints()->size() > MAX_WORD_LENGTH) { AKLOGE("The word is too long to insert the ngram to the dictionary. " - "length: %d", bigramProperty->getTargetCodePoints()->size()); + "length: %zd", bigramProperty->getTargetCodePoints()->size()); return false; } int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp index 4220312e0..d53575aa7 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp @@ -85,6 +85,27 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, mBuffers, headerPolicy, &ptNodeReader, &ptNodeArrayReader, &bigramPolicy, &shortcutPolicy); + int entryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; + if (!mBuffers->getMutableLanguageModelDictContent()->updateAllProbabilityEntries(headerPolicy, + entryCountTable)) { + AKLOGE("Failed to update probabilities in language model dict content."); + return false; + } + if (headerPolicy->isDecayingDict()) { + int maxEntryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; + maxEntryCountTable[0] = headerPolicy->getMaxUnigramCount(); + maxEntryCountTable[1] = headerPolicy->getMaxBigramCount(); + for (size_t i = 2; i < NELEMS(maxEntryCountTable); ++i) { + // TODO: Have max n-gram count. + maxEntryCountTable[i] = headerPolicy->getMaxBigramCount(); + } + if (!mBuffers->getMutableLanguageModelDictContent()->truncateEntries(entryCountTable, + maxEntryCountTable, headerPolicy)) { + AKLOGE("Failed to truncate entries in language model dict content."); + return false; + } + } + DynamicPtReadingHelper readingHelper(&ptNodeReader, &ptNodeArrayReader); readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos); DynamicPtGcEventListeners @@ -187,6 +208,7 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, return true; } +// TODO: Remove. bool Ver4PatriciaTrieWritingHelper::truncateUnigrams( const Ver4PatriciaTrieNodeReader *const ptNodeReader, Ver4PatriciaTrieNodeWriter *const ptNodeWriter, const int maxUnigramCount) { @@ -227,6 +249,7 @@ bool Ver4PatriciaTrieWritingHelper::truncateUnigrams( return true; } +// TODO: Remove. bool Ver4PatriciaTrieWritingHelper::truncateBigrams(const int maxBigramCount) { const TerminalPositionLookupTable *const terminalPosLookupTable = mBuffers->getTerminalPositionLookupTable(); diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.cpp index 833063c17..ecbe7922c 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.cpp @@ -31,7 +31,7 @@ uint32_t BufferWithExtendableBuffer::readUint(const int size, const int pos) con uint32_t BufferWithExtendableBuffer::readUintAndAdvancePosition(const int size, int *const pos) const { - const int value = readUint(size, *pos); + const uint32_t value = readUint(size, *pos); *pos += size; return value; } diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/byte_array_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/byte_array_utils.h index c0a9fcb1d..4b3c98988 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/byte_array_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/byte_array_utils.h @@ -114,7 +114,7 @@ class ByteArrayUtils { return buffer[(*pos)++]; } - static AK_FORCE_INLINE int readUint(const uint8_t *const buffer, + static AK_FORCE_INLINE uint32_t readUint(const uint8_t *const buffer, const int size, const int pos) { // size must be in 1 to 4. ASSERT(size >= 1 && size <= 4); diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp index 1916ea560..e6e7167c2 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp @@ -23,7 +23,7 @@ namespace latinime { const uint32_t FormatUtils::MAGIC_NUMBER = 0x9BC13AFE; // Magic number (4 bytes), version (2 bytes), flags (2 bytes), header size (4 bytes) = 12 -const int FormatUtils::DICTIONARY_MINIMUM_SIZE = 12; +const size_t FormatUtils::DICTIONARY_MINIMUM_SIZE = 12; /* static */ FormatUtils::FORMAT_VERSION FormatUtils::getFormatVersion(const int formatVersion) { switch (formatVersion) { @@ -40,14 +40,14 @@ const int FormatUtils::DICTIONARY_MINIMUM_SIZE = 12; } } /* static */ FormatUtils::FORMAT_VERSION FormatUtils::detectFormatVersion( - const uint8_t *const dict, const int dictSize) { + const ReadOnlyByteArrayView dictBuffer) { // The magic number is stored big-endian. // If the dictionary is less than 4 bytes, we can't even read the magic number, so we don't // understand this format. - if (dictSize < DICTIONARY_MINIMUM_SIZE) { + if (dictBuffer.size() < DICTIONARY_MINIMUM_SIZE) { return UNKNOWN_VERSION; } - const uint32_t magicNumber = ByteArrayUtils::readUint32(dict, 0); + const uint32_t magicNumber = ByteArrayUtils::readUint32(dictBuffer.data(), 0); switch (magicNumber) { case MAGIC_NUMBER: // The layout of the header is as follows: @@ -58,7 +58,7 @@ const int FormatUtils::DICTIONARY_MINIMUM_SIZE = 12; // Conceptually this converts the hardcoded value of the bytes in the file into // the symbolic value we use in the code. But we want the constants to be the // same so we use them for both here. - return getFormatVersion(ByteArrayUtils::readUint16(dict, 4)); + return getFormatVersion(ByteArrayUtils::readUint16(dictBuffer.data(), 4)); default: return UNKNOWN_VERSION; } diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h index 55ad5799f..51ad9877c 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h @@ -20,6 +20,7 @@ #include <cstdint> #include "defines.h" +#include "utils/byte_array_view.h" namespace latinime { @@ -42,12 +43,12 @@ class FormatUtils { static const uint32_t MAGIC_NUMBER; static FORMAT_VERSION getFormatVersion(const int formatVersion); - static FORMAT_VERSION detectFormatVersion(const uint8_t *const dict, const int dictSize); + static FORMAT_VERSION detectFormatVersion(const ReadOnlyByteArrayView dictBuffer); private: DISALLOW_IMPLICIT_CONSTRUCTORS(FormatUtils); - static const int DICTIONARY_MINIMUM_SIZE; + static const size_t DICTIONARY_MINIMUM_SIZE; }; } // namespace latinime #endif /* LATINIME_FORMAT_UTILS_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp index 407b8efd0..39f417ebb 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp @@ -26,6 +26,7 @@ const int TrieMap::FIELD1_SIZE = 3; const int TrieMap::ENTRY_SIZE = FIELD0_SIZE + FIELD1_SIZE; const uint32_t TrieMap::VALUE_FLAG = 0x400000; const uint32_t TrieMap::VALUE_MASK = 0x3FFFFF; +const uint32_t TrieMap::INVALID_VALUE_IN_KEY_VALUE_ENTRY = VALUE_MASK; const uint32_t TrieMap::TERMINAL_LINK_FLAG = 0x800000; const uint32_t TrieMap::TERMINAL_LINK_MASK = 0x7FFFFF; const int TrieMap::NUM_OF_BITS_USED_FOR_ONE_LEVEL = 5; @@ -34,6 +35,7 @@ const int TrieMap::MAX_NUM_OF_ENTRIES_IN_ONE_LEVEL = 1 << NUM_OF_BITS_USED_FOR_O const int TrieMap::ROOT_BITMAP_ENTRY_INDEX = 0; const int TrieMap::ROOT_BITMAP_ENTRY_POS = MAX_NUM_OF_ENTRIES_IN_ONE_LEVEL * FIELD0_SIZE; const TrieMap::Entry TrieMap::EMPTY_BITMAP_ENTRY = TrieMap::Entry(0, 0); +const int TrieMap::TERMINAL_LINKED_ENTRY_COUNT = 2; // Value entry and bitmap entry. const uint64_t TrieMap::MAX_VALUE = (static_cast<uint64_t>(1) << ((FIELD0_SIZE + FIELD1_SIZE) * CHAR_BIT)) - 1; const int TrieMap::MAX_BUFFER_SIZE = TERMINAL_LINK_MASK * ENTRY_SIZE; @@ -76,14 +78,14 @@ int TrieMap::getNextLevelBitmapEntryIndex(const int key, const int bitmapEntryIn return terminalEntry.getValueEntryIndex() + 1; } // Create a value entry and a bitmap entry. - const int valueEntryIndex = allocateTable(2 /* entryCount */); + const int valueEntryIndex = allocateTable(TERMINAL_LINKED_ENTRY_COUNT); if (!writeEntry(Entry(0, terminalEntry.getValue()), valueEntryIndex)) { return INVALID_INDEX; } if (!writeEntry(EMPTY_BITMAP_ENTRY, valueEntryIndex + 1)) { return INVALID_INDEX; } - if (!writeField1(valueEntryIndex | TERMINAL_LINK_FLAG, valueEntryIndex)) { + if (!writeField1(valueEntryIndex | TERMINAL_LINK_FLAG, terminalEntryIndex)) { return INVALID_INDEX; } return valueEntryIndex + 1; @@ -108,6 +110,31 @@ bool TrieMap::save(FILE *const file) const { return DictFileWritingUtils::writeBufferToFileTail(file, &mBuffer); } +bool TrieMap::remove(const int key, const int bitmapEntryIndex) { + const Entry bitmapEntry = readEntry(bitmapEntryIndex); + const uint32_t unsignedKey = static_cast<uint32_t>(key); + const int terminalEntryIndex = getTerminalEntryIndex( + unsignedKey, getBitShuffledKey(unsignedKey), bitmapEntry, 0 /* level */); + if (terminalEntryIndex == INVALID_INDEX) { + // Not found. + return false; + } + const Entry terminalEntry = readEntry(terminalEntryIndex); + if (!writeField1(VALUE_FLAG ^ INVALID_VALUE_IN_KEY_VALUE_ENTRY , terminalEntryIndex)) { + return false; + } + if (terminalEntry.hasTerminalLink()) { + const Entry nextLevelBitmapEntry = readEntry(terminalEntry.getValueEntryIndex() + 1); + if (!freeTable(terminalEntry.getValueEntryIndex(), TERMINAL_LINKED_ENTRY_COUNT)) { + return false; + } + if (!removeInner(nextLevelBitmapEntry)){ + return false; + } + } + return true; +} + /** * Iterate next entry in a certain level. * @@ -129,7 +156,7 @@ const TrieMap::Result TrieMap::iterateNext(std::vector<TableIterationState> *con if (entry.isBitmapEntry()) { // Move to child. iterationState->emplace_back(popCount(entry.getBitmap()), entry.getTableIndex()); - } else { + } else if (entry.isValidTerminalEntry()) { if (outKey) { *outKey = entry.getKey(); } @@ -162,12 +189,12 @@ uint32_t TrieMap::getBitShuffledKey(const uint32_t key) const { } bool TrieMap::writeValue(const uint64_t value, const int terminalEntryIndex) { - if (value <= VALUE_MASK) { + if (value < VALUE_MASK) { // Write value into the terminal entry. return writeField1(value | VALUE_FLAG, terminalEntryIndex); } // Create value entry and write value. - const int valueEntryIndex = allocateTable(2 /* entryCount */); + const int valueEntryIndex = allocateTable(TERMINAL_LINKED_ENTRY_COUNT); if (!writeEntry(Entry(value >> (FIELD1_SIZE * CHAR_BIT), value), valueEntryIndex)) { return false; } @@ -227,6 +254,9 @@ int TrieMap::getTerminalEntryIndex(const uint32_t key, const uint32_t hashedKey, // Move to the next level. return getTerminalEntryIndex(key, hashedKey, entry, level + 1); } + if (!entry.isValidTerminalEntry()) { + return INVALID_INDEX; + } if (entry.getKey() == key) { // Terminal entry is found. return entryIndex; @@ -287,6 +317,10 @@ bool TrieMap::putInternal(const uint32_t key, const uint64_t value, const uint32 // Bitmap entry is found. Go to the next level. return putInternal(key, value, hashedKey, entryIndex, entry, level + 1); } + if (!entry.isValidTerminalEntry()) { + // Overwrite invalid terminal entry. + return writeTerminalEntry(key, value, entryIndex); + } if (entry.getKey() == key) { // Terminal entry for the key is found. Update the value. return updateValue(entry, value, entryIndex); @@ -384,4 +418,37 @@ bool TrieMap::addNewEntryByExpandingTable(const uint32_t key, const uint64_t val return true; } +bool TrieMap::removeInner(const Entry &bitmapEntry) { + const int tableSize = popCount(bitmapEntry.getBitmap()); + if (tableSize <= 0) { + // The table is empty. No need to remove any entries. + return true; + } + for (int i = 0; i < tableSize; ++i) { + const int entryIndex = bitmapEntry.getTableIndex() + i; + const Entry entry = readEntry(entryIndex); + if (entry.isBitmapEntry()) { + // Delete next bitmap entry recursively. + if (!removeInner(entry)) { + return false; + } + } else { + // Invalidate terminal entry just in case. + if (!writeField1(VALUE_FLAG ^ INVALID_VALUE_IN_KEY_VALUE_ENTRY , entryIndex)) { + return false; + } + if (entry.hasTerminalLink()) { + const Entry nextLevelBitmapEntry = readEntry(entry.getValueEntryIndex() + 1); + if (!freeTable(entry.getValueEntryIndex(), TERMINAL_LINKED_ENTRY_COUNT)) { + return false; + } + if (!removeInner(nextLevelBitmapEntry)) { + return false; + } + } + } + } + return true; +} + } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h index 3e5c4010c..c2aeac211 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h @@ -84,6 +84,10 @@ class TrieMap { return mValue; } + AK_FORCE_INLINE int getNextLevelBitmapEntryIndex() const { + return mNextLevelBitmapEntryIndex; + } + private: const TrieMap *const mTrieMap; const int mKey; @@ -202,6 +206,8 @@ class TrieMap { bool save(FILE *const file) const; + bool remove(const int key, const int bitmapEntryIndex); + private: DISALLOW_COPY_AND_ASSIGN(TrieMap); @@ -245,6 +251,11 @@ class TrieMap { } // For terminal entry. + AK_FORCE_INLINE bool isValidTerminalEntry() const { + return hasTerminalLink() || ((mData1 & VALUE_MASK) != INVALID_VALUE_IN_KEY_VALUE_ENTRY); + } + + // For terminal entry. AK_FORCE_INLINE uint32_t getValueEntryIndex() const { return mData1 & TERMINAL_LINK_MASK; } @@ -272,6 +283,7 @@ class TrieMap { static const int ENTRY_SIZE; static const uint32_t VALUE_FLAG; static const uint32_t VALUE_MASK; + static const uint32_t INVALID_VALUE_IN_KEY_VALUE_ENTRY; static const uint32_t TERMINAL_LINK_FLAG; static const uint32_t TERMINAL_LINK_MASK; static const int NUM_OF_BITS_USED_FOR_ONE_LEVEL; @@ -280,6 +292,7 @@ class TrieMap { static const int ROOT_BITMAP_ENTRY_INDEX; static const int ROOT_BITMAP_ENTRY_POS; static const Entry EMPTY_BITMAP_ENTRY; + static const int TERMINAL_LINKED_ENTRY_COUNT; static const int MAX_BUFFER_SIZE; uint32_t getBitShuffledKey(const uint32_t key) const; @@ -378,6 +391,8 @@ class TrieMap { AK_FORCE_INLINE int getTailEntryIndex() const { return (mBuffer.getTailPosition() - ROOT_BITMAP_ENTRY_POS) / ENTRY_SIZE; } + + bool removeInner(const Entry &bitmapEntry); }; } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h index 04cb6603a..52c4251f0 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h @@ -51,10 +51,10 @@ class TypingScoring : public Scoring { } if (boostExactMatches && ErrorTypeUtils::isExactMatch(containedErrorTypes)) { score += ScoringParams::EXACT_MATCH_PROMOTION; - if ((ErrorTypeUtils::MATCH_WITH_CASE_ERROR & containedErrorTypes) != 0) { + if ((ErrorTypeUtils::MATCH_WITH_WRONG_CASE & containedErrorTypes) != 0) { score -= ScoringParams::CASE_ERROR_PENALTY_FOR_EXACT_MATCH; } - if ((ErrorTypeUtils::MATCH_WITH_ACCENT_ERROR & containedErrorTypes) != 0) { + if ((ErrorTypeUtils::MATCH_WITH_MISSING_ACCENT & containedErrorTypes) != 0) { score -= ScoringParams::ACCENT_ERROR_PENALTY_FOR_EXACT_MATCH; } if ((ErrorTypeUtils::MATCH_WITH_DIGRAPH & containedErrorTypes) != 0) { diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp b/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp index 54f65c786..1d590c353 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp +++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp @@ -36,25 +36,34 @@ ErrorTypeUtils::ErrorType TypingWeighting::getErrorType(const CorrectionType cor // Compare the node code point with original primary code point on the keyboard. const ProximityInfoState *const pInfoState = traverseSession->getProximityInfoState(0); - const int primaryOriginalCodePoint = pInfoState->getPrimaryOriginalCodePointAt( + const int primaryCodePoint = pInfoState->getPrimaryCodePointAt( dicNode->getInputIndex(0)); const int nodeCodePoint = dicNode->getNodeCodePoint(); - if (primaryOriginalCodePoint == nodeCodePoint) { + // TODO: Check whether the input code point is on the keyboard. + if (primaryCodePoint == nodeCodePoint) { // Node code point is same as original code point on the keyboard. return ErrorTypeUtils::NOT_AN_ERROR; - } else if (CharUtils::toLowerCase(primaryOriginalCodePoint) == + } else if (CharUtils::toLowerCase(primaryCodePoint) == CharUtils::toLowerCase(nodeCodePoint)) { // Only cases of the code points are different. - return ErrorTypeUtils::MATCH_WITH_CASE_ERROR; - } else if (CharUtils::toBaseCodePoint(primaryOriginalCodePoint) == - CharUtils::toBaseCodePoint(nodeCodePoint)) { + return ErrorTypeUtils::MATCH_WITH_WRONG_CASE; + } else if (primaryCodePoint == CharUtils::toBaseCodePoint(nodeCodePoint)) { // Node code point is a variant of original code point. - return ErrorTypeUtils::MATCH_WITH_ACCENT_ERROR; - } else { + return ErrorTypeUtils::MATCH_WITH_MISSING_ACCENT; + } else if (CharUtils::toBaseCodePoint(primaryCodePoint) + == CharUtils::toBaseCodePoint(nodeCodePoint)) { + // Base code points are the same but the code point is intentionally input. + return ErrorTypeUtils::MATCH_WITH_WRONG_ACCENT; + } else if (CharUtils::toLowerCase(primaryCodePoint) + == CharUtils::toBaseLowerCase(nodeCodePoint)) { // Node code point is a variant of original code point and the cases are also // different. - return ErrorTypeUtils::MATCH_WITH_ACCENT_ERROR - | ErrorTypeUtils::MATCH_WITH_CASE_ERROR; + return ErrorTypeUtils::MATCH_WITH_MISSING_ACCENT + | ErrorTypeUtils::MATCH_WITH_WRONG_CASE; + } else { + // Base code points are the same and the cases are different. + return ErrorTypeUtils::MATCH_WITH_WRONG_ACCENT + | ErrorTypeUtils::MATCH_WITH_WRONG_CASE; } } break; diff --git a/native/jni/src/utils/byte_array_view.h b/native/jni/src/utils/byte_array_view.h index 2c97c6d58..10d7ae278 100644 --- a/native/jni/src/utils/byte_array_view.h +++ b/native/jni/src/utils/byte_array_view.h @@ -77,10 +77,12 @@ class ReadWriteByteArrayView { } private: - DISALLOW_ASSIGNMENT_OPERATOR(ReadWriteByteArrayView); + // Default copy constructor and assignment operator are used for using this class with STL + // containers. - uint8_t *const mPtr; - const size_t mSize; + // These members cannot be const to have the assignment operator. + uint8_t *mPtr; + size_t mSize; }; } // namespace latinime diff --git a/native/jni/src/utils/int_array_view.h b/native/jni/src/utils/int_array_view.h index c1ddc9812..53f2d2971 100644 --- a/native/jni/src/utils/int_array_view.h +++ b/native/jni/src/utils/int_array_view.h @@ -91,6 +91,11 @@ class IntArrayView { return mPtr + mSize; } + // Returns the view whose size is smaller than or equal to the given count. + const IntArrayView limit(const size_t maxSize) const { + return IntArrayView(mPtr, std::min(maxSize, mSize)); + } + private: DISALLOW_ASSIGNMENT_OPERATOR(IntArrayView); |