aboutsummaryrefslogtreecommitdiffstats
path: root/native
diff options
context:
space:
mode:
Diffstat (limited to 'native')
-rw-r--r--native/jni/src/suggest/core/dicnode/internal/dic_node_properties.h17
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp15
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h3
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp44
-rw-r--r--native/jni/src/utils/int_array_view.h9
-rw-r--r--native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp6
-rw-r--r--native/jni/tests/utils/int_array_view_test.cpp13
7 files changed, 68 insertions, 39 deletions
diff --git a/native/jni/src/suggest/core/dicnode/internal/dic_node_properties.h b/native/jni/src/suggest/core/dicnode/internal/dic_node_properties.h
index cecfc7aa9..1b796b5d4 100644
--- a/native/jni/src/suggest/core/dicnode/internal/dic_node_properties.h
+++ b/native/jni/src/suggest/core/dicnode/internal/dic_node_properties.h
@@ -32,7 +32,7 @@ class DicNodeProperties {
public:
AK_FORCE_INLINE DicNodeProperties()
: mChildrenPtNodeArrayPos(NOT_A_DICT_POS), mDicNodeCodePoint(NOT_A_CODE_POINT),
- mWordId(NOT_A_WORD_ID), mDepth(0), mLeavingDepth(0) {}
+ mWordId(NOT_A_WORD_ID), mDepth(0), mLeavingDepth(0), mPrevWordCount(0) {}
~DicNodeProperties() {}
@@ -45,6 +45,7 @@ class DicNodeProperties {
mDepth = depth;
mLeavingDepth = leavingDepth;
prevWordIds.copyToArray(&mPrevWordIds, 0 /* offset */);
+ mPrevWordCount = prevWordIds.size();
}
// Init for root with prevWordsPtNodePos which is used for n-gram
@@ -55,6 +56,7 @@ class DicNodeProperties {
mDepth = 0;
mLeavingDepth = 0;
prevWordIds.copyToArray(&mPrevWordIds, 0 /* offset */);
+ mPrevWordCount = prevWordIds.size();
}
void initByCopy(const DicNodeProperties *const dicNodeProp) {
@@ -63,8 +65,9 @@ class DicNodeProperties {
mWordId = dicNodeProp->mWordId;
mDepth = dicNodeProp->mDepth;
mLeavingDepth = dicNodeProp->mLeavingDepth;
- WordIdArrayView::fromArray(dicNodeProp->mPrevWordIds)
- .copyToArray(&mPrevWordIds, 0 /* offset */);
+ const WordIdArrayView prevWordIdArrayView = dicNodeProp->getPrevWordIds();
+ prevWordIdArrayView.copyToArray(&mPrevWordIds, 0 /* offset */);
+ mPrevWordCount = prevWordIdArrayView.size();
}
// Init as passing child
@@ -74,8 +77,9 @@ class DicNodeProperties {
mWordId = dicNodeProp->mWordId;
mDepth = dicNodeProp->mDepth + 1; // Increment the depth of a passing child
mLeavingDepth = dicNodeProp->mLeavingDepth;
- WordIdArrayView::fromArray(dicNodeProp->mPrevWordIds)
- .copyToArray(&mPrevWordIds, 0 /* offset */);
+ const WordIdArrayView prevWordIdArrayView = dicNodeProp->getPrevWordIds();
+ prevWordIdArrayView.copyToArray(&mPrevWordIds, 0 /* offset */);
+ mPrevWordCount = prevWordIdArrayView.size();
}
int getChildrenPtNodeArrayPos() const {
@@ -104,7 +108,7 @@ class DicNodeProperties {
}
const WordIdArrayView getPrevWordIds() const {
- return WordIdArrayView::fromArray(mPrevWordIds);
+ return WordIdArrayView::fromArray(mPrevWordIds).limit(mPrevWordCount);
}
int getWordId() const {
@@ -121,6 +125,7 @@ class DicNodeProperties {
uint16_t mDepth;
uint16_t mLeavingDepth;
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> mPrevWordIds;
+ size_t mPrevWordCount;
};
} // namespace latinime
#endif // LATINIME_DIC_NODE_PROPERTIES_H
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
index f54bb151a..0675de6fa 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
@@ -39,7 +39,7 @@ bool LanguageModelDictContent::runGC(
}
int LanguageModelDictContent::getWordProbability(const WordIdArrayView prevWordIds,
- const int wordId) const {
+ const int wordId, const HeaderPolicy *const headerPolicy) const {
int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex();
int maxLevel = 0;
@@ -58,14 +58,15 @@ int LanguageModelDictContent::getWordProbability(const WordIdArrayView prevWordI
if (!result.mIsValid) {
continue;
}
- const int probability =
- ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo).getProbability();
+ const ProbabilityEntry probabilityEntry =
+ ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo);
if (mHasHistoricalInfo) {
- return std::min(
- probability + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */),
- MAX_PROBABILITY);
+ const int probability = ForgettingCurveUtils::decodeProbability(
+ probabilityEntry.getHistoricalInfo(), headerPolicy)
+ + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */);
+ return std::min(probability, MAX_PROBABILITY);
} else {
- return probability;
+ return probabilityEntry.getProbability();
}
}
// Cannot find the word.
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
index 4e0b47036..a793af4be 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
@@ -128,7 +128,8 @@ class LanguageModelDictContent {
const LanguageModelDictContent *const originalContent,
int *const outNgramCount);
- int getWordProbability(const WordIdArrayView prevWordIds, const int wordId) const;
+ int getWordProbability(const WordIdArrayView prevWordIds, const int wordId,
+ const HeaderPolicy *const headerPolicy) const;
ProbabilityEntry getProbabilityEntry(const int wordId) const {
return getNgramProbabilityEntry(WordIdArrayView(), wordId);
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
index e624bf338..d537711b0 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
@@ -121,9 +121,10 @@ const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext(
mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
// TODO: Support n-gram.
- return WordAttributes(mBuffers->getLanguageModelDictContent()->getWordProbability(
- prevWordIds.limit(1 /* maxSize */), wordId), ptNodeParams.isBlacklisted(),
- ptNodeParams.isNotAWord(), ptNodeParams.getProbability() == 0);
+ const int probability = mBuffers->getLanguageModelDictContent()->getWordProbability(
+ prevWordIds.limit(1 /* maxSize */), wordId, mHeaderPolicy);
+ return WordAttributes(probability, ptNodeParams.isBlacklisted(), ptNodeParams.isNotAWord(),
+ probability == 0);
}
int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds,
@@ -309,30 +310,32 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI
if (prevWordIds.empty()) {
return false;
}
- // TODO: Support N-gram.
- if (prevWordIds[0] == NOT_A_WORD_ID) {
- if (prevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */)) {
- const std::vector<UnigramProperty::ShortcutProperty> shortcuts;
- const UnigramProperty beginningOfSentenceUnigramProperty(
- true /* representsBeginningOfSentence */, true /* isNotAWord */,
- false /* isBlacklisted */, MAX_PROBABILITY /* probability */,
- NOT_A_TIMESTAMP /* timestamp */, 0 /* level */, 0 /* count */, &shortcuts);
- if (!addUnigramEntry(prevWordsInfo->getNthPrevWordCodePoints(1 /* n */),
- &beginningOfSentenceUnigramProperty)) {
- AKLOGE("Cannot add unigram entry for the beginning-of-sentence.");
- return false;
- }
- // Refresh word ids.
- prevWordsInfo->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */);
- } else {
+ for (size_t i = 0; i < prevWordIds.size(); ++i) {
+ if (prevWordIds[i] != NOT_A_WORD_ID) {
+ continue;
+ }
+ if (!prevWordsInfo->isNthPrevWordBeginningOfSentence(i + 1 /* n */)) {
return false;
}
+ const std::vector<UnigramProperty::ShortcutProperty> shortcuts;
+ const UnigramProperty beginningOfSentenceUnigramProperty(
+ true /* representsBeginningOfSentence */, true /* isNotAWord */,
+ false /* isBlacklisted */, MAX_PROBABILITY /* probability */,
+ NOT_A_TIMESTAMP /* timestamp */, 0 /* level */, 0 /* count */, &shortcuts);
+ if (!addUnigramEntry(prevWordsInfo->getNthPrevWordCodePoints(1 /* n */),
+ &beginningOfSentenceUnigramProperty)) {
+ AKLOGE("Cannot add unigram entry for the beginning-of-sentence.");
+ return false;
+ }
+ // Refresh word ids.
+ prevWordsInfo->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */);
}
const int wordId = getWordId(CodePointArrayView(*bigramProperty->getTargetCodePoints()),
false /* forceLowerCaseSearch */);
if (wordId == NOT_A_WORD_ID) {
return false;
}
+ // TODO: Support N-gram.
bool addedNewEntry = false;
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordsPtNodePos;
for (size_t i = 0; i < prevWordsPtNodePos.size(); ++i) {
@@ -374,8 +377,7 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor
WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray;
const WordIdArrayView prevWordIds = prevWordsInfo->getPrevWordIds(this, &prevWordIdArray,
false /* tryLowerCaseSerch */);
- // TODO: Support N-gram.
- if (prevWordIds.empty() || prevWordIds[0] == NOT_A_WORD_ID) {
+ if (prevWordIds.empty() || prevWordIds.contains(NOT_A_WORD_ID)) {
return false;
}
const int wordId = getWordId(wordCodePoints, false /* forceLowerCaseSearch */);
diff --git a/native/jni/src/utils/int_array_view.h b/native/jni/src/utils/int_array_view.h
index caa13d976..cc5f328ba 100644
--- a/native/jni/src/utils/int_array_view.h
+++ b/native/jni/src/utils/int_array_view.h
@@ -17,6 +17,7 @@
#ifndef LATINIME_INT_ARRAY_VIEW_H
#define LATINIME_INT_ARRAY_VIEW_H
+#include <algorithm>
#include <array>
#include <cstdint>
#include <cstring>
@@ -92,12 +93,16 @@ class IntArrayView {
return mPtr + mSize;
}
+ AK_FORCE_INLINE bool contains(const int value) const {
+ return std::find(begin(), end(), value) != end();
+ }
+
// Returns the view whose size is smaller than or equal to the given count.
- const IntArrayView limit(const size_t maxSize) const {
+ AK_FORCE_INLINE const IntArrayView limit(const size_t maxSize) const {
return IntArrayView(mPtr, std::min(maxSize, mSize));
}
- const IntArrayView skip(const size_t n) const {
+ AK_FORCE_INLINE const IntArrayView skip(const size_t n) const {
if (mSize <= n) {
return IntArrayView();
}
diff --git a/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp b/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp
index 7608b45c2..c5849d054 100644
--- a/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp
+++ b/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp
@@ -107,13 +107,15 @@ TEST(LanguageModelDictContentTest, TestGetWordProbability) {
languageModelDictContent.setProbabilityEntry(prevWordIds[0], &probabilityEntry);
languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1), wordId,
&bigramProbabilityEntry);
- EXPECT_EQ(bigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId));
+ EXPECT_EQ(bigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId,
+ nullptr /* headerPolicy */));
const ProbabilityEntry trigramProbabilityEntry(flag, trigramProbability);
languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1),
prevWordIds[1], &probabilityEntry);
languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(2), wordId,
&trigramProbabilityEntry);
- EXPECT_EQ(trigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId));
+ EXPECT_EQ(trigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId,
+ nullptr /* headerPolicy */));
}
} // namespace
diff --git a/native/jni/tests/utils/int_array_view_test.cpp b/native/jni/tests/utils/int_array_view_test.cpp
index 3bc294cdd..934e27e1c 100644
--- a/native/jni/tests/utils/int_array_view_test.cpp
+++ b/native/jni/tests/utils/int_array_view_test.cpp
@@ -58,6 +58,19 @@ TEST(IntArrayViewTest, TestConstructFromObject) {
EXPECT_EQ(object, intArrayView[0]);
}
+TEST(IntArrayViewTest, TestContains) {
+ EXPECT_FALSE(IntArrayView().contains(0));
+ EXPECT_FALSE(IntArrayView().contains(1));
+
+ const std::vector<int> intVector = {3, 2, 1, 0, -1, -2};
+ IntArrayView intArrayView(intVector);
+ EXPECT_TRUE(intArrayView.contains(0));
+ EXPECT_TRUE(intArrayView.contains(3));
+ EXPECT_TRUE(intArrayView.contains(-2));
+ EXPECT_FALSE(intArrayView.contains(-3));
+ EXPECT_FALSE(intArrayView.limit(0).contains(3));
+}
+
TEST(IntArrayViewTest, TestLimit) {
const std::vector<int> intVector = {3, 2, 1, 0, -1, -2};
IntArrayView intArrayView(intVector);