diff options
Diffstat (limited to 'native')
29 files changed, 681 insertions, 360 deletions
diff --git a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp index 11fa3da3a..1dd68ea8b 100644 --- a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp +++ b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp @@ -109,7 +109,8 @@ static jlong latinime_BinaryDictionary_open(JNIEnv *env, jclass clazz, jstring s } Dictionary *dictionary = 0; if (BinaryFormat::UNKNOWN_FORMAT - == BinaryFormat::detectFormat(static_cast<uint8_t *>(dictBuf))) { + == BinaryFormat::detectFormat(static_cast<uint8_t *>(dictBuf), + static_cast<int>(dictSize))) { AKLOGE("DICT: dictionary format is unknown, bad magic number"); #ifdef USE_MMAP_FOR_DICTIONARY releaseDictBuf(static_cast<const char *>(dictBuf) - adjust, adjDictSize, fd); diff --git a/native/jni/src/bigram_dictionary.cpp b/native/jni/src/bigram_dictionary.cpp index 92890383a..9053e7226 100644 --- a/native/jni/src/bigram_dictionary.cpp +++ b/native/jni/src/bigram_dictionary.cpp @@ -187,7 +187,7 @@ void BigramDictionary::fillBigramAddressToProbabilityMapAndFilter(const int *pre &pos); (*map)[bigramPos] = probability; setInFilter(filter, bigramPos); - } while (0 != (BinaryFormat::FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags)); + } while (BinaryFormat::FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags); } bool BigramDictionary::checkFirstCharacter(int *word, int *inputCodePoints) const { diff --git a/native/jni/src/binary_format.h b/native/jni/src/binary_format.h index ad16039ef..98241532f 100644 --- a/native/jni/src/binary_format.h +++ b/native/jni/src/binary_format.h @@ -23,6 +23,7 @@ #include "bloom_filter.h" #include "char_utils.h" +#include "hash_map_compat.h" namespace latinime { @@ -63,13 +64,14 @@ class BinaryFormat { static const int UNKNOWN_FORMAT = -1; static const int SHORTCUT_LIST_SIZE_SIZE = 2; - static int detectFormat(const uint8_t *const dict); - static int getHeaderSize(const uint8_t *const dict); - static int getFlags(const uint8_t *const dict); + static int detectFormat(const uint8_t *const dict, const int dictSize); + static int getHeaderSize(const uint8_t *const dict, const int dictSize); + static int getFlags(const uint8_t *const dict, const int dictSize); static bool hasBlacklistedOrNotAWordFlag(const int flags); - static void readHeaderValue(const uint8_t *const dict, const char *const key, int *outValue, - const int outValueSize); - static int readHeaderValueInt(const uint8_t *const dict, const char *const key); + static void readHeaderValue(const uint8_t *const dict, const int dictSize, + const char *const key, int *outValue, const int outValueSize); + static int readHeaderValueInt(const uint8_t *const dict, const int dictSize, + const char *const key); static int getGroupCountAndForwardPointer(const uint8_t *const dict, int *pos); static uint8_t getFlagsAndForwardPointer(const uint8_t *const dict, int *pos); static int getCodePointAndForwardPointer(const uint8_t *const dict, int *pos); @@ -93,7 +95,13 @@ class BinaryFormat { const int unigramProbability, const int bigramProbability); static int getProbability(const int position, const std::map<int, int> *bigramMap, const uint8_t *bigramFilter, const int unigramProbability); - static float getMultiWordCostMultiplier(const uint8_t *const dict); + static int getBigramProbabilityFromHashMap(const int position, + const hash_map_compat<int, int> *bigramMap, const int unigramProbability); + static float getMultiWordCostMultiplier(const uint8_t *const dict, const int dictSize); + static void fillBigramProbabilityToHashMap(const uint8_t *const root, int position, + hash_map_compat<int, int> *bigramMap); + static int getBigramProbability(const uint8_t *const root, int position, + const int nextPosition, const int unigramProbability); // Flags for special processing // Those *must* match the flags in makedict (BinaryDictInputOutput#*_PROCESSING_FLAG) or @@ -105,6 +113,8 @@ class BinaryFormat { private: DISALLOW_IMPLICIT_CONSTRUCTORS(BinaryFormat); + static int getBigramListPositionForWordPosition(const uint8_t *const root, int position); + static const int FLAG_GROUP_ADDRESS_TYPE_NOADDRESS = 0x00; static const int FLAG_GROUP_ADDRESS_TYPE_ONEBYTE = 0x40; static const int FLAG_GROUP_ADDRESS_TYPE_TWOBYTES = 0x80; @@ -113,6 +123,8 @@ class BinaryFormat { static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES = 0x20; static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES = 0x30; + // Any file smaller than this is not a dictionary. + static const int DICTIONARY_MINIMUM_SIZE = 4; // Originally, format version 1 had a 16-bit magic number, then the version number `01' // then options that must be 0. Hence the first 32-bits of the format are always as follow // and it's okay to consider them a magic number as a whole. @@ -122,6 +134,8 @@ class BinaryFormat { // number, so we had to change it so that version 2 files would be rejected by older // implementations. On this occasion, we made the magic number 32 bits long. static const int FORMAT_VERSION_2_MAGIC_NUMBER = -1681835266; // 0x9BC13AFE + // Magic number (4 bytes), version (2 bytes), options (2 bytes), header size (4 bytes) = 12 + static const int FORMAT_VERSION_2_MINIMUM_SIZE = 12; static const int CHARACTER_ARRAY_TERMINATOR_SIZE = 1; static const int MINIMAL_ONE_BYTE_CHARACTER_VALUE = 0x20; @@ -132,8 +146,11 @@ class BinaryFormat { static int skipBigrams(const uint8_t *const dict, const uint8_t flags, const int pos); }; -AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict) { +AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict, const int dictSize) { // The magic number is stored big-endian. + // If the dictionary is less than 4 bytes, we can't even read the magic number, so we don't + // understand this format. + if (dictSize < DICTIONARY_MINIMUM_SIZE) return UNKNOWN_FORMAT; const int magicNumber = (dict[0] << 24) + (dict[1] << 16) + (dict[2] << 8) + dict[3]; switch (magicNumber) { case FORMAT_VERSION_1_MAGIC_NUMBER: @@ -143,6 +160,10 @@ AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict) { // Options (2 bytes) must be 0x00 0x00 return 1; case FORMAT_VERSION_2_MAGIC_NUMBER: + // Version 2 dictionaries are at least 12 bytes long (see below details for the header). + // If this dictionary has the version 2 magic number but is less than 12 bytes long, then + // it's an unknown format and we need to avoid confidently reading the next bytes. + if (dictSize < FORMAT_VERSION_2_MINIMUM_SIZE) return UNKNOWN_FORMAT; // Format 2 header is as follows: // Magic number (4 bytes) 0x9B 0xC1 0x3A 0xFE // Version number (2 bytes) 0x00 0x02 @@ -154,8 +175,8 @@ AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict) { } } -inline int BinaryFormat::getFlags(const uint8_t *const dict) { - switch (detectFormat(dict)) { +inline int BinaryFormat::getFlags(const uint8_t *const dict, const int dictSize) { + switch (detectFormat(dict, dictSize)) { case 1: return NO_FLAGS; // TODO: NO_FLAGS is unused anywhere else? default: @@ -164,11 +185,11 @@ inline int BinaryFormat::getFlags(const uint8_t *const dict) { } inline bool BinaryFormat::hasBlacklistedOrNotAWordFlag(const int flags) { - return flags & (FLAG_IS_BLACKLISTED | FLAG_IS_NOT_A_WORD); + return (flags & (FLAG_IS_BLACKLISTED | FLAG_IS_NOT_A_WORD)) != 0; } -inline int BinaryFormat::getHeaderSize(const uint8_t *const dict) { - switch (detectFormat(dict)) { +inline int BinaryFormat::getHeaderSize(const uint8_t *const dict, const int dictSize) { + switch (detectFormat(dict, dictSize)) { case 1: return FORMAT_VERSION_1_HEADER_SIZE; case 2: @@ -179,12 +200,12 @@ inline int BinaryFormat::getHeaderSize(const uint8_t *const dict) { } } -inline void BinaryFormat::readHeaderValue(const uint8_t *const dict, const char *const key, - int *outValue, const int outValueSize) { +inline void BinaryFormat::readHeaderValue(const uint8_t *const dict, const int dictSize, + const char *const key, int *outValue, const int outValueSize) { int outValueIndex = 0; // Only format 2 and above have header attributes as {key,value} string pairs. For prior // formats, we just return an empty string, as if the key wasn't found. - if (2 <= detectFormat(dict)) { + if (2 <= detectFormat(dict, dictSize)) { const int headerOptionsOffset = 4 /* magic number */ + 2 /* dictionary version */ + 2 /* flags */; const int headerSize = @@ -227,11 +248,12 @@ inline void BinaryFormat::readHeaderValue(const uint8_t *const dict, const char if (outValueIndex >= 0) outValue[outValueIndex] = 0; } -inline int BinaryFormat::readHeaderValueInt(const uint8_t *const dict, const char *const key) { +inline int BinaryFormat::readHeaderValueInt(const uint8_t *const dict, const int dictSize, + const char *const key) { const int bufferSize = LARGEST_INT_DIGIT_COUNT; int intBuffer[bufferSize]; char charBuffer[bufferSize]; - BinaryFormat::readHeaderValue(dict, key, intBuffer, bufferSize); + BinaryFormat::readHeaderValue(dict, dictSize, key, intBuffer, bufferSize); for (int i = 0; i < bufferSize; ++i) { charBuffer[i] = intBuffer[i]; } @@ -247,8 +269,10 @@ AK_FORCE_INLINE int BinaryFormat::getGroupCountAndForwardPointer(const uint8_t * return ((msb & 0x7F) << 8) | dict[(*pos)++]; } -inline float BinaryFormat::getMultiWordCostMultiplier(const uint8_t *const dict) { - const int headerValue = readHeaderValueInt(dict, "MULTIPLE_WORDS_DEMOTION_RATE"); +inline float BinaryFormat::getMultiWordCostMultiplier(const uint8_t *const dict, + const int dictSize) { + const int headerValue = readHeaderValueInt(dict, dictSize, + "MULTIPLE_WORDS_DEMOTION_RATE"); if (headerValue == S_INT_MIN) { return 1.0f; } @@ -687,5 +711,68 @@ inline int BinaryFormat::getProbability(const int position, const std::map<int, } return backoff(unigramProbability); } + +// This returns a probability in log space. +inline int BinaryFormat::getBigramProbabilityFromHashMap(const int position, + const hash_map_compat<int, int> *bigramMap, const int unigramProbability) { + if (!bigramMap) return backoff(unigramProbability); + const hash_map_compat<int, int>::const_iterator bigramProbabilityIt = bigramMap->find(position); + if (bigramProbabilityIt != bigramMap->end()) { + const int bigramProbability = bigramProbabilityIt->second; + return computeProbabilityForBigram(unigramProbability, bigramProbability); + } + return backoff(unigramProbability); +} + +AK_FORCE_INLINE void BinaryFormat::fillBigramProbabilityToHashMap( + const uint8_t *const root, int position, hash_map_compat<int, int> *bigramMap) { + position = getBigramListPositionForWordPosition(root, position); + if (0 == position) return; + + uint8_t bigramFlags; + do { + bigramFlags = getFlagsAndForwardPointer(root, &position); + const int probability = MASK_ATTRIBUTE_PROBABILITY & bigramFlags; + const int bigramPos = getAttributeAddressAndForwardPointer(root, bigramFlags, + &position); + (*bigramMap)[bigramPos] = probability; + } while (FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags); +} + +AK_FORCE_INLINE int BinaryFormat::getBigramProbability(const uint8_t *const root, int position, + const int nextPosition, const int unigramProbability) { + position = getBigramListPositionForWordPosition(root, position); + if (0 == position) return backoff(unigramProbability); + + uint8_t bigramFlags; + do { + bigramFlags = getFlagsAndForwardPointer(root, &position); + const int bigramPos = getAttributeAddressAndForwardPointer( + root, bigramFlags, &position); + if (bigramPos == nextPosition) { + const int bigramProbability = MASK_ATTRIBUTE_PROBABILITY & bigramFlags; + return computeProbabilityForBigram(unigramProbability, bigramProbability); + } + } while (FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags); + return backoff(unigramProbability); +} + +// Returns a pointer to the start of the bigram list. +AK_FORCE_INLINE int BinaryFormat::getBigramListPositionForWordPosition( + const uint8_t *const root, int position) { + if (NOT_VALID_WORD == position) return 0; + const uint8_t flags = getFlagsAndForwardPointer(root, &position); + if (!(flags & FLAG_HAS_BIGRAMS)) return 0; + if (flags & FLAG_HAS_MULTIPLE_CHARS) { + position = skipOtherCharacters(root, position); + } else { + getCodePointAndForwardPointer(root, &position); + } + position = skipProbability(flags, position); + position = skipChildrenPosition(flags, position); + position = skipShortcuts(root, flags, position); + return position; +} + } // namespace latinime #endif // LATINIME_BINARY_FORMAT_H diff --git a/native/jni/src/correction.cpp b/native/jni/src/correction.cpp index 0c65939e0..61bf3f619 100644 --- a/native/jni/src/correction.cpp +++ b/native/jni/src/correction.cpp @@ -23,6 +23,8 @@ #include "defines.h" #include "proximity_info_state.h" #include "suggest_utils.h" +#include "suggest/policyimpl/utils/edit_distance.h" +#include "suggest/policyimpl/utils/damerau_levenshtein_edit_distance_policy.h" namespace latinime { @@ -906,50 +908,11 @@ inline static bool isUpperCase(unsigned short c) { return totalFreq; } -/* Damerau-Levenshtein distance */ -inline static int editDistanceInternal(int *editDistanceTable, const int *before, - const int beforeLength, const int *after, const int afterLength) { - // dp[li][lo] dp[a][b] = dp[ a * lo + b] - int *dp = editDistanceTable; - const int li = beforeLength + 1; - const int lo = afterLength + 1; - for (int i = 0; i < li; ++i) { - dp[lo * i] = i; - } - for (int i = 0; i < lo; ++i) { - dp[i] = i; - } - - for (int i = 0; i < li - 1; ++i) { - for (int j = 0; j < lo - 1; ++j) { - const int ci = toBaseLowerCase(before[i]); - const int co = toBaseLowerCase(after[j]); - const int cost = (ci == co) ? 0 : 1; - dp[(i + 1) * lo + (j + 1)] = min(dp[i * lo + (j + 1)] + 1, - min(dp[(i + 1) * lo + j] + 1, dp[i * lo + j] + cost)); - if (i > 0 && j > 0 && ci == toBaseLowerCase(after[j - 1]) - && co == toBaseLowerCase(before[i - 1])) { - dp[(i + 1) * lo + (j + 1)] = min( - dp[(i + 1) * lo + (j + 1)], dp[(i - 1) * lo + (j - 1)] + cost); - } - } - } - - if (DEBUG_EDIT_DISTANCE) { - AKLOGI("IN = %d, OUT = %d", beforeLength, afterLength); - for (int i = 0; i < li; ++i) { - for (int j = 0; j < lo; ++j) { - AKLOGI("EDIT[%d][%d], %d", i, j, dp[i * lo + j]); - } - } - } - return dp[li * lo - 1]; -} - /* static */ int Correction::RankingAlgorithm::editDistance(const int *before, const int beforeLength, const int *after, const int afterLength) { - int table[(beforeLength + 1) * (afterLength + 1)]; - return editDistanceInternal(table, before, beforeLength, after, afterLength); + const DamerauLevenshteinEditDistancePolicy daemaruLevenshtein( + before, beforeLength, after, afterLength); + return static_cast<int>(EditDistance::getEditDistance(&daemaruLevenshtein)); } diff --git a/native/jni/src/defines.h b/native/jni/src/defines.h index 6ef9f414b..eb59744f6 100644 --- a/native/jni/src/defines.h +++ b/native/jni/src/defines.h @@ -379,6 +379,15 @@ static inline void prof_out(void) { #error "BIGRAM_FILTER_MODULO is larger than BIGRAM_FILTER_BYTE_SIZE" #endif +// Max number of bigram maps (previous word contexts) to be cached. Increasing this number could +// improve bigram lookup speed for multi-word suggestions, but at the cost of more memory usage. +// Also, there are diminishing returns since the most frequently used bigrams are typically near +// the beginning of the input and are thus the first ones to be cached. Note that these bigrams +// are reset for each new composing word. +#define MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP 25 +// Most common previous word contexts currently have 100 bigrams +#define DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP 100 + template<typename T> AK_FORCE_INLINE const T &min(const T &a, const T &b) { return a < b ? a : b; } template<typename T> AK_FORCE_INLINE const T &max(const T &a, const T &b) { return a > b ? a : b; } @@ -417,16 +426,45 @@ typedef enum { } DoubleLetterLevel; typedef enum { + // Correction for MATCH_CHAR CT_MATCH, + // Correction for PROXIMITY_CHAR CT_PROXIMITY, + // Correction for ADDITIONAL_PROXIMITY_CHAR CT_ADDITIONAL_PROXIMITY, + // Correction for SUBSTITUTION_CHAR CT_SUBSTITUTION, + // Skip one omitted letter CT_OMISSION, + // Delete an unnecessarily inserted letter CT_INSERTION, + // Swap the order of next two touch points CT_TRANSPOSITION, CT_COMPLETION, CT_TERMINAL, + // Create new word with space omission CT_NEW_WORD_SPACE_OMITTION, + // Create new word with space substitution CT_NEW_WORD_SPACE_SUBSTITUTION, } CorrectionType; + +// ErrorType is mainly decided by CorrectionType but it is also depending on if +// the correction has really been performed or not. +typedef enum { + // Substitution, omission and transposition + ET_EDIT_CORRECTION, + // Proximity error + ET_PROXIMITY_CORRECTION, + // Completion + ET_COMPLETION, + // New word + // TODO: Remove. + // A new word error should be an edit correction error or a proximity correction error. + ET_NEW_WORD, + // Treat error as an intentional omission when the CorrectionType is omission and the node can + // be intentional omission. + ET_INTENTIONAL_OMISSION, + // Not treated as an error. Tracked for checking exact match + ET_NOT_AN_ERROR +} ErrorType; #endif // LATINIME_DEFINES_H diff --git a/native/jni/src/dictionary.cpp b/native/jni/src/dictionary.cpp index c998c0676..dadb2bab2 100644 --- a/native/jni/src/dictionary.cpp +++ b/native/jni/src/dictionary.cpp @@ -34,9 +34,11 @@ namespace latinime { Dictionary::Dictionary(void *dict, int dictSize, int mmapFd, int dictBufAdjust) : mDict(static_cast<unsigned char *>(dict)), - mOffsetDict((static_cast<unsigned char *>(dict)) + BinaryFormat::getHeaderSize(mDict)), + mOffsetDict((static_cast<unsigned char *>(dict)) + + BinaryFormat::getHeaderSize(mDict, dictSize)), mDictSize(dictSize), mMmapFd(mmapFd), mDictBufAdjust(dictBufAdjust), - mUnigramDictionary(new UnigramDictionary(mOffsetDict, BinaryFormat::getFlags(mDict))), + mUnigramDictionary(new UnigramDictionary(mOffsetDict, + BinaryFormat::getFlags(mDict, dictSize))), mBigramDictionary(new BigramDictionary(mOffsetDict)), mGestureSuggest(new Suggest(GestureSuggestPolicyFactory::getGestureSuggestPolicy())), mTypingSuggest(new Suggest(TypingSuggestPolicyFactory::getTypingSuggestPolicy())) { diff --git a/native/jni/src/dictionary.h b/native/jni/src/dictionary.h index 0653d3ca9..2ad5b6c0b 100644 --- a/native/jni/src/dictionary.h +++ b/native/jni/src/dictionary.h @@ -31,6 +31,7 @@ class UnigramDictionary; class Dictionary { public: // Taken from SuggestedWords.java + static const int KIND_MASK_KIND = 0xFF; // Mask to get only the kind static const int KIND_TYPED = 0; // What user typed static const int KIND_CORRECTION = 1; // Simple correction/suggestion static const int KIND_COMPLETION = 2; // Completion (suggestion with appended chars) @@ -41,6 +42,10 @@ class Dictionary { static const int KIND_SHORTCUT = 7; // A shortcut static const int KIND_PREDICTION = 8; // A prediction (== a suggestion with no input) + static const int KIND_MASK_FLAGS = 0xFFFFFF00; // Mask to get the flags + static const int KIND_FLAG_POSSIBLY_OFFENSIVE = 0x80000000; + static const int KIND_FLAG_EXACT_MATCH = 0x40000000; + Dictionary(void *dict, int dictSize, int mmapFd, int dictBufAdjust); int getSuggestions(ProximityInfo *proximityInfo, void *traverseSession, int *xcoordinates, diff --git a/native/jni/src/multi_bigram_map.h b/native/jni/src/multi_bigram_map.h new file mode 100644 index 000000000..7e1b6301f --- /dev/null +++ b/native/jni/src/multi_bigram_map.h @@ -0,0 +1,89 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_MULTI_BIGRAM_MAP_H +#define LATINIME_MULTI_BIGRAM_MAP_H + +#include <cstring> +#include <stdint.h> + +#include "defines.h" +#include "binary_format.h" +#include "hash_map_compat.h" + +namespace latinime { + +// Class for caching bigram maps for multiple previous word contexts. This is useful since the +// algorithm needs to look up the set of bigrams for every word pair that occurs in every +// multi-word suggestion. +class MultiBigramMap { + public: + MultiBigramMap() : mBigramMaps() {} + ~MultiBigramMap() {} + + // Look up the bigram probability for the given word pair from the cached bigram maps. + // Also caches the bigrams if there is space remaining and they have not been cached already. + int getBigramProbability(const uint8_t *const dicRoot, const int wordPosition, + const int nextWordPosition, const int unigramProbability) { + hash_map_compat<int, BigramMap>::const_iterator mapPosition = + mBigramMaps.find(wordPosition); + if (mapPosition != mBigramMaps.end()) { + return mapPosition->second.getBigramProbability(nextWordPosition, unigramProbability); + } + if (mBigramMaps.size() < MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP) { + addBigramsForWordPosition(dicRoot, wordPosition); + return mBigramMaps[wordPosition].getBigramProbability( + nextWordPosition, unigramProbability); + } + return BinaryFormat::getBigramProbability( + dicRoot, wordPosition, nextWordPosition, unigramProbability); + } + + void clear() { + mBigramMaps.clear(); + } + + private: + DISALLOW_COPY_AND_ASSIGN(MultiBigramMap); + + class BigramMap { + public: + BigramMap() : mBigramMap(DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP) {} + ~BigramMap() {} + + void init(const uint8_t *const dicRoot, int position) { + BinaryFormat::fillBigramProbabilityToHashMap(dicRoot, position, &mBigramMap); + } + + inline int getBigramProbability(const int nextWordPosition, const int unigramProbability) + const { + return BinaryFormat::getBigramProbabilityFromHashMap( + nextWordPosition, &mBigramMap, unigramProbability); + } + + private: + // Note: Default copy constructor needed for use in hash_map. + hash_map_compat<int, int> mBigramMap; + }; + + void addBigramsForWordPosition(const uint8_t *const dicRoot, const int position) { + mBigramMaps[position].init(dicRoot, position); + } + + hash_map_compat<int, BigramMap> mBigramMaps; +}; +} // namespace latinime +#endif // LATINIME_MULTI_BIGRAM_MAP_H diff --git a/native/jni/src/suggest/core/dicnode/dic_node.h b/native/jni/src/suggest/core/dicnode/dic_node.h index e8432546b..4225bb3e5 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node.h +++ b/native/jni/src/suggest/core/dicnode/dic_node.h @@ -219,7 +219,7 @@ class DicNode { return (prevWordLen == 1 && currentWordLen == 1); } - bool isCapitalized() const { + bool isFirstCharUppercase() const { const int c = getOutputWordBuf()[0]; return isAsciiUpper(c); } @@ -463,6 +463,10 @@ class DicNode { mDicNodeState.mDicNodeStateScoring.advanceDigraphIndex(); } + bool isExactMatch() const { + return mDicNodeState.mDicNodeStateScoring.isExactMatch(); + } + uint8_t getFlags() const { return mDicNodeProperties.getFlags(); } @@ -542,13 +546,12 @@ class DicNode { // Caveat: Must not be called outside Weighting // This restriction is guaranteed by "friend" AK_FORCE_INLINE void addCost(const float spatialCost, const float languageCost, - const bool doNormalization, const int inputSize, const bool isEditCorrection, - const bool isProximityCorrection) { + const bool doNormalization, const int inputSize, const ErrorType errorType) { if (DEBUG_GEO_FULL) { LOGI_SHOW_ADD_COST_PROP; } mDicNodeState.mDicNodeStateScoring.addCost(spatialCost, languageCost, doNormalization, - inputSize, getTotalInputIndex(), isEditCorrection, isProximityCorrection); + inputSize, getTotalInputIndex(), errorType); } // Caveat: Must not be called outside Weighting diff --git a/native/jni/src/suggest/core/dicnode/dic_node_state_input.h b/native/jni/src/suggest/core/dicnode/dic_node_state_input.h index 7ad3e3e5f..bbd9435b5 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_state_input.h +++ b/native/jni/src/suggest/core/dicnode/dic_node_state_input.h @@ -46,8 +46,8 @@ class DicNodeStateInput { for (int i = 0; i < MAX_POINTER_COUNT_G; i++) { mInputIndex[i] = src->mInputIndex[i]; mPrevCodePoint[i] = src->mPrevCodePoint[i]; - mTerminalDiffCost[i] = resetTerminalDiffCost ? - static_cast<float>(MAX_VALUE_FOR_WEIGHTING) : src->mTerminalDiffCost[i]; + mTerminalDiffCost[i] = resetTerminalDiffCost ? + static_cast<float>(MAX_VALUE_FOR_WEIGHTING) : src->mTerminalDiffCost[i]; } } diff --git a/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h b/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h index fd9d610e3..dca9d60da 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h +++ b/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h @@ -31,7 +31,7 @@ class DicNodeStateScoring { mDigraphIndex(DigraphUtils::NOT_A_DIGRAPH_INDEX), mEditCorrectionCount(0), mProximityCorrectionCount(0), mNormalizedCompoundDistance(0.0f), mSpatialDistance(0.0f), mLanguageDistance(0.0f), - mRawLength(0.0f) { + mRawLength(0.0f), mExactMatch(true) { } virtual ~DicNodeStateScoring() {} @@ -45,6 +45,7 @@ class DicNodeStateScoring { mRawLength = 0.0f; mDoubleLetterLevel = NOT_A_DOUBLE_LETTER; mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX; + mExactMatch = true; } AK_FORCE_INLINE void init(const DicNodeStateScoring *const scoring) { @@ -56,17 +57,32 @@ class DicNodeStateScoring { mRawLength = scoring->mRawLength; mDoubleLetterLevel = scoring->mDoubleLetterLevel; mDigraphIndex = scoring->mDigraphIndex; + mExactMatch = scoring->mExactMatch; } void addCost(const float spatialCost, const float languageCost, const bool doNormalization, - const int inputSize, const int totalInputIndex, const bool isEditCorrection, - const bool isProximityCorrection) { + const int inputSize, const int totalInputIndex, const ErrorType errorType) { addDistance(spatialCost, languageCost, doNormalization, inputSize, totalInputIndex); - if (isEditCorrection) { - ++mEditCorrectionCount; - } - if (isProximityCorrection) { - ++mProximityCorrectionCount; + switch (errorType) { + case ET_EDIT_CORRECTION: + ++mEditCorrectionCount; + mExactMatch = false; + break; + case ET_PROXIMITY_CORRECTION: + ++mProximityCorrectionCount; + mExactMatch = false; + break; + case ET_COMPLETION: + mExactMatch = false; + break; + case ET_NEW_WORD: + mExactMatch = false; + break; + case ET_INTENTIONAL_OMISSION: + mExactMatch = false; + break; + case ET_NOT_AN_ERROR: + break; } } @@ -143,6 +159,10 @@ class DicNodeStateScoring { } } + bool isExactMatch() const { + return mExactMatch; + } + private: // Caution!!! // Use a default copy constructor and an assign operator because shallow copies are ok @@ -157,6 +177,7 @@ class DicNodeStateScoring { float mSpatialDistance; float mLanguageDistance; float mRawLength; + bool mExactMatch; AK_FORCE_INLINE void addDistance(float spatialDistance, float languageDistance, bool doNormalization, int inputSize, int totalInputIndex) { diff --git a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp index 031e706ae..5357c3773 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp +++ b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp @@ -21,6 +21,7 @@ #include "dic_node.h" #include "dic_node_utils.h" #include "dic_node_vector.h" +#include "multi_bigram_map.h" #include "proximity_info.h" #include "proximity_info_state.h" @@ -191,11 +192,11 @@ namespace latinime { * Computes the combined bigram / unigram cost for the given dicNode. */ /* static */ float DicNodeUtils::getBigramNodeImprobability(const uint8_t *const dicRoot, - const DicNode *const node, hash_map_compat<int, int16_t> *bigramCacheMap) { + const DicNode *const node, MultiBigramMap *multiBigramMap) { if (node->isImpossibleBigramWord()) { return static_cast<float>(MAX_VALUE_FOR_WEIGHTING); } - const int probability = getBigramNodeProbability(dicRoot, node, bigramCacheMap); + const int probability = getBigramNodeProbability(dicRoot, node, multiBigramMap); // TODO: This equation to calculate the improbability looks unreasonable. Investigate this. const float cost = static_cast<float>(MAX_PROBABILITY - probability) / static_cast<float>(MAX_PROBABILITY); @@ -203,92 +204,25 @@ namespace latinime { } /* static */ int DicNodeUtils::getBigramNodeProbability(const uint8_t *const dicRoot, - const DicNode *const node, hash_map_compat<int, int16_t> *bigramCacheMap) { + const DicNode *const node, MultiBigramMap *multiBigramMap) { const int unigramProbability = node->getProbability(); - const int encodedDiffOfBigramProbability = - getBigramNodeEncodedDiffProbability(dicRoot, node, bigramCacheMap); - if (NOT_A_PROBABILITY == encodedDiffOfBigramProbability) { + const int wordPos = node->getPos(); + const int prevWordPos = node->getPrevWordPos(); + if (NOT_VALID_WORD == wordPos || NOT_VALID_WORD == prevWordPos) { + // Note: Normally wordPos comes from the dictionary and should never equal NOT_VALID_WORD. return backoff(unigramProbability); } - return BinaryFormat::computeProbabilityForBigram( - unigramProbability, encodedDiffOfBigramProbability); + if (multiBigramMap) { + return multiBigramMap->getBigramProbability( + dicRoot, prevWordPos, wordPos, unigramProbability); + } + return BinaryFormat::getBigramProbability(dicRoot, prevWordPos, wordPos, unigramProbability); } /////////////////////////////////////// // Bigram / Unigram dictionary utils // /////////////////////////////////////// -/* static */ int16_t DicNodeUtils::getBigramNodeEncodedDiffProbability(const uint8_t *const dicRoot, - const DicNode *const node, hash_map_compat<int, int16_t> *bigramCacheMap) { - const int wordPos = node->getPos(); - const int prevWordPos = node->getPrevWordPos(); - return getBigramProbability(dicRoot, prevWordPos, wordPos, bigramCacheMap); -} - -// TODO: Move this to BigramDictionary -/* static */ int16_t DicNodeUtils::getBigramProbability(const uint8_t *const dicRoot, int pos, - const int nextPos, hash_map_compat<int, int16_t> *bigramCacheMap) { - // TODO: this is painfully slow compared to the method used in the previous version of the - // algorithm. Switch to that method. - if (NOT_VALID_WORD == pos) return NOT_A_PROBABILITY; - if (NOT_VALID_WORD == nextPos) return NOT_A_PROBABILITY; - - // Create a hash code for the given node pair (based on Josh Bloch's effective Java). - // TODO: Use a real hash map data structure that deals with collisions. - int hash = 17; - hash = hash * 31 + pos; - hash = hash * 31 + nextPos; - - hash_map_compat<int, int16_t>::const_iterator mapPos = bigramCacheMap->find(hash); - if (mapPos != bigramCacheMap->end()) { - return mapPos->second; - } - if (NOT_VALID_WORD == pos) { - return NOT_A_PROBABILITY; - } - const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(dicRoot, &pos); - if (0 == (flags & BinaryFormat::FLAG_HAS_BIGRAMS)) { - return NOT_A_PROBABILITY; - } - if (0 == (flags & BinaryFormat::FLAG_HAS_MULTIPLE_CHARS)) { - BinaryFormat::getCodePointAndForwardPointer(dicRoot, &pos); - } else { - pos = BinaryFormat::skipOtherCharacters(dicRoot, pos); - } - pos = BinaryFormat::skipChildrenPosition(flags, pos); - pos = BinaryFormat::skipProbability(flags, pos); - uint8_t bigramFlags; - int count = 0; - do { - bigramFlags = BinaryFormat::getFlagsAndForwardPointer(dicRoot, &pos); - const int bigramPos = BinaryFormat::getAttributeAddressAndForwardPointer(dicRoot, - bigramFlags, &pos); - if (bigramPos == nextPos) { - const int16_t probability = BinaryFormat::MASK_ATTRIBUTE_PROBABILITY & bigramFlags; - if (static_cast<int>(bigramCacheMap->size()) < MAX_BIGRAM_MAP_SIZE) { - (*bigramCacheMap)[hash] = probability; - } - return probability; - } - count++; - } while ((0 != (BinaryFormat::FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags)) - && count < MAX_BIGRAMS_CONSIDERED_PER_CONTEXT); - if (static_cast<int>(bigramCacheMap->size()) < MAX_BIGRAM_MAP_SIZE) { - // TODO: does this -1 mean NOT_VALID_WORD? - (*bigramCacheMap)[hash] = -1; - } - return NOT_A_PROBABILITY; -} - -/* static */ int DicNodeUtils::getWordPos(const uint8_t *const dicRoot, const int *word, - const int wordLength) { - if (!word) { - return NOT_VALID_WORD; - } - return BinaryFormat::getTerminalPosition( - dicRoot, word, wordLength, false /* forceLowerCaseSearch */); -} - /* static */ bool DicNodeUtils::isMatchedNodeCodePoint(const ProximityInfoState *pInfoState, const int pointIndex, const bool exactOnly, const int nodeCodePoint) { if (!pInfoState) { diff --git a/native/jni/src/suggest/core/dicnode/dic_node_utils.h b/native/jni/src/suggest/core/dicnode/dic_node_utils.h index 15f9730de..5bc542d05 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_utils.h +++ b/native/jni/src/suggest/core/dicnode/dic_node_utils.h @@ -21,7 +21,6 @@ #include <vector> #include "defines.h" -#include "hash_map_compat.h" namespace latinime { @@ -29,6 +28,7 @@ class DicNode; class DicNodeVector; class ProximityInfo; class ProximityInfoState; +class MultiBigramMap; class DicNodeUtils { public: @@ -41,9 +41,8 @@ class DicNodeUtils { static void initByCopy(DicNode *srcNode, DicNode *destNode); static void getAllChildDicNodes(DicNode *dicNode, const uint8_t *const dicRoot, DicNodeVector *childDicNodes); - static int getWordPos(const uint8_t *const dicRoot, const int *word, const int prevWordLength); static float getBigramNodeImprobability(const uint8_t *const dicRoot, - const DicNode *const node, hash_map_compat<int, int16_t> *const bigramCacheMap); + const DicNode *const node, MultiBigramMap *const multiBigramMap); static bool isDicNodeFilteredOut(const int nodeCodePoint, const ProximityInfo *const pInfo, const std::vector<int> *const codePointsFilter); // TODO: Move to private @@ -58,15 +57,11 @@ class DicNodeUtils { private: DISALLOW_IMPLICIT_CONSTRUCTORS(DicNodeUtils); - // Max cache size for the space omission error correction bigram lookup - static const int MAX_BIGRAM_MAP_SIZE = 20000; // Max number of bigrams to look up static const int MAX_BIGRAMS_CONSIDERED_PER_CONTEXT = 500; static int getBigramNodeProbability(const uint8_t *const dicRoot, const DicNode *const node, - hash_map_compat<int, int16_t> *bigramCacheMap); - static int16_t getBigramNodeEncodedDiffProbability(const uint8_t *const dicRoot, - const DicNode *const node, hash_map_compat<int, int16_t> *bigramCacheMap); + MultiBigramMap *multiBigramMap); static void createAndGetPassingChildNode(DicNode *dicNode, const ProximityInfoState *pInfoState, const int pointIndex, const bool exactOnly, DicNodeVector *childDicNodes); static void createAndGetAllLeavingChildNodes(DicNode *dicNode, const uint8_t *const dicRoot, @@ -77,8 +72,6 @@ class DicNodeUtils { const int terminalDepth, const ProximityInfoState *pInfoState, const int pointIndex, const bool exactOnly, const std::vector<int> *const codePointsFilter, const ProximityInfo *const pInfo, DicNodeVector *childDicNodes); - static int16_t getBigramProbability(const uint8_t *const dicRoot, int pos, const int nextPos, - hash_map_compat<int, int16_t> *bigramCacheMap); // TODO: Move to proximity info static bool isMatchedNodeCodePoint(const ProximityInfoState *pInfoState, const int pointIndex, diff --git a/native/jni/src/suggest/core/policy/scoring.h b/native/jni/src/suggest/core/policy/scoring.h index b8c10e25a..102e856f5 100644 --- a/native/jni/src/suggest/core/policy/scoring.h +++ b/native/jni/src/suggest/core/policy/scoring.h @@ -29,16 +29,14 @@ class Scoring { public: virtual int calculateFinalScore(const float compoundDistance, const int inputSize, const bool forceCommit) const = 0; - virtual bool getMostProbableString( - const DicTraverseSession *const traverseSession, const int terminalSize, - const float languageWeight, int *const outputCodePoints, int *const type, - int *const freq) const = 0; + virtual bool getMostProbableString(const DicTraverseSession *const traverseSession, + const int terminalSize, const float languageWeight, int *const outputCodePoints, + int *const type, int *const freq) const = 0; virtual void safetyNetForMostProbableString(const int terminalSize, const int maxScore, int *const outputCodePoints, int *const frequencies) const = 0; // TODO: Make more generic - virtual void searchWordWithDoubleLetter(DicNode *terminals, - const int terminalSize, int *doubleLetterTerminalIndex, - DoubleLetterLevel *doubleLetterLevel) const = 0; + virtual void searchWordWithDoubleLetter(DicNode *terminals, const int terminalSize, + int *doubleLetterTerminalIndex, DoubleLetterLevel *doubleLetterLevel) const = 0; virtual float getAdjustedLanguageWeight(DicTraverseSession *const traverseSession, DicNode *const terminals, const int size) const = 0; virtual float getDoubleLetterDemotionDistanceCost(const int terminalIndex, diff --git a/native/jni/src/suggest/core/policy/suggest_policy.h b/native/jni/src/suggest/core/policy/suggest_policy.h index 885e214f7..5b6402c44 100644 --- a/native/jni/src/suggest/core/policy/suggest_policy.h +++ b/native/jni/src/suggest/core/policy/suggest_policy.h @@ -20,6 +20,7 @@ #include "defines.h" namespace latinime { + class Traversal; class Scoring; class Weighting; diff --git a/native/jni/src/suggest/core/policy/traversal.h b/native/jni/src/suggest/core/policy/traversal.h index 02c358aec..c6f66f231 100644 --- a/native/jni/src/suggest/core/policy/traversal.h +++ b/native/jni/src/suggest/core/policy/traversal.h @@ -28,7 +28,8 @@ class Traversal { virtual int getMaxPointerCount() const = 0; virtual bool allowsErrorCorrections(const DicNode *const dicNode) const = 0; virtual bool isOmission(const DicTraverseSession *const traverseSession, - const DicNode *const dicNode, const DicNode *const childDicNode) const = 0; + const DicNode *const dicNode, const DicNode *const childDicNode, + const bool allowsErrorCorrections) const = 0; virtual bool isSpaceSubstitutionTerminal(const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const = 0; virtual bool isSpaceOmissionTerminal(const DicTraverseSession *const traverseSession, @@ -38,9 +39,8 @@ class Traversal { const DicNode *const dicNode) const = 0; virtual bool canDoLookAheadCorrection(const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const = 0; - virtual ProximityType getProximityType( - const DicTraverseSession *const traverseSession, const DicNode *const dicNode, - const DicNode *const childDicNode) const = 0; + virtual ProximityType getProximityType(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode, const DicNode *const childDicNode) const = 0; virtual bool sameAsTyped(const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const = 0; virtual bool needsToTraverseAllUserInput() const = 0; @@ -48,9 +48,8 @@ class Traversal { virtual bool allowPartialCommit() const = 0; virtual int getDefaultExpandDicNodeSize() const = 0; virtual int getMaxCacheSize() const = 0; - virtual bool isPossibleOmissionChildNode( - const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, - const DicNode *const dicNode) const = 0; + virtual bool isPossibleOmissionChildNode(const DicTraverseSession *const traverseSession, + const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0; virtual bool isGoodToTraverseNextWord(const DicNode *const dicNode) const = 0; protected: diff --git a/native/jni/src/suggest/core/policy/weighting.cpp b/native/jni/src/suggest/core/policy/weighting.cpp index a6d30e457..d01531f07 100644 --- a/native/jni/src/suggest/core/policy/weighting.cpp +++ b/native/jni/src/suggest/core/policy/weighting.cpp @@ -18,7 +18,6 @@ #include "char_utils.h" #include "defines.h" -#include "hash_map_compat.h" #include "suggest/core/dicnode/dic_node.h" #include "suggest/core/dicnode/dic_node_profiler.h" #include "suggest/core/dicnode/dic_node_utils.h" @@ -26,6 +25,8 @@ namespace latinime { +class MultiBigramMap; + static inline void profile(const CorrectionType correctionType, DicNode *const node) { #if DEBUG_DICT switch (correctionType) { @@ -69,20 +70,18 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n } /* static */ void Weighting::addCostAndForwardInputIndex(const Weighting *const weighting, - const CorrectionType correctionType, - const DicTraverseSession *const traverseSession, + const CorrectionType correctionType, const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, DicNode *const dicNode, - hash_map_compat<int, int16_t> *const bigramCacheMap) { + MultiBigramMap *const multiBigramMap) { const int inputSize = traverseSession->getInputSize(); DicNode_InputStateG inputStateG; inputStateG.mNeedsToUpdateInputStateG = false; // Don't use input info by default const float spatialCost = Weighting::getSpatialCost(weighting, correctionType, traverseSession, parentDicNode, dicNode, &inputStateG); const float languageCost = Weighting::getLanguageCost(weighting, correctionType, - traverseSession, parentDicNode, dicNode, bigramCacheMap); - const bool edit = Weighting::isEditCorrection(correctionType); - const bool proximity = Weighting::isProximityCorrection(weighting, correctionType, - traverseSession, dicNode); + traverseSession, parentDicNode, dicNode, multiBigramMap); + const ErrorType errorType = weighting->getErrorType(correctionType, traverseSession, + parentDicNode, dicNode); profile(correctionType, dicNode); if (inputStateG.mNeedsToUpdateInputStateG) { dicNode->updateInputIndexG(&inputStateG); @@ -91,13 +90,13 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n (correctionType == CT_TRANSPOSITION)); } dicNode->addCost(spatialCost, languageCost, weighting->needsToNormalizeCompoundDistance(), - inputSize, edit, proximity); + inputSize, errorType); } /* static */ float Weighting::getSpatialCost(const Weighting *const weighting, - const CorrectionType correctionType, - const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, - const DicNode *const dicNode, DicNode_InputStateG *const inputStateG) { + const CorrectionType correctionType, const DicTraverseSession *const traverseSession, + const DicNode *const parentDicNode, const DicNode *const dicNode, + DicNode_InputStateG *const inputStateG) { switch(correctionType) { case CT_OMISSION: return weighting->getOmissionCost(parentDicNode, dicNode); @@ -129,14 +128,14 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n /* static */ float Weighting::getLanguageCost(const Weighting *const weighting, const CorrectionType correctionType, const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, const DicNode *const dicNode, - hash_map_compat<int, int16_t> *const bigramCacheMap) { + MultiBigramMap *const multiBigramMap) { switch(correctionType) { case CT_OMISSION: return 0.0f; case CT_SUBSTITUTION: return 0.0f; case CT_NEW_WORD_SPACE_OMITTION: - return weighting->getNewWordBigramCost(traverseSession, parentDicNode, bigramCacheMap); + return weighting->getNewWordBigramCost(traverseSession, parentDicNode, multiBigramMap); case CT_MATCH: return 0.0f; case CT_COMPLETION: @@ -144,11 +143,11 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n case CT_TERMINAL: { const float languageImprobability = DicNodeUtils::getBigramNodeImprobability( - traverseSession->getOffsetDict(), dicNode, bigramCacheMap); + traverseSession->getOffsetDict(), dicNode, multiBigramMap); return weighting->getTerminalLanguageCost(traverseSession, dicNode, languageImprobability); } case CT_NEW_WORD_SPACE_SUBSTITUTION: - return weighting->getNewWordBigramCost(traverseSession, parentDicNode, bigramCacheMap); + return weighting->getNewWordBigramCost(traverseSession, parentDicNode, multiBigramMap); case CT_INSERTION: return 0.0f; case CT_TRANSPOSITION: @@ -158,64 +157,6 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n } } -/* static */ bool Weighting::isEditCorrection(const CorrectionType correctionType) { - switch(correctionType) { - case CT_OMISSION: - return true; - case CT_ADDITIONAL_PROXIMITY: - // Should return true? - return false; - case CT_SUBSTITUTION: - // Should return true? - return false; - case CT_NEW_WORD_SPACE_OMITTION: - return false; - case CT_MATCH: - return false; - case CT_COMPLETION: - return false; - case CT_TERMINAL: - return false; - case CT_NEW_WORD_SPACE_SUBSTITUTION: - return false; - case CT_INSERTION: - return true; - case CT_TRANSPOSITION: - return true; - default: - return false; - } -} - -/* static */ bool Weighting::isProximityCorrection(const Weighting *const weighting, - const CorrectionType correctionType, - const DicTraverseSession *const traverseSession, const DicNode *const dicNode) { - switch(correctionType) { - case CT_OMISSION: - return false; - case CT_ADDITIONAL_PROXIMITY: - return false; - case CT_SUBSTITUTION: - return false; - case CT_NEW_WORD_SPACE_OMITTION: - return false; - case CT_MATCH: - return weighting->isProximityDicNode(traverseSession, dicNode); - case CT_COMPLETION: - return false; - case CT_TERMINAL: - return false; - case CT_NEW_WORD_SPACE_SUBSTITUTION: - return false; - case CT_INSERTION: - return false; - case CT_TRANSPOSITION: - return false; - default: - return false; - } -} - /* static */ int Weighting::getForwardInputCount(const CorrectionType correctionType) { switch(correctionType) { case CT_OMISSION: diff --git a/native/jni/src/suggest/core/policy/weighting.h b/native/jni/src/suggest/core/policy/weighting.h index bce479c51..0d2745b40 100644 --- a/native/jni/src/suggest/core/policy/weighting.h +++ b/native/jni/src/suggest/core/policy/weighting.h @@ -18,13 +18,13 @@ #define LATINIME_WEIGHTING_H #include "defines.h" -#include "hash_map_compat.h" namespace latinime { class DicNode; class DicTraverseSession; struct DicNode_InputStateG; +class MultiBigramMap; class Weighting { public: @@ -32,7 +32,7 @@ class Weighting { const CorrectionType correctionType, const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, DicNode *const dicNode, - hash_map_compat<int, int16_t> *const bigramCacheMap); + MultiBigramMap *const multiBigramMap); protected: virtual float getTerminalSpatialCost(const DicTraverseSession *const traverseSession, @@ -61,7 +61,7 @@ class Weighting { virtual float getNewWordBigramCost( const DicTraverseSession *const traverseSession, const DicNode *const dicNode, - hash_map_compat<int, int16_t> *const bigramCacheMap) const = 0; + MultiBigramMap *const multiBigramMap) const = 0; virtual float getCompletionCost( const DicTraverseSession *const traverseSession, @@ -80,6 +80,10 @@ class Weighting { virtual float getSpaceSubstitutionCost(const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const = 0; + virtual ErrorType getErrorType(const CorrectionType correctionType, + const DicTraverseSession *const traverseSession, + const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0; + Weighting() {} virtual ~Weighting() {} @@ -93,13 +97,7 @@ class Weighting { static float getLanguageCost(const Weighting *const weighting, const CorrectionType correctionType, const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, const DicNode *const dicNode, - hash_map_compat<int, int16_t> *const bigramCacheMap); - // TODO: Move to TypingWeighting and GestureWeighting? - static bool isEditCorrection(const CorrectionType correctionType); - // TODO: Move to TypingWeighting and GestureWeighting? - static bool isProximityCorrection(const Weighting *const weighting, - const CorrectionType correctionType, const DicTraverseSession *const traverseSession, - const DicNode *const dicNode); + MultiBigramMap *const multiBigramMap); // TODO: Move to TypingWeighting and GestureWeighting? static int getForwardInputCount(const CorrectionType correctionType); }; diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.cpp b/native/jni/src/suggest/core/session/dic_traverse_session.cpp index 3c44db21c..6408f0163 100644 --- a/native/jni/src/suggest/core/session/dic_traverse_session.cpp +++ b/native/jni/src/suggest/core/session/dic_traverse_session.cpp @@ -64,12 +64,21 @@ static TraverseSessionFactoryRegisterer traverseSessionFactoryRegisterer; void DicTraverseSession::init(const Dictionary *const dictionary, const int *prevWord, int prevWordLength) { mDictionary = dictionary; - mMultiWordCostMultiplier = BinaryFormat::getMultiWordCostMultiplier(mDictionary->getDict()); + mMultiWordCostMultiplier = BinaryFormat::getMultiWordCostMultiplier(mDictionary->getDict(), + mDictionary->getDictSize()); if (!prevWord) { mPrevWordPos = NOT_VALID_WORD; return; } - mPrevWordPos = DicNodeUtils::getWordPos(dictionary->getOffsetDict(), prevWord, prevWordLength); + // TODO: merge following similar calls to getTerminalPosition into one case-insensitive call. + mPrevWordPos = BinaryFormat::getTerminalPosition(dictionary->getOffsetDict(), prevWord, + prevWordLength, false /* forceLowerCaseSearch */); + if (mPrevWordPos == NOT_VALID_WORD) { + // Check bigrams for lower-cased previous word if original was not found. Useful for + // auto-capitalized words like "The [current_word]". + mPrevWordPos = BinaryFormat::getTerminalPosition(dictionary->getOffsetDict(), prevWord, + prevWordLength, true /* forceLowerCaseSearch */); + } } void DicTraverseSession::setupForGetSuggestions(const ProximityInfo *pInfo, @@ -92,7 +101,7 @@ int DicTraverseSession::getDictFlags() const { void DicTraverseSession::resetCache(const int nextActiveCacheSize, const int maxWords) { mDicNodesCache.reset(nextActiveCacheSize, maxWords); - mBigramCacheMap.clear(); + mMultiBigramMap.clear(); mPartiallyCommited = false; } diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.h b/native/jni/src/suggest/core/session/dic_traverse_session.h index d9c2a51d0..d88be5b88 100644 --- a/native/jni/src/suggest/core/session/dic_traverse_session.h +++ b/native/jni/src/suggest/core/session/dic_traverse_session.h @@ -21,8 +21,8 @@ #include <vector> #include "defines.h" -#include "hash_map_compat.h" #include "jni.h" +#include "multi_bigram_map.h" #include "proximity_info_state.h" #include "suggest/core/dicnode/dic_nodes_cache.h" @@ -35,7 +35,7 @@ class DicTraverseSession { public: AK_FORCE_INLINE DicTraverseSession(JNIEnv *env, jstring localeStr) : mPrevWordPos(NOT_VALID_WORD), mProximityInfo(0), - mDictionary(0), mDicNodesCache(), mBigramCacheMap(), + mDictionary(0), mDicNodesCache(), mMultiBigramMap(), mInputSize(0), mPartiallyCommited(false), mMaxPointerCount(1), mMultiWordCostMultiplier(1.0f) { // NOTE: mProximityInfoStates is an array of instances. @@ -67,7 +67,7 @@ class DicTraverseSession { // TODO: Use proper parameter when changed int getDicRootPos() const { return 0; } DicNodesCache *getDicTraverseCache() { return &mDicNodesCache; } - hash_map_compat<int, int16_t> *getBigramCacheMap() { return &mBigramCacheMap; } + MultiBigramMap *getMultiBigramMap() { return &mMultiBigramMap; } const ProximityInfoState *getProximityInfoState(int id) const { return &mProximityInfoStates[id]; } @@ -170,7 +170,7 @@ class DicTraverseSession { DicNodesCache mDicNodesCache; // Temporary cache for bigram frequencies - hash_map_compat<int, int16_t> mBigramCacheMap; + MultiBigramMap mMultiBigramMap; ProximityInfoState mProximityInfoStates[MAX_POINTER_COUNT_G]; int mInputSize; diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp index 9de2cd2e2..a18794850 100644 --- a/native/jni/src/suggest/core/suggest.cpp +++ b/native/jni/src/suggest/core/suggest.cpp @@ -161,12 +161,20 @@ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequen + doubleLetterCost; const TerminalAttributes terminalAttributes(traverseSession->getOffsetDict(), terminalDicNode->getFlags(), terminalDicNode->getAttributesPos()); - const int originalTerminalProbability = terminalDicNode->getProbability(); + const bool isPossiblyOffensiveWord = terminalDicNode->getProbability() <= 0; + const bool isExactMatch = terminalDicNode->isExactMatch(); + const bool isFirstCharUppercase = terminalDicNode->isFirstCharUppercase(); + // Heuristic: We exclude freq=0 first-char-uppercase words from exact match. + // (e.g. "AMD" and "and") + const bool isSafeExactMatch = isExactMatch + && !(isPossiblyOffensiveWord && isFirstCharUppercase); + const int outputTypeFlags = + (isPossiblyOffensiveWord ? Dictionary::KIND_FLAG_POSSIBLY_OFFENSIVE : 0) + | (isSafeExactMatch ? Dictionary::KIND_FLAG_EXACT_MATCH : 0); + + // Entries that are blacklisted or do not represent a word should not be output. + const bool isValidWord = !terminalAttributes.isBlacklistedOrNotAWord(); - // Do not suggest words with a 0 probability, or entries that are blacklisted or do not - // represent a word. However, we should still submit their shortcuts if any. - const bool isValidWord = - originalTerminalProbability > 0 && !terminalAttributes.isBlacklistedOrNotAWord(); // Increase output score of top typing suggestion to ensure autocorrection. // TODO: Better integration with java side autocorrection logic. // Force autocorrection for obvious long multi-word suggestions. @@ -188,10 +196,9 @@ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequen } } - // Do not suggest words with a 0 probability, or entries that are blacklisted or do not - // represent a word. However, we should still submit their shortcuts if any. + // Don't output invalid words. However, we still need to submit their shortcuts if any. if (isValidWord) { - outputTypes[outputWordIndex] = Dictionary::KIND_CORRECTION; + outputTypes[outputWordIndex] = Dictionary::KIND_CORRECTION | outputTypeFlags; frequencies[outputWordIndex] = finalScore; // Populate the outputChars array with the suggested word. const int startIndex = outputWordIndex * MAX_WORD_LENGTH; @@ -294,8 +301,8 @@ void Suggest::expandCurrentDicNodes(DicTraverseSession *traverseSession) const { correctionDicNode.advanceDigraphIndex(); processDicNodeAsDigraph(traverseSession, &correctionDicNode); } - if (allowsErrorCorrections - && TRAVERSAL->isOmission(traverseSession, &dicNode, childDicNode)) { + if (TRAVERSAL->isOmission(traverseSession, &dicNode, childDicNode, + allowsErrorCorrections)) { // TODO: (Gesture) Change weight between omission and substitution errors // TODO: (Gesture) Terminal node should not be handled as omission correctionDicNode.initByCopy(childDicNode); @@ -357,7 +364,7 @@ void Suggest::processTerminalDicNode( DicNode terminalDicNode; DicNodeUtils::initByCopy(dicNode, &terminalDicNode); Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TERMINAL, traverseSession, 0, - &terminalDicNode, traverseSession->getBigramCacheMap()); + &terminalDicNode, traverseSession->getMultiBigramMap()); traverseSession->getDicTraverseCache()->copyPushTerminal(&terminalDicNode); } @@ -389,8 +396,10 @@ void Suggest::processDicNodeAsMatch(DicTraverseSession *traverseSession, void Suggest::processDicNodeAsAdditionalProximityChar(DicTraverseSession *traverseSession, DicNode *dicNode, DicNode *childDicNode) const { + // Note: Most types of corrections don't need to look up the bigram information since they do + // not treat the node as a terminal. There is no need to pass the bigram map in these cases. Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_ADDITIONAL_PROXIMITY, - traverseSession, dicNode, childDicNode, 0 /* bigramCacheMap */); + traverseSession, dicNode, childDicNode, 0 /* multiBigramMap */); weightChildNode(traverseSession, childDicNode); processExpandedDicNode(traverseSession, childDicNode); } @@ -398,7 +407,7 @@ void Suggest::processDicNodeAsAdditionalProximityChar(DicTraverseSession *traver void Suggest::processDicNodeAsSubstitution(DicTraverseSession *traverseSession, DicNode *dicNode, DicNode *childDicNode) const { Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_SUBSTITUTION, traverseSession, - dicNode, childDicNode, 0 /* bigramCacheMap */); + dicNode, childDicNode, 0 /* multiBigramMap */); weightChildNode(traverseSession, childDicNode); processExpandedDicNode(traverseSession, childDicNode); } @@ -422,20 +431,15 @@ void Suggest::processDicNodeAsDigraph(DicTraverseSession *traverseSession, */ void Suggest::processDicNodeAsOmission( DicTraverseSession *traverseSession, DicNode *dicNode) const { - // If the omission is surely intentional that it should incur zero cost. - const bool isZeroCostOmission = dicNode->isZeroCostOmission(); DicNodeVector childDicNodes; - DicNodeUtils::getAllChildDicNodes(dicNode, traverseSession->getOffsetDict(), &childDicNodes); const int size = childDicNodes.getSizeAndLock(); for (int i = 0; i < size; i++) { DicNode *const childDicNode = childDicNodes[i]; - if (!isZeroCostOmission) { - // Treat this word as omission - Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_OMISSION, traverseSession, - dicNode, childDicNode, 0 /* bigramCacheMap */); - } + // Treat this word as omission + Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_OMISSION, traverseSession, + dicNode, childDicNode, 0 /* multiBigramMap */); weightChildNode(traverseSession, childDicNode); if (!TRAVERSAL->isPossibleOmissionChildNode(traverseSession, dicNode, childDicNode)) { @@ -459,7 +463,7 @@ void Suggest::processDicNodeAsInsertion(DicTraverseSession *traverseSession, for (int i = 0; i < size; i++) { DicNode *const childDicNode = childDicNodes[i]; Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_INSERTION, traverseSession, - dicNode, childDicNode, 0 /* bigramCacheMap */); + dicNode, childDicNode, 0 /* multiBigramMap */); processExpandedDicNode(traverseSession, childDicNode); } } @@ -484,7 +488,7 @@ void Suggest::processDicNodeAsTransposition(DicTraverseSession *traverseSession, for (int j = 0; j < childSize2; j++) { DicNode *const childDicNode2 = childDicNodes2[j]; Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TRANSPOSITION, - traverseSession, childDicNodes1[i], childDicNode2, 0 /* bigramCacheMap */); + traverseSession, childDicNodes1[i], childDicNode2, 0 /* multiBigramMap */); processExpandedDicNode(traverseSession, childDicNode2); } } @@ -499,10 +503,10 @@ void Suggest::weightChildNode(DicTraverseSession *traverseSession, DicNode *dicN const int inputSize = traverseSession->getInputSize(); if (dicNode->isCompletion(inputSize)) { Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_COMPLETION, traverseSession, - 0 /* parentDicNode */, dicNode, 0 /* bigramCacheMap */); + 0 /* parentDicNode */, dicNode, 0 /* multiBigramMap */); } else { // completion Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_MATCH, traverseSession, - 0 /* parentDicNode */, dicNode, 0 /* bigramCacheMap */); + 0 /* parentDicNode */, dicNode, 0 /* multiBigramMap */); } } @@ -523,7 +527,7 @@ void Suggest::createNextWordDicNode(DicTraverseSession *traverseSession, DicNode const CorrectionType correctionType = spaceSubstitution ? CT_NEW_WORD_SPACE_SUBSTITUTION : CT_NEW_WORD_SPACE_OMITTION; Weighting::addCostAndForwardInputIndex(WEIGHTING, correctionType, traverseSession, dicNode, - &newDicNode, traverseSession->getBigramCacheMap()); + &newDicNode, traverseSession->getMultiBigramMap()); traverseSession->getDicTraverseCache()->copyPushNextActive(&newDicNode); } } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp index 993358616..f87989286 100644 --- a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp +++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp @@ -28,25 +28,25 @@ const int ScoringParams::THRESHOLD_SHORT_WORD_LENGTH = 4; const float ScoringParams::DISTANCE_WEIGHT_LENGTH = 0.132f; const float ScoringParams::PROXIMITY_COST = 0.086f; const float ScoringParams::FIRST_PROXIMITY_COST = 0.104f; -const float ScoringParams::OMISSION_COST = 0.388f; -const float ScoringParams::OMISSION_COST_SAME_CHAR = 0.431f; -const float ScoringParams::OMISSION_COST_FIRST_CHAR = 0.532f; -const float ScoringParams::INSERTION_COST = 0.670f; -const float ScoringParams::INSERTION_COST_SAME_CHAR = 0.526f; -const float ScoringParams::INSERTION_COST_FIRST_CHAR = 0.563f; -const float ScoringParams::TRANSPOSITION_COST = 0.494f; -const float ScoringParams::SPACE_SUBSTITUTION_COST = 0.289f; +const float ScoringParams::OMISSION_COST = 0.458f; +const float ScoringParams::OMISSION_COST_SAME_CHAR = 0.491f; +const float ScoringParams::OMISSION_COST_FIRST_CHAR = 0.582f; +const float ScoringParams::INSERTION_COST = 0.730f; +const float ScoringParams::INSERTION_COST_SAME_CHAR = 0.586f; +const float ScoringParams::INSERTION_COST_FIRST_CHAR = 0.623f; +const float ScoringParams::TRANSPOSITION_COST = 0.516f; +const float ScoringParams::SPACE_SUBSTITUTION_COST = 0.319f; const float ScoringParams::ADDITIONAL_PROXIMITY_COST = 0.380f; -const float ScoringParams::SUBSTITUTION_COST = 0.363f; -const float ScoringParams::COST_NEW_WORD = 0.024f; -const float ScoringParams::COST_NEW_WORD_CAPITALIZED = 0.174f; +const float ScoringParams::SUBSTITUTION_COST = 0.403f; +const float ScoringParams::COST_NEW_WORD = 0.042f; +const float ScoringParams::COST_SECOND_OR_LATER_WORD_FIRST_CHAR_UPPERCASE = 0.25f; const float ScoringParams::DISTANCE_WEIGHT_LANGUAGE = 1.123f; const float ScoringParams::COST_FIRST_LOOKAHEAD = 0.545f; const float ScoringParams::COST_LOOKAHEAD = 0.073f; -const float ScoringParams::HAS_PROXIMITY_TERMINAL_COST = 0.126f; -const float ScoringParams::HAS_EDIT_CORRECTION_TERMINAL_COST = 0.056f; -const float ScoringParams::HAS_MULTI_WORD_TERMINAL_COST = 0.536f; +const float ScoringParams::HAS_PROXIMITY_TERMINAL_COST = 0.105f; +const float ScoringParams::HAS_EDIT_CORRECTION_TERMINAL_COST = 0.038f; +const float ScoringParams::HAS_MULTI_WORD_TERMINAL_COST = 0.444f; const float ScoringParams::TYPING_BASE_OUTPUT_SCORE = 1.0f; const float ScoringParams::TYPING_MAX_OUTPUT_SCORE_PER_INPUT = 0.1f; -const float ScoringParams::MAX_NORM_DISTANCE_FOR_EDIT = 0.1f; +const float ScoringParams::NORMALIZED_SPATIAL_DISTANCE_THRESHOLD_FOR_EDIT = 0.06f; } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.h b/native/jni/src/suggest/policyimpl/typing/scoring_params.h index 8f104b362..53ac999c1 100644 --- a/native/jni/src/suggest/policyimpl/typing/scoring_params.h +++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.h @@ -48,7 +48,7 @@ class ScoringParams { static const float ADDITIONAL_PROXIMITY_COST; static const float SUBSTITUTION_COST; static const float COST_NEW_WORD; - static const float COST_NEW_WORD_CAPITALIZED; + static const float COST_SECOND_OR_LATER_WORD_FIRST_CHAR_UPPERCASE; static const float DISTANCE_WEIGHT_LANGUAGE; static const float COST_FIRST_LOOKAHEAD; static const float COST_LOOKAHEAD; @@ -57,7 +57,7 @@ class ScoringParams { static const float HAS_MULTI_WORD_TERMINAL_COST; static const float TYPING_BASE_OUTPUT_SCORE; static const float TYPING_MAX_OUTPUT_SCORE_PER_INPUT; - static const float MAX_NORM_DISTANCE_FOR_EDIT; + static const float NORMALIZED_SPATIAL_DISTANCE_THRESHOLD_FOR_EDIT; private: DISALLOW_IMPLICIT_CONSTRUCTORS(ScoringParams); diff --git a/native/jni/src/suggest/policyimpl/typing/typing_traversal.h b/native/jni/src/suggest/policyimpl/typing/typing_traversal.h index 9f8347452..12110d54f 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_traversal.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_traversal.h @@ -39,14 +39,21 @@ class TypingTraversal : public Traversal { AK_FORCE_INLINE bool allowsErrorCorrections(const DicNode *const dicNode) const { return dicNode->getNormalizedSpatialDistance() - < ScoringParams::MAX_NORM_DISTANCE_FOR_EDIT; + < ScoringParams::NORMALIZED_SPATIAL_DISTANCE_THRESHOLD_FOR_EDIT; } AK_FORCE_INLINE bool isOmission(const DicTraverseSession *const traverseSession, - const DicNode *const dicNode, const DicNode *const childDicNode) const { + const DicNode *const dicNode, const DicNode *const childDicNode, + const bool allowsErrorCorrections) const { if (!CORRECT_OMISSION) { return false; } + // Note: Always consider intentional omissions (like apostrophes) since they are common. + const bool canConsiderOmission = + allowsErrorCorrections || childDicNode->canBeIntentionalOmission(); + if (!canConsiderOmission) { + return false; + } const int inputSize = traverseSession->getInputSize(); // TODO: Don't refer to isCompletion? if (dicNode->isCompletion(inputSize)) { diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp b/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp index 1500341bd..e4c69d1f6 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp +++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp @@ -20,5 +20,41 @@ #include "suggest/policyimpl/typing/scoring_params.h" namespace latinime { + const TypingWeighting TypingWeighting::sInstance; + +ErrorType TypingWeighting::getErrorType(const CorrectionType correctionType, + const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, + const DicNode *const dicNode) const { + switch (correctionType) { + case CT_MATCH: + if (isProximityDicNode(traverseSession, dicNode)) { + return ET_PROXIMITY_CORRECTION; + } else { + return ET_NOT_AN_ERROR; + } + case CT_ADDITIONAL_PROXIMITY: + return ET_PROXIMITY_CORRECTION; + case CT_OMISSION: + if (parentDicNode->canBeIntentionalOmission()) { + return ET_INTENTIONAL_OMISSION; + } else { + return ET_EDIT_CORRECTION; + } + break; + case CT_SUBSTITUTION: + case CT_INSERTION: + case CT_TRANSPOSITION: + return ET_EDIT_CORRECTION; + case CT_NEW_WORD_SPACE_OMITTION: + case CT_NEW_WORD_SPACE_SUBSTITUTION: + return ET_NEW_WORD; + case CT_TERMINAL: + return ET_NOT_AN_ERROR; + case CT_COMPLETION: + return ET_COMPLETION; + default: + return ET_NOT_AN_ERROR; + } +} } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h index 34d25ae1a..3938c0ec5 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h @@ -28,14 +28,15 @@ namespace latinime { class DicNode; struct DicNode_InputStateG; +class MultiBigramMap; class TypingWeighting : public Weighting { public: static const TypingWeighting *getInstance() { return &sInstance; } protected: - float getTerminalSpatialCost( - const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const { + float getTerminalSpatialCost(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const { float cost = 0.0f; if (dicNode->hasMultipleWords()) { cost += ScoringParams::HAS_MULTI_WORD_TERMINAL_COST; @@ -50,13 +51,14 @@ class TypingWeighting : public Weighting { } float getOmissionCost(const DicNode *const parentDicNode, const DicNode *const dicNode) const { - bool sameCodePoint = false; - bool isFirstLetterOmission = false; - float cost = 0.0f; - sameCodePoint = dicNode->isSameNodeCodePoint(parentDicNode); + const bool isZeroCostOmission = parentDicNode->isZeroCostOmission(); + const bool sameCodePoint = dicNode->isSameNodeCodePoint(parentDicNode); // If the traversal omitted the first letter then the dicNode should now be on the second. - isFirstLetterOmission = dicNode->getDepth() == 2; - if (isFirstLetterOmission) { + const bool isFirstLetterOmission = dicNode->getDepth() == 2; + float cost = 0.0f; + if (isZeroCostOmission) { + cost = 0.0f; + } else if (isFirstLetterOmission) { cost = ScoringParams::OMISSION_COST_FIRST_CHAR; } else { cost = sameCodePoint ? ScoringParams::OMISSION_COST_SAME_CHAR @@ -65,9 +67,8 @@ class TypingWeighting : public Weighting { return cost; } - float getMatchedCost( - const DicTraverseSession *const traverseSession, const DicNode *const dicNode, - DicNode_InputStateG *inputStateG) const { + float getMatchedCost(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode, DicNode_InputStateG *inputStateG) const { const int pointIndex = dicNode->getInputIndex(0); // Note: min() required since length can be MAX_POINT_TO_KEY_LENGTH for characters not on // the keyboard (like accented letters) @@ -79,13 +80,23 @@ class TypingWeighting : public Weighting { const bool isFirstChar = pointIndex == 0; const bool isProximity = isProximityDicNode(traverseSession, dicNode); - const float cost = isProximity ? (isFirstChar ? ScoringParams::FIRST_PROXIMITY_COST + float cost = isProximity ? (isFirstChar ? ScoringParams::FIRST_PROXIMITY_COST : ScoringParams::PROXIMITY_COST) : 0.0f; + if (dicNode->getDepth() == 2) { + // At the second character of the current word, we check if the first char is uppercase + // and the word is a second or later word of a multiple word suggestion. We demote it + // if so. + const bool isSecondOrLaterWordFirstCharUppercase = + dicNode->hasMultipleWords() && dicNode->isFirstCharUppercase(); + if (isSecondOrLaterWordFirstCharUppercase) { + cost += ScoringParams::COST_SECOND_OR_LATER_WORD_FIRST_CHAR_UPPERCASE; + } + } return weightedDistance + cost; } - bool isProximityDicNode( - const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const { + bool isProximityDicNode(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const { const int pointIndex = dicNode->getInputIndex(0); const int primaryCodePoint = toBaseLowerCase( traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt(pointIndex)); @@ -93,9 +104,8 @@ class TypingWeighting : public Weighting { return primaryCodePoint != dicNodeChar; } - float getTranspositionCost( - const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, - const DicNode *const dicNode) const { + float getTranspositionCost(const DicTraverseSession *const traverseSession, + const DicNode *const parentDicNode, const DicNode *const dicNode) const { const int16_t parentPointIndex = parentDicNode->getInputIndex(0); const int prevCodePoint = parentDicNode->getNodeCodePoint(); const float distance1 = traverseSession->getProximityInfoState(0)->getPointToKeyLength( @@ -109,8 +119,7 @@ class TypingWeighting : public Weighting { return ScoringParams::TRANSPOSITION_COST + weightedLengthDistance; } - float getInsertionCost( - const DicTraverseSession *const traverseSession, + float getInsertionCost(const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, const DicNode *const dicNode) const { const int16_t parentPointIndex = parentDicNode->getInputIndex(0); const int prevCodePoint = @@ -130,17 +139,14 @@ class TypingWeighting : public Weighting { float getNewWordCost(const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const { - const bool isCapitalized = dicNode->isCapitalized(); - const float cost = isCapitalized ? - ScoringParams::COST_NEW_WORD_CAPITALIZED : ScoringParams::COST_NEW_WORD; - return cost * traverseSession->getMultiWordCostMultiplier(); + return ScoringParams::COST_NEW_WORD * traverseSession->getMultiWordCostMultiplier(); } - float getNewWordBigramCost( - const DicTraverseSession *const traverseSession, const DicNode *const dicNode, - hash_map_compat<int, int16_t> *const bigramCacheMap) const { + float getNewWordBigramCost(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode, + MultiBigramMap *const multiBigramMap) const { return DicNodeUtils::getBigramNodeImprobability(traverseSession->getOffsetDict(), - dicNode, bigramCacheMap) * ScoringParams::DISTANCE_WEIGHT_LANGUAGE; + dicNode, multiBigramMap) * ScoringParams::DISTANCE_WEIGHT_LANGUAGE; } float getCompletionCost(const DicTraverseSession *const traverseSession, @@ -156,15 +162,8 @@ class TypingWeighting : public Weighting { float getTerminalLanguageCost(const DicTraverseSession *const traverseSession, const DicNode *const dicNode, const float dicNodeLanguageImprobability) const { - const bool hasEditCount = dicNode->getEditCorrectionCount() > 0; - const bool isSameLength = dicNode->getDepth() == traverseSession->getInputSize(); - const bool hasMultipleWords = dicNode->hasMultipleWords(); - const bool hasProximityErrors = dicNode->getProximityCorrectionCount() > 0; - // Gesture input is always assumed to have proximity errors - // because the input word shouldn't be treated as perfect - const bool isExactMatch = !hasEditCount && !hasMultipleWords - && !hasProximityErrors && isSameLength; - const float languageImprobability = isExactMatch ? 0.0f : dicNodeLanguageImprobability; + const float languageImprobability = (dicNode->isExactMatch()) ? + 0.0f : dicNodeLanguageImprobability; return languageImprobability * ScoringParams::DISTANCE_WEIGHT_LANGUAGE; } @@ -180,15 +179,16 @@ class TypingWeighting : public Weighting { return ScoringParams::SUBSTITUTION_COST; } - AK_FORCE_INLINE float getSpaceSubstitutionCost( - const DicTraverseSession *const traverseSession, + AK_FORCE_INLINE float getSpaceSubstitutionCost(const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const { - const bool isCapitalized = dicNode->isCapitalized(); - const float cost = ScoringParams::SPACE_SUBSTITUTION_COST + (isCapitalized ? - ScoringParams::COST_NEW_WORD_CAPITALIZED : ScoringParams::COST_NEW_WORD); + const float cost = ScoringParams::SPACE_SUBSTITUTION_COST + ScoringParams::COST_NEW_WORD; return cost * traverseSession->getMultiWordCostMultiplier(); } + ErrorType getErrorType(const CorrectionType correctionType, + const DicTraverseSession *const traverseSession, + const DicNode *const parentDicNode, const DicNode *const dicNode) const; + private: DISALLOW_COPY_AND_ASSIGN(TypingWeighting); static const TypingWeighting sInstance; diff --git a/native/jni/src/suggest/policyimpl/utils/damerau_levenshtein_edit_distance_policy.h b/native/jni/src/suggest/policyimpl/utils/damerau_levenshtein_edit_distance_policy.h new file mode 100644 index 000000000..ec1457455 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/utils/damerau_levenshtein_edit_distance_policy.h @@ -0,0 +1,79 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DAEMARU_LEVENSHTEIN_EDIT_DISTANCE_POLICY_H +#define LATINIME_DAEMARU_LEVENSHTEIN_EDIT_DISTANCE_POLICY_H + +#include "char_utils.h" +#include "suggest/policyimpl/utils/edit_distance_policy.h" + +namespace latinime { + +class DamerauLevenshteinEditDistancePolicy : public EditDistancePolicy { + public: + DamerauLevenshteinEditDistancePolicy(const int *const string0, const int length0, + const int *const string1, const int length1) + : mString0(string0), mString0Length(length0), mString1(string1), + mString1Length(length1) {} + ~DamerauLevenshteinEditDistancePolicy() {} + + AK_FORCE_INLINE float getSubstitutionCost(const int index0, const int index1) const { + const int c0 = toBaseLowerCase(mString0[index0]); + const int c1 = toBaseLowerCase(mString1[index1]); + return (c0 == c1) ? 0.0f : 1.0f; + } + + AK_FORCE_INLINE float getDeletionCost(const int index0, const int index1) const { + return 1.0f; + } + + AK_FORCE_INLINE float getInsertionCost(const int index0, const int index1) const { + return 1.0f; + } + + AK_FORCE_INLINE bool allowTransposition(const int index0, const int index1) const { + const int c0 = toBaseLowerCase(mString0[index0]); + const int c1 = toBaseLowerCase(mString1[index1]); + if (index0 > 0 && index1 > 0 && c0 == toBaseLowerCase(mString1[index1 - 1]) + && c1 == toBaseLowerCase(mString0[index0 - 1])) { + return true; + } + return false; + } + + AK_FORCE_INLINE float getTranspositionCost(const int index0, const int index1) const { + return getSubstitutionCost(index0, index1); + } + + AK_FORCE_INLINE int getString0Length() const { + return mString0Length; + } + + AK_FORCE_INLINE int getString1Length() const { + return mString1Length; + } + + private: + DISALLOW_COPY_AND_ASSIGN (DamerauLevenshteinEditDistancePolicy); + + const int *const mString0; + const int mString0Length; + const int *const mString1; + const int mString1Length; +}; +} // namespace latinime + +#endif // LATINIME_DAEMARU_LEVENSHTEIN_EDIT_DISTANCE_POLICY_H diff --git a/native/jni/src/suggest/policyimpl/utils/edit_distance.h b/native/jni/src/suggest/policyimpl/utils/edit_distance.h new file mode 100644 index 000000000..cbbd66894 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/utils/edit_distance.h @@ -0,0 +1,70 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_EDIT_DISTANCE_H +#define LATINIME_EDIT_DISTANCE_H + +#include "defines.h" +#include "suggest/policyimpl/utils/edit_distance_policy.h" + +namespace latinime { + +class EditDistance { + public: + // CAVEAT: There may be performance penalty if you need the edit distance as an integer value. + AK_FORCE_INLINE static float getEditDistance(const EditDistancePolicy *const policy) { + const int beforeLength = policy->getString0Length(); + const int afterLength = policy->getString1Length(); + float dp[(beforeLength + 1) * (afterLength + 1)]; + for (int i = 0; i <= beforeLength; ++i) { + dp[(afterLength + 1) * i] = i * policy->getInsertionCost(i - 1, -1); + } + for (int i = 0; i <= afterLength; ++i) { + dp[i] = i * policy->getDeletionCost(-1, i - 1); + } + + for (int i = 0; i < beforeLength; ++i) { + for (int j = 0; j < afterLength; ++j) { + dp[(afterLength + 1) * (i + 1) + (j + 1)] = min( + dp[(afterLength + 1) * i + (j + 1)] + policy->getInsertionCost(i, j), + min(dp[(afterLength + 1) * (i + 1) + j] + policy->getDeletionCost(i, j), + dp[(afterLength + 1) * i + j] + + policy->getSubstitutionCost(i, j))); + if (policy->allowTransposition(i, j)) { + dp[(afterLength + 1) * (i + 1) + (j + 1)] = min( + dp[(afterLength + 1) * (i + 1) + (j + 1)], + dp[(afterLength + 1) * (i - 1) + (j - 1)] + + policy->getTranspositionCost(i, j)); + } + } + } + if (DEBUG_EDIT_DISTANCE) { + AKLOGI("IN = %d, OUT = %d", beforeLength, afterLength); + for (int i = 0; i < beforeLength + 1; ++i) { + for (int j = 0; j < afterLength + 1; ++j) { + AKLOGI("EDIT[%d][%d], %f", i, j, dp[(afterLength + 1) * i + j]); + } + } + } + return dp[(beforeLength + 1) * (afterLength + 1) - 1]; + } + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(EditDistance); +}; +} // namespace latinime + +#endif // LATINIME_EDIT_DISTANCE_H diff --git a/native/jni/src/suggest/policyimpl/utils/edit_distance_policy.h b/native/jni/src/suggest/policyimpl/utils/edit_distance_policy.h new file mode 100644 index 000000000..e3d1792cb --- /dev/null +++ b/native/jni/src/suggest/policyimpl/utils/edit_distance_policy.h @@ -0,0 +1,43 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_EDIT_DISTANCE_POLICY_H +#define LATINIME_EDIT_DISTANCE_POLICY_H + +#include "defines.h" + +namespace latinime { + +class EditDistancePolicy { + public: + virtual float getSubstitutionCost(const int index0, const int index1) const = 0; + virtual float getDeletionCost(const int index0, const int index1) const = 0; + virtual float getInsertionCost(const int index0, const int index1) const = 0; + virtual bool allowTransposition(const int index0, const int index1) const = 0; + virtual float getTranspositionCost(const int index0, const int index1) const = 0; + virtual int getString0Length() const = 0; + virtual int getString1Length() const = 0; + + protected: + EditDistancePolicy() {} + virtual ~EditDistancePolicy() {} + + private: + DISALLOW_COPY_AND_ASSIGN(EditDistancePolicy); +}; +} // namespace latinime + +#endif // LATINIME_EDIT_DISTANCE_POLICY_H |