1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 // Contains classes that can execute different models/parts of a model. 18 19 #ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_ 20 #define LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_ 21 22 #include <cstdint> 23 #include <memory> 24 25 #include "utils/base/logging.h" 26 #include "utils/tensor-view.h" 27 #include "tensorflow/lite/interpreter.h" 28 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" 29 #include "tensorflow/lite/kernels/register.h" 30 #include "tensorflow/lite/model.h" 31 #include "tensorflow/lite/mutable_op_resolver.h" 32 #include "tensorflow/lite/op_resolver.h" 33 #include "tensorflow/lite/string_util.h" 34 35 namespace libtextclassifier3 { 36 37 // Creates a TF.Lite Op resolver in default configuration, with ops for 38 // Annotator and Actions models. 39 std::unique_ptr<tflite::OpResolver> BuildOpResolver(); 40 41 // Like above, but allows passage of a function that can register additional 42 // ops. 43 std::unique_ptr<tflite::OpResolver> BuildOpResolver( 44 const std::function<void(tflite::MutableOpResolver*)>& customize_fn); 45 46 std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromModelSpec( 47 const tflite::Model*); 48 std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromBuffer( 49 const flatbuffers::Vector<uint8_t>*); 50 51 // Executor for the text selection prediction and classification models. 52 class TfLiteModelExecutor { 53 public: FromModelSpec(const tflite::Model * model_spec)54 static std::unique_ptr<TfLiteModelExecutor> FromModelSpec( 55 const tflite::Model* model_spec) { 56 auto model = TfLiteModelFromModelSpec(model_spec); 57 if (!model) { 58 return nullptr; 59 } 60 return std::unique_ptr<TfLiteModelExecutor>( 61 new TfLiteModelExecutor(std::move(model))); 62 } 63 FromBuffer(const flatbuffers::Vector<uint8_t> * model_spec_buffer)64 static std::unique_ptr<TfLiteModelExecutor> FromBuffer( 65 const flatbuffers::Vector<uint8_t>* model_spec_buffer) { 66 auto model = TfLiteModelFromBuffer(model_spec_buffer); 67 if (!model) { 68 return nullptr; 69 } 70 return std::unique_ptr<TfLiteModelExecutor>( 71 new TfLiteModelExecutor(std::move(model))); 72 } 73 74 // Creates an Interpreter for the model that serves as a scratch-pad for the 75 // inference. The Interpreter is NOT thread-safe. 76 std::unique_ptr<tflite::Interpreter> CreateInterpreter() const; 77 78 template <typename T> SetInput(const int input_index,const TensorView<T> & input_data,tflite::Interpreter * interpreter)79 void SetInput(const int input_index, const TensorView<T>& input_data, 80 tflite::Interpreter* interpreter) const { 81 input_data.copy_to(interpreter->typed_input_tensor<T>(input_index), 82 input_data.size()); 83 } 84 85 template <typename T> SetInput(const int input_index,const std::vector<T> & input_data,tflite::Interpreter * interpreter)86 void SetInput(const int input_index, const std::vector<T>& input_data, 87 tflite::Interpreter* interpreter) const { 88 std::copy(input_data.begin(), input_data.end(), 89 interpreter->typed_input_tensor<T>(input_index)); 90 } 91 92 template <typename T> SetInput(const int input_index,const T input_value,tflite::Interpreter * interpreter)93 void SetInput(const int input_index, const T input_value, 94 tflite::Interpreter* interpreter) const { 95 TfLiteTensor* input_tensor = 96 interpreter->tensor(interpreter->inputs()[input_index]); 97 switch (input_tensor->type) { 98 case kTfLiteFloat32: 99 *tflite::GetTensorData<float>(input_tensor) = input_value; 100 break; 101 case kTfLiteInt32: 102 *tflite::GetTensorData<int32_t>(input_tensor) = input_value; 103 break; 104 case kTfLiteUInt8: 105 *tflite::GetTensorData<uint8_t>(input_tensor) = input_value; 106 break; 107 case kTfLiteInt64: 108 *tflite::GetTensorData<int64_t>(input_tensor) = input_value; 109 break; 110 case kTfLiteBool: 111 *tflite::GetTensorData<bool>(input_tensor) = input_value; 112 break; 113 case kTfLiteInt16: 114 *tflite::GetTensorData<int16_t>(input_tensor) = input_value; 115 break; 116 case kTfLiteInt8: 117 *tflite::GetTensorData<int8_t>(input_tensor) = input_value; 118 break; 119 default: 120 break; 121 } 122 } 123 124 template <typename T> OutputView(const int output_index,const tflite::Interpreter * interpreter)125 TensorView<T> OutputView(const int output_index, 126 const tflite::Interpreter* interpreter) const { 127 const TfLiteTensor* output_tensor = 128 interpreter->tensor(interpreter->outputs()[output_index]); 129 return TensorView<T>(interpreter->typed_output_tensor<T>(output_index), 130 std::vector<int>(output_tensor->dims->data, 131 output_tensor->dims->data + 132 output_tensor->dims->size)); 133 } 134 135 template <typename T> Output(const int output_index,const tflite::Interpreter * interpreter)136 std::vector<T> Output(const int output_index, 137 const tflite::Interpreter* interpreter) const { 138 TensorView<T> output_view = OutputView<T>(output_index, interpreter); 139 return std::vector<T>(output_view.data(), 140 output_view.data() + output_view.size()); 141 } 142 143 protected: 144 explicit TfLiteModelExecutor( 145 std::unique_ptr<const tflite::FlatBufferModel> model); 146 TfLiteModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model, 147 std::unique_ptr<tflite::OpResolver> resolver); 148 149 std::unique_ptr<const tflite::FlatBufferModel> model_; 150 std::unique_ptr<tflite::OpResolver> resolver_; 151 }; 152 153 template <> 154 void TfLiteModelExecutor::SetInput(const int input_index, 155 const std::vector<std::string>& input_data, 156 tflite::Interpreter* interpreter) const; 157 158 template <> 159 std::vector<tflite::StringRef> TfLiteModelExecutor::Output( 160 const int output_index, const tflite::Interpreter* interpreter) const; 161 162 template <> 163 std::vector<std::string> TfLiteModelExecutor::Output( 164 const int output_index, const tflite::Interpreter* interpreter) const; 165 166 } // namespace libtextclassifier3 167 168 #endif // LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_ 169