aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--native/src/correction.cpp137
-rw-r--r--native/src/correction.h6
-rw-r--r--native/src/correction_state.h2
-rw-r--r--native/src/defines.h16
-rw-r--r--native/src/proximity_info.cpp4
-rw-r--r--native/src/proximity_info.h4
-rw-r--r--native/src/unigram_dictionary.cpp3
7 files changed, 157 insertions, 15 deletions
diff --git a/native/src/correction.cpp b/native/src/correction.cpp
index f8f73ddf5..a4090a966 100644
--- a/native/src/correction.cpp
+++ b/native/src/correction.cpp
@@ -21,6 +21,7 @@
#define LOG_TAG "LatinIME: correction.cpp"
#include "correction.h"
+#include "dictionary.h"
#include "proximity_info.h"
namespace latinime {
@@ -93,16 +94,11 @@ int Correction::getFinalFreq(const int freq, unsigned short **word, int *wordLen
return -1;
}
- // TODO: Remove this
- if (mSkipPos >= 0 && mSkippedCount <= 0) {
- return -1;
- }
-
*word = mWord;
const bool sameLength = (mExcessivePos == mInputLength - 1) ? (mInputLength == inputIndex + 2)
: (mInputLength == inputIndex + 1);
return Correction::RankingAlgorithm::calculateFinalFreq(
- inputIndex, outputIndex, freq, sameLength, this);
+ inputIndex, outputIndex, freq, sameLength, mEditDistanceTable, this);
}
bool Correction::initProcessState(const int outputIndex) {
@@ -117,6 +113,7 @@ bool Correction::initProcessState(const int outputIndex) {
mSkippedCount = mCorrectionStates[outputIndex].mSkippedCount;
mSkipPos = mCorrectionStates[outputIndex].mSkipPos;
mSkipping = false;
+ mProximityMatching = false;
mMatching = false;
return true;
}
@@ -160,6 +157,7 @@ void Correction::incrementOutputIndex() {
mCorrectionStates[mOutputIndex].mSkipping = mSkipping;
mCorrectionStates[mOutputIndex].mSkipPos = mSkipPos;
mCorrectionStates[mOutputIndex].mMatching = mMatching;
+ mCorrectionStates[mOutputIndex].mProximityMatching = mProximityMatching;
}
void Correction::startToTraverseAllNodes() {
@@ -207,6 +205,20 @@ Correction::CorrectionType Correction::processCharAndCalcState(
}
if (mNeedsToTraverseAllNodes || isQuote(c)) {
+ const bool checkProximityChars =
+ !(mSkippedCount > 0 || mExcessivePos >= 0 || mTransposedPos >= 0);
+ // Note: This logic tries saving cases like contrst --> contrast -- "a" is one of
+ // proximity chars of "s", but it should rather be handled as a skipped char.
+ if (checkProximityChars
+ && mInputIndex > 0
+ && mCorrectionStates[mOutputIndex].mProximityMatching
+ && mCorrectionStates[mOutputIndex].mSkipping
+ && mProximityInfo->getMatchedProximityId(
+ mInputIndex - 1, c, false)
+ == ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR) {
+ ++mSkippedCount;
+ --mProximityCount;
+ }
return processSkipChar(c, isTerminal);
} else {
int inputIndexForProximity = mInputIndex;
@@ -220,16 +232,27 @@ Correction::CorrectionType Correction::processCharAndCalcState(
}
}
+ // TODO: sum counters
const bool checkProximityChars =
- !(mSkipPos >= 0 || mExcessivePos >= 0 || mTransposedPos >= 0);
+ !(mSkippedCount > 0 || mExcessivePos >= 0 || mTransposedPos >= 0);
int matchedProximityCharId = mProximityInfo->getMatchedProximityId(
inputIndexForProximity, c, checkProximityChars);
if (ProximityInfo::UNRELATED_CHAR == matchedProximityCharId) {
- if (skip) {
+ if (skip && mProximityCount == 0) {
// Skip this letter and continue deeper
++mSkippedCount;
return processSkipChar(c, isTerminal);
+ } else if (checkProximityChars
+ && inputIndexForProximity > 0
+ && mCorrectionStates[mOutputIndex].mProximityMatching
+ && mCorrectionStates[mOutputIndex].mSkipping
+ && mProximityInfo->getMatchedProximityId(
+ inputIndexForProximity - 1, c, false)
+ == ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR) {
+ ++mSkippedCount;
+ --mProximityCount;
+ return processSkipChar(c, isTerminal);
} else {
return UNRELATED;
}
@@ -238,6 +261,7 @@ Correction::CorrectionType Correction::processCharAndCalcState(
// proximity chars. So, we don't need to check proximity.
mMatching = true;
} else if (ProximityInfo::NEAR_PROXIMITY_CHAR == matchedProximityCharId) {
+ mProximityMatching = true;
incrementProximityCount();
}
@@ -320,29 +344,116 @@ inline static void multiplyRate(const int rate, int *freq) {
}
}
+/* static */
+inline static int editDistance(
+ int* editDistanceTable, const unsigned short* input,
+ const int inputLength, const unsigned short* output, const int outputLength) {
+ // dp[li][lo] dp[a][b] = dp[ a * lo + b]
+ int* dp = editDistanceTable;
+ const int li = inputLength + 1;
+ const int lo = outputLength + 1;
+ for (int i = 0; i < li; ++i) {
+ dp[lo * i] = i;
+ }
+ for (int i = 0; i < lo; ++i) {
+ dp[i] = i;
+ }
+
+ for (int i = 0; i < li - 1; ++i) {
+ for (int j = 0; j < lo - 1; ++j) {
+ const uint32_t ci = Dictionary::toBaseLowerCase(input[i]);
+ const uint32_t co = Dictionary::toBaseLowerCase(output[j]);
+ const uint16_t cost = (ci == co) ? 0 : 1;
+ dp[(i + 1) * lo + (j + 1)] = min(dp[i * lo + (j + 1)] + 1,
+ min(dp[(i + 1) * lo + j] + 1, dp[i * lo + j] + cost));
+ if (li > 0 && lo > 0
+ && ci == Dictionary::toBaseLowerCase(output[j - 1])
+ && co == Dictionary::toBaseLowerCase(input[i - 1])) {
+ dp[(i + 1) * lo + (j + 1)] = min(
+ dp[(i + 1) * lo + (j + 1)], dp[(i - 1) * lo + (j - 1)] + cost);
+ }
+ }
+ }
+
+ if (DEBUG_EDIT_DISTANCE) {
+ LOGI("IN = %d, OUT = %d", inputLength, outputLength);
+ for (int i = 0; i < li; ++i) {
+ for (int j = 0; j < lo; ++j) {
+ LOGI("EDIT[%d][%d], %d", i, j, dp[i * lo + j]);
+ }
+ }
+ }
+ return dp[li * lo - 1];
+}
+
//////////////////////
// RankingAlgorithm //
//////////////////////
/* static */
int Correction::RankingAlgorithm::calculateFinalFreq(const int inputIndex, const int outputIndex,
- const int freq, const bool sameLength, const Correction* correction) {
+ const int freq, const bool sameLength, int* editDistanceTable,
+ const Correction* correction) {
const int excessivePos = correction->getExcessivePos();
const int transposedPos = correction->getTransposedPos();
const int inputLength = correction->mInputLength;
const int typedLetterMultiplier = correction->TYPED_LETTER_MULTIPLIER;
const int fullWordMultiplier = correction->FULL_WORD_MULTIPLIER;
const ProximityInfo *proximityInfo = correction->mProximityInfo;
+ const int skipCount = correction->mSkippedCount;
+ const int proximityMatchedCount = correction->mProximityCount;
// TODO: use mExcessiveCount
- const int matchCount = inputLength - correction->mProximityCount - (excessivePos >= 0 ? 1 : 0);
- const int matchWeight = powerIntCapped(typedLetterMultiplier, matchCount);
+ int matchCount = inputLength - correction->mProximityCount - (excessivePos >= 0 ? 1 : 0);
const unsigned short* word = correction->mWord;
- const bool skipped = correction->mSkippedCount > 0;
+ const bool skipped = skipCount > 0;
+
+ // ----- TODO: use edit distance here as follows? ---------------------- /
+ //if (!skipped && excessivePos < 0 && transposedPos < 0) {
+ // const int ed = editDistance(dp, proximityInfo->getInputWord(),
+ // inputLength, word, outputIndex + 1);
+ // matchCount = outputIndex + 1 - ed;
+ // if (ed == 1 && !sameLength) ++matchCount;
+ //}
+ // const int ed = editDistance(dp, proximityInfo->getInputWord(),
+ // inputLength, word, outputIndex + 1);
+ // if (ed == 1 && !sameLength) ++matchCount; ------------------------ /
+ int matchWeight = powerIntCapped(typedLetterMultiplier, matchCount);
// TODO: Demote by edit distance
int finalFreq = freq * matchWeight;
+ // +1 +11/-12
+ /*if (inputLength == outputIndex && !skipped && excessivePos < 0 && transposedPos < 0) {
+ const int ed = editDistance(dp, proximityInfo->getInputWord(),
+ inputLength, word, outputIndex + 1);
+ if (ed == 1) {
+ multiplyRate(160, &finalFreq);
+ }
+ }*/
+ if (inputLength == outputIndex && excessivePos < 0 && transposedPos < 0
+ && (proximityMatchedCount > 0 || skipped)) {
+ const int ed = editDistance(editDistanceTable, proximityInfo->getPrimaryInputWord(),
+ inputLength, word, outputIndex + 1);
+ if (ed == 1) {
+ multiplyRate(160, &finalFreq);
+ }
+ }
+
+ // TODO: Promote properly?
+ //if (skipCount == 1 && excessivePos < 0 && transposedPos < 0 && inputLength == outputIndex
+ // && !sameLength) {
+ // multiplyRate(150, &finalFreq);
+ //}
+ //if (skipCount == 0 && excessivePos < 0 && transposedPos < 0 && inputLength == outputIndex
+ // && !sameLength) {
+ // multiplyRate(150, &finalFreq);
+ //}
+ //if (skipCount == 0 && excessivePos < 0 && transposedPos < 0
+ // && inputLength == outputIndex + 1) {
+ // multiplyRate(150, &finalFreq);
+ //}
+
if (skipped) {
if (inputLength >= 2) {
const int demotionRate = WORDS_WITH_MISSING_CHARACTER_DEMOTION_RATE
@@ -389,7 +500,7 @@ int Correction::RankingAlgorithm::calculateFinalFreq(const int inputIndex, const
multiplyIntCapped(typedLetterMultiplier, &finalFreq);
multiplyRate(WORDS_WITH_PROXIMITY_CHARACTER_DEMOTION_RATE, &finalFreq);
}
- if (DEBUG_DICT) {
+ if (DEBUG_DICT_FULL) {
LOGI("calc: %d, %d", outputIndex, sameLength);
}
if (sameLength) multiplyIntCapped(fullWordMultiplier, &finalFreq);
diff --git a/native/src/correction.h b/native/src/correction.h
index 2fa8c905d..9d385a44e 100644
--- a/native/src/correction.h
+++ b/native/src/correction.h
@@ -120,6 +120,8 @@ private:
int mTerminalInputIndex;
int mTerminalOutputIndex;
unsigned short mWord[MAX_WORD_LENGTH_INTERNAL];
+ // Caveat: Do not create multiple tables per thread as this table eats up RAM a lot.
+ int mEditDistanceTable[MAX_WORD_LENGTH_INTERNAL * MAX_WORD_LENGTH_INTERNAL];
CorrectionState mCorrectionStates[MAX_WORD_LENGTH_INTERNAL];
@@ -132,11 +134,13 @@ private:
bool mNeedsToTraverseAllNodes;
bool mMatching;
bool mSkipping;
+ bool mProximityMatching;
class RankingAlgorithm {
public:
static int calculateFinalFreq(const int inputIndex, const int depth,
- const int freq, const bool sameLength, const Correction* correction);
+ const int freq, const bool sameLength, int *editDistanceTable,
+ const Correction* correction);
static int calcFreqForSplitTwoWords(const int firstFreq, const int secondFreq,
const Correction* correction);
};
diff --git a/native/src/correction_state.h b/native/src/correction_state.h
index d30d13c85..267deda9b 100644
--- a/native/src/correction_state.h
+++ b/native/src/correction_state.h
@@ -33,6 +33,7 @@ struct CorrectionState {
int8_t mSkipPos; // should be signed
bool mMatching;
bool mSkipping;
+ bool mProximityMatching;
bool mNeedsToTraverseAllNodes;
};
@@ -47,6 +48,7 @@ inline static void initCorrectionState(CorrectionState *state, const int rootPos
state->mSkippedCount = 0;
state->mMatching = false;
state->mSkipping = false;
+ state->mProximityMatching = false;
state->mNeedsToTraverseAllNodes = traverseAll;
state->mSkipPos = -1;
}
diff --git a/native/src/defines.h b/native/src/defines.h
index c1838d341..c1d08e695 100644
--- a/native/src/defines.h
+++ b/native/src/defines.h
@@ -94,20 +94,36 @@ static void prof_out(void) {
#endif
#define DEBUG_DICT true
#define DEBUG_DICT_FULL false
+#define DEBUG_EDIT_DISTANCE false
#define DEBUG_SHOW_FOUND_WORD DEBUG_DICT_FULL
#define DEBUG_NODE DEBUG_DICT_FULL
#define DEBUG_TRACE DEBUG_DICT_FULL
#define DEBUG_PROXIMITY_INFO true
+#define DUMP_WORD(word, length) do { dumpWord(word, length); } while(0)
+
+static char charBuf[50];
+
+static void dumpWord(const unsigned short* word, const int length) {
+ for (int i = 0; i < length; ++i) {
+ charBuf[i] = word[i];
+ }
+ charBuf[length] = 0;
+ LOGI("[ %s ]", charBuf);
+}
+
#else // FLAG_DBG
#define DEBUG_DICT false
#define DEBUG_DICT_FULL false
+#define DEBUG_EDIT_DISTANCE false
#define DEBUG_SHOW_FOUND_WORD false
#define DEBUG_NODE false
#define DEBUG_TRACE false
#define DEBUG_PROXIMITY_INFO false
+#define DUMP_WORD(word, length)
+
#endif // FLAG_DBG
#ifndef U_SHORT_MAX
diff --git a/native/src/proximity_info.cpp b/native/src/proximity_info.cpp
index d437e251a..361bdacbf 100644
--- a/native/src/proximity_info.cpp
+++ b/native/src/proximity_info.cpp
@@ -68,6 +68,10 @@ bool ProximityInfo::hasSpaceProximity(const int x, const int y) const {
void ProximityInfo::setInputParams(const int* inputCodes, const int inputLength) {
mInputCodes = inputCodes;
mInputLength = inputLength;
+ for (int i = 0; i < inputLength; ++i) {
+ mPrimaryInputWord[i] = getPrimaryCharAt(i);
+ }
+ mPrimaryInputWord[inputLength] = 0;
}
inline const int* ProximityInfo::getProximityCharsAt(const int index) const {
diff --git a/native/src/proximity_info.h b/native/src/proximity_info.h
index d9ed46f5b..75fc8fb63 100644
--- a/native/src/proximity_info.h
+++ b/native/src/proximity_info.h
@@ -46,6 +46,9 @@ public:
ProximityType getMatchedProximityId(
const int index, const unsigned short c, const bool checkProximityChars) const;
bool sameAsTyped(const unsigned short *word, int length) const;
+ const unsigned short* getPrimaryInputWord() const {
+ return mPrimaryInputWord;
+ }
private:
int getStartIndexFromCoordinates(const int x, const int y) const;
@@ -59,6 +62,7 @@ private:
const int *mInputCodes;
uint32_t *mProximityCharsArray;
int mInputLength;
+ unsigned short mPrimaryInputWord[MAX_WORD_LENGTH_INTERNAL];
};
} // namespace latinime
diff --git a/native/src/unigram_dictionary.cpp b/native/src/unigram_dictionary.cpp
index 6517bc0b8..6bc350505 100644
--- a/native/src/unigram_dictionary.cpp
+++ b/native/src/unigram_dictionary.cpp
@@ -187,8 +187,9 @@ void UnigramDictionary::getWordSuggestions(ProximityInfo *proximityInfo,
mCorrection->initCorrection(mProximityInfo, mInputLength, maxDepth);
PROF_END(0);
+ // TODO: remove
PROF_START(1);
- getSuggestionCandidates(-1, -1, -1);
+ // Note: This line is intentionally left blank
PROF_END(1);
PROF_START(2);