aboutsummaryrefslogtreecommitdiffstats
path: root/native
diff options
context:
space:
mode:
Diffstat (limited to 'native')
-rw-r--r--native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp28
-rw-r--r--native/jni/src/defines.h2
-rw-r--r--native/jni/src/suggest/core/dictionary/suggestions_output_utils.cpp10
-rw-r--r--native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h2
-rw-r--r--native/jni/src/suggest/core/policy/scoring.h3
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp6
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h18
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp19
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.h6
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp8
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h15
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp30
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h6
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp8
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.h2
-rw-r--r--native/jni/src/suggest/policyimpl/typing/typing_scoring.h4
-rw-r--r--native/jni/src/utils/char_utils.cpp2
-rw-r--r--native/jni/src/utils/char_utils.h9
18 files changed, 155 insertions, 23 deletions
diff --git a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp
index c919ebd91..f5c3ee63c 100644
--- a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp
+++ b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp
@@ -29,6 +29,7 @@
#include "suggest/policyimpl/dictionary/structure/dictionary_structure_with_buffer_policy_factory.h"
#include "suggest/policyimpl/dictionary/utils/dict_file_writing_utils.h"
#include "utils/autocorrection_threshold_utils.h"
+#include "utils/char_utils.h"
#include "utils/time_keeper.h"
namespace latinime {
@@ -37,13 +38,15 @@ class ProximityInfo;
// TODO: Move to makedict.
static jboolean latinime_BinaryDictionary_createEmptyDictFile(JNIEnv *env, jclass clazz,
- jstring filePath, jlong dictVersion, jobjectArray attributeKeyStringArray,
+ jstring filePath, jlong dictVersion, jstring locale, jobjectArray attributeKeyStringArray,
jobjectArray attributeValueStringArray) {
const jsize filePathUtf8Length = env->GetStringUTFLength(filePath);
char filePathChars[filePathUtf8Length + 1];
env->GetStringUTFRegion(filePath, 0, env->GetStringLength(filePath), filePathChars);
filePathChars[filePathUtf8Length] = '\0';
-
+ jsize localeLength = env->GetStringLength(locale);
+ jchar localeCodePoints[localeLength];
+ env->GetStringRegion(locale, 0, localeLength, localeCodePoints);
const int keyCount = env->GetArrayLength(attributeKeyStringArray);
const int valueCount = env->GetArrayLength(attributeValueStringArray);
if (keyCount != valueCount) {
@@ -73,7 +76,7 @@ static jboolean latinime_BinaryDictionary_createEmptyDictFile(JNIEnv *env, jclas
}
return DictFileWritingUtils::createEmptyDictFile(filePathChars, static_cast<int>(dictVersion),
- &attributeMap);
+ CharUtils::convertShortArrayToIntVector(localeCodePoints, localeLength), &attributeMap);
}
static jlong latinime_BinaryDictionary_open(JNIEnv *env, jclass clazz, jstring sourceDir,
@@ -137,6 +140,17 @@ static void latinime_BinaryDictionary_close(JNIEnv *env, jclass clazz, jlong dic
delete dictionary;
}
+static void latinime_BinaryDictionary_getHeaderInfo(JNIEnv *env, jclass clazz, jlong dict,
+ jintArray outHeaderSize, jintArray outFormatVersion, jobject outAttributeKeys,
+ jobject outAttributeValues) {
+ Dictionary *dictionary = reinterpret_cast<Dictionary *>(dict);
+ if (!dictionary) return;
+ const int formatVersion = dictionary->getFormatVersionNumber();
+ env->SetIntArrayRegion(outFormatVersion, 0 /* start */, 1 /* len */, &formatVersion);
+ // TODO: Implement
+ return;
+}
+
static int latinime_BinaryDictionary_getFormatVersion(JNIEnv *env, jclass clazz, jlong dict) {
Dictionary *dictionary = reinterpret_cast<Dictionary *>(dict);
if (!dictionary) return 0;
@@ -492,7 +506,8 @@ static int latinime_BinaryDictionary_setCurrentTimeForTest(JNIEnv *env, jclass c
static const JNINativeMethod sMethods[] = {
{
const_cast<char *>("createEmptyDictFileNative"),
- const_cast<char *>("(Ljava/lang/String;J[Ljava/lang/String;[Ljava/lang/String;)Z"),
+ const_cast<char *>(
+ "(Ljava/lang/String;JLjava/lang/String;[Ljava/lang/String;[Ljava/lang/String;)Z"),
reinterpret_cast<void *>(latinime_BinaryDictionary_createEmptyDictFile)
},
{
@@ -511,6 +526,11 @@ static const JNINativeMethod sMethods[] = {
reinterpret_cast<void *>(latinime_BinaryDictionary_getFormatVersion)
},
{
+ const_cast<char *>("getHeaderInfoNative"),
+ const_cast<char *>("(J[I[ILjava/util/ArrayList;Ljava/util/ArrayList;)V"),
+ reinterpret_cast<void *>(latinime_BinaryDictionary_getHeaderInfo)
+ },
+ {
const_cast<char *>("flushNative"),
const_cast<char *>("(JLjava/lang/String;)V"),
reinterpret_cast<void *>(latinime_BinaryDictionary_flush)
diff --git a/native/jni/src/defines.h b/native/jni/src/defines.h
index 1969ebae0..22cc4c02b 100644
--- a/native/jni/src/defines.h
+++ b/native/jni/src/defines.h
@@ -311,7 +311,7 @@ static inline void prof_out(void) {
// A special value to mean the first word confidence makes no sense in this case,
// e.g. this is not a multi-word suggestion.
-#define NOT_A_FIRST_WORD_CONFIDENCE (S_INT_MAX)
+#define NOT_A_FIRST_WORD_CONFIDENCE (S_INT_MIN)
// How high the confidence needs to be for us to auto-commit. Arbitrary.
// This needs to be the same as CONFIDENCE_FOR_AUTO_COMMIT in BinaryDictionary.java
#define CONFIDENCE_FOR_AUTO_COMMIT (1000000)
diff --git a/native/jni/src/suggest/core/dictionary/suggestions_output_utils.cpp b/native/jni/src/suggest/core/dictionary/suggestions_output_utils.cpp
index b8106377c..e37811b88 100644
--- a/native/jni/src/suggest/core/dictionary/suggestions_output_utils.cpp
+++ b/native/jni/src/suggest/core/dictionary/suggestions_output_utils.cpp
@@ -78,7 +78,8 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16;
outputAutoCommitFirstWordConfidence[0] =
computeFirstWordConfidence(&terminals[0]);
}
-
+ const bool boostExactMatches = traverseSession->getDictionaryStructurePolicy()->
+ getHeaderStructurePolicy()->shouldBoostExactMatches();
// Output suggestion results here
for (int terminalIndex = 0; terminalIndex < terminalSize && outputWordIndex < MAX_RESULTS;
++terminalIndex) {
@@ -102,7 +103,7 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16;
&& !(isPossiblyOffensiveWord && isFirstCharUppercase);
const int outputTypeFlags =
(isPossiblyOffensiveWord ? Dictionary::KIND_FLAG_POSSIBLY_OFFENSIVE : 0)
- | (isSafeExactMatch ? Dictionary::KIND_FLAG_EXACT_MATCH : 0);
+ | ((isSafeExactMatch && boostExactMatches) ? Dictionary::KIND_FLAG_EXACT_MATCH : 0);
// Entries that are blacklisted or do not represent a word should not be output.
const bool isValidWord = !terminalDicNode->isBlacklistedOrNotAWord();
@@ -113,7 +114,8 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16;
compoundDistance, traverseSession->getInputSize(),
terminalDicNode->getContainedErrorTypes(),
(forceCommitMultiWords && terminalDicNode->hasMultipleWords())
- || (isValidWord && scoringPolicy->doesAutoCorrectValidWord()));
+ || (isValidWord && scoringPolicy->doesAutoCorrectValidWord()),
+ boostExactMatches);
if (maxScore < finalScore && isValidWord) {
maxScore = finalScore;
}
@@ -147,7 +149,7 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16;
scoringPolicy->calculateFinalScore(compoundDistance,
traverseSession->getInputSize(),
terminalDicNode->getContainedErrorTypes(),
- true /* forceCommit */) : finalScore;
+ true /* forceCommit */, boostExactMatches) : finalScore;
const int updatedOutputWordIndex = outputShortcuts(&shortcutIt,
outputWordIndex, shortcutBaseScore, outputCodePoints, frequencies, outputTypes,
sameAsTyped);
diff --git a/native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h b/native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h
index b76b13971..417620e00 100644
--- a/native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h
+++ b/native/jni/src/suggest/core/policy/dictionary_header_structure_policy.h
@@ -40,6 +40,8 @@ class DictionaryHeaderStructurePolicy {
virtual void readHeaderValueOrQuestionMark(const char *const key, int *outValue,
int outValueSize) const = 0;
+ virtual bool shouldBoostExactMatches() const = 0;
+
protected:
DictionaryHeaderStructurePolicy() {}
diff --git a/native/jni/src/suggest/core/policy/scoring.h b/native/jni/src/suggest/core/policy/scoring.h
index 783383450..e581a97c3 100644
--- a/native/jni/src/suggest/core/policy/scoring.h
+++ b/native/jni/src/suggest/core/policy/scoring.h
@@ -28,7 +28,8 @@ class DicTraverseSession;
class Scoring {
public:
virtual int calculateFinalScore(const float compoundDistance, const int inputSize,
- const ErrorTypeUtils::ErrorType containedErrorTypes, const bool forceCommit) const = 0;
+ const ErrorTypeUtils::ErrorType containedErrorTypes, const bool forceCommit,
+ const bool boostExactMatches) const = 0;
virtual bool getMostProbableString(const DicTraverseSession *const traverseSession,
const int terminalSize, const float languageWeight, int *const outputCodePoints,
int *const type, int *const freq) const = 0;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp
index 7504524f0..b5b5ed740 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp
@@ -32,6 +32,7 @@ const char *const HeaderPolicy::EXTENDED_REGION_SIZE_KEY = "EXTENDED_REGION_SIZE
// Historical info is information that is needed to support decaying such as timestamp, level and
// count.
const char *const HeaderPolicy::HAS_HISTORICAL_INFO_KEY = "HAS_HISTORICAL_INFO";
+const char *const HeaderPolicy::LOCALE_KEY = "locale"; // match Java declaration
const int HeaderPolicy::DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE = 100;
const float HeaderPolicy::MULTIPLE_WORD_COST_MULTIPLIER_SCALE = 100.0f;
@@ -59,6 +60,10 @@ void HeaderPolicy::readHeaderValueOrQuestionMark(const char *const key, int *out
outValue[terminalIndex] = '\0';
}
+const std::vector<int> HeaderPolicy::readLocale() const {
+ return HeaderReadWriteUtils::readCodePointVectorAttributeValue(&mAttributeMap, LOCALE_KEY);
+}
+
float HeaderPolicy::readMultipleWordCostMultiplier() const {
const int demotionRate = HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
MULTIPLE_WORDS_DEMOTION_RATE_KEY, DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE);
@@ -116,6 +121,7 @@ void HeaderPolicy::fillInHeader(const bool updatesLastDecayedTime, const int uni
// Set the current time as the generation time.
HeaderReadWriteUtils::setIntAttribute(outAttributeMap, DATE_KEY,
TimeKeeper::peekCurrentTime());
+ HeaderReadWriteUtils::setCodePointVectorAttribute(outAttributeMap, LOCALE_KEY, mLocale);
if (updatesLastDecayedTime) {
// Set current time as the last updated time.
HeaderReadWriteUtils::setIntAttribute(outAttributeMap, LAST_DECAYED_TIME_KEY,
diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h
index a44f9f0fc..a05e00c39 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h
@@ -23,6 +23,7 @@
#include "suggest/core/policy/dictionary_header_structure_policy.h"
#include "suggest/policyimpl/dictionary/header/header_read_write_utils.h"
#include "suggest/policyimpl/dictionary/utils/format_utils.h"
+#include "utils/char_utils.h"
#include "utils/time_keeper.h"
namespace latinime {
@@ -35,6 +36,7 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
mDictionaryFlags(HeaderReadWriteUtils::getFlags(dictBuf)),
mSize(HeaderReadWriteUtils::getHeaderSize(dictBuf)),
mAttributeMap(createAttributeMapAndReadAllAttributes(dictBuf)),
+ mLocale(readLocale()),
mMultiWordCostMultiplier(readMultipleWordCostMultiplier()),
mRequiresGermanUmlautProcessing(readRequiresGermanUmlautProcessing()),
mIsDecayingDict(HeaderReadWriteUtils::readBoolAttributeValue(&mAttributeMap,
@@ -54,10 +56,11 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
// Constructs header information using an attribute map.
HeaderPolicy(const FormatUtils::FORMAT_VERSION dictFormatVersion,
+ const std::vector<int> locale,
const HeaderReadWriteUtils::AttributeMap *const attributeMap)
: mDictFormatVersion(dictFormatVersion),
mDictionaryFlags(HeaderReadWriteUtils::createAndGetDictionaryFlagsUsingAttributeMap(
- attributeMap)), mSize(0), mAttributeMap(*attributeMap),
+ attributeMap)), mSize(0), mAttributeMap(*attributeMap), mLocale(locale),
mMultiWordCostMultiplier(readMultipleWordCostMultiplier()),
mRequiresGermanUmlautProcessing(readRequiresGermanUmlautProcessing()),
mIsDecayingDict(HeaderReadWriteUtils::readBoolAttributeValue(&mAttributeMap,
@@ -68,12 +71,13 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
DATE_KEY, TimeKeeper::peekCurrentTime() /* defaultValue */)),
mUnigramCount(0), mBigramCount(0), mExtendedRegionSize(0),
mHasHistoricalInfoOfWords(HeaderReadWriteUtils::readBoolAttributeValue(
- &mAttributeMap, HAS_HISTORICAL_INFO_KEY, false /* defaultValue */)) {}
+ &mAttributeMap, HAS_HISTORICAL_INFO_KEY, false /* defaultValue */)) {
+ }
// Temporary dummy header.
HeaderPolicy()
: mDictFormatVersion(FormatUtils::UNKNOWN_VERSION), mDictionaryFlags(0), mSize(0),
- mAttributeMap(), mMultiWordCostMultiplier(0.0f),
+ mAttributeMap(), mLocale(CharUtils::EMPTY_STRING), mMultiWordCostMultiplier(0.0f),
mRequiresGermanUmlautProcessing(false), mIsDecayingDict(false),
mDate(0), mLastDecayedTime(0), mUnigramCount(0), mBigramCount(0),
mExtendedRegionSize(0), mHasHistoricalInfoOfWords(false) {}
@@ -146,6 +150,11 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
return mHasHistoricalInfoOfWords;
}
+ AK_FORCE_INLINE bool shouldBoostExactMatches() const {
+ // TODO: Investigate better ways to handle exact matches for personalized dictionaries.
+ return !isDecayingDict();
+ }
+
void readHeaderValueOrQuestionMark(const char *const key,
int *outValue, int outValueSize) const;
@@ -169,6 +178,7 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
static const char *const BIGRAM_COUNT_KEY;
static const char *const EXTENDED_REGION_SIZE_KEY;
static const char *const HAS_HISTORICAL_INFO_KEY;
+ static const char *const LOCALE_KEY;
static const int DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE;
static const float MULTIPLE_WORD_COST_MULTIPLIER_SCALE;
@@ -176,6 +186,7 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
const HeaderReadWriteUtils::DictionaryFlags mDictionaryFlags;
const int mSize;
HeaderReadWriteUtils::AttributeMap mAttributeMap;
+ const std::vector<int> mLocale;
const float mMultiWordCostMultiplier;
const bool mRequiresGermanUmlautProcessing;
const bool mIsDecayingDict;
@@ -186,6 +197,7 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
const int mExtendedRegionSize;
const bool mHasHistoricalInfoOfWords;
+ const std::vector<int> readLocale() const;
float readMultipleWordCostMultiplier() const;
bool readRequiresGermanUmlautProcessing() const;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp
index 6b4598642..850b0d87f 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.cpp
@@ -130,6 +130,13 @@ const HeaderReadWriteUtils::DictionaryFlags HeaderReadWriteUtils::NO_FLAGS = 0;
return true;
}
+/* static */ void HeaderReadWriteUtils::setCodePointVectorAttribute(
+ AttributeMap *const headerAttributes, const char *const key, const std::vector<int> value) {
+ AttributeMap::key_type keyVector;
+ insertCharactersIntoVector(key, &keyVector);
+ (*headerAttributes)[keyVector] = value;
+}
+
/* static */ void HeaderReadWriteUtils::setBoolAttribute(AttributeMap *const headerAttributes,
const char *const key, const bool value) {
setIntAttribute(headerAttributes, key, value ? 1 : 0);
@@ -151,6 +158,18 @@ const HeaderReadWriteUtils::DictionaryFlags HeaderReadWriteUtils::NO_FLAGS = 0;
(*headerAttributes)[*key] = valueVector;
}
+/* static */ const std::vector<int> HeaderReadWriteUtils::readCodePointVectorAttributeValue(
+ const AttributeMap *const headerAttributes, const char *const key) {
+ AttributeMap::key_type keyVector;
+ insertCharactersIntoVector(key, &keyVector);
+ AttributeMap::const_iterator it = headerAttributes->find(keyVector);
+ if (it == headerAttributes->end()) {
+ return std::vector<int>();
+ } else {
+ return it->second;
+ }
+}
+
/* static */ bool HeaderReadWriteUtils::readBoolAttributeValue(
const AttributeMap *const headerAttributes, const char *const key,
const bool defaultValue) {
diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.h b/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.h
index fc24bbdd5..3433c0494 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_read_write_utils.h
@@ -63,12 +63,18 @@ class HeaderReadWriteUtils {
/**
* Methods for header attributes.
*/
+ static void setCodePointVectorAttribute(AttributeMap *const headerAttributes,
+ const char *const key, const std::vector<int> value);
+
static void setBoolAttribute(AttributeMap *const headerAttributes,
const char *const key, const bool value);
static void setIntAttribute(AttributeMap *const headerAttributes,
const char *const key, const int value);
+ static const std::vector<int> readCodePointVectorAttributeValue(
+ const AttributeMap *const headerAttributes, const char *const key);
+
static bool readBoolAttributeValue(const AttributeMap *const headerAttributes,
const char *const key, const bool defaultValue);
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp
index b918e0765..824d442e4 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp
@@ -28,6 +28,14 @@ const int DynamicPtReadingHelper::MAX_CHILD_COUNT_TO_AVOID_INFINITE_LOOP = 10000
const int DynamicPtReadingHelper::MAX_PT_NODE_ARRAY_COUNT_TO_AVOID_INFINITE_LOOP = 100000;
const size_t DynamicPtReadingHelper::MAX_READING_STATE_STACK_SIZE = MAX_WORD_LENGTH;
+bool DynamicPtReadingHelper::TraversePolicyToGetAllTerminalPtNodePositions::onVisitingPtNode(
+ const PtNodeParams *const ptNodeParams) {
+ if (ptNodeParams->isTerminal() && !ptNodeParams->isDeleted()) {
+ mTerminalPositions->push_back(ptNodeParams->getHeadPos());
+ }
+ return true;
+}
+
// Visits all PtNodes in post-order depth first manner.
// For example, visits c -> b -> y -> x -> a for the following dictionary:
// a _ b _ c
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h
index a69490943..bcc5c7857 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h
@@ -59,6 +59,21 @@ class DynamicPtReadingHelper {
DISALLOW_COPY_AND_ASSIGN(TraversingEventListener);
};
+ class TraversePolicyToGetAllTerminalPtNodePositions : public TraversingEventListener {
+ public:
+ TraversePolicyToGetAllTerminalPtNodePositions(std::vector<int> *const terminalPositions)
+ : mTerminalPositions(terminalPositions) {}
+ bool onAscend() { return true; }
+ bool onDescend(const int ptNodeArrayPos) { return true; }
+ bool onReadingPtNodeArrayTail() { return true; }
+ bool onVisitingPtNode(const PtNodeParams *const ptNodeParams);
+
+ private:
+ DISALLOW_IMPLICIT_CONSTRUCTORS(TraversePolicyToGetAllTerminalPtNodePositions);
+
+ std::vector<int> *const mTerminalPositions;
+ };
+
DynamicPtReadingHelper(const BufferWithExtendableBuffer *const buffer,
const PtNodeReader *const ptNodeReader)
: mIsError(false), mReadingState(), mBuffer(buffer),
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
index 1c420e070..75d85988c 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
@@ -392,10 +392,32 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(const int *const code
historicalInfo->getCount(), &bigrams, &shortcuts);
}
-int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token,
- int *const outCodePoints) {
- // TODO: Implement.
- return 0;
+int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outCodePoints) {
+ if (token == 0) {
+ mTerminalPtNodePositionsForIteratingWords.clear();
+ DynamicPtReadingHelper::TraversePolicyToGetAllTerminalPtNodePositions traversePolicy(
+ &mTerminalPtNodePositionsForIteratingWords);
+ DynamicPtReadingHelper readingHelper(mDictBuffer, &mNodeReader);
+ readingHelper.initWithPtNodeArrayPos(getRootPosition());
+ readingHelper.traverseAllPtNodesInPostorderDepthFirstManner(&traversePolicy);
+ }
+ const int terminalPtNodePositionsVectorSize =
+ static_cast<int>(mTerminalPtNodePositionsForIteratingWords.size());
+ if (token < 0 || token >= terminalPtNodePositionsVectorSize) {
+ AKLOGE("Given token %d is invalid.", token);
+ return 0;
+ }
+ const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token];
+ int unigramProbability = NOT_A_PROBABILITY;
+ getCodePointsAndProbabilityAndReturnCodePointCount(terminalPtNodePos, MAX_WORD_LENGTH,
+ outCodePoints, &unigramProbability);
+ const int nextToken = token + 1;
+ if (nextToken >= terminalPtNodePositionsVectorSize) {
+ // All words have been iterated.
+ mTerminalPtNodePositionsForIteratingWords.clear();
+ return 0;
+ }
+ return nextToken;
}
} // namespace latinime
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h
index 1bcd4ceea..9ba5be0c3 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h
@@ -17,6 +17,8 @@
#ifndef LATINIME_VER4_PATRICIA_TRIE_POLICY_H
#define LATINIME_VER4_PATRICIA_TRIE_POLICY_H
+#include <vector>
+
#include "defines.h"
#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h"
#include "suggest/policyimpl/dictionary/bigram/ver4_bigram_list_policy.h"
@@ -50,7 +52,8 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
mUpdatingHelper(mDictBuffer, &mNodeReader, &mNodeWriter),
mWritingHelper(mBuffers.get()),
mUnigramCount(mHeaderPolicy->getUnigramCount()),
- mBigramCount(mHeaderPolicy->getBigramCount()) {};
+ mBigramCount(mHeaderPolicy->getBigramCount()),
+ mTerminalPtNodePositionsForIteratingWords() {};
AK_FORCE_INLINE int getRootPosition() const {
return 0;
@@ -134,6 +137,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
Ver4PatriciaTrieWritingHelper mWritingHelper;
int mUnigramCount;
int mBigramCount;
+ std::vector<int> mTerminalPtNodePositionsForIteratingWords;
};
} // namespace latinime
#endif // LATINIME_VER4_PATRICIA_TRIE_POLICY_H
diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp
index 84403c807..335ea0de0 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.cpp
@@ -31,11 +31,12 @@ namespace latinime {
const char *const DictFileWritingUtils::TEMP_FILE_SUFFIX_FOR_WRITING_DICT_FILE = ".tmp";
/* static */ bool DictFileWritingUtils::createEmptyDictFile(const char *const filePath,
- const int dictVersion, const HeaderReadWriteUtils::AttributeMap *const attributeMap) {
+ const int dictVersion, const std::vector<int> localeAsCodePointVector,
+ const HeaderReadWriteUtils::AttributeMap *const attributeMap) {
TimeKeeper::setCurrentTime();
switch (dictVersion) {
case FormatUtils::VERSION_4:
- return createEmptyV4DictFile(filePath, attributeMap);
+ return createEmptyV4DictFile(filePath, localeAsCodePointVector, attributeMap);
default:
AKLOGE("Cannot create dictionary %s because format version %d is not supported.",
filePath, dictVersion);
@@ -44,8 +45,9 @@ const char *const DictFileWritingUtils::TEMP_FILE_SUFFIX_FOR_WRITING_DICT_FILE =
}
/* static */ bool DictFileWritingUtils::createEmptyV4DictFile(const char *const dirPath,
+ const std::vector<int> localeAsCodePointVector,
const HeaderReadWriteUtils::AttributeMap *const attributeMap) {
- HeaderPolicy headerPolicy(FormatUtils::VERSION_4, attributeMap);
+ HeaderPolicy headerPolicy(FormatUtils::VERSION_4, localeAsCodePointVector, attributeMap);
Ver4DictBuffers::Ver4DictBuffersPtr dictBuffers =
Ver4DictBuffers::createVer4DictBuffers(&headerPolicy);
headerPolicy.fillInAndWriteHeaderToBuffer(true /* updatesLastDecayedTime */,
diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.h
index bdf9fd63c..c2ecff45e 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/utils/dict_file_writing_utils.h
@@ -31,6 +31,7 @@ class DictFileWritingUtils {
static const char *const TEMP_FILE_SUFFIX_FOR_WRITING_DICT_FILE;
static bool createEmptyDictFile(const char *const filePath, const int dictVersion,
+ const std::vector<int> localeAsCodePointVector,
const HeaderReadWriteUtils::AttributeMap *const attributeMap);
static bool flushAllHeaderAndBodyToFile(const char *const filePath,
@@ -44,6 +45,7 @@ class DictFileWritingUtils {
DISALLOW_IMPLICIT_CONSTRUCTORS(DictFileWritingUtils);
static bool createEmptyV4DictFile(const char *const filePath,
+ const std::vector<int> localeAsCodePointVector,
const HeaderReadWriteUtils::AttributeMap *const attributeMap);
static bool flushBufferToFile(const char *const filePath,
diff --git a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h
index c777e7238..8b405e8de 100644
--- a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h
+++ b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h
@@ -50,14 +50,14 @@ class TypingScoring : public Scoring {
AK_FORCE_INLINE int calculateFinalScore(const float compoundDistance,
const int inputSize, const ErrorTypeUtils::ErrorType containedErrorTypes,
- const bool forceCommit) const {
+ const bool forceCommit, const bool boostExactMatches) const {
const float maxDistance = ScoringParams::DISTANCE_WEIGHT_LANGUAGE
+ static_cast<float>(inputSize) * ScoringParams::TYPING_MAX_OUTPUT_SCORE_PER_INPUT;
float score = ScoringParams::TYPING_BASE_OUTPUT_SCORE - compoundDistance / maxDistance;
if (forceCommit) {
score += ScoringParams::AUTOCORRECT_OUTPUT_THRESHOLD;
}
- if (ErrorTypeUtils::isExactMatch(containedErrorTypes)) {
+ if (boostExactMatches && ErrorTypeUtils::isExactMatch(containedErrorTypes)) {
score += ScoringParams::EXACT_MATCH_PROMOTION;
if ((ErrorTypeUtils::MATCH_WITH_CASE_ERROR & containedErrorTypes) != 0) {
score -= ScoringParams::CASE_ERROR_PENALTY_FOR_EXACT_MATCH;
diff --git a/native/jni/src/utils/char_utils.cpp b/native/jni/src/utils/char_utils.cpp
index 0e7039610..d41fc8924 100644
--- a/native/jni/src/utils/char_utils.cpp
+++ b/native/jni/src/utils/char_utils.cpp
@@ -1273,4 +1273,6 @@ static int compare_pair_capital(const void *a, const void *b) {
/* U+04F0 */ 0x0423, 0x0443, 0x0423, 0x0443, 0x0427, 0x0447, 0x04F6, 0x04F7,
/* U+04F8 */ 0x042B, 0x044B, 0x04FA, 0x04FB, 0x04FC, 0x04FD, 0x04FE, 0x04FF,
};
+
+/* static */ const std::vector<int> CharUtils::EMPTY_STRING(1 /* size */, '\0' /* value */);
} // namespace latinime
diff --git a/native/jni/src/utils/char_utils.h b/native/jni/src/utils/char_utils.h
index 41663c81a..98b8966df 100644
--- a/native/jni/src/utils/char_utils.h
+++ b/native/jni/src/utils/char_utils.h
@@ -18,6 +18,7 @@
#define LATINIME_CHAR_UTILS_H
#include <cctype>
+#include <vector>
#include "defines.h"
@@ -85,7 +86,15 @@ class CharUtils {
return spaceCount;
}
+ static AK_FORCE_INLINE std::vector<int> convertShortArrayToIntVector(
+ const unsigned short *const source, const int length) {
+ std::vector<int> destination;
+ destination.insert(destination.end(), source, source + length);
+ return destination; // Copies the vector
+ }
+
static unsigned short latin_tolower(const unsigned short c);
+ static const std::vector<int> EMPTY_STRING;
private:
DISALLOW_IMPLICIT_CONSTRUCTORS(CharUtils);