aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--native/jni/src/defines.h20
-rw-r--r--native/jni/src/suggest/core/dicnode/dic_node.h9
-rw-r--r--native/jni/src/suggest/core/dicnode/dic_node_state_input.h4
-rw-r--r--native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h37
-rw-r--r--native/jni/src/suggest/core/policy/weighting.cpp63
-rw-r--r--native/jni/src/suggest/core/policy/weighting.h10
-rw-r--r--native/jni/src/suggest/core/suggest.cpp11
-rw-r--r--native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp35
-rw-r--r--native/jni/src/suggest/policyimpl/typing/typing_weighting.h28
9 files changed, 115 insertions, 102 deletions
diff --git a/native/jni/src/defines.h b/native/jni/src/defines.h
index e6719c9c3..d3b351f81 100644
--- a/native/jni/src/defines.h
+++ b/native/jni/src/defines.h
@@ -438,4 +438,24 @@ typedef enum {
// Create new word with space substitution
CT_NEW_WORD_SPACE_SUBSTITUTION,
} CorrectionType;
+
+// ErrorType is mainly decided by CorrectionType but it is also depending on if
+// the correction has really been performed or not.
+typedef enum {
+ // Substitution, omission and transposition
+ ET_EDIT_CORRECTION,
+ // Proximity error
+ ET_PROXIMITY_CORRECTION,
+ // Completion
+ ET_COMPLETION,
+ // New word
+ // TODO: Remove.
+ // A new word error should be an edit correction error or a proximity correction error.
+ ET_NEW_WORD,
+ // Treat error as an intentional omission when the CorrectionType is omission and the node can
+ // be intentional omission.
+ ET_INTENTIONAL_OMISSION,
+ // Not treated as an error. Tracked for checking exact match
+ ET_NOT_AN_ERROR
+} ErrorType;
#endif // LATINIME_DEFINES_H
diff --git a/native/jni/src/suggest/core/dicnode/dic_node.h b/native/jni/src/suggest/core/dicnode/dic_node.h
index e8432546b..92783dec7 100644
--- a/native/jni/src/suggest/core/dicnode/dic_node.h
+++ b/native/jni/src/suggest/core/dicnode/dic_node.h
@@ -463,6 +463,10 @@ class DicNode {
mDicNodeState.mDicNodeStateScoring.advanceDigraphIndex();
}
+ bool isExactMatch() const {
+ return mDicNodeState.mDicNodeStateScoring.isExactMatch();
+ }
+
uint8_t getFlags() const {
return mDicNodeProperties.getFlags();
}
@@ -542,13 +546,12 @@ class DicNode {
// Caveat: Must not be called outside Weighting
// This restriction is guaranteed by "friend"
AK_FORCE_INLINE void addCost(const float spatialCost, const float languageCost,
- const bool doNormalization, const int inputSize, const bool isEditCorrection,
- const bool isProximityCorrection) {
+ const bool doNormalization, const int inputSize, const ErrorType errorType) {
if (DEBUG_GEO_FULL) {
LOGI_SHOW_ADD_COST_PROP;
}
mDicNodeState.mDicNodeStateScoring.addCost(spatialCost, languageCost, doNormalization,
- inputSize, getTotalInputIndex(), isEditCorrection, isProximityCorrection);
+ inputSize, getTotalInputIndex(), errorType);
}
// Caveat: Must not be called outside Weighting
diff --git a/native/jni/src/suggest/core/dicnode/dic_node_state_input.h b/native/jni/src/suggest/core/dicnode/dic_node_state_input.h
index 7ad3e3e5f..bbd9435b5 100644
--- a/native/jni/src/suggest/core/dicnode/dic_node_state_input.h
+++ b/native/jni/src/suggest/core/dicnode/dic_node_state_input.h
@@ -46,8 +46,8 @@ class DicNodeStateInput {
for (int i = 0; i < MAX_POINTER_COUNT_G; i++) {
mInputIndex[i] = src->mInputIndex[i];
mPrevCodePoint[i] = src->mPrevCodePoint[i];
- mTerminalDiffCost[i] = resetTerminalDiffCost ?
- static_cast<float>(MAX_VALUE_FOR_WEIGHTING) : src->mTerminalDiffCost[i];
+ mTerminalDiffCost[i] = resetTerminalDiffCost ?
+ static_cast<float>(MAX_VALUE_FOR_WEIGHTING) : src->mTerminalDiffCost[i];
}
}
diff --git a/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h b/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h
index fd9d610e3..dca9d60da 100644
--- a/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h
+++ b/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h
@@ -31,7 +31,7 @@ class DicNodeStateScoring {
mDigraphIndex(DigraphUtils::NOT_A_DIGRAPH_INDEX),
mEditCorrectionCount(0), mProximityCorrectionCount(0),
mNormalizedCompoundDistance(0.0f), mSpatialDistance(0.0f), mLanguageDistance(0.0f),
- mRawLength(0.0f) {
+ mRawLength(0.0f), mExactMatch(true) {
}
virtual ~DicNodeStateScoring() {}
@@ -45,6 +45,7 @@ class DicNodeStateScoring {
mRawLength = 0.0f;
mDoubleLetterLevel = NOT_A_DOUBLE_LETTER;
mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX;
+ mExactMatch = true;
}
AK_FORCE_INLINE void init(const DicNodeStateScoring *const scoring) {
@@ -56,17 +57,32 @@ class DicNodeStateScoring {
mRawLength = scoring->mRawLength;
mDoubleLetterLevel = scoring->mDoubleLetterLevel;
mDigraphIndex = scoring->mDigraphIndex;
+ mExactMatch = scoring->mExactMatch;
}
void addCost(const float spatialCost, const float languageCost, const bool doNormalization,
- const int inputSize, const int totalInputIndex, const bool isEditCorrection,
- const bool isProximityCorrection) {
+ const int inputSize, const int totalInputIndex, const ErrorType errorType) {
addDistance(spatialCost, languageCost, doNormalization, inputSize, totalInputIndex);
- if (isEditCorrection) {
- ++mEditCorrectionCount;
- }
- if (isProximityCorrection) {
- ++mProximityCorrectionCount;
+ switch (errorType) {
+ case ET_EDIT_CORRECTION:
+ ++mEditCorrectionCount;
+ mExactMatch = false;
+ break;
+ case ET_PROXIMITY_CORRECTION:
+ ++mProximityCorrectionCount;
+ mExactMatch = false;
+ break;
+ case ET_COMPLETION:
+ mExactMatch = false;
+ break;
+ case ET_NEW_WORD:
+ mExactMatch = false;
+ break;
+ case ET_INTENTIONAL_OMISSION:
+ mExactMatch = false;
+ break;
+ case ET_NOT_AN_ERROR:
+ break;
}
}
@@ -143,6 +159,10 @@ class DicNodeStateScoring {
}
}
+ bool isExactMatch() const {
+ return mExactMatch;
+ }
+
private:
// Caution!!!
// Use a default copy constructor and an assign operator because shallow copies are ok
@@ -157,6 +177,7 @@ class DicNodeStateScoring {
float mSpatialDistance;
float mLanguageDistance;
float mRawLength;
+ bool mExactMatch;
AK_FORCE_INLINE void addDistance(float spatialDistance, float languageDistance,
bool doNormalization, int inputSize, int totalInputIndex) {
diff --git a/native/jni/src/suggest/core/policy/weighting.cpp b/native/jni/src/suggest/core/policy/weighting.cpp
index 6c08e7678..857ddcc1d 100644
--- a/native/jni/src/suggest/core/policy/weighting.cpp
+++ b/native/jni/src/suggest/core/policy/weighting.cpp
@@ -80,9 +80,8 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
traverseSession, parentDicNode, dicNode, &inputStateG);
const float languageCost = Weighting::getLanguageCost(weighting, correctionType,
traverseSession, parentDicNode, dicNode, bigramCacheMap);
- const bool edit = Weighting::isEditCorrection(correctionType);
- const bool proximity = Weighting::isProximityCorrection(weighting, correctionType,
- traverseSession, dicNode);
+ const ErrorType errorType = weighting->getErrorType(correctionType, traverseSession,
+ parentDicNode, dicNode);
profile(correctionType, dicNode);
if (inputStateG.mNeedsToUpdateInputStateG) {
dicNode->updateInputIndexG(&inputStateG);
@@ -91,7 +90,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
(correctionType == CT_TRANSPOSITION));
}
dicNode->addCost(spatialCost, languageCost, weighting->needsToNormalizeCompoundDistance(),
- inputSize, edit, proximity);
+ inputSize, errorType);
}
/* static */ float Weighting::getSpatialCost(const Weighting *const weighting,
@@ -158,62 +157,6 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
}
}
-/* static */ bool Weighting::isEditCorrection(const CorrectionType correctionType) {
- switch(correctionType) {
- case CT_OMISSION:
- return true;
- case CT_ADDITIONAL_PROXIMITY:
- return true;
- case CT_SUBSTITUTION:
- return true;
- case CT_NEW_WORD_SPACE_OMITTION:
- return false;
- case CT_MATCH:
- return false;
- case CT_COMPLETION:
- return false;
- case CT_TERMINAL:
- return false;
- case CT_NEW_WORD_SPACE_SUBSTITUTION:
- return false;
- case CT_INSERTION:
- return true;
- case CT_TRANSPOSITION:
- return true;
- default:
- return false;
- }
-}
-
-/* static */ bool Weighting::isProximityCorrection(const Weighting *const weighting,
- const CorrectionType correctionType,
- const DicTraverseSession *const traverseSession, const DicNode *const dicNode) {
- switch(correctionType) {
- case CT_OMISSION:
- return false;
- case CT_ADDITIONAL_PROXIMITY:
- return true;
- case CT_SUBSTITUTION:
- return false;
- case CT_NEW_WORD_SPACE_OMITTION:
- return false;
- case CT_MATCH:
- return weighting->isProximityDicNode(traverseSession, dicNode);
- case CT_COMPLETION:
- return false;
- case CT_TERMINAL:
- return false;
- case CT_NEW_WORD_SPACE_SUBSTITUTION:
- return false;
- case CT_INSERTION:
- return false;
- case CT_TRANSPOSITION:
- return false;
- default:
- return false;
- }
-}
-
/* static */ int Weighting::getForwardInputCount(const CorrectionType correctionType) {
switch(correctionType) {
case CT_OMISSION:
diff --git a/native/jni/src/suggest/core/policy/weighting.h b/native/jni/src/suggest/core/policy/weighting.h
index bce479c51..6e740d9d6 100644
--- a/native/jni/src/suggest/core/policy/weighting.h
+++ b/native/jni/src/suggest/core/policy/weighting.h
@@ -80,6 +80,10 @@ class Weighting {
virtual float getSpaceSubstitutionCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const = 0;
+ virtual ErrorType getErrorType(const CorrectionType correctionType,
+ const DicTraverseSession *const traverseSession,
+ const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0;
+
Weighting() {}
virtual ~Weighting() {}
@@ -95,12 +99,6 @@ class Weighting {
const DicNode *const parentDicNode, const DicNode *const dicNode,
hash_map_compat<int, int16_t> *const bigramCacheMap);
// TODO: Move to TypingWeighting and GestureWeighting?
- static bool isEditCorrection(const CorrectionType correctionType);
- // TODO: Move to TypingWeighting and GestureWeighting?
- static bool isProximityCorrection(const Weighting *const weighting,
- const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
- const DicNode *const dicNode);
- // TODO: Move to TypingWeighting and GestureWeighting?
static int getForwardInputCount(const CorrectionType correctionType);
};
} // namespace latinime
diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp
index 9de2cd2e2..0cf4e4a68 100644
--- a/native/jni/src/suggest/core/suggest.cpp
+++ b/native/jni/src/suggest/core/suggest.cpp
@@ -422,20 +422,15 @@ void Suggest::processDicNodeAsDigraph(DicTraverseSession *traverseSession,
*/
void Suggest::processDicNodeAsOmission(
DicTraverseSession *traverseSession, DicNode *dicNode) const {
- // If the omission is surely intentional that it should incur zero cost.
- const bool isZeroCostOmission = dicNode->isZeroCostOmission();
DicNodeVector childDicNodes;
-
DicNodeUtils::getAllChildDicNodes(dicNode, traverseSession->getOffsetDict(), &childDicNodes);
const int size = childDicNodes.getSizeAndLock();
for (int i = 0; i < size; i++) {
DicNode *const childDicNode = childDicNodes[i];
- if (!isZeroCostOmission) {
- // Treat this word as omission
- Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_OMISSION, traverseSession,
- dicNode, childDicNode, 0 /* bigramCacheMap */);
- }
+ // Treat this word as omission
+ Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_OMISSION, traverseSession,
+ dicNode, childDicNode, 0 /* bigramCacheMap */);
weightChildNode(traverseSession, childDicNode);
if (!TRAVERSAL->isPossibleOmissionChildNode(traverseSession, dicNode, childDicNode)) {
diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp b/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp
index 1500341bd..47bd20425 100644
--- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp
+++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp
@@ -21,4 +21,39 @@
namespace latinime {
const TypingWeighting TypingWeighting::sInstance;
+
+ErrorType TypingWeighting::getErrorType(const CorrectionType correctionType,
+ const DicTraverseSession *const traverseSession,
+ const DicNode *const parentDicNode, const DicNode *const dicNode) const {
+ switch (correctionType) {
+ case CT_MATCH:
+ if (isProximityDicNode(traverseSession, dicNode)) {
+ return ET_PROXIMITY_CORRECTION;
+ } else {
+ return ET_NOT_AN_ERROR;
+ }
+ case CT_ADDITIONAL_PROXIMITY:
+ return ET_PROXIMITY_CORRECTION;
+ case CT_OMISSION:
+ if (parentDicNode->canBeIntentionalOmission()) {
+ return ET_INTENTIONAL_OMISSION;
+ } else {
+ return ET_EDIT_CORRECTION;
+ }
+ break;
+ case CT_SUBSTITUTION:
+ case CT_INSERTION:
+ case CT_TRANSPOSITION:
+ return ET_EDIT_CORRECTION;
+ case CT_NEW_WORD_SPACE_OMITTION:
+ case CT_NEW_WORD_SPACE_SUBSTITUTION:
+ return ET_NEW_WORD;
+ case CT_TERMINAL:
+ return ET_NOT_AN_ERROR;
+ case CT_COMPLETION:
+ return ET_COMPLETION;
+ default:
+ return ET_NOT_AN_ERROR;
+ }
+}
} // namespace latinime
diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h
index 34d25ae1a..4a0bd7194 100644
--- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h
+++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h
@@ -50,13 +50,14 @@ class TypingWeighting : public Weighting {
}
float getOmissionCost(const DicNode *const parentDicNode, const DicNode *const dicNode) const {
- bool sameCodePoint = false;
- bool isFirstLetterOmission = false;
- float cost = 0.0f;
- sameCodePoint = dicNode->isSameNodeCodePoint(parentDicNode);
+ const bool isZeroCostOmission = parentDicNode->isZeroCostOmission();
+ const bool sameCodePoint = dicNode->isSameNodeCodePoint(parentDicNode);
// If the traversal omitted the first letter then the dicNode should now be on the second.
- isFirstLetterOmission = dicNode->getDepth() == 2;
- if (isFirstLetterOmission) {
+ const bool isFirstLetterOmission = dicNode->getDepth() == 2;
+ float cost = 0.0f;
+ if (isZeroCostOmission) {
+ cost = 0.0f;
+ } else if (isFirstLetterOmission) {
cost = ScoringParams::OMISSION_COST_FIRST_CHAR;
} else {
cost = sameCodePoint ? ScoringParams::OMISSION_COST_SAME_CHAR
@@ -156,15 +157,8 @@ class TypingWeighting : public Weighting {
float getTerminalLanguageCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode, const float dicNodeLanguageImprobability) const {
- const bool hasEditCount = dicNode->getEditCorrectionCount() > 0;
- const bool isSameLength = dicNode->getDepth() == traverseSession->getInputSize();
- const bool hasMultipleWords = dicNode->hasMultipleWords();
- const bool hasProximityErrors = dicNode->getProximityCorrectionCount() > 0;
- // Gesture input is always assumed to have proximity errors
- // because the input word shouldn't be treated as perfect
- const bool isExactMatch = !hasEditCount && !hasMultipleWords
- && !hasProximityErrors && isSameLength;
- const float languageImprobability = isExactMatch ? 0.0f : dicNodeLanguageImprobability;
+ const float languageImprobability = (dicNode->isExactMatch()) ?
+ 0.0f : dicNodeLanguageImprobability;
return languageImprobability * ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
}
@@ -189,6 +183,10 @@ class TypingWeighting : public Weighting {
return cost * traverseSession->getMultiWordCostMultiplier();
}
+ ErrorType getErrorType(const CorrectionType correctionType,
+ const DicTraverseSession *const traverseSession,
+ const DicNode *const parentDicNode, const DicNode *const dicNode) const;
+
private:
DISALLOW_COPY_AND_ASSIGN(TypingWeighting);
static const TypingWeighting sInstance;