diff options
Diffstat (limited to 'native/jni/src')
22 files changed, 276 insertions, 90 deletions
diff --git a/native/jni/src/correction.cpp b/native/jni/src/correction.cpp index 76234f840..0c65939e0 100644 --- a/native/jni/src/correction.cpp +++ b/native/jni/src/correction.cpp @@ -675,7 +675,7 @@ inline static bool isUpperCase(unsigned short c) { multiplyIntCapped(typedLetterMultiplier, &finalFreq); } const float factor = - SuggestUtils::getDistanceScalingFactor(static_cast<float>(squaredDistance)); + SuggestUtils::getLengthScalingFactor(static_cast<float>(squaredDistance)); if (factor > 0.0f) { multiplyRate(static_cast<int>(factor * 100.0f), &finalFreq); } else if (squaredDistance == PROXIMITY_CHAR_WITHOUT_DISTANCE_INFO) { diff --git a/native/jni/src/defines.h b/native/jni/src/defines.h index a45691261..a7b023a75 100644 --- a/native/jni/src/defines.h +++ b/native/jni/src/defines.h @@ -216,6 +216,7 @@ static inline void prof_out(void) { #define DEBUG_DOUBLE_LETTER false #define DEBUG_CACHE false #define DEBUG_DUMP_ERROR false +#define DEBUG_EVALUATE_MOST_PROBABLE_STRING false #ifdef FLAG_FULL_DBG #define DEBUG_GEO_FULL true @@ -241,6 +242,7 @@ static inline void prof_out(void) { #define DEBUG_DOUBLE_LETTER false #define DEBUG_CACHE false #define DEBUG_DUMP_ERROR false +#define DEBUG_EVALUATE_MOST_PROBABLE_STRING false #define DEBUG_GEO_FULL false diff --git a/native/jni/src/dictionary.cpp b/native/jni/src/dictionary.cpp index ed6ddb517..c998c0676 100644 --- a/native/jni/src/dictionary.cpp +++ b/native/jni/src/dictionary.cpp @@ -103,4 +103,9 @@ int Dictionary::getProbability(const int *word, int length) const { bool Dictionary::isValidBigram(const int *word1, int length1, const int *word2, int length2) const { return mBigramDictionary->isValidBigram(word1, length1, word2, length2); } + +int Dictionary::getDictFlags() const { + return mUnigramDictionary->getDictFlags(); +} + } // namespace latinime diff --git a/native/jni/src/dictionary.h b/native/jni/src/dictionary.h index 8c6a7de52..0653d3ca9 100644 --- a/native/jni/src/dictionary.h +++ b/native/jni/src/dictionary.h @@ -63,6 +63,7 @@ class Dictionary { int getDictSize() const { return mDictSize; } int getMmapFd() const { return mMmapFd; } int getDictBufAdjust() const { return mDictBufAdjust; } + int getDictFlags() const; virtual ~Dictionary(); private: diff --git a/native/jni/src/digraph_utils.cpp b/native/jni/src/digraph_utils.cpp index 8781c5077..6a1ab0271 100644 --- a/native/jni/src/digraph_utils.cpp +++ b/native/jni/src/digraph_utils.cpp @@ -27,39 +27,47 @@ const DigraphUtils::digraph_t DigraphUtils::GERMAN_UMLAUT_DIGRAPHS[] = const DigraphUtils::digraph_t DigraphUtils::FRENCH_LIGATURES_DIGRAPHS[] = { { 'a', 'e', 0x00E6 }, // U+00E6 : LATIN SMALL LETTER AE { 'o', 'e', 0x0153 } }; // U+0153 : LATIN SMALL LIGATURE OE +const DigraphUtils::DigraphType DigraphUtils::USED_DIGRAPH_TYPES[] = + { DIGRAPH_TYPE_GERMAN_UMLAUT, DIGRAPH_TYPE_FRENCH_LIGATURES }; /* static */ bool DigraphUtils::hasDigraphForCodePoint( const int dictFlags, const int compositeGlyphCodePoint) { - if (DigraphUtils::getDigraphForCodePoint(dictFlags, compositeGlyphCodePoint)) { + const DigraphUtils::DigraphType digraphType = getDigraphTypeForDictionary(dictFlags); + if (DigraphUtils::getDigraphForDigraphTypeAndCodePoint(digraphType, compositeGlyphCodePoint)) { return true; } return false; } -// Retrieves the set of all digraphs associated with the given dictionary. -// Returns the size of the digraph array, or 0 if none exist. -/* static */ int DigraphUtils::getAllDigraphsForDictionaryAndReturnSize( - const int dictFlags, const DigraphUtils::digraph_t **digraphs) { +// Returns the digraph type associated with the given dictionary. +/* static */ DigraphUtils::DigraphType DigraphUtils::getDigraphTypeForDictionary( + const int dictFlags) { if (BinaryFormat::REQUIRES_GERMAN_UMLAUT_PROCESSING & dictFlags) { - *digraphs = DigraphUtils::GERMAN_UMLAUT_DIGRAPHS; - return NELEMS(DigraphUtils::GERMAN_UMLAUT_DIGRAPHS); + return DIGRAPH_TYPE_GERMAN_UMLAUT; } if (BinaryFormat::REQUIRES_FRENCH_LIGATURES_PROCESSING & dictFlags) { - *digraphs = DigraphUtils::FRENCH_LIGATURES_DIGRAPHS; - return NELEMS(DigraphUtils::FRENCH_LIGATURES_DIGRAPHS); + return DIGRAPH_TYPE_FRENCH_LIGATURES; } - return 0; + return DIGRAPH_TYPE_NONE; +} + +// Retrieves the set of all digraphs associated with the given dictionary flags. +// Returns the size of the digraph array, or 0 if none exist. +/* static */ int DigraphUtils::getAllDigraphsForDictionaryAndReturnSize( + const int dictFlags, const DigraphUtils::digraph_t **const digraphs) { + const DigraphUtils::DigraphType digraphType = getDigraphTypeForDictionary(dictFlags); + return getAllDigraphsForDigraphTypeAndReturnSize(digraphType, digraphs); } // Returns the digraph codepoint for the given composite glyph codepoint and digraph codepoint index // (which specifies the first or second codepoint in the digraph). -/* static */ int DigraphUtils::getDigraphCodePointForIndex(const int dictFlags, - const int compositeGlyphCodePoint, const DigraphCodePointIndex digraphCodePointIndex) { +/* static */ int DigraphUtils::getDigraphCodePointForIndex(const int compositeGlyphCodePoint, + const DigraphCodePointIndex digraphCodePointIndex) { if (digraphCodePointIndex == NOT_A_DIGRAPH_INDEX) { return NOT_A_CODE_POINT; } - const DigraphUtils::digraph_t *digraph = - DigraphUtils::getDigraphForCodePoint(dictFlags, compositeGlyphCodePoint); + const DigraphUtils::digraph_t *const digraph = + DigraphUtils::getDigraphForCodePoint(compositeGlyphCodePoint); if (!digraph) { return NOT_A_CODE_POINT; } @@ -72,16 +80,48 @@ const DigraphUtils::digraph_t DigraphUtils::FRENCH_LIGATURES_DIGRAPHS[] = return NOT_A_CODE_POINT; } +// Retrieves the set of all digraphs associated with the given digraph type. +// Returns the size of the digraph array, or 0 if none exist. +/* static */ int DigraphUtils::getAllDigraphsForDigraphTypeAndReturnSize( + const DigraphUtils::DigraphType digraphType, + const DigraphUtils::digraph_t **const digraphs) { + if (digraphType == DigraphUtils::DIGRAPH_TYPE_GERMAN_UMLAUT) { + *digraphs = GERMAN_UMLAUT_DIGRAPHS; + return NELEMS(GERMAN_UMLAUT_DIGRAPHS); + } + if (digraphType == DIGRAPH_TYPE_FRENCH_LIGATURES) { + *digraphs = FRENCH_LIGATURES_DIGRAPHS; + return NELEMS(FRENCH_LIGATURES_DIGRAPHS); + } + return 0; +} + /** * Returns the digraph for the input composite glyph codepoint, or 0 if none exists. - * dictFlags: the dictionary flags needed to determine which digraphs are supported. * compositeGlyphCodePoint: the method returns the digraph corresponding to this codepoint. */ /* static */ const DigraphUtils::digraph_t *DigraphUtils::getDigraphForCodePoint( - const int dictFlags, const int compositeGlyphCodePoint) { + const int compositeGlyphCodePoint) { + for (size_t i = 0; i < NELEMS(USED_DIGRAPH_TYPES); i++) { + const DigraphUtils::digraph_t *const digraph = getDigraphForDigraphTypeAndCodePoint( + USED_DIGRAPH_TYPES[i], compositeGlyphCodePoint); + if (digraph) { + return digraph; + } + } + return 0; +} + +/** + * Returns the digraph for the input composite glyph codepoint, or 0 if none exists. + * digraphType: the type of digraphs supported. + * compositeGlyphCodePoint: the method returns the digraph corresponding to this codepoint. + */ +/* static */ const DigraphUtils::digraph_t *DigraphUtils::getDigraphForDigraphTypeAndCodePoint( + const DigraphUtils::DigraphType digraphType, const int compositeGlyphCodePoint) { const DigraphUtils::digraph_t *digraphs = 0; const int digraphsSize = - DigraphUtils::getAllDigraphsForDictionaryAndReturnSize(dictFlags, &digraphs); + DigraphUtils::getAllDigraphsForDictionaryAndReturnSize(digraphType, &digraphs); for (int i = 0; i < digraphsSize; i++) { if (digraphs[i].compositeGlyph == compositeGlyphCodePoint) { return &digraphs[i]; diff --git a/native/jni/src/digraph_utils.h b/native/jni/src/digraph_utils.h index 6e364b67a..94435228e 100644 --- a/native/jni/src/digraph_utils.h +++ b/native/jni/src/digraph_utils.h @@ -27,21 +27,34 @@ class DigraphUtils { SECOND_DIGRAPH_CODEPOINT } DigraphCodePointIndex; + typedef enum { + DIGRAPH_TYPE_NONE, + DIGRAPH_TYPE_GERMAN_UMLAUT, + DIGRAPH_TYPE_FRENCH_LIGATURES + } DigraphType; + typedef struct { int first; int second; int compositeGlyph; } digraph_t; static bool hasDigraphForCodePoint(const int dictFlags, const int compositeGlyphCodePoint); static int getAllDigraphsForDictionaryAndReturnSize( - const int dictFlags, const digraph_t **digraphs); + const int dictFlags, const digraph_t **const digraphs); static int getDigraphCodePointForIndex(const int dictFlags, const int compositeGlyphCodePoint, const DigraphCodePointIndex digraphCodePointIndex); + static int getDigraphCodePointForIndex(const int compositeGlyphCodePoint, + const DigraphCodePointIndex digraphCodePointIndex); private: DISALLOW_IMPLICIT_CONSTRUCTORS(DigraphUtils); - static const digraph_t *getDigraphForCodePoint( - const int dictFlags, const int compositeGlyphCodePoint); + static DigraphType getDigraphTypeForDictionary(const int dictFlags); + static int getAllDigraphsForDigraphTypeAndReturnSize( + const DigraphType digraphType, const digraph_t **const digraphs); + static const digraph_t *getDigraphForCodePoint(const int compositeGlyphCodePoint); + static const digraph_t *getDigraphForDigraphTypeAndCodePoint( + const DigraphType digraphType, const int compositeGlyphCodePoint); static const digraph_t GERMAN_UMLAUT_DIGRAPHS[]; static const digraph_t FRENCH_LIGATURES_DIGRAPHS[]; + static const DigraphType USED_DIGRAPH_TYPES[]; }; } // namespace latinime #endif // DIGRAPH_UTILS_H diff --git a/native/jni/src/proximity_info.cpp b/native/jni/src/proximity_info.cpp index 50f38e82e..88d670d61 100644 --- a/native/jni/src/proximity_info.cpp +++ b/native/jni/src/proximity_info.cpp @@ -49,13 +49,17 @@ static AK_FORCE_INLINE void safeGetOrFillZeroFloatArrayRegion(JNIEnv *env, jfloa ProximityInfo::ProximityInfo(JNIEnv *env, const jstring localeJStr, const int keyboardWidth, const int keyboardHeight, const int gridWidth, - const int gridHeight, const int mostCommonKeyWidth, const jintArray proximityChars, - const int keyCount, const jintArray keyXCoordinates, const jintArray keyYCoordinates, - const jintArray keyWidths, const jintArray keyHeights, const jintArray keyCharCodes, - const jfloatArray sweetSpotCenterXs, const jfloatArray sweetSpotCenterYs, - const jfloatArray sweetSpotRadii) + const int gridHeight, const int mostCommonKeyWidth, const int mostCommonKeyHeight, + const jintArray proximityChars, const int keyCount, const jintArray keyXCoordinates, + const jintArray keyYCoordinates, const jintArray keyWidths, const jintArray keyHeights, + const jintArray keyCharCodes, const jfloatArray sweetSpotCenterXs, + const jfloatArray sweetSpotCenterYs, const jfloatArray sweetSpotRadii) : GRID_WIDTH(gridWidth), GRID_HEIGHT(gridHeight), MOST_COMMON_KEY_WIDTH(mostCommonKeyWidth), MOST_COMMON_KEY_WIDTH_SQUARE(mostCommonKeyWidth * mostCommonKeyWidth), + MOST_COMMON_KEY_HEIGHT(mostCommonKeyHeight), + NORMALIZED_SQUARED_MOST_COMMON_KEY_HYPOTENUSE(1.0f + + SQUARE_FLOAT(static_cast<float>(mostCommonKeyHeight) / + static_cast<float>(mostCommonKeyWidth))), CELL_WIDTH((keyboardWidth + gridWidth - 1) / gridWidth), CELL_HEIGHT((keyboardHeight + gridHeight - 1) / gridHeight), KEY_COUNT(min(keyCount, MAX_KEY_COUNT_IN_A_KEYBOARD)), diff --git a/native/jni/src/proximity_info.h b/native/jni/src/proximity_info.h index e21262fdb..deb9ae0de 100644 --- a/native/jni/src/proximity_info.h +++ b/native/jni/src/proximity_info.h @@ -30,11 +30,11 @@ class ProximityInfo { public: ProximityInfo(JNIEnv *env, const jstring localeJStr, const int keyboardWidth, const int keyboardHeight, const int gridWidth, - const int gridHeight, const int mostCommonKeyWidth, const jintArray proximityChars, - const int keyCount, const jintArray keyXCoordinates, const jintArray keyYCoordinates, - const jintArray keyWidths, const jintArray keyHeights, const jintArray keyCharCodes, - const jfloatArray sweetSpotCenterXs, const jfloatArray sweetSpotCenterYs, - const jfloatArray sweetSpotRadii); + const int gridHeight, const int mostCommonKeyWidth, const int mostCommonKeyHeight, + const jintArray proximityChars, const int keyCount, const jintArray keyXCoordinates, + const jintArray keyYCoordinates, const jintArray keyWidths, const jintArray keyHeights, + const jintArray keyCharCodes, const jfloatArray sweetSpotCenterXs, + const jfloatArray sweetSpotCenterYs, const jfloatArray sweetSpotRadii); ~ProximityInfo(); bool hasSpaceProximity(const int x, const int y) const; int getNormalizedSquaredDistance(const int inputIndex, const int proximityIndex) const; @@ -56,6 +56,9 @@ class ProximityInfo { bool hasTouchPositionCorrectionData() const { return HAS_TOUCH_POSITION_CORRECTION_DATA; } int getMostCommonKeyWidth() const { return MOST_COMMON_KEY_WIDTH; } int getMostCommonKeyWidthSquare() const { return MOST_COMMON_KEY_WIDTH_SQUARE; } + float getNormalizedSquaredMostCommonKeyHypotenuse() const { + return NORMALIZED_SQUARED_MOST_COMMON_KEY_HYPOTENUSE; + } int getKeyCount() const { return KEY_COUNT; } int getCellHeight() const { return CELL_HEIGHT; } int getCellWidth() const { return CELL_WIDTH; } @@ -99,6 +102,8 @@ class ProximityInfo { const int GRID_HEIGHT; const int MOST_COMMON_KEY_WIDTH; const int MOST_COMMON_KEY_WIDTH_SQUARE; + const int MOST_COMMON_KEY_HEIGHT; + const float NORMALIZED_SQUARED_MOST_COMMON_KEY_HYPOTENUSE; const int CELL_WIDTH; const int CELL_HEIGHT; const int KEY_COUNT; diff --git a/native/jni/src/proximity_info_state.cpp b/native/jni/src/proximity_info_state.cpp index a10b260e1..cc5b736bd 100644 --- a/native/jni/src/proximity_info_state.cpp +++ b/native/jni/src/proximity_info_state.cpp @@ -81,7 +81,7 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi mSampledTimes.clear(); mSampledInputIndice.clear(); mSampledLengthCache.clear(); - mSampledDistanceCache_G.clear(); + mSampledNormalizedSquaredLengthCache.clear(); mSampledNearKeySets.clear(); mSampledSearchKeySets.clear(); mSpeedRates.clear(); @@ -122,14 +122,15 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi if (mSampledInputSize > 0) { ProximityInfoStateUtils::initGeometricDistanceInfos(mProximityInfo, mSampledInputSize, lastSavedInputSize, verticalSweetSpotScale, &mSampledInputXs, &mSampledInputYs, - &mSampledNearKeySets, &mSampledDistanceCache_G); + &mSampledNearKeySets, &mSampledNormalizedSquaredLengthCache); if (isGeometric) { // updates probabilities of skipping or mapping each key for all points. ProximityInfoStateUtils::updateAlignPointProbabilities( mMaxPointToKeyLength, mProximityInfo->getMostCommonKeyWidth(), mProximityInfo->getKeyCount(), lastSavedInputSize, mSampledInputSize, &mSampledInputXs, &mSampledInputYs, &mSpeedRates, &mSampledLengthCache, - &mSampledDistanceCache_G, &mSampledNearKeySets, &mCharProbabilities); + &mSampledNormalizedSquaredLengthCache, &mSampledNearKeySets, + &mCharProbabilities); ProximityInfoStateUtils::updateSampledSearchKeySets(mProximityInfo, mSampledInputSize, lastSavedInputSize, &mSampledLengthCache, &mSampledNearKeySets, &mSampledSearchKeySets, @@ -171,7 +172,7 @@ float ProximityInfoState::getPointToKeyLength( const int keyId = mProximityInfo->getKeyIndexOf(codePoint); if (keyId != NOT_AN_INDEX) { const int index = inputIndex * mProximityInfo->getKeyCount() + keyId; - return min(mSampledDistanceCache_G[index], mMaxPointToKeyLength); + return min(mSampledNormalizedSquaredLengthCache[index], mMaxPointToKeyLength); } if (isIntentionalOmissionCodePoint(codePoint)) { return 0.0f; @@ -183,7 +184,8 @@ float ProximityInfoState::getPointToKeyLength( float ProximityInfoState::getPointToKeyByIdLength( const int inputIndex, const int keyId) const { return ProximityInfoStateUtils::getPointToKeyByIdLength(mMaxPointToKeyLength, - &mSampledDistanceCache_G, mProximityInfo->getKeyCount(), inputIndex, keyId); + &mSampledNormalizedSquaredLengthCache, mProximityInfo->getKeyCount(), inputIndex, + keyId); } // In the following function, c is the current character of the dictionary word currently examined. diff --git a/native/jni/src/proximity_info_state.h b/native/jni/src/proximity_info_state.h index 9bba751d0..bbe8af240 100644 --- a/native/jni/src/proximity_info_state.h +++ b/native/jni/src/proximity_info_state.h @@ -49,8 +49,8 @@ class ProximityInfoState { mKeyCount(0), mCellHeight(0), mCellWidth(0), mGridHeight(0), mGridWidth(0), mIsContinuousSuggestionPossible(false), mSampledInputXs(), mSampledInputYs(), mSampledTimes(), mSampledInputIndice(), mSampledLengthCache(), - mBeelineSpeedPercentiles(), mSampledDistanceCache_G(), mSpeedRates(), mDirections(), - mCharProbabilities(), mSampledNearKeySets(), mSampledSearchKeySets(), + mBeelineSpeedPercentiles(), mSampledNormalizedSquaredLengthCache(), mSpeedRates(), + mDirections(), mCharProbabilities(), mSampledNearKeySets(), mSampledSearchKeySets(), mSampledSearchKeyVectors(), mTouchPositionCorrectionEnabled(false), mSampledInputSize(0), mMostProbableStringProbability(0.0f) { memset(mInputProximities, 0, sizeof(mInputProximities)); @@ -147,7 +147,9 @@ class ProximityInfoState { return mIsContinuousSuggestionPossible; } + // TODO: Rename s/Length/NormalizedSquaredLength/ float getPointToKeyByIdLength(const int inputIndex, const int keyId) const; + // TODO: Rename s/Length/NormalizedSquaredLength/ float getPointToKeyLength(const int inputIndex, const int codePoint) const; ProximityType getProximityType(const int index, const int codePoint, @@ -231,7 +233,7 @@ class ProximityInfoState { std::vector<int> mSampledInputIndice; std::vector<int> mSampledLengthCache; std::vector<int> mBeelineSpeedPercentiles; - std::vector<float> mSampledDistanceCache_G; + std::vector<float> mSampledNormalizedSquaredLengthCache; std::vector<float> mSpeedRates; std::vector<float> mDirections; // probabilities of skipping or mapping to a key for each point. diff --git a/native/jni/src/proximity_info_state_utils.cpp b/native/jni/src/proximity_info_state_utils.cpp index df70cffdf..359673cd8 100644 --- a/native/jni/src/proximity_info_state_utils.cpp +++ b/native/jni/src/proximity_info_state_utils.cpp @@ -225,13 +225,13 @@ namespace latinime { const int lastSavedInputSize, const float verticalSweetSpotScale, const std::vector<int> *const sampledInputXs, const std::vector<int> *const sampledInputYs, - std::vector<NearKeycodesSet> *SampledNearKeySets, - std::vector<float> *SampledDistanceCache_G) { - SampledNearKeySets->resize(sampledInputSize); + std::vector<NearKeycodesSet> *sampledNearKeySets, + std::vector<float> *sampledNormalizedSquaredLengthCache) { + sampledNearKeySets->resize(sampledInputSize); const int keyCount = proximityInfo->getKeyCount(); - SampledDistanceCache_G->resize(sampledInputSize * keyCount); + sampledNormalizedSquaredLengthCache->resize(sampledInputSize * keyCount); for (int i = lastSavedInputSize; i < sampledInputSize; ++i) { - (*SampledNearKeySets)[i].reset(); + (*sampledNearKeySets)[i].reset(); for (int k = 0; k < keyCount; ++k) { const int index = i * keyCount + k; const int x = (*sampledInputXs)[i]; @@ -239,10 +239,10 @@ namespace latinime { const float normalizedSquaredDistance = proximityInfo->getNormalizedSquaredDistanceFromCenterFloatG( k, x, y, verticalSweetSpotScale); - (*SampledDistanceCache_G)[index] = normalizedSquaredDistance; + (*sampledNormalizedSquaredLengthCache)[index] = normalizedSquaredDistance; if (normalizedSquaredDistance < ProximityInfoParams::NEAR_KEY_NORMALIZED_SQUARED_THRESHOLD) { - (*SampledNearKeySets)[i][k] = true; + (*sampledNearKeySets)[i][k] = true; } } } @@ -642,11 +642,11 @@ namespace latinime { // This function basically converts from a length to an edit distance. Accordingly, it's obviously // wrong to compare with mMaxPointToKeyLength. /* static */ float ProximityInfoStateUtils::getPointToKeyByIdLength(const float maxPointToKeyLength, - const std::vector<float> *const SampledDistanceCache_G, const int keyCount, + const std::vector<float> *const sampledNormalizedSquaredLengthCache, const int keyCount, const int inputIndex, const int keyId) { if (keyId != NOT_AN_INDEX) { const int index = inputIndex * keyCount + keyId; - return min((*SampledDistanceCache_G)[index], maxPointToKeyLength); + return min((*sampledNormalizedSquaredLengthCache)[index], maxPointToKeyLength); } // If the char is not a key on the keyboard then return the max length. return static_cast<float>(MAX_VALUE_FOR_WEIGHTING); @@ -660,8 +660,8 @@ namespace latinime { const std::vector<int> *const sampledInputYs, const std::vector<float> *const sampledSpeedRates, const std::vector<int> *const sampledLengthCache, - const std::vector<float> *const SampledDistanceCache_G, - std::vector<NearKeycodesSet> *SampledNearKeySets, + const std::vector<float> *const sampledNormalizedSquaredLengthCache, + std::vector<NearKeycodesSet> *sampledNearKeySets, std::vector<hash_map_compat<int, float> > *charProbabilities) { charProbabilities->resize(sampledInputSize); // Calculates probabilities of using a point as a correlated point with the character @@ -677,9 +677,9 @@ namespace latinime { float nearestKeyDistance = static_cast<float>(MAX_VALUE_FOR_WEIGHTING); for (int j = 0; j < keyCount; ++j) { - if ((*SampledNearKeySets)[i].test(j)) { + if ((*sampledNearKeySets)[i].test(j)) { const float distance = getPointToKeyByIdLength( - maxPointToKeyLength, SampledDistanceCache_G, keyCount, i, j); + maxPointToKeyLength, sampledNormalizedSquaredLengthCache, keyCount, i, j); if (distance < nearestKeyDistance) { nearestKeyDistance = distance; } @@ -758,14 +758,15 @@ namespace latinime { // Summing up probability densities of all near keys. float sumOfProbabilityDensities = 0.0f; for (int j = 0; j < keyCount; ++j) { - if ((*SampledNearKeySets)[i].test(j)) { + if ((*sampledNearKeySets)[i].test(j)) { float distance = sqrtf(getPointToKeyByIdLength( - maxPointToKeyLength, SampledDistanceCache_G, keyCount, i, j)); + maxPointToKeyLength, sampledNormalizedSquaredLengthCache, keyCount, i, j)); if (i == 0 && i != sampledInputSize - 1) { // For the first point, weighted average of distances from first point and the // next point to the key is used as a point to key distance. const float nextDistance = sqrtf(getPointToKeyByIdLength( - maxPointToKeyLength, SampledDistanceCache_G, keyCount, i + 1, j)); + maxPointToKeyLength, sampledNormalizedSquaredLengthCache, keyCount, + i + 1, j)); if (nextDistance < distance) { // The distance of the first point tends to bigger than continuing // points because the first touch by the user can be sloppy. @@ -779,7 +780,8 @@ namespace latinime { // For the first point, weighted average of distances from last point and // the previous point to the key is used as a point to key distance. const float previousDistance = sqrtf(getPointToKeyByIdLength( - maxPointToKeyLength, SampledDistanceCache_G, keyCount, i - 1, j)); + maxPointToKeyLength, sampledNormalizedSquaredLengthCache, keyCount, + i - 1, j)); if (previousDistance < distance) { // The distance of the last point tends to bigger than continuing points // because the last touch by the user can be sloppy. So we promote the @@ -798,14 +800,15 @@ namespace latinime { // Split the probability of an input point to keys that are close to the input point. for (int j = 0; j < keyCount; ++j) { - if ((*SampledNearKeySets)[i].test(j)) { + if ((*sampledNearKeySets)[i].test(j)) { float distance = sqrtf(getPointToKeyByIdLength( - maxPointToKeyLength, SampledDistanceCache_G, keyCount, i, j)); + maxPointToKeyLength, sampledNormalizedSquaredLengthCache, keyCount, i, j)); if (i == 0 && i != sampledInputSize - 1) { // For the first point, weighted average of distances from the first point and // the next point to the key is used as a point to key distance. const float prevDistance = sqrtf(getPointToKeyByIdLength( - maxPointToKeyLength, SampledDistanceCache_G, keyCount, i + 1, j)); + maxPointToKeyLength, sampledNormalizedSquaredLengthCache, keyCount, + i + 1, j)); if (prevDistance < distance) { distance = (distance + prevDistance * ProximityInfoParams::NEXT_DISTANCE_WEIGHT) @@ -815,7 +818,8 @@ namespace latinime { // For the first point, weighted average of distances from last point and // the previous point to the key is used as a point to key distance. const float prevDistance = sqrtf(getPointToKeyByIdLength( - maxPointToKeyLength, SampledDistanceCache_G, keyCount, i - 1, j)); + maxPointToKeyLength, sampledNormalizedSquaredLengthCache, keyCount, + i - 1, j)); if (prevDistance < distance) { distance = (distance + prevDistance * ProximityInfoParams::PREV_DISTANCE_WEIGHT) @@ -882,10 +886,10 @@ namespace latinime { for (int j = 0; j < keyCount; ++j) { hash_map_compat<int, float>::iterator it = (*charProbabilities)[i].find(j); if (it == (*charProbabilities)[i].end()){ - (*SampledNearKeySets)[i].reset(j); + (*sampledNearKeySets)[i].reset(j); } else if(it->second < ProximityInfoParams::MIN_PROBABILITY) { // Erases from near keys vector because it has very low probability. - (*SampledNearKeySets)[i].reset(j); + (*sampledNearKeySets)[i].reset(j); (*charProbabilities)[i].erase(j); } else { it->second = -logf(it->second); @@ -899,7 +903,7 @@ namespace latinime { const ProximityInfo *const proximityInfo, const int sampledInputSize, const int lastSavedInputSize, const std::vector<int> *const sampledLengthCache, - const std::vector<NearKeycodesSet> *const SampledNearKeySets, + const std::vector<NearKeycodesSet> *const sampledNearKeySets, std::vector<NearKeycodesSet> *sampledSearchKeySets, std::vector<std::vector<int> > *sampledSearchKeyVectors) { sampledSearchKeySets->resize(sampledInputSize); @@ -916,7 +920,7 @@ namespace latinime { if ((*sampledLengthCache)[j] - (*sampledLengthCache)[i] >= readForwordLength) { break; } - (*sampledSearchKeySets)[i] |= (*SampledNearKeySets)[j]; + (*sampledSearchKeySets)[i] |= (*sampledNearKeySets)[j]; } } const int keyCount = proximityInfo->getKeyCount(); diff --git a/native/jni/src/proximity_info_state_utils.h b/native/jni/src/proximity_info_state_utils.h index c9feb59a3..1837c7ab6 100644 --- a/native/jni/src/proximity_info_state_utils.h +++ b/native/jni/src/proximity_info_state_utils.h @@ -71,25 +71,25 @@ class ProximityInfoStateUtils { const std::vector<int> *const sampledInputYs, const std::vector<float> *const sampledSpeedRates, const std::vector<int> *const sampledLengthCache, - const std::vector<float> *const SampledDistanceCache_G, - std::vector<NearKeycodesSet> *SampledNearKeySets, + const std::vector<float> *const sampledNormalizedSquaredLengthCache, + std::vector<NearKeycodesSet> *sampledNearKeySets, std::vector<hash_map_compat<int, float> > *charProbabilities); static void updateSampledSearchKeySets(const ProximityInfo *const proximityInfo, const int sampledInputSize, const int lastSavedInputSize, const std::vector<int> *const sampledLengthCache, - const std::vector<NearKeycodesSet> *const SampledNearKeySets, + const std::vector<NearKeycodesSet> *const sampledNearKeySets, std::vector<NearKeycodesSet> *sampledSearchKeySets, std::vector<std::vector<int> > *sampledSearchKeyVectors); static float getPointToKeyByIdLength(const float maxPointToKeyLength, - const std::vector<float> *const SampledDistanceCache_G, const int keyCount, + const std::vector<float> *const sampledNormalizedSquaredLengthCache, const int keyCount, const int inputIndex, const int keyId); static void initGeometricDistanceInfos(const ProximityInfo *const proximityInfo, const int sampledInputSize, const int lastSavedInputSize, const float verticalSweetSpotScale, const std::vector<int> *const sampledInputXs, const std::vector<int> *const sampledInputYs, - std::vector<NearKeycodesSet> *SampledNearKeySets, - std::vector<float> *SampledDistanceCache_G); + std::vector<NearKeycodesSet> *sampledNearKeySets, + std::vector<float> *sampledNormalizedSquaredLengthCache); static void initPrimaryInputWord(const int inputSize, const int *const inputProximities, int *primaryInputWord); static void initNormalizedSquaredDistances(const ProximityInfo *const proximityInfo, diff --git a/native/jni/src/suggest/core/dicnode/dic_node.h b/native/jni/src/suggest/core/dicnode/dic_node.h index cde7b99a7..32faae52c 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node.h +++ b/native/jni/src/suggest/core/dicnode/dic_node.h @@ -23,6 +23,7 @@ #include "dic_node_profiler.h" #include "dic_node_properties.h" #include "dic_node_release_listener.h" +#include "digraph_utils.h" #if DEBUG_DICT #define LOGI_SHOW_ADD_COST_PROP \ @@ -399,8 +400,15 @@ class DicNode { // TODO: Remove // ////////////////////// // TODO: Remove once touch path is merged into ProximityInfoState + // Note: Returned codepoint may be a digraph codepoint if the node is in a composite glyph. int getNodeCodePoint() const { - return mDicNodeProperties.getNodeCodePoint(); + const int codePoint = mDicNodeProperties.getNodeCodePoint(); + const DigraphUtils::DigraphCodePointIndex digraphIndex = + mDicNodeState.mDicNodeStateScoring.getDigraphIndex(); + if (digraphIndex == DigraphUtils::NOT_A_DIGRAPH_INDEX) { + return codePoint; + } + return DigraphUtils::getDigraphCodePointForIndex(codePoint, digraphIndex); } //////////////////////////////// @@ -452,6 +460,15 @@ class DicNode { mDicNodeState.mDicNodeStateScoring.setDoubleLetterLevel(doubleLetterLevel); } + bool isInDigraph() const { + return mDicNodeState.mDicNodeStateScoring.getDigraphIndex() + != DigraphUtils::NOT_A_DIGRAPH_INDEX; + } + + void advanceDigraphIndex() { + mDicNodeState.mDicNodeStateScoring.advanceDigraphIndex(); + } + uint8_t getFlags() const { return mDicNodeProperties.getFlags(); } diff --git a/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h b/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h index 8e816329f..8902d3122 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h +++ b/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h @@ -20,6 +20,7 @@ #include <stdint.h> #include "defines.h" +#include "digraph_utils.h" namespace latinime { @@ -27,6 +28,7 @@ class DicNodeStateScoring { public: AK_FORCE_INLINE DicNodeStateScoring() : mDoubleLetterLevel(NOT_A_DOUBLE_LETTER), + mDigraphIndex(DigraphUtils::NOT_A_DIGRAPH_INDEX), mEditCorrectionCount(0), mProximityCorrectionCount(0), mNormalizedCompoundDistance(0.0f), mSpatialDistance(0.0f), mLanguageDistance(0.0f), mTotalPrevWordsLanguageCost(0.0f), mRawLength(0.0f) { @@ -43,6 +45,7 @@ class DicNodeStateScoring { mTotalPrevWordsLanguageCost = 0.0f; mRawLength = 0.0f; mDoubleLetterLevel = NOT_A_DOUBLE_LETTER; + mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX; } AK_FORCE_INLINE void init(const DicNodeStateScoring *const scoring) { @@ -54,6 +57,7 @@ class DicNodeStateScoring { mTotalPrevWordsLanguageCost = scoring->mTotalPrevWordsLanguageCost; mRawLength = scoring->mRawLength; mDoubleLetterLevel = scoring->mDoubleLetterLevel; + mDigraphIndex = scoring->mDigraphIndex; } void addCost(const float spatialCost, const float languageCost, const bool doNormalization, @@ -126,6 +130,24 @@ class DicNodeStateScoring { } } + DigraphUtils::DigraphCodePointIndex getDigraphIndex() const { + return mDigraphIndex; + } + + void advanceDigraphIndex() { + switch(mDigraphIndex) { + case DigraphUtils::NOT_A_DIGRAPH_INDEX: + mDigraphIndex = DigraphUtils::FIRST_DIGRAPH_CODEPOINT; + break; + case DigraphUtils::FIRST_DIGRAPH_CODEPOINT: + mDigraphIndex = DigraphUtils::SECOND_DIGRAPH_CODEPOINT; + break; + case DigraphUtils::SECOND_DIGRAPH_CODEPOINT: + mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX; + break; + } + } + float getTotalPrevWordsLanguageCost() const { return mTotalPrevWordsLanguageCost; } @@ -135,6 +157,7 @@ class DicNodeStateScoring { // Use a default copy constructor and an assign operator because shallow copies are ok // for this class DoubleLetterLevel mDoubleLetterLevel; + DigraphUtils::DigraphCodePointIndex mDigraphIndex; int16_t mEditCorrectionCount; int16_t mProximityCorrectionCount; diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.cpp b/native/jni/src/suggest/core/session/dic_traverse_session.cpp index ef6616e40..5b783a2ba 100644 --- a/native/jni/src/suggest/core/session/dic_traverse_session.cpp +++ b/native/jni/src/suggest/core/session/dic_traverse_session.cpp @@ -84,6 +84,10 @@ const uint8_t *DicTraverseSession::getOffsetDict() const { return mDictionary->getOffsetDict(); } +int DicTraverseSession::getDictFlags() const { + return mDictionary->getDictFlags(); +} + void DicTraverseSession::resetCache(const int nextActiveCacheSize, const int maxWords) { mDicNodesCache.reset(nextActiveCacheSize, maxWords); mBigramCacheMap.clear(); diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.h b/native/jni/src/suggest/core/session/dic_traverse_session.h index 62e1d1ab9..fe0527639 100644 --- a/native/jni/src/suggest/core/session/dic_traverse_session.h +++ b/native/jni/src/suggest/core/session/dic_traverse_session.h @@ -53,7 +53,7 @@ class DicTraverseSession { void resetCache(const int nextActiveCacheSize, const int maxWords); const uint8_t *getOffsetDict() const; - bool canUseCache() const; + int getDictFlags() const; //-------------------- // getters and setters @@ -134,7 +134,7 @@ class DicTraverseSession { if (!mDicNodesCache.hasCachedDicNodesForContinuousSuggestion()) { return false; } - ASSERT(mMaxPointerCount < MAX_POINTER_COUNT_G); + ASSERT(mMaxPointerCount <= MAX_POINTER_COUNT_G); for (int i = 0; i < mMaxPointerCount; ++i) { const ProximityInfoState *const pInfoState = getProximityInfoState(i); // If a proximity info state is not continuous suggestion possible, @@ -146,6 +146,10 @@ class DicTraverseSession { return true; } + bool isTouchPositionCorrectionEnabled() const { + return mProximityInfoStates[0].touchPositionCorrectionEnabled(); + } + private: DISALLOW_IMPLICIT_CONSTRUCTORS(DicTraverseSession); // threshold to start caching diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp index 764c37240..67d351fa1 100644 --- a/native/jni/src/suggest/core/suggest.cpp +++ b/native/jni/src/suggest/core/suggest.cpp @@ -18,6 +18,7 @@ #include "char_utils.h" #include "dictionary.h" +#include "digraph_utils.h" #include "proximity_info.h" #include "suggest/core/dicnode/dic_node.h" #include "suggest/core/dicnode/dic_node_priority_queue.h" @@ -123,8 +124,12 @@ void Suggest::initializeSearch(DicTraverseSession *traverseSession, int commitPo */ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequencies, int *outputCodePoints, int *spaceIndices, int *outputTypes) const { +#if DEBUG_EVALUATE_MOST_PROBABLE_STRING + const int terminalSize = 0; +#else const int terminalSize = min(MAX_RESULTS, static_cast<int>(traverseSession->getDicTraverseCache()->terminalSize())); +#endif DicNode terminals[MAX_RESULTS]; // Avoiding non-POD variable length array for (int index = terminalSize - 1; index >= 0; --index) { @@ -221,7 +226,7 @@ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequen void Suggest::expandCurrentDicNodes(DicTraverseSession *traverseSession) const { const int inputSize = traverseSession->getInputSize(); DicNodeVector childDicNodes(TRAVERSAL->getDefaultExpandDicNodeSize()); - DicNode omissionDicNode; + DicNode correctionDicNode; // TODO: Find more efficient caching const bool shouldDepthLevelCache = TRAVERSAL->shouldDepthLevelCache(traverseSession); @@ -257,7 +262,10 @@ void Suggest::expandCurrentDicNodes(DicTraverseSession *traverseSession) const { dicNode.setCached(); } - if (isLookAheadCorrection) { + if (dicNode.isInDigraph()) { + // Finish digraph handling if the node is in the middle of a digraph expansion. + processDicNodeAsDigraph(traverseSession, &dicNode); + } else if (isLookAheadCorrection) { // The algorithm maintains a small set of "deferred" nodes that have not consumed the // latest touch point yet. These are needed to apply look-ahead correction operations // that require special handling of the latest touch point. For example, with insertions @@ -291,12 +299,18 @@ void Suggest::expandCurrentDicNodes(DicTraverseSession *traverseSession) const { processDicNodeAsMatch(traverseSession, childDicNode); continue; } + if (DigraphUtils::hasDigraphForCodePoint(traverseSession->getDictFlags(), + childDicNode->getNodeCodePoint())) { + correctionDicNode.initByCopy(childDicNode); + correctionDicNode.advanceDigraphIndex(); + processDicNodeAsDigraph(traverseSession, &correctionDicNode); + } if (allowsErrorCorrections && TRAVERSAL->isOmission(traverseSession, &dicNode, childDicNode)) { // TODO: (Gesture) Change weight between omission and substitution errors // TODO: (Gesture) Terminal node should not be handled as omission - omissionDicNode.initByCopy(childDicNode); - processDicNodeAsOmission(traverseSession, &omissionDicNode); + correctionDicNode.initByCopy(childDicNode); + processDicNodeAsOmission(traverseSession, &correctionDicNode); } const ProximityType proximityType = TRAVERSAL->getProximityType( traverseSession, &dicNode, childDicNode); @@ -400,6 +414,16 @@ void Suggest::processDicNodeAsSubstitution(DicTraverseSession *traverseSession, processExpandedDicNode(traverseSession, childDicNode); } +// Process the node codepoint as a digraph. This means that composite glyphs like the German +// u-umlaut is expanded to the transliteration "ue". Note that this happens in parallel with +// the normal non-digraph traversal, so both "uber" and "ueber" can be corrected to "[u-umlaut]ber". +void Suggest::processDicNodeAsDigraph(DicTraverseSession *traverseSession, + DicNode *childDicNode) const { + weightChildNode(traverseSession, childDicNode); + childDicNode->advanceDigraphIndex(); + processExpandedDicNode(traverseSession, childDicNode); +} + /** * Handle the dicNode as an omission error (e.g., ths => this). Skip the current letter and consider * matches for all possible next letters. Note that just skipping the current letter without any diff --git a/native/jni/src/suggest/core/suggest.h b/native/jni/src/suggest/core/suggest.h index 9f609c50c..becd6c1de 100644 --- a/native/jni/src/suggest/core/suggest.h +++ b/native/jni/src/suggest/core/suggest.h @@ -65,6 +65,7 @@ class Suggest : public SuggestInterface { void generateFeatures( DicTraverseSession *traverseSession, DicNode *dicNode, float *features) const; void processDicNodeAsOmission(DicTraverseSession *traverseSession, DicNode *dicNode) const; + void processDicNodeAsDigraph(DicTraverseSession *traverseSession, DicNode *dicNode) const; void processDicNodeAsTransposition(DicTraverseSession *traverseSession, DicNode *dicNode) const; void processDicNodeAsInsertion(DicTraverseSession *traverseSession, DicNode *dicNode) const; diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h index 52d54eb0f..2dcee343f 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h @@ -18,6 +18,7 @@ #define LATINIME_TYPING_WEIGHTING_H #include "defines.h" +#include "suggest_utils.h" #include "suggest/core/dicnode/dic_node_utils.h" #include "suggest/core/policy/weighting.h" #include "suggest/core/session/dic_traverse_session.h" @@ -70,10 +71,12 @@ class TypingWeighting : public Weighting { const int pointIndex = dicNode->getInputIndex(0); // Note: min() required since length can be MAX_POINT_TO_KEY_LENGTH for characters not on // the keyboard (like accented letters) - const float length = min(ScoringParams::MAX_SPATIAL_DISTANCE, - traverseSession->getProximityInfoState(0)->getPointToKeyLength( - pointIndex, dicNode->getNodeCodePoint())); - const float weightedDistance = length * ScoringParams::DISTANCE_WEIGHT_LENGTH; + const float normalizedSquaredLength = traverseSession->getProximityInfoState(0) + ->getPointToKeyLength(pointIndex, dicNode->getNodeCodePoint()); + const float normalizedDistance = SuggestUtils::getSweetSpotFactor( + traverseSession->isTouchPositionCorrectionEnabled(), normalizedSquaredLength); + const float weightedDistance = ScoringParams::DISTANCE_WEIGHT_LENGTH * normalizedDistance; + const bool isFirstChar = pointIndex == 0; const bool isProximity = isProximityDicNode(traverseSession, dicNode); const float cost = isProximity ? (isFirstChar ? ScoringParams::FIRST_PROXIMITY_COST diff --git a/native/jni/src/suggest_utils.h b/native/jni/src/suggest_utils.h index aab9f7ba8..e053dd662 100644 --- a/native/jni/src/suggest_utils.h +++ b/native/jni/src/suggest_utils.h @@ -23,10 +23,8 @@ namespace latinime { class SuggestUtils { public: - static float getDistanceScalingFactor(const float normalizedSquaredDistance) { - if (normalizedSquaredDistance < 0.0f) { - return -1.0f; - } + // TODO: (OLD) Remove + static float getLengthScalingFactor(const float normalizedSquaredDistance) { // Promote or demote the score according to the distance from the sweet spot static const float A = ZERO_DISTANCE_PROMOTION_RATE / 100.0f; static const float B = 1.0f; @@ -50,6 +48,39 @@ class SuggestUtils { return factor; } + static float getSweetSpotFactor(const bool isTouchPositionCorrectionEnabled, + const float normalizedSquaredDistance) { + // Promote or demote the score according to the distance from the sweet spot + static const float A = 0.0f; + static const float B = 0.24f; + static const float C = 1.20f; + static const float R0 = 0.0f; + static const float R1 = 0.25f; // Sweet spot + static const float R2 = 1.0f; + const float x = normalizedSquaredDistance; + if (!isTouchPositionCorrectionEnabled) { + return min(C, x); + } + + // factor is a piecewise linear function like: + // C -------------. + // / . + // B / . + // -/ . + // A _-^ . + // . + // R0 R1 R2 . + + if (x < R0) { + return A; + } else if (x < R1) { + return (A * (R1 - x) + B * (x - R0)) / (R1 - R0); + } else if (x < R2) { + return (B * (R2 - x) + C * (x - R1)) / (R2 - R1); + } else { + return C; + } + } private: DISALLOW_IMPLICIT_CONSTRUCTORS(SuggestUtils); }; diff --git a/native/jni/src/unigram_dictionary.cpp b/native/jni/src/unigram_dictionary.cpp index 33795cade..a672294b5 100644 --- a/native/jni/src/unigram_dictionary.cpp +++ b/native/jni/src/unigram_dictionary.cpp @@ -32,9 +32,9 @@ namespace latinime { // TODO: check the header -UnigramDictionary::UnigramDictionary(const uint8_t *const streamStart, const unsigned int flags) +UnigramDictionary::UnigramDictionary(const uint8_t *const streamStart, const unsigned int dictFlags) : DICT_ROOT(streamStart), ROOT_POS(0), - MAX_DIGRAPH_SEARCH_DEPTH(DEFAULT_MAX_DIGRAPH_SEARCH_DEPTH), FLAGS(flags) { + MAX_DIGRAPH_SEARCH_DEPTH(DEFAULT_MAX_DIGRAPH_SEARCH_DEPTH), DICT_FLAGS(dictFlags) { if (DEBUG_DICT) { AKLOGI("UnigramDictionary - constructor"); } @@ -163,7 +163,7 @@ int UnigramDictionary::getSuggestions(ProximityInfo *proximityInfo, const int *x masterCorrection.resetCorrection(); const DigraphUtils::digraph_t *digraphs = 0; const int digraphsSize = - DigraphUtils::getAllDigraphsForDictionaryAndReturnSize(FLAGS, &digraphs); + DigraphUtils::getAllDigraphsForDictionaryAndReturnSize(DICT_FLAGS, &digraphs); if (digraphsSize > 0) { // Incrementally tune the word and try all possibilities int codesBuffer[sizeof(*inputCodePoints) * inputSize]; diff --git a/native/jni/src/unigram_dictionary.h b/native/jni/src/unigram_dictionary.h index 1a01758fe..a64a539bd 100644 --- a/native/jni/src/unigram_dictionary.h +++ b/native/jni/src/unigram_dictionary.h @@ -38,7 +38,7 @@ class UnigramDictionary { static const int FLAG_MULTIPLE_SUGGEST_ABORT = 0; static const int FLAG_MULTIPLE_SUGGEST_SKIP = 1; static const int FLAG_MULTIPLE_SUGGEST_CONTINUE = 2; - UnigramDictionary(const uint8_t *const streamStart, const unsigned int flags); + UnigramDictionary(const uint8_t *const streamStart, const unsigned int dictFlags); int getProbability(const int *const inWord, const int length) const; int getBigramPosition(int pos, int *word, int offset, int length) const; int getSuggestions(ProximityInfo *proximityInfo, const int *xcoordinates, @@ -46,6 +46,7 @@ class UnigramDictionary { const std::map<int, int> *bigramMap, const uint8_t *bigramFilter, const bool useFullEditDistance, int *outWords, int *frequencies, int *outputTypes) const; + int getDictFlags() const { return DICT_FLAGS; } virtual ~UnigramDictionary(); private: @@ -109,7 +110,7 @@ class UnigramDictionary { const uint8_t *const DICT_ROOT; const int ROOT_POS; const int MAX_DIGRAPH_SEARCH_DEPTH; - const int FLAGS; + const int DICT_FLAGS; }; } // namespace latinime #endif // LATINIME_UNIGRAM_DICTIONARY_H |