diff options
Diffstat (limited to 'native/jni/src')
115 files changed, 7618 insertions, 4232 deletions
diff --git a/native/jni/src/defines.h b/native/jni/src/defines.h index dd7437f24..89dfa39b3 100644 --- a/native/jni/src/defines.h +++ b/native/jni/src/defines.h @@ -35,46 +35,74 @@ // Must be equal to ProximityInfo.MAX_PROXIMITY_CHARS_SIZE in Java #define MAX_PROXIMITY_CHARS_SIZE 16 #define ADDITIONAL_PROXIMITY_CHAR_DELIMITER_CODE 2 +#define NELEMS(x) (sizeof(x) / sizeof((x)[0])) -#if defined(FLAG_DO_PROFILE) || defined(FLAG_DBG) -#include <android/log.h> -#ifndef LOG_TAG -#define LOG_TAG "LatinIME: " -#endif // LOG_TAG -#define AKLOGE(fmt, ...) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, fmt, ##__VA_ARGS__) -#define AKLOGI(fmt, ...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, fmt, ##__VA_ARGS__) - -#define DUMP_RESULT(words, frequencies) do { dumpResult(words, frequencies); } while (0) -#define DUMP_WORD(word, length) do { dumpWord(word, length); } while (0) -#define INTS_TO_CHARS(input, length, output) do { \ - intArrayToCharArray(input, length, output); } while (0) - -// TODO: Support full UTF-8 conversion -AK_FORCE_INLINE static int intArrayToCharArray(const int *source, const int sourceSize, - char *dest) { +AK_FORCE_INLINE static int intArrayToCharArray(const int *const source, const int sourceSize, + char *dest, const int destSize) { + // We want to always terminate with a 0 char, so stop one short of the length to make + // sure there is room. + const int destLimit = destSize - 1; int si = 0; int di = 0; - while (si < sourceSize && di < MAX_WORD_LENGTH - 1 && 0 != source[si]) { + while (si < sourceSize && di < destLimit && 0 != source[si]) { const int codePoint = source[si++]; - if (codePoint < 0x7F) { + if (codePoint < 0x7F) { // One byte dest[di++] = codePoint; - } else if (codePoint < 0x7FF) { + } else if (codePoint < 0x7FF) { // Two bytes + if (di + 1 >= destLimit) break; dest[di++] = 0xC0 + (codePoint >> 6); dest[di++] = 0x80 + (codePoint & 0x3F); - } else if (codePoint < 0xFFFF) { + } else if (codePoint < 0xFFFF) { // Three bytes + if (di + 2 >= destLimit) break; dest[di++] = 0xE0 + (codePoint >> 12); - dest[di++] = 0x80 + ((codePoint & 0xFC0) >> 6); + dest[di++] = 0x80 + ((codePoint >> 6) & 0x3F); + dest[di++] = 0x80 + (codePoint & 0x3F); + } else if (codePoint <= 0x1FFFFF) { // Four bytes + if (di + 3 >= destLimit) break; + dest[di++] = 0xF0 + (codePoint >> 18); + dest[di++] = 0x80 + ((codePoint >> 12) & 0x3F); + dest[di++] = 0x80 + ((codePoint >> 6) & 0x3F); dest[di++] = 0x80 + (codePoint & 0x3F); + } else if (codePoint <= 0x3FFFFFF) { // Five bytes + if (di + 4 >= destLimit) break; + dest[di++] = 0xF8 + (codePoint >> 24); + dest[di++] = 0x80 + ((codePoint >> 18) & 0x3F); + dest[di++] = 0x80 + ((codePoint >> 12) & 0x3F); + dest[di++] = 0x80 + ((codePoint >> 6) & 0x3F); + dest[di++] = codePoint & 0x3F; + } else if (codePoint <= 0x7FFFFFFF) { // Six bytes + if (di + 5 >= destLimit) break; + dest[di++] = 0xFC + (codePoint >> 30); + dest[di++] = 0x80 + ((codePoint >> 24) & 0x3F); + dest[di++] = 0x80 + ((codePoint >> 18) & 0x3F); + dest[di++] = 0x80 + ((codePoint >> 12) & 0x3F); + dest[di++] = 0x80 + ((codePoint >> 6) & 0x3F); + dest[di++] = codePoint & 0x3F; + } else { + // Not a code point... skip. } } dest[di] = 0; return di; } +#if defined(FLAG_DO_PROFILE) || defined(FLAG_DBG) +#include <android/log.h> +#ifndef LOG_TAG +#define LOG_TAG "LatinIME: " +#endif // LOG_TAG +#define AKLOGE(fmt, ...) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, fmt, ##__VA_ARGS__) +#define AKLOGI(fmt, ...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, fmt, ##__VA_ARGS__) + +#define DUMP_RESULT(words, frequencies) do { dumpResult(words, frequencies); } while (0) +#define DUMP_WORD(word, length) do { dumpWord(word, length); } while (0) +#define INTS_TO_CHARS(input, length, output, outlength) do { \ + intArrayToCharArray(input, length, output, outlength); } while (0) + static inline void dumpWordInfo(const int *word, const int length, const int rank, const int probability) { static char charBuf[50]; - const int N = intArrayToCharArray(word, length, charBuf); + const int N = intArrayToCharArray(word, length, charBuf, NELEMS(charBuf)); if (N > 1) { AKLOGI("%2d [ %s ] (%d)", rank, charBuf, probability); } @@ -90,7 +118,7 @@ static inline void dumpResult(const int *outWords, const int *frequencies) { static AK_FORCE_INLINE void dumpWord(const int *word, const int length) { static char charBuf[50]; - const int N = intArrayToCharArray(word, length, charBuf); + const int N = intArrayToCharArray(word, length, charBuf, NELEMS(charBuf)); if (N > 1) { AKLOGI("[ %s ]", charBuf); } @@ -203,14 +231,12 @@ static inline void prof_out(void) { #define DEBUG_DICT true #define DEBUG_DICT_FULL false #define DEBUG_EDIT_DISTANCE false -#define DEBUG_SHOW_FOUND_WORD false #define DEBUG_NODE DEBUG_DICT_FULL #define DEBUG_TRACE DEBUG_DICT_FULL #define DEBUG_PROXIMITY_INFO false #define DEBUG_PROXIMITY_CHARS false #define DEBUG_CORRECTION false #define DEBUG_CORRECTION_FREQ false -#define DEBUG_WORDS_PRIORITY_QUEUE false #define DEBUG_SAMPLING_POINTS false #define DEBUG_POINTS_PROBABILITY false #define DEBUG_DOUBLE_LETTER false @@ -229,14 +255,12 @@ static inline void prof_out(void) { #define DEBUG_DICT false #define DEBUG_DICT_FULL false #define DEBUG_EDIT_DISTANCE false -#define DEBUG_SHOW_FOUND_WORD false #define DEBUG_NODE false #define DEBUG_TRACE false #define DEBUG_PROXIMITY_INFO false #define DEBUG_PROXIMITY_CHARS false #define DEBUG_CORRECTION false #define DEBUG_CORRECTION_FREQ false -#define DEBUG_WORDS_PRIORITY_QUEUE false #define DEBUG_SAMPLING_POINTS false #define DEBUG_POINTS_PROBABILITY false #define DEBUG_DOUBLE_LETTER false @@ -268,81 +292,24 @@ static inline void prof_out(void) { // of the binary dictionary where a {key,value} string pair scheme is used. #define LARGEST_INT_DIGIT_COUNT 11 -// Define this to use mmap() for dictionary loading. Undefine to use malloc() instead of mmap(). -// We measured and compared performance of both, and found mmap() is fairly good in terms of -// loading time, and acceptable even for several initial lookups which involve page faults. -#define USE_MMAP_FOR_DICTIONARY - -#define NOT_VALID_WORD (-99) #define NOT_A_CODE_POINT (-1) #define NOT_A_DISTANCE (-1) #define NOT_A_COORDINATE (-1) -#define MATCH_CHAR_WITHOUT_DISTANCE_INFO (-2) -#define PROXIMITY_CHAR_WITHOUT_DISTANCE_INFO (-3) -#define ADDITIONAL_PROXIMITY_CHAR_DISTANCE_INFO (-4) #define NOT_AN_INDEX (-1) #define NOT_A_PROBABILITY (-1) +#define NOT_A_DICT_POS (S_INT_MIN) #define KEYCODE_SPACE ' ' #define KEYCODE_SINGLE_QUOTE '\'' #define KEYCODE_HYPHEN_MINUS '-' -#define CALIBRATE_SCORE_BY_TOUCH_COORDINATES true -#define SUGGEST_MULTIPLE_WORDS true #define SUGGEST_INTERFACE_OUTPUT_SCALE 1000000.0f - -// The following "rate"s are used as a multiplier before dividing by 100, so they are in percent. -#define WORDS_WITH_MISSING_CHARACTER_DEMOTION_RATE 80 -#define WORDS_WITH_MISSING_CHARACTER_DEMOTION_START_POS_10X 12 -#define WORDS_WITH_MISSING_SPACE_CHARACTER_DEMOTION_RATE 58 -#define WORDS_WITH_MISTYPED_SPACE_DEMOTION_RATE 50 -#define WORDS_WITH_EXCESSIVE_CHARACTER_DEMOTION_RATE 75 -#define WORDS_WITH_EXCESSIVE_CHARACTER_OUT_OF_PROXIMITY_DEMOTION_RATE 75 -#define WORDS_WITH_TRANSPOSED_CHARACTERS_DEMOTION_RATE 70 -#define FULL_MATCHED_WORDS_PROMOTION_RATE 120 -#define WORDS_WITH_PROXIMITY_CHARACTER_DEMOTION_RATE 90 -#define WORDS_WITH_ADDITIONAL_PROXIMITY_CHARACTER_DEMOTION_RATE 70 -#define WORDS_WITH_MATCH_SKIP_PROMOTION_RATE 105 -#define WORDS_WITH_JUST_ONE_CORRECTION_PROMOTION_RATE 148 -#define WORDS_WITH_JUST_ONE_CORRECTION_PROMOTION_MULTIPLIER 3 -#define CORRECTION_COUNT_RATE_DEMOTION_RATE_BASE 45 -#define INPUT_EXCEEDS_OUTPUT_DEMOTION_RATE 70 -#define FIRST_CHAR_DIFFERENT_DEMOTION_RATE 96 -#define TWO_WORDS_CAPITALIZED_DEMOTION_RATE 50 -#define TWO_WORDS_CORRECTION_DEMOTION_BASE 80 -#define TWO_WORDS_PLUS_OTHER_ERROR_CORRECTION_DEMOTION_DIVIDER 1 -#define ZERO_DISTANCE_PROMOTION_RATE 110.0f -#define NEUTRAL_SCORE_SQUARED_RADIUS 8.0f -#define HALF_SCORE_SQUARED_RADIUS 32.0f #define MAX_PROBABILITY 255 #define MAX_BIGRAM_ENCODED_PROBABILITY 15 // Assuming locale strings such as en_US, sr-Latn etc. #define MAX_LOCALE_STRING_LENGTH 10 -// Word limit for sub queues used in WordsPriorityQueuePool. Sub queues are temporary queues used -// for better performance. -// Holds up to 1 candidate for each word -#define SUB_QUEUE_MAX_WORDS 1 -#define SUB_QUEUE_MAX_COUNT 10 -#define SUB_QUEUE_MIN_WORD_LENGTH 4 -// TODO: Extend this limitation -#define MULTIPLE_WORDS_SUGGESTION_MAX_WORDS 5 -// TODO: Remove this limitation -#define MULTIPLE_WORDS_SUGGESTION_MAX_WORD_LENGTH 12 -// TODO: Remove this limitation -#define MULTIPLE_WORDS_SUGGESTION_MAX_TOTAL_TRAVERSE_COUNT 45 -#define MULTIPLE_WORDS_DEMOTION_RATE 80 -#define MIN_INPUT_LENGTH_FOR_THREE_OR_MORE_WORDS_CORRECTION 6 - -#define TWO_WORDS_CORRECTION_WITH_OTHER_ERROR_THRESHOLD 0.35f -#define START_TWO_WORDS_CORRECTION_THRESHOLD 0.185f -/* heuristic... This should be changed if we change the unit of the probability. */ -#define SUPPRESS_SHORT_MULTIPLE_WORDS_THRESHOLD_FREQ (MAX_PROBABILITY * 58 / 100) - -#define MAX_DEPTH_MULTIPLIER 3 -#define FIRST_WORD_INDEX 0 - // Max value for length, distance and probability which are used in weighting // TODO: Remove #define MAX_VALUE_FOR_WEIGHTING 10000000 @@ -350,48 +317,13 @@ static inline void prof_out(void) { // The max number of the keys in one keyboard layout #define MAX_KEY_COUNT_IN_A_KEYBOARD 64 -// TODO: Reduce this constant if possible; check the maximum number of digraphs in the same -// word in the dictionary for languages with digraphs, like German and French -#define DEFAULT_MAX_DIGRAPH_SEARCH_DEPTH 5 - -#define MIN_USER_TYPED_LENGTH_FOR_MULTIPLE_WORD_SUGGESTION 3 - // TODO: Remove #define MAX_POINTER_COUNT 1 #define MAX_POINTER_COUNT_G 2 -// Size, in bytes, of the bloom filter index for bigrams -// 128 gives us 1024 buckets. The probability of false positive is (1 - e ** (-kn/m))**k, -// where k is the number of hash functions, n the number of bigrams, and m the number of -// bits we can test. -// At the moment 100 is the maximum number of bigrams for a word with the current -// dictionaries, so n = 100. 1024 buckets give us m = 1024. -// With 1 hash function, our false positive rate is about 9.3%, which should be enough for -// our uses since we are only using this to increase average performance. For the record, -// k = 2 gives 3.1% and k = 3 gives 1.6%. With k = 1, making m = 2048 gives 4.8%, -// and m = 4096 gives 2.4%. -#define BIGRAM_FILTER_BYTE_SIZE 128 -// Must be smaller than BIGRAM_FILTER_BYTE_SIZE * 8, and preferably prime. 1021 is the largest -// prime under 128 * 8. -#define BIGRAM_FILTER_MODULO 1021 -#if BIGRAM_FILTER_BYTE_SIZE * 8 < BIGRAM_FILTER_MODULO -#error "BIGRAM_FILTER_MODULO is larger than BIGRAM_FILTER_BYTE_SIZE" -#endif - -// Max number of bigram maps (previous word contexts) to be cached. Increasing this number could -// improve bigram lookup speed for multi-word suggestions, but at the cost of more memory usage. -// Also, there are diminishing returns since the most frequently used bigrams are typically near -// the beginning of the input and are thus the first ones to be cached. Note that these bigrams -// are reset for each new composing word. -#define MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP 25 -// Most common previous word contexts currently have 100 bigrams -#define DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP 100 - template<typename T> AK_FORCE_INLINE const T &min(const T &a, const T &b) { return a < b ? a : b; } template<typename T> AK_FORCE_INLINE const T &max(const T &a, const T &b) { return a > b ? a : b; } -#define NELEMS(x) (sizeof(x) / sizeof((x)[0])) - // DEBUG #define INPUTLENGTH_FOR_DEBUG (-1) #define MIN_OUTPUT_INDEX_FOR_DEBUG (-1) @@ -441,6 +373,7 @@ typedef enum { CT_TRANSPOSITION, CT_COMPLETION, CT_TERMINAL, + CT_TERMINAL_INSERTION, // Create new word with space omission CT_NEW_WORD_SPACE_OMITTION, // Create new word with space substitution diff --git a/native/jni/src/obsolete/correction.cpp b/native/jni/src/obsolete/correction.cpp deleted file mode 100644 index e6c577f85..000000000 --- a/native/jni/src/obsolete/correction.cpp +++ /dev/null @@ -1,979 +0,0 @@ -/* - * Copyright (C) 2011 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#define LOG_TAG "LatinIME: correction.cpp" - -#include <cmath> - -#include "defines.h" -#include "obsolete/correction.h" -#include "suggest/core/layout/proximity_info_state.h" -#include "suggest/core/layout/touch_position_correction_utils.h" -#include "suggest/policyimpl/utils/edit_distance.h" -#include "suggest/policyimpl/utils/damerau_levenshtein_edit_distance_policy.h" -#include "utils/char_utils.h" - -namespace latinime { - -class ProximityInfo; - -///////////////////////////// -// edit distance funcitons // -///////////////////////////// - -inline static void initEditDistance(int *editDistanceTable) { - for (int i = 0; i <= MAX_WORD_LENGTH; ++i) { - editDistanceTable[i] = i; - } -} - -inline static void dumpEditDistance10ForDebug(int *editDistanceTable, - const int editDistanceTableWidth, const int outputLength) { - if (DEBUG_DICT) { - AKLOGI("EditDistanceTable"); - for (int i = 0; i <= 10; ++i) { - int c[11]; - for (int j = 0; j <= 10; ++j) { - if (j < editDistanceTableWidth + 1 && i < outputLength + 1) { - c[j] = (editDistanceTable + i * (editDistanceTableWidth + 1))[j]; - } else { - c[j] = -1; - } - } - AKLOGI("[ %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d ]", - c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7], c[8], c[9], c[10]); - (void)c; // To suppress compiler warning - } - } -} - -inline static int getCurrentEditDistance(int *editDistanceTable, const int editDistanceTableWidth, - const int outputLength, const int inputSize) { - if (DEBUG_EDIT_DISTANCE) { - AKLOGI("getCurrentEditDistance %d, %d", inputSize, outputLength); - } - return editDistanceTable[(editDistanceTableWidth + 1) * (outputLength) + inputSize]; -} - -//////////////// -// Correction // -//////////////// - -void Correction::resetCorrection() { - mTotalTraverseCount = 0; -} - -void Correction::initCorrection(const ProximityInfo *pi, const int inputSize, const int maxDepth) { - mProximityInfo = pi; - mInputSize = inputSize; - mMaxDepth = maxDepth; - mMaxEditDistance = mInputSize < 5 ? 2 : mInputSize / 2; - // TODO: This is not supposed to be required. Check what's going wrong with - // editDistance[0 ~ MAX_WORD_LENGTH] - initEditDistance(mEditDistanceTable); -} - -void Correction::initCorrectionState( - const int rootPos, const int childCount, const bool traverseAll) { - latinime::initCorrectionState(mCorrectionStates, rootPos, childCount, traverseAll); - // TODO: remove - mCorrectionStates[0].mTransposedPos = mTransposedPos; - mCorrectionStates[0].mExcessivePos = mExcessivePos; - mCorrectionStates[0].mSkipPos = mSkipPos; -} - -void Correction::setCorrectionParams(const int skipPos, const int excessivePos, - const int transposedPos, const int spaceProximityPos, const int missingSpacePos, - const bool useFullEditDistance, const bool doAutoCompletion, const int maxErrors) { - // TODO: remove - mTransposedPos = transposedPos; - mExcessivePos = excessivePos; - mSkipPos = skipPos; - // TODO: remove - mCorrectionStates[0].mTransposedPos = transposedPos; - mCorrectionStates[0].mExcessivePos = excessivePos; - mCorrectionStates[0].mSkipPos = skipPos; - - mSpaceProximityPos = spaceProximityPos; - mMissingSpacePos = missingSpacePos; - mUseFullEditDistance = useFullEditDistance; - mDoAutoCompletion = doAutoCompletion; - mMaxErrors = maxErrors; -} - -void Correction::checkState() const { - if (DEBUG_DICT) { - int inputCount = 0; - if (mSkipPos >= 0) ++inputCount; - if (mExcessivePos >= 0) ++inputCount; - if (mTransposedPos >= 0) ++inputCount; - } -} - -bool Correction::sameAsTyped() const { - return mProximityInfoState.sameAsTyped(mWord, mOutputIndex); -} - -int Correction::getFreqForSplitMultipleWords(const int *freqArray, const int *wordLengthArray, - const int wordCount, const bool isSpaceProximity, const int *word) const { - return Correction::RankingAlgorithm::calcFreqForSplitMultipleWords(freqArray, wordLengthArray, - wordCount, this, isSpaceProximity, word); -} - -int Correction::getFinalProbability(const int probability, int **word, int *wordLength) { - return getFinalProbabilityInternal(probability, word, wordLength, mInputSize); -} - -int Correction::getFinalProbabilityForSubQueue(const int probability, int **word, int *wordLength, - const int inputSize) { - return getFinalProbabilityInternal(probability, word, wordLength, inputSize); -} - -bool Correction::initProcessState(const int outputIndex) { - if (mCorrectionStates[outputIndex].mChildCount <= 0) { - return false; - } - mOutputIndex = outputIndex; - --(mCorrectionStates[outputIndex].mChildCount); - mInputIndex = mCorrectionStates[outputIndex].mInputIndex; - mNeedsToTraverseAllNodes = mCorrectionStates[outputIndex].mNeedsToTraverseAllNodes; - - mEquivalentCharCount = mCorrectionStates[outputIndex].mEquivalentCharCount; - mProximityCount = mCorrectionStates[outputIndex].mProximityCount; - mTransposedCount = mCorrectionStates[outputIndex].mTransposedCount; - mExcessiveCount = mCorrectionStates[outputIndex].mExcessiveCount; - mSkippedCount = mCorrectionStates[outputIndex].mSkippedCount; - mLastCharExceeded = mCorrectionStates[outputIndex].mLastCharExceeded; - - mTransposedPos = mCorrectionStates[outputIndex].mTransposedPos; - mExcessivePos = mCorrectionStates[outputIndex].mExcessivePos; - mSkipPos = mCorrectionStates[outputIndex].mSkipPos; - - mMatching = false; - mProximityMatching = false; - mAdditionalProximityMatching = false; - mTransposing = false; - mExceeding = false; - mSkipping = false; - - return true; -} - -int Correction::goDownTree(const int parentIndex, const int childCount, const int firstChildPos) { - mCorrectionStates[mOutputIndex].mParentIndex = parentIndex; - mCorrectionStates[mOutputIndex].mChildCount = childCount; - mCorrectionStates[mOutputIndex].mSiblingPos = firstChildPos; - return mOutputIndex; -} - -// TODO: remove -int Correction::getInputIndex() const { - return mInputIndex; -} - -bool Correction::needsToPrune() const { - // TODO: use edit distance here - return mOutputIndex - 1 >= mMaxDepth || mProximityCount > mMaxEditDistance - // Allow one char longer word for missing character - || (!mDoAutoCompletion && (mOutputIndex > mInputSize)); -} - -inline static bool isEquivalentChar(ProximityType type) { - return type == MATCH_CHAR; -} - -inline static bool isProximityCharOrEquivalentChar(ProximityType type) { - return type == MATCH_CHAR || type == PROXIMITY_CHAR; -} - -Correction::CorrectionType Correction::processCharAndCalcState(const int c, const bool isTerminal) { - const int correctionCount = (mSkippedCount + mExcessiveCount + mTransposedCount); - if (correctionCount > mMaxErrors) { - return processUnrelatedCorrectionType(); - } - - // TODO: Change the limit if we'll allow two or more corrections - const bool noCorrectionsHappenedSoFar = correctionCount == 0; - const bool canTryCorrection = noCorrectionsHappenedSoFar; - int proximityIndex = 0; - mDistances[mOutputIndex] = NOT_A_DISTANCE; - - // Skip checking this node - if (mNeedsToTraverseAllNodes || isSingleQuote(c)) { - bool incremented = false; - if (mLastCharExceeded && mInputIndex == mInputSize - 1) { - // TODO: Do not check the proximity if EditDistance exceeds the threshold - const ProximityType matchId = mProximityInfoState.getProximityType( - mInputIndex, c, true, &proximityIndex); - if (isEquivalentChar(matchId)) { - mLastCharExceeded = false; - --mExcessiveCount; - mDistances[mOutputIndex] = - mProximityInfoState.getNormalizedSquaredDistance(mInputIndex, 0); - } else if (matchId == PROXIMITY_CHAR) { - mLastCharExceeded = false; - --mExcessiveCount; - ++mProximityCount; - mDistances[mOutputIndex] = mProximityInfoState.getNormalizedSquaredDistance( - mInputIndex, proximityIndex); - } - if (!isSingleQuote(c)) { - incrementInputIndex(); - incremented = true; - } - } - return processSkipChar(c, isTerminal, incremented); - } - - // Check possible corrections. - if (mExcessivePos >= 0) { - if (mExcessiveCount == 0 && mExcessivePos < mOutputIndex) { - mExcessivePos = mOutputIndex; - } - if (mExcessivePos < mInputSize - 1) { - mExceeding = mExcessivePos == mInputIndex && canTryCorrection; - } - } - - if (mSkipPos >= 0) { - if (mSkippedCount == 0 && mSkipPos < mOutputIndex) { - if (DEBUG_DICT) { - // TODO: Enable this assertion. - //ASSERT(mSkipPos == mOutputIndex - 1); - } - mSkipPos = mOutputIndex; - } - mSkipping = mSkipPos == mOutputIndex && canTryCorrection; - } - - if (mTransposedPos >= 0) { - if (mTransposedCount == 0 && mTransposedPos < mOutputIndex) { - mTransposedPos = mOutputIndex; - } - if (mTransposedPos < mInputSize - 1) { - mTransposing = mInputIndex == mTransposedPos && canTryCorrection; - } - } - - bool secondTransposing = false; - if (mTransposedCount % 2 == 1) { - if (isEquivalentChar(mProximityInfoState.getProximityType( - mInputIndex - 1, c, false))) { - ++mTransposedCount; - secondTransposing = true; - } else if (mCorrectionStates[mOutputIndex].mExceeding) { - --mTransposedCount; - ++mExcessiveCount; - --mExcessivePos; - incrementInputIndex(); - } else { - --mTransposedCount; - if (DEBUG_CORRECTION - && (INPUTLENGTH_FOR_DEBUG <= 0 || INPUTLENGTH_FOR_DEBUG == mInputSize) - && (MIN_OUTPUT_INDEX_FOR_DEBUG <= 0 - || MIN_OUTPUT_INDEX_FOR_DEBUG < mOutputIndex)) { - DUMP_WORD(mWord, mOutputIndex); - AKLOGI("UNRELATED(0): %d, %d, %d, %d, %c", mProximityCount, mSkippedCount, - mTransposedCount, mExcessiveCount, c); - } - return processUnrelatedCorrectionType(); - } - } - - // TODO: Change the limit if we'll allow two or more proximity chars with corrections - // Work around: When the mMaxErrors is 1, we only allow just one error - // including proximity correction. - const bool checkProximityChars = (mMaxErrors > 1) - ? (noCorrectionsHappenedSoFar || mProximityCount == 0) - : (noCorrectionsHappenedSoFar && mProximityCount == 0); - - ProximityType matchedProximityCharId = secondTransposing - ? MATCH_CHAR - : mProximityInfoState.getProximityType( - mInputIndex, c, checkProximityChars, &proximityIndex); - - if (SUBSTITUTION_CHAR == matchedProximityCharId - || ADDITIONAL_PROXIMITY_CHAR == matchedProximityCharId) { - if (canTryCorrection && mOutputIndex > 0 - && mCorrectionStates[mOutputIndex].mProximityMatching - && mCorrectionStates[mOutputIndex].mExceeding - && isEquivalentChar(mProximityInfoState.getProximityType( - mInputIndex, mWord[mOutputIndex - 1], false))) { - if (DEBUG_CORRECTION - && (INPUTLENGTH_FOR_DEBUG <= 0 || INPUTLENGTH_FOR_DEBUG == mInputSize) - && (MIN_OUTPUT_INDEX_FOR_DEBUG <= 0 - || MIN_OUTPUT_INDEX_FOR_DEBUG < mOutputIndex)) { - AKLOGI("CONVERSION p->e %c", mWord[mOutputIndex - 1]); - } - // Conversion p->e - // Example: - // wearth -> earth - // px -> (E)mmmmm - ++mExcessiveCount; - --mProximityCount; - mExcessivePos = mOutputIndex - 1; - ++mInputIndex; - // Here, we are doing something equivalent to matchedProximityCharId, - // but we already know that "excessive char correction" just happened - // so that we just need to check "mProximityCount == 0". - matchedProximityCharId = mProximityInfoState.getProximityType( - mInputIndex, c, mProximityCount == 0, &proximityIndex); - } - } - - if (SUBSTITUTION_CHAR == matchedProximityCharId - || ADDITIONAL_PROXIMITY_CHAR == matchedProximityCharId) { - if (ADDITIONAL_PROXIMITY_CHAR == matchedProximityCharId) { - mAdditionalProximityMatching = true; - } - // TODO: Optimize - // As the current char turned out to be an unrelated char, - // we will try other correction-types. Please note that mCorrectionStates[mOutputIndex] - // here refers to the previous state. - if (mInputIndex < mInputSize - 1 && mOutputIndex > 0 && mTransposedCount > 0 - && !mCorrectionStates[mOutputIndex].mTransposing - && mCorrectionStates[mOutputIndex - 1].mTransposing - && isEquivalentChar(mProximityInfoState.getProximityType( - mInputIndex, mWord[mOutputIndex - 1], false)) - && isEquivalentChar( - mProximityInfoState.getProximityType(mInputIndex + 1, c, false))) { - // Conversion t->e - // Example: - // occaisional -> occa sional - // mmmmttx -> mmmm(E)mmmmmm - mTransposedCount -= 2; - ++mExcessiveCount; - ++mInputIndex; - } else if (mOutputIndex > 0 && mInputIndex > 0 && mTransposedCount > 0 - && !mCorrectionStates[mOutputIndex].mTransposing - && mCorrectionStates[mOutputIndex - 1].mTransposing - && isEquivalentChar( - mProximityInfoState.getProximityType(mInputIndex - 1, c, false))) { - // Conversion t->s - // Example: - // chcolate -> chocolate - // mmttx -> mmsmmmmmm - mTransposedCount -= 2; - ++mSkippedCount; - --mInputIndex; - } else if (canTryCorrection && mInputIndex > 0 - && mCorrectionStates[mOutputIndex].mProximityMatching - && mCorrectionStates[mOutputIndex].mSkipping - && isEquivalentChar( - mProximityInfoState.getProximityType(mInputIndex - 1, c, false))) { - // Conversion p->s - // Note: This logic tries saving cases like contrst --> contrast -- "a" is one of - // proximity chars of "s", but it should rather be handled as a skipped char. - ++mSkippedCount; - --mProximityCount; - return processSkipChar(c, isTerminal, false); - } else if (mInputIndex - 1 < mInputSize - && mSkippedCount > 0 - && mCorrectionStates[mOutputIndex].mSkipping - && mCorrectionStates[mOutputIndex].mAdditionalProximityMatching - && isProximityCharOrEquivalentChar( - mProximityInfoState.getProximityType(mInputIndex + 1, c, false))) { - // Conversion s->a - incrementInputIndex(); - --mSkippedCount; - mProximityMatching = true; - ++mProximityCount; - mDistances[mOutputIndex] = ADDITIONAL_PROXIMITY_CHAR_DISTANCE_INFO; - } else if ((mExceeding || mTransposing) && mInputIndex - 1 < mInputSize - && isEquivalentChar( - mProximityInfoState.getProximityType(mInputIndex + 1, c, false))) { - // 1.2. Excessive or transpose correction - if (mTransposing) { - ++mTransposedCount; - } else { - ++mExcessiveCount; - incrementInputIndex(); - } - if (DEBUG_CORRECTION - && (INPUTLENGTH_FOR_DEBUG <= 0 || INPUTLENGTH_FOR_DEBUG == mInputSize) - && (MIN_OUTPUT_INDEX_FOR_DEBUG <= 0 - || MIN_OUTPUT_INDEX_FOR_DEBUG < mOutputIndex)) { - DUMP_WORD(mWord, mOutputIndex); - if (mTransposing) { - AKLOGI("TRANSPOSE: %d, %d, %d, %d, %c", mProximityCount, mSkippedCount, - mTransposedCount, mExcessiveCount, c); - } else { - AKLOGI("EXCEED: %d, %d, %d, %d, %c", mProximityCount, mSkippedCount, - mTransposedCount, mExcessiveCount, c); - } - } - } else if (mSkipping) { - // 3. Skip correction - ++mSkippedCount; - if (DEBUG_CORRECTION - && (INPUTLENGTH_FOR_DEBUG <= 0 || INPUTLENGTH_FOR_DEBUG == mInputSize) - && (MIN_OUTPUT_INDEX_FOR_DEBUG <= 0 - || MIN_OUTPUT_INDEX_FOR_DEBUG < mOutputIndex)) { - AKLOGI("SKIP: %d, %d, %d, %d, %c", mProximityCount, mSkippedCount, - mTransposedCount, mExcessiveCount, c); - } - return processSkipChar(c, isTerminal, false); - } else if (ADDITIONAL_PROXIMITY_CHAR == matchedProximityCharId) { - // As a last resort, use additional proximity characters - mProximityMatching = true; - ++mProximityCount; - mDistances[mOutputIndex] = ADDITIONAL_PROXIMITY_CHAR_DISTANCE_INFO; - if (DEBUG_CORRECTION - && (INPUTLENGTH_FOR_DEBUG <= 0 || INPUTLENGTH_FOR_DEBUG == mInputSize) - && (MIN_OUTPUT_INDEX_FOR_DEBUG <= 0 - || MIN_OUTPUT_INDEX_FOR_DEBUG < mOutputIndex)) { - AKLOGI("ADDITIONALPROX: %d, %d, %d, %d, %c", mProximityCount, mSkippedCount, - mTransposedCount, mExcessiveCount, c); - } - } else { - if (DEBUG_CORRECTION - && (INPUTLENGTH_FOR_DEBUG <= 0 || INPUTLENGTH_FOR_DEBUG == mInputSize) - && (MIN_OUTPUT_INDEX_FOR_DEBUG <= 0 - || MIN_OUTPUT_INDEX_FOR_DEBUG < mOutputIndex)) { - DUMP_WORD(mWord, mOutputIndex); - AKLOGI("UNRELATED(1): %d, %d, %d, %d, %c", mProximityCount, mSkippedCount, - mTransposedCount, mExcessiveCount, c); - } - return processUnrelatedCorrectionType(); - } - } else if (secondTransposing) { - // If inputIndex is greater than mInputSize, that means there is no - // proximity chars. So, we don't need to check proximity. - mMatching = true; - } else if (isEquivalentChar(matchedProximityCharId)) { - mMatching = true; - ++mEquivalentCharCount; - mDistances[mOutputIndex] = mProximityInfoState.getNormalizedSquaredDistance(mInputIndex, 0); - } else if (PROXIMITY_CHAR == matchedProximityCharId) { - mProximityMatching = true; - ++mProximityCount; - mDistances[mOutputIndex] = - mProximityInfoState.getNormalizedSquaredDistance(mInputIndex, proximityIndex); - if (DEBUG_CORRECTION - && (INPUTLENGTH_FOR_DEBUG <= 0 || INPUTLENGTH_FOR_DEBUG == mInputSize) - && (MIN_OUTPUT_INDEX_FOR_DEBUG <= 0 - || MIN_OUTPUT_INDEX_FOR_DEBUG < mOutputIndex)) { - AKLOGI("PROX: %d, %d, %d, %d, %c", mProximityCount, mSkippedCount, - mTransposedCount, mExcessiveCount, c); - } - } - - addCharToCurrentWord(c); - - // 4. Last char excessive correction - mLastCharExceeded = mExcessiveCount == 0 && mSkippedCount == 0 && mTransposedCount == 0 - && mProximityCount == 0 && (mInputIndex == mInputSize - 2); - const bool isSameAsUserTypedLength = (mInputSize == mInputIndex + 1) || mLastCharExceeded; - if (mLastCharExceeded) { - ++mExcessiveCount; - } - - // Start traversing all nodes after the index exceeds the user typed length - if (isSameAsUserTypedLength) { - startToTraverseAllNodes(); - } - - const bool needsToTryOnTerminalForTheLastPossibleExcessiveChar = - mExceeding && mInputIndex == mInputSize - 2; - - // Finally, we are ready to go to the next character, the next "virtual node". - // We should advance the input index. - // We do this in this branch of the 'if traverseAllNodes' because we are still matching - // characters to input; the other branch is not matching them but searching for - // completions, this is why it does not have to do it. - incrementInputIndex(); - // Also, the next char is one "virtual node" depth more than this char. - incrementOutputIndex(); - - if ((needsToTryOnTerminalForTheLastPossibleExcessiveChar - || isSameAsUserTypedLength) && isTerminal) { - mTerminalInputIndex = mInputIndex - 1; - mTerminalOutputIndex = mOutputIndex - 1; - if (DEBUG_CORRECTION - && (INPUTLENGTH_FOR_DEBUG <= 0 || INPUTLENGTH_FOR_DEBUG == mInputSize) - && (MIN_OUTPUT_INDEX_FOR_DEBUG <= 0 || MIN_OUTPUT_INDEX_FOR_DEBUG < mOutputIndex)) { - DUMP_WORD(mWord, mOutputIndex); - AKLOGI("ONTERMINAL(1): %d, %d, %d, %d, %c", mProximityCount, mSkippedCount, - mTransposedCount, mExcessiveCount, c); - } - return ON_TERMINAL; - } else { - mTerminalInputIndex = mInputIndex - 1; - mTerminalOutputIndex = mOutputIndex - 1; - return NOT_ON_TERMINAL; - } -} - -inline static int getQuoteCount(const int *word, const int length) { - int quoteCount = 0; - for (int i = 0; i < length; ++i) { - if (word[i] == KEYCODE_SINGLE_QUOTE) { - ++quoteCount; - } - } - return quoteCount; -} - -inline static bool isUpperCase(unsigned short c) { - return CharUtils::isAsciiUpper(CharUtils::toBaseCodePoint(c)); -} - -////////////////////// -// RankingAlgorithm // -////////////////////// - -/* static */ int Correction::RankingAlgorithm::calculateFinalProbability(const int inputIndex, - const int outputIndex, const int freq, int *editDistanceTable, const Correction *correction, - const int inputSize) { - const int excessivePos = correction->getExcessivePos(); - const int typedLetterMultiplier = correction->TYPED_LETTER_MULTIPLIER; - const int fullWordMultiplier = correction->FULL_WORD_MULTIPLIER; - const ProximityInfoState *proximityInfoState = &correction->mProximityInfoState; - const int skippedCount = correction->mSkippedCount; - const int transposedCount = correction->mTransposedCount / 2; - const int excessiveCount = correction->mExcessiveCount + correction->mTransposedCount % 2; - const int proximityMatchedCount = correction->mProximityCount; - const bool lastCharExceeded = correction->mLastCharExceeded; - const bool useFullEditDistance = correction->mUseFullEditDistance; - const int outputLength = outputIndex + 1; - if (skippedCount >= inputSize || inputSize == 0) { - return -1; - } - - // TODO: find more robust way - bool sameLength = lastCharExceeded ? (inputSize == inputIndex + 2) - : (inputSize == inputIndex + 1); - - // TODO: use mExcessiveCount - const int matchCount = inputSize - correction->mProximityCount - excessiveCount; - - const int *word = correction->mWord; - const bool skipped = skippedCount > 0; - - const int quoteDiffCount = max(0, getQuoteCount(word, outputLength) - - getQuoteCount(proximityInfoState->getPrimaryInputWord(), inputSize)); - - // TODO: Calculate edit distance for transposed and excessive - int ed = 0; - if (DEBUG_DICT_FULL) { - dumpEditDistance10ForDebug(editDistanceTable, correction->mInputSize, outputLength); - } - int adjustedProximityMatchedCount = proximityMatchedCount; - - int finalFreq = freq; - - if (DEBUG_CORRECTION_FREQ - && (INPUTLENGTH_FOR_DEBUG <= 0 || INPUTLENGTH_FOR_DEBUG == inputSize)) { - AKLOGI("FinalFreq0: %d", finalFreq); - } - // TODO: Optimize this. - if (transposedCount > 0 || proximityMatchedCount > 0 || skipped || excessiveCount > 0) { - ed = getCurrentEditDistance(editDistanceTable, correction->mInputSize, outputLength, - inputSize) - transposedCount; - - const int matchWeight = powerIntCapped(typedLetterMultiplier, - max(inputSize, outputLength) - ed); - multiplyIntCapped(matchWeight, &finalFreq); - - // TODO: Demote further if there are two or more excessive chars with longer user input? - if (inputSize > outputLength) { - multiplyRate(INPUT_EXCEEDS_OUTPUT_DEMOTION_RATE, &finalFreq); - } - - ed = max(0, ed - quoteDiffCount); - adjustedProximityMatchedCount = min(max(0, ed - (outputLength - inputSize)), - proximityMatchedCount); - if (transposedCount <= 0) { - if (ed == 1 && (inputSize == outputLength - 1 || inputSize == outputLength + 1)) { - // Promote a word with just one skipped or excessive char - if (sameLength) { - multiplyRate(WORDS_WITH_JUST_ONE_CORRECTION_PROMOTION_RATE - + WORDS_WITH_JUST_ONE_CORRECTION_PROMOTION_MULTIPLIER * outputLength, - &finalFreq); - } else { - multiplyIntCapped(typedLetterMultiplier, &finalFreq); - } - } else if (ed == 0) { - multiplyIntCapped(typedLetterMultiplier, &finalFreq); - sameLength = true; - } - } - } else { - const int matchWeight = powerIntCapped(typedLetterMultiplier, matchCount); - multiplyIntCapped(matchWeight, &finalFreq); - } - - if (proximityInfoState->getProximityType(0, word[0], true) == SUBSTITUTION_CHAR) { - multiplyRate(FIRST_CHAR_DIFFERENT_DEMOTION_RATE, &finalFreq); - } - - /////////////////////////////////////////////// - // Promotion and Demotion for each correction - - // Demotion for a word with missing character - if (skipped) { - const int demotionRate = WORDS_WITH_MISSING_CHARACTER_DEMOTION_RATE - * (10 * inputSize - WORDS_WITH_MISSING_CHARACTER_DEMOTION_START_POS_10X) - / (10 * inputSize - - WORDS_WITH_MISSING_CHARACTER_DEMOTION_START_POS_10X + 10); - if (DEBUG_DICT_FULL) { - AKLOGI("Demotion rate for missing character is %d.", demotionRate); - } - multiplyRate(demotionRate, &finalFreq); - } - - // Demotion for a word with transposed character - if (transposedCount > 0) multiplyRate( - WORDS_WITH_TRANSPOSED_CHARACTERS_DEMOTION_RATE, &finalFreq); - - // Demotion for a word with excessive character - if (excessiveCount > 0) { - multiplyRate(WORDS_WITH_EXCESSIVE_CHARACTER_DEMOTION_RATE, &finalFreq); - if (!lastCharExceeded && !proximityInfoState->existsAdjacentProximityChars(excessivePos)) { - if (DEBUG_DICT_FULL) { - AKLOGI("Double excessive demotion"); - } - // If an excessive character is not adjacent to the left char or the right char, - // we will demote this word. - multiplyRate(WORDS_WITH_EXCESSIVE_CHARACTER_OUT_OF_PROXIMITY_DEMOTION_RATE, &finalFreq); - } - } - - int additionalProximityCount = 0; - // Demote additional proximity characters - for (int i = 0; i < outputLength; ++i) { - const int squaredDistance = correction->mDistances[i]; - if (squaredDistance == ADDITIONAL_PROXIMITY_CHAR_DISTANCE_INFO) { - ++additionalProximityCount; - } - } - - const bool performTouchPositionCorrection = - CALIBRATE_SCORE_BY_TOUCH_COORDINATES - && proximityInfoState->touchPositionCorrectionEnabled() - && skippedCount == 0 && excessiveCount == 0 && transposedCount == 0 - && additionalProximityCount == 0; - - // Score calibration by touch coordinates is being done only for pure-fat finger typing error - // cases. - // TODO: Remove this constraint. - if (performTouchPositionCorrection) { - for (int i = 0; i < outputLength; ++i) { - const int squaredDistance = correction->mDistances[i]; - if (i < adjustedProximityMatchedCount) { - multiplyIntCapped(typedLetterMultiplier, &finalFreq); - } - const float factor = TouchPositionCorrectionUtils::getLengthScalingFactor( - static_cast<float>(squaredDistance)); - if (factor > 0.0f) { - multiplyRate(static_cast<int>(factor * 100.0f), &finalFreq); - } else if (squaredDistance == PROXIMITY_CHAR_WITHOUT_DISTANCE_INFO) { - multiplyRate(WORDS_WITH_PROXIMITY_CHARACTER_DEMOTION_RATE, &finalFreq); - } - } - } else { - // Promotion for a word with proximity characters - for (int i = 0; i < adjustedProximityMatchedCount; ++i) { - // A word with proximity corrections - if (DEBUG_DICT_FULL) { - AKLOGI("Found a proximity correction."); - } - multiplyIntCapped(typedLetterMultiplier, &finalFreq); - if (i < additionalProximityCount) { - multiplyRate(WORDS_WITH_ADDITIONAL_PROXIMITY_CHARACTER_DEMOTION_RATE, &finalFreq); - } else { - multiplyRate(WORDS_WITH_PROXIMITY_CHARACTER_DEMOTION_RATE, &finalFreq); - } - } - } - - // If the user types too many(three or more) proximity characters with additional proximity - // character,do not treat as the same length word. - if (sameLength && additionalProximityCount > 0 && (adjustedProximityMatchedCount >= 3 - || transposedCount > 0 || skipped || excessiveCount > 0)) { - sameLength = false; - } - - const int errorCount = adjustedProximityMatchedCount > 0 - ? adjustedProximityMatchedCount - : (proximityMatchedCount + transposedCount); - multiplyRate( - 100 - CORRECTION_COUNT_RATE_DEMOTION_RATE_BASE * errorCount / inputSize, &finalFreq); - - // Promotion for an exactly matched word - if (ed == 0) { - // Full exact match - if (sameLength && transposedCount == 0 && !skipped && excessiveCount == 0 - && quoteDiffCount == 0 && additionalProximityCount == 0) { - finalFreq = capped255MultForFullMatchAccentsOrCapitalizationDifference(finalFreq); - } - } - - // Promote a word with no correction - if (proximityMatchedCount == 0 && transposedCount == 0 && !skipped && excessiveCount == 0 - && additionalProximityCount == 0) { - multiplyRate(FULL_MATCHED_WORDS_PROMOTION_RATE, &finalFreq); - } - - // TODO: Check excessive count and transposed count - // TODO: Remove this if possible - /* - If the last character of the user input word is the same as the next character - of the output word, and also all of characters of the user input are matched - to the output word, we'll promote that word a bit because - that word can be considered the combination of skipped and matched characters. - This means that the 'sm' pattern wins over the 'ma' pattern. - e.g.) - shel -> shell [mmmma] or [mmmsm] - hel -> hello [mmmaa] or [mmsma] - m ... matching - s ... skipping - a ... traversing all - t ... transposing - e ... exceeding - p ... proximity matching - */ - if (matchCount == inputSize && matchCount >= 2 && !skipped - && word[matchCount] == word[matchCount - 1]) { - multiplyRate(WORDS_WITH_MATCH_SKIP_PROMOTION_RATE, &finalFreq); - } - - // TODO: Do not use sameLength? - if (sameLength) { - multiplyIntCapped(fullWordMultiplier, &finalFreq); - } - - if (useFullEditDistance && outputLength > inputSize + 1) { - const int diff = outputLength - inputSize - 1; - const int divider = diff < 31 ? 1 << diff : S_INT_MAX; - finalFreq = divider > finalFreq ? 1 : finalFreq / divider; - } - - if (DEBUG_DICT_FULL) { - AKLOGI("calc: %d, %d", outputLength, sameLength); - } - - if (DEBUG_CORRECTION_FREQ - && (INPUTLENGTH_FOR_DEBUG <= 0 || INPUTLENGTH_FOR_DEBUG == inputSize)) { - DUMP_WORD(correction->getPrimaryInputWord(), inputSize); - DUMP_WORD(correction->mWord, outputLength); - AKLOGI("FinalFreq: [P%d, S%d, T%d, E%d, A%d] %d, %d, %d, %d, %d, %d", proximityMatchedCount, - skippedCount, transposedCount, excessiveCount, additionalProximityCount, - outputLength, lastCharExceeded, sameLength, quoteDiffCount, ed, finalFreq); - } - - return finalFreq; -} - -/* static */ int Correction::RankingAlgorithm::calcFreqForSplitMultipleWords(const int *freqArray, - const int *wordLengthArray, const int wordCount, const Correction *correction, - const bool isSpaceProximity, const int *word) { - const int typedLetterMultiplier = correction->TYPED_LETTER_MULTIPLIER; - - bool firstCapitalizedWordDemotion = false; - bool secondCapitalizedWordDemotion = false; - - { - // TODO: Handle multiple capitalized word demotion properly - const int firstWordLength = wordLengthArray[0]; - const int secondWordLength = wordLengthArray[1]; - if (firstWordLength >= 2) { - firstCapitalizedWordDemotion = isUpperCase(word[0]); - } - - if (secondWordLength >= 2) { - // FIXME: word[firstWordLength + 1] is incorrect. - secondCapitalizedWordDemotion = isUpperCase(word[firstWordLength + 1]); - } - } - - - const bool capitalizedWordDemotion = - firstCapitalizedWordDemotion ^ secondCapitalizedWordDemotion; - - int totalLength = 0; - int totalFreq = 0; - for (int i = 0; i < wordCount; ++i) { - const int wordLength = wordLengthArray[i]; - if (wordLength <= 0) { - return 0; - } - totalLength += wordLength; - const int demotionRate = 100 - TWO_WORDS_CORRECTION_DEMOTION_BASE / (wordLength + 1); - int tempFirstFreq = freqArray[i]; - multiplyRate(demotionRate, &tempFirstFreq); - totalFreq += tempFirstFreq; - } - - if (totalLength <= 0 || totalFreq <= 0) { - return 0; - } - - // TODO: Currently totalFreq is adjusted to two word metrix. - // Promote pairFreq with multiplying by 2, because the word length is the same as the typed - // length. - totalFreq = totalFreq * 2 / wordCount; - if (wordCount > 2) { - // Safety net for 3+ words -- Caveats: many heuristics and workarounds here. - int oneLengthCounter = 0; - int twoLengthCounter = 0; - for (int i = 0; i < wordCount; ++i) { - const int wordLength = wordLengthArray[i]; - // TODO: Use bigram instead of this safety net - if (i < wordCount - 1) { - const int nextWordLength = wordLengthArray[i + 1]; - if (wordLength == 1 && nextWordLength == 2) { - // Safety net to filter 1 length and 2 length sequential words - return 0; - } - } - const int freq = freqArray[i]; - // Demote too short weak words - if (wordLength <= 4 && freq <= SUPPRESS_SHORT_MULTIPLE_WORDS_THRESHOLD_FREQ) { - multiplyRate(100 * freq / MAX_PROBABILITY, &totalFreq); - } - if (wordLength == 1) { - ++oneLengthCounter; - } else if (wordLength == 2) { - ++twoLengthCounter; - } - if (oneLengthCounter >= 2 || (oneLengthCounter + twoLengthCounter) >= 4) { - // Safety net to filter too many short words - return 0; - } - } - multiplyRate(MULTIPLE_WORDS_DEMOTION_RATE, &totalFreq); - } - - // This is a workaround to try offsetting the not-enough-demotion which will be done in - // calcNormalizedScore in Utils.java. - // In calcNormalizedScore the score will be demoted by (1 - 1 / length) - // but we demoted only (1 - 1 / (length + 1)) so we will additionally adjust freq by - // (1 - 1 / length) / (1 - 1 / (length + 1)) = (1 - 1 / (length * length)) - const int normalizedScoreNotEnoughDemotionAdjustment = 100 - 100 / (totalLength * totalLength); - multiplyRate(normalizedScoreNotEnoughDemotionAdjustment, &totalFreq); - - // At this moment, totalFreq is calculated by the following formula: - // (firstFreq * (1 - 1 / (firstWordLength + 1)) + secondFreq * (1 - 1 / (secondWordLength + 1))) - // * (1 - 1 / totalLength) / (1 - 1 / (totalLength + 1)) - - multiplyIntCapped(powerIntCapped(typedLetterMultiplier, totalLength), &totalFreq); - - // This is another workaround to offset the demotion which will be done in - // calcNormalizedScore in Utils.java. - // In calcNormalizedScore the score will be demoted by (1 - 1 / length) so we have to promote - // the same amount because we already have adjusted the synthetic freq of this "missing or - // mistyped space" suggestion candidate above in this method. - const int normalizedScoreDemotionRateOffset = (100 + 100 / totalLength); - multiplyRate(normalizedScoreDemotionRateOffset, &totalFreq); - - if (isSpaceProximity) { - // A word pair with one space proximity correction - if (DEBUG_DICT) { - AKLOGI("Found a word pair with space proximity correction."); - } - multiplyIntCapped(typedLetterMultiplier, &totalFreq); - multiplyRate(WORDS_WITH_PROXIMITY_CHARACTER_DEMOTION_RATE, &totalFreq); - } - - if (isSpaceProximity) { - multiplyRate(WORDS_WITH_MISTYPED_SPACE_DEMOTION_RATE, &totalFreq); - } else { - multiplyRate(WORDS_WITH_MISSING_SPACE_CHARACTER_DEMOTION_RATE, &totalFreq); - } - - if (capitalizedWordDemotion) { - multiplyRate(TWO_WORDS_CAPITALIZED_DEMOTION_RATE, &totalFreq); - } - - if (DEBUG_CORRECTION_FREQ) { - AKLOGI("Multiple words (%d, %d) (%d, %d) %d, %d", freqArray[0], freqArray[1], - wordLengthArray[0], wordLengthArray[1], capitalizedWordDemotion, totalFreq); - DUMP_WORD(word, wordLengthArray[0]); - } - - return totalFreq; -} - -/* static */ int Correction::RankingAlgorithm::editDistance(const int *before, - const int beforeLength, const int *after, const int afterLength) { - const DamerauLevenshteinEditDistancePolicy daemaruLevenshtein( - before, beforeLength, after, afterLength); - return static_cast<int>(EditDistance::getEditDistance(&daemaruLevenshtein)); -} - - -// In dictionary.cpp, getSuggestion() method, -// When USE_SUGGEST_INTERFACE_FOR_TYPING is true: -// SUGGEST_INTERFACE_OUTPUT_SCALE was multiplied to the original suggestion scores to convert -// them to integers. -// score = (int)((original score) * SUGGEST_INTERFACE_OUTPUT_SCALE) -// Undo the scaling here to recover the original score. -// normalizedScore = ((float)score) / SUGGEST_INTERFACE_OUTPUT_SCALE -// Otherwise: suggestion scores are computed using the below formula. -// original score -// := powf(mTypedLetterMultiplier (this is defined 2), -// (the number of matched characters between typed word and suggested word)) -// * (individual word's score which defined in the unigram dictionary, -// and this score is defined in range [0, 255].) -// Then, the following processing is applied. -// - If the dictionary word is matched up to the point of the user entry -// (full match up to min(before.length(), after.length()) -// => Then multiply by FULL_MATCHED_WORDS_PROMOTION_RATE (this is defined 1.2) -// - If the word is a true full match except for differences in accents or -// capitalization, then treat it as if the score was 255. -// - If before.length() == after.length() -// => multiply by mFullWordMultiplier (this is defined 2)) -// So, maximum original score is powf(2, min(before.length(), after.length())) * 255 * 2 * 1.2 -// For historical reasons we ignore the 1.2 modifier (because the measure for a good -// autocorrection threshold was done at a time when it didn't exist). This doesn't change -// the result. -// So, we can normalize original score by dividing powf(2, min(b.l(),a.l())) * 255 * 2. - -/* static */ float Correction::RankingAlgorithm::calcNormalizedScore(const int *before, - const int beforeLength, const int *after, const int afterLength, const int score) { - if (0 == beforeLength || 0 == afterLength) { - return 0.0f; - } - const int distance = editDistance(before, beforeLength, after, afterLength); - int spaceCount = 0; - for (int i = 0; i < afterLength; ++i) { - if (after[i] == KEYCODE_SPACE) { - ++spaceCount; - } - } - - if (spaceCount == afterLength) { - return 0.0f; - } - - // add a weight based on edit distance. - // distance <= max(afterLength, beforeLength) == afterLength, - // so, 0 <= distance / afterLength <= 1 - const float weight = 1.0f - static_cast<float>(distance) / static_cast<float>(afterLength); - - if (USE_SUGGEST_INTERFACE_FOR_TYPING) { - return (static_cast<float>(score) / SUGGEST_INTERFACE_OUTPUT_SCALE) * weight; - } - const float maxScore = score >= S_INT_MAX ? static_cast<float>(S_INT_MAX) - : static_cast<float>(MAX_INITIAL_SCORE) - * powf(static_cast<float>(TYPED_LETTER_MULTIPLIER), - static_cast<float>(min(beforeLength, afterLength - spaceCount))) - * static_cast<float>(FULL_WORD_MULTIPLIER); - - return (static_cast<float>(score) / maxScore) * weight; -} -} // namespace latinime diff --git a/native/jni/src/obsolete/correction.h b/native/jni/src/obsolete/correction.h deleted file mode 100644 index 710220d66..000000000 --- a/native/jni/src/obsolete/correction.h +++ /dev/null @@ -1,377 +0,0 @@ -/* - * Copyright (C) 2011 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LATINIME_CORRECTION_H -#define LATINIME_CORRECTION_H - -#include <cstring> // for memset() - -#include "defines.h" -#include "obsolete/correction_state.h" -#include "suggest/core/layout/proximity_info_state.h" -#include "utils/char_utils.h" - -namespace latinime { - -class ProximityInfo; - -class Correction { - public: - typedef enum { - TRAVERSE_ALL_ON_TERMINAL, - TRAVERSE_ALL_NOT_ON_TERMINAL, - UNRELATED, - ON_TERMINAL, - NOT_ON_TERMINAL - } CorrectionType; - - Correction() - : mProximityInfo(0), mUseFullEditDistance(false), mDoAutoCompletion(false), - mMaxEditDistance(0), mMaxDepth(0), mInputSize(0), mSpaceProximityPos(0), - mMissingSpacePos(0), mTerminalInputIndex(0), mTerminalOutputIndex(0), mMaxErrors(0), - mTotalTraverseCount(0), mNeedsToTraverseAllNodes(false), mOutputIndex(0), - mInputIndex(0), mEquivalentCharCount(0), mProximityCount(0), mExcessiveCount(0), - mTransposedCount(0), mSkippedCount(0), mTransposedPos(0), mExcessivePos(0), - mSkipPos(0), mLastCharExceeded(false), mMatching(false), mProximityMatching(false), - mAdditionalProximityMatching(false), mExceeding(false), mTransposing(false), - mSkipping(false), mProximityInfoState() { - memset(mWord, 0, sizeof(mWord)); - memset(mDistances, 0, sizeof(mDistances)); - memset(mEditDistanceTable, 0, sizeof(mEditDistanceTable)); - // NOTE: mCorrectionStates is an array of instances. - // No need to initialize it explicitly here. - } - - // Non virtual inline destructor -- never inherit this class - ~Correction() {} - void resetCorrection(); - void initCorrection(const ProximityInfo *pi, const int inputSize, const int maxDepth); - void initCorrectionState(const int rootPos, const int childCount, const bool traverseAll); - - // TODO: remove - void setCorrectionParams(const int skipPos, const int excessivePos, const int transposedPos, - const int spaceProximityPos, const int missingSpacePos, const bool useFullEditDistance, - const bool doAutoCompletion, const int maxErrors); - void checkState() const; - bool sameAsTyped() const; - bool initProcessState(const int index); - - int getInputIndex() const; - - bool needsToPrune() const; - - int pushAndGetTotalTraverseCount() { - return ++mTotalTraverseCount; - } - - int getFreqForSplitMultipleWords(const int *freqArray, const int *wordLengthArray, - const int wordCount, const bool isSpaceProximity, const int *word) const; - int getFinalProbability(const int probability, int **word, int *wordLength); - int getFinalProbabilityForSubQueue(const int probability, int **word, int *wordLength, - const int inputSize); - - CorrectionType processCharAndCalcState(const int c, const bool isTerminal); - - ///////////////////////// - // Tree helper methods - int goDownTree(const int parentIndex, const int childCount, const int firstChildPos); - - inline int getTreeSiblingPos(const int index) const { - return mCorrectionStates[index].mSiblingPos; - } - - inline void setTreeSiblingPos(const int index, const int pos) { - mCorrectionStates[index].mSiblingPos = pos; - } - - inline int getTreeParentIndex(const int index) const { - return mCorrectionStates[index].mParentIndex; - } - - class RankingAlgorithm { - public: - static int calculateFinalProbability(const int inputIndex, const int depth, - const int probability, int *editDistanceTable, const Correction *correction, - const int inputSize); - static int calcFreqForSplitMultipleWords(const int *freqArray, const int *wordLengthArray, - const int wordCount, const Correction *correction, const bool isSpaceProximity, - const int *word); - static float calcNormalizedScore(const int *before, const int beforeLength, - const int *after, const int afterLength, const int score); - static int editDistance(const int *before, const int beforeLength, const int *after, - const int afterLength); - private: - static const int MAX_INITIAL_SCORE = 255; - }; - - // proximity info state - void initInputParams(const ProximityInfo *proximityInfo, const int *inputCodes, - const int inputSize, const int *xCoordinates, const int *yCoordinates) { - mProximityInfoState.initInputParams(0, static_cast<float>(MAX_VALUE_FOR_WEIGHTING), - proximityInfo, inputCodes, inputSize, xCoordinates, yCoordinates, 0, 0, false); - } - - const int *getPrimaryInputWord() const { - return mProximityInfoState.getPrimaryInputWord(); - } - - int getPrimaryCodePointAt(const int index) const { - return mProximityInfoState.getPrimaryCodePointAt(index); - } - - private: - DISALLOW_COPY_AND_ASSIGN(Correction); - - ///////////////////////// - // static inline utils // - ///////////////////////// - static const int TWO_31ST_DIV_255 = S_INT_MAX / 255; - static inline int capped255MultForFullMatchAccentsOrCapitalizationDifference(const int num) { - return (num < TWO_31ST_DIV_255 ? 255 * num : S_INT_MAX); - } - - static const int TWO_31ST_DIV_2 = S_INT_MAX / 2; - AK_FORCE_INLINE static void multiplyIntCapped(const int multiplier, int *base) { - const int temp = *base; - if (temp != S_INT_MAX) { - // Branch if multiplier == 2 for the optimization - if (multiplier < 0) { - if (DEBUG_DICT) { - ASSERT(false); - } - AKLOGI("--- Invalid multiplier: %d", multiplier); - } else if (multiplier == 0) { - *base = 0; - } else if (multiplier == 2) { - *base = TWO_31ST_DIV_2 >= temp ? temp << 1 : S_INT_MAX; - } else { - // TODO: This overflow check gives a wrong answer when, for example, - // temp = 2^16 + 1 and multiplier = 2^17 + 1. - // Fix this behavior. - const int tempRetval = temp * multiplier; - *base = tempRetval >= temp ? tempRetval : S_INT_MAX; - } - } - } - - AK_FORCE_INLINE static int powerIntCapped(const int base, const int n) { - if (n <= 0) return 1; - if (base == 2) { - return n < 31 ? 1 << n : S_INT_MAX; - } - int ret = base; - for (int i = 1; i < n; ++i) multiplyIntCapped(base, &ret); - return ret; - } - - AK_FORCE_INLINE static void multiplyRate(const int rate, int *freq) { - if (*freq != S_INT_MAX) { - if (*freq > 1000000) { - *freq /= 100; - multiplyIntCapped(rate, freq); - } else { - multiplyIntCapped(rate, freq); - *freq /= 100; - } - } - } - - inline int getSpaceProximityPos() const { - return mSpaceProximityPos; - } - inline int getMissingSpacePos() const { - return mMissingSpacePos; - } - - inline int getSkipPos() const { - return mSkipPos; - } - - inline int getExcessivePos() const { - return mExcessivePos; - } - - inline int getTransposedPos() const { - return mTransposedPos; - } - - inline void incrementInputIndex(); - inline void incrementOutputIndex(); - inline void startToTraverseAllNodes(); - inline bool isSingleQuote(const int c); - inline CorrectionType processSkipChar(const int c, const bool isTerminal, - const bool inputIndexIncremented); - inline CorrectionType processUnrelatedCorrectionType(); - inline void addCharToCurrentWord(const int c); - inline int getFinalProbabilityInternal(const int probability, int **word, int *wordLength, - const int inputSize); - - static const int TYPED_LETTER_MULTIPLIER = 2; - static const int FULL_WORD_MULTIPLIER = 2; - const ProximityInfo *mProximityInfo; - - bool mUseFullEditDistance; - bool mDoAutoCompletion; - int mMaxEditDistance; - int mMaxDepth; - int mInputSize; - int mSpaceProximityPos; - int mMissingSpacePos; - int mTerminalInputIndex; - int mTerminalOutputIndex; - int mMaxErrors; - - int mTotalTraverseCount; - - // The following arrays are state buffer. - int mWord[MAX_WORD_LENGTH]; - int mDistances[MAX_WORD_LENGTH]; - - // Edit distance calculation requires a buffer with (N+1)^2 length for the input length N. - // Caveat: Do not create multiple tables per thread as this table eats up RAM a lot. - int mEditDistanceTable[(MAX_WORD_LENGTH + 1) * (MAX_WORD_LENGTH + 1)]; - - CorrectionState mCorrectionStates[MAX_WORD_LENGTH]; - - // The following member variables are being used as cache values of the correction state. - bool mNeedsToTraverseAllNodes; - int mOutputIndex; - int mInputIndex; - - int mEquivalentCharCount; - int mProximityCount; - int mExcessiveCount; - int mTransposedCount; - int mSkippedCount; - - int mTransposedPos; - int mExcessivePos; - int mSkipPos; - - bool mLastCharExceeded; - - bool mMatching; - bool mProximityMatching; - bool mAdditionalProximityMatching; - bool mExceeding; - bool mTransposing; - bool mSkipping; - ProximityInfoState mProximityInfoState; -}; - -inline void Correction::incrementInputIndex() { - ++mInputIndex; -} - -AK_FORCE_INLINE void Correction::incrementOutputIndex() { - ++mOutputIndex; - mCorrectionStates[mOutputIndex].mParentIndex = mCorrectionStates[mOutputIndex - 1].mParentIndex; - mCorrectionStates[mOutputIndex].mChildCount = mCorrectionStates[mOutputIndex - 1].mChildCount; - mCorrectionStates[mOutputIndex].mSiblingPos = mCorrectionStates[mOutputIndex - 1].mSiblingPos; - mCorrectionStates[mOutputIndex].mInputIndex = mInputIndex; - mCorrectionStates[mOutputIndex].mNeedsToTraverseAllNodes = mNeedsToTraverseAllNodes; - - mCorrectionStates[mOutputIndex].mEquivalentCharCount = mEquivalentCharCount; - mCorrectionStates[mOutputIndex].mProximityCount = mProximityCount; - mCorrectionStates[mOutputIndex].mTransposedCount = mTransposedCount; - mCorrectionStates[mOutputIndex].mExcessiveCount = mExcessiveCount; - mCorrectionStates[mOutputIndex].mSkippedCount = mSkippedCount; - - mCorrectionStates[mOutputIndex].mSkipPos = mSkipPos; - mCorrectionStates[mOutputIndex].mTransposedPos = mTransposedPos; - mCorrectionStates[mOutputIndex].mExcessivePos = mExcessivePos; - - mCorrectionStates[mOutputIndex].mLastCharExceeded = mLastCharExceeded; - - mCorrectionStates[mOutputIndex].mMatching = mMatching; - mCorrectionStates[mOutputIndex].mProximityMatching = mProximityMatching; - mCorrectionStates[mOutputIndex].mAdditionalProximityMatching = mAdditionalProximityMatching; - mCorrectionStates[mOutputIndex].mTransposing = mTransposing; - mCorrectionStates[mOutputIndex].mExceeding = mExceeding; - mCorrectionStates[mOutputIndex].mSkipping = mSkipping; -} - -inline void Correction::startToTraverseAllNodes() { - mNeedsToTraverseAllNodes = true; -} - -AK_FORCE_INLINE bool Correction::isSingleQuote(const int c) { - const int userTypedChar = mProximityInfoState.getPrimaryCodePointAt(mInputIndex); - return (c == KEYCODE_SINGLE_QUOTE && userTypedChar != KEYCODE_SINGLE_QUOTE); -} - -AK_FORCE_INLINE Correction::CorrectionType Correction::processSkipChar(const int c, - const bool isTerminal, const bool inputIndexIncremented) { - addCharToCurrentWord(c); - mTerminalInputIndex = mInputIndex - (inputIndexIncremented ? 1 : 0); - mTerminalOutputIndex = mOutputIndex; - incrementOutputIndex(); - if (mNeedsToTraverseAllNodes && isTerminal) { - return TRAVERSE_ALL_ON_TERMINAL; - } - return TRAVERSE_ALL_NOT_ON_TERMINAL; -} - -inline Correction::CorrectionType Correction::processUnrelatedCorrectionType() { - // Needs to set mTerminalInputIndex and mTerminalOutputIndex before returning any CorrectionType - mTerminalInputIndex = mInputIndex; - mTerminalOutputIndex = mOutputIndex; - return UNRELATED; -} - -AK_FORCE_INLINE static void calcEditDistanceOneStep(int *editDistanceTable, const int *input, - const int inputSize, const int *output, const int outputLength) { - // TODO: Make sure that editDistance[0 ~ MAX_WORD_LENGTH] is not touched. - // Let dp[i][j] be editDistanceTable[i * (inputSize + 1) + j]. - // Assuming that dp[0][0] ... dp[outputLength - 1][inputSize] are already calculated, - // and calculate dp[ouputLength][0] ... dp[outputLength][inputSize]. - int *const current = editDistanceTable + outputLength * (inputSize + 1); - const int *const prev = editDistanceTable + (outputLength - 1) * (inputSize + 1); - const int *const prevprev = - outputLength >= 2 ? editDistanceTable + (outputLength - 2) * (inputSize + 1) : 0; - current[0] = outputLength; - const int co = CharUtils::toBaseLowerCase(output[outputLength - 1]); - const int prevCO = outputLength >= 2 ? CharUtils::toBaseLowerCase(output[outputLength - 2]) : 0; - for (int i = 1; i <= inputSize; ++i) { - const int ci = CharUtils::toBaseLowerCase(input[i - 1]); - const int cost = (ci == co) ? 0 : 1; - current[i] = min(current[i - 1] + 1, min(prev[i] + 1, prev[i - 1] + cost)); - if (i >= 2 && prevprev && ci == prevCO && co == CharUtils::toBaseLowerCase(input[i - 2])) { - current[i] = min(current[i], prevprev[i - 2] + 1); - } - } -} - -AK_FORCE_INLINE void Correction::addCharToCurrentWord(const int c) { - mWord[mOutputIndex] = c; - const int *primaryInputWord = mProximityInfoState.getPrimaryInputWord(); - calcEditDistanceOneStep(mEditDistanceTable, primaryInputWord, mInputSize, mWord, - mOutputIndex + 1); -} - -inline int Correction::getFinalProbabilityInternal(const int probability, int **word, - int *wordLength, const int inputSize) { - const int outputIndex = mTerminalOutputIndex; - const int inputIndex = mTerminalInputIndex; - *wordLength = outputIndex + 1; - *word = mWord; - int finalProbability= Correction::RankingAlgorithm::calculateFinalProbability( - inputIndex, outputIndex, probability, mEditDistanceTable, this, inputSize); - return finalProbability; -} - -} // namespace latinime -#endif // LATINIME_CORRECTION_H diff --git a/native/jni/src/obsolete/correction_state.h b/native/jni/src/obsolete/correction_state.h deleted file mode 100644 index a63d4aa94..000000000 --- a/native/jni/src/obsolete/correction_state.h +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Copyright (C) 2011 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LATINIME_CORRECTION_STATE_H -#define LATINIME_CORRECTION_STATE_H - -#include <stdint.h> - -#include "defines.h" - -namespace latinime { - -struct CorrectionState { - int mParentIndex; - int mSiblingPos; - uint16_t mChildCount; - uint8_t mInputIndex; - - uint8_t mEquivalentCharCount; - uint8_t mProximityCount; - uint8_t mTransposedCount; - uint8_t mExcessiveCount; - uint8_t mSkippedCount; - - int8_t mTransposedPos; - int8_t mExcessivePos; - int8_t mSkipPos; // should be signed - - // TODO: int? - bool mLastCharExceeded; - - bool mMatching; - bool mTransposing; - bool mExceeding; - bool mSkipping; - bool mProximityMatching; - bool mAdditionalProximityMatching; - - bool mNeedsToTraverseAllNodes; -}; - -inline static void initCorrectionState(CorrectionState *state, const int rootPos, - const uint16_t childCount, const bool traverseAll) { - state->mParentIndex = -1; - state->mChildCount = childCount; - state->mInputIndex = 0; - state->mSiblingPos = rootPos; - state->mNeedsToTraverseAllNodes = traverseAll; - - state->mTransposedPos = -1; - state->mExcessivePos = -1; - state->mSkipPos = -1; - - state->mEquivalentCharCount = 0; - state->mProximityCount = 0; - state->mTransposedCount = 0; - state->mExcessiveCount = 0; - state->mSkippedCount = 0; - - state->mLastCharExceeded = false; - - state->mMatching = false; - state->mProximityMatching = false; - state->mTransposing = false; - state->mExceeding = false; - state->mSkipping = false; - state->mAdditionalProximityMatching = false; -} -} // namespace latinime -#endif // LATINIME_CORRECTION_STATE_H diff --git a/native/jni/src/obsolete/words_priority_queue.cpp b/native/jni/src/obsolete/words_priority_queue.cpp deleted file mode 100644 index 563cf918e..000000000 --- a/native/jni/src/obsolete/words_priority_queue.cpp +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Copyright (C) 2012, The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "obsolete/words_priority_queue.h" - -namespace latinime { - -int WordsPriorityQueue::outputSuggestions(const int *before, const int beforeLength, - int *frequencies, int *outputCodePoints, int* outputTypes) { - mHighestSuggestedWord = 0; - const int size = min(MAX_WORDS, static_cast<int>(mSuggestions.size())); - SuggestedWord *swBuffer[size]; - int index = size - 1; - while (!mSuggestions.empty() && index >= 0) { - SuggestedWord *sw = mSuggestions.top(); - if (DEBUG_WORDS_PRIORITY_QUEUE) { - AKLOGI("dump word. %d", sw->mScore); - DUMP_WORD(sw->mWord, sw->mWordLength); - } - swBuffer[index] = sw; - mSuggestions.pop(); - --index; - } - if (size >= 2) { - SuggestedWord *nsMaxSw = 0; - int maxIndex = 0; - float maxNs = 0; - for (int i = 0; i < size; ++i) { - SuggestedWord *tempSw = swBuffer[i]; - if (!tempSw) { - continue; - } - const float tempNs = getNormalizedScore(tempSw, before, beforeLength, 0, 0, 0); - if (tempNs >= maxNs) { - maxNs = tempNs; - maxIndex = i; - nsMaxSw = tempSw; - } - } - if (maxIndex > 0 && nsMaxSw) { - memmove(&swBuffer[1], &swBuffer[0], maxIndex * sizeof(swBuffer[0])); - swBuffer[0] = nsMaxSw; - } - } - for (int i = 0; i < size; ++i) { - SuggestedWord *sw = swBuffer[i]; - if (!sw) { - AKLOGE("SuggestedWord is null %d", i); - continue; - } - const int wordLength = sw->mWordLength; - int *targetAddress = outputCodePoints + i * MAX_WORD_LENGTH; - frequencies[i] = sw->mScore; - outputTypes[i] = sw->mType; - memcpy(targetAddress, sw->mWord, wordLength * sizeof(targetAddress[0])); - if (wordLength < MAX_WORD_LENGTH) { - targetAddress[wordLength] = 0; - } - sw->mUsed = false; - } - return size; -} -} // namespace latinime diff --git a/native/jni/src/obsolete/words_priority_queue.h b/native/jni/src/obsolete/words_priority_queue.h deleted file mode 100644 index 337e3e32b..000000000 --- a/native/jni/src/obsolete/words_priority_queue.h +++ /dev/null @@ -1,175 +0,0 @@ -/* - * Copyright (C) 2011 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LATINIME_WORDS_PRIORITY_QUEUE_H -#define LATINIME_WORDS_PRIORITY_QUEUE_H - -#include <cstring> // for memcpy() -#include <queue> - -#include "defines.h" -#include "obsolete/correction.h" - -namespace latinime { - -class WordsPriorityQueue { - public: - struct SuggestedWord { - int mScore; - int mWord[MAX_WORD_LENGTH]; - int mWordLength; - bool mUsed; - int mType; - - void setParams(int score, int *word, int wordLength, int type) { - mScore = score; - mWordLength = wordLength; - memcpy(mWord, word, sizeof(mWord[0]) * wordLength); - mUsed = true; - mType = type; - } - }; - - WordsPriorityQueue(int maxWords) - : mSuggestions(), MAX_WORDS(maxWords), - mSuggestedWords(new SuggestedWord[MAX_WORD_LENGTH]), mHighestSuggestedWord(0) { - for (int i = 0; i < MAX_WORD_LENGTH; ++i) { - mSuggestedWords[i].mUsed = false; - } - } - - // Non virtual inline destructor -- never inherit this class - AK_FORCE_INLINE ~WordsPriorityQueue() { - delete[] mSuggestedWords; - } - - void push(int score, int *word, int wordLength, int type) { - SuggestedWord *sw = 0; - if (size() >= MAX_WORDS) { - sw = mSuggestions.top(); - const int minScore = sw->mScore; - if (minScore >= score) { - return; - } - sw->mUsed = false; - mSuggestions.pop(); - } - if (sw == 0) { - sw = getFreeSuggestedWord(score, word, wordLength, type); - } else { - sw->setParams(score, word, wordLength, type); - } - if (sw == 0) { - AKLOGE("SuggestedWord is accidentally null."); - return; - } - if (DEBUG_WORDS_PRIORITY_QUEUE) { - AKLOGI("Push word. %d, %d", score, wordLength); - DUMP_WORD(word, wordLength); - } - mSuggestions.push(sw); - if (!mHighestSuggestedWord || mHighestSuggestedWord->mScore < sw->mScore) { - mHighestSuggestedWord = sw; - } - } - - SuggestedWord *top() const { - if (mSuggestions.empty()) return 0; - SuggestedWord *sw = mSuggestions.top(); - return sw; - } - - int size() const { - return static_cast<int>(mSuggestions.size()); - } - - AK_FORCE_INLINE void clear() { - mHighestSuggestedWord = 0; - while (!mSuggestions.empty()) { - SuggestedWord *sw = mSuggestions.top(); - if (DEBUG_WORDS_PRIORITY_QUEUE) { - AKLOGI("Clear word. %d", sw->mScore); - DUMP_WORD(sw->mWord, sw->mWordLength); - } - sw->mUsed = false; - mSuggestions.pop(); - } - } - - AK_FORCE_INLINE void dumpTopWord() const { - if (size() <= 0) { - return; - } - DUMP_WORD(mHighestSuggestedWord->mWord, mHighestSuggestedWord->mWordLength); - } - - AK_FORCE_INLINE float getHighestNormalizedScore(const int *before, const int beforeLength, - int **outWord, int *outScore, int *outLength) const { - if (!mHighestSuggestedWord) { - return 0.0f; - } - return getNormalizedScore(mHighestSuggestedWord, before, beforeLength, outWord, outScore, - outLength); - } - - int outputSuggestions(const int *before, const int beforeLength, int *frequencies, - int *outputCodePoints, int* outputTypes); - - private: - DISALLOW_IMPLICIT_CONSTRUCTORS(WordsPriorityQueue); - struct wordComparator { - bool operator ()(SuggestedWord * left, SuggestedWord * right) { - return left->mScore > right->mScore; - } - }; - - SuggestedWord *getFreeSuggestedWord(int score, int *word, int wordLength, int type) const { - for (int i = 0; i < MAX_WORD_LENGTH; ++i) { - if (!mSuggestedWords[i].mUsed) { - mSuggestedWords[i].setParams(score, word, wordLength, type); - return &mSuggestedWords[i]; - } - } - return 0; - } - - static float getNormalizedScore(SuggestedWord *sw, const int *before, const int beforeLength, - int **outWord, int *outScore, int *outLength) { - const int score = sw->mScore; - int *word = sw->mWord; - const int wordLength = sw->mWordLength; - if (outScore) { - *outScore = score; - } - if (outWord) { - *outWord = word; - } - if (outLength) { - *outLength = wordLength; - } - return Correction::RankingAlgorithm::calcNormalizedScore(before, beforeLength, word, - wordLength, score); - } - - typedef std::priority_queue<SuggestedWord *, std::vector<SuggestedWord *>, - wordComparator> Suggestions; - Suggestions mSuggestions; - const int MAX_WORDS; - SuggestedWord *mSuggestedWords; - SuggestedWord *mHighestSuggestedWord; -}; -} // namespace latinime -#endif // LATINIME_WORDS_PRIORITY_QUEUE_H diff --git a/native/jni/src/obsolete/words_priority_queue_pool.h b/native/jni/src/obsolete/words_priority_queue_pool.h deleted file mode 100644 index bf04568db..000000000 --- a/native/jni/src/obsolete/words_priority_queue_pool.h +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Copyright (C) 2011 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LATINIME_WORDS_PRIORITY_QUEUE_POOL_H -#define LATINIME_WORDS_PRIORITY_QUEUE_POOL_H - -#include "defines.h" -#include "obsolete/words_priority_queue.h" - -namespace latinime { - -class WordsPriorityQueuePool { - public: - WordsPriorityQueuePool(int mainQueueMaxWords, int subQueueMaxWords) - // Note: using placement new() requires the caller to call the destructor explicitly. - : mMasterQueue(new(mMasterQueueBuf) WordsPriorityQueue(mainQueueMaxWords)) { - for (int i = 0, subQueueBufOffset = 0; - i < MULTIPLE_WORDS_SUGGESTION_MAX_WORDS * SUB_QUEUE_MAX_COUNT; - ++i, subQueueBufOffset += static_cast<int>(sizeof(WordsPriorityQueue))) { - mSubQueues[i] = new(mSubQueueBuf + subQueueBufOffset) - WordsPriorityQueue(subQueueMaxWords); - } - } - - // Non virtual inline destructor -- never inherit this class - ~WordsPriorityQueuePool() { - // Note: these explicit calls to the destructor match the calls to placement new() above. - if (mMasterQueue) mMasterQueue->~WordsPriorityQueue(); - for (int i = 0; i < MULTIPLE_WORDS_SUGGESTION_MAX_WORDS * SUB_QUEUE_MAX_COUNT; ++i) { - if (mSubQueues[i]) mSubQueues[i]->~WordsPriorityQueue(); - } - } - - WordsPriorityQueue *getMasterQueue() const { - return mMasterQueue; - } - - WordsPriorityQueue *getSubQueue(const int wordIndex, const int inputWordLength) const { - if (wordIndex >= MULTIPLE_WORDS_SUGGESTION_MAX_WORDS) { - return 0; - } - if (inputWordLength < 0 || inputWordLength >= SUB_QUEUE_MAX_COUNT) { - if (DEBUG_WORDS_PRIORITY_QUEUE) { - ASSERT(false); - } - return 0; - } - return mSubQueues[wordIndex * SUB_QUEUE_MAX_COUNT + inputWordLength]; - } - - inline void clearAll() { - mMasterQueue->clear(); - for (int i = 0; i < MULTIPLE_WORDS_SUGGESTION_MAX_WORDS; ++i) { - clearSubQueue(i); - } - } - - AK_FORCE_INLINE void clearSubQueue(const int wordIndex) { - for (int i = 0; i < SUB_QUEUE_MAX_COUNT; ++i) { - WordsPriorityQueue *queue = getSubQueue(wordIndex, i); - if (queue) { - queue->clear(); - } - } - } - - void dumpSubQueue1TopSuggestions() const { - AKLOGI("DUMP SUBQUEUE1 TOP SUGGESTIONS"); - for (int i = 0; i < SUB_QUEUE_MAX_COUNT; ++i) { - getSubQueue(0, i)->dumpTopWord(); - } - } - - private: - DISALLOW_IMPLICIT_CONSTRUCTORS(WordsPriorityQueuePool); - char mMasterQueueBuf[sizeof(WordsPriorityQueue)]; - char mSubQueueBuf[SUB_QUEUE_MAX_COUNT * MULTIPLE_WORDS_SUGGESTION_MAX_WORDS - * sizeof(WordsPriorityQueue)]; - WordsPriorityQueue *mMasterQueue; - WordsPriorityQueue *mSubQueues[SUB_QUEUE_MAX_COUNT * MULTIPLE_WORDS_SUGGESTION_MAX_WORDS]; -}; -} // namespace latinime -#endif // LATINIME_WORDS_PRIORITY_QUEUE_POOL_H diff --git a/native/jni/src/suggest/core/dicnode/dic_node.h b/native/jni/src/suggest/core/dicnode/dic_node.h index 3f64d07b2..41ef9d2b2 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node.h +++ b/native/jni/src/suggest/core/dicnode/dic_node.h @@ -18,25 +18,26 @@ #define LATINIME_DIC_NODE_H #include "defines.h" -#include "suggest/core/dicnode/dic_node_state.h" #include "suggest/core/dicnode/dic_node_profiler.h" -#include "suggest/core/dicnode/dic_node_properties.h" #include "suggest/core/dicnode/dic_node_release_listener.h" +#include "suggest/core/dicnode/internal/dic_node_state.h" +#include "suggest/core/dicnode/internal/dic_node_properties.h" #include "suggest/core/dictionary/digraph_utils.h" #include "utils/char_utils.h" #if DEBUG_DICT #define LOGI_SHOW_ADD_COST_PROP \ do { char charBuf[50]; \ - INTS_TO_CHARS(getOutputWordBuf(), getDepth(), charBuf); \ + INTS_TO_CHARS(getOutputWordBuf(), getNodeCodePointCount(), charBuf, NELEMS(charBuf)); \ AKLOGI("%20s, \"%c\", size = %03d, total = %03d, index(0) = %02d, dist = %.4f, %s,,", \ __FUNCTION__, getNodeCodePoint(), inputSize, getTotalInputIndex(), \ getInputIndex(0), getNormalizedCompoundDistance(), charBuf); } while (0) #define DUMP_WORD_AND_SCORE(header) \ do { char charBuf[50]; char prevWordCharBuf[50]; \ - INTS_TO_CHARS(getOutputWordBuf(), getDepth(), charBuf); \ + INTS_TO_CHARS(getOutputWordBuf(), getNodeCodePointCount(), charBuf, NELEMS(charBuf)); \ INTS_TO_CHARS(mDicNodeState.mDicNodeStatePrevWord.mPrevWord, \ - mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(), prevWordCharBuf); \ + mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(), prevWordCharBuf, \ + NELEMS(prevWordCharBuf)); \ AKLOGI("#%8s, %5f, %5f, %5f, %5f, %s, %s, %d,,", header, \ getSpatialDistanceForScoring(), getLanguageDistanceForScoring(), \ getNormalizedCompoundDistance(), getRawLength(), prevWordCharBuf, charBuf, \ @@ -51,6 +52,11 @@ namespace latinime { // This struct is purely a bucket to return values. No instances of this struct should be kept. struct DicNode_InputStateG { + DicNode_InputStateG() + : mNeedsToUpdateInputStateG(false), mPointerId(0), mInputIndex(0), + mPrevCodePoint(0), mTerminalDiffCost(0.0f), mRawLength(0.0f), + mDoubleLetterLevel(NOT_A_DOUBLE_LETTER) {} + bool mNeedsToUpdateInputStateG; int mPointerId; int16_t mInputIndex; @@ -92,7 +98,6 @@ class DicNode { DicNode &operator=(const DicNode &dicNode); virtual ~DicNode() {} - // TODO: minimize arguments by looking binary_format // Init for copy void initByCopy(const DicNode *dicNode) { mIsUsed = true; @@ -102,35 +107,28 @@ class DicNode { PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); } - // TODO: minimize arguments by looking binary_format // Init for root with prevWordNodePos which is used for bigram - void initAsRoot(const int pos, const int childrenPos, const int childrenCount, - const int prevWordNodePos) { + void initAsRoot(const int rootGroupPos, const int prevWordNodePos) { mIsUsed = true; mIsCachedForNextSuggestion = false; mDicNodeProperties.init( - pos, 0, childrenPos, 0, 0, 0, childrenCount, 0, 0, false, false, true, 0, 0); + NOT_A_DICT_POS /* pos */, rootGroupPos, NOT_A_CODE_POINT /* nodeCodePoint */, + NOT_A_PROBABILITY /* probability */, false /* isTerminal */, + true /* hasChildren */, false /* isBlacklistedOrNotAWord */, 0 /* depth */, + 0 /* terminalDepth */); mDicNodeState.init(prevWordNodePos); PROF_NODE_RESET(mProfiler); } - void initAsPassingChild(DicNode *parentNode) { - mIsUsed = true; - mIsCachedForNextSuggestion = parentNode->mIsCachedForNextSuggestion; - const int c = parentNode->getNodeTypedCodePoint(); - mDicNodeProperties.init(&parentNode->mDicNodeProperties, c); - mDicNodeState.init(&parentNode->mDicNodeState); - PROF_NODE_COPY(&parentNode->mProfiler, mProfiler); - } - - // TODO: minimize arguments by looking binary_format // Init for root with previous word - void initAsRootWithPreviousWord(DicNode *dicNode, const int pos, const int childrenPos, - const int childrenCount) { + void initAsRootWithPreviousWord(DicNode *dicNode, const int rootGroupPos) { mIsUsed = true; - mIsCachedForNextSuggestion = false; + mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion; mDicNodeProperties.init( - pos, 0, childrenPos, 0, 0, 0, childrenCount, 0, 0, false, false, true, 0, 0); + NOT_A_DICT_POS /* pos */, rootGroupPos, NOT_A_CODE_POINT /* nodeCodePoint */, + NOT_A_PROBABILITY /* probability */, false /* isTerminal */, + true /* hasChildren */, false /* isBlacklistedOrNotAWord */, 0 /* depth */, + 0 /* terminalDepth */); // TODO: Move to dicNodeState? mDicNodeState.mDicNodeStateOutput.init(); // reset for next word mDicNodeState.mDicNodeStateInput.init( @@ -145,26 +143,33 @@ class DicNode { dicNode->mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(), dicNode->getOutputWordBuf(), dicNode->mDicNodeProperties.getDepth(), - dicNode->mDicNodeState.mDicNodeStatePrevWord.mPrevSpacePositions, + dicNode->mDicNodeState.mDicNodeStatePrevWord.getSecondWordFirstInputIndex(), mDicNodeState.mDicNodeStateInput.getInputIndex(0) /* lastInputIndex */); PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); } - // TODO: minimize arguments by looking binary_format - void initAsChild(DicNode *dicNode, const int pos, const uint8_t flags, const int childrenPos, - const int attributesPos, const int siblingPos, const int nodeCodePoint, - const int childrenCount, const int probability, const int bigramProbability, - const bool isTerminal, const bool hasMultipleChars, const bool hasChildren, - const uint16_t additionalSubwordLength, const int *additionalSubword) { + void initAsPassingChild(DicNode *parentNode) { + mIsUsed = true; + mIsCachedForNextSuggestion = parentNode->mIsCachedForNextSuggestion; + const int c = parentNode->getNodeTypedCodePoint(); + mDicNodeProperties.init(&parentNode->mDicNodeProperties, c); + mDicNodeState.init(&parentNode->mDicNodeState); + PROF_NODE_COPY(&parentNode->mProfiler, mProfiler); + } + + void initAsChild(const DicNode *const dicNode, const int pos, const int childrenPos, + const int probability, const bool isTerminal, const bool hasChildren, + const bool isBlacklistedOrNotAWord, const uint16_t mergedNodeCodePointCount, + const int *const mergedNodeCodePoints) { mIsUsed = true; - uint16_t newDepth = static_cast<uint16_t>(dicNode->getDepth() + 1); + uint16_t newDepth = static_cast<uint16_t>(dicNode->getNodeCodePointCount() + 1); mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion; const uint16_t newLeavingDepth = static_cast<uint16_t>( - dicNode->mDicNodeProperties.getLeavingDepth() + additionalSubwordLength); - mDicNodeProperties.init(pos, flags, childrenPos, attributesPos, siblingPos, nodeCodePoint, - childrenCount, probability, bigramProbability, isTerminal, hasMultipleChars, - hasChildren, newDepth, newLeavingDepth); - mDicNodeState.init(&dicNode->mDicNodeState, additionalSubwordLength, additionalSubword); + dicNode->mDicNodeProperties.getLeavingDepth() + mergedNodeCodePointCount); + mDicNodeProperties.init(pos, childrenPos, mergedNodeCodePoints[0], probability, + isTerminal, hasChildren, isBlacklistedOrNotAWord, newDepth, newLeavingDepth); + mDicNodeState.init(&dicNode->mDicNodeState, mergedNodeCodePointCount, + mergedNodeCodePoints); PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); } @@ -180,7 +185,7 @@ class DicNode { } bool isRoot() const { - return getDepth() == 0; + return getNodeCodePointCount() == 0; } bool hasChildren() const { @@ -188,12 +193,12 @@ class DicNode { } bool isLeavingNode() const { - ASSERT(getDepth() <= getLeavingDepth()); - return getDepth() == getLeavingDepth(); + ASSERT(getNodeCodePointCount() <= mDicNodeProperties.getLeavingDepth()); + return getNodeCodePointCount() == mDicNodeProperties.getLeavingDepth(); } AK_FORCE_INLINE bool isFirstLetter() const { - return getDepth() == 1; + return getNodeCodePointCount() == 1; } bool isCached() const { @@ -206,17 +211,21 @@ class DicNode { // Used to expand the node in DicNodeUtils int getNodeTypedCodePoint() const { - return mDicNodeState.mDicNodeStateOutput.getCodePointAt(getDepth()); + return mDicNodeState.mDicNodeStateOutput.getCodePointAt(getNodeCodePointCount()); } - bool isImpossibleBigramWord() const { - if (mDicNodeProperties.hasBlacklistedOrNotAWordFlag()) { - return true; + // Check if the current word and the previous word can be considered as a valid multiple word + // suggestion. + bool isValidMultipleWordSuggestion() const { + if (isBlacklistedOrNotAWord()) { + return false; } + // Treat suggestion as invalid if the current and the previous word are single character + // words. const int prevWordLen = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength() - mDicNodeState.mDicNodeStatePrevWord.getPrevWordStart() - 1; - const int currentWordLen = getDepth(); - return (prevWordLen == 1 && currentWordLen == 1); + const int currentWordLen = getNodeCodePointCount(); + return (prevWordLen != 1 || currentWordLen != 1); } bool isFirstCharUppercase() const { @@ -225,7 +234,7 @@ class DicNode { } bool isFirstWord() const { - return mDicNodeState.mDicNodeStatePrevWord.getPrevWordNodePos() == NOT_VALID_WORD; + return mDicNodeState.mDicNodeStatePrevWord.getPrevWordNodePos() == NOT_A_DICT_POS; } bool isCompletion(const int inputSize) const { @@ -251,37 +260,27 @@ class DicNode { return mDicNodeProperties.getChildrenPos(); } - // Used in DicNodeUtils - int getChildrenCount() const { - return mDicNodeProperties.getChildrenCount(); - } - - // Used in DicNodeUtils int getProbability() const { return mDicNodeProperties.getProbability(); } AK_FORCE_INLINE bool isTerminalWordNode() const { const bool isTerminalNodes = mDicNodeProperties.isTerminal(); - const int currentNodeDepth = getDepth(); + const int currentNodeDepth = getNodeCodePointCount(); const int terminalNodeDepth = mDicNodeProperties.getLeavingDepth(); return isTerminalNodes && currentNodeDepth > 0 && currentNodeDepth == terminalNodeDepth; } bool shouldBeFilterdBySafetyNetForBigram() const { - const uint16_t currentDepth = getDepth(); + const uint16_t currentDepth = getNodeCodePointCount(); const int prevWordLen = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength() - mDicNodeState.mDicNodeStatePrevWord.getPrevWordStart() - 1; return !(currentDepth > 0 && (currentDepth != 1 || prevWordLen != 1)); } - uint16_t getLeavingDepth() const { - return mDicNodeProperties.getLeavingDepth(); - } - bool isTotalInputSizeExceedingLimit() const { const int prevWordsLen = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(); - const int currentWordDepth = getDepth(); + const int currentWordDepth = getNodeCodePointCount(); // TODO: 3 can be 2? Needs to be investigated. // TODO: Have a const variable for 3 (or 2) return prevWordsLen + currentWordDepth > MAX_WORD_LENGTH - 3; @@ -316,26 +315,31 @@ class DicNode { void outputResult(int *dest) const { const uint16_t prevWordLength = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(); - const uint16_t currentDepth = getDepth(); + const uint16_t currentDepth = getNodeCodePointCount(); DicNodeUtils::appendTwoWords(mDicNodeState.mDicNodeStatePrevWord.mPrevWord, prevWordLength, getOutputWordBuf(), currentDepth, dest); DUMP_WORD_AND_SCORE("OUTPUT"); } - void outputSpacePositionsResult(int *spaceIndices) const { - mDicNodeState.mDicNodeStatePrevWord.outputSpacePositions(spaceIndices); + int getSecondWordFirstInputIndex(const ProximityInfoState *const pInfoState) const { + const int inputIndex = mDicNodeState.mDicNodeStatePrevWord.getSecondWordFirstInputIndex(); + if (inputIndex == NOT_AN_INDEX) { + return NOT_AN_INDEX; + } else { + return pInfoState->getInputIndexOfSampledPoint(inputIndex); + } } bool hasMultipleWords() const { return mDicNodeState.mDicNodeStatePrevWord.getPrevWordCount() > 0; } - float getProximityCorrectionCount() const { - return static_cast<float>(mDicNodeState.mDicNodeStateScoring.getProximityCorrectionCount()); + int getProximityCorrectionCount() const { + return mDicNodeState.mDicNodeStateScoring.getProximityCorrectionCount(); } - float getEditCorrectionCount() const { - return static_cast<float>(mDicNodeState.mDicNodeStateScoring.getEditCorrectionCount()); + int getEditCorrectionCount() const { + return mDicNodeState.mDicNodeStateScoring.getEditCorrectionCount(); } // Used to prune nodes @@ -365,7 +369,7 @@ class DicNode { } AK_FORCE_INLINE const int *getOutputWordBuf() const { - return mDicNodeState.mDicNodeStateOutput.mWordBuf; + return mDicNodeState.mDicNodeStateOutput.mCodePointsBuf; } int getPrevCodePointG(int pointerId) const { @@ -467,16 +471,17 @@ class DicNode { return mDicNodeState.mDicNodeStateScoring.isExactMatch(); } - uint8_t getFlags() const { - return mDicNodeProperties.getFlags(); + bool isBlacklistedOrNotAWord() const { + return mDicNodeProperties.isBlacklistedOrNotAWord(); } - int getAttributesPos() const { - return mDicNodeProperties.getAttributesPos(); + inline uint16_t getNodeCodePointCount() const { + return mDicNodeProperties.getDepth(); } - inline uint16_t getDepth() const { - return mDicNodeProperties.getDepth(); + // Returns code point count including spaces + inline uint16_t getTotalNodeCodePointCount() const { + return getNodeCodePointCount() + mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(); } AK_FORCE_INLINE void dump(const char *tag) const { @@ -503,6 +508,12 @@ class DicNode { if (!right->isUsed()) { return false; } + // Promote exact matches to prevent them from being pruned. + const bool leftExactMatch = isExactMatch(); + const bool rightExactMatch = right->isExactMatch(); + if (leftExactMatch != rightExactMatch) { + return leftExactMatch; + } const float diff = right->getNormalizedCompoundDistance() - getNormalizedCompoundDistance(); static const float MIN_DIFF = 0.000001f; @@ -511,8 +522,8 @@ class DicNode { } else if (diff < -MIN_DIFF) { return false; } - const int depth = getDepth(); - const int depthDiff = right->getDepth() - depth; + const int depth = getNodeCodePointCount(); + const int depthDiff = right->getNodeCodePointCount() - depth; if (depthDiff != 0) { return depthDiff > 0; } @@ -567,7 +578,11 @@ class DicNode { } } - AK_FORCE_INLINE void updateInputIndexG(DicNode_InputStateG *inputStateG) { + AK_FORCE_INLINE void updateInputIndexG(const DicNode_InputStateG *const inputStateG) { + if (mDicNodeState.mDicNodeStatePrevWord.getPrevWordCount() == 1 && isFirstLetter()) { + mDicNodeState.mDicNodeStatePrevWord.setSecondWordFirstInputIndex( + inputStateG->mInputIndex); + } mDicNodeState.mDicNodeStateInput.updateInputIndexG(inputStateG->mPointerId, inputStateG->mInputIndex, inputStateG->mPrevCodePoint, inputStateG->mTerminalDiffCost, inputStateG->mRawLength); diff --git a/native/jni/src/suggest/core/dicnode/dic_node_priority_queue.h b/native/jni/src/suggest/core/dicnode/dic_node_priority_queue.h index 970e3bda4..7461f0cc6 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_priority_queue.h +++ b/native/jni/src/suggest/core/dicnode/dic_node_priority_queue.h @@ -24,19 +24,16 @@ #include "suggest/core/dicnode/dic_node.h" #include "suggest/core/dicnode/dic_node_release_listener.h" -#define MAX_DIC_NODE_PRIORITY_QUEUE_CAPACITY 200 - namespace latinime { class DicNodePriorityQueue : public DicNodeReleaseListener { public: - AK_FORCE_INLINE DicNodePriorityQueue() - : MAX_CAPACITY(MAX_DIC_NODE_PRIORITY_QUEUE_CAPACITY), - mMaxSize(MAX_DIC_NODE_PRIORITY_QUEUE_CAPACITY), mDicNodesBuf(), mUnusedNodeIndices(), - mNextUnusedNodeId(0), mDicNodesQueue() { - mDicNodesBuf.resize(MAX_CAPACITY + 1); - mUnusedNodeIndices.resize(MAX_CAPACITY + 1); - reset(); + AK_FORCE_INLINE explicit DicNodePriorityQueue(const int capacity) + : mCapacity(capacity), mMaxSize(capacity), mDicNodesBuf(), + mUnusedNodeIndices(), mNextUnusedNodeId(0), mDicNodesQueue() { + mDicNodesBuf.resize(mCapacity + 1); + mUnusedNodeIndices.resize(mCapacity + 1); + clearAndResizeToCapacity(); } // Non virtual inline destructor -- never inherit this class @@ -51,11 +48,12 @@ class DicNodePriorityQueue : public DicNodeReleaseListener { } AK_FORCE_INLINE void setMaxSize(const int maxSize) { - mMaxSize = min(maxSize, MAX_CAPACITY); + ASSERT(maxSize <= mCapacity); + mMaxSize = min(maxSize, mCapacity); } - AK_FORCE_INLINE void reset() { - clearAndResize(MAX_CAPACITY); + AK_FORCE_INLINE void clearAndResizeToCapacity() { + clearAndResize(mCapacity); } AK_FORCE_INLINE void clear() { @@ -63,27 +61,19 @@ class DicNodePriorityQueue : public DicNodeReleaseListener { } AK_FORCE_INLINE void clearAndResize(const int maxSize) { + ASSERT(maxSize <= mCapacity); while (!mDicNodesQueue.empty()) { mDicNodesQueue.pop(); } setMaxSize(maxSize); - for (int i = 0; i < MAX_CAPACITY + 1; ++i) { + for (int i = 0; i < mCapacity + 1; ++i) { mDicNodesBuf[i].remove(); mDicNodesBuf[i].setReleaseListener(this); - mUnusedNodeIndices[i] = i == MAX_CAPACITY ? NOT_A_NODE_ID : static_cast<int>(i) + 1; + mUnusedNodeIndices[i] = i == mCapacity ? NOT_A_NODE_ID : static_cast<int>(i) + 1; } mNextUnusedNodeId = 0; } - AK_FORCE_INLINE DicNode *newDicNode(DicNode *dicNode) { - DicNode *newNode = searchEmptyDicNode(); - if (newNode) { - DicNodeUtils::initByCopy(dicNode, newNode); - return newNode; - } - return 0; - } - // Copy AK_FORCE_INLINE DicNode *copyPush(DicNode *dicNode) { return copyPush(dicNode, mMaxSize); @@ -110,12 +100,12 @@ class DicNodePriorityQueue : public DicNodeReleaseListener { } mUnusedNodeIndices[index] = mNextUnusedNodeId; mNextUnusedNodeId = index; - ASSERT(index >= 0 && index < (MAX_CAPACITY + 1)); + ASSERT(index >= 0 && index < (mCapacity + 1)); } AK_FORCE_INLINE void dump() const { AKLOGI("\n\n\n\n\n==========================="); - for (int i = 0; i < MAX_CAPACITY + 1; ++i) { + for (int i = 0; i < mCapacity + 1; ++i) { if (mDicNodesBuf[i].isUsed()) { mDicNodesBuf[i].dump("QUEUE: "); } @@ -124,7 +114,7 @@ class DicNodePriorityQueue : public DicNodeReleaseListener { } private: - DISALLOW_COPY_AND_ASSIGN(DicNodePriorityQueue); + DISALLOW_IMPLICIT_CONSTRUCTORS(DicNodePriorityQueue); static const int NOT_A_NODE_ID = -1; AK_FORCE_INLINE static bool compareDicNode(DicNode *left, DicNode *right) { @@ -138,7 +128,7 @@ class DicNodePriorityQueue : public DicNodeReleaseListener { }; typedef std::priority_queue<DicNode *, std::vector<DicNode *>, DicNodeComparator> DicNodesQueue; - const int MAX_CAPACITY; + const int mCapacity; int mMaxSize; std::vector<DicNode> mDicNodesBuf; // of each element of mDicNodesBuf respectively std::vector<int> mUnusedNodeIndices; @@ -162,13 +152,12 @@ class DicNodePriorityQueue : public DicNodeReleaseListener { } AK_FORCE_INLINE DicNode *searchEmptyDicNode() { - // TODO: Currently O(n) but should be improved to O(1) - if (MAX_CAPACITY == 0) { + if (mCapacity == 0) { return 0; } if (mNextUnusedNodeId == NOT_A_NODE_ID) { AKLOGI("No unused node found."); - for (int i = 0; i < MAX_CAPACITY + 1; ++i) { + for (int i = 0; i < mCapacity + 1; ++i) { AKLOGI("Dump node availability, %d, %d, %d", i, mDicNodesBuf[i].isUsed(), mUnusedNodeIndices[i]); } @@ -184,7 +173,7 @@ class DicNodePriorityQueue : public DicNodeReleaseListener { const int index = static_cast<int>(dicNode - &mDicNodesBuf[0]); mNextUnusedNodeId = mUnusedNodeIndices[index]; mUnusedNodeIndices[index] = NOT_A_NODE_ID; - ASSERT(index >= 0 && index < (MAX_CAPACITY + 1)); + ASSERT(index >= 0 && index < (mCapacity + 1)); } AK_FORCE_INLINE DicNode *pushPoolNodeWithMaxSize(DicNode *dicNode, const int maxSize) { @@ -208,6 +197,15 @@ class DicNodePriorityQueue : public DicNodeReleaseListener { AK_FORCE_INLINE DicNode *copyPush(DicNode *dicNode, const int maxSize) { return pushPoolNodeWithMaxSize(newDicNode(dicNode), maxSize); } + + AK_FORCE_INLINE DicNode *newDicNode(DicNode *dicNode) { + DicNode *newNode = searchEmptyDicNode(); + if (newNode) { + DicNodeUtils::initByCopy(dicNode, newNode); + } + return newNode; + } + }; } // namespace latinime #endif // LATINIME_DIC_NODE_PRIORITY_QUEUE_H diff --git a/native/jni/src/suggest/core/dicnode/dic_node_profiler.h b/native/jni/src/suggest/core/dicnode/dic_node_profiler.h index 90f75d0c6..1f4d2570e 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_profiler.h +++ b/native/jni/src/suggest/core/dicnode/dic_node_profiler.h @@ -31,6 +31,7 @@ #define PROF_TRANSPOSITION(profiler) profiler.profTransposition() #define PROF_NEARESTKEY(profiler) profiler.profNearestKey() #define PROF_TERMINAL(profiler) profiler.profTerminal() +#define PROF_TERMINAL_INSERTION(profiler) profiler.profTerminalInsertion() #define PROF_NEW_WORD(profiler) profiler.profNewWord() #define PROF_NEW_WORD_BIGRAM(profiler) profiler.profNewWordBigram() #define PROF_NODE_RESET(profiler) profiler.reset() @@ -47,6 +48,7 @@ #define PROF_TRANSPOSITION(profiler) #define PROF_NEARESTKEY(profiler) #define PROF_TERMINAL(profiler) +#define PROF_TERMINAL_INSERTION(profiler) #define PROF_NEW_WORD(profiler) #define PROF_NEW_WORD_BIGRAM(profiler) #define PROF_NODE_RESET(profiler) @@ -62,7 +64,7 @@ class DicNodeProfiler { : mProfOmission(0), mProfInsertion(0), mProfTransposition(0), mProfAdditionalProximity(0), mProfSubstitution(0), mProfSpaceSubstitution(0), mProfSpaceOmission(0), - mProfMatch(0), mProfCompletion(0), mProfTerminal(0), + mProfMatch(0), mProfCompletion(0), mProfTerminal(0), mProfTerminalInsertion(0), mProfNearestKey(0), mProfNewWord(0), mProfNewWordBigram(0) {} int mProfOmission; @@ -75,6 +77,7 @@ class DicNodeProfiler { int mProfMatch; int mProfCompletion; int mProfTerminal; + int mProfTerminalInsertion; int mProfNearestKey; int mProfNewWord; int mProfNewWordBigram; @@ -123,6 +126,10 @@ class DicNodeProfiler { ++mProfTerminal; } + void profTerminalInsertion() { + ++mProfTerminalInsertion; + } + void profNewWord() { ++mProfNewWord; } diff --git a/native/jni/src/suggest/core/dicnode/dic_node_release_listener.h b/native/jni/src/suggest/core/dicnode/dic_node_release_listener.h index 2a81c3cae..2ca4f21bd 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_release_listener.h +++ b/native/jni/src/suggest/core/dicnode/dic_node_release_listener.h @@ -21,6 +21,8 @@ namespace latinime { +class DicNode; + class DicNodeReleaseListener { public: DicNodeReleaseListener() {} diff --git a/native/jni/src/suggest/core/dicnode/dic_node_state_output.h b/native/jni/src/suggest/core/dicnode/dic_node_state_output.h deleted file mode 100644 index 1d4f50a06..000000000 --- a/native/jni/src/suggest/core/dicnode/dic_node_state_output.h +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Copyright (C) 2012 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LATINIME_DIC_NODE_STATE_OUTPUT_H -#define LATINIME_DIC_NODE_STATE_OUTPUT_H - -#include <cstring> // for memcpy() -#include <stdint.h> - -#include "defines.h" - -namespace latinime { - -class DicNodeStateOutput { - public: - DicNodeStateOutput() : mOutputtedLength(0) { - init(); - } - - virtual ~DicNodeStateOutput() {} - - void init() { - mOutputtedLength = 0; - mWordBuf[0] = 0; - } - - void init(const DicNodeStateOutput *const stateOutput) { - memcpy(mWordBuf, stateOutput->mWordBuf, - stateOutput->mOutputtedLength * sizeof(mWordBuf[0])); - mOutputtedLength = stateOutput->mOutputtedLength; - if (mOutputtedLength < MAX_WORD_LENGTH) { - mWordBuf[mOutputtedLength] = 0; - } - } - - void addSubword(const uint16_t additionalSubwordLength, const int *const additionalSubword) { - if (additionalSubword) { - memcpy(&mWordBuf[mOutputtedLength], additionalSubword, - additionalSubwordLength * sizeof(mWordBuf[0])); - mOutputtedLength = static_cast<uint16_t>(mOutputtedLength + additionalSubwordLength); - if (mOutputtedLength < MAX_WORD_LENGTH) { - mWordBuf[mOutputtedLength] = 0; - } - } - } - - // TODO: Remove - int getCodePointAt(const int id) const { - return mWordBuf[id]; - } - - // TODO: Move to private - int mWordBuf[MAX_WORD_LENGTH]; - - private: - // Caution!!! - // Use a default copy constructor and an assign operator because shallow copies are ok - // for this class - uint16_t mOutputtedLength; -}; -} // namespace latinime -#endif // LATINIME_DIC_NODE_STATE_OUTPUT_H diff --git a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp index 3deee1a42..ec65114c7 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp +++ b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp @@ -14,18 +14,14 @@ * limitations under the License. */ +#include "suggest/core/dicnode/dic_node_utils.h" + #include <cstring> -#include <vector> #include "suggest/core/dicnode/dic_node.h" -#include "suggest/core/dicnode/dic_node_utils.h" #include "suggest/core/dicnode/dic_node_vector.h" -#include "suggest/core/dictionary/binary_dictionary_info.h" -#include "suggest/core/dictionary/binary_format.h" #include "suggest/core/dictionary/multi_bigram_map.h" -#include "suggest/core/dictionary/probability_utils.h" -#include "suggest/core/layout/proximity_info.h" -#include "suggest/core/layout/proximity_info_state.h" +#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" #include "utils/char_utils.h" namespace latinime { @@ -34,25 +30,17 @@ namespace latinime { // Node initialization utils // /////////////////////////////// -/* static */ void DicNodeUtils::initAsRoot(const BinaryDictionaryInfo *const binaryDictionaryInfo, +/* static */ void DicNodeUtils::initAsRoot( + const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy, const int prevWordNodePos, DicNode *const newRootNode) { - int curPos = binaryDictionaryInfo->getRootPosition(); - const int pos = curPos; - const int childrenCount = BinaryFormat::getGroupCountAndForwardPointer( - binaryDictionaryInfo->getDictRoot(), &curPos); - const int childrenPos = curPos; - newRootNode->initAsRoot(pos, childrenPos, childrenCount, prevWordNodePos); + newRootNode->initAsRoot(dictionaryStructurePolicy->getRootPosition(), prevWordNodePos); } /*static */ void DicNodeUtils::initAsRootWithPreviousWord( - const BinaryDictionaryInfo *const binaryDictionaryInfo, + const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy, DicNode *const prevWordLastNode, DicNode *const newRootNode) { - int curPos = binaryDictionaryInfo->getRootPosition(); - const int pos = curPos; - const int childrenCount = BinaryFormat::getGroupCountAndForwardPointer( - binaryDictionaryInfo->getDictRoot(), &curPos); - const int childrenPos = curPos; - newRootNode->initAsRootWithPreviousWord(prevWordLastNode, pos, childrenPos, childrenCount); + newRootNode->initAsRootWithPreviousWord( + prevWordLastNode, dictionaryStructurePolicy->getRootPosition()); } /* static */ void DicNodeUtils::initByCopy(DicNode *srcNode, DicNode *destNode) { @@ -62,141 +50,16 @@ namespace latinime { /////////////////////////////////// // Traverse node expansion utils // /////////////////////////////////// - -/* static */ void DicNodeUtils::createAndGetPassingChildNode(DicNode *dicNode, - const ProximityInfoState *pInfoState, const int pointIndex, const bool exactOnly, - DicNodeVector *childDicNodes) { - // Passing multiple chars node. No need to traverse child - const int codePoint = dicNode->getNodeTypedCodePoint(); - const int baseLowerCaseCodePoint = CharUtils::toBaseLowerCase(codePoint); - const bool isMatch = isMatchedNodeCodePoint(pInfoState, pointIndex, exactOnly, codePoint); - if (isMatch || CharUtils::isIntentionalOmissionCodePoint(baseLowerCaseCodePoint)) { - childDicNodes->pushPassingChild(dicNode); - } -} - -/* static */ int DicNodeUtils::createAndGetLeavingChildNode(DicNode *dicNode, int pos, - const BinaryDictionaryInfo *const binaryDictionaryInfo, const int terminalDepth, - const ProximityInfoState *pInfoState, const int pointIndex, const bool exactOnly, - const std::vector<int> *const codePointsFilter, const ProximityInfo *const pInfo, - DicNodeVector *childDicNodes) { - int nextPos = pos; - const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer( - binaryDictionaryInfo->getDictRoot(), &pos); - const bool hasMultipleChars = (0 != (BinaryFormat::FLAG_HAS_MULTIPLE_CHARS & flags)); - const bool isTerminal = (0 != (BinaryFormat::FLAG_IS_TERMINAL & flags)); - const bool hasChildren = BinaryFormat::hasChildrenInFlags(flags); - - int codePoint = BinaryFormat::getCodePointAndForwardPointer( - binaryDictionaryInfo->getDictRoot(), &pos); - ASSERT(NOT_A_CODE_POINT != codePoint); - const int nodeCodePoint = codePoint; - // TODO: optimize this - int additionalWordBuf[MAX_WORD_LENGTH]; - uint16_t additionalSubwordLength = 0; - additionalWordBuf[additionalSubwordLength++] = codePoint; - - do { - const int nextCodePoint = hasMultipleChars - ? BinaryFormat::getCodePointAndForwardPointer( - binaryDictionaryInfo->getDictRoot(), &pos) : NOT_A_CODE_POINT; - const bool isLastChar = (NOT_A_CODE_POINT == nextCodePoint); - if (!isLastChar) { - additionalWordBuf[additionalSubwordLength++] = nextCodePoint; - } - codePoint = nextCodePoint; - } while (NOT_A_CODE_POINT != codePoint); - - const int probability = isTerminal ? BinaryFormat::readProbabilityWithoutMovingPointer( - binaryDictionaryInfo->getDictRoot(), pos) : -1; - pos = BinaryFormat::skipProbability(flags, pos); - int childrenPos = hasChildren ? BinaryFormat::readChildrenPosition( - binaryDictionaryInfo->getDictRoot(), flags, pos) : 0; - const int attributesPos = BinaryFormat::skipChildrenPosition(flags, pos); - const int siblingPos = BinaryFormat::skipChildrenPosAndAttributes( - binaryDictionaryInfo->getDictRoot(), flags, pos); - - if (isDicNodeFilteredOut(nodeCodePoint, pInfo, codePointsFilter)) { - return siblingPos; - } - if (!isMatchedNodeCodePoint(pInfoState, pointIndex, exactOnly, nodeCodePoint)) { - return siblingPos; - } - const int childrenCount = hasChildren ? BinaryFormat::getGroupCountAndForwardPointer( - binaryDictionaryInfo->getDictRoot(), &childrenPos) : 0; - childDicNodes->pushLeavingChild(dicNode, nextPos, flags, childrenPos, attributesPos, siblingPos, - nodeCodePoint, childrenCount, probability, -1 /* bigramProbability */, isTerminal, - hasMultipleChars, hasChildren, additionalSubwordLength, additionalWordBuf); - return siblingPos; -} - -/* static */ bool DicNodeUtils::isDicNodeFilteredOut(const int nodeCodePoint, - const ProximityInfo *const pInfo, const std::vector<int> *const codePointsFilter) { - const int filterSize = codePointsFilter ? codePointsFilter->size() : 0; - if (filterSize <= 0) { - return false; - } - if (pInfo && (pInfo->getKeyIndexOf(nodeCodePoint) == NOT_AN_INDEX - || CharUtils::isIntentionalOmissionCodePoint(nodeCodePoint))) { - // If normalized nodeCodePoint is not on the keyboard or skippable, this child is never - // filtered. - return false; - } - const int lowerCodePoint = CharUtils::toLowerCase(nodeCodePoint); - const int baseLowerCodePoint = CharUtils::toBaseCodePoint(lowerCodePoint); - // TODO: Avoid linear search - for (int i = 0; i < filterSize; ++i) { - // Checking if a normalized code point is in filter characters when pInfo is not - // null. When pInfo is null, nodeCodePoint is used to check filtering without - // normalizing. - if ((pInfo && ((*codePointsFilter)[i] == lowerCodePoint - || (*codePointsFilter)[i] == baseLowerCodePoint)) - || (!pInfo && (*codePointsFilter)[i] == nodeCodePoint)) { - return false; - } - } - return true; -} - -/* static */ void DicNodeUtils::createAndGetAllLeavingChildNodes(DicNode *dicNode, - const BinaryDictionaryInfo *const binaryDictionaryInfo, - const ProximityInfoState *pInfoState, const int pointIndex, const bool exactOnly, - const std::vector<int> *const codePointsFilter, const ProximityInfo *const pInfo, - DicNodeVector *childDicNodes) { - const int terminalDepth = dicNode->getLeavingDepth(); - const int childCount = dicNode->getChildrenCount(); - int nextPos = dicNode->getChildrenPos(); - for (int i = 0; i < childCount; i++) { - const int filterSize = codePointsFilter ? codePointsFilter->size() : 0; - nextPos = createAndGetLeavingChildNode(dicNode, nextPos, binaryDictionaryInfo, - terminalDepth, pInfoState, pointIndex, exactOnly, codePointsFilter, pInfo, - childDicNodes); - if (!pInfo && filterSize > 0 && childDicNodes->exceeds(filterSize)) { - // All code points have been found. - break; - } - } -} - /* static */ void DicNodeUtils::getAllChildDicNodes(DicNode *dicNode, - const BinaryDictionaryInfo *const binaryDictionaryInfo, DicNodeVector *childDicNodes) { - getProximityChildDicNodes(dicNode, binaryDictionaryInfo, 0, 0, false, childDicNodes); -} - -/* static */ void DicNodeUtils::getProximityChildDicNodes(DicNode *dicNode, - const BinaryDictionaryInfo *const binaryDictionaryInfo, - const ProximityInfoState *pInfoState, const int pointIndex, bool exactOnly, + const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy, DicNodeVector *childDicNodes) { if (dicNode->isTotalInputSizeExceedingLimit()) { return; } if (!dicNode->isLeavingNode()) { - DicNodeUtils::createAndGetPassingChildNode(dicNode, pInfoState, pointIndex, exactOnly, - childDicNodes); + childDicNodes->pushPassingChild(dicNode); } else { - DicNodeUtils::createAndGetAllLeavingChildNodes( - dicNode, binaryDictionaryInfo, pInfoState, pointIndex, exactOnly, - 0 /* codePointsFilter */, 0 /* pInfo */, childDicNodes); + dictionaryStructurePolicy->createAndGetAllChildNodes(dicNode, childDicNodes); } } @@ -207,12 +70,13 @@ namespace latinime { * Computes the combined bigram / unigram cost for the given dicNode. */ /* static */ float DicNodeUtils::getBigramNodeImprobability( - const BinaryDictionaryInfo *const binaryDictionaryInfo, + const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy, const DicNode *const node, MultiBigramMap *multiBigramMap) { - if (node->isImpossibleBigramWord()) { + if (node->hasMultipleWords() && !node->isValidMultipleWordSuggestion()) { return static_cast<float>(MAX_VALUE_FOR_WEIGHTING); } - const int probability = getBigramNodeProbability(binaryDictionaryInfo, node, multiBigramMap); + const int probability = getBigramNodeProbability(dictionaryStructurePolicy, node, + multiBigramMap); // TODO: This equation to calculate the improbability looks unreasonable. Investigate this. const float cost = static_cast<float>(MAX_PROBABILITY - probability) / static_cast<float>(MAX_PROBABILITY); @@ -220,38 +84,23 @@ namespace latinime { } /* static */ int DicNodeUtils::getBigramNodeProbability( - const BinaryDictionaryInfo *const binaryDictionaryInfo, + const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy, const DicNode *const node, MultiBigramMap *multiBigramMap) { const int unigramProbability = node->getProbability(); const int wordPos = node->getPos(); const int prevWordPos = node->getPrevWordPos(); - if (NOT_VALID_WORD == wordPos || NOT_VALID_WORD == prevWordPos) { - // Note: Normally wordPos comes from the dictionary and should never equal NOT_VALID_WORD. - return ProbabilityUtils::backoff(unigramProbability); + if (NOT_A_DICT_POS == wordPos || NOT_A_DICT_POS == prevWordPos) { + // Note: Normally wordPos comes from the dictionary and should never equal + // NOT_A_VALID_WORD_POS. + return dictionaryStructurePolicy->getProbability(unigramProbability, + NOT_A_PROBABILITY); } if (multiBigramMap) { - return multiBigramMap->getBigramProbability( - binaryDictionaryInfo, prevWordPos, wordPos, unigramProbability); - } - return BinaryFormat::getBigramProbability( - binaryDictionaryInfo->getDictRoot(), prevWordPos, wordPos, unigramProbability); -} - -/////////////////////////////////////// -// Bigram / Unigram dictionary utils // -/////////////////////////////////////// - -/* static */ bool DicNodeUtils::isMatchedNodeCodePoint(const ProximityInfoState *pInfoState, - const int pointIndex, const bool exactOnly, const int nodeCodePoint) { - if (!pInfoState) { - return true; - } - if (exactOnly) { - return pInfoState->getPrimaryCodePointAt(pointIndex) == nodeCodePoint; + return multiBigramMap->getBigramProbability(dictionaryStructurePolicy, prevWordPos, + wordPos, unigramProbability); } - const ProximityType matchedId = pInfoState->getProximityType(pointIndex, nodeCodePoint, - true /* checkProximityChars */); - return isProximityChar(matchedId); + return dictionaryStructurePolicy->getProbability(unigramProbability, + NOT_A_PROBABILITY); } //////////////// @@ -280,7 +129,7 @@ namespace latinime { } actualLength1 = i + 1; } - actualLength1 = min(actualLength1, MAX_WORD_LENGTH - actualLength0 - 1); + actualLength1 = min(actualLength1, MAX_WORD_LENGTH - actualLength0); memcpy(&dest[actualLength0], src1, actualLength1 * sizeof(dest[0])); return actualLength0 + actualLength1; } diff --git a/native/jni/src/suggest/core/dicnode/dic_node_utils.h b/native/jni/src/suggest/core/dicnode/dic_node_utils.h index e198d6181..3fb351a61 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_utils.h +++ b/native/jni/src/suggest/core/dicnode/dic_node_utils.h @@ -18,68 +18,42 @@ #define LATINIME_DIC_NODE_UTILS_H #include <stdint.h> -#include <vector> #include "defines.h" namespace latinime { -class BinaryDictionaryInfo; class DicNode; class DicNodeVector; -class ProximityInfo; -class ProximityInfoState; +class DictionaryStructureWithBufferPolicy; class MultiBigramMap; class DicNodeUtils { public: static int appendTwoWords(const int *src0, const int16_t length0, const int *src1, const int16_t length1, int *dest); - static void initAsRoot(const BinaryDictionaryInfo *const binaryDictionaryInfo, + static void initAsRoot( + const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy, const int prevWordNodePos, DicNode *newRootNode); - static void initAsRootWithPreviousWord(const BinaryDictionaryInfo *const binaryDictionaryInfo, + static void initAsRootWithPreviousWord( + const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy, DicNode *prevWordLastNode, DicNode *newRootNode); static void initByCopy(DicNode *srcNode, DicNode *destNode); static void getAllChildDicNodes(DicNode *dicNode, - const BinaryDictionaryInfo *const binaryDictionaryInfo, DicNodeVector *childDicNodes); - static float getBigramNodeImprobability(const BinaryDictionaryInfo *const binaryDictionaryInfo, - const DicNode *const node, MultiBigramMap *const multiBigramMap); - static bool isDicNodeFilteredOut(const int nodeCodePoint, const ProximityInfo *const pInfo, - const std::vector<int> *const codePointsFilter); - // TODO: Move to private - static void getProximityChildDicNodes(DicNode *dicNode, - const BinaryDictionaryInfo *const binaryDictionaryInfo, - const ProximityInfoState *pInfoState, const int pointIndex, bool exactOnly, + const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy, DicNodeVector *childDicNodes); - - // TODO: Move to proximity info - static bool isProximityChar(ProximityType type) { - return type == MATCH_CHAR || type == PROXIMITY_CHAR || type == ADDITIONAL_PROXIMITY_CHAR; - } + static float getBigramNodeImprobability( + const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy, + const DicNode *const node, MultiBigramMap *const multiBigramMap); private: DISALLOW_IMPLICIT_CONSTRUCTORS(DicNodeUtils); // Max number of bigrams to look up static const int MAX_BIGRAMS_CONSIDERED_PER_CONTEXT = 500; - static int getBigramNodeProbability(const BinaryDictionaryInfo *const binaryDictionaryInfo, + static int getBigramNodeProbability( + const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy, const DicNode *const node, MultiBigramMap *multiBigramMap); - static void createAndGetPassingChildNode(DicNode *dicNode, const ProximityInfoState *pInfoState, - const int pointIndex, const bool exactOnly, DicNodeVector *childDicNodes); - static void createAndGetAllLeavingChildNodes(DicNode *dicNode, - const BinaryDictionaryInfo *const binaryDictionaryInfo, - const ProximityInfoState *pInfoState, const int pointIndex, const bool exactOnly, - const std::vector<int> *const codePointsFilter, - const ProximityInfo *const pInfo, DicNodeVector *childDicNodes); - static int createAndGetLeavingChildNode(DicNode *dicNode, int pos, - const BinaryDictionaryInfo *const binaryDictionaryInfo, const int terminalDepth, - const ProximityInfoState *pInfoState, const int pointIndex, - const bool exactOnly, const std::vector<int> *const codePointsFilter, - const ProximityInfo *const pInfo, DicNodeVector *childDicNodes); - - // TODO: Move to proximity info - static bool isMatchedNodeCodePoint(const ProximityInfoState *pInfoState, const int pointIndex, - const bool exactOnly, const int nodeCodePoint); }; } // namespace latinime #endif // LATINIME_DIC_NODE_UTILS_H diff --git a/native/jni/src/suggest/core/dicnode/dic_node_vector.h b/native/jni/src/suggest/core/dicnode/dic_node_vector.h index e23c411f0..42addae8d 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_vector.h +++ b/native/jni/src/suggest/core/dicnode/dic_node_vector.h @@ -62,17 +62,15 @@ class DicNodeVector { mDicNodes.back().initAsPassingChild(dicNode); } - void pushLeavingChild(DicNode *dicNode, const int pos, const uint8_t flags, - const int childrenPos, const int attributesPos, const int siblingPos, - const int nodeCodePoint, const int childrenCount, const int probability, - const int bigramProbability, const bool isTerminal, const bool hasMultipleChars, - const bool hasChildren, const uint16_t additionalSubwordLength, - const int *additionalSubword) { + void pushLeavingChild(const DicNode *const dicNode, const int pos, const int childrenPos, + const int probability, const bool isTerminal, const bool hasChildren, + const bool isBlacklistedOrNotAWord, const uint16_t mergedNodeCodePointCount, + const int *const mergedNodeCodePoints) { ASSERT(!mLock); mDicNodes.push_back(mEmptyNode); - mDicNodes.back().initAsChild(dicNode, pos, flags, childrenPos, attributesPos, siblingPos, - nodeCodePoint, childrenCount, probability, -1 /* bigramProbability */, isTerminal, - hasMultipleChars, hasChildren, additionalSubwordLength, additionalSubword); + mDicNodes.back().initAsChild(dicNode, pos, childrenPos, probability, isTerminal, + hasChildren, isBlacklistedOrNotAWord, mergedNodeCodePointCount, + mergedNodeCodePoints); } DicNode *operator[](const int id) { diff --git a/native/jni/src/suggest/core/dicnode/dic_nodes_cache.cpp b/native/jni/src/suggest/core/dicnode/dic_nodes_cache.cpp index c3d2a2e74..b6be47e90 100644 --- a/native/jni/src/suggest/core/dicnode/dic_nodes_cache.cpp +++ b/native/jni/src/suggest/core/dicnode/dic_nodes_cache.cpp @@ -23,6 +23,11 @@ namespace latinime { +// The biggest value among MAX_CACHE_DIC_NODE_SIZE, MAX_CACHE_DIC_NODE_SIZE_FOR_SINGLE_POINT, ... +const int DicNodesCache::LARGE_PRIORITY_QUEUE_CAPACITY = 310; +// Capacity for reducing memory footprint. +const int DicNodesCache::SMALL_PRIORITY_QUEUE_CAPACITY = 100; + /** * Truncates all of the dicNodes so that they start at the given commit point. * Only called for multi-word typing input. diff --git a/native/jni/src/suggest/core/dicnode/dic_nodes_cache.h b/native/jni/src/suggest/core/dicnode/dic_nodes_cache.h index 7f5bdbcf6..8493b6a8b 100644 --- a/native/jni/src/suggest/core/dicnode/dic_nodes_cache.h +++ b/native/jni/src/suggest/core/dicnode/dic_nodes_cache.h @@ -22,12 +22,6 @@ #include "defines.h" #include "suggest/core/dicnode/dic_node_priority_queue.h" -#define INITIAL_QUEUE_ID_ACTIVE 0 -#define INITIAL_QUEUE_ID_NEXT_ACTIVE 1 -#define INITIAL_QUEUE_ID_TERMINAL 2 -#define INITIAL_QUEUE_ID_CACHE_FOR_CONTINUOUS_SUGGESTION 3 -#define PRIORITY_QUEUES_SIZE 4 - namespace latinime { class DicNode; @@ -37,24 +31,32 @@ class DicNode; */ class DicNodesCache { public: - AK_FORCE_INLINE DicNodesCache() - : mActiveDicNodes(&mDicNodePriorityQueues[INITIAL_QUEUE_ID_ACTIVE]), - mNextActiveDicNodes(&mDicNodePriorityQueues[INITIAL_QUEUE_ID_NEXT_ACTIVE]), - mTerminalDicNodes(&mDicNodePriorityQueues[INITIAL_QUEUE_ID_TERMINAL]), - mCachedDicNodesForContinuousSuggestion( - &mDicNodePriorityQueues[INITIAL_QUEUE_ID_CACHE_FOR_CONTINUOUS_SUGGESTION]), - mInputIndex(0), mLastCachedInputIndex(0) { - } + AK_FORCE_INLINE explicit DicNodesCache(const bool usesLargeCapacityCache) + : mUsesLargeCapacityCache(usesLargeCapacityCache), + mDicNodePriorityQueue0(getCacheCapacity()), + mDicNodePriorityQueue1(getCacheCapacity()), + mDicNodePriorityQueue2(getCacheCapacity()), + mDicNodePriorityQueueForTerminal(MAX_RESULTS), + mActiveDicNodes(&mDicNodePriorityQueue0), + mNextActiveDicNodes(&mDicNodePriorityQueue1), + mCachedDicNodesForContinuousSuggestion(&mDicNodePriorityQueue2), + mTerminalDicNodes(&mDicNodePriorityQueueForTerminal), + mInputIndex(0), mLastCachedInputIndex(0) {} AK_FORCE_INLINE virtual ~DicNodesCache() {} AK_FORCE_INLINE void reset(const int nextActiveSize, const int terminalSize) { mInputIndex = 0; mLastCachedInputIndex = 0; - mActiveDicNodes->reset(); - mNextActiveDicNodes->clearAndResize(nextActiveSize); + // We want to use the max capacity for the current active dic node queue. + mActiveDicNodes->clearAndResizeToCapacity(); + // nextActiveSize is used to limit the next iteration's active dic node size. + const int nextActiveSizeFittingToTheCapacity = min(nextActiveSize, getCacheCapacity()); + mNextActiveDicNodes->clearAndResize(nextActiveSizeFittingToTheCapacity); mTerminalDicNodes->clearAndResize(terminalSize); - mCachedDicNodesForContinuousSuggestion->reset(); + // We want to use the max capacity for the cached dic nodes that will be used for the + // continuous suggestion. + mCachedDicNodesForContinuousSuggestion->clearAndResizeToCapacity(); } AK_FORCE_INLINE void continueSearch() { @@ -147,9 +149,8 @@ class DicNodesCache { mCachedDicNodesForContinuousSuggestion->dump(); } mInputIndex = mLastCachedInputIndex; - mCachedDicNodesForContinuousSuggestion = - moveNodesAndReturnReusableEmptyQueue( - mCachedDicNodesForContinuousSuggestion, &mActiveDicNodes); + mCachedDicNodesForContinuousSuggestion = moveNodesAndReturnReusableEmptyQueue( + mCachedDicNodesForContinuousSuggestion, &mActiveDicNodes); } AK_FORCE_INLINE static DicNodePriorityQueue *moveNodesAndReturnReusableEmptyQueue( @@ -163,21 +164,35 @@ class DicNodesCache { return tmp; } + AK_FORCE_INLINE int getCacheCapacity() const { + return mUsesLargeCapacityCache ? + LARGE_PRIORITY_QUEUE_CAPACITY : SMALL_PRIORITY_QUEUE_CAPACITY; + } + AK_FORCE_INLINE void resetTemporaryCaches() { mActiveDicNodes->clear(); mNextActiveDicNodes->clear(); mTerminalDicNodes->clear(); } - DicNodePriorityQueue mDicNodePriorityQueues[PRIORITY_QUEUES_SIZE]; + static const int LARGE_PRIORITY_QUEUE_CAPACITY; + static const int SMALL_PRIORITY_QUEUE_CAPACITY; + + const bool mUsesLargeCapacityCache; + // Instances + DicNodePriorityQueue mDicNodePriorityQueue0; + DicNodePriorityQueue mDicNodePriorityQueue1; + DicNodePriorityQueue mDicNodePriorityQueue2; + DicNodePriorityQueue mDicNodePriorityQueueForTerminal; + // Active dicNodes currently being expanded. DicNodePriorityQueue *mActiveDicNodes; // Next dicNodes to be expanded. DicNodePriorityQueue *mNextActiveDicNodes; - // Current top terminal dicNodes. - DicNodePriorityQueue *mTerminalDicNodes; // Cached dicNodes used for continuous suggestion. DicNodePriorityQueue *mCachedDicNodesForContinuousSuggestion; + // Current top terminal dicNodes. + DicNodePriorityQueue *mTerminalDicNodes; int mInputIndex; int mLastCachedInputIndex; }; diff --git a/native/jni/src/suggest/core/dicnode/dic_node_properties.h b/native/jni/src/suggest/core/dicnode/internal/dic_node_properties.h index d2f87c10b..9e0f62ceb 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_properties.h +++ b/native/jni/src/suggest/core/dicnode/internal/dic_node_properties.h @@ -20,60 +20,46 @@ #include <stdint.h> #include "defines.h" -#include "suggest/core/dictionary/binary_format.h" namespace latinime { /** * Node for traversing the lexicon trie. */ +// TODO: Introduce a dictionary node class which has attribute members required to understand the +// dictionary structure. class DicNodeProperties { public: AK_FORCE_INLINE DicNodeProperties() - : mPos(0), mFlags(0), mChildrenPos(0), mAttributesPos(0), mSiblingPos(0), - mChildrenCount(0), mProbability(0), mBigramProbability(0), mNodeCodePoint(0), - mDepth(0), mLeavingDepth(0), mIsTerminal(false), mHasMultipleChars(false), - mHasChildren(false) { - } + : mPos(0), mChildrenPos(0), mProbability(0), mNodeCodePoint(0), mIsTerminal(false), + mHasChildren(false), mIsBlacklistedOrNotAWord(false), mDepth(0), mLeavingDepth(0) {} virtual ~DicNodeProperties() {} // Should be called only once per DicNode is initialized. - void init(const int pos, const uint8_t flags, const int childrenPos, const int attributesPos, - const int siblingPos, const int nodeCodePoint, const int childrenCount, - const int probability, const int bigramProbability, const bool isTerminal, - const bool hasMultipleChars, const bool hasChildren, const uint16_t depth, - const uint16_t terminalDepth) { + void init(const int pos, const int childrenPos, const int nodeCodePoint, const int probability, + const bool isTerminal, const bool hasChildren, const bool isBlacklistedOrNotAWord, + const uint16_t depth, const uint16_t leavingDepth) { mPos = pos; - mFlags = flags; mChildrenPos = childrenPos; - mAttributesPos = attributesPos; - mSiblingPos = siblingPos; mNodeCodePoint = nodeCodePoint; - mChildrenCount = childrenCount; mProbability = probability; - mBigramProbability = bigramProbability; mIsTerminal = isTerminal; - mHasMultipleChars = hasMultipleChars; mHasChildren = hasChildren; + mIsBlacklistedOrNotAWord = isBlacklistedOrNotAWord; mDepth = depth; - mLeavingDepth = terminalDepth; + mLeavingDepth = leavingDepth; } // Init for copy void init(const DicNodeProperties *const nodeProp) { mPos = nodeProp->mPos; - mFlags = nodeProp->mFlags; mChildrenPos = nodeProp->mChildrenPos; - mAttributesPos = nodeProp->mAttributesPos; - mSiblingPos = nodeProp->mSiblingPos; mNodeCodePoint = nodeProp->mNodeCodePoint; - mChildrenCount = nodeProp->mChildrenCount; mProbability = nodeProp->mProbability; - mBigramProbability = nodeProp->mBigramProbability; mIsTerminal = nodeProp->mIsTerminal; - mHasMultipleChars = nodeProp->mHasMultipleChars; mHasChildren = nodeProp->mHasChildren; + mIsBlacklistedOrNotAWord = nodeProp->mIsBlacklistedOrNotAWord; mDepth = nodeProp->mDepth; mLeavingDepth = nodeProp->mLeavingDepth; } @@ -81,17 +67,12 @@ class DicNodeProperties { // Init as passing child void init(const DicNodeProperties *const nodeProp, const int codePoint) { mPos = nodeProp->mPos; - mFlags = nodeProp->mFlags; mChildrenPos = nodeProp->mChildrenPos; - mAttributesPos = nodeProp->mAttributesPos; - mSiblingPos = nodeProp->mSiblingPos; mNodeCodePoint = codePoint; // Overwrite the node char of a passing child - mChildrenCount = nodeProp->mChildrenCount; mProbability = nodeProp->mProbability; - mBigramProbability = nodeProp->mBigramProbability; mIsTerminal = nodeProp->mIsTerminal; - mHasMultipleChars = nodeProp->mHasMultipleChars; mHasChildren = nodeProp->mHasChildren; + mIsBlacklistedOrNotAWord = nodeProp->mIsBlacklistedOrNotAWord; mDepth = nodeProp->mDepth + 1; // Increment the depth of a passing child mLeavingDepth = nodeProp->mLeavingDepth; } @@ -100,22 +81,10 @@ class DicNodeProperties { return mPos; } - uint8_t getFlags() const { - return mFlags; - } - int getChildrenPos() const { return mChildrenPos; } - int getAttributesPos() const { - return mAttributesPos; - } - - int getChildrenCount() const { - return mChildrenCount; - } - int getProbability() const { return mProbability; } @@ -137,42 +106,27 @@ class DicNodeProperties { return mIsTerminal; } - bool hasMultipleChars() const { - return mHasMultipleChars; - } - bool hasChildren() const { - return mChildrenCount > 0 || mDepth != mLeavingDepth; + return mHasChildren || mDepth != mLeavingDepth; } - bool hasBlacklistedOrNotAWordFlag() const { - return BinaryFormat::hasBlacklistedOrNotAWordFlag(mFlags); + bool isBlacklistedOrNotAWord() const { + return mIsBlacklistedOrNotAWord; } private: // Caution!!! // Use a default copy constructor and an assign operator because shallow copies are ok // for this class - - // Not used - int getSiblingPos() const { - return mSiblingPos; - } - int mPos; - uint8_t mFlags; int mChildrenPos; - int mAttributesPos; - int mSiblingPos; - int mChildrenCount; int mProbability; - int mBigramProbability; // not used for now int mNodeCodePoint; - uint16_t mDepth; - uint16_t mLeavingDepth; bool mIsTerminal; - bool mHasMultipleChars; bool mHasChildren; + bool mIsBlacklistedOrNotAWord; + uint16_t mDepth; + uint16_t mLeavingDepth; }; } // namespace latinime #endif // LATINIME_DIC_NODE_PROPERTIES_H diff --git a/native/jni/src/suggest/core/dicnode/dic_node_state.h b/native/jni/src/suggest/core/dicnode/internal/dic_node_state.h index d35e7d79f..b0fddb724 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_state.h +++ b/native/jni/src/suggest/core/dicnode/internal/dic_node_state.h @@ -18,10 +18,10 @@ #define LATINIME_DIC_NODE_STATE_H #include "defines.h" -#include "suggest/core/dicnode/dic_node_state_input.h" -#include "suggest/core/dicnode/dic_node_state_output.h" -#include "suggest/core/dicnode/dic_node_state_prevword.h" -#include "suggest/core/dicnode/dic_node_state_scoring.h" +#include "suggest/core/dicnode/internal/dic_node_state_input.h" +#include "suggest/core/dicnode/internal/dic_node_state_output.h" +#include "suggest/core/dicnode/internal/dic_node_state_prevword.h" +#include "suggest/core/dicnode/internal/dic_node_state_scoring.h" namespace latinime { @@ -55,11 +55,12 @@ class DicNodeState { mDicNodeStateScoring.init(&src->mDicNodeStateScoring); } - // Init by copy and adding subword - void init(const DicNodeState *const src, const uint16_t additionalSubwordLength, - const int *const additionalSubword) { + // Init by copy and adding merged node code points. + void init(const DicNodeState *const src, const uint16_t mergedNodeCodePointCount, + const int *const mergedNodeCodePoints) { init(src); - mDicNodeStateOutput.addSubword(additionalSubwordLength, additionalSubword); + mDicNodeStateOutput.addMergedNodeCodePoints( + mergedNodeCodePointCount, mergedNodeCodePoints); } private: diff --git a/native/jni/src/suggest/core/dicnode/dic_node_state_input.h b/native/jni/src/suggest/core/dicnode/internal/dic_node_state_input.h index bbd9435b5..bbd9435b5 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_state_input.h +++ b/native/jni/src/suggest/core/dicnode/internal/dic_node_state_input.h diff --git a/native/jni/src/suggest/core/dicnode/internal/dic_node_state_output.h b/native/jni/src/suggest/core/dicnode/internal/dic_node_state_output.h new file mode 100644 index 000000000..74eb5dfe7 --- /dev/null +++ b/native/jni/src/suggest/core/dicnode/internal/dic_node_state_output.h @@ -0,0 +1,79 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DIC_NODE_STATE_OUTPUT_H +#define LATINIME_DIC_NODE_STATE_OUTPUT_H + +#include <cstring> // for memcpy() +#include <stdint.h> + +#include "defines.h" + +namespace latinime { + +class DicNodeStateOutput { + public: + DicNodeStateOutput() : mOutputtedCodePointCount(0) { + init(); + } + + virtual ~DicNodeStateOutput() {} + + void init() { + mOutputtedCodePointCount = 0; + mCodePointsBuf[0] = 0; + } + + void init(const DicNodeStateOutput *const stateOutput) { + memcpy(mCodePointsBuf, stateOutput->mCodePointsBuf, + stateOutput->mOutputtedCodePointCount * sizeof(mCodePointsBuf[0])); + mOutputtedCodePointCount = stateOutput->mOutputtedCodePointCount; + if (mOutputtedCodePointCount < MAX_WORD_LENGTH) { + mCodePointsBuf[mOutputtedCodePointCount] = 0; + } + } + + void addMergedNodeCodePoints(const uint16_t mergedNodeCodePointCount, + const int *const mergedNodeCodePoints) { + if (mergedNodeCodePoints) { + const int additionalCodePointCount = min(static_cast<int>(mergedNodeCodePointCount), + MAX_WORD_LENGTH - mOutputtedCodePointCount); + memcpy(&mCodePointsBuf[mOutputtedCodePointCount], mergedNodeCodePoints, + additionalCodePointCount * sizeof(mCodePointsBuf[0])); + mOutputtedCodePointCount = static_cast<uint16_t>( + mOutputtedCodePointCount + mergedNodeCodePointCount); + if (mOutputtedCodePointCount < MAX_WORD_LENGTH) { + mCodePointsBuf[mOutputtedCodePointCount] = 0; + } + } + } + + // TODO: Remove + int getCodePointAt(const int index) const { + return mCodePointsBuf[index]; + } + + // TODO: Move to private + int mCodePointsBuf[MAX_WORD_LENGTH]; + + private: + // Caution!!! + // Use a default copy constructor and an assign operator because shallow copies are ok + // for this class + uint16_t mOutputtedCodePointCount; +}; +} // namespace latinime +#endif // LATINIME_DIC_NODE_STATE_OUTPUT_H diff --git a/native/jni/src/suggest/core/dicnode/dic_node_state_prevword.h b/native/jni/src/suggest/core/dicnode/internal/dic_node_state_prevword.h index c3968c090..b8986203d 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_state_prevword.h +++ b/native/jni/src/suggest/core/dicnode/internal/dic_node_state_prevword.h @@ -22,6 +22,7 @@ #include "defines.h" #include "suggest/core/dicnode/dic_node_utils.h" +#include "suggest/core/layout/proximity_info_state.h" namespace latinime { @@ -29,9 +30,8 @@ class DicNodeStatePrevWord { public: AK_FORCE_INLINE DicNodeStatePrevWord() : mPrevWordCount(0), mPrevWordLength(0), mPrevWordStart(0), mPrevWordProbability(0), - mPrevWordNodePos(0) { + mPrevWordNodePos(NOT_A_DICT_POS), mSecondWordFirstInputIndex(NOT_AN_INDEX) { memset(mPrevWord, 0, sizeof(mPrevWord)); - memset(mPrevSpacePositions, 0, sizeof(mPrevSpacePositions)); } virtual ~DicNodeStatePrevWord() {} @@ -41,8 +41,8 @@ class DicNodeStatePrevWord { mPrevWordCount = 0; mPrevWordStart = 0; mPrevWordProbability = -1; - mPrevWordNodePos = NOT_VALID_WORD; - memset(mPrevSpacePositions, 0, sizeof(mPrevSpacePositions)); + mPrevWordNodePos = NOT_A_DICT_POS; + mSecondWordFirstInputIndex = NOT_AN_INDEX; } void init(const int prevWordNodePos) { @@ -51,7 +51,7 @@ class DicNodeStatePrevWord { mPrevWordStart = 0; mPrevWordProbability = -1; mPrevWordNodePos = prevWordNodePos; - memset(mPrevSpacePositions, 0, sizeof(mPrevSpacePositions)); + mSecondWordFirstInputIndex = NOT_AN_INDEX; } // Init by copy @@ -61,24 +61,26 @@ class DicNodeStatePrevWord { mPrevWordStart = prevWord->mPrevWordStart; mPrevWordProbability = prevWord->mPrevWordProbability; mPrevWordNodePos = prevWord->mPrevWordNodePos; + mSecondWordFirstInputIndex = prevWord->mSecondWordFirstInputIndex; memcpy(mPrevWord, prevWord->mPrevWord, prevWord->mPrevWordLength * sizeof(mPrevWord[0])); - memcpy(mPrevSpacePositions, prevWord->mPrevSpacePositions, sizeof(mPrevSpacePositions)); } void init(const int16_t prevWordCount, const int16_t prevWordProbability, const int prevWordNodePos, const int *const src0, const int16_t length0, - const int *const src1, const int16_t length1, const int *const prevSpacePositions, - const int lastInputIndex) { - mPrevWordCount = prevWordCount; + const int *const src1, const int16_t length1, + const int prevWordSecondWordFirstInputIndex, const int lastInputIndex) { + mPrevWordCount = min(prevWordCount, static_cast<int16_t>(MAX_RESULTS)); mPrevWordProbability = prevWordProbability; mPrevWordNodePos = prevWordNodePos; - const int twoWordsLen = + int twoWordsLen = DicNodeUtils::appendTwoWords(src0, length0, src1, length1, mPrevWord); + if (twoWordsLen >= MAX_WORD_LENGTH) { + twoWordsLen = MAX_WORD_LENGTH - 1; + } mPrevWord[twoWordsLen] = KEYCODE_SPACE; mPrevWordStart = length0; mPrevWordLength = static_cast<int16_t>(twoWordsLen + 1); - memcpy(mPrevSpacePositions, prevSpacePositions, sizeof(mPrevSpacePositions)); - mPrevSpacePositions[mPrevWordCount - 1] = lastInputIndex; + mSecondWordFirstInputIndex = prevWordSecondWordFirstInputIndex; } void truncate(const int offset) { @@ -93,11 +95,12 @@ class DicNodeStatePrevWord { mPrevWordLength = newPrevWordLength; } - void outputSpacePositions(int *spaceIndices) const { - // Convert uint16_t to int - for (int i = 0; i < MAX_RESULTS; i++) { - spaceIndices[i] = mPrevSpacePositions[i]; - } + void setSecondWordFirstInputIndex(const int inputIndex) { + mSecondWordFirstInputIndex = inputIndex; + } + + int getSecondWordFirstInputIndex() const { + return mSecondWordFirstInputIndex; } // TODO: remove @@ -113,10 +116,6 @@ class DicNodeStatePrevWord { return mPrevWordStart; } - int16_t getPrevWordProbability() const { - return mPrevWordProbability; - } - int getPrevWordNodePos() const { return mPrevWordNodePos; } @@ -139,8 +138,6 @@ class DicNodeStatePrevWord { // TODO: Move to private int mPrevWord[MAX_WORD_LENGTH]; - // TODO: Move to private - int mPrevSpacePositions[MAX_RESULTS]; private: // Caution!!! @@ -151,6 +148,7 @@ class DicNodeStatePrevWord { int16_t mPrevWordStart; int16_t mPrevWordProbability; int mPrevWordNodePos; + int mSecondWordFirstInputIndex; }; } // namespace latinime #endif // LATINIME_DIC_NODE_STATE_PREVWORD_H diff --git a/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h b/native/jni/src/suggest/core/dicnode/internal/dic_node_state_scoring.h index 4c884225a..4c884225a 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h +++ b/native/jni/src/suggest/core/dicnode/internal/dic_node_state_scoring.h diff --git a/native/jni/src/suggest/core/dictionary/bigram_dictionary.cpp b/native/jni/src/suggest/core/dictionary/bigram_dictionary.cpp index 59d1b19b6..71f4ef6ea 100644 --- a/native/jni/src/suggest/core/dictionary/bigram_dictionary.cpp +++ b/native/jni/src/suggest/core/dictionary/bigram_dictionary.cpp @@ -21,17 +21,16 @@ #include "bigram_dictionary.h" #include "defines.h" -#include "suggest/core/dictionary/binary_dictionary_info.h" -#include "suggest/core/dictionary/binary_format.h" -#include "suggest/core/dictionary/bloom_filter.h" +#include "suggest/core/dictionary/binary_dictionary_bigrams_iterator.h" #include "suggest/core/dictionary/dictionary.h" -#include "suggest/core/dictionary/probability_utils.h" +#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" #include "utils/char_utils.h" namespace latinime { -BigramDictionary::BigramDictionary(const BinaryDictionaryInfo *const binaryDictionaryInfo) - : mBinaryDictionaryInfo(binaryDictionaryInfo) { +BigramDictionary::BigramDictionary( + const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy) + : mDictionaryStructurePolicy(dictionaryStructurePolicy) { if (DEBUG_DICT) { AKLOGI("BigramDictionary - constructor"); } @@ -88,147 +87,89 @@ void BigramDictionary::addWordBigram(int *word, int length, int probability, int /* Parameters : * prevWord: the word before, the one for which we need to look up bigrams. * prevWordLength: its length. - * inputCodePoints: what user typed, in the same format as for UnigramDictionary::getSuggestions. - * inputSize: the size of the codes array. - * bigramCodePoints: an array for output, at the same format as outwords for getSuggestions. - * bigramProbability: an array to output frequencies. + * outBigramCodePoints: an array for output, at the same format as outwords for getSuggestions. + * outBigramProbability: an array to output frequencies. * outputTypes: an array to output types. * This method returns the number of bigrams this word has, for backward compatibility. - * Note: this is not the number of bigrams output in the array, which is the number of - * bigrams this word has WHOSE first letter also matches the letter the user typed. - * TODO: this may not be a sensible thing to do. It makes sense when the bigrams are - * used to match the first letter of the second word, but once the user has typed more - * and the bigrams are used to boost unigram result scores, it makes little sense to - * reduce their scope to the ones that match the first letter. */ -int BigramDictionary::getBigrams(const int *prevWord, int prevWordLength, int *inputCodePoints, - int inputSize, int *bigramCodePoints, int *bigramProbability, int *outputTypes) const { +int BigramDictionary::getPredictions(const int *prevWord, const int prevWordLength, + int *const outBigramCodePoints, int *const outBigramProbability, + int *const outputTypes) const { // TODO: remove unused arguments, and refrain from storing stuff in members of this class // TODO: have "in" arguments before "out" ones, and make out args explicit in the name - const uint8_t *const root = mBinaryDictionaryInfo->getDictRoot(); int pos = getBigramListPositionForWord(prevWord, prevWordLength, false /* forceLowerCaseSearch */); // getBigramListPositionForWord returns 0 if this word isn't in the dictionary or has no bigrams - if (0 == pos) { + if (NOT_A_DICT_POS == pos) { // If no bigrams for this exact word, search again in lower case. pos = getBigramListPositionForWord(prevWord, prevWordLength, true /* forceLowerCaseSearch */); } // If still no bigrams, we really don't have them! - if (0 == pos) return 0; - uint8_t bigramFlags; + if (NOT_A_DICT_POS == pos) return 0; + int bigramCount = 0; - do { - bigramFlags = BinaryFormat::getFlagsAndForwardPointer(root, &pos); - int bigramBuffer[MAX_WORD_LENGTH]; - int unigramProbability = 0; - const int bigramPos = BinaryFormat::getAttributeAddressAndForwardPointer(root, bigramFlags, - &pos); - const int length = BinaryFormat::getWordAtAddress(root, bigramPos, MAX_WORD_LENGTH, - bigramBuffer, &unigramProbability); - - // inputSize == 0 means we are trying to find bigram predictions. - if (inputSize < 1 || checkFirstCharacter(bigramBuffer, inputCodePoints)) { - const int bigramProbabilityTemp = - BinaryFormat::MASK_ATTRIBUTE_PROBABILITY & bigramFlags; - // Due to space constraints, the probability for bigrams is approximate - the lower the - // unigram probability, the worse the precision. The theoritical maximum error in - // resulting probability is 8 - although in the practice it's never bigger than 3 or 4 - // in very bad cases. This means that sometimes, we'll see some bigrams interverted - // here, but it can't get too bad. - const int probability = ProbabilityUtils::computeProbabilityForBigram( - unigramProbability, bigramProbabilityTemp); - addWordBigram(bigramBuffer, length, probability, bigramProbability, bigramCodePoints, - outputTypes); - ++bigramCount; + int unigramProbability = 0; + int bigramBuffer[MAX_WORD_LENGTH]; + BinaryDictionaryBigramsIterator bigramsIt( + mDictionaryStructurePolicy->getBigramsStructurePolicy(), pos); + while (bigramsIt.hasNext()) { + bigramsIt.next(); + if (bigramsIt.getBigramPos() == NOT_A_DICT_POS) { + continue; + } + const int codePointCount = mDictionaryStructurePolicy-> + getCodePointsAndProbabilityAndReturnCodePointCount(bigramsIt.getBigramPos(), + MAX_WORD_LENGTH, bigramBuffer, &unigramProbability); + if (codePointCount <= 0) { + continue; } - } while (BinaryFormat::FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags); + // Due to space constraints, the probability for bigrams is approximate - the lower the + // unigram probability, the worse the precision. The theoritical maximum error in + // resulting probability is 8 - although in the practice it's never bigger than 3 or 4 + // in very bad cases. This means that sometimes, we'll see some bigrams interverted + // here, but it can't get too bad. + const int probability = mDictionaryStructurePolicy->getProbability( + unigramProbability, bigramsIt.getProbability()); + addWordBigram(bigramBuffer, codePointCount, probability, outBigramProbability, + outBigramCodePoints, outputTypes); + ++bigramCount; + } return min(bigramCount, MAX_RESULTS); } // Returns a pointer to the start of the bigram list. -// If the word is not found or has no bigrams, this function returns 0. +// If the word is not found or has no bigrams, this function returns NOT_A_DICT_POS. int BigramDictionary::getBigramListPositionForWord(const int *prevWord, const int prevWordLength, const bool forceLowerCaseSearch) const { - if (0 >= prevWordLength) return 0; - const uint8_t *const root = mBinaryDictionaryInfo->getDictRoot(); - int pos = BinaryFormat::getTerminalPosition(root, prevWord, prevWordLength, + if (0 >= prevWordLength) return NOT_A_DICT_POS; + int pos = mDictionaryStructurePolicy->getTerminalNodePositionOfWord(prevWord, prevWordLength, forceLowerCaseSearch); - - if (NOT_VALID_WORD == pos) return 0; - const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(root, &pos); - if (0 == (flags & BinaryFormat::FLAG_HAS_BIGRAMS)) return 0; - if (0 == (flags & BinaryFormat::FLAG_HAS_MULTIPLE_CHARS)) { - BinaryFormat::getCodePointAndForwardPointer(root, &pos); - } else { - pos = BinaryFormat::skipOtherCharacters(root, pos); - } - pos = BinaryFormat::skipProbability(flags, pos); - pos = BinaryFormat::skipChildrenPosition(flags, pos); - pos = BinaryFormat::skipShortcuts(root, flags, pos); - return pos; -} - -void BigramDictionary::fillBigramAddressToProbabilityMapAndFilter(const int *prevWord, - const int prevWordLength, std::map<int, int> *map, uint8_t *filter) const { - memset(filter, 0, BIGRAM_FILTER_BYTE_SIZE); - const uint8_t *const root = mBinaryDictionaryInfo->getDictRoot(); - int pos = getBigramListPositionForWord(prevWord, prevWordLength, - false /* forceLowerCaseSearch */); - if (0 == pos) { - // If no bigrams for this exact string, search again in lower case. - pos = getBigramListPositionForWord(prevWord, prevWordLength, - true /* forceLowerCaseSearch */); - } - if (0 == pos) return; - - uint8_t bigramFlags; - do { - bigramFlags = BinaryFormat::getFlagsAndForwardPointer(root, &pos); - const int probability = BinaryFormat::MASK_ATTRIBUTE_PROBABILITY & bigramFlags; - const int bigramPos = BinaryFormat::getAttributeAddressAndForwardPointer(root, bigramFlags, - &pos); - (*map)[bigramPos] = probability; - setInFilter(filter, bigramPos); - } while (BinaryFormat::FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags); + if (NOT_A_DICT_POS == pos) return NOT_A_DICT_POS; + return mDictionaryStructurePolicy->getBigramsPositionOfPtNode(pos); } -bool BigramDictionary::checkFirstCharacter(int *word, int *inputCodePoints) const { - // Checks whether this word starts with same character or neighboring characters of - // what user typed. - - int maxAlt = MAX_ALTERNATIVES; - const int firstBaseLowerCodePoint = CharUtils::toBaseLowerCase(*word); - while (maxAlt > 0) { - if (CharUtils::toBaseLowerCase(*inputCodePoints) == firstBaseLowerCodePoint) { - return true; - } - inputCodePoints++; - maxAlt--; - } - return false; -} - -bool BigramDictionary::isValidBigram(const int *word1, int length1, const int *word2, - int length2) const { - const uint8_t *const root = mBinaryDictionaryInfo->getDictRoot(); - int pos = getBigramListPositionForWord(word1, length1, false /* forceLowerCaseSearch */); +int BigramDictionary::getBigramProbability(const int *word0, int length0, const int *word1, + int length1) const { + int pos = getBigramListPositionForWord(word0, length0, false /* forceLowerCaseSearch */); // getBigramListPositionForWord returns 0 if this word isn't in the dictionary or has no bigrams - if (0 == pos) return false; - int nextWordPos = BinaryFormat::getTerminalPosition(root, word2, length2, + if (NOT_A_DICT_POS == pos) return NOT_A_PROBABILITY; + int nextWordPos = mDictionaryStructurePolicy->getTerminalNodePositionOfWord(word1, length1, false /* forceLowerCaseSearch */); - if (NOT_VALID_WORD == nextWordPos) return false; - uint8_t bigramFlags; - do { - bigramFlags = BinaryFormat::getFlagsAndForwardPointer(root, &pos); - const int bigramPos = BinaryFormat::getAttributeAddressAndForwardPointer(root, bigramFlags, - &pos); - if (bigramPos == nextWordPos) { - return true; + if (NOT_A_DICT_POS == nextWordPos) return NOT_A_PROBABILITY; + + BinaryDictionaryBigramsIterator bigramsIt( + mDictionaryStructurePolicy->getBigramsStructurePolicy(), pos); + while (bigramsIt.hasNext()) { + bigramsIt.next(); + if (bigramsIt.getBigramPos() == nextWordPos) { + return mDictionaryStructurePolicy->getProbability( + mDictionaryStructurePolicy->getUnigramProbabilityOfPtNode(nextWordPos), + bigramsIt.getProbability()); } - } while (BinaryFormat::FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags); - return false; + } + return NOT_A_PROBABILITY; } // TODO: Move functions related to bigram to here diff --git a/native/jni/src/suggest/core/dictionary/bigram_dictionary.h b/native/jni/src/suggest/core/dictionary/bigram_dictionary.h index 8b7a253a2..8af7ee75d 100644 --- a/native/jni/src/suggest/core/dictionary/bigram_dictionary.h +++ b/native/jni/src/suggest/core/dictionary/bigram_dictionary.h @@ -17,37 +17,30 @@ #ifndef LATINIME_BIGRAM_DICTIONARY_H #define LATINIME_BIGRAM_DICTIONARY_H -#include <map> -#include <stdint.h> - #include "defines.h" namespace latinime { -class BinaryDictionaryInfo; +class DictionaryStructureWithBufferPolicy; class BigramDictionary { public: - BigramDictionary(const BinaryDictionaryInfo *const binaryDictionaryInfo); + BigramDictionary(const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy); - int getBigrams(const int *word, int length, int *inputCodePoints, int inputSize, int *outWords, - int *frequencies, int *outputTypes) const; - void fillBigramAddressToProbabilityMapAndFilter(const int *prevWord, const int prevWordLength, - std::map<int, int> *map, uint8_t *filter) const; - bool isValidBigram(const int *word1, int length1, const int *word2, int length2) const; + int getPredictions(const int *word, int length, int *outBigramCodePoints, + int *outBigramProbability, int *outputTypes) const; + int getBigramProbability(const int *word1, int length1, const int *word2, int length2) const; ~BigramDictionary(); + private: DISALLOW_IMPLICIT_CONSTRUCTORS(BigramDictionary); void addWordBigram(int *word, int length, int probability, int *bigramProbability, int *bigramCodePoints, int *outputTypes) const; - bool checkFirstCharacter(int *word, int *inputCodePoints) const; int getBigramListPositionForWord(const int *prevWord, const int prevWordLength, const bool forceLowerCaseSearch) const; - const BinaryDictionaryInfo *const mBinaryDictionaryInfo; - // TODO: Re-implement proximity correction for bigram correction - static const int MAX_ALTERNATIVES = 1; + const DictionaryStructureWithBufferPolicy *const mDictionaryStructurePolicy; }; } // namespace latinime #endif // LATINIME_BIGRAM_DICTIONARY_H diff --git a/native/jni/src/suggest/core/dictionary/binary_dictionary_bigrams_iterator.h b/native/jni/src/suggest/core/dictionary/binary_dictionary_bigrams_iterator.h new file mode 100644 index 000000000..d16ac47fe --- /dev/null +++ b/native/jni/src/suggest/core/dictionary/binary_dictionary_bigrams_iterator.h @@ -0,0 +1,59 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_BINARY_DICTIONARY_BIGRAMS_ITERATOR_H +#define LATINIME_BINARY_DICTIONARY_BIGRAMS_ITERATOR_H + +#include "defines.h" +#include "suggest/core/policy/dictionary_bigrams_structure_policy.h" + +namespace latinime { + +class BinaryDictionaryBigramsIterator { + public: + BinaryDictionaryBigramsIterator( + const DictionaryBigramsStructurePolicy *const bigramsStructurePolicy, const int pos) + : mBigramsStructurePolicy(bigramsStructurePolicy), mPos(pos), + mBigramPos(NOT_A_DICT_POS), mProbability(NOT_A_PROBABILITY), + mHasNext(pos != NOT_A_DICT_POS) {} + + AK_FORCE_INLINE bool hasNext() const { + return mHasNext; + } + + AK_FORCE_INLINE void next() { + mBigramsStructurePolicy->getNextBigram(&mBigramPos, &mProbability, &mHasNext, &mPos); + } + + AK_FORCE_INLINE int getProbability() const { + return mProbability; + } + + AK_FORCE_INLINE int getBigramPos() const { + return mBigramPos; + } + + private: + DISALLOW_COPY_AND_ASSIGN(BinaryDictionaryBigramsIterator); + + const DictionaryBigramsStructurePolicy *const mBigramsStructurePolicy; + int mPos; + int mBigramPos; + int mProbability; + bool mHasNext; +}; +} // namespace latinime +#endif // LATINIME_BINARY_DICTIONARY_BIGRAMS_ITERATOR_H diff --git a/native/jni/src/suggest/core/dictionary/binary_dictionary_format.cpp b/native/jni/src/suggest/core/dictionary/binary_dictionary_format.cpp deleted file mode 100644 index 50e0211d7..000000000 --- a/native/jni/src/suggest/core/dictionary/binary_dictionary_format.cpp +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Copyright (C) 2013 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "suggest/core/dictionary/binary_dictionary_format.h" - -namespace latinime { - -/** - * Dictionary size - */ -// Any file smaller than this is not a dictionary. -const int BinaryDictionaryFormat::DICTIONARY_MINIMUM_SIZE = 4; - -/** - * Format versions - */ -// Originally, format version 1 had a 16-bit magic number, then the version number `01' -// then options that must be 0. Hence the first 32-bits of the format are always as follow -// and it's okay to consider them a magic number as a whole. -const uint32_t BinaryDictionaryFormat::FORMAT_VERSION_1_MAGIC_NUMBER = 0x78B10100; -const int BinaryDictionaryFormat::FORMAT_VERSION_1_HEADER_SIZE = 5; - -// The versions of Latin IME that only handle format version 1 only test for the magic -// number, so we had to change it so that version 2 files would be rejected by older -// implementations. On this occasion, we made the magic number 32 bits long. -const uint32_t BinaryDictionaryFormat::FORMAT_VERSION_2_MAGIC_NUMBER = 0x9BC13AFE; -// Magic number (4 bytes), version (2 bytes), options (2 bytes), header size (4 bytes) = 12 -const int BinaryDictionaryFormat::FORMAT_VERSION_2_MINIMUM_SIZE = 12; -const int BinaryDictionaryFormat::VERSION_2_MAGIC_NUMBER_SIZE = 4; -const int BinaryDictionaryFormat::VERSION_2_DICTIONARY_VERSION_SIZE = 2; -const int BinaryDictionaryFormat::VERSION_2_DICTIONARY_FLAG_SIZE = 2; - -/* static */ BinaryDictionaryFormat::FORMAT_VERSION BinaryDictionaryFormat::detectFormatVersion( - const uint8_t *const dict, const int dictSize) { - // The magic number is stored big-endian. - // If the dictionary is less than 4 bytes, we can't even read the magic number, so we don't - // understand this format. - if (dictSize < DICTIONARY_MINIMUM_SIZE) { - return UNKNOWN_VERSION; - } - const uint32_t magicNumber = ByteArrayUtils::readUint32(dict, 0); - switch (magicNumber) { - case FORMAT_VERSION_1_MAGIC_NUMBER: - // Format 1 header is exactly 5 bytes long and looks like: - // Magic number (2 bytes) 0x78 0xB1 - // Version number (1 byte) 0x01 - // Options (2 bytes) must be 0x00 0x00 - return VERSION_1; - case FORMAT_VERSION_2_MAGIC_NUMBER: - // Version 2 dictionaries are at least 12 bytes long. - // If this dictionary has the version 2 magic number but is less than 12 bytes long, - // then it's an unknown format and we need to avoid confidently reading the next bytes. - if (dictSize < FORMAT_VERSION_2_MINIMUM_SIZE) { - return UNKNOWN_VERSION; - } - // Format 2 header is as follows: - // Magic number (4 bytes) 0x9B 0xC1 0x3A 0xFE - // Version number (2 bytes) 0x00 0x02 - // Options (2 bytes) - // Header size (4 bytes) : integer, big endian - if (ByteArrayUtils::readUint16(dict, 4) == 2) { - return VERSION_2; - } else { - return UNKNOWN_VERSION; - } - default: - return UNKNOWN_VERSION; - } -} - -} // namespace latinime diff --git a/native/jni/src/suggest/core/dictionary/binary_dictionary_format.h b/native/jni/src/suggest/core/dictionary/binary_dictionary_format.h deleted file mode 100644 index 3aa1662da..000000000 --- a/native/jni/src/suggest/core/dictionary/binary_dictionary_format.h +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright (C) 2013, The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LATINIME_BINARY_DICTIONARY_FORMAT_H -#define LATINIME_BINARY_DICTIONARY_FORMAT_H - -#include <stdint.h> - -#include "defines.h" -#include "suggest/core/dictionary/byte_array_utils.h" - -namespace latinime { - -/** - * Methods to handle binary dictionary format version. - * - * Currently, we have a file with a similar name, binary_format.h. binary_format.h contains binary - * reading methods and utility methods for various purposes. - * On the other hand, this file deals with only about dictionary format version. - */ -class BinaryDictionaryFormat { - public: - // TODO: Remove obsolete version logic - enum FORMAT_VERSION { - VERSION_1, - VERSION_2, - UNKNOWN_VERSION - }; - - static FORMAT_VERSION detectFormatVersion(const uint8_t *const dict, const int dictSize); - - static AK_FORCE_INLINE int getHeaderSize( - const uint8_t *const dict, const FORMAT_VERSION format) { - switch (format) { - case VERSION_1: - return FORMAT_VERSION_1_HEADER_SIZE; - case VERSION_2: - // See the format of the header in the comment in detectFormat() above - return ByteArrayUtils::readUint32(dict, 8); - default: - return S_INT_MAX; - } - } - - private: - DISALLOW_IMPLICIT_CONSTRUCTORS(BinaryDictionaryFormat); - - static const int DICTIONARY_MINIMUM_SIZE; - static const uint32_t FORMAT_VERSION_1_MAGIC_NUMBER; - static const int FORMAT_VERSION_1_HEADER_SIZE; - static const uint32_t FORMAT_VERSION_2_MAGIC_NUMBER; - static const int FORMAT_VERSION_2_MINIMUM_SIZE; - static const int VERSION_2_MAGIC_NUMBER_SIZE; - static const int VERSION_2_DICTIONARY_VERSION_SIZE ; - static const int VERSION_2_DICTIONARY_FLAG_SIZE; -}; -} // namespace latinime -#endif /* LATINIME_BINARY_DICTIONARY_FORMAT_H */ diff --git a/native/jni/src/suggest/core/dictionary/binary_dictionary_info.h b/native/jni/src/suggest/core/dictionary/binary_dictionary_info.h deleted file mode 100644 index 8508c6786..000000000 --- a/native/jni/src/suggest/core/dictionary/binary_dictionary_info.h +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright (C) 2013, The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LATINIME_BINARY_DICTIONARY_INFO_H -#define LATINIME_BINARY_DICTIONARY_INFO_H - -#include <stdint.h> - -#include "defines.h" -#include "suggest/core/dictionary/binary_dictionary_format.h" - -namespace latinime { - -class BinaryDictionaryInfo { - public: - BinaryDictionaryInfo(const uint8_t *const dictBuf, const int dictSize) - : mDictBuf(dictBuf), - mFormat(BinaryDictionaryFormat::detectFormatVersion(mDictBuf, dictSize)), - mDictRoot(mDictBuf + BinaryDictionaryFormat::getHeaderSize(mDictBuf, mFormat)) {} - - AK_FORCE_INLINE const uint8_t *getDictBuf() const { - return mDictBuf; - } - - AK_FORCE_INLINE const uint8_t *getDictRoot() const { - return mDictRoot; - } - - AK_FORCE_INLINE BinaryDictionaryFormat::FORMAT_VERSION getFormat() const { - return mFormat; - } - - AK_FORCE_INLINE int getRootPosition() const { - return 0; - } - - private: - DISALLOW_COPY_AND_ASSIGN(BinaryDictionaryInfo); - - const uint8_t *const mDictBuf; - const BinaryDictionaryFormat::FORMAT_VERSION mFormat; - const uint8_t *const mDictRoot; -}; -} -#endif /* LATINIME_BINARY_DICTIONARY_INFO_H */ diff --git a/native/jni/src/suggest/core/dictionary/binary_dictionary_shortcut_iterator.h b/native/jni/src/suggest/core/dictionary/binary_dictionary_shortcut_iterator.h new file mode 100644 index 000000000..558e0a5c3 --- /dev/null +++ b/native/jni/src/suggest/core/dictionary/binary_dictionary_shortcut_iterator.h @@ -0,0 +1,55 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_BINARY_DICTIONARY_SHORTCUT_ITERATOR_H +#define LATINIME_BINARY_DICTIONARY_SHORTCUT_ITERATOR_H + +#include "defines.h" +#include "suggest/core/policy/dictionary_shortcuts_structure_policy.h" + +namespace latinime { + +class BinaryDictionaryShortcutIterator { + public: + BinaryDictionaryShortcutIterator( + const DictionaryShortcutsStructurePolicy *const shortcutStructurePolicy, + const int shortcutPos) + : mShortcutStructurePolicy(shortcutStructurePolicy), + mPos(shortcutStructurePolicy->getStartPos(shortcutPos)), + mHasNextShortcutTarget(shortcutPos != NOT_A_DICT_POS) {} + + AK_FORCE_INLINE bool hasNextShortcutTarget() const { + return mHasNextShortcutTarget; + } + + // Gets the shortcut target itself as an int string and put it to outTarget, put its length + // to outTargetLength, put whether it is whitelist to outIsWhitelist. + AK_FORCE_INLINE void nextShortcutTarget( + const int maxDepth, int *const outTarget, int *const outTargetLength, + bool *const outIsWhitelist) { + mShortcutStructurePolicy->getNextShortcut(maxDepth, outTarget, outTargetLength, + outIsWhitelist, &mHasNextShortcutTarget, &mPos); + } + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(BinaryDictionaryShortcutIterator); + + const DictionaryShortcutsStructurePolicy *const mShortcutStructurePolicy; + int mPos; + bool mHasNextShortcutTarget; +}; +} // namespace latinime +#endif // LATINIME_BINARY_DICTIONARY_SHORTCUT_ITERATOR_H diff --git a/native/jni/src/suggest/core/dictionary/binary_format.h b/native/jni/src/suggest/core/dictionary/binary_format.h deleted file mode 100644 index 1b57793fa..000000000 --- a/native/jni/src/suggest/core/dictionary/binary_format.h +++ /dev/null @@ -1,745 +0,0 @@ -/* - * Copyright (C) 2011 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LATINIME_BINARY_FORMAT_H -#define LATINIME_BINARY_FORMAT_H - -#include <cstdlib> -#include <stdint.h> - -#include "suggest/core/dictionary/bloom_filter.h" -#include "suggest/core/dictionary/probability_utils.h" -#include "utils/char_utils.h" -#include "utils/hash_map_compat.h" - -namespace latinime { - -class BinaryFormat { - public: - // Mask and flags for children address type selection. - static const int MASK_GROUP_ADDRESS_TYPE = 0xC0; - - // Flag for single/multiple char group - static const int FLAG_HAS_MULTIPLE_CHARS = 0x20; - - // Flag for terminal groups - static const int FLAG_IS_TERMINAL = 0x10; - - // Flag for shortcut targets presence - static const int FLAG_HAS_SHORTCUT_TARGETS = 0x08; - // Flag for bigram presence - static const int FLAG_HAS_BIGRAMS = 0x04; - // Flag for non-words (typically, shortcut only entries) - static const int FLAG_IS_NOT_A_WORD = 0x02; - // Flag for blacklist - static const int FLAG_IS_BLACKLISTED = 0x01; - - // Attribute (bigram/shortcut) related flags: - // Flag for presence of more attributes - static const int FLAG_ATTRIBUTE_HAS_NEXT = 0x80; - // Flag for sign of offset. If this flag is set, the offset value must be negated. - static const int FLAG_ATTRIBUTE_OFFSET_NEGATIVE = 0x40; - - // Mask for attribute probability, stored on 4 bits inside the flags byte. - static const int MASK_ATTRIBUTE_PROBABILITY = 0x0F; - // The numeric value of the shortcut probability that means 'whitelist'. - static const int WHITELIST_SHORTCUT_PROBABILITY = 15; - - // Mask and flags for attribute address type selection. - static const int MASK_ATTRIBUTE_ADDRESS_TYPE = 0x30; - - static const int UNKNOWN_FORMAT = -1; - static const int SHORTCUT_LIST_SIZE_SIZE = 2; - - static int detectFormat(const uint8_t *const dict, const int dictSize); - static int getHeaderSize(const uint8_t *const dict, const int dictSize); - static int getFlags(const uint8_t *const dict, const int dictSize); - static bool hasBlacklistedOrNotAWordFlag(const int flags); - static void readHeaderValue(const uint8_t *const dict, const int dictSize, - const char *const key, int *outValue, const int outValueSize); - static int readHeaderValueInt(const uint8_t *const dict, const int dictSize, - const char *const key); - static int getGroupCountAndForwardPointer(const uint8_t *const dict, int *pos); - static uint8_t getFlagsAndForwardPointer(const uint8_t *const dict, int *pos); - static int getCodePointAndForwardPointer(const uint8_t *const dict, int *pos); - static int readProbabilityWithoutMovingPointer(const uint8_t *const dict, const int pos); - static int skipOtherCharacters(const uint8_t *const dict, const int pos); - static int skipChildrenPosition(const uint8_t flags, const int pos); - static int skipProbability(const uint8_t flags, const int pos); - static int skipShortcuts(const uint8_t *const dict, const uint8_t flags, const int pos); - static int skipChildrenPosAndAttributes(const uint8_t *const dict, const uint8_t flags, - const int pos); - static int readChildrenPosition(const uint8_t *const dict, const uint8_t flags, const int pos); - static bool hasChildrenInFlags(const uint8_t flags); - static int getAttributeAddressAndForwardPointer(const uint8_t *const dict, const uint8_t flags, - int *pos); - static int getAttributeProbabilityFromFlags(const int flags); - static int getTerminalPosition(const uint8_t *const root, const int *const inWord, - const int length, const bool forceLowerCaseSearch); - static int getWordAtAddress(const uint8_t *const root, const int address, const int maxDepth, - int *outWord, int *outUnigramProbability); - static int getBigramProbabilityFromHashMap(const int position, - const hash_map_compat<int, int> *bigramMap, const int unigramProbability); - static float getMultiWordCostMultiplier(const uint8_t *const dict, const int dictSize); - static void fillBigramProbabilityToHashMap(const uint8_t *const root, int position, - hash_map_compat<int, int> *bigramMap); - static int getBigramProbability(const uint8_t *const root, int position, - const int nextPosition, const int unigramProbability); - - // Flags for special processing - // Those *must* match the flags in makedict (BinaryDictInputOutput#*_PROCESSING_FLAG) or - // something very bad (like, the apocalypse) will happen. Please update both at the same time. - enum { - REQUIRES_GERMAN_UMLAUT_PROCESSING = 0x1, - REQUIRES_FRENCH_LIGATURES_PROCESSING = 0x4 - }; - - private: - DISALLOW_IMPLICIT_CONSTRUCTORS(BinaryFormat); - static int getBigramListPositionForWordPosition(const uint8_t *const root, int position); - - static const int FLAG_GROUP_ADDRESS_TYPE_NOADDRESS = 0x00; - static const int FLAG_GROUP_ADDRESS_TYPE_ONEBYTE = 0x40; - static const int FLAG_GROUP_ADDRESS_TYPE_TWOBYTES = 0x80; - static const int FLAG_GROUP_ADDRESS_TYPE_THREEBYTES = 0xC0; - static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE = 0x10; - static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES = 0x20; - static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES = 0x30; - - // Any file smaller than this is not a dictionary. - static const int DICTIONARY_MINIMUM_SIZE = 4; - // Originally, format version 1 had a 16-bit magic number, then the version number `01' - // then options that must be 0. Hence the first 32-bits of the format are always as follow - // and it's okay to consider them a magic number as a whole. - static const int FORMAT_VERSION_1_MAGIC_NUMBER = 0x78B10100; - static const int FORMAT_VERSION_1_HEADER_SIZE = 5; - // The versions of Latin IME that only handle format version 1 only test for the magic - // number, so we had to change it so that version 2 files would be rejected by older - // implementations. On this occasion, we made the magic number 32 bits long. - static const int FORMAT_VERSION_2_MAGIC_NUMBER = -1681835266; // 0x9BC13AFE - // Magic number (4 bytes), version (2 bytes), options (2 bytes), header size (4 bytes) = 12 - static const int FORMAT_VERSION_2_MINIMUM_SIZE = 12; - - static const int CHARACTER_ARRAY_TERMINATOR_SIZE = 1; - static const int MINIMAL_ONE_BYTE_CHARACTER_VALUE = 0x20; - static const int CHARACTER_ARRAY_TERMINATOR = 0x1F; - static const int MULTIPLE_BYTE_CHARACTER_ADDITIONAL_SIZE = 2; - static const int NO_FLAGS = 0; - static int skipAllAttributes(const uint8_t *const dict, const uint8_t flags, const int pos); - static int skipBigrams(const uint8_t *const dict, const uint8_t flags, const int pos); -}; - -AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict, const int dictSize) { - // The magic number is stored big-endian. - // If the dictionary is less than 4 bytes, we can't even read the magic number, so we don't - // understand this format. - if (dictSize < DICTIONARY_MINIMUM_SIZE) return UNKNOWN_FORMAT; - const int magicNumber = (dict[0] << 24) + (dict[1] << 16) + (dict[2] << 8) + dict[3]; - switch (magicNumber) { - case FORMAT_VERSION_1_MAGIC_NUMBER: - // Format 1 header is exactly 5 bytes long and looks like: - // Magic number (2 bytes) 0x78 0xB1 - // Version number (1 byte) 0x01 - // Options (2 bytes) must be 0x00 0x00 - return 1; - case FORMAT_VERSION_2_MAGIC_NUMBER: - // Version 2 dictionaries are at least 12 bytes long (see below details for the header). - // If this dictionary has the version 2 magic number but is less than 12 bytes long, then - // it's an unknown format and we need to avoid confidently reading the next bytes. - if (dictSize < FORMAT_VERSION_2_MINIMUM_SIZE) return UNKNOWN_FORMAT; - // Format 2 header is as follows: - // Magic number (4 bytes) 0x9B 0xC1 0x3A 0xFE - // Version number (2 bytes) 0x00 0x02 - // Options (2 bytes) - // Header size (4 bytes) : integer, big endian - return (dict[4] << 8) + dict[5]; - default: - return UNKNOWN_FORMAT; - } -} - -inline int BinaryFormat::getFlags(const uint8_t *const dict, const int dictSize) { - switch (detectFormat(dict, dictSize)) { - case 1: - return NO_FLAGS; // TODO: NO_FLAGS is unused anywhere else? - default: - return (dict[6] << 8) + dict[7]; - } -} - -inline bool BinaryFormat::hasBlacklistedOrNotAWordFlag(const int flags) { - return (flags & (FLAG_IS_BLACKLISTED | FLAG_IS_NOT_A_WORD)) != 0; -} - -inline int BinaryFormat::getHeaderSize(const uint8_t *const dict, const int dictSize) { - switch (detectFormat(dict, dictSize)) { - case 1: - return FORMAT_VERSION_1_HEADER_SIZE; - case 2: - // See the format of the header in the comment in detectFormat() above - return (dict[8] << 24) + (dict[9] << 16) + (dict[10] << 8) + dict[11]; - default: - return S_INT_MAX; - } -} - -inline void BinaryFormat::readHeaderValue(const uint8_t *const dict, const int dictSize, - const char *const key, int *outValue, const int outValueSize) { - int outValueIndex = 0; - // Only format 2 and above have header attributes as {key,value} string pairs. For prior - // formats, we just return an empty string, as if the key wasn't found. - if (2 <= detectFormat(dict, dictSize)) { - const int headerOptionsOffset = 4 /* magic number */ - + 2 /* dictionary version */ + 2 /* flags */; - const int headerSize = - (dict[headerOptionsOffset] << 24) + (dict[headerOptionsOffset + 1] << 16) - + (dict[headerOptionsOffset + 2] << 8) + dict[headerOptionsOffset + 3]; - const int headerEnd = headerOptionsOffset + 4 + headerSize; - int index = headerOptionsOffset + 4; - while (index < headerEnd) { - int keyIndex = 0; - int codePoint = getCodePointAndForwardPointer(dict, &index); - while (codePoint != NOT_A_CODE_POINT) { - if (codePoint != key[keyIndex++]) { - break; - } - codePoint = getCodePointAndForwardPointer(dict, &index); - } - if (codePoint == NOT_A_CODE_POINT && key[keyIndex] == 0) { - // We found the key! Copy and return the value. - codePoint = getCodePointAndForwardPointer(dict, &index); - while (codePoint != NOT_A_CODE_POINT && outValueIndex < outValueSize) { - outValue[outValueIndex++] = codePoint; - codePoint = getCodePointAndForwardPointer(dict, &index); - } - // Finished copying. Break to go to the termination code. - break; - } - // We didn't find the key, skip the remainder of it and its value - while (codePoint != NOT_A_CODE_POINT) { - codePoint = getCodePointAndForwardPointer(dict, &index); - } - codePoint = getCodePointAndForwardPointer(dict, &index); - while (codePoint != NOT_A_CODE_POINT) { - codePoint = getCodePointAndForwardPointer(dict, &index); - } - } - // We couldn't find it - fall through and return an empty value. - } - // Put a terminator 0 if possible at all (always unless outValueSize is <= 0) - if (outValueIndex >= outValueSize) outValueIndex = outValueSize - 1; - if (outValueIndex >= 0) outValue[outValueIndex] = 0; -} - -inline int BinaryFormat::readHeaderValueInt(const uint8_t *const dict, const int dictSize, - const char *const key) { - const int bufferSize = LARGEST_INT_DIGIT_COUNT; - int intBuffer[bufferSize]; - char charBuffer[bufferSize]; - BinaryFormat::readHeaderValue(dict, dictSize, key, intBuffer, bufferSize); - for (int i = 0; i < bufferSize; ++i) { - charBuffer[i] = intBuffer[i]; - } - // If not a number, return S_INT_MIN - if (!isdigit(charBuffer[0])) return S_INT_MIN; - return atoi(charBuffer); -} - -AK_FORCE_INLINE int BinaryFormat::getGroupCountAndForwardPointer(const uint8_t *const dict, - int *pos) { - const int msb = dict[(*pos)++]; - if (msb < 0x80) return msb; - return ((msb & 0x7F) << 8) | dict[(*pos)++]; -} - -inline float BinaryFormat::getMultiWordCostMultiplier(const uint8_t *const dict, - const int dictSize) { - const int headerValue = readHeaderValueInt(dict, dictSize, - "MULTIPLE_WORDS_DEMOTION_RATE"); - if (headerValue == S_INT_MIN) { - return 1.0f; - } - if (headerValue <= 0) { - return static_cast<float>(MAX_VALUE_FOR_WEIGHTING); - } - return 100.0f / static_cast<float>(headerValue); -} - -inline uint8_t BinaryFormat::getFlagsAndForwardPointer(const uint8_t *const dict, int *pos) { - return dict[(*pos)++]; -} - -AK_FORCE_INLINE int BinaryFormat::getCodePointAndForwardPointer(const uint8_t *const dict, - int *pos) { - const int origin = *pos; - const int codePoint = dict[origin]; - if (codePoint < MINIMAL_ONE_BYTE_CHARACTER_VALUE) { - if (codePoint == CHARACTER_ARRAY_TERMINATOR) { - *pos = origin + 1; - return NOT_A_CODE_POINT; - } else { - *pos = origin + 3; - const int char_1 = codePoint << 16; - const int char_2 = char_1 + (dict[origin + 1] << 8); - return char_2 + dict[origin + 2]; - } - } else { - *pos = origin + 1; - return codePoint; - } -} - -inline int BinaryFormat::readProbabilityWithoutMovingPointer(const uint8_t *const dict, - const int pos) { - return dict[pos]; -} - -AK_FORCE_INLINE int BinaryFormat::skipOtherCharacters(const uint8_t *const dict, const int pos) { - int currentPos = pos; - int character = dict[currentPos++]; - while (CHARACTER_ARRAY_TERMINATOR != character) { - if (character < MINIMAL_ONE_BYTE_CHARACTER_VALUE) { - currentPos += MULTIPLE_BYTE_CHARACTER_ADDITIONAL_SIZE; - } - character = dict[currentPos++]; - } - return currentPos; -} - -static inline int attributeAddressSize(const uint8_t flags) { - static const int ATTRIBUTE_ADDRESS_SHIFT = 4; - return (flags & BinaryFormat::MASK_ATTRIBUTE_ADDRESS_TYPE) >> ATTRIBUTE_ADDRESS_SHIFT; - /* Note: this is a value-dependant optimization of what may probably be - more readably written this way: - switch (flags * BinaryFormat::MASK_ATTRIBUTE_ADDRESS_TYPE) { - case FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE: return 1; - case FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES: return 2; - case FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTE: return 3; - default: return 0; - } - */ -} - -static AK_FORCE_INLINE int skipExistingBigrams(const uint8_t *const dict, const int pos) { - int currentPos = pos; - uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(dict, ¤tPos); - while (flags & BinaryFormat::FLAG_ATTRIBUTE_HAS_NEXT) { - currentPos += attributeAddressSize(flags); - flags = BinaryFormat::getFlagsAndForwardPointer(dict, ¤tPos); - } - currentPos += attributeAddressSize(flags); - return currentPos; -} - -static inline int childrenAddressSize(const uint8_t flags) { - static const int CHILDREN_ADDRESS_SHIFT = 6; - return (BinaryFormat::MASK_GROUP_ADDRESS_TYPE & flags) >> CHILDREN_ADDRESS_SHIFT; - /* See the note in attributeAddressSize. The same applies here */ -} - -static AK_FORCE_INLINE int shortcutByteSize(const uint8_t *const dict, const int pos) { - return (static_cast<int>(dict[pos] << 8)) + (dict[pos + 1]); -} - -inline int BinaryFormat::skipChildrenPosition(const uint8_t flags, const int pos) { - return pos + childrenAddressSize(flags); -} - -inline int BinaryFormat::skipProbability(const uint8_t flags, const int pos) { - return FLAG_IS_TERMINAL & flags ? pos + 1 : pos; -} - -AK_FORCE_INLINE int BinaryFormat::skipShortcuts(const uint8_t *const dict, const uint8_t flags, - const int pos) { - if (FLAG_HAS_SHORTCUT_TARGETS & flags) { - return pos + shortcutByteSize(dict, pos); - } else { - return pos; - } -} - -AK_FORCE_INLINE int BinaryFormat::skipBigrams(const uint8_t *const dict, const uint8_t flags, - const int pos) { - if (FLAG_HAS_BIGRAMS & flags) { - return skipExistingBigrams(dict, pos); - } else { - return pos; - } -} - -AK_FORCE_INLINE int BinaryFormat::skipAllAttributes(const uint8_t *const dict, const uint8_t flags, - const int pos) { - // This function skips all attributes: shortcuts and bigrams. - int newPos = pos; - newPos = skipShortcuts(dict, flags, newPos); - newPos = skipBigrams(dict, flags, newPos); - return newPos; -} - -AK_FORCE_INLINE int BinaryFormat::skipChildrenPosAndAttributes(const uint8_t *const dict, - const uint8_t flags, const int pos) { - int currentPos = pos; - currentPos = skipChildrenPosition(flags, currentPos); - currentPos = skipAllAttributes(dict, flags, currentPos); - return currentPos; -} - -AK_FORCE_INLINE int BinaryFormat::readChildrenPosition(const uint8_t *const dict, - const uint8_t flags, const int pos) { - int offset = 0; - switch (MASK_GROUP_ADDRESS_TYPE & flags) { - case FLAG_GROUP_ADDRESS_TYPE_ONEBYTE: - offset = dict[pos]; - break; - case FLAG_GROUP_ADDRESS_TYPE_TWOBYTES: - offset = dict[pos] << 8; - offset += dict[pos + 1]; - break; - case FLAG_GROUP_ADDRESS_TYPE_THREEBYTES: - offset = dict[pos] << 16; - offset += dict[pos + 1] << 8; - offset += dict[pos + 2]; - break; - default: - // If we come here, it means we asked for the children of a word with - // no children. - return -1; - } - return pos + offset; -} - -inline bool BinaryFormat::hasChildrenInFlags(const uint8_t flags) { - return (FLAG_GROUP_ADDRESS_TYPE_NOADDRESS != (MASK_GROUP_ADDRESS_TYPE & flags)); -} - -AK_FORCE_INLINE int BinaryFormat::getAttributeAddressAndForwardPointer(const uint8_t *const dict, - const uint8_t flags, int *pos) { - int offset = 0; - const int origin = *pos; - switch (MASK_ATTRIBUTE_ADDRESS_TYPE & flags) { - case FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE: - offset = dict[origin]; - *pos = origin + 1; - break; - case FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES: - offset = dict[origin] << 8; - offset += dict[origin + 1]; - *pos = origin + 2; - break; - case FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES: - offset = dict[origin] << 16; - offset += dict[origin + 1] << 8; - offset += dict[origin + 2]; - *pos = origin + 3; - break; - } - if (FLAG_ATTRIBUTE_OFFSET_NEGATIVE & flags) { - return origin - offset; - } else { - return origin + offset; - } -} - -inline int BinaryFormat::getAttributeProbabilityFromFlags(const int flags) { - return flags & MASK_ATTRIBUTE_PROBABILITY; -} - -// This function gets the byte position of the last chargroup of the exact matching word in the -// dictionary. If no match is found, it returns NOT_VALID_WORD. -AK_FORCE_INLINE int BinaryFormat::getTerminalPosition(const uint8_t *const root, - const int *const inWord, const int length, const bool forceLowerCaseSearch) { - int pos = 0; - int wordPos = 0; - - while (true) { - // If we already traversed the tree further than the word is long, there means - // there was no match (or we would have found it). - if (wordPos >= length) return NOT_VALID_WORD; - int charGroupCount = BinaryFormat::getGroupCountAndForwardPointer(root, &pos); - const int wChar = forceLowerCaseSearch - ? CharUtils::toLowerCase(inWord[wordPos]) : inWord[wordPos]; - while (true) { - // If there are no more character groups in this node, it means we could not - // find a matching character for this depth, therefore there is no match. - if (0 >= charGroupCount) return NOT_VALID_WORD; - const int charGroupPos = pos; - const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(root, &pos); - int character = BinaryFormat::getCodePointAndForwardPointer(root, &pos); - if (character == wChar) { - // This is the correct node. Only one character group may start with the same - // char within a node, so either we found our match in this node, or there is - // no match and we can return NOT_VALID_WORD. So we will check all the characters - // in this character group indeed does match. - if (FLAG_HAS_MULTIPLE_CHARS & flags) { - character = BinaryFormat::getCodePointAndForwardPointer(root, &pos); - while (NOT_A_CODE_POINT != character) { - ++wordPos; - // If we shoot the length of the word we search for, or if we find a single - // character that does not match, as explained above, it means the word is - // not in the dictionary (by virtue of this chargroup being the only one to - // match the word on the first character, but not matching the whole word). - if (wordPos >= length) return NOT_VALID_WORD; - if (inWord[wordPos] != character) return NOT_VALID_WORD; - character = BinaryFormat::getCodePointAndForwardPointer(root, &pos); - } - } - // If we come here we know that so far, we do match. Either we are on a terminal - // and we match the length, in which case we found it, or we traverse children. - // If we don't match the length AND don't have children, then a word in the - // dictionary fully matches a prefix of the searched word but not the full word. - ++wordPos; - if (FLAG_IS_TERMINAL & flags) { - if (wordPos == length) { - return charGroupPos; - } - pos = BinaryFormat::skipProbability(FLAG_IS_TERMINAL, pos); - } - if (FLAG_GROUP_ADDRESS_TYPE_NOADDRESS == (MASK_GROUP_ADDRESS_TYPE & flags)) { - return NOT_VALID_WORD; - } - // We have children and we are still shorter than the word we are searching for, so - // we need to traverse children. Put the pointer on the children position, and - // break - pos = BinaryFormat::readChildrenPosition(root, flags, pos); - break; - } else { - // This chargroup does not match, so skip the remaining part and go to the next. - if (FLAG_HAS_MULTIPLE_CHARS & flags) { - pos = BinaryFormat::skipOtherCharacters(root, pos); - } - pos = BinaryFormat::skipProbability(flags, pos); - pos = BinaryFormat::skipChildrenPosAndAttributes(root, flags, pos); - } - --charGroupCount; - } - } -} - -// This function searches for a terminal in the dictionary by its address. -// Due to the fact that words are ordered in the dictionary in a strict breadth-first order, -// it is possible to check for this with advantageous complexity. For each node, we search -// for groups with children and compare the children address with the address we look for. -// When we shoot the address we look for, it means the word we look for is in the children -// of the previous group. The only tricky part is the fact that if we arrive at the end of a -// node with the last group's children address still less than what we are searching for, we -// must descend the last group's children (for example, if the word we are searching for starts -// with a z, it's the last group of the root node, so all children addresses will be smaller -// than the address we look for, and we have to descend the z node). -/* Parameters : - * root: the dictionary buffer - * address: the byte position of the last chargroup of the word we are searching for (this is - * what is stored as the "bigram address" in each bigram) - * outword: an array to write the found word, with MAX_WORD_LENGTH size. - * outUnigramProbability: a pointer to an int to write the probability into. - * Return value : the length of the word, of 0 if the word was not found. - */ -AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, const int address, - const int maxDepth, int *outWord, int *outUnigramProbability) { - int pos = 0; - int wordPos = 0; - - // One iteration of the outer loop iterates through nodes. As stated above, we will only - // traverse nodes that are actually a part of the terminal we are searching, so each time - // we enter this loop we are one depth level further than last time. - // The only reason we count nodes is because we want to reduce the probability of infinite - // looping in case there is a bug. Since we know there is an upper bound to the depth we are - // supposed to traverse, it does not hurt to count iterations. - for (int loopCount = maxDepth; loopCount > 0; --loopCount) { - int lastCandidateGroupPos = 0; - // Let's loop through char groups in this node searching for either the terminal - // or one of its ascendants. - for (int charGroupCount = getGroupCountAndForwardPointer(root, &pos); charGroupCount > 0; - --charGroupCount) { - const int startPos = pos; - const uint8_t flags = getFlagsAndForwardPointer(root, &pos); - const int character = getCodePointAndForwardPointer(root, &pos); - if (address == startPos) { - // We found the address. Copy the rest of the word in the buffer and return - // the length. - outWord[wordPos] = character; - if (FLAG_HAS_MULTIPLE_CHARS & flags) { - int nextChar = getCodePointAndForwardPointer(root, &pos); - // We count chars in order to avoid infinite loops if the file is broken or - // if there is some other bug - int charCount = maxDepth; - while (NOT_A_CODE_POINT != nextChar && --charCount > 0) { - outWord[++wordPos] = nextChar; - nextChar = getCodePointAndForwardPointer(root, &pos); - } - } - *outUnigramProbability = readProbabilityWithoutMovingPointer(root, pos); - return ++wordPos; - } - // We need to skip past this char group, so skip any remaining chars after the - // first and possibly the probability. - if (FLAG_HAS_MULTIPLE_CHARS & flags) { - pos = skipOtherCharacters(root, pos); - } - pos = skipProbability(flags, pos); - - // The fact that this group has children is very important. Since we already know - // that this group does not match, if it has no children we know it is irrelevant - // to what we are searching for. - const bool hasChildren = (FLAG_GROUP_ADDRESS_TYPE_NOADDRESS != - (MASK_GROUP_ADDRESS_TYPE & flags)); - // We will write in `found' whether we have passed the children address we are - // searching for. For example if we search for "beer", the children of b are less - // than the address we are searching for and the children of c are greater. When we - // come here for c, we realize this is too big, and that we should descend b. - bool found; - if (hasChildren) { - // Here comes the tricky part. First, read the children position. - const int childrenPos = readChildrenPosition(root, flags, pos); - if (childrenPos > address) { - // If the children pos is greater than address, it means the previous chargroup, - // which address is stored in lastCandidateGroupPos, was the right one. - found = true; - } else if (1 >= charGroupCount) { - // However if we are on the LAST group of this node, and we have NOT shot the - // address we should descend THIS node. So we trick the lastCandidateGroupPos - // so that we will descend this node, not the previous one. - lastCandidateGroupPos = startPos; - found = true; - } else { - // Else, we should continue looking. - found = false; - } - } else { - // Even if we don't have children here, we could still be on the last group of this - // node. If this is the case, we should descend the last group that had children, - // and their address is already in lastCandidateGroup. - found = (1 >= charGroupCount); - } - - if (found) { - // Okay, we found the group we should descend. Its address is in - // the lastCandidateGroupPos variable, so we just re-read it. - if (0 != lastCandidateGroupPos) { - const uint8_t lastFlags = - getFlagsAndForwardPointer(root, &lastCandidateGroupPos); - const int lastChar = - getCodePointAndForwardPointer(root, &lastCandidateGroupPos); - // We copy all the characters in this group to the buffer - outWord[wordPos] = lastChar; - if (FLAG_HAS_MULTIPLE_CHARS & lastFlags) { - int nextChar = getCodePointAndForwardPointer(root, &lastCandidateGroupPos); - int charCount = maxDepth; - while (-1 != nextChar && --charCount > 0) { - outWord[++wordPos] = nextChar; - nextChar = getCodePointAndForwardPointer(root, &lastCandidateGroupPos); - } - } - ++wordPos; - // Now we only need to branch to the children address. Skip the probability if - // it's there, read pos, and break to resume the search at pos. - lastCandidateGroupPos = skipProbability(lastFlags, lastCandidateGroupPos); - pos = readChildrenPosition(root, lastFlags, lastCandidateGroupPos); - break; - } else { - // Here is a little tricky part: we come here if we found out that all children - // addresses in this group are bigger than the address we are searching for. - // Should we conclude the word is not in the dictionary? No! It could still be - // one of the remaining chargroups in this node, so we have to keep looking in - // this node until we find it (or we realize it's not there either, in which - // case it's actually not in the dictionary). Pass the end of this group, ready - // to start the next one. - pos = skipChildrenPosAndAttributes(root, flags, pos); - } - } else { - // If we did not find it, we should record the last children address for the next - // iteration. - if (hasChildren) lastCandidateGroupPos = startPos; - // Now skip the end of this group (children pos and the attributes if any) so that - // our pos is after the end of this char group, at the start of the next one. - pos = skipChildrenPosAndAttributes(root, flags, pos); - } - - } - } - // If we have looked through all the chargroups and found no match, the address is - // not the address of a terminal in this dictionary. - return 0; -} - -// This returns a probability in log space. -inline int BinaryFormat::getBigramProbabilityFromHashMap(const int position, - const hash_map_compat<int, int> *bigramMap, const int unigramProbability) { - if (!bigramMap) { - return ProbabilityUtils::backoff(unigramProbability); - } - const hash_map_compat<int, int>::const_iterator bigramProbabilityIt = bigramMap->find(position); - if (bigramProbabilityIt != bigramMap->end()) { - const int bigramProbability = bigramProbabilityIt->second; - return ProbabilityUtils::computeProbabilityForBigram(unigramProbability, bigramProbability); - } - return ProbabilityUtils::backoff(unigramProbability); -} - -AK_FORCE_INLINE void BinaryFormat::fillBigramProbabilityToHashMap( - const uint8_t *const root, int position, hash_map_compat<int, int> *bigramMap) { - position = getBigramListPositionForWordPosition(root, position); - if (0 == position) return; - - uint8_t bigramFlags; - do { - bigramFlags = getFlagsAndForwardPointer(root, &position); - const int probability = MASK_ATTRIBUTE_PROBABILITY & bigramFlags; - const int bigramPos = getAttributeAddressAndForwardPointer(root, bigramFlags, - &position); - (*bigramMap)[bigramPos] = probability; - } while (FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags); -} - -AK_FORCE_INLINE int BinaryFormat::getBigramProbability(const uint8_t *const root, int position, - const int nextPosition, const int unigramProbability) { - position = getBigramListPositionForWordPosition(root, position); - if (0 == position) { - return ProbabilityUtils::backoff(unigramProbability); - } - - uint8_t bigramFlags; - do { - bigramFlags = getFlagsAndForwardPointer(root, &position); - const int bigramPos = getAttributeAddressAndForwardPointer( - root, bigramFlags, &position); - if (bigramPos == nextPosition) { - const int bigramProbability = MASK_ATTRIBUTE_PROBABILITY & bigramFlags; - return ProbabilityUtils::computeProbabilityForBigram( - unigramProbability, bigramProbability); - } - } while (FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags); - return ProbabilityUtils::backoff(unigramProbability); -} - -// Returns a pointer to the start of the bigram list. -AK_FORCE_INLINE int BinaryFormat::getBigramListPositionForWordPosition( - const uint8_t *const root, int position) { - if (NOT_VALID_WORD == position) return 0; - const uint8_t flags = getFlagsAndForwardPointer(root, &position); - if (!(flags & FLAG_HAS_BIGRAMS)) return 0; - if (flags & FLAG_HAS_MULTIPLE_CHARS) { - position = skipOtherCharacters(root, position); - } else { - getCodePointAndForwardPointer(root, &position); - } - position = skipProbability(flags, position); - position = skipChildrenPosition(flags, position); - position = skipShortcuts(root, flags, position); - return position; -} - -} // namespace latinime -#endif // LATINIME_BINARY_FORMAT_H diff --git a/native/jni/src/suggest/core/dictionary/bloom_filter.cpp b/native/jni/src/suggest/core/dictionary/bloom_filter.cpp new file mode 100644 index 000000000..4ae474e0c --- /dev/null +++ b/native/jni/src/suggest/core/dictionary/bloom_filter.cpp @@ -0,0 +1,25 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/core/dictionary/bloom_filter.h" + +namespace latinime { + +// Must be smaller than BIGRAM_FILTER_BYTE_SIZE * 8, and preferably prime. 1021 is the largest +// prime under 128 * 8. +const int BloomFilter::BIGRAM_FILTER_MODULO = 1021; + +} // namespace latinime diff --git a/native/jni/src/suggest/core/dictionary/bloom_filter.h b/native/jni/src/suggest/core/dictionary/bloom_filter.h index bcce1f7ea..5205456a8 100644 --- a/native/jni/src/suggest/core/dictionary/bloom_filter.h +++ b/native/jni/src/suggest/core/dictionary/bloom_filter.h @@ -23,16 +23,48 @@ namespace latinime { -// TODO: uint32_t position -static inline void setInFilter(uint8_t *filter, const int32_t position) { - const uint32_t bucket = static_cast<uint32_t>(position % BIGRAM_FILTER_MODULO); - filter[bucket >> 3] |= static_cast<uint8_t>(1 << (bucket & 0x7)); -} - -// TODO: uint32_t position -static inline bool isInFilter(const uint8_t *filter, const int32_t position) { - const uint32_t bucket = static_cast<uint32_t>(position % BIGRAM_FILTER_MODULO); - return filter[bucket >> 3] & static_cast<uint8_t>(1 << (bucket & 0x7)); -} +// This bloom filter is used for optimizing bigram retrieval. +// Execution times with previous word "this" are as follows: +// without bloom filter (use only hash_map): +// Total 147792.34 (sum of others 147771.57) +// with bloom filter: +// Total 145900.64 (sum of others 145874.30) +// always read binary dictionary: +// Total 148603.14 (sum of others 148579.90) +class BloomFilter { + public: + BloomFilter() { + ASSERT(BIGRAM_FILTER_BYTE_SIZE * 8 >= BIGRAM_FILTER_MODULO); + } + + // TODO: uint32_t position + AK_FORCE_INLINE void setInFilter(const int32_t position) { + const uint32_t bucket = static_cast<uint32_t>(position % BIGRAM_FILTER_MODULO); + mFilter[bucket >> 3] |= static_cast<uint8_t>(1 << (bucket & 0x7)); + } + + // TODO: uint32_t position + AK_FORCE_INLINE bool isInFilter(const int32_t position) const { + const uint32_t bucket = static_cast<uint32_t>(position % BIGRAM_FILTER_MODULO); + return (mFilter[bucket >> 3] & static_cast<uint8_t>(1 << (bucket & 0x7))) != 0; + } + + private: + // Size, in bytes, of the bloom filter index for bigrams + // 128 gives us 1024 buckets. The probability of false positive is (1 - e ** (-kn/m))**k, + // where k is the number of hash functions, n the number of bigrams, and m the number of + // bits we can test. + // At the moment 100 is the maximum number of bigrams for a word with the current + // dictionaries, so n = 100. 1024 buckets give us m = 1024. + // With 1 hash function, our false positive rate is about 9.3%, which should be enough for + // our uses since we are only using this to increase average performance. For the record, + // k = 2 gives 3.1% and k = 3 gives 1.6%. With k = 1, making m = 2048 gives 4.8%, + // and m = 4096 gives 2.4%. + // This is assigned here because it is used for array size. + static const int BIGRAM_FILTER_BYTE_SIZE = 128; + static const int BIGRAM_FILTER_MODULO; + + uint8_t mFilter[BIGRAM_FILTER_BYTE_SIZE]; +}; } // namespace latinime #endif // LATINIME_BLOOM_FILTER_H diff --git a/native/jni/src/suggest/core/dictionary/byte_array_utils.h b/native/jni/src/suggest/core/dictionary/byte_array_utils.h deleted file mode 100644 index 832b74725..000000000 --- a/native/jni/src/suggest/core/dictionary/byte_array_utils.h +++ /dev/null @@ -1,148 +0,0 @@ -/* - * Copyright (C) 2013, The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LATINIME_BYTE_ARRAY_UTILS_H -#define LATINIME_BYTE_ARRAY_UTILS_H - -#include <stdint.h> - -#include "defines.h" - -namespace latinime { - -/** - * Utility methods for reading byte arrays. - */ -class ByteArrayUtils { - public: - /** - * Integer - * - * Each method read a corresponding size integer in a big endian manner. - */ - static AK_FORCE_INLINE uint32_t readUint32(const uint8_t *const buffer, const int pos) { - return (buffer[pos] << 24) ^ (buffer[pos + 1] << 16) - ^ (buffer[pos + 2] << 8) ^ buffer[pos + 3]; - } - - static AK_FORCE_INLINE uint32_t readUint24(const uint8_t *const buffer, const int pos) { - return (buffer[pos] << 16) ^ (buffer[pos + 1] << 8) ^ buffer[pos + 2]; - } - - static AK_FORCE_INLINE uint16_t readUint16(const uint8_t *const buffer, const int pos) { - return (buffer[pos] << 8) ^ buffer[pos + 1]; - } - - static AK_FORCE_INLINE uint8_t readUint8(const uint8_t *const buffer, const int pos) { - return buffer[pos]; - } - - static AK_FORCE_INLINE uint32_t readUint32andAdvancePosition( - const uint8_t *const buffer, int *const pos) { - const uint32_t value = readUint32(buffer, *pos); - *pos += 4; - return value; - } - - static AK_FORCE_INLINE uint32_t readUint24andAdvancePosition( - const uint8_t *const buffer, int *const pos) { - const uint32_t value = readUint24(buffer, *pos); - *pos += 3; - return value; - } - - static AK_FORCE_INLINE uint16_t readUint16andAdvancePosition( - const uint8_t *const buffer, int *const pos) { - const uint16_t value = readUint16(buffer, *pos); - *pos += 2; - return value; - } - - static AK_FORCE_INLINE uint8_t readUint8andAdvancePosition( - const uint8_t *const buffer, int *const pos) { - return buffer[(*pos)++]; - } - - /** - * Code Point - * - * 1 byte = bbbbbbbb match - * case 000xxxxx: xxxxx << 16 + next byte << 8 + next byte - * else: if 00011111 (= 0x1F) : this is the terminator. This is a relevant choice because - * unicode code points range from 0 to 0x10FFFF, so any 3-byte value starting with - * 00011111 would be outside unicode. - * else: iso-latin-1 code - * This allows for the whole unicode range to be encoded, including chars outside of - * the BMP. Also everything in the iso-latin-1 charset is only 1 byte, except control - * characters which should never happen anyway (and still work, but take 3 bytes). - */ - static AK_FORCE_INLINE int readCodePoint(const uint8_t *const buffer, const int pos) { - int p = pos; - return readCodePointAndAdvancePosition(buffer, &p); - } - - static AK_FORCE_INLINE int readCodePointAndAdvancePosition( - const uint8_t *const buffer, int *const pos) { - const uint8_t firstByte = readUint8(buffer, *pos); - if (firstByte < MINIMAL_ONE_BYTE_CHARACTER_VALUE) { - if (firstByte == CHARACTER_ARRAY_TERMINATOR) { - *pos += 1; - return NOT_A_CODE_POINT; - } else { - return readUint24andAdvancePosition(buffer, pos); - } - } else { - *pos += 1; - return firstByte; - } - } - - /** - * String (array of code points) - * - * Reads code points until the terminator is found. - */ - // Returns the length of the string. - static int readStringAndAdvancePosition(const uint8_t *const buffer, int *const pos, - int *const outBuffer, const int maxLength) { - int length = 0; - int codePoint = readCodePointAndAdvancePosition(buffer, pos); - while (NOT_A_CODE_POINT != codePoint && length < maxLength) { - outBuffer[length++] = codePoint; - codePoint = readCodePointAndAdvancePosition(buffer, pos); - } - return length; - } - - // Advances the position and returns the length of the string. - static int advancePositionToBehindString( - const uint8_t *const buffer, int *const pos, const int maxLength) { - int length = 0; - int codePoint = readCodePointAndAdvancePosition(buffer, pos); - while (NOT_A_CODE_POINT != codePoint && length < maxLength) { - codePoint = readCodePointAndAdvancePosition(buffer, pos); - } - return length; - } - - private: - DISALLOW_IMPLICIT_CONSTRUCTORS(ByteArrayUtils); - - static const uint8_t MINIMAL_ONE_BYTE_CHARACTER_VALUE; - static const uint8_t CHARACTER_ARRAY_TERMINATOR; -}; -} // namespace latinime -#endif /* LATINIME_BYTE_ARRAY_UTILS_H */ diff --git a/native/jni/src/suggest/core/dictionary/dictionary.cpp b/native/jni/src/suggest/core/dictionary/dictionary.cpp index 2d4ad5df5..ec1b63a12 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.cpp +++ b/native/jni/src/suggest/core/dictionary/dictionary.cpp @@ -18,34 +18,35 @@ #include "suggest/core/dictionary/dictionary.h" -#include <map> // TODO: remove #include <stdint.h> #include "defines.h" #include "suggest/core/dictionary/bigram_dictionary.h" -#include "suggest/core/dictionary/binary_format.h" +#include "suggest/core/policy/dictionary_header_structure_policy.h" +#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" #include "suggest/core/session/dic_traverse_session.h" #include "suggest/core/suggest.h" #include "suggest/core/suggest_options.h" #include "suggest/policyimpl/gesture/gesture_suggest_policy_factory.h" #include "suggest/policyimpl/typing/typing_suggest_policy_factory.h" +#include "utils/log_utils.h" namespace latinime { -Dictionary::Dictionary(void *dict, int dictSize, int mmapFd, int dictBufAdjust) - : mBinaryDicitonaryInfo(static_cast<const uint8_t *>(dict), dictSize), - mDictSize(dictSize), - mDictFlags(BinaryFormat::getFlags(mBinaryDicitonaryInfo.getDictBuf(), dictSize)), - mMmapFd(mmapFd), mDictBufAdjust(dictBufAdjust), - mBigramDictionary(new BigramDictionary(&mBinaryDicitonaryInfo)), +Dictionary::Dictionary(JNIEnv *env, + DictionaryStructureWithBufferPolicy *const dictionaryStructureWithBufferPolicy) + : mDictionaryStructureWithBufferPolicy(dictionaryStructureWithBufferPolicy), + mBigramDictionary(new BigramDictionary(mDictionaryStructureWithBufferPolicy)), mGestureSuggest(new Suggest(GestureSuggestPolicyFactory::getGestureSuggestPolicy())), mTypingSuggest(new Suggest(TypingSuggestPolicyFactory::getTypingSuggestPolicy())) { + logDictionaryInfo(env); } Dictionary::~Dictionary() { delete mBigramDictionary; delete mGestureSuggest; delete mTypingSuggest; + delete mDictionaryStructureWithBufferPolicy; } int Dictionary::getSuggestions(ProximityInfo *proximityInfo, DicTraverseSession *traverseSession, @@ -77,43 +78,79 @@ int Dictionary::getSuggestions(ProximityInfo *proximityInfo, DicTraverseSession } } -int Dictionary::getBigrams(const int *word, int length, int *inputCodePoints, int inputSize, - int *outWords, int *frequencies, int *outputTypes) const { +int Dictionary::getBigrams(const int *word, int length, int *outWords, int *frequencies, + int *outputTypes) const { if (length <= 0) return 0; - return mBigramDictionary->getBigrams(word, length, inputCodePoints, inputSize, outWords, - frequencies, outputTypes); + return mBigramDictionary->getPredictions(word, length, outWords, frequencies, outputTypes); } int Dictionary::getProbability(const int *word, int length) const { - const uint8_t *const root = mBinaryDicitonaryInfo.getDictRoot(); - int pos = BinaryFormat::getTerminalPosition(root, word, length, + int pos = getDictionaryStructurePolicy()->getTerminalNodePositionOfWord(word, length, false /* forceLowerCaseSearch */); - if (NOT_VALID_WORD == pos) { + if (NOT_A_DICT_POS == pos) { return NOT_A_PROBABILITY; } - const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(root, &pos); - if (flags & (BinaryFormat::FLAG_IS_BLACKLISTED | BinaryFormat::FLAG_IS_NOT_A_WORD)) { - // If this is not a word, or if it's a blacklisted entry, it should behave as - // having no probability outside of the suggestion process (where it should be used - // for shortcuts). - return NOT_A_PROBABILITY; - } - const bool hasMultipleChars = (0 != (BinaryFormat::FLAG_HAS_MULTIPLE_CHARS & flags)); - if (hasMultipleChars) { - pos = BinaryFormat::skipOtherCharacters(root, pos); - } else { - BinaryFormat::getCodePointAndForwardPointer(root, &pos); - } - const int unigramProbability = BinaryFormat::readProbabilityWithoutMovingPointer(root, pos); - return unigramProbability; + return getDictionaryStructurePolicy()->getUnigramProbabilityOfPtNode(pos); +} + +int Dictionary::getBigramProbability(const int *word0, int length0, const int *word1, + int length1) const { + return mBigramDictionary->getBigramProbability(word0, length0, word1, length1); +} + +void Dictionary::addUnigramWord(const int *const word, const int length, const int probability) { + mDictionaryStructureWithBufferPolicy->addUnigramWord(word, length, probability); +} + +void Dictionary::addBigramWords(const int *const word0, const int length0, const int *const word1, + const int length1, const int probability) { + mDictionaryStructureWithBufferPolicy->addBigramWords(word0, length0, word1, length1, + probability); +} + +void Dictionary::removeBigramWords(const int *const word0, const int length0, + const int *const word1, const int length1) { + mDictionaryStructureWithBufferPolicy->removeBigramWords(word0, length0, word1, length1); +} + +void Dictionary::flush(const char *const filePath) { + mDictionaryStructureWithBufferPolicy->flush(filePath); } -bool Dictionary::isValidBigram(const int *word1, int length1, const int *word2, int length2) const { - return mBigramDictionary->isValidBigram(word1, length1, word2, length2); +void Dictionary::flushWithGC(const char *const filePath) { + mDictionaryStructureWithBufferPolicy->flushWithGC(filePath); } -int Dictionary::getDictFlags() const { - return mDictFlags; +bool Dictionary::needsToRunGC() { + return mDictionaryStructureWithBufferPolicy->needsToRunGC(); +} + +void Dictionary::logDictionaryInfo(JNIEnv *const env) const { + const int BUFFER_SIZE = 16; + int dictionaryIdCodePointBuffer[BUFFER_SIZE]; + int versionStringCodePointBuffer[BUFFER_SIZE]; + int dateStringCodePointBuffer[BUFFER_SIZE]; + const DictionaryHeaderStructurePolicy *const headerPolicy = + getDictionaryStructurePolicy()->getHeaderStructurePolicy(); + headerPolicy->readHeaderValueOrQuestionMark("dictionary", dictionaryIdCodePointBuffer, + BUFFER_SIZE); + headerPolicy->readHeaderValueOrQuestionMark("version", versionStringCodePointBuffer, + BUFFER_SIZE); + headerPolicy->readHeaderValueOrQuestionMark("date", dateStringCodePointBuffer, BUFFER_SIZE); + + char dictionaryIdCharBuffer[BUFFER_SIZE]; + char versionStringCharBuffer[BUFFER_SIZE]; + char dateStringCharBuffer[BUFFER_SIZE]; + intArrayToCharArray(dictionaryIdCodePointBuffer, BUFFER_SIZE, + dictionaryIdCharBuffer, BUFFER_SIZE); + intArrayToCharArray(versionStringCodePointBuffer, BUFFER_SIZE, + versionStringCharBuffer, BUFFER_SIZE); + intArrayToCharArray(dateStringCodePointBuffer, BUFFER_SIZE, + dateStringCharBuffer, BUFFER_SIZE); + + LogUtils::logToJava(env, + "Dictionary info: dictionary = %s ; version = %s ; date = %s", + dictionaryIdCharBuffer, versionStringCharBuffer, dateStringCharBuffer); } } // namespace latinime diff --git a/native/jni/src/suggest/core/dictionary/dictionary.h b/native/jni/src/suggest/core/dictionary/dictionary.h index 1f25080b1..974447468 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.h +++ b/native/jni/src/suggest/core/dictionary/dictionary.h @@ -20,11 +20,12 @@ #include <stdint.h> #include "defines.h" -#include "suggest/core/dictionary/binary_dictionary_info.h" +#include "jni.h" namespace latinime { class BigramDictionary; +class DictionaryStructureWithBufferPolicy; class DicTraverseSession; class ProximityInfo; class SuggestInterface; @@ -52,7 +53,8 @@ class Dictionary { static const int KIND_FLAG_POSSIBLY_OFFENSIVE = 0x80000000; static const int KIND_FLAG_EXACT_MATCH = 0x40000000; - Dictionary(void *dict, int dictSize, int mmapFd, int dictBufAdjust); + Dictionary(JNIEnv *env, + DictionaryStructureWithBufferPolicy *const dictionaryStructureWithBufferPoilcy); int getSuggestions(ProximityInfo *proximityInfo, DicTraverseSession *traverseSession, int *xcoordinates, int *ycoordinates, int *times, int *pointerIds, int *inputCodePoints, @@ -60,34 +62,42 @@ class Dictionary { const SuggestOptions *const suggestOptions, int *outWords, int *frequencies, int *spaceIndices, int *outputTypes) const; - int getBigrams(const int *word, int length, int *inputCodePoints, int inputSize, int *outWords, - int *frequencies, int *outputTypes) const; + int getBigrams(const int *word, int length, int *outWords, int *frequencies, + int *outputTypes) const; int getProbability(const int *word, int length) const; - bool isValidBigram(const int *word1, int length1, const int *word2, int length2) const; - const BinaryDictionaryInfo *getBinaryDictionaryInfo() const { - return &mBinaryDicitonaryInfo; + + int getBigramProbability(const int *word0, int length0, const int *word1, int length1) const; + + void addUnigramWord(const int *const word, const int length, const int probability); + + void addBigramWords(const int *const word0, const int length0, const int *const word1, + const int length1, const int probability); + + void removeBigramWords(const int *const word0, const int length0, const int *const word1, + const int length1); + + void flush(const char *const filePath); + + void flushWithGC(const char *const filePath); + + bool needsToRunGC(); + + const DictionaryStructureWithBufferPolicy *getDictionaryStructurePolicy() const { + return mDictionaryStructureWithBufferPolicy; } - int getDictSize() const { return mDictSize; } - int getMmapFd() const { return mMmapFd; } - int getDictBufAdjust() const { return mDictBufAdjust; } - int getDictFlags() const; + virtual ~Dictionary(); private: DISALLOW_IMPLICIT_CONSTRUCTORS(Dictionary); - const BinaryDictionaryInfo mBinaryDicitonaryInfo; - // Used only for the mmap version of dictionary loading, but we use these as dummy variables - // also for the malloc version. - const int mDictSize; - const int mDictFlags; - const int mMmapFd; - const int mDictBufAdjust; - - const BigramDictionary *mBigramDictionary; - SuggestInterface *mGestureSuggest; - SuggestInterface *mTypingSuggest; + DictionaryStructureWithBufferPolicy *const mDictionaryStructureWithBufferPolicy; + const BigramDictionary *const mBigramDictionary; + const SuggestInterface *const mGestureSuggest; + const SuggestInterface *const mTypingSuggest; + + void logDictionaryInfo(JNIEnv *const env) const; }; } // namespace latinime #endif // LATINIME_DICTIONARY_H diff --git a/native/jni/src/suggest/core/dictionary/digraph_utils.cpp b/native/jni/src/suggest/core/dictionary/digraph_utils.cpp index f53e56ef1..3271c1bfb 100644 --- a/native/jni/src/suggest/core/dictionary/digraph_utils.cpp +++ b/native/jni/src/suggest/core/dictionary/digraph_utils.cpp @@ -16,8 +16,10 @@ #include "suggest/core/dictionary/digraph_utils.h" +#include <cstdlib> + #include "defines.h" -#include "suggest/core/dictionary/binary_format.h" +#include "suggest/core/policy/dictionary_header_structure_policy.h" #include "utils/char_utils.h" namespace latinime { @@ -33,8 +35,9 @@ const DigraphUtils::DigraphType DigraphUtils::USED_DIGRAPH_TYPES[] = { DIGRAPH_TYPE_GERMAN_UMLAUT, DIGRAPH_TYPE_FRENCH_LIGATURES }; /* static */ bool DigraphUtils::hasDigraphForCodePoint( - const int dictFlags, const int compositeGlyphCodePoint) { - const DigraphUtils::DigraphType digraphType = getDigraphTypeForDictionary(dictFlags); + const DictionaryHeaderStructurePolicy *const headerPolicy, + const int compositeGlyphCodePoint) { + const DigraphUtils::DigraphType digraphType = getDigraphTypeForDictionary(headerPolicy); if (DigraphUtils::getDigraphForDigraphTypeAndCodePoint(digraphType, compositeGlyphCodePoint)) { return true; } @@ -43,24 +46,16 @@ const DigraphUtils::DigraphType DigraphUtils::USED_DIGRAPH_TYPES[] = // Returns the digraph type associated with the given dictionary. /* static */ DigraphUtils::DigraphType DigraphUtils::getDigraphTypeForDictionary( - const int dictFlags) { - if (BinaryFormat::REQUIRES_GERMAN_UMLAUT_PROCESSING & dictFlags) { + const DictionaryHeaderStructurePolicy *const headerPolicy) { + if (headerPolicy->requiresGermanUmlautProcessing()) { return DIGRAPH_TYPE_GERMAN_UMLAUT; } - if (BinaryFormat::REQUIRES_FRENCH_LIGATURES_PROCESSING & dictFlags) { + if (headerPolicy->requiresFrenchLigatureProcessing()) { return DIGRAPH_TYPE_FRENCH_LIGATURES; } return DIGRAPH_TYPE_NONE; } -// Retrieves the set of all digraphs associated with the given dictionary flags. -// Returns the size of the digraph array, or 0 if none exist. -/* static */ int DigraphUtils::getAllDigraphsForDictionaryAndReturnSize( - const int dictFlags, const DigraphUtils::digraph_t **const digraphs) { - const DigraphUtils::DigraphType digraphType = getDigraphTypeForDictionary(dictFlags); - return getAllDigraphsForDigraphTypeAndReturnSize(digraphType, digraphs); -} - // Returns the digraph codepoint for the given composite glyph codepoint and digraph codepoint index // (which specifies the first or second codepoint in the digraph). /* static */ int DigraphUtils::getDigraphCodePointForIndex(const int compositeGlyphCodePoint, @@ -124,7 +119,7 @@ const DigraphUtils::DigraphType DigraphUtils::USED_DIGRAPH_TYPES[] = const DigraphUtils::digraph_t *digraphs = 0; const int compositeGlyphLowerCodePoint = CharUtils::toLowerCase(compositeGlyphCodePoint); const int digraphsSize = - DigraphUtils::getAllDigraphsForDictionaryAndReturnSize(digraphType, &digraphs); + DigraphUtils::getAllDigraphsForDigraphTypeAndReturnSize(digraphType, &digraphs); for (int i = 0; i < digraphsSize; i++) { if (digraphs[i].compositeGlyph == compositeGlyphLowerCodePoint) { return &digraphs[i]; diff --git a/native/jni/src/suggest/core/dictionary/digraph_utils.h b/native/jni/src/suggest/core/dictionary/digraph_utils.h index c1205940c..6ae16e390 100644 --- a/native/jni/src/suggest/core/dictionary/digraph_utils.h +++ b/native/jni/src/suggest/core/dictionary/digraph_utils.h @@ -21,6 +21,8 @@ namespace latinime { +class DictionaryHeaderStructurePolicy; + class DigraphUtils { public: typedef enum { @@ -37,17 +39,15 @@ class DigraphUtils { typedef struct { int first; int second; int compositeGlyph; } digraph_t; - static bool hasDigraphForCodePoint(const int dictFlags, const int compositeGlyphCodePoint); - static int getAllDigraphsForDictionaryAndReturnSize( - const int dictFlags, const digraph_t **const digraphs); - static int getDigraphCodePointForIndex(const int dictFlags, const int compositeGlyphCodePoint, - const DigraphCodePointIndex digraphCodePointIndex); + static bool hasDigraphForCodePoint(const DictionaryHeaderStructurePolicy *const headerPolicy, + const int compositeGlyphCodePoint); static int getDigraphCodePointForIndex(const int compositeGlyphCodePoint, const DigraphCodePointIndex digraphCodePointIndex); private: DISALLOW_IMPLICIT_CONSTRUCTORS(DigraphUtils); - static DigraphType getDigraphTypeForDictionary(const int dictFlags); + static DigraphType getDigraphTypeForDictionary( + const DictionaryHeaderStructurePolicy *const headerPolicy); static int getAllDigraphsForDigraphTypeAndReturnSize( const DigraphType digraphType, const digraph_t **const digraphs); static const digraph_t *getDigraphForCodePoint(const int compositeGlyphCodePoint); diff --git a/native/jni/src/suggest/core/dictionary/multi_bigram_map.cpp b/native/jni/src/suggest/core/dictionary/multi_bigram_map.cpp new file mode 100644 index 000000000..b1d2f4b4d --- /dev/null +++ b/native/jni/src/suggest/core/dictionary/multi_bigram_map.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/core/dictionary/multi_bigram_map.h" + +#include <cstddef> + +namespace latinime { + +// Max number of bigram maps (previous word contexts) to be cached. Increasing this number +// could improve bigram lookup speed for multi-word suggestions, but at the cost of more memory +// usage. Also, there are diminishing returns since the most frequently used bigrams are +// typically near the beginning of the input and are thus the first ones to be cached. Note +// that these bigrams are reset for each new composing word. +const size_t MultiBigramMap::MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP = 25; + +// Most common previous word contexts currently have 100 bigrams +const int MultiBigramMap::BigramMap::DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP = 100; + +} // namespace latinime diff --git a/native/jni/src/suggest/core/dictionary/multi_bigram_map.h b/native/jni/src/suggest/core/dictionary/multi_bigram_map.h index ba97e5842..4633c07b0 100644 --- a/native/jni/src/suggest/core/dictionary/multi_bigram_map.h +++ b/native/jni/src/suggest/core/dictionary/multi_bigram_map.h @@ -17,9 +17,12 @@ #ifndef LATINIME_MULTI_BIGRAM_MAP_H #define LATINIME_MULTI_BIGRAM_MAP_H +#include <cstddef> + #include "defines.h" -#include "suggest/core/dictionary/binary_dictionary_info.h" -#include "suggest/core/dictionary/binary_format.h" +#include "suggest/core/dictionary/binary_dictionary_bigrams_iterator.h" +#include "suggest/core/dictionary/bloom_filter.h" +#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" #include "utils/hash_map_compat.h" namespace latinime { @@ -34,20 +37,21 @@ class MultiBigramMap { // Look up the bigram probability for the given word pair from the cached bigram maps. // Also caches the bigrams if there is space remaining and they have not been cached already. - int getBigramProbability(const BinaryDictionaryInfo *const binaryDicitonaryInfo, + int getBigramProbability(const DictionaryStructureWithBufferPolicy *const structurePolicy, const int wordPosition, const int nextWordPosition, const int unigramProbability) { hash_map_compat<int, BigramMap>::const_iterator mapPosition = mBigramMaps.find(wordPosition); if (mapPosition != mBigramMaps.end()) { - return mapPosition->second.getBigramProbability(nextWordPosition, unigramProbability); + return mapPosition->second.getBigramProbability(structurePolicy, nextWordPosition, + unigramProbability); } if (mBigramMaps.size() < MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP) { - addBigramsForWordPosition(binaryDicitonaryInfo, wordPosition); - return mBigramMaps[wordPosition].getBigramProbability( + addBigramsForWordPosition(structurePolicy, wordPosition); + return mBigramMaps[wordPosition].getBigramProbability(structurePolicy, nextWordPosition, unigramProbability); } - return BinaryFormat::getBigramProbability(binaryDicitonaryInfo->getDictRoot(), - wordPosition, nextWordPosition, unigramProbability); + return readBigramProbabilityFromBinaryDictionary(structurePolicy, wordPosition, + nextWordPosition, unigramProbability); } void clear() { @@ -59,30 +63,69 @@ class MultiBigramMap { class BigramMap { public: - BigramMap() : mBigramMap(DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP) {} + BigramMap() : mBigramMap(DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP), mBloomFilter() {} ~BigramMap() {} - void init(const BinaryDictionaryInfo *const binaryDicitonaryInfo, const int position) { - BinaryFormat::fillBigramProbabilityToHashMap( - binaryDicitonaryInfo->getDictRoot(), position, &mBigramMap); + void init(const DictionaryStructureWithBufferPolicy *const structurePolicy, + const int nodePos) { + const int bigramsListPos = structurePolicy->getBigramsPositionOfPtNode(nodePos); + BinaryDictionaryBigramsIterator bigramsIt(structurePolicy->getBigramsStructurePolicy(), + bigramsListPos); + while (bigramsIt.hasNext()) { + bigramsIt.next(); + if (bigramsIt.getBigramPos() == NOT_A_DICT_POS) { + continue; + } + mBigramMap[bigramsIt.getBigramPos()] = bigramsIt.getProbability(); + mBloomFilter.setInFilter(bigramsIt.getBigramPos()); + } } - inline int getBigramProbability(const int nextWordPosition, const int unigramProbability) - const { - return BinaryFormat::getBigramProbabilityFromHashMap( - nextWordPosition, &mBigramMap, unigramProbability); + AK_FORCE_INLINE int getBigramProbability( + const DictionaryStructureWithBufferPolicy *const structurePolicy, + const int nextWordPosition, const int unigramProbability) const { + int bigramProbability = NOT_A_PROBABILITY; + if (mBloomFilter.isInFilter(nextWordPosition)) { + const hash_map_compat<int, int>::const_iterator bigramProbabilityIt = + mBigramMap.find(nextWordPosition); + if (bigramProbabilityIt != mBigramMap.end()) { + bigramProbability = bigramProbabilityIt->second; + } + } + return structurePolicy->getProbability(unigramProbability, bigramProbability); } private: - // Note: Default copy constructor needed for use in hash_map. + // NOTE: The BigramMap class doesn't use DISALLOW_COPY_AND_ASSIGN() because its default + // copy constructor is needed for use in hash_map. + static const int DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP; hash_map_compat<int, int> mBigramMap; + BloomFilter mBloomFilter; }; - void addBigramsForWordPosition(const BinaryDictionaryInfo *const binaryDicitonaryInfo, - const int position) { - mBigramMaps[position].init(binaryDicitonaryInfo, position); + AK_FORCE_INLINE void addBigramsForWordPosition( + const DictionaryStructureWithBufferPolicy *const structurePolicy, const int position) { + mBigramMaps[position].init(structurePolicy, position); + } + + AK_FORCE_INLINE int readBigramProbabilityFromBinaryDictionary( + const DictionaryStructureWithBufferPolicy *const structurePolicy, const int nodePos, + const int nextWordPosition, const int unigramProbability) { + int bigramProbability = NOT_A_PROBABILITY; + const int bigramsListPos = structurePolicy->getBigramsPositionOfPtNode(nodePos); + BinaryDictionaryBigramsIterator bigramsIt(structurePolicy->getBigramsStructurePolicy(), + bigramsListPos); + while (bigramsIt.hasNext()) { + bigramsIt.next(); + if (bigramsIt.getBigramPos() == nextWordPosition) { + bigramProbability = bigramsIt.getProbability(); + break; + } + } + return structurePolicy->getProbability(unigramProbability, bigramProbability); } + static const size_t MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP; hash_map_compat<int, BigramMap> mBigramMaps; }; } // namespace latinime diff --git a/native/jni/src/suggest/core/dictionary/shortcut_utils.h b/native/jni/src/suggest/core/dictionary/shortcut_utils.h index 601ac5f5a..461d7b454 100644 --- a/native/jni/src/suggest/core/dictionary/shortcut_utils.h +++ b/native/jni/src/suggest/core/dictionary/shortcut_utils.h @@ -19,25 +19,24 @@ #include "defines.h" #include "suggest/core/dicnode/dic_node_utils.h" -#include "suggest/core/dictionary/terminal_attributes.h" +#include "suggest/core/dictionary/binary_dictionary_shortcut_iterator.h" namespace latinime { class ShortcutUtils { public: - static int outputShortcuts(const TerminalAttributes *const terminalAttributes, + static int outputShortcuts(BinaryDictionaryShortcutIterator *const shortcutIt, int outputWordIndex, const int finalScore, int *const outputCodePoints, int *const frequencies, int *const outputTypes, const bool sameAsTyped) { - TerminalAttributes::ShortcutIterator iterator = terminalAttributes->getShortcutIterator(); - while (iterator.hasNextShortcutTarget() && outputWordIndex < MAX_RESULTS) { - int shortcutTarget[MAX_WORD_LENGTH]; - int shortcutProbability; - const int shortcutTargetStringLength = iterator.getNextShortcutTarget( - MAX_WORD_LENGTH, shortcutTarget, &shortcutProbability); + int shortcutTarget[MAX_WORD_LENGTH]; + while (shortcutIt->hasNextShortcutTarget() && outputWordIndex < MAX_RESULTS) { + bool isWhilelist; + int shortcutTargetStringLength; + shortcutIt->nextShortcutTarget(MAX_WORD_LENGTH, shortcutTarget, + &shortcutTargetStringLength, &isWhilelist); int shortcutScore; int kind; - if (shortcutProbability == BinaryFormat::WHITELIST_SHORTCUT_PROBABILITY - && sameAsTyped) { + if (isWhilelist && sameAsTyped) { shortcutScore = S_INT_MAX; kind = Dictionary::KIND_WHITELIST; } else { diff --git a/native/jni/src/suggest/core/dictionary/terminal_attributes.h b/native/jni/src/suggest/core/dictionary/terminal_attributes.h deleted file mode 100644 index bbd9af090..000000000 --- a/native/jni/src/suggest/core/dictionary/terminal_attributes.h +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Copyright (C) 2012 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LATINIME_TERMINAL_ATTRIBUTES_H -#define LATINIME_TERMINAL_ATTRIBUTES_H - -#include <stdint.h> - -#include "suggest/core/dictionary/binary_dictionary_info.h" -#include "suggest/core/dictionary/binary_format.h" - -namespace latinime { - -/** - * This class encapsulates information about a terminal that allows to - * retrieve local node attributes like the list of shortcuts without - * exposing the format structure to the client. - */ -class TerminalAttributes { - public: - class ShortcutIterator { - public: - ShortcutIterator(const BinaryDictionaryInfo *const binaryDictionaryInfo, const int pos, - const uint8_t flags) - : mBinaryDicitionaryInfo(binaryDictionaryInfo), mPos(pos), - mHasNextShortcutTarget(0 != (flags & BinaryFormat::FLAG_HAS_SHORTCUT_TARGETS)) { - } - - inline bool hasNextShortcutTarget() const { - return mHasNextShortcutTarget; - } - - // Gets the shortcut target itself as an int string. For parameters and return value - // see BinaryFormat::getWordAtAddress. - inline int getNextShortcutTarget(const int maxDepth, int *outWord, int *outFreq) { - const int shortcutFlags = BinaryFormat::getFlagsAndForwardPointer( - mBinaryDicitionaryInfo->getDictRoot(), &mPos); - mHasNextShortcutTarget = 0 != (shortcutFlags & BinaryFormat::FLAG_ATTRIBUTE_HAS_NEXT); - unsigned int i; - for (i = 0; i < MAX_WORD_LENGTH; ++i) { - const int codePoint = BinaryFormat::getCodePointAndForwardPointer( - mBinaryDicitionaryInfo->getDictRoot(), &mPos); - if (NOT_A_CODE_POINT == codePoint) break; - outWord[i] = codePoint; - } - *outFreq = BinaryFormat::getAttributeProbabilityFromFlags(shortcutFlags); - return i; - } - - private: - const BinaryDictionaryInfo *const mBinaryDicitionaryInfo; - int mPos; - bool mHasNextShortcutTarget; - }; - - TerminalAttributes(const BinaryDictionaryInfo *const binaryDicitonaryInfo, - const uint8_t flags, const int pos) - : mBinaryDicitionaryInfo(binaryDicitonaryInfo), mFlags(flags), mStartPos(pos) { - } - - inline ShortcutIterator getShortcutIterator() const { - // The size of the shortcuts is stored here so that the whole shortcut chunk can be - // skipped quickly, so we ignore it. - return ShortcutIterator( - mBinaryDicitionaryInfo, mStartPos + BinaryFormat::SHORTCUT_LIST_SIZE_SIZE, mFlags); - } - - bool isBlacklistedOrNotAWord() const { - return BinaryFormat::hasBlacklistedOrNotAWordFlag(mFlags); - } - - private: - DISALLOW_IMPLICIT_CONSTRUCTORS(TerminalAttributes); - const BinaryDictionaryInfo *const mBinaryDicitionaryInfo; - const uint8_t mFlags; - const int mStartPos; -}; -} // namespace latinime -#endif // LATINIME_TERMINAL_ATTRIBUTES_H diff --git a/native/jni/src/suggest/core/layout/proximity_info.cpp b/native/jni/src/suggest/core/layout/proximity_info.cpp index 80355c148..e64476d82 100644 --- a/native/jni/src/suggest/core/layout/proximity_info.cpp +++ b/native/jni/src/suggest/core/layout/proximity_info.cpp @@ -134,24 +134,13 @@ bool ProximityInfo::hasSpaceProximity(const int x, const int y) const { } float ProximityInfo::getNormalizedSquaredDistanceFromCenterFloatG( - const int keyId, const int x, const int y, const float verticalScale) const { - const bool correctTouchPosition = hasTouchPositionCorrectionData(); - const float centerX = static_cast<float>(correctTouchPosition ? getSweetSpotCenterXAt(keyId) - : getKeyCenterXOfKeyIdG(keyId)); - const float visualKeyCenterY = static_cast<float>(getKeyCenterYOfKeyIdG(keyId)); - float centerY; - if (correctTouchPosition) { - const float sweetSpotCenterY = static_cast<float>(getSweetSpotCenterYAt(keyId)); - const float gapY = sweetSpotCenterY - visualKeyCenterY; - centerY = visualKeyCenterY + gapY * verticalScale; - } else { - centerY = visualKeyCenterY; - } + const int keyId, const int x, const int y, const bool isGeometric) const { + const float centerX = static_cast<float>(getKeyCenterXOfKeyIdG(keyId, x, isGeometric)); + const float centerY = static_cast<float>(getKeyCenterYOfKeyIdG(keyId, y, isGeometric)); const float touchX = static_cast<float>(x); const float touchY = static_cast<float>(y); - const float keyWidth = static_cast<float>(getMostCommonKeyWidth()); return ProximityInfoUtils::getSquaredDistanceFloat(centerX, centerY, touchX, touchY) - / GeometryUtils::SQUARE_FLOAT(keyWidth); + / GeometryUtils::SQUARE_FLOAT(static_cast<float>(getMostCommonKeyWidth())); } int ProximityInfo::getCodePointOf(const int keyIndex) const { @@ -168,41 +157,88 @@ void ProximityInfo::initializeG() { const int lowerCode = CharUtils::toLowerCase(code); mCenterXsG[i] = mKeyXCoordinates[i] + mKeyWidths[i] / 2; mCenterYsG[i] = mKeyYCoordinates[i] + mKeyHeights[i] / 2; + if (hasTouchPositionCorrectionData()) { + // Computes sweet spot center points for geometric input. + const float verticalScale = ProximityInfoParams::VERTICAL_SWEET_SPOT_SCALE_G; + const float sweetSpotCenterY = static_cast<float>(mSweetSpotCenterYs[i]); + const float gapY = sweetSpotCenterY - mCenterYsG[i]; + mSweetSpotCenterYsG[i] = static_cast<int>(mCenterYsG[i] + gapY * verticalScale); + } mCodeToKeyMap[lowerCode] = i; mKeyIndexToCodePointG[i] = lowerCode; } for (int i = 0; i < KEY_COUNT; i++) { mKeyKeyDistancesG[i][i] = 0; for (int j = i + 1; j < KEY_COUNT; j++) { - mKeyKeyDistancesG[i][j] = GeometryUtils::getDistanceInt( - mCenterXsG[i], mCenterYsG[i], mCenterXsG[j], mCenterYsG[j]); + if (hasTouchPositionCorrectionData()) { + // Computes distances using sweet spots if they exist. + // We have two types of Y coordinate sweet spots, for geometric and for the others. + // The sweet spots for geometric input are used for calculating key-key distances + // here. + mKeyKeyDistancesG[i][j] = GeometryUtils::getDistanceInt( + mSweetSpotCenterXs[i], mSweetSpotCenterYsG[i], + mSweetSpotCenterXs[j], mSweetSpotCenterYsG[j]); + } else { + mKeyKeyDistancesG[i][j] = GeometryUtils::getDistanceInt( + mCenterXsG[i], mCenterYsG[i], mCenterXsG[j], mCenterYsG[j]); + } mKeyKeyDistancesG[j][i] = mKeyKeyDistancesG[i][j]; } } } -int ProximityInfo::getKeyCenterXOfCodePointG(int charCode) const { - return getKeyCenterXOfKeyIdG( - ProximityInfoUtils::getKeyIndexOf(KEY_COUNT, charCode, &mCodeToKeyMap)); -} - -int ProximityInfo::getKeyCenterYOfCodePointG(int charCode) const { - return getKeyCenterYOfKeyIdG( - ProximityInfoUtils::getKeyIndexOf(KEY_COUNT, charCode, &mCodeToKeyMap)); -} - -int ProximityInfo::getKeyCenterXOfKeyIdG(int keyId) const { - if (keyId >= 0) { - return mCenterXsG[keyId]; +// referencePointX is used only for keys wider than most common key width. When the referencePointX +// is NOT_A_COORDINATE, this method calculates the return value without using the line segment. +// isGeometric is currently not used because we don't have extra X coordinates sweet spots for +// geometric input. +int ProximityInfo::getKeyCenterXOfKeyIdG( + const int keyId, const int referencePointX, const bool isGeometric) const { + if (keyId < 0) { + return 0; + } + int centerX = (hasTouchPositionCorrectionData()) ? static_cast<int>(mSweetSpotCenterXs[keyId]) + : mCenterXsG[keyId]; + const int keyWidth = mKeyWidths[keyId]; + if (referencePointX != NOT_A_COORDINATE + && keyWidth > getMostCommonKeyWidth()) { + // For keys wider than most common keys, we use a line segment instead of the center point; + // thus, centerX is adjusted depending on referencePointX. + const int keyWidthHalfDiff = (keyWidth - getMostCommonKeyWidth()) / 2; + if (referencePointX < centerX - keyWidthHalfDiff) { + centerX -= keyWidthHalfDiff; + } else if (referencePointX > centerX + keyWidthHalfDiff) { + centerX += keyWidthHalfDiff; + } else { + centerX = referencePointX; + } } - return 0; + return centerX; } -int ProximityInfo::getKeyCenterYOfKeyIdG(int keyId) const { - if (keyId >= 0) { - return mCenterYsG[keyId]; +// When the referencePointY is NOT_A_COORDINATE, this method calculates the return value without +// using the line segment. +int ProximityInfo::getKeyCenterYOfKeyIdG( + const int keyId, const int referencePointY, const bool isGeometric) const { + // TODO: Remove "isGeometric" and have separate "proximity_info"s for gesture and typing. + if (keyId < 0) { + return 0; + } + int centerY; + if (!hasTouchPositionCorrectionData()) { + centerY = mCenterYsG[keyId]; + } else if (isGeometric) { + centerY = static_cast<int>(mSweetSpotCenterYsG[keyId]); + } else { + centerY = static_cast<int>(mSweetSpotCenterYs[keyId]); + } + if (referencePointY != NOT_A_COORDINATE && + centerY + mKeyHeights[keyId] > KEYBOARD_HEIGHT && centerY < referencePointY) { + // When the distance between center point and bottom edge of the keyboard is shorter than + // the key height, we assume the key is located at the bottom row of the keyboard. + // The center point is extended to the bottom edge for such keys. + return referencePointY; } - return 0; + return centerY; } int ProximityInfo::getKeyKeyDistanceG(const int keyId0, const int keyId1) const { diff --git a/native/jni/src/suggest/core/layout/proximity_info.h b/native/jni/src/suggest/core/layout/proximity_info.h index 6ca2fdd7b..f25949001 100644 --- a/native/jni/src/suggest/core/layout/proximity_info.h +++ b/native/jni/src/suggest/core/layout/proximity_info.h @@ -24,8 +24,6 @@ namespace latinime { -class Correction; - class ProximityInfo { public: ProximityInfo(JNIEnv *env, const jstring localeJStr, @@ -39,9 +37,7 @@ class ProximityInfo { bool hasSpaceProximity(const int x, const int y) const; int getNormalizedSquaredDistance(const int inputIndex, const int proximityIndex) const; float getNormalizedSquaredDistanceFromCenterFloatG( - const int keyId, const int x, const int y, - const float verticalScale) const; - bool sameAsTyped(const unsigned short *word, int length) const; + const int keyId, const int x, const int y, const bool isGeometric) const; int getCodePointOf(const int keyIndex) const; bool hasSweetSpotData(const int keyIndex) const { // When there are no calibration data for a key, @@ -68,10 +64,10 @@ class ProximityInfo { int getKeyboardHeight() const { return KEYBOARD_HEIGHT; } float getKeyboardHypotenuse() const { return KEYBOARD_HYPOTENUSE; } - int getKeyCenterXOfCodePointG(int charCode) const; - int getKeyCenterYOfCodePointG(int charCode) const; - int getKeyCenterXOfKeyIdG(int keyId) const; - int getKeyCenterYOfKeyIdG(int keyId) const; + int getKeyCenterXOfKeyIdG( + const int keyId, const int referencePointX, const bool isGeometric) const; + int getKeyCenterYOfKeyIdG( + const int keyId, const int referencePointY, const bool isGeometric) const; int getKeyKeyDistanceG(int keyId0, int keyId1) const; AK_FORCE_INLINE void initializeProximities(const int *const inputCodes, @@ -95,8 +91,6 @@ class ProximityInfo { DISALLOW_IMPLICIT_CONSTRUCTORS(ProximityInfo); void initializeG(); - float calculateNormalizedSquaredDistance(const int keyIndex, const int inputIndex) const; - bool hasInputCoordinates() const; const int GRID_WIDTH; const int GRID_HEIGHT; @@ -120,6 +114,8 @@ class ProximityInfo { int mKeyCodePoints[MAX_KEY_COUNT_IN_A_KEYBOARD]; float mSweetSpotCenterXs[MAX_KEY_COUNT_IN_A_KEYBOARD]; float mSweetSpotCenterYs[MAX_KEY_COUNT_IN_A_KEYBOARD]; + // Sweet spots for geometric input. Note that we have extra sweet spots only for Y coordinates. + float mSweetSpotCenterYsG[MAX_KEY_COUNT_IN_A_KEYBOARD]; float mSweetSpotRadii[MAX_KEY_COUNT_IN_A_KEYBOARD]; hash_map_compat<int, int> mCodeToKeyMap; diff --git a/native/jni/src/suggest/core/layout/proximity_info_state.cpp b/native/jni/src/suggest/core/layout/proximity_info_state.cpp index 4e53992d4..fbabd92f2 100644 --- a/native/jni/src/suggest/core/layout/proximity_info_state.cpp +++ b/native/jni/src/suggest/core/layout/proximity_info_state.cpp @@ -36,8 +36,8 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi const int *const xCoordinates, const int *const yCoordinates, const int *const times, const int *const pointerIds, const bool isGeometric) { ASSERT(isGeometric || (inputSize < MAX_WORD_LENGTH)); - mIsContinuousSuggestionPossible = - ProximityInfoStateUtils::checkAndReturnIsContinuousSuggestionPossible( + mIsContinuousSuggestionPossible = (mHasBeenUpdatedByGeometricInput != isGeometric) ? + false : ProximityInfoStateUtils::checkAndReturnIsContinuousSuggestionPossible( inputSize, xCoordinates, yCoordinates, times, mSampledInputSize, &mSampledInputXs, &mSampledInputYs, &mSampledTimes, &mSampledInputIndice); if (DEBUG_DICT) { @@ -97,15 +97,10 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi pushTouchPointStartIndex, lastSavedInputSize); } - // TODO: Remove the dependency of "isGeometric" - const float verticalSweetSpotScale = isGeometric - ? ProximityInfoParams::VERTICAL_SWEET_SPOT_SCALE_G - : ProximityInfoParams::VERTICAL_SWEET_SPOT_SCALE; - if (xCoordinates && yCoordinates) { mSampledInputSize = ProximityInfoStateUtils::updateTouchPoints(mProximityInfo, mMaxPointToKeyLength, mInputProximities, xCoordinates, yCoordinates, times, - pointerIds, verticalSweetSpotScale, inputSize, isGeometric, pointerId, + pointerIds, inputSize, isGeometric, pointerId, pushTouchPointStartIndex, &mSampledInputXs, &mSampledInputYs, &mSampledTimes, &mSampledLengthCache, &mSampledInputIndice); } @@ -123,7 +118,7 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi if (mSampledInputSize > 0) { ProximityInfoStateUtils::initGeometricDistanceInfos(mProximityInfo, mSampledInputSize, - lastSavedInputSize, verticalSweetSpotScale, &mSampledInputXs, &mSampledInputYs, + lastSavedInputSize, isGeometric, &mSampledInputXs, &mSampledInputYs, &mSampledNearKeySets, &mSampledNormalizedSquaredLengthCache); if (isGeometric) { // updates probabilities of skipping or mapping each key for all points. @@ -156,15 +151,11 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi if (!isGeometric && pointerId == 0) { ProximityInfoStateUtils::initPrimaryInputWord( inputSize, mInputProximities, mPrimaryInputWord); - if (mTouchPositionCorrectionEnabled) { - ProximityInfoStateUtils::initNormalizedSquaredDistances( - mProximityInfo, inputSize, xCoordinates, yCoordinates, mInputProximities, - &mSampledInputXs, &mSampledInputYs, mNormalizedSquaredDistances); - } } if (DEBUG_GEO_FULL) { AKLOGI("ProximityState init finished: %d points out of %d", mSampledInputSize, inputSize); } + mHasBeenUpdatedByGeometricInput = isGeometric; } // This function basically converts from a length to an edit distance. Accordingly, it's obviously @@ -279,26 +270,6 @@ float ProximityInfoState::getDirection(const int index0, const int index1) const &mSampledInputXs, &mSampledInputYs, index0, index1); } -float ProximityInfoState::getLineToKeyDistance( - const int from, const int to, const int keyId, const bool extend) const { - if (from < 0 || from > mSampledInputSize - 1) { - return 0.0f; - } - if (to < 0 || to > mSampledInputSize - 1) { - return 0.0f; - } - const int x0 = mSampledInputXs[from]; - const int y0 = mSampledInputYs[from]; - const int x1 = mSampledInputXs[to]; - const int y1 = mSampledInputYs[to]; - - const int keyX = mProximityInfo->getKeyCenterXOfKeyIdG(keyId); - const int keyY = mProximityInfo->getKeyCenterYOfKeyIdG(keyId); - - return ProximityInfoUtils::pointToLineSegSquaredDistanceFloat( - keyX, keyY, x0, y0, x1, y1, extend); -} - float ProximityInfoState::getMostProbableString(int *const codePointBuf) const { memcpy(codePointBuf, mMostProbableString, sizeof(mMostProbableString)); return mMostProbableStringProbability; diff --git a/native/jni/src/suggest/core/layout/proximity_info_state.h b/native/jni/src/suggest/core/layout/proximity_info_state.h index 0079ab5b8..c94060fa9 100644 --- a/native/jni/src/suggest/core/layout/proximity_info_state.h +++ b/native/jni/src/suggest/core/layout/proximity_info_state.h @@ -46,14 +46,14 @@ class ProximityInfoState { : mProximityInfo(0), mMaxPointToKeyLength(0.0f), mAverageSpeed(0.0f), mHasTouchPositionCorrectionData(false), mMostCommonKeyWidthSquare(0), mKeyCount(0), mCellHeight(0), mCellWidth(0), mGridHeight(0), mGridWidth(0), - mIsContinuousSuggestionPossible(false), mSampledInputXs(), mSampledInputYs(), - mSampledTimes(), mSampledInputIndice(), mSampledLengthCache(), - mBeelineSpeedPercentiles(), mSampledNormalizedSquaredLengthCache(), mSpeedRates(), - mDirections(), mCharProbabilities(), mSampledNearKeySets(), mSampledSearchKeySets(), + mIsContinuousSuggestionPossible(false), mHasBeenUpdatedByGeometricInput(false), + mSampledInputXs(), mSampledInputYs(), mSampledTimes(), mSampledInputIndice(), + mSampledLengthCache(), mBeelineSpeedPercentiles(), + mSampledNormalizedSquaredLengthCache(), mSpeedRates(), mDirections(), + mCharProbabilities(), mSampledNearKeySets(), mSampledSearchKeySets(), mSampledSearchKeyVectors(), mTouchPositionCorrectionEnabled(false), mSampledInputSize(0), mMostProbableStringProbability(0.0f) { memset(mInputProximities, 0, sizeof(mInputProximities)); - memset(mNormalizedSquaredDistances, 0, sizeof(mNormalizedSquaredDistances)); memset(mPrimaryInputWord, 0, sizeof(mPrimaryInputWord)); memset(mMostProbableString, 0, sizeof(mMostProbableString)); } @@ -91,7 +91,7 @@ class ProximityInfoState { return false; } - inline bool existsAdjacentProximityChars(const int index) const { + AK_FORCE_INLINE bool existsAdjacentProximityChars(const int index) const { if (index < 0 || index >= mSampledInputSize) return false; const int currentCodePoint = getPrimaryCodePointAt(index); const int leftIndex = index - 1; @@ -106,12 +106,6 @@ class ProximityInfoState { return false; } - inline int getNormalizedSquaredDistance( - const int inputIndex, const int proximityIndex) const { - return mNormalizedSquaredDistances[ - inputIndex * MAX_PROXIMITY_CHARS_SIZE + proximityIndex]; - } - inline const int *getPrimaryInputWord() const { return mPrimaryInputWord; } @@ -136,6 +130,10 @@ class ProximityInfoState { return mSampledInputYs[index]; } + int getInputIndexOfSampledPoint(const int sampledIndex) const { + return mSampledInputIndice[sampledIndex]; + } + bool hasSpaceProximity(const int index) const; int getLengthCache(const int index) const { @@ -190,24 +188,10 @@ class ProximityInfoState { float getProbability(const int index, const int charCode) const; - float getLineToKeyDistance( - const int from, const int to, const int keyId, const bool extend) const; - bool isKeyInSerchKeysAfterIndex(const int index, const int keyId) const; private: DISALLOW_COPY_AND_ASSIGN(ProximityInfoState); - ///////////////////////////////////////// - // Defined in proximity_info_state.cpp // - ///////////////////////////////////////// - float calculateNormalizedSquaredDistance(const int keyIndex, const int inputIndex) const; - - float calculateSquaredDistanceFromSweetSpotCenter( - const int keyIndex, const int inputIndex) const; - - ///////////////////////////////////////// - // Defined here // - ///////////////////////////////////////// inline const int *getProximityCodePointsAt(const int index) const { return ProximityInfoStateUtils::getProximityCodePointsAt(mInputProximities, index); @@ -225,6 +209,7 @@ class ProximityInfoState { int mGridHeight; int mGridWidth; bool mIsContinuousSuggestionPossible; + bool mHasBeenUpdatedByGeometricInput; std::vector<int> mSampledInputXs; std::vector<int> mSampledInputYs; @@ -249,7 +234,6 @@ class ProximityInfoState { std::vector<std::vector<int> > mSampledSearchKeyVectors; bool mTouchPositionCorrectionEnabled; int mInputProximities[MAX_PROXIMITY_CHARS_SIZE * MAX_WORD_LENGTH]; - int mNormalizedSquaredDistances[MAX_PROXIMITY_CHARS_SIZE * MAX_WORD_LENGTH]; int mSampledInputSize; int mPrimaryInputWord[MAX_WORD_LENGTH]; float mMostProbableStringProbability; diff --git a/native/jni/src/suggest/core/layout/proximity_info_state_utils.cpp b/native/jni/src/suggest/core/layout/proximity_info_state_utils.cpp index 6f88833a2..904671f7f 100644 --- a/native/jni/src/suggest/core/layout/proximity_info_state_utils.cpp +++ b/native/jni/src/suggest/core/layout/proximity_info_state_utils.cpp @@ -43,8 +43,8 @@ namespace latinime { const ProximityInfo *const proximityInfo, const int maxPointToKeyLength, const int *const inputProximities, const int *const inputXCoordinates, const int *const inputYCoordinates, const int *const times, const int *const pointerIds, - const float verticalSweetSpotScale, const int inputSize, const bool isGeometric, - const int pointerId, const int pushTouchPointStartIndex, std::vector<int> *sampledInputXs, + const int inputSize, const bool isGeometric, const int pointerId, + const int pushTouchPointStartIndex, std::vector<int> *sampledInputXs, std::vector<int> *sampledInputYs, std::vector<int> *sampledInputTimes, std::vector<int> *sampledLengthCache, std::vector<int> *sampledInputIndice) { if (DEBUG_SAMPLING_POINTS) { @@ -113,7 +113,7 @@ namespace latinime { } if (pushTouchPoint(proximityInfo, maxPointToKeyLength, i, c, x, y, time, - verticalSweetSpotScale, isGeometric /* doSampling */, i == lastInputIndex, + isGeometric, isGeometric /* doSampling */, i == lastInputIndex, sumAngle, currentNearKeysDistances, prevNearKeysDistances, prevPrevNearKeysDistances, sampledInputXs, sampledInputYs, sampledInputTimes, sampledLengthCache, sampledInputIndice)) { @@ -181,51 +181,9 @@ namespace latinime { return squaredDistance / squaredRadius; } -/* static */ void ProximityInfoStateUtils::initNormalizedSquaredDistances( - const ProximityInfo *const proximityInfo, const int inputSize, const int *inputXCoordinates, - const int *inputYCoordinates, const int *const inputProximities, - const std::vector<int> *const sampledInputXs, const std::vector<int> *const sampledInputYs, - int *normalizedSquaredDistances) { - memset(normalizedSquaredDistances, NOT_A_DISTANCE, - sizeof(normalizedSquaredDistances[0]) * MAX_PROXIMITY_CHARS_SIZE * MAX_WORD_LENGTH); - const bool hasInputCoordinates = sampledInputXs->size() > 0 && sampledInputYs->size() > 0; - for (int i = 0; i < inputSize; ++i) { - const int *proximityCodePoints = getProximityCodePointsAt(inputProximities, i); - const int primaryKey = proximityCodePoints[0]; - const int x = inputXCoordinates[i]; - const int y = inputYCoordinates[i]; - if (DEBUG_PROXIMITY_CHARS) { - int a = x + y + primaryKey; - a += 0; - AKLOGI("--- Primary = %c, x = %d, y = %d", primaryKey, x, y); - } - for (int j = 0; j < MAX_PROXIMITY_CHARS_SIZE && proximityCodePoints[j] > 0; ++j) { - const int currentCodePoint = proximityCodePoints[j]; - const float squaredDistance = - hasInputCoordinates ? calculateNormalizedSquaredDistance( - proximityInfo, sampledInputXs, sampledInputYs, - proximityInfo->getKeyIndexOf(currentCodePoint), i) : - ProximityInfoParams::NOT_A_DISTANCE_FLOAT; - if (squaredDistance >= 0.0f) { - normalizedSquaredDistances[i * MAX_PROXIMITY_CHARS_SIZE + j] = - static_cast<int>(squaredDistance - * ProximityInfoParams::NORMALIZED_SQUARED_DISTANCE_SCALING_FACTOR); - } else { - normalizedSquaredDistances[i * MAX_PROXIMITY_CHARS_SIZE + j] = - (j == 0) ? MATCH_CHAR_WITHOUT_DISTANCE_INFO : - PROXIMITY_CHAR_WITHOUT_DISTANCE_INFO; - } - if (DEBUG_PROXIMITY_CHARS) { - AKLOGI("--- Proximity (%d) = %c", j, currentCodePoint); - } - } - } - -} - /* static */ void ProximityInfoStateUtils::initGeometricDistanceInfos( const ProximityInfo *const proximityInfo, const int sampledInputSize, - const int lastSavedInputSize, const float verticalSweetSpotScale, + const int lastSavedInputSize, const bool isGeometric, const std::vector<int> *const sampledInputXs, const std::vector<int> *const sampledInputYs, std::vector<NearKeycodesSet> *sampledNearKeySets, @@ -241,7 +199,7 @@ namespace latinime { const int y = (*sampledInputYs)[i]; const float normalizedSquaredDistance = proximityInfo->getNormalizedSquaredDistanceFromCenterFloatG( - k, x, y, verticalSweetSpotScale); + k, x, y, isGeometric); (*sampledNormalizedSquaredLengthCache)[index] = normalizedSquaredDistance; if (normalizedSquaredDistance < ProximityInfoParams::NEAR_KEY_NORMALIZED_SQUARED_THRESHOLD) { @@ -359,14 +317,13 @@ namespace latinime { // the given point and the nearest key position. /* static */ float ProximityInfoStateUtils::updateNearKeysDistances( const ProximityInfo *const proximityInfo, const float maxPointToKeyLength, const int x, - const int y, const float verticalSweetspotScale, - NearKeysDistanceMap *const currentNearKeysDistances) { + const int y, const bool isGeometric, NearKeysDistanceMap *const currentNearKeysDistances) { currentNearKeysDistances->clear(); const int keyCount = proximityInfo->getKeyCount(); float nearestKeyDistance = maxPointToKeyLength; for (int k = 0; k < keyCount; ++k) { const float dist = proximityInfo->getNormalizedSquaredDistanceFromCenterFloatG(k, x, y, - verticalSweetspotScale); + isGeometric); if (dist < ProximityInfoParams::NEAR_KEY_THRESHOLD_FOR_DISTANCE) { currentNearKeysDistances->insert(std::pair<int, float>(k, dist)); } @@ -447,7 +404,7 @@ namespace latinime { // Returning if previous point is popped or not. /* static */ bool ProximityInfoStateUtils::pushTouchPoint(const ProximityInfo *const proximityInfo, const int maxPointToKeyLength, const int inputIndex, const int nodeCodePoint, int x, int y, - const int time, const float verticalSweetSpotScale, const bool doSampling, + const int time, const bool isGeometric, const bool doSampling, const bool isLastPoint, const float sumAngle, NearKeysDistanceMap *const currentNearKeysDistances, const NearKeysDistanceMap *const prevNearKeysDistances, @@ -461,7 +418,7 @@ namespace latinime { bool popped = false; if (nodeCodePoint < 0 && doSampling) { const float nearest = updateNearKeysDistances(proximityInfo, maxPointToKeyLength, x, y, - verticalSweetSpotScale, currentNearKeysDistances); + isGeometric, currentNearKeysDistances); const float score = getPointScore(mostCommonKeyWidth, x, y, time, isLastPoint, nearest, sumAngle, currentNearKeysDistances, prevNearKeysDistances, prevPrevNearKeysDistances, sampledInputXs, sampledInputYs); @@ -495,8 +452,8 @@ namespace latinime { if (nodeCodePoint >= 0 && (x < 0 || y < 0)) { const int keyId = proximityInfo->getKeyIndexOf(nodeCodePoint); if (keyId >= 0) { - x = proximityInfo->getKeyCenterXOfKeyIdG(keyId); - y = proximityInfo->getKeyCenterYOfKeyIdG(keyId); + x = proximityInfo->getKeyCenterXOfKeyIdG(keyId, NOT_AN_INDEX, isGeometric); + y = proximityInfo->getKeyCenterYOfKeyIdG(keyId, NOT_AN_INDEX, isGeometric); } } diff --git a/native/jni/src/suggest/core/layout/proximity_info_state_utils.h b/native/jni/src/suggest/core/layout/proximity_info_state_utils.h index 66fe07926..6de970033 100644 --- a/native/jni/src/suggest/core/layout/proximity_info_state_utils.h +++ b/native/jni/src/suggest/core/layout/proximity_info_state_utils.h @@ -38,8 +38,7 @@ class ProximityInfoStateUtils { static int updateTouchPoints(const ProximityInfo *const proximityInfo, const int maxPointToKeyLength, const int *const inputProximities, const int *const inputXCoordinates, const int *const inputYCoordinates, - const int *const times, const int *const pointerIds, - const float verticalSweetSpotScale, const int inputSize, + const int *const times, const int *const pointerIds, const int inputSize, const bool isGeometric, const int pointerId, const int pushTouchPointStartIndex, std::vector<int> *sampledInputXs, std::vector<int> *sampledInputYs, std::vector<int> *sampledInputTimes, std::vector<int> *sampledLengthCache, @@ -84,8 +83,7 @@ class ProximityInfoStateUtils { const std::vector<float> *const sampledNormalizedSquaredLengthCache, const int keyCount, const int inputIndex, const int keyId); static void initGeometricDistanceInfos(const ProximityInfo *const proximityInfo, - const int sampledInputSize, const int lastSavedInputSize, - const float verticalSweetSpotScale, + const int sampledInputSize, const int lastSavedInputSize, const bool isGeometric, const std::vector<int> *const sampledInputXs, const std::vector<int> *const sampledInputYs, std::vector<NearKeycodesSet> *sampledNearKeySets, @@ -120,7 +118,7 @@ class ProximityInfoStateUtils { static float updateNearKeysDistances(const ProximityInfo *const proximityInfo, const float maxPointToKeyLength, const int x, const int y, - const float verticalSweetSpotScale, + const bool isGeometric, NearKeysDistanceMap *const currentNearKeysDistances); static bool isPrevLocalMin(const NearKeysDistanceMap *const currentNearKeysDistances, const NearKeysDistanceMap *const prevNearKeysDistances, @@ -133,7 +131,7 @@ class ProximityInfoStateUtils { std::vector<int> *sampledInputXs, std::vector<int> *sampledInputYs); static bool pushTouchPoint(const ProximityInfo *const proximityInfo, const int maxPointToKeyLength, const int inputIndex, const int nodeCodePoint, int x, - int y, const int time, const float verticalSweetSpotScale, + int y, const int time, const bool isGeometric, const bool doSampling, const bool isLastPoint, const float sumAngle, NearKeysDistanceMap *const currentNearKeysDistances, const NearKeysDistanceMap *const prevNearKeysDistances, diff --git a/native/jni/src/suggest/core/layout/proximity_info_utils.h b/native/jni/src/suggest/core/layout/proximity_info_utils.h index 54f7539d1..0e28560fc 100644 --- a/native/jni/src/suggest/core/layout/proximity_info_utils.h +++ b/native/jni/src/suggest/core/layout/proximity_info_utils.h @@ -117,6 +117,10 @@ class ProximityInfoUtils { return getSquaredDistanceFloat(x, y, projectionX, projectionY); } + static AK_FORCE_INLINE bool isMatchOrProximityChar(const ProximityType type) { + return type == MATCH_CHAR || type == PROXIMITY_CHAR || type == ADDITIONAL_PROXIMITY_CHAR; + } + // Normal distribution N(u, sigma^2). struct NormalDistribution { public: diff --git a/native/jni/src/suggest/core/layout/touch_position_correction_utils.h b/native/jni/src/suggest/core/layout/touch_position_correction_utils.h index 429dcae0d..9130e87d3 100644 --- a/native/jni/src/suggest/core/layout/touch_position_correction_utils.h +++ b/native/jni/src/suggest/core/layout/touch_position_correction_utils.h @@ -23,31 +23,6 @@ namespace latinime { class TouchPositionCorrectionUtils { public: - // TODO: (OLD) Remove - static float getLengthScalingFactor(const float normalizedSquaredDistance) { - // Promote or demote the score according to the distance from the sweet spot - static const float A = ZERO_DISTANCE_PROMOTION_RATE / 100.0f; - static const float B = 1.0f; - static const float C = 0.5f; - static const float MIN = 0.3f; - static const float R1 = NEUTRAL_SCORE_SQUARED_RADIUS; - static const float R2 = HALF_SCORE_SQUARED_RADIUS; - const float x = normalizedSquaredDistance / static_cast<float>( - ProximityInfoParams::NORMALIZED_SQUARED_DISTANCE_SCALING_FACTOR); - const float factor = max((x < R1) - ? (A * (R1 - x) + B * x) / R1 - : (B * (R2 - x) + C * (x - R1)) / (R2 - R1), MIN); - // factor is a piecewise linear function like: - // A -_ . - // ^-_ . - // B \ . - // \_ . - // C ------------. - // . - // 0 R1 R2 . - return factor; - } - static float getSweetSpotFactor(const bool isTouchPositionCorrectionEnabled, const float normalizedSquaredDistance) { // Promote or demote the score according to the distance from the sweet spot diff --git a/native/jni/src/suggest/core/policy/dictionary_bigrams_structure_policy.h b/native/jni/src/suggest/core/policy/dictionary_bigrams_structure_policy.h new file mode 100644 index 000000000..661ef1b1a --- /dev/null +++ b/native/jni/src/suggest/core/policy/dictionary_bigrams_structure_policy.h @@ -0,0 +1,42 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DICTIONARY_BIGRAMS_STRUCTURE_POLICY_H +#define LATINIME_DICTIONARY_BIGRAMS_STRUCTURE_POLICY_H + +#include "defines.h" + +namespace latinime { + +/* + * This class abstracts structure of bigrams. + */ +class DictionaryBigramsStructurePolicy { + public: + virtual ~DictionaryBigramsStructurePolicy() {} + + virtual void getNextBigram(int *const outBigramPos, int *const outProbability, + bool *const outHasNext, int *const pos) const = 0; + virtual void skipAllBigrams(int *const pos) const = 0; + + protected: + DictionaryBigramsStructurePolicy() {} + + private: + DISALLOW_COPY_AND_ASSIGN(DictionaryBigramsStructurePolicy); +}; +} // namespace latinime +#endif /* LATINIME_DICTIONARY_BIGRAMS_STRUCTURE_POLICY_H */ diff --git a/native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h b/native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h new file mode 100644 index 000000000..a6829b476 --- /dev/null +++ b/native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h @@ -0,0 +1,50 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DICTIONARY_HEADER_STRUCTURE_POLICY_H +#define LATINIME_DICTIONARY_HEADER_STRUCTURE_POLICY_H + +#include "defines.h" + +namespace latinime { + +/* + * This class abstracts structure of dictionaries. + * Implement this policy to support additional dictionaries. + */ +class DictionaryHeaderStructurePolicy { + public: + virtual ~DictionaryHeaderStructurePolicy() {} + + virtual bool supportsDynamicUpdate() const = 0; + + virtual bool requiresGermanUmlautProcessing() const = 0; + + virtual bool requiresFrenchLigatureProcessing() const = 0; + + virtual float getMultiWordCostMultiplier() const = 0; + + virtual void readHeaderValueOrQuestionMark(const char *const key, int *outValue, + int outValueSize) const = 0; + + protected: + DictionaryHeaderStructurePolicy() {} + + private: + DISALLOW_COPY_AND_ASSIGN(DictionaryHeaderStructurePolicy); +}; +} // namespace latinime +#endif /* LATINIME_DICTIONARY_HEADER_STRUCTURE_POLICY_H */ diff --git a/native/jni/src/suggest/core/policy/dictionary_shortcuts_structure_policy.h b/native/jni/src/suggest/core/policy/dictionary_shortcuts_structure_policy.h new file mode 100644 index 000000000..40b6c2de1 --- /dev/null +++ b/native/jni/src/suggest/core/policy/dictionary_shortcuts_structure_policy.h @@ -0,0 +1,46 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DICTIONARY_SHORTCUTS_STRUCTURE_POLICY_H +#define LATINIME_DICTIONARY_SHORTCUTS_STRUCTURE_POLICY_H + +#include "defines.h" + +namespace latinime { + +/* + * This class abstracts structure of shortcuts. + */ +class DictionaryShortcutsStructurePolicy { + public: + virtual ~DictionaryShortcutsStructurePolicy() {} + + virtual int getStartPos(const int pos) const = 0; + + virtual void getNextShortcut(const int maxCodePointCount, int *const outCodePoint, + int *const outCodePointCount, bool *const outIsWhitelist, bool *const outHasNext, + int *const pos) const = 0; + + virtual void skipAllShortcuts(int *const pos) const = 0; + + protected: + DictionaryShortcutsStructurePolicy() {} + + private: + DISALLOW_COPY_AND_ASSIGN(DictionaryShortcutsStructurePolicy); +}; +} // namespace latinime +#endif /* LATINIME_DICTIONARY_SHORTCUTS_STRUCTURE_POLICY_H */ diff --git a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h new file mode 100644 index 000000000..b95488ebd --- /dev/null +++ b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h @@ -0,0 +1,90 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DICTIONARY_STRUCTURE_POLICY_H +#define LATINIME_DICTIONARY_STRUCTURE_POLICY_H + +#include "defines.h" + +namespace latinime { + +class DicNode; +class DicNodeVector; +class DictionaryBigramsStructurePolicy; +class DictionaryHeaderStructurePolicy; +class DictionaryShortcutsStructurePolicy; + +/* + * This class abstracts structure of dictionaries. + * Implement this policy to support additional dictionaries. + */ +class DictionaryStructureWithBufferPolicy { + public: + virtual ~DictionaryStructureWithBufferPolicy() {} + + virtual int getRootPosition() const = 0; + + virtual void createAndGetAllChildNodes(const DicNode *const dicNode, + DicNodeVector *const childDicNodes) const = 0; + + virtual int getCodePointsAndProbabilityAndReturnCodePointCount( + const int nodePos, const int maxCodePointCount, int *const outCodePoints, + int *const outUnigramProbability) const = 0; + + virtual int getTerminalNodePositionOfWord(const int *const inWord, + const int length, const bool forceLowerCaseSearch) const = 0; + + virtual int getProbability(const int unigramProbability, + const int bigramProbability) const = 0; + + virtual int getUnigramProbabilityOfPtNode(const int nodePos) const = 0; + + virtual int getShortcutPositionOfPtNode(const int nodePos) const = 0; + + virtual int getBigramsPositionOfPtNode(const int nodePos) const = 0; + + virtual const DictionaryHeaderStructurePolicy *getHeaderStructurePolicy() const = 0; + + virtual const DictionaryBigramsStructurePolicy *getBigramsStructurePolicy() const = 0; + + virtual const DictionaryShortcutsStructurePolicy *getShortcutsStructurePolicy() const = 0; + + // Returns whether the update was success or not. + virtual bool addUnigramWord(const int *const word, const int length, + const int probability) = 0; + + // Returns whether the update was success or not. + virtual bool addBigramWords(const int *const word0, const int length0, const int *const word1, + const int length1, const int probability) = 0; + + // Returns whether the update was success or not. + virtual bool removeBigramWords(const int *const word0, const int length0, + const int *const word1, const int length1) = 0; + + virtual void flush(const char *const filePath) = 0; + + virtual void flushWithGC(const char *const filePath) = 0; + + virtual bool needsToRunGC() const = 0; + + protected: + DictionaryStructureWithBufferPolicy() {} + + private: + DISALLOW_COPY_AND_ASSIGN(DictionaryStructureWithBufferPolicy); +}; +} // namespace latinime +#endif /* LATINIME_DICTIONARY_STRUCTURE_POLICY_H */ diff --git a/native/jni/src/suggest/core/policy/traversal.h b/native/jni/src/suggest/core/policy/traversal.h index c6f66f231..e935533f2 100644 --- a/native/jni/src/suggest/core/policy/traversal.h +++ b/native/jni/src/suggest/core/policy/traversal.h @@ -45,9 +45,9 @@ class Traversal { const DicNode *const dicNode) const = 0; virtual bool needsToTraverseAllUserInput() const = 0; virtual float getMaxSpatialDistance() const = 0; - virtual bool allowPartialCommit() const = 0; + virtual bool autoCorrectsToMultiWordSuggestionIfTop() const = 0; virtual int getDefaultExpandDicNodeSize() const = 0; - virtual int getMaxCacheSize() const = 0; + virtual int getMaxCacheSize(const int inputSize) const = 0; virtual bool isPossibleOmissionChildNode(const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0; virtual bool isGoodToTraverseNextWord(const DicNode *const dicNode) const = 0; diff --git a/native/jni/src/suggest/core/policy/weighting.cpp b/native/jni/src/suggest/core/policy/weighting.cpp index 0c57ca001..f9b777df2 100644 --- a/native/jni/src/suggest/core/policy/weighting.cpp +++ b/native/jni/src/suggest/core/policy/weighting.cpp @@ -50,6 +50,9 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n case CT_TERMINAL: PROF_TERMINAL(node->mProfiler); return; + case CT_TERMINAL_INSERTION: + PROF_TERMINAL_INSERTION(node->mProfiler); + return; case CT_NEW_WORD_SPACE_SUBSTITUTION: PROF_SPACE_SUBSTITUTION(node->mProfiler); return; @@ -106,13 +109,15 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n // only used for typing return weighting->getSubstitutionCost(); case CT_NEW_WORD_SPACE_OMITTION: - return weighting->getNewWordCost(traverseSession, dicNode); + return weighting->getNewWordSpatialCost(traverseSession, dicNode, inputStateG); case CT_MATCH: return weighting->getMatchedCost(traverseSession, dicNode, inputStateG); case CT_COMPLETION: return weighting->getCompletionCost(traverseSession, dicNode); case CT_TERMINAL: return weighting->getTerminalSpatialCost(traverseSession, dicNode); + case CT_TERMINAL_INSERTION: + return weighting->getTerminalInsertionCost(traverseSession, dicNode); case CT_NEW_WORD_SPACE_SUBSTITUTION: return weighting->getSpaceSubstitutionCost(traverseSession, dicNode); case CT_INSERTION: @@ -134,7 +139,8 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n case CT_SUBSTITUTION: return 0.0f; case CT_NEW_WORD_SPACE_OMITTION: - return weighting->getNewWordBigramCost(traverseSession, parentDicNode, multiBigramMap); + return weighting->getNewWordBigramLanguageCost( + traverseSession, parentDicNode, multiBigramMap); case CT_MATCH: return 0.0f; case CT_COMPLETION: @@ -142,11 +148,14 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n case CT_TERMINAL: { const float languageImprobability = DicNodeUtils::getBigramNodeImprobability( - traverseSession->getBinaryDictionaryInfo(), dicNode, multiBigramMap); + traverseSession->getDictionaryStructurePolicy(), dicNode, multiBigramMap); return weighting->getTerminalLanguageCost(traverseSession, dicNode, languageImprobability); } + case CT_TERMINAL_INSERTION: + return 0.0f; case CT_NEW_WORD_SPACE_SUBSTITUTION: - return weighting->getNewWordBigramCost(traverseSession, parentDicNode, multiBigramMap); + return weighting->getNewWordBigramLanguageCost( + traverseSession, parentDicNode, multiBigramMap); case CT_INSERTION: return 0.0f; case CT_TRANSPOSITION: @@ -161,9 +170,9 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n case CT_OMISSION: return 0; case CT_ADDITIONAL_PROXIMITY: - return 0; + return 0; /* 0 because CT_MATCH will be called */ case CT_SUBSTITUTION: - return 0; + return 0; /* 0 because CT_MATCH will be called */ case CT_NEW_WORD_SPACE_OMITTION: return 0; case CT_MATCH: @@ -172,12 +181,14 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n return 1; case CT_TERMINAL: return 0; + case CT_TERMINAL_INSERTION: + return 1; case CT_NEW_WORD_SPACE_SUBSTITUTION: return 1; case CT_INSERTION: - return 2; + return 2; /* look ahead + skip the current char */ case CT_TRANSPOSITION: - return 2; + return 2; /* look ahead + skip the current char */ default: return 0; } diff --git a/native/jni/src/suggest/core/policy/weighting.h b/native/jni/src/suggest/core/policy/weighting.h index 0d2745b40..2d49e98a6 100644 --- a/native/jni/src/suggest/core/policy/weighting.h +++ b/native/jni/src/suggest/core/policy/weighting.h @@ -56,10 +56,10 @@ class Weighting { const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0; - virtual float getNewWordCost(const DicTraverseSession *const traverseSession, - const DicNode *const dicNode) const = 0; + virtual float getNewWordSpatialCost(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode, DicNode_InputStateG *const inputStateG) const = 0; - virtual float getNewWordBigramCost( + virtual float getNewWordBigramLanguageCost( const DicTraverseSession *const traverseSession, const DicNode *const dicNode, MultiBigramMap *const multiBigramMap) const = 0; @@ -67,6 +67,10 @@ class Weighting { const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const = 0; + virtual float getTerminalInsertionCost( + const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const = 0; + virtual float getTerminalLanguageCost( const DicTraverseSession *const traverseSession, const DicNode *const dicNode, float dicNodeLanguageImprobability) const = 0; diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.cpp b/native/jni/src/suggest/core/session/dic_traverse_session.cpp index c398caefa..50f2bbd8d 100644 --- a/native/jni/src/suggest/core/session/dic_traverse_session.cpp +++ b/native/jni/src/suggest/core/session/dic_traverse_session.cpp @@ -17,35 +17,35 @@ #include "suggest/core/session/dic_traverse_session.h" #include "defines.h" -#include "jni.h" -#include "suggest/core/dicnode/dic_node_utils.h" -#include "suggest/core/dictionary/binary_dictionary_info.h" -#include "suggest/core/dictionary/binary_format.h" #include "suggest/core/dictionary/dictionary.h" +#include "suggest/core/policy/dictionary_header_structure_policy.h" +#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" namespace latinime { +// 256K bytes threshold is heuristically used to distinguish dictionaries containing many unigrams +// (e.g. main dictionary) from small dictionaries (e.g. contacts...) +const int DicTraverseSession::DICTIONARY_SIZE_THRESHOLD_TO_USE_LARGE_CACHE_FOR_SUGGESTION = + 256 * 1024; + void DicTraverseSession::init(const Dictionary *const dictionary, const int *prevWord, int prevWordLength, const SuggestOptions *const suggestOptions) { mDictionary = dictionary; - mMultiWordCostMultiplier = BinaryFormat::getMultiWordCostMultiplier( - mDictionary->getBinaryDictionaryInfo()->getDictBuf(), - mDictionary->getDictSize()); + mMultiWordCostMultiplier = getDictionaryStructurePolicy()->getHeaderStructurePolicy() + ->getMultiWordCostMultiplier(); mSuggestOptions = suggestOptions; if (!prevWord) { - mPrevWordPos = NOT_VALID_WORD; + mPrevWordPos = NOT_A_DICT_POS; return; } // TODO: merge following similar calls to getTerminalPosition into one case-insensitive call. - mPrevWordPos = BinaryFormat::getTerminalPosition( - dictionary->getBinaryDictionaryInfo()->getDictRoot(), prevWord, - prevWordLength, false /* forceLowerCaseSearch */); - if (mPrevWordPos == NOT_VALID_WORD) { + mPrevWordPos = getDictionaryStructurePolicy()->getTerminalNodePositionOfWord( + prevWord, prevWordLength, false /* forceLowerCaseSearch */); + if (mPrevWordPos == NOT_A_DICT_POS) { // Check bigrams for lower-cased previous word if original was not found. Useful for // auto-capitalized words like "The [current_word]". - mPrevWordPos = BinaryFormat::getTerminalPosition( - dictionary->getBinaryDictionaryInfo()->getDictRoot(), prevWord, - prevWordLength, true /* forceLowerCaseSearch */); + mPrevWordPos = getDictionaryStructurePolicy()->getTerminalNodePositionOfWord( + prevWord, prevWordLength, true /* forceLowerCaseSearch */); } } @@ -59,16 +59,14 @@ void DicTraverseSession::setupForGetSuggestions(const ProximityInfo *pInfo, maxSpatialDistance, maxPointerCount); } -const BinaryDictionaryInfo *DicTraverseSession::getBinaryDictionaryInfo() const { - return mDictionary->getBinaryDictionaryInfo(); -} - -int DicTraverseSession::getDictFlags() const { - return mDictionary->getDictFlags(); +const DictionaryStructureWithBufferPolicy *DicTraverseSession::getDictionaryStructurePolicy() + const { + return mDictionary->getDictionaryStructurePolicy(); } -void DicTraverseSession::resetCache(const int nextActiveCacheSize, const int maxWords) { - mDicNodesCache.reset(nextActiveCacheSize, maxWords); +void DicTraverseSession::resetCache(const int thresholdForNextActiveDicNodes, const int maxWords) { + mDicNodesCache.reset(thresholdForNextActiveDicNodes /* nextActiveSize */, + maxWords /* terminalSize */); mMultiBigramMap.clear(); mPartiallyCommited = false; } diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.h b/native/jni/src/suggest/core/session/dic_traverse_session.h index 630b3b59b..e0b1c67d9 100644 --- a/native/jni/src/suggest/core/session/dic_traverse_session.h +++ b/native/jni/src/suggest/core/session/dic_traverse_session.h @@ -28,8 +28,8 @@ namespace latinime { -class BinaryDictionaryInfo; class Dictionary; +class DictionaryStructureWithBufferPolicy; class ProximityInfo; class SuggestOptions; @@ -37,8 +37,12 @@ class DicTraverseSession { public: // A factory method for DicTraverseSession - static AK_FORCE_INLINE void *getSessionInstance(JNIEnv *env, jstring localeStr) { - return new DicTraverseSession(env, localeStr); + static AK_FORCE_INLINE void *getSessionInstance(JNIEnv *env, jstring localeStr, + jlong dictSize) { + // To deal with the trade-off between accuracy and memory space, large cache is used for + // dictionaries larger that the threshold + return new DicTraverseSession(env, localeStr, + dictSize >= DICTIONARY_SIZE_THRESHOLD_TO_USE_LARGE_CACHE_FOR_SUGGESTION); } static AK_FORCE_INLINE void initSessionInstance(DicTraverseSession *traverseSession, @@ -54,10 +58,10 @@ class DicTraverseSession { delete traverseSession; } - AK_FORCE_INLINE DicTraverseSession(JNIEnv *env, jstring localeStr) - : mPrevWordPos(NOT_VALID_WORD), mProximityInfo(0), - mDictionary(0), mSuggestOptions(0), mDicNodesCache(), mMultiBigramMap(), - mInputSize(0), mPartiallyCommited(false), mMaxPointerCount(1), + AK_FORCE_INLINE DicTraverseSession(JNIEnv *env, jstring localeStr, bool usesLargeCache) + : mPrevWordPos(NOT_A_DICT_POS), mProximityInfo(0), + mDictionary(0), mSuggestOptions(0), mDicNodesCache(usesLargeCache), + mMultiBigramMap(), mInputSize(0), mPartiallyCommited(false), mMaxPointerCount(1), mMultiWordCostMultiplier(1.0f) { // NOTE: mProximityInfoStates is an array of instances. // No need to initialize it explicitly here. @@ -73,11 +77,9 @@ class DicTraverseSession { const int inputSize, const int *const inputXs, const int *const inputYs, const int *const times, const int *const pointerIds, const float maxSpatialDistance, const int maxPointerCount); - void resetCache(const int nextActiveCacheSize, const int maxWords); + void resetCache(const int thresholdForNextActiveDicNodes, const int maxWords); - // TODO: Remove - const BinaryDictionaryInfo *getBinaryDictionaryInfo() const; - int getDictFlags() const; + const DictionaryStructureWithBufferPolicy *getDictionaryStructurePolicy() const; //-------------------- // getters and setters @@ -111,7 +113,9 @@ class DicTraverseSession { if (usedPointerCount != 1) { return false; } - *pointerId = usedPointerId; + if (pointerId) { + *pointerId = usedPointerId; + } return true; } @@ -183,6 +187,7 @@ class DicTraverseSession { DISALLOW_IMPLICIT_CONSTRUCTORS(DicTraverseSession); // threshold to start caching static const int CACHE_START_INPUT_LENGTH_THRESHOLD; + static const int DICTIONARY_SIZE_THRESHOLD_TO_USE_LARGE_CACHE_FOR_SUGGESTION; void initializeProximityInfoStates(const int *const inputCodePoints, const int *const inputXs, const int *const inputYs, const int *const times, const int *const pointerIds, const int inputSize, const float maxSpatialDistance, const int maxPointerCount); diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp index 1f108e400..b1340e12f 100644 --- a/native/jni/src/suggest/core/suggest.cpp +++ b/native/jni/src/suggest/core/suggest.cpp @@ -19,11 +19,12 @@ #include "suggest/core/dicnode/dic_node.h" #include "suggest/core/dicnode/dic_node_priority_queue.h" #include "suggest/core/dicnode/dic_node_vector.h" +#include "suggest/core/dictionary/binary_dictionary_shortcut_iterator.h" #include "suggest/core/dictionary/dictionary.h" #include "suggest/core/dictionary/digraph_utils.h" #include "suggest/core/dictionary/shortcut_utils.h" -#include "suggest/core/dictionary/terminal_attributes.h" #include "suggest/core/layout/proximity_info.h" +#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" #include "suggest/core/policy/scoring.h" #include "suggest/core/policy/traversal.h" #include "suggest/core/policy/weighting.h" @@ -83,9 +84,9 @@ void Suggest::initializeSearch(DicTraverseSession *traverseSession, int commitPo if (!traverseSession->getProximityInfoState(0)->isUsed()) { return; } - if (TRAVERSAL->allowPartialCommit()) { - commitPoint = 0; - } + + // Never auto partial commit for now. + commitPoint = 0; if (traverseSession->getInputSize() > MIN_CONTINUOUS_SUGGESTION_INPUT_SIZE && traverseSession->isContinuousSuggestionPossible()) { @@ -102,10 +103,11 @@ void Suggest::initializeSearch(DicTraverseSession *traverseSession, int commitPo } } else { // Restart recognition at the root. - traverseSession->resetCache(TRAVERSAL->getMaxCacheSize(), MAX_RESULTS); + traverseSession->resetCache(TRAVERSAL->getMaxCacheSize(traverseSession->getInputSize()), + MAX_RESULTS); // Create a new dic node here DicNode rootNode; - DicNodeUtils::initAsRoot(traverseSession->getBinaryDictionaryInfo(), + DicNodeUtils::initAsRoot(traverseSession->getDictionaryStructurePolicy(), traverseSession->getPrevWordPos(), &rootNode); traverseSession->getDicTraverseCache()->copyPushActive(&rootNode); } @@ -115,7 +117,7 @@ void Suggest::initializeSearch(DicTraverseSession *traverseSession, int commitPo * Outputs the final list of suggestions (i.e., terminal nodes). */ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequencies, - int *outputCodePoints, int *spaceIndices, int *outputTypes) const { + int *outputCodePoints, int *outputIndicesToPartialCommit, int *outputTypes) const { #if DEBUG_EVALUATE_MOST_PROBABLE_STRING const int terminalSize = 0; #else @@ -137,6 +139,7 @@ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequen SCORING->getMostProbableString(traverseSession, terminalSize, languageWeight, &outputCodePoints[0], &outputTypes[0], &frequencies[0]); if (hasMostProbableString) { + outputIndicesToPartialCommit[outputWordIndex] = NOT_AN_INDEX; ++outputWordIndex; } @@ -147,6 +150,20 @@ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequen &doubleLetterTerminalIndex, &doubleLetterLevel); int maxScore = S_INT_MIN; + // Force autocorrection for obvious long multi-word suggestions when the top suggestion is + // a long multiple words suggestion. + // TODO: Implement a smarter auto-commit method for handling multi-word suggestions. + // traverseSession->isPartiallyCommited() always returns false because we never auto partial + // commit for now. + const bool forceCommitMultiWords = (terminalSize > 0) ? + TRAVERSAL->autoCorrectsToMultiWordSuggestionIfTop() + && (traverseSession->isPartiallyCommited() + || (traverseSession->getInputSize() + >= MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT + && terminals[0].hasMultipleWords())) : false; + // TODO: have partial commit work even with multiple pointers. + const bool outputSecondWordFirstLetterInputIndex = + traverseSession->isOnlyOnePointerUsed(0 /* pointerId */); // Output suggestion results here for (int terminalIndex = 0; terminalIndex < terminalSize && outputWordIndex < MAX_RESULTS; ++terminalIndex) { @@ -158,9 +175,9 @@ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequen terminalIndex, doubleLetterTerminalIndex, doubleLetterLevel); const float compoundDistance = terminalDicNode->getCompoundDistance(languageWeight) + doubleLetterCost; - const TerminalAttributes terminalAttributes(traverseSession->getBinaryDictionaryInfo(), - terminalDicNode->getFlags(), terminalDicNode->getAttributesPos()); - const bool isPossiblyOffensiveWord = terminalDicNode->getProbability() <= 0; + const bool isPossiblyOffensiveWord = + traverseSession->getDictionaryStructurePolicy()->getProbability( + terminalDicNode->getProbability(), NOT_A_PROBABILITY) <= 0; const bool isExactMatch = terminalDicNode->isExactMatch(); const bool isFirstCharUppercase = terminalDicNode->isFirstCharUppercase(); // Heuristic: We exclude freq=0 first-char-uppercase words from exact match. @@ -172,42 +189,58 @@ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequen | (isSafeExactMatch ? Dictionary::KIND_FLAG_EXACT_MATCH : 0); // Entries that are blacklisted or do not represent a word should not be output. - const bool isValidWord = !terminalAttributes.isBlacklistedOrNotAWord(); + const bool isValidWord = !terminalDicNode->isBlacklistedOrNotAWord(); // Increase output score of top typing suggestion to ensure autocorrection. // TODO: Better integration with java side autocorrection logic. - // Force autocorrection for obvious long multi-word suggestions. - const bool isForceCommitMultiWords = TRAVERSAL->allowPartialCommit() - && (traverseSession->isPartiallyCommited() - || (traverseSession->getInputSize() >= MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT - && terminalDicNode->hasMultipleWords())); - const int finalScore = SCORING->calculateFinalScore( compoundDistance, traverseSession->getInputSize(), - isForceCommitMultiWords || (isValidWord && SCORING->doesAutoCorrectValidWord())); - - maxScore = max(maxScore, finalScore); - - if (TRAVERSAL->allowPartialCommit()) { - // Index for top typing suggestion should be 0. - if (isValidWord && outputWordIndex == 0) { - terminalDicNode->outputSpacePositionsResult(spaceIndices); - } + terminalDicNode->isExactMatch() + || (forceCommitMultiWords && terminalDicNode->hasMultipleWords()) + || (isValidWord && SCORING->doesAutoCorrectValidWord())); + if (maxScore < finalScore && isValidWord) { + maxScore = finalScore; } // Don't output invalid words. However, we still need to submit their shortcuts if any. if (isValidWord) { outputTypes[outputWordIndex] = Dictionary::KIND_CORRECTION | outputTypeFlags; frequencies[outputWordIndex] = finalScore; + if (outputSecondWordFirstLetterInputIndex) { + outputIndicesToPartialCommit[outputWordIndex] = + terminalDicNode->getSecondWordFirstInputIndex( + traverseSession->getProximityInfoState(0)); + } else { + outputIndicesToPartialCommit[outputWordIndex] = NOT_AN_INDEX; + } // Populate the outputChars array with the suggested word. const int startIndex = outputWordIndex * MAX_WORD_LENGTH; terminalDicNode->outputResult(&outputCodePoints[startIndex]); ++outputWordIndex; } - const bool sameAsTyped = TRAVERSAL->sameAsTyped(traverseSession, terminalDicNode); - outputWordIndex = ShortcutUtils::outputShortcuts(&terminalAttributes, outputWordIndex, - finalScore, outputCodePoints, frequencies, outputTypes, sameAsTyped); + if (!terminalDicNode->hasMultipleWords()) { + BinaryDictionaryShortcutIterator shortcutIt( + traverseSession->getDictionaryStructurePolicy()->getShortcutsStructurePolicy(), + traverseSession->getDictionaryStructurePolicy() + ->getShortcutPositionOfPtNode(terminalDicNode->getPos())); + // Shortcut is not supported for multiple words suggestions. + // TODO: Check shortcuts during traversal for multiple words suggestions. + const bool sameAsTyped = TRAVERSAL->sameAsTyped(traverseSession, terminalDicNode); + const int updatedOutputWordIndex = ShortcutUtils::outputShortcuts(&shortcutIt, + outputWordIndex, finalScore, outputCodePoints, frequencies, outputTypes, + sameAsTyped); + const int secondWordFirstInputIndex = terminalDicNode->getSecondWordFirstInputIndex( + traverseSession->getProximityInfoState(0)); + for (int i = outputWordIndex; i < updatedOutputWordIndex; ++i) { + if (outputSecondWordFirstLetterInputIndex) { + outputIndicesToPartialCommit[i] = secondWordFirstInputIndex; + } else { + outputIndicesToPartialCommit[i] = NOT_AN_INDEX; + } + } + outputWordIndex = updatedOutputWordIndex; + } DicNode::managedDelete(terminalDicNode); } @@ -284,7 +317,7 @@ void Suggest::expandCurrentDicNodes(DicTraverseSession *traverseSession) const { } DicNodeUtils::getAllChildDicNodes( - &dicNode, traverseSession->getBinaryDictionaryInfo(), &childDicNodes); + &dicNode, traverseSession->getDictionaryStructurePolicy(), &childDicNodes); const int childDicNodesSize = childDicNodes.getSizeAndLock(); for (int i = 0; i < childDicNodesSize; ++i) { @@ -294,7 +327,9 @@ void Suggest::expandCurrentDicNodes(DicTraverseSession *traverseSession) const { processDicNodeAsMatch(traverseSession, childDicNode); continue; } - if (DigraphUtils::hasDigraphForCodePoint(traverseSession->getDictFlags(), + if (DigraphUtils::hasDigraphForCodePoint( + traverseSession->getDictionaryStructurePolicy() + ->getHeaderStructurePolicy(), childDicNode->getNodeCodePoint())) { correctionDicNode.initByCopy(childDicNode); correctionDicNode.advanceDigraphIndex(); @@ -351,17 +386,17 @@ void Suggest::processTerminalDicNode( if (!dicNode->isTerminalWordNode()) { return; } - if (TRAVERSAL->needsToTraverseAllUserInput() - && dicNode->getInputIndex(0) < traverseSession->getInputSize()) { - return; - } - if (dicNode->shouldBeFilterdBySafetyNetForBigram()) { return; } // Create a non-cached node here. DicNode terminalDicNode; DicNodeUtils::initByCopy(dicNode, &terminalDicNode); + if (TRAVERSAL->needsToTraverseAllUserInput() + && dicNode->getInputIndex(0) < traverseSession->getInputSize()) { + Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TERMINAL_INSERTION, traverseSession, 0, + &terminalDicNode, traverseSession->getMultiBigramMap()); + } Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TERMINAL, traverseSession, 0, &terminalDicNode, traverseSession->getMultiBigramMap()); traverseSession->getDicTraverseCache()->copyPushTerminal(&terminalDicNode); @@ -432,7 +467,7 @@ void Suggest::processDicNodeAsOmission( DicTraverseSession *traverseSession, DicNode *dicNode) const { DicNodeVector childDicNodes; DicNodeUtils::getAllChildDicNodes( - dicNode, traverseSession->getBinaryDictionaryInfo(), &childDicNodes); + dicNode, traverseSession->getDictionaryStructurePolicy(), &childDicNodes); const int size = childDicNodes.getSizeAndLock(); for (int i = 0; i < size; i++) { @@ -441,7 +476,6 @@ void Suggest::processDicNodeAsOmission( Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_OMISSION, traverseSession, dicNode, childDicNode, 0 /* multiBigramMap */); weightChildNode(traverseSession, childDicNode); - if (!TRAVERSAL->isPossibleOmissionChildNode(traverseSession, dicNode, childDicNode)) { continue; } @@ -457,10 +491,14 @@ void Suggest::processDicNodeAsInsertion(DicTraverseSession *traverseSession, DicNode *dicNode) const { const int16_t pointIndex = dicNode->getInputIndex(0); DicNodeVector childDicNodes; - DicNodeUtils::getProximityChildDicNodes(dicNode, traverseSession->getBinaryDictionaryInfo(), - traverseSession->getProximityInfoState(0), pointIndex + 1, true, &childDicNodes); + DicNodeUtils::getAllChildDicNodes(dicNode, traverseSession->getDictionaryStructurePolicy(), + &childDicNodes); const int size = childDicNodes.getSizeAndLock(); for (int i = 0; i < size; i++) { + if (traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt(pointIndex + 1) + != childDicNodes[i]->getNodeCodePoint()) { + continue; + } DicNode *const childDicNode = childDicNodes[i]; Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_INSERTION, traverseSession, dicNode, childDicNode, 0 /* multiBigramMap */); @@ -475,18 +513,29 @@ void Suggest::processDicNodeAsTransposition(DicTraverseSession *traverseSession, DicNode *dicNode) const { const int16_t pointIndex = dicNode->getInputIndex(0); DicNodeVector childDicNodes1; - DicNodeUtils::getProximityChildDicNodes(dicNode, traverseSession->getBinaryDictionaryInfo(), - traverseSession->getProximityInfoState(0), pointIndex + 1, false, &childDicNodes1); + DicNodeUtils::getAllChildDicNodes(dicNode, traverseSession->getDictionaryStructurePolicy(), + &childDicNodes1); const int childSize1 = childDicNodes1.getSizeAndLock(); for (int i = 0; i < childSize1; i++) { + const ProximityType matchedId1 = traverseSession->getProximityInfoState(0) + ->getProximityType(pointIndex + 1, childDicNodes1[i]->getNodeCodePoint(), + true /* checkProximityChars */); + if (!ProximityInfoUtils::isMatchOrProximityChar(matchedId1)) { + continue; + } if (childDicNodes1[i]->hasChildren()) { DicNodeVector childDicNodes2; - DicNodeUtils::getProximityChildDicNodes( - childDicNodes1[i], traverseSession->getBinaryDictionaryInfo(), - traverseSession->getProximityInfoState(0), pointIndex, false, &childDicNodes2); + DicNodeUtils::getAllChildDicNodes(childDicNodes1[i], + traverseSession->getDictionaryStructurePolicy(), &childDicNodes2); const int childSize2 = childDicNodes2.getSizeAndLock(); for (int j = 0; j < childSize2; j++) { DicNode *const childDicNode2 = childDicNodes2[j]; + const ProximityType matchedId2 = traverseSession->getProximityInfoState(0) + ->getProximityType(pointIndex, childDicNode2->getNodeCodePoint(), + true /* checkProximityChars */); + if (!ProximityInfoUtils::isMatchOrProximityChar(matchedId2)) { + continue; + } Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TRANSPOSITION, traverseSession, childDicNodes1[i], childDicNode2, 0 /* multiBigramMap */); processExpandedDicNode(traverseSession, childDicNode2); @@ -523,11 +572,17 @@ void Suggest::createNextWordDicNode(DicTraverseSession *traverseSession, DicNode // Create a non-cached node here. DicNode newDicNode; DicNodeUtils::initAsRootWithPreviousWord( - traverseSession->getBinaryDictionaryInfo(), dicNode, &newDicNode); + traverseSession->getDictionaryStructurePolicy(), dicNode, &newDicNode); const CorrectionType correctionType = spaceSubstitution ? CT_NEW_WORD_SPACE_SUBSTITUTION : CT_NEW_WORD_SPACE_OMITTION; Weighting::addCostAndForwardInputIndex(WEIGHTING, correctionType, traverseSession, dicNode, &newDicNode, traverseSession->getMultiBigramMap()); - traverseSession->getDicTraverseCache()->copyPushNextActive(&newDicNode); + if (newDicNode.getCompoundDistance() < static_cast<float>(MAX_VALUE_FOR_WEIGHTING)) { + // newDicNode is worth continuing to traverse. + // CAVEAT: This pruning is important for speed. Remove this when we can afford not to prune + // here because here is not the right place to do pruning. Pruning should take place only + // in DicNodePriorityQueue. + traverseSession->getDicTraverseCache()->copyPushNextActive(&newDicNode); + } } } // namespace latinime diff --git a/native/jni/src/suggest/core/suggest.h b/native/jni/src/suggest/core/suggest.h index 875cbe4e0..b24019632 100644 --- a/native/jni/src/suggest/core/suggest.h +++ b/native/jni/src/suggest/core/suggest.h @@ -55,7 +55,7 @@ class Suggest : public SuggestInterface { void createNextWordDicNode(DicTraverseSession *traverseSession, DicNode *dicNode, const bool spaceSubstitution) const; int outputSuggestions(DicTraverseSession *traverseSession, int *frequencies, - int *outputCodePoints, int *outputIndices, int *outputTypes) const; + int *outputCodePoints, int *outputIndicesToPartialCommit, int *outputTypes) const; void initializeSearch(DicTraverseSession *traverseSession, int commitPoint) const; void expandCurrentDicNodes(DicTraverseSession *traverseSession) const; void processTerminalDicNode(DicTraverseSession *traverseSession, DicNode *dicNode) const; diff --git a/native/jni/src/suggest/policyimpl/dictionary/bigram/bigram_list_policy.h b/native/jni/src/suggest/policyimpl/dictionary/bigram/bigram_list_policy.h new file mode 100644 index 000000000..6ff95cac4 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/bigram/bigram_list_policy.h @@ -0,0 +1,53 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_BIGRAM_LIST_POLICY_H +#define LATINIME_BIGRAM_LIST_POLICY_H + +#include <stdint.h> + +#include "defines.h" +#include "suggest/core/policy/dictionary_bigrams_structure_policy.h" +#include "suggest/policyimpl/dictionary/bigram/bigram_list_read_write_utils.h" + +namespace latinime { + +class BigramListPolicy : public DictionaryBigramsStructurePolicy { + public: + explicit BigramListPolicy(const uint8_t *const bigramsBuf) : mBigramsBuf(bigramsBuf) {} + + ~BigramListPolicy() {} + + void getNextBigram(int *const outBigramPos, int *const outProbability, bool *const outHasNext, + int *const pos) const { + BigramListReadWriteUtils::BigramFlags flags; + BigramListReadWriteUtils::getBigramEntryPropertiesAndAdvancePosition(mBigramsBuf, &flags, + outBigramPos, pos); + *outProbability = BigramListReadWriteUtils::getProbabilityFromFlags(flags); + *outHasNext = BigramListReadWriteUtils::hasNext(flags); + } + + void skipAllBigrams(int *const pos) const { + BigramListReadWriteUtils::skipExistingBigrams(mBigramsBuf, pos); + } + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(BigramListPolicy); + + const uint8_t *const mBigramsBuf; +}; +} // namespace latinime +#endif // LATINIME_BIGRAM_LIST_POLICY_H diff --git a/native/jni/src/suggest/policyimpl/dictionary/bigram/bigram_list_read_write_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/bigram/bigram_list_read_write_utils.cpp new file mode 100644 index 000000000..1926b9831 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/bigram/bigram_list_read_write_utils.cpp @@ -0,0 +1,182 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/bigram/bigram_list_read_write_utils.h" + +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.h" +#include "suggest/policyimpl/dictionary/utils/byte_array_utils.h" +#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" + +namespace latinime { + +const BigramListReadWriteUtils::BigramFlags BigramListReadWriteUtils::MASK_ATTRIBUTE_ADDRESS_TYPE = + 0x30; +const BigramListReadWriteUtils::BigramFlags + BigramListReadWriteUtils::FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE = 0x10; +const BigramListReadWriteUtils::BigramFlags + BigramListReadWriteUtils::FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES = 0x20; +const BigramListReadWriteUtils::BigramFlags + BigramListReadWriteUtils::FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES = 0x30; +const BigramListReadWriteUtils::BigramFlags + BigramListReadWriteUtils::FLAG_ATTRIBUTE_OFFSET_NEGATIVE = 0x40; +// Flag for presence of more attributes +const BigramListReadWriteUtils::BigramFlags BigramListReadWriteUtils::FLAG_ATTRIBUTE_HAS_NEXT = + 0x80; +// Mask for attribute probability, stored on 4 bits inside the flags byte. +const BigramListReadWriteUtils::BigramFlags + BigramListReadWriteUtils::MASK_ATTRIBUTE_PROBABILITY = 0x0F; +const int BigramListReadWriteUtils::ATTRIBUTE_ADDRESS_SHIFT = 4; + +/* static */ void BigramListReadWriteUtils::getBigramEntryPropertiesAndAdvancePosition( + const uint8_t *const bigramsBuf, BigramFlags *const outBigramFlags, + int *const outTargetPtNodePos, int *const bigramEntryPos) { + const BigramFlags bigramFlags = ByteArrayUtils::readUint8AndAdvancePosition(bigramsBuf, + bigramEntryPos); + if (outBigramFlags) { + *outBigramFlags = bigramFlags; + } + const int targetPos = getBigramAddressAndAdvancePosition(bigramsBuf, bigramFlags, + bigramEntryPos); + if (outTargetPtNodePos) { + *outTargetPtNodePos = targetPos; + } +} + +/* static */ void BigramListReadWriteUtils::skipExistingBigrams(const uint8_t *const bigramsBuf, + int *const bigramListPos) { + BigramFlags flags; + do { + getBigramEntryPropertiesAndAdvancePosition(bigramsBuf, &flags, 0 /* outTargetPtNodePos */, + bigramListPos); + } while(hasNext(flags)); +} + +/* static */ int BigramListReadWriteUtils::getBigramAddressAndAdvancePosition( + const uint8_t *const bigramsBuf, const BigramFlags flags, int *const pos) { + int offset = 0; + const int origin = *pos; + switch (MASK_ATTRIBUTE_ADDRESS_TYPE & flags) { + case FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE: + offset = ByteArrayUtils::readUint8AndAdvancePosition(bigramsBuf, pos); + break; + case FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES: + offset = ByteArrayUtils::readUint16AndAdvancePosition(bigramsBuf, pos); + break; + case FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES: + offset = ByteArrayUtils::readUint24AndAdvancePosition(bigramsBuf, pos); + break; + } + if (offset == DynamicPatriciaTrieReadingUtils::DICT_OFFSET_INVALID) { + return NOT_A_DICT_POS; + } else if (offset == DynamicPatriciaTrieReadingUtils::DICT_OFFSET_ZERO_OFFSET) { + return origin; + } + if (isOffsetNegative(flags)) { + return origin - offset; + } else { + return origin + offset; + } +} + +/* static */ bool BigramListReadWriteUtils::setHasNextFlag( + BufferWithExtendableBuffer *const buffer, const bool hasNext, const int entryPos) { + const bool usesAdditionalBuffer = buffer->isInAdditionalBuffer(entryPos); + int readingPos = entryPos; + if (usesAdditionalBuffer) { + readingPos -= buffer->getOriginalBufferSize(); + } + BigramFlags bigramFlags = ByteArrayUtils::readUint8AndAdvancePosition( + buffer->getBuffer(usesAdditionalBuffer), &readingPos); + if (hasNext) { + bigramFlags = bigramFlags | FLAG_ATTRIBUTE_HAS_NEXT; + } else { + bigramFlags = bigramFlags & (~FLAG_ATTRIBUTE_HAS_NEXT); + } + int writingPos = entryPos; + return buffer->writeUintAndAdvancePosition(bigramFlags, 1 /* size */, &writingPos); +} + +/* static */ bool BigramListReadWriteUtils::createAndWriteBigramEntry( + BufferWithExtendableBuffer *const buffer, const int targetPos, const int probability, + const bool hasNext, int *const writingPos) { + BigramFlags flags; + if (!createAndGetBigramFlags(*writingPos, targetPos, probability, hasNext, &flags)) { + return false; + } + return writeBigramEntry(buffer, flags, targetPos, writingPos); +} + +/* static */ bool BigramListReadWriteUtils::writeBigramEntry( + BufferWithExtendableBuffer *const bufferToWrite, const BigramFlags flags, + const int targetPtNodePos, int *const writingPos) { + const int offset = getBigramTargetOffset(targetPtNodePos, *writingPos); + const BigramFlags flagsToWrite = (offset < 0) ? + (flags | FLAG_ATTRIBUTE_OFFSET_NEGATIVE) : (flags & ~FLAG_ATTRIBUTE_OFFSET_NEGATIVE); + if (!bufferToWrite->writeUintAndAdvancePosition(flagsToWrite, 1 /* size */, writingPos)) { + return false; + } + const uint32_t absOffest = abs(offset); + const int bigramTargetFieldSize = attributeAddressSize(flags); + return bufferToWrite->writeUintAndAdvancePosition(absOffest, bigramTargetFieldSize, + writingPos); +} + +// Returns true if the bigram entry is valid and put entry flags into out*. +/* static */ bool BigramListReadWriteUtils::createAndGetBigramFlags(const int entryPos, + const int targetPtNodePos, const int probability, const bool hasNext, + BigramFlags *const outBigramFlags) { + BigramFlags flags = probability & MASK_ATTRIBUTE_PROBABILITY; + if (hasNext) { + flags |= FLAG_ATTRIBUTE_HAS_NEXT; + } + const int offset = getBigramTargetOffset(targetPtNodePos, entryPos); + if (offset < 0) { + flags |= FLAG_ATTRIBUTE_OFFSET_NEGATIVE; + } + const uint32_t absOffest = abs(offset); + if ((absOffest >> 24) != 0) { + // Offset is too large. + return false; + } else if ((absOffest >> 16) != 0) { + flags |= FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES; + } else if ((absOffest >> 8) != 0) { + flags |= FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES; + } else { + flags |= FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE; + } + // Currently, all newly written bigram position fields are 3 bytes to simplify dictionary + // writing. + // TODO: Remove following 2 lines and optimize memory space. + flags = (flags & (~MASK_ATTRIBUTE_ADDRESS_TYPE)) | FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES; + *outBigramFlags = flags; + return true; +} + +/* static */ int BigramListReadWriteUtils::getBigramTargetOffset(const int targetPtNodePos, + const int entryPos) { + if (targetPtNodePos == NOT_A_DICT_POS) { + return DynamicPatriciaTrieReadingUtils::DICT_OFFSET_INVALID; + } else { + const int offset = targetPtNodePos - (entryPos + 1 /* bigramFlagsField */); + if (offset == 0) { + return DynamicPatriciaTrieReadingUtils::DICT_OFFSET_ZERO_OFFSET; + } else { + return offset; + } + } +} + +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/bigram/bigram_list_read_write_utils.h b/native/jni/src/suggest/policyimpl/dictionary/bigram/bigram_list_read_write_utils.h new file mode 100644 index 000000000..eabe4e099 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/bigram/bigram_list_read_write_utils.h @@ -0,0 +1,102 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_BIGRAM_LIST_READ_WRITE_UTILS_H +#define LATINIME_BIGRAM_LIST_READ_WRITE_UTILS_H + +#include <cstdlib> +#include <stdint.h> + +#include "defines.h" + +namespace latinime { + +class BufferWithExtendableBuffer; + +class BigramListReadWriteUtils { +public: + typedef uint8_t BigramFlags; + + static void getBigramEntryPropertiesAndAdvancePosition(const uint8_t *const bigramsBuf, + BigramFlags *const outBigramFlags, int *const outTargetPtNodePos, + int *const bigramEntryPos); + + static AK_FORCE_INLINE int getProbabilityFromFlags(const BigramFlags flags) { + return flags & MASK_ATTRIBUTE_PROBABILITY; + } + + static AK_FORCE_INLINE bool hasNext(const BigramFlags flags) { + return (flags & FLAG_ATTRIBUTE_HAS_NEXT) != 0; + } + + // Bigrams reading methods + static void skipExistingBigrams(const uint8_t *const bigramsBuf, int *const bigramListPos); + + // Returns the size of the bigram position field that is stored in bigram flags. + static AK_FORCE_INLINE int attributeAddressSize(const BigramFlags flags) { + return (flags & MASK_ATTRIBUTE_ADDRESS_TYPE) >> ATTRIBUTE_ADDRESS_SHIFT; + /* Note: this is a value-dependant optimization of what may probably be + more readably written this way: + switch (flags * BinaryFormat::MASK_ATTRIBUTE_ADDRESS_TYPE) { + case FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE: return 1; + case FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES: return 2; + case FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTE: return 3; + default: return 0; + } + */ + } + + static bool setHasNextFlag(BufferWithExtendableBuffer *const buffer, + const bool hasNext, const int entryPos); + + static AK_FORCE_INLINE BigramFlags setProbabilityInFlags(const BigramFlags flags, + const int probability) { + return (flags & (~MASK_ATTRIBUTE_PROBABILITY)) | (probability & MASK_ATTRIBUTE_PROBABILITY); + } + + static bool createAndWriteBigramEntry(BufferWithExtendableBuffer *const buffer, + const int targetPos, const int probability, const bool hasNext, int *const writingPos); + + static bool writeBigramEntry(BufferWithExtendableBuffer *const buffer, const BigramFlags flags, + const int targetOffset, int *const writingPos); + +private: + DISALLOW_IMPLICIT_CONSTRUCTORS(BigramListReadWriteUtils); + + static const BigramFlags MASK_ATTRIBUTE_ADDRESS_TYPE; + static const BigramFlags FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE; + static const BigramFlags FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES; + static const BigramFlags FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES; + static const BigramFlags FLAG_ATTRIBUTE_OFFSET_NEGATIVE; + static const BigramFlags FLAG_ATTRIBUTE_HAS_NEXT; + static const BigramFlags MASK_ATTRIBUTE_PROBABILITY; + static const int ATTRIBUTE_ADDRESS_SHIFT; + + // Returns true if the bigram entry is valid and put entry flags into out*. + static bool createAndGetBigramFlags(const int entryPos, const int targetPos, + const int probability, const bool hasNext, BigramFlags *const outBigramFlags); + + static AK_FORCE_INLINE bool isOffsetNegative(const BigramFlags flags) { + return (flags & FLAG_ATTRIBUTE_OFFSET_NEGATIVE) != 0; + } + + static int getBigramAddressAndAdvancePosition(const uint8_t *const bigramsBuf, + const BigramFlags flags, int *const pos); + + static int getBigramTargetOffset(const int targetPtNodePos, const int entryPos); +}; +} // namespace latinime +#endif // LATINIME_BIGRAM_LIST_READ_WRITE_UTILS_H diff --git a/native/jni/src/suggest/policyimpl/dictionary/bigram/dynamic_bigram_list_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/bigram/dynamic_bigram_list_policy.cpp new file mode 100644 index 000000000..29307b56a --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/bigram/dynamic_bigram_list_policy.cpp @@ -0,0 +1,336 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/bigram/dynamic_bigram_list_policy.h" + +#include "suggest/core/policy/dictionary_shortcuts_structure_policy.h" +#include "suggest/policyimpl/dictionary/bigram/bigram_list_read_write_utils.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_node_reader.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.h" +#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" + +namespace latinime { + +const int DynamicBigramListPolicy::CONTINUING_BIGRAM_LINK_COUNT_LIMIT = 10000; +const int DynamicBigramListPolicy::BIGRAM_ENTRY_COUNT_IN_A_BIGRAM_LIST_LIMIT = 100000; + +void DynamicBigramListPolicy::getNextBigram(int *const outBigramPos, int *const outProbability, + bool *const outHasNext, int *const bigramEntryPos) const { + const bool usesAdditionalBuffer = mBuffer->isInAdditionalBuffer(*bigramEntryPos); + const uint8_t *const buffer = mBuffer->getBuffer(usesAdditionalBuffer); + if (usesAdditionalBuffer) { + *bigramEntryPos -= mBuffer->getOriginalBufferSize(); + } + BigramListReadWriteUtils::BigramFlags bigramFlags; + int originalBigramPos; + BigramListReadWriteUtils::getBigramEntryPropertiesAndAdvancePosition(buffer, &bigramFlags, + &originalBigramPos, bigramEntryPos); + if (usesAdditionalBuffer && originalBigramPos != NOT_A_DICT_POS) { + originalBigramPos += mBuffer->getOriginalBufferSize(); + } + *outBigramPos = followBigramLinkAndGetCurrentBigramPtNodePos(originalBigramPos); + *outProbability = BigramListReadWriteUtils::getProbabilityFromFlags(bigramFlags); + *outHasNext = BigramListReadWriteUtils::hasNext(bigramFlags); + if (usesAdditionalBuffer) { + *bigramEntryPos += mBuffer->getOriginalBufferSize(); + } +} + +void DynamicBigramListPolicy::skipAllBigrams(int *const bigramListPos) const { + const bool usesAdditionalBuffer = mBuffer->isInAdditionalBuffer(*bigramListPos); + const uint8_t *const buffer = mBuffer->getBuffer(usesAdditionalBuffer); + if (usesAdditionalBuffer) { + *bigramListPos -= mBuffer->getOriginalBufferSize(); + } + BigramListReadWriteUtils::skipExistingBigrams(buffer, bigramListPos); + if (usesAdditionalBuffer) { + *bigramListPos += mBuffer->getOriginalBufferSize(); + } +} + +bool DynamicBigramListPolicy::copyAllBigrams(BufferWithExtendableBuffer *const bufferToWrite, + int *const fromPos, int *const toPos, int *const outBigramsCount) const { + const bool usesAdditionalBuffer = mBuffer->isInAdditionalBuffer(*fromPos); + if (usesAdditionalBuffer) { + *fromPos -= mBuffer->getOriginalBufferSize(); + } + *outBigramsCount = 0; + BigramListReadWriteUtils::BigramFlags bigramFlags; + int bigramEntryCount = 0; + int lastWrittenEntryPos = NOT_A_DICT_POS; + do { + if (++bigramEntryCount > BIGRAM_ENTRY_COUNT_IN_A_BIGRAM_LIST_LIMIT) { + AKLOGE("Too many bigram entries. Entry count: %d, Limit: %d", + bigramEntryCount, BIGRAM_ENTRY_COUNT_IN_A_BIGRAM_LIST_LIMIT); + ASSERT(false); + return false; + } + // The buffer address can be changed after calling buffer writing methods. + int originalBigramPos; + BigramListReadWriteUtils::getBigramEntryPropertiesAndAdvancePosition( + mBuffer->getBuffer(usesAdditionalBuffer), &bigramFlags, &originalBigramPos, + fromPos); + if (originalBigramPos == NOT_A_DICT_POS) { + // skip invalid bigram entry. + continue; + } + if (usesAdditionalBuffer) { + originalBigramPos += mBuffer->getOriginalBufferSize(); + } + const int bigramPos = followBigramLinkAndGetCurrentBigramPtNodePos(originalBigramPos); + if (bigramPos == NOT_A_DICT_POS) { + // Target PtNode has been invalidated. + continue; + } + lastWrittenEntryPos = *toPos; + if (!BigramListReadWriteUtils::createAndWriteBigramEntry(bufferToWrite, bigramPos, + BigramListReadWriteUtils::getProbabilityFromFlags(bigramFlags), + BigramListReadWriteUtils::hasNext(bigramFlags), toPos)) { + return false; + } + (*outBigramsCount)++; + } while(BigramListReadWriteUtils::hasNext(bigramFlags)); + // Makes the last entry the terminal of the list. Updates the flags. + if (lastWrittenEntryPos != NOT_A_DICT_POS) { + if (!BigramListReadWriteUtils::setHasNextFlag(bufferToWrite, false /* hasNext */, + lastWrittenEntryPos)) { + return false; + } + } + if (usesAdditionalBuffer) { + *fromPos += mBuffer->getOriginalBufferSize(); + } + return true; +} + +// Finding useless bigram entries and remove them. Bigram entry is useless when the target PtNode +// has been deleted or is not a valid terminal. +bool DynamicBigramListPolicy::updateAllBigramEntriesAndDeleteUselessEntries( + int *const bigramListPos) { + const bool usesAdditionalBuffer = mBuffer->isInAdditionalBuffer(*bigramListPos); + if (usesAdditionalBuffer) { + *bigramListPos -= mBuffer->getOriginalBufferSize(); + } + DynamicPatriciaTrieNodeReader nodeReader(mBuffer, this /* bigramsPolicy */, mShortcutPolicy); + BigramListReadWriteUtils::BigramFlags bigramFlags; + int bigramEntryCount = 0; + do { + if (++bigramEntryCount > BIGRAM_ENTRY_COUNT_IN_A_BIGRAM_LIST_LIMIT) { + AKLOGE("Too many bigram entries. Entry count: %d, Limit: %d", + bigramEntryCount, BIGRAM_ENTRY_COUNT_IN_A_BIGRAM_LIST_LIMIT); + ASSERT(false); + return false; + } + int bigramEntryPos = *bigramListPos; + int originalBigramPos; + // The buffer address can be changed after calling buffer writing methods. + BigramListReadWriteUtils::getBigramEntryPropertiesAndAdvancePosition( + mBuffer->getBuffer(usesAdditionalBuffer), &bigramFlags, &originalBigramPos, + bigramListPos); + if (usesAdditionalBuffer) { + bigramEntryPos += mBuffer->getOriginalBufferSize(); + } + if (originalBigramPos == NOT_A_DICT_POS) { + // This entry has already been removed. + continue; + } + if (usesAdditionalBuffer) { + originalBigramPos += mBuffer->getOriginalBufferSize(); + } + const int bigramTargetNodePos = + followBigramLinkAndGetCurrentBigramPtNodePos(originalBigramPos); + nodeReader.fetchNodeInfoInBufferFromPtNodePos(bigramTargetNodePos); + // TODO: Update probability for supporting probability decaying. + if (nodeReader.isDeleted() || !nodeReader.isTerminal() + || bigramTargetNodePos == NOT_A_DICT_POS) { + // The target is no longer valid terminal. Invalidate the current bigram entry. + if (!BigramListReadWriteUtils::writeBigramEntry(mBuffer, bigramFlags, + NOT_A_DICT_POS /* targetOffset */, &bigramEntryPos)) { + return false; + } + } + } while(BigramListReadWriteUtils::hasNext(bigramFlags)); + return true; +} + +// Updates bigram target PtNode positions in the list after the placing step in GC. +bool DynamicBigramListPolicy::updateAllBigramTargetPtNodePositions(int *const bigramListPos, + const DynamicPatriciaTrieWritingHelper::PtNodePositionRelocationMap *const + ptNodePositionRelocationMap) { + const bool usesAdditionalBuffer = mBuffer->isInAdditionalBuffer(*bigramListPos); + if (usesAdditionalBuffer) { + *bigramListPos -= mBuffer->getOriginalBufferSize(); + } + BigramListReadWriteUtils::BigramFlags bigramFlags; + int bigramEntryCount = 0; + do { + if (++bigramEntryCount > BIGRAM_ENTRY_COUNT_IN_A_BIGRAM_LIST_LIMIT) { + AKLOGE("Too many bigram entries. Entry count: %d, Limit: %d", + bigramEntryCount, BIGRAM_ENTRY_COUNT_IN_A_BIGRAM_LIST_LIMIT); + ASSERT(false); + return false; + } + int bigramEntryPos = *bigramListPos; + if (usesAdditionalBuffer) { + bigramEntryPos += mBuffer->getOriginalBufferSize(); + } + int bigramTargetPtNodePos; + // The buffer address can be changed after calling buffer writing methods. + BigramListReadWriteUtils::getBigramEntryPropertiesAndAdvancePosition( + mBuffer->getBuffer(usesAdditionalBuffer), &bigramFlags, &bigramTargetPtNodePos, + bigramListPos); + if (bigramTargetPtNodePos == NOT_A_DICT_POS) { + continue; + } + if (usesAdditionalBuffer) { + bigramTargetPtNodePos += mBuffer->getOriginalBufferSize(); + } + + DynamicPatriciaTrieWritingHelper::PtNodePositionRelocationMap::const_iterator it = + ptNodePositionRelocationMap->find(bigramTargetPtNodePos); + if (it != ptNodePositionRelocationMap->end()) { + bigramTargetPtNodePos = it->second; + } else { + bigramTargetPtNodePos = NOT_A_DICT_POS; + } + if (!BigramListReadWriteUtils::writeBigramEntry(mBuffer, bigramFlags, + bigramTargetPtNodePos, &bigramEntryPos)) { + return false; + } + } while(BigramListReadWriteUtils::hasNext(bigramFlags)); + return true; +} + +bool DynamicBigramListPolicy::addNewBigramEntryToBigramList(const int bigramTargetPos, + const int probability, int *const bigramListPos) { + const bool usesAdditionalBuffer = mBuffer->isInAdditionalBuffer(*bigramListPos); + if (usesAdditionalBuffer) { + *bigramListPos -= mBuffer->getOriginalBufferSize(); + } + BigramListReadWriteUtils::BigramFlags bigramFlags; + int bigramEntryCount = 0; + do { + if (++bigramEntryCount > BIGRAM_ENTRY_COUNT_IN_A_BIGRAM_LIST_LIMIT) { + AKLOGE("Too many bigram entries. Entry count: %d, Limit: %d", + bigramEntryCount, BIGRAM_ENTRY_COUNT_IN_A_BIGRAM_LIST_LIMIT); + ASSERT(false); + return false; + } + int entryPos = *bigramListPos; + if (usesAdditionalBuffer) { + entryPos += mBuffer->getOriginalBufferSize(); + } + int originalBigramPos; + // The buffer address can be changed after calling buffer writing methods. + BigramListReadWriteUtils::getBigramEntryPropertiesAndAdvancePosition( + mBuffer->getBuffer(usesAdditionalBuffer), &bigramFlags, &originalBigramPos, + bigramListPos); + if (usesAdditionalBuffer && originalBigramPos != NOT_A_DICT_POS) { + originalBigramPos += mBuffer->getOriginalBufferSize(); + } + if (followBigramLinkAndGetCurrentBigramPtNodePos(originalBigramPos) == bigramTargetPos) { + // Update this bigram entry. + const BigramListReadWriteUtils::BigramFlags updatedFlags = + BigramListReadWriteUtils::setProbabilityInFlags(bigramFlags, probability); + return BigramListReadWriteUtils::writeBigramEntry(mBuffer, updatedFlags, + originalBigramPos, &entryPos); + } + if (BigramListReadWriteUtils::hasNext(bigramFlags)) { + continue; + } + // The current last entry is found. + // First, update the flags of the last entry. + if (!BigramListReadWriteUtils::setHasNextFlag(mBuffer, true /* hasNext */, entryPos)) { + return false; + } + if (usesAdditionalBuffer) { + *bigramListPos += mBuffer->getOriginalBufferSize(); + } + // Then, add a new entry after the last entry. + return writeNewBigramEntry(bigramTargetPos, probability, bigramListPos); + } while(BigramListReadWriteUtils::hasNext(bigramFlags)); + // We return directly from the while loop. + ASSERT(false); + return false; +} + +bool DynamicBigramListPolicy::writeNewBigramEntry(const int bigramTargetPos, const int probability, + int *const writingPos) { + // hasNext is false because we are adding a new bigram entry at the end of the bigram list. + return BigramListReadWriteUtils::createAndWriteBigramEntry(mBuffer, bigramTargetPos, + probability, false /* hasNext */, writingPos); +} + +bool DynamicBigramListPolicy::removeBigram(const int bigramListPos, const int bigramTargetPos) { + const bool usesAdditionalBuffer = mBuffer->isInAdditionalBuffer(bigramListPos); + int pos = bigramListPos; + if (usesAdditionalBuffer) { + pos -= mBuffer->getOriginalBufferSize(); + } + BigramListReadWriteUtils::BigramFlags bigramFlags; + int bigramEntryCount = 0; + do { + if (++bigramEntryCount > BIGRAM_ENTRY_COUNT_IN_A_BIGRAM_LIST_LIMIT) { + AKLOGE("Too many bigram entries. Entry count: %d, Limit: %d", + bigramEntryCount, BIGRAM_ENTRY_COUNT_IN_A_BIGRAM_LIST_LIMIT); + ASSERT(false); + return false; + } + int bigramEntryPos = pos; + int originalBigramPos; + // The buffer address can be changed after calling buffer writing methods. + BigramListReadWriteUtils::getBigramEntryPropertiesAndAdvancePosition( + mBuffer->getBuffer(usesAdditionalBuffer), &bigramFlags, &originalBigramPos, &pos); + if (usesAdditionalBuffer) { + bigramEntryPos += mBuffer->getOriginalBufferSize(); + } + if (usesAdditionalBuffer && originalBigramPos != NOT_A_DICT_POS) { + originalBigramPos += mBuffer->getOriginalBufferSize(); + } + const int bigramPos = followBigramLinkAndGetCurrentBigramPtNodePos(originalBigramPos); + if (bigramPos != bigramTargetPos) { + continue; + } + // Target entry is found. Write an invalid target position to mark the bigram invalid. + return BigramListReadWriteUtils::writeBigramEntry(mBuffer, bigramFlags, + NOT_A_DICT_POS /* targetOffset */, &bigramEntryPos); + } while(BigramListReadWriteUtils::hasNext(bigramFlags)); + return false; +} + +int DynamicBigramListPolicy::followBigramLinkAndGetCurrentBigramPtNodePos( + const int originalBigramPos) const { + if (originalBigramPos == NOT_A_DICT_POS) { + return NOT_A_DICT_POS; + } + int currentPos = originalBigramPos; + DynamicPatriciaTrieNodeReader nodeReader(mBuffer, this /* bigramsPolicy */, mShortcutPolicy); + nodeReader.fetchNodeInfoInBufferFromPtNodePos(currentPos); + int bigramLinkCount = 0; + while (nodeReader.getBigramLinkedNodePos() != NOT_A_DICT_POS) { + currentPos = nodeReader.getBigramLinkedNodePos(); + nodeReader.fetchNodeInfoInBufferFromPtNodePos(currentPos); + bigramLinkCount++; + if (bigramLinkCount > CONTINUING_BIGRAM_LINK_COUNT_LIMIT) { + AKLOGE("Bigram link is invalid. start position: %d", originalBigramPos); + ASSERT(false); + return NOT_A_DICT_POS; + } + } + return currentPos; +} + +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/bigram/dynamic_bigram_list_policy.h b/native/jni/src/suggest/policyimpl/dictionary/bigram/dynamic_bigram_list_policy.h new file mode 100644 index 000000000..8ea318a41 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/bigram/dynamic_bigram_list_policy.h @@ -0,0 +1,81 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DYNAMIC_BIGRAM_LIST_POLICY_H +#define LATINIME_DYNAMIC_BIGRAM_LIST_POLICY_H + +#include <stdint.h> + +#include "defines.h" +#include "suggest/core/policy/dictionary_bigrams_structure_policy.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.h" + +namespace latinime { + +class BufferWithExtendableBuffer; +class DictionaryShortcutsStructurePolicy; + +/* + * This is a dynamic version of BigramListPolicy and supports an additional buffer. + */ +class DynamicBigramListPolicy : public DictionaryBigramsStructurePolicy { + public: + DynamicBigramListPolicy(BufferWithExtendableBuffer *const buffer, + const DictionaryShortcutsStructurePolicy *const shortcutPolicy) + : mBuffer(buffer), mShortcutPolicy(shortcutPolicy) {} + + ~DynamicBigramListPolicy() {} + + void getNextBigram(int *const outBigramPos, int *const outProbability, bool *const outHasNext, + int *const bigramEntryPos) const; + + void skipAllBigrams(int *const bigramListPos) const; + + // Copy bigrams from the bigram list that starts at fromPos in mBuffer to toPos in + // bufferToWrite and advance these positions after bigram lists. This method skips invalid + // bigram entries and write the valid bigram entry count to outBigramsCount. + bool copyAllBigrams(BufferWithExtendableBuffer *const bufferToWrite, int *const fromPos, + int *const toPos, int *const outBigramsCount) const; + + bool updateAllBigramEntriesAndDeleteUselessEntries(int *const bigramListPos); + + bool updateAllBigramTargetPtNodePositions(int *const bigramListPos, + const DynamicPatriciaTrieWritingHelper::PtNodePositionRelocationMap *const + ptNodePositionRelocationMap); + + bool addNewBigramEntryToBigramList(const int bigramTargetPos, const int probability, + int *const bigramListPos); + + bool writeNewBigramEntry(const int bigramTargetPos, const int probability, + int *const writingPos); + + // Return if targetBigramPos is found or not. + bool removeBigram(const int bigramListPos, const int bigramTargetPos); + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(DynamicBigramListPolicy); + + static const int CONTINUING_BIGRAM_LINK_COUNT_LIMIT; + static const int BIGRAM_ENTRY_COUNT_IN_A_BIGRAM_LIST_LIMIT; + + BufferWithExtendableBuffer *const mBuffer; + const DictionaryShortcutsStructurePolicy *const mShortcutPolicy; + + // Follow bigram link and return the position of bigram target PtNode that is currently valid. + int followBigramLinkAndGetCurrentBigramPtNodePos(const int originalBigramPos) const; +}; +} // namespace latinime +#endif // LATINIME_DYNAMIC_BIGRAM_LIST_POLICY_H diff --git a/native/jni/src/suggest/policyimpl/dictionary/dictionary_structure_with_buffer_policy_factory.cpp b/native/jni/src/suggest/policyimpl/dictionary/dictionary_structure_with_buffer_policy_factory.cpp new file mode 100644 index 000000000..ff80dd2f6 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/dictionary_structure_with_buffer_policy_factory.cpp @@ -0,0 +1,53 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/dictionary_structure_with_buffer_policy_factory.h" + +#include <stdint.h> + +#include "defines.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.h" +#include "suggest/policyimpl/dictionary/patricia_trie_policy.h" +#include "suggest/policyimpl/dictionary/utils/format_utils.h" +#include "suggest/policyimpl/dictionary/utils/mmapped_buffer.h" + +namespace latinime { + +/* static */ DictionaryStructureWithBufferPolicy *DictionaryStructureWithBufferPolicyFactory + ::newDictionaryStructureWithBufferPolicy(const char *const path, const int bufOffset, + const int size, const bool isUpdatable) { + // Allocated buffer in MmapedBuffer::openBuffer() will be freed in the destructor of + // impl classes of DictionaryStructureWithBufferPolicy. + const MmappedBuffer *const mmapedBuffer = MmappedBuffer::openBuffer(path, bufOffset, size, + isUpdatable); + if (!mmapedBuffer) { + return 0; + } + switch (FormatUtils::detectFormatVersion(mmapedBuffer->getBuffer(), + mmapedBuffer->getBufferSize())) { + case FormatUtils::VERSION_2: + return new PatriciaTriePolicy(mmapedBuffer); + case FormatUtils::VERSION_3: + return new DynamicPatriciaTriePolicy(mmapedBuffer); + default: + AKLOGE("DICT: dictionary format is unknown, bad magic number"); + delete mmapedBuffer; + ASSERT(false); + return 0; + } +} + +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/dictionary_structure_with_buffer_policy_factory.h b/native/jni/src/suggest/policyimpl/dictionary/dictionary_structure_with_buffer_policy_factory.h new file mode 100644 index 000000000..8cebc3b16 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/dictionary_structure_with_buffer_policy_factory.h @@ -0,0 +1,36 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DICTIONARY_STRUCTURE_WITH_BUFFER_POLICY_FACTORY_H +#define LATINIME_DICTIONARY_STRUCTURE_WITH_BUFFER_POLICY_FACTORY_H + +#include <stdint.h> + +#include "defines.h" +#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" + +namespace latinime { + +class DictionaryStructureWithBufferPolicyFactory { + public: + static DictionaryStructureWithBufferPolicy *newDictionaryStructureWithBufferPolicy( + const char *const path, const int bufOffset, const int size, const bool isUpdatable); + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(DictionaryStructureWithBufferPolicyFactory); +}; +} // namespace latinime +#endif // LATINIME_DICTIONARY_STRUCTURE_WITH_BUFFER_POLICY_FACTORY_H diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_gc_event_listeners.cpp b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_gc_event_listeners.cpp new file mode 100644 index 000000000..c60e45819 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_gc_event_listeners.cpp @@ -0,0 +1,149 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_gc_event_listeners.h" + +namespace latinime { + +bool DynamicPatriciaTrieGcEventListeners + ::TraversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted + ::onVisitingPtNode(const DynamicPatriciaTrieNodeReader *const node, + const int *const nodeCodePoints) { + // PtNode is useless when the PtNode is not a terminal and doesn't have any not useless + // children. + bool isUselessPtNode = !node->isTerminal(); + if (mChildrenValue > 0) { + isUselessPtNode = false; + } else if (node->isTerminal()) { + // Remove children as all children are useless. + int writingPos = node->getChildrenPosFieldPos(); + if (!DynamicPatriciaTrieWritingUtils::writeChildrenPositionAndAdvancePosition( + mBuffer, NOT_A_DICT_POS /* childrenPosition */, &writingPos)) { + return false; + } + } + if (isUselessPtNode) { + // Current PtNode is no longer needed. Mark it as deleted. + if (!mWritingHelper->markNodeAsDeleted(node)) { + return false; + } + } else { + valueStack.back() += 1; + } + return true; +} + +// Writes dummy PtNode array size when the head of PtNode array is read. +bool DynamicPatriciaTrieGcEventListeners::TraversePolicyToPlaceAndWriteValidPtNodesToBuffer + ::onDescend(const int ptNodeArrayPos) { + mValidPtNodeCount = 0; + int writingPos = mBufferToWrite->getTailPosition(); + mDictPositionRelocationMap->mPtNodeArrayPositionRelocationMap.insert( + DynamicPatriciaTrieWritingHelper::PtNodeArrayPositionRelocationMap::value_type( + ptNodeArrayPos, writingPos)); + // Writes dummy PtNode array size because arrays can have a forward link or needles PtNodes. + // This field will be updated later in onReadingPtNodeArrayTail() with actual PtNode count. + mPtNodeArraySizeFieldPos = writingPos; + return DynamicPatriciaTrieWritingUtils::writePtNodeArraySizeAndAdvancePosition( + mBufferToWrite, 0 /* arraySize */, &writingPos); +} + +// Write PtNode array terminal and actual PtNode array size. +bool DynamicPatriciaTrieGcEventListeners::TraversePolicyToPlaceAndWriteValidPtNodesToBuffer + ::onReadingPtNodeArrayTail() { + int writingPos = mBufferToWrite->getTailPosition(); + // Write PtNode array terminal. + if (!DynamicPatriciaTrieWritingUtils::writeForwardLinkPositionAndAdvancePosition( + mBufferToWrite, NOT_A_DICT_POS /* forwardLinkPos */, &writingPos)) { + return false; + } + // Write actual PtNode array size. + if (!DynamicPatriciaTrieWritingUtils::writePtNodeArraySizeAndAdvancePosition( + mBufferToWrite, mValidPtNodeCount, &mPtNodeArraySizeFieldPos)) { + return false; + } + return true; +} + +// Write valid PtNode to buffer and memorize mapping from the old position to the new position. +bool DynamicPatriciaTrieGcEventListeners::TraversePolicyToPlaceAndWriteValidPtNodesToBuffer + ::onVisitingPtNode(const DynamicPatriciaTrieNodeReader *const node, + const int *const nodeCodePoints) { + if (node->isDeleted()) { + // Current PtNode is not written in new buffer because it has been deleted. + mDictPositionRelocationMap->mPtNodePositionRelocationMap.insert( + DynamicPatriciaTrieWritingHelper::PtNodePositionRelocationMap::value_type( + node->getHeadPos(), NOT_A_DICT_POS)); + return true; + } + int writingPos = mBufferToWrite->getTailPosition(); + mDictPositionRelocationMap->mPtNodePositionRelocationMap.insert( + DynamicPatriciaTrieWritingHelper::PtNodePositionRelocationMap::value_type( + node->getHeadPos(), writingPos)); + mValidPtNodeCount++; + // Writes current PtNode. + return mWritingHelper->writePtNodeToBufferByCopyingPtNodeInfo(mBufferToWrite, node, + node->getParentPos(), nodeCodePoints, node->getCodePointCount(), + node->getProbability(), &writingPos); +} + +bool DynamicPatriciaTrieGcEventListeners::TraversePolicyToUpdateAllPositionFields + ::onVisitingPtNode(const DynamicPatriciaTrieNodeReader *const node, + const int *const nodeCodePoints) { + // Updates parent position. + int parentPos = node->getParentPos(); + if (parentPos != NOT_A_DICT_POS) { + DynamicPatriciaTrieWritingHelper::PtNodePositionRelocationMap::const_iterator it = + mDictPositionRelocationMap->mPtNodePositionRelocationMap.find(parentPos); + if (it != mDictPositionRelocationMap->mPtNodePositionRelocationMap.end()) { + parentPos = it->second; + } + } + int writingPos = node->getHeadPos() + DynamicPatriciaTrieWritingUtils::NODE_FLAG_FIELD_SIZE; + // Write updated parent offset. + if (!DynamicPatriciaTrieWritingUtils::writeParentPosOffsetAndAdvancePosition(mBufferToWrite, + parentPos, node->getHeadPos(), &writingPos)) { + return false; + } + + // Updates children position. + int childrenPos = node->getChildrenPos(); + if (childrenPos != NOT_A_DICT_POS) { + DynamicPatriciaTrieWritingHelper::PtNodeArrayPositionRelocationMap::const_iterator it = + mDictPositionRelocationMap->mPtNodeArrayPositionRelocationMap.find(childrenPos); + if (it != mDictPositionRelocationMap->mPtNodeArrayPositionRelocationMap.end()) { + childrenPos = it->second; + } + } + writingPos = node->getChildrenPosFieldPos(); + if (!DynamicPatriciaTrieWritingUtils::writeChildrenPositionAndAdvancePosition(mBufferToWrite, + childrenPos, &writingPos)) { + return false; + } + + // Updates bigram target PtNode positions in the bigram list. + int bigramsPos = node->getBigramsPos(); + if (bigramsPos != NOT_A_DICT_POS) { + if (!mBigramPolicy->updateAllBigramTargetPtNodePositions(&bigramsPos, + &mDictPositionRelocationMap->mPtNodePositionRelocationMap)) { + return false; + } + } + + return true; +} + +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_gc_event_listeners.h b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_gc_event_listeners.h new file mode 100644 index 000000000..4256f22fb --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_gc_event_listeners.h @@ -0,0 +1,178 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DYNAMIC_PATRICIA_TRIE_GC_EVENT_LISTENERS_H +#define LATINIME_DYNAMIC_PATRICIA_TRIE_GC_EVENT_LISTENERS_H + +#include <vector> + +#include "defines.h" +#include "suggest/policyimpl/dictionary/bigram/dynamic_bigram_list_policy.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_helper.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_utils.h" +#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" +#include "utils/hash_map_compat.h" + +namespace latinime { + +class DynamicPatriciaTrieGcEventListeners { + public: + // Updates all PtNodes that can be reached from the root. Checks if each PtNode is useless or + // not and marks useless PtNodes as deleted. Such deleted PtNodes will be discarded in the GC. + // TODO: Concatenate non-terminal PtNodes. + class TraversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted + : public DynamicPatriciaTrieReadingHelper::TraversingEventListener { + public: + TraversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted( + DynamicPatriciaTrieWritingHelper *const writingHelper, + BufferWithExtendableBuffer *const buffer) + : mWritingHelper(writingHelper), mBuffer(buffer), valueStack(), + mChildrenValue(0) {} + + ~TraversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted() {}; + + bool onAscend() { + if (valueStack.empty()) { + return false; + } + mChildrenValue = valueStack.back(); + valueStack.pop_back(); + return true; + } + + bool onDescend(const int ptNodeArrayPos) { + valueStack.push_back(0); + return true; + } + + bool onReadingPtNodeArrayTail() { return true; } + + bool onVisitingPtNode(const DynamicPatriciaTrieNodeReader *const node, + const int *const nodeCodePoints); + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS( + TraversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted); + + DynamicPatriciaTrieWritingHelper *const mWritingHelper; + BufferWithExtendableBuffer *const mBuffer; + std::vector<int> valueStack; + int mChildrenValue; + }; + + // Updates all bigram entries that are held by valid PtNodes. This removes useless bigram + // entries. + class TraversePolicyToUpdateBigramProbability + : public DynamicPatriciaTrieReadingHelper::TraversingEventListener { + public: + TraversePolicyToUpdateBigramProbability(DynamicBigramListPolicy *const bigramPolicy) + : mBigramPolicy(bigramPolicy) {} + + bool onAscend() { return true; } + + bool onDescend(const int ptNodeArrayPos) { return true; } + + bool onReadingPtNodeArrayTail() { return true; } + + bool onVisitingPtNode(const DynamicPatriciaTrieNodeReader *const node, + const int *const nodeCodePoints) { + if (!node->isDeleted()) { + int pos = node->getBigramsPos(); + if (pos != NOT_A_DICT_POS) { + if (!mBigramPolicy->updateAllBigramEntriesAndDeleteUselessEntries(&pos)) { + return false; + } + } + } + return true; + } + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(TraversePolicyToUpdateBigramProbability); + + DynamicBigramListPolicy *const mBigramPolicy; + }; + + class TraversePolicyToPlaceAndWriteValidPtNodesToBuffer + : public DynamicPatriciaTrieReadingHelper::TraversingEventListener { + public: + TraversePolicyToPlaceAndWriteValidPtNodesToBuffer( + DynamicPatriciaTrieWritingHelper *const writingHelper, + BufferWithExtendableBuffer *const bufferToWrite, + DynamicPatriciaTrieWritingHelper::DictPositionRelocationMap *const + dictPositionRelocationMap) + : mWritingHelper(writingHelper), mBufferToWrite(bufferToWrite), + mDictPositionRelocationMap(dictPositionRelocationMap), mValidPtNodeCount(0), + mPtNodeArraySizeFieldPos(NOT_A_DICT_POS) {}; + + bool onAscend() { return true; } + + bool onDescend(const int ptNodeArrayPos); + + bool onReadingPtNodeArrayTail(); + + bool onVisitingPtNode(const DynamicPatriciaTrieNodeReader *const node, + const int *const nodeCodePoints); + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(TraversePolicyToPlaceAndWriteValidPtNodesToBuffer); + + DynamicPatriciaTrieWritingHelper *const mWritingHelper; + BufferWithExtendableBuffer *const mBufferToWrite; + DynamicPatriciaTrieWritingHelper::DictPositionRelocationMap *const + mDictPositionRelocationMap; + int mValidPtNodeCount; + int mPtNodeArraySizeFieldPos; + }; + + class TraversePolicyToUpdateAllPositionFields + : public DynamicPatriciaTrieReadingHelper::TraversingEventListener { + public: + TraversePolicyToUpdateAllPositionFields( + DynamicPatriciaTrieWritingHelper *const writingHelper, + DynamicBigramListPolicy *const bigramPolicy, + BufferWithExtendableBuffer *const bufferToWrite, + const DynamicPatriciaTrieWritingHelper::DictPositionRelocationMap *const + dictPositionRelocationMap) + : mWritingHelper(writingHelper), mBigramPolicy(bigramPolicy), + mBufferToWrite(bufferToWrite), + mDictPositionRelocationMap(dictPositionRelocationMap) {}; + + bool onAscend() { return true; } + + bool onDescend(const int ptNodeArrayPos) { return true; } + + bool onReadingPtNodeArrayTail() { return true; } + + bool onVisitingPtNode(const DynamicPatriciaTrieNodeReader *const node, + const int *const nodeCodePoints); + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(TraversePolicyToUpdateAllPositionFields); + + DynamicPatriciaTrieWritingHelper *const mWritingHelper; + DynamicBigramListPolicy *const mBigramPolicy; + BufferWithExtendableBuffer *const mBufferToWrite; + const DynamicPatriciaTrieWritingHelper::DictPositionRelocationMap *const + mDictPositionRelocationMap; + }; + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(DynamicPatriciaTrieGcEventListeners); +}; +} // namespace latinime +#endif /* LATINIME_DYNAMIC_PATRICIA_TRIE_GC_EVENT_LISTENERS_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_node_reader.cpp b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_node_reader.cpp new file mode 100644 index 000000000..456352c17 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_node_reader.cpp @@ -0,0 +1,123 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_node_reader.h" + +#include "suggest/core/policy/dictionary_bigrams_structure_policy.h" +#include "suggest/core/policy/dictionary_shortcuts_structure_policy.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.h" +#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" + +namespace latinime { + +void DynamicPatriciaTrieNodeReader::fetchPtNodeInfoFromBufferAndProcessMovedPtNode( + const int ptNodePos, const int maxCodePointCount, int *const outCodePoints) { + if (ptNodePos < 0 || ptNodePos >= mBuffer->getTailPosition()) { + AKLOGE("Fetching PtNode info form invalid dictionary position: %d, dictionary size: %d", + ptNodePos, mBuffer->getTailPosition()); + ASSERT(false); + invalidatePtNodeInfo(); + return; + } + const bool usesAdditionalBuffer = mBuffer->isInAdditionalBuffer(ptNodePos); + const uint8_t *const dictBuf = mBuffer->getBuffer(usesAdditionalBuffer); + int pos = ptNodePos; + mHeadPos = ptNodePos; + if (usesAdditionalBuffer) { + pos -= mBuffer->getOriginalBufferSize(); + } + mFlags = PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(dictBuf, &pos); + const int parentPosOffset = + DynamicPatriciaTrieReadingUtils::getParentPtNodePosOffsetAndAdvancePosition(dictBuf, + &pos); + mParentPos = DynamicPatriciaTrieReadingUtils::getParentPtNodePos(parentPosOffset, mHeadPos); + if (outCodePoints != 0) { + mCodePointCount = PatriciaTrieReadingUtils::getCharsAndAdvancePosition( + dictBuf, mFlags, maxCodePointCount, outCodePoints, &pos); + } else { + mCodePointCount = PatriciaTrieReadingUtils::skipCharacters( + dictBuf, mFlags, MAX_WORD_LENGTH, &pos); + } + if (isTerminal()) { + mProbabilityFieldPos = pos; + if (usesAdditionalBuffer) { + mProbabilityFieldPos += mBuffer->getOriginalBufferSize(); + } + mProbability = PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(dictBuf, &pos); + } else { + mProbabilityFieldPos = NOT_A_DICT_POS; + mProbability = NOT_A_PROBABILITY; + } + mChildrenPosFieldPos = pos; + if (usesAdditionalBuffer) { + mChildrenPosFieldPos += mBuffer->getOriginalBufferSize(); + } + mChildrenPos = DynamicPatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition( + dictBuf, &pos); + if (usesAdditionalBuffer && mChildrenPos != NOT_A_DICT_POS) { + mChildrenPos += mBuffer->getOriginalBufferSize(); + } + if (mSiblingPos == NOT_A_DICT_POS) { + if (DynamicPatriciaTrieReadingUtils::isMoved(mFlags)) { + mBigramLinkedNodePos = mChildrenPos; + } else { + mBigramLinkedNodePos = NOT_A_DICT_POS; + } + } + if (usesAdditionalBuffer) { + pos += mBuffer->getOriginalBufferSize(); + } + if (PatriciaTrieReadingUtils::hasShortcutTargets(mFlags)) { + mShortcutPos = pos; + mShortcutsPolicy->skipAllShortcuts(&pos); + } else { + mShortcutPos = NOT_A_DICT_POS; + } + if (PatriciaTrieReadingUtils::hasBigrams(mFlags)) { + mBigramPos = pos; + mBigramsPolicy->skipAllBigrams(&pos); + } else { + mBigramPos = NOT_A_DICT_POS; + } + // Update siblingPos if needed. + if (mSiblingPos == NOT_A_DICT_POS) { + // Sibling position is the tail position of current node. + mSiblingPos = pos; + } + // Read destination node if the read node is a moved node. + if (DynamicPatriciaTrieReadingUtils::isMoved(mFlags)) { + // The destination position is stored at the same place as the parent position. + fetchPtNodeInfoFromBufferAndProcessMovedPtNode(mParentPos, maxCodePointCount, + outCodePoints); + } +} + +void DynamicPatriciaTrieNodeReader::invalidatePtNodeInfo() { + mHeadPos = NOT_A_DICT_POS; + mFlags = 0; + mParentPos = NOT_A_DICT_POS; + mCodePointCount = 0; + mProbabilityFieldPos = NOT_A_DICT_POS; + mProbability = NOT_A_PROBABILITY; + mChildrenPosFieldPos = NOT_A_DICT_POS; + mChildrenPos = NOT_A_DICT_POS; + mBigramLinkedNodePos = NOT_A_DICT_POS; + mShortcutPos = NOT_A_DICT_POS; + mBigramPos = NOT_A_DICT_POS; + mSiblingPos = NOT_A_DICT_POS; +} + +} diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_node_reader.h b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_node_reader.h new file mode 100644 index 000000000..3b36d425f --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_node_reader.h @@ -0,0 +1,163 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DYNAMIC_PATRICIA_TRIE_NODE_READER_H +#define LATINIME_DYNAMIC_PATRICIA_TRIE_NODE_READER_H + +#include <stdint.h> + +#include "defines.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.h" +#include "suggest/policyimpl/dictionary/patricia_trie_reading_utils.h" + +namespace latinime { + +class BufferWithExtendableBuffer; +class DictionaryBigramsStructurePolicy; +class DictionaryShortcutsStructurePolicy; + +/* + * This class is used for helping to read nodes of dynamic patricia trie. This class handles moved + * node and reads node attributes. + */ +class DynamicPatriciaTrieNodeReader { + public: + DynamicPatriciaTrieNodeReader(const BufferWithExtendableBuffer *const buffer, + const DictionaryBigramsStructurePolicy *const bigramsPolicy, + const DictionaryShortcutsStructurePolicy *const shortcutsPolicy) + : mBuffer(buffer), mBigramsPolicy(bigramsPolicy), + mShortcutsPolicy(shortcutsPolicy), mHeadPos(NOT_A_DICT_POS), mFlags(0), + mParentPos(NOT_A_DICT_POS), mCodePointCount(0), mProbabilityFieldPos(NOT_A_DICT_POS), + mProbability(NOT_A_PROBABILITY), mChildrenPosFieldPos(NOT_A_DICT_POS), + mChildrenPos(NOT_A_DICT_POS), mBigramLinkedNodePos(NOT_A_DICT_POS), + mShortcutPos(NOT_A_DICT_POS), mBigramPos(NOT_A_DICT_POS), + mSiblingPos(NOT_A_DICT_POS) {} + + ~DynamicPatriciaTrieNodeReader() {} + + // Reads PtNode information from dictionary buffer and updates members with the information. + AK_FORCE_INLINE void fetchNodeInfoInBufferFromPtNodePos(const int ptNodePos) { + fetchNodeInfoInBufferFromPtNodePosAndGetNodeCodePoints(ptNodePos , + 0 /* maxCodePointCount */, 0 /* outCodePoints */); + } + + AK_FORCE_INLINE void fetchNodeInfoInBufferFromPtNodePosAndGetNodeCodePoints( + const int ptNodePos, const int maxCodePointCount, int *const outCodePoints) { + mSiblingPos = NOT_A_DICT_POS; + mBigramLinkedNodePos = NOT_A_DICT_POS; + fetchPtNodeInfoFromBufferAndProcessMovedPtNode(ptNodePos, maxCodePointCount, outCodePoints); + } + + // HeadPos is different from NodePos when the current PtNode is a moved PtNode. + AK_FORCE_INLINE int getHeadPos() const { + return mHeadPos; + } + + // Flags + AK_FORCE_INLINE bool isDeleted() const { + return DynamicPatriciaTrieReadingUtils::isDeleted(mFlags); + } + + AK_FORCE_INLINE bool hasChildren() const { + return mChildrenPos != NOT_A_DICT_POS; + } + + AK_FORCE_INLINE bool isTerminal() const { + return PatriciaTrieReadingUtils::isTerminal(mFlags); + } + + AK_FORCE_INLINE bool isBlacklisted() const { + return PatriciaTrieReadingUtils::isBlacklisted(mFlags); + } + + AK_FORCE_INLINE bool isNotAWord() const { + return PatriciaTrieReadingUtils::isNotAWord(mFlags); + } + + // Parent node position + AK_FORCE_INLINE int getParentPos() const { + return mParentPos; + } + + // Number of code points + AK_FORCE_INLINE uint8_t getCodePointCount() const { + return mCodePointCount; + } + + // Probability + AK_FORCE_INLINE int getProbabilityFieldPos() const { + return mProbabilityFieldPos; + } + + AK_FORCE_INLINE int getProbability() const { + return mProbability; + } + + // Children PtNode array position + AK_FORCE_INLINE int getChildrenPosFieldPos() const { + return mChildrenPosFieldPos; + } + + AK_FORCE_INLINE int getChildrenPos() const { + return mChildrenPos; + } + + // Bigram linked node position. + AK_FORCE_INLINE int getBigramLinkedNodePos() const { + return mBigramLinkedNodePos; + } + + // Shortcutlist position + AK_FORCE_INLINE int getShortcutPos() const { + return mShortcutPos; + } + + // Bigrams position + AK_FORCE_INLINE int getBigramsPos() const { + return mBigramPos; + } + + // Sibling node position + AK_FORCE_INLINE int getSiblingNodePos() const { + return mSiblingPos; + } + + private: + DISALLOW_COPY_AND_ASSIGN(DynamicPatriciaTrieNodeReader); + + const BufferWithExtendableBuffer *const mBuffer; + const DictionaryBigramsStructurePolicy *const mBigramsPolicy; + const DictionaryShortcutsStructurePolicy *const mShortcutsPolicy; + int mHeadPos; + DynamicPatriciaTrieReadingUtils::NodeFlags mFlags; + int mParentPos; + uint8_t mCodePointCount; + int mProbabilityFieldPos; + int mProbability; + int mChildrenPosFieldPos; + int mChildrenPos; + int mBigramLinkedNodePos; + int mShortcutPos; + int mBigramPos; + int mSiblingPos; + + void fetchPtNodeInfoFromBufferAndProcessMovedPtNode(const int ptNodePos, + const int maxCodePointCount, int *const outCodePoints); + + void invalidatePtNodeInfo(); +}; +} // namespace latinime +#endif /* LATINIME_DYNAMIC_PATRICIA_TRIE_NODE_READER_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.cpp new file mode 100644 index 000000000..42397c19e --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.cpp @@ -0,0 +1,275 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.h" + +#include "defines.h" +#include "suggest/core/dicnode/dic_node.h" +#include "suggest/core/dicnode/dic_node_vector.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_node_reader.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_helper.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.h" +#include "suggest/policyimpl/dictionary/patricia_trie_reading_utils.h" +#include "suggest/policyimpl/dictionary/utils/probability_utils.h" + +namespace latinime { + +void DynamicPatriciaTriePolicy::createAndGetAllChildNodes(const DicNode *const dicNode, + DicNodeVector *const childDicNodes) const { + if (!dicNode->hasChildren()) { + return; + } + DynamicPatriciaTrieReadingHelper readingHelper(&mBufferWithExtendableBuffer, + getBigramsStructurePolicy(), getShortcutsStructurePolicy()); + readingHelper.initWithPtNodeArrayPos(dicNode->getChildrenPos()); + const DynamicPatriciaTrieNodeReader *const nodeReader = readingHelper.getNodeReader(); + while (!readingHelper.isEnd()) { + childDicNodes->pushLeavingChild(dicNode, nodeReader->getHeadPos(), + nodeReader->getChildrenPos(), nodeReader->getProbability(), + nodeReader->isTerminal() && !nodeReader->isDeleted(), + nodeReader->hasChildren(), nodeReader->isBlacklisted() || nodeReader->isNotAWord(), + nodeReader->getCodePointCount(), readingHelper.getMergedNodeCodePoints()); + readingHelper.readNextSiblingNode(); + } +} + +int DynamicPatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( + const int ptNodePos, const int maxCodePointCount, int *const outCodePoints, + int *const outUnigramProbability) const { + // This method traverses parent nodes from the terminal by following parent pointers; thus, + // node code points are stored in the buffer in the reverse order. + int reverseCodePoints[maxCodePointCount]; + DynamicPatriciaTrieReadingHelper readingHelper(&mBufferWithExtendableBuffer, + getBigramsStructurePolicy(), getShortcutsStructurePolicy()); + // First, read the terminal node and get its probability. + readingHelper.initWithPtNodePos(ptNodePos); + if (!readingHelper.isValidTerminalNode()) { + // Node at the ptNodePos is not a valid terminal node. + *outUnigramProbability = NOT_A_PROBABILITY; + return 0; + } + // Store terminal node probability. + *outUnigramProbability = readingHelper.getNodeReader()->getProbability(); + // Then, following parent node link to the dictionary root and fetch node code points. + while (!readingHelper.isEnd()) { + if (readingHelper.getTotalCodePointCount() > maxCodePointCount) { + // The ptNodePos is not a valid terminal node position in the dictionary. + *outUnigramProbability = NOT_A_PROBABILITY; + return 0; + } + // Store node code points to buffer in the reverse order. + readingHelper.fetchMergedNodeCodePointsInReverseOrder( + readingHelper.getPrevTotalCodePointCount(), reverseCodePoints); + // Follow parent node toward the root node. + readingHelper.readParentNode(); + } + if (readingHelper.isError()) { + // The node position or the dictionary is invalid. + *outUnigramProbability = NOT_A_PROBABILITY; + return 0; + } + // Reverse the stored code points to output them. + const int codePointCount = readingHelper.getTotalCodePointCount(); + for (int i = 0; i < codePointCount; ++i) { + outCodePoints[i] = reverseCodePoints[codePointCount - i - 1]; + } + return codePointCount; +} + +int DynamicPatriciaTriePolicy::getTerminalNodePositionOfWord(const int *const inWord, + const int length, const bool forceLowerCaseSearch) const { + int searchCodePoints[length]; + for (int i = 0; i < length; ++i) { + searchCodePoints[i] = forceLowerCaseSearch ? CharUtils::toLowerCase(inWord[i]) : inWord[i]; + } + DynamicPatriciaTrieReadingHelper readingHelper(&mBufferWithExtendableBuffer, + getBigramsStructurePolicy(), getShortcutsStructurePolicy()); + readingHelper.initWithPtNodeArrayPos(getRootPosition()); + const DynamicPatriciaTrieNodeReader *const nodeReader = readingHelper.getNodeReader(); + while (!readingHelper.isEnd()) { + const int matchedCodePointCount = readingHelper.getPrevTotalCodePointCount(); + if (readingHelper.getTotalCodePointCount() > length + || !readingHelper.isMatchedCodePoint(0 /* index */, + searchCodePoints[matchedCodePointCount])) { + // Current node has too many code points or its first code point is different from + // target code point. Skip this node and read the next sibling node. + readingHelper.readNextSiblingNode(); + continue; + } + // Check following merged node code points. + const int nodeCodePointCount = nodeReader->getCodePointCount(); + for (int j = 1; j < nodeCodePointCount; ++j) { + if (!readingHelper.isMatchedCodePoint( + j, searchCodePoints[matchedCodePointCount + j])) { + // Different code point is found. The given word is not included in the dictionary. + return NOT_A_DICT_POS; + } + } + // All characters are matched. + if (length == readingHelper.getTotalCodePointCount()) { + // Terminal position is found. + return nodeReader->getHeadPos(); + } + if (!nodeReader->hasChildren()) { + return NOT_A_DICT_POS; + } + // Advance to the children nodes. + readingHelper.readChildNode(); + } + // If we already traversed the tree further than the word is long, there means + // there was no match (or we would have found it). + return NOT_A_DICT_POS; +} + +int DynamicPatriciaTriePolicy::getProbability(const int unigramProbability, + const int bigramProbability) const { + // TODO: check mHeaderPolicy.usesForgettingCurve(); + if (unigramProbability == NOT_A_PROBABILITY) { + return NOT_A_PROBABILITY; + } else if (bigramProbability == NOT_A_PROBABILITY) { + return ProbabilityUtils::backoff(unigramProbability); + } else { + return ProbabilityUtils::computeProbabilityForBigram(unigramProbability, + bigramProbability); + } +} + +int DynamicPatriciaTriePolicy::getUnigramProbabilityOfPtNode(const int ptNodePos) const { + if (ptNodePos == NOT_A_DICT_POS) { + return NOT_A_PROBABILITY; + } + DynamicPatriciaTrieNodeReader nodeReader(&mBufferWithExtendableBuffer, + getBigramsStructurePolicy(), getShortcutsStructurePolicy()); + nodeReader.fetchNodeInfoInBufferFromPtNodePos(ptNodePos); + if (nodeReader.isDeleted() || nodeReader.isBlacklisted() || nodeReader.isNotAWord()) { + return NOT_A_PROBABILITY; + } + return getProbability(nodeReader.getProbability(), NOT_A_PROBABILITY); +} + +int DynamicPatriciaTriePolicy::getShortcutPositionOfPtNode(const int ptNodePos) const { + if (ptNodePos == NOT_A_DICT_POS) { + return NOT_A_DICT_POS; + } + DynamicPatriciaTrieNodeReader nodeReader(&mBufferWithExtendableBuffer, + getBigramsStructurePolicy(), getShortcutsStructurePolicy()); + nodeReader.fetchNodeInfoInBufferFromPtNodePos(ptNodePos); + if (nodeReader.isDeleted()) { + return NOT_A_DICT_POS; + } + return nodeReader.getShortcutPos(); +} + +int DynamicPatriciaTriePolicy::getBigramsPositionOfPtNode(const int ptNodePos) const { + if (ptNodePos == NOT_A_DICT_POS) { + return NOT_A_DICT_POS; + } + DynamicPatriciaTrieNodeReader nodeReader(&mBufferWithExtendableBuffer, + getBigramsStructurePolicy(), getShortcutsStructurePolicy()); + nodeReader.fetchNodeInfoInBufferFromPtNodePos(ptNodePos); + if (nodeReader.isDeleted()) { + return NOT_A_DICT_POS; + } + return nodeReader.getBigramsPos(); +} + +bool DynamicPatriciaTriePolicy::addUnigramWord(const int *const word, const int length, + const int probability) { + if (!mBuffer->isUpdatable()) { + AKLOGI("Warning: addUnigramWord() is called for non-updatable dictionary."); + return false; + } + DynamicPatriciaTrieReadingHelper readingHelper(&mBufferWithExtendableBuffer, + getBigramsStructurePolicy(), getShortcutsStructurePolicy()); + readingHelper.initWithPtNodeArrayPos(getRootPosition()); + DynamicPatriciaTrieWritingHelper writingHelper(&mBufferWithExtendableBuffer, + &mBigramListPolicy, &mShortcutListPolicy); + return writingHelper.addUnigramWord(&readingHelper, word, length, probability); +} + +bool DynamicPatriciaTriePolicy::addBigramWords(const int *const word0, const int length0, + const int *const word1, const int length1, const int probability) { + if (!mBuffer->isUpdatable()) { + AKLOGI("Warning: addBigramWords() is called for non-updatable dictionary."); + return false; + } + const int word0Pos = getTerminalNodePositionOfWord(word0, length0, + false /* forceLowerCaseSearch */); + if (word0Pos == NOT_A_DICT_POS) { + return false; + } + const int word1Pos = getTerminalNodePositionOfWord(word1, length1, + false /* forceLowerCaseSearch */); + if (word1Pos == NOT_A_DICT_POS) { + return false; + } + DynamicPatriciaTrieWritingHelper writingHelper(&mBufferWithExtendableBuffer, + &mBigramListPolicy, &mShortcutListPolicy); + return writingHelper.addBigramWords(word0Pos, word1Pos, probability); +} + +bool DynamicPatriciaTriePolicy::removeBigramWords(const int *const word0, const int length0, + const int *const word1, const int length1) { + if (!mBuffer->isUpdatable()) { + AKLOGI("Warning: removeBigramWords() is called for non-updatable dictionary."); + return false; + } + const int word0Pos = getTerminalNodePositionOfWord(word0, length0, + false /* forceLowerCaseSearch */); + if (word0Pos == NOT_A_DICT_POS) { + return false; + } + const int word1Pos = getTerminalNodePositionOfWord(word1, length1, + false /* forceLowerCaseSearch */); + if (word1Pos == NOT_A_DICT_POS) { + return false; + } + DynamicPatriciaTrieWritingHelper writingHelper(&mBufferWithExtendableBuffer, + &mBigramListPolicy, &mShortcutListPolicy); + return writingHelper.removeBigramWords(word0Pos, word1Pos); +} + +void DynamicPatriciaTriePolicy::flush(const char *const filePath) { + if (!mBuffer->isUpdatable()) { + AKLOGI("Warning: flush() is called for non-updatable dictionary."); + return; + } + DynamicPatriciaTrieWritingHelper writingHelper(&mBufferWithExtendableBuffer, + &mBigramListPolicy, &mShortcutListPolicy); + writingHelper.writeToDictFile(filePath, &mHeaderPolicy); +} + +void DynamicPatriciaTriePolicy::flushWithGC(const char *const filePath) { + if (!mBuffer->isUpdatable()) { + AKLOGI("Warning: flushWithGC() is called for non-updatable dictionary."); + return; + } + DynamicPatriciaTrieWritingHelper writingHelper(&mBufferWithExtendableBuffer, + &mBigramListPolicy, &mShortcutListPolicy); + writingHelper.writeToDictFileWithGC(getRootPosition(), filePath, &mHeaderPolicy); +} + +bool DynamicPatriciaTriePolicy::needsToRunGC() const { + if (!mBuffer->isUpdatable()) { + AKLOGI("Warning: needsToRunGC() is called for non-updatable dictionary."); + return false; + } + // TODO: Implement more properly. + return mBufferWithExtendableBuffer.isNearSizeLimit(); +} + +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.h new file mode 100644 index 000000000..06d8095d8 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.h @@ -0,0 +1,104 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DYNAMIC_PATRICIA_TRIE_POLICY_H +#define LATINIME_DYNAMIC_PATRICIA_TRIE_POLICY_H + +#include "defines.h" +#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" +#include "suggest/policyimpl/dictionary/bigram/dynamic_bigram_list_policy.h" +#include "suggest/policyimpl/dictionary/header/header_policy.h" +#include "suggest/policyimpl/dictionary/shortcut/dynamic_shortcut_list_policy.h" +#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" +#include "suggest/policyimpl/dictionary/utils/mmapped_buffer.h" + +namespace latinime { + +class DicNode; +class DicNodeVector; + +class DynamicPatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { + public: + DynamicPatriciaTriePolicy(const MmappedBuffer *const buffer) + : mBuffer(buffer), mHeaderPolicy(mBuffer->getBuffer(), buffer->getBufferSize()), + mBufferWithExtendableBuffer(mBuffer->getBuffer() + mHeaderPolicy.getSize(), + mBuffer->getBufferSize() - mHeaderPolicy.getSize()), + mShortcutListPolicy(&mBufferWithExtendableBuffer), + mBigramListPolicy(&mBufferWithExtendableBuffer, &mShortcutListPolicy) {} + + ~DynamicPatriciaTriePolicy() { + delete mBuffer; + } + + AK_FORCE_INLINE int getRootPosition() const { + return 0; + } + + void createAndGetAllChildNodes(const DicNode *const dicNode, + DicNodeVector *const childDicNodes) const; + + int getCodePointsAndProbabilityAndReturnCodePointCount( + const int terminalPtNodePos, const int maxCodePointCount, int *const outCodePoints, + int *const outUnigramProbability) const; + + int getTerminalNodePositionOfWord(const int *const inWord, + const int length, const bool forceLowerCaseSearch) const; + + int getProbability(const int unigramProbability, const int bigramProbability) const; + + int getUnigramProbabilityOfPtNode(const int ptNodePos) const; + + int getShortcutPositionOfPtNode(const int ptNodePos) const; + + int getBigramsPositionOfPtNode(const int ptNodePos) const; + + const DictionaryHeaderStructurePolicy *getHeaderStructurePolicy() const { + return &mHeaderPolicy; + } + + const DictionaryBigramsStructurePolicy *getBigramsStructurePolicy() const { + return &mBigramListPolicy; + } + + const DictionaryShortcutsStructurePolicy *getShortcutsStructurePolicy() const { + return &mShortcutListPolicy; + } + + bool addUnigramWord(const int *const word, const int length, const int probability); + + bool addBigramWords(const int *const word0, const int length0, const int *const word1, + const int length1, const int probability); + + bool removeBigramWords(const int *const word0, const int length0, const int *const word1, + const int length1); + + void flush(const char *const filePath); + + void flushWithGC(const char *const filePath); + + bool needsToRunGC() const; + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(DynamicPatriciaTriePolicy); + + const MmappedBuffer *const mBuffer; + const HeaderPolicy mHeaderPolicy; + BufferWithExtendableBuffer mBufferWithExtendableBuffer; + DynamicShortcutListPolicy mShortcutListPolicy; + DynamicBigramListPolicy mBigramListPolicy; +}; +} // namespace latinime +#endif // LATINIME_DYNAMIC_PATRICIA_TRIE_POLICY_H diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_helper.cpp new file mode 100644 index 000000000..f4a2ef389 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_helper.cpp @@ -0,0 +1,215 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_helper.h" + +#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" + +namespace latinime { + +// To avoid infinite loop caused by invalid or malicious forward links. +const int DynamicPatriciaTrieReadingHelper::MAX_CHILD_COUNT_TO_AVOID_INFINITE_LOOP = 100000; +const int DynamicPatriciaTrieReadingHelper::MAX_NODE_ARRAY_COUNT_TO_AVOID_INFINITE_LOOP = 100000; +const size_t DynamicPatriciaTrieReadingHelper::MAX_READING_STATE_STACK_SIZE = MAX_WORD_LENGTH; + +// Visits all PtNodes in post-order depth first manner. +// For example, visits c -> b -> y -> x -> a for the following dictionary: +// a _ b _ c +// \ x _ y +bool DynamicPatriciaTrieReadingHelper::traverseAllPtNodesInPostorderDepthFirstManner( + TraversingEventListener *const listener) { + bool alreadyVisitedChildren = false; + // Descend from the root to the root PtNode array. + if (!listener->onDescend(getPosOfLastPtNodeArrayHead())) { + return false; + } + while (!isEnd()) { + if (!alreadyVisitedChildren) { + if (mNodeReader.hasChildren()) { + // Move to the first child. + if (!listener->onDescend(mNodeReader.getChildrenPos())) { + return false; + } + pushReadingStateToStack(); + readChildNode(); + } else { + alreadyVisitedChildren = true; + } + } else { + if (!listener->onVisitingPtNode(&mNodeReader, mMergedNodeCodePoints)) { + return false; + } + readNextSiblingNode(); + if (isEnd()) { + // All PtNodes in current linked PtNode arrays have been visited. + // Return to the parent. + if (!listener->onReadingPtNodeArrayTail()) { + return false; + } + if (mReadingStateStack.size() <= 0) { + break; + } + if (!listener->onAscend()) { + return false; + } + popReadingStateFromStack(); + alreadyVisitedChildren = true; + } else { + // Process sibling PtNode. + alreadyVisitedChildren = false; + } + } + } + // Ascend from the root PtNode array to the root. + if (!listener->onAscend()) { + return false; + } + return !isError(); +} + +// Visits all PtNodes in PtNode array level pre-order depth first manner, which is the same order +// that PtNodes are written in the dictionary buffer. +// For example, visits a -> b -> x -> c -> y for the following dictionary: +// a _ b _ c +// \ x _ y +bool DynamicPatriciaTrieReadingHelper::traverseAllPtNodesInPtNodeArrayLevelPreorderDepthFirstManner( + TraversingEventListener *const listener) { + bool alreadyVisitedAllPtNodesInArray = false; + bool alreadyVisitedChildren = false; + // Descend from the root to the root PtNode array. + if (!listener->onDescend(getPosOfLastPtNodeArrayHead())) { + return false; + } + pushReadingStateToStack(); + while (!isEnd()) { + if (alreadyVisitedAllPtNodesInArray) { + if (alreadyVisitedChildren) { + // Move to next sibling PtNode's children. + readNextSiblingNode(); + if (isEnd()) { + // Return to the parent PTNode. + if (!listener->onAscend()) { + return false; + } + if (mReadingStateStack.size() <= 0) { + break; + } + popReadingStateFromStack(); + alreadyVisitedChildren = true; + alreadyVisitedAllPtNodesInArray = true; + } else { + alreadyVisitedChildren = false; + } + } else { + if (mNodeReader.hasChildren()) { + // Move to the first child. + if (!listener->onDescend(mNodeReader.getChildrenPos())) { + return false; + } + pushReadingStateToStack(); + readChildNode(); + // Push state to return the head of PtNode array. + pushReadingStateToStack(); + alreadyVisitedAllPtNodesInArray = false; + alreadyVisitedChildren = false; + } else { + alreadyVisitedChildren = true; + } + } + } else { + if (!listener->onVisitingPtNode(&mNodeReader, mMergedNodeCodePoints)) { + return false; + } + readNextSiblingNode(); + if (isEnd()) { + if (!listener->onReadingPtNodeArrayTail()) { + return false; + } + // Return to the head of current PtNode array. + popReadingStateFromStack(); + alreadyVisitedAllPtNodesInArray = true; + } + } + } + popReadingStateFromStack(); + // Ascend from the root PtNode array to the root. + if (!listener->onAscend()) { + return false; + } + return !isError(); +} + +// Read node array size and process empty node arrays. Nodes and arrays are counted up in this +// method to avoid an infinite loop. +void DynamicPatriciaTrieReadingHelper::nextPtNodeArray() { + mReadingState.mPosOfLastPtNodeArrayHead = mReadingState.mPos; + const bool usesAdditionalBuffer = mBuffer->isInAdditionalBuffer(mReadingState.mPos); + const uint8_t *const dictBuf = mBuffer->getBuffer(usesAdditionalBuffer); + if (usesAdditionalBuffer) { + mReadingState.mPos -= mBuffer->getOriginalBufferSize(); + } + mReadingState.mNodeCount = PatriciaTrieReadingUtils::getPtNodeArraySizeAndAdvancePosition( + dictBuf, &mReadingState.mPos); + if (usesAdditionalBuffer) { + mReadingState.mPos += mBuffer->getOriginalBufferSize(); + } + // Count up nodes and node arrays to avoid infinite loop. + mReadingState.mTotalNodeCount += mReadingState.mNodeCount; + mReadingState.mNodeArrayCount++; + if (mReadingState.mNodeCount < 0 + || mReadingState.mTotalNodeCount > MAX_CHILD_COUNT_TO_AVOID_INFINITE_LOOP + || mReadingState.mNodeArrayCount > MAX_NODE_ARRAY_COUNT_TO_AVOID_INFINITE_LOOP) { + // Invalid dictionary. + AKLOGI("Invalid dictionary. nodeCount: %d, totalNodeCount: %d, MAX_CHILD_COUNT: %d" + "nodeArrayCount: %d, MAX_NODE_ARRAY_COUNT: %d", + mReadingState.mNodeCount, mReadingState.mTotalNodeCount, + MAX_CHILD_COUNT_TO_AVOID_INFINITE_LOOP, mReadingState.mNodeArrayCount, + MAX_NODE_ARRAY_COUNT_TO_AVOID_INFINITE_LOOP); + ASSERT(false); + mIsError = true; + mReadingState.mPos = NOT_A_DICT_POS; + return; + } + if (mReadingState.mNodeCount == 0) { + // Empty node array. Try following forward link. + followForwardLink(); + } +} + +// Follow the forward link and read the next node array if exists. +void DynamicPatriciaTrieReadingHelper::followForwardLink() { + const bool usesAdditionalBuffer = mBuffer->isInAdditionalBuffer(mReadingState.mPos); + const uint8_t *const dictBuf = mBuffer->getBuffer(usesAdditionalBuffer); + if (usesAdditionalBuffer) { + mReadingState.mPos -= mBuffer->getOriginalBufferSize(); + } + const int forwardLinkPosition = + DynamicPatriciaTrieReadingUtils::getForwardLinkPosition(dictBuf, mReadingState.mPos); + if (usesAdditionalBuffer) { + mReadingState.mPos += mBuffer->getOriginalBufferSize(); + } + mReadingState.mPosOfLastForwardLinkField = mReadingState.mPos; + if (DynamicPatriciaTrieReadingUtils::isValidForwardLinkPosition(forwardLinkPosition)) { + // Follow the forward link. + mReadingState.mPos += forwardLinkPosition; + nextPtNodeArray(); + } else { + // All node arrays have been read. + mReadingState.mPos = NOT_A_DICT_POS; + } +} + +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_helper.h b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_helper.h new file mode 100644 index 000000000..c6d8ddcf7 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_helper.h @@ -0,0 +1,286 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DYNAMIC_PATRICIA_TRIE_READING_HELPER_H +#define LATINIME_DYNAMIC_PATRICIA_TRIE_READING_HELPER_H + +#include <cstddef> +#include <vector> + +#include "defines.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_node_reader.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.h" +#include "suggest/policyimpl/dictionary/patricia_trie_reading_utils.h" + +namespace latinime { + +class BufferWithExtendableBuffer; +class DictionaryBigramsStructurePolicy; +class DictionaryShortcutsStructurePolicy; + +/* + * This class is used for traversing dynamic patricia trie. This class supports iterating nodes and + * dealing with additional buffer. This class counts nodes and node arrays to avoid infinite loop. + */ +class DynamicPatriciaTrieReadingHelper { + public: + class TraversingEventListener { + public: + virtual ~TraversingEventListener() {}; + + // Returns whether the event handling was succeeded or not. + virtual bool onAscend() = 0; + + // Returns whether the event handling was succeeded or not. + virtual bool onDescend(const int ptNodeArrayPos) = 0; + + // Returns whether the event handling was succeeded or not. + virtual bool onReadingPtNodeArrayTail() = 0; + + // Returns whether the event handling was succeeded or not. + virtual bool onVisitingPtNode(const DynamicPatriciaTrieNodeReader *const node, + const int *const nodeCodePoints) = 0; + + protected: + TraversingEventListener() {}; + + private: + DISALLOW_COPY_AND_ASSIGN(TraversingEventListener); + }; + + DynamicPatriciaTrieReadingHelper(const BufferWithExtendableBuffer *const buffer, + const DictionaryBigramsStructurePolicy *const bigramsPolicy, + const DictionaryShortcutsStructurePolicy *const shortcutsPolicy) + : mIsError(false), mReadingState(), mBuffer(buffer), + mNodeReader(mBuffer, bigramsPolicy, shortcutsPolicy), mReadingStateStack() {} + + ~DynamicPatriciaTrieReadingHelper() {} + + AK_FORCE_INLINE bool isError() const { + return mIsError; + } + + AK_FORCE_INLINE bool isEnd() const { + return mReadingState.mPos == NOT_A_DICT_POS; + } + + // Initialize reading state with the head position of a PtNode array. + AK_FORCE_INLINE void initWithPtNodeArrayPos(const int ptNodeArrayPos) { + if (ptNodeArrayPos == NOT_A_DICT_POS) { + mReadingState.mPos = NOT_A_DICT_POS; + } else { + mIsError = false; + mReadingState.mPos = ptNodeArrayPos; + mReadingState.mPrevTotalCodePointCount = 0; + mReadingState.mTotalNodeCount = 0; + mReadingState.mNodeArrayCount = 0; + mReadingState.mPosOfLastForwardLinkField = NOT_A_DICT_POS; + mReadingStateStack.clear(); + nextPtNodeArray(); + if (!isEnd()) { + fetchPtNodeInfo(); + } + } + } + + // Initialize reading state with the head position of a node. + AK_FORCE_INLINE void initWithPtNodePos(const int ptNodePos) { + if (ptNodePos == NOT_A_DICT_POS) { + mReadingState.mPos = NOT_A_DICT_POS; + } else { + mIsError = false; + mReadingState.mPos = ptNodePos; + mReadingState.mNodeCount = 1; + mReadingState.mPrevTotalCodePointCount = 0; + mReadingState.mTotalNodeCount = 1; + mReadingState.mNodeArrayCount = 1; + mReadingState.mPosOfLastForwardLinkField = NOT_A_DICT_POS; + mReadingState.mPosOfLastPtNodeArrayHead = NOT_A_DICT_POS; + mReadingStateStack.clear(); + fetchPtNodeInfo(); + } + } + + AK_FORCE_INLINE const DynamicPatriciaTrieNodeReader* getNodeReader() const { + return &mNodeReader; + } + + AK_FORCE_INLINE bool isValidTerminalNode() const { + return !isEnd() && !mNodeReader.isDeleted() && mNodeReader.isTerminal(); + } + + AK_FORCE_INLINE bool isMatchedCodePoint(const int index, const int codePoint) const { + return mMergedNodeCodePoints[index] == codePoint; + } + + // Return code point count exclude the last read node's code points. + AK_FORCE_INLINE int getPrevTotalCodePointCount() const { + return mReadingState.mPrevTotalCodePointCount; + } + + // Return code point count include the last read node's code points. + AK_FORCE_INLINE int getTotalCodePointCount() const { + return mReadingState.mPrevTotalCodePointCount + mNodeReader.getCodePointCount(); + } + + AK_FORCE_INLINE void fetchMergedNodeCodePointsInReverseOrder( + const int index, int *const outCodePoints) const { + const int nodeCodePointCount = mNodeReader.getCodePointCount(); + for (int i = 0; i < nodeCodePointCount; ++i) { + outCodePoints[index + i] = mMergedNodeCodePoints[nodeCodePointCount - 1 - i]; + } + } + + AK_FORCE_INLINE const int *getMergedNodeCodePoints() const { + return mMergedNodeCodePoints; + } + + AK_FORCE_INLINE void readNextSiblingNode() { + mReadingState.mNodeCount -= 1; + mReadingState.mPos = mNodeReader.getSiblingNodePos(); + if (mReadingState.mNodeCount <= 0) { + // All nodes in the current node array have been read. + followForwardLink(); + if (!isEnd()) { + fetchPtNodeInfo(); + } + } else { + fetchPtNodeInfo(); + } + } + + // Read the first child node of the current node. + AK_FORCE_INLINE void readChildNode() { + if (mNodeReader.hasChildren()) { + mReadingState.mPrevTotalCodePointCount += mNodeReader.getCodePointCount(); + mReadingState.mTotalNodeCount = 0; + mReadingState.mNodeArrayCount = 0; + mReadingState.mPos = mNodeReader.getChildrenPos(); + mReadingState.mPosOfLastForwardLinkField = NOT_A_DICT_POS; + // Read children node array. + nextPtNodeArray(); + if (!isEnd()) { + fetchPtNodeInfo(); + } + } else { + mReadingState.mPos = NOT_A_DICT_POS; + } + } + + // Read the parent node of the current node. + AK_FORCE_INLINE void readParentNode() { + if (mNodeReader.getParentPos() != NOT_A_DICT_POS) { + mReadingState.mPrevTotalCodePointCount += mNodeReader.getCodePointCount(); + mReadingState.mTotalNodeCount = 1; + mReadingState.mNodeArrayCount = 1; + mReadingState.mNodeCount = 1; + mReadingState.mPos = mNodeReader.getParentPos(); + mReadingState.mPosOfLastForwardLinkField = NOT_A_DICT_POS; + mReadingState.mPosOfLastPtNodeArrayHead = NOT_A_DICT_POS; + fetchPtNodeInfo(); + } else { + mReadingState.mPos = NOT_A_DICT_POS; + } + } + + AK_FORCE_INLINE int getPosOfLastForwardLinkField() const { + return mReadingState.mPosOfLastForwardLinkField; + } + + AK_FORCE_INLINE int getPosOfLastPtNodeArrayHead() const { + return mReadingState.mPosOfLastPtNodeArrayHead; + } + + AK_FORCE_INLINE void reloadCurrentPtNodeInfo() { + if (!isEnd()) { + fetchPtNodeInfo(); + } + } + + bool traverseAllPtNodesInPostorderDepthFirstManner(TraversingEventListener *const listener); + + bool traverseAllPtNodesInPtNodeArrayLevelPreorderDepthFirstManner( + TraversingEventListener *const listener); + + private: + DISALLOW_COPY_AND_ASSIGN(DynamicPatriciaTrieReadingHelper); + + class ReadingState { + public: + // Note that copy constructor and assignment operator are used for this class to use + // std::vector. + ReadingState() : mPos(NOT_A_DICT_POS), mNodeCount(0), mPrevTotalCodePointCount(0), + mTotalNodeCount(0), mNodeArrayCount(0), mPosOfLastForwardLinkField(NOT_A_DICT_POS), + mPosOfLastPtNodeArrayHead(NOT_A_DICT_POS) {} + + int mPos; + // Node count of a node array. + int mNodeCount; + int mPrevTotalCodePointCount; + int mTotalNodeCount; + int mNodeArrayCount; + int mPosOfLastForwardLinkField; + int mPosOfLastPtNodeArrayHead; + }; + + static const int MAX_CHILD_COUNT_TO_AVOID_INFINITE_LOOP; + static const int MAX_NODE_ARRAY_COUNT_TO_AVOID_INFINITE_LOOP; + static const size_t MAX_READING_STATE_STACK_SIZE; + + bool mIsError; + ReadingState mReadingState; + const BufferWithExtendableBuffer *const mBuffer; + DynamicPatriciaTrieNodeReader mNodeReader; + int mMergedNodeCodePoints[MAX_WORD_LENGTH]; + std::vector<ReadingState> mReadingStateStack; + + void nextPtNodeArray(); + + void followForwardLink(); + + AK_FORCE_INLINE void fetchPtNodeInfo() { + mNodeReader.fetchNodeInfoInBufferFromPtNodePosAndGetNodeCodePoints(mReadingState.mPos, + MAX_WORD_LENGTH, mMergedNodeCodePoints); + if (mNodeReader.getCodePointCount() <= 0) { + // Empty node is not allowed. + mIsError = true; + mReadingState.mPos = NOT_A_DICT_POS; + } + } + + AK_FORCE_INLINE void pushReadingStateToStack() { + if (mReadingStateStack.size() > MAX_READING_STATE_STACK_SIZE) { + AKLOGI("Reading state stack overflow. Max size: %zd", MAX_READING_STATE_STACK_SIZE); + ASSERT(false); + mIsError = true; + mReadingState.mPos = NOT_A_DICT_POS; + } else { + mReadingStateStack.push_back(mReadingState); + } + } + + AK_FORCE_INLINE void popReadingStateFromStack() { + if (mReadingStateStack.empty()) { + mReadingState.mPos = NOT_A_DICT_POS; + } else { + mReadingState = mReadingStateStack.back(); + mReadingStateStack.pop_back(); + fetchPtNodeInfo(); + } + } +}; +} // namespace latinime +#endif /* LATINIME_DYNAMIC_PATRICIA_TRIE_READING_HELPER_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.cpp new file mode 100644 index 000000000..d68446db6 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.cpp @@ -0,0 +1,72 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.h" + +#include "defines.h" +#include "suggest/policyimpl/dictionary/utils/byte_array_utils.h" + +namespace latinime { + +typedef DynamicPatriciaTrieReadingUtils DptReadingUtils; + +const DptReadingUtils::NodeFlags DptReadingUtils::MASK_MOVED = 0xC0; +const DptReadingUtils::NodeFlags DptReadingUtils::FLAG_IS_NOT_MOVED = 0xC0; +const DptReadingUtils::NodeFlags DptReadingUtils::FLAG_IS_MOVED = 0x40; +const DptReadingUtils::NodeFlags DptReadingUtils::FLAG_IS_DELETED = 0x80; + +// TODO: Make DICT_OFFSET_ZERO_OFFSET = 0. +// Currently, DICT_OFFSET_INVALID is 0 in Java side but offset can be 0 during GC. So, the maximum +// value of offsets, which is 0x7FFFFF is used to represent 0 offset. +const int DptReadingUtils::DICT_OFFSET_INVALID = 0; +const int DptReadingUtils::DICT_OFFSET_ZERO_OFFSET = 0x7FFFFF; + +/* static */ int DptReadingUtils::getForwardLinkPosition(const uint8_t *const buffer, + const int pos) { + int linkAddressPos = pos; + return ByteArrayUtils::readSint24AndAdvancePosition(buffer, &linkAddressPos); +} + +/* static */ int DptReadingUtils::getParentPtNodePosOffsetAndAdvancePosition( + const uint8_t *const buffer, int *const pos) { + return ByteArrayUtils::readSint24AndAdvancePosition(buffer, pos); +} + +/* static */ int DptReadingUtils::getParentPtNodePos(const int parentOffset, const int ptNodePos) { + if (parentOffset == DICT_OFFSET_INVALID) { + return NOT_A_DICT_POS; + } else if (parentOffset == DICT_OFFSET_ZERO_OFFSET) { + return ptNodePos; + } else { + return parentOffset + ptNodePos; + } +} + +/* static */ int DptReadingUtils::readChildrenPositionAndAdvancePosition( + const uint8_t *const buffer, int *const pos) { + const int base = *pos; + const int offset = ByteArrayUtils::readSint24AndAdvancePosition(buffer, pos); + if (offset == DICT_OFFSET_INVALID) { + // The PtNode does not have children. + return NOT_A_DICT_POS; + } else if (offset == DICT_OFFSET_ZERO_OFFSET) { + return base; + } else { + return base + offset; + } +} + +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.h b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.h new file mode 100644 index 000000000..67c3cc57e --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.h @@ -0,0 +1,75 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DYNAMIC_PATRICIA_TRIE_READING_UTILS_H +#define LATINIME_DYNAMIC_PATRICIA_TRIE_READING_UTILS_H + +#include <stdint.h> + +#include "defines.h" + +namespace latinime { + +class DynamicPatriciaTrieReadingUtils { + public: + typedef uint8_t NodeFlags; + + static const int DICT_OFFSET_INVALID; + static const int DICT_OFFSET_ZERO_OFFSET; + + static int getForwardLinkPosition(const uint8_t *const buffer, const int pos); + + static AK_FORCE_INLINE bool isValidForwardLinkPosition(const int forwardLinkAddress) { + return forwardLinkAddress != 0; + } + + static int getParentPtNodePosOffsetAndAdvancePosition(const uint8_t *const buffer, + int *const pos); + + static int getParentPtNodePos(const int parentOffset, const int ptNodePos); + + static int readChildrenPositionAndAdvancePosition(const uint8_t *const buffer, int *const pos); + + /** + * Node Flags + */ + static AK_FORCE_INLINE bool isMoved(const NodeFlags flags) { + return FLAG_IS_MOVED == (MASK_MOVED & flags); + } + + static AK_FORCE_INLINE bool isDeleted(const NodeFlags flags) { + return FLAG_IS_DELETED == (MASK_MOVED & flags); + } + + static AK_FORCE_INLINE NodeFlags updateAndGetFlags(const NodeFlags originalFlags, + const bool isMoved, const bool isDeleted) { + NodeFlags flags = originalFlags; + flags = isMoved ? ((flags & (~MASK_MOVED)) | FLAG_IS_MOVED) : flags; + flags = isDeleted ? ((flags & (~MASK_MOVED)) | FLAG_IS_DELETED) : flags; + flags = (!isMoved && !isDeleted) ? ((flags & (~MASK_MOVED)) | FLAG_IS_NOT_MOVED) : flags; + return flags; + } + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(DynamicPatriciaTrieReadingUtils); + + static const NodeFlags MASK_MOVED; + static const NodeFlags FLAG_IS_NOT_MOVED; + static const NodeFlags FLAG_IS_MOVED; + static const NodeFlags FLAG_IS_DELETED; +}; +} // namespace latinime +#endif /* LATINIME_DYNAMIC_PATRICIA_TRIE_READING_UTILS_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.cpp new file mode 100644 index 000000000..578645cd5 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.cpp @@ -0,0 +1,511 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.h" + +#include "suggest/policyimpl/dictionary/bigram/dynamic_bigram_list_policy.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_gc_event_listeners.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_node_reader.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_helper.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_utils.h" +#include "suggest/policyimpl/dictionary/header/header_policy.h" +#include "suggest/policyimpl/dictionary/patricia_trie_reading_utils.h" +#include "suggest/policyimpl/dictionary/shortcut/dynamic_shortcut_list_policy.h" +#include "suggest/policyimpl/dictionary/utils/dict_file_writing_utils.h" +#include "utils/hash_map_compat.h" + +namespace latinime { + +const int DynamicPatriciaTrieWritingHelper::CHILDREN_POSITION_FIELD_SIZE = 3; +// TODO: Make MAX_DICTIONARY_SIZE 8MB. +const size_t DynamicPatriciaTrieWritingHelper::MAX_DICTIONARY_SIZE = 2 * 1024 * 1024; + +bool DynamicPatriciaTrieWritingHelper::addUnigramWord( + DynamicPatriciaTrieReadingHelper *const readingHelper, + const int *const wordCodePoints, const int codePointCount, const int probability) { + int parentPos = NOT_A_DICT_POS; + while (!readingHelper->isEnd()) { + const int matchedCodePointCount = readingHelper->getPrevTotalCodePointCount(); + if (!readingHelper->isMatchedCodePoint(0 /* index */, + wordCodePoints[matchedCodePointCount])) { + // The first code point is different from target code point. Skip this node and read + // the next sibling node. + readingHelper->readNextSiblingNode(); + continue; + } + // Check following merged node code points. + const DynamicPatriciaTrieNodeReader *const nodeReader = readingHelper->getNodeReader(); + const int nodeCodePointCount = nodeReader->getCodePointCount(); + for (int j = 1; j < nodeCodePointCount; ++j) { + const int nextIndex = matchedCodePointCount + j; + if (nextIndex >= codePointCount || !readingHelper->isMatchedCodePoint(j, + wordCodePoints[matchedCodePointCount + j])) { + return reallocatePtNodeAndAddNewPtNodes(nodeReader, + readingHelper->getMergedNodeCodePoints(), j, probability, + wordCodePoints + matchedCodePointCount, + codePointCount - matchedCodePointCount); + } + } + // All characters are matched. + if (codePointCount == readingHelper->getTotalCodePointCount()) { + return setPtNodeProbability(nodeReader, probability, + readingHelper->getMergedNodeCodePoints()); + } + if (!nodeReader->hasChildren()) { + return createChildrenPtNodeArrayAndAChildPtNode(nodeReader, probability, + wordCodePoints + readingHelper->getTotalCodePointCount(), + codePointCount - readingHelper->getTotalCodePointCount()); + } + // Advance to the children nodes. + parentPos = nodeReader->getHeadPos(); + readingHelper->readChildNode(); + } + if (readingHelper->isError()) { + // The dictionary is invalid. + return false; + } + int pos = readingHelper->getPosOfLastForwardLinkField(); + return createAndInsertNodeIntoPtNodeArray(parentPos, + wordCodePoints + readingHelper->getPrevTotalCodePointCount(), + codePointCount - readingHelper->getPrevTotalCodePointCount(), + probability, &pos); +} + +bool DynamicPatriciaTrieWritingHelper::addBigramWords(const int word0Pos, const int word1Pos, + const int probability) { + int mMergedNodeCodePoints[MAX_WORD_LENGTH]; + DynamicPatriciaTrieNodeReader nodeReader(mBuffer, mBigramPolicy, mShortcutPolicy); + nodeReader.fetchNodeInfoInBufferFromPtNodePosAndGetNodeCodePoints(word0Pos, MAX_WORD_LENGTH, + mMergedNodeCodePoints); + // Move node to add bigram entry. + const int newNodePos = mBuffer->getTailPosition(); + if (!markNodeAsMovedAndSetPosition(&nodeReader, newNodePos, newNodePos)) { + return false; + } + int writingPos = newNodePos; + // Write a new PtNode using original PtNode's info to the tail of the dictionary in mBuffer. + if (!writePtNodeToBufferByCopyingPtNodeInfo(mBuffer, &nodeReader, nodeReader.getParentPos(), + mMergedNodeCodePoints, nodeReader.getCodePointCount(), nodeReader.getProbability(), + &writingPos)) { + return false; + } + nodeReader.fetchNodeInfoInBufferFromPtNodePos(newNodePos); + if (nodeReader.getBigramsPos() != NOT_A_DICT_POS) { + // Insert a new bigram entry into the existing bigram list. + int bigramListPos = nodeReader.getBigramsPos(); + return mBigramPolicy->addNewBigramEntryToBigramList(word1Pos, probability, &bigramListPos); + } else { + // The PtNode doesn't have a bigram list. + // First, Write a bigram entry at the tail position of the PtNode. + if (!mBigramPolicy->writeNewBigramEntry(word1Pos, probability, &writingPos)) { + return false; + } + // Then, Mark as the PtNode having bigram list in the flags. + const PatriciaTrieReadingUtils::NodeFlags updatedFlags = + PatriciaTrieReadingUtils::createAndGetFlags(nodeReader.isBlacklisted(), + nodeReader.isNotAWord(), nodeReader.getProbability() != NOT_A_PROBABILITY, + nodeReader.getShortcutPos() != NOT_A_DICT_POS, true /* hasBigrams */, + nodeReader.getCodePointCount() > 1, CHILDREN_POSITION_FIELD_SIZE); + writingPos = newNodePos; + // Write updated flags into the moved PtNode's flags field. + return DynamicPatriciaTrieWritingUtils::writeFlagsAndAdvancePosition(mBuffer, updatedFlags, + &writingPos); + } +} + +// Remove a bigram relation from word0Pos to word1Pos. +bool DynamicPatriciaTrieWritingHelper::removeBigramWords(const int word0Pos, const int word1Pos) { + DynamicPatriciaTrieNodeReader nodeReader(mBuffer, mBigramPolicy, mShortcutPolicy); + nodeReader.fetchNodeInfoInBufferFromPtNodePos(word0Pos); + if (nodeReader.getBigramsPos() == NOT_A_DICT_POS) { + return false; + } + return mBigramPolicy->removeBigram(nodeReader.getBigramsPos(), word1Pos); +} + +void DynamicPatriciaTrieWritingHelper::writeToDictFile(const char *const fileName, + const HeaderPolicy *const headerPolicy) { + BufferWithExtendableBuffer headerBuffer(0 /* originalBuffer */, 0 /* originalBufferSize */); + if (!headerPolicy->writeHeaderToBuffer(&headerBuffer, false /* updatesLastUpdatedTime */)) { + return; + } + DictFileWritingUtils::flushAllHeaderAndBodyToFile(fileName, &headerBuffer, mBuffer); +} + +void DynamicPatriciaTrieWritingHelper::writeToDictFileWithGC(const int rootPtNodeArrayPos, + const char *const fileName, const HeaderPolicy *const headerPolicy) { + BufferWithExtendableBuffer headerBuffer(0 /* originalBuffer */, 0 /* originalBufferSize */); + if (!headerPolicy->writeHeaderToBuffer(&headerBuffer, true /* updatesLastUpdatedTime */)) { + return; + } + BufferWithExtendableBuffer newDictBuffer(0 /* originalBuffer */, 0 /* originalBufferSize */, + MAX_DICTIONARY_SIZE); + if (!runGC(rootPtNodeArrayPos, &newDictBuffer)) { + return; + } + DictFileWritingUtils::flushAllHeaderAndBodyToFile(fileName, &headerBuffer, &newDictBuffer); +} + +bool DynamicPatriciaTrieWritingHelper::markNodeAsDeleted( + const DynamicPatriciaTrieNodeReader *const nodeToUpdate) { + int pos = nodeToUpdate->getHeadPos(); + const bool usesAdditionalBuffer = mBuffer->isInAdditionalBuffer(pos); + const uint8_t *const dictBuf = mBuffer->getBuffer(usesAdditionalBuffer); + if (usesAdditionalBuffer) { + pos -= mBuffer->getOriginalBufferSize(); + } + // Read original flags + const PatriciaTrieReadingUtils::NodeFlags originalFlags = + PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(dictBuf, &pos); + const PatriciaTrieReadingUtils::NodeFlags updatedFlags = + DynamicPatriciaTrieReadingUtils::updateAndGetFlags(originalFlags, false /* isMoved */, + true /* isDeleted */); + int writingPos = nodeToUpdate->getHeadPos(); + // Update flags. + return DynamicPatriciaTrieWritingUtils::writeFlagsAndAdvancePosition(mBuffer, updatedFlags, + &writingPos); +} + +bool DynamicPatriciaTrieWritingHelper::markNodeAsMovedAndSetPosition( + const DynamicPatriciaTrieNodeReader *const originalNode, const int movedPos, + const int bigramLinkedNodePos) { + int pos = originalNode->getHeadPos(); + const bool usesAdditionalBuffer = mBuffer->isInAdditionalBuffer(pos); + const uint8_t *const dictBuf = mBuffer->getBuffer(usesAdditionalBuffer); + if (usesAdditionalBuffer) { + pos -= mBuffer->getOriginalBufferSize(); + } + // Read original flags + const PatriciaTrieReadingUtils::NodeFlags originalFlags = + PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(dictBuf, &pos); + const PatriciaTrieReadingUtils::NodeFlags updatedFlags = + DynamicPatriciaTrieReadingUtils::updateAndGetFlags(originalFlags, true /* isMoved */, + false /* isDeleted */); + int writingPos = originalNode->getHeadPos(); + // Update flags. + if (!DynamicPatriciaTrieWritingUtils::writeFlagsAndAdvancePosition(mBuffer, updatedFlags, + &writingPos)) { + return false; + } + // Update moved position, which is stored in the parent offset field. + if (!DynamicPatriciaTrieWritingUtils::writeParentPosOffsetAndAdvancePosition( + mBuffer, movedPos, originalNode->getHeadPos(), &writingPos)) { + return false; + } + // Update bigram linked node position, which is stored in the children position field. + int childrenPosFieldPos = originalNode->getChildrenPosFieldPos(); + if (!DynamicPatriciaTrieWritingUtils::writeChildrenPositionAndAdvancePosition( + mBuffer, bigramLinkedNodePos, &childrenPosFieldPos)) { + return false; + } + if (originalNode->hasChildren()) { + // Update children's parent position. + DynamicPatriciaTrieReadingHelper readingHelper(mBuffer, mBigramPolicy, mShortcutPolicy); + const DynamicPatriciaTrieNodeReader *const nodeReader = readingHelper.getNodeReader(); + readingHelper.initWithPtNodeArrayPos(originalNode->getChildrenPos()); + while (!readingHelper.isEnd()) { + int parentOffsetFieldPos = nodeReader->getHeadPos() + + DynamicPatriciaTrieWritingUtils::NODE_FLAG_FIELD_SIZE; + if (!DynamicPatriciaTrieWritingUtils::writeParentPosOffsetAndAdvancePosition( + mBuffer, movedPos, nodeReader->getHeadPos(), &parentOffsetFieldPos)) { + // Parent offset cannot be written because of a bug or a broken dictionary; thus, + // we give up to update dictionary. + return false; + } + readingHelper.readNextSiblingNode(); + } + } + return true; +} + +// Write new PtNode at writingPos. +bool DynamicPatriciaTrieWritingHelper::writePtNodeWithFullInfoToBuffer( + BufferWithExtendableBuffer *const bufferToWrite, const bool isBlacklisted, + const bool isNotAWord, const int parentPos, const int *const codePoints, + const int codePointCount, const int probability, const int childrenPos, + const int originalBigramListPos, const int originalShortcutListPos, + int *const writingPos) { + const int nodePos = *writingPos; + // Write dummy flags. The Node flags are updated with appropriate flags at the last step of the + // PtNode writing. + if (!DynamicPatriciaTrieWritingUtils::writeFlagsAndAdvancePosition(bufferToWrite, + 0 /* nodeFlags */, writingPos)) { + return false; + } + // Calculate a parent offset and write the offset. + if (!DynamicPatriciaTrieWritingUtils::writeParentPosOffsetAndAdvancePosition(bufferToWrite, + parentPos, nodePos, writingPos)) { + return false; + } + // Write code points + if (!DynamicPatriciaTrieWritingUtils::writeCodePointsAndAdvancePosition(bufferToWrite, + codePoints, codePointCount, writingPos)) { + return false; + } + // Write probability when the probability is a valid probability, which means this node is + // terminal. + if (probability != NOT_A_PROBABILITY) { + if (!DynamicPatriciaTrieWritingUtils::writeProbabilityAndAdvancePosition(bufferToWrite, + probability, writingPos)) { + return false; + } + } + // Write children position + if (!DynamicPatriciaTrieWritingUtils::writeChildrenPositionAndAdvancePosition(bufferToWrite, + childrenPos, writingPos)) { + return false; + } + // Copy shortcut list when the originalShortcutListPos is valid dictionary position. + if (originalShortcutListPos != NOT_A_DICT_POS) { + int fromPos = originalShortcutListPos; + if (!mShortcutPolicy->copyAllShortcutsAndReturnIfSucceededOrNot(bufferToWrite, &fromPos, + writingPos)) { + return false; + } + } + // Copy bigram list when the originalBigramListPos is valid dictionary position. + int bigramCount = 0; + if (originalBigramListPos != NOT_A_DICT_POS) { + int fromPos = originalBigramListPos; + if (!mBigramPolicy->copyAllBigrams(bufferToWrite, &fromPos, writingPos, &bigramCount)) { + return false; + } + } + // Create node flags and write them. + PatriciaTrieReadingUtils::NodeFlags nodeFlags = + PatriciaTrieReadingUtils::createAndGetFlags(isBlacklisted, isNotAWord, + probability != NOT_A_PROBABILITY /* isTerminal */, + originalShortcutListPos != NOT_A_DICT_POS /* hasShortcutTargets */, + bigramCount > 0 /* hasBigrams */, codePointCount > 1 /* hasMultipleChars */, + CHILDREN_POSITION_FIELD_SIZE); + int flagsFieldPos = nodePos; + if (!DynamicPatriciaTrieWritingUtils::writeFlagsAndAdvancePosition(bufferToWrite, nodeFlags, + &flagsFieldPos)) { + return false; + } + return true; +} + +bool DynamicPatriciaTrieWritingHelper::writePtNodeToBuffer( + BufferWithExtendableBuffer *const bufferToWrite, const int parentPos, + const int *const codePoints, const int codePointCount, const int probability, + int *const writingPos) { + return writePtNodeWithFullInfoToBuffer(bufferToWrite, false /* isBlacklisted */, + false /* isNotAWord */, parentPos, codePoints, codePointCount, probability, + NOT_A_DICT_POS /* childrenPos */, NOT_A_DICT_POS /* originalBigramsPos */, + NOT_A_DICT_POS /* originalShortcutPos */, writingPos); +} + +bool DynamicPatriciaTrieWritingHelper::writePtNodeToBufferByCopyingPtNodeInfo( + BufferWithExtendableBuffer *const bufferToWrite, + const DynamicPatriciaTrieNodeReader *const originalNode, const int parentPos, + const int *const codePoints, const int codePointCount, const int probability, + int *const writingPos) { + return writePtNodeWithFullInfoToBuffer(bufferToWrite, originalNode->isBlacklisted(), + originalNode->isNotAWord(), parentPos, codePoints, codePointCount, probability, + originalNode->getChildrenPos(), originalNode->getBigramsPos(), + originalNode->getShortcutPos(), writingPos); +} + +bool DynamicPatriciaTrieWritingHelper::createAndInsertNodeIntoPtNodeArray(const int parentPos, + const int *const nodeCodePoints, const int nodeCodePointCount, const int probability, + int *const forwardLinkFieldPos) { + const int newPtNodeArrayPos = mBuffer->getTailPosition(); + if (!DynamicPatriciaTrieWritingUtils::writeForwardLinkPositionAndAdvancePosition(mBuffer, + newPtNodeArrayPos, forwardLinkFieldPos)) { + return false; + } + return createNewPtNodeArrayWithAChildPtNode(parentPos, nodeCodePoints, nodeCodePointCount, + probability); +} + +bool DynamicPatriciaTrieWritingHelper::setPtNodeProbability( + const DynamicPatriciaTrieNodeReader *const originalPtNode, const int probability, + const int *const codePoints) { + if (originalPtNode->isTerminal()) { + // Overwrites the probability. + int probabilityFieldPos = originalPtNode->getProbabilityFieldPos(); + if (!DynamicPatriciaTrieWritingUtils::writeProbabilityAndAdvancePosition(mBuffer, + probability, &probabilityFieldPos)) { + return false; + } + } else { + // Make the node terminal and write the probability. + int movedPos = mBuffer->getTailPosition(); + if (!markNodeAsMovedAndSetPosition(originalPtNode, movedPos, movedPos)) { + return false; + } + if (!writePtNodeToBufferByCopyingPtNodeInfo(mBuffer, originalPtNode, + originalPtNode->getParentPos(), codePoints, originalPtNode->getCodePointCount(), + probability, &movedPos)) { + return false; + } + } + return true; +} + +bool DynamicPatriciaTrieWritingHelper::createChildrenPtNodeArrayAndAChildPtNode( + const DynamicPatriciaTrieNodeReader *const parentNode, const int probability, + const int *const codePoints, const int codePointCount) { + const int newPtNodeArrayPos = mBuffer->getTailPosition(); + int childrenPosFieldPos = parentNode->getChildrenPosFieldPos(); + if (!DynamicPatriciaTrieWritingUtils::writeChildrenPositionAndAdvancePosition(mBuffer, + newPtNodeArrayPos, &childrenPosFieldPos)) { + return false; + } + return createNewPtNodeArrayWithAChildPtNode(parentNode->getHeadPos(), codePoints, + codePointCount, probability); +} + +bool DynamicPatriciaTrieWritingHelper::createNewPtNodeArrayWithAChildPtNode( + const int parentPtNodePos, const int *const nodeCodePoints, const int nodeCodePointCount, + const int probability) { + int writingPos = mBuffer->getTailPosition(); + if (!DynamicPatriciaTrieWritingUtils::writePtNodeArraySizeAndAdvancePosition(mBuffer, + 1 /* arraySize */, &writingPos)) { + return false; + } + if (!writePtNodeToBuffer(mBuffer, parentPtNodePos, nodeCodePoints, nodeCodePointCount, + probability, &writingPos)) { + return false; + } + if (!DynamicPatriciaTrieWritingUtils::writeForwardLinkPositionAndAdvancePosition(mBuffer, + NOT_A_DICT_POS /* forwardLinkPos */, &writingPos)) { + return false; + } + return true; +} + +// Returns whether the dictionary updating was succeeded or not. +bool DynamicPatriciaTrieWritingHelper::reallocatePtNodeAndAddNewPtNodes( + const DynamicPatriciaTrieNodeReader *const reallocatingPtNode, + const int *const reallocatingPtNodeCodePoints, const int overlappingCodePointCount, + const int probabilityOfNewPtNode, const int *const newNodeCodePoints, + const int newNodeCodePointCount) { + // When addsExtraChild is true, split the reallocating PtNode and add new child. + // Reallocating PtNode: abcde, newNode: abcxy. + // abc (1st, not terminal) __ de (2nd) + // \_ xy (extra child, terminal) + // Otherwise, this method makes 1st part terminal and write probabilityOfNewPtNode. + // Reallocating PtNode: abcde, newNode: abc. + // abc (1st, terminal) __ de (2nd) + const bool addsExtraChild = newNodeCodePointCount > overlappingCodePointCount; + const int firstPartOfReallocatedPtNodePos = mBuffer->getTailPosition(); + int writingPos = firstPartOfReallocatedPtNodePos; + // Write the 1st part of the reallocating node. The children position will be updated later + // with actual children position. + const int newProbability = addsExtraChild ? NOT_A_PROBABILITY : probabilityOfNewPtNode; + if (!writePtNodeToBuffer(mBuffer, reallocatingPtNode->getParentPos(), + reallocatingPtNodeCodePoints, overlappingCodePointCount, newProbability, + &writingPos)) { + return false; + } + const int actualChildrenPos = writingPos; + // Create new children PtNode array. + const size_t newPtNodeCount = addsExtraChild ? 2 : 1; + if (!DynamicPatriciaTrieWritingUtils::writePtNodeArraySizeAndAdvancePosition(mBuffer, + newPtNodeCount, &writingPos)) { + return false; + } + // Write the 2nd part of the reallocating node. + const int secondPartOfReallocatedPtNodePos = writingPos; + if (!writePtNodeToBufferByCopyingPtNodeInfo(mBuffer, reallocatingPtNode, + firstPartOfReallocatedPtNodePos, + reallocatingPtNodeCodePoints + overlappingCodePointCount, + reallocatingPtNode->getCodePointCount() - overlappingCodePointCount, + reallocatingPtNode->getProbability(), &writingPos)) { + return false; + } + if (addsExtraChild) { + if (!writePtNodeToBuffer(mBuffer, firstPartOfReallocatedPtNodePos, + newNodeCodePoints + overlappingCodePointCount, + newNodeCodePointCount - overlappingCodePointCount, probabilityOfNewPtNode, + &writingPos)) { + return false; + } + } + if (!DynamicPatriciaTrieWritingUtils::writeForwardLinkPositionAndAdvancePosition(mBuffer, + NOT_A_DICT_POS /* forwardLinkPos */, &writingPos)) { + return false; + } + // Update original reallocatingPtNode as moved. + if (!markNodeAsMovedAndSetPosition(reallocatingPtNode, firstPartOfReallocatedPtNodePos, + secondPartOfReallocatedPtNodePos)) { + return false; + } + // Load node info. Information of the 1st part will be fetched. + DynamicPatriciaTrieNodeReader nodeReader(mBuffer, mBigramPolicy, mShortcutPolicy); + nodeReader.fetchNodeInfoInBufferFromPtNodePos(firstPartOfReallocatedPtNodePos); + // Update children position. + int childrenPosFieldPos = nodeReader.getChildrenPosFieldPos(); + if (!DynamicPatriciaTrieWritingUtils::writeChildrenPositionAndAdvancePosition(mBuffer, + actualChildrenPos, &childrenPosFieldPos)) { + return false; + } + return true; +} + +bool DynamicPatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, + BufferWithExtendableBuffer *const bufferToWrite) { + DynamicPatriciaTrieReadingHelper readingHelper(mBuffer, mBigramPolicy, mShortcutPolicy); + readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos); + DynamicPatriciaTrieGcEventListeners + ::TraversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted + traversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted( + this, mBuffer); + if (!readingHelper.traverseAllPtNodesInPostorderDepthFirstManner( + &traversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted)) { + return false; + } + + readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos); + DynamicPatriciaTrieGcEventListeners::TraversePolicyToUpdateBigramProbability + traversePolicyToUpdateBigramProbability(mBigramPolicy); + if (!readingHelper.traverseAllPtNodesInPostorderDepthFirstManner( + &traversePolicyToUpdateBigramProbability)) { + return false; + } + + // Mapping from positions in mBuffer to positions in bufferToWrite. + DictPositionRelocationMap dictPositionRelocationMap; + readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos); + DynamicPatriciaTrieGcEventListeners::TraversePolicyToPlaceAndWriteValidPtNodesToBuffer + traversePolicyToPlaceAndWriteValidPtNodesToBuffer(this, bufferToWrite, + &dictPositionRelocationMap); + if (!readingHelper.traverseAllPtNodesInPtNodeArrayLevelPreorderDepthFirstManner( + &traversePolicyToPlaceAndWriteValidPtNodesToBuffer)) { + return false; + } + + // Create policy instance for the GCed dictionary. + DynamicShortcutListPolicy newDictShortcutPolicy(bufferToWrite); + DynamicBigramListPolicy newDictBigramPolicy(bufferToWrite, &newDictShortcutPolicy); + // Create reading helper for the GCed dictionary. + DynamicPatriciaTrieReadingHelper newDictReadingHelper(bufferToWrite, &newDictBigramPolicy, + &newDictShortcutPolicy); + newDictReadingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos); + DynamicPatriciaTrieGcEventListeners::TraversePolicyToUpdateAllPositionFields + traversePolicyToUpdateAllPositionFields(this, &newDictBigramPolicy, bufferToWrite, + &dictPositionRelocationMap); + if (!newDictReadingHelper.traverseAllPtNodesInPtNodeArrayLevelPreorderDepthFirstManner( + &traversePolicyToUpdateAllPositionFields)) { + return false; + } + return true; +} + +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.h b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.h new file mode 100644 index 000000000..fe1b2437a --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_helper.h @@ -0,0 +1,128 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DYNAMIC_PATRICIA_TRIE_WRITING_HELPER_H +#define LATINIME_DYNAMIC_PATRICIA_TRIE_WRITING_HELPER_H + +#include <stdint.h> + +#include "defines.h" +#include "utils/hash_map_compat.h" + +namespace latinime { + +class BufferWithExtendableBuffer; +class DynamicBigramListPolicy; +class DynamicPatriciaTrieNodeReader; +class DynamicPatriciaTrieReadingHelper; +class DynamicShortcutListPolicy; +class HeaderPolicy; + +class DynamicPatriciaTrieWritingHelper { + public: + typedef hash_map_compat<int, int> PtNodeArrayPositionRelocationMap; + typedef hash_map_compat<int, int> PtNodePositionRelocationMap; + struct DictPositionRelocationMap { + public: + DictPositionRelocationMap() + : mPtNodeArrayPositionRelocationMap(), mPtNodePositionRelocationMap() {} + + PtNodeArrayPositionRelocationMap mPtNodeArrayPositionRelocationMap; + PtNodePositionRelocationMap mPtNodePositionRelocationMap; + + private: + DISALLOW_COPY_AND_ASSIGN(DictPositionRelocationMap); + }; + + DynamicPatriciaTrieWritingHelper(BufferWithExtendableBuffer *const buffer, + DynamicBigramListPolicy *const bigramPolicy, + DynamicShortcutListPolicy *const shortcutPolicy) + : mBuffer(buffer), mBigramPolicy(bigramPolicy), mShortcutPolicy(shortcutPolicy) {} + + ~DynamicPatriciaTrieWritingHelper() {} + + // Add a word to the dictionary. If the word already exists, update the probability. + bool addUnigramWord(DynamicPatriciaTrieReadingHelper *const readingHelper, + const int *const wordCodePoints, const int codePointCount, const int probability); + + // Add a bigram relation from word0Pos to word1Pos. + bool addBigramWords(const int word0Pos, const int word1Pos, const int probability); + + // Remove a bigram relation from word0Pos to word1Pos. + bool removeBigramWords(const int word0Pos, const int word1Pos); + + void writeToDictFile(const char *const fileName, const HeaderPolicy *const headerPolicy); + + void writeToDictFileWithGC(const int rootPtNodeArrayPos, const char *const fileName, + const HeaderPolicy *const headerPolicy); + + // CAVEAT: This method must be called only from inner classes of + // DynamicPatriciaTrieGcEventListeners. + bool markNodeAsDeleted(const DynamicPatriciaTrieNodeReader *const nodeToUpdate); + + // CAVEAT: This method must be called only from this class or inner classes of + // DynamicPatriciaTrieGcEventListeners. + bool writePtNodeToBufferByCopyingPtNodeInfo(BufferWithExtendableBuffer *const bufferToWrite, + const DynamicPatriciaTrieNodeReader *const originalNode, const int parentPos, + const int *const codePoints, const int codePointCount, const int probability, + int *const writingPos); + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(DynamicPatriciaTrieWritingHelper); + + static const int CHILDREN_POSITION_FIELD_SIZE; + static const size_t MAX_DICTIONARY_SIZE; + + BufferWithExtendableBuffer *const mBuffer; + DynamicBigramListPolicy *const mBigramPolicy; + DynamicShortcutListPolicy *const mShortcutPolicy; + + bool markNodeAsMovedAndSetPosition(const DynamicPatriciaTrieNodeReader *const nodeToUpdate, + const int movedPos, const int bigramLinkedNodePos); + + bool writePtNodeWithFullInfoToBuffer(BufferWithExtendableBuffer *const bufferToWrite, + const bool isBlacklisted, const bool isNotAWord, + const int parentPos, const int *const codePoints, const int codePointCount, + const int probability, const int childrenPos, const int originalBigramListPos, + const int originalShortcutListPos, int *const writingPos); + + bool writePtNodeToBuffer(BufferWithExtendableBuffer *const bufferToWrite, + const int parentPos, const int *const codePoints, const int codePointCount, + const int probability, int *const writingPos); + + bool createAndInsertNodeIntoPtNodeArray(const int parentPos, const int *const nodeCodePoints, + const int nodeCodePointCount, const int probability, int *const forwardLinkFieldPos); + + bool setPtNodeProbability(const DynamicPatriciaTrieNodeReader *const originalNode, + const int probability, const int *const codePoints); + + bool createChildrenPtNodeArrayAndAChildPtNode( + const DynamicPatriciaTrieNodeReader *const parentNode, const int probability, + const int *const codePoints, const int codePointCount); + + bool createNewPtNodeArrayWithAChildPtNode(const int parentPos, const int *const nodeCodePoints, + const int nodeCodePointCount, const int probability); + + bool reallocatePtNodeAndAddNewPtNodes( + const DynamicPatriciaTrieNodeReader *const reallocatingPtNode, + const int *const reallocatingPtNodeCodePoints, const int overlappingCodePointCount, + const int probabilityOfNewPtNode, const int *const newNodeCodePoints, + const int newNodeCodePointCount); + + bool runGC(const int rootPtNodeArrayPos, BufferWithExtendableBuffer *const bufferToWrite); +}; +} // namespace latinime +#endif /* LATINIME_DYNAMIC_PATRICIA_TRIE_WRITING_HELPER_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_utils.cpp new file mode 100644 index 000000000..30ff10cd6 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_utils.cpp @@ -0,0 +1,147 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_utils.h" + +#include <cstddef> +#include <cstdlib> +#include <stdint.h> + +#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" + +namespace latinime { + +const size_t DynamicPatriciaTrieWritingUtils::MAX_PTNODE_ARRAY_SIZE_TO_USE_SMALL_SIZE_FIELD = 0x7F; +const size_t DynamicPatriciaTrieWritingUtils::MAX_PTNODE_ARRAY_SIZE = 0x7FFF; +const int DynamicPatriciaTrieWritingUtils::SMALL_PTNODE_ARRAY_SIZE_FIELD_SIZE = 1; +const int DynamicPatriciaTrieWritingUtils::LARGE_PTNODE_ARRAY_SIZE_FIELD_SIZE = 2; +const int DynamicPatriciaTrieWritingUtils::LARGE_PTNODE_ARRAY_SIZE_FIELD_SIZE_FLAG = 0x8000; +const int DynamicPatriciaTrieWritingUtils::DICT_OFFSET_FIELD_SIZE = 3; +const int DynamicPatriciaTrieWritingUtils::MAX_DICT_OFFSET_VALUE = 0x7FFFFF; +const int DynamicPatriciaTrieWritingUtils::MIN_DICT_OFFSET_VALUE = -0x7FFFFF; +const int DynamicPatriciaTrieWritingUtils::DICT_OFFSET_NEGATIVE_FLAG = 0x800000; +const int DynamicPatriciaTrieWritingUtils::PROBABILITY_FIELD_SIZE = 1; +const int DynamicPatriciaTrieWritingUtils::NODE_FLAG_FIELD_SIZE = 1; + +/* static */ bool DynamicPatriciaTrieWritingUtils::writeEmptyDictionary( + BufferWithExtendableBuffer *const buffer, const int rootPos) { + int writingPos = rootPos; + if (!writePtNodeArraySizeAndAdvancePosition(buffer, 0 /* arraySize */, &writingPos)) { + return false; + } + return writeForwardLinkPositionAndAdvancePosition(buffer, NOT_A_DICT_POS /* forwardLinkPos */, + &writingPos); +} + +/* static */ bool DynamicPatriciaTrieWritingUtils::writeForwardLinkPositionAndAdvancePosition( + BufferWithExtendableBuffer *const buffer, const int forwardLinkPos, + int *const forwardLinkFieldPos) { + return writeDictOffset(buffer, forwardLinkPos, (*forwardLinkFieldPos), forwardLinkFieldPos); +} + +/* static */ bool DynamicPatriciaTrieWritingUtils::writePtNodeArraySizeAndAdvancePosition( + BufferWithExtendableBuffer *const buffer, const size_t arraySize, + int *const arraySizeFieldPos) { + // Currently, all array size field to be created has LARGE_PTNODE_ARRAY_SIZE_FIELD_SIZE to + // simplify updating process. + // TODO: Use SMALL_PTNODE_ARRAY_SIZE_FIELD_SIZE for small arrays. + /*if (arraySize <= MAX_PTNODE_ARRAY_SIZE_TO_USE_SMALL_SIZE_FIELD) { + return buffer->writeUintAndAdvancePosition(arraySize, SMALL_PTNODE_ARRAY_SIZE_FIELD_SIZE, + arraySizeFieldPos); + } else */ + if (arraySize <= MAX_PTNODE_ARRAY_SIZE) { + uint32_t data = arraySize | LARGE_PTNODE_ARRAY_SIZE_FIELD_SIZE_FLAG; + return buffer->writeUintAndAdvancePosition(data, LARGE_PTNODE_ARRAY_SIZE_FIELD_SIZE, + arraySizeFieldPos); + } else { + AKLOGI("PtNode array size cannot be written because arraySize is too large: %zd", + arraySize); + ASSERT(false); + return false; + } +} + +/* static */ bool DynamicPatriciaTrieWritingUtils::writeFlagsAndAdvancePosition( + BufferWithExtendableBuffer *const buffer, + const DynamicPatriciaTrieReadingUtils::NodeFlags nodeFlags, int *const nodeFlagsFieldPos) { + return buffer->writeUintAndAdvancePosition(nodeFlags, NODE_FLAG_FIELD_SIZE, nodeFlagsFieldPos); +} + +// Note that parentOffset is offset from node's head position. +/* static */ bool DynamicPatriciaTrieWritingUtils::writeParentPosOffsetAndAdvancePosition( + BufferWithExtendableBuffer *const buffer, const int parentPos, const int basePos, + int *const parentPosFieldPos) { + return writeDictOffset(buffer, parentPos, basePos, parentPosFieldPos); +} + +/* static */ bool DynamicPatriciaTrieWritingUtils::writeCodePointsAndAdvancePosition( + BufferWithExtendableBuffer *const buffer, const int *const codePoints, + const int codePointCount, int *const codePointFieldPos) { + if (codePointCount <= 0) { + AKLOGI("code points cannot be written because codePointCount is invalid: %d", + codePointCount); + ASSERT(false); + return false; + } + const bool hasMultipleCodePoints = codePointCount > 1; + return buffer->writeCodePointsAndAdvancePosition(codePoints, codePointCount, + hasMultipleCodePoints, codePointFieldPos); +} + +/* static */ bool DynamicPatriciaTrieWritingUtils::writeProbabilityAndAdvancePosition( + BufferWithExtendableBuffer *const buffer, const int probability, + int *const probabilityFieldPos) { + if (probability < 0 || probability > MAX_PROBABILITY) { + AKLOGI("probability cannot be written because the probability is invalid: %d", + probability); + ASSERT(false); + return false; + } + return buffer->writeUintAndAdvancePosition(probability, PROBABILITY_FIELD_SIZE, + probabilityFieldPos); +} + +/* static */ bool DynamicPatriciaTrieWritingUtils::writeChildrenPositionAndAdvancePosition( + BufferWithExtendableBuffer *const buffer, const int childrenPosition, + int *const childrenPositionFieldPos) { + return writeDictOffset(buffer, childrenPosition, (*childrenPositionFieldPos), + childrenPositionFieldPos); +} + +/* static */ bool DynamicPatriciaTrieWritingUtils::writeDictOffset( + BufferWithExtendableBuffer *const buffer, const int targetPos, const int basePos, + int *const offsetFieldPos) { + int offset = targetPos - basePos; + if (targetPos == NOT_A_DICT_POS) { + offset = DynamicPatriciaTrieReadingUtils::DICT_OFFSET_INVALID; + } else if (offset == 0) { + offset = DynamicPatriciaTrieReadingUtils::DICT_OFFSET_ZERO_OFFSET; + } + if (offset > MAX_DICT_OFFSET_VALUE || offset < MIN_DICT_OFFSET_VALUE) { + AKLOGI("offset cannot be written because the offset is too large or too small: %d", + offset); + ASSERT(false); + return false; + } + uint32_t data = 0; + if (offset >= 0) { + data = offset; + } else { + data = abs(offset) | DICT_OFFSET_NEGATIVE_FLAG; + } + return buffer->writeUintAndAdvancePosition(data, DICT_OFFSET_FIELD_SIZE, offsetFieldPos); +} +} diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_utils.h b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_utils.h new file mode 100644 index 000000000..af76bc6b5 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_utils.h @@ -0,0 +1,76 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DYNAMIC_PATRICIA_TRIE_WRITING_UTILS_H +#define LATINIME_DYNAMIC_PATRICIA_TRIE_WRITING_UTILS_H + +#include <cstddef> + +#include "defines.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.h" + +namespace latinime { + +class BufferWithExtendableBuffer; + +class DynamicPatriciaTrieWritingUtils { + public: + static const int NODE_FLAG_FIELD_SIZE; + + static bool writeEmptyDictionary(BufferWithExtendableBuffer *const buffer, const int rootPos); + + static bool writeForwardLinkPositionAndAdvancePosition( + BufferWithExtendableBuffer *const buffer, const int forwardLinkPos, + int *const forwardLinkFieldPos); + + static bool writePtNodeArraySizeAndAdvancePosition(BufferWithExtendableBuffer *const buffer, + const size_t arraySize, int *const arraySizeFieldPos); + + static bool writeFlagsAndAdvancePosition(BufferWithExtendableBuffer *const buffer, + const DynamicPatriciaTrieReadingUtils::NodeFlags nodeFlags, + int *const nodeFlagsFieldPos); + + static bool writeParentPosOffsetAndAdvancePosition(BufferWithExtendableBuffer *const buffer, + const int parentPosition, const int basePos, int *const parentPosFieldPos); + + static bool writeCodePointsAndAdvancePosition(BufferWithExtendableBuffer *const buffer, + const int *const codePoints, const int codePointCount, int *const codePointFieldPos); + + static bool writeProbabilityAndAdvancePosition(BufferWithExtendableBuffer *const buffer, + const int probability, int *const probabilityFieldPos); + + static bool writeChildrenPositionAndAdvancePosition(BufferWithExtendableBuffer *const buffer, + const int childrenPosition, int *const childrenPositionFieldPos); + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(DynamicPatriciaTrieWritingUtils); + + static const size_t MAX_PTNODE_ARRAY_SIZE_TO_USE_SMALL_SIZE_FIELD; + static const size_t MAX_PTNODE_ARRAY_SIZE; + static const int SMALL_PTNODE_ARRAY_SIZE_FIELD_SIZE; + static const int LARGE_PTNODE_ARRAY_SIZE_FIELD_SIZE; + static const int LARGE_PTNODE_ARRAY_SIZE_FIELD_SIZE_FLAG; + static const int DICT_OFFSET_FIELD_SIZE; + static const int MAX_DICT_OFFSET_VALUE; + static const int MIN_DICT_OFFSET_VALUE; + static const int DICT_OFFSET_NEGATIVE_FLAG; + static const int PROBABILITY_FIELD_SIZE; + + static bool writeDictOffset(BufferWithExtendableBuffer *const buffer, const int targetPos, + const int basePos, int *const offsetFieldPos); +}; +} // namespace latinime +#endif /* LATINIME_DYNAMIC_PATRICIA_TRIE_WRITING_UTILS_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp new file mode 100644 index 000000000..7bbeacaa0 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp @@ -0,0 +1,131 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/header/header_policy.h" + +#include <cstddef> +#include <cstdio> +#include <ctime> + +namespace latinime { + + +// Note that these are corresponding definitions in Java side in FormatSpec.FileHeader. +const char *const HeaderPolicy::MULTIPLE_WORDS_DEMOTION_RATE_KEY = "MULTIPLE_WORDS_DEMOTION_RATE"; +const char *const HeaderPolicy::USES_FORGETTING_CURVE_KEY = "USES_FORGETTING_CURVE"; +const char *const HeaderPolicy::LAST_UPDATED_TIME_KEY = "date"; +const int HeaderPolicy::DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE = 100; +const float HeaderPolicy::MULTIPLE_WORD_COST_MULTIPLIER_SCALE = 100.0f; + +// Used for logging. Question mark is used to indicate that the key is not found. +void HeaderPolicy::readHeaderValueOrQuestionMark(const char *const key, int *outValue, + int outValueSize) const { + if (outValueSize <= 0) return; + if (outValueSize == 1) { + outValue[0] = '\0'; + return; + } + std::vector<int> keyCodePointVector; + HeaderReadWriteUtils::insertCharactersIntoVector(key, &keyCodePointVector); + HeaderReadWriteUtils::AttributeMap::const_iterator it = mAttributeMap.find(keyCodePointVector); + if (it == mAttributeMap.end()) { + // The key was not found. + outValue[0] = '?'; + outValue[1] = '\0'; + return; + } + const int terminalIndex = min(static_cast<int>(it->second.size()), outValueSize - 1); + for (int i = 0; i < terminalIndex; ++i) { + outValue[i] = it->second[i]; + } + outValue[terminalIndex] = '\0'; +} + +float HeaderPolicy::readMultipleWordCostMultiplier() const { + std::vector<int> keyVector; + HeaderReadWriteUtils::insertCharactersIntoVector(MULTIPLE_WORDS_DEMOTION_RATE_KEY, &keyVector); + const int demotionRate = HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap, + &keyVector, DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE); + if (demotionRate <= 0) { + return static_cast<float>(MAX_VALUE_FOR_WEIGHTING); + } + return MULTIPLE_WORD_COST_MULTIPLIER_SCALE / static_cast<float>(demotionRate); +} + +bool HeaderPolicy::readUsesForgettingCurveFlag() const { + std::vector<int> keyVector; + HeaderReadWriteUtils::insertCharactersIntoVector(USES_FORGETTING_CURVE_KEY, &keyVector); + return HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap, &keyVector, + false /* defaultValue */); +} + +// Returns current time when the key is not found or the value is invalid. +int HeaderPolicy::readLastUpdatedTime() const { + std::vector<int> keyVector; + HeaderReadWriteUtils::insertCharactersIntoVector(LAST_UPDATED_TIME_KEY, &keyVector); + return HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap, &keyVector, + time(0) /* defaultValue */); +} + +bool HeaderPolicy::writeHeaderToBuffer(BufferWithExtendableBuffer *const bufferToWrite, + const bool updatesLastUpdatedTime) const { + int writingPos = 0; + if (!HeaderReadWriteUtils::writeDictionaryVersion(bufferToWrite, mDictFormatVersion, + &writingPos)) { + return false; + } + if (!HeaderReadWriteUtils::writeDictionaryFlags(bufferToWrite, mDictionaryFlags, + &writingPos)) { + return false; + } + // Temporarily writes a dummy header size. + int headerSizeFieldPos = writingPos; + if (!HeaderReadWriteUtils::writeDictionaryHeaderSize(bufferToWrite, 0 /* size */, + &writingPos)) { + return false; + } + if (updatesLastUpdatedTime) { + // Set current time as a last updated time. + HeaderReadWriteUtils::AttributeMap attributeMapTowrite(mAttributeMap); + std::vector<int> updatedTimekey; + HeaderReadWriteUtils::insertCharactersIntoVector(LAST_UPDATED_TIME_KEY, &updatedTimekey); + HeaderReadWriteUtils::setIntAttribute(&attributeMapTowrite, &updatedTimekey, time(0)); + if (!HeaderReadWriteUtils::writeHeaderAttributes(bufferToWrite, &attributeMapTowrite, + &writingPos)) { + return false; + } + } else { + if (!HeaderReadWriteUtils::writeHeaderAttributes(bufferToWrite, &mAttributeMap, + &writingPos)) { + return false; + } + } + // Writes an actual header size. + if (!HeaderReadWriteUtils::writeDictionaryHeaderSize(bufferToWrite, writingPos, + &headerSizeFieldPos)) { + return false; + } + return true; +} + +/* static */ HeaderReadWriteUtils::AttributeMap + HeaderPolicy::createAttributeMapAndReadAllAttributes(const uint8_t *const dictBuf) { + HeaderReadWriteUtils::AttributeMap attributeMap; + HeaderReadWriteUtils::fetchAllHeaderAttributes(dictBuf, &attributeMap); + return attributeMap; +} + +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h new file mode 100644 index 000000000..e97c08ca4 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h @@ -0,0 +1,114 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_HEADER_POLICY_H +#define LATINIME_HEADER_POLICY_H + +#include <stdint.h> + +#include "defines.h" +#include "suggest/core/policy/dictionary_header_structure_policy.h" +#include "suggest/policyimpl/dictionary/header/header_read_write_utils.h" +#include "suggest/policyimpl/dictionary/utils/format_utils.h" + +namespace latinime { + +class HeaderPolicy : public DictionaryHeaderStructurePolicy { + public: + // Reads information from existing dictionary buffer. + HeaderPolicy(const uint8_t *const dictBuf, const int dictSize) + : mDictFormatVersion(FormatUtils::detectFormatVersion(dictBuf, dictSize)), + mDictionaryFlags(HeaderReadWriteUtils::getFlags(dictBuf)), + mSize(HeaderReadWriteUtils::getHeaderSize(dictBuf)), + mAttributeMap(createAttributeMapAndReadAllAttributes(dictBuf)), + mMultiWordCostMultiplier(readMultipleWordCostMultiplier()), + mUsesForgettingCurve(readUsesForgettingCurveFlag()), + mLastUpdatedTime(readLastUpdatedTime()) {} + + // Constructs header information using an attribute map. + HeaderPolicy(const FormatUtils::FORMAT_VERSION dictFormatVersion, + const HeaderReadWriteUtils::AttributeMap *const attributeMap) + : mDictFormatVersion(dictFormatVersion), + mDictionaryFlags(HeaderReadWriteUtils::createAndGetDictionaryFlagsUsingAttributeMap( + attributeMap)), mSize(0), mAttributeMap(*attributeMap), + mMultiWordCostMultiplier(readUsesForgettingCurveFlag()), + mUsesForgettingCurve(readUsesForgettingCurveFlag()), + mLastUpdatedTime(readLastUpdatedTime()) {} + + ~HeaderPolicy() {} + + AK_FORCE_INLINE int getSize() const { + return mSize; + } + + AK_FORCE_INLINE bool supportsDynamicUpdate() const { + return HeaderReadWriteUtils::supportsDynamicUpdate(mDictionaryFlags); + } + + AK_FORCE_INLINE bool requiresGermanUmlautProcessing() const { + return HeaderReadWriteUtils::requiresGermanUmlautProcessing(mDictionaryFlags); + } + + AK_FORCE_INLINE bool requiresFrenchLigatureProcessing() const { + return HeaderReadWriteUtils::requiresFrenchLigatureProcessing(mDictionaryFlags); + } + + AK_FORCE_INLINE float getMultiWordCostMultiplier() const { + return mMultiWordCostMultiplier; + } + + AK_FORCE_INLINE bool usesForgettingCurve() const { + return mUsesForgettingCurve; + } + + AK_FORCE_INLINE int getLastUpdatedTime() const { + return mLastUpdatedTime; + } + + void readHeaderValueOrQuestionMark(const char *const key, + int *outValue, int outValueSize) const; + + bool writeHeaderToBuffer(BufferWithExtendableBuffer *const bufferToWrite, + const bool updatesLastUpdatedTime) const; + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(HeaderPolicy); + + static const char *const MULTIPLE_WORDS_DEMOTION_RATE_KEY; + static const char *const USES_FORGETTING_CURVE_KEY; + static const char *const LAST_UPDATED_TIME_KEY; + static const int DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE; + static const float MULTIPLE_WORD_COST_MULTIPLIER_SCALE; + + const FormatUtils::FORMAT_VERSION mDictFormatVersion; + const HeaderReadWriteUtils::DictionaryFlags mDictionaryFlags; + const int mSize; + HeaderReadWriteUtils::AttributeMap mAttributeMap; + const float mMultiWordCostMultiplier; + const bool mUsesForgettingCurve; + const int mLastUpdatedTime; + + float readMultipleWordCostMultiplier() const; + + bool readUsesForgettingCurveFlag() const; + + int readLastUpdatedTime() const; + + static HeaderReadWriteUtils::AttributeMap createAttributeMapAndReadAllAttributes( + const uint8_t *const dictBuf); +}; +} // namespace latinime +#endif /* LATINIME_HEADER_POLICY_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp new file mode 100644 index 000000000..3b1c78085 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp @@ -0,0 +1,215 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/header/header_read_write_utils.h" + +#include <cctype> +#include <cstdio> +#include <vector> + +#include "defines.h" +#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" +#include "suggest/policyimpl/dictionary/utils/byte_array_utils.h" + +namespace latinime { + +const int HeaderReadWriteUtils::MAX_ATTRIBUTE_KEY_LENGTH = 256; +const int HeaderReadWriteUtils::MAX_ATTRIBUTE_VALUE_LENGTH = 256; + +const int HeaderReadWriteUtils::HEADER_MAGIC_NUMBER_SIZE = 4; +const int HeaderReadWriteUtils::HEADER_DICTIONARY_VERSION_SIZE = 2; +const int HeaderReadWriteUtils::HEADER_FLAG_SIZE = 2; +const int HeaderReadWriteUtils::HEADER_SIZE_FIELD_SIZE = 4; + +const HeaderReadWriteUtils::DictionaryFlags HeaderReadWriteUtils::NO_FLAGS = 0; +// Flags for special processing +// Those *must* match the flags in makedict (FormatSpec#*_PROCESSING_FLAG) or +// something very bad (like, the apocalypse) will happen. Please update both at the same time. +const HeaderReadWriteUtils::DictionaryFlags + HeaderReadWriteUtils::GERMAN_UMLAUT_PROCESSING_FLAG = 0x1; +const HeaderReadWriteUtils::DictionaryFlags + HeaderReadWriteUtils::SUPPORTS_DYNAMIC_UPDATE_FLAG = 0x2; +const HeaderReadWriteUtils::DictionaryFlags + HeaderReadWriteUtils::FRENCH_LIGATURE_PROCESSING_FLAG = 0x4; + +// Note that these are corresponding definitions in Java side in FormatSpec.FileHeader. +const char *const HeaderReadWriteUtils::SUPPORTS_DYNAMIC_UPDATE_KEY = "SUPPORTS_DYNAMIC_UPDATE"; +const char *const HeaderReadWriteUtils::REQUIRES_GERMAN_UMLAUT_PROCESSING_KEY = + "REQUIRES_GERMAN_UMLAUT_PROCESSING"; +const char *const HeaderReadWriteUtils::REQUIRES_FRENCH_LIGATURE_PROCESSING_KEY = + "REQUIRES_FRENCH_LIGATURE_PROCESSING"; + +/* static */ int HeaderReadWriteUtils::getHeaderSize(const uint8_t *const dictBuf) { + // See the format of the header in the comment in + // BinaryDictionaryFormatUtils::detectFormatVersion() + return ByteArrayUtils::readUint32(dictBuf, HEADER_MAGIC_NUMBER_SIZE + + HEADER_DICTIONARY_VERSION_SIZE + HEADER_FLAG_SIZE); +} + +/* static */ HeaderReadWriteUtils::DictionaryFlags + HeaderReadWriteUtils::getFlags(const uint8_t *const dictBuf) { + return ByteArrayUtils::readUint16(dictBuf, + HEADER_MAGIC_NUMBER_SIZE + HEADER_DICTIONARY_VERSION_SIZE); +} + +/* static */ HeaderReadWriteUtils::DictionaryFlags + HeaderReadWriteUtils::createAndGetDictionaryFlagsUsingAttributeMap( + const HeaderReadWriteUtils::AttributeMap *const attributeMap) { + AttributeMap::key_type key; + insertCharactersIntoVector(REQUIRES_GERMAN_UMLAUT_PROCESSING_KEY, &key); + const bool requiresGermanUmlautProcessing = readBoolAttributeValue(attributeMap, &key, + false /* defaultValue */); + key.clear(); + insertCharactersIntoVector(REQUIRES_FRENCH_LIGATURE_PROCESSING_KEY, &key); + const bool requiresFrenchLigatureProcessing = readBoolAttributeValue(attributeMap, &key, + false /* defaultValue */); + key.clear(); + insertCharactersIntoVector(SUPPORTS_DYNAMIC_UPDATE_KEY, &key); + const bool supportsDynamicUpdate = readBoolAttributeValue(attributeMap, &key, + false /* defaultValue */); + DictionaryFlags dictflags = NO_FLAGS; + dictflags |= requiresGermanUmlautProcessing ? GERMAN_UMLAUT_PROCESSING_FLAG : 0; + dictflags |= requiresFrenchLigatureProcessing ? FRENCH_LIGATURE_PROCESSING_FLAG : 0; + dictflags |= supportsDynamicUpdate ? SUPPORTS_DYNAMIC_UPDATE_FLAG : 0; + return dictflags; +} + +/* static */ void HeaderReadWriteUtils::fetchAllHeaderAttributes(const uint8_t *const dictBuf, + AttributeMap *const headerAttributes) { + const int headerSize = getHeaderSize(dictBuf); + int pos = getHeaderOptionsPosition(); + if (pos == NOT_A_DICT_POS) { + // The header doesn't have header options. + return; + } + int keyBuffer[MAX_ATTRIBUTE_KEY_LENGTH]; + int valueBuffer[MAX_ATTRIBUTE_VALUE_LENGTH]; + while (pos < headerSize) { + const int keyLength = ByteArrayUtils::readStringAndAdvancePosition(dictBuf, + MAX_ATTRIBUTE_KEY_LENGTH, keyBuffer, &pos); + std::vector<int> key; + key.insert(key.end(), keyBuffer, keyBuffer + keyLength); + const int valueLength = ByteArrayUtils::readStringAndAdvancePosition(dictBuf, + MAX_ATTRIBUTE_VALUE_LENGTH, valueBuffer, &pos); + std::vector<int> value; + value.insert(value.end(), valueBuffer, valueBuffer + valueLength); + headerAttributes->insert(AttributeMap::value_type(key, value)); + } +} + +/* static */ bool HeaderReadWriteUtils::writeDictionaryVersion( + BufferWithExtendableBuffer *const buffer, const FormatUtils::FORMAT_VERSION version, + int *const writingPos) { + if (!buffer->writeUintAndAdvancePosition(FormatUtils::MAGIC_NUMBER, HEADER_MAGIC_NUMBER_SIZE, + writingPos)) { + return false; + } + switch (version) { + case FormatUtils::VERSION_2: + // Version 2 dictionary writing is not supported. + return false; + case FormatUtils::VERSION_3: + return buffer->writeUintAndAdvancePosition(3 /* data */, + HEADER_DICTIONARY_VERSION_SIZE, writingPos); + default: + return false; + } +} + +/* static */ bool HeaderReadWriteUtils::writeDictionaryFlags( + BufferWithExtendableBuffer *const buffer, const DictionaryFlags flags, + int *const writingPos) { + return buffer->writeUintAndAdvancePosition(flags, HEADER_FLAG_SIZE, writingPos); +} + +/* static */ bool HeaderReadWriteUtils::writeDictionaryHeaderSize( + BufferWithExtendableBuffer *const buffer, const int size, int *const writingPos) { + return buffer->writeUintAndAdvancePosition(size, HEADER_SIZE_FIELD_SIZE, writingPos); +} + +/* static */ bool HeaderReadWriteUtils::writeHeaderAttributes( + BufferWithExtendableBuffer *const buffer, const AttributeMap *const headerAttributes, + int *const writingPos) { + for (AttributeMap::const_iterator it = headerAttributes->begin(); + it != headerAttributes->end(); ++it) { + // Write a key. + if (!buffer->writeCodePointsAndAdvancePosition(&(it->first.at(0)), it->first.size(), + true /* writesTerminator */, writingPos)) { + return false; + } + // Write a value. + if (!buffer->writeCodePointsAndAdvancePosition(&(it->second.at(0)), it->second.size(), + true /* writesTerminator */, writingPos)) { + return false; + } + } + return true; +} + +/* static */ void HeaderReadWriteUtils::setBoolAttribute(AttributeMap *const headerAttributes, + const AttributeMap::key_type *const key, const bool value) { + setIntAttribute(headerAttributes, key, value ? 1 : 0); +} + +/* static */ void HeaderReadWriteUtils::setIntAttribute(AttributeMap *const headerAttributes, + const AttributeMap::key_type *const key, const int value) { + AttributeMap::mapped_type valueVector; + char charBuf[LARGEST_INT_DIGIT_COUNT + 1]; + snprintf(charBuf, LARGEST_INT_DIGIT_COUNT + 1, "%d", value); + insertCharactersIntoVector(charBuf, &valueVector); + (*headerAttributes)[*key] = valueVector; +} + +/* static */ bool HeaderReadWriteUtils::readBoolAttributeValue( + const AttributeMap *const headerAttributes, const AttributeMap::key_type *const key, + const bool defaultValue) { + const int intDefaultValue = defaultValue ? 1 : 0; + const int intValue = readIntAttributeValue(headerAttributes, key, intDefaultValue); + return intValue != 0; +} + +/* static */ int HeaderReadWriteUtils::readIntAttributeValue( + const AttributeMap *const headerAttributes, const AttributeMap::key_type *const key, + const int defaultValue) { + AttributeMap::const_iterator it = headerAttributes->find(*key); + if (it != headerAttributes->end()) { + int value = 0; + bool isNegative = false; + for (size_t i = 0; i < it->second.size(); ++i) { + if (i == 0 && it->second.at(i) == '-') { + isNegative = true; + } else { + if (!isdigit(it->second.at(i))) { + // If not a number. + return defaultValue; + } + value *= 10; + value += it->second.at(i) - '0'; + } + } + return isNegative ? -value : value; + } + return defaultValue; +} + +/* static */ void HeaderReadWriteUtils::insertCharactersIntoVector(const char *const characters, + std::vector<int> *const vector) { + for (int i = 0; characters[i]; ++i) { + vector->push_back(characters[i]); + } +} + +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.h b/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.h new file mode 100644 index 000000000..caa5097f6 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.h @@ -0,0 +1,117 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_HEADER_READ_WRITE_UTILS_H +#define LATINIME_HEADER_READ_WRITE_UTILS_H + +#include <map> +#include <stdint.h> +#include <vector> + +#include "defines.h" +#include "suggest/policyimpl/dictionary/utils/format_utils.h" + +namespace latinime { + +class BufferWithExtendableBuffer; + +class HeaderReadWriteUtils { + public: + typedef uint16_t DictionaryFlags; + typedef std::map<std::vector<int>, std::vector<int> > AttributeMap; + + static int getHeaderSize(const uint8_t *const dictBuf); + + static DictionaryFlags getFlags(const uint8_t *const dictBuf); + + static AK_FORCE_INLINE bool supportsDynamicUpdate(const DictionaryFlags flags) { + return (flags & SUPPORTS_DYNAMIC_UPDATE_FLAG) != 0; + } + + static AK_FORCE_INLINE bool requiresGermanUmlautProcessing(const DictionaryFlags flags) { + return (flags & GERMAN_UMLAUT_PROCESSING_FLAG) != 0; + } + + static AK_FORCE_INLINE bool requiresFrenchLigatureProcessing(const DictionaryFlags flags) { + return (flags & FRENCH_LIGATURE_PROCESSING_FLAG) != 0; + } + + static AK_FORCE_INLINE int getHeaderOptionsPosition() { + return HEADER_MAGIC_NUMBER_SIZE + HEADER_DICTIONARY_VERSION_SIZE + HEADER_FLAG_SIZE + + HEADER_SIZE_FIELD_SIZE; + } + + static DictionaryFlags createAndGetDictionaryFlagsUsingAttributeMap( + const HeaderReadWriteUtils::AttributeMap *const attributeMap); + + static void fetchAllHeaderAttributes(const uint8_t *const dictBuf, + AttributeMap *const headerAttributes); + + static bool writeDictionaryVersion(BufferWithExtendableBuffer *const buffer, + const FormatUtils::FORMAT_VERSION version, int *const writingPos); + + static bool writeDictionaryFlags(BufferWithExtendableBuffer *const buffer, + const DictionaryFlags flags, int *const writingPos); + + static bool writeDictionaryHeaderSize(BufferWithExtendableBuffer *const buffer, + const int size, int *const writingPos); + + static bool writeHeaderAttributes(BufferWithExtendableBuffer *const buffer, + const AttributeMap *const headerAttributes, int *const writingPos); + + /** + * Methods for header attributes. + */ + static void setBoolAttribute(AttributeMap *const headerAttributes, + const AttributeMap::key_type *const key, const bool value); + + static void setIntAttribute(AttributeMap *const headerAttributes, + const AttributeMap::key_type *const key, const int value); + + static bool readBoolAttributeValue(const AttributeMap *const headerAttributes, + const AttributeMap::key_type *const key, const bool defaultValue); + + static int readIntAttributeValue(const AttributeMap *const headerAttributes, + const AttributeMap::key_type *const key, const int defaultValue); + + static void insertCharactersIntoVector(const char *const characters, + AttributeMap::key_type *const key); + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(HeaderReadWriteUtils); + + static const int MAX_ATTRIBUTE_KEY_LENGTH; + static const int MAX_ATTRIBUTE_VALUE_LENGTH; + + static const int HEADER_MAGIC_NUMBER_SIZE; + static const int HEADER_DICTIONARY_VERSION_SIZE; + static const int HEADER_FLAG_SIZE; + static const int HEADER_SIZE_FIELD_SIZE; + + static const DictionaryFlags NO_FLAGS; + // Flags for special processing + // Those *must* match the flags in makedict (FormatSpec#*_PROCESSING_FLAGS) or + // something very bad (like, the apocalypse) will happen. Please update both at the same time. + static const DictionaryFlags GERMAN_UMLAUT_PROCESSING_FLAG; + static const DictionaryFlags SUPPORTS_DYNAMIC_UPDATE_FLAG; + static const DictionaryFlags FRENCH_LIGATURE_PROCESSING_FLAG; + + static const char *const SUPPORTS_DYNAMIC_UPDATE_KEY; + static const char *const REQUIRES_GERMAN_UMLAUT_PROCESSING_KEY; + static const char *const REQUIRES_FRENCH_LIGATURE_PROCESSING_KEY; +}; +} +#endif /* LATINIME_HEADER_READ_WRITE_UTILS_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.cpp new file mode 100644 index 000000000..8a84bd261 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.cpp @@ -0,0 +1,433 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#include "suggest/policyimpl/dictionary/patricia_trie_policy.h" + +#include "defines.h" +#include "suggest/core/dicnode/dic_node.h" +#include "suggest/core/dicnode/dic_node_vector.h" +#include "suggest/policyimpl/dictionary/patricia_trie_reading_utils.h" +#include "suggest/policyimpl/dictionary/utils/probability_utils.h" + +namespace latinime { + +void PatriciaTriePolicy::createAndGetAllChildNodes(const DicNode *const dicNode, + DicNodeVector *const childDicNodes) const { + if (!dicNode->hasChildren()) { + return; + } + int nextPos = dicNode->getChildrenPos(); + if (nextPos < 0 || nextPos >= mDictBufferSize) { + AKLOGE("Children PtNode array position is invalid. pos: %d, dict size: %d", + nextPos, mDictBufferSize); + ASSERT(false); + return; + } + const int childCount = PatriciaTrieReadingUtils::getPtNodeArraySizeAndAdvancePosition( + mDictRoot, &nextPos); + for (int i = 0; i < childCount; i++) { + if (nextPos < 0 || nextPos >= mDictBufferSize) { + AKLOGE("Child PtNode position is invalid. pos: %d, dict size: %d, childCount: %d / %d", + nextPos, mDictBufferSize, i, childCount); + ASSERT(false); + return; + } + nextPos = createAndGetLeavingChildNode(dicNode, nextPos, childDicNodes); + } +} + +// This retrieves code points and the probability of the word by its terminal position. +// Due to the fact that words are ordered in the dictionary in a strict breadth-first order, +// it is possible to check for this with advantageous complexity. For each node, we search +// for PtNodes with children and compare the children position with the position we look for. +// When we shoot the position we look for, it means the word we look for is in the children +// of the previous PtNode. The only tricky part is the fact that if we arrive at the end of a +// PtNode array with the last PtNode's children position still less than what we are searching for, +// we must descend the last PtNode's children (for example, if the word we are searching for starts +// with a z, it's the last PtNode of the root array, so all children addresses will be smaller +// than the position we look for, and we have to descend the z node). +/* Parameters : + * ptNodePos: the byte position of the terminal PtNode of the word we are searching for (this is + * what is stored as the "bigram position" in each bigram) + * outCodePoints: an array to write the found word, with MAX_WORD_LENGTH size. + * outUnigramProbability: a pointer to an int to write the probability into. + * Return value : the code point count, of 0 if the word was not found. + */ +// TODO: Split this function to be more readable +int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( + const int ptNodePos, const int maxCodePointCount, int *const outCodePoints, + int *const outUnigramProbability) const { + int pos = getRootPosition(); + int wordPos = 0; + // One iteration of the outer loop iterates through PtNode arrays. As stated above, we will + // only traverse nodes that are actually a part of the terminal we are searching, so each time + // we enter this loop we are one depth level further than last time. + // The only reason we count nodes is because we want to reduce the probability of infinite + // looping in case there is a bug. Since we know there is an upper bound to the depth we are + // supposed to traverse, it does not hurt to count iterations. + for (int loopCount = maxCodePointCount; loopCount > 0; --loopCount) { + int lastCandidatePtNodePos = 0; + // Let's loop through PtNodes in this PtNode array searching for either the terminal + // or one of its ascendants. + for (int ptNodeCount = PatriciaTrieReadingUtils::getPtNodeArraySizeAndAdvancePosition( + mDictRoot, &pos); ptNodeCount > 0; --ptNodeCount) { + const int startPos = pos; + const PatriciaTrieReadingUtils::NodeFlags flags = + PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(mDictRoot, &pos); + const int character = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition( + mDictRoot, &pos); + if (ptNodePos == startPos) { + // We found the position. Copy the rest of the code points in the buffer and return + // the length. + outCodePoints[wordPos] = character; + if (PatriciaTrieReadingUtils::hasMultipleChars(flags)) { + int nextChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition( + mDictRoot, &pos); + // We count code points in order to avoid infinite loops if the file is broken + // or if there is some other bug + int charCount = maxCodePointCount; + while (NOT_A_CODE_POINT != nextChar && --charCount > 0) { + outCodePoints[++wordPos] = nextChar; + nextChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition( + mDictRoot, &pos); + } + } + *outUnigramProbability = + PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mDictRoot, + &pos); + return ++wordPos; + } + // We need to skip past this PtNode, so skip any remaining code points after the + // first and possibly the probability. + if (PatriciaTrieReadingUtils::hasMultipleChars(flags)) { + PatriciaTrieReadingUtils::skipCharacters(mDictRoot, flags, MAX_WORD_LENGTH, &pos); + } + if (PatriciaTrieReadingUtils::isTerminal(flags)) { + PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mDictRoot, &pos); + } + // The fact that this PtNode has children is very important. Since we already know + // that this PtNode does not match, if it has no children we know it is irrelevant + // to what we are searching for. + const bool hasChildren = PatriciaTrieReadingUtils::hasChildrenInFlags(flags); + // We will write in `found' whether we have passed the children position we are + // searching for. For example if we search for "beer", the children of b are less + // than the address we are searching for and the children of c are greater. When we + // come here for c, we realize this is too big, and that we should descend b. + bool found; + if (hasChildren) { + int currentPos = pos; + // Here comes the tricky part. First, read the children position. + const int childrenPos = PatriciaTrieReadingUtils + ::readChildrenPositionAndAdvancePosition(mDictRoot, flags, ¤tPos); + if (childrenPos > ptNodePos) { + // If the children pos is greater than the position, it means the previous + // PtNode, which position is stored in lastCandidatePtNodePos, was the right + // one. + found = true; + } else if (1 >= ptNodeCount) { + // However if we are on the LAST PtNode of this array, and we have NOT shot the + // position we should descend THIS node. So we trick the lastCandidatePtNodePos + // so that we will descend this PtNode, not the previous one. + lastCandidatePtNodePos = startPos; + found = true; + } else { + // Else, we should continue looking. + found = false; + } + } else { + // Even if we don't have children here, we could still be on the last PtNode of / + // this array. If this is the case, we should descend the last PtNode that had + // children, and their position is already in lastCandidatePtNodePos. + found = (1 >= ptNodeCount); + } + + if (found) { + // Okay, we found the PtNode we should descend. Its position is in + // the lastCandidatePtNodePos variable, so we just re-read it. + if (0 != lastCandidatePtNodePos) { + const PatriciaTrieReadingUtils::NodeFlags lastFlags = + PatriciaTrieReadingUtils::getFlagsAndAdvancePosition( + mDictRoot, &lastCandidatePtNodePos); + const int lastChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition( + mDictRoot, &lastCandidatePtNodePos); + // We copy all the characters in this PtNode to the buffer + outCodePoints[wordPos] = lastChar; + if (PatriciaTrieReadingUtils::hasMultipleChars(lastFlags)) { + int nextChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition( + mDictRoot, &lastCandidatePtNodePos); + int charCount = maxCodePointCount; + while (-1 != nextChar && --charCount > 0) { + outCodePoints[++wordPos] = nextChar; + nextChar = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition( + mDictRoot, &lastCandidatePtNodePos); + } + } + ++wordPos; + // Now we only need to branch to the children address. Skip the probability if + // it's there, read pos, and break to resume the search at pos. + if (PatriciaTrieReadingUtils::isTerminal(lastFlags)) { + PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mDictRoot, + &lastCandidatePtNodePos); + } + pos = PatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition( + mDictRoot, lastFlags, &lastCandidatePtNodePos); + break; + } else { + // Here is a little tricky part: we come here if we found out that all children + // addresses in this PtNode are bigger than the address we are searching for. + // Should we conclude the word is not in the dictionary? No! It could still be + // one of the remaining PtNodes in this array, so we have to keep looking in + // this array until we find it (or we realize it's not there either, in which + // case it's actually not in the dictionary). Pass the end of this PtNode, + // ready to start the next one. + if (PatriciaTrieReadingUtils::hasChildrenInFlags(flags)) { + PatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition( + mDictRoot, flags, &pos); + } + if (PatriciaTrieReadingUtils::hasShortcutTargets(flags)) { + mShortcutListPolicy.skipAllShortcuts(&pos); + } + if (PatriciaTrieReadingUtils::hasBigrams(flags)) { + mBigramListPolicy.skipAllBigrams(&pos); + } + } + } else { + // If we did not find it, we should record the last children address for the next + // iteration. + if (hasChildren) lastCandidatePtNodePos = startPos; + // Now skip the end of this PtNode (children pos and the attributes if any) so that + // our pos is after the end of this PtNode, at the start of the next one. + if (PatriciaTrieReadingUtils::hasChildrenInFlags(flags)) { + PatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition( + mDictRoot, flags, &pos); + } + if (PatriciaTrieReadingUtils::hasShortcutTargets(flags)) { + mShortcutListPolicy.skipAllShortcuts(&pos); + } + if (PatriciaTrieReadingUtils::hasBigrams(flags)) { + mBigramListPolicy.skipAllBigrams(&pos); + } + } + + } + } + // If we have looked through all the PtNodes and found no match, the ptNodePos is + // not the position of a terminal in this dictionary. + return 0; +} + +// This function gets the position of the terminal node of the exact matching word in the +// dictionary. If no match is found, it returns NOT_A_DICT_POS. +int PatriciaTriePolicy::getTerminalNodePositionOfWord(const int *const inWord, + const int length, const bool forceLowerCaseSearch) const { + int pos = getRootPosition(); + int wordPos = 0; + + while (true) { + // If we already traversed the tree further than the word is long, there means + // there was no match (or we would have found it). + if (wordPos >= length) return NOT_A_DICT_POS; + int ptNodeCount = PatriciaTrieReadingUtils::getPtNodeArraySizeAndAdvancePosition(mDictRoot, + &pos); + const int wChar = forceLowerCaseSearch + ? CharUtils::toLowerCase(inWord[wordPos]) : inWord[wordPos]; + while (true) { + // If there are no more PtNodes in this array, it means we could not + // find a matching character for this depth, therefore there is no match. + if (0 >= ptNodeCount) return NOT_A_DICT_POS; + const int ptNodePos = pos; + const PatriciaTrieReadingUtils::NodeFlags flags = + PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(mDictRoot, &pos); + int character = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition(mDictRoot, + &pos); + if (character == wChar) { + // This is the correct PtNode. Only one PtNode may start with the same char within + // a PtNode array, so either we found our match in this array, or there is + // no match and we can return NOT_A_DICT_POS. So we will check all the + // characters in this PtNode indeed does match. + if (PatriciaTrieReadingUtils::hasMultipleChars(flags)) { + character = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition(mDictRoot, + &pos); + while (NOT_A_CODE_POINT != character) { + ++wordPos; + // If we shoot the length of the word we search for, or if we find a single + // character that does not match, as explained above, it means the word is + // not in the dictionary (by virtue of this PtNode being the only one to + // match the word on the first character, but not matching the whole word). + if (wordPos >= length) return NOT_A_DICT_POS; + if (inWord[wordPos] != character) return NOT_A_DICT_POS; + character = PatriciaTrieReadingUtils::getCodePointAndAdvancePosition( + mDictRoot, &pos); + } + } + // If we come here we know that so far, we do match. Either we are on a terminal + // and we match the length, in which case we found it, or we traverse children. + // If we don't match the length AND don't have children, then a word in the + // dictionary fully matches a prefix of the searched word but not the full word. + ++wordPos; + if (PatriciaTrieReadingUtils::isTerminal(flags)) { + if (wordPos == length) { + return ptNodePos; + } + PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mDictRoot, &pos); + } + if (!PatriciaTrieReadingUtils::hasChildrenInFlags(flags)) { + return NOT_A_DICT_POS; + } + // We have children and we are still shorter than the word we are searching for, so + // we need to traverse children. Put the pointer on the children position, and + // break + pos = PatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition(mDictRoot, + flags, &pos); + break; + } else { + // This PtNode does not match, so skip the remaining part and go to the next. + if (PatriciaTrieReadingUtils::hasMultipleChars(flags)) { + PatriciaTrieReadingUtils::skipCharacters(mDictRoot, flags, MAX_WORD_LENGTH, + &pos); + } + if (PatriciaTrieReadingUtils::isTerminal(flags)) { + PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mDictRoot, &pos); + } + if (PatriciaTrieReadingUtils::hasChildrenInFlags(flags)) { + PatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition(mDictRoot, + flags, &pos); + } + if (PatriciaTrieReadingUtils::hasShortcutTargets(flags)) { + mShortcutListPolicy.skipAllShortcuts(&pos); + } + if (PatriciaTrieReadingUtils::hasBigrams(flags)) { + mBigramListPolicy.skipAllBigrams(&pos); + } + } + --ptNodeCount; + } + } +} + +int PatriciaTriePolicy::getProbability(const int unigramProbability, + const int bigramProbability) const { + if (unigramProbability == NOT_A_PROBABILITY) { + return NOT_A_PROBABILITY; + } else if (bigramProbability == NOT_A_PROBABILITY) { + return ProbabilityUtils::backoff(unigramProbability); + } else { + return ProbabilityUtils::computeProbabilityForBigram(unigramProbability, + bigramProbability); + } +} + +int PatriciaTriePolicy::getUnigramProbabilityOfPtNode(const int ptNodePos) const { + if (ptNodePos == NOT_A_DICT_POS) { + return NOT_A_PROBABILITY; + } + int pos = ptNodePos; + const PatriciaTrieReadingUtils::NodeFlags flags = + PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(mDictRoot, &pos); + if (!PatriciaTrieReadingUtils::isTerminal(flags)) { + return NOT_A_PROBABILITY; + } + if (PatriciaTrieReadingUtils::isNotAWord(flags) + || PatriciaTrieReadingUtils::isBlacklisted(flags)) { + // If this is not a word, or if it's a blacklisted entry, it should behave as + // having no probability outside of the suggestion process (where it should be used + // for shortcuts). + return NOT_A_PROBABILITY; + } + PatriciaTrieReadingUtils::skipCharacters(mDictRoot, flags, MAX_WORD_LENGTH, &pos); + return getProbability(PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition( + mDictRoot, &pos), NOT_A_PROBABILITY); +} + +int PatriciaTriePolicy::getShortcutPositionOfPtNode(const int ptNodePos) const { + if (ptNodePos == NOT_A_DICT_POS) { + return NOT_A_DICT_POS; + } + int pos = ptNodePos; + const PatriciaTrieReadingUtils::NodeFlags flags = + PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(mDictRoot, &pos); + if (!PatriciaTrieReadingUtils::hasShortcutTargets(flags)) { + return NOT_A_DICT_POS; + } + PatriciaTrieReadingUtils::skipCharacters(mDictRoot, flags, MAX_WORD_LENGTH, &pos); + if (PatriciaTrieReadingUtils::isTerminal(flags)) { + PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mDictRoot, &pos); + } + if (PatriciaTrieReadingUtils::hasChildrenInFlags(flags)) { + PatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition(mDictRoot, flags, &pos); + } + return pos; +} + +int PatriciaTriePolicy::getBigramsPositionOfPtNode(const int ptNodePos) const { + if (ptNodePos == NOT_A_DICT_POS) { + return NOT_A_DICT_POS; + } + int pos = ptNodePos; + const PatriciaTrieReadingUtils::NodeFlags flags = + PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(mDictRoot, &pos); + if (!PatriciaTrieReadingUtils::hasBigrams(flags)) { + return NOT_A_DICT_POS; + } + PatriciaTrieReadingUtils::skipCharacters(mDictRoot, flags, MAX_WORD_LENGTH, &pos); + if (PatriciaTrieReadingUtils::isTerminal(flags)) { + PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mDictRoot, &pos); + } + if (PatriciaTrieReadingUtils::hasChildrenInFlags(flags)) { + PatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition(mDictRoot, flags, &pos); + } + if (PatriciaTrieReadingUtils::hasShortcutTargets(flags)) { + mShortcutListPolicy.skipAllShortcuts(&pos);; + } + return pos; +} + +int PatriciaTriePolicy::createAndGetLeavingChildNode(const DicNode *const dicNode, + const int ptNodePos, DicNodeVector *childDicNodes) const { + int pos = ptNodePos; + const PatriciaTrieReadingUtils::NodeFlags flags = + PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(mDictRoot, &pos); + int mergedNodeCodePoints[MAX_WORD_LENGTH]; + const int mergedNodeCodePointCount = PatriciaTrieReadingUtils::getCharsAndAdvancePosition( + mDictRoot, flags, MAX_WORD_LENGTH, mergedNodeCodePoints, &pos); + const int probability = (PatriciaTrieReadingUtils::isTerminal(flags))? + PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mDictRoot, &pos) + : NOT_A_PROBABILITY; + const int childrenPos = PatriciaTrieReadingUtils::hasChildrenInFlags(flags) ? + PatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition( + mDictRoot, flags, &pos) : NOT_A_DICT_POS; + if (PatriciaTrieReadingUtils::hasShortcutTargets(flags)) { + getShortcutsStructurePolicy()->skipAllShortcuts(&pos); + } + if (PatriciaTrieReadingUtils::hasBigrams(flags)) { + getBigramsStructurePolicy()->skipAllBigrams(&pos); + } + if (mergedNodeCodePointCount <= 0) { + AKLOGE("Empty PtNode is not allowed. Code point count: %d", mergedNodeCodePointCount); + ASSERT(false); + return pos; + } + childDicNodes->pushLeavingChild(dicNode, ptNodePos, childrenPos, probability, + PatriciaTrieReadingUtils::isTerminal(flags), + PatriciaTrieReadingUtils::hasChildrenInFlags(flags), + PatriciaTrieReadingUtils::isBlacklisted(flags) || + PatriciaTrieReadingUtils::isNotAWord(flags), + mergedNodeCodePointCount, mergedNodeCodePoints); + return pos; +} + +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.h new file mode 100644 index 000000000..f1de914cb --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.h @@ -0,0 +1,130 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_PATRICIA_TRIE_POLICY_H +#define LATINIME_PATRICIA_TRIE_POLICY_H + +#include <stdint.h> + +#include "defines.h" +#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h" +#include "suggest/policyimpl/dictionary/bigram/bigram_list_policy.h" +#include "suggest/policyimpl/dictionary/header/header_policy.h" +#include "suggest/policyimpl/dictionary/shortcut/shortcut_list_policy.h" +#include "suggest/policyimpl/dictionary/utils/mmapped_buffer.h" + +namespace latinime { + +class DicNode; +class DicNodeVector; + +class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { + public: + PatriciaTriePolicy(const MmappedBuffer *const buffer) + : mBuffer(buffer), mHeaderPolicy(mBuffer->getBuffer(), buffer->getBufferSize()), + mDictRoot(mBuffer->getBuffer() + mHeaderPolicy.getSize()), + mDictBufferSize(mBuffer->getBufferSize() - mHeaderPolicy.getSize()), + mBigramListPolicy(mDictRoot), mShortcutListPolicy(mDictRoot) {} + + ~PatriciaTriePolicy() { + delete mBuffer; + } + + AK_FORCE_INLINE int getRootPosition() const { + return 0; + } + + void createAndGetAllChildNodes(const DicNode *const dicNode, + DicNodeVector *const childDicNodes) const; + + int getCodePointsAndProbabilityAndReturnCodePointCount( + const int terminalNodePos, const int maxCodePointCount, int *const outCodePoints, + int *const outUnigramProbability) const; + + int getTerminalNodePositionOfWord(const int *const inWord, + const int length, const bool forceLowerCaseSearch) const; + + int getProbability(const int unigramProbability, const int bigramProbability) const; + + int getUnigramProbabilityOfPtNode(const int ptNodePos) const; + + int getShortcutPositionOfPtNode(const int ptNodePos) const; + + int getBigramsPositionOfPtNode(const int ptNodePos) const; + + const DictionaryHeaderStructurePolicy *getHeaderStructurePolicy() const { + return &mHeaderPolicy; + } + + const DictionaryBigramsStructurePolicy *getBigramsStructurePolicy() const { + return &mBigramListPolicy; + } + + const DictionaryShortcutsStructurePolicy *getShortcutsStructurePolicy() const { + return &mShortcutListPolicy; + } + + bool addUnigramWord(const int *const word, const int length, const int probability) { + // This method should not be called for non-updatable dictionary. + AKLOGI("Warning: addUnigramWord() is called for non-updatable dictionary."); + return false; + } + + bool addBigramWords(const int *const word0, const int length0, const int *const word1, + const int length1, const int probability) { + // This method should not be called for non-updatable dictionary. + AKLOGI("Warning: addBigramWords() is called for non-updatable dictionary."); + return false; + } + + bool removeBigramWords(const int *const word0, const int length0, const int *const word1, + const int length1) { + // This method should not be called for non-updatable dictionary. + AKLOGI("Warning: removeBigramWords() is called for non-updatable dictionary."); + return false; + } + + void flush(const char *const filePath) { + // This method should not be called for non-updatable dictionary. + AKLOGI("Warning: flush() is called for non-updatable dictionary."); + } + + void flushWithGC(const char *const filePath) { + // This method should not be called for non-updatable dictionary. + AKLOGI("Warning: flushWithGC() is called for non-updatable dictionary."); + } + + bool needsToRunGC() const { + // This method should not be called for non-updatable dictionary. + AKLOGI("Warning: needsToRunGC() is called for non-updatable dictionary."); + return false; + } + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(PatriciaTriePolicy); + + const MmappedBuffer *const mBuffer; + const HeaderPolicy mHeaderPolicy; + const uint8_t *const mDictRoot; + const int mDictBufferSize; + const BigramListPolicy mBigramListPolicy; + const ShortcutListPolicy mShortcutListPolicy; + + int createAndGetLeavingChildNode(const DicNode *const dicNode, const int ptNodePos, + DicNodeVector *const childDicNodes) const; +}; +} // namespace latinime +#endif // LATINIME_PATRICIA_TRIE_POLICY_H diff --git a/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_reading_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_reading_utils.cpp new file mode 100644 index 000000000..7df55815f --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_reading_utils.cpp @@ -0,0 +1,133 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/patricia_trie_reading_utils.h" + +#include "defines.h" +#include "suggest/policyimpl/dictionary/utils/byte_array_utils.h" + +namespace latinime { + +typedef PatriciaTrieReadingUtils PtReadingUtils; + +const PtReadingUtils::NodeFlags PtReadingUtils::MASK_CHILDREN_POSITION_TYPE = 0xC0; +const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_CHILDREN_POSITION_TYPE_NOPOSITION = 0x00; +const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_CHILDREN_POSITION_TYPE_ONEBYTE = 0x40; +const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_CHILDREN_POSITION_TYPE_TWOBYTES = 0x80; +const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_CHILDREN_POSITION_TYPE_THREEBYTES = 0xC0; + +// Flag for single/multiple char group +const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_HAS_MULTIPLE_CHARS = 0x20; +// Flag for terminal PtNodes +const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_IS_TERMINAL = 0x10; +// Flag for shortcut targets presence +const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_HAS_SHORTCUT_TARGETS = 0x08; +// Flag for bigram presence +const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_HAS_BIGRAMS = 0x04; +// Flag for non-words (typically, shortcut only entries) +const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_IS_NOT_A_WORD = 0x02; +// Flag for blacklist +const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_IS_BLACKLISTED = 0x01; + +/* static */ int PtReadingUtils::getPtNodeArraySizeAndAdvancePosition( + const uint8_t *const buffer, int *const pos) { + const uint8_t firstByte = ByteArrayUtils::readUint8AndAdvancePosition(buffer, pos); + if (firstByte < 0x80) { + return firstByte; + } else { + return ((firstByte & 0x7F) << 8) ^ ByteArrayUtils::readUint8AndAdvancePosition( + buffer, pos); + } +} + +/* static */ PtReadingUtils::NodeFlags PtReadingUtils::getFlagsAndAdvancePosition( + const uint8_t *const buffer, int *const pos) { + return ByteArrayUtils::readUint8AndAdvancePosition(buffer, pos); +} + +/* static */ int PtReadingUtils::getCodePointAndAdvancePosition(const uint8_t *const buffer, + int *const pos) { + return ByteArrayUtils::readCodePointAndAdvancePosition(buffer, pos); +} + +// Returns the number of read characters. +/* static */ int PtReadingUtils::getCharsAndAdvancePosition(const uint8_t *const buffer, + const NodeFlags flags, const int maxLength, int *const outBuffer, int *const pos) { + int length = 0; + if (hasMultipleChars(flags)) { + length = ByteArrayUtils::readStringAndAdvancePosition(buffer, maxLength, outBuffer, + pos); + } else { + const int codePoint = getCodePointAndAdvancePosition(buffer, pos); + if (codePoint == NOT_A_CODE_POINT) { + // CAVEAT: codePoint == NOT_A_CODE_POINT means the code point is + // CHARACTER_ARRAY_TERMINATOR. The code point must not be CHARACTER_ARRAY_TERMINATOR + // when the PtNode has a single code point. + length = 0; + AKLOGE("codePoint is NOT_A_CODE_POINT. pos: %d, codePoint: 0x%x, buffer[pos - 1]: 0x%x", + *pos - 1, codePoint, buffer[*pos - 1]); + ASSERT(false); + } else if (maxLength > 0) { + outBuffer[0] = codePoint; + length = 1; + } + } + return length; +} + +// Returns the number of skipped characters. +/* static */ int PtReadingUtils::skipCharacters(const uint8_t *const buffer, const NodeFlags flags, + const int maxLength, int *const pos) { + if (hasMultipleChars(flags)) { + return ByteArrayUtils::advancePositionToBehindString(buffer, maxLength, pos); + } else { + if (maxLength > 0) { + getCodePointAndAdvancePosition(buffer, pos); + return 1; + } else { + return 0; + } + } +} + +/* static */ int PtReadingUtils::readProbabilityAndAdvancePosition(const uint8_t *const buffer, + int *const pos) { + return ByteArrayUtils::readUint8AndAdvancePosition(buffer, pos); +} + +/* static */ int PtReadingUtils::readChildrenPositionAndAdvancePosition( + const uint8_t *const buffer, const NodeFlags flags, int *const pos) { + const int base = *pos; + int offset = 0; + switch (MASK_CHILDREN_POSITION_TYPE & flags) { + case FLAG_CHILDREN_POSITION_TYPE_ONEBYTE: + offset = ByteArrayUtils::readUint8AndAdvancePosition(buffer, pos); + break; + case FLAG_CHILDREN_POSITION_TYPE_TWOBYTES: + offset = ByteArrayUtils::readUint16AndAdvancePosition(buffer, pos); + break; + case FLAG_CHILDREN_POSITION_TYPE_THREEBYTES: + offset = ByteArrayUtils::readUint24AndAdvancePosition(buffer, pos); + break; + default: + // If we come here, it means we asked for the children of a word with + // no children. + return NOT_A_DICT_POS; + } + return base + offset; +} + +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_reading_utils.h b/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_reading_utils.h new file mode 100644 index 000000000..8420ee95a --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_reading_utils.h @@ -0,0 +1,120 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_PATRICIA_TRIE_READING_UTILS_H +#define LATINIME_PATRICIA_TRIE_READING_UTILS_H + +#include <stdint.h> + +#include "defines.h" + +namespace latinime { + +class PatriciaTrieReadingUtils { + public: + typedef uint8_t NodeFlags; + + static int getPtNodeArraySizeAndAdvancePosition(const uint8_t *const buffer, int *const pos); + + static NodeFlags getFlagsAndAdvancePosition(const uint8_t *const buffer, int *const pos); + + static int getCodePointAndAdvancePosition(const uint8_t *const buffer, int *const pos); + + // Returns the number of read characters. + static int getCharsAndAdvancePosition(const uint8_t *const buffer, const NodeFlags flags, + const int maxLength, int *const outBuffer, int *const pos); + + // Returns the number of skipped characters. + static int skipCharacters(const uint8_t *const buffer, const NodeFlags flags, + const int maxLength, int *const pos); + + static int readProbabilityAndAdvancePosition(const uint8_t *const buffer, int *const pos); + + static int readChildrenPositionAndAdvancePosition(const uint8_t *const buffer, + const NodeFlags flags, int *const pos); + + /** + * Node Flags + */ + static AK_FORCE_INLINE bool isBlacklisted(const NodeFlags flags) { + return (flags & FLAG_IS_BLACKLISTED) != 0; + } + + static AK_FORCE_INLINE bool isNotAWord(const NodeFlags flags) { + return (flags & FLAG_IS_NOT_A_WORD) != 0; + } + + static AK_FORCE_INLINE bool isTerminal(const NodeFlags flags) { + return (flags & FLAG_IS_TERMINAL) != 0; + } + + static AK_FORCE_INLINE bool hasShortcutTargets(const NodeFlags flags) { + return (flags & FLAG_HAS_SHORTCUT_TARGETS) != 0; + } + + static AK_FORCE_INLINE bool hasBigrams(const NodeFlags flags) { + return (flags & FLAG_HAS_BIGRAMS) != 0; + } + + static AK_FORCE_INLINE bool hasMultipleChars(const NodeFlags flags) { + return (flags & FLAG_HAS_MULTIPLE_CHARS) != 0; + } + + static AK_FORCE_INLINE bool hasChildrenInFlags(const NodeFlags flags) { + return FLAG_CHILDREN_POSITION_TYPE_NOPOSITION != (MASK_CHILDREN_POSITION_TYPE & flags); + } + + static AK_FORCE_INLINE NodeFlags createAndGetFlags(const bool isBlacklisted, + const bool isNotAWord, const bool isTerminal, const bool hasShortcutTargets, + const bool hasBigrams, const bool hasMultipleChars, + const int childrenPositionFieldSize) { + NodeFlags nodeFlags = 0; + nodeFlags = isBlacklisted ? (nodeFlags | FLAG_IS_BLACKLISTED) : nodeFlags; + nodeFlags = isNotAWord ? (nodeFlags | FLAG_IS_NOT_A_WORD) : nodeFlags; + nodeFlags = isTerminal ? (nodeFlags | FLAG_IS_TERMINAL) : nodeFlags; + nodeFlags = hasShortcutTargets ? (nodeFlags | FLAG_HAS_SHORTCUT_TARGETS) : nodeFlags; + nodeFlags = hasBigrams ? (nodeFlags | FLAG_HAS_BIGRAMS) : nodeFlags; + nodeFlags = hasMultipleChars ? (nodeFlags | FLAG_HAS_MULTIPLE_CHARS) : nodeFlags; + if (childrenPositionFieldSize == 1) { + nodeFlags |= FLAG_CHILDREN_POSITION_TYPE_ONEBYTE; + } else if (childrenPositionFieldSize == 2) { + nodeFlags |= FLAG_CHILDREN_POSITION_TYPE_TWOBYTES; + } else if (childrenPositionFieldSize == 3) { + nodeFlags |= FLAG_CHILDREN_POSITION_TYPE_THREEBYTES; + } else { + nodeFlags |= FLAG_CHILDREN_POSITION_TYPE_NOPOSITION; + } + return nodeFlags; + } + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(PatriciaTrieReadingUtils); + + static const NodeFlags MASK_CHILDREN_POSITION_TYPE; + static const NodeFlags FLAG_CHILDREN_POSITION_TYPE_NOPOSITION; + static const NodeFlags FLAG_CHILDREN_POSITION_TYPE_ONEBYTE; + static const NodeFlags FLAG_CHILDREN_POSITION_TYPE_TWOBYTES; + static const NodeFlags FLAG_CHILDREN_POSITION_TYPE_THREEBYTES; + + static const NodeFlags FLAG_HAS_MULTIPLE_CHARS; + static const NodeFlags FLAG_IS_TERMINAL; + static const NodeFlags FLAG_HAS_SHORTCUT_TARGETS; + static const NodeFlags FLAG_HAS_BIGRAMS; + static const NodeFlags FLAG_IS_NOT_A_WORD; + static const NodeFlags FLAG_IS_BLACKLISTED; +}; +} // namespace latinime +#endif /* LATINIME_PATRICIA_TRIE_NODE_READING_UTILS_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/shortcut/dynamic_shortcut_list_policy.h b/native/jni/src/suggest/policyimpl/dictionary/shortcut/dynamic_shortcut_list_policy.h new file mode 100644 index 000000000..bd3211f6a --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/shortcut/dynamic_shortcut_list_policy.h @@ -0,0 +1,123 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DYNAMIC_SHORTCUT_LIST_POLICY_H +#define LATINIME_DYNAMIC_SHORTCUT_LIST_POLICY_H + +#include <stdint.h> + +#include "defines.h" +#include "suggest/core/policy/dictionary_shortcuts_structure_policy.h" +#include "suggest/policyimpl/dictionary/shortcut/shortcut_list_reading_utils.h" +#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" + +namespace latinime { + +/* + * This is a dynamic version of ShortcutListPolicy and supports an additional buffer. + */ +class DynamicShortcutListPolicy : public DictionaryShortcutsStructurePolicy { + public: + explicit DynamicShortcutListPolicy(const BufferWithExtendableBuffer *const buffer) + : mBuffer(buffer) {} + + ~DynamicShortcutListPolicy() {} + + int getStartPos(const int pos) const { + if (pos == NOT_A_DICT_POS) { + return NOT_A_DICT_POS; + } + return pos + ShortcutListReadingUtils::getShortcutListSizeFieldSize(); + } + + void getNextShortcut(const int maxCodePointCount, int *const outCodePoint, + int *const outCodePointCount, bool *const outIsWhitelist, bool *const outHasNext, + int *const pos) const { + const bool usesAdditionalBuffer = mBuffer->isInAdditionalBuffer(*pos); + const uint8_t *const buffer = mBuffer->getBuffer(usesAdditionalBuffer); + if (usesAdditionalBuffer) { + *pos -= mBuffer->getOriginalBufferSize(); + } + const ShortcutListReadingUtils::ShortcutFlags flags = + ShortcutListReadingUtils::getFlagsAndForwardPointer(buffer, pos); + if (outHasNext) { + *outHasNext = ShortcutListReadingUtils::hasNext(flags); + } + if (outIsWhitelist) { + *outIsWhitelist = ShortcutListReadingUtils::isWhitelist(flags); + } + if (outCodePoint) { + *outCodePointCount = ShortcutListReadingUtils::readShortcutTarget( + buffer, maxCodePointCount, outCodePoint, pos); + } + if (usesAdditionalBuffer) { + *pos += mBuffer->getOriginalBufferSize(); + } + } + + void skipAllShortcuts(int *const pos) const { + const bool usesAdditionalBuffer = mBuffer->isInAdditionalBuffer(*pos); + const uint8_t *const buffer = mBuffer->getBuffer(usesAdditionalBuffer); + if (usesAdditionalBuffer) { + *pos -= mBuffer->getOriginalBufferSize(); + } + const int shortcutListSize = ShortcutListReadingUtils + ::getShortcutListSizeAndForwardPointer(buffer, pos); + *pos += shortcutListSize; + if (usesAdditionalBuffer) { + *pos += mBuffer->getOriginalBufferSize(); + } + } + + // Copy shortcuts from the shortcut list that starts at fromPos in mBuffer to toPos in + // bufferToWrite and advance these positions after the shortcut lists. This returns whether + // the copy was succeeded or not. + bool copyAllShortcutsAndReturnIfSucceededOrNot(BufferWithExtendableBuffer *const bufferToWrite, + int *const fromPos, int *const toPos) const { + const bool usesAdditionalBuffer = mBuffer->isInAdditionalBuffer(*fromPos); + if (usesAdditionalBuffer) { + *fromPos -= mBuffer->getOriginalBufferSize(); + } + const int shortcutListSize = ShortcutListReadingUtils + ::getShortcutListSizeAndForwardPointer(mBuffer->getBuffer(usesAdditionalBuffer), + fromPos); + // Copy shortcut list size. + if (!bufferToWrite->writeUintAndAdvancePosition( + shortcutListSize + ShortcutListReadingUtils::getShortcutListSizeFieldSize(), + ShortcutListReadingUtils::getShortcutListSizeFieldSize(), toPos)) { + return false; + } + // Copy shortcut list. + for (int i = 0; i < shortcutListSize; ++i) { + const uint8_t data = ByteArrayUtils::readUint8AndAdvancePosition( + mBuffer->getBuffer(usesAdditionalBuffer), fromPos); + if (!bufferToWrite->writeUintAndAdvancePosition(data, 1 /* size */, toPos)) { + return false; + } + } + if (usesAdditionalBuffer) { + *fromPos += mBuffer->getOriginalBufferSize(); + } + return true; + } + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(DynamicShortcutListPolicy); + + const BufferWithExtendableBuffer *const mBuffer; +}; +} // namespace latinime +#endif // LATINIME_DYNAMIC_SHORTCUT_LIST_POLICY_H diff --git a/native/jni/src/suggest/policyimpl/dictionary/shortcut/shortcut_list_policy.h b/native/jni/src/suggest/policyimpl/dictionary/shortcut/shortcut_list_policy.h new file mode 100644 index 000000000..d73f73953 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/shortcut/shortcut_list_policy.h @@ -0,0 +1,73 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_SHORTCUT_LIST_POLICY_H +#define LATINIME_SHORTCUT_LIST_POLICY_H + +#include <stdint.h> + +#include "defines.h" +#include "suggest/core/policy/dictionary_shortcuts_structure_policy.h" +#include "suggest/policyimpl/dictionary/shortcut/shortcut_list_reading_utils.h" + +namespace latinime { + +class ShortcutListPolicy : public DictionaryShortcutsStructurePolicy { + public: + explicit ShortcutListPolicy(const uint8_t *const shortcutBuf) + : mShortcutsBuf(shortcutBuf) {} + + ~ShortcutListPolicy() {} + + int getStartPos(const int pos) const { + if (pos == NOT_A_DICT_POS) { + return NOT_A_DICT_POS; + } + int listPos = pos; + ShortcutListReadingUtils::getShortcutListSizeAndForwardPointer(mShortcutsBuf, &listPos); + return listPos; + } + + void getNextShortcut(const int maxCodePointCount, int *const outCodePoint, + int *const outCodePointCount, bool *const outIsWhitelist, bool *const outHasNext, + int *const pos) const { + const ShortcutListReadingUtils::ShortcutFlags flags = + ShortcutListReadingUtils::getFlagsAndForwardPointer(mShortcutsBuf, pos); + if (outHasNext) { + *outHasNext = ShortcutListReadingUtils::hasNext(flags); + } + if (outIsWhitelist) { + *outIsWhitelist = ShortcutListReadingUtils::isWhitelist(flags); + } + if (outCodePoint) { + *outCodePointCount = ShortcutListReadingUtils::readShortcutTarget( + mShortcutsBuf, maxCodePointCount, outCodePoint, pos); + } + } + + void skipAllShortcuts(int *const pos) const { + const int shortcutListSize = ShortcutListReadingUtils + ::getShortcutListSizeAndForwardPointer(mShortcutsBuf, pos); + *pos += shortcutListSize; + } + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(ShortcutListPolicy); + + const uint8_t *const mShortcutsBuf; +}; +} // namespace latinime +#endif // LATINIME_SHORTCUT_LIST_POLICY_H diff --git a/native/jni/src/suggest/policyimpl/dictionary/shortcut/shortcut_list_reading_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/shortcut/shortcut_list_reading_utils.cpp new file mode 100644 index 000000000..847dcdee5 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/shortcut/shortcut_list_reading_utils.cpp @@ -0,0 +1,51 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/shortcut/shortcut_list_reading_utils.h" + +#include "suggest/policyimpl/dictionary/utils/byte_array_utils.h" + +namespace latinime { + +// Flag for presence of more attributes +const ShortcutListReadingUtils::ShortcutFlags + ShortcutListReadingUtils::FLAG_ATTRIBUTE_HAS_NEXT = 0x80; +// Mask for attribute probability, stored on 4 bits inside the flags byte. +const ShortcutListReadingUtils::ShortcutFlags + ShortcutListReadingUtils::MASK_ATTRIBUTE_PROBABILITY = 0x0F; +const int ShortcutListReadingUtils::SHORTCUT_LIST_SIZE_FIELD_SIZE = 2; +// The numeric value of the shortcut probability that means 'whitelist'. +const int ShortcutListReadingUtils::WHITELIST_SHORTCUT_PROBABILITY = 15; + +/* static */ ShortcutListReadingUtils::ShortcutFlags + ShortcutListReadingUtils::getFlagsAndForwardPointer(const uint8_t *const dictRoot, + int *const pos) { + return ByteArrayUtils::readUint8AndAdvancePosition(dictRoot, pos); +} + +/* static */ int ShortcutListReadingUtils::getShortcutListSizeAndForwardPointer( + const uint8_t *const dictRoot, int *const pos) { + // readUint16andAdvancePosition() returns an offset *including* the uint16 field itself. + return ByteArrayUtils::readUint16AndAdvancePosition(dictRoot, pos) + - SHORTCUT_LIST_SIZE_FIELD_SIZE; +} + +/* static */ int ShortcutListReadingUtils::readShortcutTarget( + const uint8_t *const dictRoot, const int maxLength, int *const outWord, int *const pos) { + return ByteArrayUtils::readStringAndAdvancePosition(dictRoot, maxLength, outWord, pos); +} + +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/shortcut/shortcut_list_reading_utils.h b/native/jni/src/suggest/policyimpl/dictionary/shortcut/shortcut_list_reading_utils.h new file mode 100644 index 000000000..a83ed5a50 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/shortcut/shortcut_list_reading_utils.h @@ -0,0 +1,69 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_SHORTCUT_LIST_READING_UTILS_H +#define LATINIME_SHORTCUT_LIST_READING_UTILS_H + +#include <stdint.h> + +#include "defines.h" + +namespace latinime { + +class ShortcutListReadingUtils { + public: + typedef uint8_t ShortcutFlags; + + static ShortcutFlags getFlagsAndForwardPointer(const uint8_t *const dictRoot, int *const pos); + + static AK_FORCE_INLINE int getProbabilityFromFlags(const ShortcutFlags flags) { + return flags & MASK_ATTRIBUTE_PROBABILITY; + } + + static AK_FORCE_INLINE bool hasNext(const ShortcutFlags flags) { + return (flags & FLAG_ATTRIBUTE_HAS_NEXT) != 0; + } + + // This method returns the size of the shortcut list region excluding the shortcut list size + // field at the beginning. + static int getShortcutListSizeAndForwardPointer(const uint8_t *const dictRoot, int *const pos); + + static AK_FORCE_INLINE int getShortcutListSizeFieldSize() { + return SHORTCUT_LIST_SIZE_FIELD_SIZE; + } + + static AK_FORCE_INLINE void skipShortcuts(const uint8_t *const dictRoot, int *const pos) { + const int shortcutListSize = getShortcutListSizeAndForwardPointer(dictRoot, pos); + *pos += shortcutListSize; + } + + static AK_FORCE_INLINE bool isWhitelist(const ShortcutFlags flags) { + return getProbabilityFromFlags(flags) == WHITELIST_SHORTCUT_PROBABILITY; + } + + static int readShortcutTarget(const uint8_t *const dictRoot, const int maxLength, + int *const outWord, int *const pos); + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(ShortcutListReadingUtils); + + static const ShortcutFlags FLAG_ATTRIBUTE_HAS_NEXT; + static const ShortcutFlags MASK_ATTRIBUTE_PROBABILITY; + static const int SHORTCUT_LIST_SIZE_FIELD_SIZE; + static const int WHITELIST_SHORTCUT_PROBABILITY; +}; +} // namespace latinime +#endif // LATINIME_SHORTCUT_LIST_READING_UTILS_H diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.cpp new file mode 100644 index 000000000..f692882f2 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.cpp @@ -0,0 +1,103 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" + +namespace latinime { + +const size_t BufferWithExtendableBuffer::MAX_ADDITIONAL_BUFFER_SIZE = 1024 * 1024; +const int BufferWithExtendableBuffer::NEAR_BUFFER_LIMIT_THRESHOLD_PERCENTILE = 90; +// TODO: Needs to allocate larger memory corresponding to the current vector size. +const size_t BufferWithExtendableBuffer::EXTEND_ADDITIONAL_BUFFER_SIZE_STEP = 128 * 1024; + +bool BufferWithExtendableBuffer::writeUintAndAdvancePosition(const uint32_t data, const int size, + int *const pos) { + if (!(size >= 1 && size <= 4)) { + AKLOGI("writeUintAndAdvancePosition() is called with invalid size: %d", size); + ASSERT(false); + return false; + } + if (!checkAndPrepareWriting(*pos, size)) { + return false; + } + const bool usesAdditionalBuffer = isInAdditionalBuffer(*pos); + uint8_t *const buffer = usesAdditionalBuffer ? &mAdditionalBuffer[0] : mOriginalBuffer; + if (usesAdditionalBuffer) { + *pos -= mOriginalBufferSize; + } + ByteArrayUtils::writeUintAndAdvancePosition(buffer, data, size, pos); + if (usesAdditionalBuffer) { + *pos += mOriginalBufferSize; + } + return true; +} + +bool BufferWithExtendableBuffer::writeCodePointsAndAdvancePosition(const int *const codePoints, + const int codePointCount, const bool writesTerminator ,int *const pos) { + const size_t size = ByteArrayUtils::calculateRequiredByteCountToStoreCodePoints( + codePoints, codePointCount, writesTerminator); + if (!checkAndPrepareWriting(*pos, size)) { + return false; + } + const bool usesAdditionalBuffer = isInAdditionalBuffer(*pos); + uint8_t *const buffer = usesAdditionalBuffer ? &mAdditionalBuffer[0] : mOriginalBuffer; + if (usesAdditionalBuffer) { + *pos -= mOriginalBufferSize; + } + ByteArrayUtils::writeCodePointsAndAdvancePosition(buffer, codePoints, codePointCount, + writesTerminator, pos); + if (usesAdditionalBuffer) { + *pos += mOriginalBufferSize; + } + return true; +} + +bool BufferWithExtendableBuffer::extendBuffer() { + const size_t sizeAfterExtending = + mAdditionalBuffer.size() + EXTEND_ADDITIONAL_BUFFER_SIZE_STEP; + if (sizeAfterExtending > mMaxAdditionalBufferSize) { + return false; + } + mAdditionalBuffer.resize(mAdditionalBuffer.size() + EXTEND_ADDITIONAL_BUFFER_SIZE_STEP); + return true; +} + +bool BufferWithExtendableBuffer::checkAndPrepareWriting(const int pos, const int size) { + if (isInAdditionalBuffer(pos)) { + const int tailPosition = getTailPosition(); + if (pos == tailPosition) { + // Append data to the tail. + if (pos + size > static_cast<int>(mAdditionalBuffer.size()) + mOriginalBufferSize) { + // Need to extend buffer. + if (!extendBuffer()) { + return false; + } + } + mUsedAdditionalBufferSize += size; + } else if (pos + size > tailPosition) { + // The access will beyond the tail of used region. + return false; + } + } else { + if (pos < 0 || mOriginalBufferSize < pos + size) { + // Invalid position or violate the boundary. + return false; + } + } + return true; +} + +} diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h b/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h new file mode 100644 index 000000000..17d2e39c2 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h @@ -0,0 +1,103 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_BUFFER_WITH_EXTENDABLE_BUFFER_H +#define LATINIME_BUFFER_WITH_EXTENDABLE_BUFFER_H + +#include <cstddef> +#include <stdint.h> +#include <vector> + +#include "defines.h" +#include "suggest/policyimpl/dictionary/utils/byte_array_utils.h" + +namespace latinime { + +// This is used as a buffer that can be extended for updatable dictionaries. +// To optimize performance, raw pointer is directly used for reading buffer. The position has to be +// adjusted to access additional buffer. On the other hand, this class does not provide writable +// raw pointer but provides several methods that handle boundary checking for writing data. +class BufferWithExtendableBuffer { + public: + BufferWithExtendableBuffer(uint8_t *const originalBuffer, const int originalBufferSize, + const int maxAdditionalBufferSize = MAX_ADDITIONAL_BUFFER_SIZE) + : mOriginalBuffer(originalBuffer), mOriginalBufferSize(originalBufferSize), + mAdditionalBuffer(EXTEND_ADDITIONAL_BUFFER_SIZE_STEP), mUsedAdditionalBufferSize(0), + mMaxAdditionalBufferSize(maxAdditionalBufferSize) {} + + AK_FORCE_INLINE int getTailPosition() const { + return mOriginalBufferSize + mUsedAdditionalBufferSize; + } + + /** + * For reading. + */ + AK_FORCE_INLINE bool isInAdditionalBuffer(const int position) const { + return position >= mOriginalBufferSize; + } + + // TODO: Resolve the issue that the address can be changed when the vector is resized. + // CAVEAT!: Be careful about array out of bound access with buffers + AK_FORCE_INLINE const uint8_t *getBuffer(const bool usesAdditionalBuffer) const { + if (usesAdditionalBuffer) { + return &mAdditionalBuffer[0]; + } else { + return mOriginalBuffer; + } + } + + AK_FORCE_INLINE int getOriginalBufferSize() const { + return mOriginalBufferSize; + } + + AK_FORCE_INLINE bool isNearSizeLimit() const { + return mAdditionalBuffer.size() >= ((mMaxAdditionalBufferSize + * NEAR_BUFFER_LIMIT_THRESHOLD_PERCENTILE) / 100); + } + + /** + * For writing. + * + * Writing is allowed for original buffer, already written region of additional buffer and the + * tail of additional buffer. + */ + bool writeUintAndAdvancePosition(const uint32_t data, const int size, int *const pos); + + bool writeCodePointsAndAdvancePosition(const int *const codePoints, const int codePointCount, + const bool writesTerminator, int *const pos); + + private: + DISALLOW_COPY_AND_ASSIGN(BufferWithExtendableBuffer); + + static const size_t MAX_ADDITIONAL_BUFFER_SIZE; + static const int NEAR_BUFFER_LIMIT_THRESHOLD_PERCENTILE; + static const size_t EXTEND_ADDITIONAL_BUFFER_SIZE_STEP; + + uint8_t *const mOriginalBuffer; + const int mOriginalBufferSize; + std::vector<uint8_t> mAdditionalBuffer; + int mUsedAdditionalBufferSize; + const size_t mMaxAdditionalBufferSize; + + // Return if the buffer is successfully extended or not. + bool extendBuffer(); + + // Returns if it is possible to write size-bytes from pos. When pos is at the tail position of + // the additional buffer, try extending the buffer. + bool checkAndPrepareWriting(const int pos, const int size); +}; +} +#endif /* LATINIME_BUFFER_WITH_EXTENDABLE_BUFFER_H */ diff --git a/native/jni/src/suggest/core/dictionary/byte_array_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/byte_array_utils.cpp index 68b1d5d15..1833e8832 100644 --- a/native/jni/src/suggest/core/dictionary/byte_array_utils.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/byte_array_utils.cpp @@ -14,11 +14,12 @@ * limitations under the License. */ -#include "suggest/core/dictionary/byte_array_utils.h" +#include "suggest/policyimpl/dictionary/utils/byte_array_utils.h" namespace latinime { -const uint8_t ByteArrayUtils::MINIMAL_ONE_BYTE_CHARACTER_VALUE = 0x20; +const uint8_t ByteArrayUtils::MINIMUM_ONE_BYTE_CHARACTER_VALUE = 0x20; +const uint8_t ByteArrayUtils::MAXIMUM_ONE_BYTE_CHARACTER_VALUE = 0xFF; const uint8_t ByteArrayUtils::CHARACTER_ARRAY_TERMINATOR = 0x1F; } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/byte_array_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/byte_array_utils.h new file mode 100644 index 000000000..0c1576818 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/byte_array_utils.h @@ -0,0 +1,261 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_BYTE_ARRAY_UTILS_H +#define LATINIME_BYTE_ARRAY_UTILS_H + +#include <stdint.h> + +#include "defines.h" + +namespace latinime { + +/** + * Utility methods for reading byte arrays. + */ +class ByteArrayUtils { + public: + /** + * Integer writing + * + * Each method write a corresponding size integer in a big endian manner. + */ + static AK_FORCE_INLINE void writeUintAndAdvancePosition(uint8_t *const buffer, + const uint32_t data, const int size, int *const pos) { + // size must be in 1 to 4. + ASSERT(size >= 1 && size <= 4); + switch (size) { + case 1: + ByteArrayUtils::writeUint8AndAdvancePosition(buffer, data, pos); + return; + case 2: + ByteArrayUtils::writeUint16AndAdvancePosition(buffer, data, pos); + return; + case 3: + ByteArrayUtils::writeUint24AndAdvancePosition(buffer, data, pos); + return; + case 4: + ByteArrayUtils::writeUint32AndAdvancePosition(buffer, data, pos); + return; + default: + break; + } + } + + /** + * Integer reading + * + * Each method read a corresponding size integer in a big endian manner. + */ + static AK_FORCE_INLINE uint32_t readUint32(const uint8_t *const buffer, const int pos) { + return (buffer[pos] << 24) ^ (buffer[pos + 1] << 16) + ^ (buffer[pos + 2] << 8) ^ buffer[pos + 3]; + } + + static AK_FORCE_INLINE uint32_t readUint24(const uint8_t *const buffer, const int pos) { + return (buffer[pos] << 16) ^ (buffer[pos + 1] << 8) ^ buffer[pos + 2]; + } + + static AK_FORCE_INLINE uint16_t readUint16(const uint8_t *const buffer, const int pos) { + return (buffer[pos] << 8) ^ buffer[pos + 1]; + } + + static AK_FORCE_INLINE uint8_t readUint8(const uint8_t *const buffer, const int pos) { + return buffer[pos]; + } + + static AK_FORCE_INLINE uint32_t readUint32AndAdvancePosition( + const uint8_t *const buffer, int *const pos) { + const uint32_t value = readUint32(buffer, *pos); + *pos += 4; + return value; + } + + static AK_FORCE_INLINE int readSint24AndAdvancePosition( + const uint8_t *const buffer, int *const pos) { + const uint8_t value = readUint8(buffer, *pos); + if (value < 0x80) { + return readUint24AndAdvancePosition(buffer, pos); + } else { + (*pos)++; + return -(((value & 0x7F) << 16) ^ readUint16AndAdvancePosition(buffer, pos)); + } + } + + static AK_FORCE_INLINE uint32_t readUint24AndAdvancePosition( + const uint8_t *const buffer, int *const pos) { + const uint32_t value = readUint24(buffer, *pos); + *pos += 3; + return value; + } + + static AK_FORCE_INLINE uint16_t readUint16AndAdvancePosition( + const uint8_t *const buffer, int *const pos) { + const uint16_t value = readUint16(buffer, *pos); + *pos += 2; + return value; + } + + static AK_FORCE_INLINE uint8_t readUint8AndAdvancePosition( + const uint8_t *const buffer, int *const pos) { + return buffer[(*pos)++]; + } + + /** + * Code Point Reading + * + * 1 byte = bbbbbbbb match + * case 000xxxxx: xxxxx << 16 + next byte << 8 + next byte + * else: if 00011111 (= 0x1F) : this is the terminator. This is a relevant choice because + * unicode code points range from 0 to 0x10FFFF, so any 3-byte value starting with + * 00011111 would be outside unicode. + * else: iso-latin-1 code + * This allows for the whole unicode range to be encoded, including chars outside of + * the BMP. Also everything in the iso-latin-1 charset is only 1 byte, except control + * characters which should never happen anyway (and still work, but take 3 bytes). + */ + static AK_FORCE_INLINE int readCodePoint(const uint8_t *const buffer, const int pos) { + int p = pos; + return readCodePointAndAdvancePosition(buffer, &p); + } + + static AK_FORCE_INLINE int readCodePointAndAdvancePosition( + const uint8_t *const buffer, int *const pos) { + const uint8_t firstByte = readUint8(buffer, *pos); + if (firstByte < MINIMUM_ONE_BYTE_CHARACTER_VALUE) { + if (firstByte == CHARACTER_ARRAY_TERMINATOR) { + *pos += 1; + return NOT_A_CODE_POINT; + } else { + return readUint24AndAdvancePosition(buffer, pos); + } + } else { + *pos += 1; + return firstByte; + } + } + + /** + * String (array of code points) Reading + * + * Reads code points until the terminator is found. + */ + // Returns the length of the string. + static int readStringAndAdvancePosition(const uint8_t *const buffer, + const int maxLength, int *const outBuffer, int *const pos) { + int length = 0; + int codePoint = readCodePointAndAdvancePosition(buffer, pos); + while (NOT_A_CODE_POINT != codePoint && length < maxLength) { + outBuffer[length++] = codePoint; + codePoint = readCodePointAndAdvancePosition(buffer, pos); + } + return length; + } + + // Advances the position and returns the length of the string. + static int advancePositionToBehindString( + const uint8_t *const buffer, const int maxLength, int *const pos) { + int length = 0; + int codePoint = readCodePointAndAdvancePosition(buffer, pos); + while (NOT_A_CODE_POINT != codePoint && length < maxLength) { + codePoint = readCodePointAndAdvancePosition(buffer, pos); + length++; + } + return length; + } + + /** + * String (array of code points) Writing + */ + static void writeCodePointsAndAdvancePosition(uint8_t *const buffer, + const int *const codePoints, const int codePointCount, const bool writesTerminator, + int *const pos) { + for (int i = 0; i < codePointCount; ++i) { + const int codePoint = codePoints[i]; + if (codePoint == NOT_A_CODE_POINT || codePoint == CHARACTER_ARRAY_TERMINATOR) { + break; + } else if (codePoint < MINIMUM_ONE_BYTE_CHARACTER_VALUE + || codePoint > MAXIMUM_ONE_BYTE_CHARACTER_VALUE) { + // three bytes character. + writeUint24AndAdvancePosition(buffer, codePoint, pos); + } else { + // one byte character. + writeUint8AndAdvancePosition(buffer, codePoint, pos); + } + } + if (writesTerminator) { + writeUint8AndAdvancePosition(buffer, CHARACTER_ARRAY_TERMINATOR, pos); + } + } + + static int calculateRequiredByteCountToStoreCodePoints(const int *const codePoints, + const int codePointCount, const bool writesTerminator) { + int byteCount = 0; + for (int i = 0; i < codePointCount; ++i) { + const int codePoint = codePoints[i]; + if (codePoint == NOT_A_CODE_POINT || codePoint == CHARACTER_ARRAY_TERMINATOR) { + break; + } else if (codePoint < MINIMUM_ONE_BYTE_CHARACTER_VALUE + || codePoint > MAXIMUM_ONE_BYTE_CHARACTER_VALUE) { + // three bytes character. + byteCount += 3; + } else { + // one byte character. + byteCount += 1; + } + } + if (writesTerminator) { + // The terminator is one byte. + byteCount += 1; + } + return byteCount; + } + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(ByteArrayUtils); + + static const uint8_t MINIMUM_ONE_BYTE_CHARACTER_VALUE; + static const uint8_t MAXIMUM_ONE_BYTE_CHARACTER_VALUE; + static const uint8_t CHARACTER_ARRAY_TERMINATOR; + + static AK_FORCE_INLINE void writeUint32AndAdvancePosition(uint8_t *const buffer, + const uint32_t data, int *const pos) { + buffer[(*pos)++] = (data >> 24) & 0xFF; + buffer[(*pos)++] = (data >> 16) & 0xFF; + buffer[(*pos)++] = (data >> 8) & 0xFF; + buffer[(*pos)++] = data & 0xFF; + } + + static AK_FORCE_INLINE void writeUint24AndAdvancePosition(uint8_t *const buffer, + const uint32_t data, int *const pos) { + buffer[(*pos)++] = (data >> 16) & 0xFF; + buffer[(*pos)++] = (data >> 8) & 0xFF; + buffer[(*pos)++] = data & 0xFF; + } + + static AK_FORCE_INLINE void writeUint16AndAdvancePosition(uint8_t *const buffer, + const uint16_t data, int *const pos) { + buffer[(*pos)++] = (data >> 8) & 0xFF; + buffer[(*pos)++] = data & 0xFF; + } + + static AK_FORCE_INLINE void writeUint8AndAdvancePosition(uint8_t *const buffer, + const uint8_t data, int *const pos) { + buffer[(*pos)++] = data & 0xFF; + } +}; +} // namespace latinime +#endif /* LATINIME_BYTE_ARRAY_UTILS_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp new file mode 100644 index 000000000..2e4ec2e1d --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp @@ -0,0 +1,107 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/utils/dict_file_writing_utils.h" + +#include <cstdio> +#include <cstring> + +#include "suggest/policyimpl/dictionary/header/header_policy.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_writing_utils.h" +#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h" +#include "suggest/policyimpl/dictionary/utils/format_utils.h" + +namespace latinime { + +const char *const DictFileWritingUtils::TEMP_FILE_SUFFIX_FOR_WRITING_DICT_FILE = ".tmp"; + +/* static */ bool DictFileWritingUtils::createEmptyDictFile(const char *const filePath, + const int dictVersion, const HeaderReadWriteUtils::AttributeMap *const attributeMap) { + switch (dictVersion) { + case 3: + return createEmptyV3DictFile(filePath, attributeMap); + default: + // Only version 3 dictionary is supported for now. + return false; + } +} + +/* static */ bool DictFileWritingUtils::createEmptyV3DictFile(const char *const filePath, + const HeaderReadWriteUtils::AttributeMap *const attributeMap) { + BufferWithExtendableBuffer headerBuffer(0 /* originalBuffer */, 0 /* originalBufferSize */); + HeaderPolicy headerPolicy(FormatUtils::VERSION_3, attributeMap); + headerPolicy.writeHeaderToBuffer(&headerBuffer, true /* updatesLastUpdatedTime */); + BufferWithExtendableBuffer bodyBuffer(0 /* originalBuffer */, 0 /* originalBufferSize */); + if (!DynamicPatriciaTrieWritingUtils::writeEmptyDictionary(&bodyBuffer, 0 /* rootPos */)) { + return false; + } + return flushAllHeaderAndBodyToFile(filePath, &headerBuffer, &bodyBuffer); +} + +/* static */ bool DictFileWritingUtils::flushAllHeaderAndBodyToFile(const char *const filePath, + BufferWithExtendableBuffer *const dictHeader, BufferWithExtendableBuffer *const dictBody) { + const int tmpFileNameBufSize = strlen(filePath) + + strlen(TEMP_FILE_SUFFIX_FOR_WRITING_DICT_FILE) + 1 /* terminator */; + // Name of a temporary file used for writing that is a connected string of original name and + // TEMP_FILE_SUFFIX_FOR_WRITING_DICT_FILE. + char tmpFileName[tmpFileNameBufSize]; + snprintf(tmpFileName, tmpFileNameBufSize, "%s%s", filePath, + TEMP_FILE_SUFFIX_FOR_WRITING_DICT_FILE); + FILE *const file = fopen(tmpFileName, "wb"); + if (!file) { + AKLOGE("Dictionary file %s cannnot be opened.", tmpFileName); + ASSERT(false); + return false; + } + // Write the dictionary header. + if (!writeBufferToFile(file, dictHeader)) { + remove(tmpFileName); + AKLOGE("Dictionary header cannnot be written. size: %d", dictHeader->getTailPosition()); + ASSERT(false); + return false; + } + // Write the dictionary body. + if (!writeBufferToFile(file, dictBody)) { + remove(tmpFileName); + AKLOGE("Dictionary body cannnot be written. size: %d", dictBody->getTailPosition()); + ASSERT(false); + return false; + } + fclose(file); + rename(tmpFileName, filePath); + return true; +} + +// This closes file pointer when an error is caused and returns whether the writing was succeeded +// or not. +/* static */ bool DictFileWritingUtils::writeBufferToFile(FILE *const file, + const BufferWithExtendableBuffer *const buffer) { + const int originalBufSize = buffer->getOriginalBufferSize(); + if (originalBufSize > 0 && fwrite(buffer->getBuffer(false /* usesAdditionalBuffer */), + originalBufSize, 1, file) < 1) { + fclose(file); + return false; + } + const int additionalBufSize = buffer->getTailPosition() - buffer->getOriginalBufferSize(); + if (additionalBufSize > 0 && fwrite(buffer->getBuffer(true /* usesAdditionalBuffer */), + additionalBufSize, 1, file) < 1) { + fclose(file); + return false; + } + return true; +} + +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.h new file mode 100644 index 000000000..bd4ac66fd --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.h @@ -0,0 +1,50 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DICT_FILE_WRITING_UTILS_H +#define LATINIME_DICT_FILE_WRITING_UTILS_H + +#include <cstdio> + +#include "defines.h" +#include "suggest/policyimpl/dictionary/header/header_read_write_utils.h" + +namespace latinime { + +class BufferWithExtendableBuffer; + +class DictFileWritingUtils { + public: + static bool createEmptyDictFile(const char *const filePath, const int dictVersion, + const HeaderReadWriteUtils::AttributeMap *const attributeMap); + + static bool flushAllHeaderAndBodyToFile(const char *const filePath, + BufferWithExtendableBuffer *const dictHeader, + BufferWithExtendableBuffer *const dictBody); + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(DictFileWritingUtils); + + static const char *const TEMP_FILE_SUFFIX_FOR_WRITING_DICT_FILE; + + static bool createEmptyV3DictFile(const char *const filePath, + const HeaderReadWriteUtils::AttributeMap *const attributeMap); + + static bool writeBufferToFile(FILE *const file, + const BufferWithExtendableBuffer *const buffer); +}; +} // namespace latinime +#endif /* LATINIME_DICT_FILE_WRITING_UTILS_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp new file mode 100644 index 000000000..1d77d5c27 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.cpp @@ -0,0 +1,56 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/utils/format_utils.h" + +#include "suggest/policyimpl/dictionary/utils/byte_array_utils.h" + +namespace latinime { + +const uint32_t FormatUtils::MAGIC_NUMBER = 0x9BC13AFE; + +// Magic number (4 bytes), version (2 bytes), flags (2 bytes), header size (4 bytes) = 12 +const int FormatUtils::DICTIONARY_MINIMUM_SIZE = 12; + +/* static */ FormatUtils::FORMAT_VERSION FormatUtils::detectFormatVersion( + const uint8_t *const dict, const int dictSize) { + // The magic number is stored big-endian. + // If the dictionary is less than 4 bytes, we can't even read the magic number, so we don't + // understand this format. + if (dictSize < DICTIONARY_MINIMUM_SIZE) { + return UNKNOWN_VERSION; + } + const uint32_t magicNumber = ByteArrayUtils::readUint32(dict, 0); + switch (magicNumber) { + case MAGIC_NUMBER: + // Version 2 header is as follows: + // Magic number (4 bytes) 0x9B 0xC1 0x3A 0xFE + // Dictionary format version number (2 bytes) + // Options (2 bytes) + // Header size (4 bytes) : integer, big endian + if (ByteArrayUtils::readUint16(dict, 4) == 2) { + return VERSION_2; + } else if (ByteArrayUtils::readUint16(dict, 4) == 3) { + return VERSION_3; + } else { + return UNKNOWN_VERSION; + } + default: + return UNKNOWN_VERSION; + } +} + +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h new file mode 100644 index 000000000..79ed0de29 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/format_utils.h @@ -0,0 +1,49 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_FORMAT_UTILS_H +#define LATINIME_FORMAT_UTILS_H + +#include <stdint.h> + +#include "defines.h" + +namespace latinime { + +/** + * Methods to handle binary dictionary format version. + */ +class FormatUtils { + public: + enum FORMAT_VERSION { + VERSION_2, + VERSION_3, + UNKNOWN_VERSION + }; + + // 32 bit magic number is stored at the beginning of the dictionary header to reject + // unsupported or obsolete dictionary formats. + static const uint32_t MAGIC_NUMBER; + + static FORMAT_VERSION detectFormatVersion(const uint8_t *const dict, const int dictSize); + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(FormatUtils); + + static const int DICTIONARY_MINIMUM_SIZE; +}; +} // namespace latinime +#endif /* LATINIME_FORMAT_UTILS_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/mmapped_buffer.h b/native/jni/src/suggest/policyimpl/dictionary/utils/mmapped_buffer.h new file mode 100644 index 000000000..6b69116eb --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/mmapped_buffer.h @@ -0,0 +1,102 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_MMAPPED_BUFFER_H +#define LATINIME_MMAPPED_BUFFER_H + +#include <cerrno> +#include <fcntl.h> +#include <stdint.h> +#include <sys/mman.h> +#include <unistd.h> + +#include "defines.h" + +namespace latinime { + +class MmappedBuffer { + public: + static MmappedBuffer* openBuffer(const char *const path, const int bufferOffset, + const int bufferSize, const bool isUpdatable) { + const int openMode = isUpdatable ? O_RDWR : O_RDONLY; + const int mmapFd = open(path, openMode); + if (mmapFd < 0) { + AKLOGE("DICT: Can't open the source. path=%s errno=%d", path, errno); + return 0; + } + const int pagesize = getpagesize(); + const int offset = bufferOffset % pagesize; + int alignedOffset = bufferOffset - offset; + int alignedSize = bufferSize + offset; + const int protMode = isUpdatable ? PROT_READ | PROT_WRITE : PROT_READ; + void *const mmappedBuffer = mmap(0, alignedSize, protMode, MAP_PRIVATE, mmapFd, + alignedOffset); + if (mmappedBuffer == MAP_FAILED) { + AKLOGE("DICT: Can't mmap dictionary. errno=%d", errno); + close(mmapFd); + return 0; + } + uint8_t *const buffer = static_cast<uint8_t *>(mmappedBuffer) + offset; + if (!buffer) { + AKLOGE("DICT: buffer is null"); + close(mmapFd); + return 0; + } + return new MmappedBuffer(buffer, bufferSize, mmappedBuffer, alignedSize, mmapFd, + isUpdatable); + } + + ~MmappedBuffer() { + int ret = munmap(mMmappedBuffer, mAlignedSize); + if (ret != 0) { + AKLOGE("DICT: Failure in munmap. ret=%d errno=%d", ret, errno); + } + ret = close(mMmapFd); + if (ret != 0) { + AKLOGE("DICT: Failure in close. ret=%d errno=%d", ret, errno); + } + } + + AK_FORCE_INLINE uint8_t *getBuffer() const { + return mBuffer; + } + + AK_FORCE_INLINE int getBufferSize() const { + return mBufferSize; + } + + AK_FORCE_INLINE bool isUpdatable() const { + return mIsUpdatable; + } + + private: + AK_FORCE_INLINE MmappedBuffer(uint8_t *const buffer, const int bufferSize, + void *const mmappedBuffer, const int alignedSize, const int mmapFd, + const bool isUpdatable) + : mBuffer(buffer), mBufferSize(bufferSize), mMmappedBuffer(mmappedBuffer), + mAlignedSize(alignedSize), mMmapFd(mmapFd), mIsUpdatable(isUpdatable) {} + + DISALLOW_IMPLICIT_CONSTRUCTORS(MmappedBuffer); + + uint8_t *const mBuffer; + const int mBufferSize; + void *const mMmappedBuffer; + const int mAlignedSize; + const int mMmapFd; + const bool mIsUpdatable; +}; +} +#endif /* LATINIME_MMAPPED_BUFFER_H */ diff --git a/native/jni/src/suggest/core/dictionary/probability_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/probability_utils.h index 14d2f8436..21fe355b8 100644 --- a/native/jni/src/suggest/core/dictionary/probability_utils.h +++ b/native/jni/src/suggest/policyimpl/dictionary/utils/probability_utils.h @@ -17,7 +17,6 @@ #ifndef LATINIME_PROBABILITY_UTILS_H #define LATINIME_PROBABILITY_UTILS_H -#include <map> #include <stdint.h> #include "defines.h" @@ -42,31 +41,13 @@ class ProbabilityUtils { // the unigram probability to be the median value of the 17th step from the top. A value of // 0 for the bigram probability represents the middle of the 16th step from the top, // while a value of 15 represents the middle of the top step. - // See makedict.BinaryDictInputOutput for details. + // See makedict.BinaryDictEncoder#makeBigramFlags for details. const float stepSize = static_cast<float>(MAX_PROBABILITY - unigramProbability) / (1.5f + MAX_BIGRAM_ENCODED_PROBABILITY); return unigramProbability + static_cast<int>(static_cast<float>(bigramProbability + 1) * stepSize); } - // This returns a probability in log space. - static AK_FORCE_INLINE int getProbability(const int position, - const std::map<int, int> *const bigramMap, - const uint8_t *bigramFilter, const int unigramProbability) { - if (!bigramMap || !bigramFilter) { - return backoff(unigramProbability); - } - if (!isInFilter(bigramFilter, position)){ - return backoff(unigramProbability); - } - const std::map<int, int>::const_iterator bigramProbabilityIt = bigramMap->find(position); - if (bigramProbabilityIt != bigramMap->end()) { - const int bigramProbability = bigramProbabilityIt->second; - return computeProbabilityForBigram(unigramProbability, bigramProbability); - } - return backoff(unigramProbability); - } - private: DISALLOW_IMPLICIT_CONSTRUCTORS(ProbabilityUtils); }; diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp index f87989286..ecceb60d3 100644 --- a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp +++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp @@ -22,31 +22,36 @@ const float ScoringParams::MAX_SPATIAL_DISTANCE = 1.0f; const int ScoringParams::THRESHOLD_NEXT_WORD_PROBABILITY = 40; const int ScoringParams::THRESHOLD_NEXT_WORD_PROBABILITY_FOR_CAPPED = 120; const float ScoringParams::AUTOCORRECT_OUTPUT_THRESHOLD = 1.0f; -const int ScoringParams::MAX_CACHE_DIC_NODE_SIZE = 125; +// TODO: Unlimit max cache dic node size +const int ScoringParams::MAX_CACHE_DIC_NODE_SIZE = 170; +const int ScoringParams::MAX_CACHE_DIC_NODE_SIZE_FOR_SINGLE_POINT = 310; const int ScoringParams::THRESHOLD_SHORT_WORD_LENGTH = 4; const float ScoringParams::DISTANCE_WEIGHT_LENGTH = 0.132f; -const float ScoringParams::PROXIMITY_COST = 0.086f; -const float ScoringParams::FIRST_PROXIMITY_COST = 0.104f; +const float ScoringParams::PROXIMITY_COST = 0.095f; +const float ScoringParams::FIRST_CHAR_PROXIMITY_COST = 0.102f; +const float ScoringParams::FIRST_PROXIMITY_COST = 0.019f; const float ScoringParams::OMISSION_COST = 0.458f; const float ScoringParams::OMISSION_COST_SAME_CHAR = 0.491f; const float ScoringParams::OMISSION_COST_FIRST_CHAR = 0.582f; const float ScoringParams::INSERTION_COST = 0.730f; +const float ScoringParams::TERMINAL_INSERTION_COST = 0.93f; const float ScoringParams::INSERTION_COST_SAME_CHAR = 0.586f; +const float ScoringParams::INSERTION_COST_PROXIMITY_CHAR = 0.70f; const float ScoringParams::INSERTION_COST_FIRST_CHAR = 0.623f; -const float ScoringParams::TRANSPOSITION_COST = 0.516f; +const float ScoringParams::TRANSPOSITION_COST = 0.526f; const float ScoringParams::SPACE_SUBSTITUTION_COST = 0.319f; const float ScoringParams::ADDITIONAL_PROXIMITY_COST = 0.380f; -const float ScoringParams::SUBSTITUTION_COST = 0.403f; +const float ScoringParams::SUBSTITUTION_COST = 0.383f; const float ScoringParams::COST_NEW_WORD = 0.042f; const float ScoringParams::COST_SECOND_OR_LATER_WORD_FIRST_CHAR_UPPERCASE = 0.25f; const float ScoringParams::DISTANCE_WEIGHT_LANGUAGE = 1.123f; const float ScoringParams::COST_FIRST_LOOKAHEAD = 0.545f; const float ScoringParams::COST_LOOKAHEAD = 0.073f; -const float ScoringParams::HAS_PROXIMITY_TERMINAL_COST = 0.105f; -const float ScoringParams::HAS_EDIT_CORRECTION_TERMINAL_COST = 0.038f; -const float ScoringParams::HAS_MULTI_WORD_TERMINAL_COST = 0.444f; +const float ScoringParams::HAS_PROXIMITY_TERMINAL_COST = 0.093f; +const float ScoringParams::HAS_EDIT_CORRECTION_TERMINAL_COST = 0.041f; +const float ScoringParams::HAS_MULTI_WORD_TERMINAL_COST = 0.447f; const float ScoringParams::TYPING_BASE_OUTPUT_SCORE = 1.0f; const float ScoringParams::TYPING_MAX_OUTPUT_SCORE_PER_INPUT = 0.1f; -const float ScoringParams::NORMALIZED_SPATIAL_DISTANCE_THRESHOLD_FOR_EDIT = 0.06f; +const float ScoringParams::NORMALIZED_SPATIAL_DISTANCE_THRESHOLD_FOR_EDIT = 0.045f; } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.h b/native/jni/src/suggest/policyimpl/typing/scoring_params.h index 53ac999c1..7d4b5c3c7 100644 --- a/native/jni/src/suggest/policyimpl/typing/scoring_params.h +++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.h @@ -29,6 +29,7 @@ class ScoringParams { static const int THRESHOLD_NEXT_WORD_PROBABILITY_FOR_CAPPED; static const float AUTOCORRECT_OUTPUT_THRESHOLD; static const int MAX_CACHE_DIC_NODE_SIZE; + static const int MAX_CACHE_DIC_NODE_SIZE_FOR_SINGLE_POINT; static const int THRESHOLD_SHORT_WORD_LENGTH; // Numerically optimized parameters (currently for tap typing only). @@ -36,12 +37,15 @@ class ScoringParams { // TODO: explore optimization of gesture parameters. static const float DISTANCE_WEIGHT_LENGTH; static const float PROXIMITY_COST; + static const float FIRST_CHAR_PROXIMITY_COST; static const float FIRST_PROXIMITY_COST; static const float OMISSION_COST; static const float OMISSION_COST_SAME_CHAR; static const float OMISSION_COST_FIRST_CHAR; static const float INSERTION_COST; + static const float TERMINAL_INSERTION_COST; static const float INSERTION_COST_SAME_CHAR; + static const float INSERTION_COST_PROXIMITY_CHAR; static const float INSERTION_COST_FIRST_CHAR; static const float TRANSPOSITION_COST; static const float SPACE_SUBSTITUTION_COST; diff --git a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h index 90e2133e7..56ffcc93e 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h @@ -55,10 +55,10 @@ class TypingScoring : public Scoring { const int inputSize, const bool forceCommit) const { const float maxDistance = ScoringParams::DISTANCE_WEIGHT_LANGUAGE + static_cast<float>(inputSize) * ScoringParams::TYPING_MAX_OUTPUT_SCORE_PER_INPUT; - return static_cast<int>((ScoringParams::TYPING_BASE_OUTPUT_SCORE - - (compoundDistance / maxDistance) - + (forceCommit ? ScoringParams::AUTOCORRECT_OUTPUT_THRESHOLD : 0.0f)) - * SUGGEST_INTERFACE_OUTPUT_SCALE); + const float score = ScoringParams::TYPING_BASE_OUTPUT_SCORE + - compoundDistance / maxDistance + + (forceCommit ? ScoringParams::AUTOCORRECT_OUTPUT_THRESHOLD : 0.0f); + return static_cast<int>(score * SUGGEST_INTERFACE_OUTPUT_SCALE); } AK_FORCE_INLINE float getDoubleLetterDemotionDistanceCost(const int terminalIndex, diff --git a/native/jni/src/suggest/policyimpl/typing/typing_traversal.h b/native/jni/src/suggest/policyimpl/typing/typing_traversal.h index e21b318e6..89e53f441 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_traversal.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_traversal.h @@ -23,6 +23,7 @@ #include "suggest/core/dicnode/dic_node.h" #include "suggest/core/dicnode/dic_node_vector.h" #include "suggest/core/layout/proximity_info_state.h" +#include "suggest/core/layout/proximity_info_utils.h" #include "suggest/core/policy/traversal.h" #include "suggest/core/session/dic_traverse_session.h" #include "suggest/policyimpl/typing/scoring_params.h" @@ -136,7 +137,7 @@ class TypingTraversal : public Traversal { return ScoringParams::MAX_SPATIAL_DISTANCE; } - AK_FORCE_INLINE bool allowPartialCommit() const { + AK_FORCE_INLINE bool autoCorrectsToMultiWordSuggestionIfTop() const { return true; } @@ -147,11 +148,12 @@ class TypingTraversal : public Traversal { AK_FORCE_INLINE bool sameAsTyped( const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const { return traverseSession->getProximityInfoState(0)->sameAsTyped( - dicNode->getOutputWordBuf(), dicNode->getDepth()); + dicNode->getOutputWordBuf(), dicNode->getNodeCodePointCount()); } - AK_FORCE_INLINE int getMaxCacheSize() const { - return ScoringParams::MAX_CACHE_DIC_NODE_SIZE; + AK_FORCE_INLINE int getMaxCacheSize(const int inputSize) const { + return (inputSize <= 1) ? ScoringParams::MAX_CACHE_DIC_NODE_SIZE_FOR_SINGLE_POINT + : ScoringParams::MAX_CACHE_DIC_NODE_SIZE; } AK_FORCE_INLINE bool isPossibleOmissionChildNode( @@ -159,7 +161,7 @@ class TypingTraversal : public Traversal { const DicNode *const dicNode) const { const ProximityType proximityType = getProximityType(traverseSession, parentDicNode, dicNode); - if (!DicNodeUtils::isProximityChar(proximityType)) { + if (!ProximityInfoUtils::isMatchOrProximityChar(proximityType)) { return false; } return true; @@ -171,7 +173,7 @@ class TypingTraversal : public Traversal { return false; } const int c = dicNode->getOutputWordBuf()[0]; - const bool shortCappedWord = dicNode->getDepth() + const bool shortCappedWord = dicNode->getNodeCodePointCount() < ScoringParams::THRESHOLD_SHORT_WORD_LENGTH && CharUtils::isAsciiUpper(c); return !shortCappedWord || probability >= ScoringParams::THRESHOLD_NEXT_WORD_PROBABILITY_FOR_CAPPED; diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp b/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp index e4c69d1f6..408b12ae9 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp +++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp @@ -44,6 +44,7 @@ ErrorType TypingWeighting::getErrorType(const CorrectionType correctionType, break; case CT_SUBSTITUTION: case CT_INSERTION: + case CT_TERMINAL_INSERTION: case CT_TRANSPOSITION: return ET_EDIT_CORRECTION; case CT_NEW_WORD_SPACE_OMITTION: diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h index 17fa11082..9f0a331e3 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h @@ -55,7 +55,7 @@ class TypingWeighting : public Weighting { const bool isZeroCostOmission = parentDicNode->isZeroCostOmission(); const bool sameCodePoint = dicNode->isSameNodeCodePoint(parentDicNode); // If the traversal omitted the first letter then the dicNode should now be on the second. - const bool isFirstLetterOmission = dicNode->getDepth() == 2; + const bool isFirstLetterOmission = dicNode->getNodeCodePointCount() == 2; float cost = 0.0f; if (isZeroCostOmission) { cost = 0.0f; @@ -74,16 +74,20 @@ class TypingWeighting : public Weighting { // Note: min() required since length can be MAX_POINT_TO_KEY_LENGTH for characters not on // the keyboard (like accented letters) const float normalizedSquaredLength = traverseSession->getProximityInfoState(0) - ->getPointToKeyLength(pointIndex, dicNode->getNodeCodePoint()); + ->getPointToKeyLength(pointIndex, + CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint())); const float normalizedDistance = TouchPositionCorrectionUtils::getSweetSpotFactor( traverseSession->isTouchPositionCorrectionEnabled(), normalizedSquaredLength); const float weightedDistance = ScoringParams::DISTANCE_WEIGHT_LENGTH * normalizedDistance; const bool isFirstChar = pointIndex == 0; const bool isProximity = isProximityDicNode(traverseSession, dicNode); - float cost = isProximity ? (isFirstChar ? ScoringParams::FIRST_PROXIMITY_COST + float cost = isProximity ? (isFirstChar ? ScoringParams::FIRST_CHAR_PROXIMITY_COST : ScoringParams::PROXIMITY_COST) : 0.0f; - if (dicNode->getDepth() == 2) { + if (isProximity && dicNode->getProximityCorrectionCount() == 0) { + cost += ScoringParams::FIRST_PROXIMITY_COST; + } + if (dicNode->getNodeCodePointCount() == 2) { // At the second character of the current word, we check if the first char is uppercase // and the word is a second or later word of a multiple word suggestion. We demote it // if so. @@ -110,10 +114,10 @@ class TypingWeighting : public Weighting { const int16_t parentPointIndex = parentDicNode->getInputIndex(0); const int prevCodePoint = parentDicNode->getNodeCodePoint(); const float distance1 = traverseSession->getProximityInfoState(0)->getPointToKeyLength( - parentPointIndex + 1, prevCodePoint); + parentPointIndex + 1, CharUtils::toBaseLowerCase(prevCodePoint)); const int codePoint = dicNode->getNodeCodePoint(); const float distance2 = traverseSession->getProximityInfoState(0)->getPointToKeyLength( - parentPointIndex, codePoint); + parentPointIndex, CharUtils::toBaseLowerCase(codePoint)); const float distance = distance1 + distance2; const float weightedLengthDistance = distance * ScoringParams::DISTANCE_WEIGHT_LENGTH; @@ -122,31 +126,38 @@ class TypingWeighting : public Weighting { float getInsertionCost(const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, const DicNode *const dicNode) const { - const int16_t parentPointIndex = parentDicNode->getInputIndex(0); - const int prevCodePoint = - traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt(parentPointIndex); - + const int16_t insertedPointIndex = parentDicNode->getInputIndex(0); + const int prevCodePoint = traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt( + insertedPointIndex); const int currentCodePoint = dicNode->getNodeCodePoint(); const bool sameCodePoint = prevCodePoint == currentCodePoint; + const bool existsAdjacentProximityChars = traverseSession->getProximityInfoState(0) + ->existsAdjacentProximityChars(insertedPointIndex); const float dist = traverseSession->getProximityInfoState(0)->getPointToKeyLength( - parentPointIndex + 1, currentCodePoint); + insertedPointIndex + 1, CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint())); const float weightedDistance = dist * ScoringParams::DISTANCE_WEIGHT_LENGTH; - const bool singleChar = dicNode->getDepth() == 1; - const float cost = (singleChar ? ScoringParams::INSERTION_COST_FIRST_CHAR : 0.0f) - + (sameCodePoint ? ScoringParams::INSERTION_COST_SAME_CHAR - : ScoringParams::INSERTION_COST); + const bool singleChar = dicNode->getNodeCodePointCount() == 1; + float cost = (singleChar ? ScoringParams::INSERTION_COST_FIRST_CHAR : 0.0f); + if (sameCodePoint) { + cost += ScoringParams::INSERTION_COST_SAME_CHAR; + } else if (existsAdjacentProximityChars) { + cost += ScoringParams::INSERTION_COST_PROXIMITY_CHAR; + } else { + cost += ScoringParams::INSERTION_COST; + } return cost + weightedDistance; } - float getNewWordCost(const DicTraverseSession *const traverseSession, - const DicNode *const dicNode) const { + float getNewWordSpatialCost(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode, DicNode_InputStateG *inputStateG) const { return ScoringParams::COST_NEW_WORD * traverseSession->getMultiWordCostMultiplier(); } - float getNewWordBigramCost(const DicTraverseSession *const traverseSession, + float getNewWordBigramLanguageCost(const DicTraverseSession *const traverseSession, const DicNode *const dicNode, MultiBigramMap *const multiBigramMap) const { - return DicNodeUtils::getBigramNodeImprobability(traverseSession->getBinaryDictionaryInfo(), + return DicNodeUtils::getBigramNodeImprobability( + traverseSession->getDictionaryStructurePolicy(), dicNode, multiBigramMap) * ScoringParams::DISTANCE_WEIGHT_LANGUAGE; } @@ -163,9 +174,16 @@ class TypingWeighting : public Weighting { float getTerminalLanguageCost(const DicTraverseSession *const traverseSession, const DicNode *const dicNode, const float dicNodeLanguageImprobability) const { - const float languageImprobability = (dicNode->isExactMatch()) ? - 0.0f : dicNodeLanguageImprobability; - return languageImprobability * ScoringParams::DISTANCE_WEIGHT_LANGUAGE; + return dicNodeLanguageImprobability * ScoringParams::DISTANCE_WEIGHT_LANGUAGE; + } + + float getTerminalInsertionCost(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const { + const int inputIndex = dicNode->getInputIndex(0); + const int inputSize = traverseSession->getInputSize(); + ASSERT(inputIndex < inputSize); + // TODO: Implement more efficient logic + return ScoringParams::TERMINAL_INSERTION_COST * (inputSize - inputIndex); } AK_FORCE_INLINE bool needsToNormalizeCompoundDistance() const { diff --git a/native/jni/src/suggest/policyimpl/utils/edit_distance.h b/native/jni/src/suggest/policyimpl/utils/edit_distance.h index cbbd66894..0871c37ce 100644 --- a/native/jni/src/suggest/policyimpl/utils/edit_distance.h +++ b/native/jni/src/suggest/policyimpl/utils/edit_distance.h @@ -62,6 +62,26 @@ class EditDistance { return dp[(beforeLength + 1) * (afterLength + 1) - 1]; } + AK_FORCE_INLINE static void dumpEditDistance10ForDebug(const float *const editDistanceTable, + const int editDistanceTableWidth, const int outputLength) { + if (DEBUG_DICT) { + AKLOGI("EditDistanceTable"); + for (int i = 0; i <= 10; ++i) { + float c[11]; + for (int j = 0; j <= 10; ++j) { + if (j < editDistanceTableWidth + 1 && i < outputLength + 1) { + c[j] = (editDistanceTable + i * (editDistanceTableWidth + 1))[j]; + } else { + c[j] = -1.0f; + } + } + AKLOGI("[ %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f ]", + c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7], c[8], c[9], c[10]); + (void)c; // To suppress compiler warning + } + } + } + private: DISALLOW_IMPLICIT_CONSTRUCTORS(EditDistance); }; diff --git a/native/jni/src/utils/autocorrection_threshold_utils.cpp b/native/jni/src/utils/autocorrection_threshold_utils.cpp new file mode 100644 index 000000000..1f8ee0814 --- /dev/null +++ b/native/jni/src/utils/autocorrection_threshold_utils.cpp @@ -0,0 +1,108 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/autocorrection_threshold_utils.h" + +#include <cmath> + +#include "defines.h" +#include "suggest/policyimpl/utils/edit_distance.h" +#include "suggest/policyimpl/utils/damerau_levenshtein_edit_distance_policy.h" + +namespace latinime { + +const int AutocorrectionThresholdUtils::MAX_INITIAL_SCORE = 255; +const int AutocorrectionThresholdUtils::TYPED_LETTER_MULTIPLIER = 2; +const int AutocorrectionThresholdUtils::FULL_WORD_MULTIPLIER = 2; + +/* static */ int AutocorrectionThresholdUtils::editDistance(const int *before, + const int beforeLength, const int *after, const int afterLength) { + const DamerauLevenshteinEditDistancePolicy daemaruLevenshtein( + before, beforeLength, after, afterLength); + return static_cast<int>(EditDistance::getEditDistance(&daemaruLevenshtein)); +} + +// In dictionary.cpp, getSuggestion() method, +// When USE_SUGGEST_INTERFACE_FOR_TYPING is true: +// +// // TODO: Revise the following logic thoroughly by referring to the logic +// // marked as "Otherwise" below. +// SUGGEST_INTERFACE_OUTPUT_SCALE was multiplied to the original suggestion scores to convert +// them to integers. +// score = (int)((original score) * SUGGEST_INTERFACE_OUTPUT_SCALE) +// Undo the scaling here to recover the original score. +// normalizedScore = ((float)score) / SUGGEST_INTERFACE_OUTPUT_SCALE +// +// Otherwise: suggestion scores are computed using the below formula. +// original score +// := powf(mTypedLetterMultiplier (this is defined 2), +// (the number of matched characters between typed word and suggested word)) +// * (individual word's score which defined in the unigram dictionary, +// and this score is defined in range [0, 255].) +// Then, the following processing is applied. +// - If the dictionary word is matched up to the point of the user entry +// (full match up to min(before.length(), after.length()) +// => Then multiply by FULL_MATCHED_WORDS_PROMOTION_RATE (this is defined 1.2) +// - If the word is a true full match except for differences in accents or +// capitalization, then treat it as if the score was 255. +// - If before.length() == after.length() +// => multiply by mFullWordMultiplier (this is defined 2)) +// So, maximum original score is powf(2, min(before.length(), after.length())) * 255 * 2 * 1.2 +// For historical reasons we ignore the 1.2 modifier (because the measure for a good +// autocorrection threshold was done at a time when it didn't exist). This doesn't change +// the result. +// So, we can normalize original score by dividing powf(2, min(b.l(),a.l())) * 255 * 2. + +/* static */ float AutocorrectionThresholdUtils::calcNormalizedScore(const int *before, + const int beforeLength, const int *after, const int afterLength, const int score) { + if (0 == beforeLength || 0 == afterLength) { + return 0.0f; + } + const int distance = editDistance(before, beforeLength, after, afterLength); + int spaceCount = 0; + for (int i = 0; i < afterLength; ++i) { + if (after[i] == KEYCODE_SPACE) { + ++spaceCount; + } + } + + if (spaceCount == afterLength) { + return 0.0f; + } + + if (score <= 0 || distance >= afterLength) { + // normalizedScore must be 0.0f (the minimum value) if the score is less than or equal to 0, + // or if the edit distance is larger than or equal to afterLength. + return 0.0f; + } + // add a weight based on edit distance. + const float weight = 1.0f - static_cast<float>(distance) / static_cast<float>(afterLength); + + // TODO: Revise the following logic thoroughly by referring to... + if (true /* USE_SUGGEST_INTERFACE_FOR_TYPING */) { + return (static_cast<float>(score) / SUGGEST_INTERFACE_OUTPUT_SCALE) * weight; + } + // ...this logic. + const float maxScore = score >= S_INT_MAX ? static_cast<float>(S_INT_MAX) + : static_cast<float>(MAX_INITIAL_SCORE) + * powf(static_cast<float>(TYPED_LETTER_MULTIPLIER), + static_cast<float>(min(beforeLength, afterLength - spaceCount))) + * static_cast<float>(FULL_WORD_MULTIPLIER); + + return (static_cast<float>(score) / maxScore) * weight; +} + +} // namespace latinime diff --git a/native/jni/src/utils/autocorrection_threshold_utils.h b/native/jni/src/utils/autocorrection_threshold_utils.h new file mode 100644 index 000000000..c7537a6a5 --- /dev/null +++ b/native/jni/src/utils/autocorrection_threshold_utils.h @@ -0,0 +1,39 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_AUTOCORRECTION_THRESHOLD_UTILS_H +#define LATINIME_AUTOCORRECTION_THRESHOLD_UTILS_H + +#include "defines.h" + +namespace latinime { + +class AutocorrectionThresholdUtils { + public: + static float calcNormalizedScore(const int *before, const int beforeLength, + const int *after, const int afterLength, const int score); + static int editDistance(const int *before, const int beforeLength, const int *after, + const int afterLength); + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(AutocorrectionThresholdUtils); + + static const int MAX_INITIAL_SCORE; + static const int TYPED_LETTER_MULTIPLIER; + static const int FULL_WORD_MULTIPLIER; +}; +} // namespace latinime +#endif // LATINIME_AUTOCORRECTION_THRESHOLD_UTILS_H diff --git a/native/jni/src/utils/log_utils.cpp b/native/jni/src/utils/log_utils.cpp new file mode 100644 index 000000000..5ab2b2862 --- /dev/null +++ b/native/jni/src/utils/log_utils.cpp @@ -0,0 +1,72 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "log_utils.h" + +#include <cstdio> +#include <stdarg.h> + +#include "defines.h" + +namespace latinime { + /* static */ void LogUtils::logToJava(JNIEnv *const env, const char *const format, ...) { + static const char *TAG = "LatinIME:LogUtils"; + const jclass androidUtilLogClass = env->FindClass("android/util/Log"); + if (!androidUtilLogClass) { + // If we can't find the class, we are probably in off-device testing, and + // it's expected. Regardless, logging is not essential to functionality, so + // we should just return. However, FindClass has thrown an exception behind + // our back and there is no way to prevent it from doing that, so we clear + // the exception before we return. + env->ExceptionClear(); + return; + } + const jmethodID logDotIMethodId = env->GetStaticMethodID(androidUtilLogClass, "i", + "(Ljava/lang/String;Ljava/lang/String;)I"); + if (!logDotIMethodId) { + env->ExceptionClear(); + if (androidUtilLogClass) env->DeleteLocalRef(androidUtilLogClass); + return; + } + const jstring javaTag = env->NewStringUTF(TAG); + + static const int DEFAULT_LINE_SIZE = 128; + char fixedSizeCString[DEFAULT_LINE_SIZE]; + va_list argList; + va_start(argList, format); + // Get the necessary size. Add 1 for the 0 terminator. + const int size = vsnprintf(fixedSizeCString, DEFAULT_LINE_SIZE, format, argList) + 1; + va_end(argList); + + jstring javaString; + if (size <= DEFAULT_LINE_SIZE) { + // The buffer was large enough. + javaString = env->NewStringUTF(fixedSizeCString); + } else { + // The buffer was not large enough. + va_start(argList, format); + char variableSizeCString[size]; + vsnprintf(variableSizeCString, size, format, argList); + va_end(argList); + javaString = env->NewStringUTF(variableSizeCString); + } + + env->CallStaticIntMethod(androidUtilLogClass, logDotIMethodId, javaTag, javaString); + if (javaString) env->DeleteLocalRef(javaString); + if (javaTag) env->DeleteLocalRef(javaTag); + if (androidUtilLogClass) env->DeleteLocalRef(androidUtilLogClass); + } +} diff --git a/native/jni/src/utils/log_utils.h b/native/jni/src/utils/log_utils.h new file mode 100644 index 000000000..6ac16d91a --- /dev/null +++ b/native/jni/src/utils/log_utils.h @@ -0,0 +1,37 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_LOG_UTILS_H +#define LATINIME_LOG_UTILS_H + +#include "defines.h" +#include "jni.h" + +namespace latinime { + +class LogUtils { + public: + static void logToJava(JNIEnv *const env, const char *const format, ...) +#ifdef __GNUC__ + __attribute__ ((format (printf, 2, 3))) +#endif // __GNUC__ + ; + + private: + DISALLOW_COPY_AND_ASSIGN(LogUtils); +}; +} // namespace latinime +#endif // LATINIME_LOG_UTILS_H |