• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define TFLITE_IMPORT_NUMPY  // See numpy.h for explanation.
17 #include "tensorflow/lite/python/interpreter_wrapper/numpy.h"
18 
19 #include <memory>
20 
21 namespace tflite {
22 namespace python {
23 
ImportNumpy()24 void ImportNumpy() { import_array1(); }
25 
26 }  // namespace python
27 
28 namespace python_utils {
29 
30 struct PyObjectDereferencer {
operator ()tflite::python_utils::PyObjectDereferencer31   void operator()(PyObject* py_object) const { Py_DECREF(py_object); }
32 };
33 using UniquePyObjectRef = std::unique_ptr<PyObject, PyObjectDereferencer>;
34 
TfLiteTypeToPyArrayType(TfLiteType tf_lite_type)35 int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) {
36   switch (tf_lite_type) {
37     case kTfLiteFloat32:
38       return NPY_FLOAT32;
39     case kTfLiteFloat16:
40       return NPY_FLOAT16;
41     case kTfLiteFloat64:
42       return NPY_FLOAT64;
43     case kTfLiteInt32:
44       return NPY_INT32;
45     case kTfLiteUInt32:
46       return NPY_UINT32;
47     case kTfLiteUInt16:
48       return NPY_UINT16;
49     case kTfLiteInt16:
50       return NPY_INT16;
51     case kTfLiteUInt8:
52       return NPY_UINT8;
53     case kTfLiteInt8:
54       return NPY_INT8;
55     case kTfLiteInt64:
56       return NPY_INT64;
57     case kTfLiteUInt64:
58       return NPY_UINT64;
59     case kTfLiteString:
60       return NPY_STRING;
61     case kTfLiteBool:
62       return NPY_BOOL;
63     case kTfLiteComplex64:
64       return NPY_COMPLEX64;
65     case kTfLiteComplex128:
66       return NPY_COMPLEX128;
67     case kTfLiteResource:
68     case kTfLiteVariant:
69       return NPY_OBJECT;
70     case kTfLiteNoType:
71       return NPY_NOTYPE;
72       // Avoid default so compiler errors created when new types are made.
73   }
74   return NPY_NOTYPE;
75 }
76 
TfLiteTypeFromPyType(int py_type)77 TfLiteType TfLiteTypeFromPyType(int py_type) {
78   switch (py_type) {
79     case NPY_FLOAT32:
80       return kTfLiteFloat32;
81     case NPY_FLOAT16:
82       return kTfLiteFloat16;
83     case NPY_FLOAT64:
84       return kTfLiteFloat64;
85     case NPY_INT32:
86       return kTfLiteInt32;
87     case NPY_UINT32:
88       return kTfLiteUInt32;
89     case NPY_INT16:
90       return kTfLiteInt16;
91     case NPY_UINT8:
92       return kTfLiteUInt8;
93     case NPY_INT8:
94       return kTfLiteInt8;
95     case NPY_INT64:
96       return kTfLiteInt64;
97     case NPY_UINT64:
98       return kTfLiteUInt64;
99     case NPY_BOOL:
100       return kTfLiteBool;
101     case NPY_OBJECT:
102     case NPY_STRING:
103     case NPY_UNICODE:
104       return kTfLiteString;
105     case NPY_COMPLEX64:
106       return kTfLiteComplex64;
107     case NPY_COMPLEX128:
108       return kTfLiteComplex128;
109   }
110   return kTfLiteNoType;
111 }
112 
TfLiteTypeFromPyArray(PyArrayObject * array)113 TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) {
114   int pyarray_type = PyArray_TYPE(array);
115   return TfLiteTypeFromPyType(pyarray_type);
116 }
117 
118 #if PY_VERSION_HEX >= 0x03030000
FillStringBufferFromPyUnicode(PyObject * value,DynamicBuffer * dynamic_buffer)119 bool FillStringBufferFromPyUnicode(PyObject* value,
120                                    DynamicBuffer* dynamic_buffer) {
121   Py_ssize_t len = -1;
122   const char* buf = PyUnicode_AsUTF8AndSize(value, &len);
123   if (buf == nullptr) {
124     PyErr_SetString(PyExc_ValueError, "PyUnicode_AsUTF8AndSize() failed.");
125     return false;
126   }
127   dynamic_buffer->AddString(buf, len);
128   return true;
129 }
130 #else
FillStringBufferFromPyUnicode(PyObject * value,DynamicBuffer * dynamic_buffer)131 bool FillStringBufferFromPyUnicode(PyObject* value,
132                                    DynamicBuffer* dynamic_buffer) {
133   UniquePyObjectRef utemp(PyUnicode_AsUTF8String(value));
134   if (!utemp) {
135     PyErr_SetString(PyExc_ValueError, "PyUnicode_AsUTF8String() failed.");
136     return false;
137   }
138   char* buf = nullptr;
139   Py_ssize_t len = -1;
140   if (PyBytes_AsStringAndSize(utemp.get(), &buf, &len) == -1) {
141     PyErr_SetString(PyExc_ValueError, "PyBytes_AsStringAndSize() failed.");
142     return false;
143   }
144   dynamic_buffer->AddString(buf, len);
145   return true;
146 }
147 #endif
148 
FillStringBufferFromPyString(PyObject * value,DynamicBuffer * dynamic_buffer)149 bool FillStringBufferFromPyString(PyObject* value,
150                                   DynamicBuffer* dynamic_buffer) {
151   if (PyUnicode_Check(value)) {
152     return FillStringBufferFromPyUnicode(value, dynamic_buffer);
153   }
154 
155   char* buf = nullptr;
156   Py_ssize_t len = -1;
157   if (PyBytes_AsStringAndSize(value, &buf, &len) == -1) {
158     PyErr_SetString(PyExc_ValueError, "PyBytes_AsStringAndSize() failed.");
159     return false;
160   }
161   dynamic_buffer->AddString(buf, len);
162   return true;
163 }
164 
FillStringBufferWithPyArray(PyObject * value,DynamicBuffer * dynamic_buffer)165 bool FillStringBufferWithPyArray(PyObject* value,
166                                  DynamicBuffer* dynamic_buffer) {
167   if (!PyArray_Check(value)) {
168     PyErr_Format(PyExc_ValueError,
169                  "Passed in value type is not a numpy array, got type %s.",
170                  value->ob_type->tp_name);
171     return false;
172   }
173 
174   PyArrayObject* array = reinterpret_cast<PyArrayObject*>(value);
175   switch (PyArray_TYPE(array)) {
176     case NPY_OBJECT:
177     case NPY_STRING:
178     case NPY_UNICODE: {
179       if (PyArray_NDIM(array) == 0) {
180         dynamic_buffer->AddString(static_cast<char*>(PyArray_DATA(array)),
181                                   PyArray_NBYTES(array));
182         return true;
183       }
184       UniquePyObjectRef iter(PyArray_IterNew(value));
185       while (PyArray_ITER_NOTDONE(iter.get())) {
186         UniquePyObjectRef item(PyArray_GETITEM(
187             array, reinterpret_cast<char*>(PyArray_ITER_DATA(iter.get()))));
188 
189         if (!FillStringBufferFromPyString(item.get(), dynamic_buffer)) {
190           return false;
191         }
192 
193         PyArray_ITER_NEXT(iter.get());
194       }
195       return true;
196     }
197     default:
198       break;
199   }
200 
201   PyErr_Format(PyExc_ValueError,
202                "Cannot use numpy array of type %d for string tensor.",
203                PyArray_TYPE(array));
204   return false;
205 }
206 
207 }  // namespace python_utils
208 }  // namespace tflite
209