aboutsummaryrefslogtreecommitdiffstats
path: root/native
diff options
context:
space:
mode:
Diffstat (limited to 'native')
-rw-r--r--native/jni/src/binary_format.h5
-rw-r--r--native/jni/src/suggest/core/dicnode/dic_node.h8
-rw-r--r--native/jni/src/suggest/core/dicnode/dic_node_properties.h5
-rw-r--r--native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h17
-rw-r--r--native/jni/src/suggest/core/policy/weighting.cpp2
-rw-r--r--native/jni/src/suggest/policyimpl/typing/scoring_params.cpp10
-rw-r--r--native/jni/src/suggest/policyimpl/typing/typing_weighting.h9
-rw-r--r--native/jni/src/terminal_attributes.h2
8 files changed, 21 insertions, 37 deletions
diff --git a/native/jni/src/binary_format.h b/native/jni/src/binary_format.h
index 2d2e19501..ad16039ef 100644
--- a/native/jni/src/binary_format.h
+++ b/native/jni/src/binary_format.h
@@ -66,6 +66,7 @@ class BinaryFormat {
static int detectFormat(const uint8_t *const dict);
static int getHeaderSize(const uint8_t *const dict);
static int getFlags(const uint8_t *const dict);
+ static bool hasBlacklistedOrNotAWordFlag(const int flags);
static void readHeaderValue(const uint8_t *const dict, const char *const key, int *outValue,
const int outValueSize);
static int readHeaderValueInt(const uint8_t *const dict, const char *const key);
@@ -162,6 +163,10 @@ inline int BinaryFormat::getFlags(const uint8_t *const dict) {
}
}
+inline bool BinaryFormat::hasBlacklistedOrNotAWordFlag(const int flags) {
+ return flags & (FLAG_IS_BLACKLISTED | FLAG_IS_NOT_A_WORD);
+}
+
inline int BinaryFormat::getHeaderSize(const uint8_t *const dict) {
switch (detectFormat(dict)) {
case 1:
diff --git a/native/jni/src/suggest/core/dicnode/dic_node.h b/native/jni/src/suggest/core/dicnode/dic_node.h
index 32faae52c..e8432546b 100644
--- a/native/jni/src/suggest/core/dicnode/dic_node.h
+++ b/native/jni/src/suggest/core/dicnode/dic_node.h
@@ -210,8 +210,7 @@ class DicNode {
}
bool isImpossibleBigramWord() const {
- const int probability = mDicNodeProperties.getProbability();
- if (probability == 0) {
+ if (mDicNodeProperties.hasBlacklistedOrNotAWordFlag()) {
return true;
}
const int prevWordLen = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength()
@@ -360,11 +359,6 @@ class DicNode {
return mDicNodeState.mDicNodeStateScoring.getCompoundDistance(languageWeight);
}
- // Note that "cost" means delta for "distance" that is weighted.
- float getTotalPrevWordsLanguageCost() const {
- return mDicNodeState.mDicNodeStateScoring.getTotalPrevWordsLanguageCost();
- }
-
// Used to commit input partially
int getPrevWordNodePos() const {
return mDicNodeState.mDicNodeStatePrevWord.getPrevWordNodePos();
diff --git a/native/jni/src/suggest/core/dicnode/dic_node_properties.h b/native/jni/src/suggest/core/dicnode/dic_node_properties.h
index 173ef35d0..63a6b1340 100644
--- a/native/jni/src/suggest/core/dicnode/dic_node_properties.h
+++ b/native/jni/src/suggest/core/dicnode/dic_node_properties.h
@@ -19,6 +19,7 @@
#include <stdint.h>
+#include "binary_format.h"
#include "defines.h"
namespace latinime {
@@ -144,6 +145,10 @@ class DicNodeProperties {
return mChildrenCount > 0 || mDepth != mLeavingDepth;
}
+ bool hasBlacklistedOrNotAWordFlag() const {
+ return BinaryFormat::hasBlacklistedOrNotAWordFlag(mFlags);
+ }
+
private:
// Caution!!!
// Use a default copy constructor and an assign operator because shallow copies are ok
diff --git a/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h b/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h
index 8902d3122..fd9d610e3 100644
--- a/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h
+++ b/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h
@@ -31,7 +31,7 @@ class DicNodeStateScoring {
mDigraphIndex(DigraphUtils::NOT_A_DIGRAPH_INDEX),
mEditCorrectionCount(0), mProximityCorrectionCount(0),
mNormalizedCompoundDistance(0.0f), mSpatialDistance(0.0f), mLanguageDistance(0.0f),
- mTotalPrevWordsLanguageCost(0.0f), mRawLength(0.0f) {
+ mRawLength(0.0f) {
}
virtual ~DicNodeStateScoring() {}
@@ -42,7 +42,6 @@ class DicNodeStateScoring {
mNormalizedCompoundDistance = 0.0f;
mSpatialDistance = 0.0f;
mLanguageDistance = 0.0f;
- mTotalPrevWordsLanguageCost = 0.0f;
mRawLength = 0.0f;
mDoubleLetterLevel = NOT_A_DOUBLE_LETTER;
mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX;
@@ -54,7 +53,6 @@ class DicNodeStateScoring {
mNormalizedCompoundDistance = scoring->mNormalizedCompoundDistance;
mSpatialDistance = scoring->mSpatialDistance;
mLanguageDistance = scoring->mLanguageDistance;
- mTotalPrevWordsLanguageCost = scoring->mTotalPrevWordsLanguageCost;
mRawLength = scoring->mRawLength;
mDoubleLetterLevel = scoring->mDoubleLetterLevel;
mDigraphIndex = scoring->mDigraphIndex;
@@ -70,9 +68,6 @@ class DicNodeStateScoring {
if (isProximityCorrection) {
++mProximityCorrectionCount;
}
- if (languageCost > 0.0f) {
- setTotalPrevWordsLanguageCost(mTotalPrevWordsLanguageCost + languageCost);
- }
}
void addRawLength(const float rawLength) {
@@ -148,10 +143,6 @@ class DicNodeStateScoring {
}
}
- float getTotalPrevWordsLanguageCost() const {
- return mTotalPrevWordsLanguageCost;
- }
-
private:
// Caution!!!
// Use a default copy constructor and an assign operator because shallow copies are ok
@@ -165,7 +156,6 @@ class DicNodeStateScoring {
float mNormalizedCompoundDistance;
float mSpatialDistance;
float mLanguageDistance;
- float mTotalPrevWordsLanguageCost;
float mRawLength;
AK_FORCE_INLINE void addDistance(float spatialDistance, float languageDistance,
@@ -179,11 +169,6 @@ class DicNodeStateScoring {
/ static_cast<float>(max(1, totalInputIndex));
}
}
-
- //TODO: remove
- AK_FORCE_INLINE void setTotalPrevWordsLanguageCost(float totalPrevWordsLanguageCost) {
- mTotalPrevWordsLanguageCost = totalPrevWordsLanguageCost;
- }
};
} // namespace latinime
#endif // LATINIME_DIC_NODE_STATE_SCORING_H
diff --git a/native/jni/src/suggest/core/policy/weighting.cpp b/native/jni/src/suggest/core/policy/weighting.cpp
index b9c0b8129..a6d30e457 100644
--- a/native/jni/src/suggest/core/policy/weighting.cpp
+++ b/native/jni/src/suggest/core/policy/weighting.cpp
@@ -229,7 +229,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
case CT_MATCH:
return 1;
case CT_COMPLETION:
- return 0;
+ return 1;
case CT_TERMINAL:
return 0;
case CT_NEW_WORD_SPACE_SUBSTITUTION:
diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp
index 0fa684f01..993358616 100644
--- a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp
+++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp
@@ -35,17 +35,17 @@ const float ScoringParams::INSERTION_COST = 0.670f;
const float ScoringParams::INSERTION_COST_SAME_CHAR = 0.526f;
const float ScoringParams::INSERTION_COST_FIRST_CHAR = 0.563f;
const float ScoringParams::TRANSPOSITION_COST = 0.494f;
-const float ScoringParams::SPACE_SUBSTITUTION_COST = 0.239f;
+const float ScoringParams::SPACE_SUBSTITUTION_COST = 0.289f;
const float ScoringParams::ADDITIONAL_PROXIMITY_COST = 0.380f;
const float ScoringParams::SUBSTITUTION_COST = 0.363f;
-const float ScoringParams::COST_NEW_WORD = 0.054f;
+const float ScoringParams::COST_NEW_WORD = 0.024f;
const float ScoringParams::COST_NEW_WORD_CAPITALIZED = 0.174f;
const float ScoringParams::DISTANCE_WEIGHT_LANGUAGE = 1.123f;
-const float ScoringParams::COST_FIRST_LOOKAHEAD = 0.462f;
-const float ScoringParams::COST_LOOKAHEAD = 0.092f;
+const float ScoringParams::COST_FIRST_LOOKAHEAD = 0.545f;
+const float ScoringParams::COST_LOOKAHEAD = 0.073f;
const float ScoringParams::HAS_PROXIMITY_TERMINAL_COST = 0.126f;
const float ScoringParams::HAS_EDIT_CORRECTION_TERMINAL_COST = 0.056f;
-const float ScoringParams::HAS_MULTI_WORD_TERMINAL_COST = 0.136f;
+const float ScoringParams::HAS_MULTI_WORD_TERMINAL_COST = 0.536f;
const float ScoringParams::TYPING_BASE_OUTPUT_SCORE = 1.0f;
const float ScoringParams::TYPING_MAX_OUTPUT_SCORE_PER_INPUT = 0.1f;
const float ScoringParams::MAX_NORM_DISTANCE_FOR_EDIT = 0.1f;
diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h
index 74e4e34e4..34d25ae1a 100644
--- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h
+++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h
@@ -140,7 +140,7 @@ class TypingWeighting : public Weighting {
const DicTraverseSession *const traverseSession, const DicNode *const dicNode,
hash_map_compat<int, int16_t> *const bigramCacheMap) const {
return DicNodeUtils::getBigramNodeImprobability(traverseSession->getOffsetDict(),
- dicNode, bigramCacheMap);
+ dicNode, bigramCacheMap) * ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
}
float getCompletionCost(const DicTraverseSession *const traverseSession,
@@ -164,13 +164,8 @@ class TypingWeighting : public Weighting {
// because the input word shouldn't be treated as perfect
const bool isExactMatch = !hasEditCount && !hasMultipleWords
&& !hasProximityErrors && isSameLength;
-
- const float totalPrevWordsLanguageCost = dicNode->getTotalPrevWordsLanguageCost();
const float languageImprobability = isExactMatch ? 0.0f : dicNodeLanguageImprobability;
- const float languageWeight = ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
- // TODO: Caveat: The following equation should be:
- // totalPrevWordsLanguageCost + (languageImprobability * languageWeight);
- return (totalPrevWordsLanguageCost + languageImprobability) * languageWeight;
+ return languageImprobability * ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
}
AK_FORCE_INLINE bool needsToNormalizeCompoundDistance() const {
diff --git a/native/jni/src/terminal_attributes.h b/native/jni/src/terminal_attributes.h
index 144ae1452..92ef71c2c 100644
--- a/native/jni/src/terminal_attributes.h
+++ b/native/jni/src/terminal_attributes.h
@@ -72,7 +72,7 @@ class TerminalAttributes {
}
bool isBlacklistedOrNotAWord() const {
- return mFlags & (BinaryFormat::FLAG_IS_BLACKLISTED | BinaryFormat::FLAG_IS_NOT_A_WORD);
+ return BinaryFormat::hasBlacklistedOrNotAWordFlag(mFlags);
}
private: