• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h"
16 
17 #include <stdarg.h>
18 
19 #include <cstring>
20 #include <functional>
21 #include <memory>
22 #include <sstream>
23 #include <string>
24 #include <utility>
25 
26 #include "absl/memory/memory.h"
27 #include "absl/strings/str_format.h"
28 #include "tensorflow/lite/c/common.h"
29 #include "tensorflow/lite/core/api/error_reporter.h"
30 #include "tensorflow/lite/core/api/op_resolver.h"
31 #include "tensorflow/lite/interpreter.h"
32 #include "tensorflow/lite/kernels/internal/compatibility.h"
33 #include "tensorflow/lite/kernels/register.h"
34 #include "tensorflow/lite/kernels/register_ref.h"
35 #include "tensorflow/lite/model.h"
36 #include "tensorflow/lite/mutable_op_resolver.h"
37 #include "tensorflow/lite/python/interpreter_wrapper/numpy.h"
38 #include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
39 #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
40 #include "tensorflow/lite/shared_library.h"
41 #include "tensorflow/lite/string_util.h"
42 #include "tensorflow/lite/util.h"
43 
44 #define TFLITE_PY_CHECK(x)               \
45   if ((x) != kTfLiteOk) {                \
46     return error_reporter_->exception(); \
47   }
48 
49 #define TFLITE_PY_TENSOR_BOUNDS_CHECK(i)                                    \
50   if (i >= interpreter_->tensors_size() || i < 0) {                         \
51     PyErr_Format(PyExc_ValueError,                                          \
52                  "Invalid tensor index %d exceeds max tensor index %lu", i, \
53                  interpreter_->tensors_size());                             \
54     return nullptr;                                                         \
55   }
56 
57 #define TFLITE_PY_SUBGRAPH_TENSOR_BOUNDS_CHECK(i, subgraph_index)             \
58   if (i >= interpreter_->subgraph(subgraph_index)->tensors_size() || i < 0) { \
59     PyErr_Format(PyExc_ValueError,                                            \
60                  "Invalid tensor index %d exceeds max tensor index %lu", i,   \
61                  interpreter_->subgraph(subgraph_index)->tensors_size());     \
62     return nullptr;                                                           \
63   }
64 
65 #define TFLITE_PY_SUBGRAPH_BOUNDS_CHECK(i)                                   \
66   if (i >= interpreter_->subgraphs_size() || i < 0) {                        \
67     PyErr_Format(PyExc_ValueError,                                           \
68                  "Invalid subgraph index %d exceeds max subgraph index %lu", \
69                  i, interpreter_->subgraphs_size());                         \
70     return nullptr;                                                          \
71   }
72 
73 #define TFLITE_PY_NODES_BOUNDS_CHECK(i)                   \
74   if (i >= interpreter_->nodes_size() || i < 0) {         \
75     PyErr_Format(PyExc_ValueError, "Invalid node index"); \
76     return nullptr;                                       \
77   }
78 
79 #define TFLITE_PY_ENSURE_VALID_INTERPRETER()                               \
80   if (!interpreter_) {                                                     \
81     PyErr_SetString(PyExc_ValueError, "Interpreter was not initialized."); \
82     return nullptr;                                                        \
83   }
84 
85 namespace tflite {
86 namespace interpreter_wrapper {
87 
88 namespace {
89 
90 using python_utils::PyDecrefDeleter;
91 
CreateInterpreter(const InterpreterWrapper::Model * model,const tflite::MutableOpResolver & resolver,bool preserve_all_tensors)92 std::unique_ptr<Interpreter> CreateInterpreter(
93     const InterpreterWrapper::Model* model,
94     const tflite::MutableOpResolver& resolver, bool preserve_all_tensors) {
95   if (!model) {
96     return nullptr;
97   }
98 
99   ::tflite::python::ImportNumpy();
100 
101   std::unique_ptr<Interpreter> interpreter;
102   InterpreterOptions options;
103   options.SetPreserveAllTensors(preserve_all_tensors);
104   InterpreterBuilder builder(*model, resolver, &options);
105   if (builder(&interpreter) != kTfLiteOk) {
106     return nullptr;
107   }
108   return interpreter;
109 }
110 
PyArrayFromFloatVector(const float * data,npy_intp size)111 PyObject* PyArrayFromFloatVector(const float* data, npy_intp size) {
112   void* pydata = malloc(size * sizeof(float));
113   memcpy(pydata, data, size * sizeof(float));
114   PyObject* obj = PyArray_SimpleNewFromData(1, &size, NPY_FLOAT32, pydata);
115   PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(obj), NPY_ARRAY_OWNDATA);
116   return obj;
117 }
118 
PyArrayFromIntVector(const int * data,npy_intp size)119 PyObject* PyArrayFromIntVector(const int* data, npy_intp size) {
120   void* pydata = malloc(size * sizeof(int));
121   memcpy(pydata, data, size * sizeof(int));
122   PyObject* obj = PyArray_SimpleNewFromData(1, &size, NPY_INT32, pydata);
123   PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(obj), NPY_ARRAY_OWNDATA);
124   return obj;
125 }
126 
PyTupleFromQuantizationParam(const TfLiteQuantizationParams & param)127 PyObject* PyTupleFromQuantizationParam(const TfLiteQuantizationParams& param) {
128   PyObject* result = PyTuple_New(2);
129   PyTuple_SET_ITEM(result, 0, PyFloat_FromDouble(param.scale));
130   PyTuple_SET_ITEM(result, 1, PyLong_FromLong(param.zero_point));
131   return result;
132 }
133 
PyDictFromSparsityParam(const TfLiteSparsity & param)134 PyObject* PyDictFromSparsityParam(const TfLiteSparsity& param) {
135   PyObject* result = PyDict_New();
136   PyDict_SetItemString(result, "traversal_order",
137                        PyArrayFromIntVector(param.traversal_order->data,
138                                             param.traversal_order->size));
139   PyDict_SetItemString(
140       result, "block_map",
141       PyArrayFromIntVector(param.block_map->data, param.block_map->size));
142   PyObject* dim_metadata = PyList_New(param.dim_metadata_size);
143   for (int i = 0; i < param.dim_metadata_size; i++) {
144     PyObject* dim_metadata_i = PyDict_New();
145     if (param.dim_metadata[i].format == kTfLiteDimDense) {
146       PyDict_SetItemString(dim_metadata_i, "format", PyLong_FromSize_t(0));
147       PyDict_SetItemString(dim_metadata_i, "dense_size",
148                            PyLong_FromSize_t(param.dim_metadata[i].dense_size));
149     } else {
150       PyDict_SetItemString(dim_metadata_i, "format", PyLong_FromSize_t(1));
151       const auto* array_segments = param.dim_metadata[i].array_segments;
152       const auto* array_indices = param.dim_metadata[i].array_indices;
153       PyDict_SetItemString(
154           dim_metadata_i, "array_segments",
155           PyArrayFromIntVector(array_segments->data, array_segments->size));
156       PyDict_SetItemString(
157           dim_metadata_i, "array_indices",
158           PyArrayFromIntVector(array_indices->data, array_indices->size));
159     }
160     PyList_SetItem(dim_metadata, i, dim_metadata_i);
161   }
162   PyDict_SetItemString(result, "dim_metadata", dim_metadata);
163   return result;
164 }
165 
RegisterCustomOpByName(const char * registerer_name,tflite::MutableOpResolver * resolver,std::string * error_msg)166 bool RegisterCustomOpByName(const char* registerer_name,
167                             tflite::MutableOpResolver* resolver,
168                             std::string* error_msg) {
169   // Registerer functions take a pointer to a BuiltinOpResolver as an input
170   // parameter and return void.
171   // TODO(b/137576229): We should implement this functionality in a more
172   // principled way.
173   typedef void (*RegistererFunctionType)(tflite::MutableOpResolver*);
174 
175   // Look for the Registerer function by name.
176   RegistererFunctionType registerer = reinterpret_cast<RegistererFunctionType>(
177       SharedLibrary::GetSymbol(registerer_name));
178 
179   // Fail in an informative way if the function was not found.
180   if (registerer == nullptr) {
181     *error_msg =
182         absl::StrFormat("Looking up symbol '%s' failed with error '%s'.",
183                         registerer_name, SharedLibrary::GetError());
184     return false;
185   }
186 
187   // Call the registerer with the resolver.
188   registerer(resolver);
189   return true;
190 }
191 
192 }  // namespace
193 
194 static constexpr int kBuiltinOpResolver = 1;
195 static constexpr int kBuiltinRefOpResolver = 2;
196 static constexpr int kBuiltinOpResolverWithoutDefaultDelegates = 3;
197 
CreateInterpreterWrapper(std::unique_ptr<InterpreterWrapper::Model> model,int op_resolver_id,std::unique_ptr<PythonErrorReporter> error_reporter,const std::vector<std::string> & registerers_by_name,const std::vector<std::function<void (uintptr_t)>> & registerers_by_func,std::string * error_msg,bool preserve_all_tensors)198 InterpreterWrapper* InterpreterWrapper::CreateInterpreterWrapper(
199     std::unique_ptr<InterpreterWrapper::Model> model, int op_resolver_id,
200     std::unique_ptr<PythonErrorReporter> error_reporter,
201     const std::vector<std::string>& registerers_by_name,
202     const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
203     std::string* error_msg, bool preserve_all_tensors) {
204   if (!model) {
205     *error_msg = error_reporter->message();
206     return nullptr;
207   }
208 
209   std::unique_ptr<tflite::MutableOpResolver> resolver;
210   switch (op_resolver_id) {
211     case kBuiltinOpResolver:
212       resolver = std::make_unique<tflite::ops::builtin::BuiltinOpResolver>();
213       break;
214     case kBuiltinRefOpResolver:
215       resolver = std::make_unique<tflite::ops::builtin::BuiltinRefOpResolver>();
216       break;
217     case kBuiltinOpResolverWithoutDefaultDelegates:
218       resolver = std::make_unique<
219           tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates>();
220       break;
221     default:
222       // This should not never happen because the eventual caller in
223       // interpreter.py should have passed a valid id here.
224       TFLITE_DCHECK(false);
225       return nullptr;
226   }
227 
228   for (const auto& registerer : registerers_by_name) {
229     if (!RegisterCustomOpByName(registerer.c_str(), resolver.get(), error_msg))
230       return nullptr;
231   }
232   for (const auto& registerer : registerers_by_func) {
233     registerer(reinterpret_cast<uintptr_t>(resolver.get()));
234   }
235   auto interpreter =
236       CreateInterpreter(model.get(), *resolver, preserve_all_tensors);
237   if (!interpreter) {
238     *error_msg = error_reporter->message();
239     return nullptr;
240   }
241 
242   InterpreterWrapper* wrapper =
243       new InterpreterWrapper(std::move(model), std::move(error_reporter),
244                              std::move(resolver), std::move(interpreter));
245   return wrapper;
246 }
247 
InterpreterWrapper(std::unique_ptr<InterpreterWrapper::Model> model,std::unique_ptr<PythonErrorReporter> error_reporter,std::unique_ptr<tflite::MutableOpResolver> resolver,std::unique_ptr<Interpreter> interpreter)248 InterpreterWrapper::InterpreterWrapper(
249     std::unique_ptr<InterpreterWrapper::Model> model,
250     std::unique_ptr<PythonErrorReporter> error_reporter,
251     std::unique_ptr<tflite::MutableOpResolver> resolver,
252     std::unique_ptr<Interpreter> interpreter)
253     : model_(std::move(model)),
254       error_reporter_(std::move(error_reporter)),
255       resolver_(std::move(resolver)),
256       interpreter_(std::move(interpreter)) {}
257 
~InterpreterWrapper()258 InterpreterWrapper::~InterpreterWrapper() {}
259 
260 // LINT.IfChange
261 static constexpr int kUndeterminedSubgraphIndex = -1;
262 // LINT.ThenChange(//tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper_pybind11.cc)
AllocateTensors(int subgraph_index)263 PyObject* InterpreterWrapper::AllocateTensors(int subgraph_index) {
264   TFLITE_PY_ENSURE_VALID_INTERPRETER();
265   if (subgraph_index == kUndeterminedSubgraphIndex) {
266     TFLITE_PY_CHECK(interpreter_->AllocateTensors());
267   } else {
268     TFLITE_PY_SUBGRAPH_BOUNDS_CHECK(subgraph_index);
269     TFLITE_PY_CHECK(interpreter_->subgraph(subgraph_index)->AllocateTensors());
270   }
271   Py_RETURN_NONE;
272 }
273 
Invoke(int subgraph_index)274 PyObject* InterpreterWrapper::Invoke(int subgraph_index) {
275   TFLITE_PY_ENSURE_VALID_INTERPRETER();
276   TFLITE_PY_SUBGRAPH_BOUNDS_CHECK(subgraph_index);
277 
278   // Release the GIL so that we can run multiple interpreters in parallel
279   TfLiteStatus status_code = kTfLiteOk;
280   Py_BEGIN_ALLOW_THREADS;  // To return can happen between this and end!
281   tflite::Subgraph* subgraph = interpreter_->subgraph(subgraph_index);
282   status_code = subgraph->Invoke();
283 
284   if (!interpreter_->allow_buffer_handle_output_) {
285     for (int tensor_index : subgraph->outputs()) {
286       subgraph->EnsureTensorDataIsReadable(tensor_index);
287     }
288   }
289   Py_END_ALLOW_THREADS;
290 
291   TFLITE_PY_CHECK(
292       status_code);  // don't move this into the Py_BEGIN/Py_End block
293 
294   Py_RETURN_NONE;
295 }
296 
InputIndices() const297 PyObject* InterpreterWrapper::InputIndices() const {
298   TFLITE_PY_ENSURE_VALID_INTERPRETER();
299   PyObject* np_array = PyArrayFromIntVector(interpreter_->inputs().data(),
300                                             interpreter_->inputs().size());
301 
302   return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
303 }
304 
OutputIndices() const305 PyObject* InterpreterWrapper::OutputIndices() const {
306   PyObject* np_array = PyArrayFromIntVector(interpreter_->outputs().data(),
307                                             interpreter_->outputs().size());
308 
309   return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
310 }
311 
ResizeInputTensorImpl(int i,PyObject * value)312 PyObject* InterpreterWrapper::ResizeInputTensorImpl(int i, PyObject* value) {
313   TFLITE_PY_ENSURE_VALID_INTERPRETER();
314 
315   std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
316       PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr));
317   if (!array_safe) {
318     PyErr_SetString(PyExc_ValueError,
319                     "Failed to convert numpy value into readable tensor.");
320     return nullptr;
321   }
322 
323   PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
324 
325   if (PyArray_NDIM(array) != 1) {
326     PyErr_Format(PyExc_ValueError, "Shape should be 1D instead of %d.",
327                  PyArray_NDIM(array));
328     return nullptr;
329   }
330 
331   if (PyArray_TYPE(array) != NPY_INT32) {
332     PyErr_Format(PyExc_ValueError, "Shape must be type int32 (was %d).",
333                  PyArray_TYPE(array));
334     return nullptr;
335   }
336 
337   PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(array),
338                       NPY_ARRAY_OWNDATA);
339   return PyArray_Return(reinterpret_cast<PyArrayObject*>(array));
340 }
341 
ResizeInputTensor(int i,PyObject * value,bool strict,int subgraph_index)342 PyObject* InterpreterWrapper::ResizeInputTensor(int i, PyObject* value,
343                                                 bool strict,
344                                                 int subgraph_index) {
345   TFLITE_PY_ENSURE_VALID_INTERPRETER();
346   TFLITE_PY_SUBGRAPH_BOUNDS_CHECK(subgraph_index);
347 
348   PyArrayObject* array =
349       reinterpret_cast<PyArrayObject*>(ResizeInputTensorImpl(i, value));
350   if (array == nullptr) {
351     return nullptr;
352   }
353 
354   std::vector<int> dims(PyArray_SHAPE(array)[0]);
355   memcpy(dims.data(), PyArray_BYTES(array), dims.size() * sizeof(int));
356 
357   if (strict) {
358     TFLITE_PY_CHECK(interpreter_->subgraph(subgraph_index)
359                         ->ResizeInputTensorStrict(i, dims));
360   } else {
361     TFLITE_PY_CHECK(
362         interpreter_->subgraph(subgraph_index)->ResizeInputTensor(i, dims));
363   }
364   Py_RETURN_NONE;
365 }
366 
NumTensors() const367 int InterpreterWrapper::NumTensors() const {
368   if (!interpreter_) {
369     return 0;
370   }
371   return interpreter_->tensors_size();
372 }
373 
TensorName(int i) const374 std::string InterpreterWrapper::TensorName(int i) const {
375   if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) {
376     return "";
377   }
378 
379   const TfLiteTensor* tensor = interpreter_->tensor(i);
380   return tensor->name ? tensor->name : "";
381 }
382 
TensorType(int i) const383 PyObject* InterpreterWrapper::TensorType(int i) const {
384   TFLITE_PY_ENSURE_VALID_INTERPRETER();
385   TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
386 
387   const TfLiteTensor* tensor = interpreter_->tensor(i);
388   if (tensor->type == kTfLiteNoType) {
389     PyErr_Format(PyExc_ValueError, "Tensor with no type found.");
390     return nullptr;
391   }
392 
393   int code = python_utils::TfLiteTypeToPyArrayType(tensor->type);
394   if (code == -1) {
395     PyErr_Format(PyExc_ValueError, "Invalid tflite type code %d", code);
396     return nullptr;
397   }
398   return PyArray_TypeObjectFromType(code);
399 }
400 
TensorSize(int i) const401 PyObject* InterpreterWrapper::TensorSize(int i) const {
402   TFLITE_PY_ENSURE_VALID_INTERPRETER();
403   TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
404 
405   const TfLiteTensor* tensor = interpreter_->tensor(i);
406   if (tensor->dims == nullptr) {
407     PyErr_Format(PyExc_ValueError, "Tensor with no shape found.");
408     return nullptr;
409   }
410   PyObject* np_array =
411       PyArrayFromIntVector(tensor->dims->data, tensor->dims->size);
412 
413   return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
414 }
415 
TensorSizeSignature(int i) const416 PyObject* InterpreterWrapper::TensorSizeSignature(int i) const {
417   TFLITE_PY_ENSURE_VALID_INTERPRETER();
418   TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
419 
420   const TfLiteTensor* tensor = interpreter_->tensor(i);
421   const int32_t* size_signature_data = nullptr;
422   int32_t size_signature_size = 0;
423   if (tensor->dims_signature != nullptr && tensor->dims_signature->size != 0) {
424     size_signature_data = tensor->dims_signature->data;
425     size_signature_size = tensor->dims_signature->size;
426   } else {
427     size_signature_data = tensor->dims->data;
428     size_signature_size = tensor->dims->size;
429   }
430   PyObject* np_array =
431       PyArrayFromIntVector(size_signature_data, size_signature_size);
432 
433   return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
434 }
435 
TensorSparsityParameters(int i) const436 PyObject* InterpreterWrapper::TensorSparsityParameters(int i) const {
437   TFLITE_PY_ENSURE_VALID_INTERPRETER();
438   TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
439   const TfLiteTensor* tensor = interpreter_->tensor(i);
440   if (tensor->sparsity == nullptr) {
441     return PyDict_New();
442   }
443 
444   return PyDictFromSparsityParam(*tensor->sparsity);
445 }
446 
TensorQuantization(int i) const447 PyObject* InterpreterWrapper::TensorQuantization(int i) const {
448   TFLITE_PY_ENSURE_VALID_INTERPRETER();
449   TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
450   const TfLiteTensor* tensor = interpreter_->tensor(i);
451   return PyTupleFromQuantizationParam(tensor->params);
452 }
453 
TensorQuantizationParameters(int i) const454 PyObject* InterpreterWrapper::TensorQuantizationParameters(int i) const {
455   TFLITE_PY_ENSURE_VALID_INTERPRETER();
456   TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
457   const TfLiteTensor* tensor = interpreter_->tensor(i);
458   const TfLiteQuantization quantization = tensor->quantization;
459   float* scales_data = nullptr;
460   int32_t* zero_points_data = nullptr;
461   int32_t scales_size = 0;
462   int32_t zero_points_size = 0;
463   int32_t quantized_dimension = 0;
464   if (quantization.type == kTfLiteAffineQuantization) {
465     const TfLiteAffineQuantization* q_params =
466         reinterpret_cast<const TfLiteAffineQuantization*>(quantization.params);
467     if (q_params->scale) {
468       scales_data = q_params->scale->data;
469       scales_size = q_params->scale->size;
470     }
471     if (q_params->zero_point) {
472       zero_points_data = q_params->zero_point->data;
473       zero_points_size = q_params->zero_point->size;
474     }
475     quantized_dimension = q_params->quantized_dimension;
476   }
477   PyObject* scales_array = PyArrayFromFloatVector(scales_data, scales_size);
478   PyObject* zero_points_array =
479       PyArrayFromIntVector(zero_points_data, zero_points_size);
480 
481   PyObject* result = PyTuple_New(3);
482   PyTuple_SET_ITEM(result, 0, scales_array);
483   PyTuple_SET_ITEM(result, 1, zero_points_array);
484   PyTuple_SET_ITEM(result, 2, PyLong_FromLong(quantized_dimension));
485   return result;
486 }
487 
SetTensor(int i,PyObject * value,int subgraph_index)488 PyObject* InterpreterWrapper::SetTensor(int i, PyObject* value,
489                                         int subgraph_index) {
490   TFLITE_PY_ENSURE_VALID_INTERPRETER();
491   TFLITE_PY_SUBGRAPH_BOUNDS_CHECK(subgraph_index);
492   TFLITE_PY_SUBGRAPH_TENSOR_BOUNDS_CHECK(i, subgraph_index);
493 
494   std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
495       PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr));
496   if (!array_safe) {
497     PyErr_SetString(PyExc_ValueError,
498                     "Failed to convert value into readable tensor.");
499     return nullptr;
500   }
501 
502   PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
503   TfLiteTensor* tensor = interpreter_->subgraph(subgraph_index)->tensor(i);
504 
505   if (python_utils::TfLiteTypeFromPyArray(array) != tensor->type) {
506     PyErr_Format(PyExc_ValueError,
507                  "Cannot set tensor:"
508                  " Got value of type %s"
509                  " but expected type %s for input %d, name: %s ",
510                  TfLiteTypeGetName(python_utils::TfLiteTypeFromPyArray(array)),
511                  TfLiteTypeGetName(tensor->type), i, tensor->name);
512     return nullptr;
513   }
514 
515   if (PyArray_NDIM(array) != tensor->dims->size) {
516     PyErr_Format(PyExc_ValueError,
517                  "Cannot set tensor: Dimension mismatch."
518                  " Got %d"
519                  " but expected %d for input %d.",
520                  PyArray_NDIM(array), tensor->dims->size, i);
521     return nullptr;
522   }
523 
524   for (int j = 0; j < PyArray_NDIM(array); j++) {
525     if (tensor->dims->data[j] != PyArray_SHAPE(array)[j]) {
526       PyErr_Format(PyExc_ValueError,
527                    "Cannot set tensor: Dimension mismatch."
528                    " Got %ld"
529                    " but expected %d for dimension %d of input %d.",
530                    PyArray_SHAPE(array)[j], tensor->dims->data[j], j, i);
531       return nullptr;
532     }
533   }
534 
535   if (tensor->type != kTfLiteString) {
536     // Only allow empty tensors.
537     if (tensor->data.raw == nullptr && tensor->bytes) {
538       PyErr_Format(PyExc_ValueError,
539                    "Cannot set tensor:"
540                    " Tensor is unallocated. Try calling allocate_tensors()"
541                    " first");
542       return nullptr;
543     }
544 
545     size_t size = PyArray_NBYTES(array);
546     if (size != tensor->bytes) {
547       PyErr_Format(PyExc_ValueError,
548                    "numpy array had %zu bytes but expected %zu bytes.", size,
549                    tensor->bytes);
550       return nullptr;
551     }
552     memcpy(tensor->data.raw, PyArray_DATA(array), size);
553   } else {
554     DynamicBuffer dynamic_buffer;
555     if (!python_utils::FillStringBufferWithPyArray(value, &dynamic_buffer)) {
556       return nullptr;
557     }
558     dynamic_buffer.WriteToTensor(tensor, nullptr);
559   }
560   Py_RETURN_NONE;
561 }
562 
NumNodes() const563 int InterpreterWrapper::NumNodes() const {
564   if (!interpreter_) {
565     return 0;
566   }
567   return interpreter_->nodes_size();
568 }
569 
NodeInputs(int i) const570 PyObject* InterpreterWrapper::NodeInputs(int i) const {
571   TFLITE_PY_ENSURE_VALID_INTERPRETER();
572   TFLITE_PY_NODES_BOUNDS_CHECK(i);
573 
574   const TfLiteNode* node = &(interpreter_->node_and_registration(i)->first);
575   PyObject* inputs =
576       PyArrayFromIntVector(node->inputs->data, node->inputs->size);
577   return inputs;
578 }
579 
NodeOutputs(int i) const580 PyObject* InterpreterWrapper::NodeOutputs(int i) const {
581   TFLITE_PY_ENSURE_VALID_INTERPRETER();
582   TFLITE_PY_NODES_BOUNDS_CHECK(i);
583 
584   const TfLiteNode* node = &(interpreter_->node_and_registration(i)->first);
585   PyObject* outputs =
586       PyArrayFromIntVector(node->outputs->data, node->outputs->size);
587   return outputs;
588 }
589 
NodeName(int i) const590 std::string InterpreterWrapper::NodeName(int i) const {
591   if (!interpreter_ || i >= interpreter_->nodes_size() || i < 0) {
592     return "";
593   }
594   // Get op name from registration
595   const TfLiteRegistration* node_registration =
596       &(interpreter_->node_and_registration(i)->second);
597   int32_t op_code = node_registration->builtin_code;
598   std::string op_name;
599   if (op_code == tflite::BuiltinOperator_CUSTOM) {
600     const char* custom_name = node_registration->custom_name;
601     op_name = custom_name ? custom_name : "UnknownCustomOp";
602   } else {
603     op_name = tflite::EnumNamesBuiltinOperator()[op_code];
604   }
605   std::string op_name_str(op_name);
606   return op_name_str;
607 }
608 
609 namespace {
610 
611 // Checks to see if a tensor access can succeed (returns nullptr on error).
612 // Otherwise returns Py_None.
CheckGetTensorArgs(Interpreter * interpreter_,int tensor_index,TfLiteTensor ** tensor,int * type_num,int subgraph_index)613 PyObject* CheckGetTensorArgs(Interpreter* interpreter_, int tensor_index,
614                              TfLiteTensor** tensor, int* type_num,
615                              int subgraph_index) {
616   TFLITE_PY_ENSURE_VALID_INTERPRETER();
617   TFLITE_PY_SUBGRAPH_BOUNDS_CHECK(subgraph_index);
618   TFLITE_PY_SUBGRAPH_TENSOR_BOUNDS_CHECK(tensor_index, subgraph_index);
619 
620   *tensor = interpreter_->subgraph(subgraph_index)->tensor(tensor_index);
621   // Invalid size only when bytes are 0 but pointer is allocated.
622   if ((*tensor)->bytes == 0 && (*tensor)->data.raw) {
623     PyErr_SetString(PyExc_ValueError, "Invalid tensor size.");
624     return nullptr;
625   }
626 
627   *type_num = python_utils::TfLiteTypeToPyArrayType((*tensor)->type);
628   if (*type_num == -1) {
629     PyErr_SetString(PyExc_ValueError, "Unknown tensor type.");
630     return nullptr;
631   }
632 
633   // Tensor data can't be null if size is > 0. 0 bytes is valid if tensor
634   // is empty.
635   if (!(*tensor)->data.raw && (*tensor)->bytes) {
636     PyErr_SetString(PyExc_ValueError,
637                     "Tensor data is null."
638                     " Run allocate_tensors() first");
639     return nullptr;
640   }
641 
642   Py_RETURN_NONE;
643 }
644 
645 }  // namespace
646 
GetSignatureDefs() const647 PyObject* InterpreterWrapper::GetSignatureDefs() const {
648   PyObject* result = PyDict_New();
649   for (const auto& sig_key : interpreter_->signature_keys()) {
650     PyObject* signature_def = PyDict_New();
651     PyObject* inputs = PyDict_New();
652     PyObject* outputs = PyDict_New();
653     const auto& signature_def_inputs =
654         interpreter_->signature_inputs(sig_key->c_str());
655     const auto& signature_def_outputs =
656         interpreter_->signature_outputs(sig_key->c_str());
657     for (const auto& input : signature_def_inputs) {
658       PyDict_SetItemString(inputs, input.first.c_str(),
659                            PyLong_FromLong(input.second));
660     }
661     for (const auto& output : signature_def_outputs) {
662       PyDict_SetItemString(outputs, output.first.c_str(),
663                            PyLong_FromLong(output.second));
664     }
665 
666     PyDict_SetItemString(signature_def, "inputs", inputs);
667     PyDict_SetItemString(signature_def, "outputs", outputs);
668     PyDict_SetItemString(result, sig_key->c_str(), signature_def);
669   }
670   return result;
671 }
672 
GetSubgraphIndexFromSignature(const char * signature_key)673 PyObject* InterpreterWrapper::GetSubgraphIndexFromSignature(
674     const char* signature_key) {
675   TFLITE_PY_ENSURE_VALID_INTERPRETER();
676 
677   int32_t subgraph_index =
678       interpreter_->GetSubgraphIndexFromSignature(signature_key);
679 
680   if (subgraph_index < 0) {
681     PyErr_SetString(PyExc_ValueError, "No matching signature.");
682     return nullptr;
683   }
684   return PyLong_FromLong(static_cast<int64_t>(subgraph_index));
685 }
686 
GetTensor(int i,int subgraph_index) const687 PyObject* InterpreterWrapper::GetTensor(int i, int subgraph_index) const {
688   // Sanity check accessor
689   TfLiteTensor* tensor = nullptr;
690   int type_num = 0;
691 
692   PyObject* check_result = CheckGetTensorArgs(interpreter_.get(), i, &tensor,
693                                               &type_num, subgraph_index);
694   if (check_result == nullptr) return check_result;
695   Py_XDECREF(check_result);
696 
697   std::vector<npy_intp> dims(tensor->dims->data,
698                              tensor->dims->data + tensor->dims->size);
699   if (tensor->type != kTfLiteString && tensor->type != kTfLiteResource &&
700       tensor->type != kTfLiteVariant) {
701     // Make a buffer copy but we must tell Numpy It owns that data or else
702     // it will leak.
703     void* data = malloc(tensor->bytes);
704     if (!data) {
705       PyErr_SetString(PyExc_ValueError, "Malloc to copy tensor failed.");
706       return nullptr;
707     }
708     memcpy(data, tensor->data.raw, tensor->bytes);
709     PyObject* np_array;
710     if (tensor->sparsity == nullptr) {
711       np_array =
712           PyArray_SimpleNewFromData(dims.size(), dims.data(), type_num, data);
713     } else {
714       std::vector<npy_intp> sparse_buffer_dims(1);
715       size_t size_of_type;
716       if (GetSizeOfType(nullptr, tensor->type, &size_of_type) != kTfLiteOk) {
717         PyErr_SetString(PyExc_ValueError, "Unknown tensor type.");
718         free(data);
719         return nullptr;
720       }
721       sparse_buffer_dims[0] = tensor->bytes / size_of_type;
722       np_array = PyArray_SimpleNewFromData(
723           sparse_buffer_dims.size(), sparse_buffer_dims.data(), type_num, data);
724     }
725     PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(np_array),
726                         NPY_ARRAY_OWNDATA);
727     return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
728   } else {
729     // Create a C-order array so the data is contiguous in memory.
730     const int32_t kCOrder = 0;
731     PyObject* py_object =
732         PyArray_EMPTY(dims.size(), dims.data(), NPY_OBJECT, kCOrder);
733 
734     if (py_object == nullptr) {
735       PyErr_SetString(PyExc_MemoryError, "Failed to allocate PyArray.");
736       return nullptr;
737     }
738 
739     PyArrayObject* py_array = reinterpret_cast<PyArrayObject*>(py_object);
740     PyObject** data = reinterpret_cast<PyObject**>(PyArray_DATA(py_array));
741     auto num_strings = GetStringCount(tensor);
742     for (int j = 0; j < num_strings; ++j) {
743       auto ref = GetString(tensor, j);
744 
745       PyObject* bytes = PyBytes_FromStringAndSize(ref.str, ref.len);
746       if (bytes == nullptr) {
747         Py_DECREF(py_object);
748         PyErr_Format(PyExc_ValueError,
749                      "Could not create PyBytes from string %d of input %d.", j,
750                      i);
751         return nullptr;
752       }
753       // PyArray_EMPTY produces an array full of Py_None, which we must decref.
754       Py_DECREF(data[j]);
755       data[j] = bytes;
756     }
757     return py_object;
758   }
759 }
760 
tensor(PyObject * base_object,int tensor_index,int subgraph_index)761 PyObject* InterpreterWrapper::tensor(PyObject* base_object, int tensor_index,
762                                      int subgraph_index) {
763   // Sanity check accessor
764   TfLiteTensor* tensor = nullptr;
765   int type_num = 0;
766 
767   PyObject* check_result = CheckGetTensorArgs(
768       interpreter_.get(), tensor_index, &tensor, &type_num, subgraph_index);
769   if (check_result == nullptr) return check_result;
770   Py_XDECREF(check_result);
771 
772   std::vector<npy_intp> dims(tensor->dims->data,
773                              tensor->dims->data + tensor->dims->size);
774   PyArrayObject* np_array =
775       reinterpret_cast<PyArrayObject*>(PyArray_SimpleNewFromData(
776           dims.size(), dims.data(), type_num, tensor->data.raw));
777   Py_INCREF(base_object);  // SetBaseObject steals, so we need to add.
778   PyArray_SetBaseObject(np_array, base_object);
779   return PyArray_Return(np_array);
780 }
781 
CreateWrapperCPPFromFile(const char * model_path,int op_resolver_id,const std::vector<std::string> & registerers_by_name,const std::vector<std::function<void (uintptr_t)>> & registerers_by_func,std::string * error_msg,bool preserve_all_tensors)782 InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
783     const char* model_path, int op_resolver_id,
784     const std::vector<std::string>& registerers_by_name,
785     const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
786     std::string* error_msg, bool preserve_all_tensors) {
787   std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
788   std::unique_ptr<InterpreterWrapper::Model> model =
789       Model::BuildFromFile(model_path, error_reporter.get());
790   return CreateInterpreterWrapper(std::move(model), op_resolver_id,
791                                   std::move(error_reporter),
792                                   registerers_by_name, registerers_by_func,
793                                   error_msg, preserve_all_tensors);
794 }
795 
CreateWrapperCPPFromFile(const char * model_path,int op_resolver_id,const std::vector<std::string> & registerers,std::string * error_msg,bool preserve_all_tensors)796 InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
797     const char* model_path, int op_resolver_id,
798     const std::vector<std::string>& registerers, std::string* error_msg,
799     bool preserve_all_tensors) {
800   return CreateWrapperCPPFromFile(model_path, op_resolver_id, registerers,
801                                   {} /*registerers_by_func*/, error_msg,
802                                   preserve_all_tensors);
803 }
804 
CreateWrapperCPPFromBuffer(PyObject * data,int op_resolver_id,const std::vector<std::string> & registerers_by_name,const std::vector<std::function<void (uintptr_t)>> & registerers_by_func,std::string * error_msg,bool preserve_all_tensors)805 InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
806     PyObject* data, int op_resolver_id,
807     const std::vector<std::string>& registerers_by_name,
808     const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
809     std::string* error_msg, bool preserve_all_tensors) {
810   char* buf = nullptr;
811   Py_ssize_t length;
812   std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
813 
814   if (python_utils::ConvertFromPyString(data, &buf, &length) == -1) {
815     return nullptr;
816   }
817   std::unique_ptr<InterpreterWrapper::Model> model =
818       Model::BuildFromBuffer(buf, length, error_reporter.get());
819   return CreateInterpreterWrapper(std::move(model), op_resolver_id,
820                                   std::move(error_reporter),
821                                   registerers_by_name, registerers_by_func,
822                                   error_msg, preserve_all_tensors);
823 }
824 
CreateWrapperCPPFromBuffer(PyObject * data,int op_resolver_id,const std::vector<std::string> & registerers,std::string * error_msg,bool preserve_all_tensors)825 InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
826     PyObject* data, int op_resolver_id,
827     const std::vector<std::string>& registerers, std::string* error_msg,
828     bool preserve_all_tensors) {
829   return CreateWrapperCPPFromBuffer(data, op_resolver_id, registerers, {},
830                                     error_msg, preserve_all_tensors);
831 }
832 
ResetVariableTensors()833 PyObject* InterpreterWrapper::ResetVariableTensors() {
834   TFLITE_PY_ENSURE_VALID_INTERPRETER();
835   TFLITE_PY_CHECK(interpreter_->ResetVariableTensors());
836   Py_RETURN_NONE;
837 }
838 
SetNumThreads(int num_threads)839 PyObject* InterpreterWrapper::SetNumThreads(int num_threads) {
840   TFLITE_PY_ENSURE_VALID_INTERPRETER();
841   interpreter_->SetNumThreads(num_threads);
842   Py_RETURN_NONE;
843 }
844 
ModifyGraphWithDelegate(TfLiteDelegate * delegate)845 PyObject* InterpreterWrapper::ModifyGraphWithDelegate(
846     TfLiteDelegate* delegate) {
847   TFLITE_PY_ENSURE_VALID_INTERPRETER();
848   TFLITE_PY_CHECK(interpreter_->ModifyGraphWithDelegate(delegate));
849   Py_RETURN_NONE;
850 }
851 
852 }  // namespace interpreter_wrapper
853 }  // namespace tflite
854