diff options
Diffstat (limited to 'native')
19 files changed, 162 insertions, 100 deletions
diff --git a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp index 18b78c4df..28aaf2d1a 100644 --- a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp +++ b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp @@ -28,6 +28,7 @@ #include "suggest/core/dictionary/property/unigram_property.h" #include "suggest/core/dictionary/property/word_property.h" #include "suggest/core/result/suggestion_results.h" +#include "suggest/core/session/prev_words_info.h" #include "suggest/core/suggest_options.h" #include "suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.h" #include "utils/char_utils.h" @@ -190,7 +191,9 @@ static void latinime_BinaryDictionary_getSuggestions(JNIEnv *env, jclass clazz, ProximityInfo *pInfo = reinterpret_cast<ProximityInfo *>(proximityInfo); DicTraverseSession *traverseSession = reinterpret_cast<DicTraverseSession *>(dicTraverseSession); - + if (!traverseSession) { + return; + } // Input values int xCoordinates[inputSize]; int yCoordinates[inputSize]; @@ -245,15 +248,15 @@ static void latinime_BinaryDictionary_getSuggestions(JNIEnv *env, jclass clazz, float languageWeight; env->GetFloatArrayRegion(inOutLanguageWeight, 0, 1 /* len */, &languageWeight); SuggestionResults suggestionResults(MAX_RESULTS); + const PrevWordsInfo prevWordsInfo(prevWordCodePoints, prevWordCodePointsLength, + false /* isStartOfSentence */); if (givenSuggestOptions.isGesture() || inputSize > 0) { // TODO: Use SuggestionResults to return suggestions. dictionary->getSuggestions(pInfo, traverseSession, xCoordinates, yCoordinates, - times, pointerIds, inputCodePoints, inputSize, prevWordCodePoints, - prevWordCodePointsLength, &givenSuggestOptions, languageWeight, - &suggestionResults); + times, pointerIds, inputCodePoints, inputSize, &prevWordsInfo, + &givenSuggestOptions, languageWeight, &suggestionResults); } else { - dictionary->getPredictions(prevWordCodePoints, prevWordCodePointsLength, - &suggestionResults); + dictionary->getPredictions(&prevWordsInfo, &suggestionResults); } suggestionResults.outputSuggestions(env, outSuggestionCount, outCodePointsArray, outScoresArray, outSpaceIndicesArray, outTypesArray, @@ -280,8 +283,8 @@ static jint latinime_BinaryDictionary_getBigramProbability(JNIEnv *env, jclass c int word1CodePoints[word1Length]; env->GetIntArrayRegion(word0, 0, word0Length, word0CodePoints); env->GetIntArrayRegion(word1, 0, word1Length, word1CodePoints); - return dictionary->getBigramProbability(word0CodePoints, word0Length, word1CodePoints, - word1Length); + const PrevWordsInfo prevWordsInfo(word0CodePoints, word0Length, false /* isStartOfSentence */); + return dictionary->getBigramProbability(&prevWordsInfo, word1CodePoints, word1Length); } // Method to iterate all words in the dictionary for makedict. @@ -467,16 +470,6 @@ static int latinime_BinaryDictionary_addMultipleDictionaryEntries(JNIEnv *env, j return languageModelParamCount; } -static int latinime_BinaryDictionary_calculateProbabilityNative(JNIEnv *env, jclass clazz, - jlong dict, jint unigramProbability, jint bigramProbability) { - Dictionary *dictionary = reinterpret_cast<Dictionary *>(dict); - if (!dictionary) { - return NOT_A_PROBABILITY; - } - return dictionary->getDictionaryStructurePolicy()->getProbability(unigramProbability, - bigramProbability); -} - static jstring latinime_BinaryDictionary_getProperty(JNIEnv *env, jclass clazz, jlong dict, jstring query) { Dictionary *dictionary = reinterpret_cast<Dictionary *>(dict); @@ -670,11 +663,6 @@ static const JNINativeMethod sMethods[] = { reinterpret_cast<void *>(latinime_BinaryDictionary_addMultipleDictionaryEntries) }, { - const_cast<char *>("calculateProbabilityNative"), - const_cast<char *>("(JII)I"), - reinterpret_cast<void *>(latinime_BinaryDictionary_calculateProbabilityNative) - }, - { const_cast<char *>("getPropertyNative"), const_cast<char *>("(JLjava/lang/String;)Ljava/lang/String;"), reinterpret_cast<void *>(latinime_BinaryDictionary_getProperty) diff --git a/native/jni/com_android_inputmethod_latin_DicTraverseSession.cpp b/native/jni/com_android_inputmethod_latin_DicTraverseSession.cpp index 386643332..766064153 100644 --- a/native/jni/com_android_inputmethod_latin_DicTraverseSession.cpp +++ b/native/jni/com_android_inputmethod_latin_DicTraverseSession.cpp @@ -22,6 +22,7 @@ #include "jni.h" #include "jni_common.h" #include "suggest/core/session/dic_traverse_session.h" +#include "suggest/core/session/prev_words_info.h" namespace latinime { class Dictionary; @@ -34,16 +35,19 @@ static jlong latinime_setDicTraverseSession(JNIEnv *env, jclass clazz, jstring l static void latinime_initDicTraverseSession(JNIEnv *env, jclass clazz, jlong traverseSession, jlong dictionary, jintArray previousWord, jint previousWordLength) { DicTraverseSession *ts = reinterpret_cast<DicTraverseSession *>(traverseSession); + if (!ts) { + return; + } Dictionary *dict = reinterpret_cast<Dictionary *>(dictionary); if (!previousWord) { - DicTraverseSession::initSessionInstance( - ts, dict, 0 /* prevWord */, 0 /* prevWordLength*/, 0 /* suggestOptions */); + PrevWordsInfo prevWordsInfo; + ts->init(dict, &prevWordsInfo, 0 /* suggestOptions */); return; } int prevWord[previousWordLength]; env->GetIntArrayRegion(previousWord, 0, previousWordLength, prevWord); - DicTraverseSession::initSessionInstance( - ts, dict, prevWord, previousWordLength, 0 /* suggestOptions */); + PrevWordsInfo prevWordsInfo(prevWord, previousWordLength, false /* isStartOfSentence */); + ts->init(dict, &prevWordsInfo, 0 /* suggestOptions */); } static void latinime_releaseDicTraverseSession(JNIEnv *env, jclass clazz, jlong traverseSession) { diff --git a/native/jni/src/defines.h b/native/jni/src/defines.h index a80c97530..24d04e51f 100644 --- a/native/jni/src/defines.h +++ b/native/jni/src/defines.h @@ -336,6 +336,9 @@ static inline void prof_out(void) { #define MAX_POINTER_COUNT 1 #define MAX_POINTER_COUNT_G 2 +// (MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1)-gram is supported. +#define MAX_PREV_WORD_COUNT_FOR_N_GRAM 1 + #define DISALLOW_DEFAULT_CONSTRUCTOR(TypeName) \ TypeName() = delete diff --git a/native/jni/src/suggest/core/dictionary/bigram_dictionary.cpp b/native/jni/src/suggest/core/dictionary/bigram_dictionary.cpp index f793363a8..847fa1b02 100644 --- a/native/jni/src/suggest/core/dictionary/bigram_dictionary.cpp +++ b/native/jni/src/suggest/core/dictionary/bigram_dictionary.cpp @@ -26,6 +26,7 @@ #include "suggest/core/dictionary/dictionary.h" #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" #include "suggest/core/result/suggestion_results.h" +#include "suggest/core/session/prev_words_info.h" #include "utils/char_utils.h" namespace latinime { @@ -42,19 +43,18 @@ BigramDictionary::~BigramDictionary() { } /* Parameters : - * prevWord: the word before, the one for which we need to look up bigrams. - * prevWordLength: its length. + * prevWordsInfo: Information of previous words to get the predictions. * outSuggestionResults: SuggestionResults to put the predictions. */ -void BigramDictionary::getPredictions(const int *prevWord, const int prevWordLength, +void BigramDictionary::getPredictions(const PrevWordsInfo *const prevWordsInfo, SuggestionResults *const outSuggestionResults) const { - int pos = getBigramListPositionForWord(prevWord, prevWordLength, - false /* forceLowerCaseSearch */); + int pos = getBigramListPositionForWord(prevWordsInfo->getPrevWordCodePoints(), + prevWordsInfo->getPrevWordCodePointCount(), false /* forceLowerCaseSearch */); // getBigramListPositionForWord returns 0 if this word isn't in the dictionary or has no bigrams if (NOT_A_DICT_POS == pos) { // If no bigrams for this exact word, search again in lower case. - pos = getBigramListPositionForWord(prevWord, prevWordLength, - true /* forceLowerCaseSearch */); + pos = getBigramListPositionForWord(prevWordsInfo->getPrevWordCodePoints(), + prevWordsInfo->getPrevWordCodePointCount(), true /* forceLowerCaseSearch */); } // If still no bigrams, we really don't have them! if (NOT_A_DICT_POS == pos) return; @@ -96,9 +96,10 @@ int BigramDictionary::getBigramListPositionForWord(const int *prevWord, const in return mDictionaryStructurePolicy->getBigramsPositionOfPtNode(pos); } -int BigramDictionary::getBigramProbability(const int *word0, int length0, const int *word1, - int length1) const { - int pos = getBigramListPositionForWord(word0, length0, false /* forceLowerCaseSearch */); +int BigramDictionary::getBigramProbability(const PrevWordsInfo *const prevWordsInfo, + const int *word1, int length1) const { + int pos = getBigramListPositionForWord(prevWordsInfo->getPrevWordCodePoints(), + prevWordsInfo->getPrevWordCodePointCount(), false /* forceLowerCaseSearch */); // getBigramListPositionForWord returns 0 if this word isn't in the dictionary or has no bigrams if (NOT_A_DICT_POS == pos) return NOT_A_PROBABILITY; int nextWordPos = mDictionaryStructurePolicy->getTerminalPtNodePositionOfWord(word1, length1, diff --git a/native/jni/src/suggest/core/dictionary/bigram_dictionary.h b/native/jni/src/suggest/core/dictionary/bigram_dictionary.h index 12aaf20d3..bd3aed1bd 100644 --- a/native/jni/src/suggest/core/dictionary/bigram_dictionary.h +++ b/native/jni/src/suggest/core/dictionary/bigram_dictionary.h @@ -22,15 +22,17 @@ namespace latinime { class DictionaryStructureWithBufferPolicy; +class PrevWordsInfo; class SuggestionResults; class BigramDictionary { public: BigramDictionary(const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy); - void getPredictions(const int *word, int length, + void getPredictions(const PrevWordsInfo *const prevWordsInfo, SuggestionResults *const outSuggestionResults) const; - int getBigramProbability(const int *word1, int length1, const int *word2, int length2) const; + int getBigramProbability(const PrevWordsInfo *const prevWordsInfo, + const int *word1, int length1) const; ~BigramDictionary(); private: diff --git a/native/jni/src/suggest/core/dictionary/dictionary.cpp b/native/jni/src/suggest/core/dictionary/dictionary.cpp index fdc893653..c860d82af 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.cpp +++ b/native/jni/src/suggest/core/dictionary/dictionary.cpp @@ -44,12 +44,11 @@ Dictionary::Dictionary(JNIEnv *env, DictionaryStructureWithBufferPolicy::Structu void Dictionary::getSuggestions(ProximityInfo *proximityInfo, DicTraverseSession *traverseSession, int *xcoordinates, int *ycoordinates, int *times, int *pointerIds, int *inputCodePoints, - int inputSize, int *prevWordCodePoints, int prevWordLength, + int inputSize, const PrevWordsInfo *const prevWordsInfo, const SuggestOptions *const suggestOptions, const float languageWeight, SuggestionResults *const outSuggestionResults) const { TimeKeeper::setCurrentTime(); - DicTraverseSession::initSessionInstance( - traverseSession, this, prevWordCodePoints, prevWordLength, suggestOptions); + traverseSession->init(this, prevWordsInfo, suggestOptions); const auto &suggest = suggestOptions->isGesture() ? mGestureSuggest : mTypingSuggest; suggest->getSuggestions(proximityInfo, traverseSession, xcoordinates, ycoordinates, times, pointerIds, inputCodePoints, inputSize, @@ -59,11 +58,10 @@ void Dictionary::getSuggestions(ProximityInfo *proximityInfo, DicTraverseSession } } -void Dictionary::getPredictions(const int *word, int length, +void Dictionary::getPredictions(const PrevWordsInfo *const prevWordsInfo, SuggestionResults *const outSuggestionResults) const { TimeKeeper::setCurrentTime(); - if (length <= 0) return; - mBigramDictionary.getPredictions(word, length, outSuggestionResults); + mBigramDictionary.getPredictions(prevWordsInfo, outSuggestionResults); } int Dictionary::getProbability(const int *word, int length) const { @@ -76,10 +74,10 @@ int Dictionary::getProbability(const int *word, int length) const { return getDictionaryStructurePolicy()->getUnigramProbabilityOfPtNode(pos); } -int Dictionary::getBigramProbability(const int *word0, int length0, const int *word1, +int Dictionary::getBigramProbability(const PrevWordsInfo *const prevWordsInfo, const int *word1, int length1) const { TimeKeeper::setCurrentTime(); - return mBigramDictionary.getBigramProbability(word0, length0, word1, length1); + return mBigramDictionary.getBigramProbability(prevWordsInfo, word1, length1); } void Dictionary::addUnigramWord(const int *const word, const int length, diff --git a/native/jni/src/suggest/core/dictionary/dictionary.h b/native/jni/src/suggest/core/dictionary/dictionary.h index f0a7e5b6a..b63c61fbb 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.h +++ b/native/jni/src/suggest/core/dictionary/dictionary.h @@ -31,6 +31,7 @@ namespace latinime { class DictionaryStructureWithBufferPolicy; class DicTraverseSession; +class PrevWordsInfo; class ProximityInfo; class SuggestionResults; class SuggestOptions; @@ -62,16 +63,17 @@ class Dictionary { void getSuggestions(ProximityInfo *proximityInfo, DicTraverseSession *traverseSession, int *xcoordinates, int *ycoordinates, int *times, int *pointerIds, int *inputCodePoints, - int inputSize, int *prevWordCodePoints, int prevWordLength, + int inputSize, const PrevWordsInfo *const prevWordsInfo, const SuggestOptions *const suggestOptions, const float languageWeight, SuggestionResults *const outSuggestionResults) const; - void getPredictions(const int *word, int length, + void getPredictions(const PrevWordsInfo *const prevWordsInfo, SuggestionResults *const outSuggestionResults) const; int getProbability(const int *word, int length) const; - int getBigramProbability(const int *word0, int length0, const int *word1, int length1) const; + int getBigramProbability(const PrevWordsInfo *const prevWordsInfo, + const int *word1, int length1) const; void addUnigramWord(const int *const codePoints, const int codePointCount, const UnigramProperty *const unigramProperty); diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.cpp b/native/jni/src/suggest/core/session/dic_traverse_session.cpp index 77b634e07..b9e9db719 100644 --- a/native/jni/src/suggest/core/session/dic_traverse_session.cpp +++ b/native/jni/src/suggest/core/session/dic_traverse_session.cpp @@ -20,6 +20,7 @@ #include "suggest/core/dictionary/dictionary.h" #include "suggest/core/policy/dictionary_header_structure_policy.h" #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" +#include "suggest/core/session/prev_words_info.h" namespace latinime { @@ -28,24 +29,26 @@ namespace latinime { const int DicTraverseSession::DICTIONARY_SIZE_THRESHOLD_TO_USE_LARGE_CACHE_FOR_SUGGESTION = 256 * 1024; -void DicTraverseSession::init(const Dictionary *const dictionary, const int *prevWord, - int prevWordLength, const SuggestOptions *const suggestOptions) { +void DicTraverseSession::init(const Dictionary *const dictionary, + const PrevWordsInfo *const prevWordsInfo, const SuggestOptions *const suggestOptions) { mDictionary = dictionary; mMultiWordCostMultiplier = getDictionaryStructurePolicy()->getHeaderStructurePolicy() ->getMultiWordCostMultiplier(); mSuggestOptions = suggestOptions; - if (!prevWord) { - mPrevWordPtNodePos = NOT_A_DICT_POS; + if (!prevWordsInfo->getPrevWordCodePoints()) { + mPrevWordsPtNodePos[0] = NOT_A_DICT_POS; return; } // TODO: merge following similar calls to getTerminalPosition into one case-insensitive call. - mPrevWordPtNodePos = getDictionaryStructurePolicy()->getTerminalPtNodePositionOfWord( - prevWord, prevWordLength, false /* forceLowerCaseSearch */); - if (mPrevWordPtNodePos == NOT_A_DICT_POS) { + mPrevWordsPtNodePos[0] = getDictionaryStructurePolicy()->getTerminalPtNodePositionOfWord( + prevWordsInfo->getPrevWordCodePoints(), prevWordsInfo->getPrevWordCodePointCount(), + false /* forceLowerCaseSearch */); + if (mPrevWordsPtNodePos[0] == NOT_A_DICT_POS) { // Check bigrams for lower-cased previous word if original was not found. Useful for // auto-capitalized words like "The [current_word]". - mPrevWordPtNodePos = getDictionaryStructurePolicy()->getTerminalPtNodePositionOfWord( - prevWord, prevWordLength, true /* forceLowerCaseSearch */); + mPrevWordsPtNodePos[0] = getDictionaryStructurePolicy()->getTerminalPtNodePositionOfWord( + prevWordsInfo->getPrevWordCodePoints(), prevWordsInfo->getPrevWordCodePointCount(), + true /* forceLowerCaseSearch */); } } diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.h b/native/jni/src/suggest/core/session/dic_traverse_session.h index 843ca85a0..90aff06c3 100644 --- a/native/jni/src/suggest/core/session/dic_traverse_session.h +++ b/native/jni/src/suggest/core/session/dic_traverse_session.h @@ -29,6 +29,7 @@ namespace latinime { class Dictionary; class DictionaryStructureWithBufferPolicy; +class PrevWordsInfo; class ProximityInfo; class SuggestOptions; @@ -44,32 +45,25 @@ class DicTraverseSession { dictSize >= DICTIONARY_SIZE_THRESHOLD_TO_USE_LARGE_CACHE_FOR_SUGGESTION); } - static AK_FORCE_INLINE void initSessionInstance(DicTraverseSession *traverseSession, - const Dictionary *const dictionary, const int *prevWord, const int prevWordLength, - const SuggestOptions *const suggestOptions) { - if (traverseSession) { - DicTraverseSession *tSession = static_cast<DicTraverseSession *>(traverseSession); - tSession->init(dictionary, prevWord, prevWordLength, suggestOptions); - } - } - static AK_FORCE_INLINE void releaseSessionInstance(DicTraverseSession *traverseSession) { delete traverseSession; } AK_FORCE_INLINE DicTraverseSession(JNIEnv *env, jstring localeStr, bool usesLargeCache) - : mPrevWordPtNodePos(NOT_A_DICT_POS), mProximityInfo(nullptr), - mDictionary(nullptr), mSuggestOptions(nullptr), mDicNodesCache(usesLargeCache), - mMultiBigramMap(), mInputSize(0), mMaxPointerCount(1), + : mProximityInfo(nullptr), mDictionary(nullptr), mSuggestOptions(nullptr), + mDicNodesCache(usesLargeCache), mMultiBigramMap(), mInputSize(0), mMaxPointerCount(1), mMultiWordCostMultiplier(1.0f) { // NOTE: mProximityInfoStates is an array of instances. // No need to initialize it explicitly here. + for (size_t i = 0; i < NELEMS(mPrevWordsPtNodePos); ++i) { + mPrevWordsPtNodePos[i] = NOT_A_DICT_POS; + } } // Non virtual inline destructor -- never inherit this class AK_FORCE_INLINE ~DicTraverseSession() {} - void init(const Dictionary *dictionary, const int *prevWord, int prevWordLength, + void init(const Dictionary *dictionary, const PrevWordsInfo *const prevWordsInfo, const SuggestOptions *const suggestOptions); // TODO: Remove and merge into init void setupForGetSuggestions(const ProximityInfo *pInfo, const int *inputCodePoints, @@ -85,9 +79,7 @@ class DicTraverseSession { //-------------------- const ProximityInfo *getProximityInfo() const { return mProximityInfo; } const SuggestOptions *getSuggestOptions() const { return mSuggestOptions; } - int getPrevWordPtNodePos() const { return mPrevWordPtNodePos; } - // TODO: REMOVE - void setPrevWordPtNodePos(const int ptNodePos) { mPrevWordPtNodePos = ptNodePos; } + int getPrevWordPtNodePos() const { return mPrevWordsPtNodePos[0]; } DicNodesCache *getDicTraverseCache() { return &mDicNodesCache; } MultiBigramMap *getMultiBigramMap() { return &mMultiBigramMap; } const ProximityInfoState *getProximityInfoState(int id) const { @@ -174,7 +166,7 @@ class DicTraverseSession { const int *const inputYs, const int *const times, const int *const pointerIds, const int inputSize, const float maxSpatialDistance, const int maxPointerCount); - int mPrevWordPtNodePos; + int mPrevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; const ProximityInfo *mProximityInfo; const Dictionary *mDictionary; const SuggestOptions *mSuggestOptions; diff --git a/native/jni/src/suggest/core/session/prev_words_info.h b/native/jni/src/suggest/core/session/prev_words_info.h new file mode 100644 index 000000000..bc685945e --- /dev/null +++ b/native/jni/src/suggest/core/session/prev_words_info.h @@ -0,0 +1,65 @@ +/* + * Copyright (C) 2014 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_PREV_WORDS_INFO_H +#define LATINIME_PREV_WORDS_INFO_H + +#include "defines.h" + +namespace latinime { + +// TODO: Support n-gram. +// TODO: Support beginning of sentence. +// This class does not take ownership of any code point buffers. +class PrevWordsInfo { + public: + // No prev word information. + PrevWordsInfo() { + clear(); + } + + PrevWordsInfo(const int *const prevWordCodePoints, const int prevWordCodePointCount, + const bool isBeginningOfSentence) { + clear(); + mPrevWordCodePoints[0] = prevWordCodePoints; + mPrevWordCodePointCount[0] = prevWordCodePointCount; + mIsBeginningOfSentence[0] = isBeginningOfSentence; + } + const int *getPrevWordCodePoints() const { + return mPrevWordCodePoints[0]; + } + + int getPrevWordCodePointCount() const { + return mPrevWordCodePointCount[0]; + } + + private: + DISALLOW_COPY_AND_ASSIGN(PrevWordsInfo); + + void clear() { + for (size_t i = 0; i < NELEMS(mPrevWordCodePoints); ++i) { + mPrevWordCodePoints[i] = nullptr; + mPrevWordCodePointCount[i] = 0; + mIsBeginningOfSentence[i] = false; + } + } + + const int *mPrevWordCodePoints[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; + int mPrevWordCodePointCount[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; + bool mIsBeginningOfSentence[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; +}; +} // namespace latinime +#endif // LATINIME_PREV_WORDS_INFO_H diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_gc_event_listeners.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_gc_event_listeners.cpp index 8f42df6d2..028e9ecbf 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_gc_event_listeners.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_gc_event_listeners.cpp @@ -29,10 +29,10 @@ bool DynamicPtGcEventListeners // PtNode is useless when the PtNode is not a terminal and doesn't have any not useless // children. bool isUselessPtNode = !ptNodeParams->isTerminal(); - if (ptNodeParams->isTerminal()) { + if (ptNodeParams->isTerminal() && !ptNodeParams->representsNonWordInfo()) { bool needsToKeepPtNode = true; - if (!mPtNodeWriter->updatePtNodeProbabilityAndGetNeedsToKeepPtNodeAfterGC(ptNodeParams, - &needsToKeepPtNode)) { + if (!mPtNodeWriter->updatePtNodeProbabilityAndGetNeedsToKeepPtNodeAfterGC( + ptNodeParams, &needsToKeepPtNode)) { AKLOGE("Cannot update PtNode probability or get needs to keep PtNode after GC."); return false; } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h index bef401f87..5704c2e90 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h @@ -160,7 +160,8 @@ class PtNodeParams { } AK_FORCE_INLINE bool representsNonWordInfo() const { - return getCodePointCount() > 0 && CharUtils::isInUnicodeSpace(getCodePoints()[0]); + return getCodePointCount() > 0 && CharUtils::isInUnicodeSpace(getCodePoints()[0]) + && isNotAWord(); } // Parent node position diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/bigram_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/bigram_dict_content.cpp index 56f19dbae..d53922763 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/bigram_dict_content.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/bigram_dict_content.cpp @@ -38,8 +38,6 @@ const BigramEntry BigramDictContent::getBigramEntryAndAdvancePosition( int level = 0; int count = 0; if (mHasHistoricalInfo) { - probability = bigramListBuffer->readUintAndAdvancePosition( - Ver4DictConstants::PROBABILITY_SIZE, bigramEntryPos); timestamp = bigramListBuffer->readUintAndAdvancePosition( Ver4DictConstants::TIME_STAMP_FIELD_SIZE, bigramEntryPos); level = bigramListBuffer->readUintAndAdvancePosition( @@ -47,7 +45,8 @@ const BigramEntry BigramDictContent::getBigramEntryAndAdvancePosition( count = bigramListBuffer->readUintAndAdvancePosition( Ver4DictConstants::WORD_COUNT_FIELD_SIZE, bigramEntryPos); } else { - probability = bigramFlags & Ver4DictConstants::BIGRAM_PROBABILITY_MASK; + probability = bigramListBuffer->readUintAndAdvancePosition( + Ver4DictConstants::PROBABILITY_SIZE, bigramEntryPos); } const int encodedTargetTerminalId = bigramListBuffer->readUintAndAdvancePosition( Ver4DictConstants::BIGRAM_TARGET_TERMINAL_ID_FIELD_SIZE, bigramEntryPos); @@ -65,21 +64,13 @@ const BigramEntry BigramDictContent::getBigramEntryAndAdvancePosition( bool BigramDictContent::writeBigramEntryAndAdvancePosition( const BigramEntry *const bigramEntryToWrite, int *const entryWritingPos) { BufferWithExtendableBuffer *const bigramListBuffer = getWritableContentBuffer(); - const int bigramFlags = createAndGetBigramFlags( - mHasHistoricalInfo ? 0 : bigramEntryToWrite->getProbability(), - bigramEntryToWrite->hasNext()); + const int bigramFlags = createAndGetBigramFlags(bigramEntryToWrite->hasNext()); if (!bigramListBuffer->writeUintAndAdvancePosition(bigramFlags, Ver4DictConstants::BIGRAM_FLAGS_FIELD_SIZE, entryWritingPos)) { AKLOGE("Cannot write bigram flags. pos: %d, flags: %x", *entryWritingPos, bigramFlags); return false; } if (mHasHistoricalInfo) { - if (!bigramListBuffer->writeUintAndAdvancePosition(bigramEntryToWrite->getProbability(), - Ver4DictConstants::PROBABILITY_SIZE, entryWritingPos)) { - AKLOGE("Cannot write bigram probability. pos: %d, probability: %d", *entryWritingPos, - bigramEntryToWrite->getProbability()); - return false; - } const HistoricalInfo *const historicalInfo = bigramEntryToWrite->getHistoricalInfo(); if (!bigramListBuffer->writeUintAndAdvancePosition(historicalInfo->getTimeStamp(), Ver4DictConstants::TIME_STAMP_FIELD_SIZE, entryWritingPos)) { @@ -99,6 +90,13 @@ bool BigramDictContent::writeBigramEntryAndAdvancePosition( historicalInfo->getCount()); return false; } + } else { + if (!bigramListBuffer->writeUintAndAdvancePosition(bigramEntryToWrite->getProbability(), + Ver4DictConstants::PROBABILITY_SIZE, entryWritingPos)) { + AKLOGE("Cannot write bigram probability. pos: %d, probability: %d", *entryWritingPos, + bigramEntryToWrite->getProbability()); + return false; + } } const int targetTerminalIdToWrite = (bigramEntryToWrite->getTargetTerminalId() == Ver4DictConstants::NOT_A_TERMINAL_ID) ? diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/bigram_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/bigram_dict_content.h index 944e0f9e2..b8bdb63a8 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/bigram_dict_content.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/bigram_dict_content.h @@ -95,9 +95,8 @@ class BigramDictContent : public SparseTableDictContent { private: DISALLOW_COPY_AND_ASSIGN(BigramDictContent); - int createAndGetBigramFlags(const int probability, const bool hasNext) const { - return (probability & Ver4DictConstants::BIGRAM_PROBABILITY_MASK) - | (hasNext ? Ver4DictConstants::BIGRAM_HAS_NEXT_MASK : 0); + int createAndGetBigramFlags(const bool hasNext) const { + return hasNext ? Ver4DictConstants::BIGRAM_HAS_NEXT_MASK : 0; } bool runGCBigramList(const int bigramListPos, diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp index 8373dc549..7da9e3072 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp @@ -115,9 +115,7 @@ int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability, } else if (bigramProbability == NOT_A_PROBABILITY) { return ProbabilityUtils::backoff(unigramProbability); } else { - // bigramProbability is a bigram probability delta. - return ProbabilityUtils::computeProbabilityForBigram(unigramProbability, - bigramProbability); + return bigramProbability; } } } @@ -398,7 +396,7 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(const int *const code const int probability = bigramEntry.hasHistoricalInfo() ? ForgettingCurveUtils::decodeProbability( bigramEntry.getHistoricalInfo(), mHeaderPolicy) : - getProbability(word1Probability, bigramEntry.getProbability()); + bigramEntry.getProbability(); bigrams.emplace_back(&word1, probability, historicalInfo->getTimeStamp(), historicalInfo->getLevel(), historicalInfo->getCount()); diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp index f31c50253..e868ddf6f 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp @@ -213,13 +213,16 @@ bool Ver4PatriciaTrieWritingHelper::truncateUnigrams( // Delete unigrams. while (static_cast<int>(priorityQueue.size()) > maxUnigramCount) { const int ptNodePos = priorityQueue.top().getDictPos(); + priorityQueue.pop(); const PtNodeParams ptNodeParams = ptNodeReader->fetchNodeInfoInBufferFromPtNodePos(ptNodePos); + if (ptNodeParams.representsNonWordInfo()) { + continue; + } if (!ptNodeWriter->markPtNodeAsWillBecomeNonTerminal(&ptNodeParams)) { AKLOGE("Cannot mark PtNode as willBecomeNonterminal. PtNode pos: %d", ptNodePos); return false; } - priorityQueue.pop(); } return true; } diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp index fa9600c74..3fc566e7a 100644 --- a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp +++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp @@ -37,6 +37,7 @@ const float ScoringParams::DISTANCE_WEIGHT_LENGTH = 0.1524f; const float ScoringParams::PROXIMITY_COST = 0.0694f; const float ScoringParams::FIRST_CHAR_PROXIMITY_COST = 0.072f; const float ScoringParams::FIRST_PROXIMITY_COST = 0.07788f; +const float ScoringParams::INTENTIONAL_OMISSION_COST = 0.1f; const float ScoringParams::OMISSION_COST = 0.467f; const float ScoringParams::OMISSION_COST_SAME_CHAR = 0.345f; const float ScoringParams::OMISSION_COST_FIRST_CHAR = 0.5256f; diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.h b/native/jni/src/suggest/policyimpl/typing/scoring_params.h index b66962019..b12de6d87 100644 --- a/native/jni/src/suggest/policyimpl/typing/scoring_params.h +++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.h @@ -44,6 +44,7 @@ class ScoringParams { static const float PROXIMITY_COST; static const float FIRST_CHAR_PROXIMITY_COST; static const float FIRST_PROXIMITY_COST; + static const float INTENTIONAL_OMISSION_COST; static const float OMISSION_COST; static const float OMISSION_COST_SAME_CHAR; static const float OMISSION_COST_FIRST_CHAR; diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h index 0ba439b47..84077174d 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h @@ -54,12 +54,15 @@ class TypingWeighting : public Weighting { float getOmissionCost(const DicNode *const parentDicNode, const DicNode *const dicNode) const { const bool isZeroCostOmission = parentDicNode->isZeroCostOmission(); + const bool isIntentionalOmission = parentDicNode->canBeIntentionalOmission(); const bool sameCodePoint = dicNode->isSameNodeCodePoint(parentDicNode); // If the traversal omitted the first letter then the dicNode should now be on the second. const bool isFirstLetterOmission = dicNode->getNodeCodePointCount() == 2; float cost = 0.0f; if (isZeroCostOmission) { cost = 0.0f; + } else if (isIntentionalOmission) { + cost = ScoringParams::INTENTIONAL_OMISSION_COST; } else if (isFirstLetterOmission) { cost = ScoringParams::OMISSION_COST_FIRST_CHAR; } else { |