1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h"
16
17 #include <stdarg.h>
18
19 #include <cstring>
20 #include <functional>
21 #include <memory>
22 #include <sstream>
23 #include <string>
24
25 #include "absl/memory/memory.h"
26 #include "absl/strings/str_format.h"
27 #include "tensorflow/lite/c/common.h"
28 #include "tensorflow/lite/core/api/error_reporter.h"
29 #include "tensorflow/lite/core/api/op_resolver.h"
30 #include "tensorflow/lite/interpreter.h"
31 #include "tensorflow/lite/kernels/internal/compatibility.h"
32 #include "tensorflow/lite/kernels/register.h"
33 #include "tensorflow/lite/kernels/register_ref.h"
34 #include "tensorflow/lite/model.h"
35 #include "tensorflow/lite/mutable_op_resolver.h"
36 #include "tensorflow/lite/python/interpreter_wrapper/numpy.h"
37 #include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
38 #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
39 #include "tensorflow/lite/shared_library.h"
40 #include "tensorflow/lite/string_util.h"
41 #include "tensorflow/lite/util.h"
42
43 #define TFLITE_PY_CHECK(x) \
44 if ((x) != kTfLiteOk) { \
45 return error_reporter_->exception(); \
46 }
47
48 #define TFLITE_PY_TENSOR_BOUNDS_CHECK(i) \
49 if (i >= interpreter_->tensors_size() || i < 0) { \
50 PyErr_Format(PyExc_ValueError, \
51 "Invalid tensor index %d exceeds max tensor index %lu", i, \
52 interpreter_->tensors_size()); \
53 return nullptr; \
54 }
55
56 #define TFLITE_PY_SUBGRAPH_TENSOR_BOUNDS_CHECK(i, subgraph_index) \
57 if (i >= interpreter_->subgraph(subgraph_index)->tensors_size() || i < 0) { \
58 PyErr_Format(PyExc_ValueError, \
59 "Invalid tensor index %d exceeds max tensor index %lu", i, \
60 interpreter_->subgraph(subgraph_index)->tensors_size()); \
61 return nullptr; \
62 }
63
64 #define TFLITE_PY_SUBGRAPH_BOUNDS_CHECK(i) \
65 if (i >= interpreter_->subgraphs_size() || i < 0) { \
66 PyErr_Format(PyExc_ValueError, \
67 "Invalid subgraph index %d exceeds max subgraph index %lu", \
68 i, interpreter_->subgraphs_size()); \
69 return nullptr; \
70 }
71
72 #define TFLITE_PY_NODES_BOUNDS_CHECK(i) \
73 if (i >= interpreter_->nodes_size() || i < 0) { \
74 PyErr_Format(PyExc_ValueError, "Invalid node index"); \
75 return nullptr; \
76 }
77
78 #define TFLITE_PY_ENSURE_VALID_INTERPRETER() \
79 if (!interpreter_) { \
80 PyErr_SetString(PyExc_ValueError, "Interpreter was not initialized."); \
81 return nullptr; \
82 }
83
84 namespace tflite {
85 namespace interpreter_wrapper {
86
87 namespace {
88
89 using python_utils::PyDecrefDeleter;
90
CreateInterpreter(const InterpreterWrapper::Model * model,const tflite::MutableOpResolver & resolver,bool preserve_all_tensors)91 std::unique_ptr<Interpreter> CreateInterpreter(
92 const InterpreterWrapper::Model* model,
93 const tflite::MutableOpResolver& resolver, bool preserve_all_tensors) {
94 if (!model) {
95 return nullptr;
96 }
97
98 ::tflite::python::ImportNumpy();
99
100 std::unique_ptr<Interpreter> interpreter;
101 InterpreterBuilder builder(*model, resolver);
102 if (preserve_all_tensors) builder.PreserveAllTensorsExperimental();
103 if (builder(&interpreter) != kTfLiteOk) {
104 return nullptr;
105 }
106 return interpreter;
107 }
108
PyArrayFromFloatVector(const float * data,npy_intp size)109 PyObject* PyArrayFromFloatVector(const float* data, npy_intp size) {
110 void* pydata = malloc(size * sizeof(float));
111 memcpy(pydata, data, size * sizeof(float));
112 PyObject* obj = PyArray_SimpleNewFromData(1, &size, NPY_FLOAT32, pydata);
113 PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(obj), NPY_ARRAY_OWNDATA);
114 return obj;
115 }
116
PyArrayFromIntVector(const int * data,npy_intp size)117 PyObject* PyArrayFromIntVector(const int* data, npy_intp size) {
118 void* pydata = malloc(size * sizeof(int));
119 memcpy(pydata, data, size * sizeof(int));
120 PyObject* obj = PyArray_SimpleNewFromData(1, &size, NPY_INT32, pydata);
121 PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(obj), NPY_ARRAY_OWNDATA);
122 return obj;
123 }
124
PyTupleFromQuantizationParam(const TfLiteQuantizationParams & param)125 PyObject* PyTupleFromQuantizationParam(const TfLiteQuantizationParams& param) {
126 PyObject* result = PyTuple_New(2);
127 PyTuple_SET_ITEM(result, 0, PyFloat_FromDouble(param.scale));
128 PyTuple_SET_ITEM(result, 1, PyLong_FromLong(param.zero_point));
129 return result;
130 }
131
PyDictFromSparsityParam(const TfLiteSparsity & param)132 PyObject* PyDictFromSparsityParam(const TfLiteSparsity& param) {
133 PyObject* result = PyDict_New();
134 PyDict_SetItemString(result, "traversal_order",
135 PyArrayFromIntVector(param.traversal_order->data,
136 param.traversal_order->size));
137 PyDict_SetItemString(
138 result, "block_map",
139 PyArrayFromIntVector(param.block_map->data, param.block_map->size));
140 PyObject* dim_metadata = PyList_New(param.dim_metadata_size);
141 for (int i = 0; i < param.dim_metadata_size; i++) {
142 PyObject* dim_metadata_i = PyDict_New();
143 if (param.dim_metadata[i].format == kTfLiteDimDense) {
144 PyDict_SetItemString(dim_metadata_i, "format", PyLong_FromSize_t(0));
145 PyDict_SetItemString(dim_metadata_i, "dense_size",
146 PyLong_FromSize_t(param.dim_metadata[i].dense_size));
147 } else {
148 PyDict_SetItemString(dim_metadata_i, "format", PyLong_FromSize_t(1));
149 const auto* array_segments = param.dim_metadata[i].array_segments;
150 const auto* array_indices = param.dim_metadata[i].array_indices;
151 PyDict_SetItemString(
152 dim_metadata_i, "array_segments",
153 PyArrayFromIntVector(array_segments->data, array_segments->size));
154 PyDict_SetItemString(
155 dim_metadata_i, "array_indices",
156 PyArrayFromIntVector(array_indices->data, array_indices->size));
157 }
158 PyList_SetItem(dim_metadata, i, dim_metadata_i);
159 }
160 PyDict_SetItemString(result, "dim_metadata", dim_metadata);
161 return result;
162 }
163
RegisterCustomOpByName(const char * registerer_name,tflite::MutableOpResolver * resolver,std::string * error_msg)164 bool RegisterCustomOpByName(const char* registerer_name,
165 tflite::MutableOpResolver* resolver,
166 std::string* error_msg) {
167 // Registerer functions take a pointer to a BuiltinOpResolver as an input
168 // parameter and return void.
169 // TODO(b/137576229): We should implement this functionality in a more
170 // principled way.
171 typedef void (*RegistererFunctionType)(tflite::MutableOpResolver*);
172
173 // Look for the Registerer function by name.
174 RegistererFunctionType registerer = reinterpret_cast<RegistererFunctionType>(
175 SharedLibrary::GetSymbol(registerer_name));
176
177 // Fail in an informative way if the function was not found.
178 if (registerer == nullptr) {
179 *error_msg =
180 absl::StrFormat("Looking up symbol '%s' failed with error '%s'.",
181 registerer_name, SharedLibrary::GetError());
182 return false;
183 }
184
185 // Call the registerer with the resolver.
186 registerer(resolver);
187 return true;
188 }
189
190 } // namespace
191
192 static constexpr int kBuiltinOpResolver = 1;
193 static constexpr int kBuiltinRefOpResolver = 2;
194 static constexpr int kBuiltinOpResolverWithoutDefaultDelegates = 3;
195
CreateInterpreterWrapper(std::unique_ptr<InterpreterWrapper::Model> model,int op_resolver_id,std::unique_ptr<PythonErrorReporter> error_reporter,const std::vector<std::string> & registerers_by_name,const std::vector<std::function<void (uintptr_t)>> & registerers_by_func,std::string * error_msg,bool preserve_all_tensors)196 InterpreterWrapper* InterpreterWrapper::CreateInterpreterWrapper(
197 std::unique_ptr<InterpreterWrapper::Model> model, int op_resolver_id,
198 std::unique_ptr<PythonErrorReporter> error_reporter,
199 const std::vector<std::string>& registerers_by_name,
200 const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
201 std::string* error_msg, bool preserve_all_tensors) {
202 if (!model) {
203 *error_msg = error_reporter->message();
204 return nullptr;
205 }
206
207 std::unique_ptr<tflite::MutableOpResolver> resolver;
208 switch (op_resolver_id) {
209 case kBuiltinOpResolver:
210 resolver = absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>();
211 break;
212 case kBuiltinRefOpResolver:
213 resolver =
214 absl::make_unique<tflite::ops::builtin::BuiltinRefOpResolver>();
215 break;
216 case kBuiltinOpResolverWithoutDefaultDelegates:
217 resolver = absl::make_unique<
218 tflite::ops::builtin::BuiltinOpResolverWithoutDefaultDelegates>();
219 break;
220 default:
221 // This should not never happen because the eventual caller in
222 // interpreter.py should have passed a valid id here.
223 TFLITE_DCHECK(false);
224 return nullptr;
225 }
226
227 for (const auto& registerer : registerers_by_name) {
228 if (!RegisterCustomOpByName(registerer.c_str(), resolver.get(), error_msg))
229 return nullptr;
230 }
231 for (const auto& registerer : registerers_by_func) {
232 registerer(reinterpret_cast<uintptr_t>(resolver.get()));
233 }
234 auto interpreter =
235 CreateInterpreter(model.get(), *resolver, preserve_all_tensors);
236 if (!interpreter) {
237 *error_msg = error_reporter->message();
238 return nullptr;
239 }
240
241 InterpreterWrapper* wrapper =
242 new InterpreterWrapper(std::move(model), std::move(error_reporter),
243 std::move(resolver), std::move(interpreter));
244 return wrapper;
245 }
246
InterpreterWrapper(std::unique_ptr<InterpreterWrapper::Model> model,std::unique_ptr<PythonErrorReporter> error_reporter,std::unique_ptr<tflite::MutableOpResolver> resolver,std::unique_ptr<Interpreter> interpreter)247 InterpreterWrapper::InterpreterWrapper(
248 std::unique_ptr<InterpreterWrapper::Model> model,
249 std::unique_ptr<PythonErrorReporter> error_reporter,
250 std::unique_ptr<tflite::MutableOpResolver> resolver,
251 std::unique_ptr<Interpreter> interpreter)
252 : model_(std::move(model)),
253 error_reporter_(std::move(error_reporter)),
254 resolver_(std::move(resolver)),
255 interpreter_(std::move(interpreter)) {}
256
~InterpreterWrapper()257 InterpreterWrapper::~InterpreterWrapper() {}
258
AllocateTensors(int subgraph_index)259 PyObject* InterpreterWrapper::AllocateTensors(int subgraph_index) {
260 TFLITE_PY_ENSURE_VALID_INTERPRETER();
261 TFLITE_PY_SUBGRAPH_BOUNDS_CHECK(subgraph_index);
262 TFLITE_PY_CHECK(interpreter_->subgraph(subgraph_index)->AllocateTensors());
263 Py_RETURN_NONE;
264 }
265
Invoke(int subgraph_index)266 PyObject* InterpreterWrapper::Invoke(int subgraph_index) {
267 TFLITE_PY_ENSURE_VALID_INTERPRETER();
268 TFLITE_PY_SUBGRAPH_BOUNDS_CHECK(subgraph_index);
269
270 // Release the GIL so that we can run multiple interpreters in parallel
271 TfLiteStatus status_code = kTfLiteOk;
272 Py_BEGIN_ALLOW_THREADS; // To return can happen between this and end!
273 tflite::Subgraph* subgraph = interpreter_->subgraph(subgraph_index);
274 status_code = subgraph->Invoke();
275
276 if (!interpreter_->allow_buffer_handle_output_) {
277 for (int tensor_index : subgraph->outputs()) {
278 subgraph->EnsureTensorDataIsReadable(tensor_index);
279 }
280 }
281 Py_END_ALLOW_THREADS;
282
283 TFLITE_PY_CHECK(
284 status_code); // don't move this into the Py_BEGIN/Py_End block
285
286 Py_RETURN_NONE;
287 }
288
InputIndices() const289 PyObject* InterpreterWrapper::InputIndices() const {
290 TFLITE_PY_ENSURE_VALID_INTERPRETER();
291 PyObject* np_array = PyArrayFromIntVector(interpreter_->inputs().data(),
292 interpreter_->inputs().size());
293
294 return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
295 }
296
OutputIndices() const297 PyObject* InterpreterWrapper::OutputIndices() const {
298 PyObject* np_array = PyArrayFromIntVector(interpreter_->outputs().data(),
299 interpreter_->outputs().size());
300
301 return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
302 }
303
ResizeInputTensorImpl(int i,PyObject * value)304 PyObject* InterpreterWrapper::ResizeInputTensorImpl(int i, PyObject* value) {
305 TFLITE_PY_ENSURE_VALID_INTERPRETER();
306
307 std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
308 PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr));
309 if (!array_safe) {
310 PyErr_SetString(PyExc_ValueError,
311 "Failed to convert numpy value into readable tensor.");
312 return nullptr;
313 }
314
315 PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
316
317 if (PyArray_NDIM(array) != 1) {
318 PyErr_Format(PyExc_ValueError, "Shape should be 1D instead of %d.",
319 PyArray_NDIM(array));
320 return nullptr;
321 }
322
323 if (PyArray_TYPE(array) != NPY_INT32) {
324 PyErr_Format(PyExc_ValueError, "Shape must be type int32 (was %d).",
325 PyArray_TYPE(array));
326 return nullptr;
327 }
328
329 PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(array),
330 NPY_ARRAY_OWNDATA);
331 return PyArray_Return(reinterpret_cast<PyArrayObject*>(array));
332 }
333
ResizeInputTensor(int i,PyObject * value,bool strict,int subgraph_index)334 PyObject* InterpreterWrapper::ResizeInputTensor(int i, PyObject* value,
335 bool strict,
336 int subgraph_index) {
337 TFLITE_PY_ENSURE_VALID_INTERPRETER();
338 TFLITE_PY_SUBGRAPH_BOUNDS_CHECK(subgraph_index);
339
340 PyArrayObject* array =
341 reinterpret_cast<PyArrayObject*>(ResizeInputTensorImpl(i, value));
342 if (array == nullptr) {
343 return nullptr;
344 }
345
346 std::vector<int> dims(PyArray_SHAPE(array)[0]);
347 memcpy(dims.data(), PyArray_BYTES(array), dims.size() * sizeof(int));
348
349 if (strict) {
350 TFLITE_PY_CHECK(interpreter_->subgraph(subgraph_index)
351 ->ResizeInputTensorStrict(i, dims));
352 } else {
353 TFLITE_PY_CHECK(
354 interpreter_->subgraph(subgraph_index)->ResizeInputTensor(i, dims));
355 }
356 Py_RETURN_NONE;
357 }
358
NumTensors() const359 int InterpreterWrapper::NumTensors() const {
360 if (!interpreter_) {
361 return 0;
362 }
363 return interpreter_->tensors_size();
364 }
365
TensorName(int i) const366 std::string InterpreterWrapper::TensorName(int i) const {
367 if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) {
368 return "";
369 }
370
371 const TfLiteTensor* tensor = interpreter_->tensor(i);
372 return tensor->name ? tensor->name : "";
373 }
374
TensorType(int i) const375 PyObject* InterpreterWrapper::TensorType(int i) const {
376 TFLITE_PY_ENSURE_VALID_INTERPRETER();
377 TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
378
379 const TfLiteTensor* tensor = interpreter_->tensor(i);
380 if (tensor->type == kTfLiteNoType) {
381 PyErr_Format(PyExc_ValueError, "Tensor with no type found.");
382 return nullptr;
383 }
384
385 int code = python_utils::TfLiteTypeToPyArrayType(tensor->type);
386 if (code == -1) {
387 PyErr_Format(PyExc_ValueError, "Invalid tflite type code %d", code);
388 return nullptr;
389 }
390 return PyArray_TypeObjectFromType(code);
391 }
392
TensorSize(int i) const393 PyObject* InterpreterWrapper::TensorSize(int i) const {
394 TFLITE_PY_ENSURE_VALID_INTERPRETER();
395 TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
396
397 const TfLiteTensor* tensor = interpreter_->tensor(i);
398 if (tensor->dims == nullptr) {
399 PyErr_Format(PyExc_ValueError, "Tensor with no shape found.");
400 return nullptr;
401 }
402 PyObject* np_array =
403 PyArrayFromIntVector(tensor->dims->data, tensor->dims->size);
404
405 return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
406 }
407
TensorSizeSignature(int i) const408 PyObject* InterpreterWrapper::TensorSizeSignature(int i) const {
409 TFLITE_PY_ENSURE_VALID_INTERPRETER();
410 TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
411
412 const TfLiteTensor* tensor = interpreter_->tensor(i);
413 const int32_t* size_signature_data = nullptr;
414 int32_t size_signature_size = 0;
415 if (tensor->dims_signature != nullptr && tensor->dims_signature->size != 0) {
416 size_signature_data = tensor->dims_signature->data;
417 size_signature_size = tensor->dims_signature->size;
418 } else {
419 size_signature_data = tensor->dims->data;
420 size_signature_size = tensor->dims->size;
421 }
422 PyObject* np_array =
423 PyArrayFromIntVector(size_signature_data, size_signature_size);
424
425 return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
426 }
427
TensorSparsityParameters(int i) const428 PyObject* InterpreterWrapper::TensorSparsityParameters(int i) const {
429 TFLITE_PY_ENSURE_VALID_INTERPRETER();
430 TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
431 const TfLiteTensor* tensor = interpreter_->tensor(i);
432 if (tensor->sparsity == nullptr) {
433 return PyDict_New();
434 }
435
436 return PyDictFromSparsityParam(*tensor->sparsity);
437 }
438
TensorQuantization(int i) const439 PyObject* InterpreterWrapper::TensorQuantization(int i) const {
440 TFLITE_PY_ENSURE_VALID_INTERPRETER();
441 TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
442 const TfLiteTensor* tensor = interpreter_->tensor(i);
443 return PyTupleFromQuantizationParam(tensor->params);
444 }
445
TensorQuantizationParameters(int i) const446 PyObject* InterpreterWrapper::TensorQuantizationParameters(int i) const {
447 TFLITE_PY_ENSURE_VALID_INTERPRETER();
448 TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
449 const TfLiteTensor* tensor = interpreter_->tensor(i);
450 const TfLiteQuantization quantization = tensor->quantization;
451 float* scales_data = nullptr;
452 int32_t* zero_points_data = nullptr;
453 int32_t scales_size = 0;
454 int32_t zero_points_size = 0;
455 int32_t quantized_dimension = 0;
456 if (quantization.type == kTfLiteAffineQuantization) {
457 const TfLiteAffineQuantization* q_params =
458 reinterpret_cast<const TfLiteAffineQuantization*>(quantization.params);
459 if (q_params->scale) {
460 scales_data = q_params->scale->data;
461 scales_size = q_params->scale->size;
462 }
463 if (q_params->zero_point) {
464 zero_points_data = q_params->zero_point->data;
465 zero_points_size = q_params->zero_point->size;
466 }
467 quantized_dimension = q_params->quantized_dimension;
468 }
469 PyObject* scales_array = PyArrayFromFloatVector(scales_data, scales_size);
470 PyObject* zero_points_array =
471 PyArrayFromIntVector(zero_points_data, zero_points_size);
472
473 PyObject* result = PyTuple_New(3);
474 PyTuple_SET_ITEM(result, 0, scales_array);
475 PyTuple_SET_ITEM(result, 1, zero_points_array);
476 PyTuple_SET_ITEM(result, 2, PyLong_FromLong(quantized_dimension));
477 return result;
478 }
479
SetTensor(int i,PyObject * value,int subgraph_index)480 PyObject* InterpreterWrapper::SetTensor(int i, PyObject* value,
481 int subgraph_index) {
482 TFLITE_PY_ENSURE_VALID_INTERPRETER();
483 TFLITE_PY_SUBGRAPH_BOUNDS_CHECK(subgraph_index);
484 TFLITE_PY_SUBGRAPH_TENSOR_BOUNDS_CHECK(i, subgraph_index);
485
486 std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
487 PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr));
488 if (!array_safe) {
489 PyErr_SetString(PyExc_ValueError,
490 "Failed to convert value into readable tensor.");
491 return nullptr;
492 }
493
494 PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
495 TfLiteTensor* tensor = interpreter_->subgraph(subgraph_index)->tensor(i);
496
497 if (python_utils::TfLiteTypeFromPyArray(array) != tensor->type) {
498 PyErr_Format(PyExc_ValueError,
499 "Cannot set tensor:"
500 " Got value of type %s"
501 " but expected type %s for input %d, name: %s ",
502 TfLiteTypeGetName(python_utils::TfLiteTypeFromPyArray(array)),
503 TfLiteTypeGetName(tensor->type), i, tensor->name);
504 return nullptr;
505 }
506
507 if (PyArray_NDIM(array) != tensor->dims->size) {
508 PyErr_Format(PyExc_ValueError,
509 "Cannot set tensor: Dimension mismatch."
510 " Got %d"
511 " but expected %d for input %d.",
512 PyArray_NDIM(array), tensor->dims->size, i);
513 return nullptr;
514 }
515
516 for (int j = 0; j < PyArray_NDIM(array); j++) {
517 if (tensor->dims->data[j] != PyArray_SHAPE(array)[j]) {
518 PyErr_Format(PyExc_ValueError,
519 "Cannot set tensor: Dimension mismatch."
520 " Got %ld"
521 " but expected %d for dimension %d of input %d.",
522 PyArray_SHAPE(array)[j], tensor->dims->data[j], j, i);
523 return nullptr;
524 }
525 }
526
527 if (tensor->type != kTfLiteString) {
528 if (tensor->data.raw == nullptr) {
529 PyErr_Format(PyExc_ValueError,
530 "Cannot set tensor:"
531 " Tensor is unallocated. Try calling allocate_tensors()"
532 " first");
533 return nullptr;
534 }
535
536 size_t size = PyArray_NBYTES(array);
537 if (size != tensor->bytes) {
538 PyErr_Format(PyExc_ValueError,
539 "numpy array had %zu bytes but expected %zu bytes.", size,
540 tensor->bytes);
541 return nullptr;
542 }
543 memcpy(tensor->data.raw, PyArray_DATA(array), size);
544 } else {
545 DynamicBuffer dynamic_buffer;
546 if (!python_utils::FillStringBufferWithPyArray(value, &dynamic_buffer)) {
547 return nullptr;
548 }
549 dynamic_buffer.WriteToTensor(tensor, nullptr);
550 }
551 Py_RETURN_NONE;
552 }
553
NumNodes() const554 int InterpreterWrapper::NumNodes() const {
555 if (!interpreter_) {
556 return 0;
557 }
558 return interpreter_->nodes_size();
559 }
560
NodeInputs(int i) const561 PyObject* InterpreterWrapper::NodeInputs(int i) const {
562 TFLITE_PY_ENSURE_VALID_INTERPRETER();
563 TFLITE_PY_NODES_BOUNDS_CHECK(i);
564
565 const TfLiteNode* node = &(interpreter_->node_and_registration(i)->first);
566 PyObject* inputs =
567 PyArrayFromIntVector(node->inputs->data, node->inputs->size);
568 return inputs;
569 }
570
NodeOutputs(int i) const571 PyObject* InterpreterWrapper::NodeOutputs(int i) const {
572 TFLITE_PY_ENSURE_VALID_INTERPRETER();
573 TFLITE_PY_NODES_BOUNDS_CHECK(i);
574
575 const TfLiteNode* node = &(interpreter_->node_and_registration(i)->first);
576 PyObject* outputs =
577 PyArrayFromIntVector(node->outputs->data, node->outputs->size);
578 return outputs;
579 }
580
NodeName(int i) const581 std::string InterpreterWrapper::NodeName(int i) const {
582 if (!interpreter_ || i >= interpreter_->nodes_size() || i < 0) {
583 return "";
584 }
585 // Get op name from registration
586 const TfLiteRegistration* node_registration =
587 &(interpreter_->node_and_registration(i)->second);
588 int32_t op_code = node_registration->builtin_code;
589 std::string op_name;
590 if (op_code == tflite::BuiltinOperator_CUSTOM) {
591 const char* custom_name = node_registration->custom_name;
592 op_name = custom_name ? custom_name : "UnknownCustomOp";
593 } else {
594 op_name = tflite::EnumNamesBuiltinOperator()[op_code];
595 }
596 std::string op_name_str(op_name);
597 return op_name_str;
598 }
599
600 namespace {
601
602 // Checks to see if a tensor access can succeed (returns nullptr on error).
603 // Otherwise returns Py_None.
CheckGetTensorArgs(Interpreter * interpreter_,int tensor_index,TfLiteTensor ** tensor,int * type_num,int subgraph_index)604 PyObject* CheckGetTensorArgs(Interpreter* interpreter_, int tensor_index,
605 TfLiteTensor** tensor, int* type_num,
606 int subgraph_index) {
607 TFLITE_PY_ENSURE_VALID_INTERPRETER();
608 TFLITE_PY_SUBGRAPH_BOUNDS_CHECK(subgraph_index);
609 TFLITE_PY_SUBGRAPH_TENSOR_BOUNDS_CHECK(tensor_index, subgraph_index);
610
611 *tensor = interpreter_->subgraph(subgraph_index)->tensor(tensor_index);
612 if ((*tensor)->bytes == 0) {
613 PyErr_SetString(PyExc_ValueError, "Invalid tensor size.");
614 return nullptr;
615 }
616
617 *type_num = python_utils::TfLiteTypeToPyArrayType((*tensor)->type);
618 if (*type_num == -1) {
619 PyErr_SetString(PyExc_ValueError, "Unknown tensor type.");
620 return nullptr;
621 }
622
623 if (!(*tensor)->data.raw) {
624 PyErr_SetString(PyExc_ValueError,
625 "Tensor data is null."
626 " Run allocate_tensors() first");
627 return nullptr;
628 }
629
630 Py_RETURN_NONE;
631 }
632
633 } // namespace
634
GetSignatureDefs() const635 PyObject* InterpreterWrapper::GetSignatureDefs() const {
636 PyObject* result = PyDict_New();
637 for (const auto& sig_key : interpreter_->signature_keys()) {
638 PyObject* signature_def = PyDict_New();
639 PyObject* inputs = PyDict_New();
640 PyObject* outputs = PyDict_New();
641 const auto& signature_def_inputs =
642 interpreter_->signature_inputs(sig_key->c_str());
643 const auto& signature_def_outputs =
644 interpreter_->signature_outputs(sig_key->c_str());
645 for (const auto& input : signature_def_inputs) {
646 PyDict_SetItemString(inputs, input.first.c_str(),
647 PyLong_FromLong(input.second));
648 }
649 for (const auto& output : signature_def_outputs) {
650 PyDict_SetItemString(outputs, output.first.c_str(),
651 PyLong_FromLong(output.second));
652 }
653
654 PyDict_SetItemString(signature_def, "inputs", inputs);
655 PyDict_SetItemString(signature_def, "outputs", outputs);
656 PyDict_SetItemString(result, sig_key->c_str(), signature_def);
657 }
658 return result;
659 }
660
GetSubgraphIndexFromSignature(const char * signature_key)661 PyObject* InterpreterWrapper::GetSubgraphIndexFromSignature(
662 const char* signature_key) {
663 TFLITE_PY_ENSURE_VALID_INTERPRETER();
664
665 int32_t subgraph_index =
666 interpreter_->GetSubgraphIndexFromSignature(signature_key);
667
668 if (subgraph_index < 0) {
669 PyErr_SetString(PyExc_ValueError, "No matching signature.");
670 return nullptr;
671 }
672 return PyLong_FromLong(static_cast<int64_t>(subgraph_index));
673 }
674
GetTensor(int i,int subgraph_index) const675 PyObject* InterpreterWrapper::GetTensor(int i, int subgraph_index) const {
676 // Sanity check accessor
677 TfLiteTensor* tensor = nullptr;
678 int type_num = 0;
679
680 PyObject* check_result = CheckGetTensorArgs(interpreter_.get(), i, &tensor,
681 &type_num, subgraph_index);
682 if (check_result == nullptr) return check_result;
683 Py_XDECREF(check_result);
684
685 std::vector<npy_intp> dims(tensor->dims->data,
686 tensor->dims->data + tensor->dims->size);
687 if (tensor->type != kTfLiteString && tensor->type != kTfLiteResource &&
688 tensor->type != kTfLiteVariant) {
689 // Make a buffer copy but we must tell Numpy It owns that data or else
690 // it will leak.
691 void* data = malloc(tensor->bytes);
692 if (!data) {
693 PyErr_SetString(PyExc_ValueError, "Malloc to copy tensor failed.");
694 return nullptr;
695 }
696 memcpy(data, tensor->data.raw, tensor->bytes);
697 PyObject* np_array;
698 if (tensor->sparsity == nullptr) {
699 np_array =
700 PyArray_SimpleNewFromData(dims.size(), dims.data(), type_num, data);
701 } else {
702 std::vector<npy_intp> sparse_buffer_dims(1);
703 size_t size_of_type;
704 if (GetSizeOfType(nullptr, tensor->type, &size_of_type) != kTfLiteOk) {
705 PyErr_SetString(PyExc_ValueError, "Unknown tensor type.");
706 free(data);
707 return nullptr;
708 }
709 sparse_buffer_dims[0] = tensor->bytes / size_of_type;
710 np_array = PyArray_SimpleNewFromData(
711 sparse_buffer_dims.size(), sparse_buffer_dims.data(), type_num, data);
712 }
713 PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(np_array),
714 NPY_ARRAY_OWNDATA);
715 return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
716 } else {
717 // Create a C-order array so the data is contiguous in memory.
718 const int32_t kCOrder = 0;
719 PyObject* py_object =
720 PyArray_EMPTY(dims.size(), dims.data(), NPY_OBJECT, kCOrder);
721
722 if (py_object == nullptr) {
723 PyErr_SetString(PyExc_MemoryError, "Failed to allocate PyArray.");
724 return nullptr;
725 }
726
727 PyArrayObject* py_array = reinterpret_cast<PyArrayObject*>(py_object);
728 PyObject** data = reinterpret_cast<PyObject**>(PyArray_DATA(py_array));
729 auto num_strings = GetStringCount(tensor);
730 for (int j = 0; j < num_strings; ++j) {
731 auto ref = GetString(tensor, j);
732
733 PyObject* bytes = PyBytes_FromStringAndSize(ref.str, ref.len);
734 if (bytes == nullptr) {
735 Py_DECREF(py_object);
736 PyErr_Format(PyExc_ValueError,
737 "Could not create PyBytes from string %d of input %d.", j,
738 i);
739 return nullptr;
740 }
741 // PyArray_EMPTY produces an array full of Py_None, which we must decref.
742 Py_DECREF(data[j]);
743 data[j] = bytes;
744 }
745 return py_object;
746 }
747 }
748
tensor(PyObject * base_object,int tensor_index,int subgraph_index)749 PyObject* InterpreterWrapper::tensor(PyObject* base_object, int tensor_index,
750 int subgraph_index) {
751 // Sanity check accessor
752 TfLiteTensor* tensor = nullptr;
753 int type_num = 0;
754
755 PyObject* check_result = CheckGetTensorArgs(
756 interpreter_.get(), tensor_index, &tensor, &type_num, subgraph_index);
757 if (check_result == nullptr) return check_result;
758 Py_XDECREF(check_result);
759
760 std::vector<npy_intp> dims(tensor->dims->data,
761 tensor->dims->data + tensor->dims->size);
762 PyArrayObject* np_array =
763 reinterpret_cast<PyArrayObject*>(PyArray_SimpleNewFromData(
764 dims.size(), dims.data(), type_num, tensor->data.raw));
765 Py_INCREF(base_object); // SetBaseObject steals, so we need to add.
766 PyArray_SetBaseObject(np_array, base_object);
767 return PyArray_Return(np_array);
768 }
769
CreateWrapperCPPFromFile(const char * model_path,int op_resolver_id,const std::vector<std::string> & registerers_by_name,const std::vector<std::function<void (uintptr_t)>> & registerers_by_func,std::string * error_msg,bool preserve_all_tensors)770 InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
771 const char* model_path, int op_resolver_id,
772 const std::vector<std::string>& registerers_by_name,
773 const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
774 std::string* error_msg, bool preserve_all_tensors) {
775 std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
776 std::unique_ptr<InterpreterWrapper::Model> model =
777 Model::BuildFromFile(model_path, error_reporter.get());
778 return CreateInterpreterWrapper(std::move(model), op_resolver_id,
779 std::move(error_reporter),
780 registerers_by_name, registerers_by_func,
781 error_msg, preserve_all_tensors);
782 }
783
CreateWrapperCPPFromFile(const char * model_path,int op_resolver_id,const std::vector<std::string> & registerers,std::string * error_msg,bool preserve_all_tensors)784 InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
785 const char* model_path, int op_resolver_id,
786 const std::vector<std::string>& registerers, std::string* error_msg,
787 bool preserve_all_tensors) {
788 return CreateWrapperCPPFromFile(model_path, op_resolver_id, registerers,
789 {} /*registerers_by_func*/, error_msg,
790 preserve_all_tensors);
791 }
792
CreateWrapperCPPFromBuffer(PyObject * data,int op_resolver_id,const std::vector<std::string> & registerers_by_name,const std::vector<std::function<void (uintptr_t)>> & registerers_by_func,std::string * error_msg,bool preserve_all_tensors)793 InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
794 PyObject* data, int op_resolver_id,
795 const std::vector<std::string>& registerers_by_name,
796 const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
797 std::string* error_msg, bool preserve_all_tensors) {
798 char* buf = nullptr;
799 Py_ssize_t length;
800 std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
801
802 if (python_utils::ConvertFromPyString(data, &buf, &length) == -1) {
803 return nullptr;
804 }
805 std::unique_ptr<InterpreterWrapper::Model> model =
806 Model::BuildFromBuffer(buf, length, error_reporter.get());
807 return CreateInterpreterWrapper(std::move(model), op_resolver_id,
808 std::move(error_reporter),
809 registerers_by_name, registerers_by_func,
810 error_msg, preserve_all_tensors);
811 }
812
CreateWrapperCPPFromBuffer(PyObject * data,int op_resolver_id,const std::vector<std::string> & registerers,std::string * error_msg,bool preserve_all_tensors)813 InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
814 PyObject* data, int op_resolver_id,
815 const std::vector<std::string>& registerers, std::string* error_msg,
816 bool preserve_all_tensors) {
817 return CreateWrapperCPPFromBuffer(data, op_resolver_id, registerers, {},
818 error_msg, preserve_all_tensors);
819 }
820
ResetVariableTensors()821 PyObject* InterpreterWrapper::ResetVariableTensors() {
822 TFLITE_PY_ENSURE_VALID_INTERPRETER();
823 TFLITE_PY_CHECK(interpreter_->ResetVariableTensors());
824 Py_RETURN_NONE;
825 }
826
SetNumThreads(int num_threads)827 PyObject* InterpreterWrapper::SetNumThreads(int num_threads) {
828 TFLITE_PY_ENSURE_VALID_INTERPRETER();
829 interpreter_->SetNumThreads(num_threads);
830 Py_RETURN_NONE;
831 }
832
ModifyGraphWithDelegate(TfLiteDelegate * delegate)833 PyObject* InterpreterWrapper::ModifyGraphWithDelegate(
834 TfLiteDelegate* delegate) {
835 TFLITE_PY_ENSURE_VALID_INTERPRETER();
836 TFLITE_PY_CHECK(interpreter_->ModifyGraphWithDelegate(delegate));
837 Py_RETURN_NONE;
838 }
839
840 } // namespace interpreter_wrapper
841 } // namespace tflite
842