aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp44
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h10
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp29
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp6
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h4
5 files changed, 72 insertions, 21 deletions
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
index bbcea2ee0..a66cfef76 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
@@ -16,6 +16,8 @@
#include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h"
+#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h"
+
namespace latinime {
bool LanguageModelDictContent::save(FILE *const file) const {
@@ -118,4 +120,46 @@ int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWord
return bitmapEntryIndex;
}
+bool LanguageModelDictContent::updateAllProbabilityEntriesInner(const int bitmapEntryIndex,
+ const int level, const HeaderPolicy *const headerPolicy, int *const outEntryCounts) {
+ for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
+ if (level > MAX_PREV_WORD_COUNT_FOR_N_GRAM) {
+ AKLOGE("Invalid level. level: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.",
+ level, MAX_PREV_WORD_COUNT_FOR_N_GRAM);
+ return false;
+ }
+ const ProbabilityEntry probabilityEntry =
+ ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
+ if (mHasHistoricalInfo && !probabilityEntry.representsBeginningOfSentence()) {
+ const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave(
+ probabilityEntry.getHistoricalInfo(), headerPolicy);
+ if (ForgettingCurveUtils::needsToKeep(&historicalInfo, headerPolicy)) {
+ // Update the entry.
+ const ProbabilityEntry updatedEntry(probabilityEntry.getFlags(), &historicalInfo);
+ if (!mTrieMap.put(entry.key(), updatedEntry.encode(mHasHistoricalInfo),
+ bitmapEntryIndex)) {
+ return false;
+ }
+ } else {
+ // Remove the entry.
+ if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) {
+ return false;
+ }
+ continue;
+ }
+ }
+ if (!probabilityEntry.representsBeginningOfSentence()) {
+ outEntryCounts[level] += 1;
+ }
+ if (!entry.hasNextLevelMap()) {
+ continue;
+ }
+ if (!updateAllProbabilityEntriesInner(entry.getNextLevelBitmapEntryIndex(), level + 1,
+ headerPolicy, outEntryCounts)) {
+ return false;
+ }
+ }
+ return true;
+}
+
} // namespace latinime
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
index bd07f2f62..31ee2fe24 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
@@ -29,6 +29,8 @@
namespace latinime {
+class HeaderPolicy;
+
/**
* Class representing language model.
*
@@ -73,6 +75,12 @@ class LanguageModelDictContent {
bool removeNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId);
+ bool updateAllProbabilityEntries(const HeaderPolicy *const headerPolicy,
+ int *const outEntryCounts) {
+ return updateAllProbabilityEntriesInner(mTrieMap.getRootBitmapEntryIndex(), 0 /* level */,
+ headerPolicy, outEntryCounts);
+ }
+
private:
DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent);
@@ -84,6 +92,8 @@ class LanguageModelDictContent {
int *const outNgramCount);
int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds);
int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
+ bool updateAllProbabilityEntriesInner(const int bitmapEntryIndex, const int level,
+ const HeaderPolicy *const headerPolicy, int *const outEntryCounts);
};
} // namespace latinime
#endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp
index fb6840ba6..b7c31bf75 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp
@@ -161,29 +161,15 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeProbabilityAndGetNeedsToKeepPtNodeA
const ProbabilityEntry originalProbabilityEntry =
mBuffers->getLanguageModelDictContent()->getProbabilityEntry(
toBeUpdatedPtNodeParams->getTerminalId());
- if (originalProbabilityEntry.hasHistoricalInfo()) {
- const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave(
- originalProbabilityEntry.getHistoricalInfo(), mHeaderPolicy);
- const ProbabilityEntry probabilityEntry(originalProbabilityEntry.getFlags(),
- &historicalInfo);
- if (!mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry(
- toBeUpdatedPtNodeParams->getTerminalId(), &probabilityEntry)) {
- AKLOGE("Cannot write updated probability entry. terminalId: %d",
- toBeUpdatedPtNodeParams->getTerminalId());
- return false;
- }
- const bool isValid = ForgettingCurveUtils::needsToKeep(&historicalInfo, mHeaderPolicy);
- if (!isValid) {
- if (!markPtNodeAsWillBecomeNonTerminal(toBeUpdatedPtNodeParams)) {
- AKLOGE("Cannot mark PtNode as willBecomeNonTerminal.");
- return false;
- }
- }
- *outNeedsToKeepPtNode = isValid;
- } else {
- // No need to update probability.
+ if (originalProbabilityEntry.isValid()) {
*outNeedsToKeepPtNode = true;
+ return true;
+ }
+ if (!markPtNodeAsWillBecomeNonTerminal(toBeUpdatedPtNodeParams)) {
+ AKLOGE("Cannot mark PtNode as willBecomeNonTerminal.");
+ return false;
}
+ *outNeedsToKeepPtNode = false;
return true;
}
@@ -380,6 +366,7 @@ bool Ver4PatriciaTrieNodeWriter::writePtNodeAndGetTerminalIdAndAdvancePosition(
isTerminal, ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */);
}
+// TODO: Move probability handling code to LanguageModelDictContent.
const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom(
const ProbabilityEntry *const originalProbabilityEntry,
const ProbabilityEntry *const probabilityEntry) const {
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
index 4220312e0..35bc44b8f 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
@@ -85,6 +85,12 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
mBuffers, headerPolicy, &ptNodeReader, &ptNodeArrayReader, &bigramPolicy,
&shortcutPolicy);
+ int entryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
+ if (!mBuffers->getMutableLanguageModelDictContent()->updateAllProbabilityEntries(headerPolicy,
+ entryCountTable)) {
+ AKLOGE("Failed to update probabilities in language model dict content.");
+ return false;
+ }
DynamicPtReadingHelper readingHelper(&ptNodeReader, &ptNodeArrayReader);
readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos);
DynamicPtGcEventListeners
diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h
index 6d91790b2..c2aeac211 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h
@@ -84,6 +84,10 @@ class TrieMap {
return mValue;
}
+ AK_FORCE_INLINE int getNextLevelBitmapEntryIndex() const {
+ return mNextLevelBitmapEntryIndex;
+ }
+
private:
const TrieMap *const mTrieMap;
const int mKey;