diff options
Diffstat (limited to 'native')
20 files changed, 230 insertions, 77 deletions
diff --git a/native/jni/com_android_inputmethod_keyboard_ProximityInfo.cpp b/native/jni/com_android_inputmethod_keyboard_ProximityInfo.cpp index 3c482ca58..dedb02abf 100644 --- a/native/jni/com_android_inputmethod_keyboard_ProximityInfo.cpp +++ b/native/jni/com_android_inputmethod_keyboard_ProximityInfo.cpp @@ -26,13 +26,13 @@ namespace latinime { static jlong latinime_Keyboard_setProximityInfo(JNIEnv *env, jclass clazz, jstring localeJStr, jint displayWidth, jint displayHeight, jint gridWidth, jint gridHeight, - jint mostCommonkeyWidth, jintArray proximityChars, jint keyCount, + jint mostCommonkeyWidth, jint mostCommonkeyHeight, jintArray proximityChars, jint keyCount, jintArray keyXCoordinates, jintArray keyYCoordinates, jintArray keyWidths, jintArray keyHeights, jintArray keyCharCodes, jfloatArray sweetSpotCenterXs, jfloatArray sweetSpotCenterYs, jfloatArray sweetSpotRadii) { ProximityInfo *proximityInfo = new ProximityInfo(env, localeJStr, displayWidth, displayHeight, - gridWidth, gridHeight, mostCommonkeyWidth, proximityChars, keyCount, - keyXCoordinates, keyYCoordinates, keyWidths, keyHeights, keyCharCodes, + gridWidth, gridHeight, mostCommonkeyWidth, mostCommonkeyHeight, proximityChars, + keyCount, keyXCoordinates, keyYCoordinates, keyWidths, keyHeights, keyCharCodes, sweetSpotCenterXs, sweetSpotCenterYs, sweetSpotRadii); return reinterpret_cast<jlong>(proximityInfo); } @@ -44,7 +44,7 @@ static void latinime_Keyboard_release(JNIEnv *env, jclass clazz, jlong proximity static JNINativeMethod sMethods[] = { {const_cast<char *>("setProximityInfoNative"), - const_cast<char *>("(Ljava/lang/String;IIIII[II[I[I[I[I[I[F[F[F)J"), + const_cast<char *>("(Ljava/lang/String;IIIIII[II[I[I[I[I[I[F[F[F)J"), reinterpret_cast<void *>(latinime_Keyboard_setProximityInfo)}, {const_cast<char *>("releaseProximityInfoNative"), const_cast<char *>("(J)V"), diff --git a/native/jni/src/dictionary.cpp b/native/jni/src/dictionary.cpp index ed6ddb517..c998c0676 100644 --- a/native/jni/src/dictionary.cpp +++ b/native/jni/src/dictionary.cpp @@ -103,4 +103,9 @@ int Dictionary::getProbability(const int *word, int length) const { bool Dictionary::isValidBigram(const int *word1, int length1, const int *word2, int length2) const { return mBigramDictionary->isValidBigram(word1, length1, word2, length2); } + +int Dictionary::getDictFlags() const { + return mUnigramDictionary->getDictFlags(); +} + } // namespace latinime diff --git a/native/jni/src/dictionary.h b/native/jni/src/dictionary.h index 8c6a7de52..0653d3ca9 100644 --- a/native/jni/src/dictionary.h +++ b/native/jni/src/dictionary.h @@ -63,6 +63,7 @@ class Dictionary { int getDictSize() const { return mDictSize; } int getMmapFd() const { return mMmapFd; } int getDictBufAdjust() const { return mDictBufAdjust; } + int getDictFlags() const; virtual ~Dictionary(); private: diff --git a/native/jni/src/digraph_utils.cpp b/native/jni/src/digraph_utils.cpp index 8781c5077..6a1ab0271 100644 --- a/native/jni/src/digraph_utils.cpp +++ b/native/jni/src/digraph_utils.cpp @@ -27,39 +27,47 @@ const DigraphUtils::digraph_t DigraphUtils::GERMAN_UMLAUT_DIGRAPHS[] = const DigraphUtils::digraph_t DigraphUtils::FRENCH_LIGATURES_DIGRAPHS[] = { { 'a', 'e', 0x00E6 }, // U+00E6 : LATIN SMALL LETTER AE { 'o', 'e', 0x0153 } }; // U+0153 : LATIN SMALL LIGATURE OE +const DigraphUtils::DigraphType DigraphUtils::USED_DIGRAPH_TYPES[] = + { DIGRAPH_TYPE_GERMAN_UMLAUT, DIGRAPH_TYPE_FRENCH_LIGATURES }; /* static */ bool DigraphUtils::hasDigraphForCodePoint( const int dictFlags, const int compositeGlyphCodePoint) { - if (DigraphUtils::getDigraphForCodePoint(dictFlags, compositeGlyphCodePoint)) { + const DigraphUtils::DigraphType digraphType = getDigraphTypeForDictionary(dictFlags); + if (DigraphUtils::getDigraphForDigraphTypeAndCodePoint(digraphType, compositeGlyphCodePoint)) { return true; } return false; } -// Retrieves the set of all digraphs associated with the given dictionary. -// Returns the size of the digraph array, or 0 if none exist. -/* static */ int DigraphUtils::getAllDigraphsForDictionaryAndReturnSize( - const int dictFlags, const DigraphUtils::digraph_t **digraphs) { +// Returns the digraph type associated with the given dictionary. +/* static */ DigraphUtils::DigraphType DigraphUtils::getDigraphTypeForDictionary( + const int dictFlags) { if (BinaryFormat::REQUIRES_GERMAN_UMLAUT_PROCESSING & dictFlags) { - *digraphs = DigraphUtils::GERMAN_UMLAUT_DIGRAPHS; - return NELEMS(DigraphUtils::GERMAN_UMLAUT_DIGRAPHS); + return DIGRAPH_TYPE_GERMAN_UMLAUT; } if (BinaryFormat::REQUIRES_FRENCH_LIGATURES_PROCESSING & dictFlags) { - *digraphs = DigraphUtils::FRENCH_LIGATURES_DIGRAPHS; - return NELEMS(DigraphUtils::FRENCH_LIGATURES_DIGRAPHS); + return DIGRAPH_TYPE_FRENCH_LIGATURES; } - return 0; + return DIGRAPH_TYPE_NONE; +} + +// Retrieves the set of all digraphs associated with the given dictionary flags. +// Returns the size of the digraph array, or 0 if none exist. +/* static */ int DigraphUtils::getAllDigraphsForDictionaryAndReturnSize( + const int dictFlags, const DigraphUtils::digraph_t **const digraphs) { + const DigraphUtils::DigraphType digraphType = getDigraphTypeForDictionary(dictFlags); + return getAllDigraphsForDigraphTypeAndReturnSize(digraphType, digraphs); } // Returns the digraph codepoint for the given composite glyph codepoint and digraph codepoint index // (which specifies the first or second codepoint in the digraph). -/* static */ int DigraphUtils::getDigraphCodePointForIndex(const int dictFlags, - const int compositeGlyphCodePoint, const DigraphCodePointIndex digraphCodePointIndex) { +/* static */ int DigraphUtils::getDigraphCodePointForIndex(const int compositeGlyphCodePoint, + const DigraphCodePointIndex digraphCodePointIndex) { if (digraphCodePointIndex == NOT_A_DIGRAPH_INDEX) { return NOT_A_CODE_POINT; } - const DigraphUtils::digraph_t *digraph = - DigraphUtils::getDigraphForCodePoint(dictFlags, compositeGlyphCodePoint); + const DigraphUtils::digraph_t *const digraph = + DigraphUtils::getDigraphForCodePoint(compositeGlyphCodePoint); if (!digraph) { return NOT_A_CODE_POINT; } @@ -72,16 +80,48 @@ const DigraphUtils::digraph_t DigraphUtils::FRENCH_LIGATURES_DIGRAPHS[] = return NOT_A_CODE_POINT; } +// Retrieves the set of all digraphs associated with the given digraph type. +// Returns the size of the digraph array, or 0 if none exist. +/* static */ int DigraphUtils::getAllDigraphsForDigraphTypeAndReturnSize( + const DigraphUtils::DigraphType digraphType, + const DigraphUtils::digraph_t **const digraphs) { + if (digraphType == DigraphUtils::DIGRAPH_TYPE_GERMAN_UMLAUT) { + *digraphs = GERMAN_UMLAUT_DIGRAPHS; + return NELEMS(GERMAN_UMLAUT_DIGRAPHS); + } + if (digraphType == DIGRAPH_TYPE_FRENCH_LIGATURES) { + *digraphs = FRENCH_LIGATURES_DIGRAPHS; + return NELEMS(FRENCH_LIGATURES_DIGRAPHS); + } + return 0; +} + /** * Returns the digraph for the input composite glyph codepoint, or 0 if none exists. - * dictFlags: the dictionary flags needed to determine which digraphs are supported. * compositeGlyphCodePoint: the method returns the digraph corresponding to this codepoint. */ /* static */ const DigraphUtils::digraph_t *DigraphUtils::getDigraphForCodePoint( - const int dictFlags, const int compositeGlyphCodePoint) { + const int compositeGlyphCodePoint) { + for (size_t i = 0; i < NELEMS(USED_DIGRAPH_TYPES); i++) { + const DigraphUtils::digraph_t *const digraph = getDigraphForDigraphTypeAndCodePoint( + USED_DIGRAPH_TYPES[i], compositeGlyphCodePoint); + if (digraph) { + return digraph; + } + } + return 0; +} + +/** + * Returns the digraph for the input composite glyph codepoint, or 0 if none exists. + * digraphType: the type of digraphs supported. + * compositeGlyphCodePoint: the method returns the digraph corresponding to this codepoint. + */ +/* static */ const DigraphUtils::digraph_t *DigraphUtils::getDigraphForDigraphTypeAndCodePoint( + const DigraphUtils::DigraphType digraphType, const int compositeGlyphCodePoint) { const DigraphUtils::digraph_t *digraphs = 0; const int digraphsSize = - DigraphUtils::getAllDigraphsForDictionaryAndReturnSize(dictFlags, &digraphs); + DigraphUtils::getAllDigraphsForDictionaryAndReturnSize(digraphType, &digraphs); for (int i = 0; i < digraphsSize; i++) { if (digraphs[i].compositeGlyph == compositeGlyphCodePoint) { return &digraphs[i]; diff --git a/native/jni/src/digraph_utils.h b/native/jni/src/digraph_utils.h index 6e364b67a..94435228e 100644 --- a/native/jni/src/digraph_utils.h +++ b/native/jni/src/digraph_utils.h @@ -27,21 +27,34 @@ class DigraphUtils { SECOND_DIGRAPH_CODEPOINT } DigraphCodePointIndex; + typedef enum { + DIGRAPH_TYPE_NONE, + DIGRAPH_TYPE_GERMAN_UMLAUT, + DIGRAPH_TYPE_FRENCH_LIGATURES + } DigraphType; + typedef struct { int first; int second; int compositeGlyph; } digraph_t; static bool hasDigraphForCodePoint(const int dictFlags, const int compositeGlyphCodePoint); static int getAllDigraphsForDictionaryAndReturnSize( - const int dictFlags, const digraph_t **digraphs); + const int dictFlags, const digraph_t **const digraphs); static int getDigraphCodePointForIndex(const int dictFlags, const int compositeGlyphCodePoint, const DigraphCodePointIndex digraphCodePointIndex); + static int getDigraphCodePointForIndex(const int compositeGlyphCodePoint, + const DigraphCodePointIndex digraphCodePointIndex); private: DISALLOW_IMPLICIT_CONSTRUCTORS(DigraphUtils); - static const digraph_t *getDigraphForCodePoint( - const int dictFlags, const int compositeGlyphCodePoint); + static DigraphType getDigraphTypeForDictionary(const int dictFlags); + static int getAllDigraphsForDigraphTypeAndReturnSize( + const DigraphType digraphType, const digraph_t **const digraphs); + static const digraph_t *getDigraphForCodePoint(const int compositeGlyphCodePoint); + static const digraph_t *getDigraphForDigraphTypeAndCodePoint( + const DigraphType digraphType, const int compositeGlyphCodePoint); static const digraph_t GERMAN_UMLAUT_DIGRAPHS[]; static const digraph_t FRENCH_LIGATURES_DIGRAPHS[]; + static const DigraphType USED_DIGRAPH_TYPES[]; }; } // namespace latinime #endif // DIGRAPH_UTILS_H diff --git a/native/jni/src/proximity_info.cpp b/native/jni/src/proximity_info.cpp index 74b5e0131..88d670d61 100644 --- a/native/jni/src/proximity_info.cpp +++ b/native/jni/src/proximity_info.cpp @@ -49,13 +49,17 @@ static AK_FORCE_INLINE void safeGetOrFillZeroFloatArrayRegion(JNIEnv *env, jfloa ProximityInfo::ProximityInfo(JNIEnv *env, const jstring localeJStr, const int keyboardWidth, const int keyboardHeight, const int gridWidth, - const int gridHeight, const int mostCommonKeyWidth, const jintArray proximityChars, - const int keyCount, const jintArray keyXCoordinates, const jintArray keyYCoordinates, - const jintArray keyWidths, const jintArray keyHeights, const jintArray keyCharCodes, - const jfloatArray sweetSpotCenterXs, const jfloatArray sweetSpotCenterYs, - const jfloatArray sweetSpotRadii) + const int gridHeight, const int mostCommonKeyWidth, const int mostCommonKeyHeight, + const jintArray proximityChars, const int keyCount, const jintArray keyXCoordinates, + const jintArray keyYCoordinates, const jintArray keyWidths, const jintArray keyHeights, + const jintArray keyCharCodes, const jfloatArray sweetSpotCenterXs, + const jfloatArray sweetSpotCenterYs, const jfloatArray sweetSpotRadii) : GRID_WIDTH(gridWidth), GRID_HEIGHT(gridHeight), MOST_COMMON_KEY_WIDTH(mostCommonKeyWidth), MOST_COMMON_KEY_WIDTH_SQUARE(mostCommonKeyWidth * mostCommonKeyWidth), + MOST_COMMON_KEY_HEIGHT(mostCommonKeyHeight), + NORMALIZED_SQUARED_MOST_COMMON_KEY_HYPOTENUSE(1.0f + + SQUARE_FLOAT(static_cast<float>(mostCommonKeyHeight) / + static_cast<float>(mostCommonKeyWidth))), CELL_WIDTH((keyboardWidth + gridWidth - 1) / gridWidth), CELL_HEIGHT((keyboardHeight + gridHeight - 1) / gridHeight), KEY_COUNT(min(keyCount, MAX_KEY_COUNT_IN_A_KEYBOARD)), @@ -129,7 +133,7 @@ bool ProximityInfo::hasSpaceProximity(const int x, const int y) const { } float ProximityInfo::getNormalizedSquaredDistanceFromCenterFloatG( - const int keyId, const int x, const int y) const { + const int keyId, const int x, const int y, const float verticalScale) const { const bool correctTouchPosition = hasTouchPositionCorrectionData(); const float centerX = static_cast<float>(correctTouchPosition ? getSweetSpotCenterXAt(keyId) : getKeyCenterXOfKeyIdG(keyId)); @@ -138,7 +142,7 @@ float ProximityInfo::getNormalizedSquaredDistanceFromCenterFloatG( if (correctTouchPosition) { const float sweetSpotCenterY = static_cast<float>(getSweetSpotCenterYAt(keyId)); const float gapY = sweetSpotCenterY - visualKeyCenterY; - centerY = visualKeyCenterY + gapY * ProximityInfoParams::VERTICAL_SWEET_SPOT_SCALE_G; + centerY = visualKeyCenterY + gapY * verticalScale; } else { centerY = visualKeyCenterY; } diff --git a/native/jni/src/proximity_info.h b/native/jni/src/proximity_info.h index 57a175d2c..deb9ae0de 100644 --- a/native/jni/src/proximity_info.h +++ b/native/jni/src/proximity_info.h @@ -30,16 +30,17 @@ class ProximityInfo { public: ProximityInfo(JNIEnv *env, const jstring localeJStr, const int keyboardWidth, const int keyboardHeight, const int gridWidth, - const int gridHeight, const int mostCommonKeyWidth, const jintArray proximityChars, - const int keyCount, const jintArray keyXCoordinates, const jintArray keyYCoordinates, - const jintArray keyWidths, const jintArray keyHeights, const jintArray keyCharCodes, - const jfloatArray sweetSpotCenterXs, const jfloatArray sweetSpotCenterYs, - const jfloatArray sweetSpotRadii); + const int gridHeight, const int mostCommonKeyWidth, const int mostCommonKeyHeight, + const jintArray proximityChars, const int keyCount, const jintArray keyXCoordinates, + const jintArray keyYCoordinates, const jintArray keyWidths, const jintArray keyHeights, + const jintArray keyCharCodes, const jfloatArray sweetSpotCenterXs, + const jfloatArray sweetSpotCenterYs, const jfloatArray sweetSpotRadii); ~ProximityInfo(); bool hasSpaceProximity(const int x, const int y) const; int getNormalizedSquaredDistance(const int inputIndex, const int proximityIndex) const; float getNormalizedSquaredDistanceFromCenterFloatG( - const int keyId, const int x, const int y) const; + const int keyId, const int x, const int y, + const float verticalScale) const; bool sameAsTyped(const unsigned short *word, int length) const; int getCodePointOf(const int keyIndex) const; bool hasSweetSpotData(const int keyIndex) const { @@ -55,6 +56,9 @@ class ProximityInfo { bool hasTouchPositionCorrectionData() const { return HAS_TOUCH_POSITION_CORRECTION_DATA; } int getMostCommonKeyWidth() const { return MOST_COMMON_KEY_WIDTH; } int getMostCommonKeyWidthSquare() const { return MOST_COMMON_KEY_WIDTH_SQUARE; } + float getNormalizedSquaredMostCommonKeyHypotenuse() const { + return NORMALIZED_SQUARED_MOST_COMMON_KEY_HYPOTENUSE; + } int getKeyCount() const { return KEY_COUNT; } int getCellHeight() const { return CELL_HEIGHT; } int getCellWidth() const { return CELL_WIDTH; } @@ -98,6 +102,8 @@ class ProximityInfo { const int GRID_HEIGHT; const int MOST_COMMON_KEY_WIDTH; const int MOST_COMMON_KEY_WIDTH_SQUARE; + const int MOST_COMMON_KEY_HEIGHT; + const float NORMALIZED_SQUARED_MOST_COMMON_KEY_HYPOTENUSE; const int CELL_WIDTH; const int CELL_HEIGHT; const int KEY_COUNT; diff --git a/native/jni/src/proximity_info_params.cpp b/native/jni/src/proximity_info_params.cpp index f9a4352ee..2675d9e70 100644 --- a/native/jni/src/proximity_info_params.cpp +++ b/native/jni/src/proximity_info_params.cpp @@ -20,7 +20,8 @@ namespace latinime { const float ProximityInfoParams::NOT_A_DISTANCE_FLOAT = -1.0f; const int ProximityInfoParams::MIN_DOUBLE_LETTER_BEELINE_SPEED_PERCENTILE = 5; -const float ProximityInfoParams::VERTICAL_SWEET_SPOT_SCALE_G = 1.1f; +const float ProximityInfoParams::VERTICAL_SWEET_SPOT_SCALE = 1.0f; +const float ProximityInfoParams::VERTICAL_SWEET_SPOT_SCALE_G = 0.5f; /* Per method constants */ // Used by ProximityInfoStateUtils::initGeometricDistanceInfos() diff --git a/native/jni/src/proximity_info_params.h b/native/jni/src/proximity_info_params.h index e7aec0976..4e47f7308 100644 --- a/native/jni/src/proximity_info_params.h +++ b/native/jni/src/proximity_info_params.h @@ -25,6 +25,7 @@ class ProximityInfoParams { public: static const float NOT_A_DISTANCE_FLOAT; static const int MIN_DOUBLE_LETTER_BEELINE_SPEED_PERCENTILE; + static const float VERTICAL_SWEET_SPOT_SCALE; static const float VERTICAL_SWEET_SPOT_SCALE_G; // Used by ProximityInfoStateUtils::initGeometricDistanceInfos() diff --git a/native/jni/src/proximity_info_state.cpp b/native/jni/src/proximity_info_state.cpp index 861ba9971..a10b260e1 100644 --- a/native/jni/src/proximity_info_state.cpp +++ b/native/jni/src/proximity_info_state.cpp @@ -28,6 +28,7 @@ namespace latinime { +// TODO: Remove the dependency of "isGeometric" void ProximityInfoState::initInputParams(const int pointerId, const float maxPointToKeyLength, const ProximityInfo *proximityInfo, const int *const inputCodes, const int inputSize, const int *const xCoordinates, const int *const yCoordinates, const int *const times, @@ -94,12 +95,17 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi pushTouchPointStartIndex, lastSavedInputSize); } + // TODO: Remove the dependency of "isGeometric" + const float verticalSweetSpotScale = isGeometric + ? ProximityInfoParams::VERTICAL_SWEET_SPOT_SCALE_G + : ProximityInfoParams::VERTICAL_SWEET_SPOT_SCALE; + if (xCoordinates && yCoordinates) { mSampledInputSize = ProximityInfoStateUtils::updateTouchPoints(mProximityInfo, mMaxPointToKeyLength, mInputProximities, xCoordinates, yCoordinates, times, - pointerIds, inputSize, isGeometric, pointerId, pushTouchPointStartIndex, - &mSampledInputXs, &mSampledInputYs, &mSampledTimes, &mSampledLengthCache, - &mSampledInputIndice); + pointerIds, verticalSweetSpotScale, inputSize, isGeometric, pointerId, + pushTouchPointStartIndex, &mSampledInputXs, &mSampledInputYs, &mSampledTimes, + &mSampledLengthCache, &mSampledInputIndice); } if (mSampledInputSize > 0 && isGeometric) { @@ -115,8 +121,8 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi if (mSampledInputSize > 0) { ProximityInfoStateUtils::initGeometricDistanceInfos(mProximityInfo, mSampledInputSize, - lastSavedInputSize, &mSampledInputXs, &mSampledInputYs, &mSampledNearKeySets, - &mSampledDistanceCache_G); + lastSavedInputSize, verticalSweetSpotScale, &mSampledInputXs, &mSampledInputYs, + &mSampledNearKeySets, &mSampledDistanceCache_G); if (isGeometric) { // updates probabilities of skipping or mapping each key for all points. ProximityInfoStateUtils::updateAlignPointProbabilities( diff --git a/native/jni/src/proximity_info_state_utils.cpp b/native/jni/src/proximity_info_state_utils.cpp index 760508076..df70cffdf 100644 --- a/native/jni/src/proximity_info_state_utils.cpp +++ b/native/jni/src/proximity_info_state_utils.cpp @@ -42,8 +42,8 @@ namespace latinime { const ProximityInfo *const proximityInfo, const int maxPointToKeyLength, const int *const inputProximities, const int *const inputXCoordinates, const int *const inputYCoordinates, const int *const times, const int *const pointerIds, - const int inputSize, const bool isGeometric, const int pointerId, - const int pushTouchPointStartIndex, std::vector<int> *sampledInputXs, + const float verticalSweetSpotScale, const int inputSize, const bool isGeometric, + const int pointerId, const int pushTouchPointStartIndex, std::vector<int> *sampledInputXs, std::vector<int> *sampledInputYs, std::vector<int> *sampledInputTimes, std::vector<int> *sampledLengthCache, std::vector<int> *sampledInputIndice) { if (DEBUG_SAMPLING_POINTS) { @@ -112,10 +112,10 @@ namespace latinime { } if (pushTouchPoint(proximityInfo, maxPointToKeyLength, i, c, x, y, time, - isGeometric /* doSampling */, i == lastInputIndex, sumAngle, - currentNearKeysDistances, prevNearKeysDistances, prevPrevNearKeysDistances, - sampledInputXs, sampledInputYs, sampledInputTimes, sampledLengthCache, - sampledInputIndice)) { + verticalSweetSpotScale, isGeometric /* doSampling */, i == lastInputIndex, + sumAngle, currentNearKeysDistances, prevNearKeysDistances, + prevPrevNearKeysDistances, sampledInputXs, sampledInputYs, sampledInputTimes, + sampledLengthCache, sampledInputIndice)) { // Previous point information was popped. NearKeysDistanceMap *tmp = prevNearKeysDistances; prevNearKeysDistances = currentNearKeysDistances; @@ -222,7 +222,8 @@ namespace latinime { /* static */ void ProximityInfoStateUtils::initGeometricDistanceInfos( const ProximityInfo *const proximityInfo, const int sampledInputSize, - const int lastSavedInputSize, const std::vector<int> *const sampledInputXs, + const int lastSavedInputSize, const float verticalSweetSpotScale, + const std::vector<int> *const sampledInputXs, const std::vector<int> *const sampledInputYs, std::vector<NearKeycodesSet> *SampledNearKeySets, std::vector<float> *SampledDistanceCache_G) { @@ -236,7 +237,8 @@ namespace latinime { const int x = (*sampledInputXs)[i]; const int y = (*sampledInputYs)[i]; const float normalizedSquaredDistance = - proximityInfo->getNormalizedSquaredDistanceFromCenterFloatG(k, x, y); + proximityInfo->getNormalizedSquaredDistanceFromCenterFloatG( + k, x, y, verticalSweetSpotScale); (*SampledDistanceCache_G)[index] = normalizedSquaredDistance; if (normalizedSquaredDistance < ProximityInfoParams::NEAR_KEY_NORMALIZED_SQUARED_THRESHOLD) { @@ -354,12 +356,14 @@ namespace latinime { // the given point and the nearest key position. /* static */ float ProximityInfoStateUtils::updateNearKeysDistances( const ProximityInfo *const proximityInfo, const float maxPointToKeyLength, const int x, - const int y, NearKeysDistanceMap *const currentNearKeysDistances) { + const int y, const float verticalSweetspotScale, + NearKeysDistanceMap *const currentNearKeysDistances) { currentNearKeysDistances->clear(); const int keyCount = proximityInfo->getKeyCount(); float nearestKeyDistance = maxPointToKeyLength; for (int k = 0; k < keyCount; ++k) { - const float dist = proximityInfo->getNormalizedSquaredDistanceFromCenterFloatG(k, x, y); + const float dist = proximityInfo->getNormalizedSquaredDistanceFromCenterFloatG(k, x, y, + verticalSweetspotScale); if (dist < ProximityInfoParams::NEAR_KEY_THRESHOLD_FOR_DISTANCE) { currentNearKeysDistances->insert(std::pair<int, float>(k, dist)); } @@ -439,7 +443,8 @@ namespace latinime { // Returning if previous point is popped or not. /* static */ bool ProximityInfoStateUtils::pushTouchPoint(const ProximityInfo *const proximityInfo, const int maxPointToKeyLength, const int inputIndex, const int nodeCodePoint, int x, int y, - const int time, const bool doSampling, const bool isLastPoint, const float sumAngle, + const int time, const float verticalSweetSpotScale, const bool doSampling, + const bool isLastPoint, const float sumAngle, NearKeysDistanceMap *const currentNearKeysDistances, const NearKeysDistanceMap *const prevNearKeysDistances, const NearKeysDistanceMap *const prevPrevNearKeysDistances, @@ -451,8 +456,8 @@ namespace latinime { size_t size = sampledInputXs->size(); bool popped = false; if (nodeCodePoint < 0 && doSampling) { - const float nearest = updateNearKeysDistances( - proximityInfo, maxPointToKeyLength, x, y, currentNearKeysDistances); + const float nearest = updateNearKeysDistances(proximityInfo, maxPointToKeyLength, x, y, + verticalSweetSpotScale, currentNearKeysDistances); const float score = getPointScore(mostCommonKeyWidth, x, y, time, isLastPoint, nearest, sumAngle, currentNearKeysDistances, prevNearKeysDistances, prevPrevNearKeysDistances, sampledInputXs, sampledInputYs); diff --git a/native/jni/src/proximity_info_state_utils.h b/native/jni/src/proximity_info_state_utils.h index 3ceb25d8b..c9feb59a3 100644 --- a/native/jni/src/proximity_info_state_utils.h +++ b/native/jni/src/proximity_info_state_utils.h @@ -38,7 +38,8 @@ class ProximityInfoStateUtils { static int updateTouchPoints(const ProximityInfo *const proximityInfo, const int maxPointToKeyLength, const int *const inputProximities, const int *const inputXCoordinates, const int *const inputYCoordinates, - const int *const times, const int *const pointerIds, const int inputSize, + const int *const times, const int *const pointerIds, + const float verticalSweetSpotScale, const int inputSize, const bool isGeometric, const int pointerId, const int pushTouchPointStartIndex, std::vector<int> *sampledInputXs, std::vector<int> *sampledInputYs, std::vector<int> *sampledInputTimes, std::vector<int> *sampledLengthCache, @@ -84,6 +85,7 @@ class ProximityInfoStateUtils { const int inputIndex, const int keyId); static void initGeometricDistanceInfos(const ProximityInfo *const proximityInfo, const int sampledInputSize, const int lastSavedInputSize, + const float verticalSweetSpotScale, const std::vector<int> *const sampledInputXs, const std::vector<int> *const sampledInputYs, std::vector<NearKeycodesSet> *SampledNearKeySets, @@ -118,6 +120,7 @@ class ProximityInfoStateUtils { static float updateNearKeysDistances(const ProximityInfo *const proximityInfo, const float maxPointToKeyLength, const int x, const int y, + const float verticalSweetSpotScale, NearKeysDistanceMap *const currentNearKeysDistances); static bool isPrevLocalMin(const NearKeysDistanceMap *const currentNearKeysDistances, const NearKeysDistanceMap *const prevNearKeysDistances, @@ -130,7 +133,8 @@ class ProximityInfoStateUtils { std::vector<int> *sampledInputXs, std::vector<int> *sampledInputYs); static bool pushTouchPoint(const ProximityInfo *const proximityInfo, const int maxPointToKeyLength, const int inputIndex, const int nodeCodePoint, int x, - int y, const int time, const bool doSampling, const bool isLastPoint, + int y, const int time, const float verticalSweetSpotScale, + const bool doSampling, const bool isLastPoint, const float sumAngle, NearKeysDistanceMap *const currentNearKeysDistances, const NearKeysDistanceMap *const prevNearKeysDistances, const NearKeysDistanceMap *const prevPrevNearKeysDistances, diff --git a/native/jni/src/suggest/core/dicnode/dic_node.h b/native/jni/src/suggest/core/dicnode/dic_node.h index 7bfa459a2..32faae52c 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node.h +++ b/native/jni/src/suggest/core/dicnode/dic_node.h @@ -23,6 +23,7 @@ #include "dic_node_profiler.h" #include "dic_node_properties.h" #include "dic_node_release_listener.h" +#include "digraph_utils.h" #if DEBUG_DICT #define LOGI_SHOW_ADD_COST_PROP \ @@ -48,13 +49,6 @@ namespace latinime { -// Naming convention -// - Distance: "Weighted" edit distance -- used both for spatial and language. -// - Compound Distance: Spatial Distance + Language Distance -- used for pruning and scoring -// - Cost: delta/diff for Distance -- used both for spatial and language -// - Length: "Non-weighted" -- used only for spatial -// - Probability: "Non-weighted" -- used only for language - // This struct is purely a bucket to return values. No instances of this struct should be kept. struct DicNode_InputStateG { bool mNeedsToUpdateInputStateG; @@ -406,8 +400,15 @@ class DicNode { // TODO: Remove // ////////////////////// // TODO: Remove once touch path is merged into ProximityInfoState + // Note: Returned codepoint may be a digraph codepoint if the node is in a composite glyph. int getNodeCodePoint() const { - return mDicNodeProperties.getNodeCodePoint(); + const int codePoint = mDicNodeProperties.getNodeCodePoint(); + const DigraphUtils::DigraphCodePointIndex digraphIndex = + mDicNodeState.mDicNodeStateScoring.getDigraphIndex(); + if (digraphIndex == DigraphUtils::NOT_A_DIGRAPH_INDEX) { + return codePoint; + } + return DigraphUtils::getDigraphCodePointForIndex(codePoint, digraphIndex); } //////////////////////////////// @@ -459,6 +460,15 @@ class DicNode { mDicNodeState.mDicNodeStateScoring.setDoubleLetterLevel(doubleLetterLevel); } + bool isInDigraph() const { + return mDicNodeState.mDicNodeStateScoring.getDigraphIndex() + != DigraphUtils::NOT_A_DIGRAPH_INDEX; + } + + void advanceDigraphIndex() { + mDicNodeState.mDicNodeStateScoring.advanceDigraphIndex(); + } + uint8_t getFlags() const { return mDicNodeProperties.getFlags(); } diff --git a/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h b/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h index 8e816329f..8902d3122 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h +++ b/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h @@ -20,6 +20,7 @@ #include <stdint.h> #include "defines.h" +#include "digraph_utils.h" namespace latinime { @@ -27,6 +28,7 @@ class DicNodeStateScoring { public: AK_FORCE_INLINE DicNodeStateScoring() : mDoubleLetterLevel(NOT_A_DOUBLE_LETTER), + mDigraphIndex(DigraphUtils::NOT_A_DIGRAPH_INDEX), mEditCorrectionCount(0), mProximityCorrectionCount(0), mNormalizedCompoundDistance(0.0f), mSpatialDistance(0.0f), mLanguageDistance(0.0f), mTotalPrevWordsLanguageCost(0.0f), mRawLength(0.0f) { @@ -43,6 +45,7 @@ class DicNodeStateScoring { mTotalPrevWordsLanguageCost = 0.0f; mRawLength = 0.0f; mDoubleLetterLevel = NOT_A_DOUBLE_LETTER; + mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX; } AK_FORCE_INLINE void init(const DicNodeStateScoring *const scoring) { @@ -54,6 +57,7 @@ class DicNodeStateScoring { mTotalPrevWordsLanguageCost = scoring->mTotalPrevWordsLanguageCost; mRawLength = scoring->mRawLength; mDoubleLetterLevel = scoring->mDoubleLetterLevel; + mDigraphIndex = scoring->mDigraphIndex; } void addCost(const float spatialCost, const float languageCost, const bool doNormalization, @@ -126,6 +130,24 @@ class DicNodeStateScoring { } } + DigraphUtils::DigraphCodePointIndex getDigraphIndex() const { + return mDigraphIndex; + } + + void advanceDigraphIndex() { + switch(mDigraphIndex) { + case DigraphUtils::NOT_A_DIGRAPH_INDEX: + mDigraphIndex = DigraphUtils::FIRST_DIGRAPH_CODEPOINT; + break; + case DigraphUtils::FIRST_DIGRAPH_CODEPOINT: + mDigraphIndex = DigraphUtils::SECOND_DIGRAPH_CODEPOINT; + break; + case DigraphUtils::SECOND_DIGRAPH_CODEPOINT: + mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX; + break; + } + } + float getTotalPrevWordsLanguageCost() const { return mTotalPrevWordsLanguageCost; } @@ -135,6 +157,7 @@ class DicNodeStateScoring { // Use a default copy constructor and an assign operator because shallow copies are ok // for this class DoubleLetterLevel mDoubleLetterLevel; + DigraphUtils::DigraphCodePointIndex mDigraphIndex; int16_t mEditCorrectionCount; int16_t mProximityCorrectionCount; diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.cpp b/native/jni/src/suggest/core/session/dic_traverse_session.cpp index ef6616e40..5b783a2ba 100644 --- a/native/jni/src/suggest/core/session/dic_traverse_session.cpp +++ b/native/jni/src/suggest/core/session/dic_traverse_session.cpp @@ -84,6 +84,10 @@ const uint8_t *DicTraverseSession::getOffsetDict() const { return mDictionary->getOffsetDict(); } +int DicTraverseSession::getDictFlags() const { + return mDictionary->getDictFlags(); +} + void DicTraverseSession::resetCache(const int nextActiveCacheSize, const int maxWords) { mDicNodesCache.reset(nextActiveCacheSize, maxWords); mBigramCacheMap.clear(); diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.h b/native/jni/src/suggest/core/session/dic_traverse_session.h index 62e1d1ab9..525d198cd 100644 --- a/native/jni/src/suggest/core/session/dic_traverse_session.h +++ b/native/jni/src/suggest/core/session/dic_traverse_session.h @@ -53,7 +53,7 @@ class DicTraverseSession { void resetCache(const int nextActiveCacheSize, const int maxWords); const uint8_t *getOffsetDict() const; - bool canUseCache() const; + int getDictFlags() const; //-------------------- // getters and setters diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp index 1e97a9176..63bb20004 100644 --- a/native/jni/src/suggest/core/suggest.cpp +++ b/native/jni/src/suggest/core/suggest.cpp @@ -18,6 +18,7 @@ #include "char_utils.h" #include "dictionary.h" +#include "digraph_utils.h" #include "proximity_info.h" #include "suggest/core/dicnode/dic_node.h" #include "suggest/core/dicnode/dic_node_priority_queue.h" @@ -221,7 +222,7 @@ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequen void Suggest::expandCurrentDicNodes(DicTraverseSession *traverseSession) const { const int inputSize = traverseSession->getInputSize(); DicNodeVector childDicNodes(TRAVERSAL->getDefaultExpandDicNodeSize()); - DicNode omissionDicNode; + DicNode correctionDicNode; // TODO: Find more efficient caching const bool shouldDepthLevelCache = TRAVERSAL->shouldDepthLevelCache(traverseSession); @@ -257,7 +258,10 @@ void Suggest::expandCurrentDicNodes(DicTraverseSession *traverseSession) const { dicNode.setCached(); } - if (isLookAheadCorrection) { + if (dicNode.isInDigraph()) { + // Finish digraph handling if the node is in the middle of a digraph expansion. + processDicNodeAsDigraph(traverseSession, &dicNode); + } else if (isLookAheadCorrection) { // The algorithm maintains a small set of "deferred" nodes that have not consumed the // latest touch point yet. These are needed to apply look-ahead correction operations // that require special handling of the latest touch point. For example, with insertions @@ -291,12 +295,18 @@ void Suggest::expandCurrentDicNodes(DicTraverseSession *traverseSession) const { processDicNodeAsMatch(traverseSession, childDicNode); continue; } + if (DigraphUtils::hasDigraphForCodePoint(traverseSession->getDictFlags(), + childDicNode->getNodeCodePoint())) { + correctionDicNode.initByCopy(childDicNode); + correctionDicNode.advanceDigraphIndex(); + processDicNodeAsDigraph(traverseSession, &correctionDicNode); + } if (allowsErrorCorrections && TRAVERSAL->isOmission(traverseSession, &dicNode, childDicNode)) { // TODO: (Gesture) Change weight between omission and substitution errors // TODO: (Gesture) Terminal node should not be handled as omission - omissionDicNode.initByCopy(childDicNode); - processDicNodeAsOmission(traverseSession, &omissionDicNode); + correctionDicNode.initByCopy(childDicNode); + processDicNodeAsOmission(traverseSession, &correctionDicNode); } const ProximityType proximityType = TRAVERSAL->getProximityType( traverseSession, &dicNode, childDicNode); @@ -400,6 +410,16 @@ void Suggest::processDicNodeAsSubstitution(DicTraverseSession *traverseSession, processExpandedDicNode(traverseSession, childDicNode); } +// Process the node codepoint as a digraph. This means that composite glyphs like the German +// u-umlaut is expanded to the transliteration "ue". Note that this happens in parallel with +// the normal non-digraph traversal, so both "uber" and "ueber" can be corrected to "[u-umlaut]ber". +void Suggest::processDicNodeAsDigraph(DicTraverseSession *traverseSession, + DicNode *childDicNode) const { + weightChildNode(traverseSession, childDicNode); + childDicNode->advanceDigraphIndex(); + processExpandedDicNode(traverseSession, childDicNode); +} + /** * Handle the dicNode as an omission error (e.g., ths => this). Skip the current letter and consider * matches for all possible next letters. Note that just skipping the current letter without any @@ -426,7 +446,6 @@ void Suggest::processDicNodeAsOmission( weightChildNode(traverseSession, childDicNode); if (!TRAVERSAL->isPossibleOmissionChildNode(traverseSession, dicNode, childDicNode)) { - DicNode::managedDelete(childDicNode); continue; } processExpandedDicNode(traverseSession, childDicNode); diff --git a/native/jni/src/suggest/core/suggest.h b/native/jni/src/suggest/core/suggest.h index a1e7e7a94..136c4e548 100644 --- a/native/jni/src/suggest/core/suggest.h +++ b/native/jni/src/suggest/core/suggest.h @@ -23,6 +23,15 @@ namespace latinime { +// Naming convention +// - Distance: "Weighted" edit distance -- used both for spatial and language. +// - Compound Distance: Spatial Distance + Language Distance -- used for pruning and scoring +// - Cost: delta/diff for Distance -- used both for spatial and language +// - Length: "Non-weighted" -- used only for spatial +// - Probability: "Non-weighted" -- used only for language +// - Score: Final calibrated score based on the compound distance, which is sent to java as the +// priority of a suggested word + class DicNode; class DicTraverseSession; class ProximityInfo; @@ -55,6 +64,7 @@ class Suggest : public SuggestInterface { void generateFeatures( DicTraverseSession *traverseSession, DicNode *dicNode, float *features) const; void processDicNodeAsOmission(DicTraverseSession *traverseSession, DicNode *dicNode) const; + void processDicNodeAsDigraph(DicTraverseSession *traverseSession, DicNode *dicNode) const; void processDicNodeAsTransposition(DicTraverseSession *traverseSession, DicNode *dicNode) const; void processDicNodeAsInsertion(DicTraverseSession *traverseSession, DicNode *dicNode) const; diff --git a/native/jni/src/unigram_dictionary.cpp b/native/jni/src/unigram_dictionary.cpp index 33795cade..a672294b5 100644 --- a/native/jni/src/unigram_dictionary.cpp +++ b/native/jni/src/unigram_dictionary.cpp @@ -32,9 +32,9 @@ namespace latinime { // TODO: check the header -UnigramDictionary::UnigramDictionary(const uint8_t *const streamStart, const unsigned int flags) +UnigramDictionary::UnigramDictionary(const uint8_t *const streamStart, const unsigned int dictFlags) : DICT_ROOT(streamStart), ROOT_POS(0), - MAX_DIGRAPH_SEARCH_DEPTH(DEFAULT_MAX_DIGRAPH_SEARCH_DEPTH), FLAGS(flags) { + MAX_DIGRAPH_SEARCH_DEPTH(DEFAULT_MAX_DIGRAPH_SEARCH_DEPTH), DICT_FLAGS(dictFlags) { if (DEBUG_DICT) { AKLOGI("UnigramDictionary - constructor"); } @@ -163,7 +163,7 @@ int UnigramDictionary::getSuggestions(ProximityInfo *proximityInfo, const int *x masterCorrection.resetCorrection(); const DigraphUtils::digraph_t *digraphs = 0; const int digraphsSize = - DigraphUtils::getAllDigraphsForDictionaryAndReturnSize(FLAGS, &digraphs); + DigraphUtils::getAllDigraphsForDictionaryAndReturnSize(DICT_FLAGS, &digraphs); if (digraphsSize > 0) { // Incrementally tune the word and try all possibilities int codesBuffer[sizeof(*inputCodePoints) * inputSize]; diff --git a/native/jni/src/unigram_dictionary.h b/native/jni/src/unigram_dictionary.h index 1a01758fe..a64a539bd 100644 --- a/native/jni/src/unigram_dictionary.h +++ b/native/jni/src/unigram_dictionary.h @@ -38,7 +38,7 @@ class UnigramDictionary { static const int FLAG_MULTIPLE_SUGGEST_ABORT = 0; static const int FLAG_MULTIPLE_SUGGEST_SKIP = 1; static const int FLAG_MULTIPLE_SUGGEST_CONTINUE = 2; - UnigramDictionary(const uint8_t *const streamStart, const unsigned int flags); + UnigramDictionary(const uint8_t *const streamStart, const unsigned int dictFlags); int getProbability(const int *const inWord, const int length) const; int getBigramPosition(int pos, int *word, int offset, int length) const; int getSuggestions(ProximityInfo *proximityInfo, const int *xcoordinates, @@ -46,6 +46,7 @@ class UnigramDictionary { const std::map<int, int> *bigramMap, const uint8_t *bigramFilter, const bool useFullEditDistance, int *outWords, int *frequencies, int *outputTypes) const; + int getDictFlags() const { return DICT_FLAGS; } virtual ~UnigramDictionary(); private: @@ -109,7 +110,7 @@ class UnigramDictionary { const uint8_t *const DICT_ROOT; const int ROOT_POS; const int MAX_DIGRAPH_SEARCH_DEPTH; - const int FLAGS; + const int DICT_FLAGS; }; } // namespace latinime #endif // LATINIME_UNIGRAM_DICTIONARY_H |