diff options
Diffstat (limited to 'native')
18 files changed, 155 insertions, 23 deletions
diff --git a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp index c919ebd91..f5c3ee63c 100644 --- a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp +++ b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp @@ -29,6 +29,7 @@ #include "suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.h" #include "suggest/policyimpl/dictionary/utils/dict_file_writing_utils.h" #include "utils/autocorrection_threshold_utils.h" +#include "utils/char_utils.h" #include "utils/time_keeper.h" namespace latinime { @@ -37,13 +38,15 @@ class ProximityInfo; // TODO: Move to makedict. static jboolean latinime_BinaryDictionary_createEmptyDictFile(JNIEnv *env, jclass clazz, - jstring filePath, jlong dictVersion, jobjectArray attributeKeyStringArray, + jstring filePath, jlong dictVersion, jstring locale, jobjectArray attributeKeyStringArray, jobjectArray attributeValueStringArray) { const jsize filePathUtf8Length = env->GetStringUTFLength(filePath); char filePathChars[filePathUtf8Length + 1]; env->GetStringUTFRegion(filePath, 0, env->GetStringLength(filePath), filePathChars); filePathChars[filePathUtf8Length] = '\0'; - + jsize localeLength = env->GetStringLength(locale); + jchar localeCodePoints[localeLength]; + env->GetStringRegion(locale, 0, localeLength, localeCodePoints); const int keyCount = env->GetArrayLength(attributeKeyStringArray); const int valueCount = env->GetArrayLength(attributeValueStringArray); if (keyCount != valueCount) { @@ -73,7 +76,7 @@ static jboolean latinime_BinaryDictionary_createEmptyDictFile(JNIEnv *env, jclas } return DictFileWritingUtils::createEmptyDictFile(filePathChars, static_cast<int>(dictVersion), - &attributeMap); + CharUtils::convertShortArrayToIntVector(localeCodePoints, localeLength), &attributeMap); } static jlong latinime_BinaryDictionary_open(JNIEnv *env, jclass clazz, jstring sourceDir, @@ -137,6 +140,17 @@ static void latinime_BinaryDictionary_close(JNIEnv *env, jclass clazz, jlong dic delete dictionary; } +static void latinime_BinaryDictionary_getHeaderInfo(JNIEnv *env, jclass clazz, jlong dict, + jintArray outHeaderSize, jintArray outFormatVersion, jobject outAttributeKeys, + jobject outAttributeValues) { + Dictionary *dictionary = reinterpret_cast<Dictionary *>(dict); + if (!dictionary) return; + const int formatVersion = dictionary->getFormatVersionNumber(); + env->SetIntArrayRegion(outFormatVersion, 0 /* start */, 1 /* len */, &formatVersion); + // TODO: Implement + return; +} + static int latinime_BinaryDictionary_getFormatVersion(JNIEnv *env, jclass clazz, jlong dict) { Dictionary *dictionary = reinterpret_cast<Dictionary *>(dict); if (!dictionary) return 0; @@ -492,7 +506,8 @@ static int latinime_BinaryDictionary_setCurrentTimeForTest(JNIEnv *env, jclass c static const JNINativeMethod sMethods[] = { { const_cast<char *>("createEmptyDictFileNative"), - const_cast<char *>("(Ljava/lang/String;J[Ljava/lang/String;[Ljava/lang/String;)Z"), + const_cast<char *>( + "(Ljava/lang/String;JLjava/lang/String;[Ljava/lang/String;[Ljava/lang/String;)Z"), reinterpret_cast<void *>(latinime_BinaryDictionary_createEmptyDictFile) }, { @@ -511,6 +526,11 @@ static const JNINativeMethod sMethods[] = { reinterpret_cast<void *>(latinime_BinaryDictionary_getFormatVersion) }, { + const_cast<char *>("getHeaderInfoNative"), + const_cast<char *>("(J[I[ILjava/util/ArrayList;Ljava/util/ArrayList;)V"), + reinterpret_cast<void *>(latinime_BinaryDictionary_getHeaderInfo) + }, + { const_cast<char *>("flushNative"), const_cast<char *>("(JLjava/lang/String;)V"), reinterpret_cast<void *>(latinime_BinaryDictionary_flush) diff --git a/native/jni/src/defines.h b/native/jni/src/defines.h index 1969ebae0..22cc4c02b 100644 --- a/native/jni/src/defines.h +++ b/native/jni/src/defines.h @@ -311,7 +311,7 @@ static inline void prof_out(void) { // A special value to mean the first word confidence makes no sense in this case, // e.g. this is not a multi-word suggestion. -#define NOT_A_FIRST_WORD_CONFIDENCE (S_INT_MAX) +#define NOT_A_FIRST_WORD_CONFIDENCE (S_INT_MIN) // How high the confidence needs to be for us to auto-commit. Arbitrary. // This needs to be the same as CONFIDENCE_FOR_AUTO_COMMIT in BinaryDictionary.java #define CONFIDENCE_FOR_AUTO_COMMIT (1000000) diff --git a/native/jni/src/suggest/core/dictionary/suggestions_output_utils.cpp b/native/jni/src/suggest/core/dictionary/suggestions_output_utils.cpp index b8106377c..e37811b88 100644 --- a/native/jni/src/suggest/core/dictionary/suggestions_output_utils.cpp +++ b/native/jni/src/suggest/core/dictionary/suggestions_output_utils.cpp @@ -78,7 +78,8 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; outputAutoCommitFirstWordConfidence[0] = computeFirstWordConfidence(&terminals[0]); } - + const bool boostExactMatches = traverseSession->getDictionaryStructurePolicy()-> + getHeaderStructurePolicy()->shouldBoostExactMatches(); // Output suggestion results here for (int terminalIndex = 0; terminalIndex < terminalSize && outputWordIndex < MAX_RESULTS; ++terminalIndex) { @@ -102,7 +103,7 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; && !(isPossiblyOffensiveWord && isFirstCharUppercase); const int outputTypeFlags = (isPossiblyOffensiveWord ? Dictionary::KIND_FLAG_POSSIBLY_OFFENSIVE : 0) - | (isSafeExactMatch ? Dictionary::KIND_FLAG_EXACT_MATCH : 0); + | ((isSafeExactMatch && boostExactMatches) ? Dictionary::KIND_FLAG_EXACT_MATCH : 0); // Entries that are blacklisted or do not represent a word should not be output. const bool isValidWord = !terminalDicNode->isBlacklistedOrNotAWord(); @@ -113,7 +114,8 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; compoundDistance, traverseSession->getInputSize(), terminalDicNode->getContainedErrorTypes(), (forceCommitMultiWords && terminalDicNode->hasMultipleWords()) - || (isValidWord && scoringPolicy->doesAutoCorrectValidWord())); + || (isValidWord && scoringPolicy->doesAutoCorrectValidWord()), + boostExactMatches); if (maxScore < finalScore && isValidWord) { maxScore = finalScore; } @@ -147,7 +149,7 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; scoringPolicy->calculateFinalScore(compoundDistance, traverseSession->getInputSize(), terminalDicNode->getContainedErrorTypes(), - true /* forceCommit */) : finalScore; + true /* forceCommit */, boostExactMatches) : finalScore; const int updatedOutputWordIndex = outputShortcuts(&shortcutIt, outputWordIndex, shortcutBaseScore, outputCodePoints, frequencies, outputTypes, sameAsTyped); diff --git a/native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h b/native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h index b76b13971..417620e00 100644 --- a/native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h +++ b/native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h @@ -40,6 +40,8 @@ class DictionaryHeaderStructurePolicy { virtual void readHeaderValueOrQuestionMark(const char *const key, int *outValue, int outValueSize) const = 0; + virtual bool shouldBoostExactMatches() const = 0; + protected: DictionaryHeaderStructurePolicy() {} diff --git a/native/jni/src/suggest/core/policy/scoring.h b/native/jni/src/suggest/core/policy/scoring.h index 783383450..e581a97c3 100644 --- a/native/jni/src/suggest/core/policy/scoring.h +++ b/native/jni/src/suggest/core/policy/scoring.h @@ -28,7 +28,8 @@ class DicTraverseSession; class Scoring { public: virtual int calculateFinalScore(const float compoundDistance, const int inputSize, - const ErrorTypeUtils::ErrorType containedErrorTypes, const bool forceCommit) const = 0; + const ErrorTypeUtils::ErrorType containedErrorTypes, const bool forceCommit, + const bool boostExactMatches) const = 0; virtual bool getMostProbableString(const DicTraverseSession *const traverseSession, const int terminalSize, const float languageWeight, int *const outputCodePoints, int *const type, int *const freq) const = 0; diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp index 7504524f0..b5b5ed740 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp @@ -32,6 +32,7 @@ const char *const HeaderPolicy::EXTENDED_REGION_SIZE_KEY = "EXTENDED_REGION_SIZE // Historical info is information that is needed to support decaying such as timestamp, level and // count. const char *const HeaderPolicy::HAS_HISTORICAL_INFO_KEY = "HAS_HISTORICAL_INFO"; +const char *const HeaderPolicy::LOCALE_KEY = "locale"; // match Java declaration const int HeaderPolicy::DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE = 100; const float HeaderPolicy::MULTIPLE_WORD_COST_MULTIPLIER_SCALE = 100.0f; @@ -59,6 +60,10 @@ void HeaderPolicy::readHeaderValueOrQuestionMark(const char *const key, int *out outValue[terminalIndex] = '\0'; } +const std::vector<int> HeaderPolicy::readLocale() const { + return HeaderReadWriteUtils::readCodePointVectorAttributeValue(&mAttributeMap, LOCALE_KEY); +} + float HeaderPolicy::readMultipleWordCostMultiplier() const { const int demotionRate = HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap, MULTIPLE_WORDS_DEMOTION_RATE_KEY, DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE); @@ -116,6 +121,7 @@ void HeaderPolicy::fillInHeader(const bool updatesLastDecayedTime, const int uni // Set the current time as the generation time. HeaderReadWriteUtils::setIntAttribute(outAttributeMap, DATE_KEY, TimeKeeper::peekCurrentTime()); + HeaderReadWriteUtils::setCodePointVectorAttribute(outAttributeMap, LOCALE_KEY, mLocale); if (updatesLastDecayedTime) { // Set current time as the last updated time. HeaderReadWriteUtils::setIntAttribute(outAttributeMap, LAST_DECAYED_TIME_KEY, diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h index a44f9f0fc..a05e00c39 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h @@ -23,6 +23,7 @@ #include "suggest/core/policy/dictionary_header_structure_policy.h" #include "suggest/policyimpl/dictionary/header/header_read_write_utils.h" #include "suggest/policyimpl/dictionary/utils/format_utils.h" +#include "utils/char_utils.h" #include "utils/time_keeper.h" namespace latinime { @@ -35,6 +36,7 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { mDictionaryFlags(HeaderReadWriteUtils::getFlags(dictBuf)), mSize(HeaderReadWriteUtils::getHeaderSize(dictBuf)), mAttributeMap(createAttributeMapAndReadAllAttributes(dictBuf)), + mLocale(readLocale()), mMultiWordCostMultiplier(readMultipleWordCostMultiplier()), mRequiresGermanUmlautProcessing(readRequiresGermanUmlautProcessing()), mIsDecayingDict(HeaderReadWriteUtils::readBoolAttributeValue(&mAttributeMap, @@ -54,10 +56,11 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { // Constructs header information using an attribute map. HeaderPolicy(const FormatUtils::FORMAT_VERSION dictFormatVersion, + const std::vector<int> locale, const HeaderReadWriteUtils::AttributeMap *const attributeMap) : mDictFormatVersion(dictFormatVersion), mDictionaryFlags(HeaderReadWriteUtils::createAndGetDictionaryFlagsUsingAttributeMap( - attributeMap)), mSize(0), mAttributeMap(*attributeMap), + attributeMap)), mSize(0), mAttributeMap(*attributeMap), mLocale(locale), mMultiWordCostMultiplier(readMultipleWordCostMultiplier()), mRequiresGermanUmlautProcessing(readRequiresGermanUmlautProcessing()), mIsDecayingDict(HeaderReadWriteUtils::readBoolAttributeValue(&mAttributeMap, @@ -68,12 +71,13 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { DATE_KEY, TimeKeeper::peekCurrentTime() /* defaultValue */)), mUnigramCount(0), mBigramCount(0), mExtendedRegionSize(0), mHasHistoricalInfoOfWords(HeaderReadWriteUtils::readBoolAttributeValue( - &mAttributeMap, HAS_HISTORICAL_INFO_KEY, false /* defaultValue */)) {} + &mAttributeMap, HAS_HISTORICAL_INFO_KEY, false /* defaultValue */)) { + } // Temporary dummy header. HeaderPolicy() : mDictFormatVersion(FormatUtils::UNKNOWN_VERSION), mDictionaryFlags(0), mSize(0), - mAttributeMap(), mMultiWordCostMultiplier(0.0f), + mAttributeMap(), mLocale(CharUtils::EMPTY_STRING), mMultiWordCostMultiplier(0.0f), mRequiresGermanUmlautProcessing(false), mIsDecayingDict(false), mDate(0), mLastDecayedTime(0), mUnigramCount(0), mBigramCount(0), mExtendedRegionSize(0), mHasHistoricalInfoOfWords(false) {} @@ -146,6 +150,11 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { return mHasHistoricalInfoOfWords; } + AK_FORCE_INLINE bool shouldBoostExactMatches() const { + // TODO: Investigate better ways to handle exact matches for personalized dictionaries. + return !isDecayingDict(); + } + void readHeaderValueOrQuestionMark(const char *const key, int *outValue, int outValueSize) const; @@ -169,6 +178,7 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { static const char *const BIGRAM_COUNT_KEY; static const char *const EXTENDED_REGION_SIZE_KEY; static const char *const HAS_HISTORICAL_INFO_KEY; + static const char *const LOCALE_KEY; static const int DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE; static const float MULTIPLE_WORD_COST_MULTIPLIER_SCALE; @@ -176,6 +186,7 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { const HeaderReadWriteUtils::DictionaryFlags mDictionaryFlags; const int mSize; HeaderReadWriteUtils::AttributeMap mAttributeMap; + const std::vector<int> mLocale; const float mMultiWordCostMultiplier; const bool mRequiresGermanUmlautProcessing; const bool mIsDecayingDict; @@ -186,6 +197,7 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { const int mExtendedRegionSize; const bool mHasHistoricalInfoOfWords; + const std::vector<int> readLocale() const; float readMultipleWordCostMultiplier() const; bool readRequiresGermanUmlautProcessing() const; diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp index 6b4598642..850b0d87f 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp @@ -130,6 +130,13 @@ const HeaderReadWriteUtils::DictionaryFlags HeaderReadWriteUtils::NO_FLAGS = 0; return true; } +/* static */ void HeaderReadWriteUtils::setCodePointVectorAttribute( + AttributeMap *const headerAttributes, const char *const key, const std::vector<int> value) { + AttributeMap::key_type keyVector; + insertCharactersIntoVector(key, &keyVector); + (*headerAttributes)[keyVector] = value; +} + /* static */ void HeaderReadWriteUtils::setBoolAttribute(AttributeMap *const headerAttributes, const char *const key, const bool value) { setIntAttribute(headerAttributes, key, value ? 1 : 0); @@ -151,6 +158,18 @@ const HeaderReadWriteUtils::DictionaryFlags HeaderReadWriteUtils::NO_FLAGS = 0; (*headerAttributes)[*key] = valueVector; } +/* static */ const std::vector<int> HeaderReadWriteUtils::readCodePointVectorAttributeValue( + const AttributeMap *const headerAttributes, const char *const key) { + AttributeMap::key_type keyVector; + insertCharactersIntoVector(key, &keyVector); + AttributeMap::const_iterator it = headerAttributes->find(keyVector); + if (it == headerAttributes->end()) { + return std::vector<int>(); + } else { + return it->second; + } +} + /* static */ bool HeaderReadWriteUtils::readBoolAttributeValue( const AttributeMap *const headerAttributes, const char *const key, const bool defaultValue) { diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.h b/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.h index fc24bbdd5..3433c0494 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.h @@ -63,12 +63,18 @@ class HeaderReadWriteUtils { /** * Methods for header attributes. */ + static void setCodePointVectorAttribute(AttributeMap *const headerAttributes, + const char *const key, const std::vector<int> value); + static void setBoolAttribute(AttributeMap *const headerAttributes, const char *const key, const bool value); static void setIntAttribute(AttributeMap *const headerAttributes, const char *const key, const int value); + static const std::vector<int> readCodePointVectorAttributeValue( + const AttributeMap *const headerAttributes, const char *const key); + static bool readBoolAttributeValue(const AttributeMap *const headerAttributes, const char *const key, const bool defaultValue); diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp index b918e0765..824d442e4 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp @@ -28,6 +28,14 @@ const int DynamicPtReadingHelper::MAX_CHILD_COUNT_TO_AVOID_INFINITE_LOOP = 10000 const int DynamicPtReadingHelper::MAX_PT_NODE_ARRAY_COUNT_TO_AVOID_INFINITE_LOOP = 100000; const size_t DynamicPtReadingHelper::MAX_READING_STATE_STACK_SIZE = MAX_WORD_LENGTH; +bool DynamicPtReadingHelper::TraversePolicyToGetAllTerminalPtNodePositions::onVisitingPtNode( + const PtNodeParams *const ptNodeParams) { + if (ptNodeParams->isTerminal() && !ptNodeParams->isDeleted()) { + mTerminalPositions->push_back(ptNodeParams->getHeadPos()); + } + return true; +} + // Visits all PtNodes in post-order depth first manner. // For example, visits c -> b -> y -> x -> a for the following dictionary: // a _ b _ c diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h index a69490943..bcc5c7857 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h @@ -59,6 +59,21 @@ class DynamicPtReadingHelper { DISALLOW_COPY_AND_ASSIGN(TraversingEventListener); }; + class TraversePolicyToGetAllTerminalPtNodePositions : public TraversingEventListener { + public: + TraversePolicyToGetAllTerminalPtNodePositions(std::vector<int> *const terminalPositions) + : mTerminalPositions(terminalPositions) {} + bool onAscend() { return true; } + bool onDescend(const int ptNodeArrayPos) { return true; } + bool onReadingPtNodeArrayTail() { return true; } + bool onVisitingPtNode(const PtNodeParams *const ptNodeParams); + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(TraversePolicyToGetAllTerminalPtNodePositions); + + std::vector<int> *const mTerminalPositions; + }; + DynamicPtReadingHelper(const BufferWithExtendableBuffer *const buffer, const PtNodeReader *const ptNodeReader) : mIsError(false), mReadingState(), mBuffer(buffer), diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp index 1c420e070..75d85988c 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp @@ -392,10 +392,32 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(const int *const code historicalInfo->getCount(), &bigrams, &shortcuts); } -int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, - int *const outCodePoints) { - // TODO: Implement. - return 0; +int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outCodePoints) { + if (token == 0) { + mTerminalPtNodePositionsForIteratingWords.clear(); + DynamicPtReadingHelper::TraversePolicyToGetAllTerminalPtNodePositions traversePolicy( + &mTerminalPtNodePositionsForIteratingWords); + DynamicPtReadingHelper readingHelper(mDictBuffer, &mNodeReader); + readingHelper.initWithPtNodeArrayPos(getRootPosition()); + readingHelper.traverseAllPtNodesInPostorderDepthFirstManner(&traversePolicy); + } + const int terminalPtNodePositionsVectorSize = + static_cast<int>(mTerminalPtNodePositionsForIteratingWords.size()); + if (token < 0 || token >= terminalPtNodePositionsVectorSize) { + AKLOGE("Given token %d is invalid.", token); + return 0; + } + const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token]; + int unigramProbability = NOT_A_PROBABILITY; + getCodePointsAndProbabilityAndReturnCodePointCount(terminalPtNodePos, MAX_WORD_LENGTH, + outCodePoints, &unigramProbability); + const int nextToken = token + 1; + if (nextToken >= terminalPtNodePositionsVectorSize) { + // All words have been iterated. + mTerminalPtNodePositionsForIteratingWords.clear(); + return 0; + } + return nextToken; } } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h index 1bcd4ceea..9ba5be0c3 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h @@ -17,6 +17,8 @@ #ifndef LATINIME_VER4_PATRICIA_TRIE_POLICY_H #define LATINIME_VER4_PATRICIA_TRIE_POLICY_H +#include <vector> + #include "defines.h" #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" #include "suggest/policyimpl/dictionary/bigram/ver4_bigram_list_policy.h" @@ -50,7 +52,8 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { mUpdatingHelper(mDictBuffer, &mNodeReader, &mNodeWriter), mWritingHelper(mBuffers.get()), mUnigramCount(mHeaderPolicy->getUnigramCount()), - mBigramCount(mHeaderPolicy->getBigramCount()) {}; + mBigramCount(mHeaderPolicy->getBigramCount()), + mTerminalPtNodePositionsForIteratingWords() {}; AK_FORCE_INLINE int getRootPosition() const { return 0; @@ -134,6 +137,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { Ver4PatriciaTrieWritingHelper mWritingHelper; int mUnigramCount; int mBigramCount; + std::vector<int> mTerminalPtNodePositionsForIteratingWords; }; } // namespace latinime #endif // LATINIME_VER4_PATRICIA_TRIE_POLICY_H diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp index 84403c807..335ea0de0 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp @@ -31,11 +31,12 @@ namespace latinime { const char *const DictFileWritingUtils::TEMP_FILE_SUFFIX_FOR_WRITING_DICT_FILE = ".tmp"; /* static */ bool DictFileWritingUtils::createEmptyDictFile(const char *const filePath, - const int dictVersion, const HeaderReadWriteUtils::AttributeMap *const attributeMap) { + const int dictVersion, const std::vector<int> localeAsCodePointVector, + const HeaderReadWriteUtils::AttributeMap *const attributeMap) { TimeKeeper::setCurrentTime(); switch (dictVersion) { case FormatUtils::VERSION_4: - return createEmptyV4DictFile(filePath, attributeMap); + return createEmptyV4DictFile(filePath, localeAsCodePointVector, attributeMap); default: AKLOGE("Cannot create dictionary %s because format version %d is not supported.", filePath, dictVersion); @@ -44,8 +45,9 @@ const char *const DictFileWritingUtils::TEMP_FILE_SUFFIX_FOR_WRITING_DICT_FILE = } /* static */ bool DictFileWritingUtils::createEmptyV4DictFile(const char *const dirPath, + const std::vector<int> localeAsCodePointVector, const HeaderReadWriteUtils::AttributeMap *const attributeMap) { - HeaderPolicy headerPolicy(FormatUtils::VERSION_4, attributeMap); + HeaderPolicy headerPolicy(FormatUtils::VERSION_4, localeAsCodePointVector, attributeMap); Ver4DictBuffers::Ver4DictBuffersPtr dictBuffers = Ver4DictBuffers::createVer4DictBuffers(&headerPolicy); headerPolicy.fillInAndWriteHeaderToBuffer(true /* updatesLastDecayedTime */, diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.h index bdf9fd63c..c2ecff45e 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.h @@ -31,6 +31,7 @@ class DictFileWritingUtils { static const char *const TEMP_FILE_SUFFIX_FOR_WRITING_DICT_FILE; static bool createEmptyDictFile(const char *const filePath, const int dictVersion, + const std::vector<int> localeAsCodePointVector, const HeaderReadWriteUtils::AttributeMap *const attributeMap); static bool flushAllHeaderAndBodyToFile(const char *const filePath, @@ -44,6 +45,7 @@ class DictFileWritingUtils { DISALLOW_IMPLICIT_CONSTRUCTORS(DictFileWritingUtils); static bool createEmptyV4DictFile(const char *const filePath, + const std::vector<int> localeAsCodePointVector, const HeaderReadWriteUtils::AttributeMap *const attributeMap); static bool flushBufferToFile(const char *const filePath, diff --git a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h index c777e7238..8b405e8de 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h @@ -50,14 +50,14 @@ class TypingScoring : public Scoring { AK_FORCE_INLINE int calculateFinalScore(const float compoundDistance, const int inputSize, const ErrorTypeUtils::ErrorType containedErrorTypes, - const bool forceCommit) const { + const bool forceCommit, const bool boostExactMatches) const { const float maxDistance = ScoringParams::DISTANCE_WEIGHT_LANGUAGE + static_cast<float>(inputSize) * ScoringParams::TYPING_MAX_OUTPUT_SCORE_PER_INPUT; float score = ScoringParams::TYPING_BASE_OUTPUT_SCORE - compoundDistance / maxDistance; if (forceCommit) { score += ScoringParams::AUTOCORRECT_OUTPUT_THRESHOLD; } - if (ErrorTypeUtils::isExactMatch(containedErrorTypes)) { + if (boostExactMatches && ErrorTypeUtils::isExactMatch(containedErrorTypes)) { score += ScoringParams::EXACT_MATCH_PROMOTION; if ((ErrorTypeUtils::MATCH_WITH_CASE_ERROR & containedErrorTypes) != 0) { score -= ScoringParams::CASE_ERROR_PENALTY_FOR_EXACT_MATCH; diff --git a/native/jni/src/utils/char_utils.cpp b/native/jni/src/utils/char_utils.cpp index 0e7039610..d41fc8924 100644 --- a/native/jni/src/utils/char_utils.cpp +++ b/native/jni/src/utils/char_utils.cpp @@ -1273,4 +1273,6 @@ static int compare_pair_capital(const void *a, const void *b) { /* U+04F0 */ 0x0423, 0x0443, 0x0423, 0x0443, 0x0427, 0x0447, 0x04F6, 0x04F7, /* U+04F8 */ 0x042B, 0x044B, 0x04FA, 0x04FB, 0x04FC, 0x04FD, 0x04FE, 0x04FF, }; + +/* static */ const std::vector<int> CharUtils::EMPTY_STRING(1 /* size */, '\0' /* value */); } // namespace latinime diff --git a/native/jni/src/utils/char_utils.h b/native/jni/src/utils/char_utils.h index 41663c81a..98b8966df 100644 --- a/native/jni/src/utils/char_utils.h +++ b/native/jni/src/utils/char_utils.h @@ -18,6 +18,7 @@ #define LATINIME_CHAR_UTILS_H #include <cctype> +#include <vector> #include "defines.h" @@ -85,7 +86,15 @@ class CharUtils { return spaceCount; } + static AK_FORCE_INLINE std::vector<int> convertShortArrayToIntVector( + const unsigned short *const source, const int length) { + std::vector<int> destination; + destination.insert(destination.end(), source, source + length); + return destination; // Copies the vector + } + static unsigned short latin_tolower(const unsigned short c); + static const std::vector<int> EMPTY_STRING; private: DISALLOW_IMPLICIT_CONSTRUCTORS(CharUtils); |