diff options
39 files changed, 3120 insertions, 44 deletions
diff --git a/java/AndroidManifest.xml b/java/AndroidManifest.xml index 80cd08569..17d11c01d 100644 --- a/java/AndroidManifest.xml +++ b/java/AndroidManifest.xml @@ -95,6 +95,12 @@ </intent-filter> </receiver> + <receiver android:name=".DictionaryPackInstallBroadcastReceiver"> + <intent-filter> + <action android:name="com.android.inputmethod.dictionarypack.UNKNOWN_CLIENT" /> + </intent-filter> + </receiver> + <provider android:name="com.android.inputmethod.dictionarypack.DictionaryProvider" android:grantUriPermissions="true" android:exported="false" diff --git a/java/res/values/strings.xml b/java/res/values/strings.xml index fb341acc3..ebcd3d956 100644 --- a/java/res/values/strings.xml +++ b/java/res/values/strings.xml @@ -498,9 +498,9 @@ <!-- Message about some dictionary indicating the file is installed, but the dictionary is disabled --> <string name="dictionary_disabled">Installed, disabled</string> - <!-- Message to display in the dictionaries setting screen when some error prevented us to list installed dictionaries [CHAR LIMIT=50] --> + <!-- Message to display in the dictionaries setting screen when some error prevented us to list installed dictionaries [CHAR LIMIT=20] --> <string name="cannot_connect_to_dict_service">Problem connecting to dictionary service</string> - <!-- Message to display in the dictionaries setting screen when we found that no dictionaries are available [CHAR LIMIT=50]--> + <!-- Message to display in the dictionaries setting screen when we found that no dictionaries are available [CHAR LIMIT=20]--> <string name="no_dictionaries_available">No dictionaries available</string> <!-- Title of the options to press to refresh the list (as in, check for updates now) [CHAR_LIMIT=50] --> diff --git a/java/src/com/android/inputmethod/dictionarypack/DictionaryPackConstants.java b/java/src/com/android/inputmethod/dictionarypack/DictionaryPackConstants.java index 0c8b466a4..69615887f 100644 --- a/java/src/com/android/inputmethod/dictionarypack/DictionaryPackConstants.java +++ b/java/src/com/android/inputmethod/dictionarypack/DictionaryPackConstants.java @@ -25,16 +25,34 @@ package com.android.inputmethod.dictionarypack; */ public class DictionaryPackConstants { /** + * The root domain for the dictionary pack, upon which authorities and actions will append + * their own distinctive strings. + */ + private static final String DICTIONARY_DOMAIN = "com.android.inputmethod.dictionarypack"; + + /** * Authority for the ContentProvider protocol. */ // TODO: find some way to factorize this string with the one in the resources - public static final String AUTHORITY = "com.android.inputmethod.dictionarypack.aosp"; + public static final String AUTHORITY = DICTIONARY_DOMAIN + ".aosp"; /** * The action of the intent for publishing that new dictionary data is available. */ // TODO: make this different across different packages. A suggested course of action is // to use the package name inside this string. - public static final String NEW_DICTIONARY_INTENT_ACTION = - "com.android.inputmethod.dictionarypack.newdict"; + // NOTE: The appended string should be uppercase like all other actions, but it's not for + // historical reasons. + public static final String NEW_DICTIONARY_INTENT_ACTION = DICTIONARY_DOMAIN + ".newdict"; + + /** + * The action of the intent sent by the dictionary pack to ask for a client to make + * itself known. This is used when the settings activity is brought up for a client the + * dictionary pack does not know about. + */ + public static final String UNKNOWN_DICTIONARY_PROVIDER_CLIENT = DICTIONARY_DOMAIN + + ".UNKNOWN_CLIENT"; + // In the above intents, the name of the string extra that contains the name of the client + // we want information about. + public static final String DICTIONARY_PROVIDER_CLIENT_EXTRA = "client"; } diff --git a/java/src/com/android/inputmethod/dictionarypack/DictionaryProvider.java b/java/src/com/android/inputmethod/dictionarypack/DictionaryProvider.java index 77b3b8e2e..f8d1c4fc9 100644 --- a/java/src/com/android/inputmethod/dictionarypack/DictionaryProvider.java +++ b/java/src/com/android/inputmethod/dictionarypack/DictionaryProvider.java @@ -509,6 +509,11 @@ public final class DictionaryProvider extends ContentProvider { } catch (final BadFormatException e) { Log.w(TAG, "Not enough information to insert this dictionary " + values, e); } + // We just received new information about the list of dictionary for this client. + // For all intents and purposes, this is new metadata, so we should publish it + // so that any listeners (like the Settings interface for example) can update + // themselves. + UpdateHandler.publishUpdateMetadataCompleted(getContext(), true); break; case DICTIONARY_V1_WHOLE_LIST: case DICTIONARY_V1_DICT_INFO: diff --git a/java/src/com/android/inputmethod/dictionarypack/DictionarySettingsFragment.java b/java/src/com/android/inputmethod/dictionarypack/DictionarySettingsFragment.java index e85bb0d4a..9e27c1f3f 100644 --- a/java/src/com/android/inputmethod/dictionarypack/DictionarySettingsFragment.java +++ b/java/src/com/android/inputmethod/dictionarypack/DictionarySettingsFragment.java @@ -110,6 +110,15 @@ public final class DictionarySettingsFragment extends PreferenceFragment super.onResume(); mChangedSettings = false; UpdateHandler.registerUpdateEventListener(this); + final Activity activity = getActivity(); + if (!MetadataDbHelper.isClientKnown(activity, mClientId)) { + Log.i(TAG, "Unknown dictionary pack client: " + mClientId + ". Requesting info."); + final Intent unknownClientBroadcast = + new Intent(DictionaryPackConstants.UNKNOWN_DICTIONARY_PROVIDER_CLIENT); + unknownClientBroadcast.putExtra( + DictionaryPackConstants.DICTIONARY_PROVIDER_CLIENT_EXTRA, mClientId); + activity.sendBroadcast(unknownClientBroadcast); + } final IntentFilter filter = new IntentFilter(); filter.addAction(ConnectivityManager.CONNECTIVITY_ACTION); getActivity().registerReceiver(mConnectivityChangedReceiver, filter); @@ -363,7 +372,12 @@ public final class DictionarySettingsFragment extends PreferenceFragment getActivity(), android.R.anim.fade_out)); preferenceView.startAnimation(AnimationUtils.loadAnimation( getActivity(), android.R.anim.fade_in)); - mUpdateNowMenu.setTitle(R.string.check_for_updates_now); + // The menu is created by the framework asynchronously after the activity, + // which means it's possible to have the activity running but the menu not + // created yet - hence the necessity for a null check here. + if (null != mUpdateNowMenu) { + mUpdateNowMenu.setTitle(R.string.check_for_updates_now); + } } }); } diff --git a/java/src/com/android/inputmethod/dictionarypack/EventHandler.java b/java/src/com/android/inputmethod/dictionarypack/EventHandler.java index 96c4a8305..d8aa33bb8 100644 --- a/java/src/com/android/inputmethod/dictionarypack/EventHandler.java +++ b/java/src/com/android/inputmethod/dictionarypack/EventHandler.java @@ -16,13 +16,9 @@ package com.android.inputmethod.dictionarypack; -import com.android.inputmethod.latin.LatinIME; -import com.android.inputmethod.latin.R; - import android.content.BroadcastReceiver; import android.content.Context; import android.content.Intent; -import android.util.Log; public final class EventHandler extends BroadcastReceiver { private static final String TAG = EventHandler.class.getName(); diff --git a/java/src/com/android/inputmethod/dictionarypack/UpdateHandler.java b/java/src/com/android/inputmethod/dictionarypack/UpdateHandler.java index b4727509c..e05a79b7b 100644 --- a/java/src/com/android/inputmethod/dictionarypack/UpdateHandler.java +++ b/java/src/com/android/inputmethod/dictionarypack/UpdateHandler.java @@ -444,7 +444,19 @@ public final class UpdateHandler { manager.remove(fileId); } - private static void publishUpdateMetadataCompleted(final Context context, + /** + * Sends a broadcast informing listeners that the dictionaries were updated. + * + * This will call all local listeners through the UpdateEventListener#downloadedMetadata + * callback (for example, the dictionary provider interface uses this to stop the Loading + * animation) and send a broadcast about the metadata having been updated. For a client of + * the dictionary pack like Latin IME, this means it should re-query the dictionary pack + * for any relevant new data. + * + * @param context the context, to send the broadcast. + * @param downloadSuccessful whether the download of the metadata was successful or not. + */ + public static void publishUpdateMetadataCompleted(final Context context, final boolean downloadSuccessful) { // We need to warn all listeners of what happened. But some listeners may want to // remove themselves or re-register something in response. Hence we should take a diff --git a/java/src/com/android/inputmethod/keyboard/internal/GesturePreviewTrail.java b/java/src/com/android/inputmethod/keyboard/internal/GesturePreviewTrail.java index b047fe038..e3e6d39e4 100644 --- a/java/src/com/android/inputmethod/keyboard/internal/GesturePreviewTrail.java +++ b/java/src/com/android/inputmethod/keyboard/internal/GesturePreviewTrail.java @@ -44,6 +44,7 @@ final class GesturePreviewTrail { // The wall time of the zero value in {@link #mEventTimes} private long mCurrentTimeBase; private int mTrailStartIndex; + private int mLastInterpolatedDrawIndex; static final class Params { public final int mTrailColor; @@ -96,6 +97,17 @@ final class GesturePreviewTrail { } final int[] eventTimes = mEventTimes.getPrimitiveArray(); final int strokeId = stroke.getGestureStrokeId(); + // Because interpolation algorithm in {@link GestureStrokeWithPreviewPoints} can't determine + // the interpolated points in the last segment of gesture stroke, it may need recalculation + // of interpolation when new segments are added to the stroke. + // {@link #mLastInterpolatedDrawIndex} holds the start index of the last segment. It may + // be updated by the interpolation + // {@link GestureStrokeWithPreviewPoints#interpolatePreviewStroke} + // or by animation {@link #drawGestureTrail(Canvas,Paint,Rect,Params)} below. + final int lastInterpolatedIndex = (strokeId == mCurrentStrokeId) + ? mLastInterpolatedDrawIndex : trailSize; + mLastInterpolatedDrawIndex = stroke.interpolateStrokeAndReturnStartIndexOfLastSegment( + lastInterpolatedIndex, mEventTimes, mXCoordinates, mYCoordinates); if (strokeId != mCurrentStrokeId) { final int elapsedTime = (int)(downTime - mCurrentTimeBase); for (int i = mTrailStartIndex; i < trailSize; i++) { @@ -216,6 +228,10 @@ final class GesturePreviewTrail { System.arraycopy(eventTimes, startIndex, eventTimes, 0, newSize); System.arraycopy(xCoords, startIndex, xCoords, 0, newSize); System.arraycopy(yCoords, startIndex, yCoords, 0, newSize); + // The start index of the last segment of the stroke + // {@link mLastInterpolatedDrawIndex} should also be updated because all array + // elements have just been shifted for compaction. + mLastInterpolatedDrawIndex = Math.max(mLastInterpolatedDrawIndex - startIndex, 0); } mEventTimes.setLength(newSize); mXCoordinates.setLength(newSize); diff --git a/java/src/com/android/inputmethod/keyboard/internal/GestureStrokeWithPreviewPoints.java b/java/src/com/android/inputmethod/keyboard/internal/GestureStrokeWithPreviewPoints.java index fc81410ff..3315954c1 100644 --- a/java/src/com/android/inputmethod/keyboard/internal/GestureStrokeWithPreviewPoints.java +++ b/java/src/com/android/inputmethod/keyboard/internal/GestureStrokeWithPreviewPoints.java @@ -21,19 +21,32 @@ import com.android.inputmethod.latin.ResizableIntArray; public final class GestureStrokeWithPreviewPoints extends GestureStroke { public static final int PREVIEW_CAPACITY = 256; + private static final boolean ENABLE_INTERPOLATION = true; + private final ResizableIntArray mPreviewEventTimes = new ResizableIntArray(PREVIEW_CAPACITY); private final ResizableIntArray mPreviewXCoordinates = new ResizableIntArray(PREVIEW_CAPACITY); private final ResizableIntArray mPreviewYCoordinates = new ResizableIntArray(PREVIEW_CAPACITY); private int mStrokeId; private int mLastPreviewSize; + private final HermiteInterpolator mInterpolator = new HermiteInterpolator(); + private int mLastInterpolatedPreviewIndex; - private int mMinPreviewSampleLengthSquare; + private int mMinPreviewSamplingDistanceSquared; private int mLastX; private int mLastY; + private double mMinPreviewSamplingDistance; + private double mDistanceFromLastSample; - // TODO: Move this to resource. - private static final float MIN_PREVIEW_SAMPLE_LENGTH_RATIO_TO_KEY_WIDTH = 0.1f; + // TODO: Move these constants to resource. + // The minimum linear distance between sample points for preview in keyWidth unit. + private static final float MIN_PREVIEW_SAMPLING_RATIO_TO_KEY_WIDTH = 0.1f; + // The minimum trail distance between sample points for preview in keyWidth unit when using + // interpolation. + private static final float MIN_PREVIEW_SAMPLING_RATIO_TO_KEY_WIDTH_WITH_INTERPOLATION = 0.2f; + // The angular threshold to use interpolation in radian. PI/12 is 15 degree. + private static final double INTERPOLATION_ANGULAR_THRESHOLD = Math.PI / 12.0d; + private static final int MAX_INTERPOLATION_PARTITION = 4; public GestureStrokeWithPreviewPoints(final int pointerId, final GestureStrokeParams params) { super(pointerId, params); @@ -44,6 +57,7 @@ public final class GestureStrokeWithPreviewPoints extends GestureStroke { super.reset(); mStrokeId++; mLastPreviewSize = 0; + mLastInterpolatedPreviewIndex = 0; mPreviewEventTimes.setLength(0); mPreviewXCoordinates.setLength(0); mPreviewYCoordinates.setLength(0); @@ -53,35 +67,49 @@ public final class GestureStrokeWithPreviewPoints extends GestureStroke { return mStrokeId; } - public int getGestureStrokePreviewSize() { - return mPreviewEventTimes.getLength(); - } - @Override public void setKeyboardGeometry(final int keyWidth, final int keyboardHeight) { super.setKeyboardGeometry(keyWidth, keyboardHeight); - final float sampleLength = keyWidth * MIN_PREVIEW_SAMPLE_LENGTH_RATIO_TO_KEY_WIDTH; - mMinPreviewSampleLengthSquare = (int)(sampleLength * sampleLength); + final float samplingRatioToKeyWidth = ENABLE_INTERPOLATION + ? MIN_PREVIEW_SAMPLING_RATIO_TO_KEY_WIDTH_WITH_INTERPOLATION + : MIN_PREVIEW_SAMPLING_RATIO_TO_KEY_WIDTH; + mMinPreviewSamplingDistance = keyWidth * samplingRatioToKeyWidth; + mMinPreviewSamplingDistanceSquared = (int)( + mMinPreviewSamplingDistance * mMinPreviewSamplingDistance); } - private boolean needsSampling(final int x, final int y) { + private boolean needsSampling(final int x, final int y, final boolean isMajorEvent) { + if (ENABLE_INTERPOLATION) { + mDistanceFromLastSample += Math.hypot(x - mLastX, y - mLastY); + mLastX = x; + mLastY = y; + if (mDistanceFromLastSample >= mMinPreviewSamplingDistance) { + mDistanceFromLastSample = 0.0d; + return true; + } + return false; + } + final int dx = x - mLastX; final int dy = y - mLastY; - return dx * dx + dy * dy >= mMinPreviewSampleLengthSquare; + if (isMajorEvent || dx * dx + dy * dy >= mMinPreviewSamplingDistanceSquared) { + mLastX = x; + mLastY = y; + return true; + } + return false; } @Override public boolean addPointOnKeyboard(final int x, final int y, final int time, final boolean isMajorEvent) { - final boolean onValidArea = super.addPointOnKeyboard(x, y, time, isMajorEvent); - if (isMajorEvent || needsSampling(x, y)) { + if (needsSampling(x, y, isMajorEvent)) { mPreviewEventTimes.add(time); mPreviewXCoordinates.add(x); mPreviewYCoordinates.add(y); - mLastX = x; - mLastY = y; } - return onValidArea; + return super.addPointOnKeyboard(x, y, time, isMajorEvent); + } public void appendPreviewStroke(final ResizableIntArray eventTimes, @@ -95,4 +123,82 @@ public final class GestureStrokeWithPreviewPoints extends GestureStroke { yCoords.append(mPreviewYCoordinates, mLastPreviewSize, length); mLastPreviewSize = mPreviewEventTimes.getLength(); } + + /** + * Calculate interpolated points between the last interpolated point and the end of the trail. + * And return the start index of the last interpolated segment of input arrays because it + * may need to recalculate the interpolated points in the segment if further segments are + * added to this stroke. + * + * @param lastInterpolatedIndex the start index of the last interpolated segment of + * <code>eventTimes</code>, <code>xCoords</code>, and <code>yCoords</code>. + * @param eventTimes the event time array of gesture preview trail to be drawn. + * @param xCoords the x-coordinates array of gesture preview trail to be drawn. + * @param yCoords the y-coordinates array of gesture preview trail to be drawn. + * @return the start index of the last interpolated segment of input arrays. + */ + public int interpolateStrokeAndReturnStartIndexOfLastSegment(final int lastInterpolatedIndex, + final ResizableIntArray eventTimes, final ResizableIntArray xCoords, + final ResizableIntArray yCoords) { + if (!ENABLE_INTERPOLATION) { + return lastInterpolatedIndex; + } + final int size = mPreviewEventTimes.getLength(); + final int[] pt = mPreviewEventTimes.getPrimitiveArray(); + final int[] px = mPreviewXCoordinates.getPrimitiveArray(); + final int[] py = mPreviewYCoordinates.getPrimitiveArray(); + mInterpolator.reset(px, py, 0, size); + // The last segment of gesture stroke needs to be interpolated again because the slope of + // the tangent at the last point isn't determined. + int lastInterpolatedDrawIndex = lastInterpolatedIndex; + int d1 = lastInterpolatedIndex; + for (int p2 = mLastInterpolatedPreviewIndex + 1; p2 < size; p2++) { + final int p1 = p2 - 1; + final int p0 = p1 - 1; + final int p3 = p2 + 1; + mLastInterpolatedPreviewIndex = p1; + lastInterpolatedDrawIndex = d1; + mInterpolator.setInterval(p0, p1, p2, p3); + final double m1 = Math.atan2(mInterpolator.mSlope1Y, mInterpolator.mSlope1X); + final double m2 = Math.atan2(mInterpolator.mSlope2Y, mInterpolator.mSlope2X); + final double dm = Math.abs(angularDiff(m2, m1)); + final int partition = Math.min((int)Math.ceil(dm / INTERPOLATION_ANGULAR_THRESHOLD), + MAX_INTERPOLATION_PARTITION); + final int t1 = eventTimes.get(d1); + final int dt = pt[p2] - pt[p1]; + d1++; + for (int i = 1; i < partition; i++) { + final float t = i / (float)partition; + mInterpolator.interpolate(t); + eventTimes.add(d1, (int)(dt * t) + t1); + xCoords.add(d1, (int)mInterpolator.mInterpolatedX); + yCoords.add(d1, (int)mInterpolator.mInterpolatedY); + d1++; + } + eventTimes.add(d1, pt[p2]); + xCoords.add(d1, px[p2]); + yCoords.add(d1, py[p2]); + } + return lastInterpolatedDrawIndex; + } + + private static final double TWO_PI = Math.PI * 2.0d; + + /** + * Calculate the angular of rotation from <code>a0</code> to <code>a1</code>. + * + * @param a1 the angular to which the rotation ends. + * @param a0 the angular from which the rotation starts. + * @return the angular rotation value from a0 to a1, normalized to [-PI, +PI]. + */ + private static double angularDiff(final double a1, final double a0) { + double deltaAngle = a1 - a0; + while (deltaAngle > Math.PI) { + deltaAngle -= TWO_PI; + } + while (deltaAngle < -Math.PI) { + deltaAngle += TWO_PI; + } + return deltaAngle; + } } diff --git a/java/src/com/android/inputmethod/keyboard/internal/HermiteInterpolator.java b/java/src/com/android/inputmethod/keyboard/internal/HermiteInterpolator.java new file mode 100644 index 000000000..0ec8153f5 --- /dev/null +++ b/java/src/com/android/inputmethod/keyboard/internal/HermiteInterpolator.java @@ -0,0 +1,166 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.inputmethod.keyboard.internal; + +import com.android.inputmethod.annotations.UsedForTesting; + +/** + * Interpolates XY-coordinates using Cubic Hermite Curve. + */ +public final class HermiteInterpolator { + private int[] mXCoords; + private int[] mYCoords; + private int mMinPos; + private int mMaxPos; + + // Working variable to calculate interpolated value. + /** The coordinates of the start point of the interval. */ + public int mP1X, mP1Y; + /** The coordinates of the end point of the interval. */ + public int mP2X, mP2Y; + /** The slope of the tangent at the start point. */ + public float mSlope1X, mSlope1Y; + /** The slope of the tangent at the end point. */ + public float mSlope2X, mSlope2Y; + /** The interpolated coordinates. + * The return variables of {@link #interpolate(float)} to avoid instantiations. + */ + public float mInterpolatedX, mInterpolatedY; + + public HermiteInterpolator() { + // Nothing to do with here. + } + + /** + * Reset this interpolator to point XY-coordinates data. + * @param xCoords the array of x-coordinates. Valid data are in left-open interval + * <code>[minPos, maxPos)</code>. + * @param yCoords the array of y-coordinates. Valid data are in left-open interval + * <code>[minPos, maxPos)</code>. + * @param minPos the minimum index of left-open interval of valid data. + * @param maxPos the maximum index of left-open interval of valid data. + */ + @UsedForTesting + public void reset(final int[] xCoords, final int[] yCoords, final int minPos, + final int maxPos) { + mXCoords = xCoords; + mYCoords = yCoords; + mMinPos = minPos; + mMaxPos = maxPos; + } + + /** + * Set interpolation interval. + * <p> + * The start and end coordinates of the interval will be set in {@link #mP1X}, {@link #mP1Y}, + * {@link #mP2X}, and {@link #mP2Y}. The slope of the tangents at start and end points will be + * set in {@link #mSlope1X}, {@link #mSlope1Y}, {@link #mSlope2X}, and {@link #mSlope2Y}. + * + * @param p0 the index just before interpolation interval. If <code>p1</code> points the start + * of valid points, <code>p0</code> must be less than <code>minPos</code> of + * {@link #reset(int[],int[],int,int)}. + * @param p1 the start index of interpolation interval. + * @param p2 the end index of interpolation interval. + * @param p3 the index just after interpolation interval. If <code>p2</code> points the end of + * valid points, <code>p3</code> must be equal or greater than <code>maxPos</code> of + * {@link #reset(int[],int[],int,int)}. + */ + @UsedForTesting + public void setInterval(final int p0, final int p1, final int p2, final int p3) { + mP1X = mXCoords[p1]; + mP1Y = mYCoords[p1]; + mP2X = mXCoords[p2]; + mP2Y = mYCoords[p2]; + // A(ax,ay) is the vector p1->p2. + final int ax = mP2X - mP1X; + final int ay = mP2Y - mP1Y; + + // Calculate the slope of the tangent at p1. + if (p0 >= mMinPos) { + // p1 has previous valid point p0. + // The slope of the tangent is half of the vector p0->p2. + mSlope1X = (mP2X - mXCoords[p0]) / 2.0f; + mSlope1Y = (mP2Y - mYCoords[p0]) / 2.0f; + } else if (p3 < mMaxPos) { + // p1 has no previous valid point, but p2 has next valid point p3. + // B(bx,by) is the slope vector of the tangent at p2. + final float bx = (mXCoords[p3] - mP1X) / 2.0f; + final float by = (mYCoords[p3] - mP1Y) / 2.0f; + final float crossProdAB = ax * by - ay * bx; + final float dotProdAB = ax * bx + ay * by; + final float normASquare = ax * ax + ay * ay; + final float invHalfNormASquare = 1.0f / normASquare / 2.0f; + // The slope of the tangent is the mirror image of vector B to vector A. + mSlope1X = invHalfNormASquare * (dotProdAB * ax + crossProdAB * ay); + mSlope1Y = invHalfNormASquare * (dotProdAB * ay - crossProdAB * ax); + } else { + // p1 and p2 have no previous valid point. (Interval has only point p1 and p2) + mSlope1X = ax; + mSlope1Y = ay; + } + + // Calculate the slope of the tangent at p2. + if (p3 < mMaxPos) { + // p2 has next valid point p3. + // The slope of the tangent is half of the vector p1->p3. + mSlope2X = (mXCoords[p3] - mP1X) / 2.0f; + mSlope2Y = (mYCoords[p3] - mP1Y) / 2.0f; + } else if (p0 >= mMinPos) { + // p2 has no next valid point, but p1 has previous valid point p0. + // B(bx,by) is the slope vector of the tangent at p1. + final float bx = (mP2X - mXCoords[p0]) / 2.0f; + final float by = (mP2Y - mYCoords[p0]) / 2.0f; + final float crossProdAB = ax * by - ay * bx; + final float dotProdAB = ax * bx + ay * by; + final float normASquare = ax * ax + ay * ay; + final float invHalfNormASquare = 1.0f / normASquare / 2.0f; + // The slope of the tangent is the mirror image of vector B to vector A. + mSlope2X = invHalfNormASquare * (dotProdAB * ax + crossProdAB * ay); + mSlope2Y = invHalfNormASquare * (dotProdAB * ay - crossProdAB * ax); + } else { + // p1 and p2 has no previous valid point. (Interval has only point p1 and p2) + mSlope2X = ax; + mSlope2Y = ay; + } + } + + /** + * Calculate interpolation value at <code>t</code> in unit interval <code>[0,1]</code>. + * <p> + * On the unit interval [0,1], given a starting point p1 at t=0 and an ending point p2 at t=1 + * with the slope of the tangent m1 at p1 and m2 at p2, the polynomial of cubic Hermite curve + * can be defined by + * p(t) = (1+2t)(1-t)(1-t)*p1 + t(1-t)(1-t)*m1 + (3-2t)t^2*p2 + (t-1)t^2*m2 + * where t is an element of [0,1]. + * <p> + * The interpolated XY-coordinates will be set in {@link #mInterpolatedX} and + * {@link #mInterpolatedY}. + * + * @param t the interpolation parameter. The value must be in close interval <code>[0,1]</code>. + */ + @UsedForTesting + public void interpolate(final float t) { + final float omt = 1.0f - t; + final float tm2 = 2.0f * t; + final float k1 = 1.0f + tm2; + final float k2 = 3.0f - tm2; + final float omt2 = omt * omt; + final float t2 = t * t; + mInterpolatedX = (k1 * mP1X + t * mSlope1X) * omt2 + (k2 * mP2X - omt * mSlope2X) * t2; + mInterpolatedY = (k1 * mP1Y + t * mSlope1Y) * omt2 + (k2 * mP2Y - omt * mSlope2Y) * t2; + } +} diff --git a/java/src/com/android/inputmethod/latin/BinaryDictionaryFileDumper.java b/java/src/com/android/inputmethod/latin/BinaryDictionaryFileDumper.java index 4bec99c04..562e1d0b7 100644 --- a/java/src/com/android/inputmethod/latin/BinaryDictionaryFileDumper.java +++ b/java/src/com/android/inputmethod/latin/BinaryDictionaryFileDumper.java @@ -450,4 +450,25 @@ public final class BinaryDictionaryFileDumper { info.toContentValues()); } } + + /** + * Initialize a client record with the dictionary content provider. + * + * This merely acquires the content provider and calls + * #reinitializeClientRecordInDictionaryContentProvider. + * + * @param context the context for resources and providers. + * @param clientId the client ID to use. + */ + public static void initializeClientRecordHelper(final Context context, + final String clientId) { + try { + final ContentProviderClient client = context.getContentResolver(). + acquireContentProviderClient(getProviderUriBuilder("").build()); + if (null == client) return; + reinitializeClientRecordInDictionaryContentProvider(context, client, clientId); + } catch (RemoteException e) { + Log.e(TAG, "Cannot contact the dictionary content provider", e); + } + } } diff --git a/java/src/com/android/inputmethod/latin/DictionaryPackInstallBroadcastReceiver.java b/java/src/com/android/inputmethod/latin/DictionaryPackInstallBroadcastReceiver.java index 35f3119ea..41fcb83e6 100644 --- a/java/src/com/android/inputmethod/latin/DictionaryPackInstallBroadcastReceiver.java +++ b/java/src/com/android/inputmethod/latin/DictionaryPackInstallBroadcastReceiver.java @@ -25,14 +25,35 @@ import android.content.pm.PackageInfo; import android.content.pm.PackageManager; import android.content.pm.ProviderInfo; import android.net.Uri; +import android.util.Log; /** - * Takes action to reload the necessary data when a dictionary pack was added/removed. + * Receives broadcasts pertaining to dictionary management and takes the appropriate action. + * + * This object receives three types of broadcasts. + * - Package installed/added. When a dictionary provider application is added or removed, we + * need to query the dictionaries. + * - New dictionary broadcast. The dictionary provider broadcasts new dictionary availability. When + * this happens, we need to re-query the dictionaries. + * - Unknown client. If the dictionary provider is in urgent need of data about some client that + * it does not know, it sends this broadcast. When we receive this, we need to tell the dictionary + * provider about ourselves. This happens when the settings for the dictionary pack are accessed, + * but Latin IME never got a chance to register itself. */ public final class DictionaryPackInstallBroadcastReceiver extends BroadcastReceiver { + private static final String TAG = DictionaryPackInstallBroadcastReceiver.class.getSimpleName(); final LatinIME mService; + public DictionaryPackInstallBroadcastReceiver() { + // This empty constructor is necessary for the system to instantiate this receiver. + // This happens when the dictionary pack says it can't find a record for our client, + // which happens when the dictionary pack settings are called before the keyboard + // was ever started once. + Log.i(TAG, "Latin IME dictionary broadcast receiver instantiated from the framework."); + mService = null; + } + public DictionaryPackInstallBroadcastReceiver(final LatinIME service) { mService = service; } @@ -44,6 +65,11 @@ public final class DictionaryPackInstallBroadcastReceiver extends BroadcastRecei // We need to reread the dictionary if a new dictionary package is installed. if (action.equals(Intent.ACTION_PACKAGE_ADDED)) { + if (null == mService) { + Log.e(TAG, "Called with intent " + action + " but we don't know the service: this " + + "should never happen"); + return; + } final Uri packageUri = intent.getData(); if (null == packageUri) return; // No package name : we can't do anything final String packageName = packageUri.getSchemeSpecificPart(); @@ -71,6 +97,11 @@ public final class DictionaryPackInstallBroadcastReceiver extends BroadcastRecei return; } else if (action.equals(Intent.ACTION_PACKAGE_REMOVED) && !intent.getBooleanExtra(Intent.EXTRA_REPLACING, false)) { + if (null == mService) { + Log.e(TAG, "Called with intent " + action + " but we don't know the service: this " + + "should never happen"); + return; + } // When the dictionary package is removed, we need to reread dictionary (to use the // next-priority one, or stop using a dictionary at all if this was the only one, // since this is the user request). @@ -82,7 +113,28 @@ public final class DictionaryPackInstallBroadcastReceiver extends BroadcastRecei // read dictionary from? mService.resetSuggestMainDict(); } else if (action.equals(DictionaryPackConstants.NEW_DICTIONARY_INTENT_ACTION)) { + if (null == mService) { + Log.e(TAG, "Called with intent " + action + " but we don't know the service: this " + + "should never happen"); + return; + } mService.resetSuggestMainDict(); + } else if (action.equals(DictionaryPackConstants.UNKNOWN_DICTIONARY_PROVIDER_CLIENT)) { + if (null != mService) { + // Careful! This is returning if the service is NOT null. This is because we + // should come here instantiated by the framework in reaction to a broadcast of + // the above action, so we should gave gone through the no-args constructor. + Log.e(TAG, "Called with intent " + action + " but we have a reference to the " + + "service: this should never happen"); + return; + } + // The dictionary provider does not know about some client. We check that it's really + // us that it needs to know about, and if it's the case, we register with the provider. + final String wantedClientId = + intent.getStringExtra(DictionaryPackConstants.DICTIONARY_PROVIDER_CLIENT_EXTRA); + final String myClientId = context.getString(R.string.dictionary_pack_client_id); + if (!wantedClientId.equals(myClientId)) return; // Not for us + BinaryDictionaryFileDumper.initializeClientRecordHelper(context, myClientId); } } } diff --git a/java/src/com/android/inputmethod/latin/LatinIME.java b/java/src/com/android/inputmethod/latin/LatinIME.java index 92b68dcd7..0fc26a80e 100644 --- a/java/src/com/android/inputmethod/latin/LatinIME.java +++ b/java/src/com/android/inputmethod/latin/LatinIME.java @@ -1143,11 +1143,11 @@ public final class LatinIME extends InputMethodService implements KeyboardAction if (!mWordComposer.isComposingWord()) return; final String typedWord = mWordComposer.getTypedWord(); if (typedWord.length() > 0) { - commitChosenWord(typedWord, LastComposedWord.COMMIT_TYPE_USER_TYPED_WORD, - separatorString); if (ProductionFlag.USES_DEVELOPMENT_ONLY_DIAGNOSTICS) { ResearchLogger.getInstance().onWordFinished(typedWord, mWordComposer.isBatchMode()); } + commitChosenWord(typedWord, LastComposedWord.COMMIT_TYPE_USER_TYPED_WORD, + separatorString); } } diff --git a/java/src/com/android/inputmethod/research/ResearchLogger.java b/java/src/com/android/inputmethod/research/ResearchLogger.java index aa4c03357..fbfa9c977 100644 --- a/java/src/com/android/inputmethod/research/ResearchLogger.java +++ b/java/src/com/android/inputmethod/research/ResearchLogger.java @@ -1636,8 +1636,7 @@ public class ResearchLogger implements SharedPreferences.OnSharedPreferenceChang final String scrubbedAutoCorrection = scrubDigitsFromString(autoCorrection); final ResearchLogger researchLogger = getInstance(); researchLogger.mCurrentLogUnit.initializeSuggestions(suggestedWords); - researchLogger.commitCurrentLogUnitAsWord(scrubbedAutoCorrection, Long.MAX_VALUE, - isBatchMode); + researchLogger.onWordFinished(scrubbedAutoCorrection, isBatchMode); // Add the autocorrection logStatement at the end of the logUnit for the committed word. // We have to do this after calling commitCurrentLogUnitAsWord, because it may split the diff --git a/native/jni/Android.mk b/native/jni/Android.mk index 12f99eb52..b476fc3d1 100644 --- a/native/jni/Android.mk +++ b/native/jni/Android.mk @@ -29,7 +29,9 @@ LATIN_IME_SRC_FULLPATH_DIR := $(LOCAL_PATH)/$(LATIN_IME_SRC_DIR) LOCAL_C_INCLUDES += \ $(LATIN_IME_SRC_FULLPATH_DIR) \ $(LATIN_IME_SRC_FULLPATH_DIR)/suggest \ - $(LATIN_IME_SRC_FULLPATH_DIR)/suggest/core/dicnode + $(LATIN_IME_SRC_FULLPATH_DIR)/suggest/core \ + $(addprefix $(LATIN_IME_SRC_FULLPATH_DIR)/suggest/core/, dicnode dictionary policy session) \ + $(LATIN_IME_SRC_FULLPATH_DIR)/suggest/policyimpl/typing LOCAL_CFLAGS += -Werror -Wall -Wextra -Weffc++ -Wformat=2 -Wcast-qual -Wcast-align \ -Wwrite-strings -Wfloat-equal -Wpointer-arith -Winit-self -Wredundant-decls -Wno-system-headers @@ -63,7 +65,16 @@ LATIN_IME_CORE_SRC_FILES := \ unigram_dictionary.cpp \ words_priority_queue.cpp \ suggest/core/dicnode/dic_node.cpp \ + suggest/core/dicnode/dic_nodes_cache.cpp \ suggest/core/dicnode/dic_node_utils.cpp \ + suggest/core/policy/weighting.cpp \ + suggest/core/session/dic_traverse_session.cpp \ + suggest/core/suggest.cpp \ + suggest/policyimpl/typing/scoring_params.cpp \ + suggest/policyimpl/typing/typing_scoring.cpp \ + suggest/policyimpl/typing/typing_suggest_policy.cpp \ + suggest/policyimpl/typing/typing_traversal.cpp \ + suggest/policyimpl/typing/typing_weighting.cpp \ suggest/gesture_suggest.cpp \ suggest/typing_suggest.cpp diff --git a/native/jni/src/suggest/core/dicnode/dic_nodes_cache.cpp b/native/jni/src/suggest/core/dicnode/dic_nodes_cache.cpp new file mode 100644 index 000000000..b9a60780b --- /dev/null +++ b/native/jni/src/suggest/core/dicnode/dic_nodes_cache.cpp @@ -0,0 +1,59 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <list> + +#include "defines.h" +#include "dic_node_priority_queue.h" +#include "dic_node_utils.h" +#include "dic_nodes_cache.h" + +namespace latinime { + +/** + * Truncates all of the dicNodes so that they start at the given commit point. + * Only called for multi-word typing input. + */ +DicNode *DicNodesCache::setCommitPoint(int commitPoint) { + std::list<DicNode> dicNodesList; + while (mCachedDicNodesForContinuousSuggestion->getSize() > 0) { + DicNode dicNode; + mCachedDicNodesForContinuousSuggestion->copyPop(&dicNode); + dicNodesList.push_front(dicNode); + } + + // Get the starting words of the top scoring dicNode (last dicNode popped from priority queue) + // up to the commit point. These words have already been committed to the text view. + DicNode *topDicNode = &dicNodesList.front(); + DicNode topDicNodeCopy; + DicNodeUtils::initByCopy(topDicNode, &topDicNodeCopy); + + // Keep only those dicNodes that match the same starting words. + std::list<DicNode>::iterator iter; + for (iter = dicNodesList.begin(); iter != dicNodesList.end(); iter++) { + DicNode *dicNode = &*iter; + if (dicNode->truncateNode(&topDicNodeCopy, commitPoint)) { + mCachedDicNodesForContinuousSuggestion->copyPush(dicNode); + } else { + // Top dicNode should be reprocessed. + ASSERT(dicNode != topDicNode); + DicNode::managedDelete(dicNode); + } + } + mInputIndex -= commitPoint; + return topDicNode; +} +} // namespace latinime diff --git a/native/jni/src/suggest/core/dicnode/dic_nodes_cache.h b/native/jni/src/suggest/core/dicnode/dic_nodes_cache.h new file mode 100644 index 000000000..a62aa422a --- /dev/null +++ b/native/jni/src/suggest/core/dicnode/dic_nodes_cache.h @@ -0,0 +1,185 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DIC_NODES_CACHE_H +#define LATINIME_DIC_NODES_CACHE_H + +#include <stdint.h> + +#include "defines.h" +#include "dic_node_priority_queue.h" + +#define INITIAL_QUEUE_ID_ACTIVE 0 +#define INITIAL_QUEUE_ID_NEXT_ACTIVE 1 +#define INITIAL_QUEUE_ID_TERMINAL 2 +#define INITIAL_QUEUE_ID_CACHE_FOR_CONTINUOUS_SUGGESTION 3 +#define PRIORITY_QUEUES_SIZE 4 + +namespace latinime { + +class DicNode; + +/** + * Class for controlling dicNode search priority queue and lexicon trie traversal. + */ +class DicNodesCache { + public: + AK_FORCE_INLINE DicNodesCache() + : mActiveDicNodes(&mDicNodePriorityQueues[INITIAL_QUEUE_ID_ACTIVE]), + mNextActiveDicNodes(&mDicNodePriorityQueues[INITIAL_QUEUE_ID_NEXT_ACTIVE]), + mTerminalDicNodes(&mDicNodePriorityQueues[INITIAL_QUEUE_ID_TERMINAL]), + mCachedDicNodesForContinuousSuggestion( + &mDicNodePriorityQueues[INITIAL_QUEUE_ID_CACHE_FOR_CONTINUOUS_SUGGESTION]), + mInputIndex(0), mLastCachedInputIndex(0) { + } + + AK_FORCE_INLINE virtual ~DicNodesCache() {} + + AK_FORCE_INLINE void reset(const int nextActiveSize, const int terminalSize) { + mInputIndex = 0; + mLastCachedInputIndex = 0; + mActiveDicNodes->reset(); + mNextActiveDicNodes->clearAndResize(nextActiveSize); + mTerminalDicNodes->clearAndResize(terminalSize); + mCachedDicNodesForContinuousSuggestion->reset(); + } + + AK_FORCE_INLINE void continueSearch() { + resetTemporaryCaches(); + restoreActiveDicNodesFromCache(); + } + + AK_FORCE_INLINE void advanceActiveDicNodes() { + if (DEBUG_DICT) { + AKLOGI("Advance active %d nodes.", mNextActiveDicNodes->getSize()); + } + if (DEBUG_DICT_FULL) { + mNextActiveDicNodes->dump(); + } + mNextActiveDicNodes = + moveNodesAndReturnReusableEmptyQueue(mNextActiveDicNodes, &mActiveDicNodes); + } + + DicNode *setCommitPoint(int commitPoint); + + int activeSize() const { return mActiveDicNodes->getSize(); } + int terminalSize() const { return mTerminalDicNodes->getSize(); } + bool isLookAheadCorrectionInputIndex(const int inputIndex) const { + return inputIndex == mInputIndex - 1; + } + void advanceInputIndex(const int inputSize) { + if (mInputIndex < inputSize) { + mInputIndex++; + } + } + + AK_FORCE_INLINE void copyPushTerminal(DicNode *dicNode) { + mTerminalDicNodes->copyPush(dicNode); + } + + AK_FORCE_INLINE void copyPushActive(DicNode *dicNode) { + mActiveDicNodes->copyPush(dicNode); + } + + AK_FORCE_INLINE bool copyPushContinue(DicNode *dicNode) { + return mCachedDicNodesForContinuousSuggestion->copyPush(dicNode); + } + + AK_FORCE_INLINE void copyPushNextActive(DicNode *dicNode) { + DicNode *pushedDicNode = mNextActiveDicNodes->copyPush(dicNode); + if (!pushedDicNode) { + if (dicNode->isCached()) { + dicNode->remove(); + } + // We simply drop any dic node that was not cached, ignoring the slim chance + // that one of its children represents what the user really wanted. + } + } + + void popTerminal(DicNode *dest) { + mTerminalDicNodes->copyPop(dest); + } + + void popActive(DicNode *dest) { + mActiveDicNodes->copyPop(dest); + } + + bool hasCachedDicNodesForContinuousSuggestion() const { + return mCachedDicNodesForContinuousSuggestion + && mCachedDicNodesForContinuousSuggestion->getSize() > 0; + } + + AK_FORCE_INLINE bool isCacheBorderForTyping(const int inputSize) const { + // TODO: Move this variable to header + static const int CACHE_BACK_LENGTH = 3; + const int cacheInputIndex = inputSize - CACHE_BACK_LENGTH; + const bool shouldCache = (cacheInputIndex == mInputIndex) + && (cacheInputIndex != mLastCachedInputIndex); + return shouldCache; + } + + AK_FORCE_INLINE void updateLastCachedInputIndex() { + mLastCachedInputIndex = mInputIndex; + } + + private: + DISALLOW_COPY_AND_ASSIGN(DicNodesCache); + + AK_FORCE_INLINE void restoreActiveDicNodesFromCache() { + if (DEBUG_DICT) { + AKLOGI("Restore %d nodes. inputIndex = %d.", + mCachedDicNodesForContinuousSuggestion->getSize(), mLastCachedInputIndex); + } + if (DEBUG_DICT_FULL || DEBUG_CACHE) { + mCachedDicNodesForContinuousSuggestion->dump(); + } + mInputIndex = mLastCachedInputIndex; + mCachedDicNodesForContinuousSuggestion = + moveNodesAndReturnReusableEmptyQueue( + mCachedDicNodesForContinuousSuggestion, &mActiveDicNodes); + } + + AK_FORCE_INLINE static DicNodePriorityQueue *moveNodesAndReturnReusableEmptyQueue( + DicNodePriorityQueue *src, DicNodePriorityQueue **dest) { + const int srcMaxSize = src->getMaxSize(); + const int destMaxSize = (*dest)->getMaxSize(); + DicNodePriorityQueue *tmp = *dest; + *dest = src; + (*dest)->setMaxSize(destMaxSize); + tmp->clearAndResize(srcMaxSize); + return tmp; + } + + AK_FORCE_INLINE void resetTemporaryCaches() { + mActiveDicNodes->clear(); + mNextActiveDicNodes->clear(); + mTerminalDicNodes->clear(); + } + + DicNodePriorityQueue mDicNodePriorityQueues[PRIORITY_QUEUES_SIZE]; + // Active dicNodes currently being expanded. + DicNodePriorityQueue *mActiveDicNodes; + // Next dicNodes to be expanded. + DicNodePriorityQueue *mNextActiveDicNodes; + // Current top terminal dicNodes. + DicNodePriorityQueue *mTerminalDicNodes; + // Cached dicNodes used for continuous suggestion. + DicNodePriorityQueue *mCachedDicNodesForContinuousSuggestion; + int mInputIndex; + int mLastCachedInputIndex; +}; +} // namespace latinime +#endif // LATINIME_DIC_NODES_CACHE_H diff --git a/native/jni/src/suggest/core/dictionary/shortcut_utils.h b/native/jni/src/suggest/core/dictionary/shortcut_utils.h new file mode 100644 index 000000000..e592136cc --- /dev/null +++ b/native/jni/src/suggest/core/dictionary/shortcut_utils.h @@ -0,0 +1,65 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_SHORTCUT_UTILS +#define LATINIME_SHORTCUT_UTILS + +#include "defines.h" +#include "dic_node_utils.h" +#include "terminal_attributes.h" + +namespace latinime { + +class ShortcutUtils { + public: + static int outputShortcuts(const TerminalAttributes *const terminalAttributes, + int outputWordIndex, const int finalScore, int *const outputCodePoints, + int *const frequencies, int *const outputTypes, const bool sameAsTyped) { + TerminalAttributes::ShortcutIterator iterator = terminalAttributes->getShortcutIterator(); + while (iterator.hasNextShortcutTarget() && outputWordIndex < MAX_RESULTS) { + int shortcutTarget[MAX_WORD_LENGTH]; + int shortcutProbability; + const int shortcutTargetStringLength = iterator.getNextShortcutTarget( + MAX_WORD_LENGTH, shortcutTarget, &shortcutProbability); + int shortcutScore; + int kind; + if (shortcutProbability == BinaryFormat::WHITELIST_SHORTCUT_PROBABILITY + && sameAsTyped) { + shortcutScore = S_INT_MAX; + kind = Dictionary::KIND_WHITELIST; + } else { + // shortcut entry's score == its base entry's score - 1 + shortcutScore = finalScore; + // Protection against int underflow + shortcutScore = max(S_INT_MIN + 1, shortcutScore) - 1; + kind = Dictionary::KIND_CORRECTION; + } + outputTypes[outputWordIndex] = kind; + frequencies[outputWordIndex] = shortcutScore; + frequencies[outputWordIndex] = max(S_INT_MIN + 1, shortcutScore) - 1; + const int startIndex2 = outputWordIndex * MAX_WORD_LENGTH; + DicNodeUtils::appendTwoWords(0, 0, shortcutTarget, shortcutTargetStringLength, + &outputCodePoints[startIndex2]); + ++outputWordIndex; + } + return outputWordIndex; + } + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(ShortcutUtils); +}; +} // namespace latinime +#endif // LATINIME_SHORTCUT_UTILS diff --git a/native/jni/src/suggest/core/policy/scoring.h b/native/jni/src/suggest/core/policy/scoring.h new file mode 100644 index 000000000..b8c10e25a --- /dev/null +++ b/native/jni/src/suggest/core/policy/scoring.h @@ -0,0 +1,57 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_SCORING_H +#define LATINIME_SCORING_H + +#include "defines.h" + +namespace latinime { + +class DicNode; +class DicTraverseSession; + +// This class basically tweaks suggestions and distances apart from CompoundDistance +class Scoring { + public: + virtual int calculateFinalScore(const float compoundDistance, const int inputSize, + const bool forceCommit) const = 0; + virtual bool getMostProbableString( + const DicTraverseSession *const traverseSession, const int terminalSize, + const float languageWeight, int *const outputCodePoints, int *const type, + int *const freq) const = 0; + virtual void safetyNetForMostProbableString(const int terminalSize, + const int maxScore, int *const outputCodePoints, int *const frequencies) const = 0; + // TODO: Make more generic + virtual void searchWordWithDoubleLetter(DicNode *terminals, + const int terminalSize, int *doubleLetterTerminalIndex, + DoubleLetterLevel *doubleLetterLevel) const = 0; + virtual float getAdjustedLanguageWeight(DicTraverseSession *const traverseSession, + DicNode *const terminals, const int size) const = 0; + virtual float getDoubleLetterDemotionDistanceCost(const int terminalIndex, + const int doubleLetterTerminalIndex, + const DoubleLetterLevel doubleLetterLevel) const = 0; + virtual bool doesAutoCorrectValidWord() const = 0; + + protected: + Scoring() {} + virtual ~Scoring() {} + + private: + DISALLOW_COPY_AND_ASSIGN(Scoring); +}; +} // namespace latinime +#endif // LATINIME_SCORING_H diff --git a/native/jni/src/suggest/core/policy/suggest_policy.h b/native/jni/src/suggest/core/policy/suggest_policy.h new file mode 100644 index 000000000..885e214f7 --- /dev/null +++ b/native/jni/src/suggest/core/policy/suggest_policy.h @@ -0,0 +1,39 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_SUGGEST_POLICY_H +#define LATINIME_SUGGEST_POLICY_H + +#include "defines.h" + +namespace latinime { +class Traversal; +class Scoring; +class Weighting; + +class SuggestPolicy { + public: + SuggestPolicy() {} + virtual ~SuggestPolicy() {} + virtual const Traversal *getTraversal() const = 0; + virtual const Scoring *getScoring() const = 0; + virtual const Weighting *getWeighting() const = 0; + + private: + DISALLOW_COPY_AND_ASSIGN(SuggestPolicy); +}; +} // namespace latinime +#endif // LATINIME_SUGGEST_POLICY_H diff --git a/native/jni/src/suggest/core/policy/traversal.h b/native/jni/src/suggest/core/policy/traversal.h new file mode 100644 index 000000000..1d5082ff8 --- /dev/null +++ b/native/jni/src/suggest/core/policy/traversal.h @@ -0,0 +1,61 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_TRAVERSAL_H +#define LATINIME_TRAVERSAL_H + +#include "defines.h" + +namespace latinime { +class Traversal { + public: + virtual int getMaxPointerCount() const = 0; + virtual bool allowsErrorCorrections(const DicNode *const dicNode) const = 0; + virtual bool isOmission(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode, const DicNode *const childDicNode) const = 0; + virtual bool isSpaceSubstitutionTerminal(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const = 0; + virtual bool isSpaceOmissionTerminal(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const = 0; + virtual bool shouldDepthLevelCache(const DicTraverseSession *const traverseSession) const = 0; + virtual bool shouldNodeLevelCache(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const = 0; + virtual bool canDoLookAheadCorrection(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const = 0; + virtual ProximityType getProximityType( + const DicTraverseSession *const traverseSession, const DicNode *const dicNode, + const DicNode *const childDicNode) const = 0; + virtual bool sameAsTyped(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const = 0; + virtual bool needsToTraverseAllUserInput() const = 0; + virtual float getMaxSpatialDistance() const = 0; + virtual bool allowPartialCommit() const = 0; + virtual int getDefaultExpandDicNodeSize() const = 0; + virtual int getMaxCacheSize() const = 0; + virtual bool isPossibleOmissionChildNode( + const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, + const DicNode *const dicNode) const = 0; + virtual bool isGoodToTraverseNextWord(const DicNode *const dicNode) const = 0; + + protected: + Traversal() {} + virtual ~Traversal() {} + + private: + DISALLOW_COPY_AND_ASSIGN(Traversal); +}; +} // namespace latinime +#endif // LATINIME_TRAVERSAL_H diff --git a/native/jni/src/suggest/core/policy/weighting.cpp b/native/jni/src/suggest/core/policy/weighting.cpp new file mode 100644 index 000000000..4d08fa0fa --- /dev/null +++ b/native/jni/src/suggest/core/policy/weighting.cpp @@ -0,0 +1,244 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "char_utils.h" +#include "defines.h" +#include "dic_node.h" +#include "dic_node_profiler.h" +#include "dic_node_utils.h" +#include "dic_traverse_session.h" +#include "hash_map_compat.h" +#include "weighting.h" + +namespace latinime { + +static inline void profile(const CorrectionType correctionType, DicNode *const node) { +#if DEBUG_DICT + switch (correctionType) { + case CT_OMISSION: + PROF_OMISSION(node->mProfiler); + return; + case CT_ADDITIONAL_PROXIMITY: + PROF_ADDITIONAL_PROXIMITY(node->mProfiler); + return; + case CT_SUBSTITUTION: + PROF_SUBSTITUTION(node->mProfiler); + return; + case CT_NEW_WORD: + PROF_NEW_WORD(node->mProfiler); + return; + case CT_MATCH: + PROF_MATCH(node->mProfiler); + return; + case CT_COMPLETION: + PROF_COMPLETION(node->mProfiler); + return; + case CT_TERMINAL: + PROF_TERMINAL(node->mProfiler); + return; + case CT_SPACE_SUBSTITUTION: + PROF_SPACE_SUBSTITUTION(node->mProfiler); + return; + case CT_INSERTION: + PROF_INSERTION(node->mProfiler); + return; + case CT_TRANSPOSITION: + PROF_TRANSPOSITION(node->mProfiler); + return; + default: + // do nothing + return; + } +#else + // do nothing +#endif +} + +/* static */ void Weighting::addCostAndForwardInputIndex(const Weighting *const weighting, + const CorrectionType correctionType, + const DicTraverseSession *const traverseSession, + const DicNode *const parentDicNode, DicNode *const dicNode, + hash_map_compat<int, int16_t> *const bigramCacheMap) { + const int inputSize = traverseSession->getInputSize(); + DicNode_InputStateG inputStateG; + inputStateG.mNeedsToUpdateInputStateG = false; // Don't use input info by default + const float spatialCost = Weighting::getSpatialCost(weighting, correctionType, + traverseSession, parentDicNode, dicNode, &inputStateG); + const float languageCost = Weighting::getLanguageCost(weighting, correctionType, + traverseSession, parentDicNode, dicNode, bigramCacheMap); + const bool edit = Weighting::isEditCorrection(correctionType); + const bool proximity = Weighting::isProximityCorrection(weighting, correctionType, + traverseSession, dicNode); + profile(correctionType, dicNode); + if (inputStateG.mNeedsToUpdateInputStateG) { + dicNode->updateInputIndexG(&inputStateG); + } else { + dicNode->forwardInputIndex(0, getForwardInputCount(correctionType), + (correctionType == CT_TRANSPOSITION)); + } + dicNode->addCost(spatialCost, languageCost, weighting->needsToNormalizeCompoundDistance(), + inputSize, edit, proximity); +} + +/* static */ float Weighting::getSpatialCost(const Weighting *const weighting, + const CorrectionType correctionType, + const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, + const DicNode *const dicNode, DicNode_InputStateG *const inputStateG) { + switch(correctionType) { + case CT_OMISSION: + return weighting->getOmissionCost(parentDicNode, dicNode); + case CT_ADDITIONAL_PROXIMITY: + // only used for typing + return weighting->getAdditionalProximityCost(); + case CT_SUBSTITUTION: + // only used for typing + return weighting->getSubstitutionCost(); + case CT_NEW_WORD: + return weighting->getNewWordCost(dicNode); + case CT_MATCH: + return weighting->getMatchedCost(traverseSession, dicNode, inputStateG); + case CT_COMPLETION: + return weighting->getCompletionCost(traverseSession, dicNode); + case CT_TERMINAL: + return weighting->getTerminalSpatialCost(traverseSession, dicNode); + case CT_SPACE_SUBSTITUTION: + return weighting->getSpaceSubstitutionCost(); + case CT_INSERTION: + return weighting->getInsertionCost(traverseSession, parentDicNode, dicNode); + case CT_TRANSPOSITION: + return weighting->getTranspositionCost(traverseSession, parentDicNode, dicNode); + default: + return 0.0f; + } +} + +/* static */ float Weighting::getLanguageCost(const Weighting *const weighting, + const CorrectionType correctionType, const DicTraverseSession *const traverseSession, + const DicNode *const parentDicNode, const DicNode *const dicNode, + hash_map_compat<int, int16_t> *const bigramCacheMap) { + switch(correctionType) { + case CT_OMISSION: + return 0.0f; + case CT_SUBSTITUTION: + return 0.0f; + case CT_NEW_WORD: + return weighting->getNewWordBigramCost(traverseSession, parentDicNode, bigramCacheMap); + case CT_MATCH: + return 0.0f; + case CT_COMPLETION: + return 0.0f; + case CT_TERMINAL: { + const float languageImprobability = + DicNodeUtils::getBigramNodeImprobability( + traverseSession->getOffsetDict(), dicNode, bigramCacheMap); + return weighting->getTerminalLanguageCost(traverseSession, dicNode, languageImprobability); + } + case CT_SPACE_SUBSTITUTION: + return 0.0f; + case CT_INSERTION: + return 0.0f; + case CT_TRANSPOSITION: + return 0.0f; + default: + return 0.0f; + } +} + +/* static */ bool Weighting::isEditCorrection(const CorrectionType correctionType) { + switch(correctionType) { + case CT_OMISSION: + return true; + case CT_ADDITIONAL_PROXIMITY: + // Should return true? + return false; + case CT_SUBSTITUTION: + // Should return true? + return false; + case CT_NEW_WORD: + return false; + case CT_MATCH: + return false; + case CT_COMPLETION: + return false; + case CT_TERMINAL: + return false; + case CT_SPACE_SUBSTITUTION: + return false; + case CT_INSERTION: + return true; + case CT_TRANSPOSITION: + return true; + default: + return false; + } +} + +/* static */ bool Weighting::isProximityCorrection(const Weighting *const weighting, + const CorrectionType correctionType, + const DicTraverseSession *const traverseSession, const DicNode *const dicNode) { + switch(correctionType) { + case CT_OMISSION: + return false; + case CT_ADDITIONAL_PROXIMITY: + return false; + case CT_SUBSTITUTION: + return false; + case CT_NEW_WORD: + return false; + case CT_MATCH: + return weighting->isProximityDicNode(traverseSession, dicNode); + case CT_COMPLETION: + return false; + case CT_TERMINAL: + return false; + case CT_SPACE_SUBSTITUTION: + return false; + case CT_INSERTION: + return false; + case CT_TRANSPOSITION: + return false; + default: + return false; + } +} + +/* static */ int Weighting::getForwardInputCount(const CorrectionType correctionType) { + switch(correctionType) { + case CT_OMISSION: + return 0; + case CT_ADDITIONAL_PROXIMITY: + return 0; + case CT_SUBSTITUTION: + return 0; + case CT_NEW_WORD: + return 0; + case CT_MATCH: + return 1; + case CT_COMPLETION: + return 0; + case CT_TERMINAL: + return 0; + case CT_SPACE_SUBSTITUTION: + return 1; + case CT_INSERTION: + return 2; + case CT_TRANSPOSITION: + return 2; + default: + return 0; + } +} +} // namespace latinime diff --git a/native/jni/src/suggest/core/policy/weighting.h b/native/jni/src/suggest/core/policy/weighting.h new file mode 100644 index 000000000..83a0f4b45 --- /dev/null +++ b/native/jni/src/suggest/core/policy/weighting.h @@ -0,0 +1,104 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_WEIGHTING_H +#define LATINIME_WEIGHTING_H + +#include "defines.h" + +namespace latinime { + +class DicNode; +class DicTraverseSession; +struct DicNode_InputStateG; + +class Weighting { + public: + static void addCostAndForwardInputIndex(const Weighting *const weighting, + const CorrectionType correctionType, + const DicTraverseSession *const traverseSession, + const DicNode *const parentDicNode, DicNode *const dicNode, + hash_map_compat<int, int16_t> *const bigramCacheMap); + + protected: + virtual float getTerminalSpatialCost(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const = 0; + + virtual float getOmissionCost( + const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0; + + virtual float getMatchedCost( + const DicTraverseSession *const traverseSession, const DicNode *const dicNode, + DicNode_InputStateG *inputStateG) const = 0; + + virtual bool isProximityDicNode(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const = 0; + + virtual float getTranspositionCost( + const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, + const DicNode *const dicNode) const = 0; + + virtual float getInsertionCost( + const DicTraverseSession *const traverseSession, + const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0; + + virtual float getNewWordCost(const DicNode *const dicNode) const = 0; + + virtual float getNewWordBigramCost( + const DicTraverseSession *const traverseSession, const DicNode *const dicNode, + hash_map_compat<int, int16_t> *const bigramCacheMap) const = 0; + + virtual float getCompletionCost( + const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const = 0; + + virtual float getTerminalLanguageCost( + const DicTraverseSession *const traverseSession, const DicNode *const dicNode, + float dicNodeLanguageImprobability) const = 0; + + virtual bool needsToNormalizeCompoundDistance() const = 0; + + virtual float getAdditionalProximityCost() const = 0; + + virtual float getSubstitutionCost() const = 0; + + virtual float getSpaceSubstitutionCost() const = 0; + + Weighting() {} + virtual ~Weighting() {} + + private: + DISALLOW_COPY_AND_ASSIGN(Weighting); + + static float getSpatialCost(const Weighting *const weighting, + const CorrectionType correctionType, const DicTraverseSession *const traverseSession, + const DicNode *const parentDicNode, const DicNode *const dicNode, + DicNode_InputStateG *const inputStateG); + static float getLanguageCost(const Weighting *const weighting, + const CorrectionType correctionType, const DicTraverseSession *const traverseSession, + const DicNode *const parentDicNode, const DicNode *const dicNode, + hash_map_compat<int, int16_t> *const bigramCacheMap); + // TODO: Move to TypingWeighting and GestureWeighting? + static bool isEditCorrection(const CorrectionType correctionType); + // TODO: Move to TypingWeighting and GestureWeighting? + static bool isProximityCorrection(const Weighting *const weighting, + const CorrectionType correctionType, const DicTraverseSession *const traverseSession, + const DicNode *const dicNode); + // TODO: Move to TypingWeighting and GestureWeighting? + static int getForwardInputCount(const CorrectionType correctionType); +}; +} // namespace latinime +#endif // LATINIME_WEIGHTING_H diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.cpp b/native/jni/src/suggest/core/session/dic_traverse_session.cpp new file mode 100644 index 000000000..1f781dd43 --- /dev/null +++ b/native/jni/src/suggest/core/session/dic_traverse_session.cpp @@ -0,0 +1,106 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "defines.h" +#include "dictionary.h" +#include "dic_node_utils.h" +#include "dic_traverse_session.h" +#include "dic_traverse_wrapper.h" +#include "jni.h" + +namespace latinime { + +const int DicTraverseSession::CACHE_START_INPUT_LENGTH_THRESHOLD = 20; + +// A factory method for DicTraverseSession +static void *getSessionInstance(JNIEnv *env, jstring localeStr) { + return new DicTraverseSession(env, localeStr); +} + +// TODO: Pass "DicTraverseSession *traverseSession" when the source code structure settles down. +static void initSessionInstance(void *traverseSession, const Dictionary *const dictionary, + const int *prevWord, const int prevWordLength) { + if (traverseSession) { + DicTraverseSession *tSession = static_cast<DicTraverseSession *>(traverseSession); + tSession->init(dictionary, prevWord, prevWordLength); + } +} + +// TODO: Pass "DicTraverseSession *traverseSession" when the source code structure settles down. +static void releaseSessionInstance(void *traverseSession) { + delete static_cast<DicTraverseSession *>(traverseSession); +} + +// An ad-hoc internal class to register the factory method defined above +class TraverseSessionFactoryRegisterer { + public: + TraverseSessionFactoryRegisterer() { + DicTraverseWrapper::setTraverseSessionFactoryMethod(getSessionInstance); + DicTraverseWrapper::setTraverseSessionInitMethod(initSessionInstance); + DicTraverseWrapper::setTraverseSessionReleaseMethod(releaseSessionInstance); + } + private: + DISALLOW_COPY_AND_ASSIGN(TraverseSessionFactoryRegisterer); +}; + +// To invoke the TraverseSessionFactoryRegisterer constructor in the global constructor. +static TraverseSessionFactoryRegisterer traverseSessionFactoryRegisterer; + +void DicTraverseSession::init(const Dictionary *const dictionary, const int *prevWord, + int prevWordLength) { + mDictionary = dictionary; + if (!prevWord) { + mPrevWordPos = NOT_VALID_WORD; + return; + } + mPrevWordPos = DicNodeUtils::getWordPos(dictionary->getOffsetDict(), prevWord, prevWordLength); +} + +void DicTraverseSession::setupForGetSuggestions(const ProximityInfo *pInfo, + const int *inputCodePoints, const int inputSize, const int *const inputXs, + const int *const inputYs, const int *const times, const int *const pointerIds, + const float maxSpatialDistance, const int maxPointerCount) { + mProximityInfo = pInfo; + mMaxPointerCount = maxPointerCount; + initializeProximityInfoStates(inputCodePoints, inputXs, inputYs, times, pointerIds, inputSize, + maxSpatialDistance, maxPointerCount); +} + +const uint8_t *DicTraverseSession::getOffsetDict() const { + return mDictionary->getOffsetDict(); +} + +void DicTraverseSession::resetCache(const int nextActiveCacheSize, const int maxWords) { + mDicNodesCache.reset(nextActiveCacheSize, maxWords); + mBigramCacheMap.clear(); + mPartiallyCommited = false; +} + +void DicTraverseSession::initializeProximityInfoStates(const int *const inputCodePoints, + const int *const inputXs, const int *const inputYs, const int *const times, + const int *const pointerIds, const int inputSize, const float maxSpatialDistance, + const int maxPointerCount) { + ASSERT(1 <= maxPointerCount && maxPointerCount <= MAX_POINTER_COUNT_G); + mInputSize = 0; + for (int i = 0; i < maxPointerCount; ++i) { + mProximityInfoStates[i].initInputParams(i, maxSpatialDistance, getProximityInfo(), + inputCodePoints, inputSize, inputXs, inputYs, times, pointerIds, + maxPointerCount == MAX_POINTER_COUNT_G + /* TODO: this is a hack. fix proximity info state */); + mInputSize += mProximityInfoStates[i].size(); + } +} +} // namespace latinime diff --git a/native/jni/src/suggest/core/session/dic_traverse_session.h b/native/jni/src/suggest/core/session/dic_traverse_session.h new file mode 100644 index 000000000..af036f82b --- /dev/null +++ b/native/jni/src/suggest/core/session/dic_traverse_session.h @@ -0,0 +1,171 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_DIC_TRAVERSE_SESSION_H +#define LATINIME_DIC_TRAVERSE_SESSION_H + +#include <stdint.h> +#include <vector> + +#include "defines.h" +#include "dic_nodes_cache.h" +#include "hash_map_compat.h" +#include "jni.h" +#include "proximity_info_state.h" + +namespace latinime { + +class Dictionary; +class ProximityInfo; + +class DicTraverseSession { + public: + AK_FORCE_INLINE DicTraverseSession(JNIEnv *env, jstring localeStr) + : mPrevWordPos(NOT_VALID_WORD), mProximityInfo(0), + mDictionary(0), mDicNodesCache(), mBigramCacheMap(), + mInputSize(0), mPartiallyCommited(false), mMaxPointerCount(1) { + // NOTE: mProximityInfoStates is an array of instances. + // No need to initialize it explicitly here. + } + + // Non virtual inline destructor -- never inherit this class + AK_FORCE_INLINE ~DicTraverseSession() {} + + void init(const Dictionary *dictionary, const int *prevWord, int prevWordLength); + // TODO: Remove and merge into init + void setupForGetSuggestions(const ProximityInfo *pInfo, const int *inputCodePoints, + const int inputSize, const int *const inputXs, const int *const inputYs, + const int *const times, const int *const pointerIds, const float maxSpatialDistance, + const int maxPointerCount); + void resetCache(const int nextActiveCacheSize, const int maxWords); + + const uint8_t *getOffsetDict() const; + bool canUseCache() const; + + //-------------------- + // getters and setters + //-------------------- + const ProximityInfo *getProximityInfo() const { return mProximityInfo; } + int getPrevWordPos() const { return mPrevWordPos; } + // TODO: REMOVE + void setPrevWordPos(int pos) { mPrevWordPos = pos; } + // TODO: Use proper parameter when changed + int getDicRootPos() const { return 0; } + DicNodesCache *getDicTraverseCache() { return &mDicNodesCache; } + hash_map_compat<int, int16_t> *getBigramCacheMap() { return &mBigramCacheMap; } + const ProximityInfoState *getProximityInfoState(int id) const { + return &mProximityInfoStates[id]; + } + int getInputSize() const { return mInputSize; } + void setPartiallyCommited() { mPartiallyCommited = true; } + bool isPartiallyCommited() const { return mPartiallyCommited; } + + bool isOnlyOnePointerUsed(int *pointerId) const { + // Not in the dictionary word + int usedPointerCount = 0; + int usedPointerId = 0; + for (int i = 0; i < mMaxPointerCount; ++i) { + if (mProximityInfoStates[i].isUsed()) { + ++usedPointerCount; + usedPointerId = i; + } + } + if (usedPointerCount != 1) { + return false; + } + *pointerId = usedPointerId; + return true; + } + + void getSearchKeys(const DicNode *node, std::vector<int> *const outputSearchKeyVector) const { + for (int i = 0; i < MAX_POINTER_COUNT_G; ++i) { + if (!mProximityInfoStates[i].isUsed()) { + continue; + } + const int pointerId = node->getInputIndex(i); + const std::vector<int> *const searchKeyVector = + mProximityInfoStates[i].getSearchKeyVector(pointerId); + outputSearchKeyVector->insert(outputSearchKeyVector->end(), searchKeyVector->begin(), + searchKeyVector->end()); + } + } + + ProximityType getProximityTypeG(const DicNode *const node, const int childCodePoint) const { + ProximityType proximityType = UNRELATED_CHAR; + for (int i = 0; i < MAX_POINTER_COUNT_G; ++i) { + if (!mProximityInfoStates[i].isUsed()) { + continue; + } + const int pointerId = node->getInputIndex(i); + proximityType = mProximityInfoStates[i].getProximityTypeG(pointerId, childCodePoint); + ASSERT(proximityType == UNRELATED_CHAR || proximityType == MATCH_CHAR); + // TODO: Make this more generic + // Currently we assume there are only two types here -- UNRELATED_CHAR + // and MATCH_CHAR + if (proximityType != UNRELATED_CHAR) { + return proximityType; + } + } + return proximityType; + } + + AK_FORCE_INLINE bool isCacheBorderForTyping(const int inputSize) const { + return mDicNodesCache.isCacheBorderForTyping(inputSize); + } + + /** + * Returns whether or not it is possible to continue suggestion from the previous search. + */ + // TODO: Remove. No need to check once the session is fully implemented. + bool isContinuousSuggestionPossible() const { + if (!mDicNodesCache.hasCachedDicNodesForContinuousSuggestion()) { + return false; + } + ASSERT(mMaxPointerCount < MAX_POINTER_COUNT_G); + for (int i = 0; i < mMaxPointerCount; ++i) { + const ProximityInfoState *const pInfoState = getProximityInfoState(i); + // If a proximity info state is not continuous suggestion possible, + // do not continue searching. + if (pInfoState->isUsed() && !pInfoState->isContinuousSuggestionPossible()) { + return false; + } + } + return true; + } + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(DicTraverseSession); + // threshold to start caching + static const int CACHE_START_INPUT_LENGTH_THRESHOLD; + void initializeProximityInfoStates(const int *const inputCodePoints, const int *const inputXs, + const int *const inputYs, const int *const times, const int *const pointerIds, + const int inputSize, const float maxSpatialDistance, const int maxPointerCount); + + int mPrevWordPos; + const ProximityInfo *mProximityInfo; + const Dictionary *mDictionary; + + DicNodesCache mDicNodesCache; + // Temporary cache for bigram frequencies + hash_map_compat<int, int16_t> mBigramCacheMap; + ProximityInfoState mProximityInfoStates[MAX_POINTER_COUNT_G]; + + int mInputSize; + bool mPartiallyCommited; + int mMaxPointerCount; +}; +} // namespace latinime +#endif // LATINIME_DIC_TRAVERSE_SESSION_H diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp new file mode 100644 index 000000000..7fba1d504 --- /dev/null +++ b/native/jni/src/suggest/core/suggest.cpp @@ -0,0 +1,518 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "char_utils.h" +#include "dictionary.h" +#include "dic_node_priority_queue.h" +#include "dic_node_vector.h" +#include "dic_traverse_session.h" +#include "proximity_info.h" +#include "scoring.h" +#include "shortcut_utils.h" +#include "suggest.h" +#include "terminal_attributes.h" +#include "traversal.h" +#include "weighting.h" + +namespace latinime { + +// Initialization of class constants. +const int Suggest::LOOKAHEAD_DIC_NODES_CACHE_SIZE = 25; +const int Suggest::MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT = 16; +const int Suggest::MIN_CONTINUOUS_SUGGESTION_INPUT_SIZE = 2; +const float Suggest::AUTOCORRECT_CLASSIFICATION_THRESHOLD = 0.33f; +const float Suggest::AUTOCORRECT_LANGUAGE_FEATURE_THRESHOLD = 0.6f; + +const bool Suggest::CORRECT_SPACE_OMISSION = true; +const bool Suggest::CORRECT_TRANSPOSITION = true; +const bool Suggest::CORRECT_INSERTION = true; +const bool Suggest::CORRECT_OMISSION_G = true; + +/** + * Returns a set of suggestions for the given input touch points. The commitPoint argument indicates + * whether to prematurely commit the suggested words up to the given point for sentence-level + * suggestion. + * + * Note: Currently does not support concurrent calls across threads. Continuous suggestion is + * automatically activated for sequential calls that share the same starting input. + * TODO: Stop detecting continuous suggestion. Start using traverseSession instead. + */ +int Suggest::getSuggestions(ProximityInfo *pInfo, void *traverseSession, + int *inputXs, int *inputYs, int *times, int *pointerIds, int *inputCodePoints, + int inputSize, int commitPoint, int *outWords, int *frequencies, int *outputIndices, + int *outputTypes) const { + PROF_OPEN; + PROF_START(0); + const float maxSpatialDistance = TRAVERSAL->getMaxSpatialDistance(); + DicTraverseSession *tSession = static_cast<DicTraverseSession *>(traverseSession); + tSession->setupForGetSuggestions(pInfo, inputCodePoints, inputSize, inputXs, inputYs, times, + pointerIds, maxSpatialDistance, TRAVERSAL->getMaxPointerCount()); + // TODO: Add the way to evaluate cache + + initializeSearch(tSession, commitPoint); + PROF_END(0); + PROF_START(1); + + // keep expanding search dicNodes until all have terminated. + while (tSession->getDicTraverseCache()->activeSize() > 0) { + expandCurrentDicNodes(tSession); + tSession->getDicTraverseCache()->advanceActiveDicNodes(); + tSession->getDicTraverseCache()->advanceInputIndex(inputSize); + } + PROF_END(1); + PROF_START(2); + const int size = outputSuggestions(tSession, frequencies, outWords, outputIndices, outputTypes); + PROF_END(2); + PROF_CLOSE; + return size; +} + +/** + * Initializes the search at the root of the lexicon trie. Note that when possible the search will + * continue suggestion from where it left off during the last call. + */ +void Suggest::initializeSearch(DicTraverseSession *traverseSession, int commitPoint) const { + if (!traverseSession->getProximityInfoState(0)->isUsed()) { + return; + } + if (TRAVERSAL->allowPartialCommit()) { + commitPoint = 0; + } + + if (traverseSession->getInputSize() > MIN_CONTINUOUS_SUGGESTION_INPUT_SIZE + && traverseSession->isContinuousSuggestionPossible()) { + if (commitPoint == 0) { + // Continue suggestion + traverseSession->getDicTraverseCache()->continueSearch(); + } else { + // Continue suggestion after partial commit. + DicNode *topDicNode = + traverseSession->getDicTraverseCache()->setCommitPoint(commitPoint); + traverseSession->setPrevWordPos(topDicNode->getPrevWordNodePos()); + traverseSession->getDicTraverseCache()->continueSearch(); + traverseSession->setPartiallyCommited(); + } + } else { + // Restart recognition at the root. + traverseSession->resetCache(TRAVERSAL->getMaxCacheSize(), MAX_RESULTS); + // Create a new dic node here + DicNode rootNode; + DicNodeUtils::initAsRoot(traverseSession->getDicRootPos(), + traverseSession->getOffsetDict(), traverseSession->getPrevWordPos(), &rootNode); + traverseSession->getDicTraverseCache()->copyPushActive(&rootNode); + } +} + +/** + * Outputs the final list of suggestions (i.e., terminal nodes). + */ +int Suggest::outputSuggestions(DicTraverseSession *traverseSession, int *frequencies, + int *outputCodePoints, int *spaceIndices, int *outputTypes) const { + const int terminalSize = min(MAX_RESULTS, + static_cast<int>(traverseSession->getDicTraverseCache()->terminalSize())); + DicNode terminals[MAX_RESULTS]; // Avoiding non-POD variable length array + + for (int index = terminalSize - 1; index >= 0; --index) { + traverseSession->getDicTraverseCache()->popTerminal(&terminals[index]); + } + + const float languageWeight = SCORING->getAdjustedLanguageWeight( + traverseSession, terminals, terminalSize); + + int outputWordIndex = 0; + // Insert most probable word at index == 0 as long as there is one terminal at least + const bool hasMostProbableString = + SCORING->getMostProbableString(traverseSession, terminalSize, languageWeight, + &outputCodePoints[0], &outputTypes[0], &frequencies[0]); + if (hasMostProbableString) { + ++outputWordIndex; + } + + // Initial value of the loop index for terminal nodes (words) + int doubleLetterTerminalIndex = -1; + DoubleLetterLevel doubleLetterLevel = NOT_A_DOUBLE_LETTER; + SCORING->searchWordWithDoubleLetter(terminals, terminalSize, + &doubleLetterTerminalIndex, &doubleLetterLevel); + + int maxScore = S_INT_MIN; + // Output suggestion results here + for (int terminalIndex = 0; terminalIndex < terminalSize && outputWordIndex < MAX_RESULTS; + ++terminalIndex) { + DicNode *terminalDicNode = &terminals[terminalIndex]; + if (DEBUG_GEO_FULL) { + terminalDicNode->dump("OUT:"); + } + const float doubleLetterCost = SCORING->getDoubleLetterDemotionDistanceCost( + terminalIndex, doubleLetterTerminalIndex, doubleLetterLevel); + const float compoundDistance = terminalDicNode->getCompoundDistance(languageWeight) + + doubleLetterCost; + const TerminalAttributes terminalAttributes(traverseSession->getOffsetDict(), + terminalDicNode->getFlags(), terminalDicNode->getAttributesPos()); + const int originalTerminalProbability = terminalDicNode->getProbability(); + + // Do not suggest words with a 0 probability, or entries that are blacklisted or do not + // represent a word. However, we should still submit their shortcuts if any. + const bool isValidWord = + originalTerminalProbability > 0 && !terminalAttributes.isBlacklistedOrNotAWord(); + // Increase output score of top typing suggestion to ensure autocorrection. + // TODO: Better integration with java side autocorrection logic. + // Force autocorrection for obvious long multi-word suggestions. + const bool isForceCommitMultiWords = TRAVERSAL->allowPartialCommit() + && (traverseSession->isPartiallyCommited() + || (traverseSession->getInputSize() >= MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT + && terminalDicNode->hasMultipleWords())); + + const int finalScore = SCORING->calculateFinalScore( + compoundDistance, traverseSession->getInputSize(), + isForceCommitMultiWords || (isValidWord && SCORING->doesAutoCorrectValidWord())); + + maxScore = max(maxScore, finalScore); + + if (TRAVERSAL->allowPartialCommit()) { + // Index for top typing suggestion should be 0. + if (isValidWord && outputWordIndex == 0) { + terminalDicNode->outputSpacePositionsResult(spaceIndices); + } + } + + // Do not suggest words with a 0 probability, or entries that are blacklisted or do not + // represent a word. However, we should still submit their shortcuts if any. + if (isValidWord) { + outputTypes[outputWordIndex] = Dictionary::KIND_CORRECTION; + frequencies[outputWordIndex] = finalScore; + // Populate the outputChars array with the suggested word. + const int startIndex = outputWordIndex * MAX_WORD_LENGTH; + terminalDicNode->outputResult(&outputCodePoints[startIndex]); + ++outputWordIndex; + } + + const bool sameAsTyped = TRAVERSAL->sameAsTyped(traverseSession, terminalDicNode); + outputWordIndex = ShortcutUtils::outputShortcuts(&terminalAttributes, outputWordIndex, + finalScore, outputCodePoints, frequencies, outputTypes, sameAsTyped); + DicNode::managedDelete(terminalDicNode); + } + + if (hasMostProbableString) { + SCORING->safetyNetForMostProbableString(terminalSize, maxScore, + &outputCodePoints[0], &frequencies[0]); + } + return outputWordIndex; +} + +/** + * Expands the dicNodes in the current search priority queue by advancing to the possible child + * nodes based on the next touch point(s) (or no touch points for lookahead) + */ +void Suggest::expandCurrentDicNodes(DicTraverseSession *traverseSession) const { + const int inputSize = traverseSession->getInputSize(); + DicNodeVector childDicNodes(TRAVERSAL->getDefaultExpandDicNodeSize()); + DicNode omissionDicNode; + + // TODO: Find more efficient caching + const bool shouldDepthLevelCache = TRAVERSAL->shouldDepthLevelCache(traverseSession); + if (shouldDepthLevelCache) { + traverseSession->getDicTraverseCache()->updateLastCachedInputIndex(); + } + if (DEBUG_CACHE) { + AKLOGI("expandCurrentDicNodes depth level cache = %d, inputSize = %d", + shouldDepthLevelCache, inputSize); + } + while (traverseSession->getDicTraverseCache()->activeSize() > 0) { + DicNode dicNode; + traverseSession->getDicTraverseCache()->popActive(&dicNode); + if (dicNode.isTotalInputSizeExceedingLimit()) { + return; + } + childDicNodes.clear(); + const int point0Index = dicNode.getInputIndex(0); + const bool canDoLookAheadCorrection = + TRAVERSAL->canDoLookAheadCorrection(traverseSession, &dicNode); + const bool isLookAheadCorrection = canDoLookAheadCorrection + && traverseSession->getDicTraverseCache()-> + isLookAheadCorrectionInputIndex(static_cast<int>(point0Index)); + const bool isCompletion = dicNode.isCompletion(inputSize); + + const bool shouldNodeLevelCache = + TRAVERSAL->shouldNodeLevelCache(traverseSession, &dicNode); + if (shouldDepthLevelCache || shouldNodeLevelCache) { + if (DEBUG_CACHE) { + dicNode.dump("PUSH_CACHE"); + } + traverseSession->getDicTraverseCache()->copyPushContinue(&dicNode); + dicNode.setCached(); + } + + if (isLookAheadCorrection) { + // The algorithm maintains a small set of "deferred" nodes that have not consumed the + // latest touch point yet. These are needed to apply look-ahead correction operations + // that require special handling of the latest touch point. For example, with insertions + // (e.g., "thiis" -> "this") the latest touch point should not be consumed at all. + if (CORRECT_TRANSPOSITION) { + processDicNodeAsTransposition(traverseSession, &dicNode); + } + if (CORRECT_INSERTION) { + processDicNodeAsInsertion(traverseSession, &dicNode); + } + } else { // !isLookAheadCorrection + // Only consider typing error corrections if the normalized compound distance is + // below a spatial distance threshold. + // NOTE: the threshold may need to be updated if scoring model changes. + // TODO: Remove. Do not prune node here. + const bool allowsErrorCorrections = TRAVERSAL->allowsErrorCorrections(&dicNode); + // Process for handling space substitution (e.g., hevis => he is) + if (allowsErrorCorrections + && TRAVERSAL->isSpaceSubstitutionTerminal(traverseSession, &dicNode)) { + createNextWordDicNode(traverseSession, &dicNode, true /* spaceSubstitution */); + } + + DicNodeUtils::getAllChildDicNodes( + &dicNode, traverseSession->getOffsetDict(), &childDicNodes); + + const int childDicNodesSize = childDicNodes.getSizeAndLock(); + for (int i = 0; i < childDicNodesSize; ++i) { + DicNode *const childDicNode = childDicNodes[i]; + if (isCompletion) { + // Handle forward lookahead when the lexicon letter exceeds the input size. + processDicNodeAsMatch(traverseSession, childDicNode); + continue; + } + if (allowsErrorCorrections + && TRAVERSAL->isOmission(traverseSession, &dicNode, childDicNode)) { + // TODO: (Gesture) Change weight between omission and substitution errors + // TODO: (Gesture) Terminal node should not be handled as omission + omissionDicNode.initByCopy(childDicNode); + processDicNodeAsOmission(traverseSession, &omissionDicNode); + } + const ProximityType proximityType = TRAVERSAL->getProximityType( + traverseSession, &dicNode, childDicNode); + switch (proximityType) { + // TODO: Consider the difference of proximityType here + case MATCH_CHAR: + case PROXIMITY_CHAR: + processDicNodeAsMatch(traverseSession, childDicNode); + break; + case ADDITIONAL_PROXIMITY_CHAR: + if (allowsErrorCorrections) { + processDicNodeAsAdditionalProximityChar(traverseSession, &dicNode, + childDicNode); + } + break; + case SUBSTITUTION_CHAR: + if (allowsErrorCorrections) { + processDicNodeAsSubstitution(traverseSession, &dicNode, childDicNode); + } + break; + case UNRELATED_CHAR: + // Just drop this node and do nothing. + break; + default: + // Just drop this node and do nothing. + break; + } + } + + // Push the node for look-ahead correction + if (allowsErrorCorrections && canDoLookAheadCorrection) { + traverseSession->getDicTraverseCache()->copyPushNextActive(&dicNode); + } + } + } +} + +void Suggest::processTerminalDicNode( + DicTraverseSession *traverseSession, DicNode *dicNode) const { + if (dicNode->getCompoundDistance() >= static_cast<float>(MAX_VALUE_FOR_WEIGHTING)) { + return; + } + if (!dicNode->isTerminalWordNode()) { + return; + } + if (TRAVERSAL->needsToTraverseAllUserInput() + && dicNode->getInputIndex(0) < traverseSession->getInputSize()) { + return; + } + + if (dicNode->shouldBeFilterdBySafetyNetForBigram()) { + return; + } + // Create a non-cached node here. + DicNode terminalDicNode; + DicNodeUtils::initByCopy(dicNode, &terminalDicNode); + Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TERMINAL, traverseSession, 0, + &terminalDicNode, traverseSession->getBigramCacheMap()); + traverseSession->getDicTraverseCache()->copyPushTerminal(&terminalDicNode); +} + +/** + * Adds the expanded dicNode to the next search priority queue. Also creates an additional next word + * (by the space omission error correction) search path if input dicNode is on a terminal node. + */ +void Suggest::processExpandedDicNode( + DicTraverseSession *traverseSession, DicNode *dicNode) const { + processTerminalDicNode(traverseSession, dicNode); + if (dicNode->getCompoundDistance() < static_cast<float>(MAX_VALUE_FOR_WEIGHTING)) { + if (TRAVERSAL->isSpaceOmissionTerminal(traverseSession, dicNode)) { + createNextWordDicNode(traverseSession, dicNode, false /* spaceSubstitution */); + } + const int allowsLookAhead = !(dicNode->hasMultipleWords() + && dicNode->isCompletion(traverseSession->getInputSize())); + if (dicNode->hasChildren() && allowsLookAhead) { + traverseSession->getDicTraverseCache()->copyPushNextActive(dicNode); + } + } + DicNode::managedDelete(dicNode); +} + +void Suggest::processDicNodeAsMatch(DicTraverseSession *traverseSession, + DicNode *childDicNode) const { + weightChildNode(traverseSession, childDicNode); + processExpandedDicNode(traverseSession, childDicNode); +} + +void Suggest::processDicNodeAsAdditionalProximityChar(DicTraverseSession *traverseSession, + DicNode *dicNode, DicNode *childDicNode) const { + Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_ADDITIONAL_PROXIMITY, + traverseSession, dicNode, childDicNode, 0 /* bigramCacheMap */); + weightChildNode(traverseSession, childDicNode); + processExpandedDicNode(traverseSession, childDicNode); +} + +void Suggest::processDicNodeAsSubstitution(DicTraverseSession *traverseSession, + DicNode *dicNode, DicNode *childDicNode) const { + Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_SUBSTITUTION, traverseSession, + dicNode, childDicNode, 0 /* bigramCacheMap */); + weightChildNode(traverseSession, childDicNode); + processExpandedDicNode(traverseSession, childDicNode); +} + +/** + * Handle the dicNode as an omission error (e.g., ths => this). Skip the current letter and consider + * matches for all possible next letters. Note that just skipping the current letter without any + * other conditions tends to flood the search dic nodes cache with omission nodes. Instead, check + * the possible *next* letters after the omission to better limit search to plausible omissions. + * Note that apostrophes are handled as omissions. + */ +void Suggest::processDicNodeAsOmission( + DicTraverseSession *traverseSession, DicNode *dicNode) const { + // If the omission is surely intentional that it should incur zero cost. + const bool isZeroCostOmission = dicNode->isZeroCostOmission(); + DicNodeVector childDicNodes; + + DicNodeUtils::getAllChildDicNodes(dicNode, traverseSession->getOffsetDict(), &childDicNodes); + + const int size = childDicNodes.getSizeAndLock(); + for (int i = 0; i < size; i++) { + DicNode *const childDicNode = childDicNodes[i]; + if (!isZeroCostOmission) { + // Treat this word as omission + Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_OMISSION, traverseSession, + dicNode, childDicNode, 0 /* bigramCacheMap */); + } + weightChildNode(traverseSession, childDicNode); + + if (!TRAVERSAL->isPossibleOmissionChildNode(traverseSession, dicNode, childDicNode)) { + DicNode::managedDelete(childDicNode); + continue; + } + processExpandedDicNode(traverseSession, childDicNode); + } +} + +/** + * Handle the dicNode as an insertion error (e.g., thiis => this). Skip the current touch point and + * consider matches for the next touch point. + */ +void Suggest::processDicNodeAsInsertion(DicTraverseSession *traverseSession, + DicNode *dicNode) const { + const int16_t pointIndex = dicNode->getInputIndex(0); + DicNodeVector childDicNodes; + DicNodeUtils::getProximityChildDicNodes(dicNode, traverseSession->getOffsetDict(), + traverseSession->getProximityInfoState(0), pointIndex + 1, true, &childDicNodes); + const int size = childDicNodes.getSizeAndLock(); + for (int i = 0; i < size; i++) { + DicNode *const childDicNode = childDicNodes[i]; + Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_INSERTION, traverseSession, + dicNode, childDicNode, 0 /* bigramCacheMap */); + processExpandedDicNode(traverseSession, childDicNode); + } +} + +/** + * Handle the dicNode as a transposition error (e.g., thsi => this). Swap the next two touch points. + */ +void Suggest::processDicNodeAsTransposition(DicTraverseSession *traverseSession, + DicNode *dicNode) const { + const int16_t pointIndex = dicNode->getInputIndex(0); + DicNodeVector childDicNodes1; + DicNodeUtils::getProximityChildDicNodes(dicNode, traverseSession->getOffsetDict(), + traverseSession->getProximityInfoState(0), pointIndex + 1, false, &childDicNodes1); + const int childSize1 = childDicNodes1.getSizeAndLock(); + for (int i = 0; i < childSize1; i++) { + if (childDicNodes1[i]->hasChildren()) { + DicNodeVector childDicNodes2; + DicNodeUtils::getProximityChildDicNodes( + childDicNodes1[i], traverseSession->getOffsetDict(), + traverseSession->getProximityInfoState(0), pointIndex, false, &childDicNodes2); + const int childSize2 = childDicNodes2.getSizeAndLock(); + for (int j = 0; j < childSize2; j++) { + DicNode *const childDicNode2 = childDicNodes2[j]; + Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TRANSPOSITION, + traverseSession, childDicNodes1[i], childDicNode2, 0 /* bigramCacheMap */); + processExpandedDicNode(traverseSession, childDicNode2); + } + } + DicNode::managedDelete(childDicNodes1[i]); + } +} + +/** + * Weight child node by aligning it to the key + */ +void Suggest::weightChildNode(DicTraverseSession *traverseSession, DicNode *dicNode) const { + const int inputSize = traverseSession->getInputSize(); + if (dicNode->isCompletion(inputSize)) { + Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_COMPLETION, traverseSession, + 0 /* parentDicNode */, dicNode, 0 /* bigramCacheMap */); + } else { // completion + Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_MATCH, traverseSession, + 0 /* parentDicNode */, dicNode, 0 /* bigramCacheMap */); + } +} + +/** + * Creates a new dicNode that represents a space insertion at the end of the input dicNode. Also + * incorporates the unigram / bigram score for the ending word into the new dicNode. + */ +void Suggest::createNextWordDicNode(DicTraverseSession *traverseSession, DicNode *dicNode, + const bool spaceSubstitution) const { + if (!TRAVERSAL->isGoodToTraverseNextWord(dicNode)) { + return; + } + + // Create a non-cached node here. + DicNode newDicNode; + DicNodeUtils::initAsRootWithPreviousWord(traverseSession->getDicRootPos(), + traverseSession->getOffsetDict(), dicNode, &newDicNode); + Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_NEW_WORD, traverseSession, dicNode, + &newDicNode, traverseSession->getBigramCacheMap()); + if (spaceSubstitution) { + // Merge this with CT_NEW_WORD + Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_SPACE_SUBSTITUTION, + traverseSession, 0, &newDicNode, 0 /* bigramCacheMap */); + } + traverseSession->getDicTraverseCache()->copyPushNextActive(&newDicNode); +} +} // namespace latinime diff --git a/native/jni/src/suggest/core/suggest.h b/native/jni/src/suggest/core/suggest.h new file mode 100644 index 000000000..75d646bdd --- /dev/null +++ b/native/jni/src/suggest/core/suggest.h @@ -0,0 +1,95 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_SUGGEST_IMPL_H +#define LATINIME_SUGGEST_IMPL_H + +#include "defines.h" +#include "suggest_interface.h" +#include "suggest_policy.h" + +namespace latinime { + +class DicNode; +class DicTraverseSession; +class ProximityInfo; +class Scoring; +class Traversal; +class Weighting; + +class Suggest : public SuggestInterface { + public: + AK_FORCE_INLINE Suggest(const SuggestPolicy *const suggestPolicy) + : TRAVERSAL(suggestPolicy->getTraversal()), + SCORING(suggestPolicy->getScoring()), WEIGHTING(suggestPolicy->getWeighting()) {} + AK_FORCE_INLINE virtual ~Suggest() {} + int getSuggestions(ProximityInfo *pInfo, void *traverseSession, int *inputXs, int *inputYs, + int *times, int *pointerIds, int *inputCodePoints, int inputSize, int commitPoint, + int *outWords, int *frequencies, int *outputIndices, int *outputTypes) const; + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(Suggest); + void createNextWordDicNode(DicTraverseSession *traverseSession, DicNode *dicNode, + const bool spaceSubstitution) const; + int outputSuggestions(DicTraverseSession *traverseSession, int *frequencies, + int *outputCodePoints, int *outputIndices, int *outputTypes) const; + void initializeSearch(DicTraverseSession *traverseSession, int commitPoint) const; + void expandCurrentDicNodes(DicTraverseSession *traverseSession) const; + void processTerminalDicNode(DicTraverseSession *traverseSession, DicNode *dicNode) const; + void processExpandedDicNode(DicTraverseSession *traverseSession, DicNode *dicNode) const; + void weightChildNode(DicTraverseSession *traverseSession, DicNode *dicNode) const; + float getAutocorrectScore(DicTraverseSession *traverseSession, DicNode *dicNode) const; + void generateFeatures( + DicTraverseSession *traverseSession, DicNode *dicNode, float *features) const; + void processDicNodeAsOmission(DicTraverseSession *traverseSession, DicNode *dicNode) const; + void processDicNodeAsTransposition(DicTraverseSession *traverseSession, + DicNode *dicNode) const; + void processDicNodeAsInsertion(DicTraverseSession *traverseSession, DicNode *dicNode) const; + void processDicNodeAsAdditionalProximityChar(DicTraverseSession *traverseSession, + DicNode *dicNode, DicNode *childDicNode) const; + void processDicNodeAsSubstitution(DicTraverseSession *traverseSession, DicNode *dicNode, + DicNode *childDicNode) const; + void processDicNodeAsMatch(DicTraverseSession *traverseSession, + DicNode *childDicNode) const; + + // Dic nodes cache size for lookahead (autocompletion) + static const int LOOKAHEAD_DIC_NODES_CACHE_SIZE; + // Max characters to lookahead + static const int MAX_LOOKAHEAD; + // Inputs longer than this will autocorrect if the suggestion is multi-word + static const int MIN_LEN_FOR_MULTI_WORD_AUTOCORRECT; + static const int MIN_CONTINUOUS_SUGGESTION_INPUT_SIZE; + // Base value for converting costs into scores (low so will not autocorrect without classifier) + static const float BASE_OUTPUT_SCORE; + + // Threshold for autocorrection classifier + static const float AUTOCORRECT_CLASSIFICATION_THRESHOLD; + // Threshold for computing the language model feature for autocorrect classification + static const float AUTOCORRECT_LANGUAGE_FEATURE_THRESHOLD; + + // Typing error correction settings + static const bool CORRECT_SPACE_OMISSION; + static const bool CORRECT_TRANSPOSITION; + static const bool CORRECT_INSERTION; + + const Traversal *const TRAVERSAL; + const Scoring *const SCORING; + const Weighting *const WEIGHTING; + + static const bool CORRECT_OMISSION_G; +}; +} // namespace latinime +#endif // LATINIME_SUGGEST_IMPL_H diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp new file mode 100644 index 000000000..90985d0fe --- /dev/null +++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp @@ -0,0 +1,52 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "scoring_params.h" + +namespace latinime { +// TODO: RENAME all +const float ScoringParams::MAX_SPATIAL_DISTANCE = 1.0f; +const int ScoringParams::THRESHOLD_NEXT_WORD_PROBABILITY = 40; +const int ScoringParams::THRESHOLD_NEXT_WORD_PROBABILITY_FOR_CAPPED = 120; +const float ScoringParams::AUTOCORRECT_OUTPUT_THRESHOLD = 1.0f; +const int ScoringParams::MAX_CACHE_DIC_NODE_SIZE = 125; +const int ScoringParams::THRESHOLD_SHORT_WORD_LENGTH = 4; + +const float ScoringParams::DISTANCE_WEIGHT_LENGTH = 0.132f; +const float ScoringParams::PROXIMITY_COST = 0.086f; +const float ScoringParams::FIRST_PROXIMITY_COST = 0.104f; +const float ScoringParams::OMISSION_COST = 0.388f; +const float ScoringParams::OMISSION_COST_SAME_CHAR = 0.431f; +const float ScoringParams::OMISSION_COST_FIRST_CHAR = 0.532f; +const float ScoringParams::INSERTION_COST = 0.670f; +const float ScoringParams::INSERTION_COST_SAME_CHAR = 0.526f; +const float ScoringParams::INSERTION_COST_FIRST_CHAR = 0.563f; +const float ScoringParams::TRANSPOSITION_COST = 0.494f; +const float ScoringParams::SPACE_SUBSTITUTION_COST = 0.239f; +const float ScoringParams::ADDITIONAL_PROXIMITY_COST = 0.380f; +const float ScoringParams::SUBSTITUTION_COST = 0.363f; +const float ScoringParams::COST_NEW_WORD = 0.054f; +const float ScoringParams::COST_NEW_WORD_CAPITALIZED = 0.174f; +const float ScoringParams::DISTANCE_WEIGHT_LANGUAGE = 1.123f; +const float ScoringParams::COST_FIRST_LOOKAHEAD = 0.462f; +const float ScoringParams::COST_LOOKAHEAD = 0.092f; +const float ScoringParams::HAS_PROXIMITY_TERMINAL_COST = 0.126f; +const float ScoringParams::HAS_EDIT_CORRECTION_TERMINAL_COST = 0.056f; +const float ScoringParams::HAS_MULTI_WORD_TERMINAL_COST = 0.136f; +const float ScoringParams::TYPING_BASE_OUTPUT_SCORE = 1.0f; +const float ScoringParams::TYPING_MAX_OUTPUT_SCORE_PER_INPUT = 0.1f; +const float ScoringParams::MAX_NORM_DISTANCE_FOR_EDIT = 0.1f; +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.h b/native/jni/src/suggest/policyimpl/typing/scoring_params.h new file mode 100644 index 000000000..8f104b362 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.h @@ -0,0 +1,66 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_SCORING_PARAMS_H +#define LATINIME_SCORING_PARAMS_H + +#include "defines.h" + +namespace latinime { + +class ScoringParams { + public: + // Fixed model parameters + static const float MAX_SPATIAL_DISTANCE; + static const int THRESHOLD_NEXT_WORD_PROBABILITY; + static const int THRESHOLD_NEXT_WORD_PROBABILITY_FOR_CAPPED; + static const float AUTOCORRECT_OUTPUT_THRESHOLD; + static const int MAX_CACHE_DIC_NODE_SIZE; + static const int THRESHOLD_SHORT_WORD_LENGTH; + + // Numerically optimized parameters (currently for tap typing only). + // TODO: add ability to modify these constants programmatically. + // TODO: explore optimization of gesture parameters. + static const float DISTANCE_WEIGHT_LENGTH; + static const float PROXIMITY_COST; + static const float FIRST_PROXIMITY_COST; + static const float OMISSION_COST; + static const float OMISSION_COST_SAME_CHAR; + static const float OMISSION_COST_FIRST_CHAR; + static const float INSERTION_COST; + static const float INSERTION_COST_SAME_CHAR; + static const float INSERTION_COST_FIRST_CHAR; + static const float TRANSPOSITION_COST; + static const float SPACE_SUBSTITUTION_COST; + static const float ADDITIONAL_PROXIMITY_COST; + static const float SUBSTITUTION_COST; + static const float COST_NEW_WORD; + static const float COST_NEW_WORD_CAPITALIZED; + static const float DISTANCE_WEIGHT_LANGUAGE; + static const float COST_FIRST_LOOKAHEAD; + static const float COST_LOOKAHEAD; + static const float HAS_PROXIMITY_TERMINAL_COST; + static const float HAS_EDIT_CORRECTION_TERMINAL_COST; + static const float HAS_MULTI_WORD_TERMINAL_COST; + static const float TYPING_BASE_OUTPUT_SCORE; + static const float TYPING_MAX_OUTPUT_SCORE_PER_INPUT; + static const float MAX_NORM_DISTANCE_FOR_EDIT; + + private: + DISALLOW_IMPLICIT_CONSTRUCTORS(ScoringParams); +}; +} // namespace latinime +#endif // LATINIME_SCORING_PARAMS_H diff --git a/native/jni/src/suggest/policyimpl/typing/typing_scoring.cpp b/native/jni/src/suggest/policyimpl/typing/typing_scoring.cpp new file mode 100644 index 000000000..53f68f20f --- /dev/null +++ b/native/jni/src/suggest/policyimpl/typing/typing_scoring.cpp @@ -0,0 +1,21 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "typing_scoring.h" + +namespace latinime { +const TypingScoring TypingScoring::sInstance; +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/typing/typing_scoring.h b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h new file mode 100644 index 000000000..ed941f0ae --- /dev/null +++ b/native/jni/src/suggest/policyimpl/typing/typing_scoring.h @@ -0,0 +1,82 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_TYPING_SCORING_H +#define LATINIME_TYPING_SCORING_H + +#include "defines.h" +#include "scoring.h" +#include "scoring_params.h" + +namespace latinime { + +class DicNode; +class DicTraverseSession; + +class TypingScoring : public Scoring { + public: + static const TypingScoring *getInstance() { return &sInstance; } + + AK_FORCE_INLINE bool getMostProbableString( + const DicTraverseSession *const traverseSession, const int terminalSize, + const float languageWeight, int *const outputCodePoints, int *const type, + int *const freq) const { + return false; + } + + AK_FORCE_INLINE void safetyNetForMostProbableString(const int terminalSize, + const int maxScore, int *const outputCodePoints, int *const frequencies) const { + } + + AK_FORCE_INLINE void searchWordWithDoubleLetter(DicNode *terminals, + const int terminalSize, int *doubleLetterTerminalIndex, + DoubleLetterLevel *doubleLetterLevel) const { + } + + AK_FORCE_INLINE float getAdjustedLanguageWeight(DicTraverseSession *const traverseSession, + DicNode *const terminals, const int size) const { + return 1.0f; + } + + AK_FORCE_INLINE int calculateFinalScore(const float compoundDistance, + const int inputSize, const bool forceCommit) const { + const float maxDistance = ScoringParams::DISTANCE_WEIGHT_LANGUAGE + + static_cast<float>(inputSize) * ScoringParams::TYPING_MAX_OUTPUT_SCORE_PER_INPUT; + return static_cast<int>((ScoringParams::TYPING_BASE_OUTPUT_SCORE + - (compoundDistance / maxDistance) + + (forceCommit ? ScoringParams::AUTOCORRECT_OUTPUT_THRESHOLD : 0.0f)) + * SUGGEST_INTERFACE_OUTPUT_SCALE); + } + + AK_FORCE_INLINE float getDoubleLetterDemotionDistanceCost(const int terminalIndex, + const int doubleLetterTerminalIndex, + const DoubleLetterLevel doubleLetterLevel) const { + return 0.0f; + } + + AK_FORCE_INLINE bool doesAutoCorrectValidWord() const { + return false; + } + + private: + DISALLOW_COPY_AND_ASSIGN(TypingScoring); + static const TypingScoring sInstance; + + TypingScoring() {} + ~TypingScoring() {} +}; +} // namespace latinime +#endif // LATINIME_TYPING_SCORING_H diff --git a/native/jni/src/suggest/policyimpl/typing/typing_suggest_policy.cpp b/native/jni/src/suggest/policyimpl/typing/typing_suggest_policy.cpp new file mode 100644 index 000000000..ebba37531 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/typing/typing_suggest_policy.cpp @@ -0,0 +1,42 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "suggest.h" +#include "typing_suggest.h" +#include "typing_suggest_policy.h" + +namespace latinime { + +const TypingSuggestPolicy TypingSuggestPolicy::sInstance; + +// A factory method for a "typing" Suggest instance +static SuggestInterface *getTypingSuggestInstance() { + return new Suggest(TypingSuggestPolicy::getInstance()); +} + +// An ad-hoc internal class to register the factory method getTypingSuggestInstance() defined above +class TypingSuggestFactoryRegisterer { + public: + TypingSuggestFactoryRegisterer() { + TypingSuggest::setTypingSuggestFactoryMethod(getTypingSuggestInstance); + } + private: + DISALLOW_COPY_AND_ASSIGN(TypingSuggestFactoryRegisterer); +}; + +// To invoke the TypingSuggestFactoryRegisterer's constructor in the global constructor +static TypingSuggestFactoryRegisterer typingSuggestFactoryregisterer; +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/typing/typing_suggest_policy.h b/native/jni/src/suggest/policyimpl/typing/typing_suggest_policy.h new file mode 100644 index 000000000..55668fc25 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/typing/typing_suggest_policy.h @@ -0,0 +1,55 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_TYPING_SUGGEST_POLICY_H +#define LATINIME_TYPING_SUGGEST_POLICY_H + +#include "defines.h" +#include "suggest_policy.h" +#include "typing_scoring.h" +#include "typing_traversal.h" +#include "typing_weighting.h" + +namespace latinime { + +class Scoring; +class Traversal; +class Weighting; + +class TypingSuggestPolicy : public SuggestPolicy { + public: + static const TypingSuggestPolicy *getInstance() { return &sInstance; } + + TypingSuggestPolicy() {} + virtual ~TypingSuggestPolicy() {} + AK_FORCE_INLINE const Traversal *getTraversal() const { + return TypingTraversal::getInstance(); + } + + AK_FORCE_INLINE const Scoring *getScoring() const { + return TypingScoring::getInstance(); + } + + AK_FORCE_INLINE const Weighting *getWeighting() const { + return TypingWeighting::getInstance(); + } + + private: + DISALLOW_COPY_AND_ASSIGN(TypingSuggestPolicy); + static const TypingSuggestPolicy sInstance; +}; +} // namespace latinime +#endif // LATINIME_TYPING_SUGGEST_POLICY_H diff --git a/native/jni/src/suggest/policyimpl/typing/typing_traversal.cpp b/native/jni/src/suggest/policyimpl/typing/typing_traversal.cpp new file mode 100644 index 000000000..68c614e77 --- /dev/null +++ b/native/jni/src/suggest/policyimpl/typing/typing_traversal.cpp @@ -0,0 +1,24 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "typing_traversal.h" + +namespace latinime { +const bool TypingTraversal::CORRECT_OMISSION = true; +const bool TypingTraversal::CORRECT_SPACE_SUBSTITUTION = true; +const bool TypingTraversal::CORRECT_SPACE_OMISSION = true; +const TypingTraversal TypingTraversal::sInstance; +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/typing/typing_traversal.h b/native/jni/src/suggest/policyimpl/typing/typing_traversal.h new file mode 100644 index 000000000..16153f8bb --- /dev/null +++ b/native/jni/src/suggest/policyimpl/typing/typing_traversal.h @@ -0,0 +1,184 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_TYPING_TRAVERSAL_H +#define LATINIME_TYPING_TRAVERSAL_H + +#include <stdint.h> + +#include "char_utils.h" +#include "defines.h" +#include "dic_node.h" +#include "dic_node_vector.h" +#include "dic_traverse_session.h" +#include "proximity_info_state.h" +#include "scoring_params.h" +#include "traversal.h" + +namespace latinime { +class TypingTraversal : public Traversal { + public: + static const TypingTraversal *getInstance() { return &sInstance; } + + AK_FORCE_INLINE int getMaxPointerCount() const { + return MAX_POINTER_COUNT; + } + + AK_FORCE_INLINE bool allowsErrorCorrections(const DicNode *const dicNode) const { + return dicNode->getNormalizedSpatialDistance() + < ScoringParams::MAX_NORM_DISTANCE_FOR_EDIT; + } + + AK_FORCE_INLINE bool isOmission(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode, const DicNode *const childDicNode) const { + if (!CORRECT_OMISSION) { + return false; + } + const int inputSize = traverseSession->getInputSize(); + // TODO: Don't refer to isCompletion? + if (dicNode->isCompletion(inputSize)) { + return false; + } + if (dicNode->canBeIntentionalOmission()) { + return true; + } + const int point0Index = dicNode->getInputIndex(0); + const int currentBaseLowerCodePoint = + toBaseLowerCase(childDicNode->getNodeCodePoint()); + const int typedBaseLowerCodePoint = + toBaseLowerCase(traverseSession->getProximityInfoState(0) + ->getPrimaryCodePointAt(point0Index)); + return (currentBaseLowerCodePoint != typedBaseLowerCodePoint); + } + + AK_FORCE_INLINE bool isSpaceSubstitutionTerminal( + const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const { + if (!CORRECT_SPACE_SUBSTITUTION) { + return false; + } + if (!canDoLookAheadCorrection(traverseSession, dicNode)) { + return false; + } + const int point0Index = dicNode->getInputIndex(0); + return dicNode->isTerminalWordNode() + && traverseSession->getProximityInfoState(0)-> + hasSpaceProximity(point0Index); + } + + AK_FORCE_INLINE bool isSpaceOmissionTerminal( + const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const { + if (!CORRECT_SPACE_OMISSION) { + return false; + } + const int inputSize = traverseSession->getInputSize(); + // TODO: Don't refer to isCompletion? + if (dicNode->isCompletion(inputSize)) { + return false; + } + if (!dicNode->isTerminalWordNode()) { + return false; + } + const int16_t pointIndex = dicNode->getInputIndex(0); + return pointIndex <= inputSize && !dicNode->isTotalInputSizeExceedingLimit() + && !dicNode->shouldBeFilterdBySafetyNetForBigram(); + } + + AK_FORCE_INLINE bool shouldDepthLevelCache( + const DicTraverseSession *const traverseSession) const { + const int inputSize = traverseSession->getInputSize(); + return traverseSession->isCacheBorderForTyping(inputSize); + } + + AK_FORCE_INLINE bool shouldNodeLevelCache( + const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const { + return false; + } + + AK_FORCE_INLINE bool canDoLookAheadCorrection( + const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const { + const int inputSize = traverseSession->getInputSize(); + return dicNode->canDoLookAheadCorrection(inputSize); + } + + AK_FORCE_INLINE ProximityType getProximityType( + const DicTraverseSession *const traverseSession, const DicNode *const dicNode, + const DicNode *const childDicNode) const { + return traverseSession->getProximityInfoState(0)->getProximityType( + dicNode->getInputIndex(0), childDicNode->getNodeCodePoint(), + true /* checkProximityChars */); + } + + AK_FORCE_INLINE bool needsToTraverseAllUserInput() const { + return true; + } + + AK_FORCE_INLINE float getMaxSpatialDistance() const { + return ScoringParams::MAX_SPATIAL_DISTANCE; + } + + AK_FORCE_INLINE bool allowPartialCommit() const { + return true; + } + + AK_FORCE_INLINE int getDefaultExpandDicNodeSize() const { + return DicNodeVector::DEFAULT_NODES_SIZE_FOR_OPTIMIZATION; + } + + AK_FORCE_INLINE bool sameAsTyped( + const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const { + return traverseSession->getProximityInfoState(0)->sameAsTyped( + dicNode->getOutputWordBuf(), dicNode->getDepth()); + } + + AK_FORCE_INLINE int getMaxCacheSize() const { + return ScoringParams::MAX_CACHE_DIC_NODE_SIZE; + } + + AK_FORCE_INLINE bool isPossibleOmissionChildNode( + const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, + const DicNode *const dicNode) const { + const ProximityType proximityType = + getProximityType(traverseSession, parentDicNode, dicNode); + if (!DicNodeUtils::isProximityChar(proximityType)) { + return false; + } + return true; + } + + AK_FORCE_INLINE bool isGoodToTraverseNextWord(const DicNode *const dicNode) const { + const int probability = dicNode->getProbability(); + if (probability < ScoringParams::THRESHOLD_NEXT_WORD_PROBABILITY) { + return false; + } + const int c = dicNode->getOutputWordBuf()[0]; + const bool shortCappedWord = dicNode->getDepth() + < ScoringParams::THRESHOLD_SHORT_WORD_LENGTH && isAsciiUpper(c); + return !shortCappedWord + || probability >= ScoringParams::THRESHOLD_NEXT_WORD_PROBABILITY_FOR_CAPPED; + } + + private: + DISALLOW_COPY_AND_ASSIGN(TypingTraversal); + static const bool CORRECT_OMISSION; + static const bool CORRECT_SPACE_SUBSTITUTION; + static const bool CORRECT_SPACE_OMISSION; + static const TypingTraversal sInstance; + + TypingTraversal() {} + ~TypingTraversal() {} +}; +} // namespace latinime +#endif // LATINIME_TYPING_TRAVERSAL_H diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp b/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp new file mode 100644 index 000000000..6e4b2fb6a --- /dev/null +++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp @@ -0,0 +1,23 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "dic_node.h" +#include "scoring_params.h" +#include "typing_weighting.h" + +namespace latinime { +const TypingWeighting TypingWeighting::sInstance; +} // namespace latinime diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h new file mode 100644 index 000000000..e8075f41a --- /dev/null +++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h @@ -0,0 +1,194 @@ +/* + * Copyright (C) 2012 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LATINIME_TYPING_WEIGHTING_H +#define LATINIME_TYPING_WEIGHTING_H + +#include "defines.h" +#include "dic_node_utils.h" +#include "dic_traverse_session.h" +#include "weighting.h" + +namespace latinime { + +class DicNode; +struct DicNode_InputStateG; + +class TypingWeighting : public Weighting { + public: + static const TypingWeighting *getInstance() { return &sInstance; } + + protected: + float getTerminalSpatialCost( + const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const { + float cost = 0.0f; + if (dicNode->hasMultipleWords()) { + cost += ScoringParams::HAS_MULTI_WORD_TERMINAL_COST; + } + if (dicNode->getProximityCorrectionCount() > 0) { + cost += ScoringParams::HAS_PROXIMITY_TERMINAL_COST; + } + if (dicNode->getEditCorrectionCount() > 0) { + cost += ScoringParams::HAS_EDIT_CORRECTION_TERMINAL_COST; + } + return cost; + } + + float getOmissionCost(const DicNode *const parentDicNode, const DicNode *const dicNode) const { + bool sameCodePoint = false; + bool isFirstLetterOmission = false; + float cost = 0.0f; + sameCodePoint = dicNode->isSameNodeCodePoint(parentDicNode); + // If the traversal omitted the first letter then the dicNode should now be on the second. + isFirstLetterOmission = dicNode->getDepth() == 2; + if (isFirstLetterOmission) { + cost = ScoringParams::OMISSION_COST_FIRST_CHAR; + } else { + cost = sameCodePoint ? ScoringParams::OMISSION_COST_SAME_CHAR + : ScoringParams::OMISSION_COST; + } + return cost; + } + + float getMatchedCost( + const DicTraverseSession *const traverseSession, const DicNode *const dicNode, + DicNode_InputStateG *inputStateG) const { + const int pointIndex = dicNode->getInputIndex(0); + // Note: min() required since length can be MAX_POINT_TO_KEY_LENGTH for characters not on + // the keyboard (like accented letters) + const float length = min(ScoringParams::MAX_SPATIAL_DISTANCE, + traverseSession->getProximityInfoState(0)->getPointToKeyLength( + pointIndex, dicNode->getNodeCodePoint())); + const float weightedDistance = length * ScoringParams::DISTANCE_WEIGHT_LENGTH; + const bool isFirstChar = pointIndex == 0; + const bool isProximity = isProximityDicNode(traverseSession, dicNode); + const float cost = isProximity ? (isFirstChar ? ScoringParams::FIRST_PROXIMITY_COST + : ScoringParams::PROXIMITY_COST) : 0.0f; + return weightedDistance + cost; + } + + bool isProximityDicNode( + const DicTraverseSession *const traverseSession, const DicNode *const dicNode) const { + const int pointIndex = dicNode->getInputIndex(0); + const int primaryCodePoint = toBaseLowerCase( + traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt(pointIndex)); + const int dicNodeChar = toBaseLowerCase(dicNode->getNodeCodePoint()); + return primaryCodePoint != dicNodeChar; + } + + float getTranspositionCost( + const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode, + const DicNode *const dicNode) const { + const int16_t parentPointIndex = parentDicNode->getInputIndex(0); + const int prevCodePoint = parentDicNode->getNodeCodePoint(); + const float distance1 = traverseSession->getProximityInfoState(0)->getPointToKeyLength( + parentPointIndex + 1, prevCodePoint); + const int codePoint = dicNode->getNodeCodePoint(); + const float distance2 = traverseSession->getProximityInfoState(0)->getPointToKeyLength( + parentPointIndex, codePoint); + const float distance = distance1 + distance2; + const float weightedLengthDistance = + distance * ScoringParams::DISTANCE_WEIGHT_LENGTH; + return ScoringParams::TRANSPOSITION_COST + weightedLengthDistance; + } + + float getInsertionCost( + const DicTraverseSession *const traverseSession, + const DicNode *const parentDicNode, const DicNode *const dicNode) const { + const int16_t parentPointIndex = parentDicNode->getInputIndex(0); + const int prevCodePoint = + traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt(parentPointIndex); + + const int currentCodePoint = dicNode->getNodeCodePoint(); + const bool sameCodePoint = prevCodePoint == currentCodePoint; + const float dist = traverseSession->getProximityInfoState(0)->getPointToKeyLength( + parentPointIndex + 1, currentCodePoint); + const float weightedDistance = dist * ScoringParams::DISTANCE_WEIGHT_LENGTH; + const bool singleChar = dicNode->getDepth() == 1; + const float cost = (singleChar ? ScoringParams::INSERTION_COST_FIRST_CHAR : 0.0f) + + (sameCodePoint ? ScoringParams::INSERTION_COST_SAME_CHAR + : ScoringParams::INSERTION_COST); + return cost + weightedDistance; + } + + float getNewWordCost(const DicNode *const dicNode) const { + const bool isCapitalized = dicNode->isCapitalized(); + return isCapitalized ? + ScoringParams::COST_NEW_WORD_CAPITALIZED : ScoringParams::COST_NEW_WORD; + } + + float getNewWordBigramCost( + const DicTraverseSession *const traverseSession, const DicNode *const dicNode, + hash_map_compat<int, int16_t> *const bigramCacheMap) const { + return DicNodeUtils::getBigramNodeImprobability(traverseSession->getOffsetDict(), + dicNode, bigramCacheMap); + } + + float getCompletionCost(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode) const { + // The auto completion starts when the input index is same as the input size + const bool firstCompletion = dicNode->getInputIndex(0) + == traverseSession->getInputSize(); + // TODO: Change the cost for the first completion for the gesture? + const float cost = firstCompletion ? ScoringParams::COST_FIRST_LOOKAHEAD + : ScoringParams::COST_LOOKAHEAD; + return cost; + } + + float getTerminalLanguageCost(const DicTraverseSession *const traverseSession, + const DicNode *const dicNode, const float dicNodeLanguageImprobability) const { + const bool hasEditCount = dicNode->getEditCorrectionCount() > 0; + const bool isSameLength = dicNode->getDepth() == traverseSession->getInputSize(); + const bool hasMultipleWords = dicNode->hasMultipleWords(); + const bool hasProximityErrors = dicNode->getProximityCorrectionCount() > 0; + // Gesture input is always assumed to have proximity errors + // because the input word shouldn't be treated as perfect + const bool isExactMatch = !hasEditCount && !hasMultipleWords + && !hasProximityErrors && isSameLength; + + const float totalPrevWordsLanguageCost = dicNode->getTotalPrevWordsLanguageCost(); + const float languageImprobability = isExactMatch ? 0.0f : dicNodeLanguageImprobability; + const float languageWeight = ScoringParams::DISTANCE_WEIGHT_LANGUAGE; + // TODO: Caveat: The following equation should be: + // totalPrevWordsLanguageCost + (languageImprobability * languageWeight); + return (totalPrevWordsLanguageCost + languageImprobability) * languageWeight; + } + + AK_FORCE_INLINE bool needsToNormalizeCompoundDistance() const { + return false; + } + + AK_FORCE_INLINE float getAdditionalProximityCost() const { + return ScoringParams::ADDITIONAL_PROXIMITY_COST; + } + + AK_FORCE_INLINE float getSubstitutionCost() const { + return ScoringParams::SUBSTITUTION_COST; + } + + AK_FORCE_INLINE float getSpaceSubstitutionCost() const { + return ScoringParams::SPACE_SUBSTITUTION_COST; + } + + private: + DISALLOW_COPY_AND_ASSIGN(TypingWeighting); + static const TypingWeighting sInstance; + + TypingWeighting() {} + ~TypingWeighting() {} +}; +} // namespace latinime +#endif // LATINIME_TYPING_WEIGHTING_H diff --git a/tests/src/com/android/inputmethod/keyboard/internal/HermiteInterpolatorTests.java b/tests/src/com/android/inputmethod/keyboard/internal/HermiteInterpolatorTests.java new file mode 100644 index 000000000..3ff5aa485 --- /dev/null +++ b/tests/src/com/android/inputmethod/keyboard/internal/HermiteInterpolatorTests.java @@ -0,0 +1,203 @@ +/* + * Copyright (C) 2013 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.android.inputmethod.keyboard.internal; + +import android.test.AndroidTestCase; +import android.test.suitebuilder.annotation.SmallTest; + +@SmallTest +public class HermiteInterpolatorTests extends AndroidTestCase { + private final HermiteInterpolator mInterpolator = new HermiteInterpolator(); + + @Override + protected void setUp() throws Exception { + super.setUp(); + } + + private static final float EPSLION = 0.0000005f; + + private static void assertFloatEquals(final String message, float expected, float actual) { + if (Math.abs(expected - actual) >= EPSLION) { + fail(String.format("%s expected:<%s> but was:<%s>", message, expected, actual)); + } + } + + // t=0 p0=(0,1) + // t=1 p1=(1,0) + // t=2 p2=(3,2) + // t=3 p3=(2,3) + // y + // | + // 3 + o p3 + // | + // 2 + o p2 + // | + // 1 o p0 + // | p1 + // 0 +---o---+---+-- x + // 0 1 2 3 + private final int[] mXCoords = { 0, 1, 3, 2 }; + private final int[] mYCoords = { 1, 0, 2, 3 }; + private static final int p0 = 0; + private static final int p1 = 1; + private static final int p2 = 2; + private static final int p3 = 3; + + public void testP0P1() { + // [(p0 p1) p2 p3] + mInterpolator.reset(mXCoords, mYCoords, p0, p3 + 1); + mInterpolator.setInterval(p0 - 1, p0, p1, p1 + 1); + assertEquals("p0x", mXCoords[p0], mInterpolator.mP1X); + assertEquals("p0y", mYCoords[p0], mInterpolator.mP1Y); + assertEquals("p1x", mXCoords[p1], mInterpolator.mP2X); + assertEquals("p1y", mYCoords[p1], mInterpolator.mP2Y); + // XY-slope at p0=3.0 (-0.75/-0.25) + assertFloatEquals("slope x p0", -0.25f, mInterpolator.mSlope1X); + assertFloatEquals("slope y p0", -0.75f, mInterpolator.mSlope1Y); + // XY-slope at p1=1/3.0 (0.50/1.50) + assertFloatEquals("slope x p1", 1.50f, mInterpolator.mSlope2X); + assertFloatEquals("slope y p1", 0.50f, mInterpolator.mSlope2Y); + // t=0.0 (p0) + mInterpolator.interpolate(0.0f); + assertFloatEquals("t=0.0 x", 0.0f, mInterpolator.mInterpolatedX); + assertFloatEquals("t=0.0 y", 1.0f, mInterpolator.mInterpolatedY); + // t=0.2 + mInterpolator.interpolate(0.2f); + assertFloatEquals("t=0.2 x", 0.02400f, mInterpolator.mInterpolatedX); + assertFloatEquals("t=0.2 y", 0.78400f, mInterpolator.mInterpolatedY); + // t=0.5 + mInterpolator.interpolate(0.5f); + assertFloatEquals("t=0.5 x", 0.28125f, mInterpolator.mInterpolatedX); + assertFloatEquals("t=0.5 y", 0.34375f, mInterpolator.mInterpolatedY); + // t=0.8 + mInterpolator.interpolate(0.8f); + assertFloatEquals("t=0.8 x", 0.69600f, mInterpolator.mInterpolatedX); + assertFloatEquals("t=0.8 y", 0.01600f, mInterpolator.mInterpolatedY); + // t=1.0 (p1) + mInterpolator.interpolate(1.0f); + assertFloatEquals("t=1.0 x", 1.0f, mInterpolator.mInterpolatedX); + assertFloatEquals("t=1.0 y", 0.0f, mInterpolator.mInterpolatedY); + } + + public void testP1P2() { + // [p0 (p1 p2) p3] + mInterpolator.reset(mXCoords, mYCoords, p0, p3 + 1); + mInterpolator.setInterval(p1 - 1, p1, p2, p2 + 1); + assertEquals("p1x", mXCoords[p1], mInterpolator.mP1X); + assertEquals("p1y", mYCoords[p1], mInterpolator.mP1Y); + assertEquals("p2x", mXCoords[p2], mInterpolator.mP2X); + assertEquals("p2y", mYCoords[p2], mInterpolator.mP2Y); + // XY-slope at p1=1/3.0 (0.50/1.50) + assertFloatEquals("slope x p1", 1.50f, mInterpolator.mSlope1X); + assertFloatEquals("slope y p1", 0.50f, mInterpolator.mSlope1Y); + // XY-slope at p2=3.0 (1.50/0.50) + assertFloatEquals("slope x p2", 0.50f, mInterpolator.mSlope2X); + assertFloatEquals("slope y p2", 1.50f, mInterpolator.mSlope2Y); + // t=0.0 (p1) + mInterpolator.interpolate(0.0f); + assertFloatEquals("t=0.0 x", 1.0f, mInterpolator.mInterpolatedX); + assertFloatEquals("t=0.0 y", 0.0f, mInterpolator.mInterpolatedY); + // t=0.2 + mInterpolator.interpolate(0.2f); + assertFloatEquals("t=0.2 x", 1.384f, mInterpolator.mInterpolatedX); + assertFloatEquals("t=0.2 y", 0.224f, mInterpolator.mInterpolatedY); + // t=0.5 + mInterpolator.interpolate(0.5f); + assertFloatEquals("t=0.5 x", 2.125f, mInterpolator.mInterpolatedX); + assertFloatEquals("t=0.5 y", 0.875f, mInterpolator.mInterpolatedY); + // t=0.8 + mInterpolator.interpolate(0.8f); + assertFloatEquals("t=0.8 x", 2.776f, mInterpolator.mInterpolatedX); + assertFloatEquals("t=0.8 y", 1.616f, mInterpolator.mInterpolatedY); + // t=1.0 (p2) + mInterpolator.interpolate(1.0f); + assertFloatEquals("t=1.0 x", 3.0f, mInterpolator.mInterpolatedX); + assertFloatEquals("t=1.0 y", 2.0f, mInterpolator.mInterpolatedY); + } + + public void testP2P3() { + // [p0 p1 (p2 p3)] + mInterpolator.reset(mXCoords, mYCoords, p0, p3 + 1); + mInterpolator.setInterval(p2 - 1, p2, p3, p3 + 1); + assertEquals("p2x", mXCoords[p2], mInterpolator.mP1X); + assertEquals("p2y", mYCoords[p2], mInterpolator.mP1Y); + assertEquals("p3x", mXCoords[p3], mInterpolator.mP2X); + assertEquals("p3y", mYCoords[p3], mInterpolator.mP2Y); + // XY-slope at p2=3.0 (1.50/0.50) + assertFloatEquals("slope x p2", 0.50f, mInterpolator.mSlope1X); + assertFloatEquals("slope y p2", 1.50f, mInterpolator.mSlope1Y); + // XY-slope at p3=1/3.0 (-0.25/-0.75) + assertFloatEquals("slope x p3", -0.75f, mInterpolator.mSlope2X); + assertFloatEquals("slope y p3", -0.25f, mInterpolator.mSlope2Y); + // t=0.0 (p2) + mInterpolator.interpolate(0.0f); + assertFloatEquals("t=0.0 x", 3.0f, mInterpolator.mInterpolatedX); + assertFloatEquals("t=0.0 y", 2.0f, mInterpolator.mInterpolatedY); + // t=0.2 + mInterpolator.interpolate(0.2f); + assertFloatEquals("t=0.2 x", 2.98400f, mInterpolator.mInterpolatedX); + assertFloatEquals("t=0.2 y", 2.30400f, mInterpolator.mInterpolatedY); + // t=0.5 + mInterpolator.interpolate(0.5f); + assertFloatEquals("t=0.5 x", 2.65625f, mInterpolator.mInterpolatedX); + assertFloatEquals("t=0.5 y", 2.71875f, mInterpolator.mInterpolatedY); + // t=0.8 + mInterpolator.interpolate(0.8f); + assertFloatEquals("t=0.8 x", 2.21600f, mInterpolator.mInterpolatedX); + assertFloatEquals("t=0.8 y", 2.97600f, mInterpolator.mInterpolatedY); + // t=1.0 (p3) + mInterpolator.interpolate(1.0f); + assertFloatEquals("t=1.0 x", 2.0f, mInterpolator.mInterpolatedX); + assertFloatEquals("t=1.0 y", 3.0f, mInterpolator.mInterpolatedY); + } + + public void testJustP1P2() { + // [(p1 p2)] + mInterpolator.reset(mXCoords, mYCoords, p1, p2 + 1); + mInterpolator.setInterval(p1 - 1, p1, p2, p2 + 1); + assertEquals("p1x", mXCoords[p1], mInterpolator.mP1X); + assertEquals("p1y", mYCoords[p1], mInterpolator.mP1Y); + assertEquals("p2x", mXCoords[p2], mInterpolator.mP2X); + assertEquals("p2y", mYCoords[p2], mInterpolator.mP2Y); + // XY-slope at p1=1.0 (2.0/2.0) + assertFloatEquals("slope x p1", 2.00f, mInterpolator.mSlope1X); + assertFloatEquals("slope y p1", 2.00f, mInterpolator.mSlope1Y); + // XY-slope at p2=1.0 (2.0/2.0) + assertFloatEquals("slope x p2", 2.00f, mInterpolator.mSlope2X); + assertFloatEquals("slope y p2", 2.00f, mInterpolator.mSlope2Y); + // t=0.0 (p1) + mInterpolator.interpolate(0.0f); + assertFloatEquals("t=0.0 x", 1.0f, mInterpolator.mInterpolatedX); + assertFloatEquals("t=0.0 y", 0.0f, mInterpolator.mInterpolatedY); + // t=0.2 + mInterpolator.interpolate(0.2f); + assertFloatEquals("t=0.2 x", 1.4f, mInterpolator.mInterpolatedX); + assertFloatEquals("t=0.2 y", 0.4f, mInterpolator.mInterpolatedY); + // t=0.5 + mInterpolator.interpolate(0.5f); + assertFloatEquals("t=0.5 x", 2.0f, mInterpolator.mInterpolatedX); + assertFloatEquals("t=0.5 y", 1.0f, mInterpolator.mInterpolatedY); + // t=0.8 + mInterpolator.interpolate(0.8f); + assertFloatEquals("t=0.8 x", 2.6f, mInterpolator.mInterpolatedX); + assertFloatEquals("t=0.8 y", 1.6f, mInterpolator.mInterpolatedY); + // t=1.0 (p2) + mInterpolator.interpolate(1.0f); + assertFloatEquals("t=1.0 x", 3.0f, mInterpolator.mInterpolatedX); + assertFloatEquals("t=1.0 y", 2.0f, mInterpolator.mInterpolatedY); + } +} diff --git a/tests/src/com/android/inputmethod/latin/InputTestsBase.java b/tests/src/com/android/inputmethod/latin/InputTestsBase.java index 4ccbf4857..04e1f932a 100644 --- a/tests/src/com/android/inputmethod/latin/InputTestsBase.java +++ b/tests/src/com/android/inputmethod/latin/InputTestsBase.java @@ -130,7 +130,9 @@ public class InputTestsBase extends ServiceTestCase<LatinIME> { protected void setUp() throws Exception { super.setUp(); mTextView = new MyTextView(getContext()); - mTextView.setInputType(InputType.TYPE_CLASS_TEXT); + final int inputType = InputType.TYPE_CLASS_TEXT | InputType.TYPE_TEXT_FLAG_AUTO_CORRECT + | InputType.TYPE_TEXT_FLAG_MULTI_LINE; + mTextView.setInputType(inputType); mTextView.setEnabled(true); setupService(); mLatinIME = getService(); @@ -138,9 +140,7 @@ public class InputTestsBase extends ServiceTestCase<LatinIME> { mLatinIME.onCreate(); setDebugMode(previousDebugSetting); final EditorInfo ei = new EditorInfo(); - ei.inputType = InputType.TYPE_CLASS_TEXT | InputType.TYPE_TEXT_FLAG_AUTO_CORRECT; final InputConnection ic = mTextView.onCreateInputConnection(ei); - ei.inputType = InputType.TYPE_CLASS_TEXT | InputType.TYPE_TEXT_FLAG_AUTO_CORRECT; final LayoutInflater inflater = (LayoutInflater)getContext().getSystemService(Context.LAYOUT_INFLATER_SERVICE); final ViewGroup vg = new FrameLayout(getContext()); @@ -181,17 +181,21 @@ public class InputTestsBase extends ServiceTestCase<LatinIME> { // a message that calls it instead of calling it directly. Looper.loop(); - // Once #quit() has been called, the message queue has an "mQuiting" field that prevents - // any subsequent post in this queue. However the queue itself is still fully functional! - // If we have a way of resetting "queue.mQuiting" then we can continue using it as normal, - // coming back to this method to run the messages. + // Once #quit() has been called, the looper is not functional any more (it used to be, + // but now it SIGSEGV's if it's used again). + // It won't accept creating a new looper for this thread and switching to it... + // ...unless we can trick it into throwing out the old looper and believing it hasn't + // been initialized before. MessageQueue queue = Looper.myQueue(); try { - // However there is no way of doing it externally, and mQuiting is private. + // However there is no way of doing it externally, and the static ThreadLocal + // field into which it's stored is private. // So... get out the big guns. - java.lang.reflect.Field f = MessageQueue.class.getDeclaredField("mQuiting"); - f.setAccessible(true); // What do you mean "private"? - f.setBoolean(queue, false); + java.lang.reflect.Field f = Looper.class.getDeclaredField("sThreadLocal"); + f.setAccessible(true); // private lolwut + final ThreadLocal<Looper> a = (ThreadLocal<Looper>) f.get(looper); + a.set(null); + looper.prepare(); } catch (NoSuchFieldException e) { throw new RuntimeException(e); } catch (IllegalAccessException e) { |