1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_BASE_TASK_API_H_ 17 #define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_BASE_TASK_API_H_ 18 19 #include <utility> 20 21 #include "absl/status/status.h" 22 #include "absl/strings/string_view.h" 23 #include "tensorflow/lite/c/common.h" 24 #include "tensorflow_lite_support/cc/common.h" 25 #include "tensorflow_lite_support/cc/port/status_macros.h" 26 #include "tensorflow_lite_support/cc/port/statusor.h" 27 #include "tensorflow_lite_support/cc/port/tflite_wrapper.h" 28 #include "tensorflow_lite_support/cc/task/core/tflite_engine.h" 29 30 namespace tflite { 31 namespace task { 32 namespace core { 33 34 class BaseUntypedTaskApi { 35 public: BaseUntypedTaskApi(std::unique_ptr<TfLiteEngine> engine)36 explicit BaseUntypedTaskApi(std::unique_ptr<TfLiteEngine> engine) 37 : engine_{std::move(engine)} {} 38 39 virtual ~BaseUntypedTaskApi() = default; 40 GetTfLiteEngine()41 TfLiteEngine* GetTfLiteEngine() { return engine_.get(); } GetTfLiteEngine()42 const TfLiteEngine* GetTfLiteEngine() const { return engine_.get(); } 43 GetMetadataExtractor()44 const metadata::ModelMetadataExtractor* GetMetadataExtractor() const { 45 return engine_->metadata_extractor(); 46 } 47 48 protected: 49 std::unique_ptr<TfLiteEngine> engine_; 50 }; 51 52 template <class OutputType, class... InputTypes> 53 class BaseTaskApi : public BaseUntypedTaskApi { 54 public: BaseTaskApi(std::unique_ptr<TfLiteEngine> engine)55 explicit BaseTaskApi(std::unique_ptr<TfLiteEngine> engine) 56 : BaseUntypedTaskApi(std::move(engine)) {} 57 // BaseTaskApi is neither copyable nor movable. 58 BaseTaskApi(const BaseTaskApi&) = delete; 59 BaseTaskApi& operator=(const BaseTaskApi&) = delete; 60 61 // Cancels the current running TFLite invocation on CPU. 62 // 63 // Usually called on a different thread than the one inference is running on. 64 // Calling Cancel() will cause the underlying TFLite interpreter to return an 65 // error, which will turn into a `CANCELLED` status and empty results. Calling 66 // Cancel() at the other time will not take any effect on the current or 67 // following invocation. It is perfectly fine to run inference again on the 68 // same instance after a cancelled invocation. If the TFLite inference is 69 // partially delegated on CPU, logs a warning message and only cancels the 70 // invocation running on CPU. Other invocation which depends on the output of 71 // the CPU invocation will not be executed. Cancel()72 void Cancel() { engine_->Cancel(); } 73 74 protected: 75 // Subclasses need to populate input_tensors from api_inputs. 76 virtual absl::Status Preprocess( 77 const std::vector<TfLiteTensor*>& input_tensors, 78 InputTypes... api_inputs) = 0; 79 80 // Subclasses need to construct OutputType object from output_tensors. 81 // Original inputs are also provided as they may be needed. 82 virtual tflite::support::StatusOr<OutputType> Postprocess( 83 const std::vector<const TfLiteTensor*>& output_tensors, 84 InputTypes... api_inputs) = 0; 85 86 // Returns (the addresses of) the model's inputs. GetInputTensors()87 std::vector<TfLiteTensor*> GetInputTensors() { return engine_->GetInputs(); } 88 89 // Returns (the addresses of) the model's outputs. GetOutputTensors()90 std::vector<const TfLiteTensor*> GetOutputTensors() { 91 return engine_->GetOutputs(); 92 } 93 94 // Performs inference using tflite::support::TfLiteInterpreterWrapper 95 // InvokeWithoutFallback(). Infer(InputTypes...args)96 tflite::support::StatusOr<OutputType> Infer(InputTypes... args) { 97 tflite::task::core::TfLiteEngine::InterpreterWrapper* interpreter_wrapper = 98 engine_->interpreter_wrapper(); 99 // Note: AllocateTensors() is already performed by the interpreter wrapper 100 // at InitInterpreter time (see TfLiteEngine). 101 RETURN_IF_ERROR(Preprocess(GetInputTensors(), args...)); 102 absl::Status status = interpreter_wrapper->InvokeWithoutFallback(); 103 if (!status.ok()) { 104 return status.GetPayload(tflite::support::kTfLiteSupportPayload) 105 .has_value() 106 ? status 107 : tflite::support::CreateStatusWithPayload(status.code(), 108 status.message()); 109 } 110 return Postprocess(GetOutputTensors(), args...); 111 } 112 113 // Performs inference using tflite::support::TfLiteInterpreterWrapper 114 // InvokeWithFallback() to benefit from automatic fallback from delegation to 115 // CPU where applicable. InferWithFallback(InputTypes...args)116 tflite::support::StatusOr<OutputType> InferWithFallback(InputTypes... args) { 117 tflite::task::core::TfLiteEngine::InterpreterWrapper* interpreter_wrapper = 118 engine_->interpreter_wrapper(); 119 // Note: AllocateTensors() is already performed by the interpreter wrapper 120 // at InitInterpreter time (see TfLiteEngine). 121 RETURN_IF_ERROR(Preprocess(GetInputTensors(), args...)); 122 auto set_inputs_nop = 123 [](tflite::task::core::TfLiteEngine::Interpreter* interpreter) 124 -> absl::Status { 125 // NOP since inputs are populated at Preprocess() time. 126 return absl::OkStatus(); 127 }; 128 absl::Status status = 129 interpreter_wrapper->InvokeWithFallback(set_inputs_nop); 130 if (!status.ok()) { 131 return status.GetPayload(tflite::support::kTfLiteSupportPayload) 132 .has_value() 133 ? status 134 : tflite::support::CreateStatusWithPayload(status.code(), 135 status.message()); 136 } 137 return Postprocess(GetOutputTensors(), args...); 138 } 139 }; 140 141 } // namespace core 142 } // namespace task 143 } // namespace tflite 144 #endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_BASE_TASK_API_H_ 145