aboutsummaryrefslogtreecommitdiffstats
path: root/native/jni/src
diff options
context:
space:
mode:
Diffstat (limited to 'native/jni/src')
-rw-r--r--native/jni/src/suggest/core/dictionary/error_type_utils.cpp19
-rw-r--r--native/jni/src/suggest/core/dictionary/error_type_utils.h5
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp16
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.h5
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp20
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h5
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp6
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/bigram_dict_content.h8
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp173
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h54
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h54
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/shortcut_dict_content.h8
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/single_dict_content.h7
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/sparse_table_dict_content.h15
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h6
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.cpp28
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.h3
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp4
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h5
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp83
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h6
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp38
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp23
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.cpp2
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/utils/byte_array_utils.h2
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp10
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h5
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp77
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h15
-rw-r--r--native/jni/src/suggest/policyimpl/typing/typing_scoring.h4
-rw-r--r--native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp29
-rw-r--r--native/jni/src/utils/byte_array_view.h8
-rw-r--r--native/jni/src/utils/int_array_view.h5
33 files changed, 575 insertions, 173 deletions
diff --git a/native/jni/src/suggest/core/dictionary/error_type_utils.cpp b/native/jni/src/suggest/core/dictionary/error_type_utils.cpp
index b6bf7a98c..1e2494e92 100644
--- a/native/jni/src/suggest/core/dictionary/error_type_utils.cpp
+++ b/native/jni/src/suggest/core/dictionary/error_type_utils.cpp
@@ -19,17 +19,18 @@
namespace latinime {
const ErrorTypeUtils::ErrorType ErrorTypeUtils::NOT_AN_ERROR = 0x0;
-const ErrorTypeUtils::ErrorType ErrorTypeUtils::MATCH_WITH_CASE_ERROR = 0x1;
-const ErrorTypeUtils::ErrorType ErrorTypeUtils::MATCH_WITH_ACCENT_ERROR = 0x2;
-const ErrorTypeUtils::ErrorType ErrorTypeUtils::MATCH_WITH_DIGRAPH = 0x4;
-const ErrorTypeUtils::ErrorType ErrorTypeUtils::INTENTIONAL_OMISSION = 0x8;
-const ErrorTypeUtils::ErrorType ErrorTypeUtils::EDIT_CORRECTION = 0x10;
-const ErrorTypeUtils::ErrorType ErrorTypeUtils::PROXIMITY_CORRECTION = 0x20;
-const ErrorTypeUtils::ErrorType ErrorTypeUtils::COMPLETION = 0x40;
-const ErrorTypeUtils::ErrorType ErrorTypeUtils::NEW_WORD = 0x80;
+const ErrorTypeUtils::ErrorType ErrorTypeUtils::MATCH_WITH_WRONG_CASE = 0x1;
+const ErrorTypeUtils::ErrorType ErrorTypeUtils::MATCH_WITH_MISSING_ACCENT = 0x2;
+const ErrorTypeUtils::ErrorType ErrorTypeUtils::MATCH_WITH_WRONG_ACCENT = 0x4;
+const ErrorTypeUtils::ErrorType ErrorTypeUtils::MATCH_WITH_DIGRAPH = 0x8;
+const ErrorTypeUtils::ErrorType ErrorTypeUtils::INTENTIONAL_OMISSION = 0x10;
+const ErrorTypeUtils::ErrorType ErrorTypeUtils::EDIT_CORRECTION = 0x20;
+const ErrorTypeUtils::ErrorType ErrorTypeUtils::PROXIMITY_CORRECTION = 0x40;
+const ErrorTypeUtils::ErrorType ErrorTypeUtils::COMPLETION = 0x80;
+const ErrorTypeUtils::ErrorType ErrorTypeUtils::NEW_WORD = 0x100;
const ErrorTypeUtils::ErrorType ErrorTypeUtils::ERRORS_TREATED_AS_AN_EXACT_MATCH =
- NOT_AN_ERROR | MATCH_WITH_CASE_ERROR | MATCH_WITH_ACCENT_ERROR | MATCH_WITH_DIGRAPH;
+ NOT_AN_ERROR | MATCH_WITH_WRONG_CASE | MATCH_WITH_MISSING_ACCENT | MATCH_WITH_DIGRAPH;
const ErrorTypeUtils::ErrorType
ErrorTypeUtils::ERRORS_TREATED_AS_AN_EXACT_MATCH_WITH_INTENTIONAL_OMISSION =
diff --git a/native/jni/src/suggest/core/dictionary/error_type_utils.h b/native/jni/src/suggest/core/dictionary/error_type_utils.h
index e3e76b238..fd1d5fcff 100644
--- a/native/jni/src/suggest/core/dictionary/error_type_utils.h
+++ b/native/jni/src/suggest/core/dictionary/error_type_utils.h
@@ -30,8 +30,9 @@ class ErrorTypeUtils {
typedef uint32_t ErrorType;
static const ErrorType NOT_AN_ERROR;
- static const ErrorType MATCH_WITH_CASE_ERROR;
- static const ErrorType MATCH_WITH_ACCENT_ERROR;
+ static const ErrorType MATCH_WITH_WRONG_CASE;
+ static const ErrorType MATCH_WITH_MISSING_ACCENT;
+ static const ErrorType MATCH_WITH_WRONG_ACCENT;
static const ErrorType MATCH_WITH_DIGRAPH;
// Treat error as an intentional omission when the CorrectionType is omission and the node can
// be intentional omission.
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp
index 278f2b199..97a8bcc98 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp
@@ -234,8 +234,8 @@ bool Ver4PatriciaTrieNodeWriter::writeNewTerminalPtNodeAndAdvancePosition(
bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds, const int wordId,
const BigramProperty *const bigramProperty, bool *const outAddedNewEntry) {
if (!mBigramPolicy->addNewEntry(prevWordIds[0], wordId, bigramProperty, outAddedNewEntry)) {
- AKLOGE("Cannot add new bigram entry. terminalId: %d, targetTerminalId: %d",
- sourcePtNodeParams->getTerminalId(), targetPtNodeParam->getTerminalId());
+ AKLOGE("Cannot add new bigram entry. prevWordId: %d, wordId: %d",
+ prevWordIds[0], wordId);
return false;
}
const int ptNodePos =
@@ -425,6 +425,18 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeFlags(const int ptNodePos,
return true;
}
+bool Ver4PatriciaTrieNodeWriter::suppressUnigramEntry(const PtNodeParams *const ptNodeParams) {
+ if (!mHeaderPolicy->hasHistoricalInfoOfWords()) {
+ // Require historical info to suppress unigram entry.
+ return false;
+ }
+ const HistoricalInfo suppressedHistorycalInfo(0 /* timestamp */, 0 /* level */, 0 /* count */);
+ const ProbabilityEntry probabilityEntryToWrite =
+ ProbabilityEntry().createEntryWithUpdatedHistoricalInfo(&suppressedHistorycalInfo);
+ return mBuffers->getMutableProbabilityDictContent()->setProbabilityEntry(
+ ptNodeParams->getTerminalId(), &probabilityEntryToWrite);
+}
+
} // namespace v402
} // namespace backward
} // namespace latinime
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.h
index d49d9a666..9d8a55bff 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.h
@@ -111,6 +111,11 @@ class Ver4PatriciaTrieNodeWriter : public PtNodeWriter {
bool updatePtNodeHasBigramsAndShortcutTargetsFlags(const PtNodeParams *const ptNodeParams);
+ // Suppress unigram not to use the word for generating suggestions. So, this method can be used
+ // only for dictionaries with historical info. Also, suppressed entries are included in unigram
+ // count. They will be removed from the dictionary during GC.
+ bool suppressUnigramEntry(const PtNodeParams *const ptNodeParams);
+
private:
DISALLOW_COPY_AND_ASSIGN(Ver4PatriciaTrieNodeWriter);
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
index 1296b8acd..9c6452e40 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
@@ -210,7 +210,7 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const int *const word, const int le
}
for (const auto &shortcut : unigramProperty->getShortcuts()) {
if (shortcut.getTargetCodePoints()->size() > MAX_WORD_LENGTH) {
- AKLOGE("One of shortcut targets is too long to insert to the dictionary, length: %d",
+ AKLOGE("One of shortcut targets is too long to insert to the dictionary, length: %zd",
shortcut.getTargetCodePoints()->size());
return false;
}
@@ -245,7 +245,7 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const int *const word, const int le
if (!mUpdatingHelper.addShortcutTarget(wordPos,
shortcut.getTargetCodePoints()->data(),
shortcut.getTargetCodePoints()->size(), shortcut.getProbability())) {
- AKLOGE("Cannot add new shortcut target. PtNodePos: %d, length: %d, "
+ AKLOGE("Cannot add new shortcut target. PtNodePos: %d, length: %zd, "
"probability: %d", wordPos, shortcut.getTargetCodePoints()->size(),
shortcut.getProbability());
return false;
@@ -258,6 +258,20 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const int *const word, const int le
}
}
+bool Ver4PatriciaTriePolicy::removeUnigramEntry(const int *const word, const int length) {
+ if (!mBuffers->isUpdatable()) {
+ AKLOGI("Warning: removeUnigramEntry() is called for non-updatable dictionary.");
+ return false;
+ }
+ const int ptNodePos = getTerminalPtNodePositionOfWord(word, length,
+ false /* forceLowerCaseSearch */);
+ if (ptNodePos == NOT_A_DICT_POS) {
+ return false;
+ }
+ const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
+ return mNodeWriter.suppressUnigramEntry(&ptNodeParams);
+}
+
bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsInfo,
const BigramProperty *const bigramProperty) {
if (!mBuffers->isUpdatable()) {
@@ -275,7 +289,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI
}
if (bigramProperty->getTargetCodePoints()->size() > MAX_WORD_LENGTH) {
AKLOGE("The word is too long to insert the ngram to the dictionary. "
- "length: %d", bigramProperty->getTargetCodePoints()->size());
+ "length: %zd", bigramProperty->getTargetCodePoints()->size());
return false;
}
int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h
index 9e989b268..d77499636 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h
@@ -108,10 +108,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
bool addUnigramEntry(const int *const word, const int length,
const UnigramProperty *const unigramProperty);
- bool removeUnigramEntry(const int *const word, const int length) {
- // Removing unigram entry is not supported.
- return false;
- }
+ bool removeUnigramEntry(const int *const word, const int length);
bool addNgramEntry(const PrevWordsInfo *const prevWordsInfo,
const BigramProperty *const bigramProperty);
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp
index e4ea3da16..9fa93efc9 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp
@@ -111,8 +111,7 @@ template<class DictConstants, class DictBuffers, class DictBuffersPtr, class Str
return nullptr;
}
const FormatUtils::FORMAT_VERSION formatVersion = FormatUtils::detectFormatVersion(
- mmappedBuffer->getReadOnlyByteArrayView().data(),
- mmappedBuffer->getReadOnlyByteArrayView().size());
+ mmappedBuffer->getReadOnlyByteArrayView());
switch (formatVersion) {
case FormatUtils::VERSION_2:
AKLOGE("Given path is a directory but the format is version 2. path: %s", path);
@@ -174,8 +173,7 @@ template<class DictConstants, class DictBuffers, class DictBuffersPtr, class Str
if (!mmappedBuffer) {
return nullptr;
}
- switch (FormatUtils::detectFormatVersion(mmappedBuffer->getReadOnlyByteArrayView().data(),
- mmappedBuffer->getReadOnlyByteArrayView().size())) {
+ switch (FormatUtils::detectFormatVersion(mmappedBuffer->getReadOnlyByteArrayView())) {
case FormatUtils::VERSION_2:
return DictionaryStructureWithBufferPolicy::StructurePolicyPtr(
new PatriciaTriePolicy(std::move(mmappedBuffer)));
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/bigram_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/bigram_dict_content.h
index 361dd2c74..20bae5943 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/bigram_dict_content.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/bigram_dict_content.h
@@ -17,7 +17,6 @@
#ifndef LATINIME_BIGRAM_DICT_CONTENT_H
#define LATINIME_BIGRAM_DICT_CONTENT_H
-#include <cstdint>
#include <cstdio>
#include "defines.h"
@@ -28,11 +27,12 @@
namespace latinime {
+class ReadWriteByteArrayView;
+
class BigramDictContent : public SparseTableDictContent {
public:
- BigramDictContent(uint8_t *const *buffers, const int *bufferSizes, const bool hasHistoricalInfo)
- : SparseTableDictContent(buffers, bufferSizes,
- Ver4DictConstants::BIGRAM_ADDRESS_TABLE_BLOCK_SIZE,
+ BigramDictContent(const ReadWriteByteArrayView *const buffers, const bool hasHistoricalInfo)
+ : SparseTableDictContent(buffers, Ver4DictConstants::BIGRAM_ADDRESS_TABLE_BLOCK_SIZE,
Ver4DictConstants::BIGRAM_ADDRESS_TABLE_DATA_SIZE),
mHasHistoricalInfo(hasHistoricalInfo) {}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
index 5dc91ba10..ea2d24e67 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
@@ -16,6 +16,11 @@
#include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h"
+#include <algorithm>
+#include <cstring>
+
+#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h"
+
namespace latinime {
bool LanguageModelDictContent::save(FILE *const file) const {
@@ -45,12 +50,38 @@ ProbabilityEntry LanguageModelDictContent::getNgramProbabilityEntry(
}
bool LanguageModelDictContent::setNgramProbabilityEntry(const WordIdArrayView prevWordIds,
- const int terminalId, const ProbabilityEntry *const probabilityEntry) {
+ const int wordId, const ProbabilityEntry *const probabilityEntry) {
+ if (wordId == Ver4DictConstants::NOT_A_TERMINAL_ID) {
+ return false;
+ }
+ const int bitmapEntryIndex = createAndGetBitmapEntryIndex(prevWordIds);
+ if (bitmapEntryIndex == TrieMap::INVALID_INDEX) {
+ return false;
+ }
+ return mTrieMap.put(wordId, probabilityEntry->encode(mHasHistoricalInfo), bitmapEntryIndex);
+}
+
+bool LanguageModelDictContent::removeNgramProbabilityEntry(const WordIdArrayView prevWordIds,
+ const int wordId) {
const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds);
if (bitmapEntryIndex == TrieMap::INVALID_INDEX) {
+ // Cannot find bitmap entry for the probability entry. The entry doesn't exist.
return false;
}
- return mTrieMap.put(terminalId, probabilityEntry->encode(mHasHistoricalInfo), bitmapEntryIndex);
+ return mTrieMap.remove(wordId, bitmapEntryIndex);
+}
+
+bool LanguageModelDictContent::truncateEntries(const int *const entryCounts,
+ const int *const maxEntryCounts, const HeaderPolicy *const headerPolicy) {
+ for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
+ if (entryCounts[i] <= maxEntryCounts[i]) {
+ continue;
+ }
+ if (!turncateEntriesInSpecifiedLevel(headerPolicy, maxEntryCounts[i], i)) {
+ return false;
+ }
+ }
+ return true;
}
bool LanguageModelDictContent::runGCInner(
@@ -80,6 +111,19 @@ bool LanguageModelDictContent::runGCInner(
return true;
}
+int LanguageModelDictContent::createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds) {
+ if (prevWordIds.empty()) {
+ return mTrieMap.getRootBitmapEntryIndex();
+ }
+ const int lastBitmapEntryIndex =
+ getBitmapEntryIndex(prevWordIds.limit(prevWordIds.size() - 1));
+ if (lastBitmapEntryIndex == TrieMap::INVALID_INDEX) {
+ return TrieMap::INVALID_INDEX;
+ }
+ return mTrieMap.getNextLevelBitmapEntryIndex(prevWordIds[prevWordIds.size() - 1],
+ lastBitmapEntryIndex);
+}
+
int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWordIds) const {
int bitmapEntryIndex = mTrieMap.getRootBitmapEntryIndex();
for (const int wordId : prevWordIds) {
@@ -92,4 +136,129 @@ int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWord
return bitmapEntryIndex;
}
+bool LanguageModelDictContent::updateAllProbabilityEntriesInner(const int bitmapEntryIndex,
+ const int level, const HeaderPolicy *const headerPolicy, int *const outEntryCounts) {
+ for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
+ if (level > MAX_PREV_WORD_COUNT_FOR_N_GRAM) {
+ AKLOGE("Invalid level. level: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.",
+ level, MAX_PREV_WORD_COUNT_FOR_N_GRAM);
+ return false;
+ }
+ const ProbabilityEntry probabilityEntry =
+ ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
+ if (mHasHistoricalInfo && !probabilityEntry.representsBeginningOfSentence()) {
+ const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave(
+ probabilityEntry.getHistoricalInfo(), headerPolicy);
+ if (ForgettingCurveUtils::needsToKeep(&historicalInfo, headerPolicy)) {
+ // Update the entry.
+ const ProbabilityEntry updatedEntry(probabilityEntry.getFlags(), &historicalInfo);
+ if (!mTrieMap.put(entry.key(), updatedEntry.encode(mHasHistoricalInfo),
+ bitmapEntryIndex)) {
+ return false;
+ }
+ } else {
+ // Remove the entry.
+ if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) {
+ return false;
+ }
+ continue;
+ }
+ }
+ if (!probabilityEntry.representsBeginningOfSentence()) {
+ outEntryCounts[level] += 1;
+ }
+ if (!entry.hasNextLevelMap()) {
+ continue;
+ }
+ if (!updateAllProbabilityEntriesInner(entry.getNextLevelBitmapEntryIndex(), level + 1,
+ headerPolicy, outEntryCounts)) {
+ return false;
+ }
+ }
+ return true;
+}
+
+bool LanguageModelDictContent::turncateEntriesInSpecifiedLevel(
+ const HeaderPolicy *const headerPolicy, const int maxEntryCount, const int targetLevel) {
+ std::vector<int> prevWordIds;
+ std::vector<EntryInfoToTurncate> entryInfoVector;
+ if (!getEntryInfo(headerPolicy, targetLevel, mTrieMap.getRootBitmapEntryIndex(),
+ &prevWordIds, &entryInfoVector)) {
+ return false;
+ }
+ if (static_cast<int>(entryInfoVector.size()) <= maxEntryCount) {
+ return true;
+ }
+ const int entryCountToRemove = static_cast<int>(entryInfoVector.size()) - maxEntryCount;
+ std::partial_sort(entryInfoVector.begin(), entryInfoVector.begin() + entryCountToRemove,
+ entryInfoVector.end(),
+ EntryInfoToTurncate::Comparator());
+ for (int i = 0; i < entryCountToRemove; ++i) {
+ const EntryInfoToTurncate &entryInfo = entryInfoVector[i];
+ if (!removeNgramProbabilityEntry(
+ WordIdArrayView(entryInfo.mPrevWordIds, entryInfo.mEntryLevel), entryInfo.mKey)) {
+ return false;
+ }
+ }
+ return true;
+}
+
+bool LanguageModelDictContent::getEntryInfo(const HeaderPolicy *const headerPolicy,
+ const int targetLevel, const int bitmapEntryIndex, std::vector<int> *const prevWordIds,
+ std::vector<EntryInfoToTurncate> *const outEntryInfo) const {
+ const int currentLevel = prevWordIds->size();
+ for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
+ if (currentLevel < targetLevel) {
+ if (!entry.hasNextLevelMap()) {
+ continue;
+ }
+ prevWordIds->push_back(entry.key());
+ if (!getEntryInfo(headerPolicy, targetLevel, entry.getNextLevelBitmapEntryIndex(),
+ prevWordIds, outEntryInfo)) {
+ return false;
+ }
+ prevWordIds->pop_back();
+ continue;
+ }
+ const ProbabilityEntry probabilityEntry =
+ ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
+ const int probability = (mHasHistoricalInfo) ?
+ ForgettingCurveUtils::decodeProbability(probabilityEntry.getHistoricalInfo(),
+ headerPolicy) : probabilityEntry.getProbability();
+ outEntryInfo->emplace_back(probability,
+ probabilityEntry.getHistoricalInfo()->getTimeStamp(),
+ entry.key(), targetLevel, prevWordIds->data());
+ }
+ return true;
+}
+
+bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()(
+ const EntryInfoToTurncate &left, const EntryInfoToTurncate &right) const {
+ if (left.mProbability != right.mProbability) {
+ return left.mProbability < right.mProbability;
+ }
+ if (left.mTimestamp != right.mTimestamp) {
+ return left.mTimestamp > right.mTimestamp;
+ }
+ if (left.mKey != right.mKey) {
+ return left.mKey < right.mKey;
+ }
+ if (left.mEntryLevel != right.mEntryLevel) {
+ return left.mEntryLevel > right.mEntryLevel;
+ }
+ for (int i = 0; i < left.mEntryLevel; ++i) {
+ if (left.mPrevWordIds[i] != right.mPrevWordIds[i]) {
+ return left.mPrevWordIds[i] < right.mPrevWordIds[i];
+ }
+ }
+ // left and rigth represent the same entry.
+ return false;
+}
+
+LanguageModelDictContent::EntryInfoToTurncate::EntryInfoToTurncate(const int probability,
+ const int timestamp, const int key, const int entryLevel, const int *const prevWordIds)
+ : mProbability(probability), mTimestamp(timestamp), mKey(key), mEntryLevel(entryLevel) {
+ memmove(mPrevWordIds, prevWordIds, mEntryLevel * sizeof(mPrevWordIds[0]));
+}
+
} // namespace latinime
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
index 18f2e0170..43b2aab66 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
@@ -18,6 +18,7 @@
#define LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H
#include <cstdio>
+#include <vector>
#include "defines.h"
#include "suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h"
@@ -29,6 +30,8 @@
namespace latinime {
+class HeaderPolicy;
+
/**
* Class representing language model.
*
@@ -61,23 +64,72 @@ class LanguageModelDictContent {
return setNgramProbabilityEntry(WordIdArrayView(), wordId, probabilityEntry);
}
+ bool removeProbabilityEntry(const int wordId) {
+ return removeNgramProbabilityEntry(WordIdArrayView(), wordId);
+ }
+
ProbabilityEntry getNgramProbabilityEntry(const WordIdArrayView prevWordIds,
const int wordId) const;
bool setNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId,
const ProbabilityEntry *const probabilityEntry);
+ bool removeNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId);
+
+ bool updateAllProbabilityEntries(const HeaderPolicy *const headerPolicy,
+ int *const outEntryCounts) {
+ for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
+ outEntryCounts[i] = 0;
+ }
+ return updateAllProbabilityEntriesInner(mTrieMap.getRootBitmapEntryIndex(), 0 /* level */,
+ headerPolicy, outEntryCounts);
+ }
+
+ // entryCounts should be created by updateAllProbabilityEntries.
+ bool truncateEntries(const int *const entryCounts, const int *const maxEntryCounts,
+ const HeaderPolicy *const headerPolicy);
+
private:
DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent);
+ class EntryInfoToTurncate {
+ public:
+ class Comparator {
+ public:
+ bool operator()(const EntryInfoToTurncate &left,
+ const EntryInfoToTurncate &right) const;
+ private:
+ DISALLOW_ASSIGNMENT_OPERATOR(Comparator);
+ };
+
+ EntryInfoToTurncate(const int probability, const int timestamp, const int key,
+ const int entryLevel, const int *const prevWordIds);
+
+ int mProbability;
+ int mTimestamp;
+ int mKey;
+ int mEntryLevel;
+ int mPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
+
+ private:
+ DISALLOW_DEFAULT_CONSTRUCTOR(EntryInfoToTurncate);
+ };
+
TrieMap mTrieMap;
const bool mHasHistoricalInfo;
bool runGCInner(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex,
int *const outNgramCount);
-
+ int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds);
int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
+ bool updateAllProbabilityEntriesInner(const int bitmapEntryIndex, const int level,
+ const HeaderPolicy *const headerPolicy, int *const outEntryCounts);
+ bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy,
+ const int maxEntryCount, const int targetLevel);
+ bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel,
+ const int bitmapEntryIndex, std::vector<int> *const prevWordIds,
+ std::vector<EntryInfoToTurncate> *const outEntryInfo) const;
};
} // namespace latinime
#endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h
index feff6b57f..3dfaba755 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h
@@ -21,6 +21,8 @@
#include <cstdint>
#include "defines.h"
+#include "suggest/core/dictionary/property/bigram_property.h"
+#include "suggest/core/dictionary/property/unigram_property.h"
#include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h"
#include "suggest/policyimpl/dictionary/utils/historical_info.h"
@@ -41,24 +43,32 @@ class ProbabilityEntry {
: mFlags(flags), mProbability(probability), mHistoricalInfo() {}
// Entry with historical information.
- ProbabilityEntry(const int flags, const int probability,
- const HistoricalInfo *const historicalInfo)
- : mFlags(flags), mProbability(probability), mHistoricalInfo(*historicalInfo) {}
-
- const ProbabilityEntry createEntryWithUpdatedProbability(const int probability) const {
- return ProbabilityEntry(mFlags, probability, &mHistoricalInfo);
- }
-
- const ProbabilityEntry createEntryWithUpdatedHistoricalInfo(
- const HistoricalInfo *const historicalInfo) const {
- return ProbabilityEntry(mFlags, mProbability, historicalInfo);
+ ProbabilityEntry(const int flags, const HistoricalInfo *const historicalInfo)
+ : mFlags(flags), mProbability(NOT_A_PROBABILITY), mHistoricalInfo(*historicalInfo) {}
+
+ // Create from unigram property.
+ ProbabilityEntry(const UnigramProperty *const unigramProperty)
+ : mFlags(createFlags(unigramProperty->representsBeginningOfSentence())),
+ mProbability(unigramProperty->getProbability()),
+ mHistoricalInfo(unigramProperty->getTimestamp(), unigramProperty->getLevel(),
+ unigramProperty->getCount()) {}
+
+ // Create from bigram property.
+ // TODO: Set flags.
+ ProbabilityEntry(const BigramProperty *const bigramProperty)
+ : mFlags(0), mProbability(bigramProperty->getProbability()),
+ mHistoricalInfo(bigramProperty->getTimestamp(), bigramProperty->getLevel(),
+ bigramProperty->getCount()) {}
+
+ bool isValid() const {
+ return (mProbability != NOT_A_PROBABILITY) || hasHistoricalInfo();
}
bool hasHistoricalInfo() const {
return mHistoricalInfo.isValid();
}
- int getFlags() const {
+ uint8_t getFlags() const {
return mFlags;
}
@@ -70,6 +80,10 @@ class ProbabilityEntry {
return &mHistoricalInfo;
}
+ bool representsBeginningOfSentence() const {
+ return (mFlags & Ver4DictConstants::FLAG_REPRESENTS_BEGINNING_OF_SENTENCE) != 0;
+ }
+
uint64_t encode(const bool hasHistoricalInfo) const {
uint64_t encodedEntry = static_cast<uint64_t>(mFlags);
if (hasHistoricalInfo) {
@@ -89,7 +103,7 @@ class ProbabilityEntry {
static ProbabilityEntry decode(const uint64_t encodedEntry, const bool hasHistoricalInfo) {
if (hasHistoricalInfo) {
const int flags = readFromEncodedEntry(encodedEntry,
- Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE,
+ Ver4DictConstants::FLAGS_IN_LANGUAGE_MODEL_SIZE,
Ver4DictConstants::TIME_STAMP_FIELD_SIZE
+ Ver4DictConstants::WORD_LEVEL_FIELD_SIZE
+ Ver4DictConstants::WORD_COUNT_FIELD_SIZE);
@@ -103,10 +117,10 @@ class ProbabilityEntry {
const int count = readFromEncodedEntry(encodedEntry,
Ver4DictConstants::WORD_COUNT_FIELD_SIZE, 0 /* pos */);
const HistoricalInfo historicalInfo(timestamp, level, count);
- return ProbabilityEntry(flags, NOT_A_PROBABILITY, &historicalInfo);
+ return ProbabilityEntry(flags, &historicalInfo);
} else {
const int flags = readFromEncodedEntry(encodedEntry,
- Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE,
+ Ver4DictConstants::FLAGS_IN_LANGUAGE_MODEL_SIZE,
Ver4DictConstants::PROBABILITY_SIZE);
const int probability = readFromEncodedEntry(encodedEntry,
Ver4DictConstants::PROBABILITY_SIZE, 0 /* pos */);
@@ -118,7 +132,7 @@ class ProbabilityEntry {
// Copy constructor is public to use this class as a type of return value.
DISALLOW_ASSIGNMENT_OPERATOR(ProbabilityEntry);
- const int mFlags;
+ const uint8_t mFlags;
const int mProbability;
const HistoricalInfo mHistoricalInfo;
@@ -126,6 +140,14 @@ class ProbabilityEntry {
return static_cast<int>(
(encodedEntry >> (pos * CHAR_BIT)) & ((1ull << (size * CHAR_BIT)) - 1));
}
+
+ static uint8_t createFlags(const bool representsBeginningOfSentence) {
+ uint8_t flags = 0;
+ if (representsBeginningOfSentence) {
+ flags ^= Ver4DictConstants::FLAG_REPRESENTS_BEGINNING_OF_SENTENCE;
+ }
+ return flags;
+ }
};
} // namespace latinime
#endif /* LATINIME_PROBABILITY_ENTRY_H */
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/shortcut_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/shortcut_dict_content.h
index 7b12aff16..85c9ce8d8 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/shortcut_dict_content.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/shortcut_dict_content.h
@@ -17,7 +17,6 @@
#ifndef LATINIME_SHORTCUT_DICT_CONTENT_H
#define LATINIME_SHORTCUT_DICT_CONTENT_H
-#include <cstdint>
#include <cstdio>
#include "defines.h"
@@ -27,11 +26,12 @@
namespace latinime {
+class ReadWriteByteArrayView;
+
class ShortcutDictContent : public SparseTableDictContent {
public:
- ShortcutDictContent(uint8_t *const *buffers, const int *bufferSizes)
- : SparseTableDictContent(buffers, bufferSizes,
- Ver4DictConstants::SHORTCUT_ADDRESS_TABLE_BLOCK_SIZE,
+ ShortcutDictContent(const ReadWriteByteArrayView *const buffers)
+ : SparseTableDictContent(buffers, Ver4DictConstants::SHORTCUT_ADDRESS_TABLE_BLOCK_SIZE,
Ver4DictConstants::SHORTCUT_ADDRESS_TABLE_DATA_SIZE) {}
ShortcutDictContent()
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/single_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/single_dict_content.h
index 921774181..309c434cf 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/single_dict_content.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/single_dict_content.h
@@ -17,7 +17,6 @@
#ifndef LATINIME_SINGLE_DICT_CONTENT_H
#define LATINIME_SINGLE_DICT_CONTENT_H
-#include <cstdint>
#include <cstdio>
#include "defines.h"
@@ -30,9 +29,9 @@ namespace latinime {
class SingleDictContent {
public:
- SingleDictContent(uint8_t *const buffer, const int bufferSize)
- : mExpandableContentBuffer(ReadWriteByteArrayView(buffer, bufferSize),
- BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE) {}
+ SingleDictContent(const ReadWriteByteArrayView buffer)
+ : mExpandableContentBuffer(buffer,
+ BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE) {}
SingleDictContent()
: mExpandableContentBuffer(Ver4DictConstants::MAX_DICTIONARY_SIZE) {}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/sparse_table_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/sparse_table_dict_content.h
index c98dd11fd..0ce2da7bf 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/sparse_table_dict_content.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/sparse_table_dict_content.h
@@ -17,7 +17,6 @@
#ifndef LATINIME_SPARSE_TABLE_DICT_CONTENT_H
#define LATINIME_SPARSE_TABLE_DICT_CONTENT_H
-#include <cstdint>
#include <cstdio>
#include "defines.h"
@@ -31,19 +30,13 @@ namespace latinime {
// TODO: Support multiple contents.
class SparseTableDictContent {
public:
- AK_FORCE_INLINE SparseTableDictContent(uint8_t *const *buffers, const int *bufferSizes,
+ AK_FORCE_INLINE SparseTableDictContent(const ReadWriteByteArrayView *const buffers,
const int sparseTableBlockSize, const int sparseTableDataSize)
- : mExpandableLookupTableBuffer(
- ReadWriteByteArrayView(buffers[LOOKUP_TABLE_BUFFER_INDEX],
- bufferSizes[LOOKUP_TABLE_BUFFER_INDEX]),
+ : mExpandableLookupTableBuffer(buffers[LOOKUP_TABLE_BUFFER_INDEX],
BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE),
- mExpandableAddressTableBuffer(
- ReadWriteByteArrayView(buffers[ADDRESS_TABLE_BUFFER_INDEX],
- bufferSizes[ADDRESS_TABLE_BUFFER_INDEX]),
+ mExpandableAddressTableBuffer(buffers[ADDRESS_TABLE_BUFFER_INDEX],
BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE),
- mExpandableContentBuffer(
- ReadWriteByteArrayView(buffers[CONTENT_BUFFER_INDEX],
- bufferSizes[CONTENT_BUFFER_INDEX]),
+ mExpandableContentBuffer(buffers[CONTENT_BUFFER_INDEX],
BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE),
mAddressLookupTable(&mExpandableLookupTableBuffer, &mExpandableAddressTableBuffer,
sparseTableBlockSize, sparseTableDataSize) {}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h
index b2262bf1e..febcbe5b4 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h
@@ -17,13 +17,13 @@
#ifndef LATINIME_TERMINAL_POSITION_LOOKUP_TABLE_H
#define LATINIME_TERMINAL_POSITION_LOOKUP_TABLE_H
-#include <cstdint>
#include <cstdio>
#include <unordered_map>
#include "defines.h"
#include "suggest/policyimpl/dictionary/structure/v4/content/single_dict_content.h"
#include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h"
+#include "utils/byte_array_view.h"
namespace latinime {
@@ -31,8 +31,8 @@ class TerminalPositionLookupTable : public SingleDictContent {
public:
typedef std::unordered_map<int, int> TerminalIdMap;
- TerminalPositionLookupTable(uint8_t *const buffer, const int bufferSize)
- : SingleDictContent(buffer, bufferSize),
+ TerminalPositionLookupTable(const ReadWriteByteArrayView buffer)
+ : SingleDictContent(buffer),
mSize(getBuffer()->getTailPosition()
/ Ver4DictConstants::TERMINAL_ADDRESS_TABLE_ADDRESS_SIZE) {}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.cpp
index 3c8008dc4..1f40e3dd2 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.cpp
@@ -45,16 +45,13 @@ namespace latinime {
if (!bodyBuffer) {
return Ver4DictBuffersPtr(nullptr);
}
- std::vector<uint8_t *> buffers;
- std::vector<int> bufferSizes;
+ std::vector<ReadWriteByteArrayView> buffers;
const ReadWriteByteArrayView buffer = bodyBuffer->getReadWriteByteArrayView();
int position = 0;
while (position < static_cast<int>(buffer.size())) {
const int bufferSize = ByteArrayUtils::readUint32AndAdvancePosition(
buffer.data(), &position);
- const ReadWriteByteArrayView subBuffer = buffer.subView(position, bufferSize);
- buffers.push_back(subBuffer.data());
- bufferSizes.push_back(subBuffer.size());
+ buffers.push_back(buffer.subView(position, bufferSize));
position += bufferSize;
if (bufferSize < 0 || position < 0 || position > static_cast<int>(buffer.size())) {
AKLOGE("The dict body file is corrupted.");
@@ -66,7 +63,7 @@ namespace latinime {
return Ver4DictBuffersPtr(nullptr);
}
return Ver4DictBuffersPtr(new Ver4DictBuffers(std::move(headerBuffer), std::move(bodyBuffer),
- formatVersion, buffers, bufferSizes));
+ formatVersion, buffers));
}
bool Ver4DictBuffers::flushHeaderAndDictBuffers(const char *const dictDirPath,
@@ -178,29 +175,20 @@ bool Ver4DictBuffers::flushDictBuffers(FILE *const file) const {
Ver4DictBuffers::Ver4DictBuffers(MmappedBuffer::MmappedBufferPtr &&headerBuffer,
MmappedBuffer::MmappedBufferPtr &&bodyBuffer,
const FormatUtils::FORMAT_VERSION formatVersion,
- const std::vector<uint8_t *> &contentBuffers, const std::vector<int> &contentBufferSizes)
+ const std::vector<ReadWriteByteArrayView> &contentBuffers)
: mHeaderBuffer(std::move(headerBuffer)), mDictBuffer(std::move(bodyBuffer)),
mHeaderPolicy(mHeaderBuffer->getReadOnlyByteArrayView().data(), formatVersion),
mExpandableHeaderBuffer(mHeaderBuffer->getReadWriteByteArrayView(),
BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE),
- mExpandableTrieBuffer(
- ReadWriteByteArrayView(contentBuffers[Ver4DictConstants::TRIE_BUFFER_INDEX],
- contentBufferSizes[Ver4DictConstants::TRIE_BUFFER_INDEX]),
+ mExpandableTrieBuffer(contentBuffers[Ver4DictConstants::TRIE_BUFFER_INDEX],
BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE),
mTerminalPositionLookupTable(
- contentBuffers[Ver4DictConstants::TERMINAL_ADDRESS_LOOKUP_TABLE_BUFFER_INDEX],
- contentBufferSizes[
- Ver4DictConstants::TERMINAL_ADDRESS_LOOKUP_TABLE_BUFFER_INDEX]),
- mLanguageModelDictContent(
- ReadWriteByteArrayView(
- contentBuffers[Ver4DictConstants::LANGUAGE_MODEL_BUFFER_INDEX],
- contentBufferSizes[Ver4DictConstants::LANGUAGE_MODEL_BUFFER_INDEX]),
+ contentBuffers[Ver4DictConstants::TERMINAL_ADDRESS_LOOKUP_TABLE_BUFFER_INDEX]),
+ mLanguageModelDictContent(contentBuffers[Ver4DictConstants::LANGUAGE_MODEL_BUFFER_INDEX],
mHeaderPolicy.hasHistoricalInfoOfWords()),
mBigramDictContent(&contentBuffers[Ver4DictConstants::BIGRAM_BUFFERS_INDEX],
- &contentBufferSizes[Ver4DictConstants::BIGRAM_BUFFERS_INDEX],
mHeaderPolicy.hasHistoricalInfoOfWords()),
- mShortcutDictContent(&contentBuffers[Ver4DictConstants::SHORTCUT_BUFFERS_INDEX],
- &contentBufferSizes[Ver4DictConstants::SHORTCUT_BUFFERS_INDEX]),
+ mShortcutDictContent(&contentBuffers[Ver4DictConstants::SHORTCUT_BUFFERS_INDEX]),
mIsUpdatable(mDictBuffer->isUpdatable()) {}
Ver4DictBuffers::Ver4DictBuffers(const HeaderPolicy *const headerPolicy, const int maxTrieSize)
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.h
index 68027dcb8..70a7983f1 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.h
@@ -122,8 +122,7 @@ class Ver4DictBuffers {
Ver4DictBuffers(MmappedBuffer::MmappedBufferPtr &&headerBuffer,
MmappedBuffer::MmappedBufferPtr &&bodyBuffer,
const FormatUtils::FORMAT_VERSION formatVersion,
- const std::vector<uint8_t *> &contentBuffers,
- const std::vector<int> &contentBufferSizes);
+ const std::vector<ReadWriteByteArrayView> &contentBuffers);
Ver4DictBuffers(const HeaderPolicy *const headerPolicy, const int maxTrieSize);
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp
index 93d4e562d..b085a6661 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp
@@ -46,7 +46,7 @@ const int Ver4DictConstants::SHORTCUT_BUFFERS_INDEX =
const int Ver4DictConstants::NOT_A_TERMINAL_ID = -1;
const int Ver4DictConstants::PROBABILITY_SIZE = 1;
-const int Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE = 1;
+const int Ver4DictConstants::FLAGS_IN_LANGUAGE_MODEL_SIZE = 1;
const int Ver4DictConstants::TERMINAL_ADDRESS_TABLE_ADDRESS_SIZE = 3;
const int Ver4DictConstants::NOT_A_TERMINAL_ADDRESS = 0;
const int Ver4DictConstants::TERMINAL_ID_FIELD_SIZE = 4;
@@ -54,6 +54,8 @@ const int Ver4DictConstants::TIME_STAMP_FIELD_SIZE = 4;
const int Ver4DictConstants::WORD_LEVEL_FIELD_SIZE = 1;
const int Ver4DictConstants::WORD_COUNT_FIELD_SIZE = 1;
+const uint8_t Ver4DictConstants::FLAG_REPRESENTS_BEGINNING_OF_SENTENCE = 0x1;
+
const int Ver4DictConstants::BIGRAM_ADDRESS_TABLE_BLOCK_SIZE = 16;
const int Ver4DictConstants::BIGRAM_ADDRESS_TABLE_DATA_SIZE = 4;
const int Ver4DictConstants::SHORTCUT_ADDRESS_TABLE_BLOCK_SIZE = 64;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h
index 6950ca70f..230b3052d 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h
@@ -20,6 +20,7 @@
#include "defines.h"
#include <cstddef>
+#include <cstdint>
namespace latinime {
@@ -41,13 +42,15 @@ class Ver4DictConstants {
static const int NOT_A_TERMINAL_ID;
static const int PROBABILITY_SIZE;
- static const int FLAGS_IN_PROBABILITY_FILE_SIZE;
+ static const int FLAGS_IN_LANGUAGE_MODEL_SIZE;
static const int TERMINAL_ADDRESS_TABLE_ADDRESS_SIZE;
static const int NOT_A_TERMINAL_ADDRESS;
static const int TERMINAL_ID_FIELD_SIZE;
static const int TIME_STAMP_FIELD_SIZE;
static const int WORD_LEVEL_FIELD_SIZE;
static const int WORD_COUNT_FIELD_SIZE;
+ // Flags in probability entry.
+ static const uint8_t FLAG_REPRESENTS_BEGINNING_OF_SENTENCE;
static const int BIGRAM_ADDRESS_TABLE_BLOCK_SIZE;
static const int BIGRAM_ADDRESS_TABLE_DATA_SIZE;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp
index 857222f5d..b7c31bf75 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp
@@ -145,10 +145,11 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeUnigramProperty(
const ProbabilityEntry originalProbabilityEntry =
mBuffers->getLanguageModelDictContent()->getProbabilityEntry(
toBeUpdatedPtNodeParams->getTerminalId());
- const ProbabilityEntry probabilityEntry = createUpdatedEntryFrom(&originalProbabilityEntry,
- unigramProperty);
+ const ProbabilityEntry probabilityEntryOfUnigramProperty = ProbabilityEntry(unigramProperty);
+ const ProbabilityEntry updatedProbabilityEntry =
+ createUpdatedEntryFrom(&originalProbabilityEntry, &probabilityEntryOfUnigramProperty);
return mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry(
- toBeUpdatedPtNodeParams->getTerminalId(), &probabilityEntry);
+ toBeUpdatedPtNodeParams->getTerminalId(), &updatedProbabilityEntry);
}
bool Ver4PatriciaTrieNodeWriter::updatePtNodeProbabilityAndGetNeedsToKeepPtNodeAfterGC(
@@ -160,29 +161,15 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeProbabilityAndGetNeedsToKeepPtNodeA
const ProbabilityEntry originalProbabilityEntry =
mBuffers->getLanguageModelDictContent()->getProbabilityEntry(
toBeUpdatedPtNodeParams->getTerminalId());
- if (originalProbabilityEntry.hasHistoricalInfo()) {
- const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave(
- originalProbabilityEntry.getHistoricalInfo(), mHeaderPolicy);
- const ProbabilityEntry probabilityEntry =
- originalProbabilityEntry.createEntryWithUpdatedHistoricalInfo(&historicalInfo);
- if (!mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry(
- toBeUpdatedPtNodeParams->getTerminalId(), &probabilityEntry)) {
- AKLOGE("Cannot write updated probability entry. terminalId: %d",
- toBeUpdatedPtNodeParams->getTerminalId());
- return false;
- }
- const bool isValid = ForgettingCurveUtils::needsToKeep(&historicalInfo, mHeaderPolicy);
- if (!isValid) {
- if (!markPtNodeAsWillBecomeNonTerminal(toBeUpdatedPtNodeParams)) {
- AKLOGE("Cannot mark PtNode as willBecomeNonTerminal.");
- return false;
- }
- }
- *outNeedsToKeepPtNode = isValid;
- } else {
- // No need to update probability.
+ if (originalProbabilityEntry.isValid()) {
*outNeedsToKeepPtNode = true;
+ return true;
}
+ if (!markPtNodeAsWillBecomeNonTerminal(toBeUpdatedPtNodeParams)) {
+ AKLOGE("Cannot mark PtNode as willBecomeNonTerminal.");
+ return false;
+ }
+ *outNeedsToKeepPtNode = false;
return true;
}
@@ -216,16 +203,36 @@ bool Ver4PatriciaTrieNodeWriter::writeNewTerminalPtNodeAndAdvancePosition(
}
// Write probability.
ProbabilityEntry newProbabilityEntry;
+ const ProbabilityEntry probabilityEntryOfUnigramProperty = ProbabilityEntry(unigramProperty);
const ProbabilityEntry probabilityEntryToWrite = createUpdatedEntryFrom(
- &newProbabilityEntry, unigramProperty);
+ &newProbabilityEntry, &probabilityEntryOfUnigramProperty);
return mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry(
terminalId, &probabilityEntryToWrite);
}
bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds, const int wordId,
const BigramProperty *const bigramProperty, bool *const outAddedNewBigram) {
+ // TODO: Support n-gram.
+ LanguageModelDictContent *const languageModelDictContent =
+ mBuffers->getMutableLanguageModelDictContent();
+ const ProbabilityEntry probabilityEntry =
+ languageModelDictContent->getNgramProbabilityEntry(
+ prevWordIds.limit(1 /* maxSize */), wordId);
+ const ProbabilityEntry probabilityEntryOfBigramProperty(bigramProperty);
+ const ProbabilityEntry updatedProbabilityEntry = createUpdatedEntryFrom(
+ &probabilityEntry, &probabilityEntryOfBigramProperty);
+ if (!languageModelDictContent->setNgramProbabilityEntry(
+ prevWordIds.limit(1 /* maxSize */), wordId, &updatedProbabilityEntry)) {
+ AKLOGE("Cannot add new ngram entry. prevWordId: %d, wordId: %d",
+ prevWordIds[0], wordId);
+ return false;
+ }
+ if (!probabilityEntry.isValid() && outAddedNewBigram) {
+ *outAddedNewBigram = true;
+ }
+ // TODO: Remove.
if (!mBigramPolicy->addNewEntry(prevWordIds[0], wordId, bigramProperty, outAddedNewBigram)) {
- AKLOGE("Cannot add new bigram entry. terminalId: %d, targetTerminalId: %d",
+ AKLOGE("Cannot add new bigram entry. prevWordId: %d, wordId: %d",
prevWordIds[0], wordId);
return false;
}
@@ -234,6 +241,15 @@ bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds
bool Ver4PatriciaTrieNodeWriter::removeNgramEntry(const WordIdArrayView prevWordIds,
const int wordId) {
+ // TODO: Support n-gram.
+ LanguageModelDictContent *const languageModelDictContent =
+ mBuffers->getMutableLanguageModelDictContent();
+ if (!languageModelDictContent->removeNgramProbabilityEntry(prevWordIds.limit(1 /* maxSize */),
+ wordId)) {
+ // TODO: Uncomment.
+ // return false;
+ }
+ // TODO: Remove.
return mBigramPolicy->removeEntry(prevWordIds[0], wordId);
}
@@ -350,22 +366,19 @@ bool Ver4PatriciaTrieNodeWriter::writePtNodeAndGetTerminalIdAndAdvancePosition(
isTerminal, ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */);
}
+// TODO: Move probability handling code to LanguageModelDictContent.
const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom(
const ProbabilityEntry *const originalProbabilityEntry,
- const UnigramProperty *const unigramProperty) const {
- // TODO: Consolidate historical info and probability.
+ const ProbabilityEntry *const probabilityEntry) const {
if (mHeaderPolicy->hasHistoricalInfoOfWords()) {
- const HistoricalInfo historicalInfoForUpdate(unigramProperty->getTimestamp(),
- unigramProperty->getLevel(), unigramProperty->getCount());
const HistoricalInfo updatedHistoricalInfo =
ForgettingCurveUtils::createUpdatedHistoricalInfo(
originalProbabilityEntry->getHistoricalInfo(),
- unigramProperty->getProbability(), &historicalInfoForUpdate, mHeaderPolicy);
- return originalProbabilityEntry->createEntryWithUpdatedHistoricalInfo(
- &updatedHistoricalInfo);
+ probabilityEntry->getProbability(), probabilityEntry->getHistoricalInfo(),
+ mHeaderPolicy);
+ return ProbabilityEntry(probabilityEntry->getFlags(), &updatedHistoricalInfo);
} else {
- return originalProbabilityEntry->createEntryWithUpdatedProbability(
- unigramProperty->getProbability());
+ return *probabilityEntry;
}
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h
index 6703dba04..5d73b6ea3 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h
@@ -98,12 +98,12 @@ class Ver4PatriciaTrieNodeWriter : public PtNodeWriter {
const PtNodeParams *const ptNodeParams, int *const outTerminalId,
int *const ptNodeWritingPos);
- // Create updated probability entry using given unigram property. In addition to the
+ // Create updated probability entry using given probability property. In addition to the
// probability, this method updates historical information if needed.
- // TODO: Update flags belonging to the unigram property.
+ // TODO: Update flags.
const ProbabilityEntry createUpdatedEntryFrom(
const ProbabilityEntry *const originalProbabilityEntry,
- const UnigramProperty *const unigramProperty) const;
+ const ProbabilityEntry *const probabilityEntry) const;
bool updatePtNodeFlags(const int ptNodePos, const bool isBlacklisted, const bool isNotAWord,
const bool isTerminal, const bool hasMultipleChars);
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
index 723808399..2ea248e86 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
@@ -127,21 +127,28 @@ int Ver4PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtN
if (ptNodePos == NOT_A_DICT_POS) {
return NOT_A_PROBABILITY;
}
- const PtNodeParams ptNodeParams(mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos));
+ const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) {
return NOT_A_PROBABILITY;
}
if (prevWordsPtNodePos) {
- const int bigramsPosition = getBigramsPositionOfPtNode(prevWordsPtNodePos[0]);
- BinaryDictionaryBigramsIterator bigramsIt(&mBigramPolicy, bigramsPosition);
- while (bigramsIt.hasNext()) {
- bigramsIt.next();
- if (bigramsIt.getBigramPos() == ptNodePos
- && bigramsIt.getProbability() != NOT_A_PROBABILITY) {
- return getProbability(ptNodeParams.getProbability(), bigramsIt.getProbability());
- }
+ // TODO: Support n-gram.
+ const PtNodeParams prevWordPtNodeParams =
+ mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(prevWordsPtNodePos[0]);
+ const int prevWordTerminalId = prevWordPtNodeParams.getTerminalId();
+ const ProbabilityEntry probabilityEntry =
+ mBuffers->getLanguageModelDictContent()->getNgramProbabilityEntry(
+ IntArrayView::fromObject(&prevWordTerminalId),
+ ptNodeParams.getTerminalId());
+ if (!probabilityEntry.isValid()) {
+ return NOT_A_PROBABILITY;
+ }
+ if (mHeaderPolicy->hasHistoricalInfoOfWords()) {
+ return ForgettingCurveUtils::decodeProbability(probabilityEntry.getHistoricalInfo(),
+ mHeaderPolicy);
+ } else {
+ return probabilityEntry.getProbability();
}
- return NOT_A_PROBABILITY;
}
return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
}
@@ -200,7 +207,7 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const int *const word, const int le
}
for (const auto &shortcut : unigramProperty->getShortcuts()) {
if (shortcut.getTargetCodePoints()->size() > MAX_WORD_LENGTH) {
- AKLOGE("One of shortcut targets is too long to insert to the dictionary, length: %d",
+ AKLOGE("One of shortcut targets is too long to insert to the dictionary, length: %zd",
shortcut.getTargetCodePoints()->size());
return false;
}
@@ -235,7 +242,7 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const int *const word, const int le
if (!mUpdatingHelper.addShortcutTarget(wordPos,
shortcut.getTargetCodePoints()->data(),
shortcut.getTargetCodePoints()->size(), shortcut.getProbability())) {
- AKLOGE("Cannot add new shortcut target. PtNodePos: %d, length: %d, "
+ AKLOGE("Cannot add new shortcut target. PtNodePos: %d, length: %zd, "
"probability: %d", wordPos, shortcut.getTargetCodePoints()->size(),
shortcut.getProbability());
return false;
@@ -263,6 +270,11 @@ bool Ver4PatriciaTriePolicy::removeUnigramEntry(const int *const word, const int
AKLOGE("Cannot remove unigram. ptNodePos: %d", ptNodePos);
return false;
}
+ if (!mBuffers->getMutableLanguageModelDictContent()->removeProbabilityEntry(
+ ptNodeParams.getTerminalId())) {
+ // TODO: Uncomment.
+ // return false;
+ }
if (!ptNodeParams.representsNonWordInfo()) {
mUnigramCount--;
}
@@ -286,7 +298,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI
}
if (bigramProperty->getTargetCodePoints()->size() > MAX_WORD_LENGTH) {
AKLOGE("The word is too long to insert the ngram to the dictionary. "
- "length: %d", bigramProperty->getTargetCodePoints()->size());
+ "length: %zd", bigramProperty->getTargetCodePoints()->size());
return false;
}
int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
index 4220312e0..d53575aa7 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
@@ -85,6 +85,27 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
mBuffers, headerPolicy, &ptNodeReader, &ptNodeArrayReader, &bigramPolicy,
&shortcutPolicy);
+ int entryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
+ if (!mBuffers->getMutableLanguageModelDictContent()->updateAllProbabilityEntries(headerPolicy,
+ entryCountTable)) {
+ AKLOGE("Failed to update probabilities in language model dict content.");
+ return false;
+ }
+ if (headerPolicy->isDecayingDict()) {
+ int maxEntryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
+ maxEntryCountTable[0] = headerPolicy->getMaxUnigramCount();
+ maxEntryCountTable[1] = headerPolicy->getMaxBigramCount();
+ for (size_t i = 2; i < NELEMS(maxEntryCountTable); ++i) {
+ // TODO: Have max n-gram count.
+ maxEntryCountTable[i] = headerPolicy->getMaxBigramCount();
+ }
+ if (!mBuffers->getMutableLanguageModelDictContent()->truncateEntries(entryCountTable,
+ maxEntryCountTable, headerPolicy)) {
+ AKLOGE("Failed to truncate entries in language model dict content.");
+ return false;
+ }
+ }
+
DynamicPtReadingHelper readingHelper(&ptNodeReader, &ptNodeArrayReader);
readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos);
DynamicPtGcEventListeners
@@ -187,6 +208,7 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
return true;
}
+// TODO: Remove.
bool Ver4PatriciaTrieWritingHelper::truncateUnigrams(
const Ver4PatriciaTrieNodeReader *const ptNodeReader,
Ver4PatriciaTrieNodeWriter *const ptNodeWriter, const int maxUnigramCount) {
@@ -227,6 +249,7 @@ bool Ver4PatriciaTrieWritingHelper::truncateUnigrams(
return true;
}
+// TODO: Remove.
bool Ver4PatriciaTrieWritingHelper::truncateBigrams(const int maxBigramCount) {
const TerminalPositionLookupTable *const terminalPosLookupTable =
mBuffers->getTerminalPositionLookupTable();
diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.cpp
index 833063c17..ecbe7922c 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.cpp
@@ -31,7 +31,7 @@ uint32_t BufferWithExtendableBuffer::readUint(const int size, const int pos) con
uint32_t BufferWithExtendableBuffer::readUintAndAdvancePosition(const int size,
int *const pos) const {
- const int value = readUint(size, *pos);
+ const uint32_t value = readUint(size, *pos);
*pos += size;
return value;
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/byte_array_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/byte_array_utils.h
index c0a9fcb1d..4b3c98988 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/utils/byte_array_utils.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/utils/byte_array_utils.h
@@ -114,7 +114,7 @@ class ByteArrayUtils {
return buffer[(*pos)++];
}
- static AK_FORCE_INLINE int readUint(const uint8_t *const buffer,
+ static AK_FORCE_INLINE uint32_t readUint(const uint8_t *const buffer,
const int size, const int pos) {
// size must be in 1 to 4.
ASSERT(size >= 1 && size <= 4);
diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp
index 1916ea560..e6e7167c2 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp
@@ -23,7 +23,7 @@ namespace latinime {
const uint32_t FormatUtils::MAGIC_NUMBER = 0x9BC13AFE;
// Magic number (4 bytes), version (2 bytes), flags (2 bytes), header size (4 bytes) = 12
-const int FormatUtils::DICTIONARY_MINIMUM_SIZE = 12;
+const size_t FormatUtils::DICTIONARY_MINIMUM_SIZE = 12;
/* static */ FormatUtils::FORMAT_VERSION FormatUtils::getFormatVersion(const int formatVersion) {
switch (formatVersion) {
@@ -40,14 +40,14 @@ const int FormatUtils::DICTIONARY_MINIMUM_SIZE = 12;
}
}
/* static */ FormatUtils::FORMAT_VERSION FormatUtils::detectFormatVersion(
- const uint8_t *const dict, const int dictSize) {
+ const ReadOnlyByteArrayView dictBuffer) {
// The magic number is stored big-endian.
// If the dictionary is less than 4 bytes, we can't even read the magic number, so we don't
// understand this format.
- if (dictSize < DICTIONARY_MINIMUM_SIZE) {
+ if (dictBuffer.size() < DICTIONARY_MINIMUM_SIZE) {
return UNKNOWN_VERSION;
}
- const uint32_t magicNumber = ByteArrayUtils::readUint32(dict, 0);
+ const uint32_t magicNumber = ByteArrayUtils::readUint32(dictBuffer.data(), 0);
switch (magicNumber) {
case MAGIC_NUMBER:
// The layout of the header is as follows:
@@ -58,7 +58,7 @@ const int FormatUtils::DICTIONARY_MINIMUM_SIZE = 12;
// Conceptually this converts the hardcoded value of the bytes in the file into
// the symbolic value we use in the code. But we want the constants to be the
// same so we use them for both here.
- return getFormatVersion(ByteArrayUtils::readUint16(dict, 4));
+ return getFormatVersion(ByteArrayUtils::readUint16(dictBuffer.data(), 4));
default:
return UNKNOWN_VERSION;
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h
index 55ad5799f..51ad9877c 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h
@@ -20,6 +20,7 @@
#include <cstdint>
#include "defines.h"
+#include "utils/byte_array_view.h"
namespace latinime {
@@ -42,12 +43,12 @@ class FormatUtils {
static const uint32_t MAGIC_NUMBER;
static FORMAT_VERSION getFormatVersion(const int formatVersion);
- static FORMAT_VERSION detectFormatVersion(const uint8_t *const dict, const int dictSize);
+ static FORMAT_VERSION detectFormatVersion(const ReadOnlyByteArrayView dictBuffer);
private:
DISALLOW_IMPLICIT_CONSTRUCTORS(FormatUtils);
- static const int DICTIONARY_MINIMUM_SIZE;
+ static const size_t DICTIONARY_MINIMUM_SIZE;
};
} // namespace latinime
#endif /* LATINIME_FORMAT_UTILS_H */
diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp
index 407b8efd0..39f417ebb 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp
@@ -26,6 +26,7 @@ const int TrieMap::FIELD1_SIZE = 3;
const int TrieMap::ENTRY_SIZE = FIELD0_SIZE + FIELD1_SIZE;
const uint32_t TrieMap::VALUE_FLAG = 0x400000;
const uint32_t TrieMap::VALUE_MASK = 0x3FFFFF;
+const uint32_t TrieMap::INVALID_VALUE_IN_KEY_VALUE_ENTRY = VALUE_MASK;
const uint32_t TrieMap::TERMINAL_LINK_FLAG = 0x800000;
const uint32_t TrieMap::TERMINAL_LINK_MASK = 0x7FFFFF;
const int TrieMap::NUM_OF_BITS_USED_FOR_ONE_LEVEL = 5;
@@ -34,6 +35,7 @@ const int TrieMap::MAX_NUM_OF_ENTRIES_IN_ONE_LEVEL = 1 << NUM_OF_BITS_USED_FOR_O
const int TrieMap::ROOT_BITMAP_ENTRY_INDEX = 0;
const int TrieMap::ROOT_BITMAP_ENTRY_POS = MAX_NUM_OF_ENTRIES_IN_ONE_LEVEL * FIELD0_SIZE;
const TrieMap::Entry TrieMap::EMPTY_BITMAP_ENTRY = TrieMap::Entry(0, 0);
+const int TrieMap::TERMINAL_LINKED_ENTRY_COUNT = 2; // Value entry and bitmap entry.
const uint64_t TrieMap::MAX_VALUE =
(static_cast<uint64_t>(1) << ((FIELD0_SIZE + FIELD1_SIZE) * CHAR_BIT)) - 1;
const int TrieMap::MAX_BUFFER_SIZE = TERMINAL_LINK_MASK * ENTRY_SIZE;
@@ -76,14 +78,14 @@ int TrieMap::getNextLevelBitmapEntryIndex(const int key, const int bitmapEntryIn
return terminalEntry.getValueEntryIndex() + 1;
}
// Create a value entry and a bitmap entry.
- const int valueEntryIndex = allocateTable(2 /* entryCount */);
+ const int valueEntryIndex = allocateTable(TERMINAL_LINKED_ENTRY_COUNT);
if (!writeEntry(Entry(0, terminalEntry.getValue()), valueEntryIndex)) {
return INVALID_INDEX;
}
if (!writeEntry(EMPTY_BITMAP_ENTRY, valueEntryIndex + 1)) {
return INVALID_INDEX;
}
- if (!writeField1(valueEntryIndex | TERMINAL_LINK_FLAG, valueEntryIndex)) {
+ if (!writeField1(valueEntryIndex | TERMINAL_LINK_FLAG, terminalEntryIndex)) {
return INVALID_INDEX;
}
return valueEntryIndex + 1;
@@ -108,6 +110,31 @@ bool TrieMap::save(FILE *const file) const {
return DictFileWritingUtils::writeBufferToFileTail(file, &mBuffer);
}
+bool TrieMap::remove(const int key, const int bitmapEntryIndex) {
+ const Entry bitmapEntry = readEntry(bitmapEntryIndex);
+ const uint32_t unsignedKey = static_cast<uint32_t>(key);
+ const int terminalEntryIndex = getTerminalEntryIndex(
+ unsignedKey, getBitShuffledKey(unsignedKey), bitmapEntry, 0 /* level */);
+ if (terminalEntryIndex == INVALID_INDEX) {
+ // Not found.
+ return false;
+ }
+ const Entry terminalEntry = readEntry(terminalEntryIndex);
+ if (!writeField1(VALUE_FLAG ^ INVALID_VALUE_IN_KEY_VALUE_ENTRY , terminalEntryIndex)) {
+ return false;
+ }
+ if (terminalEntry.hasTerminalLink()) {
+ const Entry nextLevelBitmapEntry = readEntry(terminalEntry.getValueEntryIndex() + 1);
+ if (!freeTable(terminalEntry.getValueEntryIndex(), TERMINAL_LINKED_ENTRY_COUNT)) {
+ return false;
+ }
+ if (!removeInner(nextLevelBitmapEntry)){
+ return false;
+ }
+ }
+ return true;
+}
+
/**
* Iterate next entry in a certain level.
*
@@ -129,7 +156,7 @@ const TrieMap::Result TrieMap::iterateNext(std::vector<TableIterationState> *con
if (entry.isBitmapEntry()) {
// Move to child.
iterationState->emplace_back(popCount(entry.getBitmap()), entry.getTableIndex());
- } else {
+ } else if (entry.isValidTerminalEntry()) {
if (outKey) {
*outKey = entry.getKey();
}
@@ -162,12 +189,12 @@ uint32_t TrieMap::getBitShuffledKey(const uint32_t key) const {
}
bool TrieMap::writeValue(const uint64_t value, const int terminalEntryIndex) {
- if (value <= VALUE_MASK) {
+ if (value < VALUE_MASK) {
// Write value into the terminal entry.
return writeField1(value | VALUE_FLAG, terminalEntryIndex);
}
// Create value entry and write value.
- const int valueEntryIndex = allocateTable(2 /* entryCount */);
+ const int valueEntryIndex = allocateTable(TERMINAL_LINKED_ENTRY_COUNT);
if (!writeEntry(Entry(value >> (FIELD1_SIZE * CHAR_BIT), value), valueEntryIndex)) {
return false;
}
@@ -227,6 +254,9 @@ int TrieMap::getTerminalEntryIndex(const uint32_t key, const uint32_t hashedKey,
// Move to the next level.
return getTerminalEntryIndex(key, hashedKey, entry, level + 1);
}
+ if (!entry.isValidTerminalEntry()) {
+ return INVALID_INDEX;
+ }
if (entry.getKey() == key) {
// Terminal entry is found.
return entryIndex;
@@ -287,6 +317,10 @@ bool TrieMap::putInternal(const uint32_t key, const uint64_t value, const uint32
// Bitmap entry is found. Go to the next level.
return putInternal(key, value, hashedKey, entryIndex, entry, level + 1);
}
+ if (!entry.isValidTerminalEntry()) {
+ // Overwrite invalid terminal entry.
+ return writeTerminalEntry(key, value, entryIndex);
+ }
if (entry.getKey() == key) {
// Terminal entry for the key is found. Update the value.
return updateValue(entry, value, entryIndex);
@@ -384,4 +418,37 @@ bool TrieMap::addNewEntryByExpandingTable(const uint32_t key, const uint64_t val
return true;
}
+bool TrieMap::removeInner(const Entry &bitmapEntry) {
+ const int tableSize = popCount(bitmapEntry.getBitmap());
+ if (tableSize <= 0) {
+ // The table is empty. No need to remove any entries.
+ return true;
+ }
+ for (int i = 0; i < tableSize; ++i) {
+ const int entryIndex = bitmapEntry.getTableIndex() + i;
+ const Entry entry = readEntry(entryIndex);
+ if (entry.isBitmapEntry()) {
+ // Delete next bitmap entry recursively.
+ if (!removeInner(entry)) {
+ return false;
+ }
+ } else {
+ // Invalidate terminal entry just in case.
+ if (!writeField1(VALUE_FLAG ^ INVALID_VALUE_IN_KEY_VALUE_ENTRY , entryIndex)) {
+ return false;
+ }
+ if (entry.hasTerminalLink()) {
+ const Entry nextLevelBitmapEntry = readEntry(entry.getValueEntryIndex() + 1);
+ if (!freeTable(entry.getValueEntryIndex(), TERMINAL_LINKED_ENTRY_COUNT)) {
+ return false;
+ }
+ if (!removeInner(nextLevelBitmapEntry)) {
+ return false;
+ }
+ }
+ }
+ }
+ return true;
+}
+
} // namespace latinime
diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h
index 3e5c4010c..c2aeac211 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h
@@ -84,6 +84,10 @@ class TrieMap {
return mValue;
}
+ AK_FORCE_INLINE int getNextLevelBitmapEntryIndex() const {
+ return mNextLevelBitmapEntryIndex;
+ }
+
private:
const TrieMap *const mTrieMap;
const int mKey;
@@ -202,6 +206,8 @@ class TrieMap {
bool save(FILE *const file) const;
+ bool remove(const int key, const int bitmapEntryIndex);
+
private:
DISALLOW_COPY_AND_ASSIGN(TrieMap);
@@ -245,6 +251,11 @@ class TrieMap {
}
// For terminal entry.
+ AK_FORCE_INLINE bool isValidTerminalEntry() const {
+ return hasTerminalLink() || ((mData1 & VALUE_MASK) != INVALID_VALUE_IN_KEY_VALUE_ENTRY);
+ }
+
+ // For terminal entry.
AK_FORCE_INLINE uint32_t getValueEntryIndex() const {
return mData1 & TERMINAL_LINK_MASK;
}
@@ -272,6 +283,7 @@ class TrieMap {
static const int ENTRY_SIZE;
static const uint32_t VALUE_FLAG;
static const uint32_t VALUE_MASK;
+ static const uint32_t INVALID_VALUE_IN_KEY_VALUE_ENTRY;
static const uint32_t TERMINAL_LINK_FLAG;
static const uint32_t TERMINAL_LINK_MASK;
static const int NUM_OF_BITS_USED_FOR_ONE_LEVEL;
@@ -280,6 +292,7 @@ class TrieMap {
static const int ROOT_BITMAP_ENTRY_INDEX;
static const int ROOT_BITMAP_ENTRY_POS;
static const Entry EMPTY_BITMAP_ENTRY;
+ static const int TERMINAL_LINKED_ENTRY_COUNT;
static const int MAX_BUFFER_SIZE;
uint32_t getBitShuffledKey(const uint32_t key) const;
@@ -378,6 +391,8 @@ class TrieMap {
AK_FORCE_INLINE int getTailEntryIndex() const {
return (mBuffer.getTailPosition() - ROOT_BITMAP_ENTRY_POS) / ENTRY_SIZE;
}
+
+ bool removeInner(const Entry &bitmapEntry);
};
} // namespace latinime
diff --git a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h
index 04cb6603a..52c4251f0 100644
--- a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h
+++ b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h
@@ -51,10 +51,10 @@ class TypingScoring : public Scoring {
}
if (boostExactMatches && ErrorTypeUtils::isExactMatch(containedErrorTypes)) {
score += ScoringParams::EXACT_MATCH_PROMOTION;
- if ((ErrorTypeUtils::MATCH_WITH_CASE_ERROR & containedErrorTypes) != 0) {
+ if ((ErrorTypeUtils::MATCH_WITH_WRONG_CASE & containedErrorTypes) != 0) {
score -= ScoringParams::CASE_ERROR_PENALTY_FOR_EXACT_MATCH;
}
- if ((ErrorTypeUtils::MATCH_WITH_ACCENT_ERROR & containedErrorTypes) != 0) {
+ if ((ErrorTypeUtils::MATCH_WITH_MISSING_ACCENT & containedErrorTypes) != 0) {
score -= ScoringParams::ACCENT_ERROR_PENALTY_FOR_EXACT_MATCH;
}
if ((ErrorTypeUtils::MATCH_WITH_DIGRAPH & containedErrorTypes) != 0) {
diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp b/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp
index 54f65c786..1d590c353 100644
--- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp
+++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp
@@ -36,25 +36,34 @@ ErrorTypeUtils::ErrorType TypingWeighting::getErrorType(const CorrectionType cor
// Compare the node code point with original primary code point on the keyboard.
const ProximityInfoState *const pInfoState =
traverseSession->getProximityInfoState(0);
- const int primaryOriginalCodePoint = pInfoState->getPrimaryOriginalCodePointAt(
+ const int primaryCodePoint = pInfoState->getPrimaryCodePointAt(
dicNode->getInputIndex(0));
const int nodeCodePoint = dicNode->getNodeCodePoint();
- if (primaryOriginalCodePoint == nodeCodePoint) {
+ // TODO: Check whether the input code point is on the keyboard.
+ if (primaryCodePoint == nodeCodePoint) {
// Node code point is same as original code point on the keyboard.
return ErrorTypeUtils::NOT_AN_ERROR;
- } else if (CharUtils::toLowerCase(primaryOriginalCodePoint) ==
+ } else if (CharUtils::toLowerCase(primaryCodePoint) ==
CharUtils::toLowerCase(nodeCodePoint)) {
// Only cases of the code points are different.
- return ErrorTypeUtils::MATCH_WITH_CASE_ERROR;
- } else if (CharUtils::toBaseCodePoint(primaryOriginalCodePoint) ==
- CharUtils::toBaseCodePoint(nodeCodePoint)) {
+ return ErrorTypeUtils::MATCH_WITH_WRONG_CASE;
+ } else if (primaryCodePoint == CharUtils::toBaseCodePoint(nodeCodePoint)) {
// Node code point is a variant of original code point.
- return ErrorTypeUtils::MATCH_WITH_ACCENT_ERROR;
- } else {
+ return ErrorTypeUtils::MATCH_WITH_MISSING_ACCENT;
+ } else if (CharUtils::toBaseCodePoint(primaryCodePoint)
+ == CharUtils::toBaseCodePoint(nodeCodePoint)) {
+ // Base code points are the same but the code point is intentionally input.
+ return ErrorTypeUtils::MATCH_WITH_WRONG_ACCENT;
+ } else if (CharUtils::toLowerCase(primaryCodePoint)
+ == CharUtils::toBaseLowerCase(nodeCodePoint)) {
// Node code point is a variant of original code point and the cases are also
// different.
- return ErrorTypeUtils::MATCH_WITH_ACCENT_ERROR
- | ErrorTypeUtils::MATCH_WITH_CASE_ERROR;
+ return ErrorTypeUtils::MATCH_WITH_MISSING_ACCENT
+ | ErrorTypeUtils::MATCH_WITH_WRONG_CASE;
+ } else {
+ // Base code points are the same and the cases are different.
+ return ErrorTypeUtils::MATCH_WITH_WRONG_ACCENT
+ | ErrorTypeUtils::MATCH_WITH_WRONG_CASE;
}
}
break;
diff --git a/native/jni/src/utils/byte_array_view.h b/native/jni/src/utils/byte_array_view.h
index 2c97c6d58..10d7ae278 100644
--- a/native/jni/src/utils/byte_array_view.h
+++ b/native/jni/src/utils/byte_array_view.h
@@ -77,10 +77,12 @@ class ReadWriteByteArrayView {
}
private:
- DISALLOW_ASSIGNMENT_OPERATOR(ReadWriteByteArrayView);
+ // Default copy constructor and assignment operator are used for using this class with STL
+ // containers.
- uint8_t *const mPtr;
- const size_t mSize;
+ // These members cannot be const to have the assignment operator.
+ uint8_t *mPtr;
+ size_t mSize;
};
} // namespace latinime
diff --git a/native/jni/src/utils/int_array_view.h b/native/jni/src/utils/int_array_view.h
index c1ddc9812..53f2d2971 100644
--- a/native/jni/src/utils/int_array_view.h
+++ b/native/jni/src/utils/int_array_view.h
@@ -91,6 +91,11 @@ class IntArrayView {
return mPtr + mSize;
}
+ // Returns the view whose size is smaller than or equal to the given count.
+ const IntArrayView limit(const size_t maxSize) const {
+ return IntArrayView(mPtr, std::min(maxSize, mSize));
+ }
+
private:
DISALLOW_ASSIGNMENT_OPERATOR(IntArrayView);