aboutsummaryrefslogtreecommitdiffstats
path: root/native/jni/src
diff options
context:
space:
mode:
Diffstat (limited to 'native/jni/src')
-rw-r--r--native/jni/src/correction.cpp2
-rw-r--r--native/jni/src/defines.h2
-rw-r--r--native/jni/src/dictionary.cpp5
-rw-r--r--native/jni/src/dictionary.h1
-rw-r--r--native/jni/src/digraph_utils.cpp74
-rw-r--r--native/jni/src/digraph_utils.h19
-rw-r--r--native/jni/src/proximity_info.cpp14
-rw-r--r--native/jni/src/proximity_info.h15
-rw-r--r--native/jni/src/proximity_info_state.cpp12
-rw-r--r--native/jni/src/proximity_info_state.h8
-rw-r--r--native/jni/src/proximity_info_state_utils.cpp54
-rw-r--r--native/jni/src/proximity_info_state_utils.h12
-rw-r--r--native/jni/src/suggest/core/dicnode/dic_node.h19
-rw-r--r--native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h23
-rw-r--r--native/jni/src/suggest/core/session/dic_traverse_session.cpp4
-rw-r--r--native/jni/src/suggest/core/session/dic_traverse_session.h8
-rw-r--r--native/jni/src/suggest/core/suggest.cpp32
-rw-r--r--native/jni/src/suggest/core/suggest.h1
-rw-r--r--native/jni/src/suggest/policyimpl/typing/typing_weighting.h11
-rw-r--r--native/jni/src/suggest_utils.h39
-rw-r--r--native/jni/src/unigram_dictionary.cpp6
-rw-r--r--native/jni/src/unigram_dictionary.h5
22 files changed, 276 insertions, 90 deletions
diff --git a/native/jni/src/correction.cpp b/native/jni/src/correction.cpp
index 76234f840..0c65939e0 100644
--- a/native/jni/src/correction.cpp
+++ b/native/jni/src/correction.cpp
@@ -675,7 +675,7 @@ inline static bool isUpperCase(unsigned short c) {
multiplyIntCapped(typedLetterMultiplier, &finalFreq);
}
const float factor =
- SuggestUtils::getDistanceScalingFactor(static_cast<float>(squaredDistance));
+ SuggestUtils::getLengthScalingFactor(static_cast<float>(squaredDistance));
if (factor > 0.0f) {
multiplyRate(static_cast<int>(factor * 100.0f), &finalFreq);
} else if (squaredDistance == PROXIMITY_CHAR_WITHOUT_DISTANCE_INFO) {
diff --git a/native/jni/src/defines.h b/native/jni/src/defines.h
index a45691261..a7b023a75 100644
--- a/native/jni/src/defines.h
+++ b/native/jni/src/defines.h
@@ -216,6 +216,7 @@ static inline void prof_out(void) {
#define DEBUG_DOUBLE_LETTER false
#define DEBUG_CACHE false
#define DEBUG_DUMP_ERROR false
+#define DEBUG_EVALUATE_MOST_PROBABLE_STRING false
#ifdef FLAG_FULL_DBG
#define DEBUG_GEO_FULL true
@@ -241,6 +242,7 @@ static inline void prof_out(void) {
#define DEBUG_DOUBLE_LETTER false
#define DEBUG_CACHE false
#define DEBUG_DUMP_ERROR false
+#define DEBUG_EVALUATE_MOST_PROBABLE_STRING false
#define DEBUG_GEO_FULL false
diff --git a/native/jni/src/dictionary.cpp b/native/jni/src/dictionary.cpp
index ed6ddb517..c998c0676 100644
--- a/native/jni/src/dictionary.cpp
+++ b/native/jni/src/dictionary.cpp
@@ -103,4 +103,9 @@ int Dictionary::getProbability(const int *word, int length) const {
bool Dictionary::isValidBigram(const int *word1, int length1, const int *word2, int length2) const {
return mBigramDictionary->isValidBigram(word1, length1, word2, length2);
}
+
+int Dictionary::getDictFlags() const {
+ return mUnigramDictionary->getDictFlags();
+}
+
} // namespace latinime
diff --git a/native/jni/src/dictionary.h b/native/jni/src/dictionary.h
index 8c6a7de52..0653d3ca9 100644
--- a/native/jni/src/dictionary.h
+++ b/native/jni/src/dictionary.h
@@ -63,6 +63,7 @@ class Dictionary {
int getDictSize() const { return mDictSize; }
int getMmapFd() const { return mMmapFd; }
int getDictBufAdjust() const { return mDictBufAdjust; }
+ int getDictFlags() const;
virtual ~Dictionary();
private:
diff --git a/native/jni/src/digraph_utils.cpp b/native/jni/src/digraph_utils.cpp
index 8781c5077..6a1ab0271 100644
--- a/native/jni/src/digraph_utils.cpp
+++ b/native/jni/src/digraph_utils.cpp
@@ -27,39 +27,47 @@ const DigraphUtils::digraph_t DigraphUtils::GERMAN_UMLAUT_DIGRAPHS[] =
const DigraphUtils::digraph_t DigraphUtils::FRENCH_LIGATURES_DIGRAPHS[] =
{ { 'a', 'e', 0x00E6 }, // U+00E6 : LATIN SMALL LETTER AE
{ 'o', 'e', 0x0153 } }; // U+0153 : LATIN SMALL LIGATURE OE
+const DigraphUtils::DigraphType DigraphUtils::USED_DIGRAPH_TYPES[] =
+ { DIGRAPH_TYPE_GERMAN_UMLAUT, DIGRAPH_TYPE_FRENCH_LIGATURES };
/* static */ bool DigraphUtils::hasDigraphForCodePoint(
const int dictFlags, const int compositeGlyphCodePoint) {
- if (DigraphUtils::getDigraphForCodePoint(dictFlags, compositeGlyphCodePoint)) {
+ const DigraphUtils::DigraphType digraphType = getDigraphTypeForDictionary(dictFlags);
+ if (DigraphUtils::getDigraphForDigraphTypeAndCodePoint(digraphType, compositeGlyphCodePoint)) {
return true;
}
return false;
}
-// Retrieves the set of all digraphs associated with the given dictionary.
-// Returns the size of the digraph array, or 0 if none exist.
-/* static */ int DigraphUtils::getAllDigraphsForDictionaryAndReturnSize(
- const int dictFlags, const DigraphUtils::digraph_t **digraphs) {
+// Returns the digraph type associated with the given dictionary.
+/* static */ DigraphUtils::DigraphType DigraphUtils::getDigraphTypeForDictionary(
+ const int dictFlags) {
if (BinaryFormat::REQUIRES_GERMAN_UMLAUT_PROCESSING & dictFlags) {
- *digraphs = DigraphUtils::GERMAN_UMLAUT_DIGRAPHS;
- return NELEMS(DigraphUtils::GERMAN_UMLAUT_DIGRAPHS);
+ return DIGRAPH_TYPE_GERMAN_UMLAUT;
}
if (BinaryFormat::REQUIRES_FRENCH_LIGATURES_PROCESSING & dictFlags) {
- *digraphs = DigraphUtils::FRENCH_LIGATURES_DIGRAPHS;
- return NELEMS(DigraphUtils::FRENCH_LIGATURES_DIGRAPHS);
+ return DIGRAPH_TYPE_FRENCH_LIGATURES;
}
- return 0;
+ return DIGRAPH_TYPE_NONE;
+}
+
+// Retrieves the set of all digraphs associated with the given dictionary flags.
+// Returns the size of the digraph array, or 0 if none exist.
+/* static */ int DigraphUtils::getAllDigraphsForDictionaryAndReturnSize(
+ const int dictFlags, const DigraphUtils::digraph_t **const digraphs) {
+ const DigraphUtils::DigraphType digraphType = getDigraphTypeForDictionary(dictFlags);
+ return getAllDigraphsForDigraphTypeAndReturnSize(digraphType, digraphs);
}
// Returns the digraph codepoint for the given composite glyph codepoint and digraph codepoint index
// (which specifies the first or second codepoint in the digraph).
-/* static */ int DigraphUtils::getDigraphCodePointForIndex(const int dictFlags,
- const int compositeGlyphCodePoint, const DigraphCodePointIndex digraphCodePointIndex) {
+/* static */ int DigraphUtils::getDigraphCodePointForIndex(const int compositeGlyphCodePoint,
+ const DigraphCodePointIndex digraphCodePointIndex) {
if (digraphCodePointIndex == NOT_A_DIGRAPH_INDEX) {
return NOT_A_CODE_POINT;
}
- const DigraphUtils::digraph_t *digraph =
- DigraphUtils::getDigraphForCodePoint(dictFlags, compositeGlyphCodePoint);
+ const DigraphUtils::digraph_t *const digraph =
+ DigraphUtils::getDigraphForCodePoint(compositeGlyphCodePoint);
if (!digraph) {
return NOT_A_CODE_POINT;
}
@@ -72,16 +80,48 @@ const DigraphUtils::digraph_t DigraphUtils::FRENCH_LIGATURES_DIGRAPHS[] =
return NOT_A_CODE_POINT;
}
+// Retrieves the set of all digraphs associated with the given digraph type.
+// Returns the size of the digraph array, or 0 if none exist.
+/* static */ int DigraphUtils::getAllDigraphsForDigraphTypeAndReturnSize(
+ const DigraphUtils::DigraphType digraphType,
+ const DigraphUtils::digraph_t **const digraphs) {
+ if (digraphType == DigraphUtils::DIGRAPH_TYPE_GERMAN_UMLAUT) {
+ *digraphs = GERMAN_UMLAUT_DIGRAPHS;
+ return NELEMS(GERMAN_UMLAUT_DIGRAPHS);
+ }
+ if (digraphType == DIGRAPH_TYPE_FRENCH_LIGATURES) {
+ *digraphs = FRENCH_LIGATURES_DIGRAPHS;
+ return NELEMS(FRENCH_LIGATURES_DIGRAPHS);
+ }
+ return 0;
+}
+
/**
* Returns the digraph for the input composite glyph codepoint, or 0 if none exists.
- * dictFlags: the dictionary flags needed to determine which digraphs are supported.
* compositeGlyphCodePoint: the method returns the digraph corresponding to this codepoint.
*/
/* static */ const DigraphUtils::digraph_t *DigraphUtils::getDigraphForCodePoint(
- const int dictFlags, const int compositeGlyphCodePoint) {
+ const int compositeGlyphCodePoint) {
+ for (size_t i = 0; i < NELEMS(USED_DIGRAPH_TYPES); i++) {
+ const DigraphUtils::digraph_t *const digraph = getDigraphForDigraphTypeAndCodePoint(
+ USED_DIGRAPH_TYPES[i], compositeGlyphCodePoint);
+ if (digraph) {
+ return digraph;
+ }
+ }
+ return 0;
+}
+
+/**
+ * Returns the digraph for the input composite glyph codepoint, or 0 if none exists.
+ * digraphType: the type of digraphs supported.
+ * compositeGlyphCodePoint: the method returns the digraph corresponding to this codepoint.
+ */
+/* static */ const DigraphUtils::digraph_t *DigraphUtils::getDigraphForDigraphTypeAndCodePoint(
+ const DigraphUtils::DigraphType digraphType, const int compositeGlyphCodePoint) {
const DigraphUtils::digraph_t *digraphs = 0;
const int digraphsSize =
- DigraphUtils::getAllDigraphsForDictionaryAndReturnSize(dictFlags, &digraphs);
+ DigraphUtils::getAllDigraphsForDictionaryAndReturnSize(digraphType, &digraphs);
for (int i = 0; i < digraphsSize; i++) {
if (digraphs[i].compositeGlyph == compositeGlyphCodePoint) {
return &digraphs[i];
diff --git a/native/jni/src/digraph_utils.h b/native/jni/src/digraph_utils.h
index 6e364b67a..94435228e 100644
--- a/native/jni/src/digraph_utils.h
+++ b/native/jni/src/digraph_utils.h
@@ -27,21 +27,34 @@ class DigraphUtils {
SECOND_DIGRAPH_CODEPOINT
} DigraphCodePointIndex;
+ typedef enum {
+ DIGRAPH_TYPE_NONE,
+ DIGRAPH_TYPE_GERMAN_UMLAUT,
+ DIGRAPH_TYPE_FRENCH_LIGATURES
+ } DigraphType;
+
typedef struct { int first; int second; int compositeGlyph; } digraph_t;
static bool hasDigraphForCodePoint(const int dictFlags, const int compositeGlyphCodePoint);
static int getAllDigraphsForDictionaryAndReturnSize(
- const int dictFlags, const digraph_t **digraphs);
+ const int dictFlags, const digraph_t **const digraphs);
static int getDigraphCodePointForIndex(const int dictFlags, const int compositeGlyphCodePoint,
const DigraphCodePointIndex digraphCodePointIndex);
+ static int getDigraphCodePointForIndex(const int compositeGlyphCodePoint,
+ const DigraphCodePointIndex digraphCodePointIndex);
private:
DISALLOW_IMPLICIT_CONSTRUCTORS(DigraphUtils);
- static const digraph_t *getDigraphForCodePoint(
- const int dictFlags, const int compositeGlyphCodePoint);
+ static DigraphType getDigraphTypeForDictionary(const int dictFlags);
+ static int getAllDigraphsForDigraphTypeAndReturnSize(
+ const DigraphType digraphType, const digraph_t **const digraphs);
+ static const digraph_t *getDigraphForCodePoint(const int compositeGlyphCodePoint);
+ static const digraph_t *getDigraphForDigraphTypeAndCodePoint(
+ const DigraphType digraphType, const int compositeGlyphCodePoint);
static const digraph_t GERMAN_UMLAUT_DIGRAPHS[];
static const digraph_t FRENCH_LIGATURES_DIGRAPHS[];
+ static const DigraphType USED_DIGRAPH_TYPES[];
};
} // namespace latinime
#endif // DIGRAPH_UTILS_H
diff --git a/native/jni/src/proximity_info.cpp b/native/jni/src/proximity_info.cpp
index 50f38e82e..88d670d61 100644
--- a/native/jni/src/proximity_info.cpp
+++ b/native/jni/src/proximity_info.cpp
@@ -49,13 +49,17 @@ static AK_FORCE_INLINE void safeGetOrFillZeroFloatArrayRegion(JNIEnv *env, jfloa
ProximityInfo::ProximityInfo(JNIEnv *env, const jstring localeJStr,
const int keyboardWidth, const int keyboardHeight, const int gridWidth,
- const int gridHeight, const int mostCommonKeyWidth, const jintArray proximityChars,
- const int keyCount, const jintArray keyXCoordinates, const jintArray keyYCoordinates,
- const jintArray keyWidths, const jintArray keyHeights, const jintArray keyCharCodes,
- const jfloatArray sweetSpotCenterXs, const jfloatArray sweetSpotCenterYs,
- const jfloatArray sweetSpotRadii)
+ const int gridHeight, const int mostCommonKeyWidth, const int mostCommonKeyHeight,
+ const jintArray proximityChars, const int keyCount, const jintArray keyXCoordinates,
+ const jintArray keyYCoordinates, const jintArray keyWidths, const jintArray keyHeights,
+ const jintArray keyCharCodes, const jfloatArray sweetSpotCenterXs,
+ const jfloatArray sweetSpotCenterYs, const jfloatArray sweetSpotRadii)
: GRID_WIDTH(gridWidth), GRID_HEIGHT(gridHeight), MOST_COMMON_KEY_WIDTH(mostCommonKeyWidth),
MOST_COMMON_KEY_WIDTH_SQUARE(mostCommonKeyWidth * mostCommonKeyWidth),
+ MOST_COMMON_KEY_HEIGHT(mostCommonKeyHeight),
+ NORMALIZED_SQUARED_MOST_COMMON_KEY_HYPOTENUSE(1.0f +
+ SQUARE_FLOAT(static_cast<float>(mostCommonKeyHeight) /
+ static_cast<float>(mostCommonKeyWidth))),
CELL_WIDTH((keyboardWidth + gridWidth - 1) / gridWidth),
CELL_HEIGHT((keyboardHeight + gridHeight - 1) / gridHeight),
KEY_COUNT(min(keyCount, MAX_KEY_COUNT_IN_A_KEYBOARD)),
diff --git a/native/jni/src/proximity_info.h b/native/jni/src/proximity_info.h
index e21262fdb..deb9ae0de 100644
--- a/native/jni/src/proximity_info.h
+++ b/native/jni/src/proximity_info.h
@@ -30,11 +30,11 @@ class ProximityInfo {
public:
ProximityInfo(JNIEnv *env, const jstring localeJStr,
const int keyboardWidth, const int keyboardHeight, const int gridWidth,
- const int gridHeight, const int mostCommonKeyWidth, const jintArray proximityChars,
- const int keyCount, const jintArray keyXCoordinates, const jintArray keyYCoordinates,
- const jintArray keyWidths, const jintArray keyHeights, const jintArray keyCharCodes,
- const jfloatArray sweetSpotCenterXs, const jfloatArray sweetSpotCenterYs,
- const jfloatArray sweetSpotRadii);
+ const int gridHeight, const int mostCommonKeyWidth, const int mostCommonKeyHeight,
+ const jintArray proximityChars, const int keyCount, const jintArray keyXCoordinates,
+ const jintArray keyYCoordinates, const jintArray keyWidths, const jintArray keyHeights,
+ const jintArray keyCharCodes, const jfloatArray sweetSpotCenterXs,
+ const jfloatArray sweetSpotCenterYs, const jfloatArray sweetSpotRadii);
~ProximityInfo();
bool hasSpaceProximity(const int x, const int y) const;
int getNormalizedSquaredDistance(const int inputIndex, const int proximityIndex) const;
@@ -56,6 +56,9 @@ class ProximityInfo {
bool hasTouchPositionCorrectionData() const { return HAS_TOUCH_POSITION_CORRECTION_DATA; }
int getMostCommonKeyWidth() const { return MOST_COMMON_KEY_WIDTH; }
int getMostCommonKeyWidthSquare() const { return MOST_COMMON_KEY_WIDTH_SQUARE; }
+ float getNormalizedSquaredMostCommonKeyHypotenuse() const {
+ return NORMALIZED_SQUARED_MOST_COMMON_KEY_HYPOTENUSE;
+ }
int getKeyCount() const { return KEY_COUNT; }
int getCellHeight() const { return CELL_HEIGHT; }
int getCellWidth() const { return CELL_WIDTH; }
@@ -99,6 +102,8 @@ class ProximityInfo {
const int GRID_HEIGHT;
const int MOST_COMMON_KEY_WIDTH;
const int MOST_COMMON_KEY_WIDTH_SQUARE;
+ const int MOST_COMMON_KEY_HEIGHT;
+ const float NORMALIZED_SQUARED_MOST_COMMON_KEY_HYPOTENUSE;
const int CELL_WIDTH;
const int CELL_HEIGHT;
const int KEY_COUNT;
diff --git a/native/jni/src/proximity_info_state.cpp b/native/jni/src/proximity_info_state.cpp
index a10b260e1..cc5b736bd 100644
--- a/native/jni/src/proximity_info_state.cpp
+++ b/native/jni/src/proximity_info_state.cpp
@@ -81,7 +81,7 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi
mSampledTimes.clear();
mSampledInputIndice.clear();
mSampledLengthCache.clear();
- mSampledDistanceCache_G.clear();
+ mSampledNormalizedSquaredLengthCache.clear();
mSampledNearKeySets.clear();
mSampledSearchKeySets.clear();
mSpeedRates.clear();
@@ -122,14 +122,15 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi
if (mSampledInputSize > 0) {
ProximityInfoStateUtils::initGeometricDistanceInfos(mProximityInfo, mSampledInputSize,
lastSavedInputSize, verticalSweetSpotScale, &mSampledInputXs, &mSampledInputYs,
- &mSampledNearKeySets, &mSampledDistanceCache_G);
+ &mSampledNearKeySets, &mSampledNormalizedSquaredLengthCache);
if (isGeometric) {
// updates probabilities of skipping or mapping each key for all points.
ProximityInfoStateUtils::updateAlignPointProbabilities(
mMaxPointToKeyLength, mProximityInfo->getMostCommonKeyWidth(),
mProximityInfo->getKeyCount(), lastSavedInputSize, mSampledInputSize,
&mSampledInputXs, &mSampledInputYs, &mSpeedRates, &mSampledLengthCache,
- &mSampledDistanceCache_G, &mSampledNearKeySets, &mCharProbabilities);
+ &mSampledNormalizedSquaredLengthCache, &mSampledNearKeySets,
+ &mCharProbabilities);
ProximityInfoStateUtils::updateSampledSearchKeySets(mProximityInfo,
mSampledInputSize, lastSavedInputSize, &mSampledLengthCache,
&mSampledNearKeySets, &mSampledSearchKeySets,
@@ -171,7 +172,7 @@ float ProximityInfoState::getPointToKeyLength(
const int keyId = mProximityInfo->getKeyIndexOf(codePoint);
if (keyId != NOT_AN_INDEX) {
const int index = inputIndex * mProximityInfo->getKeyCount() + keyId;
- return min(mSampledDistanceCache_G[index], mMaxPointToKeyLength);
+ return min(mSampledNormalizedSquaredLengthCache[index], mMaxPointToKeyLength);
}
if (isIntentionalOmissionCodePoint(codePoint)) {
return 0.0f;
@@ -183,7 +184,8 @@ float ProximityInfoState::getPointToKeyLength(
float ProximityInfoState::getPointToKeyByIdLength(
const int inputIndex, const int keyId) const {
return ProximityInfoStateUtils::getPointToKeyByIdLength(mMaxPointToKeyLength,
- &mSampledDistanceCache_G, mProximityInfo->getKeyCount(), inputIndex, keyId);
+ &mSampledNormalizedSquaredLengthCache, mProximityInfo->getKeyCount(), inputIndex,
+ keyId);
}
// In the following function, c is the current character of the dictionary word currently examined.
diff --git a/native/jni/src/proximity_info_state.h b/native/jni/src/proximity_info_state.h
index 9bba751d0..bbe8af240 100644
--- a/native/jni/src/proximity_info_state.h
+++ b/native/jni/src/proximity_info_state.h
@@ -49,8 +49,8 @@ class ProximityInfoState {
mKeyCount(0), mCellHeight(0), mCellWidth(0), mGridHeight(0), mGridWidth(0),
mIsContinuousSuggestionPossible(false), mSampledInputXs(), mSampledInputYs(),
mSampledTimes(), mSampledInputIndice(), mSampledLengthCache(),
- mBeelineSpeedPercentiles(), mSampledDistanceCache_G(), mSpeedRates(), mDirections(),
- mCharProbabilities(), mSampledNearKeySets(), mSampledSearchKeySets(),
+ mBeelineSpeedPercentiles(), mSampledNormalizedSquaredLengthCache(), mSpeedRates(),
+ mDirections(), mCharProbabilities(), mSampledNearKeySets(), mSampledSearchKeySets(),
mSampledSearchKeyVectors(), mTouchPositionCorrectionEnabled(false),
mSampledInputSize(0), mMostProbableStringProbability(0.0f) {
memset(mInputProximities, 0, sizeof(mInputProximities));
@@ -147,7 +147,9 @@ class ProximityInfoState {
return mIsContinuousSuggestionPossible;
}
+ // TODO: Rename s/Length/NormalizedSquaredLength/
float getPointToKeyByIdLength(const int inputIndex, const int keyId) const;
+ // TODO: Rename s/Length/NormalizedSquaredLength/
float getPointToKeyLength(const int inputIndex, const int codePoint) const;
ProximityType getProximityType(const int index, const int codePoint,
@@ -231,7 +233,7 @@ class ProximityInfoState {
std::vector<int> mSampledInputIndice;
std::vector<int> mSampledLengthCache;
std::vector<int> mBeelineSpeedPercentiles;
- std::vector<float> mSampledDistanceCache_G;
+ std::vector<float> mSampledNormalizedSquaredLengthCache;
std::vector<float> mSpeedRates;
std::vector<float> mDirections;
// probabilities of skipping or mapping to a key for each point.
diff --git a/native/jni/src/proximity_info_state_utils.cpp b/native/jni/src/proximity_info_state_utils.cpp
index df70cffdf..359673cd8 100644
--- a/native/jni/src/proximity_info_state_utils.cpp
+++ b/native/jni/src/proximity_info_state_utils.cpp
@@ -225,13 +225,13 @@ namespace latinime {
const int lastSavedInputSize, const float verticalSweetSpotScale,
const std::vector<int> *const sampledInputXs,
const std::vector<int> *const sampledInputYs,
- std::vector<NearKeycodesSet> *SampledNearKeySets,
- std::vector<float> *SampledDistanceCache_G) {
- SampledNearKeySets->resize(sampledInputSize);
+ std::vector<NearKeycodesSet> *sampledNearKeySets,
+ std::vector<float> *sampledNormalizedSquaredLengthCache) {
+ sampledNearKeySets->resize(sampledInputSize);
const int keyCount = proximityInfo->getKeyCount();
- SampledDistanceCache_G->resize(sampledInputSize * keyCount);
+ sampledNormalizedSquaredLengthCache->resize(sampledInputSize * keyCount);
for (int i = lastSavedInputSize; i < sampledInputSize; ++i) {
- (*SampledNearKeySets)[i].reset();
+ (*sampledNearKeySets)[i].reset();
for (int k = 0; k < keyCount; ++k) {
const int index = i * keyCount + k;
const int x = (*sampledInputXs)[i];
@@ -239,10 +239,10 @@ namespace latinime {
const float normalizedSquaredDistance =
proximityInfo->getNormalizedSquaredDistanceFromCenterFloatG(
k, x, y, verticalSweetSpotScale);
- (*SampledDistanceCache_G)[index] = normalizedSquaredDistance;
+ (*sampledNormalizedSquaredLengthCache)[index] = normalizedSquaredDistance;
if (normalizedSquaredDistance
< ProximityInfoParams::NEAR_KEY_NORMALIZED_SQUARED_THRESHOLD) {
- (*SampledNearKeySets)[i][k] = true;
+ (*sampledNearKeySets)[i][k] = true;
}
}
}
@@ -642,11 +642,11 @@ namespace latinime {
// This function basically converts from a length to an edit distance. Accordingly, it's obviously
// wrong to compare with mMaxPointToKeyLength.
/* static */ float ProximityInfoStateUtils::getPointToKeyByIdLength(const float maxPointToKeyLength,
- const std::vector<float> *const SampledDistanceCache_G, const int keyCount,
+ const std::vector<float> *const sampledNormalizedSquaredLengthCache, const int keyCount,
const int inputIndex, const int keyId) {
if (keyId != NOT_AN_INDEX) {
const int index = inputIndex * keyCount + keyId;
- return min((*SampledDistanceCache_G)[index], maxPointToKeyLength);
+ return min((*sampledNormalizedSquaredLengthCache)[index], maxPointToKeyLength);
}
// If the char is not a key on the keyboard then return the max length.
return static_cast<float>(MAX_VALUE_FOR_WEIGHTING);
@@ -660,8 +660,8 @@ namespace latinime {
const std::vector<int> *const sampledInputYs,
const std::vector<float> *const sampledSpeedRates,
const std::vector<int> *const sampledLengthCache,
- const std::vector<float> *const SampledDistanceCache_G,
- std::vector<NearKeycodesSet> *SampledNearKeySets,
+ const std::vector<float> *const sampledNormalizedSquaredLengthCache,
+ std::vector<NearKeycodesSet> *sampledNearKeySets,
std::vector<hash_map_compat<int, float> > *charProbabilities) {
charProbabilities->resize(sampledInputSize);
// Calculates probabilities of using a point as a correlated point with the character
@@ -677,9 +677,9 @@ namespace latinime {
float nearestKeyDistance = static_cast<float>(MAX_VALUE_FOR_WEIGHTING);
for (int j = 0; j < keyCount; ++j) {
- if ((*SampledNearKeySets)[i].test(j)) {
+ if ((*sampledNearKeySets)[i].test(j)) {
const float distance = getPointToKeyByIdLength(
- maxPointToKeyLength, SampledDistanceCache_G, keyCount, i, j);
+ maxPointToKeyLength, sampledNormalizedSquaredLengthCache, keyCount, i, j);
if (distance < nearestKeyDistance) {
nearestKeyDistance = distance;
}
@@ -758,14 +758,15 @@ namespace latinime {
// Summing up probability densities of all near keys.
float sumOfProbabilityDensities = 0.0f;
for (int j = 0; j < keyCount; ++j) {
- if ((*SampledNearKeySets)[i].test(j)) {
+ if ((*sampledNearKeySets)[i].test(j)) {
float distance = sqrtf(getPointToKeyByIdLength(
- maxPointToKeyLength, SampledDistanceCache_G, keyCount, i, j));
+ maxPointToKeyLength, sampledNormalizedSquaredLengthCache, keyCount, i, j));
if (i == 0 && i != sampledInputSize - 1) {
// For the first point, weighted average of distances from first point and the
// next point to the key is used as a point to key distance.
const float nextDistance = sqrtf(getPointToKeyByIdLength(
- maxPointToKeyLength, SampledDistanceCache_G, keyCount, i + 1, j));
+ maxPointToKeyLength, sampledNormalizedSquaredLengthCache, keyCount,
+ i + 1, j));
if (nextDistance < distance) {
// The distance of the first point tends to bigger than continuing
// points because the first touch by the user can be sloppy.
@@ -779,7 +780,8 @@ namespace latinime {
// For the first point, weighted average of distances from last point and
// the previous point to the key is used as a point to key distance.
const float previousDistance = sqrtf(getPointToKeyByIdLength(
- maxPointToKeyLength, SampledDistanceCache_G, keyCount, i - 1, j));
+ maxPointToKeyLength, sampledNormalizedSquaredLengthCache, keyCount,
+ i - 1, j));
if (previousDistance < distance) {
// The distance of the last point tends to bigger than continuing points
// because the last touch by the user can be sloppy. So we promote the
@@ -798,14 +800,15 @@ namespace latinime {
// Split the probability of an input point to keys that are close to the input point.
for (int j = 0; j < keyCount; ++j) {
- if ((*SampledNearKeySets)[i].test(j)) {
+ if ((*sampledNearKeySets)[i].test(j)) {
float distance = sqrtf(getPointToKeyByIdLength(
- maxPointToKeyLength, SampledDistanceCache_G, keyCount, i, j));
+ maxPointToKeyLength, sampledNormalizedSquaredLengthCache, keyCount, i, j));
if (i == 0 && i != sampledInputSize - 1) {
// For the first point, weighted average of distances from the first point and
// the next point to the key is used as a point to key distance.
const float prevDistance = sqrtf(getPointToKeyByIdLength(
- maxPointToKeyLength, SampledDistanceCache_G, keyCount, i + 1, j));
+ maxPointToKeyLength, sampledNormalizedSquaredLengthCache, keyCount,
+ i + 1, j));
if (prevDistance < distance) {
distance = (distance
+ prevDistance * ProximityInfoParams::NEXT_DISTANCE_WEIGHT)
@@ -815,7 +818,8 @@ namespace latinime {
// For the first point, weighted average of distances from last point and
// the previous point to the key is used as a point to key distance.
const float prevDistance = sqrtf(getPointToKeyByIdLength(
- maxPointToKeyLength, SampledDistanceCache_G, keyCount, i - 1, j));
+ maxPointToKeyLength, sampledNormalizedSquaredLengthCache, keyCount,
+ i - 1, j));
if (prevDistance < distance) {
distance = (distance
+ prevDistance * ProximityInfoParams::PREV_DISTANCE_WEIGHT)
@@ -882,10 +886,10 @@ namespace latinime {
for (int j = 0; j < keyCount; ++j) {
hash_map_compat<int, float>::iterator it = (*charProbabilities)[i].find(j);
if (it == (*charProbabilities)[i].end()){
- (*SampledNearKeySets)[i].reset(j);
+ (*sampledNearKeySets)[i].reset(j);
} else if(it->second < ProximityInfoParams::MIN_PROBABILITY) {
// Erases from near keys vector because it has very low probability.
- (*SampledNearKeySets)[i].reset(j);
+ (*sampledNearKeySets)[i].reset(j);
(*charProbabilities)[i].erase(j);
} else {
it->second = -logf(it->second);
@@ -899,7 +903,7 @@ namespace latinime {
const ProximityInfo *const proximityInfo, const int sampledInputSize,
const int lastSavedInputSize,
const std::vector<int> *const sampledLengthCache,
- const std::vector<NearKeycodesSet> *const SampledNearKeySets,
+ const std::vector<NearKeycodesSet> *const sampledNearKeySets,
std::vector<NearKeycodesSet> *sampledSearchKeySets,
std::vector<std::vector<int> > *sampledSearchKeyVectors) {
sampledSearchKeySets->resize(sampledInputSize);
@@ -916,7 +920,7 @@ namespace latinime {
if ((*sampledLengthCache)[j] - (*sampledLengthCache)[i] >= readForwordLength) {
break;
}
- (*sampledSearchKeySets)[i] |= (*SampledNearKeySets)[j];
+ (*sampledSearchKeySets)[i] |= (*sampledNearKeySets)[j];
}
}
const int keyCount = proximityInfo->getKeyCount();
diff --git a/native/jni/src/proximity_info_state_utils.h b/native/jni/src/proximity_info_state_utils.h
index c9feb59a3..1837c7ab6 100644
--- a/native/jni/src/proximity_info_state_utils.h
+++ b/native/jni/src/proximity_info_state_utils.h
@@ -71,25 +71,25 @@ class ProximityInfoStateUtils {
const std::vector<int> *const sampledInputYs,
const std::vector<float> *const sampledSpeedRates,
const std::vector<int> *const sampledLengthCache,
- const std::vector<float> *const SampledDistanceCache_G,
- std::vector<NearKeycodesSet> *SampledNearKeySets,
+ const std::vector<float> *const sampledNormalizedSquaredLengthCache,
+ std::vector<NearKeycodesSet> *sampledNearKeySets,
std::vector<hash_map_compat<int, float> > *charProbabilities);
static void updateSampledSearchKeySets(const ProximityInfo *const proximityInfo,
const int sampledInputSize, const int lastSavedInputSize,
const std::vector<int> *const sampledLengthCache,
- const std::vector<NearKeycodesSet> *const SampledNearKeySets,
+ const std::vector<NearKeycodesSet> *const sampledNearKeySets,
std::vector<NearKeycodesSet> *sampledSearchKeySets,
std::vector<std::vector<int> > *sampledSearchKeyVectors);
static float getPointToKeyByIdLength(const float maxPointToKeyLength,
- const std::vector<float> *const SampledDistanceCache_G, const int keyCount,
+ const std::vector<float> *const sampledNormalizedSquaredLengthCache, const int keyCount,
const int inputIndex, const int keyId);
static void initGeometricDistanceInfos(const ProximityInfo *const proximityInfo,
const int sampledInputSize, const int lastSavedInputSize,
const float verticalSweetSpotScale,
const std::vector<int> *const sampledInputXs,
const std::vector<int> *const sampledInputYs,
- std::vector<NearKeycodesSet> *SampledNearKeySets,
- std::vector<float> *SampledDistanceCache_G);
+ std::vector<NearKeycodesSet> *sampledNearKeySets,
+ std::vector<float> *sampledNormalizedSquaredLengthCache);
static void initPrimaryInputWord(const int inputSize, const int *const inputProximities,
int *primaryInputWord);
static void initNormalizedSquaredDistances(const ProximityInfo *const proximityInfo,
diff --git a/native/jni/src/suggest/core/dicnode/dic_node.h b/native/jni/src/suggest/core/dicnode/dic_node.h
index cde7b99a7..32faae52c 100644
--- a/native/jni/src/suggest/core/dicnode/dic_node.h
+++ b/native/jni/src/suggest/core/dicnode/dic_node.h
@@ -23,6 +23,7 @@
#include "dic_node_profiler.h"
#include "dic_node_properties.h"
#include "dic_node_release_listener.h"
+#include "digraph_utils.h"
#if DEBUG_DICT
#define LOGI_SHOW_ADD_COST_PROP \
@@ -399,8 +400,15 @@ class DicNode {
// TODO: Remove //
//////////////////////
// TODO: Remove once touch path is merged into ProximityInfoState
+ // Note: Returned codepoint may be a digraph codepoint if the node is in a composite glyph.
int getNodeCodePoint() const {
- return mDicNodeProperties.getNodeCodePoint();
+ const int codePoint = mDicNodeProperties.getNodeCodePoint();
+ const DigraphUtils::DigraphCodePointIndex digraphIndex =
+ mDicNodeState.mDicNodeStateScoring.getDigraphIndex();
+ if (digraphIndex == DigraphUtils::NOT_A_DIGRAPH_INDEX) {
+ return codePoint;
+ }
+ return DigraphUtils::getDigraphCodePointForIndex(codePoint, digraphIndex);
}
////////////////////////////////
@@ -452,6 +460,15 @@ class DicNode {
mDicNodeState.mDicNodeStateScoring.setDoubleLetterLevel(doubleLetterLevel);
}
+ bool isInDigraph() const {
+ return mDicNodeState.mDicNodeStateScoring.getDigraphIndex()
+ != DigraphUtils::NOT_A_DIGRAPH_INDEX;
+ }
+
+ void advanceDigraphIndex() {
+ mDicNodeState.mDicNodeStateScoring.advanceDigraphIndex();
+ }
+
uint8_t getFlags() const {
return mDicNodeProperties.getFlags();
}
diff --git a/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h b/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h
index 8e816329f..8902d3122 100644
--- a/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h
+++ b/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h
@@ -20,6 +20,7 @@
#include <stdint.h>
#include "defines.h"
+#include "digraph_utils.h"
namespace latinime {
@@ -27,6 +28,7 @@ class DicNodeStateScoring {
public:
AK_FORCE_INLINE DicNodeStateScoring()
: mDoubleLetterLevel(NOT_A_DOUBLE_LETTER),
+ mDigraphIndex(DigraphUtils::NOT_A_DIGRAPH_INDEX),
mEditCorrectionCount(0), mProximityCorrectionCount(0),
mNormalizedCompoundDistance(0.0f), mSpatialDistance(0.0f), mLanguageDistance(0.0f),
mTotalPrevWordsLanguageCost(0.0f), mRawLength(0.0f) {
@@ -43,6 +45,7 @@ class DicNodeStateScoring {
mTotalPrevWordsLanguageCost = 0.0f;
mRawLength = 0.0f;
mDoubleLetterLevel = NOT_A_DOUBLE_LETTER;
+ mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX;
}
AK_FORCE_INLINE void init(const DicNodeStateScoring *const scoring) {
@@ -54,6 +57,7 @@ class DicNodeStateScoring {
mTotalPrevWordsLanguageCost = scoring->mTotalPrevWordsLanguageCost;
mRawLength = scoring->mRawLength;
mDoubleLetterLevel = scoring->mDoubleLetterLevel;
+ mDigraphIndex = scoring->mDigraphIndex;
}
void addCost(const float spatialCost, const float languageCost, const bool doNormalization,
@@ -126,6 +130,24 @@ class DicNodeStateScoring {
}
}
+ DigraphUtils::DigraphCodePointIndex getDigraphIndex() const {
+ return mDigraphIndex;
+ }
+
+ void advanceDigraphIndex() {
+ switch(mDigraphIndex) {
+ case DigraphUtils::NOT_A_DIGRAPH_INDEX:
+ mDigraphIndex = DigraphUtils::FIRST_DIGRAPH_CODEPOINT;
+ break;
+ case DigraphUtils::FIRST_DIGRAPH_CODEPOINT:
+ mDigraphIndex = DigraphUtils::SECOND_DIGRAPH_CODEPOINT;
+ break;
+ case DigraphUtils::SECOND_DIGRAPH_CODEPOINT:
+ mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX;
+ break;
+ }
+ }
+
float getTotalPrevWordsLanguageCost() const {
return mTotalPrevWordsLanguageCost;
}
@@ -135,6 +157,7 @@ class DicNodeStateScoring {
// Use a default copy constructor and an assign operator because shallow copies are ok
// for this class
DoubleLetterLevel mDoubleLetterLevel;
+ DigraphUtils::DigraphCodePointIndex mDigraphIndex;
int16_t mEditCorrectionCount;
int16_t mProximityCorrectionCount;
diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.cpp b/native/jni/src/suggest/core/session/dic_traverse_session.cpp
index ef6616e40..5b783a2ba 100644
--- a/native/jni/src/suggest/core/session/dic_traverse_session.cpp
+++ b/native/jni/src/suggest/core/session/dic_traverse_session.cpp
@@ -84,6 +84,10 @@ const uint8_t *DicTraverseSession::getOffsetDict() const {
return mDictionary->getOffsetDict();
}
+int DicTraverseSession::getDictFlags() const {
+ return mDictionary->getDictFlags();
+}
+
void DicTraverseSession::resetCache(const int nextActiveCacheSize, const int maxWords) {
mDicNodesCache.reset(nextActiveCacheSize, maxWords);
mBigramCacheMap.clear();
diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.h b/native/jni/src/suggest/core/session/dic_traverse_session.h
index 62e1d1ab9..fe0527639 100644
--- a/native/jni/src/suggest/core/session/dic_traverse_session.h
+++ b/native/jni/src/suggest/core/session/dic_traverse_session.h
@@ -53,7 +53,7 @@ class DicTraverseSession {
void resetCache(const int nextActiveCacheSize, const int maxWords);
const uint8_t *getOffsetDict() const;
- bool canUseCache() const;
+ int getDictFlags() const;
//--------------------
// getters and setters
@@ -134,7 +134,7 @@ class DicTraverseSession {
if (!mDicNodesCache.hasCachedDicNodesForContinuousSuggestion()) {
return false;
}
- ASSERT(mMaxPointerCount < MAX_POINTER_COUNT_G);
+ ASSERT(mMaxPointerCount <= MAX_POINTER_COUNT_G);
for (int i = 0; i < mMaxPointerCount; ++i) {
const ProximityInfoState *const pInfoState = getProximityInfoState(i);
// If a proximity info state is not continuous suggestion possible,
@@ -146,6 +146,10 @@ class DicTraverseSession {
return true;
}
+ bool isTouchPositionCorrectionEnabled() const {
+ return mProximityInfoStates[0].touchPositionCorrectionEnabled();
+ }
+
private:
DISALLOW_IMPLICIT_CONSTRUCTORS(DicTraverseSession);
// threshold to start caching
diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp
index 764c37240..67d351fa1 100644
--- a/native/jni/src/suggest/core/suggest.cpp
+++ b/native/jni/src/suggest/core/suggest.cpp
@@ -18,6 +18,7 @@
#include "char_utils.h"
#include "dictionary.h"
+#include "digraph_utils.h"
#include "proximity_info.h"
#include "suggest/core/dicnode/dic_node.h"
#include "suggest/core/dicnode/dic_node_priority_queue.h"
@@ -123,8 +124,12 @@ void Suggest::initializeSearch(DicTraverseSession *traverseSession, int commitPo
*/
int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequencies,
int *outputCodePoints, int *spaceIndices, int *outputTypes) const {
+#if DEBUG_EVALUATE_MOST_PROBABLE_STRING
+ const int terminalSize = 0;
+#else
const int terminalSize = min(MAX_RESULTS,
static_cast<int>(traverseSession->getDicTraverseCache()->terminalSize()));
+#endif
DicNode terminals[MAX_RESULTS]; // Avoiding non-POD variable length array
for (int index = terminalSize - 1; index >= 0; --index) {
@@ -221,7 +226,7 @@ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequen
void Suggest::expandCurrentDicNodes(DicTraverseSession *traverseSession) const {
const int inputSize = traverseSession->getInputSize();
DicNodeVector childDicNodes(TRAVERSAL->getDefaultExpandDicNodeSize());
- DicNode omissionDicNode;
+ DicNode correctionDicNode;
// TODO: Find more efficient caching
const bool shouldDepthLevelCache = TRAVERSAL->shouldDepthLevelCache(traverseSession);
@@ -257,7 +262,10 @@ void Suggest::expandCurrentDicNodes(DicTraverseSession *traverseSession) const {
dicNode.setCached();
}
- if (isLookAheadCorrection) {
+ if (dicNode.isInDigraph()) {
+ // Finish digraph handling if the node is in the middle of a digraph expansion.
+ processDicNodeAsDigraph(traverseSession, &dicNode);
+ } else if (isLookAheadCorrection) {
// The algorithm maintains a small set of "deferred" nodes that have not consumed the
// latest touch point yet. These are needed to apply look-ahead correction operations
// that require special handling of the latest touch point. For example, with insertions
@@ -291,12 +299,18 @@ void Suggest::expandCurrentDicNodes(DicTraverseSession *traverseSession) const {
processDicNodeAsMatch(traverseSession, childDicNode);
continue;
}
+ if (DigraphUtils::hasDigraphForCodePoint(traverseSession->getDictFlags(),
+ childDicNode->getNodeCodePoint())) {
+ correctionDicNode.initByCopy(childDicNode);
+ correctionDicNode.advanceDigraphIndex();
+ processDicNodeAsDigraph(traverseSession, &correctionDicNode);
+ }
if (allowsErrorCorrections
&& TRAVERSAL->isOmission(traverseSession, &dicNode, childDicNode)) {
// TODO: (Gesture) Change weight between omission and substitution errors
// TODO: (Gesture) Terminal node should not be handled as omission
- omissionDicNode.initByCopy(childDicNode);
- processDicNodeAsOmission(traverseSession, &omissionDicNode);
+ correctionDicNode.initByCopy(childDicNode);
+ processDicNodeAsOmission(traverseSession, &correctionDicNode);
}
const ProximityType proximityType = TRAVERSAL->getProximityType(
traverseSession, &dicNode, childDicNode);
@@ -400,6 +414,16 @@ void Suggest::processDicNodeAsSubstitution(DicTraverseSession *traverseSession,
processExpandedDicNode(traverseSession, childDicNode);
}
+// Process the node codepoint as a digraph. This means that composite glyphs like the German
+// u-umlaut is expanded to the transliteration "ue". Note that this happens in parallel with
+// the normal non-digraph traversal, so both "uber" and "ueber" can be corrected to "[u-umlaut]ber".
+void Suggest::processDicNodeAsDigraph(DicTraverseSession *traverseSession,
+ DicNode *childDicNode) const {
+ weightChildNode(traverseSession, childDicNode);
+ childDicNode->advanceDigraphIndex();
+ processExpandedDicNode(traverseSession, childDicNode);
+}
+
/**
* Handle the dicNode as an omission error (e.g., ths => this). Skip the current letter and consider
* matches for all possible next letters. Note that just skipping the current letter without any
diff --git a/native/jni/src/suggest/core/suggest.h b/native/jni/src/suggest/core/suggest.h
index 9f609c50c..becd6c1de 100644
--- a/native/jni/src/suggest/core/suggest.h
+++ b/native/jni/src/suggest/core/suggest.h
@@ -65,6 +65,7 @@ class Suggest : public SuggestInterface {
void generateFeatures(
DicTraverseSession *traverseSession, DicNode *dicNode, float *features) const;
void processDicNodeAsOmission(DicTraverseSession *traverseSession, DicNode *dicNode) const;
+ void processDicNodeAsDigraph(DicTraverseSession *traverseSession, DicNode *dicNode) const;
void processDicNodeAsTransposition(DicTraverseSession *traverseSession,
DicNode *dicNode) const;
void processDicNodeAsInsertion(DicTraverseSession *traverseSession, DicNode *dicNode) const;
diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h
index 52d54eb0f..2dcee343f 100644
--- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h
+++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h
@@ -18,6 +18,7 @@
#define LATINIME_TYPING_WEIGHTING_H
#include "defines.h"
+#include "suggest_utils.h"
#include "suggest/core/dicnode/dic_node_utils.h"
#include "suggest/core/policy/weighting.h"
#include "suggest/core/session/dic_traverse_session.h"
@@ -70,10 +71,12 @@ class TypingWeighting : public Weighting {
const int pointIndex = dicNode->getInputIndex(0);
// Note: min() required since length can be MAX_POINT_TO_KEY_LENGTH for characters not on
// the keyboard (like accented letters)
- const float length = min(ScoringParams::MAX_SPATIAL_DISTANCE,
- traverseSession->getProximityInfoState(0)->getPointToKeyLength(
- pointIndex, dicNode->getNodeCodePoint()));
- const float weightedDistance = length * ScoringParams::DISTANCE_WEIGHT_LENGTH;
+ const float normalizedSquaredLength = traverseSession->getProximityInfoState(0)
+ ->getPointToKeyLength(pointIndex, dicNode->getNodeCodePoint());
+ const float normalizedDistance = SuggestUtils::getSweetSpotFactor(
+ traverseSession->isTouchPositionCorrectionEnabled(), normalizedSquaredLength);
+ const float weightedDistance = ScoringParams::DISTANCE_WEIGHT_LENGTH * normalizedDistance;
+
const bool isFirstChar = pointIndex == 0;
const bool isProximity = isProximityDicNode(traverseSession, dicNode);
const float cost = isProximity ? (isFirstChar ? ScoringParams::FIRST_PROXIMITY_COST
diff --git a/native/jni/src/suggest_utils.h b/native/jni/src/suggest_utils.h
index aab9f7ba8..e053dd662 100644
--- a/native/jni/src/suggest_utils.h
+++ b/native/jni/src/suggest_utils.h
@@ -23,10 +23,8 @@
namespace latinime {
class SuggestUtils {
public:
- static float getDistanceScalingFactor(const float normalizedSquaredDistance) {
- if (normalizedSquaredDistance < 0.0f) {
- return -1.0f;
- }
+ // TODO: (OLD) Remove
+ static float getLengthScalingFactor(const float normalizedSquaredDistance) {
// Promote or demote the score according to the distance from the sweet spot
static const float A = ZERO_DISTANCE_PROMOTION_RATE / 100.0f;
static const float B = 1.0f;
@@ -50,6 +48,39 @@ class SuggestUtils {
return factor;
}
+ static float getSweetSpotFactor(const bool isTouchPositionCorrectionEnabled,
+ const float normalizedSquaredDistance) {
+ // Promote or demote the score according to the distance from the sweet spot
+ static const float A = 0.0f;
+ static const float B = 0.24f;
+ static const float C = 1.20f;
+ static const float R0 = 0.0f;
+ static const float R1 = 0.25f; // Sweet spot
+ static const float R2 = 1.0f;
+ const float x = normalizedSquaredDistance;
+ if (!isTouchPositionCorrectionEnabled) {
+ return min(C, x);
+ }
+
+ // factor is a piecewise linear function like:
+ // C -------------.
+ // / .
+ // B / .
+ // -/ .
+ // A _-^ .
+ // .
+ // R0 R1 R2 .
+
+ if (x < R0) {
+ return A;
+ } else if (x < R1) {
+ return (A * (R1 - x) + B * (x - R0)) / (R1 - R0);
+ } else if (x < R2) {
+ return (B * (R2 - x) + C * (x - R1)) / (R2 - R1);
+ } else {
+ return C;
+ }
+ }
private:
DISALLOW_IMPLICIT_CONSTRUCTORS(SuggestUtils);
};
diff --git a/native/jni/src/unigram_dictionary.cpp b/native/jni/src/unigram_dictionary.cpp
index 33795cade..a672294b5 100644
--- a/native/jni/src/unigram_dictionary.cpp
+++ b/native/jni/src/unigram_dictionary.cpp
@@ -32,9 +32,9 @@
namespace latinime {
// TODO: check the header
-UnigramDictionary::UnigramDictionary(const uint8_t *const streamStart, const unsigned int flags)
+UnigramDictionary::UnigramDictionary(const uint8_t *const streamStart, const unsigned int dictFlags)
: DICT_ROOT(streamStart), ROOT_POS(0),
- MAX_DIGRAPH_SEARCH_DEPTH(DEFAULT_MAX_DIGRAPH_SEARCH_DEPTH), FLAGS(flags) {
+ MAX_DIGRAPH_SEARCH_DEPTH(DEFAULT_MAX_DIGRAPH_SEARCH_DEPTH), DICT_FLAGS(dictFlags) {
if (DEBUG_DICT) {
AKLOGI("UnigramDictionary - constructor");
}
@@ -163,7 +163,7 @@ int UnigramDictionary::getSuggestions(ProximityInfo *proximityInfo, const int *x
masterCorrection.resetCorrection();
const DigraphUtils::digraph_t *digraphs = 0;
const int digraphsSize =
- DigraphUtils::getAllDigraphsForDictionaryAndReturnSize(FLAGS, &digraphs);
+ DigraphUtils::getAllDigraphsForDictionaryAndReturnSize(DICT_FLAGS, &digraphs);
if (digraphsSize > 0)
{ // Incrementally tune the word and try all possibilities
int codesBuffer[sizeof(*inputCodePoints) * inputSize];
diff --git a/native/jni/src/unigram_dictionary.h b/native/jni/src/unigram_dictionary.h
index 1a01758fe..a64a539bd 100644
--- a/native/jni/src/unigram_dictionary.h
+++ b/native/jni/src/unigram_dictionary.h
@@ -38,7 +38,7 @@ class UnigramDictionary {
static const int FLAG_MULTIPLE_SUGGEST_ABORT = 0;
static const int FLAG_MULTIPLE_SUGGEST_SKIP = 1;
static const int FLAG_MULTIPLE_SUGGEST_CONTINUE = 2;
- UnigramDictionary(const uint8_t *const streamStart, const unsigned int flags);
+ UnigramDictionary(const uint8_t *const streamStart, const unsigned int dictFlags);
int getProbability(const int *const inWord, const int length) const;
int getBigramPosition(int pos, int *word, int offset, int length) const;
int getSuggestions(ProximityInfo *proximityInfo, const int *xcoordinates,
@@ -46,6 +46,7 @@ class UnigramDictionary {
const std::map<int, int> *bigramMap, const uint8_t *bigramFilter,
const bool useFullEditDistance, int *outWords, int *frequencies,
int *outputTypes) const;
+ int getDictFlags() const { return DICT_FLAGS; }
virtual ~UnigramDictionary();
private:
@@ -109,7 +110,7 @@ class UnigramDictionary {
const uint8_t *const DICT_ROOT;
const int ROOT_POS;
const int MAX_DIGRAPH_SEARCH_DEPTH;
- const int FLAGS;
+ const int DICT_FLAGS;
};
} // namespace latinime
#endif // LATINIME_UNIGRAM_DICTIONARY_H