diff options
Diffstat (limited to 'native/jni')
25 files changed, 277 insertions, 133 deletions
diff --git a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp index 28aaf2d1a..e41fe1d43 100644 --- a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp +++ b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp @@ -178,10 +178,10 @@ static void latinime_BinaryDictionary_getSuggestions(JNIEnv *env, jclass clazz, jlong proximityInfo, jlong dicTraverseSession, jintArray xCoordinatesArray, jintArray yCoordinatesArray, jintArray timesArray, jintArray pointerIdsArray, jintArray inputCodePointsArray, jint inputSize, jintArray suggestOptions, - jintArray prevWordCodePointsForBigrams, jintArray outSuggestionCount, - jintArray outCodePointsArray, jintArray outScoresArray, jintArray outSpaceIndicesArray, - jintArray outTypesArray, jintArray outAutoCommitFirstWordConfidenceArray, - jfloatArray inOutLanguageWeight) { + jintArray prevWordCodePointsForBigrams, jboolean isBeginningOfSentence, + jintArray outSuggestionCount, jintArray outCodePointsArray, jintArray outScoresArray, + jintArray outSpaceIndicesArray, jintArray outTypesArray, + jintArray outAutoCommitFirstWordConfidenceArray, jfloatArray inOutLanguageWeight) { Dictionary *dictionary = reinterpret_cast<Dictionary *>(dict); // Assign 0 to outSuggestionCount here in case of returning earlier in this method. JniDataUtils::putIntToArray(env, outSuggestionCount, 0 /* index */, 0); @@ -274,7 +274,7 @@ static jint latinime_BinaryDictionary_getProbability(JNIEnv *env, jclass clazz, } static jint latinime_BinaryDictionary_getBigramProbability(JNIEnv *env, jclass clazz, - jlong dict, jintArray word0, jintArray word1) { + jlong dict, jintArray word0, jboolean isBeginningOfSentence, jintArray word1) { Dictionary *dictionary = reinterpret_cast<Dictionary *>(dict); if (!dictionary) return JNI_FALSE; const jsize word0Length = env->GetArrayLength(word0); @@ -283,7 +283,7 @@ static jint latinime_BinaryDictionary_getBigramProbability(JNIEnv *env, jclass c int word1CodePoints[word1Length]; env->GetIntArrayRegion(word0, 0, word0Length, word0CodePoints); env->GetIntArrayRegion(word1, 0, word1Length, word1CodePoints); - const PrevWordsInfo prevWordsInfo(word0CodePoints, word0Length, false /* isStartOfSentence */); + const PrevWordsInfo prevWordsInfo(word0CodePoints, word0Length, isBeginningOfSentence); return dictionary->getBigramProbability(&prevWordsInfo, word1CodePoints, word1Length); } @@ -326,7 +326,8 @@ static void latinime_BinaryDictionary_getWordProperty(JNIEnv *env, jclass clazz, static void latinime_BinaryDictionary_addUnigramWord(JNIEnv *env, jclass clazz, jlong dict, jintArray word, jint probability, jintArray shortcutTarget, jint shortcutProbability, - jboolean isNotAWord, jboolean isBlacklisted, jint timestamp) { + jboolean isBeginningOfSentence, jboolean isNotAWord, jboolean isBlacklisted, + jint timestamp) { Dictionary *dictionary = reinterpret_cast<Dictionary *>(dict); if (!dictionary) { return; @@ -341,13 +342,14 @@ static void latinime_BinaryDictionary_addUnigramWord(JNIEnv *env, jclass clazz, shortcuts.emplace_back(&shortcutTargetCodePoints, shortcutProbability); } // Use 1 for count to indicate the word has inputted. - const UnigramProperty unigramProperty(isNotAWord, isBlacklisted, - probability, timestamp, 0 /* level */, 1 /* count */, &shortcuts); - dictionary->addUnigramWord(codePoints, codePointCount, &unigramProperty); + const UnigramProperty unigramProperty(isBeginningOfSentence, isNotAWord, + isBlacklisted, probability, timestamp, 0 /* level */, 1 /* count */, &shortcuts); + dictionary->addUnigramEntry(codePoints, codePointCount, &unigramProperty); } static void latinime_BinaryDictionary_addBigramWords(JNIEnv *env, jclass clazz, jlong dict, - jintArray word0, jintArray word1, jint probability, jint timestamp) { + jintArray word0, jboolean isBeginningOfSentence, jintArray word1, jint probability, + jint timestamp) { Dictionary *dictionary = reinterpret_cast<Dictionary *>(dict); if (!dictionary) { return; @@ -363,11 +365,12 @@ static void latinime_BinaryDictionary_addBigramWords(JNIEnv *env, jclass clazz, // Use 1 for count to indicate the bigram has inputted. const BigramProperty bigramProperty(&bigramTargetCodePoints, probability, timestamp, 0 /* level */, 1 /* count */); - dictionary->addBigramWords(word0CodePoints, word0Length, &bigramProperty); + const PrevWordsInfo prevWordsInfo(word0CodePoints, word0Length, isBeginningOfSentence); + dictionary->addNgramEntry(&prevWordsInfo, &bigramProperty); } static void latinime_BinaryDictionary_removeBigramWords(JNIEnv *env, jclass clazz, jlong dict, - jintArray word0, jintArray word1) { + jintArray word0, jboolean isBeginningOfSentence, jintArray word1) { Dictionary *dictionary = reinterpret_cast<Dictionary *>(dict); if (!dictionary) { return; @@ -378,8 +381,8 @@ static void latinime_BinaryDictionary_removeBigramWords(JNIEnv *env, jclass claz jsize word1Length = env->GetArrayLength(word1); int word1CodePoints[word1Length]; env->GetIntArrayRegion(word1, 0, word1Length, word1CodePoints); - dictionary->removeBigramWords(word0CodePoints, word0Length, word1CodePoints, - word1Length); + const PrevWordsInfo prevWordsInfo(word0CodePoints, word0Length, isBeginningOfSentence); + dictionary->removeNgramEntry(&prevWordsInfo, word1CodePoints, word1Length); } // Returns how many language model params are processed. @@ -447,9 +450,10 @@ static int latinime_BinaryDictionary_addMultipleDictionaryEntries(JNIEnv *env, j shortcuts.emplace_back(&shortcutTargetCodePoints, shortcutProbability); } // Use 1 for count to indicate the word has inputted. - const UnigramProperty unigramProperty(isNotAWord, isBlacklisted, - unigramProbability, timestamp, 0 /* level */, 1 /* count */, &shortcuts); - dictionary->addUnigramWord(word1CodePoints, word1Length, &unigramProperty); + const UnigramProperty unigramProperty(false /* isBeginningOfSentence */, isNotAWord, + isBlacklisted, unigramProbability, timestamp, 0 /* level */, 1 /* count */, + &shortcuts); + dictionary->addUnigramEntry(word1CodePoints, word1Length, &unigramProperty); if (word0) { jint bigramProbability = env->GetIntField(languageModelParam, bigramProbabilityFieldId); const std::vector<int> bigramTargetCodePoints( @@ -457,7 +461,9 @@ static int latinime_BinaryDictionary_addMultipleDictionaryEntries(JNIEnv *env, j // Use 1 for count to indicate the bigram has inputted. const BigramProperty bigramProperty(&bigramTargetCodePoints, bigramProbability, timestamp, 0 /* level */, 1 /* count */); - dictionary->addBigramWords(word0CodePoints, word0Length, &bigramProperty); + const PrevWordsInfo prevWordsInfo(word0CodePoints, word0Length, + false /* isBeginningOfSentence */); + dictionary->addNgramEntry(&prevWordsInfo, &bigramProperty); } if (dictionary->needsToRunGC(true /* mindsBlockByGC */)) { return i + 1; @@ -541,7 +547,7 @@ static bool latinime_BinaryDictionary_migrateNative(JNIEnv *env, jclass clazz, j return false; } } - if (!dictionaryStructureWithBufferPolicy->addUnigramWord(wordCodePoints, wordLength, + if (!dictionaryStructureWithBufferPolicy->addUnigramEntry(wordCodePoints, wordLength, wordProperty.getUnigramProperty())) { LogUtils::logToJava(env, "Cannot add unigram to the new dict."); return false; @@ -561,8 +567,10 @@ static bool latinime_BinaryDictionary_migrateNative(JNIEnv *env, jclass clazz, j return false; } } + const PrevWordsInfo prevWordsInfo(wordCodePoints, wordLength, + false /* isStartOfSentence */); for (const BigramProperty &bigramProperty : *wordProperty.getBigramProperties()) { - if (!dictionaryStructureWithBufferPolicy->addBigramWords(wordCodePoints, wordLength, + if (!dictionaryStructureWithBufferPolicy->addNgramEntry(&prevWordsInfo, &bigramProperty)) { LogUtils::logToJava(env, "Cannot add bigram to the new dict."); return false; @@ -617,7 +625,7 @@ static const JNINativeMethod sMethods[] = { }, { const_cast<char *>("getSuggestionsNative"), - const_cast<char *>("(JJJ[I[I[I[I[II[I[I[I[I[I[I[I[I[F)V"), + const_cast<char *>("(JJJ[I[I[I[I[II[I[IZ[I[I[I[I[I[I[F)V"), reinterpret_cast<void *>(latinime_BinaryDictionary_getSuggestions) }, { @@ -627,7 +635,7 @@ static const JNINativeMethod sMethods[] = { }, { const_cast<char *>("getBigramProbabilityNative"), - const_cast<char *>("(J[I[I)I"), + const_cast<char *>("(J[IZ[I)I"), reinterpret_cast<void *>(latinime_BinaryDictionary_getBigramProbability) }, { @@ -643,17 +651,17 @@ static const JNINativeMethod sMethods[] = { }, { const_cast<char *>("addUnigramWordNative"), - const_cast<char *>("(J[II[IIZZI)V"), + const_cast<char *>("(J[II[IIZZZI)V"), reinterpret_cast<void *>(latinime_BinaryDictionary_addUnigramWord) }, { const_cast<char *>("addBigramWordsNative"), - const_cast<char *>("(J[I[III)V"), + const_cast<char *>("(J[IZ[III)V"), reinterpret_cast<void *>(latinime_BinaryDictionary_addBigramWords) }, { const_cast<char *>("removeBigramWordsNative"), - const_cast<char *>("(J[I[I)V"), + const_cast<char *>("(J[IZ[I)V"), reinterpret_cast<void *>(latinime_BinaryDictionary_removeBigramWords) }, { diff --git a/native/jni/src/suggest/core/dicnode/dic_node.h b/native/jni/src/suggest/core/dicnode/dic_node.h index e69d2c46b..ef03d2b6d 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node.h +++ b/native/jni/src/suggest/core/dicnode/dic_node.h @@ -203,12 +203,12 @@ class DicNode { return mDicNodeState.mDicNodeStateInput.getInputIndex(0) < inputSize - 1; } - // Used to get n-gram probability in DicNodeUtils + // Used to get n-gram probability in DicNodeUtils. int getPtNodePos() const { return mDicNodeProperties.getPtNodePos(); } - // Used to get n-gram probability in DicNodeUtils + // Used to get n-gram probability in DicNodeUtils. n is 1-indexed. int getNthPrevWordTerminalPtNodePos(const int n) const { if (n <= 0 || n > MAX_PREV_WORD_COUNT_FOR_N_GRAM) { return NOT_A_DICT_POS; diff --git a/native/jni/src/suggest/core/dictionary/dictionary.cpp b/native/jni/src/suggest/core/dictionary/dictionary.cpp index c860d82af..bcf7d5905 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.cpp +++ b/native/jni/src/suggest/core/dictionary/dictionary.cpp @@ -74,28 +74,34 @@ int Dictionary::getProbability(const int *word, int length) const { return getDictionaryStructurePolicy()->getUnigramProbabilityOfPtNode(pos); } -int Dictionary::getBigramProbability(const PrevWordsInfo *const prevWordsInfo, const int *word1, - int length1) const { +int Dictionary::getBigramProbability(const PrevWordsInfo *const prevWordsInfo, const int *word, + int length) const { TimeKeeper::setCurrentTime(); - return mBigramDictionary.getBigramProbability(prevWordsInfo, word1, length1); + return mBigramDictionary.getBigramProbability(prevWordsInfo, word, length); } -void Dictionary::addUnigramWord(const int *const word, const int length, +void Dictionary::addUnigramEntry(const int *const word, const int length, const UnigramProperty *const unigramProperty) { + if (unigramProperty->representsBeginningOfSentence() + && !mDictionaryStructureWithBufferPolicy->getHeaderStructurePolicy() + ->supportsBeginningOfSentence()) { + AKLOGE("The dictionary doesn't support Beginning-of-Sentence."); + return; + } TimeKeeper::setCurrentTime(); - mDictionaryStructureWithBufferPolicy->addUnigramWord(word, length, unigramProperty); + mDictionaryStructureWithBufferPolicy->addUnigramEntry(word, length, unigramProperty); } -void Dictionary::addBigramWords(const int *const word0, const int length0, +void Dictionary::addNgramEntry(const PrevWordsInfo *const prevWordsInfo, const BigramProperty *const bigramProperty) { TimeKeeper::setCurrentTime(); - mDictionaryStructureWithBufferPolicy->addBigramWords(word0, length0, bigramProperty); + mDictionaryStructureWithBufferPolicy->addNgramEntry(prevWordsInfo, bigramProperty); } -void Dictionary::removeBigramWords(const int *const word0, const int length0, - const int *const word1, const int length1) { +void Dictionary::removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, + const int *const word, const int length) { TimeKeeper::setCurrentTime(); - mDictionaryStructureWithBufferPolicy->removeBigramWords(word0, length0, word1, length1); + mDictionaryStructureWithBufferPolicy->removeNgramEntry(prevWordsInfo, word, length); } void Dictionary::flush(const char *const filePath) { diff --git a/native/jni/src/suggest/core/dictionary/dictionary.h b/native/jni/src/suggest/core/dictionary/dictionary.h index b63c61fbb..817d9f7fc 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.h +++ b/native/jni/src/suggest/core/dictionary/dictionary.h @@ -73,16 +73,16 @@ class Dictionary { int getProbability(const int *word, int length) const; int getBigramProbability(const PrevWordsInfo *const prevWordsInfo, - const int *word1, int length1) const; + const int *word, int length) const; - void addUnigramWord(const int *const codePoints, const int codePointCount, + void addUnigramEntry(const int *const codePoints, const int codePointCount, const UnigramProperty *const unigramProperty); - void addBigramWords(const int *const word0, const int length0, + void addNgramEntry(const PrevWordsInfo *const prevWordsInfo, const BigramProperty *const bigramProperty); - void removeBigramWords(const int *const word0, const int length0, const int *const word1, - const int length1); + void removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, const int *const word, + const int length); void flush(const char *const filePath); diff --git a/native/jni/src/suggest/core/dictionary/property/bigram_property.h b/native/jni/src/suggest/core/dictionary/property/bigram_property.h index 8d3429b5b..343af143c 100644 --- a/native/jni/src/suggest/core/dictionary/property/bigram_property.h +++ b/native/jni/src/suggest/core/dictionary/property/bigram_property.h @@ -23,6 +23,7 @@ namespace latinime { +// TODO: Change to NgramProperty. class BigramProperty { public: BigramProperty(const std::vector<int> *const targetCodePoints, diff --git a/native/jni/src/suggest/core/dictionary/property/unigram_property.h b/native/jni/src/suggest/core/dictionary/property/unigram_property.h index d2551057b..902eb000f 100644 --- a/native/jni/src/suggest/core/dictionary/property/unigram_property.h +++ b/native/jni/src/suggest/core/dictionary/property/unigram_property.h @@ -48,15 +48,21 @@ class UnigramProperty { }; UnigramProperty() - : mIsNotAWord(false), mIsBlacklisted(false), mProbability(NOT_A_PROBABILITY), - mTimestamp(NOT_A_TIMESTAMP), mLevel(0), mCount(0), mShortcuts() {} - - UnigramProperty(const bool isNotAWord, const bool isBlacklisted, const int probability, - const int timestamp, const int level, const int count, - const std::vector<ShortcutProperty> *const shortcuts) - : mIsNotAWord(isNotAWord), mIsBlacklisted(isBlacklisted), mProbability(probability), + : mRepresentsBeginningOfSentence(false), mIsNotAWord(false), mIsBlacklisted(false), + mProbability(NOT_A_PROBABILITY), mTimestamp(NOT_A_TIMESTAMP), mLevel(0), mCount(0), + mShortcuts() {} + + UnigramProperty(const bool representsBeginningOfSentence, const bool isNotAWord, + const bool isBlacklisted, const int probability, const int timestamp, const int level, + const int count, const std::vector<ShortcutProperty> *const shortcuts) + : mRepresentsBeginningOfSentence(representsBeginningOfSentence), + mIsNotAWord(isNotAWord), mIsBlacklisted(isBlacklisted), mProbability(probability), mTimestamp(timestamp), mLevel(level), mCount(count), mShortcuts(*shortcuts) {} + bool representsBeginningOfSentence() const { + return mRepresentsBeginningOfSentence; + } + bool isNotAWord() const { return mIsNotAWord; } @@ -94,6 +100,7 @@ class UnigramProperty { DISALLOW_ASSIGNMENT_OPERATOR(UnigramProperty); // TODO: Make members const. + bool mRepresentsBeginningOfSentence; bool mIsNotAWord; bool mIsBlacklisted; int mProbability; diff --git a/native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h b/native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h index 845e629e6..a61227626 100644 --- a/native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h +++ b/native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h @@ -51,6 +51,8 @@ class DictionaryHeaderStructurePolicy { virtual const std::vector<int> *getLocale() const = 0; + virtual bool supportsBeginningOfSentence() const = 0; + protected: DictionaryHeaderStructurePolicy() {} diff --git a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h index ce5a49f83..3fd815f98 100644 --- a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h +++ b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h @@ -29,6 +29,7 @@ class DicNodeVector; class DictionaryBigramsStructurePolicy; class DictionaryHeaderStructurePolicy; class DictionaryShortcutsStructurePolicy; +class PrevWordsInfo; class UnigramProperty; /* @@ -69,16 +70,16 @@ class DictionaryStructureWithBufferPolicy { virtual const DictionaryShortcutsStructurePolicy *getShortcutsStructurePolicy() const = 0; // Returns whether the update was success or not. - virtual bool addUnigramWord(const int *const word, const int length, + virtual bool addUnigramEntry(const int *const word, const int length, const UnigramProperty *const unigramProperty) = 0; // Returns whether the update was success or not. - virtual bool addBigramWords(const int *const word0, const int length0, + virtual bool addNgramEntry(const PrevWordsInfo *const prevWordsInfo, const BigramProperty *const bigramProperty) = 0; // Returns whether the update was success or not. - virtual bool removeBigramWords(const int *const word0, const int length0, - const int *const word1, const int length1) = 0; + virtual bool removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, + const int *const word, const int length) = 0; virtual void flush(const char *const filePath) = 0; diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.cpp b/native/jni/src/suggest/core/session/dic_traverse_session.cpp index dc2b66a2c..f1e411f38 100644 --- a/native/jni/src/suggest/core/session/dic_traverse_session.cpp +++ b/native/jni/src/suggest/core/session/dic_traverse_session.cpp @@ -36,7 +36,7 @@ void DicTraverseSession::init(const Dictionary *const dictionary, ->getMultiWordCostMultiplier(); mSuggestOptions = suggestOptions; prevWordsInfo->getPrevWordsTerminalPtNodePos( - getDictionaryStructurePolicy(), mPrevWordsPtNodePos); + getDictionaryStructurePolicy(), mPrevWordsPtNodePos, true /* tryLowerCaseSearch */); } void DicTraverseSession::setupForGetSuggestions(const ProximityInfo *pInfo, diff --git a/native/jni/src/suggest/core/session/prev_words_info.h b/native/jni/src/suggest/core/session/prev_words_info.h index 70a99ef38..a58000abb 100644 --- a/native/jni/src/suggest/core/session/prev_words_info.h +++ b/native/jni/src/suggest/core/session/prev_words_info.h @@ -20,11 +20,11 @@ #include "defines.h" #include "suggest/core/dictionary/binary_dictionary_bigrams_iterator.h" #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" +#include "utils/char_utils.h" namespace latinime { // TODO: Support n-gram. -// TODO: Support beginning of sentence. // This class does not take ownership of any code point buffers. class PrevWordsInfo { public: @@ -41,29 +41,48 @@ class PrevWordsInfo { mIsBeginningOfSentence[0] = isBeginningOfSentence; } + bool isValid() const { + for (size_t i = 0; i < NELEMS(mPrevWordCodePoints); ++i) { + if (mPrevWordCodePointCount[i] > MAX_WORD_LENGTH) { + return false; + } + } + return true; + } + void getPrevWordsTerminalPtNodePos( const DictionaryStructureWithBufferPolicy *const dictStructurePolicy, - int *const outPrevWordsTerminalPtNodePos) const { + int *const outPrevWordsTerminalPtNodePos, const bool tryLowerCaseSearch) const { for (size_t i = 0; i < NELEMS(mPrevWordCodePoints); ++i) { outPrevWordsTerminalPtNodePos[i] = getTerminalPtNodePosOfWord(dictStructurePolicy, mPrevWordCodePoints[i], mPrevWordCodePointCount[i], - mIsBeginningOfSentence[i]); + mIsBeginningOfSentence[i], tryLowerCaseSearch); } } BinaryDictionaryBigramsIterator getBigramsIteratorForPrediction( const DictionaryStructureWithBufferPolicy *const dictStructurePolicy) const { - int pos = getBigramListPositionForWord(dictStructurePolicy, mPrevWordCodePoints[0], - mPrevWordCodePointCount[0], false /* forceLowerCaseSearch */); - // getBigramListPositionForWord returns NOT_A_DICT_POS if this word isn't in the - // dictionary or has no bigrams - if (NOT_A_DICT_POS == pos) { - // If no bigrams for this exact word, search again in lower case. - pos = getBigramListPositionForWord(dictStructurePolicy, mPrevWordCodePoints[0], - mPrevWordCodePointCount[0], true /* forceLowerCaseSearch */); + const int bigramListPos = getBigramListPositionForWordWithTryingLowerCaseSearch( + dictStructurePolicy, mPrevWordCodePoints[0], mPrevWordCodePointCount[0], + mIsBeginningOfSentence[0]); + return BinaryDictionaryBigramsIterator(dictStructurePolicy->getBigramsStructurePolicy(), + bigramListPos); + } + + // n is 1-indexed. + const int *getNthPrevWordCodePoints(const int n) const { + if (n <= 0 || n > MAX_PREV_WORD_COUNT_FOR_N_GRAM) { + return nullptr; + } + return mPrevWordCodePoints[n - 1]; + } + + // n is 1-indexed. + int getNthPrevWordCodePointCount(const int n) const { + if (n <= 0 || n > MAX_PREV_WORD_COUNT_FOR_N_GRAM) { + return 0; } - return BinaryDictionaryBigramsIterator( - dictStructurePolicy->getBigramsStructurePolicy(), pos); + return mPrevWordCodePointCount[n - 1]; } private: @@ -72,19 +91,57 @@ class PrevWordsInfo { static int getTerminalPtNodePosOfWord( const DictionaryStructureWithBufferPolicy *const dictStructurePolicy, const int *const wordCodePoints, const int wordCodePointCount, - const bool isBeginningOfSentence) { + const bool isBeginningOfSentence, const bool tryLowerCaseSearch) { if (!dictStructurePolicy || !wordCodePoints) { return NOT_A_DICT_POS; } + int codePoints[MAX_WORD_LENGTH]; + int codePointCount = wordCodePointCount; + memmove(codePoints, wordCodePoints, sizeof(int) * codePointCount); + if (isBeginningOfSentence) { + codePointCount = CharUtils::attachBeginningOfSentenceMarker(codePoints, + codePointCount, MAX_WORD_LENGTH); + if (codePointCount <= 0) { + return NOT_A_DICT_POS; + } + } const int wordPtNodePos = dictStructurePolicy->getTerminalPtNodePositionOfWord( - wordCodePoints, wordCodePointCount, false /* forceLowerCaseSearch */); - if (wordPtNodePos != NOT_A_DICT_POS) { + codePoints, codePointCount, false /* forceLowerCaseSearch */); + if (wordPtNodePos != NOT_A_DICT_POS || !tryLowerCaseSearch) { + // Return the position when when the word was found or doesn't try lower case + // search. return wordPtNodePos; } // Check bigrams for lower-cased previous word if original was not found. Useful for // auto-capitalized words like "The [current_word]". return dictStructurePolicy->getTerminalPtNodePositionOfWord( - wordCodePoints, wordCodePointCount, true /* forceLowerCaseSearch */); + codePoints, codePointCount, true /* forceLowerCaseSearch */); + } + + static int getBigramListPositionForWordWithTryingLowerCaseSearch( + const DictionaryStructureWithBufferPolicy *const dictStructurePolicy, + const int *const wordCodePoints, const int wordCodePointCount, + const bool isBeginningOfSentence) { + int codePoints[MAX_WORD_LENGTH]; + int codePointCount = wordCodePointCount; + memmove(codePoints, wordCodePoints, sizeof(int) * codePointCount); + if (isBeginningOfSentence) { + codePointCount = CharUtils::attachBeginningOfSentenceMarker(codePoints, + codePointCount, MAX_WORD_LENGTH); + if (codePointCount <= 0) { + return NOT_A_DICT_POS; + } + } + int pos = getBigramListPositionForWord(dictStructurePolicy, codePoints, + codePointCount, false /* forceLowerCaseSearch */); + // getBigramListPositionForWord returns NOT_A_DICT_POS if this word isn't in the + // dictionary or has no bigrams + if (NOT_A_DICT_POS == pos) { + // If no bigrams for this exact word, search again in lower case. + pos = getBigramListPositionForWord(dictStructurePolicy, codePoints, + codePointCount, true /* forceLowerCaseSearch */); + } + return pos; } static int getBigramListPositionForWord( diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h index 479d15164..75f4fef90 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h @@ -139,6 +139,8 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { switch (mDictFormatVersion) { case FormatUtils::VERSION_2: return FormatUtils::VERSION_2; + case FormatUtils::VERSION_401: + return FormatUtils::VERSION_401; case FormatUtils::VERSION_4_ONLY_FOR_TESTING: return FormatUtils::VERSION_4_ONLY_FOR_TESTING; case FormatUtils::VERSION_4: @@ -246,6 +248,10 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { return &mLocale; } + bool supportsBeginningOfSentence() const { + return mDictFormatVersion > FormatUtils::VERSION_401; + } + private: DISALLOW_COPY_AND_ASSIGN(HeaderPolicy); diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp index a8f8f284b..b13ad1879 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp @@ -98,6 +98,7 @@ typedef DictionaryHeaderStructurePolicy::AttributeMap AttributeMap; case FormatUtils::VERSION_2: // Version 2 dictionary writing is not supported. return false; + case FormatUtils::VERSION_401: case FormatUtils::VERSION_4_ONLY_FOR_TESTING: case FormatUtils::VERSION_4: case FormatUtils::VERSION_4_DEV: diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v401/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v401/ver4_patricia_trie_policy.cpp index dde1af299..557a0b4c8 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v401/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v401/ver4_patricia_trie_policy.cpp @@ -31,6 +31,7 @@ #include "suggest/core/dictionary/property/bigram_property.h" #include "suggest/core/dictionary/property/unigram_property.h" #include "suggest/core/dictionary/property/word_property.h" +#include "suggest/core/session/prev_words_info.h" #include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h" #include "suggest/policyimpl/dictionary/structure/backward/v401/ver4_patricia_trie_node_reader.h" #include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h" @@ -163,10 +164,10 @@ int Ver4PatriciaTriePolicy::getBigramsPositionOfPtNode(const int ptNodePos) cons ptNodeParams.getTerminalId()); } -bool Ver4PatriciaTriePolicy::addUnigramWord(const int *const word, const int length, +bool Ver4PatriciaTriePolicy::addUnigramEntry(const int *const word, const int length, const UnigramProperty *const unigramProperty) { if (!mBuffers->isUpdatable()) { - AKLOGI("Warning: addUnigramWord() is called for non-updatable dictionary."); + AKLOGI("Warning: addUnigramEntry() is called for non-updatable dictionary."); return false; } if (mDictBuffer->getTailPosition() >= MIN_DICT_SIZE_TO_REFUSE_DYNAMIC_OPERATIONS) { @@ -218,10 +219,12 @@ bool Ver4PatriciaTriePolicy::addUnigramWord(const int *const word, const int len } } -bool Ver4PatriciaTriePolicy::addBigramWords(const int *const word0, const int length0, +bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsInfo, const BigramProperty *const bigramProperty) { + const int length0 = prevWordsInfo->getNthPrevWordCodePointCount(1); + const int *word0 = prevWordsInfo->getNthPrevWordCodePoints(1); if (!mBuffers->isUpdatable()) { - AKLOGI("Warning: addBigramWords() is called for non-updatable dictionary."); + AKLOGI("Warning: addNgramEntry() is called for non-updatable dictionary."); return false; } if (mDictBuffer->getTailPosition() >= MIN_DICT_SIZE_TO_REFUSE_DYNAMIC_OPERATIONS) { @@ -257,8 +260,10 @@ bool Ver4PatriciaTriePolicy::addBigramWords(const int *const word0, const int le } } -bool Ver4PatriciaTriePolicy::removeBigramWords(const int *const word0, const int length0, +bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, const int *const word1, const int length1) { + const int length0 = prevWordsInfo->getNthPrevWordCodePointCount(1); + const int *word0 = prevWordsInfo->getNthPrevWordCodePoints(1); if (!mBuffers->isUpdatable()) { AKLOGI("Warning: addBigramWords() is called for non-updatable dictionary."); return false; @@ -427,8 +432,8 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(const int *const code shortcuts.emplace_back(&target, shortcutProbability); } } - const UnigramProperty unigramProperty(ptNodeParams.isNotAWord(), - ptNodeParams.isBlacklisted(), ptNodeParams.getProbability(), + const UnigramProperty unigramProperty(false /* representsBeginningOfSentence */, + ptNodeParams.isNotAWord(), ptNodeParams.isBlacklisted(), ptNodeParams.getProbability(), historicalInfo->getTimeStamp(), historicalInfo->getLevel(), historicalInfo->getCount(), &shortcuts); return WordProperty(&codePointVector, &unigramProperty, &bigrams); diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v401/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v401/ver4_patricia_trie_policy.h index 2f8ad539c..95813881d 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v401/ver4_patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v401/ver4_patricia_trie_policy.h @@ -108,14 +108,14 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { return &mShortcutPolicy; } - bool addUnigramWord(const int *const word, const int length, + bool addUnigramEntry(const int *const word, const int length, const UnigramProperty *const unigramProperty); - bool addBigramWords(const int *const word0, const int length0, + bool addNgramEntry(const PrevWordsInfo *const prevWordsInfo, const BigramProperty *const bigramProperty); - bool removeBigramWords(const int *const word0, const int length0, const int *const word1, - const int length1); + bool removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, const int *const word, + const int length); void flush(const char *const filePath); diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp index 59f1f29e9..93e330a2a 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp @@ -57,13 +57,14 @@ namespace latinime { const DictionaryHeaderStructurePolicy::AttributeMap *const attributeMap) { FormatUtils::FORMAT_VERSION dictFormatVersion = FormatUtils::getFormatVersion(formatVersion); switch (dictFormatVersion) { - case FormatUtils::VERSION_4: { + case FormatUtils::VERSION_401: { return newPolicyForOnMemoryV4Dict<backward::v401::Ver4DictConstants, backward::v401::Ver4DictBuffers, backward::v401::Ver4DictBuffers::Ver4DictBuffersPtr, backward::v401::Ver4PatriciaTriePolicy>( dictFormatVersion, locale, attributeMap); } + case FormatUtils::VERSION_4: case FormatUtils::VERSION_4_ONLY_FOR_TESTING: case FormatUtils::VERSION_4_DEV: { return newPolicyForOnMemoryV4Dict<Ver4DictConstants, Ver4DictBuffers, @@ -115,13 +116,14 @@ template<class DictConstants, class DictBuffers, class DictBuffersPtr, class Str case FormatUtils::VERSION_2: AKLOGE("Given path is a directory but the format is version 2. path: %s", path); break; - case FormatUtils::VERSION_4: { + case FormatUtils::VERSION_401: { return newPolicyForV4Dict<backward::v401::Ver4DictConstants, backward::v401::Ver4DictBuffers, backward::v401::Ver4DictBuffers::Ver4DictBuffersPtr, backward::v401::Ver4PatriciaTriePolicy>( headerFilePath, formatVersion, std::move(mmappedBuffer)); } + case FormatUtils::VERSION_4: case FormatUtils::VERSION_4_ONLY_FOR_TESTING: case FormatUtils::VERSION_4_DEV: { return newPolicyForV4Dict<Ver4DictConstants, Ver4DictBuffers, @@ -145,7 +147,8 @@ template<class DictConstants, class DictBuffers, class DictBuffersPtr, class Str char dictPath[dictDirPathBufSize]; if (!FileUtils::getFilePathWithoutSuffix(headerFilePath, DictConstants::HEADER_FILE_EXTENSION, dictDirPathBufSize, dictPath)) { - AKLOGE("Dictionary file name is not valid as a ver4 dictionary. path: %s", path); + AKLOGE("Dictionary file name is not valid as a ver4 dictionary. header path: %s", + headerFilePath); ASSERT(false); return nullptr; } @@ -153,7 +156,7 @@ template<class DictConstants, class DictBuffers, class DictBuffersPtr, class Str DictBuffers::openVer4DictBuffers(dictPath, std::move(mmappedBuffer), formatVersion); if (!dictBuffers || !dictBuffers->isValid()) { AKLOGE("DICT: The dictionary doesn't satisfy ver4 format requirements. path: %s", - path); + dictPath); ASSERT(false); return nullptr; } @@ -176,6 +179,7 @@ template<class DictConstants, class DictBuffers, class DictBuffersPtr, class Str case FormatUtils::VERSION_2: return DictionaryStructureWithBufferPolicy::StructurePolicyPtr( new PatriciaTriePolicy(std::move(mmappedBuffer))); + case FormatUtils::VERSION_401: case FormatUtils::VERSION_4_ONLY_FOR_TESTING: case FormatUtils::VERSION_4: case FormatUtils::VERSION_4_DEV: diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_gc_event_listeners.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_gc_event_listeners.cpp index 028e9ecbf..1f00fc6ab 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_gc_event_listeners.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_gc_event_listeners.cpp @@ -56,7 +56,7 @@ bool DynamicPtGcEventListeners } } else { mValueStack.back() += 1; - if (ptNodeParams->isTerminal()) { + if (ptNodeParams->isTerminal() && !ptNodeParams->representsNonWordInfo()) { mValidUnigramCount += 1; } } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h index 5704c2e90..b2e60a837 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h @@ -160,7 +160,12 @@ class PtNodeParams { } AK_FORCE_INLINE bool representsNonWordInfo() const { - return getCodePointCount() > 0 && CharUtils::isInUnicodeSpace(getCodePoints()[0]) + return getCodePointCount() > 0 && !CharUtils::isInUnicodeSpace(getCodePoints()[0]) + && isNotAWord(); + } + + AK_FORCE_INLINE int representsBeginningOfSentence() const { + return getCodePointCount() > 0 && getCodePoints()[0] == CODE_POINT_BEGINNING_OF_SENTENCE && isNotAWord(); } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp index 30dcfba37..a6a470c4e 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp @@ -383,8 +383,8 @@ const WordProperty PatriciaTriePolicy::getWordProperty(const int *const codePoin shortcuts.emplace_back(&shortcutTarget, shortcutProbability); } } - const UnigramProperty unigramProperty(ptNodeParams.isNotAWord(), - ptNodeParams.isBlacklisted(), ptNodeParams.getProbability(), + const UnigramProperty unigramProperty(ptNodeParams.representsBeginningOfSentence(), + ptNodeParams.isNotAWord(), ptNodeParams.isBlacklisted(), ptNodeParams.getProbability(), NOT_A_TIMESTAMP /* timestamp */, 0 /* level */, 0 /* count */, &shortcuts); return WordProperty(&codePointVector, &unigramProperty, &bigrams); } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h index 54d1e0f6d..6240d46aa 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h @@ -81,24 +81,24 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { return &mShortcutListPolicy; } - bool addUnigramWord(const int *const word, const int length, + bool addUnigramEntry(const int *const word, const int length, const UnigramProperty *const unigramProperty) { // This method should not be called for non-updatable dictionary. - AKLOGI("Warning: addUnigramWord() is called for non-updatable dictionary."); + AKLOGI("Warning: addUnigramEntry() is called for non-updatable dictionary."); return false; } - bool addBigramWords(const int *const word0, const int length0, + bool addNgramEntry(const PrevWordsInfo *const prevWordsInfo, const BigramProperty *const bigramProperty) { // This method should not be called for non-updatable dictionary. - AKLOGI("Warning: addBigramWords() is called for non-updatable dictionary."); + AKLOGI("Warning: addNgramEntry() is called for non-updatable dictionary."); return false; } - bool removeBigramWords(const int *const word0, const int length0, const int *const word1, - const int length1) { + bool removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, const int *const word, + const int length) { // This method should not be called for non-updatable dictionary. - AKLOGI("Warning: removeBigramWords() is called for non-updatable dictionary."); + AKLOGI("Warning: removeNgramEntry() is called for non-updatable dictionary."); return false; } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp index 7da9e3072..02478700a 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp @@ -23,6 +23,7 @@ #include "suggest/core/dictionary/property/bigram_property.h" #include "suggest/core/dictionary/property/unigram_property.h" #include "suggest/core/dictionary/property/word_property.h" +#include "suggest/core/session/prev_words_info.h" #include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h" #include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.h" #include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h" @@ -60,7 +61,7 @@ void Ver4PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const d isTerminal = ptNodeParams.getProbability() != NOT_A_PROBABILITY; } readingHelper.readNextSiblingNode(ptNodeParams); - if (!ptNodeParams.representsNonWordInfo()) { + if (ptNodeParams.representsNonWordInfo()) { // Skip PtNodes that represent non-word information. continue; } @@ -155,10 +156,10 @@ int Ver4PatriciaTriePolicy::getBigramsPositionOfPtNode(const int ptNodePos) cons ptNodeParams.getTerminalId()); } -bool Ver4PatriciaTriePolicy::addUnigramWord(const int *const word, const int length, +bool Ver4PatriciaTriePolicy::addUnigramEntry(const int *const word, const int length, const UnigramProperty *const unigramProperty) { if (!mBuffers->isUpdatable()) { - AKLOGI("Warning: addUnigramWord() is called for non-updatable dictionary."); + AKLOGI("Warning: addUnigramEntry() is called for non-updatable dictionary."); return false; } if (mDictBuffer->getTailPosition() >= MIN_DICT_SIZE_TO_REFUSE_DYNAMIC_OPERATIONS) { @@ -180,9 +181,19 @@ bool Ver4PatriciaTriePolicy::addUnigramWord(const int *const word, const int len DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader); readingHelper.initWithPtNodeArrayPos(getRootPosition()); bool addedNewUnigram = false; - if (mUpdatingHelper.addUnigramWord(&readingHelper, word, length, + int codePointsToAdd[MAX_WORD_LENGTH]; + int codePointCountToAdd = length; + memmove(codePointsToAdd, word, sizeof(int) * length); + if (unigramProperty->representsBeginningOfSentence()) { + codePointCountToAdd = CharUtils::attachBeginningOfSentenceMarker(codePointsToAdd, + codePointCountToAdd, MAX_WORD_LENGTH); + } + if (codePointCountToAdd <= 0) { + return false; + } + if (mUpdatingHelper.addUnigramWord(&readingHelper, codePointsToAdd, codePointCountToAdd, unigramProperty, &addedNewUnigram)) { - if (addedNewUnigram) { + if (addedNewUnigram && !unigramProperty->representsBeginningOfSentence()) { mUnigramCount++; } if (unigramProperty->getShortcuts().size() > 0) { @@ -210,10 +221,10 @@ bool Ver4PatriciaTriePolicy::addUnigramWord(const int *const word, const int len } } -bool Ver4PatriciaTriePolicy::addBigramWords(const int *const word0, const int length0, +bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsInfo, const BigramProperty *const bigramProperty) { if (!mBuffers->isUpdatable()) { - AKLOGI("Warning: addBigramWords() is called for non-updatable dictionary."); + AKLOGI("Warning: addNgramEntry() is called for non-updatable dictionary."); return false; } if (mDictBuffer->getTailPosition() >= MIN_DICT_SIZE_TO_REFUSE_DYNAMIC_OPERATIONS) { @@ -221,15 +232,20 @@ bool Ver4PatriciaTriePolicy::addBigramWords(const int *const word0, const int le mDictBuffer->getTailPosition()); return false; } - if (length0 > MAX_WORD_LENGTH - || bigramProperty->getTargetCodePoints()->size() > MAX_WORD_LENGTH) { - AKLOGE("Either src word or target word is too long to insert the bigram to the dictionary. " - "length0: %d, length1: %d", length0, bigramProperty->getTargetCodePoints()->size()); + if (!prevWordsInfo->isValid()) { + AKLOGE("prev words info is not valid for adding n-gram entry to the dictionary."); return false; } - const int word0Pos = getTerminalPtNodePositionOfWord(word0, length0, - false /* forceLowerCaseSearch */); - if (word0Pos == NOT_A_DICT_POS) { + if (bigramProperty->getTargetCodePoints()->size() > MAX_WORD_LENGTH) { + AKLOGE("The word is too long to insert the ngram to the dictionary. " + "length: %d", bigramProperty->getTargetCodePoints()->size()); + return false; + } + int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; + prevWordsInfo->getPrevWordsTerminalPtNodePos(this, prevWordsPtNodePos, + false /* tryLowerCaseSearch */); + // TODO: Support N-gram. + if (prevWordsPtNodePos[0] == NOT_A_DICT_POS) { return false; } const int word1Pos = getTerminalPtNodePositionOfWord( @@ -239,7 +255,8 @@ bool Ver4PatriciaTriePolicy::addBigramWords(const int *const word0, const int le return false; } bool addedNewBigram = false; - if (mUpdatingHelper.addBigramWords(word0Pos, word1Pos, bigramProperty, &addedNewBigram)) { + if (mUpdatingHelper.addBigramWords(prevWordsPtNodePos[0], word1Pos, bigramProperty, + &addedNewBigram)) { if (addedNewBigram) { mBigramCount++; } @@ -249,10 +266,10 @@ bool Ver4PatriciaTriePolicy::addBigramWords(const int *const word0, const int le } } -bool Ver4PatriciaTriePolicy::removeBigramWords(const int *const word0, const int length0, - const int *const word1, const int length1) { +bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, + const int *const word, const int length) { if (!mBuffers->isUpdatable()) { - AKLOGI("Warning: addBigramWords() is called for non-updatable dictionary."); + AKLOGI("Warning: removeNgramEntry() is called for non-updatable dictionary."); return false; } if (mDictBuffer->getTailPosition() >= MIN_DICT_SIZE_TO_REFUSE_DYNAMIC_OPERATIONS) { @@ -260,22 +277,26 @@ bool Ver4PatriciaTriePolicy::removeBigramWords(const int *const word0, const int mDictBuffer->getTailPosition()); return false; } - if (length0 > MAX_WORD_LENGTH || length1 > MAX_WORD_LENGTH) { - AKLOGE("Either src word or target word is too long to remove the bigram to from the " - "dictionary. length0: %d, length1: %d", length0, length1); + if (!prevWordsInfo->isValid()) { + AKLOGE("prev words info is not valid for removing n-gram entry form the dictionary."); return false; } - const int word0Pos = getTerminalPtNodePositionOfWord(word0, length0, - false /* forceLowerCaseSearch */); - if (word0Pos == NOT_A_DICT_POS) { + if (length > MAX_WORD_LENGTH) { + AKLOGE("word is too long to remove n-gram entry form the dictionary. length: %d", length); + } + int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; + prevWordsInfo->getPrevWordsTerminalPtNodePos(this, prevWordsPtNodePos, + false /* tryLowerCaseSerch */); + // TODO: Support N-gram. + if (prevWordsPtNodePos[0] == NOT_A_DICT_POS) { return false; } - const int word1Pos = getTerminalPtNodePositionOfWord(word1, length1, + const int wordPos = getTerminalPtNodePositionOfWord(word, length, false /* forceLowerCaseSearch */); - if (word1Pos == NOT_A_DICT_POS) { + if (wordPos == NOT_A_DICT_POS) { return false; } - if (mUpdatingHelper.removeBigramWords(word0Pos, word1Pos)) { + if (mUpdatingHelper.removeBigramWords(prevWordsPtNodePos[0], wordPos)) { mBigramCount--; return true; } else { @@ -419,8 +440,8 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(const int *const code shortcuts.emplace_back(&target, shortcutProbability); } } - const UnigramProperty unigramProperty(ptNodeParams.isNotAWord(), - ptNodeParams.isBlacklisted(), ptNodeParams.getProbability(), + const UnigramProperty unigramProperty(ptNodeParams.representsBeginningOfSentence(), + ptNodeParams.isNotAWord(), ptNodeParams.isBlacklisted(), ptNodeParams.getProbability(), historicalInfo->getTimeStamp(), historicalInfo->getLevel(), historicalInfo->getCount(), &shortcuts); return WordProperty(&codePointVector, &unigramProperty, &bigrams); diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h index b78576484..008f2e423 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h @@ -90,13 +90,13 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { return &mShortcutPolicy; } - bool addUnigramWord(const int *const word, const int length, + bool addUnigramEntry(const int *const word, const int length, const UnigramProperty *const unigramProperty); - bool addBigramWords(const int *const word0, const int length0, + bool addNgramEntry(const PrevWordsInfo *const prevWordsInfo, const BigramProperty *const bigramProperty); - bool removeBigramWords(const int *const word0, const int length0, const int *const word1, + bool removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, const int *const word1, const int length1); void flush(const char *const filePath); diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp index 105363db5..a04551a44 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp @@ -41,11 +41,12 @@ const char *const DictFileWritingUtils::TEMP_FILE_SUFFIX_FOR_WRITING_DICT_FILE = TimeKeeper::setCurrentTime(); const FormatUtils::FORMAT_VERSION formatVersion = FormatUtils::getFormatVersion(dictVersion); switch (formatVersion) { - case FormatUtils::VERSION_4: + case FormatUtils::VERSION_401: return createEmptyV4DictFile<backward::v401::Ver4DictConstants, backward::v401::Ver4DictBuffers, backward::v401::Ver4DictBuffers::Ver4DictBuffersPtr>( filePath, localeAsCodePointVector, attributeMap, formatVersion); + case FormatUtils::VERSION_4: case FormatUtils::VERSION_4_ONLY_FOR_TESTING: case FormatUtils::VERSION_4_DEV: return createEmptyV4DictFile<Ver4DictConstants, Ver4DictBuffers, diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp index ba405b07e..18f558094 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp @@ -29,6 +29,8 @@ const int FormatUtils::DICTIONARY_MINIMUM_SIZE = 12; switch (formatVersion) { case VERSION_2: return VERSION_2; + case VERSION_401: + return VERSION_401; case VERSION_4_ONLY_FOR_TESTING: return VERSION_4_ONLY_FOR_TESTING; case VERSION_4: @@ -60,6 +62,8 @@ const int FormatUtils::DICTIONARY_MINIMUM_SIZE = 12; // same so we use them for both here. if (ByteArrayUtils::readUint16(dict, 4) == VERSION_2) { return VERSION_2; + } else if (ByteArrayUtils::readUint16(dict, 4) == VERSION_401) { + return VERSION_401; } else if (ByteArrayUtils::readUint16(dict, 4) == VERSION_4_ONLY_FOR_TESTING) { return VERSION_4_ONLY_FOR_TESTING; } else if (ByteArrayUtils::readUint16(dict, 4) == VERSION_4) { diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h index c47f30ca4..b05cb2fc8 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h @@ -32,8 +32,9 @@ class FormatUtils { // These MUST have the same values as the relevant constants in FormatSpec.java. VERSION_2 = 2, VERSION_4_ONLY_FOR_TESTING = 399, - VERSION_4 = 401, - VERSION_4_DEV = 402, + VERSION_401 = 401, + VERSION_4 = 402, + VERSION_4_DEV = 403, UNKNOWN_VERSION = -1 }; diff --git a/native/jni/src/utils/char_utils.h b/native/jni/src/utils/char_utils.h index 634c45b04..f28ed5682 100644 --- a/native/jni/src/utils/char_utils.h +++ b/native/jni/src/utils/char_utils.h @@ -18,6 +18,7 @@ #define LATINIME_CHAR_UTILS_H #include <cctype> +#include <cstring> #include <vector> #include "defines.h" @@ -93,6 +94,19 @@ class CharUtils { static unsigned short latin_tolower(const unsigned short c); static const std::vector<int> EMPTY_STRING; + // Returns updated code point count. Returns 0 when the code points cannot be marked as a + // Beginning-of-Sentence. + static AK_FORCE_INLINE int attachBeginningOfSentenceMarker(int *const codePoints, + const int codePointCount, const int maxCodePoint) { + if (codePointCount >= maxCodePoint) { + // the code points cannot be marked as a Beginning-of-Sentence. + return 0; + } + memmove(codePoints + 1, codePoints, sizeof(int) * codePointCount); + codePoints[0] = CODE_POINT_BEGINNING_OF_SENTENCE; + return codePointCount + 1; + } + private: DISALLOW_IMPLICIT_CONSTRUCTORS(CharUtils); |