diff options
Diffstat (limited to 'native/jni/src')
23 files changed, 356 insertions, 156 deletions
diff --git a/native/jni/src/binary_format.h b/native/jni/src/binary_format.h index 1c4061fd8..2d2e19501 100644 --- a/native/jni/src/binary_format.h +++ b/native/jni/src/binary_format.h @@ -92,6 +92,7 @@ class BinaryFormat { const int unigramProbability, const int bigramProbability); static int getProbability(const int position, const std::map<int, int> *bigramMap, const uint8_t *bigramFilter, const int unigramProbability); + static float getMultiWordCostMultiplier(const uint8_t *const dict); // Flags for special processing // Those *must* match the flags in makedict (BinaryDictInputOutput#*_PROCESSING_FLAG) or @@ -241,6 +242,17 @@ AK_FORCE_INLINE int BinaryFormat::getGroupCountAndForwardPointer(const uint8_t * return ((msb & 0x7F) << 8) | dict[(*pos)++]; } +inline float BinaryFormat::getMultiWordCostMultiplier(const uint8_t *const dict) { + const int headerValue = readHeaderValueInt(dict, "MULTIPLE_WORDS_DEMOTION_RATE"); + if (headerValue == S_INT_MIN) { + return 1.0f; + } + if (headerValue <= 0) { + return static_cast<float>(MAX_VALUE_FOR_WEIGHTING); + } + return 100.0f / static_cast<float>(headerValue); +} + inline uint8_t BinaryFormat::getFlagsAndForwardPointer(const uint8_t *const dict, int *pos) { return dict[(*pos)++]; } diff --git a/native/jni/src/char_utils.cpp b/native/jni/src/char_utils.cpp index 8d917ea74..e219beb62 100644 --- a/native/jni/src/char_utils.cpp +++ b/native/jni/src/char_utils.cpp @@ -45,18 +45,16 @@ struct LatinCapitalSmallPair { extern "C" int main() { for (unsigned short c = 0; c < 0xFFFF; c++) { - const unsigned short baseC = c < NELEMS(BASE_CHARS) ? BASE_CHARS[c] : c; - if (baseC <= 0x7F) continue; - const unsigned short icu4cLowerBaseC = u_tolower(baseC); - const unsigned short myLowerBaseC = latin_tolower(baseC); - if (baseC != icu4cLowerBaseC) { + if (c <= 0x7F) continue; + const unsigned short icu4cLowerC = u_tolower(c); + const unsigned short myLowerC = latin_tolower(c); + if (c != icu4cLowerC) { #ifdef CONFIRMING_CHAR_UTILS - if (icu4cLowerBaseC != myLowerBaseC) { - fprintf(stderr, "icu4cLowerBaseC != myLowerBaseC, 0x%04X, 0x%04X\n", - icu4cLowerBaseC, myLowerBaseC); + if (icu4cLowerC != myLowerC) { + fprintf(stderr, "icu4cLowerC != myLowerC, 0x%04X, 0x%04X\n", icu4cLowerC, myLowerC); } #else // CONFIRMING_CHAR_UTILS - printf("0x%04X, 0x%04X\n", baseC, icu4cLowerBaseC); + printf("0x%04X, 0x%04X\n", c, icu4cLowerC); #endif // CONFIRMING_CHAR_UTILS } } @@ -77,14 +75,99 @@ extern "C" int main() { * $ */ static const struct LatinCapitalSmallPair SORTED_CHAR_MAP[] = { + { 0x00C0, 0x00E0 }, // LATIN CAPITAL LETTER A WITH GRAVE + { 0x00C1, 0x00E1 }, // LATIN CAPITAL LETTER A WITH ACUTE + { 0x00C2, 0x00E2 }, // LATIN CAPITAL LETTER A WITH CIRCUMFLEX + { 0x00C3, 0x00E3 }, // LATIN CAPITAL LETTER A WITH TILDE + { 0x00C4, 0x00E4 }, // LATIN CAPITAL LETTER A WITH DIAERESIS + { 0x00C5, 0x00E5 }, // LATIN CAPITAL LETTER A WITH RING ABOVE { 0x00C6, 0x00E6 }, // LATIN CAPITAL LETTER AE + { 0x00C7, 0x00E7 }, // LATIN CAPITAL LETTER C WITH CEDILLA + { 0x00C8, 0x00E8 }, // LATIN CAPITAL LETTER E WITH GRAVE + { 0x00C9, 0x00E9 }, // LATIN CAPITAL LETTER E WITH ACUTE + { 0x00CA, 0x00EA }, // LATIN CAPITAL LETTER E WITH CIRCUMFLEX + { 0x00CB, 0x00EB }, // LATIN CAPITAL LETTER E WITH DIAERESIS + { 0x00CC, 0x00EC }, // LATIN CAPITAL LETTER I WITH GRAVE + { 0x00CD, 0x00ED }, // LATIN CAPITAL LETTER I WITH ACUTE + { 0x00CE, 0x00EE }, // LATIN CAPITAL LETTER I WITH CIRCUMFLEX + { 0x00CF, 0x00EF }, // LATIN CAPITAL LETTER I WITH DIAERESIS { 0x00D0, 0x00F0 }, // LATIN CAPITAL LETTER ETH + { 0x00D1, 0x00F1 }, // LATIN CAPITAL LETTER N WITH TILDE + { 0x00D2, 0x00F2 }, // LATIN CAPITAL LETTER O WITH GRAVE + { 0x00D3, 0x00F3 }, // LATIN CAPITAL LETTER O WITH ACUTE + { 0x00D4, 0x00F4 }, // LATIN CAPITAL LETTER O WITH CIRCUMFLEX + { 0x00D5, 0x00F5 }, // LATIN CAPITAL LETTER O WITH TILDE + { 0x00D6, 0x00F6 }, // LATIN CAPITAL LETTER O WITH DIAERESIS + { 0x00D8, 0x00F8 }, // LATIN CAPITAL LETTER O WITH STROKE + { 0x00D9, 0x00F9 }, // LATIN CAPITAL LETTER U WITH GRAVE + { 0x00DA, 0x00FA }, // LATIN CAPITAL LETTER U WITH ACUTE + { 0x00DB, 0x00FB }, // LATIN CAPITAL LETTER U WITH CIRCUMFLEX + { 0x00DC, 0x00FC }, // LATIN CAPITAL LETTER U WITH DIAERESIS + { 0x00DD, 0x00FD }, // LATIN CAPITAL LETTER Y WITH ACUTE { 0x00DE, 0x00FE }, // LATIN CAPITAL LETTER THORN + { 0x0100, 0x0101 }, // LATIN CAPITAL LETTER A WITH MACRON + { 0x0102, 0x0103 }, // LATIN CAPITAL LETTER A WITH BREVE + { 0x0104, 0x0105 }, // LATIN CAPITAL LETTER A WITH OGONEK + { 0x0106, 0x0107 }, // LATIN CAPITAL LETTER C WITH ACUTE + { 0x0108, 0x0109 }, // LATIN CAPITAL LETTER C WITH CIRCUMFLEX + { 0x010A, 0x010B }, // LATIN CAPITAL LETTER C WITH DOT ABOVE + { 0x010C, 0x010D }, // LATIN CAPITAL LETTER C WITH CARON + { 0x010E, 0x010F }, // LATIN CAPITAL LETTER D WITH CARON { 0x0110, 0x0111 }, // LATIN CAPITAL LETTER D WITH STROKE + { 0x0112, 0x0113 }, // LATIN CAPITAL LETTER E WITH MACRON + { 0x0114, 0x0115 }, // LATIN CAPITAL LETTER E WITH BREVE + { 0x0116, 0x0117 }, // LATIN CAPITAL LETTER E WITH DOT ABOVE + { 0x0118, 0x0119 }, // LATIN CAPITAL LETTER E WITH OGONEK + { 0x011A, 0x011B }, // LATIN CAPITAL LETTER E WITH CARON + { 0x011C, 0x011D }, // LATIN CAPITAL LETTER G WITH CIRCUMFLEX + { 0x011E, 0x011F }, // LATIN CAPITAL LETTER G WITH BREVE + { 0x0120, 0x0121 }, // LATIN CAPITAL LETTER G WITH DOT ABOVE + { 0x0122, 0x0123 }, // LATIN CAPITAL LETTER G WITH CEDILLA + { 0x0124, 0x0125 }, // LATIN CAPITAL LETTER H WITH CIRCUMFLEX { 0x0126, 0x0127 }, // LATIN CAPITAL LETTER H WITH STROKE + { 0x0128, 0x0129 }, // LATIN CAPITAL LETTER I WITH TILDE + { 0x012A, 0x012B }, // LATIN CAPITAL LETTER I WITH MACRON + { 0x012C, 0x012D }, // LATIN CAPITAL LETTER I WITH BREVE + { 0x012E, 0x012F }, // LATIN CAPITAL LETTER I WITH OGONEK + { 0x0130, 0x0069 }, // LATIN CAPITAL LETTER I WITH DOT ABOVE + { 0x0132, 0x0133 }, // LATIN CAPITAL LIGATURE IJ + { 0x0134, 0x0135 }, // LATIN CAPITAL LETTER J WITH CIRCUMFLEX + { 0x0136, 0x0137 }, // LATIN CAPITAL LETTER K WITH CEDILLA + { 0x0139, 0x013A }, // LATIN CAPITAL LETTER L WITH ACUTE + { 0x013B, 0x013C }, // LATIN CAPITAL LETTER L WITH CEDILLA + { 0x013D, 0x013E }, // LATIN CAPITAL LETTER L WITH CARON + { 0x013F, 0x0140 }, // LATIN CAPITAL LETTER L WITH MIDDLE DOT + { 0x0141, 0x0142 }, // LATIN CAPITAL LETTER L WITH STROKE + { 0x0143, 0x0144 }, // LATIN CAPITAL LETTER N WITH ACUTE + { 0x0145, 0x0146 }, // LATIN CAPITAL LETTER N WITH CEDILLA + { 0x0147, 0x0148 }, // LATIN CAPITAL LETTER N WITH CARON { 0x014A, 0x014B }, // LATIN CAPITAL LETTER ENG + { 0x014C, 0x014D }, // LATIN CAPITAL LETTER O WITH MACRON + { 0x014E, 0x014F }, // LATIN CAPITAL LETTER O WITH BREVE + { 0x0150, 0x0151 }, // LATIN CAPITAL LETTER O WITH DOUBLE ACUTE { 0x0152, 0x0153 }, // LATIN CAPITAL LIGATURE OE + { 0x0154, 0x0155 }, // LATIN CAPITAL LETTER R WITH ACUTE + { 0x0156, 0x0157 }, // LATIN CAPITAL LETTER R WITH CEDILLA + { 0x0158, 0x0159 }, // LATIN CAPITAL LETTER R WITH CARON + { 0x015A, 0x015B }, // LATIN CAPITAL LETTER S WITH ACUTE + { 0x015C, 0x015D }, // LATIN CAPITAL LETTER S WITH CIRCUMFLEX + { 0x015E, 0x015F }, // LATIN CAPITAL LETTER S WITH CEDILLA + { 0x0160, 0x0161 }, // LATIN CAPITAL LETTER S WITH CARON + { 0x0162, 0x0163 }, // LATIN CAPITAL LETTER T WITH CEDILLA + { 0x0164, 0x0165 }, // LATIN CAPITAL LETTER T WITH CARON { 0x0166, 0x0167 }, // LATIN CAPITAL LETTER T WITH STROKE + { 0x0168, 0x0169 }, // LATIN CAPITAL LETTER U WITH TILDE + { 0x016A, 0x016B }, // LATIN CAPITAL LETTER U WITH MACRON + { 0x016C, 0x016D }, // LATIN CAPITAL LETTER U WITH BREVE + { 0x016E, 0x016F }, // LATIN CAPITAL LETTER U WITH RING ABOVE + { 0x0170, 0x0171 }, // LATIN CAPITAL LETTER U WITH DOUBLE ACUTE + { 0x0172, 0x0173 }, // LATIN CAPITAL LETTER U WITH OGONEK + { 0x0174, 0x0175 }, // LATIN CAPITAL LETTER W WITH CIRCUMFLEX + { 0x0176, 0x0177 }, // LATIN CAPITAL LETTER Y WITH CIRCUMFLEX + { 0x0178, 0x00FF }, // LATIN CAPITAL LETTER Y WITH DIAERESIS + { 0x0179, 0x017A }, // LATIN CAPITAL LETTER Z WITH ACUTE + { 0x017B, 0x017C }, // LATIN CAPITAL LETTER Z WITH DOT ABOVE + { 0x017D, 0x017E }, // LATIN CAPITAL LETTER Z WITH CARON { 0x0181, 0x0253 }, // LATIN CAPITAL LETTER B WITH HOOK { 0x0182, 0x0183 }, // LATIN CAPITAL LETTER B WITH TOPBAR { 0x0184, 0x0185 }, // LATIN CAPITAL LETTER TONE SIX @@ -105,6 +188,7 @@ static const struct LatinCapitalSmallPair SORTED_CHAR_MAP[] = { { 0x019C, 0x026F }, // LATIN CAPITAL LETTER TURNED M { 0x019D, 0x0272 }, // LATIN CAPITAL LETTER N WITH LEFT HOOK { 0x019F, 0x0275 }, // LATIN CAPITAL LETTER O WITH MIDDLE TILDE + { 0x01A0, 0x01A1 }, // LATIN CAPITAL LETTER O WITH HORN { 0x01A2, 0x01A3 }, // LATIN CAPITAL LETTER OI { 0x01A4, 0x01A5 }, // LATIN CAPITAL LETTER P WITH HOOK { 0x01A6, 0x0280 }, // LATIN LETTER YR @@ -112,6 +196,7 @@ static const struct LatinCapitalSmallPair SORTED_CHAR_MAP[] = { { 0x01A9, 0x0283 }, // LATIN CAPITAL LETTER ESH { 0x01AC, 0x01AD }, // LATIN CAPITAL LETTER T WITH HOOK { 0x01AE, 0x0288 }, // LATIN CAPITAL LETTER T WITH RETROFLEX HOOK + { 0x01AF, 0x01B0 }, // LATIN CAPITAL LETTER U WITH HORN { 0x01B1, 0x028A }, // LATIN CAPITAL LETTER UPSILON { 0x01B2, 0x028B }, // LATIN CAPITAL LETTER V WITH HOOK { 0x01B3, 0x01B4 }, // LATIN CAPITAL LETTER Y WITH HOOK @@ -119,13 +204,64 @@ static const struct LatinCapitalSmallPair SORTED_CHAR_MAP[] = { { 0x01B7, 0x0292 }, // LATIN CAPITAL LETTER EZH { 0x01B8, 0x01B9 }, // LATIN CAPITAL LETTER EZH REVERSED { 0x01BC, 0x01BD }, // LATIN CAPITAL LETTER TONE FIVE + { 0x01C4, 0x01C6 }, // LATIN CAPITAL LETTER DZ WITH CARON + { 0x01C5, 0x01C6 }, // LATIN CAPITAL LETTER D WITH SMALL LETTER Z WITH CARON + { 0x01C7, 0x01C9 }, // LATIN CAPITAL LETTER LJ + { 0x01C8, 0x01C9 }, // LATIN CAPITAL LETTER L WITH SMALL LETTER J + { 0x01CA, 0x01CC }, // LATIN CAPITAL LETTER NJ + { 0x01CB, 0x01CC }, // LATIN CAPITAL LETTER N WITH SMALL LETTER J + { 0x01CD, 0x01CE }, // LATIN CAPITAL LETTER A WITH CARON + { 0x01CF, 0x01D0 }, // LATIN CAPITAL LETTER I WITH CARON + { 0x01D1, 0x01D2 }, // LATIN CAPITAL LETTER O WITH CARON + { 0x01D3, 0x01D4 }, // LATIN CAPITAL LETTER U WITH CARON + { 0x01D5, 0x01D6 }, // LATIN CAPITAL LETTER U WITH DIAERESIS AND MACRON + { 0x01D7, 0x01D8 }, // LATIN CAPITAL LETTER U WITH DIAERESIS AND ACUTE + { 0x01D9, 0x01DA }, // LATIN CAPITAL LETTER U WITH DIAERESIS AND CARON + { 0x01DB, 0x01DC }, // LATIN CAPITAL LETTER U WITH DIAERESIS AND GRAVE + { 0x01DE, 0x01DF }, // LATIN CAPITAL LETTER A WITH DIAERESIS AND MACRON + { 0x01E0, 0x01E1 }, // LATIN CAPITAL LETTER A WITH DOT ABOVE AND MACRON + { 0x01E2, 0x01E3 }, // LATIN CAPITAL LETTER AE WITH MACRON { 0x01E4, 0x01E5 }, // LATIN CAPITAL LETTER G WITH STROKE + { 0x01E6, 0x01E7 }, // LATIN CAPITAL LETTER G WITH CARON + { 0x01E8, 0x01E9 }, // LATIN CAPITAL LETTER K WITH CARON + { 0x01EA, 0x01EB }, // LATIN CAPITAL LETTER O WITH OGONEK + { 0x01EC, 0x01ED }, // LATIN CAPITAL LETTER O WITH OGONEK AND MACRON + { 0x01EE, 0x01EF }, // LATIN CAPITAL LETTER EZH WITH CARON + { 0x01F1, 0x01F3 }, // LATIN CAPITAL LETTER DZ + { 0x01F2, 0x01F3 }, // LATIN CAPITAL LETTER D WITH SMALL LETTER Z + { 0x01F4, 0x01F5 }, // LATIN CAPITAL LETTER G WITH ACUTE { 0x01F6, 0x0195 }, // LATIN CAPITAL LETTER HWAIR { 0x01F7, 0x01BF }, // LATIN CAPITAL LETTER WYNN + { 0x01F8, 0x01F9 }, // LATIN CAPITAL LETTER N WITH GRAVE + { 0x01FA, 0x01FB }, // LATIN CAPITAL LETTER A WITH RING ABOVE AND ACUTE + { 0x01FC, 0x01FD }, // LATIN CAPITAL LETTER AE WITH ACUTE + { 0x01FE, 0x01FF }, // LATIN CAPITAL LETTER O WITH STROKE AND ACUTE + { 0x0200, 0x0201 }, // LATIN CAPITAL LETTER A WITH DOUBLE GRAVE + { 0x0202, 0x0203 }, // LATIN CAPITAL LETTER A WITH INVERTED BREVE + { 0x0204, 0x0205 }, // LATIN CAPITAL LETTER E WITH DOUBLE GRAVE + { 0x0206, 0x0207 }, // LATIN CAPITAL LETTER E WITH INVERTED BREVE + { 0x0208, 0x0209 }, // LATIN CAPITAL LETTER I WITH DOUBLE GRAVE + { 0x020A, 0x020B }, // LATIN CAPITAL LETTER I WITH INVERTED BREVE + { 0x020C, 0x020D }, // LATIN CAPITAL LETTER O WITH DOUBLE GRAVE + { 0x020E, 0x020F }, // LATIN CAPITAL LETTER O WITH INVERTED BREVE + { 0x0210, 0x0211 }, // LATIN CAPITAL LETTER R WITH DOUBLE GRAVE + { 0x0212, 0x0213 }, // LATIN CAPITAL LETTER R WITH INVERTED BREVE + { 0x0214, 0x0215 }, // LATIN CAPITAL LETTER U WITH DOUBLE GRAVE + { 0x0216, 0x0217 }, // LATIN CAPITAL LETTER U WITH INVERTED BREVE + { 0x0218, 0x0219 }, // LATIN CAPITAL LETTER S WITH COMMA BELOW + { 0x021A, 0x021B }, // LATIN CAPITAL LETTER T WITH COMMA BELOW { 0x021C, 0x021D }, // LATIN CAPITAL LETTER YOGH + { 0x021E, 0x021F }, // LATIN CAPITAL LETTER H WITH CARON { 0x0220, 0x019E }, // LATIN CAPITAL LETTER N WITH LONG RIGHT LEG { 0x0222, 0x0223 }, // LATIN CAPITAL LETTER OU { 0x0224, 0x0225 }, // LATIN CAPITAL LETTER Z WITH HOOK + { 0x0226, 0x0227 }, // LATIN CAPITAL LETTER A WITH DOT ABOVE + { 0x0228, 0x0229 }, // LATIN CAPITAL LETTER E WITH CEDILLA + { 0x022A, 0x022B }, // LATIN CAPITAL LETTER O WITH DIAERESIS AND MACRON + { 0x022C, 0x022D }, // LATIN CAPITAL LETTER O WITH TILDE AND MACRON + { 0x022E, 0x022F }, // LATIN CAPITAL LETTER O WITH DOT ABOVE + { 0x0230, 0x0231 }, // LATIN CAPITAL LETTER O WITH DOT ABOVE AND MACRON + { 0x0232, 0x0233 }, // LATIN CAPITAL LETTER Y WITH MACRON { 0x023A, 0x2C65 }, // LATIN CAPITAL LETTER A WITH STROKE { 0x023B, 0x023C }, // LATIN CAPITAL LETTER C WITH STROKE { 0x023D, 0x019A }, // LATIN CAPITAL LETTER L WITH BAR @@ -142,6 +278,13 @@ static const struct LatinCapitalSmallPair SORTED_CHAR_MAP[] = { { 0x0370, 0x0371 }, // GREEK CAPITAL LETTER HETA { 0x0372, 0x0373 }, // GREEK CAPITAL LETTER ARCHAIC SAMPI { 0x0376, 0x0377 }, // GREEK CAPITAL LETTER PAMPHYLIAN DIGAMMA + { 0x0386, 0x03AC }, // GREEK CAPITAL LETTER ALPHA WITH TONOS + { 0x0388, 0x03AD }, // GREEK CAPITAL LETTER EPSILON WITH TONOS + { 0x0389, 0x03AE }, // GREEK CAPITAL LETTER ETA WITH TONOS + { 0x038A, 0x03AF }, // GREEK CAPITAL LETTER IOTA WITH TONOS + { 0x038C, 0x03CC }, // GREEK CAPITAL LETTER OMICRON WITH TONOS + { 0x038E, 0x03CD }, // GREEK CAPITAL LETTER UPSILON WITH TONOS + { 0x038F, 0x03CE }, // GREEK CAPITAL LETTER OMEGA WITH TONOS { 0x0391, 0x03B1 }, // GREEK CAPITAL LETTER ALPHA { 0x0392, 0x03B2 }, // GREEK CAPITAL LETTER BETA { 0x0393, 0x03B3 }, // GREEK CAPITAL LETTER GAMMA @@ -166,6 +309,8 @@ static const struct LatinCapitalSmallPair SORTED_CHAR_MAP[] = { { 0x03A7, 0x03C7 }, // GREEK CAPITAL LETTER CHI { 0x03A8, 0x03C8 }, // GREEK CAPITAL LETTER PSI { 0x03A9, 0x03C9 }, // GREEK CAPITAL LETTER OMEGA + { 0x03AA, 0x03CA }, // GREEK CAPITAL LETTER IOTA WITH DIALYTIKA + { 0x03AB, 0x03CB }, // GREEK CAPITAL LETTER UPSILON WITH DIALYTIKA { 0x03CF, 0x03D7 }, // GREEK CAPITAL KAI SYMBOL { 0x03D8, 0x03D9 }, // GREEK LETTER ARCHAIC KOPPA { 0x03DA, 0x03DB }, // GREEK LETTER STIGMA @@ -179,19 +324,28 @@ static const struct LatinCapitalSmallPair SORTED_CHAR_MAP[] = { { 0x03EA, 0x03EB }, // COPTIC CAPITAL LETTER GANGIA { 0x03EC, 0x03ED }, // COPTIC CAPITAL LETTER SHIMA { 0x03EE, 0x03EF }, // COPTIC CAPITAL LETTER DEI + { 0x03F4, 0x03B8 }, // GREEK CAPITAL THETA SYMBOL { 0x03F7, 0x03F8 }, // GREEK CAPITAL LETTER SHO + { 0x03F9, 0x03F2 }, // GREEK CAPITAL LUNATE SIGMA SYMBOL { 0x03FA, 0x03FB }, // GREEK CAPITAL LETTER SAN { 0x03FD, 0x037B }, // GREEK CAPITAL REVERSED LUNATE SIGMA SYMBOL { 0x03FE, 0x037C }, // GREEK CAPITAL DOTTED LUNATE SIGMA SYMBOL { 0x03FF, 0x037D }, // GREEK CAPITAL REVERSED DOTTED LUNATE SIGMA SYMBOL + { 0x0400, 0x0450 }, // CYRILLIC CAPITAL LETTER IE WITH GRAVE + { 0x0401, 0x0451 }, // CYRILLIC CAPITAL LETTER IO { 0x0402, 0x0452 }, // CYRILLIC CAPITAL LETTER DJE + { 0x0403, 0x0453 }, // CYRILLIC CAPITAL LETTER GJE { 0x0404, 0x0454 }, // CYRILLIC CAPITAL LETTER UKRAINIAN IE { 0x0405, 0x0455 }, // CYRILLIC CAPITAL LETTER DZE { 0x0406, 0x0456 }, // CYRILLIC CAPITAL LETTER BYELORUSSIAN-UKRAINIAN I + { 0x0407, 0x0457 }, // CYRILLIC CAPITAL LETTER YI { 0x0408, 0x0458 }, // CYRILLIC CAPITAL LETTER JE { 0x0409, 0x0459 }, // CYRILLIC CAPITAL LETTER LJE { 0x040A, 0x045A }, // CYRILLIC CAPITAL LETTER NJE { 0x040B, 0x045B }, // CYRILLIC CAPITAL LETTER TSHE + { 0x040C, 0x045C }, // CYRILLIC CAPITAL LETTER KJE + { 0x040D, 0x045D }, // CYRILLIC CAPITAL LETTER I WITH GRAVE + { 0x040E, 0x045E }, // CYRILLIC CAPITAL LETTER SHORT U { 0x040F, 0x045F }, // CYRILLIC CAPITAL LETTER DZHE { 0x0410, 0x0430 }, // CYRILLIC CAPITAL LETTER A { 0x0411, 0x0431 }, // CYRILLIC CAPITAL LETTER BE @@ -236,6 +390,7 @@ static const struct LatinCapitalSmallPair SORTED_CHAR_MAP[] = { { 0x0470, 0x0471 }, // CYRILLIC CAPITAL LETTER PSI { 0x0472, 0x0473 }, // CYRILLIC CAPITAL LETTER FITA { 0x0474, 0x0475 }, // CYRILLIC CAPITAL LETTER IZHITSA + { 0x0476, 0x0477 }, // CYRILLIC CAPITAL LETTER IZHITSA WITH DOUBLE GRAVE ACCENT { 0x0478, 0x0479 }, // CYRILLIC CAPITAL LETTER UK { 0x047A, 0x047B }, // CYRILLIC CAPITAL LETTER ROUND OMEGA { 0x047C, 0x047D }, // CYRILLIC CAPITAL LETTER OMEGA WITH TITLO @@ -269,17 +424,34 @@ static const struct LatinCapitalSmallPair SORTED_CHAR_MAP[] = { { 0x04BC, 0x04BD }, // CYRILLIC CAPITAL LETTER ABKHASIAN CHE { 0x04BE, 0x04BF }, // CYRILLIC CAPITAL LETTER ABKHASIAN CHE WITH DESCENDER { 0x04C0, 0x04CF }, // CYRILLIC LETTER PALOCHKA + { 0x04C1, 0x04C2 }, // CYRILLIC CAPITAL LETTER ZHE WITH BREVE { 0x04C3, 0x04C4 }, // CYRILLIC CAPITAL LETTER KA WITH HOOK { 0x04C5, 0x04C6 }, // CYRILLIC CAPITAL LETTER EL WITH TAIL { 0x04C7, 0x04C8 }, // CYRILLIC CAPITAL LETTER EN WITH HOOK { 0x04C9, 0x04CA }, // CYRILLIC CAPITAL LETTER EN WITH TAIL { 0x04CB, 0x04CC }, // CYRILLIC CAPITAL LETTER KHAKASSIAN CHE { 0x04CD, 0x04CE }, // CYRILLIC CAPITAL LETTER EM WITH TAIL + { 0x04D0, 0x04D1 }, // CYRILLIC CAPITAL LETTER A WITH BREVE + { 0x04D2, 0x04D3 }, // CYRILLIC CAPITAL LETTER A WITH DIAERESIS { 0x04D4, 0x04D5 }, // CYRILLIC CAPITAL LIGATURE A IE + { 0x04D6, 0x04D7 }, // CYRILLIC CAPITAL LETTER IE WITH BREVE { 0x04D8, 0x04D9 }, // CYRILLIC CAPITAL LETTER SCHWA + { 0x04DA, 0x04DB }, // CYRILLIC CAPITAL LETTER SCHWA WITH DIAERESIS + { 0x04DC, 0x04DD }, // CYRILLIC CAPITAL LETTER ZHE WITH DIAERESIS + { 0x04DE, 0x04DF }, // CYRILLIC CAPITAL LETTER ZE WITH DIAERESIS { 0x04E0, 0x04E1 }, // CYRILLIC CAPITAL LETTER ABKHASIAN DZE + { 0x04E2, 0x04E3 }, // CYRILLIC CAPITAL LETTER I WITH MACRON + { 0x04E4, 0x04E5 }, // CYRILLIC CAPITAL LETTER I WITH DIAERESIS + { 0x04E6, 0x04E7 }, // CYRILLIC CAPITAL LETTER O WITH DIAERESIS { 0x04E8, 0x04E9 }, // CYRILLIC CAPITAL LETTER BARRED O + { 0x04EA, 0x04EB }, // CYRILLIC CAPITAL LETTER BARRED O WITH DIAERESIS + { 0x04EC, 0x04ED }, // CYRILLIC CAPITAL LETTER E WITH DIAERESIS + { 0x04EE, 0x04EF }, // CYRILLIC CAPITAL LETTER U WITH MACRON + { 0x04F0, 0x04F1 }, // CYRILLIC CAPITAL LETTER U WITH DIAERESIS + { 0x04F2, 0x04F3 }, // CYRILLIC CAPITAL LETTER U WITH DOUBLE ACUTE + { 0x04F4, 0x04F5 }, // CYRILLIC CAPITAL LETTER CHE WITH DIAERESIS { 0x04F6, 0x04F7 }, // CYRILLIC CAPITAL LETTER GHE WITH DESCENDER + { 0x04F8, 0x04F9 }, // CYRILLIC CAPITAL LETTER YERU WITH DIAERESIS { 0x04FA, 0x04FB }, // CYRILLIC CAPITAL LETTER GHE WITH STROKE AND HOOK { 0x04FC, 0x04FD }, // CYRILLIC CAPITAL LETTER HA WITH HOOK { 0x04FE, 0x04FF }, // CYRILLIC CAPITAL LETTER HA WITH STROKE diff --git a/native/jni/src/char_utils.h b/native/jni/src/char_utils.h index 58d388dbf..b429f40b2 100644 --- a/native/jni/src/char_utils.h +++ b/native/jni/src/char_utils.h @@ -58,7 +58,8 @@ inline static int toBaseCodePoint(int c) { AK_FORCE_INLINE static int toLowerCase(const int c) { if (isAsciiUpper(c)) { return toAsciiLower(c); - } else if (isAscii(c)) { + } + if (isAscii(c)) { return c; } return static_cast<int>(latin_tolower(static_cast<unsigned short>(c))); diff --git a/native/jni/src/correction.cpp b/native/jni/src/correction.cpp index 76234f840..0c65939e0 100644 --- a/native/jni/src/correction.cpp +++ b/native/jni/src/correction.cpp @@ -675,7 +675,7 @@ inline static bool isUpperCase(unsigned short c) { multiplyIntCapped(typedLetterMultiplier, &finalFreq); } const float factor = - SuggestUtils::getDistanceScalingFactor(static_cast<float>(squaredDistance)); + SuggestUtils::getLengthScalingFactor(static_cast<float>(squaredDistance)); if (factor > 0.0f) { multiplyRate(static_cast<int>(factor * 100.0f), &finalFreq); } else if (squaredDistance == PROXIMITY_CHAR_WITHOUT_DISTANCE_INFO) { diff --git a/native/jni/src/defines.h b/native/jni/src/defines.h index a7b023a75..6ef9f414b 100644 --- a/native/jni/src/defines.h +++ b/native/jni/src/defines.h @@ -424,10 +424,9 @@ typedef enum { CT_OMISSION, CT_INSERTION, CT_TRANSPOSITION, - CT_SPACE_SUBSTITUTION, - CT_SPACE_OMISSION, CT_COMPLETION, CT_TERMINAL, - CT_NEW_WORD, + CT_NEW_WORD_SPACE_OMITTION, + CT_NEW_WORD_SPACE_SUBSTITUTION, } CorrectionType; #endif // LATINIME_DEFINES_H diff --git a/native/jni/src/digraph_utils.cpp b/native/jni/src/digraph_utils.cpp index 6a1ab0271..083442669 100644 --- a/native/jni/src/digraph_utils.cpp +++ b/native/jni/src/digraph_utils.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ +#include "char_utils.h" #include "binary_format.h" #include "defines.h" #include "digraph_utils.h" @@ -120,10 +121,11 @@ const DigraphUtils::DigraphType DigraphUtils::USED_DIGRAPH_TYPES[] = /* static */ const DigraphUtils::digraph_t *DigraphUtils::getDigraphForDigraphTypeAndCodePoint( const DigraphUtils::DigraphType digraphType, const int compositeGlyphCodePoint) { const DigraphUtils::digraph_t *digraphs = 0; + const int compositeGlyphLowerCodePoint = toLowerCase(compositeGlyphCodePoint); const int digraphsSize = DigraphUtils::getAllDigraphsForDictionaryAndReturnSize(digraphType, &digraphs); for (int i = 0; i < digraphsSize; i++) { - if (digraphs[i].compositeGlyph == compositeGlyphCodePoint) { + if (digraphs[i].compositeGlyph == compositeGlyphLowerCodePoint) { return &digraphs[i]; } } diff --git a/native/jni/src/proximity_info_state.cpp b/native/jni/src/proximity_info_state.cpp index a10b260e1..cc5b736bd 100644 --- a/native/jni/src/proximity_info_state.cpp +++ b/native/jni/src/proximity_info_state.cpp @@ -81,7 +81,7 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi mSampledTimes.clear(); mSampledInputIndice.clear(); mSampledLengthCache.clear(); - mSampledDistanceCache_G.clear(); + mSampledNormalizedSquaredLengthCache.clear(); mSampledNearKeySets.clear(); mSampledSearchKeySets.clear(); mSpeedRates.clear(); @@ -122,14 +122,15 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi if (mSampledInputSize > 0) { ProximityInfoStateUtils::initGeometricDistanceInfos(mProximityInfo, mSampledInputSize, lastSavedInputSize, verticalSweetSpotScale, &mSampledInputXs, &mSampledInputYs, - &mSampledNearKeySets, &mSampledDistanceCache_G); + &mSampledNearKeySets, &mSampledNormalizedSquaredLengthCache); if (isGeometric) { // updates probabilities of skipping or mapping each key for all points. ProximityInfoStateUtils::updateAlignPointProbabilities( mMaxPointToKeyLength, mProximityInfo->getMostCommonKeyWidth(), mProximityInfo->getKeyCount(), lastSavedInputSize, mSampledInputSize, &mSampledInputXs, &mSampledInputYs, &mSpeedRates, &mSampledLengthCache, - &mSampledDistanceCache_G, &mSampledNearKeySets, &mCharProbabilities); + &mSampledNormalizedSquaredLengthCache, &mSampledNearKeySets, + &mCharProbabilities); ProximityInfoStateUtils::updateSampledSearchKeySets(mProximityInfo, mSampledInputSize, lastSavedInputSize, &mSampledLengthCache, &mSampledNearKeySets, &mSampledSearchKeySets, @@ -171,7 +172,7 @@ float ProximityInfoState::getPointToKeyLength( const int keyId = mProximityInfo->getKeyIndexOf(codePoint); if (keyId != NOT_AN_INDEX) { const int index = inputIndex * mProximityInfo->getKeyCount() + keyId; - return min(mSampledDistanceCache_G[index], mMaxPointToKeyLength); + return min(mSampledNormalizedSquaredLengthCache[index], mMaxPointToKeyLength); } if (isIntentionalOmissionCodePoint(codePoint)) { return 0.0f; @@ -183,7 +184,8 @@ float ProximityInfoState::getPointToKeyLength( float ProximityInfoState::getPointToKeyByIdLength( const int inputIndex, const int keyId) const { return ProximityInfoStateUtils::getPointToKeyByIdLength(mMaxPointToKeyLength, - &mSampledDistanceCache_G, mProximityInfo->getKeyCount(), inputIndex, keyId); + &mSampledNormalizedSquaredLengthCache, mProximityInfo->getKeyCount(), inputIndex, + keyId); } // In the following function, c is the current character of the dictionary word currently examined. diff --git a/native/jni/src/proximity_info_state.h b/native/jni/src/proximity_info_state.h index 9bba751d0..bbe8af240 100644 --- a/native/jni/src/proximity_info_state.h +++ b/native/jni/src/proximity_info_state.h @@ -49,8 +49,8 @@ class ProximityInfoState { mKeyCount(0), mCellHeight(0), mCellWidth(0), mGridHeight(0), mGridWidth(0), mIsContinuousSuggestionPossible(false), mSampledInputXs(), mSampledInputYs(), mSampledTimes(), mSampledInputIndice(), mSampledLengthCache(), - mBeelineSpeedPercentiles(), mSampledDistanceCache_G(), mSpeedRates(), mDirections(), - mCharProbabilities(), mSampledNearKeySets(), mSampledSearchKeySets(), + mBeelineSpeedPercentiles(), mSampledNormalizedSquaredLengthCache(), mSpeedRates(), + mDirections(), mCharProbabilities(), mSampledNearKeySets(), mSampledSearchKeySets(), mSampledSearchKeyVectors(), mTouchPositionCorrectionEnabled(false), mSampledInputSize(0), mMostProbableStringProbability(0.0f) { memset(mInputProximities, 0, sizeof(mInputProximities)); @@ -147,7 +147,9 @@ class ProximityInfoState { return mIsContinuousSuggestionPossible; } + // TODO: Rename s/Length/NormalizedSquaredLength/ float getPointToKeyByIdLength(const int inputIndex, const int keyId) const; + // TODO: Rename s/Length/NormalizedSquaredLength/ float getPointToKeyLength(const int inputIndex, const int codePoint) const; ProximityType getProximityType(const int index, const int codePoint, @@ -231,7 +233,7 @@ class ProximityInfoState { std::vector<int> mSampledInputIndice; std::vector<int> mSampledLengthCache; std::vector<int> mBeelineSpeedPercentiles; - std::vector<float> mSampledDistanceCache_G; + std::vector<float> mSampledNormalizedSquaredLengthCache; std::vector<float> mSpeedRates; std::vector<float> mDirections; // probabilities of skipping or mapping to a key for each point. diff --git a/native/jni/src/proximity_info_state_utils.cpp b/native/jni/src/proximity_info_state_utils.cpp index df70cffdf..359673cd8 100644 --- a/native/jni/src/proximity_info_state_utils.cpp +++ b/native/jni/src/proximity_info_state_utils.cpp @@ -225,13 +225,13 @@ namespace latinime { const int lastSavedInputSize, const float verticalSweetSpotScale, const std::vector<int> *const sampledInputXs, const std::vector<int> *const sampledInputYs, - std::vector<NearKeycodesSet> *SampledNearKeySets, - std::vector<float> *SampledDistanceCache_G) { - SampledNearKeySets->resize(sampledInputSize); + std::vector<NearKeycodesSet> *sampledNearKeySets, + std::vector<float> *sampledNormalizedSquaredLengthCache) { + sampledNearKeySets->resize(sampledInputSize); const int keyCount = proximityInfo->getKeyCount(); - SampledDistanceCache_G->resize(sampledInputSize * keyCount); + sampledNormalizedSquaredLengthCache->resize(sampledInputSize * keyCount); for (int i = lastSavedInputSize; i < sampledInputSize; ++i) { - (*SampledNearKeySets)[i].reset(); + (*sampledNearKeySets)[i].reset(); for (int k = 0; k < keyCount; ++k) { const int index = i * keyCount + k; const int x = (*sampledInputXs)[i]; @@ -239,10 +239,10 @@ namespace latinime { const float normalizedSquaredDistance = proximityInfo->getNormalizedSquaredDistanceFromCenterFloatG( k, x, y, verticalSweetSpotScale); - (*SampledDistanceCache_G)[index] = normalizedSquaredDistance; + (*sampledNormalizedSquaredLengthCache)[index] = normalizedSquaredDistance; if (normalizedSquaredDistance < ProximityInfoParams::NEAR_KEY_NORMALIZED_SQUARED_THRESHOLD) { - (*SampledNearKeySets)[i][k] = true; + (*sampledNearKeySets)[i][k] = true; } } } @@ -642,11 +642,11 @@ namespace latinime { // This function basically converts from a length to an edit distance. Accordingly, it's obviously // wrong to compare with mMaxPointToKeyLength. /* static */ float ProximityInfoStateUtils::getPointToKeyByIdLength(const float maxPointToKeyLength, - const std::vector<float> *const SampledDistanceCache_G, const int keyCount, + const std::vector<float> *const sampledNormalizedSquaredLengthCache, const int keyCount, const int inputIndex, const int keyId) { if (keyId != NOT_AN_INDEX) { const int index = inputIndex * keyCount + keyId; - return min((*SampledDistanceCache_G)[index], maxPointToKeyLength); + return min((*sampledNormalizedSquaredLengthCache)[index], maxPointToKeyLength); } // If the char is not a key on the keyboard then return the max length. return static_cast<float>(MAX_VALUE_FOR_WEIGHTING); @@ -660,8 +660,8 @@ namespace latinime { const std::vector<int> *const sampledInputYs, const std::vector<float> *const sampledSpeedRates, const std::vector<int> *const sampledLengthCache, - const std::vector<float> *const SampledDistanceCache_G, - std::vector<NearKeycodesSet> *SampledNearKeySets, + const std::vector<float> *const sampledNormalizedSquaredLengthCache, + std::vector<NearKeycodesSet> *sampledNearKeySets, std::vector<hash_map_compat<int, float> > *charProbabilities) { charProbabilities->resize(sampledInputSize); // Calculates probabilities of using a point as a correlated point with the character @@ -677,9 +677,9 @@ namespace latinime { float nearestKeyDistance = static_cast<float>(MAX_VALUE_FOR_WEIGHTING); for (int j = 0; j < keyCount; ++j) { - if ((*SampledNearKeySets)[i].test(j)) { + if ((*sampledNearKeySets)[i].test(j)) { const float distance = getPointToKeyByIdLength( - maxPointToKeyLength, SampledDistanceCache_G, keyCount, i, j); + maxPointToKeyLength, sampledNormalizedSquaredLengthCache, keyCount, i, j); if (distance < nearestKeyDistance) { nearestKeyDistance = distance; } @@ -758,14 +758,15 @@ namespace latinime { // Summing up probability densities of all near keys. float sumOfProbabilityDensities = 0.0f; for (int j = 0; j < keyCount; ++j) { - if ((*SampledNearKeySets)[i].test(j)) { + if ((*sampledNearKeySets)[i].test(j)) { float distance = sqrtf(getPointToKeyByIdLength( - maxPointToKeyLength, SampledDistanceCache_G, keyCount, i, j)); + maxPointToKeyLength, sampledNormalizedSquaredLengthCache, keyCount, i, j)); if (i == 0 && i != sampledInputSize - 1) { // For the first point, weighted average of distances from first point and the // next point to the key is used as a point to key distance. const float nextDistance = sqrtf(getPointToKeyByIdLength( - maxPointToKeyLength, SampledDistanceCache_G, keyCount, i + 1, j)); + maxPointToKeyLength, sampledNormalizedSquaredLengthCache, keyCount, + i + 1, j)); if (nextDistance < distance) { // The distance of the first point tends to bigger than continuing // points because the first touch by the user can be sloppy. @@ -779,7 +780,8 @@ namespace latinime { // For the first point, weighted average of distances from last point and // the previous point to the key is used as a point to key distance. const float previousDistance = sqrtf(getPointToKeyByIdLength( - maxPointToKeyLength, SampledDistanceCache_G, keyCount, i - 1, j)); + maxPointToKeyLength, sampledNormalizedSquaredLengthCache, keyCount, + i - 1, j)); if (previousDistance < distance) { // The distance of the last point tends to bigger than continuing points // because the last touch by the user can be sloppy. So we promote the @@ -798,14 +800,15 @@ namespace latinime { // Split the probability of an input point to keys that are close to the input point. for (int j = 0; j < keyCount; ++j) { - if ((*SampledNearKeySets)[i].test(j)) { + if ((*sampledNearKeySets)[i].test(j)) { float distance = sqrtf(getPointToKeyByIdLength( - maxPointToKeyLength, SampledDistanceCache_G, keyCount, i, j)); + maxPointToKeyLength, sampledNormalizedSquaredLengthCache, keyCount, i, j)); if (i == 0 && i != sampledInputSize - 1) { // For the first point, weighted average of distances from the first point and // the next point to the key is used as a point to key distance. const float prevDistance = sqrtf(getPointToKeyByIdLength( - maxPointToKeyLength, SampledDistanceCache_G, keyCount, i + 1, j)); + maxPointToKeyLength, sampledNormalizedSquaredLengthCache, keyCount, + i + 1, j)); if (prevDistance < distance) { distance = (distance + prevDistance * ProximityInfoParams::NEXT_DISTANCE_WEIGHT) @@ -815,7 +818,8 @@ namespace latinime { // For the first point, weighted average of distances from last point and // the previous point to the key is used as a point to key distance. const float prevDistance = sqrtf(getPointToKeyByIdLength( - maxPointToKeyLength, SampledDistanceCache_G, keyCount, i - 1, j)); + maxPointToKeyLength, sampledNormalizedSquaredLengthCache, keyCount, + i - 1, j)); if (prevDistance < distance) { distance = (distance + prevDistance * ProximityInfoParams::PREV_DISTANCE_WEIGHT) @@ -882,10 +886,10 @@ namespace latinime { for (int j = 0; j < keyCount; ++j) { hash_map_compat<int, float>::iterator it = (*charProbabilities)[i].find(j); if (it == (*charProbabilities)[i].end()){ - (*SampledNearKeySets)[i].reset(j); + (*sampledNearKeySets)[i].reset(j); } else if(it->second < ProximityInfoParams::MIN_PROBABILITY) { // Erases from near keys vector because it has very low probability. - (*SampledNearKeySets)[i].reset(j); + (*sampledNearKeySets)[i].reset(j); (*charProbabilities)[i].erase(j); } else { it->second = -logf(it->second); @@ -899,7 +903,7 @@ namespace latinime { const ProximityInfo *const proximityInfo, const int sampledInputSize, const int lastSavedInputSize, const std::vector<int> *const sampledLengthCache, - const std::vector<NearKeycodesSet> *const SampledNearKeySets, + const std::vector<NearKeycodesSet> *const sampledNearKeySets, std::vector<NearKeycodesSet> *sampledSearchKeySets, std::vector<std::vector<int> > *sampledSearchKeyVectors) { sampledSearchKeySets->resize(sampledInputSize); @@ -916,7 +920,7 @@ namespace latinime { if ((*sampledLengthCache)[j] - (*sampledLengthCache)[i] >= readForwordLength) { break; } - (*sampledSearchKeySets)[i] |= (*SampledNearKeySets)[j]; + (*sampledSearchKeySets)[i] |= (*sampledNearKeySets)[j]; } } const int keyCount = proximityInfo->getKeyCount(); diff --git a/native/jni/src/proximity_info_state_utils.h b/native/jni/src/proximity_info_state_utils.h index c9feb59a3..1837c7ab6 100644 --- a/native/jni/src/proximity_info_state_utils.h +++ b/native/jni/src/proximity_info_state_utils.h @@ -71,25 +71,25 @@ class ProximityInfoStateUtils { const std::vector<int> *const sampledInputYs, const std::vector<float> *const sampledSpeedRates, const std::vector<int> *const sampledLengthCache, - const std::vector<float> *const SampledDistanceCache_G, - std::vector<NearKeycodesSet> *SampledNearKeySets, + const std::vector<float> *const sampledNormalizedSquaredLengthCache, + std::vector<NearKeycodesSet> *sampledNearKeySets, std::vector<hash_map_compat<int, float> > *charProbabilities); static void updateSampledSearchKeySets(const ProximityInfo *const proximityInfo, const int sampledInputSize, const int lastSavedInputSize, const std::vector<int> *const sampledLengthCache, - const std::vector<NearKeycodesSet> *const SampledNearKeySets, + const std::vector<NearKeycodesSet> *const sampledNearKeySets, std::vector<NearKeycodesSet> *sampledSearchKeySets, std::vector<std::vector<int> > *sampledSearchKeyVectors); static float getPointToKeyByIdLength(const float maxPointToKeyLength, - const std::vector<float> *const SampledDistanceCache_G, const int keyCount, + const std::vector<float> *const sampledNormalizedSquaredLengthCache, const int keyCount, const int inputIndex, const int keyId); static void initGeometricDistanceInfos(const ProximityInfo *const proximityInfo, const int sampledInputSize, const int lastSavedInputSize, const float verticalSweetSpotScale, const std::vector<int> *const sampledInputXs, const std::vector<int> *const sampledInputYs, - std::vector<NearKeycodesSet> *SampledNearKeySets, - std::vector<float> *SampledDistanceCache_G); + std::vector<NearKeycodesSet> *sampledNearKeySets, + std::vector<float> *sampledNormalizedSquaredLengthCache); static void initPrimaryInputWord(const int inputSize, const int *const inputProximities, int *primaryInputWord); static void initNormalizedSquaredDistances(const ProximityInfo *const proximityInfo, diff --git a/native/jni/src/suggest/core/dicnode/dic_node.h b/native/jni/src/suggest/core/dicnode/dic_node.h index 32faae52c..f8d2df452 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node.h +++ b/native/jni/src/suggest/core/dicnode/dic_node.h @@ -360,11 +360,6 @@ class DicNode { return mDicNodeState.mDicNodeStateScoring.getCompoundDistance(languageWeight); } - // Note that "cost" means delta for "distance" that is weighted. - float getTotalPrevWordsLanguageCost() const { - return mDicNodeState.mDicNodeStateScoring.getTotalPrevWordsLanguageCost(); - } - // Used to commit input partially int getPrevWordNodePos() const { return mDicNodeState.mDicNodeStatePrevWord.getPrevWordNodePos(); diff --git a/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h b/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h index 8902d3122..fd9d610e3 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h +++ b/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h @@ -31,7 +31,7 @@ class DicNodeStateScoring { mDigraphIndex(DigraphUtils::NOT_A_DIGRAPH_INDEX), mEditCorrectionCount(0), mProximityCorrectionCount(0), mNormalizedCompoundDistance(0.0f), mSpatialDistance(0.0f), mLanguageDistance(0.0f), - mTotalPrevWordsLanguageCost(0.0f), mRawLength(0.0f) { + mRawLength(0.0f) { } virtual ~DicNodeStateScoring() {} @@ -42,7 +42,6 @@ class DicNodeStateScoring { mNormalizedCompoundDistance = 0.0f; mSpatialDistance = 0.0f; mLanguageDistance = 0.0f; - mTotalPrevWordsLanguageCost = 0.0f; mRawLength = 0.0f; mDoubleLetterLevel = NOT_A_DOUBLE_LETTER; mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX; @@ -54,7 +53,6 @@ class DicNodeStateScoring { mNormalizedCompoundDistance = scoring->mNormalizedCompoundDistance; mSpatialDistance = scoring->mSpatialDistance; mLanguageDistance = scoring->mLanguageDistance; - mTotalPrevWordsLanguageCost = scoring->mTotalPrevWordsLanguageCost; mRawLength = scoring->mRawLength; mDoubleLetterLevel = scoring->mDoubleLetterLevel; mDigraphIndex = scoring->mDigraphIndex; @@ -70,9 +68,6 @@ class DicNodeStateScoring { if (isProximityCorrection) { ++mProximityCorrectionCount; } - if (languageCost > 0.0f) { - setTotalPrevWordsLanguageCost(mTotalPrevWordsLanguageCost + languageCost); - } } void addRawLength(const float rawLength) { @@ -148,10 +143,6 @@ class DicNodeStateScoring { } } - float getTotalPrevWordsLanguageCost() const { - return mTotalPrevWordsLanguageCost; - } - private: // Caution!!! // Use a default copy constructor and an assign operator because shallow copies are ok @@ -165,7 +156,6 @@ class DicNodeStateScoring { float mNormalizedCompoundDistance; float mSpatialDistance; float mLanguageDistance; - float mTotalPrevWordsLanguageCost; float mRawLength; AK_FORCE_INLINE void addDistance(float spatialDistance, float languageDistance, @@ -179,11 +169,6 @@ class DicNodeStateScoring { / static_cast<float>(max(1, totalInputIndex)); } } - - //TODO: remove - AK_FORCE_INLINE void setTotalPrevWordsLanguageCost(float totalPrevWordsLanguageCost) { - mTotalPrevWordsLanguageCost = totalPrevWordsLanguageCost; - } }; } // namespace latinime #endif // LATINIME_DIC_NODE_STATE_SCORING_H diff --git a/native/jni/src/suggest/core/policy/weighting.cpp b/native/jni/src/suggest/core/policy/weighting.cpp index e62b70423..b9c0b8129 100644 --- a/native/jni/src/suggest/core/policy/weighting.cpp +++ b/native/jni/src/suggest/core/policy/weighting.cpp @@ -38,7 +38,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n case CT_SUBSTITUTION: PROF_SUBSTITUTION(node->mProfiler); return; - case CT_NEW_WORD: + case CT_NEW_WORD_SPACE_OMITTION: PROF_NEW_WORD(node->mProfiler); return; case CT_MATCH: @@ -50,7 +50,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n case CT_TERMINAL: PROF_TERMINAL(node->mProfiler); return; - case CT_SPACE_SUBSTITUTION: + case CT_NEW_WORD_SPACE_SUBSTITUTION: PROF_SPACE_SUBSTITUTION(node->mProfiler); return; case CT_INSERTION: @@ -107,16 +107,16 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n case CT_SUBSTITUTION: // only used for typing return weighting->getSubstitutionCost(); - case CT_NEW_WORD: - return weighting->getNewWordCost(dicNode); + case CT_NEW_WORD_SPACE_OMITTION: + return weighting->getNewWordCost(traverseSession, dicNode); case CT_MATCH: return weighting->getMatchedCost(traverseSession, dicNode, inputStateG); case CT_COMPLETION: return weighting->getCompletionCost(traverseSession, dicNode); case CT_TERMINAL: return weighting->getTerminalSpatialCost(traverseSession, dicNode); - case CT_SPACE_SUBSTITUTION: - return weighting->getSpaceSubstitutionCost(); + case CT_NEW_WORD_SPACE_SUBSTITUTION: + return weighting->getSpaceSubstitutionCost(traverseSession, dicNode); case CT_INSERTION: return weighting->getInsertionCost(traverseSession, parentDicNode, dicNode); case CT_TRANSPOSITION: @@ -135,7 +135,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n return 0.0f; case CT_SUBSTITUTION: return 0.0f; - case CT_NEW_WORD: + case CT_NEW_WORD_SPACE_OMITTION: return weighting->getNewWordBigramCost(traverseSession, parentDicNode, bigramCacheMap); case CT_MATCH: return 0.0f; @@ -147,8 +147,8 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n traverseSession->getOffsetDict(), dicNode, bigramCacheMap); return weighting->getTerminalLanguageCost(traverseSession, dicNode, languageImprobability); } - case CT_SPACE_SUBSTITUTION: - return 0.0f; + case CT_NEW_WORD_SPACE_SUBSTITUTION: + return weighting->getNewWordBigramCost(traverseSession, parentDicNode, bigramCacheMap); case CT_INSERTION: return 0.0f; case CT_TRANSPOSITION: @@ -168,7 +168,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n case CT_SUBSTITUTION: // Should return true? return false; - case CT_NEW_WORD: + case CT_NEW_WORD_SPACE_OMITTION: return false; case CT_MATCH: return false; @@ -176,7 +176,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n return false; case CT_TERMINAL: return false; - case CT_SPACE_SUBSTITUTION: + case CT_NEW_WORD_SPACE_SUBSTITUTION: return false; case CT_INSERTION: return true; @@ -197,7 +197,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n return false; case CT_SUBSTITUTION: return false; - case CT_NEW_WORD: + case CT_NEW_WORD_SPACE_OMITTION: return false; case CT_MATCH: return weighting->isProximityDicNode(traverseSession, dicNode); @@ -205,7 +205,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n return false; case CT_TERMINAL: return false; - case CT_SPACE_SUBSTITUTION: + case CT_NEW_WORD_SPACE_SUBSTITUTION: return false; case CT_INSERTION: return false; @@ -224,7 +224,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n return 0; case CT_SUBSTITUTION: return 0; - case CT_NEW_WORD: + case CT_NEW_WORD_SPACE_OMITTION: return 0; case CT_MATCH: return 1; @@ -232,7 +232,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n return 0; case CT_TERMINAL: return 0; - case CT_SPACE_SUBSTITUTION: + case CT_NEW_WORD_SPACE_SUBSTITUTION: return 1; case CT_INSERTION: return 2; diff --git a/native/jni/src/suggest/core/policy/weighting.h b/native/jni/src/suggest/core/policy/weighting.h index b92dbe278..bce479c51 100644 --- a/native/jni/src/suggest/core/policy/weighting.h +++ b/native/jni/src/suggest/core/policy/weighting.h @@ -56,7 +56,8 @@ class Weighting { const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0; - virtual float getNewWordCost(const DicNode *const dicNode) const = 0; + virtual float getNewWordCost(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const = 0; virtual float getNewWordBigramCost( const DicTraverseSession *const traverseSession, const DicNode *const dicNode, @@ -76,7 +77,8 @@ class Weighting { virtual float getSubstitutionCost() const = 0; - virtual float getSpaceSubstitutionCost() const = 0; + virtual float getSpaceSubstitutionCost(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const = 0; Weighting() {} virtual ~Weighting() {} diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.cpp b/native/jni/src/suggest/core/session/dic_traverse_session.cpp index 5b783a2ba..3c44db21c 100644 --- a/native/jni/src/suggest/core/session/dic_traverse_session.cpp +++ b/native/jni/src/suggest/core/session/dic_traverse_session.cpp @@ -16,6 +16,7 @@ #include "suggest/core/session/dic_traverse_session.h" +#include "binary_format.h" #include "defines.h" #include "dictionary.h" #include "dic_traverse_wrapper.h" @@ -63,6 +64,7 @@ static TraverseSessionFactoryRegisterer traverseSessionFactoryRegisterer; void DicTraverseSession::init(const Dictionary *const dictionary, const int *prevWord, int prevWordLength) { mDictionary = dictionary; + mMultiWordCostMultiplier = BinaryFormat::getMultiWordCostMultiplier(mDictionary->getDict()); if (!prevWord) { mPrevWordPos = NOT_VALID_WORD; return; diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.h b/native/jni/src/suggest/core/session/dic_traverse_session.h index 525d198cd..d9c2a51d0 100644 --- a/native/jni/src/suggest/core/session/dic_traverse_session.h +++ b/native/jni/src/suggest/core/session/dic_traverse_session.h @@ -36,7 +36,8 @@ class DicTraverseSession { AK_FORCE_INLINE DicTraverseSession(JNIEnv *env, jstring localeStr) : mPrevWordPos(NOT_VALID_WORD), mProximityInfo(0), mDictionary(0), mDicNodesCache(), mBigramCacheMap(), - mInputSize(0), mPartiallyCommited(false), mMaxPointerCount(1) { + mInputSize(0), mPartiallyCommited(false), mMaxPointerCount(1), + mMultiWordCostMultiplier(1.0f) { // NOTE: mProximityInfoStates is an array of instances. // No need to initialize it explicitly here. } @@ -52,6 +53,7 @@ class DicTraverseSession { const int maxPointerCount); void resetCache(const int nextActiveCacheSize, const int maxWords); + // TODO: Remove const uint8_t *getOffsetDict() const; int getDictFlags() const; @@ -134,7 +136,7 @@ class DicTraverseSession { if (!mDicNodesCache.hasCachedDicNodesForContinuousSuggestion()) { return false; } - ASSERT(mMaxPointerCount < MAX_POINTER_COUNT_G); + ASSERT(mMaxPointerCount <= MAX_POINTER_COUNT_G); for (int i = 0; i < mMaxPointerCount; ++i) { const ProximityInfoState *const pInfoState = getProximityInfoState(i); // If a proximity info state is not continuous suggestion possible, @@ -146,6 +148,14 @@ class DicTraverseSession { return true; } + bool isTouchPositionCorrectionEnabled() const { + return mProximityInfoStates[0].touchPositionCorrectionEnabled(); + } + + float getMultiWordCostMultiplier() const { + return mMultiWordCostMultiplier; + } + private: DISALLOW_IMPLICIT_CONSTRUCTORS(DicTraverseSession); // threshold to start caching @@ -166,6 +176,11 @@ class DicTraverseSession { int mInputSize; bool mPartiallyCommited; int mMaxPointerCount; + + ///////////////////////////////// + // Configuration per dictionary + float mMultiWordCostMultiplier; + }; } // namespace latinime #endif // LATINIME_DIC_TRAVERSE_SESSION_H diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp index 67d351fa1..9de2cd2e2 100644 --- a/native/jni/src/suggest/core/suggest.cpp +++ b/native/jni/src/suggest/core/suggest.cpp @@ -33,16 +33,9 @@ namespace latinime { // Initialization of class constants. -const int Suggest::LOOKAHEAD_DIC_NODES_CACHE_SIZE = 25; const int Suggest::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; const int Suggest::MIN_CONTINUOUS_SUGGESTION_INPUT_SIZE = 2; const float Suggest::AUTOCORRECT_CLASSIFICATION_THRESHOLD = 0.33f; -const float Suggest::AUTOCORRECT_LANGUAGE_FEATURE_THRESHOLD = 0.6f; - -const bool Suggest::CORRECT_SPACE_OMISSION = true; -const bool Suggest::CORRECT_TRANSPOSITION = true; -const bool Suggest::CORRECT_INSERTION = true; -const bool Suggest::CORRECT_OMISSION_G = true; /** * Returns a set of suggestions for the given input touch points. The commitPoint argument indicates @@ -270,12 +263,8 @@ void Suggest::expandCurrentDicNodes(DicTraverseSession *traverseSession) const { // latest touch point yet. These are needed to apply look-ahead correction operations // that require special handling of the latest touch point. For example, with insertions // (e.g., "thiis" -> "this") the latest touch point should not be consumed at all. - if (CORRECT_TRANSPOSITION) { - processDicNodeAsTransposition(traverseSession, &dicNode); - } - if (CORRECT_INSERTION) { - processDicNodeAsInsertion(traverseSession, &dicNode); - } + processDicNodeAsTransposition(traverseSession, &dicNode); + processDicNodeAsInsertion(traverseSession, &dicNode); } else { // !isLookAheadCorrection // Only consider typing error corrections if the normalized compound distance is // below a spatial distance threshold. @@ -531,13 +520,10 @@ void Suggest::createNextWordDicNode(DicTraverseSession *traverseSession, DicNode DicNode newDicNode; DicNodeUtils::initAsRootWithPreviousWord(traverseSession->getDicRootPos(), traverseSession->getOffsetDict(), dicNode, &newDicNode); - Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_NEW_WORD, traverseSession, dicNode, + const CorrectionType correctionType = spaceSubstitution ? + CT_NEW_WORD_SPACE_SUBSTITUTION : CT_NEW_WORD_SPACE_OMITTION; + Weighting::addCostAndForwardInputIndex(WEIGHTING, correctionType, traverseSession, dicNode, &newDicNode, traverseSession->getBigramCacheMap()); - if (spaceSubstitution) { - // Merge this with CT_NEW_WORD - Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_SPACE_SUBSTITUTION, - traverseSession, 0, &newDicNode, 0 /* bigramCacheMap */); - } traverseSession->getDicTraverseCache()->copyPushNextActive(&newDicNode); } } // namespace latinime diff --git a/native/jni/src/suggest/core/suggest.h b/native/jni/src/suggest/core/suggest.h index becd6c1de..875cbe4e0 100644 --- a/native/jni/src/suggest/core/suggest.h +++ b/native/jni/src/suggest/core/suggest.h @@ -76,31 +76,16 @@ class Suggest : public SuggestInterface { void processDicNodeAsMatch(DicTraverseSession *traverseSession, DicNode *childDicNode) const; - // Dic nodes cache size for lookahead (autocompletion) - static const int LOOKAHEAD_DIC_NODES_CACHE_SIZE; - // Max characters to lookahead - static const int MAX_LOOKAHEAD; // Inputs longer than this will autocorrect if the suggestion is multi-word static const int MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT; static const int MIN_CONTINUOUS_SUGGESTION_INPUT_SIZE; - // Base value for converting costs into scores (low so will not autocorrect without classifier) - static const float BASE_OUTPUT_SCORE; // Threshold for autocorrection classifier static const float AUTOCORRECT_CLASSIFICATION_THRESHOLD; - // Threshold for computing the language model feature for autocorrect classification - static const float AUTOCORRECT_LANGUAGE_FEATURE_THRESHOLD; - - // Typing error correction settings - static const bool CORRECT_SPACE_OMISSION; - static const bool CORRECT_TRANSPOSITION; - static const bool CORRECT_INSERTION; const Traversal *const TRAVERSAL; const Scoring *const SCORING; const Weighting *const WEIGHTING; - - static const bool CORRECT_OMISSION_G; }; } // namespace latinime #endif // LATINIME_SUGGEST_IMPL_H diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp index 0fa684f01..11ccf1773 100644 --- a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp +++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp @@ -35,17 +35,17 @@ const float ScoringParams::INSERTION_COST = 0.670f; const float ScoringParams::INSERTION_COST_SAME_CHAR = 0.526f; const float ScoringParams::INSERTION_COST_FIRST_CHAR = 0.563f; const float ScoringParams::TRANSPOSITION_COST = 0.494f; -const float ScoringParams::SPACE_SUBSTITUTION_COST = 0.239f; +const float ScoringParams::SPACE_SUBSTITUTION_COST = 0.289f; const float ScoringParams::ADDITIONAL_PROXIMITY_COST = 0.380f; const float ScoringParams::SUBSTITUTION_COST = 0.363f; -const float ScoringParams::COST_NEW_WORD = 0.054f; +const float ScoringParams::COST_NEW_WORD = 0.024f; const float ScoringParams::COST_NEW_WORD_CAPITALIZED = 0.174f; const float ScoringParams::DISTANCE_WEIGHT_LANGUAGE = 1.123f; const float ScoringParams::COST_FIRST_LOOKAHEAD = 0.462f; const float ScoringParams::COST_LOOKAHEAD = 0.092f; const float ScoringParams::HAS_PROXIMITY_TERMINAL_COST = 0.126f; const float ScoringParams::HAS_EDIT_CORRECTION_TERMINAL_COST = 0.056f; -const float ScoringParams::HAS_MULTI_WORD_TERMINAL_COST = 0.136f; +const float ScoringParams::HAS_MULTI_WORD_TERMINAL_COST = 0.536f; const float ScoringParams::TYPING_BASE_OUTPUT_SCORE = 1.0f; const float ScoringParams::TYPING_MAX_OUTPUT_SCORE_PER_INPUT = 0.1f; const float ScoringParams::MAX_NORM_DISTANCE_FOR_EDIT = 0.1f; diff --git a/native/jni/src/suggest/policyimpl/typing/typing_traversal.cpp b/native/jni/src/suggest/policyimpl/typing/typing_traversal.cpp index 66f8ba9fa..e7e40e34d 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_traversal.cpp +++ b/native/jni/src/suggest/policyimpl/typing/typing_traversal.cpp @@ -18,7 +18,7 @@ namespace latinime { const bool TypingTraversal::CORRECT_OMISSION = true; -const bool TypingTraversal::CORRECT_SPACE_SUBSTITUTION = true; -const bool TypingTraversal::CORRECT_SPACE_OMISSION = true; +const bool TypingTraversal::CORRECT_NEW_WORD_SPACE_SUBSTITUTION = true; +const bool TypingTraversal::CORRECT_NEW_WORD_SPACE_OMISSION = true; const TypingTraversal TypingTraversal::sInstance; } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/typing/typing_traversal.h b/native/jni/src/suggest/policyimpl/typing/typing_traversal.h index f22029a2c..9f8347452 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_traversal.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_traversal.h @@ -66,7 +66,7 @@ class TypingTraversal : public Traversal { AK_FORCE_INLINE bool isSpaceSubstitutionTerminal( const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const { - if (!CORRECT_SPACE_SUBSTITUTION) { + if (!CORRECT_NEW_WORD_SPACE_SUBSTITUTION) { return false; } if (!canDoLookAheadCorrection(traverseSession, dicNode)) { @@ -80,7 +80,7 @@ class TypingTraversal : public Traversal { AK_FORCE_INLINE bool isSpaceOmissionTerminal( const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const { - if (!CORRECT_SPACE_OMISSION) { + if (!CORRECT_NEW_WORD_SPACE_OMISSION) { return false; } const int inputSize = traverseSession->getInputSize(); @@ -173,8 +173,8 @@ class TypingTraversal : public Traversal { private: DISALLOW_COPY_AND_ASSIGN(TypingTraversal); static const bool CORRECT_OMISSION; - static const bool CORRECT_SPACE_SUBSTITUTION; - static const bool CORRECT_SPACE_OMISSION; + static const bool CORRECT_NEW_WORD_SPACE_SUBSTITUTION; + static const bool CORRECT_NEW_WORD_SPACE_OMISSION; static const TypingTraversal sInstance; TypingTraversal() {} diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h index 52d54eb0f..34d25ae1a 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h @@ -18,6 +18,7 @@ #define LATINIME_TYPING_WEIGHTING_H #include "defines.h" +#include "suggest_utils.h" #include "suggest/core/dicnode/dic_node_utils.h" #include "suggest/core/policy/weighting.h" #include "suggest/core/session/dic_traverse_session.h" @@ -70,10 +71,12 @@ class TypingWeighting : public Weighting { const int pointIndex = dicNode->getInputIndex(0); // Note: min() required since length can be MAX_POINT_TO_KEY_LENGTH for characters not on // the keyboard (like accented letters) - const float length = min(ScoringParams::MAX_SPATIAL_DISTANCE, - traverseSession->getProximityInfoState(0)->getPointToKeyLength( - pointIndex, dicNode->getNodeCodePoint())); - const float weightedDistance = length * ScoringParams::DISTANCE_WEIGHT_LENGTH; + const float normalizedSquaredLength = traverseSession->getProximityInfoState(0) + ->getPointToKeyLength(pointIndex, dicNode->getNodeCodePoint()); + const float normalizedDistance = SuggestUtils::getSweetSpotFactor( + traverseSession->isTouchPositionCorrectionEnabled(), normalizedSquaredLength); + const float weightedDistance = ScoringParams::DISTANCE_WEIGHT_LENGTH * normalizedDistance; + const bool isFirstChar = pointIndex == 0; const bool isProximity = isProximityDicNode(traverseSession, dicNode); const float cost = isProximity ? (isFirstChar ? ScoringParams::FIRST_PROXIMITY_COST @@ -125,17 +128,19 @@ class TypingWeighting : public Weighting { return cost + weightedDistance; } - float getNewWordCost(const DicNode *const dicNode) const { + float getNewWordCost(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const { const bool isCapitalized = dicNode->isCapitalized(); - return isCapitalized ? + const float cost = isCapitalized ? ScoringParams::COST_NEW_WORD_CAPITALIZED : ScoringParams::COST_NEW_WORD; + return cost * traverseSession->getMultiWordCostMultiplier(); } float getNewWordBigramCost( const DicTraverseSession *const traverseSession, const DicNode *const dicNode, hash_map_compat<int, int16_t> *const bigramCacheMap) const { return DicNodeUtils::getBigramNodeImprobability(traverseSession->getOffsetDict(), - dicNode, bigramCacheMap); + dicNode, bigramCacheMap) * ScoringParams::DISTANCE_WEIGHT_LANGUAGE; } float getCompletionCost(const DicTraverseSession *const traverseSession, @@ -159,13 +164,8 @@ class TypingWeighting : public Weighting { // because the input word shouldn't be treated as perfect const bool isExactMatch = !hasEditCount && !hasMultipleWords && !hasProximityErrors && isSameLength; - - const float totalPrevWordsLanguageCost = dicNode->getTotalPrevWordsLanguageCost(); const float languageImprobability = isExactMatch ? 0.0f : dicNodeLanguageImprobability; - const float languageWeight = ScoringParams::DISTANCE_WEIGHT_LANGUAGE; - // TODO: Caveat: The following equation should be: - // totalPrevWordsLanguageCost + (languageImprobability * languageWeight); - return (totalPrevWordsLanguageCost + languageImprobability) * languageWeight; + return languageImprobability * ScoringParams::DISTANCE_WEIGHT_LANGUAGE; } AK_FORCE_INLINE bool needsToNormalizeCompoundDistance() const { @@ -180,8 +180,13 @@ class TypingWeighting : public Weighting { return ScoringParams::SUBSTITUTION_COST; } - AK_FORCE_INLINE float getSpaceSubstitutionCost() const { - return ScoringParams::SPACE_SUBSTITUTION_COST; + AK_FORCE_INLINE float getSpaceSubstitutionCost( + const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const { + const bool isCapitalized = dicNode->isCapitalized(); + const float cost = ScoringParams::SPACE_SUBSTITUTION_COST + (isCapitalized ? + ScoringParams::COST_NEW_WORD_CAPITALIZED : ScoringParams::COST_NEW_WORD); + return cost * traverseSession->getMultiWordCostMultiplier(); } private: diff --git a/native/jni/src/suggest_utils.h b/native/jni/src/suggest_utils.h index aab9f7ba8..e053dd662 100644 --- a/native/jni/src/suggest_utils.h +++ b/native/jni/src/suggest_utils.h @@ -23,10 +23,8 @@ namespace latinime { class SuggestUtils { public: - static float getDistanceScalingFactor(const float normalizedSquaredDistance) { - if (normalizedSquaredDistance < 0.0f) { - return -1.0f; - } + // TODO: (OLD) Remove + static float getLengthScalingFactor(const float normalizedSquaredDistance) { // Promote or demote the score according to the distance from the sweet spot static const float A = ZERO_DISTANCE_PROMOTION_RATE / 100.0f; static const float B = 1.0f; @@ -50,6 +48,39 @@ class SuggestUtils { return factor; } + static float getSweetSpotFactor(const bool isTouchPositionCorrectionEnabled, + const float normalizedSquaredDistance) { + // Promote or demote the score according to the distance from the sweet spot + static const float A = 0.0f; + static const float B = 0.24f; + static const float C = 1.20f; + static const float R0 = 0.0f; + static const float R1 = 0.25f; // Sweet spot + static const float R2 = 1.0f; + const float x = normalizedSquaredDistance; + if (!isTouchPositionCorrectionEnabled) { + return min(C, x); + } + + // factor is a piecewise linear function like: + // C -------------. + // / . + // B / . + // -/ . + // A _-^ . + // . + // R0 R1 R2 . + + if (x < R0) { + return A; + } else if (x < R1) { + return (A * (R1 - x) + B * (x - R0)) / (R1 - R0); + } else if (x < R2) { + return (B * (R2 - x) + C * (x - R1)) / (R2 - R1); + } else { + return C; + } + } private: DISALLOW_IMPLICIT_CONSTRUCTORS(SuggestUtils); }; |