aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorKeisuke Kuroyanagi <ksk@google.com>2014-08-21 12:48:24 +0900
committerKeisuke Kuroyanagi <ksk@google.com>2014-08-22 20:13:04 +0900
commit063f86d40f2cb0d250b2166af8e1cf98ab135f8c (patch)
treec4bc899bf78b41c5e1a22a89d67834d4799fef8a
parent9aa6699107de4da356b8eb89fb3ca38100e19c9d (diff)
downloadlatinime-063f86d40f2cb0d250b2166af8e1cf98ab135f8c.tar.gz
latinime-063f86d40f2cb0d250b2166af8e1cf98ab135f8c.tar.xz
latinime-063f86d40f2cb0d250b2166af8e1cf98ab135f8c.zip
Truncate entries in language model dict content.
Bug: 14425059 Change-Id: I023c1d5109a2c43fcea3bb11a0fd7198c82891ba
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp99
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h36
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp17
3 files changed, 152 insertions, 0 deletions
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
index a66cfef76..ea2d24e67 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
@@ -16,6 +16,9 @@
#include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h"
+#include <algorithm>
+#include <cstring>
+
#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h"
namespace latinime {
@@ -68,6 +71,19 @@ bool LanguageModelDictContent::removeNgramProbabilityEntry(const WordIdArrayView
return mTrieMap.remove(wordId, bitmapEntryIndex);
}
+bool LanguageModelDictContent::truncateEntries(const int *const entryCounts,
+ const int *const maxEntryCounts, const HeaderPolicy *const headerPolicy) {
+ for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
+ if (entryCounts[i] <= maxEntryCounts[i]) {
+ continue;
+ }
+ if (!turncateEntriesInSpecifiedLevel(headerPolicy, maxEntryCounts[i], i)) {
+ return false;
+ }
+ }
+ return true;
+}
+
bool LanguageModelDictContent::runGCInner(
const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
const TrieMap::TrieMapRange trieMapRange,
@@ -162,4 +178,87 @@ bool LanguageModelDictContent::updateAllProbabilityEntriesInner(const int bitmap
return true;
}
+bool LanguageModelDictContent::turncateEntriesInSpecifiedLevel(
+ const HeaderPolicy *const headerPolicy, const int maxEntryCount, const int targetLevel) {
+ std::vector<int> prevWordIds;
+ std::vector<EntryInfoToTurncate> entryInfoVector;
+ if (!getEntryInfo(headerPolicy, targetLevel, mTrieMap.getRootBitmapEntryIndex(),
+ &prevWordIds, &entryInfoVector)) {
+ return false;
+ }
+ if (static_cast<int>(entryInfoVector.size()) <= maxEntryCount) {
+ return true;
+ }
+ const int entryCountToRemove = static_cast<int>(entryInfoVector.size()) - maxEntryCount;
+ std::partial_sort(entryInfoVector.begin(), entryInfoVector.begin() + entryCountToRemove,
+ entryInfoVector.end(),
+ EntryInfoToTurncate::Comparator());
+ for (int i = 0; i < entryCountToRemove; ++i) {
+ const EntryInfoToTurncate &entryInfo = entryInfoVector[i];
+ if (!removeNgramProbabilityEntry(
+ WordIdArrayView(entryInfo.mPrevWordIds, entryInfo.mEntryLevel), entryInfo.mKey)) {
+ return false;
+ }
+ }
+ return true;
+}
+
+bool LanguageModelDictContent::getEntryInfo(const HeaderPolicy *const headerPolicy,
+ const int targetLevel, const int bitmapEntryIndex, std::vector<int> *const prevWordIds,
+ std::vector<EntryInfoToTurncate> *const outEntryInfo) const {
+ const int currentLevel = prevWordIds->size();
+ for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
+ if (currentLevel < targetLevel) {
+ if (!entry.hasNextLevelMap()) {
+ continue;
+ }
+ prevWordIds->push_back(entry.key());
+ if (!getEntryInfo(headerPolicy, targetLevel, entry.getNextLevelBitmapEntryIndex(),
+ prevWordIds, outEntryInfo)) {
+ return false;
+ }
+ prevWordIds->pop_back();
+ continue;
+ }
+ const ProbabilityEntry probabilityEntry =
+ ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
+ const int probability = (mHasHistoricalInfo) ?
+ ForgettingCurveUtils::decodeProbability(probabilityEntry.getHistoricalInfo(),
+ headerPolicy) : probabilityEntry.getProbability();
+ outEntryInfo->emplace_back(probability,
+ probabilityEntry.getHistoricalInfo()->getTimeStamp(),
+ entry.key(), targetLevel, prevWordIds->data());
+ }
+ return true;
+}
+
+bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()(
+ const EntryInfoToTurncate &left, const EntryInfoToTurncate &right) const {
+ if (left.mProbability != right.mProbability) {
+ return left.mProbability < right.mProbability;
+ }
+ if (left.mTimestamp != right.mTimestamp) {
+ return left.mTimestamp > right.mTimestamp;
+ }
+ if (left.mKey != right.mKey) {
+ return left.mKey < right.mKey;
+ }
+ if (left.mEntryLevel != right.mEntryLevel) {
+ return left.mEntryLevel > right.mEntryLevel;
+ }
+ for (int i = 0; i < left.mEntryLevel; ++i) {
+ if (left.mPrevWordIds[i] != right.mPrevWordIds[i]) {
+ return left.mPrevWordIds[i] < right.mPrevWordIds[i];
+ }
+ }
+ // left and rigth represent the same entry.
+ return false;
+}
+
+LanguageModelDictContent::EntryInfoToTurncate::EntryInfoToTurncate(const int probability,
+ const int timestamp, const int key, const int entryLevel, const int *const prevWordIds)
+ : mProbability(probability), mTimestamp(timestamp), mKey(key), mEntryLevel(entryLevel) {
+ memmove(mPrevWordIds, prevWordIds, mEntryLevel * sizeof(mPrevWordIds[0]));
+}
+
} // namespace latinime
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
index 31ee2fe24..43b2aab66 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
@@ -18,6 +18,7 @@
#define LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H
#include <cstdio>
+#include <vector>
#include "defines.h"
#include "suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h"
@@ -77,13 +78,43 @@ class LanguageModelDictContent {
bool updateAllProbabilityEntries(const HeaderPolicy *const headerPolicy,
int *const outEntryCounts) {
+ for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
+ outEntryCounts[i] = 0;
+ }
return updateAllProbabilityEntriesInner(mTrieMap.getRootBitmapEntryIndex(), 0 /* level */,
headerPolicy, outEntryCounts);
}
+ // entryCounts should be created by updateAllProbabilityEntries.
+ bool truncateEntries(const int *const entryCounts, const int *const maxEntryCounts,
+ const HeaderPolicy *const headerPolicy);
+
private:
DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent);
+ class EntryInfoToTurncate {
+ public:
+ class Comparator {
+ public:
+ bool operator()(const EntryInfoToTurncate &left,
+ const EntryInfoToTurncate &right) const;
+ private:
+ DISALLOW_ASSIGNMENT_OPERATOR(Comparator);
+ };
+
+ EntryInfoToTurncate(const int probability, const int timestamp, const int key,
+ const int entryLevel, const int *const prevWordIds);
+
+ int mProbability;
+ int mTimestamp;
+ int mKey;
+ int mEntryLevel;
+ int mPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
+
+ private:
+ DISALLOW_DEFAULT_CONSTRUCTOR(EntryInfoToTurncate);
+ };
+
TrieMap mTrieMap;
const bool mHasHistoricalInfo;
@@ -94,6 +125,11 @@ class LanguageModelDictContent {
int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
bool updateAllProbabilityEntriesInner(const int bitmapEntryIndex, const int level,
const HeaderPolicy *const headerPolicy, int *const outEntryCounts);
+ bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy,
+ const int maxEntryCount, const int targetLevel);
+ bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel,
+ const int bitmapEntryIndex, std::vector<int> *const prevWordIds,
+ std::vector<EntryInfoToTurncate> *const outEntryInfo) const;
};
} // namespace latinime
#endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
index 35bc44b8f..d53575aa7 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
@@ -91,6 +91,21 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
AKLOGE("Failed to update probabilities in language model dict content.");
return false;
}
+ if (headerPolicy->isDecayingDict()) {
+ int maxEntryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
+ maxEntryCountTable[0] = headerPolicy->getMaxUnigramCount();
+ maxEntryCountTable[1] = headerPolicy->getMaxBigramCount();
+ for (size_t i = 2; i < NELEMS(maxEntryCountTable); ++i) {
+ // TODO: Have max n-gram count.
+ maxEntryCountTable[i] = headerPolicy->getMaxBigramCount();
+ }
+ if (!mBuffers->getMutableLanguageModelDictContent()->truncateEntries(entryCountTable,
+ maxEntryCountTable, headerPolicy)) {
+ AKLOGE("Failed to truncate entries in language model dict content.");
+ return false;
+ }
+ }
+
DynamicPtReadingHelper readingHelper(&ptNodeReader, &ptNodeArrayReader);
readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos);
DynamicPtGcEventListeners
@@ -193,6 +208,7 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
return true;
}
+// TODO: Remove.
bool Ver4PatriciaTrieWritingHelper::truncateUnigrams(
const Ver4PatriciaTrieNodeReader *const ptNodeReader,
Ver4PatriciaTrieNodeWriter *const ptNodeWriter, const int maxUnigramCount) {
@@ -233,6 +249,7 @@ bool Ver4PatriciaTrieWritingHelper::truncateUnigrams(
return true;
}
+// TODO: Remove.
bool Ver4PatriciaTrieWritingHelper::truncateBigrams(const int maxBigramCount) {
const TerminalPositionLookupTable *const terminalPosLookupTable =
mBuffers->getTerminalPositionLookupTable();