1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/python/eager/pywrap_tensor.h"
17
18 #include <stdlib.h>
19 #include <string.h>
20
21 #include <cmath>
22
23 #include "structmember.h" // NOLINT // For PyMemberDef
24 #include "pybind11/pybind11.h"
25 #include "tensorflow/c/c_api.h"
26 #include "tensorflow/c/eager/c_api.h"
27 #include "tensorflow/c/eager/c_api_internal.h"
28 #include "tensorflow/c/eager/tfe_context_internal.h"
29 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
30 #include "tensorflow/c/tf_status.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/framework/types.pb.h"
33 #include "tensorflow/core/lib/strings/strcat.h"
34 #include "tensorflow/python/eager/pywrap_tensor_conversion.h"
35 #include "tensorflow/python/eager/pywrap_tfe.h"
36 #include "tensorflow/python/lib/core/ndarray_tensor.h"
37 #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
38 #include "tensorflow/python/lib/core/numpy.h"
39 #include "tensorflow/python/lib/core/py_exception_registry.h"
40 #include "tensorflow/python/lib/core/py_seq_tensor.h"
41 #include "tensorflow/python/lib/core/safe_ptr.h"
42
43 // forward declare
44 struct EagerTensor;
45 namespace tensorflow {
46
47 // Convert a TFE_TensorHandle to a Python numpy.ndarray object.
48 // The two may share underlying storage so changes to one may reflect in the
49 // other.
TFE_TensorHandleToNumpy(TFE_TensorHandle * handle,TF_Status * status)50 PyObject* TFE_TensorHandleToNumpy(TFE_TensorHandle* handle, TF_Status* status) {
51 if (TFE_TensorHandleDataType(handle) == TF_RESOURCE) {
52 TF_SetStatus(status, TF_INVALID_ARGUMENT,
53 "Cannot convert a Tensor of dtype resource to a NumPy array.");
54 return nullptr;
55 }
56
57 tensorflow::Safe_TF_TensorPtr tensor = nullptr;
58 Py_BEGIN_ALLOW_THREADS;
59 tensor = tensorflow::make_safe(TFE_TensorHandleResolve(handle, status));
60 Py_END_ALLOW_THREADS;
61 if (!status->status.ok()) {
62 return nullptr;
63 }
64
65 PyObject* ret = nullptr;
66 auto cppstatus =
67 tensorflow::TF_TensorToMaybeAliasedPyArray(std::move(tensor), &ret);
68 tensorflow::Set_TF_Status_from_Status(status, cppstatus);
69 if (!status->status.ok()) {
70 Py_XDECREF(ret);
71 return nullptr;
72 }
73 CHECK_NE(ret, nullptr);
74 return ret;
75 }
76 } // namespace tensorflow
77 namespace {
78
79 using tensorflow::TFE_TensorHandleToNumpy;
80
81 // An instance of _EagerTensorProfiler that will receive callbacks about
82 // events on eager tensors. This is set by TFE_Py_InitEagerTensor, if at all.
83 PyObject* eager_tensor_profiler = nullptr;
84
85 // Read-only dict. Please don't use this in any setting where the dict might
86 // actually get mutated. This is only used to pass empty kwargs when creating a
87 // new EagerTensor.
EmptyDict()88 PyObject* EmptyDict() {
89 static PyObject* empty_dict = PyDict_New();
90 return empty_dict;
91 }
92
EmptyTuple()93 PyObject* EmptyTuple() {
94 static PyObject* empty_tuple = PyTuple_New(0);
95 return empty_tuple;
96 }
97
GetContextHandle(PyObject * py_context)98 TFE_Context* GetContextHandle(PyObject* py_context) {
99 tensorflow::Safe_PyObjectPtr py_context_handle(
100 PyObject_GetAttrString(py_context, "_handle"));
101 if (py_context_handle == nullptr) {
102 // Current Python code makes sure this never happens. If it does, or
103 // becomes hard to maintain, we can call the ensure_initialized() method
104 // here.
105 PyErr_SetString(
106 PyExc_TypeError,
107 "Expected `context` argument in EagerTensor constructor to have a "
108 "`_handle` attribute but it did not. Was eager Context initialized?");
109 return nullptr;
110 }
111
112 auto* ctx = reinterpret_cast<TFE_Context*>(
113 PyCapsule_GetPointer(py_context_handle.get(), nullptr));
114 if (ctx == nullptr) {
115 PyErr_SetString(PyExc_TypeError,
116 tensorflow::strings::StrCat(
117 "Expected context._handle to contain a PyCapsule "
118 "encoded pointer to TFE_Context. Got ",
119 Py_TYPE(py_context_handle.get())->tp_name)
120 .c_str());
121 }
122 return ctx;
123 }
124
125
126 // Helper function to convert `v` to a tensorflow::DataType and store it in
127 // `*out`. Returns true on success, false otherwise.
128 // Note that we assume that v is a python int (not long) representing a
129 // TF_DataType/tensorflow::DataType value.
PyIntToDataType(PyObject * v,tensorflow::DataType * out)130 bool PyIntToDataType(PyObject* v, tensorflow::DataType* out) {
131 #if PY_MAJOR_VERSION < 3
132 if (PyInt_Check(v)) {
133 *out = static_cast<tensorflow::DataType>(PyInt_AS_LONG(v));
134 return true;
135 }
136 #else
137 if (PyLong_Check(v)) {
138 *out = static_cast<tensorflow::DataType>(PyLong_AsLong(v));
139 return true;
140 }
141 #endif
142 return false;
143 }
144
145 // Helper function to create a python integer from TF_DataType.
PyIntFromDataType(TF_DataType l)146 PyObject* PyIntFromDataType(TF_DataType l) {
147 #if PY_MAJOR_VERSION < 3
148 return PyInt_FromLong(l);
149 #else
150 return PyLong_FromLong(l);
151 #endif
152 }
153
154 // PyObject->tensorflow::DataType conversion function to be used with
155 // PyArg_Parse* APIs.
ConvertDataType(PyObject * obj,tensorflow::DataType * dst)156 int ConvertDataType(PyObject* obj, tensorflow::DataType* dst) {
157 if (obj == Py_None) {
158 *dst = tensorflow::DataType::DT_INVALID;
159 } else if (!PyIntToDataType(obj, dst)) {
160 PyErr_SetString(
161 PyExc_TypeError,
162 tensorflow::strings::StrCat(
163 "Expecting a DataType value for dtype. Got ", Py_TYPE(obj)->tp_name)
164 .c_str());
165 return 0;
166 }
167
168 return 1;
169 }
170
171 // Conversion function extracting a const char** device name from a PyObject.
172 // The function should be used with PyArg_Parse* APIs.
ConvertDeviceName(PyObject * obj,const char ** dst)173 int ConvertDeviceName(PyObject* obj, const char** dst) {
174 if (obj == Py_None) {
175 *dst = nullptr;
176 } else {
177 auto device_name = TFE_GetPythonString(obj);
178 if (device_name == nullptr) {
179 PyErr_Clear();
180 PyErr_SetString(PyExc_TypeError, "Error parsing device argument.");
181 return 0;
182 }
183 *dst = device_name;
184 }
185
186 return 1;
187 }
188
RaiseExceptionTypeFromTFStatus(TF_Status * status)189 void RaiseExceptionTypeFromTFStatus(TF_Status* status) {
190 TF_Code code = TF_GetCode(status);
191 PyObject* exception = tensorflow::PyExceptionRegistry::Lookup(code);
192 PyErr_SetObject(exception,
193 pybind11::make_tuple(pybind11::none(), pybind11::none(),
194 TF_Message(status))
195 .ptr());
196 }
197
198 } // namespace
199
200 namespace tensorflow {
201 // This function checks whether the desired type is "compatible" with the
202 // inferred type. At a high level, compatibility means that all integral types
203 // are compatible with each other, and all floating types are compatible with
204 // each other.
205 //
206 // Type compatibility doesn't consider overflows (i.e. int64 is *always*
207 // compatible with int32). This is intended to match graph behavior.
IsCompatible(DataType desired,DataType returned)208 bool IsCompatible(DataType desired, DataType returned) {
209 if (desired == returned) return true;
210
211 if (DataTypeIsInteger(desired) && DataTypeIsInteger(returned)) {
212 return true;
213 } else if (DataTypeIsFloating(desired) &&
214 (DataTypeIsFloating(returned) || DataTypeIsInteger(returned))) {
215 return true;
216 } else if (DataTypeIsComplex(desired) &&
217 (DataTypeIsComplex(returned) || DataTypeIsInteger(returned) ||
218 DataTypeIsFloating(returned))) {
219 return true;
220 } else if (DataTypeIsQuantized(desired) && DataTypeIsInteger(returned)) {
221 return true;
222 }
223 return false;
224 }
225
226 // TODO(nareshmodi): Move EagerCast and ReadVariableOp (which use the C API to
227 // execute TFE Ops) to a separate common library.
228 // Casts data referred to by `handle` from type `src_type_enum` to type
229 // `dst_type_enum`.
EagerCast(TFE_Context * ctx,TFE_TensorHandle * handle,TF_DataType src_type_enum,TF_DataType dst_type_enum,TF_Status * out_status)230 TFE_TensorHandle* EagerCast(TFE_Context* ctx, TFE_TensorHandle* handle,
231 TF_DataType src_type_enum,
232 TF_DataType dst_type_enum, TF_Status* out_status) {
233 if (ctx == nullptr) return nullptr;
234 const char* op_name = "Cast";
235 const char* device_name = "/device:CPU:0";
236 TFE_Op* op = TFE_NewOp(ctx, op_name, out_status);
237 #define RETURN_ERROR \
238 { \
239 TFE_DeleteOp(op); \
240 return nullptr; \
241 }
242 if (!out_status->status.ok()) RETURN_ERROR
243 TFE_OpSetDevice(op, device_name, out_status);
244 if (!out_status->status.ok()) RETURN_ERROR
245 TFE_OpAddInput(op, handle, out_status);
246 if (!out_status->status.ok()) RETURN_ERROR
247 TFE_OpSetAttrType(op, "SrcT", src_type_enum);
248 TFE_OpSetAttrType(op, "DstT", dst_type_enum);
249 TFE_OpSetAttrBool(op, "Truncate", false);
250 TFE_TensorHandle* output = nullptr;
251 int num_outputs = 1;
252 TFE_Execute(op, &output, &num_outputs, out_status);
253 if (!out_status->status.ok() || num_outputs != 1 || output == nullptr) {
254 if (output != nullptr) {
255 TFE_DeleteTensorHandle(output);
256 }
257 RETURN_ERROR
258 }
259 TFE_DeleteOp(op);
260 return output;
261 #undef RETURN_ERROR
262 }
263
ConvertToEagerTensorUncached(TFE_Context * ctx,PyObject * value,tensorflow::DataType dtype,const char * device_name)264 TFE_TensorHandle* ConvertToEagerTensorUncached(TFE_Context* ctx,
265 PyObject* value,
266 tensorflow::DataType dtype,
267 const char* device_name) {
268 tensorflow::Safe_PyObjectPtr value_decrefer;
269 if (PyArray_IsScalar(value, Generic)) {
270 // Convert numpy scalars to numpy arrays.
271 value = PyArray_FromScalar(value, nullptr);
272 // The returned value needs to be DECREF'd, but the original value was
273 // created in python code, and doesn't need to be DECREF'd.
274 value_decrefer.reset(value);
275 }
276
277 Safe_TFE_TensorHandlePtr handle =
278 make_safe(PySeqToTFE_TensorHandle(ctx, value, dtype));
279
280 if (handle == nullptr) return nullptr;
281
282 Safe_TF_StatusPtr status = make_safe(TF_NewStatus());
283 TF_DataType handle_dtype = TFE_TensorHandleDataType(handle.get());
284 if (dtype != tensorflow::DT_INVALID &&
285 dtype != static_cast<DataType>(handle_dtype)) {
286 if (tensorflow::IsCompatible(dtype, static_cast<DataType>(handle_dtype))) {
287 handle = tensorflow::make_safe(
288 tensorflow::EagerCast(ctx, handle.get(), handle_dtype,
289 static_cast<TF_DataType>(dtype), status.get()));
290 if (!status->status.ok()) {
291 PyErr_SetString(PyExc_TypeError,
292 absl::StrCat("Error while casting from dtype ",
293 tensorflow::DataTypeString(
294 static_cast<DataType>(handle_dtype)),
295 " to ", tensorflow::DataTypeString(dtype),
296 ". ", TF_Message(status.get()))
297 .c_str());
298 return nullptr;
299 }
300 } else {
301 tensorflow::Safe_PyObjectPtr value_str(PyObject_Repr(value));
302 PyErr_SetString(
303 PyExc_TypeError,
304 absl::StrCat("Cannot convert ", TFE_GetPythonString(value_str.get()),
305 " to EagerTensor of dtype ",
306 tensorflow::DataTypeString(dtype))
307 .c_str());
308 return nullptr;
309 }
310 }
311
312 // We always generate CPU:0 tensors, but we may need to change the device
313 // slightly, as for example from /job:localhost/... to /job:worker/...
314 //
315 // Note that this is a shallow copy and will share the underlying buffer,
316 // because we are copying to the same device.
317 if (device_name != nullptr &&
318 strstr(device_name, "/device:CPU:0") != nullptr) {
319 handle = make_safe(TFE_TensorHandleCopyToDevice(handle.get(), ctx,
320 device_name, status.get()));
321 const TF_Code code = TF_GetCode(status.get());
322 if (code != TF_OK) {
323 RaiseExceptionTypeFromTFStatus(status.get());
324 return nullptr;
325 }
326 }
327
328 return handle.release();
329 }
330
ConvertToEagerTensor(TFE_Context * ctx,PyObject * value,DataType dtype,const char * device_name)331 TFE_TensorHandle* ConvertToEagerTensor(TFE_Context* ctx, PyObject* value,
332 DataType dtype,
333 const char* device_name) {
334 // Reduce the overhead of allocation/transfer-to-device for scalars by
335 // caching the corresponding handles. Note that currently only Python
336 // scalars are cached.
337 // TODO(slebedev): also cache singleton NumPy arrays and scalars?
338 if (PyArray_IsPythonNumber(value)) {
339 auto* cache = TFE_TensorHandleCache::Get();
340 TFE_TensorHandle* handle = cache->Lookup(value, dtype, ctx, device_name);
341 if (handle != nullptr) return handle;
342 handle = ConvertToEagerTensorUncached(ctx, value, dtype, device_name);
343 if (handle == nullptr) return nullptr;
344 if (!PyFloat_Check(value) || std::isfinite(PyFloat_AS_DOUBLE(value))) {
345 cache->Insert(value, dtype, ctx, device_name, handle);
346 }
347 return handle;
348 } else {
349 return ConvertToEagerTensorUncached(ctx, value, dtype, device_name);
350 }
351 }
352
353 } // namespace tensorflow
354
355 extern "C" {
356
357 static const int kMaxEagerTensorParentSize = 64;
358
359 // TODO(agarwal): store context handle in EagerTensor.
360 typedef struct EagerTensor {
361 PyObject_HEAD;
362 // Note that we leave kMaxEagerTensorParentSize bytes here for use by the
363 // parent class. The parent class is set at runtime, so we don't know the
364 // exact size at compile time.
365 char unused[kMaxEagerTensorParentSize];
366 TFE_TensorHandle* handle;
367 int64_t id;
368 // Indicates whether it's a packed tensor or not.
369 bool is_packed;
370 // This mirrors tensorflow.core.framework.ops.Tensor._handle_data Which will
371 // be None for tensors of type other than DT_RESOURCE. For DT_RESOURCE
372 // tensors, this will contain a serialized HandleData proto with shape
373 // inference metadata about shapes and dtypes of resources accessible from
374 // this handle.
375 // Note that we assume that handle_data cannot participate in reference
376 // cycles, and hence don't provide GC support for it.
377 PyObject* handle_data;
378
379 // This stores `_tensor_shape`, a cached `TensorShape` object, and is set the
380 // first time that `_EagerTensorBase`'s `shape` property is called.
381 PyObject* tensor_shape;
382
383 // We store a status object here as an optimization to avoid allocating a new
384 // Status objects on different functions that operate on EagerTensor and need
385 // to use a TF_Status object. However note that accesses to `status` are not
386 // thread-safe.
387 TF_Status status;
388
389 // The eager Context (from eager/context.py) used by this Tensor.
390 // This is currently used only to make sure context outlives TensorHandles.
391 PyObject* context;
392
393 PyObject* weakreflist; /* List of weak references */
394
395 // Per-instance attribute dictionary, to support monkey patching
396 // (e.g. EagerTensor.assign when slicing variables). This dictionary is
397 // created by CPython the first time an attribute is assigned, pointed to by
398 // tp_dictoffset. Note that garbage collection is not enabled for
399 // EagerTensors, so assigning objects to EagerTensor attributes which require
400 // garbage collection is likely to cause issues.
401 PyObject* dict;
402 } EagerTensor;
403
404 namespace {
405
406 // Returns true on success - successfully invoked or no profiler registered.
407 // Returns false if some error occurred.
MaybeInvokeCreatedOnEagerTensorProfiler(EagerTensor * created_tensor)408 bool MaybeInvokeCreatedOnEagerTensorProfiler(EagerTensor* created_tensor) {
409 if (eager_tensor_profiler != nullptr) {
410 #if PY_MAJOR_VERSION < 3
411 PyObject* created_method_name = PyString_InternFromString("created");
412 #else
413 PyObject* created_method_name = PyUnicode_InternFromString("created");
414 #endif
415 if (created_method_name == nullptr) {
416 return false;
417 }
418 PyObject* result = PyObject_CallMethodObjArgs(
419 eager_tensor_profiler, created_method_name, created_tensor, NULL);
420 if (result == nullptr) {
421 LOG(ERROR) << "Invoking created() on EagerTensor profiler failed";
422 // While we can potentially continue because the error is related to
423 // profiling, we choose to return an error because:
424 // - If profiling is used, the user likely wants to stop execution on
425 // profiling errors.
426 // - Error in profiling code might have left some state in an invalid
427 // form that can lead to an error later on. Better to fail fast.
428 Py_DECREF(created_method_name);
429 return false;
430 }
431 Py_DECREF(created_method_name);
432 Py_DECREF(result);
433 }
434 return true;
435 }
436
437 } // namespace
438
439 // tp_init for EagerTensor.
EagerTensor_init(EagerTensor * self,PyObject * args,PyObject * kwds)440 int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) {
441 self->id = get_uid();
442 self->handle = nullptr;
443 self->is_packed = false;
444 Py_INCREF(Py_None);
445 self->handle_data = Py_None;
446 Py_INCREF(Py_None);
447 self->tensor_shape = Py_None;
448 self->status.status = tensorflow::Status::OK();
449 self->dict = nullptr;
450 self->weakreflist = nullptr;
451 self->context = nullptr;
452 PyObject* value;
453 const char* device_name = nullptr;
454 tensorflow::DataType dtype = tensorflow::DataType::DT_INVALID;
455 const char* kwlist[] = {"value", "device", "dtype", nullptr};
456 if (!PyArg_ParseTupleAndKeywords(
457 args, kwds, "OO&|O&", const_cast<char**>(kwlist), &value,
458 ConvertDeviceName, &device_name, ConvertDataType, &dtype)) {
459 return -1;
460 }
461
462 PyObject* py_context = GetPyEagerContext();
463 if (py_context == nullptr) return -1;
464 self->context = py_context;
465
466 auto* handle = tensorflow::ConvertToEagerTensor(GetContextHandle(py_context),
467 value, dtype, device_name);
468 if (handle == nullptr) return -1;
469 self->handle = handle;
470
471 if (!MaybeInvokeCreatedOnEagerTensorProfiler(self)) {
472 return -1;
473 }
474
475 return 0;
476 }
477
478 // tp_dealloc for EagerTensor.
EagerTensor_dealloc(EagerTensor * self)479 void EagerTensor_dealloc(EagerTensor* self) {
480 // Unhook the object from python's GC so that the weakref deleter doesn't
481 // try to re-delete this.
482 PyObject_GC_UnTrack((PyObject*)self);
483
484 // Clear weak references to self.
485 // Needs to happen before any actual destruction.
486 PyObject_ClearWeakRefs((PyObject*)self);
487
488 Py_DECREF(self->handle_data);
489 Py_DECREF(self->tensor_shape);
490 // If an attribute dictionary has been created, release it. Note that this
491 // is only ever created by CPython's attribute setting methods; we don't
492 // create it ourselves.
493 Py_CLEAR(self->dict);
494 if (self->handle != nullptr) {
495 TFE_DeleteTensorHandle(self->handle);
496 self->handle = nullptr;
497 }
498
499 // Decref context after deleting the tensor handle.
500 Py_XDECREF(self->context);
501
502 // We have the global interpreter lock, so use this chance to perform delayed
503 // refcount decrements.
504 tensorflow::ClearDecrefCache();
505 auto id = self->id;
506 Py_TYPE(self)->tp_free(self);
507 TFE_Py_TapeSetDeleteTrace(id);
508 }
509
510 // Getter for `_id`.
EagerTensor_getid(EagerTensor * self,void * closure)511 static PyObject* EagerTensor_getid(EagerTensor* self, void* closure) {
512 return PyLong_FromLongLong(self->id);
513 }
514
515 // Getter for `_datatype_enum`.
EagerTensor_datatype_enum(EagerTensor * self)516 static PyObject* EagerTensor_datatype_enum(EagerTensor* self) {
517 return PyIntFromDataType(TFE_TensorHandleDataType(self->handle));
518 }
519
520 // Getter for `_shape_tuple`.
EagerTensor_shape_tuple(EagerTensor * self)521 static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
522 auto handle = self->handle;
523 int n = TFE_TensorHandleNumDims(handle, &self->status);
524 TF_Code code = TF_GetCode(&self->status);
525 if (code != TF_OK) {
526 RaiseExceptionTypeFromTFStatus(&self->status);
527 // Cleanup self->status before returning.
528 self->status.status = tensorflow::Status::OK();
529 return nullptr;
530 }
531 PyObject* shape = PyTuple_New(n);
532 if (PyErr_Occurred()) return nullptr;
533 for (int i = 0; i < n; ++i) {
534 PyObject* dim =
535 PyLong_FromLongLong(TFE_TensorHandleDim(handle, i, &self->status));
536 code = TF_GetCode(&self->status);
537 if (code != TF_OK || dim == nullptr ||
538 PyTuple_SetItem(shape, i, dim) != 0) {
539 if (code != TF_OK) {
540 RaiseExceptionTypeFromTFStatus(&self->status);
541 } else {
542 PyErr_SetString(PyExc_RuntimeError, "Error while creating shape");
543 }
544 // Cleanup self->status before returning.
545 self->status.status = tensorflow::Status::OK();
546 Py_DECREF(shape);
547 if (dim != nullptr) Py_DECREF(dim);
548 return nullptr;
549 }
550 }
551 return shape;
552 }
553
554 // Getter for `_rank`.
EagerTensor_rank(EagerTensor * self)555 static PyObject* EagerTensor_rank(EagerTensor* self) {
556 int num_dims = TFE_TensorHandleNumDims(self->handle, &self->status);
557 if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) {
558 // Cleanup self->status before returning.
559 self->status.status = tensorflow::Status::OK();
560 return nullptr;
561 }
562 #if PY_MAJOR_VERSION < 3
563 return PyInt_FromLong(num_dims);
564 #else
565 return PyLong_FromLong(num_dims);
566 #endif
567 }
568
569 // Getter for `_num_elements`.
EagerTensor_num_elements(EagerTensor * self)570 static PyObject* EagerTensor_num_elements(EagerTensor* self) {
571 auto handle = self->handle;
572 int n = TFE_TensorHandleNumElements(handle, &self->status);
573 if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) {
574 // Cleanup self->status before returning.
575 self->status.status = tensorflow::Status::OK();
576 return nullptr;
577 }
578 return PyLong_FromLongLong(n);
579 }
580
EagerTensor_handle_data(EagerTensor * self,void * unused)581 static PyObject* EagerTensor_handle_data(EagerTensor* self, void* unused) {
582 Py_INCREF(self->handle_data);
583 return self->handle_data;
584 }
585
EagerTensor_sethandle_data(EagerTensor * self,PyObject * value,void * unused)586 static int EagerTensor_sethandle_data(EagerTensor* self, PyObject* value,
587 void* unused) {
588 Py_DECREF(self->handle_data);
589 Py_INCREF(value);
590 self->handle_data = value;
591 return 0;
592 }
593
EagerTensor_tensor_shape(EagerTensor * self,void * unused)594 static PyObject* EagerTensor_tensor_shape(EagerTensor* self, void* unused) {
595 Py_INCREF(self->tensor_shape);
596 return self->tensor_shape;
597 }
598
EagerTensor_settensor_shape(EagerTensor * self,PyObject * value,void * unused)599 static int EagerTensor_settensor_shape(EagerTensor* self, PyObject* value,
600 void* unused) {
601 Py_DECREF(self->tensor_shape);
602 Py_INCREF(value);
603 self->tensor_shape = value;
604 return 0;
605 }
606
607 // Function `_copy_to_device`.
EagerTensor_copy_to_device(EagerTensor * self,PyObject * args,PyObject * kwds)608 static PyObject* EagerTensor_copy_to_device(EagerTensor* self, PyObject* args,
609 PyObject* kwds) {
610 if (!_PyArg_NoKeywords("copy_to_device", kwds)) return nullptr;
611
612 const char* device_name = nullptr;
613 if (!PyArg_ParseTuple(args, "O&:copy_to_device", ConvertDeviceName,
614 &device_name)) {
615 return nullptr;
616 }
617
618 // Note that this is a shallow copy and will share the underlying buffer
619 // if copying to the same device.
620 TFE_TensorHandle* handle = TFE_TensorHandleCopyToDevice(
621 self->handle, GetContextHandle(self->context), device_name,
622 &self->status);
623 if (MaybeRaiseExceptionFromTFStatus(&self->status, PyExc_RuntimeError)) {
624 // Cleanup self->status before returning.
625 self->status.status = tensorflow::Status::OK();
626 return nullptr;
627 }
628
629 return EagerTensorFromHandle(handle);
630 }
631
632 // Function `_numpy_internal`.
633 // Convert an EagerTensor to a Python numpy.ndarray object.
634 // The two may share underlying storage so changes to one may reflect in the
635 // other.
636 // Note that if `self` is not on CPU, we raise an Exception.
EagerTensor_numpy_internal(EagerTensor * self)637 static PyObject* EagerTensor_numpy_internal(EagerTensor* self) {
638 auto* py_array = TFE_TensorHandleToNumpy(self->handle, &self->status);
639 if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) {
640 Py_XDECREF(py_array);
641 // Cleanup self->status before returning.
642 self->status.status = tensorflow::Status::OK();
643 return nullptr;
644 } else {
645 return PyArray_Return(reinterpret_cast<PyArrayObject*>(py_array));
646 }
647 }
648
649 // Getter `device`.
EagerTensor_device(EagerTensor * self)650 static PyObject* EagerTensor_device(EagerTensor* self) {
651 const char* device = TFE_TensorHandleDeviceName(self->handle, &self->status);
652 if (MaybeRaiseExceptionFromTFStatus(&self->status, PyExc_ValueError)) {
653 // Cleanup self->status before returning.
654 self->status.status = tensorflow::Status::OK();
655 return nullptr;
656 }
657 #if PY_MAJOR_VERSION >= 3
658 return PyUnicode_FromString(device);
659 #else
660 return PyBytes_FromString(device);
661 #endif
662 }
663
664 // Getter `backing_device`.
EagerTensor_backing_device(EagerTensor * self)665 static PyObject* EagerTensor_backing_device(EagerTensor* self) {
666 const char* device =
667 TFE_TensorHandleBackingDeviceName(self->handle, &self->status);
668 if (MaybeRaiseExceptionFromTFStatus(&self->status, PyExc_ValueError)) {
669 // Cleanup self->status before returning.
670 self->status.status = tensorflow::Status::OK();
671 return nullptr;
672 }
673 #if PY_MAJOR_VERSION >= 3
674 return PyUnicode_FromString(device);
675 #else
676 return PyBytes_FromString(device);
677 #endif
678 }
679
680 // Getter `is_packed`.
EagerTensor_is_packed(EagerTensor * self)681 static PyObject* EagerTensor_is_packed(EagerTensor* self) {
682 return PyBool_FromLong(self->is_packed);
683 }
684
685 static PyGetSetDef EagerTensor_getsetters[] = {
686 {const_cast<char*>("_id"), (getter)EagerTensor_getid, nullptr,
687 const_cast<char*>("Tensor ID."), nullptr},
688 {const_cast<char*>("device"), (getter)EagerTensor_device, nullptr,
689 const_cast<char*>("Device of op that produced the tensor."), nullptr},
690 {const_cast<char*>("backing_device"), (getter)EagerTensor_backing_device,
691 nullptr, const_cast<char*>("Device on which tensor's memory is resident."),
692 nullptr},
693 {const_cast<char*>("is_packed"), (getter)EagerTensor_is_packed, nullptr,
694 const_cast<char*>("Whether the EagerTensor is a packed tensor or not."),
695 nullptr},
696 {const_cast<char*>("_handle_data"), (getter)EagerTensor_handle_data,
697 (setter)EagerTensor_sethandle_data,
698 const_cast<char*>("Shape/DType data if the EagerTensor is a DT_RESOURCE"),
699 nullptr},
700 {const_cast<char*>("_tensor_shape"), (getter)EagerTensor_tensor_shape,
701 (setter)EagerTensor_settensor_shape,
702 const_cast<char*>("Shape of the tensor."), nullptr},
703 {nullptr} /* Sentinel */
704 };
705
706 #if PY_MAJOR_VERSION < 3
707 // Only used for Python2 since Python3 seems to set the __dict__ correctly.
708 static PyMemberDef EagerTensor_members[] = {
709 {const_cast<char*>("__dict__"), T_OBJECT, offsetof(EagerTensor, dict),
710 READONLY},
711 {nullptr},
712 };
713 #endif
714
715 static PyMethodDef EagerTensor_methods[] = {
716 {"_numpy_internal", (PyCFunction)EagerTensor_numpy_internal, METH_NOARGS,
717 PyDoc_STR("Internal method to get a NumPy array for the tensor.")},
718 {"_datatype_enum", (PyCFunction)EagerTensor_datatype_enum, METH_NOARGS,
719 PyDoc_STR("The DType of the tensor as an enum.")},
720 {"_shape_tuple", (PyCFunction)EagerTensor_shape_tuple, METH_NOARGS,
721 PyDoc_STR("The shape of the tensor as a python tuple.")},
722 {"_rank", (PyCFunction)EagerTensor_rank, METH_NOARGS,
723 PyDoc_STR("The rank of the tensor.")},
724 {"_copy_to_device", (PyCFunction)EagerTensor_copy_to_device,
725 METH_VARARGS | METH_KEYWORDS,
726 PyDoc_STR("Copies the tensor to the desired device.")},
727 {"_num_elements", (PyCFunction)EagerTensor_num_elements, METH_NOARGS,
728 PyDoc_STR("Number of elements in the tensor.")},
729 {nullptr, nullptr},
730 };
731
EagerTensor_getbuffer(EagerTensor * self,Py_buffer * view,int flags)732 static int EagerTensor_getbuffer(EagerTensor* self, Py_buffer* view,
733 int flags) {
734 if ((flags & PyBUF_WRITABLE) == PyBUF_WRITABLE) {
735 PyErr_SetString(PyExc_BufferError, "EagerTensor is not writable.");
736 return -1;
737 }
738
739 // TensorHandleToNumpy is zero-copy for everything but DT_RESOURCE and
740 // DT_STRING so the following is only slightly slower than a NumPy-free
741 // implementation.
742 auto py_array = tensorflow::make_safe(
743 TFE_TensorHandleToNumpy(self->handle, &self->status));
744 if (MaybeRaiseExceptionFromTFStatus(&self->status, PyExc_BufferError)) {
745 // Cleanup self->status before returning.
746 self->status.status = tensorflow::Status::OK();
747 return -1;
748 }
749 if (PyObject_GetBuffer(py_array.get(), view, flags) < 0) {
750 return -1;
751 }
752 view->readonly = 1;
753 return 0;
754 }
755
756 static PyBufferProcs EagerTensor_as_buffer = {
757 #if PY_MAJOR_VERSION < 3
758 nullptr, nullptr, nullptr, nullptr,
759 #endif
760 (getbufferproc)EagerTensor_getbuffer,
761 // Never called because getbufferproc delegates to NumPy.
762 (releasebufferproc) nullptr};
763
764 // Note that here we are trying to dynamically create a new class as a subclass
765 // of a "HEAPTYPE" class that is itself created in python code and passed in at
766 // runtime. This is fairly atypical and undocumented.
767 //
768 // We use the following strategy for this. Unfortunately, we have to use
769 // different approaches for python2.x vs python3.x
770 // For python2.x, we create the class as a static type and set its tp_base to
771 // the passed in type. Unfortunately setting tp_flags to include
772 // Py_TPFLAGS_HEAPTYPE does not work by itself since it needs some more
773 // initialization of the underlying PyHeapTypeObject and not doing that leads to
774 // some random crashes especially during garbage collection.
775 // python3.x explicitly disables a static subclass of a HEAPTYPE base class.
776 // However it provides a new function, PyType_FromSpecWithBases, to create
777 // types dynamically.
778
779 // Type object for EagerTensor. This is set by TFE_Py_InitEagerTensor.
780 PyTypeObject* EagerTensorType = nullptr;
781
782 #if PY_MAJOR_VERSION >= 3
783 static PyType_Slot EagerTensor_Type_slots[] = {
784 {Py_tp_dealloc, reinterpret_cast<void*>(EagerTensor_dealloc)},
785 {Py_tp_methods, reinterpret_cast<void*>(EagerTensor_methods)},
786 {Py_tp_getset, reinterpret_cast<void*>(EagerTensor_getsetters)},
787 {Py_tp_init, reinterpret_cast<void*>(EagerTensor_init)},
788 {0, nullptr},
789 };
790 #else
791
792 #define EAGER_TENSOR_TPFLAGS (Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_NEWBUFFER)
793
794 // TODO(agarwal): support active_trace.
795 static PyTypeObject _EagerTensorType = {
796 // clang-format off
797 PyVarObject_HEAD_INIT(nullptr, 0)
798 // clang-format on
799 "EagerTensor", /* tp_name */
800 sizeof(EagerTensor), /* tp_basicsize */
801 0, /* tp_itemsize */
802 (destructor)EagerTensor_dealloc, /* tp_dealloc */
803 #if PY_VERSION_HEX < 0x03080000
804 nullptr, /* tp_print */
805 #else
806 0, /* tp_vectorcall_offset */
807 #endif
808 nullptr, /* tp_getattr */
809 nullptr, /* tp_setattr */
810 nullptr, /* tp_compare */
811 nullptr, /* tp_repr */
812 nullptr, /* tp_as_number */
813 nullptr, /* tp_as_sequence */
814 nullptr, /* tp_as_mapping */
815 nullptr, /* tp_hash */
816 nullptr, /* tp_call */
817 nullptr, /* tp_str */
818 nullptr, /* tp_getattro */
819 nullptr, /* tp_setattro */
820 &EagerTensor_as_buffer, /* tp_as_buffer */
821 EAGER_TENSOR_TPFLAGS, /* tp_flags */
822 nullptr, /* tp_doc */
823 nullptr, /* tp_traverse */
824 nullptr, /* tp_clear */
825 nullptr, /* tp_richcompare */
826 offsetof(EagerTensor, weakreflist), /* tp_weaklistoffset */
827 nullptr, /* tp_iter */
828 nullptr, /* tp_iternext */
829 EagerTensor_methods, /* tp_methods */
830 EagerTensor_members, /* tp_members */
831 EagerTensor_getsetters, /* tp_getset */
832 nullptr, /* tp_base */
833 nullptr, /* tp_dict */
834 nullptr, /* tp_descr_get */
835 nullptr, /* tp_descr_set */
836 offsetof(EagerTensor, dict), /* tp_dictoffset */
837 (initproc)EagerTensor_init, /* tp_init */
838 nullptr, /* tp_alloc */
839 nullptr, /* tp_new */
840 };
841
842 #endif
843
844 } // extern "C"
845
EagerTensor_CheckExact(const PyObject * o)846 bool EagerTensor_CheckExact(const PyObject* o) {
847 return Py_TYPE(o) == EagerTensorType;
848 }
849
EagerTensor_Handle(const PyObject * o)850 TFE_TensorHandle* EagerTensor_Handle(const PyObject* o) {
851 return reinterpret_cast<const EagerTensor*>(o)->handle;
852 }
853
EagerTensorFromHandle(TFE_TensorHandle * handle,const bool is_packed)854 PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle,
855 const bool is_packed) {
856 if (handle == nullptr) {
857 return nullptr;
858 }
859 EagerTensor* t = reinterpret_cast<EagerTensor*>(
860 EagerTensorType->tp_new(EagerTensorType, EmptyTuple(), EmptyDict()));
861 if (t != nullptr) {
862 t->id = get_uid();
863 t->is_packed = is_packed;
864 Py_INCREF(Py_None);
865 t->handle_data = Py_None;
866 Py_INCREF(Py_None);
867 t->tensor_shape = Py_None;
868 t->handle = handle;
869 t->status.status = tensorflow::Status::OK();
870 t->weakreflist = nullptr;
871 PyObject* py_context = GetPyEagerContext();
872 if (py_context == nullptr) {
873 LOG(ERROR) << "Cannot create an eager tensor before eager context has "
874 "been set or after it has been deleted";
875 return nullptr;
876 }
877 t->context = py_context;
878
879 if (!MaybeInvokeCreatedOnEagerTensorProfiler(t)) {
880 return nullptr;
881 }
882 }
883 return reinterpret_cast<PyObject*>(t);
884 }
885
PyEagerTensor_ID(const PyObject * tensor)886 tensorflow::int64 PyEagerTensor_ID(const PyObject* tensor) {
887 DCHECK(EagerTensor_CheckExact(tensor));
888 return reinterpret_cast<const EagerTensor*>(tensor)->id;
889 }
890
PyEagerTensor_Dtype(const PyObject * tensor)891 tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor) {
892 DCHECK(EagerTensor_CheckExact(tensor));
893 return static_cast<tensorflow::DataType>(TFE_TensorHandleDataType(
894 reinterpret_cast<const EagerTensor*>(tensor)->handle));
895 }
896
PyEagerTensor_NumElements(PyObject * tensor)897 tensorflow::int64 PyEagerTensor_NumElements(PyObject* tensor) {
898 DCHECK(EagerTensor_CheckExact(tensor));
899 EagerTensor* as_c_eager_tensor = reinterpret_cast<EagerTensor*>(tensor);
900 tensorflow::int64 result = TFE_TensorHandleNumElements(
901 as_c_eager_tensor->handle, &as_c_eager_tensor->status);
902
903 if (MaybeRaiseExceptionFromTFStatus(&as_c_eager_tensor->status,
904 PyExc_ValueError)) {
905 // Cleanup status before returning.
906 as_c_eager_tensor->status.status = tensorflow::Status::OK();
907 return -1;
908 }
909
910 return result;
911 }
912
TFE_Py_InitEagerTensor(PyObject * base_class)913 PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
914 if (!PyType_Check(base_class)) {
915 PyErr_SetString(
916 PyExc_TypeError,
917 tensorflow::strings::StrCat(
918 "Expecting a class definition for `base_class` passed to ",
919 "TFE_InitEagerTensor. Got ", Py_TYPE(base_class)->tp_name)
920 .c_str());
921 return nullptr;
922 }
923 // Note that we allocated kMaxEagerTensorParentSize bytes of unused space in
924 // EagerTensor to allow for the space usage of the base class.
925 PyTypeObject* base_class_type = reinterpret_cast<PyTypeObject*>(base_class);
926 if (base_class_type->tp_basicsize > kMaxEagerTensorParentSize) {
927 PyErr_SetString(
928 PyExc_TypeError,
929 tensorflow::strings::StrCat(
930 "Unable to create subclass EagerTensor from base class ",
931 Py_TYPE(base_class)->tp_name,
932 ". Need its size to be <= ", kMaxEagerTensorParentSize)
933 .c_str());
934 return nullptr;
935 }
936 if (base_class_type->tp_itemsize != 0) {
937 PyErr_SetString(
938 PyExc_TypeError,
939 tensorflow::strings::StrCat(
940 "Unable to create subclass EagerTensor from base class ",
941 Py_TYPE(base_class)->tp_name,
942 " which supports variable length instances.")
943 .c_str());
944 return nullptr;
945 }
946 Py_INCREF(base_class);
947 #if PY_MAJOR_VERSION >= 3
948 PyObject* bases = PyTuple_New(1);
949 PyTuple_SET_ITEM(bases, 0, base_class);
950
951 tensorflow::Safe_PyObjectPtr base_class_module(
952 PyObject_GetAttrString(base_class, "__module__"));
953 const char* module = nullptr;
954 if (PyErr_Occurred()) {
955 PyErr_Clear();
956 module = "__builtin__";
957 } else {
958 module = PyBytes_AsString(base_class_module.get());
959 if (module == nullptr) {
960 PyErr_Clear();
961 module = PyUnicode_AsUTF8(base_class_module.get());
962 if (module == nullptr) {
963 PyErr_Clear();
964 module = "__builtin__";
965 }
966 }
967 }
968
969 // NOTE: The c_str from this string needs to outlast the function, hence is
970 // static.
971 static tensorflow::string fully_qualified_name =
972 tensorflow::strings::StrCat(module, ".EagerTensor");
973
974 static PyType_Spec EagerTensor_Type_spec = {
975 fully_qualified_name.c_str(), sizeof(EagerTensor), 0,
976 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE, EagerTensor_Type_slots};
977
978 EagerTensorType = reinterpret_cast<PyTypeObject*>(
979 PyType_FromSpecWithBases(&EagerTensor_Type_spec, bases));
980 if (PyErr_Occurred()) {
981 return nullptr;
982 }
983 if (EagerTensorType == nullptr) {
984 PyErr_SetString(PyExc_RuntimeError, "Error while creating EagerTensorType");
985 return nullptr;
986 }
987 EagerTensorType->tp_dictoffset = offsetof(EagerTensor, dict);
988 EagerTensorType->tp_as_buffer = &EagerTensor_as_buffer;
989 #else
990 _EagerTensorType.tp_base = base_class_type;
991
992 if (PyType_Ready(&_EagerTensorType) < 0) {
993 if (PyErr_Occurred()) return nullptr;
994 PyErr_SetString(PyExc_RuntimeError,
995 "Error while creating EagerTensor type.");
996 return nullptr;
997 }
998 EagerTensorType = &_EagerTensorType;
999 Py_INCREF(EagerTensorType);
1000 #endif
1001 return reinterpret_cast<PyObject*>(EagerTensorType);
1002 }
1003
TFE_Py_SetEagerTensorProfiler(PyObject * profiler)1004 PyObject* TFE_Py_SetEagerTensorProfiler(PyObject* profiler) {
1005 Py_XDECREF(eager_tensor_profiler);
1006
1007 if (profiler == Py_None) {
1008 eager_tensor_profiler = nullptr;
1009 } else {
1010 eager_tensor_profiler = profiler;
1011 Py_INCREF(eager_tensor_profiler);
1012 }
1013 Py_RETURN_NONE;
1014 }
1015
TFE_Py_TensorShapeSlice(PyObject * tensors,int slice_dim)1016 PyObject* TFE_Py_TensorShapeSlice(PyObject* tensors, int slice_dim) {
1017 if (!PyList_Check(tensors) && !PyTuple_Check(tensors)) {
1018 PyErr_SetString(PyExc_TypeError,
1019 tensorflow::strings::StrCat(
1020 "tensors argument must be a list or a tuple. Got \"",
1021 Py_TYPE(tensors)->tp_name, "\"")
1022 .c_str());
1023 return nullptr;
1024 }
1025 if (slice_dim < 0) {
1026 PyErr_SetString(
1027 PyExc_ValueError,
1028 tensorflow::strings::StrCat("Slice dimension must be non-negative. "
1029 "Got ",
1030 slice_dim)
1031 .c_str());
1032 return nullptr;
1033 }
1034
1035 PyObject* py_context = GetPyEagerContext();
1036 if (py_context == nullptr) {
1037 PyErr_SetString(PyExc_RuntimeError, tensorflow::strings::StrCat(
1038 "Cannot create EagerTensor when "
1039 "EagerContext is not valid")
1040 .c_str());
1041 return nullptr;
1042 }
1043
1044 TFE_Context* ctx = GetContextHandle(py_context);
1045
1046 Py_ssize_t num_tensors = PySequence_Fast_GET_SIZE(tensors);
1047 PyObject** tensors_array = PySequence_Fast_ITEMS(tensors);
1048 int64_t num_tensors_int = static_cast<int64_t>(num_tensors);
1049
1050 auto status = tensorflow::make_safe(TF_NewStatus());
1051
1052 // Create an empty tensor.
1053 auto* tensor = tensorflow::unwrap(ctx)->CreateTensor(
1054 tensorflow::DT_INT32, /*dim_sizes=*/{num_tensors_int});
1055
1056 if (num_tensors_int > 0) {
1057 int32_t* data = reinterpret_cast<int32_t*>(tensor->Data());
1058
1059 // Fill the tensor with dims.
1060 for (Py_ssize_t i = 0; i < num_tensors; ++i) {
1061 PyObject* tensor_obj = tensors_array[i];
1062 if (!EagerTensor_CheckExact(tensor_obj)) {
1063 PyErr_SetString(
1064 PyExc_TypeError,
1065 tensorflow::strings::StrCat("Expected a list of EagerTensors but "
1066 "element ",
1067 i, " has type \"",
1068 Py_TYPE(tensor_obj)->tp_name, "\"")
1069 .c_str());
1070 return nullptr;
1071 }
1072
1073 EagerTensor* t = reinterpret_cast<EagerTensor*>(tensor_obj);
1074 TFE_TensorHandle* handle = t->handle;
1075 int num_dims = TFE_TensorHandleNumDims(handle, status.get());
1076 if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_ValueError)) {
1077 return nullptr;
1078 }
1079 if (slice_dim >= num_dims) {
1080 PyErr_SetString(
1081 PyExc_IndexError,
1082 tensorflow::strings::StrCat("Slice dimension (", slice_dim,
1083 ") must be smaller than rank of all "
1084 "tensors, but tensor at index ",
1085 i, " has rank ", num_dims)
1086 .c_str());
1087 return nullptr;
1088 }
1089 int64_t dim = TFE_TensorHandleDim(handle, slice_dim, status.get());
1090 if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_ValueError)) {
1091 return nullptr;
1092 }
1093 data[i] = dim;
1094 }
1095 }
1096
1097 TFE_TensorHandle* handle =
1098 tensorflow::wrap(tensorflow::unwrap(ctx)->CreateLocalHandle(tensor));
1099
1100 if (!status->status.ok()) {
1101 PyErr_SetString(
1102 PyExc_RuntimeError,
1103 tensorflow::strings::StrCat("Failed to construct new tensor handle: ",
1104 TF_Message(status.get()))
1105 .c_str());
1106 return nullptr;
1107 }
1108
1109 return EagerTensorFromHandle(handle);
1110 }
1111
TFE_Py_TensorShapeOnDevice(PyObject * tensor)1112 PyObject* TFE_Py_TensorShapeOnDevice(PyObject* tensor) {
1113 if (!EagerTensor_CheckExact(tensor)) {
1114 PyErr_SetString(
1115 PyExc_TypeError,
1116 tensorflow::strings::StrCat("Expected an EagerTensors but got type \"",
1117 Py_TYPE(tensor)->tp_name, "\"")
1118 .c_str());
1119 return nullptr;
1120 }
1121 TFE_TensorHandle* handle = EagerTensor_Handle(tensor);
1122
1123 auto status = tensorflow::make_safe(TF_NewStatus());
1124 TFE_TensorDebugInfo* debug_info =
1125 TFE_TensorHandleTensorDebugInfo(handle, status.get());
1126 if (!status->status.ok()) {
1127 PyErr_SetString(
1128 PyExc_RuntimeError,
1129 tensorflow::strings::StrCat("Error retrieving tensor's device shape: ",
1130 TF_Message(status.get()))
1131 .c_str());
1132 return nullptr;
1133 }
1134
1135 int rank = TFE_TensorDebugInfoOnDeviceNumDims(debug_info);
1136 PyObject* shape = PyTuple_New(rank);
1137 for (int i = 0; i < rank; ++i) {
1138 tensorflow::int64 dim_size = TFE_TensorDebugInfoOnDeviceDim(debug_info, i);
1139 PyTuple_SET_ITEM(shape, i, PyLong_FromLongLong(dim_size));
1140 }
1141 TFE_DeleteTensorDebugInfo(debug_info);
1142
1143 return shape;
1144 }
1145