aboutsummaryrefslogtreecommitdiffstats
path: root/native
diff options
context:
space:
mode:
Diffstat (limited to 'native')
-rw-r--r--native/jni/com_android_inputmethod_keyboard_ProximityInfo.cpp8
-rw-r--r--native/jni/src/dictionary.cpp5
-rw-r--r--native/jni/src/dictionary.h1
-rw-r--r--native/jni/src/digraph_utils.cpp74
-rw-r--r--native/jni/src/digraph_utils.h19
-rw-r--r--native/jni/src/proximity_info.cpp18
-rw-r--r--native/jni/src/proximity_info.h18
-rw-r--r--native/jni/src/proximity_info_params.cpp3
-rw-r--r--native/jni/src/proximity_info_params.h1
-rw-r--r--native/jni/src/proximity_info_state.cpp16
-rw-r--r--native/jni/src/proximity_info_state_utils.cpp31
-rw-r--r--native/jni/src/proximity_info_state_utils.h8
-rw-r--r--native/jni/src/suggest/core/dicnode/dic_node.h26
-rw-r--r--native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h23
-rw-r--r--native/jni/src/suggest/core/session/dic_traverse_session.cpp4
-rw-r--r--native/jni/src/suggest/core/session/dic_traverse_session.h2
-rw-r--r--native/jni/src/suggest/core/suggest.cpp29
-rw-r--r--native/jni/src/suggest/core/suggest.h10
-rw-r--r--native/jni/src/unigram_dictionary.cpp6
-rw-r--r--native/jni/src/unigram_dictionary.h5
20 files changed, 230 insertions, 77 deletions
diff --git a/native/jni/com_android_inputmethod_keyboard_ProximityInfo.cpp b/native/jni/com_android_inputmethod_keyboard_ProximityInfo.cpp
index 3c482ca58..dedb02abf 100644
--- a/native/jni/com_android_inputmethod_keyboard_ProximityInfo.cpp
+++ b/native/jni/com_android_inputmethod_keyboard_ProximityInfo.cpp
@@ -26,13 +26,13 @@ namespace latinime {
static jlong latinime_Keyboard_setProximityInfo(JNIEnv *env, jclass clazz, jstring localeJStr,
jint displayWidth, jint displayHeight, jint gridWidth, jint gridHeight,
- jint mostCommonkeyWidth, jintArray proximityChars, jint keyCount,
+ jint mostCommonkeyWidth, jint mostCommonkeyHeight, jintArray proximityChars, jint keyCount,
jintArray keyXCoordinates, jintArray keyYCoordinates, jintArray keyWidths,
jintArray keyHeights, jintArray keyCharCodes, jfloatArray sweetSpotCenterXs,
jfloatArray sweetSpotCenterYs, jfloatArray sweetSpotRadii) {
ProximityInfo *proximityInfo = new ProximityInfo(env, localeJStr, displayWidth, displayHeight,
- gridWidth, gridHeight, mostCommonkeyWidth, proximityChars, keyCount,
- keyXCoordinates, keyYCoordinates, keyWidths, keyHeights, keyCharCodes,
+ gridWidth, gridHeight, mostCommonkeyWidth, mostCommonkeyHeight, proximityChars,
+ keyCount, keyXCoordinates, keyYCoordinates, keyWidths, keyHeights, keyCharCodes,
sweetSpotCenterXs, sweetSpotCenterYs, sweetSpotRadii);
return reinterpret_cast<jlong>(proximityInfo);
}
@@ -44,7 +44,7 @@ static void latinime_Keyboard_release(JNIEnv *env, jclass clazz, jlong proximity
static JNINativeMethod sMethods[] = {
{const_cast<char *>("setProximityInfoNative"),
- const_cast<char *>("(Ljava/lang/String;IIIII[II[I[I[I[I[I[F[F[F)J"),
+ const_cast<char *>("(Ljava/lang/String;IIIIII[II[I[I[I[I[I[F[F[F)J"),
reinterpret_cast<void *>(latinime_Keyboard_setProximityInfo)},
{const_cast<char *>("releaseProximityInfoNative"),
const_cast<char *>("(J)V"),
diff --git a/native/jni/src/dictionary.cpp b/native/jni/src/dictionary.cpp
index ed6ddb517..c998c0676 100644
--- a/native/jni/src/dictionary.cpp
+++ b/native/jni/src/dictionary.cpp
@@ -103,4 +103,9 @@ int Dictionary::getProbability(const int *word, int length) const {
bool Dictionary::isValidBigram(const int *word1, int length1, const int *word2, int length2) const {
return mBigramDictionary->isValidBigram(word1, length1, word2, length2);
}
+
+int Dictionary::getDictFlags() const {
+ return mUnigramDictionary->getDictFlags();
+}
+
} // namespace latinime
diff --git a/native/jni/src/dictionary.h b/native/jni/src/dictionary.h
index 8c6a7de52..0653d3ca9 100644
--- a/native/jni/src/dictionary.h
+++ b/native/jni/src/dictionary.h
@@ -63,6 +63,7 @@ class Dictionary {
int getDictSize() const { return mDictSize; }
int getMmapFd() const { return mMmapFd; }
int getDictBufAdjust() const { return mDictBufAdjust; }
+ int getDictFlags() const;
virtual ~Dictionary();
private:
diff --git a/native/jni/src/digraph_utils.cpp b/native/jni/src/digraph_utils.cpp
index 8781c5077..6a1ab0271 100644
--- a/native/jni/src/digraph_utils.cpp
+++ b/native/jni/src/digraph_utils.cpp
@@ -27,39 +27,47 @@ const DigraphUtils::digraph_t DigraphUtils::GERMAN_UMLAUT_DIGRAPHS[] =
const DigraphUtils::digraph_t DigraphUtils::FRENCH_LIGATURES_DIGRAPHS[] =
{ { 'a', 'e', 0x00E6 }, // U+00E6 : LATIN SMALL LETTER AE
{ 'o', 'e', 0x0153 } }; // U+0153 : LATIN SMALL LIGATURE OE
+const DigraphUtils::DigraphType DigraphUtils::USED_DIGRAPH_TYPES[] =
+ { DIGRAPH_TYPE_GERMAN_UMLAUT, DIGRAPH_TYPE_FRENCH_LIGATURES };
/* static */ bool DigraphUtils::hasDigraphForCodePoint(
const int dictFlags, const int compositeGlyphCodePoint) {
- if (DigraphUtils::getDigraphForCodePoint(dictFlags, compositeGlyphCodePoint)) {
+ const DigraphUtils::DigraphType digraphType = getDigraphTypeForDictionary(dictFlags);
+ if (DigraphUtils::getDigraphForDigraphTypeAndCodePoint(digraphType, compositeGlyphCodePoint)) {
return true;
}
return false;
}
-// Retrieves the set of all digraphs associated with the given dictionary.
-// Returns the size of the digraph array, or 0 if none exist.
-/* static */ int DigraphUtils::getAllDigraphsForDictionaryAndReturnSize(
- const int dictFlags, const DigraphUtils::digraph_t **digraphs) {
+// Returns the digraph type associated with the given dictionary.
+/* static */ DigraphUtils::DigraphType DigraphUtils::getDigraphTypeForDictionary(
+ const int dictFlags) {
if (BinaryFormat::REQUIRES_GERMAN_UMLAUT_PROCESSING & dictFlags) {
- *digraphs = DigraphUtils::GERMAN_UMLAUT_DIGRAPHS;
- return NELEMS(DigraphUtils::GERMAN_UMLAUT_DIGRAPHS);
+ return DIGRAPH_TYPE_GERMAN_UMLAUT;
}
if (BinaryFormat::REQUIRES_FRENCH_LIGATURES_PROCESSING & dictFlags) {
- *digraphs = DigraphUtils::FRENCH_LIGATURES_DIGRAPHS;
- return NELEMS(DigraphUtils::FRENCH_LIGATURES_DIGRAPHS);
+ return DIGRAPH_TYPE_FRENCH_LIGATURES;
}
- return 0;
+ return DIGRAPH_TYPE_NONE;
+}
+
+// Retrieves the set of all digraphs associated with the given dictionary flags.
+// Returns the size of the digraph array, or 0 if none exist.
+/* static */ int DigraphUtils::getAllDigraphsForDictionaryAndReturnSize(
+ const int dictFlags, const DigraphUtils::digraph_t **const digraphs) {
+ const DigraphUtils::DigraphType digraphType = getDigraphTypeForDictionary(dictFlags);
+ return getAllDigraphsForDigraphTypeAndReturnSize(digraphType, digraphs);
}
// Returns the digraph codepoint for the given composite glyph codepoint and digraph codepoint index
// (which specifies the first or second codepoint in the digraph).
-/* static */ int DigraphUtils::getDigraphCodePointForIndex(const int dictFlags,
- const int compositeGlyphCodePoint, const DigraphCodePointIndex digraphCodePointIndex) {
+/* static */ int DigraphUtils::getDigraphCodePointForIndex(const int compositeGlyphCodePoint,
+ const DigraphCodePointIndex digraphCodePointIndex) {
if (digraphCodePointIndex == NOT_A_DIGRAPH_INDEX) {
return NOT_A_CODE_POINT;
}
- const DigraphUtils::digraph_t *digraph =
- DigraphUtils::getDigraphForCodePoint(dictFlags, compositeGlyphCodePoint);
+ const DigraphUtils::digraph_t *const digraph =
+ DigraphUtils::getDigraphForCodePoint(compositeGlyphCodePoint);
if (!digraph) {
return NOT_A_CODE_POINT;
}
@@ -72,16 +80,48 @@ const DigraphUtils::digraph_t DigraphUtils::FRENCH_LIGATURES_DIGRAPHS[] =
return NOT_A_CODE_POINT;
}
+// Retrieves the set of all digraphs associated with the given digraph type.
+// Returns the size of the digraph array, or 0 if none exist.
+/* static */ int DigraphUtils::getAllDigraphsForDigraphTypeAndReturnSize(
+ const DigraphUtils::DigraphType digraphType,
+ const DigraphUtils::digraph_t **const digraphs) {
+ if (digraphType == DigraphUtils::DIGRAPH_TYPE_GERMAN_UMLAUT) {
+ *digraphs = GERMAN_UMLAUT_DIGRAPHS;
+ return NELEMS(GERMAN_UMLAUT_DIGRAPHS);
+ }
+ if (digraphType == DIGRAPH_TYPE_FRENCH_LIGATURES) {
+ *digraphs = FRENCH_LIGATURES_DIGRAPHS;
+ return NELEMS(FRENCH_LIGATURES_DIGRAPHS);
+ }
+ return 0;
+}
+
/**
* Returns the digraph for the input composite glyph codepoint, or 0 if none exists.
- * dictFlags: the dictionary flags needed to determine which digraphs are supported.
* compositeGlyphCodePoint: the method returns the digraph corresponding to this codepoint.
*/
/* static */ const DigraphUtils::digraph_t *DigraphUtils::getDigraphForCodePoint(
- const int dictFlags, const int compositeGlyphCodePoint) {
+ const int compositeGlyphCodePoint) {
+ for (size_t i = 0; i < NELEMS(USED_DIGRAPH_TYPES); i++) {
+ const DigraphUtils::digraph_t *const digraph = getDigraphForDigraphTypeAndCodePoint(
+ USED_DIGRAPH_TYPES[i], compositeGlyphCodePoint);
+ if (digraph) {
+ return digraph;
+ }
+ }
+ return 0;
+}
+
+/**
+ * Returns the digraph for the input composite glyph codepoint, or 0 if none exists.
+ * digraphType: the type of digraphs supported.
+ * compositeGlyphCodePoint: the method returns the digraph corresponding to this codepoint.
+ */
+/* static */ const DigraphUtils::digraph_t *DigraphUtils::getDigraphForDigraphTypeAndCodePoint(
+ const DigraphUtils::DigraphType digraphType, const int compositeGlyphCodePoint) {
const DigraphUtils::digraph_t *digraphs = 0;
const int digraphsSize =
- DigraphUtils::getAllDigraphsForDictionaryAndReturnSize(dictFlags, &digraphs);
+ DigraphUtils::getAllDigraphsForDictionaryAndReturnSize(digraphType, &digraphs);
for (int i = 0; i < digraphsSize; i++) {
if (digraphs[i].compositeGlyph == compositeGlyphCodePoint) {
return &digraphs[i];
diff --git a/native/jni/src/digraph_utils.h b/native/jni/src/digraph_utils.h
index 6e364b67a..94435228e 100644
--- a/native/jni/src/digraph_utils.h
+++ b/native/jni/src/digraph_utils.h
@@ -27,21 +27,34 @@ class DigraphUtils {
SECOND_DIGRAPH_CODEPOINT
} DigraphCodePointIndex;
+ typedef enum {
+ DIGRAPH_TYPE_NONE,
+ DIGRAPH_TYPE_GERMAN_UMLAUT,
+ DIGRAPH_TYPE_FRENCH_LIGATURES
+ } DigraphType;
+
typedef struct { int first; int second; int compositeGlyph; } digraph_t;
static bool hasDigraphForCodePoint(const int dictFlags, const int compositeGlyphCodePoint);
static int getAllDigraphsForDictionaryAndReturnSize(
- const int dictFlags, const digraph_t **digraphs);
+ const int dictFlags, const digraph_t **const digraphs);
static int getDigraphCodePointForIndex(const int dictFlags, const int compositeGlyphCodePoint,
const DigraphCodePointIndex digraphCodePointIndex);
+ static int getDigraphCodePointForIndex(const int compositeGlyphCodePoint,
+ const DigraphCodePointIndex digraphCodePointIndex);
private:
DISALLOW_IMPLICIT_CONSTRUCTORS(DigraphUtils);
- static const digraph_t *getDigraphForCodePoint(
- const int dictFlags, const int compositeGlyphCodePoint);
+ static DigraphType getDigraphTypeForDictionary(const int dictFlags);
+ static int getAllDigraphsForDigraphTypeAndReturnSize(
+ const DigraphType digraphType, const digraph_t **const digraphs);
+ static const digraph_t *getDigraphForCodePoint(const int compositeGlyphCodePoint);
+ static const digraph_t *getDigraphForDigraphTypeAndCodePoint(
+ const DigraphType digraphType, const int compositeGlyphCodePoint);
static const digraph_t GERMAN_UMLAUT_DIGRAPHS[];
static const digraph_t FRENCH_LIGATURES_DIGRAPHS[];
+ static const DigraphType USED_DIGRAPH_TYPES[];
};
} // namespace latinime
#endif // DIGRAPH_UTILS_H
diff --git a/native/jni/src/proximity_info.cpp b/native/jni/src/proximity_info.cpp
index 74b5e0131..88d670d61 100644
--- a/native/jni/src/proximity_info.cpp
+++ b/native/jni/src/proximity_info.cpp
@@ -49,13 +49,17 @@ static AK_FORCE_INLINE void safeGetOrFillZeroFloatArrayRegion(JNIEnv *env, jfloa
ProximityInfo::ProximityInfo(JNIEnv *env, const jstring localeJStr,
const int keyboardWidth, const int keyboardHeight, const int gridWidth,
- const int gridHeight, const int mostCommonKeyWidth, const jintArray proximityChars,
- const int keyCount, const jintArray keyXCoordinates, const jintArray keyYCoordinates,
- const jintArray keyWidths, const jintArray keyHeights, const jintArray keyCharCodes,
- const jfloatArray sweetSpotCenterXs, const jfloatArray sweetSpotCenterYs,
- const jfloatArray sweetSpotRadii)
+ const int gridHeight, const int mostCommonKeyWidth, const int mostCommonKeyHeight,
+ const jintArray proximityChars, const int keyCount, const jintArray keyXCoordinates,
+ const jintArray keyYCoordinates, const jintArray keyWidths, const jintArray keyHeights,
+ const jintArray keyCharCodes, const jfloatArray sweetSpotCenterXs,
+ const jfloatArray sweetSpotCenterYs, const jfloatArray sweetSpotRadii)
: GRID_WIDTH(gridWidth), GRID_HEIGHT(gridHeight), MOST_COMMON_KEY_WIDTH(mostCommonKeyWidth),
MOST_COMMON_KEY_WIDTH_SQUARE(mostCommonKeyWidth * mostCommonKeyWidth),
+ MOST_COMMON_KEY_HEIGHT(mostCommonKeyHeight),
+ NORMALIZED_SQUARED_MOST_COMMON_KEY_HYPOTENUSE(1.0f +
+ SQUARE_FLOAT(static_cast<float>(mostCommonKeyHeight) /
+ static_cast<float>(mostCommonKeyWidth))),
CELL_WIDTH((keyboardWidth + gridWidth - 1) / gridWidth),
CELL_HEIGHT((keyboardHeight + gridHeight - 1) / gridHeight),
KEY_COUNT(min(keyCount, MAX_KEY_COUNT_IN_A_KEYBOARD)),
@@ -129,7 +133,7 @@ bool ProximityInfo::hasSpaceProximity(const int x, const int y) const {
}
float ProximityInfo::getNormalizedSquaredDistanceFromCenterFloatG(
- const int keyId, const int x, const int y) const {
+ const int keyId, const int x, const int y, const float verticalScale) const {
const bool correctTouchPosition = hasTouchPositionCorrectionData();
const float centerX = static_cast<float>(correctTouchPosition ? getSweetSpotCenterXAt(keyId)
: getKeyCenterXOfKeyIdG(keyId));
@@ -138,7 +142,7 @@ float ProximityInfo::getNormalizedSquaredDistanceFromCenterFloatG(
if (correctTouchPosition) {
const float sweetSpotCenterY = static_cast<float>(getSweetSpotCenterYAt(keyId));
const float gapY = sweetSpotCenterY - visualKeyCenterY;
- centerY = visualKeyCenterY + gapY * ProximityInfoParams::VERTICAL_SWEET_SPOT_SCALE_G;
+ centerY = visualKeyCenterY + gapY * verticalScale;
} else {
centerY = visualKeyCenterY;
}
diff --git a/native/jni/src/proximity_info.h b/native/jni/src/proximity_info.h
index 57a175d2c..deb9ae0de 100644
--- a/native/jni/src/proximity_info.h
+++ b/native/jni/src/proximity_info.h
@@ -30,16 +30,17 @@ class ProximityInfo {
public:
ProximityInfo(JNIEnv *env, const jstring localeJStr,
const int keyboardWidth, const int keyboardHeight, const int gridWidth,
- const int gridHeight, const int mostCommonKeyWidth, const jintArray proximityChars,
- const int keyCount, const jintArray keyXCoordinates, const jintArray keyYCoordinates,
- const jintArray keyWidths, const jintArray keyHeights, const jintArray keyCharCodes,
- const jfloatArray sweetSpotCenterXs, const jfloatArray sweetSpotCenterYs,
- const jfloatArray sweetSpotRadii);
+ const int gridHeight, const int mostCommonKeyWidth, const int mostCommonKeyHeight,
+ const jintArray proximityChars, const int keyCount, const jintArray keyXCoordinates,
+ const jintArray keyYCoordinates, const jintArray keyWidths, const jintArray keyHeights,
+ const jintArray keyCharCodes, const jfloatArray sweetSpotCenterXs,
+ const jfloatArray sweetSpotCenterYs, const jfloatArray sweetSpotRadii);
~ProximityInfo();
bool hasSpaceProximity(const int x, const int y) const;
int getNormalizedSquaredDistance(const int inputIndex, const int proximityIndex) const;
float getNormalizedSquaredDistanceFromCenterFloatG(
- const int keyId, const int x, const int y) const;
+ const int keyId, const int x, const int y,
+ const float verticalScale) const;
bool sameAsTyped(const unsigned short *word, int length) const;
int getCodePointOf(const int keyIndex) const;
bool hasSweetSpotData(const int keyIndex) const {
@@ -55,6 +56,9 @@ class ProximityInfo {
bool hasTouchPositionCorrectionData() const { return HAS_TOUCH_POSITION_CORRECTION_DATA; }
int getMostCommonKeyWidth() const { return MOST_COMMON_KEY_WIDTH; }
int getMostCommonKeyWidthSquare() const { return MOST_COMMON_KEY_WIDTH_SQUARE; }
+ float getNormalizedSquaredMostCommonKeyHypotenuse() const {
+ return NORMALIZED_SQUARED_MOST_COMMON_KEY_HYPOTENUSE;
+ }
int getKeyCount() const { return KEY_COUNT; }
int getCellHeight() const { return CELL_HEIGHT; }
int getCellWidth() const { return CELL_WIDTH; }
@@ -98,6 +102,8 @@ class ProximityInfo {
const int GRID_HEIGHT;
const int MOST_COMMON_KEY_WIDTH;
const int MOST_COMMON_KEY_WIDTH_SQUARE;
+ const int MOST_COMMON_KEY_HEIGHT;
+ const float NORMALIZED_SQUARED_MOST_COMMON_KEY_HYPOTENUSE;
const int CELL_WIDTH;
const int CELL_HEIGHT;
const int KEY_COUNT;
diff --git a/native/jni/src/proximity_info_params.cpp b/native/jni/src/proximity_info_params.cpp
index f9a4352ee..2675d9e70 100644
--- a/native/jni/src/proximity_info_params.cpp
+++ b/native/jni/src/proximity_info_params.cpp
@@ -20,7 +20,8 @@
namespace latinime {
const float ProximityInfoParams::NOT_A_DISTANCE_FLOAT = -1.0f;
const int ProximityInfoParams::MIN_DOUBLE_LETTER_BEELINE_SPEED_PERCENTILE = 5;
-const float ProximityInfoParams::VERTICAL_SWEET_SPOT_SCALE_G = 1.1f;
+const float ProximityInfoParams::VERTICAL_SWEET_SPOT_SCALE = 1.0f;
+const float ProximityInfoParams::VERTICAL_SWEET_SPOT_SCALE_G = 0.5f;
/* Per method constants */
// Used by ProximityInfoStateUtils::initGeometricDistanceInfos()
diff --git a/native/jni/src/proximity_info_params.h b/native/jni/src/proximity_info_params.h
index e7aec0976..4e47f7308 100644
--- a/native/jni/src/proximity_info_params.h
+++ b/native/jni/src/proximity_info_params.h
@@ -25,6 +25,7 @@ class ProximityInfoParams {
public:
static const float NOT_A_DISTANCE_FLOAT;
static const int MIN_DOUBLE_LETTER_BEELINE_SPEED_PERCENTILE;
+ static const float VERTICAL_SWEET_SPOT_SCALE;
static const float VERTICAL_SWEET_SPOT_SCALE_G;
// Used by ProximityInfoStateUtils::initGeometricDistanceInfos()
diff --git a/native/jni/src/proximity_info_state.cpp b/native/jni/src/proximity_info_state.cpp
index 861ba9971..a10b260e1 100644
--- a/native/jni/src/proximity_info_state.cpp
+++ b/native/jni/src/proximity_info_state.cpp
@@ -28,6 +28,7 @@
namespace latinime {
+// TODO: Remove the dependency of "isGeometric"
void ProximityInfoState::initInputParams(const int pointerId, const float maxPointToKeyLength,
const ProximityInfo *proximityInfo, const int *const inputCodes, const int inputSize,
const int *const xCoordinates, const int *const yCoordinates, const int *const times,
@@ -94,12 +95,17 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi
pushTouchPointStartIndex, lastSavedInputSize);
}
+ // TODO: Remove the dependency of "isGeometric"
+ const float verticalSweetSpotScale = isGeometric
+ ? ProximityInfoParams::VERTICAL_SWEET_SPOT_SCALE_G
+ : ProximityInfoParams::VERTICAL_SWEET_SPOT_SCALE;
+
if (xCoordinates && yCoordinates) {
mSampledInputSize = ProximityInfoStateUtils::updateTouchPoints(mProximityInfo,
mMaxPointToKeyLength, mInputProximities, xCoordinates, yCoordinates, times,
- pointerIds, inputSize, isGeometric, pointerId, pushTouchPointStartIndex,
- &mSampledInputXs, &mSampledInputYs, &mSampledTimes, &mSampledLengthCache,
- &mSampledInputIndice);
+ pointerIds, verticalSweetSpotScale, inputSize, isGeometric, pointerId,
+ pushTouchPointStartIndex, &mSampledInputXs, &mSampledInputYs, &mSampledTimes,
+ &mSampledLengthCache, &mSampledInputIndice);
}
if (mSampledInputSize > 0 && isGeometric) {
@@ -115,8 +121,8 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi
if (mSampledInputSize > 0) {
ProximityInfoStateUtils::initGeometricDistanceInfos(mProximityInfo, mSampledInputSize,
- lastSavedInputSize, &mSampledInputXs, &mSampledInputYs, &mSampledNearKeySets,
- &mSampledDistanceCache_G);
+ lastSavedInputSize, verticalSweetSpotScale, &mSampledInputXs, &mSampledInputYs,
+ &mSampledNearKeySets, &mSampledDistanceCache_G);
if (isGeometric) {
// updates probabilities of skipping or mapping each key for all points.
ProximityInfoStateUtils::updateAlignPointProbabilities(
diff --git a/native/jni/src/proximity_info_state_utils.cpp b/native/jni/src/proximity_info_state_utils.cpp
index 760508076..df70cffdf 100644
--- a/native/jni/src/proximity_info_state_utils.cpp
+++ b/native/jni/src/proximity_info_state_utils.cpp
@@ -42,8 +42,8 @@ namespace latinime {
const ProximityInfo *const proximityInfo, const int maxPointToKeyLength,
const int *const inputProximities, const int *const inputXCoordinates,
const int *const inputYCoordinates, const int *const times, const int *const pointerIds,
- const int inputSize, const bool isGeometric, const int pointerId,
- const int pushTouchPointStartIndex, std::vector<int> *sampledInputXs,
+ const float verticalSweetSpotScale, const int inputSize, const bool isGeometric,
+ const int pointerId, const int pushTouchPointStartIndex, std::vector<int> *sampledInputXs,
std::vector<int> *sampledInputYs, std::vector<int> *sampledInputTimes,
std::vector<int> *sampledLengthCache, std::vector<int> *sampledInputIndice) {
if (DEBUG_SAMPLING_POINTS) {
@@ -112,10 +112,10 @@ namespace latinime {
}
if (pushTouchPoint(proximityInfo, maxPointToKeyLength, i, c, x, y, time,
- isGeometric /* doSampling */, i == lastInputIndex, sumAngle,
- currentNearKeysDistances, prevNearKeysDistances, prevPrevNearKeysDistances,
- sampledInputXs, sampledInputYs, sampledInputTimes, sampledLengthCache,
- sampledInputIndice)) {
+ verticalSweetSpotScale, isGeometric /* doSampling */, i == lastInputIndex,
+ sumAngle, currentNearKeysDistances, prevNearKeysDistances,
+ prevPrevNearKeysDistances, sampledInputXs, sampledInputYs, sampledInputTimes,
+ sampledLengthCache, sampledInputIndice)) {
// Previous point information was popped.
NearKeysDistanceMap *tmp = prevNearKeysDistances;
prevNearKeysDistances = currentNearKeysDistances;
@@ -222,7 +222,8 @@ namespace latinime {
/* static */ void ProximityInfoStateUtils::initGeometricDistanceInfos(
const ProximityInfo *const proximityInfo, const int sampledInputSize,
- const int lastSavedInputSize, const std::vector<int> *const sampledInputXs,
+ const int lastSavedInputSize, const float verticalSweetSpotScale,
+ const std::vector<int> *const sampledInputXs,
const std::vector<int> *const sampledInputYs,
std::vector<NearKeycodesSet> *SampledNearKeySets,
std::vector<float> *SampledDistanceCache_G) {
@@ -236,7 +237,8 @@ namespace latinime {
const int x = (*sampledInputXs)[i];
const int y = (*sampledInputYs)[i];
const float normalizedSquaredDistance =
- proximityInfo->getNormalizedSquaredDistanceFromCenterFloatG(k, x, y);
+ proximityInfo->getNormalizedSquaredDistanceFromCenterFloatG(
+ k, x, y, verticalSweetSpotScale);
(*SampledDistanceCache_G)[index] = normalizedSquaredDistance;
if (normalizedSquaredDistance
< ProximityInfoParams::NEAR_KEY_NORMALIZED_SQUARED_THRESHOLD) {
@@ -354,12 +356,14 @@ namespace latinime {
// the given point and the nearest key position.
/* static */ float ProximityInfoStateUtils::updateNearKeysDistances(
const ProximityInfo *const proximityInfo, const float maxPointToKeyLength, const int x,
- const int y, NearKeysDistanceMap *const currentNearKeysDistances) {
+ const int y, const float verticalSweetspotScale,
+ NearKeysDistanceMap *const currentNearKeysDistances) {
currentNearKeysDistances->clear();
const int keyCount = proximityInfo->getKeyCount();
float nearestKeyDistance = maxPointToKeyLength;
for (int k = 0; k < keyCount; ++k) {
- const float dist = proximityInfo->getNormalizedSquaredDistanceFromCenterFloatG(k, x, y);
+ const float dist = proximityInfo->getNormalizedSquaredDistanceFromCenterFloatG(k, x, y,
+ verticalSweetspotScale);
if (dist < ProximityInfoParams::NEAR_KEY_THRESHOLD_FOR_DISTANCE) {
currentNearKeysDistances->insert(std::pair<int, float>(k, dist));
}
@@ -439,7 +443,8 @@ namespace latinime {
// Returning if previous point is popped or not.
/* static */ bool ProximityInfoStateUtils::pushTouchPoint(const ProximityInfo *const proximityInfo,
const int maxPointToKeyLength, const int inputIndex, const int nodeCodePoint, int x, int y,
- const int time, const bool doSampling, const bool isLastPoint, const float sumAngle,
+ const int time, const float verticalSweetSpotScale, const bool doSampling,
+ const bool isLastPoint, const float sumAngle,
NearKeysDistanceMap *const currentNearKeysDistances,
const NearKeysDistanceMap *const prevNearKeysDistances,
const NearKeysDistanceMap *const prevPrevNearKeysDistances,
@@ -451,8 +456,8 @@ namespace latinime {
size_t size = sampledInputXs->size();
bool popped = false;
if (nodeCodePoint < 0 && doSampling) {
- const float nearest = updateNearKeysDistances(
- proximityInfo, maxPointToKeyLength, x, y, currentNearKeysDistances);
+ const float nearest = updateNearKeysDistances(proximityInfo, maxPointToKeyLength, x, y,
+ verticalSweetSpotScale, currentNearKeysDistances);
const float score = getPointScore(mostCommonKeyWidth, x, y, time, isLastPoint, nearest,
sumAngle, currentNearKeysDistances, prevNearKeysDistances,
prevPrevNearKeysDistances, sampledInputXs, sampledInputYs);
diff --git a/native/jni/src/proximity_info_state_utils.h b/native/jni/src/proximity_info_state_utils.h
index 3ceb25d8b..c9feb59a3 100644
--- a/native/jni/src/proximity_info_state_utils.h
+++ b/native/jni/src/proximity_info_state_utils.h
@@ -38,7 +38,8 @@ class ProximityInfoStateUtils {
static int updateTouchPoints(const ProximityInfo *const proximityInfo,
const int maxPointToKeyLength, const int *const inputProximities,
const int *const inputXCoordinates, const int *const inputYCoordinates,
- const int *const times, const int *const pointerIds, const int inputSize,
+ const int *const times, const int *const pointerIds,
+ const float verticalSweetSpotScale, const int inputSize,
const bool isGeometric, const int pointerId, const int pushTouchPointStartIndex,
std::vector<int> *sampledInputXs, std::vector<int> *sampledInputYs,
std::vector<int> *sampledInputTimes, std::vector<int> *sampledLengthCache,
@@ -84,6 +85,7 @@ class ProximityInfoStateUtils {
const int inputIndex, const int keyId);
static void initGeometricDistanceInfos(const ProximityInfo *const proximityInfo,
const int sampledInputSize, const int lastSavedInputSize,
+ const float verticalSweetSpotScale,
const std::vector<int> *const sampledInputXs,
const std::vector<int> *const sampledInputYs,
std::vector<NearKeycodesSet> *SampledNearKeySets,
@@ -118,6 +120,7 @@ class ProximityInfoStateUtils {
static float updateNearKeysDistances(const ProximityInfo *const proximityInfo,
const float maxPointToKeyLength, const int x, const int y,
+ const float verticalSweetSpotScale,
NearKeysDistanceMap *const currentNearKeysDistances);
static bool isPrevLocalMin(const NearKeysDistanceMap *const currentNearKeysDistances,
const NearKeysDistanceMap *const prevNearKeysDistances,
@@ -130,7 +133,8 @@ class ProximityInfoStateUtils {
std::vector<int> *sampledInputXs, std::vector<int> *sampledInputYs);
static bool pushTouchPoint(const ProximityInfo *const proximityInfo,
const int maxPointToKeyLength, const int inputIndex, const int nodeCodePoint, int x,
- int y, const int time, const bool doSampling, const bool isLastPoint,
+ int y, const int time, const float verticalSweetSpotScale,
+ const bool doSampling, const bool isLastPoint,
const float sumAngle, NearKeysDistanceMap *const currentNearKeysDistances,
const NearKeysDistanceMap *const prevNearKeysDistances,
const NearKeysDistanceMap *const prevPrevNearKeysDistances,
diff --git a/native/jni/src/suggest/core/dicnode/dic_node.h b/native/jni/src/suggest/core/dicnode/dic_node.h
index 7bfa459a2..32faae52c 100644
--- a/native/jni/src/suggest/core/dicnode/dic_node.h
+++ b/native/jni/src/suggest/core/dicnode/dic_node.h
@@ -23,6 +23,7 @@
#include "dic_node_profiler.h"
#include "dic_node_properties.h"
#include "dic_node_release_listener.h"
+#include "digraph_utils.h"
#if DEBUG_DICT
#define LOGI_SHOW_ADD_COST_PROP \
@@ -48,13 +49,6 @@
namespace latinime {
-// Naming convention
-// - Distance: "Weighted" edit distance -- used both for spatial and language.
-// - Compound Distance: Spatial Distance + Language Distance -- used for pruning and scoring
-// - Cost: delta/diff for Distance -- used both for spatial and language
-// - Length: "Non-weighted" -- used only for spatial
-// - Probability: "Non-weighted" -- used only for language
-
// This struct is purely a bucket to return values. No instances of this struct should be kept.
struct DicNode_InputStateG {
bool mNeedsToUpdateInputStateG;
@@ -406,8 +400,15 @@ class DicNode {
// TODO: Remove //
//////////////////////
// TODO: Remove once touch path is merged into ProximityInfoState
+ // Note: Returned codepoint may be a digraph codepoint if the node is in a composite glyph.
int getNodeCodePoint() const {
- return mDicNodeProperties.getNodeCodePoint();
+ const int codePoint = mDicNodeProperties.getNodeCodePoint();
+ const DigraphUtils::DigraphCodePointIndex digraphIndex =
+ mDicNodeState.mDicNodeStateScoring.getDigraphIndex();
+ if (digraphIndex == DigraphUtils::NOT_A_DIGRAPH_INDEX) {
+ return codePoint;
+ }
+ return DigraphUtils::getDigraphCodePointForIndex(codePoint, digraphIndex);
}
////////////////////////////////
@@ -459,6 +460,15 @@ class DicNode {
mDicNodeState.mDicNodeStateScoring.setDoubleLetterLevel(doubleLetterLevel);
}
+ bool isInDigraph() const {
+ return mDicNodeState.mDicNodeStateScoring.getDigraphIndex()
+ != DigraphUtils::NOT_A_DIGRAPH_INDEX;
+ }
+
+ void advanceDigraphIndex() {
+ mDicNodeState.mDicNodeStateScoring.advanceDigraphIndex();
+ }
+
uint8_t getFlags() const {
return mDicNodeProperties.getFlags();
}
diff --git a/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h b/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h
index 8e816329f..8902d3122 100644
--- a/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h
+++ b/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h
@@ -20,6 +20,7 @@
#include <stdint.h>
#include "defines.h"
+#include "digraph_utils.h"
namespace latinime {
@@ -27,6 +28,7 @@ class DicNodeStateScoring {
public:
AK_FORCE_INLINE DicNodeStateScoring()
: mDoubleLetterLevel(NOT_A_DOUBLE_LETTER),
+ mDigraphIndex(DigraphUtils::NOT_A_DIGRAPH_INDEX),
mEditCorrectionCount(0), mProximityCorrectionCount(0),
mNormalizedCompoundDistance(0.0f), mSpatialDistance(0.0f), mLanguageDistance(0.0f),
mTotalPrevWordsLanguageCost(0.0f), mRawLength(0.0f) {
@@ -43,6 +45,7 @@ class DicNodeStateScoring {
mTotalPrevWordsLanguageCost = 0.0f;
mRawLength = 0.0f;
mDoubleLetterLevel = NOT_A_DOUBLE_LETTER;
+ mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX;
}
AK_FORCE_INLINE void init(const DicNodeStateScoring *const scoring) {
@@ -54,6 +57,7 @@ class DicNodeStateScoring {
mTotalPrevWordsLanguageCost = scoring->mTotalPrevWordsLanguageCost;
mRawLength = scoring->mRawLength;
mDoubleLetterLevel = scoring->mDoubleLetterLevel;
+ mDigraphIndex = scoring->mDigraphIndex;
}
void addCost(const float spatialCost, const float languageCost, const bool doNormalization,
@@ -126,6 +130,24 @@ class DicNodeStateScoring {
}
}
+ DigraphUtils::DigraphCodePointIndex getDigraphIndex() const {
+ return mDigraphIndex;
+ }
+
+ void advanceDigraphIndex() {
+ switch(mDigraphIndex) {
+ case DigraphUtils::NOT_A_DIGRAPH_INDEX:
+ mDigraphIndex = DigraphUtils::FIRST_DIGRAPH_CODEPOINT;
+ break;
+ case DigraphUtils::FIRST_DIGRAPH_CODEPOINT:
+ mDigraphIndex = DigraphUtils::SECOND_DIGRAPH_CODEPOINT;
+ break;
+ case DigraphUtils::SECOND_DIGRAPH_CODEPOINT:
+ mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX;
+ break;
+ }
+ }
+
float getTotalPrevWordsLanguageCost() const {
return mTotalPrevWordsLanguageCost;
}
@@ -135,6 +157,7 @@ class DicNodeStateScoring {
// Use a default copy constructor and an assign operator because shallow copies are ok
// for this class
DoubleLetterLevel mDoubleLetterLevel;
+ DigraphUtils::DigraphCodePointIndex mDigraphIndex;
int16_t mEditCorrectionCount;
int16_t mProximityCorrectionCount;
diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.cpp b/native/jni/src/suggest/core/session/dic_traverse_session.cpp
index ef6616e40..5b783a2ba 100644
--- a/native/jni/src/suggest/core/session/dic_traverse_session.cpp
+++ b/native/jni/src/suggest/core/session/dic_traverse_session.cpp
@@ -84,6 +84,10 @@ const uint8_t *DicTraverseSession::getOffsetDict() const {
return mDictionary->getOffsetDict();
}
+int DicTraverseSession::getDictFlags() const {
+ return mDictionary->getDictFlags();
+}
+
void DicTraverseSession::resetCache(const int nextActiveCacheSize, const int maxWords) {
mDicNodesCache.reset(nextActiveCacheSize, maxWords);
mBigramCacheMap.clear();
diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.h b/native/jni/src/suggest/core/session/dic_traverse_session.h
index 62e1d1ab9..525d198cd 100644
--- a/native/jni/src/suggest/core/session/dic_traverse_session.h
+++ b/native/jni/src/suggest/core/session/dic_traverse_session.h
@@ -53,7 +53,7 @@ class DicTraverseSession {
void resetCache(const int nextActiveCacheSize, const int maxWords);
const uint8_t *getOffsetDict() const;
- bool canUseCache() const;
+ int getDictFlags() const;
//--------------------
// getters and setters
diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp
index 1e97a9176..63bb20004 100644
--- a/native/jni/src/suggest/core/suggest.cpp
+++ b/native/jni/src/suggest/core/suggest.cpp
@@ -18,6 +18,7 @@
#include "char_utils.h"
#include "dictionary.h"
+#include "digraph_utils.h"
#include "proximity_info.h"
#include "suggest/core/dicnode/dic_node.h"
#include "suggest/core/dicnode/dic_node_priority_queue.h"
@@ -221,7 +222,7 @@ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequen
void Suggest::expandCurrentDicNodes(DicTraverseSession *traverseSession) const {
const int inputSize = traverseSession->getInputSize();
DicNodeVector childDicNodes(TRAVERSAL->getDefaultExpandDicNodeSize());
- DicNode omissionDicNode;
+ DicNode correctionDicNode;
// TODO: Find more efficient caching
const bool shouldDepthLevelCache = TRAVERSAL->shouldDepthLevelCache(traverseSession);
@@ -257,7 +258,10 @@ void Suggest::expandCurrentDicNodes(DicTraverseSession *traverseSession) const {
dicNode.setCached();
}
- if (isLookAheadCorrection) {
+ if (dicNode.isInDigraph()) {
+ // Finish digraph handling if the node is in the middle of a digraph expansion.
+ processDicNodeAsDigraph(traverseSession, &dicNode);
+ } else if (isLookAheadCorrection) {
// The algorithm maintains a small set of "deferred" nodes that have not consumed the
// latest touch point yet. These are needed to apply look-ahead correction operations
// that require special handling of the latest touch point. For example, with insertions
@@ -291,12 +295,18 @@ void Suggest::expandCurrentDicNodes(DicTraverseSession *traverseSession) const {
processDicNodeAsMatch(traverseSession, childDicNode);
continue;
}
+ if (DigraphUtils::hasDigraphForCodePoint(traverseSession->getDictFlags(),
+ childDicNode->getNodeCodePoint())) {
+ correctionDicNode.initByCopy(childDicNode);
+ correctionDicNode.advanceDigraphIndex();
+ processDicNodeAsDigraph(traverseSession, &correctionDicNode);
+ }
if (allowsErrorCorrections
&& TRAVERSAL->isOmission(traverseSession, &dicNode, childDicNode)) {
// TODO: (Gesture) Change weight between omission and substitution errors
// TODO: (Gesture) Terminal node should not be handled as omission
- omissionDicNode.initByCopy(childDicNode);
- processDicNodeAsOmission(traverseSession, &omissionDicNode);
+ correctionDicNode.initByCopy(childDicNode);
+ processDicNodeAsOmission(traverseSession, &correctionDicNode);
}
const ProximityType proximityType = TRAVERSAL->getProximityType(
traverseSession, &dicNode, childDicNode);
@@ -400,6 +410,16 @@ void Suggest::processDicNodeAsSubstitution(DicTraverseSession *traverseSession,
processExpandedDicNode(traverseSession, childDicNode);
}
+// Process the node codepoint as a digraph. This means that composite glyphs like the German
+// u-umlaut is expanded to the transliteration "ue". Note that this happens in parallel with
+// the normal non-digraph traversal, so both "uber" and "ueber" can be corrected to "[u-umlaut]ber".
+void Suggest::processDicNodeAsDigraph(DicTraverseSession *traverseSession,
+ DicNode *childDicNode) const {
+ weightChildNode(traverseSession, childDicNode);
+ childDicNode->advanceDigraphIndex();
+ processExpandedDicNode(traverseSession, childDicNode);
+}
+
/**
* Handle the dicNode as an omission error (e.g., ths => this). Skip the current letter and consider
* matches for all possible next letters. Note that just skipping the current letter without any
@@ -426,7 +446,6 @@ void Suggest::processDicNodeAsOmission(
weightChildNode(traverseSession, childDicNode);
if (!TRAVERSAL->isPossibleOmissionChildNode(traverseSession, dicNode, childDicNode)) {
- DicNode::managedDelete(childDicNode);
continue;
}
processExpandedDicNode(traverseSession, childDicNode);
diff --git a/native/jni/src/suggest/core/suggest.h b/native/jni/src/suggest/core/suggest.h
index a1e7e7a94..136c4e548 100644
--- a/native/jni/src/suggest/core/suggest.h
+++ b/native/jni/src/suggest/core/suggest.h
@@ -23,6 +23,15 @@
namespace latinime {
+// Naming convention
+// - Distance: "Weighted" edit distance -- used both for spatial and language.
+// - Compound Distance: Spatial Distance + Language Distance -- used for pruning and scoring
+// - Cost: delta/diff for Distance -- used both for spatial and language
+// - Length: "Non-weighted" -- used only for spatial
+// - Probability: "Non-weighted" -- used only for language
+// - Score: Final calibrated score based on the compound distance, which is sent to java as the
+// priority of a suggested word
+
class DicNode;
class DicTraverseSession;
class ProximityInfo;
@@ -55,6 +64,7 @@ class Suggest : public SuggestInterface {
void generateFeatures(
DicTraverseSession *traverseSession, DicNode *dicNode, float *features) const;
void processDicNodeAsOmission(DicTraverseSession *traverseSession, DicNode *dicNode) const;
+ void processDicNodeAsDigraph(DicTraverseSession *traverseSession, DicNode *dicNode) const;
void processDicNodeAsTransposition(DicTraverseSession *traverseSession,
DicNode *dicNode) const;
void processDicNodeAsInsertion(DicTraverseSession *traverseSession, DicNode *dicNode) const;
diff --git a/native/jni/src/unigram_dictionary.cpp b/native/jni/src/unigram_dictionary.cpp
index 33795cade..a672294b5 100644
--- a/native/jni/src/unigram_dictionary.cpp
+++ b/native/jni/src/unigram_dictionary.cpp
@@ -32,9 +32,9 @@
namespace latinime {
// TODO: check the header
-UnigramDictionary::UnigramDictionary(const uint8_t *const streamStart, const unsigned int flags)
+UnigramDictionary::UnigramDictionary(const uint8_t *const streamStart, const unsigned int dictFlags)
: DICT_ROOT(streamStart), ROOT_POS(0),
- MAX_DIGRAPH_SEARCH_DEPTH(DEFAULT_MAX_DIGRAPH_SEARCH_DEPTH), FLAGS(flags) {
+ MAX_DIGRAPH_SEARCH_DEPTH(DEFAULT_MAX_DIGRAPH_SEARCH_DEPTH), DICT_FLAGS(dictFlags) {
if (DEBUG_DICT) {
AKLOGI("UnigramDictionary - constructor");
}
@@ -163,7 +163,7 @@ int UnigramDictionary::getSuggestions(ProximityInfo *proximityInfo, const int *x
masterCorrection.resetCorrection();
const DigraphUtils::digraph_t *digraphs = 0;
const int digraphsSize =
- DigraphUtils::getAllDigraphsForDictionaryAndReturnSize(FLAGS, &digraphs);
+ DigraphUtils::getAllDigraphsForDictionaryAndReturnSize(DICT_FLAGS, &digraphs);
if (digraphsSize > 0)
{ // Incrementally tune the word and try all possibilities
int codesBuffer[sizeof(*inputCodePoints) * inputSize];
diff --git a/native/jni/src/unigram_dictionary.h b/native/jni/src/unigram_dictionary.h
index 1a01758fe..a64a539bd 100644
--- a/native/jni/src/unigram_dictionary.h
+++ b/native/jni/src/unigram_dictionary.h
@@ -38,7 +38,7 @@ class UnigramDictionary {
static const int FLAG_MULTIPLE_SUGGEST_ABORT = 0;
static const int FLAG_MULTIPLE_SUGGEST_SKIP = 1;
static const int FLAG_MULTIPLE_SUGGEST_CONTINUE = 2;
- UnigramDictionary(const uint8_t *const streamStart, const unsigned int flags);
+ UnigramDictionary(const uint8_t *const streamStart, const unsigned int dictFlags);
int getProbability(const int *const inWord, const int length) const;
int getBigramPosition(int pos, int *word, int offset, int length) const;
int getSuggestions(ProximityInfo *proximityInfo, const int *xcoordinates,
@@ -46,6 +46,7 @@ class UnigramDictionary {
const std::map<int, int> *bigramMap, const uint8_t *bigramFilter,
const bool useFullEditDistance, int *outWords, int *frequencies,
int *outputTypes) const;
+ int getDictFlags() const { return DICT_FLAGS; }
virtual ~UnigramDictionary();
private:
@@ -109,7 +110,7 @@ class UnigramDictionary {
const uint8_t *const DICT_ROOT;
const int ROOT_POS;
const int MAX_DIGRAPH_SEARCH_DEPTH;
- const int FLAGS;
+ const int DICT_FLAGS;
};
} // namespace latinime
#endif // LATINIME_UNIGRAM_DICTIONARY_H