aboutsummaryrefslogtreecommitdiffstats
path: root/native/jni/src
diff options
context:
space:
mode:
Diffstat (limited to 'native/jni/src')
-rw-r--r--native/jni/src/suggest/core/dictionary/dictionary.cpp42
-rw-r--r--native/jni/src/suggest/core/dictionary/dictionary.h18
-rw-r--r--native/jni/src/suggest/core/dictionary/dictionary_utils.cpp8
-rw-r--r--native/jni/src/suggest/core/dictionary/property/historical_info.h10
-rw-r--r--native/jni/src/suggest/core/dictionary/property/ngram_property.h10
-rw-r--r--native/jni/src/suggest/core/dictionary/property/unigram_property.h20
-rw-r--r--native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h8
-rw-r--r--native/jni/src/suggest/core/policy/traversal.h2
-rw-r--r--native/jni/src/suggest/core/policy/weighting.cpp12
-rw-r--r--native/jni/src/suggest/core/session/dic_traverse_session.cpp6
-rw-r--r--native/jni/src/suggest/core/session/dic_traverse_session.h4
-rw-r--r--native/jni/src/suggest/core/session/ngram_context.h (renamed from native/jni/src/suggest/core/session/prev_words_info.h)31
-rw-r--r--native/jni/src/suggest/core/suggest.cpp8
-rw-r--r--native/jni/src/suggest/core/suggest_options.h9
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp8
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h2
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp2
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp53
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h8
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp2
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h9
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.cpp2
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.h4
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp125
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h22
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h10
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp33
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h13
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp86
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h13
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp10
-rw-r--r--native/jni/src/suggest/policyimpl/typing/scoring_params.cpp6
-rw-r--r--native/jni/src/suggest/policyimpl/typing/scoring_params.h4
-rw-r--r--native/jni/src/suggest/policyimpl/typing/typing_traversal.h26
-rw-r--r--native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp5
-rw-r--r--native/jni/src/utils/jni_data_utils.h6
36 files changed, 380 insertions, 257 deletions
diff --git a/native/jni/src/suggest/core/dictionary/dictionary.cpp b/native/jni/src/suggest/core/dictionary/dictionary.cpp
index 7a69d3ceb..697e99ffb 100644
--- a/native/jni/src/suggest/core/dictionary/dictionary.cpp
+++ b/native/jni/src/suggest/core/dictionary/dictionary.cpp
@@ -23,7 +23,7 @@
#include "suggest/core/policy/dictionary_header_structure_policy.h"
#include "suggest/core/result/suggestion_results.h"
#include "suggest/core/session/dic_traverse_session.h"
-#include "suggest/core/session/prev_words_info.h"
+#include "suggest/core/session/ngram_context.h"
#include "suggest/core/suggest.h"
#include "suggest/core/suggest_options.h"
#include "suggest/policyimpl/gesture/gesture_suggest_policy_factory.h"
@@ -46,11 +46,11 @@ Dictionary::Dictionary(JNIEnv *env, DictionaryStructureWithBufferPolicy::Structu
void Dictionary::getSuggestions(ProximityInfo *proximityInfo, DicTraverseSession *traverseSession,
int *xcoordinates, int *ycoordinates, int *times, int *pointerIds, int *inputCodePoints,
- int inputSize, const PrevWordsInfo *const prevWordsInfo,
+ int inputSize, const NgramContext *const ngramContext,
const SuggestOptions *const suggestOptions, const float weightOfLangModelVsSpatialModel,
SuggestionResults *const outSuggestionResults) const {
TimeKeeper::setCurrentTime();
- traverseSession->init(this, prevWordsInfo, suggestOptions);
+ traverseSession->init(this, ngramContext, suggestOptions);
const auto &suggest = suggestOptions->isGesture() ? mGestureSuggest : mTypingSuggest;
suggest->getSuggestions(proximityInfo, traverseSession, xcoordinates,
ycoordinates, times, pointerIds, inputCodePoints, inputSize,
@@ -58,10 +58,10 @@ void Dictionary::getSuggestions(ProximityInfo *proximityInfo, DicTraverseSession
}
Dictionary::NgramListenerForPrediction::NgramListenerForPrediction(
- const PrevWordsInfo *const prevWordsInfo, const WordIdArrayView prevWordIds,
+ const NgramContext *const ngramContext, const WordIdArrayView prevWordIds,
SuggestionResults *const suggestionResults,
const DictionaryStructureWithBufferPolicy *const dictStructurePolicy)
- : mPrevWordsInfo(prevWordsInfo), mPrevWordIds(prevWordIds),
+ : mNgramContext(ngramContext), mPrevWordIds(prevWordIds),
mSuggestionResults(suggestionResults), mDictStructurePolicy(dictStructurePolicy) {}
void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbability,
@@ -69,7 +69,7 @@ void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbabi
if (targetWordId == NOT_A_WORD_ID) {
return;
}
- if (mPrevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */)
+ if (mNgramContext->isNthPrevWordBeginningOfSentence(1 /* n */)
&& ngramProbability == NOT_A_PROBABILITY) {
return;
}
@@ -85,20 +85,20 @@ void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbabi
wordAttributes.getProbability());
}
-void Dictionary::getPredictions(const PrevWordsInfo *const prevWordsInfo,
+void Dictionary::getPredictions(const NgramContext *const ngramContext,
SuggestionResults *const outSuggestionResults) const {
TimeKeeper::setCurrentTime();
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray;
- const WordIdArrayView prevWordIds = prevWordsInfo->getPrevWordIds(
+ const WordIdArrayView prevWordIds = ngramContext->getPrevWordIds(
mDictionaryStructureWithBufferPolicy.get(), &prevWordIdArray,
true /* tryLowerCaseSearch */);
- NgramListenerForPrediction listener(prevWordsInfo, prevWordIds, outSuggestionResults,
+ NgramListenerForPrediction listener(ngramContext, prevWordIds, outSuggestionResults,
mDictionaryStructureWithBufferPolicy.get());
mDictionaryStructureWithBufferPolicy->iterateNgramEntries(prevWordIds, &listener);
}
int Dictionary::getProbability(const CodePointArrayView codePoints) const {
- return getNgramProbability(nullptr /* prevWordsInfo */, codePoints);
+ return getNgramProbability(nullptr /* ngramContext */, codePoints);
}
int Dictionary::getMaxProbabilityOfExactMatches(const CodePointArrayView codePoints) const {
@@ -107,18 +107,18 @@ int Dictionary::getMaxProbabilityOfExactMatches(const CodePointArrayView codePoi
mDictionaryStructureWithBufferPolicy.get(), codePoints);
}
-int Dictionary::getNgramProbability(const PrevWordsInfo *const prevWordsInfo,
+int Dictionary::getNgramProbability(const NgramContext *const ngramContext,
const CodePointArrayView codePoints) const {
TimeKeeper::setCurrentTime();
const int wordId = mDictionaryStructureWithBufferPolicy->getWordId(codePoints,
false /* forceLowerCaseSearch */);
if (wordId == NOT_A_WORD_ID) return NOT_A_PROBABILITY;
- if (!prevWordsInfo) {
+ if (!ngramContext) {
return getDictionaryStructurePolicy()->getProbabilityOfWord(WordIdArrayView(), wordId);
}
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray;
- const WordIdArrayView prevWordIds = prevWordsInfo->getPrevWordIds
- (mDictionaryStructureWithBufferPolicy.get(), &prevWordIdArray,
+ const WordIdArrayView prevWordIds = ngramContext->getPrevWordIds(
+ mDictionaryStructureWithBufferPolicy.get(), &prevWordIdArray,
true /* tryLowerCaseSearch */);
return getDictionaryStructurePolicy()->getProbabilityOfWord(prevWordIds, wordId);
}
@@ -140,24 +140,24 @@ bool Dictionary::removeUnigramEntry(const CodePointArrayView codePoints) {
return mDictionaryStructureWithBufferPolicy->removeUnigramEntry(codePoints);
}
-bool Dictionary::addNgramEntry(const PrevWordsInfo *const prevWordsInfo,
+bool Dictionary::addNgramEntry(const NgramContext *const ngramContext,
const NgramProperty *const ngramProperty) {
TimeKeeper::setCurrentTime();
- return mDictionaryStructureWithBufferPolicy->addNgramEntry(prevWordsInfo, ngramProperty);
+ return mDictionaryStructureWithBufferPolicy->addNgramEntry(ngramContext, ngramProperty);
}
-bool Dictionary::removeNgramEntry(const PrevWordsInfo *const prevWordsInfo,
+bool Dictionary::removeNgramEntry(const NgramContext *const ngramContext,
const CodePointArrayView codePoints) {
TimeKeeper::setCurrentTime();
- return mDictionaryStructureWithBufferPolicy->removeNgramEntry(prevWordsInfo, codePoints);
+ return mDictionaryStructureWithBufferPolicy->removeNgramEntry(ngramContext, codePoints);
}
-bool Dictionary::updateCounter(const PrevWordsInfo *const prevWordsInfo,
+bool Dictionary::updateEntriesForWordWithNgramContext(const NgramContext *const ngramContext,
const CodePointArrayView codePoints, const bool isValidWord,
const HistoricalInfo historicalInfo) {
TimeKeeper::setCurrentTime();
- return mDictionaryStructureWithBufferPolicy->updateCounter(prevWordsInfo, codePoints,
- isValidWord, historicalInfo);
+ return mDictionaryStructureWithBufferPolicy->updateEntriesForWordWithNgramContext(ngramContext,
+ codePoints, isValidWord, historicalInfo);
}
bool Dictionary::flush(const char *const filePath) {
diff --git a/native/jni/src/suggest/core/dictionary/dictionary.h b/native/jni/src/suggest/core/dictionary/dictionary.h
index a58dbfbd7..843aec473 100644
--- a/native/jni/src/suggest/core/dictionary/dictionary.h
+++ b/native/jni/src/suggest/core/dictionary/dictionary.h
@@ -33,7 +33,7 @@ namespace latinime {
class DictionaryStructureWithBufferPolicy;
class DicTraverseSession;
-class PrevWordsInfo;
+class NgramContext;
class ProximityInfo;
class SuggestionResults;
class SuggestOptions;
@@ -66,18 +66,18 @@ class Dictionary {
void getSuggestions(ProximityInfo *proximityInfo, DicTraverseSession *traverseSession,
int *xcoordinates, int *ycoordinates, int *times, int *pointerIds, int *inputCodePoints,
- int inputSize, const PrevWordsInfo *const prevWordsInfo,
+ int inputSize, const NgramContext *const ngramContext,
const SuggestOptions *const suggestOptions, const float weightOfLangModelVsSpatialModel,
SuggestionResults *const outSuggestionResults) const;
- void getPredictions(const PrevWordsInfo *const prevWordsInfo,
+ void getPredictions(const NgramContext *const ngramContext,
SuggestionResults *const outSuggestionResults) const;
int getProbability(const CodePointArrayView codePoints) const;
int getMaxProbabilityOfExactMatches(const CodePointArrayView codePoints) const;
- int getNgramProbability(const PrevWordsInfo *const prevWordsInfo,
+ int getNgramProbability(const NgramContext *const ngramContext,
const CodePointArrayView codePoints) const;
bool addUnigramEntry(const CodePointArrayView codePoints,
@@ -85,13 +85,13 @@ class Dictionary {
bool removeUnigramEntry(const CodePointArrayView codePoints);
- bool addNgramEntry(const PrevWordsInfo *const prevWordsInfo,
+ bool addNgramEntry(const NgramContext *const ngramContext,
const NgramProperty *const ngramProperty);
- bool removeNgramEntry(const PrevWordsInfo *const prevWordsInfo,
+ bool removeNgramEntry(const NgramContext *const ngramContext,
const CodePointArrayView codePoints);
- bool updateCounter(const PrevWordsInfo *const prevWordsInfo,
+ bool updateEntriesForWordWithNgramContext(const NgramContext *const ngramContext,
const CodePointArrayView codePoints, const bool isValidWord,
const HistoricalInfo historicalInfo);
@@ -123,7 +123,7 @@ class Dictionary {
class NgramListenerForPrediction : public NgramListener {
public:
- NgramListenerForPrediction(const PrevWordsInfo *const prevWordsInfo,
+ NgramListenerForPrediction(const NgramContext *const ngramContext,
const WordIdArrayView prevWordIds, SuggestionResults *const suggestionResults,
const DictionaryStructureWithBufferPolicy *const dictStructurePolicy);
virtual void onVisitEntry(const int ngramProbability, const int targetWordId);
@@ -131,7 +131,7 @@ class Dictionary {
private:
DISALLOW_IMPLICIT_CONSTRUCTORS(NgramListenerForPrediction);
- const PrevWordsInfo *const mPrevWordsInfo;
+ const NgramContext *const mNgramContext;
const WordIdArrayView mPrevWordIds;
SuggestionResults *const mSuggestionResults;
const DictionaryStructureWithBufferPolicy *const mDictStructurePolicy;
diff --git a/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp b/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp
index b85f3622a..9573c37bc 100644
--- a/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp
+++ b/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp
@@ -21,7 +21,7 @@
#include "suggest/core/dicnode/dic_node_vector.h"
#include "suggest/core/dictionary/dictionary.h"
#include "suggest/core/dictionary/digraph_utils.h"
-#include "suggest/core/session/prev_words_info.h"
+#include "suggest/core/session/ngram_context.h"
#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h"
#include "utils/int_array_view.h"
@@ -33,10 +33,10 @@ namespace latinime {
std::vector<DicNode> current;
std::vector<DicNode> next;
- // No prev words information.
- PrevWordsInfo emptyPrevWordsInfo;
+ // No ngram context.
+ NgramContext emptyNgramContext;
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray;
- const WordIdArrayView prevWordIds = emptyPrevWordsInfo.getPrevWordIds(
+ const WordIdArrayView prevWordIds = emptyNgramContext.getPrevWordIds(
dictionaryStructurePolicy, &prevWordIdArray, false /* tryLowerCaseSearch */);
current.emplace_back();
DicNodeUtils::initAsRoot(dictionaryStructurePolicy, prevWordIds, &current.front());
diff --git a/native/jni/src/suggest/core/dictionary/property/historical_info.h b/native/jni/src/suggest/core/dictionary/property/historical_info.h
index 5ed9ebfca..f9bd6fd8c 100644
--- a/native/jni/src/suggest/core/dictionary/property/historical_info.h
+++ b/native/jni/src/suggest/core/dictionary/property/historical_info.h
@@ -47,12 +47,12 @@ class HistoricalInfo {
}
private:
- // Default copy constructor and assign operator are used for using in std::vector.
+ // Default copy constructor is used for using in std::vector.
+ DISALLOW_ASSIGNMENT_OPERATOR(HistoricalInfo);
- // TODO: Make members const.
- int mTimestamp;
- int mLevel;
- int mCount;
+ const int mTimestamp;
+ const int mLevel;
+ const int mCount;
};
} // namespace latinime
#endif /* LATINIME_HISTORICAL_INFO_H */
diff --git a/native/jni/src/suggest/core/dictionary/property/ngram_property.h b/native/jni/src/suggest/core/dictionary/property/ngram_property.h
index dce460099..8709799f9 100644
--- a/native/jni/src/suggest/core/dictionary/property/ngram_property.h
+++ b/native/jni/src/suggest/core/dictionary/property/ngram_property.h
@@ -44,13 +44,13 @@ class NgramProperty {
}
private:
- // Default copy constructor and assign operator are used for using in std::vector.
+ // Default copy constructor is used for using in std::vector.
DISALLOW_DEFAULT_CONSTRUCTOR(NgramProperty);
+ DISALLOW_ASSIGNMENT_OPERATOR(NgramProperty);
- // TODO: Make members const.
- std::vector<int> mTargetCodePoints;
- int mProbability;
- HistoricalInfo mHistoricalInfo;
+ const std::vector<int> mTargetCodePoints;
+ const int mProbability;
+ const HistoricalInfo mHistoricalInfo;
};
} // namespace latinime
#endif // LATINIME_NGRAM_PROPERTY_H
diff --git a/native/jni/src/suggest/core/dictionary/property/unigram_property.h b/native/jni/src/suggest/core/dictionary/property/unigram_property.h
index d1f0ab4ca..5ed2e2602 100644
--- a/native/jni/src/suggest/core/dictionary/property/unigram_property.h
+++ b/native/jni/src/suggest/core/dictionary/property/unigram_property.h
@@ -41,12 +41,11 @@ class UnigramProperty {
}
private:
- // Default copy constructor and assign operator are used for using in std::vector.
+ // Default copy constructor is used for using in std::vector.
DISALLOW_DEFAULT_CONSTRUCTOR(ShortcutProperty);
- // TODO: Make members const.
- std::vector<int> mTargetCodePoints;
- int mProbability;
+ const std::vector<int> mTargetCodePoints;
+ const int mProbability;
};
UnigramProperty()
@@ -104,13 +103,12 @@ class UnigramProperty {
// Default copy constructor is used for using as a return value.
DISALLOW_ASSIGNMENT_OPERATOR(UnigramProperty);
- // TODO: Make members const.
- bool mRepresentsBeginningOfSentence;
- bool mIsNotAWord;
- bool mIsBlacklisted;
- int mProbability;
- HistoricalInfo mHistoricalInfo;
- std::vector<ShortcutProperty> mShortcuts;
+ const bool mRepresentsBeginningOfSentence;
+ const bool mIsNotAWord;
+ const bool mIsBlacklisted;
+ const int mProbability;
+ const HistoricalInfo mHistoricalInfo;
+ const std::vector<ShortcutProperty> mShortcuts;
};
} // namespace latinime
#endif // LATINIME_UNIGRAM_PROPERTY_H
diff --git a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h
index 6624b7921..ceda5c03f 100644
--- a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h
+++ b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h
@@ -33,7 +33,7 @@ class DicNodeVector;
class DictionaryHeaderStructurePolicy;
class MultiBigramMap;
class NgramListener;
-class PrevWordsInfo;
+class NgramContext;
class UnigramProperty;
/*
@@ -81,15 +81,15 @@ class DictionaryStructureWithBufferPolicy {
virtual bool removeUnigramEntry(const CodePointArrayView wordCodePoints) = 0;
// Returns whether the update was success or not.
- virtual bool addNgramEntry(const PrevWordsInfo *const prevWordsInfo,
+ virtual bool addNgramEntry(const NgramContext *const ngramContext,
const NgramProperty *const ngramProperty) = 0;
// Returns whether the update was success or not.
- virtual bool removeNgramEntry(const PrevWordsInfo *const prevWordsInfo,
+ virtual bool removeNgramEntry(const NgramContext *const ngramContext,
const CodePointArrayView wordCodePoints) = 0;
// Returns whether the update was success or not.
- virtual bool updateCounter(const PrevWordsInfo *const prevWordsInfo,
+ virtual bool updateEntriesForWordWithNgramContext(const NgramContext *const ngramContext,
const CodePointArrayView wordCodePoints, const bool isValidWord,
const HistoricalInfo historicalInfo) = 0;
diff --git a/native/jni/src/suggest/core/policy/traversal.h b/native/jni/src/suggest/core/policy/traversal.h
index 6dfa7e314..5b6616d9a 100644
--- a/native/jni/src/suggest/core/policy/traversal.h
+++ b/native/jni/src/suggest/core/policy/traversal.h
@@ -44,7 +44,7 @@ class Traversal {
virtual bool needsToTraverseAllUserInput() const = 0;
virtual float getMaxSpatialDistance() const = 0;
virtual int getDefaultExpandDicNodeSize() const = 0;
- virtual int getMaxCacheSize(const int inputSize) const = 0;
+ virtual int getMaxCacheSize(const int inputSize, const float weightForLocale) const = 0;
virtual int getTerminalCacheSize() const = 0;
virtual bool isPossibleOmissionChildNode(const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0;
diff --git a/native/jni/src/suggest/core/policy/weighting.cpp b/native/jni/src/suggest/core/policy/weighting.cpp
index c202b81fe..a06e7d070 100644
--- a/native/jni/src/suggest/core/policy/weighting.cpp
+++ b/native/jni/src/suggest/core/policy/weighting.cpp
@@ -110,10 +110,14 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
return weighting->getOmissionCost(parentDicNode, dicNode);
case CT_ADDITIONAL_PROXIMITY:
// only used for typing
- return weighting->getAdditionalProximityCost();
+ // TODO: Quit calling getMatchedCost().
+ return weighting->getAdditionalProximityCost()
+ + weighting->getMatchedCost(traverseSession, dicNode, inputStateG);
case CT_SUBSTITUTION:
// only used for typing
- return weighting->getSubstitutionCost();
+ // TODO: Quit calling getMatchedCost().
+ return weighting->getSubstitutionCost()
+ + weighting->getMatchedCost(traverseSession, dicNode, inputStateG);
case CT_NEW_WORD_SPACE_OMISSION:
return weighting->getNewWordSpatialCost(traverseSession, dicNode, inputStateG);
case CT_MATCH:
@@ -176,9 +180,9 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
case CT_OMISSION:
return 0;
case CT_ADDITIONAL_PROXIMITY:
- return 0; /* 0 because CT_MATCH will be called */
+ return 1;
case CT_SUBSTITUTION:
- return 0; /* 0 because CT_MATCH will be called */
+ return 1;
case CT_NEW_WORD_SPACE_OMISSION:
return 0;
case CT_MATCH:
diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.cpp b/native/jni/src/suggest/core/session/dic_traverse_session.cpp
index b4d01d0f0..52dc2f86c 100644
--- a/native/jni/src/suggest/core/session/dic_traverse_session.cpp
+++ b/native/jni/src/suggest/core/session/dic_traverse_session.cpp
@@ -20,7 +20,7 @@
#include "suggest/core/dictionary/dictionary.h"
#include "suggest/core/policy/dictionary_header_structure_policy.h"
#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h"
-#include "suggest/core/session/prev_words_info.h"
+#include "suggest/core/session/ngram_context.h"
namespace latinime {
@@ -30,12 +30,12 @@ const int DicTraverseSession::DICTIONARY_SIZE_THRESHOLD_TO_USE_LARGE_CACHE_FOR_S
256 * 1024;
void DicTraverseSession::init(const Dictionary *const dictionary,
- const PrevWordsInfo *const prevWordsInfo, const SuggestOptions *const suggestOptions) {
+ const NgramContext *const ngramContext, const SuggestOptions *const suggestOptions) {
mDictionary = dictionary;
mMultiWordCostMultiplier = getDictionaryStructurePolicy()->getHeaderStructurePolicy()
->getMultiWordCostMultiplier();
mSuggestOptions = suggestOptions;
- mPrevWordIdCount = prevWordsInfo->getPrevWordIds(getDictionaryStructurePolicy(),
+ mPrevWordIdCount = ngramContext->getPrevWordIds(getDictionaryStructurePolicy(),
&mPrevWordIdArray, true /* tryLowerCaseSearch */).size();
}
diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.h b/native/jni/src/suggest/core/session/dic_traverse_session.h
index 9f841aa3c..bc53167f0 100644
--- a/native/jni/src/suggest/core/session/dic_traverse_session.h
+++ b/native/jni/src/suggest/core/session/dic_traverse_session.h
@@ -30,7 +30,7 @@ namespace latinime {
class Dictionary;
class DictionaryStructureWithBufferPolicy;
-class PrevWordsInfo;
+class NgramContext;
class ProximityInfo;
class SuggestOptions;
@@ -61,7 +61,7 @@ class DicTraverseSession {
// Non virtual inline destructor -- never inherit this class
AK_FORCE_INLINE ~DicTraverseSession() {}
- void init(const Dictionary *dictionary, const PrevWordsInfo *const prevWordsInfo,
+ void init(const Dictionary *dictionary, const NgramContext *const ngramContext,
const SuggestOptions *const suggestOptions);
// TODO: Remove and merge into init
void setupForGetSuggestions(const ProximityInfo *pInfo, const int *inputCodePoints,
diff --git a/native/jni/src/suggest/core/session/prev_words_info.h b/native/jni/src/suggest/core/session/ngram_context.h
index 553d5ad07..64c71410f 100644
--- a/native/jni/src/suggest/core/session/prev_words_info.h
+++ b/native/jni/src/suggest/core/session/ngram_context.h
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-#ifndef LATINIME_PREV_WORDS_INFO_H
-#define LATINIME_PREV_WORDS_INFO_H
+#ifndef LATINIME_NGRAM_CONTEXT_H
+#define LATINIME_NGRAM_CONTEXT_H
#include <array>
@@ -26,25 +26,26 @@
namespace latinime {
-class PrevWordsInfo {
+// Rename to NgramContext.
+class NgramContext {
public:
// No prev word information.
- PrevWordsInfo() : mPrevWordCount(0) {
+ NgramContext() : mPrevWordCount(0) {
clear();
}
- PrevWordsInfo(const PrevWordsInfo &prevWordsInfo)
- : mPrevWordCount(prevWordsInfo.mPrevWordCount) {
+ NgramContext(const NgramContext &ngramContext)
+ : mPrevWordCount(ngramContext.mPrevWordCount) {
for (size_t i = 0; i < mPrevWordCount; ++i) {
- mPrevWordCodePointCount[i] = prevWordsInfo.mPrevWordCodePointCount[i];
- memmove(mPrevWordCodePoints[i], prevWordsInfo.mPrevWordCodePoints[i],
+ mPrevWordCodePointCount[i] = ngramContext.mPrevWordCodePointCount[i];
+ memmove(mPrevWordCodePoints[i], ngramContext.mPrevWordCodePoints[i],
sizeof(mPrevWordCodePoints[i][0]) * mPrevWordCodePointCount[i]);
- mIsBeginningOfSentence[i] = prevWordsInfo.mIsBeginningOfSentence[i];
+ mIsBeginningOfSentence[i] = ngramContext.mIsBeginningOfSentence[i];
}
}
// Construct from previous words.
- PrevWordsInfo(const int prevWordCodePoints[][MAX_WORD_LENGTH],
+ NgramContext(const int prevWordCodePoints[][MAX_WORD_LENGTH],
const int *const prevWordCodePointCount, const bool *const isBeginningOfSentence,
const size_t prevWordCount)
: mPrevWordCount(std::min(NELEMS(mPrevWordCodePoints), prevWordCount)) {
@@ -61,7 +62,7 @@ class PrevWordsInfo {
}
// Construct from a previous word.
- PrevWordsInfo(const int *const prevWordCodePoints, const int prevWordCodePointCount,
+ NgramContext(const int *const prevWordCodePoints, const int prevWordCodePointCount,
const bool isBeginningOfSentence) : mPrevWordCount(1) {
clear();
if (prevWordCodePointCount > MAX_WORD_LENGTH || !prevWordCodePoints) {
@@ -78,8 +79,8 @@ class PrevWordsInfo {
}
// TODO: Remove.
- const PrevWordsInfo getTrimmedPrevWordsInfo(const size_t maxPrevWordCount) const {
- return PrevWordsInfo(mPrevWordCodePoints, mPrevWordCodePointCount, mIsBeginningOfSentence,
+ const NgramContext getTrimmedNgramContext(const size_t maxPrevWordCount) const {
+ return NgramContext(mPrevWordCodePoints, mPrevWordCodePointCount, mIsBeginningOfSentence,
std::min(mPrevWordCount, maxPrevWordCount));
}
@@ -122,7 +123,7 @@ class PrevWordsInfo {
}
private:
- DISALLOW_ASSIGNMENT_OPERATOR(PrevWordsInfo);
+ DISALLOW_ASSIGNMENT_OPERATOR(NgramContext);
static int getWordId(const DictionaryStructureWithBufferPolicy *const dictStructurePolicy,
const int *const wordCodePoints, const int wordCodePointCount,
@@ -165,4 +166,4 @@ class PrevWordsInfo {
bool mIsBeginningOfSentence[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
};
} // namespace latinime
-#endif // LATINIME_PREV_WORDS_INFO_H
+#endif // LATINIME_NGRAM_CONTEXT_H
diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp
index 457414f2b..c71526293 100644
--- a/native/jni/src/suggest/core/suggest.cpp
+++ b/native/jni/src/suggest/core/suggest.cpp
@@ -28,6 +28,7 @@
#include "suggest/core/policy/weighting.h"
#include "suggest/core/result/suggestions_output_utils.h"
#include "suggest/core/session/dic_traverse_session.h"
+#include "suggest/core/suggest_options.h"
namespace latinime {
@@ -88,7 +89,8 @@ void Suggest::initializeSearch(DicTraverseSession *traverseSession) const {
traverseSession->getDicTraverseCache()->continueSearch();
} else {
// Restart recognition at the root.
- traverseSession->resetCache(TRAVERSAL->getMaxCacheSize(traverseSession->getInputSize()),
+ traverseSession->resetCache(TRAVERSAL->getMaxCacheSize(traverseSession->getInputSize(),
+ traverseSession->getSuggestOptions()->weightForLocale()),
TRAVERSAL->getTerminalCacheSize());
// Create a new dic node here
DicNode rootNode;
@@ -282,7 +284,6 @@ void Suggest::processDicNodeAsAdditionalProximityChar(DicTraverseSession *traver
// not treat the node as a terminal. There is no need to pass the bigram map in these cases.
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_ADDITIONAL_PROXIMITY,
traverseSession, dicNode, childDicNode, 0 /* multiBigramMap */);
- weightChildNode(traverseSession, childDicNode);
processExpandedDicNode(traverseSession, childDicNode);
}
@@ -290,7 +291,6 @@ void Suggest::processDicNodeAsSubstitution(DicTraverseSession *traverseSession,
DicNode *dicNode, DicNode *childDicNode) const {
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_SUBSTITUTION, traverseSession,
dicNode, childDicNode, 0 /* multiBigramMap */);
- weightChildNode(traverseSession, childDicNode);
processExpandedDicNode(traverseSession, childDicNode);
}
@@ -401,7 +401,7 @@ void Suggest::weightChildNode(DicTraverseSession *traverseSession, DicNode *dicN
if (dicNode->isCompletion(inputSize)) {
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_COMPLETION, traverseSession,
0 /* parentDicNode */, dicNode, 0 /* multiBigramMap */);
- } else { // completion
+ } else {
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_MATCH, traverseSession,
0 /* parentDicNode */, dicNode, 0 /* multiBigramMap */);
}
diff --git a/native/jni/src/suggest/core/suggest_options.h b/native/jni/src/suggest/core/suggest_options.h
index d456680dd..4d331292b 100644
--- a/native/jni/src/suggest/core/suggest_options.h
+++ b/native/jni/src/suggest/core/suggest_options.h
@@ -42,6 +42,12 @@ class SuggestOptions{
return getBoolOption(SPACE_AWARE_GESTURE_ENABLED);
}
+ AK_FORCE_INLINE float weightForLocale() const {
+ // The weight is in thousands and we want the real value, so we divide by 1000.
+ // NativeSuggestOptions#setWeightForLocale does the opposite processing in Java.
+ return static_cast<float>(getIntOption(WEIGHT_FOR_LOCALE_IN_THOUSANDS)) / 1000.0f;
+ }
+
AK_FORCE_INLINE bool getAdditionalFeaturesBoolOption(const int key) const {
return getBoolOption(key + ADDITIONAL_FEATURES_OPTIONS);
}
@@ -55,9 +61,10 @@ class SuggestOptions{
static const int USE_FULL_EDIT_DISTANCE = 1;
static const int BLOCK_OFFENSIVE_WORDS = 2;
static const int SPACE_AWARE_GESTURE_ENABLED = 3;
+ static const int WEIGHT_FOR_LOCALE_IN_THOUSANDS = 4;
// Additional features options are stored after the other options and used as setting values of
// experimental features.
- static const int ADDITIONAL_FEATURES_OPTIONS = 4;
+ static const int ADDITIONAL_FEATURES_OPTIONS = 5;
const int *const mOptions;
const int mLength;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp
index 4c4dfc578..8fb256c54 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp
@@ -38,15 +38,17 @@ const char *const HeaderPolicy::LOCALE_KEY = "locale"; // match Java declaration
const char *const HeaderPolicy::FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY =
"FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID";
-const char *const HeaderPolicy::MAX_UNIGRAM_COUNT_KEY = "MAX_UNIGRAM_COUNT";
-const char *const HeaderPolicy::MAX_BIGRAM_COUNT_KEY = "MAX_BIGRAM_COUNT";
+const char *const HeaderPolicy::MAX_UNIGRAM_COUNT_KEY = "MAX_UNIGRAM_ENTRY_COUNT";
+const char *const HeaderPolicy::MAX_BIGRAM_COUNT_KEY = "MAX_BIGRAM_ENTRY_COUNT";
+const char *const HeaderPolicy::MAX_TRIGRAM_COUNT_KEY = "MAX_TRIGRAM_ENTRY_COUNT";
const int HeaderPolicy::DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE = 100;
const float HeaderPolicy::MULTIPLE_WORD_COST_MULTIPLIER_SCALE = 100.0f;
const int HeaderPolicy::DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID = 3;
const int HeaderPolicy::DEFAULT_MAX_UNIGRAM_COUNT = 10000;
-const int HeaderPolicy::DEFAULT_MAX_BIGRAM_COUNT = 10000;
+const int HeaderPolicy::DEFAULT_MAX_BIGRAM_COUNT = 30000;
+const int HeaderPolicy::DEFAULT_MAX_TRIGRAM_COUNT = 30000;
// Used for logging. Question mark is used to indicate that the key is not found.
void HeaderPolicy::readHeaderValueOrQuestionMark(const char *const key, int *outValue,
diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h
index bc8eaded3..836bbe5a1 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h
@@ -253,11 +253,13 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
static const char *const FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS_KEY;
static const char *const MAX_UNIGRAM_COUNT_KEY;
static const char *const MAX_BIGRAM_COUNT_KEY;
+ static const char *const MAX_TRIGRAM_COUNT_KEY;
static const int DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE;
static const float MULTIPLE_WORD_COST_MULTIPLIER_SCALE;
static const int DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID;
static const int DEFAULT_MAX_UNIGRAM_COUNT;
static const int DEFAULT_MAX_BIGRAM_COUNT;
+ static const int DEFAULT_MAX_TRIGRAM_COUNT;
const FormatUtils::FORMAT_VERSION mDictFormatVersion;
const HeaderReadWriteUtils::DictionaryFlags mDictionaryFlags;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp
index 8d169743c..6243f14cc 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp
@@ -310,7 +310,7 @@ bool Ver4PatriciaTrieNodeWriter::addShortcutTarget(const PtNodeParams *const ptN
const int shortcutProbability) {
if (!mShortcutPolicy->addNewShortcut(ptNodeParams->getTerminalId(),
targetCodePoints, targetCodePointCount, shortcutProbability)) {
- AKLOGE("Cannot add new shortuct entry. terminalId: %d", ptNodeParams->getTerminalId());
+ AKLOGE("Cannot add new shortcut entry. terminalId: %d", ptNodeParams->getTerminalId());
return false;
}
if (!ptNodeParams->hasShortcutTargets()) {
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
index 36eafa1e9..0eae934ae 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
@@ -33,7 +33,7 @@
#include "suggest/core/dictionary/property/ngram_property.h"
#include "suggest/core/dictionary/property/unigram_property.h"
#include "suggest/core/dictionary/property/word_property.h"
-#include "suggest/core/session/prev_words_info.h"
+#include "suggest/core/session/ngram_context.h"
#include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h"
#include "suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_reader.h"
#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h"
@@ -186,7 +186,9 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordI
if (bigramsIt.getBigramPos() == ptNodePos
&& bigramsIt.getProbability() != NOT_A_PROBABILITY) {
const int bigramConditionalProbability = getBigramConditionalProbability(
- prevWordPtNodeParams.getProbability(), bigramsIt.getProbability());
+ prevWordPtNodeParams.getProbability(),
+ prevWordPtNodeParams.representsBeginningOfSentence(),
+ bigramsIt.getProbability());
return getProbability(ptNodeParams.getProbability(), bigramConditionalProbability);
}
}
@@ -209,15 +211,19 @@ void Ver4PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordI
while (bigramsIt.hasNext()) {
bigramsIt.next();
const int bigramConditionalProbability = getBigramConditionalProbability(
- prevWordPtNodeParams.getProbability(), bigramsIt.getProbability());
+ prevWordPtNodeParams.getProbability(),
+ prevWordPtNodeParams.representsBeginningOfSentence(), bigramsIt.getProbability());
listener->onVisitEntry(bigramConditionalProbability,
getWordIdFromTerminalPtNodePos(bigramsIt.getBigramPos()));
}
}
int Ver4PatriciaTriePolicy::getBigramConditionalProbability(const int prevWordUnigramProbability,
- const int bigramProbability) const {
+ const bool isInBeginningOfSentenceContext, const int bigramProbability) const {
if (mHeaderPolicy->hasHistoricalInfoOfWords()) {
+ if (isInBeginningOfSentenceContext) {
+ return bigramProbability;
+ }
// Calculate conditional probability.
return std::min(MAX_PROBABILITY - prevWordUnigramProbability + bigramProbability,
MAX_PROBABILITY);
@@ -338,7 +344,7 @@ bool Ver4PatriciaTriePolicy::removeUnigramEntry(const CodePointArrayView wordCod
return mNodeWriter.suppressUnigramEntry(&ptNodeParams);
}
-bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsInfo,
+bool Ver4PatriciaTriePolicy::addNgramEntry(const NgramContext *const ngramContext,
const NgramProperty *const ngramProperty) {
if (!mBuffers->isUpdatable()) {
AKLOGI("Warning: addNgramEntry() is called for non-updatable dictionary.");
@@ -349,8 +355,8 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI
mDictBuffer->getTailPosition());
return false;
}
- if (!prevWordsInfo->isValid()) {
- AKLOGE("prev words info is not valid for adding n-gram entry to the dictionary.");
+ if (!ngramContext->isValid()) {
+ AKLOGE("Ngram context is not valid for adding n-gram entry to the dictionary.");
return false;
}
if (ngramProperty->getTargetCodePoints()->size() > MAX_WORD_LENGTH) {
@@ -359,23 +365,23 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI
return false;
}
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray;
- const WordIdArrayView prevWordIds = prevWordsInfo->getPrevWordIds(this, &prevWordIdArray,
+ const WordIdArrayView prevWordIds = ngramContext->getPrevWordIds(this, &prevWordIdArray,
false /* tryLowerCaseSearch */);
if (prevWordIds.empty()) {
return false;
}
if (prevWordIds[0] == NOT_A_WORD_ID) {
- if (prevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */)) {
+ if (ngramContext->isNthPrevWordBeginningOfSentence(1 /* n */)) {
const UnigramProperty beginningOfSentenceUnigramProperty(
true /* representsBeginningOfSentence */, true /* isNotAWord */,
false /* isBlacklisted */, MAX_PROBABILITY /* probability */, HistoricalInfo());
- if (!addUnigramEntry(prevWordsInfo->getNthPrevWordCodePoints(1 /* n */),
+ if (!addUnigramEntry(ngramContext->getNthPrevWordCodePoints(1 /* n */),
&beginningOfSentenceUnigramProperty)) {
AKLOGE("Cannot add unigram entry for the beginning-of-sentence.");
return false;
}
// Refresh word ids.
- prevWordsInfo->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */);
+ ngramContext->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */);
} else {
return false;
}
@@ -399,7 +405,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI
}
}
-bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWordsInfo,
+bool Ver4PatriciaTriePolicy::removeNgramEntry(const NgramContext *const ngramContext,
const CodePointArrayView wordCodePoints) {
if (!mBuffers->isUpdatable()) {
AKLOGI("Warning: removeNgramEntry() is called for non-updatable dictionary.");
@@ -410,8 +416,8 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor
mDictBuffer->getTailPosition());
return false;
}
- if (!prevWordsInfo->isValid()) {
- AKLOGE("prev words info is not valid for removing n-gram entry form the dictionary.");
+ if (!ngramContext->isValid()) {
+ AKLOGE("Ngram context is not valid for removing n-gram entry form the dictionary.");
return false;
}
if (wordCodePoints.size() > MAX_WORD_LENGTH) {
@@ -419,7 +425,7 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor
wordCodePoints.size());
}
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray;
- const WordIdArrayView prevWordIds = prevWordsInfo->getPrevWordIds(this, &prevWordIdArray,
+ const WordIdArrayView prevWordIds = ngramContext->getPrevWordIds(this, &prevWordIdArray,
false /* tryLowerCaseSerch */);
if (prevWordIds.firstOrDefault(NOT_A_WORD_ID) == NOT_A_WORD_ID) {
return false;
@@ -440,26 +446,27 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor
}
-bool Ver4PatriciaTriePolicy::updateCounter(const PrevWordsInfo *const prevWordsInfo,
- const CodePointArrayView wordCodePoints, const bool isValidWord,
- const HistoricalInfo historicalInfo) {
+bool Ver4PatriciaTriePolicy::updateEntriesForWordWithNgramContext(
+ const NgramContext *const ngramContext, const CodePointArrayView wordCodePoints,
+ const bool isValidWord, const HistoricalInfo historicalInfo) {
if (!mBuffers->isUpdatable()) {
- AKLOGI("Warning: updateCounter() is called for non-updatable dictionary.");
+ AKLOGI("Warning: updateEntriesForWordWithNgramContext() is called for non-updatable "
+ "dictionary.");
return false;
}
const int probability = isValidWord ? DUMMY_PROBABILITY_FOR_VALID_WORDS : NOT_A_PROBABILITY;
const UnigramProperty unigramProperty(false /* representsBeginningOfSentence */,
false /* isNotAWord */, false /*isBlacklisted*/, probability, historicalInfo);
if (!addUnigramEntry(wordCodePoints, &unigramProperty)) {
- AKLOGE("Cannot update unigarm entry in updateCounter().");
+ AKLOGE("Cannot update unigarm entry in updateEntriesForWordWithNgramContext().");
return false;
}
- const int probabilityForNgram = prevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */)
+ const int probabilityForNgram = ngramContext->isNthPrevWordBeginningOfSentence(1 /* n */)
? NOT_A_PROBABILITY : probability;
const NgramProperty ngramProperty(wordCodePoints.toVector(), probabilityForNgram,
historicalInfo);
- if (!addNgramEntry(prevWordsInfo, &ngramProperty)) {
- AKLOGE("Cannot update unigarm entry in updateCounter().");
+ if (!addNgramEntry(ngramContext, &ngramProperty)) {
+ AKLOGE("Cannot update unigarm entry in updateEntriesForWordWithNgramContext().");
return false;
}
return true;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h
index b82563e61..1ad5e7e36 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h
@@ -112,13 +112,13 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
bool removeUnigramEntry(const CodePointArrayView wordCodePoints);
- bool addNgramEntry(const PrevWordsInfo *const prevWordsInfo,
+ bool addNgramEntry(const NgramContext *const ngramContext,
const NgramProperty *const ngramProperty);
- bool removeNgramEntry(const PrevWordsInfo *const prevWordsInfo,
+ bool removeNgramEntry(const NgramContext *const ngramContext,
const CodePointArrayView wordCodePoints);
- bool updateCounter(const PrevWordsInfo *const prevWordsInfo,
+ bool updateEntriesForWordWithNgramContext(const NgramContext *const ngramContext,
const CodePointArrayView wordCodePoints, const bool isValidWord,
const HistoricalInfo historicalInfo);
@@ -175,7 +175,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
const WordAttributes getWordAttributes(const int probability,
const PtNodeParams &ptNodeParams) const;
int getBigramConditionalProbability(const int prevWordUnigramProbability,
- const int bigramProbability) const;
+ const bool isInBeginningOfSentenceContext, const int bigramProbability) const;
};
} // namespace v402
} // namespace backward
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp
index d3d684bfa..b7f1199c5 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp
@@ -23,7 +23,7 @@
#include "suggest/core/dictionary/binary_dictionary_bigrams_iterator.h"
#include "suggest/core/dictionary/multi_bigram_map.h"
#include "suggest/core/dictionary/ngram_listener.h"
-#include "suggest/core/session/prev_words_info.h"
+#include "suggest/core/session/ngram_context.h"
#include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h"
#include "suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h"
#include "suggest/policyimpl/dictionary/utils/probability_utils.h"
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h
index 32a95bb6c..b17681388 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h
@@ -93,25 +93,26 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
return false;
}
- bool addNgramEntry(const PrevWordsInfo *const prevWordsInfo,
+ bool addNgramEntry(const NgramContext *const ngramContext,
const NgramProperty *const ngramProperty) {
// This method should not be called for non-updatable dictionary.
AKLOGI("Warning: addNgramEntry() is called for non-updatable dictionary.");
return false;
}
- bool removeNgramEntry(const PrevWordsInfo *const prevWordsInfo,
+ bool removeNgramEntry(const NgramContext *const ngramContext,
const CodePointArrayView wordCodePoints) {
// This method should not be called for non-updatable dictionary.
AKLOGI("Warning: removeNgramEntry() is called for non-updatable dictionary.");
return false;
}
- bool updateCounter(const PrevWordsInfo *const prevWordsInfo,
+ bool updateEntriesForWordWithNgramContext(const NgramContext *const ngramContext,
const CodePointArrayView wordCodePoints, const bool isValidWord,
const HistoricalInfo historicalInfo) {
// This method should not be called for non-updatable dictionary.
- AKLOGI("Warning: updateCounter() is called for non-updatable dictionary.");
+ AKLOGI("Warning: updateEntriesForWordWithNgramContext() is called for non-updatable "
+ "dictionary.");
return false;
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.cpp
index dc0ed96d0..90d4687dd 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.cpp
@@ -37,7 +37,7 @@ const PtNodeParams Ver2ParticiaTrieNodeReader::fetchPtNodeParamsInBufferFromPtNo
int shortcutPos = NOT_A_DICT_POS;
int bigramPos = NOT_A_DICT_POS;
int siblingPos = NOT_A_DICT_POS;
- PatriciaTrieReadingUtils::readPtNodeInfo(mBuffer.data(), ptNodePos, mShortuctPolicy,
+ PatriciaTrieReadingUtils::readPtNodeInfo(mBuffer.data(), ptNodePos, mShortcutPolicy,
mBigramPolicy, mCodePointTable, &flags, &mergedNodeCodePointCount, mergedNodeCodePoints,
&probability, &childrenPos, &shortcutPos, &bigramPos, &siblingPos);
if (mergedNodeCodePointCount <= 0) {
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.h
index 24ec5bcca..838d37314 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.h
@@ -35,7 +35,7 @@ class Ver2ParticiaTrieNodeReader : public PtNodeReader {
const DictionaryBigramsStructurePolicy *const bigramPolicy,
const DictionaryShortcutsStructurePolicy *const shortcutPolicy,
const int *const codePointTable)
- : mBuffer(buffer), mBigramPolicy(bigramPolicy), mShortuctPolicy(shortcutPolicy),
+ : mBuffer(buffer), mBigramPolicy(bigramPolicy), mShortcutPolicy(shortcutPolicy),
mCodePointTable(codePointTable) {}
virtual const PtNodeParams fetchPtNodeParamsInBufferFromPtNodePos(const int ptNodePos) const;
@@ -45,7 +45,7 @@ class Ver2ParticiaTrieNodeReader : public PtNodeReader {
const ReadOnlyByteArrayView mBuffer;
const DictionaryBigramsStructurePolicy *const mBigramPolicy;
- const DictionaryShortcutsStructurePolicy *const mShortuctPolicy;
+ const DictionaryShortcutsStructurePolicy *const mShortcutPolicy;
const int *const mCodePointTable;
};
} // namespace latinime
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
index 956dabb4f..c4297f5d6 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
@@ -25,6 +25,7 @@ namespace latinime {
const int LanguageModelDictContent::UNIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE = 0;
const int LanguageModelDictContent::BIGRAM_COUNT_INDEX_IN_ENTRY_COUNT_TABLE = 1;
+const int LanguageModelDictContent::DUMMY_PROBABILITY_FOR_VALID_WORDS = 1;
bool LanguageModelDictContent::save(FILE *const file) const {
return mTrieMap.save(file);
@@ -42,18 +43,18 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr
const int wordId, const HeaderPolicy *const headerPolicy) const {
int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex();
- int maxLevel = 0;
+ int maxPrevWordCount = 0;
for (size_t i = 0; i < prevWordIds.size(); ++i) {
const int nextBitmapEntryIndex =
mTrieMap.get(prevWordIds[i], bitmapEntryIndices[i]).mNextLevelBitmapEntryIndex;
if (nextBitmapEntryIndex == TrieMap::INVALID_INDEX) {
break;
}
- maxLevel = i + 1;
+ maxPrevWordCount = i + 1;
bitmapEntryIndices[i + 1] = nextBitmapEntryIndex;
}
- for (int i = maxLevel; i >= 0; --i) {
+ for (int i = maxPrevWordCount; i >= 0; --i) {
const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndices[i]);
if (!result.mIsValid) {
continue;
@@ -68,9 +69,24 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr
// The entry should not be treated as a valid entry.
continue;
}
- probability = std::min(rawProbability
- + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */),
- MAX_PROBABILITY);
+ if (i == 0) {
+ // unigram
+ probability = rawProbability;
+ } else {
+ const ProbabilityEntry prevWordProbabilityEntry = getNgramProbabilityEntry(
+ prevWordIds.skip(1 /* n */).limit(i - 1), prevWordIds[0]);
+ if (!prevWordProbabilityEntry.isValid()) {
+ continue;
+ }
+ if (prevWordProbabilityEntry.representsBeginningOfSentence()) {
+ probability = rawProbability;
+ } else {
+ const int prevWordRawProbability = ForgettingCurveUtils::decodeProbability(
+ prevWordProbabilityEntry.getHistoricalInfo(), headerPolicy);
+ probability = std::min(MAX_PROBABILITY - prevWordRawProbability
+ + rawProbability, MAX_PROBABILITY);
+ }
+ }
} else {
probability = probabilityEntry.getProbability();
}
@@ -143,6 +159,56 @@ bool LanguageModelDictContent::truncateEntries(const int *const entryCounts,
return true;
}
+bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView prevWordIds,
+ const int wordId, const bool isValid, const HistoricalInfo historicalInfo,
+ const HeaderPolicy *const headerPolicy, int *const outAddedNewNgramEntryCount) {
+ if (outAddedNewNgramEntryCount) {
+ *outAddedNewNgramEntryCount = 0;
+ }
+ if (!mHasHistoricalInfo) {
+ AKLOGE("updateAllEntriesOnInputWord is called for dictionary without historical info.");
+ return false;
+ }
+ const ProbabilityEntry originalUnigramProbabilityEntry = getProbabilityEntry(wordId);
+ const ProbabilityEntry updatedUnigramProbabilityEntry = createUpdatedEntryFrom(
+ originalUnigramProbabilityEntry, isValid, historicalInfo, headerPolicy);
+ if (!setProbabilityEntry(wordId, &updatedUnigramProbabilityEntry)) {
+ return false;
+ }
+ for (size_t i = 0; i < prevWordIds.size(); ++i) {
+ if (prevWordIds[i] == NOT_A_WORD_ID) {
+ break;
+ }
+ // TODO: Optimize this code.
+ const WordIdArrayView limitedPrevWordIds = prevWordIds.limit(i + 1);
+ const ProbabilityEntry originalNgramProbabilityEntry = getNgramProbabilityEntry(
+ limitedPrevWordIds, wordId);
+ const ProbabilityEntry updatedNgramProbabilityEntry = createUpdatedEntryFrom(
+ originalNgramProbabilityEntry, isValid, historicalInfo, headerPolicy);
+ if (!setNgramProbabilityEntry(limitedPrevWordIds, wordId, &updatedNgramProbabilityEntry)) {
+ return false;
+ }
+ if (!originalNgramProbabilityEntry.isValid() && outAddedNewNgramEntryCount) {
+ *outAddedNewNgramEntryCount += 1;
+ }
+ }
+ return true;
+}
+
+const ProbabilityEntry LanguageModelDictContent::createUpdatedEntryFrom(
+ const ProbabilityEntry &originalProbabilityEntry, const bool isValid,
+ const HistoricalInfo historicalInfo, const HeaderPolicy *const headerPolicy) const {
+ const HistoricalInfo updatedHistoricalInfo = ForgettingCurveUtils::createUpdatedHistoricalInfo(
+ originalProbabilityEntry.getHistoricalInfo(), isValid ?
+ DUMMY_PROBABILITY_FOR_VALID_WORDS : NOT_A_PROBABILITY,
+ &historicalInfo, headerPolicy);
+ if (originalProbabilityEntry.isValid()) {
+ return ProbabilityEntry(originalProbabilityEntry.getFlags(), &updatedHistoricalInfo);
+ } else {
+ return ProbabilityEntry(0 /* flags */, &updatedHistoricalInfo);
+ }
+}
+
bool LanguageModelDictContent::runGCInner(
const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
const TrieMap::TrieMapRange trieMapRange,
@@ -203,17 +269,27 @@ int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWord
return bitmapEntryIndex;
}
-bool LanguageModelDictContent::updateAllProbabilityEntriesInner(const int bitmapEntryIndex,
- const int level, const HeaderPolicy *const headerPolicy, int *const outEntryCounts) {
+bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex,
+ const int prevWordCount, const HeaderPolicy *const headerPolicy,
+ int *const outEntryCounts) {
for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
- if (level > MAX_PREV_WORD_COUNT_FOR_N_GRAM) {
- AKLOGE("Invalid level. level: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.",
- level, MAX_PREV_WORD_COUNT_FOR_N_GRAM);
+ if (prevWordCount > MAX_PREV_WORD_COUNT_FOR_N_GRAM) {
+ AKLOGE("Invalid prevWordCount. prevWordCount: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.",
+ prevWordCount, MAX_PREV_WORD_COUNT_FOR_N_GRAM);
return false;
}
const ProbabilityEntry probabilityEntry =
ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
- if (mHasHistoricalInfo && !probabilityEntry.representsBeginningOfSentence()) {
+ if (prevWordCount > 0 && probabilityEntry.isValid()
+ && !mTrieMap.getRoot(entry.key()).mIsValid) {
+ // The entry is related to a word that has been removed. Remove the entry.
+ if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) {
+ return false;
+ }
+ continue;
+ }
+ if (mHasHistoricalInfo && !probabilityEntry.representsBeginningOfSentence()
+ && probabilityEntry.isValid()) {
const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave(
probabilityEntry.getHistoricalInfo(), headerPolicy);
if (ForgettingCurveUtils::needsToKeep(&historicalInfo, headerPolicy)) {
@@ -232,13 +308,13 @@ bool LanguageModelDictContent::updateAllProbabilityEntriesInner(const int bitmap
}
}
if (!probabilityEntry.representsBeginningOfSentence()) {
- outEntryCounts[level] += 1;
+ outEntryCounts[prevWordCount] += 1;
}
if (!entry.hasNextLevelMap()) {
continue;
}
- if (!updateAllProbabilityEntriesInner(entry.getNextLevelBitmapEntryIndex(), level + 1,
- headerPolicy, outEntryCounts)) {
+ if (!updateAllProbabilityEntriesForGCInner(entry.getNextLevelBitmapEntryIndex(),
+ prevWordCount + 1, headerPolicy, outEntryCounts)) {
return false;
}
}
@@ -266,7 +342,7 @@ bool LanguageModelDictContent::turncateEntriesInSpecifiedLevel(
for (int i = 0; i < entryCountToRemove; ++i) {
const EntryInfoToTurncate &entryInfo = entryInfoVector[i];
if (!removeNgramProbabilityEntry(
- WordIdArrayView(entryInfo.mPrevWordIds, entryInfo.mEntryLevel), entryInfo.mKey)) {
+ WordIdArrayView(entryInfo.mPrevWordIds, entryInfo.mPrevWordCount), entryInfo.mKey)) {
return false;
}
}
@@ -276,9 +352,9 @@ bool LanguageModelDictContent::turncateEntriesInSpecifiedLevel(
bool LanguageModelDictContent::getEntryInfo(const HeaderPolicy *const headerPolicy,
const int targetLevel, const int bitmapEntryIndex, std::vector<int> *const prevWordIds,
std::vector<EntryInfoToTurncate> *const outEntryInfo) const {
- const int currentLevel = prevWordIds->size();
+ const int prevWordCount = prevWordIds->size();
for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
- if (currentLevel < targetLevel) {
+ if (prevWordCount < targetLevel) {
if (!entry.hasNextLevelMap()) {
continue;
}
@@ -313,10 +389,10 @@ bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()(
if (left.mKey != right.mKey) {
return left.mKey < right.mKey;
}
- if (left.mEntryLevel != right.mEntryLevel) {
- return left.mEntryLevel > right.mEntryLevel;
+ if (left.mPrevWordCount != right.mPrevWordCount) {
+ return left.mPrevWordCount > right.mPrevWordCount;
}
- for (int i = 0; i < left.mEntryLevel; ++i) {
+ for (int i = 0; i < left.mPrevWordCount; ++i) {
if (left.mPrevWordIds[i] != right.mPrevWordIds[i]) {
return left.mPrevWordIds[i] < right.mPrevWordIds[i];
}
@@ -326,9 +402,10 @@ bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()(
}
LanguageModelDictContent::EntryInfoToTurncate::EntryInfoToTurncate(const int probability,
- const int timestamp, const int key, const int entryLevel, const int *const prevWordIds)
- : mProbability(probability), mTimestamp(timestamp), mKey(key), mEntryLevel(entryLevel) {
- memmove(mPrevWordIds, prevWordIds, mEntryLevel * sizeof(mPrevWordIds[0]));
+ const int timestamp, const int key, const int prevWordCount, const int *const prevWordIds)
+ : mProbability(probability), mTimestamp(timestamp), mKey(key),
+ mPrevWordCount(prevWordCount) {
+ memmove(mPrevWordIds, prevWordIds, mPrevWordCount * sizeof(mPrevWordIds[0]));
}
} // namespace latinime
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
index b7e4af977..51ef090e1 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
@@ -154,19 +154,23 @@ class LanguageModelDictContent {
EntryRange getProbabilityEntries(const WordIdArrayView prevWordIds) const;
- bool updateAllProbabilityEntries(const HeaderPolicy *const headerPolicy,
+ bool updateAllProbabilityEntriesForGC(const HeaderPolicy *const headerPolicy,
int *const outEntryCounts) {
for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
outEntryCounts[i] = 0;
}
- return updateAllProbabilityEntriesInner(mTrieMap.getRootBitmapEntryIndex(), 0 /* level */,
- headerPolicy, outEntryCounts);
+ return updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(),
+ 0 /* prevWordCount */, headerPolicy, outEntryCounts);
}
// entryCounts should be created by updateAllProbabilityEntries.
bool truncateEntries(const int *const entryCounts, const int *const maxEntryCounts,
const HeaderPolicy *const headerPolicy, int *const outEntryCounts);
+ bool updateAllEntriesOnInputWord(const WordIdArrayView prevWordIds, const int wordId,
+ const bool isValid, const HistoricalInfo historicalInfo,
+ const HeaderPolicy *const headerPolicy, int *const outAddedNewNgramEntryCount);
+
private:
DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent);
@@ -181,18 +185,21 @@ class LanguageModelDictContent {
};
EntryInfoToTurncate(const int probability, const int timestamp, const int key,
- const int entryLevel, const int *const prevWordIds);
+ const int prevWordCount, const int *const prevWordIds);
int mProbability;
int mTimestamp;
int mKey;
- int mEntryLevel;
+ int mPrevWordCount;
int mPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
private:
DISALLOW_DEFAULT_CONSTRUCTOR(EntryInfoToTurncate);
};
+ // TODO: Remove
+ static const int DUMMY_PROBABILITY_FOR_VALID_WORDS;
+
TrieMap mTrieMap;
const bool mHasHistoricalInfo;
@@ -201,13 +208,16 @@ class LanguageModelDictContent {
int *const outNgramCount);
int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds);
int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
- bool updateAllProbabilityEntriesInner(const int bitmapEntryIndex, const int level,
+ bool updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, const int prevWordCount,
const HeaderPolicy *const headerPolicy, int *const outEntryCounts);
bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy,
const int maxEntryCount, const int targetLevel, int *const outEntryCount);
bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel,
const int bitmapEntryIndex, std::vector<int> *const prevWordIds,
std::vector<EntryInfoToTurncate> *const outEntryInfo) const;
+ const ProbabilityEntry createUpdatedEntryFrom(const ProbabilityEntry &originalProbabilityEntry,
+ const bool isValid, const HistoricalInfo historicalInfo,
+ const HeaderPolicy *const headerPolicy) const;
};
} // namespace latinime
#endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h
index fa1415633..f4d340f86 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h
@@ -98,17 +98,17 @@ class ProbabilityEntry {
}
uint64_t encode(const bool hasHistoricalInfo) const {
- uint64_t encodedEntry = static_cast<uint64_t>(mFlags);
+ uint64_t encodedEntry = static_cast<uint8_t>(mFlags);
if (hasHistoricalInfo) {
encodedEntry = (encodedEntry << (Ver4DictConstants::TIME_STAMP_FIELD_SIZE * CHAR_BIT))
- ^ static_cast<uint64_t>(mHistoricalInfo.getTimestamp());
+ | static_cast<uint32_t>(mHistoricalInfo.getTimestamp());
encodedEntry = (encodedEntry << (Ver4DictConstants::WORD_LEVEL_FIELD_SIZE * CHAR_BIT))
- ^ static_cast<uint64_t>(mHistoricalInfo.getLevel());
+ | static_cast<uint8_t>(mHistoricalInfo.getLevel());
encodedEntry = (encodedEntry << (Ver4DictConstants::WORD_COUNT_FIELD_SIZE * CHAR_BIT))
- ^ static_cast<uint64_t>(mHistoricalInfo.getCount());
+ | static_cast<uint8_t>(mHistoricalInfo.getCount());
} else {
encodedEntry = (encodedEntry << (Ver4DictConstants::PROBABILITY_SIZE * CHAR_BIT))
- ^ static_cast<uint64_t>(mProbability);
+ | static_cast<uint8_t>(mProbability);
}
return encodedEntry;
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp
index f13512d5a..794c63ffd 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp
@@ -142,14 +142,9 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeUnigramProperty(
if (!toBeUpdatedPtNodeParams->isTerminal()) {
return false;
}
- const ProbabilityEntry originalProbabilityEntry =
- mBuffers->getLanguageModelDictContent()->getProbabilityEntry(
- toBeUpdatedPtNodeParams->getTerminalId());
const ProbabilityEntry probabilityEntryOfUnigramProperty = ProbabilityEntry(unigramProperty);
- const ProbabilityEntry updatedProbabilityEntry =
- createUpdatedEntryFrom(&originalProbabilityEntry, &probabilityEntryOfUnigramProperty);
return mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry(
- toBeUpdatedPtNodeParams->getTerminalId(), &updatedProbabilityEntry);
+ toBeUpdatedPtNodeParams->getTerminalId(), &probabilityEntryOfUnigramProperty);
}
bool Ver4PatriciaTrieNodeWriter::updatePtNodeProbabilityAndGetNeedsToKeepPtNodeAfterGC(
@@ -203,10 +198,8 @@ bool Ver4PatriciaTrieNodeWriter::writeNewTerminalPtNodeAndAdvancePosition(
// Write probability.
ProbabilityEntry newProbabilityEntry;
const ProbabilityEntry probabilityEntryOfUnigramProperty = ProbabilityEntry(unigramProperty);
- const ProbabilityEntry probabilityEntryToWrite = createUpdatedEntryFrom(
- &newProbabilityEntry, &probabilityEntryOfUnigramProperty);
return mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry(
- terminalId, &probabilityEntryToWrite);
+ terminalId, &probabilityEntryOfUnigramProperty);
}
// TODO: Support counting ngram entries.
@@ -217,10 +210,8 @@ bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds
const ProbabilityEntry probabilityEntry =
languageModelDictContent->getNgramProbabilityEntry(prevWordIds, wordId);
const ProbabilityEntry probabilityEntryOfNgramProperty(ngramProperty);
- const ProbabilityEntry updatedProbabilityEntry = createUpdatedEntryFrom(
- &probabilityEntry, &probabilityEntryOfNgramProperty);
if (!languageModelDictContent->setNgramProbabilityEntry(
- prevWordIds, wordId, &updatedProbabilityEntry)) {
+ prevWordIds, wordId, &probabilityEntryOfNgramProperty)) {
AKLOGE("Cannot add new ngram entry. prevWordId[0]: %d, prevWordId.size(): %zd, wordId: %d",
prevWordIds[0], prevWordIds.size(), wordId);
return false;
@@ -285,7 +276,7 @@ bool Ver4PatriciaTrieNodeWriter::addShortcutTarget(const PtNodeParams *const ptN
const int shortcutProbability) {
if (!mShortcutPolicy->addNewShortcut(ptNodeParams->getTerminalId(),
targetCodePoints, targetCodePointCount, shortcutProbability)) {
- AKLOGE("Cannot add new shortuct entry. terminalId: %d", ptNodeParams->getTerminalId());
+ AKLOGE("Cannot add new shortcut entry. terminalId: %d", ptNodeParams->getTerminalId());
return false;
}
return true;
@@ -346,22 +337,6 @@ bool Ver4PatriciaTrieNodeWriter::writePtNodeAndGetTerminalIdAndAdvancePosition(
ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */);
}
-// TODO: Move probability handling code to LanguageModelDictContent.
-const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom(
- const ProbabilityEntry *const originalProbabilityEntry,
- const ProbabilityEntry *const probabilityEntry) const {
- if (mHeaderPolicy->hasHistoricalInfoOfWords()) {
- const HistoricalInfo updatedHistoricalInfo =
- ForgettingCurveUtils::createUpdatedHistoricalInfo(
- originalProbabilityEntry->getHistoricalInfo(),
- probabilityEntry->getProbability(), probabilityEntry->getHistoricalInfo(),
- mHeaderPolicy);
- return ProbabilityEntry(probabilityEntry->getFlags(), &updatedHistoricalInfo);
- } else {
- return *probabilityEntry;
- }
-}
-
bool Ver4PatriciaTrieNodeWriter::updatePtNodeFlags(const int ptNodePos, const bool isTerminal,
const bool hasMultipleChars) {
// Create node flags and write them.
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h
index ea4f09904..4ecf88729 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h
@@ -38,11 +38,10 @@ class Ver4ShortcutListPolicy;
class Ver4PatriciaTrieNodeWriter : public PtNodeWriter {
public:
Ver4PatriciaTrieNodeWriter(BufferWithExtendableBuffer *const trieBuffer,
- Ver4DictBuffers *const buffers, const HeaderPolicy *const headerPolicy,
- const PtNodeReader *const ptNodeReader,
+ Ver4DictBuffers *const buffers, const PtNodeReader *const ptNodeReader,
const PtNodeArrayReader *const ptNodeArrayReader,
Ver4ShortcutListPolicy *const shortcutPolicy)
- : mTrieBuffer(trieBuffer), mBuffers(buffers), mHeaderPolicy(headerPolicy),
+ : mTrieBuffer(trieBuffer), mBuffers(buffers),
mReadingHelper(ptNodeReader, ptNodeArrayReader), mShortcutPolicy(shortcutPolicy) {}
virtual ~Ver4PatriciaTrieNodeWriter() {}
@@ -96,20 +95,12 @@ class Ver4PatriciaTrieNodeWriter : public PtNodeWriter {
const PtNodeParams *const ptNodeParams, int *const outTerminalId,
int *const ptNodeWritingPos);
- // Create updated probability entry using given probability property. In addition to the
- // probability, this method updates historical information if needed.
- // TODO: Update flags.
- const ProbabilityEntry createUpdatedEntryFrom(
- const ProbabilityEntry *const originalProbabilityEntry,
- const ProbabilityEntry *const probabilityEntry) const;
-
bool updatePtNodeFlags(const int ptNodePos, const bool isTerminal, const bool hasMultipleChars);
static const int CHILDREN_POSITION_FIELD_SIZE;
BufferWithExtendableBuffer *const mTrieBuffer;
Ver4DictBuffers *const mBuffers;
- const HeaderPolicy *const mHeaderPolicy;
DynamicPtReadingHelper mReadingHelper;
Ver4ShortcutListPolicy *const mShortcutPolicy;
};
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
index 036197c41..ea8c0dc22 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
@@ -26,7 +26,7 @@
#include "suggest/core/dictionary/property/ngram_property.h"
#include "suggest/core/dictionary/property/unigram_property.h"
#include "suggest/core/dictionary/property/word_property.h"
-#include "suggest/core/session/prev_words_info.h"
+#include "suggest/core/session/ngram_context.h"
#include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h"
#include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.h"
#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h"
@@ -43,7 +43,6 @@ const char *const Ver4PatriciaTriePolicy::MAX_BIGRAM_COUNT_QUERY = "MAX_BIGRAM_C
const int Ver4PatriciaTriePolicy::MARGIN_TO_REFUSE_DYNAMIC_OPERATIONS = 1024;
const int Ver4PatriciaTriePolicy::MIN_DICT_SIZE_TO_REFUSE_DYNAMIC_OPERATIONS =
Ver4DictConstants::MAX_DICTIONARY_SIZE - MARGIN_TO_REFUSE_DYNAMIC_OPERATIONS;
-const int Ver4PatriciaTriePolicy::DUMMY_PROBABILITY_FOR_VALID_WORDS = 1;
void Ver4PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const dicNode,
DicNodeVector *const childDicNodes) const {
@@ -151,8 +150,7 @@ void Ver4PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordI
}
const int probability = probabilityEntry.hasHistoricalInfo() ?
ForgettingCurveUtils::decodeProbability(
- probabilityEntry.getHistoricalInfo(), mHeaderPolicy)
- + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */) :
+ probabilityEntry.getHistoricalInfo(), mHeaderPolicy) :
probabilityEntry.getProbability();
listener->onVisitEntry(probability, entry.getWordId());
}
@@ -266,7 +264,7 @@ bool Ver4PatriciaTriePolicy::removeUnigramEntry(const CodePointArrayView wordCod
return true;
}
-bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsInfo,
+bool Ver4PatriciaTriePolicy::addNgramEntry(const NgramContext *const ngramContext,
const NgramProperty *const ngramProperty) {
if (!mBuffers->isUpdatable()) {
AKLOGI("Warning: addNgramEntry() is called for non-updatable dictionary.");
@@ -277,8 +275,8 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI
mDictBuffer->getTailPosition());
return false;
}
- if (!prevWordsInfo->isValid()) {
- AKLOGE("prev words info is not valid for adding n-gram entry to the dictionary.");
+ if (!ngramContext->isValid()) {
+ AKLOGE("Ngram context is not valid for adding n-gram entry to the dictionary.");
return false;
}
if (ngramProperty->getTargetCodePoints()->size() > MAX_WORD_LENGTH) {
@@ -287,7 +285,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI
return false;
}
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray;
- const WordIdArrayView prevWordIds = prevWordsInfo->getPrevWordIds(this, &prevWordIdArray,
+ const WordIdArrayView prevWordIds = ngramContext->getPrevWordIds(this, &prevWordIdArray,
false /* tryLowerCaseSearch */);
if (prevWordIds.empty()) {
return false;
@@ -296,19 +294,19 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI
if (prevWordIds[i] != NOT_A_WORD_ID) {
continue;
}
- if (!prevWordsInfo->isNthPrevWordBeginningOfSentence(i + 1 /* n */)) {
+ if (!ngramContext->isNthPrevWordBeginningOfSentence(i + 1 /* n */)) {
return false;
}
const UnigramProperty beginningOfSentenceUnigramProperty(
true /* representsBeginningOfSentence */, true /* isNotAWord */,
false /* isBlacklisted */, MAX_PROBABILITY /* probability */, HistoricalInfo());
- if (!addUnigramEntry(prevWordsInfo->getNthPrevWordCodePoints(1 /* n */),
+ if (!addUnigramEntry(ngramContext->getNthPrevWordCodePoints(1 /* n */),
&beginningOfSentenceUnigramProperty)) {
AKLOGE("Cannot add unigram entry for the beginning-of-sentence.");
return false;
}
// Refresh word ids.
- prevWordsInfo->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */);
+ ngramContext->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */);
}
const int wordId = getWordId(CodePointArrayView(*ngramProperty->getTargetCodePoints()),
false /* forceLowerCaseSearch */);
@@ -326,7 +324,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI
}
}
-bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWordsInfo,
+bool Ver4PatriciaTriePolicy::removeNgramEntry(const NgramContext *const ngramContext,
const CodePointArrayView wordCodePoints) {
if (!mBuffers->isUpdatable()) {
AKLOGI("Warning: removeNgramEntry() is called for non-updatable dictionary.");
@@ -337,8 +335,8 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor
mDictBuffer->getTailPosition());
return false;
}
- if (!prevWordsInfo->isValid()) {
- AKLOGE("prev words info is not valid for removing n-gram entry form the dictionary.");
+ if (!ngramContext->isValid()) {
+ AKLOGE("Ngram context is not valid for removing n-gram entry form the dictionary.");
return false;
}
if (wordCodePoints.size() > MAX_WORD_LENGTH) {
@@ -346,7 +344,7 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor
wordCodePoints.size());
}
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray;
- const WordIdArrayView prevWordIds = prevWordsInfo->getPrevWordIds(this, &prevWordIdArray,
+ const WordIdArrayView prevWordIds = ngramContext->getPrevWordIds(this, &prevWordIdArray,
false /* tryLowerCaseSerch */);
if (prevWordIds.empty() || prevWordIds.contains(NOT_A_WORD_ID)) {
return false;
@@ -363,32 +361,52 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor
}
}
-bool Ver4PatriciaTriePolicy::updateCounter(const PrevWordsInfo *const prevWordsInfo,
- const CodePointArrayView wordCodePoints, const bool isValidWord,
- const HistoricalInfo historicalInfo) {
+bool Ver4PatriciaTriePolicy::updateEntriesForWordWithNgramContext(
+ const NgramContext *const ngramContext, const CodePointArrayView wordCodePoints,
+ const bool isValidWord, const HistoricalInfo historicalInfo) {
if (!mBuffers->isUpdatable()) {
- AKLOGI("Warning: updateCounter() is called for non-updatable dictionary.");
+ AKLOGI("Warning: updateEntriesForWordWithNgramContext() is called for non-updatable "
+ "dictionary.");
return false;
}
- // TODO: Have count up method in language model dict content.
- const int probability = isValidWord ? DUMMY_PROBABILITY_FOR_VALID_WORDS : NOT_A_PROBABILITY;
- const UnigramProperty unigramProperty(false /* representsBeginningOfSentence */,
- false /* isNotAWord */, false /*isBlacklisted*/, probability, historicalInfo);
- if (!addUnigramEntry(wordCodePoints, &unigramProperty)) {
- AKLOGE("Cannot update unigarm entry in updateCounter().");
- return false;
+ const bool updateAsAValidWord = ngramContext->isNthPrevWordBeginningOfSentence(1 /* n */) ?
+ false : isValidWord;
+ int wordId = getWordId(wordCodePoints, false /* tryLowerCaseSearch */);
+ if (wordId == NOT_A_WORD_ID) {
+ // The word is not in the dictionary.
+ const UnigramProperty unigramProperty(false /* representsBeginningOfSentence */,
+ false /* isNotAWord */, false /* isBlacklisted */, NOT_A_PROBABILITY,
+ HistoricalInfo(historicalInfo.getTimestamp(), 0 /* level */, 0 /* count */));
+ if (!addUnigramEntry(wordCodePoints, &unigramProperty)) {
+ AKLOGE("Cannot add unigarm entry in updateEntriesForWordWithNgramContext().");
+ return false;
+ }
+ wordId = getWordId(wordCodePoints, false /* tryLowerCaseSearch */);
}
- const int probabilityForNgram = prevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */)
- ? NOT_A_PROBABILITY : probability;
- const NgramProperty ngramProperty(wordCodePoints.toVector(), probabilityForNgram,
- historicalInfo);
- for (size_t i = 1; i <= prevWordsInfo->getPrevWordCount(); ++i) {
- const PrevWordsInfo trimmedPrevWordsInfo(prevWordsInfo->getTrimmedPrevWordsInfo(i));
- if (!addNgramEntry(&trimmedPrevWordsInfo, &ngramProperty)) {
- AKLOGE("Cannot update ngram entry in updateCounter().");
+
+ WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray;
+ const WordIdArrayView prevWordIds = ngramContext->getPrevWordIds(this, &prevWordIdArray,
+ false /* tryLowerCaseSearch */);
+ if (prevWordIds.firstOrDefault(NOT_A_WORD_ID) == NOT_A_WORD_ID
+ && ngramContext->isNthPrevWordBeginningOfSentence(1 /* n */)) {
+ const UnigramProperty beginningOfSentenceUnigramProperty(
+ true /* representsBeginningOfSentence */,
+ true /* isNotAWord */, false /* isBlacklisted */, NOT_A_PROBABILITY,
+ HistoricalInfo(historicalInfo.getTimestamp(), 0 /* level */, 0 /* count */));
+ if (!addUnigramEntry(ngramContext->getNthPrevWordCodePoints(1 /* n */),
+ &beginningOfSentenceUnigramProperty)) {
+ AKLOGE("Cannot add BoS entry in updateEntriesForWordWithNgramContext().");
return false;
}
+ // Refresh word ids.
+ ngramContext->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */);
+ }
+ int addedNewNgramEntryCount = 0;
+ if (!mBuffers->getMutableLanguageModelDictContent()->updateAllEntriesOnInputWord(prevWordIds,
+ wordId, updateAsAValidWord, historicalInfo, mHeaderPolicy, &addedNewNgramEntryCount)) {
+ return false;
}
+ mBigramCount += addedNewNgramEntryCount;
return true;
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h
index 662bb8d4b..c0532815c 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h
@@ -47,8 +47,8 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
mShortcutPolicy(mBuffers->getMutableShortcutDictContent(),
mBuffers->getTerminalPositionLookupTable()),
mNodeReader(mDictBuffer), mPtNodeArrayReader(mDictBuffer),
- mNodeWriter(mDictBuffer, mBuffers.get(), mHeaderPolicy, &mNodeReader,
- &mPtNodeArrayReader, &mShortcutPolicy),
+ mNodeWriter(mDictBuffer, mBuffers.get(), &mNodeReader, &mPtNodeArrayReader,
+ &mShortcutPolicy),
mUpdatingHelper(mDictBuffer, &mNodeReader, &mNodeWriter),
mWritingHelper(mBuffers.get()),
mUnigramCount(mHeaderPolicy->getUnigramCount()),
@@ -92,13 +92,13 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
bool removeUnigramEntry(const CodePointArrayView wordCodePoints);
- bool addNgramEntry(const PrevWordsInfo *const prevWordsInfo,
+ bool addNgramEntry(const NgramContext *const ngramContext,
const NgramProperty *const ngramProperty);
- bool removeNgramEntry(const PrevWordsInfo *const prevWordsInfo,
+ bool removeNgramEntry(const NgramContext *const ngramContext,
const CodePointArrayView wordCodePoints);
- bool updateCounter(const PrevWordsInfo *const prevWordsInfo,
+ bool updateEntriesForWordWithNgramContext(const NgramContext *const ngramContext,
const CodePointArrayView wordCodePoints, const bool isValidWord,
const HistoricalInfo historicalInfo);
@@ -131,8 +131,6 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
// prevent the dictionary from overflowing.
static const int MARGIN_TO_REFUSE_DYNAMIC_OPERATIONS;
static const int MIN_DICT_SIZE_TO_REFUSE_DYNAMIC_OPERATIONS;
- // TODO: Remove
- static const int DUMMY_PROBABILITY_FOR_VALID_WORDS;
const Ver4DictBuffers::Ver4DictBuffersPtr mBuffers;
const HeaderPolicy *const mHeaderPolicy;
@@ -144,6 +142,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
DynamicPtUpdatingHelper mUpdatingHelper;
Ver4PatriciaTrieWritingHelper mWritingHelper;
int mUnigramCount;
+ // TODO: Support counting ngram entries.
int mBigramCount;
std::vector<int> mTerminalPtNodePositionsForIteratingWords;
mutable bool mIsCorrupted;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
index e1ff973de..f0d59c150 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
@@ -78,11 +78,11 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
Ver4ShortcutListPolicy shortcutPolicy(mBuffers->getMutableShortcutDictContent(),
mBuffers->getTerminalPositionLookupTable());
Ver4PatriciaTrieNodeWriter ptNodeWriter(mBuffers->getWritableTrieBuffer(),
- mBuffers, headerPolicy, &ptNodeReader, &ptNodeArrayReader, &shortcutPolicy);
+ mBuffers, &ptNodeReader, &ptNodeArrayReader, &shortcutPolicy);
int entryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
- if (!mBuffers->getMutableLanguageModelDictContent()->updateAllProbabilityEntries(headerPolicy,
- entryCountTable)) {
+ if (!mBuffers->getMutableLanguageModelDictContent()->updateAllProbabilityEntriesForGC(
+ headerPolicy, entryCountTable)) {
AKLOGE("Failed to update probabilities in language model dict content.");
return false;
}
@@ -118,7 +118,7 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
PtNodeWriter::DictPositionRelocationMap dictPositionRelocationMap;
readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos);
Ver4PatriciaTrieNodeWriter ptNodeWriterForNewBuffers(buffersToWrite->getWritableTrieBuffer(),
- buffersToWrite, headerPolicy, &ptNodeReader, &ptNodeArrayReader, &shortcutPolicy);
+ buffersToWrite, &ptNodeReader, &ptNodeArrayReader, &shortcutPolicy);
DynamicPtGcEventListeners::TraversePolicyToPlaceAndWriteValidPtNodesToBuffer
traversePolicyToPlaceAndWriteValidPtNodesToBuffer(&ptNodeWriterForNewBuffers,
buffersToWrite->getWritableTrieBuffer(), &dictPositionRelocationMap);
@@ -133,7 +133,7 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
Ver4ShortcutListPolicy newShortcutPolicy(buffersToWrite->getMutableShortcutDictContent(),
buffersToWrite->getTerminalPositionLookupTable());
Ver4PatriciaTrieNodeWriter newPtNodeWriter(buffersToWrite->getWritableTrieBuffer(),
- buffersToWrite, headerPolicy, &newPtNodeReader, &newPtNodeArrayreader,
+ buffersToWrite, &newPtNodeReader, &newPtNodeArrayreader,
&newShortcutPolicy);
// Re-assign terminal IDs for valid terminal PtNodes.
TerminalPositionLookupTable::TerminalIdMap terminalIdMap;
diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp
index 3fc566e7a..6a2db687d 100644
--- a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp
+++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp
@@ -31,6 +31,7 @@ const float ScoringParams::DIGRAPH_PENALTY_FOR_EXACT_MATCH = 0.03f;
// TODO: Unlimit max cache dic node size
const int ScoringParams::MAX_CACHE_DIC_NODE_SIZE = 170;
const int ScoringParams::MAX_CACHE_DIC_NODE_SIZE_FOR_SINGLE_POINT = 310;
+const int ScoringParams::MAX_CACHE_DIC_NODE_SIZE_FOR_LOW_PROBABILITY_LOCALE = 50;
const int ScoringParams::THRESHOLD_SHORT_WORD_LENGTH = 4;
const float ScoringParams::DISTANCE_WEIGHT_LENGTH = 0.1524f;
@@ -48,7 +49,7 @@ const float ScoringParams::INSERTION_COST_PROXIMITY_CHAR = 0.674f;
const float ScoringParams::INSERTION_COST_FIRST_CHAR = 0.639f;
const float ScoringParams::TRANSPOSITION_COST = 0.5608f;
const float ScoringParams::SPACE_SUBSTITUTION_COST = 0.334f;
-const float ScoringParams::ADDITIONAL_PROXIMITY_COST = 0.4576f;
+const float ScoringParams::ADDITIONAL_PROXIMITY_COST = 0.37972f;
const float ScoringParams::SUBSTITUTION_COST = 0.3806f;
const float ScoringParams::COST_NEW_WORD = 0.0314f;
const float ScoringParams::COST_SECOND_OR_LATER_WORD_FIRST_CHAR_UPPERCASE = 0.3224f;
@@ -61,4 +62,7 @@ const float ScoringParams::HAS_MULTI_WORD_TERMINAL_COST = 0.4182f;
const float ScoringParams::TYPING_BASE_OUTPUT_SCORE = 1.0f;
const float ScoringParams::TYPING_MAX_OUTPUT_SCORE_PER_INPUT = 0.1f;
const float ScoringParams::NORMALIZED_SPATIAL_DISTANCE_THRESHOLD_FOR_EDIT = 0.095f;
+const float ScoringParams::LOCALE_WEIGHT_THRESHOLD_FOR_SPACE_SUBSTITUTION = 0.99f;
+const float ScoringParams::LOCALE_WEIGHT_THRESHOLD_FOR_SPACE_OMISSION = 0.99f;
+const float ScoringParams::LOCALE_WEIGHT_THRESHOLD_FOR_SMALL_CACHE_SIZE = 0.99f;
} // namespace latinime
diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.h b/native/jni/src/suggest/policyimpl/typing/scoring_params.h
index b12de6d87..731424f3d 100644
--- a/native/jni/src/suggest/policyimpl/typing/scoring_params.h
+++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.h
@@ -30,6 +30,7 @@ class ScoringParams {
static const float AUTOCORRECT_OUTPUT_THRESHOLD;
static const int MAX_CACHE_DIC_NODE_SIZE;
static const int MAX_CACHE_DIC_NODE_SIZE_FOR_SINGLE_POINT;
+ static const int MAX_CACHE_DIC_NODE_SIZE_FOR_LOW_PROBABILITY_LOCALE;
static const int THRESHOLD_SHORT_WORD_LENGTH;
static const float EXACT_MATCH_PROMOTION;
@@ -68,6 +69,9 @@ class ScoringParams {
static const float TYPING_BASE_OUTPUT_SCORE;
static const float TYPING_MAX_OUTPUT_SCORE_PER_INPUT;
static const float NORMALIZED_SPATIAL_DISTANCE_THRESHOLD_FOR_EDIT;
+ static const float LOCALE_WEIGHT_THRESHOLD_FOR_SPACE_SUBSTITUTION;
+ static const float LOCALE_WEIGHT_THRESHOLD_FOR_SPACE_OMISSION;
+ static const float LOCALE_WEIGHT_THRESHOLD_FOR_SMALL_CACHE_SIZE;
private:
DISALLOW_IMPLICIT_CONSTRUCTORS(ScoringParams);
diff --git a/native/jni/src/suggest/policyimpl/typing/typing_traversal.h b/native/jni/src/suggest/policyimpl/typing/typing_traversal.h
index b64ee8be4..b9b6314ae 100644
--- a/native/jni/src/suggest/policyimpl/typing/typing_traversal.h
+++ b/native/jni/src/suggest/policyimpl/typing/typing_traversal.h
@@ -26,6 +26,7 @@
#include "suggest/core/layout/proximity_info_utils.h"
#include "suggest/core/policy/traversal.h"
#include "suggest/core/session/dic_traverse_session.h"
+#include "suggest/core/suggest_options.h"
#include "suggest/policyimpl/typing/scoring_params.h"
#include "utils/char_utils.h"
@@ -77,6 +78,13 @@ class TypingTraversal : public Traversal {
if (!CORRECT_NEW_WORD_SPACE_SUBSTITUTION) {
return false;
}
+ if (traverseSession->getSuggestOptions()->weightForLocale()
+ < ScoringParams::LOCALE_WEIGHT_THRESHOLD_FOR_SPACE_SUBSTITUTION) {
+ // Space substitution is heavy, so we skip doing it if the weight for this language
+ // is low because we anticipate the suggestions out of this dictionary are not for
+ // the language the user intends to type in.
+ return false;
+ }
if (!canDoLookAheadCorrection(traverseSession, dicNode)) {
return false;
}
@@ -91,6 +99,13 @@ class TypingTraversal : public Traversal {
if (!CORRECT_NEW_WORD_SPACE_OMISSION) {
return false;
}
+ if (traverseSession->getSuggestOptions()->weightForLocale()
+ < ScoringParams::LOCALE_WEIGHT_THRESHOLD_FOR_SPACE_OMISSION) {
+ // Space omission is heavy, so we skip doing it if the weight for this language
+ // is low because we anticipate the suggestions out of this dictionary are not for
+ // the language the user intends to type in.
+ return false;
+ }
const int inputSize = traverseSession->getInputSize();
// TODO: Don't refer to isCompletion?
if (dicNode->isCompletion(inputSize)) {
@@ -141,9 +156,14 @@ class TypingTraversal : public Traversal {
return DicNodeVector::DEFAULT_NODES_SIZE_FOR_OPTIMIZATION;
}
- AK_FORCE_INLINE int getMaxCacheSize(const int inputSize) const {
- return (inputSize <= 1) ? ScoringParams::MAX_CACHE_DIC_NODE_SIZE_FOR_SINGLE_POINT
- : ScoringParams::MAX_CACHE_DIC_NODE_SIZE;
+ AK_FORCE_INLINE int getMaxCacheSize(const int inputSize, const float weightForLocale) const {
+ if (inputSize <= 1) {
+ return ScoringParams::MAX_CACHE_DIC_NODE_SIZE_FOR_SINGLE_POINT;
+ }
+ if (weightForLocale < ScoringParams::LOCALE_WEIGHT_THRESHOLD_FOR_SMALL_CACHE_SIZE) {
+ return ScoringParams::MAX_CACHE_DIC_NODE_SIZE_FOR_LOW_PROBABILITY_LOCALE;
+ }
+ return ScoringParams::MAX_CACHE_DIC_NODE_SIZE;
}
AK_FORCE_INLINE int getTerminalCacheSize() const {
diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp b/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp
index 1d590c353..db7a39efb 100644
--- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp
+++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp
@@ -68,7 +68,8 @@ ErrorTypeUtils::ErrorType TypingWeighting::getErrorType(const CorrectionType cor
}
break;
case CT_ADDITIONAL_PROXIMITY:
- return ErrorTypeUtils::PROXIMITY_CORRECTION;
+ // TODO: Change to EDIT_CORRECTION.
+ return ErrorTypeUtils::PROXIMITY_CORRECTION;
case CT_OMISSION:
if (parentDicNode->canBeIntentionalOmission()) {
return ErrorTypeUtils::INTENTIONAL_OMISSION;
@@ -77,6 +78,8 @@ ErrorTypeUtils::ErrorType TypingWeighting::getErrorType(const CorrectionType cor
}
break;
case CT_SUBSTITUTION:
+ // TODO: Quit settng PROXIMITY_CORRECTION.
+ return ErrorTypeUtils::EDIT_CORRECTION | ErrorTypeUtils::PROXIMITY_CORRECTION;
case CT_INSERTION:
case CT_TERMINAL_INSERTION:
case CT_TRANSPOSITION:
diff --git a/native/jni/src/utils/jni_data_utils.h b/native/jni/src/utils/jni_data_utils.h
index 235a03bba..25cc41742 100644
--- a/native/jni/src/utils/jni_data_utils.h
+++ b/native/jni/src/utils/jni_data_utils.h
@@ -21,7 +21,7 @@
#include "defines.h"
#include "jni.h"
-#include "suggest/core/session/prev_words_info.h"
+#include "suggest/core/session/ngram_context.h"
#include "suggest/core/policy/dictionary_header_structure_policy.h"
#include "suggest/policyimpl/dictionary/header/header_read_write_utils.h"
#include "utils/char_utils.h"
@@ -96,7 +96,7 @@ class JniDataUtils {
}
}
- static PrevWordsInfo constructPrevWordsInfo(JNIEnv *env, jobjectArray prevWordCodePointArrays,
+ static NgramContext constructNgramContext(JNIEnv *env, jobjectArray prevWordCodePointArrays,
jbooleanArray isBeginningOfSentenceArray, const size_t prevWordCount) {
int prevWordCodePoints[MAX_PREV_WORD_COUNT_FOR_N_GRAM][MAX_WORD_LENGTH];
int prevWordCodePointCount[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
@@ -119,7 +119,7 @@ class JniDataUtils {
&isBeginningOfSentenceBoolean);
isBeginningOfSentence[i] = isBeginningOfSentenceBoolean == JNI_TRUE;
}
- return PrevWordsInfo(prevWordCodePoints, prevWordCodePointCount, isBeginningOfSentence,
+ return NgramContext(prevWordCodePoints, prevWordCodePointCount, isBeginningOfSentence,
prevWordCount);
}