aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--native/jni/src/suggest/core/suggest.cpp23
-rw-r--r--native/jni/src/suggest/core/suggest.h2
-rw-r--r--native/jni/src/suggest/policyimpl/typing/typing_weighting.h3
3 files changed, 27 insertions, 1 deletions
diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp
index a8f16c8cb..173a612be 100644
--- a/native/jni/src/suggest/core/suggest.cpp
+++ b/native/jni/src/suggest/core/suggest.cpp
@@ -36,6 +36,7 @@ namespace latinime {
const int Suggest::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16;
const int Suggest::MIN_CONTINUOUS_SUGGESTION_INPUT_SIZE = 2;
const float Suggest::AUTOCORRECT_CLASSIFICATION_THRESHOLD = 0.33f;
+const int Suggest::FINAL_SCORE_PENALTY_FOR_NOT_BEST_EXACT_MATCHED_WORD = 1;
/**
* Returns a set of suggestions for the given input touch points. The commitPoint argument indicates
@@ -148,6 +149,8 @@ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequen
&doubleLetterTerminalIndex, &doubleLetterLevel);
int maxScore = S_INT_MIN;
+ int bestExactMatchedNodeTerminalIndex = -1;
+ int bestExactMatchedNodeOutputWordIndex = -1;
// Output suggestion results here
for (int terminalIndex = 0; terminalIndex < terminalSize && outputWordIndex < MAX_RESULTS;
++terminalIndex) {
@@ -186,7 +189,6 @@ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequen
const int finalScore = SCORING->calculateFinalScore(
compoundDistance, traverseSession->getInputSize(),
isForceCommitMultiWords || (isValidWord && SCORING->doesAutoCorrectValidWord()));
-
maxScore = max(maxScore, finalScore);
if (TRAVERSAL->allowPartialCommit()) {
@@ -200,6 +202,25 @@ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequen
if (isValidWord) {
outputTypes[outputWordIndex] = Dictionary::KIND_CORRECTION | outputTypeFlags;
frequencies[outputWordIndex] = finalScore;
+ if (isSafeExactMatch) {
+ // Demote exact matches that are not the highest probable node among all exact
+ // matches.
+ const bool isBestTerminal = bestExactMatchedNodeTerminalIndex < 0
+ || terminals[bestExactMatchedNodeTerminalIndex].getProbability()
+ < terminalDicNode->getProbability();
+ const int outputWordIndexToBeDemoted = isBestTerminal ?
+ bestExactMatchedNodeOutputWordIndex : outputWordIndex;
+ if (outputWordIndexToBeDemoted >= 0) {
+ frequencies[outputWordIndexToBeDemoted] -=
+ FINAL_SCORE_PENALTY_FOR_NOT_BEST_EXACT_MATCHED_WORD;
+ }
+ if (isBestTerminal) {
+ // Updates the best exact matched node index.
+ bestExactMatchedNodeTerminalIndex = terminalIndex;
+ // Updates the best exact matched output word index.
+ bestExactMatchedNodeOutputWordIndex = outputWordIndex;
+ }
+ }
// Populate the outputChars array with the suggested word.
const int startIndex = outputWordIndex * MAX_WORD_LENGTH;
terminalDicNode->outputResult(&outputCodePoints[startIndex]);
diff --git a/native/jni/src/suggest/core/suggest.h b/native/jni/src/suggest/core/suggest.h
index 875cbe4e0..752bde9ac 100644
--- a/native/jni/src/suggest/core/suggest.h
+++ b/native/jni/src/suggest/core/suggest.h
@@ -82,6 +82,8 @@ class Suggest : public SuggestInterface {
// Threshold for autocorrection classifier
static const float AUTOCORRECT_CLASSIFICATION_THRESHOLD;
+ // Final score penalty to exact match words that are not the most probable exact match.
+ static const int FINAL_SCORE_PENALTY_FOR_NOT_BEST_EXACT_MATCHED_WORD;
const Traversal *const TRAVERSAL;
const Scoring *const SCORING;
diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h
index 8036ffd0d..7ba4af5f9 100644
--- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h
+++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h
@@ -163,6 +163,9 @@ class TypingWeighting : public Weighting {
float getTerminalLanguageCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode, const float dicNodeLanguageImprobability) const {
+ // We promote exact matches here to prevent them from being pruned. The final score of
+ // exact match nodes might be demoted later in Suggest::outputSuggestions if there are
+ // multiple exact matches.
const float languageImprobability = (dicNode->isExactMatch()) ?
0.0f : dicNodeLanguageImprobability;
return languageImprobability * ScoringParams::DISTANCE_WEIGHT_LANGUAGE;