aboutsummaryrefslogtreecommitdiffstats
path: root/native/jni/src/correction.h
diff options
context:
space:
mode:
Diffstat (limited to 'native/jni/src/correction.h')
-rw-r--r--native/jni/src/correction.h159
1 files changed, 131 insertions, 28 deletions
diff --git a/native/jni/src/correction.h b/native/jni/src/correction.h
index f016d5453..912cd838e 100644
--- a/native/jni/src/correction.h
+++ b/native/jni/src/correction.h
@@ -56,7 +56,8 @@ class Correction {
// No need to initialize it explicitly here.
}
- virtual ~Correction() {}
+ // Non virtual inline destructor -- never inherit this class
+ ~Correction() {}
void resetCorrection();
void initCorrection(
const ProximityInfo *pi, const int inputSize, const int maxWordLength);
@@ -78,14 +79,13 @@ class Correction {
return ++mTotalTraverseCount;
}
- int getFreqForSplitMultipleWords(
- const int *freqArray, const int *wordLengthArray, const int wordCount,
- const bool isSpaceProximity, const unsigned short *word);
- int getFinalProbability(const int probability, unsigned short **word, int *wordLength);
- int getFinalProbabilityForSubQueue(const int probability, unsigned short **word,
- int *wordLength, const int inputSize);
+ int getFreqForSplitMultipleWords(const int *freqArray, const int *wordLengthArray,
+ const int wordCount, const bool isSpaceProximity, const int *word);
+ int getFinalProbability(const int probability, int **word, int *wordLength);
+ int getFinalProbabilityForSubQueue(const int probability, int **word, int *wordLength,
+ const int inputSize);
- CorrectionType processCharAndCalcState(const int32_t c, const bool isTerminal);
+ CorrectionType processCharAndCalcState(const int c, const bool isTerminal);
/////////////////////////
// Tree helper methods
@@ -110,29 +110,28 @@ class Correction {
const int inputSize);
static int calcFreqForSplitMultipleWords(const int *freqArray, const int *wordLengthArray,
const int wordCount, const Correction *correction, const bool isSpaceProximity,
- const unsigned short *word);
- static float calcNormalizedScore(const unsigned short *before, const int beforeLength,
- const unsigned short *after, const int afterLength, const int score);
- static int editDistance(const unsigned short *before,
- const int beforeLength, const unsigned short *after, const int afterLength);
+ const int *word);
+ static float calcNormalizedScore(const int *before, const int beforeLength,
+ const int *after, const int afterLength, const int score);
+ static int editDistance(const int *before, const int beforeLength, const int *after,
+ const int afterLength);
private:
- static const int CODE_SPACE = ' ';
static const int MAX_INITIAL_SCORE = 255;
};
// proximity info state
- void initInputParams(const ProximityInfo *proximityInfo, const int32_t *inputCodes,
+ void initInputParams(const ProximityInfo *proximityInfo, const int *inputCodes,
const int inputSize, const int *xCoordinates, const int *yCoordinates) {
mProximityInfoState.initInputParams(0, MAX_POINT_TO_KEY_LENGTH,
proximityInfo, inputCodes, inputSize, xCoordinates, yCoordinates, 0, 0, false);
}
- const unsigned short *getPrimaryInputWord() const {
+ const int *getPrimaryInputWord() const {
return mProximityInfoState.getPrimaryInputWord();
}
- unsigned short getPrimaryCharAt(const int index) const {
- return mProximityInfoState.getPrimaryCharAt(index);
+ int getPrimaryCodePointAt(const int index) const {
+ return mProximityInfoState.getPrimaryCodePointAt(index);
}
private:
@@ -147,7 +146,7 @@ class Correction {
}
static const int TWO_31ST_DIV_2 = S_INT_MAX / 2;
- inline static void multiplyIntCapped(const int multiplier, int *base) {
+ AK_FORCE_INLINE static void multiplyIntCapped(const int multiplier, int *base) {
const int temp = *base;
if (temp != S_INT_MAX) {
// Branch if multiplier == 2 for the optimization
@@ -170,7 +169,7 @@ class Correction {
}
}
- inline static int powerIntCapped(const int base, const int n) {
+ AK_FORCE_INLINE static int powerIntCapped(const int base, const int n) {
if (n <= 0) return 1;
if (base == 2) {
return n < 31 ? 1 << n : S_INT_MAX;
@@ -181,7 +180,7 @@ class Correction {
}
}
- inline static void multiplyRate(const int rate, int *freq) {
+ AK_FORCE_INLINE static void multiplyRate(const int rate, int *freq) {
if (*freq != S_INT_MAX) {
if (*freq > 1000000) {
*freq /= 100;
@@ -215,13 +214,13 @@ class Correction {
inline void incrementInputIndex();
inline void incrementOutputIndex();
inline void startToTraverseAllNodes();
- inline bool isSingleQuote(const unsigned short c);
- inline CorrectionType processSkipChar(
- const int32_t c, const bool isTerminal, const bool inputIndexIncremented);
+ inline bool isSingleQuote(const int c);
+ inline CorrectionType processSkipChar(const int c, const bool isTerminal,
+ const bool inputIndexIncremented);
inline CorrectionType processUnrelatedCorrectionType();
- inline void addCharToCurrentWord(const int32_t c);
- inline int getFinalProbabilityInternal(const int probability, unsigned short **word,
- int *wordLength, const int inputSize);
+ inline void addCharToCurrentWord(const int c);
+ inline int getFinalProbabilityInternal(const int probability, int **word, int *wordLength,
+ const int inputSize);
static const int TYPED_LETTER_MULTIPLIER = 2;
static const int FULL_WORD_MULTIPLIER = 2;
@@ -241,7 +240,7 @@ class Correction {
uint8_t mTotalTraverseCount;
// The following arrays are state buffer.
- unsigned short mWord[MAX_WORD_LENGTH_INTERNAL];
+ int mWord[MAX_WORD_LENGTH_INTERNAL];
int mDistances[MAX_WORD_LENGTH_INTERNAL];
// Edit distance calculation requires a buffer with (N+1)^2 length for the input length N.
@@ -275,5 +274,109 @@ class Correction {
bool mSkipping;
ProximityInfoState mProximityInfoState;
};
+
+inline void Correction::incrementInputIndex() {
+ ++mInputIndex;
+}
+
+AK_FORCE_INLINE void Correction::incrementOutputIndex() {
+ ++mOutputIndex;
+ mCorrectionStates[mOutputIndex].mParentIndex = mCorrectionStates[mOutputIndex - 1].mParentIndex;
+ mCorrectionStates[mOutputIndex].mChildCount = mCorrectionStates[mOutputIndex - 1].mChildCount;
+ mCorrectionStates[mOutputIndex].mSiblingPos = mCorrectionStates[mOutputIndex - 1].mSiblingPos;
+ mCorrectionStates[mOutputIndex].mInputIndex = mInputIndex;
+ mCorrectionStates[mOutputIndex].mNeedsToTraverseAllNodes = mNeedsToTraverseAllNodes;
+
+ mCorrectionStates[mOutputIndex].mEquivalentCharCount = mEquivalentCharCount;
+ mCorrectionStates[mOutputIndex].mProximityCount = mProximityCount;
+ mCorrectionStates[mOutputIndex].mTransposedCount = mTransposedCount;
+ mCorrectionStates[mOutputIndex].mExcessiveCount = mExcessiveCount;
+ mCorrectionStates[mOutputIndex].mSkippedCount = mSkippedCount;
+
+ mCorrectionStates[mOutputIndex].mSkipPos = mSkipPos;
+ mCorrectionStates[mOutputIndex].mTransposedPos = mTransposedPos;
+ mCorrectionStates[mOutputIndex].mExcessivePos = mExcessivePos;
+
+ mCorrectionStates[mOutputIndex].mLastCharExceeded = mLastCharExceeded;
+
+ mCorrectionStates[mOutputIndex].mMatching = mMatching;
+ mCorrectionStates[mOutputIndex].mProximityMatching = mProximityMatching;
+ mCorrectionStates[mOutputIndex].mAdditionalProximityMatching = mAdditionalProximityMatching;
+ mCorrectionStates[mOutputIndex].mTransposing = mTransposing;
+ mCorrectionStates[mOutputIndex].mExceeding = mExceeding;
+ mCorrectionStates[mOutputIndex].mSkipping = mSkipping;
+}
+
+inline void Correction::startToTraverseAllNodes() {
+ mNeedsToTraverseAllNodes = true;
+}
+
+inline bool Correction::isSingleQuote(const int c) {
+ const int userTypedChar = mProximityInfoState.getPrimaryCodePointAt(mInputIndex);
+ return (c == KEYCODE_SINGLE_QUOTE && userTypedChar != KEYCODE_SINGLE_QUOTE);
+}
+
+AK_FORCE_INLINE Correction::CorrectionType Correction::processSkipChar(const int c,
+ const bool isTerminal, const bool inputIndexIncremented) {
+ addCharToCurrentWord(c);
+ mTerminalInputIndex = mInputIndex - (inputIndexIncremented ? 1 : 0);
+ mTerminalOutputIndex = mOutputIndex;
+ if (mNeedsToTraverseAllNodes && isTerminal) {
+ incrementOutputIndex();
+ return TRAVERSE_ALL_ON_TERMINAL;
+ } else {
+ incrementOutputIndex();
+ return TRAVERSE_ALL_NOT_ON_TERMINAL;
+ }
+}
+
+inline Correction::CorrectionType Correction::processUnrelatedCorrectionType() {
+ // Needs to set mTerminalInputIndex and mTerminalOutputIndex before returning any CorrectionType
+ mTerminalInputIndex = mInputIndex;
+ mTerminalOutputIndex = mOutputIndex;
+ return UNRELATED;
+}
+
+AK_FORCE_INLINE static void calcEditDistanceOneStep(int *editDistanceTable, const int *input,
+ const int inputSize, const int *output, const int outputLength) {
+ // TODO: Make sure that editDistance[0 ~ MAX_WORD_LENGTH_INTERNAL] is not touched.
+ // Let dp[i][j] be editDistanceTable[i * (inputSize + 1) + j].
+ // Assuming that dp[0][0] ... dp[outputLength - 1][inputSize] are already calculated,
+ // and calculate dp[ouputLength][0] ... dp[outputLength][inputSize].
+ int *const current = editDistanceTable + outputLength * (inputSize + 1);
+ const int *const prev = editDistanceTable + (outputLength - 1) * (inputSize + 1);
+ const int *const prevprev =
+ outputLength >= 2 ? editDistanceTable + (outputLength - 2) * (inputSize + 1) : 0;
+ current[0] = outputLength;
+ const int co = toBaseLowerCase(output[outputLength - 1]);
+ const int prevCO = outputLength >= 2 ? toBaseLowerCase(output[outputLength - 2]) : 0;
+ for (int i = 1; i <= inputSize; ++i) {
+ const int ci = toBaseLowerCase(input[i - 1]);
+ const uint16_t cost = (ci == co) ? 0 : 1;
+ current[i] = min(current[i - 1] + 1, min(prev[i] + 1, prev[i - 1] + cost));
+ if (i >= 2 && prevprev && ci == prevCO && co == toBaseLowerCase(input[i - 2])) {
+ current[i] = min(current[i], prevprev[i - 2] + 1);
+ }
+ }
+}
+
+AK_FORCE_INLINE void Correction::addCharToCurrentWord(const int c) {
+ mWord[mOutputIndex] = c;
+ const int *primaryInputWord = mProximityInfoState.getPrimaryInputWord();
+ calcEditDistanceOneStep(mEditDistanceTable, primaryInputWord, mInputSize, mWord,
+ mOutputIndex + 1);
+}
+
+inline int Correction::getFinalProbabilityInternal(const int probability, int **word,
+ int *wordLength, const int inputSize) {
+ const int outputIndex = mTerminalOutputIndex;
+ const int inputIndex = mTerminalInputIndex;
+ *wordLength = outputIndex + 1;
+ *word = mWord;
+ int finalProbability= Correction::RankingAlgorithm::calculateFinalProbability(
+ inputIndex, outputIndex, probability, mEditDistanceTable, this, inputSize);
+ return finalProbability;
+}
+
} // namespace latinime
#endif // LATINIME_CORRECTION_H