1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h"
16
17 #include <stdarg.h>
18
19 #include <functional>
20 #include <sstream>
21 #include <string>
22
23 #include "absl/memory/memory.h"
24 #include "absl/strings/str_format.h"
25 #include "tensorflow/lite/c/common.h"
26 #include "tensorflow/lite/core/api/error_reporter.h"
27 #include "tensorflow/lite/interpreter.h"
28 #include "tensorflow/lite/kernels/register.h"
29 #include "tensorflow/lite/model.h"
30 #include "tensorflow/lite/python/interpreter_wrapper/numpy.h"
31 #include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
32 #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
33 #include "tensorflow/lite/shared_library.h"
34 #include "tensorflow/lite/string_util.h"
35 #include "tensorflow/lite/util.h"
36
37 #define TFLITE_PY_CHECK(x) \
38 if ((x) != kTfLiteOk) { \
39 return error_reporter_->exception(); \
40 }
41
42 #define TFLITE_PY_TENSOR_BOUNDS_CHECK(i) \
43 if (i >= interpreter_->tensors_size() || i < 0) { \
44 PyErr_Format(PyExc_ValueError, \
45 "Invalid tensor index %d exceeds max tensor index %lu", i, \
46 interpreter_->tensors_size()); \
47 return nullptr; \
48 }
49
50 #define TFLITE_PY_NODES_BOUNDS_CHECK(i) \
51 if (i >= interpreter_->nodes_size() || i < 0) { \
52 PyErr_Format(PyExc_ValueError, "Invalid node index"); \
53 return nullptr; \
54 }
55
56 #define TFLITE_PY_ENSURE_VALID_INTERPRETER() \
57 if (!interpreter_) { \
58 PyErr_SetString(PyExc_ValueError, "Interpreter was not initialized."); \
59 return nullptr; \
60 }
61
62 namespace tflite {
63 namespace interpreter_wrapper {
64
65 namespace {
66
67 using python_utils::PyDecrefDeleter;
68
CreateInterpreter(const InterpreterWrapper::Model * model,const tflite::ops::builtin::BuiltinOpResolver & resolver)69 std::unique_ptr<Interpreter> CreateInterpreter(
70 const InterpreterWrapper::Model* model,
71 const tflite::ops::builtin::BuiltinOpResolver& resolver) {
72 if (!model) {
73 return nullptr;
74 }
75
76 ::tflite::python::ImportNumpy();
77
78 std::unique_ptr<Interpreter> interpreter;
79 if (InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) {
80 return nullptr;
81 }
82 return interpreter;
83 }
84
PyArrayFromFloatVector(const float * data,npy_intp size)85 PyObject* PyArrayFromFloatVector(const float* data, npy_intp size) {
86 void* pydata = malloc(size * sizeof(float));
87 memcpy(pydata, data, size * sizeof(float));
88 PyObject* obj = PyArray_SimpleNewFromData(1, &size, NPY_FLOAT32, pydata);
89 PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(obj), NPY_ARRAY_OWNDATA);
90 return obj;
91 }
92
PyArrayFromIntVector(const int * data,npy_intp size)93 PyObject* PyArrayFromIntVector(const int* data, npy_intp size) {
94 void* pydata = malloc(size * sizeof(int));
95 memcpy(pydata, data, size * sizeof(int));
96 PyObject* obj = PyArray_SimpleNewFromData(1, &size, NPY_INT32, pydata);
97 PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(obj), NPY_ARRAY_OWNDATA);
98 return obj;
99 }
100
PyTupleFromQuantizationParam(const TfLiteQuantizationParams & param)101 PyObject* PyTupleFromQuantizationParam(const TfLiteQuantizationParams& param) {
102 PyObject* result = PyTuple_New(2);
103 PyTuple_SET_ITEM(result, 0, PyFloat_FromDouble(param.scale));
104 PyTuple_SET_ITEM(result, 1, PyLong_FromLong(param.zero_point));
105 return result;
106 }
107
PyDictFromSparsityParam(const TfLiteSparsity & param)108 PyObject* PyDictFromSparsityParam(const TfLiteSparsity& param) {
109 PyObject* result = PyDict_New();
110 PyDict_SetItemString(result, "traversal_order",
111 PyArrayFromIntVector(param.traversal_order->data,
112 param.traversal_order->size));
113 PyDict_SetItemString(
114 result, "block_map",
115 PyArrayFromIntVector(param.block_map->data, param.block_map->size));
116 PyObject* dim_metadata = PyList_New(param.dim_metadata_size);
117 for (int i = 0; i < param.dim_metadata_size; i++) {
118 PyObject* dim_metadata_i = PyDict_New();
119 if (param.dim_metadata[i].format == kTfLiteDimDense) {
120 PyDict_SetItemString(dim_metadata_i, "format", PyLong_FromSize_t(0));
121 PyDict_SetItemString(dim_metadata_i, "dense_size",
122 PyLong_FromSize_t(param.dim_metadata[i].dense_size));
123 } else {
124 PyDict_SetItemString(dim_metadata_i, "format", PyLong_FromSize_t(1));
125 const auto* array_segments = param.dim_metadata[i].array_segments;
126 const auto* array_indices = param.dim_metadata[i].array_indices;
127 PyDict_SetItemString(
128 dim_metadata_i, "array_segments",
129 PyArrayFromIntVector(array_segments->data, array_segments->size));
130 PyDict_SetItemString(
131 dim_metadata_i, "array_indices",
132 PyArrayFromIntVector(array_indices->data, array_indices->size));
133 }
134 PyList_SetItem(dim_metadata, i, dim_metadata_i);
135 }
136 PyDict_SetItemString(result, "dim_metadata", dim_metadata);
137 return result;
138 }
139
RegisterCustomOpByName(const char * registerer_name,tflite::MutableOpResolver * resolver,std::string * error_msg)140 bool RegisterCustomOpByName(const char* registerer_name,
141 tflite::MutableOpResolver* resolver,
142 std::string* error_msg) {
143 // Registerer functions take a pointer to a BuiltinOpResolver as an input
144 // parameter and return void.
145 // TODO(b/137576229): We should implement this functionality in a more
146 // principled way.
147 typedef void (*RegistererFunctionType)(tflite::MutableOpResolver*);
148
149 // Look for the Registerer function by name.
150 RegistererFunctionType registerer = reinterpret_cast<RegistererFunctionType>(
151 SharedLibrary::GetSymbol(registerer_name));
152
153 // Fail in an informative way if the function was not found.
154 if (registerer == nullptr) {
155 *error_msg =
156 absl::StrFormat("Looking up symbol '%s' failed with error '%s'.",
157 registerer_name, SharedLibrary::GetError());
158 return false;
159 }
160
161 // Call the registerer with the resolver.
162 registerer(resolver);
163 return true;
164 }
165
166 } // namespace
167
CreateInterpreterWrapper(std::unique_ptr<InterpreterWrapper::Model> model,std::unique_ptr<PythonErrorReporter> error_reporter,const std::vector<std::string> & registerers_by_name,const std::vector<std::function<void (uintptr_t)>> & registerers_by_func,std::string * error_msg)168 InterpreterWrapper* InterpreterWrapper::CreateInterpreterWrapper(
169 std::unique_ptr<InterpreterWrapper::Model> model,
170 std::unique_ptr<PythonErrorReporter> error_reporter,
171 const std::vector<std::string>& registerers_by_name,
172 const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
173 std::string* error_msg) {
174 if (!model) {
175 *error_msg = error_reporter->message();
176 return nullptr;
177 }
178
179 auto resolver = absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>();
180 for (const auto& registerer : registerers_by_name) {
181 if (!RegisterCustomOpByName(registerer.c_str(), resolver.get(), error_msg))
182 return nullptr;
183 }
184 for (const auto& registerer : registerers_by_func) {
185 registerer(reinterpret_cast<uintptr_t>(resolver.get()));
186 }
187 auto interpreter = CreateInterpreter(model.get(), *resolver);
188 if (!interpreter) {
189 *error_msg = error_reporter->message();
190 return nullptr;
191 }
192
193 InterpreterWrapper* wrapper =
194 new InterpreterWrapper(std::move(model), std::move(error_reporter),
195 std::move(resolver), std::move(interpreter));
196 return wrapper;
197 }
198
InterpreterWrapper(std::unique_ptr<InterpreterWrapper::Model> model,std::unique_ptr<PythonErrorReporter> error_reporter,std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver,std::unique_ptr<Interpreter> interpreter)199 InterpreterWrapper::InterpreterWrapper(
200 std::unique_ptr<InterpreterWrapper::Model> model,
201 std::unique_ptr<PythonErrorReporter> error_reporter,
202 std::unique_ptr<tflite::ops::builtin::BuiltinOpResolver> resolver,
203 std::unique_ptr<Interpreter> interpreter)
204 : model_(std::move(model)),
205 error_reporter_(std::move(error_reporter)),
206 resolver_(std::move(resolver)),
207 interpreter_(std::move(interpreter)) {}
208
~InterpreterWrapper()209 InterpreterWrapper::~InterpreterWrapper() {}
210
AllocateTensors()211 PyObject* InterpreterWrapper::AllocateTensors() {
212 TFLITE_PY_ENSURE_VALID_INTERPRETER();
213 TFLITE_PY_CHECK(interpreter_->AllocateTensors());
214 Py_RETURN_NONE;
215 }
216
Invoke()217 PyObject* InterpreterWrapper::Invoke() {
218 TFLITE_PY_ENSURE_VALID_INTERPRETER();
219
220 // Release the GIL so that we can run multiple interpreters in parallel
221 TfLiteStatus status_code = kTfLiteOk;
222 Py_BEGIN_ALLOW_THREADS; // To return can happen between this and end!
223 status_code = interpreter_->Invoke();
224 Py_END_ALLOW_THREADS;
225
226 TFLITE_PY_CHECK(
227 status_code); // don't move this into the Py_BEGIN/Py_End block
228
229 Py_RETURN_NONE;
230 }
231
InputIndices() const232 PyObject* InterpreterWrapper::InputIndices() const {
233 TFLITE_PY_ENSURE_VALID_INTERPRETER();
234 PyObject* np_array = PyArrayFromIntVector(interpreter_->inputs().data(),
235 interpreter_->inputs().size());
236
237 return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
238 }
239
OutputIndices() const240 PyObject* InterpreterWrapper::OutputIndices() const {
241 PyObject* np_array = PyArrayFromIntVector(interpreter_->outputs().data(),
242 interpreter_->outputs().size());
243
244 return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
245 }
246
ResizeInputTensorImpl(int i,PyObject * value)247 PyObject* InterpreterWrapper::ResizeInputTensorImpl(int i, PyObject* value) {
248 TFLITE_PY_ENSURE_VALID_INTERPRETER();
249
250 std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
251 PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr));
252 if (!array_safe) {
253 PyErr_SetString(PyExc_ValueError,
254 "Failed to convert numpy value into readable tensor.");
255 return nullptr;
256 }
257
258 PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
259
260 if (PyArray_NDIM(array) != 1) {
261 PyErr_Format(PyExc_ValueError, "Shape should be 1D instead of %d.",
262 PyArray_NDIM(array));
263 return nullptr;
264 }
265
266 if (PyArray_TYPE(array) != NPY_INT32) {
267 PyErr_Format(PyExc_ValueError, "Shape must be type int32 (was %d).",
268 PyArray_TYPE(array));
269 return nullptr;
270 }
271
272 PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(array),
273 NPY_ARRAY_OWNDATA);
274 return PyArray_Return(reinterpret_cast<PyArrayObject*>(array));
275 }
276
ResizeInputTensor(int i,PyObject * value,bool strict)277 PyObject* InterpreterWrapper::ResizeInputTensor(int i, PyObject* value,
278 bool strict) {
279 PyArrayObject* array =
280 reinterpret_cast<PyArrayObject*>(ResizeInputTensorImpl(i, value));
281 if (array == nullptr) {
282 return nullptr;
283 }
284
285 std::vector<int> dims(PyArray_SHAPE(array)[0]);
286 memcpy(dims.data(), PyArray_BYTES(array), dims.size() * sizeof(int));
287
288 if (strict) {
289 TFLITE_PY_CHECK(interpreter_->ResizeInputTensorStrict(i, dims));
290 } else {
291 TFLITE_PY_CHECK(interpreter_->ResizeInputTensor(i, dims));
292 }
293 Py_RETURN_NONE;
294 }
295
NumTensors() const296 int InterpreterWrapper::NumTensors() const {
297 if (!interpreter_) {
298 return 0;
299 }
300 return interpreter_->tensors_size();
301 }
302
TensorName(int i) const303 std::string InterpreterWrapper::TensorName(int i) const {
304 if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) {
305 return "";
306 }
307
308 const TfLiteTensor* tensor = interpreter_->tensor(i);
309 return tensor->name ? tensor->name : "";
310 }
311
TensorType(int i) const312 PyObject* InterpreterWrapper::TensorType(int i) const {
313 TFLITE_PY_ENSURE_VALID_INTERPRETER();
314 TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
315
316 const TfLiteTensor* tensor = interpreter_->tensor(i);
317 if (tensor->type == kTfLiteNoType) {
318 PyErr_Format(PyExc_ValueError, "Tensor with no type found.");
319 return nullptr;
320 }
321
322 int code = python_utils::TfLiteTypeToPyArrayType(tensor->type);
323 if (code == -1) {
324 PyErr_Format(PyExc_ValueError, "Invalid tflite type code %d", code);
325 return nullptr;
326 }
327 return PyArray_TypeObjectFromType(code);
328 }
329
TensorSize(int i) const330 PyObject* InterpreterWrapper::TensorSize(int i) const {
331 TFLITE_PY_ENSURE_VALID_INTERPRETER();
332 TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
333
334 const TfLiteTensor* tensor = interpreter_->tensor(i);
335 if (tensor->dims == nullptr) {
336 PyErr_Format(PyExc_ValueError, "Tensor with no shape found.");
337 return nullptr;
338 }
339 PyObject* np_array =
340 PyArrayFromIntVector(tensor->dims->data, tensor->dims->size);
341
342 return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
343 }
344
TensorSizeSignature(int i) const345 PyObject* InterpreterWrapper::TensorSizeSignature(int i) const {
346 TFLITE_PY_ENSURE_VALID_INTERPRETER();
347 TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
348
349 const TfLiteTensor* tensor = interpreter_->tensor(i);
350 const int32_t* size_signature_data = nullptr;
351 int32_t size_signature_size = 0;
352 if (tensor->dims_signature != nullptr && tensor->dims_signature->size != 0) {
353 size_signature_data = tensor->dims_signature->data;
354 size_signature_size = tensor->dims_signature->size;
355 } else {
356 size_signature_data = tensor->dims->data;
357 size_signature_size = tensor->dims->size;
358 }
359 PyObject* np_array =
360 PyArrayFromIntVector(size_signature_data, size_signature_size);
361
362 return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
363 }
364
TensorSparsityParameters(int i) const365 PyObject* InterpreterWrapper::TensorSparsityParameters(int i) const {
366 TFLITE_PY_ENSURE_VALID_INTERPRETER();
367 TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
368 const TfLiteTensor* tensor = interpreter_->tensor(i);
369 if (tensor->sparsity == nullptr) {
370 return PyDict_New();
371 }
372
373 return PyDictFromSparsityParam(*tensor->sparsity);
374 }
375
TensorQuantization(int i) const376 PyObject* InterpreterWrapper::TensorQuantization(int i) const {
377 TFLITE_PY_ENSURE_VALID_INTERPRETER();
378 TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
379 const TfLiteTensor* tensor = interpreter_->tensor(i);
380 return PyTupleFromQuantizationParam(tensor->params);
381 }
382
TensorQuantizationParameters(int i) const383 PyObject* InterpreterWrapper::TensorQuantizationParameters(int i) const {
384 TFLITE_PY_ENSURE_VALID_INTERPRETER();
385 TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
386 const TfLiteTensor* tensor = interpreter_->tensor(i);
387 const TfLiteQuantization quantization = tensor->quantization;
388 float* scales_data = nullptr;
389 int32_t* zero_points_data = nullptr;
390 int32_t scales_size = 0;
391 int32_t zero_points_size = 0;
392 int32_t quantized_dimension = 0;
393 if (quantization.type == kTfLiteAffineQuantization) {
394 const TfLiteAffineQuantization* q_params =
395 reinterpret_cast<const TfLiteAffineQuantization*>(quantization.params);
396 if (q_params->scale) {
397 scales_data = q_params->scale->data;
398 scales_size = q_params->scale->size;
399 }
400 if (q_params->zero_point) {
401 zero_points_data = q_params->zero_point->data;
402 zero_points_size = q_params->zero_point->size;
403 }
404 quantized_dimension = q_params->quantized_dimension;
405 }
406 PyObject* scales_array = PyArrayFromFloatVector(scales_data, scales_size);
407 PyObject* zero_points_array =
408 PyArrayFromIntVector(zero_points_data, zero_points_size);
409
410 PyObject* result = PyTuple_New(3);
411 PyTuple_SET_ITEM(result, 0, scales_array);
412 PyTuple_SET_ITEM(result, 1, zero_points_array);
413 PyTuple_SET_ITEM(result, 2, PyLong_FromLong(quantized_dimension));
414 return result;
415 }
416
SetTensor(int i,PyObject * value)417 PyObject* InterpreterWrapper::SetTensor(int i, PyObject* value) {
418 TFLITE_PY_ENSURE_VALID_INTERPRETER();
419 TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
420
421 std::unique_ptr<PyObject, PyDecrefDeleter> array_safe(
422 PyArray_FromAny(value, nullptr, 0, 0, NPY_ARRAY_CARRAY, nullptr));
423 if (!array_safe) {
424 PyErr_SetString(PyExc_ValueError,
425 "Failed to convert value into readable tensor.");
426 return nullptr;
427 }
428
429 PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
430 TfLiteTensor* tensor = interpreter_->tensor(i);
431
432 if (python_utils::TfLiteTypeFromPyArray(array) != tensor->type) {
433 PyErr_Format(PyExc_ValueError,
434 "Cannot set tensor:"
435 " Got value of type %s"
436 " but expected type %s for input %d, name: %s ",
437 TfLiteTypeGetName(python_utils::TfLiteTypeFromPyArray(array)),
438 TfLiteTypeGetName(tensor->type), i, tensor->name);
439 return nullptr;
440 }
441
442 if (PyArray_NDIM(array) != tensor->dims->size) {
443 PyErr_Format(PyExc_ValueError,
444 "Cannot set tensor: Dimension mismatch."
445 " Got %d"
446 " but expected %d for input %d.",
447 PyArray_NDIM(array), tensor->dims->size, i);
448 return nullptr;
449 }
450
451 for (int j = 0; j < PyArray_NDIM(array); j++) {
452 if (tensor->dims->data[j] != PyArray_SHAPE(array)[j]) {
453 PyErr_Format(PyExc_ValueError,
454 "Cannot set tensor: Dimension mismatch."
455 " Got %ld"
456 " but expected %d for dimension %d of input %d.",
457 PyArray_SHAPE(array)[j], tensor->dims->data[j], j, i);
458 return nullptr;
459 }
460 }
461
462 if (tensor->type != kTfLiteString) {
463 if (tensor->data.raw == nullptr) {
464 PyErr_Format(PyExc_ValueError,
465 "Cannot set tensor:"
466 " Tensor is unallocated. Try calling allocate_tensors()"
467 " first");
468 return nullptr;
469 }
470
471 size_t size = PyArray_NBYTES(array);
472 if (size != tensor->bytes) {
473 PyErr_Format(PyExc_ValueError,
474 "numpy array had %zu bytes but expected %zu bytes.", size,
475 tensor->bytes);
476 return nullptr;
477 }
478 memcpy(tensor->data.raw, PyArray_DATA(array), size);
479 } else {
480 DynamicBuffer dynamic_buffer;
481 if (!python_utils::FillStringBufferWithPyArray(value, &dynamic_buffer)) {
482 return nullptr;
483 }
484 dynamic_buffer.WriteToTensor(tensor, nullptr);
485 }
486 Py_RETURN_NONE;
487 }
488
NumNodes() const489 int InterpreterWrapper::NumNodes() const {
490 if (!interpreter_) {
491 return 0;
492 }
493 return interpreter_->nodes_size();
494 }
495
NodeInputs(int i) const496 PyObject* InterpreterWrapper::NodeInputs(int i) const {
497 TFLITE_PY_ENSURE_VALID_INTERPRETER();
498 TFLITE_PY_NODES_BOUNDS_CHECK(i);
499
500 const TfLiteNode* node = &(interpreter_->node_and_registration(i)->first);
501 PyObject* inputs =
502 PyArrayFromIntVector(node->inputs->data, node->inputs->size);
503 return inputs;
504 }
505
NodeOutputs(int i) const506 PyObject* InterpreterWrapper::NodeOutputs(int i) const {
507 TFLITE_PY_ENSURE_VALID_INTERPRETER();
508 TFLITE_PY_NODES_BOUNDS_CHECK(i);
509
510 const TfLiteNode* node = &(interpreter_->node_and_registration(i)->first);
511 PyObject* outputs =
512 PyArrayFromIntVector(node->outputs->data, node->outputs->size);
513 return outputs;
514 }
515
NodeName(int i) const516 std::string InterpreterWrapper::NodeName(int i) const {
517 if (!interpreter_ || i >= interpreter_->nodes_size() || i < 0) {
518 return "";
519 }
520 // Get op name from registration
521 const TfLiteRegistration* node_registration =
522 &(interpreter_->node_and_registration(i)->second);
523 int32_t op_code = node_registration->builtin_code;
524 std::string op_name;
525 if (op_code == tflite::BuiltinOperator_CUSTOM) {
526 const char* custom_name = node_registration->custom_name;
527 op_name = custom_name ? custom_name : "UnknownCustomOp";
528 } else {
529 op_name = tflite::EnumNamesBuiltinOperator()[op_code];
530 }
531 std::string op_name_str(op_name);
532 return op_name_str;
533 }
534
535 namespace {
536
537 // Checks to see if a tensor access can succeed (returns nullptr on error).
538 // Otherwise returns Py_None.
CheckGetTensorArgs(Interpreter * interpreter_,int tensor_index,TfLiteTensor ** tensor,int * type_num)539 PyObject* CheckGetTensorArgs(Interpreter* interpreter_, int tensor_index,
540 TfLiteTensor** tensor, int* type_num) {
541 TFLITE_PY_ENSURE_VALID_INTERPRETER();
542 TFLITE_PY_TENSOR_BOUNDS_CHECK(tensor_index);
543
544 *tensor = interpreter_->tensor(tensor_index);
545 if ((*tensor)->bytes == 0) {
546 PyErr_SetString(PyExc_ValueError, "Invalid tensor size.");
547 return nullptr;
548 }
549
550 *type_num = python_utils::TfLiteTypeToPyArrayType((*tensor)->type);
551 if (*type_num == -1) {
552 PyErr_SetString(PyExc_ValueError, "Unknown tensor type.");
553 return nullptr;
554 }
555
556 if (!(*tensor)->data.raw) {
557 PyErr_SetString(PyExc_ValueError,
558 "Tensor data is null."
559 " Run allocate_tensors() first");
560 return nullptr;
561 }
562
563 Py_RETURN_NONE;
564 }
565
566 } // namespace
567
GetSignatureDefs() const568 PyObject* InterpreterWrapper::GetSignatureDefs() const {
569 PyObject* result = PyDict_New();
570 for (const auto& sig_def_name : interpreter_->signature_def_names()) {
571 PyObject* signature_def = PyDict_New();
572 PyObject* inputs = PyDict_New();
573 PyObject* outputs = PyDict_New();
574 const auto& signature_def_inputs =
575 interpreter_->signature_inputs(sig_def_name->c_str());
576 const auto& signature_def_outputs =
577 interpreter_->signature_outputs(sig_def_name->c_str());
578 for (const auto& input : signature_def_inputs) {
579 PyDict_SetItemString(inputs, input.first.c_str(),
580 PyLong_FromLong(input.second));
581 }
582 for (const auto& output : signature_def_outputs) {
583 PyDict_SetItemString(outputs, output.first.c_str(),
584 PyLong_FromLong(output.second));
585 }
586
587 PyDict_SetItemString(signature_def, "inputs", inputs);
588 PyDict_SetItemString(signature_def, "outputs", outputs);
589 PyDict_SetItemString(result, sig_def_name->c_str(), signature_def);
590 }
591 return result;
592 }
593
GetOutputTensorFromSignatureDefName(const char * output_name,const char * method_name) const594 PyObject* InterpreterWrapper::GetOutputTensorFromSignatureDefName(
595 const char* output_name, const char* method_name) const {
596 const auto& outputs = interpreter_->signature_outputs(method_name);
597 const auto& output = outputs.find(output_name);
598 if (output == outputs.end()) return nullptr;
599 return GetTensor(output->second);
600 }
601
SetInputTensorFromSignatureDefName(const char * input_name,const char * method_name,PyObject * value)602 PyObject* InterpreterWrapper::SetInputTensorFromSignatureDefName(
603 const char* input_name, const char* method_name, PyObject* value) {
604 const auto& inputs = interpreter_->signature_inputs(method_name);
605 const auto& input = inputs.find(input_name);
606 if (input == inputs.end()) return nullptr;
607 return SetTensor(input->second, value);
608 }
609
GetTensor(int i) const610 PyObject* InterpreterWrapper::GetTensor(int i) const {
611 // Sanity check accessor
612 TfLiteTensor* tensor = nullptr;
613 int type_num = 0;
614
615 PyObject* check_result =
616 CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num);
617 if (check_result == nullptr) return check_result;
618 Py_XDECREF(check_result);
619
620 std::vector<npy_intp> dims(tensor->dims->data,
621 tensor->dims->data + tensor->dims->size);
622 if (tensor->type != kTfLiteString && tensor->type != kTfLiteResource &&
623 tensor->type != kTfLiteVariant) {
624 // Make a buffer copy but we must tell Numpy It owns that data or else
625 // it will leak.
626 void* data = malloc(tensor->bytes);
627 if (!data) {
628 PyErr_SetString(PyExc_ValueError, "Malloc to copy tensor failed.");
629 return nullptr;
630 }
631 memcpy(data, tensor->data.raw, tensor->bytes);
632 PyObject* np_array;
633 if (tensor->sparsity == nullptr) {
634 np_array =
635 PyArray_SimpleNewFromData(dims.size(), dims.data(), type_num, data);
636 } else {
637 std::vector<npy_intp> sparse_buffer_dims(1);
638 size_t size_of_type;
639 if (GetSizeOfType(nullptr, tensor->type, &size_of_type) != kTfLiteOk) {
640 PyErr_SetString(PyExc_ValueError, "Unknown tensor type.");
641 free(data);
642 return nullptr;
643 }
644 sparse_buffer_dims[0] = tensor->bytes / size_of_type;
645 np_array = PyArray_SimpleNewFromData(
646 sparse_buffer_dims.size(), sparse_buffer_dims.data(), type_num, data);
647 }
648 PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(np_array),
649 NPY_ARRAY_OWNDATA);
650 return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
651 } else {
652 // Create a C-order array so the data is contiguous in memory.
653 const int32_t kCOrder = 0;
654 PyObject* py_object =
655 PyArray_EMPTY(dims.size(), dims.data(), NPY_OBJECT, kCOrder);
656
657 if (py_object == nullptr) {
658 PyErr_SetString(PyExc_MemoryError, "Failed to allocate PyArray.");
659 return nullptr;
660 }
661
662 PyArrayObject* py_array = reinterpret_cast<PyArrayObject*>(py_object);
663 PyObject** data = reinterpret_cast<PyObject**>(PyArray_DATA(py_array));
664 auto num_strings = GetStringCount(tensor);
665 for (int j = 0; j < num_strings; ++j) {
666 auto ref = GetString(tensor, j);
667
668 PyObject* bytes = PyBytes_FromStringAndSize(ref.str, ref.len);
669 if (bytes == nullptr) {
670 Py_DECREF(py_object);
671 PyErr_Format(PyExc_ValueError,
672 "Could not create PyBytes from string %d of input %d.", j,
673 i);
674 return nullptr;
675 }
676 // PyArray_EMPTY produces an array full of Py_None, which we must decref.
677 Py_DECREF(data[j]);
678 data[j] = bytes;
679 }
680 return py_object;
681 }
682 }
683
tensor(PyObject * base_object,int i)684 PyObject* InterpreterWrapper::tensor(PyObject* base_object, int i) {
685 // Sanity check accessor
686 TfLiteTensor* tensor = nullptr;
687 int type_num = 0;
688
689 PyObject* check_result =
690 CheckGetTensorArgs(interpreter_.get(), i, &tensor, &type_num);
691 if (check_result == nullptr) return check_result;
692 Py_XDECREF(check_result);
693
694 std::vector<npy_intp> dims(tensor->dims->data,
695 tensor->dims->data + tensor->dims->size);
696 PyArrayObject* np_array =
697 reinterpret_cast<PyArrayObject*>(PyArray_SimpleNewFromData(
698 dims.size(), dims.data(), type_num, tensor->data.raw));
699 Py_INCREF(base_object); // SetBaseObject steals, so we need to add.
700 PyArray_SetBaseObject(np_array, base_object);
701 return PyArray_Return(np_array);
702 }
703
CreateWrapperCPPFromFile(const char * model_path,const std::vector<std::string> & registerers_by_name,const std::vector<std::function<void (uintptr_t)>> & registerers_by_func,std::string * error_msg)704 InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
705 const char* model_path, const std::vector<std::string>& registerers_by_name,
706 const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
707 std::string* error_msg) {
708 std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
709 std::unique_ptr<InterpreterWrapper::Model> model =
710 Model::BuildFromFile(model_path, error_reporter.get());
711 return CreateInterpreterWrapper(std::move(model), std::move(error_reporter),
712 registerers_by_name, registerers_by_func,
713 error_msg);
714 }
715
CreateWrapperCPPFromFile(const char * model_path,const std::vector<std::string> & registerers,std::string * error_msg)716 InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromFile(
717 const char* model_path, const std::vector<std::string>& registerers,
718 std::string* error_msg) {
719 return CreateWrapperCPPFromFile(model_path, registerers, {}, error_msg);
720 }
721
CreateWrapperCPPFromBuffer(PyObject * data,const std::vector<std::string> & registerers_by_name,const std::vector<std::function<void (uintptr_t)>> & registerers_by_func,std::string * error_msg)722 InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
723 PyObject* data, const std::vector<std::string>& registerers_by_name,
724 const std::vector<std::function<void(uintptr_t)>>& registerers_by_func,
725 std::string* error_msg) {
726 char* buf = nullptr;
727 Py_ssize_t length;
728 std::unique_ptr<PythonErrorReporter> error_reporter(new PythonErrorReporter);
729
730 if (python_utils::ConvertFromPyString(data, &buf, &length) == -1) {
731 return nullptr;
732 }
733 std::unique_ptr<InterpreterWrapper::Model> model =
734 Model::BuildFromBuffer(buf, length, error_reporter.get());
735 return CreateInterpreterWrapper(std::move(model), std::move(error_reporter),
736 registerers_by_name, registerers_by_func,
737 error_msg);
738 }
739
CreateWrapperCPPFromBuffer(PyObject * data,const std::vector<std::string> & registerers,std::string * error_msg)740 InterpreterWrapper* InterpreterWrapper::CreateWrapperCPPFromBuffer(
741 PyObject* data, const std::vector<std::string>& registerers,
742 std::string* error_msg) {
743 return CreateWrapperCPPFromBuffer(data, registerers, {}, error_msg);
744 }
745
ResetVariableTensors()746 PyObject* InterpreterWrapper::ResetVariableTensors() {
747 TFLITE_PY_ENSURE_VALID_INTERPRETER();
748 TFLITE_PY_CHECK(interpreter_->ResetVariableTensors());
749 Py_RETURN_NONE;
750 }
751
SetNumThreads(int num_threads)752 PyObject* InterpreterWrapper::SetNumThreads(int num_threads) {
753 TFLITE_PY_ENSURE_VALID_INTERPRETER();
754 interpreter_->SetNumThreads(num_threads);
755 Py_RETURN_NONE;
756 }
757
ModifyGraphWithDelegate(TfLiteDelegate * delegate)758 PyObject* InterpreterWrapper::ModifyGraphWithDelegate(
759 TfLiteDelegate* delegate) {
760 TFLITE_PY_ENSURE_VALID_INTERPRETER();
761 TFLITE_PY_CHECK(interpreter_->ModifyGraphWithDelegate(delegate));
762 Py_RETURN_NONE;
763 }
764
765 } // namespace interpreter_wrapper
766 } // namespace tflite
767