diff options
Diffstat (limited to 'native/jni/src')
40 files changed, 4770 insertions, 19 deletions
diff --git a/native/jni/src/proximity_info_state.cpp b/native/jni/src/proximity_info_state.cpp index 7fcfd5dc8..861ba9971 100644 --- a/native/jni/src/proximity_info_state.cpp +++ b/native/jni/src/proximity_info_state.cpp @@ -33,11 +33,13 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi const int *const xCoordinates, const int *const yCoordinates, const int *const times, const int *const pointerIds, const bool isGeometric) { ASSERT(isGeometric || (inputSize < MAX_WORD_LENGTH)); - mIsContinuationPossible = ProximityInfoStateUtils::checkAndReturnIsContinuationPossible( - inputSize, xCoordinates, yCoordinates, times, mSampledInputSize, &mSampledInputXs, - &mSampledInputYs, &mSampledTimes, &mSampledInputIndice); + mIsContinuousSuggestionPossible = + ProximityInfoStateUtils::checkAndReturnIsContinuousSuggestionPossible( + inputSize, xCoordinates, yCoordinates, times, mSampledInputSize, + &mSampledInputXs, &mSampledInputYs, &mSampledTimes, &mSampledInputIndice); if (DEBUG_DICT) { - AKLOGI("isContinuationPossible = %s", (mIsContinuationPossible ? "true" : "false")); + AKLOGI("isContinuousSuggestionPossible = %s", + (mIsContinuousSuggestionPossible ? "true" : "false")); } mProximityInfo = proximityInfo; @@ -64,7 +66,7 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi mSampledInputSize = 0; mMostProbableStringProbability = 0.0f; - if (mIsContinuationPossible && mSampledInputIndice.size() > 1) { + if (mIsContinuousSuggestionPossible && mSampledInputIndice.size() > 1) { // Just update difference. // Previous two points are never skipped. Thus, we pop 2 input point data here. pushTouchPointStartIndex = ProximityInfoStateUtils::trimLastTwoTouchPoints( diff --git a/native/jni/src/proximity_info_state.h b/native/jni/src/proximity_info_state.h index 224240b00..9bba751d0 100644 --- a/native/jni/src/proximity_info_state.h +++ b/native/jni/src/proximity_info_state.h @@ -47,12 +47,12 @@ class ProximityInfoState { : mProximityInfo(0), mMaxPointToKeyLength(0.0f), mAverageSpeed(0.0f), mHasTouchPositionCorrectionData(false), mMostCommonKeyWidthSquare(0), mKeyCount(0), mCellHeight(0), mCellWidth(0), mGridHeight(0), mGridWidth(0), - mIsContinuationPossible(false), mSampledInputXs(), mSampledInputYs(), mSampledTimes(), - mSampledInputIndice(), mSampledLengthCache(), mBeelineSpeedPercentiles(), - mSampledDistanceCache_G(), mSpeedRates(), mDirections(), mCharProbabilities(), - mSampledNearKeySets(), mSampledSearchKeySets(), mSampledSearchKeyVectors(), - mTouchPositionCorrectionEnabled(false), mSampledInputSize(0), - mMostProbableStringProbability(0.0f) { + mIsContinuousSuggestionPossible(false), mSampledInputXs(), mSampledInputYs(), + mSampledTimes(), mSampledInputIndice(), mSampledLengthCache(), + mBeelineSpeedPercentiles(), mSampledDistanceCache_G(), mSpeedRates(), mDirections(), + mCharProbabilities(), mSampledNearKeySets(), mSampledSearchKeySets(), + mSampledSearchKeyVectors(), mTouchPositionCorrectionEnabled(false), + mSampledInputSize(0), mMostProbableStringProbability(0.0f) { memset(mInputProximities, 0, sizeof(mInputProximities)); memset(mNormalizedSquaredDistances, 0, sizeof(mNormalizedSquaredDistances)); memset(mPrimaryInputWord, 0, sizeof(mPrimaryInputWord)); @@ -143,8 +143,8 @@ class ProximityInfoState { return mSampledLengthCache[index]; } - bool isContinuationPossible() const { - return mIsContinuationPossible; + bool isContinuousSuggestionPossible() const { + return mIsContinuousSuggestionPossible; } float getPointToKeyByIdLength(const int inputIndex, const int keyId) const; @@ -223,7 +223,7 @@ class ProximityInfoState { int mCellWidth; int mGridHeight; int mGridWidth; - bool mIsContinuationPossible; + bool mIsContinuousSuggestionPossible; std::vector<int> mSampledInputXs; std::vector<int> mSampledInputYs; diff --git a/native/jni/src/proximity_info_state_utils.cpp b/native/jni/src/proximity_info_state_utils.cpp index ccb28bc8c..760508076 100644 --- a/native/jni/src/proximity_info_state_utils.cpp +++ b/native/jni/src/proximity_info_state_utils.cpp @@ -968,10 +968,10 @@ namespace latinime { return true; } -/* static */ bool ProximityInfoStateUtils::checkAndReturnIsContinuationPossible(const int inputSize, - const int *const xCoordinates, const int *const yCoordinates, const int *const times, - const int sampledInputSize, const std::vector<int> *const sampledInputXs, - const std::vector<int> *const sampledInputYs, +/* static */ bool ProximityInfoStateUtils::checkAndReturnIsContinuousSuggestionPossible( + const int inputSize, const int *const xCoordinates, const int *const yCoordinates, + const int *const times, const int sampledInputSize, + const std::vector<int> *const sampledInputXs, const std::vector<int> *const sampledInputYs, const std::vector<int> *const sampledTimes, const std::vector<int> *const sampledInputIndices) { if (inputSize < sampledInputSize) { diff --git a/native/jni/src/proximity_info_state_utils.h b/native/jni/src/proximity_info_state_utils.h index a7f4a3425..3ceb25d8b 100644 --- a/native/jni/src/proximity_info_state_utils.h +++ b/native/jni/src/proximity_info_state_utils.h @@ -101,7 +101,7 @@ class ProximityInfoStateUtils { const std::vector<int> *const sampledTimes, const std::vector<float> *const sampledSpeedRates, const std::vector<int> *const sampledBeelineSpeedPercentiles); - static bool checkAndReturnIsContinuationPossible(const int inputSize, + static bool checkAndReturnIsContinuousSuggestionPossible(const int inputSize, const int *const xCoordinates, const int *const yCoordinates, const int *const times, const int sampledInputSize, const std::vector<int> *const sampledInputXs, const std::vector<int> *const sampledInputYs, diff --git a/native/jni/src/suggest/core/dicnode/dic_node.cpp b/native/jni/src/suggest/core/dicnode/dic_node.cpp new file mode 100644 index 000000000..8c48c587b --- /dev/null +++ b/native/jni/src/suggest/core/dicnode/dic_node.cpp @@ -0,0 +1,44 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "dic_node.h" + +namespace latinime { + +DicNode::DicNode(const DicNode &dicNode) + : +#if DEBUG_DICT + mProfiler(dicNode.mProfiler), +#endif + mDicNodeProperties(dicNode.mDicNodeProperties), mDicNodeState(dicNode.mDicNodeState), + mIsCachedForNextSuggestion(dicNode.mIsCachedForNextSuggestion), mIsUsed(dicNode.mIsUsed), + mReleaseListener(0) { + /* empty */ +} + +DicNode &DicNode::operator=(const DicNode &dicNode) { +#if DEBUG_DICT + mProfiler = dicNode.mProfiler; +#endif + mDicNodeProperties = dicNode.mDicNodeProperties; + mDicNodeState = dicNode.mDicNodeState; + mIsCachedForNextSuggestion = dicNode.mIsCachedForNextSuggestion; + mIsUsed = dicNode.mIsUsed; + mReleaseListener = dicNode.mReleaseListener; + return *this; +} + +} // namespace latinime diff --git a/native/jni/src/suggest/core/dicnode/dic_node.h b/native/jni/src/suggest/core/dicnode/dic_node.h new file mode 100644 index 000000000..7bfa459a2 --- /dev/null +++ b/native/jni/src/suggest/core/dicnode/dic_node.h @@ -0,0 +1,572 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DIC_NODE_H +#define LATINIME_DIC_NODE_H + +#include "char_utils.h" +#include "defines.h" +#include "dic_node_state.h" +#include "dic_node_profiler.h" +#include "dic_node_properties.h" +#include "dic_node_release_listener.h" + +#if DEBUG_DICT +#define LOGI_SHOW_ADD_COST_PROP \ + do { char charBuf[50]; \ + INTS_TO_CHARS(getOutputWordBuf(), getDepth(), charBuf); \ + AKLOGI("%20s, \"%c\", size = %03d, total = %03d, index(0) = %02d, dist = %.4f, %s,,", \ + __FUNCTION__, getNodeCodePoint(), inputSize, getTotalInputIndex(), \ + getInputIndex(0), getNormalizedCompoundDistance(), charBuf); } while (0) +#define DUMP_WORD_AND_SCORE(header) \ + do { char charBuf[50]; char prevWordCharBuf[50]; \ + INTS_TO_CHARS(getOutputWordBuf(), getDepth(), charBuf); \ + INTS_TO_CHARS(mDicNodeState.mDicNodeStatePrevWord.mPrevWord, \ + mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(), prevWordCharBuf); \ + AKLOGI("#%8s, %5f, %5f, %5f, %5f, %s, %s, %d,,", header, \ + getSpatialDistanceForScoring(), getLanguageDistanceForScoring(), \ + getNormalizedCompoundDistance(), getRawLength(), prevWordCharBuf, charBuf, \ + getInputIndex(0)); \ + } while (0) +#else +#define LOGI_SHOW_ADD_COST_PROP +#define DUMP_WORD_AND_SCORE(header) +#endif + +namespace latinime { + +// Naming convention +// - Distance: "Weighted" edit distance -- used both for spatial and language. +// - Compound Distance: Spatial Distance + Language Distance -- used for pruning and scoring +// - Cost: delta/diff for Distance -- used both for spatial and language +// - Length: "Non-weighted" -- used only for spatial +// - Probability: "Non-weighted" -- used only for language + +// This struct is purely a bucket to return values. No instances of this struct should be kept. +struct DicNode_InputStateG { + bool mNeedsToUpdateInputStateG; + int mPointerId; + int16_t mInputIndex; + int mPrevCodePoint; + float mTerminalDiffCost; + float mRawLength; + DoubleLetterLevel mDoubleLetterLevel; +}; + +class DicNode { + // Caveat: We define Weighting as a friend class of DicNode to let Weighting change + // the distance of DicNode. + // Caution!!! In general, we avoid using the "friend" access modifier. + // This is an exception to explicitly hide DicNode::addCost() from all classes but Weighting. + friend class Weighting; + + public: +#if DEBUG_DICT + DicNodeProfiler mProfiler; +#endif + ////////////////// + // Memory utils // + ////////////////// + AK_FORCE_INLINE static void managedDelete(DicNode *node) { + node->remove(); + } + // end + ///////////////// + + AK_FORCE_INLINE DicNode() + : +#if DEBUG_DICT + mProfiler(), +#endif + mDicNodeProperties(), mDicNodeState(), mIsCachedForNextSuggestion(false), + mIsUsed(false), mReleaseListener(0) {} + + DicNode(const DicNode &dicNode); + DicNode &operator=(const DicNode &dicNode); + virtual ~DicNode() {} + + // TODO: minimize arguments by looking binary_format + // Init for copy + void initByCopy(const DicNode *dicNode) { + mIsUsed = true; + mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion; + mDicNodeProperties.init(&dicNode->mDicNodeProperties); + mDicNodeState.init(&dicNode->mDicNodeState); + PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); + } + + // TODO: minimize arguments by looking binary_format + // Init for root with prevWordNodePos which is used for bigram + void initAsRoot(const int pos, const int childrenPos, const int childrenCount, + const int prevWordNodePos) { + mIsUsed = true; + mIsCachedForNextSuggestion = false; + mDicNodeProperties.init( + pos, 0, childrenPos, 0, 0, 0, childrenCount, 0, 0, false, false, true, 0, 0); + mDicNodeState.init(prevWordNodePos); + PROF_NODE_RESET(mProfiler); + } + + void initAsPassingChild(DicNode *parentNode) { + mIsUsed = true; + mIsCachedForNextSuggestion = parentNode->mIsCachedForNextSuggestion; + const int c = parentNode->getNodeTypedCodePoint(); + mDicNodeProperties.init(&parentNode->mDicNodeProperties, c); + mDicNodeState.init(&parentNode->mDicNodeState); + PROF_NODE_COPY(&parentNode->mProfiler, mProfiler); + } + + // TODO: minimize arguments by looking binary_format + // Init for root with previous word + void initAsRootWithPreviousWord(DicNode *dicNode, const int pos, const int childrenPos, + const int childrenCount) { + mIsUsed = true; + mIsCachedForNextSuggestion = false; + mDicNodeProperties.init( + pos, 0, childrenPos, 0, 0, 0, childrenCount, 0, 0, false, false, true, 0, 0); + // TODO: Move to dicNodeState? + mDicNodeState.mDicNodeStateOutput.init(); // reset for next word + mDicNodeState.mDicNodeStateInput.init( + &dicNode->mDicNodeState.mDicNodeStateInput, true /* resetTerminalDiffCost */); + mDicNodeState.mDicNodeStateScoring.init( + &dicNode->mDicNodeState.mDicNodeStateScoring); + mDicNodeState.mDicNodeStatePrevWord.init( + dicNode->mDicNodeState.mDicNodeStatePrevWord.getPrevWordCount() + 1, + dicNode->mDicNodeProperties.getProbability(), + dicNode->mDicNodeProperties.getPos(), + dicNode->mDicNodeState.mDicNodeStatePrevWord.mPrevWord, + dicNode->mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(), + dicNode->getOutputWordBuf(), + dicNode->mDicNodeProperties.getDepth(), + dicNode->mDicNodeState.mDicNodeStatePrevWord.mPrevSpacePositions, + mDicNodeState.mDicNodeStateInput.getInputIndex(0) /* lastInputIndex */); + PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); + } + + // TODO: minimize arguments by looking binary_format + void initAsChild(DicNode *dicNode, const int pos, const uint8_t flags, const int childrenPos, + const int attributesPos, const int siblingPos, const int nodeCodePoint, + const int childrenCount, const int probability, const int bigramProbability, + const bool isTerminal, const bool hasMultipleChars, const bool hasChildren, + const uint16_t additionalSubwordLength, const int *additionalSubword) { + mIsUsed = true; + uint16_t newDepth = static_cast<uint16_t>(dicNode->getDepth() + 1); + mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion; + const uint16_t newLeavingDepth = static_cast<uint16_t>( + dicNode->mDicNodeProperties.getLeavingDepth() + additionalSubwordLength); + mDicNodeProperties.init(pos, flags, childrenPos, attributesPos, siblingPos, nodeCodePoint, + childrenCount, probability, bigramProbability, isTerminal, hasMultipleChars, + hasChildren, newDepth, newLeavingDepth); + mDicNodeState.init(&dicNode->mDicNodeState, additionalSubwordLength, additionalSubword); + PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); + } + + AK_FORCE_INLINE void remove() { + mIsUsed = false; + if (mReleaseListener) { + mReleaseListener->onReleased(this); + } + } + + bool isUsed() const { + return mIsUsed; + } + + bool isRoot() const { + return getDepth() == 0; + } + + bool hasChildren() const { + return mDicNodeProperties.hasChildren(); + } + + bool isLeavingNode() const { + ASSERT(getDepth() <= getLeavingDepth()); + return getDepth() == getLeavingDepth(); + } + + AK_FORCE_INLINE bool isFirstLetter() const { + return getDepth() == 1; + } + + bool isCached() const { + return mIsCachedForNextSuggestion; + } + + void setCached() { + mIsCachedForNextSuggestion = true; + } + + // Used to expand the node in DicNodeUtils + int getNodeTypedCodePoint() const { + return mDicNodeState.mDicNodeStateOutput.getCodePointAt(getDepth()); + } + + bool isImpossibleBigramWord() const { + const int probability = mDicNodeProperties.getProbability(); + if (probability == 0) { + return true; + } + const int prevWordLen = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength() + - mDicNodeState.mDicNodeStatePrevWord.getPrevWordStart() - 1; + const int currentWordLen = getDepth(); + return (prevWordLen == 1 && currentWordLen == 1); + } + + bool isCapitalized() const { + const int c = getOutputWordBuf()[0]; + return isAsciiUpper(c); + } + + bool isFirstWord() const { + return mDicNodeState.mDicNodeStatePrevWord.getPrevWordNodePos() == NOT_VALID_WORD; + } + + bool isCompletion(const int inputSize) const { + return mDicNodeState.mDicNodeStateInput.getInputIndex(0) >= inputSize; + } + + bool canDoLookAheadCorrection(const int inputSize) const { + return mDicNodeState.mDicNodeStateInput.getInputIndex(0) < inputSize - 1; + } + + // Used to get bigram probability in DicNodeUtils + int getPos() const { + return mDicNodeProperties.getPos(); + } + + // Used to get bigram probability in DicNodeUtils + int getPrevWordPos() const { + return mDicNodeState.mDicNodeStatePrevWord.getPrevWordNodePos(); + } + + // Used in DicNodeUtils + int getChildrenPos() const { + return mDicNodeProperties.getChildrenPos(); + } + + // Used in DicNodeUtils + int getChildrenCount() const { + return mDicNodeProperties.getChildrenCount(); + } + + // Used in DicNodeUtils + int getProbability() const { + return mDicNodeProperties.getProbability(); + } + + AK_FORCE_INLINE bool isTerminalWordNode() const { + const bool isTerminalNodes = mDicNodeProperties.isTerminal(); + const int currentNodeDepth = getDepth(); + const int terminalNodeDepth = mDicNodeProperties.getLeavingDepth(); + return isTerminalNodes && currentNodeDepth > 0 && currentNodeDepth == terminalNodeDepth; + } + + bool shouldBeFilterdBySafetyNetForBigram() const { + const uint16_t currentDepth = getDepth(); + const int prevWordLen = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength() + - mDicNodeState.mDicNodeStatePrevWord.getPrevWordStart() - 1; + return !(currentDepth > 0 && (currentDepth != 1 || prevWordLen != 1)); + } + + uint16_t getLeavingDepth() const { + return mDicNodeProperties.getLeavingDepth(); + } + + bool isTotalInputSizeExceedingLimit() const { + const int prevWordsLen = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(); + const int currentWordDepth = getDepth(); + // TODO: 3 can be 2? Needs to be investigated. + // TODO: Have a const variable for 3 (or 2) + return prevWordsLen + currentWordDepth > MAX_WORD_LENGTH - 3; + } + + // TODO: This may be defective. Needs to be revised. + bool truncateNode(const DicNode *const topNode, const int inputCommitPoint) { + const int prevWordLenOfTop = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(); + int newPrevWordStartIndex = inputCommitPoint; + int charCount = 0; + // Find new word start index + for (int i = 0; i < prevWordLenOfTop; ++i) { + const int c = mDicNodeState.mDicNodeStatePrevWord.getPrevWordCodePointAt(i); + // TODO: Check other separators. + if (c != KEYCODE_SPACE && c != KEYCODE_SINGLE_QUOTE) { + if (charCount == inputCommitPoint) { + newPrevWordStartIndex = i; + break; + } + ++charCount; + } + } + if (!mDicNodeState.mDicNodeStatePrevWord.startsWith( + &topNode->mDicNodeState.mDicNodeStatePrevWord, newPrevWordStartIndex - 1)) { + // Node mismatch. + return false; + } + mDicNodeState.mDicNodeStateInput.truncate(inputCommitPoint); + mDicNodeState.mDicNodeStatePrevWord.truncate(newPrevWordStartIndex); + return true; + } + + void outputResult(int *dest) const { + const uint16_t prevWordLength = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(); + const uint16_t currentDepth = getDepth(); + DicNodeUtils::appendTwoWords(mDicNodeState.mDicNodeStatePrevWord.mPrevWord, + prevWordLength, getOutputWordBuf(), currentDepth, dest); + DUMP_WORD_AND_SCORE("OUTPUT"); + } + + void outputSpacePositionsResult(int *spaceIndices) const { + mDicNodeState.mDicNodeStatePrevWord.outputSpacePositions(spaceIndices); + } + + bool hasMultipleWords() const { + return mDicNodeState.mDicNodeStatePrevWord.getPrevWordCount() > 0; + } + + float getProximityCorrectionCount() const { + return static_cast<float>(mDicNodeState.mDicNodeStateScoring.getProximityCorrectionCount()); + } + + float getEditCorrectionCount() const { + return static_cast<float>(mDicNodeState.mDicNodeStateScoring.getEditCorrectionCount()); + } + + // Used to prune nodes + float getNormalizedCompoundDistance() const { + return mDicNodeState.mDicNodeStateScoring.getNormalizedCompoundDistance(); + } + + // Used to prune nodes + float getNormalizedSpatialDistance() const { + return mDicNodeState.mDicNodeStateScoring.getSpatialDistance() + / static_cast<float>(getInputIndex(0) + 1); + } + + // Used to prune nodes + float getCompoundDistance() const { + return mDicNodeState.mDicNodeStateScoring.getCompoundDistance(); + } + + // Used to prune nodes + float getCompoundDistance(const float languageWeight) const { + return mDicNodeState.mDicNodeStateScoring.getCompoundDistance(languageWeight); + } + + // Note that "cost" means delta for "distance" that is weighted. + float getTotalPrevWordsLanguageCost() const { + return mDicNodeState.mDicNodeStateScoring.getTotalPrevWordsLanguageCost(); + } + + // Used to commit input partially + int getPrevWordNodePos() const { + return mDicNodeState.mDicNodeStatePrevWord.getPrevWordNodePos(); + } + + AK_FORCE_INLINE const int *getOutputWordBuf() const { + return mDicNodeState.mDicNodeStateOutput.mWordBuf; + } + + int getPrevCodePointG(int pointerId) const { + return mDicNodeState.mDicNodeStateInput.getPrevCodePoint(pointerId); + } + + // Whether the current codepoint can be an intentional omission, in which case the traversal + // algorithm will always check for a possible omission here. + bool canBeIntentionalOmission() const { + return isIntentionalOmissionCodePoint(getNodeCodePoint()); + } + + // Whether the omission is so frequent that it should incur zero cost. + bool isZeroCostOmission() const { + // TODO: do not hardcode and read from header + return (getNodeCodePoint() == KEYCODE_SINGLE_QUOTE); + } + + // TODO: remove + float getTerminalDiffCostG(int path) const { + return mDicNodeState.mDicNodeStateInput.getTerminalDiffCost(path); + } + + ////////////////////// + // Temporary getter // + // TODO: Remove // + ////////////////////// + // TODO: Remove once touch path is merged into ProximityInfoState + int getNodeCodePoint() const { + return mDicNodeProperties.getNodeCodePoint(); + } + + //////////////////////////////// + // Utils for cost calculation // + //////////////////////////////// + AK_FORCE_INLINE bool isSameNodeCodePoint(const DicNode *const dicNode) const { + return mDicNodeProperties.getNodeCodePoint() + == dicNode->mDicNodeProperties.getNodeCodePoint(); + } + + // TODO: remove + // TODO: rename getNextInputIndex + int16_t getInputIndex(int pointerId) const { + return mDicNodeState.mDicNodeStateInput.getInputIndex(pointerId); + } + + //////////////////////////////////// + // Getter of features for scoring // + //////////////////////////////////// + float getSpatialDistanceForScoring() const { + return mDicNodeState.mDicNodeStateScoring.getSpatialDistance(); + } + + float getLanguageDistanceForScoring() const { + return mDicNodeState.mDicNodeStateScoring.getLanguageDistance(); + } + + float getLanguageDistanceRatePerWordForScoring() const { + const float langDist = getLanguageDistanceForScoring(); + const float totalWordCount = + static_cast<float>(mDicNodeState.mDicNodeStatePrevWord.getPrevWordCount() + 1); + return langDist / totalWordCount; + } + + float getRawLength() const { + return mDicNodeState.mDicNodeStateScoring.getRawLength(); + } + + bool isLessThanOneErrorForScoring() const { + return mDicNodeState.mDicNodeStateScoring.getEditCorrectionCount() + + mDicNodeState.mDicNodeStateScoring.getProximityCorrectionCount() <= 1; + } + + DoubleLetterLevel getDoubleLetterLevel() const { + return mDicNodeState.mDicNodeStateScoring.getDoubleLetterLevel(); + } + + void setDoubleLetterLevel(DoubleLetterLevel doubleLetterLevel) { + mDicNodeState.mDicNodeStateScoring.setDoubleLetterLevel(doubleLetterLevel); + } + + uint8_t getFlags() const { + return mDicNodeProperties.getFlags(); + } + + int getAttributesPos() const { + return mDicNodeProperties.getAttributesPos(); + } + + inline uint16_t getDepth() const { + return mDicNodeProperties.getDepth(); + } + + AK_FORCE_INLINE void dump(const char *tag) const { +#if DEBUG_DICT + DUMP_WORD_AND_SCORE(tag); +#if DEBUG_DUMP_ERROR + mProfiler.dump(); +#endif +#endif + } + + void setReleaseListener(DicNodeReleaseListener *releaseListener) { + mReleaseListener = releaseListener; + } + + AK_FORCE_INLINE bool compare(const DicNode *right) { + if (!isUsed() && !right->isUsed()) { + // Compare pointer values here for stable comparison + return this > right; + } + if (!isUsed()) { + return true; + } + if (!right->isUsed()) { + return false; + } + const float diff = + right->getNormalizedCompoundDistance() - getNormalizedCompoundDistance(); + static const float MIN_DIFF = 0.000001f; + if (diff > MIN_DIFF) { + return true; + } else if (diff < -MIN_DIFF) { + return false; + } + const int depth = getDepth(); + const int depthDiff = right->getDepth() - depth; + if (depthDiff != 0) { + return depthDiff > 0; + } + for (int i = 0; i < depth; ++i) { + const int codePoint = mDicNodeState.mDicNodeStateOutput.getCodePointAt(i); + const int rightCodePoint = right->mDicNodeState.mDicNodeStateOutput.getCodePointAt(i); + if (codePoint != rightCodePoint) { + return rightCodePoint > codePoint; + } + } + // Compare pointer values here for stable comparison + return this > right; + } + + private: + DicNodeProperties mDicNodeProperties; + DicNodeState mDicNodeState; + // TODO: Remove + bool mIsCachedForNextSuggestion; + bool mIsUsed; + DicNodeReleaseListener *mReleaseListener; + + AK_FORCE_INLINE int getTotalInputIndex() const { + int index = 0; + for (int i = 0; i < MAX_POINTER_COUNT_G; i++) { + index += mDicNodeState.mDicNodeStateInput.getInputIndex(i); + } + return index; + } + + // Caveat: Must not be called outside Weighting + // This restriction is guaranteed by "friend" + AK_FORCE_INLINE void addCost(const float spatialCost, const float languageCost, + const bool doNormalization, const int inputSize, const bool isEditCorrection, + const bool isProximityCorrection) { + if (DEBUG_GEO_FULL) { + LOGI_SHOW_ADD_COST_PROP; + } + mDicNodeState.mDicNodeStateScoring.addCost(spatialCost, languageCost, doNormalization, + inputSize, getTotalInputIndex(), isEditCorrection, isProximityCorrection); + } + + // Caveat: Must not be called outside Weighting + // This restriction is guaranteed by "friend" + AK_FORCE_INLINE void forwardInputIndex(const int pointerId, const int count, + const bool overwritesPrevCodePointByNodeCodePoint) { + if (count == 0) { + return; + } + mDicNodeState.mDicNodeStateInput.forwardInputIndex(pointerId, count); + if (overwritesPrevCodePointByNodeCodePoint) { + mDicNodeState.mDicNodeStateInput.setPrevCodePoint(0, getNodeCodePoint()); + } + } + + AK_FORCE_INLINE void updateInputIndexG(DicNode_InputStateG *inputStateG) { + mDicNodeState.mDicNodeStateInput.updateInputIndexG(inputStateG->mPointerId, + inputStateG->mInputIndex, inputStateG->mPrevCodePoint, + inputStateG->mTerminalDiffCost, inputStateG->mRawLength); + mDicNodeState.mDicNodeStateScoring.addRawLength(inputStateG->mRawLength); + mDicNodeState.mDicNodeStateScoring.setDoubleLetterLevel(inputStateG->mDoubleLetterLevel); + } +}; +} // namespace latinime +#endif // LATINIME_DIC_NODE_H diff --git a/native/jni/src/suggest/core/dicnode/dic_node_priority_queue.h b/native/jni/src/suggest/core/dicnode/dic_node_priority_queue.h new file mode 100644 index 000000000..d3f28a8bd --- /dev/null +++ b/native/jni/src/suggest/core/dicnode/dic_node_priority_queue.h @@ -0,0 +1,213 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DIC_NODE_PRIORITY_QUEUE_H +#define LATINIME_DIC_NODE_PRIORITY_QUEUE_H + +#include <queue> +#include <vector> + +#include "defines.h" +#include "dic_node.h" +#include "dic_node_release_listener.h" + +#define MAX_DIC_NODE_PRIORITY_QUEUE_CAPACITY 200 + +namespace latinime { + +class DicNodePriorityQueue : public DicNodeReleaseListener { + public: + AK_FORCE_INLINE DicNodePriorityQueue() + : MAX_CAPACITY(MAX_DIC_NODE_PRIORITY_QUEUE_CAPACITY), + mMaxSize(MAX_DIC_NODE_PRIORITY_QUEUE_CAPACITY), mDicNodesBuf(), mUnusedNodeIndices(), + mNextUnusedNodeId(0), mDicNodesQueue() { + mDicNodesBuf.resize(MAX_CAPACITY + 1); + mUnusedNodeIndices.resize(MAX_CAPACITY + 1); + reset(); + } + + // Non virtual inline destructor -- never inherit this class + AK_FORCE_INLINE ~DicNodePriorityQueue() {} + + int getSize() const { + return static_cast<int>(mDicNodesQueue.size()); + } + + int getMaxSize() const { + return mMaxSize; + } + + AK_FORCE_INLINE void setMaxSize(const int maxSize) { + mMaxSize = min(maxSize, MAX_CAPACITY); + } + + AK_FORCE_INLINE void reset() { + clearAndResize(MAX_CAPACITY); + } + + AK_FORCE_INLINE void clear() { + clearAndResize(mMaxSize); + } + + AK_FORCE_INLINE void clearAndResize(const int maxSize) { + while (!mDicNodesQueue.empty()) { + mDicNodesQueue.pop(); + } + setMaxSize(maxSize); + for (int i = 0; i < MAX_CAPACITY + 1; ++i) { + mDicNodesBuf[i].remove(); + mDicNodesBuf[i].setReleaseListener(this); + mUnusedNodeIndices[i] = i == MAX_CAPACITY ? NOT_A_NODE_ID : static_cast<int>(i) + 1; + } + mNextUnusedNodeId = 0; + } + + AK_FORCE_INLINE DicNode *newDicNode(DicNode *dicNode) { + DicNode *newNode = searchEmptyDicNode(); + if (newNode) { + DicNodeUtils::initByCopy(dicNode, newNode); + return newNode; + } + return 0; + } + + // Copy + AK_FORCE_INLINE DicNode *copyPush(DicNode *dicNode) { + return copyPush(dicNode, mMaxSize); + } + + AK_FORCE_INLINE void copyPop(DicNode *dest) { + if (mDicNodesQueue.empty()) { + ASSERT(false); + return; + } + DicNode *node = mDicNodesQueue.top(); + if (dest) { + DicNodeUtils::initByCopy(node, dest); + } + node->remove(); + mDicNodesQueue.pop(); + } + + void onReleased(DicNode *dicNode) { + const int index = static_cast<int>(dicNode - &mDicNodesBuf[0]); + if (mUnusedNodeIndices[index] != NOT_A_NODE_ID) { + // it's already released + return; + } + mUnusedNodeIndices[index] = mNextUnusedNodeId; + mNextUnusedNodeId = index; + ASSERT(index >= 0 && index < (MAX_CAPACITY + 1)); + } + + AK_FORCE_INLINE void dump() const { + AKLOGI("\n\n\n\n\n==========================="); + for (int i = 0; i < MAX_CAPACITY + 1; ++i) { + if (mDicNodesBuf[i].isUsed()) { + mDicNodesBuf[i].dump("QUEUE: "); + } + } + AKLOGI("===========================\n\n\n\n\n"); + } + + private: + DISALLOW_COPY_AND_ASSIGN(DicNodePriorityQueue); + static const int NOT_A_NODE_ID = -1; + + AK_FORCE_INLINE static bool compareDicNode(DicNode *left, DicNode *right) { + return left->compare(right); + } + + struct DicNodeComparator { + bool operator ()(DicNode *left, DicNode *right) { + return compareDicNode(left, right); + } + }; + + typedef std::priority_queue<DicNode *, std::vector<DicNode *>, DicNodeComparator> DicNodesQueue; + const int MAX_CAPACITY; + int mMaxSize; + std::vector<DicNode> mDicNodesBuf; // of each element of mDicNodesBuf respectively + std::vector<int> mUnusedNodeIndices; + int mNextUnusedNodeId; + DicNodesQueue mDicNodesQueue; + + inline bool isFull(const int maxSize) const { + return getSize() >= maxSize; + } + + AK_FORCE_INLINE void pop() { + copyPop(0); + } + + AK_FORCE_INLINE bool betterThanWorstDicNode(DicNode *dicNode) const { + DicNode *worstNode = mDicNodesQueue.top(); + if (!worstNode) { + return true; + } + return compareDicNode(dicNode, worstNode); + } + + AK_FORCE_INLINE DicNode *searchEmptyDicNode() { + // TODO: Currently O(n) but should be improved to O(1) + if (MAX_CAPACITY == 0) { + return 0; + } + if (mNextUnusedNodeId == NOT_A_NODE_ID) { + AKLOGI("No unused node found."); + for (int i = 0; i < MAX_CAPACITY + 1; ++i) { + AKLOGI("Dump node availability, %d, %d, %d", + i, mDicNodesBuf[i].isUsed(), mUnusedNodeIndices[i]); + } + ASSERT(false); + return 0; + } + DicNode *dicNode = &mDicNodesBuf[mNextUnusedNodeId]; + markNodeAsUsed(dicNode); + return dicNode; + } + + AK_FORCE_INLINE void markNodeAsUsed(DicNode *dicNode) { + const int index = static_cast<int>(dicNode - &mDicNodesBuf[0]); + mNextUnusedNodeId = mUnusedNodeIndices[index]; + mUnusedNodeIndices[index] = NOT_A_NODE_ID; + ASSERT(index >= 0 && index < (MAX_CAPACITY + 1)); + } + + AK_FORCE_INLINE DicNode *pushPoolNodeWithMaxSize(DicNode *dicNode, const int maxSize) { + if (!dicNode) { + return 0; + } + if (!isFull(maxSize)) { + mDicNodesQueue.push(dicNode); + return dicNode; + } + if (betterThanWorstDicNode(dicNode)) { + pop(); + mDicNodesQueue.push(dicNode); + return dicNode; + } + dicNode->remove(); + return 0; + } + + // Copy + AK_FORCE_INLINE DicNode *copyPush(DicNode *dicNode, const int maxSize) { + return pushPoolNodeWithMaxSize(newDicNode(dicNode), maxSize); + } +}; +} // namespace latinime +#endif // LATINIME_DIC_NODE_PRIORITY_QUEUE_H diff --git a/native/jni/src/suggest/core/dicnode/dic_node_profiler.h b/native/jni/src/suggest/core/dicnode/dic_node_profiler.h new file mode 100644 index 000000000..90f75d0c6 --- /dev/null +++ b/native/jni/src/suggest/core/dicnode/dic_node_profiler.h @@ -0,0 +1,181 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DIC_NODE_PROFILER_H +#define LATINIME_DIC_NODE_PROFILER_H + +#include "defines.h" + +#if DEBUG_DICT +#define PROF_SPACE_SUBSTITUTION(profiler) profiler.profSpaceSubstitution() +#define PROF_SPACE_OMISSION(profiler) profiler.profSpaceOmission() +#define PROF_ADDITIONAL_PROXIMITY(profiler) profiler.profAdditionalProximity() +#define PROF_SUBSTITUTION(profiler) profiler.profSubstitution() +#define PROF_OMISSION(profiler) profiler.profOmission() +#define PROF_INSERTION(profiler) profiler.profInsertion() +#define PROF_MATCH(profiler) profiler.profMatch() +#define PROF_COMPLETION(profiler) profiler.profCompletion() +#define PROF_TRANSPOSITION(profiler) profiler.profTransposition() +#define PROF_NEARESTKEY(profiler) profiler.profNearestKey() +#define PROF_TERMINAL(profiler) profiler.profTerminal() +#define PROF_NEW_WORD(profiler) profiler.profNewWord() +#define PROF_NEW_WORD_BIGRAM(profiler) profiler.profNewWordBigram() +#define PROF_NODE_RESET(profiler) profiler.reset() +#define PROF_NODE_COPY(src, dest) dest.copy(src) +#else +#define PROF_SPACE_SUBSTITUTION(profiler) +#define PROF_SPACE_OMISSION(profiler) +#define PROF_ADDITONAL_PROXIMITY(profiler) +#define PROF_SUBSTITUTION(profiler) +#define PROF_OMISSION(profiler) +#define PROF_INSERTION(profiler) +#define PROF_MATCH(profiler) +#define PROF_COMPLETION(profiler) +#define PROF_TRANSPOSITION(profiler) +#define PROF_NEARESTKEY(profiler) +#define PROF_TERMINAL(profiler) +#define PROF_NEW_WORD(profiler) +#define PROF_NEW_WORD_BIGRAM(profiler) +#define PROF_NODE_RESET(profiler) +#define PROF_NODE_COPY(src, dest) +#endif + +namespace latinime { + +class DicNodeProfiler { + public: +#if DEBUG_DICT + AK_FORCE_INLINE DicNodeProfiler() + : mProfOmission(0), mProfInsertion(0), mProfTransposition(0), + mProfAdditionalProximity(0), mProfSubstitution(0), + mProfSpaceSubstitution(0), mProfSpaceOmission(0), + mProfMatch(0), mProfCompletion(0), mProfTerminal(0), + mProfNearestKey(0), mProfNewWord(0), mProfNewWordBigram(0) {} + + int mProfOmission; + int mProfInsertion; + int mProfTransposition; + int mProfAdditionalProximity; + int mProfSubstitution; + int mProfSpaceSubstitution; + int mProfSpaceOmission; + int mProfMatch; + int mProfCompletion; + int mProfTerminal; + int mProfNearestKey; + int mProfNewWord; + int mProfNewWordBigram; + + void profSpaceSubstitution() { + ++mProfSpaceSubstitution; + } + + void profSpaceOmission() { + ++mProfSpaceOmission; + } + + void profAdditionalProximity() { + ++mProfAdditionalProximity; + } + + void profSubstitution() { + ++mProfSubstitution; + } + + void profOmission() { + ++mProfOmission; + } + + void profInsertion() { + ++mProfInsertion; + } + + void profMatch() { + ++mProfMatch; + } + + void profCompletion() { + ++mProfCompletion; + } + + void profTransposition() { + ++mProfTransposition; + } + + void profNearestKey() { + ++mProfNearestKey; + } + + void profTerminal() { + ++mProfTerminal; + } + + void profNewWord() { + ++mProfNewWord; + } + + void profNewWordBigram() { + ++mProfNewWordBigram; + } + + void reset() { + mProfSpaceSubstitution = 0; + mProfSpaceOmission = 0; + mProfAdditionalProximity = 0; + mProfSubstitution = 0; + mProfOmission = 0; + mProfInsertion = 0; + mProfMatch = 0; + mProfCompletion = 0; + mProfTransposition = 0; + mProfNearestKey = 0; + mProfTerminal = 0; + mProfNewWord = 0; + mProfNewWordBigram = 0; + } + + void copy(const DicNodeProfiler *const profiler) { + mProfSpaceSubstitution = profiler->mProfSpaceSubstitution; + mProfSpaceOmission = profiler->mProfSpaceOmission; + mProfAdditionalProximity = profiler->mProfAdditionalProximity; + mProfSubstitution = profiler->mProfSubstitution; + mProfOmission = profiler->mProfOmission; + mProfInsertion = profiler->mProfInsertion; + mProfMatch = profiler->mProfMatch; + mProfCompletion = profiler->mProfCompletion; + mProfTransposition = profiler->mProfTransposition; + mProfNearestKey = profiler->mProfNearestKey; + mProfTerminal = profiler->mProfTerminal; + mProfNewWord = profiler->mProfNewWord; + mProfNewWordBigram = profiler->mProfNewWordBigram; + } + + void dump() const { + AKLOGI("O %d, I %d, T %d, AP %d, S %d, SS %d, SO %d, M %d, C %d, TE %d, NW = %d, NWB = %d", + mProfOmission, mProfInsertion, mProfTransposition, mProfAdditionalProximity, + mProfSubstitution, mProfSpaceSubstitution, mProfSpaceOmission, mProfMatch, + mProfCompletion, mProfTerminal, mProfNewWord, mProfNewWordBigram); + } +#else + DicNodeProfiler() {} +#endif + private: + // Caution!!! + // Use a default copy constructor and an assign operator because shallow copies are ok + // for this class +}; +} +#endif // LATINIME_DIC_NODE_PROFILER_H diff --git a/native/jni/src/suggest/core/dicnode/dic_node_properties.h b/native/jni/src/suggest/core/dicnode/dic_node_properties.h new file mode 100644 index 000000000..173ef35d0 --- /dev/null +++ b/native/jni/src/suggest/core/dicnode/dic_node_properties.h @@ -0,0 +1,173 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DIC_NODE_PROPERTIES_H +#define LATINIME_DIC_NODE_PROPERTIES_H + +#include <stdint.h> + +#include "defines.h" + +namespace latinime { + +/** + * Node for traversing the lexicon trie. + */ +class DicNodeProperties { + public: + AK_FORCE_INLINE DicNodeProperties() + : mPos(0), mFlags(0), mChildrenPos(0), mAttributesPos(0), mSiblingPos(0), + mChildrenCount(0), mProbability(0), mBigramProbability(0), mNodeCodePoint(0), + mDepth(0), mLeavingDepth(0), mIsTerminal(false), mHasMultipleChars(false), + mHasChildren(false) { + } + + virtual ~DicNodeProperties() {} + + // Should be called only once per DicNode is initialized. + void init(const int pos, const uint8_t flags, const int childrenPos, const int attributesPos, + const int siblingPos, const int nodeCodePoint, const int childrenCount, + const int probability, const int bigramProbability, const bool isTerminal, + const bool hasMultipleChars, const bool hasChildren, const uint16_t depth, + const uint16_t terminalDepth) { + mPos = pos; + mFlags = flags; + mChildrenPos = childrenPos; + mAttributesPos = attributesPos; + mSiblingPos = siblingPos; + mNodeCodePoint = nodeCodePoint; + mChildrenCount = childrenCount; + mProbability = probability; + mBigramProbability = bigramProbability; + mIsTerminal = isTerminal; + mHasMultipleChars = hasMultipleChars; + mHasChildren = hasChildren; + mDepth = depth; + mLeavingDepth = terminalDepth; + } + + // Init for copy + void init(const DicNodeProperties *const nodeProp) { + mPos = nodeProp->mPos; + mFlags = nodeProp->mFlags; + mChildrenPos = nodeProp->mChildrenPos; + mAttributesPos = nodeProp->mAttributesPos; + mSiblingPos = nodeProp->mSiblingPos; + mNodeCodePoint = nodeProp->mNodeCodePoint; + mChildrenCount = nodeProp->mChildrenCount; + mProbability = nodeProp->mProbability; + mBigramProbability = nodeProp->mBigramProbability; + mIsTerminal = nodeProp->mIsTerminal; + mHasMultipleChars = nodeProp->mHasMultipleChars; + mHasChildren = nodeProp->mHasChildren; + mDepth = nodeProp->mDepth; + mLeavingDepth = nodeProp->mLeavingDepth; + } + + // Init as passing child + void init(const DicNodeProperties *const nodeProp, const int codePoint) { + mPos = nodeProp->mPos; + mFlags = nodeProp->mFlags; + mChildrenPos = nodeProp->mChildrenPos; + mAttributesPos = nodeProp->mAttributesPos; + mSiblingPos = nodeProp->mSiblingPos; + mNodeCodePoint = codePoint; // Overwrite the node char of a passing child + mChildrenCount = nodeProp->mChildrenCount; + mProbability = nodeProp->mProbability; + mBigramProbability = nodeProp->mBigramProbability; + mIsTerminal = nodeProp->mIsTerminal; + mHasMultipleChars = nodeProp->mHasMultipleChars; + mHasChildren = nodeProp->mHasChildren; + mDepth = nodeProp->mDepth + 1; // Increment the depth of a passing child + mLeavingDepth = nodeProp->mLeavingDepth; + } + + int getPos() const { + return mPos; + } + + uint8_t getFlags() const { + return mFlags; + } + + int getChildrenPos() const { + return mChildrenPos; + } + + int getAttributesPos() const { + return mAttributesPos; + } + + int getChildrenCount() const { + return mChildrenCount; + } + + int getProbability() const { + return mProbability; + } + + int getNodeCodePoint() const { + return mNodeCodePoint; + } + + uint16_t getDepth() const { + return mDepth; + } + + // TODO: Move to output? + uint16_t getLeavingDepth() const { + return mLeavingDepth; + } + + bool isTerminal() const { + return mIsTerminal; + } + + bool hasMultipleChars() const { + return mHasMultipleChars; + } + + bool hasChildren() const { + return mChildrenCount > 0 || mDepth != mLeavingDepth; + } + + private: + // Caution!!! + // Use a default copy constructor and an assign operator because shallow copies are ok + // for this class + + // Not used + int getSiblingPos() const { + return mSiblingPos; + } + + int mPos; + uint8_t mFlags; + int mChildrenPos; + int mAttributesPos; + int mSiblingPos; + int mChildrenCount; + int mProbability; + int mBigramProbability; // not used for now + int mNodeCodePoint; + uint16_t mDepth; + uint16_t mLeavingDepth; + bool mIsTerminal; + bool mHasMultipleChars; + bool mHasChildren; +}; +} // namespace latinime +#endif // LATINIME_DIC_NODE_PROPERTIES_H diff --git a/native/jni/src/suggest/core/dicnode/dic_node_release_listener.h b/native/jni/src/suggest/core/dicnode/dic_node_release_listener.h new file mode 100644 index 000000000..2a81c3cae --- /dev/null +++ b/native/jni/src/suggest/core/dicnode/dic_node_release_listener.h @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DIC_NODE_RELEASE_LISTENER_H +#define LATINIME_DIC_NODE_RELEASE_LISTENER_H + +#include "defines.h" + +namespace latinime { + +class DicNodeReleaseListener { + public: + DicNodeReleaseListener() {} + virtual ~DicNodeReleaseListener() {} + virtual void onReleased(DicNode *dicNode) = 0; + private: + DISALLOW_COPY_AND_ASSIGN(DicNodeReleaseListener); +}; +} // namespace latinime +#endif // LATINIME_DIC_NODE_RELEASE_LISTENER_H diff --git a/native/jni/src/suggest/core/dicnode/dic_node_state.h b/native/jni/src/suggest/core/dicnode/dic_node_state.h new file mode 100644 index 000000000..239b63c32 --- /dev/null +++ b/native/jni/src/suggest/core/dicnode/dic_node_state.h @@ -0,0 +1,71 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DIC_NODE_STATE_H +#define LATINIME_DIC_NODE_STATE_H + +#include "defines.h" +#include "dic_node_state_input.h" +#include "dic_node_state_output.h" +#include "dic_node_state_prevword.h" +#include "dic_node_state_scoring.h" + +namespace latinime { + +class DicNodeState { + public: + DicNodeStateInput mDicNodeStateInput; + DicNodeStateOutput mDicNodeStateOutput; + DicNodeStatePrevWord mDicNodeStatePrevWord; + DicNodeStateScoring mDicNodeStateScoring; + + AK_FORCE_INLINE DicNodeState() + : mDicNodeStateInput(), mDicNodeStateOutput(), mDicNodeStatePrevWord(), + mDicNodeStateScoring() { + } + + virtual ~DicNodeState() {} + + // Init with prevWordPos + void init(const int prevWordPos) { + mDicNodeStateInput.init(); + mDicNodeStateOutput.init(); + mDicNodeStatePrevWord.init(prevWordPos); + mDicNodeStateScoring.init(); + } + + // Init by copy + AK_FORCE_INLINE void init(const DicNodeState *const src) { + mDicNodeStateInput.init(&src->mDicNodeStateInput); + mDicNodeStateOutput.init(&src->mDicNodeStateOutput); + mDicNodeStatePrevWord.init(&src->mDicNodeStatePrevWord); + mDicNodeStateScoring.init(&src->mDicNodeStateScoring); + } + + // Init by copy and adding subword + void init(const DicNodeState *const src, const uint16_t additionalSubwordLength, + const int *const additionalSubword) { + init(src); + mDicNodeStateOutput.addSubword(additionalSubwordLength, additionalSubword); + } + + private: + // Caution!!! + // Use a default copy constructor and an assign operator because shallow copies are ok + // for this class +}; +} // namespace latinime +#endif // LATINIME_DIC_NODE_STATE_H diff --git a/native/jni/src/suggest/core/dicnode/dic_node_state_input.h b/native/jni/src/suggest/core/dicnode/dic_node_state_input.h new file mode 100644 index 000000000..7ad3e3e5f --- /dev/null +++ b/native/jni/src/suggest/core/dicnode/dic_node_state_input.h @@ -0,0 +1,100 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DIC_NODE_STATE_INPUT_H +#define LATINIME_DIC_NODE_STATE_INPUT_H + +#include "defines.h" + +namespace latinime { + +// TODO: Have a .cpp for this class +class DicNodeStateInput { + public: + DicNodeStateInput() {} + virtual ~DicNodeStateInput() {} + + // TODO: Merge into DicNodeStatePrevWord::truncate + void truncate(const int commitPoint) { + mInputIndex[0] -= commitPoint; + } + + void init() { + for (int i = 0; i < MAX_POINTER_COUNT_G; i++) { + // TODO: The initial value for mInputIndex should be -1? + //mInputIndex[i] = i == 0 ? 0 : -1; + mInputIndex[i] = 0; + mPrevCodePoint[i] = NOT_A_CODE_POINT; + mTerminalDiffCost[i] = static_cast<float>(MAX_VALUE_FOR_WEIGHTING); + } + } + + void init(const DicNodeStateInput *const src, const bool resetTerminalDiffCost) { + for (int i = 0; i < MAX_POINTER_COUNT_G; i++) { + mInputIndex[i] = src->mInputIndex[i]; + mPrevCodePoint[i] = src->mPrevCodePoint[i]; + mTerminalDiffCost[i] = resetTerminalDiffCost ? + static_cast<float>(MAX_VALUE_FOR_WEIGHTING) : src->mTerminalDiffCost[i]; + } + } + + void updateInputIndexG(const int pointerId, const int inputIndex, + const int prevCodePoint, const float terminalDiffCost, const float rawLength) { + mInputIndex[pointerId] = inputIndex; + mPrevCodePoint[pointerId] = prevCodePoint; + mTerminalDiffCost[pointerId] = terminalDiffCost; + } + + void init(const DicNodeStateInput *const src) { + init(src, false); + } + + // For transposition + void setPrevCodePoint(const int pointerId, const int c) { + mPrevCodePoint[pointerId] = c; + } + + void forwardInputIndex(const int pointerId, const int val) { + if (mInputIndex[pointerId] < 0) { + mInputIndex[pointerId] = val; + } else { + mInputIndex[pointerId] = mInputIndex[pointerId] + val; + } + } + + int getInputIndex(const int pointerId) const { + // when "inputIndex" exceeds "inputSize", auto-completion needs to be done + return mInputIndex[pointerId]; + } + + int getPrevCodePoint(const int pointerId) const { + return mPrevCodePoint[pointerId]; + } + + float getTerminalDiffCost(const int pointerId) const { + return mTerminalDiffCost[pointerId]; + } + + private: + // Caution!!! + // Use a default copy constructor and an assign operator because shallow copies are ok + // for this class + int mInputIndex[MAX_POINTER_COUNT_G]; + int mPrevCodePoint[MAX_POINTER_COUNT_G]; + float mTerminalDiffCost[MAX_POINTER_COUNT_G]; +}; +} // namespace latinime +#endif // LATINIME_DIC_NODE_STATE_INPUT_H diff --git a/native/jni/src/suggest/core/dicnode/dic_node_state_output.h b/native/jni/src/suggest/core/dicnode/dic_node_state_output.h new file mode 100644 index 000000000..1d4f50a06 --- /dev/null +++ b/native/jni/src/suggest/core/dicnode/dic_node_state_output.h @@ -0,0 +1,75 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DIC_NODE_STATE_OUTPUT_H +#define LATINIME_DIC_NODE_STATE_OUTPUT_H + +#include <cstring> // for memcpy() +#include <stdint.h> + +#include "defines.h" + +namespace latinime { + +class DicNodeStateOutput { + public: + DicNodeStateOutput() : mOutputtedLength(0) { + init(); + } + + virtual ~DicNodeStateOutput() {} + + void init() { + mOutputtedLength = 0; + mWordBuf[0] = 0; + } + + void init(const DicNodeStateOutput *const stateOutput) { + memcpy(mWordBuf, stateOutput->mWordBuf, + stateOutput->mOutputtedLength * sizeof(mWordBuf[0])); + mOutputtedLength = stateOutput->mOutputtedLength; + if (mOutputtedLength < MAX_WORD_LENGTH) { + mWordBuf[mOutputtedLength] = 0; + } + } + + void addSubword(const uint16_t additionalSubwordLength, const int *const additionalSubword) { + if (additionalSubword) { + memcpy(&mWordBuf[mOutputtedLength], additionalSubword, + additionalSubwordLength * sizeof(mWordBuf[0])); + mOutputtedLength = static_cast<uint16_t>(mOutputtedLength + additionalSubwordLength); + if (mOutputtedLength < MAX_WORD_LENGTH) { + mWordBuf[mOutputtedLength] = 0; + } + } + } + + // TODO: Remove + int getCodePointAt(const int id) const { + return mWordBuf[id]; + } + + // TODO: Move to private + int mWordBuf[MAX_WORD_LENGTH]; + + private: + // Caution!!! + // Use a default copy constructor and an assign operator because shallow copies are ok + // for this class + uint16_t mOutputtedLength; +}; +} // namespace latinime +#endif // LATINIME_DIC_NODE_STATE_OUTPUT_H diff --git a/native/jni/src/suggest/core/dicnode/dic_node_state_prevword.h b/native/jni/src/suggest/core/dicnode/dic_node_state_prevword.h new file mode 100644 index 000000000..e3b892bda --- /dev/null +++ b/native/jni/src/suggest/core/dicnode/dic_node_state_prevword.h @@ -0,0 +1,156 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DIC_NODE_STATE_PREVWORD_H +#define LATINIME_DIC_NODE_STATE_PREVWORD_H + +#include <cstring> // for memset() +#include <stdint.h> + +#include "defines.h" +#include "dic_node_utils.h" + +namespace latinime { + +class DicNodeStatePrevWord { + public: + AK_FORCE_INLINE DicNodeStatePrevWord() + : mPrevWordCount(0), mPrevWordLength(0), mPrevWordStart(0), mPrevWordProbability(0), + mPrevWordNodePos(0) { + memset(mPrevWord, 0, sizeof(mPrevWord)); + memset(mPrevSpacePositions, 0, sizeof(mPrevSpacePositions)); + } + + virtual ~DicNodeStatePrevWord() {} + + void init() { + mPrevWordLength = 0; + mPrevWordCount = 0; + mPrevWordStart = 0; + mPrevWordProbability = -1; + mPrevWordNodePos = NOT_VALID_WORD; + memset(mPrevSpacePositions, 0, sizeof(mPrevSpacePositions)); + } + + void init(const int prevWordNodePos) { + mPrevWordLength = 0; + mPrevWordCount = 0; + mPrevWordStart = 0; + mPrevWordProbability = -1; + mPrevWordNodePos = prevWordNodePos; + memset(mPrevSpacePositions, 0, sizeof(mPrevSpacePositions)); + } + + // Init by copy + AK_FORCE_INLINE void init(const DicNodeStatePrevWord *const prevWord) { + mPrevWordLength = prevWord->mPrevWordLength; + mPrevWordCount = prevWord->mPrevWordCount; + mPrevWordStart = prevWord->mPrevWordStart; + mPrevWordProbability = prevWord->mPrevWordProbability; + mPrevWordNodePos = prevWord->mPrevWordNodePos; + memcpy(mPrevWord, prevWord->mPrevWord, prevWord->mPrevWordLength * sizeof(mPrevWord[0])); + memcpy(mPrevSpacePositions, prevWord->mPrevSpacePositions, sizeof(mPrevSpacePositions)); + } + + void init(const int16_t prevWordCount, const int16_t prevWordProbability, + const int prevWordNodePos, const int *const src0, const int16_t length0, + const int *const src1, const int16_t length1, const int *const prevSpacePositions, + const int lastInputIndex) { + mPrevWordCount = prevWordCount; + mPrevWordProbability = prevWordProbability; + mPrevWordNodePos = prevWordNodePos; + const int twoWordsLen = + DicNodeUtils::appendTwoWords(src0, length0, src1, length1, mPrevWord); + mPrevWord[twoWordsLen] = KEYCODE_SPACE; + mPrevWordStart = length0; + mPrevWordLength = static_cast<int16_t>(twoWordsLen + 1); + memcpy(mPrevSpacePositions, prevSpacePositions, sizeof(mPrevSpacePositions)); + mPrevSpacePositions[mPrevWordCount - 1] = lastInputIndex; + } + + void truncate(const int offset) { + // TODO: memmove + if (mPrevWordLength < offset) { + memset(mPrevWord, 0, sizeof(mPrevWord)); + mPrevWordLength = 0; + return; + } + const int newPrevWordLength = mPrevWordLength - offset; + memmove(mPrevWord, &mPrevWord[offset], newPrevWordLength * sizeof(mPrevWord[0])); + mPrevWordLength = newPrevWordLength; + } + + void outputSpacePositions(int *spaceIndices) const { + // Convert uint16_t to int + for (int i = 0; i < MAX_RESULTS; i++) { + spaceIndices[i] = mPrevSpacePositions[i]; + } + } + + // TODO: remove + int16_t getPrevWordLength() const { + return mPrevWordLength; + } + + int16_t getPrevWordCount() const { + return mPrevWordCount; + } + + int16_t getPrevWordStart() const { + return mPrevWordStart; + } + + int16_t getPrevWordProbability() const { + return mPrevWordProbability; + } + + int getPrevWordNodePos() const { + return mPrevWordNodePos; + } + + int getPrevWordCodePointAt(const int id) const { + return mPrevWord[id]; + } + + bool startsWith(const DicNodeStatePrevWord *const prefix, const int prefixLen) const { + if (prefixLen > mPrevWordLength) { + return false; + } + for (int i = 0; i < prefixLen; ++i) { + if (mPrevWord[i] != prefix->mPrevWord[i]) { + return false; + } + } + return true; + } + + // TODO: Move to private + int mPrevWord[MAX_WORD_LENGTH]; + // TODO: Move to private + int mPrevSpacePositions[MAX_RESULTS]; + + private: + // Caution!!! + // Use a default copy constructor and an assign operator because shallow copies are ok + // for this class + int16_t mPrevWordCount; + int16_t mPrevWordLength; + int16_t mPrevWordStart; + int16_t mPrevWordProbability; + int mPrevWordNodePos; +}; +} // namespace latinime +#endif // LATINIME_DIC_NODE_STATE_PREVWORD_H diff --git a/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h b/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h new file mode 100644 index 000000000..8e816329f --- /dev/null +++ b/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h @@ -0,0 +1,166 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DIC_NODE_STATE_SCORING_H +#define LATINIME_DIC_NODE_STATE_SCORING_H + +#include <stdint.h> + +#include "defines.h" + +namespace latinime { + +class DicNodeStateScoring { + public: + AK_FORCE_INLINE DicNodeStateScoring() + : mDoubleLetterLevel(NOT_A_DOUBLE_LETTER), + mEditCorrectionCount(0), mProximityCorrectionCount(0), + mNormalizedCompoundDistance(0.0f), mSpatialDistance(0.0f), mLanguageDistance(0.0f), + mTotalPrevWordsLanguageCost(0.0f), mRawLength(0.0f) { + } + + virtual ~DicNodeStateScoring() {} + + void init() { + mEditCorrectionCount = 0; + mProximityCorrectionCount = 0; + mNormalizedCompoundDistance = 0.0f; + mSpatialDistance = 0.0f; + mLanguageDistance = 0.0f; + mTotalPrevWordsLanguageCost = 0.0f; + mRawLength = 0.0f; + mDoubleLetterLevel = NOT_A_DOUBLE_LETTER; + } + + AK_FORCE_INLINE void init(const DicNodeStateScoring *const scoring) { + mEditCorrectionCount = scoring->mEditCorrectionCount; + mProximityCorrectionCount = scoring->mProximityCorrectionCount; + mNormalizedCompoundDistance = scoring->mNormalizedCompoundDistance; + mSpatialDistance = scoring->mSpatialDistance; + mLanguageDistance = scoring->mLanguageDistance; + mTotalPrevWordsLanguageCost = scoring->mTotalPrevWordsLanguageCost; + mRawLength = scoring->mRawLength; + mDoubleLetterLevel = scoring->mDoubleLetterLevel; + } + + void addCost(const float spatialCost, const float languageCost, const bool doNormalization, + const int inputSize, const int totalInputIndex, const bool isEditCorrection, + const bool isProximityCorrection) { + addDistance(spatialCost, languageCost, doNormalization, inputSize, totalInputIndex); + if (isEditCorrection) { + ++mEditCorrectionCount; + } + if (isProximityCorrection) { + ++mProximityCorrectionCount; + } + if (languageCost > 0.0f) { + setTotalPrevWordsLanguageCost(mTotalPrevWordsLanguageCost + languageCost); + } + } + + void addRawLength(const float rawLength) { + mRawLength += rawLength; + } + + float getCompoundDistance() const { + return getCompoundDistance(1.0f); + } + + float getCompoundDistance(const float languageWeight) const { + return mSpatialDistance + mLanguageDistance * languageWeight; + } + + float getNormalizedCompoundDistance() const { + return mNormalizedCompoundDistance; + } + + float getSpatialDistance() const { + return mSpatialDistance; + } + + float getLanguageDistance() const { + return mLanguageDistance; + } + + int16_t getEditCorrectionCount() const { + return mEditCorrectionCount; + } + + int16_t getProximityCorrectionCount() const { + return mProximityCorrectionCount; + } + + float getRawLength() const { + return mRawLength; + } + + DoubleLetterLevel getDoubleLetterLevel() const { + return mDoubleLetterLevel; + } + + void setDoubleLetterLevel(DoubleLetterLevel doubleLetterLevel) { + switch(doubleLetterLevel) { + case NOT_A_DOUBLE_LETTER: + break; + case A_DOUBLE_LETTER: + if (mDoubleLetterLevel != A_STRONG_DOUBLE_LETTER) { + mDoubleLetterLevel = doubleLetterLevel; + } + break; + case A_STRONG_DOUBLE_LETTER: + mDoubleLetterLevel = doubleLetterLevel; + break; + } + } + + float getTotalPrevWordsLanguageCost() const { + return mTotalPrevWordsLanguageCost; + } + + private: + // Caution!!! + // Use a default copy constructor and an assign operator because shallow copies are ok + // for this class + DoubleLetterLevel mDoubleLetterLevel; + + int16_t mEditCorrectionCount; + int16_t mProximityCorrectionCount; + + float mNormalizedCompoundDistance; + float mSpatialDistance; + float mLanguageDistance; + float mTotalPrevWordsLanguageCost; + float mRawLength; + + AK_FORCE_INLINE void addDistance(float spatialDistance, float languageDistance, + bool doNormalization, int inputSize, int totalInputIndex) { + mSpatialDistance += spatialDistance; + mLanguageDistance += languageDistance; + if (!doNormalization) { + mNormalizedCompoundDistance = mSpatialDistance + mLanguageDistance; + } else { + mNormalizedCompoundDistance = (mSpatialDistance + mLanguageDistance) + / static_cast<float>(max(1, totalInputIndex)); + } + } + + //TODO: remove + AK_FORCE_INLINE void setTotalPrevWordsLanguageCost(float totalPrevWordsLanguageCost) { + mTotalPrevWordsLanguageCost = totalPrevWordsLanguageCost; + } +}; +} // namespace latinime +#endif // LATINIME_DIC_NODE_STATE_SCORING_H diff --git a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp new file mode 100644 index 000000000..031e706ae --- /dev/null +++ b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp @@ -0,0 +1,335 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <cstring> +#include <vector> + +#include "binary_format.h" +#include "dic_node.h" +#include "dic_node_utils.h" +#include "dic_node_vector.h" +#include "proximity_info.h" +#include "proximity_info_state.h" + +namespace latinime { + +/////////////////////////////// +// Node initialization utils // +/////////////////////////////// + +/* static */ void DicNodeUtils::initAsRoot(const int rootPos, const uint8_t *const dicRoot, + const int prevWordNodePos, DicNode *newRootNode) { + int curPos = rootPos; + const int pos = curPos; + const int childrenCount = BinaryFormat::getGroupCountAndForwardPointer(dicRoot, &curPos); + const int childrenPos = curPos; + newRootNode->initAsRoot(pos, childrenPos, childrenCount, prevWordNodePos); +} + +/*static */ void DicNodeUtils::initAsRootWithPreviousWord(const int rootPos, + const uint8_t *const dicRoot, DicNode *prevWordLastNode, DicNode *newRootNode) { + int curPos = rootPos; + const int pos = curPos; + const int childrenCount = BinaryFormat::getGroupCountAndForwardPointer(dicRoot, &curPos); + const int childrenPos = curPos; + newRootNode->initAsRootWithPreviousWord(prevWordLastNode, pos, childrenPos, childrenCount); +} + +/* static */ void DicNodeUtils::initByCopy(DicNode *srcNode, DicNode *destNode) { + destNode->initByCopy(srcNode); +} + +/////////////////////////////////// +// Traverse node expansion utils // +/////////////////////////////////// + +/* static */ void DicNodeUtils::createAndGetPassingChildNode(DicNode *dicNode, + const ProximityInfoState *pInfoState, const int pointIndex, const bool exactOnly, + DicNodeVector *childDicNodes) { + // Passing multiple chars node. No need to traverse child + const int codePoint = dicNode->getNodeTypedCodePoint(); + const int baseLowerCaseCodePoint = toBaseLowerCase(codePoint); + const bool isMatch = isMatchedNodeCodePoint(pInfoState, pointIndex, exactOnly, codePoint); + if (isMatch || isIntentionalOmissionCodePoint(baseLowerCaseCodePoint)) { + childDicNodes->pushPassingChild(dicNode); + } +} + +/* static */ int DicNodeUtils::createAndGetLeavingChildNode(DicNode *dicNode, int pos, + const uint8_t *const dicRoot, const int terminalDepth, const ProximityInfoState *pInfoState, + const int pointIndex, const bool exactOnly, const std::vector<int> *const codePointsFilter, + const ProximityInfo *const pInfo, DicNodeVector *childDicNodes) { + int nextPos = pos; + const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(dicRoot, &pos); + const bool hasMultipleChars = (0 != (BinaryFormat::FLAG_HAS_MULTIPLE_CHARS & flags)); + const bool isTerminal = (0 != (BinaryFormat::FLAG_IS_TERMINAL & flags)); + const bool hasChildren = BinaryFormat::hasChildrenInFlags(flags); + + int codePoint = BinaryFormat::getCodePointAndForwardPointer(dicRoot, &pos); + ASSERT(NOT_A_CODE_POINT != codePoint); + const int nodeCodePoint = codePoint; + // TODO: optimize this + int additionalWordBuf[MAX_WORD_LENGTH]; + uint16_t additionalSubwordLength = 0; + additionalWordBuf[additionalSubwordLength++] = codePoint; + + do { + const int nextCodePoint = hasMultipleChars + ? BinaryFormat::getCodePointAndForwardPointer(dicRoot, &pos) : NOT_A_CODE_POINT; + const bool isLastChar = (NOT_A_CODE_POINT == nextCodePoint); + if (!isLastChar) { + additionalWordBuf[additionalSubwordLength++] = nextCodePoint; + } + codePoint = nextCodePoint; + } while (NOT_A_CODE_POINT != codePoint); + + const int probability = + isTerminal ? BinaryFormat::readProbabilityWithoutMovingPointer(dicRoot, pos) : -1; + pos = BinaryFormat::skipProbability(flags, pos); + int childrenPos = hasChildren ? BinaryFormat::readChildrenPosition(dicRoot, flags, pos) : 0; + const int attributesPos = BinaryFormat::skipChildrenPosition(flags, pos); + const int siblingPos = BinaryFormat::skipChildrenPosAndAttributes(dicRoot, flags, pos); + + if (isDicNodeFilteredOut(nodeCodePoint, pInfo, codePointsFilter)) { + return siblingPos; + } + if (!isMatchedNodeCodePoint(pInfoState, pointIndex, exactOnly, nodeCodePoint)) { + return siblingPos; + } + const int childrenCount = hasChildren + ? BinaryFormat::getGroupCountAndForwardPointer(dicRoot, &childrenPos) : 0; + childDicNodes->pushLeavingChild(dicNode, nextPos, flags, childrenPos, attributesPos, siblingPos, + nodeCodePoint, childrenCount, probability, -1 /* bigramProbability */, isTerminal, + hasMultipleChars, hasChildren, additionalSubwordLength, additionalWordBuf); + return siblingPos; +} + +/* static */ bool DicNodeUtils::isDicNodeFilteredOut(const int nodeCodePoint, + const ProximityInfo *const pInfo, const std::vector<int> *const codePointsFilter) { + const int filterSize = codePointsFilter ? codePointsFilter->size() : 0; + if (filterSize <= 0) { + return false; + } + if (pInfo && (pInfo->getKeyIndexOf(nodeCodePoint) == NOT_AN_INDEX + || isIntentionalOmissionCodePoint(nodeCodePoint))) { + // If normalized nodeCodePoint is not on the keyboard or skippable, this child is never + // filtered. + return false; + } + const int lowerCodePoint = toLowerCase(nodeCodePoint); + const int baseLowerCodePoint = toBaseCodePoint(lowerCodePoint); + // TODO: Avoid linear search + for (int i = 0; i < filterSize; ++i) { + // Checking if a normalized code point is in filter characters when pInfo is not + // null. When pInfo is null, nodeCodePoint is used to check filtering without + // normalizing. + if ((pInfo && ((*codePointsFilter)[i] == lowerCodePoint + || (*codePointsFilter)[i] == baseLowerCodePoint)) + || (!pInfo && (*codePointsFilter)[i] == nodeCodePoint)) { + return false; + } + } + return true; +} + +/* static */ void DicNodeUtils::createAndGetAllLeavingChildNodes(DicNode *dicNode, + const uint8_t *const dicRoot, const ProximityInfoState *pInfoState, const int pointIndex, + const bool exactOnly, const std::vector<int> *const codePointsFilter, + const ProximityInfo *const pInfo, DicNodeVector *childDicNodes) { + const int terminalDepth = dicNode->getLeavingDepth(); + const int childCount = dicNode->getChildrenCount(); + int nextPos = dicNode->getChildrenPos(); + for (int i = 0; i < childCount; i++) { + const int filterSize = codePointsFilter ? codePointsFilter->size() : 0; + nextPos = createAndGetLeavingChildNode(dicNode, nextPos, dicRoot, terminalDepth, pInfoState, + pointIndex, exactOnly, codePointsFilter, pInfo, childDicNodes); + if (!pInfo && filterSize > 0 && childDicNodes->exceeds(filterSize)) { + // All code points have been found. + break; + } + } +} + +/* static */ void DicNodeUtils::getAllChildDicNodes(DicNode *dicNode, const uint8_t *const dicRoot, + DicNodeVector *childDicNodes) { + getProximityChildDicNodes(dicNode, dicRoot, 0, 0, false, childDicNodes); +} + +/* static */ void DicNodeUtils::getProximityChildDicNodes(DicNode *dicNode, + const uint8_t *const dicRoot, const ProximityInfoState *pInfoState, const int pointIndex, + bool exactOnly, DicNodeVector *childDicNodes) { + if (dicNode->isTotalInputSizeExceedingLimit()) { + return; + } + if (!dicNode->isLeavingNode()) { + DicNodeUtils::createAndGetPassingChildNode(dicNode, pInfoState, pointIndex, exactOnly, + childDicNodes); + } else { + DicNodeUtils::createAndGetAllLeavingChildNodes(dicNode, dicRoot, pInfoState, pointIndex, + exactOnly, 0 /* codePointsFilter */, 0 /* pInfo */, + childDicNodes); + } +} + +/////////////////// +// Scoring utils // +/////////////////// +/** + * Computes the combined bigram / unigram cost for the given dicNode. + */ +/* static */ float DicNodeUtils::getBigramNodeImprobability(const uint8_t *const dicRoot, + const DicNode *const node, hash_map_compat<int, int16_t> *bigramCacheMap) { + if (node->isImpossibleBigramWord()) { + return static_cast<float>(MAX_VALUE_FOR_WEIGHTING); + } + const int probability = getBigramNodeProbability(dicRoot, node, bigramCacheMap); + // TODO: This equation to calculate the improbability looks unreasonable. Investigate this. + const float cost = static_cast<float>(MAX_PROBABILITY - probability) + / static_cast<float>(MAX_PROBABILITY); + return cost; +} + +/* static */ int DicNodeUtils::getBigramNodeProbability(const uint8_t *const dicRoot, + const DicNode *const node, hash_map_compat<int, int16_t> *bigramCacheMap) { + const int unigramProbability = node->getProbability(); + const int encodedDiffOfBigramProbability = + getBigramNodeEncodedDiffProbability(dicRoot, node, bigramCacheMap); + if (NOT_A_PROBABILITY == encodedDiffOfBigramProbability) { + return backoff(unigramProbability); + } + return BinaryFormat::computeProbabilityForBigram( + unigramProbability, encodedDiffOfBigramProbability); +} + +/////////////////////////////////////// +// Bigram / Unigram dictionary utils // +/////////////////////////////////////// + +/* static */ int16_t DicNodeUtils::getBigramNodeEncodedDiffProbability(const uint8_t *const dicRoot, + const DicNode *const node, hash_map_compat<int, int16_t> *bigramCacheMap) { + const int wordPos = node->getPos(); + const int prevWordPos = node->getPrevWordPos(); + return getBigramProbability(dicRoot, prevWordPos, wordPos, bigramCacheMap); +} + +// TODO: Move this to BigramDictionary +/* static */ int16_t DicNodeUtils::getBigramProbability(const uint8_t *const dicRoot, int pos, + const int nextPos, hash_map_compat<int, int16_t> *bigramCacheMap) { + // TODO: this is painfully slow compared to the method used in the previous version of the + // algorithm. Switch to that method. + if (NOT_VALID_WORD == pos) return NOT_A_PROBABILITY; + if (NOT_VALID_WORD == nextPos) return NOT_A_PROBABILITY; + + // Create a hash code for the given node pair (based on Josh Bloch's effective Java). + // TODO: Use a real hash map data structure that deals with collisions. + int hash = 17; + hash = hash * 31 + pos; + hash = hash * 31 + nextPos; + + hash_map_compat<int, int16_t>::const_iterator mapPos = bigramCacheMap->find(hash); + if (mapPos != bigramCacheMap->end()) { + return mapPos->second; + } + if (NOT_VALID_WORD == pos) { + return NOT_A_PROBABILITY; + } + const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(dicRoot, &pos); + if (0 == (flags & BinaryFormat::FLAG_HAS_BIGRAMS)) { + return NOT_A_PROBABILITY; + } + if (0 == (flags & BinaryFormat::FLAG_HAS_MULTIPLE_CHARS)) { + BinaryFormat::getCodePointAndForwardPointer(dicRoot, &pos); + } else { + pos = BinaryFormat::skipOtherCharacters(dicRoot, pos); + } + pos = BinaryFormat::skipChildrenPosition(flags, pos); + pos = BinaryFormat::skipProbability(flags, pos); + uint8_t bigramFlags; + int count = 0; + do { + bigramFlags = BinaryFormat::getFlagsAndForwardPointer(dicRoot, &pos); + const int bigramPos = BinaryFormat::getAttributeAddressAndForwardPointer(dicRoot, + bigramFlags, &pos); + if (bigramPos == nextPos) { + const int16_t probability = BinaryFormat::MASK_ATTRIBUTE_PROBABILITY & bigramFlags; + if (static_cast<int>(bigramCacheMap->size()) < MAX_BIGRAM_MAP_SIZE) { + (*bigramCacheMap)[hash] = probability; + } + return probability; + } + count++; + } while ((0 != (BinaryFormat::FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags)) + && count < MAX_BIGRAMS_CONSIDERED_PER_CONTEXT); + if (static_cast<int>(bigramCacheMap->size()) < MAX_BIGRAM_MAP_SIZE) { + // TODO: does this -1 mean NOT_VALID_WORD? + (*bigramCacheMap)[hash] = -1; + } + return NOT_A_PROBABILITY; +} + +/* static */ int DicNodeUtils::getWordPos(const uint8_t *const dicRoot, const int *word, + const int wordLength) { + if (!word) { + return NOT_VALID_WORD; + } + return BinaryFormat::getTerminalPosition( + dicRoot, word, wordLength, false /* forceLowerCaseSearch */); +} + +/* static */ bool DicNodeUtils::isMatchedNodeCodePoint(const ProximityInfoState *pInfoState, + const int pointIndex, const bool exactOnly, const int nodeCodePoint) { + if (!pInfoState) { + return true; + } + if (exactOnly) { + return pInfoState->getPrimaryCodePointAt(pointIndex) == nodeCodePoint; + } + const ProximityType matchedId = pInfoState->getProximityType(pointIndex, nodeCodePoint, + true /* checkProximityChars */); + return isProximityChar(matchedId); +} + +//////////////// +// Char utils // +//////////////// + +// TODO: Move to char_utils? +/* static */ int DicNodeUtils::appendTwoWords(const int *const src0, const int16_t length0, + const int *const src1, const int16_t length1, int *dest) { + int actualLength0 = 0; + for (int i = 0; i < length0; ++i) { + if (src0[i] == 0) { + break; + } + actualLength0 = i + 1; + } + actualLength0 = min(actualLength0, MAX_WORD_LENGTH); + memcpy(dest, src0, actualLength0 * sizeof(dest[0])); + if (!src1 || length1 == 0) { + return actualLength0; + } + int actualLength1 = 0; + for (int i = 0; i < length1; ++i) { + if (src1[i] == 0) { + break; + } + actualLength1 = i + 1; + } + actualLength1 = min(actualLength1, MAX_WORD_LENGTH - actualLength0 - 1); + memcpy(&dest[actualLength0], src1, actualLength1 * sizeof(dest[0])); + return actualLength0 + actualLength1; +} +} // namespace latinime diff --git a/native/jni/src/suggest/core/dicnode/dic_node_utils.h b/native/jni/src/suggest/core/dicnode/dic_node_utils.h new file mode 100644 index 000000000..15f9730de --- /dev/null +++ b/native/jni/src/suggest/core/dicnode/dic_node_utils.h @@ -0,0 +1,88 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DIC_NODE_UTILS_H +#define LATINIME_DIC_NODE_UTILS_H + +#include <stdint.h> +#include <vector> + +#include "defines.h" +#include "hash_map_compat.h" + +namespace latinime { + +class DicNode; +class DicNodeVector; +class ProximityInfo; +class ProximityInfoState; + +class DicNodeUtils { + public: + static int appendTwoWords(const int *src0, const int16_t length0, const int *src1, + const int16_t length1, int *dest); + static void initAsRoot(const int rootPos, const uint8_t *const dicRoot, + const int prevWordNodePos, DicNode *newRootNode); + static void initAsRootWithPreviousWord(const int rootPos, const uint8_t *const dicRoot, + DicNode *prevWordLastNode, DicNode *newRootNode); + static void initByCopy(DicNode *srcNode, DicNode *destNode); + static void getAllChildDicNodes(DicNode *dicNode, const uint8_t *const dicRoot, + DicNodeVector *childDicNodes); + static int getWordPos(const uint8_t *const dicRoot, const int *word, const int prevWordLength); + static float getBigramNodeImprobability(const uint8_t *const dicRoot, + const DicNode *const node, hash_map_compat<int, int16_t> *const bigramCacheMap); + static bool isDicNodeFilteredOut(const int nodeCodePoint, const ProximityInfo *const pInfo, + const std::vector<int> *const codePointsFilter); + // TODO: Move to private + static void getProximityChildDicNodes(DicNode *dicNode, const uint8_t *const dicRoot, + const ProximityInfoState *pInfoState, const int pointIndex, bool exactOnly, + DicNodeVector *childDicNodes); + + // TODO: Move to proximity info + static bool isProximityChar(ProximityType type) { + return type == MATCH_CHAR || type == PROXIMITY_CHAR || type == ADDITIONAL_PROXIMITY_CHAR; + } + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(DicNodeUtils); + // Max cache size for the space omission error correction bigram lookup + static const int MAX_BIGRAM_MAP_SIZE = 20000; + // Max number of bigrams to look up + static const int MAX_BIGRAMS_CONSIDERED_PER_CONTEXT = 500; + + static int getBigramNodeProbability(const uint8_t *const dicRoot, const DicNode *const node, + hash_map_compat<int, int16_t> *bigramCacheMap); + static int16_t getBigramNodeEncodedDiffProbability(const uint8_t *const dicRoot, + const DicNode *const node, hash_map_compat<int, int16_t> *bigramCacheMap); + static void createAndGetPassingChildNode(DicNode *dicNode, const ProximityInfoState *pInfoState, + const int pointIndex, const bool exactOnly, DicNodeVector *childDicNodes); + static void createAndGetAllLeavingChildNodes(DicNode *dicNode, const uint8_t *const dicRoot, + const ProximityInfoState *pInfoState, const int pointIndex, const bool exactOnly, + const std::vector<int> *const codePointsFilter, + const ProximityInfo *const pInfo, DicNodeVector *childDicNodes); + static int createAndGetLeavingChildNode(DicNode *dicNode, int pos, const uint8_t *const dicRoot, + const int terminalDepth, const ProximityInfoState *pInfoState, const int pointIndex, + const bool exactOnly, const std::vector<int> *const codePointsFilter, + const ProximityInfo *const pInfo, DicNodeVector *childDicNodes); + static int16_t getBigramProbability(const uint8_t *const dicRoot, int pos, const int nextPos, + hash_map_compat<int, int16_t> *bigramCacheMap); + + // TODO: Move to proximity info + static bool isMatchedNodeCodePoint(const ProximityInfoState *pInfoState, const int pointIndex, + const bool exactOnly, const int nodeCodePoint); +}; +} // namespace latinime +#endif // LATINIME_DIC_NODE_UTILS_H diff --git a/native/jni/src/suggest/core/dicnode/dic_node_vector.h b/native/jni/src/suggest/core/dicnode/dic_node_vector.h new file mode 100644 index 000000000..ca07edaee --- /dev/null +++ b/native/jni/src/suggest/core/dicnode/dic_node_vector.h @@ -0,0 +1,95 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DIC_NODE_VECTOR_H +#define LATINIME_DIC_NODE_VECTOR_H + +#include <vector> + +#include "defines.h" +#include "dic_node.h" + +namespace latinime { + +class DicNodeVector { + public: +#ifdef FLAG_DBG + // 0 will introduce resizing the vector. + static const int DEFAULT_NODES_SIZE_FOR_OPTIMIZATION = 0; +#else + static const int DEFAULT_NODES_SIZE_FOR_OPTIMIZATION = 60; +#endif + AK_FORCE_INLINE DicNodeVector() : mDicNodes(0), mLock(false), mEmptyNode() {} + + // Specify the capacity of the vector + AK_FORCE_INLINE DicNodeVector(const int size) : mDicNodes(0), mLock(false), mEmptyNode() { + mDicNodes.reserve(size); + } + + // Non virtual inline destructor -- never inherit this class + AK_FORCE_INLINE ~DicNodeVector() {} + + AK_FORCE_INLINE void clear() { + mDicNodes.clear(); + mLock = false; + } + + int getSizeAndLock() { + mLock = true; + return static_cast<int>(mDicNodes.size()); + } + + bool exceeds(const size_t limit) const { + return mDicNodes.size() >= limit; + } + + void pushPassingChild(DicNode *dicNode) { + ASSERT(!mLock); + mDicNodes.push_back(mEmptyNode); + mDicNodes.back().initAsPassingChild(dicNode); + } + + void pushLeavingChild(DicNode *dicNode, const int pos, const uint8_t flags, + const int childrenPos, const int attributesPos, const int siblingPos, + const int nodeCodePoint, const int childrenCount, const int probability, + const int bigramProbability, const bool isTerminal, const bool hasMultipleChars, + const bool hasChildren, const uint16_t additionalSubwordLength, + const int *additionalSubword) { + ASSERT(!mLock); + mDicNodes.push_back(mEmptyNode); + mDicNodes.back().initAsChild(dicNode, pos, flags, childrenPos, attributesPos, siblingPos, + nodeCodePoint, childrenCount, probability, -1 /* bigramProbability */, isTerminal, + hasMultipleChars, hasChildren, additionalSubwordLength, additionalSubword); + } + + DicNode *operator[](const int id) { + ASSERT(id < static_cast<int>(mDicNodes.size())); + return &mDicNodes[id]; + } + + DicNode *front() { + ASSERT(1 <= static_cast<int>(mDicNodes.size())); + return &mDicNodes[0]; + } + + private: + DISALLOW_COPY_AND_ASSIGN(DicNodeVector); + std::vector<DicNode> mDicNodes; + bool mLock; + DicNode mEmptyNode; +}; +} // namespace latinime +#endif // LATINIME_DIC_NODE_VECTOR_H diff --git a/native/jni/src/suggest/core/dicnode/dic_nodes_cache.cpp b/native/jni/src/suggest/core/dicnode/dic_nodes_cache.cpp new file mode 100644 index 000000000..b9a60780b --- /dev/null +++ b/native/jni/src/suggest/core/dicnode/dic_nodes_cache.cpp @@ -0,0 +1,59 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <list> + +#include "defines.h" +#include "dic_node_priority_queue.h" +#include "dic_node_utils.h" +#include "dic_nodes_cache.h" + +namespace latinime { + +/** + * Truncates all of the dicNodes so that they start at the given commit point. + * Only called for multi-word typing input. + */ +DicNode *DicNodesCache::setCommitPoint(int commitPoint) { + std::list<DicNode> dicNodesList; + while (mCachedDicNodesForContinuousSuggestion->getSize() > 0) { + DicNode dicNode; + mCachedDicNodesForContinuousSuggestion->copyPop(&dicNode); + dicNodesList.push_front(dicNode); + } + + // Get the starting words of the top scoring dicNode (last dicNode popped from priority queue) + // up to the commit point. These words have already been committed to the text view. + DicNode *topDicNode = &dicNodesList.front(); + DicNode topDicNodeCopy; + DicNodeUtils::initByCopy(topDicNode, &topDicNodeCopy); + + // Keep only those dicNodes that match the same starting words. + std::list<DicNode>::iterator iter; + for (iter = dicNodesList.begin(); iter != dicNodesList.end(); iter++) { + DicNode *dicNode = &*iter; + if (dicNode->truncateNode(&topDicNodeCopy, commitPoint)) { + mCachedDicNodesForContinuousSuggestion->copyPush(dicNode); + } else { + // Top dicNode should be reprocessed. + ASSERT(dicNode != topDicNode); + DicNode::managedDelete(dicNode); + } + } + mInputIndex -= commitPoint; + return topDicNode; +} +} // namespace latinime diff --git a/native/jni/src/suggest/core/dicnode/dic_nodes_cache.h b/native/jni/src/suggest/core/dicnode/dic_nodes_cache.h new file mode 100644 index 000000000..a62aa422a --- /dev/null +++ b/native/jni/src/suggest/core/dicnode/dic_nodes_cache.h @@ -0,0 +1,185 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DIC_NODES_CACHE_H +#define LATINIME_DIC_NODES_CACHE_H + +#include <stdint.h> + +#include "defines.h" +#include "dic_node_priority_queue.h" + +#define INITIAL_QUEUE_ID_ACTIVE 0 +#define INITIAL_QUEUE_ID_NEXT_ACTIVE 1 +#define INITIAL_QUEUE_ID_TERMINAL 2 +#define INITIAL_QUEUE_ID_CACHE_FOR_CONTINUOUS_SUGGESTION 3 +#define PRIORITY_QUEUES_SIZE 4 + +namespace latinime { + +class DicNode; + +/** + * Class for controlling dicNode search priority queue and lexicon trie traversal. + */ +class DicNodesCache { + public: + AK_FORCE_INLINE DicNodesCache() + : mActiveDicNodes(&mDicNodePriorityQueues[INITIAL_QUEUE_ID_ACTIVE]), + mNextActiveDicNodes(&mDicNodePriorityQueues[INITIAL_QUEUE_ID_NEXT_ACTIVE]), + mTerminalDicNodes(&mDicNodePriorityQueues[INITIAL_QUEUE_ID_TERMINAL]), + mCachedDicNodesForContinuousSuggestion( + &mDicNodePriorityQueues[INITIAL_QUEUE_ID_CACHE_FOR_CONTINUOUS_SUGGESTION]), + mInputIndex(0), mLastCachedInputIndex(0) { + } + + AK_FORCE_INLINE virtual ~DicNodesCache() {} + + AK_FORCE_INLINE void reset(const int nextActiveSize, const int terminalSize) { + mInputIndex = 0; + mLastCachedInputIndex = 0; + mActiveDicNodes->reset(); + mNextActiveDicNodes->clearAndResize(nextActiveSize); + mTerminalDicNodes->clearAndResize(terminalSize); + mCachedDicNodesForContinuousSuggestion->reset(); + } + + AK_FORCE_INLINE void continueSearch() { + resetTemporaryCaches(); + restoreActiveDicNodesFromCache(); + } + + AK_FORCE_INLINE void advanceActiveDicNodes() { + if (DEBUG_DICT) { + AKLOGI("Advance active %d nodes.", mNextActiveDicNodes->getSize()); + } + if (DEBUG_DICT_FULL) { + mNextActiveDicNodes->dump(); + } + mNextActiveDicNodes = + moveNodesAndReturnReusableEmptyQueue(mNextActiveDicNodes, &mActiveDicNodes); + } + + DicNode *setCommitPoint(int commitPoint); + + int activeSize() const { return mActiveDicNodes->getSize(); } + int terminalSize() const { return mTerminalDicNodes->getSize(); } + bool isLookAheadCorrectionInputIndex(const int inputIndex) const { + return inputIndex == mInputIndex - 1; + } + void advanceInputIndex(const int inputSize) { + if (mInputIndex < inputSize) { + mInputIndex++; + } + } + + AK_FORCE_INLINE void copyPushTerminal(DicNode *dicNode) { + mTerminalDicNodes->copyPush(dicNode); + } + + AK_FORCE_INLINE void copyPushActive(DicNode *dicNode) { + mActiveDicNodes->copyPush(dicNode); + } + + AK_FORCE_INLINE bool copyPushContinue(DicNode *dicNode) { + return mCachedDicNodesForContinuousSuggestion->copyPush(dicNode); + } + + AK_FORCE_INLINE void copyPushNextActive(DicNode *dicNode) { + DicNode *pushedDicNode = mNextActiveDicNodes->copyPush(dicNode); + if (!pushedDicNode) { + if (dicNode->isCached()) { + dicNode->remove(); + } + // We simply drop any dic node that was not cached, ignoring the slim chance + // that one of its children represents what the user really wanted. + } + } + + void popTerminal(DicNode *dest) { + mTerminalDicNodes->copyPop(dest); + } + + void popActive(DicNode *dest) { + mActiveDicNodes->copyPop(dest); + } + + bool hasCachedDicNodesForContinuousSuggestion() const { + return mCachedDicNodesForContinuousSuggestion + && mCachedDicNodesForContinuousSuggestion->getSize() > 0; + } + + AK_FORCE_INLINE bool isCacheBorderForTyping(const int inputSize) const { + // TODO: Move this variable to header + static const int CACHE_BACK_LENGTH = 3; + const int cacheInputIndex = inputSize - CACHE_BACK_LENGTH; + const bool shouldCache = (cacheInputIndex == mInputIndex) + && (cacheInputIndex != mLastCachedInputIndex); + return shouldCache; + } + + AK_FORCE_INLINE void updateLastCachedInputIndex() { + mLastCachedInputIndex = mInputIndex; + } + + private: + DISALLOW_COPY_AND_ASSIGN(DicNodesCache); + + AK_FORCE_INLINE void restoreActiveDicNodesFromCache() { + if (DEBUG_DICT) { + AKLOGI("Restore %d nodes. inputIndex = %d.", + mCachedDicNodesForContinuousSuggestion->getSize(), mLastCachedInputIndex); + } + if (DEBUG_DICT_FULL || DEBUG_CACHE) { + mCachedDicNodesForContinuousSuggestion->dump(); + } + mInputIndex = mLastCachedInputIndex; + mCachedDicNodesForContinuousSuggestion = + moveNodesAndReturnReusableEmptyQueue( + mCachedDicNodesForContinuousSuggestion, &mActiveDicNodes); + } + + AK_FORCE_INLINE static DicNodePriorityQueue *moveNodesAndReturnReusableEmptyQueue( + DicNodePriorityQueue *src, DicNodePriorityQueue **dest) { + const int srcMaxSize = src->getMaxSize(); + const int destMaxSize = (*dest)->getMaxSize(); + DicNodePriorityQueue *tmp = *dest; + *dest = src; + (*dest)->setMaxSize(destMaxSize); + tmp->clearAndResize(srcMaxSize); + return tmp; + } + + AK_FORCE_INLINE void resetTemporaryCaches() { + mActiveDicNodes->clear(); + mNextActiveDicNodes->clear(); + mTerminalDicNodes->clear(); + } + + DicNodePriorityQueue mDicNodePriorityQueues[PRIORITY_QUEUES_SIZE]; + // Active dicNodes currently being expanded. + DicNodePriorityQueue *mActiveDicNodes; + // Next dicNodes to be expanded. + DicNodePriorityQueue *mNextActiveDicNodes; + // Current top terminal dicNodes. + DicNodePriorityQueue *mTerminalDicNodes; + // Cached dicNodes used for continuous suggestion. + DicNodePriorityQueue *mCachedDicNodesForContinuousSuggestion; + int mInputIndex; + int mLastCachedInputIndex; +}; +} // namespace latinime +#endif // LATINIME_DIC_NODES_CACHE_H diff --git a/native/jni/src/suggest/core/dictionary/shortcut_utils.h b/native/jni/src/suggest/core/dictionary/shortcut_utils.h new file mode 100644 index 000000000..e592136cc --- /dev/null +++ b/native/jni/src/suggest/core/dictionary/shortcut_utils.h @@ -0,0 +1,65 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_SHORTCUT_UTILS +#define LATINIME_SHORTCUT_UTILS + +#include "defines.h" +#include "dic_node_utils.h" +#include "terminal_attributes.h" + +namespace latinime { + +class ShortcutUtils { + public: + static int outputShortcuts(const TerminalAttributes *const terminalAttributes, + int outputWordIndex, const int finalScore, int *const outputCodePoints, + int *const frequencies, int *const outputTypes, const bool sameAsTyped) { + TerminalAttributes::ShortcutIterator iterator = terminalAttributes->getShortcutIterator(); + while (iterator.hasNextShortcutTarget() && outputWordIndex < MAX_RESULTS) { + int shortcutTarget[MAX_WORD_LENGTH]; + int shortcutProbability; + const int shortcutTargetStringLength = iterator.getNextShortcutTarget( + MAX_WORD_LENGTH, shortcutTarget, &shortcutProbability); + int shortcutScore; + int kind; + if (shortcutProbability == BinaryFormat::WHITELIST_SHORTCUT_PROBABILITY + && sameAsTyped) { + shortcutScore = S_INT_MAX; + kind = Dictionary::KIND_WHITELIST; + } else { + // shortcut entry's score == its base entry's score - 1 + shortcutScore = finalScore; + // Protection against int underflow + shortcutScore = max(S_INT_MIN + 1, shortcutScore) - 1; + kind = Dictionary::KIND_CORRECTION; + } + outputTypes[outputWordIndex] = kind; + frequencies[outputWordIndex] = shortcutScore; + frequencies[outputWordIndex] = max(S_INT_MIN + 1, shortcutScore) - 1; + const int startIndex2 = outputWordIndex * MAX_WORD_LENGTH; + DicNodeUtils::appendTwoWords(0, 0, shortcutTarget, shortcutTargetStringLength, + &outputCodePoints[startIndex2]); + ++outputWordIndex; + } + return outputWordIndex; + } + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(ShortcutUtils); +}; +} // namespace latinime +#endif // LATINIME_SHORTCUT_UTILS diff --git a/native/jni/src/suggest/core/policy/scoring.h b/native/jni/src/suggest/core/policy/scoring.h new file mode 100644 index 000000000..b8c10e25a --- /dev/null +++ b/native/jni/src/suggest/core/policy/scoring.h @@ -0,0 +1,57 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_SCORING_H +#define LATINIME_SCORING_H + +#include "defines.h" + +namespace latinime { + +class DicNode; +class DicTraverseSession; + +// This class basically tweaks suggestions and distances apart from CompoundDistance +class Scoring { + public: + virtual int calculateFinalScore(const float compoundDistance, const int inputSize, + const bool forceCommit) const = 0; + virtual bool getMostProbableString( + const DicTraverseSession *const traverseSession, const int terminalSize, + const float languageWeight, int *const outputCodePoints, int *const type, + int *const freq) const = 0; + virtual void safetyNetForMostProbableString(const int terminalSize, + const int maxScore, int *const outputCodePoints, int *const frequencies) const = 0; + // TODO: Make more generic + virtual void searchWordWithDoubleLetter(DicNode *terminals, + const int terminalSize, int *doubleLetterTerminalIndex, + DoubleLetterLevel *doubleLetterLevel) const = 0; + virtual float getAdjustedLanguageWeight(DicTraverseSession *const traverseSession, + DicNode *const terminals, const int size) const = 0; + virtual float getDoubleLetterDemotionDistanceCost(const int terminalIndex, + const int doubleLetterTerminalIndex, + const DoubleLetterLevel doubleLetterLevel) const = 0; + virtual bool doesAutoCorrectValidWord() const = 0; + + protected: + Scoring() {} + virtual ~Scoring() {} + + private: + DISALLOW_COPY_AND_ASSIGN(Scoring); +}; +} // namespace latinime +#endif // LATINIME_SCORING_H diff --git a/native/jni/src/suggest/core/policy/suggest_policy.h b/native/jni/src/suggest/core/policy/suggest_policy.h new file mode 100644 index 000000000..885e214f7 --- /dev/null +++ b/native/jni/src/suggest/core/policy/suggest_policy.h @@ -0,0 +1,39 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_SUGGEST_POLICY_H +#define LATINIME_SUGGEST_POLICY_H + +#include "defines.h" + +namespace latinime { +class Traversal; +class Scoring; +class Weighting; + +class SuggestPolicy { + public: + SuggestPolicy() {} + virtual ~SuggestPolicy() {} + virtual const Traversal *getTraversal() const = 0; + virtual const Scoring *getScoring() const = 0; + virtual const Weighting *getWeighting() const = 0; + + private: + DISALLOW_COPY_AND_ASSIGN(SuggestPolicy); +}; +} // namespace latinime +#endif // LATINIME_SUGGEST_POLICY_H diff --git a/native/jni/src/suggest/core/policy/traversal.h b/native/jni/src/suggest/core/policy/traversal.h new file mode 100644 index 000000000..1d5082ff8 --- /dev/null +++ b/native/jni/src/suggest/core/policy/traversal.h @@ -0,0 +1,61 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_TRAVERSAL_H +#define LATINIME_TRAVERSAL_H + +#include "defines.h" + +namespace latinime { +class Traversal { + public: + virtual int getMaxPointerCount() const = 0; + virtual bool allowsErrorCorrections(const DicNode *const dicNode) const = 0; + virtual bool isOmission(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode, const DicNode *const childDicNode) const = 0; + virtual bool isSpaceSubstitutionTerminal(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const = 0; + virtual bool isSpaceOmissionTerminal(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const = 0; + virtual bool shouldDepthLevelCache(const DicTraverseSession *const traverseSession) const = 0; + virtual bool shouldNodeLevelCache(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const = 0; + virtual bool canDoLookAheadCorrection(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const = 0; + virtual ProximityType getProximityType( + const DicTraverseSession *const traverseSession, const DicNode *const dicNode, + const DicNode *const childDicNode) const = 0; + virtual bool sameAsTyped(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const = 0; + virtual bool needsToTraverseAllUserInput() const = 0; + virtual float getMaxSpatialDistance() const = 0; + virtual bool allowPartialCommit() const = 0; + virtual int getDefaultExpandDicNodeSize() const = 0; + virtual int getMaxCacheSize() const = 0; + virtual bool isPossibleOmissionChildNode( + const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, + const DicNode *const dicNode) const = 0; + virtual bool isGoodToTraverseNextWord(const DicNode *const dicNode) const = 0; + + protected: + Traversal() {} + virtual ~Traversal() {} + + private: + DISALLOW_COPY_AND_ASSIGN(Traversal); +}; +} // namespace latinime +#endif // LATINIME_TRAVERSAL_H diff --git a/native/jni/src/suggest/core/policy/weighting.cpp b/native/jni/src/suggest/core/policy/weighting.cpp new file mode 100644 index 000000000..4d08fa0fa --- /dev/null +++ b/native/jni/src/suggest/core/policy/weighting.cpp @@ -0,0 +1,244 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "char_utils.h" +#include "defines.h" +#include "dic_node.h" +#include "dic_node_profiler.h" +#include "dic_node_utils.h" +#include "dic_traverse_session.h" +#include "hash_map_compat.h" +#include "weighting.h" + +namespace latinime { + +static inline void profile(const CorrectionType correctionType, DicNode *const node) { +#if DEBUG_DICT + switch (correctionType) { + case CT_OMISSION: + PROF_OMISSION(node->mProfiler); + return; + case CT_ADDITIONAL_PROXIMITY: + PROF_ADDITIONAL_PROXIMITY(node->mProfiler); + return; + case CT_SUBSTITUTION: + PROF_SUBSTITUTION(node->mProfiler); + return; + case CT_NEW_WORD: + PROF_NEW_WORD(node->mProfiler); + return; + case CT_MATCH: + PROF_MATCH(node->mProfiler); + return; + case CT_COMPLETION: + PROF_COMPLETION(node->mProfiler); + return; + case CT_TERMINAL: + PROF_TERMINAL(node->mProfiler); + return; + case CT_SPACE_SUBSTITUTION: + PROF_SPACE_SUBSTITUTION(node->mProfiler); + return; + case CT_INSERTION: + PROF_INSERTION(node->mProfiler); + return; + case CT_TRANSPOSITION: + PROF_TRANSPOSITION(node->mProfiler); + return; + default: + // do nothing + return; + } +#else + // do nothing +#endif +} + +/* static */ void Weighting::addCostAndForwardInputIndex(const Weighting *const weighting, + const CorrectionType correctionType, + const DicTraverseSession *const traverseSession, + const DicNode *const parentDicNode, DicNode *const dicNode, + hash_map_compat<int, int16_t> *const bigramCacheMap) { + const int inputSize = traverseSession->getInputSize(); + DicNode_InputStateG inputStateG; + inputStateG.mNeedsToUpdateInputStateG = false; // Don't use input info by default + const float spatialCost = Weighting::getSpatialCost(weighting, correctionType, + traverseSession, parentDicNode, dicNode, &inputStateG); + const float languageCost = Weighting::getLanguageCost(weighting, correctionType, + traverseSession, parentDicNode, dicNode, bigramCacheMap); + const bool edit = Weighting::isEditCorrection(correctionType); + const bool proximity = Weighting::isProximityCorrection(weighting, correctionType, + traverseSession, dicNode); + profile(correctionType, dicNode); + if (inputStateG.mNeedsToUpdateInputStateG) { + dicNode->updateInputIndexG(&inputStateG); + } else { + dicNode->forwardInputIndex(0, getForwardInputCount(correctionType), + (correctionType == CT_TRANSPOSITION)); + } + dicNode->addCost(spatialCost, languageCost, weighting->needsToNormalizeCompoundDistance(), + inputSize, edit, proximity); +} + +/* static */ float Weighting::getSpatialCost(const Weighting *const weighting, + const CorrectionType correctionType, + const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, + const DicNode *const dicNode, DicNode_InputStateG *const inputStateG) { + switch(correctionType) { + case CT_OMISSION: + return weighting->getOmissionCost(parentDicNode, dicNode); + case CT_ADDITIONAL_PROXIMITY: + // only used for typing + return weighting->getAdditionalProximityCost(); + case CT_SUBSTITUTION: + // only used for typing + return weighting->getSubstitutionCost(); + case CT_NEW_WORD: + return weighting->getNewWordCost(dicNode); + case CT_MATCH: + return weighting->getMatchedCost(traverseSession, dicNode, inputStateG); + case CT_COMPLETION: + return weighting->getCompletionCost(traverseSession, dicNode); + case CT_TERMINAL: + return weighting->getTerminalSpatialCost(traverseSession, dicNode); + case CT_SPACE_SUBSTITUTION: + return weighting->getSpaceSubstitutionCost(); + case CT_INSERTION: + return weighting->getInsertionCost(traverseSession, parentDicNode, dicNode); + case CT_TRANSPOSITION: + return weighting->getTranspositionCost(traverseSession, parentDicNode, dicNode); + default: + return 0.0f; + } +} + +/* static */ float Weighting::getLanguageCost(const Weighting *const weighting, + const CorrectionType correctionType, const DicTraverseSession *const traverseSession, + const DicNode *const parentDicNode, const DicNode *const dicNode, + hash_map_compat<int, int16_t> *const bigramCacheMap) { + switch(correctionType) { + case CT_OMISSION: + return 0.0f; + case CT_SUBSTITUTION: + return 0.0f; + case CT_NEW_WORD: + return weighting->getNewWordBigramCost(traverseSession, parentDicNode, bigramCacheMap); + case CT_MATCH: + return 0.0f; + case CT_COMPLETION: + return 0.0f; + case CT_TERMINAL: { + const float languageImprobability = + DicNodeUtils::getBigramNodeImprobability( + traverseSession->getOffsetDict(), dicNode, bigramCacheMap); + return weighting->getTerminalLanguageCost(traverseSession, dicNode, languageImprobability); + } + case CT_SPACE_SUBSTITUTION: + return 0.0f; + case CT_INSERTION: + return 0.0f; + case CT_TRANSPOSITION: + return 0.0f; + default: + return 0.0f; + } +} + +/* static */ bool Weighting::isEditCorrection(const CorrectionType correctionType) { + switch(correctionType) { + case CT_OMISSION: + return true; + case CT_ADDITIONAL_PROXIMITY: + // Should return true? + return false; + case CT_SUBSTITUTION: + // Should return true? + return false; + case CT_NEW_WORD: + return false; + case CT_MATCH: + return false; + case CT_COMPLETION: + return false; + case CT_TERMINAL: + return false; + case CT_SPACE_SUBSTITUTION: + return false; + case CT_INSERTION: + return true; + case CT_TRANSPOSITION: + return true; + default: + return false; + } +} + +/* static */ bool Weighting::isProximityCorrection(const Weighting *const weighting, + const CorrectionType correctionType, + const DicTraverseSession *const traverseSession, const DicNode *const dicNode) { + switch(correctionType) { + case CT_OMISSION: + return false; + case CT_ADDITIONAL_PROXIMITY: + return false; + case CT_SUBSTITUTION: + return false; + case CT_NEW_WORD: + return false; + case CT_MATCH: + return weighting->isProximityDicNode(traverseSession, dicNode); + case CT_COMPLETION: + return false; + case CT_TERMINAL: + return false; + case CT_SPACE_SUBSTITUTION: + return false; + case CT_INSERTION: + return false; + case CT_TRANSPOSITION: + return false; + default: + return false; + } +} + +/* static */ int Weighting::getForwardInputCount(const CorrectionType correctionType) { + switch(correctionType) { + case CT_OMISSION: + return 0; + case CT_ADDITIONAL_PROXIMITY: + return 0; + case CT_SUBSTITUTION: + return 0; + case CT_NEW_WORD: + return 0; + case CT_MATCH: + return 1; + case CT_COMPLETION: + return 0; + case CT_TERMINAL: + return 0; + case CT_SPACE_SUBSTITUTION: + return 1; + case CT_INSERTION: + return 2; + case CT_TRANSPOSITION: + return 2; + default: + return 0; + } +} +} // namespace latinime diff --git a/native/jni/src/suggest/core/policy/weighting.h b/native/jni/src/suggest/core/policy/weighting.h new file mode 100644 index 000000000..83a0f4b45 --- /dev/null +++ b/native/jni/src/suggest/core/policy/weighting.h @@ -0,0 +1,104 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_WEIGHTING_H +#define LATINIME_WEIGHTING_H + +#include "defines.h" + +namespace latinime { + +class DicNode; +class DicTraverseSession; +struct DicNode_InputStateG; + +class Weighting { + public: + static void addCostAndForwardInputIndex(const Weighting *const weighting, + const CorrectionType correctionType, + const DicTraverseSession *const traverseSession, + const DicNode *const parentDicNode, DicNode *const dicNode, + hash_map_compat<int, int16_t> *const bigramCacheMap); + + protected: + virtual float getTerminalSpatialCost(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const = 0; + + virtual float getOmissionCost( + const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0; + + virtual float getMatchedCost( + const DicTraverseSession *const traverseSession, const DicNode *const dicNode, + DicNode_InputStateG *inputStateG) const = 0; + + virtual bool isProximityDicNode(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const = 0; + + virtual float getTranspositionCost( + const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, + const DicNode *const dicNode) const = 0; + + virtual float getInsertionCost( + const DicTraverseSession *const traverseSession, + const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0; + + virtual float getNewWordCost(const DicNode *const dicNode) const = 0; + + virtual float getNewWordBigramCost( + const DicTraverseSession *const traverseSession, const DicNode *const dicNode, + hash_map_compat<int, int16_t> *const bigramCacheMap) const = 0; + + virtual float getCompletionCost( + const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const = 0; + + virtual float getTerminalLanguageCost( + const DicTraverseSession *const traverseSession, const DicNode *const dicNode, + float dicNodeLanguageImprobability) const = 0; + + virtual bool needsToNormalizeCompoundDistance() const = 0; + + virtual float getAdditionalProximityCost() const = 0; + + virtual float getSubstitutionCost() const = 0; + + virtual float getSpaceSubstitutionCost() const = 0; + + Weighting() {} + virtual ~Weighting() {} + + private: + DISALLOW_COPY_AND_ASSIGN(Weighting); + + static float getSpatialCost(const Weighting *const weighting, + const CorrectionType correctionType, const DicTraverseSession *const traverseSession, + const DicNode *const parentDicNode, const DicNode *const dicNode, + DicNode_InputStateG *const inputStateG); + static float getLanguageCost(const Weighting *const weighting, + const CorrectionType correctionType, const DicTraverseSession *const traverseSession, + const DicNode *const parentDicNode, const DicNode *const dicNode, + hash_map_compat<int, int16_t> *const bigramCacheMap); + // TODO: Move to TypingWeighting and GestureWeighting? + static bool isEditCorrection(const CorrectionType correctionType); + // TODO: Move to TypingWeighting and GestureWeighting? + static bool isProximityCorrection(const Weighting *const weighting, + const CorrectionType correctionType, const DicTraverseSession *const traverseSession, + const DicNode *const dicNode); + // TODO: Move to TypingWeighting and GestureWeighting? + static int getForwardInputCount(const CorrectionType correctionType); +}; +} // namespace latinime +#endif // LATINIME_WEIGHTING_H diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.cpp b/native/jni/src/suggest/core/session/dic_traverse_session.cpp new file mode 100644 index 000000000..1f781dd43 --- /dev/null +++ b/native/jni/src/suggest/core/session/dic_traverse_session.cpp @@ -0,0 +1,106 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "defines.h" +#include "dictionary.h" +#include "dic_node_utils.h" +#include "dic_traverse_session.h" +#include "dic_traverse_wrapper.h" +#include "jni.h" + +namespace latinime { + +const int DicTraverseSession::CACHE_START_INPUT_LENGTH_THRESHOLD = 20; + +// A factory method for DicTraverseSession +static void *getSessionInstance(JNIEnv *env, jstring localeStr) { + return new DicTraverseSession(env, localeStr); +} + +// TODO: Pass "DicTraverseSession *traverseSession" when the source code structure settles down. +static void initSessionInstance(void *traverseSession, const Dictionary *const dictionary, + const int *prevWord, const int prevWordLength) { + if (traverseSession) { + DicTraverseSession *tSession = static_cast<DicTraverseSession *>(traverseSession); + tSession->init(dictionary, prevWord, prevWordLength); + } +} + +// TODO: Pass "DicTraverseSession *traverseSession" when the source code structure settles down. +static void releaseSessionInstance(void *traverseSession) { + delete static_cast<DicTraverseSession *>(traverseSession); +} + +// An ad-hoc internal class to register the factory method defined above +class TraverseSessionFactoryRegisterer { + public: + TraverseSessionFactoryRegisterer() { + DicTraverseWrapper::setTraverseSessionFactoryMethod(getSessionInstance); + DicTraverseWrapper::setTraverseSessionInitMethod(initSessionInstance); + DicTraverseWrapper::setTraverseSessionReleaseMethod(releaseSessionInstance); + } + private: + DISALLOW_COPY_AND_ASSIGN(TraverseSessionFactoryRegisterer); +}; + +// To invoke the TraverseSessionFactoryRegisterer constructor in the global constructor. +static TraverseSessionFactoryRegisterer traverseSessionFactoryRegisterer; + +void DicTraverseSession::init(const Dictionary *const dictionary, const int *prevWord, + int prevWordLength) { + mDictionary = dictionary; + if (!prevWord) { + mPrevWordPos = NOT_VALID_WORD; + return; + } + mPrevWordPos = DicNodeUtils::getWordPos(dictionary->getOffsetDict(), prevWord, prevWordLength); +} + +void DicTraverseSession::setupForGetSuggestions(const ProximityInfo *pInfo, + const int *inputCodePoints, const int inputSize, const int *const inputXs, + const int *const inputYs, const int *const times, const int *const pointerIds, + const float maxSpatialDistance, const int maxPointerCount) { + mProximityInfo = pInfo; + mMaxPointerCount = maxPointerCount; + initializeProximityInfoStates(inputCodePoints, inputXs, inputYs, times, pointerIds, inputSize, + maxSpatialDistance, maxPointerCount); +} + +const uint8_t *DicTraverseSession::getOffsetDict() const { + return mDictionary->getOffsetDict(); +} + +void DicTraverseSession::resetCache(const int nextActiveCacheSize, const int maxWords) { + mDicNodesCache.reset(nextActiveCacheSize, maxWords); + mBigramCacheMap.clear(); + mPartiallyCommited = false; +} + +void DicTraverseSession::initializeProximityInfoStates(const int *const inputCodePoints, + const int *const inputXs, const int *const inputYs, const int *const times, + const int *const pointerIds, const int inputSize, const float maxSpatialDistance, + const int maxPointerCount) { + ASSERT(1 <= maxPointerCount && maxPointerCount <= MAX_POINTER_COUNT_G); + mInputSize = 0; + for (int i = 0; i < maxPointerCount; ++i) { + mProximityInfoStates[i].initInputParams(i, maxSpatialDistance, getProximityInfo(), + inputCodePoints, inputSize, inputXs, inputYs, times, pointerIds, + maxPointerCount == MAX_POINTER_COUNT_G + /* TODO: this is a hack. fix proximity info state */); + mInputSize += mProximityInfoStates[i].size(); + } +} +} // namespace latinime diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.h b/native/jni/src/suggest/core/session/dic_traverse_session.h new file mode 100644 index 000000000..af036f82b --- /dev/null +++ b/native/jni/src/suggest/core/session/dic_traverse_session.h @@ -0,0 +1,171 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DIC_TRAVERSE_SESSION_H +#define LATINIME_DIC_TRAVERSE_SESSION_H + +#include <stdint.h> +#include <vector> + +#include "defines.h" +#include "dic_nodes_cache.h" +#include "hash_map_compat.h" +#include "jni.h" +#include "proximity_info_state.h" + +namespace latinime { + +class Dictionary; +class ProximityInfo; + +class DicTraverseSession { + public: + AK_FORCE_INLINE DicTraverseSession(JNIEnv *env, jstring localeStr) + : mPrevWordPos(NOT_VALID_WORD), mProximityInfo(0), + mDictionary(0), mDicNodesCache(), mBigramCacheMap(), + mInputSize(0), mPartiallyCommited(false), mMaxPointerCount(1) { + // NOTE: mProximityInfoStates is an array of instances. + // No need to initialize it explicitly here. + } + + // Non virtual inline destructor -- never inherit this class + AK_FORCE_INLINE ~DicTraverseSession() {} + + void init(const Dictionary *dictionary, const int *prevWord, int prevWordLength); + // TODO: Remove and merge into init + void setupForGetSuggestions(const ProximityInfo *pInfo, const int *inputCodePoints, + const int inputSize, const int *const inputXs, const int *const inputYs, + const int *const times, const int *const pointerIds, const float maxSpatialDistance, + const int maxPointerCount); + void resetCache(const int nextActiveCacheSize, const int maxWords); + + const uint8_t *getOffsetDict() const; + bool canUseCache() const; + + //-------------------- + // getters and setters + //-------------------- + const ProximityInfo *getProximityInfo() const { return mProximityInfo; } + int getPrevWordPos() const { return mPrevWordPos; } + // TODO: REMOVE + void setPrevWordPos(int pos) { mPrevWordPos = pos; } + // TODO: Use proper parameter when changed + int getDicRootPos() const { return 0; } + DicNodesCache *getDicTraverseCache() { return &mDicNodesCache; } + hash_map_compat<int, int16_t> *getBigramCacheMap() { return &mBigramCacheMap; } + const ProximityInfoState *getProximityInfoState(int id) const { + return &mProximityInfoStates[id]; + } + int getInputSize() const { return mInputSize; } + void setPartiallyCommited() { mPartiallyCommited = true; } + bool isPartiallyCommited() const { return mPartiallyCommited; } + + bool isOnlyOnePointerUsed(int *pointerId) const { + // Not in the dictionary word + int usedPointerCount = 0; + int usedPointerId = 0; + for (int i = 0; i < mMaxPointerCount; ++i) { + if (mProximityInfoStates[i].isUsed()) { + ++usedPointerCount; + usedPointerId = i; + } + } + if (usedPointerCount != 1) { + return false; + } + *pointerId = usedPointerId; + return true; + } + + void getSearchKeys(const DicNode *node, std::vector<int> *const outputSearchKeyVector) const { + for (int i = 0; i < MAX_POINTER_COUNT_G; ++i) { + if (!mProximityInfoStates[i].isUsed()) { + continue; + } + const int pointerId = node->getInputIndex(i); + const std::vector<int> *const searchKeyVector = + mProximityInfoStates[i].getSearchKeyVector(pointerId); + outputSearchKeyVector->insert(outputSearchKeyVector->end(), searchKeyVector->begin(), + searchKeyVector->end()); + } + } + + ProximityType getProximityTypeG(const DicNode *const node, const int childCodePoint) const { + ProximityType proximityType = UNRELATED_CHAR; + for (int i = 0; i < MAX_POINTER_COUNT_G; ++i) { + if (!mProximityInfoStates[i].isUsed()) { + continue; + } + const int pointerId = node->getInputIndex(i); + proximityType = mProximityInfoStates[i].getProximityTypeG(pointerId, childCodePoint); + ASSERT(proximityType == UNRELATED_CHAR || proximityType == MATCH_CHAR); + // TODO: Make this more generic + // Currently we assume there are only two types here -- UNRELATED_CHAR + // and MATCH_CHAR + if (proximityType != UNRELATED_CHAR) { + return proximityType; + } + } + return proximityType; + } + + AK_FORCE_INLINE bool isCacheBorderForTyping(const int inputSize) const { + return mDicNodesCache.isCacheBorderForTyping(inputSize); + } + + /** + * Returns whether or not it is possible to continue suggestion from the previous search. + */ + // TODO: Remove. No need to check once the session is fully implemented. + bool isContinuousSuggestionPossible() const { + if (!mDicNodesCache.hasCachedDicNodesForContinuousSuggestion()) { + return false; + } + ASSERT(mMaxPointerCount < MAX_POINTER_COUNT_G); + for (int i = 0; i < mMaxPointerCount; ++i) { + const ProximityInfoState *const pInfoState = getProximityInfoState(i); + // If a proximity info state is not continuous suggestion possible, + // do not continue searching. + if (pInfoState->isUsed() && !pInfoState->isContinuousSuggestionPossible()) { + return false; + } + } + return true; + } + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(DicTraverseSession); + // threshold to start caching + static const int CACHE_START_INPUT_LENGTH_THRESHOLD; + void initializeProximityInfoStates(const int *const inputCodePoints, const int *const inputXs, + const int *const inputYs, const int *const times, const int *const pointerIds, + const int inputSize, const float maxSpatialDistance, const int maxPointerCount); + + int mPrevWordPos; + const ProximityInfo *mProximityInfo; + const Dictionary *mDictionary; + + DicNodesCache mDicNodesCache; + // Temporary cache for bigram frequencies + hash_map_compat<int, int16_t> mBigramCacheMap; + ProximityInfoState mProximityInfoStates[MAX_POINTER_COUNT_G]; + + int mInputSize; + bool mPartiallyCommited; + int mMaxPointerCount; +}; +} // namespace latinime +#endif // LATINIME_DIC_TRAVERSE_SESSION_H diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp new file mode 100644 index 000000000..7fba1d504 --- /dev/null +++ b/native/jni/src/suggest/core/suggest.cpp @@ -0,0 +1,518 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "char_utils.h" +#include "dictionary.h" +#include "dic_node_priority_queue.h" +#include "dic_node_vector.h" +#include "dic_traverse_session.h" +#include "proximity_info.h" +#include "scoring.h" +#include "shortcut_utils.h" +#include "suggest.h" +#include "terminal_attributes.h" +#include "traversal.h" +#include "weighting.h" + +namespace latinime { + +// Initialization of class constants. +const int Suggest::LOOKAHEAD_DIC_NODES_CACHE_SIZE = 25; +const int Suggest::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; +const int Suggest::MIN_CONTINUOUS_SUGGESTION_INPUT_SIZE = 2; +const float Suggest::AUTOCORRECT_CLASSIFICATION_THRESHOLD = 0.33f; +const float Suggest::AUTOCORRECT_LANGUAGE_FEATURE_THRESHOLD = 0.6f; + +const bool Suggest::CORRECT_SPACE_OMISSION = true; +const bool Suggest::CORRECT_TRANSPOSITION = true; +const bool Suggest::CORRECT_INSERTION = true; +const bool Suggest::CORRECT_OMISSION_G = true; + +/** + * Returns a set of suggestions for the given input touch points. The commitPoint argument indicates + * whether to prematurely commit the suggested words up to the given point for sentence-level + * suggestion. + * + * Note: Currently does not support concurrent calls across threads. Continuous suggestion is + * automatically activated for sequential calls that share the same starting input. + * TODO: Stop detecting continuous suggestion. Start using traverseSession instead. + */ +int Suggest::getSuggestions(ProximityInfo *pInfo, void *traverseSession, + int *inputXs, int *inputYs, int *times, int *pointerIds, int *inputCodePoints, + int inputSize, int commitPoint, int *outWords, int *frequencies, int *outputIndices, + int *outputTypes) const { + PROF_OPEN; + PROF_START(0); + const float maxSpatialDistance = TRAVERSAL->getMaxSpatialDistance(); + DicTraverseSession *tSession = static_cast<DicTraverseSession *>(traverseSession); + tSession->setupForGetSuggestions(pInfo, inputCodePoints, inputSize, inputXs, inputYs, times, + pointerIds, maxSpatialDistance, TRAVERSAL->getMaxPointerCount()); + // TODO: Add the way to evaluate cache + + initializeSearch(tSession, commitPoint); + PROF_END(0); + PROF_START(1); + + // keep expanding search dicNodes until all have terminated. + while (tSession->getDicTraverseCache()->activeSize() > 0) { + expandCurrentDicNodes(tSession); + tSession->getDicTraverseCache()->advanceActiveDicNodes(); + tSession->getDicTraverseCache()->advanceInputIndex(inputSize); + } + PROF_END(1); + PROF_START(2); + const int size = outputSuggestions(tSession, frequencies, outWords, outputIndices, outputTypes); + PROF_END(2); + PROF_CLOSE; + return size; +} + +/** + * Initializes the search at the root of the lexicon trie. Note that when possible the search will + * continue suggestion from where it left off during the last call. + */ +void Suggest::initializeSearch(DicTraverseSession *traverseSession, int commitPoint) const { + if (!traverseSession->getProximityInfoState(0)->isUsed()) { + return; + } + if (TRAVERSAL->allowPartialCommit()) { + commitPoint = 0; + } + + if (traverseSession->getInputSize() > MIN_CONTINUOUS_SUGGESTION_INPUT_SIZE + && traverseSession->isContinuousSuggestionPossible()) { + if (commitPoint == 0) { + // Continue suggestion + traverseSession->getDicTraverseCache()->continueSearch(); + } else { + // Continue suggestion after partial commit. + DicNode *topDicNode = + traverseSession->getDicTraverseCache()->setCommitPoint(commitPoint); + traverseSession->setPrevWordPos(topDicNode->getPrevWordNodePos()); + traverseSession->getDicTraverseCache()->continueSearch(); + traverseSession->setPartiallyCommited(); + } + } else { + // Restart recognition at the root. + traverseSession->resetCache(TRAVERSAL->getMaxCacheSize(), MAX_RESULTS); + // Create a new dic node here + DicNode rootNode; + DicNodeUtils::initAsRoot(traverseSession->getDicRootPos(), + traverseSession->getOffsetDict(), traverseSession->getPrevWordPos(), &rootNode); + traverseSession->getDicTraverseCache()->copyPushActive(&rootNode); + } +} + +/** + * Outputs the final list of suggestions (i.e., terminal nodes). + */ +int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequencies, + int *outputCodePoints, int *spaceIndices, int *outputTypes) const { + const int terminalSize = min(MAX_RESULTS, + static_cast<int>(traverseSession->getDicTraverseCache()->terminalSize())); + DicNode terminals[MAX_RESULTS]; // Avoiding non-POD variable length array + + for (int index = terminalSize - 1; index >= 0; --index) { + traverseSession->getDicTraverseCache()->popTerminal(&terminals[index]); + } + + const float languageWeight = SCORING->getAdjustedLanguageWeight( + traverseSession, terminals, terminalSize); + + int outputWordIndex = 0; + // Insert most probable word at index == 0 as long as there is one terminal at least + const bool hasMostProbableString = + SCORING->getMostProbableString(traverseSession, terminalSize, languageWeight, + &outputCodePoints[0], &outputTypes[0], &frequencies[0]); + if (hasMostProbableString) { + ++outputWordIndex; + } + + // Initial value of the loop index for terminal nodes (words) + int doubleLetterTerminalIndex = -1; + DoubleLetterLevel doubleLetterLevel = NOT_A_DOUBLE_LETTER; + SCORING->searchWordWithDoubleLetter(terminals, terminalSize, + &doubleLetterTerminalIndex, &doubleLetterLevel); + + int maxScore = S_INT_MIN; + // Output suggestion results here + for (int terminalIndex = 0; terminalIndex < terminalSize && outputWordIndex < MAX_RESULTS; + ++terminalIndex) { + DicNode *terminalDicNode = &terminals[terminalIndex]; + if (DEBUG_GEO_FULL) { + terminalDicNode->dump("OUT:"); + } + const float doubleLetterCost = SCORING->getDoubleLetterDemotionDistanceCost( + terminalIndex, doubleLetterTerminalIndex, doubleLetterLevel); + const float compoundDistance = terminalDicNode->getCompoundDistance(languageWeight) + + doubleLetterCost; + const TerminalAttributes terminalAttributes(traverseSession->getOffsetDict(), + terminalDicNode->getFlags(), terminalDicNode->getAttributesPos()); + const int originalTerminalProbability = terminalDicNode->getProbability(); + + // Do not suggest words with a 0 probability, or entries that are blacklisted or do not + // represent a word. However, we should still submit their shortcuts if any. + const bool isValidWord = + originalTerminalProbability > 0 && !terminalAttributes.isBlacklistedOrNotAWord(); + // Increase output score of top typing suggestion to ensure autocorrection. + // TODO: Better integration with java side autocorrection logic. + // Force autocorrection for obvious long multi-word suggestions. + const bool isForceCommitMultiWords = TRAVERSAL->allowPartialCommit() + && (traverseSession->isPartiallyCommited() + || (traverseSession->getInputSize() >= MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT + && terminalDicNode->hasMultipleWords())); + + const int finalScore = SCORING->calculateFinalScore( + compoundDistance, traverseSession->getInputSize(), + isForceCommitMultiWords || (isValidWord && SCORING->doesAutoCorrectValidWord())); + + maxScore = max(maxScore, finalScore); + + if (TRAVERSAL->allowPartialCommit()) { + // Index for top typing suggestion should be 0. + if (isValidWord && outputWordIndex == 0) { + terminalDicNode->outputSpacePositionsResult(spaceIndices); + } + } + + // Do not suggest words with a 0 probability, or entries that are blacklisted or do not + // represent a word. However, we should still submit their shortcuts if any. + if (isValidWord) { + outputTypes[outputWordIndex] = Dictionary::KIND_CORRECTION; + frequencies[outputWordIndex] = finalScore; + // Populate the outputChars array with the suggested word. + const int startIndex = outputWordIndex * MAX_WORD_LENGTH; + terminalDicNode->outputResult(&outputCodePoints[startIndex]); + ++outputWordIndex; + } + + const bool sameAsTyped = TRAVERSAL->sameAsTyped(traverseSession, terminalDicNode); + outputWordIndex = ShortcutUtils::outputShortcuts(&terminalAttributes, outputWordIndex, + finalScore, outputCodePoints, frequencies, outputTypes, sameAsTyped); + DicNode::managedDelete(terminalDicNode); + } + + if (hasMostProbableString) { + SCORING->safetyNetForMostProbableString(terminalSize, maxScore, + &outputCodePoints[0], &frequencies[0]); + } + return outputWordIndex; +} + +/** + * Expands the dicNodes in the current search priority queue by advancing to the possible child + * nodes based on the next touch point(s) (or no touch points for lookahead) + */ +void Suggest::expandCurrentDicNodes(DicTraverseSession *traverseSession) const { + const int inputSize = traverseSession->getInputSize(); + DicNodeVector childDicNodes(TRAVERSAL->getDefaultExpandDicNodeSize()); + DicNode omissionDicNode; + + // TODO: Find more efficient caching + const bool shouldDepthLevelCache = TRAVERSAL->shouldDepthLevelCache(traverseSession); + if (shouldDepthLevelCache) { + traverseSession->getDicTraverseCache()->updateLastCachedInputIndex(); + } + if (DEBUG_CACHE) { + AKLOGI("expandCurrentDicNodes depth level cache = %d, inputSize = %d", + shouldDepthLevelCache, inputSize); + } + while (traverseSession->getDicTraverseCache()->activeSize() > 0) { + DicNode dicNode; + traverseSession->getDicTraverseCache()->popActive(&dicNode); + if (dicNode.isTotalInputSizeExceedingLimit()) { + return; + } + childDicNodes.clear(); + const int point0Index = dicNode.getInputIndex(0); + const bool canDoLookAheadCorrection = + TRAVERSAL->canDoLookAheadCorrection(traverseSession, &dicNode); + const bool isLookAheadCorrection = canDoLookAheadCorrection + && traverseSession->getDicTraverseCache()-> + isLookAheadCorrectionInputIndex(static_cast<int>(point0Index)); + const bool isCompletion = dicNode.isCompletion(inputSize); + + const bool shouldNodeLevelCache = + TRAVERSAL->shouldNodeLevelCache(traverseSession, &dicNode); + if (shouldDepthLevelCache || shouldNodeLevelCache) { + if (DEBUG_CACHE) { + dicNode.dump("PUSH_CACHE"); + } + traverseSession->getDicTraverseCache()->copyPushContinue(&dicNode); + dicNode.setCached(); + } + + if (isLookAheadCorrection) { + // The algorithm maintains a small set of "deferred" nodes that have not consumed the + // latest touch point yet. These are needed to apply look-ahead correction operations + // that require special handling of the latest touch point. For example, with insertions + // (e.g., "thiis" -> "this") the latest touch point should not be consumed at all. + if (CORRECT_TRANSPOSITION) { + processDicNodeAsTransposition(traverseSession, &dicNode); + } + if (CORRECT_INSERTION) { + processDicNodeAsInsertion(traverseSession, &dicNode); + } + } else { // !isLookAheadCorrection + // Only consider typing error corrections if the normalized compound distance is + // below a spatial distance threshold. + // NOTE: the threshold may need to be updated if scoring model changes. + // TODO: Remove. Do not prune node here. + const bool allowsErrorCorrections = TRAVERSAL->allowsErrorCorrections(&dicNode); + // Process for handling space substitution (e.g., hevis => he is) + if (allowsErrorCorrections + && TRAVERSAL->isSpaceSubstitutionTerminal(traverseSession, &dicNode)) { + createNextWordDicNode(traverseSession, &dicNode, true /* spaceSubstitution */); + } + + DicNodeUtils::getAllChildDicNodes( + &dicNode, traverseSession->getOffsetDict(), &childDicNodes); + + const int childDicNodesSize = childDicNodes.getSizeAndLock(); + for (int i = 0; i < childDicNodesSize; ++i) { + DicNode *const childDicNode = childDicNodes[i]; + if (isCompletion) { + // Handle forward lookahead when the lexicon letter exceeds the input size. + processDicNodeAsMatch(traverseSession, childDicNode); + continue; + } + if (allowsErrorCorrections + && TRAVERSAL->isOmission(traverseSession, &dicNode, childDicNode)) { + // TODO: (Gesture) Change weight between omission and substitution errors + // TODO: (Gesture) Terminal node should not be handled as omission + omissionDicNode.initByCopy(childDicNode); + processDicNodeAsOmission(traverseSession, &omissionDicNode); + } + const ProximityType proximityType = TRAVERSAL->getProximityType( + traverseSession, &dicNode, childDicNode); + switch (proximityType) { + // TODO: Consider the difference of proximityType here + case MATCH_CHAR: + case PROXIMITY_CHAR: + processDicNodeAsMatch(traverseSession, childDicNode); + break; + case ADDITIONAL_PROXIMITY_CHAR: + if (allowsErrorCorrections) { + processDicNodeAsAdditionalProximityChar(traverseSession, &dicNode, + childDicNode); + } + break; + case SUBSTITUTION_CHAR: + if (allowsErrorCorrections) { + processDicNodeAsSubstitution(traverseSession, &dicNode, childDicNode); + } + break; + case UNRELATED_CHAR: + // Just drop this node and do nothing. + break; + default: + // Just drop this node and do nothing. + break; + } + } + + // Push the node for look-ahead correction + if (allowsErrorCorrections && canDoLookAheadCorrection) { + traverseSession->getDicTraverseCache()->copyPushNextActive(&dicNode); + } + } + } +} + +void Suggest::processTerminalDicNode( + DicTraverseSession *traverseSession, DicNode *dicNode) const { + if (dicNode->getCompoundDistance() >= static_cast<float>(MAX_VALUE_FOR_WEIGHTING)) { + return; + } + if (!dicNode->isTerminalWordNode()) { + return; + } + if (TRAVERSAL->needsToTraverseAllUserInput() + && dicNode->getInputIndex(0) < traverseSession->getInputSize()) { + return; + } + + if (dicNode->shouldBeFilterdBySafetyNetForBigram()) { + return; + } + // Create a non-cached node here. + DicNode terminalDicNode; + DicNodeUtils::initByCopy(dicNode, &terminalDicNode); + Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TERMINAL, traverseSession, 0, + &terminalDicNode, traverseSession->getBigramCacheMap()); + traverseSession->getDicTraverseCache()->copyPushTerminal(&terminalDicNode); +} + +/** + * Adds the expanded dicNode to the next search priority queue. Also creates an additional next word + * (by the space omission error correction) search path if input dicNode is on a terminal node. + */ +void Suggest::processExpandedDicNode( + DicTraverseSession *traverseSession, DicNode *dicNode) const { + processTerminalDicNode(traverseSession, dicNode); + if (dicNode->getCompoundDistance() < static_cast<float>(MAX_VALUE_FOR_WEIGHTING)) { + if (TRAVERSAL->isSpaceOmissionTerminal(traverseSession, dicNode)) { + createNextWordDicNode(traverseSession, dicNode, false /* spaceSubstitution */); + } + const int allowsLookAhead = !(dicNode->hasMultipleWords() + && dicNode->isCompletion(traverseSession->getInputSize())); + if (dicNode->hasChildren() && allowsLookAhead) { + traverseSession->getDicTraverseCache()->copyPushNextActive(dicNode); + } + } + DicNode::managedDelete(dicNode); +} + +void Suggest::processDicNodeAsMatch(DicTraverseSession *traverseSession, + DicNode *childDicNode) const { + weightChildNode(traverseSession, childDicNode); + processExpandedDicNode(traverseSession, childDicNode); +} + +void Suggest::processDicNodeAsAdditionalProximityChar(DicTraverseSession *traverseSession, + DicNode *dicNode, DicNode *childDicNode) const { + Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_ADDITIONAL_PROXIMITY, + traverseSession, dicNode, childDicNode, 0 /* bigramCacheMap */); + weightChildNode(traverseSession, childDicNode); + processExpandedDicNode(traverseSession, childDicNode); +} + +void Suggest::processDicNodeAsSubstitution(DicTraverseSession *traverseSession, + DicNode *dicNode, DicNode *childDicNode) const { + Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_SUBSTITUTION, traverseSession, + dicNode, childDicNode, 0 /* bigramCacheMap */); + weightChildNode(traverseSession, childDicNode); + processExpandedDicNode(traverseSession, childDicNode); +} + +/** + * Handle the dicNode as an omission error (e.g., ths => this). Skip the current letter and consider + * matches for all possible next letters. Note that just skipping the current letter without any + * other conditions tends to flood the search dic nodes cache with omission nodes. Instead, check + * the possible *next* letters after the omission to better limit search to plausible omissions. + * Note that apostrophes are handled as omissions. + */ +void Suggest::processDicNodeAsOmission( + DicTraverseSession *traverseSession, DicNode *dicNode) const { + // If the omission is surely intentional that it should incur zero cost. + const bool isZeroCostOmission = dicNode->isZeroCostOmission(); + DicNodeVector childDicNodes; + + DicNodeUtils::getAllChildDicNodes(dicNode, traverseSession->getOffsetDict(), &childDicNodes); + + const int size = childDicNodes.getSizeAndLock(); + for (int i = 0; i < size; i++) { + DicNode *const childDicNode = childDicNodes[i]; + if (!isZeroCostOmission) { + // Treat this word as omission + Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_OMISSION, traverseSession, + dicNode, childDicNode, 0 /* bigramCacheMap */); + } + weightChildNode(traverseSession, childDicNode); + + if (!TRAVERSAL->isPossibleOmissionChildNode(traverseSession, dicNode, childDicNode)) { + DicNode::managedDelete(childDicNode); + continue; + } + processExpandedDicNode(traverseSession, childDicNode); + } +} + +/** + * Handle the dicNode as an insertion error (e.g., thiis => this). Skip the current touch point and + * consider matches for the next touch point. + */ +void Suggest::processDicNodeAsInsertion(DicTraverseSession *traverseSession, + DicNode *dicNode) const { + const int16_t pointIndex = dicNode->getInputIndex(0); + DicNodeVector childDicNodes; + DicNodeUtils::getProximityChildDicNodes(dicNode, traverseSession->getOffsetDict(), + traverseSession->getProximityInfoState(0), pointIndex + 1, true, &childDicNodes); + const int size = childDicNodes.getSizeAndLock(); + for (int i = 0; i < size; i++) { + DicNode *const childDicNode = childDicNodes[i]; + Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_INSERTION, traverseSession, + dicNode, childDicNode, 0 /* bigramCacheMap */); + processExpandedDicNode(traverseSession, childDicNode); + } +} + +/** + * Handle the dicNode as a transposition error (e.g., thsi => this). Swap the next two touch points. + */ +void Suggest::processDicNodeAsTransposition(DicTraverseSession *traverseSession, + DicNode *dicNode) const { + const int16_t pointIndex = dicNode->getInputIndex(0); + DicNodeVector childDicNodes1; + DicNodeUtils::getProximityChildDicNodes(dicNode, traverseSession->getOffsetDict(), + traverseSession->getProximityInfoState(0), pointIndex + 1, false, &childDicNodes1); + const int childSize1 = childDicNodes1.getSizeAndLock(); + for (int i = 0; i < childSize1; i++) { + if (childDicNodes1[i]->hasChildren()) { + DicNodeVector childDicNodes2; + DicNodeUtils::getProximityChildDicNodes( + childDicNodes1[i], traverseSession->getOffsetDict(), + traverseSession->getProximityInfoState(0), pointIndex, false, &childDicNodes2); + const int childSize2 = childDicNodes2.getSizeAndLock(); + for (int j = 0; j < childSize2; j++) { + DicNode *const childDicNode2 = childDicNodes2[j]; + Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TRANSPOSITION, + traverseSession, childDicNodes1[i], childDicNode2, 0 /* bigramCacheMap */); + processExpandedDicNode(traverseSession, childDicNode2); + } + } + DicNode::managedDelete(childDicNodes1[i]); + } +} + +/** + * Weight child node by aligning it to the key + */ +void Suggest::weightChildNode(DicTraverseSession *traverseSession, DicNode *dicNode) const { + const int inputSize = traverseSession->getInputSize(); + if (dicNode->isCompletion(inputSize)) { + Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_COMPLETION, traverseSession, + 0 /* parentDicNode */, dicNode, 0 /* bigramCacheMap */); + } else { // completion + Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_MATCH, traverseSession, + 0 /* parentDicNode */, dicNode, 0 /* bigramCacheMap */); + } +} + +/** + * Creates a new dicNode that represents a space insertion at the end of the input dicNode. Also + * incorporates the unigram / bigram score for the ending word into the new dicNode. + */ +void Suggest::createNextWordDicNode(DicTraverseSession *traverseSession, DicNode *dicNode, + const bool spaceSubstitution) const { + if (!TRAVERSAL->isGoodToTraverseNextWord(dicNode)) { + return; + } + + // Create a non-cached node here. + DicNode newDicNode; + DicNodeUtils::initAsRootWithPreviousWord(traverseSession->getDicRootPos(), + traverseSession->getOffsetDict(), dicNode, &newDicNode); + Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_NEW_WORD, traverseSession, dicNode, + &newDicNode, traverseSession->getBigramCacheMap()); + if (spaceSubstitution) { + // Merge this with CT_NEW_WORD + Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_SPACE_SUBSTITUTION, + traverseSession, 0, &newDicNode, 0 /* bigramCacheMap */); + } + traverseSession->getDicTraverseCache()->copyPushNextActive(&newDicNode); +} +} // namespace latinime diff --git a/native/jni/src/suggest/core/suggest.h b/native/jni/src/suggest/core/suggest.h new file mode 100644 index 000000000..75d646bdd --- /dev/null +++ b/native/jni/src/suggest/core/suggest.h @@ -0,0 +1,95 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_SUGGEST_IMPL_H +#define LATINIME_SUGGEST_IMPL_H + +#include "defines.h" +#include "suggest_interface.h" +#include "suggest_policy.h" + +namespace latinime { + +class DicNode; +class DicTraverseSession; +class ProximityInfo; +class Scoring; +class Traversal; +class Weighting; + +class Suggest : public SuggestInterface { + public: + AK_FORCE_INLINE Suggest(const SuggestPolicy *const suggestPolicy) + : TRAVERSAL(suggestPolicy->getTraversal()), + SCORING(suggestPolicy->getScoring()), WEIGHTING(suggestPolicy->getWeighting()) {} + AK_FORCE_INLINE virtual ~Suggest() {} + int getSuggestions(ProximityInfo *pInfo, void *traverseSession, int *inputXs, int *inputYs, + int *times, int *pointerIds, int *inputCodePoints, int inputSize, int commitPoint, + int *outWords, int *frequencies, int *outputIndices, int *outputTypes) const; + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(Suggest); + void createNextWordDicNode(DicTraverseSession *traverseSession, DicNode *dicNode, + const bool spaceSubstitution) const; + int outputSuggestions(DicTraverseSession *traverseSession, int *frequencies, + int *outputCodePoints, int *outputIndices, int *outputTypes) const; + void initializeSearch(DicTraverseSession *traverseSession, int commitPoint) const; + void expandCurrentDicNodes(DicTraverseSession *traverseSession) const; + void processTerminalDicNode(DicTraverseSession *traverseSession, DicNode *dicNode) const; + void processExpandedDicNode(DicTraverseSession *traverseSession, DicNode *dicNode) const; + void weightChildNode(DicTraverseSession *traverseSession, DicNode *dicNode) const; + float getAutocorrectScore(DicTraverseSession *traverseSession, DicNode *dicNode) const; + void generateFeatures( + DicTraverseSession *traverseSession, DicNode *dicNode, float *features) const; + void processDicNodeAsOmission(DicTraverseSession *traverseSession, DicNode *dicNode) const; + void processDicNodeAsTransposition(DicTraverseSession *traverseSession, + DicNode *dicNode) const; + void processDicNodeAsInsertion(DicTraverseSession *traverseSession, DicNode *dicNode) const; + void processDicNodeAsAdditionalProximityChar(DicTraverseSession *traverseSession, + DicNode *dicNode, DicNode *childDicNode) const; + void processDicNodeAsSubstitution(DicTraverseSession *traverseSession, DicNode *dicNode, + DicNode *childDicNode) const; + void processDicNodeAsMatch(DicTraverseSession *traverseSession, + DicNode *childDicNode) const; + + // Dic nodes cache size for lookahead (autocompletion) + static const int LOOKAHEAD_DIC_NODES_CACHE_SIZE; + // Max characters to lookahead + static const int MAX_LOOKAHEAD; + // Inputs longer than this will autocorrect if the suggestion is multi-word + static const int MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT; + static const int MIN_CONTINUOUS_SUGGESTION_INPUT_SIZE; + // Base value for converting costs into scores (low so will not autocorrect without classifier) + static const float BASE_OUTPUT_SCORE; + + // Threshold for autocorrection classifier + static const float AUTOCORRECT_CLASSIFICATION_THRESHOLD; + // Threshold for computing the language model feature for autocorrect classification + static const float AUTOCORRECT_LANGUAGE_FEATURE_THRESHOLD; + + // Typing error correction settings + static const bool CORRECT_SPACE_OMISSION; + static const bool CORRECT_TRANSPOSITION; + static const bool CORRECT_INSERTION; + + const Traversal *const TRAVERSAL; + const Scoring *const SCORING; + const Weighting *const WEIGHTING; + + static const bool CORRECT_OMISSION_G; +}; +} // namespace latinime +#endif // LATINIME_SUGGEST_IMPL_H diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp new file mode 100644 index 000000000..90985d0fe --- /dev/null +++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp @@ -0,0 +1,52 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "scoring_params.h" + +namespace latinime { +// TODO: RENAME all +const float ScoringParams::MAX_SPATIAL_DISTANCE = 1.0f; +const int ScoringParams::THRESHOLD_NEXT_WORD_PROBABILITY = 40; +const int ScoringParams::THRESHOLD_NEXT_WORD_PROBABILITY_FOR_CAPPED = 120; +const float ScoringParams::AUTOCORRECT_OUTPUT_THRESHOLD = 1.0f; +const int ScoringParams::MAX_CACHE_DIC_NODE_SIZE = 125; +const int ScoringParams::THRESHOLD_SHORT_WORD_LENGTH = 4; + +const float ScoringParams::DISTANCE_WEIGHT_LENGTH = 0.132f; +const float ScoringParams::PROXIMITY_COST = 0.086f; +const float ScoringParams::FIRST_PROXIMITY_COST = 0.104f; +const float ScoringParams::OMISSION_COST = 0.388f; +const float ScoringParams::OMISSION_COST_SAME_CHAR = 0.431f; +const float ScoringParams::OMISSION_COST_FIRST_CHAR = 0.532f; +const float ScoringParams::INSERTION_COST = 0.670f; +const float ScoringParams::INSERTION_COST_SAME_CHAR = 0.526f; +const float ScoringParams::INSERTION_COST_FIRST_CHAR = 0.563f; +const float ScoringParams::TRANSPOSITION_COST = 0.494f; +const float ScoringParams::SPACE_SUBSTITUTION_COST = 0.239f; +const float ScoringParams::ADDITIONAL_PROXIMITY_COST = 0.380f; +const float ScoringParams::SUBSTITUTION_COST = 0.363f; +const float ScoringParams::COST_NEW_WORD = 0.054f; +const float ScoringParams::COST_NEW_WORD_CAPITALIZED = 0.174f; +const float ScoringParams::DISTANCE_WEIGHT_LANGUAGE = 1.123f; +const float ScoringParams::COST_FIRST_LOOKAHEAD = 0.462f; +const float ScoringParams::COST_LOOKAHEAD = 0.092f; +const float ScoringParams::HAS_PROXIMITY_TERMINAL_COST = 0.126f; +const float ScoringParams::HAS_EDIT_CORRECTION_TERMINAL_COST = 0.056f; +const float ScoringParams::HAS_MULTI_WORD_TERMINAL_COST = 0.136f; +const float ScoringParams::TYPING_BASE_OUTPUT_SCORE = 1.0f; +const float ScoringParams::TYPING_MAX_OUTPUT_SCORE_PER_INPUT = 0.1f; +const float ScoringParams::MAX_NORM_DISTANCE_FOR_EDIT = 0.1f; +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.h b/native/jni/src/suggest/policyimpl/typing/scoring_params.h new file mode 100644 index 000000000..8f104b362 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.h @@ -0,0 +1,66 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_SCORING_PARAMS_H +#define LATINIME_SCORING_PARAMS_H + +#include "defines.h" + +namespace latinime { + +class ScoringParams { + public: + // Fixed model parameters + static const float MAX_SPATIAL_DISTANCE; + static const int THRESHOLD_NEXT_WORD_PROBABILITY; + static const int THRESHOLD_NEXT_WORD_PROBABILITY_FOR_CAPPED; + static const float AUTOCORRECT_OUTPUT_THRESHOLD; + static const int MAX_CACHE_DIC_NODE_SIZE; + static const int THRESHOLD_SHORT_WORD_LENGTH; + + // Numerically optimized parameters (currently for tap typing only). + // TODO: add ability to modify these constants programmatically. + // TODO: explore optimization of gesture parameters. + static const float DISTANCE_WEIGHT_LENGTH; + static const float PROXIMITY_COST; + static const float FIRST_PROXIMITY_COST; + static const float OMISSION_COST; + static const float OMISSION_COST_SAME_CHAR; + static const float OMISSION_COST_FIRST_CHAR; + static const float INSERTION_COST; + static const float INSERTION_COST_SAME_CHAR; + static const float INSERTION_COST_FIRST_CHAR; + static const float TRANSPOSITION_COST; + static const float SPACE_SUBSTITUTION_COST; + static const float ADDITIONAL_PROXIMITY_COST; + static const float SUBSTITUTION_COST; + static const float COST_NEW_WORD; + static const float COST_NEW_WORD_CAPITALIZED; + static const float DISTANCE_WEIGHT_LANGUAGE; + static const float COST_FIRST_LOOKAHEAD; + static const float COST_LOOKAHEAD; + static const float HAS_PROXIMITY_TERMINAL_COST; + static const float HAS_EDIT_CORRECTION_TERMINAL_COST; + static const float HAS_MULTI_WORD_TERMINAL_COST; + static const float TYPING_BASE_OUTPUT_SCORE; + static const float TYPING_MAX_OUTPUT_SCORE_PER_INPUT; + static const float MAX_NORM_DISTANCE_FOR_EDIT; + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(ScoringParams); +}; +} // namespace latinime +#endif // LATINIME_SCORING_PARAMS_H diff --git a/native/jni/src/suggest/policyimpl/typing/typing_scoring.cpp b/native/jni/src/suggest/policyimpl/typing/typing_scoring.cpp new file mode 100644 index 000000000..53f68f20f --- /dev/null +++ b/native/jni/src/suggest/policyimpl/typing/typing_scoring.cpp @@ -0,0 +1,21 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "typing_scoring.h" + +namespace latinime { +const TypingScoring TypingScoring::sInstance; +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h new file mode 100644 index 000000000..ed941f0ae --- /dev/null +++ b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h @@ -0,0 +1,82 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_TYPING_SCORING_H +#define LATINIME_TYPING_SCORING_H + +#include "defines.h" +#include "scoring.h" +#include "scoring_params.h" + +namespace latinime { + +class DicNode; +class DicTraverseSession; + +class TypingScoring : public Scoring { + public: + static const TypingScoring *getInstance() { return &sInstance; } + + AK_FORCE_INLINE bool getMostProbableString( + const DicTraverseSession *const traverseSession, const int terminalSize, + const float languageWeight, int *const outputCodePoints, int *const type, + int *const freq) const { + return false; + } + + AK_FORCE_INLINE void safetyNetForMostProbableString(const int terminalSize, + const int maxScore, int *const outputCodePoints, int *const frequencies) const { + } + + AK_FORCE_INLINE void searchWordWithDoubleLetter(DicNode *terminals, + const int terminalSize, int *doubleLetterTerminalIndex, + DoubleLetterLevel *doubleLetterLevel) const { + } + + AK_FORCE_INLINE float getAdjustedLanguageWeight(DicTraverseSession *const traverseSession, + DicNode *const terminals, const int size) const { + return 1.0f; + } + + AK_FORCE_INLINE int calculateFinalScore(const float compoundDistance, + const int inputSize, const bool forceCommit) const { + const float maxDistance = ScoringParams::DISTANCE_WEIGHT_LANGUAGE + + static_cast<float>(inputSize) * ScoringParams::TYPING_MAX_OUTPUT_SCORE_PER_INPUT; + return static_cast<int>((ScoringParams::TYPING_BASE_OUTPUT_SCORE + - (compoundDistance / maxDistance) + + (forceCommit ? ScoringParams::AUTOCORRECT_OUTPUT_THRESHOLD : 0.0f)) + * SUGGEST_INTERFACE_OUTPUT_SCALE); + } + + AK_FORCE_INLINE float getDoubleLetterDemotionDistanceCost(const int terminalIndex, + const int doubleLetterTerminalIndex, + const DoubleLetterLevel doubleLetterLevel) const { + return 0.0f; + } + + AK_FORCE_INLINE bool doesAutoCorrectValidWord() const { + return false; + } + + private: + DISALLOW_COPY_AND_ASSIGN(TypingScoring); + static const TypingScoring sInstance; + + TypingScoring() {} + ~TypingScoring() {} +}; +} // namespace latinime +#endif // LATINIME_TYPING_SCORING_H diff --git a/native/jni/src/suggest/policyimpl/typing/typing_suggest_policy.cpp b/native/jni/src/suggest/policyimpl/typing/typing_suggest_policy.cpp new file mode 100644 index 000000000..ebba37531 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/typing/typing_suggest_policy.cpp @@ -0,0 +1,42 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest.h" +#include "typing_suggest.h" +#include "typing_suggest_policy.h" + +namespace latinime { + +const TypingSuggestPolicy TypingSuggestPolicy::sInstance; + +// A factory method for a "typing" Suggest instance +static SuggestInterface *getTypingSuggestInstance() { + return new Suggest(TypingSuggestPolicy::getInstance()); +} + +// An ad-hoc internal class to register the factory method getTypingSuggestInstance() defined above +class TypingSuggestFactoryRegisterer { + public: + TypingSuggestFactoryRegisterer() { + TypingSuggest::setTypingSuggestFactoryMethod(getTypingSuggestInstance); + } + private: + DISALLOW_COPY_AND_ASSIGN(TypingSuggestFactoryRegisterer); +}; + +// To invoke the TypingSuggestFactoryRegisterer's constructor in the global constructor +static TypingSuggestFactoryRegisterer typingSuggestFactoryregisterer; +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/typing/typing_suggest_policy.h b/native/jni/src/suggest/policyimpl/typing/typing_suggest_policy.h new file mode 100644 index 000000000..55668fc25 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/typing/typing_suggest_policy.h @@ -0,0 +1,55 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_TYPING_SUGGEST_POLICY_H +#define LATINIME_TYPING_SUGGEST_POLICY_H + +#include "defines.h" +#include "suggest_policy.h" +#include "typing_scoring.h" +#include "typing_traversal.h" +#include "typing_weighting.h" + +namespace latinime { + +class Scoring; +class Traversal; +class Weighting; + +class TypingSuggestPolicy : public SuggestPolicy { + public: + static const TypingSuggestPolicy *getInstance() { return &sInstance; } + + TypingSuggestPolicy() {} + virtual ~TypingSuggestPolicy() {} + AK_FORCE_INLINE const Traversal *getTraversal() const { + return TypingTraversal::getInstance(); + } + + AK_FORCE_INLINE const Scoring *getScoring() const { + return TypingScoring::getInstance(); + } + + AK_FORCE_INLINE const Weighting *getWeighting() const { + return TypingWeighting::getInstance(); + } + + private: + DISALLOW_COPY_AND_ASSIGN(TypingSuggestPolicy); + static const TypingSuggestPolicy sInstance; +}; +} // namespace latinime +#endif // LATINIME_TYPING_SUGGEST_POLICY_H diff --git a/native/jni/src/suggest/policyimpl/typing/typing_traversal.cpp b/native/jni/src/suggest/policyimpl/typing/typing_traversal.cpp new file mode 100644 index 000000000..68c614e77 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/typing/typing_traversal.cpp @@ -0,0 +1,24 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "typing_traversal.h" + +namespace latinime { +const bool TypingTraversal::CORRECT_OMISSION = true; +const bool TypingTraversal::CORRECT_SPACE_SUBSTITUTION = true; +const bool TypingTraversal::CORRECT_SPACE_OMISSION = true; +const TypingTraversal TypingTraversal::sInstance; +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/typing/typing_traversal.h b/native/jni/src/suggest/policyimpl/typing/typing_traversal.h new file mode 100644 index 000000000..16153f8bb --- /dev/null +++ b/native/jni/src/suggest/policyimpl/typing/typing_traversal.h @@ -0,0 +1,184 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_TYPING_TRAVERSAL_H +#define LATINIME_TYPING_TRAVERSAL_H + +#include <stdint.h> + +#include "char_utils.h" +#include "defines.h" +#include "dic_node.h" +#include "dic_node_vector.h" +#include "dic_traverse_session.h" +#include "proximity_info_state.h" +#include "scoring_params.h" +#include "traversal.h" + +namespace latinime { +class TypingTraversal : public Traversal { + public: + static const TypingTraversal *getInstance() { return &sInstance; } + + AK_FORCE_INLINE int getMaxPointerCount() const { + return MAX_POINTER_COUNT; + } + + AK_FORCE_INLINE bool allowsErrorCorrections(const DicNode *const dicNode) const { + return dicNode->getNormalizedSpatialDistance() + < ScoringParams::MAX_NORM_DISTANCE_FOR_EDIT; + } + + AK_FORCE_INLINE bool isOmission(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode, const DicNode *const childDicNode) const { + if (!CORRECT_OMISSION) { + return false; + } + const int inputSize = traverseSession->getInputSize(); + // TODO: Don't refer to isCompletion? + if (dicNode->isCompletion(inputSize)) { + return false; + } + if (dicNode->canBeIntentionalOmission()) { + return true; + } + const int point0Index = dicNode->getInputIndex(0); + const int currentBaseLowerCodePoint = + toBaseLowerCase(childDicNode->getNodeCodePoint()); + const int typedBaseLowerCodePoint = + toBaseLowerCase(traverseSession->getProximityInfoState(0) + ->getPrimaryCodePointAt(point0Index)); + return (currentBaseLowerCodePoint != typedBaseLowerCodePoint); + } + + AK_FORCE_INLINE bool isSpaceSubstitutionTerminal( + const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const { + if (!CORRECT_SPACE_SUBSTITUTION) { + return false; + } + if (!canDoLookAheadCorrection(traverseSession, dicNode)) { + return false; + } + const int point0Index = dicNode->getInputIndex(0); + return dicNode->isTerminalWordNode() + && traverseSession->getProximityInfoState(0)-> + hasSpaceProximity(point0Index); + } + + AK_FORCE_INLINE bool isSpaceOmissionTerminal( + const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const { + if (!CORRECT_SPACE_OMISSION) { + return false; + } + const int inputSize = traverseSession->getInputSize(); + // TODO: Don't refer to isCompletion? + if (dicNode->isCompletion(inputSize)) { + return false; + } + if (!dicNode->isTerminalWordNode()) { + return false; + } + const int16_t pointIndex = dicNode->getInputIndex(0); + return pointIndex <= inputSize && !dicNode->isTotalInputSizeExceedingLimit() + && !dicNode->shouldBeFilterdBySafetyNetForBigram(); + } + + AK_FORCE_INLINE bool shouldDepthLevelCache( + const DicTraverseSession *const traverseSession) const { + const int inputSize = traverseSession->getInputSize(); + return traverseSession->isCacheBorderForTyping(inputSize); + } + + AK_FORCE_INLINE bool shouldNodeLevelCache( + const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const { + return false; + } + + AK_FORCE_INLINE bool canDoLookAheadCorrection( + const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const { + const int inputSize = traverseSession->getInputSize(); + return dicNode->canDoLookAheadCorrection(inputSize); + } + + AK_FORCE_INLINE ProximityType getProximityType( + const DicTraverseSession *const traverseSession, const DicNode *const dicNode, + const DicNode *const childDicNode) const { + return traverseSession->getProximityInfoState(0)->getProximityType( + dicNode->getInputIndex(0), childDicNode->getNodeCodePoint(), + true /* checkProximityChars */); + } + + AK_FORCE_INLINE bool needsToTraverseAllUserInput() const { + return true; + } + + AK_FORCE_INLINE float getMaxSpatialDistance() const { + return ScoringParams::MAX_SPATIAL_DISTANCE; + } + + AK_FORCE_INLINE bool allowPartialCommit() const { + return true; + } + + AK_FORCE_INLINE int getDefaultExpandDicNodeSize() const { + return DicNodeVector::DEFAULT_NODES_SIZE_FOR_OPTIMIZATION; + } + + AK_FORCE_INLINE bool sameAsTyped( + const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const { + return traverseSession->getProximityInfoState(0)->sameAsTyped( + dicNode->getOutputWordBuf(), dicNode->getDepth()); + } + + AK_FORCE_INLINE int getMaxCacheSize() const { + return ScoringParams::MAX_CACHE_DIC_NODE_SIZE; + } + + AK_FORCE_INLINE bool isPossibleOmissionChildNode( + const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, + const DicNode *const dicNode) const { + const ProximityType proximityType = + getProximityType(traverseSession, parentDicNode, dicNode); + if (!DicNodeUtils::isProximityChar(proximityType)) { + return false; + } + return true; + } + + AK_FORCE_INLINE bool isGoodToTraverseNextWord(const DicNode *const dicNode) const { + const int probability = dicNode->getProbability(); + if (probability < ScoringParams::THRESHOLD_NEXT_WORD_PROBABILITY) { + return false; + } + const int c = dicNode->getOutputWordBuf()[0]; + const bool shortCappedWord = dicNode->getDepth() + < ScoringParams::THRESHOLD_SHORT_WORD_LENGTH && isAsciiUpper(c); + return !shortCappedWord + || probability >= ScoringParams::THRESHOLD_NEXT_WORD_PROBABILITY_FOR_CAPPED; + } + + private: + DISALLOW_COPY_AND_ASSIGN(TypingTraversal); + static const bool CORRECT_OMISSION; + static const bool CORRECT_SPACE_SUBSTITUTION; + static const bool CORRECT_SPACE_OMISSION; + static const TypingTraversal sInstance; + + TypingTraversal() {} + ~TypingTraversal() {} +}; +} // namespace latinime +#endif // LATINIME_TYPING_TRAVERSAL_H diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp b/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp new file mode 100644 index 000000000..6e4b2fb6a --- /dev/null +++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp @@ -0,0 +1,23 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "dic_node.h" +#include "scoring_params.h" +#include "typing_weighting.h" + +namespace latinime { +const TypingWeighting TypingWeighting::sInstance; +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h new file mode 100644 index 000000000..e8075f41a --- /dev/null +++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h @@ -0,0 +1,194 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_TYPING_WEIGHTING_H +#define LATINIME_TYPING_WEIGHTING_H + +#include "defines.h" +#include "dic_node_utils.h" +#include "dic_traverse_session.h" +#include "weighting.h" + +namespace latinime { + +class DicNode; +struct DicNode_InputStateG; + +class TypingWeighting : public Weighting { + public: + static const TypingWeighting *getInstance() { return &sInstance; } + + protected: + float getTerminalSpatialCost( + const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const { + float cost = 0.0f; + if (dicNode->hasMultipleWords()) { + cost += ScoringParams::HAS_MULTI_WORD_TERMINAL_COST; + } + if (dicNode->getProximityCorrectionCount() > 0) { + cost += ScoringParams::HAS_PROXIMITY_TERMINAL_COST; + } + if (dicNode->getEditCorrectionCount() > 0) { + cost += ScoringParams::HAS_EDIT_CORRECTION_TERMINAL_COST; + } + return cost; + } + + float getOmissionCost(const DicNode *const parentDicNode, const DicNode *const dicNode) const { + bool sameCodePoint = false; + bool isFirstLetterOmission = false; + float cost = 0.0f; + sameCodePoint = dicNode->isSameNodeCodePoint(parentDicNode); + // If the traversal omitted the first letter then the dicNode should now be on the second. + isFirstLetterOmission = dicNode->getDepth() == 2; + if (isFirstLetterOmission) { + cost = ScoringParams::OMISSION_COST_FIRST_CHAR; + } else { + cost = sameCodePoint ? ScoringParams::OMISSION_COST_SAME_CHAR + : ScoringParams::OMISSION_COST; + } + return cost; + } + + float getMatchedCost( + const DicTraverseSession *const traverseSession, const DicNode *const dicNode, + DicNode_InputStateG *inputStateG) const { + const int pointIndex = dicNode->getInputIndex(0); + // Note: min() required since length can be MAX_POINT_TO_KEY_LENGTH for characters not on + // the keyboard (like accented letters) + const float length = min(ScoringParams::MAX_SPATIAL_DISTANCE, + traverseSession->getProximityInfoState(0)->getPointToKeyLength( + pointIndex, dicNode->getNodeCodePoint())); + const float weightedDistance = length * ScoringParams::DISTANCE_WEIGHT_LENGTH; + const bool isFirstChar = pointIndex == 0; + const bool isProximity = isProximityDicNode(traverseSession, dicNode); + const float cost = isProximity ? (isFirstChar ? ScoringParams::FIRST_PROXIMITY_COST + : ScoringParams::PROXIMITY_COST) : 0.0f; + return weightedDistance + cost; + } + + bool isProximityDicNode( + const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const { + const int pointIndex = dicNode->getInputIndex(0); + const int primaryCodePoint = toBaseLowerCase( + traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt(pointIndex)); + const int dicNodeChar = toBaseLowerCase(dicNode->getNodeCodePoint()); + return primaryCodePoint != dicNodeChar; + } + + float getTranspositionCost( + const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, + const DicNode *const dicNode) const { + const int16_t parentPointIndex = parentDicNode->getInputIndex(0); + const int prevCodePoint = parentDicNode->getNodeCodePoint(); + const float distance1 = traverseSession->getProximityInfoState(0)->getPointToKeyLength( + parentPointIndex + 1, prevCodePoint); + const int codePoint = dicNode->getNodeCodePoint(); + const float distance2 = traverseSession->getProximityInfoState(0)->getPointToKeyLength( + parentPointIndex, codePoint); + const float distance = distance1 + distance2; + const float weightedLengthDistance = + distance * ScoringParams::DISTANCE_WEIGHT_LENGTH; + return ScoringParams::TRANSPOSITION_COST + weightedLengthDistance; + } + + float getInsertionCost( + const DicTraverseSession *const traverseSession, + const DicNode *const parentDicNode, const DicNode *const dicNode) const { + const int16_t parentPointIndex = parentDicNode->getInputIndex(0); + const int prevCodePoint = + traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt(parentPointIndex); + + const int currentCodePoint = dicNode->getNodeCodePoint(); + const bool sameCodePoint = prevCodePoint == currentCodePoint; + const float dist = traverseSession->getProximityInfoState(0)->getPointToKeyLength( + parentPointIndex + 1, currentCodePoint); + const float weightedDistance = dist * ScoringParams::DISTANCE_WEIGHT_LENGTH; + const bool singleChar = dicNode->getDepth() == 1; + const float cost = (singleChar ? ScoringParams::INSERTION_COST_FIRST_CHAR : 0.0f) + + (sameCodePoint ? ScoringParams::INSERTION_COST_SAME_CHAR + : ScoringParams::INSERTION_COST); + return cost + weightedDistance; + } + + float getNewWordCost(const DicNode *const dicNode) const { + const bool isCapitalized = dicNode->isCapitalized(); + return isCapitalized ? + ScoringParams::COST_NEW_WORD_CAPITALIZED : ScoringParams::COST_NEW_WORD; + } + + float getNewWordBigramCost( + const DicTraverseSession *const traverseSession, const DicNode *const dicNode, + hash_map_compat<int, int16_t> *const bigramCacheMap) const { + return DicNodeUtils::getBigramNodeImprobability(traverseSession->getOffsetDict(), + dicNode, bigramCacheMap); + } + + float getCompletionCost(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const { + // The auto completion starts when the input index is same as the input size + const bool firstCompletion = dicNode->getInputIndex(0) + == traverseSession->getInputSize(); + // TODO: Change the cost for the first completion for the gesture? + const float cost = firstCompletion ? ScoringParams::COST_FIRST_LOOKAHEAD + : ScoringParams::COST_LOOKAHEAD; + return cost; + } + + float getTerminalLanguageCost(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode, const float dicNodeLanguageImprobability) const { + const bool hasEditCount = dicNode->getEditCorrectionCount() > 0; + const bool isSameLength = dicNode->getDepth() == traverseSession->getInputSize(); + const bool hasMultipleWords = dicNode->hasMultipleWords(); + const bool hasProximityErrors = dicNode->getProximityCorrectionCount() > 0; + // Gesture input is always assumed to have proximity errors + // because the input word shouldn't be treated as perfect + const bool isExactMatch = !hasEditCount && !hasMultipleWords + && !hasProximityErrors && isSameLength; + + const float totalPrevWordsLanguageCost = dicNode->getTotalPrevWordsLanguageCost(); + const float languageImprobability = isExactMatch ? 0.0f : dicNodeLanguageImprobability; + const float languageWeight = ScoringParams::DISTANCE_WEIGHT_LANGUAGE; + // TODO: Caveat: The following equation should be: + // totalPrevWordsLanguageCost + (languageImprobability * languageWeight); + return (totalPrevWordsLanguageCost + languageImprobability) * languageWeight; + } + + AK_FORCE_INLINE bool needsToNormalizeCompoundDistance() const { + return false; + } + + AK_FORCE_INLINE float getAdditionalProximityCost() const { + return ScoringParams::ADDITIONAL_PROXIMITY_COST; + } + + AK_FORCE_INLINE float getSubstitutionCost() const { + return ScoringParams::SUBSTITUTION_COST; + } + + AK_FORCE_INLINE float getSpaceSubstitutionCost() const { + return ScoringParams::SPACE_SUBSTITUTION_COST; + } + + private: + DISALLOW_COPY_AND_ASSIGN(TypingWeighting); + static const TypingWeighting sInstance; + + TypingWeighting() {} + ~TypingWeighting() {} +}; +} // namespace latinime +#endif // LATINIME_TYPING_WEIGHTING_H |