aboutsummaryrefslogtreecommitdiffstats
path: root/native
diff options
context:
space:
mode:
Diffstat (limited to 'native')
-rw-r--r--native/dicttoolkit/src/command_executors/makedict_executor.cpp6
-rw-r--r--native/dicttoolkit/src/utils/arguments_and_options.h23
-rw-r--r--native/dicttoolkit/src/utils/arguments_parser.cpp81
-rw-r--r--native/dicttoolkit/src/utils/arguments_parser.h35
-rw-r--r--native/dicttoolkit/tests/utils/arguments_parser_test.cpp74
-rw-r--r--native/jni/src/defines.h2
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp59
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h81
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp17
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h3
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.cpp14
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.cpp15
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h55
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp30
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp22
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h3
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp9
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/utils/entry_counters.h86
-rw-r--r--native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp21
-rw-r--r--native/jni/src/utils/ngram_utils.h63
20 files changed, 420 insertions, 279 deletions
diff --git a/native/dicttoolkit/src/command_executors/makedict_executor.cpp b/native/dicttoolkit/src/command_executors/makedict_executor.cpp
index 8a84e8069..4b0a5aeea 100644
--- a/native/dicttoolkit/src/command_executors/makedict_executor.cpp
+++ b/native/dicttoolkit/src/command_executors/makedict_executor.cpp
@@ -24,6 +24,12 @@ namespace dicttoolkit {
const char *const MakedictExecutor::COMMAND_NAME = "makedict";
/* static */ int MakedictExecutor::run(const int argc, char **argv) {
+ const ArgumentsAndOptions argumentsAndOptions =
+ getArgumentsParser().parseArguments(argc, argv, true /* printErrorMessages */);
+ if (!argumentsAndOptions.isValid()) {
+ printUsage();
+ return 1;
+ }
fprintf(stderr, "Command '%s' has not been implemented yet.\n", COMMAND_NAME);
return 0;
}
diff --git a/native/dicttoolkit/src/utils/arguments_and_options.h b/native/dicttoolkit/src/utils/arguments_and_options.h
index d8f5985e5..2d81b1ecb 100644
--- a/native/dicttoolkit/src/utils/arguments_and_options.h
+++ b/native/dicttoolkit/src/utils/arguments_and_options.h
@@ -42,6 +42,29 @@ class ArgumentsAndOptions {
return mOptions.find(optionName) != mOptions.end();
}
+ const std::string &getOptionValue(const std::string &optionName) const {
+ const auto &it = mOptions.find(optionName);
+ ASSERT(it != mOptions.end());
+ return it->second;
+ }
+
+ bool hasArgument(const std::string &name) const {
+ const auto &it = mArguments.find(name);
+ return it != mArguments.end() && !it->second.empty();
+ }
+
+ const std::string &getSingleArgument(const std::string &name) const {
+ const auto &it = mArguments.find(name);
+ ASSERT(it != mArguments.end() && !it->second.empty());
+ return it->second.front();
+ }
+
+ const std::vector<std::string> &getVariableLengthArguments(const std::string &name) const {
+ const auto &it = mArguments.find(name);
+ ASSERT(it != mArguments.end());
+ return it->second;
+ }
+
private:
DISALLOW_ASSIGNMENT_OPERATOR(ArgumentsAndOptions);
diff --git a/native/dicttoolkit/src/utils/arguments_parser.cpp b/native/dicttoolkit/src/utils/arguments_parser.cpp
index 52cc7b21d..1451284f1 100644
--- a/native/dicttoolkit/src/utils/arguments_parser.cpp
+++ b/native/dicttoolkit/src/utils/arguments_parser.cpp
@@ -21,7 +21,7 @@
namespace latinime {
namespace dicttoolkit {
-const int ArgumentSpec::UNLIMITED_COUNT = -1;
+const size_t ArgumentSpec::UNLIMITED_COUNT = S_INT_MAX;
bool ArgumentsParser::validateSpecs() const {
std::unordered_set<std::string> argumentNameSet;
@@ -53,7 +53,7 @@ void ArgumentsParser::printUsage(const std::string &commandName,
const std::string &optionName = option.first;
const OptionSpec &spec = option.second;
printf(" [-%s", optionName.c_str());
- if (spec.takeValue()) {
+ if (spec.needsValue()) {
printf(" <%s>", spec.getValueName().c_str());
}
printf("]");
@@ -74,11 +74,11 @@ void ArgumentsParser::printUsage(const std::string &commandName,
const std::string &optionName = option.first;
const OptionSpec &spec = option.second;
printf(" -%s", optionName.c_str());
- if (spec.takeValue()) {
+ if (spec.needsValue()) {
printf(" <%s>", spec.getValueName().c_str());
}
printf("\t\t\t%s", spec.getDescription().c_str());
- if (spec.takeValue() && !spec.getDefaultValue().empty()) {
+ if (spec.needsValue() && !spec.getDefaultValue().empty()) {
printf("\tdefault: %s", spec.getDefaultValue().c_str());
}
printf("\n");
@@ -89,9 +89,76 @@ void ArgumentsParser::printUsage(const std::string &commandName,
printf("\n\n");
}
-const ArgumentsAndOptions ArgumentsParser::parseArguments(const int argc, char **argv) const {
- // TODO: Implement
- return ArgumentsAndOptions();
+const ArgumentsAndOptions ArgumentsParser::parseArguments(const int argc, char **argv,
+ const bool printErrorMessage) const {
+ if (argc <= 0) {
+ AKLOGE("Invalid argc (%d).", argc);
+ ASSERT(false);
+ return ArgumentsAndOptions();
+ }
+ std::unordered_map<std::string, std::string> options;
+ for (const auto &entry : mOptionSpecs) {
+ const std::string &optionName = entry.first;
+ const OptionSpec &optionSpec = entry.second;
+ if (optionSpec.needsValue() && !optionSpec.getDefaultValue().empty()) {
+ // Set default value.
+ options[optionName] = optionSpec.getDefaultValue();
+ }
+ }
+ std::unordered_map<std::string, std::vector<std::string>> arguments;
+ auto argumentSpecIt = mArgumentSpecs.cbegin();
+ for (int i = 1; i < argc; ++i) {
+ const std::string arg = argv[i];
+ if (arg.length() > 1 && arg[0] == '-') {
+ // option
+ const std::string optionName = arg.substr(1);
+ const auto it = mOptionSpecs.find(optionName);
+ if (it == mOptionSpecs.end()) {
+ if (printErrorMessage) {
+ fprintf(stderr, "Unknown option: '%s'\n", optionName.c_str());
+ }
+ return ArgumentsAndOptions();
+ }
+ std::string optionValue;
+ if (it->second.needsValue()) {
+ ++i;
+ if (i >= argc) {
+ if (printErrorMessage) {
+ fprintf(stderr, "Missing argument for option '%s'\n", optionName.c_str());
+ }
+ return ArgumentsAndOptions();
+ }
+ optionValue = argv[i];
+ }
+ options[optionName] = optionValue;
+ } else {
+ // argument
+ if (argumentSpecIt == mArgumentSpecs.end()) {
+ if (printErrorMessage) {
+ fprintf(stderr, "Too many arguments.\n");
+ }
+ return ArgumentsAndOptions();
+ }
+ arguments[argumentSpecIt->getName()].push_back(arg);
+ if (arguments[argumentSpecIt->getName()].size() >= argumentSpecIt->getMaxCount()) {
+ ++argumentSpecIt;
+ }
+ }
+ }
+
+ if (argumentSpecIt != mArgumentSpecs.end()) {
+ const auto &it = arguments.find(argumentSpecIt->getName());
+ const size_t minCount = argumentSpecIt->getMinCount();
+ const size_t actualcount = it == arguments.end() ? 0 : it->second.size();
+ if (minCount > actualcount) {
+ if (printErrorMessage) {
+ fprintf(stderr, "Not enough arguments. %zd argumant(s) required for <%s>\n",
+ minCount, argumentSpecIt->getName().c_str());
+ }
+ return ArgumentsAndOptions();
+ }
+ }
+ return ArgumentsAndOptions(std::move(options), std::move(arguments));
}
} // namespace dicttoolkit
diff --git a/native/dicttoolkit/src/utils/arguments_parser.h b/native/dicttoolkit/src/utils/arguments_parser.h
index 510a8722b..32bd328d4 100644
--- a/native/dicttoolkit/src/utils/arguments_parser.h
+++ b/native/dicttoolkit/src/utils/arguments_parser.h
@@ -35,29 +35,29 @@ class OptionSpec {
static OptionSpec keyValueOption(const std::string &valueName, const std::string &defaultValue,
const std::string &description) {
- return OptionSpec(true /* takeValue */, valueName, defaultValue, description);
+ return OptionSpec(true /* needsValue */, valueName, defaultValue, description);
}
static OptionSpec switchOption(const std::string &description) {
- return OptionSpec(false /* takeValue */, "" /* valueName */, "" /* defaultValue */,
+ return OptionSpec(false /* needsValue */, "" /* valueName */, "" /* defaultValue */,
description);
}
- bool takeValue() const { return mTakeValue; }
+ bool needsValue() const { return mNeedsValue; }
const std::string &getValueName() const { return mValueName; }
const std::string &getDefaultValue() const { return mDefaultValue; }
const std::string &getDescription() const { return mDescription; }
private:
- OptionSpec(const bool takeValue, const std::string &valueName, const std::string &defaultValue,
+ OptionSpec(const bool needsValue, const std::string &valueName, const std::string &defaultValue,
const std::string &description)
- : mTakeValue(takeValue), mValueName(valueName), mDefaultValue(defaultValue),
+ : mNeedsValue(needsValue), mValueName(valueName), mDefaultValue(defaultValue),
mDescription(description) {}
// Whether the option have to be used with a value or just a switch.
- // e.g. 'f' in "command -f /path/to/file" is mTakeValue == true.
- // 'f' in "command -f -t" is mTakeValue == false.
- bool mTakeValue;
+ // e.g. 'f' in "command -f /path/to/file" is mNeedsValue == true.
+ // 'f' in "command -f -t" is mNeedsValue == false.
+ bool mNeedsValue;
// Name of the value used to show usage.
std::string mValueName;
std::string mDefaultValue;
@@ -66,32 +66,32 @@ class OptionSpec {
class ArgumentSpec {
public:
- static const int UNLIMITED_COUNT;
+ static const size_t UNLIMITED_COUNT;
static ArgumentSpec singleArgument(const std::string &name, const std::string &description) {
return ArgumentSpec(name, 1 /* minCount */, 1 /* maxCount */, description);
}
- static ArgumentSpec variableLengthArguments(const std::string &name, const int minCount,
- const int maxCount, const std::string &description) {
+ static ArgumentSpec variableLengthArguments(const std::string &name, const size_t minCount,
+ const size_t maxCount, const std::string &description) {
return ArgumentSpec(name, minCount, maxCount, description);
}
const std::string &getName() const { return mName; }
- int getMinCount() const { return mMinCount; }
- int getMaxCount() const { return mMaxCount; }
+ size_t getMinCount() const { return mMinCount; }
+ size_t getMaxCount() const { return mMaxCount; }
const std::string &getDescription() const { return mDescription; }
private:
DISALLOW_DEFAULT_CONSTRUCTOR(ArgumentSpec);
- ArgumentSpec(const std::string &name, const int minCount, const int maxCount,
+ ArgumentSpec(const std::string &name, const size_t minCount, const size_t maxCount,
const std::string &description)
: mName(name), mMinCount(minCount), mMaxCount(maxCount), mDescription(description) {}
const std::string mName;
- const int mMinCount;
- const int mMaxCount;
+ const size_t mMinCount;
+ const size_t mMaxCount;
const std::string mDescription;
};
@@ -101,7 +101,8 @@ class ArgumentsParser {
const std::vector<ArgumentSpec> &&argumentSpecs)
: mOptionSpecs(std::move(optionSpecs)), mArgumentSpecs(std::move(argumentSpecs)) {}
- const ArgumentsAndOptions parseArguments(const int argc, char **argv) const;
+ const ArgumentsAndOptions parseArguments(const int argc, char **argv,
+ const bool printErrorMessage) const;
bool validateSpecs() const;
void printUsage(const std::string &commandName, const std::string &description) const;
diff --git a/native/dicttoolkit/tests/utils/arguments_parser_test.cpp b/native/dicttoolkit/tests/utils/arguments_parser_test.cpp
index e79425b87..58b499823 100644
--- a/native/dicttoolkit/tests/utils/arguments_parser_test.cpp
+++ b/native/dicttoolkit/tests/utils/arguments_parser_test.cpp
@@ -68,6 +68,80 @@ TEST(ArgumentsParserTests, TestValitadeSpecs) {
}
}
+int initArgv(char *mutableCommandLine, char **argv) {
+ bool readingSeparator = false;
+ int argc = 1;
+ argv[0] = mutableCommandLine;
+ const size_t length = strlen(mutableCommandLine);
+ for (size_t i = 0; i < length; ++i) {
+ if (mutableCommandLine[i] != ' ' && readingSeparator) {
+ readingSeparator = false;
+ argv[argc] = mutableCommandLine + i;
+ ++argc;
+ } else if (mutableCommandLine[i] == ' ' && !readingSeparator) {
+ readingSeparator = true;
+ mutableCommandLine[i] = '\0';
+ }
+ }
+ argv[argc] = nullptr;
+ return argc;
+}
+
+TEST(ArgumentsParserTests, TestParseArguments) {
+ std::unordered_map<std::string, OptionSpec> optionSpecs;
+ optionSpecs["a"] = OptionSpec::switchOption("description");
+ optionSpecs["b"] = OptionSpec::keyValueOption("valueName", "default", "description");
+ const std::vector<ArgumentSpec> argumentSpecs = {
+ ArgumentSpec::singleArgument("arg0", "description"),
+ ArgumentSpec::variableLengthArguments("arg1", 0 /* minCount */, 2 /* maxCount */,
+ "description"),
+ };
+ const ArgumentsParser parser =
+ ArgumentsParser(std::move(optionSpecs), std::move(argumentSpecs));
+
+ {
+ char kMutableCommandLine[1024] = "command arg";
+ char *argv[128] = {};
+ const int argc = initArgv(kMutableCommandLine, argv);
+ ASSERT_EQ(2, argc);
+ const ArgumentsAndOptions argumentsAndOptions = parser.parseArguments(
+ argc, argv, false /* printErrorMessages */);
+ EXPECT_FALSE(argumentsAndOptions.hasOption("a"));
+ EXPECT_EQ("default", argumentsAndOptions.getOptionValue("b"));
+ EXPECT_EQ("arg", argumentsAndOptions.getSingleArgument("arg0"));
+ EXPECT_FALSE(argumentsAndOptions.hasArgument("arg1"));
+ }
+ {
+ char kArgumentBuffer[1024] = "command -a arg arg";
+ char *argv[128] = {};
+ const int argc = initArgv(kArgumentBuffer, argv);
+ ASSERT_EQ(4, argc);
+ const ArgumentsAndOptions argumentsAndOptions = parser.parseArguments(
+ argc, argv, false /* printErrorMessages */);
+ EXPECT_TRUE(argumentsAndOptions.hasOption("a"));
+ EXPECT_EQ("default", argumentsAndOptions.getOptionValue("b"));
+ EXPECT_EQ("arg", argumentsAndOptions.getSingleArgument("arg0"));
+ EXPECT_TRUE(argumentsAndOptions.hasArgument("arg1"));
+ EXPECT_EQ(1u, argumentsAndOptions.getVariableLengthArguments("arg1").size());
+ }
+ {
+ char kArgumentBuffer[1024] = "command -b value arg arg1 arg2";
+ char *argv[128] = {};
+ const int argc = initArgv(kArgumentBuffer, argv);
+ ASSERT_EQ(6, argc);
+ const ArgumentsAndOptions argumentsAndOptions = parser.parseArguments(
+ argc, argv, false /* printErrorMessages */);
+ EXPECT_FALSE(argumentsAndOptions.hasOption("a"));
+ EXPECT_EQ("value", argumentsAndOptions.getOptionValue("b"));
+ EXPECT_EQ("arg", argumentsAndOptions.getSingleArgument("arg0"));
+ const std::vector<std::string> &arg1 =
+ argumentsAndOptions.getVariableLengthArguments("arg1");
+ EXPECT_EQ(2u, arg1.size());
+ EXPECT_EQ("arg1", arg1[0]);
+ EXPECT_EQ("arg2", arg1[1]);
+ }
+}
+
} // namespace
} // namespace dicttoolkit
} // namespace latinime
diff --git a/native/jni/src/defines.h b/native/jni/src/defines.h
index 0e67b4d5a..10b930e4f 100644
--- a/native/jni/src/defines.h
+++ b/native/jni/src/defines.h
@@ -275,7 +275,7 @@ static inline void showStackTrace() {
#define MAX_POINTER_COUNT_G 2
// (MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1)-gram is supported.
-#define MAX_PREV_WORD_COUNT_FOR_N_GRAM 2
+#define MAX_PREV_WORD_COUNT_FOR_N_GRAM 3
#define DISALLOW_DEFAULT_CONSTRUCTOR(TypeName) \
TypeName() = delete
diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp
index 300e96c4e..c93f31017 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.cpp
@@ -18,6 +18,8 @@
#include <algorithm>
+#include "utils/ngram_utils.h"
+
namespace latinime {
// Note that these are corresponding definitions in Java side in DictionaryHeader.
@@ -28,9 +30,12 @@ const char *const HeaderPolicy::REQUIRES_GERMAN_UMLAUT_PROCESSING_KEY =
const char *const HeaderPolicy::IS_DECAYING_DICT_KEY = "USES_FORGETTING_CURVE";
const char *const HeaderPolicy::DATE_KEY = "date";
const char *const HeaderPolicy::LAST_DECAYED_TIME_KEY = "LAST_DECAYED_TIME";
-const char *const HeaderPolicy::UNIGRAM_COUNT_KEY = "UNIGRAM_COUNT";
-const char *const HeaderPolicy::BIGRAM_COUNT_KEY = "BIGRAM_COUNT";
-const char *const HeaderPolicy::TRIGRAM_COUNT_KEY = "TRIGRAM_COUNT";
+const char *const HeaderPolicy::NGRAM_COUNT_KEYS[] =
+ {"UNIGRAM_COUNT", "BIGRAM_COUNT", "TRIGRAM_COUNT", "QUADGRAM_COUNT"};
+const char *const HeaderPolicy::MAX_NGRAM_COUNT_KEYS[] =
+ {"MAX_UNIGRAM_ENTRY_COUNT", "MAX_BIGRAM_ENTRY_COUNT", "MAX_TRIGRAM_ENTRY_COUNT",
+ "MAX_QUADGRAM_ENTRY_COUNT"};
+const int HeaderPolicy::DEFAULT_MAX_NGRAM_COUNTS[] = {10000, 30000, 30000, 30000};
const char *const HeaderPolicy::EXTENDED_REGION_SIZE_KEY = "EXTENDED_REGION_SIZE";
// Historical info is information that is needed to support decaying such as timestamp, level and
// count.
@@ -39,18 +44,10 @@ const char *const HeaderPolicy::LOCALE_KEY = "locale"; // match Java declaration
const char *const HeaderPolicy::FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY =
"FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID";
-const char *const HeaderPolicy::MAX_UNIGRAM_COUNT_KEY = "MAX_UNIGRAM_ENTRY_COUNT";
-const char *const HeaderPolicy::MAX_BIGRAM_COUNT_KEY = "MAX_BIGRAM_ENTRY_COUNT";
-const char *const HeaderPolicy::MAX_TRIGRAM_COUNT_KEY = "MAX_TRIGRAM_ENTRY_COUNT";
-
const int HeaderPolicy::DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE = 100;
const float HeaderPolicy::MULTIPLE_WORD_COST_MULTIPLIER_SCALE = 100.0f;
const int HeaderPolicy::DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID = 3;
-const int HeaderPolicy::DEFAULT_MAX_UNIGRAM_COUNT = 10000;
-const int HeaderPolicy::DEFAULT_MAX_BIGRAM_COUNT = 30000;
-const int HeaderPolicy::DEFAULT_MAX_TRIGRAM_COUNT = 30000;
-
// Used for logging. Question mark is used to indicate that the key is not found.
void HeaderPolicy::readHeaderValueOrQuestionMark(const char *const key, int *outValue,
int outValueSize) const {
@@ -126,15 +123,22 @@ bool HeaderPolicy::fillInAndWriteHeaderToBuffer(const bool updatesLastDecayedTim
return true;
}
+namespace {
+
+int getIndexFromNgramType(const NgramType ngramType) {
+ return static_cast<int>(ngramType);
+}
+
+} // namespace
+
void HeaderPolicy::fillInHeader(const bool updatesLastDecayedTime,
const EntryCounts &entryCounts, const int extendedRegionSize,
DictionaryHeaderStructurePolicy::AttributeMap *outAttributeMap) const {
- HeaderReadWriteUtils::setIntAttribute(outAttributeMap, UNIGRAM_COUNT_KEY,
- entryCounts.getUnigramCount());
- HeaderReadWriteUtils::setIntAttribute(outAttributeMap, BIGRAM_COUNT_KEY,
- entryCounts.getBigramCount());
- HeaderReadWriteUtils::setIntAttribute(outAttributeMap, TRIGRAM_COUNT_KEY,
- entryCounts.getTrigramCount());
+ for (const auto ngramType : AllNgramTypes::ASCENDING) {
+ HeaderReadWriteUtils::setIntAttribute(outAttributeMap,
+ NGRAM_COUNT_KEYS[getIndexFromNgramType(ngramType)],
+ entryCounts.getNgramCount(ngramType));
+ }
HeaderReadWriteUtils::setIntAttribute(outAttributeMap, EXTENDED_REGION_SIZE_KEY,
extendedRegionSize);
// Set the current time as the generation time.
@@ -155,4 +159,25 @@ void HeaderPolicy::fillInHeader(const bool updatesLastDecayedTime,
return attributeMap;
}
+/* static */ const EntryCounts HeaderPolicy::readNgramCounts() const {
+ MutableEntryCounters entryCounters;
+ for (const auto ngramType : AllNgramTypes::ASCENDING) {
+ const int entryCount = HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
+ NGRAM_COUNT_KEYS[getIndexFromNgramType(ngramType)], 0 /* defaultValue */);
+ entryCounters.setNgramCount(ngramType, entryCount);
+ }
+ return entryCounters.getEntryCounts();
+}
+
+/* static */ const EntryCounts HeaderPolicy::readMaxNgramCounts() const {
+ MutableEntryCounters entryCounters;
+ for (const auto ngramType : AllNgramTypes::ASCENDING) {
+ const int index = getIndexFromNgramType(ngramType);
+ const int maxEntryCount = HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
+ MAX_NGRAM_COUNT_KEYS[index], DEFAULT_MAX_NGRAM_COUNTS[index]);
+ entryCounters.setNgramCount(ngramType, maxEntryCount);
+ }
+ return entryCounters.getEntryCounts();
+}
+
} // namespace latinime
diff --git a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h
index 7a5acd7d5..f76931baa 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/header/header_policy.h
@@ -46,12 +46,7 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
DATE_KEY, TimeKeeper::peekCurrentTime() /* defaultValue */)),
mLastDecayedTime(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
LAST_DECAYED_TIME_KEY, TimeKeeper::peekCurrentTime() /* defaultValue */)),
- mUnigramCount(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
- UNIGRAM_COUNT_KEY, 0 /* defaultValue */)),
- mBigramCount(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
- BIGRAM_COUNT_KEY, 0 /* defaultValue */)),
- mTrigramCount(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
- TRIGRAM_COUNT_KEY, 0 /* defaultValue */)),
+ mNgramCounts(readNgramCounts()), mMaxNgramCounts(readMaxNgramCounts()),
mExtendedRegionSize(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
EXTENDED_REGION_SIZE_KEY, 0 /* defaultValue */)),
mHasHistoricalInfoOfWords(HeaderReadWriteUtils::readBoolAttributeValue(
@@ -59,12 +54,6 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
mForgettingCurveProbabilityValuesTableId(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY,
DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID)),
- mMaxUnigramCount(HeaderReadWriteUtils::readIntAttributeValue(
- &mAttributeMap, MAX_UNIGRAM_COUNT_KEY, DEFAULT_MAX_UNIGRAM_COUNT)),
- mMaxBigramCount(HeaderReadWriteUtils::readIntAttributeValue(
- &mAttributeMap, MAX_BIGRAM_COUNT_KEY, DEFAULT_MAX_BIGRAM_COUNT)),
- mMaxTrigramCount(HeaderReadWriteUtils::readIntAttributeValue(
- &mAttributeMap, MAX_TRIGRAM_COUNT_KEY, DEFAULT_MAX_TRIGRAM_COUNT)),
mCodePointTable(HeaderReadWriteUtils::readCodePointTable(&mAttributeMap)) {}
// Constructs header information using an attribute map.
@@ -82,18 +71,13 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
DATE_KEY, TimeKeeper::peekCurrentTime() /* defaultValue */)),
mLastDecayedTime(HeaderReadWriteUtils::readIntAttributeValue(&mAttributeMap,
DATE_KEY, TimeKeeper::peekCurrentTime() /* defaultValue */)),
- mUnigramCount(0), mBigramCount(0), mTrigramCount(0), mExtendedRegionSize(0),
+ mNgramCounts(readNgramCounts()), mMaxNgramCounts(readMaxNgramCounts()),
+ mExtendedRegionSize(0),
mHasHistoricalInfoOfWords(HeaderReadWriteUtils::readBoolAttributeValue(
&mAttributeMap, HAS_HISTORICAL_INFO_KEY, false /* defaultValue */)),
mForgettingCurveProbabilityValuesTableId(HeaderReadWriteUtils::readIntAttributeValue(
&mAttributeMap, FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY,
DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID)),
- mMaxUnigramCount(HeaderReadWriteUtils::readIntAttributeValue(
- &mAttributeMap, MAX_UNIGRAM_COUNT_KEY, DEFAULT_MAX_UNIGRAM_COUNT)),
- mMaxBigramCount(HeaderReadWriteUtils::readIntAttributeValue(
- &mAttributeMap, MAX_BIGRAM_COUNT_KEY, DEFAULT_MAX_BIGRAM_COUNT)),
- mMaxTrigramCount(HeaderReadWriteUtils::readIntAttributeValue(
- &mAttributeMap, MAX_TRIGRAM_COUNT_KEY, DEFAULT_MAX_TRIGRAM_COUNT)),
mCodePointTable(HeaderReadWriteUtils::readCodePointTable(&mAttributeMap)) {}
// Copy header information
@@ -105,15 +89,12 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
mRequiresGermanUmlautProcessing(headerPolicy->mRequiresGermanUmlautProcessing),
mIsDecayingDict(headerPolicy->mIsDecayingDict),
mDate(headerPolicy->mDate), mLastDecayedTime(headerPolicy->mLastDecayedTime),
- mUnigramCount(headerPolicy->mUnigramCount), mBigramCount(headerPolicy->mBigramCount),
- mTrigramCount(headerPolicy->mTrigramCount),
+ mNgramCounts(headerPolicy->mNgramCounts),
+ mMaxNgramCounts(headerPolicy->mMaxNgramCounts),
mExtendedRegionSize(headerPolicy->mExtendedRegionSize),
mHasHistoricalInfoOfWords(headerPolicy->mHasHistoricalInfoOfWords),
mForgettingCurveProbabilityValuesTableId(
headerPolicy->mForgettingCurveProbabilityValuesTableId),
- mMaxUnigramCount(headerPolicy->mMaxUnigramCount),
- mMaxBigramCount(headerPolicy->mMaxBigramCount),
- mMaxTrigramCount(headerPolicy->mMaxTrigramCount),
mCodePointTable(headerPolicy->mCodePointTable) {}
// Temporary dummy header.
@@ -121,10 +102,9 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
: mDictFormatVersion(FormatUtils::UNKNOWN_VERSION), mDictionaryFlags(0), mSize(0),
mAttributeMap(), mLocale(CharUtils::EMPTY_STRING), mMultiWordCostMultiplier(0.0f),
mRequiresGermanUmlautProcessing(false), mIsDecayingDict(false),
- mDate(0), mLastDecayedTime(0), mUnigramCount(0), mBigramCount(0), mTrigramCount(0),
+ mDate(0), mLastDecayedTime(0), mNgramCounts(), mMaxNgramCounts(),
mExtendedRegionSize(0), mHasHistoricalInfoOfWords(false),
- mForgettingCurveProbabilityValuesTableId(0), mMaxUnigramCount(0), mMaxBigramCount(0),
- mMaxTrigramCount(0), mCodePointTable(nullptr) {}
+ mForgettingCurveProbabilityValuesTableId(0), mCodePointTable(nullptr) {}
~HeaderPolicy() {}
@@ -186,16 +166,12 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
return mLastDecayedTime;
}
- AK_FORCE_INLINE int getUnigramCount() const {
- return mUnigramCount;
+ AK_FORCE_INLINE const EntryCounts &getNgramCounts() const {
+ return mNgramCounts;
}
- AK_FORCE_INLINE int getBigramCount() const {
- return mBigramCount;
- }
-
- AK_FORCE_INLINE int getTrigramCount() const {
- return mTrigramCount;
+ AK_FORCE_INLINE const EntryCounts getMaxNgramCounts() const {
+ return mMaxNgramCounts;
}
AK_FORCE_INLINE int getExtendedRegionSize() const {
@@ -219,18 +195,6 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
return mForgettingCurveProbabilityValuesTableId;
}
- AK_FORCE_INLINE int getMaxUnigramCount() const {
- return mMaxUnigramCount;
- }
-
- AK_FORCE_INLINE int getMaxBigramCount() const {
- return mMaxBigramCount;
- }
-
- AK_FORCE_INLINE int getMaxTrigramCount() const {
- return mMaxTrigramCount;
- }
-
void readHeaderValueOrQuestionMark(const char *const key,
int *outValue, int outValueSize) const;
@@ -262,24 +226,18 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
static const char *const IS_DECAYING_DICT_KEY;
static const char *const DATE_KEY;
static const char *const LAST_DECAYED_TIME_KEY;
- static const char *const UNIGRAM_COUNT_KEY;
- static const char *const BIGRAM_COUNT_KEY;
- static const char *const TRIGRAM_COUNT_KEY;
+ static const char *const NGRAM_COUNT_KEYS[];
+ static const char *const MAX_NGRAM_COUNT_KEYS[];
+ static const int DEFAULT_MAX_NGRAM_COUNTS[];
static const char *const EXTENDED_REGION_SIZE_KEY;
static const char *const HAS_HISTORICAL_INFO_KEY;
static const char *const LOCALE_KEY;
static const char *const FORGETTING_CURVE_OCCURRENCES_TO_LEVEL_UP_KEY;
static const char *const FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID_KEY;
static const char *const FORGETTING_CURVE_DURATION_TO_LEVEL_DOWN_IN_SECONDS_KEY;
- static const char *const MAX_UNIGRAM_COUNT_KEY;
- static const char *const MAX_BIGRAM_COUNT_KEY;
- static const char *const MAX_TRIGRAM_COUNT_KEY;
static const int DEFAULT_MULTIPLE_WORDS_DEMOTION_RATE;
static const float MULTIPLE_WORD_COST_MULTIPLIER_SCALE;
static const int DEFAULT_FORGETTING_CURVE_PROBABILITY_VALUES_TABLE_ID;
- static const int DEFAULT_MAX_UNIGRAM_COUNT;
- static const int DEFAULT_MAX_BIGRAM_COUNT;
- static const int DEFAULT_MAX_TRIGRAM_COUNT;
const FormatUtils::FORMAT_VERSION mDictFormatVersion;
const HeaderReadWriteUtils::DictionaryFlags mDictionaryFlags;
@@ -291,21 +249,18 @@ class HeaderPolicy : public DictionaryHeaderStructurePolicy {
const bool mIsDecayingDict;
const int mDate;
const int mLastDecayedTime;
- const int mUnigramCount;
- const int mBigramCount;
- const int mTrigramCount;
+ const EntryCounts mNgramCounts;
+ const EntryCounts mMaxNgramCounts;
const int mExtendedRegionSize;
const bool mHasHistoricalInfoOfWords;
const int mForgettingCurveProbabilityValuesTableId;
- const int mMaxUnigramCount;
- const int mMaxBigramCount;
- const int mMaxTrigramCount;
const int *const mCodePointTable;
const std::vector<int> readLocale() const;
float readMultipleWordCostMultiplier() const;
bool readRequiresGermanUmlautProcessing() const;
-
+ const EntryCounts readNgramCounts() const;
+ const EntryCounts readMaxNgramCounts() const;
static DictionaryHeaderStructurePolicy::AttributeMap createAttributeMapAndReadAllAttributes(
const uint8_t *const dictBuf);
};
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
index ca7d93b0e..051aed45a 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
@@ -303,7 +303,7 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const CodePointArrayView wordCodePo
if (mUpdatingHelper.addUnigramWord(&readingHelper, codePointArrayView, unigramProperty,
&addedNewUnigram)) {
if (addedNewUnigram && !unigramProperty->representsBeginningOfSentence()) {
- mEntryCounters.incrementUnigramCount();
+ mEntryCounters.incrementNgramCount(NgramType::Unigram);
}
if (unigramProperty->getShortcuts().size() > 0) {
// Add shortcut target.
@@ -397,7 +397,7 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const NgramProperty *const ngramPrope
if (mUpdatingHelper.addNgramEntry(PtNodePosArrayView::singleElementView(&prevWordPtNodePos),
wordPos, ngramProperty, &addedNewBigram)) {
if (addedNewBigram) {
- mEntryCounters.incrementBigramCount();
+ mEntryCounters.incrementNgramCount(NgramType::Bigram);
}
return true;
} else {
@@ -438,7 +438,7 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const NgramContext *const ngramCon
const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]);
if (mUpdatingHelper.removeNgramEntry(
PtNodePosArrayView::singleElementView(&prevWordPtNodePos), wordPos)) {
- mEntryCounters.decrementBigramCount();
+ mEntryCounters.decrementNgramCount(NgramType::Bigram);
return true;
} else {
return false;
@@ -525,20 +525,23 @@ void Ver4PatriciaTriePolicy::getProperty(const char *const query, const int quer
char *const outResult, const int maxResultLength) {
const int compareLength = queryLength + 1 /* terminator */;
if (strncmp(query, UNIGRAM_COUNT_QUERY, compareLength) == 0) {
- snprintf(outResult, maxResultLength, "%d", mEntryCounters.getUnigramCount());
+ snprintf(outResult, maxResultLength, "%d",
+ mEntryCounters.getNgramCount(NgramType::Unigram));
} else if (strncmp(query, BIGRAM_COUNT_QUERY, compareLength) == 0) {
- snprintf(outResult, maxResultLength, "%d", mEntryCounters.getBigramCount());
+ snprintf(outResult, maxResultLength, "%d", mEntryCounters.getNgramCount(NgramType::Bigram));
} else if (strncmp(query, MAX_UNIGRAM_COUNT_QUERY, compareLength) == 0) {
snprintf(outResult, maxResultLength, "%d",
mHeaderPolicy->isDecayingDict() ?
ForgettingCurveUtils::getEntryCountHardLimit(
- mHeaderPolicy->getMaxUnigramCount()) :
+ mHeaderPolicy->getMaxNgramCounts().getNgramCount(
+ NgramType::Unigram)) :
static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE));
} else if (strncmp(query, MAX_BIGRAM_COUNT_QUERY, compareLength) == 0) {
snprintf(outResult, maxResultLength, "%d",
mHeaderPolicy->isDecayingDict() ?
ForgettingCurveUtils::getEntryCountHardLimit(
- mHeaderPolicy->getMaxBigramCount()) :
+ mHeaderPolicy->getMaxNgramCounts().getNgramCount(
+ NgramType::Bigram)) :
static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE));
}
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h
index 0480876ed..80b1111b4 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h
@@ -76,8 +76,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
&mPtNodeArrayReader, &mBigramPolicy, &mShortcutPolicy),
mUpdatingHelper(mDictBuffer, &mNodeReader, &mNodeWriter),
mWritingHelper(mBuffers.get()),
- mEntryCounters(mHeaderPolicy->getUnigramCount(), mHeaderPolicy->getBigramCount(),
- mHeaderPolicy->getTrigramCount()),
+ mEntryCounters(mHeaderPolicy->getNgramCounts().getCountArray()),
mTerminalPtNodePositionsForIteratingWords(), mIsCorrupted(false) {};
virtual int getRootPosition() const {
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.cpp
index a033d396b..985c16803 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_writing_helper.cpp
@@ -53,8 +53,8 @@ bool Ver4PatriciaTrieWritingHelper::writeToDictFile(const char *const dictDirPat
entryCounts, extendedRegionSize, &headerBuffer)) {
AKLOGE("Cannot write header structure to buffer. "
"updatesLastDecayedTime: %d, unigramCount: %d, bigramCount: %d, "
- "extendedRegionSize: %d", false, entryCounts.getUnigramCount(),
- entryCounts.getBigramCount(), extendedRegionSize);
+ "extendedRegionSize: %d", false, entryCounts.getNgramCount(NgramType::Unigram),
+ entryCounts.getNgramCount(NgramType::Bigram), extendedRegionSize);
return false;
}
return mBuffers->flushHeaderAndDictBuffers(dictDirPath, &headerBuffer);
@@ -73,9 +73,11 @@ bool Ver4PatriciaTrieWritingHelper::writeToDictFileWithGC(const int rootPtNodeAr
}
BufferWithExtendableBuffer headerBuffer(
BufferWithExtendableBuffer::DEFAULT_MAX_ADDITIONAL_BUFFER_SIZE);
+ MutableEntryCounters entryCounters;
+ entryCounters.setNgramCount(NgramType::Unigram, unigramCount);
+ entryCounters.setNgramCount(NgramType::Bigram, bigramCount);
if (!headerPolicy->fillInAndWriteHeaderToBuffer(true /* updatesLastDecayedTime */,
- EntryCounts(unigramCount, bigramCount, 0 /* trigramCount */),
- 0 /* extendedRegionSize */, &headerBuffer)) {
+ entryCounters.getEntryCounts(), 0 /* extendedRegionSize */, &headerBuffer)) {
return false;
}
return dictBuffers->flushHeaderAndDictBuffers(dictDirPath, &headerBuffer);
@@ -107,7 +109,7 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
}
const int unigramCount = traversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted
.getValidUnigramCount();
- const int maxUnigramCount = headerPolicy->getMaxUnigramCount();
+ const int maxUnigramCount = headerPolicy->getMaxNgramCounts().getNgramCount(NgramType::Unigram);
if (headerPolicy->isDecayingDict() && unigramCount > maxUnigramCount) {
if (!truncateUnigrams(&ptNodeReader, &ptNodeWriter, maxUnigramCount)) {
AKLOGE("Cannot remove unigrams. current: %d, max: %d", unigramCount,
@@ -124,7 +126,7 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
return false;
}
const int bigramCount = traversePolicyToUpdateBigramProbability.getValidBigramEntryCount();
- const int maxBigramCount = headerPolicy->getMaxBigramCount();
+ const int maxBigramCount = headerPolicy->getMaxNgramCounts().getNgramCount(NgramType::Bigram);
if (headerPolicy->isDecayingDict() && bigramCount > maxBigramCount) {
if (!truncateBigrams(maxBigramCount)) {
AKLOGE("Cannot remove bigrams. current: %d, max: %d", bigramCount, maxBigramCount);
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.cpp
index b0fbb3e72..025ee9932 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.cpp
@@ -18,17 +18,14 @@
namespace latinime {
-// These counts are used to provide stable probabilities even if the user's input count is small.
-const int DynamicLanguageModelProbabilityUtils::ASSUMED_MIN_COUNT_FOR_UNIGRAMS = 8192;
-const int DynamicLanguageModelProbabilityUtils::ASSUMED_MIN_COUNT_FOR_BIGRAMS = 2;
-const int DynamicLanguageModelProbabilityUtils::ASSUMED_MIN_COUNT_FOR_TRIGRAMS = 2;
+// Used to provide stable probabilities even if the user's input count is small.
+const int DynamicLanguageModelProbabilityUtils::ASSUMED_MIN_COUNTS[] = {8192, 2, 2, 1};
-// These are encoded backoff weights.
-// Note that we give positive value for trigrams that means the weight is more than 1.
+// Encoded backoff weights.
+// Note that we give positive values for trigrams and quadgrams that means the weight is more than
+// 1.
// TODO: Apply backoff for main dictionaries and quit giving a positive backoff weight.
-const int DynamicLanguageModelProbabilityUtils::ENCODED_BACKOFF_WEIGHT_FOR_UNIGRAMS = -32;
-const int DynamicLanguageModelProbabilityUtils::ENCODED_BACKOFF_WEIGHT_FOR_BIGRAMS = 0;
-const int DynamicLanguageModelProbabilityUtils::ENCODED_BACKOFF_WEIGHT_FOR_TRIGRAMS = 8;
+const int DynamicLanguageModelProbabilityUtils::ENCODED_BACKOFF_WEIGHTS[] = {-32, -4, 2, 8};
// This value is used to remove too old entries from the dictionary.
const int DynamicLanguageModelProbabilityUtils::DURATION_TO_DISCARD_ENTRY_IN_SECONDS =
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h
index 88bc58fe8..644ae2ca7 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h
@@ -21,6 +21,7 @@
#include "defines.h"
#include "suggest/core/dictionary/property/historical_info.h"
+#include "utils/ngram_utils.h"
#include "utils/time_keeper.h"
namespace latinime {
@@ -28,46 +29,14 @@ namespace latinime {
class DynamicLanguageModelProbabilityUtils {
public:
static float computeRawProbabilityFromCounts(const int count, const int contextCount,
- const int matchedWordCountInContext) {
- int minCount = 0;
- switch (matchedWordCountInContext) {
- case 1:
- minCount = ASSUMED_MIN_COUNT_FOR_UNIGRAMS;
- break;
- case 2:
- minCount = ASSUMED_MIN_COUNT_FOR_BIGRAMS;
- break;
- case 3:
- minCount = ASSUMED_MIN_COUNT_FOR_TRIGRAMS;
- break;
- default:
- AKLOGE("computeRawProbabilityFromCounts is called with invalid "
- "matchedWordCountInContext (%d).", matchedWordCountInContext);
- ASSERT(false);
- return 0.0f;
- }
+ const NgramType ngramType) {
+ const int minCount = ASSUMED_MIN_COUNTS[static_cast<int>(ngramType)];
return static_cast<float>(count) / static_cast<float>(std::max(contextCount, minCount));
}
- static float backoff(const int ngramProbability, const int matchedWordCountInContext) {
- int probability = NOT_A_PROBABILITY;
-
- switch (matchedWordCountInContext) {
- case 1:
- probability = ngramProbability + ENCODED_BACKOFF_WEIGHT_FOR_UNIGRAMS;
- break;
- case 2:
- probability = ngramProbability + ENCODED_BACKOFF_WEIGHT_FOR_BIGRAMS;
- break;
- case 3:
- probability = ngramProbability + ENCODED_BACKOFF_WEIGHT_FOR_TRIGRAMS;
- break;
- default:
- AKLOGE("backoff is called with invalid matchedWordCountInContext (%d).",
- matchedWordCountInContext);
- ASSERT(false);
- return NOT_A_PROBABILITY;
- }
+ static float backoff(const int ngramProbability, const NgramType ngramType) {
+ const int probability =
+ ngramProbability + ENCODED_BACKOFF_WEIGHTS[static_cast<int>(ngramType)];
return std::min(std::max(probability, NOT_A_PROBABILITY), MAX_PROBABILITY);
}
@@ -97,16 +66,10 @@ class DynamicLanguageModelProbabilityUtils {
private:
DISALLOW_IMPLICIT_CONSTRUCTORS(DynamicLanguageModelProbabilityUtils);
- static_assert(MAX_PREV_WORD_COUNT_FOR_N_GRAM <= 2, "Max supported Ngram is Trigram.");
-
- static const int ASSUMED_MIN_COUNT_FOR_UNIGRAMS;
- static const int ASSUMED_MIN_COUNT_FOR_BIGRAMS;
- static const int ASSUMED_MIN_COUNT_FOR_TRIGRAMS;
-
- static const int ENCODED_BACKOFF_WEIGHT_FOR_UNIGRAMS;
- static const int ENCODED_BACKOFF_WEIGHT_FOR_BIGRAMS;
- static const int ENCODED_BACKOFF_WEIGHT_FOR_TRIGRAMS;
+ static_assert(MAX_PREV_WORD_COUNT_FOR_N_GRAM <= 3, "Max supported Ngram is Quadgram.");
+ static const int ASSUMED_MIN_COUNTS[];
+ static const int ENCODED_BACKOFF_WEIGHTS[];
static const int DURATION_TO_DISCARD_ENTRY_IN_SECONDS;
};
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
index 31b1ea696..6db7ea444 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
@@ -21,6 +21,7 @@
#include "suggest/policyimpl/dictionary/structure/v4/content/dynamic_language_model_probability_utils.h"
#include "suggest/policyimpl/dictionary/utils/probability_utils.h"
+#include "utils/ngram_utils.h"
namespace latinime {
@@ -89,16 +90,17 @@ const WordAttributes LanguageModelDictContent::getWordAttributes(const WordIdArr
}
contextCount = prevWordProbabilityEntry.getHistoricalInfo()->getCount();
}
+ const NgramType ngramType = NgramUtils::getNgramTypeFromWordCount(i + 1);
const float rawProbability =
DynamicLanguageModelProbabilityUtils::computeRawProbabilityFromCounts(
- historicalInfo->getCount(), contextCount, i + 1);
+ historicalInfo->getCount(), contextCount, ngramType);
const int encodedRawProbability =
ProbabilityUtils::encodeRawProbability(rawProbability);
const int decayedProbability =
DynamicLanguageModelProbabilityUtils::getDecayedProbability(
encodedRawProbability, *historicalInfo);
probability = DynamicLanguageModelProbabilityUtils::backoff(
- decayedProbability, i + 1 /* n */);
+ decayedProbability, ngramType);
} else {
probability = probabilityEntry.getProbability();
}
@@ -198,18 +200,19 @@ bool LanguageModelDictContent::truncateEntries(const EntryCounts &currentEntryCo
MutableEntryCounters *const outEntryCounters) {
for (int prevWordCount = 0; prevWordCount <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++prevWordCount) {
const int totalWordCount = prevWordCount + 1;
- if (currentEntryCounts.getNgramCount(totalWordCount)
- <= maxEntryCounts.getNgramCount(totalWordCount)) {
- outEntryCounters->setNgramCount(totalWordCount,
- currentEntryCounts.getNgramCount(totalWordCount));
+ const NgramType ngramType = NgramUtils::getNgramTypeFromWordCount(totalWordCount);
+ if (currentEntryCounts.getNgramCount(ngramType)
+ <= maxEntryCounts.getNgramCount(ngramType)) {
+ outEntryCounters->setNgramCount(ngramType,
+ currentEntryCounts.getNgramCount(ngramType));
continue;
}
int entryCount = 0;
if (!turncateEntriesInSpecifiedLevel(headerPolicy,
- maxEntryCounts.getNgramCount(totalWordCount), prevWordCount, &entryCount)) {
+ maxEntryCounts.getNgramCount(ngramType), prevWordCount, &entryCount)) {
return false;
}
- outEntryCounters->setNgramCount(totalWordCount, entryCount);
+ outEntryCounters->setNgramCount(ngramType, entryCount);
}
return true;
}
@@ -246,7 +249,10 @@ bool LanguageModelDictContent::updateAllEntriesOnInputWord(const WordIdArrayView
mGlobalCounters.updateMaxValueOfCounters(
updatedNgramProbabilityEntry.getHistoricalInfo()->getCount());
if (!originalNgramProbabilityEntry.isValid()) {
- entryCountersToUpdate->incrementNgramCount(i + 2);
+ // (i + 2) words are used in total because the prevWords consists of (i + 1) words when
+ // looking at its i-th element.
+ entryCountersToUpdate->incrementNgramCount(
+ NgramUtils::getNgramTypeFromWordCount(i + 2));
}
}
return true;
@@ -369,7 +375,8 @@ bool LanguageModelDictContent::updateAllProbabilityEntriesForGCInner(const int b
}
}
}
- outEntryCounters->incrementNgramCount(prevWordCount + 1);
+ outEntryCounters->incrementNgramCount(
+ NgramUtils::getNgramTypeFromWordCount(prevWordCount + 1));
if (!entry.hasNextLevelMap()) {
continue;
}
@@ -402,7 +409,8 @@ bool LanguageModelDictContent::turncateEntriesInSpecifiedLevel(
for (int i = 0; i < entryCountToRemove; ++i) {
const EntryInfoToTurncate &entryInfo = entryInfoVector[i];
if (!removeNgramProbabilityEntry(
- WordIdArrayView(entryInfo.mPrevWordIds, entryInfo.mPrevWordCount), entryInfo.mKey)) {
+ WordIdArrayView(entryInfo.mPrevWordIds, entryInfo.mPrevWordCount),
+ entryInfo.mKey)) {
return false;
}
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
index 7449cd02b..a96719533 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
@@ -31,6 +31,7 @@
#include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_reader.h"
#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h"
#include "suggest/policyimpl/dictionary/utils/probability_utils.h"
+#include "utils/ngram_utils.h"
namespace latinime {
@@ -215,7 +216,7 @@ bool Ver4PatriciaTriePolicy::addUnigramEntry(const CodePointArrayView wordCodePo
if (mUpdatingHelper.addUnigramWord(&readingHelper, codePointArrayView, unigramProperty,
&addedNewUnigram)) {
if (addedNewUnigram && !unigramProperty->representsBeginningOfSentence()) {
- mEntryCounters.incrementUnigramCount();
+ mEntryCounters.incrementNgramCount(NgramType::Unigram);
}
if (unigramProperty->getShortcuts().size() > 0) {
// Add shortcut target.
@@ -263,7 +264,7 @@ bool Ver4PatriciaTriePolicy::removeUnigramEntry(const CodePointArrayView wordCod
return false;
}
if (!ptNodeParams.representsNonWordInfo()) {
- mEntryCounters.decrementUnigramCount();
+ mEntryCounters.decrementNgramCount(NgramType::Unigram);
}
return true;
}
@@ -321,7 +322,8 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const NgramProperty *const ngramPrope
bool addedNewEntry = false;
if (mNodeWriter.addNgramEntry(prevWordIds, wordId, ngramProperty, &addedNewEntry)) {
if (addedNewEntry) {
- mEntryCounters.incrementNgramCount(prevWordIds.size() + 1);
+ mEntryCounters.incrementNgramCount(
+ NgramUtils::getNgramTypeFromWordCount(prevWordIds.size() + 1));
}
return true;
} else {
@@ -359,7 +361,8 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const NgramContext *const ngramCon
return false;
}
if (mNodeWriter.removeNgramEntry(prevWordIds, wordId)) {
- mEntryCounters.decrementNgramCount(prevWordIds.size());
+ mEntryCounters.decrementNgramCount(
+ NgramUtils::getNgramTypeFromWordCount(prevWordIds.size() + 1));
return true;
} else {
return false;
@@ -477,20 +480,23 @@ void Ver4PatriciaTriePolicy::getProperty(const char *const query, const int quer
char *const outResult, const int maxResultLength) {
const int compareLength = queryLength + 1 /* terminator */;
if (strncmp(query, UNIGRAM_COUNT_QUERY, compareLength) == 0) {
- snprintf(outResult, maxResultLength, "%d", mEntryCounters.getUnigramCount());
+ snprintf(outResult, maxResultLength, "%d",
+ mEntryCounters.getNgramCount(NgramType::Unigram));
} else if (strncmp(query, BIGRAM_COUNT_QUERY, compareLength) == 0) {
- snprintf(outResult, maxResultLength, "%d", mEntryCounters.getBigramCount());
+ snprintf(outResult, maxResultLength, "%d", mEntryCounters.getNgramCount(NgramType::Bigram));
} else if (strncmp(query, MAX_UNIGRAM_COUNT_QUERY, compareLength) == 0) {
snprintf(outResult, maxResultLength, "%d",
mHeaderPolicy->isDecayingDict() ?
ForgettingCurveUtils::getEntryCountHardLimit(
- mHeaderPolicy->getMaxUnigramCount()) :
+ mHeaderPolicy->getMaxNgramCounts().getNgramCount(
+ NgramType::Unigram)) :
static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE));
} else if (strncmp(query, MAX_BIGRAM_COUNT_QUERY, compareLength) == 0) {
snprintf(outResult, maxResultLength, "%d",
mHeaderPolicy->isDecayingDict() ?
ForgettingCurveUtils::getEntryCountHardLimit(
- mHeaderPolicy->getMaxBigramCount()) :
+ mHeaderPolicy->getMaxNgramCounts().getNgramCount(
+ NgramType::Bigram)) :
static_cast<int>(Ver4DictConstants::MAX_DICTIONARY_SIZE));
}
}
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h
index 13700b390..93faa83a0 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h
@@ -51,8 +51,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
&mShortcutPolicy),
mUpdatingHelper(mDictBuffer, &mNodeReader, &mNodeWriter),
mWritingHelper(mBuffers.get()),
- mEntryCounters(mHeaderPolicy->getUnigramCount(), mHeaderPolicy->getBigramCount(),
- mHeaderPolicy->getTrigramCount()),
+ mEntryCounters(mHeaderPolicy->getNgramCounts().getCountArray()),
mTerminalPtNodePositionsForIteratingWords(), mIsCorrupted(false) {};
AK_FORCE_INLINE int getRootPosition() const {
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
index 7f0604ce8..34af76c5d 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
@@ -29,6 +29,7 @@
#include "suggest/policyimpl/dictionary/utils/buffer_with_extendable_buffer.h"
#include "suggest/policyimpl/dictionary/utils/file_utils.h"
#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h"
+#include "utils/ngram_utils.h"
namespace latinime {
@@ -43,8 +44,9 @@ bool Ver4PatriciaTrieWritingHelper::writeToDictFile(const char *const dictDirPat
entryCounts, extendedRegionSize, &headerBuffer)) {
AKLOGE("Cannot write header structure to buffer. "
"updatesLastDecayedTime: %d, unigramCount: %d, bigramCount: %d, trigramCount: %d,"
- "extendedRegionSize: %d", false, entryCounts.getUnigramCount(),
- entryCounts.getBigramCount(), entryCounts.getTrigramCount(),
+ "extendedRegionSize: %d", false, entryCounts.getNgramCount(NgramType::Unigram),
+ entryCounts.getNgramCount(NgramType::Bigram),
+ entryCounts.getNgramCount(NgramType::Trigram),
extendedRegionSize);
return false;
}
@@ -86,8 +88,7 @@ bool Ver4PatriciaTrieWritingHelper::runGC(const int rootPtNodeArrayPos,
return false;
}
if (headerPolicy->isDecayingDict()) {
- const EntryCounts maxEntryCounts(headerPolicy->getMaxUnigramCount(),
- headerPolicy->getMaxBigramCount(), headerPolicy->getMaxTrigramCount());
+ const EntryCounts &maxEntryCounts = headerPolicy->getMaxNgramCounts();
if (!mBuffers->getMutableLanguageModelDictContent()->truncateEntries(
outEntryCounters->getEntryCounts(), maxEntryCounts, headerPolicy,
outEntryCounters)) {
diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/entry_counters.h b/native/jni/src/suggest/policyimpl/dictionary/utils/entry_counters.h
index 73dc42a18..5e443026e 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/utils/entry_counters.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/utils/entry_counters.h
@@ -20,42 +20,31 @@
#include <array>
#include "defines.h"
+#include "utils/ngram_utils.h"
namespace latinime {
// Copyable but immutable
class EntryCounts final {
public:
- EntryCounts() : mEntryCounts({{0, 0, 0}}) {}
-
- EntryCounts(const int unigramCount, const int bigramCount, const int trigramCount)
- : mEntryCounts({{unigramCount, bigramCount, trigramCount}}) {}
+ EntryCounts() : mEntryCounts({{0, 0, 0, 0}}) {}
explicit EntryCounts(const std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1> &counters)
: mEntryCounts(counters) {}
- int getUnigramCount() const {
- return mEntryCounts[0];
- }
-
- int getBigramCount() const {
- return mEntryCounts[1];
- }
-
- int getTrigramCount() const {
- return mEntryCounts[2];
+ int getNgramCount(const NgramType ngramType) const {
+ return mEntryCounts[static_cast<int>(ngramType)];
}
- int getNgramCount(const size_t n) const {
- if (n < 1 || n > mEntryCounts.size()) {
- return 0;
- }
- return mEntryCounts[n - 1];
+ const std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1> &getCountArray() const {
+ return mEntryCounts;
}
private:
DISALLOW_ASSIGNMENT_OPERATOR(EntryCounts);
+ // Counts from Unigram (0-th element) to (MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1)-gram
+ // (MAX_PREV_WORD_COUNT_FOR_N_GRAM-th element)
const std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1> mEntryCounts;
};
@@ -65,68 +54,35 @@ class MutableEntryCounters final {
mEntryCounters.fill(0);
}
- MutableEntryCounters(const int unigramCount, const int bigramCount, const int trigramCount)
- : mEntryCounters({{unigramCount, bigramCount, trigramCount}}) {}
+ explicit MutableEntryCounters(
+ const std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1> &counters)
+ : mEntryCounters(counters) {}
const EntryCounts getEntryCounts() const {
return EntryCounts(mEntryCounters);
}
- int getUnigramCount() const {
- return mEntryCounters[0];
- }
-
- int getBigramCount() const {
- return mEntryCounters[1];
- }
-
- int getTrigramCount() const {
- return mEntryCounters[2];
- }
-
- void incrementUnigramCount() {
- ++mEntryCounters[0];
- }
-
- void decrementUnigramCount() {
- ASSERT(mEntryCounters[0] != 0);
- --mEntryCounters[0];
- }
-
- void incrementBigramCount() {
- ++mEntryCounters[1];
- }
-
- void decrementBigramCount() {
- ASSERT(mEntryCounters[1] != 0);
- --mEntryCounters[1];
+ void incrementNgramCount(const NgramType ngramType) {
+ ++mEntryCounters[static_cast<int>(ngramType)];
}
- void incrementNgramCount(const size_t n) {
- if (n < 1 || n > mEntryCounters.size()) {
- return;
- }
- ++mEntryCounters[n - 1];
+ void decrementNgramCount(const NgramType ngramType) {
+ --mEntryCounters[static_cast<int>(ngramType)];
}
- void decrementNgramCount(const size_t n) {
- if (n < 1 || n > mEntryCounters.size()) {
- return;
- }
- ASSERT(mEntryCounters[n - 1] != 0);
- --mEntryCounters[n - 1];
+ int getNgramCount(const NgramType ngramType) const {
+ return mEntryCounters[static_cast<int>(ngramType)];
}
- void setNgramCount(const size_t n, const int count) {
- if (n < 1 || n > mEntryCounters.size()) {
- return;
- }
- mEntryCounters[n - 1] = count;
+ void setNgramCount(const NgramType ngramType, const int count) {
+ mEntryCounters[static_cast<int>(ngramType)] = count;
}
private:
DISALLOW_COPY_AND_ASSIGN(MutableEntryCounters);
+ // Counters from Unigram (0-th element) to (MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1)-gram
+ // (MAX_PREV_WORD_COUNT_FOR_N_GRAM-th element)
std::array<int, MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1> mEntryCounters;
};
} // namespace latinime
diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp
index 9055f7bfc..f05c6149e 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp
@@ -126,20 +126,13 @@ const ForgettingCurveUtils::ProbabilityTable ForgettingCurveUtils::sProbabilityT
/* static */ bool ForgettingCurveUtils::needsToDecay(const bool mindsBlockByDecay,
const EntryCounts &entryCounts, const HeaderPolicy *const headerPolicy) {
- if (entryCounts.getUnigramCount()
- >= getEntryCountHardLimit(headerPolicy->getMaxUnigramCount())) {
- // Unigram count exceeds the limit.
- return true;
- }
- if (entryCounts.getBigramCount()
- >= getEntryCountHardLimit(headerPolicy->getMaxBigramCount())) {
- // Bigram count exceeds the limit.
- return true;
- }
- if (entryCounts.getTrigramCount()
- >= getEntryCountHardLimit(headerPolicy->getMaxTrigramCount())) {
- // Trigram count exceeds the limit.
- return true;
+ const EntryCounts &maxNgramCounts = headerPolicy->getMaxNgramCounts();
+ for (const auto ngramType : AllNgramTypes::ASCENDING) {
+ if (entryCounts.getNgramCount(ngramType)
+ >= getEntryCountHardLimit(maxNgramCounts.getNgramCount(ngramType))) {
+ // Unigram count exceeds the limit.
+ return true;
+ }
}
if (mindsBlockByDecay) {
return false;
diff --git a/native/jni/src/utils/ngram_utils.h b/native/jni/src/utils/ngram_utils.h
new file mode 100644
index 000000000..fa85ba35f
--- /dev/null
+++ b/native/jni/src/utils/ngram_utils.h
@@ -0,0 +1,63 @@
+/*
+ * Copyright (C) 2014, The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LATINIME_NGRAM_UTILS_H
+#define LATINIME_NGRAM_UTILS_H
+
+#include "defines.h"
+
+namespace latinime {
+
+enum class NgramType : int {
+ Unigram = 0,
+ Bigram = 1,
+ Trigram = 2,
+ Quadgram = 3,
+ NotANgramType = -1,
+};
+
+namespace AllNgramTypes {
+// Use anonymous namespace to avoid ODR (One Definition Rule) violation.
+namespace {
+
+const NgramType ASCENDING[] = {
+ NgramType::Unigram, NgramType::Bigram, NgramType::Trigram
+};
+
+const NgramType DESCENDING[] = {
+ NgramType::Trigram, NgramType::Bigram, NgramType::Unigram
+};
+
+} // namespace
+} // namespace AllNgramTypes
+
+class NgramUtils final {
+ public:
+ static AK_FORCE_INLINE NgramType getNgramTypeFromWordCount(const int wordCount) {
+ // Max supported ngram is (MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1)-gram.
+ if (wordCount <= 0 || wordCount > MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1) {
+ return NgramType::NotANgramType;
+ }
+ // Convert word count to 0-origin enum value.
+ return static_cast<NgramType>(wordCount - 1);
+ }
+
+ private:
+ DISALLOW_IMPLICIT_CONSTRUCTORS(NgramUtils);
+
+};
+}
+#endif /* LATINIME_NGRAM_UTILS_H */