diff options
Diffstat (limited to 'native/jni/src')
31 files changed, 629 insertions, 92 deletions
diff --git a/native/jni/src/defines.h b/native/jni/src/defines.h index 89dfa39b3..742e388e4 100644 --- a/native/jni/src/defines.h +++ b/native/jni/src/defines.h @@ -299,6 +299,19 @@ static inline void prof_out(void) { #define NOT_A_PROBABILITY (-1) #define NOT_A_DICT_POS (S_INT_MIN) +// A special value to mean the first word confidence makes no sense in this case, +// e.g. this is not a multi-word suggestion. +#define NOT_A_FIRST_WORD_CONFIDENCE (S_INT_MAX) +// How high the confidence needs to be for us to auto-commit. Arbitrary. +// This needs to be the same as CONFIDENCE_FOR_AUTO_COMMIT in BinaryDictionary.java +#define CONFIDENCE_FOR_AUTO_COMMIT (1000000) +// 80% of the full confidence +#define DISTANCE_WEIGHT_FOR_AUTO_COMMIT (80 * CONFIDENCE_FOR_AUTO_COMMIT / 100) +// 100% of the full confidence +#define LENGTH_WEIGHT_FOR_AUTO_COMMIT (CONFIDENCE_FOR_AUTO_COMMIT) +// 80% of the full confidence +#define SPACE_COUNT_WEIGHT_FOR_AUTO_COMMIT (80 * CONFIDENCE_FOR_AUTO_COMMIT / 100) + #define KEYCODE_SPACE ' ' #define KEYCODE_SINGLE_QUOTE '\'' #define KEYCODE_HYPHEN_MINUS '-' @@ -375,7 +388,7 @@ typedef enum { CT_TERMINAL, CT_TERMINAL_INSERTION, // Create new word with space omission - CT_NEW_WORD_SPACE_OMITTION, + CT_NEW_WORD_SPACE_OMISSION, // Create new word with space substitution CT_NEW_WORD_SPACE_SUBSTITUTION, } CorrectionType; diff --git a/native/jni/src/suggest/core/dicnode/dic_node.h b/native/jni/src/suggest/core/dicnode/dic_node.h index 41ef9d2b2..49cfdecac 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node.h +++ b/native/jni/src/suggest/core/dicnode/dic_node.h @@ -38,10 +38,10 @@ INTS_TO_CHARS(mDicNodeState.mDicNodeStatePrevWord.mPrevWord, \ mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(), prevWordCharBuf, \ NELEMS(prevWordCharBuf)); \ - AKLOGI("#%8s, %5f, %5f, %5f, %5f, %s, %s, %d,,", header, \ + AKLOGI("#%8s, %5f, %5f, %5f, %5f, %s, %s, %d, %5f,", header, \ getSpatialDistanceForScoring(), getLanguageDistanceForScoring(), \ getNormalizedCompoundDistance(), getRawLength(), prevWordCharBuf, charBuf, \ - getInputIndex(0)); \ + getInputIndex(0), getNormalizedCompoundDistanceAfterFirstWord()); \ } while (0) #else #define LOGI_SHOW_ADD_COST_PROP @@ -271,7 +271,7 @@ class DicNode { return isTerminalNodes && currentNodeDepth > 0 && currentNodeDepth == terminalNodeDepth; } - bool shouldBeFilterdBySafetyNetForBigram() const { + bool shouldBeFilteredBySafetyNetForBigram() const { const uint16_t currentDepth = getNodeCodePointCount(); const int prevWordLen = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength() - mDicNodeState.mDicNodeStatePrevWord.getPrevWordStart() - 1; @@ -321,6 +321,16 @@ class DicNode { DUMP_WORD_AND_SCORE("OUTPUT"); } + // "Total" in this context (and other methods in this class) means the whole suggestion. When + // this represents a multi-word suggestion, the referenced PtNode (in mDicNodeState) is only + // the one that corresponds to the last word of the suggestion, and all the previous words + // are concatenated together in mPrevWord - which contains a space at the end. + int getTotalNodeSpaceCount() const { + if (isFirstWord()) return 0; + return CharUtils::getSpaceCount(mDicNodeState.mDicNodeStatePrevWord.mPrevWord, + mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength()); + } + int getSecondWordFirstInputIndex(const ProximityInfoState *const pInfoState) const { const int inputIndex = mDicNodeState.mDicNodeStatePrevWord.getSecondWordFirstInputIndex(); if (inputIndex == NOT_AN_INDEX) { @@ -434,6 +444,13 @@ class DicNode { return mDicNodeState.mDicNodeStateScoring.getLanguageDistance(); } + // For space-aware gestures, we store the normalized distance at the char index + // that ends the first word of the suggestion. We call this the distance after + // first word. + float getNormalizedCompoundDistanceAfterFirstWord() const { + return mDicNodeState.mDicNodeStateScoring.getNormalizedCompoundDistanceAfterFirstWord(); + } + float getLanguageDistanceRatePerWordForScoring() const { const float langDist = getLanguageDistanceForScoring(); const float totalWordCount = @@ -565,6 +582,12 @@ class DicNode { inputSize, getTotalInputIndex(), errorType); } + // Saves the current normalized compound distance for space-aware gestures. + // See getNormalizedCompoundDistanceAfterFirstWord for details. + AK_FORCE_INLINE void saveNormalizedCompoundDistanceAfterFirstWordIfNoneYet() { + mDicNodeState.mDicNodeStateScoring.saveNormalizedCompoundDistanceAfterFirstWordIfNoneYet(); + } + // Caveat: Must not be called outside Weighting // This restriction is guaranteed by "friend" AK_FORCE_INLINE void forwardInputIndex(const int pointerId, const int count, diff --git a/native/jni/src/suggest/core/dicnode/internal/dic_node_state_scoring.h b/native/jni/src/suggest/core/dicnode/internal/dic_node_state_scoring.h index 4c884225a..3c85d0e9d 100644 --- a/native/jni/src/suggest/core/dicnode/internal/dic_node_state_scoring.h +++ b/native/jni/src/suggest/core/dicnode/internal/dic_node_state_scoring.h @@ -31,7 +31,8 @@ class DicNodeStateScoring { mDigraphIndex(DigraphUtils::NOT_A_DIGRAPH_INDEX), mEditCorrectionCount(0), mProximityCorrectionCount(0), mNormalizedCompoundDistance(0.0f), mSpatialDistance(0.0f), mLanguageDistance(0.0f), - mRawLength(0.0f), mExactMatch(true) { + mRawLength(0.0f), mExactMatch(true), + mNormalizedCompoundDistanceAfterFirstWord(MAX_VALUE_FOR_WEIGHTING) { } virtual ~DicNodeStateScoring() {} @@ -45,6 +46,7 @@ class DicNodeStateScoring { mRawLength = 0.0f; mDoubleLetterLevel = NOT_A_DOUBLE_LETTER; mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX; + mNormalizedCompoundDistanceAfterFirstWord = MAX_VALUE_FOR_WEIGHTING; mExactMatch = true; } @@ -58,6 +60,8 @@ class DicNodeStateScoring { mDoubleLetterLevel = scoring->mDoubleLetterLevel; mDigraphIndex = scoring->mDigraphIndex; mExactMatch = scoring->mExactMatch; + mNormalizedCompoundDistanceAfterFirstWord = + scoring->mNormalizedCompoundDistanceAfterFirstWord; } void addCost(const float spatialCost, const float languageCost, const bool doNormalization, @@ -86,6 +90,17 @@ class DicNodeStateScoring { } } + // Saves the current normalized distance for space-aware gestures. + // See getNormalizedCompoundDistanceAfterFirstWord for details. + void saveNormalizedCompoundDistanceAfterFirstWordIfNoneYet() { + // We get called here after each word. We only want to store the distance after + // the first word, so if we already have a distance we skip saving -- hence "IfNoneYet" + // in the method name. + if (mNormalizedCompoundDistanceAfterFirstWord >= MAX_VALUE_FOR_WEIGHTING) { + mNormalizedCompoundDistanceAfterFirstWord = getNormalizedCompoundDistance(); + } + } + void addRawLength(const float rawLength) { mRawLength += rawLength; } @@ -102,6 +117,13 @@ class DicNodeStateScoring { return mNormalizedCompoundDistance; } + // For space-aware gestures, we store the normalized distance at the char index + // that ends the first word of the suggestion. We call this the distance after + // first word. + float getNormalizedCompoundDistanceAfterFirstWord() const { + return mNormalizedCompoundDistanceAfterFirstWord; + } + float getSpatialDistance() const { return mSpatialDistance; } @@ -178,6 +200,7 @@ class DicNodeStateScoring { float mLanguageDistance; float mRawLength; bool mExactMatch; + float mNormalizedCompoundDistanceAfterFirstWord; AK_FORCE_INLINE void addDistance(float spatialDistance, float languageDistance, bool doNormalization, int inputSize, int totalInputIndex) { diff --git a/native/jni/src/suggest/core/dictionary/dictionary.cpp b/native/jni/src/suggest/core/dictionary/dictionary.cpp index b1d01ed86..59ead1894 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.cpp +++ b/native/jni/src/suggest/core/dictionary/dictionary.cpp @@ -55,14 +55,14 @@ int Dictionary::getSuggestions(ProximityInfo *proximityInfo, DicTraverseSession int *xcoordinates, int *ycoordinates, int *times, int *pointerIds, int *inputCodePoints, int inputSize, int *prevWordCodePoints, int prevWordLength, int commitPoint, const SuggestOptions *const suggestOptions, int *outWords, int *frequencies, - int *spaceIndices, int *outputTypes) const { + int *spaceIndices, int *outputTypes, int *outputAutoCommitFirstWordConfidence) const { int result = 0; if (suggestOptions->isGesture()) { DicTraverseSession::initSessionInstance( traverseSession, this, prevWordCodePoints, prevWordLength, suggestOptions); result = mGestureSuggest->getSuggestions(proximityInfo, traverseSession, xcoordinates, ycoordinates, times, pointerIds, inputCodePoints, inputSize, commitPoint, outWords, - frequencies, spaceIndices, outputTypes); + frequencies, spaceIndices, outputTypes, outputAutoCommitFirstWordConfidence); if (DEBUG_DICT) { DUMP_RESULT(outWords, frequencies); } @@ -72,7 +72,8 @@ int Dictionary::getSuggestions(ProximityInfo *proximityInfo, DicTraverseSession traverseSession, this, prevWordCodePoints, prevWordLength, suggestOptions); result = mTypingSuggest->getSuggestions(proximityInfo, traverseSession, xcoordinates, ycoordinates, times, pointerIds, inputCodePoints, inputSize, commitPoint, - outWords, frequencies, spaceIndices, outputTypes); + outWords, frequencies, spaceIndices, outputTypes, + outputAutoCommitFirstWordConfidence); if (DEBUG_DICT) { DUMP_RESULT(outWords, frequencies); } @@ -128,7 +129,7 @@ bool Dictionary::needsToRunGC(const bool mindsBlockByGC) { } void Dictionary::getProperty(const char *const query, char *const outResult, - const int maxResultLength) const { + const int maxResultLength) { return mDictionaryStructureWithBufferPolicy->getProperty(query, outResult, maxResultLength); } diff --git a/native/jni/src/suggest/core/dictionary/dictionary.h b/native/jni/src/suggest/core/dictionary/dictionary.h index d8a0f3e58..0195d5bf0 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.h +++ b/native/jni/src/suggest/core/dictionary/dictionary.h @@ -60,7 +60,7 @@ class Dictionary { int *xcoordinates, int *ycoordinates, int *times, int *pointerIds, int *inputCodePoints, int inputSize, int *prevWordCodePoints, int prevWordLength, int commitPoint, const SuggestOptions *const suggestOptions, int *outWords, int *frequencies, - int *spaceIndices, int *outputTypes) const; + int *spaceIndices, int *outputTypes, int *outputAutoCommitFirstWordConfidence) const; int getBigrams(const int *word, int length, int *outWords, int *frequencies, int *outputTypes) const; @@ -84,7 +84,7 @@ class Dictionary { bool needsToRunGC(const bool mindsBlockByGC); void getProperty(const char *const query, char *const outResult, - const int maxResultLength) const; + const int maxResultLength); const DictionaryStructureWithBufferPolicy *getDictionaryStructurePolicy() const { return mDictionaryStructureWithBufferPolicy; diff --git a/native/jni/src/suggest/core/layout/proximity_info_params.cpp b/native/jni/src/suggest/core/layout/proximity_info_params.cpp index 0e887f700..49df10301 100644 --- a/native/jni/src/suggest/core/layout/proximity_info_params.cpp +++ b/native/jni/src/suggest/core/layout/proximity_info_params.cpp @@ -69,13 +69,13 @@ const float ProximityInfoParams::STRAIGHT_ANGLE_THRESHOLD = M_PI_F * 15.0f / 180 const float ProximityInfoParams::SKIP_CORNER_PROBABILITY = 0.4f; const float ProximityInfoParams::SPEED_MARGIN = 0.1f; const float ProximityInfoParams::CENTER_VALUE_OF_NORMALIZED_DISTRIBUTION = 0.0f; -// TODO: The variance is critical for accuracy; thus, adjusting these parameter by machine +// TODO: The variance is critical for accuracy; thus, adjusting these parameters by machine // learning or something would be efficient. -const float ProximityInfoParams::SPEEDxANGLE_WEIGHT_FOR_STANDARD_DIVIATION = 0.3f; -const float ProximityInfoParams::MAX_SPEEDxANGLE_RATE_FOR_STANDERD_DIVIATION = 0.25f; -const float ProximityInfoParams::SPEEDxNEAREST_WEIGHT_FOR_STANDARD_DIVIATION = 0.5f; -const float ProximityInfoParams::MAX_SPEEDxNEAREST_RATE_FOR_STANDERD_DIVIATION = 0.15f; -const float ProximityInfoParams::MIN_STANDERD_DIVIATION = 0.37f; +const float ProximityInfoParams::SPEEDxANGLE_WEIGHT_FOR_STANDARD_DEVIATION = 0.3f; +const float ProximityInfoParams::MAX_SPEEDxANGLE_RATE_FOR_STANDARD_DEVIATION = 0.25f; +const float ProximityInfoParams::SPEEDxNEAREST_WEIGHT_FOR_STANDARD_DEVIATION = 0.5f; +const float ProximityInfoParams::MAX_SPEEDxNEAREST_RATE_FOR_STANDARD_DEVIATION = 0.15f; +const float ProximityInfoParams::MIN_STANDARD_DEVIATION = 0.37f; const float ProximityInfoParams::PREV_DISTANCE_WEIGHT = 0.5f; const float ProximityInfoParams::NEXT_DISTANCE_WEIGHT = 0.6f; diff --git a/native/jni/src/suggest/core/layout/proximity_info_params.h b/native/jni/src/suggest/core/layout/proximity_info_params.h index 4e47f7308..ae1f82c22 100644 --- a/native/jni/src/suggest/core/layout/proximity_info_params.h +++ b/native/jni/src/suggest/core/layout/proximity_info_params.h @@ -73,11 +73,11 @@ class ProximityInfoParams { static const float SKIP_CORNER_PROBABILITY; static const float SPEED_MARGIN; static const float CENTER_VALUE_OF_NORMALIZED_DISTRIBUTION; - static const float SPEEDxANGLE_WEIGHT_FOR_STANDARD_DIVIATION; - static const float MAX_SPEEDxANGLE_RATE_FOR_STANDERD_DIVIATION; - static const float SPEEDxNEAREST_WEIGHT_FOR_STANDARD_DIVIATION; - static const float MAX_SPEEDxNEAREST_RATE_FOR_STANDERD_DIVIATION; - static const float MIN_STANDERD_DIVIATION; + static const float SPEEDxANGLE_WEIGHT_FOR_STANDARD_DEVIATION; + static const float MAX_SPEEDxANGLE_RATE_FOR_STANDARD_DEVIATION; + static const float SPEEDxNEAREST_WEIGHT_FOR_STANDARD_DEVIATION; + static const float MAX_SPEEDxNEAREST_RATE_FOR_STANDARD_DEVIATION; + static const float MIN_STANDARD_DEVIATION; static const float PREV_DISTANCE_WEIGHT; static const float NEXT_DISTANCE_WEIGHT; diff --git a/native/jni/src/suggest/core/layout/proximity_info_state_utils.cpp b/native/jni/src/suggest/core/layout/proximity_info_state_utils.cpp index 904671f7f..e1b35340b 100644 --- a/native/jni/src/suggest/core/layout/proximity_info_state_utils.cpp +++ b/native/jni/src/suggest/core/layout/proximity_info_state_utils.cpp @@ -708,13 +708,13 @@ namespace latinime { const float inputCharProbability = 1.0f - skipProbability; const float speedxAngleRate = min(speedRate * currentAngle / M_PI_F - * ProximityInfoParams::SPEEDxANGLE_WEIGHT_FOR_STANDARD_DIVIATION, - ProximityInfoParams::MAX_SPEEDxANGLE_RATE_FOR_STANDERD_DIVIATION); + * ProximityInfoParams::SPEEDxANGLE_WEIGHT_FOR_STANDARD_DEVIATION, + ProximityInfoParams::MAX_SPEEDxANGLE_RATE_FOR_STANDARD_DEVIATION); const float speedxNearestKeyDistanceRate = min(speedRate * nearestKeyDistance - * ProximityInfoParams::SPEEDxNEAREST_WEIGHT_FOR_STANDARD_DIVIATION, - ProximityInfoParams::MAX_SPEEDxNEAREST_RATE_FOR_STANDERD_DIVIATION); + * ProximityInfoParams::SPEEDxNEAREST_WEIGHT_FOR_STANDARD_DEVIATION, + ProximityInfoParams::MAX_SPEEDxNEAREST_RATE_FOR_STANDARD_DEVIATION); const float sigma = speedxAngleRate + speedxNearestKeyDistanceRate - + ProximityInfoParams::MIN_STANDERD_DIVIATION; + + ProximityInfoParams::MIN_STANDARD_DEVIATION; ProximityInfoUtils::NormalDistribution distribution(ProximityInfoParams::CENTER_VALUE_OF_NORMALIZED_DISTRIBUTION, sigma); diff --git a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h index c7ffef0d5..41f82049f 100644 --- a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h +++ b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h @@ -80,8 +80,10 @@ class DictionaryStructureWithBufferPolicy { virtual bool needsToRunGC(const bool mindsBlockByGC) const = 0; + // Currently, this method is used only for testing. You may want to consider creating new + // dedicated method instead of this if you want to use this in the production. virtual void getProperty(const char *const query, char *const outResult, - const int maxResultLength) const = 0; + const int maxResultLength) = 0; protected: DictionaryStructureWithBufferPolicy() {} diff --git a/native/jni/src/suggest/core/policy/weighting.cpp b/native/jni/src/suggest/core/policy/weighting.cpp index f9b777df2..0c4016893 100644 --- a/native/jni/src/suggest/core/policy/weighting.cpp +++ b/native/jni/src/suggest/core/policy/weighting.cpp @@ -38,7 +38,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n case CT_SUBSTITUTION: PROF_SUBSTITUTION(node->mProfiler); return; - case CT_NEW_WORD_SPACE_OMITTION: + case CT_NEW_WORD_SPACE_OMISSION: PROF_NEW_WORD(node->mProfiler); return; case CT_MATCH: @@ -93,6 +93,11 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n } dicNode->addCost(spatialCost, languageCost, weighting->needsToNormalizeCompoundDistance(), inputSize, errorType); + if (CT_NEW_WORD_SPACE_OMISSION == correctionType) { + // When we are on a terminal, we save the current distance for evaluating + // when to auto-commit partial suggestions. + dicNode->saveNormalizedCompoundDistanceAfterFirstWordIfNoneYet(); + } } /* static */ float Weighting::getSpatialCost(const Weighting *const weighting, @@ -108,7 +113,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n case CT_SUBSTITUTION: // only used for typing return weighting->getSubstitutionCost(); - case CT_NEW_WORD_SPACE_OMITTION: + case CT_NEW_WORD_SPACE_OMISSION: return weighting->getNewWordSpatialCost(traverseSession, dicNode, inputStateG); case CT_MATCH: return weighting->getMatchedCost(traverseSession, dicNode, inputStateG); @@ -138,7 +143,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n return 0.0f; case CT_SUBSTITUTION: return 0.0f; - case CT_NEW_WORD_SPACE_OMITTION: + case CT_NEW_WORD_SPACE_OMISSION: return weighting->getNewWordBigramLanguageCost( traverseSession, parentDicNode, multiBigramMap); case CT_MATCH: @@ -173,7 +178,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n return 0; /* 0 because CT_MATCH will be called */ case CT_SUBSTITUTION: return 0; /* 0 because CT_MATCH will be called */ - case CT_NEW_WORD_SPACE_OMITTION: + case CT_NEW_WORD_SPACE_OMISSION: return 0; case CT_MATCH: return 1; diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp index b1340e12f..73ccebc88 100644 --- a/native/jni/src/suggest/core/suggest.cpp +++ b/native/jni/src/suggest/core/suggest.cpp @@ -49,7 +49,7 @@ const float Suggest::AUTOCORRECT_CLASSIFICATION_THRESHOLD = 0.33f; int Suggest::getSuggestions(ProximityInfo *pInfo, void *traverseSession, int *inputXs, int *inputYs, int *times, int *pointerIds, int *inputCodePoints, int inputSize, int commitPoint, int *outWords, int *frequencies, int *outputIndices, - int *outputTypes) const { + int *outputTypes, int *outputAutoCommitFirstWordConfidence) const { PROF_OPEN; PROF_START(0); const float maxSpatialDistance = TRAVERSAL->getMaxSpatialDistance(); @@ -70,7 +70,8 @@ int Suggest::getSuggestions(ProximityInfo *pInfo, void *traverseSession, } PROF_END(1); PROF_START(2); - const int size = outputSuggestions(tSession, frequencies, outWords, outputIndices, outputTypes); + const int size = outputSuggestions(tSession, frequencies, outWords, outputIndices, outputTypes, + outputAutoCommitFirstWordConfidence); PROF_END(2); PROF_CLOSE; return size; @@ -117,7 +118,8 @@ void Suggest::initializeSearch(DicTraverseSession *traverseSession, int commitPo * Outputs the final list of suggestions (i.e., terminal nodes). */ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequencies, - int *outputCodePoints, int *outputIndicesToPartialCommit, int *outputTypes) const { + int *outputCodePoints, int *outputIndicesToPartialCommit, int *outputTypes, + int *outputAutoCommitFirstWordConfidence) const { #if DEBUG_EVALUATE_MOST_PROBABLE_STRING const int terminalSize = 0; #else @@ -164,6 +166,12 @@ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequen // TODO: have partial commit work even with multiple pointers. const bool outputSecondWordFirstLetterInputIndex = traverseSession->isOnlyOnePointerUsed(0 /* pointerId */); + if (terminalSize > 0) { + // If we have no suggestions, don't write this + outputAutoCommitFirstWordConfidence[0] = + computeFirstWordConfidence(&terminals[0]); + } + // Output suggestion results here for (int terminalIndex = 0; terminalIndex < terminalSize && outputWordIndex < MAX_RESULTS; ++terminalIndex) { @@ -251,6 +259,57 @@ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequen return outputWordIndex; } +int Suggest::computeFirstWordConfidence(const DicNode *const terminalDicNode) const { + // Get the number of spaces in the first suggestion + const int spaceCount = terminalDicNode->getTotalNodeSpaceCount(); + // Get the number of characters in the first suggestion + const int length = terminalDicNode->getTotalNodeCodePointCount(); + // Get the distance for the first word of the suggestion + const float distance = terminalDicNode->getNormalizedCompoundDistanceAfterFirstWord(); + + // Arbitrarily, we give a score whose useful values range from 0 to 1,000,000. + // 1,000,000 will be the cutoff to auto-commit. It's fine if the number is under 0 or + // above 1,000,000 : under 0 just means it's very bad to commit, and above 1,000,000 means + // we are very confident. + // Expected space count is 1 ~ 5 + static const int MIN_EXPECTED_SPACE_COUNT = 1; + static const int MAX_EXPECTED_SPACE_COUNT = 5; + // Expected length is about 4 ~ 30 + static const int MIN_EXPECTED_LENGTH = 4; + static const int MAX_EXPECTED_LENGTH = 30; + // Expected distance is about 0.2 ~ 2.0, but consider 0.0 ~ 2.0 + static const float MIN_EXPECTED_DISTANCE = 0.0; + static const float MAX_EXPECTED_DISTANCE = 2.0; + // This is not strict: it's where most stuff will be falling, but it's still fine if it's + // outside these values. We want to output a value that reflects all of these. Each factor + // contributes a bit. + + // We need at least a space. + if (spaceCount < 1) return NOT_A_FIRST_WORD_CONFIDENCE; + + // The smaller the edit distance, the higher the contribution. MIN_EXPECTED_DISTANCE means 0 + // contribution, while MAX_EXPECTED_DISTANCE means full contribution according to the + // weight of the distance. Clamp to avoid overflows. + const float clampedDistance = distance < MIN_EXPECTED_DISTANCE ? MIN_EXPECTED_DISTANCE + : distance > MAX_EXPECTED_DISTANCE ? MAX_EXPECTED_DISTANCE : distance; + const int distanceContribution = DISTANCE_WEIGHT_FOR_AUTO_COMMIT + * (MAX_EXPECTED_DISTANCE - clampedDistance) + / (MAX_EXPECTED_DISTANCE - MIN_EXPECTED_DISTANCE); + // The larger the suggestion length, the larger the contribution. MIN_EXPECTED_LENGTH is no + // contribution, MAX_EXPECTED_LENGTH is full contribution according to the weight of the + // length. Length is guaranteed to be between 1 and 48, so we don't need to clamp. + const int lengthContribution = LENGTH_WEIGHT_FOR_AUTO_COMMIT + * (length - MIN_EXPECTED_LENGTH) / (MAX_EXPECTED_LENGTH - MIN_EXPECTED_LENGTH); + // The more spaces, the larger the contribution. MIN_EXPECTED_SPACE_COUNT space is no + // contribution, MAX_EXPECTED_SPACE_COUNT spaces is full contribution according to the + // weight of the space count. + const int spaceContribution = SPACE_COUNT_WEIGHT_FOR_AUTO_COMMIT + * (spaceCount - MIN_EXPECTED_SPACE_COUNT) + / (MAX_EXPECTED_SPACE_COUNT - MIN_EXPECTED_SPACE_COUNT); + + return distanceContribution + lengthContribution + spaceContribution; +} + /** * Expands the dicNodes in the current search priority queue by advancing to the possible child * nodes based on the next touch point(s) (or no touch points for lookahead) @@ -386,7 +445,7 @@ void Suggest::processTerminalDicNode( if (!dicNode->isTerminalWordNode()) { return; } - if (dicNode->shouldBeFilterdBySafetyNetForBigram()) { + if (dicNode->shouldBeFilteredBySafetyNetForBigram()) { return; } // Create a non-cached node here. @@ -574,7 +633,7 @@ void Suggest::createNextWordDicNode(DicTraverseSession *traverseSession, DicNode DicNodeUtils::initAsRootWithPreviousWord( traverseSession->getDictionaryStructurePolicy(), dicNode, &newDicNode); const CorrectionType correctionType = spaceSubstitution ? - CT_NEW_WORD_SPACE_SUBSTITUTION : CT_NEW_WORD_SPACE_OMITTION; + CT_NEW_WORD_SPACE_SUBSTITUTION : CT_NEW_WORD_SPACE_OMISSION; Weighting::addCostAndForwardInputIndex(WEIGHTING, correctionType, traverseSession, dicNode, &newDicNode, traverseSession->getMultiBigramMap()); if (newDicNode.getCompoundDistance() < static_cast<float>(MAX_VALUE_FOR_WEIGHTING)) { diff --git a/native/jni/src/suggest/core/suggest.h b/native/jni/src/suggest/core/suggest.h index b24019632..b20343d29 100644 --- a/native/jni/src/suggest/core/suggest.h +++ b/native/jni/src/suggest/core/suggest.h @@ -48,14 +48,17 @@ class Suggest : public SuggestInterface { AK_FORCE_INLINE virtual ~Suggest() {} int getSuggestions(ProximityInfo *pInfo, void *traverseSession, int *inputXs, int *inputYs, int *times, int *pointerIds, int *inputCodePoints, int inputSize, int commitPoint, - int *outWords, int *frequencies, int *outputIndices, int *outputTypes) const; + int *outWords, int *frequencies, int *outputIndices, int *outputTypes, + int *outputAutoCommitFirstWordConfidence) const; private: DISALLOW_IMPLICIT_CONSTRUCTORS(Suggest); void createNextWordDicNode(DicTraverseSession *traverseSession, DicNode *dicNode, const bool spaceSubstitution) const; int outputSuggestions(DicTraverseSession *traverseSession, int *frequencies, - int *outputCodePoints, int *outputIndicesToPartialCommit, int *outputTypes) const; + int *outputCodePoints, int *outputIndicesToPartialCommit, int *outputTypes, + int *outputAutoCommitFirstWordConfidence) const; + int computeFirstWordConfidence(const DicNode *const terminalDicNode) const; void initializeSearch(DicTraverseSession *traverseSession, int commitPoint) const; void expandCurrentDicNodes(DicTraverseSession *traverseSession) const; void processTerminalDicNode(DicTraverseSession *traverseSession, DicNode *dicNode) const; diff --git a/native/jni/src/suggest/core/suggest_interface.h b/native/jni/src/suggest/core/suggest_interface.h index 0bb85d7e5..4deb4d924 100644 --- a/native/jni/src/suggest/core/suggest_interface.h +++ b/native/jni/src/suggest/core/suggest_interface.h @@ -28,7 +28,7 @@ class SuggestInterface { virtual int getSuggestions(ProximityInfo *pInfo, void *traverseSession, int *inputXs, int *inputYs, int *times, int *pointerIds, int *inputCodePoints, int inputSize, int commitPoint, int *outWords, int *frequencies, int *outputIndices, - int *outputTypes) const = 0; + int *outputTypes, int *outputAutoCommitFirstWordConfidence) const = 0; SuggestInterface() {} virtual ~SuggestInterface() {} private: diff --git a/native/jni/src/suggest/policyimpl/dictionary/bigram/dynamic_bigram_list_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/bigram/dynamic_bigram_list_policy.cpp index e02f4cbf1..8753c6eb0 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/bigram/dynamic_bigram_list_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/bigram/dynamic_bigram_list_policy.cpp @@ -17,10 +17,10 @@ #include "suggest/policyimpl/dictionary/bigram/dynamic_bigram_list_policy.h" #include "suggest/core/policy/dictionary_shortcuts_structure_policy.h" -#include "suggest/policyimpl/dictionary/bigram/bigram_list_read_write_utils.h" #include "suggest/policyimpl/dictionary/dynamic_patricia_trie_node_reader.h" #include "suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.h" #include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" +#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h" namespace latinime { @@ -41,9 +41,14 @@ void DynamicBigramListPolicy::getNextBigram(int *const outBigramPos, int *const if (usesAdditionalBuffer && originalBigramPos != NOT_A_DICT_POS) { originalBigramPos += mBuffer->getOriginalBufferSize(); } - *outBigramPos = followBigramLinkAndGetCurrentBigramPtNodePos(originalBigramPos); *outProbability = BigramListReadWriteUtils::getProbabilityFromFlags(bigramFlags); *outHasNext = BigramListReadWriteUtils::hasNext(bigramFlags); + if (mIsDecayingDict && !ForgettingCurveUtils::isValidEncodedProbability(*outProbability)) { + // This bigram is too weak to output. + *outBigramPos = NOT_A_DICT_POS; + } else { + *outBigramPos = followBigramLinkAndGetCurrentBigramPtNodePos(originalBigramPos); + } if (usesAdditionalBuffer) { *bigramEntryPos += mBuffer->getOriginalBufferSize(); } @@ -153,15 +158,21 @@ bool DynamicBigramListPolicy::updateAllBigramEntriesAndDeleteUselessEntries( const int bigramTargetNodePos = followBigramLinkAndGetCurrentBigramPtNodePos(originalBigramPos); nodeReader.fetchNodeInfoInBufferFromPtNodePos(bigramTargetNodePos); - // TODO: Update probability for supporting probability decaying. if (nodeReader.isDeleted() || !nodeReader.isTerminal() || bigramTargetNodePos == NOT_A_DICT_POS) { // The target is no longer valid terminal. Invalidate the current bigram entry. if (!BigramListReadWriteUtils::writeBigramEntry(mBuffer, bigramFlags, - NOT_A_DICT_POS /* targetOffset */, &bigramEntryPos)) { + NOT_A_DICT_POS /* targetPtNodePos */, &bigramEntryPos)) { return false; } - } else { + continue; + } + bool isRemoved = false; + if (!updateProbabilityForDecay(bigramFlags, bigramTargetNodePos, &bigramEntryPos, + &isRemoved)) { + return false; + } + if (!isRemoved) { (*outValidBigramEntryCount) += 1; } } while(BigramListReadWriteUtils::hasNext(bigramFlags)); @@ -247,8 +258,14 @@ bool DynamicBigramListPolicy::addNewBigramEntryToBigramList(const int bigramTarg if (followBigramLinkAndGetCurrentBigramPtNodePos(originalBigramPos) == bigramTargetPos) { // Update this bigram entry. *outAddedNewBigram = false; + const int originalProbability = BigramListReadWriteUtils::getProbabilityFromFlags( + bigramFlags); + const int probabilityToWrite = mIsDecayingDict ? + ForgettingCurveUtils::getUpdatedEncodedProbability(originalProbability, + probability) : probability; const BigramListReadWriteUtils::BigramFlags updatedFlags = - BigramListReadWriteUtils::setProbabilityInFlags(bigramFlags, probability); + BigramListReadWriteUtils::setProbabilityInFlags(bigramFlags, + probabilityToWrite); return BigramListReadWriteUtils::writeBigramEntry(mBuffer, updatedFlags, originalBigramPos, &entryPos); } @@ -276,8 +293,11 @@ bool DynamicBigramListPolicy::addNewBigramEntryToBigramList(const int bigramTarg bool DynamicBigramListPolicy::writeNewBigramEntry(const int bigramTargetPos, const int probability, int *const writingPos) { // hasNext is false because we are adding a new bigram entry at the end of the bigram list. + const int probabilityToWrite = mIsDecayingDict ? + ForgettingCurveUtils::getUpdatedEncodedProbability(NOT_A_PROBABILITY, probability) : + probability; return BigramListReadWriteUtils::createAndWriteBigramEntry(mBuffer, bigramTargetPos, - probability, false /* hasNext */, writingPos); + probabilityToWrite, false /* hasNext */, writingPos); } bool DynamicBigramListPolicy::removeBigram(const int bigramListPos, const int bigramTargetPos) { @@ -339,4 +359,33 @@ int DynamicBigramListPolicy::followBigramLinkAndGetCurrentBigramPtNodePos( return currentPos; } +bool DynamicBigramListPolicy::updateProbabilityForDecay( + BigramListReadWriteUtils::BigramFlags bigramFlags, const int targetPtNodePos, + int *const bigramEntryPos, bool *const outRemoved) const { + *outRemoved = false; + if (mIsDecayingDict) { + // Update bigram probability for decaying. + const int newProbability = ForgettingCurveUtils::getEncodedProbabilityToSave( + BigramListReadWriteUtils::getProbabilityFromFlags(bigramFlags)); + if (ForgettingCurveUtils::isValidEncodedProbability(newProbability)) { + // Write new probability. + const BigramListReadWriteUtils::BigramFlags updatedBigramFlags = + BigramListReadWriteUtils::setProbabilityInFlags( + bigramFlags, newProbability); + if (!BigramListReadWriteUtils::writeBigramEntry(mBuffer, updatedBigramFlags, + targetPtNodePos, bigramEntryPos)) { + return false; + } + } else { + // Remove current bigram entry. + *outRemoved = true; + if (!BigramListReadWriteUtils::writeBigramEntry(mBuffer, bigramFlags, + NOT_A_DICT_POS /* targetPtNodePos */, bigramEntryPos)) { + return false; + } + } + } + return true; +} + } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/bigram/dynamic_bigram_list_policy.h b/native/jni/src/suggest/policyimpl/dictionary/bigram/dynamic_bigram_list_policy.h index 3ebf69946..b358b4ed5 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/bigram/dynamic_bigram_list_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/bigram/dynamic_bigram_list_policy.h @@ -21,6 +21,7 @@ #include "defines.h" #include "suggest/core/policy/dictionary_bigrams_structure_policy.h" +#include "suggest/policyimpl/dictionary/bigram/bigram_list_read_write_utils.h" #include "suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.h" namespace latinime { @@ -34,8 +35,9 @@ class DictionaryShortcutsStructurePolicy; class DynamicBigramListPolicy : public DictionaryBigramsStructurePolicy { public: DynamicBigramListPolicy(BufferWithExtendableBuffer *const buffer, - const DictionaryShortcutsStructurePolicy *const shortcutPolicy) - : mBuffer(buffer), mShortcutPolicy(shortcutPolicy) {} + const DictionaryShortcutsStructurePolicy *const shortcutPolicy, + const bool isDecayingDict) + : mBuffer(buffer), mShortcutPolicy(shortcutPolicy), mIsDecayingDict(isDecayingDict) {} ~DynamicBigramListPolicy() {} @@ -74,9 +76,13 @@ class DynamicBigramListPolicy : public DictionaryBigramsStructurePolicy { BufferWithExtendableBuffer *const mBuffer; const DictionaryShortcutsStructurePolicy *const mShortcutPolicy; + const bool mIsDecayingDict; // Follow bigram link and return the position of bigram target PtNode that is currently valid. int followBigramLinkAndGetCurrentBigramPtNodePos(const int originalBigramPos) const; + + bool updateProbabilityForDecay(BigramListReadWriteUtils::BigramFlags bigramFlags, + const int targetPtNodePos, int *const bigramEntryPos, bool *const outRemoved) const; }; } // namespace latinime #endif // LATINIME_DYNAMIC_BIGRAM_LIST_POLICY_H diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_gc_event_listeners.cpp b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_gc_event_listeners.cpp index 5f755c302..324b53062 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_gc_event_listeners.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_gc_event_listeners.cpp @@ -16,6 +16,8 @@ #include "suggest/policyimpl/dictionary/dynamic_patricia_trie_gc_event_listeners.h" +#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h" + namespace latinime { bool DynamicPatriciaTrieGcEventListeners @@ -25,6 +27,19 @@ bool DynamicPatriciaTrieGcEventListeners // PtNode is useless when the PtNode is not a terminal and doesn't have any not useless // children. bool isUselessPtNode = !node->isTerminal(); + if (node->isTerminal() && mIsDecayingDict) { + const int newProbability = + ForgettingCurveUtils::getEncodedProbabilityToSave(node->getProbability()); + int writingPos = node->getProbabilityFieldPos(); + // Update probability. + if (!DynamicPatriciaTrieWritingUtils::writeProbabilityAndAdvancePosition( + mBuffer, newProbability, &writingPos)) { + return false; + } + if (!ForgettingCurveUtils::isValidEncodedProbability(newProbability)) { + isUselessPtNode = false; + } + } if (mChildrenValue > 0) { isUselessPtNode = false; } else if (node->isTerminal()) { diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_gc_event_listeners.h b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_gc_event_listeners.h index 301998882..463715af5 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_gc_event_listeners.h +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_gc_event_listeners.h @@ -39,9 +39,9 @@ class DynamicPatriciaTrieGcEventListeners { public: TraversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted( DynamicPatriciaTrieWritingHelper *const writingHelper, - BufferWithExtendableBuffer *const buffer) - : mWritingHelper(writingHelper), mBuffer(buffer), mValueStack(), - mChildrenValue(0), mValidUnigramCount(0) {} + BufferWithExtendableBuffer *const buffer, const bool isDecayingDict) + : mWritingHelper(writingHelper), mBuffer(buffer), mIsDecayingDict(isDecayingDict), + mValueStack(), mChildrenValue(0), mValidUnigramCount(0) {} ~TraversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted() {}; @@ -74,6 +74,7 @@ class DynamicPatriciaTrieGcEventListeners { DynamicPatriciaTrieWritingHelper *const mWritingHelper; BufferWithExtendableBuffer *const mBuffer; + const int mIsDecayingDict; std::vector<int> mValueStack; int mChildrenValue; int mValidUnigramCount; diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.cpp index 8c0890e2e..60d0db0c0 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.cpp @@ -18,6 +18,7 @@ #include <cstdio> #include <cstring> +#include <ctime> #include "defines.h" #include "suggest/core/dicnode/dic_node.h" @@ -27,12 +28,21 @@ #include "suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.h" #include "suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.h" #include "suggest/policyimpl/dictionary/patricia_trie_reading_utils.h" +#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h" #include "suggest/policyimpl/dictionary/utils/probability_utils.h" namespace latinime { +// Note that these are corresponding definitions in Java side in BinaryDictionaryTests and +// BinaryDictionaryDecayingTests. const char *const DynamicPatriciaTriePolicy::UNIGRAM_COUNT_QUERY = "UNIGRAM_COUNT"; const char *const DynamicPatriciaTriePolicy::BIGRAM_COUNT_QUERY = "BIGRAM_COUNT"; +const char *const DynamicPatriciaTriePolicy::SET_NEEDS_TO_DECAY_FOR_TESTING_QUERY = + "SET_NEEDS_TO_DECAY_FOR_TESTING"; +const int DynamicPatriciaTriePolicy::MAX_DICT_EXTENDED_REGION_SIZE = 1024 * 1024; +const int DynamicPatriciaTriePolicy::MIN_DICT_SIZE_TO_REFUSE_DYNAMIC_OPERATIONS = + DynamicPatriciaTrieWritingHelper::MAX_DICTIONARY_SIZE - 1024; +const int DynamicPatriciaTriePolicy::DECAY_INTERVAL_FOR_DECAYING_DICTS = 2 * 60 * 60; void DynamicPatriciaTriePolicy::createAndGetAllChildNodes(const DicNode *const dicNode, DicNodeVector *const childDicNodes) const { @@ -143,14 +153,17 @@ int DynamicPatriciaTriePolicy::getTerminalNodePositionOfWord(const int *const in int DynamicPatriciaTriePolicy::getProbability(const int unigramProbability, const int bigramProbability) const { - // TODO: check mHeaderPolicy.usesForgettingCurve(); - if (unigramProbability == NOT_A_PROBABILITY) { - return NOT_A_PROBABILITY; - } else if (bigramProbability == NOT_A_PROBABILITY) { - return ProbabilityUtils::backoff(unigramProbability); + if (mHeaderPolicy.isDecayingDict()) { + return ForgettingCurveUtils::getProbability(unigramProbability, bigramProbability); } else { - return ProbabilityUtils::computeProbabilityForBigram(unigramProbability, - bigramProbability); + if (unigramProbability == NOT_A_PROBABILITY) { + return NOT_A_PROBABILITY; + } else if (bigramProbability == NOT_A_PROBABILITY) { + return ProbabilityUtils::backoff(unigramProbability); + } else { + return ProbabilityUtils::computeProbabilityForBigram(unigramProbability, + bigramProbability); + } } } @@ -199,11 +212,16 @@ bool DynamicPatriciaTriePolicy::addUnigramWord(const int *const word, const int AKLOGI("Warning: addUnigramWord() is called for non-updatable dictionary."); return false; } + if (mBufferWithExtendableBuffer.getTailPosition() + >= MIN_DICT_SIZE_TO_REFUSE_DYNAMIC_OPERATIONS) { + AKLOGE("The dictionary is too large to dynamically update."); + return false; + } DynamicPatriciaTrieReadingHelper readingHelper(&mBufferWithExtendableBuffer, getBigramsStructurePolicy(), getShortcutsStructurePolicy()); readingHelper.initWithPtNodeArrayPos(getRootPosition()); DynamicPatriciaTrieWritingHelper writingHelper(&mBufferWithExtendableBuffer, - &mBigramListPolicy, &mShortcutListPolicy); + &mBigramListPolicy, &mShortcutListPolicy, mHeaderPolicy.isDecayingDict()); bool addedNewUnigram = false; if (writingHelper.addUnigramWord(&readingHelper, word, length, probability, &addedNewUnigram)) { @@ -222,6 +240,11 @@ bool DynamicPatriciaTriePolicy::addBigramWords(const int *const word0, const int AKLOGI("Warning: addBigramWords() is called for non-updatable dictionary."); return false; } + if (mBufferWithExtendableBuffer.getTailPosition() + >= MIN_DICT_SIZE_TO_REFUSE_DYNAMIC_OPERATIONS) { + AKLOGE("The dictionary is too large to dynamically update."); + return false; + } const int word0Pos = getTerminalNodePositionOfWord(word0, length0, false /* forceLowerCaseSearch */); if (word0Pos == NOT_A_DICT_POS) { @@ -233,7 +256,7 @@ bool DynamicPatriciaTriePolicy::addBigramWords(const int *const word0, const int return false; } DynamicPatriciaTrieWritingHelper writingHelper(&mBufferWithExtendableBuffer, - &mBigramListPolicy, &mShortcutListPolicy); + &mBigramListPolicy, &mShortcutListPolicy, mHeaderPolicy.isDecayingDict()); bool addedNewBigram = false; if (writingHelper.addBigramWords(word0Pos, word1Pos, probability, &addedNewBigram)) { if (addedNewBigram) { @@ -251,6 +274,11 @@ bool DynamicPatriciaTriePolicy::removeBigramWords(const int *const word0, const AKLOGI("Warning: removeBigramWords() is called for non-updatable dictionary."); return false; } + if (mBufferWithExtendableBuffer.getTailPosition() + >= MIN_DICT_SIZE_TO_REFUSE_DYNAMIC_OPERATIONS) { + AKLOGE("The dictionary is too large to dynamically update."); + return false; + } const int word0Pos = getTerminalNodePositionOfWord(word0, length0, false /* forceLowerCaseSearch */); if (word0Pos == NOT_A_DICT_POS) { @@ -262,7 +290,7 @@ bool DynamicPatriciaTriePolicy::removeBigramWords(const int *const word0, const return false; } DynamicPatriciaTrieWritingHelper writingHelper(&mBufferWithExtendableBuffer, - &mBigramListPolicy, &mShortcutListPolicy); + &mBigramListPolicy, &mShortcutListPolicy, mHeaderPolicy.isDecayingDict()); if (writingHelper.removeBigramWords(word0Pos, word1Pos)) { mBigramCount--; return true; @@ -277,7 +305,7 @@ void DynamicPatriciaTriePolicy::flush(const char *const filePath) { return; } DynamicPatriciaTrieWritingHelper writingHelper(&mBufferWithExtendableBuffer, - &mBigramListPolicy, &mShortcutListPolicy); + &mBigramListPolicy, &mShortcutListPolicy, false /* needsToDecay */); writingHelper.writeToDictFile(filePath, &mHeaderPolicy, mUnigramCount, mBigramCount); } @@ -286,9 +314,15 @@ void DynamicPatriciaTriePolicy::flushWithGC(const char *const filePath) { AKLOGI("Warning: flushWithGC() is called for non-updatable dictionary."); return; } + const bool runGCwithDecay = needsToDecay(); + DynamicBigramListPolicy bigramListPolicyForGC(&mBufferWithExtendableBuffer, + &mShortcutListPolicy, runGCwithDecay); DynamicPatriciaTrieWritingHelper writingHelper(&mBufferWithExtendableBuffer, - &mBigramListPolicy, &mShortcutListPolicy); + &bigramListPolicyForGC, &mShortcutListPolicy, runGCwithDecay); writingHelper.writeToDictFileWithGC(getRootPosition(), filePath, &mHeaderPolicy); + if (runGCwithDecay) { + mNeedsToDecayForTesting = false; + } } bool DynamicPatriciaTriePolicy::needsToRunGC(const bool mindsBlockByGC) const { @@ -296,17 +330,48 @@ bool DynamicPatriciaTriePolicy::needsToRunGC(const bool mindsBlockByGC) const { AKLOGI("Warning: needsToRunGC() is called for non-updatable dictionary."); return false; } - // TODO: Implement more properly. - return mBufferWithExtendableBuffer.isNearSizeLimit(); + if (mBufferWithExtendableBuffer.isNearSizeLimit()) { + // Additional buffer size is near the limit. + return true; + } else if (mHeaderPolicy.getExtendedRegionSize() + + mBufferWithExtendableBuffer.getUsedAdditionalBufferSize() + > MAX_DICT_EXTENDED_REGION_SIZE) { + // Total extended region size exceeds the limit. + return true; + } else if (mBufferWithExtendableBuffer.getTailPosition() + >= MIN_DICT_SIZE_TO_REFUSE_DYNAMIC_OPERATIONS + && mBufferWithExtendableBuffer.getUsedAdditionalBufferSize() > 0) { + // Needs to reduce dictionary size. + return true; + } else if (mHeaderPolicy.isDecayingDict()) { + if (mUnigramCount >= ForgettingCurveUtils::MAX_UNIGRAM_COUNT) { + // Unigram count exceeds the limit. + return true; + } else if (mBigramCount >= ForgettingCurveUtils::MAX_BIGRAM_COUNT) { + // Bigram count exceeds the limit. + return true; + } else if (mindsBlockByGC && needsToDecay()) { + // Time to update probabilities for decaying. + return true; + } + } + return false; } void DynamicPatriciaTriePolicy::getProperty(const char *const query, char *const outResult, - const int maxResultLength) const { + const int maxResultLength) { if (strncmp(query, UNIGRAM_COUNT_QUERY, maxResultLength) == 0) { snprintf(outResult, maxResultLength, "%d", mUnigramCount); } else if (strncmp(query, BIGRAM_COUNT_QUERY, maxResultLength) == 0) { snprintf(outResult, maxResultLength, "%d", mBigramCount); + } else if (strncmp(query, SET_NEEDS_TO_DECAY_FOR_TESTING_QUERY, maxResultLength) == 0) { + mNeedsToDecayForTesting = true; } } +bool DynamicPatriciaTriePolicy::needsToDecay() const { + return mHeaderPolicy.isDecayingDict() && (mNeedsToDecayForTesting + || mHeaderPolicy.getLastDecayedTime() + DECAY_INTERVAL_FOR_DECAYING_DICTS < time(0)); +} + } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.h index bdb436c8e..c3bbe9977 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.h @@ -37,9 +37,10 @@ class DynamicPatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { mBufferWithExtendableBuffer(mBuffer->getBuffer() + mHeaderPolicy.getSize(), mBuffer->getBufferSize() - mHeaderPolicy.getSize()), mShortcutListPolicy(&mBufferWithExtendableBuffer), - mBigramListPolicy(&mBufferWithExtendableBuffer, &mShortcutListPolicy), + mBigramListPolicy(&mBufferWithExtendableBuffer, &mShortcutListPolicy, + mHeaderPolicy.isDecayingDict()), mUnigramCount(mHeaderPolicy.getUnigramCount()), - mBigramCount(mHeaderPolicy.getBigramCount()) {} + mBigramCount(mHeaderPolicy.getBigramCount()), mNeedsToDecayForTesting(false) {} ~DynamicPatriciaTriePolicy() { delete mBuffer; @@ -94,13 +95,17 @@ class DynamicPatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { bool needsToRunGC(const bool mindsBlockByGC) const; void getProperty(const char *const query, char *const outResult, - const int maxResultLength) const; + const int maxResultLength); private: DISALLOW_IMPLICIT_CONSTRUCTORS(DynamicPatriciaTriePolicy); - static const char*const UNIGRAM_COUNT_QUERY; - static const char*const BIGRAM_COUNT_QUERY; + static const char *const UNIGRAM_COUNT_QUERY; + static const char *const BIGRAM_COUNT_QUERY; + static const char *const SET_NEEDS_TO_DECAY_FOR_TESTING_QUERY; + static const int MAX_DICT_EXTENDED_REGION_SIZE; + static const int MIN_DICT_SIZE_TO_REFUSE_DYNAMIC_OPERATIONS; + static const int DECAY_INTERVAL_FOR_DECAYING_DICTS; const MmappedBuffer *const mBuffer; const HeaderPolicy mHeaderPolicy; @@ -109,6 +114,9 @@ class DynamicPatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { DynamicBigramListPolicy mBigramListPolicy; int mUnigramCount; int mBigramCount; + int mNeedsToDecayForTesting; + + bool needsToDecay() const; }; } // namespace latinime #endif // LATINIME_DYNAMIC_PATRICIA_TRIE_POLICY_H diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.cpp index 2a2e9bcbe..70a9ee564 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.cpp @@ -26,6 +26,7 @@ #include "suggest/policyimpl/dictionary/patricia_trie_reading_utils.h" #include "suggest/policyimpl/dictionary/shortcut/dynamic_shortcut_list_policy.h" #include "suggest/policyimpl/dictionary/utils/dict_file_writing_utils.h" +#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h" #include "utils/hash_map_compat.h" namespace latinime { @@ -57,7 +58,9 @@ bool DynamicPatriciaTrieWritingHelper::addUnigramWord( wordCodePoints[matchedCodePointCount + j])) { *outAddedNewUnigram = true; return reallocatePtNodeAndAddNewPtNodes(nodeReader, - readingHelper->getMergedNodeCodePoints(), j, probability, + readingHelper->getMergedNodeCodePoints(), j, + getUpdatedProbability(NOT_A_PROBABILITY /* originalProbability */, + probability), wordCodePoints + matchedCodePointCount, codePointCount - matchedCodePointCount); } @@ -69,7 +72,8 @@ bool DynamicPatriciaTrieWritingHelper::addUnigramWord( } if (!nodeReader->hasChildren()) { *outAddedNewUnigram = true; - return createChildrenPtNodeArrayAndAChildPtNode(nodeReader, probability, + return createChildrenPtNodeArrayAndAChildPtNode(nodeReader, + getUpdatedProbability(NOT_A_PROBABILITY /* originalProbability */, probability), wordCodePoints + readingHelper->getTotalCodePointCount(), codePointCount - readingHelper->getTotalCodePointCount()); } @@ -86,7 +90,7 @@ bool DynamicPatriciaTrieWritingHelper::addUnigramWord( return createAndInsertNodeIntoPtNodeArray(parentPos, wordCodePoints + readingHelper->getPrevTotalCodePointCount(), codePointCount - readingHelper->getPrevTotalCodePointCount(), - probability, &pos); + getUpdatedProbability(NOT_A_PROBABILITY /* originalProbability */, probability), &pos); } bool DynamicPatriciaTrieWritingHelper::addBigramWords(const int word0Pos, const int word1Pos, @@ -149,7 +153,7 @@ void DynamicPatriciaTrieWritingHelper::writeToDictFile(const char *const fileNam const int extendedRegionSize = headerPolicy->getExtendedRegionSize() + mBuffer->getUsedAdditionalBufferSize(); if (!headerPolicy->writeHeaderToBuffer(&headerBuffer, false /* updatesLastUpdatedTime */, - unigramCount, bigramCount, extendedRegionSize)) { + false /* updatesLastDecayedTime */, unigramCount, bigramCount, extendedRegionSize)) { return; } DictFileWritingUtils::flushAllHeaderAndBodyToFile(fileName, &headerBuffer, mBuffer); @@ -166,7 +170,7 @@ void DynamicPatriciaTrieWritingHelper::writeToDictFileWithGC(const int rootPtNod } BufferWithExtendableBuffer headerBuffer(0 /* originalBuffer */, 0 /* originalBufferSize */); if (!headerPolicy->writeHeaderToBuffer(&headerBuffer, true /* updatesLastUpdatedTime */, - unigramCount, bigramCount, 0 /* extendedRegionSize */)) { + mNeedsToDecay, unigramCount, bigramCount, 0 /* extendedRegionSize */)) { return; } DictFileWritingUtils::flushAllHeaderAndBodyToFile(fileName, &headerBuffer, &newDictBuffer); @@ -351,9 +355,11 @@ bool DynamicPatriciaTrieWritingHelper::setPtNodeProbability( if (originalPtNode->isTerminal()) { // Overwrites the probability. *outAddedNewUnigram = false; + const int probabilityToWrite = getUpdatedProbability(originalPtNode->getProbability(), + probability); int probabilityFieldPos = originalPtNode->getProbabilityFieldPos(); if (!DynamicPatriciaTrieWritingUtils::writeProbabilityAndAdvancePosition(mBuffer, - probability, &probabilityFieldPos)) { + probabilityToWrite, &probabilityFieldPos)) { return false; } } else { @@ -365,7 +371,8 @@ bool DynamicPatriciaTrieWritingHelper::setPtNodeProbability( } if (!writePtNodeToBufferByCopyingPtNodeInfo(mBuffer, originalPtNode, originalPtNode->getParentPos(), codePoints, originalPtNode->getCodePointCount(), - probability, &movedPos)) { + getUpdatedProbability(NOT_A_PROBABILITY /* originalProbability */, probability), + &movedPos)) { return false; } } @@ -481,11 +488,15 @@ bool DynamicPatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, DynamicPatriciaTrieGcEventListeners ::TraversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted traversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted( - this, mBuffer); + this, mBuffer, mNeedsToDecay); if (!readingHelper.traverseAllPtNodesInPostorderDepthFirstManner( &traversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted)) { return false; } + if (mNeedsToDecay && traversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted + .getValidUnigramCount() > ForgettingCurveUtils::MAX_UNIGRAM_COUNT_AFTER_GC) { + // TODO: Remove more unigrams. + } readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos); DynamicPatriciaTrieGcEventListeners::TraversePolicyToUpdateBigramProbability @@ -495,6 +506,11 @@ bool DynamicPatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, return false; } + if (mNeedsToDecay && traversePolicyToUpdateBigramProbability.getValidBigramEntryCount() + > ForgettingCurveUtils::MAX_BIGRAM_COUNT_AFTER_GC) { + // TODO: Remove more bigrams. + } + // Mapping from positions in mBuffer to positions in bufferToWrite. DictPositionRelocationMap dictPositionRelocationMap; readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos); @@ -508,7 +524,8 @@ bool DynamicPatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, // Create policy instance for the GCed dictionary. DynamicShortcutListPolicy newDictShortcutPolicy(bufferToWrite); - DynamicBigramListPolicy newDictBigramPolicy(bufferToWrite, &newDictShortcutPolicy); + DynamicBigramListPolicy newDictBigramPolicy(bufferToWrite, &newDictShortcutPolicy, + mNeedsToDecay); // Create reading helper for the GCed dictionary. DynamicPatriciaTrieReadingHelper newDictReadingHelper(bufferToWrite, &newDictBigramPolicy, &newDictShortcutPolicy); @@ -525,4 +542,14 @@ bool DynamicPatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, return true; } +int DynamicPatriciaTrieWritingHelper::getUpdatedProbability(const int originalProbability, + const int newProbability) { + if (mNeedsToDecay) { + return ForgettingCurveUtils::getUpdatedEncodedProbability(originalProbability, + newProbability); + } else { + return newProbability; + } +} + } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.h b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.h index 827b6097f..0caf29120 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.h +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.h @@ -47,10 +47,13 @@ class DynamicPatriciaTrieWritingHelper { DISALLOW_COPY_AND_ASSIGN(DictPositionRelocationMap); }; + static const size_t MAX_DICTIONARY_SIZE; + DynamicPatriciaTrieWritingHelper(BufferWithExtendableBuffer *const buffer, DynamicBigramListPolicy *const bigramPolicy, - DynamicShortcutListPolicy *const shortcutPolicy) - : mBuffer(buffer), mBigramPolicy(bigramPolicy), mShortcutPolicy(shortcutPolicy) {} + DynamicShortcutListPolicy *const shortcutPolicy, const bool needsToDecay) + : mBuffer(buffer), mBigramPolicy(bigramPolicy), mShortcutPolicy(shortcutPolicy), + mNeedsToDecay(needsToDecay) {} ~DynamicPatriciaTrieWritingHelper() {} @@ -87,11 +90,11 @@ class DynamicPatriciaTrieWritingHelper { DISALLOW_IMPLICIT_CONSTRUCTORS(DynamicPatriciaTrieWritingHelper); static const int CHILDREN_POSITION_FIELD_SIZE; - static const size_t MAX_DICTIONARY_SIZE; BufferWithExtendableBuffer *const mBuffer; DynamicBigramListPolicy *const mBigramPolicy; DynamicShortcutListPolicy *const mShortcutPolicy; + const bool mNeedsToDecay; bool markNodeAsMovedAndSetPosition(const DynamicPatriciaTrieNodeReader *const nodeToUpdate, const int movedPos, const int bigramLinkedNodePos); @@ -127,6 +130,8 @@ class DynamicPatriciaTrieWritingHelper { bool runGC(const int rootPtNodeArrayPos, BufferWithExtendableBuffer *const bufferToWrite, int *const outUnigramCount, int *const outBigramCount); + + int getUpdatedProbability(const int originalProbability, const int newProbability); }; } // namespace latinime #endif /* LATINIME_DYNAMIC_PATRICIA_TRIE_WRITING_HELPER_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp index 9ce9994dd..eb072fbaf 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp @@ -23,6 +23,7 @@ const char *const HeaderPolicy::MULTIPLE_WORDS_DEMOTION_RATE_KEY = "MULTIPLE_WOR // TODO: Change attribute string to "IS_DECAYING_DICT". const char *const HeaderPolicy::IS_DECAYING_DICT_KEY = "USES_FORGETTING_CURVE"; const char *const HeaderPolicy::LAST_UPDATED_TIME_KEY = "date"; +const char *const HeaderPolicy::LAST_DECAYED_TIME_KEY = "LAST_DECAYED_TIME"; const char *const HeaderPolicy::UNIGRAM_COUNT_KEY = "UNIGRAM_COUNT"; const char *const HeaderPolicy::BIGRAM_COUNT_KEY = "BIGRAM_COUNT"; const char *const HeaderPolicy::EXTENDED_REGION_SIZE_KEY = "EXTENDED_REGION_SIZE"; @@ -63,8 +64,8 @@ float HeaderPolicy::readMultipleWordCostMultiplier() const { } bool HeaderPolicy::writeHeaderToBuffer(BufferWithExtendableBuffer *const bufferToWrite, - const bool updatesLastUpdatedTime, const int unigramCount, const int bigramCount, - const int extendedRegionSize) const { + const bool updatesLastUpdatedTime, const bool updatesLastDecayedTime, + const int unigramCount, const int bigramCount, const int extendedRegionSize) const { int writingPos = 0; if (!HeaderReadWriteUtils::writeDictionaryVersion(bufferToWrite, mDictFormatVersion, &writingPos)) { @@ -90,6 +91,11 @@ bool HeaderPolicy::writeHeaderToBuffer(BufferWithExtendableBuffer *const bufferT HeaderReadWriteUtils::setIntAttribute(&attributeMapTowrite, LAST_UPDATED_TIME_KEY, time(0)); } + if (updatesLastDecayedTime) { + // Set current time as a last updated time. + HeaderReadWriteUtils::setIntAttribute(&attributeMapTowrite, LAST_DECAYED_TIME_KEY, + time(0)); + } if (!HeaderReadWriteUtils::writeHeaderAttributes(bufferToWrite, &attributeMapTowrite, &writingPos)) { return false; diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h index 4261667fa..a9c7805a8 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h @@ -40,6 +40,8 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { IS_DECAYING_DICT_KEY, false /* defaultValue */)), mLastUpdatedTime(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap, LAST_UPDATED_TIME_KEY, time(0) /* defaultValue */)), + mLastDecayedTime(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap, + LAST_DECAYED_TIME_KEY, time(0) /* defaultValue */)), mUnigramCount(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap, UNIGRAM_COUNT_KEY, 0 /* defaultValue */)), mBigramCount(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap, @@ -58,6 +60,8 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { IS_DECAYING_DICT_KEY, false /* defaultValue */)), mLastUpdatedTime(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap, LAST_UPDATED_TIME_KEY, time(0) /* defaultValue */)), + mLastDecayedTime(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap, + LAST_UPDATED_TIME_KEY, time(0) /* defaultValue */)), mUnigramCount(0), mBigramCount(0), mExtendedRegionSize(0) {} ~HeaderPolicy() {} @@ -90,6 +94,10 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { return mLastUpdatedTime; } + AK_FORCE_INLINE int getLastDecayedTime() const { + return mLastDecayedTime; + } + AK_FORCE_INLINE int getUnigramCount() const { return mUnigramCount; } @@ -106,8 +114,8 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { int *outValue, int outValueSize) const; bool writeHeaderToBuffer(BufferWithExtendableBuffer *const bufferToWrite, - const bool updatesLastUpdatedTime, const int unigramCount, - const int bigramCount, const int extendedRegionSize) const; + const bool updatesLastUpdatedTime, const bool updatesLastDecayedTime, + const int unigramCount, const int bigramCount, const int extendedRegionSize) const; private: DISALLOW_IMPLICIT_CONSTRUCTORS(HeaderPolicy); @@ -115,6 +123,7 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { static const char *const MULTIPLE_WORDS_DEMOTION_RATE_KEY; static const char *const IS_DECAYING_DICT_KEY; static const char *const LAST_UPDATED_TIME_KEY; + static const char *const LAST_DECAYED_TIME_KEY; static const char *const UNIGRAM_COUNT_KEY; static const char *const BIGRAM_COUNT_KEY; static const char *const EXTENDED_REGION_SIZE_KEY; @@ -128,6 +137,7 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { const float mMultiWordCostMultiplier; const bool mIsDecayingDict; const int mLastUpdatedTime; + const int mLastDecayedTime; const int mUnigramCount; const int mBigramCount; const int mExtendedRegionSize; diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp index 2694ce8d5..5ded8f6a1 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp @@ -139,6 +139,9 @@ const char *const HeaderReadWriteUtils::REQUIRES_FRENCH_LIGATURE_PROCESSING_KEY int *const writingPos) { for (AttributeMap::const_iterator it = headerAttributes->begin(); it != headerAttributes->end(); ++it) { + if (it->first.empty() || it->second.empty()) { + continue; + } // Write a key. if (!buffer->writeCodePointsAndAdvancePosition(&(it->first.at(0)), it->first.size(), true /* writesTerminator */, writingPos)) { diff --git a/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.h index 8d88c68e8..0f8662aea 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.h @@ -114,7 +114,7 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { } void getProperty(const char *const query, char *const outResult, - const int maxResultLength) const { + const int maxResultLength) { // getProperty is not supported for this class. if (maxResultLength > 0) { outResult[0] = '\0'; diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp index f22e94c6a..994826fa8 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp @@ -44,7 +44,8 @@ const char *const DictFileWritingUtils::TEMP_FILE_SUFFIX_FOR_WRITING_DICT_FILE = BufferWithExtendableBuffer headerBuffer(0 /* originalBuffer */, 0 /* originalBufferSize */); HeaderPolicy headerPolicy(FormatUtils::VERSION_3, attributeMap); headerPolicy.writeHeaderToBuffer(&headerBuffer, true /* updatesLastUpdatedTime */, - 0 /* unigramCount */, 0 /* bigramCount */, 0 /* extendedRegionSize */); + true /* updatesLastDecayedTime */, 0 /* unigramCount */, 0 /* bigramCount */, + 0 /* extendedRegionSize */); BufferWithExtendableBuffer bodyBuffer(0 /* originalBuffer */, 0 /* originalBufferSize */); if (!DynamicPatriciaTrieWritingUtils::writeEmptyDictionary(&bodyBuffer, 0 /* rootPos */)) { return false; diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp new file mode 100644 index 000000000..b502fe25d --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp @@ -0,0 +1,123 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <cmath> +#include <stdlib.h> + +#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h" + +#include "suggest/policyimpl/dictionary/utils/probability_utils.h" + +namespace latinime { + +const int ForgettingCurveUtils::MAX_UNIGRAM_COUNT = 12000; +const int ForgettingCurveUtils::MAX_UNIGRAM_COUNT_AFTER_GC = 10000; +const int ForgettingCurveUtils::MAX_BIGRAM_COUNT = 12000; +const int ForgettingCurveUtils::MAX_BIGRAM_COUNT_AFTER_GC = 10000; + +const int ForgettingCurveUtils::MAX_COMPUTED_PROBABILITY = 127; +const int ForgettingCurveUtils::MAX_ENCODED_PROBABILITY = 15; +const int ForgettingCurveUtils::MIN_VALID_ENCODED_PROBABILITY = 3; +const int ForgettingCurveUtils::ENCODED_PROBABILITY_STEP = 1; +// Currently, we try to decay each uni/bigram once every 2 hours. Accordingly, the expected +// duration of the decay is approximately 66hours. +const float ForgettingCurveUtils::MIN_PROBABILITY_TO_DECAY = 0.03f; + +const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityTable; + +/* static */ int ForgettingCurveUtils::getProbability(const int encodedUnigramProbability, + const int encodedBigramProbability) { + if (encodedUnigramProbability == NOT_A_PROBABILITY) { + return NOT_A_PROBABILITY; + } else if (encodedBigramProbability == NOT_A_PROBABILITY) { + return backoff(decodeProbability(encodedUnigramProbability)); + } else { + const int unigramProbability = decodeProbability(encodedUnigramProbability); + const int bigramProbability = decodeProbability(encodedBigramProbability); + return min(max(unigramProbability, bigramProbability), MAX_COMPUTED_PROBABILITY); + } +} + +// Caveat: Unlike getProbability(), this method doesn't assume special bigram probability encoding +// (i.e. unigram probability + bigram probability delta). +/* static */ int ForgettingCurveUtils::getUpdatedEncodedProbability( + const int originalEncodedProbability, const int newProbability) { + if (originalEncodedProbability == NOT_A_PROBABILITY) { + // The bigram relation is not in this dictionary. + if (newProbability == NOT_A_PROBABILITY) { + // The bigram target is not in other dictionaries. + return 0; + } else { + return MIN_VALID_ENCODED_PROBABILITY; + } + } else { + if (newProbability != NOT_A_PROBABILITY + && originalEncodedProbability < MIN_VALID_ENCODED_PROBABILITY) { + return MIN_VALID_ENCODED_PROBABILITY; + } + return min(originalEncodedProbability + ENCODED_PROBABILITY_STEP, MAX_ENCODED_PROBABILITY); + } +} + +/* static */ int ForgettingCurveUtils::isValidEncodedProbability(const int encodedProbability) { + return encodedProbability >= MIN_VALID_ENCODED_PROBABILITY; +} + +/* static */ int ForgettingCurveUtils::getEncodedProbabilityToSave(const int encodedProbability) { + const int currentEncodedProbability = max(min(encodedProbability, MAX_ENCODED_PROBABILITY), 0); + // TODO: Implement the decay in more proper way. + const float currentRate = static_cast<float>(currentEncodedProbability) + / static_cast<float>(MAX_ENCODED_PROBABILITY); + const float thresholdToDecay = MIN_PROBABILITY_TO_DECAY + + (1.0f - MIN_PROBABILITY_TO_DECAY) * (1.0f - currentRate); + const float randValue = static_cast<float>(rand()) / static_cast<float>(RAND_MAX); + if (thresholdToDecay < randValue) { + return max(currentEncodedProbability - ENCODED_PROBABILITY_STEP, 0); + } else { + return currentEncodedProbability; + } +} + +/* static */ int ForgettingCurveUtils::decodeProbability(const int encodedProbability) { + if (encodedProbability < MIN_VALID_ENCODED_PROBABILITY) { + return NOT_A_PROBABILITY; + } else { + return min(sProbabilityTable.getProbability(encodedProbability), MAX_ENCODED_PROBABILITY); + } +} + +// See comments in ProbabilityUtils::backoff(). +/* static */ int ForgettingCurveUtils::backoff(const int unigramProbability) { + if (unigramProbability == NOT_A_PROBABILITY) { + return NOT_A_PROBABILITY; + } else { + return max(unigramProbability - 8, 0); + } +} + +ForgettingCurveUtils::ProbabilityTable::ProbabilityTable() : mTable() { + // Table entry is as follows: + // 1, 1, 1, 2, 3, 5, 6, 9, 13, 18, 25, 34, 48, 66, 91, 127. + // Note that first MIN_VALID_ENCODED_PROBABILITY values are not used. + mTable.resize(MAX_ENCODED_PROBABILITY + 1); + for (int i = 0; i <= MAX_ENCODED_PROBABILITY; ++i) { + const int probability = static_cast<int>(powf(static_cast<float>(MAX_COMPUTED_PROBABILITY), + static_cast<float>(i) / static_cast<float>(MAX_ENCODED_PROBABILITY))); + mTable[i] = min(MAX_COMPUTED_PROBABILITY, max(0, probability)); + } +} + +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h new file mode 100644 index 000000000..d666f22aa --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h @@ -0,0 +1,79 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_FORGETTING_CURVE_UTILS_H +#define LATINIME_FORGETTING_CURVE_UTILS_H + +#include <vector> + +#include "defines.h" + +namespace latinime { + +// TODO: Check the elapsed time and decrease the probability depending on the time. Time field is +// required to introduced to each terminal PtNode and bigram entry. +// TODO: Quit using bigram probability to indicate the delta. +class ForgettingCurveUtils { + public: + static const int MAX_UNIGRAM_COUNT; + static const int MAX_UNIGRAM_COUNT_AFTER_GC; + static const int MAX_BIGRAM_COUNT; + static const int MAX_BIGRAM_COUNT_AFTER_GC; + + static int getProbability(const int encodedUnigramProbability, + const int encodedBigramProbability); + + static int getUpdatedEncodedProbability(const int originalEncodedProbability, + const int newProbability); + + static int isValidEncodedProbability(const int encodedProbability); + + static int getEncodedProbabilityToSave(const int encodedProbability); + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(ForgettingCurveUtils); + + class ProbabilityTable { + public: + ProbabilityTable(); + + int getProbability(const int encodedProbability) const { + if (encodedProbability < 0 || encodedProbability > static_cast<int>(mTable.size())) { + return NOT_A_PROBABILITY; + } + return mTable[encodedProbability]; + } + + private: + DISALLOW_COPY_AND_ASSIGN(ProbabilityTable); + + std::vector<int> mTable; + }; + + static const int MAX_COMPUTED_PROBABILITY; + static const int MAX_ENCODED_PROBABILITY; + static const int MIN_VALID_ENCODED_PROBABILITY; + static const int ENCODED_PROBABILITY_STEP; + static const float MIN_PROBABILITY_TO_DECAY; + + static const ProbabilityTable sProbabilityTable; + + static int decodeProbability(const int encodedProbability); + + static int backoff(const int unigramProbability); +}; +} // namespace latinime +#endif /* LATINIME_FORGETTING_CURVE_UTILS_H */ diff --git a/native/jni/src/suggest/policyimpl/typing/typing_traversal.h b/native/jni/src/suggest/policyimpl/typing/typing_traversal.h index 89e53f441..007c19e0a 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_traversal.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_traversal.h @@ -101,7 +101,7 @@ class TypingTraversal : public Traversal { } const int16_t pointIndex = dicNode->getInputIndex(0); return pointIndex <= inputSize && !dicNode->isTotalInputSizeExceedingLimit() - && !dicNode->shouldBeFilterdBySafetyNetForBigram(); + && !dicNode->shouldBeFilteredBySafetyNetForBigram(); } AK_FORCE_INLINE bool shouldDepthLevelCache( diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp b/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp index 408b12ae9..5b6b5e874 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp +++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp @@ -47,7 +47,7 @@ ErrorType TypingWeighting::getErrorType(const CorrectionType correctionType, case CT_TERMINAL_INSERTION: case CT_TRANSPOSITION: return ET_EDIT_CORRECTION; - case CT_NEW_WORD_SPACE_OMITTION: + case CT_NEW_WORD_SPACE_OMISSION: case CT_NEW_WORD_SPACE_SUBSTITUTION: return ET_NEW_WORD; case CT_TERMINAL: diff --git a/native/jni/src/utils/char_utils.h b/native/jni/src/utils/char_utils.h index 2e735a81c..41663c81a 100644 --- a/native/jni/src/utils/char_utils.h +++ b/native/jni/src/utils/char_utils.h @@ -75,6 +75,16 @@ class CharUtils { return c; } + static AK_FORCE_INLINE int getSpaceCount(const int *const codePointBuffer, const int length) { + int spaceCount = 0; + for (int i = 0; i < length; ++i) { + if (codePointBuffer[i] == KEYCODE_SPACE) { + ++spaceCount; + } + } + return spaceCount; + } + static unsigned short latin_tolower(const unsigned short c); private: |