1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "TfLiteMotionPredictor"
18 #include <input/TfLiteMotionPredictor.h>
19 
20 #include <fcntl.h>
21 #include <sys/mman.h>
22 #include <unistd.h>
23 
24 #include <algorithm>
25 #include <cmath>
26 #include <cstddef>
27 #include <cstdint>
28 #include <memory>
29 #include <span>
30 #include <type_traits>
31 #include <utility>
32 
33 #include <android-base/file.h>
34 #include <android-base/logging.h>
35 #include <android-base/mapped_file.h>
36 #define ATRACE_TAG ATRACE_TAG_INPUT
37 #include <cutils/trace.h>
38 #include <log/log.h>
39 #include <utils/Timers.h>
40 
41 #include "tensorflow/lite/core/api/error_reporter.h"
42 #include "tensorflow/lite/core/api/op_resolver.h"
43 #include "tensorflow/lite/interpreter.h"
44 #include "tensorflow/lite/kernels/builtin_op_kernels.h"
45 #include "tensorflow/lite/model.h"
46 #include "tensorflow/lite/mutable_op_resolver.h"
47 
48 #include "tinyxml2.h"
49 
50 namespace android {
51 namespace {
52 
53 constexpr char SIGNATURE_KEY[] = "serving_default";
54 
55 // Input tensor names.
56 constexpr char INPUT_R[] = "r";
57 constexpr char INPUT_PHI[] = "phi";
58 constexpr char INPUT_PRESSURE[] = "pressure";
59 constexpr char INPUT_TILT[] = "tilt";
60 constexpr char INPUT_ORIENTATION[] = "orientation";
61 
62 // Output tensor names.
63 constexpr char OUTPUT_R[] = "r";
64 constexpr char OUTPUT_PHI[] = "phi";
65 constexpr char OUTPUT_PRESSURE[] = "pressure";
66 
67 // Ideally, we would just use std::filesystem::exists here, but it requires libc++fs, which causes
68 // build issues in other parts of the system.
69 #if defined(__ANDROID__)
fileExists(const char * filename)70 bool fileExists(const char* filename) {
71     struct stat buffer;
72     return stat(filename, &buffer) == 0;
73 }
74 #endif
75 
getModelPath()76 std::string getModelPath() {
77 #if defined(__ANDROID__)
78     static const char* oemModel = "/vendor/etc/motion_predictor_model.tflite";
79     if (fileExists(oemModel)) {
80         return oemModel;
81     }
82     return "/system/etc/motion_predictor_model.tflite";
83 #else
84     return base::GetExecutableDirectory() + "/motion_predictor_model.tflite";
85 #endif
86 }
87 
getConfigPath()88 std::string getConfigPath() {
89     // The config file should be alongside the model file.
90     return base::Dirname(getModelPath()) + "/motion_predictor_config.xml";
91 }
92 
parseXMLInt64(const tinyxml2::XMLElement & configRoot,const char * elementName)93 int64_t parseXMLInt64(const tinyxml2::XMLElement& configRoot, const char* elementName) {
94     const tinyxml2::XMLElement* element = configRoot.FirstChildElement(elementName);
95     LOG_ALWAYS_FATAL_IF(!element, "Could not find '%s' element", elementName);
96 
97     int64_t value = 0;
98     LOG_ALWAYS_FATAL_IF(element->QueryInt64Text(&value) != tinyxml2::XML_SUCCESS,
99                         "Failed to parse %s: %s", elementName, element->GetText());
100     return value;
101 }
102 
parseXMLFloat(const tinyxml2::XMLElement & configRoot,const char * elementName)103 float parseXMLFloat(const tinyxml2::XMLElement& configRoot, const char* elementName) {
104     const tinyxml2::XMLElement* element = configRoot.FirstChildElement(elementName);
105     LOG_ALWAYS_FATAL_IF(!element, "Could not find '%s' element", elementName);
106 
107     float value = 0;
108     LOG_ALWAYS_FATAL_IF(element->QueryFloatText(&value) != tinyxml2::XML_SUCCESS,
109                         "Failed to parse %s: %s", elementName, element->GetText());
110     return value;
111 }
112 
113 // A TFLite ErrorReporter that logs to logcat.
114 class LoggingErrorReporter : public tflite::ErrorReporter {
115 public:
Report(const char * format,va_list args)116     int Report(const char* format, va_list args) override {
117         return LOG_PRI_VA(ANDROID_LOG_ERROR, LOG_TAG, format, args);
118     }
119 };
120 
121 // Searches a runner for an input tensor.
findInputTensor(const char * name,tflite::SignatureRunner * runner)122 TfLiteTensor* findInputTensor(const char* name, tflite::SignatureRunner* runner) {
123     TfLiteTensor* tensor = runner->input_tensor(name);
124     LOG_ALWAYS_FATAL_IF(!tensor, "Failed to find input tensor '%s'", name);
125     return tensor;
126 }
127 
128 // Searches a runner for an output tensor.
findOutputTensor(const char * name,tflite::SignatureRunner * runner)129 const TfLiteTensor* findOutputTensor(const char* name, tflite::SignatureRunner* runner) {
130     const TfLiteTensor* tensor = runner->output_tensor(name);
131     LOG_ALWAYS_FATAL_IF(!tensor, "Failed to find output tensor '%s'", name);
132     return tensor;
133 }
134 
135 // Returns the buffer for a tensor of type T.
136 template <typename T>
getTensorBuffer(typename std::conditional<std::is_const<T>::value,const TfLiteTensor *,TfLiteTensor * >::type tensor)137 std::span<T> getTensorBuffer(typename std::conditional<std::is_const<T>::value, const TfLiteTensor*,
138                                                        TfLiteTensor*>::type tensor) {
139     LOG_ALWAYS_FATAL_IF(!tensor);
140 
141     const TfLiteType type = tflite::typeToTfLiteType<typename std::remove_cv<T>::type>();
142     LOG_ALWAYS_FATAL_IF(tensor->type != type, "Unexpected type for '%s' tensor: %s (expected %s)",
143                         tensor->name, TfLiteTypeGetName(tensor->type), TfLiteTypeGetName(type));
144 
145     LOG_ALWAYS_FATAL_IF(!tensor->data.data);
146     return std::span<T>(reinterpret_cast<T*>(tensor->data.data), tensor->bytes / sizeof(T));
147 }
148 
149 // Verifies that a tensor exists and has an underlying buffer of type T.
150 template <typename T>
checkTensor(const TfLiteTensor * tensor)151 void checkTensor(const TfLiteTensor* tensor) {
152     LOG_ALWAYS_FATAL_IF(!tensor);
153 
154     const auto buffer = getTensorBuffer<const T>(tensor);
155     LOG_ALWAYS_FATAL_IF(buffer.empty(), "No buffer for tensor '%s'", tensor->name);
156 }
157 
createOpResolver()158 std::unique_ptr<tflite::OpResolver> createOpResolver() {
159     auto resolver = std::make_unique<tflite::MutableOpResolver>();
160     resolver->AddBuiltin(::tflite::BuiltinOperator_CONCATENATION,
161                          ::tflite::ops::builtin::Register_CONCATENATION());
162     resolver->AddBuiltin(::tflite::BuiltinOperator_FULLY_CONNECTED,
163                          ::tflite::ops::builtin::Register_FULLY_CONNECTED());
164     resolver->AddBuiltin(::tflite::BuiltinOperator_GELU, ::tflite::ops::builtin::Register_GELU());
165     return resolver;
166 }
167 
168 } // namespace
169 
TfLiteMotionPredictorBuffers(size_t inputLength)170 TfLiteMotionPredictorBuffers::TfLiteMotionPredictorBuffers(size_t inputLength)
171       : mInputR(inputLength, 0),
172         mInputPhi(inputLength, 0),
173         mInputPressure(inputLength, 0),
174         mInputTilt(inputLength, 0),
175         mInputOrientation(inputLength, 0) {
176     LOG_ALWAYS_FATAL_IF(inputLength == 0, "Buffer input size must be greater than 0");
177 }
178 
reset()179 void TfLiteMotionPredictorBuffers::reset() {
180     std::fill(mInputR.begin(), mInputR.end(), 0);
181     std::fill(mInputPhi.begin(), mInputPhi.end(), 0);
182     std::fill(mInputPressure.begin(), mInputPressure.end(), 0);
183     std::fill(mInputTilt.begin(), mInputTilt.end(), 0);
184     std::fill(mInputOrientation.begin(), mInputOrientation.end(), 0);
185     mAxisFrom.reset();
186     mAxisTo.reset();
187 }
188 
copyTo(TfLiteMotionPredictorModel & model) const189 void TfLiteMotionPredictorBuffers::copyTo(TfLiteMotionPredictorModel& model) const {
190     LOG_ALWAYS_FATAL_IF(mInputR.size() != model.inputLength(),
191                         "Buffer length %zu doesn't match model input length %zu", mInputR.size(),
192                         model.inputLength());
193     LOG_ALWAYS_FATAL_IF(!isReady(), "Buffers are incomplete");
194 
195     std::copy(mInputR.begin(), mInputR.end(), model.inputR().begin());
196     std::copy(mInputPhi.begin(), mInputPhi.end(), model.inputPhi().begin());
197     std::copy(mInputPressure.begin(), mInputPressure.end(), model.inputPressure().begin());
198     std::copy(mInputTilt.begin(), mInputTilt.end(), model.inputTilt().begin());
199     std::copy(mInputOrientation.begin(), mInputOrientation.end(), model.inputOrientation().begin());
200 }
201 
pushSample(int64_t timestamp,const TfLiteMotionPredictorSample sample)202 void TfLiteMotionPredictorBuffers::pushSample(int64_t timestamp,
203                                               const TfLiteMotionPredictorSample sample) {
204     // Convert the sample (x, y) into polar (r, φ) based on a reference axis
205     // from the preceding two points (mAxisFrom/mAxisTo).
206 
207     mTimestamp = timestamp;
208 
209     if (!mAxisTo) { // First point.
210         mAxisTo = sample;
211         return;
212     }
213 
214     // Vector from the last point to the current sample point.
215     const TfLiteMotionPredictorSample::Point v = sample.position - mAxisTo->position;
216 
217     const float r = std::hypot(v.x, v.y);
218     float phi = 0;
219     float orientation = 0;
220 
221     if (!mAxisFrom && r > 0) { // Second point.
222         // We can only determine the distance from the first point, and not any
223         // angle. However, if the second point forms an axis, the orientation can
224         // be transformed relative to that axis.
225         const float axisPhi = std::atan2(v.y, v.x);
226         // A MotionEvent's orientation is measured clockwise from the vertical
227         // axis, but axisPhi is measured counter-clockwise from the horizontal
228         // axis.
229         orientation = M_PI_2 - sample.orientation - axisPhi;
230     } else {
231         const TfLiteMotionPredictorSample::Point axis = mAxisTo->position - mAxisFrom->position;
232         const float axisPhi = std::atan2(axis.y, axis.x);
233         phi = std::atan2(v.y, v.x) - axisPhi;
234 
235         if (std::hypot(axis.x, axis.y) > 0) {
236             // See note above.
237             orientation = M_PI_2 - sample.orientation - axisPhi;
238         }
239     }
240 
241     // Update the axis for the next point.
242     if (r > 0) {
243         mAxisFrom = mAxisTo;
244         mAxisTo = sample;
245     }
246 
247     // Push the current sample onto the end of the input buffers.
248     mInputR.pushBack(r);
249     mInputPhi.pushBack(phi);
250     mInputPressure.pushBack(sample.pressure);
251     mInputTilt.pushBack(sample.tilt);
252     mInputOrientation.pushBack(orientation);
253 }
254 
create()255 std::unique_ptr<TfLiteMotionPredictorModel> TfLiteMotionPredictorModel::create() {
256     const std::string modelPath = getModelPath();
257     android::base::unique_fd fd(open(modelPath.c_str(), O_RDONLY));
258     if (fd == -1) {
259         PLOG(FATAL) << "Could not read model from " << modelPath;
260     }
261 
262     const off_t fdSize = lseek(fd, 0, SEEK_END);
263     if (fdSize == -1) {
264         PLOG(FATAL) << "Failed to determine file size";
265     }
266 
267     std::unique_ptr<android::base::MappedFile> modelBuffer =
268             android::base::MappedFile::FromFd(fd, /*offset=*/0, fdSize, PROT_READ);
269     if (!modelBuffer) {
270         PLOG(FATAL) << "Failed to mmap model";
271     }
272 
273     const std::string configPath = getConfigPath();
274     tinyxml2::XMLDocument configDocument;
275     LOG_ALWAYS_FATAL_IF(configDocument.LoadFile(configPath.c_str()) != tinyxml2::XML_SUCCESS,
276                         "Failed to load config file from %s", configPath.c_str());
277 
278     // Parse configuration file.
279     const tinyxml2::XMLElement* configRoot = configDocument.FirstChildElement("motion-predictor");
280     LOG_ALWAYS_FATAL_IF(!configRoot);
281     Config config{
282             .predictionInterval = parseXMLInt64(*configRoot, "prediction-interval"),
283             .distanceNoiseFloor = parseXMLFloat(*configRoot, "distance-noise-floor"),
284             .lowJerk = parseXMLFloat(*configRoot, "low-jerk"),
285             .highJerk = parseXMLFloat(*configRoot, "high-jerk"),
286             .jerkAlpha = parseXMLFloat(*configRoot, "jerk-alpha"),
287     };
288 
289     return std::unique_ptr<TfLiteMotionPredictorModel>(
290             new TfLiteMotionPredictorModel(std::move(modelBuffer), std::move(config)));
291 }
292 
TfLiteMotionPredictorModel(std::unique_ptr<android::base::MappedFile> model,Config config)293 TfLiteMotionPredictorModel::TfLiteMotionPredictorModel(
294         std::unique_ptr<android::base::MappedFile> model, Config config)
295       : mFlatBuffer(std::move(model)), mConfig(std::move(config)) {
296     CHECK(mFlatBuffer);
297     mErrorReporter = std::make_unique<LoggingErrorReporter>();
298     mModel = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(mFlatBuffer->data(),
299                                                                mFlatBuffer->size(),
300                                                                /*extra_verifier=*/nullptr,
301                                                                mErrorReporter.get());
302     LOG_ALWAYS_FATAL_IF(!mModel);
303 
304     auto resolver = createOpResolver();
305     tflite::InterpreterBuilder builder(*mModel, *resolver);
306 
307     if (builder(&mInterpreter) != kTfLiteOk || !mInterpreter) {
308         LOG_ALWAYS_FATAL("Failed to build interpreter");
309     }
310 
311     mRunner = mInterpreter->GetSignatureRunner(SIGNATURE_KEY);
312     LOG_ALWAYS_FATAL_IF(!mRunner, "Failed to find runner for signature '%s'", SIGNATURE_KEY);
313 
314     allocateTensors();
315 }
316 
~TfLiteMotionPredictorModel()317 TfLiteMotionPredictorModel::~TfLiteMotionPredictorModel() {}
318 
allocateTensors()319 void TfLiteMotionPredictorModel::allocateTensors() {
320     if (mRunner->AllocateTensors() != kTfLiteOk) {
321         LOG_ALWAYS_FATAL("Failed to allocate tensors");
322     }
323 
324     attachInputTensors();
325     attachOutputTensors();
326 
327     checkTensor<float>(mInputR);
328     checkTensor<float>(mInputPhi);
329     checkTensor<float>(mInputPressure);
330     checkTensor<float>(mInputTilt);
331     checkTensor<float>(mInputOrientation);
332     checkTensor<float>(mOutputR);
333     checkTensor<float>(mOutputPhi);
334     checkTensor<float>(mOutputPressure);
335 
336     const auto checkInputTensorSize = [this](const TfLiteTensor* tensor) {
337         const size_t size = getTensorBuffer<const float>(tensor).size();
338         LOG_ALWAYS_FATAL_IF(size != inputLength(),
339                             "Tensor '%s' length %zu does not match input length %zu", tensor->name,
340                             size, inputLength());
341     };
342 
343     checkInputTensorSize(mInputR);
344     checkInputTensorSize(mInputPhi);
345     checkInputTensorSize(mInputPressure);
346     checkInputTensorSize(mInputTilt);
347     checkInputTensorSize(mInputOrientation);
348 }
349 
attachInputTensors()350 void TfLiteMotionPredictorModel::attachInputTensors() {
351     mInputR = findInputTensor(INPUT_R, mRunner);
352     mInputPhi = findInputTensor(INPUT_PHI, mRunner);
353     mInputPressure = findInputTensor(INPUT_PRESSURE, mRunner);
354     mInputTilt = findInputTensor(INPUT_TILT, mRunner);
355     mInputOrientation = findInputTensor(INPUT_ORIENTATION, mRunner);
356 }
357 
attachOutputTensors()358 void TfLiteMotionPredictorModel::attachOutputTensors() {
359     mOutputR = findOutputTensor(OUTPUT_R, mRunner);
360     mOutputPhi = findOutputTensor(OUTPUT_PHI, mRunner);
361     mOutputPressure = findOutputTensor(OUTPUT_PRESSURE, mRunner);
362 }
363 
invoke()364 bool TfLiteMotionPredictorModel::invoke() {
365     ATRACE_BEGIN("TfLiteMotionPredictorModel::invoke");
366     TfLiteStatus result = mRunner->Invoke();
367     ATRACE_END();
368 
369     if (result != kTfLiteOk) {
370         return false;
371     }
372 
373     // Invoke() might reallocate tensors, so they need to be reattached.
374     attachInputTensors();
375     attachOutputTensors();
376 
377     if (outputR().size() != outputPhi().size() || outputR().size() != outputPressure().size()) {
378         LOG_ALWAYS_FATAL("Output size mismatch: (r: %zu, phi: %zu, pressure: %zu)",
379                          outputR().size(), outputPhi().size(), outputPressure().size());
380     }
381 
382     return true;
383 }
384 
inputLength() const385 size_t TfLiteMotionPredictorModel::inputLength() const {
386     return getTensorBuffer<const float>(mInputR).size();
387 }
388 
outputLength() const389 size_t TfLiteMotionPredictorModel::outputLength() const {
390     return getTensorBuffer<const float>(mOutputR).size();
391 }
392 
inputR()393 std::span<float> TfLiteMotionPredictorModel::inputR() {
394     return getTensorBuffer<float>(mInputR);
395 }
396 
inputPhi()397 std::span<float> TfLiteMotionPredictorModel::inputPhi() {
398     return getTensorBuffer<float>(mInputPhi);
399 }
400 
inputPressure()401 std::span<float> TfLiteMotionPredictorModel::inputPressure() {
402     return getTensorBuffer<float>(mInputPressure);
403 }
404 
inputTilt()405 std::span<float> TfLiteMotionPredictorModel::inputTilt() {
406     return getTensorBuffer<float>(mInputTilt);
407 }
408 
inputOrientation()409 std::span<float> TfLiteMotionPredictorModel::inputOrientation() {
410     return getTensorBuffer<float>(mInputOrientation);
411 }
412 
outputR() const413 std::span<const float> TfLiteMotionPredictorModel::outputR() const {
414     return getTensorBuffer<const float>(mOutputR);
415 }
416 
outputPhi() const417 std::span<const float> TfLiteMotionPredictorModel::outputPhi() const {
418     return getTensorBuffer<const float>(mOutputPhi);
419 }
420 
outputPressure() const421 std::span<const float> TfLiteMotionPredictorModel::outputPressure() const {
422     return getTensorBuffer<const float>(mOutputPressure);
423 }
424 
425 } // namespace android
426