• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/contrib/lite/interpreter.h"
17 #include <cassert>
18 #include <cstdarg>
19 #include <cstdint>
20 #include <cstring>
21 #include "tensorflow/contrib/lite/arena_planner.h"
22 #include "tensorflow/contrib/lite/context.h"
23 #include "tensorflow/contrib/lite/error_reporter.h"
24 #include "tensorflow/contrib/lite/graph_info.h"
25 #include "tensorflow/contrib/lite/kernels/gemm_support.h"
26 #include "tensorflow/contrib/lite/memory_planner.h"
27 #include "tensorflow/contrib/lite/nnapi_delegate.h"
28 
29 namespace {
30 
31 // std::vector preallocation tuning.
32 constexpr const int kSlotsToReserve = 128;
33 
34 }  // namespace
35 
36 namespace tflite {
37 
38 // A trivial implementation of GraphInfo around the Interpreter.
39 // NOTE: this interpreter info represents the subset of the
40 // graph that is executed according to execution plan. Thus,
41 // the indices are execution plan indices rather than raw node
42 // indices.
43 class InterpreterInfo : public GraphInfo {
44  public:
InterpreterInfo(Interpreter * interpreter)45   explicit InterpreterInfo(Interpreter* interpreter)
46       : interpreter_(interpreter) {}
47 
num_tensors() const48   size_t num_tensors() const override { return interpreter_->tensors_size(); }
tensor(size_t index)49   TfLiteTensor* tensor(size_t index) override {
50     return interpreter_->tensor(index);
51   }
num_nodes() const52   size_t num_nodes() const override {
53     return interpreter_->execution_plan().size();
54   }
node(size_t index) const55   const TfLiteNode& node(size_t index) const override {
56     int node_index = interpreter_->execution_plan()[index];
57     return interpreter_->node_and_registration(node_index)->first;
58   }
inputs() const59   const std::vector<int>& inputs() const override {
60     return interpreter_->inputs();
61   }
outputs() const62   const std::vector<int>& outputs() const override {
63     return interpreter_->outputs();
64   }
65 
66  public:
67   Interpreter* interpreter_;
68 };
69 
Interpreter(ErrorReporter * error_reporter)70 Interpreter::Interpreter(ErrorReporter* error_reporter)
71     : error_reporter_(error_reporter ? error_reporter
72                                      : DefaultErrorReporter()) {
73   context_.impl_ = static_cast<void*>(this);
74   context_.ResizeTensor = ResizeTensor;
75   context_.ReportError = ReportError;
76   context_.AddTensors = AddTensors;
77   context_.tensors = nullptr;
78   context_.tensors_size = 0;
79   context_.gemm_context = nullptr;
80 
81   // Invalid to call these these except from TfLiteDelegate
82   context_.GetNodeAndRegistration = nullptr;
83   context_.ReplaceSubgraphsWithDelegateKernels = nullptr;
84   context_.GetExecutionPlan = nullptr;
85 
86   // Reserve some space for the tensors to avoid excessive resizing.
87   tensors_.reserve(kSlotsToReserve);
88   nodes_and_registration_.reserve(kSlotsToReserve);
89   next_execution_plan_index_to_prepare_ = 0;
90   UseNNAPI(false);
91 }
92 
~Interpreter()93 Interpreter::~Interpreter() {
94   for (auto& nodeAndReg : nodes_and_registration_) {
95     TfLiteNode& node = nodeAndReg.first;
96     TfLiteIntArrayFree(node.inputs);
97     TfLiteIntArrayFree(node.outputs);
98     TfLiteIntArrayFree(node.temporaries);
99     if (node.builtin_data) free(node.builtin_data);
100     OpFree(nodeAndReg.second, node.user_data);
101     node.builtin_data = nullptr;
102   }
103 
104   for (int i = 0; i < context_.tensors_size; i++) {
105     TfLiteTensorFree(&context_.tensors[i]);
106   }
107 }
108 
ReplaceSubgraphsWithDelegateKernels(TfLiteContext * context,TfLiteRegistration registration,const TfLiteIntArray * nodes_to_replace)109 TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels(
110     TfLiteContext* context, TfLiteRegistration registration,
111     const TfLiteIntArray* nodes_to_replace) {
112   return static_cast<Interpreter*>(context->impl_)
113       ->ReplaceSubgraphsWithDelegateKernels(registration, nodes_to_replace);
114 }
115 
ReplaceSubgraphsWithDelegateKernels(TfLiteRegistration registration,const TfLiteIntArray * nodes_to_replace)116 TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels(
117     TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace) {
118   // Analyze the graph to find all independent subgraphs that are either
119   // fully not-this-delegate or this-delegate computation.
120   InterpreterInfo info(this);
121   std::vector<Subgraph> subgraphs;
122   PartitionGraphIntoIndependentSubgraphs(&info, nodes_to_replace, &subgraphs);
123 
124   execution_plan_.clear();
125   for (auto& subgraph : subgraphs) {
126     // Turn subgraph.nodes into a TfLiteIntArray compatible data structure.
127     // TODO(aselle): Avoid this copy by constructing subgraph.nodes that way
128     // in the first place
129     subgraph.nodes.insert(subgraph.nodes.begin(),
130                           static_cast<int>(subgraph.nodes.size()));
131     // Subgraphs calimed by the delegate should have a "macro" op created, the
132     // other subgraphs (kTfNonPartition) just have their nodes added back to
133     // the execution plan.
134     switch (subgraph.type) {
135       case Subgraph::kTfNonPartition:
136         for (auto it = subgraph.nodes.begin() + 1; it != subgraph.nodes.end();
137              ++it) {
138           execution_plan_.push_back(*it);
139         }
140         break;
141       case Subgraph::kTfPartition: {
142         void* builtin_data = nullptr;
143         int node_index;
144         // Create a node that represents computation of this subgraph.
145         AddNodeWithParameters(
146             subgraph.input_tensors, subgraph.output_tensors,
147             reinterpret_cast<const char*>(subgraph.nodes.data()),
148             subgraph.nodes.size() * sizeof(subgraph.nodes[0]), builtin_data,
149             &registration, &node_index);
150       } break;
151       case Subgraph::kTfUnexplored:
152         return kTfLiteError;
153         break;
154     }
155   }
156   return kTfLiteOk;
157 }
158 
159 // Gets an TfLiteIntArray* representing the execution plan. The interpreter owns
160 // this memory and it is only guaranteed to exist during the invocation of the
161 // delegate prepare.
GetExecutionPlan(TfLiteIntArray ** execution_plan)162 TfLiteStatus Interpreter::GetExecutionPlan(TfLiteIntArray** execution_plan) {
163   // TODO(aselle): Do not make a copy here
164   plan_cache_.reset(TfLiteIntArrayCreate(execution_plan_.size()));
165   *execution_plan = plan_cache_.get();
166   static_assert(sizeof(plan_cache_->data[0]) == sizeof(execution_plan_[0]),
167                 "TfLiteIntArray and execution_plan do not contain same type.");
168   memcpy(plan_cache_->data, execution_plan_.data(),
169          sizeof(plan_cache_->data[0]) * execution_plan_.size());
170   return kTfLiteOk;
171 }
172 
173 // WARNING: This is an experimental interface that is subject to change.
174 // Entry point for C node plugin API to get the execution plan
GetExecutionPlan(struct TfLiteContext * context,TfLiteIntArray ** execution_plan)175 TfLiteStatus Interpreter::GetExecutionPlan(struct TfLiteContext* context,
176                                            TfLiteIntArray** execution_plan) {
177   return static_cast<Interpreter*>(context->impl_)
178       ->GetExecutionPlan(execution_plan);
179 }
180 
SetInputs(std::vector<int> inputs)181 TfLiteStatus Interpreter::SetInputs(std::vector<int> inputs) {
182   TF_LITE_ENSURE_OK(&context_,
183                     CheckTensorIndices("inputs", inputs.data(), inputs.size()));
184   inputs_ = std::move(inputs);
185   return kTfLiteOk;
186 }
187 
SetOutputs(std::vector<int> outputs)188 TfLiteStatus Interpreter::SetOutputs(std::vector<int> outputs) {
189   TF_LITE_ENSURE_OK(
190       &context_, CheckTensorIndices("outputs", outputs.data(), outputs.size()));
191   outputs_ = std::move(outputs);
192   return kTfLiteOk;
193 }
194 
CheckTensorIndices(const char * label,const int * indices,int length)195 TfLiteStatus Interpreter::CheckTensorIndices(const char* label,
196                                              const int* indices, int length) {
197   // Making sure kOptionalTensor is not re-defined to something other than -1.
198   static_assert(kOptionalTensor == -1, "kOptionalTensor should be defined -1");
199 
200   for (int i = 0; i < length; i++) {
201     int index = indices[i];
202     if (index < kOptionalTensor || index >= context_.tensors_size) {
203       ReportError(&context_, "Invalid tensor index %d in %s\n", index, label);
204       consistent_ = false;
205       return kTfLiteError;
206     }
207   }
208   return kTfLiteOk;
209 }
210 
BytesRequired(TfLiteType type,const int * dims,int dims_size,size_t * bytes)211 TfLiteStatus Interpreter::BytesRequired(TfLiteType type, const int* dims,
212                                         int dims_size, size_t* bytes) {
213   // TODO(aselle): Check for overflow here using overflow.h in TensorFlow
214   // MultiplyWithoutOverflow.
215   TF_LITE_ENSURE(&context_, bytes != nullptr);
216   size_t count = 1;
217   for (int k = 0; k < dims_size; k++) count *= dims[k];
218   switch (type) {
219     case kTfLiteFloat32:
220       *bytes = sizeof(float) * count;
221       break;
222     case kTfLiteInt32:
223       *bytes = sizeof(int32_t) * count;
224       break;
225     case kTfLiteUInt8:
226       *bytes = sizeof(uint8_t) * count;
227       break;
228     case kTfLiteInt64:
229       *bytes = sizeof(int64_t) * count;
230       break;
231     default:
232       ReportError(&context_,
233                   "Only float32, int32, int64, uint8 supported currently.");
234       return kTfLiteError;
235   }
236   return kTfLiteOk;
237 }
238 
239 namespace {
convertVectorToTfLiteIntArray(const std::vector<int> & x)240 TfLiteIntArray* convertVectorToTfLiteIntArray(const std::vector<int>& x) {
241   TfLiteIntArray* lite = TfLiteIntArrayCreate(x.size());
242   for (size_t i = 0; i < x.size(); i++) lite->data[i] = x[i];
243   return lite;
244 }
245 }  // namespace
246 
AllocateTensors()247 TfLiteStatus Interpreter::AllocateTensors() {
248   next_execution_plan_index_to_prepare_ = 0;
249   if (memory_planner_) {
250     TF_LITE_ENSURE_STATUS(memory_planner_->ResetAllocations());
251   }
252 
253   if (!consistent_) {
254     ReportError(&context_, "AllocateTensors() called on inconsistent model.");
255     return kTfLiteError;
256   }
257 
258   TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors());
259   invokable_ = true;
260   return kTfLiteOk;
261 }
262 
AddNodeWithParameters(const std::vector<int> & inputs,const std::vector<int> & outputs,const char * init_data,size_t init_data_size,void * builtin_data,const TfLiteRegistration * registration,int * node_index)263 TfLiteStatus Interpreter::AddNodeWithParameters(
264     const std::vector<int>& inputs, const std::vector<int>& outputs,
265     const char* init_data, size_t init_data_size, void* builtin_data,
266     const TfLiteRegistration* registration, int* node_index) {
267   invokable_ = false;
268 
269   std::unique_ptr<void, decltype(free)*> builtin_data_deleter(builtin_data,
270                                                               free);
271 
272   TF_LITE_ENSURE_OK(&context_, CheckTensorIndices("node inputs", inputs.data(),
273                                                   inputs.size()));
274   TF_LITE_ENSURE_OK(
275       &context_,
276       CheckTensorIndices("node outputs", outputs.data(), outputs.size()));
277 
278   int new_node_index = nodes_and_registration_.size();
279   if (node_index) *node_index = new_node_index;
280   nodes_and_registration_.resize(nodes_and_registration_.size() + 1);
281 
282   auto& node_and_reg = nodes_and_registration_.back();
283   TfLiteNode& node = node_and_reg.first;
284   if (node.inputs) TfLiteIntArrayFree(node.inputs);
285   if (node.outputs) TfLiteIntArrayFree(node.outputs);
286   if (node.temporaries) TfLiteIntArrayFree(node.temporaries);
287 
288   // NOTE, here we are not using move semantics yet, since our internal
289   // representation isn't std::vector, but in the future we would like to avoid
290   // copies, so we want the interface to take r-value references now.
291   node.inputs = convertVectorToTfLiteIntArray(inputs);
292   node.outputs = convertVectorToTfLiteIntArray(outputs);
293   node.temporaries = TfLiteIntArrayCreate(0);
294   if (init_data) {
295     node.user_data = OpInit(*registration, init_data, init_data_size);
296   } else {
297     node.user_data =
298         OpInit(*registration,
299                reinterpret_cast<const char*>(builtin_data_deleter.get()), 0);
300   }
301   node.builtin_data = builtin_data_deleter.release();
302   node_and_reg.second = *registration;
303   execution_plan_.push_back(new_node_index);
304   return kTfLiteOk;
305 }
306 
ResizeInputTensor(int tensor_index,const std::vector<int> & dims)307 TfLiteStatus Interpreter::ResizeInputTensor(int tensor_index,
308                                             const std::vector<int>& dims) {
309   // TODO(aselle): All bounds checks can be implemented as one-sided bounds
310   // checks by casting to unsigned for efficiency. Profile before doing this.
311 
312   TF_LITE_ENSURE(&context_,
313                  tensor_index < context_.tensors_size && tensor_index >= 0);
314   invokable_ = false;
315   TfLiteIntArray* dims_lite = convertVectorToTfLiteIntArray(dims);
316   return ResizeTensorImpl(&context_.tensors[tensor_index], dims_lite);
317 }
318 
319 // Returns true if at least one tensor in the given list is kTfLiteDynamic.
HasDynamicTensor(const TfLiteContext & context,const TfLiteIntArray * tensors)320 bool HasDynamicTensor(const TfLiteContext& context,
321                       const TfLiteIntArray* tensors) {
322   for (int i = 0; i < tensors->size; ++i) {
323     const TfLiteTensor& tensor = context.tensors[tensors->data[i]];
324     if (tensor.allocation_type == kTfLiteDynamic) {
325       return true;
326     }
327   }
328   return false;
329 }
330 
PrepareOpsStartingAt(int first_execution_plan_index,int * last_execution_plan_index_prepared)331 TfLiteStatus Interpreter::PrepareOpsStartingAt(
332     int first_execution_plan_index, int* last_execution_plan_index_prepared) {
333   for (int execution_plan_index = first_execution_plan_index;
334        execution_plan_index < execution_plan_.size(); execution_plan_index++) {
335     int node_index = execution_plan_[execution_plan_index];
336     TfLiteNode& node = nodes_and_registration_[node_index].first;
337     const TfLiteRegistration& registration =
338         nodes_and_registration_[node_index].second;
339     if (OpPrepare(registration, &node) == kTfLiteError) {
340       return kTfLiteError;
341     }
342 
343     *last_execution_plan_index_prepared = execution_plan_index;
344 
345     // Discontinue if the node has dynamic outputs. Note that we don't
346     // stop for dynamic temporary tensors since they won't affect the
347     // sizes of other tensors in the graph.
348     if (HasDynamicTensor(context_, node.outputs)) {
349       break;
350     }
351   }
352   return kTfLiteOk;
353 }
354 
PrepareOpsAndTensors()355 TfLiteStatus Interpreter::PrepareOpsAndTensors() {
356   if (!memory_planner_) {
357     memory_planner_.reset(new ArenaPlanner(
358         &context_, std::unique_ptr<GraphInfo>(new InterpreterInfo(this))));
359     memory_planner_->PlanAllocations();
360   }
361 
362   int last_exec_plan_index_prepared = 0;
363 
364   TF_LITE_ENSURE_STATUS(PrepareOpsStartingAt(
365       next_execution_plan_index_to_prepare_, &last_exec_plan_index_prepared));
366   TF_LITE_ENSURE_STATUS(memory_planner_->ExecuteAllocations(
367       next_execution_plan_index_to_prepare_, last_exec_plan_index_prepared));
368 
369   next_execution_plan_index_to_prepare_ = last_exec_plan_index_prepared + 1;
370   return kTfLiteOk;
371 }
372 
Invoke()373 TfLiteStatus Interpreter::Invoke() {
374   if (!consistent_) {
375     ReportError(&context_, "Invoke called on model that is not consistent.");
376     return kTfLiteError;
377   }
378   if (!invokable_) {
379     ReportError(&context_, "Invoke called on model that is not ready.");
380     return kTfLiteError;
381   }
382 
383   TfLiteStatus status = kTfLiteOk;
384   if (nnapi_delegate_) {
385     if (next_execution_plan_index_to_prepare_ == execution_plan_.size()) {
386       TF_LITE_ENSURE_OK(&context_, nnapi_delegate_->Invoke(this));
387       return kTfLiteOk;
388     } else {
389       // TODO(aselle): In the future, we would like this to be an
390       // automatic tflite CPU fallback.
391       ReportError(&context_,
392                   "NNAPI was requested, but dependent sized tensors "
393                   "being used.\n");
394       return kTfLiteError;
395     }
396   }
397 
398   // Invocations are always done in node order.
399   // Note that calling Invoke repeatedly will cause the original memory plan to
400   // be reused, unless either ResizeInputTensor() or AllocateTensors() has been
401   // called.
402   // TODO(b/71913981): we should force recalculation in the presence of dynamic
403   // tensors, because they may have new value which in turn may affect shapes
404   // and allocations.
405   for (int execution_plan_index = 0;
406        execution_plan_index < execution_plan_.size(); execution_plan_index++) {
407     if (execution_plan_index == next_execution_plan_index_to_prepare_) {
408       TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors());
409       TF_LITE_ENSURE(&context_, next_execution_plan_index_to_prepare_ >=
410                                     execution_plan_index);
411     }
412     int node_index = execution_plan_[execution_plan_index];
413     TfLiteNode& node = nodes_and_registration_[node_index].first;
414     const TfLiteRegistration& registration =
415         nodes_and_registration_[node_index].second;
416     if (OpInvoke(registration, &node) == kTfLiteError) {
417       status = kTfLiteError;
418     }
419   }
420   return status;
421 }
422 
ResizeTensor(TfLiteContext * context,TfLiteTensor * tensor,TfLiteIntArray * new_size)423 TfLiteStatus Interpreter::ResizeTensor(TfLiteContext* context,
424                                        TfLiteTensor* tensor,
425                                        TfLiteIntArray* new_size) {
426   // Note here that context->impl_ is recovering the this pointer for an
427   // instance of Interpreter to call into the member function ResizeTensorImpl
428   // (this function is static).
429   return static_cast<Interpreter*>(context->impl_)
430       ->ResizeTensorImpl(tensor, new_size);
431 }
432 
ReportErrorImpl(const char * format,va_list args)433 void Interpreter::ReportErrorImpl(const char* format, va_list args) {
434   error_reporter_->Report(format, args);
435 }
436 
ReportError(TfLiteContext * context,const char * format,...)437 void Interpreter::ReportError(TfLiteContext* context, const char* format, ...) {
438   va_list args;
439   va_start(args, format);
440   auto* f = static_cast<Interpreter*>(context->impl_);
441   // Note here that context->impl_ is recovering the this pointer for an
442   // instance of Interpreter to call into the member function ReportErrorImpl
443   // (this function is static).
444   f->ReportErrorImpl(format, args);
445   va_end(args);
446 }
447 
AddTensors(int tensors_to_add,int * first_new_tensor_index)448 TfLiteStatus Interpreter::AddTensors(int tensors_to_add,
449                                      int* first_new_tensor_index) {
450   int base_index = tensors_.size();
451   if (first_new_tensor_index) *first_new_tensor_index = base_index;
452   tensors_.resize(tensors_.size() + tensors_to_add);
453   for (int i = base_index; i < tensors_.size(); i++) {
454     memset(&tensors_[i], 0, sizeof(tensors_[i]));
455   }
456   context_.tensors = tensors_.data();
457   context_.tensors_size = tensors_.size();
458   return kTfLiteOk;
459 }
460 
AddTensors(TfLiteContext * context,int tensors_to_add,int * first_new_tensor_index)461 TfLiteStatus Interpreter::AddTensors(TfLiteContext* context, int tensors_to_add,
462                                      int* first_new_tensor_index) {
463   // Note here that context->impl_ is recovering the this pointer for an
464   // instance of Interpreter to call into the member function AddTensors
465   // (this function is static).
466   return static_cast<Interpreter*>(context->impl_)
467       ->AddTensors(tensors_to_add, first_new_tensor_index);
468 }
469 
GetNodeAndRegistration(int node_index,TfLiteNode ** node,TfLiteRegistration ** registration)470 TfLiteStatus Interpreter::GetNodeAndRegistration(
471     int node_index, TfLiteNode** node, TfLiteRegistration** registration) {
472   TF_LITE_ENSURE(&context_, node_index < nodes_size() && node_index >= 0);
473   TF_LITE_ENSURE(&context_, node != nullptr && registration != nullptr);
474   *node = &nodes_and_registration_[node_index].first;
475   *registration = &nodes_and_registration_[node_index].second;
476   return kTfLiteOk;
477 }
478 
GetNodeAndRegistration(struct TfLiteContext * context,int node_index,TfLiteNode ** node,TfLiteRegistration ** registration)479 TfLiteStatus Interpreter::GetNodeAndRegistration(
480     struct TfLiteContext* context, int node_index, TfLiteNode** node,
481     TfLiteRegistration** registration) {
482   return static_cast<Interpreter*>(context->impl_)
483       ->GetNodeAndRegistration(node_index, node, registration);
484 }
485 
SetTensorParametersReadOnly(int tensor_index,TfLiteType type,const char * name,const std::vector<int> & dims,TfLiteQuantizationParams quantization,const char * buffer,size_t bytes,const Allocation * allocation)486 TfLiteStatus Interpreter::SetTensorParametersReadOnly(
487     int tensor_index, TfLiteType type, const char* name,
488     const std::vector<int>& dims, TfLiteQuantizationParams quantization,
489     const char* buffer, size_t bytes, const Allocation* allocation) {
490   TF_LITE_ENSURE(&context_,
491                  tensor_index < context_.tensors_size && tensor_index >= 0);
492   // For most tensors we know exactly how much memory is necessary so we can
493   // ensure the buffer is large enough. However, we need to skip string tensors
494   // because their sizes change with the contents of the individual strings.
495   if (type != kTfLiteString) {
496     size_t required_bytes;
497     TF_LITE_ENSURE_OK(&context_, BytesRequired(type, dims.data(), dims.size(),
498                                                &required_bytes));
499     TF_LITE_ENSURE_EQ(&context_, required_bytes, bytes);
500   }
501   invokable_ = false;
502   TfLiteTensorReset(type, name, convertVectorToTfLiteIntArray(dims),
503                     quantization, const_cast<char*>(buffer), bytes,
504                     kTfLiteMmapRo, allocation, &context_.tensors[tensor_index]);
505   return kTfLiteOk;
506 }
507 
508 // Set description of inputs/outputs/data/fptrs for node `node_index`.
509 // This variant assumes an external buffer has been allocated of size
510 // bytes. The lifetime of buffer must be ensured to be greater or equal
511 // to Interpreter.
SetTensorParametersReadWrite(int tensor_index,TfLiteType type,const char * name,const std::vector<int> & dims,TfLiteQuantizationParams quantization)512 TfLiteStatus Interpreter::SetTensorParametersReadWrite(
513     int tensor_index, TfLiteType type, const char* name,
514     const std::vector<int>& dims, TfLiteQuantizationParams quantization) {
515   invokable_ = false;
516   TF_LITE_ENSURE(&context_,
517                  tensor_index < context_.tensors_size && tensor_index >= 0);
518   size_t required_bytes = 0;
519   if (type != kTfLiteString) {
520     // These types will be allocated in our arena so we need to record how
521     // many bytes we will need based on the dimensions. String tensors are
522     // allocated dynamically and we can't know ahead of time how much space
523     // they will require.
524     TF_LITE_ENSURE_OK(&context_, BytesRequired(type, dims.data(), dims.size(),
525                                                &required_bytes));
526   }
527   TfLiteTensorReset(type, name, convertVectorToTfLiteIntArray(dims),
528                     quantization,
529                     /*buffer=*/nullptr, required_bytes,
530                     type == kTfLiteString ? kTfLiteDynamic : kTfLiteArenaRw,
531                     nullptr, &context_.tensors[tensor_index]);
532   return kTfLiteOk;
533 }
534 
SetExecutionPlan(const std::vector<int> & new_plan)535 TfLiteStatus Interpreter::SetExecutionPlan(const std::vector<int>& new_plan) {
536   for (int node_index : new_plan) {
537     TF_LITE_ENSURE(&context_, node_index >= 0 && node_index < nodes_size());
538   }
539   execution_plan_ = new_plan;
540   return kTfLiteOk;
541 }
542 
ResizeTensorImpl(TfLiteTensor * tensor,TfLiteIntArray * new_size)543 TfLiteStatus Interpreter::ResizeTensorImpl(TfLiteTensor* tensor,
544                                            TfLiteIntArray* new_size) {
545   // Note that in theory we could resize kTfLiteArenaRwPersistent tensors too.
546   if (tensor->allocation_type == kTfLiteArenaRw ||
547       tensor->allocation_type == kTfLiteDynamic) {
548     if (tensor->type != kTfLiteString) {
549       size_t bytesRequired;
550       TfLiteStatus status = BytesRequired(tensor->type, new_size->data,
551                                           new_size->size, &bytesRequired);
552       if (status != kTfLiteOk) {
553         TfLiteIntArrayFree(new_size);
554         return kTfLiteError;
555       }
556 
557       // Realloc space for kTfLiteDynamic tensors.
558       TfLiteTensorRealloc(bytesRequired, tensor);
559       tensor->bytes = bytesRequired;
560     }
561     if (tensor->dims) TfLiteIntArrayFree(tensor->dims);
562     tensor->dims = new_size;
563 
564     if (tensor->allocation_type != kTfLiteDynamic) {
565       tensor->data.raw = nullptr;
566     }
567   } else {
568     // kTfLiteMmapRo tensors are stored in the flatbuffer and are therefore
569     // of fixed size.
570     TfLiteIntArrayFree(new_size);
571     ReportError(&context_, "Attempting to resize a fixed-size tensor.");
572     return kTfLiteError;
573   }
574   return kTfLiteOk;
575 }
576 
UseNNAPI(bool enable)577 void Interpreter::UseNNAPI(bool enable) {
578   // TODO(aselle): This is a workaround for finding if NNAPI exists.
579   // We also need to make sure getLibraryHandle() is renamed to be NNAPI
580   // prefixed.
581   if (!NNAPIExists()) enable = false;
582   if (!enable) {
583     nnapi_delegate_.reset();
584   } else if (!nnapi_delegate_) {
585     nnapi_delegate_.reset(new NNAPIDelegate);
586   }
587 }
588 
SetNumThreads(int num_threads)589 void Interpreter::SetNumThreads(int num_threads) {
590   // TODO(ahentz): this forces us to link against gemmlowp even when the ops
591   // don't use it. We should implement some dynamic mechanism for this sort of
592   // library-specific initialization.
593   tflite::gemm_support::SetMaxNumThreads(&context_, num_threads);
594 }
595 
ModifyGraphWithDelegate(TfLiteDelegate * delegate)596 TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) {
597   // TODO(aselle): Consider if it is worth storing pointers to delegates.
598   // Setup additional context interface
599   context_.GetNodeAndRegistration = GetNodeAndRegistration;
600   context_.ReplaceSubgraphsWithDelegateKernels =
601       ReplaceSubgraphsWithDelegateKernels;
602   context_.GetExecutionPlan = GetExecutionPlan;
603 
604   TfLiteStatus status = delegate->Prepare(&context_, delegate->data_);
605   // Remove additional context info.
606   context_.GetNodeAndRegistration = nullptr;
607   context_.ReplaceSubgraphsWithDelegateKernels = nullptr;
608   context_.GetExecutionPlan = nullptr;
609   return status;
610 }
611 
612 }  // namespace tflite
613