1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 // Contains classes that can execute different models/parts of a model. 18 19 #ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_ 20 #define LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_ 21 22 #include <memory> 23 24 #include "utils/base/logging.h" 25 #include "utils/tensor-view.h" 26 #include "tensorflow/lite/interpreter.h" 27 #include "tensorflow/lite/kernels/register.h" 28 #include "tensorflow/lite/model.h" 29 #include "tensorflow/lite/op_resolver.h" 30 #include "tensorflow/lite/string_util.h" 31 32 namespace libtextclassifier3 { 33 34 std::unique_ptr<tflite::OpResolver> BuildOpResolver(); 35 std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromModelSpec( 36 const tflite::Model*); 37 std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromBuffer( 38 const flatbuffers::Vector<uint8_t>*); 39 40 // Executor for the text selection prediction and classification models. 41 class TfLiteModelExecutor { 42 public: FromModelSpec(const tflite::Model * model_spec)43 static std::unique_ptr<TfLiteModelExecutor> FromModelSpec( 44 const tflite::Model* model_spec) { 45 auto model = TfLiteModelFromModelSpec(model_spec); 46 if (!model) { 47 return nullptr; 48 } 49 return std::unique_ptr<TfLiteModelExecutor>( 50 new TfLiteModelExecutor(std::move(model))); 51 } 52 FromBuffer(const flatbuffers::Vector<uint8_t> * model_spec_buffer)53 static std::unique_ptr<TfLiteModelExecutor> FromBuffer( 54 const flatbuffers::Vector<uint8_t>* model_spec_buffer) { 55 auto model = TfLiteModelFromBuffer(model_spec_buffer); 56 if (!model) { 57 return nullptr; 58 } 59 return std::unique_ptr<TfLiteModelExecutor>( 60 new TfLiteModelExecutor(std::move(model))); 61 } 62 63 // Creates an Interpreter for the model that serves as a scratch-pad for the 64 // inference. The Interpreter is NOT thread-safe. 65 std::unique_ptr<tflite::Interpreter> CreateInterpreter() const; 66 67 template <typename T> SetInput(const int input_index,const TensorView<T> & input_data,tflite::Interpreter * interpreter)68 void SetInput(const int input_index, const TensorView<T>& input_data, 69 tflite::Interpreter* interpreter) const { 70 input_data.copy_to(interpreter->typed_input_tensor<T>(input_index), 71 input_data.size()); 72 } 73 74 template <typename T> SetInput(const int input_index,const std::vector<T> & input_data,tflite::Interpreter * interpreter)75 void SetInput(const int input_index, const std::vector<T>& input_data, 76 tflite::Interpreter* interpreter) const { 77 std::copy(input_data.begin(), input_data.end(), 78 interpreter->typed_input_tensor<T>(input_index)); 79 } 80 81 template <typename T> SetInput(const int input_index,const T input_value,tflite::Interpreter * interpreter)82 void SetInput(const int input_index, const T input_value, 83 tflite::Interpreter* interpreter) const { 84 TfLiteTensor* input_tensor = 85 interpreter->tensor(interpreter->inputs()[input_index]); 86 switch (input_tensor->type) { 87 case kTfLiteFloat32: 88 *(input_tensor->data.f) = input_value; 89 break; 90 case kTfLiteInt32: 91 *(input_tensor->data.i32) = input_value; 92 break; 93 case kTfLiteUInt8: 94 *(input_tensor->data.uint8) = input_value; 95 break; 96 case kTfLiteInt64: 97 *(input_tensor->data.i64) = input_value; 98 break; 99 case kTfLiteBool: 100 *(input_tensor->data.b) = input_value; 101 break; 102 case kTfLiteInt16: 103 *(input_tensor->data.i16) = input_value; 104 break; 105 case kTfLiteInt8: 106 *(input_tensor->data.int8) = input_value; 107 break; 108 default: 109 break; 110 } 111 } 112 113 template <typename T> OutputView(const int output_index,const tflite::Interpreter * interpreter)114 TensorView<T> OutputView(const int output_index, 115 const tflite::Interpreter* interpreter) const { 116 const TfLiteTensor* output_tensor = 117 interpreter->tensor(interpreter->outputs()[output_index]); 118 return TensorView<T>(interpreter->typed_output_tensor<T>(output_index), 119 std::vector<int>(output_tensor->dims->data, 120 output_tensor->dims->data + 121 output_tensor->dims->size)); 122 } 123 124 template <typename T> Output(const int output_index,const tflite::Interpreter * interpreter)125 std::vector<T> Output(const int output_index, 126 const tflite::Interpreter* interpreter) const { 127 TensorView<T> output_view = OutputView<T>(output_index, interpreter); 128 return std::vector<T>(output_view.data(), 129 output_view.data() + output_view.size()); 130 } 131 132 protected: 133 explicit TfLiteModelExecutor( 134 std::unique_ptr<const tflite::FlatBufferModel> model); 135 136 std::unique_ptr<const tflite::FlatBufferModel> model_; 137 std::unique_ptr<tflite::OpResolver> resolver_; 138 }; 139 140 template <> 141 void TfLiteModelExecutor::SetInput(const int input_index, 142 const std::vector<std::string>& input_data, 143 tflite::Interpreter* interpreter) const; 144 145 template <> 146 std::vector<tflite::StringRef> TfLiteModelExecutor::Output( 147 const int output_index, const tflite::Interpreter* interpreter) const; 148 149 template <> 150 std::vector<std::string> TfLiteModelExecutor::Output( 151 const int output_index, const tflite::Interpreter* interpreter) const; 152 153 } // namespace libtextclassifier3 154 155 #endif // LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_ 156