aboutsummaryrefslogtreecommitdiffstats
path: root/native
diff options
context:
space:
mode:
Diffstat (limited to 'native')
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp4
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp15
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h2
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h24
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp2
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h2
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp41
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h6
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp71
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h11
-rw-r--r--native/jni/src/utils/int_array_view.h5
-rw-r--r--native/jni/tests/suggest/policyimpl/dictionary/utils/trie_map_test.cpp25
-rw-r--r--native/jni/tests/utils/int_array_view_test.cpp17
13 files changed, 199 insertions, 26 deletions
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp
index 278f2b199..f7179f68d 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp
@@ -234,8 +234,8 @@ bool Ver4PatriciaTrieNodeWriter::writeNewTerminalPtNodeAndAdvancePosition(
bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds, const int wordId,
const BigramProperty *const bigramProperty, bool *const outAddedNewEntry) {
if (!mBigramPolicy->addNewEntry(prevWordIds[0], wordId, bigramProperty, outAddedNewEntry)) {
- AKLOGE("Cannot add new bigram entry. terminalId: %d, targetTerminalId: %d",
- sourcePtNodeParams->getTerminalId(), targetPtNodeParam->getTerminalId());
+ AKLOGE("Cannot add new bigram entry. prevWordId: %d, wordId: %d",
+ prevWordIds[0], wordId);
return false;
}
const int ptNodePos =
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
index 5dc91ba10..f3bc4a0cb 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
@@ -46,7 +46,7 @@ ProbabilityEntry LanguageModelDictContent::getNgramProbabilityEntry(
bool LanguageModelDictContent::setNgramProbabilityEntry(const WordIdArrayView prevWordIds,
const int terminalId, const ProbabilityEntry *const probabilityEntry) {
- const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds);
+ const int bitmapEntryIndex = createAndGetBitmapEntryIndex(prevWordIds);
if (bitmapEntryIndex == TrieMap::INVALID_INDEX) {
return false;
}
@@ -80,6 +80,19 @@ bool LanguageModelDictContent::runGCInner(
return true;
}
+int LanguageModelDictContent::createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds) {
+ if (prevWordIds.empty()) {
+ return mTrieMap.getRootBitmapEntryIndex();
+ }
+ const int lastBitmapEntryIndex =
+ getBitmapEntryIndex(prevWordIds.limit(prevWordIds.size() - 1));
+ if (lastBitmapEntryIndex == TrieMap::INVALID_INDEX) {
+ return TrieMap::INVALID_INDEX;
+ }
+ return mTrieMap.getNextLevelBitmapEntryIndex(prevWordIds[prevWordIds.size() - 1],
+ lastBitmapEntryIndex);
+}
+
int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWordIds) const {
int bitmapEntryIndex = mTrieMap.getRootBitmapEntryIndex();
for (const int wordId : prevWordIds) {
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
index 18f2e0170..104ee2520 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
@@ -76,7 +76,7 @@ class LanguageModelDictContent {
bool runGCInner(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex,
int *const outNgramCount);
-
+ int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds);
int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
};
} // namespace latinime
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h
index feff6b57f..ed77bd20e 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h
@@ -21,6 +21,8 @@
#include <cstdint>
#include "defines.h"
+#include "suggest/core/dictionary/property/bigram_property.h"
+#include "suggest/core/dictionary/property/unigram_property.h"
#include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h"
#include "suggest/policyimpl/dictionary/utils/historical_info.h"
@@ -45,6 +47,20 @@ class ProbabilityEntry {
const HistoricalInfo *const historicalInfo)
: mFlags(flags), mProbability(probability), mHistoricalInfo(*historicalInfo) {}
+ // Create from unigram property.
+ // TODO: Set flags.
+ ProbabilityEntry(const UnigramProperty *const unigramProperty)
+ : mFlags(0), mProbability(unigramProperty->getProbability()),
+ mHistoricalInfo(unigramProperty->getTimestamp(), unigramProperty->getLevel(),
+ unigramProperty->getCount()) {}
+
+ // Create from bigram property.
+ // TODO: Set flags.
+ ProbabilityEntry(const BigramProperty *const bigramProperty)
+ : mFlags(0), mProbability(bigramProperty->getProbability()),
+ mHistoricalInfo(bigramProperty->getTimestamp(), bigramProperty->getLevel(),
+ bigramProperty->getCount()) {}
+
const ProbabilityEntry createEntryWithUpdatedProbability(const int probability) const {
return ProbabilityEntry(mFlags, probability, &mHistoricalInfo);
}
@@ -54,6 +70,10 @@ class ProbabilityEntry {
return ProbabilityEntry(mFlags, mProbability, historicalInfo);
}
+ bool isValid() const {
+ return (mProbability != NOT_A_PROBABILITY) || hasHistoricalInfo();
+ }
+
bool hasHistoricalInfo() const {
return mHistoricalInfo.isValid();
}
@@ -89,7 +109,7 @@ class ProbabilityEntry {
static ProbabilityEntry decode(const uint64_t encodedEntry, const bool hasHistoricalInfo) {
if (hasHistoricalInfo) {
const int flags = readFromEncodedEntry(encodedEntry,
- Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE,
+ Ver4DictConstants::FLAGS_IN_LANGUAGE_MODEL_SIZE,
Ver4DictConstants::TIME_STAMP_FIELD_SIZE
+ Ver4DictConstants::WORD_LEVEL_FIELD_SIZE
+ Ver4DictConstants::WORD_COUNT_FIELD_SIZE);
@@ -106,7 +126,7 @@ class ProbabilityEntry {
return ProbabilityEntry(flags, NOT_A_PROBABILITY, &historicalInfo);
} else {
const int flags = readFromEncodedEntry(encodedEntry,
- Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE,
+ Ver4DictConstants::FLAGS_IN_LANGUAGE_MODEL_SIZE,
Ver4DictConstants::PROBABILITY_SIZE);
const int probability = readFromEncodedEntry(encodedEntry,
Ver4DictConstants::PROBABILITY_SIZE, 0 /* pos */);
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp
index 93d4e562d..e622442ba 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp
@@ -46,7 +46,7 @@ const int Ver4DictConstants::SHORTCUT_BUFFERS_INDEX =
const int Ver4DictConstants::NOT_A_TERMINAL_ID = -1;
const int Ver4DictConstants::PROBABILITY_SIZE = 1;
-const int Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE = 1;
+const int Ver4DictConstants::FLAGS_IN_LANGUAGE_MODEL_SIZE = 1;
const int Ver4DictConstants::TERMINAL_ADDRESS_TABLE_ADDRESS_SIZE = 3;
const int Ver4DictConstants::NOT_A_TERMINAL_ADDRESS = 0;
const int Ver4DictConstants::TERMINAL_ID_FIELD_SIZE = 4;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h
index 6950ca70f..8d29f60d4 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h
@@ -41,7 +41,7 @@ class Ver4DictConstants {
static const int NOT_A_TERMINAL_ID;
static const int PROBABILITY_SIZE;
- static const int FLAGS_IN_PROBABILITY_FILE_SIZE;
+ static const int FLAGS_IN_LANGUAGE_MODEL_SIZE;
static const int TERMINAL_ADDRESS_TABLE_ADDRESS_SIZE;
static const int NOT_A_TERMINAL_ADDRESS;
static const int TERMINAL_ID_FIELD_SIZE;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp
index 857222f5d..2c848cb29 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp
@@ -145,10 +145,11 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeUnigramProperty(
const ProbabilityEntry originalProbabilityEntry =
mBuffers->getLanguageModelDictContent()->getProbabilityEntry(
toBeUpdatedPtNodeParams->getTerminalId());
- const ProbabilityEntry probabilityEntry = createUpdatedEntryFrom(&originalProbabilityEntry,
- unigramProperty);
+ const ProbabilityEntry probabilityEntryOfUnigramProperty = ProbabilityEntry(unigramProperty);
+ const ProbabilityEntry updatedProbabilityEntry =
+ createUpdatedEntryFrom(&originalProbabilityEntry, &probabilityEntryOfUnigramProperty);
return mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry(
- toBeUpdatedPtNodeParams->getTerminalId(), &probabilityEntry);
+ toBeUpdatedPtNodeParams->getTerminalId(), &updatedProbabilityEntry);
}
bool Ver4PatriciaTrieNodeWriter::updatePtNodeProbabilityAndGetNeedsToKeepPtNodeAfterGC(
@@ -216,16 +217,36 @@ bool Ver4PatriciaTrieNodeWriter::writeNewTerminalPtNodeAndAdvancePosition(
}
// Write probability.
ProbabilityEntry newProbabilityEntry;
+ const ProbabilityEntry probabilityEntryOfUnigramProperty = ProbabilityEntry(unigramProperty);
const ProbabilityEntry probabilityEntryToWrite = createUpdatedEntryFrom(
- &newProbabilityEntry, unigramProperty);
+ &newProbabilityEntry, &probabilityEntryOfUnigramProperty);
return mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry(
terminalId, &probabilityEntryToWrite);
}
bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds, const int wordId,
const BigramProperty *const bigramProperty, bool *const outAddedNewBigram) {
+ // TODO: Support n-gram.
+ LanguageModelDictContent *const languageModelDictContent =
+ mBuffers->getMutableLanguageModelDictContent();
+ const ProbabilityEntry probabilityEntry =
+ languageModelDictContent->getNgramProbabilityEntry(
+ prevWordIds.limit(1 /* maxSize */), wordId);
+ const ProbabilityEntry probabilityEntryOfBigramProperty(bigramProperty);
+ const ProbabilityEntry updatedProbabilityEntry = createUpdatedEntryFrom(
+ &probabilityEntry, &probabilityEntryOfBigramProperty);
+ if (!languageModelDictContent->setNgramProbabilityEntry(
+ prevWordIds.limit(1 /* maxSize */), wordId, &updatedProbabilityEntry)) {
+ AKLOGE("Cannot add new ngram entry. prevWordId: %d, wordId: %d",
+ prevWordIds[0], wordId);
+ return false;
+ }
+ if (!probabilityEntry.isValid() && outAddedNewBigram) {
+ *outAddedNewBigram = true;
+ }
+ // TODO: Remove.
if (!mBigramPolicy->addNewEntry(prevWordIds[0], wordId, bigramProperty, outAddedNewBigram)) {
- AKLOGE("Cannot add new bigram entry. terminalId: %d, targetTerminalId: %d",
+ AKLOGE("Cannot add new bigram entry. prevWordId: %d, wordId: %d",
prevWordIds[0], wordId);
return false;
}
@@ -234,6 +255,7 @@ bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds
bool Ver4PatriciaTrieNodeWriter::removeNgramEntry(const WordIdArrayView prevWordIds,
const int wordId) {
+ // TODO: Remove.
return mBigramPolicy->removeEntry(prevWordIds[0], wordId);
}
@@ -352,20 +374,19 @@ bool Ver4PatriciaTrieNodeWriter::writePtNodeAndGetTerminalIdAndAdvancePosition(
const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom(
const ProbabilityEntry *const originalProbabilityEntry,
- const UnigramProperty *const unigramProperty) const {
+ const ProbabilityEntry *const probabilityEntry) const {
// TODO: Consolidate historical info and probability.
if (mHeaderPolicy->hasHistoricalInfoOfWords()) {
- const HistoricalInfo historicalInfoForUpdate(unigramProperty->getTimestamp(),
- unigramProperty->getLevel(), unigramProperty->getCount());
const HistoricalInfo updatedHistoricalInfo =
ForgettingCurveUtils::createUpdatedHistoricalInfo(
originalProbabilityEntry->getHistoricalInfo(),
- unigramProperty->getProbability(), &historicalInfoForUpdate, mHeaderPolicy);
+ probabilityEntry->getProbability(), probabilityEntry->getHistoricalInfo(),
+ mHeaderPolicy);
return originalProbabilityEntry->createEntryWithUpdatedHistoricalInfo(
&updatedHistoricalInfo);
} else {
return originalProbabilityEntry->createEntryWithUpdatedProbability(
- unigramProperty->getProbability());
+ probabilityEntry->getProbability());
}
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h
index 6703dba04..5d73b6ea3 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h
@@ -98,12 +98,12 @@ class Ver4PatriciaTrieNodeWriter : public PtNodeWriter {
const PtNodeParams *const ptNodeParams, int *const outTerminalId,
int *const ptNodeWritingPos);
- // Create updated probability entry using given unigram property. In addition to the
+ // Create updated probability entry using given probability property. In addition to the
// probability, this method updates historical information if needed.
- // TODO: Update flags belonging to the unigram property.
+ // TODO: Update flags.
const ProbabilityEntry createUpdatedEntryFrom(
const ProbabilityEntry *const originalProbabilityEntry,
- const UnigramProperty *const unigramProperty) const;
+ const ProbabilityEntry *const probabilityEntry) const;
bool updatePtNodeFlags(const int ptNodePos, const bool isBlacklisted, const bool isNotAWord,
const bool isTerminal, const bool hasMultipleChars);
diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp
index 407b8efd0..e630aba9a 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp
@@ -26,6 +26,7 @@ const int TrieMap::FIELD1_SIZE = 3;
const int TrieMap::ENTRY_SIZE = FIELD0_SIZE + FIELD1_SIZE;
const uint32_t TrieMap::VALUE_FLAG = 0x400000;
const uint32_t TrieMap::VALUE_MASK = 0x3FFFFF;
+const uint32_t TrieMap::INVALID_VALUE_IN_KEY_VALUE_ENTRY = VALUE_MASK;
const uint32_t TrieMap::TERMINAL_LINK_FLAG = 0x800000;
const uint32_t TrieMap::TERMINAL_LINK_MASK = 0x7FFFFF;
const int TrieMap::NUM_OF_BITS_USED_FOR_ONE_LEVEL = 5;
@@ -34,6 +35,7 @@ const int TrieMap::MAX_NUM_OF_ENTRIES_IN_ONE_LEVEL = 1 << NUM_OF_BITS_USED_FOR_O
const int TrieMap::ROOT_BITMAP_ENTRY_INDEX = 0;
const int TrieMap::ROOT_BITMAP_ENTRY_POS = MAX_NUM_OF_ENTRIES_IN_ONE_LEVEL * FIELD0_SIZE;
const TrieMap::Entry TrieMap::EMPTY_BITMAP_ENTRY = TrieMap::Entry(0, 0);
+const int TrieMap::TERMINAL_LINKED_ENTRY_COUNT = 2; // Value entry and bitmap entry.
const uint64_t TrieMap::MAX_VALUE =
(static_cast<uint64_t>(1) << ((FIELD0_SIZE + FIELD1_SIZE) * CHAR_BIT)) - 1;
const int TrieMap::MAX_BUFFER_SIZE = TERMINAL_LINK_MASK * ENTRY_SIZE;
@@ -76,7 +78,7 @@ int TrieMap::getNextLevelBitmapEntryIndex(const int key, const int bitmapEntryIn
return terminalEntry.getValueEntryIndex() + 1;
}
// Create a value entry and a bitmap entry.
- const int valueEntryIndex = allocateTable(2 /* entryCount */);
+ const int valueEntryIndex = allocateTable(TERMINAL_LINKED_ENTRY_COUNT);
if (!writeEntry(Entry(0, terminalEntry.getValue()), valueEntryIndex)) {
return INVALID_INDEX;
}
@@ -108,6 +110,31 @@ bool TrieMap::save(FILE *const file) const {
return DictFileWritingUtils::writeBufferToFileTail(file, &mBuffer);
}
+bool TrieMap::remove(const int key, const int bitmapEntryIndex) {
+ const Entry bitmapEntry = readEntry(bitmapEntryIndex);
+ const uint32_t unsignedKey = static_cast<uint32_t>(key);
+ const int terminalEntryIndex = getTerminalEntryIndex(
+ unsignedKey, getBitShuffledKey(unsignedKey), bitmapEntry, 0 /* level */);
+ if (terminalEntryIndex == INVALID_INDEX) {
+ // Not found.
+ return false;
+ }
+ const Entry terminalEntry = readEntry(terminalEntryIndex);
+ if (!writeField1(VALUE_FLAG ^ INVALID_VALUE_IN_KEY_VALUE_ENTRY , terminalEntryIndex)) {
+ return false;
+ }
+ if (terminalEntry.hasTerminalLink()) {
+ const Entry nextLevelBitmapEntry = readEntry(terminalEntry.getValueEntryIndex() + 1);
+ if (!freeTable(terminalEntry.getValueEntryIndex(), TERMINAL_LINKED_ENTRY_COUNT)) {
+ return false;
+ }
+ if (!removeInner(nextLevelBitmapEntry)){
+ return false;
+ }
+ }
+ return true;
+}
+
/**
* Iterate next entry in a certain level.
*
@@ -129,7 +156,7 @@ const TrieMap::Result TrieMap::iterateNext(std::vector<TableIterationState> *con
if (entry.isBitmapEntry()) {
// Move to child.
iterationState->emplace_back(popCount(entry.getBitmap()), entry.getTableIndex());
- } else {
+ } else if (entry.isValidTerminalEntry()) {
if (outKey) {
*outKey = entry.getKey();
}
@@ -162,12 +189,12 @@ uint32_t TrieMap::getBitShuffledKey(const uint32_t key) const {
}
bool TrieMap::writeValue(const uint64_t value, const int terminalEntryIndex) {
- if (value <= VALUE_MASK) {
+ if (value < VALUE_MASK) {
// Write value into the terminal entry.
return writeField1(value | VALUE_FLAG, terminalEntryIndex);
}
// Create value entry and write value.
- const int valueEntryIndex = allocateTable(2 /* entryCount */);
+ const int valueEntryIndex = allocateTable(TERMINAL_LINKED_ENTRY_COUNT);
if (!writeEntry(Entry(value >> (FIELD1_SIZE * CHAR_BIT), value), valueEntryIndex)) {
return false;
}
@@ -227,6 +254,9 @@ int TrieMap::getTerminalEntryIndex(const uint32_t key, const uint32_t hashedKey,
// Move to the next level.
return getTerminalEntryIndex(key, hashedKey, entry, level + 1);
}
+ if (!entry.isValidTerminalEntry()) {
+ return INVALID_INDEX;
+ }
if (entry.getKey() == key) {
// Terminal entry is found.
return entryIndex;
@@ -287,6 +317,10 @@ bool TrieMap::putInternal(const uint32_t key, const uint64_t value, const uint32
// Bitmap entry is found. Go to the next level.
return putInternal(key, value, hashedKey, entryIndex, entry, level + 1);
}
+ if (!entry.isValidTerminalEntry()) {
+ // Overwrite invalid terminal entry.
+ return writeTerminalEntry(key, value, entryIndex);
+ }
if (entry.getKey() == key) {
// Terminal entry for the key is found. Update the value.
return updateValue(entry, value, entryIndex);
@@ -384,4 +418,33 @@ bool TrieMap::addNewEntryByExpandingTable(const uint32_t key, const uint64_t val
return true;
}
+bool TrieMap::removeInner(const Entry &bitmapEntry) {
+ const int tableSize = popCount(bitmapEntry.getBitmap());
+ for (int i = 0; i < tableSize; ++i) {
+ const int entryIndex = bitmapEntry.getTableIndex() + i;
+ const Entry entry = readEntry(entryIndex);
+ if (entry.isBitmapEntry()) {
+ // Delete next bitmap entry recursively.
+ if (!removeInner(entry)) {
+ return false;
+ }
+ } else {
+ // Invalidate terminal entry just in case.
+ if (!writeField1(VALUE_FLAG ^ INVALID_VALUE_IN_KEY_VALUE_ENTRY , entryIndex)) {
+ return false;
+ }
+ if (entry.hasTerminalLink()) {
+ const Entry nextLevelBitmapEntry = readEntry(entry.getValueEntryIndex() + 1);
+ if (!freeTable(entry.getValueEntryIndex(), TERMINAL_LINKED_ENTRY_COUNT)) {
+ return false;
+ }
+ if (!removeInner(nextLevelBitmapEntry)) {
+ return false;
+ }
+ }
+ }
+ }
+ return freeTable(bitmapEntry.getTableIndex(), tableSize);
+}
+
} // namespace latinime
diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h
index 3e5c4010c..6d91790b2 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h
@@ -202,6 +202,8 @@ class TrieMap {
bool save(FILE *const file) const;
+ bool remove(const int key, const int bitmapEntryIndex);
+
private:
DISALLOW_COPY_AND_ASSIGN(TrieMap);
@@ -245,6 +247,11 @@ class TrieMap {
}
// For terminal entry.
+ AK_FORCE_INLINE bool isValidTerminalEntry() const {
+ return hasTerminalLink() || ((mData1 & VALUE_MASK) != INVALID_VALUE_IN_KEY_VALUE_ENTRY);
+ }
+
+ // For terminal entry.
AK_FORCE_INLINE uint32_t getValueEntryIndex() const {
return mData1 & TERMINAL_LINK_MASK;
}
@@ -272,6 +279,7 @@ class TrieMap {
static const int ENTRY_SIZE;
static const uint32_t VALUE_FLAG;
static const uint32_t VALUE_MASK;
+ static const uint32_t INVALID_VALUE_IN_KEY_VALUE_ENTRY;
static const uint32_t TERMINAL_LINK_FLAG;
static const uint32_t TERMINAL_LINK_MASK;
static const int NUM_OF_BITS_USED_FOR_ONE_LEVEL;
@@ -280,6 +288,7 @@ class TrieMap {
static const int ROOT_BITMAP_ENTRY_INDEX;
static const int ROOT_BITMAP_ENTRY_POS;
static const Entry EMPTY_BITMAP_ENTRY;
+ static const int TERMINAL_LINKED_ENTRY_COUNT;
static const int MAX_BUFFER_SIZE;
uint32_t getBitShuffledKey(const uint32_t key) const;
@@ -378,6 +387,8 @@ class TrieMap {
AK_FORCE_INLINE int getTailEntryIndex() const {
return (mBuffer.getTailPosition() - ROOT_BITMAP_ENTRY_POS) / ENTRY_SIZE;
}
+
+ bool removeInner(const Entry &bitmapEntry);
};
} // namespace latinime
diff --git a/native/jni/src/utils/int_array_view.h b/native/jni/src/utils/int_array_view.h
index c1ddc9812..53f2d2971 100644
--- a/native/jni/src/utils/int_array_view.h
+++ b/native/jni/src/utils/int_array_view.h
@@ -91,6 +91,11 @@ class IntArrayView {
return mPtr + mSize;
}
+ // Returns the view whose size is smaller than or equal to the given count.
+ const IntArrayView limit(const size_t maxSize) const {
+ return IntArrayView(mPtr, std::min(maxSize, mSize));
+ }
+
private:
DISALLOW_ASSIGNMENT_OPERATOR(IntArrayView);
diff --git a/native/jni/tests/suggest/policyimpl/dictionary/utils/trie_map_test.cpp b/native/jni/tests/suggest/policyimpl/dictionary/utils/trie_map_test.cpp
index df778b6cf..8c8e8838a 100644
--- a/native/jni/tests/suggest/policyimpl/dictionary/utils/trie_map_test.cpp
+++ b/native/jni/tests/suggest/policyimpl/dictionary/utils/trie_map_test.cpp
@@ -47,6 +47,31 @@ TEST(TrieMapTest, TestSetAndGet) {
EXPECT_EQ(0xFFFFFFFFFull, trieMap.getRoot(0).mValue);
}
+TEST(TrieMapTest, TestRemove) {
+ TrieMap trieMap;
+ trieMap.putRoot(10, 10);
+ EXPECT_EQ(10ull, trieMap.getRoot(10).mValue);
+ EXPECT_TRUE(trieMap.remove(10, trieMap.getRootBitmapEntryIndex()));
+ EXPECT_FALSE(trieMap.getRoot(10).mIsValid);
+ for (const auto &element : trieMap.getEntriesInRootLevel()) {
+ EXPECT_TRUE(false);
+ }
+ EXPECT_TRUE(trieMap.putRoot(10, 0x3FFFFF));
+ EXPECT_FALSE(trieMap.remove(11, trieMap.getRootBitmapEntryIndex()))
+ << "Should fail if the key does not exist.";
+ EXPECT_EQ(0x3FFFFFull, trieMap.getRoot(10).mValue);
+ trieMap.putRoot(12, 11);
+ const int nextLevel = trieMap.getNextLevelBitmapEntryIndex(10);
+ trieMap.put(10, 10, nextLevel);
+ EXPECT_EQ(0x3FFFFFull, trieMap.getRoot(10).mValue);
+ EXPECT_EQ(10ull, trieMap.get(10, nextLevel).mValue);
+ EXPECT_TRUE(trieMap.remove(10, trieMap.getRootBitmapEntryIndex()));
+ const TrieMap::Result result = trieMap.getRoot(10);
+ EXPECT_FALSE(result.mIsValid);
+ EXPECT_EQ(TrieMap::INVALID_INDEX, result.mNextLevelBitmapEntryIndex);
+ EXPECT_EQ(11ull, trieMap.getRoot(12).mValue);
+}
+
TEST(TrieMapTest, TestSetAndGetLarge) {
static const int ELEMENT_COUNT = 200000;
TrieMap trieMap;
diff --git a/native/jni/tests/utils/int_array_view_test.cpp b/native/jni/tests/utils/int_array_view_test.cpp
index bd843ab02..ecc451af0 100644
--- a/native/jni/tests/utils/int_array_view_test.cpp
+++ b/native/jni/tests/utils/int_array_view_test.cpp
@@ -53,9 +53,24 @@ TEST(IntArrayViewTest, TestConstructFromArray) {
TEST(IntArrayViewTest, TestConstructFromObject) {
const int object = 10;
const auto intArrayView = IntArrayView::fromObject(&object);
- EXPECT_EQ(1, intArrayView.size());
+ EXPECT_EQ(1u, intArrayView.size());
EXPECT_EQ(object, intArrayView[0]);
}
+TEST(IntArrayViewTest, TestLimit) {
+ const std::vector<int> intVector = {3, 2, 1, 0, -1, -2};
+ IntArrayView intArrayView(intVector);
+
+ EXPECT_TRUE(intArrayView.limit(0).empty());
+ EXPECT_EQ(intArrayView.size(), intArrayView.limit(intArrayView.size()).size());
+ EXPECT_EQ(intArrayView.size(), intArrayView.limit(1000).size());
+
+ IntArrayView subView = intArrayView.limit(4);
+ EXPECT_EQ(4u, subView.size());
+ for (size_t i = 0; i < subView.size(); ++i) {
+ EXPECT_EQ(intVector[i], subView[i]);
+ }
+}
+
} // namespace
} // namespace latinime