diff options
Diffstat (limited to 'native/jni/src')
13 files changed, 53 insertions, 38 deletions
diff --git a/native/jni/src/suggest/core/dictionary/bigram_dictionary.cpp b/native/jni/src/suggest/core/dictionary/bigram_dictionary.cpp index 295e760d6..56339fe48 100644 --- a/native/jni/src/suggest/core/dictionary/bigram_dictionary.cpp +++ b/native/jni/src/suggest/core/dictionary/bigram_dictionary.cpp @@ -57,6 +57,10 @@ void BigramDictionary::getPredictions(const PrevWordsInfo *const prevWordsInfo, if (bigramsIt.getBigramPos() == NOT_A_DICT_POS) { continue; } + if (prevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */) + && bigramsIt.getProbability() == NOT_A_PROBABILITY) { + continue; + } const int codePointCount = mDictionaryStructurePolicy-> getCodePointsAndProbabilityAndReturnCodePointCount(bigramsIt.getBigramPos(), MAX_WORD_LENGTH, bigramCodePoints, &unigramProbability); diff --git a/native/jni/src/suggest/core/dictionary/dictionary.cpp b/native/jni/src/suggest/core/dictionary/dictionary.cpp index bf0d0b126..e553bc0fc 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.cpp +++ b/native/jni/src/suggest/core/dictionary/dictionary.cpp @@ -145,10 +145,11 @@ const WordProperty Dictionary::getWordProperty(const int *const codePoints, codePoints, codePointCount); } -int Dictionary::getNextWordAndNextToken(const int token, int *const outCodePoints) { +int Dictionary::getNextWordAndNextToken(const int token, int *const outCodePoints, + int *const outCodePointCount) { TimeKeeper::setCurrentTime(); return mDictionaryStructureWithBufferPolicy->getNextWordAndNextToken( - token, outCodePoints); + token, outCodePoints, outCodePointCount); } void Dictionary::logDictionaryInfo(JNIEnv *const env) const { diff --git a/native/jni/src/suggest/core/dictionary/dictionary.h b/native/jni/src/suggest/core/dictionary/dictionary.h index a96c87635..83447de44 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.h +++ b/native/jni/src/suggest/core/dictionary/dictionary.h @@ -103,7 +103,8 @@ class Dictionary { // Method to iterate all words in the dictionary. // The returned token has to be used to get the next word. If token is 0, this method newly // starts iterating the dictionary. - int getNextWordAndNextToken(const int token, int *const outCodePoints); + int getNextWordAndNextToken(const int token, int *const outCodePoints, + int *const outCodePointCount); const DictionaryStructureWithBufferPolicy *getDictionaryStructurePolicy() const { return mDictionaryStructureWithBufferPolicy.get(); diff --git a/native/jni/src/suggest/core/dictionary/property/word_property.cpp b/native/jni/src/suggest/core/dictionary/property/word_property.cpp index 6f5f808f8..5bdd5606b 100644 --- a/native/jni/src/suggest/core/dictionary/property/word_property.cpp +++ b/native/jni/src/suggest/core/dictionary/property/word_property.cpp @@ -28,7 +28,8 @@ void WordProperty::outputProperties(JNIEnv *const env, jintArray outCodePoints, MAX_WORD_LENGTH /* maxLength */, mCodePoints.data(), mCodePoints.size(), false /* needsNullTermination */); jboolean flags[] = {mUnigramProperty.isNotAWord(), mUnigramProperty.isBlacklisted(), - !mBigrams.empty(), mUnigramProperty.hasShortcuts()}; + !mBigrams.empty(), mUnigramProperty.hasShortcuts(), + mUnigramProperty.representsBeginningOfSentence()}; env->SetBooleanArrayRegion(outFlags, 0 /* start */, NELEMS(flags), flags); int probabilityInfo[] = {mUnigramProperty.getProbability(), mUnigramProperty.getTimestamp(), mUnigramProperty.getLevel(), mUnigramProperty.getCount()}; diff --git a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h index e2771f97c..b72601109 100644 --- a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h +++ b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h @@ -104,7 +104,8 @@ class DictionaryStructureWithBufferPolicy { // Method to iterate all words in the dictionary. // The returned token has to be used to get the next word. If token is 0, this method newly // starts iterating the dictionary. - virtual int getNextWordAndNextToken(const int token, int *const outCodePoints) = 0; + virtual int getNextWordAndNextToken(const int token, int *const outCodePoints, + int *const outCodePointCount) = 0; virtual bool isCorrupted() const = 0; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp index 4ac0f406e..9780ae048 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp @@ -478,10 +478,9 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(const int *const code return WordProperty(&codePointVector, &unigramProperty, &bigrams); } -int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outCodePoints) { - // TODO: Return code point count like other methods. - // Null termination. - outCodePoints[0] = 0; +int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outCodePoints, + int *const outCodePointCount) { + *outCodePointCount = 0; if (token == 0) { mTerminalPtNodePositionsForIteratingWords.clear(); DynamicPtReadingHelper::TraversePolicyToGetAllTerminalPtNodePositions traversePolicy( @@ -498,13 +497,8 @@ int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const } const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token]; int unigramProbability = NOT_A_PROBABILITY; - const int codePointCount = getCodePointsAndProbabilityAndReturnCodePointCount( + *outCodePointCount = getCodePointsAndProbabilityAndReturnCodePointCount( terminalPtNodePos, MAX_WORD_LENGTH, outCodePoints, &unigramProbability); - if (codePointCount < MAX_WORD_LENGTH) { - // Null termination. outCodePoints have to be null terminated or contain MAX_WORD_LENGTH - // code points. - outCodePoints[codePointCount] = 0; - } const int nextToken = token + 1; if (nextToken >= terminalPtNodePositionsVectorSize) { // All words have been iterated. diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h index 2e948ac4a..16b1bd2c1 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h @@ -134,7 +134,8 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { const WordProperty getWordProperty(const int *const codePoints, const int codePointCount) const; - int getNextWordAndNextToken(const int token, int *const outCodePoints); + int getNextWordAndNextToken(const int token, int *const outCodePoints, + int *const outCodePointCount); bool isCorrupted() const { return mIsCorrupted; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp index 7e1f3b233..5c62b9caf 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp @@ -391,7 +391,9 @@ const WordProperty PatriciaTriePolicy::getWordProperty(const int *const codePoin return WordProperty(&codePointVector, &unigramProperty, &bigrams); } -int PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outCodePoints) { +int PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outCodePoints, + int *const outCodePointCount) { + *outCodePointCount = 0; if (token == 0) { // Start iterating the dictionary. mTerminalPtNodePositionsForIteratingWords.clear(); @@ -409,8 +411,8 @@ int PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outC } const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token]; int unigramProbability = NOT_A_PROBABILITY; - getCodePointsAndProbabilityAndReturnCodePointCount(terminalPtNodePos, MAX_WORD_LENGTH, - outCodePoints, &unigramProbability); + *outCodePointCount = getCodePointsAndProbabilityAndReturnCodePointCount(terminalPtNodePos, + MAX_WORD_LENGTH, outCodePoints, &unigramProbability); const int nextToken = token + 1; if (nextToken >= terminalPtNodePositionsVectorSize) { // All words have been iterated. diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h index dce94363a..ec8407408 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h @@ -137,7 +137,8 @@ class PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { const WordProperty getWordProperty(const int *const codePoints, const int codePointCount) const; - int getNextWordAndNextToken(const int token, int *const outCodePoints); + int getNextWordAndNextToken(const int token, int *const outCodePoints, + int *const outCodePointCount); bool isCorrupted() const { return mIsCorrupted; diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp index f7f2a32b4..46107d92a 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp @@ -489,10 +489,9 @@ const WordProperty Ver4PatriciaTriePolicy::getWordProperty(const int *const code return WordProperty(&codePointVector, &unigramProperty, &bigrams); } -int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outCodePoints) { - // TODO: Return code point count like other methods. - // Null termination. - outCodePoints[0] = 0; +int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const outCodePoints, + int *const outCodePointCount) { + *outCodePointCount = 0; if (token == 0) { mTerminalPtNodePositionsForIteratingWords.clear(); DynamicPtReadingHelper::TraversePolicyToGetAllTerminalPtNodePositions traversePolicy( @@ -509,13 +508,8 @@ int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const } const int terminalPtNodePos = mTerminalPtNodePositionsForIteratingWords[token]; int unigramProbability = NOT_A_PROBABILITY; - const int codePointCount = getCodePointsAndProbabilityAndReturnCodePointCount( + *outCodePointCount = getCodePointsAndProbabilityAndReturnCodePointCount( terminalPtNodePos, MAX_WORD_LENGTH, outCodePoints, &unigramProbability); - if (codePointCount < MAX_WORD_LENGTH) { - // Null termination. outCodePoints have to be null terminated or contain MAX_WORD_LENGTH - // code points. - outCodePoints[codePointCount] = 0; - } const int nextToken = token + 1; if (nextToken >= terminalPtNodePositionsVectorSize) { // All words have been iterated. diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h index 0a20965f3..5d66a2cce 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h @@ -113,7 +113,8 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy { const WordProperty getWordProperty(const int *const codePoints, const int codePointCount) const; - int getNextWordAndNextToken(const int token, int *const outCodePoints); + int getNextWordAndNextToken(const int token, int *const outCodePoints, + int *const outCodePointCount); bool isCorrupted() const { return mIsCorrupted; diff --git a/native/jni/src/utils/char_utils.h b/native/jni/src/utils/char_utils.h index f28ed5682..63786502b 100644 --- a/native/jni/src/utils/char_utils.h +++ b/native/jni/src/utils/char_utils.h @@ -98,6 +98,10 @@ class CharUtils { // Beginning-of-Sentence. static AK_FORCE_INLINE int attachBeginningOfSentenceMarker(int *const codePoints, const int codePointCount, const int maxCodePoint) { + if (codePointCount > 0 && codePoints[0] == CODE_POINT_BEGINNING_OF_SENTENCE) { + // Marker has already been attached. + return codePointCount; + } if (codePointCount >= maxCodePoint) { // the code points cannot be marked as a Beginning-of-Sentence. return 0; diff --git a/native/jni/src/utils/jni_data_utils.h b/native/jni/src/utils/jni_data_utils.h index 67a66fdfe..3514aeeb0 100644 --- a/native/jni/src/utils/jni_data_utils.h +++ b/native/jni/src/utils/jni_data_utils.h @@ -69,18 +69,23 @@ class JniDataUtils { static void outputCodePoints(JNIEnv *env, jintArray intArrayToOutputCodePoints, const int start, const int maxLength, const int *const codePoints, const int codePointCount, const bool needsNullTermination) { - const int outputCodePointCount = std::min(maxLength, codePointCount); - int outputCodePonts[outputCodePointCount]; - for (int i = 0; i < outputCodePointCount; ++i) { + const int codePointBufSize = std::min(maxLength, codePointCount); + int outputCodePonts[codePointBufSize]; + int outputCodePointCount = 0; + for (int i = 0; i < codePointBufSize; ++i) { const int codePoint = codePoints[i]; + int codePointToOutput = codePoint; if (!CharUtils::isInUnicodeSpace(codePoint)) { - outputCodePonts[i] = CODE_POINT_REPLACEMENT_CHARACTER; + if (codePoint == CODE_POINT_BEGINNING_OF_SENTENCE) { + // Just skip Beginning-of-Sentence marker. + continue; + } + codePointToOutput = CODE_POINT_REPLACEMENT_CHARACTER; } else if (codePoint >= 0x01 && codePoint <= 0x1F) { // Control code. - outputCodePonts[i] = CODE_POINT_REPLACEMENT_CHARACTER; - } else { - outputCodePonts[i] = codePoint; + codePointToOutput = CODE_POINT_REPLACEMENT_CHARACTER; } + outputCodePonts[outputCodePointCount++] = codePointToOutput; } env->SetIntArrayRegion(intArrayToOutputCodePoints, start, outputCodePointCount, outputCodePonts); @@ -90,6 +95,11 @@ class JniDataUtils { } } + static void putBooleanToArray(JNIEnv *env, jbooleanArray array, const int index, + const jboolean value) { + env->SetBooleanArrayRegion(array, index, 1 /* len */, &value); + } + static void putIntToArray(JNIEnv *env, jintArray array, const int index, const int value) { env->SetIntArrayRegion(array, index, 1 /* len */, &value); } |