diff options
Diffstat (limited to 'native')
27 files changed, 201 insertions, 199 deletions
diff --git a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp index 688ce44be..e420f8056 100644 --- a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp +++ b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp @@ -364,10 +364,12 @@ static bool latinime_BinaryDictionary_addUnigramEntry(JNIEnv *env, jclass clazz, int codePoints[codePointCount]; env->GetIntArrayRegion(word, 0, codePointCount, codePoints); std::vector<UnigramProperty::ShortcutProperty> shortcuts; - std::vector<int> shortcutTargetCodePoints; - JniDataUtils::jintarrayToVector(env, shortcutTarget, &shortcutTargetCodePoints); - if (!shortcutTargetCodePoints.empty()) { - shortcuts.emplace_back(&shortcutTargetCodePoints, shortcutProbability); + { + std::vector<int> shortcutTargetCodePoints; + JniDataUtils::jintarrayToVector(env, shortcutTarget, &shortcutTargetCodePoints); + if (!shortcutTargetCodePoints.empty()) { + shortcuts.emplace_back(std::move(shortcutTargetCodePoints), shortcutProbability); + } } // Use 1 for count to indicate the word has inputted. const UnigramProperty unigramProperty(isBeginningOfSentence, isNotAWord, @@ -401,11 +403,9 @@ static bool latinime_BinaryDictionary_addNgramEntry(JNIEnv *env, jclass clazz, j jsize wordLength = env->GetArrayLength(word); int wordCodePoints[wordLength]; env->GetIntArrayRegion(word, 0, wordLength, wordCodePoints); - const std::vector<int> bigramTargetCodePoints( - wordCodePoints, wordCodePoints + wordLength); // Use 1 for count to indicate the bigram has inputted. - const BigramProperty bigramProperty(&bigramTargetCodePoints, probability, - timestamp, 0 /* level */, 1 /* count */); + const BigramProperty bigramProperty(CodePointArrayView(wordCodePoints, wordLength).toVector(), + probability, timestamp, 0 /* level */, 1 /* count */); return dictionary->addNgramEntry(&prevWordsInfo, &bigramProperty); } @@ -483,12 +483,14 @@ static int latinime_BinaryDictionary_addMultipleDictionaryEntries(JNIEnv *env, j jintArray shortcutTarget = static_cast<jintArray>( env->GetObjectField(languageModelParam, shortcutTargetFieldId)); std::vector<UnigramProperty::ShortcutProperty> shortcuts; - std::vector<int> shortcutTargetCodePoints; - JniDataUtils::jintarrayToVector(env, shortcutTarget, &shortcutTargetCodePoints); - if (!shortcutTargetCodePoints.empty()) { - jint shortcutProbability = - env->GetIntField(languageModelParam, shortcutProbabilityFieldId); - shortcuts.emplace_back(&shortcutTargetCodePoints, shortcutProbability); + { + std::vector<int> shortcutTargetCodePoints; + JniDataUtils::jintarrayToVector(env, shortcutTarget, &shortcutTargetCodePoints); + if (!shortcutTargetCodePoints.empty()) { + jint shortcutProbability = + env->GetIntField(languageModelParam, shortcutProbabilityFieldId); + shortcuts.emplace_back(std::move(shortcutTargetCodePoints), shortcutProbability); + } } // Use 1 for count to indicate the word has inputted. const UnigramProperty unigramProperty(false /* isBeginningOfSentence */, isNotAWord, @@ -498,11 +500,10 @@ static int latinime_BinaryDictionary_addMultipleDictionaryEntries(JNIEnv *env, j &unigramProperty); if (word0) { jint bigramProbability = env->GetIntField(languageModelParam, bigramProbabilityFieldId); - const std::vector<int> bigramTargetCodePoints( - word1CodePoints, word1CodePoints + word1Length); // Use 1 for count to indicate the bigram has inputted. - const BigramProperty bigramProperty(&bigramTargetCodePoints, bigramProbability, - timestamp, 0 /* level */, 1 /* count */); + const BigramProperty bigramProperty( + CodePointArrayView(word1CodePoints, word1Length).toVector(), + bigramProbability, timestamp, 0 /* level */, 1 /* count */); const PrevWordsInfo prevWordsInfo(word0CodePoints, word0Length, false /* isBeginningOfSentence */); dictionary->addNgramEntry(&prevWordsInfo, &bigramProperty); diff --git a/native/jni/src/suggest/core/dictionary/dictionary.cpp b/native/jni/src/suggest/core/dictionary/dictionary.cpp index e4084b0f5..a3bb408c3 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.cpp +++ b/native/jni/src/suggest/core/dictionary/dictionary.cpp @@ -77,10 +77,8 @@ void Dictionary::NgramListenerForPrediction::onVisitEntry(const int ngramProbabi return; } int targetWordCodePoints[MAX_WORD_LENGTH]; - int unigramProbability = 0; - const int codePointCount = mDictStructurePolicy-> - getCodePointsAndProbabilityAndReturnCodePointCount(targetWordId, MAX_WORD_LENGTH, - targetWordCodePoints, &unigramProbability); + const int codePointCount = mDictStructurePolicy->getCodePointsAndReturnCodePointCount( + targetWordId, MAX_WORD_LENGTH, targetWordCodePoints); if (codePointCount <= 0) { return; } diff --git a/native/jni/src/suggest/core/dictionary/property/bigram_property.h b/native/jni/src/suggest/core/dictionary/property/bigram_property.h index 343af143c..9e0baa032 100644 --- a/native/jni/src/suggest/core/dictionary/property/bigram_property.h +++ b/native/jni/src/suggest/core/dictionary/property/bigram_property.h @@ -26,9 +26,9 @@ namespace latinime { // TODO: Change to NgramProperty. class BigramProperty { public: - BigramProperty(const std::vector<int> *const targetCodePoints, - const int probability, const int timestamp, const int level, const int count) - : mTargetCodePoints(*targetCodePoints), mProbability(probability), + BigramProperty(const std::vector<int> &&targetCodePoints, const int probability, + const int timestamp, const int level, const int count) + : mTargetCodePoints(std::move(targetCodePoints)), mProbability(probability), mTimestamp(timestamp), mLevel(level), mCount(count) {} const std::vector<int> *getTargetCodePoints() const { diff --git a/native/jni/src/suggest/core/dictionary/property/unigram_property.h b/native/jni/src/suggest/core/dictionary/property/unigram_property.h index 902eb000f..b7e7d6686 100644 --- a/native/jni/src/suggest/core/dictionary/property/unigram_property.h +++ b/native/jni/src/suggest/core/dictionary/property/unigram_property.h @@ -27,8 +27,9 @@ class UnigramProperty { public: class ShortcutProperty { public: - ShortcutProperty(const std::vector<int> *const targetCodePoints, const int probability) - : mTargetCodePoints(*targetCodePoints), mProbability(probability) {} + ShortcutProperty(const std::vector<int> &&targetCodePoints, const int probability) + : mTargetCodePoints(std::move(targetCodePoints)), + mProbability(probability) {} const std::vector<int> *getTargetCodePoints() const { return &mTargetCodePoints; @@ -71,6 +72,11 @@ class UnigramProperty { return mIsBlacklisted; } + bool isPossiblyOffensive() const { + // TODO: Have dedicated flag. + return mProbability == 0; + } + bool hasShortcuts() const { return !mShortcuts.empty(); } diff --git a/native/jni/src/suggest/core/dictionary/property/word_property.h b/native/jni/src/suggest/core/dictionary/property/word_property.h index aa3e0b68a..4e6febb3f 100644 --- a/native/jni/src/suggest/core/dictionary/property/word_property.h +++ b/native/jni/src/suggest/core/dictionary/property/word_property.h @@ -33,10 +33,10 @@ class WordProperty { WordProperty() : mCodePoints(), mUnigramProperty(), mBigrams() {} - WordProperty(const std::vector<int> *const codePoints, - const UnigramProperty *const unigramProperty, + WordProperty(const std::vector<int> &&codePoints, const UnigramProperty *const unigramProperty, const std::vector<BigramProperty> *const bigrams) - : mCodePoints(*codePoints), mUnigramProperty(*unigramProperty), mBigrams(*bigrams) {} + : mCodePoints(std::move(codePoints)), mUnigramProperty(*unigramProperty), + mBigrams(*bigrams) {} void outputProperties(JNIEnv *const env, jintArray outCodePoints, jbooleanArray outFlags, jintArray outProbabilityInfo, jobject outBigramTargets, jobject outBigramProbabilities, diff --git a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h index a498b6f65..1546b2610 100644 --- a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h +++ b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h @@ -51,9 +51,8 @@ class DictionaryStructureWithBufferPolicy { virtual void createAndGetAllChildDicNodes(const DicNode *const dicNode, DicNodeVector *const childDicNodes) const = 0; - virtual int getCodePointsAndProbabilityAndReturnCodePointCount( - const int wordId, const int maxCodePointCount, int *const outCodePoints, - int *const outUnigramProbability) const = 0; + virtual int getCodePointsAndReturnCodePointCount(const int wordId, const int maxCodePointCount, + int *const outCodePoints) const = 0; virtual int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const = 0; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp index ee1403739..f752f89f1 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp @@ -87,14 +87,13 @@ void Ver4PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const d } } -int Ver4PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( - const int wordId, const int maxCodePointCount, int *const outCodePoints, - int *const outUnigramProbability) const { +int Ver4PatriciaTriePolicy::getCodePointsAndReturnCodePointCount(const int wordId, + const int maxCodePointCount, int *const outCodePoints) const { DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader); const int ptNodePos = getTerminalPtNodePosFromWordId(wordId); readingHelper.initWithPtNodePos(ptNodePos); - const int codePointCount = readingHelper.getCodePointsAndProbabilityAndReturnCodePointCount( - maxCodePointCount, outCodePoints, outUnigramProbability); + const int codePointCount = readingHelper.getCodePointsAndReturnCodePointCount( + maxCodePointCount, outCodePoints); if (readingHelper.isError()) { mIsCorrupted = true; AKLOGE("Dictionary reading error in getCodePointsAndProbabilityAndReturnCodePointCount()."); @@ -495,8 +494,6 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty( return WordProperty(); } const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); - std::vector<int> codePointVector(ptNodeParams.getCodePoints(), - ptNodeParams.getCodePoints() + ptNodeParams.getCodePointCount()); const ProbabilityEntry probabilityEntry = mBuffers->getProbabilityDictContent()->getProbabilityEntry( ptNodeParams.getTerminalId()); @@ -521,20 +518,17 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty( if (word1TerminalPtNodePos == NOT_A_DICT_POS) { continue; } - // Word (unigram) probability - int word1Probability = NOT_A_PROBABILITY; - const int codePointCount = getCodePointsAndProbabilityAndReturnCodePointCount( + const int codePointCount = getCodePointsAndReturnCodePointCount( getWordIdFromTerminalPtNodePos(word1TerminalPtNodePos), MAX_WORD_LENGTH, - bigramWord1CodePoints, &word1Probability); - const std::vector<int> word1(bigramWord1CodePoints, - bigramWord1CodePoints + codePointCount); + bigramWord1CodePoints); const HistoricalInfo *const historicalInfo = bigramEntry.getHistoricalInfo(); const int probability = bigramEntry.hasHistoricalInfo() ? ForgettingCurveUtils::decodeProbability( bigramEntry.getHistoricalInfo(), mHeaderPolicy) : bigramEntry.getProbability(); - bigrams.emplace_back(&word1, probability, - historicalInfo->getTimeStamp(), historicalInfo->getLevel(), + bigrams.emplace_back( + CodePointArrayView(bigramWord1CodePoints, codePointCount).toVector(), + probability, historicalInfo->getTimeStamp(), historicalInfo->getLevel(), historicalInfo->getCount()); } } @@ -551,15 +545,16 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty( int shortcutProbability = NOT_A_PROBABILITY; shortcutDictContent->getShortcutEntryAndAdvancePosition(MAX_WORD_LENGTH, shortcutTarget, &shortcutTargetLength, &shortcutProbability, &hasNext, &shortcutPos); - const std::vector<int> target(shortcutTarget, shortcutTarget + shortcutTargetLength); - shortcuts.emplace_back(&target, shortcutProbability); + shortcuts.emplace_back( + CodePointArrayView(shortcutTarget, shortcutTargetLength).toVector(), + shortcutProbability); } } const UnigramProperty unigramProperty(ptNodeParams.representsBeginningOfSentence(), ptNodeParams.isNotAWord(), ptNodeParams.isBlacklisted(), ptNodeParams.getProbability(), historicalInfo->getTimeStamp(), historicalInfo->getLevel(), historicalInfo->getCount(), &shortcuts); - return WordProperty(&codePointVector, &unigramProperty, &bigrams); + return WordProperty(wordCodePoints.toVector(), &unigramProperty, &bigrams); } int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outCodePoints, @@ -580,10 +575,8 @@ int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const return 0; } const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token]; - int unigramProbability = NOT_A_PROBABILITY; - *outCodePointCount = getCodePointsAndProbabilityAndReturnCodePointCount( - getWordIdFromTerminalPtNodePos(terminalPtNodePos), MAX_WORD_LENGTH, outCodePoints, - &unigramProbability); + *outCodePointCount = getCodePointsAndReturnCodePointCount( + getWordIdFromTerminalPtNodePos(terminalPtNodePos), MAX_WORD_LENGTH, outCodePoints); const int nextToken = token + 1; if (nextToken >= terminalPtNodePositionsVectorSize) { // All words have been iterated. diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h index 576d2abb5..8420c94d0 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h @@ -85,9 +85,8 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { void createAndGetAllChildDicNodes(const DicNode *const dicNode, DicNodeVector *const childDicNodes) const; - int getCodePointsAndProbabilityAndReturnCodePointCount( - const int wordId, const int maxCodePointCount, int *const outCodePoints, - int *const outUnigramProbability) const; + int getCodePointsAndReturnCodePointCount(const int wordId, const int maxCodePointCount, + int *const outCodePoints) const; int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp index 40782a44f..5e4a4b166 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.cpp @@ -175,8 +175,8 @@ bool DynamicPtReadingHelper::traverseAllPtNodesInPtNodeArrayLevelPreorderDepthFi return !isError(); } -int DynamicPtReadingHelper::getCodePointsAndProbabilityAndReturnCodePointCount( - const int maxCodePointCount, int *const outCodePoints, int *const outUnigramProbability) { +int DynamicPtReadingHelper::getCodePointsAndReturnCodePointCount(const int maxCodePointCount, + int *const outCodePoints) { // This method traverses parent nodes from the terminal by following parent pointers; thus, // node code points are stored in the buffer in the reverse order. int reverseCodePoints[maxCodePointCount]; @@ -184,11 +184,8 @@ int DynamicPtReadingHelper::getCodePointsAndProbabilityAndReturnCodePointCount( // First, read the terminal node and get its probability. if (!isValidTerminalNode(terminalPtNodeParams)) { // Node at the ptNodePos is not a valid terminal node. - *outUnigramProbability = NOT_A_PROBABILITY; return 0; } - // Store terminal node probability. - *outUnigramProbability = terminalPtNodeParams.getProbability(); // Then, following parent node link to the dictionary root and fetch node code points. int totalCodePointCount = 0; while (!isEnd()) { @@ -196,7 +193,6 @@ int DynamicPtReadingHelper::getCodePointsAndProbabilityAndReturnCodePointCount( totalCodePointCount = getTotalCodePointCount(ptNodeParams); if (!ptNodeParams.isValid() || totalCodePointCount > maxCodePointCount) { // The ptNodePos is not a valid terminal node position in the dictionary. - *outUnigramProbability = NOT_A_PROBABILITY; return 0; } // Store node code points to buffer in the reverse order. @@ -207,7 +203,6 @@ int DynamicPtReadingHelper::getCodePointsAndProbabilityAndReturnCodePointCount( } if (isError()) { // The node position or the dictionary is invalid. - *outUnigramProbability = NOT_A_PROBABILITY; return 0; } // Reverse the stored code points to output them. diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h index 9a7abc97f..21c287fdc 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h @@ -211,8 +211,7 @@ class DynamicPtReadingHelper { bool traverseAllPtNodesInPtNodeArrayLevelPreorderDepthFirstManner( TraversingEventListener *const listener); - int getCodePointsAndProbabilityAndReturnCodePointCount(const int maxCodePointCount, - int *const outCodePoints, int *const outUnigramProbability); + int getCodePointsAndReturnCodePointCount(const int maxCodePointCount, int *const outCodePoints); int getTerminalPtNodePositionOfWord(const int *const inWord, const size_t length, const bool forceLowerCaseSearch); diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp index 6e7dba9ff..13cf9a5a8 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp @@ -58,6 +58,11 @@ void PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const dicNo } } +int PatriciaTriePolicy::getCodePointsAndReturnCodePointCount(const int wordId, + const int maxCodePointCount, int *const outCodePoints) const { + return getCodePointsAndProbabilityAndReturnCodePointCount(wordId, maxCodePointCount, + outCodePoints, nullptr /* outUnigramProbability */); +} // This retrieves code points and the probability of the word by its id. // Due to the fact that words are ordered in the dictionary in a strict breadth-first order, // it is possible to check for this with advantageous complexity. For each PtNode array, we search @@ -82,6 +87,9 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( int pos = getRootPosition(); int wordPos = 0; const int *const codePointTable = mHeaderPolicy.getCodePointTable(); + if (outUnigramProbability) { + *outUnigramProbability = NOT_A_PROBABILITY; + } // One iteration of the outer loop iterates through PtNode arrays. As stated above, we will // only traverse PtNodes that are actually a part of the terminal we are searching, so each // time we enter this loop we are one depth level further than last time. @@ -97,7 +105,6 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( pos, mBuffer.size()); mIsCorrupted = true; ASSERT(false); - *outUnigramProbability = NOT_A_PROBABILITY; return 0; } for (int ptNodeCount = PatriciaTrieReadingUtils::getPtNodeArraySizeAndAdvancePosition( @@ -107,7 +114,6 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( AKLOGE("PtNode position is invalid. pos: %d, dict size: %zd", pos, mBuffer.size()); mIsCorrupted = true; ASSERT(false); - *outUnigramProbability = NOT_A_PROBABILITY; return 0; } const PatriciaTrieReadingUtils::NodeFlags flags = @@ -130,9 +136,11 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( mBuffer.data(), codePointTable, &pos); } } - *outUnigramProbability = - PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition(mBuffer.data(), - &pos); + if (outUnigramProbability) { + *outUnigramProbability = + PatriciaTrieReadingUtils::readProbabilityAndAdvancePosition( + mBuffer.data(), &pos); + } return ++wordPos; } // We need to skip past this PtNode, so skip any remaining code points after the @@ -234,7 +242,6 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( pos); mIsCorrupted = true; ASSERT(false); - *outUnigramProbability = NOT_A_PROBABILITY; return 0; } } @@ -257,7 +264,6 @@ int PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( AKLOGE("Cannot skip bigrams. BufSize: %zd, pos: %d.", mBuffer.size(), pos); mIsCorrupted = true; ASSERT(false); - *outUnigramProbability = NOT_A_PROBABILITY; return 0; } } @@ -429,8 +435,6 @@ const WordProperty PatriciaTriePolicy::getWordProperty( const int ptNodePos = getTerminalPtNodePosFromWordId(wordId); const PtNodeParams ptNodeParams = mPtNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); - std::vector<int> codePointVector(ptNodeParams.getCodePoints(), - ptNodeParams.getCodePoints() + ptNodeParams.getCodePointCount()); // Fetch bigram information. std::vector<BigramProperty> bigrams; const int bigramListPos = getBigramsPositionOfPtNode(ptNodePos); @@ -445,11 +449,10 @@ const WordProperty PatriciaTriePolicy::getWordProperty( const int word1CodePointCount = getCodePointsAndProbabilityAndReturnCodePointCount( getWordIdFromTerminalPtNodePos(bigramsIt.getBigramPos()), MAX_WORD_LENGTH, bigramWord1CodePoints, &word1Probability); - const std::vector<int> word1(bigramWord1CodePoints, - bigramWord1CodePoints + word1CodePointCount); const int probability = getProbability(word1Probability, bigramsIt.getProbability()); - bigrams.emplace_back(&word1, probability, - NOT_A_TIMESTAMP /* timestamp */, 0 /* level */, 0 /* count */); + bigrams.emplace_back( + CodePointArrayView(bigramWord1CodePoints, word1CodePointCount).toVector(), + probability, NOT_A_TIMESTAMP /* timestamp */, 0 /* level */, 0 /* count */); } } // Fetch shortcut information. @@ -465,17 +468,17 @@ const WordProperty PatriciaTriePolicy::getWordProperty( hasNext = ShortcutListReadingUtils::hasNext(shortcutFlags); const int shortcutTargetLength = ShortcutListReadingUtils::readShortcutTarget( mBuffer, MAX_WORD_LENGTH, shortcutTargetCodePoints, &shortcutPos); - const std::vector<int> shortcutTarget(shortcutTargetCodePoints, - shortcutTargetCodePoints + shortcutTargetLength); const int shortcutProbability = ShortcutListReadingUtils::getProbabilityFromFlags(shortcutFlags); - shortcuts.emplace_back(&shortcutTarget, shortcutProbability); + shortcuts.emplace_back( + CodePointArrayView(shortcutTargetCodePoints, shortcutTargetLength).toVector(), + shortcutProbability); } } const UnigramProperty unigramProperty(ptNodeParams.representsBeginningOfSentence(), ptNodeParams.isNotAWord(), ptNodeParams.isBlacklisted(), ptNodeParams.getProbability(), NOT_A_TIMESTAMP /* timestamp */, 0 /* level */, 0 /* count */, &shortcuts); - return WordProperty(&codePointVector, &unigramProperty, &bigrams); + return WordProperty(wordCodePoints.toVector(), &unigramProperty, &bigrams); } int PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outCodePoints, @@ -497,10 +500,8 @@ int PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outC return 0; } const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token]; - int unigramProbability = NOT_A_PROBABILITY; - *outCodePointCount = getCodePointsAndProbabilityAndReturnCodePointCount( - getWordIdFromTerminalPtNodePos(terminalPtNodePos), MAX_WORD_LENGTH, outCodePoints, - &unigramProbability); + *outCodePointCount = getCodePointsAndReturnCodePointCount( + getWordIdFromTerminalPtNodePos(terminalPtNodePos), MAX_WORD_LENGTH, outCodePoints); const int nextToken = token + 1; if (nextToken >= terminalPtNodePositionsVectorSize) { // All words have been iterated. diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h index 5f179513f..0d679c5dc 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h @@ -43,7 +43,7 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { PatriciaTriePolicy(MmappedBuffer::MmappedBufferPtr mmappedBuffer) : mMmappedBuffer(std::move(mmappedBuffer)), mHeaderPolicy(mMmappedBuffer->getReadOnlyByteArrayView().data(), - FormatUtils::detectFormatVersion(mmappedBuffer->getReadOnlyByteArrayView())), + FormatUtils::detectFormatVersion(mMmappedBuffer->getReadOnlyByteArrayView())), mBuffer(mMmappedBuffer->getReadOnlyByteArrayView().skip(mHeaderPolicy.getSize())), mBigramListPolicy(mBuffer), mShortcutListPolicy(mBuffer), mPtNodeReader(mBuffer, &mBigramListPolicy, &mShortcutListPolicy, @@ -58,9 +58,8 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { void createAndGetAllChildDicNodes(const DicNode *const dicNode, DicNodeVector *const childDicNodes) const; - int getCodePointsAndProbabilityAndReturnCodePointCount( - const int wordId, const int maxCodePointCount, int *const outCodePoints, - int *const outUnigramProbability) const; + int getCodePointsAndReturnCodePointCount(const int wordId, const int maxCodePointCount, + int *const outCodePoints) const; int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const; @@ -155,6 +154,9 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { std::vector<int> mTerminalPtNodePositionsForIteratingWords; mutable bool mIsCorrupted; + int getCodePointsAndProbabilityAndReturnCodePointCount(const int wordId, + const int maxCodePointCount, int *const outCodePoints, + int *const outUnigramProbability) const; int getShortcutPositionOfPtNode(const int ptNodePos) const; int getBigramsPositionOfPtNode(const int ptNodePos) const; int createAndGetLeavingChildNode(const DicNode *const dicNode, const int ptNodePos, diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp index 35f0f768f..139230228 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp @@ -38,7 +38,7 @@ bool LanguageModelDictContent::runGC( 0 /* nextLevelBitmapEntryIndex */, outNgramCount); } -int LanguageModelDictContent::getWordProbability(const WordIdArrayView prevWordIds, +const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArrayView prevWordIds, const int wordId, const HeaderPolicy *const headerPolicy) const { int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1]; bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex(); @@ -60,17 +60,29 @@ int LanguageModelDictContent::getWordProbability(const WordIdArrayView prevWordI } const ProbabilityEntry probabilityEntry = ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo); + int probability = NOT_A_PROBABILITY; if (mHasHistoricalInfo) { - const int probability = ForgettingCurveUtils::decodeProbability( - probabilityEntry.getHistoricalInfo(), headerPolicy) - + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */); - return std::min(probability, MAX_PROBABILITY); + const int rawProbability = ForgettingCurveUtils::decodeProbability( + probabilityEntry.getHistoricalInfo(), headerPolicy); + if (rawProbability == NOT_A_PROBABILITY) { + // The entry should not be treated as a valid entry. + continue; + } + probability = std::min(rawProbability + + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */), + MAX_PROBABILITY); } else { - return probabilityEntry.getProbability(); + probability = probabilityEntry.getProbability(); } + // TODO: Some flags in unigramProbabilityEntry should be overwritten by flags in + // probabilityEntry. + const ProbabilityEntry unigramProbabilityEntry = getProbabilityEntry(wordId); + return WordAttributes(probability, unigramProbabilityEntry.isNotAWord(), + unigramProbabilityEntry.isBlacklisted(), + unigramProbabilityEntry.isPossiblyOffensive()); } // Cannot find the word. - return NOT_A_PROBABILITY; + return WordAttributes(); } ProbabilityEntry LanguageModelDictContent::getNgramProbabilityEntry( diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h index a793af4be..b7e4af977 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h @@ -21,6 +21,7 @@ #include <vector> #include "defines.h" +#include "suggest/core/dictionary/word_attributes.h" #include "suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h" #include "suggest/policyimpl/dictionary/structure/v4/content/terminal_position_lookup_table.h" #include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h" @@ -128,7 +129,7 @@ class LanguageModelDictContent { const LanguageModelDictContent *const originalContent, int *const outNgramCount); - int getWordProbability(const WordIdArrayView prevWordIds, const int wordId, + const WordAttributes getWordAttributes(const WordIdArrayView prevWordIds, const int wordId, const HeaderPolicy *const headerPolicy) const; ProbabilityEntry getProbabilityEntry(const int wordId) const { diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h index f1bf12cb2..e1e10ca17 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h @@ -49,7 +49,9 @@ class ProbabilityEntry { // Create from unigram property. ProbabilityEntry(const UnigramProperty *const unigramProperty) - : mFlags(createFlags(unigramProperty->representsBeginningOfSentence())), + : mFlags(createFlags(unigramProperty->representsBeginningOfSentence(), + unigramProperty->isNotAWord(), unigramProperty->isBlacklisted(), + unigramProperty->isPossiblyOffensive())), mProbability(unigramProperty->getProbability()), mHistoricalInfo(unigramProperty->getTimestamp(), unigramProperty->getLevel(), unigramProperty->getCount()) {} @@ -85,6 +87,18 @@ class ProbabilityEntry { return (mFlags & Ver4DictConstants::FLAG_REPRESENTS_BEGINNING_OF_SENTENCE) != 0; } + bool isNotAWord() const { + return (mFlags & Ver4DictConstants::FLAG_NOT_A_WORD) != 0; + } + + bool isBlacklisted() const { + return (mFlags & Ver4DictConstants::FLAG_BLACKLISTED) != 0; + } + + bool isPossiblyOffensive() const { + return (mFlags & Ver4DictConstants::FLAG_POSSIBLY_OFFENSIVE) != 0; + } + uint64_t encode(const bool hasHistoricalInfo) const { uint64_t encodedEntry = static_cast<uint64_t>(mFlags); if (hasHistoricalInfo) { @@ -142,10 +156,20 @@ class ProbabilityEntry { (encodedEntry >> (pos * CHAR_BIT)) & ((1ull << (size * CHAR_BIT)) - 1)); } - static uint8_t createFlags(const bool representsBeginningOfSentence) { + static uint8_t createFlags(const bool representsBeginningOfSentence, + const bool isNotAWord, const bool isBlacklisted, const bool isPossiblyOffensive) { uint8_t flags = 0; if (representsBeginningOfSentence) { - flags ^= Ver4DictConstants::FLAG_REPRESENTS_BEGINNING_OF_SENTENCE; + flags |= Ver4DictConstants::FLAG_REPRESENTS_BEGINNING_OF_SENTENCE; + } + if (isNotAWord) { + flags |= Ver4DictConstants::FLAG_NOT_A_WORD; + } + if (isBlacklisted) { + flags |= Ver4DictConstants::FLAG_BLACKLISTED; + } + if (isPossiblyOffensive) { + flags |= Ver4DictConstants::FLAG_POSSIBLY_OFFENSIVE; } return flags; } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp index 39822b94a..8e6cb974b 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp @@ -54,6 +54,9 @@ const int Ver4DictConstants::WORD_COUNT_FIELD_SIZE = 1; const uint8_t Ver4DictConstants::FLAG_REPRESENTS_BEGINNING_OF_SENTENCE = 0x1; const uint8_t Ver4DictConstants::FLAG_NOT_A_VALID_ENTRY = 0x2; +const uint8_t Ver4DictConstants::FLAG_NOT_A_WORD = 0x4; +const uint8_t Ver4DictConstants::FLAG_BLACKLISTED = 0x8; +const uint8_t Ver4DictConstants::FLAG_POSSIBLY_OFFENSIVE = 0x10; const int Ver4DictConstants::SHORTCUT_ADDRESS_TABLE_BLOCK_SIZE = 64; const int Ver4DictConstants::SHORTCUT_ADDRESS_TABLE_DATA_SIZE = 4; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h index dfcdd4d6f..600b5ffe4 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h @@ -52,6 +52,9 @@ class Ver4DictConstants { // Flags in probability entry. static const uint8_t FLAG_REPRESENTS_BEGINNING_OF_SENTENCE; static const uint8_t FLAG_NOT_A_VALID_ENTRY; + static const uint8_t FLAG_NOT_A_WORD; + static const uint8_t FLAG_BLACKLISTED; + static const uint8_t FLAG_POSSIBLY_OFFENSIVE; static const int SHORTCUT_ADDRESS_TABLE_BLOCK_SIZE; static const int SHORTCUT_ADDRESS_TABLE_DATA_SIZE; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.cpp index d795239fc..4110d6036 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.cpp @@ -51,26 +51,17 @@ const PtNodeParams Ver4PatriciaTrieNodeReader::fetchPtNodeInfoFromBufferAndProce const int parentPos = DynamicPtReadingUtils::getParentPtNodePos(parentPosOffset, headPos); int codePoints[MAX_WORD_LENGTH]; - const int codePonitCount = PatriciaTrieReadingUtils::getCharsAndAdvancePosition( - dictBuf, flags, MAX_WORD_LENGTH, mHeaderPolicy->getCodePointTable(), codePoints, &pos); + // Code point table is not used for ver4 dictionaries. + const int codePointCount = PatriciaTrieReadingUtils::getCharsAndAdvancePosition( + dictBuf, flags, MAX_WORD_LENGTH, nullptr /* codePointTable */, codePoints, &pos); int terminalIdFieldPos = NOT_A_DICT_POS; int terminalId = Ver4DictConstants::NOT_A_TERMINAL_ID; - int probability = NOT_A_PROBABILITY; if (PatriciaTrieReadingUtils::isTerminal(flags)) { terminalIdFieldPos = pos; if (usesAdditionalBuffer) { terminalIdFieldPos += mBuffer->getOriginalBufferSize(); } terminalId = Ver4PatriciaTrieReadingUtils::getTerminalIdAndAdvancePosition(dictBuf, &pos); - // TODO: Quit reading probability here. - const ProbabilityEntry probabilityEntry = - mLanguageModelDictContent->getProbabilityEntry(terminalId); - if (probabilityEntry.hasHistoricalInfo()) { - probability = ForgettingCurveUtils::decodeProbability( - probabilityEntry.getHistoricalInfo(), mHeaderPolicy); - } else { - probability = probabilityEntry.getProbability(); - } } int childrenPosFieldPos = pos; if (usesAdditionalBuffer) { @@ -91,8 +82,8 @@ const PtNodeParams Ver4PatriciaTrieNodeReader::fetchPtNodeInfoFromBufferAndProce // The destination position is stored at the same place as the parent position. return fetchPtNodeInfoFromBufferAndProcessMovedPtNode(parentPos, newSiblingNodePos); } else { - return PtNodeParams(headPos, flags, parentPos, codePonitCount, codePoints, - terminalIdFieldPos, terminalId, probability, childrenPosFieldPos, childrenPos, + return PtNodeParams(headPos, flags, parentPos, codePointCount, codePoints, + terminalIdFieldPos, terminalId, NOT_A_PROBABILITY, childrenPosFieldPos, childrenPos, newSiblingNodePos); } } diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.h index a91ad5728..f4df544e2 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.h @@ -29,15 +29,12 @@ class LanguageModelDictContent; /* * This class is used for helping to read nodes of ver4 patricia trie. This class handles moved - * node and reads node attributes including probability form language model. + * node and reads node attributes. */ class Ver4PatriciaTrieNodeReader : public PtNodeReader { public: - Ver4PatriciaTrieNodeReader(const BufferWithExtendableBuffer *const buffer, - const LanguageModelDictContent *const languageModelDictContent, - const HeaderPolicy *const headerPolicy) - : mBuffer(buffer), mLanguageModelDictContent(languageModelDictContent), - mHeaderPolicy(headerPolicy) {} + explicit Ver4PatriciaTrieNodeReader(const BufferWithExtendableBuffer *const buffer) + : mBuffer(buffer) {} ~Ver4PatriciaTrieNodeReader() {} @@ -50,8 +47,6 @@ class Ver4PatriciaTrieNodeReader : public PtNodeReader { DISALLOW_COPY_AND_ASSIGN(Ver4PatriciaTrieNodeReader); const BufferWithExtendableBuffer *const mBuffer; - const LanguageModelDictContent *const mLanguageModelDictContent; - const HeaderPolicy *const mHeaderPolicy; const PtNodeParams fetchPtNodeInfoFromBufferAndProcessMovedPtNode(const int ptNodePos, const int siblingNodePos) const; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp index 75ec16912..a1a33d27a 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp @@ -191,7 +191,6 @@ bool Ver4PatriciaTrieNodeWriter::writePtNodeAndAdvancePosition( ptNodeWritingPos); } - bool Ver4PatriciaTrieNodeWriter::writeNewTerminalPtNodeAndAdvancePosition( const PtNodeParams *const ptNodeParams, const UnigramProperty *const unigramProperty, int *const ptNodeWritingPos) { @@ -341,8 +340,8 @@ bool Ver4PatriciaTrieNodeWriter::writePtNodeAndGetTerminalIdAndAdvancePosition( ptNodeParams->getChildrenPos(), ptNodeWritingPos)) { return false; } - return updatePtNodeFlags(nodePos, ptNodeParams->isBlacklisted(), ptNodeParams->isNotAWord(), - isTerminal, ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */); + return updatePtNodeFlags(nodePos, isTerminal, + ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */); } // TODO: Move probability handling code to LanguageModelDictContent. @@ -361,14 +360,13 @@ const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom( } } -bool Ver4PatriciaTrieNodeWriter::updatePtNodeFlags(const int ptNodePos, - const bool isBlacklisted, const bool isNotAWord, const bool isTerminal, +bool Ver4PatriciaTrieNodeWriter::updatePtNodeFlags(const int ptNodePos, const bool isTerminal, const bool hasMultipleChars) { // Create node flags and write them. PatriciaTrieReadingUtils::NodeFlags nodeFlags = - PatriciaTrieReadingUtils::createAndGetFlags(isBlacklisted, isNotAWord, isTerminal, - false /* hasShortcutTargets */, false /* hasBigrams */, hasMultipleChars, - CHILDREN_POSITION_FIELD_SIZE); + PatriciaTrieReadingUtils::createAndGetFlags(false /* isNotAWord */, + false /* isBlacklisted */, isTerminal, false /* hasShortcutTargets */, + false /* hasBigrams */, hasMultipleChars, CHILDREN_POSITION_FIELD_SIZE); if (!DynamicPtWritingUtils::writeFlags(mTrieBuffer, nodeFlags, ptNodePos)) { AKLOGE("Cannot write PtNode flags. flags: %x, pos: %d", nodeFlags, ptNodePos); return false; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h index 08b7d3825..17915273b 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h @@ -103,8 +103,7 @@ class Ver4PatriciaTrieNodeWriter : public PtNodeWriter { const ProbabilityEntry *const originalProbabilityEntry, const ProbabilityEntry *const probabilityEntry) const; - bool updatePtNodeFlags(const int ptNodePos, const bool isBlacklisted, const bool isNotAWord, - const bool isTerminal, const bool hasMultipleChars); + bool updatePtNodeFlags(const int ptNodePos, const bool isTerminal, const bool hasMultipleChars); static const int CHILDREN_POSITION_FIELD_SIZE; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp index 8d4135679..0f0696410 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp @@ -56,21 +56,11 @@ void Ver4PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const d if (!ptNodeParams.isValid()) { break; } - bool isTerminal = ptNodeParams.isTerminal() && !ptNodeParams.isDeleted(); - if (isTerminal && mHeaderPolicy->isDecayingDict()) { - // A DecayingDict may have a terminal PtNode that has a terminal DicNode whose - // probability is NOT_A_PROBABILITY. In such case, we don't want to treat it as a - // valid terminal DicNode. - isTerminal = ptNodeParams.getProbability() != NOT_A_PROBABILITY; - } - readingHelper.readNextSiblingNode(ptNodeParams); - if (ptNodeParams.representsNonWordInfo()) { - // Skip PtNodes that represent non-word information. - continue; - } + const bool isTerminal = ptNodeParams.isTerminal() && !ptNodeParams.isDeleted(); const int wordId = isTerminal ? ptNodeParams.getTerminalId() : NOT_A_WORD_ID; childDicNodes->pushLeavingChild(dicNode, ptNodeParams.getChildrenPos(), wordId, ptNodeParams.getCodePointArrayView()); + readingHelper.readNextSiblingNode(ptNodeParams); } if (readingHelper.isError()) { mIsCorrupted = true; @@ -78,15 +68,14 @@ void Ver4PatriciaTriePolicy::createAndGetAllChildDicNodes(const DicNode *const d } } -int Ver4PatriciaTriePolicy::getCodePointsAndProbabilityAndReturnCodePointCount( - const int wordId, const int maxCodePointCount, int *const outCodePoints, - int *const outUnigramProbability) const { +int Ver4PatriciaTriePolicy::getCodePointsAndReturnCodePointCount(const int wordId, + const int maxCodePointCount, int *const outCodePoints) const { DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader); const int ptNodePos = mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId); readingHelper.initWithPtNodePos(ptNodePos); - const int codePointCount = readingHelper.getCodePointsAndProbabilityAndReturnCodePointCount( - maxCodePointCount, outCodePoints, outUnigramProbability); + const int codePointCount = readingHelper.getCodePointsAndReturnCodePointCount( + maxCodePointCount, outCodePoints); if (readingHelper.isError()) { mIsCorrupted = true; AKLOGE("Dictionary reading error in getCodePointsAndProbabilityAndReturnCodePointCount()."); @@ -117,13 +106,8 @@ const WordAttributes Ver4PatriciaTriePolicy::getWordAttributesInContext( if (wordId == NOT_A_WORD_ID) { return WordAttributes(); } - const int ptNodePos = - mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId); - const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); - const int probability = mBuffers->getLanguageModelDictContent()->getWordProbability( - prevWordIds, wordId, mHeaderPolicy); - return WordAttributes(probability, ptNodeParams.isBlacklisted(), ptNodeParams.isNotAWord(), - probability == 0); + return mBuffers->getLanguageModelDictContent()->getWordAttributes(prevWordIds, wordId, + mHeaderPolicy); } int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordIds, @@ -131,15 +115,10 @@ int Ver4PatriciaTriePolicy::getProbabilityOfWord(const WordIdArrayView prevWordI if (wordId == NOT_A_WORD_ID || prevWordIds.contains(NOT_A_WORD_ID)) { return NOT_A_PROBABILITY; } - const int ptNodePos = - mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId); - const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); - if (ptNodeParams.isDeleted() || ptNodeParams.isBlacklisted() || ptNodeParams.isNotAWord()) { - return NOT_A_PROBABILITY; - } const ProbabilityEntry probabilityEntry = mBuffers->getLanguageModelDictContent()->getNgramProbabilityEntry(prevWordIds, wordId); - if (!probabilityEntry.isValid()) { + if (!probabilityEntry.isValid() || probabilityEntry.isBlacklisted() + || probabilityEntry.isNotAWord()) { return NOT_A_PROBABILITY; } if (mHeaderPolicy->hasHistoricalInfoOfWords()) { @@ -166,6 +145,9 @@ void Ver4PatriciaTriePolicy::iterateNgramEntries(const WordIdArrayView prevWordI for (const auto entry : languageModelDictContent->getProbabilityEntries( prevWordIds.limit(i))) { const ProbabilityEntry &probabilityEntry = entry.getProbabilityEntry(); + if (!probabilityEntry.isValid()) { + continue; + } const int probability = probabilityEntry.hasHistoricalInfo() ? ForgettingCurveUtils::decodeProbability( probabilityEntry.getHistoricalInfo(), mHeaderPolicy) @@ -463,8 +445,6 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty( const int ptNodePos = mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId); const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos); - std::vector<int> codePointVector(ptNodeParams.getCodePoints(), - ptNodeParams.getCodePoints() + ptNodeParams.getCodePointCount()); const ProbabilityEntry probabilityEntry = mBuffers->getLanguageModelDictContent()->getProbabilityEntry( ptNodeParams.getTerminalId()); @@ -476,19 +456,15 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty( int bigramWord1CodePoints[MAX_WORD_LENGTH]; for (const auto entry : mBuffers->getLanguageModelDictContent()->getProbabilityEntries( prevWordIds)) { - // Word (unigram) probability - int word1Probability = NOT_A_PROBABILITY; - const int codePointCount = getCodePointsAndProbabilityAndReturnCodePointCount( - entry.getWordId(), MAX_WORD_LENGTH, bigramWord1CodePoints, &word1Probability); - const std::vector<int> word1(bigramWord1CodePoints, - bigramWord1CodePoints + codePointCount); + const int codePointCount = getCodePointsAndReturnCodePointCount(entry.getWordId(), + MAX_WORD_LENGTH, bigramWord1CodePoints); const ProbabilityEntry probabilityEntry = entry.getProbabilityEntry(); const HistoricalInfo *const historicalInfo = probabilityEntry.getHistoricalInfo(); const int probability = probabilityEntry.hasHistoricalInfo() ? ForgettingCurveUtils::decodeProbability(historicalInfo, mHeaderPolicy) : probabilityEntry.getProbability(); - bigrams.emplace_back(&word1, probability, - historicalInfo->getTimeStamp(), historicalInfo->getLevel(), + bigrams.emplace_back(CodePointArrayView(bigramWord1CodePoints, codePointCount).toVector(), + probability, historicalInfo->getTimeStamp(), historicalInfo->getLevel(), historicalInfo->getCount()); } // Fetch shortcut information. @@ -504,15 +480,16 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty( int shortcutProbability = NOT_A_PROBABILITY; shortcutDictContent->getShortcutEntryAndAdvancePosition(MAX_WORD_LENGTH, shortcutTarget, &shortcutTargetLength, &shortcutProbability, &hasNext, &shortcutPos); - const std::vector<int> target(shortcutTarget, shortcutTarget + shortcutTargetLength); - shortcuts.emplace_back(&target, shortcutProbability); + shortcuts.emplace_back( + CodePointArrayView(shortcutTarget, shortcutTargetLength).toVector(), + shortcutProbability); } } - const UnigramProperty unigramProperty(ptNodeParams.representsBeginningOfSentence(), - ptNodeParams.isNotAWord(), ptNodeParams.isBlacklisted(), ptNodeParams.getProbability(), - historicalInfo->getTimeStamp(), historicalInfo->getLevel(), - historicalInfo->getCount(), &shortcuts); - return WordProperty(&codePointVector, &unigramProperty, &bigrams); + const UnigramProperty unigramProperty(probabilityEntry.representsBeginningOfSentence(), + probabilityEntry.isNotAWord(), probabilityEntry.isBlacklisted(), + probabilityEntry.getProbability(), historicalInfo->getTimeStamp(), + historicalInfo->getLevel(), historicalInfo->getCount(), &shortcuts); + return WordProperty(wordCodePoints.toVector(), &unigramProperty, &bigrams); } int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outCodePoints, @@ -535,9 +512,8 @@ int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token]; const PtNodeParams ptNodeParams = mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(terminalPtNodePos); - int unigramProbability = NOT_A_PROBABILITY; - *outCodePointCount = getCodePointsAndProbabilityAndReturnCodePointCount( - ptNodeParams.getTerminalId(), MAX_WORD_LENGTH, outCodePoints, &unigramProbability); + *outCodePointCount = getCodePointsAndReturnCodePointCount(ptNodeParams.getTerminalId(), + MAX_WORD_LENGTH, outCodePoints); const int nextToken = token + 1; if (nextToken >= terminalPtNodePositionsVectorSize) { // All words have been iterated. diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h index a117a3614..c9bde2cf5 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h @@ -45,8 +45,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { mDictBuffer(mBuffers->getWritableTrieBuffer()), mShortcutPolicy(mBuffers->getMutableShortcutDictContent(), mBuffers->getTerminalPositionLookupTable()), - mNodeReader(mDictBuffer, mBuffers->getLanguageModelDictContent(), mHeaderPolicy), - mPtNodeArrayReader(mDictBuffer), + mNodeReader(mDictBuffer), mPtNodeArrayReader(mDictBuffer), mNodeWriter(mDictBuffer, mBuffers.get(), mHeaderPolicy, &mNodeReader, &mPtNodeArrayReader, &mShortcutPolicy), mUpdatingHelper(mDictBuffer, &mNodeReader, &mNodeWriter), @@ -62,9 +61,8 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { void createAndGetAllChildDicNodes(const DicNode *const dicNode, DicNodeVector *const childDicNodes) const; - int getCodePointsAndProbabilityAndReturnCodePointCount( - const int wordId, const int maxCodePointCount, int *const outCodePoints, - int *const outUnigramProbability) const; + int getCodePointsAndReturnCodePointCount(const int wordId, const int maxCodePointCount, + int *const outCodePoints) const; int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp index 63e43a544..442abadee 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp @@ -73,8 +73,7 @@ bool Ver4PatriciaTrieWritingHelper::writeToDictFileWithGC(const int rootPtNodeAr bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, const HeaderPolicy *const headerPolicy, Ver4DictBuffers *const buffersToWrite, int *const outUnigramCount, int *const outBigramCount) { - Ver4PatriciaTrieNodeReader ptNodeReader(mBuffers->getTrieBuffer(), - mBuffers->getLanguageModelDictContent(), headerPolicy); + Ver4PatriciaTrieNodeReader ptNodeReader(mBuffers->getTrieBuffer()); Ver4PtNodeArrayReader ptNodeArrayReader(mBuffers->getTrieBuffer()); Ver4ShortcutListPolicy shortcutPolicy(mBuffers->getMutableShortcutDictContent(), mBuffers->getTerminalPositionLookupTable()); @@ -137,8 +136,7 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos, } // Create policy instances for the GCed dictionary. - Ver4PatriciaTrieNodeReader newPtNodeReader(buffersToWrite->getTrieBuffer(), - buffersToWrite->getLanguageModelDictContent(), headerPolicy); + Ver4PatriciaTrieNodeReader newPtNodeReader(buffersToWrite->getTrieBuffer()); Ver4PtNodeArrayReader newPtNodeArrayreader(buffersToWrite->getTrieBuffer()); Ver4ShortcutListPolicy newShortcutPolicy(buffersToWrite->getMutableShortcutDictContent(), buffersToWrite->getTerminalPositionLookupTable()); diff --git a/native/jni/src/utils/int_array_view.h b/native/jni/src/utils/int_array_view.h index f3a8589ca..408373176 100644 --- a/native/jni/src/utils/int_array_view.h +++ b/native/jni/src/utils/int_array_view.h @@ -129,6 +129,10 @@ class IntArrayView { return mPtr[mSize - 1]; } + AK_FORCE_INLINE std::vector<int> toVector() const { + return std::vector<int>(begin(), end()); + } + private: DISALLOW_ASSIGNMENT_OPERATOR(IntArrayView); diff --git a/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp b/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp index 06f82df52..daa32c348 100644 --- a/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp +++ b/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp @@ -107,15 +107,15 @@ TEST(LanguageModelDictContentTest, TestGetWordProbability) { languageModelDictContent.setProbabilityEntry(prevWordIds[0], &probabilityEntry); languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1), wordId, &bigramProbabilityEntry); - EXPECT_EQ(bigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId, - nullptr /* headerPolicy */)); + EXPECT_EQ(bigramProbability, languageModelDictContent.getWordAttributes(prevWordIds, wordId, + nullptr /* headerPolicy */).getProbability()); const ProbabilityEntry trigramProbabilityEntry(flag, trigramProbability); languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1), prevWordIds[1], &probabilityEntry); languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(2), wordId, &trigramProbabilityEntry); - EXPECT_EQ(trigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId, - nullptr /* headerPolicy */)); + EXPECT_EQ(trigramProbability, languageModelDictContent.getWordAttributes(prevWordIds, wordId, + nullptr /* headerPolicy */).getProbability()); } } // namespace diff --git a/native/jni/tests/utils/int_array_view_test.cpp b/native/jni/tests/utils/int_array_view_test.cpp index 487bd04b1..4757a416b 100644 --- a/native/jni/tests/utils/int_array_view_test.cpp +++ b/native/jni/tests/utils/int_array_view_test.cpp @@ -144,5 +144,12 @@ TEST(IntArrayViewTest, TestLastOrDefault) { EXPECT_EQ(10, intArrayView.skip(6).lastOrDefault(10)); } +TEST(IntArrayViewTest, TestToVector) { + const std::vector<int> intVector = {3, 2, 1, 0, -1, -2}; + IntArrayView intArrayView(intVector); + EXPECT_EQ(intVector, intArrayView.toVector()); + EXPECT_EQ(std::vector<int>(), CodePointArrayView().toVector()); +} + } // namespace } // namespace latinime |