• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "MotionPredictor"
18 
19 #include <input/MotionPredictor.h>
20 
21 #include <cinttypes>
22 #include <cmath>
23 #include <cstddef>
24 #include <cstdint>
25 #include <string>
26 #include <vector>
27 
28 #include <android-base/strings.h>
29 #include <android/input.h>
30 #include <log/log.h>
31 
32 #include <attestation/HmacKeyManager.h>
33 #include <ftl/enum.h>
34 #include <input/TfLiteMotionPredictor.h>
35 
36 namespace android {
37 namespace {
38 
39 /**
40  * Log debug messages about predictions.
41  * Enable this via "adb shell setprop log.tag.MotionPredictor DEBUG"
42  */
isDebug()43 bool isDebug() {
44     return __android_log_is_loggable(ANDROID_LOG_DEBUG, LOG_TAG, ANDROID_LOG_INFO);
45 }
46 
47 // Converts a prediction of some polar (r, phi) to Cartesian (x, y) when applied to an axis.
convertPrediction(const TfLiteMotionPredictorSample::Point & axisFrom,const TfLiteMotionPredictorSample::Point & axisTo,float r,float phi)48 TfLiteMotionPredictorSample::Point convertPrediction(
49         const TfLiteMotionPredictorSample::Point& axisFrom,
50         const TfLiteMotionPredictorSample::Point& axisTo, float r, float phi) {
51     const TfLiteMotionPredictorSample::Point axis = axisTo - axisFrom;
52     const float axis_phi = std::atan2(axis.y, axis.x);
53     const float x_delta = r * std::cos(axis_phi + phi);
54     const float y_delta = r * std::sin(axis_phi + phi);
55     return {.x = axisTo.x + x_delta, .y = axisTo.y + y_delta};
56 }
57 
58 } // namespace
59 
60 // --- MotionPredictor ---
61 
MotionPredictor(nsecs_t predictionTimestampOffsetNanos,std::function<bool ()> checkMotionPredictionEnabled)62 MotionPredictor::MotionPredictor(nsecs_t predictionTimestampOffsetNanos,
63                                  std::function<bool()> checkMotionPredictionEnabled)
64       : mPredictionTimestampOffsetNanos(predictionTimestampOffsetNanos),
65         mCheckMotionPredictionEnabled(std::move(checkMotionPredictionEnabled)) {}
66 
record(const MotionEvent & event)67 android::base::Result<void> MotionPredictor::record(const MotionEvent& event) {
68     if (mLastEvent && mLastEvent->getDeviceId() != event.getDeviceId()) {
69         // We still have an active gesture for another device. The provided MotionEvent is not
70         // consistent with the previous gesture.
71         LOG(ERROR) << "Inconsistent event stream: last event is " << *mLastEvent << ", but "
72                    << __func__ << " is called with " << event;
73         return android::base::Error()
74                 << "Inconsistent event stream: still have an active gesture from device "
75                 << mLastEvent->getDeviceId() << ", but received " << event;
76     }
77     if (!isPredictionAvailable(event.getDeviceId(), event.getSource())) {
78         ALOGE("Prediction not supported for device %d's %s source", event.getDeviceId(),
79               inputEventSourceToString(event.getSource()).c_str());
80         return {};
81     }
82 
83     // Initialise the model now that it's likely to be used.
84     if (!mModel) {
85         mModel = TfLiteMotionPredictorModel::create();
86         LOG_ALWAYS_FATAL_IF(!mModel);
87     }
88 
89     if (!mBuffers) {
90         mBuffers = std::make_unique<TfLiteMotionPredictorBuffers>(mModel->inputLength());
91     }
92 
93     const int32_t action = event.getActionMasked();
94     if (action == AMOTION_EVENT_ACTION_UP || action == AMOTION_EVENT_ACTION_CANCEL) {
95         ALOGD_IF(isDebug(), "End of event stream");
96         mBuffers->reset();
97         mLastEvent.reset();
98         return {};
99     } else if (action != AMOTION_EVENT_ACTION_DOWN && action != AMOTION_EVENT_ACTION_MOVE) {
100         ALOGD_IF(isDebug(), "Skipping unsupported %s action",
101                  MotionEvent::actionToString(action).c_str());
102         return {};
103     }
104 
105     if (event.getPointerCount() != 1) {
106         ALOGD_IF(isDebug(), "Prediction not supported for multiple pointers");
107         return {};
108     }
109 
110     const ToolType toolType = event.getPointerProperties(0)->toolType;
111     if (toolType != ToolType::STYLUS) {
112         ALOGD_IF(isDebug(), "Prediction not supported for non-stylus tool: %s",
113                  ftl::enum_string(toolType).c_str());
114         return {};
115     }
116 
117     for (size_t i = 0; i <= event.getHistorySize(); ++i) {
118         if (event.isResampled(0, i)) {
119             continue;
120         }
121         const PointerCoords* coords = event.getHistoricalRawPointerCoords(0, i);
122         mBuffers->pushSample(event.getHistoricalEventTime(i),
123                              {
124                                      .position.x = coords->getAxisValue(AMOTION_EVENT_AXIS_X),
125                                      .position.y = coords->getAxisValue(AMOTION_EVENT_AXIS_Y),
126                                      .pressure = event.getHistoricalPressure(0, i),
127                                      .tilt = event.getHistoricalAxisValue(AMOTION_EVENT_AXIS_TILT,
128                                                                           0, i),
129                                      .orientation = event.getHistoricalOrientation(0, i),
130                              });
131     }
132 
133     if (!mLastEvent) {
134         mLastEvent = MotionEvent();
135     }
136     mLastEvent->copyFrom(&event, /*keepHistory=*/false);
137 
138     // Pass input event to the MetricsManager.
139     if (!mMetricsManager) {
140         mMetricsManager.emplace(mModel->config().predictionInterval, mModel->outputLength());
141     }
142     mMetricsManager->onRecord(event);
143 
144     return {};
145 }
146 
predict(nsecs_t timestamp)147 std::unique_ptr<MotionEvent> MotionPredictor::predict(nsecs_t timestamp) {
148     if (mBuffers == nullptr || !mBuffers->isReady()) {
149         return nullptr;
150     }
151 
152     LOG_ALWAYS_FATAL_IF(!mModel);
153     mBuffers->copyTo(*mModel);
154     LOG_ALWAYS_FATAL_IF(!mModel->invoke());
155 
156     // Read out the predictions.
157     const std::span<const float> predictedR = mModel->outputR();
158     const std::span<const float> predictedPhi = mModel->outputPhi();
159     const std::span<const float> predictedPressure = mModel->outputPressure();
160 
161     TfLiteMotionPredictorSample::Point axisFrom = mBuffers->axisFrom().position;
162     TfLiteMotionPredictorSample::Point axisTo = mBuffers->axisTo().position;
163 
164     if (isDebug()) {
165         ALOGD("axisFrom: %f, %f", axisFrom.x, axisFrom.y);
166         ALOGD("axisTo: %f, %f", axisTo.x, axisTo.y);
167         ALOGD("mInputR: %s", base::Join(mModel->inputR(), ", ").c_str());
168         ALOGD("mInputPhi: %s", base::Join(mModel->inputPhi(), ", ").c_str());
169         ALOGD("mInputPressure: %s", base::Join(mModel->inputPressure(), ", ").c_str());
170         ALOGD("mInputTilt: %s", base::Join(mModel->inputTilt(), ", ").c_str());
171         ALOGD("mInputOrientation: %s", base::Join(mModel->inputOrientation(), ", ").c_str());
172         ALOGD("predictedR: %s", base::Join(predictedR, ", ").c_str());
173         ALOGD("predictedPhi: %s", base::Join(predictedPhi, ", ").c_str());
174         ALOGD("predictedPressure: %s", base::Join(predictedPressure, ", ").c_str());
175     }
176 
177     LOG_ALWAYS_FATAL_IF(!mLastEvent);
178     const MotionEvent& event = *mLastEvent;
179     bool hasPredictions = false;
180     std::unique_ptr<MotionEvent> prediction = std::make_unique<MotionEvent>();
181     int64_t predictionTime = mBuffers->lastTimestamp();
182     const int64_t futureTime = timestamp + mPredictionTimestampOffsetNanos;
183 
184     for (int i = 0; i < predictedR.size() && predictionTime <= futureTime; ++i) {
185         if (predictedR[i] < mModel->config().distanceNoiseFloor) {
186             // Stop predicting when the predicted output is below the model's noise floor.
187             //
188             // We assume that all subsequent predictions in the batch are unreliable because later
189             // predictions are conditional on earlier predictions, and a state of noise is not a
190             // good basis for prediction.
191             //
192             // The UX trade-off is that this potentially sacrifices some predictions when the input
193             // device starts to speed up, but avoids producing noisy predictions as it slows down.
194             break;
195         }
196         // TODO(b/266747654): Stop predictions if confidence is < some threshold.
197 
198         const TfLiteMotionPredictorSample::Point predictedPoint =
199                 convertPrediction(axisFrom, axisTo, predictedR[i], predictedPhi[i]);
200 
201         ALOGD_IF(isDebug(), "prediction %d: %f, %f", i, predictedPoint.x, predictedPoint.y);
202         PointerCoords coords;
203         coords.clear();
204         coords.setAxisValue(AMOTION_EVENT_AXIS_X, predictedPoint.x);
205         coords.setAxisValue(AMOTION_EVENT_AXIS_Y, predictedPoint.y);
206         coords.setAxisValue(AMOTION_EVENT_AXIS_PRESSURE, predictedPressure[i]);
207 
208         predictionTime += mModel->config().predictionInterval;
209         if (i == 0) {
210             hasPredictions = true;
211             prediction->initialize(InputEvent::nextId(), event.getDeviceId(), event.getSource(),
212                                    event.getDisplayId(), INVALID_HMAC, AMOTION_EVENT_ACTION_MOVE,
213                                    event.getActionButton(), event.getFlags(), event.getEdgeFlags(),
214                                    event.getMetaState(), event.getButtonState(),
215                                    event.getClassification(), event.getTransform(),
216                                    event.getXPrecision(), event.getYPrecision(),
217                                    event.getRawXCursorPosition(), event.getRawYCursorPosition(),
218                                    event.getRawTransform(), event.getDownTime(), predictionTime,
219                                    event.getPointerCount(), event.getPointerProperties(), &coords);
220         } else {
221             prediction->addSample(predictionTime, &coords);
222         }
223 
224         axisFrom = axisTo;
225         axisTo = predictedPoint;
226     }
227 
228     if (!hasPredictions) {
229         return nullptr;
230     }
231 
232     // Pass predictions to the MetricsManager.
233     LOG_ALWAYS_FATAL_IF(!mMetricsManager);
234     mMetricsManager->onPredict(*prediction);
235 
236     return prediction;
237 }
238 
isPredictionAvailable(int32_t,int32_t source)239 bool MotionPredictor::isPredictionAvailable(int32_t /*deviceId*/, int32_t source) {
240     // Global flag override
241     if (!mCheckMotionPredictionEnabled()) {
242         ALOGD_IF(isDebug(), "Prediction not available due to flag override");
243         return false;
244     }
245 
246     // Prediction is only supported for stylus sources.
247     if (!isFromSource(source, AINPUT_SOURCE_STYLUS)) {
248         ALOGD_IF(isDebug(), "Prediction not available for non-stylus source: %s",
249                  inputEventSourceToString(source).c_str());
250         return false;
251     }
252     return true;
253 }
254 
255 } // namespace android
256