1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/contrib/lite/interpreter.h"
17 #include <cassert>
18 #include <cstdarg>
19 #include <cstdint>
20 #include <cstring>
21 #include "tensorflow/contrib/lite/arena_planner.h"
22 #include "tensorflow/contrib/lite/context.h"
23 #include "tensorflow/contrib/lite/error_reporter.h"
24 #include "tensorflow/contrib/lite/graph_info.h"
25 #include "tensorflow/contrib/lite/kernels/gemm_support.h"
26 #include "tensorflow/contrib/lite/memory_planner.h"
27 #include "tensorflow/contrib/lite/nnapi_delegate.h"
28
29 namespace {
30
31 // std::vector preallocation tuning.
32 constexpr const int kSlotsToReserve = 128;
33
34 } // namespace
35
36 namespace tflite {
37
38 // A trivial implementation of GraphInfo around the Interpreter.
39 // NOTE: this interpreter info represents the subset of the
40 // graph that is executed according to execution plan. Thus,
41 // the indices are execution plan indices rather than raw node
42 // indices.
43 class InterpreterInfo : public GraphInfo {
44 public:
InterpreterInfo(Interpreter * interpreter)45 explicit InterpreterInfo(Interpreter* interpreter)
46 : interpreter_(interpreter) {}
47
num_tensors() const48 size_t num_tensors() const override { return interpreter_->tensors_size(); }
tensor(size_t index)49 TfLiteTensor* tensor(size_t index) override {
50 return interpreter_->tensor(index);
51 }
num_nodes() const52 size_t num_nodes() const override {
53 return interpreter_->execution_plan().size();
54 }
node(size_t index) const55 const TfLiteNode& node(size_t index) const override {
56 int node_index = interpreter_->execution_plan()[index];
57 return interpreter_->node_and_registration(node_index)->first;
58 }
inputs() const59 const std::vector<int>& inputs() const override {
60 return interpreter_->inputs();
61 }
outputs() const62 const std::vector<int>& outputs() const override {
63 return interpreter_->outputs();
64 }
65
66 public:
67 Interpreter* interpreter_;
68 };
69
Interpreter(ErrorReporter * error_reporter)70 Interpreter::Interpreter(ErrorReporter* error_reporter)
71 : error_reporter_(error_reporter ? error_reporter
72 : DefaultErrorReporter()) {
73 context_.impl_ = static_cast<void*>(this);
74 context_.ResizeTensor = ResizeTensor;
75 context_.ReportError = ReportError;
76 context_.AddTensors = AddTensors;
77 context_.tensors = nullptr;
78 context_.tensors_size = 0;
79 context_.gemm_context = nullptr;
80
81 // Invalid to call these these except from TfLiteDelegate
82 context_.GetNodeAndRegistration = nullptr;
83 context_.ReplaceSubgraphsWithDelegateKernels = nullptr;
84 context_.GetExecutionPlan = nullptr;
85
86 // Reserve some space for the tensors to avoid excessive resizing.
87 tensors_.reserve(kSlotsToReserve);
88 nodes_and_registration_.reserve(kSlotsToReserve);
89 next_execution_plan_index_to_prepare_ = 0;
90 UseNNAPI(false);
91 }
92
~Interpreter()93 Interpreter::~Interpreter() {
94 for (auto& nodeAndReg : nodes_and_registration_) {
95 TfLiteNode& node = nodeAndReg.first;
96 TfLiteIntArrayFree(node.inputs);
97 TfLiteIntArrayFree(node.outputs);
98 TfLiteIntArrayFree(node.temporaries);
99 if (node.builtin_data) free(node.builtin_data);
100 OpFree(nodeAndReg.second, node.user_data);
101 node.builtin_data = nullptr;
102 }
103
104 for (int i = 0; i < context_.tensors_size; i++) {
105 TfLiteTensorFree(&context_.tensors[i]);
106 }
107 }
108
ReplaceSubgraphsWithDelegateKernels(TfLiteContext * context,TfLiteRegistration registration,const TfLiteIntArray * nodes_to_replace)109 TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels(
110 TfLiteContext* context, TfLiteRegistration registration,
111 const TfLiteIntArray* nodes_to_replace) {
112 return static_cast<Interpreter*>(context->impl_)
113 ->ReplaceSubgraphsWithDelegateKernels(registration, nodes_to_replace);
114 }
115
ReplaceSubgraphsWithDelegateKernels(TfLiteRegistration registration,const TfLiteIntArray * nodes_to_replace)116 TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels(
117 TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace) {
118 // Analyze the graph to find all independent subgraphs that are either
119 // fully not-this-delegate or this-delegate computation.
120 InterpreterInfo info(this);
121 std::vector<Subgraph> subgraphs;
122 PartitionGraphIntoIndependentSubgraphs(&info, nodes_to_replace, &subgraphs);
123
124 execution_plan_.clear();
125 for (auto& subgraph : subgraphs) {
126 // Turn subgraph.nodes into a TfLiteIntArray compatible data structure.
127 // TODO(aselle): Avoid this copy by constructing subgraph.nodes that way
128 // in the first place
129 subgraph.nodes.insert(subgraph.nodes.begin(),
130 static_cast<int>(subgraph.nodes.size()));
131 // Subgraphs calimed by the delegate should have a "macro" op created, the
132 // other subgraphs (kTfNonPartition) just have their nodes added back to
133 // the execution plan.
134 switch (subgraph.type) {
135 case Subgraph::kTfNonPartition:
136 for (auto it = subgraph.nodes.begin() + 1; it != subgraph.nodes.end();
137 ++it) {
138 execution_plan_.push_back(*it);
139 }
140 break;
141 case Subgraph::kTfPartition: {
142 void* builtin_data = nullptr;
143 int node_index;
144 // Create a node that represents computation of this subgraph.
145 AddNodeWithParameters(
146 subgraph.input_tensors, subgraph.output_tensors,
147 reinterpret_cast<const char*>(subgraph.nodes.data()),
148 subgraph.nodes.size() * sizeof(subgraph.nodes[0]), builtin_data,
149 ®istration, &node_index);
150 } break;
151 case Subgraph::kTfUnexplored:
152 return kTfLiteError;
153 break;
154 }
155 }
156 return kTfLiteOk;
157 }
158
159 // Gets an TfLiteIntArray* representing the execution plan. The interpreter owns
160 // this memory and it is only guaranteed to exist during the invocation of the
161 // delegate prepare.
GetExecutionPlan(TfLiteIntArray ** execution_plan)162 TfLiteStatus Interpreter::GetExecutionPlan(TfLiteIntArray** execution_plan) {
163 // TODO(aselle): Do not make a copy here
164 plan_cache_.reset(TfLiteIntArrayCreate(execution_plan_.size()));
165 *execution_plan = plan_cache_.get();
166 static_assert(sizeof(plan_cache_->data[0]) == sizeof(execution_plan_[0]),
167 "TfLiteIntArray and execution_plan do not contain same type.");
168 memcpy(plan_cache_->data, execution_plan_.data(),
169 sizeof(plan_cache_->data[0]) * execution_plan_.size());
170 return kTfLiteOk;
171 }
172
173 // WARNING: This is an experimental interface that is subject to change.
174 // Entry point for C node plugin API to get the execution plan
GetExecutionPlan(struct TfLiteContext * context,TfLiteIntArray ** execution_plan)175 TfLiteStatus Interpreter::GetExecutionPlan(struct TfLiteContext* context,
176 TfLiteIntArray** execution_plan) {
177 return static_cast<Interpreter*>(context->impl_)
178 ->GetExecutionPlan(execution_plan);
179 }
180
SetInputs(std::vector<int> inputs)181 TfLiteStatus Interpreter::SetInputs(std::vector<int> inputs) {
182 TF_LITE_ENSURE_OK(&context_,
183 CheckTensorIndices("inputs", inputs.data(), inputs.size()));
184 inputs_ = std::move(inputs);
185 return kTfLiteOk;
186 }
187
SetOutputs(std::vector<int> outputs)188 TfLiteStatus Interpreter::SetOutputs(std::vector<int> outputs) {
189 TF_LITE_ENSURE_OK(
190 &context_, CheckTensorIndices("outputs", outputs.data(), outputs.size()));
191 outputs_ = std::move(outputs);
192 return kTfLiteOk;
193 }
194
CheckTensorIndices(const char * label,const int * indices,int length)195 TfLiteStatus Interpreter::CheckTensorIndices(const char* label,
196 const int* indices, int length) {
197 // Making sure kOptionalTensor is not re-defined to something other than -1.
198 static_assert(kOptionalTensor == -1, "kOptionalTensor should be defined -1");
199
200 for (int i = 0; i < length; i++) {
201 int index = indices[i];
202 if (index < kOptionalTensor || index >= context_.tensors_size) {
203 ReportError(&context_, "Invalid tensor index %d in %s\n", index, label);
204 consistent_ = false;
205 return kTfLiteError;
206 }
207 }
208 return kTfLiteOk;
209 }
210
BytesRequired(TfLiteType type,const int * dims,int dims_size,size_t * bytes)211 TfLiteStatus Interpreter::BytesRequired(TfLiteType type, const int* dims,
212 int dims_size, size_t* bytes) {
213 // TODO(aselle): Check for overflow here using overflow.h in TensorFlow
214 // MultiplyWithoutOverflow.
215 TF_LITE_ENSURE(&context_, bytes != nullptr);
216 size_t count = 1;
217 for (int k = 0; k < dims_size; k++) count *= dims[k];
218 switch (type) {
219 case kTfLiteFloat32:
220 *bytes = sizeof(float) * count;
221 break;
222 case kTfLiteInt32:
223 *bytes = sizeof(int32_t) * count;
224 break;
225 case kTfLiteUInt8:
226 *bytes = sizeof(uint8_t) * count;
227 break;
228 case kTfLiteInt64:
229 *bytes = sizeof(int64_t) * count;
230 break;
231 default:
232 ReportError(&context_,
233 "Only float32, int32, int64, uint8 supported currently.");
234 return kTfLiteError;
235 }
236 return kTfLiteOk;
237 }
238
239 namespace {
convertVectorToTfLiteIntArray(const std::vector<int> & x)240 TfLiteIntArray* convertVectorToTfLiteIntArray(const std::vector<int>& x) {
241 TfLiteIntArray* lite = TfLiteIntArrayCreate(x.size());
242 for (size_t i = 0; i < x.size(); i++) lite->data[i] = x[i];
243 return lite;
244 }
245 } // namespace
246
AllocateTensors()247 TfLiteStatus Interpreter::AllocateTensors() {
248 next_execution_plan_index_to_prepare_ = 0;
249 if (memory_planner_) {
250 TF_LITE_ENSURE_STATUS(memory_planner_->ResetAllocations());
251 }
252
253 if (!consistent_) {
254 ReportError(&context_, "AllocateTensors() called on inconsistent model.");
255 return kTfLiteError;
256 }
257
258 TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors());
259 invokable_ = true;
260 return kTfLiteOk;
261 }
262
AddNodeWithParameters(const std::vector<int> & inputs,const std::vector<int> & outputs,const char * init_data,size_t init_data_size,void * builtin_data,const TfLiteRegistration * registration,int * node_index)263 TfLiteStatus Interpreter::AddNodeWithParameters(
264 const std::vector<int>& inputs, const std::vector<int>& outputs,
265 const char* init_data, size_t init_data_size, void* builtin_data,
266 const TfLiteRegistration* registration, int* node_index) {
267 invokable_ = false;
268
269 std::unique_ptr<void, decltype(free)*> builtin_data_deleter(builtin_data,
270 free);
271
272 TF_LITE_ENSURE_OK(&context_, CheckTensorIndices("node inputs", inputs.data(),
273 inputs.size()));
274 TF_LITE_ENSURE_OK(
275 &context_,
276 CheckTensorIndices("node outputs", outputs.data(), outputs.size()));
277
278 int new_node_index = nodes_and_registration_.size();
279 if (node_index) *node_index = new_node_index;
280 nodes_and_registration_.resize(nodes_and_registration_.size() + 1);
281
282 auto& node_and_reg = nodes_and_registration_.back();
283 TfLiteNode& node = node_and_reg.first;
284 if (node.inputs) TfLiteIntArrayFree(node.inputs);
285 if (node.outputs) TfLiteIntArrayFree(node.outputs);
286 if (node.temporaries) TfLiteIntArrayFree(node.temporaries);
287
288 // NOTE, here we are not using move semantics yet, since our internal
289 // representation isn't std::vector, but in the future we would like to avoid
290 // copies, so we want the interface to take r-value references now.
291 node.inputs = convertVectorToTfLiteIntArray(inputs);
292 node.outputs = convertVectorToTfLiteIntArray(outputs);
293 node.temporaries = TfLiteIntArrayCreate(0);
294 if (init_data) {
295 node.user_data = OpInit(*registration, init_data, init_data_size);
296 } else {
297 node.user_data =
298 OpInit(*registration,
299 reinterpret_cast<const char*>(builtin_data_deleter.get()), 0);
300 }
301 node.builtin_data = builtin_data_deleter.release();
302 node_and_reg.second = *registration;
303 execution_plan_.push_back(new_node_index);
304 return kTfLiteOk;
305 }
306
ResizeInputTensor(int tensor_index,const std::vector<int> & dims)307 TfLiteStatus Interpreter::ResizeInputTensor(int tensor_index,
308 const std::vector<int>& dims) {
309 // TODO(aselle): All bounds checks can be implemented as one-sided bounds
310 // checks by casting to unsigned for efficiency. Profile before doing this.
311
312 TF_LITE_ENSURE(&context_,
313 tensor_index < context_.tensors_size && tensor_index >= 0);
314 invokable_ = false;
315 TfLiteIntArray* dims_lite = convertVectorToTfLiteIntArray(dims);
316 return ResizeTensorImpl(&context_.tensors[tensor_index], dims_lite);
317 }
318
319 // Returns true if at least one tensor in the given list is kTfLiteDynamic.
HasDynamicTensor(const TfLiteContext & context,const TfLiteIntArray * tensors)320 bool HasDynamicTensor(const TfLiteContext& context,
321 const TfLiteIntArray* tensors) {
322 for (int i = 0; i < tensors->size; ++i) {
323 const TfLiteTensor& tensor = context.tensors[tensors->data[i]];
324 if (tensor.allocation_type == kTfLiteDynamic) {
325 return true;
326 }
327 }
328 return false;
329 }
330
PrepareOpsStartingAt(int first_execution_plan_index,int * last_execution_plan_index_prepared)331 TfLiteStatus Interpreter::PrepareOpsStartingAt(
332 int first_execution_plan_index, int* last_execution_plan_index_prepared) {
333 for (int execution_plan_index = first_execution_plan_index;
334 execution_plan_index < execution_plan_.size(); execution_plan_index++) {
335 int node_index = execution_plan_[execution_plan_index];
336 TfLiteNode& node = nodes_and_registration_[node_index].first;
337 const TfLiteRegistration& registration =
338 nodes_and_registration_[node_index].second;
339 if (OpPrepare(registration, &node) == kTfLiteError) {
340 return kTfLiteError;
341 }
342
343 *last_execution_plan_index_prepared = execution_plan_index;
344
345 // Discontinue if the node has dynamic outputs. Note that we don't
346 // stop for dynamic temporary tensors since they won't affect the
347 // sizes of other tensors in the graph.
348 if (HasDynamicTensor(context_, node.outputs)) {
349 break;
350 }
351 }
352 return kTfLiteOk;
353 }
354
PrepareOpsAndTensors()355 TfLiteStatus Interpreter::PrepareOpsAndTensors() {
356 if (!memory_planner_) {
357 memory_planner_.reset(new ArenaPlanner(
358 &context_, std::unique_ptr<GraphInfo>(new InterpreterInfo(this))));
359 memory_planner_->PlanAllocations();
360 }
361
362 int last_exec_plan_index_prepared = 0;
363
364 TF_LITE_ENSURE_STATUS(PrepareOpsStartingAt(
365 next_execution_plan_index_to_prepare_, &last_exec_plan_index_prepared));
366 TF_LITE_ENSURE_STATUS(memory_planner_->ExecuteAllocations(
367 next_execution_plan_index_to_prepare_, last_exec_plan_index_prepared));
368
369 next_execution_plan_index_to_prepare_ = last_exec_plan_index_prepared + 1;
370 return kTfLiteOk;
371 }
372
Invoke()373 TfLiteStatus Interpreter::Invoke() {
374 if (!consistent_) {
375 ReportError(&context_, "Invoke called on model that is not consistent.");
376 return kTfLiteError;
377 }
378 if (!invokable_) {
379 ReportError(&context_, "Invoke called on model that is not ready.");
380 return kTfLiteError;
381 }
382
383 TfLiteStatus status = kTfLiteOk;
384 if (nnapi_delegate_) {
385 if (next_execution_plan_index_to_prepare_ == execution_plan_.size()) {
386 TF_LITE_ENSURE_OK(&context_, nnapi_delegate_->Invoke(this));
387 return kTfLiteOk;
388 } else {
389 // TODO(aselle): In the future, we would like this to be an
390 // automatic tflite CPU fallback.
391 ReportError(&context_,
392 "NNAPI was requested, but dependent sized tensors "
393 "being used.\n");
394 return kTfLiteError;
395 }
396 }
397
398 // Invocations are always done in node order.
399 // Note that calling Invoke repeatedly will cause the original memory plan to
400 // be reused, unless either ResizeInputTensor() or AllocateTensors() has been
401 // called.
402 // TODO(b/71913981): we should force recalculation in the presence of dynamic
403 // tensors, because they may have new value which in turn may affect shapes
404 // and allocations.
405 for (int execution_plan_index = 0;
406 execution_plan_index < execution_plan_.size(); execution_plan_index++) {
407 if (execution_plan_index == next_execution_plan_index_to_prepare_) {
408 TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors());
409 TF_LITE_ENSURE(&context_, next_execution_plan_index_to_prepare_ >=
410 execution_plan_index);
411 }
412 int node_index = execution_plan_[execution_plan_index];
413 TfLiteNode& node = nodes_and_registration_[node_index].first;
414 const TfLiteRegistration& registration =
415 nodes_and_registration_[node_index].second;
416 if (OpInvoke(registration, &node) == kTfLiteError) {
417 status = kTfLiteError;
418 }
419 }
420 return status;
421 }
422
ResizeTensor(TfLiteContext * context,TfLiteTensor * tensor,TfLiteIntArray * new_size)423 TfLiteStatus Interpreter::ResizeTensor(TfLiteContext* context,
424 TfLiteTensor* tensor,
425 TfLiteIntArray* new_size) {
426 // Note here that context->impl_ is recovering the this pointer for an
427 // instance of Interpreter to call into the member function ResizeTensorImpl
428 // (this function is static).
429 return static_cast<Interpreter*>(context->impl_)
430 ->ResizeTensorImpl(tensor, new_size);
431 }
432
ReportErrorImpl(const char * format,va_list args)433 void Interpreter::ReportErrorImpl(const char* format, va_list args) {
434 error_reporter_->Report(format, args);
435 }
436
ReportError(TfLiteContext * context,const char * format,...)437 void Interpreter::ReportError(TfLiteContext* context, const char* format, ...) {
438 va_list args;
439 va_start(args, format);
440 auto* f = static_cast<Interpreter*>(context->impl_);
441 // Note here that context->impl_ is recovering the this pointer for an
442 // instance of Interpreter to call into the member function ReportErrorImpl
443 // (this function is static).
444 f->ReportErrorImpl(format, args);
445 va_end(args);
446 }
447
AddTensors(int tensors_to_add,int * first_new_tensor_index)448 TfLiteStatus Interpreter::AddTensors(int tensors_to_add,
449 int* first_new_tensor_index) {
450 int base_index = tensors_.size();
451 if (first_new_tensor_index) *first_new_tensor_index = base_index;
452 tensors_.resize(tensors_.size() + tensors_to_add);
453 for (int i = base_index; i < tensors_.size(); i++) {
454 memset(&tensors_[i], 0, sizeof(tensors_[i]));
455 }
456 context_.tensors = tensors_.data();
457 context_.tensors_size = tensors_.size();
458 return kTfLiteOk;
459 }
460
AddTensors(TfLiteContext * context,int tensors_to_add,int * first_new_tensor_index)461 TfLiteStatus Interpreter::AddTensors(TfLiteContext* context, int tensors_to_add,
462 int* first_new_tensor_index) {
463 // Note here that context->impl_ is recovering the this pointer for an
464 // instance of Interpreter to call into the member function AddTensors
465 // (this function is static).
466 return static_cast<Interpreter*>(context->impl_)
467 ->AddTensors(tensors_to_add, first_new_tensor_index);
468 }
469
GetNodeAndRegistration(int node_index,TfLiteNode ** node,TfLiteRegistration ** registration)470 TfLiteStatus Interpreter::GetNodeAndRegistration(
471 int node_index, TfLiteNode** node, TfLiteRegistration** registration) {
472 TF_LITE_ENSURE(&context_, node_index < nodes_size() && node_index >= 0);
473 TF_LITE_ENSURE(&context_, node != nullptr && registration != nullptr);
474 *node = &nodes_and_registration_[node_index].first;
475 *registration = &nodes_and_registration_[node_index].second;
476 return kTfLiteOk;
477 }
478
GetNodeAndRegistration(struct TfLiteContext * context,int node_index,TfLiteNode ** node,TfLiteRegistration ** registration)479 TfLiteStatus Interpreter::GetNodeAndRegistration(
480 struct TfLiteContext* context, int node_index, TfLiteNode** node,
481 TfLiteRegistration** registration) {
482 return static_cast<Interpreter*>(context->impl_)
483 ->GetNodeAndRegistration(node_index, node, registration);
484 }
485
SetTensorParametersReadOnly(int tensor_index,TfLiteType type,const char * name,const std::vector<int> & dims,TfLiteQuantizationParams quantization,const char * buffer,size_t bytes,const Allocation * allocation)486 TfLiteStatus Interpreter::SetTensorParametersReadOnly(
487 int tensor_index, TfLiteType type, const char* name,
488 const std::vector<int>& dims, TfLiteQuantizationParams quantization,
489 const char* buffer, size_t bytes, const Allocation* allocation) {
490 TF_LITE_ENSURE(&context_,
491 tensor_index < context_.tensors_size && tensor_index >= 0);
492 // For most tensors we know exactly how much memory is necessary so we can
493 // ensure the buffer is large enough. However, we need to skip string tensors
494 // because their sizes change with the contents of the individual strings.
495 if (type != kTfLiteString) {
496 size_t required_bytes;
497 TF_LITE_ENSURE_OK(&context_, BytesRequired(type, dims.data(), dims.size(),
498 &required_bytes));
499 TF_LITE_ENSURE_EQ(&context_, required_bytes, bytes);
500 }
501 invokable_ = false;
502 TfLiteTensorReset(type, name, convertVectorToTfLiteIntArray(dims),
503 quantization, const_cast<char*>(buffer), bytes,
504 kTfLiteMmapRo, allocation, &context_.tensors[tensor_index]);
505 return kTfLiteOk;
506 }
507
508 // Set description of inputs/outputs/data/fptrs for node `node_index`.
509 // This variant assumes an external buffer has been allocated of size
510 // bytes. The lifetime of buffer must be ensured to be greater or equal
511 // to Interpreter.
SetTensorParametersReadWrite(int tensor_index,TfLiteType type,const char * name,const std::vector<int> & dims,TfLiteQuantizationParams quantization)512 TfLiteStatus Interpreter::SetTensorParametersReadWrite(
513 int tensor_index, TfLiteType type, const char* name,
514 const std::vector<int>& dims, TfLiteQuantizationParams quantization) {
515 invokable_ = false;
516 TF_LITE_ENSURE(&context_,
517 tensor_index < context_.tensors_size && tensor_index >= 0);
518 size_t required_bytes = 0;
519 if (type != kTfLiteString) {
520 // These types will be allocated in our arena so we need to record how
521 // many bytes we will need based on the dimensions. String tensors are
522 // allocated dynamically and we can't know ahead of time how much space
523 // they will require.
524 TF_LITE_ENSURE_OK(&context_, BytesRequired(type, dims.data(), dims.size(),
525 &required_bytes));
526 }
527 TfLiteTensorReset(type, name, convertVectorToTfLiteIntArray(dims),
528 quantization,
529 /*buffer=*/nullptr, required_bytes,
530 type == kTfLiteString ? kTfLiteDynamic : kTfLiteArenaRw,
531 nullptr, &context_.tensors[tensor_index]);
532 return kTfLiteOk;
533 }
534
SetExecutionPlan(const std::vector<int> & new_plan)535 TfLiteStatus Interpreter::SetExecutionPlan(const std::vector<int>& new_plan) {
536 for (int node_index : new_plan) {
537 TF_LITE_ENSURE(&context_, node_index >= 0 && node_index < nodes_size());
538 }
539 execution_plan_ = new_plan;
540 return kTfLiteOk;
541 }
542
ResizeTensorImpl(TfLiteTensor * tensor,TfLiteIntArray * new_size)543 TfLiteStatus Interpreter::ResizeTensorImpl(TfLiteTensor* tensor,
544 TfLiteIntArray* new_size) {
545 // Note that in theory we could resize kTfLiteArenaRwPersistent tensors too.
546 if (tensor->allocation_type == kTfLiteArenaRw ||
547 tensor->allocation_type == kTfLiteDynamic) {
548 if (tensor->type != kTfLiteString) {
549 size_t bytesRequired;
550 TfLiteStatus status = BytesRequired(tensor->type, new_size->data,
551 new_size->size, &bytesRequired);
552 if (status != kTfLiteOk) {
553 TfLiteIntArrayFree(new_size);
554 return kTfLiteError;
555 }
556
557 // Realloc space for kTfLiteDynamic tensors.
558 TfLiteTensorRealloc(bytesRequired, tensor);
559 tensor->bytes = bytesRequired;
560 }
561 if (tensor->dims) TfLiteIntArrayFree(tensor->dims);
562 tensor->dims = new_size;
563
564 if (tensor->allocation_type != kTfLiteDynamic) {
565 tensor->data.raw = nullptr;
566 }
567 } else {
568 // kTfLiteMmapRo tensors are stored in the flatbuffer and are therefore
569 // of fixed size.
570 TfLiteIntArrayFree(new_size);
571 ReportError(&context_, "Attempting to resize a fixed-size tensor.");
572 return kTfLiteError;
573 }
574 return kTfLiteOk;
575 }
576
UseNNAPI(bool enable)577 void Interpreter::UseNNAPI(bool enable) {
578 // TODO(aselle): This is a workaround for finding if NNAPI exists.
579 // We also need to make sure getLibraryHandle() is renamed to be NNAPI
580 // prefixed.
581 if (!NNAPIExists()) enable = false;
582 if (!enable) {
583 nnapi_delegate_.reset();
584 } else if (!nnapi_delegate_) {
585 nnapi_delegate_.reset(new NNAPIDelegate);
586 }
587 }
588
SetNumThreads(int num_threads)589 void Interpreter::SetNumThreads(int num_threads) {
590 // TODO(ahentz): this forces us to link against gemmlowp even when the ops
591 // don't use it. We should implement some dynamic mechanism for this sort of
592 // library-specific initialization.
593 tflite::gemm_support::SetMaxNumThreads(&context_, num_threads);
594 }
595
ModifyGraphWithDelegate(TfLiteDelegate * delegate)596 TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) {
597 // TODO(aselle): Consider if it is worth storing pointers to delegates.
598 // Setup additional context interface
599 context_.GetNodeAndRegistration = GetNodeAndRegistration;
600 context_.ReplaceSubgraphsWithDelegateKernels =
601 ReplaceSubgraphsWithDelegateKernels;
602 context_.GetExecutionPlan = GetExecutionPlan;
603
604 TfLiteStatus status = delegate->Prepare(&context_, delegate->data_);
605 // Remove additional context info.
606 context_.GetNodeAndRegistration = nullptr;
607 context_.ReplaceSubgraphsWithDelegateKernels = nullptr;
608 context_.GetExecutionPlan = nullptr;
609 return status;
610 }
611
612 } // namespace tflite
613