diff options
34 files changed, 587 insertions, 258 deletions
diff --git a/java/res/values/attrs.xml b/java/res/values/attrs.xml index 80bf704db..c31831747 100644 --- a/java/res/values/attrs.xml +++ b/java/res/values/attrs.xml @@ -439,7 +439,6 @@ </declare-styleable> <declare-styleable name="SeekBarDialogPreference"> - <attr name="valueFormatText" format="reference" /> <attr name="maxValue" format="integer" /> <attr name="minValue" format="integer" /> <attr name="stepValue" format="integer" /> diff --git a/java/res/xml/prefs.xml b/java/res/xml/prefs.xml index 51e3420e9..d77adcef9 100644 --- a/java/res/xml/prefs.xml +++ b/java/res/xml/prefs.xml @@ -170,14 +170,12 @@ <com.android.inputmethod.latin.SeekBarDialogPreference android:key="pref_key_longpress_timeout" android:title="@string/prefs_key_longpress_timeout_settings" - latin:valueFormatText="@string/abbreviation_unit_milliseconds" latin:minValue="@integer/config_min_longpress_timeout" latin:maxValue="@integer/config_max_longpress_timeout" latin:stepValue="@integer/config_longpress_timeout_step" /> <com.android.inputmethod.latin.SeekBarDialogPreference android:key="pref_vibration_duration_settings" android:title="@string/prefs_keypress_vibration_duration_settings" - latin:valueFormatText="@string/abbreviation_unit_milliseconds" latin:maxValue="@integer/config_max_vibration_duration" /> <com.android.inputmethod.latin.SeekBarDialogPreference android:key="pref_keypress_sound_volume" diff --git a/java/src/com/android/inputmethod/keyboard/internal/KeyboardParams.java b/java/src/com/android/inputmethod/keyboard/internal/KeyboardParams.java index 15eb690e1..84319eb33 100644 --- a/java/src/com/android/inputmethod/keyboard/internal/KeyboardParams.java +++ b/java/src/com/android/inputmethod/keyboard/internal/KeyboardParams.java @@ -84,11 +84,16 @@ public class KeyboardParams { public void onAddKey(final Key newKey) { final Key key = (mKeysCache != null) ? mKeysCache.get(newKey) : newKey; - final boolean zeroWidthSpacer = key.isSpacer() && key.mWidth == 0; - if (!zeroWidthSpacer) { - mKeys.add(key); - updateHistogram(key); + final boolean isSpacer = key.isSpacer(); + if (isSpacer && key.mWidth == 0) { + // Ignore zero width {@link Spacer}. + return; } + mKeys.add(key); + if (isSpacer) { + return; + } + updateHistogram(key); if (key.mCode == Constants.CODE_SHIFT) { mShiftKeys.add(key); } diff --git a/java/src/com/android/inputmethod/latin/LatinIME.java b/java/src/com/android/inputmethod/latin/LatinIME.java index 6a3066cc1..73ec57871 100644 --- a/java/src/com/android/inputmethod/latin/LatinIME.java +++ b/java/src/com/android/inputmethod/latin/LatinIME.java @@ -2357,7 +2357,8 @@ public class LatinIME extends InputMethodService implements KeyboardActionListen LastComposedWord.NOT_A_SEPARATOR); if (ProductionFlag.USES_DEVELOPMENT_ONLY_DIAGNOSTICS) { ResearchLogger.latinIME_pickSuggestionManually(replacedWord, index, suggestion, - mWordComposer.isBatchMode()); + mWordComposer.isBatchMode(), suggestionInfo.mScore, suggestionInfo.mKind, + suggestionInfo.mSourceDict); } mConnection.endBatchEdit(); // Don't allow cancellation of manual pick diff --git a/java/src/com/android/inputmethod/latin/RecapitalizeStatus.java b/java/src/com/android/inputmethod/latin/RecapitalizeStatus.java index 8a704ab42..b9d7dcf78 100644 --- a/java/src/com/android/inputmethod/latin/RecapitalizeStatus.java +++ b/java/src/com/android/inputmethod/latin/RecapitalizeStatus.java @@ -163,7 +163,10 @@ public class RecapitalizeStatus { final int codePoint = mStringBefore.codePointBefore(nonWhitespaceEnd); if (!Character.isWhitespace(codePoint)) break; } - if (0 != nonWhitespaceStart || len != nonWhitespaceEnd) { + // If nonWhitespaceStart >= nonWhitespaceEnd, that means the selection contained only + // whitespace, so we leave it as is. + if ((0 != nonWhitespaceStart || len != nonWhitespaceEnd) + && nonWhitespaceStart < nonWhitespaceEnd) { mCursorEndAfter = mCursorStartBefore + nonWhitespaceEnd; mCursorStartBefore = mCursorStartAfter = mCursorStartBefore + nonWhitespaceStart; mStringAfter = mStringBefore = diff --git a/java/src/com/android/inputmethod/latin/SeekBarDialogPreference.java b/java/src/com/android/inputmethod/latin/SeekBarDialogPreference.java index 3ea9fedd7..44065ff33 100644 --- a/java/src/com/android/inputmethod/latin/SeekBarDialogPreference.java +++ b/java/src/com/android/inputmethod/latin/SeekBarDialogPreference.java @@ -33,10 +33,10 @@ public final class SeekBarDialogPreference extends DialogPreference public int readDefaultValue(final String key); public void writeValue(final int value, final String key); public void writeDefaultValue(final String key); + public String getValueText(final int value); public void feedbackValue(final int value); } - private final int mValueFormatResId; private final int mMaxValue; private final int mMinValue; private final int mStepValue; @@ -50,7 +50,6 @@ public final class SeekBarDialogPreference extends DialogPreference super(context, attrs); final TypedArray a = context.obtainStyledAttributes( attrs, R.styleable.SeekBarDialogPreference, 0, 0); - mValueFormatResId = a.getResourceId(R.styleable.SeekBarDialogPreference_valueFormatText, 0); mMaxValue = a.getInt(R.styleable.SeekBarDialogPreference_maxValue, 0); mMinValue = a.getInt(R.styleable.SeekBarDialogPreference_minValue, 0); mStepValue = a.getInt(R.styleable.SeekBarDialogPreference_stepValue, 0); @@ -60,15 +59,8 @@ public final class SeekBarDialogPreference extends DialogPreference public void setInterface(final ValueProxy proxy) { mValueProxy = proxy; - setSummary(getValueText(clipValue(proxy.readValue(getKey())))); - } - - private String getValueText(final int value) { - if (mValueFormatResId == 0) { - return Integer.toString(value); - } else { - return getContext().getString(mValueFormatResId, value); - } + final int value = mValueProxy.readValue(getKey()); + setSummary(mValueProxy.getValueText(value)); } @Override @@ -101,16 +93,11 @@ public final class SeekBarDialogPreference extends DialogPreference return clipValue(getValueFromProgress(progress)); } - private void setValue(final int value, final boolean fromUser) { - mValueView.setText(getValueText(value)); - if (!fromUser) { - mSeekBar.setProgress(getProgressFromValue(value)); - } - } - @Override protected void onBindDialogView(final View view) { - setValue(clipValue(mValueProxy.readValue(getKey())), false /* fromUser */); + final int value = mValueProxy.readValue(getKey()); + mValueView.setText(mValueProxy.getValueText(value)); + mSeekBar.setProgress(getProgressFromValue(clipValue(value))); } @Override @@ -125,13 +112,15 @@ public final class SeekBarDialogPreference extends DialogPreference super.onClick(dialog, which); final String key = getKey(); if (which == DialogInterface.BUTTON_NEUTRAL) { - setValue(clipValue(mValueProxy.readDefaultValue(key)), false /* fromUser */); + final int value = mValueProxy.readDefaultValue(key); + setSummary(mValueProxy.getValueText(value)); mValueProxy.writeDefaultValue(key); return; } if (which == DialogInterface.BUTTON_POSITIVE) { - setSummary(mValueView.getText()); - mValueProxy.writeValue(getClippedValueFromProgress(mSeekBar.getProgress()), key); + final int value = getClippedValueFromProgress(mSeekBar.getProgress()); + setSummary(mValueProxy.getValueText(value)); + mValueProxy.writeValue(value, key); return; } } @@ -139,7 +128,11 @@ public final class SeekBarDialogPreference extends DialogPreference @Override public void onProgressChanged(final SeekBar seekBar, final int progress, final boolean fromUser) { - setValue(getClippedValueFromProgress(progress), fromUser); + final int value = getClippedValueFromProgress(progress); + mValueView.setText(mValueProxy.getValueText(value)); + if (!fromUser) { + mSeekBar.setProgress(getProgressFromValue(value)); + } } @Override diff --git a/java/src/com/android/inputmethod/latin/SettingsFragment.java b/java/src/com/android/inputmethod/latin/SettingsFragment.java index 1fad765d7..9cc178598 100644 --- a/java/src/com/android/inputmethod/latin/SettingsFragment.java +++ b/java/src/com/android/inputmethod/latin/SettingsFragment.java @@ -364,6 +364,11 @@ public final class SettingsFragment extends InputMethodSettingsFragment public void feedbackValue(final int value) { AudioAndHapticFeedbackManager.getInstance().vibrate(value); } + + @Override + public String getValueText(final int value) { + return res.getString(R.string.abbreviation_unit_milliseconds, value); + } }); } @@ -396,6 +401,11 @@ public final class SettingsFragment extends InputMethodSettingsFragment } @Override + public String getValueText(final int value) { + return res.getString(R.string.abbreviation_unit_milliseconds, value); + } + + @Override public void feedbackValue(final int value) {} }); } @@ -439,6 +449,11 @@ public final class SettingsFragment extends InputMethodSettingsFragment } @Override + public String getValueText(final int value) { + return Integer.toString(value); + } + + @Override public void feedbackValue(final int value) { am.playSoundEffect( AudioManager.FX_KEYPRESS_STANDARD, getValueFromPercentage(value)); diff --git a/java/src/com/android/inputmethod/latin/userdictionary/UserDictionaryAddWordContents.java b/java/src/com/android/inputmethod/latin/userdictionary/UserDictionaryAddWordContents.java index 2b6fda381..89ec7466e 100644 --- a/java/src/com/android/inputmethod/latin/userdictionary/UserDictionaryAddWordContents.java +++ b/java/src/com/android/inputmethod/latin/userdictionary/UserDictionaryAddWordContents.java @@ -76,7 +76,9 @@ public class UserDictionaryAddWordContents { final String word = args.getString(EXTRA_WORD); if (null != word) { mWordEditText.setText(word); - mWordEditText.setSelection(word.length()); + // Use getText in case the edit text modified the text we set. This happens when + // it's too long to be edited. + mWordEditText.setSelection(mWordEditText.getText().length()); } final String shortcut; if (UserDictionarySettings.IS_SHORTCUT_API_SUPPORTED) { diff --git a/java/src/com/android/inputmethod/research/JsonUtils.java b/java/src/com/android/inputmethod/research/JsonUtils.java index 24cd8d935..63d524df7 100644 --- a/java/src/com/android/inputmethod/research/JsonUtils.java +++ b/java/src/com/android/inputmethod/research/JsonUtils.java @@ -94,12 +94,17 @@ import java.util.Map; .value(words.mIsPunctuationSuggestions); jsonWriter.name("isObsoleteSuggestions").value(words.mIsObsoleteSuggestions); jsonWriter.name("isPrediction").value(words.mIsPrediction); - jsonWriter.name("words"); + jsonWriter.name("suggestedWords"); jsonWriter.beginArray(); final int size = words.size(); for (int j = 0; j < size; j++) { final SuggestedWordInfo wordInfo = words.getInfo(j); - jsonWriter.value(wordInfo.toString()); + jsonWriter.beginObject(); + jsonWriter.name("word").value(wordInfo.toString()); + jsonWriter.name("score").value(wordInfo.mScore); + jsonWriter.name("kind").value(wordInfo.mKind); + jsonWriter.name("sourceDict").value(wordInfo.mSourceDict); + jsonWriter.endObject(); } jsonWriter.endArray(); jsonWriter.endObject(); diff --git a/java/src/com/android/inputmethod/research/ResearchLogger.java b/java/src/com/android/inputmethod/research/ResearchLogger.java index e890b74aa..6029ba963 100644 --- a/java/src/com/android/inputmethod/research/ResearchLogger.java +++ b/java/src/com/android/inputmethod/research/ResearchLogger.java @@ -1308,9 +1308,10 @@ public class ResearchLogger implements SharedPreferences.OnSharedPreferenceChang */ private static final LogStatement LOGSTATEMENT_LATINIME_PICKSUGGESTIONMANUALLY = new LogStatement("LatinIMEPickSuggestionManually", true, false, "replacedWord", "index", - "suggestion", "x", "y", "isBatchMode"); + "suggestion", "x", "y", "isBatchMode", "score", "kind", "sourceDict"); public static void latinIME_pickSuggestionManually(final String replacedWord, - final int index, final String suggestion, final boolean isBatchMode) { + final int index, final String suggestion, final boolean isBatchMode, + final int score, final int kind, final String sourceDict) { final ResearchLogger researchLogger = getInstance(); if (!replacedWord.equals(suggestion.toString())) { // The user chose something other than what was already there. @@ -1321,7 +1322,7 @@ public class ResearchLogger implements SharedPreferences.OnSharedPreferenceChang researchLogger.enqueueEvent(LOGSTATEMENT_LATINIME_PICKSUGGESTIONMANUALLY, scrubDigitsFromString(replacedWord), index, suggestion == null ? null : scrubbedWord, Constants.SUGGESTION_STRIP_COORDINATE, - Constants.SUGGESTION_STRIP_COORDINATE, isBatchMode); + Constants.SUGGESTION_STRIP_COORDINATE, isBatchMode, score, kind, sourceDict); researchLogger.commitCurrentLogUnitAsWord(scrubbedWord, Long.MAX_VALUE, isBatchMode); researchLogger.mStatistics.recordManualSuggestion(SystemClock.uptimeMillis()); } @@ -1843,7 +1844,7 @@ public class ResearchLogger implements SharedPreferences.OnSharedPreferenceChang */ private static final LogStatement LOGSTATEMENT_LATINIME_ONENDBATCHINPUT = new LogStatement("LatinIMEOnEndBatchInput", true, false, "enteredText", - "enteredWordPos"); + "enteredWordPos", "suggestedWords"); public static void latinIME_onEndBatchInput(final CharSequence enteredText, final int enteredWordPos, final SuggestedWords suggestedWords) { final ResearchLogger researchLogger = getInstance(); @@ -1851,7 +1852,7 @@ public class ResearchLogger implements SharedPreferences.OnSharedPreferenceChang researchLogger.mCurrentLogUnit.setWords(enteredText.toString()); } researchLogger.enqueueEvent(LOGSTATEMENT_LATINIME_ONENDBATCHINPUT, enteredText, - enteredWordPos); + enteredWordPos, suggestedWords); researchLogger.mCurrentLogUnit.initializeSuggestions(suggestedWords); researchLogger.mStatistics.recordGestureInput(enteredText.length(), SystemClock.uptimeMillis()); diff --git a/native/jni/Android.mk b/native/jni/Android.mk index 1cdfbe4d1..fb60139d3 100644 --- a/native/jni/Android.mk +++ b/native/jni/Android.mk @@ -53,12 +53,15 @@ LATIN_IME_CORE_SRC_FILES := \ dic_nodes_cache.cpp) \ $(addprefix suggest/core/dictionary/, \ bigram_dictionary.cpp \ + binary_dictionary_bigrams_reading_utils.cpp \ binary_dictionary_format_utils.cpp \ binary_dictionary_header.cpp \ binary_dictionary_header_reading_utils.cpp \ + bloom_filter.cpp \ byte_array_utils.cpp \ dictionary.cpp \ - digraph_utils.cpp) \ + digraph_utils.cpp \ + multi_bigram_map.cpp) \ $(addprefix suggest/core/layout/, \ additional_proximity_chars.cpp \ proximity_info.cpp \ diff --git a/native/jni/src/defines.h b/native/jni/src/defines.h index a3cf6a4b4..e349aedb1 100644 --- a/native/jni/src/defines.h +++ b/native/jni/src/defines.h @@ -300,33 +300,6 @@ static inline void prof_out(void) { #define DIC_NODES_CACHE_INITIAL_QUEUE_ID_CACHE_FOR_CONTINUOUS_SUGGESTION 3 #define DIC_NODES_CACHE_PRIORITY_QUEUES_SIZE 4 -// Size, in bytes, of the bloom filter index for bigrams -// 128 gives us 1024 buckets. The probability of false positive is (1 - e ** (-kn/m))**k, -// where k is the number of hash functions, n the number of bigrams, and m the number of -// bits we can test. -// At the moment 100 is the maximum number of bigrams for a word with the current -// dictionaries, so n = 100. 1024 buckets give us m = 1024. -// With 1 hash function, our false positive rate is about 9.3%, which should be enough for -// our uses since we are only using this to increase average performance. For the record, -// k = 2 gives 3.1% and k = 3 gives 1.6%. With k = 1, making m = 2048 gives 4.8%, -// and m = 4096 gives 2.4%. -#define BIGRAM_FILTER_BYTE_SIZE 128 -// Must be smaller than BIGRAM_FILTER_BYTE_SIZE * 8, and preferably prime. 1021 is the largest -// prime under 128 * 8. -#define BIGRAM_FILTER_MODULO 1021 -#if BIGRAM_FILTER_BYTE_SIZE * 8 < BIGRAM_FILTER_MODULO -#error "BIGRAM_FILTER_MODULO is larger than BIGRAM_FILTER_BYTE_SIZE" -#endif - -// Max number of bigram maps (previous word contexts) to be cached. Increasing this number could -// improve bigram lookup speed for multi-word suggestions, but at the cost of more memory usage. -// Also, there are diminishing returns since the most frequently used bigrams are typically near -// the beginning of the input and are thus the first ones to be cached. Note that these bigrams -// are reset for each new composing word. -#define MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP 25 -// Most common previous word contexts currently have 100 bigrams -#define DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP 100 - template<typename T> AK_FORCE_INLINE const T &min(const T &a, const T &b) { return a < b ? a : b; } template<typename T> AK_FORCE_INLINE const T &max(const T &a, const T &b) { return a > b ? a : b; } diff --git a/native/jni/src/suggest/core/dicnode/dic_node.h b/native/jni/src/suggest/core/dicnode/dic_node.h index 3f64d07b2..25299948d 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node.h +++ b/native/jni/src/suggest/core/dicnode/dic_node.h @@ -128,7 +128,7 @@ class DicNode { void initAsRootWithPreviousWord(DicNode *dicNode, const int pos, const int childrenPos, const int childrenCount) { mIsUsed = true; - mIsCachedForNextSuggestion = false; + mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion; mDicNodeProperties.init( pos, 0, childrenPos, 0, 0, 0, childrenCount, 0, 0, false, false, true, 0, 0); // TODO: Move to dicNodeState? @@ -479,6 +479,11 @@ class DicNode { return mDicNodeProperties.getDepth(); } + // "Length" includes spaces. + inline uint16_t getTotalLength() const { + return getDepth() + mDicNodeState.mDicNodeStatePrevWord.getPrevWordLength(); + } + AK_FORCE_INLINE void dump(const char *tag) const { #if DEBUG_DICT DUMP_WORD_AND_SCORE(tag); diff --git a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp index 3deee1a42..f0f26c72b 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp +++ b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp @@ -233,8 +233,7 @@ namespace latinime { return multiBigramMap->getBigramProbability( binaryDictionaryInfo, prevWordPos, wordPos, unigramProbability); } - return BinaryFormat::getBigramProbability( - binaryDictionaryInfo->getDictRoot(), prevWordPos, wordPos, unigramProbability); + return ProbabilityUtils::backoff(unigramProbability); } /////////////////////////////////////// diff --git a/native/jni/src/suggest/core/dictionary/bigram_dictionary.cpp b/native/jni/src/suggest/core/dictionary/bigram_dictionary.cpp index 53e2df62d..6e02100fc 100644 --- a/native/jni/src/suggest/core/dictionary/bigram_dictionary.cpp +++ b/native/jni/src/suggest/core/dictionary/bigram_dictionary.cpp @@ -21,6 +21,7 @@ #include "bigram_dictionary.h" #include "defines.h" +#include "suggest/core/dictionary/binary_dictionary_bigrams_iterator.h" #include "suggest/core/dictionary/binary_dictionary_info.h" #include "suggest/core/dictionary/binary_format.h" #include "suggest/core/dictionary/dictionary.h" @@ -100,12 +101,11 @@ void BigramDictionary::addWordBigram(int *word, int length, int probability, int * and the bigrams are used to boost unigram result scores, it makes little sense to * reduce their scope to the ones that match the first letter. */ -int BigramDictionary::getBigrams(const int *prevWord, int prevWordLength, int *inputCodePoints, +int BigramDictionary::getPredictions(const int *prevWord, int prevWordLength, int *inputCodePoints, int inputSize, int *bigramCodePoints, int *bigramProbability, int *outputTypes) const { // TODO: remove unused arguments, and refrain from storing stuff in members of this class // TODO: have "in" arguments before "out" ones, and make out args explicit in the name - const uint8_t *const root = mBinaryDictionaryInfo->getDictRoot(); int pos = getBigramListPositionForWord(prevWord, prevWordLength, false /* forceLowerCaseSearch */); // getBigramListPositionForWord returns 0 if this word isn't in the dictionary or has no bigrams @@ -116,21 +116,20 @@ int BigramDictionary::getBigrams(const int *prevWord, int prevWordLength, int *i } // If still no bigrams, we really don't have them! if (0 == pos) return 0; - uint8_t bigramFlags; + int bigramCount = 0; - do { - bigramFlags = BinaryFormat::getFlagsAndForwardPointer(root, &pos); - int bigramBuffer[MAX_WORD_LENGTH]; - int unigramProbability = 0; - const int bigramPos = BinaryFormat::getAttributeAddressAndForwardPointer(root, bigramFlags, - &pos); - const int length = BinaryFormat::getWordAtAddress(root, bigramPos, MAX_WORD_LENGTH, - bigramBuffer, &unigramProbability); + int unigramProbability = 0; + int bigramBuffer[MAX_WORD_LENGTH]; + for (BinaryDictionaryBigramsIterator bigramsIt(mBinaryDictionaryInfo, pos); + bigramsIt.hasNext(); /* no-op */) { + bigramsIt.next(); + const int length = BinaryFormat::getWordAtAddress( + mBinaryDictionaryInfo->getDictRoot(), bigramsIt.getBigramPos(), + MAX_WORD_LENGTH, bigramBuffer, &unigramProbability); // inputSize == 0 means we are trying to find bigram predictions. if (inputSize < 1 || checkFirstCharacter(bigramBuffer, inputCodePoints)) { - const int bigramProbabilityTemp = - BinaryFormat::MASK_ATTRIBUTE_PROBABILITY & bigramFlags; + const int bigramProbabilityTemp = bigramsIt.getProbability(); // Due to space constraints, the probability for bigrams is approximate - the lower the // unigram probability, the worse the precision. The theoritical maximum error in // resulting probability is 8 - although in the practice it's never bigger than 3 or 4 @@ -142,7 +141,7 @@ int BigramDictionary::getBigrams(const int *prevWord, int prevWordLength, int *i outputTypes); ++bigramCount; } - } while (BinaryFormat::FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags); + } return min(bigramCount, MAX_RESULTS); } @@ -187,22 +186,20 @@ bool BigramDictionary::checkFirstCharacter(int *word, int *inputCodePoints) cons bool BigramDictionary::isValidBigram(const int *word1, int length1, const int *word2, int length2) const { - const uint8_t *const root = mBinaryDictionaryInfo->getDictRoot(); int pos = getBigramListPositionForWord(word1, length1, false /* forceLowerCaseSearch */); // getBigramListPositionForWord returns 0 if this word isn't in the dictionary or has no bigrams if (0 == pos) return false; - int nextWordPos = BinaryFormat::getTerminalPosition(root, word2, length2, - false /* forceLowerCaseSearch */); + int nextWordPos = BinaryFormat::getTerminalPosition(mBinaryDictionaryInfo->getDictRoot(), + word2, length2, false /* forceLowerCaseSearch */); if (NOT_VALID_WORD == nextWordPos) return false; - uint8_t bigramFlags; - do { - bigramFlags = BinaryFormat::getFlagsAndForwardPointer(root, &pos); - const int bigramPos = BinaryFormat::getAttributeAddressAndForwardPointer(root, bigramFlags, - &pos); - if (bigramPos == nextWordPos) { + + for (BinaryDictionaryBigramsIterator bigramsIt(mBinaryDictionaryInfo, pos); + bigramsIt.hasNext(); /* no-op */) { + bigramsIt.next(); + if (bigramsIt.getBigramPos() == nextWordPos) { return true; } - } while (BinaryFormat::FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags); + } return false; } diff --git a/native/jni/src/suggest/core/dictionary/bigram_dictionary.h b/native/jni/src/suggest/core/dictionary/bigram_dictionary.h index 06d0e9da3..7706a2c22 100644 --- a/native/jni/src/suggest/core/dictionary/bigram_dictionary.h +++ b/native/jni/src/suggest/core/dictionary/bigram_dictionary.h @@ -27,8 +27,8 @@ class BigramDictionary { public: BigramDictionary(const BinaryDictionaryInfo *const binaryDictionaryInfo); - int getBigrams(const int *word, int length, int *inputCodePoints, int inputSize, int *outWords, - int *frequencies, int *outputTypes) const; + int getPredictions(const int *word, int length, int *inputCodePoints, int inputSize, + int *outWords, int *frequencies, int *outputTypes) const; bool isValidBigram(const int *word1, int length1, const int *word2, int length2) const; ~BigramDictionary(); diff --git a/native/jni/src/suggest/core/dictionary/binary_dictionary_bigrams_iterator.h b/native/jni/src/suggest/core/dictionary/binary_dictionary_bigrams_iterator.h new file mode 100644 index 000000000..0856840b2 --- /dev/null +++ b/native/jni/src/suggest/core/dictionary/binary_dictionary_bigrams_iterator.h @@ -0,0 +1,67 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_BINARY_DICTIONARY_BIGRAMS_ITERATOR_H +#define LATINIME_BINARY_DICTIONARY_BIGRAMS_ITERATOR_H + +#include "defines.h" +#include "suggest/core/dictionary/binary_dictionary_bigrams_reading_utils.h" +#include "suggest/core/dictionary/binary_dictionary_info.h" + +namespace latinime { + +class BinaryDictionaryBigramsIterator { + public: + BinaryDictionaryBigramsIterator( + const BinaryDictionaryInfo *const binaryDictionaryInfo, const int pos) + : mBinaryDictionaryInfo(binaryDictionaryInfo), mPos(pos), mBigramFlags(0), + mBigramPos(0), mHasNext(true) {} + + AK_FORCE_INLINE bool hasNext() const { + return mHasNext; + } + + AK_FORCE_INLINE void next() { + mBigramFlags = BinaryDictionaryBigramsReadingUtils::getFlagsAndForwardPointer( + mBinaryDictionaryInfo, &mPos); + mBigramPos = BinaryDictionaryBigramsReadingUtils::getBigramAddressAndForwardPointer( + mBinaryDictionaryInfo, mBigramFlags, &mPos); + mHasNext = BinaryDictionaryBigramsReadingUtils::hasNext(mBigramFlags); + } + + AK_FORCE_INLINE int getProbability() const { + return BinaryDictionaryBigramsReadingUtils::getBigramProbability(mBigramFlags); + } + + AK_FORCE_INLINE int getBigramPos() const { + return mBigramPos; + } + + AK_FORCE_INLINE int getFlags() const { + return mBigramFlags; + } + + private: + DISALLOW_COPY_AND_ASSIGN(BinaryDictionaryBigramsIterator); + + const BinaryDictionaryInfo *const mBinaryDictionaryInfo; + int mPos; + BinaryDictionaryBigramsReadingUtils::BigramFlags mBigramFlags; + int mBigramPos; + bool mHasNext; +}; +} // namespace latinime +#endif // LATINIME_BINARY_DICTIONARY_BIGRAMS_ITERATOR_H diff --git a/native/jni/src/suggest/core/dictionary/binary_dictionary_bigrams_reading_utils.cpp b/native/jni/src/suggest/core/dictionary/binary_dictionary_bigrams_reading_utils.cpp new file mode 100644 index 000000000..78a54b141 --- /dev/null +++ b/native/jni/src/suggest/core/dictionary/binary_dictionary_bigrams_reading_utils.cpp @@ -0,0 +1,68 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/core/dictionary/binary_dictionary_bigrams_reading_utils.h" + +#include "suggest/core/dictionary/binary_dictionary_info.h" +#include "suggest/core/dictionary/byte_array_utils.h" + +namespace latinime { + +const BinaryDictionaryBigramsReadingUtils::BigramFlags + BinaryDictionaryBigramsReadingUtils::MASK_ATTRIBUTE_ADDRESS_TYPE = 0x30; +const BinaryDictionaryBigramsReadingUtils::BigramFlags + BinaryDictionaryBigramsReadingUtils::FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE = 0x10; +const BinaryDictionaryBigramsReadingUtils::BigramFlags + BinaryDictionaryBigramsReadingUtils::FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES = 0x20; +const BinaryDictionaryBigramsReadingUtils::BigramFlags + BinaryDictionaryBigramsReadingUtils::FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES = 0x30; +const BinaryDictionaryBigramsReadingUtils::BigramFlags + BinaryDictionaryBigramsReadingUtils::FLAG_ATTRIBUTE_OFFSET_NEGATIVE = 0x40; +// Flag for presence of more attributes +const BinaryDictionaryBigramsReadingUtils::BigramFlags + BinaryDictionaryBigramsReadingUtils::FLAG_ATTRIBUTE_HAS_NEXT = 0x80; +// Mask for attribute probability, stored on 4 bits inside the flags byte. +const BinaryDictionaryBigramsReadingUtils::BigramFlags + BinaryDictionaryBigramsReadingUtils::MASK_ATTRIBUTE_PROBABILITY = 0x0F; +const int BinaryDictionaryBigramsReadingUtils::ATTRIBUTE_ADDRESS_SHIFT = 4; + +/* static */ int BinaryDictionaryBigramsReadingUtils::getBigramAddressAndForwardPointer( + const BinaryDictionaryInfo *const binaryDictionaryInfo, const BigramFlags flags, + int *const pos) { + int offset = 0; + const int origin = *pos; + switch (MASK_ATTRIBUTE_ADDRESS_TYPE & flags) { + case FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE: + offset = ByteArrayUtils::readUint8andAdvancePosition( + binaryDictionaryInfo->getDictRoot(), pos); + break; + case FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES: + offset = ByteArrayUtils::readUint16andAdvancePosition( + binaryDictionaryInfo->getDictRoot(), pos); + break; + case FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES: + offset = ByteArrayUtils::readUint24andAdvancePosition( + binaryDictionaryInfo->getDictRoot(), pos); + break; + } + if (isOffsetNegative(flags)) { + return origin - offset; + } else { + return origin + offset; + } +} + +} // namespace latinime diff --git a/native/jni/src/suggest/core/dictionary/binary_dictionary_bigrams_reading_utils.h b/native/jni/src/suggest/core/dictionary/binary_dictionary_bigrams_reading_utils.h new file mode 100644 index 000000000..e71f2a17a --- /dev/null +++ b/native/jni/src/suggest/core/dictionary/binary_dictionary_bigrams_reading_utils.h @@ -0,0 +1,90 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_BINARY_DICTIONARY_BIGRAM_READING_UTILS_H +#define LATINIME_BINARY_DICTIONARY_BIGRAM_READING_UTILS_H + +#include <stdint.h> + +#include "defines.h" +#include "suggest/core/dictionary/binary_dictionary_info.h" +#include "suggest/core/dictionary/byte_array_utils.h" + +namespace latinime { + +class BinaryDictionaryBigramsReadingUtils { + public: + typedef uint8_t BigramFlags; + + static AK_FORCE_INLINE void skipExistingBigrams( + const BinaryDictionaryInfo *const binaryDictionaryInfo, int *const pos) { + BigramFlags flags = getFlagsAndForwardPointer(binaryDictionaryInfo, pos); + while (hasNext(flags)) { + *pos += attributeAddressSize(flags); + flags = getFlagsAndForwardPointer(binaryDictionaryInfo, pos); + } + *pos += attributeAddressSize(flags); + } + + static AK_FORCE_INLINE BigramFlags getFlagsAndForwardPointer( + const BinaryDictionaryInfo *const binaryDictionaryInfo, int *const pos) { + return ByteArrayUtils::readUint8andAdvancePosition( + binaryDictionaryInfo->getDictRoot(), pos); + } + + static AK_FORCE_INLINE int getBigramProbability(const BigramFlags flags) { + return flags & MASK_ATTRIBUTE_PROBABILITY; + } + + static AK_FORCE_INLINE bool isOffsetNegative(const BigramFlags flags) { + return (flags & FLAG_ATTRIBUTE_OFFSET_NEGATIVE) != 0; + } + + static AK_FORCE_INLINE bool hasNext(const BigramFlags flags) { + return (flags & FLAG_ATTRIBUTE_HAS_NEXT) != 0; + } + + static int getBigramAddressAndForwardPointer( + const BinaryDictionaryInfo *const binaryDictionaryInfo, + const BigramFlags flags, int *const pos); + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(BinaryDictionaryBigramsReadingUtils); + + static const BigramFlags MASK_ATTRIBUTE_ADDRESS_TYPE; + static const BigramFlags FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE; + static const BigramFlags FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES; + static const BigramFlags FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES; + static const BigramFlags FLAG_ATTRIBUTE_OFFSET_NEGATIVE; + static const BigramFlags FLAG_ATTRIBUTE_HAS_NEXT; + static const BigramFlags MASK_ATTRIBUTE_PROBABILITY; + static const int ATTRIBUTE_ADDRESS_SHIFT; + + static AK_FORCE_INLINE int attributeAddressSize(const BigramFlags flags) { + return (flags & MASK_ATTRIBUTE_ADDRESS_TYPE) >> ATTRIBUTE_ADDRESS_SHIFT; + /* Note: this is a value-dependant optimization of what may probably be + more readably written this way: + switch (flags * BinaryFormat::MASK_ATTRIBUTE_ADDRESS_TYPE) { + case FLAG_ATTRIBUTE_ADDRESS_TYPE_ONEBYTE: return 1; + case FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES: return 2; + case FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTE: return 3; + default: return 0; + } + */ + } +}; +} +#endif /* LATINIME_BINARY_DICTIONARY_BIGRAM_READING_UTILS_H */ diff --git a/native/jni/src/suggest/core/dictionary/binary_format.h b/native/jni/src/suggest/core/dictionary/binary_format.h index 0a290d80a..df0ec480d 100644 --- a/native/jni/src/suggest/core/dictionary/binary_format.h +++ b/native/jni/src/suggest/core/dictionary/binary_format.h @@ -21,7 +21,6 @@ #include "suggest/core/dictionary/probability_utils.h" #include "utils/char_utils.h" -#include "utils/hash_map_compat.h" namespace latinime { @@ -81,16 +80,10 @@ class BinaryFormat { const int length, const bool forceLowerCaseSearch); static int getWordAtAddress(const uint8_t *const root, const int address, const int maxDepth, int *outWord, int *outUnigramProbability); - static int getBigramProbabilityFromHashMap(const int position, - const hash_map_compat<int, int> *bigramMap, const int unigramProbability); - static void fillBigramProbabilityToHashMap(const uint8_t *const root, int position, - hash_map_compat<int, int> *bigramMap); - static int getBigramProbability(const uint8_t *const root, int position, - const int nextPosition, const int unigramProbability); + static int getBigramListPositionForWordPosition(const uint8_t *const root, int position); private: DISALLOW_IMPLICIT_CONSTRUCTORS(BinaryFormat); - static int getBigramListPositionForWordPosition(const uint8_t *const root, int position); static const int FLAG_GROUP_ADDRESS_TYPE_NOADDRESS = 0x00; static const int FLAG_GROUP_ADDRESS_TYPE_ONEBYTE = 0x40; @@ -516,57 +509,6 @@ AK_FORCE_INLINE int BinaryFormat::getWordAtAddress(const uint8_t *const root, co return 0; } -// This returns a probability in log space. -inline int BinaryFormat::getBigramProbabilityFromHashMap(const int position, - const hash_map_compat<int, int> *bigramMap, const int unigramProbability) { - if (!bigramMap) { - return ProbabilityUtils::backoff(unigramProbability); - } - const hash_map_compat<int, int>::const_iterator bigramProbabilityIt = bigramMap->find(position); - if (bigramProbabilityIt != bigramMap->end()) { - const int bigramProbability = bigramProbabilityIt->second; - return ProbabilityUtils::computeProbabilityForBigram(unigramProbability, bigramProbability); - } - return ProbabilityUtils::backoff(unigramProbability); -} - -AK_FORCE_INLINE void BinaryFormat::fillBigramProbabilityToHashMap( - const uint8_t *const root, int position, hash_map_compat<int, int> *bigramMap) { - position = getBigramListPositionForWordPosition(root, position); - if (0 == position) return; - - uint8_t bigramFlags; - do { - bigramFlags = getFlagsAndForwardPointer(root, &position); - const int probability = MASK_ATTRIBUTE_PROBABILITY & bigramFlags; - const int bigramPos = getAttributeAddressAndForwardPointer(root, bigramFlags, - &position); - (*bigramMap)[bigramPos] = probability; - } while (FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags); -} - -AK_FORCE_INLINE int BinaryFormat::getBigramProbability(const uint8_t *const root, int position, - const int nextPosition, const int unigramProbability) { - position = getBigramListPositionForWordPosition(root, position); - if (0 == position) { - return ProbabilityUtils::backoff(unigramProbability); - } - - uint8_t bigramFlags; - do { - bigramFlags = getFlagsAndForwardPointer(root, &position); - const int bigramPos = getAttributeAddressAndForwardPointer( - root, bigramFlags, &position); - if (bigramPos == nextPosition) { - const int bigramProbability = MASK_ATTRIBUTE_PROBABILITY & bigramFlags; - return ProbabilityUtils::computeProbabilityForBigram( - unigramProbability, bigramProbability); - } - } while (FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags); - return ProbabilityUtils::backoff(unigramProbability); -} - -// Returns a pointer to the start of the bigram list. AK_FORCE_INLINE int BinaryFormat::getBigramListPositionForWordPosition( const uint8_t *const root, int position) { if (NOT_VALID_WORD == position) return 0; diff --git a/native/jni/src/suggest/core/dictionary/bloom_filter.cpp b/native/jni/src/suggest/core/dictionary/bloom_filter.cpp new file mode 100644 index 000000000..4ae474e0c --- /dev/null +++ b/native/jni/src/suggest/core/dictionary/bloom_filter.cpp @@ -0,0 +1,25 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/core/dictionary/bloom_filter.h" + +namespace latinime { + +// Must be smaller than BIGRAM_FILTER_BYTE_SIZE * 8, and preferably prime. 1021 is the largest +// prime under 128 * 8. +const int BloomFilter::BIGRAM_FILTER_MODULO = 1021; + +} // namespace latinime diff --git a/native/jni/src/suggest/core/dictionary/bloom_filter.h b/native/jni/src/suggest/core/dictionary/bloom_filter.h index bcce1f7ea..5205456a8 100644 --- a/native/jni/src/suggest/core/dictionary/bloom_filter.h +++ b/native/jni/src/suggest/core/dictionary/bloom_filter.h @@ -23,16 +23,48 @@ namespace latinime { -// TODO: uint32_t position -static inline void setInFilter(uint8_t *filter, const int32_t position) { - const uint32_t bucket = static_cast<uint32_t>(position % BIGRAM_FILTER_MODULO); - filter[bucket >> 3] |= static_cast<uint8_t>(1 << (bucket & 0x7)); -} - -// TODO: uint32_t position -static inline bool isInFilter(const uint8_t *filter, const int32_t position) { - const uint32_t bucket = static_cast<uint32_t>(position % BIGRAM_FILTER_MODULO); - return filter[bucket >> 3] & static_cast<uint8_t>(1 << (bucket & 0x7)); -} +// This bloom filter is used for optimizing bigram retrieval. +// Execution times with previous word "this" are as follows: +// without bloom filter (use only hash_map): +// Total 147792.34 (sum of others 147771.57) +// with bloom filter: +// Total 145900.64 (sum of others 145874.30) +// always read binary dictionary: +// Total 148603.14 (sum of others 148579.90) +class BloomFilter { + public: + BloomFilter() { + ASSERT(BIGRAM_FILTER_BYTE_SIZE * 8 >= BIGRAM_FILTER_MODULO); + } + + // TODO: uint32_t position + AK_FORCE_INLINE void setInFilter(const int32_t position) { + const uint32_t bucket = static_cast<uint32_t>(position % BIGRAM_FILTER_MODULO); + mFilter[bucket >> 3] |= static_cast<uint8_t>(1 << (bucket & 0x7)); + } + + // TODO: uint32_t position + AK_FORCE_INLINE bool isInFilter(const int32_t position) const { + const uint32_t bucket = static_cast<uint32_t>(position % BIGRAM_FILTER_MODULO); + return (mFilter[bucket >> 3] & static_cast<uint8_t>(1 << (bucket & 0x7))) != 0; + } + + private: + // Size, in bytes, of the bloom filter index for bigrams + // 128 gives us 1024 buckets. The probability of false positive is (1 - e ** (-kn/m))**k, + // where k is the number of hash functions, n the number of bigrams, and m the number of + // bits we can test. + // At the moment 100 is the maximum number of bigrams for a word with the current + // dictionaries, so n = 100. 1024 buckets give us m = 1024. + // With 1 hash function, our false positive rate is about 9.3%, which should be enough for + // our uses since we are only using this to increase average performance. For the record, + // k = 2 gives 3.1% and k = 3 gives 1.6%. With k = 1, making m = 2048 gives 4.8%, + // and m = 4096 gives 2.4%. + // This is assigned here because it is used for array size. + static const int BIGRAM_FILTER_BYTE_SIZE = 128; + static const int BIGRAM_FILTER_MODULO; + + uint8_t mFilter[BIGRAM_FILTER_BYTE_SIZE]; +}; } // namespace latinime #endif // LATINIME_BLOOM_FILTER_H diff --git a/native/jni/src/suggest/core/dictionary/dictionary.cpp b/native/jni/src/suggest/core/dictionary/dictionary.cpp index 561e22d2d..27b052b7e 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.cpp +++ b/native/jni/src/suggest/core/dictionary/dictionary.cpp @@ -79,7 +79,7 @@ int Dictionary::getSuggestions(ProximityInfo *proximityInfo, DicTraverseSession int Dictionary::getBigrams(const int *word, int length, int *inputCodePoints, int inputSize, int *outWords, int *frequencies, int *outputTypes) const { if (length <= 0) return 0; - return mBigramDictionary->getBigrams(word, length, inputCodePoints, inputSize, outWords, + return mBigramDictionary->getPredictions(word, length, inputCodePoints, inputSize, outWords, frequencies, outputTypes); } diff --git a/native/jni/src/suggest/core/dictionary/multi_bigram_map.cpp b/native/jni/src/suggest/core/dictionary/multi_bigram_map.cpp new file mode 100644 index 000000000..b1d2f4b4d --- /dev/null +++ b/native/jni/src/suggest/core/dictionary/multi_bigram_map.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (C) 2013, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest/core/dictionary/multi_bigram_map.h" + +#include <cstddef> + +namespace latinime { + +// Max number of bigram maps (previous word contexts) to be cached. Increasing this number +// could improve bigram lookup speed for multi-word suggestions, but at the cost of more memory +// usage. Also, there are diminishing returns since the most frequently used bigrams are +// typically near the beginning of the input and are thus the first ones to be cached. Note +// that these bigrams are reset for each new composing word. +const size_t MultiBigramMap::MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP = 25; + +// Most common previous word contexts currently have 100 bigrams +const int MultiBigramMap::BigramMap::DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP = 100; + +} // namespace latinime diff --git a/native/jni/src/suggest/core/dictionary/multi_bigram_map.h b/native/jni/src/suggest/core/dictionary/multi_bigram_map.h index ba97e5842..60169f80e 100644 --- a/native/jni/src/suggest/core/dictionary/multi_bigram_map.h +++ b/native/jni/src/suggest/core/dictionary/multi_bigram_map.h @@ -17,9 +17,13 @@ #ifndef LATINIME_MULTI_BIGRAM_MAP_H #define LATINIME_MULTI_BIGRAM_MAP_H +#include <cstddef> + #include "defines.h" +#include "suggest/core/dictionary/binary_dictionary_bigrams_iterator.h" #include "suggest/core/dictionary/binary_dictionary_info.h" #include "suggest/core/dictionary/binary_format.h" +#include "suggest/core/dictionary/bloom_filter.h" #include "utils/hash_map_compat.h" namespace latinime { @@ -34,7 +38,7 @@ class MultiBigramMap { // Look up the bigram probability for the given word pair from the cached bigram maps. // Also caches the bigrams if there is space remaining and they have not been cached already. - int getBigramProbability(const BinaryDictionaryInfo *const binaryDicitonaryInfo, + int getBigramProbability(const BinaryDictionaryInfo *const binaryDictionaryInfo, const int wordPosition, const int nextWordPosition, const int unigramProbability) { hash_map_compat<int, BigramMap>::const_iterator mapPosition = mBigramMaps.find(wordPosition); @@ -42,11 +46,11 @@ class MultiBigramMap { return mapPosition->second.getBigramProbability(nextWordPosition, unigramProbability); } if (mBigramMaps.size() < MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP) { - addBigramsForWordPosition(binaryDicitonaryInfo, wordPosition); + addBigramsForWordPosition(binaryDictionaryInfo, wordPosition); return mBigramMaps[wordPosition].getBigramProbability( nextWordPosition, unigramProbability); } - return BinaryFormat::getBigramProbability(binaryDicitonaryInfo->getDictRoot(), + return readBigramProbabilityFromBinaryDictionary(binaryDictionaryInfo, wordPosition, nextWordPosition, unigramProbability); } @@ -59,30 +63,70 @@ class MultiBigramMap { class BigramMap { public: - BigramMap() : mBigramMap(DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP) {} + BigramMap() : mBigramMap(DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP), mBloomFilter() {} ~BigramMap() {} - void init(const BinaryDictionaryInfo *const binaryDicitonaryInfo, const int position) { - BinaryFormat::fillBigramProbabilityToHashMap( - binaryDicitonaryInfo->getDictRoot(), position, &mBigramMap); + void init(const BinaryDictionaryInfo *const binaryDictionaryInfo, const int nodePos) { + const int bigramsListPos = BinaryFormat::getBigramListPositionForWordPosition( + binaryDictionaryInfo->getDictRoot(), nodePos); + if (0 == bigramsListPos) { + return; + } + for (BinaryDictionaryBigramsIterator bigramsIt(binaryDictionaryInfo, bigramsListPos); + bigramsIt.hasNext(); /* no-op */) { + bigramsIt.next(); + mBigramMap[bigramsIt.getBigramPos()] = bigramsIt.getProbability(); + mBloomFilter.setInFilter(bigramsIt.getBigramPos()); + } } - inline int getBigramProbability(const int nextWordPosition, const int unigramProbability) - const { - return BinaryFormat::getBigramProbabilityFromHashMap( - nextWordPosition, &mBigramMap, unigramProbability); + AK_FORCE_INLINE int getBigramProbability( + const int nextWordPosition, const int unigramProbability) const { + if (mBloomFilter.isInFilter(nextWordPosition)) { + const hash_map_compat<int, int>::const_iterator bigramProbabilityIt = + mBigramMap.find(nextWordPosition); + if (bigramProbabilityIt != mBigramMap.end()) { + const int bigramProbability = bigramProbabilityIt->second; + return ProbabilityUtils::computeProbabilityForBigram( + unigramProbability, bigramProbability); + } + } + return ProbabilityUtils::backoff(unigramProbability); } private: - // Note: Default copy constructor needed for use in hash_map. + // NOTE: The BigramMap class doesn't use DISALLOW_COPY_AND_ASSIGN() because its default + // copy constructor is needed for use in hash_map. + static const int DEFAULT_HASH_MAP_SIZE_FOR_EACH_BIGRAM_MAP; hash_map_compat<int, int> mBigramMap; + BloomFilter mBloomFilter; }; - void addBigramsForWordPosition(const BinaryDictionaryInfo *const binaryDicitonaryInfo, - const int position) { - mBigramMaps[position].init(binaryDicitonaryInfo, position); + AK_FORCE_INLINE void addBigramsForWordPosition( + const BinaryDictionaryInfo *const binaryDictionaryInfo, const int position) { + mBigramMaps[position].init(binaryDictionaryInfo, position); + } + + AK_FORCE_INLINE int readBigramProbabilityFromBinaryDictionary( + const BinaryDictionaryInfo *const binaryDictionaryInfo, const int nodePos, + const int nextWordPosition, const int unigramProbability) { + const int bigramsListPos = BinaryFormat::getBigramListPositionForWordPosition( + binaryDictionaryInfo->getDictRoot(), nodePos); + if (0 == bigramsListPos) { + return ProbabilityUtils::backoff(unigramProbability); + } + for (BinaryDictionaryBigramsIterator bigramsIt(binaryDictionaryInfo, bigramsListPos); + bigramsIt.hasNext(); /* no-op */) { + bigramsIt.next(); + if (bigramsIt.getBigramPos() == nextWordPosition) { + return ProbabilityUtils::computeProbabilityForBigram( + unigramProbability, bigramsIt.getProbability()); + } + } + return ProbabilityUtils::backoff(unigramProbability); } + static const size_t MAX_CACHED_PREV_WORDS_IN_BIGRAM_MAP; hash_map_compat<int, BigramMap> mBigramMaps; }; } // namespace latinime diff --git a/native/jni/src/suggest/core/layout/proximity_info.cpp b/native/jni/src/suggest/core/layout/proximity_info.cpp index 80355c148..05826a5a1 100644 --- a/native/jni/src/suggest/core/layout/proximity_info.cpp +++ b/native/jni/src/suggest/core/layout/proximity_info.cpp @@ -134,24 +134,13 @@ bool ProximityInfo::hasSpaceProximity(const int x, const int y) const { } float ProximityInfo::getNormalizedSquaredDistanceFromCenterFloatG( - const int keyId, const int x, const int y, const float verticalScale) const { - const bool correctTouchPosition = hasTouchPositionCorrectionData(); - const float centerX = static_cast<float>(correctTouchPosition ? getSweetSpotCenterXAt(keyId) - : getKeyCenterXOfKeyIdG(keyId)); - const float visualKeyCenterY = static_cast<float>(getKeyCenterYOfKeyIdG(keyId)); - float centerY; - if (correctTouchPosition) { - const float sweetSpotCenterY = static_cast<float>(getSweetSpotCenterYAt(keyId)); - const float gapY = sweetSpotCenterY - visualKeyCenterY; - centerY = visualKeyCenterY + gapY * verticalScale; - } else { - centerY = visualKeyCenterY; - } + const int keyId, const int x, const int y, const bool isGeometric) const { + const float centerX = static_cast<float>(getKeyCenterXOfKeyIdG(keyId, x, isGeometric)); + const float centerY = static_cast<float>(getKeyCenterYOfKeyIdG(keyId, y, isGeometric)); const float touchX = static_cast<float>(x); const float touchY = static_cast<float>(y); - const float keyWidth = static_cast<float>(getMostCommonKeyWidth()); return ProximityInfoUtils::getSquaredDistanceFloat(centerX, centerY, touchX, touchY) - / GeometryUtils::SQUARE_FLOAT(keyWidth); + / GeometryUtils::SQUARE_FLOAT(static_cast<float>(getMostCommonKeyWidth())); } int ProximityInfo::getCodePointOf(const int keyIndex) const { @@ -168,41 +157,80 @@ void ProximityInfo::initializeG() { const int lowerCode = CharUtils::toLowerCase(code); mCenterXsG[i] = mKeyXCoordinates[i] + mKeyWidths[i] / 2; mCenterYsG[i] = mKeyYCoordinates[i] + mKeyHeights[i] / 2; + if (hasTouchPositionCorrectionData()) { + // Computes sweet spot center points for geometric input. + const float verticalScale = ProximityInfoParams::VERTICAL_SWEET_SPOT_SCALE_G; + const float sweetSpotCenterY = static_cast<float>(mSweetSpotCenterYs[i]); + const float gapY = sweetSpotCenterY - mCenterYsG[i]; + mSweetSpotCenterYsG[i] = static_cast<int>(mCenterYsG[i] + gapY * verticalScale); + } mCodeToKeyMap[lowerCode] = i; mKeyIndexToCodePointG[i] = lowerCode; } for (int i = 0; i < KEY_COUNT; i++) { mKeyKeyDistancesG[i][i] = 0; for (int j = i + 1; j < KEY_COUNT; j++) { - mKeyKeyDistancesG[i][j] = GeometryUtils::getDistanceInt( - mCenterXsG[i], mCenterYsG[i], mCenterXsG[j], mCenterYsG[j]); + if (hasTouchPositionCorrectionData()) { + // Computes distances using sweet spots if they exist. + // We have two types of Y coordinate sweet spots, for geometric and for the others. + // The sweet spots for geometric input are used for calculating key-key distances + // here. + mKeyKeyDistancesG[i][j] = GeometryUtils::getDistanceInt( + mSweetSpotCenterXs[i], mSweetSpotCenterYsG[i], + mSweetSpotCenterXs[j], mSweetSpotCenterYsG[j]); + } else { + mKeyKeyDistancesG[i][j] = GeometryUtils::getDistanceInt( + mCenterXsG[i], mCenterYsG[i], mCenterXsG[j], mCenterYsG[j]); + } mKeyKeyDistancesG[j][i] = mKeyKeyDistancesG[i][j]; } } } -int ProximityInfo::getKeyCenterXOfCodePointG(int charCode) const { - return getKeyCenterXOfKeyIdG( - ProximityInfoUtils::getKeyIndexOf(KEY_COUNT, charCode, &mCodeToKeyMap)); -} - -int ProximityInfo::getKeyCenterYOfCodePointG(int charCode) const { - return getKeyCenterYOfKeyIdG( - ProximityInfoUtils::getKeyIndexOf(KEY_COUNT, charCode, &mCodeToKeyMap)); -} - -int ProximityInfo::getKeyCenterXOfKeyIdG(int keyId) const { - if (keyId >= 0) { - return mCenterXsG[keyId]; +// referencePointX is used only for keys wider than most common key width. When the referencePointX +// is NOT_A_COORDINATE, this method calculates the return value without using the line segment. +// isGeometric is currently not used because we don't have extra X coordinates sweet spots for +// geometric input. +int ProximityInfo::getKeyCenterXOfKeyIdG( + const int keyId, const int referencePointX, const bool isGeometric) const { + if (keyId < 0) { + return 0; + } + int centerX = (hasTouchPositionCorrectionData()) ? static_cast<int>(mSweetSpotCenterXs[keyId]) + : mCenterXsG[keyId]; + const int keyWidth = mKeyWidths[keyId]; + if (referencePointX != NOT_A_COORDINATE + && keyWidth > getMostCommonKeyWidth()) { + // For keys wider than most common keys, we use a line segment instead of the center point; + // thus, centerX is adjusted depending on referencePointX. + const int keyWidthHalfDiff = (keyWidth - getMostCommonKeyWidth()) / 2; + if (referencePointX < centerX - keyWidthHalfDiff) { + centerX -= keyWidthHalfDiff; + } else if (referencePointX > centerX + keyWidthHalfDiff) { + centerX += keyWidthHalfDiff; + } else { + centerX = referencePointX; + } } - return 0; + return centerX; } -int ProximityInfo::getKeyCenterYOfKeyIdG(int keyId) const { - if (keyId >= 0) { +// referencePointY is currently not used because we don't specially handle keys higher than the +// most common key height. When the referencePointY is NOT_A_COORDINATE, this method should +// calculate the return value without using the line segment. +int ProximityInfo::getKeyCenterYOfKeyIdG( + const int keyId, const int referencePointY, const bool isGeometric) const { + // TODO: Remove "isGeometric" and have separate "proximity_info"s for gesture and typing. + if (keyId < 0) { + return 0; + } + if (!hasTouchPositionCorrectionData()) { return mCenterYsG[keyId]; + } else if (isGeometric) { + return static_cast<int>(mSweetSpotCenterYsG[keyId]); + } else { + return static_cast<int>(mSweetSpotCenterYs[keyId]); } - return 0; } int ProximityInfo::getKeyKeyDistanceG(const int keyId0, const int keyId1) const { diff --git a/native/jni/src/suggest/core/layout/proximity_info.h b/native/jni/src/suggest/core/layout/proximity_info.h index 534c2c217..f25949001 100644 --- a/native/jni/src/suggest/core/layout/proximity_info.h +++ b/native/jni/src/suggest/core/layout/proximity_info.h @@ -37,8 +37,7 @@ class ProximityInfo { bool hasSpaceProximity(const int x, const int y) const; int getNormalizedSquaredDistance(const int inputIndex, const int proximityIndex) const; float getNormalizedSquaredDistanceFromCenterFloatG( - const int keyId, const int x, const int y, - const float verticalScale) const; + const int keyId, const int x, const int y, const bool isGeometric) const; int getCodePointOf(const int keyIndex) const; bool hasSweetSpotData(const int keyIndex) const { // When there are no calibration data for a key, @@ -65,10 +64,10 @@ class ProximityInfo { int getKeyboardHeight() const { return KEYBOARD_HEIGHT; } float getKeyboardHypotenuse() const { return KEYBOARD_HYPOTENUSE; } - int getKeyCenterXOfCodePointG(int charCode) const; - int getKeyCenterYOfCodePointG(int charCode) const; - int getKeyCenterXOfKeyIdG(int keyId) const; - int getKeyCenterYOfKeyIdG(int keyId) const; + int getKeyCenterXOfKeyIdG( + const int keyId, const int referencePointX, const bool isGeometric) const; + int getKeyCenterYOfKeyIdG( + const int keyId, const int referencePointY, const bool isGeometric) const; int getKeyKeyDistanceG(int keyId0, int keyId1) const; AK_FORCE_INLINE void initializeProximities(const int *const inputCodes, @@ -115,6 +114,8 @@ class ProximityInfo { int mKeyCodePoints[MAX_KEY_COUNT_IN_A_KEYBOARD]; float mSweetSpotCenterXs[MAX_KEY_COUNT_IN_A_KEYBOARD]; float mSweetSpotCenterYs[MAX_KEY_COUNT_IN_A_KEYBOARD]; + // Sweet spots for geometric input. Note that we have extra sweet spots only for Y coordinates. + float mSweetSpotCenterYsG[MAX_KEY_COUNT_IN_A_KEYBOARD]; float mSweetSpotRadii[MAX_KEY_COUNT_IN_A_KEYBOARD]; hash_map_compat<int, int> mCodeToKeyMap; diff --git a/native/jni/src/suggest/core/layout/proximity_info_state.cpp b/native/jni/src/suggest/core/layout/proximity_info_state.cpp index e8d950060..7780efdfd 100644 --- a/native/jni/src/suggest/core/layout/proximity_info_state.cpp +++ b/native/jni/src/suggest/core/layout/proximity_info_state.cpp @@ -97,15 +97,10 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi pushTouchPointStartIndex, lastSavedInputSize); } - // TODO: Remove the dependency of "isGeometric" - const float verticalSweetSpotScale = isGeometric - ? ProximityInfoParams::VERTICAL_SWEET_SPOT_SCALE_G - : ProximityInfoParams::VERTICAL_SWEET_SPOT_SCALE; - if (xCoordinates && yCoordinates) { mSampledInputSize = ProximityInfoStateUtils::updateTouchPoints(mProximityInfo, mMaxPointToKeyLength, mInputProximities, xCoordinates, yCoordinates, times, - pointerIds, verticalSweetSpotScale, inputSize, isGeometric, pointerId, + pointerIds, inputSize, isGeometric, pointerId, pushTouchPointStartIndex, &mSampledInputXs, &mSampledInputYs, &mSampledTimes, &mSampledLengthCache, &mSampledInputIndice); } @@ -123,7 +118,7 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi if (mSampledInputSize > 0) { ProximityInfoStateUtils::initGeometricDistanceInfos(mProximityInfo, mSampledInputSize, - lastSavedInputSize, verticalSweetSpotScale, &mSampledInputXs, &mSampledInputYs, + lastSavedInputSize, isGeometric, &mSampledInputXs, &mSampledInputYs, &mSampledNearKeySets, &mSampledNormalizedSquaredLengthCache); if (isGeometric) { // updates probabilities of skipping or mapping each key for all points. diff --git a/native/jni/src/suggest/core/layout/proximity_info_state_utils.cpp b/native/jni/src/suggest/core/layout/proximity_info_state_utils.cpp index 1bbae652c..904671f7f 100644 --- a/native/jni/src/suggest/core/layout/proximity_info_state_utils.cpp +++ b/native/jni/src/suggest/core/layout/proximity_info_state_utils.cpp @@ -43,8 +43,8 @@ namespace latinime { const ProximityInfo *const proximityInfo, const int maxPointToKeyLength, const int *const inputProximities, const int *const inputXCoordinates, const int *const inputYCoordinates, const int *const times, const int *const pointerIds, - const float verticalSweetSpotScale, const int inputSize, const bool isGeometric, - const int pointerId, const int pushTouchPointStartIndex, std::vector<int> *sampledInputXs, + const int inputSize, const bool isGeometric, const int pointerId, + const int pushTouchPointStartIndex, std::vector<int> *sampledInputXs, std::vector<int> *sampledInputYs, std::vector<int> *sampledInputTimes, std::vector<int> *sampledLengthCache, std::vector<int> *sampledInputIndice) { if (DEBUG_SAMPLING_POINTS) { @@ -113,7 +113,7 @@ namespace latinime { } if (pushTouchPoint(proximityInfo, maxPointToKeyLength, i, c, x, y, time, - verticalSweetSpotScale, isGeometric /* doSampling */, i == lastInputIndex, + isGeometric, isGeometric /* doSampling */, i == lastInputIndex, sumAngle, currentNearKeysDistances, prevNearKeysDistances, prevPrevNearKeysDistances, sampledInputXs, sampledInputYs, sampledInputTimes, sampledLengthCache, sampledInputIndice)) { @@ -183,7 +183,7 @@ namespace latinime { /* static */ void ProximityInfoStateUtils::initGeometricDistanceInfos( const ProximityInfo *const proximityInfo, const int sampledInputSize, - const int lastSavedInputSize, const float verticalSweetSpotScale, + const int lastSavedInputSize, const bool isGeometric, const std::vector<int> *const sampledInputXs, const std::vector<int> *const sampledInputYs, std::vector<NearKeycodesSet> *sampledNearKeySets, @@ -199,7 +199,7 @@ namespace latinime { const int y = (*sampledInputYs)[i]; const float normalizedSquaredDistance = proximityInfo->getNormalizedSquaredDistanceFromCenterFloatG( - k, x, y, verticalSweetSpotScale); + k, x, y, isGeometric); (*sampledNormalizedSquaredLengthCache)[index] = normalizedSquaredDistance; if (normalizedSquaredDistance < ProximityInfoParams::NEAR_KEY_NORMALIZED_SQUARED_THRESHOLD) { @@ -317,14 +317,13 @@ namespace latinime { // the given point and the nearest key position. /* static */ float ProximityInfoStateUtils::updateNearKeysDistances( const ProximityInfo *const proximityInfo, const float maxPointToKeyLength, const int x, - const int y, const float verticalSweetspotScale, - NearKeysDistanceMap *const currentNearKeysDistances) { + const int y, const bool isGeometric, NearKeysDistanceMap *const currentNearKeysDistances) { currentNearKeysDistances->clear(); const int keyCount = proximityInfo->getKeyCount(); float nearestKeyDistance = maxPointToKeyLength; for (int k = 0; k < keyCount; ++k) { const float dist = proximityInfo->getNormalizedSquaredDistanceFromCenterFloatG(k, x, y, - verticalSweetspotScale); + isGeometric); if (dist < ProximityInfoParams::NEAR_KEY_THRESHOLD_FOR_DISTANCE) { currentNearKeysDistances->insert(std::pair<int, float>(k, dist)); } @@ -405,7 +404,7 @@ namespace latinime { // Returning if previous point is popped or not. /* static */ bool ProximityInfoStateUtils::pushTouchPoint(const ProximityInfo *const proximityInfo, const int maxPointToKeyLength, const int inputIndex, const int nodeCodePoint, int x, int y, - const int time, const float verticalSweetSpotScale, const bool doSampling, + const int time, const bool isGeometric, const bool doSampling, const bool isLastPoint, const float sumAngle, NearKeysDistanceMap *const currentNearKeysDistances, const NearKeysDistanceMap *const prevNearKeysDistances, @@ -419,7 +418,7 @@ namespace latinime { bool popped = false; if (nodeCodePoint < 0 && doSampling) { const float nearest = updateNearKeysDistances(proximityInfo, maxPointToKeyLength, x, y, - verticalSweetSpotScale, currentNearKeysDistances); + isGeometric, currentNearKeysDistances); const float score = getPointScore(mostCommonKeyWidth, x, y, time, isLastPoint, nearest, sumAngle, currentNearKeysDistances, prevNearKeysDistances, prevPrevNearKeysDistances, sampledInputXs, sampledInputYs); @@ -453,8 +452,8 @@ namespace latinime { if (nodeCodePoint >= 0 && (x < 0 || y < 0)) { const int keyId = proximityInfo->getKeyIndexOf(nodeCodePoint); if (keyId >= 0) { - x = proximityInfo->getKeyCenterXOfKeyIdG(keyId); - y = proximityInfo->getKeyCenterYOfKeyIdG(keyId); + x = proximityInfo->getKeyCenterXOfKeyIdG(keyId, NOT_AN_INDEX, isGeometric); + y = proximityInfo->getKeyCenterYOfKeyIdG(keyId, NOT_AN_INDEX, isGeometric); } } diff --git a/native/jni/src/suggest/core/layout/proximity_info_state_utils.h b/native/jni/src/suggest/core/layout/proximity_info_state_utils.h index 66fe07926..6de970033 100644 --- a/native/jni/src/suggest/core/layout/proximity_info_state_utils.h +++ b/native/jni/src/suggest/core/layout/proximity_info_state_utils.h @@ -38,8 +38,7 @@ class ProximityInfoStateUtils { static int updateTouchPoints(const ProximityInfo *const proximityInfo, const int maxPointToKeyLength, const int *const inputProximities, const int *const inputXCoordinates, const int *const inputYCoordinates, - const int *const times, const int *const pointerIds, - const float verticalSweetSpotScale, const int inputSize, + const int *const times, const int *const pointerIds, const int inputSize, const bool isGeometric, const int pointerId, const int pushTouchPointStartIndex, std::vector<int> *sampledInputXs, std::vector<int> *sampledInputYs, std::vector<int> *sampledInputTimes, std::vector<int> *sampledLengthCache, @@ -84,8 +83,7 @@ class ProximityInfoStateUtils { const std::vector<float> *const sampledNormalizedSquaredLengthCache, const int keyCount, const int inputIndex, const int keyId); static void initGeometricDistanceInfos(const ProximityInfo *const proximityInfo, - const int sampledInputSize, const int lastSavedInputSize, - const float verticalSweetSpotScale, + const int sampledInputSize, const int lastSavedInputSize, const bool isGeometric, const std::vector<int> *const sampledInputXs, const std::vector<int> *const sampledInputYs, std::vector<NearKeycodesSet> *sampledNearKeySets, @@ -120,7 +118,7 @@ class ProximityInfoStateUtils { static float updateNearKeysDistances(const ProximityInfo *const proximityInfo, const float maxPointToKeyLength, const int x, const int y, - const float verticalSweetSpotScale, + const bool isGeometric, NearKeysDistanceMap *const currentNearKeysDistances); static bool isPrevLocalMin(const NearKeysDistanceMap *const currentNearKeysDistances, const NearKeysDistanceMap *const prevNearKeysDistances, @@ -133,7 +131,7 @@ class ProximityInfoStateUtils { std::vector<int> *sampledInputXs, std::vector<int> *sampledInputYs); static bool pushTouchPoint(const ProximityInfo *const proximityInfo, const int maxPointToKeyLength, const int inputIndex, const int nodeCodePoint, int x, - int y, const int time, const float verticalSweetSpotScale, + int y, const int time, const bool isGeometric, const bool doSampling, const bool isLastPoint, const float sumAngle, NearKeysDistanceMap *const currentNearKeysDistances, const NearKeysDistanceMap *const prevNearKeysDistances, diff --git a/native/jni/src/suggest/core/policy/weighting.cpp b/native/jni/src/suggest/core/policy/weighting.cpp index 0c57ca001..117f48f29 100644 --- a/native/jni/src/suggest/core/policy/weighting.cpp +++ b/native/jni/src/suggest/core/policy/weighting.cpp @@ -106,7 +106,7 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n // only used for typing return weighting->getSubstitutionCost(); case CT_NEW_WORD_SPACE_OMITTION: - return weighting->getNewWordCost(traverseSession, dicNode); + return weighting->getNewWordSpatialCost(traverseSession, dicNode, inputStateG); case CT_MATCH: return weighting->getMatchedCost(traverseSession, dicNode, inputStateG); case CT_COMPLETION: @@ -134,7 +134,8 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n case CT_SUBSTITUTION: return 0.0f; case CT_NEW_WORD_SPACE_OMITTION: - return weighting->getNewWordBigramCost(traverseSession, parentDicNode, multiBigramMap); + return weighting->getNewWordBigramLanguageCost( + traverseSession, parentDicNode, multiBigramMap); case CT_MATCH: return 0.0f; case CT_COMPLETION: @@ -146,7 +147,8 @@ static inline void profile(const CorrectionType correctionType, DicNode *const n return weighting->getTerminalLanguageCost(traverseSession, dicNode, languageImprobability); } case CT_NEW_WORD_SPACE_SUBSTITUTION: - return weighting->getNewWordBigramCost(traverseSession, parentDicNode, multiBigramMap); + return weighting->getNewWordBigramLanguageCost( + traverseSession, parentDicNode, multiBigramMap); case CT_INSERTION: return 0.0f; case CT_TRANSPOSITION: diff --git a/native/jni/src/suggest/core/policy/weighting.h b/native/jni/src/suggest/core/policy/weighting.h index 0d2745b40..781a7adbc 100644 --- a/native/jni/src/suggest/core/policy/weighting.h +++ b/native/jni/src/suggest/core/policy/weighting.h @@ -56,10 +56,10 @@ class Weighting { const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0; - virtual float getNewWordCost(const DicTraverseSession *const traverseSession, - const DicNode *const dicNode) const = 0; + virtual float getNewWordSpatialCost(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode, DicNode_InputStateG *const inputStateG) const = 0; - virtual float getNewWordBigramCost( + virtual float getNewWordBigramLanguageCost( const DicTraverseSession *const traverseSession, const DicNode *const dicNode, MultiBigramMap *const multiBigramMap) const = 0; diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp index 6c4a6c166..a8f16c8cb 100644 --- a/native/jni/src/suggest/core/suggest.cpp +++ b/native/jni/src/suggest/core/suggest.cpp @@ -530,6 +530,12 @@ void Suggest::createNextWordDicNode(DicTraverseSession *traverseSession, DicNode CT_NEW_WORD_SPACE_SUBSTITUTION : CT_NEW_WORD_SPACE_OMITTION; Weighting::addCostAndForwardInputIndex(WEIGHTING, correctionType, traverseSession, dicNode, &newDicNode, traverseSession->getMultiBigramMap()); - traverseSession->getDicTraverseCache()->copyPushNextActive(&newDicNode); + if (newDicNode.getCompoundDistance() < static_cast<float>(MAX_VALUE_FOR_WEIGHTING)) { + // newDicNode is worth continuing to traverse. + // CAVEAT: This pruning is important for speed. Remove this when we can afford not to prune + // here because here is not the right place to do pruning. Pruning should take place only + // in DicNodePriorityQueue. + traverseSession->getDicTraverseCache()->copyPushNextActive(&newDicNode); + } } } // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h index 17fa11082..a1c99182a 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h @@ -138,12 +138,12 @@ class TypingWeighting : public Weighting { return cost + weightedDistance; } - float getNewWordCost(const DicTraverseSession *const traverseSession, - const DicNode *const dicNode) const { + float getNewWordSpatialCost(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode, DicNode_InputStateG *inputStateG) const { return ScoringParams::COST_NEW_WORD * traverseSession->getMultiWordCostMultiplier(); } - float getNewWordBigramCost(const DicTraverseSession *const traverseSession, + float getNewWordBigramLanguageCost(const DicTraverseSession *const traverseSession, const DicNode *const dicNode, MultiBigramMap *const multiBigramMap) const { return DicNodeUtils::getBigramNodeImprobability(traverseSession->getBinaryDictionaryInfo(), |