diff options
Diffstat (limited to 'native')
156 files changed, 4388 insertions, 3140 deletions
diff --git a/native/jni/Android.mk b/native/jni/Android.mk index 3a2073f03..6003a6f64 100644 --- a/native/jni/Android.mk +++ b/native/jni/Android.mk @@ -48,7 +48,7 @@ LOCAL_SRC_FILES := \ ifeq ($(FLAG_DO_PROFILE), true) $(warning Making profiling version of native library) - LOCAL_CFLAGS += -DFLAG_DO_PROFILE -funwind-tables -fno-inline + LOCAL_CFLAGS += -DFLAG_DO_PROFILE -funwind-tables else # FLAG_DO_PROFILE ifeq ($(FLAG_DBG), true) $(warning Making debug version of native library) diff --git a/native/jni/HostUnitTests.mk b/native/jni/HostUnitTests.mk index 6967d9b87..e30d50a2e 100644 --- a/native/jni/HostUnitTests.mk +++ b/native/jni/HostUnitTests.mk @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Host build is never supported in unbundled (NDK/tapas) build +ifeq (,$(TARGET_BUILD_APPS)) + # HACK: Temporarily disable host tool build on Mac until the build system is ready for C++11. LATINIME_HOST_OSNAME := $(shell uname -s) ifneq ($(LATINIME_HOST_OSNAME), Darwin) # TODO: Remove this @@ -26,8 +29,10 @@ include $(LOCAL_PATH)/NativeFileList.mk #################### Host library for unit test # TODO: Remove -std=c++11 once it is set by default on host build. LATIN_IME_SRC_DIR := src +LOCAL_ADDRESS_SANITIZER := true LOCAL_CFLAGS += -std=c++11 -Wno-unused-parameter -Wno-unused-function LOCAL_CLANG := true +LOCAL_CXX_STL := libc++ LOCAL_C_INCLUDES += $(LOCAL_PATH)/$(LATIN_IME_SRC_DIR) LOCAL_MODULE := liblatinime_host_static_for_unittests LOCAL_MODULE_TAGS := optional @@ -37,9 +42,11 @@ include $(BUILD_HOST_STATIC_LIBRARY) #################### Host native tests include $(CLEAR_VARS) LATIN_IME_TEST_SRC_DIR := tests +LOCAL_ADDRESS_SANITIZER := true # TODO: Remove -std=c++11 once it is set by default on host build. LOCAL_CFLAGS += -std=c++11 -Wno-unused-parameter -Wno-unused-function LOCAL_CLANG := true +LOCAL_CXX_STL := libc++ LOCAL_C_INCLUDES += $(LOCAL_PATH)/$(LATIN_IME_SRC_DIR) LOCAL_MODULE := liblatinime_host_unittests LOCAL_MODULE_TAGS := tests @@ -47,10 +54,13 @@ LOCAL_SRC_FILES := $(addprefix $(LATIN_IME_TEST_SRC_DIR)/, $(LATIN_IME_CORE_TEST LOCAL_STATIC_LIBRARIES += liblatinime_host_static_for_unittests include $(BUILD_HOST_NATIVE_TEST) +include $(LOCAL_PATH)/CleanupNativeFileList.mk + endif # Darwin - TODO: Remove this +endif # TARGET_BUILD_APPS + #################### Clean up the tmp vars LATINIME_HOST_OSNAME := LATIN_IME_SRC_DIR := LATIN_IME_TEST_SRC_DIR := -include $(LOCAL_PATH)/CleanupNativeFileList.mk diff --git a/native/jni/NativeFileList.mk b/native/jni/NativeFileList.mk index 7a732a588..7299ed3c0 100644 --- a/native/jni/NativeFileList.mk +++ b/native/jni/NativeFileList.mk @@ -40,6 +40,7 @@ LATIN_IME_CORE_SRC_FILES := \ proximity_info_state_utils.cpp) \ suggest/core/policy/weighting.cpp \ suggest/core/session/dic_traverse_session.cpp \ + suggest/core/session/ngram_context.cpp \ $(addprefix suggest/core/result/, \ suggestion_results.cpp \ suggestions_output_utils.cpp) \ @@ -55,13 +56,12 @@ LATIN_IME_CORE_SRC_FILES := \ dynamic_pt_updating_helper.cpp \ dynamic_pt_writing_utils.cpp \ patricia_trie_reading_utils.cpp \ - shortcut/shortcut_list_reading_utils.cpp ) \ + shortcut/shortcut_list_reading_utils.cpp) \ $(addprefix suggest/policyimpl/dictionary/structure/v2/, \ patricia_trie_policy.cpp \ ver2_patricia_trie_node_reader.cpp \ ver2_pt_node_array_reader.cpp) \ $(addprefix suggest/policyimpl/dictionary/structure/v4/, \ - bigram/ver4_bigram_list_policy.cpp \ ver4_dict_buffers.cpp \ ver4_dict_constants.cpp \ ver4_patricia_trie_node_reader.cpp \ @@ -71,7 +71,6 @@ LATIN_IME_CORE_SRC_FILES := \ ver4_patricia_trie_writing_helper.cpp \ ver4_pt_node_array_reader.cpp) \ $(addprefix suggest/policyimpl/dictionary/structure/v4/content/, \ - bigram_dict_content.cpp \ language_model_dict_content.cpp \ shortcut_dict_content.cpp \ sparse_table_dict_content.cpp \ @@ -123,11 +122,21 @@ LATIN_IME_CORE_SRC_FILES += $(LATIN_IME_CORE_SRC_FILES_BACKWARD_V402) LATIN_IME_CORE_TEST_FILES := \ defines_test.cpp \ - suggest/core/layout/normal_distribution_2d_test.cpp \ + suggest/core/dicnode/dic_node_pool_test.cpp \ suggest/core/dictionary/bloom_filter_test.cpp \ + suggest/core/layout/geometry_utils_test.cpp \ + suggest/core/layout/normal_distribution_2d_test.cpp \ + suggest/policyimpl/dictionary/header/header_read_write_utils_test.cpp \ suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp \ suggest/policyimpl/dictionary/structure/v4/content/probability_entry_test.cpp \ + suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table_test.cpp \ suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer_test.cpp \ + suggest/policyimpl/dictionary/utils/byte_array_utils_test.cpp \ + suggest/policyimpl/dictionary/utils/format_utils_test.cpp \ + suggest/policyimpl/dictionary/utils/sparse_table_test.cpp \ suggest/policyimpl/dictionary/utils/trie_map_test.cpp \ + suggest/policyimpl/utils/damerau_levenshtein_edit_distance_policy_test.cpp \ utils/autocorrection_threshold_utils_test.cpp \ - utils/int_array_view_test.cpp + utils/char_utils_test.cpp \ + utils/int_array_view_test.cpp \ + utils/time_keeper_test.cpp diff --git a/native/jni/com_android_inputmethod_keyboard_ProximityInfo.cpp b/native/jni/com_android_inputmethod_keyboard_ProximityInfo.cpp index f88d37ec9..80419b335 100644 --- a/native/jni/com_android_inputmethod_keyboard_ProximityInfo.cpp +++ b/native/jni/com_android_inputmethod_keyboard_ProximityInfo.cpp @@ -25,13 +25,13 @@ namespace latinime { -static jlong latinime_Keyboard_setProximityInfo(JNIEnv *env, jclass clazz, jstring localeJStr, +static jlong latinime_Keyboard_setProximityInfo(JNIEnv *env, jclass clazz, jint displayWidth, jint displayHeight, jint gridWidth, jint gridHeight, jint mostCommonkeyWidth, jint mostCommonkeyHeight, jintArray proximityChars, jint keyCount, jintArray keyXCoordinates, jintArray keyYCoordinates, jintArray keyWidths, jintArray keyHeights, jintArray keyCharCodes, jfloatArray sweetSpotCenterXs, jfloatArray sweetSpotCenterYs, jfloatArray sweetSpotRadii) { - ProximityInfo *proximityInfo = new ProximityInfo(env, localeJStr, displayWidth, displayHeight, + ProximityInfo *proximityInfo = new ProximityInfo(env, displayWidth, displayHeight, gridWidth, gridHeight, mostCommonkeyWidth, mostCommonkeyHeight, proximityChars, keyCount, keyXCoordinates, keyYCoordinates, keyWidths, keyHeights, keyCharCodes, sweetSpotCenterXs, sweetSpotCenterYs, sweetSpotRadii); @@ -46,7 +46,7 @@ static void latinime_Keyboard_release(JNIEnv *env, jclass clazz, jlong proximity static const JNINativeMethod sMethods[] = { { const_cast<char *>("setProximityInfoNative"), - const_cast<char *>("(Ljava/lang/String;IIIIII[II[I[I[I[I[I[F[F[F)J"), + const_cast<char *>("(IIIIII[II[I[I[I[I[I[F[F[F)J"), reinterpret_cast<void *>(latinime_Keyboard_setProximityInfo) }, { diff --git a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp index 22ad2d0ab..118f600bb 100644 --- a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp +++ b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp @@ -28,10 +28,11 @@ #include "suggest/core/dictionary/property/unigram_property.h" #include "suggest/core/dictionary/property/word_property.h" #include "suggest/core/result/suggestion_results.h" -#include "suggest/core/session/prev_words_info.h" +#include "suggest/core/session/ngram_context.h" #include "suggest/core/suggest_options.h" #include "suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.h" #include "utils/char_utils.h" +#include "utils/int_array_view.h" #include "utils/jni_data_utils.h" #include "utils/log_utils.h" #include "utils/time_keeper.h" @@ -179,9 +180,10 @@ static void latinime_BinaryDictionary_getSuggestions(JNIEnv *env, jclass clazz, jintArray yCoordinatesArray, jintArray timesArray, jintArray pointerIdsArray, jintArray inputCodePointsArray, jint inputSize, jintArray suggestOptions, jobjectArray prevWordCodePointArrays, jbooleanArray isBeginningOfSentenceArray, - jintArray outSuggestionCount, jintArray outCodePointsArray, jintArray outScoresArray, - jintArray outSpaceIndicesArray, jintArray outTypesArray, - jintArray outAutoCommitFirstWordConfidenceArray, jfloatArray inOutLanguageWeight) { + jint prevWordCount, jintArray outSuggestionCount, jintArray outCodePointsArray, + jintArray outScoresArray, jintArray outSpaceIndicesArray, jintArray outTypesArray, + jintArray outAutoCommitFirstWordConfidenceArray, + jfloatArray inOutWeightOfLangModelVsSpatialModel) { Dictionary *dictionary = reinterpret_cast<Dictionary *>(dict); // Assign 0 to outSuggestionCount here in case of returning earlier in this method. JniDataUtils::putIntToArray(env, outSuggestionCount, 0 /* index */, 0); @@ -236,42 +238,47 @@ static void latinime_BinaryDictionary_getSuggestions(JNIEnv *env, jclass clazz, ASSERT(false); return; } - float languageWeight; - env->GetFloatArrayRegion(inOutLanguageWeight, 0, 1 /* len */, &languageWeight); + float weightOfLangModelVsSpatialModel; + env->GetFloatArrayRegion(inOutWeightOfLangModelVsSpatialModel, 0, 1 /* len */, + &weightOfLangModelVsSpatialModel); SuggestionResults suggestionResults(MAX_RESULTS); - const PrevWordsInfo prevWordsInfo = JniDataUtils::constructPrevWordsInfo(env, - prevWordCodePointArrays, isBeginningOfSentenceArray); + const NgramContext ngramContext = JniDataUtils::constructNgramContext(env, + prevWordCodePointArrays, isBeginningOfSentenceArray, prevWordCount); if (givenSuggestOptions.isGesture() || inputSize > 0) { // TODO: Use SuggestionResults to return suggestions. dictionary->getSuggestions(pInfo, traverseSession, xCoordinates, yCoordinates, - times, pointerIds, inputCodePoints, inputSize, &prevWordsInfo, - &givenSuggestOptions, languageWeight, &suggestionResults); + times, pointerIds, inputCodePoints, inputSize, &ngramContext, + &givenSuggestOptions, weightOfLangModelVsSpatialModel, &suggestionResults); } else { - dictionary->getPredictions(&prevWordsInfo, &suggestionResults); + dictionary->getPredictions(&ngramContext, &suggestionResults); + } + if (DEBUG_DICT) { + suggestionResults.dumpSuggestions(); } suggestionResults.outputSuggestions(env, outSuggestionCount, outCodePointsArray, outScoresArray, outSpaceIndicesArray, outTypesArray, - outAutoCommitFirstWordConfidenceArray, inOutLanguageWeight); + outAutoCommitFirstWordConfidenceArray, inOutWeightOfLangModelVsSpatialModel); } static jint latinime_BinaryDictionary_getProbability(JNIEnv *env, jclass clazz, jlong dict, jintArray word) { Dictionary *dictionary = reinterpret_cast<Dictionary *>(dict); if (!dictionary) return NOT_A_PROBABILITY; - const jsize wordLength = env->GetArrayLength(word); - int codePoints[wordLength]; - env->GetIntArrayRegion(word, 0, wordLength, codePoints); - return dictionary->getProbability(codePoints, wordLength); + const jsize codePointCount = env->GetArrayLength(word); + int codePoints[codePointCount]; + env->GetIntArrayRegion(word, 0, codePointCount, codePoints); + return dictionary->getProbability(CodePointArrayView(codePoints, codePointCount)); } static jint latinime_BinaryDictionary_getMaxProbabilityOfExactMatches( JNIEnv *env, jclass clazz, jlong dict, jintArray word) { Dictionary *dictionary = reinterpret_cast<Dictionary *>(dict); if (!dictionary) return NOT_A_PROBABILITY; - const jsize wordLength = env->GetArrayLength(word); - int codePoints[wordLength]; - env->GetIntArrayRegion(word, 0, wordLength, codePoints); - return dictionary->getMaxProbabilityOfExactMatches(codePoints, wordLength); + const jsize codePointCount = env->GetArrayLength(word); + int codePoints[codePointCount]; + env->GetIntArrayRegion(word, 0, codePointCount, codePoints); + return dictionary->getMaxProbabilityOfExactMatches( + CodePointArrayView(codePoints, codePointCount)); } static jint latinime_BinaryDictionary_getNgramProbability(JNIEnv *env, jclass clazz, @@ -282,9 +289,11 @@ static jint latinime_BinaryDictionary_getNgramProbability(JNIEnv *env, jclass cl const jsize wordLength = env->GetArrayLength(word); int wordCodePoints[wordLength]; env->GetIntArrayRegion(word, 0, wordLength, wordCodePoints); - const PrevWordsInfo prevWordsInfo = JniDataUtils::constructPrevWordsInfo(env, - prevWordCodePointArrays, isBeginningOfSentenceArray); - return dictionary->getNgramProbability(&prevWordsInfo, wordCodePoints, wordLength); + const NgramContext ngramContext = JniDataUtils::constructNgramContext(env, + prevWordCodePointArrays, isBeginningOfSentenceArray, + env->GetArrayLength(prevWordCodePointArrays)); + return dictionary->getNgramProbability(&ngramContext, + CodePointArrayView(wordCodePoints, wordLength)); } // Method to iterate all words in the dictionary for makedict. @@ -318,8 +327,9 @@ static jint latinime_BinaryDictionary_getNextWord(JNIEnv *env, jclass clazz, static void latinime_BinaryDictionary_getWordProperty(JNIEnv *env, jclass clazz, jlong dict, jintArray word, jboolean isBeginningOfSentence, jintArray outCodePoints, - jbooleanArray outFlags, jintArray outProbabilityInfo, jobject outBigramTargets, - jobject outBigramProbabilityInfo, jobject outShortcutTargets, + jbooleanArray outFlags, jintArray outProbabilityInfo, jobject outNgramPrevWordsArray, + jobject outNgramPrevWordIsBeginningOfSentenceArray, jobject outNgramTargets, + jobject outNgramProbabilityInfo, jobject outShortcutTargets, jobject outShortcutProbabilities) { Dictionary *dictionary = reinterpret_cast<Dictionary *>(dict); if (!dictionary) return; @@ -339,15 +349,17 @@ static void latinime_BinaryDictionary_getWordProperty(JNIEnv *env, jclass clazz, return; } } - const WordProperty wordProperty = dictionary->getWordProperty(wordCodePoints, codePointCount); + const WordProperty wordProperty = dictionary->getWordProperty( + CodePointArrayView(wordCodePoints, codePointCount)); wordProperty.outputProperties(env, outCodePoints, outFlags, outProbabilityInfo, - outBigramTargets, outBigramProbabilityInfo, outShortcutTargets, + outNgramPrevWordsArray, outNgramPrevWordIsBeginningOfSentenceArray, + outNgramTargets, outNgramProbabilityInfo, outShortcutTargets, outShortcutProbabilities); } static bool latinime_BinaryDictionary_addUnigramEntry(JNIEnv *env, jclass clazz, jlong dict, jintArray word, jint probability, jintArray shortcutTarget, jint shortcutProbability, - jboolean isBeginningOfSentence, jboolean isNotAWord, jboolean isBlacklisted, + jboolean isBeginningOfSentence, jboolean isNotAWord, jboolean isPossiblyOffensive, jint timestamp) { Dictionary *dictionary = reinterpret_cast<Dictionary *>(dict); if (!dictionary) { @@ -357,15 +369,19 @@ static bool latinime_BinaryDictionary_addUnigramEntry(JNIEnv *env, jclass clazz, int codePoints[codePointCount]; env->GetIntArrayRegion(word, 0, codePointCount, codePoints); std::vector<UnigramProperty::ShortcutProperty> shortcuts; - std::vector<int> shortcutTargetCodePoints; - JniDataUtils::jintarrayToVector(env, shortcutTarget, &shortcutTargetCodePoints); - if (!shortcutTargetCodePoints.empty()) { - shortcuts.emplace_back(&shortcutTargetCodePoints, shortcutProbability); + { + std::vector<int> shortcutTargetCodePoints; + JniDataUtils::jintarrayToVector(env, shortcutTarget, &shortcutTargetCodePoints); + if (!shortcutTargetCodePoints.empty()) { + shortcuts.emplace_back(std::move(shortcutTargetCodePoints), shortcutProbability); + } } // Use 1 for count to indicate the word has inputted. const UnigramProperty unigramProperty(isBeginningOfSentence, isNotAWord, - isBlacklisted, probability, timestamp, 0 /* level */, 1 /* count */, &shortcuts); - return dictionary->addUnigramEntry(codePoints, codePointCount, &unigramProperty); + isPossiblyOffensive, probability, HistoricalInfo(timestamp, 0 /* level */, + 1 /* count */), std::move(shortcuts)); + return dictionary->addUnigramEntry(CodePointArrayView(codePoints, codePointCount), + &unigramProperty); } static bool latinime_BinaryDictionary_removeUnigramEntry(JNIEnv *env, jclass clazz, jlong dict, @@ -377,7 +393,7 @@ static bool latinime_BinaryDictionary_removeUnigramEntry(JNIEnv *env, jclass cla jsize codePointCount = env->GetArrayLength(word); int codePoints[codePointCount]; env->GetIntArrayRegion(word, 0, codePointCount, codePoints); - return dictionary->removeUnigramEntry(codePoints, codePointCount); + return dictionary->removeUnigramEntry(CodePointArrayView(codePoints, codePointCount)); } static bool latinime_BinaryDictionary_addNgramEntry(JNIEnv *env, jclass clazz, jlong dict, @@ -387,17 +403,17 @@ static bool latinime_BinaryDictionary_addNgramEntry(JNIEnv *env, jclass clazz, j if (!dictionary) { return false; } - const PrevWordsInfo prevWordsInfo = JniDataUtils::constructPrevWordsInfo(env, - prevWordCodePointArrays, isBeginningOfSentenceArray); + const NgramContext ngramContext = JniDataUtils::constructNgramContext(env, + prevWordCodePointArrays, isBeginningOfSentenceArray, + env->GetArrayLength(prevWordCodePointArrays)); jsize wordLength = env->GetArrayLength(word); int wordCodePoints[wordLength]; env->GetIntArrayRegion(word, 0, wordLength, wordCodePoints); - const std::vector<int> bigramTargetCodePoints( - wordCodePoints, wordCodePoints + wordLength); - // Use 1 for count to indicate the bigram has inputted. - const BigramProperty bigramProperty(&bigramTargetCodePoints, probability, - timestamp, 0 /* level */, 1 /* count */); - return dictionary->addNgramEntry(&prevWordsInfo, &bigramProperty); + // Use 1 for count to indicate the ngram has inputted. + const NgramProperty ngramProperty(ngramContext, + CodePointArrayView(wordCodePoints, wordLength).toVector(), + probability, HistoricalInfo(timestamp, 0 /* level */, 1 /* count */)); + return dictionary->addNgramEntry(&ngramProperty); } static bool latinime_BinaryDictionary_removeNgramEntry(JNIEnv *env, jclass clazz, jlong dict, @@ -407,103 +423,90 @@ static bool latinime_BinaryDictionary_removeNgramEntry(JNIEnv *env, jclass clazz if (!dictionary) { return false; } - const PrevWordsInfo prevWordsInfo = JniDataUtils::constructPrevWordsInfo(env, - prevWordCodePointArrays, isBeginningOfSentenceArray); - jsize wordLength = env->GetArrayLength(word); - int wordCodePoints[wordLength]; - env->GetIntArrayRegion(word, 0, wordLength, wordCodePoints); - return dictionary->removeNgramEntry(&prevWordsInfo, wordCodePoints, wordLength); + const NgramContext ngramContext = JniDataUtils::constructNgramContext(env, + prevWordCodePointArrays, isBeginningOfSentenceArray, + env->GetArrayLength(prevWordCodePointArrays)); + jsize codePointCount = env->GetArrayLength(word); + int wordCodePoints[codePointCount]; + env->GetIntArrayRegion(word, 0, codePointCount, wordCodePoints); + return dictionary->removeNgramEntry(&ngramContext, + CodePointArrayView(wordCodePoints, codePointCount)); } -// Returns how many language model params are processed. -static int latinime_BinaryDictionary_addMultipleDictionaryEntries(JNIEnv *env, jclass clazz, - jlong dict, jobjectArray languageModelParams, jint startIndex) { +static bool latinime_BinaryDictionary_updateEntriesForWordWithNgramContext(JNIEnv *env, + jclass clazz, jlong dict, jobjectArray prevWordCodePointArrays, + jbooleanArray isBeginningOfSentenceArray, jintArray word, jboolean isValidWord, jint count, + jint timestamp) { + Dictionary *dictionary = reinterpret_cast<Dictionary *>(dict); + if (!dictionary) { + return false; + } + const NgramContext ngramContext = JniDataUtils::constructNgramContext(env, + prevWordCodePointArrays, isBeginningOfSentenceArray, + env->GetArrayLength(prevWordCodePointArrays)); + jsize codePointCount = env->GetArrayLength(word); + int wordCodePoints[codePointCount]; + env->GetIntArrayRegion(word, 0, codePointCount, wordCodePoints); + const HistoricalInfo historicalInfo(timestamp, 0 /* level */, count); + return dictionary->updateEntriesForWordWithNgramContext(&ngramContext, + CodePointArrayView(wordCodePoints, codePointCount), isValidWord == JNI_TRUE, + historicalInfo); +} + +// Returns how many input events are processed. +static int latinime_BinaryDictionary_updateEntriesForInputEvents(JNIEnv *env, jclass clazz, + jlong dict, jobjectArray inputEvents, jint startIndex) { Dictionary *dictionary = reinterpret_cast<Dictionary *>(dict); if (!dictionary) { return 0; } - jsize languageModelParamCount = env->GetArrayLength(languageModelParams); - if (languageModelParamCount == 0 || startIndex >= languageModelParamCount) { + jsize inputEventCount = env->GetArrayLength(inputEvents); + if (inputEventCount == 0 || startIndex >= inputEventCount) { return 0; } - jobject languageModelParam = env->GetObjectArrayElement(languageModelParams, 0); - jclass languageModelParamClass = env->GetObjectClass(languageModelParam); - env->DeleteLocalRef(languageModelParam); - - jfieldID word0FieldId = env->GetFieldID(languageModelParamClass, "mWord0", "[I"); - jfieldID word1FieldId = env->GetFieldID(languageModelParamClass, "mWord1", "[I"); - jfieldID unigramProbabilityFieldId = - env->GetFieldID(languageModelParamClass, "mUnigramProbability", "I"); - jfieldID bigramProbabilityFieldId = - env->GetFieldID(languageModelParamClass, "mBigramProbability", "I"); - jfieldID timestampFieldId = - env->GetFieldID(languageModelParamClass, "mTimestamp", "I"); - jfieldID shortcutTargetFieldId = - env->GetFieldID(languageModelParamClass, "mShortcutTarget", "[I"); - jfieldID shortcutProbabilityFieldId = - env->GetFieldID(languageModelParamClass, "mShortcutProbability", "I"); - jfieldID isNotAWordFieldId = - env->GetFieldID(languageModelParamClass, "mIsNotAWord", "Z"); - jfieldID isBlacklistedFieldId = - env->GetFieldID(languageModelParamClass, "mIsBlacklisted", "Z"); - env->DeleteLocalRef(languageModelParamClass); - - for (int i = startIndex; i < languageModelParamCount; ++i) { - jobject languageModelParam = env->GetObjectArrayElement(languageModelParams, i); - // languageModelParam is a set of params for word1; thus, word1 cannot be null. On the - // other hand, word0 can be null and then it means the set of params doesn't contain bigram - // information. - jintArray word0 = static_cast<jintArray>( - env->GetObjectField(languageModelParam, word0FieldId)); - jsize word0Length = word0 ? env->GetArrayLength(word0) : 0; - int word0CodePoints[word0Length]; - if (word0) { - env->GetIntArrayRegion(word0, 0, word0Length, word0CodePoints); - } - jintArray word1 = static_cast<jintArray>( - env->GetObjectField(languageModelParam, word1FieldId)); - jsize word1Length = env->GetArrayLength(word1); - int word1CodePoints[word1Length]; - env->GetIntArrayRegion(word1, 0, word1Length, word1CodePoints); - jint unigramProbability = env->GetIntField(languageModelParam, unigramProbabilityFieldId); - jint timestamp = env->GetIntField(languageModelParam, timestampFieldId); - jboolean isNotAWord = env->GetBooleanField(languageModelParam, isNotAWordFieldId); - jboolean isBlacklisted = env->GetBooleanField(languageModelParam, isBlacklistedFieldId); - jintArray shortcutTarget = static_cast<jintArray>( - env->GetObjectField(languageModelParam, shortcutTargetFieldId)); - std::vector<UnigramProperty::ShortcutProperty> shortcuts; - std::vector<int> shortcutTargetCodePoints; - JniDataUtils::jintarrayToVector(env, shortcutTarget, &shortcutTargetCodePoints); - if (!shortcutTargetCodePoints.empty()) { - jint shortcutProbability = - env->GetIntField(languageModelParam, shortcutProbabilityFieldId); - shortcuts.emplace_back(&shortcutTargetCodePoints, shortcutProbability); - } + jobject inputEvent = env->GetObjectArrayElement(inputEvents, 0); + jclass wordInputEventClass = env->GetObjectClass(inputEvent); + env->DeleteLocalRef(inputEvent); + + jfieldID targetWordFieldId = env->GetFieldID(wordInputEventClass, "mTargetWord", "[I"); + jfieldID prevWordCountFieldId = env->GetFieldID(wordInputEventClass, "mPrevWordsCount", "I"); + jfieldID prevWordArrayFieldId = env->GetFieldID(wordInputEventClass, "mPrevWordArray", "[[I"); + jfieldID isPrevWordBoSArrayFieldId = + env->GetFieldID(wordInputEventClass, "mIsPrevWordBeginningOfSentenceArray", "[Z"); + jfieldID isValidFieldId = env->GetFieldID(wordInputEventClass, "mIsValid", "Z"); + jfieldID timestampFieldId = env->GetFieldID(wordInputEventClass, "mTimestamp", "I"); + env->DeleteLocalRef(wordInputEventClass); + + for (int i = startIndex; i < inputEventCount; ++i) { + jobject inputEvent = env->GetObjectArrayElement(inputEvents, i); + jintArray targetWord = static_cast<jintArray>( + env->GetObjectField(inputEvent, targetWordFieldId)); + jsize wordLength = env->GetArrayLength(targetWord); + int wordCodePoints[wordLength]; + env->GetIntArrayRegion(targetWord, 0, wordLength, wordCodePoints); + env->DeleteLocalRef(targetWord); + + jint prevWordCount = env->GetIntField(inputEvent, prevWordCountFieldId); + jobjectArray prevWordArray = + static_cast<jobjectArray>(env->GetObjectField(inputEvent, prevWordArrayFieldId)); + jbooleanArray isPrevWordBeginningOfSentenceArray = static_cast<jbooleanArray>( + env->GetObjectField(inputEvent, isPrevWordBoSArrayFieldId)); + jboolean isValid = env->GetBooleanField(inputEvent, isValidFieldId); + jint timestamp = env->GetIntField(inputEvent, timestampFieldId); + const NgramContext ngramContext = JniDataUtils::constructNgramContext(env, + prevWordArray, isPrevWordBeginningOfSentenceArray, prevWordCount); // Use 1 for count to indicate the word has inputted. - const UnigramProperty unigramProperty(false /* isBeginningOfSentence */, isNotAWord, - isBlacklisted, unigramProbability, timestamp, 0 /* level */, 1 /* count */, - &shortcuts); - dictionary->addUnigramEntry(word1CodePoints, word1Length, &unigramProperty); - if (word0) { - jint bigramProbability = env->GetIntField(languageModelParam, bigramProbabilityFieldId); - const std::vector<int> bigramTargetCodePoints( - word1CodePoints, word1CodePoints + word1Length); - // Use 1 for count to indicate the bigram has inputted. - const BigramProperty bigramProperty(&bigramTargetCodePoints, bigramProbability, - timestamp, 0 /* level */, 1 /* count */); - const PrevWordsInfo prevWordsInfo(word0CodePoints, word0Length, - false /* isBeginningOfSentence */); - dictionary->addNgramEntry(&prevWordsInfo, &bigramProperty); - } + dictionary->updateEntriesForWordWithNgramContext(&ngramContext, + CodePointArrayView(wordCodePoints, wordLength), isValid, + HistoricalInfo(timestamp, 0 /* level */, 1 /* count */)); if (dictionary->needsToRunGC(true /* mindsBlockByGC */)) { return i + 1; } - env->DeleteLocalRef(word0); - env->DeleteLocalRef(word1); - env->DeleteLocalRef(shortcutTarget); - env->DeleteLocalRef(languageModelParam); + env->DeleteLocalRef(prevWordArray); + env->DeleteLocalRef(isPrevWordBeginningOfSentenceArray); + env->DeleteLocalRef(inputEvent); } - return languageModelParamCount; + return inputEventCount; } static jstring latinime_BinaryDictionary_getProperty(JNIEnv *env, jclass clazz, jlong dict, @@ -567,8 +570,8 @@ static bool latinime_BinaryDictionary_migrateNative(JNIEnv *env, jclass clazz, j // Add unigrams. do { token = dictionary->getNextWordAndNextToken(token, wordCodePoints, &wordCodePointCount); - const WordProperty wordProperty = dictionary->getWordProperty(wordCodePoints, - wordCodePointCount); + const WordProperty wordProperty = dictionary->getWordProperty( + CodePointArrayView(wordCodePoints, wordCodePointCount)); if (wordCodePoints[0] == CODE_POINT_BEGINNING_OF_SENTENCE) { // Skip beginning-of-sentence unigram. continue; @@ -581,18 +584,19 @@ static bool latinime_BinaryDictionary_migrateNative(JNIEnv *env, jclass clazz, j return false; } } - if (!dictionaryStructureWithBufferPolicy->addUnigramEntry(wordCodePoints, - wordCodePointCount, wordProperty.getUnigramProperty())) { + if (!dictionaryStructureWithBufferPolicy->addUnigramEntry( + CodePointArrayView(wordCodePoints, wordCodePointCount), + wordProperty.getUnigramProperty())) { LogUtils::logToJava(env, "Cannot add unigram to the new dict."); return false; } } while (token != 0); - // Add bigrams. + // Add ngrams. do { token = dictionary->getNextWordAndNextToken(token, wordCodePoints, &wordCodePointCount); - const WordProperty wordProperty = dictionary->getWordProperty(wordCodePoints, - wordCodePointCount); + const WordProperty wordProperty = dictionary->getWordProperty( + CodePointArrayView(wordCodePoints, wordCodePointCount)); if (dictionaryStructureWithBufferPolicy->needsToRunGC(true /* mindsBlockByGC */)) { dictionaryStructureWithBufferPolicy = runGCAndGetNewStructurePolicy( std::move(dictionaryStructureWithBufferPolicy), dictFilePathChars); @@ -601,12 +605,9 @@ static bool latinime_BinaryDictionary_migrateNative(JNIEnv *env, jclass clazz, j return false; } } - const PrevWordsInfo prevWordsInfo(wordCodePoints, wordCodePointCount, - wordProperty.getUnigramProperty()->representsBeginningOfSentence()); - for (const BigramProperty &bigramProperty : *wordProperty.getBigramProperties()) { - if (!dictionaryStructureWithBufferPolicy->addNgramEntry(&prevWordsInfo, - &bigramProperty)) { - LogUtils::logToJava(env, "Cannot add bigram to the new dict."); + for (const NgramProperty &ngramProperty : *wordProperty.getNgramProperties()) { + if (!dictionaryStructureWithBufferPolicy->addNgramEntry(&ngramProperty)) { + LogUtils::logToJava(env, "Cannot add ngram to the new dict."); return false; } } @@ -659,7 +660,7 @@ static const JNINativeMethod sMethods[] = { }, { const_cast<char *>("getSuggestionsNative"), - const_cast<char *>("(JJJ[I[I[I[I[II[I[[I[Z[I[I[I[I[I[I[F)V"), + const_cast<char *>("(JJJ[I[I[I[I[II[I[[I[ZI[I[I[I[I[I[I[F)V"), reinterpret_cast<void *>(latinime_BinaryDictionary_getSuggestions) }, { @@ -680,7 +681,8 @@ static const JNINativeMethod sMethods[] = { { const_cast<char *>("getWordPropertyNative"), const_cast<char *>("(J[IZ[I[Z[ILjava/util/ArrayList;Ljava/util/ArrayList;" - "Ljava/util/ArrayList;Ljava/util/ArrayList;)V"), + "Ljava/util/ArrayList;Ljava/util/ArrayList;Ljava/util/ArrayList;" + "Ljava/util/ArrayList;)V"), reinterpret_cast<void *>(latinime_BinaryDictionary_getWordProperty) }, { @@ -709,10 +711,15 @@ static const JNINativeMethod sMethods[] = { reinterpret_cast<void *>(latinime_BinaryDictionary_removeNgramEntry) }, { - const_cast<char *>("addMultipleDictionaryEntriesNative"), + const_cast<char *>("updateEntriesForWordWithNgramContextNative"), + const_cast<char *>("(J[[I[Z[IZII)Z"), + reinterpret_cast<void *>(latinime_BinaryDictionary_updateEntriesForWordWithNgramContext) + }, + { + const_cast<char *>("updateEntriesForInputEventsNative"), const_cast<char *>( - "(J[Lcom/android/inputmethod/latin/utils/LanguageModelParam;I)I"), - reinterpret_cast<void *>(latinime_BinaryDictionary_addMultipleDictionaryEntries) + "(J[Lcom/android/inputmethod/latin/utils/WordInputEventForPersonalization;I)I"), + reinterpret_cast<void *>(latinime_BinaryDictionary_updateEntriesForInputEvents) }, { const_cast<char *>("getPropertyNative"), diff --git a/native/jni/com_android_inputmethod_latin_BinaryDictionaryUtils.cpp b/native/jni/com_android_inputmethod_latin_BinaryDictionaryUtils.cpp index 0a34b783a..68bf417e5 100644 --- a/native/jni/com_android_inputmethod_latin_BinaryDictionaryUtils.cpp +++ b/native/jni/com_android_inputmethod_latin_BinaryDictionaryUtils.cpp @@ -68,18 +68,6 @@ static jfloat latinime_BinaryDictionaryUtils_calcNormalizedScore(JNIEnv *env, jc afterCodePoints, afterLength, score); } -static jint latinime_BinaryDictionaryUtils_editDistance(JNIEnv *env, jclass clazz, jintArray before, - jintArray after) { - jsize beforeLength = env->GetArrayLength(before); - jsize afterLength = env->GetArrayLength(after); - int beforeCodePoints[beforeLength]; - int afterCodePoints[afterLength]; - env->GetIntArrayRegion(before, 0, beforeLength, beforeCodePoints); - env->GetIntArrayRegion(after, 0, afterLength, afterCodePoints); - return AutocorrectionThresholdUtils::editDistance(beforeCodePoints, beforeLength, - afterCodePoints, afterLength); -} - static int latinime_BinaryDictionaryUtils_setCurrentTimeForTest(JNIEnv *env, jclass clazz, jint currentTime) { if (currentTime >= 0) { @@ -104,11 +92,6 @@ static const JNINativeMethod sMethods[] = { reinterpret_cast<void *>(latinime_BinaryDictionaryUtils_calcNormalizedScore) }, { - const_cast<char *>("editDistanceNative"), - const_cast<char *>("([I[I)I"), - reinterpret_cast<void *>(latinime_BinaryDictionaryUtils_editDistance) - }, - { const_cast<char *>("setCurrentTimeForTestNative"), const_cast<char *>("(I)I"), reinterpret_cast<void *>(latinime_BinaryDictionaryUtils_setCurrentTimeForTest) diff --git a/native/jni/com_android_inputmethod_latin_DicTraverseSession.cpp b/native/jni/com_android_inputmethod_latin_DicTraverseSession.cpp index 766064153..3c6bff3b6 100644 --- a/native/jni/com_android_inputmethod_latin_DicTraverseSession.cpp +++ b/native/jni/com_android_inputmethod_latin_DicTraverseSession.cpp @@ -22,7 +22,7 @@ #include "jni.h" #include "jni_common.h" #include "suggest/core/session/dic_traverse_session.h" -#include "suggest/core/session/prev_words_info.h" +#include "suggest/core/session/ngram_context.h" namespace latinime { class Dictionary; @@ -40,14 +40,14 @@ static void latinime_initDicTraverseSession(JNIEnv *env, jclass clazz, jlong tra } Dictionary *dict = reinterpret_cast<Dictionary *>(dictionary); if (!previousWord) { - PrevWordsInfo prevWordsInfo; - ts->init(dict, &prevWordsInfo, 0 /* suggestOptions */); + NgramContext emptyNgramContext; + ts->init(dict, &emptyNgramContext, 0 /* suggestOptions */); return; } int prevWord[previousWordLength]; env->GetIntArrayRegion(previousWord, 0, previousWordLength, prevWord); - PrevWordsInfo prevWordsInfo(prevWord, previousWordLength, false /* isStartOfSentence */); - ts->init(dict, &prevWordsInfo, 0 /* suggestOptions */); + NgramContext ngramContext(prevWord, previousWordLength, false /* isStartOfSentence */); + ts->init(dict, &ngramContext, 0 /* suggestOptions */); } static void latinime_releaseDicTraverseSession(JNIEnv *env, jclass clazz, jlong traverseSession) { diff --git a/native/jni/run-tests.sh b/native/jni/run-tests.sh index 3da45270d..a7fa82d9b 100755 --- a/native/jni/run-tests.sh +++ b/native/jni/run-tests.sh @@ -48,6 +48,13 @@ if [[ $show_usage == yes ]]; then if [[ ${BASH_SOURCE[0]} != $0 ]]; then return; else exit 1; fi fi +# Host build is never supported in unbundled (NDK/tapas) build +if [[ $enable_host_test == yes && -n $TARGET_BUILD_APPS ]]; then + echo "Host build is never supported in tapas build." 1>&2 + echo "Use lunch command instead." 1>&2 + if [[ ${BASH_SOURCE[0]} != $0 ]]; then return; else exit 1; fi +fi + target_test_name=liblatinime_target_unittests host_test_name=liblatinime_host_unittests diff --git a/native/jni/src/defines.h b/native/jni/src/defines.h index 24d04e51f..885118524 100644 --- a/native/jni/src/defines.h +++ b/native/jni/src/defines.h @@ -119,7 +119,7 @@ static inline void dumpWordInfo(const int *word, const int length, const int ran const int probability) { static char charBuf[50]; const int N = intArrayToCharArray(word, length, charBuf, NELEMS(charBuf)); - if (N > 1) { + if (N > 0) { AKLOGI("%2d [ %s ] (%d)", rank, charBuf, probability); } } @@ -299,8 +299,9 @@ static inline void prof_out(void) { #define NOT_AN_INDEX (-1) #define NOT_A_PROBABILITY (-1) #define NOT_A_DICT_POS (S_INT_MIN) +#define NOT_A_WORD_ID (S_INT_MIN) #define NOT_A_TIMESTAMP (-1) -#define NOT_A_LANGUAGE_WEIGHT (-1.0f) +#define NOT_A_WEIGHT_OF_LANG_MODEL_VS_SPATIAL_MODEL (-1.0f) // A special value to mean the first word confidence makes no sense in this case, // e.g. this is not a multi-word suggestion. @@ -337,7 +338,7 @@ static inline void prof_out(void) { #define MAX_POINTER_COUNT_G 2 // (MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1)-gram is supported. -#define MAX_PREV_WORD_COUNT_FOR_N_GRAM 1 +#define MAX_PREV_WORD_COUNT_FOR_N_GRAM 2 #define DISALLOW_DEFAULT_CONSTRUCTOR(TypeName) \ TypeName() = delete diff --git a/native/jni/src/suggest/core/dicnode/dic_node.h b/native/jni/src/suggest/core/dicnode/dic_node.h index d1b2c87be..5214077dc 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node.h +++ b/native/jni/src/suggest/core/dicnode/dic_node.h @@ -26,6 +26,7 @@ #include "suggest/core/dictionary/error_type_utils.h" #include "suggest/core/layout/proximity_info_state.h" #include "utils/char_utils.h" +#include "utils/int_array_view.h" #if DEBUG_DICT #define LOGI_SHOW_ADD_COST_PROP \ @@ -103,10 +104,10 @@ class DicNode { PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); } - // Init for root with prevWordsPtNodePos which is used for n-gram - void initAsRoot(const int rootPtNodeArrayPos, const int *const prevWordsPtNodePos) { + // Init for root with prevWordIds which is used for n-gram + void initAsRoot(const int rootPtNodeArrayPos, const WordIdArrayView prevWordIds) { mIsCachedForNextSuggestion = false; - mDicNodeProperties.init(rootPtNodeArrayPos, prevWordsPtNodePos); + mDicNodeProperties.init(rootPtNodeArrayPos, prevWordIds); mDicNodeState.init(); PROF_NODE_RESET(mProfiler); } @@ -114,12 +115,11 @@ class DicNode { // Init for root with previous word void initAsRootWithPreviousWord(const DicNode *const dicNode, const int rootPtNodeArrayPos) { mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion; - int newPrevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; - newPrevWordsPtNodePos[0] = dicNode->mDicNodeProperties.getPtNodePos(); - for (size_t i = 1; i < NELEMS(newPrevWordsPtNodePos); ++i) { - newPrevWordsPtNodePos[i] = dicNode->getPrevWordsTerminalPtNodePos()[i - 1]; - } - mDicNodeProperties.init(rootPtNodeArrayPos, newPrevWordsPtNodePos); + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> newPrevWordIds; + newPrevWordIds[0] = dicNode->mDicNodeProperties.getWordId(); + dicNode->getPrevWordIds().limit(newPrevWordIds.size() - 1) + .copyToArray(&newPrevWordIds, 1 /* offset */); + mDicNodeProperties.init(rootPtNodeArrayPos, WordIdArrayView::fromArray(newPrevWordIds)); mDicNodeState.initAsRootWithPreviousWord(&dicNode->mDicNodeState, dicNode->mDicNodeProperties.getDepth()); PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); @@ -135,19 +135,16 @@ class DicNode { PROF_NODE_COPY(&parentDicNode->mProfiler, mProfiler); } - void initAsChild(const DicNode *const dicNode, const int ptNodePos, - const int childrenPtNodeArrayPos, const int probability, const bool isTerminal, - const bool hasChildren, const bool isBlacklistedOrNotAWord, - const uint16_t mergedNodeCodePointCount, const int *const mergedNodeCodePoints) { + void initAsChild(const DicNode *const dicNode, const int childrenPtNodeArrayPos, + const int wordId, const CodePointArrayView mergedCodePoints) { uint16_t newDepth = static_cast<uint16_t>(dicNode->getNodeCodePointCount() + 1); mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion; const uint16_t newLeavingDepth = static_cast<uint16_t>( - dicNode->mDicNodeProperties.getLeavingDepth() + mergedNodeCodePointCount); - mDicNodeProperties.init(ptNodePos, childrenPtNodeArrayPos, mergedNodeCodePoints[0], - probability, isTerminal, hasChildren, isBlacklistedOrNotAWord, newDepth, - newLeavingDepth, dicNode->mDicNodeProperties.getPrevWordsTerminalPtNodePos()); - mDicNodeState.init(&dicNode->mDicNodeState, mergedNodeCodePointCount, - mergedNodeCodePoints); + dicNode->mDicNodeProperties.getLeavingDepth() + mergedCodePoints.size()); + mDicNodeProperties.init(childrenPtNodeArrayPos, mergedCodePoints[0], + wordId, newDepth, newLeavingDepth, dicNode->mDicNodeProperties.getPrevWordIds()); + mDicNodeState.init(&dicNode->mDicNodeState, mergedCodePoints.size(), + mergedCodePoints.data()); PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); } @@ -179,9 +176,6 @@ class DicNode { // Check if the current word and the previous word can be considered as a valid multiple word // suggestion. bool isValidMultipleWordSuggestion() const { - if (isBlacklistedOrNotAWord()) { - return false; - } // Treat suggestion as invalid if the current and the previous word are single character // words. const int prevWordLen = mDicNodeState.mDicNodeStateOutput.getPrevWordsLength() @@ -204,13 +198,12 @@ class DicNode { } // Used to get n-gram probability in DicNodeUtils. - int getPtNodePos() const { - return mDicNodeProperties.getPtNodePos(); + int getWordId() const { + return mDicNodeProperties.getWordId(); } - // TODO: Use view class to return PtNodePos array. - const int *getPrevWordsTerminalPtNodePos() const { - return mDicNodeProperties.getPrevWordsTerminalPtNodePos(); + const WordIdArrayView getPrevWordIds() const { + return mDicNodeProperties.getPrevWordIds(); } // Used in DicNodeUtils @@ -218,10 +211,6 @@ class DicNode { return mDicNodeProperties.getChildrenPtNodeArrayPos(); } - int getProbability() const { - return mDicNodeProperties.getProbability(); - } - AK_FORCE_INLINE bool isTerminalDicNode() const { const bool isTerminalPtNode = mDicNodeProperties.isTerminal(); const int currentDicNodeDepth = getNodeCodePointCount(); @@ -306,8 +295,9 @@ class DicNode { } // Used to prune nodes - float getCompoundDistance(const float languageWeight) const { - return mDicNodeState.mDicNodeStateScoring.getCompoundDistance(languageWeight); + float getCompoundDistance(const float weightOfLangModelVsSpatialModel) const { + return mDicNodeState.mDicNodeStateScoring.getCompoundDistance( + weightOfLangModelVsSpatialModel); } AK_FORCE_INLINE const int *getOutputWordBuf() const { @@ -404,10 +394,6 @@ class DicNode { return mDicNodeState.mDicNodeStateScoring.getContainedErrorTypes(); } - bool isBlacklistedOrNotAWord() const { - return mDicNodeProperties.isBlacklistedOrNotAWord(); - } - inline uint16_t getNodeCodePointCount() const { return mDicNodeProperties.getDepth(); } diff --git a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp index 69ea67418..ea438922f 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp +++ b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp @@ -18,7 +18,6 @@ #include "suggest/core/dicnode/dic_node.h" #include "suggest/core/dicnode/dic_node_vector.h" -#include "suggest/core/dictionary/multi_bigram_map.h" #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" namespace latinime { @@ -29,8 +28,8 @@ namespace latinime { /* static */ void DicNodeUtils::initAsRoot( const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy, - const int *const prevWordsPtNodePos, DicNode *const newRootDicNode) { - newRootDicNode->initAsRoot(dictionaryStructurePolicy->getRootPosition(), prevWordsPtNodePos); + const WordIdArrayView prevWordIds, DicNode *const newRootDicNode) { + newRootDicNode->initAsRoot(dictionaryStructurePolicy->getRootPosition(), prevWordIds); } /*static */ void DicNodeUtils::initAsRootWithPreviousWord( @@ -73,25 +72,17 @@ namespace latinime { if (dicNode->hasMultipleWords() && !dicNode->isValidMultipleWordSuggestion()) { return static_cast<float>(MAX_VALUE_FOR_WEIGHTING); } - const int probability = getBigramNodeProbability(dictionaryStructurePolicy, dicNode, - multiBigramMap); + const WordAttributes wordAttributes = dictionaryStructurePolicy->getWordAttributesInContext( + dicNode->getPrevWordIds(), dicNode->getWordId(), multiBigramMap); + if (wordAttributes.getProbability() == NOT_A_PROBABILITY + || (dicNode->hasMultipleWords() + && (wordAttributes.isBlacklisted() || wordAttributes.isNotAWord()))) { + return static_cast<float>(MAX_VALUE_FOR_WEIGHTING); + } // TODO: This equation to calculate the improbability looks unreasonable. Investigate this. - const float cost = static_cast<float>(MAX_PROBABILITY - probability) + const float cost = static_cast<float>(MAX_PROBABILITY - wordAttributes.getProbability()) / static_cast<float>(MAX_PROBABILITY); return cost; } -/* static */ int DicNodeUtils::getBigramNodeProbability( - const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy, - const DicNode *const dicNode, MultiBigramMap *const multiBigramMap) { - const int unigramProbability = dicNode->getProbability(); - if (multiBigramMap) { - const int *const prevWordsPtNodePos = dicNode->getPrevWordsTerminalPtNodePos(); - return multiBigramMap->getBigramProbability(dictionaryStructurePolicy, - prevWordsPtNodePos, dicNode->getPtNodePos(), unigramProbability); - } - return dictionaryStructurePolicy->getProbability(unigramProbability, - NOT_A_PROBABILITY); -} - } // namespace latinime diff --git a/native/jni/src/suggest/core/dicnode/dic_node_utils.h b/native/jni/src/suggest/core/dicnode/dic_node_utils.h index 00e80c604..b891a842a 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_utils.h +++ b/native/jni/src/suggest/core/dicnode/dic_node_utils.h @@ -18,6 +18,7 @@ #define LATINIME_DIC_NODE_UTILS_H #include "defines.h" +#include "utils/int_array_view.h" namespace latinime { @@ -30,7 +31,7 @@ class DicNodeUtils { public: static void initAsRoot( const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy, - const int *const prevWordPtNodePos, DicNode *const newRootDicNode); + const WordIdArrayView prevWordIds, DicNode *const newRootDicNode); static void initAsRootWithPreviousWord( const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy, const DicNode *const prevWordLastDicNode, DicNode *const newRootDicNode); @@ -46,10 +47,6 @@ class DicNodeUtils { DISALLOW_IMPLICIT_CONSTRUCTORS(DicNodeUtils); // Max number of bigrams to look up static const int MAX_BIGRAMS_CONSIDERED_PER_CONTEXT = 500; - - static int getBigramNodeProbability( - const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy, - const DicNode *const dicNode, MultiBigramMap *const multiBigramMap); }; } // namespace latinime #endif // LATINIME_DIC_NODE_UTILS_H diff --git a/native/jni/src/suggest/core/dicnode/dic_node_vector.h b/native/jni/src/suggest/core/dicnode/dic_node_vector.h index 54cde1988..e6b758954 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_vector.h +++ b/native/jni/src/suggest/core/dicnode/dic_node_vector.h @@ -21,6 +21,7 @@ #include "defines.h" #include "suggest/core/dicnode/dic_node.h" +#include "utils/int_array_view.h" namespace latinime { @@ -58,15 +59,11 @@ class DicNodeVector { mDicNodes.back().initAsPassingChild(dicNode); } - void pushLeavingChild(const DicNode *const dicNode, const int ptNodePos, - const int childrenPtNodeArrayPos, const int probability, const bool isTerminal, - const bool hasChildren, const bool isBlacklistedOrNotAWord, - const uint16_t mergedNodeCodePointCount, const int *const mergedNodeCodePoints) { + void pushLeavingChild(const DicNode *const dicNode, const int childrenPtNodeArrayPos, + const int wordId, const CodePointArrayView mergedCodePoints) { ASSERT(!mLock); mDicNodes.emplace_back(); - mDicNodes.back().initAsChild(dicNode, ptNodePos, childrenPtNodeArrayPos, probability, - isTerminal, hasChildren, isBlacklistedOrNotAWord, mergedNodeCodePointCount, - mergedNodeCodePoints); + mDicNodes.back().initAsChild(dicNode, childrenPtNodeArrayPos, wordId, mergedCodePoints); } DicNode *operator[](const int id) { diff --git a/native/jni/src/suggest/core/dicnode/internal/dic_node_properties.h b/native/jni/src/suggest/core/dicnode/internal/dic_node_properties.h index 8202176f7..1b796b5d4 100644 --- a/native/jni/src/suggest/core/dicnode/internal/dic_node_properties.h +++ b/native/jni/src/suggest/core/dicnode/internal/dic_node_properties.h @@ -18,8 +18,10 @@ #define LATINIME_DIC_NODE_PROPERTIES_H #include <cstdint> +#include <cstdlib> #include "defines.h" +#include "utils/int_array_view.h" namespace latinime { @@ -29,84 +31,61 @@ namespace latinime { class DicNodeProperties { public: AK_FORCE_INLINE DicNodeProperties() - : mPtNodePos(NOT_A_DICT_POS), mChildrenPtNodeArrayPos(NOT_A_DICT_POS), - mProbability(NOT_A_PROBABILITY), mDicNodeCodePoint(NOT_A_CODE_POINT), - mIsTerminal(false), mHasChildrenPtNodes(false), - mIsBlacklistedOrNotAWord(false), mDepth(0), mLeavingDepth(0) {} + : mChildrenPtNodeArrayPos(NOT_A_DICT_POS), mDicNodeCodePoint(NOT_A_CODE_POINT), + mWordId(NOT_A_WORD_ID), mDepth(0), mLeavingDepth(0), mPrevWordCount(0) {} ~DicNodeProperties() {} // Should be called only once per DicNode is initialized. - void init(const int pos, const int childrenPos, const int nodeCodePoint, const int probability, - const bool isTerminal, const bool hasChildren, const bool isBlacklistedOrNotAWord, - const uint16_t depth, const uint16_t leavingDepth, const int *const prevWordsNodePos) { - mPtNodePos = pos; + void init(const int childrenPos, const int nodeCodePoint, const int wordId, + const uint16_t depth, const uint16_t leavingDepth, const WordIdArrayView prevWordIds) { mChildrenPtNodeArrayPos = childrenPos; mDicNodeCodePoint = nodeCodePoint; - mProbability = probability; - mIsTerminal = isTerminal; - mHasChildrenPtNodes = hasChildren; - mIsBlacklistedOrNotAWord = isBlacklistedOrNotAWord; + mWordId = wordId; mDepth = depth; mLeavingDepth = leavingDepth; - memmove(mPrevWordsTerminalPtNodePos, prevWordsNodePos, sizeof(mPrevWordsTerminalPtNodePos)); + prevWordIds.copyToArray(&mPrevWordIds, 0 /* offset */); + mPrevWordCount = prevWordIds.size(); } // Init for root with prevWordsPtNodePos which is used for n-gram - void init(const int rootPtNodeArrayPos, const int *const prevWordsNodePos) { - mPtNodePos = NOT_A_DICT_POS; + void init(const int rootPtNodeArrayPos, const WordIdArrayView prevWordIds) { mChildrenPtNodeArrayPos = rootPtNodeArrayPos; mDicNodeCodePoint = NOT_A_CODE_POINT; - mProbability = NOT_A_PROBABILITY; - mIsTerminal = false; - mHasChildrenPtNodes = true; - mIsBlacklistedOrNotAWord = false; + mWordId = NOT_A_WORD_ID; mDepth = 0; mLeavingDepth = 0; - memmove(mPrevWordsTerminalPtNodePos, prevWordsNodePos, sizeof(mPrevWordsTerminalPtNodePos)); + prevWordIds.copyToArray(&mPrevWordIds, 0 /* offset */); + mPrevWordCount = prevWordIds.size(); } void initByCopy(const DicNodeProperties *const dicNodeProp) { - mPtNodePos = dicNodeProp->mPtNodePos; mChildrenPtNodeArrayPos = dicNodeProp->mChildrenPtNodeArrayPos; mDicNodeCodePoint = dicNodeProp->mDicNodeCodePoint; - mProbability = dicNodeProp->mProbability; - mIsTerminal = dicNodeProp->mIsTerminal; - mHasChildrenPtNodes = dicNodeProp->mHasChildrenPtNodes; - mIsBlacklistedOrNotAWord = dicNodeProp->mIsBlacklistedOrNotAWord; + mWordId = dicNodeProp->mWordId; mDepth = dicNodeProp->mDepth; mLeavingDepth = dicNodeProp->mLeavingDepth; - memmove(mPrevWordsTerminalPtNodePos, dicNodeProp->mPrevWordsTerminalPtNodePos, - sizeof(mPrevWordsTerminalPtNodePos)); + const WordIdArrayView prevWordIdArrayView = dicNodeProp->getPrevWordIds(); + prevWordIdArrayView.copyToArray(&mPrevWordIds, 0 /* offset */); + mPrevWordCount = prevWordIdArrayView.size(); } // Init as passing child void init(const DicNodeProperties *const dicNodeProp, const int codePoint) { - mPtNodePos = dicNodeProp->mPtNodePos; mChildrenPtNodeArrayPos = dicNodeProp->mChildrenPtNodeArrayPos; mDicNodeCodePoint = codePoint; // Overwrite the node char of a passing child - mProbability = dicNodeProp->mProbability; - mIsTerminal = dicNodeProp->mIsTerminal; - mHasChildrenPtNodes = dicNodeProp->mHasChildrenPtNodes; - mIsBlacklistedOrNotAWord = dicNodeProp->mIsBlacklistedOrNotAWord; + mWordId = dicNodeProp->mWordId; mDepth = dicNodeProp->mDepth + 1; // Increment the depth of a passing child mLeavingDepth = dicNodeProp->mLeavingDepth; - memmove(mPrevWordsTerminalPtNodePos, dicNodeProp->mPrevWordsTerminalPtNodePos, - sizeof(mPrevWordsTerminalPtNodePos)); - } - - int getPtNodePos() const { - return mPtNodePos; + const WordIdArrayView prevWordIdArrayView = dicNodeProp->getPrevWordIds(); + prevWordIdArrayView.copyToArray(&mPrevWordIds, 0 /* offset */); + mPrevWordCount = prevWordIdArrayView.size(); } int getChildrenPtNodeArrayPos() const { return mChildrenPtNodeArrayPos; } - int getProbability() const { - return mProbability; - } - int getDicNodeCodePoint() const { return mDicNodeCodePoint; } @@ -121,35 +100,32 @@ class DicNodeProperties { } bool isTerminal() const { - return mIsTerminal; + return mWordId != NOT_A_WORD_ID; } bool hasChildren() const { - return mHasChildrenPtNodes || mDepth != mLeavingDepth; + return (mChildrenPtNodeArrayPos != NOT_A_DICT_POS) || mDepth != mLeavingDepth; } - bool isBlacklistedOrNotAWord() const { - return mIsBlacklistedOrNotAWord; + const WordIdArrayView getPrevWordIds() const { + return WordIdArrayView::fromArray(mPrevWordIds).limit(mPrevWordCount); } - const int *getPrevWordsTerminalPtNodePos() const { - return mPrevWordsTerminalPtNodePos; + int getWordId() const { + return mWordId; } private: // Caution!!! // Use a default copy constructor and an assign operator because shallow copies are ok // for this class - int mPtNodePos; int mChildrenPtNodeArrayPos; - int mProbability; int mDicNodeCodePoint; - bool mIsTerminal; - bool mHasChildrenPtNodes; - bool mIsBlacklistedOrNotAWord; + int mWordId; uint16_t mDepth; uint16_t mLeavingDepth; - int mPrevWordsTerminalPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> mPrevWordIds; + size_t mPrevWordCount; }; } // namespace latinime #endif // LATINIME_DIC_NODE_PROPERTIES_H diff --git a/native/jni/src/suggest/core/dicnode/internal/dic_node_state_scoring.h b/native/jni/src/suggest/core/dicnode/internal/dic_node_state_scoring.h index c19d48eb9..3a54c2599 100644 --- a/native/jni/src/suggest/core/dicnode/internal/dic_node_state_scoring.h +++ b/native/jni/src/suggest/core/dicnode/internal/dic_node_state_scoring.h @@ -103,8 +103,10 @@ class DicNodeStateScoring { return getCompoundDistance(1.0f); } - float getCompoundDistance(const float languageWeight) const { - return mSpatialDistance + mLanguageDistance * languageWeight; + float getCompoundDistance( + const float weightOfLangModelVsSpatialModel) const { + return mSpatialDistance + + mLanguageDistance * weightOfLangModelVsSpatialModel; } float getNormalizedCompoundDistance() const { diff --git a/native/jni/src/suggest/core/dictionary/binary_dictionary_shortcut_iterator.h b/native/jni/src/suggest/core/dictionary/binary_dictionary_shortcut_iterator.h index 558e0a5c3..ee1606b6a 100644 --- a/native/jni/src/suggest/core/dictionary/binary_dictionary_shortcut_iterator.h +++ b/native/jni/src/suggest/core/dictionary/binary_dictionary_shortcut_iterator.h @@ -31,6 +31,11 @@ class BinaryDictionaryShortcutIterator { mPos(shortcutStructurePolicy->getStartPos(shortcutPos)), mHasNextShortcutTarget(shortcutPos != NOT_A_DICT_POS) {} + BinaryDictionaryShortcutIterator(const BinaryDictionaryShortcutIterator &&shortcutIterator) + : mShortcutStructurePolicy(shortcutIterator.mShortcutStructurePolicy), + mPos(shortcutIterator.mPos), + mHasNextShortcutTarget(shortcutIterator.mHasNextShortcutTarget) {} + AK_FORCE_INLINE bool hasNextShortcutTarget() const { return mHasNextShortcutTarget; } @@ -45,7 +50,8 @@ class BinaryDictionaryShortcutIterator { } private: - DISALLOW_IMPLICIT_CONSTRUCTORS(BinaryDictionaryShortcutIterator); + DISALLOW_DEFAULT_CONSTRUCTOR(BinaryDictionaryShortcutIterator); + DISALLOW_ASSIGNMENT_OPERATOR(BinaryDictionaryShortcutIterator); const DictionaryShortcutsStructurePolicy *const mShortcutStructurePolicy; int mPos; diff --git a/native/jni/src/suggest/core/dictionary/dictionary.cpp b/native/jni/src/suggest/core/dictionary/dictionary.cpp index d62573970..bfe17cc4c 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.cpp +++ b/native/jni/src/suggest/core/dictionary/dictionary.cpp @@ -23,11 +23,12 @@ #include "suggest/core/policy/dictionary_header_structure_policy.h" #include "suggest/core/result/suggestion_results.h" #include "suggest/core/session/dic_traverse_session.h" -#include "suggest/core/session/prev_words_info.h" +#include "suggest/core/session/ngram_context.h" #include "suggest/core/suggest.h" #include "suggest/core/suggest_options.h" #include "suggest/policyimpl/gesture/gesture_suggest_policy_factory.h" #include "suggest/policyimpl/typing/typing_suggest_policy_factory.h" +#include "utils/int_array_view.h" #include "utils/log_utils.h" #include "utils/time_keeper.h" @@ -45,88 +46,84 @@ Dictionary::Dictionary(JNIEnv *env, DictionaryStructureWithBufferPolicy::Structu void Dictionary::getSuggestions(ProximityInfo *proximityInfo, DicTraverseSession *traverseSession, int *xcoordinates, int *ycoordinates, int *times, int *pointerIds, int *inputCodePoints, - int inputSize, const PrevWordsInfo *const prevWordsInfo, - const SuggestOptions *const suggestOptions, const float languageWeight, + int inputSize, const NgramContext *const ngramContext, + const SuggestOptions *const suggestOptions, const float weightOfLangModelVsSpatialModel, SuggestionResults *const outSuggestionResults) const { TimeKeeper::setCurrentTime(); - traverseSession->init(this, prevWordsInfo, suggestOptions); + traverseSession->init(this, ngramContext, suggestOptions); const auto &suggest = suggestOptions->isGesture() ? mGestureSuggest : mTypingSuggest; suggest->getSuggestions(proximityInfo, traverseSession, xcoordinates, ycoordinates, times, pointerIds, inputCodePoints, inputSize, - languageWeight, outSuggestionResults); - if (DEBUG_DICT) { - outSuggestionResults->dumpSuggestions(); - } + weightOfLangModelVsSpatialModel, outSuggestionResults); } Dictionary::NgramListenerForPrediction::NgramListenerForPrediction( - const PrevWordsInfo *const prevWordsInfo, SuggestionResults *const suggestionResults, + const NgramContext *const ngramContext, const WordIdArrayView prevWordIds, + SuggestionResults *const suggestionResults, const DictionaryStructureWithBufferPolicy *const dictStructurePolicy) - : mPrevWordsInfo(prevWordsInfo), mSuggestionResults(suggestionResults), - mDictStructurePolicy(dictStructurePolicy) {} + : mNgramContext(ngramContext), mPrevWordIds(prevWordIds), + mSuggestionResults(suggestionResults), mDictStructurePolicy(dictStructurePolicy) {} void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbability, - const int targetPtNodePos) { - if (targetPtNodePos == NOT_A_DICT_POS) { + const int targetWordId) { + if (targetWordId == NOT_A_WORD_ID) { return; } - if (mPrevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */) + if (mNgramContext->isNthPrevWordBeginningOfSentence(1 /* n */) && ngramProbability == NOT_A_PROBABILITY) { return; } int targetWordCodePoints[MAX_WORD_LENGTH]; - int unigramProbability = 0; - const int codePointCount = mDictStructurePolicy-> - getCodePointsAndProbabilityAndReturnCodePointCount(targetPtNodePos, - MAX_WORD_LENGTH, targetWordCodePoints, &unigramProbability); + const int codePointCount = mDictStructurePolicy->getCodePointsAndReturnCodePointCount( + targetWordId, MAX_WORD_LENGTH, targetWordCodePoints); if (codePointCount <= 0) { return; } - const int probability = mDictStructurePolicy->getProbability( - unigramProbability, ngramProbability); - mSuggestionResults->addPrediction(targetWordCodePoints, codePointCount, probability); + const WordAttributes wordAttributes = mDictStructurePolicy->getWordAttributesInContext( + mPrevWordIds, targetWordId, nullptr /* multiBigramMap */); + mSuggestionResults->addPrediction(targetWordCodePoints, codePointCount, + wordAttributes.getProbability()); } -void Dictionary::getPredictions(const PrevWordsInfo *const prevWordsInfo, +void Dictionary::getPredictions(const NgramContext *const ngramContext, SuggestionResults *const outSuggestionResults) const { TimeKeeper::setCurrentTime(); - NgramListenerForPrediction listener(prevWordsInfo, outSuggestionResults, - mDictionaryStructureWithBufferPolicy.get()); - int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; - prevWordsInfo->getPrevWordsTerminalPtNodePos( - mDictionaryStructureWithBufferPolicy.get(), prevWordsPtNodePos, + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; + const WordIdArrayView prevWordIds = ngramContext->getPrevWordIds( + mDictionaryStructureWithBufferPolicy.get(), &prevWordIdArray, true /* tryLowerCaseSearch */); - mDictionaryStructureWithBufferPolicy->iterateNgramEntries(prevWordsPtNodePos, &listener); + NgramListenerForPrediction listener(ngramContext, prevWordIds, outSuggestionResults, + mDictionaryStructureWithBufferPolicy.get()); + mDictionaryStructureWithBufferPolicy->iterateNgramEntries(prevWordIds, &listener); } -int Dictionary::getProbability(const int *word, int length) const { - return getNgramProbability(nullptr /* prevWordsInfo */, word, length); +int Dictionary::getProbability(const CodePointArrayView codePoints) const { + return getNgramProbability(nullptr /* ngramContext */, codePoints); } -int Dictionary::getMaxProbabilityOfExactMatches(const int *word, int length) const { +int Dictionary::getMaxProbabilityOfExactMatches(const CodePointArrayView codePoints) const { TimeKeeper::setCurrentTime(); return DictionaryUtils::getMaxProbabilityOfExactMatches( - mDictionaryStructureWithBufferPolicy.get(), word, length); + mDictionaryStructureWithBufferPolicy.get(), codePoints); } -int Dictionary::getNgramProbability(const PrevWordsInfo *const prevWordsInfo, const int *word, - int length) const { +int Dictionary::getNgramProbability(const NgramContext *const ngramContext, + const CodePointArrayView codePoints) const { TimeKeeper::setCurrentTime(); - int nextWordPos = mDictionaryStructureWithBufferPolicy->getTerminalPtNodePositionOfWord(word, - length, false /* forceLowerCaseSearch */); - if (NOT_A_DICT_POS == nextWordPos) return NOT_A_PROBABILITY; - if (!prevWordsInfo) { - return getDictionaryStructurePolicy()->getProbabilityOfPtNode( - nullptr /* prevWordsPtNodePos */, nextWordPos); + const int wordId = mDictionaryStructureWithBufferPolicy->getWordId(codePoints, + false /* forceLowerCaseSearch */); + if (wordId == NOT_A_WORD_ID) return NOT_A_PROBABILITY; + if (!ngramContext) { + return getDictionaryStructurePolicy()->getProbabilityOfWord(WordIdArrayView(), wordId); } - int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; - prevWordsInfo->getPrevWordsTerminalPtNodePos( - mDictionaryStructureWithBufferPolicy.get(), prevWordsPtNodePos, + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; + const WordIdArrayView prevWordIds = ngramContext->getPrevWordIds( + mDictionaryStructureWithBufferPolicy.get(), &prevWordIdArray, true /* tryLowerCaseSearch */); - return getDictionaryStructurePolicy()->getProbabilityOfPtNode(prevWordsPtNodePos, nextWordPos); + return getDictionaryStructurePolicy()->getProbabilityOfWord(prevWordIds, wordId); } -bool Dictionary::addUnigramEntry(const int *const word, const int length, +bool Dictionary::addUnigramEntry(const CodePointArrayView codePoints, const UnigramProperty *const unigramProperty) { if (unigramProperty->representsBeginningOfSentence() && !mDictionaryStructureWithBufferPolicy->getHeaderStructurePolicy() @@ -135,24 +132,31 @@ bool Dictionary::addUnigramEntry(const int *const word, const int length, return false; } TimeKeeper::setCurrentTime(); - return mDictionaryStructureWithBufferPolicy->addUnigramEntry(word, length, unigramProperty); + return mDictionaryStructureWithBufferPolicy->addUnigramEntry(codePoints, unigramProperty); +} + +bool Dictionary::removeUnigramEntry(const CodePointArrayView codePoints) { + TimeKeeper::setCurrentTime(); + return mDictionaryStructureWithBufferPolicy->removeUnigramEntry(codePoints); } -bool Dictionary::removeUnigramEntry(const int *const codePoints, const int codePointCount) { +bool Dictionary::addNgramEntry(const NgramProperty *const ngramProperty) { TimeKeeper::setCurrentTime(); - return mDictionaryStructureWithBufferPolicy->removeUnigramEntry(codePoints, codePointCount); + return mDictionaryStructureWithBufferPolicy->addNgramEntry(ngramProperty); } -bool Dictionary::addNgramEntry(const PrevWordsInfo *const prevWordsInfo, - const BigramProperty *const bigramProperty) { +bool Dictionary::removeNgramEntry(const NgramContext *const ngramContext, + const CodePointArrayView codePoints) { TimeKeeper::setCurrentTime(); - return mDictionaryStructureWithBufferPolicy->addNgramEntry(prevWordsInfo, bigramProperty); + return mDictionaryStructureWithBufferPolicy->removeNgramEntry(ngramContext, codePoints); } -bool Dictionary::removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, - const int *const word, const int length) { +bool Dictionary::updateEntriesForWordWithNgramContext(const NgramContext *const ngramContext, + const CodePointArrayView codePoints, const bool isValidWord, + const HistoricalInfo historicalInfo) { TimeKeeper::setCurrentTime(); - return mDictionaryStructureWithBufferPolicy->removeNgramEntry(prevWordsInfo, word, length); + return mDictionaryStructureWithBufferPolicy->updateEntriesForWordWithNgramContext(ngramContext, + codePoints, isValidWord, historicalInfo); } bool Dictionary::flush(const char *const filePath) { @@ -177,11 +181,9 @@ void Dictionary::getProperty(const char *const query, const int queryLength, cha maxResultLength); } -const WordProperty Dictionary::getWordProperty(const int *const codePoints, - const int codePointCount) { +const WordProperty Dictionary::getWordProperty(const CodePointArrayView codePoints) { TimeKeeper::setCurrentTime(); - return mDictionaryStructureWithBufferPolicy->getWordProperty( - codePoints, codePointCount); + return mDictionaryStructureWithBufferPolicy->getWordProperty(codePoints); } int Dictionary::getNextWordAndNextToken(const int token, int *const outCodePoints, diff --git a/native/jni/src/suggest/core/dictionary/dictionary.h b/native/jni/src/suggest/core/dictionary/dictionary.h index 732d3b199..a5e986d15 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.h +++ b/native/jni/src/suggest/core/dictionary/dictionary.h @@ -22,16 +22,18 @@ #include "defines.h" #include "jni.h" #include "suggest/core/dictionary/ngram_listener.h" +#include "suggest/core/dictionary/property/historical_info.h" #include "suggest/core/dictionary/property/word_property.h" #include "suggest/core/policy/dictionary_header_structure_policy.h" #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" #include "suggest/core/suggest_interface.h" +#include "utils/int_array_view.h" namespace latinime { class DictionaryStructureWithBufferPolicy; class DicTraverseSession; -class PrevWordsInfo; +class NgramContext; class ProximityInfo; class SuggestionResults; class SuggestOptions; @@ -64,30 +66,33 @@ class Dictionary { void getSuggestions(ProximityInfo *proximityInfo, DicTraverseSession *traverseSession, int *xcoordinates, int *ycoordinates, int *times, int *pointerIds, int *inputCodePoints, - int inputSize, const PrevWordsInfo *const prevWordsInfo, - const SuggestOptions *const suggestOptions, const float languageWeight, + int inputSize, const NgramContext *const ngramContext, + const SuggestOptions *const suggestOptions, const float weightOfLangModelVsSpatialModel, SuggestionResults *const outSuggestionResults) const; - void getPredictions(const PrevWordsInfo *const prevWordsInfo, + void getPredictions(const NgramContext *const ngramContext, SuggestionResults *const outSuggestionResults) const; - int getProbability(const int *word, int length) const; + int getProbability(const CodePointArrayView codePoints) const; - int getMaxProbabilityOfExactMatches(const int *word, int length) const; + int getMaxProbabilityOfExactMatches(const CodePointArrayView codePoints) const; - int getNgramProbability(const PrevWordsInfo *const prevWordsInfo, - const int *word, int length) const; + int getNgramProbability(const NgramContext *const ngramContext, + const CodePointArrayView codePoints) const; - bool addUnigramEntry(const int *const codePoints, const int codePointCount, + bool addUnigramEntry(const CodePointArrayView codePoints, const UnigramProperty *const unigramProperty); - bool removeUnigramEntry(const int *const codePoints, const int codePointCount); + bool removeUnigramEntry(const CodePointArrayView codePoints); - bool addNgramEntry(const PrevWordsInfo *const prevWordsInfo, - const BigramProperty *const bigramProperty); + bool addNgramEntry(const NgramProperty *const ngramProperty); - bool removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, const int *const word, - const int length); + bool removeNgramEntry(const NgramContext *const ngramContext, + const CodePointArrayView codePoints); + + bool updateEntriesForWordWithNgramContext(const NgramContext *const ngramContext, + const CodePointArrayView codePoints, const bool isValidWord, + const HistoricalInfo historicalInfo); bool flush(const char *const filePath); @@ -98,7 +103,7 @@ class Dictionary { void getProperty(const char *const query, const int queryLength, char *const outResult, const int maxResultLength); - const WordProperty getWordProperty(const int *const codePoints, const int codePointCount); + const WordProperty getWordProperty(const CodePointArrayView codePoints); // Method to iterate all words in the dictionary. // The returned token has to be used to get the next word. If token is 0, this method newly @@ -117,15 +122,16 @@ class Dictionary { class NgramListenerForPrediction : public NgramListener { public: - NgramListenerForPrediction(const PrevWordsInfo *const prevWordsInfo, - SuggestionResults *const suggestionResults, + NgramListenerForPrediction(const NgramContext *const ngramContext, + const WordIdArrayView prevWordIds, SuggestionResults *const suggestionResults, const DictionaryStructureWithBufferPolicy *const dictStructurePolicy); - virtual void onVisitEntry(const int ngramProbability, const int targetPtNodePos); + virtual void onVisitEntry(const int ngramProbability, const int targetWordId); private: DISALLOW_IMPLICIT_CONSTRUCTORS(NgramListenerForPrediction); - const PrevWordsInfo *const mPrevWordsInfo; + const NgramContext *const mNgramContext; + const WordIdArrayView mPrevWordIds; SuggestionResults *const mSuggestionResults; const DictionaryStructureWithBufferPolicy *const mDictStructurePolicy; }; diff --git a/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp b/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp index b94966cbe..9573c37bc 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp +++ b/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp @@ -21,34 +21,35 @@ #include "suggest/core/dicnode/dic_node_vector.h" #include "suggest/core/dictionary/dictionary.h" #include "suggest/core/dictionary/digraph_utils.h" -#include "suggest/core/session/prev_words_info.h" +#include "suggest/core/session/ngram_context.h" #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" +#include "utils/int_array_view.h" namespace latinime { /* static */ int DictionaryUtils::getMaxProbabilityOfExactMatches( const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy, - const int *const codePoints, const int codePointCount) { + const CodePointArrayView codePoints) { std::vector<DicNode> current; std::vector<DicNode> next; - // No prev words information. - PrevWordsInfo emptyPrevWordsInfo; - int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; - emptyPrevWordsInfo.getPrevWordsTerminalPtNodePos(dictionaryStructurePolicy, - prevWordsPtNodePos, false /* tryLowerCaseSearch */); + // No ngram context. + NgramContext emptyNgramContext; + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; + const WordIdArrayView prevWordIds = emptyNgramContext.getPrevWordIds( + dictionaryStructurePolicy, &prevWordIdArray, false /* tryLowerCaseSearch */); current.emplace_back(); - DicNodeUtils::initAsRoot(dictionaryStructurePolicy, prevWordsPtNodePos, ¤t.front()); - for (int i = 0; i < codePointCount; ++i) { + DicNodeUtils::initAsRoot(dictionaryStructurePolicy, prevWordIds, ¤t.front()); + for (const int codePoint : codePoints) { // The base-lower input is used to ignore case errors and accent errors. - const int codePoint = CharUtils::toBaseLowerCase(codePoints[i]); + const int baseLowerCodePoint = CharUtils::toBaseLowerCase(codePoint); for (const DicNode &dicNode : current) { - if (dicNode.isInDigraph() && dicNode.getNodeCodePoint() == codePoint) { + if (dicNode.isInDigraph() && dicNode.getNodeCodePoint() == baseLowerCodePoint) { next.emplace_back(dicNode); next.back().advanceDigraphIndex(); continue; } - processChildDicNodes(dictionaryStructurePolicy, codePoint, &dicNode, &next); + processChildDicNodes(dictionaryStructurePolicy, baseLowerCodePoint, &dicNode, &next); } current.clear(); current.swap(next); @@ -59,8 +60,11 @@ namespace latinime { if (!dicNode.isTerminalDicNode()) { continue; } + const WordAttributes wordAttributes = + dictionaryStructurePolicy->getWordAttributesInContext(dicNode.getPrevWordIds(), + dicNode.getWordId(), nullptr /* multiBigramMap */); // dicNode can contain case errors, accent errors, intentional omissions or digraphs. - maxProbability = std::max(maxProbability, dicNode.getProbability()); + maxProbability = std::max(maxProbability, wordAttributes.getProbability()); } return maxProbability; } diff --git a/native/jni/src/suggest/core/dictionary/dictionary_utils.h b/native/jni/src/suggest/core/dictionary/dictionary_utils.h index 358ebf674..4dd21c9be 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary_utils.h +++ b/native/jni/src/suggest/core/dictionary/dictionary_utils.h @@ -20,6 +20,7 @@ #include <vector> #include "defines.h" +#include "utils/int_array_view.h" namespace latinime { @@ -30,7 +31,7 @@ class DictionaryUtils { public: static int getMaxProbabilityOfExactMatches( const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy, - const int *const codePoints, const int codePointCount); + const CodePointArrayView codePoints); private: DISALLOW_IMPLICIT_CONSTRUCTORS(DictionaryUtils); diff --git a/native/jni/src/suggest/core/dictionary/error_type_utils.cpp b/native/jni/src/suggest/core/dictionary/error_type_utils.cpp index b6bf7a98c..1e2494e92 100644 --- a/native/jni/src/suggest/core/dictionary/error_type_utils.cpp +++ b/native/jni/src/suggest/core/dictionary/error_type_utils.cpp @@ -19,17 +19,18 @@ namespace latinime { const ErrorTypeUtils::ErrorType ErrorTypeUtils::NOT_AN_ERROR = 0x0; -const ErrorTypeUtils::ErrorType ErrorTypeUtils::MATCH_WITH_CASE_ERROR = 0x1; -const ErrorTypeUtils::ErrorType ErrorTypeUtils::MATCH_WITH_ACCENT_ERROR = 0x2; -const ErrorTypeUtils::ErrorType ErrorTypeUtils::MATCH_WITH_DIGRAPH = 0x4; -const ErrorTypeUtils::ErrorType ErrorTypeUtils::INTENTIONAL_OMISSION = 0x8; -const ErrorTypeUtils::ErrorType ErrorTypeUtils::EDIT_CORRECTION = 0x10; -const ErrorTypeUtils::ErrorType ErrorTypeUtils::PROXIMITY_CORRECTION = 0x20; -const ErrorTypeUtils::ErrorType ErrorTypeUtils::COMPLETION = 0x40; -const ErrorTypeUtils::ErrorType ErrorTypeUtils::NEW_WORD = 0x80; +const ErrorTypeUtils::ErrorType ErrorTypeUtils::MATCH_WITH_WRONG_CASE = 0x1; +const ErrorTypeUtils::ErrorType ErrorTypeUtils::MATCH_WITH_MISSING_ACCENT = 0x2; +const ErrorTypeUtils::ErrorType ErrorTypeUtils::MATCH_WITH_WRONG_ACCENT = 0x4; +const ErrorTypeUtils::ErrorType ErrorTypeUtils::MATCH_WITH_DIGRAPH = 0x8; +const ErrorTypeUtils::ErrorType ErrorTypeUtils::INTENTIONAL_OMISSION = 0x10; +const ErrorTypeUtils::ErrorType ErrorTypeUtils::EDIT_CORRECTION = 0x20; +const ErrorTypeUtils::ErrorType ErrorTypeUtils::PROXIMITY_CORRECTION = 0x40; +const ErrorTypeUtils::ErrorType ErrorTypeUtils::COMPLETION = 0x80; +const ErrorTypeUtils::ErrorType ErrorTypeUtils::NEW_WORD = 0x100; const ErrorTypeUtils::ErrorType ErrorTypeUtils::ERRORS_TREATED_AS_AN_EXACT_MATCH = - NOT_AN_ERROR | MATCH_WITH_CASE_ERROR | MATCH_WITH_ACCENT_ERROR | MATCH_WITH_DIGRAPH; + NOT_AN_ERROR | MATCH_WITH_WRONG_CASE | MATCH_WITH_MISSING_ACCENT | MATCH_WITH_DIGRAPH; const ErrorTypeUtils::ErrorType ErrorTypeUtils::ERRORS_TREATED_AS_AN_EXACT_MATCH_WITH_INTENTIONAL_OMISSION = diff --git a/native/jni/src/suggest/core/dictionary/error_type_utils.h b/native/jni/src/suggest/core/dictionary/error_type_utils.h index e3e76b238..fd1d5fcff 100644 --- a/native/jni/src/suggest/core/dictionary/error_type_utils.h +++ b/native/jni/src/suggest/core/dictionary/error_type_utils.h @@ -30,8 +30,9 @@ class ErrorTypeUtils { typedef uint32_t ErrorType; static const ErrorType NOT_AN_ERROR; - static const ErrorType MATCH_WITH_CASE_ERROR; - static const ErrorType MATCH_WITH_ACCENT_ERROR; + static const ErrorType MATCH_WITH_WRONG_CASE; + static const ErrorType MATCH_WITH_MISSING_ACCENT; + static const ErrorType MATCH_WITH_WRONG_ACCENT; static const ErrorType MATCH_WITH_DIGRAPH; // Treat error as an intentional omission when the CorrectionType is omission and the node can // be intentional omission. diff --git a/native/jni/src/suggest/core/dictionary/multi_bigram_map.cpp b/native/jni/src/suggest/core/dictionary/multi_bigram_map.cpp index 91f33a8dd..761f51ec8 100644 --- a/native/jni/src/suggest/core/dictionary/multi_bigram_map.cpp +++ b/native/jni/src/suggest/core/dictionary/multi_bigram_map.cpp @@ -35,39 +35,37 @@ const int MultiBigramMap::BigramMap::DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP = // Also caches the bigrams if there is space remaining and they have not been cached already. int MultiBigramMap::getBigramProbability( const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int *const prevWordsPtNodePos, const int nextWordPosition, + const WordIdArrayView prevWordIds, const int nextWordId, const int unigramProbability) { - if (!prevWordsPtNodePos || prevWordsPtNodePos[0] == NOT_A_DICT_POS) { + if (prevWordIds.empty() || prevWordIds[0] == NOT_A_WORD_ID) { return structurePolicy->getProbability(unigramProbability, NOT_A_PROBABILITY); } - std::unordered_map<int, BigramMap>::const_iterator mapPosition = - mBigramMaps.find(prevWordsPtNodePos[0]); + const auto mapPosition = mBigramMaps.find(prevWordIds[0]); if (mapPosition != mBigramMaps.end()) { - return mapPosition->second.getBigramProbability(structurePolicy, nextWordPosition, + return mapPosition->second.getBigramProbability(structurePolicy, nextWordId, unigramProbability); } if (mBigramMaps.size() < MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP) { - addBigramsForWordPosition(structurePolicy, prevWordsPtNodePos); - return mBigramMaps[prevWordsPtNodePos[0]].getBigramProbability(structurePolicy, - nextWordPosition, unigramProbability); + addBigramsForWord(structurePolicy, prevWordIds); + return mBigramMaps[prevWordIds[0]].getBigramProbability(structurePolicy, + nextWordId, unigramProbability); } - return readBigramProbabilityFromBinaryDictionary(structurePolicy, prevWordsPtNodePos, - nextWordPosition, unigramProbability); + return readBigramProbabilityFromBinaryDictionary(structurePolicy, prevWordIds, + nextWordId, unigramProbability); } void MultiBigramMap::BigramMap::init( const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int *const prevWordsPtNodePos) { - structurePolicy->iterateNgramEntries(prevWordsPtNodePos, this /* listener */); + const WordIdArrayView prevWordIds) { + structurePolicy->iterateNgramEntries(prevWordIds, this /* listener */); } int MultiBigramMap::BigramMap::getBigramProbability( const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int nextWordPosition, const int unigramProbability) const { + const int nextWordId, const int unigramProbability) const { int bigramProbability = NOT_A_PROBABILITY; - if (mBloomFilter.isInFilter(nextWordPosition)) { - const std::unordered_map<int, int>::const_iterator bigramProbabilityIt = - mBigramMap.find(nextWordPosition); + if (mBloomFilter.isInFilter(nextWordId)) { + const auto bigramProbabilityIt = mBigramMap.find(nextWordId); if (bigramProbabilityIt != mBigramMap.end()) { bigramProbability = bigramProbabilityIt->second; } @@ -75,29 +73,24 @@ int MultiBigramMap::BigramMap::getBigramProbability( return structurePolicy->getProbability(unigramProbability, bigramProbability); } -void MultiBigramMap::BigramMap::onVisitEntry(const int ngramProbability, - const int targetPtNodePos) { - if (targetPtNodePos == NOT_A_DICT_POS) { +void MultiBigramMap::BigramMap::onVisitEntry(const int ngramProbability, const int targetWordId) { + if (targetWordId == NOT_A_WORD_ID) { return; } - mBigramMap[targetPtNodePos] = ngramProbability; - mBloomFilter.setInFilter(targetPtNodePos); + mBigramMap[targetWordId] = ngramProbability; + mBloomFilter.setInFilter(targetWordId); } -void MultiBigramMap::addBigramsForWordPosition( +void MultiBigramMap::addBigramsForWord( const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int *const prevWordsPtNodePos) { - if (prevWordsPtNodePos) { - mBigramMaps[prevWordsPtNodePos[0]].init(structurePolicy, prevWordsPtNodePos); - } + const WordIdArrayView prevWordIds) { + mBigramMaps[prevWordIds[0]].init(structurePolicy, prevWordIds); } int MultiBigramMap::readBigramProbabilityFromBinaryDictionary( const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int *const prevWordsPtNodePos, const int nextWordPosition, - const int unigramProbability) { - const int bigramProbability = structurePolicy->getProbabilityOfPtNode(prevWordsPtNodePos, - nextWordPosition); + const WordIdArrayView prevWordIds, const int nextWordId, const int unigramProbability) { + const int bigramProbability = structurePolicy->getProbabilityOfWord(prevWordIds, nextWordId); if (bigramProbability != NOT_A_PROBABILITY) { return bigramProbability; } diff --git a/native/jni/src/suggest/core/dictionary/multi_bigram_map.h b/native/jni/src/suggest/core/dictionary/multi_bigram_map.h index ad36dde83..d2eb5cc32 100644 --- a/native/jni/src/suggest/core/dictionary/multi_bigram_map.h +++ b/native/jni/src/suggest/core/dictionary/multi_bigram_map.h @@ -25,6 +25,7 @@ #include "suggest/core/dictionary/bloom_filter.h" #include "suggest/core/dictionary/ngram_listener.h" #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" +#include "utils/int_array_view.h" namespace latinime { @@ -39,8 +40,7 @@ class MultiBigramMap { // Look up the bigram probability for the given word pair from the cached bigram maps. // Also caches the bigrams if there is space remaining and they have not been cached already. int getBigramProbability(const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int *const prevWordsPtNodePos, const int nextWordPosition, - const int unigramProbability); + const WordIdArrayView prevWordIds, const int nextWordId, const int unigramProbability); void clear() { mBigramMaps.clear(); @@ -58,11 +58,11 @@ class MultiBigramMap { virtual ~BigramMap() {} void init(const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int *const prevWordsPtNodePos); + const WordIdArrayView prevWordIds); int getBigramProbability( const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int nextWordPosition, const int unigramProbability) const; - virtual void onVisitEntry(const int ngramProbability, const int targetPtNodePos); + const int nextWordId, const int unigramProbability) const; + virtual void onVisitEntry(const int ngramProbability, const int targetWordId); private: static const int DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP; @@ -70,14 +70,12 @@ class MultiBigramMap { BloomFilter mBloomFilter; }; - void addBigramsForWordPosition( - const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int *const prevWordsPtNodePos); + void addBigramsForWord(const DictionaryStructureWithBufferPolicy *const structurePolicy, + const WordIdArrayView prevWordIds); int readBigramProbabilityFromBinaryDictionary( const DictionaryStructureWithBufferPolicy *const structurePolicy, - const int *const prevWordsPtNodePos, const int nextWordPosition, - const int unigramProbability); + const WordIdArrayView prevWordIds, const int nextWordId, const int unigramProbability); static const size_t MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP; std::unordered_map<int, BigramMap> mBigramMaps; diff --git a/native/jni/src/suggest/core/dictionary/ngram_listener.h b/native/jni/src/suggest/core/dictionary/ngram_listener.h index 88b88bafb..e9b3c1aaf 100644 --- a/native/jni/src/suggest/core/dictionary/ngram_listener.h +++ b/native/jni/src/suggest/core/dictionary/ngram_listener.h @@ -26,7 +26,7 @@ namespace latinime { */ class NgramListener { public: - virtual void onVisitEntry(const int ngramProbability, const int targetPtNodePos) = 0; + virtual void onVisitEntry(const int ngramProbability, const int targetWordId) = 0; virtual ~NgramListener() {}; protected: diff --git a/native/jni/src/suggest/core/dictionary/property/bigram_property.h b/native/jni/src/suggest/core/dictionary/property/bigram_property.h deleted file mode 100644 index 343af143c..000000000 --- a/native/jni/src/suggest/core/dictionary/property/bigram_property.h +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright (C) 2014 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LATINIME_BIGRAM_PROPERTY_H -#define LATINIME_BIGRAM_PROPERTY_H - -#include <vector> - -#include "defines.h" - -namespace latinime { - -// TODO: Change to NgramProperty. -class BigramProperty { - public: - BigramProperty(const std::vector<int> *const targetCodePoints, - const int probability, const int timestamp, const int level, const int count) - : mTargetCodePoints(*targetCodePoints), mProbability(probability), - mTimestamp(timestamp), mLevel(level), mCount(count) {} - - const std::vector<int> *getTargetCodePoints() const { - return &mTargetCodePoints; - } - - int getProbability() const { - return mProbability; - } - - int getTimestamp() const { - return mTimestamp; - } - - int getLevel() const { - return mLevel; - } - - int getCount() const { - return mCount; - } - - private: - // Default copy constructor and assign operator are used for using in std::vector. - DISALLOW_DEFAULT_CONSTRUCTOR(BigramProperty); - - // TODO: Make members const. - std::vector<int> mTargetCodePoints; - int mProbability; - int mTimestamp; - int mLevel; - int mCount; -}; -} // namespace latinime -#endif // LATINIME_WORD_PROPERTY_H diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/historical_info.h b/native/jni/src/suggest/core/dictionary/property/historical_info.h index 428ca8626..f9bd6fd8c 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/historical_info.h +++ b/native/jni/src/suggest/core/dictionary/property/historical_info.h @@ -34,7 +34,7 @@ class HistoricalInfo { return mTimestamp != NOT_A_TIMESTAMP; } - int getTimeStamp() const { + int getTimestamp() const { return mTimestamp; } @@ -47,7 +47,7 @@ class HistoricalInfo { } private: - // Copy constructor is public to use this class as a type of return value. + // Default copy constructor is used for using in std::vector. DISALLOW_ASSIGNMENT_OPERATOR(HistoricalInfo); const int mTimestamp; diff --git a/native/jni/src/suggest/core/dictionary/property/ngram_property.h b/native/jni/src/suggest/core/dictionary/property/ngram_property.h new file mode 100644 index 000000000..e67b4da31 --- /dev/null +++ b/native/jni/src/suggest/core/dictionary/property/ngram_property.h @@ -0,0 +1,62 @@ +/* + * Copyright (C) 2014 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_NGRAM_PROPERTY_H +#define LATINIME_NGRAM_PROPERTY_H + +#include <vector> + +#include "defines.h" +#include "suggest/core/dictionary/property/historical_info.h" +#include "suggest/core/session/ngram_context.h" + +namespace latinime { + +class NgramProperty { + public: + NgramProperty(const NgramContext &ngramContext, const std::vector<int> &&targetCodePoints, + const int probability, const HistoricalInfo historicalInfo) + : mNgramContext(ngramContext), mTargetCodePoints(std::move(targetCodePoints)), + mProbability(probability), mHistoricalInfo(historicalInfo) {} + + const NgramContext *getNgramContext() const { + return &mNgramContext; + } + + const std::vector<int> *getTargetCodePoints() const { + return &mTargetCodePoints; + } + + int getProbability() const { + return mProbability; + } + + const HistoricalInfo getHistoricalInfo() const { + return mHistoricalInfo; + } + + private: + // Default copy constructor is used for using in std::vector. + DISALLOW_DEFAULT_CONSTRUCTOR(NgramProperty); + DISALLOW_ASSIGNMENT_OPERATOR(NgramProperty); + + const NgramContext mNgramContext; + const std::vector<int> mTargetCodePoints; + const int mProbability; + const HistoricalInfo mHistoricalInfo; +}; +} // namespace latinime +#endif // LATINIME_NGRAM_PROPERTY_H diff --git a/native/jni/src/suggest/core/dictionary/property/unigram_property.h b/native/jni/src/suggest/core/dictionary/property/unigram_property.h index 902eb000f..f194f979a 100644 --- a/native/jni/src/suggest/core/dictionary/property/unigram_property.h +++ b/native/jni/src/suggest/core/dictionary/property/unigram_property.h @@ -20,6 +20,7 @@ #include <vector> #include "defines.h" +#include "suggest/core/dictionary/property/historical_info.h" namespace latinime { @@ -27,8 +28,9 @@ class UnigramProperty { public: class ShortcutProperty { public: - ShortcutProperty(const std::vector<int> *const targetCodePoints, const int probability) - : mTargetCodePoints(*targetCodePoints), mProbability(probability) {} + ShortcutProperty(const std::vector<int> &&targetCodePoints, const int probability) + : mTargetCodePoints(std::move(targetCodePoints)), + mProbability(probability) {} const std::vector<int> *getTargetCodePoints() const { return &mTargetCodePoints; @@ -39,25 +41,53 @@ class UnigramProperty { } private: - // Default copy constructor and assign operator are used for using in std::vector. + // Default copy constructor is used for using in std::vector. DISALLOW_DEFAULT_CONSTRUCTOR(ShortcutProperty); - // TODO: Make members const. - std::vector<int> mTargetCodePoints; - int mProbability; + const std::vector<int> mTargetCodePoints; + const int mProbability; }; UnigramProperty() - : mRepresentsBeginningOfSentence(false), mIsNotAWord(false), mIsBlacklisted(false), - mProbability(NOT_A_PROBABILITY), mTimestamp(NOT_A_TIMESTAMP), mLevel(0), mCount(0), - mShortcuts() {} + : mRepresentsBeginningOfSentence(false), mIsNotAWord(false), + mIsBlacklisted(false), mIsPossiblyOffensive(false), mProbability(NOT_A_PROBABILITY), + mHistoricalInfo(), mShortcuts() {} + // In contexts which do not support the Blacklisted flag (v2, v4<403) UnigramProperty(const bool representsBeginningOfSentence, const bool isNotAWord, - const bool isBlacklisted, const int probability, const int timestamp, const int level, - const int count, const std::vector<ShortcutProperty> *const shortcuts) + const bool isPossiblyOffensive, const int probability, + const HistoricalInfo historicalInfo, const std::vector<ShortcutProperty> &&shortcuts) : mRepresentsBeginningOfSentence(representsBeginningOfSentence), - mIsNotAWord(isNotAWord), mIsBlacklisted(isBlacklisted), mProbability(probability), - mTimestamp(timestamp), mLevel(level), mCount(count), mShortcuts(*shortcuts) {} + mIsNotAWord(isNotAWord), mIsBlacklisted(false), + mIsPossiblyOffensive(isPossiblyOffensive), mProbability(probability), + mHistoricalInfo(historicalInfo), mShortcuts(std::move(shortcuts)) {} + + // Without shortcuts, in contexts which do not support the Blacklisted flag (v2, v4<403) + UnigramProperty(const bool representsBeginningOfSentence, const bool isNotAWord, + const bool isPossiblyOffensive, const int probability, + const HistoricalInfo historicalInfo) + : mRepresentsBeginningOfSentence(representsBeginningOfSentence), + mIsNotAWord(isNotAWord), mIsBlacklisted(false), + mIsPossiblyOffensive(isPossiblyOffensive), mProbability(probability), + mHistoricalInfo(historicalInfo), mShortcuts() {} + + // In contexts which DO support the Blacklisted flag (v403) + UnigramProperty(const bool representsBeginningOfSentence, const bool isNotAWord, + const bool isBlacklisted, const bool isPossiblyOffensive, const int probability, + const HistoricalInfo historicalInfo, const std::vector<ShortcutProperty> &&shortcuts) + : mRepresentsBeginningOfSentence(representsBeginningOfSentence), + mIsNotAWord(isNotAWord), mIsBlacklisted(isBlacklisted), + mIsPossiblyOffensive(isPossiblyOffensive), mProbability(probability), + mHistoricalInfo(historicalInfo), mShortcuts(std::move(shortcuts)) {} + + // Without shortcuts, in contexts which DO support the Blacklisted flag (v403) + UnigramProperty(const bool representsBeginningOfSentence, const bool isNotAWord, + const bool isBlacklisted, const bool isPossiblyOffensive, const int probability, + const HistoricalInfo historicalInfo) + : mRepresentsBeginningOfSentence(representsBeginningOfSentence), + mIsNotAWord(isNotAWord), mIsBlacklisted(isBlacklisted), + mIsPossiblyOffensive(isPossiblyOffensive), mProbability(probability), + mHistoricalInfo(historicalInfo), mShortcuts() {} bool representsBeginningOfSentence() const { return mRepresentsBeginningOfSentence; @@ -67,6 +97,10 @@ class UnigramProperty { return mIsNotAWord; } + bool isPossiblyOffensive() const { + return mIsPossiblyOffensive; + } + bool isBlacklisted() const { return mIsBlacklisted; } @@ -79,16 +113,8 @@ class UnigramProperty { return mProbability; } - int getTimestamp() const { - return mTimestamp; - } - - int getLevel() const { - return mLevel; - } - - int getCount() const { - return mCount; + const HistoricalInfo getHistoricalInfo() const { + return mHistoricalInfo; } const std::vector<ShortcutProperty> &getShortcuts() const { @@ -99,16 +125,13 @@ class UnigramProperty { // Default copy constructor is used for using as a return value. DISALLOW_ASSIGNMENT_OPERATOR(UnigramProperty); - // TODO: Make members const. - bool mRepresentsBeginningOfSentence; - bool mIsNotAWord; - bool mIsBlacklisted; - int mProbability; - // Historical information - int mTimestamp; - int mLevel; - int mCount; - std::vector<ShortcutProperty> mShortcuts; + const bool mRepresentsBeginningOfSentence; + const bool mIsNotAWord; + const bool mIsBlacklisted; + const bool mIsPossiblyOffensive; + const int mProbability; + const HistoricalInfo mHistoricalInfo; + const std::vector<ShortcutProperty> mShortcuts; }; } // namespace latinime #endif // LATINIME_UNIGRAM_PROPERTY_H diff --git a/native/jni/src/suggest/core/dictionary/property/word_property.cpp b/native/jni/src/suggest/core/dictionary/property/word_property.cpp index 5bdd5606b..019f0880f 100644 --- a/native/jni/src/suggest/core/dictionary/property/word_property.cpp +++ b/native/jni/src/suggest/core/dictionary/property/word_property.cpp @@ -17,22 +17,25 @@ #include "suggest/core/dictionary/property/word_property.h" #include "utils/jni_data_utils.h" +#include "suggest/core/dictionary/property/historical_info.h" namespace latinime { void WordProperty::outputProperties(JNIEnv *const env, jintArray outCodePoints, - jbooleanArray outFlags, jintArray outProbabilityInfo, jobject outBigramTargets, - jobject outBigramProbabilities, jobject outShortcutTargets, + jbooleanArray outFlags, jintArray outProbabilityInfo, + jobject outNgramPrevWordsArray, jobject outNgramPrevWordIsBeginningOfSentenceArray, + jobject outNgramTargets, jobject outNgramProbabilities, jobject outShortcutTargets, jobject outShortcutProbabilities) const { JniDataUtils::outputCodePoints(env, outCodePoints, 0 /* start */, MAX_WORD_LENGTH /* maxLength */, mCodePoints.data(), mCodePoints.size(), false /* needsNullTermination */); - jboolean flags[] = {mUnigramProperty.isNotAWord(), mUnigramProperty.isBlacklisted(), - !mBigrams.empty(), mUnigramProperty.hasShortcuts(), + jboolean flags[] = {mUnigramProperty.isNotAWord(), mUnigramProperty.isPossiblyOffensive(), + !mNgrams.empty(), mUnigramProperty.hasShortcuts(), mUnigramProperty.representsBeginningOfSentence()}; env->SetBooleanArrayRegion(outFlags, 0 /* start */, NELEMS(flags), flags); - int probabilityInfo[] = {mUnigramProperty.getProbability(), mUnigramProperty.getTimestamp(), - mUnigramProperty.getLevel(), mUnigramProperty.getCount()}; + const HistoricalInfo &historicalInfo = mUnigramProperty.getHistoricalInfo(); + int probabilityInfo[] = {mUnigramProperty.getProbability(), historicalInfo.getTimestamp(), + historicalInfo.getLevel(), historicalInfo.getCount()}; env->SetIntArrayRegion(outProbabilityInfo, 0 /* start */, NELEMS(probabilityInfo), probabilityInfo); @@ -41,23 +44,47 @@ void WordProperty::outputProperties(JNIEnv *const env, jintArray outCodePoints, jclass arrayListClass = env->FindClass("java/util/ArrayList"); jmethodID addMethodId = env->GetMethodID(arrayListClass, "add", "(Ljava/lang/Object;)Z"); - // Output bigrams. - for (const auto &bigramProperty : mBigrams) { - const std::vector<int> *const word1CodePoints = bigramProperty.getTargetCodePoints(); - jintArray bigramWord1CodePointArray = env->NewIntArray(word1CodePoints->size()); - JniDataUtils::outputCodePoints(env, bigramWord1CodePointArray, 0 /* start */, - word1CodePoints->size(), word1CodePoints->data(), word1CodePoints->size(), - false /* needsNullTermination */); - env->CallBooleanMethod(outBigramTargets, addMethodId, bigramWord1CodePointArray); - env->DeleteLocalRef(bigramWord1CodePointArray); + // Output ngrams. + jclass intArrayClass = env->FindClass("[I"); + for (const auto &ngramProperty : mNgrams) { + const NgramContext *const ngramContext = ngramProperty.getNgramContext(); + jobjectArray prevWordWordCodePointsArray = env->NewObjectArray( + ngramContext->getPrevWordCount(), intArrayClass, nullptr); + jbooleanArray prevWordIsBeginningOfSentenceArray = + env->NewBooleanArray(ngramContext->getPrevWordCount()); + for (size_t i = 0; i < ngramContext->getPrevWordCount(); ++i) { + const CodePointArrayView codePoints = ngramContext->getNthPrevWordCodePoints(i + 1); + jintArray prevWordCodePoints = env->NewIntArray(codePoints.size()); + JniDataUtils::outputCodePoints(env, prevWordCodePoints, 0 /* start */, + codePoints.size(), codePoints.data(), codePoints.size(), + false /* needsNullTermination */); + env->SetObjectArrayElement(prevWordWordCodePointsArray, i, prevWordCodePoints); + env->DeleteLocalRef(prevWordCodePoints); + JniDataUtils::putBooleanToArray(env, prevWordIsBeginningOfSentenceArray, i, + ngramContext->isNthPrevWordBeginningOfSentence(i + 1)); + } + env->CallBooleanMethod(outNgramPrevWordsArray, addMethodId, prevWordWordCodePointsArray); + env->CallBooleanMethod(outNgramPrevWordIsBeginningOfSentenceArray, addMethodId, + prevWordIsBeginningOfSentenceArray); + env->DeleteLocalRef(prevWordWordCodePointsArray); + env->DeleteLocalRef(prevWordIsBeginningOfSentenceArray); + + const std::vector<int> *const targetWordCodePoints = ngramProperty.getTargetCodePoints(); + jintArray targetWordCodePointArray = env->NewIntArray(targetWordCodePoints->size()); + JniDataUtils::outputCodePoints(env, targetWordCodePointArray, 0 /* start */, + targetWordCodePoints->size(), targetWordCodePoints->data(), + targetWordCodePoints->size(), false /* needsNullTermination */); + env->CallBooleanMethod(outNgramTargets, addMethodId, targetWordCodePointArray); + env->DeleteLocalRef(targetWordCodePointArray); - int bigramProbabilityInfo[] = {bigramProperty.getProbability(), - bigramProperty.getTimestamp(), bigramProperty.getLevel(), - bigramProperty.getCount()}; + const HistoricalInfo &ngramHistoricalInfo = ngramProperty.getHistoricalInfo(); + int bigramProbabilityInfo[] = {ngramProperty.getProbability(), + ngramHistoricalInfo.getTimestamp(), ngramHistoricalInfo.getLevel(), + ngramHistoricalInfo.getCount()}; jintArray bigramProbabilityInfoArray = env->NewIntArray(NELEMS(bigramProbabilityInfo)); env->SetIntArrayRegion(bigramProbabilityInfoArray, 0 /* start */, NELEMS(bigramProbabilityInfo), bigramProbabilityInfo); - env->CallBooleanMethod(outBigramProbabilities, addMethodId, bigramProbabilityInfoArray); + env->CallBooleanMethod(outNgramProbabilities, addMethodId, bigramProbabilityInfoArray); env->DeleteLocalRef(bigramProbabilityInfoArray); } @@ -65,8 +92,6 @@ void WordProperty::outputProperties(JNIEnv *const env, jintArray outCodePoints, for (const auto &shortcut : mUnigramProperty.getShortcuts()) { const std::vector<int> *const targetCodePoints = shortcut.getTargetCodePoints(); jintArray shortcutTargetCodePointArray = env->NewIntArray(targetCodePoints->size()); - env->SetIntArrayRegion(shortcutTargetCodePointArray, 0 /* start */, - targetCodePoints->size(), targetCodePoints->data()); JniDataUtils::outputCodePoints(env, shortcutTargetCodePointArray, 0 /* start */, targetCodePoints->size(), targetCodePoints->data(), targetCodePoints->size(), false /* needsNullTermination */); diff --git a/native/jni/src/suggest/core/dictionary/property/word_property.h b/native/jni/src/suggest/core/dictionary/property/word_property.h index aa3e0b68a..b5314faaa 100644 --- a/native/jni/src/suggest/core/dictionary/property/word_property.h +++ b/native/jni/src/suggest/core/dictionary/property/word_property.h @@ -21,7 +21,7 @@ #include "defines.h" #include "jni.h" -#include "suggest/core/dictionary/property/bigram_property.h" +#include "suggest/core/dictionary/property/ngram_property.h" #include "suggest/core/dictionary/property/unigram_property.h" namespace latinime { @@ -31,23 +31,25 @@ class WordProperty { public: // Default constructor is used to create an instance that indicates an invalid word. WordProperty() - : mCodePoints(), mUnigramProperty(), mBigrams() {} + : mCodePoints(), mUnigramProperty(), mNgrams() {} - WordProperty(const std::vector<int> *const codePoints, - const UnigramProperty *const unigramProperty, - const std::vector<BigramProperty> *const bigrams) - : mCodePoints(*codePoints), mUnigramProperty(*unigramProperty), mBigrams(*bigrams) {} + WordProperty(const std::vector<int> &&codePoints, const UnigramProperty *const unigramProperty, + const std::vector<NgramProperty> *const ngrams) + : mCodePoints(std::move(codePoints)), mUnigramProperty(*unigramProperty), + mNgrams(*ngrams) {} void outputProperties(JNIEnv *const env, jintArray outCodePoints, jbooleanArray outFlags, - jintArray outProbabilityInfo, jobject outBigramTargets, jobject outBigramProbabilities, - jobject outShortcutTargets, jobject outShortcutProbabilities) const; + jintArray outProbabilityInfo, jobject outNgramPrevWordsArray, + jobject outNgramPrevWordIsBeginningOfSentenceArray, jobject outNgramTargets, + jobject outNgramProbabilities, jobject outShortcutTargets, + jobject outShortcutProbabilities) const; const UnigramProperty *getUnigramProperty() const { return &mUnigramProperty; } - const std::vector<BigramProperty> *getBigramProperties() const { - return &mBigrams; + const std::vector<NgramProperty> *getNgramProperties() const { + return &mNgrams; } private: @@ -56,7 +58,7 @@ class WordProperty { const std::vector<int> mCodePoints; const UnigramProperty mUnigramProperty; - const std::vector<BigramProperty> mBigrams; + const std::vector<NgramProperty> mNgrams; }; } // namespace latinime #endif // LATINIME_WORD_PROPERTY_H diff --git a/native/jni/src/suggest/core/dictionary/word_attributes.h b/native/jni/src/suggest/core/dictionary/word_attributes.h new file mode 100644 index 000000000..5351e7d7d --- /dev/null +++ b/native/jni/src/suggest/core/dictionary/word_attributes.h @@ -0,0 +1,68 @@ +/* + * Copyright (C) 2014, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_WORD_ATTRIBUTES_H +#define LATINIME_WORD_ATTRIBUTES_H + +#include "defines.h" + +class WordAttributes { + public: + // Invalid word attributes. + WordAttributes() + : mProbability(NOT_A_PROBABILITY), mIsBlacklisted(false), mIsNotAWord(false), + mIsPossiblyOffensive(false) {} + + WordAttributes(const int probability, const bool isBlacklisted, const bool isNotAWord, + const bool isPossiblyOffensive) + : mProbability(probability), mIsBlacklisted(isBlacklisted), mIsNotAWord(isNotAWord), + mIsPossiblyOffensive(isPossiblyOffensive) {} + + int getProbability() const { + return mProbability; + } + + bool isBlacklisted() const { + return mIsBlacklisted; + } + + bool isNotAWord() const { + return mIsNotAWord; + } + + // Whether or not a word is possibly offensive. + // * Static dictionaries <v202, as well as dynamic dictionaries <v403, will set this based on + // whether or not the probability of the word is zero. + // * Static dictionaries >=v203 will set this based on the IS_POSSIBLY_OFFENSIVE PtNode flag. + // * Dynamic dictionaries >=v403 will set this based on the IS_POSSIBLY_OFFENSIVE language model + // flag (the PtNode flag IS_BLACKLISTED is ignored and kept as zero) + // + // See the ::getWordAttributes function for each of these dictionary policies for more details. + bool isPossiblyOffensive() const { + return mIsPossiblyOffensive; + } + + private: + DISALLOW_ASSIGNMENT_OPERATOR(WordAttributes); + + int mProbability; + bool mIsBlacklisted; + bool mIsNotAWord; + bool mIsPossiblyOffensive; +}; + + // namespace +#endif /* LATINIME_WORD_ATTRIBUTES_H */ diff --git a/native/jni/src/suggest/core/layout/additional_proximity_chars.cpp b/native/jni/src/suggest/core/layout/additional_proximity_chars.cpp index 34b8b37b0..8b39f7da5 100644 --- a/native/jni/src/suggest/core/layout/additional_proximity_chars.cpp +++ b/native/jni/src/suggest/core/layout/additional_proximity_chars.cpp @@ -19,7 +19,7 @@ namespace latinime { // TODO: Stop using hardcoded additional proximity characters. // TODO: Have proximity character informations in each language's binary dictionary. -const char *AdditionalProximityChars::LOCALE_EN_US = "en"; +const int AdditionalProximityChars::LOCALE_EN_US[LOCALE_EN_US_SIZE] = { 'e', 'n' }; const int AdditionalProximityChars::EN_US_ADDITIONAL_A[EN_US_ADDITIONAL_A_SIZE] = { 'e', 'i', 'o', 'u' diff --git a/native/jni/src/suggest/core/layout/additional_proximity_chars.h b/native/jni/src/suggest/core/layout/additional_proximity_chars.h index a88fd6cea..2260be9bd 100644 --- a/native/jni/src/suggest/core/layout/additional_proximity_chars.h +++ b/native/jni/src/suggest/core/layout/additional_proximity_chars.h @@ -18,6 +18,7 @@ #define LATINIME_ADDITIONAL_PROXIMITY_CHARS_H #include <cstring> +#include <vector> #include "defines.h" @@ -26,7 +27,8 @@ namespace latinime { class AdditionalProximityChars { private: DISALLOW_IMPLICIT_CONSTRUCTORS(AdditionalProximityChars); - static const char *LOCALE_EN_US; + static const int LOCALE_EN_US_SIZE = 2; + static const int LOCALE_EN_US[LOCALE_EN_US_SIZE]; static const int EN_US_ADDITIONAL_A_SIZE = 4; static const int EN_US_ADDITIONAL_A[]; static const int EN_US_ADDITIONAL_E_SIZE = 4; @@ -38,15 +40,22 @@ class AdditionalProximityChars { static const int EN_US_ADDITIONAL_U_SIZE = 4; static const int EN_US_ADDITIONAL_U[]; - AK_FORCE_INLINE static bool isEnLocale(const char *localeStr) { - const size_t LOCALE_EN_US_SIZE = strlen(LOCALE_EN_US); - return localeStr && strlen(localeStr) >= LOCALE_EN_US_SIZE - && strncmp(localeStr, LOCALE_EN_US, LOCALE_EN_US_SIZE) == 0; + AK_FORCE_INLINE static bool isEnLocale(const std::vector<int> *locale) { + const int NCHARS = NELEMS(LOCALE_EN_US); + if (locale->size() < NCHARS) { + return false; + } + for (int i = 0; i < NCHARS; ++i) { + if ((*locale)[i] != LOCALE_EN_US[i]) { + return false; + } + } + return true; } public: - static int getAdditionalCharsSize(const char *const localeStr, const int c) { - if (!isEnLocale(localeStr)) { + static int getAdditionalCharsSize(const std::vector<int> *locale, const int c) { + if (!isEnLocale(locale)) { return 0; } switch (c) { @@ -65,8 +74,8 @@ class AdditionalProximityChars { } } - static const int *getAdditionalChars(const char *const localeStr, const int c) { - if (!isEnLocale(localeStr)) { + static const int *getAdditionalChars(const std::vector<int> *locale, const int c) { + if (!isEnLocale(locale)) { return 0; } switch (c) { diff --git a/native/jni/src/suggest/core/layout/geometry_utils.h b/native/jni/src/suggest/core/layout/geometry_utils.h index b667df68f..000fcd4a1 100644 --- a/native/jni/src/suggest/core/layout/geometry_utils.h +++ b/native/jni/src/suggest/core/layout/geometry_utils.h @@ -38,13 +38,15 @@ class GeometryUtils { } static AK_FORCE_INLINE float getAngleDiff(const float a1, const float a2) { - const float deltaA = fabsf(a1 - a2); - const float diff = ROUND_FLOAT_10000(deltaA); - if (diff > M_PI_F) { - const float normalizedDiff = 2.0f * M_PI_F - diff; - return ROUND_FLOAT_10000(normalizedDiff); + static const float M_2PI_F = M_PI * 2.0f; + float delta = fabsf(a1 - a2); + if (delta > M_2PI_F) { + delta -= (M_2PI_F * static_cast<int>(delta / M_2PI_F)); } - return diff; + if (delta > M_PI_F) { + delta = M_2PI_F - delta; + } + return ROUND_FLOAT_10000(delta); } static AK_FORCE_INLINE int getDistanceInt(const int x1, const int y1, const int x2, diff --git a/native/jni/src/suggest/core/layout/proximity_info.cpp b/native/jni/src/suggest/core/layout/proximity_info.cpp index 4c75a188e..933a5e145 100644 --- a/native/jni/src/suggest/core/layout/proximity_info.cpp +++ b/native/jni/src/suggest/core/layout/proximity_info.cpp @@ -49,13 +49,13 @@ static AK_FORCE_INLINE void safeGetOrFillZeroFloatArrayRegion(JNIEnv *env, jfloa } } -ProximityInfo::ProximityInfo(JNIEnv *env, const jstring localeJStr, - const int keyboardWidth, const int keyboardHeight, const int gridWidth, - const int gridHeight, const int mostCommonKeyWidth, const int mostCommonKeyHeight, - const jintArray proximityChars, const int keyCount, const jintArray keyXCoordinates, - const jintArray keyYCoordinates, const jintArray keyWidths, const jintArray keyHeights, - const jintArray keyCharCodes, const jfloatArray sweetSpotCenterXs, - const jfloatArray sweetSpotCenterYs, const jfloatArray sweetSpotRadii) +ProximityInfo::ProximityInfo(JNIEnv *env, const int keyboardWidth, const int keyboardHeight, + const int gridWidth, const int gridHeight, const int mostCommonKeyWidth, + const int mostCommonKeyHeight, const jintArray proximityChars, const int keyCount, + const jintArray keyXCoordinates, const jintArray keyYCoordinates, + const jintArray keyWidths, const jintArray keyHeights, const jintArray keyCharCodes, + const jfloatArray sweetSpotCenterXs, const jfloatArray sweetSpotCenterYs, + const jfloatArray sweetSpotRadii) : GRID_WIDTH(gridWidth), GRID_HEIGHT(gridHeight), MOST_COMMON_KEY_WIDTH(mostCommonKeyWidth), MOST_COMMON_KEY_WIDTH_SQUARE(mostCommonKeyWidth * mostCommonKeyWidth), NORMALIZED_SQUARED_MOST_COMMON_KEY_HYPOTENUSE(1.0f + @@ -82,13 +82,6 @@ ProximityInfo::ProximityInfo(JNIEnv *env, const jstring localeJStr, if (DEBUG_PROXIMITY_INFO) { AKLOGI("Create proximity info array %d", proximityCharsLength); } - const jsize localeCStrUtf8Length = env->GetStringUTFLength(localeJStr); - if (localeCStrUtf8Length >= MAX_LOCALE_STRING_LENGTH) { - AKLOGI("Locale string length too long: length=%d", localeCStrUtf8Length); - ASSERT(false); - } - memset(mLocaleStr, 0, sizeof(mLocaleStr)); - env->GetStringUTFRegion(localeJStr, 0, env->GetStringLength(localeJStr), mLocaleStr); safeGetOrFillZeroIntArrayRegion(env, proximityChars, proximityCharsLength, mProximityCharsArray); safeGetOrFillZeroIntArrayRegion(env, keyXCoordinates, KEY_COUNT, mKeyXCoordinates); diff --git a/native/jni/src/suggest/core/layout/proximity_info.h b/native/jni/src/suggest/core/layout/proximity_info.h index d4e453736..f7c907697 100644 --- a/native/jni/src/suggest/core/layout/proximity_info.h +++ b/native/jni/src/suggest/core/layout/proximity_info.h @@ -18,6 +18,7 @@ #define LATINIME_PROXIMITY_INFO_H #include <unordered_map> +#include <vector> #include "defines.h" #include "jni.h" @@ -27,9 +28,9 @@ namespace latinime { class ProximityInfo { public: - ProximityInfo(JNIEnv *env, const jstring localeJStr, - const int keyboardWidth, const int keyboardHeight, const int gridWidth, - const int gridHeight, const int mostCommonKeyWidth, const int mostCommonKeyHeight, + ProximityInfo(JNIEnv *env, const int keyboardWidth, const int keyboardHeight, + const int gridWidth, const int gridHeight, + const int mostCommonKeyWidth, const int mostCommonKeyHeight, const jintArray proximityChars, const int keyCount, const jintArray keyXCoordinates, const jintArray keyYCoordinates, const jintArray keyWidths, const jintArray keyHeights, const jintArray keyCharCodes, const jfloatArray sweetSpotCenterXs, @@ -71,11 +72,11 @@ class ProximityInfo { AK_FORCE_INLINE void initializeProximities(const int *const inputCodes, const int *const inputXCoordinates, const int *const inputYCoordinates, - const int inputSize, int *allInputCodes) const { + const int inputSize, int *allInputCodes, const std::vector<int> *locale) const { ProximityInfoUtils::initializeProximities(inputCodes, inputXCoordinates, inputYCoordinates, inputSize, mKeyXCoordinates, mKeyYCoordinates, mKeyWidths, mKeyHeights, mProximityCharsArray, CELL_HEIGHT, CELL_WIDTH, GRID_WIDTH, MOST_COMMON_KEY_WIDTH, - KEY_COUNT, mLocaleStr, &mLowerCodePointToKeyMap, allInputCodes); + KEY_COUNT, locale, &mLowerCodePointToKeyMap, allInputCodes); } AK_FORCE_INLINE int getKeyIndexOf(const int c) const { @@ -103,9 +104,6 @@ class ProximityInfo { const int KEYBOARD_HEIGHT; const float KEYBOARD_HYPOTENUSE; const bool HAS_TOUCH_POSITION_CORRECTION_DATA; - // Assuming locale strings such as en_US, sr-Latn etc. - static const int MAX_LOCALE_STRING_LENGTH = 10; - char mLocaleStr[MAX_LOCALE_STRING_LENGTH]; int *mProximityCharsArray; int mKeyXCoordinates[MAX_KEY_COUNT_IN_A_KEYBOARD]; int mKeyYCoordinates[MAX_KEY_COUNT_IN_A_KEYBOARD]; diff --git a/native/jni/src/suggest/core/layout/proximity_info_state.cpp b/native/jni/src/suggest/core/layout/proximity_info_state.cpp index 91469e26d..d43a0026a 100644 --- a/native/jni/src/suggest/core/layout/proximity_info_state.cpp +++ b/native/jni/src/suggest/core/layout/proximity_info_state.cpp @@ -42,7 +42,7 @@ int ProximityInfoState::getPrimaryOriginalCodePointAt(const int index) const { void ProximityInfoState::initInputParams(const int pointerId, const float maxPointToKeyLength, const ProximityInfo *proximityInfo, const int *const inputCodes, const int inputSize, const int *const xCoordinates, const int *const yCoordinates, const int *const times, - const int *const pointerIds, const bool isGeometric) { + const int *const pointerIds, const bool isGeometric, const std::vector<int> *locale) { ASSERT(isGeometric || (inputSize < MAX_WORD_LENGTH)); mIsContinuousSuggestionPossible = (mHasBeenUpdatedByGeometricInput != isGeometric) ? false : ProximityInfoStateUtils::checkAndReturnIsContinuousSuggestionPossible( @@ -66,7 +66,7 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi if (!isGeometric && pointerId == 0) { mProximityInfo->initializeProximities(inputCodes, xCoordinates, yCoordinates, - inputSize, mInputProximities); + inputSize, mInputProximities, locale); } /////////////////////// diff --git a/native/jni/src/suggest/core/layout/proximity_info_state.h b/native/jni/src/suggest/core/layout/proximity_info_state.h index e6180fe17..a2d663544 100644 --- a/native/jni/src/suggest/core/layout/proximity_info_state.h +++ b/native/jni/src/suggest/core/layout/proximity_info_state.h @@ -37,7 +37,8 @@ class ProximityInfoState { void initInputParams(const int pointerId, const float maxPointToKeyLength, const ProximityInfo *proximityInfo, const int *const inputCodes, const int inputSize, const int *xCoordinates, const int *yCoordinates, - const int *const times, const int *const pointerIds, const bool isGeometric); + const int *const times, const int *const pointerIds, const bool isGeometric, + const std::vector<int> *locale); ///////////////////////////////////////// // Defined here // diff --git a/native/jni/src/suggest/core/layout/proximity_info_utils.h b/native/jni/src/suggest/core/layout/proximity_info_utils.h index 178aada2d..79d0615b8 100644 --- a/native/jni/src/suggest/core/layout/proximity_info_utils.h +++ b/native/jni/src/suggest/core/layout/proximity_info_utils.h @@ -19,6 +19,7 @@ #include <cmath> #include <unordered_map> +#include <vector> #include "defines.h" #include "suggest/core/layout/additional_proximity_chars.h" @@ -51,7 +52,7 @@ class ProximityInfoUtils { const int *const keyYCoordinates, const int *const keyWidths, const int *keyHeights, const int *const proximityCharsArray, const int cellHeight, const int cellWidth, const int gridWidth, const int mostCommonKeyWidth, const int keyCount, - const char *const localeStr, + const std::vector<int> *locale, const std::unordered_map<int, int> *const codeToKeyMap, int *inputProximities) { // Initialize // - mInputCodes @@ -64,7 +65,7 @@ class ProximityInfoUtils { int *proximities = &inputProximities[i * MAX_PROXIMITY_CHARS_SIZE]; calculateProximities(keyXCoordinates, keyYCoordinates, keyWidths, keyHeights, proximityCharsArray, cellHeight, cellWidth, gridWidth, mostCommonKeyWidth, - keyCount, x, y, primaryKey, localeStr, codeToKeyMap, proximities); + keyCount, x, y, primaryKey, locale, codeToKeyMap, proximities); } if (DEBUG_PROXIMITY_CHARS) { @@ -143,7 +144,7 @@ class ProximityInfoUtils { const int *const keyYCoordinates, const int *const keyWidths, const int *keyHeights, const int *const proximityCharsArray, const int cellHeight, const int cellWidth, const int gridWidth, const int mostCommonKeyWidth, const int keyCount, - const int x, const int y, const int primaryKey, const char *const localeStr, + const int x, const int y, const int primaryKey, const std::vector<int> *locale, const std::unordered_map<int, int> *const codeToKeyMap, int *proximities) { const int mostCommonKeyWidthSquare = mostCommonKeyWidth * mostCommonKeyWidth; int insertPos = 0; @@ -177,7 +178,7 @@ class ProximityInfoUtils { } } const int additionalProximitySize = - AdditionalProximityChars::getAdditionalCharsSize(localeStr, primaryKey); + AdditionalProximityChars::getAdditionalCharsSize(locale, primaryKey); if (additionalProximitySize > 0) { proximities[insertPos++] = ADDITIONAL_PROXIMITY_CHAR_DELIMITER_CODE; if (insertPos >= MAX_PROXIMITY_CHARS_SIZE) { @@ -188,7 +189,7 @@ class ProximityInfoUtils { } const int *additionalProximityChars = - AdditionalProximityChars::getAdditionalChars(localeStr, primaryKey); + AdditionalProximityChars::getAdditionalChars(locale, primaryKey); for (int j = 0; j < additionalProximitySize; ++j) { const int ac = additionalProximityChars[j]; int k = 0; diff --git a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h index e91f07682..33a0fbc19 100644 --- a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h +++ b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h @@ -20,16 +20,20 @@ #include <memory> #include "defines.h" +#include "suggest/core/dictionary/binary_dictionary_shortcut_iterator.h" +#include "suggest/core/dictionary/property/historical_info.h" #include "suggest/core/dictionary/property/word_property.h" +#include "suggest/core/dictionary/word_attributes.h" +#include "utils/int_array_view.h" namespace latinime { class DicNode; class DicNodeVector; class DictionaryHeaderStructurePolicy; -class DictionaryShortcutsStructurePolicy; +class MultiBigramMap; class NgramListener; -class PrevWordsInfo; +class NgramContext; class UnigramProperty; /* @@ -47,42 +51,45 @@ class DictionaryStructureWithBufferPolicy { virtual void createAndGetAllChildDicNodes(const DicNode *const dicNode, DicNodeVector *const childDicNodes) const = 0; - virtual int getCodePointsAndProbabilityAndReturnCodePointCount( - const int nodePos, const int maxCodePointCount, int *const outCodePoints, - int *const outUnigramProbability) const = 0; + virtual int getCodePointsAndReturnCodePointCount(const int wordId, const int maxCodePointCount, + int *const outCodePoints) const = 0; - virtual int getTerminalPtNodePositionOfWord(const int *const inWord, - const int length, const bool forceLowerCaseSearch) const = 0; + virtual int getWordId(const CodePointArrayView wordCodePoints, + const bool forceLowerCaseSearch) const = 0; - virtual int getProbability(const int unigramProbability, - const int bigramProbability) const = 0; + virtual const WordAttributes getWordAttributesInContext(const WordIdArrayView prevWordIds, + const int wordId, MultiBigramMap *const multiBigramMap) const = 0; - virtual int getProbabilityOfPtNode(const int *const prevWordsPtNodePos, - const int nodePos) const = 0; + // TODO: Remove + virtual int getProbability(const int unigramProbability, const int bigramProbability) const = 0; - virtual void iterateNgramEntries(const int *const prevWordsPtNodePos, + virtual int getProbabilityOfWord(const WordIdArrayView prevWordIds, const int wordId) const = 0; + + virtual void iterateNgramEntries(const WordIdArrayView prevWordIds, NgramListener *const listener) const = 0; - virtual int getShortcutPositionOfPtNode(const int nodePos) const = 0; + virtual BinaryDictionaryShortcutIterator getShortcutIterator(const int wordId) const = 0; virtual const DictionaryHeaderStructurePolicy *getHeaderStructurePolicy() const = 0; - virtual const DictionaryShortcutsStructurePolicy *getShortcutsStructurePolicy() const = 0; - // Returns whether the update was success or not. - virtual bool addUnigramEntry(const int *const word, const int length, + virtual bool addUnigramEntry(const CodePointArrayView wordCodePoints, const UnigramProperty *const unigramProperty) = 0; // Returns whether the update was success or not. - virtual bool removeUnigramEntry(const int *const word, const int length) = 0; + virtual bool removeUnigramEntry(const CodePointArrayView wordCodePoints) = 0; + + // Returns whether the update was success or not. + virtual bool addNgramEntry(const NgramProperty *const ngramProperty) = 0; // Returns whether the update was success or not. - virtual bool addNgramEntry(const PrevWordsInfo *const prevWordsInfo, - const BigramProperty *const bigramProperty) = 0; + virtual bool removeNgramEntry(const NgramContext *const ngramContext, + const CodePointArrayView wordCodePoints) = 0; // Returns whether the update was success or not. - virtual bool removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, - const int *const word, const int length) = 0; + virtual bool updateEntriesForWordWithNgramContext(const NgramContext *const ngramContext, + const CodePointArrayView wordCodePoints, const bool isValidWord, + const HistoricalInfo historicalInfo) = 0; // Returns whether the flush was success or not. virtual bool flush(const char *const filePath) = 0; @@ -97,9 +104,7 @@ class DictionaryStructureWithBufferPolicy { virtual void getProperty(const char *const query, const int queryLength, char *const outResult, const int maxResultLength) = 0; - // Used for testing. - virtual const WordProperty getWordProperty(const int *const codePonts, - const int codePointCount) const = 0; + virtual const WordProperty getWordProperty(const CodePointArrayView wordCodePoints) const = 0; // Method to iterate all words in the dictionary. // The returned token has to be used to get the next word. If token is 0, this method newly diff --git a/native/jni/src/suggest/core/policy/scoring.h b/native/jni/src/suggest/core/policy/scoring.h index 9e75cace4..ce3684a1c 100644 --- a/native/jni/src/suggest/core/policy/scoring.h +++ b/native/jni/src/suggest/core/policy/scoring.h @@ -32,9 +32,11 @@ class Scoring { const ErrorTypeUtils::ErrorType containedErrorTypes, const bool forceCommit, const bool boostExactMatches) const = 0; virtual void getMostProbableString(const DicTraverseSession *const traverseSession, - const float languageWeight, SuggestionResults *const outSuggestionResults) const = 0; - virtual float getAdjustedLanguageWeight(DicTraverseSession *const traverseSession, - DicNode *const terminals, const int size) const = 0; + const float weightOfLangModelVsSpatialModel, + SuggestionResults *const outSuggestionResults) const = 0; + virtual float getAdjustedWeightOfLangModelVsSpatialModel( + DicTraverseSession *const traverseSession, DicNode *const terminals, + const int size) const = 0; virtual float getDoubleLetterDemotionDistanceCost( const DicNode *const terminalDicNode) const = 0; virtual bool autoCorrectsToMultiWordSuggestionIfTop() const = 0; diff --git a/native/jni/src/suggest/core/policy/traversal.h b/native/jni/src/suggest/core/policy/traversal.h index 8ddaa0514..5b6616d9a 100644 --- a/native/jni/src/suggest/core/policy/traversal.h +++ b/native/jni/src/suggest/core/policy/traversal.h @@ -44,11 +44,12 @@ class Traversal { virtual bool needsToTraverseAllUserInput() const = 0; virtual float getMaxSpatialDistance() const = 0; virtual int getDefaultExpandDicNodeSize() const = 0; - virtual int getMaxCacheSize(const int inputSize) const = 0; + virtual int getMaxCacheSize(const int inputSize, const float weightForLocale) const = 0; virtual int getTerminalCacheSize() const = 0; virtual bool isPossibleOmissionChildNode(const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0; - virtual bool isGoodToTraverseNextWord(const DicNode *const dicNode) const = 0; + virtual bool isGoodToTraverseNextWord(const DicNode *const dicNode, + const int probability) const = 0; protected: Traversal() {} diff --git a/native/jni/src/suggest/core/policy/weighting.cpp b/native/jni/src/suggest/core/policy/weighting.cpp index c202b81fe..a06e7d070 100644 --- a/native/jni/src/suggest/core/policy/weighting.cpp +++ b/native/jni/src/suggest/core/policy/weighting.cpp @@ -110,10 +110,14 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n return weighting->getOmissionCost(parentDicNode, dicNode); case CT_ADDITIONAL_PROXIMITY: // only used for typing - return weighting->getAdditionalProximityCost(); + // TODO: Quit calling getMatchedCost(). + return weighting->getAdditionalProximityCost() + + weighting->getMatchedCost(traverseSession, dicNode, inputStateG); case CT_SUBSTITUTION: // only used for typing - return weighting->getSubstitutionCost(); + // TODO: Quit calling getMatchedCost(). + return weighting->getSubstitutionCost() + + weighting->getMatchedCost(traverseSession, dicNode, inputStateG); case CT_NEW_WORD_SPACE_OMISSION: return weighting->getNewWordSpatialCost(traverseSession, dicNode, inputStateG); case CT_MATCH: @@ -176,9 +180,9 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n case CT_OMISSION: return 0; case CT_ADDITIONAL_PROXIMITY: - return 0; /* 0 because CT_MATCH will be called */ + return 1; case CT_SUBSTITUTION: - return 0; /* 0 because CT_MATCH will be called */ + return 1; case CT_NEW_WORD_SPACE_OMISSION: return 0; case CT_MATCH: diff --git a/native/jni/src/suggest/core/result/suggestion_results.cpp b/native/jni/src/suggest/core/result/suggestion_results.cpp index 4c10bd08a..3756d1092 100644 --- a/native/jni/src/suggest/core/result/suggestion_results.cpp +++ b/native/jni/src/suggest/core/result/suggestion_results.cpp @@ -23,7 +23,7 @@ namespace latinime { void SuggestionResults::outputSuggestions(JNIEnv *env, jintArray outSuggestionCount, jintArray outputCodePointsArray, jintArray outScoresArray, jintArray outSpaceIndicesArray, jintArray outTypesArray, jintArray outAutoCommitFirstWordConfidenceArray, - jfloatArray outLanguageWeight) { + jfloatArray outWeightOfLangModelVsSpatialModel) { int outputIndex = 0; while (!mSuggestedWords.empty()) { const SuggestedWord &suggestedWord = mSuggestedWords.top(); @@ -44,7 +44,8 @@ void SuggestionResults::outputSuggestions(JNIEnv *env, jintArray outSuggestionCo mSuggestedWords.pop(); } JniDataUtils::putIntToArray(env, outSuggestionCount, 0 /* index */, outputIndex); - JniDataUtils::putFloatToArray(env, outLanguageWeight, 0 /* index */, mLanguageWeight); + JniDataUtils::putFloatToArray(env, outWeightOfLangModelVsSpatialModel, 0 /* index */, + mWeightOfLangModelVsSpatialModel); } void SuggestionResults::addPrediction(const int *const codePoints, const int codePointCount, @@ -89,7 +90,7 @@ void SuggestionResults::getSortedScores(int *const outScores) const { } void SuggestionResults::dumpSuggestions() const { - AKLOGE("language weight: %f", mLanguageWeight); + AKLOGE("weight of language model vs spatial model: %f", mWeightOfLangModelVsSpatialModel); std::vector<SuggestedWord> suggestedWords; auto copyOfSuggestedWords = mSuggestedWords; while (!copyOfSuggestedWords.empty()) { diff --git a/native/jni/src/suggest/core/result/suggestion_results.h b/native/jni/src/suggest/core/result/suggestion_results.h index 8e845e2d3..738c78a9f 100644 --- a/native/jni/src/suggest/core/result/suggestion_results.h +++ b/native/jni/src/suggest/core/result/suggestion_results.h @@ -29,13 +29,15 @@ namespace latinime { class SuggestionResults { public: explicit SuggestionResults(const int maxSuggestionCount) - : mMaxSuggestionCount(maxSuggestionCount), mLanguageWeight(NOT_A_LANGUAGE_WEIGHT), + : mMaxSuggestionCount(maxSuggestionCount), + mWeightOfLangModelVsSpatialModel(NOT_A_WEIGHT_OF_LANG_MODEL_VS_SPATIAL_MODEL), mSuggestedWords() {} // Returns suggestion count. void outputSuggestions(JNIEnv *env, jintArray outSuggestionCount, jintArray outCodePointsArray, jintArray outScoresArray, jintArray outSpaceIndicesArray, jintArray outTypesArray, - jintArray outAutoCommitFirstWordConfidenceArray, jfloatArray outLanguageWeight); + jintArray outAutoCommitFirstWordConfidenceArray, + jfloatArray outWeightOfLangModelVsSpatialModel); void addPrediction(const int *const codePoints, const int codePointCount, const int score); void addSuggestion(const int *const codePoints, const int codePointCount, const int score, const int type, const int indexToPartialCommit, @@ -43,8 +45,8 @@ class SuggestionResults { void getSortedScores(int *const outScores) const; void dumpSuggestions() const; - void setLanguageWeight(const float languageWeight) { - mLanguageWeight = languageWeight; + void setWeightOfLangModelVsSpatialModel(const float weightOfLangModelVsSpatialModel) { + mWeightOfLangModelVsSpatialModel = weightOfLangModelVsSpatialModel; } int getSuggestionCount() const { @@ -55,7 +57,7 @@ class SuggestionResults { DISALLOW_IMPLICIT_CONSTRUCTORS(SuggestionResults); const int mMaxSuggestionCount; - float mLanguageWeight; + float mWeightOfLangModelVsSpatialModel; std::priority_queue< SuggestedWord, std::vector<SuggestedWord>, SuggestedWord::Comparator> mSuggestedWords; }; diff --git a/native/jni/src/suggest/core/result/suggestions_output_utils.cpp b/native/jni/src/suggest/core/result/suggestions_output_utils.cpp index 0b99b75ec..3283f6deb 100644 --- a/native/jni/src/suggest/core/result/suggestions_output_utils.cpp +++ b/native/jni/src/suggest/core/result/suggestions_output_utils.cpp @@ -34,7 +34,8 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; /* static */ void SuggestionsOutputUtils::outputSuggestions( const Scoring *const scoringPolicy, DicTraverseSession *traverseSession, - const float languageWeight, SuggestionResults *const outSuggestionResults) { + const float weightOfLangModelVsSpatialModel, + SuggestionResults *const outSuggestionResults) { #if DEBUG_EVALUATE_MOST_PROBABLE_STRING const int terminalSize = 0; #else @@ -44,12 +45,15 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; for (int index = terminalSize - 1; index >= 0; --index) { traverseSession->getDicTraverseCache()->popTerminal(&terminals[index]); } - // Compute a language weight when an invalid language weight is passed. - // NOT_A_LANGUAGE_WEIGHT (-1) is assumed as an invalid language weight. - const float languageWeightToOutputSuggestions = (languageWeight < 0.0f) ? - scoringPolicy->getAdjustedLanguageWeight( - traverseSession, terminals.data(), terminalSize) : languageWeight; - outSuggestionResults->setLanguageWeight(languageWeightToOutputSuggestions); + // Compute a weight of language model when an invalid weight is passed. + // NOT_A_WEIGHT_OF_LANG_MODEL_VS_SPATIAL_MODEL (-1) is taken as an invalid value. + const float weightOfLangModelVsSpatialModelToOutputSuggestions = + (weightOfLangModelVsSpatialModel < 0.0f) + ? scoringPolicy->getAdjustedWeightOfLangModelVsSpatialModel(traverseSession, + terminals.data(), terminalSize) + : weightOfLangModelVsSpatialModel; + outSuggestionResults->setWeightOfLangModelVsSpatialModel( + weightOfLangModelVsSpatialModelToOutputSuggestions); // Force autocorrection for obvious long multi-word suggestions when the top suggestion is // a long multiple words suggestion. // TODO: Implement a smarter auto-commit method for handling multi-word suggestions. @@ -65,16 +69,16 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; // Output suggestion results here for (auto &terminalDicNode : terminals) { outputSuggestionsOfDicNode(scoringPolicy, traverseSession, &terminalDicNode, - languageWeightToOutputSuggestions, boostExactMatches, forceCommitMultiWords, - outputSecondWordFirstLetterInputIndex, outSuggestionResults); + weightOfLangModelVsSpatialModelToOutputSuggestions, boostExactMatches, + forceCommitMultiWords, outputSecondWordFirstLetterInputIndex, outSuggestionResults); } - scoringPolicy->getMostProbableString(traverseSession, languageWeightToOutputSuggestions, - outSuggestionResults); + scoringPolicy->getMostProbableString(traverseSession, + weightOfLangModelVsSpatialModelToOutputSuggestions, outSuggestionResults); } /* static */ void SuggestionsOutputUtils::outputSuggestionsOfDicNode( const Scoring *const scoringPolicy, DicTraverseSession *traverseSession, - const DicNode *const terminalDicNode, const float languageWeight, + const DicNode *const terminalDicNode, const float weightOfLangModelVsSpatialModel, const bool boostExactMatches, const bool forceCommitMultiWords, const bool outputSecondWordFirstLetterInputIndex, SuggestionResults *const outSuggestionResults) { @@ -83,11 +87,12 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; } const float doubleLetterCost = scoringPolicy->getDoubleLetterDemotionDistanceCost(terminalDicNode); - const float compoundDistance = terminalDicNode->getCompoundDistance(languageWeight) - + doubleLetterCost; - const bool isPossiblyOffensiveWord = - traverseSession->getDictionaryStructurePolicy()->getProbability( - terminalDicNode->getProbability(), NOT_A_PROBABILITY) <= 0; + const float compoundDistance = + terminalDicNode->getCompoundDistance(weightOfLangModelVsSpatialModel) + + doubleLetterCost; + const WordAttributes wordAttributes = traverseSession->getDictionaryStructurePolicy() + ->getWordAttributesInContext(terminalDicNode->getPrevWordIds(), + terminalDicNode->getWordId(), nullptr /* multiBigramMap */); const bool isExactMatch = ErrorTypeUtils::isExactMatch(terminalDicNode->getContainedErrorTypes()); const bool isExactMatchWithIntentionalOmission = @@ -97,19 +102,19 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; // Heuristic: We exclude probability=0 first-char-uppercase words from exact match. // (e.g. "AMD" and "and") const bool isSafeExactMatch = isExactMatch - && !(isPossiblyOffensiveWord && isFirstCharUppercase); + && !(wordAttributes.isPossiblyOffensive() && isFirstCharUppercase); const int outputTypeFlags = - (isPossiblyOffensiveWord ? Dictionary::KIND_FLAG_POSSIBLY_OFFENSIVE : 0) + (wordAttributes.isPossiblyOffensive() ? Dictionary::KIND_FLAG_POSSIBLY_OFFENSIVE : 0) | ((isSafeExactMatch && boostExactMatches) ? Dictionary::KIND_FLAG_EXACT_MATCH : 0) | (isExactMatchWithIntentionalOmission ? Dictionary::KIND_FLAG_EXACT_MATCH_WITH_INTENTIONAL_OMISSION : 0); // Entries that are blacklisted or do not represent a word should not be output. - const bool isValidWord = !terminalDicNode->isBlacklistedOrNotAWord(); + const bool isValidWord = !(wordAttributes.isBlacklisted() || wordAttributes.isNotAWord()); // When we have to block offensive words, non-exact matched offensive words should not be // output. const bool blockOffensiveWords = traverseSession->getSuggestOptions()->blockOffensiveWords(); - const bool isBlockedOffensiveWord = blockOffensiveWords && isPossiblyOffensiveWord + const bool isBlockedOffensiveWord = blockOffensiveWords && wordAttributes.isPossiblyOffensive() && !isSafeExactMatch; // Increase output score of top typing suggestion to ensure autocorrection. @@ -139,10 +144,9 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; // Shortcut is not supported for multiple words suggestions. // TODO: Check shortcuts during traversal for multiple words suggestions. if (!terminalDicNode->hasMultipleWords()) { - BinaryDictionaryShortcutIterator shortcutIt( - traverseSession->getDictionaryStructurePolicy()->getShortcutsStructurePolicy(), - traverseSession->getDictionaryStructurePolicy() - ->getShortcutPositionOfPtNode(terminalDicNode->getPtNodePos())); + BinaryDictionaryShortcutIterator shortcutIt = + traverseSession->getDictionaryStructurePolicy()->getShortcutIterator( + terminalDicNode->getWordId()); const bool sameAsTyped = scoringPolicy->sameAsTyped(traverseSession, terminalDicNode); outputShortcuts(&shortcutIt, finalScore, sameAsTyped, outSuggestionResults); } diff --git a/native/jni/src/suggest/core/result/suggestions_output_utils.h b/native/jni/src/suggest/core/result/suggestions_output_utils.h index b099b4776..bf8497828 100644 --- a/native/jni/src/suggest/core/result/suggestions_output_utils.h +++ b/native/jni/src/suggest/core/result/suggestions_output_utils.h @@ -33,7 +33,7 @@ class SuggestionsOutputUtils { * Outputs the final list of suggestions (i.e., terminal nodes). */ static void outputSuggestions(const Scoring *const scoringPolicy, - DicTraverseSession *traverseSession, const float languageWeight, + DicTraverseSession *traverseSession, const float weightOfLangModelVsSpatialModel, SuggestionResults *const outSuggestionResults); private: @@ -44,7 +44,7 @@ class SuggestionsOutputUtils { static void outputSuggestionsOfDicNode(const Scoring *const scoringPolicy, DicTraverseSession *traverseSession, const DicNode *const terminalDicNode, - const float languageWeight, const bool boostExactMatches, + const float weightOfLangModelVsSpatialModel, const bool boostExactMatches, const bool forceCommitMultiWords, const bool outputSecondWordFirstLetterInputIndex, SuggestionResults *const outSuggestionResults); static void outputShortcuts(BinaryDictionaryShortcutIterator *const shortcutIt, diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.cpp b/native/jni/src/suggest/core/session/dic_traverse_session.cpp index f1e411f38..52dc2f86c 100644 --- a/native/jni/src/suggest/core/session/dic_traverse_session.cpp +++ b/native/jni/src/suggest/core/session/dic_traverse_session.cpp @@ -20,7 +20,7 @@ #include "suggest/core/dictionary/dictionary.h" #include "suggest/core/policy/dictionary_header_structure_policy.h" #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" -#include "suggest/core/session/prev_words_info.h" +#include "suggest/core/session/ngram_context.h" namespace latinime { @@ -30,13 +30,13 @@ const int DicTraverseSession::DICTIONARY_SIZE_THRESHOLD_TO_USE_LARGE_CACHE_FOR_S 256 * 1024; void DicTraverseSession::init(const Dictionary *const dictionary, - const PrevWordsInfo *const prevWordsInfo, const SuggestOptions *const suggestOptions) { + const NgramContext *const ngramContext, const SuggestOptions *const suggestOptions) { mDictionary = dictionary; mMultiWordCostMultiplier = getDictionaryStructurePolicy()->getHeaderStructurePolicy() ->getMultiWordCostMultiplier(); mSuggestOptions = suggestOptions; - prevWordsInfo->getPrevWordsTerminalPtNodePos( - getDictionaryStructurePolicy(), mPrevWordsPtNodePos, true /* tryLowerCaseSearch */); + mPrevWordIdCount = ngramContext->getPrevWordIds(getDictionaryStructurePolicy(), + &mPrevWordIdArray, true /* tryLowerCaseSearch */).size(); } void DicTraverseSession::setupForGetSuggestions(const ProximityInfo *pInfo, @@ -69,8 +69,12 @@ void DicTraverseSession::initializeProximityInfoStates(const int *const inputCod for (int i = 0; i < maxPointerCount; ++i) { mProximityInfoStates[i].initInputParams(i, maxSpatialDistance, getProximityInfo(), inputCodePoints, inputSize, inputXs, inputYs, times, pointerIds, - maxPointerCount == MAX_POINTER_COUNT_G - /* TODO: this is a hack. fix proximity info state */); + // Right now the line below is trying to figure out whether this is a gesture by + // looking at the pointer count and assuming whatever is above the cutoff is + // a gesture and whatever is below is type. This is hacky and incorrect, we + // should pass the correct information instead. + maxPointerCount == MAX_POINTER_COUNT_G, + getDictionaryStructurePolicy()->getHeaderStructurePolicy()->getLocale()); mInputSize += mProximityInfoStates[i].size(); } } diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.h b/native/jni/src/suggest/core/session/dic_traverse_session.h index 5a51a112d..bc53167f0 100644 --- a/native/jni/src/suggest/core/session/dic_traverse_session.h +++ b/native/jni/src/suggest/core/session/dic_traverse_session.h @@ -24,12 +24,13 @@ #include "suggest/core/dicnode/dic_nodes_cache.h" #include "suggest/core/dictionary/multi_bigram_map.h" #include "suggest/core/layout/proximity_info_state.h" +#include "utils/int_array_view.h" namespace latinime { class Dictionary; class DictionaryStructureWithBufferPolicy; -class PrevWordsInfo; +class NgramContext; class ProximityInfo; class SuggestOptions; @@ -50,20 +51,17 @@ class DicTraverseSession { } AK_FORCE_INLINE DicTraverseSession(JNIEnv *env, jstring localeStr, bool usesLargeCache) - : mProximityInfo(nullptr), mDictionary(nullptr), mSuggestOptions(nullptr), - mDicNodesCache(usesLargeCache), mMultiBigramMap(), mInputSize(0), mMaxPointerCount(1), - mMultiWordCostMultiplier(1.0f) { + : mPrevWordIdCount(0), mProximityInfo(nullptr), mDictionary(nullptr), + mSuggestOptions(nullptr), mDicNodesCache(usesLargeCache), mMultiBigramMap(), + mInputSize(0), mMaxPointerCount(1), mMultiWordCostMultiplier(1.0f) { // NOTE: mProximityInfoStates is an array of instances. // No need to initialize it explicitly here. - for (size_t i = 0; i < NELEMS(mPrevWordsPtNodePos); ++i) { - mPrevWordsPtNodePos[i] = NOT_A_DICT_POS; - } } // Non virtual inline destructor -- never inherit this class AK_FORCE_INLINE ~DicTraverseSession() {} - void init(const Dictionary *dictionary, const PrevWordsInfo *const prevWordsInfo, + void init(const Dictionary *dictionary, const NgramContext *const ngramContext, const SuggestOptions *const suggestOptions); // TODO: Remove and merge into init void setupForGetSuggestions(const ProximityInfo *pInfo, const int *inputCodePoints, @@ -79,7 +77,9 @@ class DicTraverseSession { //-------------------- const ProximityInfo *getProximityInfo() const { return mProximityInfo; } const SuggestOptions *getSuggestOptions() const { return mSuggestOptions; } - const int *getPrevWordsPtNodePos() const { return mPrevWordsPtNodePos; } + const WordIdArrayView getPrevWordIds() const { + return WordIdArrayView::fromArray(mPrevWordIdArray).limit(mPrevWordIdCount); + } DicNodesCache *getDicTraverseCache() { return &mDicNodesCache; } MultiBigramMap *getMultiBigramMap() { return &mMultiBigramMap; } const ProximityInfoState *getProximityInfoState(int id) const { @@ -166,7 +166,8 @@ class DicTraverseSession { const int *const inputYs, const int *const times, const int *const pointerIds, const int inputSize, const float maxSpatialDistance, const int maxPointerCount); - int mPrevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> mPrevWordIdArray; + size_t mPrevWordIdCount; const ProximityInfo *mProximityInfo; const Dictionary *mDictionary; const SuggestOptions *mSuggestOptions; diff --git a/native/jni/src/suggest/core/session/ngram_context.cpp b/native/jni/src/suggest/core/session/ngram_context.cpp new file mode 100644 index 000000000..17ef9ae60 --- /dev/null +++ b/native/jni/src/suggest/core/session/ngram_context.cpp @@ -0,0 +1,123 @@ +/* + * Copyright (C) 2014 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/core/session/ngram_context.h" + +#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" +#include "utils/char_utils.h" + +namespace latinime { + +NgramContext::NgramContext() : mPrevWordCount(0) {} + +NgramContext::NgramContext(const NgramContext &ngramContext) + : mPrevWordCount(ngramContext.mPrevWordCount) { + for (size_t i = 0; i < mPrevWordCount; ++i) { + mPrevWordCodePointCount[i] = ngramContext.mPrevWordCodePointCount[i]; + memmove(mPrevWordCodePoints[i], ngramContext.mPrevWordCodePoints[i], + sizeof(mPrevWordCodePoints[i][0]) * mPrevWordCodePointCount[i]); + mIsBeginningOfSentence[i] = ngramContext.mIsBeginningOfSentence[i]; + } +} + +NgramContext::NgramContext(const int prevWordCodePoints[][MAX_WORD_LENGTH], + const int *const prevWordCodePointCount, const bool *const isBeginningOfSentence, + const size_t prevWordCount) + : mPrevWordCount(std::min(NELEMS(mPrevWordCodePoints), prevWordCount)) { + clear(); + for (size_t i = 0; i < mPrevWordCount; ++i) { + if (prevWordCodePointCount[i] < 0 || prevWordCodePointCount[i] > MAX_WORD_LENGTH) { + continue; + } + memmove(mPrevWordCodePoints[i], prevWordCodePoints[i], + sizeof(mPrevWordCodePoints[i][0]) * prevWordCodePointCount[i]); + mPrevWordCodePointCount[i] = prevWordCodePointCount[i]; + mIsBeginningOfSentence[i] = isBeginningOfSentence[i]; + } +} + +NgramContext::NgramContext(const int *const prevWordCodePoints, const int prevWordCodePointCount, + const bool isBeginningOfSentence) : mPrevWordCount(1) { + clear(); + if (prevWordCodePointCount > MAX_WORD_LENGTH || !prevWordCodePoints) { + return; + } + memmove(mPrevWordCodePoints[0], prevWordCodePoints, + sizeof(mPrevWordCodePoints[0][0]) * prevWordCodePointCount); + mPrevWordCodePointCount[0] = prevWordCodePointCount; + mIsBeginningOfSentence[0] = isBeginningOfSentence; +} + +bool NgramContext::isValid() const { + if (mPrevWordCodePointCount[0] > 0) { + return true; + } + if (mIsBeginningOfSentence[0]) { + return true; + } + return false; +} + +const CodePointArrayView NgramContext::getNthPrevWordCodePoints(const size_t n) const { + if (n <= 0 || n > mPrevWordCount) { + return CodePointArrayView(); + } + return CodePointArrayView(mPrevWordCodePoints[n - 1], mPrevWordCodePointCount[n - 1]); +} + +bool NgramContext::isNthPrevWordBeginningOfSentence(const size_t n) const { + if (n <= 0 || n > mPrevWordCount) { + return false; + } + return mIsBeginningOfSentence[n - 1]; +} + +/* static */ int NgramContext::getWordId( + const DictionaryStructureWithBufferPolicy *const dictStructurePolicy, + const int *const wordCodePoints, const int wordCodePointCount, + const bool isBeginningOfSentence, const bool tryLowerCaseSearch) { + if (!dictStructurePolicy || !wordCodePoints || wordCodePointCount > MAX_WORD_LENGTH) { + return NOT_A_WORD_ID; + } + int codePoints[MAX_WORD_LENGTH]; + int codePointCount = wordCodePointCount; + memmove(codePoints, wordCodePoints, sizeof(int) * codePointCount); + if (isBeginningOfSentence) { + codePointCount = CharUtils::attachBeginningOfSentenceMarker(codePoints, codePointCount, + MAX_WORD_LENGTH); + if (codePointCount <= 0) { + return NOT_A_WORD_ID; + } + } + const CodePointArrayView codePointArrayView(codePoints, codePointCount); + const int wordId = dictStructurePolicy->getWordId(codePointArrayView, + false /* forceLowerCaseSearch */); + if (wordId != NOT_A_WORD_ID || !tryLowerCaseSearch) { + // Return the id when when the word was found or doesn't try lower case search. + return wordId; + } + // Check bigrams for lower-cased previous word if original was not found. Useful for + // auto-capitalized words like "The [current_word]". + return dictStructurePolicy->getWordId(codePointArrayView, true /* forceLowerCaseSearch */); +} + +void NgramContext::clear() { + for (size_t i = 0; i < NELEMS(mPrevWordCodePoints); ++i) { + mPrevWordCodePointCount[i] = 0; + mIsBeginningOfSentence[i] = false; + } +} +} // namespace latinime diff --git a/native/jni/src/suggest/core/session/ngram_context.h b/native/jni/src/suggest/core/session/ngram_context.h new file mode 100644 index 000000000..9b36199c9 --- /dev/null +++ b/native/jni/src/suggest/core/session/ngram_context.h @@ -0,0 +1,78 @@ +/* + * Copyright (C) 2014 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_NGRAM_CONTEXT_H +#define LATINIME_NGRAM_CONTEXT_H + +#include <array> + +#include "defines.h" +#include "utils/int_array_view.h" + +namespace latinime { + +class DictionaryStructureWithBufferPolicy; + +class NgramContext { + public: + // No prev word information. + NgramContext(); + // Copy constructor to use this class with std::vector and use this class as a return value. + NgramContext(const NgramContext &ngramContext); + // Construct from previous words. + NgramContext(const int prevWordCodePoints[][MAX_WORD_LENGTH], + const int *const prevWordCodePointCount, const bool *const isBeginningOfSentence, + const size_t prevWordCount); + // Construct from a previous word. + NgramContext(const int *const prevWordCodePoints, const int prevWordCodePointCount, + const bool isBeginningOfSentence); + + size_t getPrevWordCount() const { + return mPrevWordCount; + } + bool isValid() const; + + template<size_t N> + const WordIdArrayView getPrevWordIds( + const DictionaryStructureWithBufferPolicy *const dictStructurePolicy, + WordIdArray<N> *const prevWordIdBuffer, const bool tryLowerCaseSearch) const { + for (size_t i = 0; i < std::min(mPrevWordCount, N); ++i) { + prevWordIdBuffer->at(i) = getWordId(dictStructurePolicy, mPrevWordCodePoints[i], + mPrevWordCodePointCount[i], mIsBeginningOfSentence[i], tryLowerCaseSearch); + } + return WordIdArrayView::fromArray(*prevWordIdBuffer).limit(mPrevWordCount); + } + + // n is 1-indexed. + const CodePointArrayView getNthPrevWordCodePoints(const size_t n) const; + // n is 1-indexed. + bool isNthPrevWordBeginningOfSentence(const size_t n) const; + + private: + DISALLOW_ASSIGNMENT_OPERATOR(NgramContext); + + static int getWordId(const DictionaryStructureWithBufferPolicy *const dictStructurePolicy, + const int *const wordCodePoints, const int wordCodePointCount, + const bool isBeginningOfSentence, const bool tryLowerCaseSearch); + void clear(); + + const size_t mPrevWordCount; + int mPrevWordCodePoints[MAX_PREV_WORD_COUNT_FOR_N_GRAM][MAX_WORD_LENGTH]; + int mPrevWordCodePointCount[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; + bool mIsBeginningOfSentence[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; +}; +} // namespace latinime +#endif // LATINIME_NGRAM_CONTEXT_H diff --git a/native/jni/src/suggest/core/session/prev_words_info.h b/native/jni/src/suggest/core/session/prev_words_info.h deleted file mode 100644 index e44e876e9..000000000 --- a/native/jni/src/suggest/core/session/prev_words_info.h +++ /dev/null @@ -1,162 +0,0 @@ -/* - * Copyright (C) 2014 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LATINIME_PREV_WORDS_INFO_H -#define LATINIME_PREV_WORDS_INFO_H - -#include "defines.h" -#include "suggest/core/dictionary/binary_dictionary_bigrams_iterator.h" -#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" -#include "utils/char_utils.h" - -namespace latinime { - -// TODO: Support n-gram. -class PrevWordsInfo { - public: - // No prev word information. - PrevWordsInfo() { - clear(); - } - - PrevWordsInfo(PrevWordsInfo &&prevWordsInfo) { - for (size_t i = 0; i < NELEMS(mPrevWordCodePoints); ++i) { - mPrevWordCodePointCount[i] = prevWordsInfo.mPrevWordCodePointCount[i]; - memmove(mPrevWordCodePoints[i], prevWordsInfo.mPrevWordCodePoints[i], - sizeof(mPrevWordCodePoints[i][0]) * mPrevWordCodePointCount[i]); - mIsBeginningOfSentence[i] = prevWordsInfo.mIsBeginningOfSentence[i]; - } - } - - // Construct from previous words. - PrevWordsInfo(const int prevWordCodePoints[][MAX_WORD_LENGTH], - const int *const prevWordCodePointCount, const bool *const isBeginningOfSentence, - const size_t prevWordCount) { - clear(); - for (size_t i = 0; i < std::min(NELEMS(mPrevWordCodePoints), prevWordCount); ++i) { - if (prevWordCodePointCount[i] < 0 || prevWordCodePointCount[i] > MAX_WORD_LENGTH) { - continue; - } - memmove(mPrevWordCodePoints[i], prevWordCodePoints[i], - sizeof(mPrevWordCodePoints[i][0]) * prevWordCodePointCount[i]); - mPrevWordCodePointCount[i] = prevWordCodePointCount[i]; - mIsBeginningOfSentence[i] = isBeginningOfSentence[i]; - } - } - - // Construct from a previous word. - PrevWordsInfo(const int *const prevWordCodePoints, const int prevWordCodePointCount, - const bool isBeginningOfSentence) { - clear(); - if (prevWordCodePointCount > MAX_WORD_LENGTH || !prevWordCodePoints) { - return; - } - memmove(mPrevWordCodePoints[0], prevWordCodePoints, - sizeof(mPrevWordCodePoints[0][0]) * prevWordCodePointCount); - mPrevWordCodePointCount[0] = prevWordCodePointCount; - mIsBeginningOfSentence[0] = isBeginningOfSentence; - } - - bool isValid() const { - if (mPrevWordCodePointCount[0] > 0) { - return true; - } - if (mIsBeginningOfSentence[0]) { - return true; - } - return false; - } - - void getPrevWordsTerminalPtNodePos( - const DictionaryStructureWithBufferPolicy *const dictStructurePolicy, - int *const outPrevWordsTerminalPtNodePos, const bool tryLowerCaseSearch) const { - for (size_t i = 0; i < NELEMS(mPrevWordCodePoints); ++i) { - outPrevWordsTerminalPtNodePos[i] = getTerminalPtNodePosOfWord(dictStructurePolicy, - mPrevWordCodePoints[i], mPrevWordCodePointCount[i], - mIsBeginningOfSentence[i], tryLowerCaseSearch); - } - } - - // n is 1-indexed. - const int *getNthPrevWordCodePoints(const int n) const { - if (n <= 0 || n > MAX_PREV_WORD_COUNT_FOR_N_GRAM) { - return nullptr; - } - return mPrevWordCodePoints[n - 1]; - } - - // n is 1-indexed. - int getNthPrevWordCodePointCount(const int n) const { - if (n <= 0 || n > MAX_PREV_WORD_COUNT_FOR_N_GRAM) { - return 0; - } - return mPrevWordCodePointCount[n - 1]; - } - - // n is 1-indexed. - bool isNthPrevWordBeginningOfSentence(const int n) const { - if (n <= 0 || n > MAX_PREV_WORD_COUNT_FOR_N_GRAM) { - return false; - } - return mIsBeginningOfSentence[n - 1]; - } - - private: - DISALLOW_COPY_AND_ASSIGN(PrevWordsInfo); - - static int getTerminalPtNodePosOfWord( - const DictionaryStructureWithBufferPolicy *const dictStructurePolicy, - const int *const wordCodePoints, const int wordCodePointCount, - const bool isBeginningOfSentence, const bool tryLowerCaseSearch) { - if (!dictStructurePolicy || !wordCodePoints || wordCodePointCount > MAX_WORD_LENGTH) { - return NOT_A_DICT_POS; - } - int codePoints[MAX_WORD_LENGTH]; - int codePointCount = wordCodePointCount; - memmove(codePoints, wordCodePoints, sizeof(int) * codePointCount); - if (isBeginningOfSentence) { - codePointCount = CharUtils::attachBeginningOfSentenceMarker(codePoints, - codePointCount, MAX_WORD_LENGTH); - if (codePointCount <= 0) { - return NOT_A_DICT_POS; - } - } - const int wordPtNodePos = dictStructurePolicy->getTerminalPtNodePositionOfWord( - codePoints, codePointCount, false /* forceLowerCaseSearch */); - if (wordPtNodePos != NOT_A_DICT_POS || !tryLowerCaseSearch) { - // Return the position when when the word was found or doesn't try lower case - // search. - return wordPtNodePos; - } - // Check bigrams for lower-cased previous word if original was not found. Useful for - // auto-capitalized words like "The [current_word]". - return dictStructurePolicy->getTerminalPtNodePositionOfWord( - codePoints, codePointCount, true /* forceLowerCaseSearch */); - } - - void clear() { - for (size_t i = 0; i < NELEMS(mPrevWordCodePoints); ++i) { - mPrevWordCodePointCount[i] = 0; - mIsBeginningOfSentence[i] = false; - } - } - - int mPrevWordCodePoints[MAX_PREV_WORD_COUNT_FOR_N_GRAM][MAX_WORD_LENGTH]; - int mPrevWordCodePointCount[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; - bool mIsBeginningOfSentence[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; -}; -} // namespace latinime -#endif // LATINIME_PREV_WORDS_INFO_H diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp index 0cd305f5a..c71526293 100644 --- a/native/jni/src/suggest/core/suggest.cpp +++ b/native/jni/src/suggest/core/suggest.cpp @@ -21,12 +21,14 @@ #include "suggest/core/dicnode/dic_node_vector.h" #include "suggest/core/dictionary/dictionary.h" #include "suggest/core/dictionary/digraph_utils.h" +#include "suggest/core/dictionary/word_attributes.h" #include "suggest/core/layout/proximity_info.h" #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" #include "suggest/core/policy/traversal.h" #include "suggest/core/policy/weighting.h" #include "suggest/core/result/suggestions_output_utils.h" #include "suggest/core/session/dic_traverse_session.h" +#include "suggest/core/suggest_options.h" namespace latinime { @@ -44,7 +46,7 @@ const int Suggest::MIN_CONTINUOUS_SUGGESTION_INPUT_SIZE = 2; */ void Suggest::getSuggestions(ProximityInfo *pInfo, void *traverseSession, int *inputXs, int *inputYs, int *times, int *pointerIds, int *inputCodePoints, - int inputSize, const float languageWeight, + int inputSize, const float weightOfLangModelVsSpatialModel, SuggestionResults *const outSuggestionResults) const { PROF_OPEN; PROF_START(0); @@ -67,7 +69,7 @@ void Suggest::getSuggestions(ProximityInfo *pInfo, void *traverseSession, PROF_END(1); PROF_START(2); SuggestionsOutputUtils::outputSuggestions( - SCORING, tSession, languageWeight, outSuggestionResults); + SCORING, tSession, weightOfLangModelVsSpatialModel, outSuggestionResults); PROF_END(2); PROF_CLOSE; } @@ -87,12 +89,13 @@ void Suggest::initializeSearch(DicTraverseSession *traverseSession) const { traverseSession->getDicTraverseCache()->continueSearch(); } else { // Restart recognition at the root. - traverseSession->resetCache(TRAVERSAL->getMaxCacheSize(traverseSession->getInputSize()), + traverseSession->resetCache(TRAVERSAL->getMaxCacheSize(traverseSession->getInputSize(), + traverseSession->getSuggestOptions()->weightForLocale()), TRAVERSAL->getTerminalCacheSize()); // Create a new dic node here DicNode rootNode; DicNodeUtils::initAsRoot(traverseSession->getDictionaryStructurePolicy(), - traverseSession->getPrevWordsPtNodePos(), &rootNode); + traverseSession->getPrevWordIds(), &rootNode); traverseSession->getDicTraverseCache()->copyPushActive(&rootNode); } } @@ -281,7 +284,6 @@ void Suggest::processDicNodeAsAdditionalProximityChar(DicTraverseSession *traver // not treat the node as a terminal. There is no need to pass the bigram map in these cases. Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_ADDITIONAL_PROXIMITY, traverseSession, dicNode, childDicNode, 0 /* multiBigramMap */); - weightChildNode(traverseSession, childDicNode); processExpandedDicNode(traverseSession, childDicNode); } @@ -289,7 +291,6 @@ void Suggest::processDicNodeAsSubstitution(DicTraverseSession *traverseSession, DicNode *dicNode, DicNode *childDicNode) const { Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_SUBSTITUTION, traverseSession, dicNode, childDicNode, 0 /* multiBigramMap */); - weightChildNode(traverseSession, childDicNode); processExpandedDicNode(traverseSession, childDicNode); } @@ -400,7 +401,7 @@ void Suggest::weightChildNode(DicTraverseSession *traverseSession, DicNode *dicN if (dicNode->isCompletion(inputSize)) { Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_COMPLETION, traverseSession, 0 /* parentDicNode */, dicNode, 0 /* multiBigramMap */); - } else { // completion + } else { Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_MATCH, traverseSession, 0 /* parentDicNode */, dicNode, 0 /* multiBigramMap */); } @@ -412,7 +413,11 @@ void Suggest::weightChildNode(DicTraverseSession *traverseSession, DicNode *dicN */ void Suggest::createNextWordDicNode(DicTraverseSession *traverseSession, DicNode *dicNode, const bool spaceSubstitution) const { - if (!TRAVERSAL->isGoodToTraverseNextWord(dicNode)) { + const WordAttributes wordAttributes = + traverseSession->getDictionaryStructurePolicy()->getWordAttributesInContext( + dicNode->getPrevWordIds(), dicNode->getWordId(), + traverseSession->getMultiBigramMap()); + if (!TRAVERSAL->isGoodToTraverseNextWord(dicNode, wordAttributes.getProbability())) { return; } diff --git a/native/jni/src/suggest/core/suggest.h b/native/jni/src/suggest/core/suggest.h index 788e0314b..65d5918cf 100644 --- a/native/jni/src/suggest/core/suggest.h +++ b/native/jni/src/suggest/core/suggest.h @@ -49,7 +49,8 @@ class Suggest : public SuggestInterface { AK_FORCE_INLINE virtual ~Suggest() {} void getSuggestions(ProximityInfo *pInfo, void *traverseSession, int *inputXs, int *inputYs, int *times, int *pointerIds, int *inputCodePoints, int inputSize, - const float languageWeight, SuggestionResults *const outSuggestionResults) const; + const float weightOfLangModelVsSpatialModel, + SuggestionResults *const outSuggestionResults) const; private: DISALLOW_IMPLICIT_CONSTRUCTORS(Suggest); diff --git a/native/jni/src/suggest/core/suggest_interface.h b/native/jni/src/suggest/core/suggest_interface.h index a6e5aefae..a05aa9c80 100644 --- a/native/jni/src/suggest/core/suggest_interface.h +++ b/native/jni/src/suggest/core/suggest_interface.h @@ -28,7 +28,8 @@ class SuggestInterface { public: virtual void getSuggestions(ProximityInfo *pInfo, void *traverseSession, int *inputXs, int *inputYs, int *times, int *pointerIds, int *inputCodePoints, int inputSize, - const float languageWeight, SuggestionResults *const suggestionResults) const = 0; + const float weightOfLangModelVsSpatialModel, + SuggestionResults *const suggestionResults) const = 0; SuggestInterface() {} virtual ~SuggestInterface() {} private: diff --git a/native/jni/src/suggest/core/suggest_options.h b/native/jni/src/suggest/core/suggest_options.h index d456680dd..4d331292b 100644 --- a/native/jni/src/suggest/core/suggest_options.h +++ b/native/jni/src/suggest/core/suggest_options.h @@ -42,6 +42,12 @@ class SuggestOptions{ return getBoolOption(SPACE_AWARE_GESTURE_ENABLED); } + AK_FORCE_INLINE float weightForLocale() const { + // The weight is in thousands and we want the real value, so we divide by 1000. + // NativeSuggestOptions#setWeightForLocale does the opposite processing in Java. + return static_cast<float>(getIntOption(WEIGHT_FOR_LOCALE_IN_THOUSANDS)) / 1000.0f; + } + AK_FORCE_INLINE bool getAdditionalFeaturesBoolOption(const int key) const { return getBoolOption(key + ADDITIONAL_FEATURES_OPTIONS); } @@ -55,9 +61,10 @@ class SuggestOptions{ static const int USE_FULL_EDIT_DISTANCE = 1; static const int BLOCK_OFFENSIVE_WORDS = 2; static const int SPACE_AWARE_GESTURE_ENABLED = 3; + static const int WEIGHT_FOR_LOCALE_IN_THOUSANDS = 4; // Additional features options are stored after the other options and used as setting values of // experimental features. - static const int ADDITIONAL_FEATURES_OPTIONS = 4; + static const int ADDITIONAL_FEATURES_OPTIONS = 5; const int *const mOptions; const int mLength; diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp index 6ed65d921..300e96c4e 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp @@ -30,31 +30,26 @@ const char *const HeaderPolicy::DATE_KEY = "date"; const char *const HeaderPolicy::LAST_DECAYED_TIME_KEY = "LAST_DECAYED_TIME"; const char *const HeaderPolicy::UNIGRAM_COUNT_KEY = "UNIGRAM_COUNT"; const char *const HeaderPolicy::BIGRAM_COUNT_KEY = "BIGRAM_COUNT"; +const char *const HeaderPolicy::TRIGRAM_COUNT_KEY = "TRIGRAM_COUNT"; const char *const HeaderPolicy::EXTENDED_REGION_SIZE_KEY = "EXTENDED_REGION_SIZE"; // Historical info is information that is needed to support decaying such as timestamp, level and // count. const char *const HeaderPolicy::HAS_HISTORICAL_INFO_KEY = "HAS_HISTORICAL_INFO"; const char *const HeaderPolicy::LOCALE_KEY = "locale"; // match Java declaration -const char *const HeaderPolicy::FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP_KEY = - "FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP"; const char *const HeaderPolicy::FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY = "FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID"; -const char *const HeaderPolicy::FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS_KEY = - "FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS"; -const char *const HeaderPolicy::MAX_UNIGRAM_COUNT_KEY = "MAX_UNIGRAM_COUNT"; -const char *const HeaderPolicy::MAX_BIGRAM_COUNT_KEY = "MAX_BIGRAM_COUNT"; +const char *const HeaderPolicy::MAX_UNIGRAM_COUNT_KEY = "MAX_UNIGRAM_ENTRY_COUNT"; +const char *const HeaderPolicy::MAX_BIGRAM_COUNT_KEY = "MAX_BIGRAM_ENTRY_COUNT"; +const char *const HeaderPolicy::MAX_TRIGRAM_COUNT_KEY = "MAX_TRIGRAM_ENTRY_COUNT"; const int HeaderPolicy::DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE = 100; const float HeaderPolicy::MULTIPLE_WORD_COST_MULTIPLIER_SCALE = 100.0f; -const int HeaderPolicy::DEFAULT_FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP = 2; const int HeaderPolicy::DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID = 3; -// 30 days -const int HeaderPolicy::DEFAULT_FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS = - 30 * 24 * 60 * 60; const int HeaderPolicy::DEFAULT_MAX_UNIGRAM_COUNT = 10000; -const int HeaderPolicy::DEFAULT_MAX_BIGRAM_COUNT = 10000; +const int HeaderPolicy::DEFAULT_MAX_BIGRAM_COUNT = 30000; +const int HeaderPolicy::DEFAULT_MAX_TRIGRAM_COUNT = 30000; // Used for logging. Question mark is used to indicate that the key is not found. void HeaderPolicy::readHeaderValueOrQuestionMark(const char *const key, int *outValue, @@ -100,12 +95,11 @@ bool HeaderPolicy::readRequiresGermanUmlautProcessing() const { } bool HeaderPolicy::fillInAndWriteHeaderToBuffer(const bool updatesLastDecayedTime, - const int unigramCount, const int bigramCount, - const int extendedRegionSize, BufferWithExtendableBuffer *const outBuffer) const { + const EntryCounts &entryCounts, const int extendedRegionSize, + BufferWithExtendableBuffer *const outBuffer) const { int writingPos = 0; DictionaryHeaderStructurePolicy::AttributeMap attributeMapToWrite(mAttributeMap); - fillInHeader(updatesLastDecayedTime, unigramCount, bigramCount, - extendedRegionSize, &attributeMapToWrite); + fillInHeader(updatesLastDecayedTime, entryCounts, extendedRegionSize, &attributeMapToWrite); if (!HeaderReadWriteUtils::writeDictionaryVersion(outBuffer, mDictFormatVersion, &writingPos)) { return false; @@ -132,11 +126,15 @@ bool HeaderPolicy::fillInAndWriteHeaderToBuffer(const bool updatesLastDecayedTim return true; } -void HeaderPolicy::fillInHeader(const bool updatesLastDecayedTime, const int unigramCount, - const int bigramCount, const int extendedRegionSize, +void HeaderPolicy::fillInHeader(const bool updatesLastDecayedTime, + const EntryCounts &entryCounts, const int extendedRegionSize, DictionaryHeaderStructurePolicy::AttributeMap *outAttributeMap) const { - HeaderReadWriteUtils::setIntAttribute(outAttributeMap, UNIGRAM_COUNT_KEY, unigramCount); - HeaderReadWriteUtils::setIntAttribute(outAttributeMap, BIGRAM_COUNT_KEY, bigramCount); + HeaderReadWriteUtils::setIntAttribute(outAttributeMap, UNIGRAM_COUNT_KEY, + entryCounts.getUnigramCount()); + HeaderReadWriteUtils::setIntAttribute(outAttributeMap, BIGRAM_COUNT_KEY, + entryCounts.getBigramCount()); + HeaderReadWriteUtils::setIntAttribute(outAttributeMap, TRIGRAM_COUNT_KEY, + entryCounts.getTrigramCount()); HeaderReadWriteUtils::setIntAttribute(outAttributeMap, EXTENDED_REGION_SIZE_KEY, extendedRegionSize); // Set the current time as the generation time. diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h index 87cf0cd3b..44c2f443f 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h @@ -22,6 +22,7 @@ #include "defines.h" #include "suggest/core/policy/dictionary_header_structure_policy.h" #include "suggest/policyimpl/dictionary/header/header_read_write_utils.h" +#include "suggest/policyimpl/dictionary/utils/entry_counters.h" #include "suggest/policyimpl/dictionary/utils/format_utils.h" #include "utils/char_utils.h" #include "utils/time_keeper.h" @@ -49,23 +50,22 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { UNIGRAM_COUNT_KEY, 0 /* defaultValue */)), mBigramCount(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap, BIGRAM_COUNT_KEY, 0 /* defaultValue */)), + mTrigramCount(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap, + TRIGRAM_COUNT_KEY, 0 /* defaultValue */)), mExtendedRegionSize(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap, EXTENDED_REGION_SIZE_KEY, 0 /* defaultValue */)), mHasHistoricalInfoOfWords(HeaderReadWriteUtils::readBoolAttributeValue( &mAttributeMap, HAS_HISTORICAL_INFO_KEY, false /* defaultValue */)), - mForgettingCurveOccurrencesToLevelUp(HeaderReadWriteUtils::readIntAttributeValue( - &mAttributeMap, FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP_KEY, - DEFAULT_FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP)), mForgettingCurveProbabilityValuesTableId(HeaderReadWriteUtils::readIntAttributeValue( &mAttributeMap, FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY, DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID)), - mForgettingCurveDurationToLevelDown(HeaderReadWriteUtils::readIntAttributeValue( - &mAttributeMap, FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS_KEY, - DEFAULT_FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS)), mMaxUnigramCount(HeaderReadWriteUtils::readIntAttributeValue( &mAttributeMap, MAX_UNIGRAM_COUNT_KEY, DEFAULT_MAX_UNIGRAM_COUNT)), mMaxBigramCount(HeaderReadWriteUtils::readIntAttributeValue( - &mAttributeMap, MAX_BIGRAM_COUNT_KEY, DEFAULT_MAX_BIGRAM_COUNT)) {} + &mAttributeMap, MAX_BIGRAM_COUNT_KEY, DEFAULT_MAX_BIGRAM_COUNT)), + mMaxTrigramCount(HeaderReadWriteUtils::readIntAttributeValue( + &mAttributeMap, MAX_TRIGRAM_COUNT_KEY, DEFAULT_MAX_TRIGRAM_COUNT)), + mCodePointTable(HeaderReadWriteUtils::readCodePointTable(&mAttributeMap)) {} // Constructs header information using an attribute map. HeaderPolicy(const FormatUtils::FORMAT_VERSION dictFormatVersion, @@ -82,22 +82,19 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { DATE_KEY, TimeKeeper::peekCurrentTime() /* defaultValue */)), mLastDecayedTime(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap, DATE_KEY, TimeKeeper::peekCurrentTime() /* defaultValue */)), - mUnigramCount(0), mBigramCount(0), mExtendedRegionSize(0), + mUnigramCount(0), mBigramCount(0), mTrigramCount(0), mExtendedRegionSize(0), mHasHistoricalInfoOfWords(HeaderReadWriteUtils::readBoolAttributeValue( &mAttributeMap, HAS_HISTORICAL_INFO_KEY, false /* defaultValue */)), - mForgettingCurveOccurrencesToLevelUp(HeaderReadWriteUtils::readIntAttributeValue( - &mAttributeMap, FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP_KEY, - DEFAULT_FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP)), mForgettingCurveProbabilityValuesTableId(HeaderReadWriteUtils::readIntAttributeValue( &mAttributeMap, FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY, DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID)), - mForgettingCurveDurationToLevelDown(HeaderReadWriteUtils::readIntAttributeValue( - &mAttributeMap, FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS_KEY, - DEFAULT_FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS)), mMaxUnigramCount(HeaderReadWriteUtils::readIntAttributeValue( &mAttributeMap, MAX_UNIGRAM_COUNT_KEY, DEFAULT_MAX_UNIGRAM_COUNT)), mMaxBigramCount(HeaderReadWriteUtils::readIntAttributeValue( - &mAttributeMap, MAX_BIGRAM_COUNT_KEY, DEFAULT_MAX_BIGRAM_COUNT)) {} + &mAttributeMap, MAX_BIGRAM_COUNT_KEY, DEFAULT_MAX_BIGRAM_COUNT)), + mMaxTrigramCount(HeaderReadWriteUtils::readIntAttributeValue( + &mAttributeMap, MAX_TRIGRAM_COUNT_KEY, DEFAULT_MAX_TRIGRAM_COUNT)), + mCodePointTable(HeaderReadWriteUtils::readCodePointTable(&mAttributeMap)) {} // Copy header information HeaderPolicy(const HeaderPolicy *const headerPolicy) @@ -109,26 +106,25 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { mIsDecayingDict(headerPolicy->mIsDecayingDict), mDate(headerPolicy->mDate), mLastDecayedTime(headerPolicy->mLastDecayedTime), mUnigramCount(headerPolicy->mUnigramCount), mBigramCount(headerPolicy->mBigramCount), + mTrigramCount(headerPolicy->mTrigramCount), mExtendedRegionSize(headerPolicy->mExtendedRegionSize), mHasHistoricalInfoOfWords(headerPolicy->mHasHistoricalInfoOfWords), - mForgettingCurveOccurrencesToLevelUp( - headerPolicy->mForgettingCurveOccurrencesToLevelUp), mForgettingCurveProbabilityValuesTableId( headerPolicy->mForgettingCurveProbabilityValuesTableId), - mForgettingCurveDurationToLevelDown( - headerPolicy->mForgettingCurveDurationToLevelDown), mMaxUnigramCount(headerPolicy->mMaxUnigramCount), - mMaxBigramCount(headerPolicy->mMaxBigramCount) {} + mMaxBigramCount(headerPolicy->mMaxBigramCount), + mMaxTrigramCount(headerPolicy->mMaxTrigramCount), + mCodePointTable(headerPolicy->mCodePointTable) {} // Temporary dummy header. HeaderPolicy() : mDictFormatVersion(FormatUtils::UNKNOWN_VERSION), mDictionaryFlags(0), mSize(0), mAttributeMap(), mLocale(CharUtils::EMPTY_STRING), mMultiWordCostMultiplier(0.0f), mRequiresGermanUmlautProcessing(false), mIsDecayingDict(false), - mDate(0), mLastDecayedTime(0), mUnigramCount(0), mBigramCount(0), + mDate(0), mLastDecayedTime(0), mUnigramCount(0), mBigramCount(0), mTrigramCount(0), mExtendedRegionSize(0), mHasHistoricalInfoOfWords(false), - mForgettingCurveOccurrencesToLevelUp(0), mForgettingCurveProbabilityValuesTableId(0), - mForgettingCurveDurationToLevelDown(0), mMaxUnigramCount(0), mMaxBigramCount(0) {} + mForgettingCurveProbabilityValuesTableId(0), mMaxUnigramCount(0), mMaxBigramCount(0), + mMaxTrigramCount(0), mCodePointTable(nullptr) {} ~HeaderPolicy() {} @@ -139,6 +135,8 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { switch (mDictFormatVersion) { case FormatUtils::VERSION_2: return FormatUtils::VERSION_2; + case FormatUtils::VERSION_201: + return FormatUtils::VERSION_201; case FormatUtils::VERSION_4_ONLY_FOR_TESTING: return FormatUtils::VERSION_4_ONLY_FOR_TESTING; case FormatUtils::VERSION_4: @@ -194,6 +192,10 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { return mBigramCount; } + AK_FORCE_INLINE int getTrigramCount() const { + return mTrigramCount; + } + AK_FORCE_INLINE int getExtendedRegionSize() const { return mExtendedRegionSize; } @@ -211,18 +213,10 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { return &mAttributeMap; } - AK_FORCE_INLINE int getForgettingCurveOccurrencesToLevelUp() const { - return mForgettingCurveOccurrencesToLevelUp; - } - AK_FORCE_INLINE int getForgettingCurveProbabilityValuesTableId() const { return mForgettingCurveProbabilityValuesTableId; } - AK_FORCE_INLINE int getForgettingCurveDurationToLevelDown() const { - return mForgettingCurveDurationToLevelDown; - } - AK_FORCE_INLINE int getMaxUnigramCount() const { return mMaxUnigramCount; } @@ -231,15 +225,19 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { return mMaxBigramCount; } + AK_FORCE_INLINE int getMaxTrigramCount() const { + return mMaxTrigramCount; + } + void readHeaderValueOrQuestionMark(const char *const key, int *outValue, int outValueSize) const; bool fillInAndWriteHeaderToBuffer(const bool updatesLastDecayedTime, - const int unigramCount, const int bigramCount, - const int extendedRegionSize, BufferWithExtendableBuffer *const outBuffer) const; + const EntryCounts &entryCounts, const int extendedRegionSize, + BufferWithExtendableBuffer *const outBuffer) const; - void fillInHeader(const bool updatesLastDecayedTime, - const int unigramCount, const int bigramCount, const int extendedRegionSize, + void fillInHeader(const bool updatesLastDecayedTime, const EntryCounts &entryCounts, + const int extendedRegionSize, DictionaryHeaderStructurePolicy::AttributeMap *outAttributeMap) const; AK_FORCE_INLINE const std::vector<int> *getLocale() const { @@ -250,6 +248,10 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { return mDictFormatVersion >= FormatUtils::VERSION_4; } + const int *getCodePointTable() const { + return mCodePointTable; + } + private: DISALLOW_COPY_AND_ASSIGN(HeaderPolicy); @@ -260,6 +262,7 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { static const char *const LAST_DECAYED_TIME_KEY; static const char *const UNIGRAM_COUNT_KEY; static const char *const BIGRAM_COUNT_KEY; + static const char *const TRIGRAM_COUNT_KEY; static const char *const EXTENDED_REGION_SIZE_KEY; static const char *const HAS_HISTORICAL_INFO_KEY; static const char *const LOCALE_KEY; @@ -268,13 +271,13 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { static const char *const FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS_KEY; static const char *const MAX_UNIGRAM_COUNT_KEY; static const char *const MAX_BIGRAM_COUNT_KEY; + static const char *const MAX_TRIGRAM_COUNT_KEY; static const int DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE; static const float MULTIPLE_WORD_COST_MULTIPLIER_SCALE; - static const int DEFAULT_FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP; static const int DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID; - static const int DEFAULT_FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS; static const int DEFAULT_MAX_UNIGRAM_COUNT; static const int DEFAULT_MAX_BIGRAM_COUNT; + static const int DEFAULT_MAX_TRIGRAM_COUNT; const FormatUtils::FORMAT_VERSION mDictFormatVersion; const HeaderReadWriteUtils::DictionaryFlags mDictionaryFlags; @@ -288,13 +291,14 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { const int mLastDecayedTime; const int mUnigramCount; const int mBigramCount; + const int mTrigramCount; const int mExtendedRegionSize; const bool mHasHistoricalInfoOfWords; - const int mForgettingCurveOccurrencesToLevelUp; const int mForgettingCurveProbabilityValuesTableId; - const int mForgettingCurveDurationToLevelDown; const int mMaxUnigramCount; const int mMaxBigramCount; + const int mMaxTrigramCount; + const int *const mCodePointTable; const std::vector<int> readLocale() const; float readMultipleWordCostMultiplier() const; diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp index a8f8f284b..41a8b13b8 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp @@ -18,6 +18,7 @@ #include <cctype> #include <cstdio> +#include <memory> #include <vector> #include "defines.h" @@ -34,12 +35,13 @@ namespace latinime { const int HeaderReadWriteUtils::LARGEST_INT_DIGIT_COUNT = 11; const int HeaderReadWriteUtils::MAX_ATTRIBUTE_KEY_LENGTH = 256; -const int HeaderReadWriteUtils::MAX_ATTRIBUTE_VALUE_LENGTH = 256; +const int HeaderReadWriteUtils::MAX_ATTRIBUTE_VALUE_LENGTH = 2048; const int HeaderReadWriteUtils::HEADER_MAGIC_NUMBER_SIZE = 4; const int HeaderReadWriteUtils::HEADER_DICTIONARY_VERSION_SIZE = 2; const int HeaderReadWriteUtils::HEADER_FLAG_SIZE = 2; const int HeaderReadWriteUtils::HEADER_SIZE_FIELD_SIZE = 4; +const char *const HeaderReadWriteUtils::CODE_POINT_TABLE_KEY = "codePointTable"; const HeaderReadWriteUtils::DictionaryFlags HeaderReadWriteUtils::NO_FLAGS = 0; @@ -73,20 +75,32 @@ typedef DictionaryHeaderStructurePolicy::AttributeMap AttributeMap; return; } int keyBuffer[MAX_ATTRIBUTE_KEY_LENGTH]; - int valueBuffer[MAX_ATTRIBUTE_VALUE_LENGTH]; + std::unique_ptr<int[]> valueBuffer(new int[MAX_ATTRIBUTE_VALUE_LENGTH]); while (pos < headerSize) { + // The values in the header don't use the code point table for their encoding. const int keyLength = ByteArrayUtils::readStringAndAdvancePosition(dictBuf, - MAX_ATTRIBUTE_KEY_LENGTH, keyBuffer, &pos); + MAX_ATTRIBUTE_KEY_LENGTH, nullptr /* codePointTable */, keyBuffer, &pos); std::vector<int> key; key.insert(key.end(), keyBuffer, keyBuffer + keyLength); const int valueLength = ByteArrayUtils::readStringAndAdvancePosition(dictBuf, - MAX_ATTRIBUTE_VALUE_LENGTH, valueBuffer, &pos); + MAX_ATTRIBUTE_VALUE_LENGTH, nullptr /* codePointTable */, valueBuffer.get(), &pos); std::vector<int> value; - value.insert(value.end(), valueBuffer, valueBuffer + valueLength); + value.insert(value.end(), valueBuffer.get(), valueBuffer.get() + valueLength); headerAttributes->insert(AttributeMap::value_type(key, value)); } } +/* static */ const int *HeaderReadWriteUtils::readCodePointTable( + AttributeMap *const headerAttributes) { + AttributeMap::key_type keyVector; + insertCharactersIntoVector(CODE_POINT_TABLE_KEY, &keyVector); + AttributeMap::const_iterator it = headerAttributes->find(keyVector); + if (it == headerAttributes->end()) { + return nullptr; + } + return it->second.data(); +} + /* static */ bool HeaderReadWriteUtils::writeDictionaryVersion( BufferWithExtendableBuffer *const buffer, const FormatUtils::FORMAT_VERSION version, int *const writingPos) { @@ -96,7 +110,8 @@ typedef DictionaryHeaderStructurePolicy::AttributeMap AttributeMap; } switch (version) { case FormatUtils::VERSION_2: - // Version 2 dictionary writing is not supported. + case FormatUtils::VERSION_201: + // Version 2 or 201 dictionary writing is not supported. return false; case FormatUtils::VERSION_4_ONLY_FOR_TESTING: case FormatUtils::VERSION_4: @@ -142,7 +157,8 @@ typedef DictionaryHeaderStructurePolicy::AttributeMap AttributeMap; } /* static */ void HeaderReadWriteUtils::setCodePointVectorAttribute( - AttributeMap *const headerAttributes, const char *const key, const std::vector<int> value) { + AttributeMap *const headerAttributes, const char *const key, + const std::vector<int> &value) { AttributeMap::key_type keyVector; insertCharactersIntoVector(key, &keyVector); (*headerAttributes)[keyVector] = value; diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.h b/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.h index 9b90488fc..5dd91b26c 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.h @@ -46,6 +46,9 @@ class HeaderReadWriteUtils { static void fetchAllHeaderAttributes(const uint8_t *const dictBuf, DictionaryHeaderStructurePolicy::AttributeMap *const headerAttributes); + static const int *readCodePointTable( + DictionaryHeaderStructurePolicy::AttributeMap *const headerAttributes); + static bool writeDictionaryVersion(BufferWithExtendableBuffer *const buffer, const FormatUtils::FORMAT_VERSION version, int *const writingPos); @@ -64,7 +67,7 @@ class HeaderReadWriteUtils { */ static void setCodePointVectorAttribute( DictionaryHeaderStructurePolicy::AttributeMap *const headerAttributes, - const char *const key, const std::vector<int> value); + const char *const key, const std::vector<int> &value); static void setBoolAttribute( DictionaryHeaderStructurePolicy::AttributeMap *const headerAttributes, @@ -101,6 +104,8 @@ class HeaderReadWriteUtils { static const int HEADER_FLAG_SIZE; static const int HEADER_SIZE_FIELD_SIZE; + static const char *const CODE_POINT_TABLE_KEY; + // Value for the "flags" field. It's unused at the moment. static const DictionaryFlags NO_FLAGS; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/bigram/ver4_bigram_list_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/bigram/ver4_bigram_list_policy.cpp index 3e8e059f2..bc0f47f79 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/bigram/ver4_bigram_list_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/bigram/ver4_bigram_list_policy.cpp @@ -24,7 +24,7 @@ #include "suggest/policyimpl/dictionary/structure/backward/v402/bigram/ver4_bigram_list_policy.h" -#include "suggest/core/dictionary/property/bigram_property.h" +#include "suggest/core/dictionary/property/ngram_property.h" #include "suggest/policyimpl/dictionary/header/header_policy.h" #include "suggest/policyimpl/dictionary/structure/pt_common/bigram/bigram_list_read_write_utils.h" #include "suggest/policyimpl/dictionary/structure/backward/v402/content/bigram_dict_content.h" @@ -60,7 +60,7 @@ void Ver4BigramListPolicy::getNextBigram(int *const outBigramPos, int *const out } bool Ver4BigramListPolicy::addNewEntry(const int terminalId, const int newTargetTerminalId, - const BigramProperty *const bigramProperty, bool *const outAddedNewEntry) { + const NgramProperty *const ngramProperty, bool *const outAddedNewEntry) { // 1. The word has no bigrams yet. // 2. The word has bigrams, and there is the target in the list. // 3. The word has bigrams, and there is an invalid entry that can be reclaimed. @@ -79,7 +79,7 @@ bool Ver4BigramListPolicy::addNewEntry(const int terminalId, const int newTarget const BigramEntry newBigramEntry(false /* hasNext */, NOT_A_PROBABILITY, newTargetTerminalId); const BigramEntry bigramEntryToWrite = createUpdatedBigramEntryFrom(&newBigramEntry, - bigramProperty); + ngramProperty); // Write an entry. const int writingPos = mBigramDictContent->getBigramListHeadPos(terminalId); if (!mBigramDictContent->writeBigramEntry(&bigramEntryToWrite, writingPos)) { @@ -112,7 +112,7 @@ bool Ver4BigramListPolicy::addNewEntry(const int terminalId, const int newTarget const BigramEntry newBigramEntry(false /* hasNext */, NOT_A_PROBABILITY, newTargetTerminalId); const BigramEntry bigramEntryToWrite = createUpdatedBigramEntryFrom( - &newBigramEntry, bigramProperty); + &newBigramEntry, ngramProperty); if (!mBigramDictContent->writeBigramEntryAtTail(&bigramEntryToWrite)) { return false; } @@ -138,7 +138,7 @@ bool Ver4BigramListPolicy::addNewEntry(const int terminalId, const int newTarget const BigramEntry updatedBigramEntry = originalBigramEntry.updateTargetTerminalIdAndGetEntry(newTargetTerminalId); const BigramEntry bigramEntryToWrite = createUpdatedBigramEntryFrom( - &updatedBigramEntry, bigramProperty); + &updatedBigramEntry, ngramProperty); return mBigramDictContent->writeBigramEntry(&bigramEntryToWrite, entryPosToUpdate); } @@ -264,18 +264,17 @@ int Ver4BigramListPolicy::getEntryPosToUpdate(const int targetTerminalIdToFind, const BigramEntry Ver4BigramListPolicy::createUpdatedBigramEntryFrom( const BigramEntry *const originalBigramEntry, - const BigramProperty *const bigramProperty) const { + const NgramProperty *const ngramProperty) const { // TODO: Consolidate historical info and probability. if (mHeaderPolicy->hasHistoricalInfoOfWords()) { - const HistoricalInfo historicalInfoForUpdate(bigramProperty->getTimestamp(), - bigramProperty->getLevel(), bigramProperty->getCount()); + const HistoricalInfo &historicalInfoForUpdate = ngramProperty->getHistoricalInfo(); const HistoricalInfo updatedHistoricalInfo = ForgettingCurveUtils::createUpdatedHistoricalInfo( - originalBigramEntry->getHistoricalInfo(), bigramProperty->getProbability(), + originalBigramEntry->getHistoricalInfo(), ngramProperty->getProbability(), &historicalInfoForUpdate, mHeaderPolicy); return originalBigramEntry->updateHistoricalInfoAndGetEntry(&updatedHistoricalInfo); } else { - return originalBigramEntry->updateProbabilityAndGetEntry(bigramProperty->getProbability()); + return originalBigramEntry->updateProbabilityAndGetEntry(ngramProperty->getProbability()); } } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/bigram/ver4_bigram_list_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/bigram/ver4_bigram_list_policy.h index 50a4c9743..aac6f5470 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/bigram/ver4_bigram_list_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/bigram/ver4_bigram_list_policy.h @@ -36,7 +36,7 @@ namespace v402 { class BigramDictContent; } // namespace v402 } // namespace backward -class BigramProperty; +class NgramProperty; namespace backward { namespace v402 { } // namespace v402 @@ -64,7 +64,7 @@ class Ver4BigramListPolicy : public DictionaryBigramsStructurePolicy { } bool addNewEntry(const int terminalId, const int newTargetTerminalId, - const BigramProperty *const bigramProperty, bool *const outAddedNewEntry); + const NgramProperty *const ngramProperty, bool *const outAddedNewEntry); bool removeEntry(const int terminalId, const int targetTerminalId); @@ -80,7 +80,7 @@ class Ver4BigramListPolicy : public DictionaryBigramsStructurePolicy { int *const outTailEntryPos) const; const BigramEntry createUpdatedBigramEntryFrom(const BigramEntry *const originalBigramEntry, - const BigramProperty *const bigramProperty) const; + const NgramProperty *const ngramProperty) const; bool updateHasNextFlag(const bool hasNext, const int bigramEntryPos); diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/bigram_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/bigram_dict_content.cpp index e2dd93c5e..9e1adff70 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/bigram_dict_content.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/bigram_dict_content.cpp @@ -83,10 +83,10 @@ bool BigramDictContent::writeBigramEntryAndAdvancePosition( } if (mHasHistoricalInfo) { const HistoricalInfo *const historicalInfo = bigramEntryToWrite->getHistoricalInfo(); - if (!bigramListBuffer->writeUintAndAdvancePosition(historicalInfo->getTimeStamp(), + if (!bigramListBuffer->writeUintAndAdvancePosition(historicalInfo->getTimestamp(), Ver4DictConstants::TIME_STAMP_FIELD_SIZE, entryWritingPos)) { AKLOGE("Cannot write bigram timestamps. pos: %d, timestamp: %d", *entryWritingPos, - historicalInfo->getTimeStamp()); + historicalInfo->getTimestamp()); return false; } if (!bigramListBuffer->writeUintAndAdvancePosition(historicalInfo->getLevel(), diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/bigram_entry.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/bigram_entry.h index 40968b4d8..480095a2f 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/bigram_entry.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/bigram_entry.h @@ -25,8 +25,8 @@ #define LATINIME_BACKWARD_V402_BIGRAM_ENTRY_H #include "defines.h" +#include "suggest/core/dictionary/property/historical_info.h" #include "suggest/policyimpl/dictionary/structure/backward/v402/ver4_dict_constants.h" -#include "suggest/policyimpl/dictionary/utils/historical_info.h" namespace latinime { namespace backward { diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/probability_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/probability_dict_content.cpp index c671647d4..ef6166ffd 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/probability_dict_content.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/probability_dict_content.cpp @@ -74,8 +74,8 @@ bool ProbabilityDictContent::setProbabilityEntry(const int terminalId, return false; } writingPos += getEntrySize(); - mSize++; } + mSize = terminalId + 1; } return writeEntry(probabilityEntry, entryPos); } @@ -100,7 +100,6 @@ bool ProbabilityDictContent::flushToFile(const char *const dictPath) const { bool ProbabilityDictContent::runGC( const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, const ProbabilityDictContent *const originalProbabilityDictContent) { - mSize = 0; for (TerminalPositionLookupTable::TerminalIdMap::const_iterator it = terminalIdMap->begin(); it != terminalIdMap->end(); ++it) { const ProbabilityEntry probabilityEntry = @@ -109,7 +108,6 @@ bool ProbabilityDictContent::runGC( AKLOGE("Cannot set probability entry in runGC. terminalId: %d", it->second); return false; } - mSize++; } return true; } @@ -147,7 +145,7 @@ bool ProbabilityDictContent::writeEntry(const ProbabilityEntry *const probabilit } if (mHasHistoricalInfo) { const HistoricalInfo *const historicalInfo = probabilityEntry->getHistoricalInfo(); - if (!bufferToWrite->writeUintAndAdvancePosition(historicalInfo->getTimeStamp(), + if (!bufferToWrite->writeUintAndAdvancePosition(historicalInfo->getTimestamp(), Ver4DictConstants::TIME_STAMP_FIELD_SIZE, &writingPos)) { AKLOGE("Cannot write timestamp in probability dict content. pos: %d", writingPos); return false; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/probability_entry.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/probability_entry.h index 8ccfa33dc..4111a49c0 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/probability_entry.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/content/probability_entry.h @@ -25,8 +25,8 @@ #define LATINIME_BACKWARD_V402_PROBABILITY_ENTRY_H #include "defines.h" +#include "suggest/core/dictionary/property/historical_info.h" #include "suggest/policyimpl/dictionary/structure/backward/v402/ver4_dict_constants.h" -#include "suggest/policyimpl/dictionary/utils/historical_info.h" namespace latinime { namespace backward { diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_reader.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_reader.cpp index 82399f190..5c639b19c 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_reader.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_reader.cpp @@ -23,6 +23,7 @@ #include "suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_reader.h" +#include "suggest/policyimpl/dictionary/header/header_policy.h" #include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_utils.h" #include "suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h" #include "suggest/policyimpl/dictionary/structure/backward/v402/content/probability_dict_content.h" @@ -59,8 +60,8 @@ const PtNodeParams Ver4PatriciaTrieNodeReader::fetchPtNodeInfoFromBufferAndProce const int parentPos = DynamicPtReadingUtils::getParentPtNodePos(parentPosOffset, headPos); int codePoints[MAX_WORD_LENGTH]; - const int codePonitCount = PatriciaTrieReadingUtils::getCharsAndAdvancePosition( - dictBuf, flags, MAX_WORD_LENGTH, codePoints, &pos); + const int codePointCount = PatriciaTrieReadingUtils::getCharsAndAdvancePosition( + dictBuf, flags, MAX_WORD_LENGTH, mHeaderPolicy->getCodePointTable(), codePoints, &pos); int terminalIdFieldPos = NOT_A_DICT_POS; int terminalId = Ver4DictConstants::NOT_A_TERMINAL_ID; int probability = NOT_A_PROBABILITY; @@ -98,7 +99,7 @@ const PtNodeParams Ver4PatriciaTrieNodeReader::fetchPtNodeInfoFromBufferAndProce // The destination position is stored at the same place as the parent position. return fetchPtNodeInfoFromBufferAndProcessMovedPtNode(parentPos, newSiblingNodePos); } else { - return PtNodeParams(headPos, flags, parentPos, codePonitCount, codePoints, + return PtNodeParams(headPos, flags, parentPos, codePointCount, codePoints, terminalIdFieldPos, terminalId, probability, childrenPosFieldPos, childrenPos, newSiblingNodePos); } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp index 278f2b199..d558b949a 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp @@ -232,10 +232,10 @@ bool Ver4PatriciaTrieNodeWriter::writeNewTerminalPtNodeAndAdvancePosition( } bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds, const int wordId, - const BigramProperty *const bigramProperty, bool *const outAddedNewEntry) { - if (!mBigramPolicy->addNewEntry(prevWordIds[0], wordId, bigramProperty, outAddedNewEntry)) { - AKLOGE("Cannot add new bigram entry. terminalId: %d, targetTerminalId: %d", - sourcePtNodeParams->getTerminalId(), targetPtNodeParam->getTerminalId()); + const NgramProperty *const ngramProperty, bool *const outAddedNewEntry) { + if (!mBigramPolicy->addNewEntry(prevWordIds[0], wordId, ngramProperty, outAddedNewEntry)) { + AKLOGE("Cannot add new bigram entry. prevWordId: %d, wordId: %d", + prevWordIds[0], wordId); return false; } const int ptNodePos = @@ -245,7 +245,7 @@ bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds if (!sourcePtNodeParams.hasBigrams()) { // Update has bigrams flag. return updatePtNodeFlags(sourcePtNodeParams.getHeadPos(), - sourcePtNodeParams.isBlacklisted(), sourcePtNodeParams.isNotAWord(), + sourcePtNodeParams.isPossiblyOffensive(), sourcePtNodeParams.isNotAWord(), sourcePtNodeParams.isTerminal(), sourcePtNodeParams.hasShortcutTargets(), true /* hasBigrams */, sourcePtNodeParams.getCodePointCount() > 1 /* hasMultipleChars */); @@ -310,13 +310,13 @@ bool Ver4PatriciaTrieNodeWriter::addShortcutTarget(const PtNodeParams *const ptN const int shortcutProbability) { if (!mShortcutPolicy->addNewShortcut(ptNodeParams->getTerminalId(), targetCodePoints, targetCodePointCount, shortcutProbability)) { - AKLOGE("Cannot add new shortuct entry. terminalId: %d", ptNodeParams->getTerminalId()); + AKLOGE("Cannot add new shortcut entry. terminalId: %d", ptNodeParams->getTerminalId()); return false; } if (!ptNodeParams->hasShortcutTargets()) { // Update has shortcut targets flag. return updatePtNodeFlags(ptNodeParams->getHeadPos(), - ptNodeParams->isBlacklisted(), ptNodeParams->isNotAWord(), + ptNodeParams->isPossiblyOffensive(), ptNodeParams->isNotAWord(), ptNodeParams->isTerminal(), true /* hasShortcutTargets */, ptNodeParams->hasBigrams(), ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */); @@ -330,7 +330,7 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeHasBigramsAndShortcutTargetsFlags( ptNodeParams->getTerminalId()) != NOT_A_DICT_POS; const bool hasShortcutTargets = mBuffers->getShortcutDictContent()->getShortcutListHeadPos( ptNodeParams->getTerminalId()) != NOT_A_DICT_POS; - return updatePtNodeFlags(ptNodeParams->getHeadPos(), ptNodeParams->isBlacklisted(), + return updatePtNodeFlags(ptNodeParams->getHeadPos(), ptNodeParams->isPossiblyOffensive(), ptNodeParams->isNotAWord(), ptNodeParams->isTerminal(), hasShortcutTargets, hasBigrams, ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */); } @@ -386,8 +386,9 @@ bool Ver4PatriciaTrieNodeWriter::writePtNodeAndGetTerminalIdAndAdvancePosition( ptNodeParams->getChildrenPos(), ptNodeWritingPos)) { return false; } - return updatePtNodeFlags(nodePos, ptNodeParams->isBlacklisted(), ptNodeParams->isNotAWord(), - isTerminal, ptNodeParams->hasShortcutTargets(), ptNodeParams->hasBigrams(), + return updatePtNodeFlags(nodePos, ptNodeParams->isPossiblyOffensive(), + ptNodeParams->isNotAWord(), isTerminal, ptNodeParams->hasShortcutTargets(), + ptNodeParams->hasBigrams(), ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */); } @@ -396,8 +397,7 @@ const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom( const UnigramProperty *const unigramProperty) const { // TODO: Consolidate historical info and probability. if (mHeaderPolicy->hasHistoricalInfoOfWords()) { - const HistoricalInfo historicalInfoForUpdate(unigramProperty->getTimestamp(), - unigramProperty->getLevel(), unigramProperty->getCount()); + const HistoricalInfo &historicalInfoForUpdate = unigramProperty->getHistoricalInfo(); const HistoricalInfo updatedHistoricalInfo = ForgettingCurveUtils::createUpdatedHistoricalInfo( originalProbabilityEntry->getHistoricalInfo(), @@ -425,6 +425,18 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeFlags(const int ptNodePos, return true; } +bool Ver4PatriciaTrieNodeWriter::suppressUnigramEntry(const PtNodeParams *const ptNodeParams) { + if (!mHeaderPolicy->hasHistoricalInfoOfWords()) { + // Require historical info to suppress unigram entry. + return false; + } + const HistoricalInfo suppressedHistorycalInfo(0 /* timestamp */, 0 /* level */, 0 /* count */); + const ProbabilityEntry probabilityEntryToWrite = + ProbabilityEntry().createEntryWithUpdatedHistoricalInfo(&suppressedHistorycalInfo); + return mBuffers->getMutableProbabilityDictContent()->setProbabilityEntry( + ptNodeParams->getTerminalId(), &probabilityEntryToWrite); +} + } // namespace v402 } // namespace backward } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.h index d49d9a666..d0bab50f8 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.h @@ -94,7 +94,7 @@ class Ver4PatriciaTrieNodeWriter : public PtNodeWriter { const UnigramProperty *const unigramProperty, int *const ptNodeWritingPos); virtual bool addNgramEntry(const WordIdArrayView prevWordIds, const int wordId, - const BigramProperty *const bigramProperty, bool *const outAddedNewEntry); + const NgramProperty *const ngramProperty, bool *const outAddedNewEntry); virtual bool removeNgramEntry(const WordIdArrayView prevWordIds, const int wordId); @@ -111,6 +111,11 @@ class Ver4PatriciaTrieNodeWriter : public PtNodeWriter { bool updatePtNodeHasBigramsAndShortcutTargetsFlags(const PtNodeParams *const ptNodeParams); + // Suppress unigram not to use the word for generating suggestions. So, this method can be used + // only for dictionaries with historical info. Also, suppressed entries are included in unigram + // count. They will be removed from the dictionary during GC. + bool suppressUnigramEntry(const PtNodeParams *const ptNodeParams); + private: DISALLOW_COPY_AND_ASSIGN(Ver4PatriciaTrieNodeWriter); diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp index 1296b8acd..08e39ce43 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp @@ -28,11 +28,12 @@ #include "suggest/core/dicnode/dic_node.h" #include "suggest/core/dicnode/dic_node_vector.h" +#include "suggest/core/dictionary/multi_bigram_map.h" #include "suggest/core/dictionary/ngram_listener.h" -#include "suggest/core/dictionary/property/bigram_property.h" +#include "suggest/core/dictionary/property/ngram_property.h" #include "suggest/core/dictionary/property/unigram_property.h" #include "suggest/core/dictionary/property/word_property.h" -#include "suggest/core/session/prev_words_info.h" +#include "suggest/core/session/ngram_context.h" #include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h" #include "suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_reader.h" #include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h" @@ -51,6 +52,7 @@ const char *const Ver4PatriciaTriePolicy::MAX_BIGRAM_COUNT_QUERY = "MAX_BIGRAM_C const int Ver4PatriciaTriePolicy::MARGIN_TO_REFUSE_DYNAMIC_OPERATIONS = 1024; const int Ver4PatriciaTriePolicy::MIN_DICT_SIZE_TO_REFUSE_DYNAMIC_OPERATIONS = Ver4DictConstants::MAX_DICTIONARY_SIZE - MARGIN_TO_REFUSE_DYNAMIC_OPERATIONS; +const int Ver4PatriciaTriePolicy::DUMMY_PROBABILITY_FOR_VALID_WORDS = 1; void Ver4PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const dicNode, DicNodeVector *const childDicNodes) const { @@ -76,12 +78,9 @@ void Ver4PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const d // Skip PtNodes that represent non-word information. continue; } - childDicNodes->pushLeavingChild(dicNode, ptNodeParams.getHeadPos(), - ptNodeParams.getChildrenPos(), ptNodeParams.getProbability(), isTerminal, - ptNodeParams.hasChildren(), - ptNodeParams.isBlacklisted() - || ptNodeParams.isNotAWord() /* isBlacklistedOrNotAWord */, - ptNodeParams.getCodePointCount(), ptNodeParams.getCodePoints()); + const int wordId = isTerminal ? ptNodeParams.getHeadPos() : NOT_A_WORD_ID; + childDicNodes->pushLeavingChild(dicNode, ptNodeParams.getChildrenPos(), + wordId, ptNodeParams.getCodePointArrayView()); } if (readingHelper.isError()) { mIsCorrupted = true; @@ -89,13 +88,13 @@ void Ver4PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const d } } -int Ver4PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( - const int ptNodePos, const int maxCodePointCount, int *const outCodePoints, - int *const outUnigramProbability) const { +int Ver4PatriciaTriePolicy::getCodePointsAndReturnCodePointCount(const int wordId, + const int maxCodePointCount, int *const outCodePoints) const { DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader); + const int ptNodePos = getTerminalPtNodePosFromWordId(wordId); readingHelper.initWithPtNodePos(ptNodePos); - const int codePointCount = readingHelper.getCodePointsAndProbabilityAndReturnCodePointCount( - maxCodePointCount, outCodePoints, outUnigramProbability); + const int codePointCount = readingHelper.getCodePointsAndReturnCodePointCount( + maxCodePointCount, outCodePoints); if (readingHelper.isError()) { mIsCorrupted = true; AKLOGE("Dictionary reading error in getCodePointsAndProbabilityAndReturnCodePointCount()."); @@ -103,72 +102,143 @@ int Ver4PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( return codePointCount; } -int Ver4PatriciaTriePolicy::getTerminalPtNodePositionOfWord(const int *const inWord, - const int length, const bool forceLowerCaseSearch) const { +int Ver4PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints, + const bool forceLowerCaseSearch) const { DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader); readingHelper.initWithPtNodeArrayPos(getRootPosition()); - const int ptNodePos = - readingHelper.getTerminalPtNodePositionOfWord(inWord, length, forceLowerCaseSearch); + const int ptNodePos = readingHelper.getTerminalPtNodePositionOfWord(wordCodePoints.data(), + wordCodePoints.size(), forceLowerCaseSearch); if (readingHelper.isError()) { mIsCorrupted = true; - AKLOGE("Dictionary reading error in createAndGetAllChildDicNodes()."); + AKLOGE("Dictionary reading error in getWordId()."); + } + return getWordIdFromTerminalPtNodePos(ptNodePos); +} + +const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext( + const WordIdArrayView prevWordIds, const int wordId, + MultiBigramMap *const multiBigramMap) const { + if (wordId == NOT_A_WORD_ID) { + return WordAttributes(); } - return ptNodePos; + const int ptNodePos = getTerminalPtNodePosFromWordId(wordId); + const PtNodeParams ptNodeParams(mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos)); + if (multiBigramMap) { + const int probability = multiBigramMap->getBigramProbability(this /* structurePolicy */, + prevWordIds, wordId, ptNodeParams.getProbability()); + return getWordAttributes(probability, ptNodeParams); + } + if (!prevWordIds.empty()) { + const int probability = getProbabilityOfWord(prevWordIds, wordId); + if (probability != NOT_A_PROBABILITY) { + return getWordAttributes(probability, ptNodeParams); + } + } + return getWordAttributes(getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY), + ptNodeParams); +} + +const WordAttributes Ver4PatriciaTriePolicy::getWordAttributes(const int probability, + const PtNodeParams &ptNodeParams) const { + return WordAttributes(probability, ptNodeParams.isBlacklisted(), ptNodeParams.isNotAWord(), + ptNodeParams.getProbability() == 0); } int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability, const int bigramProbability) const { - if (mHeaderPolicy->isDecayingDict()) { - // Both probabilities are encoded. Decode them and get probability. - return ForgettingCurveUtils::getProbability(unigramProbability, bigramProbability); - } else { - if (unigramProbability == NOT_A_PROBABILITY) { - return NOT_A_PROBABILITY; - } else if (bigramProbability == NOT_A_PROBABILITY) { - return ProbabilityUtils::backoff(unigramProbability); - } else { - return bigramProbability; - } + // In the v4 format, bigramProbability is a conditional probability. + const int bigramConditionalProbability = bigramProbability; + if (unigramProbability == NOT_A_PROBABILITY) { + return NOT_A_PROBABILITY; } + if (bigramConditionalProbability == NOT_A_PROBABILITY) { + return ProbabilityUtils::backoff(unigramProbability); + } + return bigramConditionalProbability; } -int Ver4PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtNodePos, - const int ptNodePos) const { - if (ptNodePos == NOT_A_DICT_POS) { +int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds, + const int wordId) const { + if (wordId == NOT_A_WORD_ID) { return NOT_A_PROBABILITY; } + const int ptNodePos = getTerminalPtNodePosFromWordId(wordId); const PtNodeParams ptNodeParams(mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos)); if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) { return NOT_A_PROBABILITY; } - if (prevWordsPtNodePos) { - const int bigramsPosition = getBigramsPositionOfPtNode(prevWordsPtNodePos[0]); - BinaryDictionaryBigramsIterator bigramsIt(&mBigramPolicy, bigramsPosition); - while (bigramsIt.hasNext()) { - bigramsIt.next(); - if (bigramsIt.getBigramPos() == ptNodePos - && bigramsIt.getProbability() != NOT_A_PROBABILITY) { - return getProbability(ptNodeParams.getProbability(), bigramsIt.getProbability()); - } - } + if (prevWordIds.empty()) { + return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY); + } + if (prevWordIds[0] == NOT_A_WORD_ID) { return NOT_A_PROBABILITY; } - return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY); + const PtNodeParams prevWordPtNodeParams = + mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(prevWordIds[0]); + if (prevWordPtNodeParams.isDeleted()) { + return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY); + } + const int bigramsPosition = mBuffers->getBigramDictContent()->getBigramListHeadPos( + prevWordPtNodeParams.getTerminalId()); + BinaryDictionaryBigramsIterator bigramsIt(&mBigramPolicy, bigramsPosition); + while (bigramsIt.hasNext()) { + bigramsIt.next(); + if (bigramsIt.getBigramPos() == ptNodePos + && bigramsIt.getProbability() != NOT_A_PROBABILITY) { + const int bigramConditionalProbability = getBigramConditionalProbability( + prevWordPtNodeParams.getProbability(), + prevWordPtNodeParams.representsBeginningOfSentence(), + bigramsIt.getProbability()); + return getProbability(ptNodeParams.getProbability(), bigramConditionalProbability); + } + } + return NOT_A_PROBABILITY; } -void Ver4PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordsPtNodePos, +void Ver4PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordIds, NgramListener *const listener) const { - if (!prevWordsPtNodePos) { + if (prevWordIds.firstOrDefault(NOT_A_DICT_POS) == NOT_A_DICT_POS) { + return; + } + const PtNodeParams prevWordPtNodeParams = + mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(prevWordIds[0]); + if (prevWordPtNodeParams.isDeleted()) { return; } - const int bigramsPosition = getBigramsPositionOfPtNode(prevWordsPtNodePos[0]); + const int bigramsPosition = mBuffers->getBigramDictContent()->getBigramListHeadPos( + prevWordPtNodeParams.getTerminalId()); BinaryDictionaryBigramsIterator bigramsIt(&mBigramPolicy, bigramsPosition); while (bigramsIt.hasNext()) { bigramsIt.next(); - listener->onVisitEntry(bigramsIt.getProbability(), bigramsIt.getBigramPos()); + const int bigramConditionalProbability = getBigramConditionalProbability( + prevWordPtNodeParams.getProbability(), + prevWordPtNodeParams.representsBeginningOfSentence(), bigramsIt.getProbability()); + listener->onVisitEntry(bigramConditionalProbability, + getWordIdFromTerminalPtNodePos(bigramsIt.getBigramPos())); + } +} + +int Ver4PatriciaTriePolicy::getBigramConditionalProbability(const int prevWordUnigramProbability, + const bool isInBeginningOfSentenceContext, const int bigramProbability) const { + if (mHeaderPolicy->hasHistoricalInfoOfWords()) { + if (isInBeginningOfSentenceContext) { + return bigramProbability; + } + // Calculate conditional probability. + return std::min(MAX_PROBABILITY - prevWordUnigramProbability + bigramProbability, + MAX_PROBABILITY); + } else { + // bigramProbability is a conditional probability. + return bigramProbability; } } +BinaryDictionaryShortcutIterator Ver4PatriciaTriePolicy::getShortcutIterator( + const int wordId) const { + const int shortcutPos = getShortcutPositionOfPtNode(getTerminalPtNodePosFromWordId(wordId)); + return BinaryDictionaryShortcutIterator(&mShortcutPolicy, shortcutPos); +} + int Ver4PatriciaTriePolicy::getShortcutPositionOfPtNode(const int ptNodePos) const { if (ptNodePos == NOT_A_DICT_POS) { return NOT_A_DICT_POS; @@ -193,7 +263,7 @@ int Ver4PatriciaTriePolicy::getBigramsPositionOfPtNode(const int ptNodePos) cons ptNodeParams.getTerminalId()); } -bool Ver4PatriciaTriePolicy::addUnigramEntry(const int *const word, const int length, +bool Ver4PatriciaTriePolicy::addUnigramEntry(const CodePointArrayView wordCodePoints, const UnigramProperty *const unigramProperty) { if (!mBuffers->isUpdatable()) { AKLOGI("Warning: addUnigramEntry() is called for non-updatable dictionary."); @@ -204,13 +274,14 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const int *const word, const int le mDictBuffer->getTailPosition()); return false; } - if (length > MAX_WORD_LENGTH) { - AKLOGE("The word is too long to insert to the dictionary, length: %d", length); + if (wordCodePoints.size() > MAX_WORD_LENGTH) { + AKLOGE("The word is too long to insert to the dictionary, length: %zd", + wordCodePoints.size()); return false; } for (const auto &shortcut : unigramProperty->getShortcuts()) { if (shortcut.getTargetCodePoints()->size() > MAX_WORD_LENGTH) { - AKLOGE("One of shortcut targets is too long to insert to the dictionary, length: %d", + AKLOGE("One of shortcut targets is too long to insert to the dictionary, length: %zd", shortcut.getTargetCodePoints()->size()); return false; } @@ -219,8 +290,8 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const int *const word, const int le readingHelper.initWithPtNodeArrayPos(getRootPosition()); bool addedNewUnigram = false; int codePointsToAdd[MAX_WORD_LENGTH]; - int codePointCountToAdd = length; - memmove(codePointsToAdd, word, sizeof(int) * length); + int codePointCountToAdd = wordCodePoints.size(); + memmove(codePointsToAdd, wordCodePoints.data(), sizeof(int) * codePointCountToAdd); if (unigramProperty->representsBeginningOfSentence()) { codePointCountToAdd = CharUtils::attachBeginningOfSentenceMarker(codePointsToAdd, codePointCountToAdd, MAX_WORD_LENGTH); @@ -228,24 +299,25 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const int *const word, const int le if (codePointCountToAdd <= 0) { return false; } - if (mUpdatingHelper.addUnigramWord(&readingHelper, codePointsToAdd, codePointCountToAdd, - unigramProperty, &addedNewUnigram)) { + const CodePointArrayView codePointArrayView(codePointsToAdd, codePointCountToAdd); + if (mUpdatingHelper.addUnigramWord(&readingHelper, codePointArrayView, unigramProperty, + &addedNewUnigram)) { if (addedNewUnigram && !unigramProperty->representsBeginningOfSentence()) { - mUnigramCount++; + mEntryCounters.incrementUnigramCount(); } if (unigramProperty->getShortcuts().size() > 0) { // Add shortcut target. - const int wordPos = getTerminalPtNodePositionOfWord(word, length, - false /* forceLowerCaseSearch */); + const int wordPos = getTerminalPtNodePosFromWordId( + getWordId(codePointArrayView, false /* forceLowerCaseSearch */)); if (wordPos == NOT_A_DICT_POS) { AKLOGE("Cannot find terminal PtNode position to add shortcut target."); return false; } for (const auto &shortcut : unigramProperty->getShortcuts()) { if (!mUpdatingHelper.addShortcutTarget(wordPos, - shortcut.getTargetCodePoints()->data(), - shortcut.getTargetCodePoints()->size(), shortcut.getProbability())) { - AKLOGE("Cannot add new shortcut target. PtNodePos: %d, length: %d, " + CodePointArrayView(*shortcut.getTargetCodePoints()), + shortcut.getProbability())) { + AKLOGE("Cannot add new shortcut target. PtNodePos: %d, length: %zd, " "probability: %d", wordPos, shortcut.getTargetCodePoints()->size(), shortcut.getProbability()); return false; @@ -258,8 +330,21 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const int *const word, const int le } } -bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsInfo, - const BigramProperty *const bigramProperty) { +bool Ver4PatriciaTriePolicy::removeUnigramEntry(const CodePointArrayView wordCodePoints) { + if (!mBuffers->isUpdatable()) { + AKLOGI("Warning: removeUnigramEntry() is called for non-updatable dictionary."); + return false; + } + const int ptNodePos = getTerminalPtNodePosFromWordId( + getWordId(wordCodePoints, false /* forceLowerCaseSearch */)); + if (ptNodePos == NOT_A_DICT_POS) { + return false; + } + const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); + return mNodeWriter.suppressUnigramEntry(&ptNodeParams); +} + +bool Ver4PatriciaTriePolicy::addNgramEntry(const NgramProperty *const ngramProperty) { if (!mBuffers->isUpdatable()) { AKLOGI("Warning: addNgramEntry() is called for non-updatable dictionary."); return false; @@ -269,50 +354,50 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI mDictBuffer->getTailPosition()); return false; } - if (!prevWordsInfo->isValid()) { - AKLOGE("prev words info is not valid for adding n-gram entry to the dictionary."); + const NgramContext *const ngramContext = ngramProperty->getNgramContext(); + if (!ngramContext->isValid()) { + AKLOGE("Ngram context is not valid for adding n-gram entry to the dictionary."); return false; } - if (bigramProperty->getTargetCodePoints()->size() > MAX_WORD_LENGTH) { + if (ngramProperty->getTargetCodePoints()->size() > MAX_WORD_LENGTH) { AKLOGE("The word is too long to insert the ngram to the dictionary. " - "length: %d", bigramProperty->getTargetCodePoints()->size()); + "length: %zd", ngramProperty->getTargetCodePoints()->size()); return false; } - int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; - prevWordsInfo->getPrevWordsTerminalPtNodePos(this, prevWordsPtNodePos, + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; + const WordIdArrayView prevWordIds = ngramContext->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */); - // TODO: Support N-gram. - if (prevWordsPtNodePos[0] == NOT_A_DICT_POS) { - if (prevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */)) { - const std::vector<UnigramProperty::ShortcutProperty> shortcuts; + if (prevWordIds.empty()) { + return false; + } + if (prevWordIds[0] == NOT_A_WORD_ID) { + if (ngramContext->isNthPrevWordBeginningOfSentence(1 /* n */)) { const UnigramProperty beginningOfSentenceUnigramProperty( true /* representsBeginningOfSentence */, true /* isNotAWord */, - false /* isBlacklisted */, MAX_PROBABILITY /* probability */, - NOT_A_TIMESTAMP /* timestamp */, 0 /* level */, 0 /* count */, &shortcuts); - if (!addUnigramEntry(prevWordsInfo->getNthPrevWordCodePoints(1 /* n */), - prevWordsInfo->getNthPrevWordCodePointCount(1 /* n */), + false /* isBlacklisted */, MAX_PROBABILITY /* probability */, HistoricalInfo()); + if (!addUnigramEntry(ngramContext->getNthPrevWordCodePoints(1 /* n */), &beginningOfSentenceUnigramProperty)) { AKLOGE("Cannot add unigram entry for the beginning-of-sentence."); return false; } - // Refresh Terminal PtNode positions. - prevWordsInfo->getPrevWordsTerminalPtNodePos(this, prevWordsPtNodePos, - false /* tryLowerCaseSearch */); + // Refresh word ids. + ngramContext->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */); } else { return false; } } - const int word1Pos = getTerminalPtNodePositionOfWord( - bigramProperty->getTargetCodePoints()->data(), - bigramProperty->getTargetCodePoints()->size(), false /* forceLowerCaseSearch */); - if (word1Pos == NOT_A_DICT_POS) { + const int wordPos = getTerminalPtNodePosFromWordId(getWordId( + CodePointArrayView(*ngramProperty->getTargetCodePoints()), + false /* forceLowerCaseSearch */)); + if (wordPos == NOT_A_DICT_POS) { return false; } bool addedNewBigram = false; - if (mUpdatingHelper.addNgramEntry(PtNodePosArrayView::fromObject(prevWordsPtNodePos), - word1Pos, bigramProperty, &addedNewBigram)) { + const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]); + if (mUpdatingHelper.addNgramEntry(PtNodePosArrayView::singleElementView(&prevWordPtNodePos), + wordPos, ngramProperty, &addedNewBigram)) { if (addedNewBigram) { - mBigramCount++; + mEntryCounters.incrementBigramCount(); } return true; } else { @@ -320,8 +405,8 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI } } -bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, - const int *const word, const int length) { +bool Ver4PatriciaTriePolicy::removeNgramEntry(const NgramContext *const ngramContext, + const CodePointArrayView wordCodePoints) { if (!mBuffers->isUpdatable()) { AKLOGI("Warning: removeNgramEntry() is called for non-updatable dictionary."); return false; @@ -331,40 +416,68 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor mDictBuffer->getTailPosition()); return false; } - if (!prevWordsInfo->isValid()) { - AKLOGE("prev words info is not valid for removing n-gram entry form the dictionary."); + if (!ngramContext->isValid()) { + AKLOGE("Ngram context is not valid for removing n-gram entry form the dictionary."); return false; } - if (length > MAX_WORD_LENGTH) { - AKLOGE("word is too long to remove n-gram entry form the dictionary. length: %d", length); + if (wordCodePoints.size() > MAX_WORD_LENGTH) { + AKLOGE("word is too long to remove n-gram entry form the dictionary. length: %zd", + wordCodePoints.size()); } - int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; - prevWordsInfo->getPrevWordsTerminalPtNodePos(this, prevWordsPtNodePos, + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; + const WordIdArrayView prevWordIds = ngramContext->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSerch */); - // TODO: Support N-gram. - if (prevWordsPtNodePos[0] == NOT_A_DICT_POS) { + if (prevWordIds.firstOrDefault(NOT_A_WORD_ID) == NOT_A_WORD_ID) { return false; } - const int wordPos = getTerminalPtNodePositionOfWord(word, length, - false /* forceLowerCaseSearch */); + const int wordPos = getTerminalPtNodePosFromWordId(getWordId(wordCodePoints, + false /* forceLowerCaseSearch */)); if (wordPos == NOT_A_DICT_POS) { return false; } + const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]); if (mUpdatingHelper.removeNgramEntry( - PtNodePosArrayView::fromObject(prevWordsPtNodePos), wordPos)) { - mBigramCount--; + PtNodePosArrayView::singleElementView(&prevWordPtNodePos), wordPos)) { + mEntryCounters.decrementBigramCount(); return true; } else { return false; } } + +bool Ver4PatriciaTriePolicy::updateEntriesForWordWithNgramContext( + const NgramContext *const ngramContext, const CodePointArrayView wordCodePoints, + const bool isValidWord, const HistoricalInfo historicalInfo) { + if (!mBuffers->isUpdatable()) { + AKLOGI("Warning: updateEntriesForWordWithNgramContext() is called for non-updatable " + "dictionary."); + return false; + } + const int probability = isValidWord ? DUMMY_PROBABILITY_FOR_VALID_WORDS : NOT_A_PROBABILITY; + const UnigramProperty unigramProperty(false /* representsBeginningOfSentence */, + false /* isNotAWord */, false /*isBlacklisted*/, probability, historicalInfo); + if (!addUnigramEntry(wordCodePoints, &unigramProperty)) { + AKLOGE("Cannot update unigarm entry in updateEntriesForWordWithNgramContext()."); + return false; + } + const int probabilityForNgram = ngramContext->isNthPrevWordBeginningOfSentence(1 /* n */) + ? NOT_A_PROBABILITY : probability; + const NgramProperty ngramProperty(*ngramContext, wordCodePoints.toVector(), probabilityForNgram, + historicalInfo); + if (!addNgramEntry(&ngramProperty)) { + AKLOGE("Cannot update unigarm entry in updateEntriesForWordWithNgramContext()."); + return false; + } + return true; +} + bool Ver4PatriciaTriePolicy::flush(const char *const filePath) { if (!mBuffers->isUpdatable()) { AKLOGI("Warning: flush() is called for non-updatable dictionary. filePath: %s", filePath); return false; } - if (!mWritingHelper.writeToDictFile(filePath, mUnigramCount, mBigramCount)) { + if (!mWritingHelper.writeToDictFile(filePath, mEntryCounters.getEntryCounts())) { AKLOGE("Cannot flush the dictionary to file."); mIsCorrupted = true; return false; @@ -402,7 +515,7 @@ bool Ver4PatriciaTriePolicy::needsToRunGC(const bool mindsBlockByGC) const { // Needs to reduce dictionary size. return true; } else if (mHeaderPolicy->isDecayingDict()) { - return ForgettingCurveUtils::needsToDecay(mindsBlockByGC, mUnigramCount, mBigramCount, + return ForgettingCurveUtils::needsToDecay(mindsBlockByGC, mEntryCounters.getEntryCounts(), mHeaderPolicy); } return false; @@ -412,41 +525,39 @@ void Ver4PatriciaTriePolicy::getProperty(const char *const query, const int quer char *const outResult, const int maxResultLength) { const int compareLength = queryLength + 1 /* terminator */; if (strncmp(query, UNIGRAM_COUNT_QUERY, compareLength) == 0) { - snprintf(outResult, maxResultLength, "%d", mUnigramCount); + snprintf(outResult, maxResultLength, "%d", mEntryCounters.getUnigramCount()); } else if (strncmp(query, BIGRAM_COUNT_QUERY, compareLength) == 0) { - snprintf(outResult, maxResultLength, "%d", mBigramCount); + snprintf(outResult, maxResultLength, "%d", mEntryCounters.getBigramCount()); } else if (strncmp(query, MAX_UNIGRAM_COUNT_QUERY, compareLength) == 0) { snprintf(outResult, maxResultLength, "%d", mHeaderPolicy->isDecayingDict() ? - ForgettingCurveUtils::getUnigramCountHardLimit( + ForgettingCurveUtils::getEntryCountHardLimit( mHeaderPolicy->getMaxUnigramCount()) : static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE)); } else if (strncmp(query, MAX_BIGRAM_COUNT_QUERY, compareLength) == 0) { snprintf(outResult, maxResultLength, "%d", mHeaderPolicy->isDecayingDict() ? - ForgettingCurveUtils::getBigramCountHardLimit( + ForgettingCurveUtils::getEntryCountHardLimit( mHeaderPolicy->getMaxBigramCount()) : static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE)); } } -const WordProperty Ver4PatriciaTriePolicy::getWordProperty(const int *const codePoints, - const int codePointCount) const { - const int ptNodePos = getTerminalPtNodePositionOfWord(codePoints, codePointCount, - false /* forceLowerCaseSearch */); +const WordProperty Ver4PatriciaTriePolicy::getWordProperty( + const CodePointArrayView wordCodePoints) const { + const int ptNodePos = getTerminalPtNodePosFromWordId( + getWordId(wordCodePoints, false /* forceLowerCaseSearch */)); if (ptNodePos == NOT_A_DICT_POS) { AKLOGE("getWordProperty is called for invalid word."); return WordProperty(); } const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); - std::vector<int> codePointVector(ptNodeParams.getCodePoints(), - ptNodeParams.getCodePoints() + ptNodeParams.getCodePointCount()); const ProbabilityEntry probabilityEntry = mBuffers->getProbabilityDictContent()->getProbabilityEntry( ptNodeParams.getTerminalId()); const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo(); // Fetch bigram information. - std::vector<BigramProperty> bigrams; + std::vector<NgramProperty> ngrams; const int bigramListPos = getBigramsPositionOfPtNode(ptNodePos); if (bigramListPos != NOT_A_DICT_POS) { int bigramWord1CodePoints[MAX_WORD_LENGTH]; @@ -465,21 +576,21 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(const int *const code if (word1TerminalPtNodePos == NOT_A_DICT_POS) { continue; } - // Word (unigram) probability - int word1Probability = NOT_A_PROBABILITY; - const int codePointCount = getCodePointsAndProbabilityAndReturnCodePointCount( - word1TerminalPtNodePos, MAX_WORD_LENGTH, bigramWord1CodePoints, - &word1Probability); - const std::vector<int> word1(bigramWord1CodePoints, - bigramWord1CodePoints + codePointCount); + const int codePointCount = getCodePointsAndReturnCodePointCount( + getWordIdFromTerminalPtNodePos(word1TerminalPtNodePos), MAX_WORD_LENGTH, + bigramWord1CodePoints); const HistoricalInfo *const historicalInfo = bigramEntry.getHistoricalInfo(); - const int probability = bigramEntry.hasHistoricalInfo() ? - ForgettingCurveUtils::decodeProbability( - bigramEntry.getHistoricalInfo(), mHeaderPolicy) : - bigramEntry.getProbability(); - bigrams.emplace_back(&word1, probability, - historicalInfo->getTimeStamp(), historicalInfo->getLevel(), - historicalInfo->getCount()); + const int rawBigramProbability = bigramEntry.hasHistoricalInfo() + ? ForgettingCurveUtils::decodeProbability( + bigramEntry.getHistoricalInfo(), mHeaderPolicy) + : bigramEntry.getProbability(); + const int probability = getBigramConditionalProbability(ptNodeParams.getProbability(), + ptNodeParams.representsBeginningOfSentence(), rawBigramProbability); + ngrams.emplace_back( + NgramContext(wordCodePoints.data(), wordCodePoints.size(), + ptNodeParams.representsBeginningOfSentence()), + CodePointArrayView(bigramWord1CodePoints, codePointCount).toVector(), + probability, *historicalInfo); } } // Fetch shortcut information. @@ -495,15 +606,15 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(const int *const code int shortcutProbability = NOT_A_PROBABILITY; shortcutDictContent->getShortcutEntryAndAdvancePosition(MAX_WORD_LENGTH, shortcutTarget, &shortcutTargetLength, &shortcutProbability, &hasNext, &shortcutPos); - const std::vector<int> target(shortcutTarget, shortcutTarget + shortcutTargetLength); - shortcuts.emplace_back(&target, shortcutProbability); + shortcuts.emplace_back( + CodePointArrayView(shortcutTarget, shortcutTargetLength).toVector(), + shortcutProbability); } } const UnigramProperty unigramProperty(ptNodeParams.representsBeginningOfSentence(), - ptNodeParams.isNotAWord(), ptNodeParams.isBlacklisted(), ptNodeParams.getProbability(), - historicalInfo->getTimeStamp(), historicalInfo->getLevel(), - historicalInfo->getCount(), &shortcuts); - return WordProperty(&codePointVector, &unigramProperty, &bigrams); + ptNodeParams.isNotAWord(), ptNodeParams.isPossiblyOffensive(), + ptNodeParams.getProbability(), *historicalInfo, std::move(shortcuts)); + return WordProperty(wordCodePoints.toVector(), &unigramProperty, &ngrams); } int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outCodePoints, @@ -524,9 +635,8 @@ int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const return 0; } const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token]; - int unigramProbability = NOT_A_PROBABILITY; - *outCodePointCount = getCodePointsAndProbabilityAndReturnCodePointCount( - terminalPtNodePos, MAX_WORD_LENGTH, outCodePoints, &unigramProbability); + *outCodePointCount = getCodePointsAndReturnCodePointCount( + getWordIdFromTerminalPtNodePos(terminalPtNodePos), MAX_WORD_LENGTH, outCodePoints); const int nextToken = token + 1; if (nextToken >= terminalPtNodePositionsVectorSize) { // All words have been iterated. @@ -536,6 +646,14 @@ int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const return nextToken; } +int Ver4PatriciaTriePolicy::getWordIdFromTerminalPtNodePos(const int ptNodePos) const { + return ptNodePos == NOT_A_DICT_POS ? NOT_A_WORD_ID : ptNodePos; +} + +int Ver4PatriciaTriePolicy::getTerminalPtNodePosFromWordId(const int wordId) const { + return wordId == NOT_A_WORD_ID ? NOT_A_DICT_POS : wordId; +} + } // namespace v402 } // namespace backward } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h index 9e989b268..0480876ed 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h @@ -28,6 +28,8 @@ #include <vector> #include "defines.h" +#include "suggest/core/dictionary/binary_dictionary_bigrams_iterator.h" +#include "suggest/core/dictionary/binary_dictionary_shortcut_iterator.h" #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" #include "suggest/policyimpl/dictionary/header/header_policy.h" #include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.h" @@ -39,6 +41,8 @@ #include "suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.h" #include "suggest/policyimpl/dictionary/structure/backward/v402/ver4_pt_node_array_reader.h" #include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" +#include "suggest/policyimpl/dictionary/utils/entry_counters.h" +#include "utils/int_array_view.h" namespace latinime { namespace backward { @@ -55,6 +59,8 @@ class DicNodeVector; namespace backward { namespace v402 { +// Word id = Position of a PtNode that represents the word. +// Max supported n-gram is bigram. class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { public: Ver4PatriciaTriePolicy(Ver4DictBuffers::Ver4DictBuffersPtr buffers) @@ -70,54 +76,51 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { &mPtNodeArrayReader, &mBigramPolicy, &mShortcutPolicy), mUpdatingHelper(mDictBuffer, &mNodeReader, &mNodeWriter), mWritingHelper(mBuffers.get()), - mUnigramCount(mHeaderPolicy->getUnigramCount()), - mBigramCount(mHeaderPolicy->getBigramCount()), + mEntryCounters(mHeaderPolicy->getUnigramCount(), mHeaderPolicy->getBigramCount(), + mHeaderPolicy->getTrigramCount()), mTerminalPtNodePositionsForIteratingWords(), mIsCorrupted(false) {}; - AK_FORCE_INLINE int getRootPosition() const { + virtual int getRootPosition() const { return 0; } void createAndGetAllChildDicNodes(const DicNode *const dicNode, DicNodeVector *const childDicNodes) const; - int getCodePointsAndProbabilityAndReturnCodePointCount( - const int terminalPtNodePos, const int maxCodePointCount, int *const outCodePoints, - int *const outUnigramProbability) const; + int getCodePointsAndReturnCodePointCount(const int wordId, const int maxCodePointCount, + int *const outCodePoints) const; - int getTerminalPtNodePositionOfWord(const int *const inWord, - const int length, const bool forceLowerCaseSearch) const; + int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const; + + const WordAttributes getWordAttributesInContext(const WordIdArrayView prevWordIds, + const int wordId, MultiBigramMap *const multiBigramMap) const; int getProbability(const int unigramProbability, const int bigramProbability) const; - int getProbabilityOfPtNode(const int *const prevWordsPtNodePos, const int ptNodePos) const; + int getProbabilityOfWord(const WordIdArrayView prevWordIds, const int wordId) const; - void iterateNgramEntries(const int *const prevWordsPtNodePos, + void iterateNgramEntries(const WordIdArrayView prevWordIds, NgramListener *const listener) const; - int getShortcutPositionOfPtNode(const int ptNodePos) const; + BinaryDictionaryShortcutIterator getShortcutIterator(const int wordId) const; const DictionaryHeaderStructurePolicy *getHeaderStructurePolicy() const { return mHeaderPolicy; } - const DictionaryShortcutsStructurePolicy *getShortcutsStructurePolicy() const { - return &mShortcutPolicy; - } - - bool addUnigramEntry(const int *const word, const int length, + bool addUnigramEntry(const CodePointArrayView wordCodePoints, const UnigramProperty *const unigramProperty); - bool removeUnigramEntry(const int *const word, const int length) { - // Removing unigram entry is not supported. - return false; - } + bool removeUnigramEntry(const CodePointArrayView wordCodePoints); + + bool addNgramEntry(const NgramProperty *const ngramProperty); - bool addNgramEntry(const PrevWordsInfo *const prevWordsInfo, - const BigramProperty *const bigramProperty); + bool removeNgramEntry(const NgramContext *const ngramContext, + const CodePointArrayView wordCodePoints); - bool removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, const int *const word1, - const int length1); + bool updateEntriesForWordWithNgramContext(const NgramContext *const ngramContext, + const CodePointArrayView wordCodePoints, const bool isValidWord, + const HistoricalInfo historicalInfo); bool flush(const char *const filePath); @@ -128,8 +131,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { void getProperty(const char *const query, const int queryLength, char *const outResult, const int maxResultLength); - const WordProperty getWordProperty(const int *const codePoints, - const int codePointCount) const; + const WordProperty getWordProperty(const CodePointArrayView wordCodePoints) const; int getNextWordAndNextToken(const int token, int *const outCodePoints, int *const outCodePointCount); @@ -149,6 +151,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { // prevent the dictionary from overflowing. static const int MARGIN_TO_REFUSE_DYNAMIC_OPERATIONS; static const int MIN_DICT_SIZE_TO_REFUSE_DYNAMIC_OPERATIONS; + static const int DUMMY_PROBABILITY_FOR_VALID_WORDS; const Ver4DictBuffers::Ver4DictBuffersPtr mBuffers; const HeaderPolicy *const mHeaderPolicy; @@ -160,12 +163,18 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { Ver4PatriciaTrieNodeWriter mNodeWriter; DynamicPtUpdatingHelper mUpdatingHelper; Ver4PatriciaTrieWritingHelper mWritingHelper; - int mUnigramCount; - int mBigramCount; + MutableEntryCounters mEntryCounters; std::vector<int> mTerminalPtNodePositionsForIteratingWords; mutable bool mIsCorrupted; int getBigramsPositionOfPtNode(const int ptNodePos) const; + int getShortcutPositionOfPtNode(const int ptNodePos) const; + int getWordIdFromTerminalPtNodePos(const int ptNodePos) const; + int getTerminalPtNodePosFromWordId(const int wordId) const; + const WordAttributes getWordAttributes(const int probability, + const PtNodeParams &ptNodeParams) const; + int getBigramConditionalProbability(const int prevWordUnigramProbability, + const bool isInBeginningOfSentenceContext, const int bigramProbability) const; }; } // namespace v402 } // namespace backward diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.cpp index 3fb4caa08..a033d396b 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.cpp @@ -43,18 +43,18 @@ namespace backward { namespace v402 { bool Ver4PatriciaTrieWritingHelper::writeToDictFile(const char *const dictDirPath, - const int unigramCount, const int bigramCount) const { + const EntryCounts &entryCounts) const { const HeaderPolicy *const headerPolicy = mBuffers->getHeaderPolicy(); BufferWithExtendableBuffer headerBuffer( BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE); const int extendedRegionSize = headerPolicy->getExtendedRegionSize() + mBuffers->getTrieBuffer()->getUsedAdditionalBufferSize(); if (!headerPolicy->fillInAndWriteHeaderToBuffer(false /* updatesLastDecayedTime */, - unigramCount, bigramCount, extendedRegionSize, &headerBuffer)) { + entryCounts, extendedRegionSize, &headerBuffer)) { AKLOGE("Cannot write header structure to buffer. " "updatesLastDecayedTime: %d, unigramCount: %d, bigramCount: %d, " - "extendedRegionSize: %d", false, unigramCount, bigramCount, - extendedRegionSize); + "extendedRegionSize: %d", false, entryCounts.getUnigramCount(), + entryCounts.getBigramCount(), extendedRegionSize); return false; } return mBuffers->flushHeaderAndDictBuffers(dictDirPath, &headerBuffer); @@ -74,7 +74,8 @@ bool Ver4PatriciaTrieWritingHelper::writeToDictFileWithGC(const int rootPtNodeAr BufferWithExtendableBuffer headerBuffer( BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE); if (!headerPolicy->fillInAndWriteHeaderToBuffer(true /* updatesLastDecayedTime */, - unigramCount, bigramCount, 0 /* extendedRegionSize */, &headerBuffer)) { + EntryCounts(unigramCount, bigramCount, 0 /* trigramCount */), + 0 /* extendedRegionSize */, &headerBuffer)) { return false; } return dictBuffers->flushHeaderAndDictBuffers(dictDirPath, &headerBuffer); @@ -216,7 +217,7 @@ bool Ver4PatriciaTrieWritingHelper::truncateUnigrams( probabilityEntry.getHistoricalInfo(), mBuffers->getHeaderPolicy()) : probabilityEntry.getProbability(); priorityQueue.push(DictProbability(terminalPos, probability, - probabilityEntry.getHistoricalInfo()->getTimeStamp())); + probabilityEntry.getHistoricalInfo()->getTimestamp())); } // Delete unigrams. @@ -263,7 +264,7 @@ bool Ver4PatriciaTrieWritingHelper::truncateBigrams(const int maxBigramCount) { bigramEntry.getHistoricalInfo(), mBuffers->getHeaderPolicy()) : bigramEntry.getProbability(); priorityQueue.push(DictProbability(entryPos, probability, - bigramEntry.getHistoricalInfo()->getTimeStamp())); + bigramEntry.getHistoricalInfo()->getTimestamp())); } } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.h index 9034ee656..1aad33e38 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.h @@ -27,6 +27,7 @@ #include "defines.h" #include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_gc_event_listeners.h" #include "suggest/policyimpl/dictionary/structure/backward/v402/content/terminal_position_lookup_table.h" +#include "suggest/policyimpl/dictionary/utils/entry_counters.h" namespace latinime { namespace backward { @@ -46,8 +47,7 @@ class Ver4PatriciaTrieWritingHelper { Ver4PatriciaTrieWritingHelper(Ver4DictBuffers *const buffers) : mBuffers(buffers) {} - bool writeToDictFile(const char *const dictDirPath, const int unigramCount, - const int bigramCount) const; + bool writeToDictFile(const char *const dictDirPath, const EntryCounts &entryCounts) const; // This method cannot be const because the original dictionary buffer will be updated to detect // useless PtNodes during GC. diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp index e4ea3da16..372c9e36f 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.cpp @@ -111,11 +111,11 @@ template<class DictConstants, class DictBuffers, class DictBuffersPtr, class Str return nullptr; } const FormatUtils::FORMAT_VERSION formatVersion = FormatUtils::detectFormatVersion( - mmappedBuffer->getReadOnlyByteArrayView().data(), - mmappedBuffer->getReadOnlyByteArrayView().size()); + mmappedBuffer->getReadOnlyByteArrayView()); switch (formatVersion) { case FormatUtils::VERSION_2: - AKLOGE("Given path is a directory but the format is version 2. path: %s", path); + case FormatUtils::VERSION_201: + AKLOGE("Given path is a directory but the format is version 2 or 201. path: %s", path); break; case FormatUtils::VERSION_4: { return newPolicyForV4Dict<backward::v402::Ver4DictConstants, @@ -174,9 +174,9 @@ template<class DictConstants, class DictBuffers, class DictBuffersPtr, class Str if (!mmappedBuffer) { return nullptr; } - switch (FormatUtils::detectFormatVersion(mmappedBuffer->getReadOnlyByteArrayView().data(), - mmappedBuffer->getReadOnlyByteArrayView().size())) { + switch (FormatUtils::detectFormatVersion(mmappedBuffer->getReadOnlyByteArrayView())) { case FormatUtils::VERSION_2: + case FormatUtils::VERSION_201: return DictionaryStructureWithBufferPolicy::StructurePolicyPtr( new PatriciaTriePolicy(std::move(mmappedBuffer))); case FormatUtils::VERSION_4_ONLY_FOR_TESTING: diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/bigram/bigram_list_read_write_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/bigram/bigram_list_read_write_utils.cpp index f7fd5c071..1b2f857ab 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/bigram/bigram_list_read_write_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/bigram/bigram_list_read_write_utils.cpp @@ -39,32 +39,31 @@ const BigramListReadWriteUtils::BigramFlags BigramListReadWriteUtils::MASK_ATTRIBUTE_PROBABILITY = 0x0F; /* static */ bool BigramListReadWriteUtils::getBigramEntryPropertiesAndAdvancePosition( - const uint8_t *const bigramsBuf, const int bufSize, BigramFlags *const outBigramFlags, + const ReadOnlyByteArrayView buffer, BigramFlags *const outBigramFlags, int *const outTargetPtNodePos, int *const bigramEntryPos) { - if (bufSize <= *bigramEntryPos) { - AKLOGE("Read invalid pos in getBigramEntryPropertiesAndAdvancePosition(). bufSize: %d, " - "bigramEntryPos: %d.", bufSize, *bigramEntryPos); + if (static_cast<int>(buffer.size()) <= *bigramEntryPos) { + AKLOGE("Read invalid pos in getBigramEntryPropertiesAndAdvancePosition(). bufSize: %zd, " + "bigramEntryPos: %d.", buffer.size(), *bigramEntryPos); return false; } - const BigramFlags bigramFlags = ByteArrayUtils::readUint8AndAdvancePosition(bigramsBuf, + const BigramFlags bigramFlags = ByteArrayUtils::readUint8AndAdvancePosition(buffer.data(), bigramEntryPos); if (outBigramFlags) { *outBigramFlags = bigramFlags; } - const int targetPos = getBigramAddressAndAdvancePosition(bigramsBuf, bigramFlags, - bigramEntryPos); + const int targetPos = getBigramAddressAndAdvancePosition(buffer, bigramFlags, bigramEntryPos); if (outTargetPtNodePos) { *outTargetPtNodePos = targetPos; } return true; } -/* static */ bool BigramListReadWriteUtils::skipExistingBigrams(const uint8_t *const bigramsBuf, - const int bufSize, int *const bigramListPos) { +/* static */ bool BigramListReadWriteUtils::skipExistingBigrams(const ReadOnlyByteArrayView buffer, + int *const bigramListPos) { BigramFlags flags; do { - if (!getBigramEntryPropertiesAndAdvancePosition(bigramsBuf, bufSize, &flags, - 0 /* outTargetPtNodePos */, bigramListPos)) { + if (!getBigramEntryPropertiesAndAdvancePosition(buffer, &flags, 0 /* outTargetPtNodePos */, + bigramListPos)) { return false; } } while(hasNext(flags)); @@ -72,18 +71,18 @@ const BigramListReadWriteUtils::BigramFlags } /* static */ int BigramListReadWriteUtils::getBigramAddressAndAdvancePosition( - const uint8_t *const bigramsBuf, const BigramFlags flags, int *const pos) { + const ReadOnlyByteArrayView buffer, const BigramFlags flags, int *const pos) { int offset = 0; const int origin = *pos; switch (MASK_ATTRIBUTE_ADDRESS_TYPE & flags) { case FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE: - offset = ByteArrayUtils::readUint8AndAdvancePosition(bigramsBuf, pos); + offset = ByteArrayUtils::readUint8AndAdvancePosition(buffer.data(), pos); break; case FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES: - offset = ByteArrayUtils::readUint16AndAdvancePosition(bigramsBuf, pos); + offset = ByteArrayUtils::readUint16AndAdvancePosition(buffer.data(), pos); break; case FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES: - offset = ByteArrayUtils::readUint24AndAdvancePosition(bigramsBuf, pos); + offset = ByteArrayUtils::readUint24AndAdvancePosition(buffer.data(), pos); break; } if (isOffsetNegative(flags)) { diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/bigram/bigram_list_read_write_utils.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/bigram/bigram_list_read_write_utils.h index 10f93fb7a..a0f7d5e83 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/bigram/bigram_list_read_write_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/bigram/bigram_list_read_write_utils.h @@ -21,6 +21,7 @@ #include <cstdlib> #include "defines.h" +#include "utils/byte_array_view.h" namespace latinime { @@ -30,8 +31,8 @@ class BigramListReadWriteUtils { public: typedef uint8_t BigramFlags; - static bool getBigramEntryPropertiesAndAdvancePosition(const uint8_t *const bigramsBuf, - const int bufSize, BigramFlags *const outBigramFlags, int *const outTargetPtNodePos, + static bool getBigramEntryPropertiesAndAdvancePosition(const ReadOnlyByteArrayView buffer, + BigramFlags *const outBigramFlags, int *const outTargetPtNodePos, int *const bigramEntryPos); static AK_FORCE_INLINE int getProbabilityFromFlags(const BigramFlags flags) { @@ -43,8 +44,7 @@ public: } // Bigrams reading methods - static bool skipExistingBigrams(const uint8_t *const bigramsBuf, const int bufSize, - int *const bigramListPos); + static bool skipExistingBigrams(const ReadOnlyByteArrayView buffer, int *const bigramListPos); private: DISALLOW_IMPLICIT_CONSTRUCTORS(BigramListReadWriteUtils); @@ -61,7 +61,7 @@ private: return (flags & FLAG_ATTRIBUTE_OFFSET_NEGATIVE) != 0; } - static int getBigramAddressAndAdvancePosition(const uint8_t *const bigramsBuf, + static int getBigramAddressAndAdvancePosition(const ReadOnlyByteArrayView buffer, const BigramFlags flags, int *const pos); }; } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_gc_event_listeners.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_gc_event_listeners.h index 2aa402748..b8a4a92e8 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_gc_event_listeners.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_gc_event_listeners.h @@ -76,6 +76,7 @@ class DynamicPtGcEventListeners { int mValidUnigramCount; }; + // TODO: Remove when we stop supporting v402 format. // Updates all bigram entries that are held by valid PtNodes. This removes useless bigram // entries. class TraversePolicyToUpdateBigramProbability diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp index 086d98b4a..5e4a4b166 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp @@ -175,8 +175,8 @@ bool DynamicPtReadingHelper::traverseAllPtNodesInPtNodeArrayLevelPreorderDepthFi return !isError(); } -int DynamicPtReadingHelper::getCodePointsAndProbabilityAndReturnCodePointCount( - const int maxCodePointCount, int *const outCodePoints, int *const outUnigramProbability) { +int DynamicPtReadingHelper::getCodePointsAndReturnCodePointCount(const int maxCodePointCount, + int *const outCodePoints) { // This method traverses parent nodes from the terminal by following parent pointers; thus, // node code points are stored in the buffer in the reverse order. int reverseCodePoints[maxCodePointCount]; @@ -184,11 +184,8 @@ int DynamicPtReadingHelper::getCodePointsAndProbabilityAndReturnCodePointCount( // First, read the terminal node and get its probability. if (!isValidTerminalNode(terminalPtNodeParams)) { // Node at the ptNodePos is not a valid terminal node. - *outUnigramProbability = NOT_A_PROBABILITY; return 0; } - // Store terminal node probability. - *outUnigramProbability = terminalPtNodeParams.getProbability(); // Then, following parent node link to the dictionary root and fetch node code points. int totalCodePointCount = 0; while (!isEnd()) { @@ -196,7 +193,6 @@ int DynamicPtReadingHelper::getCodePointsAndProbabilityAndReturnCodePointCount( totalCodePointCount = getTotalCodePointCount(ptNodeParams); if (!ptNodeParams.isValid() || totalCodePointCount > maxCodePointCount) { // The ptNodePos is not a valid terminal node position in the dictionary. - *outUnigramProbability = NOT_A_PROBABILITY; return 0; } // Store node code points to buffer in the reverse order. @@ -207,7 +203,6 @@ int DynamicPtReadingHelper::getCodePointsAndProbabilityAndReturnCodePointCount( } if (isError()) { // The node position or the dictionary is invalid. - *outUnigramProbability = NOT_A_PROBABILITY; return 0; } // Reverse the stored code points to output them. @@ -218,9 +213,9 @@ int DynamicPtReadingHelper::getCodePointsAndProbabilityAndReturnCodePointCount( } int DynamicPtReadingHelper::getTerminalPtNodePositionOfWord(const int *const inWord, - const int length, const bool forceLowerCaseSearch) { + const size_t length, const bool forceLowerCaseSearch) { int searchCodePoints[length]; - for (int i = 0; i < length; ++i) { + for (size_t i = 0; i < length; ++i) { searchCodePoints[i] = forceLowerCaseSearch ? CharUtils::toLowerCase(inWord[i]) : inWord[i]; } while (!isEnd()) { diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h index b7262581a..21c287fdc 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h @@ -138,12 +138,12 @@ class DynamicPtReadingHelper { } // Return code point count exclude the last read node's code points. - AK_FORCE_INLINE int getPrevTotalCodePointCount() const { + AK_FORCE_INLINE size_t getPrevTotalCodePointCount() const { return mReadingState.mTotalCodePointCountSinceInitialization; } // Return code point count include the last read node's code points. - AK_FORCE_INLINE int getTotalCodePointCount(const PtNodeParams &ptNodeParams) const { + AK_FORCE_INLINE size_t getTotalCodePointCount(const PtNodeParams &ptNodeParams) const { return mReadingState.mTotalCodePointCountSinceInitialization + ptNodeParams.getCodePointCount(); } @@ -211,10 +211,9 @@ class DynamicPtReadingHelper { bool traverseAllPtNodesInPtNodeArrayLevelPreorderDepthFirstManner( TraversingEventListener *const listener); - int getCodePointsAndProbabilityAndReturnCodePointCount(const int maxCodePointCount, - int *const outCodePoints, int *const outUnigramProbability); + int getCodePointsAndReturnCodePointCount(const int maxCodePointCount, int *const outCodePoints); - int getTerminalPtNodePositionOfWord(const int *const inWord, const int length, + int getTerminalPtNodePositionOfWord(const int *const inWord, const size_t length, const bool forceLowerCaseSearch); private: @@ -234,7 +233,7 @@ class DynamicPtReadingHelper { int mPos; // Remaining node count in the current array. int mRemainingPtNodeCountInThisArray; - int mTotalCodePointCountSinceInitialization; + size_t mTotalCodePointCountSinceInitialization; // Counter of PtNodes used to avoid infinite loops caused by broken or malicious links. int mTotalPtNodeIndexInThisArrayChain; // Counter of PtNode arrays used to avoid infinite loops caused by cyclic links of empty diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.cpp index 3c62e2e56..e524e86e5 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.cpp @@ -28,17 +28,16 @@ namespace latinime { const int DynamicPtUpdatingHelper::CHILDREN_POSITION_FIELD_SIZE = 3; -bool DynamicPtUpdatingHelper::addUnigramWord( - DynamicPtReadingHelper *const readingHelper, - const int *const wordCodePoints, const int codePointCount, - const UnigramProperty *const unigramProperty, bool *const outAddedNewUnigram) { +bool DynamicPtUpdatingHelper::addUnigramWord(DynamicPtReadingHelper *const readingHelper, + const CodePointArrayView wordCodePoints, const UnigramProperty *const unigramProperty, + bool *const outAddedNewUnigram) { int parentPos = NOT_A_DICT_POS; while (!readingHelper->isEnd()) { const PtNodeParams ptNodeParams(readingHelper->getPtNodeParams()); if (!ptNodeParams.isValid()) { break; } - const int matchedCodePointCount = readingHelper->getPrevTotalCodePointCount(); + const size_t matchedCodePointCount = readingHelper->getPrevTotalCodePointCount(); if (!readingHelper->isMatchedCodePoint(ptNodeParams, 0 /* index */, wordCodePoints[matchedCodePointCount])) { // The first code point is different from target code point. Skip this node and read @@ -47,26 +46,25 @@ bool DynamicPtUpdatingHelper::addUnigramWord( continue; } // Check following merged node code points. - const int nodeCodePointCount = ptNodeParams.getCodePointCount(); - for (int j = 1; j < nodeCodePointCount; ++j) { - const int nextIndex = matchedCodePointCount + j; - if (nextIndex >= codePointCount || !readingHelper->isMatchedCodePoint(ptNodeParams, j, - wordCodePoints[matchedCodePointCount + j])) { + const size_t nodeCodePointCount = ptNodeParams.getCodePointArrayView().size(); + for (size_t j = 1; j < nodeCodePointCount; ++j) { + const size_t nextIndex = matchedCodePointCount + j; + if (nextIndex >= wordCodePoints.size() + || !readingHelper->isMatchedCodePoint(ptNodeParams, j, + wordCodePoints[matchedCodePointCount + j])) { *outAddedNewUnigram = true; return reallocatePtNodeAndAddNewPtNodes(&ptNodeParams, j, unigramProperty, - wordCodePoints + matchedCodePointCount, - codePointCount - matchedCodePointCount); + wordCodePoints.skip(matchedCodePointCount)); } } // All characters are matched. - if (codePointCount == readingHelper->getTotalCodePointCount(ptNodeParams)) { + if (wordCodePoints.size() == readingHelper->getTotalCodePointCount(ptNodeParams)) { return setPtNodeProbability(&ptNodeParams, unigramProperty, outAddedNewUnigram); } if (!ptNodeParams.hasChildren()) { *outAddedNewUnigram = true; return createChildrenPtNodeArrayAndAChildPtNode(&ptNodeParams, unigramProperty, - wordCodePoints + readingHelper->getTotalCodePointCount(ptNodeParams), - codePointCount - readingHelper->getTotalCodePointCount(ptNodeParams)); + wordCodePoints.skip(readingHelper->getTotalCodePointCount(ptNodeParams))); } // Advance to the children nodes. parentPos = ptNodeParams.getHeadPos(); @@ -79,13 +77,12 @@ bool DynamicPtUpdatingHelper::addUnigramWord( int pos = readingHelper->getPosOfLastForwardLinkField(); *outAddedNewUnigram = true; return createAndInsertNodeIntoPtNodeArray(parentPos, - wordCodePoints + readingHelper->getPrevTotalCodePointCount(), - codePointCount - readingHelper->getPrevTotalCodePointCount(), - unigramProperty, &pos); + wordCodePoints.skip(readingHelper->getPrevTotalCodePointCount()), unigramProperty, + &pos); } bool DynamicPtUpdatingHelper::addNgramEntry(const PtNodePosArrayView prevWordsPtNodePos, - const int wordPos, const BigramProperty *const bigramProperty, + const int wordPos, const NgramProperty *const ngramProperty, bool *const outAddedNewEntry) { if (prevWordsPtNodePos.empty()) { return false; @@ -99,7 +96,7 @@ bool DynamicPtUpdatingHelper::addNgramEntry(const PtNodePosArrayView prevWordsPt const WordIdArrayView prevWordIds(prevWordTerminalIds, prevWordsPtNodePos.size()); const int wordId = mPtNodeReader->fetchPtNodeParamsInBufferFromPtNodePos(wordPos).getTerminalId(); - return mPtNodeWriter->addNgramEntry(prevWordIds, wordId, bigramProperty, outAddedNewEntry); + return mPtNodeWriter->addNgramEntry(prevWordIds, wordId, ngramProperty, outAddedNewEntry); } bool DynamicPtUpdatingHelper::removeNgramEntry(const PtNodePosArrayView prevWordsPtNodePos, @@ -120,23 +117,21 @@ bool DynamicPtUpdatingHelper::removeNgramEntry(const PtNodePosArrayView prevWord } bool DynamicPtUpdatingHelper::addShortcutTarget(const int wordPos, - const int *const targetCodePoints, const int targetCodePointCount, - const int shortcutProbability) { + const CodePointArrayView targetCodePoints, const int shortcutProbability) { const PtNodeParams ptNodeParams(mPtNodeReader->fetchPtNodeParamsInBufferFromPtNodePos(wordPos)); - return mPtNodeWriter->addShortcutTarget(&ptNodeParams, targetCodePoints, targetCodePointCount, - shortcutProbability); + return mPtNodeWriter->addShortcutTarget(&ptNodeParams, targetCodePoints.data(), + targetCodePoints.size(), shortcutProbability); } bool DynamicPtUpdatingHelper::createAndInsertNodeIntoPtNodeArray(const int parentPos, - const int *const nodeCodePoints, const int nodeCodePointCount, - const UnigramProperty *const unigramProperty, int *const forwardLinkFieldPos) { + const CodePointArrayView ptNodeCodePoints, const UnigramProperty *const unigramProperty, + int *const forwardLinkFieldPos) { const int newPtNodeArrayPos = mBuffer->getTailPosition(); if (!DynamicPtWritingUtils::writeForwardLinkPositionAndAdvancePosition(mBuffer, newPtNodeArrayPos, forwardLinkFieldPos)) { return false; } - return createNewPtNodeArrayWithAChildPtNode(parentPos, nodeCodePoints, nodeCodePointCount, - unigramProperty); + return createNewPtNodeArrayWithAChildPtNode(parentPos, ptNodeCodePoints, unigramProperty); } bool DynamicPtUpdatingHelper::setPtNodeProbability(const PtNodeParams *const originalPtNodeParams, @@ -151,10 +146,9 @@ bool DynamicPtUpdatingHelper::setPtNodeProbability(const PtNodeParams *const ori const int movedPos = mBuffer->getTailPosition(); int writingPos = movedPos; const PtNodeParams ptNodeParamsToWrite(getUpdatedPtNodeParams(originalPtNodeParams, - unigramProperty->isNotAWord(), unigramProperty->isBlacklisted(), + unigramProperty->isNotAWord(), unigramProperty->isPossiblyOffensive(), true /* isTerminal */, originalPtNodeParams->getParentPos(), - originalPtNodeParams->getCodePointCount(), originalPtNodeParams->getCodePoints(), - unigramProperty->getProbability())); + originalPtNodeParams->getCodePointArrayView(), unigramProperty->getProbability())); if (!mPtNodeWriter->writeNewTerminalPtNodeAndAdvancePosition(&ptNodeParamsToWrite, unigramProperty, &writingPos)) { return false; @@ -168,17 +162,17 @@ bool DynamicPtUpdatingHelper::setPtNodeProbability(const PtNodeParams *const ori bool DynamicPtUpdatingHelper::createChildrenPtNodeArrayAndAChildPtNode( const PtNodeParams *const parentPtNodeParams, const UnigramProperty *const unigramProperty, - const int *const codePoints, const int codePointCount) { + const CodePointArrayView codePoints) { const int newPtNodeArrayPos = mBuffer->getTailPosition(); if (!mPtNodeWriter->updateChildrenPosition(parentPtNodeParams, newPtNodeArrayPos)) { return false; } return createNewPtNodeArrayWithAChildPtNode(parentPtNodeParams->getHeadPos(), codePoints, - codePointCount, unigramProperty); + unigramProperty); } bool DynamicPtUpdatingHelper::createNewPtNodeArrayWithAChildPtNode( - const int parentPtNodePos, const int *const nodeCodePoints, const int nodeCodePointCount, + const int parentPtNodePos, const CodePointArrayView ptNodeCodePoints, const UnigramProperty *const unigramProperty) { int writingPos = mBuffer->getTailPosition(); if (!DynamicPtWritingUtils::writePtNodeArraySizeAndAdvancePosition(mBuffer, @@ -186,8 +180,8 @@ bool DynamicPtUpdatingHelper::createNewPtNodeArrayWithAChildPtNode( return false; } const PtNodeParams ptNodeParamsToWrite(getPtNodeParamsForNewPtNode( - unigramProperty->isNotAWord(), unigramProperty->isBlacklisted(), true /* isTerminal */, - parentPtNodePos, nodeCodePointCount, nodeCodePoints, + unigramProperty->isNotAWord(), unigramProperty->isPossiblyOffensive(), + true /* isTerminal */, parentPtNodePos, ptNodeCodePoints, unigramProperty->getProbability())); if (!mPtNodeWriter->writeNewTerminalPtNodeAndAdvancePosition(&ptNodeParamsToWrite, unigramProperty, &writingPos)) { @@ -202,9 +196,9 @@ bool DynamicPtUpdatingHelper::createNewPtNodeArrayWithAChildPtNode( // Returns whether the dictionary updating was succeeded or not. bool DynamicPtUpdatingHelper::reallocatePtNodeAndAddNewPtNodes( - const PtNodeParams *const reallocatingPtNodeParams, const int overlappingCodePointCount, - const UnigramProperty *const unigramProperty, const int *const newNodeCodePoints, - const int newNodeCodePointCount) { + const PtNodeParams *const reallocatingPtNodeParams, const size_t overlappingCodePointCount, + const UnigramProperty *const unigramProperty, + const CodePointArrayView newPtNodeCodePoints) { // When addsExtraChild is true, split the reallocating PtNode and add new child. // Reallocating PtNode: abcde, newNode: abcxy. // abc (1st, not terminal) __ de (2nd) @@ -212,25 +206,26 @@ bool DynamicPtUpdatingHelper::reallocatePtNodeAndAddNewPtNodes( // Otherwise, this method makes 1st part terminal and write information in unigramProperty. // Reallocating PtNode: abcde, newNode: abc. // abc (1st, terminal) __ de (2nd) - const bool addsExtraChild = newNodeCodePointCount > overlappingCodePointCount; + const bool addsExtraChild = newPtNodeCodePoints.size() > overlappingCodePointCount; const int firstPartOfReallocatedPtNodePos = mBuffer->getTailPosition(); int writingPos = firstPartOfReallocatedPtNodePos; // Write the 1st part of the reallocating node. The children position will be updated later // with actual children position. + const CodePointArrayView firstPtNodeCodePoints = + reallocatingPtNodeParams->getCodePointArrayView().limit(overlappingCodePointCount); if (addsExtraChild) { const PtNodeParams ptNodeParamsToWrite(getPtNodeParamsForNewPtNode( - false /* isNotAWord */, false /* isBlacklisted */, false /* isTerminal */, - reallocatingPtNodeParams->getParentPos(), overlappingCodePointCount, - reallocatingPtNodeParams->getCodePoints(), NOT_A_PROBABILITY)); + false /* isNotAWord */, false /* isPossiblyOffensive */, false /* isTerminal */, + reallocatingPtNodeParams->getParentPos(), firstPtNodeCodePoints, + NOT_A_PROBABILITY)); if (!mPtNodeWriter->writePtNodeAndAdvancePosition(&ptNodeParamsToWrite, &writingPos)) { return false; } } else { const PtNodeParams ptNodeParamsToWrite(getPtNodeParamsForNewPtNode( - unigramProperty->isNotAWord(), unigramProperty->isBlacklisted(), + unigramProperty->isNotAWord(), unigramProperty->isPossiblyOffensive(), true /* isTerminal */, reallocatingPtNodeParams->getParentPos(), - overlappingCodePointCount, reallocatingPtNodeParams->getCodePoints(), - unigramProperty->getProbability())); + firstPtNodeCodePoints, unigramProperty->getProbability())); if (!mPtNodeWriter->writeNewTerminalPtNodeAndAdvancePosition(&ptNodeParamsToWrite, unigramProperty, &writingPos)) { return false; @@ -246,20 +241,19 @@ bool DynamicPtUpdatingHelper::reallocatePtNodeAndAddNewPtNodes( // Write the 2nd part of the reallocating node. const int secondPartOfReallocatedPtNodePos = writingPos; const PtNodeParams childPartPtNodeParams(getUpdatedPtNodeParams(reallocatingPtNodeParams, - reallocatingPtNodeParams->isNotAWord(), reallocatingPtNodeParams->isBlacklisted(), + reallocatingPtNodeParams->isNotAWord(), reallocatingPtNodeParams->isPossiblyOffensive(), reallocatingPtNodeParams->isTerminal(), firstPartOfReallocatedPtNodePos, - reallocatingPtNodeParams->getCodePointCount() - overlappingCodePointCount, - reallocatingPtNodeParams->getCodePoints() + overlappingCodePointCount, + reallocatingPtNodeParams->getCodePointArrayView().skip(overlappingCodePointCount), reallocatingPtNodeParams->getProbability())); if (!mPtNodeWriter->writePtNodeAndAdvancePosition(&childPartPtNodeParams, &writingPos)) { return false; } if (addsExtraChild) { const PtNodeParams extraChildPtNodeParams(getPtNodeParamsForNewPtNode( - unigramProperty->isNotAWord(), unigramProperty->isBlacklisted(), + unigramProperty->isNotAWord(), unigramProperty->isPossiblyOffensive(), true /* isTerminal */, firstPartOfReallocatedPtNodePos, - newNodeCodePointCount - overlappingCodePointCount, - newNodeCodePoints + overlappingCodePointCount, unigramProperty->getProbability())); + newPtNodeCodePoints.skip(overlappingCodePointCount), + unigramProperty->getProbability())); if (!mPtNodeWriter->writeNewTerminalPtNodeAndAdvancePosition(&extraChildPtNodeParams, unigramProperty, &writingPos)) { return false; @@ -282,26 +276,24 @@ bool DynamicPtUpdatingHelper::reallocatePtNodeAndAddNewPtNodes( } const PtNodeParams DynamicPtUpdatingHelper::getUpdatedPtNodeParams( - const PtNodeParams *const originalPtNodeParams, - const bool isNotAWord, const bool isBlacklisted, const bool isTerminal, const int parentPos, - const int codePointCount, const int *const codePoints, const int probability) const { + const PtNodeParams *const originalPtNodeParams, const bool isNotAWord, + const bool isPossiblyOffensive, const bool isTerminal, const int parentPos, + const CodePointArrayView codePoints, const int probability) const { const PatriciaTrieReadingUtils::NodeFlags flags = PatriciaTrieReadingUtils::createAndGetFlags( - isBlacklisted, isNotAWord, isTerminal, false /* hasShortcutTargets */, - false /* hasBigrams */, codePointCount > 1 /* hasMultipleChars */, + isPossiblyOffensive, isNotAWord, isTerminal, false /* hasShortcutTargets */, + false /* hasBigrams */, codePoints.size() > 1u /* hasMultipleChars */, CHILDREN_POSITION_FIELD_SIZE); - return PtNodeParams(originalPtNodeParams, flags, parentPos, codePointCount, codePoints, - probability); + return PtNodeParams(originalPtNodeParams, flags, parentPos, codePoints, probability); } -const PtNodeParams DynamicPtUpdatingHelper::getPtNodeParamsForNewPtNode( - const bool isNotAWord, const bool isBlacklisted, const bool isTerminal, - const int parentPos, const int codePointCount, const int *const codePoints, - const int probability) const { +const PtNodeParams DynamicPtUpdatingHelper::getPtNodeParamsForNewPtNode(const bool isNotAWord, + const bool isPossiblyOffensive, const bool isTerminal, const int parentPos, + const CodePointArrayView codePoints, const int probability) const { const PatriciaTrieReadingUtils::NodeFlags flags = PatriciaTrieReadingUtils::createAndGetFlags( - isBlacklisted, isNotAWord, isTerminal, false /* hasShortcutTargets */, - false /* hasBigrams */, codePointCount > 1 /* hasMultipleChars */, + isPossiblyOffensive, isNotAWord, isTerminal, false /* hasShortcutTargets */, + false /* hasBigrams */, codePoints.size() > 1u /* hasMultipleChars */, CHILDREN_POSITION_FIELD_SIZE); - return PtNodeParams(flags, parentPos, codePointCount, codePoints, probability); + return PtNodeParams(flags, parentPos, codePoints, probability); } } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.h index 97c05c1ea..db5f6ab17 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.h @@ -23,7 +23,7 @@ namespace latinime { -class BigramProperty; +class NgramProperty; class BufferWithExtendableBuffer; class DynamicPtReadingHelper; class PtNodeReader; @@ -40,19 +40,21 @@ class DynamicPtUpdatingHelper { // Add a word to the dictionary. If the word already exists, update the probability. bool addUnigramWord(DynamicPtReadingHelper *const readingHelper, - const int *const wordCodePoints, const int codePointCount, - const UnigramProperty *const unigramProperty, bool *const outAddedNewUnigram); + const CodePointArrayView wordCodePoints, const UnigramProperty *const unigramProperty, + bool *const outAddedNewUnigram); + // TODO: Remove after stopping supporting v402. // Add an n-gram entry. bool addNgramEntry(const PtNodePosArrayView prevWordsPtNodePos, const int wordPos, - const BigramProperty *const bigramProperty, bool *const outAddedNewEntry); + const NgramProperty *const ngramProperty, bool *const outAddedNewEntry); + // TODO: Remove after stopping supporting v402. // Remove an n-gram entry. bool removeNgramEntry(const PtNodePosArrayView prevWordsPtNodePos, const int wordPos); // Add a shortcut target. - bool addShortcutTarget(const int wordPos, const int *const targetCodePoints, - const int targetCodePointCount, const int shortcutProbability); + bool addShortcutTarget(const int wordPos, const CodePointArrayView targetCodePoints, + const int shortcutProbability); private: DISALLOW_IMPLICIT_CONSTRUCTORS(DynamicPtUpdatingHelper); @@ -63,33 +65,32 @@ class DynamicPtUpdatingHelper { const PtNodeReader *const mPtNodeReader; PtNodeWriter *const mPtNodeWriter; - bool createAndInsertNodeIntoPtNodeArray(const int parentPos, const int *const nodeCodePoints, - const int nodeCodePointCount, const UnigramProperty *const unigramProperty, + bool createAndInsertNodeIntoPtNodeArray(const int parentPos, + const CodePointArrayView ptNodeCodePoints, const UnigramProperty *const unigramProperty, int *const forwardLinkFieldPos); bool setPtNodeProbability(const PtNodeParams *const originalPtNodeParams, const UnigramProperty *const unigramProperty, bool *const outAddedNewUnigram); bool createChildrenPtNodeArrayAndAChildPtNode(const PtNodeParams *const parentPtNodeParams, - const UnigramProperty *const unigramProperty, const int *const codePoints, - const int codePointCount); + const UnigramProperty *const unigramProperty, + const CodePointArrayView remainingCodePoints); - bool createNewPtNodeArrayWithAChildPtNode(const int parentPos, const int *const nodeCodePoints, - const int nodeCodePointCount, const UnigramProperty *const unigramProperty); + bool createNewPtNodeArrayWithAChildPtNode(const int parentPos, + const CodePointArrayView ptNodeCodePoints, + const UnigramProperty *const unigramProperty); - bool reallocatePtNodeAndAddNewPtNodes( - const PtNodeParams *const reallocatingPtNodeParams, const int overlappingCodePointCount, - const UnigramProperty *const unigramProperty, const int *const newNodeCodePoints, - const int newNodeCodePointCount); + bool reallocatePtNodeAndAddNewPtNodes(const PtNodeParams *const reallocatingPtNodeParams, + const size_t overlappingCodePointCount, const UnigramProperty *const unigramProperty, + const CodePointArrayView newPtNodeCodePoints); const PtNodeParams getUpdatedPtNodeParams(const PtNodeParams *const originalPtNodeParams, - const bool isNotAWord, const bool isBlacklisted, const bool isTerminal, - const int parentPos, const int codePointCount, - const int *const codePoints, const int probability) const; + const bool isNotAWord, const bool isPossiblyOffensive, const bool isTerminal, + const int parentPos, const CodePointArrayView codePoints, const int probability) const; - const PtNodeParams getPtNodeParamsForNewPtNode(const bool isNotAWord, const bool isBlacklisted, - const bool isTerminal, const int parentPos, - const int codePointCount, const int *const codePoints, const int probability) const; + const PtNodeParams getPtNodeParamsForNewPtNode(const bool isNotAWord, + const bool isPossiblyOffensive, const bool isTerminal, const int parentPos, + const CodePointArrayView codePoints, const int probability) const; }; } // namespace latinime #endif /* LATINIME_DYNAMIC_PATRICIA_TRIE_UPDATING_HELPER_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.cpp index e64a13cc4..b8d78bf10 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.cpp @@ -41,8 +41,8 @@ const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_HAS_SHORTCUT_TARGETS = 0x08 const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_HAS_BIGRAMS = 0x04; // Flag for non-words (typically, shortcut only entries) const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_IS_NOT_A_WORD = 0x02; -// Flag for blacklist -const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_IS_BLACKLISTED = 0x01; +// Flag for possibly offensive words +const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_IS_POSSIBLY_OFFENSIVE = 0x01; /* static */ int PtReadingUtils::getPtNodeArraySizeAndAdvancePosition( const uint8_t *const buffer, int *const pos) { @@ -61,19 +61,20 @@ const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_IS_BLACKLISTED = 0x01; } /* static */ int PtReadingUtils::getCodePointAndAdvancePosition(const uint8_t *const buffer, - int *const pos) { - return ByteArrayUtils::readCodePointAndAdvancePosition(buffer, pos); + const int *const codePointTable, int *const pos) { + return ByteArrayUtils::readCodePointAndAdvancePosition(buffer, codePointTable, pos); } // Returns the number of read characters. /* static */ int PtReadingUtils::getCharsAndAdvancePosition(const uint8_t *const buffer, - const NodeFlags flags, const int maxLength, int *const outBuffer, int *const pos) { + const NodeFlags flags, const int maxLength, const int *const codePointTable, + int *const outBuffer, int *const pos) { int length = 0; if (hasMultipleChars(flags)) { - length = ByteArrayUtils::readStringAndAdvancePosition(buffer, maxLength, outBuffer, - pos); + length = ByteArrayUtils::readStringAndAdvancePosition(buffer, maxLength, codePointTable, + outBuffer, pos); } else { - const int codePoint = getCodePointAndAdvancePosition(buffer, pos); + const int codePoint = getCodePointAndAdvancePosition(buffer, codePointTable, pos); if (codePoint == NOT_A_CODE_POINT) { // CAVEAT: codePoint == NOT_A_CODE_POINT means the code point is // CHARACTER_ARRAY_TERMINATOR. The code point must not be CHARACTER_ARRAY_TERMINATOR @@ -92,12 +93,12 @@ const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_IS_BLACKLISTED = 0x01; // Returns the number of skipped characters. /* static */ int PtReadingUtils::skipCharacters(const uint8_t *const buffer, const NodeFlags flags, - const int maxLength, int *const pos) { + const int maxLength, const int *const codePointTable, int *const pos) { if (hasMultipleChars(flags)) { return ByteArrayUtils::advancePositionToBehindString(buffer, maxLength, pos); } else { if (maxLength > 0) { - getCodePointAndAdvancePosition(buffer, pos); + getCodePointAndAdvancePosition(buffer, codePointTable, pos); return 1; } else { return 0; @@ -134,7 +135,7 @@ const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_IS_BLACKLISTED = 0x01; /* static */ void PtReadingUtils::readPtNodeInfo(const uint8_t *const dictBuf, const int ptNodePos, const DictionaryShortcutsStructurePolicy *const shortcutPolicy, - const DictionaryBigramsStructurePolicy *const bigramPolicy, + const DictionaryBigramsStructurePolicy *const bigramPolicy, const int *const codePointTable, NodeFlags *const outFlags, int *const outCodePointCount, int *const outCodePoint, int *const outProbability, int *const outChildrenPos, int *const outShortcutPos, int *const outBigramPos, int *const outSiblingPos) { @@ -142,7 +143,7 @@ const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_IS_BLACKLISTED = 0x01; const NodeFlags flags = getFlagsAndAdvancePosition(dictBuf, &readingPos); *outFlags = flags; *outCodePointCount = getCharsAndAdvancePosition( - dictBuf, flags, MAX_WORD_LENGTH, outCodePoint, &readingPos); + dictBuf, flags, MAX_WORD_LENGTH, codePointTable, outCodePoint, &readingPos); *outProbability = isTerminal(flags) ? readProbabilityAndAdvancePosition(dictBuf, &readingPos) : NOT_A_PROBABILITY; *outChildrenPos = hasChildrenInFlags(flags) ? diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h index c3f09c3b1..6a2bf5d3c 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h @@ -34,15 +34,17 @@ class PatriciaTrieReadingUtils { static NodeFlags getFlagsAndAdvancePosition(const uint8_t *const buffer, int *const pos); - static int getCodePointAndAdvancePosition(const uint8_t *const buffer, int *const pos); + static int getCodePointAndAdvancePosition(const uint8_t *const buffer, + const int *const codePointTable, int *const pos); // Returns the number of read characters. static int getCharsAndAdvancePosition(const uint8_t *const buffer, const NodeFlags flags, - const int maxLength, int *const outBuffer, int *const pos); + const int maxLength, const int *const codePointTable, int *const outBuffer, + int *const pos); // Returns the number of skipped characters. static int skipCharacters(const uint8_t *const buffer, const NodeFlags flags, - const int maxLength, int *const pos); + const int maxLength, const int *const codePointTable, int *const pos); static int readProbabilityAndAdvancePosition(const uint8_t *const buffer, int *const pos); @@ -52,8 +54,8 @@ class PatriciaTrieReadingUtils { /** * Node Flags */ - static AK_FORCE_INLINE bool isBlacklisted(const NodeFlags flags) { - return (flags & FLAG_IS_BLACKLISTED) != 0; + static AK_FORCE_INLINE bool isPossiblyOffensive(const NodeFlags flags) { + return (flags & FLAG_IS_POSSIBLY_OFFENSIVE) != 0; } static AK_FORCE_INLINE bool isNotAWord(const NodeFlags flags) { @@ -80,12 +82,12 @@ class PatriciaTrieReadingUtils { return FLAG_CHILDREN_POSITION_TYPE_NOPOSITION != (MASK_CHILDREN_POSITION_TYPE & flags); } - static AK_FORCE_INLINE NodeFlags createAndGetFlags(const bool isBlacklisted, + static AK_FORCE_INLINE NodeFlags createAndGetFlags(const bool isPossiblyOffensive, const bool isNotAWord, const bool isTerminal, const bool hasShortcutTargets, const bool hasBigrams, const bool hasMultipleChars, const int childrenPositionFieldSize) { NodeFlags nodeFlags = 0; - nodeFlags = isBlacklisted ? (nodeFlags | FLAG_IS_BLACKLISTED) : nodeFlags; + nodeFlags = isPossiblyOffensive ? (nodeFlags | FLAG_IS_POSSIBLY_OFFENSIVE) : nodeFlags; nodeFlags = isNotAWord ? (nodeFlags | FLAG_IS_NOT_A_WORD) : nodeFlags; nodeFlags = isTerminal ? (nodeFlags | FLAG_IS_TERMINAL) : nodeFlags; nodeFlags = hasShortcutTargets ? (nodeFlags | FLAG_HAS_SHORTCUT_TARGETS) : nodeFlags; @@ -106,9 +108,10 @@ class PatriciaTrieReadingUtils { static void readPtNodeInfo(const uint8_t *const dictBuf, const int ptNodePos, const DictionaryShortcutsStructurePolicy *const shortcutPolicy, const DictionaryBigramsStructurePolicy *const bigramPolicy, - NodeFlags *const outFlags, int *const outCodePointCount, int *const outCodePoint, - int *const outProbability, int *const outChildrenPos, int *const outShortcutPos, - int *const outBigramPos, int *const outSiblingPos); + const int *const codePointTable, NodeFlags *const outFlags, + int *const outCodePointCount, int *const outCodePoint, int *const outProbability, + int *const outChildrenPos, int *const outShortcutPos, int *const outBigramPos, + int *const outSiblingPos); private: DISALLOW_IMPLICIT_CONSTRUCTORS(PatriciaTrieReadingUtils); @@ -124,7 +127,7 @@ class PatriciaTrieReadingUtils { static const NodeFlags FLAG_HAS_SHORTCUT_TARGETS; static const NodeFlags FLAG_HAS_BIGRAMS; static const NodeFlags FLAG_IS_NOT_A_WORD; - static const NodeFlags FLAG_IS_BLACKLISTED; + static const NodeFlags FLAG_IS_POSSIBLY_OFFENSIVE; }; } // namespace latinime #endif /* LATINIME_PATRICIA_TRIE_NODE_READING_UTILS_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h index b2e60a837..585e87a24 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h @@ -24,6 +24,7 @@ #include "suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h" #include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h" #include "utils/char_utils.h" +#include "utils/int_array_view.h" namespace latinime { @@ -88,9 +89,9 @@ class PtNodeParams { // Construct new params by updating existing PtNode params. PtNodeParams(const PtNodeParams *const ptNodeParams, const PatriciaTrieReadingUtils::NodeFlags flags, const int parentPos, - const int codePointCount, const int *const codePoints, const int probability) + const CodePointArrayView codePoints, const int probability) : mHeadPos(ptNodeParams->getHeadPos()), mFlags(flags), mHasMovedFlag(true), - mParentPos(parentPos), mCodePointCount(codePointCount), mCodePoints(), + mParentPos(parentPos), mCodePointCount(codePoints.size()), mCodePoints(), mTerminalIdFieldPos(ptNodeParams->getTerminalIdFieldPos()), mTerminalId(ptNodeParams->getTerminalId()), mProbabilityFieldPos(ptNodeParams->getProbabilityFieldPos()), @@ -101,20 +102,20 @@ class PtNodeParams { mShortcutPos(ptNodeParams->getShortcutPos()), mBigramPos(ptNodeParams->getBigramsPos()), mSiblingPos(ptNodeParams->getSiblingNodePos()) { - memcpy(mCodePoints, codePoints, sizeof(int) * mCodePointCount); + memcpy(mCodePoints, codePoints.data(), sizeof(int) * mCodePointCount); } PtNodeParams(const PatriciaTrieReadingUtils::NodeFlags flags, const int parentPos, - const int codePointCount, const int *const codePoints, const int probability) + const CodePointArrayView codePoints, const int probability) : mHeadPos(NOT_A_DICT_POS), mFlags(flags), mHasMovedFlag(true), mParentPos(parentPos), - mCodePointCount(codePointCount), mCodePoints(), + mCodePointCount(codePoints.size()), mCodePoints(), mTerminalIdFieldPos(NOT_A_DICT_POS), mTerminalId(Ver4DictConstants::NOT_A_TERMINAL_ID), mProbabilityFieldPos(NOT_A_DICT_POS), mProbability(probability), mChildrenPosFieldPos(NOT_A_DICT_POS), mChildrenPos(NOT_A_DICT_POS), mBigramLinkedNodePos(NOT_A_DICT_POS), mShortcutPos(NOT_A_DICT_POS), mBigramPos(NOT_A_DICT_POS), mSiblingPos(NOT_A_DICT_POS) { - memcpy(mCodePoints, codePoints, sizeof(int) * mCodePointCount); + memcpy(mCodePoints, codePoints.data(), sizeof(int) * mCodePointCount); } AK_FORCE_INLINE bool isValid() const { @@ -144,7 +145,18 @@ class PtNodeParams { } AK_FORCE_INLINE bool isBlacklisted() const { - return PatriciaTrieReadingUtils::isBlacklisted(mFlags); + // Note: this method will be removed in the next change. + // It is used in getProbabilityOfWord and getWordAttributes for both v402 and v403. + // * getProbabilityOfWord will be changed to no longer return NOT_A_PROBABILITY + // when isBlacklisted (i.e. to only check if isNotAWord or isDeleted) + // * getWordAttributes will be changed to always return blacklisted=false and + // isPossiblyOffensive according to the function below (instead of the current + // behaviour of checking if the probability is zero) + return PatriciaTrieReadingUtils::isPossiblyOffensive(mFlags); + } + + AK_FORCE_INLINE bool isPossiblyOffensive() const { + return PatriciaTrieReadingUtils::isPossiblyOffensive(mFlags); } AK_FORCE_INLINE bool isNotAWord() const { @@ -174,11 +186,17 @@ class PtNodeParams { return mParentPos; } + AK_FORCE_INLINE const CodePointArrayView getCodePointArrayView() const { + return CodePointArrayView(mCodePoints, mCodePointCount); + } + + // TODO: Remove // Number of code points AK_FORCE_INLINE uint8_t getCodePointCount() const { return mCodePointCount; } + // TODO: Remove AK_FORCE_INLINE const int *getCodePoints() const { return mCodePoints; } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_writer.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_writer.h index 955d779ac..954db9b0a 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_writer.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_writer.h @@ -25,7 +25,7 @@ namespace latinime { -class BigramProperty; +class NgramProperty; class UnigramProperty; // Interface class used to write PtNode information. @@ -72,7 +72,7 @@ class PtNodeWriter { const UnigramProperty *const unigramProperty, int *const ptNodeWritingPos) = 0; virtual bool addNgramEntry(const WordIdArrayView prevWordIds, const int wordId, - const BigramProperty *const bigramProperty, bool *const outAddedNewEntry) = 0; + const NgramProperty *const ngramProperty, bool *const outAddedNewEntry) = 0; virtual bool removeNgramEntry(const WordIdArrayView prevWordIds, const int wordId) = 0; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/shortcut/shortcut_list_reading_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/shortcut/shortcut_list_reading_utils.cpp index 91c76941c..40b872055 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/shortcut/shortcut_list_reading_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/shortcut/shortcut_list_reading_utils.cpp @@ -31,21 +31,23 @@ const int ShortcutListReadingUtils::SHORTCUT_LIST_SIZE_FIELD_SIZE = 2; const int ShortcutListReadingUtils::WHITELIST_SHORTCUT_PROBABILITY = 15; /* static */ ShortcutListReadingUtils::ShortcutFlags - ShortcutListReadingUtils::getFlagsAndForwardPointer(const uint8_t *const dictRoot, + ShortcutListReadingUtils::getFlagsAndForwardPointer(const ReadOnlyByteArrayView buffer, int *const pos) { - return ByteArrayUtils::readUint8AndAdvancePosition(dictRoot, pos); + return ByteArrayUtils::readUint8AndAdvancePosition(buffer.data(), pos); } /* static */ int ShortcutListReadingUtils::getShortcutListSizeAndForwardPointer( - const uint8_t *const dictRoot, int *const pos) { + const ReadOnlyByteArrayView buffer, int *const pos) { // readUint16andAdvancePosition() returns an offset *including* the uint16 field itself. - return ByteArrayUtils::readUint16AndAdvancePosition(dictRoot, pos) + return ByteArrayUtils::readUint16AndAdvancePosition(buffer.data(), pos) - SHORTCUT_LIST_SIZE_FIELD_SIZE; } -/* static */ int ShortcutListReadingUtils::readShortcutTarget( - const uint8_t *const dictRoot, const int maxLength, int *const outWord, int *const pos) { - return ByteArrayUtils::readStringAndAdvancePosition(dictRoot, maxLength, outWord, pos); +/* static */ int ShortcutListReadingUtils::readShortcutTarget(const ReadOnlyByteArrayView buffer, + const int maxLength, int *const outWord, int *const pos) { + // TODO: Use codePointTable for shortcuts. + return ByteArrayUtils::readStringAndAdvancePosition(buffer.data(), maxLength, + nullptr /* codePointTable */, outWord, pos); } } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/shortcut/shortcut_list_reading_utils.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/shortcut/shortcut_list_reading_utils.h index d065bf7fd..71cb8cc2c 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/shortcut/shortcut_list_reading_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/shortcut/shortcut_list_reading_utils.h @@ -20,6 +20,7 @@ #include <cstdint> #include "defines.h" +#include "utils/byte_array_view.h" namespace latinime { @@ -27,7 +28,8 @@ class ShortcutListReadingUtils { public: typedef uint8_t ShortcutFlags; - static ShortcutFlags getFlagsAndForwardPointer(const uint8_t *const dictRoot, int *const pos); + static ShortcutFlags getFlagsAndForwardPointer(const ReadOnlyByteArrayView buffer, + int *const pos); static AK_FORCE_INLINE int getProbabilityFromFlags(const ShortcutFlags flags) { return flags & MASK_ATTRIBUTE_PROBABILITY; @@ -39,14 +41,15 @@ class ShortcutListReadingUtils { // This method returns the size of the shortcut list region excluding the shortcut list size // field at the beginning. - static int getShortcutListSizeAndForwardPointer(const uint8_t *const dictRoot, int *const pos); + static int getShortcutListSizeAndForwardPointer(const ReadOnlyByteArrayView buffer, + int *const pos); static AK_FORCE_INLINE int getShortcutListSizeFieldSize() { return SHORTCUT_LIST_SIZE_FIELD_SIZE; } - static AK_FORCE_INLINE void skipShortcuts(const uint8_t *const dictRoot, int *const pos) { - const int shortcutListSize = getShortcutListSizeAndForwardPointer(dictRoot, pos); + static AK_FORCE_INLINE void skipShortcuts(const ReadOnlyByteArrayView buffer, int *const pos) { + const int shortcutListSize = getShortcutListSizeAndForwardPointer(buffer, pos); *pos += shortcutListSize; } @@ -54,7 +57,7 @@ class ShortcutListReadingUtils { return getProbabilityFromFlags(flags) == WHITELIST_SHORTCUT_PROBABILITY; } - static int readShortcutTarget(const uint8_t *const dictRoot, const int maxLength, + static int readShortcutTarget(const ReadOnlyByteArrayView buffer, const int maxLength, int *const outWord, int *const pos); private: diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/bigram/bigram_list_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/bigram/bigram_list_policy.h index 73e291ec2..e2608435c 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/bigram/bigram_list_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/bigram/bigram_list_policy.h @@ -22,22 +22,22 @@ #include "defines.h" #include "suggest/core/policy/dictionary_bigrams_structure_policy.h" #include "suggest/policyimpl/dictionary/structure/pt_common/bigram/bigram_list_read_write_utils.h" +#include "utils/byte_array_view.h" namespace latinime { class BigramListPolicy : public DictionaryBigramsStructurePolicy { public: - BigramListPolicy(const uint8_t *const bigramsBuf, const int bufSize) - : mBigramsBuf(bigramsBuf), mBufSize(bufSize) {} + BigramListPolicy(const ReadOnlyByteArrayView buffer) : mBuffer(buffer) {} ~BigramListPolicy() {} void getNextBigram(int *const outBigramPos, int *const outProbability, bool *const outHasNext, int *const pos) const { BigramListReadWriteUtils::BigramFlags flags; - if (!BigramListReadWriteUtils::getBigramEntryPropertiesAndAdvancePosition(mBigramsBuf, - mBufSize, &flags, outBigramPos, pos)) { - AKLOGE("Cannot read bigram entry. mBufSize: %d, pos: %d. ", mBufSize, *pos); + if (!BigramListReadWriteUtils::getBigramEntryPropertiesAndAdvancePosition(mBuffer, &flags, + outBigramPos, pos)) { + AKLOGE("Cannot read bigram entry. bufSize: %zd, pos: %d. ", mBuffer.size(), *pos); *outProbability = NOT_A_PROBABILITY; *outHasNext = false; return; @@ -47,14 +47,13 @@ class BigramListPolicy : public DictionaryBigramsStructurePolicy { } bool skipAllBigrams(int *const pos) const { - return BigramListReadWriteUtils::skipExistingBigrams(mBigramsBuf, mBufSize, pos); + return BigramListReadWriteUtils::skipExistingBigrams(mBuffer, pos); } private: DISALLOW_IMPLICIT_CONSTRUCTORS(BigramListPolicy); - const uint8_t *const mBigramsBuf; - const int mBufSize; + const ReadOnlyByteArrayView mBuffer; }; } // namespace latinime #endif // LATINIME_BIGRAM_LIST_POLICY_H diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp index ea32eb2a9..66fd18a52 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp @@ -21,8 +21,9 @@ #include "suggest/core/dicnode/dic_node.h" #include "suggest/core/dicnode/dic_node_vector.h" #include "suggest/core/dictionary/binary_dictionary_bigrams_iterator.h" +#include "suggest/core/dictionary/multi_bigram_map.h" #include "suggest/core/dictionary/ngram_listener.h" -#include "suggest/core/session/prev_words_info.h" +#include "suggest/core/session/ngram_context.h" #include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h" #include "suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h" #include "suggest/policyimpl/dictionary/utils/probability_utils.h" @@ -36,19 +37,19 @@ void PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const dicNo return; } int nextPos = dicNode->getChildrenPtNodeArrayPos(); - if (nextPos < 0 || nextPos >= mDictBufferSize) { - AKLOGE("Children PtNode array position is invalid. pos: %d, dict size: %d", - nextPos, mDictBufferSize); + if (!isValidPos(nextPos)) { + AKLOGE("Children PtNode array position is invalid. pos: %d, dict size: %zd", + nextPos, mBuffer.size()); mIsCorrupted = true; ASSERT(false); return; } const int childCount = PatriciaTrieReadingUtils::getPtNodeArraySizeAndAdvancePosition( - mDictRoot, &nextPos); + mBuffer.data(), &nextPos); for (int i = 0; i < childCount; i++) { - if (nextPos < 0 || nextPos >= mDictBufferSize) { - AKLOGE("Child PtNode position is invalid. pos: %d, dict size: %d, childCount: %d / %d", - nextPos, mDictBufferSize, i, childCount); + if (!isValidPos(nextPos)) { + AKLOGE("Child PtNode position is invalid. pos: %d, dict size: %zd, childCount: %d / %d", + nextPos, mBuffer.size(), i, childCount); mIsCorrupted = true; ASSERT(false); return; @@ -57,7 +58,12 @@ void PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const dicNo } } -// This retrieves code points and the probability of the word by its terminal position. +int PatriciaTriePolicy::getCodePointsAndReturnCodePointCount(const int wordId, + const int maxCodePointCount, int *const outCodePoints) const { + return getCodePointsAndProbabilityAndReturnCodePointCount(wordId, maxCodePointCount, + outCodePoints, nullptr /* outUnigramProbability */); +} +// This retrieves code points and the probability of the word by its id. // Due to the fact that words are ordered in the dictionary in a strict breadth-first order, // it is possible to check for this with advantageous complexity. For each PtNode array, we search // for PtNodes with children and compare the children position with the position we look for. @@ -68,18 +74,22 @@ void PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const dicNo // with a z, it's the last PtNode of the root array, so all children addresses will be smaller // than the position we look for, and we have to descend the z PtNode). /* Parameters : - * ptNodePos: the byte position of the terminal PtNode of the word we are searching for (this is - * what is stored as the "bigram position" in each bigram) + * wordId: Id of the word we are searching for. * outCodePoints: an array to write the found word, with MAX_WORD_LENGTH size. * outUnigramProbability: a pointer to an int to write the probability into. * Return value : the code point count, of 0 if the word was not found. */ // TODO: Split this function to be more readable int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( - const int ptNodePos, const int maxCodePointCount, int *const outCodePoints, + const int wordId, const int maxCodePointCount, int *const outCodePoints, int *const outUnigramProbability) const { + const int ptNodePos = getTerminalPtNodePosFromWordId(wordId); int pos = getRootPosition(); int wordPos = 0; + const int *const codePointTable = mHeaderPolicy.getCodePointTable(); + if (outUnigramProbability) { + *outUnigramProbability = NOT_A_PROBABILITY; + } // One iteration of the outer loop iterates through PtNode arrays. As stated above, we will // only traverse PtNodes that are actually a part of the terminal we are searching, so each // time we enter this loop we are one depth level further than last time. @@ -90,56 +100,57 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( int lastCandidatePtNodePos = 0; // Let's loop through PtNodes in this PtNode array searching for either the terminal // or one of its ascendants. - if (pos < 0 || pos >= mDictBufferSize) { - AKLOGE("PtNode array position is invalid. pos: %d, dict size: %d", - pos, mDictBufferSize); + if (!isValidPos(pos)) { + AKLOGE("PtNode array position is invalid. pos: %d, dict size: %zd", + pos, mBuffer.size()); mIsCorrupted = true; ASSERT(false); - *outUnigramProbability = NOT_A_PROBABILITY; return 0; } for (int ptNodeCount = PatriciaTrieReadingUtils::getPtNodeArraySizeAndAdvancePosition( - mDictRoot, &pos); ptNodeCount > 0; --ptNodeCount) { + mBuffer.data(), &pos); ptNodeCount > 0; --ptNodeCount) { const int startPos = pos; - if (pos < 0 || pos >= mDictBufferSize) { - AKLOGE("PtNode position is invalid. pos: %d, dict size: %d", pos, mDictBufferSize); + if (!isValidPos(pos)) { + AKLOGE("PtNode position is invalid. pos: %d, dict size: %zd", pos, mBuffer.size()); mIsCorrupted = true; ASSERT(false); - *outUnigramProbability = NOT_A_PROBABILITY; return 0; } const PatriciaTrieReadingUtils::NodeFlags flags = - PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(mDictRoot, &pos); + PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(mBuffer.data(), &pos); const int character = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition( - mDictRoot, &pos); + mBuffer.data(), codePointTable, &pos); if (ptNodePos == startPos) { // We found the position. Copy the rest of the code points in the buffer and return // the length. outCodePoints[wordPos] = character; if (PatriciaTrieReadingUtils::hasMultipleChars(flags)) { int nextChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition( - mDictRoot, &pos); + mBuffer.data(), codePointTable, &pos); // We count code points in order to avoid infinite loops if the file is broken // or if there is some other bug int charCount = maxCodePointCount; while (NOT_A_CODE_POINT != nextChar && --charCount > 0) { outCodePoints[++wordPos] = nextChar; nextChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition( - mDictRoot, &pos); + mBuffer.data(), codePointTable, &pos); } } - *outUnigramProbability = - PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mDictRoot, - &pos); + if (outUnigramProbability) { + *outUnigramProbability = + PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition( + mBuffer.data(), &pos); + } return ++wordPos; } // We need to skip past this PtNode, so skip any remaining code points after the // first and possibly the probability. if (PatriciaTrieReadingUtils::hasMultipleChars(flags)) { - PatriciaTrieReadingUtils::skipCharacters(mDictRoot, flags, MAX_WORD_LENGTH, &pos); + PatriciaTrieReadingUtils::skipCharacters(mBuffer.data(), flags, MAX_WORD_LENGTH, + codePointTable, &pos); } if (PatriciaTrieReadingUtils::isTerminal(flags)) { - PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mDictRoot, &pos); + PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mBuffer.data(), &pos); } // The fact that this PtNode has children is very important. Since we already know // that this PtNode does not match, if it has no children we know it is irrelevant @@ -154,7 +165,8 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( int currentPos = pos; // Here comes the tricky part. First, read the children position. const int childrenPos = PatriciaTrieReadingUtils - ::readChildrenPositionAndAdvancePosition(mDictRoot, flags, ¤tPos); + ::readChildrenPositionAndAdvancePosition(mBuffer.data(), flags, + ¤tPos); if (childrenPos > ptNodePos) { // If the children pos is greater than the position, it means the previous // PtNode, which position is stored in lastCandidatePtNodePos, was the right @@ -184,30 +196,30 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( if (0 != lastCandidatePtNodePos) { const PatriciaTrieReadingUtils::NodeFlags lastFlags = PatriciaTrieReadingUtils::getFlagsAndAdvancePosition( - mDictRoot, &lastCandidatePtNodePos); + mBuffer.data(), &lastCandidatePtNodePos); const int lastChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition( - mDictRoot, &lastCandidatePtNodePos); + mBuffer.data(), codePointTable, &lastCandidatePtNodePos); // We copy all the characters in this PtNode to the buffer outCodePoints[wordPos] = lastChar; if (PatriciaTrieReadingUtils::hasMultipleChars(lastFlags)) { int nextChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition( - mDictRoot, &lastCandidatePtNodePos); + mBuffer.data(), codePointTable, &lastCandidatePtNodePos); int charCount = maxCodePointCount; while (-1 != nextChar && --charCount > 0) { outCodePoints[++wordPos] = nextChar; nextChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition( - mDictRoot, &lastCandidatePtNodePos); + mBuffer.data(), codePointTable, &lastCandidatePtNodePos); } } ++wordPos; // Now we only need to branch to the children address. Skip the probability if // it's there, read pos, and break to resume the search at pos. if (PatriciaTrieReadingUtils::isTerminal(lastFlags)) { - PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mDictRoot, + PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mBuffer.data(), &lastCandidatePtNodePos); } pos = PatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition( - mDictRoot, lastFlags, &lastCandidatePtNodePos); + mBuffer.data(), lastFlags, &lastCandidatePtNodePos); break; } else { // Here is a little tricky part: we come here if we found out that all children @@ -219,18 +231,17 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( // ready to start the next one. if (PatriciaTrieReadingUtils::hasChildrenInFlags(flags)) { PatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition( - mDictRoot, flags, &pos); + mBuffer.data(), flags, &pos); } if (PatriciaTrieReadingUtils::hasShortcutTargets(flags)) { mShortcutListPolicy.skipAllShortcuts(&pos); } if (PatriciaTrieReadingUtils::hasBigrams(flags)) { if (!mBigramListPolicy.skipAllBigrams(&pos)) { - AKLOGE("Cannot skip bigrams. BufSize: %d, pos: %d.", mDictBufferSize, + AKLOGE("Cannot skip bigrams. BufSize: %zd, pos: %d.", mBuffer.size(), pos); mIsCorrupted = true; ASSERT(false); - *outUnigramProbability = NOT_A_PROBABILITY; return 0; } } @@ -243,17 +254,16 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( // our pos is after the end of this PtNode, at the start of the next one. if (PatriciaTrieReadingUtils::hasChildrenInFlags(flags)) { PatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition( - mDictRoot, flags, &pos); + mBuffer.data(), flags, &pos); } if (PatriciaTrieReadingUtils::hasShortcutTargets(flags)) { mShortcutListPolicy.skipAllShortcuts(&pos); } if (PatriciaTrieReadingUtils::hasBigrams(flags)) { if (!mBigramListPolicy.skipAllBigrams(&pos)) { - AKLOGE("Cannot skip bigrams. BufSize: %d, pos: %d.", mDictBufferSize, pos); + AKLOGE("Cannot skip bigrams. BufSize: %zd, pos: %d.", mBuffer.size(), pos); mIsCorrupted = true; ASSERT(false); - *outUnigramProbability = NOT_A_PROBABILITY; return 0; } } @@ -267,18 +277,48 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( } // This function gets the position of the terminal PtNode of the exact matching word in the -// dictionary. If no match is found, it returns NOT_A_DICT_POS. -int PatriciaTriePolicy::getTerminalPtNodePositionOfWord(const int *const inWord, - const int length, const bool forceLowerCaseSearch) const { +// dictionary. If no match is found, it returns NOT_A_WORD_ID. +int PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints, + const bool forceLowerCaseSearch) const { DynamicPtReadingHelper readingHelper(&mPtNodeReader, &mPtNodeArrayReader); readingHelper.initWithPtNodeArrayPos(getRootPosition()); - const int ptNodePos = - readingHelper.getTerminalPtNodePositionOfWord(inWord, length, forceLowerCaseSearch); + const int ptNodePos = readingHelper.getTerminalPtNodePositionOfWord(wordCodePoints.data(), + wordCodePoints.size(), forceLowerCaseSearch); if (readingHelper.isError()) { mIsCorrupted = true; - AKLOGE("Dictionary reading error in createAndGetAllChildDicNodes()."); + AKLOGE("Dictionary reading error in getWordId()."); + } + return getWordIdFromTerminalPtNodePos(ptNodePos); +} + +const WordAttributes PatriciaTriePolicy::getWordAttributesInContext( + const WordIdArrayView prevWordIds, const int wordId, + MultiBigramMap *const multiBigramMap) const { + if (wordId == NOT_A_WORD_ID) { + return WordAttributes(); } - return ptNodePos; + const int ptNodePos = getTerminalPtNodePosFromWordId(wordId); + const PtNodeParams ptNodeParams = + mPtNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); + if (multiBigramMap) { + const int probability = multiBigramMap->getBigramProbability(this /* structurePolicy */, + prevWordIds, wordId, ptNodeParams.getProbability()); + return getWordAttributes(probability, ptNodeParams); + } + if (!prevWordIds.empty()) { + const int bigramProbability = getProbabilityOfWord(prevWordIds, wordId); + if (bigramProbability != NOT_A_PROBABILITY) { + return getWordAttributes(bigramProbability, ptNodeParams); + } + } + return getWordAttributes(getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY), + ptNodeParams); +} + +const WordAttributes PatriciaTriePolicy::getWordAttributes(const int probability, + const PtNodeParams &ptNodeParams) const { + return WordAttributes(probability, ptNodeParams.isBlacklisted(), ptNodeParams.isNotAWord(), + ptNodeParams.getProbability() == 0); } int PatriciaTriePolicy::getProbability(const int unigramProbability, @@ -297,11 +337,12 @@ int PatriciaTriePolicy::getProbability(const int unigramProbability, } } -int PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtNodePos, - const int ptNodePos) const { - if (ptNodePos == NOT_A_DICT_POS) { +int PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds, + const int wordId) const { + if (wordId == NOT_A_WORD_ID) { return NOT_A_PROBABILITY; } + const int ptNodePos = getTerminalPtNodePosFromWordId(wordId); const PtNodeParams ptNodeParams = mPtNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); if (ptNodeParams.isNotAWord() || ptNodeParams.isBlacklisted()) { @@ -310,8 +351,9 @@ int PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtNodeP // for shortcuts). return NOT_A_PROBABILITY; } - if (prevWordsPtNodePos) { - const int bigramsPosition = getBigramsPositionOfPtNode(prevWordsPtNodePos[0]); + if (!prevWordIds.empty()) { + const int bigramsPosition = getBigramsPositionOfPtNode( + getTerminalPtNodePosFromWordId(prevWordIds[0])); BinaryDictionaryBigramsIterator bigramsIt(&mBigramListPolicy, bigramsPosition); while (bigramsIt.hasNext()) { bigramsIt.next(); @@ -325,19 +367,26 @@ int PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtNodeP return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY); } -void PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordsPtNodePos, +void PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordIds, NgramListener *const listener) const { - if (!prevWordsPtNodePos) { + if (prevWordIds.empty()) { return; } - const int bigramsPosition = getBigramsPositionOfPtNode(prevWordsPtNodePos[0]); + const int bigramsPosition = getBigramsPositionOfPtNode( + getTerminalPtNodePosFromWordId(prevWordIds[0])); BinaryDictionaryBigramsIterator bigramsIt(&mBigramListPolicy, bigramsPosition); while (bigramsIt.hasNext()) { bigramsIt.next(); - listener->onVisitEntry(bigramsIt.getProbability(), bigramsIt.getBigramPos()); + listener->onVisitEntry(bigramsIt.getProbability(), + getWordIdFromTerminalPtNodePos(bigramsIt.getBigramPos())); } } +BinaryDictionaryShortcutIterator PatriciaTriePolicy::getShortcutIterator(const int wordId) const { + const int shortcutPos = getShortcutPositionOfPtNode(getTerminalPtNodePosFromWordId(wordId)); + return BinaryDictionaryShortcutIterator(&mShortcutListPolicy, shortcutPos); +} + int PatriciaTriePolicy::getShortcutPositionOfPtNode(const int ptNodePos) const { if (ptNodePos == NOT_A_DICT_POS) { return NOT_A_DICT_POS; @@ -362,35 +411,32 @@ int PatriciaTriePolicy::createAndGetLeavingChildNode(const DicNode *const dicNod int shortcutPos = NOT_A_DICT_POS; int bigramPos = NOT_A_DICT_POS; int siblingPos = NOT_A_DICT_POS; - PatriciaTrieReadingUtils::readPtNodeInfo(mDictRoot, ptNodePos, getShortcutsStructurePolicy(), - &mBigramListPolicy, &flags, &mergedNodeCodePointCount, mergedNodeCodePoints, - &probability, &childrenPos, &shortcutPos, &bigramPos, &siblingPos); + const int *const codePointTable = mHeaderPolicy.getCodePointTable(); + PatriciaTrieReadingUtils::readPtNodeInfo(mBuffer.data(), ptNodePos, &mShortcutListPolicy, + &mBigramListPolicy, codePointTable, &flags, &mergedNodeCodePointCount, + mergedNodeCodePoints, &probability, &childrenPos, &shortcutPos, &bigramPos, + &siblingPos); // Skip PtNodes don't start with Unicode code point because they represent non-word information. if (CharUtils::isInUnicodeSpace(mergedNodeCodePoints[0])) { - childDicNodes->pushLeavingChild(dicNode, ptNodePos, childrenPos, probability, - PatriciaTrieReadingUtils::isTerminal(flags), - PatriciaTrieReadingUtils::hasChildrenInFlags(flags), - PatriciaTrieReadingUtils::isBlacklisted(flags) - || PatriciaTrieReadingUtils::isNotAWord(flags), - mergedNodeCodePointCount, mergedNodeCodePoints); + const int wordId = PatriciaTrieReadingUtils::isTerminal(flags) ? ptNodePos : NOT_A_WORD_ID; + childDicNodes->pushLeavingChild(dicNode, childrenPos, wordId, + CodePointArrayView(mergedNodeCodePoints, mergedNodeCodePointCount)); } return siblingPos; } -const WordProperty PatriciaTriePolicy::getWordProperty(const int *const codePoints, - const int codePointCount) const { - const int ptNodePos = getTerminalPtNodePositionOfWord(codePoints, codePointCount, - false /* forceLowerCaseSearch */); - if (ptNodePos == NOT_A_DICT_POS) { +const WordProperty PatriciaTriePolicy::getWordProperty( + const CodePointArrayView wordCodePoints) const { + const int wordId = getWordId(wordCodePoints, false /* forceLowerCaseSearch */); + if (wordId == NOT_A_WORD_ID) { AKLOGE("getWordProperty was called for invalid word."); return WordProperty(); } + const int ptNodePos = getTerminalPtNodePosFromWordId(wordId); const PtNodeParams ptNodeParams = mPtNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); - std::vector<int> codePointVector(ptNodeParams.getCodePoints(), - ptNodeParams.getCodePoints() + ptNodeParams.getCodePointCount()); // Fetch bigram information. - std::vector<BigramProperty> bigrams; + std::vector<NgramProperty> ngrams; const int bigramListPos = getBigramsPositionOfPtNode(ptNodePos); int bigramWord1CodePoints[MAX_WORD_LENGTH]; BinaryDictionaryBigramsIterator bigramsIt(&mBigramListPolicy, bigramListPos); @@ -401,13 +447,14 @@ const WordProperty PatriciaTriePolicy::getWordProperty(const int *const codePoin if (bigramsIt.getBigramPos() != NOT_A_DICT_POS) { int word1Probability = NOT_A_PROBABILITY; const int word1CodePointCount = getCodePointsAndProbabilityAndReturnCodePointCount( - bigramsIt.getBigramPos(), MAX_WORD_LENGTH, bigramWord1CodePoints, - &word1Probability); - const std::vector<int> word1(bigramWord1CodePoints, - bigramWord1CodePoints + word1CodePointCount); + getWordIdFromTerminalPtNodePos(bigramsIt.getBigramPos()), MAX_WORD_LENGTH, + bigramWord1CodePoints, &word1Probability); const int probability = getProbability(word1Probability, bigramsIt.getProbability()); - bigrams.emplace_back(&word1, probability, - NOT_A_TIMESTAMP /* timestamp */, 0 /* level */, 0 /* count */); + ngrams.emplace_back( + NgramContext(wordCodePoints.data(), wordCodePoints.size(), + ptNodeParams.representsBeginningOfSentence()), + CodePointArrayView(bigramWord1CodePoints, word1CodePointCount).toVector(), + probability, HistoricalInfo()); } } // Fetch shortcut information. @@ -415,25 +462,25 @@ const WordProperty PatriciaTriePolicy::getWordProperty(const int *const codePoin int shortcutPos = getShortcutPositionOfPtNode(ptNodePos); if (shortcutPos != NOT_A_DICT_POS) { int shortcutTargetCodePoints[MAX_WORD_LENGTH]; - ShortcutListReadingUtils::getShortcutListSizeAndForwardPointer(mDictRoot, &shortcutPos); + ShortcutListReadingUtils::getShortcutListSizeAndForwardPointer(mBuffer, &shortcutPos); bool hasNext = true; while (hasNext) { const ShortcutListReadingUtils::ShortcutFlags shortcutFlags = - ShortcutListReadingUtils::getFlagsAndForwardPointer(mDictRoot, &shortcutPos); + ShortcutListReadingUtils::getFlagsAndForwardPointer(mBuffer, &shortcutPos); hasNext = ShortcutListReadingUtils::hasNext(shortcutFlags); const int shortcutTargetLength = ShortcutListReadingUtils::readShortcutTarget( - mDictRoot, MAX_WORD_LENGTH, shortcutTargetCodePoints, &shortcutPos); - const std::vector<int> shortcutTarget(shortcutTargetCodePoints, - shortcutTargetCodePoints + shortcutTargetLength); + mBuffer, MAX_WORD_LENGTH, shortcutTargetCodePoints, &shortcutPos); const int shortcutProbability = ShortcutListReadingUtils::getProbabilityFromFlags(shortcutFlags); - shortcuts.emplace_back(&shortcutTarget, shortcutProbability); + shortcuts.emplace_back( + CodePointArrayView(shortcutTargetCodePoints, shortcutTargetLength).toVector(), + shortcutProbability); } } const UnigramProperty unigramProperty(ptNodeParams.representsBeginningOfSentence(), - ptNodeParams.isNotAWord(), ptNodeParams.isBlacklisted(), ptNodeParams.getProbability(), - NOT_A_TIMESTAMP /* timestamp */, 0 /* level */, 0 /* count */, &shortcuts); - return WordProperty(&codePointVector, &unigramProperty, &bigrams); + ptNodeParams.isNotAWord(), ptNodeParams.isPossiblyOffensive(), + ptNodeParams.getProbability(), HistoricalInfo(), std::move(shortcuts)); + return WordProperty(wordCodePoints.toVector(), &unigramProperty, &ngrams); } int PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outCodePoints, @@ -455,9 +502,8 @@ int PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outC return 0; } const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token]; - int unigramProbability = NOT_A_PROBABILITY; - *outCodePointCount = getCodePointsAndProbabilityAndReturnCodePointCount(terminalPtNodePos, - MAX_WORD_LENGTH, outCodePoints, &unigramProbability); + *outCodePointCount = getCodePointsAndReturnCodePointCount( + getWordIdFromTerminalPtNodePos(terminalPtNodePos), MAX_WORD_LENGTH, outCodePoints); const int nextToken = token + 1; if (nextToken >= terminalPtNodePositionsVectorSize) { // All words have been iterated. @@ -467,4 +513,16 @@ int PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outC return nextToken; } +int PatriciaTriePolicy::getWordIdFromTerminalPtNodePos(const int ptNodePos) const { + return ptNodePos == NOT_A_DICT_POS ? NOT_A_WORD_ID : ptNodePos; +} + +int PatriciaTriePolicy::getTerminalPtNodePosFromWordId(const int wordId) const { + return wordId == NOT_A_WORD_ID ? NOT_A_DICT_POS : wordId; +} + +bool PatriciaTriePolicy::isValidPos(const int pos) const { + return pos >= 0 && pos < static_cast<int>(mBuffer.size()); +} + } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h index 70351d147..8933962ab 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h @@ -30,26 +30,27 @@ #include "suggest/policyimpl/dictionary/utils/format_utils.h" #include "suggest/policyimpl/dictionary/utils/mmapped_buffer.h" #include "utils/byte_array_view.h" +#include "utils/int_array_view.h" namespace latinime { class DicNode; class DicNodeVector; +// Word id = Position of a PtNode that represents the word. +// Max supported n-gram is bigram. class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { public: PatriciaTriePolicy(MmappedBuffer::MmappedBufferPtr mmappedBuffer) : mMmappedBuffer(std::move(mmappedBuffer)), mHeaderPolicy(mMmappedBuffer->getReadOnlyByteArrayView().data(), - FormatUtils::VERSION_2), - mDictRoot(mMmappedBuffer->getReadOnlyByteArrayView().data() - + mHeaderPolicy.getSize()), - mDictBufferSize(mMmappedBuffer->getReadOnlyByteArrayView().size() - - mHeaderPolicy.getSize()), - mBigramListPolicy(mDictRoot, mDictBufferSize), mShortcutListPolicy(mDictRoot), - mPtNodeReader(mDictRoot, mDictBufferSize, &mBigramListPolicy, &mShortcutListPolicy), - mPtNodeArrayReader(mDictRoot, mDictBufferSize), - mTerminalPtNodePositionsForIteratingWords(), mIsCorrupted(false) {} + FormatUtils::detectFormatVersion(mMmappedBuffer->getReadOnlyByteArrayView())), + mBuffer(mMmappedBuffer->getReadOnlyByteArrayView().skip(mHeaderPolicy.getSize())), + mBigramListPolicy(mBuffer), mShortcutListPolicy(mBuffer), + mPtNodeReader(mBuffer, &mBigramListPolicy, &mShortcutListPolicy, + mHeaderPolicy.getCodePointTable()), + mPtNodeArrayReader(mBuffer), mTerminalPtNodePositionsForIteratingWords(), + mIsCorrupted(false) {} AK_FORCE_INLINE int getRootPosition() const { return 0; @@ -58,57 +59,62 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { void createAndGetAllChildDicNodes(const DicNode *const dicNode, DicNodeVector *const childDicNodes) const; - int getCodePointsAndProbabilityAndReturnCodePointCount( - const int terminalNodePos, const int maxCodePointCount, int *const outCodePoints, - int *const outUnigramProbability) const; + int getCodePointsAndReturnCodePointCount(const int wordId, const int maxCodePointCount, + int *const outCodePoints) const; + + int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const; - int getTerminalPtNodePositionOfWord(const int *const inWord, - const int length, const bool forceLowerCaseSearch) const; + const WordAttributes getWordAttributesInContext(const WordIdArrayView prevWordIds, + const int wordId, MultiBigramMap *const multiBigramMap) const; int getProbability(const int unigramProbability, const int bigramProbability) const; - int getProbabilityOfPtNode(const int *const prevWordsPtNodePos, const int ptNodePos) const; + int getProbabilityOfWord(const WordIdArrayView prevWordIds, const int wordId) const; - void iterateNgramEntries(const int *const prevWordsPtNodePos, + void iterateNgramEntries(const WordIdArrayView prevWordIds, NgramListener *const listener) const; - int getShortcutPositionOfPtNode(const int ptNodePos) const; + BinaryDictionaryShortcutIterator getShortcutIterator(const int wordId) const; const DictionaryHeaderStructurePolicy *getHeaderStructurePolicy() const { return &mHeaderPolicy; } - const DictionaryShortcutsStructurePolicy *getShortcutsStructurePolicy() const { - return &mShortcutListPolicy; - } - - bool addUnigramEntry(const int *const word, const int length, + bool addUnigramEntry(const CodePointArrayView wordCodePoints, const UnigramProperty *const unigramProperty) { // This method should not be called for non-updatable dictionary. AKLOGI("Warning: addUnigramEntry() is called for non-updatable dictionary."); return false; } - bool removeUnigramEntry(const int *const word, const int length) { + bool removeUnigramEntry(const CodePointArrayView wordCodePoints) { // This method should not be called for non-updatable dictionary. AKLOGI("Warning: removeUnigramEntry() is called for non-updatable dictionary."); return false; } - bool addNgramEntry(const PrevWordsInfo *const prevWordsInfo, - const BigramProperty *const bigramProperty) { + bool addNgramEntry(const NgramProperty *const ngramProperty) { // This method should not be called for non-updatable dictionary. AKLOGI("Warning: addNgramEntry() is called for non-updatable dictionary."); return false; } - bool removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, const int *const word, - const int length) { + bool removeNgramEntry(const NgramContext *const ngramContext, + const CodePointArrayView wordCodePoints) { // This method should not be called for non-updatable dictionary. AKLOGI("Warning: removeNgramEntry() is called for non-updatable dictionary."); return false; } + bool updateEntriesForWordWithNgramContext(const NgramContext *const ngramContext, + const CodePointArrayView wordCodePoints, const bool isValidWord, + const HistoricalInfo historicalInfo) { + // This method should not be called for non-updatable dictionary. + AKLOGI("Warning: updateEntriesForWordWithNgramContext() is called for non-updatable " + "dictionary."); + return false; + } + bool flush(const char *const filePath) { // This method should not be called for non-updatable dictionary. AKLOGI("Warning: flush() is called for non-updatable dictionary."); @@ -135,8 +141,7 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { } } - const WordProperty getWordProperty(const int *const codePoints, - const int codePointCount) const; + const WordProperty getWordProperty(const CodePointArrayView wordCodePoints) const; int getNextWordAndNextToken(const int token, int *const outCodePoints, int *const outCodePointCount); @@ -150,8 +155,7 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { const MmappedBuffer::MmappedBufferPtr mMmappedBuffer; const HeaderPolicy mHeaderPolicy; - const uint8_t *const mDictRoot; - const int mDictBufferSize; + const ReadOnlyByteArrayView mBuffer; const BigramListPolicy mBigramListPolicy; const ShortcutListPolicy mShortcutListPolicy; const Ver2ParticiaTrieNodeReader mPtNodeReader; @@ -159,9 +163,18 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { std::vector<int> mTerminalPtNodePositionsForIteratingWords; mutable bool mIsCorrupted; + int getCodePointsAndProbabilityAndReturnCodePointCount(const int wordId, + const int maxCodePointCount, int *const outCodePoints, + int *const outUnigramProbability) const; + int getShortcutPositionOfPtNode(const int ptNodePos) const; int getBigramsPositionOfPtNode(const int ptNodePos) const; int createAndGetLeavingChildNode(const DicNode *const dicNode, const int ptNodePos, DicNodeVector *const childDicNodes) const; + int getWordIdFromTerminalPtNodePos(const int ptNodePos) const; + int getTerminalPtNodePosFromWordId(const int wordId) const; + const WordAttributes getWordAttributes(const int probability, + const PtNodeParams &ptNodeParams) const; + bool isValidPos(const int pos) const; }; } // namespace latinime #endif // LATINIME_PATRICIA_TRIE_POLICY_H diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/shortcut/shortcut_list_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/shortcut/shortcut_list_policy.h index 8e16ccc05..5319dd26c 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/shortcut/shortcut_list_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/shortcut/shortcut_list_policy.h @@ -22,13 +22,13 @@ #include "defines.h" #include "suggest/core/policy/dictionary_shortcuts_structure_policy.h" #include "suggest/policyimpl/dictionary/structure/pt_common/shortcut/shortcut_list_reading_utils.h" +#include "utils/byte_array_view.h" namespace latinime { class ShortcutListPolicy : public DictionaryShortcutsStructurePolicy { public: - explicit ShortcutListPolicy(const uint8_t *const shortcutBuf) - : mShortcutsBuf(shortcutBuf) {} + explicit ShortcutListPolicy(const ReadOnlyByteArrayView buffer) : mBuffer(buffer) {} ~ShortcutListPolicy() {} @@ -37,7 +37,7 @@ class ShortcutListPolicy : public DictionaryShortcutsStructurePolicy { return NOT_A_DICT_POS; } int listPos = pos; - ShortcutListReadingUtils::getShortcutListSizeAndForwardPointer(mShortcutsBuf, &listPos); + ShortcutListReadingUtils::getShortcutListSizeAndForwardPointer(mBuffer, &listPos); return listPos; } @@ -45,7 +45,7 @@ class ShortcutListPolicy : public DictionaryShortcutsStructurePolicy { int *const outCodePointCount, bool *const outIsWhitelist, bool *const outHasNext, int *const pos) const { const ShortcutListReadingUtils::ShortcutFlags flags = - ShortcutListReadingUtils::getFlagsAndForwardPointer(mShortcutsBuf, pos); + ShortcutListReadingUtils::getFlagsAndForwardPointer(mBuffer, pos); if (outHasNext) { *outHasNext = ShortcutListReadingUtils::hasNext(flags); } @@ -54,20 +54,20 @@ class ShortcutListPolicy : public DictionaryShortcutsStructurePolicy { } if (outCodePoint) { *outCodePointCount = ShortcutListReadingUtils::readShortcutTarget( - mShortcutsBuf, maxCodePointCount, outCodePoint, pos); + mBuffer, maxCodePointCount, outCodePoint, pos); } } void skipAllShortcuts(int *const pos) const { const int shortcutListSize = ShortcutListReadingUtils - ::getShortcutListSizeAndForwardPointer(mShortcutsBuf, pos); + ::getShortcutListSizeAndForwardPointer(mBuffer, pos); *pos += shortcutListSize; } private: DISALLOW_IMPLICIT_CONSTRUCTORS(ShortcutListPolicy); - const uint8_t *const mShortcutsBuf; + const ReadOnlyByteArrayView mBuffer; }; } // namespace latinime #endif // LATINIME_SHORTCUT_LIST_POLICY_H diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.cpp index c1e938710..90d4687dd 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.cpp @@ -22,10 +22,10 @@ namespace latinime { const PtNodeParams Ver2ParticiaTrieNodeReader::fetchPtNodeParamsInBufferFromPtNodePos( const int ptNodePos) const { - if (ptNodePos < 0 || ptNodePos >= mDictSize) { + if (ptNodePos < 0 || ptNodePos >= static_cast<int>(mBuffer.size())) { // Reading invalid position because of bug or broken dictionary. - AKLOGE("Fetching PtNode info from invalid dictionary position: %d, dictionary size: %d", - ptNodePos, mDictSize); + AKLOGE("Fetching PtNode info from invalid dictionary position: %d, dictionary size: %zd", + ptNodePos, mBuffer.size()); ASSERT(false); return PtNodeParams(); } @@ -37,9 +37,9 @@ const PtNodeParams Ver2ParticiaTrieNodeReader::fetchPtNodeParamsInBufferFromPtNo int shortcutPos = NOT_A_DICT_POS; int bigramPos = NOT_A_DICT_POS; int siblingPos = NOT_A_DICT_POS; - PatriciaTrieReadingUtils::readPtNodeInfo(mDictBuffer, ptNodePos, mShortuctPolicy, - mBigramPolicy, &flags, &mergedNodeCodePointCount, mergedNodeCodePoints, &probability, - &childrenPos, &shortcutPos, &bigramPos, &siblingPos); + PatriciaTrieReadingUtils::readPtNodeInfo(mBuffer.data(), ptNodePos, mShortcutPolicy, + mBigramPolicy, mCodePointTable, &flags, &mergedNodeCodePointCount, mergedNodeCodePoints, + &probability, &childrenPos, &shortcutPos, &bigramPos, &siblingPos); if (mergedNodeCodePointCount <= 0) { AKLOGE("Empty PtNode is not allowed. Code point count: %d", mergedNodeCodePointCount); ASSERT(false); diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.h index f0725b66d..838d37314 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_patricia_trie_node_reader.h @@ -22,6 +22,7 @@ #include "defines.h" #include "suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h" #include "suggest/policyimpl/dictionary/structure/pt_common/pt_node_reader.h" +#include "utils/byte_array_view.h" namespace latinime { @@ -30,21 +31,22 @@ class DictionaryShortcutsStructurePolicy; class Ver2ParticiaTrieNodeReader : public PtNodeReader { public: - Ver2ParticiaTrieNodeReader(const uint8_t *const dictBuffer, const int dictSize, + Ver2ParticiaTrieNodeReader(const ReadOnlyByteArrayView buffer, const DictionaryBigramsStructurePolicy *const bigramPolicy, - const DictionaryShortcutsStructurePolicy *const shortcutPolicy) - : mDictBuffer(dictBuffer), mDictSize(dictSize), mBigramPolicy(bigramPolicy), - mShortuctPolicy(shortcutPolicy) {} + const DictionaryShortcutsStructurePolicy *const shortcutPolicy, + const int *const codePointTable) + : mBuffer(buffer), mBigramPolicy(bigramPolicy), mShortcutPolicy(shortcutPolicy), + mCodePointTable(codePointTable) {} virtual const PtNodeParams fetchPtNodeParamsInBufferFromPtNodePos(const int ptNodePos) const; private: DISALLOW_IMPLICIT_CONSTRUCTORS(Ver2ParticiaTrieNodeReader); - const uint8_t *const mDictBuffer; - const int mDictSize; + const ReadOnlyByteArrayView mBuffer; const DictionaryBigramsStructurePolicy *const mBigramPolicy; - const DictionaryShortcutsStructurePolicy *const mShortuctPolicy; + const DictionaryShortcutsStructurePolicy *const mShortcutPolicy; + const int *const mCodePointTable; }; } // namespace latinime #endif /* LATINIME_VER2_PATRICIA_TRIE_NODE_READER_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_pt_node_array_reader.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_pt_node_array_reader.cpp index b46617d96..72ad1eb66 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_pt_node_array_reader.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_pt_node_array_reader.cpp @@ -22,16 +22,16 @@ namespace latinime { bool Ver2PtNodeArrayReader::readPtNodeArrayInfoAndReturnIfValid(const int ptNodeArrayPos, int *const outPtNodeCount, int *const outFirstPtNodePos) const { - if (ptNodeArrayPos < 0 || ptNodeArrayPos >= mDictSize) { + if (ptNodeArrayPos < 0 || ptNodeArrayPos >= static_cast<int>(mBuffer.size())) { // Reading invalid position because of a bug or a broken dictionary. - AKLOGE("Reading PtNode array info from invalid dictionary position: %d, dict size: %d", - ptNodeArrayPos, mDictSize); + AKLOGE("Reading PtNode array info from invalid dictionary position: %d, dict size: %zd", + ptNodeArrayPos, mBuffer.size()); ASSERT(false); return false; } int readingPos = ptNodeArrayPos; const int ptNodeCountInArray = PatriciaTrieReadingUtils::getPtNodeArraySizeAndAdvancePosition( - mDictBuffer, &readingPos); + mBuffer.data(), &readingPos); *outPtNodeCount = ptNodeCountInArray; *outFirstPtNodePos = readingPos; return true; @@ -39,10 +39,10 @@ bool Ver2PtNodeArrayReader::readPtNodeArrayInfoAndReturnIfValid(const int ptNode bool Ver2PtNodeArrayReader::readForwardLinkAndReturnIfValid(const int forwordLinkPos, int *const outNextPtNodeArrayPos) const { - if (forwordLinkPos < 0 || forwordLinkPos >= mDictSize) { + if (forwordLinkPos < 0 || forwordLinkPos >= static_cast<int>(mBuffer.size())) { // Reading invalid position because of bug or broken dictionary. - AKLOGE("Reading forward link from invalid dictionary position: %d, dict size: %d", - forwordLinkPos, mDictSize); + AKLOGE("Reading forward link from invalid dictionary position: %d, dict size: %zd", + forwordLinkPos, mBuffer.size()); ASSERT(false); return false; } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_pt_node_array_reader.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_pt_node_array_reader.h index 548272148..548f36bf3 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_pt_node_array_reader.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/ver2_pt_node_array_reader.h @@ -21,13 +21,13 @@ #include "defines.h" #include "suggest/policyimpl/dictionary/structure/pt_common/pt_node_array_reader.h" +#include "utils/byte_array_view.h" namespace latinime { class Ver2PtNodeArrayReader : public PtNodeArrayReader { public: - Ver2PtNodeArrayReader(const uint8_t *const dictBuffer, const int dictSize) - : mDictBuffer(dictBuffer), mDictSize(dictSize) {}; + Ver2PtNodeArrayReader(const ReadOnlyByteArrayView buffer) : mBuffer(buffer) {}; virtual bool readPtNodeArrayInfoAndReturnIfValid(const int ptNodeArrayPos, int *const outPtNodeCount, int *const outFirstPtNodePos) const; @@ -37,8 +37,7 @@ class Ver2PtNodeArrayReader : public PtNodeArrayReader { private: DISALLOW_COPY_AND_ASSIGN(Ver2PtNodeArrayReader); - const uint8_t *const mDictBuffer; - const int mDictSize; + const ReadOnlyByteArrayView mBuffer; }; } // namespace latinime #endif /* LATINIME_VER2_PT_NODE_ARRAY_READER_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/bigram/ver4_bigram_list_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/bigram/ver4_bigram_list_policy.cpp deleted file mode 100644 index 08dc107ab..000000000 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/bigram/ver4_bigram_list_policy.cpp +++ /dev/null @@ -1,282 +0,0 @@ -/* - * Copyright (C) 2013 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "suggest/policyimpl/dictionary/structure/v4/bigram/ver4_bigram_list_policy.h" - -#include "suggest/core/dictionary/property/bigram_property.h" -#include "suggest/policyimpl/dictionary/header/header_policy.h" -#include "suggest/policyimpl/dictionary/structure/pt_common/bigram/bigram_list_read_write_utils.h" -#include "suggest/policyimpl/dictionary/structure/v4/content/bigram_dict_content.h" -#include "suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h" -#include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h" -#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h" - -namespace latinime { - -void Ver4BigramListPolicy::getNextBigram(int *const outBigramPos, int *const outProbability, - bool *const outHasNext, int *const bigramEntryPos) const { - const BigramEntry bigramEntry = - mBigramDictContent->getBigramEntryAndAdvancePosition(bigramEntryPos); - if (outBigramPos) { - // Lookup target PtNode position. - *outBigramPos = mTerminalPositionLookupTable->getTerminalPtNodePosition( - bigramEntry.getTargetTerminalId()); - } - if (outProbability) { - if (bigramEntry.hasHistoricalInfo()) { - *outProbability = - ForgettingCurveUtils::decodeProbability(bigramEntry.getHistoricalInfo(), - mHeaderPolicy); - } else { - *outProbability = bigramEntry.getProbability(); - } - } - if (outHasNext) { - *outHasNext = bigramEntry.hasNext(); - } -} - -bool Ver4BigramListPolicy::addNewEntry(const int terminalId, const int newTargetTerminalId, - const BigramProperty *const bigramProperty, bool *const outAddedNewEntry) { - // 1. The word has no bigrams yet. - // 2. The word has bigrams, and there is the target in the list. - // 3. The word has bigrams, and there is an invalid entry that can be reclaimed. - // 4. The word has bigrams. We have to append new bigram entry to the list. - // 5. Same as 4, but the list is the last entry of the content file. - if (outAddedNewEntry) { - *outAddedNewEntry = false; - } - const int bigramListPos = mBigramDictContent->getBigramListHeadPos(terminalId); - if (bigramListPos == NOT_A_DICT_POS) { - // Case 1. PtNode that doesn't have a bigram list. - // Create new bigram list. - if (!mBigramDictContent->createNewBigramList(terminalId)) { - return false; - } - const BigramEntry newBigramEntry(false /* hasNext */, NOT_A_PROBABILITY, - newTargetTerminalId); - const BigramEntry bigramEntryToWrite = createUpdatedBigramEntryFrom(&newBigramEntry, - bigramProperty); - // Write an entry. - int writingPos = mBigramDictContent->getBigramListHeadPos(terminalId); - if (!mBigramDictContent->writeBigramEntryAndAdvancePosition(&bigramEntryToWrite, - &writingPos)) { - AKLOGE("Cannot write bigram entry. pos: %d.", writingPos); - return false; - } - if (!mBigramDictContent->writeTerminator(writingPos)) { - AKLOGE("Cannot write bigram list terminator. pos: %d.", writingPos); - return false; - } - if (outAddedNewEntry) { - *outAddedNewEntry = true; - } - return true; - } - - int tailEntryPos = NOT_A_DICT_POS; - const int entryPosToUpdate = getEntryPosToUpdate(newTargetTerminalId, bigramListPos, - &tailEntryPos); - if (entryPosToUpdate == NOT_A_DICT_POS) { - // Case 4, 5. Add new entry to the bigram list. - const int contentTailPos = mBigramDictContent->getContentTailPos(); - // If the tail entry is at the tail of content buffer, the new entry can be written without - // link (Case 5). - const bool canAppendEntry = - contentTailPos == tailEntryPos + mBigramDictContent->getBigramEntrySize(); - const int newEntryPos = canAppendEntry ? tailEntryPos : contentTailPos; - int writingPos = newEntryPos; - // Write new entry at the tail position of the bigram content. - const BigramEntry newBigramEntry(false /* hasNext */, NOT_A_PROBABILITY, - newTargetTerminalId); - const BigramEntry bigramEntryToWrite = createUpdatedBigramEntryFrom( - &newBigramEntry, bigramProperty); - if (!mBigramDictContent->writeBigramEntryAndAdvancePosition(&bigramEntryToWrite, - &writingPos)) { - AKLOGE("Cannot write bigram entry. pos: %d.", writingPos); - return false; - } - if (!mBigramDictContent->writeTerminator(writingPos)) { - AKLOGE("Cannot write bigram list terminator. pos: %d.", writingPos); - return false; - } - if (!canAppendEntry) { - // Update link of the current tail entry. - if (!mBigramDictContent->writeLink(newEntryPos, tailEntryPos)) { - AKLOGE("Cannot update bigram entry link. pos: %d, linked entry pos: %d.", - tailEntryPos, newEntryPos); - return false; - } - } - if (outAddedNewEntry) { - *outAddedNewEntry = true; - } - return true; - } - - // Case 2. Overwrite the existing entry. Case 3. Reclaim and reuse the existing invalid entry. - const BigramEntry originalBigramEntry = mBigramDictContent->getBigramEntry(entryPosToUpdate); - if (!originalBigramEntry.isValid()) { - // Case 3. Reuse the existing invalid entry. outAddedNewEntry is false when an existing - // entry is updated. - if (outAddedNewEntry) { - *outAddedNewEntry = true; - } - } - const BigramEntry updatedBigramEntry = - originalBigramEntry.updateTargetTerminalIdAndGetEntry(newTargetTerminalId); - const BigramEntry bigramEntryToWrite = createUpdatedBigramEntryFrom( - &updatedBigramEntry, bigramProperty); - return mBigramDictContent->writeBigramEntry(&bigramEntryToWrite, entryPosToUpdate); -} - -bool Ver4BigramListPolicy::removeEntry(const int terminalId, const int targetTerminalId) { - const int bigramListPos = mBigramDictContent->getBigramListHeadPos(terminalId); - if (bigramListPos == NOT_A_DICT_POS) { - // Bigram list doesn't exist. - return false; - } - const int entryPosToUpdate = getEntryPosToUpdate(targetTerminalId, bigramListPos, - nullptr /* outTailEntryPos */); - if (entryPosToUpdate == NOT_A_DICT_POS) { - // Bigram entry doesn't exist. - return false; - } - const BigramEntry bigramEntry = mBigramDictContent->getBigramEntry(entryPosToUpdate); - if (targetTerminalId != bigramEntry.getTargetTerminalId()) { - // Bigram entry doesn't exist. - return false; - } - // Remove bigram entry by marking it as invalid entry and overwriting the original entry. - const BigramEntry updatedBigramEntry = bigramEntry.getInvalidatedEntry(); - return mBigramDictContent->writeBigramEntry(&updatedBigramEntry, entryPosToUpdate); -} - -bool Ver4BigramListPolicy::updateAllBigramEntriesAndDeleteUselessEntries(const int terminalId, - int *const outBigramCount) { - const int bigramListPos = mBigramDictContent->getBigramListHeadPos(terminalId); - if (bigramListPos == NOT_A_DICT_POS) { - // Bigram list doesn't exist. - return true; - } - bool hasNext = true; - int readingPos = bigramListPos; - while (hasNext) { - const BigramEntry bigramEntry = - mBigramDictContent->getBigramEntryAndAdvancePosition(&readingPos); - const int entryPos = readingPos - mBigramDictContent->getBigramEntrySize(); - hasNext = bigramEntry.hasNext(); - if (!bigramEntry.isValid()) { - continue; - } - const int targetPtNodePos = mTerminalPositionLookupTable->getTerminalPtNodePosition( - bigramEntry.getTargetTerminalId()); - if (targetPtNodePos == NOT_A_DICT_POS) { - // Invalidate bigram entry. - const BigramEntry updatedBigramEntry = bigramEntry.getInvalidatedEntry(); - if (!mBigramDictContent->writeBigramEntry(&updatedBigramEntry, entryPos)) { - return false; - } - } else if (bigramEntry.hasHistoricalInfo()) { - const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave( - bigramEntry.getHistoricalInfo(), mHeaderPolicy); - if (ForgettingCurveUtils::needsToKeep(&historicalInfo, mHeaderPolicy)) { - const BigramEntry updatedBigramEntry = - bigramEntry.updateHistoricalInfoAndGetEntry(&historicalInfo); - if (!mBigramDictContent->writeBigramEntry(&updatedBigramEntry, entryPos)) { - return false; - } - *outBigramCount += 1; - } else { - // Remove entry. - const BigramEntry updatedBigramEntry = bigramEntry.getInvalidatedEntry(); - if (!mBigramDictContent->writeBigramEntry(&updatedBigramEntry, entryPos)) { - return false; - } - } - } else { - *outBigramCount += 1; - } - } - return true; -} - -int Ver4BigramListPolicy::getBigramEntryConut(const int terminalId) { - const int bigramListPos = mBigramDictContent->getBigramListHeadPos(terminalId); - if (bigramListPos == NOT_A_DICT_POS) { - // Bigram list doesn't exist. - return 0; - } - int bigramCount = 0; - bool hasNext = true; - int readingPos = bigramListPos; - while (hasNext) { - const BigramEntry bigramEntry = - mBigramDictContent->getBigramEntryAndAdvancePosition(&readingPos); - hasNext = bigramEntry.hasNext(); - if (bigramEntry.isValid()) { - bigramCount++; - } - } - return bigramCount; -} - -int Ver4BigramListPolicy::getEntryPosToUpdate(const int targetTerminalIdToFind, - const int bigramListPos, int *const outTailEntryPos) const { - if (outTailEntryPos) { - *outTailEntryPos = NOT_A_DICT_POS; - } - int invalidEntryPos = NOT_A_DICT_POS; - int readingPos = bigramListPos; - while (true) { - const BigramEntry bigramEntry = - mBigramDictContent->getBigramEntryAndAdvancePosition(&readingPos); - const int entryPos = readingPos - mBigramDictContent->getBigramEntrySize(); - if (!bigramEntry.hasNext()) { - if (outTailEntryPos) { - *outTailEntryPos = entryPos; - } - break; - } - if (bigramEntry.getTargetTerminalId() == targetTerminalIdToFind) { - // Entry with same target is found. - return entryPos; - } else if (!bigramEntry.isValid()) { - // Invalid entry that can be reused is found. - invalidEntryPos = entryPos; - } - } - return invalidEntryPos; -} - -const BigramEntry Ver4BigramListPolicy::createUpdatedBigramEntryFrom( - const BigramEntry *const originalBigramEntry, - const BigramProperty *const bigramProperty) const { - // TODO: Consolidate historical info and probability. - if (mHeaderPolicy->hasHistoricalInfoOfWords()) { - const HistoricalInfo historicalInfoForUpdate(bigramProperty->getTimestamp(), - bigramProperty->getLevel(), bigramProperty->getCount()); - const HistoricalInfo updatedHistoricalInfo = - ForgettingCurveUtils::createUpdatedHistoricalInfo( - originalBigramEntry->getHistoricalInfo(), bigramProperty->getProbability(), - &historicalInfoForUpdate, mHeaderPolicy); - return originalBigramEntry->updateHistoricalInfoAndGetEntry(&updatedHistoricalInfo); - } else { - return originalBigramEntry->updateProbabilityAndGetEntry(bigramProperty->getProbability()); - } -} - -} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/bigram/ver4_bigram_list_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/bigram/ver4_bigram_list_policy.h deleted file mode 100644 index 4b3bb3725..000000000 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/bigram/ver4_bigram_list_policy.h +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Copyright (C) 2013 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LATINIME_VER4_BIGRAM_LIST_POLICY_H -#define LATINIME_VER4_BIGRAM_LIST_POLICY_H - -#include "defines.h" -#include "suggest/core/policy/dictionary_bigrams_structure_policy.h" -#include "suggest/policyimpl/dictionary/structure/v4/content/bigram_entry.h" - -namespace latinime { - -class BigramDictContent; -class BigramProperty; -class HeaderPolicy; -class TerminalPositionLookupTable; - -class Ver4BigramListPolicy : public DictionaryBigramsStructurePolicy { - public: - Ver4BigramListPolicy(BigramDictContent *const bigramDictContent, - const TerminalPositionLookupTable *const terminalPositionLookupTable, - const HeaderPolicy *const headerPolicy) - : mBigramDictContent(bigramDictContent), - mTerminalPositionLookupTable(terminalPositionLookupTable), - mHeaderPolicy(headerPolicy) {} - - void getNextBigram(int *const outBigramPos, int *const outProbability, - bool *const outHasNext, int *const bigramEntryPos) const; - - bool skipAllBigrams(int *const pos) const { - // Do nothing because we don't need to skip bigram lists in ver4 dictionaries. - return true; - } - - bool addNewEntry(const int terminalId, const int newTargetTerminalId, - const BigramProperty *const bigramProperty, bool *const outAddedNewEntry); - - bool removeEntry(const int terminalId, const int targetTerminalId); - - bool updateAllBigramEntriesAndDeleteUselessEntries(const int terminalId, - int *const outBigramCount); - - int getBigramEntryConut(const int terminalId); - - private: - DISALLOW_IMPLICIT_CONSTRUCTORS(Ver4BigramListPolicy); - - int getEntryPosToUpdate(const int targetTerminalIdToFind, const int bigramListPos, - int *const outTailEntryPos) const; - - const BigramEntry createUpdatedBigramEntryFrom(const BigramEntry *const originalBigramEntry, - const BigramProperty *const bigramProperty) const; - - BigramDictContent *const mBigramDictContent; - const TerminalPositionLookupTable *const mTerminalPositionLookupTable; - const HeaderPolicy *const mHeaderPolicy; -}; -} // namespace latinime -#endif /* LATINIME_VER4_BIGRAM_LIST_POLICY_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/bigram_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/bigram_dict_content.cpp deleted file mode 100644 index d7e1952b5..000000000 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/bigram_dict_content.cpp +++ /dev/null @@ -1,219 +0,0 @@ -/* - * Copyright (C) 2013 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "suggest/policyimpl/dictionary/structure/v4/content/bigram_dict_content.h" - -#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" - -namespace latinime { - -const int BigramDictContent::INVALID_LINKED_ENTRY_POS = Ver4DictConstants::NOT_A_TERMINAL_ID; - -const BigramEntry BigramDictContent::getBigramEntryAndAdvancePosition( - int *const bigramEntryPos) const { - const BufferWithExtendableBuffer *const bigramListBuffer = getContentBuffer(); - const int bigramEntryTailPos = (*bigramEntryPos) + getBigramEntrySize(); - if (*bigramEntryPos < 0 || bigramEntryTailPos > bigramListBuffer->getTailPosition()) { - AKLOGE("Invalid bigram entry position. bigramEntryPos: %d, bigramEntryTailPos: %d, " - "bufSize: %d", *bigramEntryPos, bigramEntryTailPos, - bigramListBuffer->getTailPosition()); - ASSERT(false); - return BigramEntry(false /* hasNext */, NOT_A_PROBABILITY, - Ver4DictConstants::NOT_A_TERMINAL_ID); - } - const int bigramFlags = bigramListBuffer->readUintAndAdvancePosition( - Ver4DictConstants::BIGRAM_FLAGS_FIELD_SIZE, bigramEntryPos); - const bool isLink = (bigramFlags & Ver4DictConstants::BIGRAM_IS_LINK_MASK) != 0; - int probability = NOT_A_PROBABILITY; - int timestamp = NOT_A_TIMESTAMP; - int level = 0; - int count = 0; - if (mHasHistoricalInfo) { - timestamp = bigramListBuffer->readUintAndAdvancePosition( - Ver4DictConstants::TIME_STAMP_FIELD_SIZE, bigramEntryPos); - level = bigramListBuffer->readUintAndAdvancePosition( - Ver4DictConstants::WORD_LEVEL_FIELD_SIZE, bigramEntryPos); - count = bigramListBuffer->readUintAndAdvancePosition( - Ver4DictConstants::WORD_COUNT_FIELD_SIZE, bigramEntryPos); - } else { - probability = bigramListBuffer->readUintAndAdvancePosition( - Ver4DictConstants::PROBABILITY_SIZE, bigramEntryPos); - } - const int encodedTargetTerminalId = bigramListBuffer->readUintAndAdvancePosition( - Ver4DictConstants::BIGRAM_TARGET_TERMINAL_ID_FIELD_SIZE, bigramEntryPos); - const int targetTerminalId = - (encodedTargetTerminalId == Ver4DictConstants::INVALID_BIGRAM_TARGET_TERMINAL_ID) ? - Ver4DictConstants::NOT_A_TERMINAL_ID : encodedTargetTerminalId; - if (isLink) { - const int linkedEntryPos = targetTerminalId; - if (linkedEntryPos == INVALID_LINKED_ENTRY_POS) { - // Bigram list terminator is found. - return BigramEntry(false /* hasNext */, NOT_A_PROBABILITY, - Ver4DictConstants::NOT_A_TERMINAL_ID); - } - *bigramEntryPos = linkedEntryPos; - return getBigramEntryAndAdvancePosition(bigramEntryPos); - } - // hasNext is always true because we should continue to read the next entry until the terminator - // is found. - if (mHasHistoricalInfo) { - const HistoricalInfo historicalInfo(timestamp, level, count); - return BigramEntry(true /* hasNext */, probability, &historicalInfo, targetTerminalId); - } else { - return BigramEntry(true /* hasNext */, probability, targetTerminalId); - } -} - -bool BigramDictContent::writeBigramEntryAndAdvancePosition( - const BigramEntry *const bigramEntryToWrite, int *const entryWritingPos) { - return writeBigramEntryAttributesAndAdvancePosition(false /* isLink */, - bigramEntryToWrite->getProbability(), bigramEntryToWrite->getTargetTerminalId(), - bigramEntryToWrite->getHistoricalInfo()->getTimeStamp(), - bigramEntryToWrite->getHistoricalInfo()->getLevel(), - bigramEntryToWrite->getHistoricalInfo()->getCount(), - entryWritingPos); -} - -bool BigramDictContent::writeBigramEntryAttributesAndAdvancePosition( - const bool isLink, const int probability, const int targetTerminalId, - const int timestamp, const int level, const int count, int *const entryWritingPos) { - BufferWithExtendableBuffer *const bigramListBuffer = getWritableContentBuffer(); - const int bigramFlags = isLink ? Ver4DictConstants::BIGRAM_IS_LINK_MASK : 0; - if (!bigramListBuffer->writeUintAndAdvancePosition(bigramFlags, - Ver4DictConstants::BIGRAM_FLAGS_FIELD_SIZE, entryWritingPos)) { - AKLOGE("Cannot write bigram flags. pos: %d, flags: %x", *entryWritingPos, bigramFlags); - return false; - } - if (mHasHistoricalInfo) { - if (!bigramListBuffer->writeUintAndAdvancePosition(timestamp, - Ver4DictConstants::TIME_STAMP_FIELD_SIZE, entryWritingPos)) { - AKLOGE("Cannot write bigram timestamps. pos: %d, timestamp: %d", *entryWritingPos, - timestamp); - return false; - } - if (!bigramListBuffer->writeUintAndAdvancePosition(level, - Ver4DictConstants::WORD_LEVEL_FIELD_SIZE, entryWritingPos)) { - AKLOGE("Cannot write bigram level. pos: %d, level: %d", *entryWritingPos, - level); - return false; - } - if (!bigramListBuffer->writeUintAndAdvancePosition(count, - Ver4DictConstants::WORD_COUNT_FIELD_SIZE, entryWritingPos)) { - AKLOGE("Cannot write bigram count. pos: %d, count: %d", *entryWritingPos, - count); - return false; - } - } else { - if (!bigramListBuffer->writeUintAndAdvancePosition(probability, - Ver4DictConstants::PROBABILITY_SIZE, entryWritingPos)) { - AKLOGE("Cannot write bigram probability. pos: %d, probability: %d", *entryWritingPos, - probability); - return false; - } - } - const int targetTerminalIdToWrite = (targetTerminalId == Ver4DictConstants::NOT_A_TERMINAL_ID) ? - Ver4DictConstants::INVALID_BIGRAM_TARGET_TERMINAL_ID : targetTerminalId; - if (!bigramListBuffer->writeUintAndAdvancePosition(targetTerminalIdToWrite, - Ver4DictConstants::BIGRAM_TARGET_TERMINAL_ID_FIELD_SIZE, entryWritingPos)) { - AKLOGE("Cannot write bigram target terminal id. pos: %d, target terminal id: %d", - *entryWritingPos, targetTerminalId); - return false; - } - return true; -} - -bool BigramDictContent::writeLink(const int linkedEntryPos, const int writingPos) { - const int targetTerminalId = linkedEntryPos; - int pos = writingPos; - return writeBigramEntryAttributesAndAdvancePosition(true /* isLink */, - NOT_A_PROBABILITY /* probability */, targetTerminalId, NOT_A_TIMESTAMP, 0 /* level */, - 0 /* count */, &pos); -} - -bool BigramDictContent::runGC(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, - const BigramDictContent *const originalBigramDictContent, - int *const outBigramEntryCount) { - for (TerminalPositionLookupTable::TerminalIdMap::const_iterator it = terminalIdMap->begin(); - it != terminalIdMap->end(); ++it) { - const int originalBigramListPos = - originalBigramDictContent->getBigramListHeadPos(it->first); - if (originalBigramListPos == NOT_A_DICT_POS) { - // This terminal does not have a bigram list. - continue; - } - const int bigramListPos = getContentBuffer()->getTailPosition(); - int bigramEntryCount = 0; - // Copy bigram list with GC from original content. - if (!runGCBigramList(originalBigramListPos, originalBigramDictContent, bigramListPos, - terminalIdMap, &bigramEntryCount)) { - AKLOGE("Cannot complete GC for the bigram list. original pos: %d, pos: %d", - originalBigramListPos, bigramListPos); - return false; - } - if (bigramEntryCount == 0) { - // All bigram entries are useless. This terminal does not have a bigram list. - continue; - } - *outBigramEntryCount += bigramEntryCount; - // Set bigram list position to the lookup table. - if (!getUpdatableAddressLookupTable()->set(it->second, bigramListPos)) { - AKLOGE("Cannot set bigram list position. terminal id: %d, pos: %d", - it->second, bigramListPos); - return false; - } - } - return true; -} - -// Returns whether GC for the bigram list was succeeded or not. -bool BigramDictContent::runGCBigramList(const int bigramListPos, - const BigramDictContent *const sourceBigramDictContent, const int toPos, - const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, - int *const outEntryCount) { - bool hasNext = true; - int readingPos = bigramListPos; - int writingPos = toPos; - while (hasNext) { - const BigramEntry originalBigramEntry = - sourceBigramDictContent->getBigramEntryAndAdvancePosition(&readingPos); - hasNext = originalBigramEntry.hasNext(); - if (!originalBigramEntry.isValid()) { - continue; - } - TerminalPositionLookupTable::TerminalIdMap::const_iterator it = - terminalIdMap->find(originalBigramEntry.getTargetTerminalId()); - if (it == terminalIdMap->end()) { - // Target word has been removed. - continue; - } - const BigramEntry updatedBigramEntry = - originalBigramEntry.updateTargetTerminalIdAndGetEntry(it->second); - if (!writeBigramEntryAndAdvancePosition(&updatedBigramEntry, &writingPos)) { - AKLOGE("Cannot write bigram entry to run GC. pos: %d", writingPos); - return false; - } - *outEntryCount += 1; - } - if (*outEntryCount > 0) { - if (!writeTerminator(writingPos)) { - AKLOGE("Cannot write terminator to run GC. pos: %d", writingPos); - return false; - } - } - return true; -} - -} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/bigram_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/bigram_dict_content.h deleted file mode 100644 index 361dd2c74..000000000 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/bigram_dict_content.h +++ /dev/null @@ -1,128 +0,0 @@ -/* - * Copyright (C) 2013, The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LATINIME_BIGRAM_DICT_CONTENT_H -#define LATINIME_BIGRAM_DICT_CONTENT_H - -#include <cstdint> -#include <cstdio> - -#include "defines.h" -#include "suggest/policyimpl/dictionary/structure/v4/content/bigram_entry.h" -#include "suggest/policyimpl/dictionary/structure/v4/content/sparse_table_dict_content.h" -#include "suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h" -#include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h" - -namespace latinime { - -class BigramDictContent : public SparseTableDictContent { - public: - BigramDictContent(uint8_t *const *buffers, const int *bufferSizes, const bool hasHistoricalInfo) - : SparseTableDictContent(buffers, bufferSizes, - Ver4DictConstants::BIGRAM_ADDRESS_TABLE_BLOCK_SIZE, - Ver4DictConstants::BIGRAM_ADDRESS_TABLE_DATA_SIZE), - mHasHistoricalInfo(hasHistoricalInfo) {} - - BigramDictContent(const bool hasHistoricalInfo) - : SparseTableDictContent(Ver4DictConstants::BIGRAM_ADDRESS_TABLE_BLOCK_SIZE, - Ver4DictConstants::BIGRAM_ADDRESS_TABLE_DATA_SIZE), - mHasHistoricalInfo(hasHistoricalInfo) {} - - int getContentTailPos() const { - return getContentBuffer()->getTailPosition(); - } - - const BigramEntry getBigramEntry(const int bigramEntryPos) const { - int readingPos = bigramEntryPos; - return getBigramEntryAndAdvancePosition(&readingPos); - } - - const BigramEntry getBigramEntryAndAdvancePosition(int *const bigramEntryPos) const; - - // Returns head position of bigram list for a PtNode specified by terminalId. - int getBigramListHeadPos(const int terminalId) const { - const SparseTable *const addressLookupTable = getAddressLookupTable(); - if (!addressLookupTable->contains(terminalId)) { - return NOT_A_DICT_POS; - } - return addressLookupTable->get(terminalId); - } - - bool writeBigramEntryAtTail(const BigramEntry *const bigramEntryToWrite) { - int writingPos = getContentBuffer()->getTailPosition(); - return writeBigramEntryAndAdvancePosition(bigramEntryToWrite, &writingPos); - } - - bool writeBigramEntry(const BigramEntry *const bigramEntryToWrite, const int entryWritingPos) { - int writingPos = entryWritingPos; - return writeBigramEntryAndAdvancePosition(bigramEntryToWrite, &writingPos); - } - - bool writeBigramEntryAndAdvancePosition(const BigramEntry *const bigramEntryToWrite, - int *const entryWritingPos); - - bool writeTerminator(const int writingPos) { - // Terminator is a link to the invalid position. - return writeLink(INVALID_LINKED_ENTRY_POS, writingPos); - } - - bool writeLink(const int linkedPos, const int writingPos); - - bool createNewBigramList(const int terminalId) { - const int bigramListPos = getContentBuffer()->getTailPosition(); - return getUpdatableAddressLookupTable()->set(terminalId, bigramListPos); - } - - bool flushToFile(FILE *const file) const { - return flush(file); - } - - bool runGC(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, - const BigramDictContent *const originalBigramDictContent, - int *const outBigramEntryCount); - - int getBigramEntrySize() const { - if (mHasHistoricalInfo) { - return Ver4DictConstants::BIGRAM_FLAGS_FIELD_SIZE - + Ver4DictConstants::TIME_STAMP_FIELD_SIZE - + Ver4DictConstants::WORD_LEVEL_FIELD_SIZE - + Ver4DictConstants::WORD_COUNT_FIELD_SIZE - + Ver4DictConstants::BIGRAM_TARGET_TERMINAL_ID_FIELD_SIZE; - } else { - return Ver4DictConstants::BIGRAM_FLAGS_FIELD_SIZE - + Ver4DictConstants::PROBABILITY_SIZE - + Ver4DictConstants::BIGRAM_TARGET_TERMINAL_ID_FIELD_SIZE; - } - } - - private: - DISALLOW_COPY_AND_ASSIGN(BigramDictContent); - - static const int INVALID_LINKED_ENTRY_POS; - - bool writeBigramEntryAttributesAndAdvancePosition( - const bool isLink, const int probability, const int targetTerminalId, - const int timestamp, const int level, const int count, int *const entryWritingPos); - - bool runGCBigramList(const int bigramListPos, - const BigramDictContent *const sourceBigramDictContent, const int toPos, - const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, - int *const outEntryCount); - - bool mHasHistoricalInfo; -}; -} // namespace latinime -#endif /* LATINIME_BIGRAM_DICT_CONTENT_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/bigram_entry.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/bigram_entry.h deleted file mode 100644 index 2b0cbd93b..000000000 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/bigram_entry.h +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Copyright (C) 2013, The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LATINIME_BIGRAM_ENTRY_H -#define LATINIME_BIGRAM_ENTRY_H - -#include "defines.h" -#include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h" -#include "suggest/policyimpl/dictionary/utils/historical_info.h" - -namespace latinime { - -class BigramEntry { - public: - BigramEntry(const BigramEntry& bigramEntry) - : mHasNext(bigramEntry.mHasNext), mProbability(bigramEntry.mProbability), - mHistoricalInfo(), mTargetTerminalId(bigramEntry.mTargetTerminalId) {} - - // Entry with historical information. - BigramEntry(const bool hasNext, const int probability, const int targetTerminalId) - : mHasNext(hasNext), mProbability(probability), mHistoricalInfo(), - mTargetTerminalId(targetTerminalId) {} - - // Entry with historical information. - BigramEntry(const bool hasNext, const int probability, - const HistoricalInfo *const historicalInfo, const int targetTerminalId) - : mHasNext(hasNext), mProbability(probability), mHistoricalInfo(*historicalInfo), - mTargetTerminalId(targetTerminalId) {} - - const BigramEntry getInvalidatedEntry() const { - return updateTargetTerminalIdAndGetEntry(Ver4DictConstants::NOT_A_TERMINAL_ID); - } - - const BigramEntry updateHasNextAndGetEntry(const bool hasNext) const { - return BigramEntry(hasNext, mProbability, &mHistoricalInfo, mTargetTerminalId); - } - - const BigramEntry updateTargetTerminalIdAndGetEntry(const int newTargetTerminalId) const { - return BigramEntry(mHasNext, mProbability, &mHistoricalInfo, newTargetTerminalId); - } - - const BigramEntry updateProbabilityAndGetEntry(const int probability) const { - return BigramEntry(mHasNext, probability, &mHistoricalInfo, mTargetTerminalId); - } - - const BigramEntry updateHistoricalInfoAndGetEntry( - const HistoricalInfo *const historicalInfo) const { - return BigramEntry(mHasNext, mProbability, historicalInfo, mTargetTerminalId); - } - - bool isValid() const { - return mTargetTerminalId != Ver4DictConstants::NOT_A_TERMINAL_ID; - } - - bool hasNext() const { - return mHasNext; - } - - int getProbability() const { - return mProbability; - } - - bool hasHistoricalInfo() const { - return mHistoricalInfo.isValid(); - } - - const HistoricalInfo *getHistoricalInfo() const { - return &mHistoricalInfo; - } - - int getTargetTerminalId() const { - return mTargetTerminalId; - } - - private: - // Copy constructor is public to use this class as a type of return value. - DISALLOW_DEFAULT_CONSTRUCTOR(BigramEntry); - DISALLOW_ASSIGNMENT_OPERATOR(BigramEntry); - - const bool mHasNext; - const int mProbability; - const HistoricalInfo mHistoricalInfo; - const int mTargetTerminalId; -}; -} // namespace latinime -#endif /* LATINIME_BIGRAM_ENTRY_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp index 5dc91ba10..509bd683b 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp @@ -16,18 +16,86 @@ #include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h" +#include <algorithm> +#include <cstring> + +#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h" + namespace latinime { +const int LanguageModelDictContent::DUMMY_PROBABILITY_FOR_VALID_WORDS = 1; + bool LanguageModelDictContent::save(FILE *const file) const { return mTrieMap.save(file); } bool LanguageModelDictContent::runGC( const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, - const LanguageModelDictContent *const originalContent, - int *const outNgramCount) { + const LanguageModelDictContent *const originalContent) { return runGCInner(terminalIdMap, originalContent->mTrieMap.getEntriesInRootLevel(), - 0 /* nextLevelBitmapEntryIndex */, outNgramCount); + 0 /* nextLevelBitmapEntryIndex */); +} + +const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArrayView prevWordIds, + const int wordId, const HeaderPolicy *const headerPolicy) const { + int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; + bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex(); + int maxPrevWordCount = 0; + for (size_t i = 0; i < prevWordIds.size(); ++i) { + const int nextBitmapEntryIndex = + mTrieMap.get(prevWordIds[i], bitmapEntryIndices[i]).mNextLevelBitmapEntryIndex; + if (nextBitmapEntryIndex == TrieMap::INVALID_INDEX) { + break; + } + maxPrevWordCount = i + 1; + bitmapEntryIndices[i + 1] = nextBitmapEntryIndex; + } + + for (int i = maxPrevWordCount; i >= 0; --i) { + const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndices[i]); + if (!result.mIsValid) { + continue; + } + const ProbabilityEntry probabilityEntry = + ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo); + int probability = NOT_A_PROBABILITY; + if (mHasHistoricalInfo) { + const int rawProbability = ForgettingCurveUtils::decodeProbability( + probabilityEntry.getHistoricalInfo(), headerPolicy); + if (rawProbability == NOT_A_PROBABILITY) { + // The entry should not be treated as a valid entry. + continue; + } + if (i == 0) { + // unigram + probability = rawProbability; + } else { + const ProbabilityEntry prevWordProbabilityEntry = getNgramProbabilityEntry( + prevWordIds.skip(1 /* n */).limit(i - 1), prevWordIds[0]); + if (!prevWordProbabilityEntry.isValid()) { + continue; + } + if (prevWordProbabilityEntry.representsBeginningOfSentence()) { + probability = rawProbability; + } else { + const int prevWordRawProbability = ForgettingCurveUtils::decodeProbability( + prevWordProbabilityEntry.getHistoricalInfo(), headerPolicy); + probability = std::min(MAX_PROBABILITY - prevWordRawProbability + + rawProbability, MAX_PROBABILITY); + } + } + } else { + probability = probabilityEntry.getProbability(); + } + // TODO: Some flags in unigramProbabilityEntry should be overwritten by flags in + // probabilityEntry. + const ProbabilityEntry unigramProbabilityEntry = getProbabilityEntry(wordId); + return WordAttributes(probability, unigramProbabilityEntry.isBlacklisted(), + unigramProbabilityEntry.isNotAWord(), + unigramProbabilityEntry.isPossiblyOffensive()); + } + // Cannot find the word. + return WordAttributes(); } ProbabilityEntry LanguageModelDictContent::getNgramProbabilityEntry( @@ -45,18 +113,142 @@ ProbabilityEntry LanguageModelDictContent::getNgramProbabilityEntry( } bool LanguageModelDictContent::setNgramProbabilityEntry(const WordIdArrayView prevWordIds, - const int terminalId, const ProbabilityEntry *const probabilityEntry) { + const int wordId, const ProbabilityEntry *const probabilityEntry) { + if (wordId == Ver4DictConstants::NOT_A_TERMINAL_ID) { + return false; + } + const int bitmapEntryIndex = createAndGetBitmapEntryIndex(prevWordIds); + if (bitmapEntryIndex == TrieMap::INVALID_INDEX) { + return false; + } + return mTrieMap.put(wordId, probabilityEntry->encode(mHasHistoricalInfo), bitmapEntryIndex); +} + +bool LanguageModelDictContent::removeNgramProbabilityEntry(const WordIdArrayView prevWordIds, + const int wordId) { const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds); if (bitmapEntryIndex == TrieMap::INVALID_INDEX) { + // Cannot find bitmap entry for the probability entry. The entry doesn't exist. + return false; + } + return mTrieMap.remove(wordId, bitmapEntryIndex); +} + +LanguageModelDictContent::EntryRange LanguageModelDictContent::getProbabilityEntries( + const WordIdArrayView prevWordIds) const { + const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds); + return EntryRange(mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex), mHasHistoricalInfo); +} + +std::vector<LanguageModelDictContent::DumppedFullEntryInfo> + LanguageModelDictContent::exportAllNgramEntriesRelatedToWord( + const HeaderPolicy *const headerPolicy, const int wordId) const { + const TrieMap::Result result = mTrieMap.getRoot(wordId); + if (!result.mIsValid || result.mNextLevelBitmapEntryIndex == TrieMap::INVALID_INDEX) { + // The word doesn't have any related ngram entries. + return std::vector<DumppedFullEntryInfo>(); + } + std::vector<int> prevWordIds = { wordId }; + std::vector<DumppedFullEntryInfo> entries; + exportAllNgramEntriesRelatedToWordInner(headerPolicy, result.mNextLevelBitmapEntryIndex, + &prevWordIds, &entries); + return entries; +} + +void LanguageModelDictContent::exportAllNgramEntriesRelatedToWordInner( + const HeaderPolicy *const headerPolicy, const int bitmapEntryIndex, + std::vector<int> *const prevWordIds, + std::vector<DumppedFullEntryInfo> *const outBummpedFullEntryInfo) const { + for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) { + const int wordId = entry.key(); + const ProbabilityEntry probabilityEntry = + ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo); + if (probabilityEntry.isValid()) { + const WordAttributes wordAttributes = getWordAttributes( + WordIdArrayView(*prevWordIds), wordId, headerPolicy); + outBummpedFullEntryInfo->emplace_back(*prevWordIds, wordId, + wordAttributes, probabilityEntry); + } + if (entry.hasNextLevelMap()) { + prevWordIds->push_back(wordId); + exportAllNgramEntriesRelatedToWordInner(headerPolicy, + entry.getNextLevelBitmapEntryIndex(), prevWordIds, outBummpedFullEntryInfo); + prevWordIds->pop_back(); + } + } +} + +bool LanguageModelDictContent::truncateEntries(const EntryCounts ¤tEntryCounts, + const EntryCounts &maxEntryCounts, const HeaderPolicy *const headerPolicy, + MutableEntryCounters *const outEntryCounters) { + for (int prevWordCount = 0; prevWordCount <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++prevWordCount) { + const int totalWordCount = prevWordCount + 1; + if (currentEntryCounts.getNgramCount(totalWordCount) + <= maxEntryCounts.getNgramCount(totalWordCount)) { + outEntryCounters->setNgramCount(totalWordCount, + currentEntryCounts.getNgramCount(totalWordCount)); + continue; + } + int entryCount = 0; + if (!turncateEntriesInSpecifiedLevel(headerPolicy, + maxEntryCounts.getNgramCount(totalWordCount), prevWordCount, &entryCount)) { + return false; + } + outEntryCounters->setNgramCount(totalWordCount, entryCount); + } + return true; +} + +bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView prevWordIds, + const int wordId, const bool isValid, const HistoricalInfo historicalInfo, + const HeaderPolicy *const headerPolicy, MutableEntryCounters *const entryCountersToUpdate) { + if (!mHasHistoricalInfo) { + AKLOGE("updateAllEntriesOnInputWord is called for dictionary without historical info."); return false; } - return mTrieMap.put(terminalId, probabilityEntry->encode(mHasHistoricalInfo), bitmapEntryIndex); + const ProbabilityEntry originalUnigramProbabilityEntry = getProbabilityEntry(wordId); + const ProbabilityEntry updatedUnigramProbabilityEntry = createUpdatedEntryFrom( + originalUnigramProbabilityEntry, isValid, historicalInfo, headerPolicy); + if (!setProbabilityEntry(wordId, &updatedUnigramProbabilityEntry)) { + return false; + } + for (size_t i = 0; i < prevWordIds.size(); ++i) { + if (prevWordIds[i] == NOT_A_WORD_ID) { + break; + } + // TODO: Optimize this code. + const WordIdArrayView limitedPrevWordIds = prevWordIds.limit(i + 1); + const ProbabilityEntry originalNgramProbabilityEntry = getNgramProbabilityEntry( + limitedPrevWordIds, wordId); + const ProbabilityEntry updatedNgramProbabilityEntry = createUpdatedEntryFrom( + originalNgramProbabilityEntry, isValid, historicalInfo, headerPolicy); + if (!setNgramProbabilityEntry(limitedPrevWordIds, wordId, &updatedNgramProbabilityEntry)) { + return false; + } + if (!originalNgramProbabilityEntry.isValid()) { + entryCountersToUpdate->incrementNgramCount(i + 2); + } + } + return true; +} + +const ProbabilityEntry LanguageModelDictContent::createUpdatedEntryFrom( + const ProbabilityEntry &originalProbabilityEntry, const bool isValid, + const HistoricalInfo historicalInfo, const HeaderPolicy *const headerPolicy) const { + const HistoricalInfo updatedHistoricalInfo = ForgettingCurveUtils::createUpdatedHistoricalInfo( + originalProbabilityEntry.getHistoricalInfo(), isValid ? + DUMMY_PROBABILITY_FOR_VALID_WORDS : NOT_A_PROBABILITY, + &historicalInfo, headerPolicy); + if (originalProbabilityEntry.isValid()) { + return ProbabilityEntry(originalProbabilityEntry.getFlags(), &updatedHistoricalInfo); + } else { + return ProbabilityEntry(0 /* flags */, &updatedHistoricalInfo); + } } bool LanguageModelDictContent::runGCInner( const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, - const TrieMap::TrieMapRange trieMapRange, - const int nextLevelBitmapEntryIndex, int *const outNgramCount) { + const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex) { for (auto &entry : trieMapRange) { const auto it = terminalIdMap->find(entry.key()); if (it == terminalIdMap->end() || it->second == Ver4DictConstants::NOT_A_TERMINAL_ID) { @@ -66,13 +258,9 @@ bool LanguageModelDictContent::runGCInner( if (!mTrieMap.put(it->second, entry.value(), nextLevelBitmapEntryIndex)) { return false; } - if (outNgramCount) { - *outNgramCount += 1; - } if (entry.hasNextLevelMap()) { if (!runGCInner(terminalIdMap, entry.getEntriesInNextLevel(), - mTrieMap.getNextLevelBitmapEntryIndex(it->second, nextLevelBitmapEntryIndex), - outNgramCount)) { + mTrieMap.getNextLevelBitmapEntryIndex(it->second, nextLevelBitmapEntryIndex))) { return false; } } @@ -80,6 +268,28 @@ bool LanguageModelDictContent::runGCInner( return true; } +int LanguageModelDictContent::createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds) { + int lastBitmapEntryIndex = mTrieMap.getRootBitmapEntryIndex(); + for (const int wordId : prevWordIds) { + const TrieMap::Result result = mTrieMap.get(wordId, lastBitmapEntryIndex); + if (result.mIsValid && result.mNextLevelBitmapEntryIndex != TrieMap::INVALID_INDEX) { + lastBitmapEntryIndex = result.mNextLevelBitmapEntryIndex; + continue; + } + if (!result.mIsValid) { + if (!mTrieMap.put(wordId, ProbabilityEntry().encode(mHasHistoricalInfo), + lastBitmapEntryIndex)) { + AKLOGE("Failed to update trie map. wordId: %d, lastBitmapEntryIndex %d", wordId, + lastBitmapEntryIndex); + return TrieMap::INVALID_INDEX; + } + } + lastBitmapEntryIndex = mTrieMap.getNextLevelBitmapEntryIndex(wordId, + lastBitmapEntryIndex); + } + return lastBitmapEntryIndex; +} + int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWordIds) const { int bitmapEntryIndex = mTrieMap.getRootBitmapEntryIndex(); for (const int wordId : prevWordIds) { @@ -92,4 +302,143 @@ int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWord return bitmapEntryIndex; } +bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, + const int prevWordCount, const HeaderPolicy *const headerPolicy, + MutableEntryCounters *const outEntryCounters) { + for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) { + if (prevWordCount > MAX_PREV_WORD_COUNT_FOR_N_GRAM) { + AKLOGE("Invalid prevWordCount. prevWordCount: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.", + prevWordCount, MAX_PREV_WORD_COUNT_FOR_N_GRAM); + return false; + } + const ProbabilityEntry probabilityEntry = + ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo); + if (prevWordCount > 0 && probabilityEntry.isValid() + && !mTrieMap.getRoot(entry.key()).mIsValid) { + // The entry is related to a word that has been removed. Remove the entry. + if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) { + return false; + } + continue; + } + if (mHasHistoricalInfo && !probabilityEntry.representsBeginningOfSentence() + && probabilityEntry.isValid()) { + const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave( + probabilityEntry.getHistoricalInfo(), headerPolicy); + if (ForgettingCurveUtils::needsToKeep(&historicalInfo, headerPolicy)) { + // Update the entry. + const ProbabilityEntry updatedEntry(probabilityEntry.getFlags(), &historicalInfo); + if (!mTrieMap.put(entry.key(), updatedEntry.encode(mHasHistoricalInfo), + bitmapEntryIndex)) { + return false; + } + } else { + // Remove the entry. + if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) { + return false; + } + continue; + } + } + if (!probabilityEntry.representsBeginningOfSentence()) { + outEntryCounters->incrementNgramCount(prevWordCount + 1); + } + if (!entry.hasNextLevelMap()) { + continue; + } + if (!updateAllProbabilityEntriesForGCInner(entry.getNextLevelBitmapEntryIndex(), + prevWordCount + 1, headerPolicy, outEntryCounters)) { + return false; + } + } + return true; +} + +bool LanguageModelDictContent::turncateEntriesInSpecifiedLevel( + const HeaderPolicy *const headerPolicy, const int maxEntryCount, const int targetLevel, + int *const outEntryCount) { + std::vector<int> prevWordIds; + std::vector<EntryInfoToTurncate> entryInfoVector; + if (!getEntryInfo(headerPolicy, targetLevel, mTrieMap.getRootBitmapEntryIndex(), + &prevWordIds, &entryInfoVector)) { + return false; + } + if (static_cast<int>(entryInfoVector.size()) <= maxEntryCount) { + *outEntryCount = static_cast<int>(entryInfoVector.size()); + return true; + } + *outEntryCount = maxEntryCount; + const int entryCountToRemove = static_cast<int>(entryInfoVector.size()) - maxEntryCount; + std::partial_sort(entryInfoVector.begin(), entryInfoVector.begin() + entryCountToRemove, + entryInfoVector.end(), + EntryInfoToTurncate::Comparator()); + for (int i = 0; i < entryCountToRemove; ++i) { + const EntryInfoToTurncate &entryInfo = entryInfoVector[i]; + if (!removeNgramProbabilityEntry( + WordIdArrayView(entryInfo.mPrevWordIds, entryInfo.mPrevWordCount), entryInfo.mKey)) { + return false; + } + } + return true; +} + +bool LanguageModelDictContent::getEntryInfo(const HeaderPolicy *const headerPolicy, + const int targetLevel, const int bitmapEntryIndex, std::vector<int> *const prevWordIds, + std::vector<EntryInfoToTurncate> *const outEntryInfo) const { + const int prevWordCount = prevWordIds->size(); + for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) { + if (prevWordCount < targetLevel) { + if (!entry.hasNextLevelMap()) { + continue; + } + prevWordIds->push_back(entry.key()); + if (!getEntryInfo(headerPolicy, targetLevel, entry.getNextLevelBitmapEntryIndex(), + prevWordIds, outEntryInfo)) { + return false; + } + prevWordIds->pop_back(); + continue; + } + const ProbabilityEntry probabilityEntry = + ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo); + const int probability = (mHasHistoricalInfo) ? + ForgettingCurveUtils::decodeProbability(probabilityEntry.getHistoricalInfo(), + headerPolicy) : probabilityEntry.getProbability(); + outEntryInfo->emplace_back(probability, + probabilityEntry.getHistoricalInfo()->getTimestamp(), + entry.key(), targetLevel, prevWordIds->data()); + } + return true; +} + +bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()( + const EntryInfoToTurncate &left, const EntryInfoToTurncate &right) const { + if (left.mProbability != right.mProbability) { + return left.mProbability < right.mProbability; + } + if (left.mTimestamp != right.mTimestamp) { + return left.mTimestamp > right.mTimestamp; + } + if (left.mKey != right.mKey) { + return left.mKey < right.mKey; + } + if (left.mPrevWordCount != right.mPrevWordCount) { + return left.mPrevWordCount > right.mPrevWordCount; + } + for (int i = 0; i < left.mPrevWordCount; ++i) { + if (left.mPrevWordIds[i] != right.mPrevWordIds[i]) { + return left.mPrevWordIds[i] < right.mPrevWordIds[i]; + } + } + // left and rigth represent the same entry. + return false; +} + +LanguageModelDictContent::EntryInfoToTurncate::EntryInfoToTurncate(const int probability, + const int timestamp, const int key, const int prevWordCount, const int *const prevWordIds) + : mProbability(probability), mTimestamp(timestamp), mKey(key), + mPrevWordCount(prevWordCount) { + memmove(mPrevWordIds, prevWordIds, mPrevWordCount * sizeof(mPrevWordIds[0])); +} + } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h index 18f2e0170..1cccf92d2 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h @@ -18,17 +18,22 @@ #define LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H #include <cstdio> +#include <vector> #include "defines.h" +#include "suggest/core/dictionary/word_attributes.h" #include "suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h" #include "suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h" #include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h" +#include "suggest/policyimpl/dictionary/utils/entry_counters.h" #include "suggest/policyimpl/dictionary/utils/trie_map.h" #include "utils/byte_array_view.h" #include "utils/int_array_view.h" namespace latinime { +class HeaderPolicy; + /** * Class representing language model. * @@ -36,6 +41,96 @@ namespace latinime { */ class LanguageModelDictContent { public: + // Pair of word id and probability entry used for iteration. + class WordIdAndProbabilityEntry { + public: + WordIdAndProbabilityEntry(const int wordId, const ProbabilityEntry &probabilityEntry) + : mWordId(wordId), mProbabilityEntry(probabilityEntry) {} + + int getWordId() const { return mWordId; } + const ProbabilityEntry getProbabilityEntry() const { return mProbabilityEntry; } + + private: + DISALLOW_DEFAULT_CONSTRUCTOR(WordIdAndProbabilityEntry); + DISALLOW_ASSIGNMENT_OPERATOR(WordIdAndProbabilityEntry); + + const int mWordId; + const ProbabilityEntry mProbabilityEntry; + }; + + // Iterator. + class EntryIterator { + public: + EntryIterator(const TrieMap::TrieMapIterator &trieMapIterator, + const bool hasHistoricalInfo) + : mTrieMapIterator(trieMapIterator), mHasHistoricalInfo(hasHistoricalInfo) {} + + const WordIdAndProbabilityEntry operator*() const { + const TrieMap::TrieMapIterator::IterationResult &result = *mTrieMapIterator; + return WordIdAndProbabilityEntry( + result.key(), ProbabilityEntry::decode(result.value(), mHasHistoricalInfo)); + } + + bool operator!=(const EntryIterator &other) const { + return mTrieMapIterator != other.mTrieMapIterator; + } + + const EntryIterator &operator++() { + ++mTrieMapIterator; + return *this; + } + + private: + DISALLOW_DEFAULT_CONSTRUCTOR(EntryIterator); + DISALLOW_ASSIGNMENT_OPERATOR(EntryIterator); + + TrieMap::TrieMapIterator mTrieMapIterator; + const bool mHasHistoricalInfo; + }; + + // Class represents range to use range base for loops. + class EntryRange { + public: + EntryRange(const TrieMap::TrieMapRange trieMapRange, const bool hasHistoricalInfo) + : mTrieMapRange(trieMapRange), mHasHistoricalInfo(hasHistoricalInfo) {} + + EntryIterator begin() const { + return EntryIterator(mTrieMapRange.begin(), mHasHistoricalInfo); + } + + EntryIterator end() const { + return EntryIterator(mTrieMapRange.end(), mHasHistoricalInfo); + } + + private: + DISALLOW_DEFAULT_CONSTRUCTOR(EntryRange); + DISALLOW_ASSIGNMENT_OPERATOR(EntryRange); + + const TrieMap::TrieMapRange mTrieMapRange; + const bool mHasHistoricalInfo; + }; + + class DumppedFullEntryInfo { + public: + DumppedFullEntryInfo(std::vector<int> &prevWordIds, const int targetWordId, + const WordAttributes &wordAttributes, const ProbabilityEntry &probabilityEntry) + : mPrevWordIds(prevWordIds), mTargetWordId(targetWordId), + mWordAttributes(wordAttributes), mProbabilityEntry(probabilityEntry) {} + + const WordIdArrayView getPrevWordIds() const { return WordIdArrayView(mPrevWordIds); } + int getTargetWordId() const { return mTargetWordId; } + const WordAttributes &getWordAttributes() const { return mWordAttributes; } + const ProbabilityEntry &getProbabilityEntry() const { return mProbabilityEntry; } + + private: + DISALLOW_ASSIGNMENT_OPERATOR(DumppedFullEntryInfo); + + const std::vector<int> mPrevWordIds; + const int mTargetWordId; + const WordAttributes mWordAttributes; + const ProbabilityEntry mProbabilityEntry; + }; + LanguageModelDictContent(const ReadWriteByteArrayView trieMapBuffer, const bool hasHistoricalInfo) : mTrieMap(trieMapBuffer), mHasHistoricalInfo(hasHistoricalInfo) {} @@ -50,8 +145,10 @@ class LanguageModelDictContent { bool save(FILE *const file) const; bool runGC(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, - const LanguageModelDictContent *const originalContent, - int *const outNgramCount); + const LanguageModelDictContent *const originalContent); + + const WordAttributes getWordAttributes(const WordIdArrayView prevWordIds, const int wordId, + const HeaderPolicy *const headerPolicy) const; ProbabilityEntry getProbabilityEntry(const int wordId) const { return getNgramProbabilityEntry(WordIdArrayView(), wordId); @@ -61,23 +158,87 @@ class LanguageModelDictContent { return setNgramProbabilityEntry(WordIdArrayView(), wordId, probabilityEntry); } + bool removeProbabilityEntry(const int wordId) { + return removeNgramProbabilityEntry(WordIdArrayView(), wordId); + } + ProbabilityEntry getNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId) const; bool setNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId, const ProbabilityEntry *const probabilityEntry); + bool removeNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId); + + EntryRange getProbabilityEntries(const WordIdArrayView prevWordIds) const; + + std::vector<DumppedFullEntryInfo> exportAllNgramEntriesRelatedToWord( + const HeaderPolicy *const headerPolicy, const int wordId) const; + + bool updateAllProbabilityEntriesForGC(const HeaderPolicy *const headerPolicy, + MutableEntryCounters *const outEntryCounters) { + return updateAllProbabilityEntriesForGCInner(mTrieMap.getRootBitmapEntryIndex(), + 0 /* prevWordCount */, headerPolicy, outEntryCounters); + } + + // entryCounts should be created by updateAllProbabilityEntries. + bool truncateEntries(const EntryCounts ¤tEntryCounts, const EntryCounts &maxEntryCounts, + const HeaderPolicy *const headerPolicy, MutableEntryCounters *const outEntryCounters); + + bool updateAllEntriesOnInputWord(const WordIdArrayView prevWordIds, const int wordId, + const bool isValid, const HistoricalInfo historicalInfo, + const HeaderPolicy *const headerPolicy, + MutableEntryCounters *const entryCountersToUpdate); + private: DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent); + class EntryInfoToTurncate { + public: + class Comparator { + public: + bool operator()(const EntryInfoToTurncate &left, + const EntryInfoToTurncate &right) const; + private: + DISALLOW_ASSIGNMENT_OPERATOR(Comparator); + }; + + EntryInfoToTurncate(const int probability, const int timestamp, const int key, + const int prevWordCount, const int *const prevWordIds); + + int mProbability; + int mTimestamp; + int mKey; + int mPrevWordCount; + int mPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; + + private: + DISALLOW_DEFAULT_CONSTRUCTOR(EntryInfoToTurncate); + }; + + // TODO: Remove + static const int DUMMY_PROBABILITY_FOR_VALID_WORDS; + TrieMap mTrieMap; const bool mHasHistoricalInfo; bool runGCInner(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap, - const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex, - int *const outNgramCount); - + const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex); + int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds); int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const; + bool updateAllProbabilityEntriesForGCInner(const int bitmapEntryIndex, const int prevWordCount, + const HeaderPolicy *const headerPolicy, MutableEntryCounters *const outEntryCounters); + bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy, + const int maxEntryCount, const int targetLevel, int *const outEntryCount); + bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel, + const int bitmapEntryIndex, std::vector<int> *const prevWordIds, + std::vector<EntryInfoToTurncate> *const outEntryInfo) const; + const ProbabilityEntry createUpdatedEntryFrom(const ProbabilityEntry &originalProbabilityEntry, + const bool isValid, const HistoricalInfo historicalInfo, + const HeaderPolicy *const headerPolicy) const; + void exportAllNgramEntriesRelatedToWordInner(const HeaderPolicy *const headerPolicy, + const int bitmapEntryIndex, std::vector<int> *const prevWordIds, + std::vector<DumppedFullEntryInfo> *const outBummpedFullEntryInfo) const; }; } // namespace latinime #endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h index feff6b57f..f4d340f86 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h @@ -21,8 +21,10 @@ #include <cstdint> #include "defines.h" +#include "suggest/core/dictionary/property/historical_info.h" +#include "suggest/core/dictionary/property/ngram_property.h" +#include "suggest/core/dictionary/property/unigram_property.h" #include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h" -#include "suggest/policyimpl/dictionary/utils/historical_info.h" namespace latinime { @@ -34,31 +36,40 @@ class ProbabilityEntry { // Dummy entry ProbabilityEntry() - : mFlags(0), mProbability(NOT_A_PROBABILITY), mHistoricalInfo() {} + : mFlags(Ver4DictConstants::FLAG_NOT_A_VALID_ENTRY), mProbability(NOT_A_PROBABILITY), + mHistoricalInfo() {} // Entry without historical information ProbabilityEntry(const int flags, const int probability) : mFlags(flags), mProbability(probability), mHistoricalInfo() {} // Entry with historical information. - ProbabilityEntry(const int flags, const int probability, - const HistoricalInfo *const historicalInfo) - : mFlags(flags), mProbability(probability), mHistoricalInfo(*historicalInfo) {} - - const ProbabilityEntry createEntryWithUpdatedProbability(const int probability) const { - return ProbabilityEntry(mFlags, probability, &mHistoricalInfo); - } - - const ProbabilityEntry createEntryWithUpdatedHistoricalInfo( - const HistoricalInfo *const historicalInfo) const { - return ProbabilityEntry(mFlags, mProbability, historicalInfo); + ProbabilityEntry(const int flags, const HistoricalInfo *const historicalInfo) + : mFlags(flags), mProbability(NOT_A_PROBABILITY), mHistoricalInfo(*historicalInfo) {} + + // Create from unigram property. + ProbabilityEntry(const UnigramProperty *const unigramProperty) + : mFlags(createFlags(unigramProperty->representsBeginningOfSentence(), + unigramProperty->isNotAWord(), unigramProperty->isBlacklisted(), + unigramProperty->isPossiblyOffensive())), + mProbability(unigramProperty->getProbability()), + mHistoricalInfo(unigramProperty->getHistoricalInfo()) {} + + // Create from ngram property. + // TODO: Set flags. + ProbabilityEntry(const NgramProperty *const ngramProperty) + : mFlags(0), mProbability(ngramProperty->getProbability()), + mHistoricalInfo(ngramProperty->getHistoricalInfo()) {} + + bool isValid() const { + return (mFlags & Ver4DictConstants::FLAG_NOT_A_VALID_ENTRY) == 0; } bool hasHistoricalInfo() const { return mHistoricalInfo.isValid(); } - int getFlags() const { + uint8_t getFlags() const { return mFlags; } @@ -70,18 +81,34 @@ class ProbabilityEntry { return &mHistoricalInfo; } + bool representsBeginningOfSentence() const { + return (mFlags & Ver4DictConstants::FLAG_REPRESENTS_BEGINNING_OF_SENTENCE) != 0; + } + + bool isNotAWord() const { + return (mFlags & Ver4DictConstants::FLAG_NOT_A_WORD) != 0; + } + + bool isBlacklisted() const { + return (mFlags & Ver4DictConstants::FLAG_BLACKLISTED) != 0; + } + + bool isPossiblyOffensive() const { + return (mFlags & Ver4DictConstants::FLAG_POSSIBLY_OFFENSIVE) != 0; + } + uint64_t encode(const bool hasHistoricalInfo) const { - uint64_t encodedEntry = static_cast<uint64_t>(mFlags); + uint64_t encodedEntry = static_cast<uint8_t>(mFlags); if (hasHistoricalInfo) { encodedEntry = (encodedEntry << (Ver4DictConstants::TIME_STAMP_FIELD_SIZE * CHAR_BIT)) - ^ static_cast<uint64_t>(mHistoricalInfo.getTimeStamp()); + | static_cast<uint32_t>(mHistoricalInfo.getTimestamp()); encodedEntry = (encodedEntry << (Ver4DictConstants::WORD_LEVEL_FIELD_SIZE * CHAR_BIT)) - ^ static_cast<uint64_t>(mHistoricalInfo.getLevel()); + | static_cast<uint8_t>(mHistoricalInfo.getLevel()); encodedEntry = (encodedEntry << (Ver4DictConstants::WORD_COUNT_FIELD_SIZE * CHAR_BIT)) - ^ static_cast<uint64_t>(mHistoricalInfo.getCount()); + | static_cast<uint8_t>(mHistoricalInfo.getCount()); } else { encodedEntry = (encodedEntry << (Ver4DictConstants::PROBABILITY_SIZE * CHAR_BIT)) - ^ static_cast<uint64_t>(mProbability); + | static_cast<uint8_t>(mProbability); } return encodedEntry; } @@ -89,7 +116,7 @@ class ProbabilityEntry { static ProbabilityEntry decode(const uint64_t encodedEntry, const bool hasHistoricalInfo) { if (hasHistoricalInfo) { const int flags = readFromEncodedEntry(encodedEntry, - Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE, + Ver4DictConstants::FLAGS_IN_LANGUAGE_MODEL_SIZE, Ver4DictConstants::TIME_STAMP_FIELD_SIZE + Ver4DictConstants::WORD_LEVEL_FIELD_SIZE + Ver4DictConstants::WORD_COUNT_FIELD_SIZE); @@ -103,10 +130,10 @@ class ProbabilityEntry { const int count = readFromEncodedEntry(encodedEntry, Ver4DictConstants::WORD_COUNT_FIELD_SIZE, 0 /* pos */); const HistoricalInfo historicalInfo(timestamp, level, count); - return ProbabilityEntry(flags, NOT_A_PROBABILITY, &historicalInfo); + return ProbabilityEntry(flags, &historicalInfo); } else { const int flags = readFromEncodedEntry(encodedEntry, - Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE, + Ver4DictConstants::FLAGS_IN_LANGUAGE_MODEL_SIZE, Ver4DictConstants::PROBABILITY_SIZE); const int probability = readFromEncodedEntry(encodedEntry, Ver4DictConstants::PROBABILITY_SIZE, 0 /* pos */); @@ -118,7 +145,7 @@ class ProbabilityEntry { // Copy constructor is public to use this class as a type of return value. DISALLOW_ASSIGNMENT_OPERATOR(ProbabilityEntry); - const int mFlags; + const uint8_t mFlags; const int mProbability; const HistoricalInfo mHistoricalInfo; @@ -126,6 +153,24 @@ class ProbabilityEntry { return static_cast<int>( (encodedEntry >> (pos * CHAR_BIT)) & ((1ull << (size * CHAR_BIT)) - 1)); } + + static uint8_t createFlags(const bool representsBeginningOfSentence, + const bool isNotAWord, const bool isBlacklisted, const bool isPossiblyOffensive) { + uint8_t flags = 0; + if (representsBeginningOfSentence) { + flags |= Ver4DictConstants::FLAG_REPRESENTS_BEGINNING_OF_SENTENCE; + } + if (isNotAWord) { + flags |= Ver4DictConstants::FLAG_NOT_A_WORD; + } + if (isBlacklisted) { + flags |= Ver4DictConstants::FLAG_BLACKLISTED; + } + if (isPossiblyOffensive) { + flags |= Ver4DictConstants::FLAG_POSSIBLY_OFFENSIVE; + } + return flags; + } }; } // namespace latinime #endif /* LATINIME_PROBABILITY_ENTRY_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/shortcut_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/shortcut_dict_content.h index 7b12aff16..85c9ce8d8 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/shortcut_dict_content.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/shortcut_dict_content.h @@ -17,7 +17,6 @@ #ifndef LATINIME_SHORTCUT_DICT_CONTENT_H #define LATINIME_SHORTCUT_DICT_CONTENT_H -#include <cstdint> #include <cstdio> #include "defines.h" @@ -27,11 +26,12 @@ namespace latinime { +class ReadWriteByteArrayView; + class ShortcutDictContent : public SparseTableDictContent { public: - ShortcutDictContent(uint8_t *const *buffers, const int *bufferSizes) - : SparseTableDictContent(buffers, bufferSizes, - Ver4DictConstants::SHORTCUT_ADDRESS_TABLE_BLOCK_SIZE, + ShortcutDictContent(const ReadWriteByteArrayView *const buffers) + : SparseTableDictContent(buffers, Ver4DictConstants::SHORTCUT_ADDRESS_TABLE_BLOCK_SIZE, Ver4DictConstants::SHORTCUT_ADDRESS_TABLE_DATA_SIZE) {} ShortcutDictContent() diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/single_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/single_dict_content.h index 921774181..309c434cf 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/single_dict_content.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/single_dict_content.h @@ -17,7 +17,6 @@ #ifndef LATINIME_SINGLE_DICT_CONTENT_H #define LATINIME_SINGLE_DICT_CONTENT_H -#include <cstdint> #include <cstdio> #include "defines.h" @@ -30,9 +29,9 @@ namespace latinime { class SingleDictContent { public: - SingleDictContent(uint8_t *const buffer, const int bufferSize) - : mExpandableContentBuffer(ReadWriteByteArrayView(buffer, bufferSize), - BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE) {} + SingleDictContent(const ReadWriteByteArrayView buffer) + : mExpandableContentBuffer(buffer, + BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE) {} SingleDictContent() : mExpandableContentBuffer(Ver4DictConstants::MAX_DICTIONARY_SIZE) {} diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/sparse_table_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/sparse_table_dict_content.h index c98dd11fd..0ce2da7bf 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/sparse_table_dict_content.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/sparse_table_dict_content.h @@ -17,7 +17,6 @@ #ifndef LATINIME_SPARSE_TABLE_DICT_CONTENT_H #define LATINIME_SPARSE_TABLE_DICT_CONTENT_H -#include <cstdint> #include <cstdio> #include "defines.h" @@ -31,19 +30,13 @@ namespace latinime { // TODO: Support multiple contents. class SparseTableDictContent { public: - AK_FORCE_INLINE SparseTableDictContent(uint8_t *const *buffers, const int *bufferSizes, + AK_FORCE_INLINE SparseTableDictContent(const ReadWriteByteArrayView *const buffers, const int sparseTableBlockSize, const int sparseTableDataSize) - : mExpandableLookupTableBuffer( - ReadWriteByteArrayView(buffers[LOOKUP_TABLE_BUFFER_INDEX], - bufferSizes[LOOKUP_TABLE_BUFFER_INDEX]), + : mExpandableLookupTableBuffer(buffers[LOOKUP_TABLE_BUFFER_INDEX], BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE), - mExpandableAddressTableBuffer( - ReadWriteByteArrayView(buffers[ADDRESS_TABLE_BUFFER_INDEX], - bufferSizes[ADDRESS_TABLE_BUFFER_INDEX]), + mExpandableAddressTableBuffer(buffers[ADDRESS_TABLE_BUFFER_INDEX], BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE), - mExpandableContentBuffer( - ReadWriteByteArrayView(buffers[CONTENT_BUFFER_INDEX], - bufferSizes[CONTENT_BUFFER_INDEX]), + mExpandableContentBuffer(buffers[CONTENT_BUFFER_INDEX], BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE), mAddressLookupTable(&mExpandableLookupTableBuffer, &mExpandableAddressTableBuffer, sparseTableBlockSize, sparseTableDataSize) {} diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.cpp index cf238ee5f..2bdf07752 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.cpp @@ -34,7 +34,7 @@ int TerminalPositionLookupTable::getTerminalPtNodePosition(const int terminalId) bool TerminalPositionLookupTable::setTerminalPtNodePosition( const int terminalId, const int terminalPtNodePos) { if (terminalId < 0) { - return NOT_A_DICT_POS; + return false; } while (terminalId >= mSize) { // Write new entry. diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h index b2262bf1e..febcbe5b4 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h @@ -17,13 +17,13 @@ #ifndef LATINIME_TERMINAL_POSITION_LOOKUP_TABLE_H #define LATINIME_TERMINAL_POSITION_LOOKUP_TABLE_H -#include <cstdint> #include <cstdio> #include <unordered_map> #include "defines.h" #include "suggest/policyimpl/dictionary/structure/v4/content/single_dict_content.h" #include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h" +#include "utils/byte_array_view.h" namespace latinime { @@ -31,8 +31,8 @@ class TerminalPositionLookupTable : public SingleDictContent { public: typedef std::unordered_map<int, int> TerminalIdMap; - TerminalPositionLookupTable(uint8_t *const buffer, const int bufferSize) - : SingleDictContent(buffer, bufferSize), + TerminalPositionLookupTable(const ReadWriteByteArrayView buffer) + : SingleDictContent(buffer), mSize(getBuffer()->getTailPosition() / Ver4DictConstants::TERMINAL_ADDRESS_TABLE_ADDRESS_SIZE) {} diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.cpp index 3c8008dc4..45f88e9b2 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.cpp @@ -45,16 +45,13 @@ namespace latinime { if (!bodyBuffer) { return Ver4DictBuffersPtr(nullptr); } - std::vector<uint8_t *> buffers; - std::vector<int> bufferSizes; + std::vector<ReadWriteByteArrayView> buffers; const ReadWriteByteArrayView buffer = bodyBuffer->getReadWriteByteArrayView(); int position = 0; while (position < static_cast<int>(buffer.size())) { const int bufferSize = ByteArrayUtils::readUint32AndAdvancePosition( buffer.data(), &position); - const ReadWriteByteArrayView subBuffer = buffer.subView(position, bufferSize); - buffers.push_back(subBuffer.data()); - bufferSizes.push_back(subBuffer.size()); + buffers.push_back(buffer.subView(position, bufferSize)); position += bufferSize; if (bufferSize < 0 || position < 0 || position > static_cast<int>(buffer.size())) { AKLOGE("The dict body file is corrupted."); @@ -66,7 +63,7 @@ namespace latinime { return Ver4DictBuffersPtr(nullptr); } return Ver4DictBuffersPtr(new Ver4DictBuffers(std::move(headerBuffer), std::move(bodyBuffer), - formatVersion, buffers, bufferSizes)); + formatVersion, buffers)); } bool Ver4DictBuffers::flushHeaderAndDictBuffers(const char *const dictDirPath, @@ -162,11 +159,6 @@ bool Ver4DictBuffers::flushDictBuffers(FILE *const file) const { AKLOGE("Language model dict content cannot be written."); return false; } - // Write bigram dict content. - if (!mBigramDictContent.flushToFile(file)) { - AKLOGE("Bigram dict content cannot be written."); - return false; - } // Write shortcut dict content. if (!mShortcutDictContent.flushToFile(file)) { AKLOGE("Shortcut dict content cannot be written."); @@ -178,29 +170,18 @@ bool Ver4DictBuffers::flushDictBuffers(FILE *const file) const { Ver4DictBuffers::Ver4DictBuffers(MmappedBuffer::MmappedBufferPtr &&headerBuffer, MmappedBuffer::MmappedBufferPtr &&bodyBuffer, const FormatUtils::FORMAT_VERSION formatVersion, - const std::vector<uint8_t *> &contentBuffers, const std::vector<int> &contentBufferSizes) + const std::vector<ReadWriteByteArrayView> &contentBuffers) : mHeaderBuffer(std::move(headerBuffer)), mDictBuffer(std::move(bodyBuffer)), mHeaderPolicy(mHeaderBuffer->getReadOnlyByteArrayView().data(), formatVersion), mExpandableHeaderBuffer(mHeaderBuffer->getReadWriteByteArrayView(), BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE), - mExpandableTrieBuffer( - ReadWriteByteArrayView(contentBuffers[Ver4DictConstants::TRIE_BUFFER_INDEX], - contentBufferSizes[Ver4DictConstants::TRIE_BUFFER_INDEX]), + mExpandableTrieBuffer(contentBuffers[Ver4DictConstants::TRIE_BUFFER_INDEX], BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE), mTerminalPositionLookupTable( - contentBuffers[Ver4DictConstants::TERMINAL_ADDRESS_LOOKUP_TABLE_BUFFER_INDEX], - contentBufferSizes[ - Ver4DictConstants::TERMINAL_ADDRESS_LOOKUP_TABLE_BUFFER_INDEX]), - mLanguageModelDictContent( - ReadWriteByteArrayView( - contentBuffers[Ver4DictConstants::LANGUAGE_MODEL_BUFFER_INDEX], - contentBufferSizes[Ver4DictConstants::LANGUAGE_MODEL_BUFFER_INDEX]), - mHeaderPolicy.hasHistoricalInfoOfWords()), - mBigramDictContent(&contentBuffers[Ver4DictConstants::BIGRAM_BUFFERS_INDEX], - &contentBufferSizes[Ver4DictConstants::BIGRAM_BUFFERS_INDEX], + contentBuffers[Ver4DictConstants::TERMINAL_ADDRESS_LOOKUP_TABLE_BUFFER_INDEX]), + mLanguageModelDictContent(contentBuffers[Ver4DictConstants::LANGUAGE_MODEL_BUFFER_INDEX], mHeaderPolicy.hasHistoricalInfoOfWords()), - mShortcutDictContent(&contentBuffers[Ver4DictConstants::SHORTCUT_BUFFERS_INDEX], - &contentBufferSizes[Ver4DictConstants::SHORTCUT_BUFFERS_INDEX]), + mShortcutDictContent(&contentBuffers[Ver4DictConstants::SHORTCUT_BUFFERS_INDEX]), mIsUpdatable(mDictBuffer->isUpdatable()) {} Ver4DictBuffers::Ver4DictBuffers(const HeaderPolicy *const headerPolicy, const int maxTrieSize) @@ -208,7 +189,6 @@ Ver4DictBuffers::Ver4DictBuffers(const HeaderPolicy *const headerPolicy, const i mExpandableHeaderBuffer(Ver4DictConstants::MAX_DICTIONARY_SIZE), mExpandableTrieBuffer(maxTrieSize), mTerminalPositionLookupTable(), mLanguageModelDictContent(headerPolicy->hasHistoricalInfoOfWords()), - mBigramDictContent(headerPolicy->hasHistoricalInfoOfWords()), mShortcutDictContent(), - mIsUpdatable(true) {} + mShortcutDictContent(), mIsUpdatable(true) {} } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.h index 68027dcb8..5407525af 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.h @@ -22,7 +22,6 @@ #include "defines.h" #include "suggest/policyimpl/dictionary/header/header_policy.h" -#include "suggest/policyimpl/dictionary/structure/v4/content/bigram_dict_content.h" #include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h" #include "suggest/policyimpl/dictionary/structure/v4/content/shortcut_dict_content.h" #include "suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h" @@ -53,7 +52,6 @@ class Ver4DictBuffers { return mExpandableTrieBuffer.isNearSizeLimit() || mTerminalPositionLookupTable.isNearSizeLimit() || mLanguageModelDictContent.isNearSizeLimit() - || mBigramDictContent.isNearSizeLimit() || mShortcutDictContent.isNearSizeLimit(); } @@ -89,14 +87,6 @@ class Ver4DictBuffers { return &mLanguageModelDictContent; } - AK_FORCE_INLINE BigramDictContent *getMutableBigramDictContent() { - return &mBigramDictContent; - } - - AK_FORCE_INLINE const BigramDictContent *getBigramDictContent() const { - return &mBigramDictContent; - } - AK_FORCE_INLINE ShortcutDictContent *getMutableShortcutDictContent() { return &mShortcutDictContent; } @@ -122,8 +112,7 @@ class Ver4DictBuffers { Ver4DictBuffers(MmappedBuffer::MmappedBufferPtr &&headerBuffer, MmappedBuffer::MmappedBufferPtr &&bodyBuffer, const FormatUtils::FORMAT_VERSION formatVersion, - const std::vector<uint8_t *> &contentBuffers, - const std::vector<int> &contentBufferSizes); + const std::vector<ReadWriteByteArrayView> &contentBuffers); Ver4DictBuffers(const HeaderPolicy *const headerPolicy, const int maxTrieSize); @@ -136,7 +125,6 @@ class Ver4DictBuffers { BufferWithExtendableBuffer mExpandableTrieBuffer; TerminalPositionLookupTable mTerminalPositionLookupTable; LanguageModelDictContent mLanguageModelDictContent; - BigramDictContent mBigramDictContent; ShortcutDictContent mShortcutDictContent; const int mIsUpdatable; }; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp index 93d4e562d..8e6cb974b 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp @@ -29,24 +29,22 @@ const int Ver4DictConstants::MAX_DICT_EXTENDED_REGION_SIZE = 1 * 1024 * 1024; // NUM_OF_BUFFERS_FOR_SINGLE_DICT_CONTENT for Trie and TerminalAddressLookupTable. // NUM_OF_BUFFERS_FOR_LANGUAGE_MODEL_DICT_CONTENT for language model. -// NUM_OF_BUFFERS_FOR_SPARSE_TABLE_DICT_CONTENT for bigram and shortcut. +// NUM_OF_BUFFERS_FOR_SPARSE_TABLE_DICT_CONTENT for shortcut. const size_t Ver4DictConstants::NUM_OF_CONTENT_BUFFERS_IN_BODY_FILE = NUM_OF_BUFFERS_FOR_SINGLE_DICT_CONTENT * 2 + NUM_OF_BUFFERS_FOR_LANGUAGE_MODEL_DICT_CONTENT - + NUM_OF_BUFFERS_FOR_SPARSE_TABLE_DICT_CONTENT * 2; + + NUM_OF_BUFFERS_FOR_SPARSE_TABLE_DICT_CONTENT; const int Ver4DictConstants::TRIE_BUFFER_INDEX = 0; const int Ver4DictConstants::TERMINAL_ADDRESS_LOOKUP_TABLE_BUFFER_INDEX = TRIE_BUFFER_INDEX + NUM_OF_BUFFERS_FOR_SINGLE_DICT_CONTENT; const int Ver4DictConstants::LANGUAGE_MODEL_BUFFER_INDEX = TERMINAL_ADDRESS_LOOKUP_TABLE_BUFFER_INDEX + NUM_OF_BUFFERS_FOR_SINGLE_DICT_CONTENT; -const int Ver4DictConstants::BIGRAM_BUFFERS_INDEX = - LANGUAGE_MODEL_BUFFER_INDEX + NUM_OF_BUFFERS_FOR_LANGUAGE_MODEL_DICT_CONTENT; const int Ver4DictConstants::SHORTCUT_BUFFERS_INDEX = - BIGRAM_BUFFERS_INDEX + NUM_OF_BUFFERS_FOR_SPARSE_TABLE_DICT_CONTENT; + LANGUAGE_MODEL_BUFFER_INDEX + NUM_OF_BUFFERS_FOR_LANGUAGE_MODEL_DICT_CONTENT; const int Ver4DictConstants::NOT_A_TERMINAL_ID = -1; const int Ver4DictConstants::PROBABILITY_SIZE = 1; -const int Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE = 1; +const int Ver4DictConstants::FLAGS_IN_LANGUAGE_MODEL_SIZE = 1; const int Ver4DictConstants::TERMINAL_ADDRESS_TABLE_ADDRESS_SIZE = 3; const int Ver4DictConstants::NOT_A_TERMINAL_ADDRESS = 0; const int Ver4DictConstants::TERMINAL_ID_FIELD_SIZE = 4; @@ -54,21 +52,15 @@ const int Ver4DictConstants::TIME_STAMP_FIELD_SIZE = 4; const int Ver4DictConstants::WORD_LEVEL_FIELD_SIZE = 1; const int Ver4DictConstants::WORD_COUNT_FIELD_SIZE = 1; -const int Ver4DictConstants::BIGRAM_ADDRESS_TABLE_BLOCK_SIZE = 16; -const int Ver4DictConstants::BIGRAM_ADDRESS_TABLE_DATA_SIZE = 4; +const uint8_t Ver4DictConstants::FLAG_REPRESENTS_BEGINNING_OF_SENTENCE = 0x1; +const uint8_t Ver4DictConstants::FLAG_NOT_A_VALID_ENTRY = 0x2; +const uint8_t Ver4DictConstants::FLAG_NOT_A_WORD = 0x4; +const uint8_t Ver4DictConstants::FLAG_BLACKLISTED = 0x8; +const uint8_t Ver4DictConstants::FLAG_POSSIBLY_OFFENSIVE = 0x10; + const int Ver4DictConstants::SHORTCUT_ADDRESS_TABLE_BLOCK_SIZE = 64; const int Ver4DictConstants::SHORTCUT_ADDRESS_TABLE_DATA_SIZE = 4; -const int Ver4DictConstants::BIGRAM_TARGET_TERMINAL_ID_FIELD_SIZE = 3; -// Unsigned int max value of BIGRAM_TARGET_TERMINAL_ID_FIELD_SIZE-byte is used for representing -// invalid terminal ID in bigram lists. -const int Ver4DictConstants::INVALID_BIGRAM_TARGET_TERMINAL_ID = - (1 << (BIGRAM_TARGET_TERMINAL_ID_FIELD_SIZE * 8)) - 1; -const int Ver4DictConstants::BIGRAM_FLAGS_FIELD_SIZE = 1; -const int Ver4DictConstants::BIGRAM_PROBABILITY_MASK = 0x0F; -const int Ver4DictConstants::BIGRAM_IS_LINK_MASK = 0x80; -const int Ver4DictConstants::BIGRAM_LARGE_PROBABILITY_FIELD_SIZE = 1; - const int Ver4DictConstants::SHORTCUT_FLAGS_FIELD_SIZE = 1; const int Ver4DictConstants::SHORTCUT_PROBABILITY_MASK = 0x0F; const int Ver4DictConstants::SHORTCUT_HAS_NEXT_MASK = 0x80; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h index 6950ca70f..600b5ffe4 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h @@ -20,6 +20,7 @@ #include "defines.h" #include <cstddef> +#include <cstdint> namespace latinime { @@ -41,27 +42,23 @@ class Ver4DictConstants { static const int NOT_A_TERMINAL_ID; static const int PROBABILITY_SIZE; - static const int FLAGS_IN_PROBABILITY_FILE_SIZE; + static const int FLAGS_IN_LANGUAGE_MODEL_SIZE; static const int TERMINAL_ADDRESS_TABLE_ADDRESS_SIZE; static const int NOT_A_TERMINAL_ADDRESS; static const int TERMINAL_ID_FIELD_SIZE; static const int TIME_STAMP_FIELD_SIZE; static const int WORD_LEVEL_FIELD_SIZE; static const int WORD_COUNT_FIELD_SIZE; + // Flags in probability entry. + static const uint8_t FLAG_REPRESENTS_BEGINNING_OF_SENTENCE; + static const uint8_t FLAG_NOT_A_VALID_ENTRY; + static const uint8_t FLAG_NOT_A_WORD; + static const uint8_t FLAG_BLACKLISTED; + static const uint8_t FLAG_POSSIBLY_OFFENSIVE; - static const int BIGRAM_ADDRESS_TABLE_BLOCK_SIZE; - static const int BIGRAM_ADDRESS_TABLE_DATA_SIZE; static const int SHORTCUT_ADDRESS_TABLE_BLOCK_SIZE; static const int SHORTCUT_ADDRESS_TABLE_DATA_SIZE; - static const int BIGRAM_FLAGS_FIELD_SIZE; - static const int BIGRAM_TARGET_TERMINAL_ID_FIELD_SIZE; - static const int INVALID_BIGRAM_TARGET_TERMINAL_ID; - static const int BIGRAM_IS_LINK_MASK; - static const int BIGRAM_PROBABILITY_MASK; - // Used when bigram list has time stamp. - static const int BIGRAM_LARGE_PROBABILITY_FIELD_SIZE; - static const int SHORTCUT_FLAGS_FIELD_SIZE; static const int SHORTCUT_PROBABILITY_MASK; static const int SHORTCUT_HAS_NEXT_MASK; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.cpp index 731092efd..4110d6036 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.cpp @@ -16,6 +16,7 @@ #include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.h" +#include "suggest/policyimpl/dictionary/header/header_policy.h" #include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_utils.h" #include "suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h" #include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h" @@ -50,26 +51,17 @@ const PtNodeParams Ver4PatriciaTrieNodeReader::fetchPtNodeInfoFromBufferAndProce const int parentPos = DynamicPtReadingUtils::getParentPtNodePos(parentPosOffset, headPos); int codePoints[MAX_WORD_LENGTH]; - const int codePonitCount = PatriciaTrieReadingUtils::getCharsAndAdvancePosition( - dictBuf, flags, MAX_WORD_LENGTH, codePoints, &pos); + // Code point table is not used for ver4 dictionaries. + const int codePointCount = PatriciaTrieReadingUtils::getCharsAndAdvancePosition( + dictBuf, flags, MAX_WORD_LENGTH, nullptr /* codePointTable */, codePoints, &pos); int terminalIdFieldPos = NOT_A_DICT_POS; int terminalId = Ver4DictConstants::NOT_A_TERMINAL_ID; - int probability = NOT_A_PROBABILITY; if (PatriciaTrieReadingUtils::isTerminal(flags)) { terminalIdFieldPos = pos; if (usesAdditionalBuffer) { terminalIdFieldPos += mBuffer->getOriginalBufferSize(); } terminalId = Ver4PatriciaTrieReadingUtils::getTerminalIdAndAdvancePosition(dictBuf, &pos); - // TODO: Quit reading probability here. - const ProbabilityEntry probabilityEntry = - mLanguageModelDictContent->getProbabilityEntry(terminalId); - if (probabilityEntry.hasHistoricalInfo()) { - probability = ForgettingCurveUtils::decodeProbability( - probabilityEntry.getHistoricalInfo(), mHeaderPolicy); - } else { - probability = probabilityEntry.getProbability(); - } } int childrenPosFieldPos = pos; if (usesAdditionalBuffer) { @@ -90,8 +82,8 @@ const PtNodeParams Ver4PatriciaTrieNodeReader::fetchPtNodeInfoFromBufferAndProce // The destination position is stored at the same place as the parent position. return fetchPtNodeInfoFromBufferAndProcessMovedPtNode(parentPos, newSiblingNodePos); } else { - return PtNodeParams(headPos, flags, parentPos, codePonitCount, codePoints, - terminalIdFieldPos, terminalId, probability, childrenPosFieldPos, childrenPos, + return PtNodeParams(headPos, flags, parentPos, codePointCount, codePoints, + terminalIdFieldPos, terminalId, NOT_A_PROBABILITY, childrenPosFieldPos, childrenPos, newSiblingNodePos); } } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.h index a91ad5728..f4df544e2 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.h @@ -29,15 +29,12 @@ class LanguageModelDictContent; /* * This class is used for helping to read nodes of ver4 patricia trie. This class handles moved - * node and reads node attributes including probability form language model. + * node and reads node attributes. */ class Ver4PatriciaTrieNodeReader : public PtNodeReader { public: - Ver4PatriciaTrieNodeReader(const BufferWithExtendableBuffer *const buffer, - const LanguageModelDictContent *const languageModelDictContent, - const HeaderPolicy *const headerPolicy) - : mBuffer(buffer), mLanguageModelDictContent(languageModelDictContent), - mHeaderPolicy(headerPolicy) {} + explicit Ver4PatriciaTrieNodeReader(const BufferWithExtendableBuffer *const buffer) + : mBuffer(buffer) {} ~Ver4PatriciaTrieNodeReader() {} @@ -50,8 +47,6 @@ class Ver4PatriciaTrieNodeReader : public PtNodeReader { DISALLOW_COPY_AND_ASSIGN(Ver4PatriciaTrieNodeReader); const BufferWithExtendableBuffer *const mBuffer; - const LanguageModelDictContent *const mLanguageModelDictContent; - const HeaderPolicy *const mHeaderPolicy; const PtNodeParams fetchPtNodeInfoFromBufferAndProcessMovedPtNode(const int ptNodePos, const int siblingNodePos) const; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp index 857222f5d..3488f7d2a 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp @@ -21,7 +21,6 @@ #include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_utils.h" #include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_writing_utils.h" #include "suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h" -#include "suggest/policyimpl/dictionary/structure/v4/bigram/ver4_bigram_list_policy.h" #include "suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h" #include "suggest/policyimpl/dictionary/structure/v4/shortcut/ver4_shortcut_list_policy.h" #include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.h" @@ -62,6 +61,7 @@ bool Ver4PatriciaTrieNodeWriter::markPtNodeAsDeleted( } } +// TODO: Quit using bigramLinkedNodePos. bool Ver4PatriciaTrieNodeWriter::markPtNodeAsMoved( const PtNodeParams *const toBeUpdatedPtNodeParams, const int movedPos, const int bigramLinkedNodePos) { @@ -142,13 +142,9 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeUnigramProperty( if (!toBeUpdatedPtNodeParams->isTerminal()) { return false; } - const ProbabilityEntry originalProbabilityEntry = - mBuffers->getLanguageModelDictContent()->getProbabilityEntry( - toBeUpdatedPtNodeParams->getTerminalId()); - const ProbabilityEntry probabilityEntry = createUpdatedEntryFrom(&originalProbabilityEntry, - unigramProperty); + const ProbabilityEntry probabilityEntryOfUnigramProperty = ProbabilityEntry(unigramProperty); return mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry( - toBeUpdatedPtNodeParams->getTerminalId(), &probabilityEntry); + toBeUpdatedPtNodeParams->getTerminalId(), &probabilityEntryOfUnigramProperty); } bool Ver4PatriciaTrieNodeWriter::updatePtNodeProbabilityAndGetNeedsToKeepPtNodeAfterGC( @@ -160,29 +156,15 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeProbabilityAndGetNeedsToKeepPtNodeA const ProbabilityEntry originalProbabilityEntry = mBuffers->getLanguageModelDictContent()->getProbabilityEntry( toBeUpdatedPtNodeParams->getTerminalId()); - if (originalProbabilityEntry.hasHistoricalInfo()) { - const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave( - originalProbabilityEntry.getHistoricalInfo(), mHeaderPolicy); - const ProbabilityEntry probabilityEntry = - originalProbabilityEntry.createEntryWithUpdatedHistoricalInfo(&historicalInfo); - if (!mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry( - toBeUpdatedPtNodeParams->getTerminalId(), &probabilityEntry)) { - AKLOGE("Cannot write updated probability entry. terminalId: %d", - toBeUpdatedPtNodeParams->getTerminalId()); - return false; - } - const bool isValid = ForgettingCurveUtils::needsToKeep(&historicalInfo, mHeaderPolicy); - if (!isValid) { - if (!markPtNodeAsWillBecomeNonTerminal(toBeUpdatedPtNodeParams)) { - AKLOGE("Cannot mark PtNode as willBecomeNonTerminal."); - return false; - } - } - *outNeedsToKeepPtNode = isValid; - } else { - // No need to update probability. + if (originalProbabilityEntry.isValid()) { *outNeedsToKeepPtNode = true; + return true; } + if (!markPtNodeAsWillBecomeNonTerminal(toBeUpdatedPtNodeParams)) { + AKLOGE("Cannot mark PtNode as willBecomeNonTerminal."); + return false; + } + *outNeedsToKeepPtNode = false; return true; } @@ -205,7 +187,6 @@ bool Ver4PatriciaTrieNodeWriter::writePtNodeAndAdvancePosition( ptNodeWritingPos); } - bool Ver4PatriciaTrieNodeWriter::writeNewTerminalPtNodeAndAdvancePosition( const PtNodeParams *const ptNodeParams, const UnigramProperty *const unigramProperty, int *const ptNodeWritingPos) { @@ -216,31 +197,43 @@ bool Ver4PatriciaTrieNodeWriter::writeNewTerminalPtNodeAndAdvancePosition( } // Write probability. ProbabilityEntry newProbabilityEntry; - const ProbabilityEntry probabilityEntryToWrite = createUpdatedEntryFrom( - &newProbabilityEntry, unigramProperty); + const ProbabilityEntry probabilityEntryOfUnigramProperty = ProbabilityEntry(unigramProperty); return mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry( - terminalId, &probabilityEntryToWrite); + terminalId, &probabilityEntryOfUnigramProperty); } +// TODO: Support counting ngram entries. bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds, const int wordId, - const BigramProperty *const bigramProperty, bool *const outAddedNewBigram) { - if (!mBigramPolicy->addNewEntry(prevWordIds[0], wordId, bigramProperty, outAddedNewBigram)) { - AKLOGE("Cannot add new bigram entry. terminalId: %d, targetTerminalId: %d", - prevWordIds[0], wordId); + const NgramProperty *const ngramProperty, bool *const outAddedNewBigram) { + LanguageModelDictContent *const languageModelDictContent = + mBuffers->getMutableLanguageModelDictContent(); + const ProbabilityEntry probabilityEntry = + languageModelDictContent->getNgramProbabilityEntry(prevWordIds, wordId); + const ProbabilityEntry probabilityEntryOfNgramProperty(ngramProperty); + if (!languageModelDictContent->setNgramProbabilityEntry( + prevWordIds, wordId, &probabilityEntryOfNgramProperty)) { + AKLOGE("Cannot add new ngram entry. prevWordId[0]: %d, prevWordId.size(): %zd, wordId: %d", + prevWordIds[0], prevWordIds.size(), wordId); return false; } + if (!probabilityEntry.isValid() && outAddedNewBigram) { + *outAddedNewBigram = true; + } return true; } bool Ver4PatriciaTrieNodeWriter::removeNgramEntry(const WordIdArrayView prevWordIds, const int wordId) { - return mBigramPolicy->removeEntry(prevWordIds[0], wordId); + LanguageModelDictContent *const languageModelDictContent = + mBuffers->getMutableLanguageModelDictContent(); + return languageModelDictContent->removeNgramProbabilityEntry(prevWordIds, wordId); } +// TODO: Remove when we stop supporting v402 format. bool Ver4PatriciaTrieNodeWriter::updateAllBigramEntriesAndDeleteUselessEntries( const PtNodeParams *const sourcePtNodeParams, int *const outBigramEntryCount) { - return mBigramPolicy->updateAllBigramEntriesAndDeleteUselessEntries( - sourcePtNodeParams->getTerminalId(), outBigramEntryCount); + // Do nothing. + return true; } bool Ver4PatriciaTrieNodeWriter::updateAllPositionFields( @@ -275,12 +268,6 @@ bool Ver4PatriciaTrieNodeWriter::updateAllPositionFields( if (!updateChildrenPosition(toBeUpdatedPtNodeParams, childrenPos)) { return false; } - - // Counts bigram entries. - if (outBigramEntryCount) { - *outBigramEntryCount = mBigramPolicy->getBigramEntryConut( - toBeUpdatedPtNodeParams->getTerminalId()); - } return true; } @@ -289,7 +276,7 @@ bool Ver4PatriciaTrieNodeWriter::addShortcutTarget(const PtNodeParams *const ptN const int shortcutProbability) { if (!mShortcutPolicy->addNewShortcut(ptNodeParams->getTerminalId(), targetCodePoints, targetCodePointCount, shortcutProbability)) { - AKLOGE("Cannot add new shortuct entry. terminalId: %d", ptNodeParams->getTerminalId()); + AKLOGE("Cannot add new shortcut entry. terminalId: %d", ptNodeParams->getTerminalId()); return false; } return true; @@ -346,37 +333,17 @@ bool Ver4PatriciaTrieNodeWriter::writePtNodeAndGetTerminalIdAndAdvancePosition( ptNodeParams->getChildrenPos(), ptNodeWritingPos)) { return false; } - return updatePtNodeFlags(nodePos, ptNodeParams->isBlacklisted(), ptNodeParams->isNotAWord(), - isTerminal, ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */); -} - -const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom( - const ProbabilityEntry *const originalProbabilityEntry, - const UnigramProperty *const unigramProperty) const { - // TODO: Consolidate historical info and probability. - if (mHeaderPolicy->hasHistoricalInfoOfWords()) { - const HistoricalInfo historicalInfoForUpdate(unigramProperty->getTimestamp(), - unigramProperty->getLevel(), unigramProperty->getCount()); - const HistoricalInfo updatedHistoricalInfo = - ForgettingCurveUtils::createUpdatedHistoricalInfo( - originalProbabilityEntry->getHistoricalInfo(), - unigramProperty->getProbability(), &historicalInfoForUpdate, mHeaderPolicy); - return originalProbabilityEntry->createEntryWithUpdatedHistoricalInfo( - &updatedHistoricalInfo); - } else { - return originalProbabilityEntry->createEntryWithUpdatedProbability( - unigramProperty->getProbability()); - } + return updatePtNodeFlags(nodePos, isTerminal, + ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */); } -bool Ver4PatriciaTrieNodeWriter::updatePtNodeFlags(const int ptNodePos, - const bool isBlacklisted, const bool isNotAWord, const bool isTerminal, +bool Ver4PatriciaTrieNodeWriter::updatePtNodeFlags(const int ptNodePos, const bool isTerminal, const bool hasMultipleChars) { // Create node flags and write them. PatriciaTrieReadingUtils::NodeFlags nodeFlags = - PatriciaTrieReadingUtils::createAndGetFlags(isBlacklisted, isNotAWord, isTerminal, - false /* hasShortcutTargets */, false /* hasBigrams */, hasMultipleChars, - CHILDREN_POSITION_FIELD_SIZE); + PatriciaTrieReadingUtils::createAndGetFlags(false /* isNotAWord */, + false /* isPossiblyOffensive */, isTerminal, false /* hasShortcutTargets */, + false /* hasBigrams */, hasMultipleChars, CHILDREN_POSITION_FIELD_SIZE); if (!DynamicPtWritingUtils::writeFlags(mTrieBuffer, nodeFlags, ptNodePos)) { AKLOGE("Cannot write PtNode flags. flags: %x, pos: %d", nodeFlags, ptNodePos); return false; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h index 6703dba04..4ecf88729 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h @@ -27,7 +27,6 @@ namespace latinime { class BufferWithExtendableBuffer; class HeaderPolicy; -class Ver4BigramListPolicy; class Ver4DictBuffers; class Ver4PatriciaTrieNodeReader; class Ver4PtNodeArrayReader; @@ -39,13 +38,11 @@ class Ver4ShortcutListPolicy; class Ver4PatriciaTrieNodeWriter : public PtNodeWriter { public: Ver4PatriciaTrieNodeWriter(BufferWithExtendableBuffer *const trieBuffer, - Ver4DictBuffers *const buffers, const HeaderPolicy *const headerPolicy, - const PtNodeReader *const ptNodeReader, + Ver4DictBuffers *const buffers, const PtNodeReader *const ptNodeReader, const PtNodeArrayReader *const ptNodeArrayReader, - Ver4BigramListPolicy *const bigramPolicy, Ver4ShortcutListPolicy *const shortcutPolicy) - : mTrieBuffer(trieBuffer), mBuffers(buffers), mHeaderPolicy(headerPolicy), - mReadingHelper(ptNodeReader, ptNodeArrayReader), mBigramPolicy(bigramPolicy), - mShortcutPolicy(shortcutPolicy) {} + Ver4ShortcutListPolicy *const shortcutPolicy) + : mTrieBuffer(trieBuffer), mBuffers(buffers), + mReadingHelper(ptNodeReader, ptNodeArrayReader), mShortcutPolicy(shortcutPolicy) {} virtual ~Ver4PatriciaTrieNodeWriter() {} @@ -76,7 +73,7 @@ class Ver4PatriciaTrieNodeWriter : public PtNodeWriter { const UnigramProperty *const unigramProperty, int *const ptNodeWritingPos); virtual bool addNgramEntry(const WordIdArrayView prevWordIds, const int wordId, - const BigramProperty *const bigramProperty, bool *const outAddedNewEntry); + const NgramProperty *const ngramProperty, bool *const outAddedNewEntry); virtual bool removeNgramEntry(const WordIdArrayView prevWordIds, const int wordId); @@ -98,23 +95,13 @@ class Ver4PatriciaTrieNodeWriter : public PtNodeWriter { const PtNodeParams *const ptNodeParams, int *const outTerminalId, int *const ptNodeWritingPos); - // Create updated probability entry using given unigram property. In addition to the - // probability, this method updates historical information if needed. - // TODO: Update flags belonging to the unigram property. - const ProbabilityEntry createUpdatedEntryFrom( - const ProbabilityEntry *const originalProbabilityEntry, - const UnigramProperty *const unigramProperty) const; - - bool updatePtNodeFlags(const int ptNodePos, const bool isBlacklisted, const bool isNotAWord, - const bool isTerminal, const bool hasMultipleChars); + bool updatePtNodeFlags(const int ptNodePos, const bool isTerminal, const bool hasMultipleChars); static const int CHILDREN_POSITION_FIELD_SIZE; BufferWithExtendableBuffer *const mTrieBuffer; Ver4DictBuffers *const mBuffers; - const HeaderPolicy *const mHeaderPolicy; DynamicPtReadingHelper mReadingHelper; - Ver4BigramListPolicy *const mBigramPolicy; Ver4ShortcutListPolicy *const mShortcutPolicy; }; } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp index 723808399..d3de322f9 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp @@ -16,15 +16,17 @@ #include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h" +#include <array> #include <vector> #include "suggest/core/dicnode/dic_node.h" #include "suggest/core/dicnode/dic_node_vector.h" +#include "suggest/core/dictionary/multi_bigram_map.h" #include "suggest/core/dictionary/ngram_listener.h" -#include "suggest/core/dictionary/property/bigram_property.h" +#include "suggest/core/dictionary/property/ngram_property.h" #include "suggest/core/dictionary/property/unigram_property.h" #include "suggest/core/dictionary/property/word_property.h" -#include "suggest/core/session/prev_words_info.h" +#include "suggest/core/session/ngram_context.h" #include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h" #include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.h" #include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h" @@ -54,24 +56,11 @@ void Ver4PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const d if (!ptNodeParams.isValid()) { break; } - bool isTerminal = ptNodeParams.isTerminal() && !ptNodeParams.isDeleted(); - if (isTerminal && mHeaderPolicy->isDecayingDict()) { - // A DecayingDict may have a terminal PtNode that has a terminal DicNode whose - // probability is NOT_A_PROBABILITY. In such case, we don't want to treat it as a - // valid terminal DicNode. - isTerminal = ptNodeParams.getProbability() != NOT_A_PROBABILITY; - } + const bool isTerminal = ptNodeParams.isTerminal() && !ptNodeParams.isDeleted(); + const int wordId = isTerminal ? ptNodeParams.getTerminalId() : NOT_A_WORD_ID; + childDicNodes->pushLeavingChild(dicNode, ptNodeParams.getChildrenPos(), + wordId, ptNodeParams.getCodePointArrayView()); readingHelper.readNextSiblingNode(ptNodeParams); - if (ptNodeParams.representsNonWordInfo()) { - // Skip PtNodes that represent non-word information. - continue; - } - childDicNodes->pushLeavingChild(dicNode, ptNodeParams.getHeadPos(), - ptNodeParams.getChildrenPos(), ptNodeParams.getProbability(), isTerminal, - ptNodeParams.hasChildren(), - ptNodeParams.isBlacklisted() - || ptNodeParams.isNotAWord() /* isBlacklistedOrNotAWord */, - ptNodeParams.getCodePointCount(), ptNodeParams.getCodePoints()); } if (readingHelper.isError()) { mIsCorrupted = true; @@ -79,13 +68,14 @@ void Ver4PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const d } } -int Ver4PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( - const int ptNodePos, const int maxCodePointCount, int *const outCodePoints, - int *const outUnigramProbability) const { +int Ver4PatriciaTriePolicy::getCodePointsAndReturnCodePointCount(const int wordId, + const int maxCodePointCount, int *const outCodePoints) const { DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader); + const int ptNodePos = + mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId); readingHelper.initWithPtNodePos(ptNodePos); - const int codePointCount = readingHelper.getCodePointsAndProbabilityAndReturnCodePointCount( - maxCodePointCount, outCodePoints, outUnigramProbability); + const int codePointCount = readingHelper.getCodePointsAndReturnCodePointCount( + maxCodePointCount, outCodePoints); if (readingHelper.isError()) { mIsCorrupted = true; AKLOGE("Dictionary reading error in getCodePointsAndProbabilityAndReturnCodePointCount()."); @@ -93,76 +83,89 @@ int Ver4PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( return codePointCount; } -int Ver4PatriciaTriePolicy::getTerminalPtNodePositionOfWord(const int *const inWord, - const int length, const bool forceLowerCaseSearch) const { +int Ver4PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints, + const bool forceLowerCaseSearch) const { DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader); readingHelper.initWithPtNodeArrayPos(getRootPosition()); - const int ptNodePos = - readingHelper.getTerminalPtNodePositionOfWord(inWord, length, forceLowerCaseSearch); + const int ptNodePos = readingHelper.getTerminalPtNodePositionOfWord(wordCodePoints.data(), + wordCodePoints.size(), forceLowerCaseSearch); if (readingHelper.isError()) { mIsCorrupted = true; AKLOGE("Dictionary reading error in createAndGetAllChildDicNodes()."); } - return ptNodePos; + if (ptNodePos == NOT_A_DICT_POS) { + return NOT_A_WORD_ID; + } + const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); + if (ptNodeParams.isDeleted()) { + return NOT_A_WORD_ID; + } + return ptNodeParams.getTerminalId(); } -int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability, - const int bigramProbability) const { - if (mHeaderPolicy->isDecayingDict()) { - // Both probabilities are encoded. Decode them and get probability. - return ForgettingCurveUtils::getProbability(unigramProbability, bigramProbability); - } else { - if (unigramProbability == NOT_A_PROBABILITY) { - return NOT_A_PROBABILITY; - } else if (bigramProbability == NOT_A_PROBABILITY) { - return ProbabilityUtils::backoff(unigramProbability); - } else { - return bigramProbability; - } - } +const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext( + const WordIdArrayView prevWordIds, const int wordId, + MultiBigramMap *const multiBigramMap) const { + if (wordId == NOT_A_WORD_ID) { + return WordAttributes(); + } + return mBuffers->getLanguageModelDictContent()->getWordAttributes(prevWordIds, wordId, + mHeaderPolicy); } -int Ver4PatriciaTriePolicy::getProbabilityOfPtNode(const int *const prevWordsPtNodePos, - const int ptNodePos) const { - if (ptNodePos == NOT_A_DICT_POS) { +int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds, + const int wordId) const { + if (wordId == NOT_A_WORD_ID || prevWordIds.contains(NOT_A_WORD_ID)) { return NOT_A_PROBABILITY; } - const PtNodeParams ptNodeParams(mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos)); - if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) { + const ProbabilityEntry probabilityEntry = + mBuffers->getLanguageModelDictContent()->getNgramProbabilityEntry(prevWordIds, wordId); + if (!probabilityEntry.isValid() || probabilityEntry.isBlacklisted() + || probabilityEntry.isNotAWord()) { return NOT_A_PROBABILITY; } - if (prevWordsPtNodePos) { - const int bigramsPosition = getBigramsPositionOfPtNode(prevWordsPtNodePos[0]); - BinaryDictionaryBigramsIterator bigramsIt(&mBigramPolicy, bigramsPosition); - while (bigramsIt.hasNext()) { - bigramsIt.next(); - if (bigramsIt.getBigramPos() == ptNodePos - && bigramsIt.getProbability() != NOT_A_PROBABILITY) { - return getProbability(ptNodeParams.getProbability(), bigramsIt.getProbability()); - } - } - return NOT_A_PROBABILITY; + if (mHeaderPolicy->hasHistoricalInfoOfWords()) { + return ForgettingCurveUtils::decodeProbability(probabilityEntry.getHistoricalInfo(), + mHeaderPolicy); + } else { + return probabilityEntry.getProbability(); } - return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY); } -void Ver4PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordsPtNodePos, +BinaryDictionaryShortcutIterator Ver4PatriciaTriePolicy::getShortcutIterator( + const int wordId) const { + const int shortcutPos = getShortcutPositionOfWord(wordId); + return BinaryDictionaryShortcutIterator(&mShortcutPolicy, shortcutPos); +} + +void Ver4PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordIds, NgramListener *const listener) const { - if (!prevWordsPtNodePos) { + if (prevWordIds.empty()) { return; } - const int bigramsPosition = getBigramsPositionOfPtNode(prevWordsPtNodePos[0]); - BinaryDictionaryBigramsIterator bigramsIt(&mBigramPolicy, bigramsPosition); - while (bigramsIt.hasNext()) { - bigramsIt.next(); - listener->onVisitEntry(bigramsIt.getProbability(), bigramsIt.getBigramPos()); + const auto languageModelDictContent = mBuffers->getLanguageModelDictContent(); + for (size_t i = 1; i <= prevWordIds.size(); ++i) { + for (const auto entry : languageModelDictContent->getProbabilityEntries( + prevWordIds.limit(i))) { + const ProbabilityEntry &probabilityEntry = entry.getProbabilityEntry(); + if (!probabilityEntry.isValid()) { + continue; + } + const int probability = probabilityEntry.hasHistoricalInfo() ? + ForgettingCurveUtils::decodeProbability( + probabilityEntry.getHistoricalInfo(), mHeaderPolicy) : + probabilityEntry.getProbability(); + listener->onVisitEntry(probability, entry.getWordId()); + } } } -int Ver4PatriciaTriePolicy::getShortcutPositionOfPtNode(const int ptNodePos) const { - if (ptNodePos == NOT_A_DICT_POS) { +int Ver4PatriciaTriePolicy::getShortcutPositionOfWord(const int wordId) const { + if (wordId == NOT_A_WORD_ID) { return NOT_A_DICT_POS; } + const int ptNodePos = + mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId); const PtNodeParams ptNodeParams(mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos)); if (ptNodeParams.isDeleted()) { return NOT_A_DICT_POS; @@ -171,19 +174,7 @@ int Ver4PatriciaTriePolicy::getShortcutPositionOfPtNode(const int ptNodePos) con ptNodeParams.getTerminalId()); } -int Ver4PatriciaTriePolicy::getBigramsPositionOfPtNode(const int ptNodePos) const { - if (ptNodePos == NOT_A_DICT_POS) { - return NOT_A_DICT_POS; - } - const PtNodeParams ptNodeParams(mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos)); - if (ptNodeParams.isDeleted()) { - return NOT_A_DICT_POS; - } - return mBuffers->getBigramDictContent()->getBigramListHeadPos( - ptNodeParams.getTerminalId()); -} - -bool Ver4PatriciaTriePolicy::addUnigramEntry(const int *const word, const int length, +bool Ver4PatriciaTriePolicy::addUnigramEntry(const CodePointArrayView wordCodePoints, const UnigramProperty *const unigramProperty) { if (!mBuffers->isUpdatable()) { AKLOGI("Warning: addUnigramEntry() is called for non-updatable dictionary."); @@ -194,13 +185,14 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const int *const word, const int le mDictBuffer->getTailPosition()); return false; } - if (length > MAX_WORD_LENGTH) { - AKLOGE("The word is too long to insert to the dictionary, length: %d", length); + if (wordCodePoints.size() > MAX_WORD_LENGTH) { + AKLOGE("The word is too long to insert to the dictionary, length: %zd", + wordCodePoints.size()); return false; } for (const auto &shortcut : unigramProperty->getShortcuts()) { if (shortcut.getTargetCodePoints()->size() > MAX_WORD_LENGTH) { - AKLOGE("One of shortcut targets is too long to insert to the dictionary, length: %d", + AKLOGE("One of shortcut targets is too long to insert to the dictionary, length: %zd", shortcut.getTargetCodePoints()->size()); return false; } @@ -209,8 +201,8 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const int *const word, const int le readingHelper.initWithPtNodeArrayPos(getRootPosition()); bool addedNewUnigram = false; int codePointsToAdd[MAX_WORD_LENGTH]; - int codePointCountToAdd = length; - memmove(codePointsToAdd, word, sizeof(int) * length); + int codePointCountToAdd = wordCodePoints.size(); + memmove(codePointsToAdd, wordCodePoints.data(), sizeof(int) * codePointCountToAdd); if (unigramProperty->representsBeginningOfSentence()) { codePointCountToAdd = CharUtils::attachBeginningOfSentenceMarker(codePointsToAdd, codePointCountToAdd, MAX_WORD_LENGTH); @@ -218,24 +210,26 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const int *const word, const int le if (codePointCountToAdd <= 0) { return false; } - if (mUpdatingHelper.addUnigramWord(&readingHelper, codePointsToAdd, codePointCountToAdd, - unigramProperty, &addedNewUnigram)) { + const CodePointArrayView codePointArrayView(codePointsToAdd, codePointCountToAdd); + if (mUpdatingHelper.addUnigramWord(&readingHelper, codePointArrayView, unigramProperty, + &addedNewUnigram)) { if (addedNewUnigram && !unigramProperty->representsBeginningOfSentence()) { - mUnigramCount++; + mEntryCounters.incrementUnigramCount(); } if (unigramProperty->getShortcuts().size() > 0) { // Add shortcut target. - const int wordPos = getTerminalPtNodePositionOfWord(word, length, - false /* forceLowerCaseSearch */); - if (wordPos == NOT_A_DICT_POS) { - AKLOGE("Cannot find terminal PtNode position to add shortcut target."); + const int wordId = getWordId(codePointArrayView, false /* forceLowerCaseSearch */); + if (wordId == NOT_A_WORD_ID) { + AKLOGE("Cannot find word id to add shortcut target."); return false; } + const int wordPos = + mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId); for (const auto &shortcut : unigramProperty->getShortcuts()) { if (!mUpdatingHelper.addShortcutTarget(wordPos, - shortcut.getTargetCodePoints()->data(), - shortcut.getTargetCodePoints()->size(), shortcut.getProbability())) { - AKLOGE("Cannot add new shortcut target. PtNodePos: %d, length: %d, " + CodePointArrayView(*shortcut.getTargetCodePoints()), + shortcut.getProbability())) { + AKLOGE("Cannot add new shortcut target. PtNodePos: %d, length: %zd, " "probability: %d", wordPos, shortcut.getTargetCodePoints()->size(), shortcut.getProbability()); return false; @@ -248,29 +242,32 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const int *const word, const int le } } -bool Ver4PatriciaTriePolicy::removeUnigramEntry(const int *const word, const int length) { +bool Ver4PatriciaTriePolicy::removeUnigramEntry(const CodePointArrayView wordCodePoints) { if (!mBuffers->isUpdatable()) { AKLOGI("Warning: removeUnigramEntry() is called for non-updatable dictionary."); return false; } - const int ptNodePos = getTerminalPtNodePositionOfWord(word, length, - false /* forceLowerCaseSearch */); - if (ptNodePos == NOT_A_DICT_POS) { + const int wordId = getWordId(wordCodePoints, false /* forceLowerCaseSearch */); + if (wordId == NOT_A_WORD_ID) { return false; } + const int ptNodePos = + mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId); const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); if (!mNodeWriter.markPtNodeAsDeleted(&ptNodeParams)) { AKLOGE("Cannot remove unigram. ptNodePos: %d", ptNodePos); return false; } + if (!mBuffers->getMutableLanguageModelDictContent()->removeProbabilityEntry(wordId)) { + return false; + } if (!ptNodeParams.representsNonWordInfo()) { - mUnigramCount--; + mEntryCounters.decrementUnigramCount(); } return true; } -bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsInfo, - const BigramProperty *const bigramProperty) { +bool Ver4PatriciaTriePolicy::addNgramEntry(const NgramProperty *const ngramProperty) { if (!mBuffers->isUpdatable()) { AKLOGI("Warning: addNgramEntry() is called for non-updatable dictionary."); return false; @@ -280,51 +277,50 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI mDictBuffer->getTailPosition()); return false; } - if (!prevWordsInfo->isValid()) { - AKLOGE("prev words info is not valid for adding n-gram entry to the dictionary."); + const NgramContext *const ngramContext = ngramProperty->getNgramContext(); + if (!ngramContext->isValid()) { + AKLOGE("Ngram context is not valid for adding n-gram entry to the dictionary."); return false; } - if (bigramProperty->getTargetCodePoints()->size() > MAX_WORD_LENGTH) { + if (ngramProperty->getTargetCodePoints()->size() > MAX_WORD_LENGTH) { AKLOGE("The word is too long to insert the ngram to the dictionary. " - "length: %d", bigramProperty->getTargetCodePoints()->size()); + "length: %zd", ngramProperty->getTargetCodePoints()->size()); return false; } - int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; - prevWordsInfo->getPrevWordsTerminalPtNodePos(this, prevWordsPtNodePos, + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; + const WordIdArrayView prevWordIds = ngramContext->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */); - const auto prevWordsPtNodePosView = PtNodePosArrayView::fromFixedSizeArray(prevWordsPtNodePos); - // TODO: Support N-gram. - if (prevWordsPtNodePos[0] == NOT_A_DICT_POS) { - if (prevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */)) { - const std::vector<UnigramProperty::ShortcutProperty> shortcuts; - const UnigramProperty beginningOfSentenceUnigramProperty( - true /* representsBeginningOfSentence */, true /* isNotAWord */, - false /* isBlacklisted */, MAX_PROBABILITY /* probability */, - NOT_A_TIMESTAMP /* timestamp */, 0 /* level */, 0 /* count */, &shortcuts); - if (!addUnigramEntry(prevWordsInfo->getNthPrevWordCodePoints(1 /* n */), - prevWordsInfo->getNthPrevWordCodePointCount(1 /* n */), - &beginningOfSentenceUnigramProperty)) { - AKLOGE("Cannot add unigram entry for the beginning-of-sentence."); - return false; - } - // Refresh Terminal PtNode positions. - prevWordsInfo->getPrevWordsTerminalPtNodePos(this, prevWordsPtNodePos, - false /* tryLowerCaseSearch */); - } else { + if (prevWordIds.empty()) { + return false; + } + for (size_t i = 0; i < prevWordIds.size(); ++i) { + if (prevWordIds[i] != NOT_A_WORD_ID) { + continue; + } + if (!ngramContext->isNthPrevWordBeginningOfSentence(i + 1 /* n */)) { return false; } + const UnigramProperty beginningOfSentenceUnigramProperty( + true /* representsBeginningOfSentence */, true /* isNotAWord */, + false /* isBlacklisted */, false /* isPossiblyOffensive */, + MAX_PROBABILITY /* probability */, HistoricalInfo()); + if (!addUnigramEntry(ngramContext->getNthPrevWordCodePoints(1 /* n */), + &beginningOfSentenceUnigramProperty)) { + AKLOGE("Cannot add unigram entry for the beginning-of-sentence."); + return false; + } + // Refresh word ids. + ngramContext->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */); } - const int word1Pos = getTerminalPtNodePositionOfWord( - bigramProperty->getTargetCodePoints()->data(), - bigramProperty->getTargetCodePoints()->size(), false /* forceLowerCaseSearch */); - if (word1Pos == NOT_A_DICT_POS) { + const int wordId = getWordId(CodePointArrayView(*ngramProperty->getTargetCodePoints()), + false /* forceLowerCaseSearch */); + if (wordId == NOT_A_WORD_ID) { return false; } bool addedNewEntry = false; - if (mUpdatingHelper.addNgramEntry(prevWordsPtNodePosView, word1Pos, bigramProperty, - &addedNewEntry)) { + if (mNodeWriter.addNgramEntry(prevWordIds, wordId, ngramProperty, &addedNewEntry)) { if (addedNewEntry) { - mBigramCount++; + mEntryCounters.incrementNgramCount(prevWordIds.size() + 1); } return true; } else { @@ -332,8 +328,8 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI } } -bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, - const int *const word, const int length) { +bool Ver4PatriciaTriePolicy::removeNgramEntry(const NgramContext *const ngramContext, + const CodePointArrayView wordCodePoints) { if (!mBuffers->isUpdatable()) { AKLOGI("Warning: removeNgramEntry() is called for non-updatable dictionary."); return false; @@ -343,40 +339,86 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor mDictBuffer->getTailPosition()); return false; } - if (!prevWordsInfo->isValid()) { - AKLOGE("prev words info is not valid for removing n-gram entry form the dictionary."); + if (!ngramContext->isValid()) { + AKLOGE("Ngram context is not valid for removing n-gram entry form the dictionary."); return false; } - if (length > MAX_WORD_LENGTH) { - AKLOGE("word is too long to remove n-gram entry form the dictionary. length: %d", length); + if (wordCodePoints.size() > MAX_WORD_LENGTH) { + AKLOGE("word is too long to remove n-gram entry form the dictionary. length: %zd", + wordCodePoints.size()); } - int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; - prevWordsInfo->getPrevWordsTerminalPtNodePos(this, prevWordsPtNodePos, + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; + const WordIdArrayView prevWordIds = ngramContext->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSerch */); - const auto prevWordsPtNodePosView = PtNodePosArrayView::fromFixedSizeArray(prevWordsPtNodePos); - // TODO: Support N-gram. - if (prevWordsPtNodePos[0] == NOT_A_DICT_POS) { + if (prevWordIds.empty() || prevWordIds.contains(NOT_A_WORD_ID)) { return false; } - const int wordPos = getTerminalPtNodePositionOfWord(word, length, - false /* forceLowerCaseSearch */); - if (wordPos == NOT_A_DICT_POS) { + const int wordId = getWordId(wordCodePoints, false /* forceLowerCaseSearch */); + if (wordId == NOT_A_WORD_ID) { return false; } - if (mUpdatingHelper.removeNgramEntry(prevWordsPtNodePosView, wordPos)) { - mBigramCount--; + if (mNodeWriter.removeNgramEntry(prevWordIds, wordId)) { + mEntryCounters.decrementNgramCount(prevWordIds.size()); return true; } else { return false; } } +bool Ver4PatriciaTriePolicy::updateEntriesForWordWithNgramContext( + const NgramContext *const ngramContext, const CodePointArrayView wordCodePoints, + const bool isValidWord, const HistoricalInfo historicalInfo) { + if (!mBuffers->isUpdatable()) { + AKLOGI("Warning: updateEntriesForWordWithNgramContext() is called for non-updatable " + "dictionary."); + return false; + } + const bool updateAsAValidWord = ngramContext->isNthPrevWordBeginningOfSentence(1 /* n */) ? + false : isValidWord; + int wordId = getWordId(wordCodePoints, false /* tryLowerCaseSearch */); + if (wordId == NOT_A_WORD_ID) { + // The word is not in the dictionary. + const UnigramProperty unigramProperty(false /* representsBeginningOfSentence */, + false /* isNotAWord */, false /* isBlacklisted */, false /* isPossiblyOffensive */, + NOT_A_PROBABILITY, HistoricalInfo(historicalInfo.getTimestamp(), 0 /* level */, + 0 /* count */)); + if (!addUnigramEntry(wordCodePoints, &unigramProperty)) { + AKLOGE("Cannot add unigarm entry in updateEntriesForWordWithNgramContext()."); + return false; + } + wordId = getWordId(wordCodePoints, false /* tryLowerCaseSearch */); + } + + WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; + const WordIdArrayView prevWordIds = ngramContext->getPrevWordIds(this, &prevWordIdArray, + false /* tryLowerCaseSearch */); + if (prevWordIds.firstOrDefault(NOT_A_WORD_ID) == NOT_A_WORD_ID + && ngramContext->isNthPrevWordBeginningOfSentence(1 /* n */)) { + const UnigramProperty beginningOfSentenceUnigramProperty( + true /* representsBeginningOfSentence */, + true /* isNotAWord */, false /* isPossiblyOffensive */, NOT_A_PROBABILITY, + HistoricalInfo(historicalInfo.getTimestamp(), 0 /* level */, 0 /* count */)); + if (!addUnigramEntry(ngramContext->getNthPrevWordCodePoints(1 /* n */), + &beginningOfSentenceUnigramProperty)) { + AKLOGE("Cannot add BoS entry in updateEntriesForWordWithNgramContext()."); + return false; + } + // Refresh word ids. + ngramContext->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */); + } + if (!mBuffers->getMutableLanguageModelDictContent()->updateAllEntriesOnInputWord(prevWordIds, + wordId, updateAsAValidWord, historicalInfo, mHeaderPolicy, &mEntryCounters)) { + return false; + } + return true; +} + bool Ver4PatriciaTriePolicy::flush(const char *const filePath) { if (!mBuffers->isUpdatable()) { AKLOGI("Warning: flush() is called for non-updatable dictionary. filePath: %s", filePath); return false; } - if (!mWritingHelper.writeToDictFile(filePath, mUnigramCount, mBigramCount)) { + if (!mWritingHelper.writeToDictFile(filePath, mEntryCounters.getEntryCounts())) { AKLOGE("Cannot flush the dictionary to file."); mIsCorrupted = true; return false; @@ -414,7 +456,7 @@ bool Ver4PatriciaTriePolicy::needsToRunGC(const bool mindsBlockByGC) const { // Needs to reduce dictionary size. return true; } else if (mHeaderPolicy->isDecayingDict()) { - return ForgettingCurveUtils::needsToDecay(mindsBlockByGC, mUnigramCount, mBigramCount, + return ForgettingCurveUtils::needsToDecay(mindsBlockByGC, mEntryCounters.getEntryCounts(), mHeaderPolicy); } return false; @@ -424,79 +466,66 @@ void Ver4PatriciaTriePolicy::getProperty(const char *const query, const int quer char *const outResult, const int maxResultLength) { const int compareLength = queryLength + 1 /* terminator */; if (strncmp(query, UNIGRAM_COUNT_QUERY, compareLength) == 0) { - snprintf(outResult, maxResultLength, "%d", mUnigramCount); + snprintf(outResult, maxResultLength, "%d", mEntryCounters.getUnigramCount()); } else if (strncmp(query, BIGRAM_COUNT_QUERY, compareLength) == 0) { - snprintf(outResult, maxResultLength, "%d", mBigramCount); + snprintf(outResult, maxResultLength, "%d", mEntryCounters.getBigramCount()); } else if (strncmp(query, MAX_UNIGRAM_COUNT_QUERY, compareLength) == 0) { snprintf(outResult, maxResultLength, "%d", mHeaderPolicy->isDecayingDict() ? - ForgettingCurveUtils::getUnigramCountHardLimit( + ForgettingCurveUtils::getEntryCountHardLimit( mHeaderPolicy->getMaxUnigramCount()) : static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE)); } else if (strncmp(query, MAX_BIGRAM_COUNT_QUERY, compareLength) == 0) { snprintf(outResult, maxResultLength, "%d", mHeaderPolicy->isDecayingDict() ? - ForgettingCurveUtils::getBigramCountHardLimit( + ForgettingCurveUtils::getEntryCountHardLimit( mHeaderPolicy->getMaxBigramCount()) : static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE)); } } -const WordProperty Ver4PatriciaTriePolicy::getWordProperty(const int *const codePoints, - const int codePointCount) const { - const int ptNodePos = getTerminalPtNodePositionOfWord(codePoints, codePointCount, - false /* forceLowerCaseSearch */); - if (ptNodePos == NOT_A_DICT_POS) { +const WordProperty Ver4PatriciaTriePolicy::getWordProperty( + const CodePointArrayView wordCodePoints) const { + const int wordId = getWordId(wordCodePoints, false /* forceLowerCaseSearch */); + if (wordId == NOT_A_WORD_ID) { AKLOGE("getWordProperty is called for invalid word."); return WordProperty(); } - const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); - std::vector<int> codePointVector(ptNodeParams.getCodePoints(), - ptNodeParams.getCodePoints() + ptNodeParams.getCodePointCount()); - const ProbabilityEntry probabilityEntry = - mBuffers->getLanguageModelDictContent()->getProbabilityEntry( - ptNodeParams.getTerminalId()); - const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo(); - // Fetch bigram information. - std::vector<BigramProperty> bigrams; - const int bigramListPos = getBigramsPositionOfPtNode(ptNodePos); - if (bigramListPos != NOT_A_DICT_POS) { - int bigramWord1CodePoints[MAX_WORD_LENGTH]; - const BigramDictContent *const bigramDictContent = mBuffers->getBigramDictContent(); - const TerminalPositionLookupTable *const terminalPositionLookupTable = - mBuffers->getTerminalPositionLookupTable(); - bool hasNext = true; - int readingPos = bigramListPos; - while (hasNext) { - const BigramEntry bigramEntry = - bigramDictContent->getBigramEntryAndAdvancePosition(&readingPos); - hasNext = bigramEntry.hasNext(); - const int word1TerminalId = bigramEntry.getTargetTerminalId(); - const int word1TerminalPtNodePos = - terminalPositionLookupTable->getTerminalPtNodePosition(word1TerminalId); - if (word1TerminalPtNodePos == NOT_A_DICT_POS) { - continue; + const LanguageModelDictContent *const languageModelDictContent = + mBuffers->getLanguageModelDictContent(); + // Fetch ngram information. + std::vector<NgramProperty> ngrams; + int ngramTargetCodePoints[MAX_WORD_LENGTH]; + int ngramPrevWordsCodePoints[MAX_PREV_WORD_COUNT_FOR_N_GRAM][MAX_WORD_LENGTH]; + int ngramPrevWordsCodePointCount[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; + bool ngramPrevWordIsBeginningOfSentense[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; + for (const auto entry : languageModelDictContent->exportAllNgramEntriesRelatedToWord( + mHeaderPolicy, wordId)) { + const int codePointCount = getCodePointsAndReturnCodePointCount(entry.getTargetWordId(), + MAX_WORD_LENGTH, ngramTargetCodePoints); + const WordIdArrayView prevWordIds = entry.getPrevWordIds(); + for (size_t i = 0; i < prevWordIds.size(); ++i) { + ngramPrevWordsCodePointCount[i] = getCodePointsAndReturnCodePointCount(prevWordIds[i], + MAX_WORD_LENGTH, ngramPrevWordsCodePoints[i]); + ngramPrevWordIsBeginningOfSentense[i] = languageModelDictContent->getProbabilityEntry( + prevWordIds[i]).representsBeginningOfSentence(); + if (ngramPrevWordIsBeginningOfSentense[i]) { + ngramPrevWordsCodePointCount[i] = CharUtils::removeBeginningOfSentenceMarker( + ngramPrevWordsCodePoints[i], ngramPrevWordsCodePointCount[i]); } - // Word (unigram) probability - int word1Probability = NOT_A_PROBABILITY; - const int codePointCount = getCodePointsAndProbabilityAndReturnCodePointCount( - word1TerminalPtNodePos, MAX_WORD_LENGTH, bigramWord1CodePoints, - &word1Probability); - const std::vector<int> word1(bigramWord1CodePoints, - bigramWord1CodePoints + codePointCount); - const HistoricalInfo *const historicalInfo = bigramEntry.getHistoricalInfo(); - const int probability = bigramEntry.hasHistoricalInfo() ? - ForgettingCurveUtils::decodeProbability( - bigramEntry.getHistoricalInfo(), mHeaderPolicy) : - bigramEntry.getProbability(); - bigrams.emplace_back(&word1, probability, - historicalInfo->getTimeStamp(), historicalInfo->getLevel(), - historicalInfo->getCount()); } + const NgramContext ngramContext(ngramPrevWordsCodePoints, ngramPrevWordsCodePointCount, + ngramPrevWordIsBeginningOfSentense, prevWordIds.size()); + const ProbabilityEntry ngramProbabilityEntry = entry.getProbabilityEntry(); + const HistoricalInfo *const historicalInfo = ngramProbabilityEntry.getHistoricalInfo(); + // TODO: Output flags in WordAttributes. + ngrams.emplace_back(ngramContext, + CodePointArrayView(ngramTargetCodePoints, codePointCount).toVector(), + entry.getWordAttributes().getProbability(), *historicalInfo); } // Fetch shortcut information. std::vector<UnigramProperty::ShortcutProperty> shortcuts; - int shortcutPos = getShortcutPositionOfPtNode(ptNodePos); + int shortcutPos = getShortcutPositionOfWord(wordId); if (shortcutPos != NOT_A_DICT_POS) { int shortcutTarget[MAX_WORD_LENGTH]; const ShortcutDictContent *const shortcutDictContent = @@ -507,15 +536,20 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(const int *const code int shortcutProbability = NOT_A_PROBABILITY; shortcutDictContent->getShortcutEntryAndAdvancePosition(MAX_WORD_LENGTH, shortcutTarget, &shortcutTargetLength, &shortcutProbability, &hasNext, &shortcutPos); - const std::vector<int> target(shortcutTarget, shortcutTarget + shortcutTargetLength); - shortcuts.emplace_back(&target, shortcutProbability); + shortcuts.emplace_back( + CodePointArrayView(shortcutTarget, shortcutTargetLength).toVector(), + shortcutProbability); } } - const UnigramProperty unigramProperty(ptNodeParams.representsBeginningOfSentence(), - ptNodeParams.isNotAWord(), ptNodeParams.isBlacklisted(), ptNodeParams.getProbability(), - historicalInfo->getTimeStamp(), historicalInfo->getLevel(), - historicalInfo->getCount(), &shortcuts); - return WordProperty(&codePointVector, &unigramProperty, &bigrams); + const WordAttributes wordAttributes = languageModelDictContent->getWordAttributes( + WordIdArrayView(), wordId, mHeaderPolicy); + const ProbabilityEntry probabilityEntry = languageModelDictContent->getProbabilityEntry(wordId); + const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo(); + const UnigramProperty unigramProperty(probabilityEntry.representsBeginningOfSentence(), + wordAttributes.isNotAWord(), wordAttributes.isBlacklisted(), + wordAttributes.isPossiblyOffensive(), wordAttributes.getProbability(), + *historicalInfo, std::move(shortcuts)); + return WordProperty(wordCodePoints.toVector(), &unigramProperty, &ngrams); } int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outCodePoints, @@ -536,9 +570,10 @@ int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const return 0; } const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token]; - int unigramProbability = NOT_A_PROBABILITY; - *outCodePointCount = getCodePointsAndProbabilityAndReturnCodePointCount( - terminalPtNodePos, MAX_WORD_LENGTH, outCodePoints, &unigramProbability); + const PtNodeParams ptNodeParams = + mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(terminalPtNodePos); + *outCodePointCount = getCodePointsAndReturnCodePointCount(ptNodeParams.getTerminalId(), + MAX_WORD_LENGTH, outCodePoints); const int nextToken = token + 1; if (nextToken >= terminalPtNodePositionsVectorSize) { // All words have been iterated. diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h index faad4290d..13700b390 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h @@ -23,7 +23,6 @@ #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" #include "suggest/policyimpl/dictionary/header/header_policy.h" #include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_updating_helper.h" -#include "suggest/policyimpl/dictionary/structure/v4/bigram/ver4_bigram_list_policy.h" #include "suggest/policyimpl/dictionary/structure/v4/shortcut/ver4_shortcut_list_policy.h" #include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.h" #include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.h" @@ -31,29 +30,29 @@ #include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h" #include "suggest/policyimpl/dictionary/structure/v4/ver4_pt_node_array_reader.h" #include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" +#include "suggest/policyimpl/dictionary/utils/entry_counters.h" +#include "utils/int_array_view.h" namespace latinime { class DicNode; class DicNodeVector; +// Word id = Artificial id that is stored in the PtNode looked up by the word. class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { public: Ver4PatriciaTriePolicy(Ver4DictBuffers::Ver4DictBuffersPtr buffers) : mBuffers(std::move(buffers)), mHeaderPolicy(mBuffers->getHeaderPolicy()), mDictBuffer(mBuffers->getWritableTrieBuffer()), - mBigramPolicy(mBuffers->getMutableBigramDictContent(), - mBuffers->getTerminalPositionLookupTable(), mHeaderPolicy), mShortcutPolicy(mBuffers->getMutableShortcutDictContent(), mBuffers->getTerminalPositionLookupTable()), - mNodeReader(mDictBuffer, mBuffers->getLanguageModelDictContent(), mHeaderPolicy), - mPtNodeArrayReader(mDictBuffer), - mNodeWriter(mDictBuffer, mBuffers.get(), mHeaderPolicy, &mNodeReader, - &mPtNodeArrayReader, &mBigramPolicy, &mShortcutPolicy), + mNodeReader(mDictBuffer), mPtNodeArrayReader(mDictBuffer), + mNodeWriter(mDictBuffer, mBuffers.get(), &mNodeReader, &mPtNodeArrayReader, + &mShortcutPolicy), mUpdatingHelper(mDictBuffer, &mNodeReader, &mNodeWriter), mWritingHelper(mBuffers.get()), - mUnigramCount(mHeaderPolicy->getUnigramCount()), - mBigramCount(mHeaderPolicy->getBigramCount()), + mEntryCounters(mHeaderPolicy->getUnigramCount(), mHeaderPolicy->getBigramCount(), + mHeaderPolicy->getTrigramCount()), mTerminalPtNodePositionsForIteratingWords(), mIsCorrupted(false) {}; AK_FORCE_INLINE int getRootPosition() const { @@ -63,40 +62,44 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { void createAndGetAllChildDicNodes(const DicNode *const dicNode, DicNodeVector *const childDicNodes) const; - int getCodePointsAndProbabilityAndReturnCodePointCount( - const int terminalPtNodePos, const int maxCodePointCount, int *const outCodePoints, - int *const outUnigramProbability) const; + int getCodePointsAndReturnCodePointCount(const int wordId, const int maxCodePointCount, + int *const outCodePoints) const; - int getTerminalPtNodePositionOfWord(const int *const inWord, - const int length, const bool forceLowerCaseSearch) const; + int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const; - int getProbability(const int unigramProbability, const int bigramProbability) const; + const WordAttributes getWordAttributesInContext(const WordIdArrayView prevWordIds, + const int wordId, MultiBigramMap *const multiBigramMap) const; - int getProbabilityOfPtNode(const int *const prevWordsPtNodePos, const int ptNodePos) const; + // TODO: Remove + int getProbability(const int unigramProbability, const int bigramProbability) const { + // Not used. + return NOT_A_PROBABILITY; + } + + int getProbabilityOfWord(const WordIdArrayView prevWordIds, const int wordId) const; - void iterateNgramEntries(const int *const prevWordsPtNodePos, + void iterateNgramEntries(const WordIdArrayView prevWordIds, NgramListener *const listener) const; - int getShortcutPositionOfPtNode(const int ptNodePos) const; + BinaryDictionaryShortcutIterator getShortcutIterator(const int wordId) const; const DictionaryHeaderStructurePolicy *getHeaderStructurePolicy() const { return mHeaderPolicy; } - const DictionaryShortcutsStructurePolicy *getShortcutsStructurePolicy() const { - return &mShortcutPolicy; - } - - bool addUnigramEntry(const int *const word, const int length, + bool addUnigramEntry(const CodePointArrayView wordCodePoints, const UnigramProperty *const unigramProperty); - bool removeUnigramEntry(const int *const word, const int length); + bool removeUnigramEntry(const CodePointArrayView wordCodePoints); + + bool addNgramEntry(const NgramProperty *const ngramProperty); - bool addNgramEntry(const PrevWordsInfo *const prevWordsInfo, - const BigramProperty *const bigramProperty); + bool removeNgramEntry(const NgramContext *const ngramContext, + const CodePointArrayView wordCodePoints); - bool removeNgramEntry(const PrevWordsInfo *const prevWordsInfo, const int *const word1, - const int length1); + bool updateEntriesForWordWithNgramContext(const NgramContext *const ngramContext, + const CodePointArrayView wordCodePoints, const bool isValidWord, + const HistoricalInfo historicalInfo); bool flush(const char *const filePath); @@ -107,8 +110,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { void getProperty(const char *const query, const int queryLength, char *const outResult, const int maxResultLength); - const WordProperty getWordProperty(const int *const codePoints, - const int codePointCount) const; + const WordProperty getWordProperty(const CodePointArrayView wordCodePoints) const; int getNextWordAndNextToken(const int token, int *const outCodePoints, int *const outCodePointCount); @@ -132,19 +134,17 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { const Ver4DictBuffers::Ver4DictBuffersPtr mBuffers; const HeaderPolicy *const mHeaderPolicy; BufferWithExtendableBuffer *const mDictBuffer; - Ver4BigramListPolicy mBigramPolicy; Ver4ShortcutListPolicy mShortcutPolicy; Ver4PatriciaTrieNodeReader mNodeReader; Ver4PtNodeArrayReader mPtNodeArrayReader; Ver4PatriciaTrieNodeWriter mNodeWriter; DynamicPtUpdatingHelper mUpdatingHelper; Ver4PatriciaTrieWritingHelper mWritingHelper; - int mUnigramCount; - int mBigramCount; + MutableEntryCounters mEntryCounters; std::vector<int> mTerminalPtNodePositionsForIteratingWords; mutable bool mIsCorrupted; - int getBigramsPositionOfPtNode(const int ptNodePos) const; + int getShortcutPositionOfWord(const int wordId) const; }; } // namespace latinime #endif // LATINIME_VER4_PATRICIA_TRIE_POLICY_H diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp index 4220312e0..7f0604ce8 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp @@ -20,7 +20,6 @@ #include <queue> #include "suggest/policyimpl/dictionary/header/header_policy.h" -#include "suggest/policyimpl/dictionary/structure/v4/bigram/ver4_bigram_list_policy.h" #include "suggest/policyimpl/dictionary/structure/v4/shortcut/ver4_shortcut_list_policy.h" #include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.h" #include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h" @@ -34,17 +33,18 @@ namespace latinime { bool Ver4PatriciaTrieWritingHelper::writeToDictFile(const char *const dictDirPath, - const int unigramCount, const int bigramCount) const { + const EntryCounts &entryCounts) const { const HeaderPolicy *const headerPolicy = mBuffers->getHeaderPolicy(); BufferWithExtendableBuffer headerBuffer( BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE); const int extendedRegionSize = headerPolicy->getExtendedRegionSize() + mBuffers->getTrieBuffer()->getUsedAdditionalBufferSize(); if (!headerPolicy->fillInAndWriteHeaderToBuffer(false /* updatesLastDecayedTime */, - unigramCount, bigramCount, extendedRegionSize, &headerBuffer)) { + entryCounts, extendedRegionSize, &headerBuffer)) { AKLOGE("Cannot write header structure to buffer. " - "updatesLastDecayedTime: %d, unigramCount: %d, bigramCount: %d, " - "extendedRegionSize: %d", false, unigramCount, bigramCount, + "updatesLastDecayedTime: %d, unigramCount: %d, bigramCount: %d, trigramCount: %d," + "extendedRegionSize: %d", false, entryCounts.getUnigramCount(), + entryCounts.getBigramCount(), entryCounts.getTrigramCount(), extendedRegionSize); return false; } @@ -57,15 +57,14 @@ bool Ver4PatriciaTrieWritingHelper::writeToDictFileWithGC(const int rootPtNodeAr Ver4DictBuffers::Ver4DictBuffersPtr dictBuffers( Ver4DictBuffers::createVer4DictBuffers(headerPolicy, Ver4DictConstants::MAX_DICTIONARY_SIZE)); - int unigramCount = 0; - int bigramCount = 0; - if (!runGC(rootPtNodeArrayPos, headerPolicy, dictBuffers.get(), &unigramCount, &bigramCount)) { + MutableEntryCounters entryCounters; + if (!runGC(rootPtNodeArrayPos, headerPolicy, dictBuffers.get(), &entryCounters)) { return false; } BufferWithExtendableBuffer headerBuffer( BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE); if (!headerPolicy->fillInAndWriteHeaderToBuffer(true /* updatesLastDecayedTime */, - unigramCount, bigramCount, 0 /* extendedRegionSize */, &headerBuffer)) { + entryCounters.getEntryCounts(), 0 /* extendedRegionSize */, &headerBuffer)) { return false; } return dictBuffers->flushHeaderAndDictBuffers(dictDirPath, &headerBuffer); @@ -73,61 +72,46 @@ bool Ver4PatriciaTrieWritingHelper::writeToDictFileWithGC(const int rootPtNodeAr bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, const HeaderPolicy *const headerPolicy, Ver4DictBuffers *const buffersToWrite, - int *const outUnigramCount, int *const outBigramCount) { - Ver4PatriciaTrieNodeReader ptNodeReader(mBuffers->getTrieBuffer(), - mBuffers->getLanguageModelDictContent(), headerPolicy); + MutableEntryCounters *const outEntryCounters) { + Ver4PatriciaTrieNodeReader ptNodeReader(mBuffers->getTrieBuffer()); Ver4PtNodeArrayReader ptNodeArrayReader(mBuffers->getTrieBuffer()); - Ver4BigramListPolicy bigramPolicy(mBuffers->getMutableBigramDictContent(), - mBuffers->getTerminalPositionLookupTable(), headerPolicy); Ver4ShortcutListPolicy shortcutPolicy(mBuffers->getMutableShortcutDictContent(), mBuffers->getTerminalPositionLookupTable()); Ver4PatriciaTrieNodeWriter ptNodeWriter(mBuffers->getWritableTrieBuffer(), - mBuffers, headerPolicy, &ptNodeReader, &ptNodeArrayReader, &bigramPolicy, - &shortcutPolicy); + mBuffers, &ptNodeReader, &ptNodeArrayReader, &shortcutPolicy); - DynamicPtReadingHelper readingHelper(&ptNodeReader, &ptNodeArrayReader); - readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos); - DynamicPtGcEventListeners - ::TraversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted - traversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted( - &ptNodeWriter); - if (!readingHelper.traverseAllPtNodesInPostorderDepthFirstManner( - &traversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted)) { + if (!mBuffers->getMutableLanguageModelDictContent()->updateAllProbabilityEntriesForGC( + headerPolicy, outEntryCounters)) { + AKLOGE("Failed to update probabilities in language model dict content."); return false; } - const int unigramCount = traversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted - .getValidUnigramCount(); - const int maxUnigramCount = headerPolicy->getMaxUnigramCount(); - if (headerPolicy->isDecayingDict() && unigramCount > maxUnigramCount) { - if (!truncateUnigrams(&ptNodeReader, &ptNodeWriter, maxUnigramCount)) { - AKLOGE("Cannot remove unigrams. current: %d, max: %d", unigramCount, - maxUnigramCount); + if (headerPolicy->isDecayingDict()) { + const EntryCounts maxEntryCounts(headerPolicy->getMaxUnigramCount(), + headerPolicy->getMaxBigramCount(), headerPolicy->getMaxTrigramCount()); + if (!mBuffers->getMutableLanguageModelDictContent()->truncateEntries( + outEntryCounters->getEntryCounts(), maxEntryCounts, headerPolicy, + outEntryCounters)) { + AKLOGE("Failed to truncate entries in language model dict content."); return false; } } + DynamicPtReadingHelper readingHelper(&ptNodeReader, &ptNodeArrayReader); readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos); - DynamicPtGcEventListeners::TraversePolicyToUpdateBigramProbability - traversePolicyToUpdateBigramProbability(&ptNodeWriter); + DynamicPtGcEventListeners + ::TraversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted + traversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted( + &ptNodeWriter); if (!readingHelper.traverseAllPtNodesInPostorderDepthFirstManner( - &traversePolicyToUpdateBigramProbability)) { + &traversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted)) { return false; } - const int bigramCount = traversePolicyToUpdateBigramProbability.getValidBigramEntryCount(); - const int maxBigramCount = headerPolicy->getMaxBigramCount(); - if (headerPolicy->isDecayingDict() && bigramCount > maxBigramCount) { - if (!truncateBigrams(maxBigramCount)) { - AKLOGE("Cannot remove bigrams. current: %d, max: %d", bigramCount, maxBigramCount); - return false; - } - } // Mapping from positions in mBuffer to positions in bufferToWrite. PtNodeWriter::DictPositionRelocationMap dictPositionRelocationMap; readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos); Ver4PatriciaTrieNodeWriter ptNodeWriterForNewBuffers(buffersToWrite->getWritableTrieBuffer(), - buffersToWrite, headerPolicy, &ptNodeReader, &ptNodeArrayReader, &bigramPolicy, - &shortcutPolicy); + buffersToWrite, &ptNodeReader, &ptNodeArrayReader, &shortcutPolicy); DynamicPtGcEventListeners::TraversePolicyToPlaceAndWriteValidPtNodesToBuffer traversePolicyToPlaceAndWriteValidPtNodesToBuffer(&ptNodeWriterForNewBuffers, buffersToWrite->getWritableTrieBuffer(), &dictPositionRelocationMap); @@ -137,15 +121,12 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, } // Create policy instances for the GCed dictionary. - Ver4PatriciaTrieNodeReader newPtNodeReader(buffersToWrite->getTrieBuffer(), - buffersToWrite->getLanguageModelDictContent(), headerPolicy); + Ver4PatriciaTrieNodeReader newPtNodeReader(buffersToWrite->getTrieBuffer()); Ver4PtNodeArrayReader newPtNodeArrayreader(buffersToWrite->getTrieBuffer()); - Ver4BigramListPolicy newBigramPolicy(buffersToWrite->getMutableBigramDictContent(), - buffersToWrite->getTerminalPositionLookupTable(), headerPolicy); Ver4ShortcutListPolicy newShortcutPolicy(buffersToWrite->getMutableShortcutDictContent(), buffersToWrite->getTerminalPositionLookupTable()); Ver4PatriciaTrieNodeWriter newPtNodeWriter(buffersToWrite->getWritableTrieBuffer(), - buffersToWrite, headerPolicy, &newPtNodeReader, &newPtNodeArrayreader, &newBigramPolicy, + buffersToWrite, &newPtNodeReader, &newPtNodeArrayreader, &newShortcutPolicy); // Re-assign terminal IDs for valid terminal PtNodes. TerminalPositionLookupTable::TerminalIdMap terminalIdMap; @@ -153,14 +134,9 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, &terminalIdMap)) { return false; } - // Run GC for probability dict content. + // Run GC for language model dict content. if (!buffersToWrite->getMutableLanguageModelDictContent()->runGC(&terminalIdMap, - mBuffers->getLanguageModelDictContent(), nullptr /* outNgramCount */)) { - return false; - } - // Run GC for bigram dict content. - if(!buffersToWrite->getMutableBigramDictContent()->runGC(&terminalIdMap, - mBuffers->getBigramDictContent(), outBigramCount)) { + mBuffers->getLanguageModelDictContent())) { return false; } // Run GC for shortcut dict content. @@ -183,92 +159,6 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, &traversePolicyToUpdateAllPtNodeFlagsAndTerminalIds)) { return false; } - *outUnigramCount = traversePolicyToUpdateAllPositionFields.getUnigramCount(); - return true; -} - -bool Ver4PatriciaTrieWritingHelper::truncateUnigrams( - const Ver4PatriciaTrieNodeReader *const ptNodeReader, - Ver4PatriciaTrieNodeWriter *const ptNodeWriter, const int maxUnigramCount) { - const TerminalPositionLookupTable *const terminalPosLookupTable = - mBuffers->getTerminalPositionLookupTable(); - const int nextTerminalId = terminalPosLookupTable->getNextTerminalId(); - std::priority_queue<DictProbability, std::vector<DictProbability>, DictProbabilityComparator> - priorityQueue; - for (int i = 0; i < nextTerminalId; ++i) { - const int terminalPos = terminalPosLookupTable->getTerminalPtNodePosition(i); - if (terminalPos == NOT_A_DICT_POS) { - continue; - } - const ProbabilityEntry probabilityEntry = - mBuffers->getLanguageModelDictContent()->getProbabilityEntry(i); - const int probability = probabilityEntry.hasHistoricalInfo() ? - ForgettingCurveUtils::decodeProbability( - probabilityEntry.getHistoricalInfo(), mBuffers->getHeaderPolicy()) : - probabilityEntry.getProbability(); - priorityQueue.push(DictProbability(terminalPos, probability, - probabilityEntry.getHistoricalInfo()->getTimeStamp())); - } - - // Delete unigrams. - while (static_cast<int>(priorityQueue.size()) > maxUnigramCount) { - const int ptNodePos = priorityQueue.top().getDictPos(); - priorityQueue.pop(); - const PtNodeParams ptNodeParams = - ptNodeReader->fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); - if (ptNodeParams.representsNonWordInfo()) { - continue; - } - if (!ptNodeWriter->markPtNodeAsWillBecomeNonTerminal(&ptNodeParams)) { - AKLOGE("Cannot mark PtNode as willBecomeNonterminal. PtNode pos: %d", ptNodePos); - return false; - } - } - return true; -} - -bool Ver4PatriciaTrieWritingHelper::truncateBigrams(const int maxBigramCount) { - const TerminalPositionLookupTable *const terminalPosLookupTable = - mBuffers->getTerminalPositionLookupTable(); - const int nextTerminalId = terminalPosLookupTable->getNextTerminalId(); - std::priority_queue<DictProbability, std::vector<DictProbability>, DictProbabilityComparator> - priorityQueue; - BigramDictContent *const bigramDictContent = mBuffers->getMutableBigramDictContent(); - for (int i = 0; i < nextTerminalId; ++i) { - const int bigramListPos = bigramDictContent->getBigramListHeadPos(i); - if (bigramListPos == NOT_A_DICT_POS) { - continue; - } - bool hasNext = true; - int readingPos = bigramListPos; - while (hasNext) { - const BigramEntry bigramEntry = - bigramDictContent->getBigramEntryAndAdvancePosition(&readingPos); - const int entryPos = readingPos - bigramDictContent->getBigramEntrySize(); - hasNext = bigramEntry.hasNext(); - if (!bigramEntry.isValid()) { - continue; - } - const int probability = bigramEntry.hasHistoricalInfo() ? - ForgettingCurveUtils::decodeProbability( - bigramEntry.getHistoricalInfo(), mBuffers->getHeaderPolicy()) : - bigramEntry.getProbability(); - priorityQueue.push(DictProbability(entryPos, probability, - bigramEntry.getHistoricalInfo()->getTimeStamp())); - } - } - - // Delete bigrams. - while (static_cast<int>(priorityQueue.size()) > maxBigramCount) { - const int entryPos = priorityQueue.top().getDictPos(); - const BigramEntry bigramEntry = bigramDictContent->getBigramEntry(entryPos); - const BigramEntry invalidatedBigramEntry = bigramEntry.getInvalidatedEntry(); - if (!bigramDictContent->writeBigramEntry(&invalidatedBigramEntry, entryPos)) { - AKLOGE("Cannot write bigram entry to remove. pos: %d", entryPos); - return false; - } - priorityQueue.pop(); - } return true; } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h index bb464ad28..c56cea5cf 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h @@ -20,6 +20,7 @@ #include "defines.h" #include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_gc_event_listeners.h" #include "suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h" +#include "suggest/policyimpl/dictionary/utils/entry_counters.h" namespace latinime { @@ -33,8 +34,7 @@ class Ver4PatriciaTrieWritingHelper { Ver4PatriciaTrieWritingHelper(Ver4DictBuffers *const buffers) : mBuffers(buffers) {} - bool writeToDictFile(const char *const dictDirPath, const int unigramCount, - const int bigramCount) const; + bool writeToDictFile(const char *const dictDirPath, const EntryCounts &entryCounts) const; // This method cannot be const because the original dictionary buffer will be updated to detect // useless PtNodes during GC. @@ -66,57 +66,8 @@ class Ver4PatriciaTrieWritingHelper { const TerminalPositionLookupTable::TerminalIdMap *const mTerminalIdMap; }; - // For truncateUnigrams() and truncateBigrams(). - class DictProbability { - public: - DictProbability(const int dictPos, const int probability, const int timestamp) - : mDictPos(dictPos), mProbability(probability), mTimestamp(timestamp) {} - - int getDictPos() const { - return mDictPos; - } - - int getProbability() const { - return mProbability; - } - - int getTimestamp() const { - return mTimestamp; - } - - private: - DISALLOW_DEFAULT_CONSTRUCTOR(DictProbability); - - int mDictPos; - int mProbability; - int mTimestamp; - }; - - // For truncateUnigrams() and truncateBigrams(). - class DictProbabilityComparator { - public: - bool operator()(const DictProbability &left, const DictProbability &right) { - if (left.getProbability() != right.getProbability()) { - return left.getProbability() > right.getProbability(); - } - if (left.getTimestamp() != right.getTimestamp()) { - return left.getTimestamp() < right.getTimestamp(); - } - return left.getDictPos() > right.getDictPos(); - } - - private: - DISALLOW_ASSIGNMENT_OPERATOR(DictProbabilityComparator); - }; - bool runGC(const int rootPtNodeArrayPos, const HeaderPolicy *const headerPolicy, - Ver4DictBuffers *const buffersToWrite, int *const outUnigramCount, - int *const outBigramCount); - - bool truncateUnigrams(const Ver4PatriciaTrieNodeReader *const ptNodeReader, - Ver4PatriciaTrieNodeWriter *const ptNodeWriter, const int maxUnigramCount); - - bool truncateBigrams(const int maxBigramCount); + Ver4DictBuffers *const buffersToWrite, MutableEntryCounters *const outEntryCounters); Ver4DictBuffers *const mBuffers; }; diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.cpp index 833063c17..da2c30cd6 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.cpp @@ -31,7 +31,7 @@ uint32_t BufferWithExtendableBuffer::readUint(const int size, const int pos) con uint32_t BufferWithExtendableBuffer::readUintAndAdvancePosition(const int size, int *const pos) const { - const int value = readUint(size, *pos); + const uint32_t value = readUint(size, *pos); *pos += size; return value; } @@ -42,8 +42,10 @@ void BufferWithExtendableBuffer::readCodePointsAndAdvancePosition(const int maxC if (readingPosIsInAdditionalBuffer) { *pos -= mOriginalBuffer.size(); } + // Code point table is not used for dynamic format. *outCodePointCount = ByteArrayUtils::readStringAndAdvancePosition( - getBuffer(readingPosIsInAdditionalBuffer), maxCodePointCount, outCodePoints, pos); + getBuffer(readingPosIsInAdditionalBuffer), maxCodePointCount, + nullptr /* codePointTable */, outCodePoints, pos); if (readingPosIsInAdditionalBuffer) { *pos += mOriginalBuffer.size(); } diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/byte_array_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/byte_array_utils.h index c0a9fcb1d..abb979050 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/byte_array_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/byte_array_utils.h @@ -114,7 +114,7 @@ class ByteArrayUtils { return buffer[(*pos)++]; } - static AK_FORCE_INLINE int readUint(const uint8_t *const buffer, + static AK_FORCE_INLINE uint32_t readUint(const uint8_t *const buffer, const int size, const int pos) { // size must be in 1 to 4. ASSERT(size >= 1 && size <= 4); @@ -147,11 +147,18 @@ class ByteArrayUtils { */ static AK_FORCE_INLINE int readCodePoint(const uint8_t *const buffer, const int pos) { int p = pos; - return readCodePointAndAdvancePosition(buffer, &p); + return readCodePointAndAdvancePosition(buffer, nullptr /* codePointTable */, &p); } static AK_FORCE_INLINE int readCodePointAndAdvancePosition( - const uint8_t *const buffer, int *const pos) { + const uint8_t *const buffer, const int *const codePointTable, int *const pos) { + /* + * codePointTable is an array to convert the most frequent characters in this dictionary to + * 1 byte code points. It is only made of the original code points of the most frequent + * characters used in this dictionary. 0x20 - 0xFF is used for the 1 byte characters. + * The original code points are restored by picking the code points at the indices of the + * codePointTable. The indices are calculated by subtracting 0x20 from the firstByte. + */ const uint8_t firstByte = readUint8(buffer, *pos); if (firstByte < MINIMUM_ONE_BYTE_CHARACTER_VALUE) { if (firstByte == CHARACTER_ARRAY_TERMINATOR) { @@ -162,6 +169,9 @@ class ByteArrayUtils { } } else { *pos += 1; + if (codePointTable) { + return codePointTable[firstByte - MINIMUM_ONE_BYTE_CHARACTER_VALUE]; + } return firstByte; } } @@ -173,12 +183,13 @@ class ByteArrayUtils { */ // Returns the length of the string. static int readStringAndAdvancePosition(const uint8_t *const buffer, - const int maxLength, int *const outBuffer, int *const pos) { + const int maxLength, const int *const codePointTable, int *const outBuffer, + int *const pos) { int length = 0; - int codePoint = readCodePointAndAdvancePosition(buffer, pos); + int codePoint = readCodePointAndAdvancePosition(buffer, codePointTable, pos); while (NOT_A_CODE_POINT != codePoint && length < maxLength) { outBuffer[length++] = codePoint; - codePoint = readCodePointAndAdvancePosition(buffer, pos); + codePoint = readCodePointAndAdvancePosition(buffer, codePointTable, pos); } return length; } @@ -187,9 +198,9 @@ class ByteArrayUtils { static int advancePositionToBehindString( const uint8_t *const buffer, const int maxLength, int *const pos) { int length = 0; - int codePoint = readCodePointAndAdvancePosition(buffer, pos); + int codePoint = readCodePointAndAdvancePosition(buffer, nullptr /* codePointTable */, pos); while (NOT_A_CODE_POINT != codePoint && length < maxLength) { - codePoint = readCodePointAndAdvancePosition(buffer, pos); + codePoint = readCodePointAndAdvancePosition(buffer, nullptr /* codePointTable */, pos); length++; } return length; diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp index b7e2a7278..9d8e86675 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp @@ -27,6 +27,7 @@ #include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_writing_utils.h" #include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_buffers.h" #include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" +#include "suggest/policyimpl/dictionary/utils/entry_counters.h" #include "suggest/policyimpl/dictionary/utils/file_utils.h" #include "suggest/policyimpl/dictionary/utils/format_utils.h" #include "utils/time_keeper.h" @@ -69,8 +70,7 @@ template<class DictConstants, class DictBuffers, class DictBuffersPtr> DictBuffersPtr dictBuffers = DictBuffers::createVer4DictBuffers(&headerPolicy, DictConstants::MAX_DICT_EXTENDED_REGION_SIZE); headerPolicy.fillInAndWriteHeaderToBuffer(true /* updatesLastDecayedTime */, - 0 /* unigramCount */, 0 /* bigramCount */, - 0 /* extendedRegionSize */, dictBuffers->getWritableHeaderBuffer()); + EntryCounts(), 0 /* extendedRegionSize */, dictBuffers->getWritableHeaderBuffer()); if (!DynamicPtWritingUtils::writeEmptyDictionary( dictBuffers->getWritableTrieBuffer(), 0 /* rootPos */)) { AKLOGE("Empty ver4 dictionary structure cannot be created on memory."); diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/entry_counters.h b/native/jni/src/suggest/policyimpl/dictionary/utils/entry_counters.h new file mode 100644 index 000000000..73dc42a18 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/entry_counters.h @@ -0,0 +1,133 @@ +/* + * Copyright (C) 2014, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_ENTRY_COUNTERS_H +#define LATINIME_ENTRY_COUNTERS_H + +#include <array> + +#include "defines.h" + +namespace latinime { + +// Copyable but immutable +class EntryCounts final { + public: + EntryCounts() : mEntryCounts({{0, 0, 0}}) {} + + EntryCounts(const int unigramCount, const int bigramCount, const int trigramCount) + : mEntryCounts({{unigramCount, bigramCount, trigramCount}}) {} + + explicit EntryCounts(const std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1> &counters) + : mEntryCounts(counters) {} + + int getUnigramCount() const { + return mEntryCounts[0]; + } + + int getBigramCount() const { + return mEntryCounts[1]; + } + + int getTrigramCount() const { + return mEntryCounts[2]; + } + + int getNgramCount(const size_t n) const { + if (n < 1 || n > mEntryCounts.size()) { + return 0; + } + return mEntryCounts[n - 1]; + } + + private: + DISALLOW_ASSIGNMENT_OPERATOR(EntryCounts); + + const std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1> mEntryCounts; +}; + +class MutableEntryCounters final { + public: + MutableEntryCounters() { + mEntryCounters.fill(0); + } + + MutableEntryCounters(const int unigramCount, const int bigramCount, const int trigramCount) + : mEntryCounters({{unigramCount, bigramCount, trigramCount}}) {} + + const EntryCounts getEntryCounts() const { + return EntryCounts(mEntryCounters); + } + + int getUnigramCount() const { + return mEntryCounters[0]; + } + + int getBigramCount() const { + return mEntryCounters[1]; + } + + int getTrigramCount() const { + return mEntryCounters[2]; + } + + void incrementUnigramCount() { + ++mEntryCounters[0]; + } + + void decrementUnigramCount() { + ASSERT(mEntryCounters[0] != 0); + --mEntryCounters[0]; + } + + void incrementBigramCount() { + ++mEntryCounters[1]; + } + + void decrementBigramCount() { + ASSERT(mEntryCounters[1] != 0); + --mEntryCounters[1]; + } + + void incrementNgramCount(const size_t n) { + if (n < 1 || n > mEntryCounters.size()) { + return; + } + ++mEntryCounters[n - 1]; + } + + void decrementNgramCount(const size_t n) { + if (n < 1 || n > mEntryCounters.size()) { + return; + } + ASSERT(mEntryCounters[n - 1] != 0); + --mEntryCounters[n - 1]; + } + + void setNgramCount(const size_t n, const int count) { + if (n < 1 || n > mEntryCounters.size()) { + return; + } + mEntryCounters[n - 1] = count; + } + + private: + DISALLOW_COPY_AND_ASSIGN(MutableEntryCounters); + + std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1> mEntryCounters; +}; +} // namespace latinime +#endif /* LATINIME_ENTRY_COUNTERS_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp index fed0ae77e..9055f7bfc 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp @@ -29,13 +29,16 @@ namespace latinime { const int ForgettingCurveUtils::MULTIPLIER_TWO_IN_PROBABILITY_SCALE = 8; const int ForgettingCurveUtils::DECAY_INTERVAL_SECONDS = 2 * 60 * 60; -const int ForgettingCurveUtils::MAX_LEVEL = 3; -const int ForgettingCurveUtils::MIN_VISIBLE_LEVEL = 1; -const int ForgettingCurveUtils::MAX_ELAPSED_TIME_STEP_COUNT = 15; -const int ForgettingCurveUtils::DISCARD_LEVEL_ZERO_ENTRY_TIME_STEP_COUNT_THRESHOLD = 14; +const int ForgettingCurveUtils::MAX_LEVEL = 15; +const int ForgettingCurveUtils::MIN_VISIBLE_LEVEL = 2; +const int ForgettingCurveUtils::MAX_ELAPSED_TIME_STEP_COUNT = 31; +const int ForgettingCurveUtils::DISCARD_LEVEL_ZERO_ENTRY_TIME_STEP_COUNT_THRESHOLD = 30; +const int ForgettingCurveUtils::OCCURRENCES_TO_RAISE_THE_LEVEL = 1; +// TODO: Evaluate whether this should be 7.5 days. +// 15 days +const int ForgettingCurveUtils::DURATION_TO_LOWER_THE_LEVEL_IN_SECONDS = 15 * 24 * 60 * 60; -const float ForgettingCurveUtils::UNIGRAM_COUNT_HARD_LIMIT_WEIGHT = 1.2; -const float ForgettingCurveUtils::BIGRAM_COUNT_HARD_LIMIT_WEIGHT = 1.2; +const float ForgettingCurveUtils::ENTRY_COUNT_HARD_LIMIT_WEIGHT = 1.2; const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityTable; @@ -43,7 +46,7 @@ const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityT /* static */ const HistoricalInfo ForgettingCurveUtils::createUpdatedHistoricalInfo( const HistoricalInfo *const originalHistoricalInfo, const int newProbability, const HistoricalInfo *const newHistoricalInfo, const HeaderPolicy *const headerPolicy) { - const int timestamp = newHistoricalInfo->getTimeStamp(); + const int timestamp = newHistoricalInfo->getTimestamp(); if (newProbability != NOT_A_PROBABILITY && originalHistoricalInfo->getLevel() == 0) { // Add entry as a valid word. const int level = clampToVisibleEntryLevelRange(newHistoricalInfo->getLevel()); @@ -54,19 +57,23 @@ const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityT || (originalHistoricalInfo->getLevel() == newHistoricalInfo->getLevel() && originalHistoricalInfo->getCount() < newHistoricalInfo->getCount())) { // Initial information. + int count = newHistoricalInfo->getCount(); + if (count >= OCCURRENCES_TO_RAISE_THE_LEVEL) { + const int level = clampToValidLevelRange(newHistoricalInfo->getLevel() + 1); + return HistoricalInfo(timestamp, level, 0 /* count */); + } const int level = clampToValidLevelRange(newHistoricalInfo->getLevel()); - const int count = clampToValidCountRange(newHistoricalInfo->getCount(), headerPolicy); - return HistoricalInfo(timestamp, level, count); + return HistoricalInfo(timestamp, level, clampToValidCountRange(count, headerPolicy)); } else { const int updatedCount = originalHistoricalInfo->getCount() + 1; - if (updatedCount >= headerPolicy->getForgettingCurveOccurrencesToLevelUp()) { + if (updatedCount >= OCCURRENCES_TO_RAISE_THE_LEVEL) { // The count exceeds the max value the level can be incremented. if (originalHistoricalInfo->getLevel() >= MAX_LEVEL) { // The level is already max. return HistoricalInfo(timestamp, originalHistoricalInfo->getLevel(), originalHistoricalInfo->getCount()); } else { - // Level up. + // Raise the level. return HistoricalInfo(timestamp, originalHistoricalInfo->getLevel() + 1, 0 /* count */); } @@ -78,67 +85,62 @@ const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityT /* static */ int ForgettingCurveUtils::decodeProbability( const HistoricalInfo *const historicalInfo, const HeaderPolicy *const headerPolicy) { - const int elapsedTimeStepCount = getElapsedTimeStepCount(historicalInfo->getTimeStamp(), - headerPolicy->getForgettingCurveDurationToLevelDown()); + const int elapsedTimeStepCount = getElapsedTimeStepCount(historicalInfo->getTimestamp(), + DURATION_TO_LOWER_THE_LEVEL_IN_SECONDS); return sProbabilityTable.getProbability( headerPolicy->getForgettingCurveProbabilityValuesTableId(), clampToValidLevelRange(historicalInfo->getLevel()), clampToValidTimeStepCountRange(elapsedTimeStepCount)); } -/* static */ int ForgettingCurveUtils::getProbability(const int unigramProbability, - const int bigramProbability) { - if (unigramProbability == NOT_A_PROBABILITY) { - return NOT_A_PROBABILITY; - } else if (bigramProbability == NOT_A_PROBABILITY) { - return std::min(backoff(unigramProbability), MAX_PROBABILITY); - } else { - // TODO: Investigate better way to handle bigram probability. - return std::min(std::max(unigramProbability, - bigramProbability + MULTIPLIER_TWO_IN_PROBABILITY_SCALE), MAX_PROBABILITY); - } -} - /* static */ bool ForgettingCurveUtils::needsToKeep(const HistoricalInfo *const historicalInfo, const HeaderPolicy *const headerPolicy) { return historicalInfo->getLevel() > 0 - || getElapsedTimeStepCount(historicalInfo->getTimeStamp(), - headerPolicy->getForgettingCurveDurationToLevelDown()) + || getElapsedTimeStepCount(historicalInfo->getTimestamp(), + DURATION_TO_LOWER_THE_LEVEL_IN_SECONDS) < DISCARD_LEVEL_ZERO_ENTRY_TIME_STEP_COUNT_THRESHOLD; } /* static */ const HistoricalInfo ForgettingCurveUtils::createHistoricalInfoToSave( const HistoricalInfo *const originalHistoricalInfo, const HeaderPolicy *const headerPolicy) { - if (originalHistoricalInfo->getTimeStamp() == NOT_A_TIMESTAMP) { + if (originalHistoricalInfo->getTimestamp() == NOT_A_TIMESTAMP) { return HistoricalInfo(); } - const int durationToLevelDownInSeconds = headerPolicy->getForgettingCurveDurationToLevelDown(); + const int durationToLevelDownInSeconds = DURATION_TO_LOWER_THE_LEVEL_IN_SECONDS; const int elapsedTimeStep = getElapsedTimeStepCount( - originalHistoricalInfo->getTimeStamp(), durationToLevelDownInSeconds); + originalHistoricalInfo->getTimestamp(), durationToLevelDownInSeconds); if (elapsedTimeStep <= MAX_ELAPSED_TIME_STEP_COUNT) { // No need to update historical info. return *originalHistoricalInfo; } - // Level down. + // Lower the level. const int maxLevelDownAmonut = elapsedTimeStep / (MAX_ELAPSED_TIME_STEP_COUNT + 1); const int levelDownAmount = (maxLevelDownAmonut >= originalHistoricalInfo->getLevel()) ? originalHistoricalInfo->getLevel() : maxLevelDownAmonut; - const int adjustedTimestampInSeconds = originalHistoricalInfo->getTimeStamp() + + const int adjustedTimestampInSeconds = originalHistoricalInfo->getTimestamp() + levelDownAmount * durationToLevelDownInSeconds; return HistoricalInfo(adjustedTimestampInSeconds, originalHistoricalInfo->getLevel() - levelDownAmount, 0 /* count */); } /* static */ bool ForgettingCurveUtils::needsToDecay(const bool mindsBlockByDecay, - const int unigramCount, const int bigramCount, const HeaderPolicy *const headerPolicy) { - if (unigramCount >= getUnigramCountHardLimit(headerPolicy->getMaxUnigramCount())) { + const EntryCounts &entryCounts, const HeaderPolicy *const headerPolicy) { + if (entryCounts.getUnigramCount() + >= getEntryCountHardLimit(headerPolicy->getMaxUnigramCount())) { // Unigram count exceeds the limit. return true; - } else if (bigramCount >= getBigramCountHardLimit(headerPolicy->getMaxBigramCount())) { + } + if (entryCounts.getBigramCount() + >= getEntryCountHardLimit(headerPolicy->getMaxBigramCount())) { // Bigram count exceeds the limit. return true; } + if (entryCounts.getTrigramCount() + >= getEntryCountHardLimit(headerPolicy->getMaxTrigramCount())) { + // Trigram count exceeds the limit. + return true; + } if (mindsBlockByDecay) { return false; } @@ -170,7 +172,7 @@ const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityT /* static */ int ForgettingCurveUtils::clampToValidCountRange(const int count, const HeaderPolicy *const headerPolicy) { - return std::min(std::max(count, 0), headerPolicy->getForgettingCurveOccurrencesToLevelUp() - 1); + return std::min(std::max(count, 0), OCCURRENCES_TO_RAISE_THE_LEVEL - 1); } /* static */ int ForgettingCurveUtils::clampToValidLevelRange(const int level) { @@ -187,9 +189,9 @@ const int ForgettingCurveUtils::ProbabilityTable::MODEST_PROBABILITY_TABLE_ID = const int ForgettingCurveUtils::ProbabilityTable::STRONG_PROBABILITY_TABLE_ID = 2; const int ForgettingCurveUtils::ProbabilityTable::AGGRESSIVE_PROBABILITY_TABLE_ID = 3; const int ForgettingCurveUtils::ProbabilityTable::WEAK_MAX_PROBABILITY = 127; -const int ForgettingCurveUtils::ProbabilityTable::MODEST_BASE_PROBABILITY = 32; -const int ForgettingCurveUtils::ProbabilityTable::STRONG_BASE_PROBABILITY = 35; -const int ForgettingCurveUtils::ProbabilityTable::AGGRESSIVE_BASE_PROBABILITY = 40; +const int ForgettingCurveUtils::ProbabilityTable::MODEST_BASE_PROBABILITY = 8; +const int ForgettingCurveUtils::ProbabilityTable::STRONG_BASE_PROBABILITY = 9; +const int ForgettingCurveUtils::ProbabilityTable::AGGRESSIVE_BASE_PROBABILITY = 10; ForgettingCurveUtils::ProbabilityTable::ProbabilityTable() : mTables() { @@ -202,7 +204,7 @@ ForgettingCurveUtils::ProbabilityTable::ProbabilityTable() : mTables() { const float endProbability = getBaseProbabilityForLevel(tableId, level - 1); for (int timeStepCount = 0; timeStepCount <= MAX_ELAPSED_TIME_STEP_COUNT; ++timeStepCount) { - if (level == 0) { + if (level < MIN_VISIBLE_LEVEL) { mTables[tableId][level][timeStepCount] = NOT_A_PROBABILITY; continue; } diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h index 9910777b8..06dcae8a1 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h @@ -20,7 +20,8 @@ #include <vector> #include "defines.h" -#include "suggest/policyimpl/dictionary/utils/historical_info.h" +#include "suggest/core/dictionary/property/historical_info.h" +#include "suggest/policyimpl/dictionary/utils/entry_counters.h" namespace latinime { @@ -39,23 +40,20 @@ class ForgettingCurveUtils { static int decodeProbability(const HistoricalInfo *const historicalInfo, const HeaderPolicy *const headerPolicy); - static int getProbability(const int encodedUnigramProbability, - const int encodedBigramProbability); - static bool needsToKeep(const HistoricalInfo *const historicalInfo, const HeaderPolicy *const headerPolicy); - static bool needsToDecay(const bool mindsBlockByDecay, const int unigramCount, - const int bigramCount, const HeaderPolicy *const headerPolicy); + static bool needsToDecay(const bool mindsBlockByDecay, const EntryCounts &entryCounters, + const HeaderPolicy *const headerPolicy); - AK_FORCE_INLINE static int getUnigramCountHardLimit(const int maxUnigramCount) { - return static_cast<int>(static_cast<float>(maxUnigramCount) - * UNIGRAM_COUNT_HARD_LIMIT_WEIGHT); + // TODO: Improve probability computation method and remove this. + static int getProbabilityBiasForNgram(const int n) { + return (n - 1) * MULTIPLIER_TWO_IN_PROBABILITY_SCALE; } - AK_FORCE_INLINE static int getBigramCountHardLimit(const int maxBigramCount) { - return static_cast<int>(static_cast<float>(maxBigramCount) - * BIGRAM_COUNT_HARD_LIMIT_WEIGHT); + AK_FORCE_INLINE static int getEntryCountHardLimit(const int maxEntryCount) { + return static_cast<int>(static_cast<float>(maxEntryCount) + * ENTRY_COUNT_HARD_LIMIT_WEIGHT); } private: @@ -96,9 +94,10 @@ class ForgettingCurveUtils { static const int MIN_VISIBLE_LEVEL; static const int MAX_ELAPSED_TIME_STEP_COUNT; static const int DISCARD_LEVEL_ZERO_ENTRY_TIME_STEP_COUNT_THRESHOLD; + static const int OCCURRENCES_TO_RAISE_THE_LEVEL; + static const int DURATION_TO_LOWER_THE_LEVEL_IN_SECONDS; - static const float UNIGRAM_COUNT_HARD_LIMIT_WEIGHT; - static const float BIGRAM_COUNT_HARD_LIMIT_WEIGHT; + static const float ENTRY_COUNT_HARD_LIMIT_WEIGHT; static const ProbabilityTable sProbabilityTable; diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp index 1916ea560..0cffe569d 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp @@ -23,12 +23,14 @@ namespace latinime { const uint32_t FormatUtils::MAGIC_NUMBER = 0x9BC13AFE; // Magic number (4 bytes), version (2 bytes), flags (2 bytes), header size (4 bytes) = 12 -const int FormatUtils::DICTIONARY_MINIMUM_SIZE = 12; +const size_t FormatUtils::DICTIONARY_MINIMUM_SIZE = 12; /* static */ FormatUtils::FORMAT_VERSION FormatUtils::getFormatVersion(const int formatVersion) { switch (formatVersion) { case VERSION_2: return VERSION_2; + case VERSION_201: + return VERSION_201; case VERSION_4_ONLY_FOR_TESTING: return VERSION_4_ONLY_FOR_TESTING; case VERSION_4: @@ -40,14 +42,14 @@ const int FormatUtils::DICTIONARY_MINIMUM_SIZE = 12; } } /* static */ FormatUtils::FORMAT_VERSION FormatUtils::detectFormatVersion( - const uint8_t *const dict, const int dictSize) { + const ReadOnlyByteArrayView dictBuffer) { // The magic number is stored big-endian. // If the dictionary is less than 4 bytes, we can't even read the magic number, so we don't // understand this format. - if (dictSize < DICTIONARY_MINIMUM_SIZE) { + if (dictBuffer.size() < DICTIONARY_MINIMUM_SIZE) { return UNKNOWN_VERSION; } - const uint32_t magicNumber = ByteArrayUtils::readUint32(dict, 0); + const uint32_t magicNumber = ByteArrayUtils::readUint32(dictBuffer.data(), 0); switch (magicNumber) { case MAGIC_NUMBER: // The layout of the header is as follows: @@ -58,7 +60,7 @@ const int FormatUtils::DICTIONARY_MINIMUM_SIZE = 12; // Conceptually this converts the hardcoded value of the bytes in the file into // the symbolic value we use in the code. But we want the constants to be the // same so we use them for both here. - return getFormatVersion(ByteArrayUtils::readUint16(dict, 4)); + return getFormatVersion(ByteArrayUtils::readUint16(dictBuffer.data(), 4)); default: return UNKNOWN_VERSION; } diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h index 55ad5799f..96310086b 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h @@ -20,6 +20,7 @@ #include <cstdint> #include "defines.h" +#include "utils/byte_array_view.h" namespace latinime { @@ -31,6 +32,7 @@ class FormatUtils { enum FORMAT_VERSION { // These MUST have the same values as the relevant constants in FormatSpec.java. VERSION_2 = 2, + VERSION_201 = 201, VERSION_4_ONLY_FOR_TESTING = 399, VERSION_4 = 402, VERSION_4_DEV = 403, @@ -42,12 +44,12 @@ class FormatUtils { static const uint32_t MAGIC_NUMBER; static FORMAT_VERSION getFormatVersion(const int formatVersion); - static FORMAT_VERSION detectFormatVersion(const uint8_t *const dict, const int dictSize); + static FORMAT_VERSION detectFormatVersion(const ReadOnlyByteArrayView dictBuffer); private: DISALLOW_IMPLICIT_CONSTRUCTORS(FormatUtils); - static const int DICTIONARY_MINIMUM_SIZE; + static const size_t DICTIONARY_MINIMUM_SIZE; }; } // namespace latinime #endif /* LATINIME_FORMAT_UTILS_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/sparse_table.h b/native/jni/src/suggest/policyimpl/dictionary/utils/sparse_table.h index fca8120f1..e1a96c6f7 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/sparse_table.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/sparse_table.h @@ -24,7 +24,6 @@ namespace latinime { -// Note that there is a corresponding implementation in SparseTable.java. // TODO: Support multiple content buffers. class SparseTable { public: diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp index 407b8efd0..39f417ebb 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.cpp @@ -26,6 +26,7 @@ const int TrieMap::FIELD1_SIZE = 3; const int TrieMap::ENTRY_SIZE = FIELD0_SIZE + FIELD1_SIZE; const uint32_t TrieMap::VALUE_FLAG = 0x400000; const uint32_t TrieMap::VALUE_MASK = 0x3FFFFF; +const uint32_t TrieMap::INVALID_VALUE_IN_KEY_VALUE_ENTRY = VALUE_MASK; const uint32_t TrieMap::TERMINAL_LINK_FLAG = 0x800000; const uint32_t TrieMap::TERMINAL_LINK_MASK = 0x7FFFFF; const int TrieMap::NUM_OF_BITS_USED_FOR_ONE_LEVEL = 5; @@ -34,6 +35,7 @@ const int TrieMap::MAX_NUM_OF_ENTRIES_IN_ONE_LEVEL = 1 << NUM_OF_BITS_USED_FOR_O const int TrieMap::ROOT_BITMAP_ENTRY_INDEX = 0; const int TrieMap::ROOT_BITMAP_ENTRY_POS = MAX_NUM_OF_ENTRIES_IN_ONE_LEVEL * FIELD0_SIZE; const TrieMap::Entry TrieMap::EMPTY_BITMAP_ENTRY = TrieMap::Entry(0, 0); +const int TrieMap::TERMINAL_LINKED_ENTRY_COUNT = 2; // Value entry and bitmap entry. const uint64_t TrieMap::MAX_VALUE = (static_cast<uint64_t>(1) << ((FIELD0_SIZE + FIELD1_SIZE) * CHAR_BIT)) - 1; const int TrieMap::MAX_BUFFER_SIZE = TERMINAL_LINK_MASK * ENTRY_SIZE; @@ -76,14 +78,14 @@ int TrieMap::getNextLevelBitmapEntryIndex(const int key, const int bitmapEntryIn return terminalEntry.getValueEntryIndex() + 1; } // Create a value entry and a bitmap entry. - const int valueEntryIndex = allocateTable(2 /* entryCount */); + const int valueEntryIndex = allocateTable(TERMINAL_LINKED_ENTRY_COUNT); if (!writeEntry(Entry(0, terminalEntry.getValue()), valueEntryIndex)) { return INVALID_INDEX; } if (!writeEntry(EMPTY_BITMAP_ENTRY, valueEntryIndex + 1)) { return INVALID_INDEX; } - if (!writeField1(valueEntryIndex | TERMINAL_LINK_FLAG, valueEntryIndex)) { + if (!writeField1(valueEntryIndex | TERMINAL_LINK_FLAG, terminalEntryIndex)) { return INVALID_INDEX; } return valueEntryIndex + 1; @@ -108,6 +110,31 @@ bool TrieMap::save(FILE *const file) const { return DictFileWritingUtils::writeBufferToFileTail(file, &mBuffer); } +bool TrieMap::remove(const int key, const int bitmapEntryIndex) { + const Entry bitmapEntry = readEntry(bitmapEntryIndex); + const uint32_t unsignedKey = static_cast<uint32_t>(key); + const int terminalEntryIndex = getTerminalEntryIndex( + unsignedKey, getBitShuffledKey(unsignedKey), bitmapEntry, 0 /* level */); + if (terminalEntryIndex == INVALID_INDEX) { + // Not found. + return false; + } + const Entry terminalEntry = readEntry(terminalEntryIndex); + if (!writeField1(VALUE_FLAG ^ INVALID_VALUE_IN_KEY_VALUE_ENTRY , terminalEntryIndex)) { + return false; + } + if (terminalEntry.hasTerminalLink()) { + const Entry nextLevelBitmapEntry = readEntry(terminalEntry.getValueEntryIndex() + 1); + if (!freeTable(terminalEntry.getValueEntryIndex(), TERMINAL_LINKED_ENTRY_COUNT)) { + return false; + } + if (!removeInner(nextLevelBitmapEntry)){ + return false; + } + } + return true; +} + /** * Iterate next entry in a certain level. * @@ -129,7 +156,7 @@ const TrieMap::Result TrieMap::iterateNext(std::vector<TableIterationState> *con if (entry.isBitmapEntry()) { // Move to child. iterationState->emplace_back(popCount(entry.getBitmap()), entry.getTableIndex()); - } else { + } else if (entry.isValidTerminalEntry()) { if (outKey) { *outKey = entry.getKey(); } @@ -162,12 +189,12 @@ uint32_t TrieMap::getBitShuffledKey(const uint32_t key) const { } bool TrieMap::writeValue(const uint64_t value, const int terminalEntryIndex) { - if (value <= VALUE_MASK) { + if (value < VALUE_MASK) { // Write value into the terminal entry. return writeField1(value | VALUE_FLAG, terminalEntryIndex); } // Create value entry and write value. - const int valueEntryIndex = allocateTable(2 /* entryCount */); + const int valueEntryIndex = allocateTable(TERMINAL_LINKED_ENTRY_COUNT); if (!writeEntry(Entry(value >> (FIELD1_SIZE * CHAR_BIT), value), valueEntryIndex)) { return false; } @@ -227,6 +254,9 @@ int TrieMap::getTerminalEntryIndex(const uint32_t key, const uint32_t hashedKey, // Move to the next level. return getTerminalEntryIndex(key, hashedKey, entry, level + 1); } + if (!entry.isValidTerminalEntry()) { + return INVALID_INDEX; + } if (entry.getKey() == key) { // Terminal entry is found. return entryIndex; @@ -287,6 +317,10 @@ bool TrieMap::putInternal(const uint32_t key, const uint64_t value, const uint32 // Bitmap entry is found. Go to the next level. return putInternal(key, value, hashedKey, entryIndex, entry, level + 1); } + if (!entry.isValidTerminalEntry()) { + // Overwrite invalid terminal entry. + return writeTerminalEntry(key, value, entryIndex); + } if (entry.getKey() == key) { // Terminal entry for the key is found. Update the value. return updateValue(entry, value, entryIndex); @@ -384,4 +418,37 @@ bool TrieMap::addNewEntryByExpandingTable(const uint32_t key, const uint64_t val return true; } +bool TrieMap::removeInner(const Entry &bitmapEntry) { + const int tableSize = popCount(bitmapEntry.getBitmap()); + if (tableSize <= 0) { + // The table is empty. No need to remove any entries. + return true; + } + for (int i = 0; i < tableSize; ++i) { + const int entryIndex = bitmapEntry.getTableIndex() + i; + const Entry entry = readEntry(entryIndex); + if (entry.isBitmapEntry()) { + // Delete next bitmap entry recursively. + if (!removeInner(entry)) { + return false; + } + } else { + // Invalidate terminal entry just in case. + if (!writeField1(VALUE_FLAG ^ INVALID_VALUE_IN_KEY_VALUE_ENTRY , entryIndex)) { + return false; + } + if (entry.hasTerminalLink()) { + const Entry nextLevelBitmapEntry = readEntry(entry.getValueEntryIndex() + 1); + if (!freeTable(entry.getValueEntryIndex(), TERMINAL_LINKED_ENTRY_COUNT)) { + return false; + } + if (!removeInner(nextLevelBitmapEntry)) { + return false; + } + } + } + } + return true; +} + } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h index 3e5c4010c..00765888b 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h @@ -84,6 +84,10 @@ class TrieMap { return mValue; } + AK_FORCE_INLINE int getNextLevelBitmapEntryIndex() const { + return mNextLevelBitmapEntryIndex; + } + private: const TrieMap *const mTrieMap; const int mKey; @@ -94,7 +98,7 @@ class TrieMap { TrieMapIterator(const TrieMap *const trieMap, const int bitmapEntryIndex) : mTrieMap(trieMap), mStateStack(), mBaseBitmapEntryIndex(bitmapEntryIndex), mKey(0), mValue(0), mIsValid(false), mNextLevelBitmapEntryIndex(INVALID_INDEX) { - if (!trieMap) { + if (!trieMap || mBaseBitmapEntryIndex == INVALID_INDEX) { return; } const Entry bitmapEntry = mTrieMap->readEntry(mBaseBitmapEntryIndex); @@ -202,6 +206,8 @@ class TrieMap { bool save(FILE *const file) const; + bool remove(const int key, const int bitmapEntryIndex); + private: DISALLOW_COPY_AND_ASSIGN(TrieMap); @@ -245,6 +251,11 @@ class TrieMap { } // For terminal entry. + AK_FORCE_INLINE bool isValidTerminalEntry() const { + return hasTerminalLink() || ((mData1 & VALUE_MASK) != INVALID_VALUE_IN_KEY_VALUE_ENTRY); + } + + // For terminal entry. AK_FORCE_INLINE uint32_t getValueEntryIndex() const { return mData1 & TERMINAL_LINK_MASK; } @@ -272,6 +283,7 @@ class TrieMap { static const int ENTRY_SIZE; static const uint32_t VALUE_FLAG; static const uint32_t VALUE_MASK; + static const uint32_t INVALID_VALUE_IN_KEY_VALUE_ENTRY; static const uint32_t TERMINAL_LINK_FLAG; static const uint32_t TERMINAL_LINK_MASK; static const int NUM_OF_BITS_USED_FOR_ONE_LEVEL; @@ -280,6 +292,7 @@ class TrieMap { static const int ROOT_BITMAP_ENTRY_INDEX; static const int ROOT_BITMAP_ENTRY_POS; static const Entry EMPTY_BITMAP_ENTRY; + static const int TERMINAL_LINKED_ENTRY_COUNT; static const int MAX_BUFFER_SIZE; uint32_t getBitShuffledKey(const uint32_t key) const; @@ -378,6 +391,8 @@ class TrieMap { AK_FORCE_INLINE int getTailEntryIndex() const { return (mBuffer.getTailPosition() - ROOT_BITMAP_ENTRY_POS) / ENTRY_SIZE; } + + bool removeInner(const Entry &bitmapEntry); }; } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp index 3fc566e7a..6a2db687d 100644 --- a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp +++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp @@ -31,6 +31,7 @@ const float ScoringParams::DIGRAPH_PENALTY_FOR_EXACT_MATCH = 0.03f; // TODO: Unlimit max cache dic node size const int ScoringParams::MAX_CACHE_DIC_NODE_SIZE = 170; const int ScoringParams::MAX_CACHE_DIC_NODE_SIZE_FOR_SINGLE_POINT = 310; +const int ScoringParams::MAX_CACHE_DIC_NODE_SIZE_FOR_LOW_PROBABILITY_LOCALE = 50; const int ScoringParams::THRESHOLD_SHORT_WORD_LENGTH = 4; const float ScoringParams::DISTANCE_WEIGHT_LENGTH = 0.1524f; @@ -48,7 +49,7 @@ const float ScoringParams::INSERTION_COST_PROXIMITY_CHAR = 0.674f; const float ScoringParams::INSERTION_COST_FIRST_CHAR = 0.639f; const float ScoringParams::TRANSPOSITION_COST = 0.5608f; const float ScoringParams::SPACE_SUBSTITUTION_COST = 0.334f; -const float ScoringParams::ADDITIONAL_PROXIMITY_COST = 0.4576f; +const float ScoringParams::ADDITIONAL_PROXIMITY_COST = 0.37972f; const float ScoringParams::SUBSTITUTION_COST = 0.3806f; const float ScoringParams::COST_NEW_WORD = 0.0314f; const float ScoringParams::COST_SECOND_OR_LATER_WORD_FIRST_CHAR_UPPERCASE = 0.3224f; @@ -61,4 +62,7 @@ const float ScoringParams::HAS_MULTI_WORD_TERMINAL_COST = 0.4182f; const float ScoringParams::TYPING_BASE_OUTPUT_SCORE = 1.0f; const float ScoringParams::TYPING_MAX_OUTPUT_SCORE_PER_INPUT = 0.1f; const float ScoringParams::NORMALIZED_SPATIAL_DISTANCE_THRESHOLD_FOR_EDIT = 0.095f; +const float ScoringParams::LOCALE_WEIGHT_THRESHOLD_FOR_SPACE_SUBSTITUTION = 0.99f; +const float ScoringParams::LOCALE_WEIGHT_THRESHOLD_FOR_SPACE_OMISSION = 0.99f; +const float ScoringParams::LOCALE_WEIGHT_THRESHOLD_FOR_SMALL_CACHE_SIZE = 0.99f; } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.h b/native/jni/src/suggest/policyimpl/typing/scoring_params.h index b12de6d87..731424f3d 100644 --- a/native/jni/src/suggest/policyimpl/typing/scoring_params.h +++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.h @@ -30,6 +30,7 @@ class ScoringParams { static const float AUTOCORRECT_OUTPUT_THRESHOLD; static const int MAX_CACHE_DIC_NODE_SIZE; static const int MAX_CACHE_DIC_NODE_SIZE_FOR_SINGLE_POINT; + static const int MAX_CACHE_DIC_NODE_SIZE_FOR_LOW_PROBABILITY_LOCALE; static const int THRESHOLD_SHORT_WORD_LENGTH; static const float EXACT_MATCH_PROMOTION; @@ -68,6 +69,9 @@ class ScoringParams { static const float TYPING_BASE_OUTPUT_SCORE; static const float TYPING_MAX_OUTPUT_SCORE_PER_INPUT; static const float NORMALIZED_SPATIAL_DISTANCE_THRESHOLD_FOR_EDIT; + static const float LOCALE_WEIGHT_THRESHOLD_FOR_SPACE_SUBSTITUTION; + static const float LOCALE_WEIGHT_THRESHOLD_FOR_SPACE_OMISSION; + static const float LOCALE_WEIGHT_THRESHOLD_FOR_SMALL_CACHE_SIZE; private: DISALLOW_IMPLICIT_CONSTRUCTORS(ScoringParams); diff --git a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h index 04cb6603a..0240bcf54 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h @@ -33,10 +33,12 @@ class TypingScoring : public Scoring { static const TypingScoring *getInstance() { return &sInstance; } AK_FORCE_INLINE void getMostProbableString(const DicTraverseSession *const traverseSession, - const float languageWeight, SuggestionResults *const outSuggestionResults) const {} + const float weightOfLangModelVsSpatialModel, + SuggestionResults *const outSuggestionResults) const {} - AK_FORCE_INLINE float getAdjustedLanguageWeight(DicTraverseSession *const traverseSession, - DicNode *const terminals, const int size) const { + AK_FORCE_INLINE float getAdjustedWeightOfLangModelVsSpatialModel( + DicTraverseSession *const traverseSession, DicNode *const terminals, + const int size) const { return 1.0f; } @@ -51,10 +53,10 @@ class TypingScoring : public Scoring { } if (boostExactMatches && ErrorTypeUtils::isExactMatch(containedErrorTypes)) { score += ScoringParams::EXACT_MATCH_PROMOTION; - if ((ErrorTypeUtils::MATCH_WITH_CASE_ERROR & containedErrorTypes) != 0) { + if ((ErrorTypeUtils::MATCH_WITH_WRONG_CASE & containedErrorTypes) != 0) { score -= ScoringParams::CASE_ERROR_PENALTY_FOR_EXACT_MATCH; } - if ((ErrorTypeUtils::MATCH_WITH_ACCENT_ERROR & containedErrorTypes) != 0) { + if ((ErrorTypeUtils::MATCH_WITH_MISSING_ACCENT & containedErrorTypes) != 0) { score -= ScoringParams::ACCENT_ERROR_PENALTY_FOR_EXACT_MATCH; } if ((ErrorTypeUtils::MATCH_WITH_DIGRAPH & containedErrorTypes) != 0) { diff --git a/native/jni/src/suggest/policyimpl/typing/typing_traversal.h b/native/jni/src/suggest/policyimpl/typing/typing_traversal.h index cb3dfac70..b9b6314ae 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_traversal.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_traversal.h @@ -26,6 +26,7 @@ #include "suggest/core/layout/proximity_info_utils.h" #include "suggest/core/policy/traversal.h" #include "suggest/core/session/dic_traverse_session.h" +#include "suggest/core/suggest_options.h" #include "suggest/policyimpl/typing/scoring_params.h" #include "utils/char_utils.h" @@ -77,6 +78,13 @@ class TypingTraversal : public Traversal { if (!CORRECT_NEW_WORD_SPACE_SUBSTITUTION) { return false; } + if (traverseSession->getSuggestOptions()->weightForLocale() + < ScoringParams::LOCALE_WEIGHT_THRESHOLD_FOR_SPACE_SUBSTITUTION) { + // Space substitution is heavy, so we skip doing it if the weight for this language + // is low because we anticipate the suggestions out of this dictionary are not for + // the language the user intends to type in. + return false; + } if (!canDoLookAheadCorrection(traverseSession, dicNode)) { return false; } @@ -91,6 +99,13 @@ class TypingTraversal : public Traversal { if (!CORRECT_NEW_WORD_SPACE_OMISSION) { return false; } + if (traverseSession->getSuggestOptions()->weightForLocale() + < ScoringParams::LOCALE_WEIGHT_THRESHOLD_FOR_SPACE_OMISSION) { + // Space omission is heavy, so we skip doing it if the weight for this language + // is low because we anticipate the suggestions out of this dictionary are not for + // the language the user intends to type in. + return false; + } const int inputSize = traverseSession->getInputSize(); // TODO: Don't refer to isCompletion? if (dicNode->isCompletion(inputSize)) { @@ -141,9 +156,14 @@ class TypingTraversal : public Traversal { return DicNodeVector::DEFAULT_NODES_SIZE_FOR_OPTIMIZATION; } - AK_FORCE_INLINE int getMaxCacheSize(const int inputSize) const { - return (inputSize <= 1) ? ScoringParams::MAX_CACHE_DIC_NODE_SIZE_FOR_SINGLE_POINT - : ScoringParams::MAX_CACHE_DIC_NODE_SIZE; + AK_FORCE_INLINE int getMaxCacheSize(const int inputSize, const float weightForLocale) const { + if (inputSize <= 1) { + return ScoringParams::MAX_CACHE_DIC_NODE_SIZE_FOR_SINGLE_POINT; + } + if (weightForLocale < ScoringParams::LOCALE_WEIGHT_THRESHOLD_FOR_SMALL_CACHE_SIZE) { + return ScoringParams::MAX_CACHE_DIC_NODE_SIZE_FOR_LOW_PROBABILITY_LOCALE; + } + return ScoringParams::MAX_CACHE_DIC_NODE_SIZE; } AK_FORCE_INLINE int getTerminalCacheSize() const { @@ -161,8 +181,8 @@ class TypingTraversal : public Traversal { return true; } - AK_FORCE_INLINE bool isGoodToTraverseNextWord(const DicNode *const dicNode) const { - const int probability = dicNode->getProbability(); + AK_FORCE_INLINE bool isGoodToTraverseNextWord(const DicNode *const dicNode, + const int probability) const { if (probability < ScoringParams::THRESHOLD_NEXT_WORD_PROBABILITY) { return false; } diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp b/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp index 54f65c786..db7a39efb 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp +++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp @@ -36,30 +36,40 @@ ErrorTypeUtils::ErrorType TypingWeighting::getErrorType(const CorrectionType cor // Compare the node code point with original primary code point on the keyboard. const ProximityInfoState *const pInfoState = traverseSession->getProximityInfoState(0); - const int primaryOriginalCodePoint = pInfoState->getPrimaryOriginalCodePointAt( + const int primaryCodePoint = pInfoState->getPrimaryCodePointAt( dicNode->getInputIndex(0)); const int nodeCodePoint = dicNode->getNodeCodePoint(); - if (primaryOriginalCodePoint == nodeCodePoint) { + // TODO: Check whether the input code point is on the keyboard. + if (primaryCodePoint == nodeCodePoint) { // Node code point is same as original code point on the keyboard. return ErrorTypeUtils::NOT_AN_ERROR; - } else if (CharUtils::toLowerCase(primaryOriginalCodePoint) == + } else if (CharUtils::toLowerCase(primaryCodePoint) == CharUtils::toLowerCase(nodeCodePoint)) { // Only cases of the code points are different. - return ErrorTypeUtils::MATCH_WITH_CASE_ERROR; - } else if (CharUtils::toBaseCodePoint(primaryOriginalCodePoint) == - CharUtils::toBaseCodePoint(nodeCodePoint)) { + return ErrorTypeUtils::MATCH_WITH_WRONG_CASE; + } else if (primaryCodePoint == CharUtils::toBaseCodePoint(nodeCodePoint)) { // Node code point is a variant of original code point. - return ErrorTypeUtils::MATCH_WITH_ACCENT_ERROR; - } else { + return ErrorTypeUtils::MATCH_WITH_MISSING_ACCENT; + } else if (CharUtils::toBaseCodePoint(primaryCodePoint) + == CharUtils::toBaseCodePoint(nodeCodePoint)) { + // Base code points are the same but the code point is intentionally input. + return ErrorTypeUtils::MATCH_WITH_WRONG_ACCENT; + } else if (CharUtils::toLowerCase(primaryCodePoint) + == CharUtils::toBaseLowerCase(nodeCodePoint)) { // Node code point is a variant of original code point and the cases are also // different. - return ErrorTypeUtils::MATCH_WITH_ACCENT_ERROR - | ErrorTypeUtils::MATCH_WITH_CASE_ERROR; + return ErrorTypeUtils::MATCH_WITH_MISSING_ACCENT + | ErrorTypeUtils::MATCH_WITH_WRONG_CASE; + } else { + // Base code points are the same and the cases are different. + return ErrorTypeUtils::MATCH_WITH_WRONG_ACCENT + | ErrorTypeUtils::MATCH_WITH_WRONG_CASE; } } break; case CT_ADDITIONAL_PROXIMITY: - return ErrorTypeUtils::PROXIMITY_CORRECTION; + // TODO: Change to EDIT_CORRECTION. + return ErrorTypeUtils::PROXIMITY_CORRECTION; case CT_OMISSION: if (parentDicNode->canBeIntentionalOmission()) { return ErrorTypeUtils::INTENTIONAL_OMISSION; @@ -68,6 +78,8 @@ ErrorTypeUtils::ErrorType TypingWeighting::getErrorType(const CorrectionType cor } break; case CT_SUBSTITUTION: + // TODO: Quit settng PROXIMITY_CORRECTION. + return ErrorTypeUtils::EDIT_CORRECTION | ErrorTypeUtils::PROXIMITY_CORRECTION; case CT_INSERTION: case CT_TERMINAL_INSERTION: case CT_TRANSPOSITION: diff --git a/native/jni/src/utils/byte_array_view.h b/native/jni/src/utils/byte_array_view.h index 2c97c6d58..2b778af6f 100644 --- a/native/jni/src/utils/byte_array_view.h +++ b/native/jni/src/utils/byte_array_view.h @@ -42,6 +42,13 @@ class ReadOnlyByteArrayView { return mPtr; } + AK_FORCE_INLINE const ReadOnlyByteArrayView skip(const size_t n) const { + if (mSize <= n) { + return ReadOnlyByteArrayView(); + } + return ReadOnlyByteArrayView(mPtr + n, mSize - n); + } + private: DISALLOW_ASSIGNMENT_OPERATOR(ReadOnlyByteArrayView); @@ -77,10 +84,12 @@ class ReadWriteByteArrayView { } private: - DISALLOW_ASSIGNMENT_OPERATOR(ReadWriteByteArrayView); + // Default copy constructor and assignment operator are used for using this class with STL + // containers. - uint8_t *const mPtr; - const size_t mSize; + // These members cannot be const to have the assignment operator. + uint8_t *mPtr; + size_t mSize; }; } // namespace latinime diff --git a/native/jni/src/utils/char_utils.cpp b/native/jni/src/utils/char_utils.cpp index b17e0847d..3bb9055b2 100644 --- a/native/jni/src/utils/char_utils.cpp +++ b/native/jni/src/utils/char_utils.cpp @@ -1057,11 +1057,11 @@ static int compare_pair_capital(const void *a, const void *b) { - static_cast<int>((static_cast<const struct LatinCapitalSmallPair *>(b))->capital); } -/* static */ unsigned short CharUtils::latin_tolower(const unsigned short c) { +/* static */ int CharUtils::latin_tolower(const int c) { struct LatinCapitalSmallPair *p = static_cast<struct LatinCapitalSmallPair *>(bsearch(&c, SORTED_CHAR_MAP, NELEMS(SORTED_CHAR_MAP), sizeof(SORTED_CHAR_MAP[0]), compare_pair_capital)); - return p ? p->small : c; + return p ? static_cast<int>(p->small) : c; } /* diff --git a/native/jni/src/utils/char_utils.h b/native/jni/src/utils/char_utils.h index 63786502b..7871c26ef 100644 --- a/native/jni/src/utils/char_utils.h +++ b/native/jni/src/utils/char_utils.h @@ -27,20 +27,14 @@ namespace latinime { class CharUtils { public: + static const std::vector<int> EMPTY_STRING; + static AK_FORCE_INLINE bool isAsciiUpper(int c) { // Note: isupper(...) reports false positives for some Cyrillic characters, causing them to // be incorrectly lower-cased using toAsciiLower(...) rather than latin_tolower(...). return (c >= 'A' && c <= 'Z'); } - static AK_FORCE_INLINE int toAsciiLower(int c) { - return c - 'A' + 'a'; - } - - static AK_FORCE_INLINE bool isAscii(int c) { - return isascii(c) != 0; - } - static AK_FORCE_INLINE int toLowerCase(const int c) { if (isAsciiUpper(c)) { return toAsciiLower(c); @@ -48,7 +42,7 @@ class CharUtils { if (isAscii(c)) { return c; } - return static_cast<int>(latin_tolower(static_cast<unsigned short>(c))); + return latin_tolower(c); } static AK_FORCE_INLINE int toBaseLowerCase(const int c) { @@ -59,7 +53,6 @@ class CharUtils { // TODO: Do not hardcode here return codePoint == KEYCODE_SINGLE_QUOTE || codePoint == KEYCODE_HYPHEN_MINUS; } - static AK_FORCE_INLINE int getCodePointCount(const int arraySize, const int *const codePoints) { int size = 0; for (; size < arraySize; ++size) { @@ -91,9 +84,6 @@ class CharUtils { return codePoint >= MIN_UNICODE_CODE_POINT && codePoint <= MAX_UNICODE_CODE_POINT; } - static unsigned short latin_tolower(const unsigned short c); - static const std::vector<int> EMPTY_STRING; - // Returns updated code point count. Returns 0 when the code points cannot be marked as a // Beginning-of-Sentence. static AK_FORCE_INLINE int attachBeginningOfSentenceMarker(int *const codePoints, @@ -111,6 +101,17 @@ class CharUtils { return codePointCount + 1; } + // Returns updated code point count. + static AK_FORCE_INLINE int removeBeginningOfSentenceMarker(int *const codePoints, + const int codePointCount) { + if (codePointCount <= 0 || codePoints[0] != CODE_POINT_BEGINNING_OF_SENTENCE) { + return codePointCount; + } + const int newCodePointCount = codePointCount - 1; + memmove(codePoints, codePoints + 1, sizeof(int) * newCodePointCount); + return newCodePointCount; + } + private: DISALLOW_IMPLICIT_CONSTRUCTORS(CharUtils); @@ -125,6 +126,16 @@ class CharUtils { */ static const int BASE_CHARS_SIZE = 0x0500; static const unsigned short BASE_CHARS[BASE_CHARS_SIZE]; + + static AK_FORCE_INLINE bool isAscii(int c) { + return isascii(c) != 0; + } + + static AK_FORCE_INLINE int toAsciiLower(int c) { + return c - 'A' + 'a'; + } + + static int latin_tolower(const int c); }; } // namespace latinime #endif // LATINIME_CHAR_UTILS_H diff --git a/native/jni/src/utils/int_array_view.h b/native/jni/src/utils/int_array_view.h index c1ddc9812..408373176 100644 --- a/native/jni/src/utils/int_array_view.h +++ b/native/jni/src/utils/int_array_view.h @@ -17,8 +17,10 @@ #ifndef LATINIME_INT_ARRAY_VIEW_H #define LATINIME_INT_ARRAY_VIEW_H +#include <algorithm> +#include <array> #include <cstdint> -#include <cstdlib> +#include <cstring> #include <vector> #include "defines.h" @@ -56,14 +58,14 @@ class IntArrayView { explicit IntArrayView(const std::vector<int> &vector) : mPtr(vector.data()), mSize(vector.size()) {} - template <int N> - AK_FORCE_INLINE static IntArrayView fromFixedSizeArray(const int (&array)[N]) { - return IntArrayView(array, N); + template <size_t N> + AK_FORCE_INLINE static IntArrayView fromArray(const std::array<int, N> &array) { + return IntArrayView(array.data(), array.size()); } - // Returns a view that points one int object. Does not take ownership of the given object. - AK_FORCE_INLINE static IntArrayView fromObject(const int *const object) { - return IntArrayView(object, 1); + // Returns a view that points one int object. + AK_FORCE_INLINE static IntArrayView singleElementView(const int *const ptr) { + return IntArrayView(ptr, 1); } AK_FORCE_INLINE int operator[](const size_t index) const { @@ -91,6 +93,46 @@ class IntArrayView { return mPtr + mSize; } + AK_FORCE_INLINE bool contains(const int value) const { + return std::find(begin(), end(), value) != end(); + } + + // Returns the view whose size is smaller than or equal to the given count. + AK_FORCE_INLINE const IntArrayView limit(const size_t maxSize) const { + return IntArrayView(mPtr, std::min(maxSize, mSize)); + } + + AK_FORCE_INLINE const IntArrayView skip(const size_t n) const { + if (mSize <= n) { + return IntArrayView(); + } + return IntArrayView(mPtr + n, mSize - n); + } + + template <size_t N> + void copyToArray(std::array<int, N> *const buffer, const size_t offset) const { + ASSERT(mSize + offset <= N); + memmove(buffer->data() + offset, mPtr, sizeof(int) * mSize); + } + + AK_FORCE_INLINE int firstOrDefault(const int defaultValue) const { + if (empty()) { + return defaultValue; + } + return mPtr[0]; + } + + AK_FORCE_INLINE int lastOrDefault(const int defaultValue) const { + if (empty()) { + return defaultValue; + } + return mPtr[mSize - 1]; + } + + AK_FORCE_INLINE std::vector<int> toVector() const { + return std::vector<int>(begin(), end()); + } + private: DISALLOW_ASSIGNMENT_OPERATOR(IntArrayView); @@ -100,6 +142,9 @@ class IntArrayView { using WordIdArrayView = IntArrayView; using PtNodePosArrayView = IntArrayView; +using CodePointArrayView = IntArrayView; +template <size_t size> +using WordIdArray = std::array<int, size>; } // namespace latinime #endif // LATINIME_MEMORY_VIEW_H diff --git a/native/jni/src/utils/jni_data_utils.h b/native/jni/src/utils/jni_data_utils.h index cb82d3c3b..a259e1cd0 100644 --- a/native/jni/src/utils/jni_data_utils.h +++ b/native/jni/src/utils/jni_data_utils.h @@ -21,7 +21,7 @@ #include "defines.h" #include "jni.h" -#include "suggest/core/session/prev_words_info.h" +#include "suggest/core/session/ngram_context.h" #include "suggest/core/policy/dictionary_header_structure_policy.h" #include "suggest/policyimpl/dictionary/header/header_read_write_utils.h" #include "utils/char_utils.h" @@ -50,6 +50,7 @@ class JniDataUtils { const jsize keyUtf8Length = env->GetStringUTFLength(keyString); char keyChars[keyUtf8Length + 1]; env->GetStringUTFRegion(keyString, 0, env->GetStringLength(keyString), keyChars); + env->DeleteLocalRef(keyString); keyChars[keyUtf8Length] = '\0'; DictionaryHeaderStructurePolicy::AttributeMap::key_type key; HeaderReadWriteUtils::insertCharactersIntoVector(keyChars, &key); @@ -59,6 +60,7 @@ class JniDataUtils { const jsize valueUtf8Length = env->GetStringUTFLength(valueString); char valueChars[valueUtf8Length + 1]; env->GetStringUTFRegion(valueString, 0, env->GetStringLength(valueString), valueChars); + env->DeleteLocalRef(valueString); valueChars[valueUtf8Length] = '\0'; DictionaryHeaderStructurePolicy::AttributeMap::mapped_type value; HeaderReadWriteUtils::insertCharactersIntoVector(valueChars, &value); @@ -96,18 +98,14 @@ class JniDataUtils { } } - static PrevWordsInfo constructPrevWordsInfo(JNIEnv *env, jobjectArray prevWordCodePointArrays, - jbooleanArray isBeginningOfSentenceArray) { + static NgramContext constructNgramContext(JNIEnv *env, jobjectArray prevWordCodePointArrays, + jbooleanArray isBeginningOfSentenceArray, const size_t prevWordCount) { int prevWordCodePoints[MAX_PREV_WORD_COUNT_FOR_N_GRAM][MAX_WORD_LENGTH]; int prevWordCodePointCount[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; bool isBeginningOfSentence[MAX_PREV_WORD_COUNT_FOR_N_GRAM]; - jsize prevWordsCount = env->GetArrayLength(prevWordCodePointArrays); - for (size_t i = 0; i < NELEMS(prevWordCodePoints); ++i) { + for (size_t i = 0; i < prevWordCount; ++i) { prevWordCodePointCount[i] = 0; isBeginningOfSentence[i] = false; - if (prevWordsCount <= static_cast<int>(i)) { - continue; - } jintArray prevWord = (jintArray)env->GetObjectArrayElement(prevWordCodePointArrays, i); if (!prevWord) { continue; @@ -117,14 +115,15 @@ class JniDataUtils { continue; } env->GetIntArrayRegion(prevWord, 0, prevWordLength, prevWordCodePoints[i]); + env->DeleteLocalRef(prevWord); prevWordCodePointCount[i] = prevWordLength; jboolean isBeginningOfSentenceBoolean = JNI_FALSE; env->GetBooleanArrayRegion(isBeginningOfSentenceArray, i, 1 /* len */, &isBeginningOfSentenceBoolean); isBeginningOfSentence[i] = isBeginningOfSentenceBoolean == JNI_TRUE; } - return PrevWordsInfo(prevWordCodePoints, prevWordCodePointCount, isBeginningOfSentence, - MAX_PREV_WORD_COUNT_FOR_N_GRAM); + return NgramContext(prevWordCodePoints, prevWordCodePointCount, isBeginningOfSentence, + prevWordCount); } static void putBooleanToArray(JNIEnv *env, jbooleanArray array, const int index, diff --git a/native/jni/tests/suggest/core/dicnode/dic_node_pool_test.cpp b/native/jni/tests/suggest/core/dicnode/dic_node_pool_test.cpp new file mode 100644 index 000000000..854efdfe6 --- /dev/null +++ b/native/jni/tests/suggest/core/dicnode/dic_node_pool_test.cpp @@ -0,0 +1,69 @@ +/* + * Copyright (C) 2014 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/core/dicnode/dic_node_pool.h" + +#include <gtest/gtest.h> + +namespace latinime { +namespace { + +TEST(DicNodePoolTest, TestGet) { + static const int CAPACITY = 10; + DicNodePool dicNodePool(CAPACITY); + + for (int i = 0; i < CAPACITY; ++i) { + EXPECT_NE(nullptr, dicNodePool.getInstance()); + } + EXPECT_EQ(nullptr, dicNodePool.getInstance()); +} + +TEST(DicNodePoolTest, TestPlaceBack) { + static const int CAPACITY = 1; + DicNodePool dicNodePool(CAPACITY); + + DicNode *const dicNode = dicNodePool.getInstance(); + EXPECT_NE(nullptr, dicNode); + EXPECT_EQ(nullptr, dicNodePool.getInstance()); + dicNodePool.placeBackInstance(dicNode); + EXPECT_EQ(dicNode, dicNodePool.getInstance()); +} + +TEST(DicNodePoolTest, TestReset) { + static const int CAPACITY_SMALL = 2; + static const int CAPACITY_LARGE = 10; + DicNodePool dicNodePool(CAPACITY_SMALL); + + for (int i = 0; i < CAPACITY_SMALL; ++i) { + EXPECT_NE(nullptr, dicNodePool.getInstance()); + } + EXPECT_EQ(nullptr, dicNodePool.getInstance()); + + dicNodePool.reset(CAPACITY_LARGE); + for (int i = 0; i < CAPACITY_LARGE; ++i) { + EXPECT_NE(nullptr, dicNodePool.getInstance()); + } + EXPECT_EQ(nullptr, dicNodePool.getInstance()); + + dicNodePool.reset(CAPACITY_SMALL); + for (int i = 0; i < CAPACITY_SMALL; ++i) { + EXPECT_NE(nullptr, dicNodePool.getInstance()); + } + EXPECT_EQ(nullptr, dicNodePool.getInstance()); +} + +} // namespace +} // namespace latinime diff --git a/native/jni/tests/suggest/core/layout/geometry_utils_test.cpp b/native/jni/tests/suggest/core/layout/geometry_utils_test.cpp new file mode 100644 index 000000000..f5f89ede1 --- /dev/null +++ b/native/jni/tests/suggest/core/layout/geometry_utils_test.cpp @@ -0,0 +1,83 @@ +/* + * Copyright (C) 2014 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/core/layout/geometry_utils.h" + +#include <gtest/gtest.h> + +namespace latinime { +namespace { + +::testing::AssertionResult ExpectAngleDiffEq(const char* expectedExpression, + const char* actualExpression, float expected, float actual) { + if (actual < 0.0f || M_PI_F < actual) { + return ::testing::AssertionFailure() + << "Must be in the range of [0.0f, M_PI_F]." + << " expected: " << expected + << " actual: " << actual; + } + return ::testing::internal::CmpHelperFloatingPointEQ<float>( + expectedExpression, actualExpression, expected, actual); +} + +#define EXPECT_ANGLE_DIFF_EQ(expected, actual) \ + EXPECT_PRED_FORMAT2(ExpectAngleDiffEq, expected, actual); + +TEST(GeometryUtilsTest, testSquareFloat) { + const float test_data[] = { 0.0f, 1.0f, 123.456f, -1.0f, -9876.54321f }; + for (const float value : test_data) { + EXPECT_FLOAT_EQ(value * value, GeometryUtils::SQUARE_FLOAT(value)); + } +} + +TEST(GeometryUtilsTest, testGetAngle) { + EXPECT_FLOAT_EQ(0.0f, GeometryUtils::getAngle(0, 0, 0, 0)); + EXPECT_FLOAT_EQ(0.0f, GeometryUtils::getAngle(100, -10, 100, -10)); + + EXPECT_FLOAT_EQ(M_PI_F / 4.0f, GeometryUtils::getAngle(1, 1, 0, 0)); + EXPECT_FLOAT_EQ(M_PI_F, GeometryUtils::getAngle(-1, 0, 0, 0)); + + EXPECT_FLOAT_EQ(GeometryUtils::getAngle(0, 0, -1, 0), GeometryUtils::getAngle(1, 0, 0, 0)); + EXPECT_FLOAT_EQ(GeometryUtils::getAngle(1, 2, 3, 4), + GeometryUtils::getAngle(100, 200, 300, 400)); +} + +TEST(GeometryUtilsTest, testGetAngleDiff) { + EXPECT_ANGLE_DIFF_EQ(0.0f, GeometryUtils::getAngleDiff(0.0f, 0.0f)); + EXPECT_ANGLE_DIFF_EQ(0.0f, GeometryUtils::getAngleDiff(10000.0f, 10000.0f)); + EXPECT_ANGLE_DIFF_EQ(ROUND_FLOAT_10000(M_PI_F), + GeometryUtils::getAngleDiff(0.0f, M_PI_F)); + EXPECT_ANGLE_DIFF_EQ(ROUND_FLOAT_10000(M_PI_F / 6.0f), + GeometryUtils::getAngleDiff(M_PI_F / 3.0f, M_PI_F / 2.0f)); + EXPECT_ANGLE_DIFF_EQ(ROUND_FLOAT_10000(M_PI_F / 2.0f), + GeometryUtils::getAngleDiff(0.0f, M_PI_F * 1.5f)); + EXPECT_ANGLE_DIFF_EQ(0.0f, GeometryUtils::getAngleDiff(0.0f, M_PI_F * 1024.0f)); + EXPECT_ANGLE_DIFF_EQ(0.0f, GeometryUtils::getAngleDiff(-M_PI_F, M_PI_F)); +} + +TEST(GeometryUtilsTest, testGetDistanceInt) { + EXPECT_EQ(0, GeometryUtils::getDistanceInt(0, 0, 0, 0)); + EXPECT_EQ(0, GeometryUtils::getAngle(100, -10, 100, -10)); + + EXPECT_EQ(5, GeometryUtils::getDistanceInt(0, 0, 5, 0)); + EXPECT_EQ(5, GeometryUtils::getDistanceInt(0, 0, 3, 4)); + EXPECT_EQ(5, GeometryUtils::getDistanceInt(0, -4, 3, 0)); + EXPECT_EQ(5, GeometryUtils::getDistanceInt(0, 0, -3, -4)); + EXPECT_EQ(500, GeometryUtils::getDistanceInt(0, 0, 300, -400)); +} + +} // namespace +} // namespace latinime diff --git a/native/jni/tests/suggest/policyimpl/dictionary/header/header_read_write_utils_test.cpp b/native/jni/tests/suggest/policyimpl/dictionary/header/header_read_write_utils_test.cpp new file mode 100644 index 000000000..da6a2af27 --- /dev/null +++ b/native/jni/tests/suggest/policyimpl/dictionary/header/header_read_write_utils_test.cpp @@ -0,0 +1,78 @@ +/* + * Copyright (C) 2014 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/header/header_read_write_utils.h" + +#include <gtest/gtest.h> + +#include <cstring> +#include <vector> + +#include "suggest/core/policy/dictionary_header_structure_policy.h" + +namespace latinime { +namespace { + +TEST(HeaderReadWriteUtilsTest, TestInsertCharactersIntoVector) { + DictionaryHeaderStructurePolicy::AttributeMap::key_type vector; + + HeaderReadWriteUtils::insertCharactersIntoVector("", &vector); + EXPECT_TRUE(vector.empty()); + + static const char *str = "abc-xyz!?"; + HeaderReadWriteUtils::insertCharactersIntoVector(str, &vector); + EXPECT_EQ(strlen(str) , vector.size()); + for (size_t i = 0; i < vector.size(); ++i) { + EXPECT_EQ(str[i], vector[i]); + } +} + +TEST(HeaderReadWriteUtilsTest, TestAttributeMapForInt) { + DictionaryHeaderStructurePolicy::AttributeMap attributeMap; + + // Returns default value if not exists. + EXPECT_EQ(-1, HeaderReadWriteUtils::readIntAttributeValue(&attributeMap, "", -1)); + EXPECT_EQ(100, HeaderReadWriteUtils::readIntAttributeValue(&attributeMap, "abc", 100)); + + HeaderReadWriteUtils::setIntAttribute(&attributeMap, "abc", 10); + EXPECT_EQ(10, HeaderReadWriteUtils::readIntAttributeValue(&attributeMap, "abc", 100)); + HeaderReadWriteUtils::setIntAttribute(&attributeMap, "abc", 20); + EXPECT_EQ(20, HeaderReadWriteUtils::readIntAttributeValue(&attributeMap, "abc", 100)); + HeaderReadWriteUtils::setIntAttribute(&attributeMap, "abcd", 30); + EXPECT_EQ(30, HeaderReadWriteUtils::readIntAttributeValue(&attributeMap, "abcd", 100)); + EXPECT_EQ(20, HeaderReadWriteUtils::readIntAttributeValue(&attributeMap, "abc", 100)); +} + +TEST(HeaderReadWriteUtilsTest, TestAttributeMapCodeForPoints) { + DictionaryHeaderStructurePolicy::AttributeMap attributeMap; + + // Returns empty vector if not exists. + EXPECT_TRUE(HeaderReadWriteUtils::readCodePointVectorAttributeValue(&attributeMap, "").empty()); + EXPECT_TRUE(HeaderReadWriteUtils::readCodePointVectorAttributeValue( + &attributeMap, "abc").empty()); + + HeaderReadWriteUtils::setCodePointVectorAttribute(&attributeMap, "abc", {}); + EXPECT_TRUE(HeaderReadWriteUtils::readCodePointVectorAttributeValue( + &attributeMap, "abc").empty()); + + const std::vector<int> codePoints = { 0x0, 0x20, 0x1F, 0x100000 }; + HeaderReadWriteUtils::setCodePointVectorAttribute(&attributeMap, "abc", codePoints); + EXPECT_EQ(codePoints, HeaderReadWriteUtils::readCodePointVectorAttributeValue( + &attributeMap, "abc")); +} + +} // namespace +} // namespace latinime diff --git a/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp b/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp index 6eef2040b..4469dc715 100644 --- a/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp +++ b/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp @@ -18,27 +18,37 @@ #include <gtest/gtest.h> +#include <array> +#include <unordered_set> + #include "utils/int_array_view.h" namespace latinime { namespace { TEST(LanguageModelDictContentTest, TestUnigramProbability) { - LanguageModelDictContent LanguageModelDictContent(false /* useHistoricalInfo */); + LanguageModelDictContent languageModelDictContent(false /* useHistoricalInfo */); - const int flag = 0xFF; + const int flag = 0xF0; const int probability = 10; const int wordId = 100; const ProbabilityEntry probabilityEntry(flag, probability); - LanguageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry); + languageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry); const ProbabilityEntry entry = - LanguageModelDictContent.getProbabilityEntry(wordId); + languageModelDictContent.getProbabilityEntry(wordId); EXPECT_EQ(flag, entry.getFlags()); EXPECT_EQ(probability, entry.getProbability()); + + // Remove + EXPECT_TRUE(languageModelDictContent.removeProbabilityEntry(wordId)); + EXPECT_FALSE(languageModelDictContent.getProbabilityEntry(wordId).isValid()); + EXPECT_FALSE(languageModelDictContent.removeProbabilityEntry(wordId)); + EXPECT_TRUE(languageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry)); + EXPECT_TRUE(languageModelDictContent.getProbabilityEntry(wordId).isValid()); } TEST(LanguageModelDictContentTest, TestUnigramProbabilityWithHistoricalInfo) { - LanguageModelDictContent LanguageModelDictContent(true /* useHistoricalInfo */); + LanguageModelDictContent languageModelDictContent(true /* useHistoricalInfo */); const int flag = 0xF0; const int timestamp = 0x3FFFFFFF; @@ -46,13 +56,66 @@ TEST(LanguageModelDictContentTest, TestUnigramProbabilityWithHistoricalInfo) { const int count = 10; const int wordId = 100; const HistoricalInfo historicalInfo(timestamp, level, count); - const ProbabilityEntry probabilityEntry(flag, NOT_A_PROBABILITY, &historicalInfo); - LanguageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry); - const ProbabilityEntry entry = LanguageModelDictContent.getProbabilityEntry(wordId); + const ProbabilityEntry probabilityEntry(flag, &historicalInfo); + languageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry); + const ProbabilityEntry entry = languageModelDictContent.getProbabilityEntry(wordId); EXPECT_EQ(flag, entry.getFlags()); - EXPECT_EQ(timestamp, entry.getHistoricalInfo()->getTimeStamp()); + EXPECT_EQ(timestamp, entry.getHistoricalInfo()->getTimestamp()); EXPECT_EQ(level, entry.getHistoricalInfo()->getLevel()); EXPECT_EQ(count, entry.getHistoricalInfo()->getCount()); + + // Remove + EXPECT_TRUE(languageModelDictContent.removeProbabilityEntry(wordId)); + EXPECT_FALSE(languageModelDictContent.getProbabilityEntry(wordId).isValid()); + EXPECT_FALSE(languageModelDictContent.removeProbabilityEntry(wordId)); + EXPECT_TRUE(languageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry)); + EXPECT_TRUE(languageModelDictContent.removeProbabilityEntry(wordId)); +} + +TEST(LanguageModelDictContentTest, TestIterateProbabilityEntry) { + LanguageModelDictContent languageModelDictContent(false /* useHistoricalInfo */); + + const ProbabilityEntry originalEntry(0xFC, 100); + + const int wordIds[] = { 1, 2, 3, 4, 5 }; + for (const int wordId : wordIds) { + languageModelDictContent.setProbabilityEntry(wordId, &originalEntry); + } + std::unordered_set<int> wordIdSet(std::begin(wordIds), std::end(wordIds)); + for (const auto entry : languageModelDictContent.getProbabilityEntries(WordIdArrayView())) { + EXPECT_EQ(originalEntry.getFlags(), entry.getProbabilityEntry().getFlags()); + EXPECT_EQ(originalEntry.getProbability(), entry.getProbabilityEntry().getProbability()); + wordIdSet.erase(entry.getWordId()); + } + EXPECT_TRUE(wordIdSet.empty()); +} + +TEST(LanguageModelDictContentTest, TestGetWordProbability) { + LanguageModelDictContent languageModelDictContent(false /* useHistoricalInfo */); + + const int flag = 0xFF; + const int probability = 10; + const int bigramProbability = 20; + const int trigramProbability = 30; + const int wordId = 100; + const std::array<int, 2> prevWordIdArray = {{ 1, 2 }}; + const WordIdArrayView prevWordIds = WordIdArrayView::fromArray(prevWordIdArray); + + const ProbabilityEntry probabilityEntry(flag, probability); + languageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry); + const ProbabilityEntry bigramProbabilityEntry(flag, bigramProbability); + languageModelDictContent.setProbabilityEntry(prevWordIds[0], &probabilityEntry); + languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1), wordId, + &bigramProbabilityEntry); + EXPECT_EQ(bigramProbability, languageModelDictContent.getWordAttributes(prevWordIds, wordId, + nullptr /* headerPolicy */).getProbability()); + const ProbabilityEntry trigramProbabilityEntry(flag, trigramProbability); + languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1), + prevWordIds[1], &probabilityEntry); + languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(2), wordId, + &trigramProbabilityEntry); + EXPECT_EQ(trigramProbability, languageModelDictContent.getWordAttributes(prevWordIds, wordId, + nullptr /* headerPolicy */).getProbability()); } } // namespace diff --git a/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/probability_entry_test.cpp b/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/probability_entry_test.cpp index db94550ef..260b347ce 100644 --- a/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/probability_entry_test.cpp +++ b/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/probability_entry_test.cpp @@ -43,7 +43,7 @@ TEST(ProbabilityEntryTest, TestEncodeDecodeWithHistoricalInfo) { const int count = 10; const HistoricalInfo historicalInfo(timestamp, level, count); - const ProbabilityEntry entry(flag, NOT_A_PROBABILITY, &historicalInfo); + const ProbabilityEntry entry(flag, &historicalInfo); const uint64_t encodedEntry = entry.encode(true /* hasHistoricalInfo */); EXPECT_EQ(0xF03FFFFFFF030Aull, encodedEntry); @@ -51,7 +51,7 @@ TEST(ProbabilityEntryTest, TestEncodeDecodeWithHistoricalInfo) { ProbabilityEntry::decode(encodedEntry, true /* hasHistoricalInfo */); EXPECT_EQ(flag, decodedEntry.getFlags()); - EXPECT_EQ(timestamp, decodedEntry.getHistoricalInfo()->getTimeStamp()); + EXPECT_EQ(timestamp, decodedEntry.getHistoricalInfo()->getTimestamp()); EXPECT_EQ(level, decodedEntry.getHistoricalInfo()->getLevel()); EXPECT_EQ(count, decodedEntry.getHistoricalInfo()->getCount()); } diff --git a/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table_test.cpp b/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table_test.cpp new file mode 100644 index 000000000..23b9c55f7 --- /dev/null +++ b/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table_test.cpp @@ -0,0 +1,76 @@ +/* + * Copyright (C) 2014 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h" + +#include <gtest/gtest.h> + +#include <vector> + +#include "defines.h" +#include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h" + +namespace latinime { +namespace { + +TEST(TerminalPositionLookupTableTest, TestGetFromEmptyTable) { + TerminalPositionLookupTable lookupTable; + + EXPECT_EQ(NOT_A_DICT_POS, lookupTable.getTerminalPtNodePosition(0)); + EXPECT_EQ(NOT_A_DICT_POS, lookupTable.getTerminalPtNodePosition(-1)); + EXPECT_EQ(NOT_A_DICT_POS, lookupTable.getTerminalPtNodePosition( + Ver4DictConstants::NOT_A_TERMINAL_ID)); +} + +TEST(TerminalPositionLookupTableTest, TestSetAndGet) { + TerminalPositionLookupTable lookupTable; + + EXPECT_TRUE(lookupTable.setTerminalPtNodePosition(10, 100)); + EXPECT_EQ(100, lookupTable.getTerminalPtNodePosition(10)); + EXPECT_EQ(NOT_A_DICT_POS, lookupTable.getTerminalPtNodePosition(9)); + EXPECT_TRUE(lookupTable.setTerminalPtNodePosition(9, 200)); + EXPECT_EQ(200, lookupTable.getTerminalPtNodePosition(9)); + EXPECT_TRUE(lookupTable.setTerminalPtNodePosition(10, 300)); + EXPECT_EQ(300, lookupTable.getTerminalPtNodePosition(10)); + EXPECT_FALSE(lookupTable.setTerminalPtNodePosition(-1, 400)); + EXPECT_EQ(NOT_A_DICT_POS, lookupTable.getTerminalPtNodePosition(-1)); + EXPECT_FALSE(lookupTable.setTerminalPtNodePosition(Ver4DictConstants::NOT_A_TERMINAL_ID, 500)); + EXPECT_EQ(NOT_A_DICT_POS, lookupTable.getTerminalPtNodePosition( + Ver4DictConstants::NOT_A_TERMINAL_ID)); +} + +TEST(TerminalPositionLookupTableTest, TestGC) { + TerminalPositionLookupTable lookupTable; + + const std::vector<int> terminalIds = { 10, 20, 30 }; + const std::vector<int> terminalPositions = { 100, 200, 300 }; + + for (size_t i = 0; i < terminalIds.size(); ++i) { + EXPECT_TRUE(lookupTable.setTerminalPtNodePosition(terminalIds[i], terminalPositions[i])); + } + + TerminalPositionLookupTable::TerminalIdMap terminalIdMap; + EXPECT_TRUE(lookupTable.runGCTerminalIds(&terminalIdMap)); + + for (size_t i = 0; i < terminalIds.size(); ++i) { + EXPECT_EQ(static_cast<int>(i), terminalIdMap[terminalIds[i]]) + << "Terminal id (" << terminalIds[i] << ") should be changed to " << i; + EXPECT_EQ(terminalPositions[i], lookupTable.getTerminalPtNodePosition(i)); + } +} + +} // namespace +} // namespace latinime diff --git a/native/jni/tests/suggest/policyimpl/dictionary/utils/byte_array_utils_test.cpp b/native/jni/tests/suggest/policyimpl/dictionary/utils/byte_array_utils_test.cpp new file mode 100644 index 000000000..c201e0d00 --- /dev/null +++ b/native/jni/tests/suggest/policyimpl/dictionary/utils/byte_array_utils_test.cpp @@ -0,0 +1,105 @@ +/* + * Copyright (C) 2014 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/utils/byte_array_utils.h" + +#include <gtest/gtest.h> + +#include <cstdint> + +namespace latinime { +namespace { + +TEST(ByteArrayUtilsTest, TestReadCodePointTable) { + const int codePointTable[] = { 0x6f, 0x6b }; + const uint8_t buffer[] = { 0x20u, 0x21u, 0x00u, 0x01u, 0x00u }; + int pos = 0; + // Expect the first entry of codePointTable + EXPECT_EQ(0x6f, ByteArrayUtils::readCodePointAndAdvancePosition(buffer, codePointTable, &pos)); + // Expect the second entry of codePointTable + EXPECT_EQ(0x6b, ByteArrayUtils::readCodePointAndAdvancePosition(buffer, codePointTable, &pos)); + // Expect the original code point from buffer[2] to buffer[4], 0x100 + // It isn't picked from the codePointTable, since it exceeds the range of the codePointTable. + EXPECT_EQ(0x100, ByteArrayUtils::readCodePointAndAdvancePosition(buffer, codePointTable, &pos)); +} + +TEST(ByteArrayUtilsTest, TestReadInt) { + const uint8_t buffer[] = { 0x1u, 0x8Au, 0x0u, 0xAAu }; + + EXPECT_EQ(0x01u, ByteArrayUtils::readUint8(buffer, 0)); + EXPECT_EQ(0x8Au, ByteArrayUtils::readUint8(buffer, 1)); + EXPECT_EQ(0x0u, ByteArrayUtils::readUint8(buffer, 2)); + EXPECT_EQ(0xAAu, ByteArrayUtils::readUint8(buffer, 3)); + + EXPECT_EQ(0x018Au, ByteArrayUtils::readUint16(buffer, 0)); + EXPECT_EQ(0x8A00u, ByteArrayUtils::readUint16(buffer, 1)); + EXPECT_EQ(0xAAu, ByteArrayUtils::readUint16(buffer, 2)); + + EXPECT_EQ(0x18A00AAu, ByteArrayUtils::readUint32(buffer, 0)); + + int pos = 0; + EXPECT_EQ(0x18A00, ByteArrayUtils::readSint24AndAdvancePosition(buffer, &pos)); + pos = 1; + EXPECT_EQ(-0xA00AA, ByteArrayUtils::readSint24AndAdvancePosition(buffer, &pos)); +} + +TEST(ByteArrayUtilsTest, TestWriteAndReadInt) { + uint8_t buffer[4]; + + int pos = 0; + const uint8_t data_1B = 0xC8; + ByteArrayUtils::writeUintAndAdvancePosition(buffer, data_1B, 1, &pos); + EXPECT_EQ(data_1B, ByteArrayUtils::readUint(buffer, 1, 0)); + + pos = 0; + const uint32_t data_4B = 0xABCD1234; + ByteArrayUtils::writeUintAndAdvancePosition(buffer, data_4B, 4, &pos); + EXPECT_EQ(data_4B, ByteArrayUtils::readUint(buffer, 4, 0)); +} + +TEST(ByteArrayUtilsTest, TestReadCodePoint) { + const uint8_t buffer[] = { 0x10, 0xFF, 0x00u, 0x20u, 0x41u, 0x1Fu, 0x60 }; + + EXPECT_EQ(0x10FF00, ByteArrayUtils::readCodePoint(buffer, 0)); + EXPECT_EQ(0x20, ByteArrayUtils::readCodePoint(buffer, 3)); + EXPECT_EQ(0x41, ByteArrayUtils::readCodePoint(buffer, 4)); + EXPECT_EQ(NOT_A_CODE_POINT, ByteArrayUtils::readCodePoint(buffer, 5)); + + int pos = 0; + int codePointArray[3]; + EXPECT_EQ(3, ByteArrayUtils::readStringAndAdvancePosition(buffer, MAX_WORD_LENGTH, nullptr, + codePointArray, &pos)); + EXPECT_EQ(0x10FF00, codePointArray[0]); + EXPECT_EQ(0x20, codePointArray[1]); + EXPECT_EQ(0x41, codePointArray[2]); + EXPECT_EQ(0x60, ByteArrayUtils::readCodePoint(buffer, pos)); +} + +TEST(ByteArrayUtilsTest, TestWriteAndReadCodePoint) { + uint8_t buffer[10]; + + const int codePointArray[] = { 0x10FF00, 0x20, 0x41 }; + int pos = 0; + ByteArrayUtils::writeCodePointsAndAdvancePosition(buffer, codePointArray, 3, + true /* writesTerminator */, &pos); + EXPECT_EQ(0x10FF00, ByteArrayUtils::readCodePoint(buffer, 0)); + EXPECT_EQ(0x20, ByteArrayUtils::readCodePoint(buffer, 3)); + EXPECT_EQ(0x41, ByteArrayUtils::readCodePoint(buffer, 4)); + EXPECT_EQ(NOT_A_CODE_POINT, ByteArrayUtils::readCodePoint(buffer, 5)); +} + +} // namespace +} // namespace latinime diff --git a/native/jni/tests/suggest/policyimpl/dictionary/utils/format_utils_test.cpp b/native/jni/tests/suggest/policyimpl/dictionary/utils/format_utils_test.cpp new file mode 100644 index 000000000..15f560cd1 --- /dev/null +++ b/native/jni/tests/suggest/policyimpl/dictionary/utils/format_utils_test.cpp @@ -0,0 +1,97 @@ +/* + * Copyright (C) 2014 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/utils/format_utils.h" + +#include <gtest/gtest.h> + +#include <vector> + +#include "utils/byte_array_view.h" + +namespace latinime { +namespace { + +TEST(FormatUtilsTest, TestMagicNumber) { + EXPECT_EQ(0x9BC13AFE, FormatUtils::MAGIC_NUMBER) << "Magic number must not be changed."; +} + +const std::vector<uint8_t> getBuffer(const int magicNumber, const int version, const uint16_t flags, + const size_t headerSize) { + std::vector<uint8_t> buffer; + buffer.push_back(magicNumber >> 24); + buffer.push_back(magicNumber >> 16); + buffer.push_back(magicNumber >> 8); + buffer.push_back(magicNumber); + + buffer.push_back(version >> 8); + buffer.push_back(version); + + buffer.push_back(flags >> 8); + buffer.push_back(flags); + + buffer.push_back(headerSize >> 24); + buffer.push_back(headerSize >> 16); + buffer.push_back(headerSize >> 8); + buffer.push_back(headerSize); + return buffer; +} + +TEST(FormatUtilsTest, TestDetectFormatVersion) { + EXPECT_EQ(FormatUtils::UNKNOWN_VERSION, + FormatUtils::detectFormatVersion(ReadOnlyByteArrayView())); + + { + const std::vector<uint8_t> buffer = + getBuffer(FormatUtils::MAGIC_NUMBER, FormatUtils::VERSION_2, 0, 0); + EXPECT_EQ(FormatUtils::VERSION_2, FormatUtils::detectFormatVersion( + ReadOnlyByteArrayView(buffer.data(), buffer.size()))); + } + { + const std::vector<uint8_t> buffer = + getBuffer(FormatUtils::MAGIC_NUMBER, FormatUtils::VERSION_4, 0, 0); + EXPECT_EQ(FormatUtils::VERSION_4, FormatUtils::detectFormatVersion( + ReadOnlyByteArrayView(buffer.data(), buffer.size()))); + } + { + const std::vector<uint8_t> buffer = + getBuffer(FormatUtils::MAGIC_NUMBER, FormatUtils::VERSION_4_DEV, 0, 0); + EXPECT_EQ(FormatUtils::VERSION_4_DEV, FormatUtils::detectFormatVersion( + ReadOnlyByteArrayView(buffer.data(), buffer.size()))); + } + + { + const std::vector<uint8_t> buffer = + getBuffer(FormatUtils::MAGIC_NUMBER - 1, FormatUtils::VERSION_2, 0, 0); + EXPECT_EQ(FormatUtils::UNKNOWN_VERSION, FormatUtils::detectFormatVersion( + ReadOnlyByteArrayView(buffer.data(), buffer.size()))); + } + { + const std::vector<uint8_t> buffer = + getBuffer(FormatUtils::MAGIC_NUMBER, 100, 0, 0); + EXPECT_EQ(FormatUtils::UNKNOWN_VERSION, FormatUtils::detectFormatVersion( + ReadOnlyByteArrayView(buffer.data(), buffer.size()))); + } + { + const std::vector<uint8_t> buffer = + getBuffer(FormatUtils::MAGIC_NUMBER, FormatUtils::VERSION_2, 0, 0); + EXPECT_EQ(FormatUtils::UNKNOWN_VERSION, FormatUtils::detectFormatVersion( + ReadOnlyByteArrayView(buffer.data(), buffer.size() - 1))); + } +} + +} // namespace +} // namespace latinime diff --git a/native/jni/tests/suggest/policyimpl/dictionary/utils/sparse_table_test.cpp b/native/jni/tests/suggest/policyimpl/dictionary/utils/sparse_table_test.cpp new file mode 100644 index 000000000..0b57156a0 --- /dev/null +++ b/native/jni/tests/suggest/policyimpl/dictionary/utils/sparse_table_test.cpp @@ -0,0 +1,47 @@ +/* + * Copyright (C) 2014 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/utils/sparse_table.h" + +#include <gtest/gtest.h> + +#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" + +namespace latinime { +namespace { + +TEST(SparseTableTest, TestSetAndGet) { + static const int BLOCK_SIZE = 64; + static const int DATA_SIZE = 4; + BufferWithExtendableBuffer indexTableBuffer( + BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE); + BufferWithExtendableBuffer contentTableBuffer( + BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE); + SparseTable sparseTable(&indexTableBuffer, &contentTableBuffer, BLOCK_SIZE, DATA_SIZE); + + EXPECT_FALSE(sparseTable.contains(10)); + EXPECT_TRUE(sparseTable.set(10, 100u)); + EXPECT_EQ(100u, sparseTable.get(10)); + EXPECT_TRUE(sparseTable.contains(10)); + EXPECT_TRUE(sparseTable.contains(BLOCK_SIZE - 1)); + EXPECT_FALSE(sparseTable.contains(BLOCK_SIZE)); + EXPECT_TRUE(sparseTable.set(11, 101u)); + EXPECT_EQ(100u, sparseTable.get(10)); + EXPECT_EQ(101u, sparseTable.get(11)); +} + +} // namespace +} // namespace latinime diff --git a/native/jni/tests/suggest/policyimpl/dictionary/utils/trie_map_test.cpp b/native/jni/tests/suggest/policyimpl/dictionary/utils/trie_map_test.cpp index df778b6cf..56b5aa985 100644 --- a/native/jni/tests/suggest/policyimpl/dictionary/utils/trie_map_test.cpp +++ b/native/jni/tests/suggest/policyimpl/dictionary/utils/trie_map_test.cpp @@ -40,6 +40,7 @@ TEST(TrieMapTest, TestSetAndGet) { trieMap.putRoot(11, 1000); EXPECT_EQ(1000ull, trieMap.getRoot(11).mValue); const int next = trieMap.getNextLevelBitmapEntryIndex(10); + EXPECT_EQ(1000ull, trieMap.getRoot(10).mValue); trieMap.put(9, 9, next); EXPECT_EQ(9ull, trieMap.get(9, next).mValue); EXPECT_FALSE(trieMap.get(11, next).mIsValid); @@ -47,6 +48,33 @@ TEST(TrieMapTest, TestSetAndGet) { EXPECT_EQ(0xFFFFFFFFFull, trieMap.getRoot(0).mValue); } +TEST(TrieMapTest, TestRemove) { + TrieMap trieMap; + trieMap.putRoot(10, 10); + EXPECT_EQ(10ull, trieMap.getRoot(10).mValue); + EXPECT_TRUE(trieMap.remove(10, trieMap.getRootBitmapEntryIndex())); + EXPECT_FALSE(trieMap.getRoot(10).mIsValid); + for (const auto &element : trieMap.getEntriesInRootLevel()) { + EXPECT_TRUE(false); + } + EXPECT_TRUE(trieMap.putRoot(10, 0x3FFFFF)); + EXPECT_FALSE(trieMap.remove(11, trieMap.getRootBitmapEntryIndex())) + << "Should fail if the key does not exist."; + EXPECT_EQ(0x3FFFFFull, trieMap.getRoot(10).mValue); + trieMap.putRoot(12, 11); + const int nextLevel = trieMap.getNextLevelBitmapEntryIndex(10); + trieMap.put(10, 10, nextLevel); + EXPECT_EQ(0x3FFFFFull, trieMap.getRoot(10).mValue); + EXPECT_EQ(10ull, trieMap.get(10, nextLevel).mValue); + EXPECT_TRUE(trieMap.remove(10, trieMap.getRootBitmapEntryIndex())); + const TrieMap::Result result = trieMap.getRoot(10); + EXPECT_FALSE(result.mIsValid); + EXPECT_EQ(TrieMap::INVALID_INDEX, result.mNextLevelBitmapEntryIndex); + EXPECT_EQ(11ull, trieMap.getRoot(12).mValue); + EXPECT_TRUE(trieMap.putRoot(S_INT_MAX, 0xFFFFFFFFFull)); + EXPECT_TRUE(trieMap.remove(S_INT_MAX, trieMap.getRootBitmapEntryIndex())); +} + TEST(TrieMapTest, TestSetAndGetLarge) { static const int ELEMENT_COUNT = 200000; TrieMap trieMap; diff --git a/native/jni/tests/suggest/policyimpl/utils/damerau_levenshtein_edit_distance_policy_test.cpp b/native/jni/tests/suggest/policyimpl/utils/damerau_levenshtein_edit_distance_policy_test.cpp new file mode 100644 index 000000000..d13417964 --- /dev/null +++ b/native/jni/tests/suggest/policyimpl/utils/damerau_levenshtein_edit_distance_policy_test.cpp @@ -0,0 +1,65 @@ +/* + * Copyright (C) 2014 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/utils/damerau_levenshtein_edit_distance_policy.h" + +#include <gtest/gtest.h> + +#include <vector> + +#include "suggest/policyimpl/utils/edit_distance.h" +#include "utils/int_array_view.h" + +namespace latinime { +namespace { + +TEST(DamerauLevenshteinEditDistancePolicyTest, TestConstructPolicy) { + const std::vector<int> codePoints0 = { 0x20, 0x40, 0x60 }; + const std::vector<int> codePoints1 = { 0x10, 0x20, 0x30, 0x40, 0x50, 0x60 }; + DamerauLevenshteinEditDistancePolicy policy(codePoints0.data(), codePoints0.size(), + codePoints1.data(), codePoints1.size()); + + EXPECT_EQ(static_cast<int>(codePoints0.size()), policy.getString0Length()); + EXPECT_EQ(static_cast<int>(codePoints1.size()), policy.getString1Length()); +} + +float getEditDistance(const std::vector<int> &codePoints0, const std::vector<int> &codePoints1) { + DamerauLevenshteinEditDistancePolicy policy(codePoints0.data(), codePoints0.size(), + codePoints1.data(), codePoints1.size()); + return EditDistance::getEditDistance(&policy); +} + +TEST(DamerauLevenshteinEditDistancePolicyTest, TestEditDistance) { + EXPECT_FLOAT_EQ(0.0f, getEditDistance({}, {})); + EXPECT_FLOAT_EQ(0.0f, getEditDistance({ 1 }, { 1 })); + EXPECT_FLOAT_EQ(0.0f, getEditDistance({ 1, 2, 3 }, { 1, 2, 3 })); + + EXPECT_FLOAT_EQ(1.0f, getEditDistance({ 1 }, { })); + EXPECT_FLOAT_EQ(1.0f, getEditDistance({}, { 100 })); + EXPECT_FLOAT_EQ(5.0f, getEditDistance({}, { 1, 2, 3, 4, 5 })); + + EXPECT_FLOAT_EQ(1.0f, getEditDistance({ 0 }, { 100 })); + EXPECT_FLOAT_EQ(5.0f, getEditDistance({ 1, 2, 3, 4, 5 }, { 11, 12, 13, 14, 15 })); + + EXPECT_FLOAT_EQ(1.0f, getEditDistance({ 1 }, { 1, 2 })); + EXPECT_FLOAT_EQ(2.0f, getEditDistance({ 1, 2 }, { 0, 1, 2, 3 })); + EXPECT_FLOAT_EQ(2.0f, getEditDistance({ 0, 1, 2, 3 }, { 1, 2 })); + + EXPECT_FLOAT_EQ(1.0f, getEditDistance({ 1, 2 }, { 2, 1 })); + EXPECT_FLOAT_EQ(2.0f, getEditDistance({ 1, 2, 3, 4 }, { 2, 1, 4, 3 })); +} +} // namespace +} // namespace latinime diff --git a/native/jni/tests/utils/char_utils_test.cpp b/native/jni/tests/utils/char_utils_test.cpp new file mode 100644 index 000000000..01d534043 --- /dev/null +++ b/native/jni/tests/utils/char_utils_test.cpp @@ -0,0 +1,122 @@ +/* + * Copyright (C) 2014 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/char_utils.h" + +#include <gtest/gtest.h> + +#include "defines.h" + +namespace latinime { +namespace { + +TEST(CharUtilsTest, TestIsAsciiUpper) { + EXPECT_TRUE(CharUtils::isAsciiUpper('A')); + EXPECT_TRUE(CharUtils::isAsciiUpper('Z')); + EXPECT_FALSE(CharUtils::isAsciiUpper('a')); + EXPECT_FALSE(CharUtils::isAsciiUpper('z')); + EXPECT_FALSE(CharUtils::isAsciiUpper('@')); + EXPECT_FALSE(CharUtils::isAsciiUpper(' ')); + EXPECT_FALSE(CharUtils::isAsciiUpper(0x00C0 /* LATIN CAPITAL LETTER A WITH GRAVE */)); + EXPECT_FALSE(CharUtils::isAsciiUpper(0x00E0 /* LATIN SMALL LETTER A WITH GRAVE */)); + EXPECT_FALSE(CharUtils::isAsciiUpper(0x03C2 /* GREEK SMALL LETTER FINAL SIGMA */)); + EXPECT_FALSE(CharUtils::isAsciiUpper(0x0410 /* CYRILLIC CAPITAL LETTER A */)); + EXPECT_FALSE(CharUtils::isAsciiUpper(0x0430 /* CYRILLIC SMALL LETTER A */)); + EXPECT_FALSE(CharUtils::isAsciiUpper(0x3042 /* HIRAGANA LETTER A */)); + EXPECT_FALSE(CharUtils::isAsciiUpper(0x1F36A /* COOKIE */)); +} + +TEST(CharUtilsTest, TestToLowerCase) { + EXPECT_EQ('a', CharUtils::toLowerCase('A')); + EXPECT_EQ('z', CharUtils::toLowerCase('Z')); + EXPECT_EQ('a', CharUtils::toLowerCase('a')); + EXPECT_EQ('z', CharUtils::toLowerCase('z')); + EXPECT_EQ('@', CharUtils::toLowerCase('@')); + EXPECT_EQ(' ', CharUtils::toLowerCase(' ')); + EXPECT_EQ(0x00E0 /* LATIN SMALL LETTER A WITH GRAVE */, + CharUtils::toLowerCase(0x00C0 /* LATIN CAPITAL LETTER A WITH GRAVE */)); + EXPECT_EQ(0x00E0 /* LATIN SMALL LETTER A WITH GRAVE */, + CharUtils::toLowerCase(0x00E0 /* LATIN SMALL LETTER A WITH GRAVE */)); + EXPECT_EQ(0x03C2 /* GREEK SMALL LETTER FINAL SIGMA */, + CharUtils::toLowerCase(0x03C2 /* GREEK SMALL LETTER FINAL SIGMA */)); + EXPECT_EQ(0x0430 /* CYRILLIC SMALL LETTER A */, + CharUtils::toLowerCase(0x0410 /* CYRILLIC CAPITAL LETTER A */)); + EXPECT_EQ(0x0430 /* CYRILLIC SMALL LETTER A */, + CharUtils::toLowerCase(0x0430 /* CYRILLIC SMALL LETTER A */)); + EXPECT_EQ(0x3042 /* HIRAGANA LETTER A */, + CharUtils::toLowerCase(0x3042 /* HIRAGANA LETTER A */)); + EXPECT_EQ(0x1F36A /* COOKIE */, CharUtils::toLowerCase(0x1F36A /* COOKIE */)); +} + +TEST(CharUtilsTest, TestToBaseLowerCase) { + EXPECT_EQ('a', CharUtils::toBaseLowerCase('A')); + EXPECT_EQ('z', CharUtils::toBaseLowerCase('Z')); + EXPECT_EQ('a', CharUtils::toBaseLowerCase('a')); + EXPECT_EQ('z', CharUtils::toBaseLowerCase('z')); + EXPECT_EQ('@', CharUtils::toBaseLowerCase('@')); + EXPECT_EQ(' ', CharUtils::toBaseLowerCase(' ')); + EXPECT_EQ('a', CharUtils::toBaseLowerCase(0x00C0 /* LATIN CAPITAL LETTER A WITH GRAVE */)); + EXPECT_EQ('a', CharUtils::toBaseLowerCase(0x00E0 /* LATIN SMALL LETTER A WITH GRAVE */)); + EXPECT_EQ(0x03C2 /* GREEK SMALL LETTER FINAL SIGMA */, + CharUtils::toBaseLowerCase(0x03C2 /* GREEK SMALL LETTER FINAL SIGMA */)); + EXPECT_EQ(0x0430 /* CYRILLIC SMALL LETTER A */, + CharUtils::toBaseLowerCase(0x0410 /* CYRILLIC CAPITAL LETTER A */)); + EXPECT_EQ(0x0430 /* CYRILLIC SMALL LETTER A */, + CharUtils::toBaseLowerCase(0x0430 /* CYRILLIC SMALL LETTER A */)); + EXPECT_EQ(0x3042 /* HIRAGANA LETTER A */, + CharUtils::toBaseLowerCase(0x3042 /* HIRAGANA LETTER A */)); + EXPECT_EQ(0x1F36A /* COOKIE */, CharUtils::toBaseLowerCase(0x1F36A /* COOKIE */)); +} + +TEST(CharUtilsTest, TestToBaseCodePoint) { + EXPECT_EQ('A', CharUtils::toBaseCodePoint('A')); + EXPECT_EQ('Z', CharUtils::toBaseCodePoint('Z')); + EXPECT_EQ('a', CharUtils::toBaseCodePoint('a')); + EXPECT_EQ('z', CharUtils::toBaseCodePoint('z')); + EXPECT_EQ('@', CharUtils::toBaseCodePoint('@')); + EXPECT_EQ(' ', CharUtils::toBaseCodePoint(' ')); + EXPECT_EQ('A', CharUtils::toBaseCodePoint(0x00C0 /* LATIN CAPITAL LETTER A WITH GRAVE */)); + EXPECT_EQ('a', CharUtils::toBaseCodePoint(0x00E0 /* LATIN SMALL LETTER A WITH GRAVE */)); + EXPECT_EQ(0x03C2 /* GREEK SMALL LETTER FINAL SIGMA */, + CharUtils::toBaseLowerCase(0x03C2 /* GREEK SMALL LETTER FINAL SIGMA */)); + EXPECT_EQ(0x0410 /* CYRILLIC CAPITAL LETTER A */, + CharUtils::toBaseCodePoint(0x0410 /* CYRILLIC CAPITAL LETTER A */)); + EXPECT_EQ(0x0430 /* CYRILLIC SMALL LETTER A */, + CharUtils::toBaseCodePoint(0x0430 /* CYRILLIC SMALL LETTER A */)); + EXPECT_EQ(0x3042 /* HIRAGANA LETTER A */, + CharUtils::toBaseCodePoint(0x3042 /* HIRAGANA LETTER A */)); + EXPECT_EQ(0x1F36A /* COOKIE */, CharUtils::toBaseCodePoint(0x1F36A /* COOKIE */)); +} + +TEST(CharUtilsTest, TestIsIntentionalOmissionCodePoint) { + EXPECT_TRUE(CharUtils::isIntentionalOmissionCodePoint('\'')); + EXPECT_TRUE(CharUtils::isIntentionalOmissionCodePoint('-')); + EXPECT_FALSE(CharUtils::isIntentionalOmissionCodePoint('a')); + EXPECT_FALSE(CharUtils::isIntentionalOmissionCodePoint('?')); + EXPECT_FALSE(CharUtils::isIntentionalOmissionCodePoint('/')); +} + +TEST(CharUtilsTest, TestIsInUnicodeSpace) { + EXPECT_FALSE(CharUtils::isInUnicodeSpace(NOT_A_CODE_POINT)); + EXPECT_FALSE(CharUtils::isInUnicodeSpace(CODE_POINT_BEGINNING_OF_SENTENCE)); + EXPECT_TRUE(CharUtils::isInUnicodeSpace('a')); + EXPECT_TRUE(CharUtils::isInUnicodeSpace(0x0410 /* CYRILLIC CAPITAL LETTER A */)); + EXPECT_TRUE(CharUtils::isInUnicodeSpace(0x3042 /* HIRAGANA LETTER A */)); + EXPECT_TRUE(CharUtils::isInUnicodeSpace(0x1F36A /* COOKIE */)); +} + +} // namespace +} // namespace latinime diff --git a/native/jni/tests/utils/int_array_view_test.cpp b/native/jni/tests/utils/int_array_view_test.cpp index bd843ab02..4757a416b 100644 --- a/native/jni/tests/utils/int_array_view_test.cpp +++ b/native/jni/tests/utils/int_array_view_test.cpp @@ -18,6 +18,7 @@ #include <gtest/gtest.h> +#include <array> #include <vector> namespace latinime { @@ -45,17 +46,110 @@ TEST(IntArrayViewTest, TestIteration) { TEST(IntArrayViewTest, TestConstructFromArray) { const size_t ARRAY_SIZE = 100; - int intArray[ARRAY_SIZE]; - const auto intArrayView = IntArrayView::fromFixedSizeArray(intArray); + std::array<int, ARRAY_SIZE> intArray; + const auto intArrayView = IntArrayView::fromArray(intArray); EXPECT_EQ(ARRAY_SIZE, intArrayView.size()); } TEST(IntArrayViewTest, TestConstructFromObject) { const int object = 10; - const auto intArrayView = IntArrayView::fromObject(&object); - EXPECT_EQ(1, intArrayView.size()); + const auto intArrayView = IntArrayView::singleElementView(&object); + EXPECT_EQ(1u, intArrayView.size()); EXPECT_EQ(object, intArrayView[0]); } +TEST(IntArrayViewTest, TestContains) { + EXPECT_FALSE(IntArrayView().contains(0)); + EXPECT_FALSE(IntArrayView().contains(1)); + + const std::vector<int> intVector = {3, 2, 1, 0, -1, -2}; + IntArrayView intArrayView(intVector); + EXPECT_TRUE(intArrayView.contains(0)); + EXPECT_TRUE(intArrayView.contains(3)); + EXPECT_TRUE(intArrayView.contains(-2)); + EXPECT_FALSE(intArrayView.contains(-3)); + EXPECT_FALSE(intArrayView.limit(0).contains(3)); +} + +TEST(IntArrayViewTest, TestLimit) { + const std::vector<int> intVector = {3, 2, 1, 0, -1, -2}; + IntArrayView intArrayView(intVector); + + EXPECT_TRUE(intArrayView.limit(0).empty()); + EXPECT_EQ(intArrayView.size(), intArrayView.limit(intArrayView.size()).size()); + EXPECT_EQ(intArrayView.size(), intArrayView.limit(1000).size()); + + IntArrayView subView = intArrayView.limit(4); + EXPECT_EQ(4u, subView.size()); + for (size_t i = 0; i < subView.size(); ++i) { + EXPECT_EQ(intVector[i], subView[i]); + } +} + +TEST(IntArrayViewTest, TestSkip) { + const std::vector<int> intVector = {3, 2, 1, 0, -1, -2}; + IntArrayView intArrayView(intVector); + + EXPECT_TRUE(intArrayView.skip(intVector.size()).empty()); + EXPECT_TRUE(intArrayView.skip(intVector.size() + 1).empty()); + EXPECT_EQ(intArrayView.size(), intArrayView.skip(0).size()); + EXPECT_EQ(intArrayView.size(), intArrayView.limit(1000).size()); + + static const size_t SKIP_COUNT = 2; + IntArrayView subView = intArrayView.skip(SKIP_COUNT); + EXPECT_EQ(intVector.size() - SKIP_COUNT, subView.size()); + for (size_t i = 0; i < subView.size(); ++i) { + EXPECT_EQ(intVector[i + SKIP_COUNT], subView[i]); + } +} + +TEST(IntArrayViewTest, TestCopyToArray) { + // "{{" to suppress warning. + std::array<int, 7> buffer = {{10, 20, 30, 40, 50, 60, 70}}; + const std::vector<int> intVector = {3, 2, 1, 0, -1, -2}; + IntArrayView intArrayView(intVector); + intArrayView.limit(0).copyToArray(&buffer, 0); + EXPECT_EQ(10, buffer[0]); + EXPECT_EQ(20, buffer[1]); + intArrayView.limit(1).copyToArray(&buffer, 0); + EXPECT_EQ(intVector[0], buffer[0]); + EXPECT_EQ(20, buffer[1]); + intArrayView.limit(1).copyToArray(&buffer, 1); + EXPECT_EQ(intVector[0], buffer[0]); + EXPECT_EQ(intVector[0], buffer[1]); + intArrayView.copyToArray(&buffer, 0); + for (size_t i = 0; i < intArrayView.size(); ++i) { + EXPECT_EQ(intVector[i], buffer[i]); + } + EXPECT_EQ(70, buffer[6]); +} + +TEST(IntArrayViewTest, TestFirstOrDefault) { + const std::vector<int> intVector = {3, 2, 1, 0, -1, -2}; + IntArrayView intArrayView(intVector); + + EXPECT_EQ(3, intArrayView.firstOrDefault(10)); + EXPECT_EQ(10, intArrayView.limit(0).firstOrDefault(10)); + EXPECT_EQ(-10, intArrayView.limit(0).firstOrDefault(-10)); + EXPECT_EQ(10, intArrayView.skip(6).firstOrDefault(10)); +} + +TEST(IntArrayViewTest, TestLastOrDefault) { + const std::vector<int> intVector = {3, 2, 1, 0, -1, -2}; + IntArrayView intArrayView(intVector); + + EXPECT_EQ(-2, intArrayView.lastOrDefault(10)); + EXPECT_EQ(10, intArrayView.limit(0).lastOrDefault(10)); + EXPECT_EQ(-10, intArrayView.limit(0).lastOrDefault(-10)); + EXPECT_EQ(10, intArrayView.skip(6).lastOrDefault(10)); +} + +TEST(IntArrayViewTest, TestToVector) { + const std::vector<int> intVector = {3, 2, 1, 0, -1, -2}; + IntArrayView intArrayView(intVector); + EXPECT_EQ(intVector, intArrayView.toVector()); + EXPECT_EQ(std::vector<int>(), CodePointArrayView().toVector()); +} + } // namespace } // namespace latinime diff --git a/native/jni/tests/utils/time_keeper_test.cpp b/native/jni/tests/utils/time_keeper_test.cpp new file mode 100644 index 000000000..3f54b91f1 --- /dev/null +++ b/native/jni/tests/utils/time_keeper_test.cpp @@ -0,0 +1,38 @@ +/* + * Copyright (C) 2014 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/time_keeper.h" + +#include <gtest/gtest.h> + +namespace latinime { +namespace { + +TEST(TimeKeeperTest, TestTestMode) { + TimeKeeper::setCurrentTime(); + const int startTime = TimeKeeper::peekCurrentTime(); + static const int TEST_CURRENT_TIME = 100; + TimeKeeper::startTestModeWithForceCurrentTime(TEST_CURRENT_TIME); + EXPECT_EQ(TEST_CURRENT_TIME, TimeKeeper::peekCurrentTime()); + TimeKeeper::setCurrentTime(); + EXPECT_EQ(TEST_CURRENT_TIME, TimeKeeper::peekCurrentTime()); + TimeKeeper::stopTestMode(); + TimeKeeper::setCurrentTime(); + EXPECT_LE(startTime, TimeKeeper::peekCurrentTime()); +} + +} // namespace +} // namespace latinime |