diff options
Diffstat (limited to 'native/src/unigram_dictionary.cpp')
-rw-r--r-- | native/src/unigram_dictionary.cpp | 221 |
1 files changed, 70 insertions, 151 deletions
diff --git a/native/src/unigram_dictionary.cpp b/native/src/unigram_dictionary.cpp index b95da99a3..93d2b8418 100644 --- a/native/src/unigram_dictionary.cpp +++ b/native/src/unigram_dictionary.cpp @@ -181,14 +181,14 @@ void UnigramDictionary::getWordSuggestions(ProximityInfo *proximityInfo, PROF_START(0); initSuggestions( proximityInfo, xcoordinates, ycoordinates, codes, codesSize, outWords, frequencies); - mCorrectionState->initCorrectionState(mProximityInfo, mInputLength); if (DEBUG_DICT) assert(codesSize == mInputLength); - const int MAX_DEPTH = min(mInputLength * MAX_DEPTH_MULTIPLIER, MAX_WORD_LENGTH); + const int maxDepth = min(mInputLength * MAX_DEPTH_MULTIPLIER, MAX_WORD_LENGTH); + mCorrectionState->initCorrectionState(mProximityInfo, mInputLength, maxDepth); PROF_END(0); PROF_START(1); - getSuggestionCandidates(-1, -1, -1, MAX_DEPTH); + getSuggestionCandidates(-1, -1, -1); PROF_END(1); PROF_START(2); @@ -198,7 +198,7 @@ void UnigramDictionary::getWordSuggestions(ProximityInfo *proximityInfo, if (DEBUG_DICT) { LOGI("--- Suggest missing characters %d", i); } - getSuggestionCandidates(i, -1, -1, MAX_DEPTH); + getSuggestionCandidates(i, -1, -1); } } PROF_END(2); @@ -211,7 +211,7 @@ void UnigramDictionary::getWordSuggestions(ProximityInfo *proximityInfo, if (DEBUG_DICT) { LOGI("--- Suggest excessive characters %d", i); } - getSuggestionCandidates(-1, i, -1, MAX_DEPTH); + getSuggestionCandidates(-1, i, -1); } } PROF_END(3); @@ -224,7 +224,7 @@ void UnigramDictionary::getWordSuggestions(ProximityInfo *proximityInfo, if (DEBUG_DICT) { LOGI("--- Suggest transposed characters %d", i); } - getSuggestionCandidates(-1, -1, i, mInputLength - 1); + getSuggestionCandidates(-1, -1, i); } } PROF_END(4); @@ -272,7 +272,6 @@ void UnigramDictionary::initSuggestions(ProximityInfo *proximityInfo, const int mFrequencies = frequencies; mOutputChars = outWords; mInputLength = codesSize; - mMaxEditDistance = mInputLength < 5 ? 2 : mInputLength / 2; proximityInfo->setInputParams(codes, codesSize); mProximityInfo = proximityInfo; } @@ -342,9 +341,8 @@ static const char QUOTE = '\''; static const char SPACE = ' '; void UnigramDictionary::getSuggestionCandidates(const int skipPos, - const int excessivePos, const int transposedPos, const int maxDepth) { + const int excessivePos, const int transposedPos) { if (DEBUG_DICT) { - LOGI("getSuggestionCandidates %d", maxDepth); assert(transposedPos + 1 < mInputLength); assert(excessivePos < mInputLength); assert(missingPos < mInputLength); @@ -368,32 +366,26 @@ void UnigramDictionary::getSuggestionCandidates(const int skipPos, while (depth >= 0) { if (mStackChildCount[depth] > 0) { --mStackChildCount[depth]; - bool traverseAllNodes = mStackTraverseAll[depth]; - int diffs = mStackDiffs[depth]; int siblingPos = mStackSiblingPos[depth]; int firstChildPos; mCorrectionState->initProcessState( - mStackMatchedCount[depth], mStackInputIndex[depth], mStackOutputIndex[depth]); + mStackMatchedCount[depth], mStackInputIndex[depth], mStackOutputIndex[depth], + mStackTraverseAll[depth], mStackDiffs[depth]); - // depth will never be greater than maxDepth because in that case, // needsToTraverseChildrenNodes should be false const bool needsToTraverseChildrenNodes = processCurrentNode(siblingPos, - maxDepth, traverseAllNodes, diffs, - mCorrectionState, &childCount, - &firstChildPos, &traverseAllNodes, &diffs, - &siblingPos); + mCorrectionState, &childCount, &firstChildPos, &siblingPos); // Update next sibling pos mStackSiblingPos[depth] = siblingPos; if (needsToTraverseChildrenNodes) { // Goes to child node ++depth; mStackChildCount[depth] = childCount; - mStackTraverseAll[depth] = traverseAllNodes; - mStackDiffs[depth] = diffs; mStackSiblingPos[depth] = firstChildPos; mCorrectionState->getProcessState(&mStackMatchedCount[depth], - &mStackInputIndex[depth], &mStackOutputIndex[depth]); + &mStackInputIndex[depth], &mStackOutputIndex[depth], + &mStackTraverseAll[depth], &mStackDiffs[depth]); } } else { // Goes to parent sibling node @@ -437,12 +429,12 @@ inline bool UnigramDictionary::needsToSkipCurrentNode(const unsigned short c, return (c == QUOTE && userTypedChar != QUOTE) || skipPos == depth; } - -inline void UnigramDictionary::onTerminal( - unsigned short int* word, const int freq, CorrectionState *correctionState) { - const int finalFreq = correctionState->getFinalFreq(word, freq); +inline void UnigramDictionary::onTerminal(const int freq, CorrectionState *correctionState) { + int wordLength; + unsigned short* wordPointer; + const int finalFreq = correctionState->getFinalFreq(freq, &wordPointer, &wordLength); if (finalFreq >= 0) { - addWord(word, correctionState->getOutputIndex() + 1, finalFreq); + addWord(wordPointer, wordLength, finalFreq); } } @@ -657,20 +649,13 @@ int UnigramDictionary::getBigramPosition(int pos, unsigned short *word, int offs // there aren't any more nodes at this level, it merely returns the address of the first byte after // the current node in nextSiblingPosition. Thus, the caller must keep count of the nodes at any // given level, as output into newCount when traversing this level's parent. -inline bool UnigramDictionary::processCurrentNode(const int initialPos, const int maxDepth, - const bool initialTraverseAllNodes, const int initialDiffs, - CorrectionState *correctionState, int *newCount, int *newChildrenPosition, - bool *newTraverseAllNodes, int *newDiffs, int *nextSiblingPosition) { - const int skipPos = correctionState->getSkipPos(); - const int excessivePos = correctionState->getExcessivePos(); - const int transposedPos = correctionState->getTransposedPos(); +inline bool UnigramDictionary::processCurrentNode(const int initialPos, + CorrectionState *correctionState, int *newCount, + int *newChildrenPosition, int *nextSiblingPosition) { if (DEBUG_DICT) { correctionState->checkState(); } int pos = initialPos; - int traverseAllNodes = initialTraverseAllNodes; - int diffs = initialDiffs; - const int initialInputIndex = correctionState->getInputIndex(); // Flags contain the following information: // - Address type (MASK_GROUP_ADDRESS_TYPE) on two bits: @@ -682,6 +667,9 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in // - FLAG_HAS_BIGRAMS: whether this node has bigrams or not const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(DICT_ROOT, &pos); const bool hasMultipleChars = (0 != (FLAG_HAS_MULTIPLE_CHARS & flags)); + const bool isTerminalNode = (0 != (FLAG_IS_TERMINAL & flags)); + + bool needsToInvokeOnTerminal = false; // This gets only ONE character from the stream. Next there will be: // if FLAG_HAS_MULTIPLE CHARS: the other characters of the same node @@ -707,111 +695,21 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in const bool isLastChar = (NOT_A_CHARACTER == nextc); // If there are more chars in this nodes, then this virtual node is not a terminal. // If we are on the last char, this virtual node is a terminal if this node is. - const bool isTerminal = isLastChar && (0 != (FLAG_IS_TERMINAL & flags)); - // If there are more chars in this node, then this virtual node has children. - // If we are on the last char, this virtual node has children if this node has. - const bool hasChildren = (!isLastChar) || BinaryFormat::hasChildrenInFlags(flags); - - // This has to be done for each virtual char (this forwards the "inputIndex" which - // is the index in the user-inputted chars, as read by proximity chars. - if (excessivePos == correctionState->getOutputIndex() - && correctionState->getInputIndex() < mInputLength - 1) { - correctionState->incrementInputIndex(); - } - if (traverseAllNodes || needsToSkipCurrentNode( - c, correctionState->getInputIndex(), skipPos, correctionState->getOutputIndex())) { - mWord[correctionState->getOutputIndex()] = c; - if (traverseAllNodes && isTerminal) { - // The frequency should be here, because we come here only if this is actually - // a terminal node, and we are on its last char. - const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos); - onTerminal(mWord, freq, mCorrectionState); - } - if (!hasChildren) { - // If we don't have children here, that means we finished processing all - // characters of this node (we are on the last virtual node), AND we are in - // traverseAllNodes mode, which means we are searching for *completions*. We - // should skip the frequency if we have a terminal, and report the position - // of the next sibling. We don't have to return other values because we are - // returning false, as in "don't traverse children". - if (isTerminal) pos = BinaryFormat::skipFrequency(flags, pos); - *nextSiblingPosition = - BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos); - return false; - } - } else { - int inputIndexForProximity = correctionState->getInputIndex(); - - if (transposedPos >= 0) { - if (correctionState->getInputIndex() == transposedPos) { - ++inputIndexForProximity; - } - if (correctionState->getInputIndex() == (transposedPos + 1)) { - --inputIndexForProximity; - } - } - - int matchedProximityCharId = mProximityInfo->getMatchedProximityId( - inputIndexForProximity, c, mCorrectionState); - if (ProximityInfo::UNRELATED_CHAR == matchedProximityCharId) { - // We found that this is an unrelated character, so we should give up traversing - // this node and its children entirely. - // However we may not be on the last virtual node yet so we skip the remaining - // characters in this node, the frequency if it's there, read the next sibling - // position to output it, then return false. - // We don't have to output other values because we return false, as in - // "don't traverse children". - if (!isLastChar) { - pos = BinaryFormat::skipOtherCharacters(DICT_ROOT, pos); - } - pos = BinaryFormat::skipFrequency(flags, pos); - *nextSiblingPosition = - BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos); - return false; - } - mWord[correctionState->getOutputIndex()] = c; - // If inputIndex is greater than mInputLength, that means there is no - // proximity chars. So, we don't need to check proximity. - if (ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR == matchedProximityCharId) { - correctionState->charMatched(); - } - const bool isSameAsUserTypedLength = mInputLength - == correctionState->getInputIndex() + 1 - || (excessivePos == mInputLength - 1 - && correctionState->getInputIndex() == mInputLength - 2); - if (isSameAsUserTypedLength && isTerminal) { - const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos); - onTerminal(mWord, freq, mCorrectionState); - } - // Start traversing all nodes after the index exceeds the user typed length - traverseAllNodes = isSameAsUserTypedLength; - diffs = diffs - + ((ProximityInfo::NEAR_PROXIMITY_CHAR == matchedProximityCharId) ? 1 : 0); - // Finally, we are ready to go to the next character, the next "virtual node". - // We should advance the input index. - // We do this in this branch of the 'if traverseAllNodes' because we are still matching - // characters to input; the other branch is not matching them but searching for - // completions, this is why it does not have to do it. - correctionState->incrementInputIndex(); - - // This character matched the typed character (enough to traverse the node at least) - // so we just evaluated it. Now we should evaluate this virtual node's children - that - // is, if it has any. If it has no children, we're done here - so we skip the end of - // the node, output the siblings position, and return false "don't traverse children". - // Note that !hasChildren implies isLastChar, so we know we don't have to skip any - // remaining char in this group for there can't be any. - if (!hasChildren) { - pos = BinaryFormat::skipFrequency(flags, pos); - *nextSiblingPosition = - BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos); - return false; - } - } - // Optimization: Prune out words that are too long compared to how much was typed. - if (isTerminal - && (correctionState->getOutputIndex() >= maxDepth || diffs > mMaxEditDistance)) { - // We are giving up parsing this node and its children. Skip the rest of the node, - // output the sibling position, and return that we don't want to traverse children. + const bool isTerminal = isLastChar && isTerminalNode; + + CorrectionState::CorrectionStateType stateType = correctionState->processCharAndCalcState( + c, isTerminal); + if (stateType == CorrectionState::TRAVERSE_ALL_ON_TERMINAL + || stateType == CorrectionState::ON_TERMINAL) { + needsToInvokeOnTerminal = true; + } else if (stateType == CorrectionState::UNRELATED) { + // We found that this is an unrelated character, so we should give up traversing + // this node and its children entirely. + // However we may not be on the last virtual node yet so we skip the remaining + // characters in this node, the frequency if it's there, read the next sibling + // position to output it, then return false. + // We don't have to output other values because we return false, as in + // "don't traverse children". if (!isLastChar) { pos = BinaryFormat::skipOtherCharacters(DICT_ROOT, pos); } @@ -820,8 +718,6 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos); return false; } - // Also, the next char is one "virtual node" depth more than this char. - correctionState->incrementOutputIndex(); // Prepare for the next character. Promote the prefetched char to current char - the loop // will take care of prefetching the next. If we finally found our last char, nextc will @@ -829,16 +725,39 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in c = nextc; } while (NOT_A_CHARACTER != c); - // If inputIndex is greater than mInputLength, that means there are no proximity chars. - // Here, that's all we are interested in so we don't need to check for isSameAsUserTypedLength. - if (mInputLength <= initialInputIndex) { - traverseAllNodes = true; - } + if (isTerminalNode) { + if (needsToInvokeOnTerminal) { + // The frequency should be here, because we come here only if this is actually + // a terminal node, and we are on its last char. + const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos); + onTerminal(freq, mCorrectionState); + } + + // If there are more chars in this node, then this virtual node has children. + // If we are on the last char, this virtual node has children if this node has. + const bool hasChildren = BinaryFormat::hasChildrenInFlags(flags); + + // This character matched the typed character (enough to traverse the node at least) + // so we just evaluated it. Now we should evaluate this virtual node's children - that + // is, if it has any. If it has no children, we're done here - so we skip the end of + // the node, output the siblings position, and return false "don't traverse children". + // Note that !hasChildren implies isLastChar, so we know we don't have to skip any + // remaining char in this group for there can't be any. + if (!hasChildren) { + pos = BinaryFormat::skipFrequency(flags, pos); + *nextSiblingPosition = + BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos); + return false; + } - // All the output values that are purely computation by this function are held in local - // variables. Output them to the caller. - *newTraverseAllNodes = traverseAllNodes; - *newDiffs = diffs; + // Optimization: Prune out words that are too long compared to how much was typed. + if (correctionState->needsToPrune()) { + pos = BinaryFormat::skipFrequency(flags, pos); + *nextSiblingPosition = + BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos); + return false; + } + } // Now we finished processing this node, and we want to traverse children. If there are no // children, we can't come here. |