diff options
Diffstat (limited to 'native/jni/src')
104 files changed, 3821 insertions, 4866 deletions
diff --git a/native/jni/src/bloom_filter.h b/native/jni/src/bloom_filter.h deleted file mode 100644 index bcce1f7ea..000000000 --- a/native/jni/src/bloom_filter.h +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright (C) 2012 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LATINIME_BLOOM_FILTER_H -#define LATINIME_BLOOM_FILTER_H - -#include <stdint.h> - -#include "defines.h" - -namespace latinime { - -// TODO: uint32_t position -static inline void setInFilter(uint8_t *filter, const int32_t position) { - const uint32_t bucket = static_cast<uint32_t>(position % BIGRAM_FILTER_MODULO); - filter[bucket >> 3] |= static_cast<uint8_t>(1 << (bucket & 0x7)); -} - -// TODO: uint32_t position -static inline bool isInFilter(const uint8_t *filter, const int32_t position) { - const uint32_t bucket = static_cast<uint32_t>(position % BIGRAM_FILTER_MODULO); - return filter[bucket >> 3] & static_cast<uint8_t>(1 << (bucket & 0x7)); -} -} // namespace latinime -#endif // LATINIME_BLOOM_FILTER_H diff --git a/native/jni/src/char_utils.h b/native/jni/src/char_utils.h deleted file mode 100644 index b429f40b2..000000000 --- a/native/jni/src/char_utils.h +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Copyright (C) 2010 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LATINIME_CHAR_UTILS_H -#define LATINIME_CHAR_UTILS_H - -#include <cctype> - -#include "defines.h" - -namespace latinime { - -inline static bool isAsciiUpper(int c) { - // Note: isupper(...) reports false positives for some Cyrillic characters, causing them to - // be incorrectly lower-cased using toAsciiLower(...) rather than latin_tolower(...). - return (c >= 'A' && c <= 'Z'); -} - -inline static int toAsciiLower(int c) { - return c - 'A' + 'a'; -} - -inline static bool isAscii(int c) { - return isascii(c) != 0; -} - -unsigned short latin_tolower(const unsigned short c); - -/** - * Table mapping most combined Latin, Greek, and Cyrillic characters - * to their base characters. If c is in range, BASE_CHARS[c] == c - * if c is not a combined character, or the base character if it - * is combined. - */ -static const int BASE_CHARS_SIZE = 0x0500; -extern const unsigned short BASE_CHARS[BASE_CHARS_SIZE]; - -inline static int toBaseCodePoint(int c) { - if (c < BASE_CHARS_SIZE) { - return static_cast<int>(BASE_CHARS[c]); - } - return c; -} - -AK_FORCE_INLINE static int toLowerCase(const int c) { - if (isAsciiUpper(c)) { - return toAsciiLower(c); - } - if (isAscii(c)) { - return c; - } - return static_cast<int>(latin_tolower(static_cast<unsigned short>(c))); -} - -AK_FORCE_INLINE static int toBaseLowerCase(const int c) { - return toLowerCase(toBaseCodePoint(c)); -} - -inline static bool isIntentionalOmissionCodePoint(const int codePoint) { - // TODO: Do not hardcode here - return codePoint == KEYCODE_SINGLE_QUOTE || codePoint == KEYCODE_HYPHEN_MINUS; -} - -inline static int getCodePointCount(const int arraySize, const int *const codePoints) { - int size = 0; - for (; size < arraySize; ++size) { - if (codePoints[size] == '\0') { - break; - } - } - return size; -} - -} // namespace latinime -#endif // LATINIME_CHAR_UTILS_H diff --git a/native/jni/src/correction.cpp b/native/jni/src/correction.cpp deleted file mode 100644 index 61bf3f619..000000000 --- a/native/jni/src/correction.cpp +++ /dev/null @@ -1,979 +0,0 @@ -/* - * Copyright (C) 2011 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#define LOG_TAG "LatinIME: correction.cpp" - -#include <cmath> - -#include "char_utils.h" -#include "correction.h" -#include "defines.h" -#include "proximity_info_state.h" -#include "suggest_utils.h" -#include "suggest/policyimpl/utils/edit_distance.h" -#include "suggest/policyimpl/utils/damerau_levenshtein_edit_distance_policy.h" - -namespace latinime { - -class ProximityInfo; - -///////////////////////////// -// edit distance funcitons // -///////////////////////////// - -inline static void initEditDistance(int *editDistanceTable) { - for (int i = 0; i <= MAX_WORD_LENGTH; ++i) { - editDistanceTable[i] = i; - } -} - -inline static void dumpEditDistance10ForDebug(int *editDistanceTable, - const int editDistanceTableWidth, const int outputLength) { - if (DEBUG_DICT) { - AKLOGI("EditDistanceTable"); - for (int i = 0; i <= 10; ++i) { - int c[11]; - for (int j = 0; j <= 10; ++j) { - if (j < editDistanceTableWidth + 1 && i < outputLength + 1) { - c[j] = (editDistanceTable + i * (editDistanceTableWidth + 1))[j]; - } else { - c[j] = -1; - } - } - AKLOGI("[ %d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d ]", - c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7], c[8], c[9], c[10]); - (void)c; // To suppress compiler warning - } - } -} - -inline static int getCurrentEditDistance(int *editDistanceTable, const int editDistanceTableWidth, - const int outputLength, const int inputSize) { - if (DEBUG_EDIT_DISTANCE) { - AKLOGI("getCurrentEditDistance %d, %d", inputSize, outputLength); - } - return editDistanceTable[(editDistanceTableWidth + 1) * (outputLength) + inputSize]; -} - -//////////////// -// Correction // -//////////////// - -void Correction::resetCorrection() { - mTotalTraverseCount = 0; -} - -void Correction::initCorrection(const ProximityInfo *pi, const int inputSize, const int maxDepth) { - mProximityInfo = pi; - mInputSize = inputSize; - mMaxDepth = maxDepth; - mMaxEditDistance = mInputSize < 5 ? 2 : mInputSize / 2; - // TODO: This is not supposed to be required. Check what's going wrong with - // editDistance[0 ~ MAX_WORD_LENGTH] - initEditDistance(mEditDistanceTable); -} - -void Correction::initCorrectionState( - const int rootPos, const int childCount, const bool traverseAll) { - latinime::initCorrectionState(mCorrectionStates, rootPos, childCount, traverseAll); - // TODO: remove - mCorrectionStates[0].mTransposedPos = mTransposedPos; - mCorrectionStates[0].mExcessivePos = mExcessivePos; - mCorrectionStates[0].mSkipPos = mSkipPos; -} - -void Correction::setCorrectionParams(const int skipPos, const int excessivePos, - const int transposedPos, const int spaceProximityPos, const int missingSpacePos, - const bool useFullEditDistance, const bool doAutoCompletion, const int maxErrors) { - // TODO: remove - mTransposedPos = transposedPos; - mExcessivePos = excessivePos; - mSkipPos = skipPos; - // TODO: remove - mCorrectionStates[0].mTransposedPos = transposedPos; - mCorrectionStates[0].mExcessivePos = excessivePos; - mCorrectionStates[0].mSkipPos = skipPos; - - mSpaceProximityPos = spaceProximityPos; - mMissingSpacePos = missingSpacePos; - mUseFullEditDistance = useFullEditDistance; - mDoAutoCompletion = doAutoCompletion; - mMaxErrors = maxErrors; -} - -void Correction::checkState() const { - if (DEBUG_DICT) { - int inputCount = 0; - if (mSkipPos >= 0) ++inputCount; - if (mExcessivePos >= 0) ++inputCount; - if (mTransposedPos >= 0) ++inputCount; - } -} - -bool Correction::sameAsTyped() const { - return mProximityInfoState.sameAsTyped(mWord, mOutputIndex); -} - -int Correction::getFreqForSplitMultipleWords(const int *freqArray, const int *wordLengthArray, - const int wordCount, const bool isSpaceProximity, const int *word) const { - return Correction::RankingAlgorithm::calcFreqForSplitMultipleWords(freqArray, wordLengthArray, - wordCount, this, isSpaceProximity, word); -} - -int Correction::getFinalProbability(const int probability, int **word, int *wordLength) { - return getFinalProbabilityInternal(probability, word, wordLength, mInputSize); -} - -int Correction::getFinalProbabilityForSubQueue(const int probability, int **word, int *wordLength, - const int inputSize) { - return getFinalProbabilityInternal(probability, word, wordLength, inputSize); -} - -bool Correction::initProcessState(const int outputIndex) { - if (mCorrectionStates[outputIndex].mChildCount <= 0) { - return false; - } - mOutputIndex = outputIndex; - --(mCorrectionStates[outputIndex].mChildCount); - mInputIndex = mCorrectionStates[outputIndex].mInputIndex; - mNeedsToTraverseAllNodes = mCorrectionStates[outputIndex].mNeedsToTraverseAllNodes; - - mEquivalentCharCount = mCorrectionStates[outputIndex].mEquivalentCharCount; - mProximityCount = mCorrectionStates[outputIndex].mProximityCount; - mTransposedCount = mCorrectionStates[outputIndex].mTransposedCount; - mExcessiveCount = mCorrectionStates[outputIndex].mExcessiveCount; - mSkippedCount = mCorrectionStates[outputIndex].mSkippedCount; - mLastCharExceeded = mCorrectionStates[outputIndex].mLastCharExceeded; - - mTransposedPos = mCorrectionStates[outputIndex].mTransposedPos; - mExcessivePos = mCorrectionStates[outputIndex].mExcessivePos; - mSkipPos = mCorrectionStates[outputIndex].mSkipPos; - - mMatching = false; - mProximityMatching = false; - mAdditionalProximityMatching = false; - mTransposing = false; - mExceeding = false; - mSkipping = false; - - return true; -} - -int Correction::goDownTree(const int parentIndex, const int childCount, const int firstChildPos) { - mCorrectionStates[mOutputIndex].mParentIndex = parentIndex; - mCorrectionStates[mOutputIndex].mChildCount = childCount; - mCorrectionStates[mOutputIndex].mSiblingPos = firstChildPos; - return mOutputIndex; -} - -// TODO: remove -int Correction::getInputIndex() const { - return mInputIndex; -} - -bool Correction::needsToPrune() const { - // TODO: use edit distance here - return mOutputIndex - 1 >= mMaxDepth || mProximityCount > mMaxEditDistance - // Allow one char longer word for missing character - || (!mDoAutoCompletion && (mOutputIndex > mInputSize)); -} - -inline static bool isEquivalentChar(ProximityType type) { - return type == MATCH_CHAR; -} - -inline static bool isProximityCharOrEquivalentChar(ProximityType type) { - return type == MATCH_CHAR || type == PROXIMITY_CHAR; -} - -Correction::CorrectionType Correction::processCharAndCalcState(const int c, const bool isTerminal) { - const int correctionCount = (mSkippedCount + mExcessiveCount + mTransposedCount); - if (correctionCount > mMaxErrors) { - return processUnrelatedCorrectionType(); - } - - // TODO: Change the limit if we'll allow two or more corrections - const bool noCorrectionsHappenedSoFar = correctionCount == 0; - const bool canTryCorrection = noCorrectionsHappenedSoFar; - int proximityIndex = 0; - mDistances[mOutputIndex] = NOT_A_DISTANCE; - - // Skip checking this node - if (mNeedsToTraverseAllNodes || isSingleQuote(c)) { - bool incremented = false; - if (mLastCharExceeded && mInputIndex == mInputSize - 1) { - // TODO: Do not check the proximity if EditDistance exceeds the threshold - const ProximityType matchId = mProximityInfoState.getProximityType( - mInputIndex, c, true, &proximityIndex); - if (isEquivalentChar(matchId)) { - mLastCharExceeded = false; - --mExcessiveCount; - mDistances[mOutputIndex] = - mProximityInfoState.getNormalizedSquaredDistance(mInputIndex, 0); - } else if (matchId == PROXIMITY_CHAR) { - mLastCharExceeded = false; - --mExcessiveCount; - ++mProximityCount; - mDistances[mOutputIndex] = mProximityInfoState.getNormalizedSquaredDistance( - mInputIndex, proximityIndex); - } - if (!isSingleQuote(c)) { - incrementInputIndex(); - incremented = true; - } - } - return processSkipChar(c, isTerminal, incremented); - } - - // Check possible corrections. - if (mExcessivePos >= 0) { - if (mExcessiveCount == 0 && mExcessivePos < mOutputIndex) { - mExcessivePos = mOutputIndex; - } - if (mExcessivePos < mInputSize - 1) { - mExceeding = mExcessivePos == mInputIndex && canTryCorrection; - } - } - - if (mSkipPos >= 0) { - if (mSkippedCount == 0 && mSkipPos < mOutputIndex) { - if (DEBUG_DICT) { - // TODO: Enable this assertion. - //ASSERT(mSkipPos == mOutputIndex - 1); - } - mSkipPos = mOutputIndex; - } - mSkipping = mSkipPos == mOutputIndex && canTryCorrection; - } - - if (mTransposedPos >= 0) { - if (mTransposedCount == 0 && mTransposedPos < mOutputIndex) { - mTransposedPos = mOutputIndex; - } - if (mTransposedPos < mInputSize - 1) { - mTransposing = mInputIndex == mTransposedPos && canTryCorrection; - } - } - - bool secondTransposing = false; - if (mTransposedCount % 2 == 1) { - if (isEquivalentChar(mProximityInfoState.getProximityType( - mInputIndex - 1, c, false))) { - ++mTransposedCount; - secondTransposing = true; - } else if (mCorrectionStates[mOutputIndex].mExceeding) { - --mTransposedCount; - ++mExcessiveCount; - --mExcessivePos; - incrementInputIndex(); - } else { - --mTransposedCount; - if (DEBUG_CORRECTION - && (INPUTLENGTH_FOR_DEBUG <= 0 || INPUTLENGTH_FOR_DEBUG == mInputSize) - && (MIN_OUTPUT_INDEX_FOR_DEBUG <= 0 - || MIN_OUTPUT_INDEX_FOR_DEBUG < mOutputIndex)) { - DUMP_WORD(mWord, mOutputIndex); - AKLOGI("UNRELATED(0): %d, %d, %d, %d, %c", mProximityCount, mSkippedCount, - mTransposedCount, mExcessiveCount, c); - } - return processUnrelatedCorrectionType(); - } - } - - // TODO: Change the limit if we'll allow two or more proximity chars with corrections - // Work around: When the mMaxErrors is 1, we only allow just one error - // including proximity correction. - const bool checkProximityChars = (mMaxErrors > 1) - ? (noCorrectionsHappenedSoFar || mProximityCount == 0) - : (noCorrectionsHappenedSoFar && mProximityCount == 0); - - ProximityType matchedProximityCharId = secondTransposing - ? MATCH_CHAR - : mProximityInfoState.getProximityType( - mInputIndex, c, checkProximityChars, &proximityIndex); - - if (SUBSTITUTION_CHAR == matchedProximityCharId - || ADDITIONAL_PROXIMITY_CHAR == matchedProximityCharId) { - if (canTryCorrection && mOutputIndex > 0 - && mCorrectionStates[mOutputIndex].mProximityMatching - && mCorrectionStates[mOutputIndex].mExceeding - && isEquivalentChar(mProximityInfoState.getProximityType( - mInputIndex, mWord[mOutputIndex - 1], false))) { - if (DEBUG_CORRECTION - && (INPUTLENGTH_FOR_DEBUG <= 0 || INPUTLENGTH_FOR_DEBUG == mInputSize) - && (MIN_OUTPUT_INDEX_FOR_DEBUG <= 0 - || MIN_OUTPUT_INDEX_FOR_DEBUG < mOutputIndex)) { - AKLOGI("CONVERSION p->e %c", mWord[mOutputIndex - 1]); - } - // Conversion p->e - // Example: - // wearth -> earth - // px -> (E)mmmmm - ++mExcessiveCount; - --mProximityCount; - mExcessivePos = mOutputIndex - 1; - ++mInputIndex; - // Here, we are doing something equivalent to matchedProximityCharId, - // but we already know that "excessive char correction" just happened - // so that we just need to check "mProximityCount == 0". - matchedProximityCharId = mProximityInfoState.getProximityType( - mInputIndex, c, mProximityCount == 0, &proximityIndex); - } - } - - if (SUBSTITUTION_CHAR == matchedProximityCharId - || ADDITIONAL_PROXIMITY_CHAR == matchedProximityCharId) { - if (ADDITIONAL_PROXIMITY_CHAR == matchedProximityCharId) { - mAdditionalProximityMatching = true; - } - // TODO: Optimize - // As the current char turned out to be an unrelated char, - // we will try other correction-types. Please note that mCorrectionStates[mOutputIndex] - // here refers to the previous state. - if (mInputIndex < mInputSize - 1 && mOutputIndex > 0 && mTransposedCount > 0 - && !mCorrectionStates[mOutputIndex].mTransposing - && mCorrectionStates[mOutputIndex - 1].mTransposing - && isEquivalentChar(mProximityInfoState.getProximityType( - mInputIndex, mWord[mOutputIndex - 1], false)) - && isEquivalentChar( - mProximityInfoState.getProximityType(mInputIndex + 1, c, false))) { - // Conversion t->e - // Example: - // occaisional -> occa sional - // mmmmttx -> mmmm(E)mmmmmm - mTransposedCount -= 2; - ++mExcessiveCount; - ++mInputIndex; - } else if (mOutputIndex > 0 && mInputIndex > 0 && mTransposedCount > 0 - && !mCorrectionStates[mOutputIndex].mTransposing - && mCorrectionStates[mOutputIndex - 1].mTransposing - && isEquivalentChar( - mProximityInfoState.getProximityType(mInputIndex - 1, c, false))) { - // Conversion t->s - // Example: - // chcolate -> chocolate - // mmttx -> mmsmmmmmm - mTransposedCount -= 2; - ++mSkippedCount; - --mInputIndex; - } else if (canTryCorrection && mInputIndex > 0 - && mCorrectionStates[mOutputIndex].mProximityMatching - && mCorrectionStates[mOutputIndex].mSkipping - && isEquivalentChar( - mProximityInfoState.getProximityType(mInputIndex - 1, c, false))) { - // Conversion p->s - // Note: This logic tries saving cases like contrst --> contrast -- "a" is one of - // proximity chars of "s", but it should rather be handled as a skipped char. - ++mSkippedCount; - --mProximityCount; - return processSkipChar(c, isTerminal, false); - } else if (mInputIndex - 1 < mInputSize - && mSkippedCount > 0 - && mCorrectionStates[mOutputIndex].mSkipping - && mCorrectionStates[mOutputIndex].mAdditionalProximityMatching - && isProximityCharOrEquivalentChar( - mProximityInfoState.getProximityType(mInputIndex + 1, c, false))) { - // Conversion s->a - incrementInputIndex(); - --mSkippedCount; - mProximityMatching = true; - ++mProximityCount; - mDistances[mOutputIndex] = ADDITIONAL_PROXIMITY_CHAR_DISTANCE_INFO; - } else if ((mExceeding || mTransposing) && mInputIndex - 1 < mInputSize - && isEquivalentChar( - mProximityInfoState.getProximityType(mInputIndex + 1, c, false))) { - // 1.2. Excessive or transpose correction - if (mTransposing) { - ++mTransposedCount; - } else { - ++mExcessiveCount; - incrementInputIndex(); - } - if (DEBUG_CORRECTION - && (INPUTLENGTH_FOR_DEBUG <= 0 || INPUTLENGTH_FOR_DEBUG == mInputSize) - && (MIN_OUTPUT_INDEX_FOR_DEBUG <= 0 - || MIN_OUTPUT_INDEX_FOR_DEBUG < mOutputIndex)) { - DUMP_WORD(mWord, mOutputIndex); - if (mTransposing) { - AKLOGI("TRANSPOSE: %d, %d, %d, %d, %c", mProximityCount, mSkippedCount, - mTransposedCount, mExcessiveCount, c); - } else { - AKLOGI("EXCEED: %d, %d, %d, %d, %c", mProximityCount, mSkippedCount, - mTransposedCount, mExcessiveCount, c); - } - } - } else if (mSkipping) { - // 3. Skip correction - ++mSkippedCount; - if (DEBUG_CORRECTION - && (INPUTLENGTH_FOR_DEBUG <= 0 || INPUTLENGTH_FOR_DEBUG == mInputSize) - && (MIN_OUTPUT_INDEX_FOR_DEBUG <= 0 - || MIN_OUTPUT_INDEX_FOR_DEBUG < mOutputIndex)) { - AKLOGI("SKIP: %d, %d, %d, %d, %c", mProximityCount, mSkippedCount, - mTransposedCount, mExcessiveCount, c); - } - return processSkipChar(c, isTerminal, false); - } else if (ADDITIONAL_PROXIMITY_CHAR == matchedProximityCharId) { - // As a last resort, use additional proximity characters - mProximityMatching = true; - ++mProximityCount; - mDistances[mOutputIndex] = ADDITIONAL_PROXIMITY_CHAR_DISTANCE_INFO; - if (DEBUG_CORRECTION - && (INPUTLENGTH_FOR_DEBUG <= 0 || INPUTLENGTH_FOR_DEBUG == mInputSize) - && (MIN_OUTPUT_INDEX_FOR_DEBUG <= 0 - || MIN_OUTPUT_INDEX_FOR_DEBUG < mOutputIndex)) { - AKLOGI("ADDITIONALPROX: %d, %d, %d, %d, %c", mProximityCount, mSkippedCount, - mTransposedCount, mExcessiveCount, c); - } - } else { - if (DEBUG_CORRECTION - && (INPUTLENGTH_FOR_DEBUG <= 0 || INPUTLENGTH_FOR_DEBUG == mInputSize) - && (MIN_OUTPUT_INDEX_FOR_DEBUG <= 0 - || MIN_OUTPUT_INDEX_FOR_DEBUG < mOutputIndex)) { - DUMP_WORD(mWord, mOutputIndex); - AKLOGI("UNRELATED(1): %d, %d, %d, %d, %c", mProximityCount, mSkippedCount, - mTransposedCount, mExcessiveCount, c); - } - return processUnrelatedCorrectionType(); - } - } else if (secondTransposing) { - // If inputIndex is greater than mInputSize, that means there is no - // proximity chars. So, we don't need to check proximity. - mMatching = true; - } else if (isEquivalentChar(matchedProximityCharId)) { - mMatching = true; - ++mEquivalentCharCount; - mDistances[mOutputIndex] = mProximityInfoState.getNormalizedSquaredDistance(mInputIndex, 0); - } else if (PROXIMITY_CHAR == matchedProximityCharId) { - mProximityMatching = true; - ++mProximityCount; - mDistances[mOutputIndex] = - mProximityInfoState.getNormalizedSquaredDistance(mInputIndex, proximityIndex); - if (DEBUG_CORRECTION - && (INPUTLENGTH_FOR_DEBUG <= 0 || INPUTLENGTH_FOR_DEBUG == mInputSize) - && (MIN_OUTPUT_INDEX_FOR_DEBUG <= 0 - || MIN_OUTPUT_INDEX_FOR_DEBUG < mOutputIndex)) { - AKLOGI("PROX: %d, %d, %d, %d, %c", mProximityCount, mSkippedCount, - mTransposedCount, mExcessiveCount, c); - } - } - - addCharToCurrentWord(c); - - // 4. Last char excessive correction - mLastCharExceeded = mExcessiveCount == 0 && mSkippedCount == 0 && mTransposedCount == 0 - && mProximityCount == 0 && (mInputIndex == mInputSize - 2); - const bool isSameAsUserTypedLength = (mInputSize == mInputIndex + 1) || mLastCharExceeded; - if (mLastCharExceeded) { - ++mExcessiveCount; - } - - // Start traversing all nodes after the index exceeds the user typed length - if (isSameAsUserTypedLength) { - startToTraverseAllNodes(); - } - - const bool needsToTryOnTerminalForTheLastPossibleExcessiveChar = - mExceeding && mInputIndex == mInputSize - 2; - - // Finally, we are ready to go to the next character, the next "virtual node". - // We should advance the input index. - // We do this in this branch of the 'if traverseAllNodes' because we are still matching - // characters to input; the other branch is not matching them but searching for - // completions, this is why it does not have to do it. - incrementInputIndex(); - // Also, the next char is one "virtual node" depth more than this char. - incrementOutputIndex(); - - if ((needsToTryOnTerminalForTheLastPossibleExcessiveChar - || isSameAsUserTypedLength) && isTerminal) { - mTerminalInputIndex = mInputIndex - 1; - mTerminalOutputIndex = mOutputIndex - 1; - if (DEBUG_CORRECTION - && (INPUTLENGTH_FOR_DEBUG <= 0 || INPUTLENGTH_FOR_DEBUG == mInputSize) - && (MIN_OUTPUT_INDEX_FOR_DEBUG <= 0 || MIN_OUTPUT_INDEX_FOR_DEBUG < mOutputIndex)) { - DUMP_WORD(mWord, mOutputIndex); - AKLOGI("ONTERMINAL(1): %d, %d, %d, %d, %c", mProximityCount, mSkippedCount, - mTransposedCount, mExcessiveCount, c); - } - return ON_TERMINAL; - } else { - mTerminalInputIndex = mInputIndex - 1; - mTerminalOutputIndex = mOutputIndex - 1; - return NOT_ON_TERMINAL; - } -} - -inline static int getQuoteCount(const int *word, const int length) { - int quoteCount = 0; - for (int i = 0; i < length; ++i) { - if (word[i] == KEYCODE_SINGLE_QUOTE) { - ++quoteCount; - } - } - return quoteCount; -} - -inline static bool isUpperCase(unsigned short c) { - return isAsciiUpper(toBaseCodePoint(c)); -} - -////////////////////// -// RankingAlgorithm // -////////////////////// - -/* static */ int Correction::RankingAlgorithm::calculateFinalProbability(const int inputIndex, - const int outputIndex, const int freq, int *editDistanceTable, const Correction *correction, - const int inputSize) { - const int excessivePos = correction->getExcessivePos(); - const int typedLetterMultiplier = correction->TYPED_LETTER_MULTIPLIER; - const int fullWordMultiplier = correction->FULL_WORD_MULTIPLIER; - const ProximityInfoState *proximityInfoState = &correction->mProximityInfoState; - const int skippedCount = correction->mSkippedCount; - const int transposedCount = correction->mTransposedCount / 2; - const int excessiveCount = correction->mExcessiveCount + correction->mTransposedCount % 2; - const int proximityMatchedCount = correction->mProximityCount; - const bool lastCharExceeded = correction->mLastCharExceeded; - const bool useFullEditDistance = correction->mUseFullEditDistance; - const int outputLength = outputIndex + 1; - if (skippedCount >= inputSize || inputSize == 0) { - return -1; - } - - // TODO: find more robust way - bool sameLength = lastCharExceeded ? (inputSize == inputIndex + 2) - : (inputSize == inputIndex + 1); - - // TODO: use mExcessiveCount - const int matchCount = inputSize - correction->mProximityCount - excessiveCount; - - const int *word = correction->mWord; - const bool skipped = skippedCount > 0; - - const int quoteDiffCount = max(0, getQuoteCount(word, outputLength) - - getQuoteCount(proximityInfoState->getPrimaryInputWord(), inputSize)); - - // TODO: Calculate edit distance for transposed and excessive - int ed = 0; - if (DEBUG_DICT_FULL) { - dumpEditDistance10ForDebug(editDistanceTable, correction->mInputSize, outputLength); - } - int adjustedProximityMatchedCount = proximityMatchedCount; - - int finalFreq = freq; - - if (DEBUG_CORRECTION_FREQ - && (INPUTLENGTH_FOR_DEBUG <= 0 || INPUTLENGTH_FOR_DEBUG == inputSize)) { - AKLOGI("FinalFreq0: %d", finalFreq); - } - // TODO: Optimize this. - if (transposedCount > 0 || proximityMatchedCount > 0 || skipped || excessiveCount > 0) { - ed = getCurrentEditDistance(editDistanceTable, correction->mInputSize, outputLength, - inputSize) - transposedCount; - - const int matchWeight = powerIntCapped(typedLetterMultiplier, - max(inputSize, outputLength) - ed); - multiplyIntCapped(matchWeight, &finalFreq); - - // TODO: Demote further if there are two or more excessive chars with longer user input? - if (inputSize > outputLength) { - multiplyRate(INPUT_EXCEEDS_OUTPUT_DEMOTION_RATE, &finalFreq); - } - - ed = max(0, ed - quoteDiffCount); - adjustedProximityMatchedCount = min(max(0, ed - (outputLength - inputSize)), - proximityMatchedCount); - if (transposedCount <= 0) { - if (ed == 1 && (inputSize == outputLength - 1 || inputSize == outputLength + 1)) { - // Promote a word with just one skipped or excessive char - if (sameLength) { - multiplyRate(WORDS_WITH_JUST_ONE_CORRECTION_PROMOTION_RATE - + WORDS_WITH_JUST_ONE_CORRECTION_PROMOTION_MULTIPLIER * outputLength, - &finalFreq); - } else { - multiplyIntCapped(typedLetterMultiplier, &finalFreq); - } - } else if (ed == 0) { - multiplyIntCapped(typedLetterMultiplier, &finalFreq); - sameLength = true; - } - } - } else { - const int matchWeight = powerIntCapped(typedLetterMultiplier, matchCount); - multiplyIntCapped(matchWeight, &finalFreq); - } - - if (proximityInfoState->getProximityType(0, word[0], true) == SUBSTITUTION_CHAR) { - multiplyRate(FIRST_CHAR_DIFFERENT_DEMOTION_RATE, &finalFreq); - } - - /////////////////////////////////////////////// - // Promotion and Demotion for each correction - - // Demotion for a word with missing character - if (skipped) { - const int demotionRate = WORDS_WITH_MISSING_CHARACTER_DEMOTION_RATE - * (10 * inputSize - WORDS_WITH_MISSING_CHARACTER_DEMOTION_START_POS_10X) - / (10 * inputSize - - WORDS_WITH_MISSING_CHARACTER_DEMOTION_START_POS_10X + 10); - if (DEBUG_DICT_FULL) { - AKLOGI("Demotion rate for missing character is %d.", demotionRate); - } - multiplyRate(demotionRate, &finalFreq); - } - - // Demotion for a word with transposed character - if (transposedCount > 0) multiplyRate( - WORDS_WITH_TRANSPOSED_CHARACTERS_DEMOTION_RATE, &finalFreq); - - // Demotion for a word with excessive character - if (excessiveCount > 0) { - multiplyRate(WORDS_WITH_EXCESSIVE_CHARACTER_DEMOTION_RATE, &finalFreq); - if (!lastCharExceeded && !proximityInfoState->existsAdjacentProximityChars(excessivePos)) { - if (DEBUG_DICT_FULL) { - AKLOGI("Double excessive demotion"); - } - // If an excessive character is not adjacent to the left char or the right char, - // we will demote this word. - multiplyRate(WORDS_WITH_EXCESSIVE_CHARACTER_OUT_OF_PROXIMITY_DEMOTION_RATE, &finalFreq); - } - } - - int additionalProximityCount = 0; - // Demote additional proximity characters - for (int i = 0; i < outputLength; ++i) { - const int squaredDistance = correction->mDistances[i]; - if (squaredDistance == ADDITIONAL_PROXIMITY_CHAR_DISTANCE_INFO) { - ++additionalProximityCount; - } - } - - const bool performTouchPositionCorrection = - CALIBRATE_SCORE_BY_TOUCH_COORDINATES - && proximityInfoState->touchPositionCorrectionEnabled() - && skippedCount == 0 && excessiveCount == 0 && transposedCount == 0 - && additionalProximityCount == 0; - - // Score calibration by touch coordinates is being done only for pure-fat finger typing error - // cases. - // TODO: Remove this constraint. - if (performTouchPositionCorrection) { - for (int i = 0; i < outputLength; ++i) { - const int squaredDistance = correction->mDistances[i]; - if (i < adjustedProximityMatchedCount) { - multiplyIntCapped(typedLetterMultiplier, &finalFreq); - } - const float factor = - SuggestUtils::getLengthScalingFactor(static_cast<float>(squaredDistance)); - if (factor > 0.0f) { - multiplyRate(static_cast<int>(factor * 100.0f), &finalFreq); - } else if (squaredDistance == PROXIMITY_CHAR_WITHOUT_DISTANCE_INFO) { - multiplyRate(WORDS_WITH_PROXIMITY_CHARACTER_DEMOTION_RATE, &finalFreq); - } - } - } else { - // Promotion for a word with proximity characters - for (int i = 0; i < adjustedProximityMatchedCount; ++i) { - // A word with proximity corrections - if (DEBUG_DICT_FULL) { - AKLOGI("Found a proximity correction."); - } - multiplyIntCapped(typedLetterMultiplier, &finalFreq); - if (i < additionalProximityCount) { - multiplyRate(WORDS_WITH_ADDITIONAL_PROXIMITY_CHARACTER_DEMOTION_RATE, &finalFreq); - } else { - multiplyRate(WORDS_WITH_PROXIMITY_CHARACTER_DEMOTION_RATE, &finalFreq); - } - } - } - - // If the user types too many(three or more) proximity characters with additional proximity - // character,do not treat as the same length word. - if (sameLength && additionalProximityCount > 0 && (adjustedProximityMatchedCount >= 3 - || transposedCount > 0 || skipped || excessiveCount > 0)) { - sameLength = false; - } - - const int errorCount = adjustedProximityMatchedCount > 0 - ? adjustedProximityMatchedCount - : (proximityMatchedCount + transposedCount); - multiplyRate( - 100 - CORRECTION_COUNT_RATE_DEMOTION_RATE_BASE * errorCount / inputSize, &finalFreq); - - // Promotion for an exactly matched word - if (ed == 0) { - // Full exact match - if (sameLength && transposedCount == 0 && !skipped && excessiveCount == 0 - && quoteDiffCount == 0 && additionalProximityCount == 0) { - finalFreq = capped255MultForFullMatchAccentsOrCapitalizationDifference(finalFreq); - } - } - - // Promote a word with no correction - if (proximityMatchedCount == 0 && transposedCount == 0 && !skipped && excessiveCount == 0 - && additionalProximityCount == 0) { - multiplyRate(FULL_MATCHED_WORDS_PROMOTION_RATE, &finalFreq); - } - - // TODO: Check excessive count and transposed count - // TODO: Remove this if possible - /* - If the last character of the user input word is the same as the next character - of the output word, and also all of characters of the user input are matched - to the output word, we'll promote that word a bit because - that word can be considered the combination of skipped and matched characters. - This means that the 'sm' pattern wins over the 'ma' pattern. - e.g.) - shel -> shell [mmmma] or [mmmsm] - hel -> hello [mmmaa] or [mmsma] - m ... matching - s ... skipping - a ... traversing all - t ... transposing - e ... exceeding - p ... proximity matching - */ - if (matchCount == inputSize && matchCount >= 2 && !skipped - && word[matchCount] == word[matchCount - 1]) { - multiplyRate(WORDS_WITH_MATCH_SKIP_PROMOTION_RATE, &finalFreq); - } - - // TODO: Do not use sameLength? - if (sameLength) { - multiplyIntCapped(fullWordMultiplier, &finalFreq); - } - - if (useFullEditDistance && outputLength > inputSize + 1) { - const int diff = outputLength - inputSize - 1; - const int divider = diff < 31 ? 1 << diff : S_INT_MAX; - finalFreq = divider > finalFreq ? 1 : finalFreq / divider; - } - - if (DEBUG_DICT_FULL) { - AKLOGI("calc: %d, %d", outputLength, sameLength); - } - - if (DEBUG_CORRECTION_FREQ - && (INPUTLENGTH_FOR_DEBUG <= 0 || INPUTLENGTH_FOR_DEBUG == inputSize)) { - DUMP_WORD(correction->getPrimaryInputWord(), inputSize); - DUMP_WORD(correction->mWord, outputLength); - AKLOGI("FinalFreq: [P%d, S%d, T%d, E%d, A%d] %d, %d, %d, %d, %d, %d", proximityMatchedCount, - skippedCount, transposedCount, excessiveCount, additionalProximityCount, - outputLength, lastCharExceeded, sameLength, quoteDiffCount, ed, finalFreq); - } - - return finalFreq; -} - -/* static */ int Correction::RankingAlgorithm::calcFreqForSplitMultipleWords(const int *freqArray, - const int *wordLengthArray, const int wordCount, const Correction *correction, - const bool isSpaceProximity, const int *word) { - const int typedLetterMultiplier = correction->TYPED_LETTER_MULTIPLIER; - - bool firstCapitalizedWordDemotion = false; - bool secondCapitalizedWordDemotion = false; - - { - // TODO: Handle multiple capitalized word demotion properly - const int firstWordLength = wordLengthArray[0]; - const int secondWordLength = wordLengthArray[1]; - if (firstWordLength >= 2) { - firstCapitalizedWordDemotion = isUpperCase(word[0]); - } - - if (secondWordLength >= 2) { - // FIXME: word[firstWordLength + 1] is incorrect. - secondCapitalizedWordDemotion = isUpperCase(word[firstWordLength + 1]); - } - } - - - const bool capitalizedWordDemotion = - firstCapitalizedWordDemotion ^ secondCapitalizedWordDemotion; - - int totalLength = 0; - int totalFreq = 0; - for (int i = 0; i < wordCount; ++i) { - const int wordLength = wordLengthArray[i]; - if (wordLength <= 0) { - return 0; - } - totalLength += wordLength; - const int demotionRate = 100 - TWO_WORDS_CORRECTION_DEMOTION_BASE / (wordLength + 1); - int tempFirstFreq = freqArray[i]; - multiplyRate(demotionRate, &tempFirstFreq); - totalFreq += tempFirstFreq; - } - - if (totalLength <= 0 || totalFreq <= 0) { - return 0; - } - - // TODO: Currently totalFreq is adjusted to two word metrix. - // Promote pairFreq with multiplying by 2, because the word length is the same as the typed - // length. - totalFreq = totalFreq * 2 / wordCount; - if (wordCount > 2) { - // Safety net for 3+ words -- Caveats: many heuristics and workarounds here. - int oneLengthCounter = 0; - int twoLengthCounter = 0; - for (int i = 0; i < wordCount; ++i) { - const int wordLength = wordLengthArray[i]; - // TODO: Use bigram instead of this safety net - if (i < wordCount - 1) { - const int nextWordLength = wordLengthArray[i + 1]; - if (wordLength == 1 && nextWordLength == 2) { - // Safety net to filter 1 length and 2 length sequential words - return 0; - } - } - const int freq = freqArray[i]; - // Demote too short weak words - if (wordLength <= 4 && freq <= SUPPRESS_SHORT_MULTIPLE_WORDS_THRESHOLD_FREQ) { - multiplyRate(100 * freq / MAX_PROBABILITY, &totalFreq); - } - if (wordLength == 1) { - ++oneLengthCounter; - } else if (wordLength == 2) { - ++twoLengthCounter; - } - if (oneLengthCounter >= 2 || (oneLengthCounter + twoLengthCounter) >= 4) { - // Safety net to filter too many short words - return 0; - } - } - multiplyRate(MULTIPLE_WORDS_DEMOTION_RATE, &totalFreq); - } - - // This is a workaround to try offsetting the not-enough-demotion which will be done in - // calcNormalizedScore in Utils.java. - // In calcNormalizedScore the score will be demoted by (1 - 1 / length) - // but we demoted only (1 - 1 / (length + 1)) so we will additionally adjust freq by - // (1 - 1 / length) / (1 - 1 / (length + 1)) = (1 - 1 / (length * length)) - const int normalizedScoreNotEnoughDemotionAdjustment = 100 - 100 / (totalLength * totalLength); - multiplyRate(normalizedScoreNotEnoughDemotionAdjustment, &totalFreq); - - // At this moment, totalFreq is calculated by the following formula: - // (firstFreq * (1 - 1 / (firstWordLength + 1)) + secondFreq * (1 - 1 / (secondWordLength + 1))) - // * (1 - 1 / totalLength) / (1 - 1 / (totalLength + 1)) - - multiplyIntCapped(powerIntCapped(typedLetterMultiplier, totalLength), &totalFreq); - - // This is another workaround to offset the demotion which will be done in - // calcNormalizedScore in Utils.java. - // In calcNormalizedScore the score will be demoted by (1 - 1 / length) so we have to promote - // the same amount because we already have adjusted the synthetic freq of this "missing or - // mistyped space" suggestion candidate above in this method. - const int normalizedScoreDemotionRateOffset = (100 + 100 / totalLength); - multiplyRate(normalizedScoreDemotionRateOffset, &totalFreq); - - if (isSpaceProximity) { - // A word pair with one space proximity correction - if (DEBUG_DICT) { - AKLOGI("Found a word pair with space proximity correction."); - } - multiplyIntCapped(typedLetterMultiplier, &totalFreq); - multiplyRate(WORDS_WITH_PROXIMITY_CHARACTER_DEMOTION_RATE, &totalFreq); - } - - if (isSpaceProximity) { - multiplyRate(WORDS_WITH_MISTYPED_SPACE_DEMOTION_RATE, &totalFreq); - } else { - multiplyRate(WORDS_WITH_MISSING_SPACE_CHARACTER_DEMOTION_RATE, &totalFreq); - } - - if (capitalizedWordDemotion) { - multiplyRate(TWO_WORDS_CAPITALIZED_DEMOTION_RATE, &totalFreq); - } - - if (DEBUG_CORRECTION_FREQ) { - AKLOGI("Multiple words (%d, %d) (%d, %d) %d, %d", freqArray[0], freqArray[1], - wordLengthArray[0], wordLengthArray[1], capitalizedWordDemotion, totalFreq); - DUMP_WORD(word, wordLengthArray[0]); - } - - return totalFreq; -} - -/* static */ int Correction::RankingAlgorithm::editDistance(const int *before, - const int beforeLength, const int *after, const int afterLength) { - const DamerauLevenshteinEditDistancePolicy daemaruLevenshtein( - before, beforeLength, after, afterLength); - return static_cast<int>(EditDistance::getEditDistance(&daemaruLevenshtein)); -} - - -// In dictionary.cpp, getSuggestion() method, -// When USE_SUGGEST_INTERFACE_FOR_TYPING is true: -// SUGGEST_INTERFACE_OUTPUT_SCALE was multiplied to the original suggestion scores to convert -// them to integers. -// score = (int)((original score) * SUGGEST_INTERFACE_OUTPUT_SCALE) -// Undo the scaling here to recover the original score. -// normalizedScore = ((float)score) / SUGGEST_INTERFACE_OUTPUT_SCALE -// Otherwise: suggestion scores are computed using the below formula. -// original score -// := powf(mTypedLetterMultiplier (this is defined 2), -// (the number of matched characters between typed word and suggested word)) -// * (individual word's score which defined in the unigram dictionary, -// and this score is defined in range [0, 255].) -// Then, the following processing is applied. -// - If the dictionary word is matched up to the point of the user entry -// (full match up to min(before.length(), after.length()) -// => Then multiply by FULL_MATCHED_WORDS_PROMOTION_RATE (this is defined 1.2) -// - If the word is a true full match except for differences in accents or -// capitalization, then treat it as if the score was 255. -// - If before.length() == after.length() -// => multiply by mFullWordMultiplier (this is defined 2)) -// So, maximum original score is powf(2, min(before.length(), after.length())) * 255 * 2 * 1.2 -// For historical reasons we ignore the 1.2 modifier (because the measure for a good -// autocorrection threshold was done at a time when it didn't exist). This doesn't change -// the result. -// So, we can normalize original score by dividing powf(2, min(b.l(),a.l())) * 255 * 2. - -/* static */ float Correction::RankingAlgorithm::calcNormalizedScore(const int *before, - const int beforeLength, const int *after, const int afterLength, const int score) { - if (0 == beforeLength || 0 == afterLength) { - return 0.0f; - } - const int distance = editDistance(before, beforeLength, after, afterLength); - int spaceCount = 0; - for (int i = 0; i < afterLength; ++i) { - if (after[i] == KEYCODE_SPACE) { - ++spaceCount; - } - } - - if (spaceCount == afterLength) { - return 0.0f; - } - - // add a weight based on edit distance. - // distance <= max(afterLength, beforeLength) == afterLength, - // so, 0 <= distance / afterLength <= 1 - const float weight = 1.0f - static_cast<float>(distance) / static_cast<float>(afterLength); - - if (USE_SUGGEST_INTERFACE_FOR_TYPING) { - return (static_cast<float>(score) / SUGGEST_INTERFACE_OUTPUT_SCALE) * weight; - } - const float maxScore = score >= S_INT_MAX ? static_cast<float>(S_INT_MAX) - : static_cast<float>(MAX_INITIAL_SCORE) - * powf(static_cast<float>(TYPED_LETTER_MULTIPLIER), - static_cast<float>(min(beforeLength, afterLength - spaceCount))) - * static_cast<float>(FULL_WORD_MULTIPLIER); - - return (static_cast<float>(score) / maxScore) * weight; -} -} // namespace latinime diff --git a/native/jni/src/correction.h b/native/jni/src/correction.h deleted file mode 100644 index a9e9b48a6..000000000 --- a/native/jni/src/correction.h +++ /dev/null @@ -1,376 +0,0 @@ -/* - * Copyright (C) 2011 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LATINIME_CORRECTION_H -#define LATINIME_CORRECTION_H - -#include <cstring> // for memset() - -#include "correction_state.h" -#include "defines.h" -#include "proximity_info_state.h" - -namespace latinime { - -class ProximityInfo; - -class Correction { - public: - typedef enum { - TRAVERSE_ALL_ON_TERMINAL, - TRAVERSE_ALL_NOT_ON_TERMINAL, - UNRELATED, - ON_TERMINAL, - NOT_ON_TERMINAL - } CorrectionType; - - Correction() - : mProximityInfo(0), mUseFullEditDistance(false), mDoAutoCompletion(false), - mMaxEditDistance(0), mMaxDepth(0), mInputSize(0), mSpaceProximityPos(0), - mMissingSpacePos(0), mTerminalInputIndex(0), mTerminalOutputIndex(0), mMaxErrors(0), - mTotalTraverseCount(0), mNeedsToTraverseAllNodes(false), mOutputIndex(0), - mInputIndex(0), mEquivalentCharCount(0), mProximityCount(0), mExcessiveCount(0), - mTransposedCount(0), mSkippedCount(0), mTransposedPos(0), mExcessivePos(0), - mSkipPos(0), mLastCharExceeded(false), mMatching(false), mProximityMatching(false), - mAdditionalProximityMatching(false), mExceeding(false), mTransposing(false), - mSkipping(false), mProximityInfoState() { - memset(mWord, 0, sizeof(mWord)); - memset(mDistances, 0, sizeof(mDistances)); - memset(mEditDistanceTable, 0, sizeof(mEditDistanceTable)); - // NOTE: mCorrectionStates is an array of instances. - // No need to initialize it explicitly here. - } - - // Non virtual inline destructor -- never inherit this class - ~Correction() {} - void resetCorrection(); - void initCorrection(const ProximityInfo *pi, const int inputSize, const int maxDepth); - void initCorrectionState(const int rootPos, const int childCount, const bool traverseAll); - - // TODO: remove - void setCorrectionParams(const int skipPos, const int excessivePos, const int transposedPos, - const int spaceProximityPos, const int missingSpacePos, const bool useFullEditDistance, - const bool doAutoCompletion, const int maxErrors); - void checkState() const; - bool sameAsTyped() const; - bool initProcessState(const int index); - - int getInputIndex() const; - - bool needsToPrune() const; - - int pushAndGetTotalTraverseCount() { - return ++mTotalTraverseCount; - } - - int getFreqForSplitMultipleWords(const int *freqArray, const int *wordLengthArray, - const int wordCount, const bool isSpaceProximity, const int *word) const; - int getFinalProbability(const int probability, int **word, int *wordLength); - int getFinalProbabilityForSubQueue(const int probability, int **word, int *wordLength, - const int inputSize); - - CorrectionType processCharAndCalcState(const int c, const bool isTerminal); - - ///////////////////////// - // Tree helper methods - int goDownTree(const int parentIndex, const int childCount, const int firstChildPos); - - inline int getTreeSiblingPos(const int index) const { - return mCorrectionStates[index].mSiblingPos; - } - - inline void setTreeSiblingPos(const int index, const int pos) { - mCorrectionStates[index].mSiblingPos = pos; - } - - inline int getTreeParentIndex(const int index) const { - return mCorrectionStates[index].mParentIndex; - } - - class RankingAlgorithm { - public: - static int calculateFinalProbability(const int inputIndex, const int depth, - const int probability, int *editDistanceTable, const Correction *correction, - const int inputSize); - static int calcFreqForSplitMultipleWords(const int *freqArray, const int *wordLengthArray, - const int wordCount, const Correction *correction, const bool isSpaceProximity, - const int *word); - static float calcNormalizedScore(const int *before, const int beforeLength, - const int *after, const int afterLength, const int score); - static int editDistance(const int *before, const int beforeLength, const int *after, - const int afterLength); - private: - static const int MAX_INITIAL_SCORE = 255; - }; - - // proximity info state - void initInputParams(const ProximityInfo *proximityInfo, const int *inputCodes, - const int inputSize, const int *xCoordinates, const int *yCoordinates) { - mProximityInfoState.initInputParams(0, static_cast<float>(MAX_VALUE_FOR_WEIGHTING), - proximityInfo, inputCodes, inputSize, xCoordinates, yCoordinates, 0, 0, false); - } - - const int *getPrimaryInputWord() const { - return mProximityInfoState.getPrimaryInputWord(); - } - - int getPrimaryCodePointAt(const int index) const { - return mProximityInfoState.getPrimaryCodePointAt(index); - } - - private: - DISALLOW_COPY_AND_ASSIGN(Correction); - - ///////////////////////// - // static inline utils // - ///////////////////////// - static const int TWO_31ST_DIV_255 = S_INT_MAX / 255; - static inline int capped255MultForFullMatchAccentsOrCapitalizationDifference(const int num) { - return (num < TWO_31ST_DIV_255 ? 255 * num : S_INT_MAX); - } - - static const int TWO_31ST_DIV_2 = S_INT_MAX / 2; - AK_FORCE_INLINE static void multiplyIntCapped(const int multiplier, int *base) { - const int temp = *base; - if (temp != S_INT_MAX) { - // Branch if multiplier == 2 for the optimization - if (multiplier < 0) { - if (DEBUG_DICT) { - ASSERT(false); - } - AKLOGI("--- Invalid multiplier: %d", multiplier); - } else if (multiplier == 0) { - *base = 0; - } else if (multiplier == 2) { - *base = TWO_31ST_DIV_2 >= temp ? temp << 1 : S_INT_MAX; - } else { - // TODO: This overflow check gives a wrong answer when, for example, - // temp = 2^16 + 1 and multiplier = 2^17 + 1. - // Fix this behavior. - const int tempRetval = temp * multiplier; - *base = tempRetval >= temp ? tempRetval : S_INT_MAX; - } - } - } - - AK_FORCE_INLINE static int powerIntCapped(const int base, const int n) { - if (n <= 0) return 1; - if (base == 2) { - return n < 31 ? 1 << n : S_INT_MAX; - } - int ret = base; - for (int i = 1; i < n; ++i) multiplyIntCapped(base, &ret); - return ret; - } - - AK_FORCE_INLINE static void multiplyRate(const int rate, int *freq) { - if (*freq != S_INT_MAX) { - if (*freq > 1000000) { - *freq /= 100; - multiplyIntCapped(rate, freq); - } else { - multiplyIntCapped(rate, freq); - *freq /= 100; - } - } - } - - inline int getSpaceProximityPos() const { - return mSpaceProximityPos; - } - inline int getMissingSpacePos() const { - return mMissingSpacePos; - } - - inline int getSkipPos() const { - return mSkipPos; - } - - inline int getExcessivePos() const { - return mExcessivePos; - } - - inline int getTransposedPos() const { - return mTransposedPos; - } - - inline void incrementInputIndex(); - inline void incrementOutputIndex(); - inline void startToTraverseAllNodes(); - inline bool isSingleQuote(const int c); - inline CorrectionType processSkipChar(const int c, const bool isTerminal, - const bool inputIndexIncremented); - inline CorrectionType processUnrelatedCorrectionType(); - inline void addCharToCurrentWord(const int c); - inline int getFinalProbabilityInternal(const int probability, int **word, int *wordLength, - const int inputSize); - - static const int TYPED_LETTER_MULTIPLIER = 2; - static const int FULL_WORD_MULTIPLIER = 2; - const ProximityInfo *mProximityInfo; - - bool mUseFullEditDistance; - bool mDoAutoCompletion; - int mMaxEditDistance; - int mMaxDepth; - int mInputSize; - int mSpaceProximityPos; - int mMissingSpacePos; - int mTerminalInputIndex; - int mTerminalOutputIndex; - int mMaxErrors; - - int mTotalTraverseCount; - - // The following arrays are state buffer. - int mWord[MAX_WORD_LENGTH]; - int mDistances[MAX_WORD_LENGTH]; - - // Edit distance calculation requires a buffer with (N+1)^2 length for the input length N. - // Caveat: Do not create multiple tables per thread as this table eats up RAM a lot. - int mEditDistanceTable[(MAX_WORD_LENGTH + 1) * (MAX_WORD_LENGTH + 1)]; - - CorrectionState mCorrectionStates[MAX_WORD_LENGTH]; - - // The following member variables are being used as cache values of the correction state. - bool mNeedsToTraverseAllNodes; - int mOutputIndex; - int mInputIndex; - - int mEquivalentCharCount; - int mProximityCount; - int mExcessiveCount; - int mTransposedCount; - int mSkippedCount; - - int mTransposedPos; - int mExcessivePos; - int mSkipPos; - - bool mLastCharExceeded; - - bool mMatching; - bool mProximityMatching; - bool mAdditionalProximityMatching; - bool mExceeding; - bool mTransposing; - bool mSkipping; - ProximityInfoState mProximityInfoState; -}; - -inline void Correction::incrementInputIndex() { - ++mInputIndex; -} - -AK_FORCE_INLINE void Correction::incrementOutputIndex() { - ++mOutputIndex; - mCorrectionStates[mOutputIndex].mParentIndex = mCorrectionStates[mOutputIndex - 1].mParentIndex; - mCorrectionStates[mOutputIndex].mChildCount = mCorrectionStates[mOutputIndex - 1].mChildCount; - mCorrectionStates[mOutputIndex].mSiblingPos = mCorrectionStates[mOutputIndex - 1].mSiblingPos; - mCorrectionStates[mOutputIndex].mInputIndex = mInputIndex; - mCorrectionStates[mOutputIndex].mNeedsToTraverseAllNodes = mNeedsToTraverseAllNodes; - - mCorrectionStates[mOutputIndex].mEquivalentCharCount = mEquivalentCharCount; - mCorrectionStates[mOutputIndex].mProximityCount = mProximityCount; - mCorrectionStates[mOutputIndex].mTransposedCount = mTransposedCount; - mCorrectionStates[mOutputIndex].mExcessiveCount = mExcessiveCount; - mCorrectionStates[mOutputIndex].mSkippedCount = mSkippedCount; - - mCorrectionStates[mOutputIndex].mSkipPos = mSkipPos; - mCorrectionStates[mOutputIndex].mTransposedPos = mTransposedPos; - mCorrectionStates[mOutputIndex].mExcessivePos = mExcessivePos; - - mCorrectionStates[mOutputIndex].mLastCharExceeded = mLastCharExceeded; - - mCorrectionStates[mOutputIndex].mMatching = mMatching; - mCorrectionStates[mOutputIndex].mProximityMatching = mProximityMatching; - mCorrectionStates[mOutputIndex].mAdditionalProximityMatching = mAdditionalProximityMatching; - mCorrectionStates[mOutputIndex].mTransposing = mTransposing; - mCorrectionStates[mOutputIndex].mExceeding = mExceeding; - mCorrectionStates[mOutputIndex].mSkipping = mSkipping; -} - -inline void Correction::startToTraverseAllNodes() { - mNeedsToTraverseAllNodes = true; -} - -AK_FORCE_INLINE bool Correction::isSingleQuote(const int c) { - const int userTypedChar = mProximityInfoState.getPrimaryCodePointAt(mInputIndex); - return (c == KEYCODE_SINGLE_QUOTE && userTypedChar != KEYCODE_SINGLE_QUOTE); -} - -AK_FORCE_INLINE Correction::CorrectionType Correction::processSkipChar(const int c, - const bool isTerminal, const bool inputIndexIncremented) { - addCharToCurrentWord(c); - mTerminalInputIndex = mInputIndex - (inputIndexIncremented ? 1 : 0); - mTerminalOutputIndex = mOutputIndex; - incrementOutputIndex(); - if (mNeedsToTraverseAllNodes && isTerminal) { - return TRAVERSE_ALL_ON_TERMINAL; - } - return TRAVERSE_ALL_NOT_ON_TERMINAL; -} - -inline Correction::CorrectionType Correction::processUnrelatedCorrectionType() { - // Needs to set mTerminalInputIndex and mTerminalOutputIndex before returning any CorrectionType - mTerminalInputIndex = mInputIndex; - mTerminalOutputIndex = mOutputIndex; - return UNRELATED; -} - -AK_FORCE_INLINE static void calcEditDistanceOneStep(int *editDistanceTable, const int *input, - const int inputSize, const int *output, const int outputLength) { - // TODO: Make sure that editDistance[0 ~ MAX_WORD_LENGTH] is not touched. - // Let dp[i][j] be editDistanceTable[i * (inputSize + 1) + j]. - // Assuming that dp[0][0] ... dp[outputLength - 1][inputSize] are already calculated, - // and calculate dp[ouputLength][0] ... dp[outputLength][inputSize]. - int *const current = editDistanceTable + outputLength * (inputSize + 1); - const int *const prev = editDistanceTable + (outputLength - 1) * (inputSize + 1); - const int *const prevprev = - outputLength >= 2 ? editDistanceTable + (outputLength - 2) * (inputSize + 1) : 0; - current[0] = outputLength; - const int co = toBaseLowerCase(output[outputLength - 1]); - const int prevCO = outputLength >= 2 ? toBaseLowerCase(output[outputLength - 2]) : 0; - for (int i = 1; i <= inputSize; ++i) { - const int ci = toBaseLowerCase(input[i - 1]); - const int cost = (ci == co) ? 0 : 1; - current[i] = min(current[i - 1] + 1, min(prev[i] + 1, prev[i - 1] + cost)); - if (i >= 2 && prevprev && ci == prevCO && co == toBaseLowerCase(input[i - 2])) { - current[i] = min(current[i], prevprev[i - 2] + 1); - } - } -} - -AK_FORCE_INLINE void Correction::addCharToCurrentWord(const int c) { - mWord[mOutputIndex] = c; - const int *primaryInputWord = mProximityInfoState.getPrimaryInputWord(); - calcEditDistanceOneStep(mEditDistanceTable, primaryInputWord, mInputSize, mWord, - mOutputIndex + 1); -} - -inline int Correction::getFinalProbabilityInternal(const int probability, int **word, - int *wordLength, const int inputSize) { - const int outputIndex = mTerminalOutputIndex; - const int inputIndex = mTerminalInputIndex; - *wordLength = outputIndex + 1; - *word = mWord; - int finalProbability= Correction::RankingAlgorithm::calculateFinalProbability( - inputIndex, outputIndex, probability, mEditDistanceTable, this, inputSize); - return finalProbability; -} - -} // namespace latinime -#endif // LATINIME_CORRECTION_H diff --git a/native/jni/src/correction_state.h b/native/jni/src/correction_state.h deleted file mode 100644 index a63d4aa94..000000000 --- a/native/jni/src/correction_state.h +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Copyright (C) 2011 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LATINIME_CORRECTION_STATE_H -#define LATINIME_CORRECTION_STATE_H - -#include <stdint.h> - -#include "defines.h" - -namespace latinime { - -struct CorrectionState { - int mParentIndex; - int mSiblingPos; - uint16_t mChildCount; - uint8_t mInputIndex; - - uint8_t mEquivalentCharCount; - uint8_t mProximityCount; - uint8_t mTransposedCount; - uint8_t mExcessiveCount; - uint8_t mSkippedCount; - - int8_t mTransposedPos; - int8_t mExcessivePos; - int8_t mSkipPos; // should be signed - - // TODO: int? - bool mLastCharExceeded; - - bool mMatching; - bool mTransposing; - bool mExceeding; - bool mSkipping; - bool mProximityMatching; - bool mAdditionalProximityMatching; - - bool mNeedsToTraverseAllNodes; -}; - -inline static void initCorrectionState(CorrectionState *state, const int rootPos, - const uint16_t childCount, const bool traverseAll) { - state->mParentIndex = -1; - state->mChildCount = childCount; - state->mInputIndex = 0; - state->mSiblingPos = rootPos; - state->mNeedsToTraverseAllNodes = traverseAll; - - state->mTransposedPos = -1; - state->mExcessivePos = -1; - state->mSkipPos = -1; - - state->mEquivalentCharCount = 0; - state->mProximityCount = 0; - state->mTransposedCount = 0; - state->mExcessiveCount = 0; - state->mSkippedCount = 0; - - state->mLastCharExceeded = false; - - state->mMatching = false; - state->mProximityMatching = false; - state->mTransposing = false; - state->mExceeding = false; - state->mSkipping = false; - state->mAdditionalProximityMatching = false; -} -} // namespace latinime -#endif // LATINIME_CORRECTION_STATE_H diff --git a/native/jni/src/defines.h b/native/jni/src/defines.h index eb59744f6..34a646f80 100644 --- a/native/jni/src/defines.h +++ b/native/jni/src/defines.h @@ -35,46 +35,74 @@ // Must be equal to ProximityInfo.MAX_PROXIMITY_CHARS_SIZE in Java #define MAX_PROXIMITY_CHARS_SIZE 16 #define ADDITIONAL_PROXIMITY_CHAR_DELIMITER_CODE 2 +#define NELEMS(x) (sizeof(x) / sizeof((x)[0])) -#if defined(FLAG_DO_PROFILE) || defined(FLAG_DBG) -#include <android/log.h> -#ifndef LOG_TAG -#define LOG_TAG "LatinIME: " -#endif // LOG_TAG -#define AKLOGE(fmt, ...) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, fmt, ##__VA_ARGS__) -#define AKLOGI(fmt, ...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, fmt, ##__VA_ARGS__) - -#define DUMP_RESULT(words, frequencies) do { dumpResult(words, frequencies); } while (0) -#define DUMP_WORD(word, length) do { dumpWord(word, length); } while (0) -#define INTS_TO_CHARS(input, length, output) do { \ - intArrayToCharArray(input, length, output); } while (0) - -// TODO: Support full UTF-8 conversion -AK_FORCE_INLINE static int intArrayToCharArray(const int *source, const int sourceSize, - char *dest) { +AK_FORCE_INLINE static int intArrayToCharArray(const int *const source, const int sourceSize, + char *dest, const int destSize) { + // We want to always terminate with a 0 char, so stop one short of the length to make + // sure there is room. + const int destLimit = destSize - 1; int si = 0; int di = 0; - while (si < sourceSize && di < MAX_WORD_LENGTH - 1 && 0 != source[si]) { + while (si < sourceSize && di < destLimit && 0 != source[si]) { const int codePoint = source[si++]; - if (codePoint < 0x7F) { + if (codePoint < 0x7F) { // One byte dest[di++] = codePoint; - } else if (codePoint < 0x7FF) { + } else if (codePoint < 0x7FF) { // Two bytes + if (di + 1 >= destLimit) break; dest[di++] = 0xC0 + (codePoint >> 6); dest[di++] = 0x80 + (codePoint & 0x3F); - } else if (codePoint < 0xFFFF) { + } else if (codePoint < 0xFFFF) { // Three bytes + if (di + 2 >= destLimit) break; dest[di++] = 0xE0 + (codePoint >> 12); - dest[di++] = 0x80 + ((codePoint & 0xFC0) >> 6); + dest[di++] = 0x80 + ((codePoint >> 6) & 0x3F); + dest[di++] = 0x80 + (codePoint & 0x3F); + } else if (codePoint <= 0x1FFFFF) { // Four bytes + if (di + 3 >= destLimit) break; + dest[di++] = 0xF0 + (codePoint >> 18); + dest[di++] = 0x80 + ((codePoint >> 12) & 0x3F); + dest[di++] = 0x80 + ((codePoint >> 6) & 0x3F); dest[di++] = 0x80 + (codePoint & 0x3F); + } else if (codePoint <= 0x3FFFFFF) { // Five bytes + if (di + 4 >= destLimit) break; + dest[di++] = 0xF8 + (codePoint >> 24); + dest[di++] = 0x80 + ((codePoint >> 18) & 0x3F); + dest[di++] = 0x80 + ((codePoint >> 12) & 0x3F); + dest[di++] = 0x80 + ((codePoint >> 6) & 0x3F); + dest[di++] = codePoint & 0x3F; + } else if (codePoint <= 0x7FFFFFFF) { // Six bytes + if (di + 5 >= destLimit) break; + dest[di++] = 0xFC + (codePoint >> 30); + dest[di++] = 0x80 + ((codePoint >> 24) & 0x3F); + dest[di++] = 0x80 + ((codePoint >> 18) & 0x3F); + dest[di++] = 0x80 + ((codePoint >> 12) & 0x3F); + dest[di++] = 0x80 + ((codePoint >> 6) & 0x3F); + dest[di++] = codePoint & 0x3F; + } else { + // Not a code point... skip. } } dest[di] = 0; return di; } +#if defined(FLAG_DO_PROFILE) || defined(FLAG_DBG) +#include <android/log.h> +#ifndef LOG_TAG +#define LOG_TAG "LatinIME: " +#endif // LOG_TAG +#define AKLOGE(fmt, ...) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, fmt, ##__VA_ARGS__) +#define AKLOGI(fmt, ...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, fmt, ##__VA_ARGS__) + +#define DUMP_RESULT(words, frequencies) do { dumpResult(words, frequencies); } while (0) +#define DUMP_WORD(word, length) do { dumpWord(word, length); } while (0) +#define INTS_TO_CHARS(input, length, output, outlength) do { \ + intArrayToCharArray(input, length, output, outlength); } while (0) + static inline void dumpWordInfo(const int *word, const int length, const int rank, const int probability) { static char charBuf[50]; - const int N = intArrayToCharArray(word, length, charBuf); + const int N = intArrayToCharArray(word, length, charBuf, NELEMS(charBuf)); if (N > 1) { AKLOGI("%2d [ %s ] (%d)", rank, charBuf, probability); } @@ -90,7 +118,7 @@ static inline void dumpResult(const int *outWords, const int *frequencies) { static AK_FORCE_INLINE void dumpWord(const int *word, const int length) { static char charBuf[50]; - const int N = intArrayToCharArray(word, length, charBuf); + const int N = intArrayToCharArray(word, length, charBuf, NELEMS(charBuf)); if (N > 1) { AKLOGI("[ %s ]", charBuf); } @@ -203,14 +231,12 @@ static inline void prof_out(void) { #define DEBUG_DICT true #define DEBUG_DICT_FULL false #define DEBUG_EDIT_DISTANCE false -#define DEBUG_SHOW_FOUND_WORD false #define DEBUG_NODE DEBUG_DICT_FULL #define DEBUG_TRACE DEBUG_DICT_FULL #define DEBUG_PROXIMITY_INFO false #define DEBUG_PROXIMITY_CHARS false #define DEBUG_CORRECTION false #define DEBUG_CORRECTION_FREQ false -#define DEBUG_WORDS_PRIORITY_QUEUE false #define DEBUG_SAMPLING_POINTS false #define DEBUG_POINTS_PROBABILITY false #define DEBUG_DOUBLE_LETTER false @@ -229,14 +255,12 @@ static inline void prof_out(void) { #define DEBUG_DICT false #define DEBUG_DICT_FULL false #define DEBUG_EDIT_DISTANCE false -#define DEBUG_SHOW_FOUND_WORD false #define DEBUG_NODE false #define DEBUG_TRACE false #define DEBUG_PROXIMITY_INFO false #define DEBUG_PROXIMITY_CHARS false #define DEBUG_CORRECTION false #define DEBUG_CORRECTION_FREQ false -#define DEBUG_WORDS_PRIORITY_QUEUE false #define DEBUG_SAMPLING_POINTS false #define DEBUG_POINTS_PROBABILITY false #define DEBUG_DOUBLE_LETTER false @@ -268,82 +292,25 @@ static inline void prof_out(void) { // of the binary dictionary where a {key,value} string pair scheme is used. #define LARGEST_INT_DIGIT_COUNT 11 -// Define this to use mmap() for dictionary loading. Undefine to use malloc() instead of mmap(). -// We measured and compared performance of both, and found mmap() is fairly good in terms of -// loading time, and acceptable even for several initial lookups which involve page faults. -#define USE_MMAP_FOR_DICTIONARY - -#define NOT_VALID_WORD (-99) +#define NOT_A_VALID_WORD_POS (-99) #define NOT_A_CODE_POINT (-1) #define NOT_A_DISTANCE (-1) #define NOT_A_COORDINATE (-1) -#define MATCH_CHAR_WITHOUT_DISTANCE_INFO (-2) -#define PROXIMITY_CHAR_WITHOUT_DISTANCE_INFO (-3) -#define ADDITIONAL_PROXIMITY_CHAR_DISTANCE_INFO (-4) #define NOT_AN_INDEX (-1) #define NOT_A_PROBABILITY (-1) +#define NOT_A_DICT_POS (S_INT_MIN) #define KEYCODE_SPACE ' ' #define KEYCODE_SINGLE_QUOTE '\'' #define KEYCODE_HYPHEN_MINUS '-' -#define CALIBRATE_SCORE_BY_TOUCH_COORDINATES true -#define SUGGEST_MULTIPLE_WORDS true -#define USE_SUGGEST_INTERFACE_FOR_TYPING true #define SUGGEST_INTERFACE_OUTPUT_SCALE 1000000.0f - -// The following "rate"s are used as a multiplier before dividing by 100, so they are in percent. -#define WORDS_WITH_MISSING_CHARACTER_DEMOTION_RATE 80 -#define WORDS_WITH_MISSING_CHARACTER_DEMOTION_START_POS_10X 12 -#define WORDS_WITH_MISSING_SPACE_CHARACTER_DEMOTION_RATE 58 -#define WORDS_WITH_MISTYPED_SPACE_DEMOTION_RATE 50 -#define WORDS_WITH_EXCESSIVE_CHARACTER_DEMOTION_RATE 75 -#define WORDS_WITH_EXCESSIVE_CHARACTER_OUT_OF_PROXIMITY_DEMOTION_RATE 75 -#define WORDS_WITH_TRANSPOSED_CHARACTERS_DEMOTION_RATE 70 -#define FULL_MATCHED_WORDS_PROMOTION_RATE 120 -#define WORDS_WITH_PROXIMITY_CHARACTER_DEMOTION_RATE 90 -#define WORDS_WITH_ADDITIONAL_PROXIMITY_CHARACTER_DEMOTION_RATE 70 -#define WORDS_WITH_MATCH_SKIP_PROMOTION_RATE 105 -#define WORDS_WITH_JUST_ONE_CORRECTION_PROMOTION_RATE 148 -#define WORDS_WITH_JUST_ONE_CORRECTION_PROMOTION_MULTIPLIER 3 -#define CORRECTION_COUNT_RATE_DEMOTION_RATE_BASE 45 -#define INPUT_EXCEEDS_OUTPUT_DEMOTION_RATE 70 -#define FIRST_CHAR_DIFFERENT_DEMOTION_RATE 96 -#define TWO_WORDS_CAPITALIZED_DEMOTION_RATE 50 -#define TWO_WORDS_CORRECTION_DEMOTION_BASE 80 -#define TWO_WORDS_PLUS_OTHER_ERROR_CORRECTION_DEMOTION_DIVIDER 1 -#define ZERO_DISTANCE_PROMOTION_RATE 110.0f -#define NEUTRAL_SCORE_SQUARED_RADIUS 8.0f -#define HALF_SCORE_SQUARED_RADIUS 32.0f #define MAX_PROBABILITY 255 #define MAX_BIGRAM_ENCODED_PROBABILITY 15 // Assuming locale strings such as en_US, sr-Latn etc. #define MAX_LOCALE_STRING_LENGTH 10 -// Word limit for sub queues used in WordsPriorityQueuePool. Sub queues are temporary queues used -// for better performance. -// Holds up to 1 candidate for each word -#define SUB_QUEUE_MAX_WORDS 1 -#define SUB_QUEUE_MAX_COUNT 10 -#define SUB_QUEUE_MIN_WORD_LENGTH 4 -// TODO: Extend this limitation -#define MULTIPLE_WORDS_SUGGESTION_MAX_WORDS 5 -// TODO: Remove this limitation -#define MULTIPLE_WORDS_SUGGESTION_MAX_WORD_LENGTH 12 -// TODO: Remove this limitation -#define MULTIPLE_WORDS_SUGGESTION_MAX_TOTAL_TRAVERSE_COUNT 45 -#define MULTIPLE_WORDS_DEMOTION_RATE 80 -#define MIN_INPUT_LENGTH_FOR_THREE_OR_MORE_WORDS_CORRECTION 6 - -#define TWO_WORDS_CORRECTION_WITH_OTHER_ERROR_THRESHOLD 0.35f -#define START_TWO_WORDS_CORRECTION_THRESHOLD 0.185f -/* heuristic... This should be changed if we change the unit of the probability. */ -#define SUPPRESS_SHORT_MULTIPLE_WORDS_THRESHOLD_FREQ (MAX_PROBABILITY * 58 / 100) - -#define MAX_DEPTH_MULTIPLIER 3 -#define FIRST_WORD_INDEX 0 - // Max value for length, distance and probability which are used in weighting // TODO: Remove #define MAX_VALUE_FOR_WEIGHTING 10000000 @@ -351,48 +318,20 @@ static inline void prof_out(void) { // The max number of the keys in one keyboard layout #define MAX_KEY_COUNT_IN_A_KEYBOARD 64 -// TODO: Reduce this constant if possible; check the maximum number of digraphs in the same -// word in the dictionary for languages with digraphs, like German and French -#define DEFAULT_MAX_DIGRAPH_SEARCH_DEPTH 5 - -#define MIN_USER_TYPED_LENGTH_FOR_MULTIPLE_WORD_SUGGESTION 3 - // TODO: Remove #define MAX_POINTER_COUNT 1 #define MAX_POINTER_COUNT_G 2 -// Size, in bytes, of the bloom filter index for bigrams -// 128 gives us 1024 buckets. The probability of false positive is (1 - e ** (-kn/m))**k, -// where k is the number of hash functions, n the number of bigrams, and m the number of -// bits we can test. -// At the moment 100 is the maximum number of bigrams for a word with the current -// dictionaries, so n = 100. 1024 buckets give us m = 1024. -// With 1 hash function, our false positive rate is about 9.3%, which should be enough for -// our uses since we are only using this to increase average performance. For the record, -// k = 2 gives 3.1% and k = 3 gives 1.6%. With k = 1, making m = 2048 gives 4.8%, -// and m = 4096 gives 2.4%. -#define BIGRAM_FILTER_BYTE_SIZE 128 -// Must be smaller than BIGRAM_FILTER_BYTE_SIZE * 8, and preferably prime. 1021 is the largest -// prime under 128 * 8. -#define BIGRAM_FILTER_MODULO 1021 -#if BIGRAM_FILTER_BYTE_SIZE * 8 < BIGRAM_FILTER_MODULO -#error "BIGRAM_FILTER_MODULO is larger than BIGRAM_FILTER_BYTE_SIZE" -#endif - -// Max number of bigram maps (previous word contexts) to be cached. Increasing this number could -// improve bigram lookup speed for multi-word suggestions, but at the cost of more memory usage. -// Also, there are diminishing returns since the most frequently used bigrams are typically near -// the beginning of the input and are thus the first ones to be cached. Note that these bigrams -// are reset for each new composing word. -#define MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP 25 -// Most common previous word contexts currently have 100 bigrams -#define DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP 100 +// Queue IDs and size for DicNodesCache +#define DIC_NODES_CACHE_INITIAL_QUEUE_ID_ACTIVE 0 +#define DIC_NODES_CACHE_INITIAL_QUEUE_ID_NEXT_ACTIVE 1 +#define DIC_NODES_CACHE_INITIAL_QUEUE_ID_TERMINAL 2 +#define DIC_NODES_CACHE_INITIAL_QUEUE_ID_CACHE_FOR_CONTINUOUS_SUGGESTION 3 +#define DIC_NODES_CACHE_PRIORITY_QUEUES_SIZE 4 template<typename T> AK_FORCE_INLINE const T &min(const T &a, const T &b) { return a < b ? a : b; } template<typename T> AK_FORCE_INLINE const T &max(const T &a, const T &b) { return a > b ? a : b; } -#define NELEMS(x) (sizeof(x) / sizeof((x)[0])) - // DEBUG #define INPUTLENGTH_FOR_DEBUG (-1) #define MIN_OUTPUT_INDEX_FOR_DEBUG (-1) @@ -442,6 +381,7 @@ typedef enum { CT_TRANSPOSITION, CT_COMPLETION, CT_TERMINAL, + CT_TERMINAL_INSERTION, // Create new word with space omission CT_NEW_WORD_SPACE_OMITTION, // Create new word with space substitution diff --git a/native/jni/src/dic_traverse_wrapper.h b/native/jni/src/dic_traverse_wrapper.h deleted file mode 100644 index 1108a45c8..000000000 --- a/native/jni/src/dic_traverse_wrapper.h +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright (C) 2012 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LATINIME_DIC_TRAVERSE_WRAPPER_H -#define LATINIME_DIC_TRAVERSE_WRAPPER_H - -#include "defines.h" -#include "jni.h" - -namespace latinime { -class Dictionary; -// TODO: Remove -class DicTraverseWrapper { - public: - static void *getDicTraverseSession(JNIEnv *env, jstring locale) { - if (sDicTraverseSessionFactoryMethod) { - return sDicTraverseSessionFactoryMethod(env, locale); - } - return 0; - } - static void initDicTraverseSession(void *traverseSession, const Dictionary *const dictionary, - const int *prevWord, const int prevWordLength) { - if (sDicTraverseSessionInitMethod) { - sDicTraverseSessionInitMethod(traverseSession, dictionary, prevWord, prevWordLength); - } - } - static void releaseDicTraverseSession(void *traverseSession) { - if (sDicTraverseSessionReleaseMethod) { - sDicTraverseSessionReleaseMethod(traverseSession); - } - } - static void setTraverseSessionFactoryMethod(void *(*factoryMethod)(JNIEnv *, jstring)) { - sDicTraverseSessionFactoryMethod = factoryMethod; - } - static void setTraverseSessionInitMethod( - void (*initMethod)(void *, const Dictionary *const, const int *, const int)) { - sDicTraverseSessionInitMethod = initMethod; - } - static void setTraverseSessionReleaseMethod(void (*releaseMethod)(void *)) { - sDicTraverseSessionReleaseMethod = releaseMethod; - } - - private: - DISALLOW_IMPLICIT_CONSTRUCTORS(DicTraverseWrapper); - static void *(*sDicTraverseSessionFactoryMethod)(JNIEnv *, jstring); - static void (*sDicTraverseSessionInitMethod)( - void *, const Dictionary *const, const int *, const int); - static void (*sDicTraverseSessionReleaseMethod)(void *); -}; -} // namespace latinime -#endif // LATINIME_DIC_TRAVERSE_WRAPPER_H diff --git a/native/jni/src/dictionary.cpp b/native/jni/src/dictionary.cpp deleted file mode 100644 index dadb2bab2..000000000 --- a/native/jni/src/dictionary.cpp +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Copyright (C) 2009, The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#define LOG_TAG "LatinIME: dictionary.cpp" - -#include "dictionary.h" - -#include <map> // TODO: remove -#include <stdint.h> - -#include "bigram_dictionary.h" -#include "binary_format.h" -#include "defines.h" -#include "dic_traverse_wrapper.h" -#include "suggest/core/suggest.h" -#include "suggest/policyimpl/gesture/gesture_suggest_policy_factory.h" -#include "suggest/policyimpl/typing/typing_suggest_policy_factory.h" -#include "unigram_dictionary.h" - -namespace latinime { - -Dictionary::Dictionary(void *dict, int dictSize, int mmapFd, int dictBufAdjust) - : mDict(static_cast<unsigned char *>(dict)), - mOffsetDict((static_cast<unsigned char *>(dict)) - + BinaryFormat::getHeaderSize(mDict, dictSize)), - mDictSize(dictSize), mMmapFd(mmapFd), mDictBufAdjust(dictBufAdjust), - mUnigramDictionary(new UnigramDictionary(mOffsetDict, - BinaryFormat::getFlags(mDict, dictSize))), - mBigramDictionary(new BigramDictionary(mOffsetDict)), - mGestureSuggest(new Suggest(GestureSuggestPolicyFactory::getGestureSuggestPolicy())), - mTypingSuggest(new Suggest(TypingSuggestPolicyFactory::getTypingSuggestPolicy())) { -} - -Dictionary::~Dictionary() { - delete mUnigramDictionary; - delete mBigramDictionary; - delete mGestureSuggest; - delete mTypingSuggest; -} - -int Dictionary::getSuggestions(ProximityInfo *proximityInfo, void *traverseSession, - int *xcoordinates, int *ycoordinates, int *times, int *pointerIds, int *inputCodePoints, - int inputSize, int *prevWordCodePoints, int prevWordLength, int commitPoint, bool isGesture, - bool useFullEditDistance, int *outWords, int *frequencies, int *spaceIndices, - int *outputTypes) const { - int result = 0; - if (isGesture) { - DicTraverseWrapper::initDicTraverseSession( - traverseSession, this, prevWordCodePoints, prevWordLength); - result = mGestureSuggest->getSuggestions(proximityInfo, traverseSession, xcoordinates, - ycoordinates, times, pointerIds, inputCodePoints, inputSize, commitPoint, outWords, - frequencies, spaceIndices, outputTypes); - if (DEBUG_DICT) { - DUMP_RESULT(outWords, frequencies); - } - return result; - } else { - if (USE_SUGGEST_INTERFACE_FOR_TYPING) { - DicTraverseWrapper::initDicTraverseSession( - traverseSession, this, prevWordCodePoints, prevWordLength); - result = mTypingSuggest->getSuggestions(proximityInfo, traverseSession, xcoordinates, - ycoordinates, times, pointerIds, inputCodePoints, inputSize, commitPoint, - outWords, frequencies, spaceIndices, outputTypes); - if (DEBUG_DICT) { - DUMP_RESULT(outWords, frequencies); - } - return result; - } else { - std::map<int, int> bigramMap; - uint8_t bigramFilter[BIGRAM_FILTER_BYTE_SIZE]; - mBigramDictionary->fillBigramAddressToProbabilityMapAndFilter(prevWordCodePoints, - prevWordLength, &bigramMap, bigramFilter); - result = mUnigramDictionary->getSuggestions(proximityInfo, xcoordinates, ycoordinates, - inputCodePoints, inputSize, &bigramMap, bigramFilter, useFullEditDistance, - outWords, frequencies, outputTypes); - return result; - } - } -} - -int Dictionary::getBigrams(const int *word, int length, int *inputCodePoints, int inputSize, - int *outWords, int *frequencies, int *outputTypes) const { - if (length <= 0) return 0; - return mBigramDictionary->getBigrams(word, length, inputCodePoints, inputSize, outWords, - frequencies, outputTypes); -} - -int Dictionary::getProbability(const int *word, int length) const { - return mUnigramDictionary->getProbability(word, length); -} - -bool Dictionary::isValidBigram(const int *word1, int length1, const int *word2, int length2) const { - return mBigramDictionary->isValidBigram(word1, length1, word2, length2); -} - -int Dictionary::getDictFlags() const { - return mUnigramDictionary->getDictFlags(); -} - -} // namespace latinime diff --git a/native/jni/src/geometry_utils.h b/native/jni/src/geometry_utils.h deleted file mode 100644 index 4cbb127e8..000000000 --- a/native/jni/src/geometry_utils.h +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright (C) 2012 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LATINIME_GEOMETRY_UTILS_H -#define LATINIME_GEOMETRY_UTILS_H - -#include <cmath> - -#include "defines.h" - -#define ROUND_FLOAT_10000(f) ((f) < 1000.0f && (f) > 0.001f) \ - ? (floorf((f) * 10000.0f) / 10000.0f) : (f) - -namespace latinime { - -static inline float SQUARE_FLOAT(const float x) { return x * x; } - -static AK_FORCE_INLINE float getAngle(const int x1, const int y1, const int x2, const int y2) { - const int dx = x1 - x2; - const int dy = y1 - y2; - if (dx == 0 && dy == 0) return 0.0f; - return atan2f(static_cast<float>(dy), static_cast<float>(dx)); -} - -static AK_FORCE_INLINE float getAngleDiff(const float a1, const float a2) { - const float deltaA = fabsf(a1 - a2); - const float diff = ROUND_FLOAT_10000(deltaA); - if (diff > M_PI_F) { - const float normalizedDiff = 2.0f * M_PI_F - diff; - return ROUND_FLOAT_10000(normalizedDiff); - } - return diff; -} - -static AK_FORCE_INLINE int getDistanceInt(const int x1, const int y1, const int x2, - const int y2) { - return static_cast<int>(hypotf(static_cast<float>(x1 - x2), static_cast<float>(y1 - y2))); -} -} // namespace latinime -#endif // LATINIME_GEOMETRY_UTILS_H diff --git a/native/jni/src/multi_bigram_map.h b/native/jni/src/multi_bigram_map.h deleted file mode 100644 index 7e1b6301f..000000000 --- a/native/jni/src/multi_bigram_map.h +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Copyright (C) 2013 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LATINIME_MULTI_BIGRAM_MAP_H -#define LATINIME_MULTI_BIGRAM_MAP_H - -#include <cstring> -#include <stdint.h> - -#include "defines.h" -#include "binary_format.h" -#include "hash_map_compat.h" - -namespace latinime { - -// Class for caching bigram maps for multiple previous word contexts. This is useful since the -// algorithm needs to look up the set of bigrams for every word pair that occurs in every -// multi-word suggestion. -class MultiBigramMap { - public: - MultiBigramMap() : mBigramMaps() {} - ~MultiBigramMap() {} - - // Look up the bigram probability for the given word pair from the cached bigram maps. - // Also caches the bigrams if there is space remaining and they have not been cached already. - int getBigramProbability(const uint8_t *const dicRoot, const int wordPosition, - const int nextWordPosition, const int unigramProbability) { - hash_map_compat<int, BigramMap>::const_iterator mapPosition = - mBigramMaps.find(wordPosition); - if (mapPosition != mBigramMaps.end()) { - return mapPosition->second.getBigramProbability(nextWordPosition, unigramProbability); - } - if (mBigramMaps.size() < MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP) { - addBigramsForWordPosition(dicRoot, wordPosition); - return mBigramMaps[wordPosition].getBigramProbability( - nextWordPosition, unigramProbability); - } - return BinaryFormat::getBigramProbability( - dicRoot, wordPosition, nextWordPosition, unigramProbability); - } - - void clear() { - mBigramMaps.clear(); - } - - private: - DISALLOW_COPY_AND_ASSIGN(MultiBigramMap); - - class BigramMap { - public: - BigramMap() : mBigramMap(DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP) {} - ~BigramMap() {} - - void init(const uint8_t *const dicRoot, int position) { - BinaryFormat::fillBigramProbabilityToHashMap(dicRoot, position, &mBigramMap); - } - - inline int getBigramProbability(const int nextWordPosition, const int unigramProbability) - const { - return BinaryFormat::getBigramProbabilityFromHashMap( - nextWordPosition, &mBigramMap, unigramProbability); - } - - private: - // Note: Default copy constructor needed for use in hash_map. - hash_map_compat<int, int> mBigramMap; - }; - - void addBigramsForWordPosition(const uint8_t *const dicRoot, const int position) { - mBigramMaps[position].init(dicRoot, position); - } - - hash_map_compat<int, BigramMap> mBigramMaps; -}; -} // namespace latinime -#endif // LATINIME_MULTI_BIGRAM_MAP_H diff --git a/native/jni/src/suggest/core/dicnode/dic_node.cpp b/native/jni/src/suggest/core/dicnode/dic_node.cpp index 8c48c587b..de088c7d0 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node.cpp +++ b/native/jni/src/suggest/core/dicnode/dic_node.cpp @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "dic_node.h" +#include "suggest/core/dicnode/dic_node.h" namespace latinime { diff --git a/native/jni/src/suggest/core/dicnode/dic_node.h b/native/jni/src/suggest/core/dicnode/dic_node.h index 4225bb3e5..cdd9f59aa 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node.h +++ b/native/jni/src/suggest/core/dicnode/dic_node.h @@ -17,26 +17,27 @@ #ifndef LATINIME_DIC_NODE_H #define LATINIME_DIC_NODE_H -#include "char_utils.h" #include "defines.h" -#include "dic_node_state.h" -#include "dic_node_profiler.h" -#include "dic_node_properties.h" -#include "dic_node_release_listener.h" -#include "digraph_utils.h" +#include "suggest/core/dicnode/dic_node_profiler.h" +#include "suggest/core/dicnode/dic_node_release_listener.h" +#include "suggest/core/dicnode/internal/dic_node_state.h" +#include "suggest/core/dicnode/internal/dic_node_properties.h" +#include "suggest/core/dictionary/digraph_utils.h" +#include "utils/char_utils.h" #if DEBUG_DICT #define LOGI_SHOW_ADD_COST_PROP \ do { char charBuf[50]; \ - INTS_TO_CHARS(getOutputWordBuf(), getDepth(), charBuf); \ + INTS_TO_CHARS(getOutputWordBuf(), getNodeCodePointCount(), charBuf, NELEMS(charBuf)); \ AKLOGI("%20s, \"%c\", size = %03d, total = %03d, index(0) = %02d, dist = %.4f, %s,,", \ __FUNCTION__, getNodeCodePoint(), inputSize, getTotalInputIndex(), \ getInputIndex(0), getNormalizedCompoundDistance(), charBuf); } while (0) #define DUMP_WORD_AND_SCORE(header) \ do { char charBuf[50]; char prevWordCharBuf[50]; \ - INTS_TO_CHARS(getOutputWordBuf(), getDepth(), charBuf); \ + INTS_TO_CHARS(getOutputWordBuf(), getNodeCodePointCount(), charBuf, NELEMS(charBuf)); \ INTS_TO_CHARS(mDicNodeState.mDicNodeStatePrevWord.mPrevWord, \ - mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(), prevWordCharBuf); \ + mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(), prevWordCharBuf, \ + NELEMS(prevWordCharBuf)); \ AKLOGI("#%8s, %5f, %5f, %5f, %5f, %s, %s, %d,,", header, \ getSpatialDistanceForScoring(), getLanguageDistanceForScoring(), \ getNormalizedCompoundDistance(), getRawLength(), prevWordCharBuf, charBuf, \ @@ -51,6 +52,11 @@ namespace latinime { // This struct is purely a bucket to return values. No instances of this struct should be kept. struct DicNode_InputStateG { + DicNode_InputStateG() + : mNeedsToUpdateInputStateG(false), mPointerId(0), mInputIndex(0), + mPrevCodePoint(0), mTerminalDiffCost(0.0f), mRawLength(0.0f), + mDoubleLetterLevel(NOT_A_DOUBLE_LETTER) {} + bool mNeedsToUpdateInputStateG; int mPointerId; int16_t mInputIndex; @@ -92,7 +98,6 @@ class DicNode { DicNode &operator=(const DicNode &dicNode); virtual ~DicNode() {} - // TODO: minimize arguments by looking binary_format // Init for copy void initByCopy(const DicNode *dicNode) { mIsUsed = true; @@ -102,35 +107,28 @@ class DicNode { PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); } - // TODO: minimize arguments by looking binary_format // Init for root with prevWordNodePos which is used for bigram - void initAsRoot(const int pos, const int childrenPos, const int childrenCount, - const int prevWordNodePos) { + void initAsRoot(const int rootGroupPos, const int prevWordNodePos) { mIsUsed = true; mIsCachedForNextSuggestion = false; mDicNodeProperties.init( - pos, 0, childrenPos, 0, 0, 0, childrenCount, 0, 0, false, false, true, 0, 0); + NOT_A_VALID_WORD_POS /* pos */, rootGroupPos, NOT_A_CODE_POINT /* nodeCodePoint */, + NOT_A_PROBABILITY /* probability */, false /* isTerminal */, + true /* hasChildren */, false /* isBlacklistedOrNotAWord */, 0 /* depth */, + 0 /* terminalDepth */); mDicNodeState.init(prevWordNodePos); PROF_NODE_RESET(mProfiler); } - void initAsPassingChild(DicNode *parentNode) { - mIsUsed = true; - mIsCachedForNextSuggestion = parentNode->mIsCachedForNextSuggestion; - const int c = parentNode->getNodeTypedCodePoint(); - mDicNodeProperties.init(&parentNode->mDicNodeProperties, c); - mDicNodeState.init(&parentNode->mDicNodeState); - PROF_NODE_COPY(&parentNode->mProfiler, mProfiler); - } - - // TODO: minimize arguments by looking binary_format // Init for root with previous word - void initAsRootWithPreviousWord(DicNode *dicNode, const int pos, const int childrenPos, - const int childrenCount) { + void initAsRootWithPreviousWord(DicNode *dicNode, const int rootGroupPos) { mIsUsed = true; - mIsCachedForNextSuggestion = false; + mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion; mDicNodeProperties.init( - pos, 0, childrenPos, 0, 0, 0, childrenCount, 0, 0, false, false, true, 0, 0); + NOT_A_VALID_WORD_POS /* pos */, rootGroupPos, NOT_A_CODE_POINT /* nodeCodePoint */, + NOT_A_PROBABILITY /* probability */, false /* isTerminal */, + true /* hasChildren */, false /* isBlacklistedOrNotAWord */, 0 /* depth */, + 0 /* terminalDepth */); // TODO: Move to dicNodeState? mDicNodeState.mDicNodeStateOutput.init(); // reset for next word mDicNodeState.mDicNodeStateInput.init( @@ -150,21 +148,28 @@ class DicNode { PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); } - // TODO: minimize arguments by looking binary_format - void initAsChild(DicNode *dicNode, const int pos, const uint8_t flags, const int childrenPos, - const int attributesPos, const int siblingPos, const int nodeCodePoint, - const int childrenCount, const int probability, const int bigramProbability, - const bool isTerminal, const bool hasMultipleChars, const bool hasChildren, - const uint16_t additionalSubwordLength, const int *additionalSubword) { + void initAsPassingChild(DicNode *parentNode) { + mIsUsed = true; + mIsCachedForNextSuggestion = parentNode->mIsCachedForNextSuggestion; + const int c = parentNode->getNodeTypedCodePoint(); + mDicNodeProperties.init(&parentNode->mDicNodeProperties, c); + mDicNodeState.init(&parentNode->mDicNodeState); + PROF_NODE_COPY(&parentNode->mProfiler, mProfiler); + } + + void initAsChild(const DicNode *const dicNode, const int pos, const int childrenPos, + const int probability, const bool isTerminal, const bool hasChildren, + const bool isBlacklistedOrNotAWord, const uint16_t mergedNodeCodePointCount, + const int *const mergedNodeCodePoints) { mIsUsed = true; - uint16_t newDepth = static_cast<uint16_t>(dicNode->getDepth() + 1); + uint16_t newDepth = static_cast<uint16_t>(dicNode->getNodeCodePointCount() + 1); mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion; const uint16_t newLeavingDepth = static_cast<uint16_t>( - dicNode->mDicNodeProperties.getLeavingDepth() + additionalSubwordLength); - mDicNodeProperties.init(pos, flags, childrenPos, attributesPos, siblingPos, nodeCodePoint, - childrenCount, probability, bigramProbability, isTerminal, hasMultipleChars, - hasChildren, newDepth, newLeavingDepth); - mDicNodeState.init(&dicNode->mDicNodeState, additionalSubwordLength, additionalSubword); + dicNode->mDicNodeProperties.getLeavingDepth() + mergedNodeCodePointCount); + mDicNodeProperties.init(pos, childrenPos, mergedNodeCodePoints[0], probability, + isTerminal, hasChildren, isBlacklistedOrNotAWord, newDepth, newLeavingDepth); + mDicNodeState.init(&dicNode->mDicNodeState, mergedNodeCodePointCount, + mergedNodeCodePoints); PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); } @@ -180,7 +185,7 @@ class DicNode { } bool isRoot() const { - return getDepth() == 0; + return getNodeCodePointCount() == 0; } bool hasChildren() const { @@ -188,12 +193,12 @@ class DicNode { } bool isLeavingNode() const { - ASSERT(getDepth() <= getLeavingDepth()); - return getDepth() == getLeavingDepth(); + ASSERT(getNodeCodePointCount() <= mDicNodeProperties.getLeavingDepth()); + return getNodeCodePointCount() == mDicNodeProperties.getLeavingDepth(); } AK_FORCE_INLINE bool isFirstLetter() const { - return getDepth() == 1; + return getNodeCodePointCount() == 1; } bool isCached() const { @@ -206,26 +211,30 @@ class DicNode { // Used to expand the node in DicNodeUtils int getNodeTypedCodePoint() const { - return mDicNodeState.mDicNodeStateOutput.getCodePointAt(getDepth()); + return mDicNodeState.mDicNodeStateOutput.getCodePointAt(getNodeCodePointCount()); } - bool isImpossibleBigramWord() const { - if (mDicNodeProperties.hasBlacklistedOrNotAWordFlag()) { - return true; + // Check if the current word and the previous word can be considered as a valid multiple word + // suggestion. + bool isValidMultipleWordSuggestion() const { + if (isBlacklistedOrNotAWord()) { + return false; } + // Treat suggestion as invalid if the current and the previous word are single character + // words. const int prevWordLen = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength() - mDicNodeState.mDicNodeStatePrevWord.getPrevWordStart() - 1; - const int currentWordLen = getDepth(); - return (prevWordLen == 1 && currentWordLen == 1); + const int currentWordLen = getNodeCodePointCount(); + return (prevWordLen != 1 || currentWordLen != 1); } bool isFirstCharUppercase() const { const int c = getOutputWordBuf()[0]; - return isAsciiUpper(c); + return CharUtils::isAsciiUpper(c); } bool isFirstWord() const { - return mDicNodeState.mDicNodeStatePrevWord.getPrevWordNodePos() == NOT_VALID_WORD; + return mDicNodeState.mDicNodeStatePrevWord.getPrevWordNodePos() == NOT_A_VALID_WORD_POS; } bool isCompletion(const int inputSize) const { @@ -251,37 +260,27 @@ class DicNode { return mDicNodeProperties.getChildrenPos(); } - // Used in DicNodeUtils - int getChildrenCount() const { - return mDicNodeProperties.getChildrenCount(); - } - - // Used in DicNodeUtils int getProbability() const { return mDicNodeProperties.getProbability(); } AK_FORCE_INLINE bool isTerminalWordNode() const { const bool isTerminalNodes = mDicNodeProperties.isTerminal(); - const int currentNodeDepth = getDepth(); + const int currentNodeDepth = getNodeCodePointCount(); const int terminalNodeDepth = mDicNodeProperties.getLeavingDepth(); return isTerminalNodes && currentNodeDepth > 0 && currentNodeDepth == terminalNodeDepth; } bool shouldBeFilterdBySafetyNetForBigram() const { - const uint16_t currentDepth = getDepth(); + const uint16_t currentDepth = getNodeCodePointCount(); const int prevWordLen = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength() - mDicNodeState.mDicNodeStatePrevWord.getPrevWordStart() - 1; return !(currentDepth > 0 && (currentDepth != 1 || prevWordLen != 1)); } - uint16_t getLeavingDepth() const { - return mDicNodeProperties.getLeavingDepth(); - } - bool isTotalInputSizeExceedingLimit() const { const int prevWordsLen = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(); - const int currentWordDepth = getDepth(); + const int currentWordDepth = getNodeCodePointCount(); // TODO: 3 can be 2? Needs to be investigated. // TODO: Have a const variable for 3 (or 2) return prevWordsLen + currentWordDepth > MAX_WORD_LENGTH - 3; @@ -316,7 +315,7 @@ class DicNode { void outputResult(int *dest) const { const uint16_t prevWordLength = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(); - const uint16_t currentDepth = getDepth(); + const uint16_t currentDepth = getNodeCodePointCount(); DicNodeUtils::appendTwoWords(mDicNodeState.mDicNodeStatePrevWord.mPrevWord, prevWordLength, getOutputWordBuf(), currentDepth, dest); DUMP_WORD_AND_SCORE("OUTPUT"); @@ -330,12 +329,12 @@ class DicNode { return mDicNodeState.mDicNodeStatePrevWord.getPrevWordCount() > 0; } - float getProximityCorrectionCount() const { - return static_cast<float>(mDicNodeState.mDicNodeStateScoring.getProximityCorrectionCount()); + int getProximityCorrectionCount() const { + return mDicNodeState.mDicNodeStateScoring.getProximityCorrectionCount(); } - float getEditCorrectionCount() const { - return static_cast<float>(mDicNodeState.mDicNodeStateScoring.getEditCorrectionCount()); + int getEditCorrectionCount() const { + return mDicNodeState.mDicNodeStateScoring.getEditCorrectionCount(); } // Used to prune nodes @@ -365,7 +364,7 @@ class DicNode { } AK_FORCE_INLINE const int *getOutputWordBuf() const { - return mDicNodeState.mDicNodeStateOutput.mWordBuf; + return mDicNodeState.mDicNodeStateOutput.mCodePointsBuf; } int getPrevCodePointG(int pointerId) const { @@ -375,7 +374,7 @@ class DicNode { // Whether the current codepoint can be an intentional omission, in which case the traversal // algorithm will always check for a possible omission here. bool canBeIntentionalOmission() const { - return isIntentionalOmissionCodePoint(getNodeCodePoint()); + return CharUtils::isIntentionalOmissionCodePoint(getNodeCodePoint()); } // Whether the omission is so frequent that it should incur zero cost. @@ -467,16 +466,17 @@ class DicNode { return mDicNodeState.mDicNodeStateScoring.isExactMatch(); } - uint8_t getFlags() const { - return mDicNodeProperties.getFlags(); + bool isBlacklistedOrNotAWord() const { + return mDicNodeProperties.isBlacklistedOrNotAWord(); } - int getAttributesPos() const { - return mDicNodeProperties.getAttributesPos(); + inline uint16_t getNodeCodePointCount() const { + return mDicNodeProperties.getDepth(); } - inline uint16_t getDepth() const { - return mDicNodeProperties.getDepth(); + // Returns code point count including spaces + inline uint16_t getTotalNodeCodePointCount() const { + return getNodeCodePointCount() + mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(); } AK_FORCE_INLINE void dump(const char *tag) const { @@ -503,6 +503,12 @@ class DicNode { if (!right->isUsed()) { return false; } + // Promote exact matches to prevent them from being pruned. + const bool leftExactMatch = isExactMatch(); + const bool rightExactMatch = right->isExactMatch(); + if (leftExactMatch != rightExactMatch) { + return leftExactMatch; + } const float diff = right->getNormalizedCompoundDistance() - getNormalizedCompoundDistance(); static const float MIN_DIFF = 0.000001f; @@ -511,8 +517,8 @@ class DicNode { } else if (diff < -MIN_DIFF) { return false; } - const int depth = getDepth(); - const int depthDiff = right->getDepth() - depth; + const int depth = getNodeCodePointCount(); + const int depthDiff = right->getNodeCodePointCount() - depth; if (depthDiff != 0) { return depthDiff > 0; } diff --git a/native/jni/src/suggest/core/dicnode/dic_node_priority_queue.h b/native/jni/src/suggest/core/dicnode/dic_node_priority_queue.h index d3f28a8bd..2a486b804 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_priority_queue.h +++ b/native/jni/src/suggest/core/dicnode/dic_node_priority_queue.h @@ -21,10 +21,11 @@ #include <vector> #include "defines.h" -#include "dic_node.h" -#include "dic_node_release_listener.h" +#include "suggest/core/dicnode/dic_node.h" +#include "suggest/core/dicnode/dic_node_release_listener.h" -#define MAX_DIC_NODE_PRIORITY_QUEUE_CAPACITY 200 +// The biggest value among MAX_CACHE_DIC_NODE_SIZE, MAX_CACHE_DIC_NODE_SIZE_FOR_SINGLE_POINT, ... +#define MAX_DIC_NODE_PRIORITY_QUEUE_CAPACITY 310 namespace latinime { diff --git a/native/jni/src/suggest/core/dicnode/dic_node_profiler.h b/native/jni/src/suggest/core/dicnode/dic_node_profiler.h index 90f75d0c6..1f4d2570e 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_profiler.h +++ b/native/jni/src/suggest/core/dicnode/dic_node_profiler.h @@ -31,6 +31,7 @@ #define PROF_TRANSPOSITION(profiler) profiler.profTransposition() #define PROF_NEARESTKEY(profiler) profiler.profNearestKey() #define PROF_TERMINAL(profiler) profiler.profTerminal() +#define PROF_TERMINAL_INSERTION(profiler) profiler.profTerminalInsertion() #define PROF_NEW_WORD(profiler) profiler.profNewWord() #define PROF_NEW_WORD_BIGRAM(profiler) profiler.profNewWordBigram() #define PROF_NODE_RESET(profiler) profiler.reset() @@ -47,6 +48,7 @@ #define PROF_TRANSPOSITION(profiler) #define PROF_NEARESTKEY(profiler) #define PROF_TERMINAL(profiler) +#define PROF_TERMINAL_INSERTION(profiler) #define PROF_NEW_WORD(profiler) #define PROF_NEW_WORD_BIGRAM(profiler) #define PROF_NODE_RESET(profiler) @@ -62,7 +64,7 @@ class DicNodeProfiler { : mProfOmission(0), mProfInsertion(0), mProfTransposition(0), mProfAdditionalProximity(0), mProfSubstitution(0), mProfSpaceSubstitution(0), mProfSpaceOmission(0), - mProfMatch(0), mProfCompletion(0), mProfTerminal(0), + mProfMatch(0), mProfCompletion(0), mProfTerminal(0), mProfTerminalInsertion(0), mProfNearestKey(0), mProfNewWord(0), mProfNewWordBigram(0) {} int mProfOmission; @@ -75,6 +77,7 @@ class DicNodeProfiler { int mProfMatch; int mProfCompletion; int mProfTerminal; + int mProfTerminalInsertion; int mProfNearestKey; int mProfNewWord; int mProfNewWordBigram; @@ -123,6 +126,10 @@ class DicNodeProfiler { ++mProfTerminal; } + void profTerminalInsertion() { + ++mProfTerminalInsertion; + } + void profNewWord() { ++mProfNewWord; } diff --git a/native/jni/src/suggest/core/dicnode/dic_node_proximity_filter.h b/native/jni/src/suggest/core/dicnode/dic_node_proximity_filter.h new file mode 100644 index 000000000..1a39f2ef3 --- /dev/null +++ b/native/jni/src/suggest/core/dicnode/dic_node_proximity_filter.h @@ -0,0 +1,58 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DIC_NODE_PROXIMITY_FILTER_H +#define LATINIME_DIC_NODE_PROXIMITY_FILTER_H + +#include "defines.h" +#include "suggest/core/layout/proximity_info_state.h" +#include "suggest/core/layout/proximity_info_utils.h" +#include "suggest/core/policy/dictionary_structure_policy.h" + +namespace latinime { + +class DicNodeProximityFilter : public DictionaryStructurePolicy::NodeFilter { + public: + DicNodeProximityFilter(const ProximityInfoState *const pInfoState, + const int pointIndex, const bool exactOnly) + : mProximityInfoState(pInfoState), mPointIndex(pointIndex), mExactOnly(exactOnly) {} + + bool isFilteredOut(const int codePoint) const { + return !isProximityCodePoint(codePoint); + } + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(DicNodeProximityFilter); + + const ProximityInfoState *const mProximityInfoState; + const int mPointIndex; + const bool mExactOnly; + + // TODO: Move to proximity info state + bool isProximityCodePoint(const int codePoint) const { + if (!mProximityInfoState) { + return true; + } + if (mExactOnly) { + return mProximityInfoState->getPrimaryCodePointAt(mPointIndex) == codePoint; + } + const ProximityType matchedId = mProximityInfoState->getProximityType( + mPointIndex, codePoint, true /* checkProximityChars */); + return ProximityInfoUtils::isMatchOrProximityChar(matchedId); + } +}; +} // namespace latinime +#endif // LATINIME_DIC_NODE_PROXIMITY_FILTER_H diff --git a/native/jni/src/suggest/core/dicnode/dic_node_release_listener.h b/native/jni/src/suggest/core/dicnode/dic_node_release_listener.h index 2a81c3cae..2ca4f21bd 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_release_listener.h +++ b/native/jni/src/suggest/core/dicnode/dic_node_release_listener.h @@ -21,6 +21,8 @@ namespace latinime { +class DicNode; + class DicNodeReleaseListener { public: DicNodeReleaseListener() {} diff --git a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp index 5357c3773..6b4ef2fea 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp +++ b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp @@ -14,16 +14,18 @@ * limitations under the License. */ +#include "suggest/core/dicnode/dic_node_utils.h" + #include <cstring> -#include <vector> -#include "binary_format.h" -#include "dic_node.h" -#include "dic_node_utils.h" -#include "dic_node_vector.h" -#include "multi_bigram_map.h" -#include "proximity_info.h" -#include "proximity_info_state.h" +#include "suggest/core/dicnode/dic_node.h" +#include "suggest/core/dicnode/dic_node_proximity_filter.h" +#include "suggest/core/dicnode/dic_node_vector.h" +#include "suggest/core/dictionary/binary_dictionary_info.h" +#include "suggest/core/dictionary/multi_bigram_map.h" +#include "suggest/core/dictionary/probability_utils.h" +#include "suggest/core/policy/dictionary_structure_policy.h" +#include "utils/char_utils.h" namespace latinime { @@ -31,22 +33,17 @@ namespace latinime { // Node initialization utils // /////////////////////////////// -/* static */ void DicNodeUtils::initAsRoot(const int rootPos, const uint8_t *const dicRoot, - const int prevWordNodePos, DicNode *newRootNode) { - int curPos = rootPos; - const int pos = curPos; - const int childrenCount = BinaryFormat::getGroupCountAndForwardPointer(dicRoot, &curPos); - const int childrenPos = curPos; - newRootNode->initAsRoot(pos, childrenPos, childrenCount, prevWordNodePos); +/* static */ void DicNodeUtils::initAsRoot(const BinaryDictionaryInfo *const binaryDictionaryInfo, + const int prevWordNodePos, DicNode *const newRootNode) { + newRootNode->initAsRoot(binaryDictionaryInfo->getStructurePolicy()->getRootPosition(), + prevWordNodePos); } -/*static */ void DicNodeUtils::initAsRootWithPreviousWord(const int rootPos, - const uint8_t *const dicRoot, DicNode *prevWordLastNode, DicNode *newRootNode) { - int curPos = rootPos; - const int pos = curPos; - const int childrenCount = BinaryFormat::getGroupCountAndForwardPointer(dicRoot, &curPos); - const int childrenPos = curPos; - newRootNode->initAsRootWithPreviousWord(prevWordLastNode, pos, childrenPos, childrenCount); +/*static */ void DicNodeUtils::initAsRootWithPreviousWord( + const BinaryDictionaryInfo *const binaryDictionaryInfo, + DicNode *const prevWordLastNode, DicNode *const newRootNode) { + newRootNode->initAsRootWithPreviousWord( + prevWordLastNode, binaryDictionaryInfo->getStructurePolicy()->getRootPosition()); } /* static */ void DicNodeUtils::initByCopy(DicNode *srcNode, DicNode *destNode) { @@ -58,130 +55,35 @@ namespace latinime { /////////////////////////////////// /* static */ void DicNodeUtils::createAndGetPassingChildNode(DicNode *dicNode, - const ProximityInfoState *pInfoState, const int pointIndex, const bool exactOnly, + const DicNodeProximityFilter *const childrenFilter, DicNodeVector *childDicNodes) { // Passing multiple chars node. No need to traverse child const int codePoint = dicNode->getNodeTypedCodePoint(); - const int baseLowerCaseCodePoint = toBaseLowerCase(codePoint); - const bool isMatch = isMatchedNodeCodePoint(pInfoState, pointIndex, exactOnly, codePoint); - if (isMatch || isIntentionalOmissionCodePoint(baseLowerCaseCodePoint)) { + const int baseLowerCaseCodePoint = CharUtils::toBaseLowerCase(codePoint); + if (!childrenFilter->isFilteredOut(codePoint) + || CharUtils::isIntentionalOmissionCodePoint(baseLowerCaseCodePoint)) { childDicNodes->pushPassingChild(dicNode); } } -/* static */ int DicNodeUtils::createAndGetLeavingChildNode(DicNode *dicNode, int pos, - const uint8_t *const dicRoot, const int terminalDepth, const ProximityInfoState *pInfoState, - const int pointIndex, const bool exactOnly, const std::vector<int> *const codePointsFilter, - const ProximityInfo *const pInfo, DicNodeVector *childDicNodes) { - int nextPos = pos; - const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(dicRoot, &pos); - const bool hasMultipleChars = (0 != (BinaryFormat::FLAG_HAS_MULTIPLE_CHARS & flags)); - const bool isTerminal = (0 != (BinaryFormat::FLAG_IS_TERMINAL & flags)); - const bool hasChildren = BinaryFormat::hasChildrenInFlags(flags); - - int codePoint = BinaryFormat::getCodePointAndForwardPointer(dicRoot, &pos); - ASSERT(NOT_A_CODE_POINT != codePoint); - const int nodeCodePoint = codePoint; - // TODO: optimize this - int additionalWordBuf[MAX_WORD_LENGTH]; - uint16_t additionalSubwordLength = 0; - additionalWordBuf[additionalSubwordLength++] = codePoint; - - do { - const int nextCodePoint = hasMultipleChars - ? BinaryFormat::getCodePointAndForwardPointer(dicRoot, &pos) : NOT_A_CODE_POINT; - const bool isLastChar = (NOT_A_CODE_POINT == nextCodePoint); - if (!isLastChar) { - additionalWordBuf[additionalSubwordLength++] = nextCodePoint; - } - codePoint = nextCodePoint; - } while (NOT_A_CODE_POINT != codePoint); - - const int probability = - isTerminal ? BinaryFormat::readProbabilityWithoutMovingPointer(dicRoot, pos) : -1; - pos = BinaryFormat::skipProbability(flags, pos); - int childrenPos = hasChildren ? BinaryFormat::readChildrenPosition(dicRoot, flags, pos) : 0; - const int attributesPos = BinaryFormat::skipChildrenPosition(flags, pos); - const int siblingPos = BinaryFormat::skipChildrenPosAndAttributes(dicRoot, flags, pos); - - if (isDicNodeFilteredOut(nodeCodePoint, pInfo, codePointsFilter)) { - return siblingPos; - } - if (!isMatchedNodeCodePoint(pInfoState, pointIndex, exactOnly, nodeCodePoint)) { - return siblingPos; - } - const int childrenCount = hasChildren - ? BinaryFormat::getGroupCountAndForwardPointer(dicRoot, &childrenPos) : 0; - childDicNodes->pushLeavingChild(dicNode, nextPos, flags, childrenPos, attributesPos, siblingPos, - nodeCodePoint, childrenCount, probability, -1 /* bigramProbability */, isTerminal, - hasMultipleChars, hasChildren, additionalSubwordLength, additionalWordBuf); - return siblingPos; -} - -/* static */ bool DicNodeUtils::isDicNodeFilteredOut(const int nodeCodePoint, - const ProximityInfo *const pInfo, const std::vector<int> *const codePointsFilter) { - const int filterSize = codePointsFilter ? codePointsFilter->size() : 0; - if (filterSize <= 0) { - return false; - } - if (pInfo && (pInfo->getKeyIndexOf(nodeCodePoint) == NOT_AN_INDEX - || isIntentionalOmissionCodePoint(nodeCodePoint))) { - // If normalized nodeCodePoint is not on the keyboard or skippable, this child is never - // filtered. - return false; - } - const int lowerCodePoint = toLowerCase(nodeCodePoint); - const int baseLowerCodePoint = toBaseCodePoint(lowerCodePoint); - // TODO: Avoid linear search - for (int i = 0; i < filterSize; ++i) { - // Checking if a normalized code point is in filter characters when pInfo is not - // null. When pInfo is null, nodeCodePoint is used to check filtering without - // normalizing. - if ((pInfo && ((*codePointsFilter)[i] == lowerCodePoint - || (*codePointsFilter)[i] == baseLowerCodePoint)) - || (!pInfo && (*codePointsFilter)[i] == nodeCodePoint)) { - return false; - } - } - return true; -} - -/* static */ void DicNodeUtils::createAndGetAllLeavingChildNodes(DicNode *dicNode, - const uint8_t *const dicRoot, const ProximityInfoState *pInfoState, const int pointIndex, - const bool exactOnly, const std::vector<int> *const codePointsFilter, - const ProximityInfo *const pInfo, DicNodeVector *childDicNodes) { - const int terminalDepth = dicNode->getLeavingDepth(); - const int childCount = dicNode->getChildrenCount(); - int nextPos = dicNode->getChildrenPos(); - for (int i = 0; i < childCount; i++) { - const int filterSize = codePointsFilter ? codePointsFilter->size() : 0; - nextPos = createAndGetLeavingChildNode(dicNode, nextPos, dicRoot, terminalDepth, pInfoState, - pointIndex, exactOnly, codePointsFilter, pInfo, childDicNodes); - if (!pInfo && filterSize > 0 && childDicNodes->exceeds(filterSize)) { - // All code points have been found. - break; - } - } -} - -/* static */ void DicNodeUtils::getAllChildDicNodes(DicNode *dicNode, const uint8_t *const dicRoot, - DicNodeVector *childDicNodes) { - getProximityChildDicNodes(dicNode, dicRoot, 0, 0, false, childDicNodes); +/* static */ void DicNodeUtils::getAllChildDicNodes(DicNode *dicNode, + const BinaryDictionaryInfo *const binaryDictionaryInfo, DicNodeVector *childDicNodes) { + getProximityChildDicNodes(dicNode, binaryDictionaryInfo, 0, 0, false, childDicNodes); } /* static */ void DicNodeUtils::getProximityChildDicNodes(DicNode *dicNode, - const uint8_t *const dicRoot, const ProximityInfoState *pInfoState, const int pointIndex, - bool exactOnly, DicNodeVector *childDicNodes) { + const BinaryDictionaryInfo *const binaryDictionaryInfo, + const ProximityInfoState *pInfoState, const int pointIndex, bool exactOnly, + DicNodeVector *childDicNodes) { if (dicNode->isTotalInputSizeExceedingLimit()) { return; } + const DicNodeProximityFilter childrenFilter(pInfoState, pointIndex, exactOnly); if (!dicNode->isLeavingNode()) { - DicNodeUtils::createAndGetPassingChildNode(dicNode, pInfoState, pointIndex, exactOnly, - childDicNodes); + DicNodeUtils::createAndGetPassingChildNode(dicNode, &childrenFilter, childDicNodes); } else { - DicNodeUtils::createAndGetAllLeavingChildNodes(dicNode, dicRoot, pInfoState, pointIndex, - exactOnly, 0 /* codePointsFilter */, 0 /* pInfo */, - childDicNodes); + binaryDictionaryInfo->getStructurePolicy()->createAndGetAllChildNodes(dicNode, + binaryDictionaryInfo, &childrenFilter, childDicNodes); } } @@ -191,49 +93,35 @@ namespace latinime { /** * Computes the combined bigram / unigram cost for the given dicNode. */ -/* static */ float DicNodeUtils::getBigramNodeImprobability(const uint8_t *const dicRoot, +/* static */ float DicNodeUtils::getBigramNodeImprobability( + const BinaryDictionaryInfo *const binaryDictionaryInfo, const DicNode *const node, MultiBigramMap *multiBigramMap) { - if (node->isImpossibleBigramWord()) { + if (node->hasMultipleWords() && !node->isValidMultipleWordSuggestion()) { return static_cast<float>(MAX_VALUE_FOR_WEIGHTING); } - const int probability = getBigramNodeProbability(dicRoot, node, multiBigramMap); + const int probability = getBigramNodeProbability(binaryDictionaryInfo, node, multiBigramMap); // TODO: This equation to calculate the improbability looks unreasonable. Investigate this. const float cost = static_cast<float>(MAX_PROBABILITY - probability) / static_cast<float>(MAX_PROBABILITY); return cost; } -/* static */ int DicNodeUtils::getBigramNodeProbability(const uint8_t *const dicRoot, +/* static */ int DicNodeUtils::getBigramNodeProbability( + const BinaryDictionaryInfo *const binaryDictionaryInfo, const DicNode *const node, MultiBigramMap *multiBigramMap) { const int unigramProbability = node->getProbability(); const int wordPos = node->getPos(); const int prevWordPos = node->getPrevWordPos(); - if (NOT_VALID_WORD == wordPos || NOT_VALID_WORD == prevWordPos) { - // Note: Normally wordPos comes from the dictionary and should never equal NOT_VALID_WORD. - return backoff(unigramProbability); + if (NOT_A_VALID_WORD_POS == wordPos || NOT_A_VALID_WORD_POS == prevWordPos) { + // Note: Normally wordPos comes from the dictionary and should never equal + // NOT_A_VALID_WORD_POS. + return ProbabilityUtils::backoff(unigramProbability); } if (multiBigramMap) { return multiBigramMap->getBigramProbability( - dicRoot, prevWordPos, wordPos, unigramProbability); - } - return BinaryFormat::getBigramProbability(dicRoot, prevWordPos, wordPos, unigramProbability); -} - -/////////////////////////////////////// -// Bigram / Unigram dictionary utils // -/////////////////////////////////////// - -/* static */ bool DicNodeUtils::isMatchedNodeCodePoint(const ProximityInfoState *pInfoState, - const int pointIndex, const bool exactOnly, const int nodeCodePoint) { - if (!pInfoState) { - return true; - } - if (exactOnly) { - return pInfoState->getPrimaryCodePointAt(pointIndex) == nodeCodePoint; + binaryDictionaryInfo, prevWordPos, wordPos, unigramProbability); } - const ProximityType matchedId = pInfoState->getProximityType(pointIndex, nodeCodePoint, - true /* checkProximityChars */); - return isProximityChar(matchedId); + return ProbabilityUtils::backoff(unigramProbability); } //////////////// @@ -262,7 +150,7 @@ namespace latinime { } actualLength1 = i + 1; } - actualLength1 = min(actualLength1, MAX_WORD_LENGTH - actualLength0 - 1); + actualLength1 = min(actualLength1, MAX_WORD_LENGTH - actualLength0); memcpy(&dest[actualLength0], src1, actualLength1 * sizeof(dest[0])); return actualLength0 + actualLength1; } diff --git a/native/jni/src/suggest/core/dicnode/dic_node_utils.h b/native/jni/src/suggest/core/dicnode/dic_node_utils.h index 5bc542d05..4f12b29f4 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_utils.h +++ b/native/jni/src/suggest/core/dicnode/dic_node_utils.h @@ -18,15 +18,15 @@ #define LATINIME_DIC_NODE_UTILS_H #include <stdint.h> -#include <vector> #include "defines.h" namespace latinime { +class BinaryDictionaryInfo; class DicNode; +class DicNodeProximityFilter; class DicNodeVector; -class ProximityInfo; class ProximityInfoState; class MultiBigramMap; @@ -34,48 +34,30 @@ class DicNodeUtils { public: static int appendTwoWords(const int *src0, const int16_t length0, const int *src1, const int16_t length1, int *dest); - static void initAsRoot(const int rootPos, const uint8_t *const dicRoot, + static void initAsRoot(const BinaryDictionaryInfo *const binaryDictionaryInfo, const int prevWordNodePos, DicNode *newRootNode); - static void initAsRootWithPreviousWord(const int rootPos, const uint8_t *const dicRoot, + static void initAsRootWithPreviousWord(const BinaryDictionaryInfo *const binaryDictionaryInfo, DicNode *prevWordLastNode, DicNode *newRootNode); static void initByCopy(DicNode *srcNode, DicNode *destNode); - static void getAllChildDicNodes(DicNode *dicNode, const uint8_t *const dicRoot, - DicNodeVector *childDicNodes); - static float getBigramNodeImprobability(const uint8_t *const dicRoot, + static void getAllChildDicNodes(DicNode *dicNode, + const BinaryDictionaryInfo *const binaryDictionaryInfo, DicNodeVector *childDicNodes); + static float getBigramNodeImprobability(const BinaryDictionaryInfo *const binaryDictionaryInfo, const DicNode *const node, MultiBigramMap *const multiBigramMap); - static bool isDicNodeFilteredOut(const int nodeCodePoint, const ProximityInfo *const pInfo, - const std::vector<int> *const codePointsFilter); // TODO: Move to private - static void getProximityChildDicNodes(DicNode *dicNode, const uint8_t *const dicRoot, + static void getProximityChildDicNodes(DicNode *dicNode, + const BinaryDictionaryInfo *const binaryDictionaryInfo, const ProximityInfoState *pInfoState, const int pointIndex, bool exactOnly, DicNodeVector *childDicNodes); - // TODO: Move to proximity info - static bool isProximityChar(ProximityType type) { - return type == MATCH_CHAR || type == PROXIMITY_CHAR || type == ADDITIONAL_PROXIMITY_CHAR; - } - private: DISALLOW_IMPLICIT_CONSTRUCTORS(DicNodeUtils); // Max number of bigrams to look up static const int MAX_BIGRAMS_CONSIDERED_PER_CONTEXT = 500; - static int getBigramNodeProbability(const uint8_t *const dicRoot, const DicNode *const node, - MultiBigramMap *multiBigramMap); - static void createAndGetPassingChildNode(DicNode *dicNode, const ProximityInfoState *pInfoState, - const int pointIndex, const bool exactOnly, DicNodeVector *childDicNodes); - static void createAndGetAllLeavingChildNodes(DicNode *dicNode, const uint8_t *const dicRoot, - const ProximityInfoState *pInfoState, const int pointIndex, const bool exactOnly, - const std::vector<int> *const codePointsFilter, - const ProximityInfo *const pInfo, DicNodeVector *childDicNodes); - static int createAndGetLeavingChildNode(DicNode *dicNode, int pos, const uint8_t *const dicRoot, - const int terminalDepth, const ProximityInfoState *pInfoState, const int pointIndex, - const bool exactOnly, const std::vector<int> *const codePointsFilter, - const ProximityInfo *const pInfo, DicNodeVector *childDicNodes); - - // TODO: Move to proximity info - static bool isMatchedNodeCodePoint(const ProximityInfoState *pInfoState, const int pointIndex, - const bool exactOnly, const int nodeCodePoint); + static int getBigramNodeProbability(const BinaryDictionaryInfo *const binaryDictionaryInfo, + const DicNode *const node, MultiBigramMap *multiBigramMap); + static void createAndGetPassingChildNode(DicNode *dicNode, + const DicNodeProximityFilter *const childrenFilter, DicNodeVector *childDicNodes); }; } // namespace latinime #endif // LATINIME_DIC_NODE_UTILS_H diff --git a/native/jni/src/suggest/core/dicnode/dic_node_vector.h b/native/jni/src/suggest/core/dicnode/dic_node_vector.h index ca07edaee..42addae8d 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_vector.h +++ b/native/jni/src/suggest/core/dicnode/dic_node_vector.h @@ -20,7 +20,7 @@ #include <vector> #include "defines.h" -#include "dic_node.h" +#include "suggest/core/dicnode/dic_node.h" namespace latinime { @@ -62,17 +62,15 @@ class DicNodeVector { mDicNodes.back().initAsPassingChild(dicNode); } - void pushLeavingChild(DicNode *dicNode, const int pos, const uint8_t flags, - const int childrenPos, const int attributesPos, const int siblingPos, - const int nodeCodePoint, const int childrenCount, const int probability, - const int bigramProbability, const bool isTerminal, const bool hasMultipleChars, - const bool hasChildren, const uint16_t additionalSubwordLength, - const int *additionalSubword) { + void pushLeavingChild(const DicNode *const dicNode, const int pos, const int childrenPos, + const int probability, const bool isTerminal, const bool hasChildren, + const bool isBlacklistedOrNotAWord, const uint16_t mergedNodeCodePointCount, + const int *const mergedNodeCodePoints) { ASSERT(!mLock); mDicNodes.push_back(mEmptyNode); - mDicNodes.back().initAsChild(dicNode, pos, flags, childrenPos, attributesPos, siblingPos, - nodeCodePoint, childrenCount, probability, -1 /* bigramProbability */, isTerminal, - hasMultipleChars, hasChildren, additionalSubwordLength, additionalSubword); + mDicNodes.back().initAsChild(dicNode, pos, childrenPos, probability, isTerminal, + hasChildren, isBlacklistedOrNotAWord, mergedNodeCodePointCount, + mergedNodeCodePoints); } DicNode *operator[](const int id) { diff --git a/native/jni/src/suggest/core/dicnode/dic_nodes_cache.cpp b/native/jni/src/suggest/core/dicnode/dic_nodes_cache.cpp index b9a60780b..c3d2a2e74 100644 --- a/native/jni/src/suggest/core/dicnode/dic_nodes_cache.cpp +++ b/native/jni/src/suggest/core/dicnode/dic_nodes_cache.cpp @@ -17,9 +17,9 @@ #include <list> #include "defines.h" -#include "dic_node_priority_queue.h" -#include "dic_node_utils.h" -#include "dic_nodes_cache.h" +#include "suggest/core/dicnode/dic_node_priority_queue.h" +#include "suggest/core/dicnode/dic_node_utils.h" +#include "suggest/core/dicnode/dic_nodes_cache.h" namespace latinime { diff --git a/native/jni/src/suggest/core/dicnode/dic_nodes_cache.h b/native/jni/src/suggest/core/dicnode/dic_nodes_cache.h index a62aa422a..7aab0906e 100644 --- a/native/jni/src/suggest/core/dicnode/dic_nodes_cache.h +++ b/native/jni/src/suggest/core/dicnode/dic_nodes_cache.h @@ -20,13 +20,7 @@ #include <stdint.h> #include "defines.h" -#include "dic_node_priority_queue.h" - -#define INITIAL_QUEUE_ID_ACTIVE 0 -#define INITIAL_QUEUE_ID_NEXT_ACTIVE 1 -#define INITIAL_QUEUE_ID_TERMINAL 2 -#define INITIAL_QUEUE_ID_CACHE_FOR_CONTINUOUS_SUGGESTION 3 -#define PRIORITY_QUEUES_SIZE 4 +#include "suggest/core/dicnode/dic_node_priority_queue.h" namespace latinime { @@ -38,11 +32,12 @@ class DicNode; class DicNodesCache { public: AK_FORCE_INLINE DicNodesCache() - : mActiveDicNodes(&mDicNodePriorityQueues[INITIAL_QUEUE_ID_ACTIVE]), - mNextActiveDicNodes(&mDicNodePriorityQueues[INITIAL_QUEUE_ID_NEXT_ACTIVE]), - mTerminalDicNodes(&mDicNodePriorityQueues[INITIAL_QUEUE_ID_TERMINAL]), - mCachedDicNodesForContinuousSuggestion( - &mDicNodePriorityQueues[INITIAL_QUEUE_ID_CACHE_FOR_CONTINUOUS_SUGGESTION]), + : mActiveDicNodes(&mDicNodePriorityQueues[DIC_NODES_CACHE_INITIAL_QUEUE_ID_ACTIVE]), + mNextActiveDicNodes(&mDicNodePriorityQueues[ + DIC_NODES_CACHE_INITIAL_QUEUE_ID_NEXT_ACTIVE]), + mTerminalDicNodes(&mDicNodePriorityQueues[DIC_NODES_CACHE_INITIAL_QUEUE_ID_TERMINAL]), + mCachedDicNodesForContinuousSuggestion(&mDicNodePriorityQueues[ + DIC_NODES_CACHE_INITIAL_QUEUE_ID_CACHE_FOR_CONTINUOUS_SUGGESTION]), mInputIndex(0), mLastCachedInputIndex(0) { } @@ -147,9 +142,8 @@ class DicNodesCache { mCachedDicNodesForContinuousSuggestion->dump(); } mInputIndex = mLastCachedInputIndex; - mCachedDicNodesForContinuousSuggestion = - moveNodesAndReturnReusableEmptyQueue( - mCachedDicNodesForContinuousSuggestion, &mActiveDicNodes); + mCachedDicNodesForContinuousSuggestion = moveNodesAndReturnReusableEmptyQueue( + mCachedDicNodesForContinuousSuggestion, &mActiveDicNodes); } AK_FORCE_INLINE static DicNodePriorityQueue *moveNodesAndReturnReusableEmptyQueue( @@ -169,7 +163,7 @@ class DicNodesCache { mTerminalDicNodes->clear(); } - DicNodePriorityQueue mDicNodePriorityQueues[PRIORITY_QUEUES_SIZE]; + DicNodePriorityQueue mDicNodePriorityQueues[DIC_NODES_CACHE_PRIORITY_QUEUES_SIZE]; // Active dicNodes currently being expanded. DicNodePriorityQueue *mActiveDicNodes; // Next dicNodes to be expanded. diff --git a/native/jni/src/suggest/core/dicnode/dic_node_properties.h b/native/jni/src/suggest/core/dicnode/internal/dic_node_properties.h index 63a6b1340..9e0f62ceb 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_properties.h +++ b/native/jni/src/suggest/core/dicnode/internal/dic_node_properties.h @@ -19,7 +19,6 @@ #include <stdint.h> -#include "binary_format.h" #include "defines.h" namespace latinime { @@ -27,53 +26,40 @@ namespace latinime { /** * Node for traversing the lexicon trie. */ +// TODO: Introduce a dictionary node class which has attribute members required to understand the +// dictionary structure. class DicNodeProperties { public: AK_FORCE_INLINE DicNodeProperties() - : mPos(0), mFlags(0), mChildrenPos(0), mAttributesPos(0), mSiblingPos(0), - mChildrenCount(0), mProbability(0), mBigramProbability(0), mNodeCodePoint(0), - mDepth(0), mLeavingDepth(0), mIsTerminal(false), mHasMultipleChars(false), - mHasChildren(false) { - } + : mPos(0), mChildrenPos(0), mProbability(0), mNodeCodePoint(0), mIsTerminal(false), + mHasChildren(false), mIsBlacklistedOrNotAWord(false), mDepth(0), mLeavingDepth(0) {} virtual ~DicNodeProperties() {} // Should be called only once per DicNode is initialized. - void init(const int pos, const uint8_t flags, const int childrenPos, const int attributesPos, - const int siblingPos, const int nodeCodePoint, const int childrenCount, - const int probability, const int bigramProbability, const bool isTerminal, - const bool hasMultipleChars, const bool hasChildren, const uint16_t depth, - const uint16_t terminalDepth) { + void init(const int pos, const int childrenPos, const int nodeCodePoint, const int probability, + const bool isTerminal, const bool hasChildren, const bool isBlacklistedOrNotAWord, + const uint16_t depth, const uint16_t leavingDepth) { mPos = pos; - mFlags = flags; mChildrenPos = childrenPos; - mAttributesPos = attributesPos; - mSiblingPos = siblingPos; mNodeCodePoint = nodeCodePoint; - mChildrenCount = childrenCount; mProbability = probability; - mBigramProbability = bigramProbability; mIsTerminal = isTerminal; - mHasMultipleChars = hasMultipleChars; mHasChildren = hasChildren; + mIsBlacklistedOrNotAWord = isBlacklistedOrNotAWord; mDepth = depth; - mLeavingDepth = terminalDepth; + mLeavingDepth = leavingDepth; } // Init for copy void init(const DicNodeProperties *const nodeProp) { mPos = nodeProp->mPos; - mFlags = nodeProp->mFlags; mChildrenPos = nodeProp->mChildrenPos; - mAttributesPos = nodeProp->mAttributesPos; - mSiblingPos = nodeProp->mSiblingPos; mNodeCodePoint = nodeProp->mNodeCodePoint; - mChildrenCount = nodeProp->mChildrenCount; mProbability = nodeProp->mProbability; - mBigramProbability = nodeProp->mBigramProbability; mIsTerminal = nodeProp->mIsTerminal; - mHasMultipleChars = nodeProp->mHasMultipleChars; mHasChildren = nodeProp->mHasChildren; + mIsBlacklistedOrNotAWord = nodeProp->mIsBlacklistedOrNotAWord; mDepth = nodeProp->mDepth; mLeavingDepth = nodeProp->mLeavingDepth; } @@ -81,17 +67,12 @@ class DicNodeProperties { // Init as passing child void init(const DicNodeProperties *const nodeProp, const int codePoint) { mPos = nodeProp->mPos; - mFlags = nodeProp->mFlags; mChildrenPos = nodeProp->mChildrenPos; - mAttributesPos = nodeProp->mAttributesPos; - mSiblingPos = nodeProp->mSiblingPos; mNodeCodePoint = codePoint; // Overwrite the node char of a passing child - mChildrenCount = nodeProp->mChildrenCount; mProbability = nodeProp->mProbability; - mBigramProbability = nodeProp->mBigramProbability; mIsTerminal = nodeProp->mIsTerminal; - mHasMultipleChars = nodeProp->mHasMultipleChars; mHasChildren = nodeProp->mHasChildren; + mIsBlacklistedOrNotAWord = nodeProp->mIsBlacklistedOrNotAWord; mDepth = nodeProp->mDepth + 1; // Increment the depth of a passing child mLeavingDepth = nodeProp->mLeavingDepth; } @@ -100,22 +81,10 @@ class DicNodeProperties { return mPos; } - uint8_t getFlags() const { - return mFlags; - } - int getChildrenPos() const { return mChildrenPos; } - int getAttributesPos() const { - return mAttributesPos; - } - - int getChildrenCount() const { - return mChildrenCount; - } - int getProbability() const { return mProbability; } @@ -137,42 +106,27 @@ class DicNodeProperties { return mIsTerminal; } - bool hasMultipleChars() const { - return mHasMultipleChars; - } - bool hasChildren() const { - return mChildrenCount > 0 || mDepth != mLeavingDepth; + return mHasChildren || mDepth != mLeavingDepth; } - bool hasBlacklistedOrNotAWordFlag() const { - return BinaryFormat::hasBlacklistedOrNotAWordFlag(mFlags); + bool isBlacklistedOrNotAWord() const { + return mIsBlacklistedOrNotAWord; } private: // Caution!!! // Use a default copy constructor and an assign operator because shallow copies are ok // for this class - - // Not used - int getSiblingPos() const { - return mSiblingPos; - } - int mPos; - uint8_t mFlags; int mChildrenPos; - int mAttributesPos; - int mSiblingPos; - int mChildrenCount; int mProbability; - int mBigramProbability; // not used for now int mNodeCodePoint; - uint16_t mDepth; - uint16_t mLeavingDepth; bool mIsTerminal; - bool mHasMultipleChars; bool mHasChildren; + bool mIsBlacklistedOrNotAWord; + uint16_t mDepth; + uint16_t mLeavingDepth; }; } // namespace latinime #endif // LATINIME_DIC_NODE_PROPERTIES_H diff --git a/native/jni/src/suggest/core/dicnode/dic_node_state.h b/native/jni/src/suggest/core/dicnode/internal/dic_node_state.h index 239b63c32..b0fddb724 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_state.h +++ b/native/jni/src/suggest/core/dicnode/internal/dic_node_state.h @@ -18,10 +18,10 @@ #define LATINIME_DIC_NODE_STATE_H #include "defines.h" -#include "dic_node_state_input.h" -#include "dic_node_state_output.h" -#include "dic_node_state_prevword.h" -#include "dic_node_state_scoring.h" +#include "suggest/core/dicnode/internal/dic_node_state_input.h" +#include "suggest/core/dicnode/internal/dic_node_state_output.h" +#include "suggest/core/dicnode/internal/dic_node_state_prevword.h" +#include "suggest/core/dicnode/internal/dic_node_state_scoring.h" namespace latinime { @@ -55,11 +55,12 @@ class DicNodeState { mDicNodeStateScoring.init(&src->mDicNodeStateScoring); } - // Init by copy and adding subword - void init(const DicNodeState *const src, const uint16_t additionalSubwordLength, - const int *const additionalSubword) { + // Init by copy and adding merged node code points. + void init(const DicNodeState *const src, const uint16_t mergedNodeCodePointCount, + const int *const mergedNodeCodePoints) { init(src); - mDicNodeStateOutput.addSubword(additionalSubwordLength, additionalSubword); + mDicNodeStateOutput.addMergedNodeCodePoints( + mergedNodeCodePointCount, mergedNodeCodePoints); } private: diff --git a/native/jni/src/suggest/core/dicnode/dic_node_state_input.h b/native/jni/src/suggest/core/dicnode/internal/dic_node_state_input.h index bbd9435b5..bbd9435b5 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_state_input.h +++ b/native/jni/src/suggest/core/dicnode/internal/dic_node_state_input.h diff --git a/native/jni/src/suggest/core/dicnode/dic_node_state_output.h b/native/jni/src/suggest/core/dicnode/internal/dic_node_state_output.h index 1d4f50a06..45c7f5cf9 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_state_output.h +++ b/native/jni/src/suggest/core/dicnode/internal/dic_node_state_output.h @@ -26,50 +26,52 @@ namespace latinime { class DicNodeStateOutput { public: - DicNodeStateOutput() : mOutputtedLength(0) { + DicNodeStateOutput() : mOutputtedCodePointCount(0) { init(); } virtual ~DicNodeStateOutput() {} void init() { - mOutputtedLength = 0; - mWordBuf[0] = 0; + mOutputtedCodePointCount = 0; + mCodePointsBuf[0] = 0; } void init(const DicNodeStateOutput *const stateOutput) { - memcpy(mWordBuf, stateOutput->mWordBuf, - stateOutput->mOutputtedLength * sizeof(mWordBuf[0])); - mOutputtedLength = stateOutput->mOutputtedLength; - if (mOutputtedLength < MAX_WORD_LENGTH) { - mWordBuf[mOutputtedLength] = 0; + memcpy(mCodePointsBuf, stateOutput->mCodePointsBuf, + stateOutput->mOutputtedCodePointCount * sizeof(mCodePointsBuf[0])); + mOutputtedCodePointCount = stateOutput->mOutputtedCodePointCount; + if (mOutputtedCodePointCount < MAX_WORD_LENGTH) { + mCodePointsBuf[mOutputtedCodePointCount] = 0; } } - void addSubword(const uint16_t additionalSubwordLength, const int *const additionalSubword) { - if (additionalSubword) { - memcpy(&mWordBuf[mOutputtedLength], additionalSubword, - additionalSubwordLength * sizeof(mWordBuf[0])); - mOutputtedLength = static_cast<uint16_t>(mOutputtedLength + additionalSubwordLength); - if (mOutputtedLength < MAX_WORD_LENGTH) { - mWordBuf[mOutputtedLength] = 0; + void addMergedNodeCodePoints(const uint16_t mergedNodeCodePointCount, + const int *const mergedNodeCodePoints) { + if (mergedNodeCodePoints) { + memcpy(&mCodePointsBuf[mOutputtedCodePointCount], mergedNodeCodePoints, + mergedNodeCodePointCount * sizeof(mCodePointsBuf[0])); + mOutputtedCodePointCount = static_cast<uint16_t>( + mOutputtedCodePointCount + mergedNodeCodePointCount); + if (mOutputtedCodePointCount < MAX_WORD_LENGTH) { + mCodePointsBuf[mOutputtedCodePointCount] = 0; } } } // TODO: Remove - int getCodePointAt(const int id) const { - return mWordBuf[id]; + int getCodePointAt(const int index) const { + return mCodePointsBuf[index]; } // TODO: Move to private - int mWordBuf[MAX_WORD_LENGTH]; + int mCodePointsBuf[MAX_WORD_LENGTH]; private: // Caution!!! // Use a default copy constructor and an assign operator because shallow copies are ok // for this class - uint16_t mOutputtedLength; + uint16_t mOutputtedCodePointCount; }; } // namespace latinime #endif // LATINIME_DIC_NODE_STATE_OUTPUT_H diff --git a/native/jni/src/suggest/core/dicnode/dic_node_state_prevword.h b/native/jni/src/suggest/core/dicnode/internal/dic_node_state_prevword.h index e3b892bda..5854f4f6e 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_state_prevword.h +++ b/native/jni/src/suggest/core/dicnode/internal/dic_node_state_prevword.h @@ -21,7 +21,7 @@ #include <stdint.h> #include "defines.h" -#include "dic_node_utils.h" +#include "suggest/core/dicnode/dic_node_utils.h" namespace latinime { @@ -29,7 +29,7 @@ class DicNodeStatePrevWord { public: AK_FORCE_INLINE DicNodeStatePrevWord() : mPrevWordCount(0), mPrevWordLength(0), mPrevWordStart(0), mPrevWordProbability(0), - mPrevWordNodePos(0) { + mPrevWordNodePos(NOT_A_VALID_WORD_POS) { memset(mPrevWord, 0, sizeof(mPrevWord)); memset(mPrevSpacePositions, 0, sizeof(mPrevSpacePositions)); } @@ -41,7 +41,7 @@ class DicNodeStatePrevWord { mPrevWordCount = 0; mPrevWordStart = 0; mPrevWordProbability = -1; - mPrevWordNodePos = NOT_VALID_WORD; + mPrevWordNodePos = NOT_A_VALID_WORD_POS; memset(mPrevSpacePositions, 0, sizeof(mPrevSpacePositions)); } diff --git a/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h b/native/jni/src/suggest/core/dicnode/internal/dic_node_state_scoring.h index dca9d60da..4c884225a 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h +++ b/native/jni/src/suggest/core/dicnode/internal/dic_node_state_scoring.h @@ -20,7 +20,7 @@ #include <stdint.h> #include "defines.h" -#include "digraph_utils.h" +#include "suggest/core/dictionary/digraph_utils.h" namespace latinime { diff --git a/native/jni/src/bigram_dictionary.cpp b/native/jni/src/suggest/core/dictionary/bigram_dictionary.cpp index 9053e7226..532c769c6 100644 --- a/native/jni/src/bigram_dictionary.cpp +++ b/native/jni/src/suggest/core/dictionary/bigram_dictionary.cpp @@ -19,15 +19,18 @@ #define LOG_TAG "LatinIME: bigram_dictionary.cpp" #include "bigram_dictionary.h" -#include "binary_format.h" -#include "bloom_filter.h" -#include "char_utils.h" + #include "defines.h" -#include "dictionary.h" +#include "suggest/core/dictionary/binary_dictionary_bigrams_iterator.h" +#include "suggest/core/dictionary/binary_dictionary_info.h" +#include "suggest/core/dictionary/dictionary.h" +#include "suggest/core/dictionary/probability_utils.h" +#include "utils/char_utils.h" namespace latinime { -BigramDictionary::BigramDictionary(const uint8_t *const streamStart) : DICT_ROOT(streamStart) { +BigramDictionary::BigramDictionary(const BinaryDictionaryInfo *const binaryDictionaryInfo) + : mBinaryDictionaryInfo(binaryDictionaryInfo) { if (DEBUG_DICT) { AKLOGI("BigramDictionary - constructor"); } @@ -51,7 +54,7 @@ void BigramDictionary::addWordBigram(int *word, int length, int probability, int int insertAt = 0; while (insertAt < MAX_RESULTS) { if (probability > bigramProbability[insertAt] || (bigramProbability[insertAt] == probability - && length < getCodePointCount(MAX_WORD_LENGTH, + && length < CharUtils::getCodePointCount(MAX_WORD_LENGTH, bigramCodePoints + insertAt * MAX_WORD_LENGTH))) { break; } @@ -97,97 +100,61 @@ void BigramDictionary::addWordBigram(int *word, int length, int probability, int * and the bigrams are used to boost unigram result scores, it makes little sense to * reduce their scope to the ones that match the first letter. */ -int BigramDictionary::getBigrams(const int *prevWord, int prevWordLength, int *inputCodePoints, +int BigramDictionary::getPredictions(const int *prevWord, int prevWordLength, int *inputCodePoints, int inputSize, int *bigramCodePoints, int *bigramProbability, int *outputTypes) const { // TODO: remove unused arguments, and refrain from storing stuff in members of this class // TODO: have "in" arguments before "out" ones, and make out args explicit in the name - const uint8_t *const root = DICT_ROOT; int pos = getBigramListPositionForWord(prevWord, prevWordLength, false /* forceLowerCaseSearch */); // getBigramListPositionForWord returns 0 if this word isn't in the dictionary or has no bigrams - if (0 == pos) { + if (NOT_A_DICT_POS == pos) { // If no bigrams for this exact word, search again in lower case. pos = getBigramListPositionForWord(prevWord, prevWordLength, true /* forceLowerCaseSearch */); } // If still no bigrams, we really don't have them! - if (0 == pos) return 0; - uint8_t bigramFlags; + if (NOT_A_DICT_POS == pos) return 0; + int bigramCount = 0; - do { - bigramFlags = BinaryFormat::getFlagsAndForwardPointer(root, &pos); - int bigramBuffer[MAX_WORD_LENGTH]; - int unigramProbability = 0; - const int bigramPos = BinaryFormat::getAttributeAddressAndForwardPointer(root, bigramFlags, - &pos); - const int length = BinaryFormat::getWordAtAddress(root, bigramPos, MAX_WORD_LENGTH, - bigramBuffer, &unigramProbability); + int unigramProbability = 0; + int bigramBuffer[MAX_WORD_LENGTH]; + BinaryDictionaryBigramsIterator bigramsIt(mBinaryDictionaryInfo, pos); + while (bigramsIt.hasNext()) { + bigramsIt.next(); + const int length = mBinaryDictionaryInfo->getStructurePolicy()-> + getCodePointsAndProbabilityAndReturnCodePointCount( + mBinaryDictionaryInfo, bigramsIt.getBigramPos(), MAX_WORD_LENGTH, + bigramBuffer, &unigramProbability); // inputSize == 0 means we are trying to find bigram predictions. if (inputSize < 1 || checkFirstCharacter(bigramBuffer, inputCodePoints)) { - const int bigramProbabilityTemp = - BinaryFormat::MASK_ATTRIBUTE_PROBABILITY & bigramFlags; + const int bigramProbabilityTemp = bigramsIt.getProbability(); // Due to space constraints, the probability for bigrams is approximate - the lower the // unigram probability, the worse the precision. The theoritical maximum error in // resulting probability is 8 - although in the practice it's never bigger than 3 or 4 // in very bad cases. This means that sometimes, we'll see some bigrams interverted // here, but it can't get too bad. - const int probability = BinaryFormat::computeProbabilityForBigram( + const int probability = ProbabilityUtils::computeProbabilityForBigram( unigramProbability, bigramProbabilityTemp); addWordBigram(bigramBuffer, length, probability, bigramProbability, bigramCodePoints, outputTypes); ++bigramCount; } - } while (BinaryFormat::FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags); + } return min(bigramCount, MAX_RESULTS); } // Returns a pointer to the start of the bigram list. -// If the word is not found or has no bigrams, this function returns 0. +// If the word is not found or has no bigrams, this function returns NOT_A_DICT_POS. int BigramDictionary::getBigramListPositionForWord(const int *prevWord, const int prevWordLength, const bool forceLowerCaseSearch) const { - if (0 >= prevWordLength) return 0; - const uint8_t *const root = DICT_ROOT; - int pos = BinaryFormat::getTerminalPosition(root, prevWord, prevWordLength, - forceLowerCaseSearch); - - if (NOT_VALID_WORD == pos) return 0; - const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(root, &pos); - if (0 == (flags & BinaryFormat::FLAG_HAS_BIGRAMS)) return 0; - if (0 == (flags & BinaryFormat::FLAG_HAS_MULTIPLE_CHARS)) { - BinaryFormat::getCodePointAndForwardPointer(root, &pos); - } else { - pos = BinaryFormat::skipOtherCharacters(root, pos); - } - pos = BinaryFormat::skipProbability(flags, pos); - pos = BinaryFormat::skipChildrenPosition(flags, pos); - pos = BinaryFormat::skipShortcuts(root, flags, pos); - return pos; -} - -void BigramDictionary::fillBigramAddressToProbabilityMapAndFilter(const int *prevWord, - const int prevWordLength, std::map<int, int> *map, uint8_t *filter) const { - memset(filter, 0, BIGRAM_FILTER_BYTE_SIZE); - const uint8_t *const root = DICT_ROOT; - int pos = getBigramListPositionForWord(prevWord, prevWordLength, - false /* forceLowerCaseSearch */); - if (0 == pos) { - // If no bigrams for this exact string, search again in lower case. - pos = getBigramListPositionForWord(prevWord, prevWordLength, - true /* forceLowerCaseSearch */); - } - if (0 == pos) return; - - uint8_t bigramFlags; - do { - bigramFlags = BinaryFormat::getFlagsAndForwardPointer(root, &pos); - const int probability = BinaryFormat::MASK_ATTRIBUTE_PROBABILITY & bigramFlags; - const int bigramPos = BinaryFormat::getAttributeAddressAndForwardPointer(root, bigramFlags, - &pos); - (*map)[bigramPos] = probability; - setInFilter(filter, bigramPos); - } while (BinaryFormat::FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags); + if (0 >= prevWordLength) return NOT_A_DICT_POS; + int pos = mBinaryDictionaryInfo->getStructurePolicy()->getTerminalNodePositionOfWord( + mBinaryDictionaryInfo, prevWord, prevWordLength, forceLowerCaseSearch); + if (NOT_A_VALID_WORD_POS == pos) return NOT_A_DICT_POS; + return mBinaryDictionaryInfo->getStructurePolicy()->getBigramsPositionOfNode( + mBinaryDictionaryInfo, pos); } bool BigramDictionary::checkFirstCharacter(int *word, int *inputCodePoints) const { @@ -195,9 +162,9 @@ bool BigramDictionary::checkFirstCharacter(int *word, int *inputCodePoints) cons // what user typed. int maxAlt = MAX_ALTERNATIVES; - const int firstBaseLowerCodePoint = toBaseLowerCase(*word); + const int firstBaseLowerCodePoint = CharUtils::toBaseLowerCase(*word); while (maxAlt > 0) { - if (toBaseLowerCase(*inputCodePoints) == firstBaseLowerCodePoint) { + if (CharUtils::toBaseLowerCase(*inputCodePoints) == firstBaseLowerCodePoint) { return true; } inputCodePoints++; @@ -206,24 +173,22 @@ bool BigramDictionary::checkFirstCharacter(int *word, int *inputCodePoints) cons return false; } -bool BigramDictionary::isValidBigram(const int *word1, int length1, const int *word2, - int length2) const { - const uint8_t *const root = DICT_ROOT; - int pos = getBigramListPositionForWord(word1, length1, false /* forceLowerCaseSearch */); +bool BigramDictionary::isValidBigram(const int *word0, int length0, const int *word1, + int length1) const { + int pos = getBigramListPositionForWord(word0, length0, false /* forceLowerCaseSearch */); // getBigramListPositionForWord returns 0 if this word isn't in the dictionary or has no bigrams - if (0 == pos) return false; - int nextWordPos = BinaryFormat::getTerminalPosition(root, word2, length2, - false /* forceLowerCaseSearch */); - if (NOT_VALID_WORD == nextWordPos) return false; - uint8_t bigramFlags; - do { - bigramFlags = BinaryFormat::getFlagsAndForwardPointer(root, &pos); - const int bigramPos = BinaryFormat::getAttributeAddressAndForwardPointer(root, bigramFlags, - &pos); - if (bigramPos == nextWordPos) { + if (NOT_A_DICT_POS == pos) return false; + int nextWordPos = mBinaryDictionaryInfo->getStructurePolicy()->getTerminalNodePositionOfWord( + mBinaryDictionaryInfo, word1, length1, false /* forceLowerCaseSearch */); + if (NOT_A_VALID_WORD_POS == nextWordPos) return false; + + BinaryDictionaryBigramsIterator bigramsIt(mBinaryDictionaryInfo, pos); + while (bigramsIt.hasNext()) { + bigramsIt.next(); + if (bigramsIt.getBigramPos() == nextWordPos) { return true; } - } while (BinaryFormat::FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags); + } return false; } diff --git a/native/jni/src/bigram_dictionary.h b/native/jni/src/suggest/core/dictionary/bigram_dictionary.h index b86e564c3..7706a2c22 100644 --- a/native/jni/src/bigram_dictionary.h +++ b/native/jni/src/suggest/core/dictionary/bigram_dictionary.h @@ -17,31 +17,31 @@ #ifndef LATINIME_BIGRAM_DICTIONARY_H #define LATINIME_BIGRAM_DICTIONARY_H -#include <map> -#include <stdint.h> - #include "defines.h" namespace latinime { +class BinaryDictionaryInfo; + class BigramDictionary { public: - BigramDictionary(const uint8_t *const streamStart); - int getBigrams(const int *word, int length, int *inputCodePoints, int inputSize, int *outWords, - int *frequencies, int *outputTypes) const; - void fillBigramAddressToProbabilityMapAndFilter(const int *prevWord, const int prevWordLength, - std::map<int, int> *map, uint8_t *filter) const; + BigramDictionary(const BinaryDictionaryInfo *const binaryDictionaryInfo); + + int getPredictions(const int *word, int length, int *inputCodePoints, int inputSize, + int *outWords, int *frequencies, int *outputTypes) const; bool isValidBigram(const int *word1, int length1, const int *word2, int length2) const; ~BigramDictionary(); + private: DISALLOW_IMPLICIT_CONSTRUCTORS(BigramDictionary); + void addWordBigram(int *word, int length, int probability, int *bigramProbability, int *bigramCodePoints, int *outputTypes) const; bool checkFirstCharacter(int *word, int *inputCodePoints) const; int getBigramListPositionForWord(const int *prevWord, const int prevWordLength, const bool forceLowerCaseSearch) const; - const uint8_t *const DICT_ROOT; + const BinaryDictionaryInfo *const mBinaryDictionaryInfo; // TODO: Re-implement proximity correction for bigram correction static const int MAX_ALTERNATIVES = 1; }; diff --git a/native/jni/src/suggest/core/dictionary/binary_dictionary_bigrams_iterator.h b/native/jni/src/suggest/core/dictionary/binary_dictionary_bigrams_iterator.h new file mode 100644 index 000000000..8cbb12998 --- /dev/null +++ b/native/jni/src/suggest/core/dictionary/binary_dictionary_bigrams_iterator.h @@ -0,0 +1,69 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_BINARY_DICTIONARY_BIGRAMS_ITERATOR_H +#define LATINIME_BINARY_DICTIONARY_BIGRAMS_ITERATOR_H + +#include "defines.h" +#include "suggest/core/dictionary/binary_dictionary_info.h" +#include "suggest/core/dictionary/binary_dictionary_terminal_attributes_reading_utils.h" + +namespace latinime { + +class BinaryDictionaryBigramsIterator { + public: + BinaryDictionaryBigramsIterator( + const BinaryDictionaryInfo *const binaryDictionaryInfo, const int pos) + : mBinaryDictionaryInfo(binaryDictionaryInfo), mPos(pos), mBigramFlags(0), + mBigramPos(NOT_A_DICT_POS), mHasNext(pos != NOT_A_DICT_POS) {} + + AK_FORCE_INLINE bool hasNext() const { + return mHasNext; + } + + AK_FORCE_INLINE void next() { + mBigramFlags = BinaryDictionaryTerminalAttributesReadingUtils::getFlagsAndForwardPointer( + mBinaryDictionaryInfo, &mPos); + mBigramPos = + BinaryDictionaryTerminalAttributesReadingUtils::getBigramAddressAndForwardPointer( + mBinaryDictionaryInfo, mBigramFlags, &mPos); + mHasNext = BinaryDictionaryTerminalAttributesReadingUtils::hasNext(mBigramFlags); + } + + AK_FORCE_INLINE int getProbability() const { + return BinaryDictionaryTerminalAttributesReadingUtils::getProbabilityFromFlags( + mBigramFlags); + } + + AK_FORCE_INLINE int getBigramPos() const { + return mBigramPos; + } + + AK_FORCE_INLINE int getFlags() const { + return mBigramFlags; + } + + private: + DISALLOW_COPY_AND_ASSIGN(BinaryDictionaryBigramsIterator); + + const BinaryDictionaryInfo *const mBinaryDictionaryInfo; + int mPos; + BinaryDictionaryTerminalAttributesReadingUtils::BigramFlags mBigramFlags; + int mBigramPos; + bool mHasNext; +}; +} // namespace latinime +#endif // LATINIME_BINARY_DICTIONARY_BIGRAMS_ITERATOR_H diff --git a/native/jni/src/suggest/core/dictionary/binary_dictionary_format_utils.cpp b/native/jni/src/suggest/core/dictionary/binary_dictionary_format_utils.cpp new file mode 100644 index 000000000..5d14a0554 --- /dev/null +++ b/native/jni/src/suggest/core/dictionary/binary_dictionary_format_utils.cpp @@ -0,0 +1,74 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/core/dictionary/binary_dictionary_format_utils.h" + +namespace latinime { + +/** + * Dictionary size + */ +// Any file smaller than this is not a dictionary. +const int BinaryDictionaryFormatUtils::DICTIONARY_MINIMUM_SIZE = 4; + +/** + * Format versions + */ + +// The versions of Latin IME that only handle format version 1 only test for the magic +// number, so we had to change it so that version 2 files would be rejected by older +// implementations. On this occasion, we made the magic number 32 bits long. +const uint32_t BinaryDictionaryFormatUtils::HEADER_VERSION_2_MAGIC_NUMBER = 0x9BC13AFE; +// Magic number (4 bytes), version (2 bytes), options (2 bytes), header size (4 bytes) = 12 +const int BinaryDictionaryFormatUtils::HEADER_VERSION_2_MINIMUM_SIZE = 12; + +/* static */ BinaryDictionaryFormatUtils::FORMAT_VERSION + BinaryDictionaryFormatUtils::detectFormatVersion(const uint8_t *const dict, + const int dictSize) { + // The magic number is stored big-endian. + // If the dictionary is less than 4 bytes, we can't even read the magic number, so we don't + // understand this format. + if (dictSize < DICTIONARY_MINIMUM_SIZE) { + return UNKNOWN_VERSION; + } + const uint32_t magicNumber = ByteArrayUtils::readUint32(dict, 0); + switch (magicNumber) { + case HEADER_VERSION_2_MAGIC_NUMBER: + // Version 2 header are at least 12 bytes long. + // If this header has the version 2 magic number but is less than 12 bytes long, + // then it's an unknown format and we need to avoid confidently reading the next bytes. + if (dictSize < HEADER_VERSION_2_MINIMUM_SIZE) { + return UNKNOWN_VERSION; + } + // Version 2 header is as follows: + // Magic number (4 bytes) 0x9B 0xC1 0x3A 0xFE + // Version number (2 bytes) + // Options (2 bytes) + // Header size (4 bytes) : integer, big endian + if (ByteArrayUtils::readUint16(dict, 4) == 2) { + return VERSION_2; + } else if (ByteArrayUtils::readUint16(dict, 4) == 3) { + // TODO: Support version 3 dictionary. + return UNKNOWN_VERSION; + } else { + return UNKNOWN_VERSION; + } + default: + return UNKNOWN_VERSION; + } +} + +} // namespace latinime diff --git a/native/jni/src/suggest/core/dictionary/binary_dictionary_format_utils.h b/native/jni/src/suggest/core/dictionary/binary_dictionary_format_utils.h new file mode 100644 index 000000000..830684c70 --- /dev/null +++ b/native/jni/src/suggest/core/dictionary/binary_dictionary_format_utils.h @@ -0,0 +1,52 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_BINARY_DICTIONARY_FORMAT_UTILS_H +#define LATINIME_BINARY_DICTIONARY_FORMAT_UTILS_H + +#include <stdint.h> + +#include "defines.h" +#include "suggest/core/dictionary/byte_array_utils.h" + +namespace latinime { + +/** + * Methods to handle binary dictionary format version. + * + * Currently, we have a file with a similar name, binary_format.h. binary_format.h contains binary + * reading methods and utility methods for various purposes. + * On the other hand, this file deals with only about dictionary format version. + */ +class BinaryDictionaryFormatUtils { + public: + enum FORMAT_VERSION { + VERSION_2, + VERSION_3, + UNKNOWN_VERSION + }; + + static FORMAT_VERSION detectFormatVersion(const uint8_t *const dict, const int dictSize); + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(BinaryDictionaryFormatUtils); + + static const int DICTIONARY_MINIMUM_SIZE; + static const uint32_t HEADER_VERSION_2_MAGIC_NUMBER; + static const int HEADER_VERSION_2_MINIMUM_SIZE; +}; +} // namespace latinime +#endif /* LATINIME_BINARY_DICTIONARY_FORMAT_UTILS_H */ diff --git a/native/jni/src/suggest/core/dictionary/binary_dictionary_header.cpp b/native/jni/src/suggest/core/dictionary/binary_dictionary_header.cpp new file mode 100644 index 000000000..91c643a5f --- /dev/null +++ b/native/jni/src/suggest/core/dictionary/binary_dictionary_header.cpp @@ -0,0 +1,49 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/core/dictionary/binary_dictionary_header.h" + +#include "defines.h" +#include "suggest/core/dictionary/binary_dictionary_info.h" + +namespace latinime { + +const char *const BinaryDictionaryHeader::MULTIPLE_WORDS_DEMOTION_RATE_KEY = + "MULTIPLE_WORDS_DEMOTION_RATE"; +const float BinaryDictionaryHeader::DEFAULT_MULTI_WORD_COST_MULTIPLIER = 1.0f; +const float BinaryDictionaryHeader::MULTI_WORD_COST_MULTIPLIER_SCALE = 100.0f; + +BinaryDictionaryHeader::BinaryDictionaryHeader( + const BinaryDictionaryInfo *const binaryDictionaryInfo) + : mBinaryDictionaryInfo(binaryDictionaryInfo), + mDictionaryFlags(BinaryDictionaryHeaderReadingUtils::getFlags(binaryDictionaryInfo)), + mSize(BinaryDictionaryHeaderReadingUtils::getHeaderSize(binaryDictionaryInfo)), + mMultiWordCostMultiplier(readMultiWordCostMultiplier()) {} + +float BinaryDictionaryHeader::readMultiWordCostMultiplier() const { + const int headerValue = BinaryDictionaryHeaderReadingUtils::readHeaderValueInt( + mBinaryDictionaryInfo, MULTIPLE_WORDS_DEMOTION_RATE_KEY); + if (headerValue == S_INT_MIN) { + // not found + return DEFAULT_MULTI_WORD_COST_MULTIPLIER; + } + if (headerValue <= 0) { + return static_cast<float>(MAX_VALUE_FOR_WEIGHTING); + } + return MULTI_WORD_COST_MULTIPLIER_SCALE / static_cast<float>(headerValue); +} + +} // namespace latinime diff --git a/native/jni/src/suggest/core/dictionary/binary_dictionary_header.h b/native/jni/src/suggest/core/dictionary/binary_dictionary_header.h new file mode 100644 index 000000000..240512bce --- /dev/null +++ b/native/jni/src/suggest/core/dictionary/binary_dictionary_header.h @@ -0,0 +1,85 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_BINARY_DICTIONARY_HEADER_H +#define LATINIME_BINARY_DICTIONARY_HEADER_H + +#include "defines.h" +#include "suggest/core/dictionary/binary_dictionary_header_reading_utils.h" + +namespace latinime { + +class BinaryDictionaryInfo; + +/** + * This class abstracts dictionary header structures and provide interface to access dictionary + * header information. + */ +class BinaryDictionaryHeader { + public: + explicit BinaryDictionaryHeader(const BinaryDictionaryInfo *const binaryDictionaryInfo); + + AK_FORCE_INLINE int getSize() const { + return mSize; + } + + AK_FORCE_INLINE bool supportsDynamicUpdate() const { + return BinaryDictionaryHeaderReadingUtils::supportsDynamicUpdate(mDictionaryFlags); + } + + AK_FORCE_INLINE bool requiresGermanUmlautProcessing() const { + return BinaryDictionaryHeaderReadingUtils::requiresGermanUmlautProcessing(mDictionaryFlags); + } + + AK_FORCE_INLINE bool requiresFrenchLigatureProcessing() const { + return BinaryDictionaryHeaderReadingUtils::requiresFrenchLigatureProcessing( + mDictionaryFlags); + } + + AK_FORCE_INLINE float getMultiWordCostMultiplier() const { + return mMultiWordCostMultiplier; + } + + AK_FORCE_INLINE void readHeaderValueOrQuestionMark(const char *const key, + int *outValue, int outValueSize) const { + if (outValueSize <= 0) return; + if (outValueSize == 1) { + outValue[0] = '\0'; + return; + } + if (!BinaryDictionaryHeaderReadingUtils::readHeaderValue(mBinaryDictionaryInfo, + key, outValue, outValueSize)) { + outValue[0] = '?'; + outValue[1] = '\0'; + } + } + + private: + DISALLOW_COPY_AND_ASSIGN(BinaryDictionaryHeader); + + static const char *const MULTIPLE_WORDS_DEMOTION_RATE_KEY; + static const float DEFAULT_MULTI_WORD_COST_MULTIPLIER; + static const float MULTI_WORD_COST_MULTIPLIER_SCALE; + + const BinaryDictionaryInfo *const mBinaryDictionaryInfo; + const BinaryDictionaryHeaderReadingUtils::DictionaryFlags mDictionaryFlags; + const int mSize; + const float mMultiWordCostMultiplier; + + float readMultiWordCostMultiplier() const; +}; +} // namespace latinime +#endif // LATINIME_BINARY_DICTIONARY_HEADER_H diff --git a/native/jni/src/suggest/core/dictionary/binary_dictionary_header_reading_utils.cpp b/native/jni/src/suggest/core/dictionary/binary_dictionary_header_reading_utils.cpp new file mode 100644 index 000000000..a57b0f859 --- /dev/null +++ b/native/jni/src/suggest/core/dictionary/binary_dictionary_header_reading_utils.cpp @@ -0,0 +1,123 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/core/dictionary/binary_dictionary_header_reading_utils.h" + +#include <cctype> +#include <cstdlib> + +#include "defines.h" +#include "suggest/core/dictionary/binary_dictionary_info.h" + +namespace latinime { + +const int BinaryDictionaryHeaderReadingUtils::MAX_OPTION_KEY_LENGTH = 256; + +const int BinaryDictionaryHeaderReadingUtils::VERSION_2_HEADER_MAGIC_NUMBER_SIZE = 4; +const int BinaryDictionaryHeaderReadingUtils::VERSION_2_HEADER_DICTIONARY_VERSION_SIZE = 2; +const int BinaryDictionaryHeaderReadingUtils::VERSION_2_HEADER_FLAG_SIZE = 2; +const int BinaryDictionaryHeaderReadingUtils::VERSION_2_HEADER_SIZE_FIELD_SIZE = 4; + +const BinaryDictionaryHeaderReadingUtils::DictionaryFlags + BinaryDictionaryHeaderReadingUtils::NO_FLAGS = 0; +// Flags for special processing +// Those *must* match the flags in makedict (BinaryDictInputOutput#*_PROCESSING_FLAG) or +// something very bad (like, the apocalypse) will happen. Please update both at the same time. +const BinaryDictionaryHeaderReadingUtils::DictionaryFlags + BinaryDictionaryHeaderReadingUtils::GERMAN_UMLAUT_PROCESSING_FLAG = 0x1; +const BinaryDictionaryHeaderReadingUtils::DictionaryFlags + BinaryDictionaryHeaderReadingUtils::SUPPORTS_DYNAMIC_UPDATE_FLAG = 0x2; +const BinaryDictionaryHeaderReadingUtils::DictionaryFlags + BinaryDictionaryHeaderReadingUtils::FRENCH_LIGATURE_PROCESSING_FLAG = 0x4; + +/* static */ int BinaryDictionaryHeaderReadingUtils::getHeaderSize( + const BinaryDictionaryInfo *const binaryDictionaryInfo) { + switch (getHeaderVersion(binaryDictionaryInfo->getFormat())) { + case HEADER_VERSION_2: + // See the format of the header in the comment in + // BinaryDictionaryFormatUtils::detectFormatVersion() + return ByteArrayUtils::readUint32(binaryDictionaryInfo->getDictBuf(), + VERSION_2_HEADER_MAGIC_NUMBER_SIZE + VERSION_2_HEADER_DICTIONARY_VERSION_SIZE + + VERSION_2_HEADER_FLAG_SIZE); + default: + return S_INT_MAX; + } +} + +/* static */ BinaryDictionaryHeaderReadingUtils::DictionaryFlags + BinaryDictionaryHeaderReadingUtils::getFlags( + const BinaryDictionaryInfo *const binaryDictionaryInfo) { + switch (getHeaderVersion(binaryDictionaryInfo->getFormat())) { + case HEADER_VERSION_2: + return ByteArrayUtils::readUint16(binaryDictionaryInfo->getDictBuf(), + VERSION_2_HEADER_MAGIC_NUMBER_SIZE + VERSION_2_HEADER_DICTIONARY_VERSION_SIZE); + default: + return NO_FLAGS; + } +} + +// Returns if the key is found or not and reads the found value into outValue. +/* static */ bool BinaryDictionaryHeaderReadingUtils::readHeaderValue( + const BinaryDictionaryInfo *const binaryDictionaryInfo, + const char *const key, int *outValue, const int outValueSize) { + if (outValueSize <= 0) { + return false; + } + const int headerSize = getHeaderSize(binaryDictionaryInfo); + int pos = getHeaderOptionsPosition(binaryDictionaryInfo->getFormat()); + if (pos == NOT_A_DICT_POS) { + // The header doesn't have header options. + return false; + } + while (pos < headerSize) { + if(ByteArrayUtils::compareStringInBufferWithCharArray( + binaryDictionaryInfo->getDictBuf(), key, headerSize - pos, &pos) == 0) { + // The key was found. + const int length = ByteArrayUtils::readStringAndAdvancePosition( + binaryDictionaryInfo->getDictBuf(), outValueSize, outValue, &pos); + // Add a 0 terminator to the string. + outValue[length < outValueSize ? length : outValueSize - 1] = '\0'; + return true; + } + ByteArrayUtils::advancePositionToBehindString( + binaryDictionaryInfo->getDictBuf(), headerSize - pos, &pos); + } + // The key was not found. + return false; +} + +/* static */ int BinaryDictionaryHeaderReadingUtils::readHeaderValueInt( + const BinaryDictionaryInfo *const binaryDictionaryInfo, const char *const key) { + const int bufferSize = LARGEST_INT_DIGIT_COUNT; + int intBuffer[bufferSize]; + char charBuffer[bufferSize]; + if (!readHeaderValue(binaryDictionaryInfo, key, intBuffer, bufferSize)) { + return S_INT_MIN; + } + for (int i = 0; i < bufferSize; ++i) { + charBuffer[i] = intBuffer[i]; + if (charBuffer[i] == '0') { + break; + } + if (!isdigit(charBuffer[i])) { + // If not a number, return S_INT_MIN + return S_INT_MIN; + } + } + return atoi(charBuffer); +} + +} // namespace latinime diff --git a/native/jni/src/suggest/core/dictionary/binary_dictionary_header_reading_utils.h b/native/jni/src/suggest/core/dictionary/binary_dictionary_header_reading_utils.h new file mode 100644 index 000000000..61748227e --- /dev/null +++ b/native/jni/src/suggest/core/dictionary/binary_dictionary_header_reading_utils.h @@ -0,0 +1,105 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DICTIONARY_HEADER_READING_UTILS_H +#define LATINIME_DICTIONARY_HEADER_READING_UTILS_H + +#include <stdint.h> + +#include "defines.h" +#include "suggest/core/dictionary/binary_dictionary_format_utils.h" + +namespace latinime { + +class BinaryDictionaryInfo; + +class BinaryDictionaryHeaderReadingUtils { + public: + typedef uint16_t DictionaryFlags; + + static const int MAX_OPTION_KEY_LENGTH; + + static int getHeaderSize(const BinaryDictionaryInfo *const binaryDictionaryInfo); + + static DictionaryFlags getFlags(const BinaryDictionaryInfo *const binaryDictionaryInfo); + + static AK_FORCE_INLINE bool supportsDynamicUpdate(const DictionaryFlags flags) { + return (flags & SUPPORTS_DYNAMIC_UPDATE_FLAG) != 0; + } + + static AK_FORCE_INLINE bool requiresGermanUmlautProcessing(const DictionaryFlags flags) { + return (flags & GERMAN_UMLAUT_PROCESSING_FLAG) != 0; + } + + static AK_FORCE_INLINE bool requiresFrenchLigatureProcessing(const DictionaryFlags flags) { + return (flags & FRENCH_LIGATURE_PROCESSING_FLAG) != 0; + } + + static AK_FORCE_INLINE int getHeaderOptionsPosition( + const BinaryDictionaryFormatUtils::FORMAT_VERSION dictionaryFormat) { + switch (getHeaderVersion(dictionaryFormat)) { + case HEADER_VERSION_2: + return VERSION_2_HEADER_MAGIC_NUMBER_SIZE + VERSION_2_HEADER_DICTIONARY_VERSION_SIZE + + VERSION_2_HEADER_FLAG_SIZE + VERSION_2_HEADER_SIZE_FIELD_SIZE; + break; + default: + return NOT_A_DICT_POS; + } + } + + static bool readHeaderValue( + const BinaryDictionaryInfo *const binaryDictionaryInfo, + const char *const key, int *outValue, const int outValueSize); + + static int readHeaderValueInt( + const BinaryDictionaryInfo *const binaryDictionaryInfo, const char *const key); + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(BinaryDictionaryHeaderReadingUtils); + + enum HEADER_VERSION { + HEADER_VERSION_2, + UNKNOWN_HEADER_VERSION + }; + + static const int VERSION_2_HEADER_MAGIC_NUMBER_SIZE; + static const int VERSION_2_HEADER_DICTIONARY_VERSION_SIZE; + static const int VERSION_2_HEADER_FLAG_SIZE; + static const int VERSION_2_HEADER_SIZE_FIELD_SIZE; + + static const DictionaryFlags NO_FLAGS; + // Flags for special processing + // Those *must* match the flags in makedict (FormatSpec#*_PROCESSING_FLAGS) or + // something very bad (like, the apocalypse) will happen. Please update both at the same time. + static const DictionaryFlags GERMAN_UMLAUT_PROCESSING_FLAG; + static const DictionaryFlags SUPPORTS_DYNAMIC_UPDATE_FLAG; + static const DictionaryFlags FRENCH_LIGATURE_PROCESSING_FLAG; + static const DictionaryFlags CONTAINS_BIGRAMS_FLAG; + + static HEADER_VERSION getHeaderVersion( + const BinaryDictionaryFormatUtils::FORMAT_VERSION formatVersion) { + switch(formatVersion) { + case BinaryDictionaryFormatUtils::VERSION_2: + // Fall through + case BinaryDictionaryFormatUtils::VERSION_3: + return HEADER_VERSION_2; + default: + return UNKNOWN_HEADER_VERSION; + } + } +}; +} +#endif /* LATINIME_DICTIONARY_HEADER_READING_UTILS_H */ diff --git a/native/jni/src/suggest/core/dictionary/binary_dictionary_info.h b/native/jni/src/suggest/core/dictionary/binary_dictionary_info.h new file mode 100644 index 000000000..cbea18f90 --- /dev/null +++ b/native/jni/src/suggest/core/dictionary/binary_dictionary_info.h @@ -0,0 +1,124 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_BINARY_DICTIONARY_INFO_H +#define LATINIME_BINARY_DICTIONARY_INFO_H + +#include <stdint.h> + +#include "defines.h" +#include "jni.h" +#include "suggest/core/dictionary/binary_dictionary_format_utils.h" +#include "suggest/core/dictionary/binary_dictionary_header.h" +#include "suggest/policyimpl/dictionary/dictionary_structure_policy_factory.h" +#include "utils/log_utils.h" + +namespace latinime { + +class BinaryDictionaryInfo { + public: + AK_FORCE_INLINE BinaryDictionaryInfo(JNIEnv *env, const uint8_t *const dictBuf, + const int dictSize, const int mmapFd, const int dictBufOffset, const bool isUpdatable) + : mDictBuf(dictBuf), mDictSize(dictSize), mMmapFd(mmapFd), + mDictBufOffset(dictBufOffset), mIsUpdatable(isUpdatable), + mDictionaryFormat(BinaryDictionaryFormatUtils::detectFormatVersion( + mDictBuf, mDictSize)), + mDictionaryHeader(this), mDictRoot(mDictBuf + mDictionaryHeader.getSize()), + mStructurePolicy(DictionaryStructurePolicyFactory::getDictionaryStructurePolicy( + mDictionaryFormat)) { + logDictionaryInfo(env); + } + + AK_FORCE_INLINE const uint8_t *getDictBuf() const { + return mDictBuf; + } + + AK_FORCE_INLINE int getDictSize() const { + return mDictSize; + } + + AK_FORCE_INLINE int getMmapFd() const { + return mMmapFd; + } + + AK_FORCE_INLINE int getDictBufOffset() const { + return mDictBufOffset; + } + + AK_FORCE_INLINE const uint8_t *getDictRoot() const { + return mDictRoot; + } + + AK_FORCE_INLINE BinaryDictionaryFormatUtils::FORMAT_VERSION getFormat() const { + return mDictionaryFormat; + } + + AK_FORCE_INLINE const BinaryDictionaryHeader *getHeader() const { + return &mDictionaryHeader; + } + + AK_FORCE_INLINE bool isDynamicallyUpdatable() const { + // TODO: Support dynamic dictionary formats. + const bool isUpdatableDictionaryFormat = false; + return mIsUpdatable && isUpdatableDictionaryFormat; + } + + AK_FORCE_INLINE const DictionaryStructurePolicy *getStructurePolicy() const { + return mStructurePolicy; + } + + private: + DISALLOW_COPY_AND_ASSIGN(BinaryDictionaryInfo); + + const uint8_t *const mDictBuf; + const int mDictSize; + const int mMmapFd; + const int mDictBufOffset; + const bool mIsUpdatable; + const BinaryDictionaryFormatUtils::FORMAT_VERSION mDictionaryFormat; + const BinaryDictionaryHeader mDictionaryHeader; + const uint8_t *const mDictRoot; + const DictionaryStructurePolicy *const mStructurePolicy; + + AK_FORCE_INLINE void logDictionaryInfo(JNIEnv *const env) const { + const int BUFFER_SIZE = 16; + int dictionaryIdCodePointBuffer[BUFFER_SIZE]; + int versionStringCodePointBuffer[BUFFER_SIZE]; + int dateStringCodePointBuffer[BUFFER_SIZE]; + mDictionaryHeader.readHeaderValueOrQuestionMark("dictionary", + dictionaryIdCodePointBuffer, BUFFER_SIZE); + mDictionaryHeader.readHeaderValueOrQuestionMark("version", + versionStringCodePointBuffer, BUFFER_SIZE); + mDictionaryHeader.readHeaderValueOrQuestionMark("date", + dateStringCodePointBuffer, BUFFER_SIZE); + + char dictionaryIdCharBuffer[BUFFER_SIZE]; + char versionStringCharBuffer[BUFFER_SIZE]; + char dateStringCharBuffer[BUFFER_SIZE]; + intArrayToCharArray(dictionaryIdCodePointBuffer, BUFFER_SIZE, + dictionaryIdCharBuffer, BUFFER_SIZE); + intArrayToCharArray(versionStringCodePointBuffer, BUFFER_SIZE, + versionStringCharBuffer, BUFFER_SIZE); + intArrayToCharArray(dateStringCodePointBuffer, BUFFER_SIZE, + dateStringCharBuffer, BUFFER_SIZE); + + LogUtils::logToJava(env, + "Dictionary info: dictionary = %s ; version = %s ; date = %s ; filesize = %i", + dictionaryIdCharBuffer, versionStringCharBuffer, dateStringCharBuffer, mDictSize); + } +}; +} +#endif /* LATINIME_BINARY_DICTIONARY_INFO_H */ diff --git a/native/jni/src/suggest/core/dictionary/binary_dictionary_terminal_attributes_reading_utils.cpp b/native/jni/src/suggest/core/dictionary/binary_dictionary_terminal_attributes_reading_utils.cpp new file mode 100644 index 000000000..20b77b3b2 --- /dev/null +++ b/native/jni/src/suggest/core/dictionary/binary_dictionary_terminal_attributes_reading_utils.cpp @@ -0,0 +1,66 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/core/dictionary/binary_dictionary_terminal_attributes_reading_utils.h" + +#include "suggest/core/dictionary/binary_dictionary_info.h" +#include "suggest/core/dictionary/byte_array_utils.h" + +namespace latinime { + +typedef BinaryDictionaryTerminalAttributesReadingUtils TaUtils; + +const TaUtils::TerminalAttributeFlags TaUtils::MASK_ATTRIBUTE_ADDRESS_TYPE = 0x30; +const TaUtils::TerminalAttributeFlags TaUtils::FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE = 0x10; +const TaUtils::TerminalAttributeFlags TaUtils::FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES = 0x20; +const TaUtils::TerminalAttributeFlags TaUtils::FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES = 0x30; +const TaUtils::TerminalAttributeFlags TaUtils::FLAG_ATTRIBUTE_OFFSET_NEGATIVE = 0x40; +// Flag for presence of more attributes +const TaUtils::TerminalAttributeFlags TaUtils::FLAG_ATTRIBUTE_HAS_NEXT = 0x80; +// Mask for attribute probability, stored on 4 bits inside the flags byte. +const TaUtils::TerminalAttributeFlags TaUtils::MASK_ATTRIBUTE_PROBABILITY = 0x0F; +const int TaUtils::ATTRIBUTE_ADDRESS_SHIFT = 4; +const int TaUtils::SHORTCUT_LIST_SIZE_FIELD_SIZE = 2; +// The numeric value of the shortcut probability that means 'whitelist'. +const int TaUtils::WHITELIST_SHORTCUT_PROBABILITY = 15; + +/* static */ int TaUtils::getBigramAddressAndForwardPointer( + const BinaryDictionaryInfo *const binaryDictionaryInfo, const TerminalAttributeFlags flags, + int *const pos) { + int offset = 0; + const int origin = *pos; + switch (MASK_ATTRIBUTE_ADDRESS_TYPE & flags) { + case FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE: + offset = ByteArrayUtils::readUint8AndAdvancePosition( + binaryDictionaryInfo->getDictRoot(), pos); + break; + case FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES: + offset = ByteArrayUtils::readUint16AndAdvancePosition( + binaryDictionaryInfo->getDictRoot(), pos); + break; + case FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES: + offset = ByteArrayUtils::readUint24AndAdvancePosition( + binaryDictionaryInfo->getDictRoot(), pos); + break; + } + if (isOffsetNegative(flags)) { + return origin - offset; + } else { + return origin + offset; + } +} + +} // namespace latinime diff --git a/native/jni/src/suggest/core/dictionary/binary_dictionary_terminal_attributes_reading_utils.h b/native/jni/src/suggest/core/dictionary/binary_dictionary_terminal_attributes_reading_utils.h new file mode 100644 index 000000000..375fc7dff --- /dev/null +++ b/native/jni/src/suggest/core/dictionary/binary_dictionary_terminal_attributes_reading_utils.h @@ -0,0 +1,123 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_BINARY_DICTIONARY_TERMINAL_ATTRIBUTES_READING_UTILS_H +#define LATINIME_BINARY_DICTIONARY_TERMINAL_ATTRIBUTES_READING_UTILS_H + +#include <stdint.h> + +#include "defines.h" +#include "suggest/core/dictionary/binary_dictionary_info.h" +#include "suggest/core/dictionary/byte_array_utils.h" + +namespace latinime { + +class BinaryDictionaryTerminalAttributesReadingUtils { + public: + typedef uint8_t TerminalAttributeFlags; + typedef TerminalAttributeFlags BigramFlags; + typedef TerminalAttributeFlags ShortcutFlags; + + static AK_FORCE_INLINE TerminalAttributeFlags getFlagsAndForwardPointer( + const BinaryDictionaryInfo *const binaryDictionaryInfo, int *const pos) { + return ByteArrayUtils::readUint8AndAdvancePosition( + binaryDictionaryInfo->getDictRoot(), pos); + } + + static AK_FORCE_INLINE int getProbabilityFromFlags(const TerminalAttributeFlags flags) { + return flags & MASK_ATTRIBUTE_PROBABILITY; + } + + static AK_FORCE_INLINE bool hasNext(const TerminalAttributeFlags flags) { + return (flags & FLAG_ATTRIBUTE_HAS_NEXT) != 0; + } + + // Bigrams reading methods + static AK_FORCE_INLINE void skipExistingBigrams( + const BinaryDictionaryInfo *const binaryDictionaryInfo, int *const pos) { + BigramFlags flags = getFlagsAndForwardPointer(binaryDictionaryInfo, pos); + while (hasNext(flags)) { + *pos += attributeAddressSize(flags); + flags = getFlagsAndForwardPointer(binaryDictionaryInfo, pos); + } + *pos += attributeAddressSize(flags); + } + + static int getBigramAddressAndForwardPointer( + const BinaryDictionaryInfo *const binaryDictionaryInfo, const BigramFlags flags, + int *const pos); + + // Shortcuts reading methods + // This method returns the size of the shortcut list region excluding the shortcut list size + // field at the beginning. + static AK_FORCE_INLINE int getShortcutListSizeAndForwardPointer( + const BinaryDictionaryInfo *const binaryDictionaryInfo, int *const pos) { + // readUint16andAdvancePosition() returns an offset *including* the uint16 field itself. + return ByteArrayUtils::readUint16AndAdvancePosition( + binaryDictionaryInfo->getDictRoot(), pos) - SHORTCUT_LIST_SIZE_FIELD_SIZE; + } + + static AK_FORCE_INLINE void skipShortcuts( + const BinaryDictionaryInfo *const binaryDictionaryInfo, int *const pos) { + const int shortcutListSize = getShortcutListSizeAndForwardPointer( + binaryDictionaryInfo, pos); + *pos += shortcutListSize; + } + + static AK_FORCE_INLINE bool isWhitelist(const ShortcutFlags flags) { + return getProbabilityFromFlags(flags) == WHITELIST_SHORTCUT_PROBABILITY; + } + + static AK_FORCE_INLINE int readShortcutTarget( + const BinaryDictionaryInfo *const binaryDictionaryInfo, const int maxLength, + int *const outWord, int *const pos) { + return ByteArrayUtils::readStringAndAdvancePosition( + binaryDictionaryInfo->getDictRoot(), maxLength, outWord, pos); + } + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(BinaryDictionaryTerminalAttributesReadingUtils); + + static const TerminalAttributeFlags MASK_ATTRIBUTE_ADDRESS_TYPE; + static const TerminalAttributeFlags FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE; + static const TerminalAttributeFlags FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES; + static const TerminalAttributeFlags FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES; + static const TerminalAttributeFlags FLAG_ATTRIBUTE_OFFSET_NEGATIVE; + static const TerminalAttributeFlags FLAG_ATTRIBUTE_HAS_NEXT; + static const TerminalAttributeFlags MASK_ATTRIBUTE_PROBABILITY; + static const int ATTRIBUTE_ADDRESS_SHIFT; + static const int SHORTCUT_LIST_SIZE_FIELD_SIZE; + static const int WHITELIST_SHORTCUT_PROBABILITY; + + static AK_FORCE_INLINE bool isOffsetNegative(const TerminalAttributeFlags flags) { + return (flags & FLAG_ATTRIBUTE_OFFSET_NEGATIVE) != 0; + } + + static AK_FORCE_INLINE int attributeAddressSize(const TerminalAttributeFlags flags) { + return (flags & MASK_ATTRIBUTE_ADDRESS_TYPE) >> ATTRIBUTE_ADDRESS_SHIFT; + /* Note: this is a value-dependant optimization of what may probably be + more readably written this way: + switch (flags * BinaryFormat::MASK_ATTRIBUTE_ADDRESS_TYPE) { + case FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE: return 1; + case FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES: return 2; + case FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTE: return 3; + default: return 0; + } + */ + } +}; +} +#endif /* LATINIME_BINARY_DICTIONARY_TERMINAL_ATTRIBUTES_READING_UTILS_H */ diff --git a/native/jni/src/dic_traverse_wrapper.cpp b/native/jni/src/suggest/core/dictionary/bloom_filter.cpp index 88ca9fa0d..4ae474e0c 100644 --- a/native/jni/src/dic_traverse_wrapper.cpp +++ b/native/jni/src/suggest/core/dictionary/bloom_filter.cpp @@ -1,5 +1,5 @@ /* - * Copyright (C) 2012, The Android Open Source Project + * Copyright (C) 2013, The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,13 +14,12 @@ * limitations under the License. */ -#define LOG_TAG "LatinIME: jni: Session" - -#include "dic_traverse_wrapper.h" +#include "suggest/core/dictionary/bloom_filter.h" namespace latinime { -void *(*DicTraverseWrapper::sDicTraverseSessionFactoryMethod)(JNIEnv *, jstring) = 0; -void (*DicTraverseWrapper::sDicTraverseSessionReleaseMethod)(void *) = 0; -void (*DicTraverseWrapper::sDicTraverseSessionInitMethod)( - void *, const Dictionary *const, const int *, const int) = 0; + +// Must be smaller than BIGRAM_FILTER_BYTE_SIZE * 8, and preferably prime. 1021 is the largest +// prime under 128 * 8. +const int BloomFilter::BIGRAM_FILTER_MODULO = 1021; + } // namespace latinime diff --git a/native/jni/src/suggest/core/dictionary/bloom_filter.h b/native/jni/src/suggest/core/dictionary/bloom_filter.h new file mode 100644 index 000000000..5205456a8 --- /dev/null +++ b/native/jni/src/suggest/core/dictionary/bloom_filter.h @@ -0,0 +1,70 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_BLOOM_FILTER_H +#define LATINIME_BLOOM_FILTER_H + +#include <stdint.h> + +#include "defines.h" + +namespace latinime { + +// This bloom filter is used for optimizing bigram retrieval. +// Execution times with previous word "this" are as follows: +// without bloom filter (use only hash_map): +// Total 147792.34 (sum of others 147771.57) +// with bloom filter: +// Total 145900.64 (sum of others 145874.30) +// always read binary dictionary: +// Total 148603.14 (sum of others 148579.90) +class BloomFilter { + public: + BloomFilter() { + ASSERT(BIGRAM_FILTER_BYTE_SIZE * 8 >= BIGRAM_FILTER_MODULO); + } + + // TODO: uint32_t position + AK_FORCE_INLINE void setInFilter(const int32_t position) { + const uint32_t bucket = static_cast<uint32_t>(position % BIGRAM_FILTER_MODULO); + mFilter[bucket >> 3] |= static_cast<uint8_t>(1 << (bucket & 0x7)); + } + + // TODO: uint32_t position + AK_FORCE_INLINE bool isInFilter(const int32_t position) const { + const uint32_t bucket = static_cast<uint32_t>(position % BIGRAM_FILTER_MODULO); + return (mFilter[bucket >> 3] & static_cast<uint8_t>(1 << (bucket & 0x7))) != 0; + } + + private: + // Size, in bytes, of the bloom filter index for bigrams + // 128 gives us 1024 buckets. The probability of false positive is (1 - e ** (-kn/m))**k, + // where k is the number of hash functions, n the number of bigrams, and m the number of + // bits we can test. + // At the moment 100 is the maximum number of bigrams for a word with the current + // dictionaries, so n = 100. 1024 buckets give us m = 1024. + // With 1 hash function, our false positive rate is about 9.3%, which should be enough for + // our uses since we are only using this to increase average performance. For the record, + // k = 2 gives 3.1% and k = 3 gives 1.6%. With k = 1, making m = 2048 gives 4.8%, + // and m = 4096 gives 2.4%. + // This is assigned here because it is used for array size. + static const int BIGRAM_FILTER_BYTE_SIZE = 128; + static const int BIGRAM_FILTER_MODULO; + + uint8_t mFilter[BIGRAM_FILTER_BYTE_SIZE]; +}; +} // namespace latinime +#endif // LATINIME_BLOOM_FILTER_H diff --git a/native/jni/src/suggest/core/dictionary/byte_array_utils.cpp b/native/jni/src/suggest/core/dictionary/byte_array_utils.cpp new file mode 100644 index 000000000..68b1d5d15 --- /dev/null +++ b/native/jni/src/suggest/core/dictionary/byte_array_utils.cpp @@ -0,0 +1,24 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/core/dictionary/byte_array_utils.h" + +namespace latinime { + +const uint8_t ByteArrayUtils::MINIMAL_ONE_BYTE_CHARACTER_VALUE = 0x20; +const uint8_t ByteArrayUtils::CHARACTER_ARRAY_TERMINATOR = 0x1F; + +} // namespace latinime diff --git a/native/jni/src/suggest/core/dictionary/byte_array_utils.h b/native/jni/src/suggest/core/dictionary/byte_array_utils.h new file mode 100644 index 000000000..75ccfc766 --- /dev/null +++ b/native/jni/src/suggest/core/dictionary/byte_array_utils.h @@ -0,0 +1,192 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_BYTE_ARRAY_UTILS_H +#define LATINIME_BYTE_ARRAY_UTILS_H + +#include <stdint.h> + +#include "defines.h" + +namespace latinime { + +/** + * Utility methods for reading byte arrays. + */ +class ByteArrayUtils { + public: + /** + * Integer + * + * Each method read a corresponding size integer in a big endian manner. + */ + static AK_FORCE_INLINE uint32_t readUint32(const uint8_t *const buffer, const int pos) { + return (buffer[pos] << 24) ^ (buffer[pos + 1] << 16) + ^ (buffer[pos + 2] << 8) ^ buffer[pos + 3]; + } + + static AK_FORCE_INLINE uint32_t readUint24(const uint8_t *const buffer, const int pos) { + return (buffer[pos] << 16) ^ (buffer[pos + 1] << 8) ^ buffer[pos + 2]; + } + + static AK_FORCE_INLINE uint16_t readUint16(const uint8_t *const buffer, const int pos) { + return (buffer[pos] << 8) ^ buffer[pos + 1]; + } + + static AK_FORCE_INLINE uint8_t readUint8(const uint8_t *const buffer, const int pos) { + return buffer[pos]; + } + + static AK_FORCE_INLINE uint32_t readUint32AndAdvancePosition( + const uint8_t *const buffer, int *const pos) { + const uint32_t value = readUint32(buffer, *pos); + *pos += 4; + return value; + } + + static AK_FORCE_INLINE int readSint24AndAdvancePosition( + const uint8_t *const buffer, int *const pos) { + const uint8_t value = readUint8(buffer, *pos); + if (value < 0x80) { + return readUint24AndAdvancePosition(buffer, pos); + } else { + (*pos)++; + return -(((value & 0x7F) << 16) ^ readUint16AndAdvancePosition(buffer, pos)); + } + } + + static AK_FORCE_INLINE uint32_t readUint24AndAdvancePosition( + const uint8_t *const buffer, int *const pos) { + const uint32_t value = readUint24(buffer, *pos); + *pos += 3; + return value; + } + + static AK_FORCE_INLINE uint16_t readUint16AndAdvancePosition( + const uint8_t *const buffer, int *const pos) { + const uint16_t value = readUint16(buffer, *pos); + *pos += 2; + return value; + } + + static AK_FORCE_INLINE uint8_t readUint8AndAdvancePosition( + const uint8_t *const buffer, int *const pos) { + return buffer[(*pos)++]; + } + + /** + * Code Point + * + * 1 byte = bbbbbbbb match + * case 000xxxxx: xxxxx << 16 + next byte << 8 + next byte + * else: if 00011111 (= 0x1F) : this is the terminator. This is a relevant choice because + * unicode code points range from 0 to 0x10FFFF, so any 3-byte value starting with + * 00011111 would be outside unicode. + * else: iso-latin-1 code + * This allows for the whole unicode range to be encoded, including chars outside of + * the BMP. Also everything in the iso-latin-1 charset is only 1 byte, except control + * characters which should never happen anyway (and still work, but take 3 bytes). + */ + static AK_FORCE_INLINE int readCodePoint(const uint8_t *const buffer, const int pos) { + int p = pos; + return readCodePointAndAdvancePosition(buffer, &p); + } + + static AK_FORCE_INLINE int readCodePointAndAdvancePosition( + const uint8_t *const buffer, int *const pos) { + const uint8_t firstByte = readUint8(buffer, *pos); + if (firstByte < MINIMAL_ONE_BYTE_CHARACTER_VALUE) { + if (firstByte == CHARACTER_ARRAY_TERMINATOR) { + *pos += 1; + return NOT_A_CODE_POINT; + } else { + return readUint24AndAdvancePosition(buffer, pos); + } + } else { + *pos += 1; + return firstByte; + } + } + + /** + * String (array of code points) + * + * Reads code points until the terminator is found. + */ + // Returns the length of the string. + static int readStringAndAdvancePosition(const uint8_t *const buffer, + const int maxLength, int *const outBuffer, int *const pos) { + int length = 0; + int codePoint = readCodePointAndAdvancePosition(buffer, pos); + while (NOT_A_CODE_POINT != codePoint && length < maxLength) { + outBuffer[length++] = codePoint; + codePoint = readCodePointAndAdvancePosition(buffer, pos); + } + return length; + } + + // Advances the position and returns the length of the string. + static int advancePositionToBehindString( + const uint8_t *const buffer, const int maxLength, int *const pos) { + int length = 0; + int codePoint = readCodePointAndAdvancePosition(buffer, pos); + while (NOT_A_CODE_POINT != codePoint && length < maxLength) { + codePoint = readCodePointAndAdvancePosition(buffer, pos); + } + return length; + } + + // Returns an integer less than, equal to, or greater than zero when string starting from pos + // in buffer is less than, match, or is greater than charArray. + static AK_FORCE_INLINE int compareStringInBufferWithCharArray(const uint8_t *const buffer, + const char *const charArray, const int maxLength, int *const pos) { + int index = 0; + int codePoint = readCodePointAndAdvancePosition(buffer, pos); + const uint8_t *const uint8CharArrayForComparison = + reinterpret_cast<const uint8_t *>(charArray); + while (NOT_A_CODE_POINT != codePoint + && '\0' != uint8CharArrayForComparison[index] && index < maxLength) { + if (codePoint != uint8CharArrayForComparison[index]) { + // Different character is found. + // Skip the rest of the string in the buffer. + advancePositionToBehindString(buffer, maxLength - index, pos); + return codePoint - uint8CharArrayForComparison[index]; + } + // Advance + codePoint = readCodePointAndAdvancePosition(buffer, pos); + ++index; + } + if (NOT_A_CODE_POINT != codePoint && index < maxLength) { + // Skip the rest of the string in the buffer. + advancePositionToBehindString(buffer, maxLength - index, pos); + } + if (NOT_A_CODE_POINT == codePoint && '\0' == uint8CharArrayForComparison[index]) { + // When both of the last characters are terminals, we consider the string in the buffer + // matches the given char array + return 0; + } else { + return codePoint - uint8CharArrayForComparison[index]; + } + } + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(ByteArrayUtils); + + static const uint8_t MINIMAL_ONE_BYTE_CHARACTER_VALUE; + static const uint8_t CHARACTER_ARRAY_TERMINATOR; +}; +} // namespace latinime +#endif /* LATINIME_BYTE_ARRAY_UTILS_H */ diff --git a/native/jni/src/suggest/core/dictionary/dictionary.cpp b/native/jni/src/suggest/core/dictionary/dictionary.cpp new file mode 100644 index 000000000..4f5d29f6a --- /dev/null +++ b/native/jni/src/suggest/core/dictionary/dictionary.cpp @@ -0,0 +1,130 @@ +/* + * Copyright (C) 2009, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#define LOG_TAG "LatinIME: dictionary.cpp" + +#include "suggest/core/dictionary/dictionary.h" + +#include <map> // TODO: remove +#include <stdint.h> + +#include "defines.h" +#include "jni.h" +#include "suggest/core/dictionary/bigram_dictionary.h" +#include "suggest/core/session/dic_traverse_session.h" +#include "suggest/core/suggest.h" +#include "suggest/core/suggest_options.h" +#include "suggest/policyimpl/gesture/gesture_suggest_policy_factory.h" +#include "suggest/policyimpl/typing/typing_suggest_policy_factory.h" + +namespace latinime { + +Dictionary::Dictionary(JNIEnv *env, void *dict, int dictSize, int mmapFd, + int dictBufOffset, bool isUpdatable) + : mBinaryDictionaryInfo(env, static_cast<const uint8_t *>(dict), dictSize, mmapFd, + dictBufOffset, isUpdatable), + mBigramDictionary(new BigramDictionary(&mBinaryDictionaryInfo)), + mGestureSuggest(new Suggest(GestureSuggestPolicyFactory::getGestureSuggestPolicy())), + mTypingSuggest(new Suggest(TypingSuggestPolicyFactory::getTypingSuggestPolicy())) { +} + +Dictionary::~Dictionary() { + delete mBigramDictionary; + delete mGestureSuggest; + delete mTypingSuggest; +} + +int Dictionary::getSuggestions(ProximityInfo *proximityInfo, DicTraverseSession *traverseSession, + int *xcoordinates, int *ycoordinates, int *times, int *pointerIds, int *inputCodePoints, + int inputSize, int *prevWordCodePoints, int prevWordLength, int commitPoint, + const SuggestOptions *const suggestOptions, int *outWords, int *frequencies, + int *spaceIndices, int *outputTypes) const { + int result = 0; + if (suggestOptions->isGesture()) { + DicTraverseSession::initSessionInstance( + traverseSession, this, prevWordCodePoints, prevWordLength, suggestOptions); + result = mGestureSuggest->getSuggestions(proximityInfo, traverseSession, xcoordinates, + ycoordinates, times, pointerIds, inputCodePoints, inputSize, commitPoint, outWords, + frequencies, spaceIndices, outputTypes); + if (DEBUG_DICT) { + DUMP_RESULT(outWords, frequencies); + } + return result; + } else { + DicTraverseSession::initSessionInstance( + traverseSession, this, prevWordCodePoints, prevWordLength, suggestOptions); + result = mTypingSuggest->getSuggestions(proximityInfo, traverseSession, xcoordinates, + ycoordinates, times, pointerIds, inputCodePoints, inputSize, commitPoint, + outWords, frequencies, spaceIndices, outputTypes); + if (DEBUG_DICT) { + DUMP_RESULT(outWords, frequencies); + } + return result; + } +} + +int Dictionary::getBigrams(const int *word, int length, int *inputCodePoints, int inputSize, + int *outWords, int *frequencies, int *outputTypes) const { + if (length <= 0) return 0; + return mBigramDictionary->getPredictions(word, length, inputCodePoints, inputSize, outWords, + frequencies, outputTypes); +} + +int Dictionary::getProbability(const int *word, int length) const { + const DictionaryStructurePolicy *const structurePolicy = + mBinaryDictionaryInfo.getStructurePolicy(); + int pos = structurePolicy->getTerminalNodePositionOfWord(&mBinaryDictionaryInfo, word, length, + false /* forceLowerCaseSearch */); + if (NOT_A_VALID_WORD_POS == pos) { + return NOT_A_PROBABILITY; + } + return structurePolicy->getUnigramProbability(&mBinaryDictionaryInfo, pos); +} + +bool Dictionary::isValidBigram(const int *word0, int length0, const int *word1, int length1) const { + return mBigramDictionary->isValidBigram(word0, length0, word1, length1); +} + +void Dictionary::addUnigramWord(const int *const word, const int length, const int probability) { + if (!mBinaryDictionaryInfo.isDynamicallyUpdatable()) { + // This method should not be called for non-updatable dictionary. + AKLOGI("Warning: Dictionary::addUnigramWord() is called for non-updatable dictionary."); + return; + } + // TODO: Support dynamic update +} + +void Dictionary::addBigramWords(const int *const word0, const int length0, const int *const word1, + const int length1, const int probability) { + if (!mBinaryDictionaryInfo.isDynamicallyUpdatable()) { + // This method should not be called for non-updatable dictionary. + AKLOGI("Warning: Dictionary::addBigramWords() is called for non-updatable dictionary."); + return; + } + // TODO: Support dynamic update +} + +void Dictionary::removeBigramWords(const int *const word0, const int length0, + const int *const word1, const int length1) { + if (!mBinaryDictionaryInfo.isDynamicallyUpdatable()) { + // This method should not be called for non-updatable dictionary. + AKLOGI("Warning: Dictionary::removeBigramWords() is called for non-updatable dictionary."); + return; + } + // TODO: Support dynamic update +} + +} // namespace latinime diff --git a/native/jni/src/dictionary.h b/native/jni/src/suggest/core/dictionary/dictionary.h index 2ad5b6c0b..1bf24a85b 100644 --- a/native/jni/src/dictionary.h +++ b/native/jni/src/suggest/core/dictionary/dictionary.h @@ -20,13 +20,16 @@ #include <stdint.h> #include "defines.h" +#include "jni.h" +#include "suggest/core/dictionary/binary_dictionary_info.h" namespace latinime { class BigramDictionary; +class DicTraverseSession; class ProximityInfo; class SuggestInterface; -class UnigramDictionary; +class SuggestOptions; class Dictionary { public: @@ -41,48 +44,49 @@ class Dictionary { static const int KIND_APP_DEFINED = 6; // Suggested by the application static const int KIND_SHORTCUT = 7; // A shortcut static const int KIND_PREDICTION = 8; // A prediction (== a suggestion with no input) + // KIND_RESUMED: A resumed suggestion (comes from a span, currently this type is used only + // in java for re-correction) + static const int KIND_RESUMED = 9; + static const int KIND_OOV_CORRECTION = 10; // Most probable string correction static const int KIND_MASK_FLAGS = 0xFFFFFF00; // Mask to get the flags static const int KIND_FLAG_POSSIBLY_OFFENSIVE = 0x80000000; static const int KIND_FLAG_EXACT_MATCH = 0x40000000; - Dictionary(void *dict, int dictSize, int mmapFd, int dictBufAdjust); + Dictionary(JNIEnv *env, void *dict, int dictSize, int mmapFd, int dictBufOffset, + bool isUpdatable); - int getSuggestions(ProximityInfo *proximityInfo, void *traverseSession, int *xcoordinates, - int *ycoordinates, int *times, int *pointerIds, int *inputCodePoints, int inputSize, - int *prevWordCodePoints, int prevWordLength, int commitPoint, bool isGesture, - bool useFullEditDistance, int *outWords, int *frequencies, int *spaceIndices, - int *outputTypes) const; + int getSuggestions(ProximityInfo *proximityInfo, DicTraverseSession *traverseSession, + int *xcoordinates, int *ycoordinates, int *times, int *pointerIds, int *inputCodePoints, + int inputSize, int *prevWordCodePoints, int prevWordLength, int commitPoint, + const SuggestOptions *const suggestOptions, int *outWords, int *frequencies, + int *spaceIndices, int *outputTypes) const; int getBigrams(const int *word, int length, int *inputCodePoints, int inputSize, int *outWords, int *frequencies, int *outputTypes) const; int getProbability(const int *word, int length) const; - bool isValidBigram(const int *word1, int length1, const int *word2, int length2) const; - const uint8_t *getDict() const { // required to release dictionary buffer - return mDict; - } - const uint8_t *getOffsetDict() const { - return mOffsetDict; + + bool isValidBigram(const int *word0, int length0, const int *word1, int length1) const; + + void addUnigramWord(const int *const word, const int length, const int probability); + + void addBigramWords(const int *const word0, const int length0, const int *const word1, + const int length1, const int probability); + + void removeBigramWords(const int *const word0, const int length0, const int *const word1, + const int length1); + + const BinaryDictionaryInfo *getBinaryDictionaryInfo() const { + return &mBinaryDictionaryInfo; } - int getDictSize() const { return mDictSize; } - int getMmapFd() const { return mMmapFd; } - int getDictBufAdjust() const { return mDictBufAdjust; } - int getDictFlags() const; + virtual ~Dictionary(); private: DISALLOW_IMPLICIT_CONSTRUCTORS(Dictionary); - const uint8_t *mDict; - const uint8_t *mOffsetDict; - - // Used only for the mmap version of dictionary loading, but we use these as dummy variables - // also for the malloc version. - const int mDictSize; - const int mMmapFd; - const int mDictBufAdjust; - const UnigramDictionary *mUnigramDictionary; + const BinaryDictionaryInfo mBinaryDictionaryInfo; const BigramDictionary *mBigramDictionary; SuggestInterface *mGestureSuggest; SuggestInterface *mTypingSuggest; diff --git a/native/jni/src/digraph_utils.cpp b/native/jni/src/suggest/core/dictionary/digraph_utils.cpp index 083442669..af378b1b7 100644 --- a/native/jni/src/digraph_utils.cpp +++ b/native/jni/src/suggest/core/dictionary/digraph_utils.cpp @@ -14,10 +14,13 @@ * limitations under the License. */ -#include "char_utils.h" -#include "binary_format.h" +#include "suggest/core/dictionary/digraph_utils.h" + +#include <cstdlib> + #include "defines.h" -#include "digraph_utils.h" +#include "suggest/core/dictionary/binary_dictionary_header.h" +#include "utils/char_utils.h" namespace latinime { @@ -32,8 +35,8 @@ const DigraphUtils::DigraphType DigraphUtils::USED_DIGRAPH_TYPES[] = { DIGRAPH_TYPE_GERMAN_UMLAUT, DIGRAPH_TYPE_FRENCH_LIGATURES }; /* static */ bool DigraphUtils::hasDigraphForCodePoint( - const int dictFlags, const int compositeGlyphCodePoint) { - const DigraphUtils::DigraphType digraphType = getDigraphTypeForDictionary(dictFlags); + const BinaryDictionaryHeader *const header, const int compositeGlyphCodePoint) { + const DigraphUtils::DigraphType digraphType = getDigraphTypeForDictionary(header); if (DigraphUtils::getDigraphForDigraphTypeAndCodePoint(digraphType, compositeGlyphCodePoint)) { return true; } @@ -42,24 +45,16 @@ const DigraphUtils::DigraphType DigraphUtils::USED_DIGRAPH_TYPES[] = // Returns the digraph type associated with the given dictionary. /* static */ DigraphUtils::DigraphType DigraphUtils::getDigraphTypeForDictionary( - const int dictFlags) { - if (BinaryFormat::REQUIRES_GERMAN_UMLAUT_PROCESSING & dictFlags) { + const BinaryDictionaryHeader *const header) { + if (header->requiresGermanUmlautProcessing()) { return DIGRAPH_TYPE_GERMAN_UMLAUT; } - if (BinaryFormat::REQUIRES_FRENCH_LIGATURES_PROCESSING & dictFlags) { + if (header->requiresFrenchLigatureProcessing()) { return DIGRAPH_TYPE_FRENCH_LIGATURES; } return DIGRAPH_TYPE_NONE; } -// Retrieves the set of all digraphs associated with the given dictionary flags. -// Returns the size of the digraph array, or 0 if none exist. -/* static */ int DigraphUtils::getAllDigraphsForDictionaryAndReturnSize( - const int dictFlags, const DigraphUtils::digraph_t **const digraphs) { - const DigraphUtils::DigraphType digraphType = getDigraphTypeForDictionary(dictFlags); - return getAllDigraphsForDigraphTypeAndReturnSize(digraphType, digraphs); -} - // Returns the digraph codepoint for the given composite glyph codepoint and digraph codepoint index // (which specifies the first or second codepoint in the digraph). /* static */ int DigraphUtils::getDigraphCodePointForIndex(const int compositeGlyphCodePoint, @@ -121,9 +116,9 @@ const DigraphUtils::DigraphType DigraphUtils::USED_DIGRAPH_TYPES[] = /* static */ const DigraphUtils::digraph_t *DigraphUtils::getDigraphForDigraphTypeAndCodePoint( const DigraphUtils::DigraphType digraphType, const int compositeGlyphCodePoint) { const DigraphUtils::digraph_t *digraphs = 0; - const int compositeGlyphLowerCodePoint = toLowerCase(compositeGlyphCodePoint); + const int compositeGlyphLowerCodePoint = CharUtils::toLowerCase(compositeGlyphCodePoint); const int digraphsSize = - DigraphUtils::getAllDigraphsForDictionaryAndReturnSize(digraphType, &digraphs); + DigraphUtils::getAllDigraphsForDigraphTypeAndReturnSize(digraphType, &digraphs); for (int i = 0; i < digraphsSize; i++) { if (digraphs[i].compositeGlyph == compositeGlyphLowerCodePoint) { return &digraphs[i]; diff --git a/native/jni/src/digraph_utils.h b/native/jni/src/suggest/core/dictionary/digraph_utils.h index 94435228e..9d74fe3a6 100644 --- a/native/jni/src/digraph_utils.h +++ b/native/jni/src/suggest/core/dictionary/digraph_utils.h @@ -17,8 +17,12 @@ #ifndef DIGRAPH_UTILS_H #define DIGRAPH_UTILS_H +#include "defines.h" + namespace latinime { +class BinaryDictionaryHeader; + class DigraphUtils { public: typedef enum { @@ -35,17 +39,14 @@ class DigraphUtils { typedef struct { int first; int second; int compositeGlyph; } digraph_t; - static bool hasDigraphForCodePoint(const int dictFlags, const int compositeGlyphCodePoint); - static int getAllDigraphsForDictionaryAndReturnSize( - const int dictFlags, const digraph_t **const digraphs); - static int getDigraphCodePointForIndex(const int dictFlags, const int compositeGlyphCodePoint, - const DigraphCodePointIndex digraphCodePointIndex); + static bool hasDigraphForCodePoint( + const BinaryDictionaryHeader *const header, const int compositeGlyphCodePoint); static int getDigraphCodePointForIndex(const int compositeGlyphCodePoint, const DigraphCodePointIndex digraphCodePointIndex); private: DISALLOW_IMPLICIT_CONSTRUCTORS(DigraphUtils); - static DigraphType getDigraphTypeForDictionary(const int dictFlags); + static DigraphType getDigraphTypeForDictionary(const BinaryDictionaryHeader *const header); static int getAllDigraphsForDigraphTypeAndReturnSize( const DigraphType digraphType, const digraph_t **const digraphs); static const digraph_t *getDigraphForCodePoint(const int compositeGlyphCodePoint); diff --git a/native/jni/src/suggest/core/dictionary/multi_bigram_map.cpp b/native/jni/src/suggest/core/dictionary/multi_bigram_map.cpp new file mode 100644 index 000000000..b1d2f4b4d --- /dev/null +++ b/native/jni/src/suggest/core/dictionary/multi_bigram_map.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/core/dictionary/multi_bigram_map.h" + +#include <cstddef> + +namespace latinime { + +// Max number of bigram maps (previous word contexts) to be cached. Increasing this number +// could improve bigram lookup speed for multi-word suggestions, but at the cost of more memory +// usage. Also, there are diminishing returns since the most frequently used bigrams are +// typically near the beginning of the input and are thus the first ones to be cached. Note +// that these bigrams are reset for each new composing word. +const size_t MultiBigramMap::MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP = 25; + +// Most common previous word contexts currently have 100 bigrams +const int MultiBigramMap::BigramMap::DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP = 100; + +} // namespace latinime diff --git a/native/jni/src/suggest/core/dictionary/multi_bigram_map.h b/native/jni/src/suggest/core/dictionary/multi_bigram_map.h new file mode 100644 index 000000000..d5eafe1bf --- /dev/null +++ b/native/jni/src/suggest/core/dictionary/multi_bigram_map.h @@ -0,0 +1,127 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_MULTI_BIGRAM_MAP_H +#define LATINIME_MULTI_BIGRAM_MAP_H + +#include <cstddef> + +#include "defines.h" +#include "suggest/core/dictionary/binary_dictionary_bigrams_iterator.h" +#include "suggest/core/dictionary/binary_dictionary_info.h" +#include "suggest/core/dictionary/bloom_filter.h" +#include "suggest/core/dictionary/probability_utils.h" +#include "utils/hash_map_compat.h" + +namespace latinime { + +// Class for caching bigram maps for multiple previous word contexts. This is useful since the +// algorithm needs to look up the set of bigrams for every word pair that occurs in every +// multi-word suggestion. +class MultiBigramMap { + public: + MultiBigramMap() : mBigramMaps() {} + ~MultiBigramMap() {} + + // Look up the bigram probability for the given word pair from the cached bigram maps. + // Also caches the bigrams if there is space remaining and they have not been cached already. + int getBigramProbability(const BinaryDictionaryInfo *const binaryDictionaryInfo, + const int wordPosition, const int nextWordPosition, const int unigramProbability) { + hash_map_compat<int, BigramMap>::const_iterator mapPosition = + mBigramMaps.find(wordPosition); + if (mapPosition != mBigramMaps.end()) { + return mapPosition->second.getBigramProbability(nextWordPosition, unigramProbability); + } + if (mBigramMaps.size() < MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP) { + addBigramsForWordPosition(binaryDictionaryInfo, wordPosition); + return mBigramMaps[wordPosition].getBigramProbability( + nextWordPosition, unigramProbability); + } + return readBigramProbabilityFromBinaryDictionary(binaryDictionaryInfo, + wordPosition, nextWordPosition, unigramProbability); + } + + void clear() { + mBigramMaps.clear(); + } + + private: + DISALLOW_COPY_AND_ASSIGN(MultiBigramMap); + + class BigramMap { + public: + BigramMap() : mBigramMap(DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP), mBloomFilter() {} + ~BigramMap() {} + + void init(const BinaryDictionaryInfo *const binaryDictionaryInfo, const int nodePos) { + const int bigramsListPos = binaryDictionaryInfo->getStructurePolicy()-> + getBigramsPositionOfNode(binaryDictionaryInfo, nodePos); + BinaryDictionaryBigramsIterator bigramsIt(binaryDictionaryInfo, bigramsListPos); + while (bigramsIt.hasNext()) { + bigramsIt.next(); + mBigramMap[bigramsIt.getBigramPos()] = bigramsIt.getProbability(); + mBloomFilter.setInFilter(bigramsIt.getBigramPos()); + } + } + + AK_FORCE_INLINE int getBigramProbability( + const int nextWordPosition, const int unigramProbability) const { + if (mBloomFilter.isInFilter(nextWordPosition)) { + const hash_map_compat<int, int>::const_iterator bigramProbabilityIt = + mBigramMap.find(nextWordPosition); + if (bigramProbabilityIt != mBigramMap.end()) { + const int bigramProbability = bigramProbabilityIt->second; + return ProbabilityUtils::computeProbabilityForBigram( + unigramProbability, bigramProbability); + } + } + return ProbabilityUtils::backoff(unigramProbability); + } + + private: + // NOTE: The BigramMap class doesn't use DISALLOW_COPY_AND_ASSIGN() because its default + // copy constructor is needed for use in hash_map. + static const int DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP; + hash_map_compat<int, int> mBigramMap; + BloomFilter mBloomFilter; + }; + + AK_FORCE_INLINE void addBigramsForWordPosition( + const BinaryDictionaryInfo *const binaryDictionaryInfo, const int position) { + mBigramMaps[position].init(binaryDictionaryInfo, position); + } + + AK_FORCE_INLINE int readBigramProbabilityFromBinaryDictionary( + const BinaryDictionaryInfo *const binaryDictionaryInfo, const int nodePos, + const int nextWordPosition, const int unigramProbability) { + const int bigramsListPos = binaryDictionaryInfo->getStructurePolicy()-> + getBigramsPositionOfNode(binaryDictionaryInfo, nodePos); + BinaryDictionaryBigramsIterator bigramsIt(binaryDictionaryInfo, bigramsListPos); + while (bigramsIt.hasNext()) { + bigramsIt.next(); + if (bigramsIt.getBigramPos() == nextWordPosition) { + return ProbabilityUtils::computeProbabilityForBigram( + unigramProbability, bigramsIt.getProbability()); + } + } + return ProbabilityUtils::backoff(unigramProbability); + } + + static const size_t MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP; + hash_map_compat<int, BigramMap> mBigramMaps; +}; +} // namespace latinime +#endif // LATINIME_MULTI_BIGRAM_MAP_H diff --git a/native/jni/src/suggest/core/dictionary/probability_utils.h b/native/jni/src/suggest/core/dictionary/probability_utils.h new file mode 100644 index 000000000..f450087d8 --- /dev/null +++ b/native/jni/src/suggest/core/dictionary/probability_utils.h @@ -0,0 +1,55 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_PROBABILITY_UTILS_H +#define LATINIME_PROBABILITY_UTILS_H + +#include <stdint.h> + +#include "defines.h" + +namespace latinime { + +class ProbabilityUtils { + public: + static AK_FORCE_INLINE int backoff(const int unigramProbability) { + return unigramProbability; + // For some reason, applying the backoff weight gives bad results in tests. To apply the + // backoff weight, we divide the probability by 2, which in our storing format means + // decreasing the score by 8. + // TODO: figure out what's wrong with this. + // return unigramProbability > 8 ? + // unigramProbability - 8 : (0 == unigramProbability ? 0 : 8); + } + + static AK_FORCE_INLINE int computeProbabilityForBigram( + const int unigramProbability, const int bigramProbability) { + // We divide the range [unigramProbability..255] in 16.5 steps - in other words, we want + // the unigram probability to be the median value of the 17th step from the top. A value of + // 0 for the bigram probability represents the middle of the 16th step from the top, + // while a value of 15 represents the middle of the top step. + // See makedict.BinaryDictInputOutput for details. + const float stepSize = static_cast<float>(MAX_PROBABILITY - unigramProbability) + / (1.5f + MAX_BIGRAM_ENCODED_PROBABILITY); + return unigramProbability + + static_cast<int>(static_cast<float>(bigramProbability + 1) * stepSize); + } + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(ProbabilityUtils); +}; +} +#endif /* LATINIME_PROBABILITY_UTILS_H */ diff --git a/native/jni/src/suggest/core/dictionary/shortcut_utils.h b/native/jni/src/suggest/core/dictionary/shortcut_utils.h index c411408ec..3c2180937 100644 --- a/native/jni/src/suggest/core/dictionary/shortcut_utils.h +++ b/native/jni/src/suggest/core/dictionary/shortcut_utils.h @@ -19,7 +19,7 @@ #include "defines.h" #include "suggest/core/dicnode/dic_node_utils.h" -#include "terminal_attributes.h" +#include "suggest/core/dictionary/terminal_attributes.h" namespace latinime { @@ -29,15 +29,15 @@ class ShortcutUtils { int outputWordIndex, const int finalScore, int *const outputCodePoints, int *const frequencies, int *const outputTypes, const bool sameAsTyped) { TerminalAttributes::ShortcutIterator iterator = terminalAttributes->getShortcutIterator(); + int shortcutTarget[MAX_WORD_LENGTH]; while (iterator.hasNextShortcutTarget() && outputWordIndex < MAX_RESULTS) { - int shortcutTarget[MAX_WORD_LENGTH]; - int shortcutProbability; - const int shortcutTargetStringLength = iterator.getNextShortcutTarget( - MAX_WORD_LENGTH, shortcutTarget, &shortcutProbability); + bool isWhilelist; + int shortcutTargetStringLength; + iterator.nextShortcutTarget(MAX_WORD_LENGTH, shortcutTarget, + &shortcutTargetStringLength, &isWhilelist); int shortcutScore; int kind; - if (shortcutProbability == BinaryFormat::WHITELIST_SHORTCUT_PROBABILITY - && sameAsTyped) { + if (isWhilelist && sameAsTyped) { shortcutScore = S_INT_MAX; kind = Dictionary::KIND_WHITELIST; } else { diff --git a/native/jni/src/suggest/core/dictionary/terminal_attributes.h b/native/jni/src/suggest/core/dictionary/terminal_attributes.h new file mode 100644 index 000000000..0da6504eb --- /dev/null +++ b/native/jni/src/suggest/core/dictionary/terminal_attributes.h @@ -0,0 +1,93 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_TERMINAL_ATTRIBUTES_H +#define LATINIME_TERMINAL_ATTRIBUTES_H + +#include <stdint.h> + +#include "suggest/core/dictionary/binary_dictionary_info.h" +#include "suggest/core/dictionary/binary_dictionary_terminal_attributes_reading_utils.h" + +namespace latinime { + +/** + * This class encapsulates information about a terminal that allows to + * retrieve local node attributes like the list of shortcuts without + * exposing the format structure to the client. + */ +class TerminalAttributes { + public: + class ShortcutIterator { + public: + ShortcutIterator(const BinaryDictionaryInfo *const binaryDictionaryInfo, + const int shortcutPos, const bool hasShortcutList) + : mBinaryDictionaryInfo(binaryDictionaryInfo), mPos(shortcutPos), + mHasNextShortcutTarget(hasShortcutList) {} + + inline bool hasNextShortcutTarget() const { + return mHasNextShortcutTarget; + } + + // Gets the shortcut target itself as an int string and put it to outTarget, put its length + // to outTargetLength, put whether it is whitelist to outIsWhitelist. + AK_FORCE_INLINE void nextShortcutTarget( + const int maxDepth, int *const outTarget, int *const outTargetLength, + bool *const outIsWhitelist) { + const BinaryDictionaryTerminalAttributesReadingUtils::ShortcutFlags flags = + BinaryDictionaryTerminalAttributesReadingUtils::getFlagsAndForwardPointer( + mBinaryDictionaryInfo, &mPos); + mHasNextShortcutTarget = + BinaryDictionaryTerminalAttributesReadingUtils::hasNext(flags); + if (outIsWhitelist) { + *outIsWhitelist = + BinaryDictionaryTerminalAttributesReadingUtils::isWhitelist(flags); + } + if (outTargetLength) { + *outTargetLength = + BinaryDictionaryTerminalAttributesReadingUtils::readShortcutTarget( + mBinaryDictionaryInfo, maxDepth, outTarget, &mPos); + } + } + + private: + const BinaryDictionaryInfo *const mBinaryDictionaryInfo; + int mPos; + bool mHasNextShortcutTarget; + }; + + TerminalAttributes(const BinaryDictionaryInfo *const binaryDictionaryInfo, + const int shortcutPos) + : mBinaryDictionaryInfo(binaryDictionaryInfo), mShortcutListSizePos(shortcutPos) {} + + inline ShortcutIterator getShortcutIterator() const { + int shortcutPos = mShortcutListSizePos; + const bool hasShortcutList = shortcutPos != NOT_A_DICT_POS; + if (hasShortcutList) { + BinaryDictionaryTerminalAttributesReadingUtils::getShortcutListSizeAndForwardPointer( + mBinaryDictionaryInfo, &shortcutPos); + } + // shortcutPos is never used if hasShortcutList is false. + return ShortcutIterator(mBinaryDictionaryInfo, shortcutPos, hasShortcutList); + } + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(TerminalAttributes); + const BinaryDictionaryInfo *const mBinaryDictionaryInfo; + const int mShortcutListSizePos; +}; +} // namespace latinime +#endif // LATINIME_TERMINAL_ATTRIBUTES_H diff --git a/native/jni/src/additional_proximity_chars.cpp b/native/jni/src/suggest/core/layout/additional_proximity_chars.cpp index 661c50e91..34b8b37b0 100644 --- a/native/jni/src/additional_proximity_chars.cpp +++ b/native/jni/src/suggest/core/layout/additional_proximity_chars.cpp @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "additional_proximity_chars.h" +#include "suggest/core/layout/additional_proximity_chars.h" namespace latinime { // TODO: Stop using hardcoded additional proximity characters. diff --git a/native/jni/src/additional_proximity_chars.h b/native/jni/src/suggest/core/layout/additional_proximity_chars.h index a88fd6cea..a88fd6cea 100644 --- a/native/jni/src/additional_proximity_chars.h +++ b/native/jni/src/suggest/core/layout/additional_proximity_chars.h diff --git a/native/jni/src/suggest/core/layout/geometry_utils.h b/native/jni/src/suggest/core/layout/geometry_utils.h new file mode 100644 index 000000000..b667df68f --- /dev/null +++ b/native/jni/src/suggest/core/layout/geometry_utils.h @@ -0,0 +1,59 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_GEOMETRY_UTILS_H +#define LATINIME_GEOMETRY_UTILS_H + +#include <cmath> + +#include "defines.h" + +#define ROUND_FLOAT_10000(f) ((f) < 1000.0f && (f) > 0.001f) \ + ? (floorf((f) * 10000.0f) / 10000.0f) : (f) + +namespace latinime { + +class GeometryUtils { + public: + static inline float SQUARE_FLOAT(const float x) { return x * x; } + + static AK_FORCE_INLINE float getAngle(const int x1, const int y1, const int x2, const int y2) { + const int dx = x1 - x2; + const int dy = y1 - y2; + if (dx == 0 && dy == 0) return 0.0f; + return atan2f(static_cast<float>(dy), static_cast<float>(dx)); + } + + static AK_FORCE_INLINE float getAngleDiff(const float a1, const float a2) { + const float deltaA = fabsf(a1 - a2); + const float diff = ROUND_FLOAT_10000(deltaA); + if (diff > M_PI_F) { + const float normalizedDiff = 2.0f * M_PI_F - diff; + return ROUND_FLOAT_10000(normalizedDiff); + } + return diff; + } + + static AK_FORCE_INLINE int getDistanceInt(const int x1, const int y1, const int x2, + const int y2) { + return static_cast<int>(hypotf(static_cast<float>(x1 - x2), static_cast<float>(y1 - y2))); + } + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(GeometryUtils); +}; +} // namespace latinime +#endif // LATINIME_GEOMETRY_UTILS_H diff --git a/native/jni/src/proximity_info.cpp b/native/jni/src/suggest/core/layout/proximity_info.cpp index 88d670d61..e64476d82 100644 --- a/native/jni/src/proximity_info.cpp +++ b/native/jni/src/suggest/core/layout/proximity_info.cpp @@ -14,18 +14,19 @@ * limitations under the License. */ +#define LOG_TAG "LatinIME: proximity_info.cpp" + +#include "suggest/core/layout/proximity_info.h" + #include <cstring> #include <cmath> -#define LOG_TAG "LatinIME: proximity_info.cpp" - -#include "additional_proximity_chars.h" -#include "char_utils.h" #include "defines.h" -#include "geometry_utils.h" #include "jni.h" -#include "proximity_info.h" -#include "proximity_info_params.h" +#include "suggest/core/layout/additional_proximity_chars.h" +#include "suggest/core/layout/geometry_utils.h" +#include "suggest/core/layout/proximity_info_params.h" +#include "utils/char_utils.h" namespace latinime { @@ -58,7 +59,7 @@ ProximityInfo::ProximityInfo(JNIEnv *env, const jstring localeJStr, MOST_COMMON_KEY_WIDTH_SQUARE(mostCommonKeyWidth * mostCommonKeyWidth), MOST_COMMON_KEY_HEIGHT(mostCommonKeyHeight), NORMALIZED_SQUARED_MOST_COMMON_KEY_HYPOTENUSE(1.0f + - SQUARE_FLOAT(static_cast<float>(mostCommonKeyHeight) / + GeometryUtils::SQUARE_FLOAT(static_cast<float>(mostCommonKeyHeight) / static_cast<float>(mostCommonKeyWidth))), CELL_WIDTH((keyboardWidth + gridWidth - 1) / gridWidth), CELL_HEIGHT((keyboardHeight + gridHeight - 1) / gridHeight), @@ -133,24 +134,13 @@ bool ProximityInfo::hasSpaceProximity(const int x, const int y) const { } float ProximityInfo::getNormalizedSquaredDistanceFromCenterFloatG( - const int keyId, const int x, const int y, const float verticalScale) const { - const bool correctTouchPosition = hasTouchPositionCorrectionData(); - const float centerX = static_cast<float>(correctTouchPosition ? getSweetSpotCenterXAt(keyId) - : getKeyCenterXOfKeyIdG(keyId)); - const float visualKeyCenterY = static_cast<float>(getKeyCenterYOfKeyIdG(keyId)); - float centerY; - if (correctTouchPosition) { - const float sweetSpotCenterY = static_cast<float>(getSweetSpotCenterYAt(keyId)); - const float gapY = sweetSpotCenterY - visualKeyCenterY; - centerY = visualKeyCenterY + gapY * verticalScale; - } else { - centerY = visualKeyCenterY; - } + const int keyId, const int x, const int y, const bool isGeometric) const { + const float centerX = static_cast<float>(getKeyCenterXOfKeyIdG(keyId, x, isGeometric)); + const float centerY = static_cast<float>(getKeyCenterYOfKeyIdG(keyId, y, isGeometric)); const float touchX = static_cast<float>(x); const float touchY = static_cast<float>(y); - const float keyWidth = static_cast<float>(getMostCommonKeyWidth()); return ProximityInfoUtils::getSquaredDistanceFloat(centerX, centerY, touchX, touchY) - / SQUARE_FLOAT(keyWidth); + / GeometryUtils::SQUARE_FLOAT(static_cast<float>(getMostCommonKeyWidth())); } int ProximityInfo::getCodePointOf(const int keyIndex) const { @@ -164,44 +154,91 @@ void ProximityInfo::initializeG() { // TODO: Optimize for (int i = 0; i < KEY_COUNT; ++i) { const int code = mKeyCodePoints[i]; - const int lowerCode = toLowerCase(code); + const int lowerCode = CharUtils::toLowerCase(code); mCenterXsG[i] = mKeyXCoordinates[i] + mKeyWidths[i] / 2; mCenterYsG[i] = mKeyYCoordinates[i] + mKeyHeights[i] / 2; + if (hasTouchPositionCorrectionData()) { + // Computes sweet spot center points for geometric input. + const float verticalScale = ProximityInfoParams::VERTICAL_SWEET_SPOT_SCALE_G; + const float sweetSpotCenterY = static_cast<float>(mSweetSpotCenterYs[i]); + const float gapY = sweetSpotCenterY - mCenterYsG[i]; + mSweetSpotCenterYsG[i] = static_cast<int>(mCenterYsG[i] + gapY * verticalScale); + } mCodeToKeyMap[lowerCode] = i; mKeyIndexToCodePointG[i] = lowerCode; } for (int i = 0; i < KEY_COUNT; i++) { mKeyKeyDistancesG[i][i] = 0; for (int j = i + 1; j < KEY_COUNT; j++) { - mKeyKeyDistancesG[i][j] = getDistanceInt( - mCenterXsG[i], mCenterYsG[i], mCenterXsG[j], mCenterYsG[j]); + if (hasTouchPositionCorrectionData()) { + // Computes distances using sweet spots if they exist. + // We have two types of Y coordinate sweet spots, for geometric and for the others. + // The sweet spots for geometric input are used for calculating key-key distances + // here. + mKeyKeyDistancesG[i][j] = GeometryUtils::getDistanceInt( + mSweetSpotCenterXs[i], mSweetSpotCenterYsG[i], + mSweetSpotCenterXs[j], mSweetSpotCenterYsG[j]); + } else { + mKeyKeyDistancesG[i][j] = GeometryUtils::getDistanceInt( + mCenterXsG[i], mCenterYsG[i], mCenterXsG[j], mCenterYsG[j]); + } mKeyKeyDistancesG[j][i] = mKeyKeyDistancesG[i][j]; } } } -int ProximityInfo::getKeyCenterXOfCodePointG(int charCode) const { - return getKeyCenterXOfKeyIdG( - ProximityInfoUtils::getKeyIndexOf(KEY_COUNT, charCode, &mCodeToKeyMap)); -} - -int ProximityInfo::getKeyCenterYOfCodePointG(int charCode) const { - return getKeyCenterYOfKeyIdG( - ProximityInfoUtils::getKeyIndexOf(KEY_COUNT, charCode, &mCodeToKeyMap)); -} - -int ProximityInfo::getKeyCenterXOfKeyIdG(int keyId) const { - if (keyId >= 0) { - return mCenterXsG[keyId]; +// referencePointX is used only for keys wider than most common key width. When the referencePointX +// is NOT_A_COORDINATE, this method calculates the return value without using the line segment. +// isGeometric is currently not used because we don't have extra X coordinates sweet spots for +// geometric input. +int ProximityInfo::getKeyCenterXOfKeyIdG( + const int keyId, const int referencePointX, const bool isGeometric) const { + if (keyId < 0) { + return 0; + } + int centerX = (hasTouchPositionCorrectionData()) ? static_cast<int>(mSweetSpotCenterXs[keyId]) + : mCenterXsG[keyId]; + const int keyWidth = mKeyWidths[keyId]; + if (referencePointX != NOT_A_COORDINATE + && keyWidth > getMostCommonKeyWidth()) { + // For keys wider than most common keys, we use a line segment instead of the center point; + // thus, centerX is adjusted depending on referencePointX. + const int keyWidthHalfDiff = (keyWidth - getMostCommonKeyWidth()) / 2; + if (referencePointX < centerX - keyWidthHalfDiff) { + centerX -= keyWidthHalfDiff; + } else if (referencePointX > centerX + keyWidthHalfDiff) { + centerX += keyWidthHalfDiff; + } else { + centerX = referencePointX; + } } - return 0; + return centerX; } -int ProximityInfo::getKeyCenterYOfKeyIdG(int keyId) const { - if (keyId >= 0) { - return mCenterYsG[keyId]; +// When the referencePointY is NOT_A_COORDINATE, this method calculates the return value without +// using the line segment. +int ProximityInfo::getKeyCenterYOfKeyIdG( + const int keyId, const int referencePointY, const bool isGeometric) const { + // TODO: Remove "isGeometric" and have separate "proximity_info"s for gesture and typing. + if (keyId < 0) { + return 0; + } + int centerY; + if (!hasTouchPositionCorrectionData()) { + centerY = mCenterYsG[keyId]; + } else if (isGeometric) { + centerY = static_cast<int>(mSweetSpotCenterYsG[keyId]); + } else { + centerY = static_cast<int>(mSweetSpotCenterYs[keyId]); + } + if (referencePointY != NOT_A_COORDINATE && + centerY + mKeyHeights[keyId] > KEYBOARD_HEIGHT && centerY < referencePointY) { + // When the distance between center point and bottom edge of the keyboard is shorter than + // the key height, we assume the key is located at the bottom row of the keyboard. + // The center point is extended to the bottom edge for such keys. + return referencePointY; } - return 0; + return centerY; } int ProximityInfo::getKeyKeyDistanceG(const int keyId0, const int keyId1) const { diff --git a/native/jni/src/proximity_info.h b/native/jni/src/suggest/core/layout/proximity_info.h index deb9ae0de..f25949001 100644 --- a/native/jni/src/proximity_info.h +++ b/native/jni/src/suggest/core/layout/proximity_info.h @@ -18,14 +18,12 @@ #define LATINIME_PROXIMITY_INFO_H #include "defines.h" -#include "hash_map_compat.h" #include "jni.h" -#include "proximity_info_utils.h" +#include "suggest/core/layout/proximity_info_utils.h" +#include "utils/hash_map_compat.h" namespace latinime { -class Correction; - class ProximityInfo { public: ProximityInfo(JNIEnv *env, const jstring localeJStr, @@ -39,9 +37,7 @@ class ProximityInfo { bool hasSpaceProximity(const int x, const int y) const; int getNormalizedSquaredDistance(const int inputIndex, const int proximityIndex) const; float getNormalizedSquaredDistanceFromCenterFloatG( - const int keyId, const int x, const int y, - const float verticalScale) const; - bool sameAsTyped(const unsigned short *word, int length) const; + const int keyId, const int x, const int y, const bool isGeometric) const; int getCodePointOf(const int keyIndex) const; bool hasSweetSpotData(const int keyIndex) const { // When there are no calibration data for a key, @@ -68,10 +64,10 @@ class ProximityInfo { int getKeyboardHeight() const { return KEYBOARD_HEIGHT; } float getKeyboardHypotenuse() const { return KEYBOARD_HYPOTENUSE; } - int getKeyCenterXOfCodePointG(int charCode) const; - int getKeyCenterYOfCodePointG(int charCode) const; - int getKeyCenterXOfKeyIdG(int keyId) const; - int getKeyCenterYOfKeyIdG(int keyId) const; + int getKeyCenterXOfKeyIdG( + const int keyId, const int referencePointX, const bool isGeometric) const; + int getKeyCenterYOfKeyIdG( + const int keyId, const int referencePointY, const bool isGeometric) const; int getKeyKeyDistanceG(int keyId0, int keyId1) const; AK_FORCE_INLINE void initializeProximities(const int *const inputCodes, @@ -95,8 +91,6 @@ class ProximityInfo { DISALLOW_IMPLICIT_CONSTRUCTORS(ProximityInfo); void initializeG(); - float calculateNormalizedSquaredDistance(const int keyIndex, const int inputIndex) const; - bool hasInputCoordinates() const; const int GRID_WIDTH; const int GRID_HEIGHT; @@ -120,6 +114,8 @@ class ProximityInfo { int mKeyCodePoints[MAX_KEY_COUNT_IN_A_KEYBOARD]; float mSweetSpotCenterXs[MAX_KEY_COUNT_IN_A_KEYBOARD]; float mSweetSpotCenterYs[MAX_KEY_COUNT_IN_A_KEYBOARD]; + // Sweet spots for geometric input. Note that we have extra sweet spots only for Y coordinates. + float mSweetSpotCenterYsG[MAX_KEY_COUNT_IN_A_KEYBOARD]; float mSweetSpotRadii[MAX_KEY_COUNT_IN_A_KEYBOARD]; hash_map_compat<int, int> mCodeToKeyMap; diff --git a/native/jni/src/proximity_info_params.cpp b/native/jni/src/suggest/core/layout/proximity_info_params.cpp index 2675d9e70..0e887f700 100644 --- a/native/jni/src/proximity_info_params.cpp +++ b/native/jni/src/suggest/core/layout/proximity_info_params.cpp @@ -15,7 +15,7 @@ */ #include "defines.h" -#include "proximity_info_params.h" +#include "suggest/core/layout/proximity_info_params.h" namespace latinime { const float ProximityInfoParams::NOT_A_DISTANCE_FLOAT = -1.0f; diff --git a/native/jni/src/proximity_info_params.h b/native/jni/src/suggest/core/layout/proximity_info_params.h index 4e47f7308..4e47f7308 100644 --- a/native/jni/src/proximity_info_params.h +++ b/native/jni/src/suggest/core/layout/proximity_info_params.h diff --git a/native/jni/src/proximity_info_state.cpp b/native/jni/src/suggest/core/layout/proximity_info_state.cpp index cc5b736bd..7780efdfd 100644 --- a/native/jni/src/proximity_info_state.cpp +++ b/native/jni/src/suggest/core/layout/proximity_info_state.cpp @@ -14,17 +14,19 @@ * limitations under the License. */ +#define LOG_TAG "LatinIME: proximity_info_state.cpp" + +#include "suggest/core/layout/proximity_info_state.h" + #include <cstring> // for memset() and memcpy() #include <sstream> // for debug prints #include <vector> -#define LOG_TAG "LatinIME: proximity_info_state.cpp" - #include "defines.h" -#include "geometry_utils.h" -#include "proximity_info.h" -#include "proximity_info_state.h" -#include "proximity_info_state_utils.h" +#include "suggest/core/layout/geometry_utils.h" +#include "suggest/core/layout/proximity_info.h" +#include "suggest/core/layout/proximity_info_state_utils.h" +#include "utils/char_utils.h" namespace latinime { @@ -95,15 +97,10 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi pushTouchPointStartIndex, lastSavedInputSize); } - // TODO: Remove the dependency of "isGeometric" - const float verticalSweetSpotScale = isGeometric - ? ProximityInfoParams::VERTICAL_SWEET_SPOT_SCALE_G - : ProximityInfoParams::VERTICAL_SWEET_SPOT_SCALE; - if (xCoordinates && yCoordinates) { mSampledInputSize = ProximityInfoStateUtils::updateTouchPoints(mProximityInfo, mMaxPointToKeyLength, mInputProximities, xCoordinates, yCoordinates, times, - pointerIds, verticalSweetSpotScale, inputSize, isGeometric, pointerId, + pointerIds, inputSize, isGeometric, pointerId, pushTouchPointStartIndex, &mSampledInputXs, &mSampledInputYs, &mSampledTimes, &mSampledLengthCache, &mSampledInputIndice); } @@ -121,7 +118,7 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi if (mSampledInputSize > 0) { ProximityInfoStateUtils::initGeometricDistanceInfos(mProximityInfo, mSampledInputSize, - lastSavedInputSize, verticalSweetSpotScale, &mSampledInputXs, &mSampledInputYs, + lastSavedInputSize, isGeometric, &mSampledInputXs, &mSampledInputYs, &mSampledNearKeySets, &mSampledNormalizedSquaredLengthCache); if (isGeometric) { // updates probabilities of skipping or mapping each key for all points. @@ -154,11 +151,6 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi if (!isGeometric && pointerId == 0) { ProximityInfoStateUtils::initPrimaryInputWord( inputSize, mInputProximities, mPrimaryInputWord); - if (mTouchPositionCorrectionEnabled) { - ProximityInfoStateUtils::initNormalizedSquaredDistances( - mProximityInfo, inputSize, xCoordinates, yCoordinates, mInputProximities, - &mSampledInputXs, &mSampledInputYs, mNormalizedSquaredDistances); - } } if (DEBUG_GEO_FULL) { AKLOGI("ProximityState init finished: %d points out of %d", mSampledInputSize, inputSize); @@ -174,7 +166,7 @@ float ProximityInfoState::getPointToKeyLength( const int index = inputIndex * mProximityInfo->getKeyCount() + keyId; return min(mSampledNormalizedSquaredLengthCache[index], mMaxPointToKeyLength); } - if (isIntentionalOmissionCodePoint(codePoint)) { + if (CharUtils::isIntentionalOmissionCodePoint(codePoint)) { return 0.0f; } // If the char is not a key on the keyboard then return the max length. @@ -202,7 +194,7 @@ ProximityType ProximityInfoState::getProximityType(const int index, const int co const bool checkProximityChars, int *proximityIndex) const { const int *currentCodePoints = getProximityCodePointsAt(index); const int firstCodePoint = currentCodePoints[0]; - const int baseLowerC = toBaseLowerCase(codePoint); + const int baseLowerC = CharUtils::toBaseLowerCase(codePoint); // The first char in the array is what user typed. If it matches right away, that means the // user typed that same char for this pos. @@ -214,7 +206,7 @@ ProximityType ProximityInfoState::getProximityType(const int index, const int co // If the non-accented, lowercased version of that first character matches c, then we have a // non-accented version of the accented character the user typed. Treat it as a close char. - if (toBaseLowerCase(firstCodePoint) == baseLowerC) { + if (CharUtils::toBaseLowerCase(firstCodePoint) == baseLowerC) { return PROXIMITY_CHAR; } @@ -256,8 +248,8 @@ ProximityType ProximityInfoState::getProximityTypeG(const int index, const int c if (!isUsed()) { return UNRELATED_CHAR; } - const int lowerCodePoint = toLowerCase(codePoint); - const int baseLowerCodePoint = toBaseCodePoint(lowerCodePoint); + const int lowerCodePoint = CharUtils::toLowerCase(codePoint); + const int baseLowerCodePoint = CharUtils::toBaseCodePoint(lowerCodePoint); for (int i = 0; i < static_cast<int>(mSampledSearchKeyVectors[index].size()); ++i) { if (mSampledSearchKeyVectors[index][i] == lowerCodePoint || mSampledSearchKeyVectors[index][i] == baseLowerCodePoint) { @@ -277,26 +269,6 @@ float ProximityInfoState::getDirection(const int index0, const int index1) const &mSampledInputXs, &mSampledInputYs, index0, index1); } -float ProximityInfoState::getLineToKeyDistance( - const int from, const int to, const int keyId, const bool extend) const { - if (from < 0 || from > mSampledInputSize - 1) { - return 0.0f; - } - if (to < 0 || to > mSampledInputSize - 1) { - return 0.0f; - } - const int x0 = mSampledInputXs[from]; - const int y0 = mSampledInputYs[from]; - const int x1 = mSampledInputXs[to]; - const int y1 = mSampledInputYs[to]; - - const int keyX = mProximityInfo->getKeyCenterXOfKeyIdG(keyId); - const int keyY = mProximityInfo->getKeyCenterYOfKeyIdG(keyId); - - return ProximityInfoUtils::pointToLineSegSquaredDistanceFloat( - keyX, keyY, x0, y0, x1, y1, extend); -} - float ProximityInfoState::getMostProbableString(int *const codePointBuf) const { memcpy(codePointBuf, mMostProbableString, sizeof(mMostProbableString)); return mMostProbableStringProbability; diff --git a/native/jni/src/proximity_info_state.h b/native/jni/src/suggest/core/layout/proximity_info_state.h index bbe8af240..dbcd54488 100644 --- a/native/jni/src/proximity_info_state.h +++ b/native/jni/src/suggest/core/layout/proximity_info_state.h @@ -20,11 +20,10 @@ #include <cstring> // for memset() #include <vector> -#include "char_utils.h" #include "defines.h" -#include "hash_map_compat.h" -#include "proximity_info_params.h" -#include "proximity_info_state_utils.h" +#include "suggest/core/layout/proximity_info_params.h" +#include "suggest/core/layout/proximity_info_state_utils.h" +#include "utils/hash_map_compat.h" namespace latinime { @@ -54,7 +53,6 @@ class ProximityInfoState { mSampledSearchKeyVectors(), mTouchPositionCorrectionEnabled(false), mSampledInputSize(0), mMostProbableStringProbability(0.0f) { memset(mInputProximities, 0, sizeof(mInputProximities)); - memset(mNormalizedSquaredDistances, 0, sizeof(mNormalizedSquaredDistances)); memset(mPrimaryInputWord, 0, sizeof(mPrimaryInputWord)); memset(mMostProbableString, 0, sizeof(mMostProbableString)); } @@ -92,7 +90,7 @@ class ProximityInfoState { return false; } - inline bool existsAdjacentProximityChars(const int index) const { + AK_FORCE_INLINE bool existsAdjacentProximityChars(const int index) const { if (index < 0 || index >= mSampledInputSize) return false; const int currentCodePoint = getPrimaryCodePointAt(index); const int leftIndex = index - 1; @@ -107,12 +105,6 @@ class ProximityInfoState { return false; } - inline int getNormalizedSquaredDistance( - const int inputIndex, const int proximityIndex) const { - return mNormalizedSquaredDistances[ - inputIndex * MAX_PROXIMITY_CHARS_SIZE + proximityIndex]; - } - inline const int *getPrimaryInputWord() const { return mPrimaryInputWord; } @@ -191,24 +183,10 @@ class ProximityInfoState { float getProbability(const int index, const int charCode) const; - float getLineToKeyDistance( - const int from, const int to, const int keyId, const bool extend) const; - bool isKeyInSerchKeysAfterIndex(const int index, const int keyId) const; private: DISALLOW_COPY_AND_ASSIGN(ProximityInfoState); - ///////////////////////////////////////// - // Defined in proximity_info_state.cpp // - ///////////////////////////////////////// - float calculateNormalizedSquaredDistance(const int keyIndex, const int inputIndex) const; - - float calculateSquaredDistanceFromSweetSpotCenter( - const int keyIndex, const int inputIndex) const; - - ///////////////////////////////////////// - // Defined here // - ///////////////////////////////////////// inline const int *getProximityCodePointsAt(const int index) const { return ProximityInfoStateUtils::getProximityCodePointsAt(mInputProximities, index); @@ -250,7 +228,6 @@ class ProximityInfoState { std::vector<std::vector<int> > mSampledSearchKeyVectors; bool mTouchPositionCorrectionEnabled; int mInputProximities[MAX_PROXIMITY_CHARS_SIZE * MAX_WORD_LENGTH]; - int mNormalizedSquaredDistances[MAX_PROXIMITY_CHARS_SIZE * MAX_WORD_LENGTH]; int mSampledInputSize; int mPrimaryInputWord[MAX_WORD_LENGTH]; float mMostProbableStringProbability; diff --git a/native/jni/src/proximity_info_state_utils.cpp b/native/jni/src/suggest/core/layout/proximity_info_state_utils.cpp index 359673cd8..904671f7f 100644 --- a/native/jni/src/proximity_info_state_utils.cpp +++ b/native/jni/src/suggest/core/layout/proximity_info_state_utils.cpp @@ -14,16 +14,17 @@ * limitations under the License. */ +#include "suggest/core/layout/proximity_info_state_utils.h" + #include <cmath> #include <cstring> // for memset() #include <sstream> // for debug prints #include <vector> #include "defines.h" -#include "geometry_utils.h" -#include "proximity_info.h" -#include "proximity_info_params.h" -#include "proximity_info_state_utils.h" +#include "suggest/core/layout/geometry_utils.h" +#include "suggest/core/layout/proximity_info.h" +#include "suggest/core/layout/proximity_info_params.h" namespace latinime { @@ -42,8 +43,8 @@ namespace latinime { const ProximityInfo *const proximityInfo, const int maxPointToKeyLength, const int *const inputProximities, const int *const inputXCoordinates, const int *const inputYCoordinates, const int *const times, const int *const pointerIds, - const float verticalSweetSpotScale, const int inputSize, const bool isGeometric, - const int pointerId, const int pushTouchPointStartIndex, std::vector<int> *sampledInputXs, + const int inputSize, const bool isGeometric, const int pointerId, + const int pushTouchPointStartIndex, std::vector<int> *sampledInputXs, std::vector<int> *sampledInputYs, std::vector<int> *sampledInputTimes, std::vector<int> *sampledLengthCache, std::vector<int> *sampledInputIndice) { if (DEBUG_SAMPLING_POINTS) { @@ -103,16 +104,16 @@ namespace latinime { const int time = times ? times[i] : -1; if (i > 1) { - const float prevAngle = getAngle( + const float prevAngle = GeometryUtils::getAngle( inputXCoordinates[i - 2], inputYCoordinates[i - 2], inputXCoordinates[i - 1], inputYCoordinates[i - 1]); - const float currentAngle = getAngle( + const float currentAngle = GeometryUtils::getAngle( inputXCoordinates[i - 1], inputYCoordinates[i - 1], x, y); - sumAngle += getAngleDiff(prevAngle, currentAngle); + sumAngle += GeometryUtils::getAngleDiff(prevAngle, currentAngle); } if (pushTouchPoint(proximityInfo, maxPointToKeyLength, i, c, x, y, time, - verticalSweetSpotScale, isGeometric /* doSampling */, i == lastInputIndex, + isGeometric, isGeometric /* doSampling */, i == lastInputIndex, sumAngle, currentNearKeysDistances, prevNearKeysDistances, prevPrevNearKeysDistances, sampledInputXs, sampledInputYs, sampledInputTimes, sampledLengthCache, sampledInputIndice)) { @@ -157,7 +158,8 @@ namespace latinime { const float sweetSpotCenterY = proximityInfo->getSweetSpotCenterYAt(keyIndex); const float inputX = static_cast<float>((*sampledInputXs)[inputIndex]); const float inputY = static_cast<float>((*sampledInputYs)[inputIndex]); - return SQUARE_FLOAT(inputX - sweetSpotCenterX) + SQUARE_FLOAT(inputY - sweetSpotCenterY); + return GeometryUtils::SQUARE_FLOAT(inputX - sweetSpotCenterX) + + GeometryUtils::SQUARE_FLOAT(inputY - sweetSpotCenterY); } /* static */ float ProximityInfoStateUtils::calculateNormalizedSquaredDistance( @@ -174,55 +176,14 @@ namespace latinime { } const float squaredDistance = calculateSquaredDistanceFromSweetSpotCenter(proximityInfo, sampledInputXs, sampledInputYs, keyIndex, inputIndex); - const float squaredRadius = SQUARE_FLOAT(proximityInfo->getSweetSpotRadiiAt(keyIndex)); + const float squaredRadius = GeometryUtils::SQUARE_FLOAT( + proximityInfo->getSweetSpotRadiiAt(keyIndex)); return squaredDistance / squaredRadius; } -/* static */ void ProximityInfoStateUtils::initNormalizedSquaredDistances( - const ProximityInfo *const proximityInfo, const int inputSize, const int *inputXCoordinates, - const int *inputYCoordinates, const int *const inputProximities, - const std::vector<int> *const sampledInputXs, const std::vector<int> *const sampledInputYs, - int *normalizedSquaredDistances) { - memset(normalizedSquaredDistances, NOT_A_DISTANCE, - sizeof(normalizedSquaredDistances[0]) * MAX_PROXIMITY_CHARS_SIZE * MAX_WORD_LENGTH); - const bool hasInputCoordinates = sampledInputXs->size() > 0 && sampledInputYs->size() > 0; - for (int i = 0; i < inputSize; ++i) { - const int *proximityCodePoints = getProximityCodePointsAt(inputProximities, i); - const int primaryKey = proximityCodePoints[0]; - const int x = inputXCoordinates[i]; - const int y = inputYCoordinates[i]; - if (DEBUG_PROXIMITY_CHARS) { - int a = x + y + primaryKey; - a += 0; - AKLOGI("--- Primary = %c, x = %d, y = %d", primaryKey, x, y); - } - for (int j = 0; j < MAX_PROXIMITY_CHARS_SIZE && proximityCodePoints[j] > 0; ++j) { - const int currentCodePoint = proximityCodePoints[j]; - const float squaredDistance = - hasInputCoordinates ? calculateNormalizedSquaredDistance( - proximityInfo, sampledInputXs, sampledInputYs, - proximityInfo->getKeyIndexOf(currentCodePoint), i) : - ProximityInfoParams::NOT_A_DISTANCE_FLOAT; - if (squaredDistance >= 0.0f) { - normalizedSquaredDistances[i * MAX_PROXIMITY_CHARS_SIZE + j] = - static_cast<int>(squaredDistance - * ProximityInfoParams::NORMALIZED_SQUARED_DISTANCE_SCALING_FACTOR); - } else { - normalizedSquaredDistances[i * MAX_PROXIMITY_CHARS_SIZE + j] = - (j == 0) ? MATCH_CHAR_WITHOUT_DISTANCE_INFO : - PROXIMITY_CHAR_WITHOUT_DISTANCE_INFO; - } - if (DEBUG_PROXIMITY_CHARS) { - AKLOGI("--- Proximity (%d) = %c", j, currentCodePoint); - } - } - } - -} - /* static */ void ProximityInfoStateUtils::initGeometricDistanceInfos( const ProximityInfo *const proximityInfo, const int sampledInputSize, - const int lastSavedInputSize, const float verticalSweetSpotScale, + const int lastSavedInputSize, const bool isGeometric, const std::vector<int> *const sampledInputXs, const std::vector<int> *const sampledInputYs, std::vector<NearKeycodesSet> *sampledNearKeySets, @@ -238,7 +199,7 @@ namespace latinime { const int y = (*sampledInputYs)[i]; const float normalizedSquaredDistance = proximityInfo->getNormalizedSquaredDistanceFromCenterFloatG( - k, x, y, verticalSweetSpotScale); + k, x, y, isGeometric); (*sampledNormalizedSquaredLengthCache)[index] = normalizedSquaredDistance; if (normalizedSquaredDistance < ProximityInfoParams::NEAR_KEY_NORMALIZED_SQUARED_THRESHOLD) { @@ -285,7 +246,7 @@ namespace latinime { if (i < sampledInputSize - 1 && j >= (*sampledInputIndice)[i + 1]) { break; } - length += getDistanceInt(xCoordinates[j], yCoordinates[j], + length += GeometryUtils::getDistanceInt(xCoordinates[j], yCoordinates[j], xCoordinates[j + 1], yCoordinates[j + 1]); duration += times[j + 1] - times[j]; } @@ -296,7 +257,7 @@ namespace latinime { break; } // TODO: use mSampledLengthCache instead? - length += getDistanceInt(xCoordinates[j], yCoordinates[j], + length += GeometryUtils::getDistanceInt(xCoordinates[j], yCoordinates[j], xCoordinates[j + 1], yCoordinates[j + 1]); duration += times[j + 1] - times[j]; } @@ -349,21 +310,20 @@ namespace latinime { const int y1 = (*sampledInputYs)[index0]; const int x2 = (*sampledInputXs)[index1]; const int y2 = (*sampledInputYs)[index1]; - return getAngle(x1, y1, x2, y2); + return GeometryUtils::getAngle(x1, y1, x2, y2); } // Calculating point to key distance for all near keys and returning the distance between // the given point and the nearest key position. /* static */ float ProximityInfoStateUtils::updateNearKeysDistances( const ProximityInfo *const proximityInfo, const float maxPointToKeyLength, const int x, - const int y, const float verticalSweetspotScale, - NearKeysDistanceMap *const currentNearKeysDistances) { + const int y, const bool isGeometric, NearKeysDistanceMap *const currentNearKeysDistances) { currentNearKeysDistances->clear(); const int keyCount = proximityInfo->getKeyCount(); float nearestKeyDistance = maxPointToKeyLength; for (int k = 0; k < keyCount; ++k) { const float dist = proximityInfo->getNormalizedSquaredDistanceFromCenterFloatG(k, x, y, - verticalSweetspotScale); + isGeometric); if (dist < ProximityInfoParams::NEAR_KEY_THRESHOLD_FOR_DISTANCE) { currentNearKeysDistances->insert(std::pair<int, float>(k, dist)); } @@ -411,9 +371,9 @@ namespace latinime { } const int baseSampleRate = mostCommonKeyWidth; - const int distPrev = getDistanceInt(sampledInputXs->back(), sampledInputYs->back(), - (*sampledInputXs)[size - 2], (*sampledInputYs)[size - 2]) - * ProximityInfoParams::DISTANCE_BASE_SCALE; + const int distPrev = GeometryUtils::getDistanceInt(sampledInputXs->back(), + sampledInputYs->back(), (*sampledInputXs)[size - 2], + (*sampledInputYs)[size - 2]) * ProximityInfoParams::DISTANCE_BASE_SCALE; float score = 0.0f; // Location @@ -425,10 +385,11 @@ namespace latinime { score += ProximityInfoParams::LOCALMIN_DISTANCE_AND_NEAR_TO_KEY_SCORE; } // Angle - const float angle1 = getAngle(x, y, sampledInputXs->back(), sampledInputYs->back()); - const float angle2 = getAngle(sampledInputXs->back(), sampledInputYs->back(), + const float angle1 = GeometryUtils::getAngle(x, y, sampledInputXs->back(), + sampledInputYs->back()); + const float angle2 = GeometryUtils::getAngle(sampledInputXs->back(), sampledInputYs->back(), (*sampledInputXs)[size - 2], (*sampledInputYs)[size - 2]); - const float angleDiff = getAngleDiff(angle1, angle2); + const float angleDiff = GeometryUtils::getAngleDiff(angle1, angle2); // Save corner if (distPrev > baseSampleRate * ProximityInfoParams::CORNER_CHECK_DISTANCE_THRESHOLD_SCALE @@ -443,7 +404,7 @@ namespace latinime { // Returning if previous point is popped or not. /* static */ bool ProximityInfoStateUtils::pushTouchPoint(const ProximityInfo *const proximityInfo, const int maxPointToKeyLength, const int inputIndex, const int nodeCodePoint, int x, int y, - const int time, const float verticalSweetSpotScale, const bool doSampling, + const int time, const bool isGeometric, const bool doSampling, const bool isLastPoint, const float sumAngle, NearKeysDistanceMap *const currentNearKeysDistances, const NearKeysDistanceMap *const prevNearKeysDistances, @@ -457,7 +418,7 @@ namespace latinime { bool popped = false; if (nodeCodePoint < 0 && doSampling) { const float nearest = updateNearKeysDistances(proximityInfo, maxPointToKeyLength, x, y, - verticalSweetSpotScale, currentNearKeysDistances); + isGeometric, currentNearKeysDistances); const float score = getPointScore(mostCommonKeyWidth, x, y, time, isLastPoint, nearest, sumAngle, currentNearKeysDistances, prevNearKeysDistances, prevPrevNearKeysDistances, sampledInputXs, sampledInputYs); @@ -472,13 +433,13 @@ namespace latinime { } // Check if the last point should be skipped. if (isLastPoint && size > 0) { - if (getDistanceInt(x, y, sampledInputXs->back(), sampledInputYs->back()) + if (GeometryUtils::getDistanceInt(x, y, sampledInputXs->back(), sampledInputYs->back()) * ProximityInfoParams::LAST_POINT_SKIP_DISTANCE_SCALE < mostCommonKeyWidth) { // This point is not used because it's too close to the previous point. if (DEBUG_GEO_FULL) { AKLOGI("p0: size = %zd, x = %d, y = %d, lx = %d, ly = %d, dist = %d, " "width = %d", size, x, y, sampledInputXs->back(), - sampledInputYs->back(), getDistanceInt( + sampledInputYs->back(), GeometryUtils::getDistanceInt( x, y, sampledInputXs->back(), sampledInputYs->back()), mostCommonKeyWidth / ProximityInfoParams::LAST_POINT_SKIP_DISTANCE_SCALE); @@ -491,15 +452,15 @@ namespace latinime { if (nodeCodePoint >= 0 && (x < 0 || y < 0)) { const int keyId = proximityInfo->getKeyIndexOf(nodeCodePoint); if (keyId >= 0) { - x = proximityInfo->getKeyCenterXOfKeyIdG(keyId); - y = proximityInfo->getKeyCenterYOfKeyIdG(keyId); + x = proximityInfo->getKeyCenterXOfKeyIdG(keyId, NOT_AN_INDEX, isGeometric); + y = proximityInfo->getKeyCenterYOfKeyIdG(keyId, NOT_AN_INDEX, isGeometric); } } // Pushing point information. if (size > 0) { sampledLengthCache->push_back( - sampledLengthCache->back() + getDistanceInt( + sampledLengthCache->back() + GeometryUtils::getDistanceInt( x, y, sampledInputXs->back(), sampledInputYs->back())); } else { sampledLengthCache->push_back(0); @@ -540,7 +501,8 @@ namespace latinime { while (start > 0 && tempBeelineDistance < lookupRadius) { tempTime += times[start] - times[start - 1]; --start; - tempBeelineDistance = getDistanceInt(x0, y0, xCoordinates[start], yCoordinates[start]); + tempBeelineDistance = GeometryUtils::getDistanceInt(x0, y0, xCoordinates[start], + yCoordinates[start]); } // Exclusive unless this is an edge point if (start > 0 && start < actualInputIndex) { @@ -553,7 +515,8 @@ namespace latinime { while (end < (inputSize - 1) && tempBeelineDistance < lookupRadius) { tempTime += times[end + 1] - times[end]; ++end; - tempBeelineDistance = getDistanceInt(x0, y0, xCoordinates[end], yCoordinates[end]); + tempBeelineDistance = GeometryUtils::getDistanceInt(x0, y0, xCoordinates[end], + yCoordinates[end]); } // Exclusive unless this is an edge point if (end > actualInputIndex && end < (inputSize - 1)) { @@ -571,7 +534,7 @@ namespace latinime { const int y2 = yCoordinates[start]; const int x3 = xCoordinates[end]; const int y3 = yCoordinates[end]; - const int beelineDistance = getDistanceInt(x2, y2, x3, y3); + const int beelineDistance = GeometryUtils::getDistanceInt(x2, y2, x3, y3); int adjustedStartTime = times[start]; if (start == 0 && actualInputIndex == 0 && inputSize > 1) { adjustedStartTime += ProximityInfoParams::FIRST_POINT_TIME_OFFSET_MILLIS; @@ -613,7 +576,7 @@ namespace latinime { } const float previousDirection = getDirection(sampledInputXs, sampledInputYs, index - 1, index); const float nextDirection = getDirection(sampledInputXs, sampledInputYs, index, index + 1); - const float directionDiff = getAngleDiff(previousDirection, nextDirection); + const float directionDiff = GeometryUtils::getAngleDiff(previousDirection, nextDirection); return directionDiff; } @@ -636,7 +599,7 @@ namespace latinime { } const float previousDirection = getDirection(sampledInputXs, sampledInputYs, index0, index1); const float nextDirection = getDirection(sampledInputXs, sampledInputYs, index1, index2); - return getAngleDiff(previousDirection, nextDirection); + return GeometryUtils::getAngleDiff(previousDirection, nextDirection); } // This function basically converts from a length to an edit distance. Accordingly, it's obviously diff --git a/native/jni/src/proximity_info_state_utils.h b/native/jni/src/suggest/core/layout/proximity_info_state_utils.h index 1837c7ab6..6de970033 100644 --- a/native/jni/src/proximity_info_state_utils.h +++ b/native/jni/src/suggest/core/layout/proximity_info_state_utils.h @@ -21,7 +21,7 @@ #include <vector> #include "defines.h" -#include "hash_map_compat.h" +#include "utils/hash_map_compat.h" namespace latinime { class ProximityInfo; @@ -38,8 +38,7 @@ class ProximityInfoStateUtils { static int updateTouchPoints(const ProximityInfo *const proximityInfo, const int maxPointToKeyLength, const int *const inputProximities, const int *const inputXCoordinates, const int *const inputYCoordinates, - const int *const times, const int *const pointerIds, - const float verticalSweetSpotScale, const int inputSize, + const int *const times, const int *const pointerIds, const int inputSize, const bool isGeometric, const int pointerId, const int pushTouchPointStartIndex, std::vector<int> *sampledInputXs, std::vector<int> *sampledInputYs, std::vector<int> *sampledInputTimes, std::vector<int> *sampledLengthCache, @@ -84,8 +83,7 @@ class ProximityInfoStateUtils { const std::vector<float> *const sampledNormalizedSquaredLengthCache, const int keyCount, const int inputIndex, const int keyId); static void initGeometricDistanceInfos(const ProximityInfo *const proximityInfo, - const int sampledInputSize, const int lastSavedInputSize, - const float verticalSweetSpotScale, + const int sampledInputSize, const int lastSavedInputSize, const bool isGeometric, const std::vector<int> *const sampledInputXs, const std::vector<int> *const sampledInputYs, std::vector<NearKeycodesSet> *sampledNearKeySets, @@ -120,7 +118,7 @@ class ProximityInfoStateUtils { static float updateNearKeysDistances(const ProximityInfo *const proximityInfo, const float maxPointToKeyLength, const int x, const int y, - const float verticalSweetSpotScale, + const bool isGeometric, NearKeysDistanceMap *const currentNearKeysDistances); static bool isPrevLocalMin(const NearKeysDistanceMap *const currentNearKeysDistances, const NearKeysDistanceMap *const prevNearKeysDistances, @@ -133,7 +131,7 @@ class ProximityInfoStateUtils { std::vector<int> *sampledInputXs, std::vector<int> *sampledInputYs); static bool pushTouchPoint(const ProximityInfo *const proximityInfo, const int maxPointToKeyLength, const int inputIndex, const int nodeCodePoint, int x, - int y, const int time, const float verticalSweetSpotScale, + int y, const int time, const bool isGeometric, const bool doSampling, const bool isLastPoint, const float sumAngle, NearKeysDistanceMap *const currentNearKeysDistances, const NearKeysDistanceMap *const prevNearKeysDistances, diff --git a/native/jni/src/proximity_info_utils.h b/native/jni/src/suggest/core/layout/proximity_info_utils.h index 71c97e325..0e28560fc 100644 --- a/native/jni/src/proximity_info_utils.h +++ b/native/jni/src/suggest/core/layout/proximity_info_utils.h @@ -19,11 +19,11 @@ #include <cmath> -#include "additional_proximity_chars.h" -#include "char_utils.h" #include "defines.h" -#include "geometry_utils.h" -#include "hash_map_compat.h" +#include "suggest/core/layout/additional_proximity_chars.h" +#include "suggest/core/layout/geometry_utils.h" +#include "utils/char_utils.h" +#include "utils/hash_map_compat.h" namespace latinime { class ProximityInfoUtils { @@ -37,7 +37,7 @@ class ProximityInfoUtils { if (c == NOT_A_CODE_POINT) { return NOT_AN_INDEX; } - const int lowerCode = toLowerCase(c); + const int lowerCode = CharUtils::toLowerCase(c); hash_map_compat<int, int>::const_iterator mapPos = codeToKeyMap->find(lowerCode); if (mapPos != codeToKeyMap->end()) { return mapPos->second; @@ -87,7 +87,7 @@ class ProximityInfoUtils { static inline float getSquaredDistanceFloat(const float x1, const float y1, const float x2, const float y2) { - return SQUARE_FLOAT(x1 - x2) + SQUARE_FLOAT(y1 - y2); + return GeometryUtils::SQUARE_FLOAT(x1 - x2) + GeometryUtils::SQUARE_FLOAT(y1 - y2); } static inline float pointToLineSegSquaredDistanceFloat(const float x, const float y, @@ -98,7 +98,8 @@ class ProximityInfoUtils { const float ray2y = y2 - y1; const float dotProduct = ray1x * ray2x + ray1y * ray2y; - const float lineLengthSqr = SQUARE_FLOAT(ray2x) + SQUARE_FLOAT(ray2y); + const float lineLengthSqr = GeometryUtils::SQUARE_FLOAT(ray2x) + + GeometryUtils::SQUARE_FLOAT(ray2y); const float projectionLengthSqr = dotProduct / lineLengthSqr; float projectionX; @@ -116,17 +117,23 @@ class ProximityInfoUtils { return getSquaredDistanceFloat(x, y, projectionX, projectionY); } + static AK_FORCE_INLINE bool isMatchOrProximityChar(const ProximityType type) { + return type == MATCH_CHAR || type == PROXIMITY_CHAR || type == ADDITIONAL_PROXIMITY_CHAR; + } + // Normal distribution N(u, sigma^2). struct NormalDistribution { public: NormalDistribution(const float u, const float sigma) : mU(u), mSigma(sigma), - mPreComputedNonExpPart(1.0f / sqrtf(2.0f * M_PI_F * SQUARE_FLOAT(sigma))), - mPreComputedExponentPart(-1.0f / (2.0f * SQUARE_FLOAT(sigma))) {} + mPreComputedNonExpPart(1.0f / sqrtf(2.0f * M_PI_F + * GeometryUtils::SQUARE_FLOAT(sigma))), + mPreComputedExponentPart(-1.0f / (2.0f * GeometryUtils::SQUARE_FLOAT(sigma))) {} float getProbabilityDensity(const float x) const { const float shiftedX = x - mU; - return mPreComputedNonExpPart * expf(mPreComputedExponentPart * SQUARE_FLOAT(shiftedX)); + return mPreComputedNonExpPart + * expf(mPreComputedExponentPart * GeometryUtils::SQUARE_FLOAT(shiftedX)); } private: diff --git a/native/jni/src/suggest_utils.h b/native/jni/src/suggest/core/layout/touch_position_correction_utils.h index e053dd662..9130e87d3 100644 --- a/native/jni/src/suggest_utils.h +++ b/native/jni/src/suggest/core/layout/touch_position_correction_utils.h @@ -14,40 +14,15 @@ * limitations under the License. */ -#ifndef LATINIME_SUGGEST_UTILS_H -#define LATINIME_SUGGEST_UTILS_H +#ifndef LATINIME_TOUCH_POSITION_CORRECTION_UTILS_H +#define LATINIME_TOUCH_POSITION_CORRECTION_UTILS_H #include "defines.h" -#include "proximity_info_params.h" +#include "suggest/core/layout/proximity_info_params.h" namespace latinime { -class SuggestUtils { +class TouchPositionCorrectionUtils { public: - // TODO: (OLD) Remove - static float getLengthScalingFactor(const float normalizedSquaredDistance) { - // Promote or demote the score according to the distance from the sweet spot - static const float A = ZERO_DISTANCE_PROMOTION_RATE / 100.0f; - static const float B = 1.0f; - static const float C = 0.5f; - static const float MIN = 0.3f; - static const float R1 = NEUTRAL_SCORE_SQUARED_RADIUS; - static const float R2 = HALF_SCORE_SQUARED_RADIUS; - const float x = normalizedSquaredDistance / static_cast<float>( - ProximityInfoParams::NORMALIZED_SQUARED_DISTANCE_SCALING_FACTOR); - const float factor = max((x < R1) - ? (A * (R1 - x) + B * x) / R1 - : (B * (R2 - x) + C * (x - R1)) / (R2 - R1), MIN); - // factor is a piecewise linear function like: - // A -_ . - // ^-_ . - // B \ . - // \_ . - // C ------------. - // . - // 0 R1 R2 . - return factor; - } - static float getSweetSpotFactor(const bool isTouchPositionCorrectionEnabled, const float normalizedSquaredDistance) { // Promote or demote the score according to the distance from the sweet spot @@ -82,7 +57,7 @@ class SuggestUtils { } } private: - DISALLOW_IMPLICIT_CONSTRUCTORS(SuggestUtils); + DISALLOW_IMPLICIT_CONSTRUCTORS(TouchPositionCorrectionUtils); }; } // namespace latinime -#endif // LATINIME_SUGGEST_UTILS_H +#endif // LATINIME_TOUCH_POSITION_CORRECTION_UTILS_H diff --git a/native/jni/src/suggest/core/policy/dictionary_structure_policy.h b/native/jni/src/suggest/core/policy/dictionary_structure_policy.h new file mode 100644 index 000000000..cc14c982c --- /dev/null +++ b/native/jni/src/suggest/core/policy/dictionary_structure_policy.h @@ -0,0 +1,79 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DICTIONARY_STRUCTURE_POLICY_H +#define LATINIME_DICTIONARY_STRUCTURE_POLICY_H + +#include "defines.h" + +namespace latinime { + +class BinaryDictionaryInfo; +class DicNode; +class DicNodeVector; + +/* + * This class abstracts structure of dictionaries. + * Implement this policy to support additional dictionaries. + */ +class DictionaryStructurePolicy { + public: + // This provides a filtering method for filtering new node. + class NodeFilter { + public: + virtual bool isFilteredOut(const int codePoint) const = 0; + + protected: + NodeFilter() {} + virtual ~NodeFilter() {} + + private: + DISALLOW_COPY_AND_ASSIGN(NodeFilter); + }; + + virtual int getRootPosition() const = 0; + + virtual void createAndGetAllChildNodes(const DicNode *const dicNode, + const BinaryDictionaryInfo *const binaryDictionaryInfo, + const NodeFilter *const nodeFilter, DicNodeVector *const childDicNodes) const = 0; + + virtual int getCodePointsAndProbabilityAndReturnCodePointCount( + const BinaryDictionaryInfo *const binaryDictionaryInfo, + const int nodePos, const int maxCodePointCount, int *const outCodePoints, + int *const outUnigramProbability) const = 0; + + virtual int getTerminalNodePositionOfWord( + const BinaryDictionaryInfo *const binaryDictionaryInfo, const int *const inWord, + const int length, const bool forceLowerCaseSearch) const = 0; + + virtual int getUnigramProbability(const BinaryDictionaryInfo *const binaryDictionaryInfo, + const int nodePos) const = 0; + + virtual int getShortcutPositionOfNode(const BinaryDictionaryInfo *const binaryDictionaryInfo, + const int nodePos) const = 0; + + virtual int getBigramsPositionOfNode(const BinaryDictionaryInfo *const binaryDictionaryInfo, + const int nodePos) const = 0; + + protected: + DictionaryStructurePolicy() {} + virtual ~DictionaryStructurePolicy() {} + + private: + DISALLOW_COPY_AND_ASSIGN(DictionaryStructurePolicy); +}; +} // namespace latinime +#endif /* LATINIME_DICTIONARY_STRUCTURE_POLICY_H */ diff --git a/native/jni/src/suggest/core/policy/traversal.h b/native/jni/src/suggest/core/policy/traversal.h index c6f66f231..e935533f2 100644 --- a/native/jni/src/suggest/core/policy/traversal.h +++ b/native/jni/src/suggest/core/policy/traversal.h @@ -45,9 +45,9 @@ class Traversal { const DicNode *const dicNode) const = 0; virtual bool needsToTraverseAllUserInput() const = 0; virtual float getMaxSpatialDistance() const = 0; - virtual bool allowPartialCommit() const = 0; + virtual bool autoCorrectsToMultiWordSuggestionIfTop() const = 0; virtual int getDefaultExpandDicNodeSize() const = 0; - virtual int getMaxCacheSize() const = 0; + virtual int getMaxCacheSize(const int inputSize) const = 0; virtual bool isPossibleOmissionChildNode(const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0; virtual bool isGoodToTraverseNextWord(const DicNode *const dicNode) const = 0; diff --git a/native/jni/src/suggest/core/policy/weighting.cpp b/native/jni/src/suggest/core/policy/weighting.cpp index d01531f07..58729229f 100644 --- a/native/jni/src/suggest/core/policy/weighting.cpp +++ b/native/jni/src/suggest/core/policy/weighting.cpp @@ -16,7 +16,6 @@ #include "suggest/core/policy/weighting.h" -#include "char_utils.h" #include "defines.h" #include "suggest/core/dicnode/dic_node.h" #include "suggest/core/dicnode/dic_node_profiler.h" @@ -51,6 +50,9 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n case CT_TERMINAL: PROF_TERMINAL(node->mProfiler); return; + case CT_TERMINAL_INSERTION: + PROF_TERMINAL_INSERTION(node->mProfiler); + return; case CT_NEW_WORD_SPACE_SUBSTITUTION: PROF_SPACE_SUBSTITUTION(node->mProfiler); return; @@ -107,13 +109,15 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n // only used for typing return weighting->getSubstitutionCost(); case CT_NEW_WORD_SPACE_OMITTION: - return weighting->getNewWordCost(traverseSession, dicNode); + return weighting->getNewWordSpatialCost(traverseSession, dicNode, inputStateG); case CT_MATCH: return weighting->getMatchedCost(traverseSession, dicNode, inputStateG); case CT_COMPLETION: return weighting->getCompletionCost(traverseSession, dicNode); case CT_TERMINAL: return weighting->getTerminalSpatialCost(traverseSession, dicNode); + case CT_TERMINAL_INSERTION: + return weighting->getTerminalInsertionCost(traverseSession, dicNode); case CT_NEW_WORD_SPACE_SUBSTITUTION: return weighting->getSpaceSubstitutionCost(traverseSession, dicNode); case CT_INSERTION: @@ -135,7 +139,8 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n case CT_SUBSTITUTION: return 0.0f; case CT_NEW_WORD_SPACE_OMITTION: - return weighting->getNewWordBigramCost(traverseSession, parentDicNode, multiBigramMap); + return weighting->getNewWordBigramLanguageCost( + traverseSession, parentDicNode, multiBigramMap); case CT_MATCH: return 0.0f; case CT_COMPLETION: @@ -143,11 +148,14 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n case CT_TERMINAL: { const float languageImprobability = DicNodeUtils::getBigramNodeImprobability( - traverseSession->getOffsetDict(), dicNode, multiBigramMap); + traverseSession->getBinaryDictionaryInfo(), dicNode, multiBigramMap); return weighting->getTerminalLanguageCost(traverseSession, dicNode, languageImprobability); } + case CT_TERMINAL_INSERTION: + return 0.0f; case CT_NEW_WORD_SPACE_SUBSTITUTION: - return weighting->getNewWordBigramCost(traverseSession, parentDicNode, multiBigramMap); + return weighting->getNewWordBigramLanguageCost( + traverseSession, parentDicNode, multiBigramMap); case CT_INSERTION: return 0.0f; case CT_TRANSPOSITION: @@ -162,9 +170,9 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n case CT_OMISSION: return 0; case CT_ADDITIONAL_PROXIMITY: - return 0; + return 0; /* 0 because CT_MATCH will be called */ case CT_SUBSTITUTION: - return 0; + return 0; /* 0 because CT_MATCH will be called */ case CT_NEW_WORD_SPACE_OMITTION: return 0; case CT_MATCH: @@ -173,12 +181,14 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n return 1; case CT_TERMINAL: return 0; + case CT_TERMINAL_INSERTION: + return 1; case CT_NEW_WORD_SPACE_SUBSTITUTION: return 1; case CT_INSERTION: - return 2; + return 2; /* look ahead + skip the current char */ case CT_TRANSPOSITION: - return 2; + return 2; /* look ahead + skip the current char */ default: return 0; } diff --git a/native/jni/src/suggest/core/policy/weighting.h b/native/jni/src/suggest/core/policy/weighting.h index 0d2745b40..2d49e98a6 100644 --- a/native/jni/src/suggest/core/policy/weighting.h +++ b/native/jni/src/suggest/core/policy/weighting.h @@ -56,10 +56,10 @@ class Weighting { const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0; - virtual float getNewWordCost(const DicTraverseSession *const traverseSession, - const DicNode *const dicNode) const = 0; + virtual float getNewWordSpatialCost(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode, DicNode_InputStateG *const inputStateG) const = 0; - virtual float getNewWordBigramCost( + virtual float getNewWordBigramLanguageCost( const DicTraverseSession *const traverseSession, const DicNode *const dicNode, MultiBigramMap *const multiBigramMap) const = 0; @@ -67,6 +67,10 @@ class Weighting { const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const = 0; + virtual float getTerminalInsertionCost( + const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const = 0; + virtual float getTerminalLanguageCost( const DicTraverseSession *const traverseSession, const DicNode *const dicNode, float dicNodeLanguageImprobability) const = 0; diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.cpp b/native/jni/src/suggest/core/session/dic_traverse_session.cpp index 6408f0163..7651b19a0 100644 --- a/native/jni/src/suggest/core/session/dic_traverse_session.cpp +++ b/native/jni/src/suggest/core/session/dic_traverse_session.cpp @@ -16,68 +16,33 @@ #include "suggest/core/session/dic_traverse_session.h" -#include "binary_format.h" #include "defines.h" -#include "dictionary.h" -#include "dic_traverse_wrapper.h" #include "jni.h" -#include "suggest/core/dicnode/dic_node_utils.h" +#include "suggest/core/dictionary/binary_dictionary_header.h" +#include "suggest/core/dictionary/binary_dictionary_info.h" +#include "suggest/core/dictionary/dictionary.h" namespace latinime { -const int DicTraverseSession::CACHE_START_INPUT_LENGTH_THRESHOLD = 20; - -// A factory method for DicTraverseSession -static void *getSessionInstance(JNIEnv *env, jstring localeStr) { - return new DicTraverseSession(env, localeStr); -} - -// TODO: Pass "DicTraverseSession *traverseSession" when the source code structure settles down. -static void initSessionInstance(void *traverseSession, const Dictionary *const dictionary, - const int *prevWord, const int prevWordLength) { - if (traverseSession) { - DicTraverseSession *tSession = static_cast<DicTraverseSession *>(traverseSession); - tSession->init(dictionary, prevWord, prevWordLength); - } -} - -// TODO: Pass "DicTraverseSession *traverseSession" when the source code structure settles down. -static void releaseSessionInstance(void *traverseSession) { - delete static_cast<DicTraverseSession *>(traverseSession); -} - -// An ad-hoc internal class to register the factory method defined above -class TraverseSessionFactoryRegisterer { - public: - TraverseSessionFactoryRegisterer() { - DicTraverseWrapper::setTraverseSessionFactoryMethod(getSessionInstance); - DicTraverseWrapper::setTraverseSessionInitMethod(initSessionInstance); - DicTraverseWrapper::setTraverseSessionReleaseMethod(releaseSessionInstance); - } - private: - DISALLOW_COPY_AND_ASSIGN(TraverseSessionFactoryRegisterer); -}; - -// To invoke the TraverseSessionFactoryRegisterer constructor in the global constructor. -static TraverseSessionFactoryRegisterer traverseSessionFactoryRegisterer; - void DicTraverseSession::init(const Dictionary *const dictionary, const int *prevWord, - int prevWordLength) { + int prevWordLength, const SuggestOptions *const suggestOptions) { mDictionary = dictionary; - mMultiWordCostMultiplier = BinaryFormat::getMultiWordCostMultiplier(mDictionary->getDict(), - mDictionary->getDictSize()); + const BinaryDictionaryInfo *const binaryDictionaryInfo = + mDictionary->getBinaryDictionaryInfo(); + mMultiWordCostMultiplier = binaryDictionaryInfo->getHeader()->getMultiWordCostMultiplier(); + mSuggestOptions = suggestOptions; if (!prevWord) { - mPrevWordPos = NOT_VALID_WORD; + mPrevWordPos = NOT_A_VALID_WORD_POS; return; } // TODO: merge following similar calls to getTerminalPosition into one case-insensitive call. - mPrevWordPos = BinaryFormat::getTerminalPosition(dictionary->getOffsetDict(), prevWord, - prevWordLength, false /* forceLowerCaseSearch */); - if (mPrevWordPos == NOT_VALID_WORD) { + mPrevWordPos = binaryDictionaryInfo->getStructurePolicy()->getTerminalNodePositionOfWord( + binaryDictionaryInfo, prevWord, prevWordLength, false /* forceLowerCaseSearch */); + if (mPrevWordPos == NOT_A_VALID_WORD_POS) { // Check bigrams for lower-cased previous word if original was not found. Useful for // auto-capitalized words like "The [current_word]". - mPrevWordPos = BinaryFormat::getTerminalPosition(dictionary->getOffsetDict(), prevWord, - prevWordLength, true /* forceLowerCaseSearch */); + mPrevWordPos = binaryDictionaryInfo->getStructurePolicy()->getTerminalNodePositionOfWord( + binaryDictionaryInfo, prevWord, prevWordLength, true /* forceLowerCaseSearch */); } } @@ -91,12 +56,8 @@ void DicTraverseSession::setupForGetSuggestions(const ProximityInfo *pInfo, maxSpatialDistance, maxPointerCount); } -const uint8_t *DicTraverseSession::getOffsetDict() const { - return mDictionary->getOffsetDict(); -} - -int DicTraverseSession::getDictFlags() const { - return mDictionary->getDictFlags(); +const BinaryDictionaryInfo *DicTraverseSession::getBinaryDictionaryInfo() const { + return mDictionary->getBinaryDictionaryInfo(); } void DicTraverseSession::resetCache(const int nextActiveCacheSize, const int maxWords) { diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.h b/native/jni/src/suggest/core/session/dic_traverse_session.h index d88be5b88..de57e041a 100644 --- a/native/jni/src/suggest/core/session/dic_traverse_session.h +++ b/native/jni/src/suggest/core/session/dic_traverse_session.h @@ -22,20 +22,41 @@ #include "defines.h" #include "jni.h" -#include "multi_bigram_map.h" -#include "proximity_info_state.h" #include "suggest/core/dicnode/dic_nodes_cache.h" +#include "suggest/core/dictionary/multi_bigram_map.h" +#include "suggest/core/layout/proximity_info_state.h" namespace latinime { +class BinaryDictionaryInfo; class Dictionary; class ProximityInfo; +class SuggestOptions; class DicTraverseSession { public: + + // A factory method for DicTraverseSession + static AK_FORCE_INLINE void *getSessionInstance(JNIEnv *env, jstring localeStr) { + return new DicTraverseSession(env, localeStr); + } + + static AK_FORCE_INLINE void initSessionInstance(DicTraverseSession *traverseSession, + const Dictionary *const dictionary, const int *prevWord, const int prevWordLength, + const SuggestOptions *const suggestOptions) { + if (traverseSession) { + DicTraverseSession *tSession = static_cast<DicTraverseSession *>(traverseSession); + tSession->init(dictionary, prevWord, prevWordLength, suggestOptions); + } + } + + static AK_FORCE_INLINE void releaseSessionInstance(DicTraverseSession *traverseSession) { + delete traverseSession; + } + AK_FORCE_INLINE DicTraverseSession(JNIEnv *env, jstring localeStr) - : mPrevWordPos(NOT_VALID_WORD), mProximityInfo(0), - mDictionary(0), mDicNodesCache(), mMultiBigramMap(), + : mPrevWordPos(NOT_A_VALID_WORD_POS), mProximityInfo(0), + mDictionary(0), mSuggestOptions(0), mDicNodesCache(), mMultiBigramMap(), mInputSize(0), mPartiallyCommited(false), mMaxPointerCount(1), mMultiWordCostMultiplier(1.0f) { // NOTE: mProximityInfoStates is an array of instances. @@ -45,7 +66,8 @@ class DicTraverseSession { // Non virtual inline destructor -- never inherit this class AK_FORCE_INLINE ~DicTraverseSession() {} - void init(const Dictionary *dictionary, const int *prevWord, int prevWordLength); + void init(const Dictionary *dictionary, const int *prevWord, int prevWordLength, + const SuggestOptions *const suggestOptions); // TODO: Remove and merge into init void setupForGetSuggestions(const ProximityInfo *pInfo, const int *inputCodePoints, const int inputSize, const int *const inputXs, const int *const inputYs, @@ -54,13 +76,13 @@ class DicTraverseSession { void resetCache(const int nextActiveCacheSize, const int maxWords); // TODO: Remove - const uint8_t *getOffsetDict() const; - int getDictFlags() const; + const BinaryDictionaryInfo *getBinaryDictionaryInfo() const; //-------------------- // getters and setters //-------------------- const ProximityInfo *getProximityInfo() const { return mProximityInfo; } + const SuggestOptions *getSuggestOptions() const { return mSuggestOptions; } int getPrevWordPos() const { return mPrevWordPos; } // TODO: REMOVE void setPrevWordPos(int pos) { mPrevWordPos = pos; } @@ -167,6 +189,7 @@ class DicTraverseSession { int mPrevWordPos; const ProximityInfo *mProximityInfo; const Dictionary *mDictionary; + const SuggestOptions *mSuggestOptions; DicNodesCache mDicNodesCache; // Temporary cache for bigram frequencies diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp index a18794850..9376d7b93 100644 --- a/native/jni/src/suggest/core/suggest.cpp +++ b/native/jni/src/suggest/core/suggest.cpp @@ -16,19 +16,19 @@ #include "suggest/core/suggest.h" -#include "char_utils.h" -#include "dictionary.h" -#include "digraph_utils.h" -#include "proximity_info.h" #include "suggest/core/dicnode/dic_node.h" #include "suggest/core/dicnode/dic_node_priority_queue.h" #include "suggest/core/dicnode/dic_node_vector.h" +#include "suggest/core/dictionary/binary_dictionary_info.h" +#include "suggest/core/dictionary/dictionary.h" +#include "suggest/core/dictionary/digraph_utils.h" #include "suggest/core/dictionary/shortcut_utils.h" +#include "suggest/core/dictionary/terminal_attributes.h" +#include "suggest/core/layout/proximity_info.h" #include "suggest/core/policy/scoring.h" #include "suggest/core/policy/traversal.h" #include "suggest/core/policy/weighting.h" #include "suggest/core/session/dic_traverse_session.h" -#include "terminal_attributes.h" namespace latinime { @@ -84,9 +84,9 @@ void Suggest::initializeSearch(DicTraverseSession *traverseSession, int commitPo if (!traverseSession->getProximityInfoState(0)->isUsed()) { return; } - if (TRAVERSAL->allowPartialCommit()) { - commitPoint = 0; - } + + // Never auto partial commit for now. + commitPoint = 0; if (traverseSession->getInputSize() > MIN_CONTINUOUS_SUGGESTION_INPUT_SIZE && traverseSession->isContinuousSuggestionPossible()) { @@ -103,11 +103,12 @@ void Suggest::initializeSearch(DicTraverseSession *traverseSession, int commitPo } } else { // Restart recognition at the root. - traverseSession->resetCache(TRAVERSAL->getMaxCacheSize(), MAX_RESULTS); + traverseSession->resetCache(TRAVERSAL->getMaxCacheSize(traverseSession->getInputSize()), + MAX_RESULTS); // Create a new dic node here DicNode rootNode; - DicNodeUtils::initAsRoot(traverseSession->getDicRootPos(), - traverseSession->getOffsetDict(), traverseSession->getPrevWordPos(), &rootNode); + DicNodeUtils::initAsRoot(traverseSession->getBinaryDictionaryInfo(), + traverseSession->getPrevWordPos(), &rootNode); traverseSession->getDicTraverseCache()->copyPushActive(&rootNode); } } @@ -148,6 +149,17 @@ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequen &doubleLetterTerminalIndex, &doubleLetterLevel); int maxScore = S_INT_MIN; + // Force autocorrection for obvious long multi-word suggestions when the top suggestion is + // a long multiple words suggestion. + // TODO: Implement a smarter auto-commit method for handling multi-word suggestions. + // traverseSession->isPartiallyCommited() always returns false because we never auto partial + // commit for now. + const bool forceCommitMultiWords = (terminalSize > 0) ? + TRAVERSAL->autoCorrectsToMultiWordSuggestionIfTop() + && (traverseSession->isPartiallyCommited() + || (traverseSession->getInputSize() + >= MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT + && terminals[0].hasMultipleWords())) : false; // Output suggestion results here for (int terminalIndex = 0; terminalIndex < terminalSize && outputWordIndex < MAX_RESULTS; ++terminalIndex) { @@ -159,8 +171,6 @@ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequen terminalIndex, doubleLetterTerminalIndex, doubleLetterLevel); const float compoundDistance = terminalDicNode->getCompoundDistance(languageWeight) + doubleLetterCost; - const TerminalAttributes terminalAttributes(traverseSession->getOffsetDict(), - terminalDicNode->getFlags(), terminalDicNode->getAttributesPos()); const bool isPossiblyOffensiveWord = terminalDicNode->getProbability() <= 0; const bool isExactMatch = terminalDicNode->isExactMatch(); const bool isFirstCharUppercase = terminalDicNode->isFirstCharUppercase(); @@ -173,27 +183,21 @@ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequen | (isSafeExactMatch ? Dictionary::KIND_FLAG_EXACT_MATCH : 0); // Entries that are blacklisted or do not represent a word should not be output. - const bool isValidWord = !terminalAttributes.isBlacklistedOrNotAWord(); + const bool isValidWord = !terminalDicNode->isBlacklistedOrNotAWord(); // Increase output score of top typing suggestion to ensure autocorrection. // TODO: Better integration with java side autocorrection logic. - // Force autocorrection for obvious long multi-word suggestions. - const bool isForceCommitMultiWords = TRAVERSAL->allowPartialCommit() - && (traverseSession->isPartiallyCommited() - || (traverseSession->getInputSize() >= MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT - && terminalDicNode->hasMultipleWords())); - const int finalScore = SCORING->calculateFinalScore( compoundDistance, traverseSession->getInputSize(), - isForceCommitMultiWords || (isValidWord && SCORING->doesAutoCorrectValidWord())); - + terminalDicNode->isExactMatch() + || (forceCommitMultiWords && terminalDicNode->hasMultipleWords()) + || (isValidWord && SCORING->doesAutoCorrectValidWord())); maxScore = max(maxScore, finalScore); - if (TRAVERSAL->allowPartialCommit()) { - // Index for top typing suggestion should be 0. - if (isValidWord && outputWordIndex == 0) { - terminalDicNode->outputSpacePositionsResult(spaceIndices); - } + // TODO: Implement a smarter auto-commit method for handling multi-word suggestions. + // Index for top typing suggestion should be 0. + if (isValidWord && outputWordIndex == 0) { + terminalDicNode->outputSpacePositionsResult(spaceIndices); } // Don't output invalid words. However, we still need to submit their shortcuts if any. @@ -206,9 +210,18 @@ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequen ++outputWordIndex; } - const bool sameAsTyped = TRAVERSAL->sameAsTyped(traverseSession, terminalDicNode); - outputWordIndex = ShortcutUtils::outputShortcuts(&terminalAttributes, outputWordIndex, - finalScore, outputCodePoints, frequencies, outputTypes, sameAsTyped); + if (!terminalDicNode->hasMultipleWords()) { + const BinaryDictionaryInfo *const binaryDictionaryInfo = + traverseSession->getBinaryDictionaryInfo(); + const TerminalAttributes terminalAttributes(traverseSession->getBinaryDictionaryInfo(), + binaryDictionaryInfo->getStructurePolicy()->getShortcutPositionOfNode( + binaryDictionaryInfo, terminalDicNode->getPos())); + // Shortcut is not supported for multiple words suggestions. + // TODO: Check shortcuts during traversal for multiple words suggestions. + const bool sameAsTyped = TRAVERSAL->sameAsTyped(traverseSession, terminalDicNode); + outputWordIndex = ShortcutUtils::outputShortcuts(&terminalAttributes, outputWordIndex, + finalScore, outputCodePoints, frequencies, outputTypes, sameAsTyped); + } DicNode::managedDelete(terminalDicNode); } @@ -285,7 +298,7 @@ void Suggest::expandCurrentDicNodes(DicTraverseSession *traverseSession) const { } DicNodeUtils::getAllChildDicNodes( - &dicNode, traverseSession->getOffsetDict(), &childDicNodes); + &dicNode, traverseSession->getBinaryDictionaryInfo(), &childDicNodes); const int childDicNodesSize = childDicNodes.getSizeAndLock(); for (int i = 0; i < childDicNodesSize; ++i) { @@ -295,7 +308,8 @@ void Suggest::expandCurrentDicNodes(DicTraverseSession *traverseSession) const { processDicNodeAsMatch(traverseSession, childDicNode); continue; } - if (DigraphUtils::hasDigraphForCodePoint(traverseSession->getDictFlags(), + if (DigraphUtils::hasDigraphForCodePoint( + traverseSession->getBinaryDictionaryInfo()->getHeader(), childDicNode->getNodeCodePoint())) { correctionDicNode.initByCopy(childDicNode); correctionDicNode.advanceDigraphIndex(); @@ -352,17 +366,17 @@ void Suggest::processTerminalDicNode( if (!dicNode->isTerminalWordNode()) { return; } - if (TRAVERSAL->needsToTraverseAllUserInput() - && dicNode->getInputIndex(0) < traverseSession->getInputSize()) { - return; - } - if (dicNode->shouldBeFilterdBySafetyNetForBigram()) { return; } // Create a non-cached node here. DicNode terminalDicNode; DicNodeUtils::initByCopy(dicNode, &terminalDicNode); + if (TRAVERSAL->needsToTraverseAllUserInput() + && dicNode->getInputIndex(0) < traverseSession->getInputSize()) { + Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TERMINAL_INSERTION, traverseSession, 0, + &terminalDicNode, traverseSession->getMultiBigramMap()); + } Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TERMINAL, traverseSession, 0, &terminalDicNode, traverseSession->getMultiBigramMap()); traverseSession->getDicTraverseCache()->copyPushTerminal(&terminalDicNode); @@ -432,7 +446,8 @@ void Suggest::processDicNodeAsDigraph(DicTraverseSession *traverseSession, void Suggest::processDicNodeAsOmission( DicTraverseSession *traverseSession, DicNode *dicNode) const { DicNodeVector childDicNodes; - DicNodeUtils::getAllChildDicNodes(dicNode, traverseSession->getOffsetDict(), &childDicNodes); + DicNodeUtils::getAllChildDicNodes( + dicNode, traverseSession->getBinaryDictionaryInfo(), &childDicNodes); const int size = childDicNodes.getSizeAndLock(); for (int i = 0; i < size; i++) { @@ -457,7 +472,7 @@ void Suggest::processDicNodeAsInsertion(DicTraverseSession *traverseSession, DicNode *dicNode) const { const int16_t pointIndex = dicNode->getInputIndex(0); DicNodeVector childDicNodes; - DicNodeUtils::getProximityChildDicNodes(dicNode, traverseSession->getOffsetDict(), + DicNodeUtils::getProximityChildDicNodes(dicNode, traverseSession->getBinaryDictionaryInfo(), traverseSession->getProximityInfoState(0), pointIndex + 1, true, &childDicNodes); const int size = childDicNodes.getSizeAndLock(); for (int i = 0; i < size; i++) { @@ -475,14 +490,14 @@ void Suggest::processDicNodeAsTransposition(DicTraverseSession *traverseSession, DicNode *dicNode) const { const int16_t pointIndex = dicNode->getInputIndex(0); DicNodeVector childDicNodes1; - DicNodeUtils::getProximityChildDicNodes(dicNode, traverseSession->getOffsetDict(), + DicNodeUtils::getProximityChildDicNodes(dicNode, traverseSession->getBinaryDictionaryInfo(), traverseSession->getProximityInfoState(0), pointIndex + 1, false, &childDicNodes1); const int childSize1 = childDicNodes1.getSizeAndLock(); for (int i = 0; i < childSize1; i++) { if (childDicNodes1[i]->hasChildren()) { DicNodeVector childDicNodes2; DicNodeUtils::getProximityChildDicNodes( - childDicNodes1[i], traverseSession->getOffsetDict(), + childDicNodes1[i], traverseSession->getBinaryDictionaryInfo(), traverseSession->getProximityInfoState(0), pointIndex, false, &childDicNodes2); const int childSize2 = childDicNodes2.getSizeAndLock(); for (int j = 0; j < childSize2; j++) { @@ -522,12 +537,18 @@ void Suggest::createNextWordDicNode(DicTraverseSession *traverseSession, DicNode // Create a non-cached node here. DicNode newDicNode; - DicNodeUtils::initAsRootWithPreviousWord(traverseSession->getDicRootPos(), - traverseSession->getOffsetDict(), dicNode, &newDicNode); + DicNodeUtils::initAsRootWithPreviousWord( + traverseSession->getBinaryDictionaryInfo(), dicNode, &newDicNode); const CorrectionType correctionType = spaceSubstitution ? CT_NEW_WORD_SPACE_SUBSTITUTION : CT_NEW_WORD_SPACE_OMITTION; Weighting::addCostAndForwardInputIndex(WEIGHTING, correctionType, traverseSession, dicNode, &newDicNode, traverseSession->getMultiBigramMap()); - traverseSession->getDicTraverseCache()->copyPushNextActive(&newDicNode); + if (newDicNode.getCompoundDistance() < static_cast<float>(MAX_VALUE_FOR_WEIGHTING)) { + // newDicNode is worth continuing to traverse. + // CAVEAT: This pruning is important for speed. Remove this when we can afford not to prune + // here because here is not the right place to do pruning. Pruning should take place only + // in DicNodePriorityQueue. + traverseSession->getDicTraverseCache()->copyPushNextActive(&newDicNode); + } } } // namespace latinime diff --git a/native/jni/src/suggest/core/suggest_options.h b/native/jni/src/suggest/core/suggest_options.h new file mode 100644 index 000000000..1b21aafcf --- /dev/null +++ b/native/jni/src/suggest/core/suggest_options.h @@ -0,0 +1,74 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_SUGGEST_OPTIONS_H +#define LATINIME_SUGGEST_OPTIONS_H + +#include "defines.h" + +namespace latinime { + +class SuggestOptions{ + public: + SuggestOptions(const int *const options, const int length) + : mOptions(options), mLength(length) {} + + AK_FORCE_INLINE bool isGesture() const { + return getBoolOption(IS_GESTURE); + } + + AK_FORCE_INLINE bool useFullEditDistance() const { + return getBoolOption(USE_FULL_EDIT_DISTANCE); + } + + AK_FORCE_INLINE bool getAdditionalFeaturesBoolOption(const int key) const { + return getBoolOption(key + ADDITIONAL_FEATURES_OPTIONS); + } + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(SuggestOptions); + + // Need to update com.android.inputmethod.latin.NativeSuggestOptions when you add, remove or + // reorder options. + static const int IS_GESTURE = 0; + static const int USE_FULL_EDIT_DISTANCE = 1; + // Additional features options are stored after the other options and used as setting values of + // experimental features. + static const int ADDITIONAL_FEATURES_OPTIONS = 2; + + const int *const mOptions; + const int mLength; + + AK_FORCE_INLINE bool isValidKey(const int key) const { + return 0 <= key && key < mLength; + } + + AK_FORCE_INLINE bool getBoolOption(const int key) const { + if (isValidKey(key)) { + return mOptions[key] != 0; + } + return false; + } + + AK_FORCE_INLINE int getIntOption(const int key) const { + if (isValidKey(key)) { + return mOptions[key]; + } + return 0; + } +}; +} // namespace latinime +#endif // LATINIME_SUGGEST_OPTIONS_H diff --git a/native/jni/src/binary_format.h b/native/jni/src/suggest/policyimpl/dictionary/binary_format.h index 98241532f..23f4c7fec 100644 --- a/native/jni/src/binary_format.h +++ b/native/jni/src/suggest/policyimpl/dictionary/binary_format.h @@ -17,13 +17,10 @@ #ifndef LATINIME_BINARY_FORMAT_H #define LATINIME_BINARY_FORMAT_H -#include <cstdlib> -#include <map> #include <stdint.h> -#include "bloom_filter.h" -#include "char_utils.h" -#include "hash_map_compat.h" +#include "suggest/core/dictionary/probability_utils.h" +#include "utils/char_utils.h" namespace latinime { @@ -55,23 +52,10 @@ class BinaryFormat { // Mask for attribute probability, stored on 4 bits inside the flags byte. static const int MASK_ATTRIBUTE_PROBABILITY = 0x0F; - // The numeric value of the shortcut probability that means 'whitelist'. - static const int WHITELIST_SHORTCUT_PROBABILITY = 15; // Mask and flags for attribute address type selection. static const int MASK_ATTRIBUTE_ADDRESS_TYPE = 0x30; - static const int UNKNOWN_FORMAT = -1; - static const int SHORTCUT_LIST_SIZE_SIZE = 2; - - static int detectFormat(const uint8_t *const dict, const int dictSize); - static int getHeaderSize(const uint8_t *const dict, const int dictSize); - static int getFlags(const uint8_t *const dict, const int dictSize); - static bool hasBlacklistedOrNotAWordFlag(const int flags); - static void readHeaderValue(const uint8_t *const dict, const int dictSize, - const char *const key, int *outValue, const int outValueSize); - static int readHeaderValueInt(const uint8_t *const dict, const int dictSize, - const char *const key); static int getGroupCountAndForwardPointer(const uint8_t *const dict, int *pos); static uint8_t getFlagsAndForwardPointer(const uint8_t *const dict, int *pos); static int getCodePointAndForwardPointer(const uint8_t *const dict, int *pos); @@ -84,36 +68,14 @@ class BinaryFormat { const int pos); static int readChildrenPosition(const uint8_t *const dict, const uint8_t flags, const int pos); static bool hasChildrenInFlags(const uint8_t flags); - static int getAttributeAddressAndForwardPointer(const uint8_t *const dict, const uint8_t flags, - int *pos); - static int getAttributeProbabilityFromFlags(const int flags); static int getTerminalPosition(const uint8_t *const root, const int *const inWord, const int length, const bool forceLowerCaseSearch); - static int getWordAtAddress(const uint8_t *const root, const int address, const int maxDepth, - int *outWord, int *outUnigramProbability); - static int computeProbabilityForBigram( - const int unigramProbability, const int bigramProbability); - static int getProbability(const int position, const std::map<int, int> *bigramMap, - const uint8_t *bigramFilter, const int unigramProbability); - static int getBigramProbabilityFromHashMap(const int position, - const hash_map_compat<int, int> *bigramMap, const int unigramProbability); - static float getMultiWordCostMultiplier(const uint8_t *const dict, const int dictSize); - static void fillBigramProbabilityToHashMap(const uint8_t *const root, int position, - hash_map_compat<int, int> *bigramMap); - static int getBigramProbability(const uint8_t *const root, int position, - const int nextPosition, const int unigramProbability); - - // Flags for special processing - // Those *must* match the flags in makedict (BinaryDictInputOutput#*_PROCESSING_FLAG) or - // something very bad (like, the apocalypse) will happen. Please update both at the same time. - enum { - REQUIRES_GERMAN_UMLAUT_PROCESSING = 0x1, - REQUIRES_FRENCH_LIGATURES_PROCESSING = 0x4 - }; + static int getCodePointsAndProbabilityAndReturnCodePointCount( + const uint8_t *const root, const int nodePos, const int maxCodePointCount, + int *const outCodePoints, int *const outUnigramProbability); private: DISALLOW_IMPLICIT_CONSTRUCTORS(BinaryFormat); - static int getBigramListPositionForWordPosition(const uint8_t *const root, int position); static const int FLAG_GROUP_ADDRESS_TYPE_NOADDRESS = 0x00; static const int FLAG_GROUP_ADDRESS_TYPE_ONEBYTE = 0x40; @@ -123,20 +85,6 @@ class BinaryFormat { static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES = 0x20; static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES = 0x30; - // Any file smaller than this is not a dictionary. - static const int DICTIONARY_MINIMUM_SIZE = 4; - // Originally, format version 1 had a 16-bit magic number, then the version number `01' - // then options that must be 0. Hence the first 32-bits of the format are always as follow - // and it's okay to consider them a magic number as a whole. - static const int FORMAT_VERSION_1_MAGIC_NUMBER = 0x78B10100; - static const int FORMAT_VERSION_1_HEADER_SIZE = 5; - // The versions of Latin IME that only handle format version 1 only test for the magic - // number, so we had to change it so that version 2 files would be rejected by older - // implementations. On this occasion, we made the magic number 32 bits long. - static const int FORMAT_VERSION_2_MAGIC_NUMBER = -1681835266; // 0x9BC13AFE - // Magic number (4 bytes), version (2 bytes), options (2 bytes), header size (4 bytes) = 12 - static const int FORMAT_VERSION_2_MINIMUM_SIZE = 12; - static const int CHARACTER_ARRAY_TERMINATOR_SIZE = 1; static const int MINIMAL_ONE_BYTE_CHARACTER_VALUE = 0x20; static const int CHARACTER_ARRAY_TERMINATOR = 0x1F; @@ -146,122 +94,6 @@ class BinaryFormat { static int skipBigrams(const uint8_t *const dict, const uint8_t flags, const int pos); }; -AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict, const int dictSize) { - // The magic number is stored big-endian. - // If the dictionary is less than 4 bytes, we can't even read the magic number, so we don't - // understand this format. - if (dictSize < DICTIONARY_MINIMUM_SIZE) return UNKNOWN_FORMAT; - const int magicNumber = (dict[0] << 24) + (dict[1] << 16) + (dict[2] << 8) + dict[3]; - switch (magicNumber) { - case FORMAT_VERSION_1_MAGIC_NUMBER: - // Format 1 header is exactly 5 bytes long and looks like: - // Magic number (2 bytes) 0x78 0xB1 - // Version number (1 byte) 0x01 - // Options (2 bytes) must be 0x00 0x00 - return 1; - case FORMAT_VERSION_2_MAGIC_NUMBER: - // Version 2 dictionaries are at least 12 bytes long (see below details for the header). - // If this dictionary has the version 2 magic number but is less than 12 bytes long, then - // it's an unknown format and we need to avoid confidently reading the next bytes. - if (dictSize < FORMAT_VERSION_2_MINIMUM_SIZE) return UNKNOWN_FORMAT; - // Format 2 header is as follows: - // Magic number (4 bytes) 0x9B 0xC1 0x3A 0xFE - // Version number (2 bytes) 0x00 0x02 - // Options (2 bytes) - // Header size (4 bytes) : integer, big endian - return (dict[4] << 8) + dict[5]; - default: - return UNKNOWN_FORMAT; - } -} - -inline int BinaryFormat::getFlags(const uint8_t *const dict, const int dictSize) { - switch (detectFormat(dict, dictSize)) { - case 1: - return NO_FLAGS; // TODO: NO_FLAGS is unused anywhere else? - default: - return (dict[6] << 8) + dict[7]; - } -} - -inline bool BinaryFormat::hasBlacklistedOrNotAWordFlag(const int flags) { - return (flags & (FLAG_IS_BLACKLISTED | FLAG_IS_NOT_A_WORD)) != 0; -} - -inline int BinaryFormat::getHeaderSize(const uint8_t *const dict, const int dictSize) { - switch (detectFormat(dict, dictSize)) { - case 1: - return FORMAT_VERSION_1_HEADER_SIZE; - case 2: - // See the format of the header in the comment in detectFormat() above - return (dict[8] << 24) + (dict[9] << 16) + (dict[10] << 8) + dict[11]; - default: - return S_INT_MAX; - } -} - -inline void BinaryFormat::readHeaderValue(const uint8_t *const dict, const int dictSize, - const char *const key, int *outValue, const int outValueSize) { - int outValueIndex = 0; - // Only format 2 and above have header attributes as {key,value} string pairs. For prior - // formats, we just return an empty string, as if the key wasn't found. - if (2 <= detectFormat(dict, dictSize)) { - const int headerOptionsOffset = 4 /* magic number */ - + 2 /* dictionary version */ + 2 /* flags */; - const int headerSize = - (dict[headerOptionsOffset] << 24) + (dict[headerOptionsOffset + 1] << 16) - + (dict[headerOptionsOffset + 2] << 8) + dict[headerOptionsOffset + 3]; - const int headerEnd = headerOptionsOffset + 4 + headerSize; - int index = headerOptionsOffset + 4; - while (index < headerEnd) { - int keyIndex = 0; - int codePoint = getCodePointAndForwardPointer(dict, &index); - while (codePoint != NOT_A_CODE_POINT) { - if (codePoint != key[keyIndex++]) { - break; - } - codePoint = getCodePointAndForwardPointer(dict, &index); - } - if (codePoint == NOT_A_CODE_POINT && key[keyIndex] == 0) { - // We found the key! Copy and return the value. - codePoint = getCodePointAndForwardPointer(dict, &index); - while (codePoint != NOT_A_CODE_POINT && outValueIndex < outValueSize) { - outValue[outValueIndex++] = codePoint; - codePoint = getCodePointAndForwardPointer(dict, &index); - } - // Finished copying. Break to go to the termination code. - break; - } - // We didn't find the key, skip the remainder of it and its value - while (codePoint != NOT_A_CODE_POINT) { - codePoint = getCodePointAndForwardPointer(dict, &index); - } - codePoint = getCodePointAndForwardPointer(dict, &index); - while (codePoint != NOT_A_CODE_POINT) { - codePoint = getCodePointAndForwardPointer(dict, &index); - } - } - // We couldn't find it - fall through and return an empty value. - } - // Put a terminator 0 if possible at all (always unless outValueSize is <= 0) - if (outValueIndex >= outValueSize) outValueIndex = outValueSize - 1; - if (outValueIndex >= 0) outValue[outValueIndex] = 0; -} - -inline int BinaryFormat::readHeaderValueInt(const uint8_t *const dict, const int dictSize, - const char *const key) { - const int bufferSize = LARGEST_INT_DIGIT_COUNT; - int intBuffer[bufferSize]; - char charBuffer[bufferSize]; - BinaryFormat::readHeaderValue(dict, dictSize, key, intBuffer, bufferSize); - for (int i = 0; i < bufferSize; ++i) { - charBuffer[i] = intBuffer[i]; - } - // If not a number, return S_INT_MIN - if (!isdigit(charBuffer[0])) return S_INT_MIN; - return atoi(charBuffer); -} - AK_FORCE_INLINE int BinaryFormat::getGroupCountAndForwardPointer(const uint8_t *const dict, int *pos) { const int msb = dict[(*pos)++]; @@ -269,19 +101,6 @@ AK_FORCE_INLINE int BinaryFormat::getGroupCountAndForwardPointer(const uint8_t * return ((msb & 0x7F) << 8) | dict[(*pos)++]; } -inline float BinaryFormat::getMultiWordCostMultiplier(const uint8_t *const dict, - const int dictSize) { - const int headerValue = readHeaderValueInt(dict, dictSize, - "MULTIPLE_WORDS_DEMOTION_RATE"); - if (headerValue == S_INT_MIN) { - return 1.0f; - } - if (headerValue <= 0) { - return static_cast<float>(MAX_VALUE_FOR_WEIGHTING); - } - return 100.0f / static_cast<float>(headerValue); -} - inline uint8_t BinaryFormat::getFlagsAndForwardPointer(const uint8_t *const dict, int *pos) { return dict[(*pos)++]; } @@ -429,40 +248,8 @@ inline bool BinaryFormat::hasChildrenInFlags(const uint8_t flags) { return (FLAG_GROUP_ADDRESS_TYPE_NOADDRESS != (MASK_GROUP_ADDRESS_TYPE & flags)); } -AK_FORCE_INLINE int BinaryFormat::getAttributeAddressAndForwardPointer(const uint8_t *const dict, - const uint8_t flags, int *pos) { - int offset = 0; - const int origin = *pos; - switch (MASK_ATTRIBUTE_ADDRESS_TYPE & flags) { - case FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE: - offset = dict[origin]; - *pos = origin + 1; - break; - case FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES: - offset = dict[origin] << 8; - offset += dict[origin + 1]; - *pos = origin + 2; - break; - case FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES: - offset = dict[origin] << 16; - offset += dict[origin + 1] << 8; - offset += dict[origin + 2]; - *pos = origin + 3; - break; - } - if (FLAG_ATTRIBUTE_OFFSET_NEGATIVE & flags) { - return origin - offset; - } else { - return origin + offset; - } -} - -inline int BinaryFormat::getAttributeProbabilityFromFlags(const int flags) { - return flags & MASK_ATTRIBUTE_PROBABILITY; -} - // This function gets the byte position of the last chargroup of the exact matching word in the -// dictionary. If no match is found, it returns NOT_VALID_WORD. +// dictionary. If no match is found, it returns NOT_A_VALID_WORD_POS. AK_FORCE_INLINE int BinaryFormat::getTerminalPosition(const uint8_t *const root, const int *const inWord, const int length, const bool forceLowerCaseSearch) { int pos = 0; @@ -471,21 +258,22 @@ AK_FORCE_INLINE int BinaryFormat::getTerminalPosition(const uint8_t *const root, while (true) { // If we already traversed the tree further than the word is long, there means // there was no match (or we would have found it). - if (wordPos >= length) return NOT_VALID_WORD; + if (wordPos >= length) return NOT_A_VALID_WORD_POS; int charGroupCount = BinaryFormat::getGroupCountAndForwardPointer(root, &pos); - const int wChar = forceLowerCaseSearch ? toLowerCase(inWord[wordPos]) : inWord[wordPos]; + const int wChar = forceLowerCaseSearch + ? CharUtils::toLowerCase(inWord[wordPos]) : inWord[wordPos]; while (true) { // If there are no more character groups in this node, it means we could not // find a matching character for this depth, therefore there is no match. - if (0 >= charGroupCount) return NOT_VALID_WORD; + if (0 >= charGroupCount) return NOT_A_VALID_WORD_POS; const int charGroupPos = pos; const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(root, &pos); int character = BinaryFormat::getCodePointAndForwardPointer(root, &pos); if (character == wChar) { // This is the correct node. Only one character group may start with the same // char within a node, so either we found our match in this node, or there is - // no match and we can return NOT_VALID_WORD. So we will check all the characters - // in this character group indeed does match. + // no match and we can return NOT_A_VALID_WORD_POS. So we will check all the + // characters in this character group indeed does match. if (FLAG_HAS_MULTIPLE_CHARS & flags) { character = BinaryFormat::getCodePointAndForwardPointer(root, &pos); while (NOT_A_CODE_POINT != character) { @@ -494,8 +282,8 @@ AK_FORCE_INLINE int BinaryFormat::getTerminalPosition(const uint8_t *const root, // character that does not match, as explained above, it means the word is // not in the dictionary (by virtue of this chargroup being the only one to // match the word on the first character, but not matching the whole word). - if (wordPos >= length) return NOT_VALID_WORD; - if (inWord[wordPos] != character) return NOT_VALID_WORD; + if (wordPos >= length) return NOT_A_VALID_WORD_POS; + if (inWord[wordPos] != character) return NOT_A_VALID_WORD_POS; character = BinaryFormat::getCodePointAndForwardPointer(root, &pos); } } @@ -511,7 +299,7 @@ AK_FORCE_INLINE int BinaryFormat::getTerminalPosition(const uint8_t *const root, pos = BinaryFormat::skipProbability(FLAG_IS_TERMINAL, pos); } if (FLAG_GROUP_ADDRESS_TYPE_NOADDRESS == (MASK_GROUP_ADDRESS_TYPE & flags)) { - return NOT_VALID_WORD; + return NOT_A_VALID_WORD_POS; } // We have children and we are still shorter than the word we are searching for, so // we need to traverse children. Put the pointer on the children position, and @@ -549,8 +337,9 @@ AK_FORCE_INLINE int BinaryFormat::getTerminalPosition(const uint8_t *const root, * outUnigramProbability: a pointer to an int to write the probability into. * Return value : the length of the word, of 0 if the word was not found. */ -AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, const int address, - const int maxDepth, int *outWord, int *outUnigramProbability) { +AK_FORCE_INLINE int BinaryFormat::getCodePointsAndProbabilityAndReturnCodePointCount( + const uint8_t *const root, const int nodePos, const int maxCodePointCount, + int *const outCodePoints, int *const outUnigramProbability) { int pos = 0; int wordPos = 0; @@ -560,7 +349,7 @@ AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, co // The only reason we count nodes is because we want to reduce the probability of infinite // looping in case there is a bug. Since we know there is an upper bound to the depth we are // supposed to traverse, it does not hurt to count iterations. - for (int loopCount = maxDepth; loopCount > 0; --loopCount) { + for (int loopCount = maxCodePointCount; loopCount > 0; --loopCount) { int lastCandidateGroupPos = 0; // Let's loop through char groups in this node searching for either the terminal // or one of its ascendants. @@ -569,17 +358,17 @@ AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, co const int startPos = pos; const uint8_t flags = getFlagsAndForwardPointer(root, &pos); const int character = getCodePointAndForwardPointer(root, &pos); - if (address == startPos) { + if (nodePos == startPos) { // We found the address. Copy the rest of the word in the buffer and return // the length. - outWord[wordPos] = character; + outCodePoints[wordPos] = character; if (FLAG_HAS_MULTIPLE_CHARS & flags) { int nextChar = getCodePointAndForwardPointer(root, &pos); // We count chars in order to avoid infinite loops if the file is broken or // if there is some other bug - int charCount = maxDepth; + int charCount = maxCodePointCount; while (NOT_A_CODE_POINT != nextChar && --charCount > 0) { - outWord[++wordPos] = nextChar; + outCodePoints[++wordPos] = nextChar; nextChar = getCodePointAndForwardPointer(root, &pos); } } @@ -606,7 +395,7 @@ AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, co if (hasChildren) { // Here comes the tricky part. First, read the children position. const int childrenPos = readChildrenPosition(root, flags, pos); - if (childrenPos > address) { + if (childrenPos > nodePos) { // If the children pos is greater than address, it means the previous chargroup, // which address is stored in lastCandidateGroupPos, was the right one. found = true; @@ -636,12 +425,12 @@ AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, co const int lastChar = getCodePointAndForwardPointer(root, &lastCandidateGroupPos); // We copy all the characters in this group to the buffer - outWord[wordPos] = lastChar; + outCodePoints[wordPos] = lastChar; if (FLAG_HAS_MULTIPLE_CHARS & lastFlags) { int nextChar = getCodePointAndForwardPointer(root, &lastCandidateGroupPos); - int charCount = maxDepth; + int charCount = maxCodePointCount; while (-1 != nextChar && --charCount > 0) { - outWord[++wordPos] = nextChar; + outCodePoints[++wordPos] = nextChar; nextChar = getCodePointAndForwardPointer(root, &lastCandidateGroupPos); } } @@ -677,102 +466,5 @@ AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, co return 0; } -static inline int backoff(const int unigramProbability) { - return unigramProbability; - // For some reason, applying the backoff weight gives bad results in tests. To apply the - // backoff weight, we divide the probability by 2, which in our storing format means - // decreasing the score by 8. - // TODO: figure out what's wrong with this. - // return unigramProbability > 8 ? unigramProbability - 8 : (0 == unigramProbability ? 0 : 8); -} - -inline int BinaryFormat::computeProbabilityForBigram( - const int unigramProbability, const int bigramProbability) { - // We divide the range [unigramProbability..255] in 16.5 steps - in other words, we want the - // unigram probability to be the median value of the 17th step from the top. A value of - // 0 for the bigram probability represents the middle of the 16th step from the top, - // while a value of 15 represents the middle of the top step. - // See makedict.BinaryDictInputOutput for details. - const float stepSize = static_cast<float>(MAX_PROBABILITY - unigramProbability) - / (1.5f + MAX_BIGRAM_ENCODED_PROBABILITY); - return unigramProbability - + static_cast<int>(static_cast<float>(bigramProbability + 1) * stepSize); -} - -// This returns a probability in log space. -inline int BinaryFormat::getProbability(const int position, const std::map<int, int> *bigramMap, - const uint8_t *bigramFilter, const int unigramProbability) { - if (!bigramMap || !bigramFilter) return backoff(unigramProbability); - if (!isInFilter(bigramFilter, position)) return backoff(unigramProbability); - const std::map<int, int>::const_iterator bigramProbabilityIt = bigramMap->find(position); - if (bigramProbabilityIt != bigramMap->end()) { - const int bigramProbability = bigramProbabilityIt->second; - return computeProbabilityForBigram(unigramProbability, bigramProbability); - } - return backoff(unigramProbability); -} - -// This returns a probability in log space. -inline int BinaryFormat::getBigramProbabilityFromHashMap(const int position, - const hash_map_compat<int, int> *bigramMap, const int unigramProbability) { - if (!bigramMap) return backoff(unigramProbability); - const hash_map_compat<int, int>::const_iterator bigramProbabilityIt = bigramMap->find(position); - if (bigramProbabilityIt != bigramMap->end()) { - const int bigramProbability = bigramProbabilityIt->second; - return computeProbabilityForBigram(unigramProbability, bigramProbability); - } - return backoff(unigramProbability); -} - -AK_FORCE_INLINE void BinaryFormat::fillBigramProbabilityToHashMap( - const uint8_t *const root, int position, hash_map_compat<int, int> *bigramMap) { - position = getBigramListPositionForWordPosition(root, position); - if (0 == position) return; - - uint8_t bigramFlags; - do { - bigramFlags = getFlagsAndForwardPointer(root, &position); - const int probability = MASK_ATTRIBUTE_PROBABILITY & bigramFlags; - const int bigramPos = getAttributeAddressAndForwardPointer(root, bigramFlags, - &position); - (*bigramMap)[bigramPos] = probability; - } while (FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags); -} - -AK_FORCE_INLINE int BinaryFormat::getBigramProbability(const uint8_t *const root, int position, - const int nextPosition, const int unigramProbability) { - position = getBigramListPositionForWordPosition(root, position); - if (0 == position) return backoff(unigramProbability); - - uint8_t bigramFlags; - do { - bigramFlags = getFlagsAndForwardPointer(root, &position); - const int bigramPos = getAttributeAddressAndForwardPointer( - root, bigramFlags, &position); - if (bigramPos == nextPosition) { - const int bigramProbability = MASK_ATTRIBUTE_PROBABILITY & bigramFlags; - return computeProbabilityForBigram(unigramProbability, bigramProbability); - } - } while (FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags); - return backoff(unigramProbability); -} - -// Returns a pointer to the start of the bigram list. -AK_FORCE_INLINE int BinaryFormat::getBigramListPositionForWordPosition( - const uint8_t *const root, int position) { - if (NOT_VALID_WORD == position) return 0; - const uint8_t flags = getFlagsAndForwardPointer(root, &position); - if (!(flags & FLAG_HAS_BIGRAMS)) return 0; - if (flags & FLAG_HAS_MULTIPLE_CHARS) { - position = skipOtherCharacters(root, position); - } else { - getCodePointAndForwardPointer(root, &position); - } - position = skipProbability(flags, position); - position = skipChildrenPosition(flags, position); - position = skipShortcuts(root, flags, position); - return position; -} - } // namespace latinime #endif // LATINIME_BINARY_FORMAT_H diff --git a/native/jni/src/suggest/policyimpl/dictionary/dictionary_structure_policy_factory.h b/native/jni/src/suggest/policyimpl/dictionary/dictionary_structure_policy_factory.h new file mode 100644 index 000000000..c0df89f49 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/dictionary_structure_policy_factory.h @@ -0,0 +1,48 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DICTIONARY_STRUCTURE_POLICY_FACTORY_H +#define LATINIME_DICTIONARY_STRUCTURE_POLICY_FACTORY_H + +#include "defines.h" +#include "suggest/core/dictionary/binary_dictionary_format_utils.h" +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.h" +#include "suggest/policyimpl/dictionary/patricia_trie_policy.h" + +namespace latinime { + +class DictionaryStructurePolicy; + +class DictionaryStructurePolicyFactory { + public: + static const DictionaryStructurePolicy *getDictionaryStructurePolicy( + const BinaryDictionaryFormatUtils::FORMAT_VERSION dictionaryFormat) { + switch (dictionaryFormat) { + case BinaryDictionaryFormatUtils::VERSION_2: + return PatriciaTriePolicy::getInstance(); + case BinaryDictionaryFormatUtils::VERSION_3: + return DynamicPatriciaTriePolicy::getInstance(); + default: + ASSERT(false); + return 0; + } + } + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(DictionaryStructurePolicyFactory); +}; +} // namespace latinime +#endif // LATINIME_DICTIONARY_STRUCTURE_POLICY_FACTORY_H diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.cpp new file mode 100644 index 000000000..c7314ecf1 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.cpp @@ -0,0 +1,69 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/patricia_trie_policy.h" + +#include "defines.h" +#include "suggest/core/dicnode/dic_node.h" +#include "suggest/core/dicnode/dic_node_vector.h" +#include "suggest/core/dictionary/binary_dictionary_info.h" + +namespace latinime { + +const DynamicPatriciaTriePolicy DynamicPatriciaTriePolicy::sInstance; + +void DynamicPatriciaTriePolicy::createAndGetAllChildNodes(const DicNode *const dicNode, + const BinaryDictionaryInfo *const binaryDictionaryInfo, + const NodeFilter *const nodeFilter, DicNodeVector *const childDicNodes) const { + // TODO: Implement. +} + +int DynamicPatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( + const BinaryDictionaryInfo *const binaryDictionaryInfo, + const int nodePos, const int maxCodePointCount, int *const outCodePoints, + int *const outUnigramProbability) const { + // TODO: Implement. + return 0; +} + +int DynamicPatriciaTriePolicy::getTerminalNodePositionOfWord( + const BinaryDictionaryInfo *const binaryDictionaryInfo, const int *const inWord, + const int length, const bool forceLowerCaseSearch) const { + // TODO: Implement. + return NOT_A_DICT_POS; +} + +int DynamicPatriciaTriePolicy::getUnigramProbability( + const BinaryDictionaryInfo *const binaryDictionaryInfo, const int nodePos) const { + // TODO: Implement. + return NOT_A_PROBABILITY; +} + +int DynamicPatriciaTriePolicy::getShortcutPositionOfNode( + const BinaryDictionaryInfo *const binaryDictionaryInfo, + const int nodePos) const { + // TODO: Implement. + return NOT_A_DICT_POS; +} + +int DynamicPatriciaTriePolicy::getBigramsPositionOfNode( + const BinaryDictionaryInfo *const binaryDictionaryInfo, + const int nodePos) const { + // TODO: Implement. + return NOT_A_DICT_POS; +} + +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.h new file mode 100644 index 000000000..39dfb86fd --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_policy.h @@ -0,0 +1,69 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DYNAMIC_PATRICIA_TRIE_POLICY_H +#define LATINIME_DYNAMIC_PATRICIA_TRIE_POLICY_H + +#include "defines.h" +#include "suggest/core/policy/dictionary_structure_policy.h" + +namespace latinime { + +class BinaryDictionaryInfo; +class DicNode; +class DicNodeVector; + +class DynamicPatriciaTriePolicy : public DictionaryStructurePolicy { + public: + static AK_FORCE_INLINE const DynamicPatriciaTriePolicy *getInstance() { + return &sInstance; + } + + AK_FORCE_INLINE int getRootPosition() const { + return 0; + } + + void createAndGetAllChildNodes(const DicNode *const dicNode, + const BinaryDictionaryInfo *const binaryDictionaryInfo, + const NodeFilter *const nodeFilter, DicNodeVector *const childDicNodes) const; + + int getCodePointsAndProbabilityAndReturnCodePointCount( + const BinaryDictionaryInfo *const binaryDictionaryInfo, + const int terminalNodePos, const int maxCodePointCount, int *const outCodePoints, + int *const outUnigramProbability) const; + + int getTerminalNodePositionOfWord( + const BinaryDictionaryInfo *const binaryDictionaryInfo, const int *const inWord, + const int length, const bool forceLowerCaseSearch) const; + + int getUnigramProbability(const BinaryDictionaryInfo *const binaryDictionaryInfo, + const int nodePos) const; + + int getShortcutPositionOfNode(const BinaryDictionaryInfo *const binaryDictionaryInfo, + const int nodePos) const; + + int getBigramsPositionOfNode(const BinaryDictionaryInfo *const binaryDictionaryInfo, + const int nodePos) const; + + private: + DISALLOW_COPY_AND_ASSIGN(DynamicPatriciaTriePolicy); + static const DynamicPatriciaTriePolicy sInstance; + + DynamicPatriciaTriePolicy() {} + ~DynamicPatriciaTriePolicy() {} +}; +} // namespace latinime +#endif // LATINIME_DYNAMIC_PATRICIA_TRIE_POLICY_H diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.cpp new file mode 100644 index 000000000..0de6341b0 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.cpp @@ -0,0 +1,41 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.h" + +#include "defines.h" +#include "suggest/core/dictionary/byte_array_utils.h" + +namespace latinime { + +typedef DynamicPatriciaTrieReadingUtils DptReadingUtils; + +const DptReadingUtils::NodeFlags DptReadingUtils::MASK_MOVED = 0xC0; +const DptReadingUtils::NodeFlags DptReadingUtils::FLAG_IS_NOT_MOVED = 0xC0; +const DptReadingUtils::NodeFlags DptReadingUtils::FLAG_IS_MOVED = 0x40; +const DptReadingUtils::NodeFlags DptReadingUtils::FLAG_IS_DELETED = 0x80; + +/* static */ int DptReadingUtils::readChildrenPositionAndAdvancePosition( + const uint8_t *const buffer, const NodeFlags flags, int *const pos) { + if ((flags & MASK_MOVED) == FLAG_IS_NOT_MOVED) { + const int base = *pos; + return base + ByteArrayUtils::readSint24AndAdvancePosition(buffer, pos); + } else { + return NOT_A_DICT_POS; + } +} + +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.h b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.h new file mode 100644 index 000000000..f44c2651a --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/dynamic_patricia_trie_reading_utils.h @@ -0,0 +1,69 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DYNAMIC_PATRICIA_TRIE_READING_UTILS_H +#define LATINIME_DYNAMIC_PATRICIA_TRIE_READING_UTILS_H + +#include <stdint.h> + +#include "defines.h" +#include "suggest/core/dictionary/byte_array_utils.h" + +namespace latinime { + +class DynamicPatriciaTrieReadingUtils { + public: + typedef uint8_t NodeFlags; + + static AK_FORCE_INLINE int getForwardLinkPosition(const uint8_t *const buffer, const int pos) { + int linkAddressPos = pos; + return ByteArrayUtils::readSint24AndAdvancePosition(buffer, &linkAddressPos); + } + + static AK_FORCE_INLINE bool isValidForwardLinkPosition(const int forwardLinkAddress) { + return forwardLinkAddress != 0; + } + + static AK_FORCE_INLINE int getParentPosAndAdvancePosition(const uint8_t *const buffer, + int *const pos) { + const int base = *pos; + return base + ByteArrayUtils::readSint24AndAdvancePosition(buffer, pos); + } + + static int readChildrenPositionAndAdvancePosition(const uint8_t *const buffer, + const NodeFlags flags, int *const pos); + + /** + * Node Flags + */ + static AK_FORCE_INLINE bool isMoved(const NodeFlags flags) { + return FLAG_IS_MOVED == (MASK_MOVED & flags); + } + + static AK_FORCE_INLINE bool isDeleted(const NodeFlags flags) { + return FLAG_IS_DELETED == (MASK_MOVED & flags); + } + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(DynamicPatriciaTrieReadingUtils); + + static const NodeFlags MASK_MOVED; + static const NodeFlags FLAG_IS_NOT_MOVED; + static const NodeFlags FLAG_IS_MOVED; + static const NodeFlags FLAG_IS_DELETED; +}; +} // namespace latinime +#endif /* LATINIME_DYNAMIC_PATRICIA_TRIE_READING_UTILS_H */ diff --git a/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.cpp new file mode 100644 index 000000000..097f7c86a --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.cpp @@ -0,0 +1,169 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#include "suggest/policyimpl/dictionary/patricia_trie_policy.h" + +#include "defines.h" +#include "suggest/core/dicnode/dic_node.h" +#include "suggest/core/dicnode/dic_node_vector.h" +#include "suggest/core/dictionary/binary_dictionary_info.h" +#include "suggest/core/dictionary/binary_dictionary_terminal_attributes_reading_utils.h" +#include "suggest/policyimpl/dictionary/binary_format.h" +#include "suggest/policyimpl/dictionary/patricia_trie_reading_utils.h" + +namespace latinime { + +const PatriciaTriePolicy PatriciaTriePolicy::sInstance; + +void PatriciaTriePolicy::createAndGetAllChildNodes(const DicNode *const dicNode, + const BinaryDictionaryInfo *const binaryDictionaryInfo, + const NodeFilter *const nodeFilter, DicNodeVector *const childDicNodes) const { + if (!dicNode->hasChildren()) { + return; + } + int nextPos = dicNode->getChildrenPos(); + const int childCount = PatriciaTrieReadingUtils::getGroupCountAndAdvancePosition( + binaryDictionaryInfo->getDictRoot(), &nextPos); + for (int i = 0; i < childCount; i++) { + nextPos = createAndGetLeavingChildNode(dicNode, nextPos, binaryDictionaryInfo, + nodeFilter, childDicNodes); + } +} + +int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( + const BinaryDictionaryInfo *const binaryDictionaryInfo, + const int nodePos, const int maxCodePointCount, int *const outCodePoints, + int *const outUnigramProbability) const { + return BinaryFormat::getCodePointsAndProbabilityAndReturnCodePointCount( + binaryDictionaryInfo->getDictRoot(), nodePos, + maxCodePointCount, outCodePoints, outUnigramProbability); +} + +int PatriciaTriePolicy::getTerminalNodePositionOfWord( + const BinaryDictionaryInfo *const binaryDictionaryInfo, const int *const inWord, + const int length, const bool forceLowerCaseSearch) const { + return BinaryFormat::getTerminalPosition(binaryDictionaryInfo->getDictRoot(), inWord, + length, forceLowerCaseSearch); +} + +int PatriciaTriePolicy::getUnigramProbability( + const BinaryDictionaryInfo *const binaryDictionaryInfo, const int nodePos) const { + if (nodePos == NOT_A_VALID_WORD_POS) { + return NOT_A_PROBABILITY; + } + const uint8_t *const dictRoot = binaryDictionaryInfo->getDictRoot(); + int pos = nodePos; + const PatriciaTrieReadingUtils::NodeFlags flags = + PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(dictRoot, &pos); + if (!PatriciaTrieReadingUtils::isTerminal(flags)) { + return NOT_A_PROBABILITY; + } + if (PatriciaTrieReadingUtils::isNotAWord(flags) + || PatriciaTrieReadingUtils::isBlacklisted(flags)) { + // If this is not a word, or if it's a blacklisted entry, it should behave as + // having no probability outside of the suggestion process (where it should be used + // for shortcuts). + return NOT_A_PROBABILITY; + } + PatriciaTrieReadingUtils::skipCharacters(dictRoot, flags, MAX_WORD_LENGTH, &pos); + return PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(dictRoot, &pos); +} + +int PatriciaTriePolicy::getShortcutPositionOfNode( + const BinaryDictionaryInfo *const binaryDictionaryInfo, + const int nodePos) const { + if (nodePos == NOT_A_VALID_WORD_POS) { + return NOT_A_DICT_POS; + } + const uint8_t *const dictRoot = binaryDictionaryInfo->getDictRoot(); + int pos = nodePos; + const PatriciaTrieReadingUtils::NodeFlags flags = + PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(dictRoot, &pos); + if (!PatriciaTrieReadingUtils::hasShortcutTargets(flags)) { + return NOT_A_DICT_POS; + } + PatriciaTrieReadingUtils::skipCharacters(dictRoot, flags, MAX_WORD_LENGTH, &pos); + if (PatriciaTrieReadingUtils::isTerminal(flags)) { + PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(dictRoot, &pos); + } + if (PatriciaTrieReadingUtils::hasChildrenInFlags(flags)) { + PatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition(dictRoot, flags, &pos); + } + return pos; +} + +int PatriciaTriePolicy::getBigramsPositionOfNode( + const BinaryDictionaryInfo *const binaryDictionaryInfo, + const int nodePos) const { + if (nodePos == NOT_A_VALID_WORD_POS) { + return NOT_A_DICT_POS; + } + const uint8_t *const dictRoot = binaryDictionaryInfo->getDictRoot(); + int pos = nodePos; + const PatriciaTrieReadingUtils::NodeFlags flags = + PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(dictRoot, &pos); + if (!PatriciaTrieReadingUtils::hasBigrams(flags)) { + return NOT_A_DICT_POS; + } + PatriciaTrieReadingUtils::skipCharacters(dictRoot, flags, MAX_WORD_LENGTH, &pos); + if (PatriciaTrieReadingUtils::isTerminal(flags)) { + PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(dictRoot, &pos); + } + if (PatriciaTrieReadingUtils::hasChildrenInFlags(flags)) { + PatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition(dictRoot, flags, &pos); + } + if (PatriciaTrieReadingUtils::hasShortcutTargets(flags)) { + BinaryDictionaryTerminalAttributesReadingUtils::skipShortcuts(binaryDictionaryInfo, &pos); + } + return pos; +} + +int PatriciaTriePolicy::createAndGetLeavingChildNode(const DicNode *const dicNode, + const int nodePos, const BinaryDictionaryInfo *const binaryDictionaryInfo, + const NodeFilter *const childrenFilter, DicNodeVector *childDicNodes) const { + const uint8_t *const dictRoot = binaryDictionaryInfo->getDictRoot(); + int pos = nodePos; + const PatriciaTrieReadingUtils::NodeFlags flags = + PatriciaTrieReadingUtils::getFlagsAndAdvancePosition(dictRoot, &pos); + int mergedNodeCodePoints[MAX_WORD_LENGTH]; + const int mergedNodeCodePointCount = PatriciaTrieReadingUtils::getCharsAndAdvancePosition( + dictRoot, flags, MAX_WORD_LENGTH, mergedNodeCodePoints, &pos); + const int probability = (PatriciaTrieReadingUtils::isTerminal(flags))? + PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(dictRoot, &pos) + : NOT_A_PROBABILITY; + const int childrenPos = PatriciaTrieReadingUtils::hasChildrenInFlags(flags) ? + PatriciaTrieReadingUtils::readChildrenPositionAndAdvancePosition( + dictRoot, flags, &pos) : NOT_A_DICT_POS; + if (PatriciaTrieReadingUtils::hasShortcutTargets(flags)) { + BinaryDictionaryTerminalAttributesReadingUtils::skipShortcuts(binaryDictionaryInfo, &pos); + } + if (PatriciaTrieReadingUtils::hasBigrams(flags)) { + BinaryDictionaryTerminalAttributesReadingUtils::skipExistingBigrams( + binaryDictionaryInfo, &pos); + } + if (!childrenFilter->isFilteredOut(mergedNodeCodePoints[0])) { + childDicNodes->pushLeavingChild(dicNode, nodePos, childrenPos, probability, + PatriciaTrieReadingUtils::isTerminal(flags), + PatriciaTrieReadingUtils::hasChildrenInFlags(flags), + PatriciaTrieReadingUtils::isBlacklisted(flags) || + PatriciaTrieReadingUtils::isNotAWord(flags), + mergedNodeCodePointCount, mergedNodeCodePoints); + } + return pos; +} + +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.h new file mode 100644 index 000000000..71f256eee --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_policy.h @@ -0,0 +1,69 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_PATRICIA_TRIE_POLICY_H +#define LATINIME_PATRICIA_TRIE_POLICY_H + +#include "defines.h" +#include "suggest/core/policy/dictionary_structure_policy.h" + +namespace latinime { + +class PatriciaTriePolicy : public DictionaryStructurePolicy { + public: + static AK_FORCE_INLINE const PatriciaTriePolicy *getInstance() { + return &sInstance; + } + + AK_FORCE_INLINE int getRootPosition() const { + return 0; + } + + void createAndGetAllChildNodes(const DicNode *const dicNode, + const BinaryDictionaryInfo *const binaryDictionaryInfo, + const NodeFilter *const nodeFilter, DicNodeVector *const childDicNodes) const; + + int getCodePointsAndProbabilityAndReturnCodePointCount( + const BinaryDictionaryInfo *const binaryDictionaryInfo, + const int terminalNodePos, const int maxCodePointCount, int *const outCodePoints, + int *const outUnigramProbability) const; + + int getTerminalNodePositionOfWord( + const BinaryDictionaryInfo *const binaryDictionaryInfo, const int *const inWord, + const int length, const bool forceLowerCaseSearch) const; + + int getUnigramProbability(const BinaryDictionaryInfo *const binaryDictionaryInfo, + const int nodePos) const; + + int getShortcutPositionOfNode(const BinaryDictionaryInfo *const binaryDictionaryInfo, + const int nodePos) const; + + int getBigramsPositionOfNode(const BinaryDictionaryInfo *const binaryDictionaryInfo, + const int nodePos) const; + + private: + DISALLOW_COPY_AND_ASSIGN(PatriciaTriePolicy); + static const PatriciaTriePolicy sInstance; + + PatriciaTriePolicy() {} + ~PatriciaTriePolicy() {} + + int createAndGetLeavingChildNode(const DicNode *const dicNode, const int nodePos, + const BinaryDictionaryInfo *const binaryDictionaryInfo, + const NodeFilter *const nodeFilter, DicNodeVector *const childDicNodes) const; +}; +} // namespace latinime +#endif // LATINIME_PATRICIA_TRIE_POLICY_H diff --git a/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_reading_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_reading_utils.cpp new file mode 100644 index 000000000..89e981df8 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_reading_utils.cpp @@ -0,0 +1,67 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/policyimpl/dictionary/patricia_trie_reading_utils.h" + +#include "defines.h" +#include "suggest/core/dictionary/byte_array_utils.h" + +namespace latinime { + +typedef PatriciaTrieReadingUtils PtReadingUtils; + +const PtReadingUtils::NodeFlags PtReadingUtils::MASK_GROUP_ADDRESS_TYPE = 0xC0; +const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_GROUP_ADDRESS_TYPE_NOADDRESS = 0x00; +const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_GROUP_ADDRESS_TYPE_ONEBYTE = 0x40; +const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_GROUP_ADDRESS_TYPE_TWOBYTES = 0x80; +const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_GROUP_ADDRESS_TYPE_THREEBYTES = 0xC0; + +// Flag for single/multiple char group +const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_HAS_MULTIPLE_CHARS = 0x20; +// Flag for terminal groups +const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_IS_TERMINAL = 0x10; +// Flag for shortcut targets presence +const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_HAS_SHORTCUT_TARGETS = 0x08; +// Flag for bigram presence +const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_HAS_BIGRAMS = 0x04; +// Flag for non-words (typically, shortcut only entries) +const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_IS_NOT_A_WORD = 0x02; +// Flag for blacklist +const PtReadingUtils::NodeFlags PtReadingUtils::FLAG_IS_BLACKLISTED = 0x01; + +/* static */ int PtReadingUtils::readChildrenPositionAndAdvancePosition( + const uint8_t *const buffer, const NodeFlags flags, int *const pos) { + const int base = *pos; + int offset = 0; + switch (MASK_GROUP_ADDRESS_TYPE & flags) { + case FLAG_GROUP_ADDRESS_TYPE_ONEBYTE: + offset = ByteArrayUtils::readUint8AndAdvancePosition(buffer, pos); + break; + case FLAG_GROUP_ADDRESS_TYPE_TWOBYTES: + offset = ByteArrayUtils::readUint16AndAdvancePosition(buffer, pos); + break; + case FLAG_GROUP_ADDRESS_TYPE_THREEBYTES: + offset = ByteArrayUtils::readUint24AndAdvancePosition(buffer, pos); + break; + default: + // If we come here, it means we asked for the children of a word with + // no children. + return NOT_A_DICT_POS; + } + return base + offset; +} + +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_reading_utils.h b/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_reading_utils.h new file mode 100644 index 000000000..002c3f19b --- /dev/null +++ b/native/jni/src/suggest/policyimpl/dictionary/patricia_trie_reading_utils.h @@ -0,0 +1,139 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_PATRICIA_TRIE_READING_UTILS_H +#define LATINIME_PATRICIA_TRIE_READING_UTILS_H + +#include <stdint.h> + +#include "defines.h" +#include "suggest/core/dictionary/byte_array_utils.h" + +namespace latinime { + +class PatriciaTrieReadingUtils { + public: + typedef uint8_t NodeFlags; + + static AK_FORCE_INLINE int getGroupCountAndAdvancePosition( + const uint8_t *const buffer, int *const pos) { + const uint8_t firstByte = ByteArrayUtils::readUint8AndAdvancePosition(buffer, pos); + if (firstByte < 0x80) { + return firstByte; + } else { + return ((firstByte & 0x7F) << 8) ^ ByteArrayUtils::readUint8AndAdvancePosition( + buffer, pos); + } + } + + static AK_FORCE_INLINE NodeFlags getFlagsAndAdvancePosition(const uint8_t *const buffer, + int *const pos) { + return ByteArrayUtils::readUint8AndAdvancePosition(buffer, pos); + } + + static AK_FORCE_INLINE int getCodePointAndAdvancePosition(const uint8_t *const buffer, + int *const pos) { + return ByteArrayUtils::readCodePointAndAdvancePosition(buffer, pos); + } + + // Returns the number of read characters. + static AK_FORCE_INLINE int getCharsAndAdvancePosition(const uint8_t *const buffer, + const NodeFlags flags, const int maxLength, int *const outBuffer, int *const pos) { + int length = 0; + if (hasMultipleChars(flags)) { + length = ByteArrayUtils::readStringAndAdvancePosition(buffer, maxLength, outBuffer, + pos); + } else { + if (maxLength > 0) { + outBuffer[0] = getCodePointAndAdvancePosition(buffer, pos); + length = 1; + } + } + return length; + } + + // Returns the number of skipped characters. + static AK_FORCE_INLINE int skipCharacters(const uint8_t *const buffer, const NodeFlags flags, + const int maxLength, int *const pos) { + if (hasMultipleChars(flags)) { + return ByteArrayUtils::advancePositionToBehindString(buffer, maxLength, pos); + } else { + if (maxLength > 0) { + getCodePointAndAdvancePosition(buffer, pos); + return 1; + } else { + return 0; + } + } + } + + static AK_FORCE_INLINE int readProbabilityAndAdvancePosition(const uint8_t *const buffer, + int *const pos) { + return ByteArrayUtils::readUint8AndAdvancePosition(buffer, pos); + } + + static int readChildrenPositionAndAdvancePosition(const uint8_t *const buffer, + const NodeFlags flags, int *const pos); + + /** + * Node Flags + */ + static AK_FORCE_INLINE bool isBlacklisted(const NodeFlags flags) { + return (flags & FLAG_IS_BLACKLISTED) != 0; + } + + static AK_FORCE_INLINE bool isNotAWord(const NodeFlags flags) { + return (flags & FLAG_IS_NOT_A_WORD) != 0; + } + + static AK_FORCE_INLINE bool isTerminal(const NodeFlags flags) { + return (flags & FLAG_IS_TERMINAL) != 0; + } + + static AK_FORCE_INLINE bool hasShortcutTargets(const NodeFlags flags) { + return (flags & FLAG_HAS_SHORTCUT_TARGETS) != 0; + } + + static AK_FORCE_INLINE bool hasBigrams(const NodeFlags flags) { + return (flags & FLAG_HAS_BIGRAMS) != 0; + } + + static AK_FORCE_INLINE bool hasMultipleChars(const NodeFlags flags) { + return (flags & FLAG_HAS_MULTIPLE_CHARS) != 0; + } + + static AK_FORCE_INLINE bool hasChildrenInFlags(const NodeFlags flags) { + return FLAG_GROUP_ADDRESS_TYPE_NOADDRESS != (MASK_GROUP_ADDRESS_TYPE & flags); + } + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(PatriciaTrieReadingUtils); + + static const NodeFlags MASK_GROUP_ADDRESS_TYPE; + static const NodeFlags FLAG_GROUP_ADDRESS_TYPE_NOADDRESS; + static const NodeFlags FLAG_GROUP_ADDRESS_TYPE_ONEBYTE; + static const NodeFlags FLAG_GROUP_ADDRESS_TYPE_TWOBYTES; + static const NodeFlags FLAG_GROUP_ADDRESS_TYPE_THREEBYTES; + + static const NodeFlags FLAG_HAS_MULTIPLE_CHARS; + static const NodeFlags FLAG_IS_TERMINAL; + static const NodeFlags FLAG_HAS_SHORTCUT_TARGETS; + static const NodeFlags FLAG_HAS_BIGRAMS; + static const NodeFlags FLAG_IS_NOT_A_WORD; + static const NodeFlags FLAG_IS_BLACKLISTED; +}; +} // namespace latinime +#endif /* LATINIME_PATRICIA_TRIE_NODE_READING_UTILS_H */ diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp index f87989286..ecceb60d3 100644 --- a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp +++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp @@ -22,31 +22,36 @@ const float ScoringParams::MAX_SPATIAL_DISTANCE = 1.0f; const int ScoringParams::THRESHOLD_NEXT_WORD_PROBABILITY = 40; const int ScoringParams::THRESHOLD_NEXT_WORD_PROBABILITY_FOR_CAPPED = 120; const float ScoringParams::AUTOCORRECT_OUTPUT_THRESHOLD = 1.0f; -const int ScoringParams::MAX_CACHE_DIC_NODE_SIZE = 125; +// TODO: Unlimit max cache dic node size +const int ScoringParams::MAX_CACHE_DIC_NODE_SIZE = 170; +const int ScoringParams::MAX_CACHE_DIC_NODE_SIZE_FOR_SINGLE_POINT = 310; const int ScoringParams::THRESHOLD_SHORT_WORD_LENGTH = 4; const float ScoringParams::DISTANCE_WEIGHT_LENGTH = 0.132f; -const float ScoringParams::PROXIMITY_COST = 0.086f; -const float ScoringParams::FIRST_PROXIMITY_COST = 0.104f; +const float ScoringParams::PROXIMITY_COST = 0.095f; +const float ScoringParams::FIRST_CHAR_PROXIMITY_COST = 0.102f; +const float ScoringParams::FIRST_PROXIMITY_COST = 0.019f; const float ScoringParams::OMISSION_COST = 0.458f; const float ScoringParams::OMISSION_COST_SAME_CHAR = 0.491f; const float ScoringParams::OMISSION_COST_FIRST_CHAR = 0.582f; const float ScoringParams::INSERTION_COST = 0.730f; +const float ScoringParams::TERMINAL_INSERTION_COST = 0.93f; const float ScoringParams::INSERTION_COST_SAME_CHAR = 0.586f; +const float ScoringParams::INSERTION_COST_PROXIMITY_CHAR = 0.70f; const float ScoringParams::INSERTION_COST_FIRST_CHAR = 0.623f; -const float ScoringParams::TRANSPOSITION_COST = 0.516f; +const float ScoringParams::TRANSPOSITION_COST = 0.526f; const float ScoringParams::SPACE_SUBSTITUTION_COST = 0.319f; const float ScoringParams::ADDITIONAL_PROXIMITY_COST = 0.380f; -const float ScoringParams::SUBSTITUTION_COST = 0.403f; +const float ScoringParams::SUBSTITUTION_COST = 0.383f; const float ScoringParams::COST_NEW_WORD = 0.042f; const float ScoringParams::COST_SECOND_OR_LATER_WORD_FIRST_CHAR_UPPERCASE = 0.25f; const float ScoringParams::DISTANCE_WEIGHT_LANGUAGE = 1.123f; const float ScoringParams::COST_FIRST_LOOKAHEAD = 0.545f; const float ScoringParams::COST_LOOKAHEAD = 0.073f; -const float ScoringParams::HAS_PROXIMITY_TERMINAL_COST = 0.105f; -const float ScoringParams::HAS_EDIT_CORRECTION_TERMINAL_COST = 0.038f; -const float ScoringParams::HAS_MULTI_WORD_TERMINAL_COST = 0.444f; +const float ScoringParams::HAS_PROXIMITY_TERMINAL_COST = 0.093f; +const float ScoringParams::HAS_EDIT_CORRECTION_TERMINAL_COST = 0.041f; +const float ScoringParams::HAS_MULTI_WORD_TERMINAL_COST = 0.447f; const float ScoringParams::TYPING_BASE_OUTPUT_SCORE = 1.0f; const float ScoringParams::TYPING_MAX_OUTPUT_SCORE_PER_INPUT = 0.1f; -const float ScoringParams::NORMALIZED_SPATIAL_DISTANCE_THRESHOLD_FOR_EDIT = 0.06f; +const float ScoringParams::NORMALIZED_SPATIAL_DISTANCE_THRESHOLD_FOR_EDIT = 0.045f; } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.h b/native/jni/src/suggest/policyimpl/typing/scoring_params.h index 53ac999c1..7d4b5c3c7 100644 --- a/native/jni/src/suggest/policyimpl/typing/scoring_params.h +++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.h @@ -29,6 +29,7 @@ class ScoringParams { static const int THRESHOLD_NEXT_WORD_PROBABILITY_FOR_CAPPED; static const float AUTOCORRECT_OUTPUT_THRESHOLD; static const int MAX_CACHE_DIC_NODE_SIZE; + static const int MAX_CACHE_DIC_NODE_SIZE_FOR_SINGLE_POINT; static const int THRESHOLD_SHORT_WORD_LENGTH; // Numerically optimized parameters (currently for tap typing only). @@ -36,12 +37,15 @@ class ScoringParams { // TODO: explore optimization of gesture parameters. static const float DISTANCE_WEIGHT_LENGTH; static const float PROXIMITY_COST; + static const float FIRST_CHAR_PROXIMITY_COST; static const float FIRST_PROXIMITY_COST; static const float OMISSION_COST; static const float OMISSION_COST_SAME_CHAR; static const float OMISSION_COST_FIRST_CHAR; static const float INSERTION_COST; + static const float TERMINAL_INSERTION_COST; static const float INSERTION_COST_SAME_CHAR; + static const float INSERTION_COST_PROXIMITY_CHAR; static const float INSERTION_COST_FIRST_CHAR; static const float TRANSPOSITION_COST; static const float SPACE_SUBSTITUTION_COST; diff --git a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h index 90e2133e7..56ffcc93e 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h @@ -55,10 +55,10 @@ class TypingScoring : public Scoring { const int inputSize, const bool forceCommit) const { const float maxDistance = ScoringParams::DISTANCE_WEIGHT_LANGUAGE + static_cast<float>(inputSize) * ScoringParams::TYPING_MAX_OUTPUT_SCORE_PER_INPUT; - return static_cast<int>((ScoringParams::TYPING_BASE_OUTPUT_SCORE - - (compoundDistance / maxDistance) - + (forceCommit ? ScoringParams::AUTOCORRECT_OUTPUT_THRESHOLD : 0.0f)) - * SUGGEST_INTERFACE_OUTPUT_SCALE); + const float score = ScoringParams::TYPING_BASE_OUTPUT_SCORE + - compoundDistance / maxDistance + + (forceCommit ? ScoringParams::AUTOCORRECT_OUTPUT_THRESHOLD : 0.0f); + return static_cast<int>(score * SUGGEST_INTERFACE_OUTPUT_SCALE); } AK_FORCE_INLINE float getDoubleLetterDemotionDistanceCost(const int terminalIndex, diff --git a/native/jni/src/suggest/policyimpl/typing/typing_traversal.h b/native/jni/src/suggest/policyimpl/typing/typing_traversal.h index 12110d54f..89e53f441 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_traversal.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_traversal.h @@ -19,14 +19,15 @@ #include <stdint.h> -#include "char_utils.h" #include "defines.h" -#include "proximity_info_state.h" #include "suggest/core/dicnode/dic_node.h" #include "suggest/core/dicnode/dic_node_vector.h" +#include "suggest/core/layout/proximity_info_state.h" +#include "suggest/core/layout/proximity_info_utils.h" #include "suggest/core/policy/traversal.h" #include "suggest/core/session/dic_traverse_session.h" #include "suggest/policyimpl/typing/scoring_params.h" +#include "utils/char_utils.h" namespace latinime { class TypingTraversal : public Traversal { @@ -64,9 +65,9 @@ class TypingTraversal : public Traversal { } const int point0Index = dicNode->getInputIndex(0); const int currentBaseLowerCodePoint = - toBaseLowerCase(childDicNode->getNodeCodePoint()); + CharUtils::toBaseLowerCase(childDicNode->getNodeCodePoint()); const int typedBaseLowerCodePoint = - toBaseLowerCase(traverseSession->getProximityInfoState(0) + CharUtils::toBaseLowerCase(traverseSession->getProximityInfoState(0) ->getPrimaryCodePointAt(point0Index)); return (currentBaseLowerCodePoint != typedBaseLowerCodePoint); } @@ -136,7 +137,7 @@ class TypingTraversal : public Traversal { return ScoringParams::MAX_SPATIAL_DISTANCE; } - AK_FORCE_INLINE bool allowPartialCommit() const { + AK_FORCE_INLINE bool autoCorrectsToMultiWordSuggestionIfTop() const { return true; } @@ -147,11 +148,12 @@ class TypingTraversal : public Traversal { AK_FORCE_INLINE bool sameAsTyped( const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const { return traverseSession->getProximityInfoState(0)->sameAsTyped( - dicNode->getOutputWordBuf(), dicNode->getDepth()); + dicNode->getOutputWordBuf(), dicNode->getNodeCodePointCount()); } - AK_FORCE_INLINE int getMaxCacheSize() const { - return ScoringParams::MAX_CACHE_DIC_NODE_SIZE; + AK_FORCE_INLINE int getMaxCacheSize(const int inputSize) const { + return (inputSize <= 1) ? ScoringParams::MAX_CACHE_DIC_NODE_SIZE_FOR_SINGLE_POINT + : ScoringParams::MAX_CACHE_DIC_NODE_SIZE; } AK_FORCE_INLINE bool isPossibleOmissionChildNode( @@ -159,7 +161,7 @@ class TypingTraversal : public Traversal { const DicNode *const dicNode) const { const ProximityType proximityType = getProximityType(traverseSession, parentDicNode, dicNode); - if (!DicNodeUtils::isProximityChar(proximityType)) { + if (!ProximityInfoUtils::isMatchOrProximityChar(proximityType)) { return false; } return true; @@ -171,8 +173,8 @@ class TypingTraversal : public Traversal { return false; } const int c = dicNode->getOutputWordBuf()[0]; - const bool shortCappedWord = dicNode->getDepth() - < ScoringParams::THRESHOLD_SHORT_WORD_LENGTH && isAsciiUpper(c); + const bool shortCappedWord = dicNode->getNodeCodePointCount() + < ScoringParams::THRESHOLD_SHORT_WORD_LENGTH && CharUtils::isAsciiUpper(c); return !shortCappedWord || probability >= ScoringParams::THRESHOLD_NEXT_WORD_PROBABILITY_FOR_CAPPED; } diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp b/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp index e4c69d1f6..408b12ae9 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp +++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp @@ -44,6 +44,7 @@ ErrorType TypingWeighting::getErrorType(const CorrectionType correctionType, break; case CT_SUBSTITUTION: case CT_INSERTION: + case CT_TERMINAL_INSERTION: case CT_TRANSPOSITION: return ET_EDIT_CORRECTION; case CT_NEW_WORD_SPACE_OMITTION: diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h index 3938c0ec5..7cddb0882 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h @@ -18,11 +18,12 @@ #define LATINIME_TYPING_WEIGHTING_H #include "defines.h" -#include "suggest_utils.h" #include "suggest/core/dicnode/dic_node_utils.h" +#include "suggest/core/layout/touch_position_correction_utils.h" #include "suggest/core/policy/weighting.h" #include "suggest/core/session/dic_traverse_session.h" #include "suggest/policyimpl/typing/scoring_params.h" +#include "utils/char_utils.h" namespace latinime { @@ -54,7 +55,7 @@ class TypingWeighting : public Weighting { const bool isZeroCostOmission = parentDicNode->isZeroCostOmission(); const bool sameCodePoint = dicNode->isSameNodeCodePoint(parentDicNode); // If the traversal omitted the first letter then the dicNode should now be on the second. - const bool isFirstLetterOmission = dicNode->getDepth() == 2; + const bool isFirstLetterOmission = dicNode->getNodeCodePointCount() == 2; float cost = 0.0f; if (isZeroCostOmission) { cost = 0.0f; @@ -74,15 +75,18 @@ class TypingWeighting : public Weighting { // the keyboard (like accented letters) const float normalizedSquaredLength = traverseSession->getProximityInfoState(0) ->getPointToKeyLength(pointIndex, dicNode->getNodeCodePoint()); - const float normalizedDistance = SuggestUtils::getSweetSpotFactor( + const float normalizedDistance = TouchPositionCorrectionUtils::getSweetSpotFactor( traverseSession->isTouchPositionCorrectionEnabled(), normalizedSquaredLength); const float weightedDistance = ScoringParams::DISTANCE_WEIGHT_LENGTH * normalizedDistance; const bool isFirstChar = pointIndex == 0; const bool isProximity = isProximityDicNode(traverseSession, dicNode); - float cost = isProximity ? (isFirstChar ? ScoringParams::FIRST_PROXIMITY_COST + float cost = isProximity ? (isFirstChar ? ScoringParams::FIRST_CHAR_PROXIMITY_COST : ScoringParams::PROXIMITY_COST) : 0.0f; - if (dicNode->getDepth() == 2) { + if (isProximity && dicNode->getProximityCorrectionCount() == 0) { + cost += ScoringParams::FIRST_PROXIMITY_COST; + } + if (dicNode->getNodeCodePointCount() == 2) { // At the second character of the current word, we check if the first char is uppercase // and the word is a second or later word of a multiple word suggestion. We demote it // if so. @@ -98,9 +102,9 @@ class TypingWeighting : public Weighting { bool isProximityDicNode(const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const { const int pointIndex = dicNode->getInputIndex(0); - const int primaryCodePoint = toBaseLowerCase( + const int primaryCodePoint = CharUtils::toBaseLowerCase( traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt(pointIndex)); - const int dicNodeChar = toBaseLowerCase(dicNode->getNodeCodePoint()); + const int dicNodeChar = CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint()); return primaryCodePoint != dicNodeChar; } @@ -121,31 +125,37 @@ class TypingWeighting : public Weighting { float getInsertionCost(const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, const DicNode *const dicNode) const { - const int16_t parentPointIndex = parentDicNode->getInputIndex(0); - const int prevCodePoint = - traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt(parentPointIndex); - + const int16_t insertedPointIndex = parentDicNode->getInputIndex(0); + const int prevCodePoint = traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt( + insertedPointIndex); const int currentCodePoint = dicNode->getNodeCodePoint(); const bool sameCodePoint = prevCodePoint == currentCodePoint; + const bool existsAdjacentProximityChars = traverseSession->getProximityInfoState(0) + ->existsAdjacentProximityChars(insertedPointIndex); const float dist = traverseSession->getProximityInfoState(0)->getPointToKeyLength( - parentPointIndex + 1, currentCodePoint); + insertedPointIndex + 1, dicNode->getNodeCodePoint()); const float weightedDistance = dist * ScoringParams::DISTANCE_WEIGHT_LENGTH; - const bool singleChar = dicNode->getDepth() == 1; - const float cost = (singleChar ? ScoringParams::INSERTION_COST_FIRST_CHAR : 0.0f) - + (sameCodePoint ? ScoringParams::INSERTION_COST_SAME_CHAR - : ScoringParams::INSERTION_COST); + const bool singleChar = dicNode->getNodeCodePointCount() == 1; + float cost = (singleChar ? ScoringParams::INSERTION_COST_FIRST_CHAR : 0.0f); + if (sameCodePoint) { + cost += ScoringParams::INSERTION_COST_SAME_CHAR; + } else if (existsAdjacentProximityChars) { + cost += ScoringParams::INSERTION_COST_PROXIMITY_CHAR; + } else { + cost += ScoringParams::INSERTION_COST; + } return cost + weightedDistance; } - float getNewWordCost(const DicTraverseSession *const traverseSession, - const DicNode *const dicNode) const { + float getNewWordSpatialCost(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode, DicNode_InputStateG *inputStateG) const { return ScoringParams::COST_NEW_WORD * traverseSession->getMultiWordCostMultiplier(); } - float getNewWordBigramCost(const DicTraverseSession *const traverseSession, + float getNewWordBigramLanguageCost(const DicTraverseSession *const traverseSession, const DicNode *const dicNode, MultiBigramMap *const multiBigramMap) const { - return DicNodeUtils::getBigramNodeImprobability(traverseSession->getOffsetDict(), + return DicNodeUtils::getBigramNodeImprobability(traverseSession->getBinaryDictionaryInfo(), dicNode, multiBigramMap) * ScoringParams::DISTANCE_WEIGHT_LANGUAGE; } @@ -162,9 +172,16 @@ class TypingWeighting : public Weighting { float getTerminalLanguageCost(const DicTraverseSession *const traverseSession, const DicNode *const dicNode, const float dicNodeLanguageImprobability) const { - const float languageImprobability = (dicNode->isExactMatch()) ? - 0.0f : dicNodeLanguageImprobability; - return languageImprobability * ScoringParams::DISTANCE_WEIGHT_LANGUAGE; + return dicNodeLanguageImprobability * ScoringParams::DISTANCE_WEIGHT_LANGUAGE; + } + + float getTerminalInsertionCost(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const { + const int inputIndex = dicNode->getInputIndex(0); + const int inputSize = traverseSession->getInputSize(); + ASSERT(inputIndex < inputSize); + // TODO: Implement more efficient logic + return ScoringParams::TERMINAL_INSERTION_COST * (inputSize - inputIndex); } AK_FORCE_INLINE bool needsToNormalizeCompoundDistance() const { diff --git a/native/jni/src/suggest/policyimpl/utils/damerau_levenshtein_edit_distance_policy.h b/native/jni/src/suggest/policyimpl/utils/damerau_levenshtein_edit_distance_policy.h index ec1457455..81614bc9c 100644 --- a/native/jni/src/suggest/policyimpl/utils/damerau_levenshtein_edit_distance_policy.h +++ b/native/jni/src/suggest/policyimpl/utils/damerau_levenshtein_edit_distance_policy.h @@ -17,8 +17,8 @@ #ifndef LATINIME_DAEMARU_LEVENSHTEIN_EDIT_DISTANCE_POLICY_H #define LATINIME_DAEMARU_LEVENSHTEIN_EDIT_DISTANCE_POLICY_H -#include "char_utils.h" #include "suggest/policyimpl/utils/edit_distance_policy.h" +#include "utils/char_utils.h" namespace latinime { @@ -31,8 +31,8 @@ class DamerauLevenshteinEditDistancePolicy : public EditDistancePolicy { ~DamerauLevenshteinEditDistancePolicy() {} AK_FORCE_INLINE float getSubstitutionCost(const int index0, const int index1) const { - const int c0 = toBaseLowerCase(mString0[index0]); - const int c1 = toBaseLowerCase(mString1[index1]); + const int c0 = CharUtils::toBaseLowerCase(mString0[index0]); + const int c1 = CharUtils::toBaseLowerCase(mString1[index1]); return (c0 == c1) ? 0.0f : 1.0f; } @@ -45,10 +45,10 @@ class DamerauLevenshteinEditDistancePolicy : public EditDistancePolicy { } AK_FORCE_INLINE bool allowTransposition(const int index0, const int index1) const { - const int c0 = toBaseLowerCase(mString0[index0]); - const int c1 = toBaseLowerCase(mString1[index1]); - if (index0 > 0 && index1 > 0 && c0 == toBaseLowerCase(mString1[index1 - 1]) - && c1 == toBaseLowerCase(mString0[index0 - 1])) { + const int c0 = CharUtils::toBaseLowerCase(mString0[index0]); + const int c1 = CharUtils::toBaseLowerCase(mString1[index1]); + if (index0 > 0 && index1 > 0 && c0 == CharUtils::toBaseLowerCase(mString1[index1 - 1]) + && c1 == CharUtils::toBaseLowerCase(mString0[index0 - 1])) { return true; } return false; diff --git a/native/jni/src/suggest/policyimpl/utils/edit_distance.h b/native/jni/src/suggest/policyimpl/utils/edit_distance.h index cbbd66894..0871c37ce 100644 --- a/native/jni/src/suggest/policyimpl/utils/edit_distance.h +++ b/native/jni/src/suggest/policyimpl/utils/edit_distance.h @@ -62,6 +62,26 @@ class EditDistance { return dp[(beforeLength + 1) * (afterLength + 1) - 1]; } + AK_FORCE_INLINE static void dumpEditDistance10ForDebug(const float *const editDistanceTable, + const int editDistanceTableWidth, const int outputLength) { + if (DEBUG_DICT) { + AKLOGI("EditDistanceTable"); + for (int i = 0; i <= 10; ++i) { + float c[11]; + for (int j = 0; j <= 10; ++j) { + if (j < editDistanceTableWidth + 1 && i < outputLength + 1) { + c[j] = (editDistanceTable + i * (editDistanceTableWidth + 1))[j]; + } else { + c[j] = -1.0f; + } + } + AKLOGI("[ %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f ]", + c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7], c[8], c[9], c[10]); + (void)c; // To suppress compiler warning + } + } + } + private: DISALLOW_IMPLICIT_CONSTRUCTORS(EditDistance); }; diff --git a/native/jni/src/terminal_attributes.h b/native/jni/src/terminal_attributes.h deleted file mode 100644 index 92ef71c2c..000000000 --- a/native/jni/src/terminal_attributes.h +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Copyright (C) 2012 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LATINIME_TERMINAL_ATTRIBUTES_H -#define LATINIME_TERMINAL_ATTRIBUTES_H - -#include <stdint.h> -#include "binary_format.h" - -namespace latinime { - -/** - * This class encapsulates information about a terminal that allows to - * retrieve local node attributes like the list of shortcuts without - * exposing the format structure to the client. - */ -class TerminalAttributes { - public: - class ShortcutIterator { - public: - ShortcutIterator(const uint8_t *dict, const int pos, const uint8_t flags) - : mDict(dict), mPos(pos), - mHasNextShortcutTarget(0 != (flags & BinaryFormat::FLAG_HAS_SHORTCUT_TARGETS)) { - } - - inline bool hasNextShortcutTarget() const { - return mHasNextShortcutTarget; - } - - // Gets the shortcut target itself as an int string. For parameters and return value - // see BinaryFormat::getWordAtAddress. - inline int getNextShortcutTarget(const int maxDepth, int *outWord, int *outFreq) { - const int shortcutFlags = BinaryFormat::getFlagsAndForwardPointer(mDict, &mPos); - mHasNextShortcutTarget = 0 != (shortcutFlags & BinaryFormat::FLAG_ATTRIBUTE_HAS_NEXT); - unsigned int i; - for (i = 0; i < MAX_WORD_LENGTH; ++i) { - const int codePoint = BinaryFormat::getCodePointAndForwardPointer(mDict, &mPos); - if (NOT_A_CODE_POINT == codePoint) break; - outWord[i] = codePoint; - } - *outFreq = BinaryFormat::getAttributeProbabilityFromFlags(shortcutFlags); - return i; - } - - private: - const uint8_t *const mDict; - int mPos; - bool mHasNextShortcutTarget; - }; - - TerminalAttributes(const uint8_t *const dict, const uint8_t flags, const int pos) - : mDict(dict), mFlags(flags), mStartPos(pos) { - } - - inline ShortcutIterator getShortcutIterator() const { - // The size of the shortcuts is stored here so that the whole shortcut chunk can be - // skipped quickly, so we ignore it. - return ShortcutIterator(mDict, mStartPos + BinaryFormat::SHORTCUT_LIST_SIZE_SIZE, mFlags); - } - - bool isBlacklistedOrNotAWord() const { - return BinaryFormat::hasBlacklistedOrNotAWordFlag(mFlags); - } - - private: - DISALLOW_IMPLICIT_CONSTRUCTORS(TerminalAttributes); - const uint8_t *const mDict; - const uint8_t mFlags; - const int mStartPos; -}; -} // namespace latinime -#endif // LATINIME_TERMINAL_ATTRIBUTES_H diff --git a/native/jni/src/unigram_dictionary.cpp b/native/jni/src/unigram_dictionary.cpp deleted file mode 100644 index a672294b5..000000000 --- a/native/jni/src/unigram_dictionary.cpp +++ /dev/null @@ -1,988 +0,0 @@ -/* - * Copyright (C) 2010, The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include <cstring> - -#define LOG_TAG "LatinIME: unigram_dictionary.cpp" - -#include "binary_format.h" -#include "char_utils.h" -#include "defines.h" -#include "dictionary.h" -#include "digraph_utils.h" -#include "proximity_info.h" -#include "terminal_attributes.h" -#include "unigram_dictionary.h" -#include "words_priority_queue.h" -#include "words_priority_queue_pool.h" - -namespace latinime { - -// TODO: check the header -UnigramDictionary::UnigramDictionary(const uint8_t *const streamStart, const unsigned int dictFlags) - : DICT_ROOT(streamStart), ROOT_POS(0), - MAX_DIGRAPH_SEARCH_DEPTH(DEFAULT_MAX_DIGRAPH_SEARCH_DEPTH), DICT_FLAGS(dictFlags) { - if (DEBUG_DICT) { - AKLOGI("UnigramDictionary - constructor"); - } -} - -UnigramDictionary::~UnigramDictionary() { -} - -// TODO: This needs to take a const int* and not tinker with its contents -static void addWord(int *word, int length, int probability, WordsPriorityQueue *queue, int type) { - queue->push(probability, word, length, type); -} - -// Return the replacement code point for a digraph, or 0 if none. -int UnigramDictionary::getDigraphReplacement(const int *codes, const int i, const int inputSize, - const DigraphUtils::digraph_t *const digraphs, const unsigned int digraphsSize) const { - - // There can't be a digraph if we don't have at least 2 characters to examine - if (i + 2 > inputSize) return false; - - // Search for the first char of some digraph - int lastDigraphIndex = -1; - const int thisChar = codes[i]; - for (lastDigraphIndex = digraphsSize - 1; lastDigraphIndex >= 0; --lastDigraphIndex) { - if (thisChar == digraphs[lastDigraphIndex].first) break; - } - // No match: return early - if (lastDigraphIndex < 0) return 0; - - // It's an interesting digraph if the second char matches too. - if (digraphs[lastDigraphIndex].second == codes[i + 1]) { - return digraphs[lastDigraphIndex].compositeGlyph; - } else { - return 0; - } -} - -// Mostly the same arguments as the non-recursive version, except: -// codes is the original value. It points to the start of the work buffer, and gets passed as is. -// inputSize is the size of the user input (thus, it is the size of codesSrc). -// codesDest is the current point in the work buffer. -// codesSrc is the current point in the user-input, original, content-unmodified buffer. -// codesRemain is the remaining size in codesSrc. -void UnigramDictionary::getWordWithDigraphSuggestionsRec(ProximityInfo *proximityInfo, - const int *xcoordinates, const int *ycoordinates, const int *codesBuffer, - int *xCoordinatesBuffer, int *yCoordinatesBuffer, - const int codesBufferSize, const std::map<int, int> *bigramMap, const uint8_t *bigramFilter, - const bool useFullEditDistance, const int *codesSrc, - const int codesRemain, const int currentDepth, int *codesDest, Correction *correction, - WordsPriorityQueuePool *queuePool, - const DigraphUtils::digraph_t *const digraphs, const unsigned int digraphsSize) const { - ASSERT(sizeof(codesDest[0]) == sizeof(codesSrc[0])); - ASSERT(sizeof(xCoordinatesBuffer[0]) == sizeof(xcoordinates[0])); - ASSERT(sizeof(yCoordinatesBuffer[0]) == sizeof(ycoordinates[0])); - - const int startIndex = static_cast<int>(codesDest - codesBuffer); - if (currentDepth < MAX_DIGRAPH_SEARCH_DEPTH) { - for (int i = 0; i < codesRemain; ++i) { - xCoordinatesBuffer[startIndex + i] = xcoordinates[codesBufferSize - codesRemain + i]; - yCoordinatesBuffer[startIndex + i] = ycoordinates[codesBufferSize - codesRemain + i]; - const int replacementCodePoint = - getDigraphReplacement(codesSrc, i, codesRemain, digraphs, digraphsSize); - if (0 != replacementCodePoint) { - // Found a digraph. We will try both spellings. eg. the word is "pruefen" - - // Copy the word up to the first char of the digraph, including proximity chars, - // and overwrite the primary code with the replacement code point. Then, continue - // processing on the remaining part of the word, skipping the second char of the - // digraph. - // In our example, copy "pru", replace "u" with the version with the diaeresis and - // continue running on "fen". - // Make i the index of the second char of the digraph for simplicity. Forgetting - // to do that results in an infinite recursion so take care! - ++i; - memcpy(codesDest, codesSrc, i * sizeof(codesDest[0])); - codesDest[i - 1] = replacementCodePoint; - getWordWithDigraphSuggestionsRec(proximityInfo, xcoordinates, ycoordinates, - codesBuffer, xCoordinatesBuffer, yCoordinatesBuffer, codesBufferSize, - bigramMap, bigramFilter, useFullEditDistance, codesSrc + i + 1, - codesRemain - i - 1, currentDepth + 1, codesDest + i, correction, - queuePool, digraphs, digraphsSize); - - // Copy the second char of the digraph in place, then continue processing on - // the remaining part of the word. - // In our example, after "pru" in the buffer copy the "e", and continue on "fen" - memcpy(codesDest + i, codesSrc + i, sizeof(codesDest[0])); - getWordWithDigraphSuggestionsRec(proximityInfo, xcoordinates, ycoordinates, - codesBuffer, xCoordinatesBuffer, yCoordinatesBuffer, codesBufferSize, - bigramMap, bigramFilter, useFullEditDistance, codesSrc + i, codesRemain - i, - currentDepth + 1, codesDest + i, correction, queuePool, digraphs, - digraphsSize); - return; - } - } - } - - // If we come here, we hit the end of the word: let's check it against the dictionary. - // In our example, we'll come here once for "prufen" and then once for "pruefen". - // If the word contains several digraphs, we'll come it for the product of them. - // eg. if the word is "ueberpruefen" we'll test, in order, against - // "uberprufen", "uberpruefen", "ueberprufen", "ueberpruefen". - const unsigned int remainingBytes = sizeof(codesDest[0]) * codesRemain; - if (0 != remainingBytes) { - memcpy(codesDest, codesSrc, remainingBytes); - memcpy(&xCoordinatesBuffer[startIndex], &xcoordinates[codesBufferSize - codesRemain], - sizeof(xCoordinatesBuffer[0]) * codesRemain); - memcpy(&yCoordinatesBuffer[startIndex], &ycoordinates[codesBufferSize - codesRemain], - sizeof(yCoordinatesBuffer[0]) * codesRemain); - } - - getWordSuggestions(proximityInfo, xCoordinatesBuffer, yCoordinatesBuffer, codesBuffer, - startIndex + codesRemain, bigramMap, bigramFilter, useFullEditDistance, correction, - queuePool); -} - -// bigramMap contains the association <bigram address> -> <bigram probability> -// bigramFilter is a bloom filter for fast rejection: see functions setInFilter and isInFilter -// in bigram_dictionary.cpp -int UnigramDictionary::getSuggestions(ProximityInfo *proximityInfo, const int *xcoordinates, - const int *ycoordinates, const int *inputCodePoints, const int inputSize, - const std::map<int, int> *bigramMap, const uint8_t *bigramFilter, - const bool useFullEditDistance, int *outWords, int *frequencies, int *outputTypes) const { - WordsPriorityQueuePool queuePool(MAX_RESULTS, SUB_QUEUE_MAX_WORDS); - queuePool.clearAll(); - Correction masterCorrection; - masterCorrection.resetCorrection(); - const DigraphUtils::digraph_t *digraphs = 0; - const int digraphsSize = - DigraphUtils::getAllDigraphsForDictionaryAndReturnSize(DICT_FLAGS, &digraphs); - if (digraphsSize > 0) - { // Incrementally tune the word and try all possibilities - int codesBuffer[sizeof(*inputCodePoints) * inputSize]; - int xCoordinatesBuffer[inputSize]; - int yCoordinatesBuffer[inputSize]; - getWordWithDigraphSuggestionsRec(proximityInfo, xcoordinates, ycoordinates, codesBuffer, - xCoordinatesBuffer, yCoordinatesBuffer, inputSize, bigramMap, bigramFilter, - useFullEditDistance, inputCodePoints, inputSize, 0, codesBuffer, &masterCorrection, - &queuePool, digraphs, digraphsSize); - } else { // Normal processing - getWordSuggestions(proximityInfo, xcoordinates, ycoordinates, inputCodePoints, inputSize, - bigramMap, bigramFilter, useFullEditDistance, &masterCorrection, &queuePool); - } - - PROF_START(20); - if (DEBUG_DICT) { - float ns = queuePool.getMasterQueue()->getHighestNormalizedScore( - masterCorrection.getPrimaryInputWord(), inputSize, 0, 0, 0); - ns += 0; - AKLOGI("Max normalized score = %f", ns); - } - const int suggestedWordsCount = - queuePool.getMasterQueue()->outputSuggestions(masterCorrection.getPrimaryInputWord(), - inputSize, frequencies, outWords, outputTypes); - - if (DEBUG_DICT) { - float ns = queuePool.getMasterQueue()->getHighestNormalizedScore( - masterCorrection.getPrimaryInputWord(), inputSize, 0, 0, 0); - ns += 0; - AKLOGI("Returning %d words", suggestedWordsCount); - /// Print the returned words - for (int j = 0; j < suggestedWordsCount; ++j) { - int *w = outWords + j * MAX_WORD_LENGTH; - char s[MAX_WORD_LENGTH]; - for (int i = 0; i <= MAX_WORD_LENGTH; i++) s[i] = w[i]; - (void)s; // To suppress compiler warning - AKLOGI("%s %i", s, frequencies[j]); - } - } - PROF_END(20); - PROF_CLOSE; - return suggestedWordsCount; -} - -void UnigramDictionary::getWordSuggestions(ProximityInfo *proximityInfo, const int *xcoordinates, - const int *ycoordinates, const int *inputCodePoints, const int inputSize, - const std::map<int, int> *bigramMap, const uint8_t *bigramFilter, - const bool useFullEditDistance, Correction *correction, WordsPriorityQueuePool *queuePool) - const { - PROF_OPEN; - PROF_START(0); - PROF_END(0); - - PROF_START(1); - getOneWordSuggestions(proximityInfo, xcoordinates, ycoordinates, inputCodePoints, bigramMap, - bigramFilter, useFullEditDistance, inputSize, correction, queuePool); - PROF_END(1); - - PROF_START(2); - // Note: This line is intentionally left blank - PROF_END(2); - - PROF_START(3); - // Note: This line is intentionally left blank - PROF_END(3); - - PROF_START(4); - bool hasAutoCorrectionCandidate = false; - WordsPriorityQueue *masterQueue = queuePool->getMasterQueue(); - if (masterQueue->size() > 0) { - float nsForMaster = masterQueue->getHighestNormalizedScore( - correction->getPrimaryInputWord(), inputSize, 0, 0, 0); - hasAutoCorrectionCandidate = (nsForMaster > START_TWO_WORDS_CORRECTION_THRESHOLD); - } - PROF_END(4); - - PROF_START(5); - // Multiple word suggestions - if (SUGGEST_MULTIPLE_WORDS - && inputSize >= MIN_USER_TYPED_LENGTH_FOR_MULTIPLE_WORD_SUGGESTION) { - getSplitMultipleWordsSuggestions(proximityInfo, xcoordinates, ycoordinates, inputCodePoints, - useFullEditDistance, inputSize, correction, queuePool, - hasAutoCorrectionCandidate); - } - PROF_END(5); - - PROF_START(6); - // Note: This line is intentionally left blank - PROF_END(6); - - if (DEBUG_DICT) { - queuePool->dumpSubQueue1TopSuggestions(); - for (int i = 0; i < SUB_QUEUE_MAX_COUNT; ++i) { - WordsPriorityQueue *queue = queuePool->getSubQueue(FIRST_WORD_INDEX, i); - if (queue->size() > 0) { - WordsPriorityQueue::SuggestedWord *sw = queue->top(); - const int score = sw->mScore; - const int *word = sw->mWord; - const int wordLength = sw->mWordLength; - float ns = Correction::RankingAlgorithm::calcNormalizedScore( - correction->getPrimaryInputWord(), i, word, wordLength, score); - ns += 0; - AKLOGI("--- TOP SUB WORDS for %d --- %d %f [%d]", i, score, ns, - (ns > TWO_WORDS_CORRECTION_WITH_OTHER_ERROR_THRESHOLD)); - DUMP_WORD(correction->getPrimaryInputWord(), i); - DUMP_WORD(word, wordLength); - } - } - } -} - -void UnigramDictionary::initSuggestions(ProximityInfo *proximityInfo, const int *xCoordinates, - const int *yCoordinates, const int *codes, const int inputSize, - Correction *correction) const { - if (DEBUG_DICT) { - AKLOGI("initSuggest"); - DUMP_WORD(codes, inputSize); - } - correction->initInputParams(proximityInfo, codes, inputSize, xCoordinates, yCoordinates); - const int maxDepth = min(inputSize * MAX_DEPTH_MULTIPLIER, MAX_WORD_LENGTH); - correction->initCorrection(proximityInfo, inputSize, maxDepth); -} - -void UnigramDictionary::getOneWordSuggestions(ProximityInfo *proximityInfo, - const int *xcoordinates, const int *ycoordinates, const int *codes, - const std::map<int, int> *bigramMap, const uint8_t *bigramFilter, - const bool useFullEditDistance, const int inputSize, - Correction *correction, WordsPriorityQueuePool *queuePool) const { - initSuggestions(proximityInfo, xcoordinates, ycoordinates, codes, inputSize, correction); - getSuggestionCandidates(useFullEditDistance, inputSize, bigramMap, bigramFilter, correction, - queuePool, true /* doAutoCompletion */, DEFAULT_MAX_ERRORS, FIRST_WORD_INDEX); -} - -void UnigramDictionary::getSuggestionCandidates(const bool useFullEditDistance, - const int inputSize, const std::map<int, int> *bigramMap, const uint8_t *bigramFilter, - Correction *correction, WordsPriorityQueuePool *queuePool, - const bool doAutoCompletion, const int maxErrors, const int currentWordIndex) const { - uint8_t totalTraverseCount = correction->pushAndGetTotalTraverseCount(); - if (DEBUG_DICT) { - AKLOGI("Traverse count %d", totalTraverseCount); - } - if (totalTraverseCount > MULTIPLE_WORDS_SUGGESTION_MAX_TOTAL_TRAVERSE_COUNT) { - if (DEBUG_DICT) { - AKLOGI("Abort traversing %d", totalTraverseCount); - } - return; - } - // TODO: Remove setCorrectionParams - correction->setCorrectionParams(0, 0, 0, - -1 /* spaceProximityPos */, -1 /* missingSpacePos */, useFullEditDistance, - doAutoCompletion, maxErrors); - int rootPosition = ROOT_POS; - // Get the number of children of root, then increment the position - int childCount = BinaryFormat::getGroupCountAndForwardPointer(DICT_ROOT, &rootPosition); - int outputIndex = 0; - - correction->initCorrectionState(rootPosition, childCount, (inputSize <= 0)); - - // Depth first search - while (outputIndex >= 0) { - if (correction->initProcessState(outputIndex)) { - int siblingPos = correction->getTreeSiblingPos(outputIndex); - int firstChildPos; - - const bool needsToTraverseChildrenNodes = processCurrentNode(siblingPos, - bigramMap, bigramFilter, correction, &childCount, &firstChildPos, &siblingPos, - queuePool, currentWordIndex); - // Update next sibling pos - correction->setTreeSiblingPos(outputIndex, siblingPos); - - if (needsToTraverseChildrenNodes) { - // Goes to child node - outputIndex = correction->goDownTree(outputIndex, childCount, firstChildPos); - } - } else { - // Goes to parent sibling node - outputIndex = correction->getTreeParentIndex(outputIndex); - } - } -} - -void UnigramDictionary::onTerminal(const int probability, - const TerminalAttributes &terminalAttributes, Correction *correction, - WordsPriorityQueuePool *queuePool, const bool addToMasterQueue, - const int currentWordIndex) const { - const int inputIndex = correction->getInputIndex(); - const bool addToSubQueue = inputIndex < SUB_QUEUE_MAX_COUNT; - - int wordLength; - int *wordPointer; - - if ((currentWordIndex == FIRST_WORD_INDEX) && addToMasterQueue) { - WordsPriorityQueue *masterQueue = queuePool->getMasterQueue(); - const int finalProbability = - correction->getFinalProbability(probability, &wordPointer, &wordLength); - - if (0 != finalProbability && !terminalAttributes.isBlacklistedOrNotAWord()) { - // If the probability is 0, we don't want to add this word. However we still - // want to add its shortcuts (including a possible whitelist entry) if any. - // Furthermore, if this is not a word (shortcut only for example) or a blacklisted - // entry then we never want to suggest this. - addWord(wordPointer, wordLength, finalProbability, masterQueue, - Dictionary::KIND_CORRECTION); - } - - const int shortcutProbability = finalProbability > 0 ? finalProbability - 1 : 0; - // Please note that the shortcut candidates will be added to the master queue only. - TerminalAttributes::ShortcutIterator iterator = terminalAttributes.getShortcutIterator(); - while (iterator.hasNextShortcutTarget()) { - // TODO: addWord only supports weak ordering, meaning we have no means - // to control the order of the shortcuts relative to one another or to the word. - // We need to either modulate the probability of each shortcut according - // to its own shortcut probability or to make the queue - // so that the insert order is protected inside the queue for words - // with the same score. For the moment we use -1 to make sure the shortcut will - // never be in front of the word. - int shortcutTarget[MAX_WORD_LENGTH]; - int shortcutFrequency; - const int shortcutTargetStringLength = iterator.getNextShortcutTarget( - MAX_WORD_LENGTH, shortcutTarget, &shortcutFrequency); - int shortcutScore; - int kind; - if (shortcutFrequency == BinaryFormat::WHITELIST_SHORTCUT_PROBABILITY - && correction->sameAsTyped()) { - shortcutScore = S_INT_MAX; - kind = Dictionary::KIND_WHITELIST; - } else { - shortcutScore = shortcutProbability; - kind = Dictionary::KIND_CORRECTION; - } - addWord(shortcutTarget, shortcutTargetStringLength, shortcutScore, - masterQueue, kind); - } - } - - // We only allow two words + other error correction for words with SUB_QUEUE_MIN_WORD_LENGTH - // or more length. - if (inputIndex >= SUB_QUEUE_MIN_WORD_LENGTH && addToSubQueue) { - WordsPriorityQueue *subQueue; - subQueue = queuePool->getSubQueue(currentWordIndex, inputIndex); - if (!subQueue) { - return; - } - const int finalProbability = correction->getFinalProbabilityForSubQueue( - probability, &wordPointer, &wordLength, inputIndex); - addWord(wordPointer, wordLength, finalProbability, subQueue, Dictionary::KIND_CORRECTION); - } -} - -int UnigramDictionary::getSubStringSuggestion( - ProximityInfo *proximityInfo, const int *xcoordinates, const int *ycoordinates, - const int *codes, const bool useFullEditDistance, Correction *correction, - WordsPriorityQueuePool *queuePool, const int inputSize, - const bool hasAutoCorrectionCandidate, const int currentWordIndex, - const int inputWordStartPos, const int inputWordLength, - const int outputWordStartPos, const bool isSpaceProximity, int *freqArray, - int *wordLengthArray, int *outputWord, int *outputWordLength) const { - if (inputWordLength > MULTIPLE_WORDS_SUGGESTION_MAX_WORD_LENGTH) { - return FLAG_MULTIPLE_SUGGEST_ABORT; - } - - ///////////////////////////////////////////// - // safety net for multiple word suggestion // - // TODO: Remove this safety net // - ///////////////////////////////////////////// - int smallWordCount = 0; - int singleLetterWordCount = 0; - if (inputWordLength == 1) { - ++singleLetterWordCount; - } - if (inputWordLength <= 2) { - // small word == single letter or 2-letter word - ++smallWordCount; - } - for (int i = 0; i < currentWordIndex; ++i) { - const int length = wordLengthArray[i]; - if (length == 1) { - ++singleLetterWordCount; - // Safety net to avoid suggesting sequential single letter words - if (i < (currentWordIndex - 1)) { - if (wordLengthArray[i + 1] == 1) { - return FLAG_MULTIPLE_SUGGEST_ABORT; - } - } else if (inputWordLength == 1) { - return FLAG_MULTIPLE_SUGGEST_ABORT; - } - } - if (length <= 2) { - ++smallWordCount; - } - // Safety net to avoid suggesting multiple words with many (4 or more, for now) small words - if (singleLetterWordCount >= 3 || smallWordCount >= 4) { - return FLAG_MULTIPLE_SUGGEST_ABORT; - } - } - ////////////////////////////////////////////// - // TODO: Remove the safety net above // - ////////////////////////////////////////////// - - int *tempOutputWord = 0; - int nextWordLength = 0; - // TODO: Optimize init suggestion - initSuggestions(proximityInfo, xcoordinates, ycoordinates, codes, - inputSize, correction); - - int word[MAX_WORD_LENGTH]; - int freq = getMostProbableWordLike( - inputWordStartPos, inputWordLength, correction, word); - if (freq > 0) { - nextWordLength = inputWordLength; - tempOutputWord = word; - } else if (!hasAutoCorrectionCandidate) { - if (inputWordStartPos > 0) { - const int offset = inputWordStartPos; - initSuggestions(proximityInfo, &xcoordinates[offset], &ycoordinates[offset], - codes + offset, inputWordLength, correction); - queuePool->clearSubQueue(currentWordIndex); - // TODO: pass the bigram list for substring suggestion - getSuggestionCandidates(useFullEditDistance, inputWordLength, - 0 /* bigramMap */, 0 /* bigramFilter */, correction, queuePool, - false /* doAutoCompletion */, MAX_ERRORS_FOR_TWO_WORDS, currentWordIndex); - if (DEBUG_DICT) { - if (currentWordIndex < MULTIPLE_WORDS_SUGGESTION_MAX_WORDS) { - AKLOGI("Dump word candidates(%d) %d", currentWordIndex, inputWordLength); - for (int i = 0; i < SUB_QUEUE_MAX_COUNT; ++i) { - queuePool->getSubQueue(currentWordIndex, i)->dumpTopWord(); - } - } - } - } - WordsPriorityQueue *queue = queuePool->getSubQueue(currentWordIndex, inputWordLength); - // TODO: Return the correct value depending on doAutoCompletion - if (!queue || queue->size() <= 0) { - return FLAG_MULTIPLE_SUGGEST_ABORT; - } - int score = 0; - const float ns = queue->getHighestNormalizedScore( - correction->getPrimaryInputWord(), inputWordLength, - &tempOutputWord, &score, &nextWordLength); - if (DEBUG_DICT) { - AKLOGI("NS(%d) = %f, Score = %d", currentWordIndex, ns, score); - } - // Two words correction won't be done if the score of the first word doesn't exceed the - // threshold. - if (ns < TWO_WORDS_CORRECTION_WITH_OTHER_ERROR_THRESHOLD - || nextWordLength < SUB_QUEUE_MIN_WORD_LENGTH) { - return FLAG_MULTIPLE_SUGGEST_SKIP; - } - freq = score >> (nextWordLength + TWO_WORDS_PLUS_OTHER_ERROR_CORRECTION_DEMOTION_DIVIDER); - } - if (DEBUG_DICT) { - AKLOGI("Freq(%d): %d, length: %d, input length: %d, input start: %d (%d)", - currentWordIndex, freq, nextWordLength, inputWordLength, inputWordStartPos, - (currentWordIndex > 0) ? wordLengthArray[0] : 0); - } - if (freq <= 0 || nextWordLength <= 0 - || MAX_WORD_LENGTH <= (outputWordStartPos + nextWordLength)) { - return FLAG_MULTIPLE_SUGGEST_SKIP; - } - for (int i = 0; i < nextWordLength; ++i) { - outputWord[outputWordStartPos + i] = tempOutputWord[i]; - } - - // Put output values - freqArray[currentWordIndex] = freq; - // TODO: put output length instead of input length - wordLengthArray[currentWordIndex] = inputWordLength; - const int tempOutputWordLength = outputWordStartPos + nextWordLength; - if (outputWordLength) { - *outputWordLength = tempOutputWordLength; - } - - if ((inputWordStartPos + inputWordLength) < inputSize) { - if (outputWordStartPos + nextWordLength >= MAX_WORD_LENGTH) { - return FLAG_MULTIPLE_SUGGEST_SKIP; - } - outputWord[tempOutputWordLength] = KEYCODE_SPACE; - if (outputWordLength) { - ++*outputWordLength; - } - } else if (currentWordIndex >= 1) { - // TODO: Handle 3 or more words - const int pairFreq = correction->getFreqForSplitMultipleWords( - freqArray, wordLengthArray, currentWordIndex + 1, isSpaceProximity, outputWord); - if (DEBUG_DICT) { - DUMP_WORD(outputWord, tempOutputWordLength); - for (int i = 0; i < currentWordIndex + 1; ++i) { - AKLOGI("Split %d,%d words: freq = %d, length = %d", i, currentWordIndex + 1, - freqArray[i], wordLengthArray[i]); - } - AKLOGI("Split two words: freq = %d, length = %d, %d, isSpace ? %d", pairFreq, - inputSize, tempOutputWordLength, isSpaceProximity); - } - addWord(outputWord, tempOutputWordLength, pairFreq, queuePool->getMasterQueue(), - Dictionary::KIND_CORRECTION); - } - return FLAG_MULTIPLE_SUGGEST_CONTINUE; -} - -void UnigramDictionary::getMultiWordsSuggestionRec(ProximityInfo *proximityInfo, - const int *xcoordinates, const int *ycoordinates, const int *codes, - const bool useFullEditDistance, const int inputSize, Correction *correction, - WordsPriorityQueuePool *queuePool, const bool hasAutoCorrectionCandidate, - const int startInputPos, const int startWordIndex, const int outputWordLength, - int *freqArray, int *wordLengthArray, int *outputWord) const { - if (startWordIndex >= (MULTIPLE_WORDS_SUGGESTION_MAX_WORDS - 1)) { - // Return if the last word index - return; - } - if (startWordIndex >= 1 - && (hasAutoCorrectionCandidate - || inputSize < MIN_INPUT_LENGTH_FOR_THREE_OR_MORE_WORDS_CORRECTION)) { - // Do not suggest 3+ words if already has auto correction candidate - return; - } - for (int i = startInputPos + 1; i < inputSize; ++i) { - if (DEBUG_CORRECTION_FREQ) { - AKLOGI("Multi words(%d), start in %d sep %d start out %d", - startWordIndex, startInputPos, i, outputWordLength); - DUMP_WORD(outputWord, outputWordLength); - } - int tempOutputWordLength = 0; - // Current word - int inputWordStartPos = startInputPos; - int inputWordLength = i - startInputPos; - const int suggestionFlag = getSubStringSuggestion(proximityInfo, xcoordinates, ycoordinates, - codes, useFullEditDistance, correction, queuePool, inputSize, - hasAutoCorrectionCandidate, startWordIndex, inputWordStartPos, inputWordLength, - outputWordLength, true /* not used */, freqArray, wordLengthArray, outputWord, - &tempOutputWordLength); - if (suggestionFlag == FLAG_MULTIPLE_SUGGEST_ABORT) { - // TODO: break here - continue; - } else if (suggestionFlag == FLAG_MULTIPLE_SUGGEST_SKIP) { - continue; - } - - if (DEBUG_CORRECTION_FREQ) { - AKLOGI("Do missing space correction"); - } - // Next word - // Missing space - inputWordStartPos = i; - inputWordLength = inputSize - i; - if (getSubStringSuggestion(proximityInfo, xcoordinates, ycoordinates, codes, - useFullEditDistance, correction, queuePool, inputSize, hasAutoCorrectionCandidate, - startWordIndex + 1, inputWordStartPos, inputWordLength, tempOutputWordLength, - false /* missing space */, freqArray, wordLengthArray, outputWord, 0) - != FLAG_MULTIPLE_SUGGEST_CONTINUE) { - getMultiWordsSuggestionRec(proximityInfo, xcoordinates, ycoordinates, codes, - useFullEditDistance, inputSize, correction, queuePool, - hasAutoCorrectionCandidate, inputWordStartPos, startWordIndex + 1, - tempOutputWordLength, freqArray, wordLengthArray, outputWord); - } - - // Mistyped space - ++inputWordStartPos; - --inputWordLength; - - if (inputWordLength <= 0) { - continue; - } - - const int x = xcoordinates[inputWordStartPos - 1]; - const int y = ycoordinates[inputWordStartPos - 1]; - if (!proximityInfo->hasSpaceProximity(x, y)) { - continue; - } - - if (DEBUG_CORRECTION_FREQ) { - AKLOGI("Do mistyped space correction"); - } - getSubStringSuggestion(proximityInfo, xcoordinates, ycoordinates, codes, - useFullEditDistance, correction, queuePool, inputSize, hasAutoCorrectionCandidate, - startWordIndex + 1, inputWordStartPos, inputWordLength, tempOutputWordLength, - true /* mistyped space */, freqArray, wordLengthArray, outputWord, 0); - } -} - -void UnigramDictionary::getSplitMultipleWordsSuggestions(ProximityInfo *proximityInfo, - const int *xcoordinates, const int *ycoordinates, const int *codes, - const bool useFullEditDistance, const int inputSize, - Correction *correction, WordsPriorityQueuePool *queuePool, - const bool hasAutoCorrectionCandidate) const { - if (inputSize >= MAX_WORD_LENGTH) return; - if (DEBUG_DICT) { - AKLOGI("--- Suggest multiple words"); - } - - // Allocating fixed length array on stack - int outputWord[MAX_WORD_LENGTH]; - int freqArray[MULTIPLE_WORDS_SUGGESTION_MAX_WORDS]; - int wordLengthArray[MULTIPLE_WORDS_SUGGESTION_MAX_WORDS]; - const int outputWordLength = 0; - const int startInputPos = 0; - const int startWordIndex = 0; - getMultiWordsSuggestionRec(proximityInfo, xcoordinates, ycoordinates, codes, - useFullEditDistance, inputSize, correction, queuePool, hasAutoCorrectionCandidate, - startInputPos, startWordIndex, outputWordLength, freqArray, wordLengthArray, - outputWord); -} - -// Wrapper for getMostProbableWordLikeInner, which matches it to the previous -// interface. -int UnigramDictionary::getMostProbableWordLike(const int startInputIndex, const int inputSize, - Correction *correction, int *word) const { - int inWord[inputSize]; - for (int i = 0; i < inputSize; ++i) { - inWord[i] = correction->getPrimaryCodePointAt(startInputIndex + i); - } - return getMostProbableWordLikeInner(inWord, inputSize, word); -} - -// This function will take the position of a character array within a CharGroup, -// and check it actually like-matches the word in inWord starting at startInputIndex, -// that is, it matches it with case and accents squashed. -// The function returns true if there was a full match, false otherwise. -// The function will copy on-the-fly the characters in the CharGroup to outNewWord. -// It will also place the end position of the array in outPos; in outInputIndex, -// it will place the index of the first char AFTER the match if there was a match, -// and the initial position if there was not. It makes sense because if there was -// a match we want to continue searching, but if there was not, we want to go to -// the next CharGroup. -// In and out parameters may point to the same location. This function takes care -// not to use any input parameters after it wrote into its outputs. -static inline bool testCharGroupForContinuedLikeness(const uint8_t flags, - const uint8_t *const root, const int startPos, const int *const inWord, - const int startInputIndex, const int inputSize, int *outNewWord, int *outInputIndex, - int *outPos) { - const bool hasMultipleChars = (0 != (BinaryFormat::FLAG_HAS_MULTIPLE_CHARS & flags)); - int pos = startPos; - int codePoint = BinaryFormat::getCodePointAndForwardPointer(root, &pos); - int baseChar = toBaseLowerCase(codePoint); - const int wChar = toBaseLowerCase(inWord[startInputIndex]); - - if (baseChar != wChar) { - *outPos = hasMultipleChars ? BinaryFormat::skipOtherCharacters(root, pos) : pos; - *outInputIndex = startInputIndex; - return false; - } - int inputIndex = startInputIndex; - outNewWord[inputIndex] = codePoint; - if (hasMultipleChars) { - codePoint = BinaryFormat::getCodePointAndForwardPointer(root, &pos); - while (NOT_A_CODE_POINT != codePoint) { - baseChar = toBaseLowerCase(codePoint); - if (inputIndex + 1 >= inputSize || toBaseLowerCase(inWord[++inputIndex]) != baseChar) { - *outPos = BinaryFormat::skipOtherCharacters(root, pos); - *outInputIndex = startInputIndex; - return false; - } - outNewWord[inputIndex] = codePoint; - codePoint = BinaryFormat::getCodePointAndForwardPointer(root, &pos); - } - } - *outInputIndex = inputIndex + 1; - *outPos = pos; - return true; -} - -// This function is invoked when a word like the word searched for is found. -// It will compare the probability to the max probability, and if greater, will -// copy the word into the output buffer. In output value maxFreq, it will -// write the new maximum probability if it changed. -static inline void onTerminalWordLike(const int freq, int *newWord, const int length, int *outWord, - int *maxFreq) { - if (freq > *maxFreq) { - for (int q = 0; q < length; ++q) { - outWord[q] = newWord[q]; - } - outWord[length] = 0; - *maxFreq = freq; - } -} - -// Will find the highest probability of the words like the one passed as an argument, -// that is, everything that only differs by case/accents. -int UnigramDictionary::getMostProbableWordLikeInner(const int *const inWord, const int inputSize, - int *outWord) const { - int newWord[MAX_WORD_LENGTH]; - int depth = 0; - int maxFreq = -1; - const uint8_t *const root = DICT_ROOT; - int stackChildCount[MAX_WORD_LENGTH]; - int stackInputIndex[MAX_WORD_LENGTH]; - int stackSiblingPos[MAX_WORD_LENGTH]; - - int startPos = 0; - stackChildCount[0] = BinaryFormat::getGroupCountAndForwardPointer(root, &startPos); - stackInputIndex[0] = 0; - stackSiblingPos[0] = startPos; - while (depth >= 0) { - const int charGroupCount = stackChildCount[depth]; - int pos = stackSiblingPos[depth]; - for (int charGroupIndex = charGroupCount - 1; charGroupIndex >= 0; --charGroupIndex) { - int inputIndex = stackInputIndex[depth]; - const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(root, &pos); - // Test whether all chars in this group match with the word we are searching for. If so, - // we want to traverse its children (or if the inputSize match, evaluate its - // probability). Note that this function will output the position regardless, but will - // only write into inputIndex if there is a match. - const bool isAlike = testCharGroupForContinuedLikeness(flags, root, pos, inWord, - inputIndex, inputSize, newWord, &inputIndex, &pos); - if (isAlike && (!(BinaryFormat::FLAG_IS_NOT_A_WORD & flags)) - && (BinaryFormat::FLAG_IS_TERMINAL & flags) && (inputIndex == inputSize)) { - const int probability = - BinaryFormat::readProbabilityWithoutMovingPointer(root, pos); - onTerminalWordLike(probability, newWord, inputIndex, outWord, &maxFreq); - } - pos = BinaryFormat::skipProbability(flags, pos); - const int siblingPos = BinaryFormat::skipChildrenPosAndAttributes(root, flags, pos); - const int childrenNodePos = BinaryFormat::readChildrenPosition(root, flags, pos); - // If we had a match and the word has children, we want to traverse them. We don't have - // to traverse words longer than the one we are searching for, since they will not match - // anyway, so don't traverse unless inputIndex < inputSize. - if (isAlike && (-1 != childrenNodePos) && (inputIndex < inputSize)) { - // Save position for this depth, to get back to this once children are done - stackChildCount[depth] = charGroupIndex; - stackSiblingPos[depth] = siblingPos; - // Prepare stack values for next depth - ++depth; - int childrenPos = childrenNodePos; - stackChildCount[depth] = - BinaryFormat::getGroupCountAndForwardPointer(root, &childrenPos); - stackSiblingPos[depth] = childrenPos; - stackInputIndex[depth] = inputIndex; - pos = childrenPos; - // Go to the next depth level. - ++depth; - break; - } else { - // No match, or no children, or word too long to ever match: go the next sibling. - pos = siblingPos; - } - } - --depth; - } - return maxFreq; -} - -int UnigramDictionary::getProbability(const int *const inWord, const int length) const { - const uint8_t *const root = DICT_ROOT; - int pos = BinaryFormat::getTerminalPosition(root, inWord, length, - false /* forceLowerCaseSearch */); - if (NOT_VALID_WORD == pos) { - return NOT_A_PROBABILITY; - } - const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(root, &pos); - if (flags & (BinaryFormat::FLAG_IS_BLACKLISTED | BinaryFormat::FLAG_IS_NOT_A_WORD)) { - // If this is not a word, or if it's a blacklisted entry, it should behave as - // having no probability outside of the suggestion process (where it should be used - // for shortcuts). - return NOT_A_PROBABILITY; - } - const bool hasMultipleChars = (0 != (BinaryFormat::FLAG_HAS_MULTIPLE_CHARS & flags)); - if (hasMultipleChars) { - pos = BinaryFormat::skipOtherCharacters(root, pos); - } else { - BinaryFormat::getCodePointAndForwardPointer(DICT_ROOT, &pos); - } - const int unigramProbability = BinaryFormat::readProbabilityWithoutMovingPointer(root, pos); - return unigramProbability; -} - -// TODO: remove this function. -int UnigramDictionary::getBigramPosition(int pos, int *word, int offset, int length) const { - return -1; -} - -// ProcessCurrentNode returns a boolean telling whether to traverse children nodes or not. -// If the return value is false, then the caller should read in the output "nextSiblingPosition" -// to find out the address of the next sibling node and pass it to a new call of processCurrentNode. -// It is worthy to note that when false is returned, the output values other than -// nextSiblingPosition are undefined. -// If the return value is true, then the caller must proceed to traverse the children of this -// node. processCurrentNode will output the information about the children: their count in -// newCount, their position in newChildrenPosition, the traverseAllNodes flag in -// newTraverseAllNodes, the match weight into newMatchRate, the input index into newInputIndex, the -// diffs into newDiffs, the sibling position in nextSiblingPosition, and the output index into -// newOutputIndex. Please also note the following caveat: processCurrentNode does not know when -// there aren't any more nodes at this level, it merely returns the address of the first byte after -// the current node in nextSiblingPosition. Thus, the caller must keep count of the nodes at any -// given level, as output into newCount when traversing this level's parent. -bool UnigramDictionary::processCurrentNode(const int initialPos, - const std::map<int, int> *bigramMap, const uint8_t *bigramFilter, Correction *correction, - int *newCount, int *newChildrenPosition, int *nextSiblingPosition, - WordsPriorityQueuePool *queuePool, const int currentWordIndex) const { - if (DEBUG_DICT) { - correction->checkState(); - } - int pos = initialPos; - - // Flags contain the following information: - // - Address type (MASK_GROUP_ADDRESS_TYPE) on two bits: - // - FLAG_GROUP_ADDRESS_TYPE_{ONE,TWO,THREE}_BYTES means there are children and their address - // is on the specified number of bytes. - // - FLAG_GROUP_ADDRESS_TYPE_NOADDRESS means there are no children, and therefore no address. - // - FLAG_HAS_MULTIPLE_CHARS: whether this node has multiple char or not. - // - FLAG_IS_TERMINAL: whether this node is a terminal or not (it may still have children) - // - FLAG_HAS_BIGRAMS: whether this node has bigrams or not - const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(DICT_ROOT, &pos); - const bool hasMultipleChars = (0 != (BinaryFormat::FLAG_HAS_MULTIPLE_CHARS & flags)); - const bool isTerminalNode = (0 != (BinaryFormat::FLAG_IS_TERMINAL & flags)); - - bool needsToInvokeOnTerminal = false; - - // This gets only ONE character from the stream. Next there will be: - // if FLAG_HAS_MULTIPLE CHARS: the other characters of the same node - // else if FLAG_IS_TERMINAL: the probability - // else if MASK_GROUP_ADDRESS_TYPE is not NONE: the children address - // Note that you can't have a node that both is not a terminal and has no children. - int c = BinaryFormat::getCodePointAndForwardPointer(DICT_ROOT, &pos); - ASSERT(NOT_A_CODE_POINT != c); - - // We are going to loop through each character and make it look like it's a different - // node each time. To do that, we will process characters in this node in order until - // we find the character terminator. This is signalled by getCodePoint* returning - // NOT_A_CODE_POINT. - // As a special case, if there is only one character in this node, we must not read the - // next bytes so we will simulate the NOT_A_CODE_POINT return by testing the flags. - // This way, each loop run will look like a "virtual node". - do { - // We prefetch the next char. If 'c' is the last char of this node, we will have - // NOT_A_CODE_POINT in the next char. From this we can decide whether this virtual node - // should behave as a terminal or not and whether we have children. - const int nextc = hasMultipleChars - ? BinaryFormat::getCodePointAndForwardPointer(DICT_ROOT, &pos) : NOT_A_CODE_POINT; - const bool isLastChar = (NOT_A_CODE_POINT == nextc); - // If there are more chars in this nodes, then this virtual node is not a terminal. - // If we are on the last char, this virtual node is a terminal if this node is. - const bool isTerminal = isLastChar && isTerminalNode; - - Correction::CorrectionType stateType = correction->processCharAndCalcState( - c, isTerminal); - if (stateType == Correction::TRAVERSE_ALL_ON_TERMINAL - || stateType == Correction::ON_TERMINAL) { - needsToInvokeOnTerminal = true; - } else if (stateType == Correction::UNRELATED || correction->needsToPrune()) { - // We found that this is an unrelated character, so we should give up traversing - // this node and its children entirely. - // However we may not be on the last virtual node yet so we skip the remaining - // characters in this node, the probability if it's there, read the next sibling - // position to output it, then return false. - // We don't have to output other values because we return false, as in - // "don't traverse children". - if (!isLastChar) { - pos = BinaryFormat::skipOtherCharacters(DICT_ROOT, pos); - } - pos = BinaryFormat::skipProbability(flags, pos); - *nextSiblingPosition = - BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos); - return false; - } - - // Prepare for the next character. Promote the prefetched char to current char - the loop - // will take care of prefetching the next. If we finally found our last char, nextc will - // contain NOT_A_CODE_POINT. - c = nextc; - } while (NOT_A_CODE_POINT != c); - - if (isTerminalNode) { - // The probability should be here, because we come here only if this is actually - // a terminal node, and we are on its last char. - const int unigramProbability = - BinaryFormat::readProbabilityWithoutMovingPointer(DICT_ROOT, pos); - const int childrenAddressPos = BinaryFormat::skipProbability(flags, pos); - const int attributesPos = BinaryFormat::skipChildrenPosition(flags, childrenAddressPos); - TerminalAttributes terminalAttributes(DICT_ROOT, flags, attributesPos); - // bigramMap contains the bigram frequencies indexed by addresses for fast lookup. - // bigramFilter is a bloom filter of said frequencies for even faster rejection. - const int probability = BinaryFormat::getProbability(initialPos, bigramMap, bigramFilter, - unigramProbability); - onTerminal(probability, terminalAttributes, correction, queuePool, needsToInvokeOnTerminal, - currentWordIndex); - - // If there are more chars in this node, then this virtual node has children. - // If we are on the last char, this virtual node has children if this node has. - const bool hasChildren = BinaryFormat::hasChildrenInFlags(flags); - - // This character matched the typed character (enough to traverse the node at least) - // so we just evaluated it. Now we should evaluate this virtual node's children - that - // is, if it has any. If it has no children, we're done here - so we skip the end of - // the node, output the siblings position, and return false "don't traverse children". - // Note that !hasChildren implies isLastChar, so we know we don't have to skip any - // remaining char in this group for there can't be any. - if (!hasChildren) { - pos = BinaryFormat::skipProbability(flags, pos); - *nextSiblingPosition = - BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos); - return false; - } - - // Optimization: Prune out words that are too long compared to how much was typed. - if (correction->needsToPrune()) { - pos = BinaryFormat::skipProbability(flags, pos); - *nextSiblingPosition = - BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos); - if (DEBUG_DICT_FULL) { - AKLOGI("Traversing was pruned."); - } - return false; - } - } - - // Now we finished processing this node, and we want to traverse children. If there are no - // children, we can't come here. - ASSERT(BinaryFormat::hasChildrenInFlags(flags)); - - // If this node was a terminal it still has the probability under the pointer (it may have been - // read, but not skipped - see readProbabilityWithoutMovingPointer). - // Next come the children position, then possibly attributes (attributes are bigrams only for - // now, maybe something related to shortcuts in the future). - // Once this is read, we still need to output the number of nodes in the immediate children of - // this node, so we read and output it before returning true, as in "please traverse children". - pos = BinaryFormat::skipProbability(flags, pos); - int childrenPos = BinaryFormat::readChildrenPosition(DICT_ROOT, flags, pos); - *nextSiblingPosition = BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos); - *newCount = BinaryFormat::getGroupCountAndForwardPointer(DICT_ROOT, &childrenPos); - *newChildrenPosition = childrenPos; - return true; -} -} // namespace latinime diff --git a/native/jni/src/unigram_dictionary.h b/native/jni/src/unigram_dictionary.h deleted file mode 100644 index a64a539bd..000000000 --- a/native/jni/src/unigram_dictionary.h +++ /dev/null @@ -1,116 +0,0 @@ -/* - * Copyright (C) 2010 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LATINIME_UNIGRAM_DICTIONARY_H -#define LATINIME_UNIGRAM_DICTIONARY_H - -#include <map> -#include <stdint.h> -#include "defines.h" -#include "digraph_utils.h" - -namespace latinime { - -class Correction; -class ProximityInfo; -class TerminalAttributes; -class WordsPriorityQueuePool; - -class UnigramDictionary { - public: - // Error tolerances - static const int DEFAULT_MAX_ERRORS = 2; - static const int MAX_ERRORS_FOR_TWO_WORDS = 1; - - static const int FLAG_MULTIPLE_SUGGEST_ABORT = 0; - static const int FLAG_MULTIPLE_SUGGEST_SKIP = 1; - static const int FLAG_MULTIPLE_SUGGEST_CONTINUE = 2; - UnigramDictionary(const uint8_t *const streamStart, const unsigned int dictFlags); - int getProbability(const int *const inWord, const int length) const; - int getBigramPosition(int pos, int *word, int offset, int length) const; - int getSuggestions(ProximityInfo *proximityInfo, const int *xcoordinates, - const int *ycoordinates, const int *inputCodePoints, const int inputSize, - const std::map<int, int> *bigramMap, const uint8_t *bigramFilter, - const bool useFullEditDistance, int *outWords, int *frequencies, - int *outputTypes) const; - int getDictFlags() const { return DICT_FLAGS; } - virtual ~UnigramDictionary(); - - private: - DISALLOW_IMPLICIT_CONSTRUCTORS(UnigramDictionary); - void getWordSuggestions(ProximityInfo *proximityInfo, const int *xcoordinates, - const int *ycoordinates, const int *inputCodePoints, const int inputSize, - const std::map<int, int> *bigramMap, const uint8_t *bigramFilter, - const bool useFullEditDistance, Correction *correction, - WordsPriorityQueuePool *queuePool) const; - int getDigraphReplacement(const int *codes, const int i, const int inputSize, - const DigraphUtils::digraph_t *const digraphs, const unsigned int digraphsSize) const; - void getWordWithDigraphSuggestionsRec(ProximityInfo *proximityInfo, const int *xcoordinates, - const int *ycoordinates, const int *codesBuffer, int *xCoordinatesBuffer, - int *yCoordinatesBuffer, const int codesBufferSize, const std::map<int, int> *bigramMap, - const uint8_t *bigramFilter, const bool useFullEditDistance, const int *codesSrc, - const int codesRemain, const int currentDepth, int *codesDest, Correction *correction, - WordsPriorityQueuePool *queuePool, const DigraphUtils::digraph_t *const digraphs, - const unsigned int digraphsSize) const; - void initSuggestions(ProximityInfo *proximityInfo, const int *xcoordinates, - const int *ycoordinates, const int *codes, const int inputSize, - Correction *correction) const; - void getOneWordSuggestions(ProximityInfo *proximityInfo, const int *xcoordinates, - const int *ycoordinates, const int *codes, const std::map<int, int> *bigramMap, - const uint8_t *bigramFilter, const bool useFullEditDistance, const int inputSize, - Correction *correction, WordsPriorityQueuePool *queuePool) const; - void getSuggestionCandidates( - const bool useFullEditDistance, const int inputSize, - const std::map<int, int> *bigramMap, const uint8_t *bigramFilter, - Correction *correction, WordsPriorityQueuePool *queuePool, const bool doAutoCompletion, - const int maxErrors, const int currentWordIndex) const; - void getSplitMultipleWordsSuggestions(ProximityInfo *proximityInfo, const int *xcoordinates, - const int *ycoordinates, const int *codes, const bool useFullEditDistance, - const int inputSize, Correction *correction, WordsPriorityQueuePool *queuePool, - const bool hasAutoCorrectionCandidate) const; - void onTerminal(const int freq, const TerminalAttributes &terminalAttributes, - Correction *correction, WordsPriorityQueuePool *queuePool, const bool addToMasterQueue, - const int currentWordIndex) const; - // Process a node by considering proximity, missing and excessive character - bool processCurrentNode(const int initialPos, const std::map<int, int> *bigramMap, - const uint8_t *bigramFilter, Correction *correction, int *newCount, - int *newChildPosition, int *nextSiblingPosition, WordsPriorityQueuePool *queuePool, - const int currentWordIndex) const; - int getMostProbableWordLike(const int startInputIndex, const int inputSize, - Correction *correction, int *word) const; - int getMostProbableWordLikeInner(const int *const inWord, const int inputSize, - int *outWord) const; - int getSubStringSuggestion(ProximityInfo *proximityInfo, const int *xcoordinates, - const int *ycoordinates, const int *codes, const bool useFullEditDistance, - Correction *correction, WordsPriorityQueuePool *queuePool, const int inputSize, - const bool hasAutoCorrectionCandidate, const int currentWordIndex, - const int inputWordStartPos, const int inputWordLength, const int outputWordStartPos, - const bool isSpaceProximity, int *freqArray, int *wordLengthArray, int *outputWord, - int *outputWordLength) const; - void getMultiWordsSuggestionRec(ProximityInfo *proximityInfo, const int *xcoordinates, - const int *ycoordinates, const int *codes, const bool useFullEditDistance, - const int inputSize, Correction *correction, WordsPriorityQueuePool *queuePool, - const bool hasAutoCorrectionCandidate, const int startPos, const int startWordIndex, - const int outputWordLength, int *freqArray, int *wordLengthArray, - int *outputWord) const; - - const uint8_t *const DICT_ROOT; - const int ROOT_POS; - const int MAX_DIGRAPH_SEARCH_DEPTH; - const int DICT_FLAGS; -}; -} // namespace latinime -#endif // LATINIME_UNIGRAM_DICTIONARY_H diff --git a/native/jni/src/utils/autocorrection_threshold_utils.cpp b/native/jni/src/utils/autocorrection_threshold_utils.cpp new file mode 100644 index 000000000..3406e0f8e --- /dev/null +++ b/native/jni/src/utils/autocorrection_threshold_utils.cpp @@ -0,0 +1,105 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/autocorrection_threshold_utils.h" + +#include <cmath> + +#include "defines.h" +#include "suggest/policyimpl/utils/edit_distance.h" +#include "suggest/policyimpl/utils/damerau_levenshtein_edit_distance_policy.h" + +namespace latinime { + +const int AutocorrectionThresholdUtils::MAX_INITIAL_SCORE = 255; +const int AutocorrectionThresholdUtils::TYPED_LETTER_MULTIPLIER = 2; +const int AutocorrectionThresholdUtils::FULL_WORD_MULTIPLIER = 2; + +/* static */ int AutocorrectionThresholdUtils::editDistance(const int *before, + const int beforeLength, const int *after, const int afterLength) { + const DamerauLevenshteinEditDistancePolicy daemaruLevenshtein( + before, beforeLength, after, afterLength); + return static_cast<int>(EditDistance::getEditDistance(&daemaruLevenshtein)); +} + +// In dictionary.cpp, getSuggestion() method, +// When USE_SUGGEST_INTERFACE_FOR_TYPING is true: +// +// // TODO: Revise the following logic thoroughly by referring to the logic +// // marked as "Otherwise" below. +// SUGGEST_INTERFACE_OUTPUT_SCALE was multiplied to the original suggestion scores to convert +// them to integers. +// score = (int)((original score) * SUGGEST_INTERFACE_OUTPUT_SCALE) +// Undo the scaling here to recover the original score. +// normalizedScore = ((float)score) / SUGGEST_INTERFACE_OUTPUT_SCALE +// +// Otherwise: suggestion scores are computed using the below formula. +// original score +// := powf(mTypedLetterMultiplier (this is defined 2), +// (the number of matched characters between typed word and suggested word)) +// * (individual word's score which defined in the unigram dictionary, +// and this score is defined in range [0, 255].) +// Then, the following processing is applied. +// - If the dictionary word is matched up to the point of the user entry +// (full match up to min(before.length(), after.length()) +// => Then multiply by FULL_MATCHED_WORDS_PROMOTION_RATE (this is defined 1.2) +// - If the word is a true full match except for differences in accents or +// capitalization, then treat it as if the score was 255. +// - If before.length() == after.length() +// => multiply by mFullWordMultiplier (this is defined 2)) +// So, maximum original score is powf(2, min(before.length(), after.length())) * 255 * 2 * 1.2 +// For historical reasons we ignore the 1.2 modifier (because the measure for a good +// autocorrection threshold was done at a time when it didn't exist). This doesn't change +// the result. +// So, we can normalize original score by dividing powf(2, min(b.l(),a.l())) * 255 * 2. + +/* static */ float AutocorrectionThresholdUtils::calcNormalizedScore(const int *before, + const int beforeLength, const int *after, const int afterLength, const int score) { + if (0 == beforeLength || 0 == afterLength) { + return 0.0f; + } + const int distance = editDistance(before, beforeLength, after, afterLength); + int spaceCount = 0; + for (int i = 0; i < afterLength; ++i) { + if (after[i] == KEYCODE_SPACE) { + ++spaceCount; + } + } + + if (spaceCount == afterLength) { + return 0.0f; + } + + // add a weight based on edit distance. + // distance <= max(afterLength, beforeLength) == afterLength, + // so, 0 <= distance / afterLength <= 1 + const float weight = 1.0f - static_cast<float>(distance) / static_cast<float>(afterLength); + + // TODO: Revise the following logic thoroughly by referring to... + if (true /* USE_SUGGEST_INTERFACE_FOR_TYPING */) { + return (static_cast<float>(score) / SUGGEST_INTERFACE_OUTPUT_SCALE) * weight; + } + // ...this logic. + const float maxScore = score >= S_INT_MAX ? static_cast<float>(S_INT_MAX) + : static_cast<float>(MAX_INITIAL_SCORE) + * powf(static_cast<float>(TYPED_LETTER_MULTIPLIER), + static_cast<float>(min(beforeLength, afterLength - spaceCount))) + * static_cast<float>(FULL_WORD_MULTIPLIER); + + return (static_cast<float>(score) / maxScore) * weight; +} + +} // namespace latinime diff --git a/native/jni/src/utils/autocorrection_threshold_utils.h b/native/jni/src/utils/autocorrection_threshold_utils.h new file mode 100644 index 000000000..c7537a6a5 --- /dev/null +++ b/native/jni/src/utils/autocorrection_threshold_utils.h @@ -0,0 +1,39 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_AUTOCORRECTION_THRESHOLD_UTILS_H +#define LATINIME_AUTOCORRECTION_THRESHOLD_UTILS_H + +#include "defines.h" + +namespace latinime { + +class AutocorrectionThresholdUtils { + public: + static float calcNormalizedScore(const int *before, const int beforeLength, + const int *after, const int afterLength, const int score); + static int editDistance(const int *before, const int beforeLength, const int *after, + const int afterLength); + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(AutocorrectionThresholdUtils); + + static const int MAX_INITIAL_SCORE; + static const int TYPED_LETTER_MULTIPLIER; + static const int FULL_WORD_MULTIPLIER; +}; +} // namespace latinime +#endif // LATINIME_AUTOCORRECTION_THRESHOLD_UTILS_H diff --git a/native/jni/src/char_utils.cpp b/native/jni/src/utils/char_utils.cpp index e219beb62..0e7039610 100644 --- a/native/jni/src/char_utils.cpp +++ b/native/jni/src/utils/char_utils.cpp @@ -14,9 +14,10 @@ * limitations under the License. */ +#include "utils/char_utils.h" + #include <cstdlib> -#include "char_utils.h" #include "defines.h" namespace latinime { @@ -36,8 +37,7 @@ struct LatinCapitalSmallPair { * $ apt-get install libicu-dev * * 3. Build the following code - * (You need this file, char_utils.h, and defines.h) - * $ g++ -o char_utils -DUPDATING_CHAR_UTILS char_utils.cpp -licuuc + * $ g++ -o char_utils -I.. -DUPDATING_CHAR_UTILS char_utils.cpp -licuuc */ #ifdef UPDATING_CHAR_UTILS #include <stdio.h> @@ -47,7 +47,7 @@ extern "C" int main() { for (unsigned short c = 0; c < 0xFFFF; c++) { if (c <= 0x7F) continue; const unsigned short icu4cLowerC = u_tolower(c); - const unsigned short myLowerC = latin_tolower(c); + const unsigned short myLowerC = CharUtils::latin_tolower(c); if (c != icu4cLowerC) { #ifdef CONFIRMING_CHAR_UTILS if (icu4cLowerC != myLowerC) { @@ -70,7 +70,7 @@ extern "C" int main() { * * 5. Update the SORTED_CHAR_MAP[] array below with the output above. * Then, rebuild with -DCONFIRMING_CHAR_UTILS and confirm the program exits successfully. - * $ g++ -o char_utils -DUPDATING_CHAR_UTILS -DCONFIRMING_CHAR_UTILS char_utils.cpp -licuuc + * $ g++ -o char_utils -I.. -DUPDATING_CHAR_UTILS -DCONFIRMING_CHAR_UTILS char_utils.cpp -licuuc * $ ./char_utils * $ */ @@ -1054,7 +1054,7 @@ static int compare_pair_capital(const void *a, const void *b) { - static_cast<int>((static_cast<const struct LatinCapitalSmallPair *>(b))->capital); } -unsigned short latin_tolower(const unsigned short c) { +/* static */ unsigned short CharUtils::latin_tolower(const unsigned short c) { struct LatinCapitalSmallPair *p = static_cast<struct LatinCapitalSmallPair *>(bsearch(&c, SORTED_CHAR_MAP, NELEMS(SORTED_CHAR_MAP), sizeof(SORTED_CHAR_MAP[0]), compare_pair_capital)); @@ -1063,7 +1063,7 @@ unsigned short latin_tolower(const unsigned short c) { /* * Table mapping most combined Latin, Greek, and Cyrillic characters - * to their base characters. If c is in range, BASE_CHARS[c] == c + * to their base characters. If c is in range, CharUtils::BASE_CHARS[c] == c * if c is not a combined character, or the base character if it * is combined. * @@ -1074,7 +1074,7 @@ unsigned short latin_tolower(const unsigned short c) { * for ($j = $i; $j < $i + 8; $j++) { \ * printf("0x%04X, ", $base[$j] ? $base[$j] : $j)}; print "\n"; }' */ -const unsigned short BASE_CHARS[BASE_CHARS_SIZE] = { +/* static */ const unsigned short CharUtils::BASE_CHARS[CharUtils::BASE_CHARS_SIZE] = { /* U+0000 */ 0x0000, 0x0001, 0x0002, 0x0003, 0x0004, 0x0005, 0x0006, 0x0007, /* U+0008 */ 0x0008, 0x0009, 0x000A, 0x000B, 0x000C, 0x000D, 0x000E, 0x000F, /* U+0010 */ 0x0010, 0x0011, 0x0012, 0x0013, 0x0014, 0x0015, 0x0016, 0x0017, diff --git a/native/jni/src/utils/char_utils.h b/native/jni/src/utils/char_utils.h new file mode 100644 index 000000000..2e735a81c --- /dev/null +++ b/native/jni/src/utils/char_utils.h @@ -0,0 +1,93 @@ +/* + * Copyright (C) 2010 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_CHAR_UTILS_H +#define LATINIME_CHAR_UTILS_H + +#include <cctype> + +#include "defines.h" + +namespace latinime { + +class CharUtils { + public: + static AK_FORCE_INLINE bool isAsciiUpper(int c) { + // Note: isupper(...) reports false positives for some Cyrillic characters, causing them to + // be incorrectly lower-cased using toAsciiLower(...) rather than latin_tolower(...). + return (c >= 'A' && c <= 'Z'); + } + + static AK_FORCE_INLINE int toAsciiLower(int c) { + return c - 'A' + 'a'; + } + + static AK_FORCE_INLINE bool isAscii(int c) { + return isascii(c) != 0; + } + + static AK_FORCE_INLINE int toLowerCase(const int c) { + if (isAsciiUpper(c)) { + return toAsciiLower(c); + } + if (isAscii(c)) { + return c; + } + return static_cast<int>(latin_tolower(static_cast<unsigned short>(c))); + } + + static AK_FORCE_INLINE int toBaseLowerCase(const int c) { + return toLowerCase(toBaseCodePoint(c)); + } + + static AK_FORCE_INLINE bool isIntentionalOmissionCodePoint(const int codePoint) { + // TODO: Do not hardcode here + return codePoint == KEYCODE_SINGLE_QUOTE || codePoint == KEYCODE_HYPHEN_MINUS; + } + + static AK_FORCE_INLINE int getCodePointCount(const int arraySize, const int *const codePoints) { + int size = 0; + for (; size < arraySize; ++size) { + if (codePoints[size] == '\0') { + break; + } + } + return size; + } + + static AK_FORCE_INLINE int toBaseCodePoint(int c) { + if (c < BASE_CHARS_SIZE) { + return static_cast<int>(BASE_CHARS[c]); + } + return c; + } + + static unsigned short latin_tolower(const unsigned short c); + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(CharUtils); + + /** + * Table mapping most combined Latin, Greek, and Cyrillic characters + * to their base characters. If c is in range, BASE_CHARS[c] == c + * if c is not a combined character, or the base character if it + * is combined. + */ + static const int BASE_CHARS_SIZE = 0x0500; + static const unsigned short BASE_CHARS[BASE_CHARS_SIZE]; +}; +} // namespace latinime +#endif // LATINIME_CHAR_UTILS_H diff --git a/native/jni/src/hash_map_compat.h b/native/jni/src/utils/hash_map_compat.h index a1e982bc4..a1e982bc4 100644 --- a/native/jni/src/hash_map_compat.h +++ b/native/jni/src/utils/hash_map_compat.h diff --git a/native/jni/src/utils/log_utils.cpp b/native/jni/src/utils/log_utils.cpp new file mode 100644 index 000000000..5ab2b2862 --- /dev/null +++ b/native/jni/src/utils/log_utils.cpp @@ -0,0 +1,72 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "log_utils.h" + +#include <cstdio> +#include <stdarg.h> + +#include "defines.h" + +namespace latinime { + /* static */ void LogUtils::logToJava(JNIEnv *const env, const char *const format, ...) { + static const char *TAG = "LatinIME:LogUtils"; + const jclass androidUtilLogClass = env->FindClass("android/util/Log"); + if (!androidUtilLogClass) { + // If we can't find the class, we are probably in off-device testing, and + // it's expected. Regardless, logging is not essential to functionality, so + // we should just return. However, FindClass has thrown an exception behind + // our back and there is no way to prevent it from doing that, so we clear + // the exception before we return. + env->ExceptionClear(); + return; + } + const jmethodID logDotIMethodId = env->GetStaticMethodID(androidUtilLogClass, "i", + "(Ljava/lang/String;Ljava/lang/String;)I"); + if (!logDotIMethodId) { + env->ExceptionClear(); + if (androidUtilLogClass) env->DeleteLocalRef(androidUtilLogClass); + return; + } + const jstring javaTag = env->NewStringUTF(TAG); + + static const int DEFAULT_LINE_SIZE = 128; + char fixedSizeCString[DEFAULT_LINE_SIZE]; + va_list argList; + va_start(argList, format); + // Get the necessary size. Add 1 for the 0 terminator. + const int size = vsnprintf(fixedSizeCString, DEFAULT_LINE_SIZE, format, argList) + 1; + va_end(argList); + + jstring javaString; + if (size <= DEFAULT_LINE_SIZE) { + // The buffer was large enough. + javaString = env->NewStringUTF(fixedSizeCString); + } else { + // The buffer was not large enough. + va_start(argList, format); + char variableSizeCString[size]; + vsnprintf(variableSizeCString, size, format, argList); + va_end(argList); + javaString = env->NewStringUTF(variableSizeCString); + } + + env->CallStaticIntMethod(androidUtilLogClass, logDotIMethodId, javaTag, javaString); + if (javaString) env->DeleteLocalRef(javaString); + if (javaTag) env->DeleteLocalRef(javaTag); + if (androidUtilLogClass) env->DeleteLocalRef(androidUtilLogClass); + } +} diff --git a/native/jni/src/utils/log_utils.h b/native/jni/src/utils/log_utils.h new file mode 100644 index 000000000..6ac16d91a --- /dev/null +++ b/native/jni/src/utils/log_utils.h @@ -0,0 +1,37 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_LOG_UTILS_H +#define LATINIME_LOG_UTILS_H + +#include "defines.h" +#include "jni.h" + +namespace latinime { + +class LogUtils { + public: + static void logToJava(JNIEnv *const env, const char *const format, ...) +#ifdef __GNUC__ + __attribute__ ((format (printf, 2, 3))) +#endif // __GNUC__ + ; + + private: + DISALLOW_COPY_AND_ASSIGN(LogUtils); +}; +} // namespace latinime +#endif // LATINIME_LOG_UTILS_H diff --git a/native/jni/src/words_priority_queue.cpp b/native/jni/src/words_priority_queue.cpp deleted file mode 100644 index 7e18d0f87..000000000 --- a/native/jni/src/words_priority_queue.cpp +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Copyright (C) 2012, The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "words_priority_queue.h" - -namespace latinime { - -int WordsPriorityQueue::outputSuggestions(const int *before, const int beforeLength, - int *frequencies, int *outputCodePoints, int* outputTypes) { - mHighestSuggestedWord = 0; - const int size = min(MAX_WORDS, static_cast<int>(mSuggestions.size())); - SuggestedWord *swBuffer[size]; - int index = size - 1; - while (!mSuggestions.empty() && index >= 0) { - SuggestedWord *sw = mSuggestions.top(); - if (DEBUG_WORDS_PRIORITY_QUEUE) { - AKLOGI("dump word. %d", sw->mScore); - DUMP_WORD(sw->mWord, sw->mWordLength); - } - swBuffer[index] = sw; - mSuggestions.pop(); - --index; - } - if (size >= 2) { - SuggestedWord *nsMaxSw = 0; - int maxIndex = 0; - float maxNs = 0; - for (int i = 0; i < size; ++i) { - SuggestedWord *tempSw = swBuffer[i]; - if (!tempSw) { - continue; - } - const float tempNs = getNormalizedScore(tempSw, before, beforeLength, 0, 0, 0); - if (tempNs >= maxNs) { - maxNs = tempNs; - maxIndex = i; - nsMaxSw = tempSw; - } - } - if (maxIndex > 0 && nsMaxSw) { - memmove(&swBuffer[1], &swBuffer[0], maxIndex * sizeof(swBuffer[0])); - swBuffer[0] = nsMaxSw; - } - } - for (int i = 0; i < size; ++i) { - SuggestedWord *sw = swBuffer[i]; - if (!sw) { - AKLOGE("SuggestedWord is null %d", i); - continue; - } - const int wordLength = sw->mWordLength; - int *targetAddress = outputCodePoints + i * MAX_WORD_LENGTH; - frequencies[i] = sw->mScore; - outputTypes[i] = sw->mType; - memcpy(targetAddress, sw->mWord, wordLength * sizeof(targetAddress[0])); - if (wordLength < MAX_WORD_LENGTH) { - targetAddress[wordLength] = 0; - } - sw->mUsed = false; - } - return size; -} -} // namespace latinime diff --git a/native/jni/src/words_priority_queue.h b/native/jni/src/words_priority_queue.h deleted file mode 100644 index 54e8007a2..000000000 --- a/native/jni/src/words_priority_queue.h +++ /dev/null @@ -1,175 +0,0 @@ -/* - * Copyright (C) 2011 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LATINIME_WORDS_PRIORITY_QUEUE_H -#define LATINIME_WORDS_PRIORITY_QUEUE_H - -#include <cstring> // for memcpy() -#include <queue> - -#include "correction.h" -#include "defines.h" - -namespace latinime { - -class WordsPriorityQueue { - public: - struct SuggestedWord { - int mScore; - int mWord[MAX_WORD_LENGTH]; - int mWordLength; - bool mUsed; - int mType; - - void setParams(int score, int *word, int wordLength, int type) { - mScore = score; - mWordLength = wordLength; - memcpy(mWord, word, sizeof(mWord[0]) * wordLength); - mUsed = true; - mType = type; - } - }; - - WordsPriorityQueue(int maxWords) - : mSuggestions(), MAX_WORDS(maxWords), - mSuggestedWords(new SuggestedWord[MAX_WORD_LENGTH]), mHighestSuggestedWord(0) { - for (int i = 0; i < MAX_WORD_LENGTH; ++i) { - mSuggestedWords[i].mUsed = false; - } - } - - // Non virtual inline destructor -- never inherit this class - AK_FORCE_INLINE ~WordsPriorityQueue() { - delete[] mSuggestedWords; - } - - void push(int score, int *word, int wordLength, int type) { - SuggestedWord *sw = 0; - if (size() >= MAX_WORDS) { - sw = mSuggestions.top(); - const int minScore = sw->mScore; - if (minScore >= score) { - return; - } - sw->mUsed = false; - mSuggestions.pop(); - } - if (sw == 0) { - sw = getFreeSuggestedWord(score, word, wordLength, type); - } else { - sw->setParams(score, word, wordLength, type); - } - if (sw == 0) { - AKLOGE("SuggestedWord is accidentally null."); - return; - } - if (DEBUG_WORDS_PRIORITY_QUEUE) { - AKLOGI("Push word. %d, %d", score, wordLength); - DUMP_WORD(word, wordLength); - } - mSuggestions.push(sw); - if (!mHighestSuggestedWord || mHighestSuggestedWord->mScore < sw->mScore) { - mHighestSuggestedWord = sw; - } - } - - SuggestedWord *top() const { - if (mSuggestions.empty()) return 0; - SuggestedWord *sw = mSuggestions.top(); - return sw; - } - - int size() const { - return static_cast<int>(mSuggestions.size()); - } - - AK_FORCE_INLINE void clear() { - mHighestSuggestedWord = 0; - while (!mSuggestions.empty()) { - SuggestedWord *sw = mSuggestions.top(); - if (DEBUG_WORDS_PRIORITY_QUEUE) { - AKLOGI("Clear word. %d", sw->mScore); - DUMP_WORD(sw->mWord, sw->mWordLength); - } - sw->mUsed = false; - mSuggestions.pop(); - } - } - - AK_FORCE_INLINE void dumpTopWord() const { - if (size() <= 0) { - return; - } - DUMP_WORD(mHighestSuggestedWord->mWord, mHighestSuggestedWord->mWordLength); - } - - AK_FORCE_INLINE float getHighestNormalizedScore(const int *before, const int beforeLength, - int **outWord, int *outScore, int *outLength) const { - if (!mHighestSuggestedWord) { - return 0.0f; - } - return getNormalizedScore(mHighestSuggestedWord, before, beforeLength, outWord, outScore, - outLength); - } - - int outputSuggestions(const int *before, const int beforeLength, int *frequencies, - int *outputCodePoints, int* outputTypes); - - private: - DISALLOW_IMPLICIT_CONSTRUCTORS(WordsPriorityQueue); - struct wordComparator { - bool operator ()(SuggestedWord * left, SuggestedWord * right) { - return left->mScore > right->mScore; - } - }; - - SuggestedWord *getFreeSuggestedWord(int score, int *word, int wordLength, int type) const { - for (int i = 0; i < MAX_WORD_LENGTH; ++i) { - if (!mSuggestedWords[i].mUsed) { - mSuggestedWords[i].setParams(score, word, wordLength, type); - return &mSuggestedWords[i]; - } - } - return 0; - } - - static float getNormalizedScore(SuggestedWord *sw, const int *before, const int beforeLength, - int **outWord, int *outScore, int *outLength) { - const int score = sw->mScore; - int *word = sw->mWord; - const int wordLength = sw->mWordLength; - if (outScore) { - *outScore = score; - } - if (outWord) { - *outWord = word; - } - if (outLength) { - *outLength = wordLength; - } - return Correction::RankingAlgorithm::calcNormalizedScore(before, beforeLength, word, - wordLength, score); - } - - typedef std::priority_queue<SuggestedWord *, std::vector<SuggestedWord *>, - wordComparator> Suggestions; - Suggestions mSuggestions; - const int MAX_WORDS; - SuggestedWord *mSuggestedWords; - SuggestedWord *mHighestSuggestedWord; -}; -} // namespace latinime -#endif // LATINIME_WORDS_PRIORITY_QUEUE_H diff --git a/native/jni/src/words_priority_queue_pool.h b/native/jni/src/words_priority_queue_pool.h deleted file mode 100644 index 2cd210a05..000000000 --- a/native/jni/src/words_priority_queue_pool.h +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Copyright (C) 2011 The Android Open Source Project - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LATINIME_WORDS_PRIORITY_QUEUE_POOL_H -#define LATINIME_WORDS_PRIORITY_QUEUE_POOL_H - -#include "defines.h" -#include "words_priority_queue.h" - -namespace latinime { - -class WordsPriorityQueuePool { - public: - WordsPriorityQueuePool(int mainQueueMaxWords, int subQueueMaxWords) - // Note: using placement new() requires the caller to call the destructor explicitly. - : mMasterQueue(new(mMasterQueueBuf) WordsPriorityQueue(mainQueueMaxWords)) { - for (int i = 0, subQueueBufOffset = 0; - i < MULTIPLE_WORDS_SUGGESTION_MAX_WORDS * SUB_QUEUE_MAX_COUNT; - ++i, subQueueBufOffset += static_cast<int>(sizeof(WordsPriorityQueue))) { - mSubQueues[i] = new(mSubQueueBuf + subQueueBufOffset) - WordsPriorityQueue(subQueueMaxWords); - } - } - - // Non virtual inline destructor -- never inherit this class - ~WordsPriorityQueuePool() { - // Note: these explicit calls to the destructor match the calls to placement new() above. - if (mMasterQueue) mMasterQueue->~WordsPriorityQueue(); - for (int i = 0; i < MULTIPLE_WORDS_SUGGESTION_MAX_WORDS * SUB_QUEUE_MAX_COUNT; ++i) { - if (mSubQueues[i]) mSubQueues[i]->~WordsPriorityQueue(); - } - } - - WordsPriorityQueue *getMasterQueue() const { - return mMasterQueue; - } - - WordsPriorityQueue *getSubQueue(const int wordIndex, const int inputWordLength) const { - if (wordIndex >= MULTIPLE_WORDS_SUGGESTION_MAX_WORDS) { - return 0; - } - if (inputWordLength < 0 || inputWordLength >= SUB_QUEUE_MAX_COUNT) { - if (DEBUG_WORDS_PRIORITY_QUEUE) { - ASSERT(false); - } - return 0; - } - return mSubQueues[wordIndex * SUB_QUEUE_MAX_COUNT + inputWordLength]; - } - - inline void clearAll() { - mMasterQueue->clear(); - for (int i = 0; i < MULTIPLE_WORDS_SUGGESTION_MAX_WORDS; ++i) { - clearSubQueue(i); - } - } - - AK_FORCE_INLINE void clearSubQueue(const int wordIndex) { - for (int i = 0; i < SUB_QUEUE_MAX_COUNT; ++i) { - WordsPriorityQueue *queue = getSubQueue(wordIndex, i); - if (queue) { - queue->clear(); - } - } - } - - void dumpSubQueue1TopSuggestions() const { - AKLOGI("DUMP SUBQUEUE1 TOP SUGGESTIONS"); - for (int i = 0; i < SUB_QUEUE_MAX_COUNT; ++i) { - getSubQueue(0, i)->dumpTopWord(); - } - } - - private: - DISALLOW_IMPLICIT_CONSTRUCTORS(WordsPriorityQueuePool); - char mMasterQueueBuf[sizeof(WordsPriorityQueue)]; - char mSubQueueBuf[SUB_QUEUE_MAX_COUNT * MULTIPLE_WORDS_SUGGESTION_MAX_WORDS - * sizeof(WordsPriorityQueue)]; - WordsPriorityQueue *mMasterQueue; - WordsPriorityQueue *mSubQueues[SUB_QUEUE_MAX_COUNT * MULTIPLE_WORDS_SUGGESTION_MAX_WORDS]; -}; -} // namespace latinime -#endif // LATINIME_WORDS_PRIORITY_QUEUE_POOL_H |