• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
17 
18 #include <memory>
19 
20 #include "tensorflow/lite/python/interpreter_wrapper/numpy.h"
21 
22 namespace tflite {
23 namespace python_utils {
24 
25 struct PyObjectDereferencer {
operator ()tflite::python_utils::PyObjectDereferencer26   void operator()(PyObject* py_object) const { Py_DECREF(py_object); }
27 };
28 
29 using UniquePyObjectRef = std::unique_ptr<PyObject, PyObjectDereferencer>;
30 
TfLiteTypeToPyArrayType(TfLiteType tf_lite_type)31 int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) {
32   switch (tf_lite_type) {
33     case kTfLiteFloat32:
34       return NPY_FLOAT32;
35     case kTfLiteInt32:
36       return NPY_INT32;
37     case kTfLiteInt16:
38       return NPY_INT16;
39     case kTfLiteUInt8:
40       return NPY_UINT8;
41     case kTfLiteInt8:
42       return NPY_INT8;
43     case kTfLiteInt64:
44       return NPY_INT64;
45     case kTfLiteString:
46       return NPY_STRING;
47     case kTfLiteBool:
48       return NPY_BOOL;
49     case kTfLiteComplex64:
50       return NPY_COMPLEX64;
51     case kTfLiteNoType:
52       return NPY_NOTYPE;
53       // Avoid default so compiler errors created when new types are made.
54   }
55   return NPY_NOTYPE;
56 }
57 
TfLiteTypeFromPyArray(PyArrayObject * array)58 TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) {
59   int pyarray_type = PyArray_TYPE(array);
60   switch (pyarray_type) {
61     case NPY_FLOAT32:
62       return kTfLiteFloat32;
63     case NPY_INT32:
64       return kTfLiteInt32;
65     case NPY_INT16:
66       return kTfLiteInt16;
67     case NPY_UINT8:
68       return kTfLiteUInt8;
69     case NPY_INT8:
70       return kTfLiteInt8;
71     case NPY_INT64:
72       return kTfLiteInt64;
73     case NPY_BOOL:
74       return kTfLiteBool;
75     case NPY_OBJECT:
76     case NPY_STRING:
77     case NPY_UNICODE:
78       return kTfLiteString;
79     case NPY_COMPLEX64:
80       return kTfLiteComplex64;
81       // Avoid default so compiler errors created when new types are made.
82   }
83   return kTfLiteNoType;
84 }
85 
86 #if PY_VERSION_HEX >= 0x03030000
FillStringBufferFromPyUnicode(PyObject * value,DynamicBuffer * dynamic_buffer)87 bool FillStringBufferFromPyUnicode(PyObject* value,
88                                    DynamicBuffer* dynamic_buffer) {
89   Py_ssize_t len = -1;
90   const char* buf = PyUnicode_AsUTF8AndSize(value, &len);
91   if (buf == NULL) {
92     PyErr_SetString(PyExc_ValueError, "PyUnicode_AsUTF8AndSize() failed.");
93     return false;
94   }
95   dynamic_buffer->AddString(buf, len);
96   return true;
97 }
98 #else
FillStringBufferFromPyUnicode(PyObject * value,DynamicBuffer * dynamic_buffer)99 bool FillStringBufferFromPyUnicode(PyObject* value,
100                                    DynamicBuffer* dynamic_buffer) {
101   UniquePyObjectRef utemp(PyUnicode_AsUTF8String(value));
102   if (!utemp) {
103     PyErr_SetString(PyExc_ValueError, "PyUnicode_AsUTF8String() failed.");
104     return false;
105   }
106   char* buf = nullptr;
107   Py_ssize_t len = -1;
108   if (PyBytes_AsStringAndSize(utemp.get(), &buf, &len) == -1) {
109     PyErr_SetString(PyExc_ValueError, "PyBytes_AsStringAndSize() failed.");
110     return false;
111   }
112   dynamic_buffer->AddString(buf, len);
113   return true;
114 }
115 #endif
116 
FillStringBufferFromPyString(PyObject * value,DynamicBuffer * dynamic_buffer)117 bool FillStringBufferFromPyString(PyObject* value,
118                                   DynamicBuffer* dynamic_buffer) {
119   if (PyUnicode_Check(value)) {
120     return FillStringBufferFromPyUnicode(value, dynamic_buffer);
121   }
122 
123   char* buf = nullptr;
124   Py_ssize_t len = -1;
125   if (PyBytes_AsStringAndSize(value, &buf, &len) == -1) {
126     PyErr_SetString(PyExc_ValueError, "PyBytes_AsStringAndSize() failed.");
127     return false;
128   }
129   dynamic_buffer->AddString(buf, len);
130   return true;
131 }
132 
FillStringBufferWithPyArray(PyObject * value,DynamicBuffer * dynamic_buffer)133 bool FillStringBufferWithPyArray(PyObject* value,
134                                  DynamicBuffer* dynamic_buffer) {
135   PyArrayObject* array = reinterpret_cast<PyArrayObject*>(value);
136   switch (PyArray_TYPE(array)) {
137     case NPY_OBJECT:
138     case NPY_STRING:
139     case NPY_UNICODE: {
140       UniquePyObjectRef iter(PyArray_IterNew(value));
141       while (PyArray_ITER_NOTDONE(iter.get())) {
142         UniquePyObjectRef item(PyArray_GETITEM(
143             array, reinterpret_cast<char*>(PyArray_ITER_DATA(iter.get()))));
144 
145         if (!FillStringBufferFromPyString(item.get(), dynamic_buffer)) {
146           return false;
147         }
148 
149         PyArray_ITER_NEXT(iter.get());
150       }
151       return true;
152     }
153     default:
154       break;
155   }
156 
157   PyErr_Format(PyExc_ValueError,
158                "Cannot use numpy array of type %d for string tensor.",
159                PyArray_TYPE(array));
160   return false;
161 }
162 
ConvertFromPyString(PyObject * obj,char ** data,Py_ssize_t * length)163 int ConvertFromPyString(PyObject* obj, char** data, Py_ssize_t* length) {
164 #if PY_MAJOR_VERSION >= 3
165   return PyBytes_AsStringAndSize(obj, data, length);
166 #else
167   return PyString_AsStringAndSize(obj, data, length);
168 #endif
169 }
170 
ConvertToPyString(const char * data,size_t length)171 PyObject* ConvertToPyString(const char* data, size_t length) {
172 #if PY_MAJOR_VERSION >= 3
173   return PyBytes_FromStringAndSize(data, length);
174 #else
175   return PyString_FromStringAndSize(data, length);
176 #endif
177 }
178 
179 }  // namespace python_utils
180 }  // namespace tflite
181