• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define TFLITE_IMPORT_NUMPY  // See numpy.h for explanation.
17 #include "tensorflow/lite/python/interpreter_wrapper/numpy.h"
18 
19 #include <memory>
20 
21 namespace tflite {
22 namespace python {
23 
ImportNumpy()24 void ImportNumpy() { import_array1(); }
25 
26 }  // namespace python
27 
28 namespace python_utils {
29 
30 struct PyObjectDereferencer {
operator ()tflite::python_utils::PyObjectDereferencer31   void operator()(PyObject* py_object) const { Py_DECREF(py_object); }
32 };
33 using UniquePyObjectRef = std::unique_ptr<PyObject, PyObjectDereferencer>;
34 
TfLiteTypeToPyArrayType(TfLiteType tf_lite_type)35 int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) {
36   switch (tf_lite_type) {
37     case kTfLiteFloat32:
38       return NPY_FLOAT32;
39     case kTfLiteFloat16:
40       return NPY_FLOAT16;
41     case kTfLiteFloat64:
42       return NPY_FLOAT64;
43     case kTfLiteInt32:
44       return NPY_INT32;
45     case kTfLiteUInt32:
46       return NPY_UINT32;
47     case kTfLiteInt16:
48       return NPY_INT16;
49     case kTfLiteUInt8:
50       return NPY_UINT8;
51     case kTfLiteInt8:
52       return NPY_INT8;
53     case kTfLiteInt64:
54       return NPY_INT64;
55     case kTfLiteUInt64:
56       return NPY_UINT64;
57     case kTfLiteString:
58       return NPY_STRING;
59     case kTfLiteBool:
60       return NPY_BOOL;
61     case kTfLiteComplex64:
62       return NPY_COMPLEX64;
63     case kTfLiteComplex128:
64       return NPY_COMPLEX128;
65     case kTfLiteResource:
66     case kTfLiteVariant:
67       return NPY_OBJECT;
68     case kTfLiteNoType:
69       return NPY_NOTYPE;
70       // Avoid default so compiler errors created when new types are made.
71   }
72   return NPY_NOTYPE;
73 }
74 
TfLiteTypeFromPyType(int py_type)75 TfLiteType TfLiteTypeFromPyType(int py_type) {
76   switch (py_type) {
77     case NPY_FLOAT32:
78       return kTfLiteFloat32;
79     case NPY_FLOAT16:
80       return kTfLiteFloat16;
81     case NPY_FLOAT64:
82       return kTfLiteFloat64;
83     case NPY_INT32:
84       return kTfLiteInt32;
85     case NPY_UINT32:
86       return kTfLiteUInt32;
87     case NPY_INT16:
88       return kTfLiteInt16;
89     case NPY_UINT8:
90       return kTfLiteUInt8;
91     case NPY_INT8:
92       return kTfLiteInt8;
93     case NPY_INT64:
94       return kTfLiteInt64;
95     case NPY_UINT64:
96       return kTfLiteUInt64;
97     case NPY_BOOL:
98       return kTfLiteBool;
99     case NPY_OBJECT:
100     case NPY_STRING:
101     case NPY_UNICODE:
102       return kTfLiteString;
103     case NPY_COMPLEX64:
104       return kTfLiteComplex64;
105     case NPY_COMPLEX128:
106       return kTfLiteComplex128;
107   }
108   return kTfLiteNoType;
109 }
110 
TfLiteTypeFromPyArray(PyArrayObject * array)111 TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) {
112   int pyarray_type = PyArray_TYPE(array);
113   return TfLiteTypeFromPyType(pyarray_type);
114 }
115 
116 #if PY_VERSION_HEX >= 0x03030000
FillStringBufferFromPyUnicode(PyObject * value,DynamicBuffer * dynamic_buffer)117 bool FillStringBufferFromPyUnicode(PyObject* value,
118                                    DynamicBuffer* dynamic_buffer) {
119   Py_ssize_t len = -1;
120   const char* buf = PyUnicode_AsUTF8AndSize(value, &len);
121   if (buf == NULL) {
122     PyErr_SetString(PyExc_ValueError, "PyUnicode_AsUTF8AndSize() failed.");
123     return false;
124   }
125   dynamic_buffer->AddString(buf, len);
126   return true;
127 }
128 #else
FillStringBufferFromPyUnicode(PyObject * value,DynamicBuffer * dynamic_buffer)129 bool FillStringBufferFromPyUnicode(PyObject* value,
130                                    DynamicBuffer* dynamic_buffer) {
131   UniquePyObjectRef utemp(PyUnicode_AsUTF8String(value));
132   if (!utemp) {
133     PyErr_SetString(PyExc_ValueError, "PyUnicode_AsUTF8String() failed.");
134     return false;
135   }
136   char* buf = nullptr;
137   Py_ssize_t len = -1;
138   if (PyBytes_AsStringAndSize(utemp.get(), &buf, &len) == -1) {
139     PyErr_SetString(PyExc_ValueError, "PyBytes_AsStringAndSize() failed.");
140     return false;
141   }
142   dynamic_buffer->AddString(buf, len);
143   return true;
144 }
145 #endif
146 
FillStringBufferFromPyString(PyObject * value,DynamicBuffer * dynamic_buffer)147 bool FillStringBufferFromPyString(PyObject* value,
148                                   DynamicBuffer* dynamic_buffer) {
149   if (PyUnicode_Check(value)) {
150     return FillStringBufferFromPyUnicode(value, dynamic_buffer);
151   }
152 
153   char* buf = nullptr;
154   Py_ssize_t len = -1;
155   if (PyBytes_AsStringAndSize(value, &buf, &len) == -1) {
156     PyErr_SetString(PyExc_ValueError, "PyBytes_AsStringAndSize() failed.");
157     return false;
158   }
159   dynamic_buffer->AddString(buf, len);
160   return true;
161 }
162 
FillStringBufferWithPyArray(PyObject * value,DynamicBuffer * dynamic_buffer)163 bool FillStringBufferWithPyArray(PyObject* value,
164                                  DynamicBuffer* dynamic_buffer) {
165   PyArrayObject* array = reinterpret_cast<PyArrayObject*>(value);
166   switch (PyArray_TYPE(array)) {
167     case NPY_OBJECT:
168     case NPY_STRING:
169     case NPY_UNICODE: {
170       if (PyArray_NDIM(array) == 0) {
171         dynamic_buffer->AddString(static_cast<char*>(PyArray_DATA(array)),
172                                   PyArray_NBYTES(array));
173         return true;
174       }
175       UniquePyObjectRef iter(PyArray_IterNew(value));
176       while (PyArray_ITER_NOTDONE(iter.get())) {
177         UniquePyObjectRef item(PyArray_GETITEM(
178             array, reinterpret_cast<char*>(PyArray_ITER_DATA(iter.get()))));
179 
180         if (!FillStringBufferFromPyString(item.get(), dynamic_buffer)) {
181           return false;
182         }
183 
184         PyArray_ITER_NEXT(iter.get());
185       }
186       return true;
187     }
188     default:
189       break;
190   }
191 
192   PyErr_Format(PyExc_ValueError,
193                "Cannot use numpy array of type %d for string tensor.",
194                PyArray_TYPE(array));
195   return false;
196 }
197 
198 }  // namespace python_utils
199 }  // namespace tflite
200