aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--native/jni/src/defines.h1
-rw-r--r--native/jni/src/suggest/core/dicnode/dic_node_profiler.h9
-rw-r--r--native/jni/src/suggest/core/policy/weighting.cpp17
-rw-r--r--native/jni/src/suggest/core/policy/weighting.h4
-rw-r--r--native/jni/src/suggest/core/suggest.cpp10
-rw-r--r--native/jni/src/suggest/policyimpl/typing/scoring_params.cpp1
-rw-r--r--native/jni/src/suggest/policyimpl/typing/scoring_params.h1
-rw-r--r--native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp1
-rw-r--r--native/jni/src/suggest/policyimpl/typing/typing_weighting.h9
9 files changed, 43 insertions, 10 deletions
diff --git a/native/jni/src/defines.h b/native/jni/src/defines.h
index 974bb483b..34a646f80 100644
--- a/native/jni/src/defines.h
+++ b/native/jni/src/defines.h
@@ -381,6 +381,7 @@ typedef enum {
CT_TRANSPOSITION,
CT_COMPLETION,
CT_TERMINAL,
+ CT_TERMINAL_INSERTION,
// Create new word with space omission
CT_NEW_WORD_SPACE_OMITTION,
// Create new word with space substitution
diff --git a/native/jni/src/suggest/core/dicnode/dic_node_profiler.h b/native/jni/src/suggest/core/dicnode/dic_node_profiler.h
index 90f75d0c6..1f4d2570e 100644
--- a/native/jni/src/suggest/core/dicnode/dic_node_profiler.h
+++ b/native/jni/src/suggest/core/dicnode/dic_node_profiler.h
@@ -31,6 +31,7 @@
#define PROF_TRANSPOSITION(profiler) profiler.profTransposition()
#define PROF_NEARESTKEY(profiler) profiler.profNearestKey()
#define PROF_TERMINAL(profiler) profiler.profTerminal()
+#define PROF_TERMINAL_INSERTION(profiler) profiler.profTerminalInsertion()
#define PROF_NEW_WORD(profiler) profiler.profNewWord()
#define PROF_NEW_WORD_BIGRAM(profiler) profiler.profNewWordBigram()
#define PROF_NODE_RESET(profiler) profiler.reset()
@@ -47,6 +48,7 @@
#define PROF_TRANSPOSITION(profiler)
#define PROF_NEARESTKEY(profiler)
#define PROF_TERMINAL(profiler)
+#define PROF_TERMINAL_INSERTION(profiler)
#define PROF_NEW_WORD(profiler)
#define PROF_NEW_WORD_BIGRAM(profiler)
#define PROF_NODE_RESET(profiler)
@@ -62,7 +64,7 @@ class DicNodeProfiler {
: mProfOmission(0), mProfInsertion(0), mProfTransposition(0),
mProfAdditionalProximity(0), mProfSubstitution(0),
mProfSpaceSubstitution(0), mProfSpaceOmission(0),
- mProfMatch(0), mProfCompletion(0), mProfTerminal(0),
+ mProfMatch(0), mProfCompletion(0), mProfTerminal(0), mProfTerminalInsertion(0),
mProfNearestKey(0), mProfNewWord(0), mProfNewWordBigram(0) {}
int mProfOmission;
@@ -75,6 +77,7 @@ class DicNodeProfiler {
int mProfMatch;
int mProfCompletion;
int mProfTerminal;
+ int mProfTerminalInsertion;
int mProfNearestKey;
int mProfNewWord;
int mProfNewWordBigram;
@@ -123,6 +126,10 @@ class DicNodeProfiler {
++mProfTerminal;
}
+ void profTerminalInsertion() {
+ ++mProfTerminalInsertion;
+ }
+
void profNewWord() {
++mProfNewWord;
}
diff --git a/native/jni/src/suggest/core/policy/weighting.cpp b/native/jni/src/suggest/core/policy/weighting.cpp
index 117f48f29..58729229f 100644
--- a/native/jni/src/suggest/core/policy/weighting.cpp
+++ b/native/jni/src/suggest/core/policy/weighting.cpp
@@ -50,6 +50,9 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
case CT_TERMINAL:
PROF_TERMINAL(node->mProfiler);
return;
+ case CT_TERMINAL_INSERTION:
+ PROF_TERMINAL_INSERTION(node->mProfiler);
+ return;
case CT_NEW_WORD_SPACE_SUBSTITUTION:
PROF_SPACE_SUBSTITUTION(node->mProfiler);
return;
@@ -113,6 +116,8 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
return weighting->getCompletionCost(traverseSession, dicNode);
case CT_TERMINAL:
return weighting->getTerminalSpatialCost(traverseSession, dicNode);
+ case CT_TERMINAL_INSERTION:
+ return weighting->getTerminalInsertionCost(traverseSession, dicNode);
case CT_NEW_WORD_SPACE_SUBSTITUTION:
return weighting->getSpaceSubstitutionCost(traverseSession, dicNode);
case CT_INSERTION:
@@ -146,6 +151,8 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
traverseSession->getBinaryDictionaryInfo(), dicNode, multiBigramMap);
return weighting->getTerminalLanguageCost(traverseSession, dicNode, languageImprobability);
}
+ case CT_TERMINAL_INSERTION:
+ return 0.0f;
case CT_NEW_WORD_SPACE_SUBSTITUTION:
return weighting->getNewWordBigramLanguageCost(
traverseSession, parentDicNode, multiBigramMap);
@@ -163,9 +170,9 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
case CT_OMISSION:
return 0;
case CT_ADDITIONAL_PROXIMITY:
- return 0;
+ return 0; /* 0 because CT_MATCH will be called */
case CT_SUBSTITUTION:
- return 0;
+ return 0; /* 0 because CT_MATCH will be called */
case CT_NEW_WORD_SPACE_OMITTION:
return 0;
case CT_MATCH:
@@ -174,12 +181,14 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
return 1;
case CT_TERMINAL:
return 0;
+ case CT_TERMINAL_INSERTION:
+ return 1;
case CT_NEW_WORD_SPACE_SUBSTITUTION:
return 1;
case CT_INSERTION:
- return 2;
+ return 2; /* look ahead + skip the current char */
case CT_TRANSPOSITION:
- return 2;
+ return 2; /* look ahead + skip the current char */
default:
return 0;
}
diff --git a/native/jni/src/suggest/core/policy/weighting.h b/native/jni/src/suggest/core/policy/weighting.h
index 781a7adbc..2d49e98a6 100644
--- a/native/jni/src/suggest/core/policy/weighting.h
+++ b/native/jni/src/suggest/core/policy/weighting.h
@@ -67,6 +67,10 @@ class Weighting {
const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const = 0;
+ virtual float getTerminalInsertionCost(
+ const DicTraverseSession *const traverseSession,
+ const DicNode *const dicNode) const = 0;
+
virtual float getTerminalLanguageCost(
const DicTraverseSession *const traverseSession, const DicNode *const dicNode,
float dicNodeLanguageImprobability) const = 0;
diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp
index d6383b958..73e9714bd 100644
--- a/native/jni/src/suggest/core/suggest.cpp
+++ b/native/jni/src/suggest/core/suggest.cpp
@@ -365,17 +365,17 @@ void Suggest::processTerminalDicNode(
if (!dicNode->isTerminalWordNode()) {
return;
}
- if (TRAVERSAL->needsToTraverseAllUserInput()
- && dicNode->getInputIndex(0) < traverseSession->getInputSize()) {
- return;
- }
-
if (dicNode->shouldBeFilterdBySafetyNetForBigram()) {
return;
}
// Create a non-cached node here.
DicNode terminalDicNode;
DicNodeUtils::initByCopy(dicNode, &terminalDicNode);
+ if (TRAVERSAL->needsToTraverseAllUserInput()
+ && dicNode->getInputIndex(0) < traverseSession->getInputSize()) {
+ Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TERMINAL_INSERTION, traverseSession, 0,
+ &terminalDicNode, traverseSession->getMultiBigramMap());
+ }
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TERMINAL, traverseSession, 0,
&terminalDicNode, traverseSession->getMultiBigramMap());
traverseSession->getDicTraverseCache()->copyPushTerminal(&terminalDicNode);
diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp
index a8f797c5c..4157f411e 100644
--- a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp
+++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp
@@ -34,6 +34,7 @@ const float ScoringParams::OMISSION_COST = 0.458f;
const float ScoringParams::OMISSION_COST_SAME_CHAR = 0.491f;
const float ScoringParams::OMISSION_COST_FIRST_CHAR = 0.582f;
const float ScoringParams::INSERTION_COST = 0.730f;
+const float ScoringParams::TERMINAL_INSERTION_COST = 0.93f;
const float ScoringParams::INSERTION_COST_SAME_CHAR = 0.586f;
const float ScoringParams::INSERTION_COST_PROXIMITY_CHAR = 0.70f;
const float ScoringParams::INSERTION_COST_FIRST_CHAR = 0.623f;
diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.h b/native/jni/src/suggest/policyimpl/typing/scoring_params.h
index 4ebcc7dc3..a743b4d81 100644
--- a/native/jni/src/suggest/policyimpl/typing/scoring_params.h
+++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.h
@@ -42,6 +42,7 @@ class ScoringParams {
static const float OMISSION_COST_SAME_CHAR;
static const float OMISSION_COST_FIRST_CHAR;
static const float INSERTION_COST;
+ static const float TERMINAL_INSERTION_COST;
static const float INSERTION_COST_SAME_CHAR;
static const float INSERTION_COST_PROXIMITY_CHAR;
static const float INSERTION_COST_FIRST_CHAR;
diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp b/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp
index e4c69d1f6..408b12ae9 100644
--- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp
+++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp
@@ -44,6 +44,7 @@ ErrorType TypingWeighting::getErrorType(const CorrectionType correctionType,
break;
case CT_SUBSTITUTION:
case CT_INSERTION:
+ case CT_TERMINAL_INSERTION:
case CT_TRANSPOSITION:
return ET_EDIT_CORRECTION;
case CT_NEW_WORD_SPACE_OMITTION:
diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h
index 1bb160738..7cddb0882 100644
--- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h
+++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h
@@ -175,6 +175,15 @@ class TypingWeighting : public Weighting {
return dicNodeLanguageImprobability * ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
}
+ float getTerminalInsertionCost(const DicTraverseSession *const traverseSession,
+ const DicNode *const dicNode) const {
+ const int inputIndex = dicNode->getInputIndex(0);
+ const int inputSize = traverseSession->getInputSize();
+ ASSERT(inputIndex < inputSize);
+ // TODO: Implement more efficient logic
+ return ScoringParams::TERMINAL_INSERTION_COST * (inputSize - inputIndex);
+ }
+
AK_FORCE_INLINE bool needsToNormalizeCompoundDistance() const {
return false;
}