aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--java-overridable/src/com/android/inputmethod/latin/personalization/PersonalizationDictionaryUpdater.java9
-rw-r--r--java/src/com/android/inputmethod/latin/DictionaryFacilitator.java50
-rw-r--r--java/src/com/android/inputmethod/latin/LatinIME.java9
-rw-r--r--java/src/com/android/inputmethod/latin/PersonalizationDictionaryFacilitator.java185
-rw-r--r--java/src/com/android/inputmethod/latin/personalization/PersonalizationDataChunk.java2
-rw-r--r--native/jni/src/suggest/core/dicnode/dic_node.h27
-rw-r--r--native/jni/src/suggest/core/dicnode/dic_node_utils.cpp24
-rw-r--r--native/jni/src/suggest/core/dicnode/dic_node_utils.h4
-rw-r--r--native/jni/src/suggest/core/dicnode/dic_node_vector.h7
-rw-r--r--native/jni/src/suggest/core/dicnode/internal/dic_node_properties.h30
-rw-r--r--native/jni/src/suggest/core/dictionary/dictionary.cpp19
-rw-r--r--native/jni/src/suggest/core/dictionary/dictionary.h4
-rw-r--r--native/jni/src/suggest/core/dictionary/dictionary_utils.cpp9
-rw-r--r--native/jni/src/suggest/core/dictionary/word_attributes.h60
-rw-r--r--native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h6
-rw-r--r--native/jni/src/suggest/core/policy/traversal.h3
-rw-r--r--native/jni/src/suggest/core/result/suggestions_output_utils.cpp14
-rw-r--r--native/jni/src/suggest/core/suggest.cpp7
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp39
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h5
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h7
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp36
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h5
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp34
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h2
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp27
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h3
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h5
-rw-r--r--native/jni/src/suggest/policyimpl/typing/typing_traversal.h5
-rw-r--r--native/jni/src/utils/int_array_view.h6
-rw-r--r--native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp58
-rw-r--r--native/jni/tests/utils/int_array_view_test.cpp2
-rw-r--r--tests/src/com/android/inputmethod/latin/personalization/PersonalizationDictionaryTests.java20
33 files changed, 545 insertions, 178 deletions
diff --git a/java-overridable/src/com/android/inputmethod/latin/personalization/PersonalizationDictionaryUpdater.java b/java-overridable/src/com/android/inputmethod/latin/personalization/PersonalizationDictionaryUpdater.java
index c97a0d232..8b66cff53 100644
--- a/java-overridable/src/com/android/inputmethod/latin/personalization/PersonalizationDictionaryUpdater.java
+++ b/java-overridable/src/com/android/inputmethod/latin/personalization/PersonalizationDictionaryUpdater.java
@@ -16,8 +16,6 @@
package com.android.inputmethod.latin.personalization;
-import java.util.Locale;
-
import android.content.Context;
import com.android.inputmethod.latin.DictionaryFacilitator;
@@ -33,12 +31,7 @@ public class PersonalizationDictionaryUpdater {
mDictionaryFacilitator = dictionaryFacilitator;
}
- public Locale getLocale() {
- return null;
- }
-
- public void onLoadSettings(final boolean usePersonalizedDicts,
- final boolean isSystemLocaleSameAsLocaleOfAllEnabledSubtypesOfEnabledImes) {
+ public void onLoadSettings(final boolean usePersonalizedDicts) {
if (!mDictCleared) {
// Clear and never update the personalization dictionary.
PersonalizationHelper.removeAllPersonalizationDictionaries(mContext);
diff --git a/java/src/com/android/inputmethod/latin/DictionaryFacilitator.java b/java/src/com/android/inputmethod/latin/DictionaryFacilitator.java
index 47aaeadcc..46428839f 100644
--- a/java/src/com/android/inputmethod/latin/DictionaryFacilitator.java
+++ b/java/src/com/android/inputmethod/latin/DictionaryFacilitator.java
@@ -24,6 +24,7 @@ import android.view.inputmethod.InputMethodSubtype;
import com.android.inputmethod.annotations.UsedForTesting;
import com.android.inputmethod.keyboard.ProximityInfo;
+import com.android.inputmethod.latin.ExpandableBinaryDictionary.AddMultipleDictionaryEntriesCallback;
import com.android.inputmethod.latin.PrevWordsInfo.WordInfo;
import com.android.inputmethod.latin.SuggestedWords.SuggestedWordInfo;
import com.android.inputmethod.latin.personalization.ContextualDictionary;
@@ -36,7 +37,6 @@ import com.android.inputmethod.latin.utils.DistracterFilter;
import com.android.inputmethod.latin.utils.DistracterFilterCheckingExactMatchesAndSuggestions;
import com.android.inputmethod.latin.utils.DistracterFilterCheckingIsInDictionary;
import com.android.inputmethod.latin.utils.ExecutorUtils;
-import com.android.inputmethod.latin.utils.LanguageModelParam;
import com.android.inputmethod.latin.utils.SuggestionResults;
import java.io.File;
@@ -67,6 +67,7 @@ public class DictionaryFacilitator {
// To synchronize assigning mDictionaryGroup to ensure closing dictionaries.
private final Object mLock = new Object();
private final DistracterFilter mDistracterFilter;
+ private final PersonalizationDictionaryFacilitator mPersonalizationDictionaryFacilitator;
private static final String[] DICT_TYPES_ORDERED_TO_GET_SUGGESTIONS =
new String[] {
@@ -174,14 +175,22 @@ public class DictionaryFacilitator {
public DictionaryFacilitator() {
mDistracterFilter = DistracterFilter.EMPTY_DISTRACTER_FILTER;
+ mPersonalizationDictionaryFacilitator = null;
}
public DictionaryFacilitator(final Context context) {
mDistracterFilter = new DistracterFilterCheckingExactMatchesAndSuggestions(context);
+ mPersonalizationDictionaryFacilitator =
+ new PersonalizationDictionaryFacilitator(context, mDistracterFilter);
}
public void updateEnabledSubtypes(final List<InputMethodSubtype> enabledSubtypes) {
mDistracterFilter.updateEnabledSubtypes(enabledSubtypes);
+ mPersonalizationDictionaryFacilitator.updateEnabledSubtypes(enabledSubtypes);
+ }
+
+ public void setIsMonolingualUser(final boolean isMonolingualUser) {
+ mPersonalizationDictionaryFacilitator.setIsMonolingualUser(isMonolingualUser);
}
public Locale getLocale() {
@@ -353,6 +362,9 @@ public class DictionaryFacilitator {
dictionaryGroup.closeDict(dictType);
}
mDistracterFilter.close();
+ if (mPersonalizationDictionaryFacilitator != null) {
+ mPersonalizationDictionaryFacilitator.close();
+ }
}
@UsedForTesting
@@ -372,11 +384,11 @@ public class DictionaryFacilitator {
}
public void flushPersonalizationDictionary() {
- final ExpandableBinaryDictionary personalizationDict =
+ final ExpandableBinaryDictionary personalizationDictUsedForSuggestion =
mDictionaryGroup.getSubDict(Dictionary.TYPE_PERSONALIZATION);
- if (personalizationDict != null) {
- personalizationDict.asyncFlushBinaryDictionary();
- }
+ mPersonalizationDictionaryFacilitator.flushPersonalizationDictionariesToUpdate(
+ personalizationDictUsedForSuggestion);
+ mDistracterFilter.close();
}
public void waitForLoadingMainDictionary(final long timeout, final TimeUnit unit)
@@ -580,6 +592,7 @@ public class DictionaryFacilitator {
// personalization dictionary.
public void clearPersonalizationDictionary() {
clearSubDictionary(Dictionary.TYPE_PERSONALIZATION);
+ mPersonalizationDictionaryFacilitator.clearDictionariesToUpdate();
}
public void clearContextualDictionary() {
@@ -589,30 +602,9 @@ public class DictionaryFacilitator {
public void addEntriesToPersonalizationDictionary(
final PersonalizationDataChunk personalizationDataChunk,
final SpacingAndPunctuations spacingAndPunctuations,
- final ExpandableBinaryDictionary.AddMultipleDictionaryEntriesCallback callback) {
- final ExpandableBinaryDictionary personalizationDict =
- mDictionaryGroup.getSubDict(Dictionary.TYPE_PERSONALIZATION);
- if (personalizationDict == null) {
- if (callback != null) {
- callback.onFinished();
- }
- return;
- }
- // TODO: Get locale from personalizationDataChunk.mDetectedLanguage.
- final Locale dataChunkLocale = getLocale();
- final ArrayList<LanguageModelParam> languageModelParams =
- LanguageModelParam.createLanguageModelParamsFrom(
- personalizationDataChunk.mTokens,
- personalizationDataChunk.mTimestampInSeconds, spacingAndPunctuations,
- dataChunkLocale, new DistracterFilterCheckingIsInDictionary(
- mDistracterFilter, personalizationDict));
- if (languageModelParams == null || languageModelParams.isEmpty()) {
- if (callback != null) {
- callback.onFinished();
- }
- return;
- }
- personalizationDict.addMultipleDictionaryEntriesDynamically(languageModelParams, callback);
+ final AddMultipleDictionaryEntriesCallback callback) {
+ mPersonalizationDictionaryFacilitator.addEntriesToPersonalizationDictionariesToUpdate(
+ getLocale(), personalizationDataChunk, spacingAndPunctuations, callback);
}
public void addPhraseToContextualDictionary(final String[] phrase, final int probability,
diff --git a/java/src/com/android/inputmethod/latin/LatinIME.java b/java/src/com/android/inputmethod/latin/LatinIME.java
index c853d2d68..a6243430b 100644
--- a/java/src/com/android/inputmethod/latin/LatinIME.java
+++ b/java/src/com/android/inputmethod/latin/LatinIME.java
@@ -614,9 +614,10 @@ public class LatinIME extends InputMethodService implements KeyboardActionListen
private void refreshPersonalizationDictionarySession(
final SettingsValues currentSettingsValues) {
- mPersonalizationDictionaryUpdater.onLoadSettings(
- currentSettingsValues.mUsePersonalizedDicts,
+ mDictionaryFacilitator.setIsMonolingualUser(
mSubtypeSwitcher.isSystemLocaleSameAsLocaleOfAllEnabledSubtypesOfEnabledImes());
+ mPersonalizationDictionaryUpdater.onLoadSettings(
+ currentSettingsValues.mUsePersonalizedDicts);
mContextualDictionaryUpdater.onLoadSettings(currentSettingsValues.mUsePersonalizedDicts);
final boolean shouldKeepUserHistoryDictionaries;
if (currentSettingsValues.mUsePersonalizedDicts) {
@@ -734,10 +735,6 @@ public class LatinIME extends InputMethodService implements KeyboardActionListen
cleanupInternalStateForFinishInput();
}
}
- // TODO: Remove this test.
- if (!conf.locale.equals(mPersonalizationDictionaryUpdater.getLocale())) {
- refreshPersonalizationDictionarySession(settingsValues);
- }
super.onConfigurationChanged(conf);
}
diff --git a/java/src/com/android/inputmethod/latin/PersonalizationDictionaryFacilitator.java b/java/src/com/android/inputmethod/latin/PersonalizationDictionaryFacilitator.java
new file mode 100644
index 000000000..aa8e312a4
--- /dev/null
+++ b/java/src/com/android/inputmethod/latin/PersonalizationDictionaryFacilitator.java
@@ -0,0 +1,185 @@
+/*
+ * Copyright (C) 2014 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.inputmethod.latin;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Locale;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import android.content.Context;
+import android.view.inputmethod.InputMethodSubtype;
+
+import com.android.inputmethod.latin.ExpandableBinaryDictionary.AddMultipleDictionaryEntriesCallback;
+import com.android.inputmethod.latin.personalization.PersonalizationDataChunk;
+import com.android.inputmethod.latin.personalization.PersonalizationDictionary;
+import com.android.inputmethod.latin.settings.SpacingAndPunctuations;
+import com.android.inputmethod.latin.utils.DistracterFilter;
+import com.android.inputmethod.latin.utils.DistracterFilterCheckingIsInDictionary;
+import com.android.inputmethod.latin.utils.LanguageModelParam;
+import com.android.inputmethod.latin.utils.SubtypeLocaleUtils;
+
+/**
+ * Class for managing and updating personalization dictionaries.
+ */
+public class PersonalizationDictionaryFacilitator {
+ private final Context mContext;
+ private final DistracterFilter mDistracterFilter;
+ private final HashMap<String, HashSet<Locale>> mLangToLocalesMap = new HashMap<>();
+ private final HashMap<Locale, ExpandableBinaryDictionary> mPersonalizationDictsToUpdate =
+ new HashMap<>();
+ private boolean mIsMonolingualUser = false;;
+
+ PersonalizationDictionaryFacilitator(final Context context,
+ final DistracterFilter distracterFilter) {
+ mContext = context;
+ mDistracterFilter = distracterFilter;
+ }
+
+ public void close() {
+ mLangToLocalesMap.clear();
+ for (final ExpandableBinaryDictionary dict : mPersonalizationDictsToUpdate.values()) {
+ dict.close();
+ }
+ mPersonalizationDictsToUpdate.clear();
+ }
+
+ public void clearDictionariesToUpdate() {
+ for (final ExpandableBinaryDictionary dict : mPersonalizationDictsToUpdate.values()) {
+ dict.clear();
+ }
+ mPersonalizationDictsToUpdate.clear();
+ }
+
+ public void updateEnabledSubtypes(final List<InputMethodSubtype> enabledSubtypes) {
+ for (final InputMethodSubtype subtype : enabledSubtypes) {
+ final Locale locale = SubtypeLocaleUtils.getSubtypeLocale(subtype);
+ final String language = locale.getLanguage();
+ final HashSet<Locale> locales = mLangToLocalesMap.get(language);
+ if (locales != null) {
+ locales.add(locale);
+ } else {
+ final HashSet<Locale> localeSet = new HashSet<>();
+ localeSet.add(locale);
+ mLangToLocalesMap.put(language, localeSet);
+ }
+ }
+ }
+
+ public void setIsMonolingualUser(final boolean isMonolingualUser) {
+ mIsMonolingualUser = isMonolingualUser;
+ }
+
+ /**
+ * Flush personalization dictionaries to dictionary files. Close dictionaries after writing
+ * files except the dictionary that is used for generating suggestions.
+ *
+ * @param personalizationDictUsedForSuggestion the personalization dictionary used for
+ * generating suggestions that won't be closed.
+ */
+ public void flushPersonalizationDictionariesToUpdate(
+ final ExpandableBinaryDictionary personalizationDictUsedForSuggestion) {
+ for (final ExpandableBinaryDictionary personalizationDict :
+ mPersonalizationDictsToUpdate.values()) {
+ personalizationDict.asyncFlushBinaryDictionary();
+ if (personalizationDict != personalizationDictUsedForSuggestion) {
+ // Close if the dictionary is not being used for suggestion.
+ personalizationDict.close();
+ }
+ }
+ mDistracterFilter.close();
+ mPersonalizationDictsToUpdate.clear();
+ }
+
+ private ExpandableBinaryDictionary getPersonalizationDictToUpdate(final Context context,
+ final Locale locale) {
+ ExpandableBinaryDictionary personalizationDict = mPersonalizationDictsToUpdate.get(locale);
+ if (personalizationDict != null) {
+ return personalizationDict;
+ }
+ personalizationDict = PersonalizationDictionary.getDictionary(context, locale,
+ null /* dictFile */, "" /* dictNamePrefix */);
+ mPersonalizationDictsToUpdate.put(locale, personalizationDict);
+ return personalizationDict;
+ }
+
+ private void addEntriesToPersonalizationDictionariesForLocale(final Locale locale,
+ final PersonalizationDataChunk personalizationDataChunk,
+ final SpacingAndPunctuations spacingAndPunctuations,
+ final AddMultipleDictionaryEntriesCallback callback) {
+ final ExpandableBinaryDictionary personalizationDict =
+ getPersonalizationDictToUpdate(mContext, locale);
+ if (personalizationDict == null) {
+ if (callback != null) {
+ callback.onFinished();
+ }
+ return;
+ }
+ final ArrayList<LanguageModelParam> languageModelParams =
+ LanguageModelParam.createLanguageModelParamsFrom(
+ personalizationDataChunk.mTokens,
+ personalizationDataChunk.mTimestampInSeconds, spacingAndPunctuations,
+ locale, new DistracterFilterCheckingIsInDictionary(
+ mDistracterFilter, personalizationDict));
+ if (languageModelParams == null || languageModelParams.isEmpty()) {
+ if (callback != null) {
+ callback.onFinished();
+ }
+ return;
+ }
+ personalizationDict.addMultipleDictionaryEntriesDynamically(languageModelParams, callback);
+ }
+
+ public void addEntriesToPersonalizationDictionariesToUpdate(final Locale defaultLocale,
+ final PersonalizationDataChunk personalizationDataChunk,
+ final SpacingAndPunctuations spacingAndPunctuations,
+ final AddMultipleDictionaryEntriesCallback callback) {
+ final String language = personalizationDataChunk.mDetectedLanguage;
+ final HashSet<Locale> locales;
+ if (mIsMonolingualUser && PersonalizationDataChunk.LANGUAGE_UNKNOWN.equals(language)
+ && mLangToLocalesMap.size() == 1) {
+ locales = mLangToLocalesMap.get(defaultLocale.getLanguage());
+ } else {
+ locales = mLangToLocalesMap.get(language);
+ }
+ if (locales == null || locales.isEmpty()) {
+ if (callback != null) {
+ callback.onFinished();
+ }
+ return;
+ }
+ final AtomicInteger remainingTaskCount = new AtomicInteger(locales.size());
+ final AddMultipleDictionaryEntriesCallback callbackForLocales =
+ new AddMultipleDictionaryEntriesCallback() {
+ @Override
+ public void onFinished() {
+ if (remainingTaskCount.decrementAndGet() == 0) {
+ // Update tasks for all locales have been finished.
+ if (callback != null) {
+ callback.onFinished();
+ }
+ }
+ }
+ };
+ for (final Locale locale : locales) {
+ addEntriesToPersonalizationDictionariesForLocale(locale, personalizationDataChunk,
+ spacingAndPunctuations, callbackForLocales);
+ }
+ }
+}
diff --git a/java/src/com/android/inputmethod/latin/personalization/PersonalizationDataChunk.java b/java/src/com/android/inputmethod/latin/personalization/PersonalizationDataChunk.java
index 6f4b09741..734ed5583 100644
--- a/java/src/com/android/inputmethod/latin/personalization/PersonalizationDataChunk.java
+++ b/java/src/com/android/inputmethod/latin/personalization/PersonalizationDataChunk.java
@@ -20,6 +20,8 @@ import java.util.Collections;
import java.util.List;
public class PersonalizationDataChunk {
+ public static final String LANGUAGE_UNKNOWN = "";
+
public final boolean mInputByUser;
public final List<String> mTokens;
public final int mTimestampInSeconds;
diff --git a/native/jni/src/suggest/core/dicnode/dic_node.h b/native/jni/src/suggest/core/dicnode/dic_node.h
index 32ff0ce18..3970963e8 100644
--- a/native/jni/src/suggest/core/dicnode/dic_node.h
+++ b/native/jni/src/suggest/core/dicnode/dic_node.h
@@ -26,6 +26,7 @@
#include "suggest/core/dictionary/error_type_utils.h"
#include "suggest/core/layout/proximity_info_state.h"
#include "utils/char_utils.h"
+#include "utils/int_array_view.h"
#if DEBUG_DICT
#define LOGI_SHOW_ADD_COST_PROP \
@@ -136,17 +137,15 @@ class DicNode {
}
void initAsChild(const DicNode *const dicNode, const int childrenPtNodeArrayPos,
- const int unigramProbability, const int wordId, const bool isBlacklistedOrNotAWord,
- const uint16_t mergedNodeCodePointCount, const int *const mergedNodeCodePoints) {
+ const int wordId, const CodePointArrayView mergedCodePoints) {
uint16_t newDepth = static_cast<uint16_t>(dicNode->getNodeCodePointCount() + 1);
mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion;
const uint16_t newLeavingDepth = static_cast<uint16_t>(
- dicNode->mDicNodeProperties.getLeavingDepth() + mergedNodeCodePointCount);
- mDicNodeProperties.init(childrenPtNodeArrayPos, mergedNodeCodePoints[0],
- unigramProbability, wordId, isBlacklistedOrNotAWord, newDepth, newLeavingDepth,
- dicNode->mDicNodeProperties.getPrevWordIds());
- mDicNodeState.init(&dicNode->mDicNodeState, mergedNodeCodePointCount,
- mergedNodeCodePoints);
+ dicNode->mDicNodeProperties.getLeavingDepth() + mergedCodePoints.size());
+ mDicNodeProperties.init(childrenPtNodeArrayPos, mergedCodePoints[0],
+ wordId, newDepth, newLeavingDepth, dicNode->mDicNodeProperties.getPrevWordIds());
+ mDicNodeState.init(&dicNode->mDicNodeState, mergedCodePoints.size(),
+ mergedCodePoints.data());
PROF_NODE_COPY(&dicNode->mProfiler, mProfiler);
}
@@ -178,9 +177,6 @@ class DicNode {
// Check if the current word and the previous word can be considered as a valid multiple word
// suggestion.
bool isValidMultipleWordSuggestion() const {
- if (isBlacklistedOrNotAWord()) {
- return false;
- }
// Treat suggestion as invalid if the current and the previous word are single character
// words.
const int prevWordLen = mDicNodeState.mDicNodeStateOutput.getPrevWordsLength()
@@ -217,11 +213,6 @@ class DicNode {
return mDicNodeProperties.getChildrenPtNodeArrayPos();
}
- // TODO: Remove
- int getUnigramProbability() const {
- return mDicNodeProperties.getUnigramProbability();
- }
-
AK_FORCE_INLINE bool isTerminalDicNode() const {
const bool isTerminalPtNode = mDicNodeProperties.isTerminal();
const int currentDicNodeDepth = getNodeCodePointCount();
@@ -404,10 +395,6 @@ class DicNode {
return mDicNodeState.mDicNodeStateScoring.getContainedErrorTypes();
}
- bool isBlacklistedOrNotAWord() const {
- return mDicNodeProperties.isBlacklistedOrNotAWord();
- }
-
inline uint16_t getNodeCodePointCount() const {
return mDicNodeProperties.getDepth();
}
diff --git a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp
index 9f03e30d1..fe5fe8448 100644
--- a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp
+++ b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp
@@ -18,7 +18,6 @@
#include "suggest/core/dicnode/dic_node.h"
#include "suggest/core/dicnode/dic_node_vector.h"
-#include "suggest/core/dictionary/multi_bigram_map.h"
#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h"
namespace latinime {
@@ -73,25 +72,16 @@ namespace latinime {
if (dicNode->hasMultipleWords() && !dicNode->isValidMultipleWordSuggestion()) {
return static_cast<float>(MAX_VALUE_FOR_WEIGHTING);
}
- const int probability = getBigramNodeProbability(dictionaryStructurePolicy, dicNode,
- multiBigramMap);
+ const WordAttributes wordAttributes = dictionaryStructurePolicy->getWordAttributesInContext(
+ dicNode->getPrevWordIds(), dicNode->getWordId(), multiBigramMap);
+ if (dicNode->hasMultipleWords()
+ && (wordAttributes.isBlacklisted() || wordAttributes.isNotAWord())) {
+ return static_cast<float>(MAX_VALUE_FOR_WEIGHTING);
+ }
// TODO: This equation to calculate the improbability looks unreasonable. Investigate this.
- const float cost = static_cast<float>(MAX_PROBABILITY - probability)
+ const float cost = static_cast<float>(MAX_PROBABILITY - wordAttributes.getProbability())
/ static_cast<float>(MAX_PROBABILITY);
return cost;
}
-/* static */ int DicNodeUtils::getBigramNodeProbability(
- const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy,
- const DicNode *const dicNode, MultiBigramMap *const multiBigramMap) {
- const int unigramProbability = dicNode->getUnigramProbability();
- if (multiBigramMap) {
- const int *const prevWordIds = dicNode->getPrevWordIds();
- return multiBigramMap->getBigramProbability(dictionaryStructurePolicy,
- prevWordIds, dicNode->getWordId(), unigramProbability);
- }
- return dictionaryStructurePolicy->getProbability(unigramProbability,
- NOT_A_PROBABILITY);
-}
-
} // namespace latinime
diff --git a/native/jni/src/suggest/core/dicnode/dic_node_utils.h b/native/jni/src/suggest/core/dicnode/dic_node_utils.h
index 56ff6e3d0..961a1c29d 100644
--- a/native/jni/src/suggest/core/dicnode/dic_node_utils.h
+++ b/native/jni/src/suggest/core/dicnode/dic_node_utils.h
@@ -46,10 +46,6 @@ class DicNodeUtils {
DISALLOW_IMPLICIT_CONSTRUCTORS(DicNodeUtils);
// Max number of bigrams to look up
static const int MAX_BIGRAMS_CONSIDERED_PER_CONTEXT = 500;
-
- static int getBigramNodeProbability(
- const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy,
- const DicNode *const dicNode, MultiBigramMap *const multiBigramMap);
};
} // namespace latinime
#endif // LATINIME_DIC_NODE_UTILS_H
diff --git a/native/jni/src/suggest/core/dicnode/dic_node_vector.h b/native/jni/src/suggest/core/dicnode/dic_node_vector.h
index dfeb3fc1f..e6b758954 100644
--- a/native/jni/src/suggest/core/dicnode/dic_node_vector.h
+++ b/native/jni/src/suggest/core/dicnode/dic_node_vector.h
@@ -21,6 +21,7 @@
#include "defines.h"
#include "suggest/core/dicnode/dic_node.h"
+#include "utils/int_array_view.h"
namespace latinime {
@@ -59,12 +60,10 @@ class DicNodeVector {
}
void pushLeavingChild(const DicNode *const dicNode, const int childrenPtNodeArrayPos,
- const int unigramProbability, const int wordId, const bool isBlacklistedOrNotAWord,
- const uint16_t mergedNodeCodePointCount, const int *const mergedNodeCodePoints) {
+ const int wordId, const CodePointArrayView mergedCodePoints) {
ASSERT(!mLock);
mDicNodes.emplace_back();
- mDicNodes.back().initAsChild(dicNode, childrenPtNodeArrayPos, unigramProbability,
- wordId, isBlacklistedOrNotAWord, mergedNodeCodePointCount, mergedNodeCodePoints);
+ mDicNodes.back().initAsChild(dicNode, childrenPtNodeArrayPos, wordId, mergedCodePoints);
}
DicNode *operator[](const int id) {
diff --git a/native/jni/src/suggest/core/dicnode/internal/dic_node_properties.h b/native/jni/src/suggest/core/dicnode/internal/dic_node_properties.h
index 6a8377a1b..6a1b84273 100644
--- a/native/jni/src/suggest/core/dicnode/internal/dic_node_properties.h
+++ b/native/jni/src/suggest/core/dicnode/internal/dic_node_properties.h
@@ -29,21 +29,17 @@ namespace latinime {
class DicNodeProperties {
public:
AK_FORCE_INLINE DicNodeProperties()
- : mChildrenPtNodeArrayPos(NOT_A_DICT_POS), mUnigramProbability(NOT_A_PROBABILITY),
- mDicNodeCodePoint(NOT_A_CODE_POINT), mWordId(NOT_A_WORD_ID),
- mIsBlacklistedOrNotAWord(false), mDepth(0), mLeavingDepth(0) {}
+ : mChildrenPtNodeArrayPos(NOT_A_DICT_POS), mDicNodeCodePoint(NOT_A_CODE_POINT),
+ mWordId(NOT_A_WORD_ID), mDepth(0), mLeavingDepth(0) {}
~DicNodeProperties() {}
// Should be called only once per DicNode is initialized.
- void init(const int childrenPos, const int nodeCodePoint, const int unigramProbability,
- const int wordId, const bool isBlacklistedOrNotAWord, const uint16_t depth,
- const uint16_t leavingDepth, const int *const prevWordIds) {
+ void init(const int childrenPos, const int nodeCodePoint, const int wordId,
+ const uint16_t depth, const uint16_t leavingDepth, const int *const prevWordIds) {
mChildrenPtNodeArrayPos = childrenPos;
mDicNodeCodePoint = nodeCodePoint;
- mUnigramProbability = unigramProbability;
mWordId = wordId;
- mIsBlacklistedOrNotAWord = isBlacklistedOrNotAWord;
mDepth = depth;
mLeavingDepth = leavingDepth;
memmove(mPrevWordIds, prevWordIds, sizeof(mPrevWordIds));
@@ -53,9 +49,7 @@ class DicNodeProperties {
void init(const int rootPtNodeArrayPos, const int *const prevWordIds) {
mChildrenPtNodeArrayPos = rootPtNodeArrayPos;
mDicNodeCodePoint = NOT_A_CODE_POINT;
- mUnigramProbability = NOT_A_PROBABILITY;
mWordId = NOT_A_WORD_ID;
- mIsBlacklistedOrNotAWord = false;
mDepth = 0;
mLeavingDepth = 0;
memmove(mPrevWordIds, prevWordIds, sizeof(mPrevWordIds));
@@ -64,9 +58,7 @@ class DicNodeProperties {
void initByCopy(const DicNodeProperties *const dicNodeProp) {
mChildrenPtNodeArrayPos = dicNodeProp->mChildrenPtNodeArrayPos;
mDicNodeCodePoint = dicNodeProp->mDicNodeCodePoint;
- mUnigramProbability = dicNodeProp->mUnigramProbability;
mWordId = dicNodeProp->mWordId;
- mIsBlacklistedOrNotAWord = dicNodeProp->mIsBlacklistedOrNotAWord;
mDepth = dicNodeProp->mDepth;
mLeavingDepth = dicNodeProp->mLeavingDepth;
memmove(mPrevWordIds, dicNodeProp->mPrevWordIds, sizeof(mPrevWordIds));
@@ -76,9 +68,7 @@ class DicNodeProperties {
void init(const DicNodeProperties *const dicNodeProp, const int codePoint) {
mChildrenPtNodeArrayPos = dicNodeProp->mChildrenPtNodeArrayPos;
mDicNodeCodePoint = codePoint; // Overwrite the node char of a passing child
- mUnigramProbability = dicNodeProp->mUnigramProbability;
mWordId = dicNodeProp->mWordId;
- mIsBlacklistedOrNotAWord = dicNodeProp->mIsBlacklistedOrNotAWord;
mDepth = dicNodeProp->mDepth + 1; // Increment the depth of a passing child
mLeavingDepth = dicNodeProp->mLeavingDepth;
memmove(mPrevWordIds, dicNodeProp->mPrevWordIds, sizeof(mPrevWordIds));
@@ -88,10 +78,6 @@ class DicNodeProperties {
return mChildrenPtNodeArrayPos;
}
- int getUnigramProbability() const {
- return mUnigramProbability;
- }
-
int getDicNodeCodePoint() const {
return mDicNodeCodePoint;
}
@@ -113,10 +99,6 @@ class DicNodeProperties {
return (mChildrenPtNodeArrayPos != NOT_A_DICT_POS) || mDepth != mLeavingDepth;
}
- bool isBlacklistedOrNotAWord() const {
- return mIsBlacklistedOrNotAWord;
- }
-
const int *getPrevWordIds() const {
return mPrevWordIds;
}
@@ -130,12 +112,8 @@ class DicNodeProperties {
// Use a default copy constructor and an assign operator because shallow copies are ok
// for this class
int mChildrenPtNodeArrayPos;
- // TODO: Remove
- int mUnigramProbability;
int mDicNodeCodePoint;
int mWordId;
- // TODO: Remove
- bool mIsBlacklistedOrNotAWord;
uint16_t mDepth;
uint16_t mLeavingDepth;
int mPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
diff --git a/native/jni/src/suggest/core/dictionary/dictionary.cpp b/native/jni/src/suggest/core/dictionary/dictionary.cpp
index 8f9b2aa12..1de405104 100644
--- a/native/jni/src/suggest/core/dictionary/dictionary.cpp
+++ b/native/jni/src/suggest/core/dictionary/dictionary.cpp
@@ -61,10 +61,11 @@ void Dictionary::getSuggestions(ProximityInfo *proximityInfo, DicTraverseSession
}
Dictionary::NgramListenerForPrediction::NgramListenerForPrediction(
- const PrevWordsInfo *const prevWordsInfo, SuggestionResults *const suggestionResults,
+ const PrevWordsInfo *const prevWordsInfo, const WordIdArrayView prevWordIds,
+ SuggestionResults *const suggestionResults,
const DictionaryStructureWithBufferPolicy *const dictStructurePolicy)
- : mPrevWordsInfo(prevWordsInfo), mSuggestionResults(suggestionResults),
- mDictStructurePolicy(dictStructurePolicy) {}
+ : mPrevWordsInfo(prevWordsInfo), mPrevWordIds(prevWordIds),
+ mSuggestionResults(suggestionResults), mDictStructurePolicy(dictStructurePolicy) {}
void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbability,
const int targetWordId) {
@@ -83,19 +84,21 @@ void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbabi
if (codePointCount <= 0) {
return;
}
- const int probability = mDictStructurePolicy->getProbability(
- unigramProbability, ngramProbability);
- mSuggestionResults->addPrediction(targetWordCodePoints, codePointCount, probability);
+ const WordAttributes wordAttributes = mDictStructurePolicy->getWordAttributesInContext(
+ mPrevWordIds.data(), targetWordId, nullptr /* multiBigramMap */);
+ mSuggestionResults->addPrediction(targetWordCodePoints, codePointCount,
+ wordAttributes.getProbability());
}
void Dictionary::getPredictions(const PrevWordsInfo *const prevWordsInfo,
SuggestionResults *const outSuggestionResults) const {
TimeKeeper::setCurrentTime();
- NgramListenerForPrediction listener(prevWordsInfo, outSuggestionResults,
- mDictionaryStructureWithBufferPolicy.get());
int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
prevWordsInfo->getPrevWordIds(mDictionaryStructureWithBufferPolicy.get(), prevWordIds,
true /* tryLowerCaseSearch */);
+ NgramListenerForPrediction listener(prevWordsInfo,
+ WordIdArrayView::fromFixedSizeArray(prevWordIds), outSuggestionResults,
+ mDictionaryStructureWithBufferPolicy.get());
mDictionaryStructureWithBufferPolicy->iterateNgramEntries(prevWordIds, &listener);
}
diff --git a/native/jni/src/suggest/core/dictionary/dictionary.h b/native/jni/src/suggest/core/dictionary/dictionary.h
index 50951fbc1..0b54f30e9 100644
--- a/native/jni/src/suggest/core/dictionary/dictionary.h
+++ b/native/jni/src/suggest/core/dictionary/dictionary.h
@@ -26,6 +26,7 @@
#include "suggest/core/policy/dictionary_header_structure_policy.h"
#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h"
#include "suggest/core/suggest_interface.h"
+#include "utils/int_array_view.h"
namespace latinime {
@@ -118,7 +119,7 @@ class Dictionary {
class NgramListenerForPrediction : public NgramListener {
public:
NgramListenerForPrediction(const PrevWordsInfo *const prevWordsInfo,
- SuggestionResults *const suggestionResults,
+ const WordIdArrayView prevWordIds, SuggestionResults *const suggestionResults,
const DictionaryStructureWithBufferPolicy *const dictStructurePolicy);
virtual void onVisitEntry(const int ngramProbability, const int targetWordId);
@@ -126,6 +127,7 @@ class Dictionary {
DISALLOW_IMPLICIT_CONSTRUCTORS(NgramListenerForPrediction);
const PrevWordsInfo *const mPrevWordsInfo;
+ const WordIdArrayView mPrevWordIds;
SuggestionResults *const mSuggestionResults;
const DictionaryStructureWithBufferPolicy *const mDictStructurePolicy;
};
diff --git a/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp b/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp
index 94d7c886f..f71d4c5f0 100644
--- a/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp
+++ b/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp
@@ -54,15 +54,18 @@ namespace latinime {
current.swap(next);
}
- int maxUnigramProbability = NOT_A_PROBABILITY;
+ int maxProbability = NOT_A_PROBABILITY;
for (const DicNode &dicNode : current) {
if (!dicNode.isTerminalDicNode()) {
continue;
}
+ const WordAttributes wordAttributes =
+ dictionaryStructurePolicy->getWordAttributesInContext(dicNode.getPrevWordIds(),
+ dicNode.getWordId(), nullptr /* multiBigramMap */);
// dicNode can contain case errors, accent errors, intentional omissions or digraphs.
- maxUnigramProbability = std::max(maxUnigramProbability, dicNode.getUnigramProbability());
+ maxProbability = std::max(maxProbability, wordAttributes.getProbability());
}
- return maxUnigramProbability;
+ return maxProbability;
}
/* static */ void DictionaryUtils::processChildDicNodes(
diff --git a/native/jni/src/suggest/core/dictionary/word_attributes.h b/native/jni/src/suggest/core/dictionary/word_attributes.h
new file mode 100644
index 000000000..6e9da3570
--- /dev/null
+++ b/native/jni/src/suggest/core/dictionary/word_attributes.h
@@ -0,0 +1,60 @@
+/*
+ * Copyright (C) 2014, The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LATINIME_WORD_ATTRIBUTES_H
+#define LATINIME_WORD_ATTRIBUTES_H
+
+#include "defines.h"
+
+class WordAttributes {
+ public:
+ // Invalid word attributes.
+ WordAttributes()
+ : mProbability(NOT_A_PROBABILITY), mIsBlacklisted(false), mIsNotAWord(false),
+ mIsPossiblyOffensive(false) {}
+
+ WordAttributes(const int probability, const bool isBlacklisted, const bool isNotAWord,
+ const bool isPossiblyOffensive)
+ : mProbability(probability), mIsBlacklisted(isBlacklisted), mIsNotAWord(isNotAWord),
+ mIsPossiblyOffensive(isPossiblyOffensive) {}
+
+ int getProbability() const {
+ return mProbability;
+ }
+
+ bool isBlacklisted() const {
+ return mIsBlacklisted;
+ }
+
+ bool isNotAWord() const {
+ return mIsNotAWord;
+ }
+
+ bool isPossiblyOffensive() const {
+ return mIsPossiblyOffensive;
+ }
+
+ private:
+ DISALLOW_ASSIGNMENT_OPERATOR(WordAttributes);
+
+ int mProbability;
+ bool mIsBlacklisted;
+ bool mIsNotAWord;
+ bool mIsPossiblyOffensive;
+};
+
+ // namespace
+#endif /* LATINIME_WORD_ATTRIBUTES_H */
diff --git a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h
index aeeb66f93..7414f696c 100644
--- a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h
+++ b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h
@@ -22,6 +22,7 @@
#include "defines.h"
#include "suggest/core/dictionary/binary_dictionary_shortcut_iterator.h"
#include "suggest/core/dictionary/property/word_property.h"
+#include "suggest/core/dictionary/word_attributes.h"
#include "utils/int_array_view.h"
namespace latinime {
@@ -29,6 +30,7 @@ namespace latinime {
class DicNode;
class DicNodeVector;
class DictionaryHeaderStructurePolicy;
+class MultiBigramMap;
class NgramListener;
class PrevWordsInfo;
class UnigramProperty;
@@ -56,6 +58,10 @@ class DictionaryStructureWithBufferPolicy {
virtual int getWordId(const CodePointArrayView wordCodePoints,
const bool forceLowerCaseSearch) const = 0;
+ virtual const WordAttributes getWordAttributesInContext(const int *const prevWordIds,
+ const int wordId, MultiBigramMap *const multiBigramMap) const = 0;
+
+ // TODO: Remove
virtual int getProbability(const int unigramProbability, const int bigramProbability) const = 0;
virtual int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const = 0;
diff --git a/native/jni/src/suggest/core/policy/traversal.h b/native/jni/src/suggest/core/policy/traversal.h
index 8ddaa0514..6dfa7e314 100644
--- a/native/jni/src/suggest/core/policy/traversal.h
+++ b/native/jni/src/suggest/core/policy/traversal.h
@@ -48,7 +48,8 @@ class Traversal {
virtual int getTerminalCacheSize() const = 0;
virtual bool isPossibleOmissionChildNode(const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0;
- virtual bool isGoodToTraverseNextWord(const DicNode *const dicNode) const = 0;
+ virtual bool isGoodToTraverseNextWord(const DicNode *const dicNode,
+ const int probability) const = 0;
protected:
Traversal() {}
diff --git a/native/jni/src/suggest/core/result/suggestions_output_utils.cpp b/native/jni/src/suggest/core/result/suggestions_output_utils.cpp
index cecb4e216..6e0193772 100644
--- a/native/jni/src/suggest/core/result/suggestions_output_utils.cpp
+++ b/native/jni/src/suggest/core/result/suggestions_output_utils.cpp
@@ -85,9 +85,9 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16;
scoringPolicy->getDoubleLetterDemotionDistanceCost(terminalDicNode);
const float compoundDistance = terminalDicNode->getCompoundDistance(languageWeight)
+ doubleLetterCost;
- const bool isPossiblyOffensiveWord =
- traverseSession->getDictionaryStructurePolicy()->getProbability(
- terminalDicNode->getUnigramProbability(), NOT_A_PROBABILITY) <= 0;
+ const WordAttributes wordAttributes = traverseSession->getDictionaryStructurePolicy()
+ ->getWordAttributesInContext(terminalDicNode->getPrevWordIds(),
+ terminalDicNode->getWordId(), nullptr /* multiBigramMap */);
const bool isExactMatch =
ErrorTypeUtils::isExactMatch(terminalDicNode->getContainedErrorTypes());
const bool isExactMatchWithIntentionalOmission =
@@ -97,19 +97,19 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16;
// Heuristic: We exclude probability=0 first-char-uppercase words from exact match.
// (e.g. "AMD" and "and")
const bool isSafeExactMatch = isExactMatch
- && !(isPossiblyOffensiveWord && isFirstCharUppercase);
+ && !(wordAttributes.isPossiblyOffensive() && isFirstCharUppercase);
const int outputTypeFlags =
- (isPossiblyOffensiveWord ? Dictionary::KIND_FLAG_POSSIBLY_OFFENSIVE : 0)
+ (wordAttributes.isPossiblyOffensive() ? Dictionary::KIND_FLAG_POSSIBLY_OFFENSIVE : 0)
| ((isSafeExactMatch && boostExactMatches) ? Dictionary::KIND_FLAG_EXACT_MATCH : 0)
| (isExactMatchWithIntentionalOmission ?
Dictionary::KIND_FLAG_EXACT_MATCH_WITH_INTENTIONAL_OMISSION : 0);
// Entries that are blacklisted or do not represent a word should not be output.
- const bool isValidWord = !terminalDicNode->isBlacklistedOrNotAWord();
+ const bool isValidWord = !(wordAttributes.isBlacklisted() || wordAttributes.isNotAWord());
// When we have to block offensive words, non-exact matched offensive words should not be
// output.
const bool blockOffensiveWords = traverseSession->getSuggestOptions()->blockOffensiveWords();
- const bool isBlockedOffensiveWord = blockOffensiveWords && isPossiblyOffensiveWord
+ const bool isBlockedOffensiveWord = blockOffensiveWords && wordAttributes.isPossiblyOffensive()
&& !isSafeExactMatch;
// Increase output score of top typing suggestion to ensure autocorrection.
diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp
index 66c87f04c..947d41f4b 100644
--- a/native/jni/src/suggest/core/suggest.cpp
+++ b/native/jni/src/suggest/core/suggest.cpp
@@ -21,6 +21,7 @@
#include "suggest/core/dicnode/dic_node_vector.h"
#include "suggest/core/dictionary/dictionary.h"
#include "suggest/core/dictionary/digraph_utils.h"
+#include "suggest/core/dictionary/word_attributes.h"
#include "suggest/core/layout/proximity_info.h"
#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h"
#include "suggest/core/policy/traversal.h"
@@ -412,7 +413,11 @@ void Suggest::weightChildNode(DicTraverseSession *traverseSession, DicNode *dicN
*/
void Suggest::createNextWordDicNode(DicTraverseSession *traverseSession, DicNode *dicNode,
const bool spaceSubstitution) const {
- if (!TRAVERSAL->isGoodToTraverseNextWord(dicNode)) {
+ const WordAttributes wordAttributes =
+ traverseSession->getDictionaryStructurePolicy()->getWordAttributesInContext(
+ dicNode->getPrevWordIds(), dicNode->getWordId(),
+ traverseSession->getMultiBigramMap());
+ if (!TRAVERSAL->isGoodToTraverseNextWord(dicNode, wordAttributes.getProbability())) {
return;
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
index 6480374df..9b8a50b07 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
@@ -28,6 +28,7 @@
#include "suggest/core/dicnode/dic_node.h"
#include "suggest/core/dicnode/dic_node_vector.h"
+#include "suggest/core/dictionary/multi_bigram_map.h"
#include "suggest/core/dictionary/ngram_listener.h"
#include "suggest/core/dictionary/property/bigram_property.h"
#include "suggest/core/dictionary/property/unigram_property.h"
@@ -78,10 +79,7 @@ void Ver4PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const d
}
const int wordId = isTerminal ? ptNodeParams.getHeadPos() : NOT_A_WORD_ID;
childDicNodes->pushLeavingChild(dicNode, ptNodeParams.getChildrenPos(),
- ptNodeParams.getProbability(), wordId,
- ptNodeParams.isBlacklisted()
- || ptNodeParams.isNotAWord() /* isBlacklistedOrNotAWord */,
- ptNodeParams.getCodePointCount(), ptNodeParams.getCodePoints());
+ wordId, ptNodeParams.getCodePointArrayView());
}
if (readingHelper.isError()) {
mIsCorrupted = true;
@@ -117,6 +115,35 @@ int Ver4PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints,
return getWordIdFromTerminalPtNodePos(ptNodePos);
}
+const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext(
+ const int *const prevWordIds, const int wordId,
+ MultiBigramMap *const multiBigramMap) const {
+ if (wordId == NOT_A_WORD_ID) {
+ return WordAttributes();
+ }
+ const int ptNodePos = getTerminalPtNodePosFromWordId(wordId);
+ const PtNodeParams ptNodeParams(mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos));
+ if (multiBigramMap) {
+ const int probability = multiBigramMap->getBigramProbability(this /* structurePolicy */,
+ prevWordIds, wordId, ptNodeParams.getProbability());
+ return getWordAttributes(probability, ptNodeParams);
+ }
+ if (prevWordIds) {
+ const int probability = getProbabilityOfWord(prevWordIds, wordId);
+ if (probability != NOT_A_PROBABILITY) {
+ return getWordAttributes(probability, ptNodeParams);
+ }
+ }
+ return getWordAttributes(getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY),
+ ptNodeParams);
+}
+
+const WordAttributes Ver4PatriciaTriePolicy::getWordAttributes(const int probability,
+ const PtNodeParams &ptNodeParams) const {
+ return WordAttributes(probability, ptNodeParams.isBlacklisted(), ptNodeParams.isNotAWord(),
+ ptNodeParams.getProbability() == 0);
+}
+
int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability,
const int bigramProbability) const {
if (mHeaderPolicy->isDecayingDict()) {
@@ -333,7 +360,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI
}
bool addedNewBigram = false;
const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]);
- if (mUpdatingHelper.addNgramEntry(PtNodePosArrayView::fromObject(&prevWordPtNodePos),
+ if (mUpdatingHelper.addNgramEntry(PtNodePosArrayView::singleElementView(&prevWordPtNodePos),
wordPos, bigramProperty, &addedNewBigram)) {
if (addedNewBigram) {
mBigramCount++;
@@ -375,7 +402,7 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor
}
const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]);
if (mUpdatingHelper.removeNgramEntry(
- PtNodePosArrayView::fromObject(&prevWordPtNodePos), wordPos)) {
+ PtNodePosArrayView::singleElementView(&prevWordPtNodePos), wordPos)) {
mBigramCount--;
return true;
} else {
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h
index 562c219f4..871b556e1 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h
@@ -91,6 +91,9 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
+ const WordAttributes getWordAttributesInContext(const int *const prevWordIds, const int wordId,
+ MultiBigramMap *const multiBigramMap) const;
+
int getProbability(const int unigramProbability, const int bigramProbability) const;
int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const;
@@ -163,6 +166,8 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
int getShortcutPositionOfPtNode(const int ptNodePos) const;
int getWordIdFromTerminalPtNodePos(const int ptNodePos) const;
int getTerminalPtNodePosFromWordId(const int wordId) const;
+ const WordAttributes getWordAttributes(const int probability,
+ const PtNodeParams &ptNodeParams) const;
};
} // namespace v402
} // namespace backward
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h
index b2e60a837..c12fed324 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h
@@ -24,6 +24,7 @@
#include "suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h"
#include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h"
#include "utils/char_utils.h"
+#include "utils/int_array_view.h"
namespace latinime {
@@ -174,11 +175,17 @@ class PtNodeParams {
return mParentPos;
}
+ AK_FORCE_INLINE const CodePointArrayView getCodePointArrayView() const {
+ return CodePointArrayView(mCodePoints, mCodePointCount);
+ }
+
+ // TODO: Remove
// Number of code points
AK_FORCE_INLINE uint8_t getCodePointCount() const {
return mCodePointCount;
}
+ // TODO: Remove
AK_FORCE_INLINE const int *getCodePoints() const {
return mCodePoints;
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp
index e0406ab07..e76bae97c 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp
@@ -21,6 +21,7 @@
#include "suggest/core/dicnode/dic_node.h"
#include "suggest/core/dicnode/dic_node_vector.h"
#include "suggest/core/dictionary/binary_dictionary_bigrams_iterator.h"
+#include "suggest/core/dictionary/multi_bigram_map.h"
#include "suggest/core/dictionary/ngram_listener.h"
#include "suggest/core/session/prev_words_info.h"
#include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h"
@@ -281,6 +282,35 @@ int PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints,
return getWordIdFromTerminalPtNodePos(ptNodePos);
}
+const WordAttributes PatriciaTriePolicy::getWordAttributesInContext(const int *const prevWordIds,
+ const int wordId, MultiBigramMap *const multiBigramMap) const {
+ if (wordId == NOT_A_WORD_ID) {
+ return WordAttributes();
+ }
+ const int ptNodePos = getTerminalPtNodePosFromWordId(wordId);
+ const PtNodeParams ptNodeParams =
+ mPtNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
+ if (multiBigramMap) {
+ const int probability = multiBigramMap->getBigramProbability(this /* structurePolicy */,
+ prevWordIds, wordId, ptNodeParams.getProbability());
+ return getWordAttributes(probability, ptNodeParams);
+ }
+ if (prevWordIds) {
+ const int bigramProbability = getProbabilityOfWord(prevWordIds, wordId);
+ if (bigramProbability != NOT_A_PROBABILITY) {
+ return getWordAttributes(bigramProbability, ptNodeParams);
+ }
+ }
+ return getWordAttributes(getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY),
+ ptNodeParams);
+}
+
+const WordAttributes PatriciaTriePolicy::getWordAttributes(const int probability,
+ const PtNodeParams &ptNodeParams) const {
+ return WordAttributes(probability, ptNodeParams.isBlacklisted(), ptNodeParams.isNotAWord(),
+ ptNodeParams.getProbability() == 0);
+}
+
int PatriciaTriePolicy::getProbability(const int unigramProbability,
const int bigramProbability) const {
// Due to space constraints, the probability for bigrams is approximate - the lower the unigram
@@ -377,10 +407,8 @@ int PatriciaTriePolicy::createAndGetLeavingChildNode(const DicNode *const dicNod
// Skip PtNodes don't start with Unicode code point because they represent non-word information.
if (CharUtils::isInUnicodeSpace(mergedNodeCodePoints[0])) {
const int wordId = PatriciaTrieReadingUtils::isTerminal(flags) ? ptNodePos : NOT_A_WORD_ID;
- childDicNodes->pushLeavingChild(dicNode, childrenPos, probability, wordId,
- PatriciaTrieReadingUtils::isBlacklisted(flags)
- || PatriciaTrieReadingUtils::isNotAWord(flags),
- mergedNodeCodePointCount, mergedNodeCodePoints);
+ childDicNodes->pushLeavingChild(dicNode, childrenPos, wordId,
+ CodePointArrayView(mergedNodeCodePoints, mergedNodeCodePointCount));
}
return siblingPos;
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h
index 66df52779..8c1665d7d 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h
@@ -66,6 +66,9 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
+ const WordAttributes getWordAttributesInContext(const int *const prevWordIds, const int wordId,
+ MultiBigramMap *const multiBigramMap) const;
+
int getProbability(const int unigramProbability, const int bigramProbability) const;
int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const;
@@ -160,6 +163,8 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
DicNodeVector *const childDicNodes) const;
int getWordIdFromTerminalPtNodePos(const int ptNodePos) const;
int getTerminalPtNodePosFromWordId(const int wordId) const;
+ const WordAttributes getWordAttributes(const int probability,
+ const PtNodeParams &ptNodeParams) const;
};
} // namespace latinime
#endif // LATINIME_PATRICIA_TRIE_POLICY_H
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
index d5749e9eb..f54bb151a 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
@@ -38,6 +38,40 @@ bool LanguageModelDictContent::runGC(
0 /* nextLevelBitmapEntryIndex */, outNgramCount);
}
+int LanguageModelDictContent::getWordProbability(const WordIdArrayView prevWordIds,
+ const int wordId) const {
+ int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
+ bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex();
+ int maxLevel = 0;
+ for (size_t i = 0; i < prevWordIds.size(); ++i) {
+ const int nextBitmapEntryIndex =
+ mTrieMap.get(prevWordIds[i], bitmapEntryIndices[i]).mNextLevelBitmapEntryIndex;
+ if (nextBitmapEntryIndex == TrieMap::INVALID_INDEX) {
+ break;
+ }
+ maxLevel = i + 1;
+ bitmapEntryIndices[i + 1] = nextBitmapEntryIndex;
+ }
+
+ for (int i = maxLevel; i >= 0; --i) {
+ const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndices[i]);
+ if (!result.mIsValid) {
+ continue;
+ }
+ const int probability =
+ ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo).getProbability();
+ if (mHasHistoricalInfo) {
+ return std::min(
+ probability + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */),
+ MAX_PROBABILITY);
+ } else {
+ return probability;
+ }
+ }
+ // Cannot find the word.
+ return NOT_A_PROBABILITY;
+}
+
ProbabilityEntry LanguageModelDictContent::getNgramProbabilityEntry(
const WordIdArrayView prevWordIds, const int wordId) const {
const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds);
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
index aa612e35a..4e0b47036 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
@@ -128,6 +128,8 @@ class LanguageModelDictContent {
const LanguageModelDictContent *const originalContent,
int *const outNgramCount);
+ int getWordProbability(const WordIdArrayView prevWordIds, const int wordId) const;
+
ProbabilityEntry getProbabilityEntry(const int wordId) const {
return getNgramProbabilityEntry(WordIdArrayView(), wordId);
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
index 466c49952..0472a453a 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
@@ -20,6 +20,7 @@
#include "suggest/core/dicnode/dic_node.h"
#include "suggest/core/dicnode/dic_node_vector.h"
+#include "suggest/core/dictionary/multi_bigram_map.h"
#include "suggest/core/dictionary/ngram_listener.h"
#include "suggest/core/dictionary/property/bigram_property.h"
#include "suggest/core/dictionary/property/unigram_property.h"
@@ -68,10 +69,7 @@ void Ver4PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const d
}
const int wordId = isTerminal ? ptNodeParams.getTerminalId() : NOT_A_WORD_ID;
childDicNodes->pushLeavingChild(dicNode, ptNodeParams.getChildrenPos(),
- ptNodeParams.getProbability(), wordId,
- ptNodeParams.isBlacklisted()
- || ptNodeParams.isNotAWord() /* isBlacklistedOrNotAWord */,
- ptNodeParams.getCodePointCount(), ptNodeParams.getCodePoints());
+ wordId, ptNodeParams.getCodePointArrayView());
}
if (readingHelper.isError()) {
mIsCorrupted = true;
@@ -112,6 +110,21 @@ int Ver4PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints,
return ptNodeParams.getTerminalId();
}
+const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext(
+ const int *const prevWordIds, const int wordId,
+ MultiBigramMap *const multiBigramMap) const {
+ if (wordId == NOT_A_WORD_ID) {
+ return WordAttributes();
+ }
+ const int ptNodePos =
+ mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
+ const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
+ // TODO: Support n-gram.
+ return WordAttributes(mBuffers->getLanguageModelDictContent()->getWordProbability(
+ WordIdArrayView::singleElementView(prevWordIds), wordId), ptNodeParams.isBlacklisted(),
+ ptNodeParams.isNotAWord(), ptNodeParams.getProbability() == 0);
+}
+
int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability,
const int bigramProbability) const {
if (mHeaderPolicy->isDecayingDict()) {
@@ -143,7 +156,7 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds,
// TODO: Support n-gram.
const ProbabilityEntry probabilityEntry =
mBuffers->getLanguageModelDictContent()->getNgramProbabilityEntry(
- IntArrayView::fromObject(prevWordIds), wordId);
+ IntArrayView::singleElementView(prevWordIds), wordId);
if (!probabilityEntry.isValid()) {
return NOT_A_PROBABILITY;
}
@@ -171,7 +184,7 @@ void Ver4PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordIds,
// TODO: Support n-gram.
const auto languageModelDictContent = mBuffers->getLanguageModelDictContent();
for (const auto entry : languageModelDictContent->getProbabilityEntries(
- WordIdArrayView::fromObject(prevWordIds))) {
+ WordIdArrayView::singleElementView(prevWordIds))) {
const ProbabilityEntry &probabilityEntry = entry.getProbabilityEntry();
const int probability = probabilityEntry.hasHistoricalInfo() ?
ForgettingCurveUtils::decodeProbability(
@@ -488,7 +501,7 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
// Fetch bigram information.
// TODO: Support n-gram.
std::vector<BigramProperty> bigrams;
- const WordIdArrayView prevWordIds = WordIdArrayView::fromObject(&wordId);
+ const WordIdArrayView prevWordIds = WordIdArrayView::singleElementView(&wordId);
int bigramWord1CodePoints[MAX_WORD_LENGTH];
for (const auto entry : mBuffers->getLanguageModelDictContent()->getProbabilityEntries(
prevWordIds)) {
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h
index 0b8eec40b..980c16e4a 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h
@@ -68,6 +68,9 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
+ const WordAttributes getWordAttributesInContext(const int *const prevWordIds, const int wordId,
+ MultiBigramMap *const multiBigramMap) const;
+
int getProbability(const int unigramProbability, const int bigramProbability) const;
int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h
index 9910777b8..313eb6b64 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h
@@ -48,6 +48,11 @@ class ForgettingCurveUtils {
static bool needsToDecay(const bool mindsBlockByDecay, const int unigramCount,
const int bigramCount, const HeaderPolicy *const headerPolicy);
+ // TODO: Improve probability computation method and remove this.
+ static int getProbabilityBiasForNgram(const int n) {
+ return (n - 1) * MULTIPLIER_TWO_IN_PROBABILITY_SCALE;
+ }
+
AK_FORCE_INLINE static int getUnigramCountHardLimit(const int maxUnigramCount) {
return static_cast<int>(static_cast<float>(maxUnigramCount)
* UNIGRAM_COUNT_HARD_LIMIT_WEIGHT);
diff --git a/native/jni/src/suggest/policyimpl/typing/typing_traversal.h b/native/jni/src/suggest/policyimpl/typing/typing_traversal.h
index ed9df8eb3..b64ee8be4 100644
--- a/native/jni/src/suggest/policyimpl/typing/typing_traversal.h
+++ b/native/jni/src/suggest/policyimpl/typing/typing_traversal.h
@@ -161,9 +161,8 @@ class TypingTraversal : public Traversal {
return true;
}
- AK_FORCE_INLINE bool isGoodToTraverseNextWord(const DicNode *const dicNode) const {
- // TODO: Quit using unigram probability and use probability in the context.
- const int probability = dicNode->getUnigramProbability();
+ AK_FORCE_INLINE bool isGoodToTraverseNextWord(const DicNode *const dicNode,
+ const int probability) const {
if (probability < ScoringParams::THRESHOLD_NEXT_WORD_PROBABILITY) {
return false;
}
diff --git a/native/jni/src/utils/int_array_view.h b/native/jni/src/utils/int_array_view.h
index c9c3b21d4..08256bdef 100644
--- a/native/jni/src/utils/int_array_view.h
+++ b/native/jni/src/utils/int_array_view.h
@@ -61,9 +61,9 @@ class IntArrayView {
return IntArrayView(array, N);
}
- // Returns a view that points one int object. Does not take ownership of the given object.
- AK_FORCE_INLINE static IntArrayView fromObject(const int *const object) {
- return IntArrayView(object, 1);
+ // Returns a view that points one int object.
+ AK_FORCE_INLINE static IntArrayView singleElementView(const int *const ptr) {
+ return IntArrayView(ptr, 1);
}
AK_FORCE_INLINE int operator[](const size_t index) const {
diff --git a/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp b/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp
index ca8d56f27..e6f0353e3 100644
--- a/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp
+++ b/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp
@@ -26,28 +26,28 @@ namespace latinime {
namespace {
TEST(LanguageModelDictContentTest, TestUnigramProbability) {
- LanguageModelDictContent LanguageModelDictContent(false /* useHistoricalInfo */);
+ LanguageModelDictContent languageModelDictContent(false /* useHistoricalInfo */);
const int flag = 0xFF;
const int probability = 10;
const int wordId = 100;
const ProbabilityEntry probabilityEntry(flag, probability);
- LanguageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry);
+ languageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry);
const ProbabilityEntry entry =
- LanguageModelDictContent.getProbabilityEntry(wordId);
+ languageModelDictContent.getProbabilityEntry(wordId);
EXPECT_EQ(flag, entry.getFlags());
EXPECT_EQ(probability, entry.getProbability());
// Remove
- EXPECT_TRUE(LanguageModelDictContent.removeProbabilityEntry(wordId));
- EXPECT_FALSE(LanguageModelDictContent.getProbabilityEntry(wordId).isValid());
- EXPECT_FALSE(LanguageModelDictContent.removeProbabilityEntry(wordId));
- EXPECT_TRUE(LanguageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry));
- EXPECT_TRUE(LanguageModelDictContent.getProbabilityEntry(wordId).isValid());
+ EXPECT_TRUE(languageModelDictContent.removeProbabilityEntry(wordId));
+ EXPECT_FALSE(languageModelDictContent.getProbabilityEntry(wordId).isValid());
+ EXPECT_FALSE(languageModelDictContent.removeProbabilityEntry(wordId));
+ EXPECT_TRUE(languageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry));
+ EXPECT_TRUE(languageModelDictContent.getProbabilityEntry(wordId).isValid());
}
TEST(LanguageModelDictContentTest, TestUnigramProbabilityWithHistoricalInfo) {
- LanguageModelDictContent LanguageModelDictContent(true /* useHistoricalInfo */);
+ LanguageModelDictContent languageModelDictContent(true /* useHistoricalInfo */);
const int flag = 0xF0;
const int timestamp = 0x3FFFFFFF;
@@ -56,19 +56,19 @@ TEST(LanguageModelDictContentTest, TestUnigramProbabilityWithHistoricalInfo) {
const int wordId = 100;
const HistoricalInfo historicalInfo(timestamp, level, count);
const ProbabilityEntry probabilityEntry(flag, &historicalInfo);
- LanguageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry);
- const ProbabilityEntry entry = LanguageModelDictContent.getProbabilityEntry(wordId);
+ languageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry);
+ const ProbabilityEntry entry = languageModelDictContent.getProbabilityEntry(wordId);
EXPECT_EQ(flag, entry.getFlags());
EXPECT_EQ(timestamp, entry.getHistoricalInfo()->getTimeStamp());
EXPECT_EQ(level, entry.getHistoricalInfo()->getLevel());
EXPECT_EQ(count, entry.getHistoricalInfo()->getCount());
// Remove
- EXPECT_TRUE(LanguageModelDictContent.removeProbabilityEntry(wordId));
- EXPECT_FALSE(LanguageModelDictContent.getProbabilityEntry(wordId).isValid());
- EXPECT_FALSE(LanguageModelDictContent.removeProbabilityEntry(wordId));
- EXPECT_TRUE(LanguageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry));
- EXPECT_TRUE(LanguageModelDictContent.removeProbabilityEntry(wordId));
+ EXPECT_TRUE(languageModelDictContent.removeProbabilityEntry(wordId));
+ EXPECT_FALSE(languageModelDictContent.getProbabilityEntry(wordId).isValid());
+ EXPECT_FALSE(languageModelDictContent.removeProbabilityEntry(wordId));
+ EXPECT_TRUE(languageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry));
+ EXPECT_TRUE(languageModelDictContent.removeProbabilityEntry(wordId));
}
TEST(LanguageModelDictContentTest, TestIterateProbabilityEntry) {
@@ -89,5 +89,31 @@ TEST(LanguageModelDictContentTest, TestIterateProbabilityEntry) {
EXPECT_TRUE(wordIdSet.empty());
}
+TEST(LanguageModelDictContentTest, TestGetWordProbability) {
+ LanguageModelDictContent languageModelDictContent(false /* useHistoricalInfo */);
+
+ const int flag = 0xFF;
+ const int probability = 10;
+ const int bigramProbability = 20;
+ const int trigramProbability = 30;
+ const int wordId = 100;
+ const int prevWordIdArray[] = { 1, 2 };
+ const WordIdArrayView prevWordIds = WordIdArrayView::fromFixedSizeArray(prevWordIdArray);
+
+ const ProbabilityEntry probabilityEntry(flag, probability);
+ languageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry);
+ const ProbabilityEntry bigramProbabilityEntry(flag, bigramProbability);
+ languageModelDictContent.setProbabilityEntry(prevWordIds[0], &probabilityEntry);
+ languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1), wordId,
+ &bigramProbabilityEntry);
+ EXPECT_EQ(bigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId));
+ const ProbabilityEntry trigramProbabilityEntry(flag, trigramProbability);
+ languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1),
+ prevWordIds[1], &probabilityEntry);
+ languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(2), wordId,
+ &trigramProbabilityEntry);
+ EXPECT_EQ(trigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId));
+}
+
} // namespace
} // namespace latinime
diff --git a/native/jni/tests/utils/int_array_view_test.cpp b/native/jni/tests/utils/int_array_view_test.cpp
index 161df2f43..93bad5822 100644
--- a/native/jni/tests/utils/int_array_view_test.cpp
+++ b/native/jni/tests/utils/int_array_view_test.cpp
@@ -52,7 +52,7 @@ TEST(IntArrayViewTest, TestConstructFromArray) {
TEST(IntArrayViewTest, TestConstructFromObject) {
const int object = 10;
- const auto intArrayView = IntArrayView::fromObject(&object);
+ const auto intArrayView = IntArrayView::singleElementView(&object);
EXPECT_EQ(1u, intArrayView.size());
EXPECT_EQ(object, intArrayView[0]);
}
diff --git a/tests/src/com/android/inputmethod/latin/personalization/PersonalizationDictionaryTests.java b/tests/src/com/android/inputmethod/latin/personalization/PersonalizationDictionaryTests.java
index e9a97ff92..4e7e8140a 100644
--- a/tests/src/com/android/inputmethod/latin/personalization/PersonalizationDictionaryTests.java
+++ b/tests/src/com/android/inputmethod/latin/personalization/PersonalizationDictionaryTests.java
@@ -29,6 +29,7 @@ import com.android.inputmethod.latin.BinaryDictionary;
import com.android.inputmethod.latin.Dictionary;
import com.android.inputmethod.latin.DictionaryFacilitator;
import com.android.inputmethod.latin.ExpandableBinaryDictionary;
+import com.android.inputmethod.latin.RichInputMethodManager;
import com.android.inputmethod.latin.ExpandableBinaryDictionary.AddMultipleDictionaryEntriesCallback;
import com.android.inputmethod.latin.makedict.CodePointUtils;
import com.android.inputmethod.latin.settings.SpacingAndPunctuations;
@@ -36,6 +37,7 @@ import com.android.inputmethod.latin.settings.SpacingAndPunctuations;
import android.test.AndroidTestCase;
import android.test.suitebuilder.annotation.LargeTest;
import android.util.Log;
+import android.view.inputmethod.InputMethodSubtype;
/**
* Unit tests for personalization dictionary
@@ -55,16 +57,28 @@ public class PersonalizationDictionaryTests extends AndroidTestCase {
final DictionaryFacilitator dictionaryFacilitator = new DictionaryFacilitator(getContext());
dictionaryFacilitator.resetDictionariesForTesting(getContext(), LOCALE_EN_US, dictTypes,
new HashMap<String, File>(), new HashMap<String, Map<String, String>>());
+ // Set subtypes.
+ RichInputMethodManager.init(getContext());
+ final RichInputMethodManager richImm = RichInputMethodManager.getInstance();
+ final ArrayList<InputMethodSubtype> subtypes = new ArrayList<>();
+ subtypes.add(richImm.findSubtypeByLocaleAndKeyboardLayoutSet(
+ LOCALE_EN_US.toString(), "qwerty"));
+ dictionaryFacilitator.updateEnabledSubtypes(subtypes);
return dictionaryFacilitator;
}
public void testAddManyTokens() {
final DictionaryFacilitator dictionaryFacilitator = getDictionaryFacilitator();
dictionaryFacilitator.clearPersonalizationDictionary();
- final int dataChunkCount = 20;
- final int wordCountInOneChunk = 2000;
+ final int dataChunkCount = 100;
+ final int wordCountInOneChunk = 200;
+ final int uniqueWordCount = 100;
final Random random = new Random(System.currentTimeMillis());
final int[] codePointSet = CodePointUtils.LATIN_ALPHABETS_LOWER;
+ final ArrayList<String> words = new ArrayList<>();
+ for (int i = 0; i < uniqueWordCount; i++) {
+ words.add(CodePointUtils.generateWord(random, codePointSet));
+ }
final SpacingAndPunctuations spacingAndPunctuations =
new SpacingAndPunctuations(getContext().getResources());
@@ -75,7 +89,7 @@ public class PersonalizationDictionaryTests extends AndroidTestCase {
for (int i = 0; i < dataChunkCount; i++) {
final ArrayList<String> tokens = new ArrayList<>();
for (int j = 0; j < wordCountInOneChunk; j++) {
- tokens.add(CodePointUtils.generateWord(random, codePointSet));
+ tokens.add(words.get(random.nextInt(words.size())));
}
final PersonalizationDataChunk personalizationDataChunk = new PersonalizationDataChunk(
true /* inputByUser */, tokens, timeStampInSeconds, DUMMY_PACKAGE_NAME,