aboutsummaryrefslogtreecommitdiffstats
path: root/native
diff options
context:
space:
mode:
Diffstat (limited to 'native')
-rw-r--r--native/jni/src/suggest/core/dictionary/suggestions_output_utils.cpp10
-rw-r--r--native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h2
-rw-r--r--native/jni/src/suggest/core/policy/scoring.h3
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h5
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp8
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h15
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp30
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h6
-rw-r--r--native/jni/src/suggest/policyimpl/typing/typing_scoring.h4
9 files changed, 71 insertions, 12 deletions
diff --git a/native/jni/src/suggest/core/dictionary/suggestions_output_utils.cpp b/native/jni/src/suggest/core/dictionary/suggestions_output_utils.cpp
index b8106377c..e37811b88 100644
--- a/native/jni/src/suggest/core/dictionary/suggestions_output_utils.cpp
+++ b/native/jni/src/suggest/core/dictionary/suggestions_output_utils.cpp
@@ -78,7 +78,8 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16;
outputAutoCommitFirstWordConfidence[0] =
computeFirstWordConfidence(&terminals[0]);
}
-
+ const bool boostExactMatches = traverseSession->getDictionaryStructurePolicy()->
+ getHeaderStructurePolicy()->shouldBoostExactMatches();
// Output suggestion results here
for (int terminalIndex = 0; terminalIndex < terminalSize && outputWordIndex < MAX_RESULTS;
++terminalIndex) {
@@ -102,7 +103,7 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16;
&& !(isPossiblyOffensiveWord && isFirstCharUppercase);
const int outputTypeFlags =
(isPossiblyOffensiveWord ? Dictionary::KIND_FLAG_POSSIBLY_OFFENSIVE : 0)
- | (isSafeExactMatch ? Dictionary::KIND_FLAG_EXACT_MATCH : 0);
+ | ((isSafeExactMatch && boostExactMatches) ? Dictionary::KIND_FLAG_EXACT_MATCH : 0);
// Entries that are blacklisted or do not represent a word should not be output.
const bool isValidWord = !terminalDicNode->isBlacklistedOrNotAWord();
@@ -113,7 +114,8 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16;
compoundDistance, traverseSession->getInputSize(),
terminalDicNode->getContainedErrorTypes(),
(forceCommitMultiWords && terminalDicNode->hasMultipleWords())
- || (isValidWord && scoringPolicy->doesAutoCorrectValidWord()));
+ || (isValidWord && scoringPolicy->doesAutoCorrectValidWord()),
+ boostExactMatches);
if (maxScore < finalScore && isValidWord) {
maxScore = finalScore;
}
@@ -147,7 +149,7 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16;
scoringPolicy->calculateFinalScore(compoundDistance,
traverseSession->getInputSize(),
terminalDicNode->getContainedErrorTypes(),
- true /* forceCommit */) : finalScore;
+ true /* forceCommit */, boostExactMatches) : finalScore;
const int updatedOutputWordIndex = outputShortcuts(&shortcutIt,
outputWordIndex, shortcutBaseScore, outputCodePoints, frequencies, outputTypes,
sameAsTyped);
diff --git a/native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h b/native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h
index b76b13971..417620e00 100644
--- a/native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h
+++ b/native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h
@@ -40,6 +40,8 @@ class DictionaryHeaderStructurePolicy {
virtual void readHeaderValueOrQuestionMark(const char *const key, int *outValue,
int outValueSize) const = 0;
+ virtual bool shouldBoostExactMatches() const = 0;
+
protected:
DictionaryHeaderStructurePolicy() {}
diff --git a/native/jni/src/suggest/core/policy/scoring.h b/native/jni/src/suggest/core/policy/scoring.h
index 783383450..e581a97c3 100644
--- a/native/jni/src/suggest/core/policy/scoring.h
+++ b/native/jni/src/suggest/core/policy/scoring.h
@@ -28,7 +28,8 @@ class DicTraverseSession;
class Scoring {
public:
virtual int calculateFinalScore(const float compoundDistance, const int inputSize,
- const ErrorTypeUtils::ErrorType containedErrorTypes, const bool forceCommit) const = 0;
+ const ErrorTypeUtils::ErrorType containedErrorTypes, const bool forceCommit,
+ const bool boostExactMatches) const = 0;
virtual bool getMostProbableString(const DicTraverseSession *const traverseSession,
const int terminalSize, const float languageWeight, int *const outputCodePoints,
int *const type, int *const freq) const = 0;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h
index a44f9f0fc..1320c6560 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h
@@ -146,6 +146,11 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
return mHasHistoricalInfoOfWords;
}
+ AK_FORCE_INLINE bool shouldBoostExactMatches() const {
+ // TODO: Investigate better ways to handle exact matches for personalized dictionaries.
+ return !isDecayingDict();
+ }
+
void readHeaderValueOrQuestionMark(const char *const key,
int *outValue, int outValueSize) const;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp
index b918e0765..824d442e4 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp
@@ -28,6 +28,14 @@ const int DynamicPtReadingHelper::MAX_CHILD_COUNT_TO_AVOID_INFINITE_LOOP = 10000
const int DynamicPtReadingHelper::MAX_PT_NODE_ARRAY_COUNT_TO_AVOID_INFINITE_LOOP = 100000;
const size_t DynamicPtReadingHelper::MAX_READING_STATE_STACK_SIZE = MAX_WORD_LENGTH;
+bool DynamicPtReadingHelper::TraversePolicyToGetAllTerminalPtNodePositions::onVisitingPtNode(
+ const PtNodeParams *const ptNodeParams) {
+ if (ptNodeParams->isTerminal() && !ptNodeParams->isDeleted()) {
+ mTerminalPositions->push_back(ptNodeParams->getHeadPos());
+ }
+ return true;
+}
+
// Visits all PtNodes in post-order depth first manner.
// For example, visits c -> b -> y -> x -> a for the following dictionary:
// a _ b _ c
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h
index a69490943..bcc5c7857 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h
@@ -59,6 +59,21 @@ class DynamicPtReadingHelper {
DISALLOW_COPY_AND_ASSIGN(TraversingEventListener);
};
+ class TraversePolicyToGetAllTerminalPtNodePositions : public TraversingEventListener {
+ public:
+ TraversePolicyToGetAllTerminalPtNodePositions(std::vector<int> *const terminalPositions)
+ : mTerminalPositions(terminalPositions) {}
+ bool onAscend() { return true; }
+ bool onDescend(const int ptNodeArrayPos) { return true; }
+ bool onReadingPtNodeArrayTail() { return true; }
+ bool onVisitingPtNode(const PtNodeParams *const ptNodeParams);
+
+ private:
+ DISALLOW_IMPLICIT_CONSTRUCTORS(TraversePolicyToGetAllTerminalPtNodePositions);
+
+ std::vector<int> *const mTerminalPositions;
+ };
+
DynamicPtReadingHelper(const BufferWithExtendableBuffer *const buffer,
const PtNodeReader *const ptNodeReader)
: mIsError(false), mReadingState(), mBuffer(buffer),
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
index 1c420e070..75d85988c 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
@@ -392,10 +392,32 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(const int *const code
historicalInfo->getCount(), &bigrams, &shortcuts);
}
-int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token,
- int *const outCodePoints) {
- // TODO: Implement.
- return 0;
+int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outCodePoints) {
+ if (token == 0) {
+ mTerminalPtNodePositionsForIteratingWords.clear();
+ DynamicPtReadingHelper::TraversePolicyToGetAllTerminalPtNodePositions traversePolicy(
+ &mTerminalPtNodePositionsForIteratingWords);
+ DynamicPtReadingHelper readingHelper(mDictBuffer, &mNodeReader);
+ readingHelper.initWithPtNodeArrayPos(getRootPosition());
+ readingHelper.traverseAllPtNodesInPostorderDepthFirstManner(&traversePolicy);
+ }
+ const int terminalPtNodePositionsVectorSize =
+ static_cast<int>(mTerminalPtNodePositionsForIteratingWords.size());
+ if (token < 0 || token >= terminalPtNodePositionsVectorSize) {
+ AKLOGE("Given token %d is invalid.", token);
+ return 0;
+ }
+ const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token];
+ int unigramProbability = NOT_A_PROBABILITY;
+ getCodePointsAndProbabilityAndReturnCodePointCount(terminalPtNodePos, MAX_WORD_LENGTH,
+ outCodePoints, &unigramProbability);
+ const int nextToken = token + 1;
+ if (nextToken >= terminalPtNodePositionsVectorSize) {
+ // All words have been iterated.
+ mTerminalPtNodePositionsForIteratingWords.clear();
+ return 0;
+ }
+ return nextToken;
}
} // namespace latinime
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h
index 1bcd4ceea..9ba5be0c3 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h
@@ -17,6 +17,8 @@
#ifndef LATINIME_VER4_PATRICIA_TRIE_POLICY_H
#define LATINIME_VER4_PATRICIA_TRIE_POLICY_H
+#include <vector>
+
#include "defines.h"
#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h"
#include "suggest/policyimpl/dictionary/bigram/ver4_bigram_list_policy.h"
@@ -50,7 +52,8 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
mUpdatingHelper(mDictBuffer, &mNodeReader, &mNodeWriter),
mWritingHelper(mBuffers.get()),
mUnigramCount(mHeaderPolicy->getUnigramCount()),
- mBigramCount(mHeaderPolicy->getBigramCount()) {};
+ mBigramCount(mHeaderPolicy->getBigramCount()),
+ mTerminalPtNodePositionsForIteratingWords() {};
AK_FORCE_INLINE int getRootPosition() const {
return 0;
@@ -134,6 +137,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
Ver4PatriciaTrieWritingHelper mWritingHelper;
int mUnigramCount;
int mBigramCount;
+ std::vector<int> mTerminalPtNodePositionsForIteratingWords;
};
} // namespace latinime
#endif // LATINIME_VER4_PATRICIA_TRIE_POLICY_H
diff --git a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h
index c777e7238..8b405e8de 100644
--- a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h
+++ b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h
@@ -50,14 +50,14 @@ class TypingScoring : public Scoring {
AK_FORCE_INLINE int calculateFinalScore(const float compoundDistance,
const int inputSize, const ErrorTypeUtils::ErrorType containedErrorTypes,
- const bool forceCommit) const {
+ const bool forceCommit, const bool boostExactMatches) const {
const float maxDistance = ScoringParams::DISTANCE_WEIGHT_LANGUAGE
+ static_cast<float>(inputSize) * ScoringParams::TYPING_MAX_OUTPUT_SCORE_PER_INPUT;
float score = ScoringParams::TYPING_BASE_OUTPUT_SCORE - compoundDistance / maxDistance;
if (forceCommit) {
score += ScoringParams::AUTOCORRECT_OUTPUT_THRESHOLD;
}
- if (ErrorTypeUtils::isExactMatch(containedErrorTypes)) {
+ if (boostExactMatches && ErrorTypeUtils::isExactMatch(containedErrorTypes)) {
score += ScoringParams::EXACT_MATCH_PROMOTION;
if ((ErrorTypeUtils::MATCH_WITH_CASE_ERROR & containedErrorTypes) != 0) {
score -= ScoringParams::CASE_ERROR_PENALTY_FOR_EXACT_MATCH;