diff options
Diffstat (limited to 'native')
74 files changed, 4712 insertions, 1411 deletions
diff --git a/native/jni/Android.mk b/native/jni/Android.mk index 9e7407abd..0594ddff0 100644 --- a/native/jni/Android.mk +++ b/native/jni/Android.mk @@ -43,7 +43,7 @@ LATIN_IME_JNI_SRC_FILES := \ com_android_inputmethod_keyboard_ProximityInfo.cpp \ com_android_inputmethod_latin_BinaryDictionary.cpp \ com_android_inputmethod_latin_DicTraverseSession.cpp \ - com_android_inputmethod_latin_makedict_BinaryDictDecoder.cpp \ + com_android_inputmethod_latin_makedict_Ver3DictDecoder.cpp \ jni_common.cpp LATIN_IME_CORE_SRC_FILES := \ @@ -67,18 +67,26 @@ LATIN_IME_CORE_SRC_FILES := \ suggest/core/policy/weighting.cpp \ suggest/core/session/dic_traverse_session.cpp \ $(addprefix suggest/policyimpl/dictionary/, \ - bigram/bigram_list_reading_utils.cpp \ + bigram/bigram_list_read_write_utils.cpp \ + bigram/dynamic_bigram_list_policy.cpp \ header/header_policy.cpp \ - header/header_reading_utils.cpp \ + header/header_read_write_utils.cpp \ shortcut/shortcut_list_reading_utils.cpp \ - utils/byte_array_utils.cpp \ - utils/format_utils.cpp \ dictionary_structure_with_buffer_policy_factory.cpp \ + dynamic_patricia_trie_gc_event_listeners.cpp \ dynamic_patricia_trie_node_reader.cpp \ dynamic_patricia_trie_policy.cpp \ + dynamic_patricia_trie_reading_helper.cpp \ dynamic_patricia_trie_reading_utils.cpp \ + dynamic_patricia_trie_writing_helper.cpp \ + dynamic_patricia_trie_writing_utils.cpp \ patricia_trie_policy.cpp \ patricia_trie_reading_utils.cpp) \ + $(addprefix suggest/policyimpl/dictionary/utils/, \ + buffer_with_extendable_buffer.cpp \ + byte_array_utils.cpp \ + dict_file_writing_utils.cpp \ + format_utils.cpp) \ suggest/policyimpl/gesture/gesture_suggest_policy_factory.cpp \ $(addprefix suggest/policyimpl/typing/, \ scoring_params.cpp \ diff --git a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp index 86c2394d1..7761ec4d5 100644 --- a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp +++ b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp @@ -26,12 +26,55 @@ #include "suggest/core/dictionary/dictionary.h" #include "suggest/core/suggest_options.h" #include "suggest/policyimpl/dictionary/dictionary_structure_with_buffer_policy_factory.h" +#include "suggest/policyimpl/dictionary/utils/dict_file_writing_utils.h" #include "utils/autocorrection_threshold_utils.h" namespace latinime { class ProximityInfo; +// TODO: Move to makedict. +static jboolean latinime_BinaryDictionary_createEmptyDictFile(JNIEnv *env, jclass clazz, + jstring filePath, jlong dictVersion, jobjectArray attributeKeyStringArray, + jobjectArray attributeValueStringArray) { + const jsize filePathUtf8Length = env->GetStringUTFLength(filePath); + char filePathChars[filePathUtf8Length + 1]; + env->GetStringUTFRegion(filePath, 0, env->GetStringLength(filePath), filePathChars); + filePathChars[filePathUtf8Length] = '\0'; + + const int keyCount = env->GetArrayLength(attributeKeyStringArray); + const int valueCount = env->GetArrayLength(attributeValueStringArray); + if (keyCount != valueCount) { + return false; + } + + HeaderReadWriteUtils::AttributeMap attributeMap; + for (int i = 0; i < keyCount; i++) { + jstring keyString = static_cast<jstring>( + env->GetObjectArrayElement(attributeKeyStringArray, i)); + const jsize keyUtf8Length = env->GetStringUTFLength(keyString); + char keyChars[keyUtf8Length + 1]; + env->GetStringUTFRegion(keyString, 0, env->GetStringLength(keyString), keyChars); + keyChars[keyUtf8Length] = '\0'; + HeaderReadWriteUtils::AttributeMap::key_type key; + HeaderReadWriteUtils::insertCharactersIntoVector(keyChars, &key); + + jstring valueString = static_cast<jstring>( + env->GetObjectArrayElement(attributeValueStringArray, i)); + const jsize valueUtf8Length = env->GetStringUTFLength(valueString); + char valueChars[valueUtf8Length + 1]; + env->GetStringUTFRegion(valueString, 0, env->GetStringLength(valueString), valueChars); + valueChars[valueUtf8Length] = '\0'; + HeaderReadWriteUtils::AttributeMap::mapped_type value; + HeaderReadWriteUtils::insertCharactersIntoVector(valueChars, &value); + + attributeMap[key] = value; + } + + return DictFileWritingUtils::createEmptyDictFile(filePathChars, static_cast<int>(dictVersion), + &attributeMap); +} + static jlong latinime_BinaryDictionary_open(JNIEnv *env, jclass clazz, jstring sourceDir, jlong dictOffset, jlong dictSize, jboolean isUpdatable) { PROF_OPEN; @@ -46,8 +89,7 @@ static jlong latinime_BinaryDictionary_open(JNIEnv *env, jclass clazz, jstring s sourceDirChars[sourceDirUtf8Length] = '\0'; DictionaryStructureWithBufferPolicy *const dictionaryStructureWithBufferPolicy = DictionaryStructureWithBufferPolicyFactory::newDictionaryStructureWithBufferPolicy( - sourceDirChars, static_cast<int>(sourceDirUtf8Length), - static_cast<int>(dictOffset), static_cast<int>(dictSize), + sourceDirChars, static_cast<int>(dictOffset), static_cast<int>(dictSize), isUpdatable == JNI_TRUE); if (!dictionaryStructureWithBufferPolicy) { return 0; @@ -59,6 +101,35 @@ static jlong latinime_BinaryDictionary_open(JNIEnv *env, jclass clazz, jstring s return reinterpret_cast<jlong>(dictionary); } +static void latinime_BinaryDictionary_flush(JNIEnv *env, jclass clazz, jlong dict, + jstring filePath) { + Dictionary *dictionary = reinterpret_cast<Dictionary *>(dict); + if (!dictionary) return; + const jsize filePathUtf8Length = env->GetStringUTFLength(filePath); + char filePathChars[filePathUtf8Length + 1]; + env->GetStringUTFRegion(filePath, 0, env->GetStringLength(filePath), filePathChars); + filePathChars[filePathUtf8Length] = '\0'; + dictionary->flush(filePathChars); +} + +static bool latinime_BinaryDictionary_needsToRunGC(JNIEnv *env, jclass clazz, + jlong dict) { + Dictionary *dictionary = reinterpret_cast<Dictionary *>(dict); + if (!dictionary) return false; + return dictionary->needsToRunGC(); +} + +static void latinime_BinaryDictionary_flushWithGC(JNIEnv *env, jclass clazz, jlong dict, + jstring filePath) { + Dictionary *dictionary = reinterpret_cast<Dictionary *>(dict); + if (!dictionary) return; + const jsize filePathUtf8Length = env->GetStringUTFLength(filePath); + char filePathChars[filePathUtf8Length + 1]; + env->GetStringUTFRegion(filePath, 0, env->GetStringLength(filePath), filePathChars); + filePathChars[filePathUtf8Length] = '\0'; + dictionary->flushWithGC(filePathChars); +} + static void latinime_BinaryDictionary_close(JNIEnv *env, jclass clazz, jlong dict) { Dictionary *dictionary = reinterpret_cast<Dictionary *>(dict); if (!dictionary) return; @@ -70,7 +141,8 @@ static int latinime_BinaryDictionary_getSuggestions(JNIEnv *env, jclass clazz, j jintArray yCoordinatesArray, jintArray timesArray, jintArray pointerIdsArray, jintArray inputCodePointsArray, jint inputSize, jint commitPoint, jintArray suggestOptions, jintArray prevWordCodePointsForBigrams, jintArray outputCodePointsArray, - jintArray scoresArray, jintArray spaceIndicesArray, jintArray outputTypesArray) { + jintArray scoresArray, jintArray spaceIndicesArray, jintArray outputTypesArray, + jintArray outputAutoCommitFirstWordConfidence) { Dictionary *dictionary = reinterpret_cast<Dictionary *>(dict); if (!dictionary) return 0; ProximityInfo *pInfo = reinterpret_cast<ProximityInfo *>(proximityInfo); @@ -159,8 +231,8 @@ static jint latinime_BinaryDictionary_getProbability(JNIEnv *env, jclass clazz, return dictionary->getProbability(codePoints, wordLength); } -static jboolean latinime_BinaryDictionary_isValidBigram(JNIEnv *env, jclass clazz, jlong dict, - jintArray word0, jintArray word1) { +static jint latinime_BinaryDictionary_getBigramProbability(JNIEnv *env, jclass clazz, + jlong dict, jintArray word0, jintArray word1) { Dictionary *dictionary = reinterpret_cast<Dictionary *>(dict); if (!dictionary) return JNI_FALSE; const jsize word0Length = env->GetArrayLength(word0); @@ -169,7 +241,8 @@ static jboolean latinime_BinaryDictionary_isValidBigram(JNIEnv *env, jclass claz int word1CodePoints[word1Length]; env->GetIntArrayRegion(word0, 0, word0Length, word0CodePoints); env->GetIntArrayRegion(word1, 0, word1Length, word1CodePoints); - return dictionary->isValidBigram(word0CodePoints, word0Length, word1CodePoints, word1Length); + return dictionary->getBigramProbability(word0CodePoints, word0Length, word1CodePoints, + word1Length); } static jfloat latinime_BinaryDictionary_calcNormalizedScore(JNIEnv *env, jclass clazz, @@ -204,6 +277,7 @@ static void latinime_BinaryDictionary_addUnigramWord(JNIEnv *env, jclass clazz, } jsize wordLength = env->GetArrayLength(word); int codePoints[wordLength]; + env->GetIntArrayRegion(word, 0, wordLength, codePoints); dictionary->addUnigramWord(codePoints, wordLength, probability); } @@ -215,8 +289,10 @@ static void latinime_BinaryDictionary_addBigramWords(JNIEnv *env, jclass clazz, } jsize word0Length = env->GetArrayLength(word0); int word0CodePoints[word0Length]; + env->GetIntArrayRegion(word0, 0, word0Length, word0CodePoints); jsize word1Length = env->GetArrayLength(word1); int word1CodePoints[word1Length]; + env->GetIntArrayRegion(word1, 0, word1Length, word1CodePoints); dictionary->addBigramWords(word0CodePoints, word0Length, word1CodePoints, word1Length, probability); } @@ -229,14 +305,31 @@ static void latinime_BinaryDictionary_removeBigramWords(JNIEnv *env, jclass claz } jsize word0Length = env->GetArrayLength(word0); int word0CodePoints[word0Length]; + env->GetIntArrayRegion(word0, 0, word0Length, word0CodePoints); jsize word1Length = env->GetArrayLength(word1); int word1CodePoints[word1Length]; + env->GetIntArrayRegion(word1, 0, word1Length, word1CodePoints); dictionary->removeBigramWords(word0CodePoints, word0Length, word1CodePoints, word1Length); } +static int latinime_BinaryDictionary_calculateProbabilityNative(JNIEnv *env, jclass clazz, + jlong dict, jint unigramProbability, jint bigramProbability) { + Dictionary *dictionary = reinterpret_cast<Dictionary *>(dict); + if (!dictionary) { + return NOT_A_PROBABILITY; + } + return dictionary->getDictionaryStructurePolicy()->getProbability(unigramProbability, + bigramProbability); +} + static const JNINativeMethod sMethods[] = { { + const_cast<char *>("createEmptyDictFileNative"), + const_cast<char *>("(Ljava/lang/String;J[Ljava/lang/String;[Ljava/lang/String;)Z"), + reinterpret_cast<void *>(latinime_BinaryDictionary_createEmptyDictFile) + }, + { const_cast<char *>("openNative"), const_cast<char *>("(Ljava/lang/String;JJZ)J"), reinterpret_cast<void *>(latinime_BinaryDictionary_open) @@ -247,8 +340,23 @@ static const JNINativeMethod sMethods[] = { reinterpret_cast<void *>(latinime_BinaryDictionary_close) }, { + const_cast<char *>("flushNative"), + const_cast<char *>("(JLjava/lang/String;)V"), + reinterpret_cast<void *>(latinime_BinaryDictionary_flush) + }, + { + const_cast<char *>("needsToRunGCNative"), + const_cast<char *>("(J)Z"), + reinterpret_cast<void *>(latinime_BinaryDictionary_needsToRunGC) + }, + { + const_cast<char *>("flushWithGCNative"), + const_cast<char *>("(JLjava/lang/String;)V"), + reinterpret_cast<void *>(latinime_BinaryDictionary_flushWithGC) + }, + { const_cast<char *>("getSuggestionsNative"), - const_cast<char *>("(JJJ[I[I[I[I[III[I[I[I[I[I[I)I"), + const_cast<char *>("(JJJ[I[I[I[I[III[I[I[I[I[I[I[I)I"), reinterpret_cast<void *>(latinime_BinaryDictionary_getSuggestions) }, { @@ -257,9 +365,9 @@ static const JNINativeMethod sMethods[] = { reinterpret_cast<void *>(latinime_BinaryDictionary_getProbability) }, { - const_cast<char *>("isValidBigramNative"), - const_cast<char *>("(J[I[I)Z"), - reinterpret_cast<void *>(latinime_BinaryDictionary_isValidBigram) + const_cast<char *>("getBigramProbabilityNative"), + const_cast<char *>("(J[I[I)I"), + reinterpret_cast<void *>(latinime_BinaryDictionary_getBigramProbability) }, { const_cast<char *>("calcNormalizedScoreNative"), @@ -285,6 +393,11 @@ static const JNINativeMethod sMethods[] = { const_cast<char *>("removeBigramWordsNative"), const_cast<char *>("(J[I[I)V"), reinterpret_cast<void *>(latinime_BinaryDictionary_removeBigramWords) + }, + { + const_cast<char *>("calculateProbabilityNative"), + const_cast<char *>("(JII)I"), + reinterpret_cast<void *>(latinime_BinaryDictionary_calculateProbabilityNative) } }; diff --git a/native/jni/com_android_inputmethod_latin_DicTraverseSession.cpp b/native/jni/com_android_inputmethod_latin_DicTraverseSession.cpp index 72e625836..386643332 100644 --- a/native/jni/com_android_inputmethod_latin_DicTraverseSession.cpp +++ b/native/jni/com_android_inputmethod_latin_DicTraverseSession.cpp @@ -25,8 +25,9 @@ namespace latinime { class Dictionary; -static jlong latinime_setDicTraverseSession(JNIEnv *env, jclass clazz, jstring localeJStr) { - void *traverseSession = DicTraverseSession::getSessionInstance(env, localeJStr); +static jlong latinime_setDicTraverseSession(JNIEnv *env, jclass clazz, jstring localeJStr, + jlong dictSize) { + void *traverseSession = DicTraverseSession::getSessionInstance(env, localeJStr, dictSize); return reinterpret_cast<jlong>(traverseSession); } @@ -53,7 +54,7 @@ static void latinime_releaseDicTraverseSession(JNIEnv *env, jclass clazz, jlong static const JNINativeMethod sMethods[] = { { const_cast<char *>("setDicTraverseSessionNative"), - const_cast<char *>("(Ljava/lang/String;)J"), + const_cast<char *>("(Ljava/lang/String;J)J"), reinterpret_cast<void *>(latinime_setDicTraverseSession) }, { diff --git a/native/jni/com_android_inputmethod_latin_makedict_BinaryDictDecoder.cpp b/native/jni/com_android_inputmethod_latin_makedict_Ver3DictDecoder.cpp index 457b226b6..15088b65a 100644 --- a/native/jni/com_android_inputmethod_latin_makedict_BinaryDictDecoder.cpp +++ b/native/jni/com_android_inputmethod_latin_makedict_Ver3DictDecoder.cpp @@ -14,16 +14,16 @@ * limitations under the License. */ -#define LOG_TAG "LatinIME: jni: BinaryDictDecoder" +#define LOG_TAG "LatinIME: jni: Ver3DictDecoder" -#include "com_android_inputmethod_latin_makedict_BinaryDictDecoder.h" +#include "com_android_inputmethod_latin_makedict_Ver3DictDecoder.h" #include "defines.h" #include "jni.h" #include "jni_common.h" namespace latinime { -static int latinime_BinaryDictDecoder_doNothing(JNIEnv *env, jclass clazz) { +static int latinime_Ver3DictDecoder_doNothing(JNIEnv *env, jclass clazz) { // This is a phony method for test - it does nothing. It just returns some value // unlikely to be in memory by chance for testing purposes. // TODO: remove this method. @@ -35,13 +35,13 @@ static const JNINativeMethod sMethods[] = { // TODO: remove this entry when we have one useful method in here const_cast<char *>("doNothing"), const_cast<char *>("()I"), - reinterpret_cast<void *>(latinime_BinaryDictDecoder_doNothing) + reinterpret_cast<void *>(latinime_Ver3DictDecoder_doNothing) }, }; -int register_BinaryDictDecoder(JNIEnv *env) { +int register_Ver3DictDecoder(JNIEnv *env) { const char *const kClassPathName = - "com/android/inputmethod/latin/makedict/BinaryDictDecoder"; + "com/android/inputmethod/latin/makedict/Ver3DictDecoder"; return registerNativeMethods(env, kClassPathName, sMethods, NELEMS(sMethods)); } } // namespace latinime diff --git a/native/jni/com_android_inputmethod_latin_makedict_BinaryDictDecoder.h b/native/jni/com_android_inputmethod_latin_makedict_Ver3DictDecoder.h index 7f3cb67e6..07e80f1d8 100644 --- a/native/jni/com_android_inputmethod_latin_makedict_BinaryDictDecoder.h +++ b/native/jni/com_android_inputmethod_latin_makedict_Ver3DictDecoder.h @@ -14,12 +14,12 @@ * limitations under the License. */ -#ifndef _COM_ANDROID_INPUTMETHOD_LATIN_MAKEDICT_BINARYDICTDECODER_H -#define _COM_ANDROID_INPUTMETHOD_LATIN_MAKEDICT_BINARYDICTDECODER_H +#ifndef _COM_ANDROID_INPUTMETHOD_LATIN_MAKEDICT_VER3DICTDECODER_H +#define _COM_ANDROID_INPUTMETHOD_LATIN_MAKEDICT_VER3DICTDECODER_H #include "jni.h" namespace latinime { -int register_BinaryDictDecoder(JNIEnv *env); +int register_Ver3DictDecoder(JNIEnv *env); } // namespace latinime -#endif // _COM_ANDROID_INPUTMETHOD_LATIN_MAKEDICT_BINARYDICTDECODER_H +#endif // _COM_ANDROID_INPUTMETHOD_LATIN_MAKEDICT_VER3DICTDECODER_H diff --git a/native/jni/jni_common.cpp b/native/jni/jni_common.cpp index d44be6705..3a8f4362d 100644 --- a/native/jni/jni_common.cpp +++ b/native/jni/jni_common.cpp @@ -23,7 +23,7 @@ #include "com_android_inputmethod_latin_BinaryDictionary.h" #include "com_android_inputmethod_latin_DicTraverseSession.h" #endif -#include "com_android_inputmethod_latin_makedict_BinaryDictDecoder.h" +#include "com_android_inputmethod_latin_makedict_Ver3DictDecoder.h" #include "defines.h" /* @@ -55,8 +55,8 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { return -1; } #endif - if (!latinime::register_BinaryDictDecoder(env)) { - AKLOGE("ERROR: BinaryDictDecoder native registration failed"); + if (!latinime::register_Ver3DictDecoder(env)) { + AKLOGE("ERROR: Ver3DictDecoder native registration failed"); return -1; } /* success -- return valid version number */ diff --git a/native/jni/src/defines.h b/native/jni/src/defines.h index 34a646f80..89dfa39b3 100644 --- a/native/jni/src/defines.h +++ b/native/jni/src/defines.h @@ -292,7 +292,6 @@ static inline void prof_out(void) { // of the binary dictionary where a {key,value} string pair scheme is used. #define LARGEST_INT_DIGIT_COUNT 11 -#define NOT_A_VALID_WORD_POS (-99) #define NOT_A_CODE_POINT (-1) #define NOT_A_DISTANCE (-1) #define NOT_A_COORDINATE (-1) @@ -322,13 +321,6 @@ static inline void prof_out(void) { #define MAX_POINTER_COUNT 1 #define MAX_POINTER_COUNT_G 2 -// Queue IDs and size for DicNodesCache -#define DIC_NODES_CACHE_INITIAL_QUEUE_ID_ACTIVE 0 -#define DIC_NODES_CACHE_INITIAL_QUEUE_ID_NEXT_ACTIVE 1 -#define DIC_NODES_CACHE_INITIAL_QUEUE_ID_TERMINAL 2 -#define DIC_NODES_CACHE_INITIAL_QUEUE_ID_CACHE_FOR_CONTINUOUS_SUGGESTION 3 -#define DIC_NODES_CACHE_PRIORITY_QUEUES_SIZE 4 - template<typename T> AK_FORCE_INLINE const T &min(const T &a, const T &b) { return a < b ? a : b; } template<typename T> AK_FORCE_INLINE const T &max(const T &a, const T &b) { return a > b ? a : b; } diff --git a/native/jni/src/suggest/core/dicnode/dic_node.h b/native/jni/src/suggest/core/dicnode/dic_node.h index cdd9f59aa..41ef9d2b2 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node.h +++ b/native/jni/src/suggest/core/dicnode/dic_node.h @@ -112,7 +112,7 @@ class DicNode { mIsUsed = true; mIsCachedForNextSuggestion = false; mDicNodeProperties.init( - NOT_A_VALID_WORD_POS /* pos */, rootGroupPos, NOT_A_CODE_POINT /* nodeCodePoint */, + NOT_A_DICT_POS /* pos */, rootGroupPos, NOT_A_CODE_POINT /* nodeCodePoint */, NOT_A_PROBABILITY /* probability */, false /* isTerminal */, true /* hasChildren */, false /* isBlacklistedOrNotAWord */, 0 /* depth */, 0 /* terminalDepth */); @@ -125,7 +125,7 @@ class DicNode { mIsUsed = true; mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion; mDicNodeProperties.init( - NOT_A_VALID_WORD_POS /* pos */, rootGroupPos, NOT_A_CODE_POINT /* nodeCodePoint */, + NOT_A_DICT_POS /* pos */, rootGroupPos, NOT_A_CODE_POINT /* nodeCodePoint */, NOT_A_PROBABILITY /* probability */, false /* isTerminal */, true /* hasChildren */, false /* isBlacklistedOrNotAWord */, 0 /* depth */, 0 /* terminalDepth */); @@ -143,7 +143,7 @@ class DicNode { dicNode->mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(), dicNode->getOutputWordBuf(), dicNode->mDicNodeProperties.getDepth(), - dicNode->mDicNodeState.mDicNodeStatePrevWord.mPrevSpacePositions, + dicNode->mDicNodeState.mDicNodeStatePrevWord.getSecondWordFirstInputIndex(), mDicNodeState.mDicNodeStateInput.getInputIndex(0) /* lastInputIndex */); PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); } @@ -234,7 +234,7 @@ class DicNode { } bool isFirstWord() const { - return mDicNodeState.mDicNodeStatePrevWord.getPrevWordNodePos() == NOT_A_VALID_WORD_POS; + return mDicNodeState.mDicNodeStatePrevWord.getPrevWordNodePos() == NOT_A_DICT_POS; } bool isCompletion(const int inputSize) const { @@ -321,8 +321,13 @@ class DicNode { DUMP_WORD_AND_SCORE("OUTPUT"); } - void outputSpacePositionsResult(int *spaceIndices) const { - mDicNodeState.mDicNodeStatePrevWord.outputSpacePositions(spaceIndices); + int getSecondWordFirstInputIndex(const ProximityInfoState *const pInfoState) const { + const int inputIndex = mDicNodeState.mDicNodeStatePrevWord.getSecondWordFirstInputIndex(); + if (inputIndex == NOT_AN_INDEX) { + return NOT_AN_INDEX; + } else { + return pInfoState->getInputIndexOfSampledPoint(inputIndex); + } } bool hasMultipleWords() const { @@ -573,7 +578,11 @@ class DicNode { } } - AK_FORCE_INLINE void updateInputIndexG(DicNode_InputStateG *inputStateG) { + AK_FORCE_INLINE void updateInputIndexG(const DicNode_InputStateG *const inputStateG) { + if (mDicNodeState.mDicNodeStatePrevWord.getPrevWordCount() == 1 && isFirstLetter()) { + mDicNodeState.mDicNodeStatePrevWord.setSecondWordFirstInputIndex( + inputStateG->mInputIndex); + } mDicNodeState.mDicNodeStateInput.updateInputIndexG(inputStateG->mPointerId, inputStateG->mInputIndex, inputStateG->mPrevCodePoint, inputStateG->mTerminalDiffCost, inputStateG->mRawLength); diff --git a/native/jni/src/suggest/core/dicnode/dic_node_priority_queue.h b/native/jni/src/suggest/core/dicnode/dic_node_priority_queue.h index 2a486b804..7461f0cc6 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_priority_queue.h +++ b/native/jni/src/suggest/core/dicnode/dic_node_priority_queue.h @@ -24,20 +24,16 @@ #include "suggest/core/dicnode/dic_node.h" #include "suggest/core/dicnode/dic_node_release_listener.h" -// The biggest value among MAX_CACHE_DIC_NODE_SIZE, MAX_CACHE_DIC_NODE_SIZE_FOR_SINGLE_POINT, ... -#define MAX_DIC_NODE_PRIORITY_QUEUE_CAPACITY 310 - namespace latinime { class DicNodePriorityQueue : public DicNodeReleaseListener { public: - AK_FORCE_INLINE DicNodePriorityQueue() - : MAX_CAPACITY(MAX_DIC_NODE_PRIORITY_QUEUE_CAPACITY), - mMaxSize(MAX_DIC_NODE_PRIORITY_QUEUE_CAPACITY), mDicNodesBuf(), mUnusedNodeIndices(), - mNextUnusedNodeId(0), mDicNodesQueue() { - mDicNodesBuf.resize(MAX_CAPACITY + 1); - mUnusedNodeIndices.resize(MAX_CAPACITY + 1); - reset(); + AK_FORCE_INLINE explicit DicNodePriorityQueue(const int capacity) + : mCapacity(capacity), mMaxSize(capacity), mDicNodesBuf(), + mUnusedNodeIndices(), mNextUnusedNodeId(0), mDicNodesQueue() { + mDicNodesBuf.resize(mCapacity + 1); + mUnusedNodeIndices.resize(mCapacity + 1); + clearAndResizeToCapacity(); } // Non virtual inline destructor -- never inherit this class @@ -52,11 +48,12 @@ class DicNodePriorityQueue : public DicNodeReleaseListener { } AK_FORCE_INLINE void setMaxSize(const int maxSize) { - mMaxSize = min(maxSize, MAX_CAPACITY); + ASSERT(maxSize <= mCapacity); + mMaxSize = min(maxSize, mCapacity); } - AK_FORCE_INLINE void reset() { - clearAndResize(MAX_CAPACITY); + AK_FORCE_INLINE void clearAndResizeToCapacity() { + clearAndResize(mCapacity); } AK_FORCE_INLINE void clear() { @@ -64,27 +61,19 @@ class DicNodePriorityQueue : public DicNodeReleaseListener { } AK_FORCE_INLINE void clearAndResize(const int maxSize) { + ASSERT(maxSize <= mCapacity); while (!mDicNodesQueue.empty()) { mDicNodesQueue.pop(); } setMaxSize(maxSize); - for (int i = 0; i < MAX_CAPACITY + 1; ++i) { + for (int i = 0; i < mCapacity + 1; ++i) { mDicNodesBuf[i].remove(); mDicNodesBuf[i].setReleaseListener(this); - mUnusedNodeIndices[i] = i == MAX_CAPACITY ? NOT_A_NODE_ID : static_cast<int>(i) + 1; + mUnusedNodeIndices[i] = i == mCapacity ? NOT_A_NODE_ID : static_cast<int>(i) + 1; } mNextUnusedNodeId = 0; } - AK_FORCE_INLINE DicNode *newDicNode(DicNode *dicNode) { - DicNode *newNode = searchEmptyDicNode(); - if (newNode) { - DicNodeUtils::initByCopy(dicNode, newNode); - return newNode; - } - return 0; - } - // Copy AK_FORCE_INLINE DicNode *copyPush(DicNode *dicNode) { return copyPush(dicNode, mMaxSize); @@ -111,12 +100,12 @@ class DicNodePriorityQueue : public DicNodeReleaseListener { } mUnusedNodeIndices[index] = mNextUnusedNodeId; mNextUnusedNodeId = index; - ASSERT(index >= 0 && index < (MAX_CAPACITY + 1)); + ASSERT(index >= 0 && index < (mCapacity + 1)); } AK_FORCE_INLINE void dump() const { AKLOGI("\n\n\n\n\n==========================="); - for (int i = 0; i < MAX_CAPACITY + 1; ++i) { + for (int i = 0; i < mCapacity + 1; ++i) { if (mDicNodesBuf[i].isUsed()) { mDicNodesBuf[i].dump("QUEUE: "); } @@ -125,7 +114,7 @@ class DicNodePriorityQueue : public DicNodeReleaseListener { } private: - DISALLOW_COPY_AND_ASSIGN(DicNodePriorityQueue); + DISALLOW_IMPLICIT_CONSTRUCTORS(DicNodePriorityQueue); static const int NOT_A_NODE_ID = -1; AK_FORCE_INLINE static bool compareDicNode(DicNode *left, DicNode *right) { @@ -139,7 +128,7 @@ class DicNodePriorityQueue : public DicNodeReleaseListener { }; typedef std::priority_queue<DicNode *, std::vector<DicNode *>, DicNodeComparator> DicNodesQueue; - const int MAX_CAPACITY; + const int mCapacity; int mMaxSize; std::vector<DicNode> mDicNodesBuf; // of each element of mDicNodesBuf respectively std::vector<int> mUnusedNodeIndices; @@ -163,13 +152,12 @@ class DicNodePriorityQueue : public DicNodeReleaseListener { } AK_FORCE_INLINE DicNode *searchEmptyDicNode() { - // TODO: Currently O(n) but should be improved to O(1) - if (MAX_CAPACITY == 0) { + if (mCapacity == 0) { return 0; } if (mNextUnusedNodeId == NOT_A_NODE_ID) { AKLOGI("No unused node found."); - for (int i = 0; i < MAX_CAPACITY + 1; ++i) { + for (int i = 0; i < mCapacity + 1; ++i) { AKLOGI("Dump node availability, %d, %d, %d", i, mDicNodesBuf[i].isUsed(), mUnusedNodeIndices[i]); } @@ -185,7 +173,7 @@ class DicNodePriorityQueue : public DicNodeReleaseListener { const int index = static_cast<int>(dicNode - &mDicNodesBuf[0]); mNextUnusedNodeId = mUnusedNodeIndices[index]; mUnusedNodeIndices[index] = NOT_A_NODE_ID; - ASSERT(index >= 0 && index < (MAX_CAPACITY + 1)); + ASSERT(index >= 0 && index < (mCapacity + 1)); } AK_FORCE_INLINE DicNode *pushPoolNodeWithMaxSize(DicNode *dicNode, const int maxSize) { @@ -209,6 +197,15 @@ class DicNodePriorityQueue : public DicNodeReleaseListener { AK_FORCE_INLINE DicNode *copyPush(DicNode *dicNode, const int maxSize) { return pushPoolNodeWithMaxSize(newDicNode(dicNode), maxSize); } + + AK_FORCE_INLINE DicNode *newDicNode(DicNode *dicNode) { + DicNode *newNode = searchEmptyDicNode(); + if (newNode) { + DicNodeUtils::initByCopy(dicNode, newNode); + } + return newNode; + } + }; } // namespace latinime #endif // LATINIME_DIC_NODE_PRIORITY_QUEUE_H diff --git a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp index bb54e608e..ec65114c7 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp +++ b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp @@ -21,7 +21,6 @@ #include "suggest/core/dicnode/dic_node.h" #include "suggest/core/dicnode/dic_node_vector.h" #include "suggest/core/dictionary/multi_bigram_map.h" -#include "suggest/core/dictionary/probability_utils.h" #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" #include "utils/char_utils.h" @@ -90,16 +89,18 @@ namespace latinime { const int unigramProbability = node->getProbability(); const int wordPos = node->getPos(); const int prevWordPos = node->getPrevWordPos(); - if (NOT_A_VALID_WORD_POS == wordPos || NOT_A_VALID_WORD_POS == prevWordPos) { + if (NOT_A_DICT_POS == wordPos || NOT_A_DICT_POS == prevWordPos) { // Note: Normally wordPos comes from the dictionary and should never equal // NOT_A_VALID_WORD_POS. - return ProbabilityUtils::backoff(unigramProbability); + return dictionaryStructurePolicy->getProbability(unigramProbability, + NOT_A_PROBABILITY); } if (multiBigramMap) { return multiBigramMap->getBigramProbability(dictionaryStructurePolicy, prevWordPos, wordPos, unigramProbability); } - return ProbabilityUtils::backoff(unigramProbability); + return dictionaryStructurePolicy->getProbability(unigramProbability, + NOT_A_PROBABILITY); } //////////////// diff --git a/native/jni/src/suggest/core/dicnode/dic_nodes_cache.cpp b/native/jni/src/suggest/core/dicnode/dic_nodes_cache.cpp index c3d2a2e74..b6be47e90 100644 --- a/native/jni/src/suggest/core/dicnode/dic_nodes_cache.cpp +++ b/native/jni/src/suggest/core/dicnode/dic_nodes_cache.cpp @@ -23,6 +23,11 @@ namespace latinime { +// The biggest value among MAX_CACHE_DIC_NODE_SIZE, MAX_CACHE_DIC_NODE_SIZE_FOR_SINGLE_POINT, ... +const int DicNodesCache::LARGE_PRIORITY_QUEUE_CAPACITY = 310; +// Capacity for reducing memory footprint. +const int DicNodesCache::SMALL_PRIORITY_QUEUE_CAPACITY = 100; + /** * Truncates all of the dicNodes so that they start at the given commit point. * Only called for multi-word typing input. diff --git a/native/jni/src/suggest/core/dicnode/dic_nodes_cache.h b/native/jni/src/suggest/core/dicnode/dic_nodes_cache.h index 7aab0906e..8493b6a8b 100644 --- a/native/jni/src/suggest/core/dicnode/dic_nodes_cache.h +++ b/native/jni/src/suggest/core/dicnode/dic_nodes_cache.h @@ -31,25 +31,32 @@ class DicNode; */ class DicNodesCache { public: - AK_FORCE_INLINE DicNodesCache() - : mActiveDicNodes(&mDicNodePriorityQueues[DIC_NODES_CACHE_INITIAL_QUEUE_ID_ACTIVE]), - mNextActiveDicNodes(&mDicNodePriorityQueues[ - DIC_NODES_CACHE_INITIAL_QUEUE_ID_NEXT_ACTIVE]), - mTerminalDicNodes(&mDicNodePriorityQueues[DIC_NODES_CACHE_INITIAL_QUEUE_ID_TERMINAL]), - mCachedDicNodesForContinuousSuggestion(&mDicNodePriorityQueues[ - DIC_NODES_CACHE_INITIAL_QUEUE_ID_CACHE_FOR_CONTINUOUS_SUGGESTION]), - mInputIndex(0), mLastCachedInputIndex(0) { - } + AK_FORCE_INLINE explicit DicNodesCache(const bool usesLargeCapacityCache) + : mUsesLargeCapacityCache(usesLargeCapacityCache), + mDicNodePriorityQueue0(getCacheCapacity()), + mDicNodePriorityQueue1(getCacheCapacity()), + mDicNodePriorityQueue2(getCacheCapacity()), + mDicNodePriorityQueueForTerminal(MAX_RESULTS), + mActiveDicNodes(&mDicNodePriorityQueue0), + mNextActiveDicNodes(&mDicNodePriorityQueue1), + mCachedDicNodesForContinuousSuggestion(&mDicNodePriorityQueue2), + mTerminalDicNodes(&mDicNodePriorityQueueForTerminal), + mInputIndex(0), mLastCachedInputIndex(0) {} AK_FORCE_INLINE virtual ~DicNodesCache() {} AK_FORCE_INLINE void reset(const int nextActiveSize, const int terminalSize) { mInputIndex = 0; mLastCachedInputIndex = 0; - mActiveDicNodes->reset(); - mNextActiveDicNodes->clearAndResize(nextActiveSize); + // We want to use the max capacity for the current active dic node queue. + mActiveDicNodes->clearAndResizeToCapacity(); + // nextActiveSize is used to limit the next iteration's active dic node size. + const int nextActiveSizeFittingToTheCapacity = min(nextActiveSize, getCacheCapacity()); + mNextActiveDicNodes->clearAndResize(nextActiveSizeFittingToTheCapacity); mTerminalDicNodes->clearAndResize(terminalSize); - mCachedDicNodesForContinuousSuggestion->reset(); + // We want to use the max capacity for the cached dic nodes that will be used for the + // continuous suggestion. + mCachedDicNodesForContinuousSuggestion->clearAndResizeToCapacity(); } AK_FORCE_INLINE void continueSearch() { @@ -157,21 +164,35 @@ class DicNodesCache { return tmp; } + AK_FORCE_INLINE int getCacheCapacity() const { + return mUsesLargeCapacityCache ? + LARGE_PRIORITY_QUEUE_CAPACITY : SMALL_PRIORITY_QUEUE_CAPACITY; + } + AK_FORCE_INLINE void resetTemporaryCaches() { mActiveDicNodes->clear(); mNextActiveDicNodes->clear(); mTerminalDicNodes->clear(); } - DicNodePriorityQueue mDicNodePriorityQueues[DIC_NODES_CACHE_PRIORITY_QUEUES_SIZE]; + static const int LARGE_PRIORITY_QUEUE_CAPACITY; + static const int SMALL_PRIORITY_QUEUE_CAPACITY; + + const bool mUsesLargeCapacityCache; + // Instances + DicNodePriorityQueue mDicNodePriorityQueue0; + DicNodePriorityQueue mDicNodePriorityQueue1; + DicNodePriorityQueue mDicNodePriorityQueue2; + DicNodePriorityQueue mDicNodePriorityQueueForTerminal; + // Active dicNodes currently being expanded. DicNodePriorityQueue *mActiveDicNodes; // Next dicNodes to be expanded. DicNodePriorityQueue *mNextActiveDicNodes; - // Current top terminal dicNodes. - DicNodePriorityQueue *mTerminalDicNodes; // Cached dicNodes used for continuous suggestion. DicNodePriorityQueue *mCachedDicNodesForContinuousSuggestion; + // Current top terminal dicNodes. + DicNodePriorityQueue *mTerminalDicNodes; int mInputIndex; int mLastCachedInputIndex; }; diff --git a/native/jni/src/suggest/core/dicnode/internal/dic_node_state_output.h b/native/jni/src/suggest/core/dicnode/internal/dic_node_state_output.h index 45c7f5cf9..74eb5dfe7 100644 --- a/native/jni/src/suggest/core/dicnode/internal/dic_node_state_output.h +++ b/native/jni/src/suggest/core/dicnode/internal/dic_node_state_output.h @@ -49,8 +49,10 @@ class DicNodeStateOutput { void addMergedNodeCodePoints(const uint16_t mergedNodeCodePointCount, const int *const mergedNodeCodePoints) { if (mergedNodeCodePoints) { + const int additionalCodePointCount = min(static_cast<int>(mergedNodeCodePointCount), + MAX_WORD_LENGTH - mOutputtedCodePointCount); memcpy(&mCodePointsBuf[mOutputtedCodePointCount], mergedNodeCodePoints, - mergedNodeCodePointCount * sizeof(mCodePointsBuf[0])); + additionalCodePointCount * sizeof(mCodePointsBuf[0])); mOutputtedCodePointCount = static_cast<uint16_t>( mOutputtedCodePointCount + mergedNodeCodePointCount); if (mOutputtedCodePointCount < MAX_WORD_LENGTH) { diff --git a/native/jni/src/suggest/core/dicnode/internal/dic_node_state_prevword.h b/native/jni/src/suggest/core/dicnode/internal/dic_node_state_prevword.h index 5854f4f6e..b8986203d 100644 --- a/native/jni/src/suggest/core/dicnode/internal/dic_node_state_prevword.h +++ b/native/jni/src/suggest/core/dicnode/internal/dic_node_state_prevword.h @@ -22,6 +22,7 @@ #include "defines.h" #include "suggest/core/dicnode/dic_node_utils.h" +#include "suggest/core/layout/proximity_info_state.h" namespace latinime { @@ -29,9 +30,8 @@ class DicNodeStatePrevWord { public: AK_FORCE_INLINE DicNodeStatePrevWord() : mPrevWordCount(0), mPrevWordLength(0), mPrevWordStart(0), mPrevWordProbability(0), - mPrevWordNodePos(NOT_A_VALID_WORD_POS) { + mPrevWordNodePos(NOT_A_DICT_POS), mSecondWordFirstInputIndex(NOT_AN_INDEX) { memset(mPrevWord, 0, sizeof(mPrevWord)); - memset(mPrevSpacePositions, 0, sizeof(mPrevSpacePositions)); } virtual ~DicNodeStatePrevWord() {} @@ -41,8 +41,8 @@ class DicNodeStatePrevWord { mPrevWordCount = 0; mPrevWordStart = 0; mPrevWordProbability = -1; - mPrevWordNodePos = NOT_A_VALID_WORD_POS; - memset(mPrevSpacePositions, 0, sizeof(mPrevSpacePositions)); + mPrevWordNodePos = NOT_A_DICT_POS; + mSecondWordFirstInputIndex = NOT_AN_INDEX; } void init(const int prevWordNodePos) { @@ -51,7 +51,7 @@ class DicNodeStatePrevWord { mPrevWordStart = 0; mPrevWordProbability = -1; mPrevWordNodePos = prevWordNodePos; - memset(mPrevSpacePositions, 0, sizeof(mPrevSpacePositions)); + mSecondWordFirstInputIndex = NOT_AN_INDEX; } // Init by copy @@ -61,24 +61,26 @@ class DicNodeStatePrevWord { mPrevWordStart = prevWord->mPrevWordStart; mPrevWordProbability = prevWord->mPrevWordProbability; mPrevWordNodePos = prevWord->mPrevWordNodePos; + mSecondWordFirstInputIndex = prevWord->mSecondWordFirstInputIndex; memcpy(mPrevWord, prevWord->mPrevWord, prevWord->mPrevWordLength * sizeof(mPrevWord[0])); - memcpy(mPrevSpacePositions, prevWord->mPrevSpacePositions, sizeof(mPrevSpacePositions)); } void init(const int16_t prevWordCount, const int16_t prevWordProbability, const int prevWordNodePos, const int *const src0, const int16_t length0, - const int *const src1, const int16_t length1, const int *const prevSpacePositions, - const int lastInputIndex) { - mPrevWordCount = prevWordCount; + const int *const src1, const int16_t length1, + const int prevWordSecondWordFirstInputIndex, const int lastInputIndex) { + mPrevWordCount = min(prevWordCount, static_cast<int16_t>(MAX_RESULTS)); mPrevWordProbability = prevWordProbability; mPrevWordNodePos = prevWordNodePos; - const int twoWordsLen = + int twoWordsLen = DicNodeUtils::appendTwoWords(src0, length0, src1, length1, mPrevWord); + if (twoWordsLen >= MAX_WORD_LENGTH) { + twoWordsLen = MAX_WORD_LENGTH - 1; + } mPrevWord[twoWordsLen] = KEYCODE_SPACE; mPrevWordStart = length0; mPrevWordLength = static_cast<int16_t>(twoWordsLen + 1); - memcpy(mPrevSpacePositions, prevSpacePositions, sizeof(mPrevSpacePositions)); - mPrevSpacePositions[mPrevWordCount - 1] = lastInputIndex; + mSecondWordFirstInputIndex = prevWordSecondWordFirstInputIndex; } void truncate(const int offset) { @@ -93,11 +95,12 @@ class DicNodeStatePrevWord { mPrevWordLength = newPrevWordLength; } - void outputSpacePositions(int *spaceIndices) const { - // Convert uint16_t to int - for (int i = 0; i < MAX_RESULTS; i++) { - spaceIndices[i] = mPrevSpacePositions[i]; - } + void setSecondWordFirstInputIndex(const int inputIndex) { + mSecondWordFirstInputIndex = inputIndex; + } + + int getSecondWordFirstInputIndex() const { + return mSecondWordFirstInputIndex; } // TODO: remove @@ -113,10 +116,6 @@ class DicNodeStatePrevWord { return mPrevWordStart; } - int16_t getPrevWordProbability() const { - return mPrevWordProbability; - } - int getPrevWordNodePos() const { return mPrevWordNodePos; } @@ -139,8 +138,6 @@ class DicNodeStatePrevWord { // TODO: Move to private int mPrevWord[MAX_WORD_LENGTH]; - // TODO: Move to private - int mPrevSpacePositions[MAX_RESULTS]; private: // Caution!!! @@ -151,6 +148,7 @@ class DicNodeStatePrevWord { int16_t mPrevWordStart; int16_t mPrevWordProbability; int mPrevWordNodePos; + int mSecondWordFirstInputIndex; }; } // namespace latinime #endif // LATINIME_DIC_NODE_STATE_PREVWORD_H diff --git a/native/jni/src/suggest/core/dictionary/bigram_dictionary.cpp b/native/jni/src/suggest/core/dictionary/bigram_dictionary.cpp index ebe76467a..71f4ef6ea 100644 --- a/native/jni/src/suggest/core/dictionary/bigram_dictionary.cpp +++ b/native/jni/src/suggest/core/dictionary/bigram_dictionary.cpp @@ -23,7 +23,6 @@ #include "defines.h" #include "suggest/core/dictionary/binary_dictionary_bigrams_iterator.h" #include "suggest/core/dictionary/dictionary.h" -#include "suggest/core/dictionary/probability_utils.h" #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" #include "utils/char_utils.h" @@ -117,18 +116,24 @@ int BigramDictionary::getPredictions(const int *prevWord, const int prevWordLeng mDictionaryStructurePolicy->getBigramsStructurePolicy(), pos); while (bigramsIt.hasNext()) { bigramsIt.next(); - const int length = mDictionaryStructurePolicy-> + if (bigramsIt.getBigramPos() == NOT_A_DICT_POS) { + continue; + } + const int codePointCount = mDictionaryStructurePolicy-> getCodePointsAndProbabilityAndReturnCodePointCount(bigramsIt.getBigramPos(), MAX_WORD_LENGTH, bigramBuffer, &unigramProbability); + if (codePointCount <= 0) { + continue; + } // Due to space constraints, the probability for bigrams is approximate - the lower the // unigram probability, the worse the precision. The theoritical maximum error in // resulting probability is 8 - although in the practice it's never bigger than 3 or 4 // in very bad cases. This means that sometimes, we'll see some bigrams interverted // here, but it can't get too bad. - const int probability = ProbabilityUtils::computeProbabilityForBigram( + const int probability = mDictionaryStructurePolicy->getProbability( unigramProbability, bigramsIt.getProbability()); - addWordBigram(bigramBuffer, length, probability, outBigramProbability, outBigramCodePoints, - outputTypes); + addWordBigram(bigramBuffer, codePointCount, probability, outBigramProbability, + outBigramCodePoints, outputTypes); ++bigramCount; } return min(bigramCount, MAX_RESULTS); @@ -141,28 +146,30 @@ int BigramDictionary::getBigramListPositionForWord(const int *prevWord, const in if (0 >= prevWordLength) return NOT_A_DICT_POS; int pos = mDictionaryStructurePolicy->getTerminalNodePositionOfWord(prevWord, prevWordLength, forceLowerCaseSearch); - if (NOT_A_VALID_WORD_POS == pos) return NOT_A_DICT_POS; - return mDictionaryStructurePolicy->getBigramsPositionOfNode(pos); + if (NOT_A_DICT_POS == pos) return NOT_A_DICT_POS; + return mDictionaryStructurePolicy->getBigramsPositionOfPtNode(pos); } -bool BigramDictionary::isValidBigram(const int *word0, int length0, const int *word1, +int BigramDictionary::getBigramProbability(const int *word0, int length0, const int *word1, int length1) const { int pos = getBigramListPositionForWord(word0, length0, false /* forceLowerCaseSearch */); // getBigramListPositionForWord returns 0 if this word isn't in the dictionary or has no bigrams - if (NOT_A_DICT_POS == pos) return false; + if (NOT_A_DICT_POS == pos) return NOT_A_PROBABILITY; int nextWordPos = mDictionaryStructurePolicy->getTerminalNodePositionOfWord(word1, length1, false /* forceLowerCaseSearch */); - if (NOT_A_VALID_WORD_POS == nextWordPos) return false; + if (NOT_A_DICT_POS == nextWordPos) return NOT_A_PROBABILITY; BinaryDictionaryBigramsIterator bigramsIt( mDictionaryStructurePolicy->getBigramsStructurePolicy(), pos); while (bigramsIt.hasNext()) { bigramsIt.next(); if (bigramsIt.getBigramPos() == nextWordPos) { - return true; + return mDictionaryStructurePolicy->getProbability( + mDictionaryStructurePolicy->getUnigramProbabilityOfPtNode(nextWordPos), + bigramsIt.getProbability()); } } - return false; + return NOT_A_PROBABILITY; } // TODO: Move functions related to bigram to here diff --git a/native/jni/src/suggest/core/dictionary/bigram_dictionary.h b/native/jni/src/suggest/core/dictionary/bigram_dictionary.h index 99b964c49..8af7ee75d 100644 --- a/native/jni/src/suggest/core/dictionary/bigram_dictionary.h +++ b/native/jni/src/suggest/core/dictionary/bigram_dictionary.h @@ -29,7 +29,7 @@ class BigramDictionary { int getPredictions(const int *word, int length, int *outBigramCodePoints, int *outBigramProbability, int *outputTypes) const; - bool isValidBigram(const int *word1, int length1, const int *word2, int length2) const; + int getBigramProbability(const int *word1, int length1, const int *word2, int length2) const; ~BigramDictionary(); private: diff --git a/native/jni/src/suggest/core/dictionary/dictionary.cpp b/native/jni/src/suggest/core/dictionary/dictionary.cpp index 8418a608a..ec1b63a12 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.cpp +++ b/native/jni/src/suggest/core/dictionary/dictionary.cpp @@ -87,14 +87,15 @@ int Dictionary::getBigrams(const int *word, int length, int *outWords, int *freq int Dictionary::getProbability(const int *word, int length) const { int pos = getDictionaryStructurePolicy()->getTerminalNodePositionOfWord(word, length, false /* forceLowerCaseSearch */); - if (NOT_A_VALID_WORD_POS == pos) { + if (NOT_A_DICT_POS == pos) { return NOT_A_PROBABILITY; } - return getDictionaryStructurePolicy()->getUnigramProbability(pos); + return getDictionaryStructurePolicy()->getUnigramProbabilityOfPtNode(pos); } -bool Dictionary::isValidBigram(const int *word0, int length0, const int *word1, int length1) const { - return mBigramDictionary->isValidBigram(word0, length0, word1, length1); +int Dictionary::getBigramProbability(const int *word0, int length0, const int *word1, + int length1) const { + return mBigramDictionary->getBigramProbability(word0, length0, word1, length1); } void Dictionary::addUnigramWord(const int *const word, const int length, const int probability) { @@ -112,6 +113,18 @@ void Dictionary::removeBigramWords(const int *const word0, const int length0, mDictionaryStructureWithBufferPolicy->removeBigramWords(word0, length0, word1, length1); } +void Dictionary::flush(const char *const filePath) { + mDictionaryStructureWithBufferPolicy->flush(filePath); +} + +void Dictionary::flushWithGC(const char *const filePath) { + mDictionaryStructureWithBufferPolicy->flushWithGC(filePath); +} + +bool Dictionary::needsToRunGC() { + return mDictionaryStructureWithBufferPolicy->needsToRunGC(); +} + void Dictionary::logDictionaryInfo(JNIEnv *const env) const { const int BUFFER_SIZE = 16; int dictionaryIdCodePointBuffer[BUFFER_SIZE]; diff --git a/native/jni/src/suggest/core/dictionary/dictionary.h b/native/jni/src/suggest/core/dictionary/dictionary.h index 0afe5a54b..974447468 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.h +++ b/native/jni/src/suggest/core/dictionary/dictionary.h @@ -67,7 +67,7 @@ class Dictionary { int getProbability(const int *word, int length) const; - bool isValidBigram(const int *word0, int length0, const int *word1, int length1) const; + int getBigramProbability(const int *word0, int length0, const int *word1, int length1) const; void addUnigramWord(const int *const word, const int length, const int probability); @@ -77,6 +77,12 @@ class Dictionary { void removeBigramWords(const int *const word0, const int length0, const int *const word1, const int length1); + void flush(const char *const filePath); + + void flushWithGC(const char *const filePath); + + bool needsToRunGC(); + const DictionaryStructureWithBufferPolicy *getDictionaryStructurePolicy() const { return mDictionaryStructureWithBufferPolicy; } diff --git a/native/jni/src/suggest/core/dictionary/multi_bigram_map.h b/native/jni/src/suggest/core/dictionary/multi_bigram_map.h index 97d4cd161..4633c07b0 100644 --- a/native/jni/src/suggest/core/dictionary/multi_bigram_map.h +++ b/native/jni/src/suggest/core/dictionary/multi_bigram_map.h @@ -22,7 +22,6 @@ #include "defines.h" #include "suggest/core/dictionary/binary_dictionary_bigrams_iterator.h" #include "suggest/core/dictionary/bloom_filter.h" -#include "suggest/core/dictionary/probability_utils.h" #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" #include "utils/hash_map_compat.h" @@ -43,11 +42,12 @@ class MultiBigramMap { hash_map_compat<int, BigramMap>::const_iterator mapPosition = mBigramMaps.find(wordPosition); if (mapPosition != mBigramMaps.end()) { - return mapPosition->second.getBigramProbability(nextWordPosition, unigramProbability); + return mapPosition->second.getBigramProbability(structurePolicy, nextWordPosition, + unigramProbability); } if (mBigramMaps.size() < MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP) { addBigramsForWordPosition(structurePolicy, wordPosition); - return mBigramMaps[wordPosition].getBigramProbability( + return mBigramMaps[wordPosition].getBigramProbability(structurePolicy, nextWordPosition, unigramProbability); } return readBigramProbabilityFromBinaryDictionary(structurePolicy, wordPosition, @@ -68,28 +68,31 @@ class MultiBigramMap { void init(const DictionaryStructureWithBufferPolicy *const structurePolicy, const int nodePos) { - const int bigramsListPos = structurePolicy->getBigramsPositionOfNode(nodePos); + const int bigramsListPos = structurePolicy->getBigramsPositionOfPtNode(nodePos); BinaryDictionaryBigramsIterator bigramsIt(structurePolicy->getBigramsStructurePolicy(), bigramsListPos); while (bigramsIt.hasNext()) { bigramsIt.next(); + if (bigramsIt.getBigramPos() == NOT_A_DICT_POS) { + continue; + } mBigramMap[bigramsIt.getBigramPos()] = bigramsIt.getProbability(); mBloomFilter.setInFilter(bigramsIt.getBigramPos()); } } AK_FORCE_INLINE int getBigramProbability( + const DictionaryStructureWithBufferPolicy *const structurePolicy, const int nextWordPosition, const int unigramProbability) const { + int bigramProbability = NOT_A_PROBABILITY; if (mBloomFilter.isInFilter(nextWordPosition)) { const hash_map_compat<int, int>::const_iterator bigramProbabilityIt = mBigramMap.find(nextWordPosition); if (bigramProbabilityIt != mBigramMap.end()) { - const int bigramProbability = bigramProbabilityIt->second; - return ProbabilityUtils::computeProbabilityForBigram( - unigramProbability, bigramProbability); + bigramProbability = bigramProbabilityIt->second; } } - return ProbabilityUtils::backoff(unigramProbability); + return structurePolicy->getProbability(unigramProbability, bigramProbability); } private: @@ -108,17 +111,18 @@ class MultiBigramMap { AK_FORCE_INLINE int readBigramProbabilityFromBinaryDictionary( const DictionaryStructureWithBufferPolicy *const structurePolicy, const int nodePos, const int nextWordPosition, const int unigramProbability) { - const int bigramsListPos = structurePolicy->getBigramsPositionOfNode(nodePos); + int bigramProbability = NOT_A_PROBABILITY; + const int bigramsListPos = structurePolicy->getBigramsPositionOfPtNode(nodePos); BinaryDictionaryBigramsIterator bigramsIt(structurePolicy->getBigramsStructurePolicy(), bigramsListPos); while (bigramsIt.hasNext()) { bigramsIt.next(); if (bigramsIt.getBigramPos() == nextWordPosition) { - return ProbabilityUtils::computeProbabilityForBigram( - unigramProbability, bigramsIt.getProbability()); + bigramProbability = bigramsIt.getProbability(); + break; } } - return ProbabilityUtils::backoff(unigramProbability); + return structurePolicy->getProbability(unigramProbability, bigramProbability); } static const size_t MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP; diff --git a/native/jni/src/suggest/core/layout/proximity_info_state.cpp b/native/jni/src/suggest/core/layout/proximity_info_state.cpp index 7780efdfd..fbabd92f2 100644 --- a/native/jni/src/suggest/core/layout/proximity_info_state.cpp +++ b/native/jni/src/suggest/core/layout/proximity_info_state.cpp @@ -36,8 +36,8 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi const int *const xCoordinates, const int *const yCoordinates, const int *const times, const int *const pointerIds, const bool isGeometric) { ASSERT(isGeometric || (inputSize < MAX_WORD_LENGTH)); - mIsContinuousSuggestionPossible = - ProximityInfoStateUtils::checkAndReturnIsContinuousSuggestionPossible( + mIsContinuousSuggestionPossible = (mHasBeenUpdatedByGeometricInput != isGeometric) ? + false : ProximityInfoStateUtils::checkAndReturnIsContinuousSuggestionPossible( inputSize, xCoordinates, yCoordinates, times, mSampledInputSize, &mSampledInputXs, &mSampledInputYs, &mSampledTimes, &mSampledInputIndice); if (DEBUG_DICT) { @@ -155,6 +155,7 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi if (DEBUG_GEO_FULL) { AKLOGI("ProximityState init finished: %d points out of %d", mSampledInputSize, inputSize); } + mHasBeenUpdatedByGeometricInput = isGeometric; } // This function basically converts from a length to an edit distance. Accordingly, it's obviously diff --git a/native/jni/src/suggest/core/layout/proximity_info_state.h b/native/jni/src/suggest/core/layout/proximity_info_state.h index dbcd54488..c94060fa9 100644 --- a/native/jni/src/suggest/core/layout/proximity_info_state.h +++ b/native/jni/src/suggest/core/layout/proximity_info_state.h @@ -46,10 +46,11 @@ class ProximityInfoState { : mProximityInfo(0), mMaxPointToKeyLength(0.0f), mAverageSpeed(0.0f), mHasTouchPositionCorrectionData(false), mMostCommonKeyWidthSquare(0), mKeyCount(0), mCellHeight(0), mCellWidth(0), mGridHeight(0), mGridWidth(0), - mIsContinuousSuggestionPossible(false), mSampledInputXs(), mSampledInputYs(), - mSampledTimes(), mSampledInputIndice(), mSampledLengthCache(), - mBeelineSpeedPercentiles(), mSampledNormalizedSquaredLengthCache(), mSpeedRates(), - mDirections(), mCharProbabilities(), mSampledNearKeySets(), mSampledSearchKeySets(), + mIsContinuousSuggestionPossible(false), mHasBeenUpdatedByGeometricInput(false), + mSampledInputXs(), mSampledInputYs(), mSampledTimes(), mSampledInputIndice(), + mSampledLengthCache(), mBeelineSpeedPercentiles(), + mSampledNormalizedSquaredLengthCache(), mSpeedRates(), mDirections(), + mCharProbabilities(), mSampledNearKeySets(), mSampledSearchKeySets(), mSampledSearchKeyVectors(), mTouchPositionCorrectionEnabled(false), mSampledInputSize(0), mMostProbableStringProbability(0.0f) { memset(mInputProximities, 0, sizeof(mInputProximities)); @@ -129,6 +130,10 @@ class ProximityInfoState { return mSampledInputYs[index]; } + int getInputIndexOfSampledPoint(const int sampledIndex) const { + return mSampledInputIndice[sampledIndex]; + } + bool hasSpaceProximity(const int index) const; int getLengthCache(const int index) const { @@ -204,6 +209,7 @@ class ProximityInfoState { int mGridHeight; int mGridWidth; bool mIsContinuousSuggestionPossible; + bool mHasBeenUpdatedByGeometricInput; std::vector<int> mSampledInputXs; std::vector<int> mSampledInputYs; diff --git a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h index 532411509..b95488ebd 100644 --- a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h +++ b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h @@ -47,11 +47,14 @@ class DictionaryStructureWithBufferPolicy { virtual int getTerminalNodePositionOfWord(const int *const inWord, const int length, const bool forceLowerCaseSearch) const = 0; - virtual int getUnigramProbability(const int nodePos) const = 0; + virtual int getProbability(const int unigramProbability, + const int bigramProbability) const = 0; - virtual int getShortcutPositionOfNode(const int nodePos) const = 0; + virtual int getUnigramProbabilityOfPtNode(const int nodePos) const = 0; - virtual int getBigramsPositionOfNode(const int nodePos) const = 0; + virtual int getShortcutPositionOfPtNode(const int nodePos) const = 0; + + virtual int getBigramsPositionOfPtNode(const int nodePos) const = 0; virtual const DictionaryHeaderStructurePolicy *getHeaderStructurePolicy() const = 0; @@ -71,6 +74,12 @@ class DictionaryStructureWithBufferPolicy { virtual bool removeBigramWords(const int *const word0, const int length0, const int *const word1, const int length1) = 0; + virtual void flush(const char *const filePath) = 0; + + virtual void flushWithGC(const char *const filePath) = 0; + + virtual bool needsToRunGC() const = 0; + protected: DictionaryStructureWithBufferPolicy() {} diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.cpp b/native/jni/src/suggest/core/session/dic_traverse_session.cpp index 0ca583f90..50f2bbd8d 100644 --- a/native/jni/src/suggest/core/session/dic_traverse_session.cpp +++ b/native/jni/src/suggest/core/session/dic_traverse_session.cpp @@ -23,6 +23,11 @@ namespace latinime { +// 256K bytes threshold is heuristically used to distinguish dictionaries containing many unigrams +// (e.g. main dictionary) from small dictionaries (e.g. contacts...) +const int DicTraverseSession::DICTIONARY_SIZE_THRESHOLD_TO_USE_LARGE_CACHE_FOR_SUGGESTION = + 256 * 1024; + void DicTraverseSession::init(const Dictionary *const dictionary, const int *prevWord, int prevWordLength, const SuggestOptions *const suggestOptions) { mDictionary = dictionary; @@ -30,13 +35,13 @@ void DicTraverseSession::init(const Dictionary *const dictionary, const int *pre ->getMultiWordCostMultiplier(); mSuggestOptions = suggestOptions; if (!prevWord) { - mPrevWordPos = NOT_A_VALID_WORD_POS; + mPrevWordPos = NOT_A_DICT_POS; return; } // TODO: merge following similar calls to getTerminalPosition into one case-insensitive call. mPrevWordPos = getDictionaryStructurePolicy()->getTerminalNodePositionOfWord( prevWord, prevWordLength, false /* forceLowerCaseSearch */); - if (mPrevWordPos == NOT_A_VALID_WORD_POS) { + if (mPrevWordPos == NOT_A_DICT_POS) { // Check bigrams for lower-cased previous word if original was not found. Useful for // auto-capitalized words like "The [current_word]". mPrevWordPos = getDictionaryStructurePolicy()->getTerminalNodePositionOfWord( @@ -59,8 +64,9 @@ const DictionaryStructureWithBufferPolicy *DicTraverseSession::getDictionaryStru return mDictionary->getDictionaryStructurePolicy(); } -void DicTraverseSession::resetCache(const int nextActiveCacheSize, const int maxWords) { - mDicNodesCache.reset(nextActiveCacheSize, maxWords); +void DicTraverseSession::resetCache(const int thresholdForNextActiveDicNodes, const int maxWords) { + mDicNodesCache.reset(thresholdForNextActiveDicNodes /* nextActiveSize */, + maxWords /* terminalSize */); mMultiBigramMap.clear(); mPartiallyCommited = false; } diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.h b/native/jni/src/suggest/core/session/dic_traverse_session.h index 23de5cc65..e0b1c67d9 100644 --- a/native/jni/src/suggest/core/session/dic_traverse_session.h +++ b/native/jni/src/suggest/core/session/dic_traverse_session.h @@ -37,8 +37,12 @@ class DicTraverseSession { public: // A factory method for DicTraverseSession - static AK_FORCE_INLINE void *getSessionInstance(JNIEnv *env, jstring localeStr) { - return new DicTraverseSession(env, localeStr); + static AK_FORCE_INLINE void *getSessionInstance(JNIEnv *env, jstring localeStr, + jlong dictSize) { + // To deal with the trade-off between accuracy and memory space, large cache is used for + // dictionaries larger that the threshold + return new DicTraverseSession(env, localeStr, + dictSize >= DICTIONARY_SIZE_THRESHOLD_TO_USE_LARGE_CACHE_FOR_SUGGESTION); } static AK_FORCE_INLINE void initSessionInstance(DicTraverseSession *traverseSession, @@ -54,10 +58,10 @@ class DicTraverseSession { delete traverseSession; } - AK_FORCE_INLINE DicTraverseSession(JNIEnv *env, jstring localeStr) - : mPrevWordPos(NOT_A_VALID_WORD_POS), mProximityInfo(0), - mDictionary(0), mSuggestOptions(0), mDicNodesCache(), mMultiBigramMap(), - mInputSize(0), mPartiallyCommited(false), mMaxPointerCount(1), + AK_FORCE_INLINE DicTraverseSession(JNIEnv *env, jstring localeStr, bool usesLargeCache) + : mPrevWordPos(NOT_A_DICT_POS), mProximityInfo(0), + mDictionary(0), mSuggestOptions(0), mDicNodesCache(usesLargeCache), + mMultiBigramMap(), mInputSize(0), mPartiallyCommited(false), mMaxPointerCount(1), mMultiWordCostMultiplier(1.0f) { // NOTE: mProximityInfoStates is an array of instances. // No need to initialize it explicitly here. @@ -73,7 +77,7 @@ class DicTraverseSession { const int inputSize, const int *const inputXs, const int *const inputYs, const int *const times, const int *const pointerIds, const float maxSpatialDistance, const int maxPointerCount); - void resetCache(const int nextActiveCacheSize, const int maxWords); + void resetCache(const int thresholdForNextActiveDicNodes, const int maxWords); const DictionaryStructureWithBufferPolicy *getDictionaryStructurePolicy() const; @@ -109,7 +113,9 @@ class DicTraverseSession { if (usedPointerCount != 1) { return false; } - *pointerId = usedPointerId; + if (pointerId) { + *pointerId = usedPointerId; + } return true; } @@ -181,6 +187,7 @@ class DicTraverseSession { DISALLOW_IMPLICIT_CONSTRUCTORS(DicTraverseSession); // threshold to start caching static const int CACHE_START_INPUT_LENGTH_THRESHOLD; + static const int DICTIONARY_SIZE_THRESHOLD_TO_USE_LARGE_CACHE_FOR_SUGGESTION; void initializeProximityInfoStates(const int *const inputCodePoints, const int *const inputXs, const int *const inputYs, const int *const times, const int *const pointerIds, const int inputSize, const float maxSpatialDistance, const int maxPointerCount); diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp index 7d8dd21c5..b1340e12f 100644 --- a/native/jni/src/suggest/core/suggest.cpp +++ b/native/jni/src/suggest/core/suggest.cpp @@ -117,7 +117,7 @@ void Suggest::initializeSearch(DicTraverseSession *traverseSession, int commitPo * Outputs the final list of suggestions (i.e., terminal nodes). */ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequencies, - int *outputCodePoints, int *spaceIndices, int *outputTypes) const { + int *outputCodePoints, int *outputIndicesToPartialCommit, int *outputTypes) const { #if DEBUG_EVALUATE_MOST_PROBABLE_STRING const int terminalSize = 0; #else @@ -139,6 +139,7 @@ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequen SCORING->getMostProbableString(traverseSession, terminalSize, languageWeight, &outputCodePoints[0], &outputTypes[0], &frequencies[0]); if (hasMostProbableString) { + outputIndicesToPartialCommit[outputWordIndex] = NOT_AN_INDEX; ++outputWordIndex; } @@ -160,6 +161,9 @@ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequen || (traverseSession->getInputSize() >= MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT && terminals[0].hasMultipleWords())) : false; + // TODO: have partial commit work even with multiple pointers. + const bool outputSecondWordFirstLetterInputIndex = + traverseSession->isOnlyOnePointerUsed(0 /* pointerId */); // Output suggestion results here for (int terminalIndex = 0; terminalIndex < terminalSize && outputWordIndex < MAX_RESULTS; ++terminalIndex) { @@ -171,7 +175,9 @@ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequen terminalIndex, doubleLetterTerminalIndex, doubleLetterLevel); const float compoundDistance = terminalDicNode->getCompoundDistance(languageWeight) + doubleLetterCost; - const bool isPossiblyOffensiveWord = terminalDicNode->getProbability() <= 0; + const bool isPossiblyOffensiveWord = + traverseSession->getDictionaryStructurePolicy()->getProbability( + terminalDicNode->getProbability(), NOT_A_PROBABILITY) <= 0; const bool isExactMatch = terminalDicNode->isExactMatch(); const bool isFirstCharUppercase = terminalDicNode->isFirstCharUppercase(); // Heuristic: We exclude freq=0 first-char-uppercase words from exact match. @@ -192,18 +198,21 @@ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequen terminalDicNode->isExactMatch() || (forceCommitMultiWords && terminalDicNode->hasMultipleWords()) || (isValidWord && SCORING->doesAutoCorrectValidWord())); - maxScore = max(maxScore, finalScore); - - // TODO: Implement a smarter auto-commit method for handling multi-word suggestions. - // Index for top typing suggestion should be 0. - if (isValidWord && outputWordIndex == 0) { - terminalDicNode->outputSpacePositionsResult(spaceIndices); + if (maxScore < finalScore && isValidWord) { + maxScore = finalScore; } // Don't output invalid words. However, we still need to submit their shortcuts if any. if (isValidWord) { outputTypes[outputWordIndex] = Dictionary::KIND_CORRECTION | outputTypeFlags; frequencies[outputWordIndex] = finalScore; + if (outputSecondWordFirstLetterInputIndex) { + outputIndicesToPartialCommit[outputWordIndex] = + terminalDicNode->getSecondWordFirstInputIndex( + traverseSession->getProximityInfoState(0)); + } else { + outputIndicesToPartialCommit[outputWordIndex] = NOT_AN_INDEX; + } // Populate the outputChars array with the suggested word. const int startIndex = outputWordIndex * MAX_WORD_LENGTH; terminalDicNode->outputResult(&outputCodePoints[startIndex]); @@ -214,12 +223,23 @@ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequen BinaryDictionaryShortcutIterator shortcutIt( traverseSession->getDictionaryStructurePolicy()->getShortcutsStructurePolicy(), traverseSession->getDictionaryStructurePolicy() - ->getShortcutPositionOfNode(terminalDicNode->getPos())); + ->getShortcutPositionOfPtNode(terminalDicNode->getPos())); // Shortcut is not supported for multiple words suggestions. // TODO: Check shortcuts during traversal for multiple words suggestions. const bool sameAsTyped = TRAVERSAL->sameAsTyped(traverseSession, terminalDicNode); - outputWordIndex = ShortcutUtils::outputShortcuts(&shortcutIt, outputWordIndex, - finalScore, outputCodePoints, frequencies, outputTypes, sameAsTyped); + const int updatedOutputWordIndex = ShortcutUtils::outputShortcuts(&shortcutIt, + outputWordIndex, finalScore, outputCodePoints, frequencies, outputTypes, + sameAsTyped); + const int secondWordFirstInputIndex = terminalDicNode->getSecondWordFirstInputIndex( + traverseSession->getProximityInfoState(0)); + for (int i = outputWordIndex; i < updatedOutputWordIndex; ++i) { + if (outputSecondWordFirstLetterInputIndex) { + outputIndicesToPartialCommit[i] = secondWordFirstInputIndex; + } else { + outputIndicesToPartialCommit[i] = NOT_AN_INDEX; + } + } + outputWordIndex = updatedOutputWordIndex; } DicNode::managedDelete(terminalDicNode); } diff --git a/native/jni/src/suggest/core/suggest.h b/native/jni/src/suggest/core/suggest.h index 875cbe4e0..b24019632 100644 --- a/native/jni/src/suggest/core/suggest.h +++ b/native/jni/src/suggest/core/suggest.h @@ -55,7 +55,7 @@ class Suggest : public SuggestInterface { void createNextWordDicNode(DicTraverseSession *traverseSession, DicNode *dicNode, const bool spaceSubstitution) const; int outputSuggestions(DicTraverseSession *traverseSession, int *frequencies, - int *outputCodePoints, int *outputIndices, int *outputTypes) const; + int *outputCodePoints, int *outputIndicesToPartialCommit, int *outputTypes) const; void initializeSearch(DicTraverseSession *traverseSession, int commitPoint) const; void expandCurrentDicNodes(DicTraverseSession *traverseSession) const; void processTerminalDicNode(DicTraverseSession *traverseSession, DicNode *dicNode) const; diff --git a/native/jni/src/suggest/policyimpl/dictionary/bigram/bigram_list_policy.h b/native/jni/src/suggest/policyimpl/dictionary/bigram/bigram_list_policy.h index beb9bee27..6ff95cac4 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/bigram/bigram_list_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/bigram/bigram_list_policy.h @@ -21,7 +21,7 @@ #include "defines.h" #include "suggest/core/policy/dictionary_bigrams_structure_policy.h" -#include "suggest/policyimpl/dictionary/bigram/bigram_list_reading_utils.h" +#include "suggest/policyimpl/dictionary/bigram/bigram_list_read_write_utils.h" namespace latinime { @@ -33,16 +33,15 @@ class BigramListPolicy : public DictionaryBigramsStructurePolicy { void getNextBigram(int *const outBigramPos, int *const outProbability, bool *const outHasNext, int *const pos) const { - const BigramListReadingUtils::BigramFlags flags = - BigramListReadingUtils::getFlagsAndForwardPointer(mBigramsBuf, pos); - *outBigramPos = BigramListReadingUtils::getBigramAddressAndForwardPointer( - mBigramsBuf, flags, pos); - *outProbability = BigramListReadingUtils::getProbabilityFromFlags(flags); - *outHasNext = BigramListReadingUtils::hasNext(flags); + BigramListReadWriteUtils::BigramFlags flags; + BigramListReadWriteUtils::getBigramEntryPropertiesAndAdvancePosition(mBigramsBuf, &flags, + outBigramPos, pos); + *outProbability = BigramListReadWriteUtils::getProbabilityFromFlags(flags); + *outHasNext = BigramListReadWriteUtils::hasNext(flags); } void skipAllBigrams(int *const pos) const { - BigramListReadingUtils::skipExistingBigrams(mBigramsBuf, pos); + BigramListReadWriteUtils::skipExistingBigrams(mBigramsBuf, pos); } private: diff --git a/native/jni/src/suggest/policyimpl/dictionary/bigram/bigram_list_read_write_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/bigram/bigram_list_read_write_utils.cpp new file mode 100644 index 000000000..1926b9831 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/bigram/bigram_list_read_write_utils.cpp @@ -0,0 +1,182 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/bigram/bigram_list_read_write_utils.h" + +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.h" +#include "suggest/policyimpl/dictionary/utils/byte_array_utils.h" +#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" + +namespace latinime { + +const BigramListReadWriteUtils::BigramFlags BigramListReadWriteUtils::MASK_ATTRIBUTE_ADDRESS_TYPE = + 0x30; +const BigramListReadWriteUtils::BigramFlags + BigramListReadWriteUtils::FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE = 0x10; +const BigramListReadWriteUtils::BigramFlags + BigramListReadWriteUtils::FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES = 0x20; +const BigramListReadWriteUtils::BigramFlags + BigramListReadWriteUtils::FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES = 0x30; +const BigramListReadWriteUtils::BigramFlags + BigramListReadWriteUtils::FLAG_ATTRIBUTE_OFFSET_NEGATIVE = 0x40; +// Flag for presence of more attributes +const BigramListReadWriteUtils::BigramFlags BigramListReadWriteUtils::FLAG_ATTRIBUTE_HAS_NEXT = + 0x80; +// Mask for attribute probability, stored on 4 bits inside the flags byte. +const BigramListReadWriteUtils::BigramFlags + BigramListReadWriteUtils::MASK_ATTRIBUTE_PROBABILITY = 0x0F; +const int BigramListReadWriteUtils::ATTRIBUTE_ADDRESS_SHIFT = 4; + +/* static */ void BigramListReadWriteUtils::getBigramEntryPropertiesAndAdvancePosition( + const uint8_t *const bigramsBuf, BigramFlags *const outBigramFlags, + int *const outTargetPtNodePos, int *const bigramEntryPos) { + const BigramFlags bigramFlags = ByteArrayUtils::readUint8AndAdvancePosition(bigramsBuf, + bigramEntryPos); + if (outBigramFlags) { + *outBigramFlags = bigramFlags; + } + const int targetPos = getBigramAddressAndAdvancePosition(bigramsBuf, bigramFlags, + bigramEntryPos); + if (outTargetPtNodePos) { + *outTargetPtNodePos = targetPos; + } +} + +/* static */ void BigramListReadWriteUtils::skipExistingBigrams(const uint8_t *const bigramsBuf, + int *const bigramListPos) { + BigramFlags flags; + do { + getBigramEntryPropertiesAndAdvancePosition(bigramsBuf, &flags, 0 /* outTargetPtNodePos */, + bigramListPos); + } while(hasNext(flags)); +} + +/* static */ int BigramListReadWriteUtils::getBigramAddressAndAdvancePosition( + const uint8_t *const bigramsBuf, const BigramFlags flags, int *const pos) { + int offset = 0; + const int origin = *pos; + switch (MASK_ATTRIBUTE_ADDRESS_TYPE & flags) { + case FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE: + offset = ByteArrayUtils::readUint8AndAdvancePosition(bigramsBuf, pos); + break; + case FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES: + offset = ByteArrayUtils::readUint16AndAdvancePosition(bigramsBuf, pos); + break; + case FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES: + offset = ByteArrayUtils::readUint24AndAdvancePosition(bigramsBuf, pos); + break; + } + if (offset == DynamicPatriciaTrieReadingUtils::DICT_OFFSET_INVALID) { + return NOT_A_DICT_POS; + } else if (offset == DynamicPatriciaTrieReadingUtils::DICT_OFFSET_ZERO_OFFSET) { + return origin; + } + if (isOffsetNegative(flags)) { + return origin - offset; + } else { + return origin + offset; + } +} + +/* static */ bool BigramListReadWriteUtils::setHasNextFlag( + BufferWithExtendableBuffer *const buffer, const bool hasNext, const int entryPos) { + const bool usesAdditionalBuffer = buffer->isInAdditionalBuffer(entryPos); + int readingPos = entryPos; + if (usesAdditionalBuffer) { + readingPos -= buffer->getOriginalBufferSize(); + } + BigramFlags bigramFlags = ByteArrayUtils::readUint8AndAdvancePosition( + buffer->getBuffer(usesAdditionalBuffer), &readingPos); + if (hasNext) { + bigramFlags = bigramFlags | FLAG_ATTRIBUTE_HAS_NEXT; + } else { + bigramFlags = bigramFlags & (~FLAG_ATTRIBUTE_HAS_NEXT); + } + int writingPos = entryPos; + return buffer->writeUintAndAdvancePosition(bigramFlags, 1 /* size */, &writingPos); +} + +/* static */ bool BigramListReadWriteUtils::createAndWriteBigramEntry( + BufferWithExtendableBuffer *const buffer, const int targetPos, const int probability, + const bool hasNext, int *const writingPos) { + BigramFlags flags; + if (!createAndGetBigramFlags(*writingPos, targetPos, probability, hasNext, &flags)) { + return false; + } + return writeBigramEntry(buffer, flags, targetPos, writingPos); +} + +/* static */ bool BigramListReadWriteUtils::writeBigramEntry( + BufferWithExtendableBuffer *const bufferToWrite, const BigramFlags flags, + const int targetPtNodePos, int *const writingPos) { + const int offset = getBigramTargetOffset(targetPtNodePos, *writingPos); + const BigramFlags flagsToWrite = (offset < 0) ? + (flags | FLAG_ATTRIBUTE_OFFSET_NEGATIVE) : (flags & ~FLAG_ATTRIBUTE_OFFSET_NEGATIVE); + if (!bufferToWrite->writeUintAndAdvancePosition(flagsToWrite, 1 /* size */, writingPos)) { + return false; + } + const uint32_t absOffest = abs(offset); + const int bigramTargetFieldSize = attributeAddressSize(flags); + return bufferToWrite->writeUintAndAdvancePosition(absOffest, bigramTargetFieldSize, + writingPos); +} + +// Returns true if the bigram entry is valid and put entry flags into out*. +/* static */ bool BigramListReadWriteUtils::createAndGetBigramFlags(const int entryPos, + const int targetPtNodePos, const int probability, const bool hasNext, + BigramFlags *const outBigramFlags) { + BigramFlags flags = probability & MASK_ATTRIBUTE_PROBABILITY; + if (hasNext) { + flags |= FLAG_ATTRIBUTE_HAS_NEXT; + } + const int offset = getBigramTargetOffset(targetPtNodePos, entryPos); + if (offset < 0) { + flags |= FLAG_ATTRIBUTE_OFFSET_NEGATIVE; + } + const uint32_t absOffest = abs(offset); + if ((absOffest >> 24) != 0) { + // Offset is too large. + return false; + } else if ((absOffest >> 16) != 0) { + flags |= FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES; + } else if ((absOffest >> 8) != 0) { + flags |= FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES; + } else { + flags |= FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE; + } + // Currently, all newly written bigram position fields are 3 bytes to simplify dictionary + // writing. + // TODO: Remove following 2 lines and optimize memory space. + flags = (flags & (~MASK_ATTRIBUTE_ADDRESS_TYPE)) | FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES; + *outBigramFlags = flags; + return true; +} + +/* static */ int BigramListReadWriteUtils::getBigramTargetOffset(const int targetPtNodePos, + const int entryPos) { + if (targetPtNodePos == NOT_A_DICT_POS) { + return DynamicPatriciaTrieReadingUtils::DICT_OFFSET_INVALID; + } else { + const int offset = targetPtNodePos - (entryPos + 1 /* bigramFlagsField */); + if (offset == 0) { + return DynamicPatriciaTrieReadingUtils::DICT_OFFSET_ZERO_OFFSET; + } else { + return offset; + } + } +} + +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/bigram/bigram_list_reading_utils.h b/native/jni/src/suggest/policyimpl/dictionary/bigram/bigram_list_read_write_utils.h index d0c584bd5..eabe4e099 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/bigram/bigram_list_reading_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/bigram/bigram_list_read_write_utils.h @@ -14,24 +14,25 @@ * limitations under the License. */ -#ifndef LATINIME_BIGRAM_LIST_READING_UTILS_H -#define LATINIME_BIGRAM_LIST_READING_UTILS_H +#ifndef LATINIME_BIGRAM_LIST_READ_WRITE_UTILS_H +#define LATINIME_BIGRAM_LIST_READ_WRITE_UTILS_H +#include <cstdlib> #include <stdint.h> #include "defines.h" -#include "suggest/policyimpl/dictionary/utils/byte_array_utils.h" namespace latinime { -class BigramListReadingUtils { +class BufferWithExtendableBuffer; + +class BigramListReadWriteUtils { public: typedef uint8_t BigramFlags; - static AK_FORCE_INLINE BigramFlags getFlagsAndForwardPointer( - const uint8_t *const bigramsBuf, int *const pos) { - return ByteArrayUtils::readUint8AndAdvancePosition(bigramsBuf, pos); - } + static void getBigramEntryPropertiesAndAdvancePosition(const uint8_t *const bigramsBuf, + BigramFlags *const outBigramFlags, int *const outTargetPtNodePos, + int *const bigramEntryPos); static AK_FORCE_INLINE int getProbabilityFromFlags(const BigramFlags flags) { return flags & MASK_ATTRIBUTE_PROBABILITY; @@ -42,21 +43,38 @@ public: } // Bigrams reading methods - static AK_FORCE_INLINE void skipExistingBigrams(const uint8_t *const bigramsBuf, - int *const pos) { - BigramFlags flags = getFlagsAndForwardPointer(bigramsBuf, pos); - while (hasNext(flags)) { - *pos += attributeAddressSize(flags); - flags = getFlagsAndForwardPointer(bigramsBuf, pos); - } - *pos += attributeAddressSize(flags); + static void skipExistingBigrams(const uint8_t *const bigramsBuf, int *const bigramListPos); + + // Returns the size of the bigram position field that is stored in bigram flags. + static AK_FORCE_INLINE int attributeAddressSize(const BigramFlags flags) { + return (flags & MASK_ATTRIBUTE_ADDRESS_TYPE) >> ATTRIBUTE_ADDRESS_SHIFT; + /* Note: this is a value-dependant optimization of what may probably be + more readably written this way: + switch (flags * BinaryFormat::MASK_ATTRIBUTE_ADDRESS_TYPE) { + case FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE: return 1; + case FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES: return 2; + case FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTE: return 3; + default: return 0; + } + */ } - static int getBigramAddressAndForwardPointer(const uint8_t *const bigramsBuf, - const BigramFlags flags, int *const pos); + static bool setHasNextFlag(BufferWithExtendableBuffer *const buffer, + const bool hasNext, const int entryPos); + + static AK_FORCE_INLINE BigramFlags setProbabilityInFlags(const BigramFlags flags, + const int probability) { + return (flags & (~MASK_ATTRIBUTE_PROBABILITY)) | (probability & MASK_ATTRIBUTE_PROBABILITY); + } + + static bool createAndWriteBigramEntry(BufferWithExtendableBuffer *const buffer, + const int targetPos, const int probability, const bool hasNext, int *const writingPos); + + static bool writeBigramEntry(BufferWithExtendableBuffer *const buffer, const BigramFlags flags, + const int targetOffset, int *const writingPos); private: - DISALLOW_IMPLICIT_CONSTRUCTORS(BigramListReadingUtils); + DISALLOW_IMPLICIT_CONSTRUCTORS(BigramListReadWriteUtils); static const BigramFlags MASK_ATTRIBUTE_ADDRESS_TYPE; static const BigramFlags FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE; @@ -67,22 +85,18 @@ private: static const BigramFlags MASK_ATTRIBUTE_PROBABILITY; static const int ATTRIBUTE_ADDRESS_SHIFT; + // Returns true if the bigram entry is valid and put entry flags into out*. + static bool createAndGetBigramFlags(const int entryPos, const int targetPos, + const int probability, const bool hasNext, BigramFlags *const outBigramFlags); + static AK_FORCE_INLINE bool isOffsetNegative(const BigramFlags flags) { return (flags & FLAG_ATTRIBUTE_OFFSET_NEGATIVE) != 0; } - static AK_FORCE_INLINE int attributeAddressSize(const BigramFlags flags) { - return (flags & MASK_ATTRIBUTE_ADDRESS_TYPE) >> ATTRIBUTE_ADDRESS_SHIFT; - /* Note: this is a value-dependant optimization of what may probably be - more readably written this way: - switch (flags * BinaryFormat::MASK_ATTRIBUTE_ADDRESS_TYPE) { - case FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE: return 1; - case FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES: return 2; - case FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTE: return 3; - default: return 0; - } - */ - } + static int getBigramAddressAndAdvancePosition(const uint8_t *const bigramsBuf, + const BigramFlags flags, int *const pos); + + static int getBigramTargetOffset(const int targetPtNodePos, const int entryPos); }; } // namespace latinime -#endif // LATINIME_BIGRAM_LIST_READING_UTILS_H +#endif // LATINIME_BIGRAM_LIST_READ_WRITE_UTILS_H diff --git a/native/jni/src/suggest/policyimpl/dictionary/bigram/bigram_list_reading_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/bigram/bigram_list_reading_utils.cpp deleted file mode 100644 index 6da0e8b85..000000000 --- a/native/jni/src/suggest/policyimpl/dictionary/bigram/bigram_list_reading_utils.cpp +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Copyright (C) 2013 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "suggest/policyimpl/dictionary/bigram/bigram_list_reading_utils.h" - -#include "suggest/policyimpl/dictionary/utils/byte_array_utils.h" - -namespace latinime { - -const BigramListReadingUtils::BigramFlags BigramListReadingUtils::MASK_ATTRIBUTE_ADDRESS_TYPE = - 0x30; -const BigramListReadingUtils::BigramFlags - BigramListReadingUtils::FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE = 0x10; -const BigramListReadingUtils::BigramFlags - BigramListReadingUtils::FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES = 0x20; -const BigramListReadingUtils::BigramFlags - BigramListReadingUtils::FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES = 0x30; -const BigramListReadingUtils::BigramFlags - BigramListReadingUtils::FLAG_ATTRIBUTE_OFFSET_NEGATIVE = 0x40; -// Flag for presence of more attributes -const BigramListReadingUtils::BigramFlags BigramListReadingUtils::FLAG_ATTRIBUTE_HAS_NEXT = 0x80; -// Mask for attribute probability, stored on 4 bits inside the flags byte. -const BigramListReadingUtils::BigramFlags - BigramListReadingUtils::MASK_ATTRIBUTE_PROBABILITY = 0x0F; -const int BigramListReadingUtils::ATTRIBUTE_ADDRESS_SHIFT = 4; - -/* static */ int BigramListReadingUtils::getBigramAddressAndForwardPointer( - const uint8_t *const bigramsBuf, const BigramFlags flags, int *const pos) { - int offset = 0; - const int origin = *pos; - switch (MASK_ATTRIBUTE_ADDRESS_TYPE & flags) { - case FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE: - offset = ByteArrayUtils::readUint8AndAdvancePosition(bigramsBuf, pos); - break; - case FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES: - offset = ByteArrayUtils::readUint16AndAdvancePosition(bigramsBuf, pos); - break; - case FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES: - offset = ByteArrayUtils::readUint24AndAdvancePosition(bigramsBuf, pos); - break; - } - if (isOffsetNegative(flags)) { - return origin - offset; - } else { - return origin + offset; - } -} - -} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/bigram/dynamic_bigram_list_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/bigram/dynamic_bigram_list_policy.cpp new file mode 100644 index 000000000..29307b56a --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/bigram/dynamic_bigram_list_policy.cpp @@ -0,0 +1,336 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/bigram/dynamic_bigram_list_policy.h" + +#include "suggest/core/policy/dictionary_shortcuts_structure_policy.h" +#include "suggest/policyimpl/dictionary/bigram/bigram_list_read_write_utils.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_node_reader.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.h" +#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" + +namespace latinime { + +const int DynamicBigramListPolicy::CONTINUING_BIGRAM_LINK_COUNT_LIMIT = 10000; +const int DynamicBigramListPolicy::BIGRAM_ENTRY_COUNT_IN_A_BIGRAM_LIST_LIMIT = 100000; + +void DynamicBigramListPolicy::getNextBigram(int *const outBigramPos, int *const outProbability, + bool *const outHasNext, int *const bigramEntryPos) const { + const bool usesAdditionalBuffer = mBuffer->isInAdditionalBuffer(*bigramEntryPos); + const uint8_t *const buffer = mBuffer->getBuffer(usesAdditionalBuffer); + if (usesAdditionalBuffer) { + *bigramEntryPos -= mBuffer->getOriginalBufferSize(); + } + BigramListReadWriteUtils::BigramFlags bigramFlags; + int originalBigramPos; + BigramListReadWriteUtils::getBigramEntryPropertiesAndAdvancePosition(buffer, &bigramFlags, + &originalBigramPos, bigramEntryPos); + if (usesAdditionalBuffer && originalBigramPos != NOT_A_DICT_POS) { + originalBigramPos += mBuffer->getOriginalBufferSize(); + } + *outBigramPos = followBigramLinkAndGetCurrentBigramPtNodePos(originalBigramPos); + *outProbability = BigramListReadWriteUtils::getProbabilityFromFlags(bigramFlags); + *outHasNext = BigramListReadWriteUtils::hasNext(bigramFlags); + if (usesAdditionalBuffer) { + *bigramEntryPos += mBuffer->getOriginalBufferSize(); + } +} + +void DynamicBigramListPolicy::skipAllBigrams(int *const bigramListPos) const { + const bool usesAdditionalBuffer = mBuffer->isInAdditionalBuffer(*bigramListPos); + const uint8_t *const buffer = mBuffer->getBuffer(usesAdditionalBuffer); + if (usesAdditionalBuffer) { + *bigramListPos -= mBuffer->getOriginalBufferSize(); + } + BigramListReadWriteUtils::skipExistingBigrams(buffer, bigramListPos); + if (usesAdditionalBuffer) { + *bigramListPos += mBuffer->getOriginalBufferSize(); + } +} + +bool DynamicBigramListPolicy::copyAllBigrams(BufferWithExtendableBuffer *const bufferToWrite, + int *const fromPos, int *const toPos, int *const outBigramsCount) const { + const bool usesAdditionalBuffer = mBuffer->isInAdditionalBuffer(*fromPos); + if (usesAdditionalBuffer) { + *fromPos -= mBuffer->getOriginalBufferSize(); + } + *outBigramsCount = 0; + BigramListReadWriteUtils::BigramFlags bigramFlags; + int bigramEntryCount = 0; + int lastWrittenEntryPos = NOT_A_DICT_POS; + do { + if (++bigramEntryCount > BIGRAM_ENTRY_COUNT_IN_A_BIGRAM_LIST_LIMIT) { + AKLOGE("Too many bigram entries. Entry count: %d, Limit: %d", + bigramEntryCount, BIGRAM_ENTRY_COUNT_IN_A_BIGRAM_LIST_LIMIT); + ASSERT(false); + return false; + } + // The buffer address can be changed after calling buffer writing methods. + int originalBigramPos; + BigramListReadWriteUtils::getBigramEntryPropertiesAndAdvancePosition( + mBuffer->getBuffer(usesAdditionalBuffer), &bigramFlags, &originalBigramPos, + fromPos); + if (originalBigramPos == NOT_A_DICT_POS) { + // skip invalid bigram entry. + continue; + } + if (usesAdditionalBuffer) { + originalBigramPos += mBuffer->getOriginalBufferSize(); + } + const int bigramPos = followBigramLinkAndGetCurrentBigramPtNodePos(originalBigramPos); + if (bigramPos == NOT_A_DICT_POS) { + // Target PtNode has been invalidated. + continue; + } + lastWrittenEntryPos = *toPos; + if (!BigramListReadWriteUtils::createAndWriteBigramEntry(bufferToWrite, bigramPos, + BigramListReadWriteUtils::getProbabilityFromFlags(bigramFlags), + BigramListReadWriteUtils::hasNext(bigramFlags), toPos)) { + return false; + } + (*outBigramsCount)++; + } while(BigramListReadWriteUtils::hasNext(bigramFlags)); + // Makes the last entry the terminal of the list. Updates the flags. + if (lastWrittenEntryPos != NOT_A_DICT_POS) { + if (!BigramListReadWriteUtils::setHasNextFlag(bufferToWrite, false /* hasNext */, + lastWrittenEntryPos)) { + return false; + } + } + if (usesAdditionalBuffer) { + *fromPos += mBuffer->getOriginalBufferSize(); + } + return true; +} + +// Finding useless bigram entries and remove them. Bigram entry is useless when the target PtNode +// has been deleted or is not a valid terminal. +bool DynamicBigramListPolicy::updateAllBigramEntriesAndDeleteUselessEntries( + int *const bigramListPos) { + const bool usesAdditionalBuffer = mBuffer->isInAdditionalBuffer(*bigramListPos); + if (usesAdditionalBuffer) { + *bigramListPos -= mBuffer->getOriginalBufferSize(); + } + DynamicPatriciaTrieNodeReader nodeReader(mBuffer, this /* bigramsPolicy */, mShortcutPolicy); + BigramListReadWriteUtils::BigramFlags bigramFlags; + int bigramEntryCount = 0; + do { + if (++bigramEntryCount > BIGRAM_ENTRY_COUNT_IN_A_BIGRAM_LIST_LIMIT) { + AKLOGE("Too many bigram entries. Entry count: %d, Limit: %d", + bigramEntryCount, BIGRAM_ENTRY_COUNT_IN_A_BIGRAM_LIST_LIMIT); + ASSERT(false); + return false; + } + int bigramEntryPos = *bigramListPos; + int originalBigramPos; + // The buffer address can be changed after calling buffer writing methods. + BigramListReadWriteUtils::getBigramEntryPropertiesAndAdvancePosition( + mBuffer->getBuffer(usesAdditionalBuffer), &bigramFlags, &originalBigramPos, + bigramListPos); + if (usesAdditionalBuffer) { + bigramEntryPos += mBuffer->getOriginalBufferSize(); + } + if (originalBigramPos == NOT_A_DICT_POS) { + // This entry has already been removed. + continue; + } + if (usesAdditionalBuffer) { + originalBigramPos += mBuffer->getOriginalBufferSize(); + } + const int bigramTargetNodePos = + followBigramLinkAndGetCurrentBigramPtNodePos(originalBigramPos); + nodeReader.fetchNodeInfoInBufferFromPtNodePos(bigramTargetNodePos); + // TODO: Update probability for supporting probability decaying. + if (nodeReader.isDeleted() || !nodeReader.isTerminal() + || bigramTargetNodePos == NOT_A_DICT_POS) { + // The target is no longer valid terminal. Invalidate the current bigram entry. + if (!BigramListReadWriteUtils::writeBigramEntry(mBuffer, bigramFlags, + NOT_A_DICT_POS /* targetOffset */, &bigramEntryPos)) { + return false; + } + } + } while(BigramListReadWriteUtils::hasNext(bigramFlags)); + return true; +} + +// Updates bigram target PtNode positions in the list after the placing step in GC. +bool DynamicBigramListPolicy::updateAllBigramTargetPtNodePositions(int *const bigramListPos, + const DynamicPatriciaTrieWritingHelper::PtNodePositionRelocationMap *const + ptNodePositionRelocationMap) { + const bool usesAdditionalBuffer = mBuffer->isInAdditionalBuffer(*bigramListPos); + if (usesAdditionalBuffer) { + *bigramListPos -= mBuffer->getOriginalBufferSize(); + } + BigramListReadWriteUtils::BigramFlags bigramFlags; + int bigramEntryCount = 0; + do { + if (++bigramEntryCount > BIGRAM_ENTRY_COUNT_IN_A_BIGRAM_LIST_LIMIT) { + AKLOGE("Too many bigram entries. Entry count: %d, Limit: %d", + bigramEntryCount, BIGRAM_ENTRY_COUNT_IN_A_BIGRAM_LIST_LIMIT); + ASSERT(false); + return false; + } + int bigramEntryPos = *bigramListPos; + if (usesAdditionalBuffer) { + bigramEntryPos += mBuffer->getOriginalBufferSize(); + } + int bigramTargetPtNodePos; + // The buffer address can be changed after calling buffer writing methods. + BigramListReadWriteUtils::getBigramEntryPropertiesAndAdvancePosition( + mBuffer->getBuffer(usesAdditionalBuffer), &bigramFlags, &bigramTargetPtNodePos, + bigramListPos); + if (bigramTargetPtNodePos == NOT_A_DICT_POS) { + continue; + } + if (usesAdditionalBuffer) { + bigramTargetPtNodePos += mBuffer->getOriginalBufferSize(); + } + + DynamicPatriciaTrieWritingHelper::PtNodePositionRelocationMap::const_iterator it = + ptNodePositionRelocationMap->find(bigramTargetPtNodePos); + if (it != ptNodePositionRelocationMap->end()) { + bigramTargetPtNodePos = it->second; + } else { + bigramTargetPtNodePos = NOT_A_DICT_POS; + } + if (!BigramListReadWriteUtils::writeBigramEntry(mBuffer, bigramFlags, + bigramTargetPtNodePos, &bigramEntryPos)) { + return false; + } + } while(BigramListReadWriteUtils::hasNext(bigramFlags)); + return true; +} + +bool DynamicBigramListPolicy::addNewBigramEntryToBigramList(const int bigramTargetPos, + const int probability, int *const bigramListPos) { + const bool usesAdditionalBuffer = mBuffer->isInAdditionalBuffer(*bigramListPos); + if (usesAdditionalBuffer) { + *bigramListPos -= mBuffer->getOriginalBufferSize(); + } + BigramListReadWriteUtils::BigramFlags bigramFlags; + int bigramEntryCount = 0; + do { + if (++bigramEntryCount > BIGRAM_ENTRY_COUNT_IN_A_BIGRAM_LIST_LIMIT) { + AKLOGE("Too many bigram entries. Entry count: %d, Limit: %d", + bigramEntryCount, BIGRAM_ENTRY_COUNT_IN_A_BIGRAM_LIST_LIMIT); + ASSERT(false); + return false; + } + int entryPos = *bigramListPos; + if (usesAdditionalBuffer) { + entryPos += mBuffer->getOriginalBufferSize(); + } + int originalBigramPos; + // The buffer address can be changed after calling buffer writing methods. + BigramListReadWriteUtils::getBigramEntryPropertiesAndAdvancePosition( + mBuffer->getBuffer(usesAdditionalBuffer), &bigramFlags, &originalBigramPos, + bigramListPos); + if (usesAdditionalBuffer && originalBigramPos != NOT_A_DICT_POS) { + originalBigramPos += mBuffer->getOriginalBufferSize(); + } + if (followBigramLinkAndGetCurrentBigramPtNodePos(originalBigramPos) == bigramTargetPos) { + // Update this bigram entry. + const BigramListReadWriteUtils::BigramFlags updatedFlags = + BigramListReadWriteUtils::setProbabilityInFlags(bigramFlags, probability); + return BigramListReadWriteUtils::writeBigramEntry(mBuffer, updatedFlags, + originalBigramPos, &entryPos); + } + if (BigramListReadWriteUtils::hasNext(bigramFlags)) { + continue; + } + // The current last entry is found. + // First, update the flags of the last entry. + if (!BigramListReadWriteUtils::setHasNextFlag(mBuffer, true /* hasNext */, entryPos)) { + return false; + } + if (usesAdditionalBuffer) { + *bigramListPos += mBuffer->getOriginalBufferSize(); + } + // Then, add a new entry after the last entry. + return writeNewBigramEntry(bigramTargetPos, probability, bigramListPos); + } while(BigramListReadWriteUtils::hasNext(bigramFlags)); + // We return directly from the while loop. + ASSERT(false); + return false; +} + +bool DynamicBigramListPolicy::writeNewBigramEntry(const int bigramTargetPos, const int probability, + int *const writingPos) { + // hasNext is false because we are adding a new bigram entry at the end of the bigram list. + return BigramListReadWriteUtils::createAndWriteBigramEntry(mBuffer, bigramTargetPos, + probability, false /* hasNext */, writingPos); +} + +bool DynamicBigramListPolicy::removeBigram(const int bigramListPos, const int bigramTargetPos) { + const bool usesAdditionalBuffer = mBuffer->isInAdditionalBuffer(bigramListPos); + int pos = bigramListPos; + if (usesAdditionalBuffer) { + pos -= mBuffer->getOriginalBufferSize(); + } + BigramListReadWriteUtils::BigramFlags bigramFlags; + int bigramEntryCount = 0; + do { + if (++bigramEntryCount > BIGRAM_ENTRY_COUNT_IN_A_BIGRAM_LIST_LIMIT) { + AKLOGE("Too many bigram entries. Entry count: %d, Limit: %d", + bigramEntryCount, BIGRAM_ENTRY_COUNT_IN_A_BIGRAM_LIST_LIMIT); + ASSERT(false); + return false; + } + int bigramEntryPos = pos; + int originalBigramPos; + // The buffer address can be changed after calling buffer writing methods. + BigramListReadWriteUtils::getBigramEntryPropertiesAndAdvancePosition( + mBuffer->getBuffer(usesAdditionalBuffer), &bigramFlags, &originalBigramPos, &pos); + if (usesAdditionalBuffer) { + bigramEntryPos += mBuffer->getOriginalBufferSize(); + } + if (usesAdditionalBuffer && originalBigramPos != NOT_A_DICT_POS) { + originalBigramPos += mBuffer->getOriginalBufferSize(); + } + const int bigramPos = followBigramLinkAndGetCurrentBigramPtNodePos(originalBigramPos); + if (bigramPos != bigramTargetPos) { + continue; + } + // Target entry is found. Write an invalid target position to mark the bigram invalid. + return BigramListReadWriteUtils::writeBigramEntry(mBuffer, bigramFlags, + NOT_A_DICT_POS /* targetOffset */, &bigramEntryPos); + } while(BigramListReadWriteUtils::hasNext(bigramFlags)); + return false; +} + +int DynamicBigramListPolicy::followBigramLinkAndGetCurrentBigramPtNodePos( + const int originalBigramPos) const { + if (originalBigramPos == NOT_A_DICT_POS) { + return NOT_A_DICT_POS; + } + int currentPos = originalBigramPos; + DynamicPatriciaTrieNodeReader nodeReader(mBuffer, this /* bigramsPolicy */, mShortcutPolicy); + nodeReader.fetchNodeInfoInBufferFromPtNodePos(currentPos); + int bigramLinkCount = 0; + while (nodeReader.getBigramLinkedNodePos() != NOT_A_DICT_POS) { + currentPos = nodeReader.getBigramLinkedNodePos(); + nodeReader.fetchNodeInfoInBufferFromPtNodePos(currentPos); + bigramLinkCount++; + if (bigramLinkCount > CONTINUING_BIGRAM_LINK_COUNT_LIMIT) { + AKLOGE("Bigram link is invalid. start position: %d", originalBigramPos); + ASSERT(false); + return NOT_A_DICT_POS; + } + } + return currentPos; +} + +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/bigram/dynamic_bigram_list_policy.h b/native/jni/src/suggest/policyimpl/dictionary/bigram/dynamic_bigram_list_policy.h new file mode 100644 index 000000000..8ea318a41 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/bigram/dynamic_bigram_list_policy.h @@ -0,0 +1,81 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DYNAMIC_BIGRAM_LIST_POLICY_H +#define LATINIME_DYNAMIC_BIGRAM_LIST_POLICY_H + +#include <stdint.h> + +#include "defines.h" +#include "suggest/core/policy/dictionary_bigrams_structure_policy.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.h" + +namespace latinime { + +class BufferWithExtendableBuffer; +class DictionaryShortcutsStructurePolicy; + +/* + * This is a dynamic version of BigramListPolicy and supports an additional buffer. + */ +class DynamicBigramListPolicy : public DictionaryBigramsStructurePolicy { + public: + DynamicBigramListPolicy(BufferWithExtendableBuffer *const buffer, + const DictionaryShortcutsStructurePolicy *const shortcutPolicy) + : mBuffer(buffer), mShortcutPolicy(shortcutPolicy) {} + + ~DynamicBigramListPolicy() {} + + void getNextBigram(int *const outBigramPos, int *const outProbability, bool *const outHasNext, + int *const bigramEntryPos) const; + + void skipAllBigrams(int *const bigramListPos) const; + + // Copy bigrams from the bigram list that starts at fromPos in mBuffer to toPos in + // bufferToWrite and advance these positions after bigram lists. This method skips invalid + // bigram entries and write the valid bigram entry count to outBigramsCount. + bool copyAllBigrams(BufferWithExtendableBuffer *const bufferToWrite, int *const fromPos, + int *const toPos, int *const outBigramsCount) const; + + bool updateAllBigramEntriesAndDeleteUselessEntries(int *const bigramListPos); + + bool updateAllBigramTargetPtNodePositions(int *const bigramListPos, + const DynamicPatriciaTrieWritingHelper::PtNodePositionRelocationMap *const + ptNodePositionRelocationMap); + + bool addNewBigramEntryToBigramList(const int bigramTargetPos, const int probability, + int *const bigramListPos); + + bool writeNewBigramEntry(const int bigramTargetPos, const int probability, + int *const writingPos); + + // Return if targetBigramPos is found or not. + bool removeBigram(const int bigramListPos, const int bigramTargetPos); + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(DynamicBigramListPolicy); + + static const int CONTINUING_BIGRAM_LINK_COUNT_LIMIT; + static const int BIGRAM_ENTRY_COUNT_IN_A_BIGRAM_LIST_LIMIT; + + BufferWithExtendableBuffer *const mBuffer; + const DictionaryShortcutsStructurePolicy *const mShortcutPolicy; + + // Follow bigram link and return the position of bigram target PtNode that is currently valid. + int followBigramLinkAndGetCurrentBigramPtNodePos(const int originalBigramPos) const; +}; +} // namespace latinime +#endif // LATINIME_DYNAMIC_BIGRAM_LIST_POLICY_H diff --git a/native/jni/src/suggest/policyimpl/dictionary/binary_format.h b/native/jni/src/suggest/policyimpl/dictionary/binary_format.h deleted file mode 100644 index 23f4c7fec..000000000 --- a/native/jni/src/suggest/policyimpl/dictionary/binary_format.h +++ /dev/null @@ -1,470 +0,0 @@ -/* - * Copyright (C) 2011 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LATINIME_BINARY_FORMAT_H -#define LATINIME_BINARY_FORMAT_H - -#include <stdint.h> - -#include "suggest/core/dictionary/probability_utils.h" -#include "utils/char_utils.h" - -namespace latinime { - -class BinaryFormat { - public: - // Mask and flags for children address type selection. - static const int MASK_GROUP_ADDRESS_TYPE = 0xC0; - - // Flag for single/multiple char group - static const int FLAG_HAS_MULTIPLE_CHARS = 0x20; - - // Flag for terminal groups - static const int FLAG_IS_TERMINAL = 0x10; - - // Flag for shortcut targets presence - static const int FLAG_HAS_SHORTCUT_TARGETS = 0x08; - // Flag for bigram presence - static const int FLAG_HAS_BIGRAMS = 0x04; - // Flag for non-words (typically, shortcut only entries) - static const int FLAG_IS_NOT_A_WORD = 0x02; - // Flag for blacklist - static const int FLAG_IS_BLACKLISTED = 0x01; - - // Attribute (bigram/shortcut) related flags: - // Flag for presence of more attributes - static const int FLAG_ATTRIBUTE_HAS_NEXT = 0x80; - // Flag for sign of offset. If this flag is set, the offset value must be negated. - static const int FLAG_ATTRIBUTE_OFFSET_NEGATIVE = 0x40; - - // Mask for attribute probability, stored on 4 bits inside the flags byte. - static const int MASK_ATTRIBUTE_PROBABILITY = 0x0F; - - // Mask and flags for attribute address type selection. - static const int MASK_ATTRIBUTE_ADDRESS_TYPE = 0x30; - - static int getGroupCountAndForwardPointer(const uint8_t *const dict, int *pos); - static uint8_t getFlagsAndForwardPointer(const uint8_t *const dict, int *pos); - static int getCodePointAndForwardPointer(const uint8_t *const dict, int *pos); - static int readProbabilityWithoutMovingPointer(const uint8_t *const dict, const int pos); - static int skipOtherCharacters(const uint8_t *const dict, const int pos); - static int skipChildrenPosition(const uint8_t flags, const int pos); - static int skipProbability(const uint8_t flags, const int pos); - static int skipShortcuts(const uint8_t *const dict, const uint8_t flags, const int pos); - static int skipChildrenPosAndAttributes(const uint8_t *const dict, const uint8_t flags, - const int pos); - static int readChildrenPosition(const uint8_t *const dict, const uint8_t flags, const int pos); - static bool hasChildrenInFlags(const uint8_t flags); - static int getTerminalPosition(const uint8_t *const root, const int *const inWord, - const int length, const bool forceLowerCaseSearch); - static int getCodePointsAndProbabilityAndReturnCodePointCount( - const uint8_t *const root, const int nodePos, const int maxCodePointCount, - int *const outCodePoints, int *const outUnigramProbability); - - private: - DISALLOW_IMPLICIT_CONSTRUCTORS(BinaryFormat); - - static const int FLAG_GROUP_ADDRESS_TYPE_NOADDRESS = 0x00; - static const int FLAG_GROUP_ADDRESS_TYPE_ONEBYTE = 0x40; - static const int FLAG_GROUP_ADDRESS_TYPE_TWOBYTES = 0x80; - static const int FLAG_GROUP_ADDRESS_TYPE_THREEBYTES = 0xC0; - static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE = 0x10; - static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES = 0x20; - static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES = 0x30; - - static const int CHARACTER_ARRAY_TERMINATOR_SIZE = 1; - static const int MINIMAL_ONE_BYTE_CHARACTER_VALUE = 0x20; - static const int CHARACTER_ARRAY_TERMINATOR = 0x1F; - static const int MULTIPLE_BYTE_CHARACTER_ADDITIONAL_SIZE = 2; - static const int NO_FLAGS = 0; - static int skipAllAttributes(const uint8_t *const dict, const uint8_t flags, const int pos); - static int skipBigrams(const uint8_t *const dict, const uint8_t flags, const int pos); -}; - -AK_FORCE_INLINE int BinaryFormat::getGroupCountAndForwardPointer(const uint8_t *const dict, - int *pos) { - const int msb = dict[(*pos)++]; - if (msb < 0x80) return msb; - return ((msb & 0x7F) << 8) | dict[(*pos)++]; -} - -inline uint8_t BinaryFormat::getFlagsAndForwardPointer(const uint8_t *const dict, int *pos) { - return dict[(*pos)++]; -} - -AK_FORCE_INLINE int BinaryFormat::getCodePointAndForwardPointer(const uint8_t *const dict, - int *pos) { - const int origin = *pos; - const int codePoint = dict[origin]; - if (codePoint < MINIMAL_ONE_BYTE_CHARACTER_VALUE) { - if (codePoint == CHARACTER_ARRAY_TERMINATOR) { - *pos = origin + 1; - return NOT_A_CODE_POINT; - } else { - *pos = origin + 3; - const int char_1 = codePoint << 16; - const int char_2 = char_1 + (dict[origin + 1] << 8); - return char_2 + dict[origin + 2]; - } - } else { - *pos = origin + 1; - return codePoint; - } -} - -inline int BinaryFormat::readProbabilityWithoutMovingPointer(const uint8_t *const dict, - const int pos) { - return dict[pos]; -} - -AK_FORCE_INLINE int BinaryFormat::skipOtherCharacters(const uint8_t *const dict, const int pos) { - int currentPos = pos; - int character = dict[currentPos++]; - while (CHARACTER_ARRAY_TERMINATOR != character) { - if (character < MINIMAL_ONE_BYTE_CHARACTER_VALUE) { - currentPos += MULTIPLE_BYTE_CHARACTER_ADDITIONAL_SIZE; - } - character = dict[currentPos++]; - } - return currentPos; -} - -static inline int attributeAddressSize(const uint8_t flags) { - static const int ATTRIBUTE_ADDRESS_SHIFT = 4; - return (flags & BinaryFormat::MASK_ATTRIBUTE_ADDRESS_TYPE) >> ATTRIBUTE_ADDRESS_SHIFT; - /* Note: this is a value-dependant optimization of what may probably be - more readably written this way: - switch (flags * BinaryFormat::MASK_ATTRIBUTE_ADDRESS_TYPE) { - case FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE: return 1; - case FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES: return 2; - case FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTE: return 3; - default: return 0; - } - */ -} - -static AK_FORCE_INLINE int skipExistingBigrams(const uint8_t *const dict, const int pos) { - int currentPos = pos; - uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(dict, ¤tPos); - while (flags & BinaryFormat::FLAG_ATTRIBUTE_HAS_NEXT) { - currentPos += attributeAddressSize(flags); - flags = BinaryFormat::getFlagsAndForwardPointer(dict, ¤tPos); - } - currentPos += attributeAddressSize(flags); - return currentPos; -} - -static inline int childrenAddressSize(const uint8_t flags) { - static const int CHILDREN_ADDRESS_SHIFT = 6; - return (BinaryFormat::MASK_GROUP_ADDRESS_TYPE & flags) >> CHILDREN_ADDRESS_SHIFT; - /* See the note in attributeAddressSize. The same applies here */ -} - -static AK_FORCE_INLINE int shortcutByteSize(const uint8_t *const dict, const int pos) { - return (static_cast<int>(dict[pos] << 8)) + (dict[pos + 1]); -} - -inline int BinaryFormat::skipChildrenPosition(const uint8_t flags, const int pos) { - return pos + childrenAddressSize(flags); -} - -inline int BinaryFormat::skipProbability(const uint8_t flags, const int pos) { - return FLAG_IS_TERMINAL & flags ? pos + 1 : pos; -} - -AK_FORCE_INLINE int BinaryFormat::skipShortcuts(const uint8_t *const dict, const uint8_t flags, - const int pos) { - if (FLAG_HAS_SHORTCUT_TARGETS & flags) { - return pos + shortcutByteSize(dict, pos); - } else { - return pos; - } -} - -AK_FORCE_INLINE int BinaryFormat::skipBigrams(const uint8_t *const dict, const uint8_t flags, - const int pos) { - if (FLAG_HAS_BIGRAMS & flags) { - return skipExistingBigrams(dict, pos); - } else { - return pos; - } -} - -AK_FORCE_INLINE int BinaryFormat::skipAllAttributes(const uint8_t *const dict, const uint8_t flags, - const int pos) { - // This function skips all attributes: shortcuts and bigrams. - int newPos = pos; - newPos = skipShortcuts(dict, flags, newPos); - newPos = skipBigrams(dict, flags, newPos); - return newPos; -} - -AK_FORCE_INLINE int BinaryFormat::skipChildrenPosAndAttributes(const uint8_t *const dict, - const uint8_t flags, const int pos) { - int currentPos = pos; - currentPos = skipChildrenPosition(flags, currentPos); - currentPos = skipAllAttributes(dict, flags, currentPos); - return currentPos; -} - -AK_FORCE_INLINE int BinaryFormat::readChildrenPosition(const uint8_t *const dict, - const uint8_t flags, const int pos) { - int offset = 0; - switch (MASK_GROUP_ADDRESS_TYPE & flags) { - case FLAG_GROUP_ADDRESS_TYPE_ONEBYTE: - offset = dict[pos]; - break; - case FLAG_GROUP_ADDRESS_TYPE_TWOBYTES: - offset = dict[pos] << 8; - offset += dict[pos + 1]; - break; - case FLAG_GROUP_ADDRESS_TYPE_THREEBYTES: - offset = dict[pos] << 16; - offset += dict[pos + 1] << 8; - offset += dict[pos + 2]; - break; - default: - // If we come here, it means we asked for the children of a word with - // no children. - return -1; - } - return pos + offset; -} - -inline bool BinaryFormat::hasChildrenInFlags(const uint8_t flags) { - return (FLAG_GROUP_ADDRESS_TYPE_NOADDRESS != (MASK_GROUP_ADDRESS_TYPE & flags)); -} - -// This function gets the byte position of the last chargroup of the exact matching word in the -// dictionary. If no match is found, it returns NOT_A_VALID_WORD_POS. -AK_FORCE_INLINE int BinaryFormat::getTerminalPosition(const uint8_t *const root, - const int *const inWord, const int length, const bool forceLowerCaseSearch) { - int pos = 0; - int wordPos = 0; - - while (true) { - // If we already traversed the tree further than the word is long, there means - // there was no match (or we would have found it). - if (wordPos >= length) return NOT_A_VALID_WORD_POS; - int charGroupCount = BinaryFormat::getGroupCountAndForwardPointer(root, &pos); - const int wChar = forceLowerCaseSearch - ? CharUtils::toLowerCase(inWord[wordPos]) : inWord[wordPos]; - while (true) { - // If there are no more character groups in this node, it means we could not - // find a matching character for this depth, therefore there is no match. - if (0 >= charGroupCount) return NOT_A_VALID_WORD_POS; - const int charGroupPos = pos; - const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(root, &pos); - int character = BinaryFormat::getCodePointAndForwardPointer(root, &pos); - if (character == wChar) { - // This is the correct node. Only one character group may start with the same - // char within a node, so either we found our match in this node, or there is - // no match and we can return NOT_A_VALID_WORD_POS. So we will check all the - // characters in this character group indeed does match. - if (FLAG_HAS_MULTIPLE_CHARS & flags) { - character = BinaryFormat::getCodePointAndForwardPointer(root, &pos); - while (NOT_A_CODE_POINT != character) { - ++wordPos; - // If we shoot the length of the word we search for, or if we find a single - // character that does not match, as explained above, it means the word is - // not in the dictionary (by virtue of this chargroup being the only one to - // match the word on the first character, but not matching the whole word). - if (wordPos >= length) return NOT_A_VALID_WORD_POS; - if (inWord[wordPos] != character) return NOT_A_VALID_WORD_POS; - character = BinaryFormat::getCodePointAndForwardPointer(root, &pos); - } - } - // If we come here we know that so far, we do match. Either we are on a terminal - // and we match the length, in which case we found it, or we traverse children. - // If we don't match the length AND don't have children, then a word in the - // dictionary fully matches a prefix of the searched word but not the full word. - ++wordPos; - if (FLAG_IS_TERMINAL & flags) { - if (wordPos == length) { - return charGroupPos; - } - pos = BinaryFormat::skipProbability(FLAG_IS_TERMINAL, pos); - } - if (FLAG_GROUP_ADDRESS_TYPE_NOADDRESS == (MASK_GROUP_ADDRESS_TYPE & flags)) { - return NOT_A_VALID_WORD_POS; - } - // We have children and we are still shorter than the word we are searching for, so - // we need to traverse children. Put the pointer on the children position, and - // break - pos = BinaryFormat::readChildrenPosition(root, flags, pos); - break; - } else { - // This chargroup does not match, so skip the remaining part and go to the next. - if (FLAG_HAS_MULTIPLE_CHARS & flags) { - pos = BinaryFormat::skipOtherCharacters(root, pos); - } - pos = BinaryFormat::skipProbability(flags, pos); - pos = BinaryFormat::skipChildrenPosAndAttributes(root, flags, pos); - } - --charGroupCount; - } - } -} - -// This function searches for a terminal in the dictionary by its address. -// Due to the fact that words are ordered in the dictionary in a strict breadth-first order, -// it is possible to check for this with advantageous complexity. For each node, we search -// for groups with children and compare the children address with the address we look for. -// When we shoot the address we look for, it means the word we look for is in the children -// of the previous group. The only tricky part is the fact that if we arrive at the end of a -// node with the last group's children address still less than what we are searching for, we -// must descend the last group's children (for example, if the word we are searching for starts -// with a z, it's the last group of the root node, so all children addresses will be smaller -// than the address we look for, and we have to descend the z node). -/* Parameters : - * root: the dictionary buffer - * address: the byte position of the last chargroup of the word we are searching for (this is - * what is stored as the "bigram address" in each bigram) - * outword: an array to write the found word, with MAX_WORD_LENGTH size. - * outUnigramProbability: a pointer to an int to write the probability into. - * Return value : the length of the word, of 0 if the word was not found. - */ -AK_FORCE_INLINE int BinaryFormat::getCodePointsAndProbabilityAndReturnCodePointCount( - const uint8_t *const root, const int nodePos, const int maxCodePointCount, - int *const outCodePoints, int *const outUnigramProbability) { - int pos = 0; - int wordPos = 0; - - // One iteration of the outer loop iterates through nodes. As stated above, we will only - // traverse nodes that are actually a part of the terminal we are searching, so each time - // we enter this loop we are one depth level further than last time. - // The only reason we count nodes is because we want to reduce the probability of infinite - // looping in case there is a bug. Since we know there is an upper bound to the depth we are - // supposed to traverse, it does not hurt to count iterations. - for (int loopCount = maxCodePointCount; loopCount > 0; --loopCount) { - int lastCandidateGroupPos = 0; - // Let's loop through char groups in this node searching for either the terminal - // or one of its ascendants. - for (int charGroupCount = getGroupCountAndForwardPointer(root, &pos); charGroupCount > 0; - --charGroupCount) { - const int startPos = pos; - const uint8_t flags = getFlagsAndForwardPointer(root, &pos); - const int character = getCodePointAndForwardPointer(root, &pos); - if (nodePos == startPos) { - // We found the address. Copy the rest of the word in the buffer and return - // the length. - outCodePoints[wordPos] = character; - if (FLAG_HAS_MULTIPLE_CHARS & flags) { - int nextChar = getCodePointAndForwardPointer(root, &pos); - // We count chars in order to avoid infinite loops if the file is broken or - // if there is some other bug - int charCount = maxCodePointCount; - while (NOT_A_CODE_POINT != nextChar && --charCount > 0) { - outCodePoints[++wordPos] = nextChar; - nextChar = getCodePointAndForwardPointer(root, &pos); - } - } - *outUnigramProbability = readProbabilityWithoutMovingPointer(root, pos); - return ++wordPos; - } - // We need to skip past this char group, so skip any remaining chars after the - // first and possibly the probability. - if (FLAG_HAS_MULTIPLE_CHARS & flags) { - pos = skipOtherCharacters(root, pos); - } - pos = skipProbability(flags, pos); - - // The fact that this group has children is very important. Since we already know - // that this group does not match, if it has no children we know it is irrelevant - // to what we are searching for. - const bool hasChildren = (FLAG_GROUP_ADDRESS_TYPE_NOADDRESS != - (MASK_GROUP_ADDRESS_TYPE & flags)); - // We will write in `found' whether we have passed the children address we are - // searching for. For example if we search for "beer", the children of b are less - // than the address we are searching for and the children of c are greater. When we - // come here for c, we realize this is too big, and that we should descend b. - bool found; - if (hasChildren) { - // Here comes the tricky part. First, read the children position. - const int childrenPos = readChildrenPosition(root, flags, pos); - if (childrenPos > nodePos) { - // If the children pos is greater than address, it means the previous chargroup, - // which address is stored in lastCandidateGroupPos, was the right one. - found = true; - } else if (1 >= charGroupCount) { - // However if we are on the LAST group of this node, and we have NOT shot the - // address we should descend THIS node. So we trick the lastCandidateGroupPos - // so that we will descend this node, not the previous one. - lastCandidateGroupPos = startPos; - found = true; - } else { - // Else, we should continue looking. - found = false; - } - } else { - // Even if we don't have children here, we could still be on the last group of this - // node. If this is the case, we should descend the last group that had children, - // and their address is already in lastCandidateGroup. - found = (1 >= charGroupCount); - } - - if (found) { - // Okay, we found the group we should descend. Its address is in - // the lastCandidateGroupPos variable, so we just re-read it. - if (0 != lastCandidateGroupPos) { - const uint8_t lastFlags = - getFlagsAndForwardPointer(root, &lastCandidateGroupPos); - const int lastChar = - getCodePointAndForwardPointer(root, &lastCandidateGroupPos); - // We copy all the characters in this group to the buffer - outCodePoints[wordPos] = lastChar; - if (FLAG_HAS_MULTIPLE_CHARS & lastFlags) { - int nextChar = getCodePointAndForwardPointer(root, &lastCandidateGroupPos); - int charCount = maxCodePointCount; - while (-1 != nextChar && --charCount > 0) { - outCodePoints[++wordPos] = nextChar; - nextChar = getCodePointAndForwardPointer(root, &lastCandidateGroupPos); - } - } - ++wordPos; - // Now we only need to branch to the children address. Skip the probability if - // it's there, read pos, and break to resume the search at pos. - lastCandidateGroupPos = skipProbability(lastFlags, lastCandidateGroupPos); - pos = readChildrenPosition(root, lastFlags, lastCandidateGroupPos); - break; - } else { - // Here is a little tricky part: we come here if we found out that all children - // addresses in this group are bigger than the address we are searching for. - // Should we conclude the word is not in the dictionary? No! It could still be - // one of the remaining chargroups in this node, so we have to keep looking in - // this node until we find it (or we realize it's not there either, in which - // case it's actually not in the dictionary). Pass the end of this group, ready - // to start the next one. - pos = skipChildrenPosAndAttributes(root, flags, pos); - } - } else { - // If we did not find it, we should record the last children address for the next - // iteration. - if (hasChildren) lastCandidateGroupPos = startPos; - // Now skip the end of this group (children pos and the attributes if any) so that - // our pos is after the end of this char group, at the start of the next one. - pos = skipChildrenPosAndAttributes(root, flags, pos); - } - - } - } - // If we have looked through all the chargroups and found no match, the address is - // not the address of a terminal in this dictionary. - return 0; -} - -} // namespace latinime -#endif // LATINIME_BINARY_FORMAT_H diff --git a/native/jni/src/suggest/policyimpl/dictionary/dictionary_structure_with_buffer_policy_factory.cpp b/native/jni/src/suggest/policyimpl/dictionary/dictionary_structure_with_buffer_policy_factory.cpp index cc5252c43..ff80dd2f6 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/dictionary_structure_with_buffer_policy_factory.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/dictionary_structure_with_buffer_policy_factory.cpp @@ -22,17 +22,17 @@ #include "suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.h" #include "suggest/policyimpl/dictionary/patricia_trie_policy.h" #include "suggest/policyimpl/dictionary/utils/format_utils.h" -#include "suggest/policyimpl/dictionary/utils/mmaped_buffer.h" +#include "suggest/policyimpl/dictionary/utils/mmapped_buffer.h" namespace latinime { /* static */ DictionaryStructureWithBufferPolicy *DictionaryStructureWithBufferPolicyFactory - ::newDictionaryStructureWithBufferPolicy(const char *const path, const int pathLength, - const int bufOffset, const int size, const bool isUpdatable) { + ::newDictionaryStructureWithBufferPolicy(const char *const path, const int bufOffset, + const int size, const bool isUpdatable) { // Allocated buffer in MmapedBuffer::openBuffer() will be freed in the destructor of // impl classes of DictionaryStructureWithBufferPolicy. - const MmapedBuffer *const mmapedBuffer = MmapedBuffer::openBuffer(path, pathLength, bufOffset, - size, isUpdatable); + const MmappedBuffer *const mmapedBuffer = MmappedBuffer::openBuffer(path, bufOffset, size, + isUpdatable); if (!mmapedBuffer) { return 0; } diff --git a/native/jni/src/suggest/policyimpl/dictionary/dictionary_structure_with_buffer_policy_factory.h b/native/jni/src/suggest/policyimpl/dictionary/dictionary_structure_with_buffer_policy_factory.h index 1cb7a89c4..8cebc3b16 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/dictionary_structure_with_buffer_policy_factory.h +++ b/native/jni/src/suggest/policyimpl/dictionary/dictionary_structure_with_buffer_policy_factory.h @@ -27,8 +27,7 @@ namespace latinime { class DictionaryStructureWithBufferPolicyFactory { public: static DictionaryStructureWithBufferPolicy *newDictionaryStructureWithBufferPolicy( - const char *const path, const int pathLength, const int bufOffset, const int size, - const bool isUpdatable); + const char *const path, const int bufOffset, const int size, const bool isUpdatable); private: DISALLOW_IMPLICIT_CONSTRUCTORS(DictionaryStructureWithBufferPolicyFactory); diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_gc_event_listeners.cpp b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_gc_event_listeners.cpp new file mode 100644 index 000000000..c60e45819 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_gc_event_listeners.cpp @@ -0,0 +1,149 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_gc_event_listeners.h" + +namespace latinime { + +bool DynamicPatriciaTrieGcEventListeners + ::TraversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted + ::onVisitingPtNode(const DynamicPatriciaTrieNodeReader *const node, + const int *const nodeCodePoints) { + // PtNode is useless when the PtNode is not a terminal and doesn't have any not useless + // children. + bool isUselessPtNode = !node->isTerminal(); + if (mChildrenValue > 0) { + isUselessPtNode = false; + } else if (node->isTerminal()) { + // Remove children as all children are useless. + int writingPos = node->getChildrenPosFieldPos(); + if (!DynamicPatriciaTrieWritingUtils::writeChildrenPositionAndAdvancePosition( + mBuffer, NOT_A_DICT_POS /* childrenPosition */, &writingPos)) { + return false; + } + } + if (isUselessPtNode) { + // Current PtNode is no longer needed. Mark it as deleted. + if (!mWritingHelper->markNodeAsDeleted(node)) { + return false; + } + } else { + valueStack.back() += 1; + } + return true; +} + +// Writes dummy PtNode array size when the head of PtNode array is read. +bool DynamicPatriciaTrieGcEventListeners::TraversePolicyToPlaceAndWriteValidPtNodesToBuffer + ::onDescend(const int ptNodeArrayPos) { + mValidPtNodeCount = 0; + int writingPos = mBufferToWrite->getTailPosition(); + mDictPositionRelocationMap->mPtNodeArrayPositionRelocationMap.insert( + DynamicPatriciaTrieWritingHelper::PtNodeArrayPositionRelocationMap::value_type( + ptNodeArrayPos, writingPos)); + // Writes dummy PtNode array size because arrays can have a forward link or needles PtNodes. + // This field will be updated later in onReadingPtNodeArrayTail() with actual PtNode count. + mPtNodeArraySizeFieldPos = writingPos; + return DynamicPatriciaTrieWritingUtils::writePtNodeArraySizeAndAdvancePosition( + mBufferToWrite, 0 /* arraySize */, &writingPos); +} + +// Write PtNode array terminal and actual PtNode array size. +bool DynamicPatriciaTrieGcEventListeners::TraversePolicyToPlaceAndWriteValidPtNodesToBuffer + ::onReadingPtNodeArrayTail() { + int writingPos = mBufferToWrite->getTailPosition(); + // Write PtNode array terminal. + if (!DynamicPatriciaTrieWritingUtils::writeForwardLinkPositionAndAdvancePosition( + mBufferToWrite, NOT_A_DICT_POS /* forwardLinkPos */, &writingPos)) { + return false; + } + // Write actual PtNode array size. + if (!DynamicPatriciaTrieWritingUtils::writePtNodeArraySizeAndAdvancePosition( + mBufferToWrite, mValidPtNodeCount, &mPtNodeArraySizeFieldPos)) { + return false; + } + return true; +} + +// Write valid PtNode to buffer and memorize mapping from the old position to the new position. +bool DynamicPatriciaTrieGcEventListeners::TraversePolicyToPlaceAndWriteValidPtNodesToBuffer + ::onVisitingPtNode(const DynamicPatriciaTrieNodeReader *const node, + const int *const nodeCodePoints) { + if (node->isDeleted()) { + // Current PtNode is not written in new buffer because it has been deleted. + mDictPositionRelocationMap->mPtNodePositionRelocationMap.insert( + DynamicPatriciaTrieWritingHelper::PtNodePositionRelocationMap::value_type( + node->getHeadPos(), NOT_A_DICT_POS)); + return true; + } + int writingPos = mBufferToWrite->getTailPosition(); + mDictPositionRelocationMap->mPtNodePositionRelocationMap.insert( + DynamicPatriciaTrieWritingHelper::PtNodePositionRelocationMap::value_type( + node->getHeadPos(), writingPos)); + mValidPtNodeCount++; + // Writes current PtNode. + return mWritingHelper->writePtNodeToBufferByCopyingPtNodeInfo(mBufferToWrite, node, + node->getParentPos(), nodeCodePoints, node->getCodePointCount(), + node->getProbability(), &writingPos); +} + +bool DynamicPatriciaTrieGcEventListeners::TraversePolicyToUpdateAllPositionFields + ::onVisitingPtNode(const DynamicPatriciaTrieNodeReader *const node, + const int *const nodeCodePoints) { + // Updates parent position. + int parentPos = node->getParentPos(); + if (parentPos != NOT_A_DICT_POS) { + DynamicPatriciaTrieWritingHelper::PtNodePositionRelocationMap::const_iterator it = + mDictPositionRelocationMap->mPtNodePositionRelocationMap.find(parentPos); + if (it != mDictPositionRelocationMap->mPtNodePositionRelocationMap.end()) { + parentPos = it->second; + } + } + int writingPos = node->getHeadPos() + DynamicPatriciaTrieWritingUtils::NODE_FLAG_FIELD_SIZE; + // Write updated parent offset. + if (!DynamicPatriciaTrieWritingUtils::writeParentPosOffsetAndAdvancePosition(mBufferToWrite, + parentPos, node->getHeadPos(), &writingPos)) { + return false; + } + + // Updates children position. + int childrenPos = node->getChildrenPos(); + if (childrenPos != NOT_A_DICT_POS) { + DynamicPatriciaTrieWritingHelper::PtNodeArrayPositionRelocationMap::const_iterator it = + mDictPositionRelocationMap->mPtNodeArrayPositionRelocationMap.find(childrenPos); + if (it != mDictPositionRelocationMap->mPtNodeArrayPositionRelocationMap.end()) { + childrenPos = it->second; + } + } + writingPos = node->getChildrenPosFieldPos(); + if (!DynamicPatriciaTrieWritingUtils::writeChildrenPositionAndAdvancePosition(mBufferToWrite, + childrenPos, &writingPos)) { + return false; + } + + // Updates bigram target PtNode positions in the bigram list. + int bigramsPos = node->getBigramsPos(); + if (bigramsPos != NOT_A_DICT_POS) { + if (!mBigramPolicy->updateAllBigramTargetPtNodePositions(&bigramsPos, + &mDictPositionRelocationMap->mPtNodePositionRelocationMap)) { + return false; + } + } + + return true; +} + +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_gc_event_listeners.h b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_gc_event_listeners.h new file mode 100644 index 000000000..4256f22fb --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_gc_event_listeners.h @@ -0,0 +1,178 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DYNAMIC_PATRICIA_TRIE_GC_EVENT_LISTENERS_H +#define LATINIME_DYNAMIC_PATRICIA_TRIE_GC_EVENT_LISTENERS_H + +#include <vector> + +#include "defines.h" +#include "suggest/policyimpl/dictionary/bigram/dynamic_bigram_list_policy.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_helper.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_utils.h" +#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" +#include "utils/hash_map_compat.h" + +namespace latinime { + +class DynamicPatriciaTrieGcEventListeners { + public: + // Updates all PtNodes that can be reached from the root. Checks if each PtNode is useless or + // not and marks useless PtNodes as deleted. Such deleted PtNodes will be discarded in the GC. + // TODO: Concatenate non-terminal PtNodes. + class TraversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted + : public DynamicPatriciaTrieReadingHelper::TraversingEventListener { + public: + TraversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted( + DynamicPatriciaTrieWritingHelper *const writingHelper, + BufferWithExtendableBuffer *const buffer) + : mWritingHelper(writingHelper), mBuffer(buffer), valueStack(), + mChildrenValue(0) {} + + ~TraversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted() {}; + + bool onAscend() { + if (valueStack.empty()) { + return false; + } + mChildrenValue = valueStack.back(); + valueStack.pop_back(); + return true; + } + + bool onDescend(const int ptNodeArrayPos) { + valueStack.push_back(0); + return true; + } + + bool onReadingPtNodeArrayTail() { return true; } + + bool onVisitingPtNode(const DynamicPatriciaTrieNodeReader *const node, + const int *const nodeCodePoints); + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS( + TraversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted); + + DynamicPatriciaTrieWritingHelper *const mWritingHelper; + BufferWithExtendableBuffer *const mBuffer; + std::vector<int> valueStack; + int mChildrenValue; + }; + + // Updates all bigram entries that are held by valid PtNodes. This removes useless bigram + // entries. + class TraversePolicyToUpdateBigramProbability + : public DynamicPatriciaTrieReadingHelper::TraversingEventListener { + public: + TraversePolicyToUpdateBigramProbability(DynamicBigramListPolicy *const bigramPolicy) + : mBigramPolicy(bigramPolicy) {} + + bool onAscend() { return true; } + + bool onDescend(const int ptNodeArrayPos) { return true; } + + bool onReadingPtNodeArrayTail() { return true; } + + bool onVisitingPtNode(const DynamicPatriciaTrieNodeReader *const node, + const int *const nodeCodePoints) { + if (!node->isDeleted()) { + int pos = node->getBigramsPos(); + if (pos != NOT_A_DICT_POS) { + if (!mBigramPolicy->updateAllBigramEntriesAndDeleteUselessEntries(&pos)) { + return false; + } + } + } + return true; + } + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(TraversePolicyToUpdateBigramProbability); + + DynamicBigramListPolicy *const mBigramPolicy; + }; + + class TraversePolicyToPlaceAndWriteValidPtNodesToBuffer + : public DynamicPatriciaTrieReadingHelper::TraversingEventListener { + public: + TraversePolicyToPlaceAndWriteValidPtNodesToBuffer( + DynamicPatriciaTrieWritingHelper *const writingHelper, + BufferWithExtendableBuffer *const bufferToWrite, + DynamicPatriciaTrieWritingHelper::DictPositionRelocationMap *const + dictPositionRelocationMap) + : mWritingHelper(writingHelper), mBufferToWrite(bufferToWrite), + mDictPositionRelocationMap(dictPositionRelocationMap), mValidPtNodeCount(0), + mPtNodeArraySizeFieldPos(NOT_A_DICT_POS) {}; + + bool onAscend() { return true; } + + bool onDescend(const int ptNodeArrayPos); + + bool onReadingPtNodeArrayTail(); + + bool onVisitingPtNode(const DynamicPatriciaTrieNodeReader *const node, + const int *const nodeCodePoints); + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(TraversePolicyToPlaceAndWriteValidPtNodesToBuffer); + + DynamicPatriciaTrieWritingHelper *const mWritingHelper; + BufferWithExtendableBuffer *const mBufferToWrite; + DynamicPatriciaTrieWritingHelper::DictPositionRelocationMap *const + mDictPositionRelocationMap; + int mValidPtNodeCount; + int mPtNodeArraySizeFieldPos; + }; + + class TraversePolicyToUpdateAllPositionFields + : public DynamicPatriciaTrieReadingHelper::TraversingEventListener { + public: + TraversePolicyToUpdateAllPositionFields( + DynamicPatriciaTrieWritingHelper *const writingHelper, + DynamicBigramListPolicy *const bigramPolicy, + BufferWithExtendableBuffer *const bufferToWrite, + const DynamicPatriciaTrieWritingHelper::DictPositionRelocationMap *const + dictPositionRelocationMap) + : mWritingHelper(writingHelper), mBigramPolicy(bigramPolicy), + mBufferToWrite(bufferToWrite), + mDictPositionRelocationMap(dictPositionRelocationMap) {}; + + bool onAscend() { return true; } + + bool onDescend(const int ptNodeArrayPos) { return true; } + + bool onReadingPtNodeArrayTail() { return true; } + + bool onVisitingPtNode(const DynamicPatriciaTrieNodeReader *const node, + const int *const nodeCodePoints); + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(TraversePolicyToUpdateAllPositionFields); + + DynamicPatriciaTrieWritingHelper *const mWritingHelper; + DynamicBigramListPolicy *const mBigramPolicy; + BufferWithExtendableBuffer *const mBufferToWrite; + const DynamicPatriciaTrieWritingHelper::DictPositionRelocationMap *const + mDictPositionRelocationMap; + }; + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(DynamicPatriciaTrieGcEventListeners); +}; +} // namespace latinime +#endif /* LATINIME_DYNAMIC_PATRICIA_TRIE_GC_EVENT_LISTENERS_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_node_reader.cpp b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_node_reader.cpp index 77a85c86d..456352c17 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_node_reader.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_node_reader.cpp @@ -19,33 +19,66 @@ #include "suggest/core/policy/dictionary_bigrams_structure_policy.h" #include "suggest/core/policy/dictionary_shortcuts_structure_policy.h" #include "suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.h" +#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" namespace latinime { -void DynamicPatriciaTrieNodeReader::fetchNodeInfoFromBufferAndProcessMovedNode(const int nodePos, - const int maxCodePointCount, int *const outCodePoints) { - int pos = nodePos; - mFlags = PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(mDictRoot, &pos); - const int parentPos = - DynamicPatriciaTrieReadingUtils::getParentPosAndAdvancePosition(mDictRoot, &pos); - mParentPos = (parentPos != 0) ? mNodePos + parentPos : NOT_A_DICT_POS; +void DynamicPatriciaTrieNodeReader::fetchPtNodeInfoFromBufferAndProcessMovedPtNode( + const int ptNodePos, const int maxCodePointCount, int *const outCodePoints) { + if (ptNodePos < 0 || ptNodePos >= mBuffer->getTailPosition()) { + AKLOGE("Fetching PtNode info form invalid dictionary position: %d, dictionary size: %d", + ptNodePos, mBuffer->getTailPosition()); + ASSERT(false); + invalidatePtNodeInfo(); + return; + } + const bool usesAdditionalBuffer = mBuffer->isInAdditionalBuffer(ptNodePos); + const uint8_t *const dictBuf = mBuffer->getBuffer(usesAdditionalBuffer); + int pos = ptNodePos; + mHeadPos = ptNodePos; + if (usesAdditionalBuffer) { + pos -= mBuffer->getOriginalBufferSize(); + } + mFlags = PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(dictBuf, &pos); + const int parentPosOffset = + DynamicPatriciaTrieReadingUtils::getParentPtNodePosOffsetAndAdvancePosition(dictBuf, + &pos); + mParentPos = DynamicPatriciaTrieReadingUtils::getParentPtNodePos(parentPosOffset, mHeadPos); if (outCodePoints != 0) { mCodePointCount = PatriciaTrieReadingUtils::getCharsAndAdvancePosition( - mDictRoot, mFlags, maxCodePointCount, outCodePoints, &pos); + dictBuf, mFlags, maxCodePointCount, outCodePoints, &pos); } else { mCodePointCount = PatriciaTrieReadingUtils::skipCharacters( - mDictRoot, mFlags, MAX_WORD_LENGTH, &pos); + dictBuf, mFlags, MAX_WORD_LENGTH, &pos); } if (isTerminal()) { - mProbability = PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mDictRoot, &pos); + mProbabilityFieldPos = pos; + if (usesAdditionalBuffer) { + mProbabilityFieldPos += mBuffer->getOriginalBufferSize(); + } + mProbability = PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(dictBuf, &pos); } else { + mProbabilityFieldPos = NOT_A_DICT_POS; mProbability = NOT_A_PROBABILITY; } - if (hasChildren()) { - mChildrenPos = DynamicPatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition( - mDictRoot, mFlags, &pos); - } else { - mChildrenPos = NOT_A_DICT_POS; + mChildrenPosFieldPos = pos; + if (usesAdditionalBuffer) { + mChildrenPosFieldPos += mBuffer->getOriginalBufferSize(); + } + mChildrenPos = DynamicPatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition( + dictBuf, &pos); + if (usesAdditionalBuffer && mChildrenPos != NOT_A_DICT_POS) { + mChildrenPos += mBuffer->getOriginalBufferSize(); + } + if (mSiblingPos == NOT_A_DICT_POS) { + if (DynamicPatriciaTrieReadingUtils::isMoved(mFlags)) { + mBigramLinkedNodePos = mChildrenPos; + } else { + mBigramLinkedNodePos = NOT_A_DICT_POS; + } + } + if (usesAdditionalBuffer) { + pos += mBuffer->getOriginalBufferSize(); } if (PatriciaTrieReadingUtils::hasShortcutTargets(mFlags)) { mShortcutPos = pos; @@ -60,15 +93,31 @@ void DynamicPatriciaTrieNodeReader::fetchNodeInfoFromBufferAndProcessMovedNode(c mBigramPos = NOT_A_DICT_POS; } // Update siblingPos if needed. - if (mSiblingPos == NOT_A_VALID_WORD_POS) { + if (mSiblingPos == NOT_A_DICT_POS) { // Sibling position is the tail position of current node. mSiblingPos = pos; } // Read destination node if the read node is a moved node. if (DynamicPatriciaTrieReadingUtils::isMoved(mFlags)) { // The destination position is stored at the same place as the parent position. - fetchNodeInfoFromBufferAndProcessMovedNode(mParentPos, maxCodePointCount, outCodePoints); + fetchPtNodeInfoFromBufferAndProcessMovedPtNode(mParentPos, maxCodePointCount, + outCodePoints); } } +void DynamicPatriciaTrieNodeReader::invalidatePtNodeInfo() { + mHeadPos = NOT_A_DICT_POS; + mFlags = 0; + mParentPos = NOT_A_DICT_POS; + mCodePointCount = 0; + mProbabilityFieldPos = NOT_A_DICT_POS; + mProbability = NOT_A_PROBABILITY; + mChildrenPosFieldPos = NOT_A_DICT_POS; + mChildrenPos = NOT_A_DICT_POS; + mBigramLinkedNodePos = NOT_A_DICT_POS; + mShortcutPos = NOT_A_DICT_POS; + mBigramPos = NOT_A_DICT_POS; + mSiblingPos = NOT_A_DICT_POS; +} + } diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_node_reader.h b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_node_reader.h index e990809e8..3b36d425f 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_node_reader.h +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_node_reader.h @@ -25,6 +25,7 @@ namespace latinime { +class BufferWithExtendableBuffer; class DictionaryBigramsStructurePolicy; class DictionaryShortcutsStructurePolicy; @@ -34,32 +35,35 @@ class DictionaryShortcutsStructurePolicy; */ class DynamicPatriciaTrieNodeReader { public: - DynamicPatriciaTrieNodeReader(const uint8_t *const dictRoot, + DynamicPatriciaTrieNodeReader(const BufferWithExtendableBuffer *const buffer, const DictionaryBigramsStructurePolicy *const bigramsPolicy, const DictionaryShortcutsStructurePolicy *const shortcutsPolicy) - : mDictRoot(dictRoot), mBigramsPolicy(bigramsPolicy), - mShortcutsPolicy(shortcutsPolicy), mNodePos(NOT_A_VALID_WORD_POS), mFlags(0), - mParentPos(NOT_A_DICT_POS), mCodePointCount(0), mProbability(NOT_A_PROBABILITY), - mChildrenPos(NOT_A_DICT_POS), mShortcutPos(NOT_A_DICT_POS), - mBigramPos(NOT_A_DICT_POS), mSiblingPos(NOT_A_VALID_WORD_POS) {} + : mBuffer(buffer), mBigramsPolicy(bigramsPolicy), + mShortcutsPolicy(shortcutsPolicy), mHeadPos(NOT_A_DICT_POS), mFlags(0), + mParentPos(NOT_A_DICT_POS), mCodePointCount(0), mProbabilityFieldPos(NOT_A_DICT_POS), + mProbability(NOT_A_PROBABILITY), mChildrenPosFieldPos(NOT_A_DICT_POS), + mChildrenPos(NOT_A_DICT_POS), mBigramLinkedNodePos(NOT_A_DICT_POS), + mShortcutPos(NOT_A_DICT_POS), mBigramPos(NOT_A_DICT_POS), + mSiblingPos(NOT_A_DICT_POS) {} ~DynamicPatriciaTrieNodeReader() {} - // Reads node information from dictionary buffer and updates members with the information. - AK_FORCE_INLINE void fetchNodeInfoFromBuffer(const int nodePos) { - fetchNodeInfoFromBufferAndGetNodeCodePoints(nodePos , 0 /* maxCodePointCount */, - 0 /* outCodePoints */); + // Reads PtNode information from dictionary buffer and updates members with the information. + AK_FORCE_INLINE void fetchNodeInfoInBufferFromPtNodePos(const int ptNodePos) { + fetchNodeInfoInBufferFromPtNodePosAndGetNodeCodePoints(ptNodePos , + 0 /* maxCodePointCount */, 0 /* outCodePoints */); } - AK_FORCE_INLINE void fetchNodeInfoFromBufferAndGetNodeCodePoints(const int nodePos, - const int maxCodePointCount, int *const outCodePoints) { - mNodePos = nodePos; - mSiblingPos = NOT_A_VALID_WORD_POS; - fetchNodeInfoFromBufferAndProcessMovedNode(mNodePos, maxCodePointCount, outCodePoints); + AK_FORCE_INLINE void fetchNodeInfoInBufferFromPtNodePosAndGetNodeCodePoints( + const int ptNodePos, const int maxCodePointCount, int *const outCodePoints) { + mSiblingPos = NOT_A_DICT_POS; + mBigramLinkedNodePos = NOT_A_DICT_POS; + fetchPtNodeInfoFromBufferAndProcessMovedPtNode(ptNodePos, maxCodePointCount, outCodePoints); } - AK_FORCE_INLINE int getNodePos() const { - return mNodePos; + // HeadPos is different from NodePos when the current PtNode is a moved PtNode. + AK_FORCE_INLINE int getHeadPos() const { + return mHeadPos; } // Flags @@ -68,7 +72,7 @@ class DynamicPatriciaTrieNodeReader { } AK_FORCE_INLINE bool hasChildren() const { - return PatriciaTrieReadingUtils::hasChildrenInFlags(mFlags); + return mChildrenPos != NOT_A_DICT_POS; } AK_FORCE_INLINE bool isTerminal() const { @@ -94,15 +98,28 @@ class DynamicPatriciaTrieNodeReader { } // Probability + AK_FORCE_INLINE int getProbabilityFieldPos() const { + return mProbabilityFieldPos; + } + AK_FORCE_INLINE int getProbability() const { return mProbability; } - // Children node group position + // Children PtNode array position + AK_FORCE_INLINE int getChildrenPosFieldPos() const { + return mChildrenPosFieldPos; + } + AK_FORCE_INLINE int getChildrenPos() const { return mChildrenPos; } + // Bigram linked node position. + AK_FORCE_INLINE int getBigramLinkedNodePos() const { + return mBigramLinkedNodePos; + } + // Shortcutlist position AK_FORCE_INLINE int getShortcutPos() const { return mShortcutPos; @@ -121,22 +138,26 @@ class DynamicPatriciaTrieNodeReader { private: DISALLOW_COPY_AND_ASSIGN(DynamicPatriciaTrieNodeReader); - // TODO: Consolidate mDictRoot. - const uint8_t *const mDictRoot; + const BufferWithExtendableBuffer *const mBuffer; const DictionaryBigramsStructurePolicy *const mBigramsPolicy; const DictionaryShortcutsStructurePolicy *const mShortcutsPolicy; - int mNodePos; + int mHeadPos; DynamicPatriciaTrieReadingUtils::NodeFlags mFlags; int mParentPos; uint8_t mCodePointCount; + int mProbabilityFieldPos; int mProbability; + int mChildrenPosFieldPos; int mChildrenPos; + int mBigramLinkedNodePos; int mShortcutPos; int mBigramPos; int mSiblingPos; - void fetchNodeInfoFromBufferAndProcessMovedNode(const int nodePos, const int maxCodePointCount, - int *const outCodePoints); + void fetchPtNodeInfoFromBufferAndProcessMovedPtNode(const int ptNodePos, + const int maxCodePointCount, int *const outCodePoints); + + void invalidatePtNodeInfo(); }; } // namespace latinime #endif /* LATINIME_DYNAMIC_PATRICIA_TRIE_NODE_READER_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.cpp index 3000860a3..42397c19e 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.cpp @@ -20,95 +20,70 @@ #include "suggest/core/dicnode/dic_node.h" #include "suggest/core/dicnode/dic_node_vector.h" #include "suggest/policyimpl/dictionary/dynamic_patricia_trie_node_reader.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_helper.h" #include "suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.h" #include "suggest/policyimpl/dictionary/patricia_trie_reading_utils.h" +#include "suggest/policyimpl/dictionary/utils/probability_utils.h" namespace latinime { -// To avoid infinite loop caused by invalid or malicious forward links. -const int DynamicPatriciaTriePolicy::MAX_CHILD_COUNT_TO_AVOID_INFINITE_LOOP = 100000; - void DynamicPatriciaTriePolicy::createAndGetAllChildNodes(const DicNode *const dicNode, DicNodeVector *const childDicNodes) const { if (!dicNode->hasChildren()) { return; } - DynamicPatriciaTrieNodeReader nodeReader(mDictRoot, getBigramsStructurePolicy(), - getShortcutsStructurePolicy()); - int mergedNodeCodePoints[MAX_WORD_LENGTH]; - int nextPos = dicNode->getChildrenPos(); - int totalChildCount = 0; - do { - const int childCount = PatriciaTrieReadingUtils::getGroupCountAndAdvancePosition( - mDictRoot, &nextPos); - totalChildCount += childCount; - if (childCount <= 0 || totalChildCount > MAX_CHILD_COUNT_TO_AVOID_INFINITE_LOOP) { - // Invalid dictionary. - AKLOGI("Invalid dictionary. childCount: %d, totalChildCount: %d, MAX: %d", - childCount, totalChildCount, MAX_CHILD_COUNT_TO_AVOID_INFINITE_LOOP); - ASSERT(false); - return; - } - for (int i = 0; i < childCount; i++) { - nodeReader.fetchNodeInfoFromBufferAndGetNodeCodePoints(nextPos, MAX_WORD_LENGTH, - mergedNodeCodePoints); - if (!nodeReader.isDeleted()) { - // Push child node when the node is not a deleted node. - childDicNodes->pushLeavingChild(dicNode, nodeReader.getNodePos(), - nodeReader.getChildrenPos(), nodeReader.getProbability(), - nodeReader.isTerminal(), nodeReader.hasChildren(), - nodeReader.isBlacklisted() || nodeReader.isNotAWord(), - nodeReader.getCodePointCount(), mergedNodeCodePoints); - } - nextPos = nodeReader.getSiblingNodePos(); - } - nextPos = DynamicPatriciaTrieReadingUtils::getForwardLinkPosition(mDictRoot, nextPos); - } while (DynamicPatriciaTrieReadingUtils::isValidForwardLinkPosition(nextPos)); + DynamicPatriciaTrieReadingHelper readingHelper(&mBufferWithExtendableBuffer, + getBigramsStructurePolicy(), getShortcutsStructurePolicy()); + readingHelper.initWithPtNodeArrayPos(dicNode->getChildrenPos()); + const DynamicPatriciaTrieNodeReader *const nodeReader = readingHelper.getNodeReader(); + while (!readingHelper.isEnd()) { + childDicNodes->pushLeavingChild(dicNode, nodeReader->getHeadPos(), + nodeReader->getChildrenPos(), nodeReader->getProbability(), + nodeReader->isTerminal() && !nodeReader->isDeleted(), + nodeReader->hasChildren(), nodeReader->isBlacklisted() || nodeReader->isNotAWord(), + nodeReader->getCodePointCount(), readingHelper.getMergedNodeCodePoints()); + readingHelper.readNextSiblingNode(); + } } int DynamicPatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( - const int nodePos, const int maxCodePointCount, int *const outCodePoints, + const int ptNodePos, const int maxCodePointCount, int *const outCodePoints, int *const outUnigramProbability) const { - if (nodePos == NOT_A_VALID_WORD_POS) { - *outUnigramProbability = NOT_A_PROBABILITY; - return 0; - } // This method traverses parent nodes from the terminal by following parent pointers; thus, // node code points are stored in the buffer in the reverse order. int reverseCodePoints[maxCodePointCount]; - int mergedNodeCodePoints[maxCodePointCount]; - int codePointCount = 0; - - DynamicPatriciaTrieNodeReader nodeReader(mDictRoot, getBigramsStructurePolicy(), - getShortcutsStructurePolicy()); - // First, read terminal node and get its probability. - nodeReader.fetchNodeInfoFromBufferAndGetNodeCodePoints(nodePos, maxCodePointCount, - mergedNodeCodePoints); - // Store terminal node probability. - *outUnigramProbability = nodeReader.getProbability(); - // Store terminal node code points to buffer in the reverse order. - for (int i = nodeReader.getCodePointCount() - 1; i >= 0; --i) { - reverseCodePoints[codePointCount++] = mergedNodeCodePoints[i]; + DynamicPatriciaTrieReadingHelper readingHelper(&mBufferWithExtendableBuffer, + getBigramsStructurePolicy(), getShortcutsStructurePolicy()); + // First, read the terminal node and get its probability. + readingHelper.initWithPtNodePos(ptNodePos); + if (!readingHelper.isValidTerminalNode()) { + // Node at the ptNodePos is not a valid terminal node. + *outUnigramProbability = NOT_A_PROBABILITY; + return 0; } - // Then, follow parent pos toward the root node. - while (nodeReader.getParentPos() != NOT_A_DICT_POS) { - // codePointCount must be incremented at least once in each iteration to ensure preventing - // infinite loop. - if (nodeReader.isDeleted() || codePointCount > maxCodePointCount - || nodeReader.getCodePointCount() <= 0) { - // The nodePos is not a valid terminal node position in the dictionary. + // Store terminal node probability. + *outUnigramProbability = readingHelper.getNodeReader()->getProbability(); + // Then, following parent node link to the dictionary root and fetch node code points. + while (!readingHelper.isEnd()) { + if (readingHelper.getTotalCodePointCount() > maxCodePointCount) { + // The ptNodePos is not a valid terminal node position in the dictionary. *outUnigramProbability = NOT_A_PROBABILITY; return 0; } - // Read parent node. - nodeReader.fetchNodeInfoFromBufferAndGetNodeCodePoints(nodeReader.getParentPos(), - maxCodePointCount, mergedNodeCodePoints); // Store node code points to buffer in the reverse order. - for (int i = nodeReader.getCodePointCount() - 1; i >= 0; --i) { - reverseCodePoints[codePointCount++] = mergedNodeCodePoints[i]; - } + readingHelper.fetchMergedNodeCodePointsInReverseOrder( + readingHelper.getPrevTotalCodePointCount(), reverseCodePoints); + // Follow parent node toward the root node. + readingHelper.readParentNode(); + } + if (readingHelper.isError()) { + // The node position or the dictionary is invalid. + *outUnigramProbability = NOT_A_PROBABILITY; + return 0; } // Reverse the stored code points to output them. + const int codePointCount = readingHelper.getTotalCodePointCount(); for (int i = 0; i < codePointCount; ++i) { outCodePoints[i] = reverseCodePoints[codePointCount - i - 1]; } @@ -121,110 +96,91 @@ int DynamicPatriciaTriePolicy::getTerminalNodePositionOfWord(const int *const in for (int i = 0; i < length; ++i) { searchCodePoints[i] = forceLowerCaseSearch ? CharUtils::toLowerCase(inWord[i]) : inWord[i]; } - int mergedNodeCodePoints[MAX_WORD_LENGTH]; - int currentLength = 0; - int pos = getRootPosition(); - DynamicPatriciaTrieNodeReader nodeReader(mDictRoot, getBigramsStructurePolicy(), - getShortcutsStructurePolicy()); - while (currentLength <= length) { - // When foundMatchedNode becomes true, currentLength is increased at least once. - bool foundMatchedNode = false; - int totalChildCount = 0; - do { - const int childCount = PatriciaTrieReadingUtils::getGroupCountAndAdvancePosition( - mDictRoot, &pos); - totalChildCount += childCount; - if (childCount <= 0 || totalChildCount > MAX_CHILD_COUNT_TO_AVOID_INFINITE_LOOP) { - // Invalid dictionary. - AKLOGI("Invalid dictionary. childCount: %d, totalChildCount: %d, MAX: %d", - childCount, totalChildCount, MAX_CHILD_COUNT_TO_AVOID_INFINITE_LOOP); - ASSERT(false); - return NOT_A_VALID_WORD_POS; - } - for (int i = 0; i < childCount; i++) { - nodeReader.fetchNodeInfoFromBufferAndGetNodeCodePoints(pos, MAX_WORD_LENGTH, - mergedNodeCodePoints); - if (nodeReader.isDeleted() || nodeReader.getCodePointCount() <= 0) { - // Skip deleted or empty node. - pos = nodeReader.getSiblingNodePos(); - continue; - } - bool matched = true; - for (int j = 0; j < nodeReader.getCodePointCount(); ++j) { - if (mergedNodeCodePoints[j] != searchCodePoints[currentLength + j]) { - // Different code point is found. - matched = false; - break; - } - } - if (matched) { - currentLength += nodeReader.getCodePointCount(); - if (length == currentLength) { - // Terminal position is found. - return nodeReader.getNodePos(); - } - if (!nodeReader.hasChildren()) { - return NOT_A_VALID_WORD_POS; - } - foundMatchedNode = true; - // Advance to the children nodes. - pos = nodeReader.getChildrenPos(); - break; - } - // Try next sibling node. - pos = nodeReader.getSiblingNodePos(); - } - if (foundMatchedNode) { - break; + DynamicPatriciaTrieReadingHelper readingHelper(&mBufferWithExtendableBuffer, + getBigramsStructurePolicy(), getShortcutsStructurePolicy()); + readingHelper.initWithPtNodeArrayPos(getRootPosition()); + const DynamicPatriciaTrieNodeReader *const nodeReader = readingHelper.getNodeReader(); + while (!readingHelper.isEnd()) { + const int matchedCodePointCount = readingHelper.getPrevTotalCodePointCount(); + if (readingHelper.getTotalCodePointCount() > length + || !readingHelper.isMatchedCodePoint(0 /* index */, + searchCodePoints[matchedCodePointCount])) { + // Current node has too many code points or its first code point is different from + // target code point. Skip this node and read the next sibling node. + readingHelper.readNextSiblingNode(); + continue; + } + // Check following merged node code points. + const int nodeCodePointCount = nodeReader->getCodePointCount(); + for (int j = 1; j < nodeCodePointCount; ++j) { + if (!readingHelper.isMatchedCodePoint( + j, searchCodePoints[matchedCodePointCount + j])) { + // Different code point is found. The given word is not included in the dictionary. + return NOT_A_DICT_POS; } - // If the matched node is not found in the current node group, try to follow the - // forward link. - pos = DynamicPatriciaTrieReadingUtils::getForwardLinkPosition( - mDictRoot, pos); - } while (DynamicPatriciaTrieReadingUtils::isValidForwardLinkPosition(pos)); - if (!foundMatchedNode) { - // Matched node is not found. - return NOT_A_VALID_WORD_POS; } + // All characters are matched. + if (length == readingHelper.getTotalCodePointCount()) { + // Terminal position is found. + return nodeReader->getHeadPos(); + } + if (!nodeReader->hasChildren()) { + return NOT_A_DICT_POS; + } + // Advance to the children nodes. + readingHelper.readChildNode(); } // If we already traversed the tree further than the word is long, there means // there was no match (or we would have found it). - return NOT_A_VALID_WORD_POS; + return NOT_A_DICT_POS; } -int DynamicPatriciaTriePolicy::getUnigramProbability(const int nodePos) const { - if (nodePos == NOT_A_VALID_WORD_POS) { +int DynamicPatriciaTriePolicy::getProbability(const int unigramProbability, + const int bigramProbability) const { + // TODO: check mHeaderPolicy.usesForgettingCurve(); + if (unigramProbability == NOT_A_PROBABILITY) { return NOT_A_PROBABILITY; + } else if (bigramProbability == NOT_A_PROBABILITY) { + return ProbabilityUtils::backoff(unigramProbability); + } else { + return ProbabilityUtils::computeProbabilityForBigram(unigramProbability, + bigramProbability); } - DynamicPatriciaTrieNodeReader nodeReader(mDictRoot, getBigramsStructurePolicy(), - getShortcutsStructurePolicy()); - nodeReader.fetchNodeInfoFromBuffer(nodePos); +} + +int DynamicPatriciaTriePolicy::getUnigramProbabilityOfPtNode(const int ptNodePos) const { + if (ptNodePos == NOT_A_DICT_POS) { + return NOT_A_PROBABILITY; + } + DynamicPatriciaTrieNodeReader nodeReader(&mBufferWithExtendableBuffer, + getBigramsStructurePolicy(), getShortcutsStructurePolicy()); + nodeReader.fetchNodeInfoInBufferFromPtNodePos(ptNodePos); if (nodeReader.isDeleted() || nodeReader.isBlacklisted() || nodeReader.isNotAWord()) { return NOT_A_PROBABILITY; } - return nodeReader.getProbability(); + return getProbability(nodeReader.getProbability(), NOT_A_PROBABILITY); } -int DynamicPatriciaTriePolicy::getShortcutPositionOfNode(const int nodePos) const { - if (nodePos == NOT_A_VALID_WORD_POS) { +int DynamicPatriciaTriePolicy::getShortcutPositionOfPtNode(const int ptNodePos) const { + if (ptNodePos == NOT_A_DICT_POS) { return NOT_A_DICT_POS; } - DynamicPatriciaTrieNodeReader nodeReader(mDictRoot, getBigramsStructurePolicy(), - getShortcutsStructurePolicy()); - nodeReader.fetchNodeInfoFromBuffer(nodePos); + DynamicPatriciaTrieNodeReader nodeReader(&mBufferWithExtendableBuffer, + getBigramsStructurePolicy(), getShortcutsStructurePolicy()); + nodeReader.fetchNodeInfoInBufferFromPtNodePos(ptNodePos); if (nodeReader.isDeleted()) { return NOT_A_DICT_POS; } return nodeReader.getShortcutPos(); } -int DynamicPatriciaTriePolicy::getBigramsPositionOfNode(const int nodePos) const { - if (nodePos == NOT_A_VALID_WORD_POS) { +int DynamicPatriciaTriePolicy::getBigramsPositionOfPtNode(const int ptNodePos) const { + if (ptNodePos == NOT_A_DICT_POS) { return NOT_A_DICT_POS; } - DynamicPatriciaTrieNodeReader nodeReader(mDictRoot, getBigramsStructurePolicy(), - getShortcutsStructurePolicy()); - nodeReader.fetchNodeInfoFromBuffer(nodePos); + DynamicPatriciaTrieNodeReader nodeReader(&mBufferWithExtendableBuffer, + getBigramsStructurePolicy(), getShortcutsStructurePolicy()); + nodeReader.fetchNodeInfoInBufferFromPtNodePos(ptNodePos); if (nodeReader.isDeleted()) { return NOT_A_DICT_POS; } @@ -237,8 +193,12 @@ bool DynamicPatriciaTriePolicy::addUnigramWord(const int *const word, const int AKLOGI("Warning: addUnigramWord() is called for non-updatable dictionary."); return false; } - // TODO: Implement. - return false; + DynamicPatriciaTrieReadingHelper readingHelper(&mBufferWithExtendableBuffer, + getBigramsStructurePolicy(), getShortcutsStructurePolicy()); + readingHelper.initWithPtNodeArrayPos(getRootPosition()); + DynamicPatriciaTrieWritingHelper writingHelper(&mBufferWithExtendableBuffer, + &mBigramListPolicy, &mShortcutListPolicy); + return writingHelper.addUnigramWord(&readingHelper, word, length, probability); } bool DynamicPatriciaTriePolicy::addBigramWords(const int *const word0, const int length0, @@ -247,8 +207,19 @@ bool DynamicPatriciaTriePolicy::addBigramWords(const int *const word0, const int AKLOGI("Warning: addBigramWords() is called for non-updatable dictionary."); return false; } - // TODO: Implement. - return false; + const int word0Pos = getTerminalNodePositionOfWord(word0, length0, + false /* forceLowerCaseSearch */); + if (word0Pos == NOT_A_DICT_POS) { + return false; + } + const int word1Pos = getTerminalNodePositionOfWord(word1, length1, + false /* forceLowerCaseSearch */); + if (word1Pos == NOT_A_DICT_POS) { + return false; + } + DynamicPatriciaTrieWritingHelper writingHelper(&mBufferWithExtendableBuffer, + &mBigramListPolicy, &mShortcutListPolicy); + return writingHelper.addBigramWords(word0Pos, word1Pos, probability); } bool DynamicPatriciaTriePolicy::removeBigramWords(const int *const word0, const int length0, @@ -257,8 +228,48 @@ bool DynamicPatriciaTriePolicy::removeBigramWords(const int *const word0, const AKLOGI("Warning: removeBigramWords() is called for non-updatable dictionary."); return false; } - // TODO: Implement. - return false; + const int word0Pos = getTerminalNodePositionOfWord(word0, length0, + false /* forceLowerCaseSearch */); + if (word0Pos == NOT_A_DICT_POS) { + return false; + } + const int word1Pos = getTerminalNodePositionOfWord(word1, length1, + false /* forceLowerCaseSearch */); + if (word1Pos == NOT_A_DICT_POS) { + return false; + } + DynamicPatriciaTrieWritingHelper writingHelper(&mBufferWithExtendableBuffer, + &mBigramListPolicy, &mShortcutListPolicy); + return writingHelper.removeBigramWords(word0Pos, word1Pos); +} + +void DynamicPatriciaTriePolicy::flush(const char *const filePath) { + if (!mBuffer->isUpdatable()) { + AKLOGI("Warning: flush() is called for non-updatable dictionary."); + return; + } + DynamicPatriciaTrieWritingHelper writingHelper(&mBufferWithExtendableBuffer, + &mBigramListPolicy, &mShortcutListPolicy); + writingHelper.writeToDictFile(filePath, &mHeaderPolicy); +} + +void DynamicPatriciaTriePolicy::flushWithGC(const char *const filePath) { + if (!mBuffer->isUpdatable()) { + AKLOGI("Warning: flushWithGC() is called for non-updatable dictionary."); + return; + } + DynamicPatriciaTrieWritingHelper writingHelper(&mBufferWithExtendableBuffer, + &mBigramListPolicy, &mShortcutListPolicy); + writingHelper.writeToDictFileWithGC(getRootPosition(), filePath, &mHeaderPolicy); +} + +bool DynamicPatriciaTriePolicy::needsToRunGC() const { + if (!mBuffer->isUpdatable()) { + AKLOGI("Warning: needsToRunGC() is called for non-updatable dictionary."); + return false; + } + // TODO: Implement more properly. + return mBufferWithExtendableBuffer.isNearSizeLimit(); } } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.h index 708967d7b..06d8095d8 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.h @@ -17,14 +17,13 @@ #ifndef LATINIME_DYNAMIC_PATRICIA_TRIE_POLICY_H #define LATINIME_DYNAMIC_PATRICIA_TRIE_POLICY_H -#include <stdint.h> - #include "defines.h" #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" -#include "suggest/policyimpl/dictionary/bigram/bigram_list_policy.h" +#include "suggest/policyimpl/dictionary/bigram/dynamic_bigram_list_policy.h" #include "suggest/policyimpl/dictionary/header/header_policy.h" -#include "suggest/policyimpl/dictionary/shortcut/shortcut_list_policy.h" -#include "suggest/policyimpl/dictionary/utils/mmaped_buffer.h" +#include "suggest/policyimpl/dictionary/shortcut/dynamic_shortcut_list_policy.h" +#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" +#include "suggest/policyimpl/dictionary/utils/mmapped_buffer.h" namespace latinime { @@ -33,10 +32,12 @@ class DicNodeVector; class DynamicPatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { public: - DynamicPatriciaTriePolicy(const MmapedBuffer *const buffer) - : mBuffer(buffer), mHeaderPolicy(mBuffer->getBuffer()), - mDictRoot(mBuffer->getBuffer() + mHeaderPolicy.getSize()), - mBigramListPolicy(mDictRoot), mShortcutListPolicy(mDictRoot) {} + DynamicPatriciaTriePolicy(const MmappedBuffer *const buffer) + : mBuffer(buffer), mHeaderPolicy(mBuffer->getBuffer(), buffer->getBufferSize()), + mBufferWithExtendableBuffer(mBuffer->getBuffer() + mHeaderPolicy.getSize(), + mBuffer->getBufferSize() - mHeaderPolicy.getSize()), + mShortcutListPolicy(&mBufferWithExtendableBuffer), + mBigramListPolicy(&mBufferWithExtendableBuffer, &mShortcutListPolicy) {} ~DynamicPatriciaTriePolicy() { delete mBuffer; @@ -50,17 +51,19 @@ class DynamicPatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { DicNodeVector *const childDicNodes) const; int getCodePointsAndProbabilityAndReturnCodePointCount( - const int terminalNodePos, const int maxCodePointCount, int *const outCodePoints, + const int terminalPtNodePos, const int maxCodePointCount, int *const outCodePoints, int *const outUnigramProbability) const; int getTerminalNodePositionOfWord(const int *const inWord, const int length, const bool forceLowerCaseSearch) const; - int getUnigramProbability(const int nodePos) const; + int getProbability(const int unigramProbability, const int bigramProbability) const; + + int getUnigramProbabilityOfPtNode(const int ptNodePos) const; - int getShortcutPositionOfNode(const int nodePos) const; + int getShortcutPositionOfPtNode(const int ptNodePos) const; - int getBigramsPositionOfNode(const int nodePos) const; + int getBigramsPositionOfPtNode(const int ptNodePos) const; const DictionaryHeaderStructurePolicy *getHeaderStructurePolicy() const { return &mHeaderPolicy; @@ -82,16 +85,20 @@ class DynamicPatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { bool removeBigramWords(const int *const word0, const int length0, const int *const word1, const int length1); + void flush(const char *const filePath); + + void flushWithGC(const char *const filePath); + + bool needsToRunGC() const; + private: DISALLOW_IMPLICIT_CONSTRUCTORS(DynamicPatriciaTriePolicy); - static const int MAX_CHILD_COUNT_TO_AVOID_INFINITE_LOOP; - const MmapedBuffer *const mBuffer; + const MmappedBuffer *const mBuffer; const HeaderPolicy mHeaderPolicy; - // TODO: Consolidate mDictRoot. - const uint8_t *const mDictRoot; - const BigramListPolicy mBigramListPolicy; - const ShortcutListPolicy mShortcutListPolicy; + BufferWithExtendableBuffer mBufferWithExtendableBuffer; + DynamicShortcutListPolicy mShortcutListPolicy; + DynamicBigramListPolicy mBigramListPolicy; }; } // namespace latinime #endif // LATINIME_DYNAMIC_PATRICIA_TRIE_POLICY_H diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_helper.cpp new file mode 100644 index 000000000..f4a2ef389 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_helper.cpp @@ -0,0 +1,215 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_helper.h" + +#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" + +namespace latinime { + +// To avoid infinite loop caused by invalid or malicious forward links. +const int DynamicPatriciaTrieReadingHelper::MAX_CHILD_COUNT_TO_AVOID_INFINITE_LOOP = 100000; +const int DynamicPatriciaTrieReadingHelper::MAX_NODE_ARRAY_COUNT_TO_AVOID_INFINITE_LOOP = 100000; +const size_t DynamicPatriciaTrieReadingHelper::MAX_READING_STATE_STACK_SIZE = MAX_WORD_LENGTH; + +// Visits all PtNodes in post-order depth first manner. +// For example, visits c -> b -> y -> x -> a for the following dictionary: +// a _ b _ c +// \ x _ y +bool DynamicPatriciaTrieReadingHelper::traverseAllPtNodesInPostorderDepthFirstManner( + TraversingEventListener *const listener) { + bool alreadyVisitedChildren = false; + // Descend from the root to the root PtNode array. + if (!listener->onDescend(getPosOfLastPtNodeArrayHead())) { + return false; + } + while (!isEnd()) { + if (!alreadyVisitedChildren) { + if (mNodeReader.hasChildren()) { + // Move to the first child. + if (!listener->onDescend(mNodeReader.getChildrenPos())) { + return false; + } + pushReadingStateToStack(); + readChildNode(); + } else { + alreadyVisitedChildren = true; + } + } else { + if (!listener->onVisitingPtNode(&mNodeReader, mMergedNodeCodePoints)) { + return false; + } + readNextSiblingNode(); + if (isEnd()) { + // All PtNodes in current linked PtNode arrays have been visited. + // Return to the parent. + if (!listener->onReadingPtNodeArrayTail()) { + return false; + } + if (mReadingStateStack.size() <= 0) { + break; + } + if (!listener->onAscend()) { + return false; + } + popReadingStateFromStack(); + alreadyVisitedChildren = true; + } else { + // Process sibling PtNode. + alreadyVisitedChildren = false; + } + } + } + // Ascend from the root PtNode array to the root. + if (!listener->onAscend()) { + return false; + } + return !isError(); +} + +// Visits all PtNodes in PtNode array level pre-order depth first manner, which is the same order +// that PtNodes are written in the dictionary buffer. +// For example, visits a -> b -> x -> c -> y for the following dictionary: +// a _ b _ c +// \ x _ y +bool DynamicPatriciaTrieReadingHelper::traverseAllPtNodesInPtNodeArrayLevelPreorderDepthFirstManner( + TraversingEventListener *const listener) { + bool alreadyVisitedAllPtNodesInArray = false; + bool alreadyVisitedChildren = false; + // Descend from the root to the root PtNode array. + if (!listener->onDescend(getPosOfLastPtNodeArrayHead())) { + return false; + } + pushReadingStateToStack(); + while (!isEnd()) { + if (alreadyVisitedAllPtNodesInArray) { + if (alreadyVisitedChildren) { + // Move to next sibling PtNode's children. + readNextSiblingNode(); + if (isEnd()) { + // Return to the parent PTNode. + if (!listener->onAscend()) { + return false; + } + if (mReadingStateStack.size() <= 0) { + break; + } + popReadingStateFromStack(); + alreadyVisitedChildren = true; + alreadyVisitedAllPtNodesInArray = true; + } else { + alreadyVisitedChildren = false; + } + } else { + if (mNodeReader.hasChildren()) { + // Move to the first child. + if (!listener->onDescend(mNodeReader.getChildrenPos())) { + return false; + } + pushReadingStateToStack(); + readChildNode(); + // Push state to return the head of PtNode array. + pushReadingStateToStack(); + alreadyVisitedAllPtNodesInArray = false; + alreadyVisitedChildren = false; + } else { + alreadyVisitedChildren = true; + } + } + } else { + if (!listener->onVisitingPtNode(&mNodeReader, mMergedNodeCodePoints)) { + return false; + } + readNextSiblingNode(); + if (isEnd()) { + if (!listener->onReadingPtNodeArrayTail()) { + return false; + } + // Return to the head of current PtNode array. + popReadingStateFromStack(); + alreadyVisitedAllPtNodesInArray = true; + } + } + } + popReadingStateFromStack(); + // Ascend from the root PtNode array to the root. + if (!listener->onAscend()) { + return false; + } + return !isError(); +} + +// Read node array size and process empty node arrays. Nodes and arrays are counted up in this +// method to avoid an infinite loop. +void DynamicPatriciaTrieReadingHelper::nextPtNodeArray() { + mReadingState.mPosOfLastPtNodeArrayHead = mReadingState.mPos; + const bool usesAdditionalBuffer = mBuffer->isInAdditionalBuffer(mReadingState.mPos); + const uint8_t *const dictBuf = mBuffer->getBuffer(usesAdditionalBuffer); + if (usesAdditionalBuffer) { + mReadingState.mPos -= mBuffer->getOriginalBufferSize(); + } + mReadingState.mNodeCount = PatriciaTrieReadingUtils::getPtNodeArraySizeAndAdvancePosition( + dictBuf, &mReadingState.mPos); + if (usesAdditionalBuffer) { + mReadingState.mPos += mBuffer->getOriginalBufferSize(); + } + // Count up nodes and node arrays to avoid infinite loop. + mReadingState.mTotalNodeCount += mReadingState.mNodeCount; + mReadingState.mNodeArrayCount++; + if (mReadingState.mNodeCount < 0 + || mReadingState.mTotalNodeCount > MAX_CHILD_COUNT_TO_AVOID_INFINITE_LOOP + || mReadingState.mNodeArrayCount > MAX_NODE_ARRAY_COUNT_TO_AVOID_INFINITE_LOOP) { + // Invalid dictionary. + AKLOGI("Invalid dictionary. nodeCount: %d, totalNodeCount: %d, MAX_CHILD_COUNT: %d" + "nodeArrayCount: %d, MAX_NODE_ARRAY_COUNT: %d", + mReadingState.mNodeCount, mReadingState.mTotalNodeCount, + MAX_CHILD_COUNT_TO_AVOID_INFINITE_LOOP, mReadingState.mNodeArrayCount, + MAX_NODE_ARRAY_COUNT_TO_AVOID_INFINITE_LOOP); + ASSERT(false); + mIsError = true; + mReadingState.mPos = NOT_A_DICT_POS; + return; + } + if (mReadingState.mNodeCount == 0) { + // Empty node array. Try following forward link. + followForwardLink(); + } +} + +// Follow the forward link and read the next node array if exists. +void DynamicPatriciaTrieReadingHelper::followForwardLink() { + const bool usesAdditionalBuffer = mBuffer->isInAdditionalBuffer(mReadingState.mPos); + const uint8_t *const dictBuf = mBuffer->getBuffer(usesAdditionalBuffer); + if (usesAdditionalBuffer) { + mReadingState.mPos -= mBuffer->getOriginalBufferSize(); + } + const int forwardLinkPosition = + DynamicPatriciaTrieReadingUtils::getForwardLinkPosition(dictBuf, mReadingState.mPos); + if (usesAdditionalBuffer) { + mReadingState.mPos += mBuffer->getOriginalBufferSize(); + } + mReadingState.mPosOfLastForwardLinkField = mReadingState.mPos; + if (DynamicPatriciaTrieReadingUtils::isValidForwardLinkPosition(forwardLinkPosition)) { + // Follow the forward link. + mReadingState.mPos += forwardLinkPosition; + nextPtNodeArray(); + } else { + // All node arrays have been read. + mReadingState.mPos = NOT_A_DICT_POS; + } +} + +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_helper.h b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_helper.h new file mode 100644 index 000000000..c6d8ddcf7 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_helper.h @@ -0,0 +1,286 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DYNAMIC_PATRICIA_TRIE_READING_HELPER_H +#define LATINIME_DYNAMIC_PATRICIA_TRIE_READING_HELPER_H + +#include <cstddef> +#include <vector> + +#include "defines.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_node_reader.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.h" +#include "suggest/policyimpl/dictionary/patricia_trie_reading_utils.h" + +namespace latinime { + +class BufferWithExtendableBuffer; +class DictionaryBigramsStructurePolicy; +class DictionaryShortcutsStructurePolicy; + +/* + * This class is used for traversing dynamic patricia trie. This class supports iterating nodes and + * dealing with additional buffer. This class counts nodes and node arrays to avoid infinite loop. + */ +class DynamicPatriciaTrieReadingHelper { + public: + class TraversingEventListener { + public: + virtual ~TraversingEventListener() {}; + + // Returns whether the event handling was succeeded or not. + virtual bool onAscend() = 0; + + // Returns whether the event handling was succeeded or not. + virtual bool onDescend(const int ptNodeArrayPos) = 0; + + // Returns whether the event handling was succeeded or not. + virtual bool onReadingPtNodeArrayTail() = 0; + + // Returns whether the event handling was succeeded or not. + virtual bool onVisitingPtNode(const DynamicPatriciaTrieNodeReader *const node, + const int *const nodeCodePoints) = 0; + + protected: + TraversingEventListener() {}; + + private: + DISALLOW_COPY_AND_ASSIGN(TraversingEventListener); + }; + + DynamicPatriciaTrieReadingHelper(const BufferWithExtendableBuffer *const buffer, + const DictionaryBigramsStructurePolicy *const bigramsPolicy, + const DictionaryShortcutsStructurePolicy *const shortcutsPolicy) + : mIsError(false), mReadingState(), mBuffer(buffer), + mNodeReader(mBuffer, bigramsPolicy, shortcutsPolicy), mReadingStateStack() {} + + ~DynamicPatriciaTrieReadingHelper() {} + + AK_FORCE_INLINE bool isError() const { + return mIsError; + } + + AK_FORCE_INLINE bool isEnd() const { + return mReadingState.mPos == NOT_A_DICT_POS; + } + + // Initialize reading state with the head position of a PtNode array. + AK_FORCE_INLINE void initWithPtNodeArrayPos(const int ptNodeArrayPos) { + if (ptNodeArrayPos == NOT_A_DICT_POS) { + mReadingState.mPos = NOT_A_DICT_POS; + } else { + mIsError = false; + mReadingState.mPos = ptNodeArrayPos; + mReadingState.mPrevTotalCodePointCount = 0; + mReadingState.mTotalNodeCount = 0; + mReadingState.mNodeArrayCount = 0; + mReadingState.mPosOfLastForwardLinkField = NOT_A_DICT_POS; + mReadingStateStack.clear(); + nextPtNodeArray(); + if (!isEnd()) { + fetchPtNodeInfo(); + } + } + } + + // Initialize reading state with the head position of a node. + AK_FORCE_INLINE void initWithPtNodePos(const int ptNodePos) { + if (ptNodePos == NOT_A_DICT_POS) { + mReadingState.mPos = NOT_A_DICT_POS; + } else { + mIsError = false; + mReadingState.mPos = ptNodePos; + mReadingState.mNodeCount = 1; + mReadingState.mPrevTotalCodePointCount = 0; + mReadingState.mTotalNodeCount = 1; + mReadingState.mNodeArrayCount = 1; + mReadingState.mPosOfLastForwardLinkField = NOT_A_DICT_POS; + mReadingState.mPosOfLastPtNodeArrayHead = NOT_A_DICT_POS; + mReadingStateStack.clear(); + fetchPtNodeInfo(); + } + } + + AK_FORCE_INLINE const DynamicPatriciaTrieNodeReader* getNodeReader() const { + return &mNodeReader; + } + + AK_FORCE_INLINE bool isValidTerminalNode() const { + return !isEnd() && !mNodeReader.isDeleted() && mNodeReader.isTerminal(); + } + + AK_FORCE_INLINE bool isMatchedCodePoint(const int index, const int codePoint) const { + return mMergedNodeCodePoints[index] == codePoint; + } + + // Return code point count exclude the last read node's code points. + AK_FORCE_INLINE int getPrevTotalCodePointCount() const { + return mReadingState.mPrevTotalCodePointCount; + } + + // Return code point count include the last read node's code points. + AK_FORCE_INLINE int getTotalCodePointCount() const { + return mReadingState.mPrevTotalCodePointCount + mNodeReader.getCodePointCount(); + } + + AK_FORCE_INLINE void fetchMergedNodeCodePointsInReverseOrder( + const int index, int *const outCodePoints) const { + const int nodeCodePointCount = mNodeReader.getCodePointCount(); + for (int i = 0; i < nodeCodePointCount; ++i) { + outCodePoints[index + i] = mMergedNodeCodePoints[nodeCodePointCount - 1 - i]; + } + } + + AK_FORCE_INLINE const int *getMergedNodeCodePoints() const { + return mMergedNodeCodePoints; + } + + AK_FORCE_INLINE void readNextSiblingNode() { + mReadingState.mNodeCount -= 1; + mReadingState.mPos = mNodeReader.getSiblingNodePos(); + if (mReadingState.mNodeCount <= 0) { + // All nodes in the current node array have been read. + followForwardLink(); + if (!isEnd()) { + fetchPtNodeInfo(); + } + } else { + fetchPtNodeInfo(); + } + } + + // Read the first child node of the current node. + AK_FORCE_INLINE void readChildNode() { + if (mNodeReader.hasChildren()) { + mReadingState.mPrevTotalCodePointCount += mNodeReader.getCodePointCount(); + mReadingState.mTotalNodeCount = 0; + mReadingState.mNodeArrayCount = 0; + mReadingState.mPos = mNodeReader.getChildrenPos(); + mReadingState.mPosOfLastForwardLinkField = NOT_A_DICT_POS; + // Read children node array. + nextPtNodeArray(); + if (!isEnd()) { + fetchPtNodeInfo(); + } + } else { + mReadingState.mPos = NOT_A_DICT_POS; + } + } + + // Read the parent node of the current node. + AK_FORCE_INLINE void readParentNode() { + if (mNodeReader.getParentPos() != NOT_A_DICT_POS) { + mReadingState.mPrevTotalCodePointCount += mNodeReader.getCodePointCount(); + mReadingState.mTotalNodeCount = 1; + mReadingState.mNodeArrayCount = 1; + mReadingState.mNodeCount = 1; + mReadingState.mPos = mNodeReader.getParentPos(); + mReadingState.mPosOfLastForwardLinkField = NOT_A_DICT_POS; + mReadingState.mPosOfLastPtNodeArrayHead = NOT_A_DICT_POS; + fetchPtNodeInfo(); + } else { + mReadingState.mPos = NOT_A_DICT_POS; + } + } + + AK_FORCE_INLINE int getPosOfLastForwardLinkField() const { + return mReadingState.mPosOfLastForwardLinkField; + } + + AK_FORCE_INLINE int getPosOfLastPtNodeArrayHead() const { + return mReadingState.mPosOfLastPtNodeArrayHead; + } + + AK_FORCE_INLINE void reloadCurrentPtNodeInfo() { + if (!isEnd()) { + fetchPtNodeInfo(); + } + } + + bool traverseAllPtNodesInPostorderDepthFirstManner(TraversingEventListener *const listener); + + bool traverseAllPtNodesInPtNodeArrayLevelPreorderDepthFirstManner( + TraversingEventListener *const listener); + + private: + DISALLOW_COPY_AND_ASSIGN(DynamicPatriciaTrieReadingHelper); + + class ReadingState { + public: + // Note that copy constructor and assignment operator are used for this class to use + // std::vector. + ReadingState() : mPos(NOT_A_DICT_POS), mNodeCount(0), mPrevTotalCodePointCount(0), + mTotalNodeCount(0), mNodeArrayCount(0), mPosOfLastForwardLinkField(NOT_A_DICT_POS), + mPosOfLastPtNodeArrayHead(NOT_A_DICT_POS) {} + + int mPos; + // Node count of a node array. + int mNodeCount; + int mPrevTotalCodePointCount; + int mTotalNodeCount; + int mNodeArrayCount; + int mPosOfLastForwardLinkField; + int mPosOfLastPtNodeArrayHead; + }; + + static const int MAX_CHILD_COUNT_TO_AVOID_INFINITE_LOOP; + static const int MAX_NODE_ARRAY_COUNT_TO_AVOID_INFINITE_LOOP; + static const size_t MAX_READING_STATE_STACK_SIZE; + + bool mIsError; + ReadingState mReadingState; + const BufferWithExtendableBuffer *const mBuffer; + DynamicPatriciaTrieNodeReader mNodeReader; + int mMergedNodeCodePoints[MAX_WORD_LENGTH]; + std::vector<ReadingState> mReadingStateStack; + + void nextPtNodeArray(); + + void followForwardLink(); + + AK_FORCE_INLINE void fetchPtNodeInfo() { + mNodeReader.fetchNodeInfoInBufferFromPtNodePosAndGetNodeCodePoints(mReadingState.mPos, + MAX_WORD_LENGTH, mMergedNodeCodePoints); + if (mNodeReader.getCodePointCount() <= 0) { + // Empty node is not allowed. + mIsError = true; + mReadingState.mPos = NOT_A_DICT_POS; + } + } + + AK_FORCE_INLINE void pushReadingStateToStack() { + if (mReadingStateStack.size() > MAX_READING_STATE_STACK_SIZE) { + AKLOGI("Reading state stack overflow. Max size: %zd", MAX_READING_STATE_STACK_SIZE); + ASSERT(false); + mIsError = true; + mReadingState.mPos = NOT_A_DICT_POS; + } else { + mReadingStateStack.push_back(mReadingState); + } + } + + AK_FORCE_INLINE void popReadingStateFromStack() { + if (mReadingStateStack.empty()) { + mReadingState.mPos = NOT_A_DICT_POS; + } else { + mReadingState = mReadingStateStack.back(); + mReadingStateStack.pop_back(); + fetchPtNodeInfo(); + } + } +}; +} // namespace latinime +#endif /* LATINIME_DYNAMIC_PATRICIA_TRIE_READING_HELPER_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.cpp index 1ef3b65c3..d68446db6 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.cpp @@ -28,13 +28,44 @@ const DptReadingUtils::NodeFlags DptReadingUtils::FLAG_IS_NOT_MOVED = 0xC0; const DptReadingUtils::NodeFlags DptReadingUtils::FLAG_IS_MOVED = 0x40; const DptReadingUtils::NodeFlags DptReadingUtils::FLAG_IS_DELETED = 0x80; -/* static */ int DptReadingUtils::readChildrenPositionAndAdvancePosition( - const uint8_t *const buffer, const NodeFlags flags, int *const pos) { - if ((flags & MASK_MOVED) == FLAG_IS_NOT_MOVED) { - const int base = *pos; - return base + ByteArrayUtils::readSint24AndAdvancePosition(buffer, pos); +// TODO: Make DICT_OFFSET_ZERO_OFFSET = 0. +// Currently, DICT_OFFSET_INVALID is 0 in Java side but offset can be 0 during GC. So, the maximum +// value of offsets, which is 0x7FFFFF is used to represent 0 offset. +const int DptReadingUtils::DICT_OFFSET_INVALID = 0; +const int DptReadingUtils::DICT_OFFSET_ZERO_OFFSET = 0x7FFFFF; + +/* static */ int DptReadingUtils::getForwardLinkPosition(const uint8_t *const buffer, + const int pos) { + int linkAddressPos = pos; + return ByteArrayUtils::readSint24AndAdvancePosition(buffer, &linkAddressPos); +} + +/* static */ int DptReadingUtils::getParentPtNodePosOffsetAndAdvancePosition( + const uint8_t *const buffer, int *const pos) { + return ByteArrayUtils::readSint24AndAdvancePosition(buffer, pos); +} + +/* static */ int DptReadingUtils::getParentPtNodePos(const int parentOffset, const int ptNodePos) { + if (parentOffset == DICT_OFFSET_INVALID) { + return NOT_A_DICT_POS; + } else if (parentOffset == DICT_OFFSET_ZERO_OFFSET) { + return ptNodePos; } else { + return parentOffset + ptNodePos; + } +} + +/* static */ int DptReadingUtils::readChildrenPositionAndAdvancePosition( + const uint8_t *const buffer, int *const pos) { + const int base = *pos; + const int offset = ByteArrayUtils::readSint24AndAdvancePosition(buffer, pos); + if (offset == DICT_OFFSET_INVALID) { + // The PtNode does not have children. return NOT_A_DICT_POS; + } else if (offset == DICT_OFFSET_ZERO_OFFSET) { + return base; + } else { + return base + offset; } } diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.h b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.h index a6cb46d39..67c3cc57e 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.h @@ -20,7 +20,6 @@ #include <stdint.h> #include "defines.h" -#include "suggest/policyimpl/dictionary/utils/byte_array_utils.h" namespace latinime { @@ -28,22 +27,21 @@ class DynamicPatriciaTrieReadingUtils { public: typedef uint8_t NodeFlags; - static AK_FORCE_INLINE int getForwardLinkPosition(const uint8_t *const buffer, const int pos) { - int linkAddressPos = pos; - return ByteArrayUtils::readSint24AndAdvancePosition(buffer, &linkAddressPos); - } + static const int DICT_OFFSET_INVALID; + static const int DICT_OFFSET_ZERO_OFFSET; + + static int getForwardLinkPosition(const uint8_t *const buffer, const int pos); static AK_FORCE_INLINE bool isValidForwardLinkPosition(const int forwardLinkAddress) { return forwardLinkAddress != 0; } - static AK_FORCE_INLINE int getParentPosAndAdvancePosition(const uint8_t *const buffer, - int *const pos) { - return ByteArrayUtils::readSint24AndAdvancePosition(buffer, pos); - } + static int getParentPtNodePosOffsetAndAdvancePosition(const uint8_t *const buffer, + int *const pos); - static int readChildrenPositionAndAdvancePosition(const uint8_t *const buffer, - const NodeFlags flags, int *const pos); + static int getParentPtNodePos(const int parentOffset, const int ptNodePos); + + static int readChildrenPositionAndAdvancePosition(const uint8_t *const buffer, int *const pos); /** * Node Flags @@ -56,6 +54,15 @@ class DynamicPatriciaTrieReadingUtils { return FLAG_IS_DELETED == (MASK_MOVED & flags); } + static AK_FORCE_INLINE NodeFlags updateAndGetFlags(const NodeFlags originalFlags, + const bool isMoved, const bool isDeleted) { + NodeFlags flags = originalFlags; + flags = isMoved ? ((flags & (~MASK_MOVED)) | FLAG_IS_MOVED) : flags; + flags = isDeleted ? ((flags & (~MASK_MOVED)) | FLAG_IS_DELETED) : flags; + flags = (!isMoved && !isDeleted) ? ((flags & (~MASK_MOVED)) | FLAG_IS_NOT_MOVED) : flags; + return flags; + } + private: DISALLOW_IMPLICIT_CONSTRUCTORS(DynamicPatriciaTrieReadingUtils); diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.cpp new file mode 100644 index 000000000..578645cd5 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.cpp @@ -0,0 +1,511 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.h" + +#include "suggest/policyimpl/dictionary/bigram/dynamic_bigram_list_policy.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_gc_event_listeners.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_node_reader.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_helper.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_utils.h" +#include "suggest/policyimpl/dictionary/header/header_policy.h" +#include "suggest/policyimpl/dictionary/patricia_trie_reading_utils.h" +#include "suggest/policyimpl/dictionary/shortcut/dynamic_shortcut_list_policy.h" +#include "suggest/policyimpl/dictionary/utils/dict_file_writing_utils.h" +#include "utils/hash_map_compat.h" + +namespace latinime { + +const int DynamicPatriciaTrieWritingHelper::CHILDREN_POSITION_FIELD_SIZE = 3; +// TODO: Make MAX_DICTIONARY_SIZE 8MB. +const size_t DynamicPatriciaTrieWritingHelper::MAX_DICTIONARY_SIZE = 2 * 1024 * 1024; + +bool DynamicPatriciaTrieWritingHelper::addUnigramWord( + DynamicPatriciaTrieReadingHelper *const readingHelper, + const int *const wordCodePoints, const int codePointCount, const int probability) { + int parentPos = NOT_A_DICT_POS; + while (!readingHelper->isEnd()) { + const int matchedCodePointCount = readingHelper->getPrevTotalCodePointCount(); + if (!readingHelper->isMatchedCodePoint(0 /* index */, + wordCodePoints[matchedCodePointCount])) { + // The first code point is different from target code point. Skip this node and read + // the next sibling node. + readingHelper->readNextSiblingNode(); + continue; + } + // Check following merged node code points. + const DynamicPatriciaTrieNodeReader *const nodeReader = readingHelper->getNodeReader(); + const int nodeCodePointCount = nodeReader->getCodePointCount(); + for (int j = 1; j < nodeCodePointCount; ++j) { + const int nextIndex = matchedCodePointCount + j; + if (nextIndex >= codePointCount || !readingHelper->isMatchedCodePoint(j, + wordCodePoints[matchedCodePointCount + j])) { + return reallocatePtNodeAndAddNewPtNodes(nodeReader, + readingHelper->getMergedNodeCodePoints(), j, probability, + wordCodePoints + matchedCodePointCount, + codePointCount - matchedCodePointCount); + } + } + // All characters are matched. + if (codePointCount == readingHelper->getTotalCodePointCount()) { + return setPtNodeProbability(nodeReader, probability, + readingHelper->getMergedNodeCodePoints()); + } + if (!nodeReader->hasChildren()) { + return createChildrenPtNodeArrayAndAChildPtNode(nodeReader, probability, + wordCodePoints + readingHelper->getTotalCodePointCount(), + codePointCount - readingHelper->getTotalCodePointCount()); + } + // Advance to the children nodes. + parentPos = nodeReader->getHeadPos(); + readingHelper->readChildNode(); + } + if (readingHelper->isError()) { + // The dictionary is invalid. + return false; + } + int pos = readingHelper->getPosOfLastForwardLinkField(); + return createAndInsertNodeIntoPtNodeArray(parentPos, + wordCodePoints + readingHelper->getPrevTotalCodePointCount(), + codePointCount - readingHelper->getPrevTotalCodePointCount(), + probability, &pos); +} + +bool DynamicPatriciaTrieWritingHelper::addBigramWords(const int word0Pos, const int word1Pos, + const int probability) { + int mMergedNodeCodePoints[MAX_WORD_LENGTH]; + DynamicPatriciaTrieNodeReader nodeReader(mBuffer, mBigramPolicy, mShortcutPolicy); + nodeReader.fetchNodeInfoInBufferFromPtNodePosAndGetNodeCodePoints(word0Pos, MAX_WORD_LENGTH, + mMergedNodeCodePoints); + // Move node to add bigram entry. + const int newNodePos = mBuffer->getTailPosition(); + if (!markNodeAsMovedAndSetPosition(&nodeReader, newNodePos, newNodePos)) { + return false; + } + int writingPos = newNodePos; + // Write a new PtNode using original PtNode's info to the tail of the dictionary in mBuffer. + if (!writePtNodeToBufferByCopyingPtNodeInfo(mBuffer, &nodeReader, nodeReader.getParentPos(), + mMergedNodeCodePoints, nodeReader.getCodePointCount(), nodeReader.getProbability(), + &writingPos)) { + return false; + } + nodeReader.fetchNodeInfoInBufferFromPtNodePos(newNodePos); + if (nodeReader.getBigramsPos() != NOT_A_DICT_POS) { + // Insert a new bigram entry into the existing bigram list. + int bigramListPos = nodeReader.getBigramsPos(); + return mBigramPolicy->addNewBigramEntryToBigramList(word1Pos, probability, &bigramListPos); + } else { + // The PtNode doesn't have a bigram list. + // First, Write a bigram entry at the tail position of the PtNode. + if (!mBigramPolicy->writeNewBigramEntry(word1Pos, probability, &writingPos)) { + return false; + } + // Then, Mark as the PtNode having bigram list in the flags. + const PatriciaTrieReadingUtils::NodeFlags updatedFlags = + PatriciaTrieReadingUtils::createAndGetFlags(nodeReader.isBlacklisted(), + nodeReader.isNotAWord(), nodeReader.getProbability() != NOT_A_PROBABILITY, + nodeReader.getShortcutPos() != NOT_A_DICT_POS, true /* hasBigrams */, + nodeReader.getCodePointCount() > 1, CHILDREN_POSITION_FIELD_SIZE); + writingPos = newNodePos; + // Write updated flags into the moved PtNode's flags field. + return DynamicPatriciaTrieWritingUtils::writeFlagsAndAdvancePosition(mBuffer, updatedFlags, + &writingPos); + } +} + +// Remove a bigram relation from word0Pos to word1Pos. +bool DynamicPatriciaTrieWritingHelper::removeBigramWords(const int word0Pos, const int word1Pos) { + DynamicPatriciaTrieNodeReader nodeReader(mBuffer, mBigramPolicy, mShortcutPolicy); + nodeReader.fetchNodeInfoInBufferFromPtNodePos(word0Pos); + if (nodeReader.getBigramsPos() == NOT_A_DICT_POS) { + return false; + } + return mBigramPolicy->removeBigram(nodeReader.getBigramsPos(), word1Pos); +} + +void DynamicPatriciaTrieWritingHelper::writeToDictFile(const char *const fileName, + const HeaderPolicy *const headerPolicy) { + BufferWithExtendableBuffer headerBuffer(0 /* originalBuffer */, 0 /* originalBufferSize */); + if (!headerPolicy->writeHeaderToBuffer(&headerBuffer, false /* updatesLastUpdatedTime */)) { + return; + } + DictFileWritingUtils::flushAllHeaderAndBodyToFile(fileName, &headerBuffer, mBuffer); +} + +void DynamicPatriciaTrieWritingHelper::writeToDictFileWithGC(const int rootPtNodeArrayPos, + const char *const fileName, const HeaderPolicy *const headerPolicy) { + BufferWithExtendableBuffer headerBuffer(0 /* originalBuffer */, 0 /* originalBufferSize */); + if (!headerPolicy->writeHeaderToBuffer(&headerBuffer, true /* updatesLastUpdatedTime */)) { + return; + } + BufferWithExtendableBuffer newDictBuffer(0 /* originalBuffer */, 0 /* originalBufferSize */, + MAX_DICTIONARY_SIZE); + if (!runGC(rootPtNodeArrayPos, &newDictBuffer)) { + return; + } + DictFileWritingUtils::flushAllHeaderAndBodyToFile(fileName, &headerBuffer, &newDictBuffer); +} + +bool DynamicPatriciaTrieWritingHelper::markNodeAsDeleted( + const DynamicPatriciaTrieNodeReader *const nodeToUpdate) { + int pos = nodeToUpdate->getHeadPos(); + const bool usesAdditionalBuffer = mBuffer->isInAdditionalBuffer(pos); + const uint8_t *const dictBuf = mBuffer->getBuffer(usesAdditionalBuffer); + if (usesAdditionalBuffer) { + pos -= mBuffer->getOriginalBufferSize(); + } + // Read original flags + const PatriciaTrieReadingUtils::NodeFlags originalFlags = + PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(dictBuf, &pos); + const PatriciaTrieReadingUtils::NodeFlags updatedFlags = + DynamicPatriciaTrieReadingUtils::updateAndGetFlags(originalFlags, false /* isMoved */, + true /* isDeleted */); + int writingPos = nodeToUpdate->getHeadPos(); + // Update flags. + return DynamicPatriciaTrieWritingUtils::writeFlagsAndAdvancePosition(mBuffer, updatedFlags, + &writingPos); +} + +bool DynamicPatriciaTrieWritingHelper::markNodeAsMovedAndSetPosition( + const DynamicPatriciaTrieNodeReader *const originalNode, const int movedPos, + const int bigramLinkedNodePos) { + int pos = originalNode->getHeadPos(); + const bool usesAdditionalBuffer = mBuffer->isInAdditionalBuffer(pos); + const uint8_t *const dictBuf = mBuffer->getBuffer(usesAdditionalBuffer); + if (usesAdditionalBuffer) { + pos -= mBuffer->getOriginalBufferSize(); + } + // Read original flags + const PatriciaTrieReadingUtils::NodeFlags originalFlags = + PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(dictBuf, &pos); + const PatriciaTrieReadingUtils::NodeFlags updatedFlags = + DynamicPatriciaTrieReadingUtils::updateAndGetFlags(originalFlags, true /* isMoved */, + false /* isDeleted */); + int writingPos = originalNode->getHeadPos(); + // Update flags. + if (!DynamicPatriciaTrieWritingUtils::writeFlagsAndAdvancePosition(mBuffer, updatedFlags, + &writingPos)) { + return false; + } + // Update moved position, which is stored in the parent offset field. + if (!DynamicPatriciaTrieWritingUtils::writeParentPosOffsetAndAdvancePosition( + mBuffer, movedPos, originalNode->getHeadPos(), &writingPos)) { + return false; + } + // Update bigram linked node position, which is stored in the children position field. + int childrenPosFieldPos = originalNode->getChildrenPosFieldPos(); + if (!DynamicPatriciaTrieWritingUtils::writeChildrenPositionAndAdvancePosition( + mBuffer, bigramLinkedNodePos, &childrenPosFieldPos)) { + return false; + } + if (originalNode->hasChildren()) { + // Update children's parent position. + DynamicPatriciaTrieReadingHelper readingHelper(mBuffer, mBigramPolicy, mShortcutPolicy); + const DynamicPatriciaTrieNodeReader *const nodeReader = readingHelper.getNodeReader(); + readingHelper.initWithPtNodeArrayPos(originalNode->getChildrenPos()); + while (!readingHelper.isEnd()) { + int parentOffsetFieldPos = nodeReader->getHeadPos() + + DynamicPatriciaTrieWritingUtils::NODE_FLAG_FIELD_SIZE; + if (!DynamicPatriciaTrieWritingUtils::writeParentPosOffsetAndAdvancePosition( + mBuffer, movedPos, nodeReader->getHeadPos(), &parentOffsetFieldPos)) { + // Parent offset cannot be written because of a bug or a broken dictionary; thus, + // we give up to update dictionary. + return false; + } + readingHelper.readNextSiblingNode(); + } + } + return true; +} + +// Write new PtNode at writingPos. +bool DynamicPatriciaTrieWritingHelper::writePtNodeWithFullInfoToBuffer( + BufferWithExtendableBuffer *const bufferToWrite, const bool isBlacklisted, + const bool isNotAWord, const int parentPos, const int *const codePoints, + const int codePointCount, const int probability, const int childrenPos, + const int originalBigramListPos, const int originalShortcutListPos, + int *const writingPos) { + const int nodePos = *writingPos; + // Write dummy flags. The Node flags are updated with appropriate flags at the last step of the + // PtNode writing. + if (!DynamicPatriciaTrieWritingUtils::writeFlagsAndAdvancePosition(bufferToWrite, + 0 /* nodeFlags */, writingPos)) { + return false; + } + // Calculate a parent offset and write the offset. + if (!DynamicPatriciaTrieWritingUtils::writeParentPosOffsetAndAdvancePosition(bufferToWrite, + parentPos, nodePos, writingPos)) { + return false; + } + // Write code points + if (!DynamicPatriciaTrieWritingUtils::writeCodePointsAndAdvancePosition(bufferToWrite, + codePoints, codePointCount, writingPos)) { + return false; + } + // Write probability when the probability is a valid probability, which means this node is + // terminal. + if (probability != NOT_A_PROBABILITY) { + if (!DynamicPatriciaTrieWritingUtils::writeProbabilityAndAdvancePosition(bufferToWrite, + probability, writingPos)) { + return false; + } + } + // Write children position + if (!DynamicPatriciaTrieWritingUtils::writeChildrenPositionAndAdvancePosition(bufferToWrite, + childrenPos, writingPos)) { + return false; + } + // Copy shortcut list when the originalShortcutListPos is valid dictionary position. + if (originalShortcutListPos != NOT_A_DICT_POS) { + int fromPos = originalShortcutListPos; + if (!mShortcutPolicy->copyAllShortcutsAndReturnIfSucceededOrNot(bufferToWrite, &fromPos, + writingPos)) { + return false; + } + } + // Copy bigram list when the originalBigramListPos is valid dictionary position. + int bigramCount = 0; + if (originalBigramListPos != NOT_A_DICT_POS) { + int fromPos = originalBigramListPos; + if (!mBigramPolicy->copyAllBigrams(bufferToWrite, &fromPos, writingPos, &bigramCount)) { + return false; + } + } + // Create node flags and write them. + PatriciaTrieReadingUtils::NodeFlags nodeFlags = + PatriciaTrieReadingUtils::createAndGetFlags(isBlacklisted, isNotAWord, + probability != NOT_A_PROBABILITY /* isTerminal */, + originalShortcutListPos != NOT_A_DICT_POS /* hasShortcutTargets */, + bigramCount > 0 /* hasBigrams */, codePointCount > 1 /* hasMultipleChars */, + CHILDREN_POSITION_FIELD_SIZE); + int flagsFieldPos = nodePos; + if (!DynamicPatriciaTrieWritingUtils::writeFlagsAndAdvancePosition(bufferToWrite, nodeFlags, + &flagsFieldPos)) { + return false; + } + return true; +} + +bool DynamicPatriciaTrieWritingHelper::writePtNodeToBuffer( + BufferWithExtendableBuffer *const bufferToWrite, const int parentPos, + const int *const codePoints, const int codePointCount, const int probability, + int *const writingPos) { + return writePtNodeWithFullInfoToBuffer(bufferToWrite, false /* isBlacklisted */, + false /* isNotAWord */, parentPos, codePoints, codePointCount, probability, + NOT_A_DICT_POS /* childrenPos */, NOT_A_DICT_POS /* originalBigramsPos */, + NOT_A_DICT_POS /* originalShortcutPos */, writingPos); +} + +bool DynamicPatriciaTrieWritingHelper::writePtNodeToBufferByCopyingPtNodeInfo( + BufferWithExtendableBuffer *const bufferToWrite, + const DynamicPatriciaTrieNodeReader *const originalNode, const int parentPos, + const int *const codePoints, const int codePointCount, const int probability, + int *const writingPos) { + return writePtNodeWithFullInfoToBuffer(bufferToWrite, originalNode->isBlacklisted(), + originalNode->isNotAWord(), parentPos, codePoints, codePointCount, probability, + originalNode->getChildrenPos(), originalNode->getBigramsPos(), + originalNode->getShortcutPos(), writingPos); +} + +bool DynamicPatriciaTrieWritingHelper::createAndInsertNodeIntoPtNodeArray(const int parentPos, + const int *const nodeCodePoints, const int nodeCodePointCount, const int probability, + int *const forwardLinkFieldPos) { + const int newPtNodeArrayPos = mBuffer->getTailPosition(); + if (!DynamicPatriciaTrieWritingUtils::writeForwardLinkPositionAndAdvancePosition(mBuffer, + newPtNodeArrayPos, forwardLinkFieldPos)) { + return false; + } + return createNewPtNodeArrayWithAChildPtNode(parentPos, nodeCodePoints, nodeCodePointCount, + probability); +} + +bool DynamicPatriciaTrieWritingHelper::setPtNodeProbability( + const DynamicPatriciaTrieNodeReader *const originalPtNode, const int probability, + const int *const codePoints) { + if (originalPtNode->isTerminal()) { + // Overwrites the probability. + int probabilityFieldPos = originalPtNode->getProbabilityFieldPos(); + if (!DynamicPatriciaTrieWritingUtils::writeProbabilityAndAdvancePosition(mBuffer, + probability, &probabilityFieldPos)) { + return false; + } + } else { + // Make the node terminal and write the probability. + int movedPos = mBuffer->getTailPosition(); + if (!markNodeAsMovedAndSetPosition(originalPtNode, movedPos, movedPos)) { + return false; + } + if (!writePtNodeToBufferByCopyingPtNodeInfo(mBuffer, originalPtNode, + originalPtNode->getParentPos(), codePoints, originalPtNode->getCodePointCount(), + probability, &movedPos)) { + return false; + } + } + return true; +} + +bool DynamicPatriciaTrieWritingHelper::createChildrenPtNodeArrayAndAChildPtNode( + const DynamicPatriciaTrieNodeReader *const parentNode, const int probability, + const int *const codePoints, const int codePointCount) { + const int newPtNodeArrayPos = mBuffer->getTailPosition(); + int childrenPosFieldPos = parentNode->getChildrenPosFieldPos(); + if (!DynamicPatriciaTrieWritingUtils::writeChildrenPositionAndAdvancePosition(mBuffer, + newPtNodeArrayPos, &childrenPosFieldPos)) { + return false; + } + return createNewPtNodeArrayWithAChildPtNode(parentNode->getHeadPos(), codePoints, + codePointCount, probability); +} + +bool DynamicPatriciaTrieWritingHelper::createNewPtNodeArrayWithAChildPtNode( + const int parentPtNodePos, const int *const nodeCodePoints, const int nodeCodePointCount, + const int probability) { + int writingPos = mBuffer->getTailPosition(); + if (!DynamicPatriciaTrieWritingUtils::writePtNodeArraySizeAndAdvancePosition(mBuffer, + 1 /* arraySize */, &writingPos)) { + return false; + } + if (!writePtNodeToBuffer(mBuffer, parentPtNodePos, nodeCodePoints, nodeCodePointCount, + probability, &writingPos)) { + return false; + } + if (!DynamicPatriciaTrieWritingUtils::writeForwardLinkPositionAndAdvancePosition(mBuffer, + NOT_A_DICT_POS /* forwardLinkPos */, &writingPos)) { + return false; + } + return true; +} + +// Returns whether the dictionary updating was succeeded or not. +bool DynamicPatriciaTrieWritingHelper::reallocatePtNodeAndAddNewPtNodes( + const DynamicPatriciaTrieNodeReader *const reallocatingPtNode, + const int *const reallocatingPtNodeCodePoints, const int overlappingCodePointCount, + const int probabilityOfNewPtNode, const int *const newNodeCodePoints, + const int newNodeCodePointCount) { + // When addsExtraChild is true, split the reallocating PtNode and add new child. + // Reallocating PtNode: abcde, newNode: abcxy. + // abc (1st, not terminal) __ de (2nd) + // \_ xy (extra child, terminal) + // Otherwise, this method makes 1st part terminal and write probabilityOfNewPtNode. + // Reallocating PtNode: abcde, newNode: abc. + // abc (1st, terminal) __ de (2nd) + const bool addsExtraChild = newNodeCodePointCount > overlappingCodePointCount; + const int firstPartOfReallocatedPtNodePos = mBuffer->getTailPosition(); + int writingPos = firstPartOfReallocatedPtNodePos; + // Write the 1st part of the reallocating node. The children position will be updated later + // with actual children position. + const int newProbability = addsExtraChild ? NOT_A_PROBABILITY : probabilityOfNewPtNode; + if (!writePtNodeToBuffer(mBuffer, reallocatingPtNode->getParentPos(), + reallocatingPtNodeCodePoints, overlappingCodePointCount, newProbability, + &writingPos)) { + return false; + } + const int actualChildrenPos = writingPos; + // Create new children PtNode array. + const size_t newPtNodeCount = addsExtraChild ? 2 : 1; + if (!DynamicPatriciaTrieWritingUtils::writePtNodeArraySizeAndAdvancePosition(mBuffer, + newPtNodeCount, &writingPos)) { + return false; + } + // Write the 2nd part of the reallocating node. + const int secondPartOfReallocatedPtNodePos = writingPos; + if (!writePtNodeToBufferByCopyingPtNodeInfo(mBuffer, reallocatingPtNode, + firstPartOfReallocatedPtNodePos, + reallocatingPtNodeCodePoints + overlappingCodePointCount, + reallocatingPtNode->getCodePointCount() - overlappingCodePointCount, + reallocatingPtNode->getProbability(), &writingPos)) { + return false; + } + if (addsExtraChild) { + if (!writePtNodeToBuffer(mBuffer, firstPartOfReallocatedPtNodePos, + newNodeCodePoints + overlappingCodePointCount, + newNodeCodePointCount - overlappingCodePointCount, probabilityOfNewPtNode, + &writingPos)) { + return false; + } + } + if (!DynamicPatriciaTrieWritingUtils::writeForwardLinkPositionAndAdvancePosition(mBuffer, + NOT_A_DICT_POS /* forwardLinkPos */, &writingPos)) { + return false; + } + // Update original reallocatingPtNode as moved. + if (!markNodeAsMovedAndSetPosition(reallocatingPtNode, firstPartOfReallocatedPtNodePos, + secondPartOfReallocatedPtNodePos)) { + return false; + } + // Load node info. Information of the 1st part will be fetched. + DynamicPatriciaTrieNodeReader nodeReader(mBuffer, mBigramPolicy, mShortcutPolicy); + nodeReader.fetchNodeInfoInBufferFromPtNodePos(firstPartOfReallocatedPtNodePos); + // Update children position. + int childrenPosFieldPos = nodeReader.getChildrenPosFieldPos(); + if (!DynamicPatriciaTrieWritingUtils::writeChildrenPositionAndAdvancePosition(mBuffer, + actualChildrenPos, &childrenPosFieldPos)) { + return false; + } + return true; +} + +bool DynamicPatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, + BufferWithExtendableBuffer *const bufferToWrite) { + DynamicPatriciaTrieReadingHelper readingHelper(mBuffer, mBigramPolicy, mShortcutPolicy); + readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos); + DynamicPatriciaTrieGcEventListeners + ::TraversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted + traversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted( + this, mBuffer); + if (!readingHelper.traverseAllPtNodesInPostorderDepthFirstManner( + &traversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted)) { + return false; + } + + readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos); + DynamicPatriciaTrieGcEventListeners::TraversePolicyToUpdateBigramProbability + traversePolicyToUpdateBigramProbability(mBigramPolicy); + if (!readingHelper.traverseAllPtNodesInPostorderDepthFirstManner( + &traversePolicyToUpdateBigramProbability)) { + return false; + } + + // Mapping from positions in mBuffer to positions in bufferToWrite. + DictPositionRelocationMap dictPositionRelocationMap; + readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos); + DynamicPatriciaTrieGcEventListeners::TraversePolicyToPlaceAndWriteValidPtNodesToBuffer + traversePolicyToPlaceAndWriteValidPtNodesToBuffer(this, bufferToWrite, + &dictPositionRelocationMap); + if (!readingHelper.traverseAllPtNodesInPtNodeArrayLevelPreorderDepthFirstManner( + &traversePolicyToPlaceAndWriteValidPtNodesToBuffer)) { + return false; + } + + // Create policy instance for the GCed dictionary. + DynamicShortcutListPolicy newDictShortcutPolicy(bufferToWrite); + DynamicBigramListPolicy newDictBigramPolicy(bufferToWrite, &newDictShortcutPolicy); + // Create reading helper for the GCed dictionary. + DynamicPatriciaTrieReadingHelper newDictReadingHelper(bufferToWrite, &newDictBigramPolicy, + &newDictShortcutPolicy); + newDictReadingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos); + DynamicPatriciaTrieGcEventListeners::TraversePolicyToUpdateAllPositionFields + traversePolicyToUpdateAllPositionFields(this, &newDictBigramPolicy, bufferToWrite, + &dictPositionRelocationMap); + if (!newDictReadingHelper.traverseAllPtNodesInPtNodeArrayLevelPreorderDepthFirstManner( + &traversePolicyToUpdateAllPositionFields)) { + return false; + } + return true; +} + +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.h b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.h new file mode 100644 index 000000000..fe1b2437a --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.h @@ -0,0 +1,128 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DYNAMIC_PATRICIA_TRIE_WRITING_HELPER_H +#define LATINIME_DYNAMIC_PATRICIA_TRIE_WRITING_HELPER_H + +#include <stdint.h> + +#include "defines.h" +#include "utils/hash_map_compat.h" + +namespace latinime { + +class BufferWithExtendableBuffer; +class DynamicBigramListPolicy; +class DynamicPatriciaTrieNodeReader; +class DynamicPatriciaTrieReadingHelper; +class DynamicShortcutListPolicy; +class HeaderPolicy; + +class DynamicPatriciaTrieWritingHelper { + public: + typedef hash_map_compat<int, int> PtNodeArrayPositionRelocationMap; + typedef hash_map_compat<int, int> PtNodePositionRelocationMap; + struct DictPositionRelocationMap { + public: + DictPositionRelocationMap() + : mPtNodeArrayPositionRelocationMap(), mPtNodePositionRelocationMap() {} + + PtNodeArrayPositionRelocationMap mPtNodeArrayPositionRelocationMap; + PtNodePositionRelocationMap mPtNodePositionRelocationMap; + + private: + DISALLOW_COPY_AND_ASSIGN(DictPositionRelocationMap); + }; + + DynamicPatriciaTrieWritingHelper(BufferWithExtendableBuffer *const buffer, + DynamicBigramListPolicy *const bigramPolicy, + DynamicShortcutListPolicy *const shortcutPolicy) + : mBuffer(buffer), mBigramPolicy(bigramPolicy), mShortcutPolicy(shortcutPolicy) {} + + ~DynamicPatriciaTrieWritingHelper() {} + + // Add a word to the dictionary. If the word already exists, update the probability. + bool addUnigramWord(DynamicPatriciaTrieReadingHelper *const readingHelper, + const int *const wordCodePoints, const int codePointCount, const int probability); + + // Add a bigram relation from word0Pos to word1Pos. + bool addBigramWords(const int word0Pos, const int word1Pos, const int probability); + + // Remove a bigram relation from word0Pos to word1Pos. + bool removeBigramWords(const int word0Pos, const int word1Pos); + + void writeToDictFile(const char *const fileName, const HeaderPolicy *const headerPolicy); + + void writeToDictFileWithGC(const int rootPtNodeArrayPos, const char *const fileName, + const HeaderPolicy *const headerPolicy); + + // CAVEAT: This method must be called only from inner classes of + // DynamicPatriciaTrieGcEventListeners. + bool markNodeAsDeleted(const DynamicPatriciaTrieNodeReader *const nodeToUpdate); + + // CAVEAT: This method must be called only from this class or inner classes of + // DynamicPatriciaTrieGcEventListeners. + bool writePtNodeToBufferByCopyingPtNodeInfo(BufferWithExtendableBuffer *const bufferToWrite, + const DynamicPatriciaTrieNodeReader *const originalNode, const int parentPos, + const int *const codePoints, const int codePointCount, const int probability, + int *const writingPos); + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(DynamicPatriciaTrieWritingHelper); + + static const int CHILDREN_POSITION_FIELD_SIZE; + static const size_t MAX_DICTIONARY_SIZE; + + BufferWithExtendableBuffer *const mBuffer; + DynamicBigramListPolicy *const mBigramPolicy; + DynamicShortcutListPolicy *const mShortcutPolicy; + + bool markNodeAsMovedAndSetPosition(const DynamicPatriciaTrieNodeReader *const nodeToUpdate, + const int movedPos, const int bigramLinkedNodePos); + + bool writePtNodeWithFullInfoToBuffer(BufferWithExtendableBuffer *const bufferToWrite, + const bool isBlacklisted, const bool isNotAWord, + const int parentPos, const int *const codePoints, const int codePointCount, + const int probability, const int childrenPos, const int originalBigramListPos, + const int originalShortcutListPos, int *const writingPos); + + bool writePtNodeToBuffer(BufferWithExtendableBuffer *const bufferToWrite, + const int parentPos, const int *const codePoints, const int codePointCount, + const int probability, int *const writingPos); + + bool createAndInsertNodeIntoPtNodeArray(const int parentPos, const int *const nodeCodePoints, + const int nodeCodePointCount, const int probability, int *const forwardLinkFieldPos); + + bool setPtNodeProbability(const DynamicPatriciaTrieNodeReader *const originalNode, + const int probability, const int *const codePoints); + + bool createChildrenPtNodeArrayAndAChildPtNode( + const DynamicPatriciaTrieNodeReader *const parentNode, const int probability, + const int *const codePoints, const int codePointCount); + + bool createNewPtNodeArrayWithAChildPtNode(const int parentPos, const int *const nodeCodePoints, + const int nodeCodePointCount, const int probability); + + bool reallocatePtNodeAndAddNewPtNodes( + const DynamicPatriciaTrieNodeReader *const reallocatingPtNode, + const int *const reallocatingPtNodeCodePoints, const int overlappingCodePointCount, + const int probabilityOfNewPtNode, const int *const newNodeCodePoints, + const int newNodeCodePointCount); + + bool runGC(const int rootPtNodeArrayPos, BufferWithExtendableBuffer *const bufferToWrite); +}; +} // namespace latinime +#endif /* LATINIME_DYNAMIC_PATRICIA_TRIE_WRITING_HELPER_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_utils.cpp new file mode 100644 index 000000000..30ff10cd6 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_utils.cpp @@ -0,0 +1,147 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_utils.h" + +#include <cstddef> +#include <cstdlib> +#include <stdint.h> + +#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" + +namespace latinime { + +const size_t DynamicPatriciaTrieWritingUtils::MAX_PTNODE_ARRAY_SIZE_TO_USE_SMALL_SIZE_FIELD = 0x7F; +const size_t DynamicPatriciaTrieWritingUtils::MAX_PTNODE_ARRAY_SIZE = 0x7FFF; +const int DynamicPatriciaTrieWritingUtils::SMALL_PTNODE_ARRAY_SIZE_FIELD_SIZE = 1; +const int DynamicPatriciaTrieWritingUtils::LARGE_PTNODE_ARRAY_SIZE_FIELD_SIZE = 2; +const int DynamicPatriciaTrieWritingUtils::LARGE_PTNODE_ARRAY_SIZE_FIELD_SIZE_FLAG = 0x8000; +const int DynamicPatriciaTrieWritingUtils::DICT_OFFSET_FIELD_SIZE = 3; +const int DynamicPatriciaTrieWritingUtils::MAX_DICT_OFFSET_VALUE = 0x7FFFFF; +const int DynamicPatriciaTrieWritingUtils::MIN_DICT_OFFSET_VALUE = -0x7FFFFF; +const int DynamicPatriciaTrieWritingUtils::DICT_OFFSET_NEGATIVE_FLAG = 0x800000; +const int DynamicPatriciaTrieWritingUtils::PROBABILITY_FIELD_SIZE = 1; +const int DynamicPatriciaTrieWritingUtils::NODE_FLAG_FIELD_SIZE = 1; + +/* static */ bool DynamicPatriciaTrieWritingUtils::writeEmptyDictionary( + BufferWithExtendableBuffer *const buffer, const int rootPos) { + int writingPos = rootPos; + if (!writePtNodeArraySizeAndAdvancePosition(buffer, 0 /* arraySize */, &writingPos)) { + return false; + } + return writeForwardLinkPositionAndAdvancePosition(buffer, NOT_A_DICT_POS /* forwardLinkPos */, + &writingPos); +} + +/* static */ bool DynamicPatriciaTrieWritingUtils::writeForwardLinkPositionAndAdvancePosition( + BufferWithExtendableBuffer *const buffer, const int forwardLinkPos, + int *const forwardLinkFieldPos) { + return writeDictOffset(buffer, forwardLinkPos, (*forwardLinkFieldPos), forwardLinkFieldPos); +} + +/* static */ bool DynamicPatriciaTrieWritingUtils::writePtNodeArraySizeAndAdvancePosition( + BufferWithExtendableBuffer *const buffer, const size_t arraySize, + int *const arraySizeFieldPos) { + // Currently, all array size field to be created has LARGE_PTNODE_ARRAY_SIZE_FIELD_SIZE to + // simplify updating process. + // TODO: Use SMALL_PTNODE_ARRAY_SIZE_FIELD_SIZE for small arrays. + /*if (arraySize <= MAX_PTNODE_ARRAY_SIZE_TO_USE_SMALL_SIZE_FIELD) { + return buffer->writeUintAndAdvancePosition(arraySize, SMALL_PTNODE_ARRAY_SIZE_FIELD_SIZE, + arraySizeFieldPos); + } else */ + if (arraySize <= MAX_PTNODE_ARRAY_SIZE) { + uint32_t data = arraySize | LARGE_PTNODE_ARRAY_SIZE_FIELD_SIZE_FLAG; + return buffer->writeUintAndAdvancePosition(data, LARGE_PTNODE_ARRAY_SIZE_FIELD_SIZE, + arraySizeFieldPos); + } else { + AKLOGI("PtNode array size cannot be written because arraySize is too large: %zd", + arraySize); + ASSERT(false); + return false; + } +} + +/* static */ bool DynamicPatriciaTrieWritingUtils::writeFlagsAndAdvancePosition( + BufferWithExtendableBuffer *const buffer, + const DynamicPatriciaTrieReadingUtils::NodeFlags nodeFlags, int *const nodeFlagsFieldPos) { + return buffer->writeUintAndAdvancePosition(nodeFlags, NODE_FLAG_FIELD_SIZE, nodeFlagsFieldPos); +} + +// Note that parentOffset is offset from node's head position. +/* static */ bool DynamicPatriciaTrieWritingUtils::writeParentPosOffsetAndAdvancePosition( + BufferWithExtendableBuffer *const buffer, const int parentPos, const int basePos, + int *const parentPosFieldPos) { + return writeDictOffset(buffer, parentPos, basePos, parentPosFieldPos); +} + +/* static */ bool DynamicPatriciaTrieWritingUtils::writeCodePointsAndAdvancePosition( + BufferWithExtendableBuffer *const buffer, const int *const codePoints, + const int codePointCount, int *const codePointFieldPos) { + if (codePointCount <= 0) { + AKLOGI("code points cannot be written because codePointCount is invalid: %d", + codePointCount); + ASSERT(false); + return false; + } + const bool hasMultipleCodePoints = codePointCount > 1; + return buffer->writeCodePointsAndAdvancePosition(codePoints, codePointCount, + hasMultipleCodePoints, codePointFieldPos); +} + +/* static */ bool DynamicPatriciaTrieWritingUtils::writeProbabilityAndAdvancePosition( + BufferWithExtendableBuffer *const buffer, const int probability, + int *const probabilityFieldPos) { + if (probability < 0 || probability > MAX_PROBABILITY) { + AKLOGI("probability cannot be written because the probability is invalid: %d", + probability); + ASSERT(false); + return false; + } + return buffer->writeUintAndAdvancePosition(probability, PROBABILITY_FIELD_SIZE, + probabilityFieldPos); +} + +/* static */ bool DynamicPatriciaTrieWritingUtils::writeChildrenPositionAndAdvancePosition( + BufferWithExtendableBuffer *const buffer, const int childrenPosition, + int *const childrenPositionFieldPos) { + return writeDictOffset(buffer, childrenPosition, (*childrenPositionFieldPos), + childrenPositionFieldPos); +} + +/* static */ bool DynamicPatriciaTrieWritingUtils::writeDictOffset( + BufferWithExtendableBuffer *const buffer, const int targetPos, const int basePos, + int *const offsetFieldPos) { + int offset = targetPos - basePos; + if (targetPos == NOT_A_DICT_POS) { + offset = DynamicPatriciaTrieReadingUtils::DICT_OFFSET_INVALID; + } else if (offset == 0) { + offset = DynamicPatriciaTrieReadingUtils::DICT_OFFSET_ZERO_OFFSET; + } + if (offset > MAX_DICT_OFFSET_VALUE || offset < MIN_DICT_OFFSET_VALUE) { + AKLOGI("offset cannot be written because the offset is too large or too small: %d", + offset); + ASSERT(false); + return false; + } + uint32_t data = 0; + if (offset >= 0) { + data = offset; + } else { + data = abs(offset) | DICT_OFFSET_NEGATIVE_FLAG; + } + return buffer->writeUintAndAdvancePosition(data, DICT_OFFSET_FIELD_SIZE, offsetFieldPos); +} +} diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_utils.h b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_utils.h new file mode 100644 index 000000000..af76bc6b5 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_utils.h @@ -0,0 +1,76 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DYNAMIC_PATRICIA_TRIE_WRITING_UTILS_H +#define LATINIME_DYNAMIC_PATRICIA_TRIE_WRITING_UTILS_H + +#include <cstddef> + +#include "defines.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.h" + +namespace latinime { + +class BufferWithExtendableBuffer; + +class DynamicPatriciaTrieWritingUtils { + public: + static const int NODE_FLAG_FIELD_SIZE; + + static bool writeEmptyDictionary(BufferWithExtendableBuffer *const buffer, const int rootPos); + + static bool writeForwardLinkPositionAndAdvancePosition( + BufferWithExtendableBuffer *const buffer, const int forwardLinkPos, + int *const forwardLinkFieldPos); + + static bool writePtNodeArraySizeAndAdvancePosition(BufferWithExtendableBuffer *const buffer, + const size_t arraySize, int *const arraySizeFieldPos); + + static bool writeFlagsAndAdvancePosition(BufferWithExtendableBuffer *const buffer, + const DynamicPatriciaTrieReadingUtils::NodeFlags nodeFlags, + int *const nodeFlagsFieldPos); + + static bool writeParentPosOffsetAndAdvancePosition(BufferWithExtendableBuffer *const buffer, + const int parentPosition, const int basePos, int *const parentPosFieldPos); + + static bool writeCodePointsAndAdvancePosition(BufferWithExtendableBuffer *const buffer, + const int *const codePoints, const int codePointCount, int *const codePointFieldPos); + + static bool writeProbabilityAndAdvancePosition(BufferWithExtendableBuffer *const buffer, + const int probability, int *const probabilityFieldPos); + + static bool writeChildrenPositionAndAdvancePosition(BufferWithExtendableBuffer *const buffer, + const int childrenPosition, int *const childrenPositionFieldPos); + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(DynamicPatriciaTrieWritingUtils); + + static const size_t MAX_PTNODE_ARRAY_SIZE_TO_USE_SMALL_SIZE_FIELD; + static const size_t MAX_PTNODE_ARRAY_SIZE; + static const int SMALL_PTNODE_ARRAY_SIZE_FIELD_SIZE; + static const int LARGE_PTNODE_ARRAY_SIZE_FIELD_SIZE; + static const int LARGE_PTNODE_ARRAY_SIZE_FIELD_SIZE_FLAG; + static const int DICT_OFFSET_FIELD_SIZE; + static const int MAX_DICT_OFFSET_VALUE; + static const int MIN_DICT_OFFSET_VALUE; + static const int DICT_OFFSET_NEGATIVE_FLAG; + static const int PROBABILITY_FIELD_SIZE; + + static bool writeDictOffset(BufferWithExtendableBuffer *const buffer, const int targetPos, + const int basePos, int *const offsetFieldPos); +}; +} // namespace latinime +#endif /* LATINIME_DYNAMIC_PATRICIA_TRIE_WRITING_UTILS_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp index eb828b58c..7bbeacaa0 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp @@ -16,24 +16,116 @@ #include "suggest/policyimpl/dictionary/header/header_policy.h" +#include <cstddef> +#include <cstdio> +#include <ctime> + namespace latinime { -const char *const HeaderPolicy::MULTIPLE_WORDS_DEMOTION_RATE_KEY = - "MULTIPLE_WORDS_DEMOTION_RATE"; -const float HeaderPolicy::DEFAULT_MULTI_WORD_COST_MULTIPLIER = 1.0f; -const float HeaderPolicy::MULTI_WORD_COST_MULTIPLIER_SCALE = 100.0f; -float HeaderPolicy::readMultiWordCostMultiplier() const { - const int headerValue = HeaderReadingUtils::readHeaderValueInt( - mDictBuf, MULTIPLE_WORDS_DEMOTION_RATE_KEY); - if (headerValue == S_INT_MIN) { - // not found - return DEFAULT_MULTI_WORD_COST_MULTIPLIER; +// Note that these are corresponding definitions in Java side in FormatSpec.FileHeader. +const char *const HeaderPolicy::MULTIPLE_WORDS_DEMOTION_RATE_KEY = "MULTIPLE_WORDS_DEMOTION_RATE"; +const char *const HeaderPolicy::USES_FORGETTING_CURVE_KEY = "USES_FORGETTING_CURVE"; +const char *const HeaderPolicy::LAST_UPDATED_TIME_KEY = "date"; +const int HeaderPolicy::DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE = 100; +const float HeaderPolicy::MULTIPLE_WORD_COST_MULTIPLIER_SCALE = 100.0f; + +// Used for logging. Question mark is used to indicate that the key is not found. +void HeaderPolicy::readHeaderValueOrQuestionMark(const char *const key, int *outValue, + int outValueSize) const { + if (outValueSize <= 0) return; + if (outValueSize == 1) { + outValue[0] = '\0'; + return; + } + std::vector<int> keyCodePointVector; + HeaderReadWriteUtils::insertCharactersIntoVector(key, &keyCodePointVector); + HeaderReadWriteUtils::AttributeMap::const_iterator it = mAttributeMap.find(keyCodePointVector); + if (it == mAttributeMap.end()) { + // The key was not found. + outValue[0] = '?'; + outValue[1] = '\0'; + return; } - if (headerValue <= 0) { + const int terminalIndex = min(static_cast<int>(it->second.size()), outValueSize - 1); + for (int i = 0; i < terminalIndex; ++i) { + outValue[i] = it->second[i]; + } + outValue[terminalIndex] = '\0'; +} + +float HeaderPolicy::readMultipleWordCostMultiplier() const { + std::vector<int> keyVector; + HeaderReadWriteUtils::insertCharactersIntoVector(MULTIPLE_WORDS_DEMOTION_RATE_KEY, &keyVector); + const int demotionRate = HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap, + &keyVector, DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE); + if (demotionRate <= 0) { return static_cast<float>(MAX_VALUE_FOR_WEIGHTING); } - return MULTI_WORD_COST_MULTIPLIER_SCALE / static_cast<float>(headerValue); + return MULTIPLE_WORD_COST_MULTIPLIER_SCALE / static_cast<float>(demotionRate); +} + +bool HeaderPolicy::readUsesForgettingCurveFlag() const { + std::vector<int> keyVector; + HeaderReadWriteUtils::insertCharactersIntoVector(USES_FORGETTING_CURVE_KEY, &keyVector); + return HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap, &keyVector, + false /* defaultValue */); +} + +// Returns current time when the key is not found or the value is invalid. +int HeaderPolicy::readLastUpdatedTime() const { + std::vector<int> keyVector; + HeaderReadWriteUtils::insertCharactersIntoVector(LAST_UPDATED_TIME_KEY, &keyVector); + return HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap, &keyVector, + time(0) /* defaultValue */); +} + +bool HeaderPolicy::writeHeaderToBuffer(BufferWithExtendableBuffer *const bufferToWrite, + const bool updatesLastUpdatedTime) const { + int writingPos = 0; + if (!HeaderReadWriteUtils::writeDictionaryVersion(bufferToWrite, mDictFormatVersion, + &writingPos)) { + return false; + } + if (!HeaderReadWriteUtils::writeDictionaryFlags(bufferToWrite, mDictionaryFlags, + &writingPos)) { + return false; + } + // Temporarily writes a dummy header size. + int headerSizeFieldPos = writingPos; + if (!HeaderReadWriteUtils::writeDictionaryHeaderSize(bufferToWrite, 0 /* size */, + &writingPos)) { + return false; + } + if (updatesLastUpdatedTime) { + // Set current time as a last updated time. + HeaderReadWriteUtils::AttributeMap attributeMapTowrite(mAttributeMap); + std::vector<int> updatedTimekey; + HeaderReadWriteUtils::insertCharactersIntoVector(LAST_UPDATED_TIME_KEY, &updatedTimekey); + HeaderReadWriteUtils::setIntAttribute(&attributeMapTowrite, &updatedTimekey, time(0)); + if (!HeaderReadWriteUtils::writeHeaderAttributes(bufferToWrite, &attributeMapTowrite, + &writingPos)) { + return false; + } + } else { + if (!HeaderReadWriteUtils::writeHeaderAttributes(bufferToWrite, &mAttributeMap, + &writingPos)) { + return false; + } + } + // Writes an actual header size. + if (!HeaderReadWriteUtils::writeDictionaryHeaderSize(bufferToWrite, writingPos, + &headerSizeFieldPos)) { + return false; + } + return true; +} + +/* static */ HeaderReadWriteUtils::AttributeMap + HeaderPolicy::createAttributeMapAndReadAllAttributes(const uint8_t *const dictBuf) { + HeaderReadWriteUtils::AttributeMap attributeMap; + HeaderReadWriteUtils::fetchAllHeaderAttributes(dictBuf, &attributeMap); + return attributeMap; } } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h index e3e6fc077..e97c08ca4 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h @@ -21,16 +21,32 @@ #include "defines.h" #include "suggest/core/policy/dictionary_header_structure_policy.h" -#include "suggest/policyimpl/dictionary/header/header_reading_utils.h" +#include "suggest/policyimpl/dictionary/header/header_read_write_utils.h" +#include "suggest/policyimpl/dictionary/utils/format_utils.h" namespace latinime { class HeaderPolicy : public DictionaryHeaderStructurePolicy { public: - explicit HeaderPolicy(const uint8_t *const dictBuf) - : mDictBuf(dictBuf), mDictionaryFlags(HeaderReadingUtils::getFlags(dictBuf)), - mSize(HeaderReadingUtils::getHeaderSize(dictBuf)), - mMultiWordCostMultiplier(readMultiWordCostMultiplier()) {} + // Reads information from existing dictionary buffer. + HeaderPolicy(const uint8_t *const dictBuf, const int dictSize) + : mDictFormatVersion(FormatUtils::detectFormatVersion(dictBuf, dictSize)), + mDictionaryFlags(HeaderReadWriteUtils::getFlags(dictBuf)), + mSize(HeaderReadWriteUtils::getHeaderSize(dictBuf)), + mAttributeMap(createAttributeMapAndReadAllAttributes(dictBuf)), + mMultiWordCostMultiplier(readMultipleWordCostMultiplier()), + mUsesForgettingCurve(readUsesForgettingCurveFlag()), + mLastUpdatedTime(readLastUpdatedTime()) {} + + // Constructs header information using an attribute map. + HeaderPolicy(const FormatUtils::FORMAT_VERSION dictFormatVersion, + const HeaderReadWriteUtils::AttributeMap *const attributeMap) + : mDictFormatVersion(dictFormatVersion), + mDictionaryFlags(HeaderReadWriteUtils::createAndGetDictionaryFlagsUsingAttributeMap( + attributeMap)), mSize(0), mAttributeMap(*attributeMap), + mMultiWordCostMultiplier(readUsesForgettingCurveFlag()), + mUsesForgettingCurve(readUsesForgettingCurveFlag()), + mLastUpdatedTime(readLastUpdatedTime()) {} ~HeaderPolicy() {} @@ -39,50 +55,60 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy { } AK_FORCE_INLINE bool supportsDynamicUpdate() const { - return HeaderReadingUtils::supportsDynamicUpdate(mDictionaryFlags); + return HeaderReadWriteUtils::supportsDynamicUpdate(mDictionaryFlags); } AK_FORCE_INLINE bool requiresGermanUmlautProcessing() const { - return HeaderReadingUtils::requiresGermanUmlautProcessing(mDictionaryFlags); + return HeaderReadWriteUtils::requiresGermanUmlautProcessing(mDictionaryFlags); } AK_FORCE_INLINE bool requiresFrenchLigatureProcessing() const { - return HeaderReadingUtils::requiresFrenchLigatureProcessing( - mDictionaryFlags); + return HeaderReadWriteUtils::requiresFrenchLigatureProcessing(mDictionaryFlags); } AK_FORCE_INLINE float getMultiWordCostMultiplier() const { return mMultiWordCostMultiplier; } - AK_FORCE_INLINE void readHeaderValueOrQuestionMark(const char *const key, - int *outValue, int outValueSize) const { - if (outValueSize <= 0) return; - if (outValueSize == 1) { - outValue[0] = '\0'; - return; - } - if (!HeaderReadingUtils::readHeaderValue(mDictBuf, - key, outValue, outValueSize)) { - outValue[0] = '?'; - outValue[1] = '\0'; - } + AK_FORCE_INLINE bool usesForgettingCurve() const { + return mUsesForgettingCurve; } + AK_FORCE_INLINE int getLastUpdatedTime() const { + return mLastUpdatedTime; + } + + void readHeaderValueOrQuestionMark(const char *const key, + int *outValue, int outValueSize) const; + + bool writeHeaderToBuffer(BufferWithExtendableBuffer *const bufferToWrite, + const bool updatesLastUpdatedTime) const; + private: DISALLOW_IMPLICIT_CONSTRUCTORS(HeaderPolicy); static const char *const MULTIPLE_WORDS_DEMOTION_RATE_KEY; - static const float DEFAULT_MULTI_WORD_COST_MULTIPLIER; - static const float MULTI_WORD_COST_MULTIPLIER_SCALE; + static const char *const USES_FORGETTING_CURVE_KEY; + static const char *const LAST_UPDATED_TIME_KEY; + static const int DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE; + static const float MULTIPLE_WORD_COST_MULTIPLIER_SCALE; - const uint8_t *const mDictBuf; - const HeaderReadingUtils::DictionaryFlags mDictionaryFlags; + const FormatUtils::FORMAT_VERSION mDictFormatVersion; + const HeaderReadWriteUtils::DictionaryFlags mDictionaryFlags; const int mSize; + HeaderReadWriteUtils::AttributeMap mAttributeMap; const float mMultiWordCostMultiplier; + const bool mUsesForgettingCurve; + const int mLastUpdatedTime; - float readMultiWordCostMultiplier() const; -}; + float readMultipleWordCostMultiplier() const; + + bool readUsesForgettingCurveFlag() const; + int readLastUpdatedTime() const; + + static HeaderReadWriteUtils::AttributeMap createAttributeMapAndReadAllAttributes( + const uint8_t *const dictBuf); +}; } // namespace latinime #endif /* LATINIME_HEADER_POLICY_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp new file mode 100644 index 000000000..3b1c78085 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp @@ -0,0 +1,215 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/header/header_read_write_utils.h" + +#include <cctype> +#include <cstdio> +#include <vector> + +#include "defines.h" +#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" +#include "suggest/policyimpl/dictionary/utils/byte_array_utils.h" + +namespace latinime { + +const int HeaderReadWriteUtils::MAX_ATTRIBUTE_KEY_LENGTH = 256; +const int HeaderReadWriteUtils::MAX_ATTRIBUTE_VALUE_LENGTH = 256; + +const int HeaderReadWriteUtils::HEADER_MAGIC_NUMBER_SIZE = 4; +const int HeaderReadWriteUtils::HEADER_DICTIONARY_VERSION_SIZE = 2; +const int HeaderReadWriteUtils::HEADER_FLAG_SIZE = 2; +const int HeaderReadWriteUtils::HEADER_SIZE_FIELD_SIZE = 4; + +const HeaderReadWriteUtils::DictionaryFlags HeaderReadWriteUtils::NO_FLAGS = 0; +// Flags for special processing +// Those *must* match the flags in makedict (FormatSpec#*_PROCESSING_FLAG) or +// something very bad (like, the apocalypse) will happen. Please update both at the same time. +const HeaderReadWriteUtils::DictionaryFlags + HeaderReadWriteUtils::GERMAN_UMLAUT_PROCESSING_FLAG = 0x1; +const HeaderReadWriteUtils::DictionaryFlags + HeaderReadWriteUtils::SUPPORTS_DYNAMIC_UPDATE_FLAG = 0x2; +const HeaderReadWriteUtils::DictionaryFlags + HeaderReadWriteUtils::FRENCH_LIGATURE_PROCESSING_FLAG = 0x4; + +// Note that these are corresponding definitions in Java side in FormatSpec.FileHeader. +const char *const HeaderReadWriteUtils::SUPPORTS_DYNAMIC_UPDATE_KEY = "SUPPORTS_DYNAMIC_UPDATE"; +const char *const HeaderReadWriteUtils::REQUIRES_GERMAN_UMLAUT_PROCESSING_KEY = + "REQUIRES_GERMAN_UMLAUT_PROCESSING"; +const char *const HeaderReadWriteUtils::REQUIRES_FRENCH_LIGATURE_PROCESSING_KEY = + "REQUIRES_FRENCH_LIGATURE_PROCESSING"; + +/* static */ int HeaderReadWriteUtils::getHeaderSize(const uint8_t *const dictBuf) { + // See the format of the header in the comment in + // BinaryDictionaryFormatUtils::detectFormatVersion() + return ByteArrayUtils::readUint32(dictBuf, HEADER_MAGIC_NUMBER_SIZE + + HEADER_DICTIONARY_VERSION_SIZE + HEADER_FLAG_SIZE); +} + +/* static */ HeaderReadWriteUtils::DictionaryFlags + HeaderReadWriteUtils::getFlags(const uint8_t *const dictBuf) { + return ByteArrayUtils::readUint16(dictBuf, + HEADER_MAGIC_NUMBER_SIZE + HEADER_DICTIONARY_VERSION_SIZE); +} + +/* static */ HeaderReadWriteUtils::DictionaryFlags + HeaderReadWriteUtils::createAndGetDictionaryFlagsUsingAttributeMap( + const HeaderReadWriteUtils::AttributeMap *const attributeMap) { + AttributeMap::key_type key; + insertCharactersIntoVector(REQUIRES_GERMAN_UMLAUT_PROCESSING_KEY, &key); + const bool requiresGermanUmlautProcessing = readBoolAttributeValue(attributeMap, &key, + false /* defaultValue */); + key.clear(); + insertCharactersIntoVector(REQUIRES_FRENCH_LIGATURE_PROCESSING_KEY, &key); + const bool requiresFrenchLigatureProcessing = readBoolAttributeValue(attributeMap, &key, + false /* defaultValue */); + key.clear(); + insertCharactersIntoVector(SUPPORTS_DYNAMIC_UPDATE_KEY, &key); + const bool supportsDynamicUpdate = readBoolAttributeValue(attributeMap, &key, + false /* defaultValue */); + DictionaryFlags dictflags = NO_FLAGS; + dictflags |= requiresGermanUmlautProcessing ? GERMAN_UMLAUT_PROCESSING_FLAG : 0; + dictflags |= requiresFrenchLigatureProcessing ? FRENCH_LIGATURE_PROCESSING_FLAG : 0; + dictflags |= supportsDynamicUpdate ? SUPPORTS_DYNAMIC_UPDATE_FLAG : 0; + return dictflags; +} + +/* static */ void HeaderReadWriteUtils::fetchAllHeaderAttributes(const uint8_t *const dictBuf, + AttributeMap *const headerAttributes) { + const int headerSize = getHeaderSize(dictBuf); + int pos = getHeaderOptionsPosition(); + if (pos == NOT_A_DICT_POS) { + // The header doesn't have header options. + return; + } + int keyBuffer[MAX_ATTRIBUTE_KEY_LENGTH]; + int valueBuffer[MAX_ATTRIBUTE_VALUE_LENGTH]; + while (pos < headerSize) { + const int keyLength = ByteArrayUtils::readStringAndAdvancePosition(dictBuf, + MAX_ATTRIBUTE_KEY_LENGTH, keyBuffer, &pos); + std::vector<int> key; + key.insert(key.end(), keyBuffer, keyBuffer + keyLength); + const int valueLength = ByteArrayUtils::readStringAndAdvancePosition(dictBuf, + MAX_ATTRIBUTE_VALUE_LENGTH, valueBuffer, &pos); + std::vector<int> value; + value.insert(value.end(), valueBuffer, valueBuffer + valueLength); + headerAttributes->insert(AttributeMap::value_type(key, value)); + } +} + +/* static */ bool HeaderReadWriteUtils::writeDictionaryVersion( + BufferWithExtendableBuffer *const buffer, const FormatUtils::FORMAT_VERSION version, + int *const writingPos) { + if (!buffer->writeUintAndAdvancePosition(FormatUtils::MAGIC_NUMBER, HEADER_MAGIC_NUMBER_SIZE, + writingPos)) { + return false; + } + switch (version) { + case FormatUtils::VERSION_2: + // Version 2 dictionary writing is not supported. + return false; + case FormatUtils::VERSION_3: + return buffer->writeUintAndAdvancePosition(3 /* data */, + HEADER_DICTIONARY_VERSION_SIZE, writingPos); + default: + return false; + } +} + +/* static */ bool HeaderReadWriteUtils::writeDictionaryFlags( + BufferWithExtendableBuffer *const buffer, const DictionaryFlags flags, + int *const writingPos) { + return buffer->writeUintAndAdvancePosition(flags, HEADER_FLAG_SIZE, writingPos); +} + +/* static */ bool HeaderReadWriteUtils::writeDictionaryHeaderSize( + BufferWithExtendableBuffer *const buffer, const int size, int *const writingPos) { + return buffer->writeUintAndAdvancePosition(size, HEADER_SIZE_FIELD_SIZE, writingPos); +} + +/* static */ bool HeaderReadWriteUtils::writeHeaderAttributes( + BufferWithExtendableBuffer *const buffer, const AttributeMap *const headerAttributes, + int *const writingPos) { + for (AttributeMap::const_iterator it = headerAttributes->begin(); + it != headerAttributes->end(); ++it) { + // Write a key. + if (!buffer->writeCodePointsAndAdvancePosition(&(it->first.at(0)), it->first.size(), + true /* writesTerminator */, writingPos)) { + return false; + } + // Write a value. + if (!buffer->writeCodePointsAndAdvancePosition(&(it->second.at(0)), it->second.size(), + true /* writesTerminator */, writingPos)) { + return false; + } + } + return true; +} + +/* static */ void HeaderReadWriteUtils::setBoolAttribute(AttributeMap *const headerAttributes, + const AttributeMap::key_type *const key, const bool value) { + setIntAttribute(headerAttributes, key, value ? 1 : 0); +} + +/* static */ void HeaderReadWriteUtils::setIntAttribute(AttributeMap *const headerAttributes, + const AttributeMap::key_type *const key, const int value) { + AttributeMap::mapped_type valueVector; + char charBuf[LARGEST_INT_DIGIT_COUNT + 1]; + snprintf(charBuf, LARGEST_INT_DIGIT_COUNT + 1, "%d", value); + insertCharactersIntoVector(charBuf, &valueVector); + (*headerAttributes)[*key] = valueVector; +} + +/* static */ bool HeaderReadWriteUtils::readBoolAttributeValue( + const AttributeMap *const headerAttributes, const AttributeMap::key_type *const key, + const bool defaultValue) { + const int intDefaultValue = defaultValue ? 1 : 0; + const int intValue = readIntAttributeValue(headerAttributes, key, intDefaultValue); + return intValue != 0; +} + +/* static */ int HeaderReadWriteUtils::readIntAttributeValue( + const AttributeMap *const headerAttributes, const AttributeMap::key_type *const key, + const int defaultValue) { + AttributeMap::const_iterator it = headerAttributes->find(*key); + if (it != headerAttributes->end()) { + int value = 0; + bool isNegative = false; + for (size_t i = 0; i < it->second.size(); ++i) { + if (i == 0 && it->second.at(i) == '-') { + isNegative = true; + } else { + if (!isdigit(it->second.at(i))) { + // If not a number. + return defaultValue; + } + value *= 10; + value += it->second.at(i) - '0'; + } + } + return isNegative ? -value : value; + } + return defaultValue; +} + +/* static */ void HeaderReadWriteUtils::insertCharactersIntoVector(const char *const characters, + std::vector<int> *const vector) { + for (int i = 0; characters[i]; ++i) { + vector->push_back(characters[i]); + } +} + +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.h b/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.h new file mode 100644 index 000000000..caa5097f6 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.h @@ -0,0 +1,117 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_HEADER_READ_WRITE_UTILS_H +#define LATINIME_HEADER_READ_WRITE_UTILS_H + +#include <map> +#include <stdint.h> +#include <vector> + +#include "defines.h" +#include "suggest/policyimpl/dictionary/utils/format_utils.h" + +namespace latinime { + +class BufferWithExtendableBuffer; + +class HeaderReadWriteUtils { + public: + typedef uint16_t DictionaryFlags; + typedef std::map<std::vector<int>, std::vector<int> > AttributeMap; + + static int getHeaderSize(const uint8_t *const dictBuf); + + static DictionaryFlags getFlags(const uint8_t *const dictBuf); + + static AK_FORCE_INLINE bool supportsDynamicUpdate(const DictionaryFlags flags) { + return (flags & SUPPORTS_DYNAMIC_UPDATE_FLAG) != 0; + } + + static AK_FORCE_INLINE bool requiresGermanUmlautProcessing(const DictionaryFlags flags) { + return (flags & GERMAN_UMLAUT_PROCESSING_FLAG) != 0; + } + + static AK_FORCE_INLINE bool requiresFrenchLigatureProcessing(const DictionaryFlags flags) { + return (flags & FRENCH_LIGATURE_PROCESSING_FLAG) != 0; + } + + static AK_FORCE_INLINE int getHeaderOptionsPosition() { + return HEADER_MAGIC_NUMBER_SIZE + HEADER_DICTIONARY_VERSION_SIZE + HEADER_FLAG_SIZE + + HEADER_SIZE_FIELD_SIZE; + } + + static DictionaryFlags createAndGetDictionaryFlagsUsingAttributeMap( + const HeaderReadWriteUtils::AttributeMap *const attributeMap); + + static void fetchAllHeaderAttributes(const uint8_t *const dictBuf, + AttributeMap *const headerAttributes); + + static bool writeDictionaryVersion(BufferWithExtendableBuffer *const buffer, + const FormatUtils::FORMAT_VERSION version, int *const writingPos); + + static bool writeDictionaryFlags(BufferWithExtendableBuffer *const buffer, + const DictionaryFlags flags, int *const writingPos); + + static bool writeDictionaryHeaderSize(BufferWithExtendableBuffer *const buffer, + const int size, int *const writingPos); + + static bool writeHeaderAttributes(BufferWithExtendableBuffer *const buffer, + const AttributeMap *const headerAttributes, int *const writingPos); + + /** + * Methods for header attributes. + */ + static void setBoolAttribute(AttributeMap *const headerAttributes, + const AttributeMap::key_type *const key, const bool value); + + static void setIntAttribute(AttributeMap *const headerAttributes, + const AttributeMap::key_type *const key, const int value); + + static bool readBoolAttributeValue(const AttributeMap *const headerAttributes, + const AttributeMap::key_type *const key, const bool defaultValue); + + static int readIntAttributeValue(const AttributeMap *const headerAttributes, + const AttributeMap::key_type *const key, const int defaultValue); + + static void insertCharactersIntoVector(const char *const characters, + AttributeMap::key_type *const key); + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(HeaderReadWriteUtils); + + static const int MAX_ATTRIBUTE_KEY_LENGTH; + static const int MAX_ATTRIBUTE_VALUE_LENGTH; + + static const int HEADER_MAGIC_NUMBER_SIZE; + static const int HEADER_DICTIONARY_VERSION_SIZE; + static const int HEADER_FLAG_SIZE; + static const int HEADER_SIZE_FIELD_SIZE; + + static const DictionaryFlags NO_FLAGS; + // Flags for special processing + // Those *must* match the flags in makedict (FormatSpec#*_PROCESSING_FLAGS) or + // something very bad (like, the apocalypse) will happen. Please update both at the same time. + static const DictionaryFlags GERMAN_UMLAUT_PROCESSING_FLAG; + static const DictionaryFlags SUPPORTS_DYNAMIC_UPDATE_FLAG; + static const DictionaryFlags FRENCH_LIGATURE_PROCESSING_FLAG; + + static const char *const SUPPORTS_DYNAMIC_UPDATE_KEY; + static const char *const REQUIRES_GERMAN_UMLAUT_PROCESSING_KEY; + static const char *const REQUIRES_FRENCH_LIGATURE_PROCESSING_KEY; +}; +} +#endif /* LATINIME_HEADER_READ_WRITE_UTILS_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_reading_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/header/header_reading_utils.cpp deleted file mode 100644 index f323876c4..000000000 --- a/native/jni/src/suggest/policyimpl/dictionary/header/header_reading_utils.cpp +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Copyright (C) 2013, The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "suggest/policyimpl/dictionary/header/header_reading_utils.h" - -#include <cctype> -#include <cstdlib> - -#include "defines.h" -#include "suggest/policyimpl/dictionary/utils/byte_array_utils.h" - -namespace latinime { - -const int HeaderReadingUtils::MAX_OPTION_KEY_LENGTH = 256; - -const int HeaderReadingUtils::HEADER_MAGIC_NUMBER_SIZE = 4; -const int HeaderReadingUtils::HEADER_DICTIONARY_VERSION_SIZE = 2; -const int HeaderReadingUtils::HEADER_FLAG_SIZE = 2; -const int HeaderReadingUtils::HEADER_SIZE_FIELD_SIZE = 4; - -const HeaderReadingUtils::DictionaryFlags - HeaderReadingUtils::NO_FLAGS = 0; -// Flags for special processing -// Those *must* match the flags in makedict (FormatSpec#*_PROCESSING_FLAG) or -// something very bad (like, the apocalypse) will happen. Please update both at the same time. -const HeaderReadingUtils::DictionaryFlags - HeaderReadingUtils::GERMAN_UMLAUT_PROCESSING_FLAG = 0x1; -const HeaderReadingUtils::DictionaryFlags - HeaderReadingUtils::SUPPORTS_DYNAMIC_UPDATE_FLAG = 0x2; -const HeaderReadingUtils::DictionaryFlags - HeaderReadingUtils::FRENCH_LIGATURE_PROCESSING_FLAG = 0x4; - -/* static */ int HeaderReadingUtils::getHeaderSize(const uint8_t *const dictBuf) { - // See the format of the header in the comment in - // BinaryDictionaryFormatUtils::detectFormatVersion() - return ByteArrayUtils::readUint32(dictBuf, HEADER_MAGIC_NUMBER_SIZE - + HEADER_DICTIONARY_VERSION_SIZE + HEADER_FLAG_SIZE); -} - -/* static */ HeaderReadingUtils::DictionaryFlags - HeaderReadingUtils::getFlags(const uint8_t *const dictBuf) { - return ByteArrayUtils::readUint16(dictBuf, - HEADER_MAGIC_NUMBER_SIZE + HEADER_DICTIONARY_VERSION_SIZE); -} - -// Returns if the key is found or not and reads the found value into outValue. -/* static */ bool HeaderReadingUtils::readHeaderValue(const uint8_t *const dictBuf, - const char *const key, int *outValue, const int outValueSize) { - if (outValueSize <= 0) { - return false; - } - const int headerSize = getHeaderSize(dictBuf); - int pos = getHeaderOptionsPosition(); - if (pos == NOT_A_DICT_POS) { - // The header doesn't have header options. - return false; - } - while (pos < headerSize) { - if(ByteArrayUtils::compareStringInBufferWithCharArray( - dictBuf, key, headerSize - pos, &pos) == 0) { - // The key was found. - const int length = ByteArrayUtils::readStringAndAdvancePosition(dictBuf, outValueSize, - outValue, &pos); - // Add a 0 terminator to the string. - outValue[length < outValueSize ? length : outValueSize - 1] = '\0'; - return true; - } - ByteArrayUtils::advancePositionToBehindString(dictBuf, headerSize - pos, &pos); - } - // The key was not found. - return false; -} - -/* static */ int HeaderReadingUtils::readHeaderValueInt( - const uint8_t *const dictBuf, const char *const key) { - const int bufferSize = LARGEST_INT_DIGIT_COUNT; - int intBuffer[bufferSize]; - char charBuffer[bufferSize]; - if (!readHeaderValue(dictBuf, key, intBuffer, bufferSize)) { - return S_INT_MIN; - } - for (int i = 0; i < bufferSize; ++i) { - charBuffer[i] = intBuffer[i]; - if (charBuffer[i] == '0') { - break; - } - if (!isdigit(charBuffer[i])) { - // If not a number, return S_INT_MIN - return S_INT_MIN; - } - } - return atoi(charBuffer); -} - -} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_reading_utils.h b/native/jni/src/suggest/policyimpl/dictionary/header/header_reading_utils.h deleted file mode 100644 index c94919640..000000000 --- a/native/jni/src/suggest/policyimpl/dictionary/header/header_reading_utils.h +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Copyright (C) 2013, The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LATINIME_HEADER_READING_UTILS_H -#define LATINIME_HEADER_READING_UTILS_H - -#include <stdint.h> - -#include "defines.h" - -namespace latinime { - -class HeaderReadingUtils { - public: - typedef uint16_t DictionaryFlags; - - static const int MAX_OPTION_KEY_LENGTH; - - static int getHeaderSize(const uint8_t *const dictBuf); - - static DictionaryFlags getFlags(const uint8_t *const dictBuf); - - static AK_FORCE_INLINE bool supportsDynamicUpdate(const DictionaryFlags flags) { - return (flags & SUPPORTS_DYNAMIC_UPDATE_FLAG) != 0; - } - - static AK_FORCE_INLINE bool requiresGermanUmlautProcessing(const DictionaryFlags flags) { - return (flags & GERMAN_UMLAUT_PROCESSING_FLAG) != 0; - } - - static AK_FORCE_INLINE bool requiresFrenchLigatureProcessing(const DictionaryFlags flags) { - return (flags & FRENCH_LIGATURE_PROCESSING_FLAG) != 0; - } - - static AK_FORCE_INLINE int getHeaderOptionsPosition() { - return HEADER_MAGIC_NUMBER_SIZE + HEADER_DICTIONARY_VERSION_SIZE + HEADER_FLAG_SIZE - + HEADER_SIZE_FIELD_SIZE; - } - - static bool readHeaderValue(const uint8_t *const dictBuf, - const char *const key, int *outValue, const int outValueSize); - - static int readHeaderValueInt(const uint8_t *const dictBuf, const char *const key); - - private: - DISALLOW_IMPLICIT_CONSTRUCTORS(HeaderReadingUtils); - - static const int HEADER_MAGIC_NUMBER_SIZE; - static const int HEADER_DICTIONARY_VERSION_SIZE; - static const int HEADER_FLAG_SIZE; - static const int HEADER_SIZE_FIELD_SIZE; - - static const DictionaryFlags NO_FLAGS; - // Flags for special processing - // Those *must* match the flags in makedict (FormatSpec#*_PROCESSING_FLAGS) or - // something very bad (like, the apocalypse) will happen. Please update both at the same time. - static const DictionaryFlags GERMAN_UMLAUT_PROCESSING_FLAG; - static const DictionaryFlags SUPPORTS_DYNAMIC_UPDATE_FLAG; - static const DictionaryFlags FRENCH_LIGATURE_PROCESSING_FLAG; - static const DictionaryFlags CONTAINS_BIGRAMS_FLAG; -}; -} -#endif /* LATINIME_HEADER_READING_UTILS_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.cpp index 3e664a29b..8a84bd261 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.cpp @@ -20,8 +20,8 @@ #include "defines.h" #include "suggest/core/dicnode/dic_node.h" #include "suggest/core/dicnode/dic_node_vector.h" -#include "suggest/policyimpl/dictionary/binary_format.h" #include "suggest/policyimpl/dictionary/patricia_trie_reading_utils.h" +#include "suggest/policyimpl/dictionary/utils/probability_utils.h" namespace latinime { @@ -31,31 +31,311 @@ void PatriciaTriePolicy::createAndGetAllChildNodes(const DicNode *const dicNode, return; } int nextPos = dicNode->getChildrenPos(); - const int childCount = PatriciaTrieReadingUtils::getGroupCountAndAdvancePosition( + if (nextPos < 0 || nextPos >= mDictBufferSize) { + AKLOGE("Children PtNode array position is invalid. pos: %d, dict size: %d", + nextPos, mDictBufferSize); + ASSERT(false); + return; + } + const int childCount = PatriciaTrieReadingUtils::getPtNodeArraySizeAndAdvancePosition( mDictRoot, &nextPos); for (int i = 0; i < childCount; i++) { + if (nextPos < 0 || nextPos >= mDictBufferSize) { + AKLOGE("Child PtNode position is invalid. pos: %d, dict size: %d, childCount: %d / %d", + nextPos, mDictBufferSize, i, childCount); + ASSERT(false); + return; + } nextPos = createAndGetLeavingChildNode(dicNode, nextPos, childDicNodes); } } +// This retrieves code points and the probability of the word by its terminal position. +// Due to the fact that words are ordered in the dictionary in a strict breadth-first order, +// it is possible to check for this with advantageous complexity. For each node, we search +// for PtNodes with children and compare the children position with the position we look for. +// When we shoot the position we look for, it means the word we look for is in the children +// of the previous PtNode. The only tricky part is the fact that if we arrive at the end of a +// PtNode array with the last PtNode's children position still less than what we are searching for, +// we must descend the last PtNode's children (for example, if the word we are searching for starts +// with a z, it's the last PtNode of the root array, so all children addresses will be smaller +// than the position we look for, and we have to descend the z node). +/* Parameters : + * ptNodePos: the byte position of the terminal PtNode of the word we are searching for (this is + * what is stored as the "bigram position" in each bigram) + * outCodePoints: an array to write the found word, with MAX_WORD_LENGTH size. + * outUnigramProbability: a pointer to an int to write the probability into. + * Return value : the code point count, of 0 if the word was not found. + */ +// TODO: Split this function to be more readable int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( - const int nodePos, const int maxCodePointCount, int *const outCodePoints, + const int ptNodePos, const int maxCodePointCount, int *const outCodePoints, int *const outUnigramProbability) const { - return BinaryFormat::getCodePointsAndProbabilityAndReturnCodePointCount(mDictRoot, nodePos, - maxCodePointCount, outCodePoints, outUnigramProbability); + int pos = getRootPosition(); + int wordPos = 0; + // One iteration of the outer loop iterates through PtNode arrays. As stated above, we will + // only traverse nodes that are actually a part of the terminal we are searching, so each time + // we enter this loop we are one depth level further than last time. + // The only reason we count nodes is because we want to reduce the probability of infinite + // looping in case there is a bug. Since we know there is an upper bound to the depth we are + // supposed to traverse, it does not hurt to count iterations. + for (int loopCount = maxCodePointCount; loopCount > 0; --loopCount) { + int lastCandidatePtNodePos = 0; + // Let's loop through PtNodes in this PtNode array searching for either the terminal + // or one of its ascendants. + for (int ptNodeCount = PatriciaTrieReadingUtils::getPtNodeArraySizeAndAdvancePosition( + mDictRoot, &pos); ptNodeCount > 0; --ptNodeCount) { + const int startPos = pos; + const PatriciaTrieReadingUtils::NodeFlags flags = + PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(mDictRoot, &pos); + const int character = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition( + mDictRoot, &pos); + if (ptNodePos == startPos) { + // We found the position. Copy the rest of the code points in the buffer and return + // the length. + outCodePoints[wordPos] = character; + if (PatriciaTrieReadingUtils::hasMultipleChars(flags)) { + int nextChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition( + mDictRoot, &pos); + // We count code points in order to avoid infinite loops if the file is broken + // or if there is some other bug + int charCount = maxCodePointCount; + while (NOT_A_CODE_POINT != nextChar && --charCount > 0) { + outCodePoints[++wordPos] = nextChar; + nextChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition( + mDictRoot, &pos); + } + } + *outUnigramProbability = + PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mDictRoot, + &pos); + return ++wordPos; + } + // We need to skip past this PtNode, so skip any remaining code points after the + // first and possibly the probability. + if (PatriciaTrieReadingUtils::hasMultipleChars(flags)) { + PatriciaTrieReadingUtils::skipCharacters(mDictRoot, flags, MAX_WORD_LENGTH, &pos); + } + if (PatriciaTrieReadingUtils::isTerminal(flags)) { + PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mDictRoot, &pos); + } + // The fact that this PtNode has children is very important. Since we already know + // that this PtNode does not match, if it has no children we know it is irrelevant + // to what we are searching for. + const bool hasChildren = PatriciaTrieReadingUtils::hasChildrenInFlags(flags); + // We will write in `found' whether we have passed the children position we are + // searching for. For example if we search for "beer", the children of b are less + // than the address we are searching for and the children of c are greater. When we + // come here for c, we realize this is too big, and that we should descend b. + bool found; + if (hasChildren) { + int currentPos = pos; + // Here comes the tricky part. First, read the children position. + const int childrenPos = PatriciaTrieReadingUtils + ::readChildrenPositionAndAdvancePosition(mDictRoot, flags, ¤tPos); + if (childrenPos > ptNodePos) { + // If the children pos is greater than the position, it means the previous + // PtNode, which position is stored in lastCandidatePtNodePos, was the right + // one. + found = true; + } else if (1 >= ptNodeCount) { + // However if we are on the LAST PtNode of this array, and we have NOT shot the + // position we should descend THIS node. So we trick the lastCandidatePtNodePos + // so that we will descend this PtNode, not the previous one. + lastCandidatePtNodePos = startPos; + found = true; + } else { + // Else, we should continue looking. + found = false; + } + } else { + // Even if we don't have children here, we could still be on the last PtNode of / + // this array. If this is the case, we should descend the last PtNode that had + // children, and their position is already in lastCandidatePtNodePos. + found = (1 >= ptNodeCount); + } + + if (found) { + // Okay, we found the PtNode we should descend. Its position is in + // the lastCandidatePtNodePos variable, so we just re-read it. + if (0 != lastCandidatePtNodePos) { + const PatriciaTrieReadingUtils::NodeFlags lastFlags = + PatriciaTrieReadingUtils::getFlagsAndAdvancePosition( + mDictRoot, &lastCandidatePtNodePos); + const int lastChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition( + mDictRoot, &lastCandidatePtNodePos); + // We copy all the characters in this PtNode to the buffer + outCodePoints[wordPos] = lastChar; + if (PatriciaTrieReadingUtils::hasMultipleChars(lastFlags)) { + int nextChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition( + mDictRoot, &lastCandidatePtNodePos); + int charCount = maxCodePointCount; + while (-1 != nextChar && --charCount > 0) { + outCodePoints[++wordPos] = nextChar; + nextChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition( + mDictRoot, &lastCandidatePtNodePos); + } + } + ++wordPos; + // Now we only need to branch to the children address. Skip the probability if + // it's there, read pos, and break to resume the search at pos. + if (PatriciaTrieReadingUtils::isTerminal(lastFlags)) { + PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mDictRoot, + &lastCandidatePtNodePos); + } + pos = PatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition( + mDictRoot, lastFlags, &lastCandidatePtNodePos); + break; + } else { + // Here is a little tricky part: we come here if we found out that all children + // addresses in this PtNode are bigger than the address we are searching for. + // Should we conclude the word is not in the dictionary? No! It could still be + // one of the remaining PtNodes in this array, so we have to keep looking in + // this array until we find it (or we realize it's not there either, in which + // case it's actually not in the dictionary). Pass the end of this PtNode, + // ready to start the next one. + if (PatriciaTrieReadingUtils::hasChildrenInFlags(flags)) { + PatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition( + mDictRoot, flags, &pos); + } + if (PatriciaTrieReadingUtils::hasShortcutTargets(flags)) { + mShortcutListPolicy.skipAllShortcuts(&pos); + } + if (PatriciaTrieReadingUtils::hasBigrams(flags)) { + mBigramListPolicy.skipAllBigrams(&pos); + } + } + } else { + // If we did not find it, we should record the last children address for the next + // iteration. + if (hasChildren) lastCandidatePtNodePos = startPos; + // Now skip the end of this PtNode (children pos and the attributes if any) so that + // our pos is after the end of this PtNode, at the start of the next one. + if (PatriciaTrieReadingUtils::hasChildrenInFlags(flags)) { + PatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition( + mDictRoot, flags, &pos); + } + if (PatriciaTrieReadingUtils::hasShortcutTargets(flags)) { + mShortcutListPolicy.skipAllShortcuts(&pos); + } + if (PatriciaTrieReadingUtils::hasBigrams(flags)) { + mBigramListPolicy.skipAllBigrams(&pos); + } + } + + } + } + // If we have looked through all the PtNodes and found no match, the ptNodePos is + // not the position of a terminal in this dictionary. + return 0; } +// This function gets the position of the terminal node of the exact matching word in the +// dictionary. If no match is found, it returns NOT_A_DICT_POS. int PatriciaTriePolicy::getTerminalNodePositionOfWord(const int *const inWord, const int length, const bool forceLowerCaseSearch) const { - return BinaryFormat::getTerminalPosition(mDictRoot, inWord, - length, forceLowerCaseSearch); + int pos = getRootPosition(); + int wordPos = 0; + + while (true) { + // If we already traversed the tree further than the word is long, there means + // there was no match (or we would have found it). + if (wordPos >= length) return NOT_A_DICT_POS; + int ptNodeCount = PatriciaTrieReadingUtils::getPtNodeArraySizeAndAdvancePosition(mDictRoot, + &pos); + const int wChar = forceLowerCaseSearch + ? CharUtils::toLowerCase(inWord[wordPos]) : inWord[wordPos]; + while (true) { + // If there are no more PtNodes in this array, it means we could not + // find a matching character for this depth, therefore there is no match. + if (0 >= ptNodeCount) return NOT_A_DICT_POS; + const int ptNodePos = pos; + const PatriciaTrieReadingUtils::NodeFlags flags = + PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(mDictRoot, &pos); + int character = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition(mDictRoot, + &pos); + if (character == wChar) { + // This is the correct PtNode. Only one PtNode may start with the same char within + // a PtNode array, so either we found our match in this array, or there is + // no match and we can return NOT_A_DICT_POS. So we will check all the + // characters in this PtNode indeed does match. + if (PatriciaTrieReadingUtils::hasMultipleChars(flags)) { + character = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition(mDictRoot, + &pos); + while (NOT_A_CODE_POINT != character) { + ++wordPos; + // If we shoot the length of the word we search for, or if we find a single + // character that does not match, as explained above, it means the word is + // not in the dictionary (by virtue of this PtNode being the only one to + // match the word on the first character, but not matching the whole word). + if (wordPos >= length) return NOT_A_DICT_POS; + if (inWord[wordPos] != character) return NOT_A_DICT_POS; + character = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition( + mDictRoot, &pos); + } + } + // If we come here we know that so far, we do match. Either we are on a terminal + // and we match the length, in which case we found it, or we traverse children. + // If we don't match the length AND don't have children, then a word in the + // dictionary fully matches a prefix of the searched word but not the full word. + ++wordPos; + if (PatriciaTrieReadingUtils::isTerminal(flags)) { + if (wordPos == length) { + return ptNodePos; + } + PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mDictRoot, &pos); + } + if (!PatriciaTrieReadingUtils::hasChildrenInFlags(flags)) { + return NOT_A_DICT_POS; + } + // We have children and we are still shorter than the word we are searching for, so + // we need to traverse children. Put the pointer on the children position, and + // break + pos = PatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition(mDictRoot, + flags, &pos); + break; + } else { + // This PtNode does not match, so skip the remaining part and go to the next. + if (PatriciaTrieReadingUtils::hasMultipleChars(flags)) { + PatriciaTrieReadingUtils::skipCharacters(mDictRoot, flags, MAX_WORD_LENGTH, + &pos); + } + if (PatriciaTrieReadingUtils::isTerminal(flags)) { + PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mDictRoot, &pos); + } + if (PatriciaTrieReadingUtils::hasChildrenInFlags(flags)) { + PatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition(mDictRoot, + flags, &pos); + } + if (PatriciaTrieReadingUtils::hasShortcutTargets(flags)) { + mShortcutListPolicy.skipAllShortcuts(&pos); + } + if (PatriciaTrieReadingUtils::hasBigrams(flags)) { + mBigramListPolicy.skipAllBigrams(&pos); + } + } + --ptNodeCount; + } + } +} + +int PatriciaTriePolicy::getProbability(const int unigramProbability, + const int bigramProbability) const { + if (unigramProbability == NOT_A_PROBABILITY) { + return NOT_A_PROBABILITY; + } else if (bigramProbability == NOT_A_PROBABILITY) { + return ProbabilityUtils::backoff(unigramProbability); + } else { + return ProbabilityUtils::computeProbabilityForBigram(unigramProbability, + bigramProbability); + } } -int PatriciaTriePolicy::getUnigramProbability(const int nodePos) const { - if (nodePos == NOT_A_VALID_WORD_POS) { +int PatriciaTriePolicy::getUnigramProbabilityOfPtNode(const int ptNodePos) const { + if (ptNodePos == NOT_A_DICT_POS) { return NOT_A_PROBABILITY; } - int pos = nodePos; + int pos = ptNodePos; const PatriciaTrieReadingUtils::NodeFlags flags = PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(mDictRoot, &pos); if (!PatriciaTrieReadingUtils::isTerminal(flags)) { @@ -69,14 +349,15 @@ int PatriciaTriePolicy::getUnigramProbability(const int nodePos) const { return NOT_A_PROBABILITY; } PatriciaTrieReadingUtils::skipCharacters(mDictRoot, flags, MAX_WORD_LENGTH, &pos); - return PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mDictRoot, &pos); + return getProbability(PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition( + mDictRoot, &pos), NOT_A_PROBABILITY); } -int PatriciaTriePolicy::getShortcutPositionOfNode(const int nodePos) const { - if (nodePos == NOT_A_VALID_WORD_POS) { +int PatriciaTriePolicy::getShortcutPositionOfPtNode(const int ptNodePos) const { + if (ptNodePos == NOT_A_DICT_POS) { return NOT_A_DICT_POS; } - int pos = nodePos; + int pos = ptNodePos; const PatriciaTrieReadingUtils::NodeFlags flags = PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(mDictRoot, &pos); if (!PatriciaTrieReadingUtils::hasShortcutTargets(flags)) { @@ -92,11 +373,11 @@ int PatriciaTriePolicy::getShortcutPositionOfNode(const int nodePos) const { return pos; } -int PatriciaTriePolicy::getBigramsPositionOfNode(const int nodePos) const { - if (nodePos == NOT_A_VALID_WORD_POS) { +int PatriciaTriePolicy::getBigramsPositionOfPtNode(const int ptNodePos) const { + if (ptNodePos == NOT_A_DICT_POS) { return NOT_A_DICT_POS; } - int pos = nodePos; + int pos = ptNodePos; const PatriciaTrieReadingUtils::NodeFlags flags = PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(mDictRoot, &pos); if (!PatriciaTrieReadingUtils::hasBigrams(flags)) { @@ -116,8 +397,8 @@ int PatriciaTriePolicy::getBigramsPositionOfNode(const int nodePos) const { } int PatriciaTriePolicy::createAndGetLeavingChildNode(const DicNode *const dicNode, - const int nodePos, DicNodeVector *childDicNodes) const { - int pos = nodePos; + const int ptNodePos, DicNodeVector *childDicNodes) const { + int pos = ptNodePos; const PatriciaTrieReadingUtils::NodeFlags flags = PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(mDictRoot, &pos); int mergedNodeCodePoints[MAX_WORD_LENGTH]; @@ -135,7 +416,12 @@ int PatriciaTriePolicy::createAndGetLeavingChildNode(const DicNode *const dicNod if (PatriciaTrieReadingUtils::hasBigrams(flags)) { getBigramsStructurePolicy()->skipAllBigrams(&pos); } - childDicNodes->pushLeavingChild(dicNode, nodePos, childrenPos, probability, + if (mergedNodeCodePointCount <= 0) { + AKLOGE("Empty PtNode is not allowed. Code point count: %d", mergedNodeCodePointCount); + ASSERT(false); + return pos; + } + childDicNodes->pushLeavingChild(dicNode, ptNodePos, childrenPos, probability, PatriciaTrieReadingUtils::isTerminal(flags), PatriciaTrieReadingUtils::hasChildrenInFlags(flags), PatriciaTrieReadingUtils::isBlacklisted(flags) || diff --git a/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.h index 0d85050f3..f1de914cb 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.h @@ -24,7 +24,7 @@ #include "suggest/policyimpl/dictionary/bigram/bigram_list_policy.h" #include "suggest/policyimpl/dictionary/header/header_policy.h" #include "suggest/policyimpl/dictionary/shortcut/shortcut_list_policy.h" -#include "suggest/policyimpl/dictionary/utils/mmaped_buffer.h" +#include "suggest/policyimpl/dictionary/utils/mmapped_buffer.h" namespace latinime { @@ -33,9 +33,10 @@ class DicNodeVector; class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { public: - PatriciaTriePolicy(const MmapedBuffer *const buffer) - : mBuffer(buffer), mHeaderPolicy(mBuffer->getBuffer()), + PatriciaTriePolicy(const MmappedBuffer *const buffer) + : mBuffer(buffer), mHeaderPolicy(mBuffer->getBuffer(), buffer->getBufferSize()), mDictRoot(mBuffer->getBuffer() + mHeaderPolicy.getSize()), + mDictBufferSize(mBuffer->getBufferSize() - mHeaderPolicy.getSize()), mBigramListPolicy(mDictRoot), mShortcutListPolicy(mDictRoot) {} ~PatriciaTriePolicy() { @@ -56,11 +57,13 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { int getTerminalNodePositionOfWord(const int *const inWord, const int length, const bool forceLowerCaseSearch) const; - int getUnigramProbability(const int nodePos) const; + int getProbability(const int unigramProbability, const int bigramProbability) const; - int getShortcutPositionOfNode(const int nodePos) const; + int getUnigramProbabilityOfPtNode(const int ptNodePos) const; - int getBigramsPositionOfNode(const int nodePos) const; + int getShortcutPositionOfPtNode(const int ptNodePos) const; + + int getBigramsPositionOfPtNode(const int ptNodePos) const; const DictionaryHeaderStructurePolicy *getHeaderStructurePolicy() const { return &mHeaderPolicy; @@ -94,16 +97,33 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { return false; } + void flush(const char *const filePath) { + // This method should not be called for non-updatable dictionary. + AKLOGI("Warning: flush() is called for non-updatable dictionary."); + } + + void flushWithGC(const char *const filePath) { + // This method should not be called for non-updatable dictionary. + AKLOGI("Warning: flushWithGC() is called for non-updatable dictionary."); + } + + bool needsToRunGC() const { + // This method should not be called for non-updatable dictionary. + AKLOGI("Warning: needsToRunGC() is called for non-updatable dictionary."); + return false; + } + private: DISALLOW_IMPLICIT_CONSTRUCTORS(PatriciaTriePolicy); - const MmapedBuffer *const mBuffer; + const MmappedBuffer *const mBuffer; const HeaderPolicy mHeaderPolicy; const uint8_t *const mDictRoot; + const int mDictBufferSize; const BigramListPolicy mBigramListPolicy; const ShortcutListPolicy mShortcutListPolicy; - int createAndGetLeavingChildNode(const DicNode *const dicNode, const int nodePos, + int createAndGetLeavingChildNode(const DicNode *const dicNode, const int ptNodePos, DicNodeVector *const childDicNodes) const; }; } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_reading_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_reading_utils.cpp index 003b9435b..7df55815f 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_reading_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_reading_utils.cpp @@ -23,15 +23,15 @@ namespace latinime { typedef PatriciaTrieReadingUtils PtReadingUtils; -const PtReadingUtils::NodeFlags PtReadingUtils::MASK_GROUP_ADDRESS_TYPE = 0xC0; -const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_GROUP_ADDRESS_TYPE_NOADDRESS = 0x00; -const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_GROUP_ADDRESS_TYPE_ONEBYTE = 0x40; -const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_GROUP_ADDRESS_TYPE_TWOBYTES = 0x80; -const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_GROUP_ADDRESS_TYPE_THREEBYTES = 0xC0; +const PtReadingUtils::NodeFlags PtReadingUtils::MASK_CHILDREN_POSITION_TYPE = 0xC0; +const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_CHILDREN_POSITION_TYPE_NOPOSITION = 0x00; +const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_CHILDREN_POSITION_TYPE_ONEBYTE = 0x40; +const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_CHILDREN_POSITION_TYPE_TWOBYTES = 0x80; +const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_CHILDREN_POSITION_TYPE_THREEBYTES = 0xC0; // Flag for single/multiple char group const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_HAS_MULTIPLE_CHARS = 0x20; -// Flag for terminal groups +// Flag for terminal PtNodes const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_IS_TERMINAL = 0x10; // Flag for shortcut targets presence const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_HAS_SHORTCUT_TARGETS = 0x08; @@ -42,18 +42,84 @@ const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_IS_NOT_A_WORD = 0x02; // Flag for blacklist const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_IS_BLACKLISTED = 0x01; +/* static */ int PtReadingUtils::getPtNodeArraySizeAndAdvancePosition( + const uint8_t *const buffer, int *const pos) { + const uint8_t firstByte = ByteArrayUtils::readUint8AndAdvancePosition(buffer, pos); + if (firstByte < 0x80) { + return firstByte; + } else { + return ((firstByte & 0x7F) << 8) ^ ByteArrayUtils::readUint8AndAdvancePosition( + buffer, pos); + } +} + +/* static */ PtReadingUtils::NodeFlags PtReadingUtils::getFlagsAndAdvancePosition( + const uint8_t *const buffer, int *const pos) { + return ByteArrayUtils::readUint8AndAdvancePosition(buffer, pos); +} + +/* static */ int PtReadingUtils::getCodePointAndAdvancePosition(const uint8_t *const buffer, + int *const pos) { + return ByteArrayUtils::readCodePointAndAdvancePosition(buffer, pos); +} + +// Returns the number of read characters. +/* static */ int PtReadingUtils::getCharsAndAdvancePosition(const uint8_t *const buffer, + const NodeFlags flags, const int maxLength, int *const outBuffer, int *const pos) { + int length = 0; + if (hasMultipleChars(flags)) { + length = ByteArrayUtils::readStringAndAdvancePosition(buffer, maxLength, outBuffer, + pos); + } else { + const int codePoint = getCodePointAndAdvancePosition(buffer, pos); + if (codePoint == NOT_A_CODE_POINT) { + // CAVEAT: codePoint == NOT_A_CODE_POINT means the code point is + // CHARACTER_ARRAY_TERMINATOR. The code point must not be CHARACTER_ARRAY_TERMINATOR + // when the PtNode has a single code point. + length = 0; + AKLOGE("codePoint is NOT_A_CODE_POINT. pos: %d, codePoint: 0x%x, buffer[pos - 1]: 0x%x", + *pos - 1, codePoint, buffer[*pos - 1]); + ASSERT(false); + } else if (maxLength > 0) { + outBuffer[0] = codePoint; + length = 1; + } + } + return length; +} + +// Returns the number of skipped characters. +/* static */ int PtReadingUtils::skipCharacters(const uint8_t *const buffer, const NodeFlags flags, + const int maxLength, int *const pos) { + if (hasMultipleChars(flags)) { + return ByteArrayUtils::advancePositionToBehindString(buffer, maxLength, pos); + } else { + if (maxLength > 0) { + getCodePointAndAdvancePosition(buffer, pos); + return 1; + } else { + return 0; + } + } +} + +/* static */ int PtReadingUtils::readProbabilityAndAdvancePosition(const uint8_t *const buffer, + int *const pos) { + return ByteArrayUtils::readUint8AndAdvancePosition(buffer, pos); +} + /* static */ int PtReadingUtils::readChildrenPositionAndAdvancePosition( const uint8_t *const buffer, const NodeFlags flags, int *const pos) { const int base = *pos; int offset = 0; - switch (MASK_GROUP_ADDRESS_TYPE & flags) { - case FLAG_GROUP_ADDRESS_TYPE_ONEBYTE: + switch (MASK_CHILDREN_POSITION_TYPE & flags) { + case FLAG_CHILDREN_POSITION_TYPE_ONEBYTE: offset = ByteArrayUtils::readUint8AndAdvancePosition(buffer, pos); break; - case FLAG_GROUP_ADDRESS_TYPE_TWOBYTES: + case FLAG_CHILDREN_POSITION_TYPE_TWOBYTES: offset = ByteArrayUtils::readUint16AndAdvancePosition(buffer, pos); break; - case FLAG_GROUP_ADDRESS_TYPE_THREEBYTES: + case FLAG_CHILDREN_POSITION_TYPE_THREEBYTES: offset = ByteArrayUtils::readUint24AndAdvancePosition(buffer, pos); break; default: diff --git a/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_reading_utils.h b/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_reading_utils.h index 9f2fc2059..8420ee95a 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_reading_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_reading_utils.h @@ -20,7 +20,6 @@ #include <stdint.h> #include "defines.h" -#include "suggest/policyimpl/dictionary/utils/byte_array_utils.h" namespace latinime { @@ -28,62 +27,21 @@ class PatriciaTrieReadingUtils { public: typedef uint8_t NodeFlags; - static AK_FORCE_INLINE int getGroupCountAndAdvancePosition( - const uint8_t *const buffer, int *const pos) { - const uint8_t firstByte = ByteArrayUtils::readUint8AndAdvancePosition(buffer, pos); - if (firstByte < 0x80) { - return firstByte; - } else { - return ((firstByte & 0x7F) << 8) ^ ByteArrayUtils::readUint8AndAdvancePosition( - buffer, pos); - } - } + static int getPtNodeArraySizeAndAdvancePosition(const uint8_t *const buffer, int *const pos); - static AK_FORCE_INLINE NodeFlags getFlagsAndAdvancePosition(const uint8_t *const buffer, - int *const pos) { - return ByteArrayUtils::readUint8AndAdvancePosition(buffer, pos); - } + static NodeFlags getFlagsAndAdvancePosition(const uint8_t *const buffer, int *const pos); - static AK_FORCE_INLINE int getCodePointAndAdvancePosition(const uint8_t *const buffer, - int *const pos) { - return ByteArrayUtils::readCodePointAndAdvancePosition(buffer, pos); - } + static int getCodePointAndAdvancePosition(const uint8_t *const buffer, int *const pos); // Returns the number of read characters. - static AK_FORCE_INLINE int getCharsAndAdvancePosition(const uint8_t *const buffer, - const NodeFlags flags, const int maxLength, int *const outBuffer, int *const pos) { - int length = 0; - if (hasMultipleChars(flags)) { - length = ByteArrayUtils::readStringAndAdvancePosition(buffer, maxLength, outBuffer, - pos); - } else { - if (maxLength > 0) { - outBuffer[0] = getCodePointAndAdvancePosition(buffer, pos); - length = 1; - } - } - return length; - } + static int getCharsAndAdvancePosition(const uint8_t *const buffer, const NodeFlags flags, + const int maxLength, int *const outBuffer, int *const pos); // Returns the number of skipped characters. - static AK_FORCE_INLINE int skipCharacters(const uint8_t *const buffer, const NodeFlags flags, - const int maxLength, int *const pos) { - if (hasMultipleChars(flags)) { - return ByteArrayUtils::advancePositionToBehindString(buffer, maxLength, pos); - } else { - if (maxLength > 0) { - getCodePointAndAdvancePosition(buffer, pos); - return 1; - } else { - return 0; - } - } - } + static int skipCharacters(const uint8_t *const buffer, const NodeFlags flags, + const int maxLength, int *const pos); - static AK_FORCE_INLINE int readProbabilityAndAdvancePosition(const uint8_t *const buffer, - int *const pos) { - return ByteArrayUtils::readUint8AndAdvancePosition(buffer, pos); - } + static int readProbabilityAndAdvancePosition(const uint8_t *const buffer, int *const pos); static int readChildrenPositionAndAdvancePosition(const uint8_t *const buffer, const NodeFlags flags, int *const pos); @@ -116,17 +74,40 @@ class PatriciaTrieReadingUtils { } static AK_FORCE_INLINE bool hasChildrenInFlags(const NodeFlags flags) { - return FLAG_GROUP_ADDRESS_TYPE_NOADDRESS != (MASK_GROUP_ADDRESS_TYPE & flags); + return FLAG_CHILDREN_POSITION_TYPE_NOPOSITION != (MASK_CHILDREN_POSITION_TYPE & flags); + } + + static AK_FORCE_INLINE NodeFlags createAndGetFlags(const bool isBlacklisted, + const bool isNotAWord, const bool isTerminal, const bool hasShortcutTargets, + const bool hasBigrams, const bool hasMultipleChars, + const int childrenPositionFieldSize) { + NodeFlags nodeFlags = 0; + nodeFlags = isBlacklisted ? (nodeFlags | FLAG_IS_BLACKLISTED) : nodeFlags; + nodeFlags = isNotAWord ? (nodeFlags | FLAG_IS_NOT_A_WORD) : nodeFlags; + nodeFlags = isTerminal ? (nodeFlags | FLAG_IS_TERMINAL) : nodeFlags; + nodeFlags = hasShortcutTargets ? (nodeFlags | FLAG_HAS_SHORTCUT_TARGETS) : nodeFlags; + nodeFlags = hasBigrams ? (nodeFlags | FLAG_HAS_BIGRAMS) : nodeFlags; + nodeFlags = hasMultipleChars ? (nodeFlags | FLAG_HAS_MULTIPLE_CHARS) : nodeFlags; + if (childrenPositionFieldSize == 1) { + nodeFlags |= FLAG_CHILDREN_POSITION_TYPE_ONEBYTE; + } else if (childrenPositionFieldSize == 2) { + nodeFlags |= FLAG_CHILDREN_POSITION_TYPE_TWOBYTES; + } else if (childrenPositionFieldSize == 3) { + nodeFlags |= FLAG_CHILDREN_POSITION_TYPE_THREEBYTES; + } else { + nodeFlags |= FLAG_CHILDREN_POSITION_TYPE_NOPOSITION; + } + return nodeFlags; } private: DISALLOW_IMPLICIT_CONSTRUCTORS(PatriciaTrieReadingUtils); - static const NodeFlags MASK_GROUP_ADDRESS_TYPE; - static const NodeFlags FLAG_GROUP_ADDRESS_TYPE_NOADDRESS; - static const NodeFlags FLAG_GROUP_ADDRESS_TYPE_ONEBYTE; - static const NodeFlags FLAG_GROUP_ADDRESS_TYPE_TWOBYTES; - static const NodeFlags FLAG_GROUP_ADDRESS_TYPE_THREEBYTES; + static const NodeFlags MASK_CHILDREN_POSITION_TYPE; + static const NodeFlags FLAG_CHILDREN_POSITION_TYPE_NOPOSITION; + static const NodeFlags FLAG_CHILDREN_POSITION_TYPE_ONEBYTE; + static const NodeFlags FLAG_CHILDREN_POSITION_TYPE_TWOBYTES; + static const NodeFlags FLAG_CHILDREN_POSITION_TYPE_THREEBYTES; static const NodeFlags FLAG_HAS_MULTIPLE_CHARS; static const NodeFlags FLAG_IS_TERMINAL; diff --git a/native/jni/src/suggest/policyimpl/dictionary/shortcut/dynamic_shortcut_list_policy.h b/native/jni/src/suggest/policyimpl/dictionary/shortcut/dynamic_shortcut_list_policy.h new file mode 100644 index 000000000..bd3211f6a --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/shortcut/dynamic_shortcut_list_policy.h @@ -0,0 +1,123 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DYNAMIC_SHORTCUT_LIST_POLICY_H +#define LATINIME_DYNAMIC_SHORTCUT_LIST_POLICY_H + +#include <stdint.h> + +#include "defines.h" +#include "suggest/core/policy/dictionary_shortcuts_structure_policy.h" +#include "suggest/policyimpl/dictionary/shortcut/shortcut_list_reading_utils.h" +#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" + +namespace latinime { + +/* + * This is a dynamic version of ShortcutListPolicy and supports an additional buffer. + */ +class DynamicShortcutListPolicy : public DictionaryShortcutsStructurePolicy { + public: + explicit DynamicShortcutListPolicy(const BufferWithExtendableBuffer *const buffer) + : mBuffer(buffer) {} + + ~DynamicShortcutListPolicy() {} + + int getStartPos(const int pos) const { + if (pos == NOT_A_DICT_POS) { + return NOT_A_DICT_POS; + } + return pos + ShortcutListReadingUtils::getShortcutListSizeFieldSize(); + } + + void getNextShortcut(const int maxCodePointCount, int *const outCodePoint, + int *const outCodePointCount, bool *const outIsWhitelist, bool *const outHasNext, + int *const pos) const { + const bool usesAdditionalBuffer = mBuffer->isInAdditionalBuffer(*pos); + const uint8_t *const buffer = mBuffer->getBuffer(usesAdditionalBuffer); + if (usesAdditionalBuffer) { + *pos -= mBuffer->getOriginalBufferSize(); + } + const ShortcutListReadingUtils::ShortcutFlags flags = + ShortcutListReadingUtils::getFlagsAndForwardPointer(buffer, pos); + if (outHasNext) { + *outHasNext = ShortcutListReadingUtils::hasNext(flags); + } + if (outIsWhitelist) { + *outIsWhitelist = ShortcutListReadingUtils::isWhitelist(flags); + } + if (outCodePoint) { + *outCodePointCount = ShortcutListReadingUtils::readShortcutTarget( + buffer, maxCodePointCount, outCodePoint, pos); + } + if (usesAdditionalBuffer) { + *pos += mBuffer->getOriginalBufferSize(); + } + } + + void skipAllShortcuts(int *const pos) const { + const bool usesAdditionalBuffer = mBuffer->isInAdditionalBuffer(*pos); + const uint8_t *const buffer = mBuffer->getBuffer(usesAdditionalBuffer); + if (usesAdditionalBuffer) { + *pos -= mBuffer->getOriginalBufferSize(); + } + const int shortcutListSize = ShortcutListReadingUtils + ::getShortcutListSizeAndForwardPointer(buffer, pos); + *pos += shortcutListSize; + if (usesAdditionalBuffer) { + *pos += mBuffer->getOriginalBufferSize(); + } + } + + // Copy shortcuts from the shortcut list that starts at fromPos in mBuffer to toPos in + // bufferToWrite and advance these positions after the shortcut lists. This returns whether + // the copy was succeeded or not. + bool copyAllShortcutsAndReturnIfSucceededOrNot(BufferWithExtendableBuffer *const bufferToWrite, + int *const fromPos, int *const toPos) const { + const bool usesAdditionalBuffer = mBuffer->isInAdditionalBuffer(*fromPos); + if (usesAdditionalBuffer) { + *fromPos -= mBuffer->getOriginalBufferSize(); + } + const int shortcutListSize = ShortcutListReadingUtils + ::getShortcutListSizeAndForwardPointer(mBuffer->getBuffer(usesAdditionalBuffer), + fromPos); + // Copy shortcut list size. + if (!bufferToWrite->writeUintAndAdvancePosition( + shortcutListSize + ShortcutListReadingUtils::getShortcutListSizeFieldSize(), + ShortcutListReadingUtils::getShortcutListSizeFieldSize(), toPos)) { + return false; + } + // Copy shortcut list. + for (int i = 0; i < shortcutListSize; ++i) { + const uint8_t data = ByteArrayUtils::readUint8AndAdvancePosition( + mBuffer->getBuffer(usesAdditionalBuffer), fromPos); + if (!bufferToWrite->writeUintAndAdvancePosition(data, 1 /* size */, toPos)) { + return false; + } + } + if (usesAdditionalBuffer) { + *fromPos += mBuffer->getOriginalBufferSize(); + } + return true; + } + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(DynamicShortcutListPolicy); + + const BufferWithExtendableBuffer *const mBuffer; +}; +} // namespace latinime +#endif // LATINIME_DYNAMIC_SHORTCUT_LIST_POLICY_H diff --git a/native/jni/src/suggest/policyimpl/dictionary/shortcut/shortcut_list_reading_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/shortcut/shortcut_list_reading_utils.cpp index e70bb5071..847dcdee5 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/shortcut/shortcut_list_reading_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/shortcut/shortcut_list_reading_utils.cpp @@ -16,6 +16,8 @@ #include "suggest/policyimpl/dictionary/shortcut/shortcut_list_reading_utils.h" +#include "suggest/policyimpl/dictionary/utils/byte_array_utils.h" + namespace latinime { // Flag for presence of more attributes @@ -28,4 +30,22 @@ const int ShortcutListReadingUtils::SHORTCUT_LIST_SIZE_FIELD_SIZE = 2; // The numeric value of the shortcut probability that means 'whitelist'. const int ShortcutListReadingUtils::WHITELIST_SHORTCUT_PROBABILITY = 15; +/* static */ ShortcutListReadingUtils::ShortcutFlags + ShortcutListReadingUtils::getFlagsAndForwardPointer(const uint8_t *const dictRoot, + int *const pos) { + return ByteArrayUtils::readUint8AndAdvancePosition(dictRoot, pos); +} + +/* static */ int ShortcutListReadingUtils::getShortcutListSizeAndForwardPointer( + const uint8_t *const dictRoot, int *const pos) { + // readUint16andAdvancePosition() returns an offset *including* the uint16 field itself. + return ByteArrayUtils::readUint16AndAdvancePosition(dictRoot, pos) + - SHORTCUT_LIST_SIZE_FIELD_SIZE; +} + +/* static */ int ShortcutListReadingUtils::readShortcutTarget( + const uint8_t *const dictRoot, const int maxLength, int *const outWord, int *const pos) { + return ByteArrayUtils::readStringAndAdvancePosition(dictRoot, maxLength, outWord, pos); +} + } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/shortcut/shortcut_list_reading_utils.h b/native/jni/src/suggest/policyimpl/dictionary/shortcut/shortcut_list_reading_utils.h index b5bb964d5..a83ed5a50 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/shortcut/shortcut_list_reading_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/shortcut/shortcut_list_reading_utils.h @@ -20,7 +20,6 @@ #include <stdint.h> #include "defines.h" -#include "suggest/policyimpl/dictionary/utils/byte_array_utils.h" namespace latinime { @@ -28,10 +27,7 @@ class ShortcutListReadingUtils { public: typedef uint8_t ShortcutFlags; - static AK_FORCE_INLINE ShortcutFlags getFlagsAndForwardPointer( - const uint8_t *const dictRoot, int *const pos) { - return ByteArrayUtils::readUint8AndAdvancePosition(dictRoot, pos); - } + static ShortcutFlags getFlagsAndForwardPointer(const uint8_t *const dictRoot, int *const pos); static AK_FORCE_INLINE int getProbabilityFromFlags(const ShortcutFlags flags) { return flags & MASK_ATTRIBUTE_PROBABILITY; @@ -43,11 +39,10 @@ class ShortcutListReadingUtils { // This method returns the size of the shortcut list region excluding the shortcut list size // field at the beginning. - static AK_FORCE_INLINE int getShortcutListSizeAndForwardPointer( - const uint8_t *const dictRoot, int *const pos) { - // readUint16andAdvancePosition() returns an offset *including* the uint16 field itself. - return ByteArrayUtils::readUint16AndAdvancePosition(dictRoot, pos) - - SHORTCUT_LIST_SIZE_FIELD_SIZE; + static int getShortcutListSizeAndForwardPointer(const uint8_t *const dictRoot, int *const pos); + + static AK_FORCE_INLINE int getShortcutListSizeFieldSize() { + return SHORTCUT_LIST_SIZE_FIELD_SIZE; } static AK_FORCE_INLINE void skipShortcuts(const uint8_t *const dictRoot, int *const pos) { @@ -59,11 +54,8 @@ class ShortcutListReadingUtils { return getProbabilityFromFlags(flags) == WHITELIST_SHORTCUT_PROBABILITY; } - static AK_FORCE_INLINE int readShortcutTarget( - const uint8_t *const dictRoot, const int maxLength, int *const outWord, - int *const pos) { - return ByteArrayUtils::readStringAndAdvancePosition(dictRoot, maxLength, outWord, pos); - } + static int readShortcutTarget(const uint8_t *const dictRoot, const int maxLength, + int *const outWord, int *const pos); private: DISALLOW_IMPLICIT_CONSTRUCTORS(ShortcutListReadingUtils); diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.cpp new file mode 100644 index 000000000..f692882f2 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.cpp @@ -0,0 +1,103 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" + +namespace latinime { + +const size_t BufferWithExtendableBuffer::MAX_ADDITIONAL_BUFFER_SIZE = 1024 * 1024; +const int BufferWithExtendableBuffer::NEAR_BUFFER_LIMIT_THRESHOLD_PERCENTILE = 90; +// TODO: Needs to allocate larger memory corresponding to the current vector size. +const size_t BufferWithExtendableBuffer::EXTEND_ADDITIONAL_BUFFER_SIZE_STEP = 128 * 1024; + +bool BufferWithExtendableBuffer::writeUintAndAdvancePosition(const uint32_t data, const int size, + int *const pos) { + if (!(size >= 1 && size <= 4)) { + AKLOGI("writeUintAndAdvancePosition() is called with invalid size: %d", size); + ASSERT(false); + return false; + } + if (!checkAndPrepareWriting(*pos, size)) { + return false; + } + const bool usesAdditionalBuffer = isInAdditionalBuffer(*pos); + uint8_t *const buffer = usesAdditionalBuffer ? &mAdditionalBuffer[0] : mOriginalBuffer; + if (usesAdditionalBuffer) { + *pos -= mOriginalBufferSize; + } + ByteArrayUtils::writeUintAndAdvancePosition(buffer, data, size, pos); + if (usesAdditionalBuffer) { + *pos += mOriginalBufferSize; + } + return true; +} + +bool BufferWithExtendableBuffer::writeCodePointsAndAdvancePosition(const int *const codePoints, + const int codePointCount, const bool writesTerminator ,int *const pos) { + const size_t size = ByteArrayUtils::calculateRequiredByteCountToStoreCodePoints( + codePoints, codePointCount, writesTerminator); + if (!checkAndPrepareWriting(*pos, size)) { + return false; + } + const bool usesAdditionalBuffer = isInAdditionalBuffer(*pos); + uint8_t *const buffer = usesAdditionalBuffer ? &mAdditionalBuffer[0] : mOriginalBuffer; + if (usesAdditionalBuffer) { + *pos -= mOriginalBufferSize; + } + ByteArrayUtils::writeCodePointsAndAdvancePosition(buffer, codePoints, codePointCount, + writesTerminator, pos); + if (usesAdditionalBuffer) { + *pos += mOriginalBufferSize; + } + return true; +} + +bool BufferWithExtendableBuffer::extendBuffer() { + const size_t sizeAfterExtending = + mAdditionalBuffer.size() + EXTEND_ADDITIONAL_BUFFER_SIZE_STEP; + if (sizeAfterExtending > mMaxAdditionalBufferSize) { + return false; + } + mAdditionalBuffer.resize(mAdditionalBuffer.size() + EXTEND_ADDITIONAL_BUFFER_SIZE_STEP); + return true; +} + +bool BufferWithExtendableBuffer::checkAndPrepareWriting(const int pos, const int size) { + if (isInAdditionalBuffer(pos)) { + const int tailPosition = getTailPosition(); + if (pos == tailPosition) { + // Append data to the tail. + if (pos + size > static_cast<int>(mAdditionalBuffer.size()) + mOriginalBufferSize) { + // Need to extend buffer. + if (!extendBuffer()) { + return false; + } + } + mUsedAdditionalBufferSize += size; + } else if (pos + size > tailPosition) { + // The access will beyond the tail of used region. + return false; + } + } else { + if (pos < 0 || mOriginalBufferSize < pos + size) { + // Invalid position or violate the boundary. + return false; + } + } + return true; +} + +} diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h b/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h new file mode 100644 index 000000000..17d2e39c2 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h @@ -0,0 +1,103 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_BUFFER_WITH_EXTENDABLE_BUFFER_H +#define LATINIME_BUFFER_WITH_EXTENDABLE_BUFFER_H + +#include <cstddef> +#include <stdint.h> +#include <vector> + +#include "defines.h" +#include "suggest/policyimpl/dictionary/utils/byte_array_utils.h" + +namespace latinime { + +// This is used as a buffer that can be extended for updatable dictionaries. +// To optimize performance, raw pointer is directly used for reading buffer. The position has to be +// adjusted to access additional buffer. On the other hand, this class does not provide writable +// raw pointer but provides several methods that handle boundary checking for writing data. +class BufferWithExtendableBuffer { + public: + BufferWithExtendableBuffer(uint8_t *const originalBuffer, const int originalBufferSize, + const int maxAdditionalBufferSize = MAX_ADDITIONAL_BUFFER_SIZE) + : mOriginalBuffer(originalBuffer), mOriginalBufferSize(originalBufferSize), + mAdditionalBuffer(EXTEND_ADDITIONAL_BUFFER_SIZE_STEP), mUsedAdditionalBufferSize(0), + mMaxAdditionalBufferSize(maxAdditionalBufferSize) {} + + AK_FORCE_INLINE int getTailPosition() const { + return mOriginalBufferSize + mUsedAdditionalBufferSize; + } + + /** + * For reading. + */ + AK_FORCE_INLINE bool isInAdditionalBuffer(const int position) const { + return position >= mOriginalBufferSize; + } + + // TODO: Resolve the issue that the address can be changed when the vector is resized. + // CAVEAT!: Be careful about array out of bound access with buffers + AK_FORCE_INLINE const uint8_t *getBuffer(const bool usesAdditionalBuffer) const { + if (usesAdditionalBuffer) { + return &mAdditionalBuffer[0]; + } else { + return mOriginalBuffer; + } + } + + AK_FORCE_INLINE int getOriginalBufferSize() const { + return mOriginalBufferSize; + } + + AK_FORCE_INLINE bool isNearSizeLimit() const { + return mAdditionalBuffer.size() >= ((mMaxAdditionalBufferSize + * NEAR_BUFFER_LIMIT_THRESHOLD_PERCENTILE) / 100); + } + + /** + * For writing. + * + * Writing is allowed for original buffer, already written region of additional buffer and the + * tail of additional buffer. + */ + bool writeUintAndAdvancePosition(const uint32_t data, const int size, int *const pos); + + bool writeCodePointsAndAdvancePosition(const int *const codePoints, const int codePointCount, + const bool writesTerminator, int *const pos); + + private: + DISALLOW_COPY_AND_ASSIGN(BufferWithExtendableBuffer); + + static const size_t MAX_ADDITIONAL_BUFFER_SIZE; + static const int NEAR_BUFFER_LIMIT_THRESHOLD_PERCENTILE; + static const size_t EXTEND_ADDITIONAL_BUFFER_SIZE_STEP; + + uint8_t *const mOriginalBuffer; + const int mOriginalBufferSize; + std::vector<uint8_t> mAdditionalBuffer; + int mUsedAdditionalBufferSize; + const size_t mMaxAdditionalBufferSize; + + // Return if the buffer is successfully extended or not. + bool extendBuffer(); + + // Returns if it is possible to write size-bytes from pos. When pos is at the tail position of + // the additional buffer, try extending the buffer. + bool checkAndPrepareWriting(const int pos, const int size); +}; +} +#endif /* LATINIME_BUFFER_WITH_EXTENDABLE_BUFFER_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/byte_array_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/byte_array_utils.cpp index a84cfb9d5..1833e8832 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/byte_array_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/byte_array_utils.cpp @@ -18,7 +18,8 @@ namespace latinime { -const uint8_t ByteArrayUtils::MINIMAL_ONE_BYTE_CHARACTER_VALUE = 0x20; +const uint8_t ByteArrayUtils::MINIMUM_ONE_BYTE_CHARACTER_VALUE = 0x20; +const uint8_t ByteArrayUtils::MAXIMUM_ONE_BYTE_CHARACTER_VALUE = 0xFF; const uint8_t ByteArrayUtils::CHARACTER_ARRAY_TERMINATOR = 0x1F; } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/byte_array_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/byte_array_utils.h index 75ccfc766..0c1576818 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/byte_array_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/byte_array_utils.h @@ -29,7 +29,34 @@ namespace latinime { class ByteArrayUtils { public: /** - * Integer + * Integer writing + * + * Each method write a corresponding size integer in a big endian manner. + */ + static AK_FORCE_INLINE void writeUintAndAdvancePosition(uint8_t *const buffer, + const uint32_t data, const int size, int *const pos) { + // size must be in 1 to 4. + ASSERT(size >= 1 && size <= 4); + switch (size) { + case 1: + ByteArrayUtils::writeUint8AndAdvancePosition(buffer, data, pos); + return; + case 2: + ByteArrayUtils::writeUint16AndAdvancePosition(buffer, data, pos); + return; + case 3: + ByteArrayUtils::writeUint24AndAdvancePosition(buffer, data, pos); + return; + case 4: + ByteArrayUtils::writeUint32AndAdvancePosition(buffer, data, pos); + return; + default: + break; + } + } + + /** + * Integer reading * * Each method read a corresponding size integer in a big endian manner. */ @@ -88,7 +115,7 @@ class ByteArrayUtils { } /** - * Code Point + * Code Point Reading * * 1 byte = bbbbbbbb match * case 000xxxxx: xxxxx << 16 + next byte << 8 + next byte @@ -108,7 +135,7 @@ class ByteArrayUtils { static AK_FORCE_INLINE int readCodePointAndAdvancePosition( const uint8_t *const buffer, int *const pos) { const uint8_t firstByte = readUint8(buffer, *pos); - if (firstByte < MINIMAL_ONE_BYTE_CHARACTER_VALUE) { + if (firstByte < MINIMUM_ONE_BYTE_CHARACTER_VALUE) { if (firstByte == CHARACTER_ARRAY_TERMINATOR) { *pos += 1; return NOT_A_CODE_POINT; @@ -122,7 +149,7 @@ class ByteArrayUtils { } /** - * String (array of code points) + * String (array of code points) Reading * * Reads code points until the terminator is found. */ @@ -145,48 +172,90 @@ class ByteArrayUtils { int codePoint = readCodePointAndAdvancePosition(buffer, pos); while (NOT_A_CODE_POINT != codePoint && length < maxLength) { codePoint = readCodePointAndAdvancePosition(buffer, pos); + length++; } return length; } - // Returns an integer less than, equal to, or greater than zero when string starting from pos - // in buffer is less than, match, or is greater than charArray. - static AK_FORCE_INLINE int compareStringInBufferWithCharArray(const uint8_t *const buffer, - const char *const charArray, const int maxLength, int *const pos) { - int index = 0; - int codePoint = readCodePointAndAdvancePosition(buffer, pos); - const uint8_t *const uint8CharArrayForComparison = - reinterpret_cast<const uint8_t *>(charArray); - while (NOT_A_CODE_POINT != codePoint - && '\0' != uint8CharArrayForComparison[index] && index < maxLength) { - if (codePoint != uint8CharArrayForComparison[index]) { - // Different character is found. - // Skip the rest of the string in the buffer. - advancePositionToBehindString(buffer, maxLength - index, pos); - return codePoint - uint8CharArrayForComparison[index]; + /** + * String (array of code points) Writing + */ + static void writeCodePointsAndAdvancePosition(uint8_t *const buffer, + const int *const codePoints, const int codePointCount, const bool writesTerminator, + int *const pos) { + for (int i = 0; i < codePointCount; ++i) { + const int codePoint = codePoints[i]; + if (codePoint == NOT_A_CODE_POINT || codePoint == CHARACTER_ARRAY_TERMINATOR) { + break; + } else if (codePoint < MINIMUM_ONE_BYTE_CHARACTER_VALUE + || codePoint > MAXIMUM_ONE_BYTE_CHARACTER_VALUE) { + // three bytes character. + writeUint24AndAdvancePosition(buffer, codePoint, pos); + } else { + // one byte character. + writeUint8AndAdvancePosition(buffer, codePoint, pos); } - // Advance - codePoint = readCodePointAndAdvancePosition(buffer, pos); - ++index; } - if (NOT_A_CODE_POINT != codePoint && index < maxLength) { - // Skip the rest of the string in the buffer. - advancePositionToBehindString(buffer, maxLength - index, pos); + if (writesTerminator) { + writeUint8AndAdvancePosition(buffer, CHARACTER_ARRAY_TERMINATOR, pos); } - if (NOT_A_CODE_POINT == codePoint && '\0' == uint8CharArrayForComparison[index]) { - // When both of the last characters are terminals, we consider the string in the buffer - // matches the given char array - return 0; - } else { - return codePoint - uint8CharArrayForComparison[index]; + } + + static int calculateRequiredByteCountToStoreCodePoints(const int *const codePoints, + const int codePointCount, const bool writesTerminator) { + int byteCount = 0; + for (int i = 0; i < codePointCount; ++i) { + const int codePoint = codePoints[i]; + if (codePoint == NOT_A_CODE_POINT || codePoint == CHARACTER_ARRAY_TERMINATOR) { + break; + } else if (codePoint < MINIMUM_ONE_BYTE_CHARACTER_VALUE + || codePoint > MAXIMUM_ONE_BYTE_CHARACTER_VALUE) { + // three bytes character. + byteCount += 3; + } else { + // one byte character. + byteCount += 1; + } } + if (writesTerminator) { + // The terminator is one byte. + byteCount += 1; + } + return byteCount; } private: DISALLOW_IMPLICIT_CONSTRUCTORS(ByteArrayUtils); - static const uint8_t MINIMAL_ONE_BYTE_CHARACTER_VALUE; + static const uint8_t MINIMUM_ONE_BYTE_CHARACTER_VALUE; + static const uint8_t MAXIMUM_ONE_BYTE_CHARACTER_VALUE; static const uint8_t CHARACTER_ARRAY_TERMINATOR; + + static AK_FORCE_INLINE void writeUint32AndAdvancePosition(uint8_t *const buffer, + const uint32_t data, int *const pos) { + buffer[(*pos)++] = (data >> 24) & 0xFF; + buffer[(*pos)++] = (data >> 16) & 0xFF; + buffer[(*pos)++] = (data >> 8) & 0xFF; + buffer[(*pos)++] = data & 0xFF; + } + + static AK_FORCE_INLINE void writeUint24AndAdvancePosition(uint8_t *const buffer, + const uint32_t data, int *const pos) { + buffer[(*pos)++] = (data >> 16) & 0xFF; + buffer[(*pos)++] = (data >> 8) & 0xFF; + buffer[(*pos)++] = data & 0xFF; + } + + static AK_FORCE_INLINE void writeUint16AndAdvancePosition(uint8_t *const buffer, + const uint16_t data, int *const pos) { + buffer[(*pos)++] = (data >> 8) & 0xFF; + buffer[(*pos)++] = data & 0xFF; + } + + static AK_FORCE_INLINE void writeUint8AndAdvancePosition(uint8_t *const buffer, + const uint8_t data, int *const pos) { + buffer[(*pos)++] = data & 0xFF; + } }; } // namespace latinime #endif /* LATINIME_BYTE_ARRAY_UTILS_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp new file mode 100644 index 000000000..2e4ec2e1d --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp @@ -0,0 +1,107 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/utils/dict_file_writing_utils.h" + +#include <cstdio> +#include <cstring> + +#include "suggest/policyimpl/dictionary/header/header_policy.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_utils.h" +#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" +#include "suggest/policyimpl/dictionary/utils/format_utils.h" + +namespace latinime { + +const char *const DictFileWritingUtils::TEMP_FILE_SUFFIX_FOR_WRITING_DICT_FILE = ".tmp"; + +/* static */ bool DictFileWritingUtils::createEmptyDictFile(const char *const filePath, + const int dictVersion, const HeaderReadWriteUtils::AttributeMap *const attributeMap) { + switch (dictVersion) { + case 3: + return createEmptyV3DictFile(filePath, attributeMap); + default: + // Only version 3 dictionary is supported for now. + return false; + } +} + +/* static */ bool DictFileWritingUtils::createEmptyV3DictFile(const char *const filePath, + const HeaderReadWriteUtils::AttributeMap *const attributeMap) { + BufferWithExtendableBuffer headerBuffer(0 /* originalBuffer */, 0 /* originalBufferSize */); + HeaderPolicy headerPolicy(FormatUtils::VERSION_3, attributeMap); + headerPolicy.writeHeaderToBuffer(&headerBuffer, true /* updatesLastUpdatedTime */); + BufferWithExtendableBuffer bodyBuffer(0 /* originalBuffer */, 0 /* originalBufferSize */); + if (!DynamicPatriciaTrieWritingUtils::writeEmptyDictionary(&bodyBuffer, 0 /* rootPos */)) { + return false; + } + return flushAllHeaderAndBodyToFile(filePath, &headerBuffer, &bodyBuffer); +} + +/* static */ bool DictFileWritingUtils::flushAllHeaderAndBodyToFile(const char *const filePath, + BufferWithExtendableBuffer *const dictHeader, BufferWithExtendableBuffer *const dictBody) { + const int tmpFileNameBufSize = strlen(filePath) + + strlen(TEMP_FILE_SUFFIX_FOR_WRITING_DICT_FILE) + 1 /* terminator */; + // Name of a temporary file used for writing that is a connected string of original name and + // TEMP_FILE_SUFFIX_FOR_WRITING_DICT_FILE. + char tmpFileName[tmpFileNameBufSize]; + snprintf(tmpFileName, tmpFileNameBufSize, "%s%s", filePath, + TEMP_FILE_SUFFIX_FOR_WRITING_DICT_FILE); + FILE *const file = fopen(tmpFileName, "wb"); + if (!file) { + AKLOGE("Dictionary file %s cannnot be opened.", tmpFileName); + ASSERT(false); + return false; + } + // Write the dictionary header. + if (!writeBufferToFile(file, dictHeader)) { + remove(tmpFileName); + AKLOGE("Dictionary header cannnot be written. size: %d", dictHeader->getTailPosition()); + ASSERT(false); + return false; + } + // Write the dictionary body. + if (!writeBufferToFile(file, dictBody)) { + remove(tmpFileName); + AKLOGE("Dictionary body cannnot be written. size: %d", dictBody->getTailPosition()); + ASSERT(false); + return false; + } + fclose(file); + rename(tmpFileName, filePath); + return true; +} + +// This closes file pointer when an error is caused and returns whether the writing was succeeded +// or not. +/* static */ bool DictFileWritingUtils::writeBufferToFile(FILE *const file, + const BufferWithExtendableBuffer *const buffer) { + const int originalBufSize = buffer->getOriginalBufferSize(); + if (originalBufSize > 0 && fwrite(buffer->getBuffer(false /* usesAdditionalBuffer */), + originalBufSize, 1, file) < 1) { + fclose(file); + return false; + } + const int additionalBufSize = buffer->getTailPosition() - buffer->getOriginalBufferSize(); + if (additionalBufSize > 0 && fwrite(buffer->getBuffer(true /* usesAdditionalBuffer */), + additionalBufSize, 1, file) < 1) { + fclose(file); + return false; + } + return true; +} + +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.h new file mode 100644 index 000000000..bd4ac66fd --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.h @@ -0,0 +1,50 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DICT_FILE_WRITING_UTILS_H +#define LATINIME_DICT_FILE_WRITING_UTILS_H + +#include <cstdio> + +#include "defines.h" +#include "suggest/policyimpl/dictionary/header/header_read_write_utils.h" + +namespace latinime { + +class BufferWithExtendableBuffer; + +class DictFileWritingUtils { + public: + static bool createEmptyDictFile(const char *const filePath, const int dictVersion, + const HeaderReadWriteUtils::AttributeMap *const attributeMap); + + static bool flushAllHeaderAndBodyToFile(const char *const filePath, + BufferWithExtendableBuffer *const dictHeader, + BufferWithExtendableBuffer *const dictBody); + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(DictFileWritingUtils); + + static const char *const TEMP_FILE_SUFFIX_FOR_WRITING_DICT_FILE; + + static bool createEmptyV3DictFile(const char *const filePath, + const HeaderReadWriteUtils::AttributeMap *const attributeMap); + + static bool writeBufferToFile(FILE *const file, + const BufferWithExtendableBuffer *const buffer); +}; +} // namespace latinime +#endif /* LATINIME_DICT_FILE_WRITING_UTILS_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp index 3796c7b7b..1d77d5c27 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp @@ -20,20 +20,10 @@ namespace latinime { -/** - * Dictionary size - */ -// Any file smaller than this is not a dictionary. -const int FormatUtils::DICTIONARY_MINIMUM_SIZE = 4; +const uint32_t FormatUtils::MAGIC_NUMBER = 0x9BC13AFE; -/** - * Format versions - */ -// 32 bit magic number is stored at the beginning of the dictionary header to reject unsupported -// or obsolete dictionary formats. -const uint32_t FormatUtils::HEADER_VERSION_2_MAGIC_NUMBER = 0x9BC13AFE; -// Magic number (4 bytes), version (2 bytes), options (2 bytes), header size (4 bytes) = 12 -const int FormatUtils::HEADER_VERSION_2_MINIMUM_SIZE = 12; +// Magic number (4 bytes), version (2 bytes), flags (2 bytes), header size (4 bytes) = 12 +const int FormatUtils::DICTIONARY_MINIMUM_SIZE = 12; /* static */ FormatUtils::FORMAT_VERSION FormatUtils::detectFormatVersion( const uint8_t *const dict, const int dictSize) { @@ -45,16 +35,10 @@ const int FormatUtils::HEADER_VERSION_2_MINIMUM_SIZE = 12; } const uint32_t magicNumber = ByteArrayUtils::readUint32(dict, 0); switch (magicNumber) { - case HEADER_VERSION_2_MAGIC_NUMBER: - // Version 2 header are at least 12 bytes long. - // If this header has the version 2 magic number but is less than 12 bytes long, - // then it's an unknown format and we need to avoid confidently reading the next bytes. - if (dictSize < HEADER_VERSION_2_MINIMUM_SIZE) { - return UNKNOWN_VERSION; - } + case MAGIC_NUMBER: // Version 2 header is as follows: // Magic number (4 bytes) 0x9B 0xC1 0x3A 0xFE - // Version number (2 bytes) + // Dictionary format version number (2 bytes) // Options (2 bytes) // Header size (4 bytes) : integer, big endian if (ByteArrayUtils::readUint16(dict, 4) == 2) { diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h index f84321577..79ed0de29 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h @@ -34,14 +34,16 @@ class FormatUtils { UNKNOWN_VERSION }; + // 32 bit magic number is stored at the beginning of the dictionary header to reject + // unsupported or obsolete dictionary formats. + static const uint32_t MAGIC_NUMBER; + static FORMAT_VERSION detectFormatVersion(const uint8_t *const dict, const int dictSize); private: DISALLOW_IMPLICIT_CONSTRUCTORS(FormatUtils); static const int DICTIONARY_MINIMUM_SIZE; - static const uint32_t HEADER_VERSION_2_MAGIC_NUMBER; - static const int HEADER_VERSION_2_MINIMUM_SIZE; }; } // namespace latinime #endif /* LATINIME_FORMAT_UTILS_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/mmaped_buffer.h b/native/jni/src/suggest/policyimpl/dictionary/utils/mmapped_buffer.h index e3aabb084..6b69116eb 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/utils/mmaped_buffer.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/mmapped_buffer.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef LATINIME_MMAPED_BUFFER_H -#define LATINIME_MMAPED_BUFFER_H +#ifndef LATINIME_MMAPPED_BUFFER_H +#define LATINIME_MMAPPED_BUFFER_H #include <cerrno> #include <fcntl.h> @@ -27,39 +27,40 @@ namespace latinime { -class MmapedBuffer { +class MmappedBuffer { public: - static MmapedBuffer* openBuffer(const char *const path, const int pathLength, - const int bufOffset, const int size, const bool isUpdatable) { + static MmappedBuffer* openBuffer(const char *const path, const int bufferOffset, + const int bufferSize, const bool isUpdatable) { const int openMode = isUpdatable ? O_RDWR : O_RDONLY; - const int fd = open(path, openMode); - if (fd < 0) { + const int mmapFd = open(path, openMode); + if (mmapFd < 0) { AKLOGE("DICT: Can't open the source. path=%s errno=%d", path, errno); return 0; } const int pagesize = getpagesize(); - const int offset = bufOffset % pagesize; - int adjOffset = bufOffset - offset; - int adjSize = size + offset; + const int offset = bufferOffset % pagesize; + int alignedOffset = bufferOffset - offset; + int alignedSize = bufferSize + offset; const int protMode = isUpdatable ? PROT_READ | PROT_WRITE : PROT_READ; - void *const mmapedBuffer = mmap(0, adjSize, protMode, MAP_PRIVATE, fd, adjOffset); - if (mmapedBuffer == MAP_FAILED) { + void *const mmappedBuffer = mmap(0, alignedSize, protMode, MAP_PRIVATE, mmapFd, + alignedOffset); + if (mmappedBuffer == MAP_FAILED) { AKLOGE("DICT: Can't mmap dictionary. errno=%d", errno); - close(fd); + close(mmapFd); return 0; } - uint8_t *const buffer = static_cast<uint8_t *>(mmapedBuffer) + bufOffset; + uint8_t *const buffer = static_cast<uint8_t *>(mmappedBuffer) + offset; if (!buffer) { AKLOGE("DICT: buffer is null"); - close(fd); + close(mmapFd); return 0; } - return new MmapedBuffer(buffer, adjSize, fd, adjOffset, isUpdatable); + return new MmappedBuffer(buffer, bufferSize, mmappedBuffer, alignedSize, mmapFd, + isUpdatable); } - ~MmapedBuffer() { - int ret = munmap(static_cast<void *>(mBuffer - mBufferOffset), - mBufferSize + mBufferOffset); + ~MmappedBuffer() { + int ret = munmap(mMmappedBuffer, mAlignedSize); if (ret != 0) { AKLOGE("DICT: Failure in munmap. ret=%d errno=%d", ret, errno); } @@ -82,18 +83,20 @@ class MmapedBuffer { } private: - AK_FORCE_INLINE MmapedBuffer(uint8_t *const buffer, const int bufferSize, const int mmapFd, - const int bufferOffset, const bool isUpdatable) - : mBuffer(buffer), mBufferSize(bufferSize), mMmapFd(mmapFd), - mBufferOffset(bufferOffset), mIsUpdatable(isUpdatable) {} + AK_FORCE_INLINE MmappedBuffer(uint8_t *const buffer, const int bufferSize, + void *const mmappedBuffer, const int alignedSize, const int mmapFd, + const bool isUpdatable) + : mBuffer(buffer), mBufferSize(bufferSize), mMmappedBuffer(mmappedBuffer), + mAlignedSize(alignedSize), mMmapFd(mmapFd), mIsUpdatable(isUpdatable) {} - DISALLOW_IMPLICIT_CONSTRUCTORS(MmapedBuffer); + DISALLOW_IMPLICIT_CONSTRUCTORS(MmappedBuffer); uint8_t *const mBuffer; const int mBufferSize; + void *const mMmappedBuffer; + const int mAlignedSize; const int mMmapFd; - const int mBufferOffset; const bool mIsUpdatable; }; } -#endif /* LATINIME_MMAPED_BUFFER_H */ +#endif /* LATINIME_MMAPPED_BUFFER_H */ diff --git a/native/jni/src/suggest/core/dictionary/probability_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/probability_utils.h index 219213574..21fe355b8 100644 --- a/native/jni/src/suggest/core/dictionary/probability_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/probability_utils.h @@ -41,7 +41,7 @@ class ProbabilityUtils { // the unigram probability to be the median value of the 17th step from the top. A value of // 0 for the bigram probability represents the middle of the 16th step from the top, // while a value of 15 represents the middle of the top step. - // See makedict.BinaryDictDecoder for details. + // See makedict.BinaryDictEncoder#makeBigramFlags for details. const float stepSize = static_cast<float>(MAX_PROBABILITY - unigramProbability) / (1.5f + MAX_BIGRAM_ENCODED_PROBABILITY); return unigramProbability diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h index b6aa85896..9f0a331e3 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h @@ -74,7 +74,8 @@ class TypingWeighting : public Weighting { // Note: min() required since length can be MAX_POINT_TO_KEY_LENGTH for characters not on // the keyboard (like accented letters) const float normalizedSquaredLength = traverseSession->getProximityInfoState(0) - ->getPointToKeyLength(pointIndex, dicNode->getNodeCodePoint()); + ->getPointToKeyLength(pointIndex, + CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint())); const float normalizedDistance = TouchPositionCorrectionUtils::getSweetSpotFactor( traverseSession->isTouchPositionCorrectionEnabled(), normalizedSquaredLength); const float weightedDistance = ScoringParams::DISTANCE_WEIGHT_LENGTH * normalizedDistance; @@ -113,10 +114,10 @@ class TypingWeighting : public Weighting { const int16_t parentPointIndex = parentDicNode->getInputIndex(0); const int prevCodePoint = parentDicNode->getNodeCodePoint(); const float distance1 = traverseSession->getProximityInfoState(0)->getPointToKeyLength( - parentPointIndex + 1, prevCodePoint); + parentPointIndex + 1, CharUtils::toBaseLowerCase(prevCodePoint)); const int codePoint = dicNode->getNodeCodePoint(); const float distance2 = traverseSession->getProximityInfoState(0)->getPointToKeyLength( - parentPointIndex, codePoint); + parentPointIndex, CharUtils::toBaseLowerCase(codePoint)); const float distance = distance1 + distance2; const float weightedLengthDistance = distance * ScoringParams::DISTANCE_WEIGHT_LENGTH; @@ -133,7 +134,7 @@ class TypingWeighting : public Weighting { const bool existsAdjacentProximityChars = traverseSession->getProximityInfoState(0) ->existsAdjacentProximityChars(insertedPointIndex); const float dist = traverseSession->getProximityInfoState(0)->getPointToKeyLength( - insertedPointIndex + 1, dicNode->getNodeCodePoint()); + insertedPointIndex + 1, CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint())); const float weightedDistance = dist * ScoringParams::DISTANCE_WEIGHT_LENGTH; const bool singleChar = dicNode->getNodeCodePointCount() == 1; float cost = (singleChar ? ScoringParams::INSERTION_COST_FIRST_CHAR : 0.0f); diff --git a/native/jni/src/utils/autocorrection_threshold_utils.cpp b/native/jni/src/utils/autocorrection_threshold_utils.cpp index 3406e0f8e..1f8ee0814 100644 --- a/native/jni/src/utils/autocorrection_threshold_utils.cpp +++ b/native/jni/src/utils/autocorrection_threshold_utils.cpp @@ -83,9 +83,12 @@ const int AutocorrectionThresholdUtils::FULL_WORD_MULTIPLIER = 2; return 0.0f; } + if (score <= 0 || distance >= afterLength) { + // normalizedScore must be 0.0f (the minimum value) if the score is less than or equal to 0, + // or if the edit distance is larger than or equal to afterLength. + return 0.0f; + } // add a weight based on edit distance. - // distance <= max(afterLength, beforeLength) == afterLength, - // so, 0 <= distance / afterLength <= 1 const float weight = 1.0f - static_cast<float>(distance) / static_cast<float>(afterLength); // TODO: Revise the following logic thoroughly by referring to... |