aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp43
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h22
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp37
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h3
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/utils/entry_counters.h14
5 files changed, 54 insertions, 65 deletions
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
index 934c4f470..a88996524 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
@@ -23,8 +23,6 @@
namespace latinime {
-const int LanguageModelDictContent::UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE = 0;
-const int LanguageModelDictContent::BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE = 1;
const int LanguageModelDictContent::DUMMY_PROBABILITY_FOR_VALID_WORDS = 1;
bool LanguageModelDictContent::save(FILE *const file) const {
@@ -33,10 +31,9 @@ bool LanguageModelDictContent::save(FILE *const file) const {
bool LanguageModelDictContent::runGC(
const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
- const LanguageModelDictContent *const originalContent,
- int *const outNgramCount) {
+ const LanguageModelDictContent *const originalContent) {
return runGCInner(terminalIdMap, originalContent->mTrieMap.getEntriesInRootLevel(),
- 0 /* nextLevelBitmapEntryIndex */, outNgramCount);
+ 0 /* nextLevelBitmapEntryIndex */);
}
const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArrayView prevWordIds,
@@ -143,18 +140,23 @@ LanguageModelDictContent::EntryRange LanguageModelDictContent::getProbabilityEnt
return EntryRange(mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex), mHasHistoricalInfo);
}
-bool LanguageModelDictContent::truncateEntries(const int *const entryCounts,
- const int *const maxEntryCounts, const HeaderPolicy *const headerPolicy,
- int *const outEntryCounts) {
- for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
- if (entryCounts[i] <= maxEntryCounts[i]) {
- outEntryCounts[i] = entryCounts[i];
+bool LanguageModelDictContent::truncateEntries(const EntryCounts &currentEntryCounts,
+ const EntryCounts &maxEntryCounts, const HeaderPolicy *const headerPolicy,
+ MutableEntryCounters *const outEntryCounters) {
+ for (int prevWordCount = 0; prevWordCount <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++prevWordCount) {
+ const int totalWordCount = prevWordCount + 1;
+ if (currentEntryCounts.getNgramCount(totalWordCount)
+ <= maxEntryCounts.getNgramCount(totalWordCount)) {
+ outEntryCounters->setNgramCount(totalWordCount,
+ currentEntryCounts.getNgramCount(totalWordCount));
continue;
}
- if (!turncateEntriesInSpecifiedLevel(headerPolicy, maxEntryCounts[i], i,
- &outEntryCounts[i])) {
+ int entryCount = 0;
+ if (!turncateEntriesInSpecifiedLevel(headerPolicy,
+ maxEntryCounts.getNgramCount(totalWordCount), prevWordCount, &entryCount)) {
return false;
}
+ outEntryCounters->setNgramCount(totalWordCount, entryCount);
}
return true;
}
@@ -208,8 +210,7 @@ const ProbabilityEntry LanguageModelDictContent::createUpdatedEntryFrom(
bool LanguageModelDictContent::runGCInner(
const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
- const TrieMap::TrieMapRange trieMapRange,
- const int nextLevelBitmapEntryIndex, int *const outNgramCount) {
+ const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex) {
for (auto &entry : trieMapRange) {
const auto it = terminalIdMap->find(entry.key());
if (it == terminalIdMap->end() || it->second == Ver4DictConstants::NOT_A_TERMINAL_ID) {
@@ -219,13 +220,9 @@ bool LanguageModelDictContent::runGCInner(
if (!mTrieMap.put(it->second, entry.value(), nextLevelBitmapEntryIndex)) {
return false;
}
- if (outNgramCount) {
- *outNgramCount += 1;
- }
if (entry.hasNextLevelMap()) {
if (!runGCInner(terminalIdMap, entry.getEntriesInNextLevel(),
- mTrieMap.getNextLevelBitmapEntryIndex(it->second, nextLevelBitmapEntryIndex),
- outNgramCount)) {
+ mTrieMap.getNextLevelBitmapEntryIndex(it->second, nextLevelBitmapEntryIndex))) {
return false;
}
}
@@ -268,7 +265,7 @@ int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWord
bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex,
const int prevWordCount, const HeaderPolicy *const headerPolicy,
- int *const outEntryCounts) {
+ MutableEntryCounters *const outEntryCounters) {
for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
if (prevWordCount > MAX_PREV_WORD_COUNT_FOR_N_GRAM) {
AKLOGE("Invalid prevWordCount. prevWordCount: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.",
@@ -305,13 +302,13 @@ bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int b
}
}
if (!probabilityEntry.representsBeginningOfSentence()) {
- outEntryCounts[prevWordCount] += 1;
+ outEntryCounters->incrementNgramCount(prevWordCount + 1);
}
if (!entry.hasNextLevelMap()) {
continue;
}
if (!updateAllProbabilityEntriesForGCInner(entry.getNextLevelBitmapEntryIndex(),
- prevWordCount + 1, headerPolicy, outEntryCounts)) {
+ prevWordCount + 1, headerPolicy, outEntryCounters)) {
return false;
}
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
index 9a5f87741..41a429a54 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
@@ -41,9 +41,6 @@ class HeaderPolicy;
*/
class LanguageModelDictContent {
public:
- static const int UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE;
- static const int BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE;
-
// Pair of word id and probability entry used for iteration.
class WordIdAndProbabilityEntry {
public:
@@ -127,8 +124,7 @@ class LanguageModelDictContent {
bool save(FILE *const file) const;
bool runGC(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
- const LanguageModelDictContent *const originalContent,
- int *const outNgramCount);
+ const LanguageModelDictContent *const originalContent);
const WordAttributes getWordAttributes(const WordIdArrayView prevWordIds, const int wordId,
const HeaderPolicy *const headerPolicy) const;
@@ -156,17 +152,14 @@ class LanguageModelDictContent {
EntryRange getProbabilityEntries(const WordIdArrayView prevWordIds) const;
bool updateAllProbabilityEntriesForGC(const HeaderPolicy *const headerPolicy,
- int *const outEntryCounts) {
- for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
- outEntryCounts[i] = 0;
- }
+ MutableEntryCounters *const outEntryCounters) {
return updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(),
- 0 /* prevWordCount */, headerPolicy, outEntryCounts);
+ 0 /* prevWordCount */, headerPolicy, outEntryCounters);
}
// entryCounts should be created by updateAllProbabilityEntries.
- bool truncateEntries(const int *const entryCounts, const int *const maxEntryCounts,
- const HeaderPolicy *const headerPolicy, int *const outEntryCounts);
+ bool truncateEntries(const EntryCounts &currentEntryCounts, const EntryCounts &maxEntryCounts,
+ const HeaderPolicy *const headerPolicy, MutableEntryCounters *const outEntryCounters);
bool updateAllEntriesOnInputWord(const WordIdArrayView prevWordIds, const int wordId,
const bool isValid, const HistoricalInfo historicalInfo,
@@ -206,12 +199,11 @@ class LanguageModelDictContent {
const bool mHasHistoricalInfo;
bool runGCInner(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
- const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex,
- int *const outNgramCount);
+ const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex);
int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds);
int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
bool updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, const int prevWordCount,
- const HeaderPolicy *const headerPolicy, int *const outEntryCounts);
+ const HeaderPolicy *const headerPolicy, MutableEntryCounters *const outEntryCounters);
bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy,
const int maxEntryCount, const int targetLevel, int *const outEntryCount);
bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel,
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
index e49d0308e..c7ffcc860 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
@@ -57,16 +57,14 @@ bool Ver4PatriciaTrieWritingHelper::writeToDictFileWithGC(const int rootPtNodeAr
Ver4DictBuffers::Ver4DictBuffersPtr dictBuffers(
Ver4DictBuffers::createVer4DictBuffers(headerPolicy,
Ver4DictConstants::MAX_DICTIONARY_SIZE));
- int unigramCount = 0;
- int bigramCount = 0;
- if (!runGC(rootPtNodeArrayPos, headerPolicy, dictBuffers.get(), &unigramCount, &bigramCount)) {
+ MutableEntryCounters entryCounters;
+ if (!runGC(rootPtNodeArrayPos, headerPolicy, dictBuffers.get(), &entryCounters)) {
return false;
}
BufferWithExtendableBuffer headerBuffer(
BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE);
if (!headerPolicy->fillInAndWriteHeaderToBuffer(true /* updatesLastDecayedTime */,
- EntryCounts(unigramCount, bigramCount, 0 /* trigramCount */),
- 0 /* extendedRegionSize */, &headerBuffer)) {
+ entryCounters.getEntryCounts(), 0 /* extendedRegionSize */, &headerBuffer)) {
return false;
}
return dictBuffers->flushHeaderAndDictBuffers(dictDirPath, &headerBuffer);
@@ -74,7 +72,7 @@ bool Ver4PatriciaTrieWritingHelper::writeToDictFileWithGC(const int rootPtNodeAr
bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
const HeaderPolicy *const headerPolicy, Ver4DictBuffers *const buffersToWrite,
- int *const outUnigramCount, int *const outBigramCount) {
+ MutableEntryCounters *const outEntryCounters) {
Ver4PatriciaTrieNodeReader ptNodeReader(mBuffers->getTrieBuffer());
Ver4PtNodeArrayReader ptNodeArrayReader(mBuffers->getTrieBuffer());
Ver4ShortcutListPolicy shortcutPolicy(mBuffers->getMutableShortcutDictContent(),
@@ -82,24 +80,17 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
Ver4PatriciaTrieNodeWriter ptNodeWriter(mBuffers->getWritableTrieBuffer(),
mBuffers, &ptNodeReader, &ptNodeArrayReader, &shortcutPolicy);
- int entryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
if (!mBuffers->getMutableLanguageModelDictContent()->updateAllProbabilityEntriesForGC(
- headerPolicy, entryCountTable)) {
+ headerPolicy, outEntryCounters)) {
AKLOGE("Failed to update probabilities in language model dict content.");
return false;
}
if (headerPolicy->isDecayingDict()) {
- int maxEntryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
- maxEntryCountTable[LanguageModelDictContent::UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE] =
- headerPolicy->getMaxUnigramCount();
- maxEntryCountTable[LanguageModelDictContent::BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE] =
- headerPolicy->getMaxBigramCount();
- for (size_t i = 2; i < NELEMS(maxEntryCountTable); ++i) {
- // TODO: Have max n-gram count.
- maxEntryCountTable[i] = headerPolicy->getMaxBigramCount();
- }
- if (!mBuffers->getMutableLanguageModelDictContent()->truncateEntries(entryCountTable,
- maxEntryCountTable, headerPolicy, entryCountTable)) {
+ const EntryCounts maxEntryCounts(headerPolicy->getMaxUnigramCount(),
+ headerPolicy->getMaxBigramCount(), headerPolicy->getMaxTrigramCount());
+ if (!mBuffers->getMutableLanguageModelDictContent()->truncateEntries(
+ outEntryCounters->getEntryCounts(), maxEntryCounts, headerPolicy,
+ outEntryCounters)) {
AKLOGE("Failed to truncate entries in language model dict content.");
return false;
}
@@ -143,9 +134,9 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
&terminalIdMap)) {
return false;
}
- // Run GC for probability dict content.
+ // Run GC for language model dict content.
if (!buffersToWrite->getMutableLanguageModelDictContent()->runGC(&terminalIdMap,
- mBuffers->getLanguageModelDictContent(), nullptr /* outNgramCount */)) {
+ mBuffers->getLanguageModelDictContent())) {
return false;
}
// Run GC for shortcut dict content.
@@ -168,10 +159,6 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
&traversePolicyToUpdateAllPtNodeFlagsAndTerminalIds)) {
return false;
}
- *outUnigramCount =
- entryCountTable[LanguageModelDictContent::UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE];
- *outBigramCount =
- entryCountTable[LanguageModelDictContent::BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE];
return true;
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h
index 57a1f7bb1..c56cea5cf 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h
@@ -67,8 +67,7 @@ class Ver4PatriciaTrieWritingHelper {
};
bool runGC(const int rootPtNodeArrayPos, const HeaderPolicy *const headerPolicy,
- Ver4DictBuffers *const buffersToWrite, int *const outUnigramCount,
- int *const outBigramCount);
+ Ver4DictBuffers *const buffersToWrite, MutableEntryCounters *const outEntryCounters);
Ver4DictBuffers *const mBuffers;
};
diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/entry_counters.h b/native/jni/src/suggest/policyimpl/dictionary/utils/entry_counters.h
index b8fa5aa9e..73dc42a18 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/utils/entry_counters.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/utils/entry_counters.h
@@ -46,6 +46,13 @@ class EntryCounts final {
return mEntryCounts[2];
}
+ int getNgramCount(const size_t n) const {
+ if (n < 1 || n > mEntryCounts.size()) {
+ return 0;
+ }
+ return mEntryCounts[n - 1];
+ }
+
private:
DISALLOW_ASSIGNMENT_OPERATOR(EntryCounts);
@@ -110,6 +117,13 @@ class MutableEntryCounters final {
--mEntryCounters[n - 1];
}
+ void setNgramCount(const size_t n, const int count) {
+ if (n < 1 || n > mEntryCounters.size()) {
+ return;
+ }
+ mEntryCounters[n - 1] = count;
+ }
+
private:
DISALLOW_COPY_AND_ASSIGN(MutableEntryCounters);