1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/python/lib/core/ndarray_tensor.h"
17
18 #include <cstring>
19
20 #include "tensorflow/c/eager/tfe_context_internal.h"
21 #include "tensorflow/c/tf_tensor_internal.h"
22 #include "tensorflow/core/lib/core/coding.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/gtl/inlined_vector.h"
25 #include "tensorflow/core/platform/types.h"
26 #include "tensorflow/python/lib/core/bfloat16.h"
27 #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
28 #include "tensorflow/python/lib/core/numpy.h"
29
30 namespace tensorflow {
31 namespace {
32
numpy_type_name(int numpy_type)33 char const* numpy_type_name(int numpy_type) {
34 switch (numpy_type) {
35 #define TYPE_CASE(s) \
36 case s: \
37 return #s
38
39 TYPE_CASE(NPY_BOOL);
40 TYPE_CASE(NPY_BYTE);
41 TYPE_CASE(NPY_UBYTE);
42 TYPE_CASE(NPY_SHORT);
43 TYPE_CASE(NPY_USHORT);
44 TYPE_CASE(NPY_INT);
45 TYPE_CASE(NPY_UINT);
46 TYPE_CASE(NPY_LONG);
47 TYPE_CASE(NPY_ULONG);
48 TYPE_CASE(NPY_LONGLONG);
49 TYPE_CASE(NPY_ULONGLONG);
50 TYPE_CASE(NPY_FLOAT);
51 TYPE_CASE(NPY_DOUBLE);
52 TYPE_CASE(NPY_LONGDOUBLE);
53 TYPE_CASE(NPY_CFLOAT);
54 TYPE_CASE(NPY_CDOUBLE);
55 TYPE_CASE(NPY_CLONGDOUBLE);
56 TYPE_CASE(NPY_OBJECT);
57 TYPE_CASE(NPY_STRING);
58 TYPE_CASE(NPY_UNICODE);
59 TYPE_CASE(NPY_VOID);
60 TYPE_CASE(NPY_DATETIME);
61 TYPE_CASE(NPY_TIMEDELTA);
62 TYPE_CASE(NPY_HALF);
63 TYPE_CASE(NPY_NTYPES);
64 TYPE_CASE(NPY_NOTYPE);
65 TYPE_CASE(NPY_CHAR);
66 TYPE_CASE(NPY_USERDEF);
67 default:
68 return "not a numpy type";
69 }
70 }
71
PyArrayDescr_to_TF_DataType(PyArray_Descr * descr,TF_DataType * out_tf_datatype)72 Status PyArrayDescr_to_TF_DataType(PyArray_Descr* descr,
73 TF_DataType* out_tf_datatype) {
74 PyObject* key;
75 PyObject* value;
76 Py_ssize_t pos = 0;
77 if (PyDict_Next(descr->fields, &pos, &key, &value)) {
78 // In Python 3, the keys of numpy custom struct types are unicode, unlike
79 // Python 2, where the keys are bytes.
80 const char* key_string =
81 PyBytes_Check(key) ? PyBytes_AsString(key)
82 : PyBytes_AsString(PyUnicode_AsASCIIString(key));
83 if (!key_string) {
84 return errors::Internal("Corrupt numpy type descriptor");
85 }
86 tensorflow::string key = key_string;
87 // The typenames here should match the field names in the custom struct
88 // types constructed in test_util.py.
89 // TODO(mrry,keveman): Investigate Numpy type registration to replace this
90 // hard-coding of names.
91 if (key == "quint8") {
92 *out_tf_datatype = TF_QUINT8;
93 } else if (key == "qint8") {
94 *out_tf_datatype = TF_QINT8;
95 } else if (key == "qint16") {
96 *out_tf_datatype = TF_QINT16;
97 } else if (key == "quint16") {
98 *out_tf_datatype = TF_QUINT16;
99 } else if (key == "qint32") {
100 *out_tf_datatype = TF_QINT32;
101 } else if (key == "resource") {
102 *out_tf_datatype = TF_RESOURCE;
103 } else {
104 return errors::Internal("Unsupported numpy data type");
105 }
106 return Status::OK();
107 }
108 return errors::Internal("Unsupported numpy data type");
109 }
110
PyArray_TYPE_to_TF_DataType(PyArrayObject * array,TF_DataType * out_tf_datatype)111 Status PyArray_TYPE_to_TF_DataType(PyArrayObject* array,
112 TF_DataType* out_tf_datatype) {
113 int pyarray_type = PyArray_TYPE(array);
114 PyArray_Descr* descr = PyArray_DESCR(array);
115 switch (pyarray_type) {
116 case NPY_FLOAT16:
117 *out_tf_datatype = TF_HALF;
118 break;
119 case NPY_FLOAT32:
120 *out_tf_datatype = TF_FLOAT;
121 break;
122 case NPY_FLOAT64:
123 *out_tf_datatype = TF_DOUBLE;
124 break;
125 case NPY_INT32:
126 *out_tf_datatype = TF_INT32;
127 break;
128 case NPY_UINT8:
129 *out_tf_datatype = TF_UINT8;
130 break;
131 case NPY_UINT16:
132 *out_tf_datatype = TF_UINT16;
133 break;
134 case NPY_UINT32:
135 *out_tf_datatype = TF_UINT32;
136 break;
137 case NPY_UINT64:
138 *out_tf_datatype = TF_UINT64;
139 break;
140 case NPY_INT8:
141 *out_tf_datatype = TF_INT8;
142 break;
143 case NPY_INT16:
144 *out_tf_datatype = TF_INT16;
145 break;
146 case NPY_INT64:
147 *out_tf_datatype = TF_INT64;
148 break;
149 case NPY_BOOL:
150 *out_tf_datatype = TF_BOOL;
151 break;
152 case NPY_COMPLEX64:
153 *out_tf_datatype = TF_COMPLEX64;
154 break;
155 case NPY_COMPLEX128:
156 *out_tf_datatype = TF_COMPLEX128;
157 break;
158 case NPY_OBJECT:
159 case NPY_STRING:
160 case NPY_UNICODE:
161 *out_tf_datatype = TF_STRING;
162 break;
163 case NPY_VOID:
164 // Quantized types are currently represented as custom struct types.
165 // PyArray_TYPE returns NPY_VOID for structs, and we should look into
166 // descr to derive the actual type.
167 // Direct feeds of certain types of ResourceHandles are represented as a
168 // custom struct type.
169 return PyArrayDescr_to_TF_DataType(descr, out_tf_datatype);
170 default:
171 if (pyarray_type == Bfloat16NumpyType()) {
172 *out_tf_datatype = TF_BFLOAT16;
173 break;
174 } else if (pyarray_type == NPY_ULONGLONG) {
175 // NPY_ULONGLONG is equivalent to NPY_UINT64, while their enum values
176 // might be different on certain platforms.
177 *out_tf_datatype = TF_UINT64;
178 break;
179 } else if (pyarray_type == NPY_LONGLONG) {
180 // NPY_LONGLONG is equivalent to NPY_INT64, while their enum values
181 // might be different on certain platforms.
182 *out_tf_datatype = TF_INT64;
183 break;
184 } else if (pyarray_type == NPY_INT) {
185 // NPY_INT is equivalent to NPY_INT32, while their enum values might be
186 // different on certain platforms.
187 *out_tf_datatype = TF_INT32;
188 break;
189 } else if (pyarray_type == NPY_UINT) {
190 // NPY_UINT is equivalent to NPY_UINT32, while their enum values might
191 // be different on certain platforms.
192 *out_tf_datatype = TF_UINT32;
193 break;
194 }
195 return errors::Internal("Unsupported numpy type: ",
196 numpy_type_name(pyarray_type));
197 }
198 return Status::OK();
199 }
200
PyObjectToString(PyObject * obj,const char ** ptr,Py_ssize_t * len,PyObject ** ptr_owner)201 Status PyObjectToString(PyObject* obj, const char** ptr, Py_ssize_t* len,
202 PyObject** ptr_owner) {
203 *ptr_owner = nullptr;
204 if (PyBytes_Check(obj)) {
205 char* buf;
206 if (PyBytes_AsStringAndSize(obj, &buf, len) != 0) {
207 return errors::Internal("Unable to get element as bytes.");
208 }
209 *ptr = buf;
210 return Status::OK();
211 } else if (PyUnicode_Check(obj)) {
212 #if (PY_MAJOR_VERSION > 3 || (PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION >= 3))
213 *ptr = PyUnicode_AsUTF8AndSize(obj, len);
214 if (*ptr != nullptr) return Status::OK();
215 #else
216 PyObject* utemp = PyUnicode_AsUTF8String(obj);
217 char* buf;
218 if (utemp != nullptr && PyBytes_AsStringAndSize(utemp, &buf, len) != -1) {
219 *ptr = buf;
220 *ptr_owner = utemp;
221 return Status::OK();
222 }
223 Py_XDECREF(utemp);
224 #endif
225 return errors::Internal("Unable to convert element to UTF-8");
226 } else {
227 return errors::Internal("Unsupported object type ", obj->ob_type->tp_name);
228 }
229 }
230
231 // Iterate over the string array 'array', extract the ptr and len of each string
232 // element and call f(ptr, len).
233 template <typename F>
PyBytesArrayMap(PyArrayObject * array,F f)234 Status PyBytesArrayMap(PyArrayObject* array, F f) {
235 Safe_PyObjectPtr iter = tensorflow::make_safe(
236 PyArray_IterNew(reinterpret_cast<PyObject*>(array)));
237 while (PyArray_ITER_NOTDONE(iter.get())) {
238 auto item = tensorflow::make_safe(PyArray_GETITEM(
239 array, static_cast<char*>(PyArray_ITER_DATA(iter.get()))));
240 if (!item) {
241 return errors::Internal("Unable to get element from the feed - no item.");
242 }
243 Py_ssize_t len;
244 const char* ptr;
245 PyObject* ptr_owner = nullptr;
246 TF_RETURN_IF_ERROR(PyObjectToString(item.get(), &ptr, &len, &ptr_owner));
247 f(ptr, len);
248 Py_XDECREF(ptr_owner);
249 PyArray_ITER_NEXT(iter.get());
250 }
251 return Status::OK();
252 }
253
254 // Encode the strings in 'array' into a contiguous buffer and return the base of
255 // the buffer. The caller takes ownership of the buffer.
EncodePyBytesArray(PyArrayObject * array,tensorflow::int64 nelems,size_t * size,void ** buffer)256 Status EncodePyBytesArray(PyArrayObject* array, tensorflow::int64 nelems,
257 size_t* size, void** buffer) {
258 // Encode all strings.
259 *size = nelems * sizeof(tensorflow::tstring);
260 std::unique_ptr<tensorflow::tstring[]> base_ptr(
261 new tensorflow::tstring[nelems]);
262 tensorflow::tstring* dst = base_ptr.get();
263
264 TF_RETURN_IF_ERROR(
265 PyBytesArrayMap(array, [&dst](const char* ptr, Py_ssize_t len) {
266 dst->assign(ptr, len);
267 dst++;
268 }));
269 *buffer = base_ptr.release();
270 return Status::OK();
271 }
272
CopyTF_TensorStringsToPyArray(const TF_Tensor * src,uint64 nelems,PyArrayObject * dst)273 Status CopyTF_TensorStringsToPyArray(const TF_Tensor* src, uint64 nelems,
274 PyArrayObject* dst) {
275 const void* tensor_data = TF_TensorData(src);
276 DCHECK(tensor_data != nullptr);
277 DCHECK_EQ(TF_STRING, TF_TensorType(src));
278
279 const tstring* tstr = static_cast<const tstring*>(tensor_data);
280
281 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
282 TF_NewStatus(), TF_DeleteStatus);
283 auto iter = make_safe(PyArray_IterNew(reinterpret_cast<PyObject*>(dst)));
284 for (int64 i = 0; i < static_cast<int64>(nelems); ++i) {
285 const tstring& tstr_i = tstr[i];
286 auto py_string =
287 make_safe(PyBytes_FromStringAndSize(tstr_i.data(), tstr_i.size()));
288 if (py_string == nullptr) {
289 return errors::Internal(
290 "failed to create a python byte array when converting element #", i,
291 " of a TF_STRING tensor to a numpy ndarray");
292 }
293
294 if (PyArray_SETITEM(dst, static_cast<char*>(PyArray_ITER_DATA(iter.get())),
295 py_string.get()) != 0) {
296 return errors::Internal("Error settings element #", i,
297 " in the numpy ndarray");
298 }
299 PyArray_ITER_NEXT(iter.get());
300 }
301 return Status::OK();
302 }
303
304 // Determine the dimensions of a numpy ndarray to be created to represent an
305 // output Tensor.
GetPyArrayDimensionsForTensor(const TF_Tensor * tensor,gtl::InlinedVector<npy_intp,4> * dims,tensorflow::int64 * nelems)306 Status GetPyArrayDimensionsForTensor(const TF_Tensor* tensor,
307 gtl::InlinedVector<npy_intp, 4>* dims,
308 tensorflow::int64* nelems) {
309 dims->clear();
310 const int ndims = TF_NumDims(tensor);
311 if (TF_TensorType(tensor) == TF_RESOURCE) {
312 if (ndims != 0) {
313 return errors::InvalidArgument(
314 "Fetching of non-scalar resource tensors is not supported.");
315 }
316 dims->push_back(TF_TensorByteSize(tensor));
317 *nelems = dims->back();
318 } else {
319 *nelems = 1;
320 for (int i = 0; i < ndims; ++i) {
321 dims->push_back(TF_Dim(tensor, i));
322 *nelems *= dims->back();
323 }
324 }
325 return Status::OK();
326 }
327
328 // Determine the type description (PyArray_Descr) of a numpy ndarray to be
329 // created to represent an output Tensor.
GetPyArrayDescrForTensor(const TF_Tensor * tensor,PyArray_Descr ** descr)330 Status GetPyArrayDescrForTensor(const TF_Tensor* tensor,
331 PyArray_Descr** descr) {
332 if (TF_TensorType(tensor) == TF_RESOURCE) {
333 PyObject* field = PyTuple_New(3);
334 #if PY_MAJOR_VERSION < 3
335 PyTuple_SetItem(field, 0, PyBytes_FromString("resource"));
336 #else
337 PyTuple_SetItem(field, 0, PyUnicode_FromString("resource"));
338 #endif
339 PyTuple_SetItem(field, 1, PyArray_TypeObjectFromType(NPY_UBYTE));
340 PyTuple_SetItem(field, 2, PyLong_FromLong(1));
341 PyObject* fields = PyList_New(1);
342 PyList_SetItem(fields, 0, field);
343 int convert_result = PyArray_DescrConverter(fields, descr);
344 Py_CLEAR(field);
345 Py_CLEAR(fields);
346 if (convert_result != 1) {
347 return errors::Internal("Failed to create numpy array description for ",
348 "TF_RESOURCE-type tensor");
349 }
350 } else {
351 int type_num = -1;
352 TF_RETURN_IF_ERROR(
353 TF_DataType_to_PyArray_TYPE(TF_TensorType(tensor), &type_num));
354 *descr = PyArray_DescrFromType(type_num);
355 }
356
357 return Status::OK();
358 }
359
FastMemcpy(void * dst,const void * src,size_t size)360 inline void FastMemcpy(void* dst, const void* src, size_t size) {
361 // clang-format off
362 switch (size) {
363 // Most compilers will generate inline code for fixed sizes,
364 // which is significantly faster for small copies.
365 case 1: memcpy(dst, src, 1); break;
366 case 2: memcpy(dst, src, 2); break;
367 case 3: memcpy(dst, src, 3); break;
368 case 4: memcpy(dst, src, 4); break;
369 case 5: memcpy(dst, src, 5); break;
370 case 6: memcpy(dst, src, 6); break;
371 case 7: memcpy(dst, src, 7); break;
372 case 8: memcpy(dst, src, 8); break;
373 case 9: memcpy(dst, src, 9); break;
374 case 10: memcpy(dst, src, 10); break;
375 case 11: memcpy(dst, src, 11); break;
376 case 12: memcpy(dst, src, 12); break;
377 case 13: memcpy(dst, src, 13); break;
378 case 14: memcpy(dst, src, 14); break;
379 case 15: memcpy(dst, src, 15); break;
380 case 16: memcpy(dst, src, 16); break;
381 #if defined(PLATFORM_GOOGLE) || defined(PLATFORM_POSIX) && \
382 !defined(IS_MOBILE_PLATFORM)
383 // On Linux, memmove appears to be faster than memcpy for
384 // large sizes, strangely enough.
385 default: memmove(dst, src, size); break;
386 #else
387 default: memcpy(dst, src, size); break;
388 #endif
389 }
390 // clang-format on
391 }
392
393 } // namespace
394
395 // TODO(slebedev): revise TF_TensorToPyArray usages and switch to the
396 // aliased version where appropriate.
TF_TensorToMaybeAliasedPyArray(Safe_TF_TensorPtr tensor,PyObject ** out_ndarray)397 Status TF_TensorToMaybeAliasedPyArray(Safe_TF_TensorPtr tensor,
398 PyObject** out_ndarray) {
399 auto dtype = TF_TensorType(tensor.get());
400 if (dtype == TF_STRING || dtype == TF_RESOURCE) {
401 return TF_TensorToPyArray(std::move(tensor), out_ndarray);
402 }
403
404 TF_Tensor* moved = tensor.release();
405 int64 nelems = -1;
406 gtl::InlinedVector<npy_intp, 4> dims;
407 TF_RETURN_IF_ERROR(GetPyArrayDimensionsForTensor(moved, &dims, &nelems));
408 return ArrayFromMemory(
409 dims.size(), dims.data(), TF_TensorData(moved),
410 static_cast<DataType>(dtype), [moved] { TF_DeleteTensor(moved); },
411 out_ndarray);
412 }
413
414 // Converts the given TF_Tensor to a numpy ndarray.
415 // If the returned status is OK, the caller becomes the owner of *out_array.
TF_TensorToPyArray(Safe_TF_TensorPtr tensor,PyObject ** out_ndarray)416 Status TF_TensorToPyArray(Safe_TF_TensorPtr tensor, PyObject** out_ndarray) {
417 // A fetched operation will correspond to a null tensor, and a None
418 // in Python.
419 if (tensor == nullptr) {
420 Py_INCREF(Py_None);
421 *out_ndarray = Py_None;
422 return Status::OK();
423 }
424 int64 nelems = -1;
425 gtl::InlinedVector<npy_intp, 4> dims;
426 TF_RETURN_IF_ERROR(
427 GetPyArrayDimensionsForTensor(tensor.get(), &dims, &nelems));
428
429 // If the type is neither string nor resource we can reuse the Tensor memory.
430 TF_Tensor* original = tensor.get();
431 TF_Tensor* moved = TF_TensorMaybeMove(tensor.release());
432 if (moved != nullptr) {
433 if (ArrayFromMemory(
434 dims.size(), dims.data(), TF_TensorData(moved),
435 static_cast<DataType>(TF_TensorType(moved)),
436 [moved] { TF_DeleteTensor(moved); }, out_ndarray)
437 .ok()) {
438 return Status::OK();
439 }
440 }
441 tensor.reset(original);
442
443 // Copy the TF_TensorData into a newly-created ndarray and return it.
444 PyArray_Descr* descr = nullptr;
445 TF_RETURN_IF_ERROR(GetPyArrayDescrForTensor(tensor.get(), &descr));
446 Safe_PyObjectPtr safe_out_array =
447 tensorflow::make_safe(PyArray_Empty(dims.size(), dims.data(), descr, 0));
448 if (!safe_out_array) {
449 return errors::Internal("Could not allocate ndarray");
450 }
451 PyArrayObject* py_array =
452 reinterpret_cast<PyArrayObject*>(safe_out_array.get());
453 if (TF_TensorType(tensor.get()) == TF_STRING) {
454 Status s = CopyTF_TensorStringsToPyArray(tensor.get(), nelems, py_array);
455 if (!s.ok()) {
456 return s;
457 }
458 } else if (static_cast<size_t>(PyArray_NBYTES(py_array)) !=
459 TF_TensorByteSize(tensor.get())) {
460 return errors::Internal("ndarray was ", PyArray_NBYTES(py_array),
461 " bytes but TF_Tensor was ",
462 TF_TensorByteSize(tensor.get()), " bytes");
463 } else {
464 FastMemcpy(PyArray_DATA(py_array), TF_TensorData(tensor.get()),
465 PyArray_NBYTES(py_array));
466 }
467
468 *out_ndarray = safe_out_array.release();
469 return Status::OK();
470 }
471
NdarrayToTensor(TFE_Context * ctx,PyObject * ndarray,Safe_TF_TensorPtr * ret)472 Status NdarrayToTensor(TFE_Context* ctx, PyObject* ndarray,
473 Safe_TF_TensorPtr* ret) {
474 DCHECK(ret != nullptr);
475
476 // Make sure we dereference this array object in case of error, etc.
477 Safe_PyObjectPtr array_safe(make_safe(
478 PyArray_FromAny(ndarray, nullptr, 0, 0, NPY_ARRAY_CARRAY_RO, nullptr)));
479 if (!array_safe) return errors::InvalidArgument("Not a ndarray.");
480 PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
481
482 // Convert numpy dtype to TensorFlow dtype.
483 TF_DataType dtype = TF_FLOAT;
484 TF_RETURN_IF_ERROR(PyArray_TYPE_to_TF_DataType(array, &dtype));
485
486 tensorflow::int64 nelems = 1;
487 gtl::InlinedVector<int64_t, 4> dims;
488 for (int i = 0; i < PyArray_NDIM(array); ++i) {
489 dims.push_back(PyArray_SHAPE(array)[i]);
490 nelems *= dims[i];
491 }
492
493 // Create a TF_Tensor based on the fed data. In the case of non-string data
494 // type, this steals a reference to array, which will be relinquished when
495 // the underlying buffer is deallocated. For string, a new temporary buffer
496 // is allocated into which the strings are encoded.
497 if (dtype == TF_RESOURCE) {
498 size_t size = PyArray_NBYTES(array);
499 array_safe.release();
500
501 if (ctx) {
502 *ret = make_safe(new TF_Tensor{tensorflow::unwrap(ctx)->CreateTensor(
503 static_cast<tensorflow::DataType>(dtype), {}, 0, PyArray_DATA(array),
504 size, &DelayedNumpyDecref, array)});
505 } else {
506 *ret = make_safe(TF_NewTensor(dtype, {}, 0, PyArray_DATA(array), size,
507 &DelayedNumpyDecref, array));
508 }
509
510 } else if (dtype != TF_STRING) {
511 size_t size = PyArray_NBYTES(array);
512 array_safe.release();
513 if (ctx) {
514 *ret = make_safe(new TF_Tensor{tensorflow::unwrap(ctx)->CreateTensor(
515 static_cast<tensorflow::DataType>(dtype), dims.data(), dims.size(),
516 PyArray_DATA(array), size, &DelayedNumpyDecref, array)});
517 } else {
518 *ret = make_safe(TF_NewTensor(dtype, dims.data(), dims.size(),
519 PyArray_DATA(array), size,
520 &DelayedNumpyDecref, array));
521 }
522
523 } else {
524 size_t size = 0;
525 void* encoded = nullptr;
526 TF_RETURN_IF_ERROR(EncodePyBytesArray(array, nelems, &size, &encoded));
527 if (ctx) {
528 *ret = make_safe(new TF_Tensor{tensorflow::unwrap(ctx)->CreateTensor(
529 static_cast<tensorflow::DataType>(dtype), dims.data(), dims.size(),
530 encoded, size,
531 [](void* data, size_t len, void* arg) {
532 delete[] reinterpret_cast<tensorflow::tstring*>(data);
533 },
534 nullptr)});
535 } else {
536 *ret = make_safe(TF_NewTensor(
537 dtype, dims.data(), dims.size(), encoded, size,
538 [](void* data, size_t len, void* arg) {
539 delete[] reinterpret_cast<tensorflow::tstring*>(data);
540 },
541 nullptr));
542 }
543 }
544
545 return Status::OK();
546 }
547
548 Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst);
549 TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, Status* status);
550
NdarrayToTensor(PyObject * obj,Tensor * ret)551 Status NdarrayToTensor(PyObject* obj, Tensor* ret) {
552 Safe_TF_TensorPtr tf_tensor = make_safe(static_cast<TF_Tensor*>(nullptr));
553 Status s = NdarrayToTensor(nullptr /*ctx*/, obj, &tf_tensor);
554 if (!s.ok()) {
555 return s;
556 }
557 return TF_TensorToTensor(tf_tensor.get(), ret);
558 }
559
TensorToNdarray(const Tensor & t,PyObject ** ret)560 Status TensorToNdarray(const Tensor& t, PyObject** ret) {
561 Status status;
562 Safe_TF_TensorPtr tf_tensor = make_safe(TF_TensorFromTensor(t, &status));
563 if (!status.ok()) {
564 return status;
565 }
566 return TF_TensorToPyArray(std::move(tf_tensor), ret);
567 }
568
569 } // namespace tensorflow
570