aboutsummaryrefslogtreecommitdiffstats
path: root/native/jni/src
diff options
context:
space:
mode:
Diffstat (limited to 'native/jni/src')
-rw-r--r--native/jni/src/suggest/core/dicnode/dic_node.h27
-rw-r--r--native/jni/src/suggest/core/dicnode/dic_node_utils.cpp24
-rw-r--r--native/jni/src/suggest/core/dicnode/dic_node_utils.h4
-rw-r--r--native/jni/src/suggest/core/dicnode/dic_node_vector.h9
-rw-r--r--native/jni/src/suggest/core/dicnode/internal/dic_node_properties.h34
-rw-r--r--native/jni/src/suggest/core/dictionary/dictionary.cpp19
-rw-r--r--native/jni/src/suggest/core/dictionary/dictionary.h4
-rw-r--r--native/jni/src/suggest/core/dictionary/dictionary_utils.cpp5
-rw-r--r--native/jni/src/suggest/core/dictionary/word_attributes.h60
-rw-r--r--native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h6
-rw-r--r--native/jni/src/suggest/core/policy/traversal.h3
-rw-r--r--native/jni/src/suggest/core/result/suggestions_output_utils.cpp14
-rw-r--r--native/jni/src/suggest/core/suggest.cpp7
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp39
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h5
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h7
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp37
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h5
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp34
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h2
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp27
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h3
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h5
-rw-r--r--native/jni/src/suggest/policyimpl/typing/typing_traversal.h4
-rw-r--r--native/jni/src/utils/int_array_view.h15
25 files changed, 279 insertions, 120 deletions
diff --git a/native/jni/src/suggest/core/dicnode/dic_node.h b/native/jni/src/suggest/core/dicnode/dic_node.h
index 2230dc7b8..3970963e8 100644
--- a/native/jni/src/suggest/core/dicnode/dic_node.h
+++ b/native/jni/src/suggest/core/dicnode/dic_node.h
@@ -26,6 +26,7 @@
#include "suggest/core/dictionary/error_type_utils.h"
#include "suggest/core/layout/proximity_info_state.h"
#include "utils/char_utils.h"
+#include "utils/int_array_view.h"
#if DEBUG_DICT
#define LOGI_SHOW_ADD_COST_PROP \
@@ -136,18 +137,15 @@ class DicNode {
}
void initAsChild(const DicNode *const dicNode, const int childrenPtNodeArrayPos,
- const int probability, const int wordId, const bool hasChildren,
- const bool isBlacklistedOrNotAWord, const uint16_t mergedNodeCodePointCount,
- const int *const mergedNodeCodePoints) {
+ const int wordId, const CodePointArrayView mergedCodePoints) {
uint16_t newDepth = static_cast<uint16_t>(dicNode->getNodeCodePointCount() + 1);
mIsCachedForNextSuggestion = dicNode->mIsCachedForNextSuggestion;
const uint16_t newLeavingDepth = static_cast<uint16_t>(
- dicNode->mDicNodeProperties.getLeavingDepth() + mergedNodeCodePointCount);
- mDicNodeProperties.init(childrenPtNodeArrayPos, mergedNodeCodePoints[0],
- probability, wordId, hasChildren, isBlacklistedOrNotAWord, newDepth,
- newLeavingDepth, dicNode->mDicNodeProperties.getPrevWordIds());
- mDicNodeState.init(&dicNode->mDicNodeState, mergedNodeCodePointCount,
- mergedNodeCodePoints);
+ dicNode->mDicNodeProperties.getLeavingDepth() + mergedCodePoints.size());
+ mDicNodeProperties.init(childrenPtNodeArrayPos, mergedCodePoints[0],
+ wordId, newDepth, newLeavingDepth, dicNode->mDicNodeProperties.getPrevWordIds());
+ mDicNodeState.init(&dicNode->mDicNodeState, mergedCodePoints.size(),
+ mergedCodePoints.data());
PROF_NODE_COPY(&dicNode->mProfiler, mProfiler);
}
@@ -179,9 +177,6 @@ class DicNode {
// Check if the current word and the previous word can be considered as a valid multiple word
// suggestion.
bool isValidMultipleWordSuggestion() const {
- if (isBlacklistedOrNotAWord()) {
- return false;
- }
// Treat suggestion as invalid if the current and the previous word are single character
// words.
const int prevWordLen = mDicNodeState.mDicNodeStateOutput.getPrevWordsLength()
@@ -218,10 +213,6 @@ class DicNode {
return mDicNodeProperties.getChildrenPtNodeArrayPos();
}
- int getProbability() const {
- return mDicNodeProperties.getProbability();
- }
-
AK_FORCE_INLINE bool isTerminalDicNode() const {
const bool isTerminalPtNode = mDicNodeProperties.isTerminal();
const int currentDicNodeDepth = getNodeCodePointCount();
@@ -404,10 +395,6 @@ class DicNode {
return mDicNodeState.mDicNodeStateScoring.getContainedErrorTypes();
}
- bool isBlacklistedOrNotAWord() const {
- return mDicNodeProperties.isBlacklistedOrNotAWord();
- }
-
inline uint16_t getNodeCodePointCount() const {
return mDicNodeProperties.getDepth();
}
diff --git a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp
index 87d245276..fe5fe8448 100644
--- a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp
+++ b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp
@@ -18,7 +18,6 @@
#include "suggest/core/dicnode/dic_node.h"
#include "suggest/core/dicnode/dic_node_vector.h"
-#include "suggest/core/dictionary/multi_bigram_map.h"
#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h"
namespace latinime {
@@ -73,25 +72,16 @@ namespace latinime {
if (dicNode->hasMultipleWords() && !dicNode->isValidMultipleWordSuggestion()) {
return static_cast<float>(MAX_VALUE_FOR_WEIGHTING);
}
- const int probability = getBigramNodeProbability(dictionaryStructurePolicy, dicNode,
- multiBigramMap);
+ const WordAttributes wordAttributes = dictionaryStructurePolicy->getWordAttributesInContext(
+ dicNode->getPrevWordIds(), dicNode->getWordId(), multiBigramMap);
+ if (dicNode->hasMultipleWords()
+ && (wordAttributes.isBlacklisted() || wordAttributes.isNotAWord())) {
+ return static_cast<float>(MAX_VALUE_FOR_WEIGHTING);
+ }
// TODO: This equation to calculate the improbability looks unreasonable. Investigate this.
- const float cost = static_cast<float>(MAX_PROBABILITY - probability)
+ const float cost = static_cast<float>(MAX_PROBABILITY - wordAttributes.getProbability())
/ static_cast<float>(MAX_PROBABILITY);
return cost;
}
-/* static */ int DicNodeUtils::getBigramNodeProbability(
- const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy,
- const DicNode *const dicNode, MultiBigramMap *const multiBigramMap) {
- const int unigramProbability = dicNode->getProbability();
- if (multiBigramMap) {
- const int *const prevWordIds = dicNode->getPrevWordIds();
- return multiBigramMap->getBigramProbability(dictionaryStructurePolicy,
- prevWordIds, dicNode->getWordId(), unigramProbability);
- }
- return dictionaryStructurePolicy->getProbability(unigramProbability,
- NOT_A_PROBABILITY);
-}
-
} // namespace latinime
diff --git a/native/jni/src/suggest/core/dicnode/dic_node_utils.h b/native/jni/src/suggest/core/dicnode/dic_node_utils.h
index 56ff6e3d0..961a1c29d 100644
--- a/native/jni/src/suggest/core/dicnode/dic_node_utils.h
+++ b/native/jni/src/suggest/core/dicnode/dic_node_utils.h
@@ -46,10 +46,6 @@ class DicNodeUtils {
DISALLOW_IMPLICIT_CONSTRUCTORS(DicNodeUtils);
// Max number of bigrams to look up
static const int MAX_BIGRAMS_CONSIDERED_PER_CONTEXT = 500;
-
- static int getBigramNodeProbability(
- const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy,
- const DicNode *const dicNode, MultiBigramMap *const multiBigramMap);
};
} // namespace latinime
#endif // LATINIME_DIC_NODE_UTILS_H
diff --git a/native/jni/src/suggest/core/dicnode/dic_node_vector.h b/native/jni/src/suggest/core/dicnode/dic_node_vector.h
index b6a195103..e6b758954 100644
--- a/native/jni/src/suggest/core/dicnode/dic_node_vector.h
+++ b/native/jni/src/suggest/core/dicnode/dic_node_vector.h
@@ -21,6 +21,7 @@
#include "defines.h"
#include "suggest/core/dicnode/dic_node.h"
+#include "utils/int_array_view.h"
namespace latinime {
@@ -59,14 +60,10 @@ class DicNodeVector {
}
void pushLeavingChild(const DicNode *const dicNode, const int childrenPtNodeArrayPos,
- const int probability, const int wordId, const bool hasChildren,
- const bool isBlacklistedOrNotAWord, const uint16_t mergedNodeCodePointCount,
- const int *const mergedNodeCodePoints) {
+ const int wordId, const CodePointArrayView mergedCodePoints) {
ASSERT(!mLock);
mDicNodes.emplace_back();
- mDicNodes.back().initAsChild(dicNode, childrenPtNodeArrayPos, probability,
- wordId, hasChildren, isBlacklistedOrNotAWord, mergedNodeCodePointCount,
- mergedNodeCodePoints);
+ mDicNodes.back().initAsChild(dicNode, childrenPtNodeArrayPos, wordId, mergedCodePoints);
}
DicNode *operator[](const int id) {
diff --git a/native/jni/src/suggest/core/dicnode/internal/dic_node_properties.h b/native/jni/src/suggest/core/dicnode/internal/dic_node_properties.h
index be3134c91..6a1b84273 100644
--- a/native/jni/src/suggest/core/dicnode/internal/dic_node_properties.h
+++ b/native/jni/src/suggest/core/dicnode/internal/dic_node_properties.h
@@ -29,23 +29,17 @@ namespace latinime {
class DicNodeProperties {
public:
AK_FORCE_INLINE DicNodeProperties()
- : mChildrenPtNodeArrayPos(NOT_A_DICT_POS), mProbability(NOT_A_PROBABILITY),
- mDicNodeCodePoint(NOT_A_CODE_POINT), mWordId(NOT_A_WORD_ID),
- mHasChildrenPtNodes(false), mIsBlacklistedOrNotAWord(false), mDepth(0),
- mLeavingDepth(0) {}
+ : mChildrenPtNodeArrayPos(NOT_A_DICT_POS), mDicNodeCodePoint(NOT_A_CODE_POINT),
+ mWordId(NOT_A_WORD_ID), mDepth(0), mLeavingDepth(0) {}
~DicNodeProperties() {}
// Should be called only once per DicNode is initialized.
- void init(const int childrenPos, const int nodeCodePoint, const int probability,
- const int wordId, const bool hasChildren, const bool isBlacklistedOrNotAWord,
+ void init(const int childrenPos, const int nodeCodePoint, const int wordId,
const uint16_t depth, const uint16_t leavingDepth, const int *const prevWordIds) {
mChildrenPtNodeArrayPos = childrenPos;
mDicNodeCodePoint = nodeCodePoint;
- mProbability = probability;
mWordId = wordId;
- mHasChildrenPtNodes = hasChildren;
- mIsBlacklistedOrNotAWord = isBlacklistedOrNotAWord;
mDepth = depth;
mLeavingDepth = leavingDepth;
memmove(mPrevWordIds, prevWordIds, sizeof(mPrevWordIds));
@@ -55,10 +49,7 @@ class DicNodeProperties {
void init(const int rootPtNodeArrayPos, const int *const prevWordIds) {
mChildrenPtNodeArrayPos = rootPtNodeArrayPos;
mDicNodeCodePoint = NOT_A_CODE_POINT;
- mProbability = NOT_A_PROBABILITY;
mWordId = NOT_A_WORD_ID;
- mHasChildrenPtNodes = true;
- mIsBlacklistedOrNotAWord = false;
mDepth = 0;
mLeavingDepth = 0;
memmove(mPrevWordIds, prevWordIds, sizeof(mPrevWordIds));
@@ -67,10 +58,7 @@ class DicNodeProperties {
void initByCopy(const DicNodeProperties *const dicNodeProp) {
mChildrenPtNodeArrayPos = dicNodeProp->mChildrenPtNodeArrayPos;
mDicNodeCodePoint = dicNodeProp->mDicNodeCodePoint;
- mProbability = dicNodeProp->mProbability;
mWordId = dicNodeProp->mWordId;
- mHasChildrenPtNodes = dicNodeProp->mHasChildrenPtNodes;
- mIsBlacklistedOrNotAWord = dicNodeProp->mIsBlacklistedOrNotAWord;
mDepth = dicNodeProp->mDepth;
mLeavingDepth = dicNodeProp->mLeavingDepth;
memmove(mPrevWordIds, dicNodeProp->mPrevWordIds, sizeof(mPrevWordIds));
@@ -80,10 +68,7 @@ class DicNodeProperties {
void init(const DicNodeProperties *const dicNodeProp, const int codePoint) {
mChildrenPtNodeArrayPos = dicNodeProp->mChildrenPtNodeArrayPos;
mDicNodeCodePoint = codePoint; // Overwrite the node char of a passing child
- mProbability = dicNodeProp->mProbability;
mWordId = dicNodeProp->mWordId;
- mHasChildrenPtNodes = dicNodeProp->mHasChildrenPtNodes;
- mIsBlacklistedOrNotAWord = dicNodeProp->mIsBlacklistedOrNotAWord;
mDepth = dicNodeProp->mDepth + 1; // Increment the depth of a passing child
mLeavingDepth = dicNodeProp->mLeavingDepth;
memmove(mPrevWordIds, dicNodeProp->mPrevWordIds, sizeof(mPrevWordIds));
@@ -93,10 +78,6 @@ class DicNodeProperties {
return mChildrenPtNodeArrayPos;
}
- int getProbability() const {
- return mProbability;
- }
-
int getDicNodeCodePoint() const {
return mDicNodeCodePoint;
}
@@ -115,11 +96,7 @@ class DicNodeProperties {
}
bool hasChildren() const {
- return mHasChildrenPtNodes || mDepth != mLeavingDepth;
- }
-
- bool isBlacklistedOrNotAWord() const {
- return mIsBlacklistedOrNotAWord;
+ return (mChildrenPtNodeArrayPos != NOT_A_DICT_POS) || mDepth != mLeavingDepth;
}
const int *getPrevWordIds() const {
@@ -135,11 +112,8 @@ class DicNodeProperties {
// Use a default copy constructor and an assign operator because shallow copies are ok
// for this class
int mChildrenPtNodeArrayPos;
- int mProbability;
int mDicNodeCodePoint;
int mWordId;
- bool mHasChildrenPtNodes;
- bool mIsBlacklistedOrNotAWord;
uint16_t mDepth;
uint16_t mLeavingDepth;
int mPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
diff --git a/native/jni/src/suggest/core/dictionary/dictionary.cpp b/native/jni/src/suggest/core/dictionary/dictionary.cpp
index 8f9b2aa12..1de405104 100644
--- a/native/jni/src/suggest/core/dictionary/dictionary.cpp
+++ b/native/jni/src/suggest/core/dictionary/dictionary.cpp
@@ -61,10 +61,11 @@ void Dictionary::getSuggestions(ProximityInfo *proximityInfo, DicTraverseSession
}
Dictionary::NgramListenerForPrediction::NgramListenerForPrediction(
- const PrevWordsInfo *const prevWordsInfo, SuggestionResults *const suggestionResults,
+ const PrevWordsInfo *const prevWordsInfo, const WordIdArrayView prevWordIds,
+ SuggestionResults *const suggestionResults,
const DictionaryStructureWithBufferPolicy *const dictStructurePolicy)
- : mPrevWordsInfo(prevWordsInfo), mSuggestionResults(suggestionResults),
- mDictStructurePolicy(dictStructurePolicy) {}
+ : mPrevWordsInfo(prevWordsInfo), mPrevWordIds(prevWordIds),
+ mSuggestionResults(suggestionResults), mDictStructurePolicy(dictStructurePolicy) {}
void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbability,
const int targetWordId) {
@@ -83,19 +84,21 @@ void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbabi
if (codePointCount <= 0) {
return;
}
- const int probability = mDictStructurePolicy->getProbability(
- unigramProbability, ngramProbability);
- mSuggestionResults->addPrediction(targetWordCodePoints, codePointCount, probability);
+ const WordAttributes wordAttributes = mDictStructurePolicy->getWordAttributesInContext(
+ mPrevWordIds.data(), targetWordId, nullptr /* multiBigramMap */);
+ mSuggestionResults->addPrediction(targetWordCodePoints, codePointCount,
+ wordAttributes.getProbability());
}
void Dictionary::getPredictions(const PrevWordsInfo *const prevWordsInfo,
SuggestionResults *const outSuggestionResults) const {
TimeKeeper::setCurrentTime();
- NgramListenerForPrediction listener(prevWordsInfo, outSuggestionResults,
- mDictionaryStructureWithBufferPolicy.get());
int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
prevWordsInfo->getPrevWordIds(mDictionaryStructureWithBufferPolicy.get(), prevWordIds,
true /* tryLowerCaseSearch */);
+ NgramListenerForPrediction listener(prevWordsInfo,
+ WordIdArrayView::fromFixedSizeArray(prevWordIds), outSuggestionResults,
+ mDictionaryStructureWithBufferPolicy.get());
mDictionaryStructureWithBufferPolicy->iterateNgramEntries(prevWordIds, &listener);
}
diff --git a/native/jni/src/suggest/core/dictionary/dictionary.h b/native/jni/src/suggest/core/dictionary/dictionary.h
index 50951fbc1..0b54f30e9 100644
--- a/native/jni/src/suggest/core/dictionary/dictionary.h
+++ b/native/jni/src/suggest/core/dictionary/dictionary.h
@@ -26,6 +26,7 @@
#include "suggest/core/policy/dictionary_header_structure_policy.h"
#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h"
#include "suggest/core/suggest_interface.h"
+#include "utils/int_array_view.h"
namespace latinime {
@@ -118,7 +119,7 @@ class Dictionary {
class NgramListenerForPrediction : public NgramListener {
public:
NgramListenerForPrediction(const PrevWordsInfo *const prevWordsInfo,
- SuggestionResults *const suggestionResults,
+ const WordIdArrayView prevWordIds, SuggestionResults *const suggestionResults,
const DictionaryStructureWithBufferPolicy *const dictStructurePolicy);
virtual void onVisitEntry(const int ngramProbability, const int targetWordId);
@@ -126,6 +127,7 @@ class Dictionary {
DISALLOW_IMPLICIT_CONSTRUCTORS(NgramListenerForPrediction);
const PrevWordsInfo *const mPrevWordsInfo;
+ const WordIdArrayView mPrevWordIds;
SuggestionResults *const mSuggestionResults;
const DictionaryStructureWithBufferPolicy *const mDictStructurePolicy;
};
diff --git a/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp b/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp
index b372b6b4f..f71d4c5f0 100644
--- a/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp
+++ b/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp
@@ -59,8 +59,11 @@ namespace latinime {
if (!dicNode.isTerminalDicNode()) {
continue;
}
+ const WordAttributes wordAttributes =
+ dictionaryStructurePolicy->getWordAttributesInContext(dicNode.getPrevWordIds(),
+ dicNode.getWordId(), nullptr /* multiBigramMap */);
// dicNode can contain case errors, accent errors, intentional omissions or digraphs.
- maxProbability = std::max(maxProbability, dicNode.getProbability());
+ maxProbability = std::max(maxProbability, wordAttributes.getProbability());
}
return maxProbability;
}
diff --git a/native/jni/src/suggest/core/dictionary/word_attributes.h b/native/jni/src/suggest/core/dictionary/word_attributes.h
new file mode 100644
index 000000000..6e9da3570
--- /dev/null
+++ b/native/jni/src/suggest/core/dictionary/word_attributes.h
@@ -0,0 +1,60 @@
+/*
+ * Copyright (C) 2014, The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LATINIME_WORD_ATTRIBUTES_H
+#define LATINIME_WORD_ATTRIBUTES_H
+
+#include "defines.h"
+
+class WordAttributes {
+ public:
+ // Invalid word attributes.
+ WordAttributes()
+ : mProbability(NOT_A_PROBABILITY), mIsBlacklisted(false), mIsNotAWord(false),
+ mIsPossiblyOffensive(false) {}
+
+ WordAttributes(const int probability, const bool isBlacklisted, const bool isNotAWord,
+ const bool isPossiblyOffensive)
+ : mProbability(probability), mIsBlacklisted(isBlacklisted), mIsNotAWord(isNotAWord),
+ mIsPossiblyOffensive(isPossiblyOffensive) {}
+
+ int getProbability() const {
+ return mProbability;
+ }
+
+ bool isBlacklisted() const {
+ return mIsBlacklisted;
+ }
+
+ bool isNotAWord() const {
+ return mIsNotAWord;
+ }
+
+ bool isPossiblyOffensive() const {
+ return mIsPossiblyOffensive;
+ }
+
+ private:
+ DISALLOW_ASSIGNMENT_OPERATOR(WordAttributes);
+
+ int mProbability;
+ bool mIsBlacklisted;
+ bool mIsNotAWord;
+ bool mIsPossiblyOffensive;
+};
+
+ // namespace
+#endif /* LATINIME_WORD_ATTRIBUTES_H */
diff --git a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h
index aeeb66f93..7414f696c 100644
--- a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h
+++ b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h
@@ -22,6 +22,7 @@
#include "defines.h"
#include "suggest/core/dictionary/binary_dictionary_shortcut_iterator.h"
#include "suggest/core/dictionary/property/word_property.h"
+#include "suggest/core/dictionary/word_attributes.h"
#include "utils/int_array_view.h"
namespace latinime {
@@ -29,6 +30,7 @@ namespace latinime {
class DicNode;
class DicNodeVector;
class DictionaryHeaderStructurePolicy;
+class MultiBigramMap;
class NgramListener;
class PrevWordsInfo;
class UnigramProperty;
@@ -56,6 +58,10 @@ class DictionaryStructureWithBufferPolicy {
virtual int getWordId(const CodePointArrayView wordCodePoints,
const bool forceLowerCaseSearch) const = 0;
+ virtual const WordAttributes getWordAttributesInContext(const int *const prevWordIds,
+ const int wordId, MultiBigramMap *const multiBigramMap) const = 0;
+
+ // TODO: Remove
virtual int getProbability(const int unigramProbability, const int bigramProbability) const = 0;
virtual int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const = 0;
diff --git a/native/jni/src/suggest/core/policy/traversal.h b/native/jni/src/suggest/core/policy/traversal.h
index 8ddaa0514..6dfa7e314 100644
--- a/native/jni/src/suggest/core/policy/traversal.h
+++ b/native/jni/src/suggest/core/policy/traversal.h
@@ -48,7 +48,8 @@ class Traversal {
virtual int getTerminalCacheSize() const = 0;
virtual bool isPossibleOmissionChildNode(const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0;
- virtual bool isGoodToTraverseNextWord(const DicNode *const dicNode) const = 0;
+ virtual bool isGoodToTraverseNextWord(const DicNode *const dicNode,
+ const int probability) const = 0;
protected:
Traversal() {}
diff --git a/native/jni/src/suggest/core/result/suggestions_output_utils.cpp b/native/jni/src/suggest/core/result/suggestions_output_utils.cpp
index ad860c4a4..6e0193772 100644
--- a/native/jni/src/suggest/core/result/suggestions_output_utils.cpp
+++ b/native/jni/src/suggest/core/result/suggestions_output_utils.cpp
@@ -85,9 +85,9 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16;
scoringPolicy->getDoubleLetterDemotionDistanceCost(terminalDicNode);
const float compoundDistance = terminalDicNode->getCompoundDistance(languageWeight)
+ doubleLetterCost;
- const bool isPossiblyOffensiveWord =
- traverseSession->getDictionaryStructurePolicy()->getProbability(
- terminalDicNode->getProbability(), NOT_A_PROBABILITY) <= 0;
+ const WordAttributes wordAttributes = traverseSession->getDictionaryStructurePolicy()
+ ->getWordAttributesInContext(terminalDicNode->getPrevWordIds(),
+ terminalDicNode->getWordId(), nullptr /* multiBigramMap */);
const bool isExactMatch =
ErrorTypeUtils::isExactMatch(terminalDicNode->getContainedErrorTypes());
const bool isExactMatchWithIntentionalOmission =
@@ -97,19 +97,19 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16;
// Heuristic: We exclude probability=0 first-char-uppercase words from exact match.
// (e.g. "AMD" and "and")
const bool isSafeExactMatch = isExactMatch
- && !(isPossiblyOffensiveWord && isFirstCharUppercase);
+ && !(wordAttributes.isPossiblyOffensive() && isFirstCharUppercase);
const int outputTypeFlags =
- (isPossiblyOffensiveWord ? Dictionary::KIND_FLAG_POSSIBLY_OFFENSIVE : 0)
+ (wordAttributes.isPossiblyOffensive() ? Dictionary::KIND_FLAG_POSSIBLY_OFFENSIVE : 0)
| ((isSafeExactMatch && boostExactMatches) ? Dictionary::KIND_FLAG_EXACT_MATCH : 0)
| (isExactMatchWithIntentionalOmission ?
Dictionary::KIND_FLAG_EXACT_MATCH_WITH_INTENTIONAL_OMISSION : 0);
// Entries that are blacklisted or do not represent a word should not be output.
- const bool isValidWord = !terminalDicNode->isBlacklistedOrNotAWord();
+ const bool isValidWord = !(wordAttributes.isBlacklisted() || wordAttributes.isNotAWord());
// When we have to block offensive words, non-exact matched offensive words should not be
// output.
const bool blockOffensiveWords = traverseSession->getSuggestOptions()->blockOffensiveWords();
- const bool isBlockedOffensiveWord = blockOffensiveWords && isPossiblyOffensiveWord
+ const bool isBlockedOffensiveWord = blockOffensiveWords && wordAttributes.isPossiblyOffensive()
&& !isSafeExactMatch;
// Increase output score of top typing suggestion to ensure autocorrection.
diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp
index 66c87f04c..947d41f4b 100644
--- a/native/jni/src/suggest/core/suggest.cpp
+++ b/native/jni/src/suggest/core/suggest.cpp
@@ -21,6 +21,7 @@
#include "suggest/core/dicnode/dic_node_vector.h"
#include "suggest/core/dictionary/dictionary.h"
#include "suggest/core/dictionary/digraph_utils.h"
+#include "suggest/core/dictionary/word_attributes.h"
#include "suggest/core/layout/proximity_info.h"
#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h"
#include "suggest/core/policy/traversal.h"
@@ -412,7 +413,11 @@ void Suggest::weightChildNode(DicTraverseSession *traverseSession, DicNode *dicN
*/
void Suggest::createNextWordDicNode(DicTraverseSession *traverseSession, DicNode *dicNode,
const bool spaceSubstitution) const {
- if (!TRAVERSAL->isGoodToTraverseNextWord(dicNode)) {
+ const WordAttributes wordAttributes =
+ traverseSession->getDictionaryStructurePolicy()->getWordAttributesInContext(
+ dicNode->getPrevWordIds(), dicNode->getWordId(),
+ traverseSession->getMultiBigramMap());
+ if (!TRAVERSAL->isGoodToTraverseNextWord(dicNode, wordAttributes.getProbability())) {
return;
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
index f9013310c..9b8a50b07 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
@@ -28,6 +28,7 @@
#include "suggest/core/dicnode/dic_node.h"
#include "suggest/core/dicnode/dic_node_vector.h"
+#include "suggest/core/dictionary/multi_bigram_map.h"
#include "suggest/core/dictionary/ngram_listener.h"
#include "suggest/core/dictionary/property/bigram_property.h"
#include "suggest/core/dictionary/property/unigram_property.h"
@@ -78,10 +79,7 @@ void Ver4PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const d
}
const int wordId = isTerminal ? ptNodeParams.getHeadPos() : NOT_A_WORD_ID;
childDicNodes->pushLeavingChild(dicNode, ptNodeParams.getChildrenPos(),
- ptNodeParams.getProbability(), wordId, ptNodeParams.hasChildren(),
- ptNodeParams.isBlacklisted()
- || ptNodeParams.isNotAWord() /* isBlacklistedOrNotAWord */,
- ptNodeParams.getCodePointCount(), ptNodeParams.getCodePoints());
+ wordId, ptNodeParams.getCodePointArrayView());
}
if (readingHelper.isError()) {
mIsCorrupted = true;
@@ -117,6 +115,35 @@ int Ver4PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints,
return getWordIdFromTerminalPtNodePos(ptNodePos);
}
+const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext(
+ const int *const prevWordIds, const int wordId,
+ MultiBigramMap *const multiBigramMap) const {
+ if (wordId == NOT_A_WORD_ID) {
+ return WordAttributes();
+ }
+ const int ptNodePos = getTerminalPtNodePosFromWordId(wordId);
+ const PtNodeParams ptNodeParams(mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos));
+ if (multiBigramMap) {
+ const int probability = multiBigramMap->getBigramProbability(this /* structurePolicy */,
+ prevWordIds, wordId, ptNodeParams.getProbability());
+ return getWordAttributes(probability, ptNodeParams);
+ }
+ if (prevWordIds) {
+ const int probability = getProbabilityOfWord(prevWordIds, wordId);
+ if (probability != NOT_A_PROBABILITY) {
+ return getWordAttributes(probability, ptNodeParams);
+ }
+ }
+ return getWordAttributes(getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY),
+ ptNodeParams);
+}
+
+const WordAttributes Ver4PatriciaTriePolicy::getWordAttributes(const int probability,
+ const PtNodeParams &ptNodeParams) const {
+ return WordAttributes(probability, ptNodeParams.isBlacklisted(), ptNodeParams.isNotAWord(),
+ ptNodeParams.getProbability() == 0);
+}
+
int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability,
const int bigramProbability) const {
if (mHeaderPolicy->isDecayingDict()) {
@@ -333,7 +360,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI
}
bool addedNewBigram = false;
const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]);
- if (mUpdatingHelper.addNgramEntry(PtNodePosArrayView::fromObject(&prevWordPtNodePos),
+ if (mUpdatingHelper.addNgramEntry(PtNodePosArrayView::singleElementView(&prevWordPtNodePos),
wordPos, bigramProperty, &addedNewBigram)) {
if (addedNewBigram) {
mBigramCount++;
@@ -375,7 +402,7 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor
}
const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]);
if (mUpdatingHelper.removeNgramEntry(
- PtNodePosArrayView::fromObject(&prevWordPtNodePos), wordPos)) {
+ PtNodePosArrayView::singleElementView(&prevWordPtNodePos), wordPos)) {
mBigramCount--;
return true;
} else {
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h
index 562c219f4..871b556e1 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h
@@ -91,6 +91,9 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
+ const WordAttributes getWordAttributesInContext(const int *const prevWordIds, const int wordId,
+ MultiBigramMap *const multiBigramMap) const;
+
int getProbability(const int unigramProbability, const int bigramProbability) const;
int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const;
@@ -163,6 +166,8 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
int getShortcutPositionOfPtNode(const int ptNodePos) const;
int getWordIdFromTerminalPtNodePos(const int ptNodePos) const;
int getTerminalPtNodePosFromWordId(const int wordId) const;
+ const WordAttributes getWordAttributes(const int probability,
+ const PtNodeParams &ptNodeParams) const;
};
} // namespace v402
} // namespace backward
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h
index b2e60a837..c12fed324 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/pt_node_params.h
@@ -24,6 +24,7 @@
#include "suggest/policyimpl/dictionary/structure/pt_common/patricia_trie_reading_utils.h"
#include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h"
#include "utils/char_utils.h"
+#include "utils/int_array_view.h"
namespace latinime {
@@ -174,11 +175,17 @@ class PtNodeParams {
return mParentPos;
}
+ AK_FORCE_INLINE const CodePointArrayView getCodePointArrayView() const {
+ return CodePointArrayView(mCodePoints, mCodePointCount);
+ }
+
+ // TODO: Remove
// Number of code points
AK_FORCE_INLINE uint8_t getCodePointCount() const {
return mCodePointCount;
}
+ // TODO: Remove
AK_FORCE_INLINE const int *getCodePoints() const {
return mCodePoints;
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp
index b36c6f4df..e76bae97c 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp
@@ -21,6 +21,7 @@
#include "suggest/core/dicnode/dic_node.h"
#include "suggest/core/dicnode/dic_node_vector.h"
#include "suggest/core/dictionary/binary_dictionary_bigrams_iterator.h"
+#include "suggest/core/dictionary/multi_bigram_map.h"
#include "suggest/core/dictionary/ngram_listener.h"
#include "suggest/core/session/prev_words_info.h"
#include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h"
@@ -281,6 +282,35 @@ int PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints,
return getWordIdFromTerminalPtNodePos(ptNodePos);
}
+const WordAttributes PatriciaTriePolicy::getWordAttributesInContext(const int *const prevWordIds,
+ const int wordId, MultiBigramMap *const multiBigramMap) const {
+ if (wordId == NOT_A_WORD_ID) {
+ return WordAttributes();
+ }
+ const int ptNodePos = getTerminalPtNodePosFromWordId(wordId);
+ const PtNodeParams ptNodeParams =
+ mPtNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
+ if (multiBigramMap) {
+ const int probability = multiBigramMap->getBigramProbability(this /* structurePolicy */,
+ prevWordIds, wordId, ptNodeParams.getProbability());
+ return getWordAttributes(probability, ptNodeParams);
+ }
+ if (prevWordIds) {
+ const int bigramProbability = getProbabilityOfWord(prevWordIds, wordId);
+ if (bigramProbability != NOT_A_PROBABILITY) {
+ return getWordAttributes(bigramProbability, ptNodeParams);
+ }
+ }
+ return getWordAttributes(getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY),
+ ptNodeParams);
+}
+
+const WordAttributes PatriciaTriePolicy::getWordAttributes(const int probability,
+ const PtNodeParams &ptNodeParams) const {
+ return WordAttributes(probability, ptNodeParams.isBlacklisted(), ptNodeParams.isNotAWord(),
+ ptNodeParams.getProbability() == 0);
+}
+
int PatriciaTriePolicy::getProbability(const int unigramProbability,
const int bigramProbability) const {
// Due to space constraints, the probability for bigrams is approximate - the lower the unigram
@@ -377,11 +407,8 @@ int PatriciaTriePolicy::createAndGetLeavingChildNode(const DicNode *const dicNod
// Skip PtNodes don't start with Unicode code point because they represent non-word information.
if (CharUtils::isInUnicodeSpace(mergedNodeCodePoints[0])) {
const int wordId = PatriciaTrieReadingUtils::isTerminal(flags) ? ptNodePos : NOT_A_WORD_ID;
- childDicNodes->pushLeavingChild(dicNode, childrenPos, probability, wordId,
- PatriciaTrieReadingUtils::hasChildrenInFlags(flags),
- PatriciaTrieReadingUtils::isBlacklisted(flags)
- || PatriciaTrieReadingUtils::isNotAWord(flags),
- mergedNodeCodePointCount, mergedNodeCodePoints);
+ childDicNodes->pushLeavingChild(dicNode, childrenPos, wordId,
+ CodePointArrayView(mergedNodeCodePoints, mergedNodeCodePointCount));
}
return siblingPos;
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h
index 66df52779..8c1665d7d 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h
@@ -66,6 +66,9 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
+ const WordAttributes getWordAttributesInContext(const int *const prevWordIds, const int wordId,
+ MultiBigramMap *const multiBigramMap) const;
+
int getProbability(const int unigramProbability, const int bigramProbability) const;
int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const;
@@ -160,6 +163,8 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
DicNodeVector *const childDicNodes) const;
int getWordIdFromTerminalPtNodePos(const int ptNodePos) const;
int getTerminalPtNodePosFromWordId(const int wordId) const;
+ const WordAttributes getWordAttributes(const int probability,
+ const PtNodeParams &ptNodeParams) const;
};
} // namespace latinime
#endif // LATINIME_PATRICIA_TRIE_POLICY_H
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
index d5749e9eb..f54bb151a 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
@@ -38,6 +38,40 @@ bool LanguageModelDictContent::runGC(
0 /* nextLevelBitmapEntryIndex */, outNgramCount);
}
+int LanguageModelDictContent::getWordProbability(const WordIdArrayView prevWordIds,
+ const int wordId) const {
+ int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
+ bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex();
+ int maxLevel = 0;
+ for (size_t i = 0; i < prevWordIds.size(); ++i) {
+ const int nextBitmapEntryIndex =
+ mTrieMap.get(prevWordIds[i], bitmapEntryIndices[i]).mNextLevelBitmapEntryIndex;
+ if (nextBitmapEntryIndex == TrieMap::INVALID_INDEX) {
+ break;
+ }
+ maxLevel = i + 1;
+ bitmapEntryIndices[i + 1] = nextBitmapEntryIndex;
+ }
+
+ for (int i = maxLevel; i >= 0; --i) {
+ const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndices[i]);
+ if (!result.mIsValid) {
+ continue;
+ }
+ const int probability =
+ ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo).getProbability();
+ if (mHasHistoricalInfo) {
+ return std::min(
+ probability + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */),
+ MAX_PROBABILITY);
+ } else {
+ return probability;
+ }
+ }
+ // Cannot find the word.
+ return NOT_A_PROBABILITY;
+}
+
ProbabilityEntry LanguageModelDictContent::getNgramProbabilityEntry(
const WordIdArrayView prevWordIds, const int wordId) const {
const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds);
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
index aa612e35a..4e0b47036 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
@@ -128,6 +128,8 @@ class LanguageModelDictContent {
const LanguageModelDictContent *const originalContent,
int *const outNgramCount);
+ int getWordProbability(const WordIdArrayView prevWordIds, const int wordId) const;
+
ProbabilityEntry getProbabilityEntry(const int wordId) const {
return getNgramProbabilityEntry(WordIdArrayView(), wordId);
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
index aca2f6cae..0472a453a 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
@@ -20,6 +20,7 @@
#include "suggest/core/dicnode/dic_node.h"
#include "suggest/core/dicnode/dic_node_vector.h"
+#include "suggest/core/dictionary/multi_bigram_map.h"
#include "suggest/core/dictionary/ngram_listener.h"
#include "suggest/core/dictionary/property/bigram_property.h"
#include "suggest/core/dictionary/property/unigram_property.h"
@@ -68,10 +69,7 @@ void Ver4PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const d
}
const int wordId = isTerminal ? ptNodeParams.getTerminalId() : NOT_A_WORD_ID;
childDicNodes->pushLeavingChild(dicNode, ptNodeParams.getChildrenPos(),
- ptNodeParams.getProbability(), wordId, ptNodeParams.hasChildren(),
- ptNodeParams.isBlacklisted()
- || ptNodeParams.isNotAWord() /* isBlacklistedOrNotAWord */,
- ptNodeParams.getCodePointCount(), ptNodeParams.getCodePoints());
+ wordId, ptNodeParams.getCodePointArrayView());
}
if (readingHelper.isError()) {
mIsCorrupted = true;
@@ -112,6 +110,21 @@ int Ver4PatriciaTriePolicy::getWordId(const CodePointArrayView wordCodePoints,
return ptNodeParams.getTerminalId();
}
+const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext(
+ const int *const prevWordIds, const int wordId,
+ MultiBigramMap *const multiBigramMap) const {
+ if (wordId == NOT_A_WORD_ID) {
+ return WordAttributes();
+ }
+ const int ptNodePos =
+ mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
+ const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
+ // TODO: Support n-gram.
+ return WordAttributes(mBuffers->getLanguageModelDictContent()->getWordProbability(
+ WordIdArrayView::singleElementView(prevWordIds), wordId), ptNodeParams.isBlacklisted(),
+ ptNodeParams.isNotAWord(), ptNodeParams.getProbability() == 0);
+}
+
int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability,
const int bigramProbability) const {
if (mHeaderPolicy->isDecayingDict()) {
@@ -143,7 +156,7 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const int *const prevWordIds,
// TODO: Support n-gram.
const ProbabilityEntry probabilityEntry =
mBuffers->getLanguageModelDictContent()->getNgramProbabilityEntry(
- IntArrayView::fromObject(prevWordIds), wordId);
+ IntArrayView::singleElementView(prevWordIds), wordId);
if (!probabilityEntry.isValid()) {
return NOT_A_PROBABILITY;
}
@@ -171,7 +184,7 @@ void Ver4PatriciaTriePolicy::iterateNgramEntries(const int *const prevWordIds,
// TODO: Support n-gram.
const auto languageModelDictContent = mBuffers->getLanguageModelDictContent();
for (const auto entry : languageModelDictContent->getProbabilityEntries(
- WordIdArrayView::fromObject(prevWordIds))) {
+ WordIdArrayView::singleElementView(prevWordIds))) {
const ProbabilityEntry &probabilityEntry = entry.getProbabilityEntry();
const int probability = probabilityEntry.hasHistoricalInfo() ?
ForgettingCurveUtils::decodeProbability(
@@ -488,7 +501,7 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(
// Fetch bigram information.
// TODO: Support n-gram.
std::vector<BigramProperty> bigrams;
- const WordIdArrayView prevWordIds = WordIdArrayView::fromObject(&wordId);
+ const WordIdArrayView prevWordIds = WordIdArrayView::singleElementView(&wordId);
int bigramWord1CodePoints[MAX_WORD_LENGTH];
for (const auto entry : mBuffers->getLanguageModelDictContent()->getProbabilityEntries(
prevWordIds)) {
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h
index 0b8eec40b..980c16e4a 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h
@@ -68,6 +68,9 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
+ const WordAttributes getWordAttributesInContext(const int *const prevWordIds, const int wordId,
+ MultiBigramMap *const multiBigramMap) const;
+
int getProbability(const int unigramProbability, const int bigramProbability) const;
int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h
index 9910777b8..313eb6b64 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h
@@ -48,6 +48,11 @@ class ForgettingCurveUtils {
static bool needsToDecay(const bool mindsBlockByDecay, const int unigramCount,
const int bigramCount, const HeaderPolicy *const headerPolicy);
+ // TODO: Improve probability computation method and remove this.
+ static int getProbabilityBiasForNgram(const int n) {
+ return (n - 1) * MULTIPLIER_TWO_IN_PROBABILITY_SCALE;
+ }
+
AK_FORCE_INLINE static int getUnigramCountHardLimit(const int maxUnigramCount) {
return static_cast<int>(static_cast<float>(maxUnigramCount)
* UNIGRAM_COUNT_HARD_LIMIT_WEIGHT);
diff --git a/native/jni/src/suggest/policyimpl/typing/typing_traversal.h b/native/jni/src/suggest/policyimpl/typing/typing_traversal.h
index cb3dfac70..b64ee8be4 100644
--- a/native/jni/src/suggest/policyimpl/typing/typing_traversal.h
+++ b/native/jni/src/suggest/policyimpl/typing/typing_traversal.h
@@ -161,8 +161,8 @@ class TypingTraversal : public Traversal {
return true;
}
- AK_FORCE_INLINE bool isGoodToTraverseNextWord(const DicNode *const dicNode) const {
- const int probability = dicNode->getProbability();
+ AK_FORCE_INLINE bool isGoodToTraverseNextWord(const DicNode *const dicNode,
+ const int probability) const {
if (probability < ScoringParams::THRESHOLD_NEXT_WORD_PROBABILITY) {
return false;
}
diff --git a/native/jni/src/utils/int_array_view.h b/native/jni/src/utils/int_array_view.h
index c9c3b21d4..c39add9fe 100644
--- a/native/jni/src/utils/int_array_view.h
+++ b/native/jni/src/utils/int_array_view.h
@@ -17,8 +17,9 @@
#ifndef LATINIME_INT_ARRAY_VIEW_H
#define LATINIME_INT_ARRAY_VIEW_H
+#include <array>
#include <cstdint>
-#include <cstdlib>
+#include <cstring>
#include <vector>
#include "defines.h"
@@ -61,9 +62,9 @@ class IntArrayView {
return IntArrayView(array, N);
}
- // Returns a view that points one int object. Does not take ownership of the given object.
- AK_FORCE_INLINE static IntArrayView fromObject(const int *const object) {
- return IntArrayView(object, 1);
+ // Returns a view that points one int object.
+ AK_FORCE_INLINE static IntArrayView singleElementView(const int *const ptr) {
+ return IntArrayView(ptr, 1);
}
AK_FORCE_INLINE int operator[](const size_t index) const {
@@ -103,6 +104,12 @@ class IntArrayView {
return IntArrayView(mPtr + n, mSize - n);
}
+ template <size_t N>
+ void copyToArray(std::array<int, N> *const buffer, const size_t offset) const {
+ ASSERT(mSize + offset <= N);
+ memmove(buffer->data() + offset, mPtr, sizeof(int) * mSize);
+ }
+
private:
DISALLOW_ASSIGNMENT_OPERATOR(IntArrayView);