diff options
Diffstat (limited to 'native')
19 files changed, 215 insertions, 286 deletions
diff --git a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp index ac0b4ab15..154ea9800 100644 --- a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp +++ b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp @@ -199,47 +199,30 @@ static void latinime_BinaryDictionary_getSuggestions(JNIEnv *env, jclass clazz, ASSERT(false); return; } - int outputCodePoints[outputCodePointsLength]; - int scores[scoresLength]; - const jsize spaceIndicesLength = env->GetArrayLength(outSpaceIndicesArray); - int spaceIndices[spaceIndicesLength]; - const jsize outputTypesLength = env->GetArrayLength(outTypesArray); - int outputTypes[outputTypesLength]; const jsize outputAutoCommitFirstWordConfidenceLength = env->GetArrayLength(outAutoCommitFirstWordConfidenceArray); - // We only use the first result, as obviously we will only ever autocommit the first one ASSERT(outputAutoCommitFirstWordConfidenceLength == 1); - int outputAutoCommitFirstWordConfidence[outputAutoCommitFirstWordConfidenceLength]; - memset(outputCodePoints, 0, sizeof(outputCodePoints)); - memset(scores, 0, sizeof(scores)); - memset(spaceIndices, 0, sizeof(spaceIndices)); - memset(outputTypes, 0, sizeof(outputTypes)); - memset(outputAutoCommitFirstWordConfidence, 0, sizeof(outputAutoCommitFirstWordConfidence)); + if (outputAutoCommitFirstWordConfidenceLength != 1) { + // We only use the first result, as obviously we will only ever autocommit the first one + AKLOGE("Invalid outputAutoCommitFirstWordConfidenceLength: %d", + outputAutoCommitFirstWordConfidenceLength); + ASSERT(false); + return; + } + SuggestionResults suggestionResults(MAX_RESULTS); if (givenSuggestOptions.isGesture() || inputSize > 0) { // TODO: Use SuggestionResults to return suggestions. - count = dictionary->getSuggestions(pInfo, traverseSession, xCoordinates, yCoordinates, + dictionary->getSuggestions(pInfo, traverseSession, xCoordinates, yCoordinates, times, pointerIds, inputCodePoints, inputSize, prevWordCodePoints, - prevWordCodePointsLength, &givenSuggestOptions, outputCodePoints, - scores, spaceIndices, outputTypes, outputAutoCommitFirstWordConfidence); + prevWordCodePointsLength, &givenSuggestOptions, &suggestionResults); } else { - SuggestionResults suggestionResults(MAX_RESULTS); dictionary->getPredictions(prevWordCodePoints, prevWordCodePointsLength, &suggestionResults); - suggestionResults.outputSuggestions(env, outSuggestionCount, outCodePointsArray, - outScoresArray, outSpaceIndicesArray, outTypesArray, - outAutoCommitFirstWordConfidenceArray); - return; } - - // Copy back the output values - env->SetIntArrayRegion(outSuggestionCount, 0, 1 /* len */, &count); - env->SetIntArrayRegion(outCodePointsArray, 0, outputCodePointsLength, outputCodePoints); - env->SetIntArrayRegion(outScoresArray, 0, scoresLength, scores); - env->SetIntArrayRegion(outSpaceIndicesArray, 0, spaceIndicesLength, spaceIndices); - env->SetIntArrayRegion(outTypesArray, 0, outputTypesLength, outputTypes); - env->SetIntArrayRegion(outAutoCommitFirstWordConfidenceArray, 0, - outputAutoCommitFirstWordConfidenceLength, outputAutoCommitFirstWordConfidence); + suggestionResults.outputSuggestions(env, outSuggestionCount, outCodePointsArray, + outScoresArray, outSpaceIndicesArray, outTypesArray, + outAutoCommitFirstWordConfidenceArray); } static jint latinime_BinaryDictionary_getProbability(JNIEnv *env, jclass clazz, jlong dict, diff --git a/native/jni/src/defines.h b/native/jni/src/defines.h index 4e6ff9556..3651cd523 100644 --- a/native/jni/src/defines.h +++ b/native/jni/src/defines.h @@ -103,7 +103,8 @@ AK_FORCE_INLINE static int intArrayToCharArray(const int *const source, const in #define AKLOGI(fmt, ...) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, fmt, ##__VA_ARGS__) #endif // defined(HOST_TOOL) -#define DUMP_RESULT(words, frequencies) do { dumpResult(words, frequencies); } while (0) +#define DUMP_SUGGESTION(words, frequencies, index, score) \ + do { dumpWordInfo(words, frequencies, index, score); } while (0) #define DUMP_WORD(word, length) do { dumpWord(word, length); } while (0) #define INTS_TO_CHARS(input, length, output, outlength) do { \ intArrayToCharArray(input, length, output, outlength); } while (0) @@ -165,7 +166,7 @@ static inline void showStackTrace() { #else // defined(FLAG_DO_PROFILE) || defined(FLAG_DBG) #define AKLOGE(fmt, ...) #define AKLOGI(fmt, ...) -#define DUMP_RESULT(words, frequencies) +#define DUMP_SUGGESTION(words, frequencies, index, score) #define DUMP_WORD(word, length) #undef DO_ASSERT_TEST #define ASSERT(success) diff --git a/native/jni/src/suggest/core/dicnode/dic_node.h b/native/jni/src/suggest/core/dicnode/dic_node.h index 3118cdfa3..258aa9ce3 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node.h +++ b/native/jni/src/suggest/core/dicnode/dic_node.h @@ -83,14 +83,6 @@ class DicNode { #if DEBUG_DICT DicNodeProfiler mProfiler; #endif - ////////////////// - // Memory utils // - ////////////////// - AK_FORCE_INLINE static void managedDelete(DicNode *node) { - node->remove(); - } - // end - ///////////////// AK_FORCE_INLINE DicNode() : @@ -158,7 +150,7 @@ class DicNode { PROF_NODE_COPY(&dicNode->mProfiler, mProfiler); } - AK_FORCE_INLINE void remove() { + AK_FORCE_INLINE void finalize() { mIsUsed = false; if (mReleaseListener) { mReleaseListener->onReleased(this); @@ -478,17 +470,7 @@ class DicNode { mReleaseListener = releaseListener; } - AK_FORCE_INLINE bool compare(const DicNode *right) { - if (!isUsed() && !right->isUsed()) { - // Compare pointer values here for stable comparison - return this > right; - } - if (!isUsed()) { - return true; - } - if (!right->isUsed()) { - return false; - } + AK_FORCE_INLINE bool compare(const DicNode *right) const { // Promote exact matches to prevent them from being pruned. const bool leftExactMatch = ErrorTypeUtils::isExactMatch(getContainedErrorTypes()); const bool rightExactMatch = ErrorTypeUtils::isExactMatch(right->getContainedErrorTypes()); diff --git a/native/jni/src/suggest/core/dicnode/dic_node_priority_queue.h b/native/jni/src/suggest/core/dicnode/dic_node_priority_queue.h index 1f02731a5..213b1b968 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_priority_queue.h +++ b/native/jni/src/suggest/core/dicnode/dic_node_priority_queue.h @@ -68,15 +68,15 @@ class DicNodePriorityQueue : public DicNodeReleaseListener { } setMaxSize(maxSize); for (int i = 0; i < mCapacity + 1; ++i) { - mDicNodesBuf[i].remove(); + mDicNodesBuf[i].finalize(); mDicNodesBuf[i].setReleaseListener(this); - mUnusedNodeIndices[i] = i == mCapacity ? NOT_A_NODE_ID : static_cast<int>(i) + 1; + mUnusedNodeIndices[i] = (i == mCapacity) ? NOT_A_NODE_ID : (i + 1); } mNextUnusedNodeId = 0; } // Copy - AK_FORCE_INLINE DicNode *copyPush(DicNode *dicNode) { + AK_FORCE_INLINE DicNode *copyPush(const DicNode *const dicNode) { return copyPush(dicNode, mMaxSize); } @@ -89,11 +89,11 @@ class DicNodePriorityQueue : public DicNodeReleaseListener { if (dest) { DicNodeUtils::initByCopy(node, dest); } - node->remove(); + node->finalize(); mDicNodesQueue.pop(); } - void onReleased(DicNode *dicNode) { + void onReleased(const DicNode *dicNode) { const int index = static_cast<int>(dicNode - &mDicNodesBuf[0]); if (mUnusedNodeIndices[index] != NOT_A_NODE_ID) { // it's already released @@ -118,7 +118,8 @@ class DicNodePriorityQueue : public DicNodeReleaseListener { DISALLOW_IMPLICIT_CONSTRUCTORS(DicNodePriorityQueue); static const int NOT_A_NODE_ID = -1; - AK_FORCE_INLINE static bool compareDicNode(DicNode *left, DicNode *right) { + AK_FORCE_INLINE static bool compareDicNode(const DicNode *const left, + const DicNode *const right) { return left->compare(right); } @@ -141,10 +142,10 @@ class DicNodePriorityQueue : public DicNodeReleaseListener { } AK_FORCE_INLINE void pop() { - copyPop(0); + copyPop(nullptr); } - AK_FORCE_INLINE bool betterThanWorstDicNode(DicNode *dicNode) const { + AK_FORCE_INLINE bool betterThanWorstDicNode(const DicNode *const dicNode) const { DicNode *worstNode = mDicNodesQueue.top(); if (!worstNode) { return true; @@ -154,7 +155,7 @@ class DicNodePriorityQueue : public DicNodeReleaseListener { AK_FORCE_INLINE DicNode *searchEmptyDicNode() { if (mCapacity == 0) { - return 0; + return nullptr; } if (mNextUnusedNodeId == NOT_A_NODE_ID) { AKLOGI("No unused node found."); @@ -163,7 +164,7 @@ class DicNodePriorityQueue : public DicNodeReleaseListener { i, mDicNodesBuf[i].isUsed(), mUnusedNodeIndices[i]); } ASSERT(false); - return 0; + return nullptr; } DicNode *dicNode = &mDicNodesBuf[mNextUnusedNodeId]; markNodeAsUsed(dicNode); @@ -179,7 +180,7 @@ class DicNodePriorityQueue : public DicNodeReleaseListener { AK_FORCE_INLINE DicNode *pushPoolNodeWithMaxSize(DicNode *dicNode, const int maxSize) { if (!dicNode) { - return 0; + return nullptr; } if (!isFull(maxSize)) { mDicNodesQueue.push(dicNode); @@ -190,16 +191,16 @@ class DicNodePriorityQueue : public DicNodeReleaseListener { mDicNodesQueue.push(dicNode); return dicNode; } - dicNode->remove(); - return 0; + dicNode->finalize(); + return nullptr; } // Copy - AK_FORCE_INLINE DicNode *copyPush(DicNode *dicNode, const int maxSize) { + AK_FORCE_INLINE DicNode *copyPush(const DicNode *const dicNode, const int maxSize) { return pushPoolNodeWithMaxSize(newDicNode(dicNode), maxSize); } - AK_FORCE_INLINE DicNode *newDicNode(DicNode *dicNode) { + AK_FORCE_INLINE DicNode *newDicNode(const DicNode *const dicNode) { DicNode *newNode = searchEmptyDicNode(); if (newNode) { DicNodeUtils::initByCopy(dicNode, newNode); diff --git a/native/jni/src/suggest/core/dicnode/dic_node_release_listener.h b/native/jni/src/suggest/core/dicnode/dic_node_release_listener.h index 2ca4f21bd..c3f432951 100644 --- a/native/jni/src/suggest/core/dicnode/dic_node_release_listener.h +++ b/native/jni/src/suggest/core/dicnode/dic_node_release_listener.h @@ -27,7 +27,7 @@ class DicNodeReleaseListener { public: DicNodeReleaseListener() {} virtual ~DicNodeReleaseListener() {} - virtual void onReleased(DicNode *dicNode) = 0; + virtual void onReleased(const DicNode *dicNode) = 0; private: DISALLOW_COPY_AND_ASSIGN(DicNodeReleaseListener); }; diff --git a/native/jni/src/suggest/core/dicnode/dic_nodes_cache.h b/native/jni/src/suggest/core/dicnode/dic_nodes_cache.h index d4769e739..6b8dc8c96 100644 --- a/native/jni/src/suggest/core/dicnode/dic_nodes_cache.h +++ b/native/jni/src/suggest/core/dicnode/dic_nodes_cache.h @@ -100,14 +100,7 @@ class DicNodesCache { } AK_FORCE_INLINE void copyPushNextActive(DicNode *dicNode) { - DicNode *pushedDicNode = mNextActiveDicNodes->copyPush(dicNode); - if (!pushedDicNode) { - if (dicNode->isCached()) { - dicNode->remove(); - } - // We simply drop any dic node that was not cached, ignoring the slim chance - // that one of its children represents what the user really wanted. - } + mNextActiveDicNodes->copyPush(dicNode); } void popTerminal(DicNode *dest) { diff --git a/native/jni/src/suggest/core/dictionary/dictionary.cpp b/native/jni/src/suggest/core/dictionary/dictionary.cpp index 07b07f725..ae4646d2e 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.cpp +++ b/native/jni/src/suggest/core/dictionary/dictionary.cpp @@ -22,6 +22,7 @@ #include "defines.h" #include "suggest/core/policy/dictionary_header_structure_policy.h" +#include "suggest/core/result/suggestion_results.h" #include "suggest/core/session/dic_traverse_session.h" #include "suggest/core/suggest.h" #include "suggest/core/suggest_options.h" @@ -43,34 +44,25 @@ Dictionary::Dictionary(JNIEnv *env, DictionaryStructureWithBufferPolicy::Structu logDictionaryInfo(env); } -int Dictionary::getSuggestions(ProximityInfo *proximityInfo, DicTraverseSession *traverseSession, +void Dictionary::getSuggestions(ProximityInfo *proximityInfo, DicTraverseSession *traverseSession, int *xcoordinates, int *ycoordinates, int *times, int *pointerIds, int *inputCodePoints, int inputSize, int *prevWordCodePoints, int prevWordLength, - const SuggestOptions *const suggestOptions, int *outWords, int *outputScores, - int *spaceIndices, int *outputTypes, int *outputAutoCommitFirstWordConfidence) const { + const SuggestOptions *const suggestOptions, + SuggestionResults *const outSuggestionResults) const { TimeKeeper::setCurrentTime(); - int result = 0; + DicTraverseSession::initSessionInstance( + traverseSession, this, prevWordCodePoints, prevWordLength, suggestOptions); if (suggestOptions->isGesture()) { - DicTraverseSession::initSessionInstance( - traverseSession, this, prevWordCodePoints, prevWordLength, suggestOptions); - result = mGestureSuggest->getSuggestions(proximityInfo, traverseSession, xcoordinates, - ycoordinates, times, pointerIds, inputCodePoints, inputSize, outWords, - outputScores, spaceIndices, outputTypes, outputAutoCommitFirstWordConfidence); - if (DEBUG_DICT) { - DUMP_RESULT(outWords, outputScores); - } - return result; + mGestureSuggest->getSuggestions(proximityInfo, traverseSession, xcoordinates, + ycoordinates, times, pointerIds, inputCodePoints, inputSize, + outSuggestionResults); } else { - DicTraverseSession::initSessionInstance( - traverseSession, this, prevWordCodePoints, prevWordLength, suggestOptions); - result = mTypingSuggest->getSuggestions(proximityInfo, traverseSession, xcoordinates, + mTypingSuggest->getSuggestions(proximityInfo, traverseSession, xcoordinates, ycoordinates, times, pointerIds, inputCodePoints, inputSize, - outWords, outputScores, spaceIndices, outputTypes, - outputAutoCommitFirstWordConfidence); - if (DEBUG_DICT) { - DUMP_RESULT(outWords, outputScores); - } - return result; + outSuggestionResults); + } + if (DEBUG_DICT) { + outSuggestionResults->dumpSuggestions(); } } diff --git a/native/jni/src/suggest/core/dictionary/dictionary.h b/native/jni/src/suggest/core/dictionary/dictionary.h index 4d482e742..df5fc9b7d 100644 --- a/native/jni/src/suggest/core/dictionary/dictionary.h +++ b/native/jni/src/suggest/core/dictionary/dictionary.h @@ -62,11 +62,11 @@ class Dictionary { Dictionary(JNIEnv *env, DictionaryStructureWithBufferPolicy::StructurePolicyPtr dictionaryStructureWithBufferPolicy); - int getSuggestions(ProximityInfo *proximityInfo, DicTraverseSession *traverseSession, + void getSuggestions(ProximityInfo *proximityInfo, DicTraverseSession *traverseSession, int *xcoordinates, int *ycoordinates, int *times, int *pointerIds, int *inputCodePoints, int inputSize, int *prevWordCodePoints, int prevWordLength, - const SuggestOptions *const suggestOptions, int *outWords, int *outputScores, - int *spaceIndices, int *outputTypes, int *outputAutoCommitFirstWordConfidence) const; + const SuggestOptions *const suggestOptions, + SuggestionResults *const outSuggestionResults) const; void getPredictions(const int *word, int length, SuggestionResults *const outSuggestionResults) const; diff --git a/native/jni/src/suggest/core/policy/scoring.h b/native/jni/src/suggest/core/policy/scoring.h index 0251475d5..292194bf2 100644 --- a/native/jni/src/suggest/core/policy/scoring.h +++ b/native/jni/src/suggest/core/policy/scoring.h @@ -23,6 +23,7 @@ namespace latinime { class DicNode; class DicTraverseSession; +class SuggestionResults; // This class basically tweaks suggestions and distances apart from CompoundDistance class Scoring { @@ -30,11 +31,8 @@ class Scoring { virtual int calculateFinalScore(const float compoundDistance, const int inputSize, const ErrorTypeUtils::ErrorType containedErrorTypes, const bool forceCommit, const bool boostExactMatches) const = 0; - virtual bool getMostProbableString(const DicTraverseSession *const traverseSession, - const int terminalSize, const float languageWeight, int *const outputCodePoints, - int *const type, int *const freq) const = 0; - virtual void safetyNetForMostProbableString(const int scoreCount, - const int maxScore, int *const outputCodePoints, int *const scores) const = 0; + virtual void getMostProbableString(const DicTraverseSession *const traverseSession, + const float languageWeight, SuggestionResults *const outSuggestionResults) const = 0; virtual float getAdjustedLanguageWeight(DicTraverseSession *const traverseSession, DicNode *const terminals, const int size) const = 0; virtual float getDoubleLetterDemotionDistanceCost( diff --git a/native/jni/src/suggest/core/result/suggestion_results.cpp b/native/jni/src/suggest/core/result/suggestion_results.cpp index 2be757d83..da1c6bc72 100644 --- a/native/jni/src/suggest/core/result/suggestion_results.cpp +++ b/native/jni/src/suggest/core/result/suggestion_results.cpp @@ -54,13 +54,23 @@ void SuggestionResults::outputSuggestions(JNIEnv *env, jintArray outSuggestionCo void SuggestionResults::addPrediction(const int *const codePoints, const int codePointCount, const int probability) { - if (codePointCount <= 0 || codePointCount > MAX_WORD_LENGTH - || probability == NOT_A_PROBABILITY) { + if (probability == NOT_A_PROBABILITY) { // Invalid word. return; } - // Use probability as a score of the word. - const int score = probability; + addSuggestion(codePoints, codePointCount, probability, Dictionary::KIND_PREDICTION, + NOT_AN_INDEX, NOT_A_FIRST_WORD_CONFIDENCE); +} + +void SuggestionResults::addSuggestion(const int *const codePoints, const int codePointCount, + const int score, const int type, const int indexToPartialCommit, + const int autocimmitFirstWordConfindence) { + if (codePointCount <= 0 || codePointCount > MAX_WORD_LENGTH) { + // Invalid word. + AKLOGE("Invalid word is added to the suggestion results. codePointCount: %d", + codePointCount); + return; + } if (getSuggestionCount() >= mMaxSuggestionCount) { const SuggestedWord &mWorstSuggestion = mSuggestedWords.top(); if (score > mWorstSuggestion.getScore() || (score == mWorstSuggestion.getScore() @@ -70,8 +80,31 @@ void SuggestionResults::addPrediction(const int *const codePoints, const int cod return; } } - mSuggestedWords.push(SuggestedWord(codePoints, codePointCount, score, - Dictionary::KIND_PREDICTION, NOT_AN_INDEX, NOT_A_FIRST_WORD_CONFIDENCE)); + mSuggestedWords.push(SuggestedWord(codePoints, codePointCount, score, type, + indexToPartialCommit, autocimmitFirstWordConfindence)); +} + +void SuggestionResults::getSortedScores(int *const outScores) const { + auto copyOfSuggestedWords = mSuggestedWords; + while (!copyOfSuggestedWords.empty()) { + const SuggestedWord &suggestedWord = copyOfSuggestedWords.top(); + outScores[copyOfSuggestedWords.size() - 1] = suggestedWord.getScore(); + copyOfSuggestedWords.pop(); + } +} + +void SuggestionResults::dumpSuggestions() const { + std::vector<SuggestedWord> suggestedWords; + auto copyOfSuggestedWords = mSuggestedWords; + while (!copyOfSuggestedWords.empty()) { + suggestedWords.push_back(copyOfSuggestedWords.top()); + copyOfSuggestedWords.pop(); + } + int index = 0; + for (auto it = suggestedWords.rbegin(); it != suggestedWords.rend(); ++it) { + DUMP_SUGGESTION(it->getCodePoint(), it->getCodePointCount(), index, it->getScore()); + index++; + } } } // namespace latinime diff --git a/native/jni/src/suggest/core/result/suggestion_results.h b/native/jni/src/suggest/core/result/suggestion_results.h index 0b841ca19..020bab42b 100644 --- a/native/jni/src/suggest/core/result/suggestion_results.h +++ b/native/jni/src/suggest/core/result/suggestion_results.h @@ -35,8 +35,12 @@ class SuggestionResults { void outputSuggestions(JNIEnv *env, jintArray outSuggestionCount, jintArray outCodePointsArray, jintArray outScoresArray, jintArray outSpaceIndicesArray, jintArray outTypesArray, jintArray outAutoCommitFirstWordConfidenceArray); - void addPrediction(const int *const codePoints, const int codePointCount, const int score); + void addSuggestion(const int *const codePoints, const int codePointCount, + const int score, const int type, const int indexToPartialCommit, + const int autocimmitFirstWordConfindence); + void getSortedScores(int *const outScores) const; + void dumpSuggestions() const; int getSuggestionCount() const { return mSuggestedWords.size(); diff --git a/native/jni/src/suggest/core/result/suggestions_output_utils.cpp b/native/jni/src/suggest/core/result/suggestions_output_utils.cpp index a27631510..83140f1ab 100644 --- a/native/jni/src/suggest/core/result/suggestions_output_utils.cpp +++ b/native/jni/src/suggest/core/result/suggestions_output_utils.cpp @@ -17,163 +17,125 @@ #include "suggest/core/result/suggestions_output_utils.h" #include <algorithm> +#include <vector> #include "suggest/core/dicnode/dic_node.h" #include "suggest/core/dicnode/dic_node_utils.h" #include "suggest/core/dictionary/binary_dictionary_shortcut_iterator.h" -#include "suggest/core/dictionary/dictionary.h" #include "suggest/core/dictionary/error_type_utils.h" #include "suggest/core/policy/scoring.h" +#include "suggest/core/result/suggestion_results.h" #include "suggest/core/session/dic_traverse_session.h" namespace latinime { const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; -// TODO: Split this method. -/* static */ int SuggestionsOutputUtils::outputSuggestions( +/* static */ void SuggestionsOutputUtils::outputSuggestions( const Scoring *const scoringPolicy, DicTraverseSession *traverseSession, - int *outputScores, int *outputCodePoints, int *outputIndicesToPartialCommit, - int *outputTypes, int *outputAutoCommitFirstWordConfidence) { + SuggestionResults *const outSuggestionResults) { #if DEBUG_EVALUATE_MOST_PROBABLE_STRING const int terminalSize = 0; #else - const int terminalSize = std::min(MAX_RESULTS, - static_cast<int>(traverseSession->getDicTraverseCache()->terminalSize())); + const int terminalSize = traverseSession->getDicTraverseCache()->terminalSize(); #endif - DicNode terminals[MAX_RESULTS]; // Avoiding non-POD variable length array - + std::vector<DicNode> terminals(terminalSize); for (int index = terminalSize - 1; index >= 0; --index) { traverseSession->getDicTraverseCache()->popTerminal(&terminals[index]); } const float languageWeight = scoringPolicy->getAdjustedLanguageWeight( - traverseSession, terminals, terminalSize); - - int outputWordIndex = 0; - // Insert most probable word at index == 0 as long as there is one terminal at least - const bool hasMostProbableString = - scoringPolicy->getMostProbableString(traverseSession, terminalSize, languageWeight, - &outputCodePoints[0], &outputTypes[0], &outputScores[0]); - if (hasMostProbableString) { - outputIndicesToPartialCommit[outputWordIndex] = NOT_AN_INDEX; - ++outputWordIndex; - } - - int maxScore = S_INT_MIN; + traverseSession, terminals.data(), terminalSize); // Force autocorrection for obvious long multi-word suggestions when the top suggestion is // a long multiple words suggestion. // TODO: Implement a smarter auto-commit method for handling multi-word suggestions. - // traverseSession->isPartiallyCommited() always returns false because we never auto partial - // commit for now. - const bool forceCommitMultiWords = (terminalSize > 0) ? - scoringPolicy->autoCorrectsToMultiWordSuggestionIfTop() - && (traverseSession->isPartiallyCommited() - || (traverseSession->getInputSize() - >= MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT - && terminals[0].hasMultipleWords())) : false; + const bool forceCommitMultiWords = scoringPolicy->autoCorrectsToMultiWordSuggestionIfTop() + && (traverseSession->getInputSize() >= MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT + && !terminals.empty() && terminals.front().hasMultipleWords()); // TODO: have partial commit work even with multiple pointers. const bool outputSecondWordFirstLetterInputIndex = traverseSession->isOnlyOnePointerUsed(0 /* pointerId */); - if (terminalSize > 0) { - // If we have no suggestions, don't write this - outputAutoCommitFirstWordConfidence[0] = - computeFirstWordConfidence(&terminals[0]); - } const bool boostExactMatches = traverseSession->getDictionaryStructurePolicy()-> getHeaderStructurePolicy()->shouldBoostExactMatches(); - // Output suggestion results here - for (int terminalIndex = 0; terminalIndex < terminalSize && outputWordIndex < MAX_RESULTS; - ++terminalIndex) { - DicNode *terminalDicNode = &terminals[terminalIndex]; - if (DEBUG_GEO_FULL) { - terminalDicNode->dump("OUT:"); - } - const float doubleLetterCost = - scoringPolicy->getDoubleLetterDemotionDistanceCost(terminalDicNode); - const float compoundDistance = terminalDicNode->getCompoundDistance(languageWeight) - + doubleLetterCost; - const bool isPossiblyOffensiveWord = - traverseSession->getDictionaryStructurePolicy()->getProbability( - terminalDicNode->getProbability(), NOT_A_PROBABILITY) <= 0; - const bool isExactMatch = - ErrorTypeUtils::isExactMatch(terminalDicNode->getContainedErrorTypes()); - const bool isFirstCharUppercase = terminalDicNode->isFirstCharUppercase(); - // Heuristic: We exclude probability=0 first-char-uppercase words from exact match. - // (e.g. "AMD" and "and") - const bool isSafeExactMatch = isExactMatch - && !(isPossiblyOffensiveWord && isFirstCharUppercase); - const int outputTypeFlags = - (isPossiblyOffensiveWord ? Dictionary::KIND_FLAG_POSSIBLY_OFFENSIVE : 0) - | ((isSafeExactMatch && boostExactMatches) ? Dictionary::KIND_FLAG_EXACT_MATCH : 0); - // Entries that are blacklisted or do not represent a word should not be output. - const bool isValidWord = !terminalDicNode->isBlacklistedOrNotAWord(); - - // Increase output score of top typing suggestion to ensure autocorrection. - // TODO: Better integration with java side autocorrection logic. - const int finalScore = scoringPolicy->calculateFinalScore( - compoundDistance, traverseSession->getInputSize(), - terminalDicNode->getContainedErrorTypes(), - (forceCommitMultiWords && terminalDicNode->hasMultipleWords()) - || (isValidWord && scoringPolicy->doesAutoCorrectValidWord()), - boostExactMatches); - if (maxScore < finalScore && isValidWord) { - maxScore = finalScore; - } - - // Don't output invalid words. However, we still need to submit their shortcuts if any. - if (isValidWord) { - outputTypes[outputWordIndex] = Dictionary::KIND_CORRECTION | outputTypeFlags; - outputScores[outputWordIndex] = finalScore; - if (outputSecondWordFirstLetterInputIndex) { - outputIndicesToPartialCommit[outputWordIndex] = - terminalDicNode->getSecondWordFirstInputIndex( - traverseSession->getProximityInfoState(0)); - } else { - outputIndicesToPartialCommit[outputWordIndex] = NOT_AN_INDEX; - } - // Populate the outputChars array with the suggested word. - const int startIndex = outputWordIndex * MAX_WORD_LENGTH; - terminalDicNode->outputResult(&outputCodePoints[startIndex]); - ++outputWordIndex; - } + // Output suggestion results here + for (auto &terminalDicNode : terminals) { + outputSuggestionsOfDicNode(scoringPolicy, traverseSession, &terminalDicNode, + languageWeight, boostExactMatches, forceCommitMultiWords, + outputSecondWordFirstLetterInputIndex, outSuggestionResults); + } + scoringPolicy->getMostProbableString(traverseSession, languageWeight, outSuggestionResults); +} - if (!terminalDicNode->hasMultipleWords()) { - BinaryDictionaryShortcutIterator shortcutIt( - traverseSession->getDictionaryStructurePolicy()->getShortcutsStructurePolicy(), - traverseSession->getDictionaryStructurePolicy() - ->getShortcutPositionOfPtNode(terminalDicNode->getPtNodePos())); - // Shortcut is not supported for multiple words suggestions. - // TODO: Check shortcuts during traversal for multiple words suggestions. - const bool sameAsTyped = scoringPolicy->sameAsTyped(traverseSession, terminalDicNode); - const int shortcutBaseScore = scoringPolicy->doesAutoCorrectValidWord() ? - scoringPolicy->calculateFinalScore(compoundDistance, - traverseSession->getInputSize(), - terminalDicNode->getContainedErrorTypes(), - true /* forceCommit */, boostExactMatches) : finalScore; - const int updatedOutputWordIndex = outputShortcuts(&shortcutIt, - outputWordIndex, shortcutBaseScore, outputCodePoints, outputScores, outputTypes, - sameAsTyped); - const int secondWordFirstInputIndex = terminalDicNode->getSecondWordFirstInputIndex( - traverseSession->getProximityInfoState(0)); - for (int i = outputWordIndex; i < updatedOutputWordIndex; ++i) { - if (outputSecondWordFirstLetterInputIndex) { - outputIndicesToPartialCommit[i] = secondWordFirstInputIndex; - } else { - outputIndicesToPartialCommit[i] = NOT_AN_INDEX; - } - } - outputWordIndex = updatedOutputWordIndex; - } - DicNode::managedDelete(terminalDicNode); +/* static */ void SuggestionsOutputUtils::outputSuggestionsOfDicNode( + const Scoring *const scoringPolicy, DicTraverseSession *traverseSession, + const DicNode *const terminalDicNode, const float languageWeight, + const bool boostExactMatches, const bool forceCommitMultiWords, + const bool outputSecondWordFirstLetterInputIndex, + SuggestionResults *const outSuggestionResults) { + if (DEBUG_GEO_FULL) { + terminalDicNode->dump("OUT:"); + } + const float doubleLetterCost = + scoringPolicy->getDoubleLetterDemotionDistanceCost(terminalDicNode); + const float compoundDistance = terminalDicNode->getCompoundDistance(languageWeight) + + doubleLetterCost; + const bool isPossiblyOffensiveWord = + traverseSession->getDictionaryStructurePolicy()->getProbability( + terminalDicNode->getProbability(), NOT_A_PROBABILITY) <= 0; + const bool isExactMatch = + ErrorTypeUtils::isExactMatch(terminalDicNode->getContainedErrorTypes()); + const bool isFirstCharUppercase = terminalDicNode->isFirstCharUppercase(); + // Heuristic: We exclude probability=0 first-char-uppercase words from exact match. + // (e.g. "AMD" and "and") + const bool isSafeExactMatch = isExactMatch + && !(isPossiblyOffensiveWord && isFirstCharUppercase); + const int outputTypeFlags = + (isPossiblyOffensiveWord ? Dictionary::KIND_FLAG_POSSIBLY_OFFENSIVE : 0) + | ((isSafeExactMatch && boostExactMatches) ? Dictionary::KIND_FLAG_EXACT_MATCH : 0); + + // Entries that are blacklisted or do not represent a word should not be output. + const bool isValidWord = !terminalDicNode->isBlacklistedOrNotAWord(); + + // Increase output score of top typing suggestion to ensure autocorrection. + // TODO: Better integration with java side autocorrection logic. + const int finalScore = scoringPolicy->calculateFinalScore( + compoundDistance, traverseSession->getInputSize(), + terminalDicNode->getContainedErrorTypes(), + (forceCommitMultiWords && terminalDicNode->hasMultipleWords()) + || (isValidWord && scoringPolicy->doesAutoCorrectValidWord()), + boostExactMatches); + + // Don't output invalid words. However, we still need to submit their shortcuts if any. + if (isValidWord) { + int codePoints[MAX_WORD_LENGTH]; + terminalDicNode->outputResult(codePoints); + const int indexToPartialCommit = outputSecondWordFirstLetterInputIndex ? + terminalDicNode->getSecondWordFirstInputIndex( + traverseSession->getProximityInfoState(0)) : + NOT_AN_INDEX; + outSuggestionResults->addSuggestion(codePoints, + terminalDicNode->getTotalNodeCodePointCount(), + finalScore, Dictionary::KIND_CORRECTION | outputTypeFlags, + indexToPartialCommit, computeFirstWordConfidence(terminalDicNode)); } - if (hasMostProbableString) { - scoringPolicy->safetyNetForMostProbableString(outputWordIndex, maxScore, - &outputCodePoints[0], outputScores); + // Output shortcuts. + // Shortcut is not supported for multiple words suggestions. + // TODO: Check shortcuts during traversal for multiple words suggestions. + if (!terminalDicNode->hasMultipleWords()) { + BinaryDictionaryShortcutIterator shortcutIt( + traverseSession->getDictionaryStructurePolicy()->getShortcutsStructurePolicy(), + traverseSession->getDictionaryStructurePolicy() + ->getShortcutPositionOfPtNode(terminalDicNode->getPtNodePos())); + const bool sameAsTyped = scoringPolicy->sameAsTyped(traverseSession, terminalDicNode); + const int shortcutBaseScore = scoringPolicy->doesAutoCorrectValidWord() ? + scoringPolicy->calculateFinalScore(compoundDistance, + traverseSession->getInputSize(), + terminalDicNode->getContainedErrorTypes(), + true /* forceCommit */, boostExactMatches) : finalScore; + outputShortcuts(&shortcutIt, shortcutBaseScore, sameAsTyped, outSuggestionResults); } - return outputWordIndex; } /* static */ int SuggestionsOutputUtils::computeFirstWordConfidence( @@ -228,12 +190,11 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; return distanceContribution + lengthContribution + spaceContribution; } -/* static */ int SuggestionsOutputUtils::outputShortcuts( - BinaryDictionaryShortcutIterator *const shortcutIt, - int outputWordIndex, const int finalScore, int *const outputCodePoints, - int *const outputScores, int *const outputTypes, const bool sameAsTyped) { +/* static */ void SuggestionsOutputUtils::outputShortcuts( + BinaryDictionaryShortcutIterator *const shortcutIt, const int finalScore, + const bool sameAsTyped, SuggestionResults *const outSuggestionResults) { int shortcutTarget[MAX_WORD_LENGTH]; - while (shortcutIt->hasNextShortcutTarget() && outputWordIndex < MAX_RESULTS) { + while (shortcutIt->hasNextShortcutTarget()) { bool isWhilelist; int shortcutTargetStringLength; shortcutIt->nextShortcutTarget(MAX_WORD_LENGTH, shortcutTarget, @@ -250,15 +211,9 @@ const int SuggestionsOutputUtils::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; shortcutScore = std::max(S_INT_MIN + 1, shortcutScore) - 1; kind = Dictionary::KIND_SHORTCUT; } - outputTypes[outputWordIndex] = kind; - outputScores[outputWordIndex] = shortcutScore; - outputScores[outputWordIndex] = std::max(S_INT_MIN + 1, shortcutScore) - 1; - const int startIndex2 = outputWordIndex * MAX_WORD_LENGTH; - // Copy shortcut target code points to the output buffer. - memmove(&outputCodePoints[startIndex2], shortcutTarget, - shortcutTargetStringLength * sizeof(shortcutTarget[0])); - ++outputWordIndex; + outSuggestionResults->addSuggestion(shortcutTarget, shortcutTargetStringLength, + std::max(S_INT_MIN + 1, shortcutScore) - 1, kind, NOT_AN_INDEX, + NOT_A_FIRST_WORD_CONFIDENCE); } - return outputWordIndex; } } // namespace latinime diff --git a/native/jni/src/suggest/core/result/suggestions_output_utils.h b/native/jni/src/suggest/core/result/suggestions_output_utils.h index d456a545f..73cdb9561 100644 --- a/native/jni/src/suggest/core/result/suggestions_output_utils.h +++ b/native/jni/src/suggest/core/result/suggestions_output_utils.h @@ -25,16 +25,15 @@ class BinaryDictionaryShortcutIterator; class DicNode; class DicTraverseSession; class Scoring; +class SuggestionResults; class SuggestionsOutputUtils { public: /** * Outputs the final list of suggestions (i.e., terminal nodes). */ - static int outputSuggestions(const Scoring *const scoringPolicy, - DicTraverseSession *traverseSession, int *outputScores, int *outputCodePoints, - int *outputIndicesToPartialCommit, int *outputTypes, - int *outputAutoCommitFirstWordConfidence); + static void outputSuggestions(const Scoring *const scoringPolicy, + DicTraverseSession *traverseSession, SuggestionResults *const outSuggestionResults); private: DISALLOW_IMPLICIT_CONSTRUCTORS(SuggestionsOutputUtils); @@ -42,11 +41,15 @@ class SuggestionsOutputUtils { // Inputs longer than this will autocorrect if the suggestion is multi-word static const int MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT; + static void outputSuggestionsOfDicNode(const Scoring *const scoringPolicy, + DicTraverseSession *traverseSession, const DicNode *const terminalDicNode, + const float languageWeight, const bool boostExactMatches, + const bool forceCommitMultiWords, const bool outputSecondWordFirstLetterInputIndex, + SuggestionResults *const outSuggestionResults); + static void outputShortcuts(BinaryDictionaryShortcutIterator *const shortcutIt, + const int finalScore, const bool sameAsTyped, + SuggestionResults *const outSuggestionResults); static int computeFirstWordConfidence(const DicNode *const terminalDicNode); - - static int outputShortcuts(BinaryDictionaryShortcutIterator *const shortcutIt, - int outputWordIndex, const int finalScore, int *const outputCodePoints, - int *const outputScores, int *const outputTypes, const bool sameAsTyped); }; } // namespace latinime #endif // LATINIME_SUGGESTIONS_OUTPUT_UTILS diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.cpp b/native/jni/src/suggest/core/session/dic_traverse_session.cpp index 5070491f4..77b634e07 100644 --- a/native/jni/src/suggest/core/session/dic_traverse_session.cpp +++ b/native/jni/src/suggest/core/session/dic_traverse_session.cpp @@ -68,7 +68,6 @@ void DicTraverseSession::resetCache(const int thresholdForNextActiveDicNodes, co mDicNodesCache.reset(thresholdForNextActiveDicNodes /* nextActiveSize */, maxWords /* terminalSize */); mMultiBigramMap.clear(); - mPartiallyCommited = false; } void DicTraverseSession::initializeProximityInfoStates(const int *const inputCodePoints, diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.h b/native/jni/src/suggest/core/session/dic_traverse_session.h index b718fb57a..9e5d902dd 100644 --- a/native/jni/src/suggest/core/session/dic_traverse_session.h +++ b/native/jni/src/suggest/core/session/dic_traverse_session.h @@ -61,7 +61,7 @@ class DicTraverseSession { AK_FORCE_INLINE DicTraverseSession(JNIEnv *env, jstring localeStr, bool usesLargeCache) : mPrevWordPtNodePos(NOT_A_DICT_POS), mProximityInfo(nullptr), mDictionary(nullptr), mSuggestOptions(nullptr), mDicNodesCache(usesLargeCache), - mMultiBigramMap(), mInputSize(0), mPartiallyCommited(false), mMaxPointerCount(1), + mMultiBigramMap(), mInputSize(0), mMaxPointerCount(1), mMultiWordCostMultiplier(1.0f) { // NOTE: mProximityInfoStates is an array of instances. // No need to initialize it explicitly here. @@ -95,8 +95,6 @@ class DicTraverseSession { return &mProximityInfoStates[id]; } int getInputSize() const { return mInputSize; } - void setPartiallyCommited() { mPartiallyCommited = true; } - bool isPartiallyCommited() const { return mPartiallyCommited; } bool isOnlyOnePointerUsed(int *pointerId) const { // Not in the dictionary word @@ -188,7 +186,6 @@ class DicTraverseSession { ProximityInfoState mProximityInfoStates[MAX_POINTER_COUNT_G]; int mInputSize; - bool mPartiallyCommited; int mMaxPointerCount; ///////////////////////////////// diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp index f6de571a8..303182cf4 100644 --- a/native/jni/src/suggest/core/suggest.cpp +++ b/native/jni/src/suggest/core/suggest.cpp @@ -42,10 +42,9 @@ const int Suggest::MIN_CONTINUOUS_SUGGESTION_INPUT_SIZE = 2; * automatically activated for sequential calls that share the same starting input. * TODO: Stop detecting continuous suggestion. Start using traverseSession instead. */ -int Suggest::getSuggestions(ProximityInfo *pInfo, void *traverseSession, +void Suggest::getSuggestions(ProximityInfo *pInfo, void *traverseSession, int *inputXs, int *inputYs, int *times, int *pointerIds, int *inputCodePoints, - int inputSize, int *outWords, int *outputScores, int *outputIndices, - int *outputTypes, int *outputAutoCommitFirstWordConfidence) const { + int inputSize, SuggestionResults *const outSuggestionResults) const { PROF_OPEN; PROF_START(0); const float maxSpatialDistance = TRAVERSAL->getMaxSpatialDistance(); @@ -66,11 +65,9 @@ int Suggest::getSuggestions(ProximityInfo *pInfo, void *traverseSession, } PROF_END(1); PROF_START(2); - const int size = SuggestionsOutputUtils::outputSuggestions(SCORING, tSession, outputScores, - outWords, outputIndices, outputTypes, outputAutoCommitFirstWordConfidence); + SuggestionsOutputUtils::outputSuggestions(SCORING, tSession, outSuggestionResults); PROF_END(2); PROF_CLOSE; - return size; } /** @@ -268,7 +265,6 @@ void Suggest::processExpandedDicNode( traverseSession->getDicTraverseCache()->copyPushNextActive(dicNode); } } - DicNode::managedDelete(dicNode); } void Suggest::processDicNodeAsMatch(DicTraverseSession *traverseSession, @@ -391,7 +387,6 @@ void Suggest::processDicNodeAsTransposition(DicTraverseSession *traverseSession, processExpandedDicNode(traverseSession, childDicNode2); } } - DicNode::managedDelete(childDicNodes1[i]); } } diff --git a/native/jni/src/suggest/core/suggest.h b/native/jni/src/suggest/core/suggest.h index 33ea0b658..13ad621db 100644 --- a/native/jni/src/suggest/core/suggest.h +++ b/native/jni/src/suggest/core/suggest.h @@ -36,6 +36,7 @@ class DicNode; class DicTraverseSession; class ProximityInfo; class Scoring; +class SuggestionResults; class Traversal; class Weighting; @@ -46,10 +47,9 @@ class Suggest : public SuggestInterface { SCORING(suggestPolicy ? suggestPolicy->getScoring() : nullptr), WEIGHTING(suggestPolicy ? suggestPolicy->getWeighting() : nullptr) {} AK_FORCE_INLINE virtual ~Suggest() {} - int getSuggestions(ProximityInfo *pInfo, void *traverseSession, int *inputXs, int *inputYs, - int *times, int *pointerIds, int *inputCodePoints, int inputSize, int *outWords, - int *outputScores, int *outputIndices, int *outputTypes, - int *outputAutoCommitFirstWordConfidence) const; + void getSuggestions(ProximityInfo *pInfo, void *traverseSession, int *inputXs, int *inputYs, + int *times, int *pointerIds, int *inputCodePoints, int inputSize, + SuggestionResults *const outSuggestionResults) const; private: DISALLOW_IMPLICIT_CONSTRUCTORS(Suggest); diff --git a/native/jni/src/suggest/core/suggest_interface.h b/native/jni/src/suggest/core/suggest_interface.h index f10db830f..c3ffea9a2 100644 --- a/native/jni/src/suggest/core/suggest_interface.h +++ b/native/jni/src/suggest/core/suggest_interface.h @@ -22,13 +22,13 @@ namespace latinime { class ProximityInfo; +class SuggestionResults; class SuggestInterface { public: - virtual int getSuggestions(ProximityInfo *pInfo, void *traverseSession, int *inputXs, + virtual void getSuggestions(ProximityInfo *pInfo, void *traverseSession, int *inputXs, int *inputYs, int *times, int *pointerIds, int *inputCodePoints, int inputSize, - int *outWords, int *outputScores, int *outputIndices, int *outputTypes, - int *outputAutoCommitFirstWordConfidence) const = 0; + SuggestionResults *const suggestionResults) const = 0; SuggestInterface() {} virtual ~SuggestInterface() {} private: diff --git a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h index 8982800b7..66ea62406 100644 --- a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h +++ b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h @@ -32,15 +32,8 @@ class TypingScoring : public Scoring { public: static const TypingScoring *getInstance() { return &sInstance; } - AK_FORCE_INLINE bool getMostProbableString(const DicTraverseSession *const traverseSession, - const int terminalSize, const float languageWeight, int *const outputCodePoints, - int *const type, int *const freq) const { - return false; - } - - AK_FORCE_INLINE void safetyNetForMostProbableString(const int scoreCount, const int maxScore, - int *const outputCodePoints, int *const scores) const { - } + AK_FORCE_INLINE void getMostProbableString(const DicTraverseSession *const traverseSession, + const float languageWeight, SuggestionResults *const outSuggestionResults) const {} AK_FORCE_INLINE float getAdjustedLanguageWeight(DicTraverseSession *const traverseSession, DicNode *const terminals, const int size) const { |