diff options
Diffstat (limited to 'native/jni/src')
-rw-r--r-- | native/jni/src/binary_format.h | 53 | ||||
-rw-r--r-- | native/jni/src/correction.cpp | 47 | ||||
-rw-r--r-- | native/jni/src/dictionary.cpp | 6 | ||||
-rw-r--r-- | native/jni/src/suggest/core/session/dic_traverse_session.cpp | 3 | ||||
-rw-r--r-- | native/jni/src/suggest/core/suggest.cpp | 9 | ||||
-rw-r--r-- | native/jni/src/suggest/policyimpl/utils/damerau_levenshtein_edit_distance_policy.h | 79 | ||||
-rw-r--r-- | native/jni/src/suggest/policyimpl/utils/edit_distance.h | 70 | ||||
-rw-r--r-- | native/jni/src/suggest/policyimpl/utils/edit_distance_policy.h | 43 |
8 files changed, 244 insertions, 66 deletions
diff --git a/native/jni/src/binary_format.h b/native/jni/src/binary_format.h index 06f50dc7f..98241532f 100644 --- a/native/jni/src/binary_format.h +++ b/native/jni/src/binary_format.h @@ -64,13 +64,14 @@ class BinaryFormat { static const int UNKNOWN_FORMAT = -1; static const int SHORTCUT_LIST_SIZE_SIZE = 2; - static int detectFormat(const uint8_t *const dict); - static int getHeaderSize(const uint8_t *const dict); - static int getFlags(const uint8_t *const dict); + static int detectFormat(const uint8_t *const dict, const int dictSize); + static int getHeaderSize(const uint8_t *const dict, const int dictSize); + static int getFlags(const uint8_t *const dict, const int dictSize); static bool hasBlacklistedOrNotAWordFlag(const int flags); - static void readHeaderValue(const uint8_t *const dict, const char *const key, int *outValue, - const int outValueSize); - static int readHeaderValueInt(const uint8_t *const dict, const char *const key); + static void readHeaderValue(const uint8_t *const dict, const int dictSize, + const char *const key, int *outValue, const int outValueSize); + static int readHeaderValueInt(const uint8_t *const dict, const int dictSize, + const char *const key); static int getGroupCountAndForwardPointer(const uint8_t *const dict, int *pos); static uint8_t getFlagsAndForwardPointer(const uint8_t *const dict, int *pos); static int getCodePointAndForwardPointer(const uint8_t *const dict, int *pos); @@ -96,7 +97,7 @@ class BinaryFormat { const uint8_t *bigramFilter, const int unigramProbability); static int getBigramProbabilityFromHashMap(const int position, const hash_map_compat<int, int> *bigramMap, const int unigramProbability); - static float getMultiWordCostMultiplier(const uint8_t *const dict); + static float getMultiWordCostMultiplier(const uint8_t *const dict, const int dictSize); static void fillBigramProbabilityToHashMap(const uint8_t *const root, int position, hash_map_compat<int, int> *bigramMap); static int getBigramProbability(const uint8_t *const root, int position, @@ -122,6 +123,8 @@ class BinaryFormat { static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES = 0x20; static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES = 0x30; + // Any file smaller than this is not a dictionary. + static const int DICTIONARY_MINIMUM_SIZE = 4; // Originally, format version 1 had a 16-bit magic number, then the version number `01' // then options that must be 0. Hence the first 32-bits of the format are always as follow // and it's okay to consider them a magic number as a whole. @@ -131,6 +134,8 @@ class BinaryFormat { // number, so we had to change it so that version 2 files would be rejected by older // implementations. On this occasion, we made the magic number 32 bits long. static const int FORMAT_VERSION_2_MAGIC_NUMBER = -1681835266; // 0x9BC13AFE + // Magic number (4 bytes), version (2 bytes), options (2 bytes), header size (4 bytes) = 12 + static const int FORMAT_VERSION_2_MINIMUM_SIZE = 12; static const int CHARACTER_ARRAY_TERMINATOR_SIZE = 1; static const int MINIMAL_ONE_BYTE_CHARACTER_VALUE = 0x20; @@ -141,8 +146,11 @@ class BinaryFormat { static int skipBigrams(const uint8_t *const dict, const uint8_t flags, const int pos); }; -AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict) { +AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict, const int dictSize) { // The magic number is stored big-endian. + // If the dictionary is less than 4 bytes, we can't even read the magic number, so we don't + // understand this format. + if (dictSize < DICTIONARY_MINIMUM_SIZE) return UNKNOWN_FORMAT; const int magicNumber = (dict[0] << 24) + (dict[1] << 16) + (dict[2] << 8) + dict[3]; switch (magicNumber) { case FORMAT_VERSION_1_MAGIC_NUMBER: @@ -152,6 +160,10 @@ AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict) { // Options (2 bytes) must be 0x00 0x00 return 1; case FORMAT_VERSION_2_MAGIC_NUMBER: + // Version 2 dictionaries are at least 12 bytes long (see below details for the header). + // If this dictionary has the version 2 magic number but is less than 12 bytes long, then + // it's an unknown format and we need to avoid confidently reading the next bytes. + if (dictSize < FORMAT_VERSION_2_MINIMUM_SIZE) return UNKNOWN_FORMAT; // Format 2 header is as follows: // Magic number (4 bytes) 0x9B 0xC1 0x3A 0xFE // Version number (2 bytes) 0x00 0x02 @@ -163,8 +175,8 @@ AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict) { } } -inline int BinaryFormat::getFlags(const uint8_t *const dict) { - switch (detectFormat(dict)) { +inline int BinaryFormat::getFlags(const uint8_t *const dict, const int dictSize) { + switch (detectFormat(dict, dictSize)) { case 1: return NO_FLAGS; // TODO: NO_FLAGS is unused anywhere else? default: @@ -176,8 +188,8 @@ inline bool BinaryFormat::hasBlacklistedOrNotAWordFlag(const int flags) { return (flags & (FLAG_IS_BLACKLISTED | FLAG_IS_NOT_A_WORD)) != 0; } -inline int BinaryFormat::getHeaderSize(const uint8_t *const dict) { - switch (detectFormat(dict)) { +inline int BinaryFormat::getHeaderSize(const uint8_t *const dict, const int dictSize) { + switch (detectFormat(dict, dictSize)) { case 1: return FORMAT_VERSION_1_HEADER_SIZE; case 2: @@ -188,12 +200,12 @@ inline int BinaryFormat::getHeaderSize(const uint8_t *const dict) { } } -inline void BinaryFormat::readHeaderValue(const uint8_t *const dict, const char *const key, - int *outValue, const int outValueSize) { +inline void BinaryFormat::readHeaderValue(const uint8_t *const dict, const int dictSize, + const char *const key, int *outValue, const int outValueSize) { int outValueIndex = 0; // Only format 2 and above have header attributes as {key,value} string pairs. For prior // formats, we just return an empty string, as if the key wasn't found. - if (2 <= detectFormat(dict)) { + if (2 <= detectFormat(dict, dictSize)) { const int headerOptionsOffset = 4 /* magic number */ + 2 /* dictionary version */ + 2 /* flags */; const int headerSize = @@ -236,11 +248,12 @@ inline void BinaryFormat::readHeaderValue(const uint8_t *const dict, const char if (outValueIndex >= 0) outValue[outValueIndex] = 0; } -inline int BinaryFormat::readHeaderValueInt(const uint8_t *const dict, const char *const key) { +inline int BinaryFormat::readHeaderValueInt(const uint8_t *const dict, const int dictSize, + const char *const key) { const int bufferSize = LARGEST_INT_DIGIT_COUNT; int intBuffer[bufferSize]; char charBuffer[bufferSize]; - BinaryFormat::readHeaderValue(dict, key, intBuffer, bufferSize); + BinaryFormat::readHeaderValue(dict, dictSize, key, intBuffer, bufferSize); for (int i = 0; i < bufferSize; ++i) { charBuffer[i] = intBuffer[i]; } @@ -256,8 +269,10 @@ AK_FORCE_INLINE int BinaryFormat::getGroupCountAndForwardPointer(const uint8_t * return ((msb & 0x7F) << 8) | dict[(*pos)++]; } -inline float BinaryFormat::getMultiWordCostMultiplier(const uint8_t *const dict) { - const int headerValue = readHeaderValueInt(dict, "MULTIPLE_WORDS_DEMOTION_RATE"); +inline float BinaryFormat::getMultiWordCostMultiplier(const uint8_t *const dict, + const int dictSize) { + const int headerValue = readHeaderValueInt(dict, dictSize, + "MULTIPLE_WORDS_DEMOTION_RATE"); if (headerValue == S_INT_MIN) { return 1.0f; } diff --git a/native/jni/src/correction.cpp b/native/jni/src/correction.cpp index 0c65939e0..61bf3f619 100644 --- a/native/jni/src/correction.cpp +++ b/native/jni/src/correction.cpp @@ -23,6 +23,8 @@ #include "defines.h" #include "proximity_info_state.h" #include "suggest_utils.h" +#include "suggest/policyimpl/utils/edit_distance.h" +#include "suggest/policyimpl/utils/damerau_levenshtein_edit_distance_policy.h" namespace latinime { @@ -906,50 +908,11 @@ inline static bool isUpperCase(unsigned short c) { return totalFreq; } -/* Damerau-Levenshtein distance */ -inline static int editDistanceInternal(int *editDistanceTable, const int *before, - const int beforeLength, const int *after, const int afterLength) { - // dp[li][lo] dp[a][b] = dp[ a * lo + b] - int *dp = editDistanceTable; - const int li = beforeLength + 1; - const int lo = afterLength + 1; - for (int i = 0; i < li; ++i) { - dp[lo * i] = i; - } - for (int i = 0; i < lo; ++i) { - dp[i] = i; - } - - for (int i = 0; i < li - 1; ++i) { - for (int j = 0; j < lo - 1; ++j) { - const int ci = toBaseLowerCase(before[i]); - const int co = toBaseLowerCase(after[j]); - const int cost = (ci == co) ? 0 : 1; - dp[(i + 1) * lo + (j + 1)] = min(dp[i * lo + (j + 1)] + 1, - min(dp[(i + 1) * lo + j] + 1, dp[i * lo + j] + cost)); - if (i > 0 && j > 0 && ci == toBaseLowerCase(after[j - 1]) - && co == toBaseLowerCase(before[i - 1])) { - dp[(i + 1) * lo + (j + 1)] = min( - dp[(i + 1) * lo + (j + 1)], dp[(i - 1) * lo + (j - 1)] + cost); - } - } - } - - if (DEBUG_EDIT_DISTANCE) { - AKLOGI("IN = %d, OUT = %d", beforeLength, afterLength); - for (int i = 0; i < li; ++i) { - for (int j = 0; j < lo; ++j) { - AKLOGI("EDIT[%d][%d], %d", i, j, dp[i * lo + j]); - } - } - } - return dp[li * lo - 1]; -} - /* static */ int Correction::RankingAlgorithm::editDistance(const int *before, const int beforeLength, const int *after, const int afterLength) { - int table[(beforeLength + 1) * (afterLength + 1)]; - return editDistanceInternal(table, before, beforeLength, after, afterLength); + const DamerauLevenshteinEditDistancePolicy daemaruLevenshtein( + before, beforeLength, after, afterLength); + return static_cast<int>(EditDistance::getEditDistance(&daemaruLevenshtein)); } diff --git a/native/jni/src/dictionary.cpp b/native/jni/src/dictionary.cpp index c998c0676..dadb2bab2 100644 --- a/native/jni/src/dictionary.cpp +++ b/native/jni/src/dictionary.cpp @@ -34,9 +34,11 @@ namespace latinime { Dictionary::Dictionary(void *dict, int dictSize, int mmapFd, int dictBufAdjust) : mDict(static_cast<unsigned char *>(dict)), - mOffsetDict((static_cast<unsigned char *>(dict)) + BinaryFormat::getHeaderSize(mDict)), + mOffsetDict((static_cast<unsigned char *>(dict)) + + BinaryFormat::getHeaderSize(mDict, dictSize)), mDictSize(dictSize), mMmapFd(mmapFd), mDictBufAdjust(dictBufAdjust), - mUnigramDictionary(new UnigramDictionary(mOffsetDict, BinaryFormat::getFlags(mDict))), + mUnigramDictionary(new UnigramDictionary(mOffsetDict, + BinaryFormat::getFlags(mDict, dictSize))), mBigramDictionary(new BigramDictionary(mOffsetDict)), mGestureSuggest(new Suggest(GestureSuggestPolicyFactory::getGestureSuggestPolicy())), mTypingSuggest(new Suggest(TypingSuggestPolicyFactory::getTypingSuggestPolicy())) { diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.cpp b/native/jni/src/suggest/core/session/dic_traverse_session.cpp index 51165858b..6408f0163 100644 --- a/native/jni/src/suggest/core/session/dic_traverse_session.cpp +++ b/native/jni/src/suggest/core/session/dic_traverse_session.cpp @@ -64,7 +64,8 @@ static TraverseSessionFactoryRegisterer traverseSessionFactoryRegisterer; void DicTraverseSession::init(const Dictionary *const dictionary, const int *prevWord, int prevWordLength) { mDictionary = dictionary; - mMultiWordCostMultiplier = BinaryFormat::getMultiWordCostMultiplier(mDictionary->getDict()); + mMultiWordCostMultiplier = BinaryFormat::getMultiWordCostMultiplier(mDictionary->getDict(), + mDictionary->getDictSize()); if (!prevWord) { mPrevWordPos = NOT_VALID_WORD; return; diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp index 3221dee9c..a18794850 100644 --- a/native/jni/src/suggest/core/suggest.cpp +++ b/native/jni/src/suggest/core/suggest.cpp @@ -163,9 +163,14 @@ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequen terminalDicNode->getFlags(), terminalDicNode->getAttributesPos()); const bool isPossiblyOffensiveWord = terminalDicNode->getProbability() <= 0; const bool isExactMatch = terminalDicNode->isExactMatch(); + const bool isFirstCharUppercase = terminalDicNode->isFirstCharUppercase(); + // Heuristic: We exclude freq=0 first-char-uppercase words from exact match. + // (e.g. "AMD" and "and") + const bool isSafeExactMatch = isExactMatch + && !(isPossiblyOffensiveWord && isFirstCharUppercase); const int outputTypeFlags = - isPossiblyOffensiveWord ? Dictionary::KIND_FLAG_POSSIBLY_OFFENSIVE : 0 - | isExactMatch ? Dictionary::KIND_FLAG_EXACT_MATCH : 0; + (isPossiblyOffensiveWord ? Dictionary::KIND_FLAG_POSSIBLY_OFFENSIVE : 0) + | (isSafeExactMatch ? Dictionary::KIND_FLAG_EXACT_MATCH : 0); // Entries that are blacklisted or do not represent a word should not be output. const bool isValidWord = !terminalAttributes.isBlacklistedOrNotAWord(); diff --git a/native/jni/src/suggest/policyimpl/utils/damerau_levenshtein_edit_distance_policy.h b/native/jni/src/suggest/policyimpl/utils/damerau_levenshtein_edit_distance_policy.h new file mode 100644 index 000000000..ec1457455 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/utils/damerau_levenshtein_edit_distance_policy.h @@ -0,0 +1,79 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DAEMARU_LEVENSHTEIN_EDIT_DISTANCE_POLICY_H +#define LATINIME_DAEMARU_LEVENSHTEIN_EDIT_DISTANCE_POLICY_H + +#include "char_utils.h" +#include "suggest/policyimpl/utils/edit_distance_policy.h" + +namespace latinime { + +class DamerauLevenshteinEditDistancePolicy : public EditDistancePolicy { + public: + DamerauLevenshteinEditDistancePolicy(const int *const string0, const int length0, + const int *const string1, const int length1) + : mString0(string0), mString0Length(length0), mString1(string1), + mString1Length(length1) {} + ~DamerauLevenshteinEditDistancePolicy() {} + + AK_FORCE_INLINE float getSubstitutionCost(const int index0, const int index1) const { + const int c0 = toBaseLowerCase(mString0[index0]); + const int c1 = toBaseLowerCase(mString1[index1]); + return (c0 == c1) ? 0.0f : 1.0f; + } + + AK_FORCE_INLINE float getDeletionCost(const int index0, const int index1) const { + return 1.0f; + } + + AK_FORCE_INLINE float getInsertionCost(const int index0, const int index1) const { + return 1.0f; + } + + AK_FORCE_INLINE bool allowTransposition(const int index0, const int index1) const { + const int c0 = toBaseLowerCase(mString0[index0]); + const int c1 = toBaseLowerCase(mString1[index1]); + if (index0 > 0 && index1 > 0 && c0 == toBaseLowerCase(mString1[index1 - 1]) + && c1 == toBaseLowerCase(mString0[index0 - 1])) { + return true; + } + return false; + } + + AK_FORCE_INLINE float getTranspositionCost(const int index0, const int index1) const { + return getSubstitutionCost(index0, index1); + } + + AK_FORCE_INLINE int getString0Length() const { + return mString0Length; + } + + AK_FORCE_INLINE int getString1Length() const { + return mString1Length; + } + + private: + DISALLOW_COPY_AND_ASSIGN (DamerauLevenshteinEditDistancePolicy); + + const int *const mString0; + const int mString0Length; + const int *const mString1; + const int mString1Length; +}; +} // namespace latinime + +#endif // LATINIME_DAEMARU_LEVENSHTEIN_EDIT_DISTANCE_POLICY_H diff --git a/native/jni/src/suggest/policyimpl/utils/edit_distance.h b/native/jni/src/suggest/policyimpl/utils/edit_distance.h new file mode 100644 index 000000000..cbbd66894 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/utils/edit_distance.h @@ -0,0 +1,70 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_EDIT_DISTANCE_H +#define LATINIME_EDIT_DISTANCE_H + +#include "defines.h" +#include "suggest/policyimpl/utils/edit_distance_policy.h" + +namespace latinime { + +class EditDistance { + public: + // CAVEAT: There may be performance penalty if you need the edit distance as an integer value. + AK_FORCE_INLINE static float getEditDistance(const EditDistancePolicy *const policy) { + const int beforeLength = policy->getString0Length(); + const int afterLength = policy->getString1Length(); + float dp[(beforeLength + 1) * (afterLength + 1)]; + for (int i = 0; i <= beforeLength; ++i) { + dp[(afterLength + 1) * i] = i * policy->getInsertionCost(i - 1, -1); + } + for (int i = 0; i <= afterLength; ++i) { + dp[i] = i * policy->getDeletionCost(-1, i - 1); + } + + for (int i = 0; i < beforeLength; ++i) { + for (int j = 0; j < afterLength; ++j) { + dp[(afterLength + 1) * (i + 1) + (j + 1)] = min( + dp[(afterLength + 1) * i + (j + 1)] + policy->getInsertionCost(i, j), + min(dp[(afterLength + 1) * (i + 1) + j] + policy->getDeletionCost(i, j), + dp[(afterLength + 1) * i + j] + + policy->getSubstitutionCost(i, j))); + if (policy->allowTransposition(i, j)) { + dp[(afterLength + 1) * (i + 1) + (j + 1)] = min( + dp[(afterLength + 1) * (i + 1) + (j + 1)], + dp[(afterLength + 1) * (i - 1) + (j - 1)] + + policy->getTranspositionCost(i, j)); + } + } + } + if (DEBUG_EDIT_DISTANCE) { + AKLOGI("IN = %d, OUT = %d", beforeLength, afterLength); + for (int i = 0; i < beforeLength + 1; ++i) { + for (int j = 0; j < afterLength + 1; ++j) { + AKLOGI("EDIT[%d][%d], %f", i, j, dp[(afterLength + 1) * i + j]); + } + } + } + return dp[(beforeLength + 1) * (afterLength + 1) - 1]; + } + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(EditDistance); +}; +} // namespace latinime + +#endif // LATINIME_EDIT_DISTANCE_H diff --git a/native/jni/src/suggest/policyimpl/utils/edit_distance_policy.h b/native/jni/src/suggest/policyimpl/utils/edit_distance_policy.h new file mode 100644 index 000000000..e3d1792cb --- /dev/null +++ b/native/jni/src/suggest/policyimpl/utils/edit_distance_policy.h @@ -0,0 +1,43 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_EDIT_DISTANCE_POLICY_H +#define LATINIME_EDIT_DISTANCE_POLICY_H + +#include "defines.h" + +namespace latinime { + +class EditDistancePolicy { + public: + virtual float getSubstitutionCost(const int index0, const int index1) const = 0; + virtual float getDeletionCost(const int index0, const int index1) const = 0; + virtual float getInsertionCost(const int index0, const int index1) const = 0; + virtual bool allowTransposition(const int index0, const int index1) const = 0; + virtual float getTranspositionCost(const int index0, const int index1) const = 0; + virtual int getString0Length() const = 0; + virtual int getString1Length() const = 0; + + protected: + EditDistancePolicy() {} + virtual ~EditDistancePolicy() {} + + private: + DISALLOW_COPY_AND_ASSIGN(EditDistancePolicy); +}; +} // namespace latinime + +#endif // LATINIME_EDIT_DISTANCE_POLICY_H |