aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--native/jni/src/suggest/core/dictionary/dictionary.cpp18
-rw-r--r--native/jni/src/suggest/core/dictionary/dictionary_utils.cpp9
-rw-r--r--native/jni/src/suggest/core/session/dic_traverse_session.cpp4
-rw-r--r--native/jni/src/suggest/core/session/dic_traverse_session.h14
-rw-r--r--native/jni/src/suggest/core/session/prev_words_info.h36
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp17
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp17
7 files changed, 67 insertions, 48 deletions
diff --git a/native/jni/src/suggest/core/dictionary/dictionary.cpp b/native/jni/src/suggest/core/dictionary/dictionary.cpp
index ec261cfbf..b843791d6 100644
--- a/native/jni/src/suggest/core/dictionary/dictionary.cpp
+++ b/native/jni/src/suggest/core/dictionary/dictionary.cpp
@@ -93,13 +93,13 @@ void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbabi
void Dictionary::getPredictions(const PrevWordsInfo *const prevWordsInfo,
SuggestionResults *const outSuggestionResults) const {
TimeKeeper::setCurrentTime();
- WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIds;
- prevWordsInfo->getPrevWordIds(mDictionaryStructureWithBufferPolicy.get(), prevWordIds.data(),
+ WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray;
+ const WordIdArrayView prevWordIds = prevWordsInfo->getPrevWordIds(
+ mDictionaryStructureWithBufferPolicy.get(), &prevWordIdArray,
true /* tryLowerCaseSearch */);
- const WordIdArrayView prevWordIdArrayView = WordIdArrayView::fromArray(prevWordIds);
- NgramListenerForPrediction listener(prevWordsInfo, prevWordIdArrayView, outSuggestionResults,
+ NgramListenerForPrediction listener(prevWordsInfo, prevWordIds, outSuggestionResults,
mDictionaryStructureWithBufferPolicy.get());
- mDictionaryStructureWithBufferPolicy->iterateNgramEntries(prevWordIdArrayView, &listener);
+ mDictionaryStructureWithBufferPolicy->iterateNgramEntries(prevWordIds, &listener);
}
int Dictionary::getProbability(const int *word, int length) const {
@@ -121,11 +121,11 @@ int Dictionary::getNgramProbability(const PrevWordsInfo *const prevWordsInfo, co
if (!prevWordsInfo) {
return getDictionaryStructurePolicy()->getProbabilityOfWord(WordIdArrayView(), wordId);
}
- WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIds;
- prevWordsInfo->getPrevWordIds(mDictionaryStructureWithBufferPolicy.get(), prevWordIds.data(),
+ WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray;
+ const WordIdArrayView prevWordIds = prevWordsInfo->getPrevWordIds
+ (mDictionaryStructureWithBufferPolicy.get(), &prevWordIdArray,
true /* tryLowerCaseSearch */);
- return getDictionaryStructurePolicy()->getProbabilityOfWord(
- IntArrayView::fromArray(prevWordIds), wordId);
+ return getDictionaryStructurePolicy()->getProbabilityOfWord(prevWordIds, wordId);
}
bool Dictionary::addUnigramEntry(const int *const word, const int length,
diff --git a/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp b/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp
index 617bf5a90..d09266e29 100644
--- a/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp
+++ b/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp
@@ -35,12 +35,11 @@ namespace latinime {
// No prev words information.
PrevWordsInfo emptyPrevWordsInfo;
- WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIds;
- emptyPrevWordsInfo.getPrevWordIds(dictionaryStructurePolicy, prevWordIds.data(),
- false /* tryLowerCaseSearch */);
+ WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray;
+ const WordIdArrayView prevWordIds = emptyPrevWordsInfo.getPrevWordIds(
+ dictionaryStructurePolicy, &prevWordIdArray, false /* tryLowerCaseSearch */);
current.emplace_back();
- DicNodeUtils::initAsRoot(dictionaryStructurePolicy,
- IntArrayView::fromArray(prevWordIds), &current.front());
+ DicNodeUtils::initAsRoot(dictionaryStructurePolicy, prevWordIds, &current.front());
for (int i = 0; i < codePointCount; ++i) {
// The base-lower input is used to ignore case errors and accent errors.
const int codePoint = CharUtils::toBaseLowerCase(codePoints[i]);
diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.cpp b/native/jni/src/suggest/core/session/dic_traverse_session.cpp
index 4f58c0c54..4d7505a55 100644
--- a/native/jni/src/suggest/core/session/dic_traverse_session.cpp
+++ b/native/jni/src/suggest/core/session/dic_traverse_session.cpp
@@ -35,8 +35,8 @@ void DicTraverseSession::init(const Dictionary *const dictionary,
mMultiWordCostMultiplier = getDictionaryStructurePolicy()->getHeaderStructurePolicy()
->getMultiWordCostMultiplier();
mSuggestOptions = suggestOptions;
- prevWordsInfo->getPrevWordIds(getDictionaryStructurePolicy(), mPrevWordsIds.data(),
- true /* tryLowerCaseSearch */);
+ mPrevWordIdCount = prevWordsInfo->getPrevWordIds(getDictionaryStructurePolicy(),
+ &mPrevWordIdArray, true /* tryLowerCaseSearch */).size();
}
void DicTraverseSession::setupForGetSuggestions(const ProximityInfo *pInfo,
diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.h b/native/jni/src/suggest/core/session/dic_traverse_session.h
index af199071e..9f841aa3c 100644
--- a/native/jni/src/suggest/core/session/dic_traverse_session.h
+++ b/native/jni/src/suggest/core/session/dic_traverse_session.h
@@ -51,12 +51,11 @@ class DicTraverseSession {
}
AK_FORCE_INLINE DicTraverseSession(JNIEnv *env, jstring localeStr, bool usesLargeCache)
- : mProximityInfo(nullptr), mDictionary(nullptr), mSuggestOptions(nullptr),
- mDicNodesCache(usesLargeCache), mMultiBigramMap(), mInputSize(0), mMaxPointerCount(1),
- mMultiWordCostMultiplier(1.0f) {
+ : mPrevWordIdCount(0), mProximityInfo(nullptr), mDictionary(nullptr),
+ mSuggestOptions(nullptr), mDicNodesCache(usesLargeCache), mMultiBigramMap(),
+ mInputSize(0), mMaxPointerCount(1), mMultiWordCostMultiplier(1.0f) {
// NOTE: mProximityInfoStates is an array of instances.
// No need to initialize it explicitly here.
- mPrevWordsIds.fill(NOT_A_DICT_POS);
}
// Non virtual inline destructor -- never inherit this class
@@ -78,7 +77,9 @@ class DicTraverseSession {
//--------------------
const ProximityInfo *getProximityInfo() const { return mProximityInfo; }
const SuggestOptions *getSuggestOptions() const { return mSuggestOptions; }
- const WordIdArrayView getPrevWordIds() const { return IntArrayView::fromArray(mPrevWordsIds); }
+ const WordIdArrayView getPrevWordIds() const {
+ return WordIdArrayView::fromArray(mPrevWordIdArray).limit(mPrevWordIdCount);
+ }
DicNodesCache *getDicTraverseCache() { return &mDicNodesCache; }
MultiBigramMap *getMultiBigramMap() { return &mMultiBigramMap; }
const ProximityInfoState *getProximityInfoState(int id) const {
@@ -165,7 +166,8 @@ class DicTraverseSession {
const int *const inputYs, const int *const times, const int *const pointerIds,
const int inputSize, const float maxSpatialDistance, const int maxPointerCount);
- WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> mPrevWordsIds;
+ WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> mPrevWordIdArray;
+ size_t mPrevWordIdCount;
const ProximityInfo *mProximityInfo;
const Dictionary *mDictionary;
const SuggestOptions *mSuggestOptions;
diff --git a/native/jni/src/suggest/core/session/prev_words_info.h b/native/jni/src/suggest/core/session/prev_words_info.h
index fc9a35968..02e82a8e0 100644
--- a/native/jni/src/suggest/core/session/prev_words_info.h
+++ b/native/jni/src/suggest/core/session/prev_words_info.h
@@ -17,6 +17,8 @@
#ifndef LATINIME_PREV_WORDS_INFO_H
#define LATINIME_PREV_WORDS_INFO_H
+#include <array>
+
#include "defines.h"
#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h"
#include "utils/char_utils.h"
@@ -27,12 +29,13 @@ namespace latinime {
class PrevWordsInfo {
public:
// No prev word information.
- PrevWordsInfo() {
+ PrevWordsInfo() : mPrevWordCount(0) {
clear();
}
- PrevWordsInfo(PrevWordsInfo &&prevWordsInfo) {
- for (size_t i = 0; i < NELEMS(mPrevWordCodePoints); ++i) {
+ PrevWordsInfo(PrevWordsInfo &&prevWordsInfo)
+ : mPrevWordCount(prevWordsInfo.mPrevWordCount) {
+ for (size_t i = 0; i < mPrevWordCount; ++i) {
mPrevWordCodePointCount[i] = prevWordsInfo.mPrevWordCodePointCount[i];
memmove(mPrevWordCodePoints[i], prevWordsInfo.mPrevWordCodePoints[i],
sizeof(mPrevWordCodePoints[i][0]) * mPrevWordCodePointCount[i]);
@@ -43,9 +46,10 @@ class PrevWordsInfo {
// Construct from previous words.
PrevWordsInfo(const int prevWordCodePoints[][MAX_WORD_LENGTH],
const int *const prevWordCodePointCount, const bool *const isBeginningOfSentence,
- const size_t prevWordCount) {
+ const size_t prevWordCount)
+ : mPrevWordCount(std::min(NELEMS(mPrevWordCodePoints), prevWordCount)) {
clear();
- for (size_t i = 0; i < std::min(NELEMS(mPrevWordCodePoints), prevWordCount); ++i) {
+ for (size_t i = 0; i < mPrevWordCount; ++i) {
if (prevWordCodePointCount[i] < 0 || prevWordCodePointCount[i] > MAX_WORD_LENGTH) {
continue;
}
@@ -58,7 +62,7 @@ class PrevWordsInfo {
// Construct from a previous word.
PrevWordsInfo(const int *const prevWordCodePoints, const int prevWordCodePointCount,
- const bool isBeginningOfSentence) {
+ const bool isBeginningOfSentence) : mPrevWordCount(1) {
clear();
if (prevWordCodePointCount > MAX_WORD_LENGTH || !prevWordCodePoints) {
return;
@@ -79,26 +83,29 @@ class PrevWordsInfo {
return false;
}
- void getPrevWordIds(const DictionaryStructureWithBufferPolicy *const dictStructurePolicy,
- int *const outPrevWordIds, const bool tryLowerCaseSearch) const {
- for (size_t i = 0; i < NELEMS(mPrevWordCodePoints); ++i) {
- outPrevWordIds[i] = getWordId(dictStructurePolicy,
+ template<size_t N>
+ const WordIdArrayView getPrevWordIds(
+ const DictionaryStructureWithBufferPolicy *const dictStructurePolicy,
+ std::array<int, N> *const prevWordIdBuffer, const bool tryLowerCaseSearch) const {
+ for (size_t i = 0; i < std::min(mPrevWordCount, N); ++i) {
+ prevWordIdBuffer->at(i) = getWordId(dictStructurePolicy,
mPrevWordCodePoints[i], mPrevWordCodePointCount[i],
mIsBeginningOfSentence[i], tryLowerCaseSearch);
}
+ return WordIdArrayView::fromArray(*prevWordIdBuffer).limit(mPrevWordCount);
}
// n is 1-indexed.
- const CodePointArrayView getNthPrevWordCodePoints(const int n) const {
- if (n <= 0 || n > MAX_PREV_WORD_COUNT_FOR_N_GRAM) {
+ const CodePointArrayView getNthPrevWordCodePoints(const size_t n) const {
+ if (n <= 0 || n > mPrevWordCount) {
return CodePointArrayView();
}
return CodePointArrayView(mPrevWordCodePoints[n - 1], mPrevWordCodePointCount[n - 1]);
}
// n is 1-indexed.
- bool isNthPrevWordBeginningOfSentence(const int n) const {
- if (n <= 0 || n > MAX_PREV_WORD_COUNT_FOR_N_GRAM) {
+ bool isNthPrevWordBeginningOfSentence(const size_t n) const {
+ if (n <= 0 || n > mPrevWordCount) {
return false;
}
return mIsBeginningOfSentence[n - 1];
@@ -142,6 +149,7 @@ class PrevWordsInfo {
}
}
+ const size_t mPrevWordCount;
int mPrevWordCodePoints[MAX_PREV_WORD_COUNT_FOR_N_GRAM][MAX_WORD_LENGTH];
int mPrevWordCodePointCount[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
bool mIsBeginningOfSentence[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
index dfc3d2d9b..41b9a11b1 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
@@ -332,8 +332,12 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI
"length: %zd", bigramProperty->getTargetCodePoints()->size());
return false;
}
- int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
- prevWordsInfo->getPrevWordIds(this, prevWordIds, false /* tryLowerCaseSearch */);
+ WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray;
+ const WordIdArrayView prevWordIds = prevWordsInfo->getPrevWordIds(this, &prevWordIdArray,
+ false /* tryLowerCaseSearch */);
+ if (prevWordIds.empty()) {
+ return false;
+ }
if (prevWordIds[0] == NOT_A_WORD_ID) {
if (prevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */)) {
const std::vector<UnigramProperty::ShortcutProperty> shortcuts;
@@ -347,7 +351,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI
return false;
}
// Refresh word ids.
- prevWordsInfo->getPrevWordIds(this, prevWordIds, false /* tryLowerCaseSearch */);
+ prevWordsInfo->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */);
} else {
return false;
}
@@ -390,9 +394,10 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor
AKLOGE("word is too long to remove n-gram entry form the dictionary. length: %zd",
wordCodePoints.size());
}
- int prevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
- prevWordsInfo->getPrevWordIds(this, prevWordIds, false /* tryLowerCaseSerch */);
- if (prevWordIds[0] == NOT_A_WORD_ID) {
+ WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray;
+ const WordIdArrayView prevWordIds = prevWordsInfo->getPrevWordIds(this, &prevWordIdArray,
+ false /* tryLowerCaseSerch */);
+ if (prevWordIds.empty() || prevWordIds[0] == NOT_A_WORD_ID) {
return false;
}
const int wordPos = getTerminalPtNodePosFromWordId(getWordId(wordCodePoints,
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
index 5b1907f49..e624bf338 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
@@ -303,8 +303,12 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI
"length: %zd", bigramProperty->getTargetCodePoints()->size());
return false;
}
- WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIds;
- prevWordsInfo->getPrevWordIds(this, prevWordIds.data(), false /* tryLowerCaseSearch */);
+ WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray;
+ const WordIdArrayView prevWordIds = prevWordsInfo->getPrevWordIds(this, &prevWordIdArray,
+ false /* tryLowerCaseSearch */);
+ if (prevWordIds.empty()) {
+ return false;
+ }
// TODO: Support N-gram.
if (prevWordIds[0] == NOT_A_WORD_ID) {
if (prevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */)) {
@@ -319,7 +323,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI
return false;
}
// Refresh word ids.
- prevWordsInfo->getPrevWordIds(this, prevWordIds.data(), false /* tryLowerCaseSearch */);
+ prevWordsInfo->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */);
} else {
return false;
}
@@ -367,10 +371,11 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor
AKLOGE("word is too long to remove n-gram entry form the dictionary. length: %zd",
wordCodePoints.size());
}
- WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIds;
- prevWordsInfo->getPrevWordIds(this, prevWordIds.data(), false /* tryLowerCaseSerch */);
+ WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray;
+ const WordIdArrayView prevWordIds = prevWordsInfo->getPrevWordIds(this, &prevWordIdArray,
+ false /* tryLowerCaseSerch */);
// TODO: Support N-gram.
- if (prevWordIds[0] == NOT_A_WORD_ID) {
+ if (prevWordIds.empty() || prevWordIds[0] == NOT_A_WORD_ID) {
return false;
}
const int wordId = getWordId(wordCodePoints, false /* forceLowerCaseSearch */);