/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/lite/python/interpreter_wrapper/python_utils.h" #include #include "tensorflow/lite/python/interpreter_wrapper/numpy.h" namespace tflite { namespace python_utils { struct PyObjectDereferencer { void operator()(PyObject* py_object) const { Py_DECREF(py_object); } }; using UniquePyObjectRef = std::unique_ptr; int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) { switch (tf_lite_type) { case kTfLiteFloat32: return NPY_FLOAT32; case kTfLiteInt32: return NPY_INT32; case kTfLiteInt16: return NPY_INT16; case kTfLiteUInt8: return NPY_UINT8; case kTfLiteInt8: return NPY_INT8; case kTfLiteInt64: return NPY_INT64; case kTfLiteString: return NPY_STRING; case kTfLiteBool: return NPY_BOOL; case kTfLiteComplex64: return NPY_COMPLEX64; case kTfLiteNoType: return NPY_NOTYPE; // Avoid default so compiler errors created when new types are made. } return NPY_NOTYPE; } TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array) { int pyarray_type = PyArray_TYPE(array); switch (pyarray_type) { case NPY_FLOAT32: return kTfLiteFloat32; case NPY_INT32: return kTfLiteInt32; case NPY_INT16: return kTfLiteInt16; case NPY_UINT8: return kTfLiteUInt8; case NPY_INT8: return kTfLiteInt8; case NPY_INT64: return kTfLiteInt64; case NPY_BOOL: return kTfLiteBool; case NPY_OBJECT: case NPY_STRING: case NPY_UNICODE: return kTfLiteString; case NPY_COMPLEX64: return kTfLiteComplex64; // Avoid default so compiler errors created when new types are made. } return kTfLiteNoType; } #if PY_VERSION_HEX >= 0x03030000 bool FillStringBufferFromPyUnicode(PyObject* value, DynamicBuffer* dynamic_buffer) { Py_ssize_t len = -1; const char* buf = PyUnicode_AsUTF8AndSize(value, &len); if (buf == NULL) { PyErr_SetString(PyExc_ValueError, "PyUnicode_AsUTF8AndSize() failed."); return false; } dynamic_buffer->AddString(buf, len); return true; } #else bool FillStringBufferFromPyUnicode(PyObject* value, DynamicBuffer* dynamic_buffer) { UniquePyObjectRef utemp(PyUnicode_AsUTF8String(value)); if (!utemp) { PyErr_SetString(PyExc_ValueError, "PyUnicode_AsUTF8String() failed."); return false; } char* buf = nullptr; Py_ssize_t len = -1; if (PyBytes_AsStringAndSize(utemp.get(), &buf, &len) == -1) { PyErr_SetString(PyExc_ValueError, "PyBytes_AsStringAndSize() failed."); return false; } dynamic_buffer->AddString(buf, len); return true; } #endif bool FillStringBufferFromPyString(PyObject* value, DynamicBuffer* dynamic_buffer) { if (PyUnicode_Check(value)) { return FillStringBufferFromPyUnicode(value, dynamic_buffer); } char* buf = nullptr; Py_ssize_t len = -1; if (PyBytes_AsStringAndSize(value, &buf, &len) == -1) { PyErr_SetString(PyExc_ValueError, "PyBytes_AsStringAndSize() failed."); return false; } dynamic_buffer->AddString(buf, len); return true; } bool FillStringBufferWithPyArray(PyObject* value, DynamicBuffer* dynamic_buffer) { PyArrayObject* array = reinterpret_cast(value); switch (PyArray_TYPE(array)) { case NPY_OBJECT: case NPY_STRING: case NPY_UNICODE: { UniquePyObjectRef iter(PyArray_IterNew(value)); while (PyArray_ITER_NOTDONE(iter.get())) { UniquePyObjectRef item(PyArray_GETITEM( array, reinterpret_cast(PyArray_ITER_DATA(iter.get())))); if (!FillStringBufferFromPyString(item.get(), dynamic_buffer)) { return false; } PyArray_ITER_NEXT(iter.get()); } return true; } default: break; } PyErr_Format(PyExc_ValueError, "Cannot use numpy array of type %d for string tensor.", PyArray_TYPE(array)); return false; } int ConvertFromPyString(PyObject* obj, char** data, Py_ssize_t* length) { #if PY_MAJOR_VERSION >= 3 return PyBytes_AsStringAndSize(obj, data, length); #else return PyString_AsStringAndSize(obj, data, length); #endif } PyObject* ConvertToPyString(const char* data, size_t length) { #if PY_MAJOR_VERSION >= 3 return PyBytes_FromStringAndSize(data, length); #else return PyString_FromStringAndSize(data, length); #endif } } // namespace python_utils } // namespace tflite