• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/python/lib/core/ndarray_tensor.h"
17 
18 #include <cstring>
19 
20 #include "tensorflow/c/eager/tfe_context_internal.h"
21 #include "tensorflow/c/tf_tensor_internal.h"
22 #include "tensorflow/core/lib/core/coding.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/gtl/inlined_vector.h"
25 #include "tensorflow/core/platform/types.h"
26 #include "tensorflow/python/lib/core/bfloat16.h"
27 #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
28 #include "tensorflow/python/lib/core/numpy.h"
29 
30 namespace tensorflow {
31 namespace {
32 
numpy_type_name(int numpy_type)33 char const* numpy_type_name(int numpy_type) {
34   switch (numpy_type) {
35 #define TYPE_CASE(s) \
36   case s:            \
37     return #s
38 
39     TYPE_CASE(NPY_BOOL);
40     TYPE_CASE(NPY_BYTE);
41     TYPE_CASE(NPY_UBYTE);
42     TYPE_CASE(NPY_SHORT);
43     TYPE_CASE(NPY_USHORT);
44     TYPE_CASE(NPY_INT);
45     TYPE_CASE(NPY_UINT);
46     TYPE_CASE(NPY_LONG);
47     TYPE_CASE(NPY_ULONG);
48     TYPE_CASE(NPY_LONGLONG);
49     TYPE_CASE(NPY_ULONGLONG);
50     TYPE_CASE(NPY_FLOAT);
51     TYPE_CASE(NPY_DOUBLE);
52     TYPE_CASE(NPY_LONGDOUBLE);
53     TYPE_CASE(NPY_CFLOAT);
54     TYPE_CASE(NPY_CDOUBLE);
55     TYPE_CASE(NPY_CLONGDOUBLE);
56     TYPE_CASE(NPY_OBJECT);
57     TYPE_CASE(NPY_STRING);
58     TYPE_CASE(NPY_UNICODE);
59     TYPE_CASE(NPY_VOID);
60     TYPE_CASE(NPY_DATETIME);
61     TYPE_CASE(NPY_TIMEDELTA);
62     TYPE_CASE(NPY_HALF);
63     TYPE_CASE(NPY_NTYPES);
64     TYPE_CASE(NPY_NOTYPE);
65     TYPE_CASE(NPY_CHAR);
66     TYPE_CASE(NPY_USERDEF);
67     default:
68       return "not a numpy type";
69   }
70 }
71 
PyArrayDescr_to_TF_DataType(PyArray_Descr * descr,TF_DataType * out_tf_datatype)72 Status PyArrayDescr_to_TF_DataType(PyArray_Descr* descr,
73                                    TF_DataType* out_tf_datatype) {
74   PyObject* key;
75   PyObject* value;
76   Py_ssize_t pos = 0;
77   if (PyDict_Next(descr->fields, &pos, &key, &value)) {
78     // In Python 3, the keys of numpy custom struct types are unicode, unlike
79     // Python 2, where the keys are bytes.
80     const char* key_string =
81         PyBytes_Check(key) ? PyBytes_AsString(key)
82                            : PyBytes_AsString(PyUnicode_AsASCIIString(key));
83     if (!key_string) {
84       return errors::Internal("Corrupt numpy type descriptor");
85     }
86     tensorflow::string key = key_string;
87     // The typenames here should match the field names in the custom struct
88     // types constructed in test_util.py.
89     // TODO(mrry,keveman): Investigate Numpy type registration to replace this
90     // hard-coding of names.
91     if (key == "quint8") {
92       *out_tf_datatype = TF_QUINT8;
93     } else if (key == "qint8") {
94       *out_tf_datatype = TF_QINT8;
95     } else if (key == "qint16") {
96       *out_tf_datatype = TF_QINT16;
97     } else if (key == "quint16") {
98       *out_tf_datatype = TF_QUINT16;
99     } else if (key == "qint32") {
100       *out_tf_datatype = TF_QINT32;
101     } else if (key == "resource") {
102       *out_tf_datatype = TF_RESOURCE;
103     } else {
104       return errors::Internal("Unsupported numpy data type");
105     }
106     return Status::OK();
107   }
108   return errors::Internal("Unsupported numpy data type");
109 }
110 
PyArray_TYPE_to_TF_DataType(PyArrayObject * array,TF_DataType * out_tf_datatype)111 Status PyArray_TYPE_to_TF_DataType(PyArrayObject* array,
112                                    TF_DataType* out_tf_datatype) {
113   int pyarray_type = PyArray_TYPE(array);
114   PyArray_Descr* descr = PyArray_DESCR(array);
115   switch (pyarray_type) {
116     case NPY_FLOAT16:
117       *out_tf_datatype = TF_HALF;
118       break;
119     case NPY_FLOAT32:
120       *out_tf_datatype = TF_FLOAT;
121       break;
122     case NPY_FLOAT64:
123       *out_tf_datatype = TF_DOUBLE;
124       break;
125     case NPY_INT32:
126       *out_tf_datatype = TF_INT32;
127       break;
128     case NPY_UINT8:
129       *out_tf_datatype = TF_UINT8;
130       break;
131     case NPY_UINT16:
132       *out_tf_datatype = TF_UINT16;
133       break;
134     case NPY_UINT32:
135       *out_tf_datatype = TF_UINT32;
136       break;
137     case NPY_UINT64:
138       *out_tf_datatype = TF_UINT64;
139       break;
140     case NPY_INT8:
141       *out_tf_datatype = TF_INT8;
142       break;
143     case NPY_INT16:
144       *out_tf_datatype = TF_INT16;
145       break;
146     case NPY_INT64:
147       *out_tf_datatype = TF_INT64;
148       break;
149     case NPY_BOOL:
150       *out_tf_datatype = TF_BOOL;
151       break;
152     case NPY_COMPLEX64:
153       *out_tf_datatype = TF_COMPLEX64;
154       break;
155     case NPY_COMPLEX128:
156       *out_tf_datatype = TF_COMPLEX128;
157       break;
158     case NPY_OBJECT:
159     case NPY_STRING:
160     case NPY_UNICODE:
161       *out_tf_datatype = TF_STRING;
162       break;
163     case NPY_VOID:
164       // Quantized types are currently represented as custom struct types.
165       // PyArray_TYPE returns NPY_VOID for structs, and we should look into
166       // descr to derive the actual type.
167       // Direct feeds of certain types of ResourceHandles are represented as a
168       // custom struct type.
169       return PyArrayDescr_to_TF_DataType(descr, out_tf_datatype);
170     default:
171       if (pyarray_type == Bfloat16NumpyType()) {
172         *out_tf_datatype = TF_BFLOAT16;
173         break;
174       } else if (pyarray_type == NPY_ULONGLONG) {
175         // NPY_ULONGLONG is equivalent to NPY_UINT64, while their enum values
176         // might be different on certain platforms.
177         *out_tf_datatype = TF_UINT64;
178         break;
179       } else if (pyarray_type == NPY_LONGLONG) {
180         // NPY_LONGLONG is equivalent to NPY_INT64, while their enum values
181         // might be different on certain platforms.
182         *out_tf_datatype = TF_INT64;
183         break;
184       } else if (pyarray_type == NPY_INT) {
185         // NPY_INT is equivalent to NPY_INT32, while their enum values might be
186         // different on certain platforms.
187         *out_tf_datatype = TF_INT32;
188         break;
189       } else if (pyarray_type == NPY_UINT) {
190         // NPY_UINT is equivalent to NPY_UINT32, while their enum values might
191         // be different on certain platforms.
192         *out_tf_datatype = TF_UINT32;
193         break;
194       }
195       return errors::Internal("Unsupported numpy type: ",
196                               numpy_type_name(pyarray_type));
197   }
198   return Status::OK();
199 }
200 
PyObjectToString(PyObject * obj,const char ** ptr,Py_ssize_t * len,PyObject ** ptr_owner)201 Status PyObjectToString(PyObject* obj, const char** ptr, Py_ssize_t* len,
202                         PyObject** ptr_owner) {
203   *ptr_owner = nullptr;
204   if (PyBytes_Check(obj)) {
205     char* buf;
206     if (PyBytes_AsStringAndSize(obj, &buf, len) != 0) {
207       return errors::Internal("Unable to get element as bytes.");
208     }
209     *ptr = buf;
210     return Status::OK();
211   } else if (PyUnicode_Check(obj)) {
212 #if (PY_MAJOR_VERSION > 3 || (PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION >= 3))
213     *ptr = PyUnicode_AsUTF8AndSize(obj, len);
214     if (*ptr != nullptr) return Status::OK();
215 #else
216     PyObject* utemp = PyUnicode_AsUTF8String(obj);
217     char* buf;
218     if (utemp != nullptr && PyBytes_AsStringAndSize(utemp, &buf, len) != -1) {
219       *ptr = buf;
220       *ptr_owner = utemp;
221       return Status::OK();
222     }
223     Py_XDECREF(utemp);
224 #endif
225     return errors::Internal("Unable to convert element to UTF-8");
226   } else {
227     return errors::Internal("Unsupported object type ", obj->ob_type->tp_name);
228   }
229 }
230 
231 // Iterate over the string array 'array', extract the ptr and len of each string
232 // element and call f(ptr, len).
233 template <typename F>
PyBytesArrayMap(PyArrayObject * array,F f)234 Status PyBytesArrayMap(PyArrayObject* array, F f) {
235   Safe_PyObjectPtr iter = tensorflow::make_safe(
236       PyArray_IterNew(reinterpret_cast<PyObject*>(array)));
237   while (PyArray_ITER_NOTDONE(iter.get())) {
238     auto item = tensorflow::make_safe(PyArray_GETITEM(
239         array, static_cast<char*>(PyArray_ITER_DATA(iter.get()))));
240     if (!item) {
241       return errors::Internal("Unable to get element from the feed - no item.");
242     }
243     Py_ssize_t len;
244     const char* ptr;
245     PyObject* ptr_owner = nullptr;
246     TF_RETURN_IF_ERROR(PyObjectToString(item.get(), &ptr, &len, &ptr_owner));
247     f(ptr, len);
248     Py_XDECREF(ptr_owner);
249     PyArray_ITER_NEXT(iter.get());
250   }
251   return Status::OK();
252 }
253 
254 // Encode the strings in 'array' into a contiguous buffer and return the base of
255 // the buffer. The caller takes ownership of the buffer.
EncodePyBytesArray(PyArrayObject * array,tensorflow::int64 nelems,size_t * size,void ** buffer)256 Status EncodePyBytesArray(PyArrayObject* array, tensorflow::int64 nelems,
257                           size_t* size, void** buffer) {
258   // Encode all strings.
259   *size = nelems * sizeof(tensorflow::tstring);
260   std::unique_ptr<tensorflow::tstring[]> base_ptr(
261       new tensorflow::tstring[nelems]);
262   tensorflow::tstring* dst = base_ptr.get();
263 
264   TF_RETURN_IF_ERROR(
265       PyBytesArrayMap(array, [&dst](const char* ptr, Py_ssize_t len) {
266         dst->assign(ptr, len);
267         dst++;
268       }));
269   *buffer = base_ptr.release();
270   return Status::OK();
271 }
272 
CopyTF_TensorStringsToPyArray(const TF_Tensor * src,uint64 nelems,PyArrayObject * dst)273 Status CopyTF_TensorStringsToPyArray(const TF_Tensor* src, uint64 nelems,
274                                      PyArrayObject* dst) {
275   const void* tensor_data = TF_TensorData(src);
276   DCHECK(tensor_data != nullptr);
277   DCHECK_EQ(TF_STRING, TF_TensorType(src));
278 
279   const tstring* tstr = static_cast<const tstring*>(tensor_data);
280 
281   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
282       TF_NewStatus(), TF_DeleteStatus);
283   auto iter = make_safe(PyArray_IterNew(reinterpret_cast<PyObject*>(dst)));
284   for (int64 i = 0; i < static_cast<int64>(nelems); ++i) {
285     const tstring& tstr_i = tstr[i];
286     auto py_string =
287         make_safe(PyBytes_FromStringAndSize(tstr_i.data(), tstr_i.size()));
288     if (py_string == nullptr) {
289       return errors::Internal(
290           "failed to create a python byte array when converting element #", i,
291           " of a TF_STRING tensor to a numpy ndarray");
292     }
293 
294     if (PyArray_SETITEM(dst, static_cast<char*>(PyArray_ITER_DATA(iter.get())),
295                         py_string.get()) != 0) {
296       return errors::Internal("Error settings element #", i,
297                               " in the numpy ndarray");
298     }
299     PyArray_ITER_NEXT(iter.get());
300   }
301   return Status::OK();
302 }
303 
304 // Determine the dimensions of a numpy ndarray to be created to represent an
305 // output Tensor.
GetPyArrayDimensionsForTensor(const TF_Tensor * tensor,gtl::InlinedVector<npy_intp,4> * dims,tensorflow::int64 * nelems)306 Status GetPyArrayDimensionsForTensor(const TF_Tensor* tensor,
307                                      gtl::InlinedVector<npy_intp, 4>* dims,
308                                      tensorflow::int64* nelems) {
309   dims->clear();
310   const int ndims = TF_NumDims(tensor);
311   if (TF_TensorType(tensor) == TF_RESOURCE) {
312     if (ndims != 0) {
313       return errors::InvalidArgument(
314           "Fetching of non-scalar resource tensors is not supported.");
315     }
316     dims->push_back(TF_TensorByteSize(tensor));
317     *nelems = dims->back();
318   } else {
319     *nelems = 1;
320     for (int i = 0; i < ndims; ++i) {
321       dims->push_back(TF_Dim(tensor, i));
322       *nelems *= dims->back();
323     }
324   }
325   return Status::OK();
326 }
327 
328 // Determine the type description (PyArray_Descr) of a numpy ndarray to be
329 // created to represent an output Tensor.
GetPyArrayDescrForTensor(const TF_Tensor * tensor,PyArray_Descr ** descr)330 Status GetPyArrayDescrForTensor(const TF_Tensor* tensor,
331                                 PyArray_Descr** descr) {
332   if (TF_TensorType(tensor) == TF_RESOURCE) {
333     PyObject* field = PyTuple_New(3);
334 #if PY_MAJOR_VERSION < 3
335     PyTuple_SetItem(field, 0, PyBytes_FromString("resource"));
336 #else
337     PyTuple_SetItem(field, 0, PyUnicode_FromString("resource"));
338 #endif
339     PyTuple_SetItem(field, 1, PyArray_TypeObjectFromType(NPY_UBYTE));
340     PyTuple_SetItem(field, 2, PyLong_FromLong(1));
341     PyObject* fields = PyList_New(1);
342     PyList_SetItem(fields, 0, field);
343     int convert_result = PyArray_DescrConverter(fields, descr);
344     Py_CLEAR(field);
345     Py_CLEAR(fields);
346     if (convert_result != 1) {
347       return errors::Internal("Failed to create numpy array description for ",
348                               "TF_RESOURCE-type tensor");
349     }
350   } else {
351     int type_num = -1;
352     TF_RETURN_IF_ERROR(
353         TF_DataType_to_PyArray_TYPE(TF_TensorType(tensor), &type_num));
354     *descr = PyArray_DescrFromType(type_num);
355   }
356 
357   return Status::OK();
358 }
359 
FastMemcpy(void * dst,const void * src,size_t size)360 inline void FastMemcpy(void* dst, const void* src, size_t size) {
361   // clang-format off
362   switch (size) {
363     // Most compilers will generate inline code for fixed sizes,
364     // which is significantly faster for small copies.
365     case  1: memcpy(dst, src, 1); break;
366     case  2: memcpy(dst, src, 2); break;
367     case  3: memcpy(dst, src, 3); break;
368     case  4: memcpy(dst, src, 4); break;
369     case  5: memcpy(dst, src, 5); break;
370     case  6: memcpy(dst, src, 6); break;
371     case  7: memcpy(dst, src, 7); break;
372     case  8: memcpy(dst, src, 8); break;
373     case  9: memcpy(dst, src, 9); break;
374     case 10: memcpy(dst, src, 10); break;
375     case 11: memcpy(dst, src, 11); break;
376     case 12: memcpy(dst, src, 12); break;
377     case 13: memcpy(dst, src, 13); break;
378     case 14: memcpy(dst, src, 14); break;
379     case 15: memcpy(dst, src, 15); break;
380     case 16: memcpy(dst, src, 16); break;
381 #if defined(PLATFORM_GOOGLE) || defined(PLATFORM_POSIX) && \
382     !defined(IS_MOBILE_PLATFORM)
383     // On Linux, memmove appears to be faster than memcpy for
384     // large sizes, strangely enough.
385     default: memmove(dst, src, size); break;
386 #else
387     default: memcpy(dst, src, size); break;
388 #endif
389   }
390   // clang-format on
391 }
392 
393 }  // namespace
394 
395 // TODO(slebedev): revise TF_TensorToPyArray usages and switch to the
396 // aliased version where appropriate.
TF_TensorToMaybeAliasedPyArray(Safe_TF_TensorPtr tensor,PyObject ** out_ndarray)397 Status TF_TensorToMaybeAliasedPyArray(Safe_TF_TensorPtr tensor,
398                                       PyObject** out_ndarray) {
399   auto dtype = TF_TensorType(tensor.get());
400   if (dtype == TF_STRING || dtype == TF_RESOURCE) {
401     return TF_TensorToPyArray(std::move(tensor), out_ndarray);
402   }
403 
404   TF_Tensor* moved = tensor.release();
405   int64 nelems = -1;
406   gtl::InlinedVector<npy_intp, 4> dims;
407   TF_RETURN_IF_ERROR(GetPyArrayDimensionsForTensor(moved, &dims, &nelems));
408   return ArrayFromMemory(
409       dims.size(), dims.data(), TF_TensorData(moved),
410       static_cast<DataType>(dtype), [moved] { TF_DeleteTensor(moved); },
411       out_ndarray);
412 }
413 
414 // Converts the given TF_Tensor to a numpy ndarray.
415 // If the returned status is OK, the caller becomes the owner of *out_array.
TF_TensorToPyArray(Safe_TF_TensorPtr tensor,PyObject ** out_ndarray)416 Status TF_TensorToPyArray(Safe_TF_TensorPtr tensor, PyObject** out_ndarray) {
417   // A fetched operation will correspond to a null tensor, and a None
418   // in Python.
419   if (tensor == nullptr) {
420     Py_INCREF(Py_None);
421     *out_ndarray = Py_None;
422     return Status::OK();
423   }
424   int64 nelems = -1;
425   gtl::InlinedVector<npy_intp, 4> dims;
426   TF_RETURN_IF_ERROR(
427       GetPyArrayDimensionsForTensor(tensor.get(), &dims, &nelems));
428 
429   // If the type is neither string nor resource we can reuse the Tensor memory.
430   TF_Tensor* original = tensor.get();
431   TF_Tensor* moved = TF_TensorMaybeMove(tensor.release());
432   if (moved != nullptr) {
433     if (ArrayFromMemory(
434             dims.size(), dims.data(), TF_TensorData(moved),
435             static_cast<DataType>(TF_TensorType(moved)),
436             [moved] { TF_DeleteTensor(moved); }, out_ndarray)
437             .ok()) {
438       return Status::OK();
439     }
440   }
441   tensor.reset(original);
442 
443   // Copy the TF_TensorData into a newly-created ndarray and return it.
444   PyArray_Descr* descr = nullptr;
445   TF_RETURN_IF_ERROR(GetPyArrayDescrForTensor(tensor.get(), &descr));
446   Safe_PyObjectPtr safe_out_array =
447       tensorflow::make_safe(PyArray_Empty(dims.size(), dims.data(), descr, 0));
448   if (!safe_out_array) {
449     return errors::Internal("Could not allocate ndarray");
450   }
451   PyArrayObject* py_array =
452       reinterpret_cast<PyArrayObject*>(safe_out_array.get());
453   if (TF_TensorType(tensor.get()) == TF_STRING) {
454     Status s = CopyTF_TensorStringsToPyArray(tensor.get(), nelems, py_array);
455     if (!s.ok()) {
456       return s;
457     }
458   } else if (static_cast<size_t>(PyArray_NBYTES(py_array)) !=
459              TF_TensorByteSize(tensor.get())) {
460     return errors::Internal("ndarray was ", PyArray_NBYTES(py_array),
461                             " bytes but TF_Tensor was ",
462                             TF_TensorByteSize(tensor.get()), " bytes");
463   } else {
464     FastMemcpy(PyArray_DATA(py_array), TF_TensorData(tensor.get()),
465                PyArray_NBYTES(py_array));
466   }
467 
468   *out_ndarray = safe_out_array.release();
469   return Status::OK();
470 }
471 
NdarrayToTensor(TFE_Context * ctx,PyObject * ndarray,Safe_TF_TensorPtr * ret)472 Status NdarrayToTensor(TFE_Context* ctx, PyObject* ndarray,
473                        Safe_TF_TensorPtr* ret) {
474   DCHECK(ret != nullptr);
475 
476   // Make sure we dereference this array object in case of error, etc.
477   Safe_PyObjectPtr array_safe(make_safe(
478       PyArray_FromAny(ndarray, nullptr, 0, 0, NPY_ARRAY_CARRAY_RO, nullptr)));
479   if (!array_safe) return errors::InvalidArgument("Not a ndarray.");
480   PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
481 
482   // Convert numpy dtype to TensorFlow dtype.
483   TF_DataType dtype = TF_FLOAT;
484   TF_RETURN_IF_ERROR(PyArray_TYPE_to_TF_DataType(array, &dtype));
485 
486   tensorflow::int64 nelems = 1;
487   gtl::InlinedVector<int64_t, 4> dims;
488   for (int i = 0; i < PyArray_NDIM(array); ++i) {
489     dims.push_back(PyArray_SHAPE(array)[i]);
490     nelems *= dims[i];
491   }
492 
493   // Create a TF_Tensor based on the fed data. In the case of non-string data
494   // type, this steals a reference to array, which will be relinquished when
495   // the underlying buffer is deallocated. For string, a new temporary buffer
496   // is allocated into which the strings are encoded.
497   if (dtype == TF_RESOURCE) {
498     size_t size = PyArray_NBYTES(array);
499     array_safe.release();
500 
501     if (ctx) {
502       *ret = make_safe(new TF_Tensor{tensorflow::unwrap(ctx)->CreateTensor(
503           static_cast<tensorflow::DataType>(dtype), {}, 0, PyArray_DATA(array),
504           size, &DelayedNumpyDecref, array)});
505     } else {
506       *ret = make_safe(TF_NewTensor(dtype, {}, 0, PyArray_DATA(array), size,
507                                     &DelayedNumpyDecref, array));
508     }
509 
510   } else if (dtype != TF_STRING) {
511     size_t size = PyArray_NBYTES(array);
512     array_safe.release();
513     if (ctx) {
514       *ret = make_safe(new TF_Tensor{tensorflow::unwrap(ctx)->CreateTensor(
515           static_cast<tensorflow::DataType>(dtype), dims.data(), dims.size(),
516           PyArray_DATA(array), size, &DelayedNumpyDecref, array)});
517     } else {
518       *ret = make_safe(TF_NewTensor(dtype, dims.data(), dims.size(),
519                                     PyArray_DATA(array), size,
520                                     &DelayedNumpyDecref, array));
521     }
522 
523   } else {
524     size_t size = 0;
525     void* encoded = nullptr;
526     TF_RETURN_IF_ERROR(EncodePyBytesArray(array, nelems, &size, &encoded));
527     if (ctx) {
528       *ret = make_safe(new TF_Tensor{tensorflow::unwrap(ctx)->CreateTensor(
529           static_cast<tensorflow::DataType>(dtype), dims.data(), dims.size(),
530           encoded, size,
531           [](void* data, size_t len, void* arg) {
532             delete[] reinterpret_cast<tensorflow::tstring*>(data);
533           },
534           nullptr)});
535     } else {
536       *ret = make_safe(TF_NewTensor(
537           dtype, dims.data(), dims.size(), encoded, size,
538           [](void* data, size_t len, void* arg) {
539             delete[] reinterpret_cast<tensorflow::tstring*>(data);
540           },
541           nullptr));
542     }
543   }
544 
545   return Status::OK();
546 }
547 
548 Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
549 TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, Status* status);
550 
NdarrayToTensor(PyObject * obj,Tensor * ret)551 Status NdarrayToTensor(PyObject* obj, Tensor* ret) {
552   Safe_TF_TensorPtr tf_tensor = make_safe(static_cast<TF_Tensor*>(nullptr));
553   Status s = NdarrayToTensor(nullptr /*ctx*/, obj, &tf_tensor);
554   if (!s.ok()) {
555     return s;
556   }
557   return TF_TensorToTensor(tf_tensor.get(), ret);
558 }
559 
TensorToNdarray(const Tensor & t,PyObject ** ret)560 Status TensorToNdarray(const Tensor& t, PyObject** ret) {
561   Status status;
562   Safe_TF_TensorPtr tf_tensor = make_safe(TF_TensorFromTensor(t, &status));
563   if (!status.ok()) {
564     return status;
565   }
566   return TF_TensorToPyArray(std::move(tf_tensor), ret);
567 }
568 
569 }  // namespace tensorflow
570