1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/interpreter.h"
17
18 #include <cassert>
19 #include <cstdarg>
20 #include <cstdint>
21 #include <cstring>
22 #include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
23
24 #include "tensorflow/lite/c/c_api_internal.h"
25 #include "tensorflow/lite/context_util.h"
26 #include "tensorflow/lite/core/api/error_reporter.h"
27 #include "tensorflow/lite/graph_info.h"
28 #include "tensorflow/lite/memory_planner.h"
29 #include "tensorflow/lite/minimal_logging.h"
30 #include "tensorflow/lite/nnapi_delegate.h"
31 #include "tensorflow/lite/profiling/profiler.h"
32 #include "tensorflow/lite/schema/schema_generated.h"
33 #include "tensorflow/lite/util.h"
34
35 namespace tflite {
36
37 namespace {
38
39 // Gets the current TfLiteQuantization from the legacy fLiteQuantizationParams.
GetQuantizationFromLegacy(const TfLiteQuantizationParams & legacy_quantization)40 TfLiteQuantization GetQuantizationFromLegacy(
41 const TfLiteQuantizationParams& legacy_quantization) {
42 TfLiteQuantization quantization;
43 quantization.type = kTfLiteAffineQuantization;
44 auto* affine_quantization = reinterpret_cast<TfLiteAffineQuantization*>(
45 malloc(sizeof(TfLiteAffineQuantization)));
46 affine_quantization->scale = TfLiteFloatArrayCreate(1);
47 affine_quantization->zero_point = TfLiteIntArrayCreate(1);
48 affine_quantization->scale->data[0] = legacy_quantization.scale;
49 affine_quantization->zero_point->data[0] = legacy_quantization.zero_point;
50 quantization.params = affine_quantization;
51
52 return quantization;
53 }
54
55 } // namespace
56
Interpreter(ErrorReporter * error_reporter)57 Interpreter::Interpreter(ErrorReporter* error_reporter)
58 : error_reporter_(error_reporter ? error_reporter
59 : DefaultErrorReporter()) {
60 // Only log initialization once per-process to avoid log spam.
61 static std::once_flag init_log_once_flag;
62 std::call_once(init_log_once_flag, []() {
63 // TODO(b/128420794): Include the TFLite runtime version in the log.
64 TFLITE_LOG_PROD(TFLITE_LOG_INFO, "Initialized TensorFlow Lite runtime.");
65 });
66
67 // There's always at least 1 subgraph which is the primary subgraph.
68 AddSubgraphs(1);
69 context_ = primary_subgraph().context();
70
71 // Reserve some space for the tensors to avoid excessive resizing.
72 for (int i = 0; i < kTfLiteMaxExternalContexts; ++i) {
73 external_contexts_[i] = nullptr;
74 }
75
76 UseNNAPI(false);
77 }
78
~Interpreter()79 Interpreter::~Interpreter() {}
80
SetExternalContext(TfLiteExternalContextType type,TfLiteExternalContext * ctx)81 void Interpreter::SetExternalContext(TfLiteExternalContextType type,
82 TfLiteExternalContext* ctx) {
83 primary_subgraph().SetExternalContext(type, ctx);
84 }
85
SetInputs(std::vector<int> inputs)86 TfLiteStatus Interpreter::SetInputs(std::vector<int> inputs) {
87 return primary_subgraph().SetInputs(inputs);
88 }
89
SetOutputs(std::vector<int> outputs)90 TfLiteStatus Interpreter::SetOutputs(std::vector<int> outputs) {
91 return primary_subgraph().SetOutputs(outputs);
92 }
93
SetVariables(std::vector<int> variables)94 TfLiteStatus Interpreter::SetVariables(std::vector<int> variables) {
95 return primary_subgraph().SetVariables(variables);
96 }
97
AllocateTensors()98 TfLiteStatus Interpreter::AllocateTensors() {
99 return primary_subgraph().AllocateTensors();
100 }
101
ReserveNodes(int count)102 void Interpreter::ReserveNodes(int count) {
103 primary_subgraph().ReserveNodes(count);
104 }
105
AddSubgraphs(int subgraphs_to_add,int * first_new_subgraph_index)106 void Interpreter::AddSubgraphs(int subgraphs_to_add,
107 int* first_new_subgraph_index) {
108 const size_t base_index = subgraphs_.size();
109 if (first_new_subgraph_index) *first_new_subgraph_index = base_index;
110
111 subgraphs_.reserve(base_index + subgraphs_to_add);
112 for (int i = 0; i < subgraphs_to_add; ++i) {
113 Subgraph* subgraph =
114 new Subgraph(error_reporter_, external_contexts_, &subgraphs_);
115 subgraphs_.emplace_back(subgraph);
116 }
117 }
118
AddNodeWithParameters(const std::vector<int> & inputs,const std::vector<int> & outputs,const char * init_data,size_t init_data_size,void * builtin_data,const TfLiteRegistration * registration,int * node_index)119 TfLiteStatus Interpreter::AddNodeWithParameters(
120 const std::vector<int>& inputs, const std::vector<int>& outputs,
121 const char* init_data, size_t init_data_size, void* builtin_data,
122 const TfLiteRegistration* registration, int* node_index) {
123 return primary_subgraph().AddNodeWithParameters(inputs, outputs, init_data,
124 init_data_size, builtin_data,
125 registration, node_index);
126 }
127
ResizeInputTensor(int tensor_index,const std::vector<int> & dims)128 TfLiteStatus Interpreter::ResizeInputTensor(int tensor_index,
129 const std::vector<int>& dims) {
130 return primary_subgraph().ResizeInputTensor(tensor_index, dims);
131 }
132
Invoke()133 TfLiteStatus Interpreter::Invoke() {
134 TF_LITE_ENSURE_STATUS(primary_subgraph().Invoke());
135
136 if (!allow_buffer_handle_output_) {
137 for (int tensor_index : outputs()) {
138 TF_LITE_ENSURE_STATUS(
139 primary_subgraph().EnsureTensorDataIsReadable(tensor_index));
140 }
141 }
142
143 return kTfLiteOk;
144 }
145
AddTensors(int tensors_to_add,int * first_new_tensor_index)146 TfLiteStatus Interpreter::AddTensors(int tensors_to_add,
147 int* first_new_tensor_index) {
148 return primary_subgraph().AddTensors(tensors_to_add, first_new_tensor_index);
149 }
150
ResetVariableTensors()151 TfLiteStatus Interpreter::ResetVariableTensors() {
152 return primary_subgraph().ResetVariableTensors();
153 }
154
SetTensorParametersReadOnly(int tensor_index,TfLiteType type,const char * name,const std::vector<int> & dims,TfLiteQuantization quantization,const char * buffer,size_t bytes,const Allocation * allocation)155 TfLiteStatus Interpreter::SetTensorParametersReadOnly(
156 int tensor_index, TfLiteType type, const char* name,
157 const std::vector<int>& dims, TfLiteQuantization quantization,
158 const char* buffer, size_t bytes, const Allocation* allocation) {
159 return primary_subgraph().SetTensorParametersReadOnly(
160 tensor_index, type, name, dims.size(), dims.data(), quantization, buffer,
161 bytes, allocation);
162 }
163
SetTensorParametersReadWrite(int tensor_index,TfLiteType type,const char * name,const std::vector<int> & dims,TfLiteQuantization quantization,bool is_variable)164 TfLiteStatus Interpreter::SetTensorParametersReadWrite(
165 int tensor_index, TfLiteType type, const char* name,
166 const std::vector<int>& dims, TfLiteQuantization quantization,
167 bool is_variable) {
168 return primary_subgraph().SetTensorParametersReadWrite(
169 tensor_index, type, name, dims.size(), dims.data(), quantization,
170 is_variable);
171 }
172
SetTensorParametersReadOnly(int tensor_index,TfLiteType type,const char * name,const size_t rank,const int * dims,TfLiteQuantizationParams quantization,const char * buffer,size_t bytes,const Allocation * allocation)173 TfLiteStatus Interpreter::SetTensorParametersReadOnly(
174 int tensor_index, TfLiteType type, const char* name, const size_t rank,
175 const int* dims, TfLiteQuantizationParams quantization, const char* buffer,
176 size_t bytes, const Allocation* allocation) {
177 TfLiteQuantization new_quantization = GetQuantizationFromLegacy(quantization);
178 if (primary_subgraph().SetTensorParametersReadOnly(
179 tensor_index, type, name, rank, dims, new_quantization, buffer, bytes,
180 allocation) != kTfLiteOk) {
181 TfLiteQuantizationFree(&new_quantization);
182 return kTfLiteError;
183 }
184 return kTfLiteOk;
185 }
186
SetTensorParametersReadWrite(int tensor_index,TfLiteType type,const char * name,const size_t rank,const int * dims,TfLiteQuantizationParams quantization,bool is_variable)187 TfLiteStatus Interpreter::SetTensorParametersReadWrite(
188 int tensor_index, TfLiteType type, const char* name, const size_t rank,
189 const int* dims, TfLiteQuantizationParams quantization, bool is_variable) {
190 TfLiteQuantization new_quantization = GetQuantizationFromLegacy(quantization);
191 if (primary_subgraph().SetTensorParametersReadWrite(
192 tensor_index, type, name, rank, dims, new_quantization,
193 is_variable) != kTfLiteOk) {
194 TfLiteQuantizationFree(&new_quantization);
195 return kTfLiteError;
196 }
197 return kTfLiteOk;
198 }
199
SetExecutionPlan(const std::vector<int> & new_plan)200 TfLiteStatus Interpreter::SetExecutionPlan(const std::vector<int>& new_plan) {
201 return primary_subgraph().SetExecutionPlan(new_plan);
202 }
203
UseNNAPI(bool enable)204 void Interpreter::UseNNAPI(bool enable) { primary_subgraph().UseNNAPI(enable); }
205
SetNumThreads(int num_threads)206 void Interpreter::SetNumThreads(int num_threads) {
207 for (auto& subgraph : subgraphs_) {
208 subgraph->context()->recommended_num_threads = num_threads;
209 }
210
211 for (int i = 0; i < kTfLiteMaxExternalContexts; ++i) {
212 auto* c = external_contexts_[i];
213 if (c && c->Refresh) {
214 c->Refresh(context_);
215 }
216 }
217 }
218
SetAllowFp16PrecisionForFp32(bool allow)219 void Interpreter::SetAllowFp16PrecisionForFp32(bool allow) {
220 for (auto& subgraph : subgraphs_) {
221 subgraph->context()->allow_fp32_relax_to_fp16 = allow;
222 }
223 }
224
225 // TODO(b/121264966): Subgraphs added after cancellation is set will not get the
226 // cancellation function added to their context.
SetCancellationFunction(void * data,bool (* check_cancelled_func)(void *))227 void Interpreter::SetCancellationFunction(void* data,
228 bool (*check_cancelled_func)(void*)) {
229 for (auto& subgraph : subgraphs_) {
230 subgraph->SetCancellationFunction(data, check_cancelled_func);
231 }
232 }
233
ModifyGraphWithDelegate(TfLiteDelegate * delegate)234 TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) {
235 for (auto& subgraph : subgraphs_) {
236 TF_LITE_ENSURE_OK(context_, subgraph->ModifyGraphWithDelegate(delegate));
237 }
238 return kTfLiteOk;
239 }
240
SetBufferHandle(int tensor_index,TfLiteBufferHandle buffer_handle,TfLiteDelegate * delegate)241 TfLiteStatus Interpreter::SetBufferHandle(int tensor_index,
242 TfLiteBufferHandle buffer_handle,
243 TfLiteDelegate* delegate) {
244 TF_LITE_ENSURE(context_, tensor_index < tensors_size());
245 std::vector<TfLiteTensor>& tensors = primary_subgraph().tensors();
246 TfLiteTensor* tensor = &tensors[tensor_index];
247
248 TF_LITE_ENSURE(context_,
249 tensor->delegate == nullptr || tensor->delegate == delegate);
250 tensor->delegate = delegate;
251 if (tensor->buffer_handle != kTfLiteNullBufferHandle) {
252 TF_LITE_ENSURE(context_, tensor->delegate->FreeBufferHandle != nullptr);
253 tensor->delegate->FreeBufferHandle(context_, tensor->delegate,
254 &tensor->buffer_handle);
255 }
256 tensor->buffer_handle = buffer_handle;
257
258 return kTfLiteOk;
259 }
260
GetBufferHandle(int tensor_index,TfLiteBufferHandle * buffer_handle,TfLiteDelegate ** delegate)261 TfLiteStatus Interpreter::GetBufferHandle(int tensor_index,
262 TfLiteBufferHandle* buffer_handle,
263 TfLiteDelegate** delegate) {
264 TF_LITE_ENSURE(context_, tensor_index < tensors_size());
265 std::vector<TfLiteTensor>& tensors = primary_subgraph().tensors();
266 TfLiteTensor* tensor = &tensors[tensor_index];
267
268 *delegate = tensor->delegate;
269 *buffer_handle = tensor->buffer_handle;
270
271 return kTfLiteOk;
272 }
273
SetProfiler(profiling::Profiler * profiler)274 void Interpreter::SetProfiler(profiling::Profiler* profiler) {
275 for (auto& subgraph : subgraphs_) subgraph->SetProfiler(profiler);
276 }
277
GetProfiler()278 profiling::Profiler* Interpreter::GetProfiler() {
279 return primary_subgraph().GetProfiler();
280 }
281
282 } // namespace tflite
283