1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 // Contains classes that can execute different models/parts of a model. 18 19 #ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_ 20 #define LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_ 21 22 #include <cstdint> 23 #include <memory> 24 25 #include "utils/base/logging.h" 26 #include "utils/tensor-view.h" 27 #include "tensorflow/lite/interpreter.h" 28 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" 29 #include "tensorflow/lite/kernels/register.h" 30 #include "tensorflow/lite/model.h" 31 #include "tensorflow/lite/op_resolver.h" 32 #include "tensorflow/lite/string_util.h" 33 34 namespace libtextclassifier3 { 35 36 std::unique_ptr<tflite::OpResolver> BuildOpResolver(); 37 std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromModelSpec( 38 const tflite::Model*); 39 std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromBuffer( 40 const flatbuffers::Vector<uint8_t>*); 41 42 // Executor for the text selection prediction and classification models. 43 class TfLiteModelExecutor { 44 public: FromModelSpec(const tflite::Model * model_spec)45 static std::unique_ptr<TfLiteModelExecutor> FromModelSpec( 46 const tflite::Model* model_spec) { 47 auto model = TfLiteModelFromModelSpec(model_spec); 48 if (!model) { 49 return nullptr; 50 } 51 return std::unique_ptr<TfLiteModelExecutor>( 52 new TfLiteModelExecutor(std::move(model))); 53 } 54 FromBuffer(const flatbuffers::Vector<uint8_t> * model_spec_buffer)55 static std::unique_ptr<TfLiteModelExecutor> FromBuffer( 56 const flatbuffers::Vector<uint8_t>* model_spec_buffer) { 57 auto model = TfLiteModelFromBuffer(model_spec_buffer); 58 if (!model) { 59 return nullptr; 60 } 61 return std::unique_ptr<TfLiteModelExecutor>( 62 new TfLiteModelExecutor(std::move(model))); 63 } 64 65 // Creates an Interpreter for the model that serves as a scratch-pad for the 66 // inference. The Interpreter is NOT thread-safe. 67 std::unique_ptr<tflite::Interpreter> CreateInterpreter() const; 68 69 template <typename T> SetInput(const int input_index,const TensorView<T> & input_data,tflite::Interpreter * interpreter)70 void SetInput(const int input_index, const TensorView<T>& input_data, 71 tflite::Interpreter* interpreter) const { 72 input_data.copy_to(interpreter->typed_input_tensor<T>(input_index), 73 input_data.size()); 74 } 75 76 template <typename T> SetInput(const int input_index,const std::vector<T> & input_data,tflite::Interpreter * interpreter)77 void SetInput(const int input_index, const std::vector<T>& input_data, 78 tflite::Interpreter* interpreter) const { 79 std::copy(input_data.begin(), input_data.end(), 80 interpreter->typed_input_tensor<T>(input_index)); 81 } 82 83 template <typename T> SetInput(const int input_index,const T input_value,tflite::Interpreter * interpreter)84 void SetInput(const int input_index, const T input_value, 85 tflite::Interpreter* interpreter) const { 86 TfLiteTensor* input_tensor = 87 interpreter->tensor(interpreter->inputs()[input_index]); 88 switch (input_tensor->type) { 89 case kTfLiteFloat32: 90 *tflite::GetTensorData<float>(input_tensor) = input_value; 91 break; 92 case kTfLiteInt32: 93 *tflite::GetTensorData<int32_t>(input_tensor) = input_value; 94 break; 95 case kTfLiteUInt8: 96 *tflite::GetTensorData<uint8_t>(input_tensor) = input_value; 97 break; 98 case kTfLiteInt64: 99 *tflite::GetTensorData<int64_t>(input_tensor) = input_value; 100 break; 101 case kTfLiteBool: 102 *tflite::GetTensorData<bool>(input_tensor) = input_value; 103 break; 104 case kTfLiteInt16: 105 *tflite::GetTensorData<int16_t>(input_tensor) = input_value; 106 break; 107 case kTfLiteInt8: 108 *tflite::GetTensorData<int8_t>(input_tensor) = input_value; 109 break; 110 default: 111 break; 112 } 113 } 114 115 template <typename T> OutputView(const int output_index,const tflite::Interpreter * interpreter)116 TensorView<T> OutputView(const int output_index, 117 const tflite::Interpreter* interpreter) const { 118 const TfLiteTensor* output_tensor = 119 interpreter->tensor(interpreter->outputs()[output_index]); 120 return TensorView<T>(interpreter->typed_output_tensor<T>(output_index), 121 std::vector<int>(output_tensor->dims->data, 122 output_tensor->dims->data + 123 output_tensor->dims->size)); 124 } 125 126 template <typename T> Output(const int output_index,const tflite::Interpreter * interpreter)127 std::vector<T> Output(const int output_index, 128 const tflite::Interpreter* interpreter) const { 129 TensorView<T> output_view = OutputView<T>(output_index, interpreter); 130 return std::vector<T>(output_view.data(), 131 output_view.data() + output_view.size()); 132 } 133 134 protected: 135 explicit TfLiteModelExecutor( 136 std::unique_ptr<const tflite::FlatBufferModel> model); 137 TfLiteModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model, 138 std::unique_ptr<tflite::OpResolver> resolver); 139 140 std::unique_ptr<const tflite::FlatBufferModel> model_; 141 std::unique_ptr<tflite::OpResolver> resolver_; 142 }; 143 144 template <> 145 void TfLiteModelExecutor::SetInput(const int input_index, 146 const std::vector<std::string>& input_data, 147 tflite::Interpreter* interpreter) const; 148 149 template <> 150 std::vector<tflite::StringRef> TfLiteModelExecutor::Output( 151 const int output_index, const tflite::Interpreter* interpreter) const; 152 153 template <> 154 std::vector<std::string> TfLiteModelExecutor::Output( 155 const int output_index, const tflite::Interpreter* interpreter) const; 156 157 } // namespace libtextclassifier3 158 159 #endif // LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_ 160