aboutsummaryrefslogtreecommitdiffstats
path: root/native/jni/src
diff options
context:
space:
mode:
Diffstat (limited to 'native/jni/src')
-rw-r--r--native/jni/src/bigram_dictionary.cpp2
-rw-r--r--native/jni/src/binary_format.h77
-rw-r--r--native/jni/src/defines.h38
-rw-r--r--native/jni/src/dictionary.h5
-rw-r--r--native/jni/src/multi_bigram_map.h89
-rw-r--r--native/jni/src/suggest/core/dicnode/dic_node.h19
-rw-r--r--native/jni/src/suggest/core/dicnode/dic_node_properties.h5
-rw-r--r--native/jni/src/suggest/core/dicnode/dic_node_state_input.h4
-rw-r--r--native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h48
-rw-r--r--native/jni/src/suggest/core/dicnode/dic_node_utils.cpp92
-rw-r--r--native/jni/src/suggest/core/dicnode/dic_node_utils.h13
-rw-r--r--native/jni/src/suggest/core/policy/scoring.h12
-rw-r--r--native/jni/src/suggest/core/policy/suggest_policy.h1
-rw-r--r--native/jni/src/suggest/core/policy/traversal.h13
-rw-r--r--native/jni/src/suggest/core/policy/weighting.cpp91
-rw-r--r--native/jni/src/suggest/core/policy/weighting.h18
-rw-r--r--native/jni/src/suggest/core/session/dic_traverse_session.cpp12
-rw-r--r--native/jni/src/suggest/core/session/dic_traverse_session.h8
-rw-r--r--native/jni/src/suggest/core/suggest.cpp51
-rw-r--r--native/jni/src/suggest/policyimpl/typing/scoring_params.cpp34
-rw-r--r--native/jni/src/suggest/policyimpl/typing/scoring_params.h4
-rw-r--r--native/jni/src/suggest/policyimpl/typing/typing_traversal.h11
-rw-r--r--native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp36
-rw-r--r--native/jni/src/suggest/policyimpl/typing/typing_weighting.h89
-rw-r--r--native/jni/src/terminal_attributes.h2
25 files changed, 450 insertions, 324 deletions
diff --git a/native/jni/src/bigram_dictionary.cpp b/native/jni/src/bigram_dictionary.cpp
index 92890383a..9053e7226 100644
--- a/native/jni/src/bigram_dictionary.cpp
+++ b/native/jni/src/bigram_dictionary.cpp
@@ -187,7 +187,7 @@ void BigramDictionary::fillBigramAddressToProbabilityMapAndFilter(const int *pre
&pos);
(*map)[bigramPos] = probability;
setInFilter(filter, bigramPos);
- } while (0 != (BinaryFormat::FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags));
+ } while (BinaryFormat::FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags);
}
bool BigramDictionary::checkFirstCharacter(int *word, int *inputCodePoints) const {
diff --git a/native/jni/src/binary_format.h b/native/jni/src/binary_format.h
index 2d2e19501..06f50dc7f 100644
--- a/native/jni/src/binary_format.h
+++ b/native/jni/src/binary_format.h
@@ -23,6 +23,7 @@
#include "bloom_filter.h"
#include "char_utils.h"
+#include "hash_map_compat.h"
namespace latinime {
@@ -66,6 +67,7 @@ class BinaryFormat {
static int detectFormat(const uint8_t *const dict);
static int getHeaderSize(const uint8_t *const dict);
static int getFlags(const uint8_t *const dict);
+ static bool hasBlacklistedOrNotAWordFlag(const int flags);
static void readHeaderValue(const uint8_t *const dict, const char *const key, int *outValue,
const int outValueSize);
static int readHeaderValueInt(const uint8_t *const dict, const char *const key);
@@ -92,7 +94,13 @@ class BinaryFormat {
const int unigramProbability, const int bigramProbability);
static int getProbability(const int position, const std::map<int, int> *bigramMap,
const uint8_t *bigramFilter, const int unigramProbability);
+ static int getBigramProbabilityFromHashMap(const int position,
+ const hash_map_compat<int, int> *bigramMap, const int unigramProbability);
static float getMultiWordCostMultiplier(const uint8_t *const dict);
+ static void fillBigramProbabilityToHashMap(const uint8_t *const root, int position,
+ hash_map_compat<int, int> *bigramMap);
+ static int getBigramProbability(const uint8_t *const root, int position,
+ const int nextPosition, const int unigramProbability);
// Flags for special processing
// Those *must* match the flags in makedict (BinaryDictInputOutput#*_PROCESSING_FLAG) or
@@ -104,6 +112,8 @@ class BinaryFormat {
private:
DISALLOW_IMPLICIT_CONSTRUCTORS(BinaryFormat);
+ static int getBigramListPositionForWordPosition(const uint8_t *const root, int position);
+
static const int FLAG_GROUP_ADDRESS_TYPE_NOADDRESS = 0x00;
static const int FLAG_GROUP_ADDRESS_TYPE_ONEBYTE = 0x40;
static const int FLAG_GROUP_ADDRESS_TYPE_TWOBYTES = 0x80;
@@ -162,6 +172,10 @@ inline int BinaryFormat::getFlags(const uint8_t *const dict) {
}
}
+inline bool BinaryFormat::hasBlacklistedOrNotAWordFlag(const int flags) {
+ return (flags & (FLAG_IS_BLACKLISTED | FLAG_IS_NOT_A_WORD)) != 0;
+}
+
inline int BinaryFormat::getHeaderSize(const uint8_t *const dict) {
switch (detectFormat(dict)) {
case 1:
@@ -682,5 +696,68 @@ inline int BinaryFormat::getProbability(const int position, const std::map<int,
}
return backoff(unigramProbability);
}
+
+// This returns a probability in log space.
+inline int BinaryFormat::getBigramProbabilityFromHashMap(const int position,
+ const hash_map_compat<int, int> *bigramMap, const int unigramProbability) {
+ if (!bigramMap) return backoff(unigramProbability);
+ const hash_map_compat<int, int>::const_iterator bigramProbabilityIt = bigramMap->find(position);
+ if (bigramProbabilityIt != bigramMap->end()) {
+ const int bigramProbability = bigramProbabilityIt->second;
+ return computeProbabilityForBigram(unigramProbability, bigramProbability);
+ }
+ return backoff(unigramProbability);
+}
+
+AK_FORCE_INLINE void BinaryFormat::fillBigramProbabilityToHashMap(
+ const uint8_t *const root, int position, hash_map_compat<int, int> *bigramMap) {
+ position = getBigramListPositionForWordPosition(root, position);
+ if (0 == position) return;
+
+ uint8_t bigramFlags;
+ do {
+ bigramFlags = getFlagsAndForwardPointer(root, &position);
+ const int probability = MASK_ATTRIBUTE_PROBABILITY & bigramFlags;
+ const int bigramPos = getAttributeAddressAndForwardPointer(root, bigramFlags,
+ &position);
+ (*bigramMap)[bigramPos] = probability;
+ } while (FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags);
+}
+
+AK_FORCE_INLINE int BinaryFormat::getBigramProbability(const uint8_t *const root, int position,
+ const int nextPosition, const int unigramProbability) {
+ position = getBigramListPositionForWordPosition(root, position);
+ if (0 == position) return backoff(unigramProbability);
+
+ uint8_t bigramFlags;
+ do {
+ bigramFlags = getFlagsAndForwardPointer(root, &position);
+ const int bigramPos = getAttributeAddressAndForwardPointer(
+ root, bigramFlags, &position);
+ if (bigramPos == nextPosition) {
+ const int bigramProbability = MASK_ATTRIBUTE_PROBABILITY & bigramFlags;
+ return computeProbabilityForBigram(unigramProbability, bigramProbability);
+ }
+ } while (FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags);
+ return backoff(unigramProbability);
+}
+
+// Returns a pointer to the start of the bigram list.
+AK_FORCE_INLINE int BinaryFormat::getBigramListPositionForWordPosition(
+ const uint8_t *const root, int position) {
+ if (NOT_VALID_WORD == position) return 0;
+ const uint8_t flags = getFlagsAndForwardPointer(root, &position);
+ if (!(flags & FLAG_HAS_BIGRAMS)) return 0;
+ if (flags & FLAG_HAS_MULTIPLE_CHARS) {
+ position = skipOtherCharacters(root, position);
+ } else {
+ getCodePointAndForwardPointer(root, &position);
+ }
+ position = skipProbability(flags, position);
+ position = skipChildrenPosition(flags, position);
+ position = skipShortcuts(root, flags, position);
+ return position;
+}
+
} // namespace latinime
#endif // LATINIME_BINARY_FORMAT_H
diff --git a/native/jni/src/defines.h b/native/jni/src/defines.h
index 6ef9f414b..eb59744f6 100644
--- a/native/jni/src/defines.h
+++ b/native/jni/src/defines.h
@@ -379,6 +379,15 @@ static inline void prof_out(void) {
#error "BIGRAM_FILTER_MODULO is larger than BIGRAM_FILTER_BYTE_SIZE"
#endif
+// Max number of bigram maps (previous word contexts) to be cached. Increasing this number could
+// improve bigram lookup speed for multi-word suggestions, but at the cost of more memory usage.
+// Also, there are diminishing returns since the most frequently used bigrams are typically near
+// the beginning of the input and are thus the first ones to be cached. Note that these bigrams
+// are reset for each new composing word.
+#define MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP 25
+// Most common previous word contexts currently have 100 bigrams
+#define DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP 100
+
template<typename T> AK_FORCE_INLINE const T &min(const T &a, const T &b) { return a < b ? a : b; }
template<typename T> AK_FORCE_INLINE const T &max(const T &a, const T &b) { return a > b ? a : b; }
@@ -417,16 +426,45 @@ typedef enum {
} DoubleLetterLevel;
typedef enum {
+ // Correction for MATCH_CHAR
CT_MATCH,
+ // Correction for PROXIMITY_CHAR
CT_PROXIMITY,
+ // Correction for ADDITIONAL_PROXIMITY_CHAR
CT_ADDITIONAL_PROXIMITY,
+ // Correction for SUBSTITUTION_CHAR
CT_SUBSTITUTION,
+ // Skip one omitted letter
CT_OMISSION,
+ // Delete an unnecessarily inserted letter
CT_INSERTION,
+ // Swap the order of next two touch points
CT_TRANSPOSITION,
CT_COMPLETION,
CT_TERMINAL,
+ // Create new word with space omission
CT_NEW_WORD_SPACE_OMITTION,
+ // Create new word with space substitution
CT_NEW_WORD_SPACE_SUBSTITUTION,
} CorrectionType;
+
+// ErrorType is mainly decided by CorrectionType but it is also depending on if
+// the correction has really been performed or not.
+typedef enum {
+ // Substitution, omission and transposition
+ ET_EDIT_CORRECTION,
+ // Proximity error
+ ET_PROXIMITY_CORRECTION,
+ // Completion
+ ET_COMPLETION,
+ // New word
+ // TODO: Remove.
+ // A new word error should be an edit correction error or a proximity correction error.
+ ET_NEW_WORD,
+ // Treat error as an intentional omission when the CorrectionType is omission and the node can
+ // be intentional omission.
+ ET_INTENTIONAL_OMISSION,
+ // Not treated as an error. Tracked for checking exact match
+ ET_NOT_AN_ERROR
+} ErrorType;
#endif // LATINIME_DEFINES_H
diff --git a/native/jni/src/dictionary.h b/native/jni/src/dictionary.h
index 0653d3ca9..2ad5b6c0b 100644
--- a/native/jni/src/dictionary.h
+++ b/native/jni/src/dictionary.h
@@ -31,6 +31,7 @@ class UnigramDictionary;
class Dictionary {
public:
// Taken from SuggestedWords.java
+ static const int KIND_MASK_KIND = 0xFF; // Mask to get only the kind
static const int KIND_TYPED = 0; // What user typed
static const int KIND_CORRECTION = 1; // Simple correction/suggestion
static const int KIND_COMPLETION = 2; // Completion (suggestion with appended chars)
@@ -41,6 +42,10 @@ class Dictionary {
static const int KIND_SHORTCUT = 7; // A shortcut
static const int KIND_PREDICTION = 8; // A prediction (== a suggestion with no input)
+ static const int KIND_MASK_FLAGS = 0xFFFFFF00; // Mask to get the flags
+ static const int KIND_FLAG_POSSIBLY_OFFENSIVE = 0x80000000;
+ static const int KIND_FLAG_EXACT_MATCH = 0x40000000;
+
Dictionary(void *dict, int dictSize, int mmapFd, int dictBufAdjust);
int getSuggestions(ProximityInfo *proximityInfo, void *traverseSession, int *xcoordinates,
diff --git a/native/jni/src/multi_bigram_map.h b/native/jni/src/multi_bigram_map.h
new file mode 100644
index 000000000..7e1b6301f
--- /dev/null
+++ b/native/jni/src/multi_bigram_map.h
@@ -0,0 +1,89 @@
+/*
+ * Copyright (C) 2013 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LATINIME_MULTI_BIGRAM_MAP_H
+#define LATINIME_MULTI_BIGRAM_MAP_H
+
+#include <cstring>
+#include <stdint.h>
+
+#include "defines.h"
+#include "binary_format.h"
+#include "hash_map_compat.h"
+
+namespace latinime {
+
+// Class for caching bigram maps for multiple previous word contexts. This is useful since the
+// algorithm needs to look up the set of bigrams for every word pair that occurs in every
+// multi-word suggestion.
+class MultiBigramMap {
+ public:
+ MultiBigramMap() : mBigramMaps() {}
+ ~MultiBigramMap() {}
+
+ // Look up the bigram probability for the given word pair from the cached bigram maps.
+ // Also caches the bigrams if there is space remaining and they have not been cached already.
+ int getBigramProbability(const uint8_t *const dicRoot, const int wordPosition,
+ const int nextWordPosition, const int unigramProbability) {
+ hash_map_compat<int, BigramMap>::const_iterator mapPosition =
+ mBigramMaps.find(wordPosition);
+ if (mapPosition != mBigramMaps.end()) {
+ return mapPosition->second.getBigramProbability(nextWordPosition, unigramProbability);
+ }
+ if (mBigramMaps.size() < MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP) {
+ addBigramsForWordPosition(dicRoot, wordPosition);
+ return mBigramMaps[wordPosition].getBigramProbability(
+ nextWordPosition, unigramProbability);
+ }
+ return BinaryFormat::getBigramProbability(
+ dicRoot, wordPosition, nextWordPosition, unigramProbability);
+ }
+
+ void clear() {
+ mBigramMaps.clear();
+ }
+
+ private:
+ DISALLOW_COPY_AND_ASSIGN(MultiBigramMap);
+
+ class BigramMap {
+ public:
+ BigramMap() : mBigramMap(DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP) {}
+ ~BigramMap() {}
+
+ void init(const uint8_t *const dicRoot, int position) {
+ BinaryFormat::fillBigramProbabilityToHashMap(dicRoot, position, &mBigramMap);
+ }
+
+ inline int getBigramProbability(const int nextWordPosition, const int unigramProbability)
+ const {
+ return BinaryFormat::getBigramProbabilityFromHashMap(
+ nextWordPosition, &mBigramMap, unigramProbability);
+ }
+
+ private:
+ // Note: Default copy constructor needed for use in hash_map.
+ hash_map_compat<int, int> mBigramMap;
+ };
+
+ void addBigramsForWordPosition(const uint8_t *const dicRoot, const int position) {
+ mBigramMaps[position].init(dicRoot, position);
+ }
+
+ hash_map_compat<int, BigramMap> mBigramMaps;
+};
+} // namespace latinime
+#endif // LATINIME_MULTI_BIGRAM_MAP_H
diff --git a/native/jni/src/suggest/core/dicnode/dic_node.h b/native/jni/src/suggest/core/dicnode/dic_node.h
index 32faae52c..4225bb3e5 100644
--- a/native/jni/src/suggest/core/dicnode/dic_node.h
+++ b/native/jni/src/suggest/core/dicnode/dic_node.h
@@ -210,8 +210,7 @@ class DicNode {
}
bool isImpossibleBigramWord() const {
- const int probability = mDicNodeProperties.getProbability();
- if (probability == 0) {
+ if (mDicNodeProperties.hasBlacklistedOrNotAWordFlag()) {
return true;
}
const int prevWordLen = mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength()
@@ -220,7 +219,7 @@ class DicNode {
return (prevWordLen == 1 && currentWordLen == 1);
}
- bool isCapitalized() const {
+ bool isFirstCharUppercase() const {
const int c = getOutputWordBuf()[0];
return isAsciiUpper(c);
}
@@ -360,11 +359,6 @@ class DicNode {
return mDicNodeState.mDicNodeStateScoring.getCompoundDistance(languageWeight);
}
- // Note that "cost" means delta for "distance" that is weighted.
- float getTotalPrevWordsLanguageCost() const {
- return mDicNodeState.mDicNodeStateScoring.getTotalPrevWordsLanguageCost();
- }
-
// Used to commit input partially
int getPrevWordNodePos() const {
return mDicNodeState.mDicNodeStatePrevWord.getPrevWordNodePos();
@@ -469,6 +463,10 @@ class DicNode {
mDicNodeState.mDicNodeStateScoring.advanceDigraphIndex();
}
+ bool isExactMatch() const {
+ return mDicNodeState.mDicNodeStateScoring.isExactMatch();
+ }
+
uint8_t getFlags() const {
return mDicNodeProperties.getFlags();
}
@@ -548,13 +546,12 @@ class DicNode {
// Caveat: Must not be called outside Weighting
// This restriction is guaranteed by "friend"
AK_FORCE_INLINE void addCost(const float spatialCost, const float languageCost,
- const bool doNormalization, const int inputSize, const bool isEditCorrection,
- const bool isProximityCorrection) {
+ const bool doNormalization, const int inputSize, const ErrorType errorType) {
if (DEBUG_GEO_FULL) {
LOGI_SHOW_ADD_COST_PROP;
}
mDicNodeState.mDicNodeStateScoring.addCost(spatialCost, languageCost, doNormalization,
- inputSize, getTotalInputIndex(), isEditCorrection, isProximityCorrection);
+ inputSize, getTotalInputIndex(), errorType);
}
// Caveat: Must not be called outside Weighting
diff --git a/native/jni/src/suggest/core/dicnode/dic_node_properties.h b/native/jni/src/suggest/core/dicnode/dic_node_properties.h
index 173ef35d0..63a6b1340 100644
--- a/native/jni/src/suggest/core/dicnode/dic_node_properties.h
+++ b/native/jni/src/suggest/core/dicnode/dic_node_properties.h
@@ -19,6 +19,7 @@
#include <stdint.h>
+#include "binary_format.h"
#include "defines.h"
namespace latinime {
@@ -144,6 +145,10 @@ class DicNodeProperties {
return mChildrenCount > 0 || mDepth != mLeavingDepth;
}
+ bool hasBlacklistedOrNotAWordFlag() const {
+ return BinaryFormat::hasBlacklistedOrNotAWordFlag(mFlags);
+ }
+
private:
// Caution!!!
// Use a default copy constructor and an assign operator because shallow copies are ok
diff --git a/native/jni/src/suggest/core/dicnode/dic_node_state_input.h b/native/jni/src/suggest/core/dicnode/dic_node_state_input.h
index 7ad3e3e5f..bbd9435b5 100644
--- a/native/jni/src/suggest/core/dicnode/dic_node_state_input.h
+++ b/native/jni/src/suggest/core/dicnode/dic_node_state_input.h
@@ -46,8 +46,8 @@ class DicNodeStateInput {
for (int i = 0; i < MAX_POINTER_COUNT_G; i++) {
mInputIndex[i] = src->mInputIndex[i];
mPrevCodePoint[i] = src->mPrevCodePoint[i];
- mTerminalDiffCost[i] = resetTerminalDiffCost ?
- static_cast<float>(MAX_VALUE_FOR_WEIGHTING) : src->mTerminalDiffCost[i];
+ mTerminalDiffCost[i] = resetTerminalDiffCost ?
+ static_cast<float>(MAX_VALUE_FOR_WEIGHTING) : src->mTerminalDiffCost[i];
}
}
diff --git a/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h b/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h
index 8902d3122..dca9d60da 100644
--- a/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h
+++ b/native/jni/src/suggest/core/dicnode/dic_node_state_scoring.h
@@ -31,7 +31,7 @@ class DicNodeStateScoring {
mDigraphIndex(DigraphUtils::NOT_A_DIGRAPH_INDEX),
mEditCorrectionCount(0), mProximityCorrectionCount(0),
mNormalizedCompoundDistance(0.0f), mSpatialDistance(0.0f), mLanguageDistance(0.0f),
- mTotalPrevWordsLanguageCost(0.0f), mRawLength(0.0f) {
+ mRawLength(0.0f), mExactMatch(true) {
}
virtual ~DicNodeStateScoring() {}
@@ -42,10 +42,10 @@ class DicNodeStateScoring {
mNormalizedCompoundDistance = 0.0f;
mSpatialDistance = 0.0f;
mLanguageDistance = 0.0f;
- mTotalPrevWordsLanguageCost = 0.0f;
mRawLength = 0.0f;
mDoubleLetterLevel = NOT_A_DOUBLE_LETTER;
mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX;
+ mExactMatch = true;
}
AK_FORCE_INLINE void init(const DicNodeStateScoring *const scoring) {
@@ -54,24 +54,35 @@ class DicNodeStateScoring {
mNormalizedCompoundDistance = scoring->mNormalizedCompoundDistance;
mSpatialDistance = scoring->mSpatialDistance;
mLanguageDistance = scoring->mLanguageDistance;
- mTotalPrevWordsLanguageCost = scoring->mTotalPrevWordsLanguageCost;
mRawLength = scoring->mRawLength;
mDoubleLetterLevel = scoring->mDoubleLetterLevel;
mDigraphIndex = scoring->mDigraphIndex;
+ mExactMatch = scoring->mExactMatch;
}
void addCost(const float spatialCost, const float languageCost, const bool doNormalization,
- const int inputSize, const int totalInputIndex, const bool isEditCorrection,
- const bool isProximityCorrection) {
+ const int inputSize, const int totalInputIndex, const ErrorType errorType) {
addDistance(spatialCost, languageCost, doNormalization, inputSize, totalInputIndex);
- if (isEditCorrection) {
- ++mEditCorrectionCount;
- }
- if (isProximityCorrection) {
- ++mProximityCorrectionCount;
- }
- if (languageCost > 0.0f) {
- setTotalPrevWordsLanguageCost(mTotalPrevWordsLanguageCost + languageCost);
+ switch (errorType) {
+ case ET_EDIT_CORRECTION:
+ ++mEditCorrectionCount;
+ mExactMatch = false;
+ break;
+ case ET_PROXIMITY_CORRECTION:
+ ++mProximityCorrectionCount;
+ mExactMatch = false;
+ break;
+ case ET_COMPLETION:
+ mExactMatch = false;
+ break;
+ case ET_NEW_WORD:
+ mExactMatch = false;
+ break;
+ case ET_INTENTIONAL_OMISSION:
+ mExactMatch = false;
+ break;
+ case ET_NOT_AN_ERROR:
+ break;
}
}
@@ -148,8 +159,8 @@ class DicNodeStateScoring {
}
}
- float getTotalPrevWordsLanguageCost() const {
- return mTotalPrevWordsLanguageCost;
+ bool isExactMatch() const {
+ return mExactMatch;
}
private:
@@ -165,8 +176,8 @@ class DicNodeStateScoring {
float mNormalizedCompoundDistance;
float mSpatialDistance;
float mLanguageDistance;
- float mTotalPrevWordsLanguageCost;
float mRawLength;
+ bool mExactMatch;
AK_FORCE_INLINE void addDistance(float spatialDistance, float languageDistance,
bool doNormalization, int inputSize, int totalInputIndex) {
@@ -179,11 +190,6 @@ class DicNodeStateScoring {
/ static_cast<float>(max(1, totalInputIndex));
}
}
-
- //TODO: remove
- AK_FORCE_INLINE void setTotalPrevWordsLanguageCost(float totalPrevWordsLanguageCost) {
- mTotalPrevWordsLanguageCost = totalPrevWordsLanguageCost;
- }
};
} // namespace latinime
#endif // LATINIME_DIC_NODE_STATE_SCORING_H
diff --git a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp
index 031e706ae..5357c3773 100644
--- a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp
+++ b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp
@@ -21,6 +21,7 @@
#include "dic_node.h"
#include "dic_node_utils.h"
#include "dic_node_vector.h"
+#include "multi_bigram_map.h"
#include "proximity_info.h"
#include "proximity_info_state.h"
@@ -191,11 +192,11 @@ namespace latinime {
* Computes the combined bigram / unigram cost for the given dicNode.
*/
/* static */ float DicNodeUtils::getBigramNodeImprobability(const uint8_t *const dicRoot,
- const DicNode *const node, hash_map_compat<int, int16_t> *bigramCacheMap) {
+ const DicNode *const node, MultiBigramMap *multiBigramMap) {
if (node->isImpossibleBigramWord()) {
return static_cast<float>(MAX_VALUE_FOR_WEIGHTING);
}
- const int probability = getBigramNodeProbability(dicRoot, node, bigramCacheMap);
+ const int probability = getBigramNodeProbability(dicRoot, node, multiBigramMap);
// TODO: This equation to calculate the improbability looks unreasonable. Investigate this.
const float cost = static_cast<float>(MAX_PROBABILITY - probability)
/ static_cast<float>(MAX_PROBABILITY);
@@ -203,92 +204,25 @@ namespace latinime {
}
/* static */ int DicNodeUtils::getBigramNodeProbability(const uint8_t *const dicRoot,
- const DicNode *const node, hash_map_compat<int, int16_t> *bigramCacheMap) {
+ const DicNode *const node, MultiBigramMap *multiBigramMap) {
const int unigramProbability = node->getProbability();
- const int encodedDiffOfBigramProbability =
- getBigramNodeEncodedDiffProbability(dicRoot, node, bigramCacheMap);
- if (NOT_A_PROBABILITY == encodedDiffOfBigramProbability) {
+ const int wordPos = node->getPos();
+ const int prevWordPos = node->getPrevWordPos();
+ if (NOT_VALID_WORD == wordPos || NOT_VALID_WORD == prevWordPos) {
+ // Note: Normally wordPos comes from the dictionary and should never equal NOT_VALID_WORD.
return backoff(unigramProbability);
}
- return BinaryFormat::computeProbabilityForBigram(
- unigramProbability, encodedDiffOfBigramProbability);
+ if (multiBigramMap) {
+ return multiBigramMap->getBigramProbability(
+ dicRoot, prevWordPos, wordPos, unigramProbability);
+ }
+ return BinaryFormat::getBigramProbability(dicRoot, prevWordPos, wordPos, unigramProbability);
}
///////////////////////////////////////
// Bigram / Unigram dictionary utils //
///////////////////////////////////////
-/* static */ int16_t DicNodeUtils::getBigramNodeEncodedDiffProbability(const uint8_t *const dicRoot,
- const DicNode *const node, hash_map_compat<int, int16_t> *bigramCacheMap) {
- const int wordPos = node->getPos();
- const int prevWordPos = node->getPrevWordPos();
- return getBigramProbability(dicRoot, prevWordPos, wordPos, bigramCacheMap);
-}
-
-// TODO: Move this to BigramDictionary
-/* static */ int16_t DicNodeUtils::getBigramProbability(const uint8_t *const dicRoot, int pos,
- const int nextPos, hash_map_compat<int, int16_t> *bigramCacheMap) {
- // TODO: this is painfully slow compared to the method used in the previous version of the
- // algorithm. Switch to that method.
- if (NOT_VALID_WORD == pos) return NOT_A_PROBABILITY;
- if (NOT_VALID_WORD == nextPos) return NOT_A_PROBABILITY;
-
- // Create a hash code for the given node pair (based on Josh Bloch's effective Java).
- // TODO: Use a real hash map data structure that deals with collisions.
- int hash = 17;
- hash = hash * 31 + pos;
- hash = hash * 31 + nextPos;
-
- hash_map_compat<int, int16_t>::const_iterator mapPos = bigramCacheMap->find(hash);
- if (mapPos != bigramCacheMap->end()) {
- return mapPos->second;
- }
- if (NOT_VALID_WORD == pos) {
- return NOT_A_PROBABILITY;
- }
- const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(dicRoot, &pos);
- if (0 == (flags & BinaryFormat::FLAG_HAS_BIGRAMS)) {
- return NOT_A_PROBABILITY;
- }
- if (0 == (flags & BinaryFormat::FLAG_HAS_MULTIPLE_CHARS)) {
- BinaryFormat::getCodePointAndForwardPointer(dicRoot, &pos);
- } else {
- pos = BinaryFormat::skipOtherCharacters(dicRoot, pos);
- }
- pos = BinaryFormat::skipChildrenPosition(flags, pos);
- pos = BinaryFormat::skipProbability(flags, pos);
- uint8_t bigramFlags;
- int count = 0;
- do {
- bigramFlags = BinaryFormat::getFlagsAndForwardPointer(dicRoot, &pos);
- const int bigramPos = BinaryFormat::getAttributeAddressAndForwardPointer(dicRoot,
- bigramFlags, &pos);
- if (bigramPos == nextPos) {
- const int16_t probability = BinaryFormat::MASK_ATTRIBUTE_PROBABILITY & bigramFlags;
- if (static_cast<int>(bigramCacheMap->size()) < MAX_BIGRAM_MAP_SIZE) {
- (*bigramCacheMap)[hash] = probability;
- }
- return probability;
- }
- count++;
- } while ((0 != (BinaryFormat::FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags))
- && count < MAX_BIGRAMS_CONSIDERED_PER_CONTEXT);
- if (static_cast<int>(bigramCacheMap->size()) < MAX_BIGRAM_MAP_SIZE) {
- // TODO: does this -1 mean NOT_VALID_WORD?
- (*bigramCacheMap)[hash] = -1;
- }
- return NOT_A_PROBABILITY;
-}
-
-/* static */ int DicNodeUtils::getWordPos(const uint8_t *const dicRoot, const int *word,
- const int wordLength) {
- if (!word) {
- return NOT_VALID_WORD;
- }
- return BinaryFormat::getTerminalPosition(
- dicRoot, word, wordLength, false /* forceLowerCaseSearch */);
-}
-
/* static */ bool DicNodeUtils::isMatchedNodeCodePoint(const ProximityInfoState *pInfoState,
const int pointIndex, const bool exactOnly, const int nodeCodePoint) {
if (!pInfoState) {
diff --git a/native/jni/src/suggest/core/dicnode/dic_node_utils.h b/native/jni/src/suggest/core/dicnode/dic_node_utils.h
index 15f9730de..5bc542d05 100644
--- a/native/jni/src/suggest/core/dicnode/dic_node_utils.h
+++ b/native/jni/src/suggest/core/dicnode/dic_node_utils.h
@@ -21,7 +21,6 @@
#include <vector>
#include "defines.h"
-#include "hash_map_compat.h"
namespace latinime {
@@ -29,6 +28,7 @@ class DicNode;
class DicNodeVector;
class ProximityInfo;
class ProximityInfoState;
+class MultiBigramMap;
class DicNodeUtils {
public:
@@ -41,9 +41,8 @@ class DicNodeUtils {
static void initByCopy(DicNode *srcNode, DicNode *destNode);
static void getAllChildDicNodes(DicNode *dicNode, const uint8_t *const dicRoot,
DicNodeVector *childDicNodes);
- static int getWordPos(const uint8_t *const dicRoot, const int *word, const int prevWordLength);
static float getBigramNodeImprobability(const uint8_t *const dicRoot,
- const DicNode *const node, hash_map_compat<int, int16_t> *const bigramCacheMap);
+ const DicNode *const node, MultiBigramMap *const multiBigramMap);
static bool isDicNodeFilteredOut(const int nodeCodePoint, const ProximityInfo *const pInfo,
const std::vector<int> *const codePointsFilter);
// TODO: Move to private
@@ -58,15 +57,11 @@ class DicNodeUtils {
private:
DISALLOW_IMPLICIT_CONSTRUCTORS(DicNodeUtils);
- // Max cache size for the space omission error correction bigram lookup
- static const int MAX_BIGRAM_MAP_SIZE = 20000;
// Max number of bigrams to look up
static const int MAX_BIGRAMS_CONSIDERED_PER_CONTEXT = 500;
static int getBigramNodeProbability(const uint8_t *const dicRoot, const DicNode *const node,
- hash_map_compat<int, int16_t> *bigramCacheMap);
- static int16_t getBigramNodeEncodedDiffProbability(const uint8_t *const dicRoot,
- const DicNode *const node, hash_map_compat<int, int16_t> *bigramCacheMap);
+ MultiBigramMap *multiBigramMap);
static void createAndGetPassingChildNode(DicNode *dicNode, const ProximityInfoState *pInfoState,
const int pointIndex, const bool exactOnly, DicNodeVector *childDicNodes);
static void createAndGetAllLeavingChildNodes(DicNode *dicNode, const uint8_t *const dicRoot,
@@ -77,8 +72,6 @@ class DicNodeUtils {
const int terminalDepth, const ProximityInfoState *pInfoState, const int pointIndex,
const bool exactOnly, const std::vector<int> *const codePointsFilter,
const ProximityInfo *const pInfo, DicNodeVector *childDicNodes);
- static int16_t getBigramProbability(const uint8_t *const dicRoot, int pos, const int nextPos,
- hash_map_compat<int, int16_t> *bigramCacheMap);
// TODO: Move to proximity info
static bool isMatchedNodeCodePoint(const ProximityInfoState *pInfoState, const int pointIndex,
diff --git a/native/jni/src/suggest/core/policy/scoring.h b/native/jni/src/suggest/core/policy/scoring.h
index b8c10e25a..102e856f5 100644
--- a/native/jni/src/suggest/core/policy/scoring.h
+++ b/native/jni/src/suggest/core/policy/scoring.h
@@ -29,16 +29,14 @@ class Scoring {
public:
virtual int calculateFinalScore(const float compoundDistance, const int inputSize,
const bool forceCommit) const = 0;
- virtual bool getMostProbableString(
- const DicTraverseSession *const traverseSession, const int terminalSize,
- const float languageWeight, int *const outputCodePoints, int *const type,
- int *const freq) const = 0;
+ virtual bool getMostProbableString(const DicTraverseSession *const traverseSession,
+ const int terminalSize, const float languageWeight, int *const outputCodePoints,
+ int *const type, int *const freq) const = 0;
virtual void safetyNetForMostProbableString(const int terminalSize,
const int maxScore, int *const outputCodePoints, int *const frequencies) const = 0;
// TODO: Make more generic
- virtual void searchWordWithDoubleLetter(DicNode *terminals,
- const int terminalSize, int *doubleLetterTerminalIndex,
- DoubleLetterLevel *doubleLetterLevel) const = 0;
+ virtual void searchWordWithDoubleLetter(DicNode *terminals, const int terminalSize,
+ int *doubleLetterTerminalIndex, DoubleLetterLevel *doubleLetterLevel) const = 0;
virtual float getAdjustedLanguageWeight(DicTraverseSession *const traverseSession,
DicNode *const terminals, const int size) const = 0;
virtual float getDoubleLetterDemotionDistanceCost(const int terminalIndex,
diff --git a/native/jni/src/suggest/core/policy/suggest_policy.h b/native/jni/src/suggest/core/policy/suggest_policy.h
index 885e214f7..5b6402c44 100644
--- a/native/jni/src/suggest/core/policy/suggest_policy.h
+++ b/native/jni/src/suggest/core/policy/suggest_policy.h
@@ -20,6 +20,7 @@
#include "defines.h"
namespace latinime {
+
class Traversal;
class Scoring;
class Weighting;
diff --git a/native/jni/src/suggest/core/policy/traversal.h b/native/jni/src/suggest/core/policy/traversal.h
index 02c358aec..c6f66f231 100644
--- a/native/jni/src/suggest/core/policy/traversal.h
+++ b/native/jni/src/suggest/core/policy/traversal.h
@@ -28,7 +28,8 @@ class Traversal {
virtual int getMaxPointerCount() const = 0;
virtual bool allowsErrorCorrections(const DicNode *const dicNode) const = 0;
virtual bool isOmission(const DicTraverseSession *const traverseSession,
- const DicNode *const dicNode, const DicNode *const childDicNode) const = 0;
+ const DicNode *const dicNode, const DicNode *const childDicNode,
+ const bool allowsErrorCorrections) const = 0;
virtual bool isSpaceSubstitutionTerminal(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const = 0;
virtual bool isSpaceOmissionTerminal(const DicTraverseSession *const traverseSession,
@@ -38,9 +39,8 @@ class Traversal {
const DicNode *const dicNode) const = 0;
virtual bool canDoLookAheadCorrection(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const = 0;
- virtual ProximityType getProximityType(
- const DicTraverseSession *const traverseSession, const DicNode *const dicNode,
- const DicNode *const childDicNode) const = 0;
+ virtual ProximityType getProximityType(const DicTraverseSession *const traverseSession,
+ const DicNode *const dicNode, const DicNode *const childDicNode) const = 0;
virtual bool sameAsTyped(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const = 0;
virtual bool needsToTraverseAllUserInput() const = 0;
@@ -48,9 +48,8 @@ class Traversal {
virtual bool allowPartialCommit() const = 0;
virtual int getDefaultExpandDicNodeSize() const = 0;
virtual int getMaxCacheSize() const = 0;
- virtual bool isPossibleOmissionChildNode(
- const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode,
- const DicNode *const dicNode) const = 0;
+ virtual bool isPossibleOmissionChildNode(const DicTraverseSession *const traverseSession,
+ const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0;
virtual bool isGoodToTraverseNextWord(const DicNode *const dicNode) const = 0;
protected:
diff --git a/native/jni/src/suggest/core/policy/weighting.cpp b/native/jni/src/suggest/core/policy/weighting.cpp
index b9c0b8129..d01531f07 100644
--- a/native/jni/src/suggest/core/policy/weighting.cpp
+++ b/native/jni/src/suggest/core/policy/weighting.cpp
@@ -18,7 +18,6 @@
#include "char_utils.h"
#include "defines.h"
-#include "hash_map_compat.h"
#include "suggest/core/dicnode/dic_node.h"
#include "suggest/core/dicnode/dic_node_profiler.h"
#include "suggest/core/dicnode/dic_node_utils.h"
@@ -26,6 +25,8 @@
namespace latinime {
+class MultiBigramMap;
+
static inline void profile(const CorrectionType correctionType, DicNode *const node) {
#if DEBUG_DICT
switch (correctionType) {
@@ -69,20 +70,18 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
}
/* static */ void Weighting::addCostAndForwardInputIndex(const Weighting *const weighting,
- const CorrectionType correctionType,
- const DicTraverseSession *const traverseSession,
+ const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode, DicNode *const dicNode,
- hash_map_compat<int, int16_t> *const bigramCacheMap) {
+ MultiBigramMap *const multiBigramMap) {
const int inputSize = traverseSession->getInputSize();
DicNode_InputStateG inputStateG;
inputStateG.mNeedsToUpdateInputStateG = false; // Don't use input info by default
const float spatialCost = Weighting::getSpatialCost(weighting, correctionType,
traverseSession, parentDicNode, dicNode, &inputStateG);
const float languageCost = Weighting::getLanguageCost(weighting, correctionType,
- traverseSession, parentDicNode, dicNode, bigramCacheMap);
- const bool edit = Weighting::isEditCorrection(correctionType);
- const bool proximity = Weighting::isProximityCorrection(weighting, correctionType,
- traverseSession, dicNode);
+ traverseSession, parentDicNode, dicNode, multiBigramMap);
+ const ErrorType errorType = weighting->getErrorType(correctionType, traverseSession,
+ parentDicNode, dicNode);
profile(correctionType, dicNode);
if (inputStateG.mNeedsToUpdateInputStateG) {
dicNode->updateInputIndexG(&inputStateG);
@@ -91,13 +90,13 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
(correctionType == CT_TRANSPOSITION));
}
dicNode->addCost(spatialCost, languageCost, weighting->needsToNormalizeCompoundDistance(),
- inputSize, edit, proximity);
+ inputSize, errorType);
}
/* static */ float Weighting::getSpatialCost(const Weighting *const weighting,
- const CorrectionType correctionType,
- const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode,
- const DicNode *const dicNode, DicNode_InputStateG *const inputStateG) {
+ const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
+ const DicNode *const parentDicNode, const DicNode *const dicNode,
+ DicNode_InputStateG *const inputStateG) {
switch(correctionType) {
case CT_OMISSION:
return weighting->getOmissionCost(parentDicNode, dicNode);
@@ -129,14 +128,14 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
/* static */ float Weighting::getLanguageCost(const Weighting *const weighting,
const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode, const DicNode *const dicNode,
- hash_map_compat<int, int16_t> *const bigramCacheMap) {
+ MultiBigramMap *const multiBigramMap) {
switch(correctionType) {
case CT_OMISSION:
return 0.0f;
case CT_SUBSTITUTION:
return 0.0f;
case CT_NEW_WORD_SPACE_OMITTION:
- return weighting->getNewWordBigramCost(traverseSession, parentDicNode, bigramCacheMap);
+ return weighting->getNewWordBigramCost(traverseSession, parentDicNode, multiBigramMap);
case CT_MATCH:
return 0.0f;
case CT_COMPLETION:
@@ -144,11 +143,11 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
case CT_TERMINAL: {
const float languageImprobability =
DicNodeUtils::getBigramNodeImprobability(
- traverseSession->getOffsetDict(), dicNode, bigramCacheMap);
+ traverseSession->getOffsetDict(), dicNode, multiBigramMap);
return weighting->getTerminalLanguageCost(traverseSession, dicNode, languageImprobability);
}
case CT_NEW_WORD_SPACE_SUBSTITUTION:
- return weighting->getNewWordBigramCost(traverseSession, parentDicNode, bigramCacheMap);
+ return weighting->getNewWordBigramCost(traverseSession, parentDicNode, multiBigramMap);
case CT_INSERTION:
return 0.0f;
case CT_TRANSPOSITION:
@@ -158,64 +157,6 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
}
}
-/* static */ bool Weighting::isEditCorrection(const CorrectionType correctionType) {
- switch(correctionType) {
- case CT_OMISSION:
- return true;
- case CT_ADDITIONAL_PROXIMITY:
- // Should return true?
- return false;
- case CT_SUBSTITUTION:
- // Should return true?
- return false;
- case CT_NEW_WORD_SPACE_OMITTION:
- return false;
- case CT_MATCH:
- return false;
- case CT_COMPLETION:
- return false;
- case CT_TERMINAL:
- return false;
- case CT_NEW_WORD_SPACE_SUBSTITUTION:
- return false;
- case CT_INSERTION:
- return true;
- case CT_TRANSPOSITION:
- return true;
- default:
- return false;
- }
-}
-
-/* static */ bool Weighting::isProximityCorrection(const Weighting *const weighting,
- const CorrectionType correctionType,
- const DicTraverseSession *const traverseSession, const DicNode *const dicNode) {
- switch(correctionType) {
- case CT_OMISSION:
- return false;
- case CT_ADDITIONAL_PROXIMITY:
- return false;
- case CT_SUBSTITUTION:
- return false;
- case CT_NEW_WORD_SPACE_OMITTION:
- return false;
- case CT_MATCH:
- return weighting->isProximityDicNode(traverseSession, dicNode);
- case CT_COMPLETION:
- return false;
- case CT_TERMINAL:
- return false;
- case CT_NEW_WORD_SPACE_SUBSTITUTION:
- return false;
- case CT_INSERTION:
- return false;
- case CT_TRANSPOSITION:
- return false;
- default:
- return false;
- }
-}
-
/* static */ int Weighting::getForwardInputCount(const CorrectionType correctionType) {
switch(correctionType) {
case CT_OMISSION:
@@ -229,7 +170,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n
case CT_MATCH:
return 1;
case CT_COMPLETION:
- return 0;
+ return 1;
case CT_TERMINAL:
return 0;
case CT_NEW_WORD_SPACE_SUBSTITUTION:
diff --git a/native/jni/src/suggest/core/policy/weighting.h b/native/jni/src/suggest/core/policy/weighting.h
index bce479c51..0d2745b40 100644
--- a/native/jni/src/suggest/core/policy/weighting.h
+++ b/native/jni/src/suggest/core/policy/weighting.h
@@ -18,13 +18,13 @@
#define LATINIME_WEIGHTING_H
#include "defines.h"
-#include "hash_map_compat.h"
namespace latinime {
class DicNode;
class DicTraverseSession;
struct DicNode_InputStateG;
+class MultiBigramMap;
class Weighting {
public:
@@ -32,7 +32,7 @@ class Weighting {
const CorrectionType correctionType,
const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode, DicNode *const dicNode,
- hash_map_compat<int, int16_t> *const bigramCacheMap);
+ MultiBigramMap *const multiBigramMap);
protected:
virtual float getTerminalSpatialCost(const DicTraverseSession *const traverseSession,
@@ -61,7 +61,7 @@ class Weighting {
virtual float getNewWordBigramCost(
const DicTraverseSession *const traverseSession, const DicNode *const dicNode,
- hash_map_compat<int, int16_t> *const bigramCacheMap) const = 0;
+ MultiBigramMap *const multiBigramMap) const = 0;
virtual float getCompletionCost(
const DicTraverseSession *const traverseSession,
@@ -80,6 +80,10 @@ class Weighting {
virtual float getSpaceSubstitutionCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const = 0;
+ virtual ErrorType getErrorType(const CorrectionType correctionType,
+ const DicTraverseSession *const traverseSession,
+ const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0;
+
Weighting() {}
virtual ~Weighting() {}
@@ -93,13 +97,7 @@ class Weighting {
static float getLanguageCost(const Weighting *const weighting,
const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode, const DicNode *const dicNode,
- hash_map_compat<int, int16_t> *const bigramCacheMap);
- // TODO: Move to TypingWeighting and GestureWeighting?
- static bool isEditCorrection(const CorrectionType correctionType);
- // TODO: Move to TypingWeighting and GestureWeighting?
- static bool isProximityCorrection(const Weighting *const weighting,
- const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
- const DicNode *const dicNode);
+ MultiBigramMap *const multiBigramMap);
// TODO: Move to TypingWeighting and GestureWeighting?
static int getForwardInputCount(const CorrectionType correctionType);
};
diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.cpp b/native/jni/src/suggest/core/session/dic_traverse_session.cpp
index 3c44db21c..51165858b 100644
--- a/native/jni/src/suggest/core/session/dic_traverse_session.cpp
+++ b/native/jni/src/suggest/core/session/dic_traverse_session.cpp
@@ -69,7 +69,15 @@ void DicTraverseSession::init(const Dictionary *const dictionary, const int *pre
mPrevWordPos = NOT_VALID_WORD;
return;
}
- mPrevWordPos = DicNodeUtils::getWordPos(dictionary->getOffsetDict(), prevWord, prevWordLength);
+ // TODO: merge following similar calls to getTerminalPosition into one case-insensitive call.
+ mPrevWordPos = BinaryFormat::getTerminalPosition(dictionary->getOffsetDict(), prevWord,
+ prevWordLength, false /* forceLowerCaseSearch */);
+ if (mPrevWordPos == NOT_VALID_WORD) {
+ // Check bigrams for lower-cased previous word if original was not found. Useful for
+ // auto-capitalized words like "The [current_word]".
+ mPrevWordPos = BinaryFormat::getTerminalPosition(dictionary->getOffsetDict(), prevWord,
+ prevWordLength, true /* forceLowerCaseSearch */);
+ }
}
void DicTraverseSession::setupForGetSuggestions(const ProximityInfo *pInfo,
@@ -92,7 +100,7 @@ int DicTraverseSession::getDictFlags() const {
void DicTraverseSession::resetCache(const int nextActiveCacheSize, const int maxWords) {
mDicNodesCache.reset(nextActiveCacheSize, maxWords);
- mBigramCacheMap.clear();
+ mMultiBigramMap.clear();
mPartiallyCommited = false;
}
diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.h b/native/jni/src/suggest/core/session/dic_traverse_session.h
index d9c2a51d0..d88be5b88 100644
--- a/native/jni/src/suggest/core/session/dic_traverse_session.h
+++ b/native/jni/src/suggest/core/session/dic_traverse_session.h
@@ -21,8 +21,8 @@
#include <vector>
#include "defines.h"
-#include "hash_map_compat.h"
#include "jni.h"
+#include "multi_bigram_map.h"
#include "proximity_info_state.h"
#include "suggest/core/dicnode/dic_nodes_cache.h"
@@ -35,7 +35,7 @@ class DicTraverseSession {
public:
AK_FORCE_INLINE DicTraverseSession(JNIEnv *env, jstring localeStr)
: mPrevWordPos(NOT_VALID_WORD), mProximityInfo(0),
- mDictionary(0), mDicNodesCache(), mBigramCacheMap(),
+ mDictionary(0), mDicNodesCache(), mMultiBigramMap(),
mInputSize(0), mPartiallyCommited(false), mMaxPointerCount(1),
mMultiWordCostMultiplier(1.0f) {
// NOTE: mProximityInfoStates is an array of instances.
@@ -67,7 +67,7 @@ class DicTraverseSession {
// TODO: Use proper parameter when changed
int getDicRootPos() const { return 0; }
DicNodesCache *getDicTraverseCache() { return &mDicNodesCache; }
- hash_map_compat<int, int16_t> *getBigramCacheMap() { return &mBigramCacheMap; }
+ MultiBigramMap *getMultiBigramMap() { return &mMultiBigramMap; }
const ProximityInfoState *getProximityInfoState(int id) const {
return &mProximityInfoStates[id];
}
@@ -170,7 +170,7 @@ class DicTraverseSession {
DicNodesCache mDicNodesCache;
// Temporary cache for bigram frequencies
- hash_map_compat<int, int16_t> mBigramCacheMap;
+ MultiBigramMap mMultiBigramMap;
ProximityInfoState mProximityInfoStates[MAX_POINTER_COUNT_G];
int mInputSize;
diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp
index 9de2cd2e2..3221dee9c 100644
--- a/native/jni/src/suggest/core/suggest.cpp
+++ b/native/jni/src/suggest/core/suggest.cpp
@@ -161,12 +161,15 @@ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequen
+ doubleLetterCost;
const TerminalAttributes terminalAttributes(traverseSession->getOffsetDict(),
terminalDicNode->getFlags(), terminalDicNode->getAttributesPos());
- const int originalTerminalProbability = terminalDicNode->getProbability();
+ const bool isPossiblyOffensiveWord = terminalDicNode->getProbability() <= 0;
+ const bool isExactMatch = terminalDicNode->isExactMatch();
+ const int outputTypeFlags =
+ isPossiblyOffensiveWord ? Dictionary::KIND_FLAG_POSSIBLY_OFFENSIVE : 0
+ | isExactMatch ? Dictionary::KIND_FLAG_EXACT_MATCH : 0;
+
+ // Entries that are blacklisted or do not represent a word should not be output.
+ const bool isValidWord = !terminalAttributes.isBlacklistedOrNotAWord();
- // Do not suggest words with a 0 probability, or entries that are blacklisted or do not
- // represent a word. However, we should still submit their shortcuts if any.
- const bool isValidWord =
- originalTerminalProbability > 0 && !terminalAttributes.isBlacklistedOrNotAWord();
// Increase output score of top typing suggestion to ensure autocorrection.
// TODO: Better integration with java side autocorrection logic.
// Force autocorrection for obvious long multi-word suggestions.
@@ -188,10 +191,9 @@ int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequen
}
}
- // Do not suggest words with a 0 probability, or entries that are blacklisted or do not
- // represent a word. However, we should still submit their shortcuts if any.
+ // Don't output invalid words. However, we still need to submit their shortcuts if any.
if (isValidWord) {
- outputTypes[outputWordIndex] = Dictionary::KIND_CORRECTION;
+ outputTypes[outputWordIndex] = Dictionary::KIND_CORRECTION | outputTypeFlags;
frequencies[outputWordIndex] = finalScore;
// Populate the outputChars array with the suggested word.
const int startIndex = outputWordIndex * MAX_WORD_LENGTH;
@@ -294,8 +296,8 @@ void Suggest::expandCurrentDicNodes(DicTraverseSession *traverseSession) const {
correctionDicNode.advanceDigraphIndex();
processDicNodeAsDigraph(traverseSession, &correctionDicNode);
}
- if (allowsErrorCorrections
- && TRAVERSAL->isOmission(traverseSession, &dicNode, childDicNode)) {
+ if (TRAVERSAL->isOmission(traverseSession, &dicNode, childDicNode,
+ allowsErrorCorrections)) {
// TODO: (Gesture) Change weight between omission and substitution errors
// TODO: (Gesture) Terminal node should not be handled as omission
correctionDicNode.initByCopy(childDicNode);
@@ -357,7 +359,7 @@ void Suggest::processTerminalDicNode(
DicNode terminalDicNode;
DicNodeUtils::initByCopy(dicNode, &terminalDicNode);
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TERMINAL, traverseSession, 0,
- &terminalDicNode, traverseSession->getBigramCacheMap());
+ &terminalDicNode, traverseSession->getMultiBigramMap());
traverseSession->getDicTraverseCache()->copyPushTerminal(&terminalDicNode);
}
@@ -389,8 +391,10 @@ void Suggest::processDicNodeAsMatch(DicTraverseSession *traverseSession,
void Suggest::processDicNodeAsAdditionalProximityChar(DicTraverseSession *traverseSession,
DicNode *dicNode, DicNode *childDicNode) const {
+ // Note: Most types of corrections don't need to look up the bigram information since they do
+ // not treat the node as a terminal. There is no need to pass the bigram map in these cases.
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_ADDITIONAL_PROXIMITY,
- traverseSession, dicNode, childDicNode, 0 /* bigramCacheMap */);
+ traverseSession, dicNode, childDicNode, 0 /* multiBigramMap */);
weightChildNode(traverseSession, childDicNode);
processExpandedDicNode(traverseSession, childDicNode);
}
@@ -398,7 +402,7 @@ void Suggest::processDicNodeAsAdditionalProximityChar(DicTraverseSession *traver
void Suggest::processDicNodeAsSubstitution(DicTraverseSession *traverseSession,
DicNode *dicNode, DicNode *childDicNode) const {
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_SUBSTITUTION, traverseSession,
- dicNode, childDicNode, 0 /* bigramCacheMap */);
+ dicNode, childDicNode, 0 /* multiBigramMap */);
weightChildNode(traverseSession, childDicNode);
processExpandedDicNode(traverseSession, childDicNode);
}
@@ -422,20 +426,15 @@ void Suggest::processDicNodeAsDigraph(DicTraverseSession *traverseSession,
*/
void Suggest::processDicNodeAsOmission(
DicTraverseSession *traverseSession, DicNode *dicNode) const {
- // If the omission is surely intentional that it should incur zero cost.
- const bool isZeroCostOmission = dicNode->isZeroCostOmission();
DicNodeVector childDicNodes;
-
DicNodeUtils::getAllChildDicNodes(dicNode, traverseSession->getOffsetDict(), &childDicNodes);
const int size = childDicNodes.getSizeAndLock();
for (int i = 0; i < size; i++) {
DicNode *const childDicNode = childDicNodes[i];
- if (!isZeroCostOmission) {
- // Treat this word as omission
- Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_OMISSION, traverseSession,
- dicNode, childDicNode, 0 /* bigramCacheMap */);
- }
+ // Treat this word as omission
+ Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_OMISSION, traverseSession,
+ dicNode, childDicNode, 0 /* multiBigramMap */);
weightChildNode(traverseSession, childDicNode);
if (!TRAVERSAL->isPossibleOmissionChildNode(traverseSession, dicNode, childDicNode)) {
@@ -459,7 +458,7 @@ void Suggest::processDicNodeAsInsertion(DicTraverseSession *traverseSession,
for (int i = 0; i < size; i++) {
DicNode *const childDicNode = childDicNodes[i];
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_INSERTION, traverseSession,
- dicNode, childDicNode, 0 /* bigramCacheMap */);
+ dicNode, childDicNode, 0 /* multiBigramMap */);
processExpandedDicNode(traverseSession, childDicNode);
}
}
@@ -484,7 +483,7 @@ void Suggest::processDicNodeAsTransposition(DicTraverseSession *traverseSession,
for (int j = 0; j < childSize2; j++) {
DicNode *const childDicNode2 = childDicNodes2[j];
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TRANSPOSITION,
- traverseSession, childDicNodes1[i], childDicNode2, 0 /* bigramCacheMap */);
+ traverseSession, childDicNodes1[i], childDicNode2, 0 /* multiBigramMap */);
processExpandedDicNode(traverseSession, childDicNode2);
}
}
@@ -499,10 +498,10 @@ void Suggest::weightChildNode(DicTraverseSession *traverseSession, DicNode *dicN
const int inputSize = traverseSession->getInputSize();
if (dicNode->isCompletion(inputSize)) {
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_COMPLETION, traverseSession,
- 0 /* parentDicNode */, dicNode, 0 /* bigramCacheMap */);
+ 0 /* parentDicNode */, dicNode, 0 /* multiBigramMap */);
} else { // completion
Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_MATCH, traverseSession,
- 0 /* parentDicNode */, dicNode, 0 /* bigramCacheMap */);
+ 0 /* parentDicNode */, dicNode, 0 /* multiBigramMap */);
}
}
@@ -523,7 +522,7 @@ void Suggest::createNextWordDicNode(DicTraverseSession *traverseSession, DicNode
const CorrectionType correctionType = spaceSubstitution ?
CT_NEW_WORD_SPACE_SUBSTITUTION : CT_NEW_WORD_SPACE_OMITTION;
Weighting::addCostAndForwardInputIndex(WEIGHTING, correctionType, traverseSession, dicNode,
- &newDicNode, traverseSession->getBigramCacheMap());
+ &newDicNode, traverseSession->getMultiBigramMap());
traverseSession->getDicTraverseCache()->copyPushNextActive(&newDicNode);
}
} // namespace latinime
diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp
index 0fa684f01..f87989286 100644
--- a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp
+++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp
@@ -28,25 +28,25 @@ const int ScoringParams::THRESHOLD_SHORT_WORD_LENGTH = 4;
const float ScoringParams::DISTANCE_WEIGHT_LENGTH = 0.132f;
const float ScoringParams::PROXIMITY_COST = 0.086f;
const float ScoringParams::FIRST_PROXIMITY_COST = 0.104f;
-const float ScoringParams::OMISSION_COST = 0.388f;
-const float ScoringParams::OMISSION_COST_SAME_CHAR = 0.431f;
-const float ScoringParams::OMISSION_COST_FIRST_CHAR = 0.532f;
-const float ScoringParams::INSERTION_COST = 0.670f;
-const float ScoringParams::INSERTION_COST_SAME_CHAR = 0.526f;
-const float ScoringParams::INSERTION_COST_FIRST_CHAR = 0.563f;
-const float ScoringParams::TRANSPOSITION_COST = 0.494f;
-const float ScoringParams::SPACE_SUBSTITUTION_COST = 0.239f;
+const float ScoringParams::OMISSION_COST = 0.458f;
+const float ScoringParams::OMISSION_COST_SAME_CHAR = 0.491f;
+const float ScoringParams::OMISSION_COST_FIRST_CHAR = 0.582f;
+const float ScoringParams::INSERTION_COST = 0.730f;
+const float ScoringParams::INSERTION_COST_SAME_CHAR = 0.586f;
+const float ScoringParams::INSERTION_COST_FIRST_CHAR = 0.623f;
+const float ScoringParams::TRANSPOSITION_COST = 0.516f;
+const float ScoringParams::SPACE_SUBSTITUTION_COST = 0.319f;
const float ScoringParams::ADDITIONAL_PROXIMITY_COST = 0.380f;
-const float ScoringParams::SUBSTITUTION_COST = 0.363f;
-const float ScoringParams::COST_NEW_WORD = 0.054f;
-const float ScoringParams::COST_NEW_WORD_CAPITALIZED = 0.174f;
+const float ScoringParams::SUBSTITUTION_COST = 0.403f;
+const float ScoringParams::COST_NEW_WORD = 0.042f;
+const float ScoringParams::COST_SECOND_OR_LATER_WORD_FIRST_CHAR_UPPERCASE = 0.25f;
const float ScoringParams::DISTANCE_WEIGHT_LANGUAGE = 1.123f;
-const float ScoringParams::COST_FIRST_LOOKAHEAD = 0.462f;
-const float ScoringParams::COST_LOOKAHEAD = 0.092f;
-const float ScoringParams::HAS_PROXIMITY_TERMINAL_COST = 0.126f;
-const float ScoringParams::HAS_EDIT_CORRECTION_TERMINAL_COST = 0.056f;
-const float ScoringParams::HAS_MULTI_WORD_TERMINAL_COST = 0.136f;
+const float ScoringParams::COST_FIRST_LOOKAHEAD = 0.545f;
+const float ScoringParams::COST_LOOKAHEAD = 0.073f;
+const float ScoringParams::HAS_PROXIMITY_TERMINAL_COST = 0.105f;
+const float ScoringParams::HAS_EDIT_CORRECTION_TERMINAL_COST = 0.038f;
+const float ScoringParams::HAS_MULTI_WORD_TERMINAL_COST = 0.444f;
const float ScoringParams::TYPING_BASE_OUTPUT_SCORE = 1.0f;
const float ScoringParams::TYPING_MAX_OUTPUT_SCORE_PER_INPUT = 0.1f;
-const float ScoringParams::MAX_NORM_DISTANCE_FOR_EDIT = 0.1f;
+const float ScoringParams::NORMALIZED_SPATIAL_DISTANCE_THRESHOLD_FOR_EDIT = 0.06f;
} // namespace latinime
diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.h b/native/jni/src/suggest/policyimpl/typing/scoring_params.h
index 8f104b362..53ac999c1 100644
--- a/native/jni/src/suggest/policyimpl/typing/scoring_params.h
+++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.h
@@ -48,7 +48,7 @@ class ScoringParams {
static const float ADDITIONAL_PROXIMITY_COST;
static const float SUBSTITUTION_COST;
static const float COST_NEW_WORD;
- static const float COST_NEW_WORD_CAPITALIZED;
+ static const float COST_SECOND_OR_LATER_WORD_FIRST_CHAR_UPPERCASE;
static const float DISTANCE_WEIGHT_LANGUAGE;
static const float COST_FIRST_LOOKAHEAD;
static const float COST_LOOKAHEAD;
@@ -57,7 +57,7 @@ class ScoringParams {
static const float HAS_MULTI_WORD_TERMINAL_COST;
static const float TYPING_BASE_OUTPUT_SCORE;
static const float TYPING_MAX_OUTPUT_SCORE_PER_INPUT;
- static const float MAX_NORM_DISTANCE_FOR_EDIT;
+ static const float NORMALIZED_SPATIAL_DISTANCE_THRESHOLD_FOR_EDIT;
private:
DISALLOW_IMPLICIT_CONSTRUCTORS(ScoringParams);
diff --git a/native/jni/src/suggest/policyimpl/typing/typing_traversal.h b/native/jni/src/suggest/policyimpl/typing/typing_traversal.h
index 9f8347452..12110d54f 100644
--- a/native/jni/src/suggest/policyimpl/typing/typing_traversal.h
+++ b/native/jni/src/suggest/policyimpl/typing/typing_traversal.h
@@ -39,14 +39,21 @@ class TypingTraversal : public Traversal {
AK_FORCE_INLINE bool allowsErrorCorrections(const DicNode *const dicNode) const {
return dicNode->getNormalizedSpatialDistance()
- < ScoringParams::MAX_NORM_DISTANCE_FOR_EDIT;
+ < ScoringParams::NORMALIZED_SPATIAL_DISTANCE_THRESHOLD_FOR_EDIT;
}
AK_FORCE_INLINE bool isOmission(const DicTraverseSession *const traverseSession,
- const DicNode *const dicNode, const DicNode *const childDicNode) const {
+ const DicNode *const dicNode, const DicNode *const childDicNode,
+ const bool allowsErrorCorrections) const {
if (!CORRECT_OMISSION) {
return false;
}
+ // Note: Always consider intentional omissions (like apostrophes) since they are common.
+ const bool canConsiderOmission =
+ allowsErrorCorrections || childDicNode->canBeIntentionalOmission();
+ if (!canConsiderOmission) {
+ return false;
+ }
const int inputSize = traverseSession->getInputSize();
// TODO: Don't refer to isCompletion?
if (dicNode->isCompletion(inputSize)) {
diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp b/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp
index 1500341bd..e4c69d1f6 100644
--- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp
+++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp
@@ -20,5 +20,41 @@
#include "suggest/policyimpl/typing/scoring_params.h"
namespace latinime {
+
const TypingWeighting TypingWeighting::sInstance;
+
+ErrorType TypingWeighting::getErrorType(const CorrectionType correctionType,
+ const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode,
+ const DicNode *const dicNode) const {
+ switch (correctionType) {
+ case CT_MATCH:
+ if (isProximityDicNode(traverseSession, dicNode)) {
+ return ET_PROXIMITY_CORRECTION;
+ } else {
+ return ET_NOT_AN_ERROR;
+ }
+ case CT_ADDITIONAL_PROXIMITY:
+ return ET_PROXIMITY_CORRECTION;
+ case CT_OMISSION:
+ if (parentDicNode->canBeIntentionalOmission()) {
+ return ET_INTENTIONAL_OMISSION;
+ } else {
+ return ET_EDIT_CORRECTION;
+ }
+ break;
+ case CT_SUBSTITUTION:
+ case CT_INSERTION:
+ case CT_TRANSPOSITION:
+ return ET_EDIT_CORRECTION;
+ case CT_NEW_WORD_SPACE_OMITTION:
+ case CT_NEW_WORD_SPACE_SUBSTITUTION:
+ return ET_NEW_WORD;
+ case CT_TERMINAL:
+ return ET_NOT_AN_ERROR;
+ case CT_COMPLETION:
+ return ET_COMPLETION;
+ default:
+ return ET_NOT_AN_ERROR;
+ }
+}
} // namespace latinime
diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h
index 74e4e34e4..3938c0ec5 100644
--- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h
+++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h
@@ -28,14 +28,15 @@ namespace latinime {
class DicNode;
struct DicNode_InputStateG;
+class MultiBigramMap;
class TypingWeighting : public Weighting {
public:
static const TypingWeighting *getInstance() { return &sInstance; }
protected:
- float getTerminalSpatialCost(
- const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const {
+ float getTerminalSpatialCost(const DicTraverseSession *const traverseSession,
+ const DicNode *const dicNode) const {
float cost = 0.0f;
if (dicNode->hasMultipleWords()) {
cost += ScoringParams::HAS_MULTI_WORD_TERMINAL_COST;
@@ -50,13 +51,14 @@ class TypingWeighting : public Weighting {
}
float getOmissionCost(const DicNode *const parentDicNode, const DicNode *const dicNode) const {
- bool sameCodePoint = false;
- bool isFirstLetterOmission = false;
- float cost = 0.0f;
- sameCodePoint = dicNode->isSameNodeCodePoint(parentDicNode);
+ const bool isZeroCostOmission = parentDicNode->isZeroCostOmission();
+ const bool sameCodePoint = dicNode->isSameNodeCodePoint(parentDicNode);
// If the traversal omitted the first letter then the dicNode should now be on the second.
- isFirstLetterOmission = dicNode->getDepth() == 2;
- if (isFirstLetterOmission) {
+ const bool isFirstLetterOmission = dicNode->getDepth() == 2;
+ float cost = 0.0f;
+ if (isZeroCostOmission) {
+ cost = 0.0f;
+ } else if (isFirstLetterOmission) {
cost = ScoringParams::OMISSION_COST_FIRST_CHAR;
} else {
cost = sameCodePoint ? ScoringParams::OMISSION_COST_SAME_CHAR
@@ -65,9 +67,8 @@ class TypingWeighting : public Weighting {
return cost;
}
- float getMatchedCost(
- const DicTraverseSession *const traverseSession, const DicNode *const dicNode,
- DicNode_InputStateG *inputStateG) const {
+ float getMatchedCost(const DicTraverseSession *const traverseSession,
+ const DicNode *const dicNode, DicNode_InputStateG *inputStateG) const {
const int pointIndex = dicNode->getInputIndex(0);
// Note: min() required since length can be MAX_POINT_TO_KEY_LENGTH for characters not on
// the keyboard (like accented letters)
@@ -79,13 +80,23 @@ class TypingWeighting : public Weighting {
const bool isFirstChar = pointIndex == 0;
const bool isProximity = isProximityDicNode(traverseSession, dicNode);
- const float cost = isProximity ? (isFirstChar ? ScoringParams::FIRST_PROXIMITY_COST
+ float cost = isProximity ? (isFirstChar ? ScoringParams::FIRST_PROXIMITY_COST
: ScoringParams::PROXIMITY_COST) : 0.0f;
+ if (dicNode->getDepth() == 2) {
+ // At the second character of the current word, we check if the first char is uppercase
+ // and the word is a second or later word of a multiple word suggestion. We demote it
+ // if so.
+ const bool isSecondOrLaterWordFirstCharUppercase =
+ dicNode->hasMultipleWords() && dicNode->isFirstCharUppercase();
+ if (isSecondOrLaterWordFirstCharUppercase) {
+ cost += ScoringParams::COST_SECOND_OR_LATER_WORD_FIRST_CHAR_UPPERCASE;
+ }
+ }
return weightedDistance + cost;
}
- bool isProximityDicNode(
- const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const {
+ bool isProximityDicNode(const DicTraverseSession *const traverseSession,
+ const DicNode *const dicNode) const {
const int pointIndex = dicNode->getInputIndex(0);
const int primaryCodePoint = toBaseLowerCase(
traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt(pointIndex));
@@ -93,9 +104,8 @@ class TypingWeighting : public Weighting {
return primaryCodePoint != dicNodeChar;
}
- float getTranspositionCost(
- const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode,
- const DicNode *const dicNode) const {
+ float getTranspositionCost(const DicTraverseSession *const traverseSession,
+ const DicNode *const parentDicNode, const DicNode *const dicNode) const {
const int16_t parentPointIndex = parentDicNode->getInputIndex(0);
const int prevCodePoint = parentDicNode->getNodeCodePoint();
const float distance1 = traverseSession->getProximityInfoState(0)->getPointToKeyLength(
@@ -109,8 +119,7 @@ class TypingWeighting : public Weighting {
return ScoringParams::TRANSPOSITION_COST + weightedLengthDistance;
}
- float getInsertionCost(
- const DicTraverseSession *const traverseSession,
+ float getInsertionCost(const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode, const DicNode *const dicNode) const {
const int16_t parentPointIndex = parentDicNode->getInputIndex(0);
const int prevCodePoint =
@@ -130,17 +139,14 @@ class TypingWeighting : public Weighting {
float getNewWordCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const {
- const bool isCapitalized = dicNode->isCapitalized();
- const float cost = isCapitalized ?
- ScoringParams::COST_NEW_WORD_CAPITALIZED : ScoringParams::COST_NEW_WORD;
- return cost * traverseSession->getMultiWordCostMultiplier();
+ return ScoringParams::COST_NEW_WORD * traverseSession->getMultiWordCostMultiplier();
}
- float getNewWordBigramCost(
- const DicTraverseSession *const traverseSession, const DicNode *const dicNode,
- hash_map_compat<int, int16_t> *const bigramCacheMap) const {
+ float getNewWordBigramCost(const DicTraverseSession *const traverseSession,
+ const DicNode *const dicNode,
+ MultiBigramMap *const multiBigramMap) const {
return DicNodeUtils::getBigramNodeImprobability(traverseSession->getOffsetDict(),
- dicNode, bigramCacheMap);
+ dicNode, multiBigramMap) * ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
}
float getCompletionCost(const DicTraverseSession *const traverseSession,
@@ -156,21 +162,9 @@ class TypingWeighting : public Weighting {
float getTerminalLanguageCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode, const float dicNodeLanguageImprobability) const {
- const bool hasEditCount = dicNode->getEditCorrectionCount() > 0;
- const bool isSameLength = dicNode->getDepth() == traverseSession->getInputSize();
- const bool hasMultipleWords = dicNode->hasMultipleWords();
- const bool hasProximityErrors = dicNode->getProximityCorrectionCount() > 0;
- // Gesture input is always assumed to have proximity errors
- // because the input word shouldn't be treated as perfect
- const bool isExactMatch = !hasEditCount && !hasMultipleWords
- && !hasProximityErrors && isSameLength;
-
- const float totalPrevWordsLanguageCost = dicNode->getTotalPrevWordsLanguageCost();
- const float languageImprobability = isExactMatch ? 0.0f : dicNodeLanguageImprobability;
- const float languageWeight = ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
- // TODO: Caveat: The following equation should be:
- // totalPrevWordsLanguageCost + (languageImprobability * languageWeight);
- return (totalPrevWordsLanguageCost + languageImprobability) * languageWeight;
+ const float languageImprobability = (dicNode->isExactMatch()) ?
+ 0.0f : dicNodeLanguageImprobability;
+ return languageImprobability * ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
}
AK_FORCE_INLINE bool needsToNormalizeCompoundDistance() const {
@@ -185,15 +179,16 @@ class TypingWeighting : public Weighting {
return ScoringParams::SUBSTITUTION_COST;
}
- AK_FORCE_INLINE float getSpaceSubstitutionCost(
- const DicTraverseSession *const traverseSession,
+ AK_FORCE_INLINE float getSpaceSubstitutionCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const {
- const bool isCapitalized = dicNode->isCapitalized();
- const float cost = ScoringParams::SPACE_SUBSTITUTION_COST + (isCapitalized ?
- ScoringParams::COST_NEW_WORD_CAPITALIZED : ScoringParams::COST_NEW_WORD);
+ const float cost = ScoringParams::SPACE_SUBSTITUTION_COST + ScoringParams::COST_NEW_WORD;
return cost * traverseSession->getMultiWordCostMultiplier();
}
+ ErrorType getErrorType(const CorrectionType correctionType,
+ const DicTraverseSession *const traverseSession,
+ const DicNode *const parentDicNode, const DicNode *const dicNode) const;
+
private:
DISALLOW_COPY_AND_ASSIGN(TypingWeighting);
static const TypingWeighting sInstance;
diff --git a/native/jni/src/terminal_attributes.h b/native/jni/src/terminal_attributes.h
index 144ae1452..92ef71c2c 100644
--- a/native/jni/src/terminal_attributes.h
+++ b/native/jni/src/terminal_attributes.h
@@ -72,7 +72,7 @@ class TerminalAttributes {
}
bool isBlacklistedOrNotAWord() const {
- return mFlags & (BinaryFormat::FLAG_IS_BLACKLISTED | BinaryFormat::FLAG_IS_NOT_A_WORD);
+ return BinaryFormat::hasBlacklistedOrNotAWordFlag(mFlags);
}
private: