aboutsummaryrefslogtreecommitdiffstats
path: root/native/src/unigram_dictionary.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'native/src/unigram_dictionary.cpp')
-rw-r--r--native/src/unigram_dictionary.cpp221
1 files changed, 70 insertions, 151 deletions
diff --git a/native/src/unigram_dictionary.cpp b/native/src/unigram_dictionary.cpp
index b95da99a3..93d2b8418 100644
--- a/native/src/unigram_dictionary.cpp
+++ b/native/src/unigram_dictionary.cpp
@@ -181,14 +181,14 @@ void UnigramDictionary::getWordSuggestions(ProximityInfo *proximityInfo,
PROF_START(0);
initSuggestions(
proximityInfo, xcoordinates, ycoordinates, codes, codesSize, outWords, frequencies);
- mCorrectionState->initCorrectionState(mProximityInfo, mInputLength);
if (DEBUG_DICT) assert(codesSize == mInputLength);
- const int MAX_DEPTH = min(mInputLength * MAX_DEPTH_MULTIPLIER, MAX_WORD_LENGTH);
+ const int maxDepth = min(mInputLength * MAX_DEPTH_MULTIPLIER, MAX_WORD_LENGTH);
+ mCorrectionState->initCorrectionState(mProximityInfo, mInputLength, maxDepth);
PROF_END(0);
PROF_START(1);
- getSuggestionCandidates(-1, -1, -1, MAX_DEPTH);
+ getSuggestionCandidates(-1, -1, -1);
PROF_END(1);
PROF_START(2);
@@ -198,7 +198,7 @@ void UnigramDictionary::getWordSuggestions(ProximityInfo *proximityInfo,
if (DEBUG_DICT) {
LOGI("--- Suggest missing characters %d", i);
}
- getSuggestionCandidates(i, -1, -1, MAX_DEPTH);
+ getSuggestionCandidates(i, -1, -1);
}
}
PROF_END(2);
@@ -211,7 +211,7 @@ void UnigramDictionary::getWordSuggestions(ProximityInfo *proximityInfo,
if (DEBUG_DICT) {
LOGI("--- Suggest excessive characters %d", i);
}
- getSuggestionCandidates(-1, i, -1, MAX_DEPTH);
+ getSuggestionCandidates(-1, i, -1);
}
}
PROF_END(3);
@@ -224,7 +224,7 @@ void UnigramDictionary::getWordSuggestions(ProximityInfo *proximityInfo,
if (DEBUG_DICT) {
LOGI("--- Suggest transposed characters %d", i);
}
- getSuggestionCandidates(-1, -1, i, mInputLength - 1);
+ getSuggestionCandidates(-1, -1, i);
}
}
PROF_END(4);
@@ -272,7 +272,6 @@ void UnigramDictionary::initSuggestions(ProximityInfo *proximityInfo, const int
mFrequencies = frequencies;
mOutputChars = outWords;
mInputLength = codesSize;
- mMaxEditDistance = mInputLength < 5 ? 2 : mInputLength / 2;
proximityInfo->setInputParams(codes, codesSize);
mProximityInfo = proximityInfo;
}
@@ -342,9 +341,8 @@ static const char QUOTE = '\'';
static const char SPACE = ' ';
void UnigramDictionary::getSuggestionCandidates(const int skipPos,
- const int excessivePos, const int transposedPos, const int maxDepth) {
+ const int excessivePos, const int transposedPos) {
if (DEBUG_DICT) {
- LOGI("getSuggestionCandidates %d", maxDepth);
assert(transposedPos + 1 < mInputLength);
assert(excessivePos < mInputLength);
assert(missingPos < mInputLength);
@@ -368,32 +366,26 @@ void UnigramDictionary::getSuggestionCandidates(const int skipPos,
while (depth >= 0) {
if (mStackChildCount[depth] > 0) {
--mStackChildCount[depth];
- bool traverseAllNodes = mStackTraverseAll[depth];
- int diffs = mStackDiffs[depth];
int siblingPos = mStackSiblingPos[depth];
int firstChildPos;
mCorrectionState->initProcessState(
- mStackMatchedCount[depth], mStackInputIndex[depth], mStackOutputIndex[depth]);
+ mStackMatchedCount[depth], mStackInputIndex[depth], mStackOutputIndex[depth],
+ mStackTraverseAll[depth], mStackDiffs[depth]);
- // depth will never be greater than maxDepth because in that case,
// needsToTraverseChildrenNodes should be false
const bool needsToTraverseChildrenNodes = processCurrentNode(siblingPos,
- maxDepth, traverseAllNodes, diffs,
- mCorrectionState, &childCount,
- &firstChildPos, &traverseAllNodes, &diffs,
- &siblingPos);
+ mCorrectionState, &childCount, &firstChildPos, &siblingPos);
// Update next sibling pos
mStackSiblingPos[depth] = siblingPos;
if (needsToTraverseChildrenNodes) {
// Goes to child node
++depth;
mStackChildCount[depth] = childCount;
- mStackTraverseAll[depth] = traverseAllNodes;
- mStackDiffs[depth] = diffs;
mStackSiblingPos[depth] = firstChildPos;
mCorrectionState->getProcessState(&mStackMatchedCount[depth],
- &mStackInputIndex[depth], &mStackOutputIndex[depth]);
+ &mStackInputIndex[depth], &mStackOutputIndex[depth],
+ &mStackTraverseAll[depth], &mStackDiffs[depth]);
}
} else {
// Goes to parent sibling node
@@ -437,12 +429,12 @@ inline bool UnigramDictionary::needsToSkipCurrentNode(const unsigned short c,
return (c == QUOTE && userTypedChar != QUOTE) || skipPos == depth;
}
-
-inline void UnigramDictionary::onTerminal(
- unsigned short int* word, const int freq, CorrectionState *correctionState) {
- const int finalFreq = correctionState->getFinalFreq(word, freq);
+inline void UnigramDictionary::onTerminal(const int freq, CorrectionState *correctionState) {
+ int wordLength;
+ unsigned short* wordPointer;
+ const int finalFreq = correctionState->getFinalFreq(freq, &wordPointer, &wordLength);
if (finalFreq >= 0) {
- addWord(word, correctionState->getOutputIndex() + 1, finalFreq);
+ addWord(wordPointer, wordLength, finalFreq);
}
}
@@ -657,20 +649,13 @@ int UnigramDictionary::getBigramPosition(int pos, unsigned short *word, int offs
// there aren't any more nodes at this level, it merely returns the address of the first byte after
// the current node in nextSiblingPosition. Thus, the caller must keep count of the nodes at any
// given level, as output into newCount when traversing this level's parent.
-inline bool UnigramDictionary::processCurrentNode(const int initialPos, const int maxDepth,
- const bool initialTraverseAllNodes, const int initialDiffs,
- CorrectionState *correctionState, int *newCount, int *newChildrenPosition,
- bool *newTraverseAllNodes, int *newDiffs, int *nextSiblingPosition) {
- const int skipPos = correctionState->getSkipPos();
- const int excessivePos = correctionState->getExcessivePos();
- const int transposedPos = correctionState->getTransposedPos();
+inline bool UnigramDictionary::processCurrentNode(const int initialPos,
+ CorrectionState *correctionState, int *newCount,
+ int *newChildrenPosition, int *nextSiblingPosition) {
if (DEBUG_DICT) {
correctionState->checkState();
}
int pos = initialPos;
- int traverseAllNodes = initialTraverseAllNodes;
- int diffs = initialDiffs;
- const int initialInputIndex = correctionState->getInputIndex();
// Flags contain the following information:
// - Address type (MASK_GROUP_ADDRESS_TYPE) on two bits:
@@ -682,6 +667,9 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in
// - FLAG_HAS_BIGRAMS: whether this node has bigrams or not
const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(DICT_ROOT, &pos);
const bool hasMultipleChars = (0 != (FLAG_HAS_MULTIPLE_CHARS & flags));
+ const bool isTerminalNode = (0 != (FLAG_IS_TERMINAL & flags));
+
+ bool needsToInvokeOnTerminal = false;
// This gets only ONE character from the stream. Next there will be:
// if FLAG_HAS_MULTIPLE CHARS: the other characters of the same node
@@ -707,111 +695,21 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in
const bool isLastChar = (NOT_A_CHARACTER == nextc);
// If there are more chars in this nodes, then this virtual node is not a terminal.
// If we are on the last char, this virtual node is a terminal if this node is.
- const bool isTerminal = isLastChar && (0 != (FLAG_IS_TERMINAL & flags));
- // If there are more chars in this node, then this virtual node has children.
- // If we are on the last char, this virtual node has children if this node has.
- const bool hasChildren = (!isLastChar) || BinaryFormat::hasChildrenInFlags(flags);
-
- // This has to be done for each virtual char (this forwards the "inputIndex" which
- // is the index in the user-inputted chars, as read by proximity chars.
- if (excessivePos == correctionState->getOutputIndex()
- && correctionState->getInputIndex() < mInputLength - 1) {
- correctionState->incrementInputIndex();
- }
- if (traverseAllNodes || needsToSkipCurrentNode(
- c, correctionState->getInputIndex(), skipPos, correctionState->getOutputIndex())) {
- mWord[correctionState->getOutputIndex()] = c;
- if (traverseAllNodes && isTerminal) {
- // The frequency should be here, because we come here only if this is actually
- // a terminal node, and we are on its last char.
- const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos);
- onTerminal(mWord, freq, mCorrectionState);
- }
- if (!hasChildren) {
- // If we don't have children here, that means we finished processing all
- // characters of this node (we are on the last virtual node), AND we are in
- // traverseAllNodes mode, which means we are searching for *completions*. We
- // should skip the frequency if we have a terminal, and report the position
- // of the next sibling. We don't have to return other values because we are
- // returning false, as in "don't traverse children".
- if (isTerminal) pos = BinaryFormat::skipFrequency(flags, pos);
- *nextSiblingPosition =
- BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos);
- return false;
- }
- } else {
- int inputIndexForProximity = correctionState->getInputIndex();
-
- if (transposedPos >= 0) {
- if (correctionState->getInputIndex() == transposedPos) {
- ++inputIndexForProximity;
- }
- if (correctionState->getInputIndex() == (transposedPos + 1)) {
- --inputIndexForProximity;
- }
- }
-
- int matchedProximityCharId = mProximityInfo->getMatchedProximityId(
- inputIndexForProximity, c, mCorrectionState);
- if (ProximityInfo::UNRELATED_CHAR == matchedProximityCharId) {
- // We found that this is an unrelated character, so we should give up traversing
- // this node and its children entirely.
- // However we may not be on the last virtual node yet so we skip the remaining
- // characters in this node, the frequency if it's there, read the next sibling
- // position to output it, then return false.
- // We don't have to output other values because we return false, as in
- // "don't traverse children".
- if (!isLastChar) {
- pos = BinaryFormat::skipOtherCharacters(DICT_ROOT, pos);
- }
- pos = BinaryFormat::skipFrequency(flags, pos);
- *nextSiblingPosition =
- BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos);
- return false;
- }
- mWord[correctionState->getOutputIndex()] = c;
- // If inputIndex is greater than mInputLength, that means there is no
- // proximity chars. So, we don't need to check proximity.
- if (ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR == matchedProximityCharId) {
- correctionState->charMatched();
- }
- const bool isSameAsUserTypedLength = mInputLength
- == correctionState->getInputIndex() + 1
- || (excessivePos == mInputLength - 1
- && correctionState->getInputIndex() == mInputLength - 2);
- if (isSameAsUserTypedLength && isTerminal) {
- const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos);
- onTerminal(mWord, freq, mCorrectionState);
- }
- // Start traversing all nodes after the index exceeds the user typed length
- traverseAllNodes = isSameAsUserTypedLength;
- diffs = diffs
- + ((ProximityInfo::NEAR_PROXIMITY_CHAR == matchedProximityCharId) ? 1 : 0);
- // Finally, we are ready to go to the next character, the next "virtual node".
- // We should advance the input index.
- // We do this in this branch of the 'if traverseAllNodes' because we are still matching
- // characters to input; the other branch is not matching them but searching for
- // completions, this is why it does not have to do it.
- correctionState->incrementInputIndex();
-
- // This character matched the typed character (enough to traverse the node at least)
- // so we just evaluated it. Now we should evaluate this virtual node's children - that
- // is, if it has any. If it has no children, we're done here - so we skip the end of
- // the node, output the siblings position, and return false "don't traverse children".
- // Note that !hasChildren implies isLastChar, so we know we don't have to skip any
- // remaining char in this group for there can't be any.
- if (!hasChildren) {
- pos = BinaryFormat::skipFrequency(flags, pos);
- *nextSiblingPosition =
- BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos);
- return false;
- }
- }
- // Optimization: Prune out words that are too long compared to how much was typed.
- if (isTerminal
- && (correctionState->getOutputIndex() >= maxDepth || diffs > mMaxEditDistance)) {
- // We are giving up parsing this node and its children. Skip the rest of the node,
- // output the sibling position, and return that we don't want to traverse children.
+ const bool isTerminal = isLastChar && isTerminalNode;
+
+ CorrectionState::CorrectionStateType stateType = correctionState->processCharAndCalcState(
+ c, isTerminal);
+ if (stateType == CorrectionState::TRAVERSE_ALL_ON_TERMINAL
+ || stateType == CorrectionState::ON_TERMINAL) {
+ needsToInvokeOnTerminal = true;
+ } else if (stateType == CorrectionState::UNRELATED) {
+ // We found that this is an unrelated character, so we should give up traversing
+ // this node and its children entirely.
+ // However we may not be on the last virtual node yet so we skip the remaining
+ // characters in this node, the frequency if it's there, read the next sibling
+ // position to output it, then return false.
+ // We don't have to output other values because we return false, as in
+ // "don't traverse children".
if (!isLastChar) {
pos = BinaryFormat::skipOtherCharacters(DICT_ROOT, pos);
}
@@ -820,8 +718,6 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in
BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos);
return false;
}
- // Also, the next char is one "virtual node" depth more than this char.
- correctionState->incrementOutputIndex();
// Prepare for the next character. Promote the prefetched char to current char - the loop
// will take care of prefetching the next. If we finally found our last char, nextc will
@@ -829,16 +725,39 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in
c = nextc;
} while (NOT_A_CHARACTER != c);
- // If inputIndex is greater than mInputLength, that means there are no proximity chars.
- // Here, that's all we are interested in so we don't need to check for isSameAsUserTypedLength.
- if (mInputLength <= initialInputIndex) {
- traverseAllNodes = true;
- }
+ if (isTerminalNode) {
+ if (needsToInvokeOnTerminal) {
+ // The frequency should be here, because we come here only if this is actually
+ // a terminal node, and we are on its last char.
+ const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos);
+ onTerminal(freq, mCorrectionState);
+ }
+
+ // If there are more chars in this node, then this virtual node has children.
+ // If we are on the last char, this virtual node has children if this node has.
+ const bool hasChildren = BinaryFormat::hasChildrenInFlags(flags);
+
+ // This character matched the typed character (enough to traverse the node at least)
+ // so we just evaluated it. Now we should evaluate this virtual node's children - that
+ // is, if it has any. If it has no children, we're done here - so we skip the end of
+ // the node, output the siblings position, and return false "don't traverse children".
+ // Note that !hasChildren implies isLastChar, so we know we don't have to skip any
+ // remaining char in this group for there can't be any.
+ if (!hasChildren) {
+ pos = BinaryFormat::skipFrequency(flags, pos);
+ *nextSiblingPosition =
+ BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos);
+ return false;
+ }
- // All the output values that are purely computation by this function are held in local
- // variables. Output them to the caller.
- *newTraverseAllNodes = traverseAllNodes;
- *newDiffs = diffs;
+ // Optimization: Prune out words that are too long compared to how much was typed.
+ if (correctionState->needsToPrune()) {
+ pos = BinaryFormat::skipFrequency(flags, pos);
+ *nextSiblingPosition =
+ BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos);
+ return false;
+ }
+ }
// Now we finished processing this node, and we want to traverse children. If there are no
// children, we can't come here.