diff options
Diffstat (limited to 'native/jni/src/binary_format.h')
-rw-r--r-- | native/jni/src/binary_format.h | 127 |
1 files changed, 107 insertions, 20 deletions
diff --git a/native/jni/src/binary_format.h b/native/jni/src/binary_format.h index ad16039ef..98241532f 100644 --- a/native/jni/src/binary_format.h +++ b/native/jni/src/binary_format.h @@ -23,6 +23,7 @@ #include "bloom_filter.h" #include "char_utils.h" +#include "hash_map_compat.h" namespace latinime { @@ -63,13 +64,14 @@ class BinaryFormat { static const int UNKNOWN_FORMAT = -1; static const int SHORTCUT_LIST_SIZE_SIZE = 2; - static int detectFormat(const uint8_t *const dict); - static int getHeaderSize(const uint8_t *const dict); - static int getFlags(const uint8_t *const dict); + static int detectFormat(const uint8_t *const dict, const int dictSize); + static int getHeaderSize(const uint8_t *const dict, const int dictSize); + static int getFlags(const uint8_t *const dict, const int dictSize); static bool hasBlacklistedOrNotAWordFlag(const int flags); - static void readHeaderValue(const uint8_t *const dict, const char *const key, int *outValue, - const int outValueSize); - static int readHeaderValueInt(const uint8_t *const dict, const char *const key); + static void readHeaderValue(const uint8_t *const dict, const int dictSize, + const char *const key, int *outValue, const int outValueSize); + static int readHeaderValueInt(const uint8_t *const dict, const int dictSize, + const char *const key); static int getGroupCountAndForwardPointer(const uint8_t *const dict, int *pos); static uint8_t getFlagsAndForwardPointer(const uint8_t *const dict, int *pos); static int getCodePointAndForwardPointer(const uint8_t *const dict, int *pos); @@ -93,7 +95,13 @@ class BinaryFormat { const int unigramProbability, const int bigramProbability); static int getProbability(const int position, const std::map<int, int> *bigramMap, const uint8_t *bigramFilter, const int unigramProbability); - static float getMultiWordCostMultiplier(const uint8_t *const dict); + static int getBigramProbabilityFromHashMap(const int position, + const hash_map_compat<int, int> *bigramMap, const int unigramProbability); + static float getMultiWordCostMultiplier(const uint8_t *const dict, const int dictSize); + static void fillBigramProbabilityToHashMap(const uint8_t *const root, int position, + hash_map_compat<int, int> *bigramMap); + static int getBigramProbability(const uint8_t *const root, int position, + const int nextPosition, const int unigramProbability); // Flags for special processing // Those *must* match the flags in makedict (BinaryDictInputOutput#*_PROCESSING_FLAG) or @@ -105,6 +113,8 @@ class BinaryFormat { private: DISALLOW_IMPLICIT_CONSTRUCTORS(BinaryFormat); + static int getBigramListPositionForWordPosition(const uint8_t *const root, int position); + static const int FLAG_GROUP_ADDRESS_TYPE_NOADDRESS = 0x00; static const int FLAG_GROUP_ADDRESS_TYPE_ONEBYTE = 0x40; static const int FLAG_GROUP_ADDRESS_TYPE_TWOBYTES = 0x80; @@ -113,6 +123,8 @@ class BinaryFormat { static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES = 0x20; static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES = 0x30; + // Any file smaller than this is not a dictionary. + static const int DICTIONARY_MINIMUM_SIZE = 4; // Originally, format version 1 had a 16-bit magic number, then the version number `01' // then options that must be 0. Hence the first 32-bits of the format are always as follow // and it's okay to consider them a magic number as a whole. @@ -122,6 +134,8 @@ class BinaryFormat { // number, so we had to change it so that version 2 files would be rejected by older // implementations. On this occasion, we made the magic number 32 bits long. static const int FORMAT_VERSION_2_MAGIC_NUMBER = -1681835266; // 0x9BC13AFE + // Magic number (4 bytes), version (2 bytes), options (2 bytes), header size (4 bytes) = 12 + static const int FORMAT_VERSION_2_MINIMUM_SIZE = 12; static const int CHARACTER_ARRAY_TERMINATOR_SIZE = 1; static const int MINIMAL_ONE_BYTE_CHARACTER_VALUE = 0x20; @@ -132,8 +146,11 @@ class BinaryFormat { static int skipBigrams(const uint8_t *const dict, const uint8_t flags, const int pos); }; -AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict) { +AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict, const int dictSize) { // The magic number is stored big-endian. + // If the dictionary is less than 4 bytes, we can't even read the magic number, so we don't + // understand this format. + if (dictSize < DICTIONARY_MINIMUM_SIZE) return UNKNOWN_FORMAT; const int magicNumber = (dict[0] << 24) + (dict[1] << 16) + (dict[2] << 8) + dict[3]; switch (magicNumber) { case FORMAT_VERSION_1_MAGIC_NUMBER: @@ -143,6 +160,10 @@ AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict) { // Options (2 bytes) must be 0x00 0x00 return 1; case FORMAT_VERSION_2_MAGIC_NUMBER: + // Version 2 dictionaries are at least 12 bytes long (see below details for the header). + // If this dictionary has the version 2 magic number but is less than 12 bytes long, then + // it's an unknown format and we need to avoid confidently reading the next bytes. + if (dictSize < FORMAT_VERSION_2_MINIMUM_SIZE) return UNKNOWN_FORMAT; // Format 2 header is as follows: // Magic number (4 bytes) 0x9B 0xC1 0x3A 0xFE // Version number (2 bytes) 0x00 0x02 @@ -154,8 +175,8 @@ AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict) { } } -inline int BinaryFormat::getFlags(const uint8_t *const dict) { - switch (detectFormat(dict)) { +inline int BinaryFormat::getFlags(const uint8_t *const dict, const int dictSize) { + switch (detectFormat(dict, dictSize)) { case 1: return NO_FLAGS; // TODO: NO_FLAGS is unused anywhere else? default: @@ -164,11 +185,11 @@ inline int BinaryFormat::getFlags(const uint8_t *const dict) { } inline bool BinaryFormat::hasBlacklistedOrNotAWordFlag(const int flags) { - return flags & (FLAG_IS_BLACKLISTED | FLAG_IS_NOT_A_WORD); + return (flags & (FLAG_IS_BLACKLISTED | FLAG_IS_NOT_A_WORD)) != 0; } -inline int BinaryFormat::getHeaderSize(const uint8_t *const dict) { - switch (detectFormat(dict)) { +inline int BinaryFormat::getHeaderSize(const uint8_t *const dict, const int dictSize) { + switch (detectFormat(dict, dictSize)) { case 1: return FORMAT_VERSION_1_HEADER_SIZE; case 2: @@ -179,12 +200,12 @@ inline int BinaryFormat::getHeaderSize(const uint8_t *const dict) { } } -inline void BinaryFormat::readHeaderValue(const uint8_t *const dict, const char *const key, - int *outValue, const int outValueSize) { +inline void BinaryFormat::readHeaderValue(const uint8_t *const dict, const int dictSize, + const char *const key, int *outValue, const int outValueSize) { int outValueIndex = 0; // Only format 2 and above have header attributes as {key,value} string pairs. For prior // formats, we just return an empty string, as if the key wasn't found. - if (2 <= detectFormat(dict)) { + if (2 <= detectFormat(dict, dictSize)) { const int headerOptionsOffset = 4 /* magic number */ + 2 /* dictionary version */ + 2 /* flags */; const int headerSize = @@ -227,11 +248,12 @@ inline void BinaryFormat::readHeaderValue(const uint8_t *const dict, const char if (outValueIndex >= 0) outValue[outValueIndex] = 0; } -inline int BinaryFormat::readHeaderValueInt(const uint8_t *const dict, const char *const key) { +inline int BinaryFormat::readHeaderValueInt(const uint8_t *const dict, const int dictSize, + const char *const key) { const int bufferSize = LARGEST_INT_DIGIT_COUNT; int intBuffer[bufferSize]; char charBuffer[bufferSize]; - BinaryFormat::readHeaderValue(dict, key, intBuffer, bufferSize); + BinaryFormat::readHeaderValue(dict, dictSize, key, intBuffer, bufferSize); for (int i = 0; i < bufferSize; ++i) { charBuffer[i] = intBuffer[i]; } @@ -247,8 +269,10 @@ AK_FORCE_INLINE int BinaryFormat::getGroupCountAndForwardPointer(const uint8_t * return ((msb & 0x7F) << 8) | dict[(*pos)++]; } -inline float BinaryFormat::getMultiWordCostMultiplier(const uint8_t *const dict) { - const int headerValue = readHeaderValueInt(dict, "MULTIPLE_WORDS_DEMOTION_RATE"); +inline float BinaryFormat::getMultiWordCostMultiplier(const uint8_t *const dict, + const int dictSize) { + const int headerValue = readHeaderValueInt(dict, dictSize, + "MULTIPLE_WORDS_DEMOTION_RATE"); if (headerValue == S_INT_MIN) { return 1.0f; } @@ -687,5 +711,68 @@ inline int BinaryFormat::getProbability(const int position, const std::map<int, } return backoff(unigramProbability); } + +// This returns a probability in log space. +inline int BinaryFormat::getBigramProbabilityFromHashMap(const int position, + const hash_map_compat<int, int> *bigramMap, const int unigramProbability) { + if (!bigramMap) return backoff(unigramProbability); + const hash_map_compat<int, int>::const_iterator bigramProbabilityIt = bigramMap->find(position); + if (bigramProbabilityIt != bigramMap->end()) { + const int bigramProbability = bigramProbabilityIt->second; + return computeProbabilityForBigram(unigramProbability, bigramProbability); + } + return backoff(unigramProbability); +} + +AK_FORCE_INLINE void BinaryFormat::fillBigramProbabilityToHashMap( + const uint8_t *const root, int position, hash_map_compat<int, int> *bigramMap) { + position = getBigramListPositionForWordPosition(root, position); + if (0 == position) return; + + uint8_t bigramFlags; + do { + bigramFlags = getFlagsAndForwardPointer(root, &position); + const int probability = MASK_ATTRIBUTE_PROBABILITY & bigramFlags; + const int bigramPos = getAttributeAddressAndForwardPointer(root, bigramFlags, + &position); + (*bigramMap)[bigramPos] = probability; + } while (FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags); +} + +AK_FORCE_INLINE int BinaryFormat::getBigramProbability(const uint8_t *const root, int position, + const int nextPosition, const int unigramProbability) { + position = getBigramListPositionForWordPosition(root, position); + if (0 == position) return backoff(unigramProbability); + + uint8_t bigramFlags; + do { + bigramFlags = getFlagsAndForwardPointer(root, &position); + const int bigramPos = getAttributeAddressAndForwardPointer( + root, bigramFlags, &position); + if (bigramPos == nextPosition) { + const int bigramProbability = MASK_ATTRIBUTE_PROBABILITY & bigramFlags; + return computeProbabilityForBigram(unigramProbability, bigramProbability); + } + } while (FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags); + return backoff(unigramProbability); +} + +// Returns a pointer to the start of the bigram list. +AK_FORCE_INLINE int BinaryFormat::getBigramListPositionForWordPosition( + const uint8_t *const root, int position) { + if (NOT_VALID_WORD == position) return 0; + const uint8_t flags = getFlagsAndForwardPointer(root, &position); + if (!(flags & FLAG_HAS_BIGRAMS)) return 0; + if (flags & FLAG_HAS_MULTIPLE_CHARS) { + position = skipOtherCharacters(root, position); + } else { + getCodePointAndForwardPointer(root, &position); + } + position = skipProbability(flags, position); + position = skipChildrenPosition(flags, position); + position = skipShortcuts(root, flags, position); + return position; +} + } // namespace latinime #endif // LATINIME_BINARY_FORMAT_H |