• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h"
16 
17 #include <stdarg.h>
18 
19 #include <cstring>
20 #include <functional>
21 #include <memory>
22 #include <sstream>
23 #include <string>
24 
25 #include "absl/memory/memory.h"
26 #include "absl/strings/str_format.h"
27 #include "tensorflow/lite/c/common.h"
28 #include "tensorflow/lite/core/api/error_reporter.h"
29 #include "tensorflow/lite/core/api/op_resolver.h"
30 #include "tensorflow/lite/interpreter.h"
31 #include "tensorflow/lite/kernels/internal/compatibility.h"
32 #include "tensorflow/lite/kernels/register.h"
33 #include "tensorflow/lite/kernels/register_ref.h"
34 #include "tensorflow/lite/model.h"
35 #include "tensorflow/lite/mutable_op_resolver.h"
36 #include "tensorflow/lite/python/interpreter_wrapper/numpy.h"
37 #include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
38 #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
39 #include "tensorflow/lite/shared_library.h"
40 #include "tensorflow/lite/string_util.h"
41 #include "tensorflow/lite/util.h"
42 
43 #define TFLITE_PY_CHECK(x)               \
44   if ((x) != kTfLiteOk) {                \
45     return error_reporter_->exception(); \
46   }
47 
48 #define TFLITE_PY_TENSOR_BOUNDS_CHECK(i)                                    \
49   if (i >= interpreter_->tensors_size() || i < 0) {                         \
50     PyErr_Format(PyExc_ValueError,                                          \
51                  "Invalid tensor index %d exceeds max tensor index %lu", i, \
52                  interpreter_->tensors_size());                             \
53     return nullptr;                                                         \
54   }
55 
56 #define TFLITE_PY_SUBGRAPH_TENSOR_BOUNDS_CHECK(i, subgraph_index)             \
57   if (i >= interpreter_->subgraph(subgraph_index)->tensors_size() || i < 0) { \
58     PyErr_Format(PyExc_ValueError,                                            \
59                  "Invalid tensor index %d exceeds max tensor index %lu", i,   \
60                  interpreter_->subgraph(subgraph_index)->tensors_size());     \
61     return nullptr;                                                           \
62   }
63 
64 #define TFLITE_PY_SUBGRAPH_BOUNDS_CHECK(i)                                   \
65   if (i >= interpreter_->subgraphs_size() || i < 0) {                        \
66     PyErr_Format(PyExc_ValueError,                                           \
67                  "Invalid subgraph index %d exceeds max subgraph index %lu", \
68                  i, interpreter_->subgraphs_size());                         \
69     return nullptr;                                                          \
70   }
71 
72 #define TFLITE_PY_NODES_BOUNDS_CHECK(i)                   \
73   if (i >= interpreter_->nodes_size() || i < 0) {         \
74     PyErr_Format(PyExc_ValueError, "Invalid node index"); \
75     return nullptr;                                       \
76   }
77 
78 #define TFLITE_PY_ENSURE_VALID_INTERPRETER()                               \
79   if (!interpreter_) {                                                     \
80     PyErr_SetString(PyExc_ValueError, "Interpreter was not initialized."); \
81     return nullptr;                                                        \
82   }
83 
84 namespace tflite {
85 namespace interpreter_wrapper {
86 
87 namespace {
88 
89 using python_utils::PyDecrefDeleter;
90 
CreateInterpreter(const InterpreterWrapper::Model * model,const tflite::MutableOpResolver & resolver,bool preserve_all_tensors)91 std::unique_ptr<Interpreter> CreateInterpreter(
92     const InterpreterWrapper::Model* model,
93     const tflite::MutableOpResolver& resolver, bool preserve_all_tensors) {
94   if (!model) {
95     return nullptr;
96   }
97 
98   ::tflite::python::ImportNumpy();
99 
100   std::unique_ptr<Interpreter> interpreter;
101   InterpreterBuilder builder(*model, resolver);
102   if (preserve_all_tensors) builder.PreserveAllTensorsExperimental();
103   if (builder(&interpreter) != kTfLiteOk) {
104     return nullptr;
105   }
106   return interpreter;
107 }
108 
PyArrayFromFloatVector(const float * data,npy_intp size)109 PyObject* PyArrayFromFloatVector(const float* data, npy_intp size) {
110   void* pydata = malloc(size * sizeof(float));
111   memcpy(pydata, data, size * sizeof(float));
112   PyObject* obj = PyArray_SimpleNewFromData(1, &size, NPY_FLOAT32, pydata);
113   PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(obj), NPY_ARRAY_OWNDATA);
114   return obj;
115 }
116 
PyArrayFromIntVector(const int * data,npy_intp size)117 PyObject* PyArrayFromIntVector(const int* data, npy_intp size) {
118   void* pydata = malloc(size * sizeof(int));
119   memcpy(pydata, data, size * sizeof(int));
120   PyObject* obj = PyArray_SimpleNewFromData(1, &size, NPY_INT32, pydata);
121   PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(obj), NPY_ARRAY_OWNDATA);
122   return obj;
123 }
124 
PyTupleFromQuantizationParam(const TfLiteQuantizationParams & param)125 PyObject* PyTupleFromQuantizationParam(const TfLiteQuantizationParams& param) {
126   PyObject* result = PyTuple_New(2);
127   PyTuple_SET_ITEM(result, 0, PyFloat_FromDouble(param.scale));
128   PyTuple_SET_ITEM(result, 1, PyLong_FromLong(param.zero_point));
129   return result;
130 }
131 
PyDictFromSparsityParam(const TfLiteSparsity & param)132 PyObject* PyDictFromSparsityParam(const TfLiteSparsity& param) {
133   PyObject* result = PyDict_New();
134   PyDict_SetItemString(result, "traversal_order",
135                        PyArrayFromIntVector(param.traversal_order->data,
136                                             param.traversal_order->size));
137   PyDict_SetItemString(
138       result, "block_map",
139       PyArrayFromIntVector(param.block_map->data, param.block_map->size));
140   PyObject* dim_metadata = PyList_New(param.dim_metadata_size);
141   for (int i = 0; i < param.dim_metadata_size; i++) {
142     PyObject* dim_metadata_i = PyDict_New();
143     if (param.dim_metadata[i].format == kTfLiteDimDense) {
144       PyDict_SetItemString(dim_metadata_i, "format", PyLong_FromSize_t(0));
145       PyDict_SetItemString(dim_metadata_i, "dense_size",
146                            PyLong_FromSize_t(param.dim_metadata[i].dense_size));
147     } else {
148       PyDict_SetItemString(dim_metadata_i, "format", PyLong_FromSize_t(1));
149       const auto* array_segments = param.dim_metadata[i].array_segments;
150       const auto* array_indices = param.dim_metadata[i].array_indices;
151       PyDict_SetItemString(
152           dim_metadata_i, "array_segments",
153           PyArrayFromIntVector(array_segments->data, array_segments->size));
154       PyDict_SetItemString(
155           dim_metadata_i, "array_indices",
156           PyArrayFromIntVector(array_indices->data, array_indices->size));
157     }
158     PyList_SetItem(dim_metadata, i, dim_metadata_i);
159   }
160   PyDict_SetItemString(result, "dim_metadata", dim_metadata);
161   return result;
162 }
163 
RegisterCustomOpByName(const char * registerer_name,tflite::MutableOpResolver * resolver,std::string * error_msg)164 bool RegisterCustomOpByName(const char* registerer_name,
165                             tflite::MutableOpResolver* resolver,
166                             std::string* error_msg) {
167   // Registerer functions take a pointer to a BuiltinOpResolver as an input
168   // parameter and return void.
169   // TODO(b/137576229): We should implement this functionality in a more
170   // principled way.
171   typedef void (*RegistererFunctionType)(tflite::MutableOpResolver*);
172 
173   // Look for the Registerer function by name.
174   RegistererFunctionType registerer = reinterpret_cast<RegistererFunctionType>(
175       SharedLibrary::GetSymbol(registerer_name));
176 
177   // Fail in an informative way if the function was not found.
178   if (registerer == nullptr) {
179     *error_msg =
180         absl::StrFormat("Looking up symbol '%s' failed with error '%s'.",
181                         registerer_name, SharedLibrary::GetError());
182     return false;
183   }
184 
185   // Call the registerer with the resolver.
186   registerer(resolver);
187   return true;
188 }
189 
190 }  // namespace
191 
192 static constexpr int kBuiltinOpResolver = 1;
193 static constexpr int kBuiltinRefOpResolver = 2;
194 static constexpr int kBuiltinOpResolverWithoutDefaultDelegates = 3;
195 
CreateInterpreterWrapper(std::unique_ptr<InterpreterWrapper::Model> model,int op_resolver_id,std::unique_ptr<PythonErrorReporter> error_reporter,const std::vector<std::string> & registerers_by_name,const std::vector<std::function<void (uintptr_t)>> & registerers_by_func,std::string * error_msg,bool preserve_all_tensors)196 InterpreterWrapper* InterpreterWrapper::CreateInterpreterWrapper(
197     std::unique_ptr<InterpreterWrapper::Model> model, int op_resolver_id,
198     std::unique_ptr<PythonErrorReporter> error_reporter,
199     const std::vector<std::string>& registerers_by_name,
200     const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
201     std::string* error_msg, bool preserve_all_tensors) {
202   if (!model) {
203     *error_msg = error_reporter->message();
204     return nullptr;
205   }
206 
207   std::unique_ptr<tflite::MutableOpResolver> resolver;
208   switch (op_resolver_id) {
209     case kBuiltinOpResolver:
210       resolver = absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>();
211       break;
212     case kBuiltinRefOpResolver:
213       resolver =
214           absl::make_unique<tflite::ops::builtin::BuiltinRefOpResolver>();
215       break;
216     case kBuiltinOpResolverWithoutDefaultDelegates:
217       resolver = absl::make_unique<
218           tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates>();
219       break;
220     default:
221       // This should not never happen because the eventual caller in
222       // interpreter.py should have passed a valid id here.
223       TFLITE_DCHECK(false);
224       return nullptr;
225   }
226 
227   for (const auto& registerer : registerers_by_name) {
228     if (!RegisterCustomOpByName(registerer.c_str(), resolver.get(), error_msg))
229       return nullptr;
230   }
231   for (const auto& registerer : registerers_by_func) {
232     registerer(reinterpret_cast<uintptr_t>(resolver.get()));
233   }
234   auto interpreter =
235       CreateInterpreter(model.get(), *resolver, preserve_all_tensors);
236   if (!interpreter) {
237     *error_msg = error_reporter->message();
238     return nullptr;
239   }
240 
241   InterpreterWrapper* wrapper =
242       new InterpreterWrapper(std::move(model), std::move(error_reporter),
243                              std::move(resolver), std::move(interpreter));
244   return wrapper;
245 }
246 
InterpreterWrapper(std::unique_ptr<InterpreterWrapper::Model> model,std::unique_ptr<PythonErrorReporter> error_reporter,std::unique_ptr<tflite::MutableOpResolver> resolver,std::unique_ptr<Interpreter> interpreter)247 InterpreterWrapper::InterpreterWrapper(
248     std::unique_ptr<InterpreterWrapper::Model> model,
249     std::unique_ptr<PythonErrorReporter> error_reporter,
250     std::unique_ptr<tflite::MutableOpResolver> resolver,
251     std::unique_ptr<Interpreter> interpreter)
252     : model_(std::move(model)),
253       error_reporter_(std::move(error_reporter)),
254       resolver_(std::move(resolver)),
255       interpreter_(std::move(interpreter)) {}
256 
~InterpreterWrapper()257 InterpreterWrapper::~InterpreterWrapper() {}
258 
AllocateTensors(int subgraph_index)259 PyObject* InterpreterWrapper::AllocateTensors(int subgraph_index) {
260   TFLITE_PY_ENSURE_VALID_INTERPRETER();
261   TFLITE_PY_SUBGRAPH_BOUNDS_CHECK(subgraph_index);
262   TFLITE_PY_CHECK(interpreter_->subgraph(subgraph_index)->AllocateTensors());
263   Py_RETURN_NONE;
264 }
265 
Invoke(int subgraph_index)266 PyObject* InterpreterWrapper::Invoke(int subgraph_index) {
267   TFLITE_PY_ENSURE_VALID_INTERPRETER();
268   TFLITE_PY_SUBGRAPH_BOUNDS_CHECK(subgraph_index);
269 
270   // Release the GIL so that we can run multiple interpreters in parallel
271   TfLiteStatus status_code = kTfLiteOk;
272   Py_BEGIN_ALLOW_THREADS;  // To return can happen between this and end!
273   tflite::Subgraph* subgraph = interpreter_->subgraph(subgraph_index);
274   status_code = subgraph->Invoke();
275 
276   if (!interpreter_->allow_buffer_handle_output_) {
277     for (int tensor_index : subgraph->outputs()) {
278       subgraph->EnsureTensorDataIsReadable(tensor_index);
279     }
280   }
281   Py_END_ALLOW_THREADS;
282 
283   TFLITE_PY_CHECK(
284       status_code);  // don't move this into the Py_BEGIN/Py_End block
285 
286   Py_RETURN_NONE;
287 }
288 
InputIndices() const289 PyObject* InterpreterWrapper::InputIndices() const {
290   TFLITE_PY_ENSURE_VALID_INTERPRETER();
291   PyObject* np_array = PyArrayFromIntVector(interpreter_->inputs().data(),
292                                             interpreter_->inputs().size());
293 
294   return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
295 }
296 
OutputIndices() const297 PyObject* InterpreterWrapper::OutputIndices() const {
298   PyObject* np_array = PyArrayFromIntVector(interpreter_->outputs().data(),
299                                             interpreter_->outputs().size());
300 
301   return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
302 }
303 
ResizeInputTensorImpl(int i,PyObject * value)304 PyObject* InterpreterWrapper::ResizeInputTensorImpl(int i, PyObject* value) {
305   TFLITE_PY_ENSURE_VALID_INTERPRETER();
306 
307   std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
308       PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr));
309   if (!array_safe) {
310     PyErr_SetString(PyExc_ValueError,
311                     "Failed to convert numpy value into readable tensor.");
312     return nullptr;
313   }
314 
315   PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
316 
317   if (PyArray_NDIM(array) != 1) {
318     PyErr_Format(PyExc_ValueError, "Shape should be 1D instead of %d.",
319                  PyArray_NDIM(array));
320     return nullptr;
321   }
322 
323   if (PyArray_TYPE(array) != NPY_INT32) {
324     PyErr_Format(PyExc_ValueError, "Shape must be type int32 (was %d).",
325                  PyArray_TYPE(array));
326     return nullptr;
327   }
328 
329   PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(array),
330                       NPY_ARRAY_OWNDATA);
331   return PyArray_Return(reinterpret_cast<PyArrayObject*>(array));
332 }
333 
ResizeInputTensor(int i,PyObject * value,bool strict,int subgraph_index)334 PyObject* InterpreterWrapper::ResizeInputTensor(int i, PyObject* value,
335                                                 bool strict,
336                                                 int subgraph_index) {
337   TFLITE_PY_ENSURE_VALID_INTERPRETER();
338   TFLITE_PY_SUBGRAPH_BOUNDS_CHECK(subgraph_index);
339 
340   PyArrayObject* array =
341       reinterpret_cast<PyArrayObject*>(ResizeInputTensorImpl(i, value));
342   if (array == nullptr) {
343     return nullptr;
344   }
345 
346   std::vector<int> dims(PyArray_SHAPE(array)[0]);
347   memcpy(dims.data(), PyArray_BYTES(array), dims.size() * sizeof(int));
348 
349   if (strict) {
350     TFLITE_PY_CHECK(interpreter_->subgraph(subgraph_index)
351                         ->ResizeInputTensorStrict(i, dims));
352   } else {
353     TFLITE_PY_CHECK(
354         interpreter_->subgraph(subgraph_index)->ResizeInputTensor(i, dims));
355   }
356   Py_RETURN_NONE;
357 }
358 
NumTensors() const359 int InterpreterWrapper::NumTensors() const {
360   if (!interpreter_) {
361     return 0;
362   }
363   return interpreter_->tensors_size();
364 }
365 
TensorName(int i) const366 std::string InterpreterWrapper::TensorName(int i) const {
367   if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) {
368     return "";
369   }
370 
371   const TfLiteTensor* tensor = interpreter_->tensor(i);
372   return tensor->name ? tensor->name : "";
373 }
374 
TensorType(int i) const375 PyObject* InterpreterWrapper::TensorType(int i) const {
376   TFLITE_PY_ENSURE_VALID_INTERPRETER();
377   TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
378 
379   const TfLiteTensor* tensor = interpreter_->tensor(i);
380   if (tensor->type == kTfLiteNoType) {
381     PyErr_Format(PyExc_ValueError, "Tensor with no type found.");
382     return nullptr;
383   }
384 
385   int code = python_utils::TfLiteTypeToPyArrayType(tensor->type);
386   if (code == -1) {
387     PyErr_Format(PyExc_ValueError, "Invalid tflite type code %d", code);
388     return nullptr;
389   }
390   return PyArray_TypeObjectFromType(code);
391 }
392 
TensorSize(int i) const393 PyObject* InterpreterWrapper::TensorSize(int i) const {
394   TFLITE_PY_ENSURE_VALID_INTERPRETER();
395   TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
396 
397   const TfLiteTensor* tensor = interpreter_->tensor(i);
398   if (tensor->dims == nullptr) {
399     PyErr_Format(PyExc_ValueError, "Tensor with no shape found.");
400     return nullptr;
401   }
402   PyObject* np_array =
403       PyArrayFromIntVector(tensor->dims->data, tensor->dims->size);
404 
405   return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
406 }
407 
TensorSizeSignature(int i) const408 PyObject* InterpreterWrapper::TensorSizeSignature(int i) const {
409   TFLITE_PY_ENSURE_VALID_INTERPRETER();
410   TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
411 
412   const TfLiteTensor* tensor = interpreter_->tensor(i);
413   const int32_t* size_signature_data = nullptr;
414   int32_t size_signature_size = 0;
415   if (tensor->dims_signature != nullptr && tensor->dims_signature->size != 0) {
416     size_signature_data = tensor->dims_signature->data;
417     size_signature_size = tensor->dims_signature->size;
418   } else {
419     size_signature_data = tensor->dims->data;
420     size_signature_size = tensor->dims->size;
421   }
422   PyObject* np_array =
423       PyArrayFromIntVector(size_signature_data, size_signature_size);
424 
425   return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
426 }
427 
TensorSparsityParameters(int i) const428 PyObject* InterpreterWrapper::TensorSparsityParameters(int i) const {
429   TFLITE_PY_ENSURE_VALID_INTERPRETER();
430   TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
431   const TfLiteTensor* tensor = interpreter_->tensor(i);
432   if (tensor->sparsity == nullptr) {
433     return PyDict_New();
434   }
435 
436   return PyDictFromSparsityParam(*tensor->sparsity);
437 }
438 
TensorQuantization(int i) const439 PyObject* InterpreterWrapper::TensorQuantization(int i) const {
440   TFLITE_PY_ENSURE_VALID_INTERPRETER();
441   TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
442   const TfLiteTensor* tensor = interpreter_->tensor(i);
443   return PyTupleFromQuantizationParam(tensor->params);
444 }
445 
TensorQuantizationParameters(int i) const446 PyObject* InterpreterWrapper::TensorQuantizationParameters(int i) const {
447   TFLITE_PY_ENSURE_VALID_INTERPRETER();
448   TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
449   const TfLiteTensor* tensor = interpreter_->tensor(i);
450   const TfLiteQuantization quantization = tensor->quantization;
451   float* scales_data = nullptr;
452   int32_t* zero_points_data = nullptr;
453   int32_t scales_size = 0;
454   int32_t zero_points_size = 0;
455   int32_t quantized_dimension = 0;
456   if (quantization.type == kTfLiteAffineQuantization) {
457     const TfLiteAffineQuantization* q_params =
458         reinterpret_cast<const TfLiteAffineQuantization*>(quantization.params);
459     if (q_params->scale) {
460       scales_data = q_params->scale->data;
461       scales_size = q_params->scale->size;
462     }
463     if (q_params->zero_point) {
464       zero_points_data = q_params->zero_point->data;
465       zero_points_size = q_params->zero_point->size;
466     }
467     quantized_dimension = q_params->quantized_dimension;
468   }
469   PyObject* scales_array = PyArrayFromFloatVector(scales_data, scales_size);
470   PyObject* zero_points_array =
471       PyArrayFromIntVector(zero_points_data, zero_points_size);
472 
473   PyObject* result = PyTuple_New(3);
474   PyTuple_SET_ITEM(result, 0, scales_array);
475   PyTuple_SET_ITEM(result, 1, zero_points_array);
476   PyTuple_SET_ITEM(result, 2, PyLong_FromLong(quantized_dimension));
477   return result;
478 }
479 
SetTensor(int i,PyObject * value,int subgraph_index)480 PyObject* InterpreterWrapper::SetTensor(int i, PyObject* value,
481                                         int subgraph_index) {
482   TFLITE_PY_ENSURE_VALID_INTERPRETER();
483   TFLITE_PY_SUBGRAPH_BOUNDS_CHECK(subgraph_index);
484   TFLITE_PY_SUBGRAPH_TENSOR_BOUNDS_CHECK(i, subgraph_index);
485 
486   std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
487       PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr));
488   if (!array_safe) {
489     PyErr_SetString(PyExc_ValueError,
490                     "Failed to convert value into readable tensor.");
491     return nullptr;
492   }
493 
494   PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
495   TfLiteTensor* tensor = interpreter_->subgraph(subgraph_index)->tensor(i);
496 
497   if (python_utils::TfLiteTypeFromPyArray(array) != tensor->type) {
498     PyErr_Format(PyExc_ValueError,
499                  "Cannot set tensor:"
500                  " Got value of type %s"
501                  " but expected type %s for input %d, name: %s ",
502                  TfLiteTypeGetName(python_utils::TfLiteTypeFromPyArray(array)),
503                  TfLiteTypeGetName(tensor->type), i, tensor->name);
504     return nullptr;
505   }
506 
507   if (PyArray_NDIM(array) != tensor->dims->size) {
508     PyErr_Format(PyExc_ValueError,
509                  "Cannot set tensor: Dimension mismatch."
510                  " Got %d"
511                  " but expected %d for input %d.",
512                  PyArray_NDIM(array), tensor->dims->size, i);
513     return nullptr;
514   }
515 
516   for (int j = 0; j < PyArray_NDIM(array); j++) {
517     if (tensor->dims->data[j] != PyArray_SHAPE(array)[j]) {
518       PyErr_Format(PyExc_ValueError,
519                    "Cannot set tensor: Dimension mismatch."
520                    " Got %ld"
521                    " but expected %d for dimension %d of input %d.",
522                    PyArray_SHAPE(array)[j], tensor->dims->data[j], j, i);
523       return nullptr;
524     }
525   }
526 
527   if (tensor->type != kTfLiteString) {
528     if (tensor->data.raw == nullptr) {
529       PyErr_Format(PyExc_ValueError,
530                    "Cannot set tensor:"
531                    " Tensor is unallocated. Try calling allocate_tensors()"
532                    " first");
533       return nullptr;
534     }
535 
536     size_t size = PyArray_NBYTES(array);
537     if (size != tensor->bytes) {
538       PyErr_Format(PyExc_ValueError,
539                    "numpy array had %zu bytes but expected %zu bytes.", size,
540                    tensor->bytes);
541       return nullptr;
542     }
543     memcpy(tensor->data.raw, PyArray_DATA(array), size);
544   } else {
545     DynamicBuffer dynamic_buffer;
546     if (!python_utils::FillStringBufferWithPyArray(value, &dynamic_buffer)) {
547       return nullptr;
548     }
549     dynamic_buffer.WriteToTensor(tensor, nullptr);
550   }
551   Py_RETURN_NONE;
552 }
553 
NumNodes() const554 int InterpreterWrapper::NumNodes() const {
555   if (!interpreter_) {
556     return 0;
557   }
558   return interpreter_->nodes_size();
559 }
560 
NodeInputs(int i) const561 PyObject* InterpreterWrapper::NodeInputs(int i) const {
562   TFLITE_PY_ENSURE_VALID_INTERPRETER();
563   TFLITE_PY_NODES_BOUNDS_CHECK(i);
564 
565   const TfLiteNode* node = &(interpreter_->node_and_registration(i)->first);
566   PyObject* inputs =
567       PyArrayFromIntVector(node->inputs->data, node->inputs->size);
568   return inputs;
569 }
570 
NodeOutputs(int i) const571 PyObject* InterpreterWrapper::NodeOutputs(int i) const {
572   TFLITE_PY_ENSURE_VALID_INTERPRETER();
573   TFLITE_PY_NODES_BOUNDS_CHECK(i);
574 
575   const TfLiteNode* node = &(interpreter_->node_and_registration(i)->first);
576   PyObject* outputs =
577       PyArrayFromIntVector(node->outputs->data, node->outputs->size);
578   return outputs;
579 }
580 
NodeName(int i) const581 std::string InterpreterWrapper::NodeName(int i) const {
582   if (!interpreter_ || i >= interpreter_->nodes_size() || i < 0) {
583     return "";
584   }
585   // Get op name from registration
586   const TfLiteRegistration* node_registration =
587       &(interpreter_->node_and_registration(i)->second);
588   int32_t op_code = node_registration->builtin_code;
589   std::string op_name;
590   if (op_code == tflite::BuiltinOperator_CUSTOM) {
591     const char* custom_name = node_registration->custom_name;
592     op_name = custom_name ? custom_name : "UnknownCustomOp";
593   } else {
594     op_name = tflite::EnumNamesBuiltinOperator()[op_code];
595   }
596   std::string op_name_str(op_name);
597   return op_name_str;
598 }
599 
600 namespace {
601 
602 // Checks to see if a tensor access can succeed (returns nullptr on error).
603 // Otherwise returns Py_None.
CheckGetTensorArgs(Interpreter * interpreter_,int tensor_index,TfLiteTensor ** tensor,int * type_num,int subgraph_index)604 PyObject* CheckGetTensorArgs(Interpreter* interpreter_, int tensor_index,
605                              TfLiteTensor** tensor, int* type_num,
606                              int subgraph_index) {
607   TFLITE_PY_ENSURE_VALID_INTERPRETER();
608   TFLITE_PY_SUBGRAPH_BOUNDS_CHECK(subgraph_index);
609   TFLITE_PY_SUBGRAPH_TENSOR_BOUNDS_CHECK(tensor_index, subgraph_index);
610 
611   *tensor = interpreter_->subgraph(subgraph_index)->tensor(tensor_index);
612   if ((*tensor)->bytes == 0) {
613     PyErr_SetString(PyExc_ValueError, "Invalid tensor size.");
614     return nullptr;
615   }
616 
617   *type_num = python_utils::TfLiteTypeToPyArrayType((*tensor)->type);
618   if (*type_num == -1) {
619     PyErr_SetString(PyExc_ValueError, "Unknown tensor type.");
620     return nullptr;
621   }
622 
623   if (!(*tensor)->data.raw) {
624     PyErr_SetString(PyExc_ValueError,
625                     "Tensor data is null."
626                     " Run allocate_tensors() first");
627     return nullptr;
628   }
629 
630   Py_RETURN_NONE;
631 }
632 
633 }  // namespace
634 
GetSignatureDefs() const635 PyObject* InterpreterWrapper::GetSignatureDefs() const {
636   PyObject* result = PyDict_New();
637   for (const auto& sig_key : interpreter_->signature_keys()) {
638     PyObject* signature_def = PyDict_New();
639     PyObject* inputs = PyDict_New();
640     PyObject* outputs = PyDict_New();
641     const auto& signature_def_inputs =
642         interpreter_->signature_inputs(sig_key->c_str());
643     const auto& signature_def_outputs =
644         interpreter_->signature_outputs(sig_key->c_str());
645     for (const auto& input : signature_def_inputs) {
646       PyDict_SetItemString(inputs, input.first.c_str(),
647                            PyLong_FromLong(input.second));
648     }
649     for (const auto& output : signature_def_outputs) {
650       PyDict_SetItemString(outputs, output.first.c_str(),
651                            PyLong_FromLong(output.second));
652     }
653 
654     PyDict_SetItemString(signature_def, "inputs", inputs);
655     PyDict_SetItemString(signature_def, "outputs", outputs);
656     PyDict_SetItemString(result, sig_key->c_str(), signature_def);
657   }
658   return result;
659 }
660 
GetSubgraphIndexFromSignature(const char * signature_key)661 PyObject* InterpreterWrapper::GetSubgraphIndexFromSignature(
662     const char* signature_key) {
663   TFLITE_PY_ENSURE_VALID_INTERPRETER();
664 
665   int32_t subgraph_index =
666       interpreter_->GetSubgraphIndexFromSignature(signature_key);
667 
668   if (subgraph_index < 0) {
669     PyErr_SetString(PyExc_ValueError, "No matching signature.");
670     return nullptr;
671   }
672   return PyLong_FromLong(static_cast<int64_t>(subgraph_index));
673 }
674 
GetTensor(int i,int subgraph_index) const675 PyObject* InterpreterWrapper::GetTensor(int i, int subgraph_index) const {
676   // Sanity check accessor
677   TfLiteTensor* tensor = nullptr;
678   int type_num = 0;
679 
680   PyObject* check_result = CheckGetTensorArgs(interpreter_.get(), i, &tensor,
681                                               &type_num, subgraph_index);
682   if (check_result == nullptr) return check_result;
683   Py_XDECREF(check_result);
684 
685   std::vector<npy_intp> dims(tensor->dims->data,
686                              tensor->dims->data + tensor->dims->size);
687   if (tensor->type != kTfLiteString && tensor->type != kTfLiteResource &&
688       tensor->type != kTfLiteVariant) {
689     // Make a buffer copy but we must tell Numpy It owns that data or else
690     // it will leak.
691     void* data = malloc(tensor->bytes);
692     if (!data) {
693       PyErr_SetString(PyExc_ValueError, "Malloc to copy tensor failed.");
694       return nullptr;
695     }
696     memcpy(data, tensor->data.raw, tensor->bytes);
697     PyObject* np_array;
698     if (tensor->sparsity == nullptr) {
699       np_array =
700           PyArray_SimpleNewFromData(dims.size(), dims.data(), type_num, data);
701     } else {
702       std::vector<npy_intp> sparse_buffer_dims(1);
703       size_t size_of_type;
704       if (GetSizeOfType(nullptr, tensor->type, &size_of_type) != kTfLiteOk) {
705         PyErr_SetString(PyExc_ValueError, "Unknown tensor type.");
706         free(data);
707         return nullptr;
708       }
709       sparse_buffer_dims[0] = tensor->bytes / size_of_type;
710       np_array = PyArray_SimpleNewFromData(
711           sparse_buffer_dims.size(), sparse_buffer_dims.data(), type_num, data);
712     }
713     PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(np_array),
714                         NPY_ARRAY_OWNDATA);
715     return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
716   } else {
717     // Create a C-order array so the data is contiguous in memory.
718     const int32_t kCOrder = 0;
719     PyObject* py_object =
720         PyArray_EMPTY(dims.size(), dims.data(), NPY_OBJECT, kCOrder);
721 
722     if (py_object == nullptr) {
723       PyErr_SetString(PyExc_MemoryError, "Failed to allocate PyArray.");
724       return nullptr;
725     }
726 
727     PyArrayObject* py_array = reinterpret_cast<PyArrayObject*>(py_object);
728     PyObject** data = reinterpret_cast<PyObject**>(PyArray_DATA(py_array));
729     auto num_strings = GetStringCount(tensor);
730     for (int j = 0; j < num_strings; ++j) {
731       auto ref = GetString(tensor, j);
732 
733       PyObject* bytes = PyBytes_FromStringAndSize(ref.str, ref.len);
734       if (bytes == nullptr) {
735         Py_DECREF(py_object);
736         PyErr_Format(PyExc_ValueError,
737                      "Could not create PyBytes from string %d of input %d.", j,
738                      i);
739         return nullptr;
740       }
741       // PyArray_EMPTY produces an array full of Py_None, which we must decref.
742       Py_DECREF(data[j]);
743       data[j] = bytes;
744     }
745     return py_object;
746   }
747 }
748 
tensor(PyObject * base_object,int tensor_index,int subgraph_index)749 PyObject* InterpreterWrapper::tensor(PyObject* base_object, int tensor_index,
750                                      int subgraph_index) {
751   // Sanity check accessor
752   TfLiteTensor* tensor = nullptr;
753   int type_num = 0;
754 
755   PyObject* check_result = CheckGetTensorArgs(
756       interpreter_.get(), tensor_index, &tensor, &type_num, subgraph_index);
757   if (check_result == nullptr) return check_result;
758   Py_XDECREF(check_result);
759 
760   std::vector<npy_intp> dims(tensor->dims->data,
761                              tensor->dims->data + tensor->dims->size);
762   PyArrayObject* np_array =
763       reinterpret_cast<PyArrayObject*>(PyArray_SimpleNewFromData(
764           dims.size(), dims.data(), type_num, tensor->data.raw));
765   Py_INCREF(base_object);  // SetBaseObject steals, so we need to add.
766   PyArray_SetBaseObject(np_array, base_object);
767   return PyArray_Return(np_array);
768 }
769 
CreateWrapperCPPFromFile(const char * model_path,int op_resolver_id,const std::vector<std::string> & registerers_by_name,const std::vector<std::function<void (uintptr_t)>> & registerers_by_func,std::string * error_msg,bool preserve_all_tensors)770 InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
771     const char* model_path, int op_resolver_id,
772     const std::vector<std::string>& registerers_by_name,
773     const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
774     std::string* error_msg, bool preserve_all_tensors) {
775   std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
776   std::unique_ptr<InterpreterWrapper::Model> model =
777       Model::BuildFromFile(model_path, error_reporter.get());
778   return CreateInterpreterWrapper(std::move(model), op_resolver_id,
779                                   std::move(error_reporter),
780                                   registerers_by_name, registerers_by_func,
781                                   error_msg, preserve_all_tensors);
782 }
783 
CreateWrapperCPPFromFile(const char * model_path,int op_resolver_id,const std::vector<std::string> & registerers,std::string * error_msg,bool preserve_all_tensors)784 InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
785     const char* model_path, int op_resolver_id,
786     const std::vector<std::string>& registerers, std::string* error_msg,
787     bool preserve_all_tensors) {
788   return CreateWrapperCPPFromFile(model_path, op_resolver_id, registerers,
789                                   {} /*registerers_by_func*/, error_msg,
790                                   preserve_all_tensors);
791 }
792 
CreateWrapperCPPFromBuffer(PyObject * data,int op_resolver_id,const std::vector<std::string> & registerers_by_name,const std::vector<std::function<void (uintptr_t)>> & registerers_by_func,std::string * error_msg,bool preserve_all_tensors)793 InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
794     PyObject* data, int op_resolver_id,
795     const std::vector<std::string>& registerers_by_name,
796     const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
797     std::string* error_msg, bool preserve_all_tensors) {
798   char* buf = nullptr;
799   Py_ssize_t length;
800   std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
801 
802   if (python_utils::ConvertFromPyString(data, &buf, &length) == -1) {
803     return nullptr;
804   }
805   std::unique_ptr<InterpreterWrapper::Model> model =
806       Model::BuildFromBuffer(buf, length, error_reporter.get());
807   return CreateInterpreterWrapper(std::move(model), op_resolver_id,
808                                   std::move(error_reporter),
809                                   registerers_by_name, registerers_by_func,
810                                   error_msg, preserve_all_tensors);
811 }
812 
CreateWrapperCPPFromBuffer(PyObject * data,int op_resolver_id,const std::vector<std::string> & registerers,std::string * error_msg,bool preserve_all_tensors)813 InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
814     PyObject* data, int op_resolver_id,
815     const std::vector<std::string>& registerers, std::string* error_msg,
816     bool preserve_all_tensors) {
817   return CreateWrapperCPPFromBuffer(data, op_resolver_id, registerers, {},
818                                     error_msg, preserve_all_tensors);
819 }
820 
ResetVariableTensors()821 PyObject* InterpreterWrapper::ResetVariableTensors() {
822   TFLITE_PY_ENSURE_VALID_INTERPRETER();
823   TFLITE_PY_CHECK(interpreter_->ResetVariableTensors());
824   Py_RETURN_NONE;
825 }
826 
SetNumThreads(int num_threads)827 PyObject* InterpreterWrapper::SetNumThreads(int num_threads) {
828   TFLITE_PY_ENSURE_VALID_INTERPRETER();
829   interpreter_->SetNumThreads(num_threads);
830   Py_RETURN_NONE;
831 }
832 
ModifyGraphWithDelegate(TfLiteDelegate * delegate)833 PyObject* InterpreterWrapper::ModifyGraphWithDelegate(
834     TfLiteDelegate* delegate) {
835   TFLITE_PY_ENSURE_VALID_INTERPRETER();
836   TFLITE_PY_CHECK(interpreter_->ModifyGraphWithDelegate(delegate));
837   Py_RETURN_NONE;
838 }
839 
840 }  // namespace interpreter_wrapper
841 }  // namespace tflite
842