1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #define TFLITE_IMPORT_NUMPY // See numpy.h for explanation.
17 #include "tensorflow/lite/python/interpreter_wrapper/numpy.h"
18
19 #include <memory>
20
21 namespace tflite {
22 namespace python {
23
ImportNumpy()24 void ImportNumpy() { import_array1(); }
25
26 } // namespace python
27
28 namespace python_utils {
29
30 struct PyObjectDereferencer {
operator ()tflite::python_utils::PyObjectDereferencer31 void operator()(PyObject* py_object) const { Py_DECREF(py_object); }
32 };
33 using UniquePyObjectRef = std::unique_ptr<PyObject, PyObjectDereferencer>;
34
TfLiteTypeToPyArrayType(TfLiteType tf_lite_type)35 int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) {
36 switch (tf_lite_type) {
37 case kTfLiteFloat32:
38 return NPY_FLOAT32;
39 case kTfLiteFloat16:
40 return NPY_FLOAT16;
41 case kTfLiteFloat64:
42 return NPY_FLOAT64;
43 case kTfLiteInt32:
44 return NPY_INT32;
45 case kTfLiteUInt32:
46 return NPY_UINT32;
47 case kTfLiteUInt16:
48 return NPY_UINT16;
49 case kTfLiteInt16:
50 return NPY_INT16;
51 case kTfLiteUInt8:
52 return NPY_UINT8;
53 case kTfLiteInt8:
54 return NPY_INT8;
55 case kTfLiteInt64:
56 return NPY_INT64;
57 case kTfLiteUInt64:
58 return NPY_UINT64;
59 case kTfLiteString:
60 return NPY_STRING;
61 case kTfLiteBool:
62 return NPY_BOOL;
63 case kTfLiteComplex64:
64 return NPY_COMPLEX64;
65 case kTfLiteComplex128:
66 return NPY_COMPLEX128;
67 case kTfLiteResource:
68 case kTfLiteVariant:
69 return NPY_OBJECT;
70 case kTfLiteNoType:
71 return NPY_NOTYPE;
72 // Avoid default so compiler errors created when new types are made.
73 }
74 return NPY_NOTYPE;
75 }
76
TfLiteTypeFromPyType(int py_type)77 TfLiteType TfLiteTypeFromPyType(int py_type) {
78 switch (py_type) {
79 case NPY_FLOAT32:
80 return kTfLiteFloat32;
81 case NPY_FLOAT16:
82 return kTfLiteFloat16;
83 case NPY_FLOAT64:
84 return kTfLiteFloat64;
85 case NPY_INT32:
86 return kTfLiteInt32;
87 case NPY_UINT32:
88 return kTfLiteUInt32;
89 case NPY_INT16:
90 return kTfLiteInt16;
91 case NPY_UINT8:
92 return kTfLiteUInt8;
93 case NPY_INT8:
94 return kTfLiteInt8;
95 case NPY_INT64:
96 return kTfLiteInt64;
97 case NPY_UINT64:
98 return kTfLiteUInt64;
99 case NPY_BOOL:
100 return kTfLiteBool;
101 case NPY_OBJECT:
102 case NPY_STRING:
103 case NPY_UNICODE:
104 return kTfLiteString;
105 case NPY_COMPLEX64:
106 return kTfLiteComplex64;
107 case NPY_COMPLEX128:
108 return kTfLiteComplex128;
109 }
110 return kTfLiteNoType;
111 }
112
TfLiteTypeFromPyArray(PyArrayObject * array)113 TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) {
114 int pyarray_type = PyArray_TYPE(array);
115 return TfLiteTypeFromPyType(pyarray_type);
116 }
117
118 #if PY_VERSION_HEX >= 0x03030000
FillStringBufferFromPyUnicode(PyObject * value,DynamicBuffer * dynamic_buffer)119 bool FillStringBufferFromPyUnicode(PyObject* value,
120 DynamicBuffer* dynamic_buffer) {
121 Py_ssize_t len = -1;
122 const char* buf = PyUnicode_AsUTF8AndSize(value, &len);
123 if (buf == nullptr) {
124 PyErr_SetString(PyExc_ValueError, "PyUnicode_AsUTF8AndSize() failed.");
125 return false;
126 }
127 dynamic_buffer->AddString(buf, len);
128 return true;
129 }
130 #else
FillStringBufferFromPyUnicode(PyObject * value,DynamicBuffer * dynamic_buffer)131 bool FillStringBufferFromPyUnicode(PyObject* value,
132 DynamicBuffer* dynamic_buffer) {
133 UniquePyObjectRef utemp(PyUnicode_AsUTF8String(value));
134 if (!utemp) {
135 PyErr_SetString(PyExc_ValueError, "PyUnicode_AsUTF8String() failed.");
136 return false;
137 }
138 char* buf = nullptr;
139 Py_ssize_t len = -1;
140 if (PyBytes_AsStringAndSize(utemp.get(), &buf, &len) == -1) {
141 PyErr_SetString(PyExc_ValueError, "PyBytes_AsStringAndSize() failed.");
142 return false;
143 }
144 dynamic_buffer->AddString(buf, len);
145 return true;
146 }
147 #endif
148
FillStringBufferFromPyString(PyObject * value,DynamicBuffer * dynamic_buffer)149 bool FillStringBufferFromPyString(PyObject* value,
150 DynamicBuffer* dynamic_buffer) {
151 if (PyUnicode_Check(value)) {
152 return FillStringBufferFromPyUnicode(value, dynamic_buffer);
153 }
154
155 char* buf = nullptr;
156 Py_ssize_t len = -1;
157 if (PyBytes_AsStringAndSize(value, &buf, &len) == -1) {
158 PyErr_SetString(PyExc_ValueError, "PyBytes_AsStringAndSize() failed.");
159 return false;
160 }
161 dynamic_buffer->AddString(buf, len);
162 return true;
163 }
164
FillStringBufferWithPyArray(PyObject * value,DynamicBuffer * dynamic_buffer)165 bool FillStringBufferWithPyArray(PyObject* value,
166 DynamicBuffer* dynamic_buffer) {
167 if (!PyArray_Check(value)) {
168 PyErr_Format(PyExc_ValueError,
169 "Passed in value type is not a numpy array, got type %s.",
170 value->ob_type->tp_name);
171 return false;
172 }
173
174 PyArrayObject* array = reinterpret_cast<PyArrayObject*>(value);
175 switch (PyArray_TYPE(array)) {
176 case NPY_OBJECT:
177 case NPY_STRING:
178 case NPY_UNICODE: {
179 if (PyArray_NDIM(array) == 0) {
180 dynamic_buffer->AddString(static_cast<char*>(PyArray_DATA(array)),
181 PyArray_NBYTES(array));
182 return true;
183 }
184 UniquePyObjectRef iter(PyArray_IterNew(value));
185 while (PyArray_ITER_NOTDONE(iter.get())) {
186 UniquePyObjectRef item(PyArray_GETITEM(
187 array, reinterpret_cast<char*>(PyArray_ITER_DATA(iter.get()))));
188
189 if (!FillStringBufferFromPyString(item.get(), dynamic_buffer)) {
190 return false;
191 }
192
193 PyArray_ITER_NEXT(iter.get());
194 }
195 return true;
196 }
197 default:
198 break;
199 }
200
201 PyErr_Format(PyExc_ValueError,
202 "Cannot use numpy array of type %d for string tensor.",
203 PyArray_TYPE(array));
204 return false;
205 }
206
207 } // namespace python_utils
208 } // namespace tflite
209