diff options
Diffstat (limited to 'native/jni')
-rw-r--r-- | native/jni/src/defines.h | 20 | ||||
-rw-r--r-- | native/jni/src/dictionary.h | 5 | ||||
-rw-r--r-- | native/jni/src/suggest/core/dicnode/dic_node.h | 9 | ||||
-rw-r--r-- | native/jni/src/suggest/core/dicnode/dic_node_state_input.h | 4 | ||||
-rw-r--r-- | native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h | 37 | ||||
-rw-r--r-- | native/jni/src/suggest/core/policy/traversal.h | 3 | ||||
-rw-r--r-- | native/jni/src/suggest/core/policy/weighting.cpp | 63 | ||||
-rw-r--r-- | native/jni/src/suggest/core/policy/weighting.h | 10 | ||||
-rw-r--r-- | native/jni/src/suggest/core/suggest.cpp | 33 | ||||
-rw-r--r-- | native/jni/src/suggest/policyimpl/typing/typing_traversal.h | 9 | ||||
-rw-r--r-- | native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp | 35 | ||||
-rw-r--r-- | native/jni/src/suggest/policyimpl/typing/typing_weighting.h | 28 |
12 files changed, 142 insertions, 114 deletions
diff --git a/native/jni/src/defines.h b/native/jni/src/defines.h index e6719c9c3..d3b351f81 100644 --- a/native/jni/src/defines.h +++ b/native/jni/src/defines.h @@ -438,4 +438,24 @@ typedef enum { // Create new word with space substitution CT_NEW_WORD_SPACE_SUBSTITUTION, } CorrectionType; + +// ErrorType is mainly decided by CorrectionType but it is also depending on if +// the correction has really been performed or not. +typedef enum { + // Substitution, omission and transposition + ET_EDIT_CORRECTION, + // Proximity error + ET_PROXIMITY_CORRECTION, + // Completion + ET_COMPLETION, + // New word + // TODO: Remove. + // A new word error should be an edit correction error or a proximity correction error. + ET_NEW_WORD, + // Treat error as an intentional omission when the CorrectionType is omission and the node can + // be intentional omission. + ET_INTENTIONAL_OMISSION, + // Not treated as an error. Tracked for checking exact match + ET_NOT_AN_ERROR +} ErrorType; #endif // LATINIME_DEFINES_H diff --git a/native/jni/src/dictionary.h b/native/jni/src/dictionary.h index 0653d3ca9..2ad5b6c0b 100644 --- a/native/jni/src/dictionary.h +++ b/native/jni/src/dictionary.h @@ -31,6 +31,7 @@ class UnigramDictionary; class Dictionary { public: // Taken from SuggestedWords.java + static const int KIND_MASK_KIND = 0xFF; // Mask to get only the kind static const int KIND_TYPED = 0; // What user typed static const int KIND_CORRECTION = 1; // Simple correction/suggestion static const int KIND_COMPLETION = 2; // Completion (suggestion with appended chars) @@ -41,6 +42,10 @@ class Dictionary { static const int KIND_SHORTCUT = 7; // A shortcut static const int KIND_PREDICTION = 8; // A prediction (== a suggestion with no input) + static const int KIND_MASK_FLAGS = 0xFFFFFF00; // Mask to get the flags + static const int KIND_FLAG_POSSIBLY_OFFENSIVE = 0x80000000; + static const int KIND_FLAG_EXACT_MATCH = 0x40000000; + Dictionary(void *dict, int dictSize, int mmapFd, int dictBufAdjust); int getSuggestions(ProximityInfo *proximityInfo, void *traverseSession, int *xcoordinates, diff --git a/native/jni/src/suggest/core/dicnode/dic_node.h b/native/jni/src/suggest/core/dicnode/dic_node.h index e8432546b..92783dec7 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node.h +++ b/native/jni/src/suggest/core/dicnode/dic_node.h @@ -463,6 +463,10 @@ class DicNode { mDicNodeState.mDicNodeStateScoring.advanceDigraphIndex(); } + bool isExactMatch() const { + return mDicNodeState.mDicNodeStateScoring.isExactMatch(); + } + uint8_t getFlags() const { return mDicNodeProperties.getFlags(); } @@ -542,13 +546,12 @@ class DicNode { // Caveat: Must not be called outside Weighting // This restriction is guaranteed by "friend" AK_FORCE_INLINE void addCost(const float spatialCost, const float languageCost, - const bool doNormalization, const int inputSize, const bool isEditCorrection, - const bool isProximityCorrection) { + const bool doNormalization, const int inputSize, const ErrorType errorType) { if (DEBUG_GEO_FULL) { LOGI_SHOW_ADD_COST_PROP; } mDicNodeState.mDicNodeStateScoring.addCost(spatialCost, languageCost, doNormalization, - inputSize, getTotalInputIndex(), isEditCorrection, isProximityCorrection); + inputSize, getTotalInputIndex(), errorType); } // Caveat: Must not be called outside Weighting diff --git a/native/jni/src/suggest/core/dicnode/dic_node_state_input.h b/native/jni/src/suggest/core/dicnode/dic_node_state_input.h index 7ad3e3e5f..bbd9435b5 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_state_input.h +++ b/native/jni/src/suggest/core/dicnode/dic_node_state_input.h @@ -46,8 +46,8 @@ class DicNodeStateInput { for (int i = 0; i < MAX_POINTER_COUNT_G; i++) { mInputIndex[i] = src->mInputIndex[i]; mPrevCodePoint[i] = src->mPrevCodePoint[i]; - mTerminalDiffCost[i] = resetTerminalDiffCost ? - static_cast<float>(MAX_VALUE_FOR_WEIGHTING) : src->mTerminalDiffCost[i]; + mTerminalDiffCost[i] = resetTerminalDiffCost ? + static_cast<float>(MAX_VALUE_FOR_WEIGHTING) : src->mTerminalDiffCost[i]; } } diff --git a/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h b/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h index fd9d610e3..dca9d60da 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h +++ b/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h @@ -31,7 +31,7 @@ class DicNodeStateScoring { mDigraphIndex(DigraphUtils::NOT_A_DIGRAPH_INDEX), mEditCorrectionCount(0), mProximityCorrectionCount(0), mNormalizedCompoundDistance(0.0f), mSpatialDistance(0.0f), mLanguageDistance(0.0f), - mRawLength(0.0f) { + mRawLength(0.0f), mExactMatch(true) { } virtual ~DicNodeStateScoring() {} @@ -45,6 +45,7 @@ class DicNodeStateScoring { mRawLength = 0.0f; mDoubleLetterLevel = NOT_A_DOUBLE_LETTER; mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX; + mExactMatch = true; } AK_FORCE_INLINE void init(const DicNodeStateScoring *const scoring) { @@ -56,17 +57,32 @@ class DicNodeStateScoring { mRawLength = scoring->mRawLength; mDoubleLetterLevel = scoring->mDoubleLetterLevel; mDigraphIndex = scoring->mDigraphIndex; + mExactMatch = scoring->mExactMatch; } void addCost(const float spatialCost, const float languageCost, const bool doNormalization, - const int inputSize, const int totalInputIndex, const bool isEditCorrection, - const bool isProximityCorrection) { + const int inputSize, const int totalInputIndex, const ErrorType errorType) { addDistance(spatialCost, languageCost, doNormalization, inputSize, totalInputIndex); - if (isEditCorrection) { - ++mEditCorrectionCount; - } - if (isProximityCorrection) { - ++mProximityCorrectionCount; + switch (errorType) { + case ET_EDIT_CORRECTION: + ++mEditCorrectionCount; + mExactMatch = false; + break; + case ET_PROXIMITY_CORRECTION: + ++mProximityCorrectionCount; + mExactMatch = false; + break; + case ET_COMPLETION: + mExactMatch = false; + break; + case ET_NEW_WORD: + mExactMatch = false; + break; + case ET_INTENTIONAL_OMISSION: + mExactMatch = false; + break; + case ET_NOT_AN_ERROR: + break; } } @@ -143,6 +159,10 @@ class DicNodeStateScoring { } } + bool isExactMatch() const { + return mExactMatch; + } + private: // Caution!!! // Use a default copy constructor and an assign operator because shallow copies are ok @@ -157,6 +177,7 @@ class DicNodeStateScoring { float mSpatialDistance; float mLanguageDistance; float mRawLength; + bool mExactMatch; AK_FORCE_INLINE void addDistance(float spatialDistance, float languageDistance, bool doNormalization, int inputSize, int totalInputIndex) { diff --git a/native/jni/src/suggest/core/policy/traversal.h b/native/jni/src/suggest/core/policy/traversal.h index 02c358aec..d3146da7f 100644 --- a/native/jni/src/suggest/core/policy/traversal.h +++ b/native/jni/src/suggest/core/policy/traversal.h @@ -28,7 +28,8 @@ class Traversal { virtual int getMaxPointerCount() const = 0; virtual bool allowsErrorCorrections(const DicNode *const dicNode) const = 0; virtual bool isOmission(const DicTraverseSession *const traverseSession, - const DicNode *const dicNode, const DicNode *const childDicNode) const = 0; + const DicNode *const dicNode, const DicNode *const childDicNode, + const bool allowsErrorCorrections) const = 0; virtual bool isSpaceSubstitutionTerminal(const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const = 0; virtual bool isSpaceOmissionTerminal(const DicTraverseSession *const traverseSession, diff --git a/native/jni/src/suggest/core/policy/weighting.cpp b/native/jni/src/suggest/core/policy/weighting.cpp index 6c08e7678..857ddcc1d 100644 --- a/native/jni/src/suggest/core/policy/weighting.cpp +++ b/native/jni/src/suggest/core/policy/weighting.cpp @@ -80,9 +80,8 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n traverseSession, parentDicNode, dicNode, &inputStateG); const float languageCost = Weighting::getLanguageCost(weighting, correctionType, traverseSession, parentDicNode, dicNode, bigramCacheMap); - const bool edit = Weighting::isEditCorrection(correctionType); - const bool proximity = Weighting::isProximityCorrection(weighting, correctionType, - traverseSession, dicNode); + const ErrorType errorType = weighting->getErrorType(correctionType, traverseSession, + parentDicNode, dicNode); profile(correctionType, dicNode); if (inputStateG.mNeedsToUpdateInputStateG) { dicNode->updateInputIndexG(&inputStateG); @@ -91,7 +90,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n (correctionType == CT_TRANSPOSITION)); } dicNode->addCost(spatialCost, languageCost, weighting->needsToNormalizeCompoundDistance(), - inputSize, edit, proximity); + inputSize, errorType); } /* static */ float Weighting::getSpatialCost(const Weighting *const weighting, @@ -158,62 +157,6 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n } } -/* static */ bool Weighting::isEditCorrection(const CorrectionType correctionType) { - switch(correctionType) { - case CT_OMISSION: - return true; - case CT_ADDITIONAL_PROXIMITY: - return true; - case CT_SUBSTITUTION: - return true; - case CT_NEW_WORD_SPACE_OMITTION: - return false; - case CT_MATCH: - return false; - case CT_COMPLETION: - return false; - case CT_TERMINAL: - return false; - case CT_NEW_WORD_SPACE_SUBSTITUTION: - return false; - case CT_INSERTION: - return true; - case CT_TRANSPOSITION: - return true; - default: - return false; - } -} - -/* static */ bool Weighting::isProximityCorrection(const Weighting *const weighting, - const CorrectionType correctionType, - const DicTraverseSession *const traverseSession, const DicNode *const dicNode) { - switch(correctionType) { - case CT_OMISSION: - return false; - case CT_ADDITIONAL_PROXIMITY: - return true; - case CT_SUBSTITUTION: - return false; - case CT_NEW_WORD_SPACE_OMITTION: - return false; - case CT_MATCH: - return weighting->isProximityDicNode(traverseSession, dicNode); - case CT_COMPLETION: - return false; - case CT_TERMINAL: - return false; - case CT_NEW_WORD_SPACE_SUBSTITUTION: - return false; - case CT_INSERTION: - return false; - case CT_TRANSPOSITION: - return false; - default: - return false; - } -} - /* static */ int Weighting::getForwardInputCount(const CorrectionType correctionType) { switch(correctionType) { case CT_OMISSION: diff --git a/native/jni/src/suggest/core/policy/weighting.h b/native/jni/src/suggest/core/policy/weighting.h index bce479c51..6e740d9d6 100644 --- a/native/jni/src/suggest/core/policy/weighting.h +++ b/native/jni/src/suggest/core/policy/weighting.h @@ -80,6 +80,10 @@ class Weighting { virtual float getSpaceSubstitutionCost(const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const = 0; + virtual ErrorType getErrorType(const CorrectionType correctionType, + const DicTraverseSession *const traverseSession, + const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0; + Weighting() {} virtual ~Weighting() {} @@ -95,12 +99,6 @@ class Weighting { const DicNode *const parentDicNode, const DicNode *const dicNode, hash_map_compat<int, int16_t> *const bigramCacheMap); // TODO: Move to TypingWeighting and GestureWeighting? - static bool isEditCorrection(const CorrectionType correctionType); - // TODO: Move to TypingWeighting and GestureWeighting? - static bool isProximityCorrection(const Weighting *const weighting, - const CorrectionType correctionType, const DicTraverseSession *const traverseSession, - const DicNode *const dicNode); - // TODO: Move to TypingWeighting and GestureWeighting? static int getForwardInputCount(const CorrectionType correctionType); }; } // namespace latinime diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp index 9de2cd2e2..4f94a9a3b 100644 --- a/native/jni/src/suggest/core/suggest.cpp +++ b/native/jni/src/suggest/core/suggest.cpp @@ -161,12 +161,15 @@ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequen + doubleLetterCost; const TerminalAttributes terminalAttributes(traverseSession->getOffsetDict(), terminalDicNode->getFlags(), terminalDicNode->getAttributesPos()); - const int originalTerminalProbability = terminalDicNode->getProbability(); + const bool isPossiblyOffensiveWord = terminalDicNode->getProbability() <= 0; + const bool isExactMatch = terminalDicNode->isExactMatch(); + const int outputTypeFlags = + isPossiblyOffensiveWord ? Dictionary::KIND_FLAG_POSSIBLY_OFFENSIVE : 0 + | isExactMatch ? Dictionary::KIND_FLAG_EXACT_MATCH : 0; + + // Entries that are blacklisted or do not represent a word should not be output. + const bool isValidWord = !terminalAttributes.isBlacklistedOrNotAWord(); - // Do not suggest words with a 0 probability, or entries that are blacklisted or do not - // represent a word. However, we should still submit their shortcuts if any. - const bool isValidWord = - originalTerminalProbability > 0 && !terminalAttributes.isBlacklistedOrNotAWord(); // Increase output score of top typing suggestion to ensure autocorrection. // TODO: Better integration with java side autocorrection logic. // Force autocorrection for obvious long multi-word suggestions. @@ -188,10 +191,9 @@ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequen } } - // Do not suggest words with a 0 probability, or entries that are blacklisted or do not - // represent a word. However, we should still submit their shortcuts if any. + // Don't output invalid words. However, we still need to submit their shortcuts if any. if (isValidWord) { - outputTypes[outputWordIndex] = Dictionary::KIND_CORRECTION; + outputTypes[outputWordIndex] = Dictionary::KIND_CORRECTION | outputTypeFlags; frequencies[outputWordIndex] = finalScore; // Populate the outputChars array with the suggested word. const int startIndex = outputWordIndex * MAX_WORD_LENGTH; @@ -294,8 +296,8 @@ void Suggest::expandCurrentDicNodes(DicTraverseSession *traverseSession) const { correctionDicNode.advanceDigraphIndex(); processDicNodeAsDigraph(traverseSession, &correctionDicNode); } - if (allowsErrorCorrections - && TRAVERSAL->isOmission(traverseSession, &dicNode, childDicNode)) { + if (TRAVERSAL->isOmission(traverseSession, &dicNode, childDicNode, + allowsErrorCorrections)) { // TODO: (Gesture) Change weight between omission and substitution errors // TODO: (Gesture) Terminal node should not be handled as omission correctionDicNode.initByCopy(childDicNode); @@ -422,20 +424,15 @@ void Suggest::processDicNodeAsDigraph(DicTraverseSession *traverseSession, */ void Suggest::processDicNodeAsOmission( DicTraverseSession *traverseSession, DicNode *dicNode) const { - // If the omission is surely intentional that it should incur zero cost. - const bool isZeroCostOmission = dicNode->isZeroCostOmission(); DicNodeVector childDicNodes; - DicNodeUtils::getAllChildDicNodes(dicNode, traverseSession->getOffsetDict(), &childDicNodes); const int size = childDicNodes.getSizeAndLock(); for (int i = 0; i < size; i++) { DicNode *const childDicNode = childDicNodes[i]; - if (!isZeroCostOmission) { - // Treat this word as omission - Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_OMISSION, traverseSession, - dicNode, childDicNode, 0 /* bigramCacheMap */); - } + // Treat this word as omission + Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_OMISSION, traverseSession, + dicNode, childDicNode, 0 /* bigramCacheMap */); weightChildNode(traverseSession, childDicNode); if (!TRAVERSAL->isPossibleOmissionChildNode(traverseSession, dicNode, childDicNode)) { diff --git a/native/jni/src/suggest/policyimpl/typing/typing_traversal.h b/native/jni/src/suggest/policyimpl/typing/typing_traversal.h index 9f8347452..fb1fb79d1 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_traversal.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_traversal.h @@ -43,10 +43,17 @@ class TypingTraversal : public Traversal { } AK_FORCE_INLINE bool isOmission(const DicTraverseSession *const traverseSession, - const DicNode *const dicNode, const DicNode *const childDicNode) const { + const DicNode *const dicNode, const DicNode *const childDicNode, + const bool allowsErrorCorrections) const { if (!CORRECT_OMISSION) { return false; } + // Note: Always consider intentional omissions (like apostrophes) since they are common. + const bool canConsiderOmission = + allowsErrorCorrections || childDicNode->canBeIntentionalOmission(); + if (!canConsiderOmission) { + return false; + } const int inputSize = traverseSession->getInputSize(); // TODO: Don't refer to isCompletion? if (dicNode->isCompletion(inputSize)) { diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp b/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp index 1500341bd..47bd20425 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp +++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp @@ -21,4 +21,39 @@ namespace latinime { const TypingWeighting TypingWeighting::sInstance; + +ErrorType TypingWeighting::getErrorType(const CorrectionType correctionType, + const DicTraverseSession *const traverseSession, + const DicNode *const parentDicNode, const DicNode *const dicNode) const { + switch (correctionType) { + case CT_MATCH: + if (isProximityDicNode(traverseSession, dicNode)) { + return ET_PROXIMITY_CORRECTION; + } else { + return ET_NOT_AN_ERROR; + } + case CT_ADDITIONAL_PROXIMITY: + return ET_PROXIMITY_CORRECTION; + case CT_OMISSION: + if (parentDicNode->canBeIntentionalOmission()) { + return ET_INTENTIONAL_OMISSION; + } else { + return ET_EDIT_CORRECTION; + } + break; + case CT_SUBSTITUTION: + case CT_INSERTION: + case CT_TRANSPOSITION: + return ET_EDIT_CORRECTION; + case CT_NEW_WORD_SPACE_OMITTION: + case CT_NEW_WORD_SPACE_SUBSTITUTION: + return ET_NEW_WORD; + case CT_TERMINAL: + return ET_NOT_AN_ERROR; + case CT_COMPLETION: + return ET_COMPLETION; + default: + return ET_NOT_AN_ERROR; + } +} } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h index 34d25ae1a..4a0bd7194 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h @@ -50,13 +50,14 @@ class TypingWeighting : public Weighting { } float getOmissionCost(const DicNode *const parentDicNode, const DicNode *const dicNode) const { - bool sameCodePoint = false; - bool isFirstLetterOmission = false; - float cost = 0.0f; - sameCodePoint = dicNode->isSameNodeCodePoint(parentDicNode); + const bool isZeroCostOmission = parentDicNode->isZeroCostOmission(); + const bool sameCodePoint = dicNode->isSameNodeCodePoint(parentDicNode); // If the traversal omitted the first letter then the dicNode should now be on the second. - isFirstLetterOmission = dicNode->getDepth() == 2; - if (isFirstLetterOmission) { + const bool isFirstLetterOmission = dicNode->getDepth() == 2; + float cost = 0.0f; + if (isZeroCostOmission) { + cost = 0.0f; + } else if (isFirstLetterOmission) { cost = ScoringParams::OMISSION_COST_FIRST_CHAR; } else { cost = sameCodePoint ? ScoringParams::OMISSION_COST_SAME_CHAR @@ -156,15 +157,8 @@ class TypingWeighting : public Weighting { float getTerminalLanguageCost(const DicTraverseSession *const traverseSession, const DicNode *const dicNode, const float dicNodeLanguageImprobability) const { - const bool hasEditCount = dicNode->getEditCorrectionCount() > 0; - const bool isSameLength = dicNode->getDepth() == traverseSession->getInputSize(); - const bool hasMultipleWords = dicNode->hasMultipleWords(); - const bool hasProximityErrors = dicNode->getProximityCorrectionCount() > 0; - // Gesture input is always assumed to have proximity errors - // because the input word shouldn't be treated as perfect - const bool isExactMatch = !hasEditCount && !hasMultipleWords - && !hasProximityErrors && isSameLength; - const float languageImprobability = isExactMatch ? 0.0f : dicNodeLanguageImprobability; + const float languageImprobability = (dicNode->isExactMatch()) ? + 0.0f : dicNodeLanguageImprobability; return languageImprobability * ScoringParams::DISTANCE_WEIGHT_LANGUAGE; } @@ -189,6 +183,10 @@ class TypingWeighting : public Weighting { return cost * traverseSession->getMultiWordCostMultiplier(); } + ErrorType getErrorType(const CorrectionType correctionType, + const DicTraverseSession *const traverseSession, + const DicNode *const parentDicNode, const DicNode *const dicNode) const; + private: DISALLOW_COPY_AND_ASSIGN(TypingWeighting); static const TypingWeighting sInstance; |