• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/python/eager/pywrap_tensor.h"
17 
18 #include <stdlib.h>
19 #include <string.h>
20 
21 #include <cmath>
22 
23 #include "structmember.h"  // NOLINT // For PyMemberDef
24 #include "pybind11/pybind11.h"
25 #include "tensorflow/c/c_api.h"
26 #include "tensorflow/c/eager/c_api.h"
27 #include "tensorflow/c/eager/c_api_internal.h"
28 #include "tensorflow/c/eager/tfe_context_internal.h"
29 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
30 #include "tensorflow/c/tf_status.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/framework/types.pb.h"
33 #include "tensorflow/core/lib/strings/strcat.h"
34 #include "tensorflow/python/eager/pywrap_tensor_conversion.h"
35 #include "tensorflow/python/eager/pywrap_tfe.h"
36 #include "tensorflow/python/lib/core/ndarray_tensor.h"
37 #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
38 #include "tensorflow/python/lib/core/numpy.h"
39 #include "tensorflow/python/lib/core/py_exception_registry.h"
40 #include "tensorflow/python/lib/core/py_seq_tensor.h"
41 #include "tensorflow/python/lib/core/safe_ptr.h"
42 
43 // forward declare
44 struct EagerTensor;
45 namespace tensorflow {
46 
47 // Convert a TFE_TensorHandle to a Python numpy.ndarray object.
48 // The two may share underlying storage so changes to one may reflect in the
49 // other.
TFE_TensorHandleToNumpy(TFE_TensorHandle * handle,TF_Status * status)50 PyObject* TFE_TensorHandleToNumpy(TFE_TensorHandle* handle, TF_Status* status) {
51   if (TFE_TensorHandleDataType(handle) == TF_RESOURCE) {
52     TF_SetStatus(status, TF_INVALID_ARGUMENT,
53                  "Cannot convert a Tensor of dtype resource to a NumPy array.");
54     return nullptr;
55   }
56 
57   tensorflow::Safe_TF_TensorPtr tensor = nullptr;
58   Py_BEGIN_ALLOW_THREADS;
59   tensor = tensorflow::make_safe(TFE_TensorHandleResolve(handle, status));
60   Py_END_ALLOW_THREADS;
61   if (!status->status.ok()) {
62     return nullptr;
63   }
64 
65   PyObject* ret = nullptr;
66   auto cppstatus =
67       tensorflow::TF_TensorToMaybeAliasedPyArray(std::move(tensor), &ret);
68   tensorflow::Set_TF_Status_from_Status(status, cppstatus);
69   if (!status->status.ok()) {
70     Py_XDECREF(ret);
71     return nullptr;
72   }
73   CHECK_NE(ret, nullptr);
74   return ret;
75 }
76 }  // namespace tensorflow
77 namespace {
78 
79 using tensorflow::TFE_TensorHandleToNumpy;
80 
81 // An instance of _EagerTensorProfiler that will receive callbacks about
82 // events on eager tensors. This is set by TFE_Py_InitEagerTensor, if at all.
83 PyObject* eager_tensor_profiler = nullptr;
84 
85 // Read-only dict. Please don't use this in any setting where the dict might
86 // actually get mutated. This is only used to pass empty kwargs when creating a
87 // new EagerTensor.
EmptyDict()88 PyObject* EmptyDict() {
89   static PyObject* empty_dict = PyDict_New();
90   return empty_dict;
91 }
92 
EmptyTuple()93 PyObject* EmptyTuple() {
94   static PyObject* empty_tuple = PyTuple_New(0);
95   return empty_tuple;
96 }
97 
GetContextHandle(PyObject * py_context)98 TFE_Context* GetContextHandle(PyObject* py_context) {
99   tensorflow::Safe_PyObjectPtr py_context_handle(
100       PyObject_GetAttrString(py_context, "_handle"));
101   if (py_context_handle == nullptr) {
102     // Current Python code makes sure this never happens. If it does, or
103     // becomes hard to maintain, we can call the ensure_initialized() method
104     // here.
105     PyErr_SetString(
106         PyExc_TypeError,
107         "Expected `context` argument in EagerTensor constructor to have a "
108         "`_handle` attribute but it did not. Was eager Context initialized?");
109     return nullptr;
110   }
111 
112   auto* ctx = reinterpret_cast<TFE_Context*>(
113       PyCapsule_GetPointer(py_context_handle.get(), nullptr));
114   if (ctx == nullptr) {
115     PyErr_SetString(PyExc_TypeError,
116                     tensorflow::strings::StrCat(
117                         "Expected context._handle to contain a PyCapsule "
118                         "encoded pointer to TFE_Context. Got ",
119                         Py_TYPE(py_context_handle.get())->tp_name)
120                         .c_str());
121   }
122   return ctx;
123 }
124 
125 
126 // Helper function to convert `v` to a tensorflow::DataType and store it in
127 // `*out`. Returns true on success, false otherwise.
128 // Note that we assume that v is a python int (not long) representing a
129 // TF_DataType/tensorflow::DataType value.
PyIntToDataType(PyObject * v,tensorflow::DataType * out)130 bool PyIntToDataType(PyObject* v, tensorflow::DataType* out) {
131 #if PY_MAJOR_VERSION < 3
132   if (PyInt_Check(v)) {
133     *out = static_cast<tensorflow::DataType>(PyInt_AS_LONG(v));
134     return true;
135   }
136 #else
137   if (PyLong_Check(v)) {
138     *out = static_cast<tensorflow::DataType>(PyLong_AsLong(v));
139     return true;
140   }
141 #endif
142   return false;
143 }
144 
145 // Helper function to create a python integer from TF_DataType.
PyIntFromDataType(TF_DataType l)146 PyObject* PyIntFromDataType(TF_DataType l) {
147 #if PY_MAJOR_VERSION < 3
148   return PyInt_FromLong(l);
149 #else
150   return PyLong_FromLong(l);
151 #endif
152 }
153 
154 // PyObject->tensorflow::DataType conversion function to be used with
155 // PyArg_Parse* APIs.
ConvertDataType(PyObject * obj,tensorflow::DataType * dst)156 int ConvertDataType(PyObject* obj, tensorflow::DataType* dst) {
157   if (obj == Py_None) {
158     *dst = tensorflow::DataType::DT_INVALID;
159   } else if (!PyIntToDataType(obj, dst)) {
160     PyErr_SetString(
161         PyExc_TypeError,
162         tensorflow::strings::StrCat(
163             "Expecting a DataType value for dtype. Got ", Py_TYPE(obj)->tp_name)
164             .c_str());
165     return 0;
166   }
167 
168   return 1;
169 }
170 
171 // Conversion function extracting a const char** device name from a PyObject.
172 // The function should be used with PyArg_Parse* APIs.
ConvertDeviceName(PyObject * obj,const char ** dst)173 int ConvertDeviceName(PyObject* obj, const char** dst) {
174   if (obj == Py_None) {
175     *dst = nullptr;
176   } else {
177     auto device_name = TFE_GetPythonString(obj);
178     if (device_name == nullptr) {
179       PyErr_Clear();
180       PyErr_SetString(PyExc_TypeError, "Error parsing device argument.");
181       return 0;
182     }
183     *dst = device_name;
184   }
185 
186   return 1;
187 }
188 
RaiseExceptionTypeFromTFStatus(TF_Status * status)189 void RaiseExceptionTypeFromTFStatus(TF_Status* status) {
190   TF_Code code = TF_GetCode(status);
191   PyObject* exception = tensorflow::PyExceptionRegistry::Lookup(code);
192   PyErr_SetObject(exception,
193                   pybind11::make_tuple(pybind11::none(), pybind11::none(),
194                                        TF_Message(status))
195                       .ptr());
196 }
197 
198 }  // namespace
199 
200 namespace tensorflow {
201 // This function checks whether the desired type is "compatible" with the
202 // inferred type. At a high level, compatibility means that all integral types
203 // are compatible with each other, and all floating types are compatible with
204 // each other.
205 //
206 // Type compatibility doesn't consider overflows (i.e. int64 is *always*
207 // compatible with int32). This is intended to match graph behavior.
IsCompatible(DataType desired,DataType returned)208 bool IsCompatible(DataType desired, DataType returned) {
209   if (desired == returned) return true;
210 
211   if (DataTypeIsInteger(desired) && DataTypeIsInteger(returned)) {
212     return true;
213   } else if (DataTypeIsFloating(desired) &&
214              (DataTypeIsFloating(returned) || DataTypeIsInteger(returned))) {
215     return true;
216   } else if (DataTypeIsComplex(desired) &&
217              (DataTypeIsComplex(returned) || DataTypeIsInteger(returned) ||
218               DataTypeIsFloating(returned))) {
219     return true;
220   } else if (DataTypeIsQuantized(desired) && DataTypeIsInteger(returned)) {
221     return true;
222   }
223   return false;
224 }
225 
226 // TODO(nareshmodi): Move EagerCast and ReadVariableOp (which use the C API to
227 // execute TFE Ops) to a separate common library.
228 // Casts data referred to by `handle` from type `src_type_enum` to type
229 // `dst_type_enum`.
EagerCast(TFE_Context * ctx,TFE_TensorHandle * handle,TF_DataType src_type_enum,TF_DataType dst_type_enum,TF_Status * out_status)230 TFE_TensorHandle* EagerCast(TFE_Context* ctx, TFE_TensorHandle* handle,
231                             TF_DataType src_type_enum,
232                             TF_DataType dst_type_enum, TF_Status* out_status) {
233   if (ctx == nullptr) return nullptr;
234   const char* op_name = "Cast";
235   const char* device_name = "/device:CPU:0";
236   TFE_Op* op = TFE_NewOp(ctx, op_name, out_status);
237 #define RETURN_ERROR  \
238   {                   \
239     TFE_DeleteOp(op); \
240     return nullptr;   \
241   }
242   if (!out_status->status.ok()) RETURN_ERROR
243   TFE_OpSetDevice(op, device_name, out_status);
244   if (!out_status->status.ok()) RETURN_ERROR
245   TFE_OpAddInput(op, handle, out_status);
246   if (!out_status->status.ok()) RETURN_ERROR
247   TFE_OpSetAttrType(op, "SrcT", src_type_enum);
248   TFE_OpSetAttrType(op, "DstT", dst_type_enum);
249   TFE_OpSetAttrBool(op, "Truncate", false);
250   TFE_TensorHandle* output = nullptr;
251   int num_outputs = 1;
252   TFE_Execute(op, &output, &num_outputs, out_status);
253   if (!out_status->status.ok() || num_outputs != 1 || output == nullptr) {
254     if (output != nullptr) {
255       TFE_DeleteTensorHandle(output);
256     }
257     RETURN_ERROR
258   }
259   TFE_DeleteOp(op);
260   return output;
261 #undef RETURN_ERROR
262 }
263 
ConvertToEagerTensorUncached(TFE_Context * ctx,PyObject * value,tensorflow::DataType dtype,const char * device_name)264 TFE_TensorHandle* ConvertToEagerTensorUncached(TFE_Context* ctx,
265                                                PyObject* value,
266                                                tensorflow::DataType dtype,
267                                                const char* device_name) {
268   tensorflow::Safe_PyObjectPtr value_decrefer;
269   if (PyArray_IsScalar(value, Generic)) {
270     // Convert numpy scalars to numpy arrays.
271     value = PyArray_FromScalar(value, nullptr);
272     // The returned value needs to be DECREF'd, but the original value was
273     // created in python code, and doesn't need to be DECREF'd.
274     value_decrefer.reset(value);
275   }
276 
277   Safe_TFE_TensorHandlePtr handle =
278       make_safe(PySeqToTFE_TensorHandle(ctx, value, dtype));
279 
280   if (handle == nullptr) return nullptr;
281 
282   Safe_TF_StatusPtr status = make_safe(TF_NewStatus());
283   TF_DataType handle_dtype = TFE_TensorHandleDataType(handle.get());
284   if (dtype != tensorflow::DT_INVALID &&
285       dtype != static_cast<DataType>(handle_dtype)) {
286     if (tensorflow::IsCompatible(dtype, static_cast<DataType>(handle_dtype))) {
287       handle = tensorflow::make_safe(
288           tensorflow::EagerCast(ctx, handle.get(), handle_dtype,
289                                 static_cast<TF_DataType>(dtype), status.get()));
290       if (!status->status.ok()) {
291         PyErr_SetString(PyExc_TypeError,
292                         absl::StrCat("Error while casting from dtype ",
293                                      tensorflow::DataTypeString(
294                                          static_cast<DataType>(handle_dtype)),
295                                      " to ", tensorflow::DataTypeString(dtype),
296                                      ". ", TF_Message(status.get()))
297                             .c_str());
298         return nullptr;
299       }
300     } else {
301       tensorflow::Safe_PyObjectPtr value_str(PyObject_Repr(value));
302       PyErr_SetString(
303           PyExc_TypeError,
304           absl::StrCat("Cannot convert ", TFE_GetPythonString(value_str.get()),
305                        " to EagerTensor of dtype ",
306                        tensorflow::DataTypeString(dtype))
307               .c_str());
308       return nullptr;
309     }
310   }
311 
312   // We always generate CPU:0 tensors, but we may need to change the device
313   // slightly, as for example from /job:localhost/... to /job:worker/...
314   //
315   // Note that this is a shallow copy and will share the underlying buffer,
316   // because we are copying to the same device.
317   if (device_name != nullptr &&
318       strstr(device_name, "/device:CPU:0") != nullptr) {
319     handle = make_safe(TFE_TensorHandleCopyToDevice(handle.get(), ctx,
320                                                     device_name, status.get()));
321     const TF_Code code = TF_GetCode(status.get());
322     if (code != TF_OK) {
323       RaiseExceptionTypeFromTFStatus(status.get());
324       return nullptr;
325     }
326   }
327 
328   return handle.release();
329 }
330 
ConvertToEagerTensor(TFE_Context * ctx,PyObject * value,DataType dtype,const char * device_name)331 TFE_TensorHandle* ConvertToEagerTensor(TFE_Context* ctx, PyObject* value,
332                                        DataType dtype,
333                                        const char* device_name) {
334   // Reduce the overhead of allocation/transfer-to-device for scalars by
335   // caching the corresponding handles. Note that currently only Python
336   // scalars are cached.
337   // TODO(slebedev): also cache singleton NumPy arrays and scalars?
338   if (PyArray_IsPythonNumber(value)) {
339     auto* cache = TFE_TensorHandleCache::Get();
340     TFE_TensorHandle* handle = cache->Lookup(value, dtype, ctx, device_name);
341     if (handle != nullptr) return handle;
342     handle = ConvertToEagerTensorUncached(ctx, value, dtype, device_name);
343     if (handle == nullptr) return nullptr;
344     if (!PyFloat_Check(value) || std::isfinite(PyFloat_AS_DOUBLE(value))) {
345       cache->Insert(value, dtype, ctx, device_name, handle);
346     }
347     return handle;
348   } else {
349     return ConvertToEagerTensorUncached(ctx, value, dtype, device_name);
350   }
351 }
352 
353 }  // namespace tensorflow
354 
355 extern "C" {
356 
357 static const int kMaxEagerTensorParentSize = 64;
358 
359 // TODO(agarwal): store context handle in EagerTensor.
360 typedef struct EagerTensor {
361   PyObject_HEAD;
362   // Note that we leave kMaxEagerTensorParentSize bytes here for use by the
363   // parent class. The parent class is set at runtime, so we don't know the
364   // exact size at compile time.
365   char unused[kMaxEagerTensorParentSize];
366   TFE_TensorHandle* handle;
367   int64_t id;
368   // Indicates whether it's a packed tensor or not.
369   bool is_packed;
370   // This mirrors tensorflow.core.framework.ops.Tensor._handle_data Which will
371   // be None for tensors of type other than DT_RESOURCE. For DT_RESOURCE
372   // tensors, this will contain a serialized HandleData proto with shape
373   // inference metadata about shapes and dtypes of resources accessible from
374   // this handle.
375   // Note that we assume that handle_data cannot participate in reference
376   // cycles, and hence don't provide GC support for it.
377   PyObject* handle_data;
378 
379   // This stores `_tensor_shape`, a cached `TensorShape` object, and is set the
380   // first time that `_EagerTensorBase`'s `shape` property is called.
381   PyObject* tensor_shape;
382 
383   // We store a status object here as an optimization to avoid allocating a new
384   // Status objects on different functions that operate on EagerTensor and need
385   // to use a TF_Status object. However note that accesses to `status` are not
386   // thread-safe.
387   TF_Status status;
388 
389   // The eager Context (from eager/context.py) used by this Tensor.
390   // This is currently used only to make sure context outlives TensorHandles.
391   PyObject* context;
392 
393   PyObject* weakreflist; /* List of weak references */
394 
395   // Per-instance attribute dictionary, to support monkey patching
396   // (e.g. EagerTensor.assign when slicing variables). This dictionary is
397   // created by CPython the first time an attribute is assigned, pointed to by
398   // tp_dictoffset. Note that garbage collection is not enabled for
399   // EagerTensors, so assigning objects to EagerTensor attributes which require
400   // garbage collection is likely to cause issues.
401   PyObject* dict;
402 } EagerTensor;
403 
404 namespace {
405 
406 // Returns true on success - successfully invoked or no profiler registered.
407 // Returns false if some error occurred.
MaybeInvokeCreatedOnEagerTensorProfiler(EagerTensor * created_tensor)408 bool MaybeInvokeCreatedOnEagerTensorProfiler(EagerTensor* created_tensor) {
409   if (eager_tensor_profiler != nullptr) {
410 #if PY_MAJOR_VERSION < 3
411     PyObject* created_method_name = PyString_InternFromString("created");
412 #else
413     PyObject* created_method_name = PyUnicode_InternFromString("created");
414 #endif
415     if (created_method_name == nullptr) {
416       return false;
417     }
418     PyObject* result = PyObject_CallMethodObjArgs(
419         eager_tensor_profiler, created_method_name, created_tensor, NULL);
420     if (result == nullptr) {
421       LOG(ERROR) << "Invoking created() on EagerTensor profiler failed";
422       // While we can potentially continue because the error is related to
423       // profiling, we choose to return an error because:
424       //  - If profiling is used, the user likely wants to stop execution on
425       //    profiling errors.
426       //  - Error in profiling code might have left some state in an invalid
427       //    form that can lead to an error later on. Better to fail fast.
428       Py_DECREF(created_method_name);
429       return false;
430     }
431     Py_DECREF(created_method_name);
432     Py_DECREF(result);
433   }
434   return true;
435 }
436 
437 }  // namespace
438 
439 // tp_init for EagerTensor.
EagerTensor_init(EagerTensor * self,PyObject * args,PyObject * kwds)440 int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) {
441   self->id = get_uid();
442   self->handle = nullptr;
443   self->is_packed = false;
444   Py_INCREF(Py_None);
445   self->handle_data = Py_None;
446   Py_INCREF(Py_None);
447   self->tensor_shape = Py_None;
448   self->status.status = tensorflow::Status::OK();
449   self->dict = nullptr;
450   self->weakreflist = nullptr;
451   self->context = nullptr;
452   PyObject* value;
453   const char* device_name = nullptr;
454   tensorflow::DataType dtype = tensorflow::DataType::DT_INVALID;
455   const char* kwlist[] = {"value", "device", "dtype", nullptr};
456   if (!PyArg_ParseTupleAndKeywords(
457           args, kwds, "OO&|O&", const_cast<char**>(kwlist), &value,
458           ConvertDeviceName, &device_name, ConvertDataType, &dtype)) {
459     return -1;
460   }
461 
462   PyObject* py_context = GetPyEagerContext();
463   if (py_context == nullptr) return -1;
464   self->context = py_context;
465 
466   auto* handle = tensorflow::ConvertToEagerTensor(GetContextHandle(py_context),
467                                                   value, dtype, device_name);
468   if (handle == nullptr) return -1;
469   self->handle = handle;
470 
471   if (!MaybeInvokeCreatedOnEagerTensorProfiler(self)) {
472     return -1;
473   }
474 
475   return 0;
476 }
477 
478 // tp_dealloc for EagerTensor.
EagerTensor_dealloc(EagerTensor * self)479 void EagerTensor_dealloc(EagerTensor* self) {
480   // Unhook the object from python's GC so that the weakref deleter doesn't
481   // try to re-delete this.
482   PyObject_GC_UnTrack((PyObject*)self);
483 
484   // Clear weak references to self.
485   // Needs to happen before any actual destruction.
486   PyObject_ClearWeakRefs((PyObject*)self);
487 
488   Py_DECREF(self->handle_data);
489   Py_DECREF(self->tensor_shape);
490   // If an attribute dictionary has been created, release it. Note that this
491   // is only ever created by CPython's attribute setting methods; we don't
492   // create it ourselves.
493   Py_CLEAR(self->dict);
494   if (self->handle != nullptr) {
495     TFE_DeleteTensorHandle(self->handle);
496     self->handle = nullptr;
497   }
498 
499   // Decref context after deleting the tensor handle.
500   Py_XDECREF(self->context);
501 
502   // We have the global interpreter lock, so use this chance to perform delayed
503   // refcount decrements.
504   tensorflow::ClearDecrefCache();
505   auto id = self->id;
506   Py_TYPE(self)->tp_free(self);
507   TFE_Py_TapeSetDeleteTrace(id);
508 }
509 
510 // Getter for `_id`.
EagerTensor_getid(EagerTensor * self,void * closure)511 static PyObject* EagerTensor_getid(EagerTensor* self, void* closure) {
512   return PyLong_FromLongLong(self->id);
513 }
514 
515 // Getter for `_datatype_enum`.
EagerTensor_datatype_enum(EagerTensor * self)516 static PyObject* EagerTensor_datatype_enum(EagerTensor* self) {
517   return PyIntFromDataType(TFE_TensorHandleDataType(self->handle));
518 }
519 
520 // Getter for `_shape_tuple`.
EagerTensor_shape_tuple(EagerTensor * self)521 static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
522   auto handle = self->handle;
523   int n = TFE_TensorHandleNumDims(handle, &self->status);
524   TF_Code code = TF_GetCode(&self->status);
525   if (code != TF_OK) {
526     RaiseExceptionTypeFromTFStatus(&self->status);
527     // Cleanup self->status before returning.
528     self->status.status = tensorflow::Status::OK();
529     return nullptr;
530   }
531   PyObject* shape = PyTuple_New(n);
532   if (PyErr_Occurred()) return nullptr;
533   for (int i = 0; i < n; ++i) {
534     PyObject* dim =
535         PyLong_FromLongLong(TFE_TensorHandleDim(handle, i, &self->status));
536     code = TF_GetCode(&self->status);
537     if (code != TF_OK || dim == nullptr ||
538         PyTuple_SetItem(shape, i, dim) != 0) {
539       if (code != TF_OK) {
540         RaiseExceptionTypeFromTFStatus(&self->status);
541       } else {
542         PyErr_SetString(PyExc_RuntimeError, "Error while creating shape");
543       }
544       // Cleanup self->status before returning.
545       self->status.status = tensorflow::Status::OK();
546       Py_DECREF(shape);
547       if (dim != nullptr) Py_DECREF(dim);
548       return nullptr;
549     }
550   }
551   return shape;
552 }
553 
554 // Getter for `_rank`.
EagerTensor_rank(EagerTensor * self)555 static PyObject* EagerTensor_rank(EagerTensor* self) {
556   int num_dims = TFE_TensorHandleNumDims(self->handle, &self->status);
557   if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) {
558     // Cleanup self->status before returning.
559     self->status.status = tensorflow::Status::OK();
560     return nullptr;
561   }
562 #if PY_MAJOR_VERSION < 3
563   return PyInt_FromLong(num_dims);
564 #else
565   return PyLong_FromLong(num_dims);
566 #endif
567 }
568 
569 // Getter for `_num_elements`.
EagerTensor_num_elements(EagerTensor * self)570 static PyObject* EagerTensor_num_elements(EagerTensor* self) {
571   auto handle = self->handle;
572   int n = TFE_TensorHandleNumElements(handle, &self->status);
573   if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) {
574     // Cleanup self->status before returning.
575     self->status.status = tensorflow::Status::OK();
576     return nullptr;
577   }
578   return PyLong_FromLongLong(n);
579 }
580 
EagerTensor_handle_data(EagerTensor * self,void * unused)581 static PyObject* EagerTensor_handle_data(EagerTensor* self, void* unused) {
582   Py_INCREF(self->handle_data);
583   return self->handle_data;
584 }
585 
EagerTensor_sethandle_data(EagerTensor * self,PyObject * value,void * unused)586 static int EagerTensor_sethandle_data(EagerTensor* self, PyObject* value,
587                                       void* unused) {
588   Py_DECREF(self->handle_data);
589   Py_INCREF(value);
590   self->handle_data = value;
591   return 0;
592 }
593 
EagerTensor_tensor_shape(EagerTensor * self,void * unused)594 static PyObject* EagerTensor_tensor_shape(EagerTensor* self, void* unused) {
595   Py_INCREF(self->tensor_shape);
596   return self->tensor_shape;
597 }
598 
EagerTensor_settensor_shape(EagerTensor * self,PyObject * value,void * unused)599 static int EagerTensor_settensor_shape(EagerTensor* self, PyObject* value,
600                                        void* unused) {
601   Py_DECREF(self->tensor_shape);
602   Py_INCREF(value);
603   self->tensor_shape = value;
604   return 0;
605 }
606 
607 // Function `_copy_to_device`.
EagerTensor_copy_to_device(EagerTensor * self,PyObject * args,PyObject * kwds)608 static PyObject* EagerTensor_copy_to_device(EagerTensor* self, PyObject* args,
609                                             PyObject* kwds) {
610   if (!_PyArg_NoKeywords("copy_to_device", kwds)) return nullptr;
611 
612   const char* device_name = nullptr;
613   if (!PyArg_ParseTuple(args, "O&:copy_to_device", ConvertDeviceName,
614                         &device_name)) {
615     return nullptr;
616   }
617 
618   // Note that this is a shallow copy and will share the underlying buffer
619   // if copying to the same device.
620   TFE_TensorHandle* handle = TFE_TensorHandleCopyToDevice(
621       self->handle, GetContextHandle(self->context), device_name,
622       &self->status);
623   if (MaybeRaiseExceptionFromTFStatus(&self->status, PyExc_RuntimeError)) {
624     // Cleanup self->status before returning.
625     self->status.status = tensorflow::Status::OK();
626     return nullptr;
627   }
628 
629   return EagerTensorFromHandle(handle);
630 }
631 
632 // Function `_numpy_internal`.
633 // Convert an EagerTensor to a Python numpy.ndarray object.
634 // The two may share underlying storage so changes to one may reflect in the
635 // other.
636 // Note that if `self` is not on CPU, we raise an Exception.
EagerTensor_numpy_internal(EagerTensor * self)637 static PyObject* EagerTensor_numpy_internal(EagerTensor* self) {
638   auto* py_array = TFE_TensorHandleToNumpy(self->handle, &self->status);
639   if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) {
640     Py_XDECREF(py_array);
641     // Cleanup self->status before returning.
642     self->status.status = tensorflow::Status::OK();
643     return nullptr;
644   } else {
645     return PyArray_Return(reinterpret_cast<PyArrayObject*>(py_array));
646   }
647 }
648 
649 // Getter `device`.
EagerTensor_device(EagerTensor * self)650 static PyObject* EagerTensor_device(EagerTensor* self) {
651   const char* device = TFE_TensorHandleDeviceName(self->handle, &self->status);
652   if (MaybeRaiseExceptionFromTFStatus(&self->status, PyExc_ValueError)) {
653     // Cleanup self->status before returning.
654     self->status.status = tensorflow::Status::OK();
655     return nullptr;
656   }
657 #if PY_MAJOR_VERSION >= 3
658   return PyUnicode_FromString(device);
659 #else
660   return PyBytes_FromString(device);
661 #endif
662 }
663 
664 // Getter `backing_device`.
EagerTensor_backing_device(EagerTensor * self)665 static PyObject* EagerTensor_backing_device(EagerTensor* self) {
666   const char* device =
667       TFE_TensorHandleBackingDeviceName(self->handle, &self->status);
668   if (MaybeRaiseExceptionFromTFStatus(&self->status, PyExc_ValueError)) {
669     // Cleanup self->status before returning.
670     self->status.status = tensorflow::Status::OK();
671     return nullptr;
672   }
673 #if PY_MAJOR_VERSION >= 3
674   return PyUnicode_FromString(device);
675 #else
676   return PyBytes_FromString(device);
677 #endif
678 }
679 
680 // Getter `is_packed`.
EagerTensor_is_packed(EagerTensor * self)681 static PyObject* EagerTensor_is_packed(EagerTensor* self) {
682   return PyBool_FromLong(self->is_packed);
683 }
684 
685 static PyGetSetDef EagerTensor_getsetters[] = {
686     {const_cast<char*>("_id"), (getter)EagerTensor_getid, nullptr,
687      const_cast<char*>("Tensor ID."), nullptr},
688     {const_cast<char*>("device"), (getter)EagerTensor_device, nullptr,
689      const_cast<char*>("Device of op that produced the tensor."), nullptr},
690     {const_cast<char*>("backing_device"), (getter)EagerTensor_backing_device,
691      nullptr, const_cast<char*>("Device on which tensor's memory is resident."),
692      nullptr},
693     {const_cast<char*>("is_packed"), (getter)EagerTensor_is_packed, nullptr,
694      const_cast<char*>("Whether the EagerTensor is a packed tensor or not."),
695      nullptr},
696     {const_cast<char*>("_handle_data"), (getter)EagerTensor_handle_data,
697      (setter)EagerTensor_sethandle_data,
698      const_cast<char*>("Shape/DType data if the EagerTensor is a DT_RESOURCE"),
699      nullptr},
700     {const_cast<char*>("_tensor_shape"), (getter)EagerTensor_tensor_shape,
701      (setter)EagerTensor_settensor_shape,
702      const_cast<char*>("Shape of the tensor."), nullptr},
703     {nullptr} /* Sentinel */
704 };
705 
706 #if PY_MAJOR_VERSION < 3
707 // Only used for Python2 since Python3 seems to set the __dict__ correctly.
708 static PyMemberDef EagerTensor_members[] = {
709     {const_cast<char*>("__dict__"), T_OBJECT, offsetof(EagerTensor, dict),
710      READONLY},
711     {nullptr},
712 };
713 #endif
714 
715 static PyMethodDef EagerTensor_methods[] = {
716     {"_numpy_internal", (PyCFunction)EagerTensor_numpy_internal, METH_NOARGS,
717      PyDoc_STR("Internal method to get a NumPy array for the tensor.")},
718     {"_datatype_enum", (PyCFunction)EagerTensor_datatype_enum, METH_NOARGS,
719      PyDoc_STR("The DType of the tensor as an enum.")},
720     {"_shape_tuple", (PyCFunction)EagerTensor_shape_tuple, METH_NOARGS,
721      PyDoc_STR("The shape of the tensor as a python tuple.")},
722     {"_rank", (PyCFunction)EagerTensor_rank, METH_NOARGS,
723      PyDoc_STR("The rank of the tensor.")},
724     {"_copy_to_device", (PyCFunction)EagerTensor_copy_to_device,
725      METH_VARARGS | METH_KEYWORDS,
726      PyDoc_STR("Copies the tensor to the desired device.")},
727     {"_num_elements", (PyCFunction)EagerTensor_num_elements, METH_NOARGS,
728      PyDoc_STR("Number of elements in the tensor.")},
729     {nullptr, nullptr},
730 };
731 
EagerTensor_getbuffer(EagerTensor * self,Py_buffer * view,int flags)732 static int EagerTensor_getbuffer(EagerTensor* self, Py_buffer* view,
733                                  int flags) {
734   if ((flags & PyBUF_WRITABLE) == PyBUF_WRITABLE) {
735     PyErr_SetString(PyExc_BufferError, "EagerTensor is not writable.");
736     return -1;
737   }
738 
739   // TensorHandleToNumpy is zero-copy for everything but DT_RESOURCE and
740   // DT_STRING so the following is only slightly slower than a NumPy-free
741   // implementation.
742   auto py_array = tensorflow::make_safe(
743       TFE_TensorHandleToNumpy(self->handle, &self->status));
744   if (MaybeRaiseExceptionFromTFStatus(&self->status, PyExc_BufferError)) {
745     // Cleanup self->status before returning.
746     self->status.status = tensorflow::Status::OK();
747     return -1;
748   }
749   if (PyObject_GetBuffer(py_array.get(), view, flags) < 0) {
750     return -1;
751   }
752   view->readonly = 1;
753   return 0;
754 }
755 
756 static PyBufferProcs EagerTensor_as_buffer = {
757 #if PY_MAJOR_VERSION < 3
758     nullptr, nullptr, nullptr, nullptr,
759 #endif
760     (getbufferproc)EagerTensor_getbuffer,
761     // Never called because getbufferproc delegates to NumPy.
762     (releasebufferproc) nullptr};
763 
764 // Note that here we are trying to dynamically create a new class as a subclass
765 // of a "HEAPTYPE" class that is itself created in python code and passed in at
766 // runtime. This is fairly atypical and undocumented.
767 //
768 // We use the following strategy for this. Unfortunately, we have to use
769 // different approaches for python2.x vs python3.x
770 // For python2.x, we create the class as a static type and set its tp_base to
771 // the passed in type. Unfortunately setting tp_flags to include
772 // Py_TPFLAGS_HEAPTYPE does not work by itself since it needs some more
773 // initialization of the underlying PyHeapTypeObject and not doing that leads to
774 // some random crashes especially during garbage collection.
775 // python3.x explicitly disables a static subclass of a HEAPTYPE base class.
776 // However it provides a new function, PyType_FromSpecWithBases, to create
777 // types dynamically.
778 
779 // Type object for EagerTensor. This is set by TFE_Py_InitEagerTensor.
780 PyTypeObject* EagerTensorType = nullptr;
781 
782 #if PY_MAJOR_VERSION >= 3
783 static PyType_Slot EagerTensor_Type_slots[] = {
784     {Py_tp_dealloc, reinterpret_cast<void*>(EagerTensor_dealloc)},
785     {Py_tp_methods, reinterpret_cast<void*>(EagerTensor_methods)},
786     {Py_tp_getset, reinterpret_cast<void*>(EagerTensor_getsetters)},
787     {Py_tp_init, reinterpret_cast<void*>(EagerTensor_init)},
788     {0, nullptr},
789 };
790 #else
791 
792 #define EAGER_TENSOR_TPFLAGS (Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_NEWBUFFER)
793 
794 // TODO(agarwal): support active_trace.
795 static PyTypeObject _EagerTensorType = {
796     // clang-format off
797     PyVarObject_HEAD_INIT(nullptr, 0)
798     // clang-format on
799     "EagerTensor",                      /* tp_name */
800     sizeof(EagerTensor),                /* tp_basicsize */
801     0,                                  /* tp_itemsize */
802     (destructor)EagerTensor_dealloc,    /* tp_dealloc */
803 #if PY_VERSION_HEX < 0x03080000
804     nullptr,                            /* tp_print */
805 #else
806     0, /* tp_vectorcall_offset */
807 #endif
808     nullptr,                            /* tp_getattr */
809     nullptr,                            /* tp_setattr */
810     nullptr,                            /* tp_compare */
811     nullptr,                            /* tp_repr */
812     nullptr,                            /* tp_as_number */
813     nullptr,                            /* tp_as_sequence */
814     nullptr,                            /* tp_as_mapping */
815     nullptr,                            /* tp_hash */
816     nullptr,                            /* tp_call */
817     nullptr,                            /* tp_str */
818     nullptr,                            /* tp_getattro */
819     nullptr,                            /* tp_setattro */
820     &EagerTensor_as_buffer,             /* tp_as_buffer */
821     EAGER_TENSOR_TPFLAGS,               /* tp_flags */
822     nullptr,                            /* tp_doc */
823     nullptr,                            /* tp_traverse */
824     nullptr,                            /* tp_clear */
825     nullptr,                            /* tp_richcompare */
826     offsetof(EagerTensor, weakreflist), /* tp_weaklistoffset */
827     nullptr,                            /* tp_iter */
828     nullptr,                            /* tp_iternext */
829     EagerTensor_methods,                /* tp_methods */
830     EagerTensor_members,                /* tp_members */
831     EagerTensor_getsetters,             /* tp_getset */
832     nullptr,                            /* tp_base */
833     nullptr,                            /* tp_dict */
834     nullptr,                            /* tp_descr_get */
835     nullptr,                            /* tp_descr_set */
836     offsetof(EagerTensor, dict),        /* tp_dictoffset */
837     (initproc)EagerTensor_init,         /* tp_init */
838     nullptr,                            /* tp_alloc */
839     nullptr,                            /* tp_new */
840 };
841 
842 #endif
843 
844 }  // extern "C"
845 
EagerTensor_CheckExact(const PyObject * o)846 bool EagerTensor_CheckExact(const PyObject* o) {
847   return Py_TYPE(o) == EagerTensorType;
848 }
849 
EagerTensor_Handle(const PyObject * o)850 TFE_TensorHandle* EagerTensor_Handle(const PyObject* o) {
851   return reinterpret_cast<const EagerTensor*>(o)->handle;
852 }
853 
EagerTensorFromHandle(TFE_TensorHandle * handle,const bool is_packed)854 PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle,
855                                 const bool is_packed) {
856   if (handle == nullptr) {
857     return nullptr;
858   }
859   EagerTensor* t = reinterpret_cast<EagerTensor*>(
860       EagerTensorType->tp_new(EagerTensorType, EmptyTuple(), EmptyDict()));
861   if (t != nullptr) {
862     t->id = get_uid();
863     t->is_packed = is_packed;
864     Py_INCREF(Py_None);
865     t->handle_data = Py_None;
866     Py_INCREF(Py_None);
867     t->tensor_shape = Py_None;
868     t->handle = handle;
869     t->status.status = tensorflow::Status::OK();
870     t->weakreflist = nullptr;
871     PyObject* py_context = GetPyEagerContext();
872     if (py_context == nullptr) {
873       LOG(ERROR) << "Cannot create an eager tensor before eager context has "
874                     "been set or after it has been deleted";
875       return nullptr;
876     }
877     t->context = py_context;
878 
879     if (!MaybeInvokeCreatedOnEagerTensorProfiler(t)) {
880       return nullptr;
881     }
882   }
883   return reinterpret_cast<PyObject*>(t);
884 }
885 
PyEagerTensor_ID(const PyObject * tensor)886 tensorflow::int64 PyEagerTensor_ID(const PyObject* tensor) {
887   DCHECK(EagerTensor_CheckExact(tensor));
888   return reinterpret_cast<const EagerTensor*>(tensor)->id;
889 }
890 
PyEagerTensor_Dtype(const PyObject * tensor)891 tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor) {
892   DCHECK(EagerTensor_CheckExact(tensor));
893   return static_cast<tensorflow::DataType>(TFE_TensorHandleDataType(
894       reinterpret_cast<const EagerTensor*>(tensor)->handle));
895 }
896 
PyEagerTensor_NumElements(PyObject * tensor)897 tensorflow::int64 PyEagerTensor_NumElements(PyObject* tensor) {
898   DCHECK(EagerTensor_CheckExact(tensor));
899   EagerTensor* as_c_eager_tensor = reinterpret_cast<EagerTensor*>(tensor);
900   tensorflow::int64 result = TFE_TensorHandleNumElements(
901       as_c_eager_tensor->handle, &as_c_eager_tensor->status);
902 
903   if (MaybeRaiseExceptionFromTFStatus(&as_c_eager_tensor->status,
904                                       PyExc_ValueError)) {
905     // Cleanup status before returning.
906     as_c_eager_tensor->status.status = tensorflow::Status::OK();
907     return -1;
908   }
909 
910   return result;
911 }
912 
TFE_Py_InitEagerTensor(PyObject * base_class)913 PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
914   if (!PyType_Check(base_class)) {
915     PyErr_SetString(
916         PyExc_TypeError,
917         tensorflow::strings::StrCat(
918             "Expecting a class definition for `base_class` passed to ",
919             "TFE_InitEagerTensor. Got ", Py_TYPE(base_class)->tp_name)
920             .c_str());
921     return nullptr;
922   }
923   // Note that we allocated kMaxEagerTensorParentSize bytes of unused space in
924   // EagerTensor to allow for the space usage of the base class.
925   PyTypeObject* base_class_type = reinterpret_cast<PyTypeObject*>(base_class);
926   if (base_class_type->tp_basicsize > kMaxEagerTensorParentSize) {
927     PyErr_SetString(
928         PyExc_TypeError,
929         tensorflow::strings::StrCat(
930             "Unable to create subclass EagerTensor from base class ",
931             Py_TYPE(base_class)->tp_name,
932             ". Need its size to be <= ", kMaxEagerTensorParentSize)
933             .c_str());
934     return nullptr;
935   }
936   if (base_class_type->tp_itemsize != 0) {
937     PyErr_SetString(
938         PyExc_TypeError,
939         tensorflow::strings::StrCat(
940             "Unable to create subclass EagerTensor from base class ",
941             Py_TYPE(base_class)->tp_name,
942             " which supports variable length instances.")
943             .c_str());
944     return nullptr;
945   }
946   Py_INCREF(base_class);
947 #if PY_MAJOR_VERSION >= 3
948   PyObject* bases = PyTuple_New(1);
949   PyTuple_SET_ITEM(bases, 0, base_class);
950 
951   tensorflow::Safe_PyObjectPtr base_class_module(
952       PyObject_GetAttrString(base_class, "__module__"));
953   const char* module = nullptr;
954   if (PyErr_Occurred()) {
955     PyErr_Clear();
956     module = "__builtin__";
957   } else {
958     module = PyBytes_AsString(base_class_module.get());
959     if (module == nullptr) {
960       PyErr_Clear();
961       module = PyUnicode_AsUTF8(base_class_module.get());
962       if (module == nullptr) {
963         PyErr_Clear();
964         module = "__builtin__";
965       }
966     }
967   }
968 
969   // NOTE: The c_str from this string needs to outlast the function, hence is
970   // static.
971   static tensorflow::string fully_qualified_name =
972       tensorflow::strings::StrCat(module, ".EagerTensor");
973 
974   static PyType_Spec EagerTensor_Type_spec = {
975       fully_qualified_name.c_str(), sizeof(EagerTensor), 0,
976       Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE, EagerTensor_Type_slots};
977 
978   EagerTensorType = reinterpret_cast<PyTypeObject*>(
979       PyType_FromSpecWithBases(&EagerTensor_Type_spec, bases));
980   if (PyErr_Occurred()) {
981     return nullptr;
982   }
983   if (EagerTensorType == nullptr) {
984     PyErr_SetString(PyExc_RuntimeError, "Error while creating EagerTensorType");
985     return nullptr;
986   }
987   EagerTensorType->tp_dictoffset = offsetof(EagerTensor, dict);
988   EagerTensorType->tp_as_buffer = &EagerTensor_as_buffer;
989 #else
990   _EagerTensorType.tp_base = base_class_type;
991 
992   if (PyType_Ready(&_EagerTensorType) < 0) {
993     if (PyErr_Occurred()) return nullptr;
994     PyErr_SetString(PyExc_RuntimeError,
995                     "Error while creating EagerTensor type.");
996     return nullptr;
997   }
998   EagerTensorType = &_EagerTensorType;
999   Py_INCREF(EagerTensorType);
1000 #endif
1001   return reinterpret_cast<PyObject*>(EagerTensorType);
1002 }
1003 
TFE_Py_SetEagerTensorProfiler(PyObject * profiler)1004 PyObject* TFE_Py_SetEagerTensorProfiler(PyObject* profiler) {
1005   Py_XDECREF(eager_tensor_profiler);
1006 
1007   if (profiler == Py_None) {
1008     eager_tensor_profiler = nullptr;
1009   } else {
1010     eager_tensor_profiler = profiler;
1011     Py_INCREF(eager_tensor_profiler);
1012   }
1013   Py_RETURN_NONE;
1014 }
1015 
TFE_Py_TensorShapeSlice(PyObject * tensors,int slice_dim)1016 PyObject* TFE_Py_TensorShapeSlice(PyObject* tensors, int slice_dim) {
1017   if (!PyList_Check(tensors) && !PyTuple_Check(tensors)) {
1018     PyErr_SetString(PyExc_TypeError,
1019                     tensorflow::strings::StrCat(
1020                         "tensors argument must be a list or a tuple. Got \"",
1021                         Py_TYPE(tensors)->tp_name, "\"")
1022                         .c_str());
1023     return nullptr;
1024   }
1025   if (slice_dim < 0) {
1026     PyErr_SetString(
1027         PyExc_ValueError,
1028         tensorflow::strings::StrCat("Slice dimension must be non-negative. "
1029                                     "Got ",
1030                                     slice_dim)
1031             .c_str());
1032     return nullptr;
1033   }
1034 
1035   PyObject* py_context = GetPyEagerContext();
1036   if (py_context == nullptr) {
1037     PyErr_SetString(PyExc_RuntimeError, tensorflow::strings::StrCat(
1038                                             "Cannot create EagerTensor when "
1039                                             "EagerContext is not valid")
1040                                             .c_str());
1041     return nullptr;
1042   }
1043 
1044   TFE_Context* ctx = GetContextHandle(py_context);
1045 
1046   Py_ssize_t num_tensors = PySequence_Fast_GET_SIZE(tensors);
1047   PyObject** tensors_array = PySequence_Fast_ITEMS(tensors);
1048   int64_t num_tensors_int = static_cast<int64_t>(num_tensors);
1049 
1050   auto status = tensorflow::make_safe(TF_NewStatus());
1051 
1052   // Create an empty tensor.
1053   auto* tensor = tensorflow::unwrap(ctx)->CreateTensor(
1054       tensorflow::DT_INT32, /*dim_sizes=*/{num_tensors_int});
1055 
1056   if (num_tensors_int > 0) {
1057     int32_t* data = reinterpret_cast<int32_t*>(tensor->Data());
1058 
1059     // Fill the tensor with dims.
1060     for (Py_ssize_t i = 0; i < num_tensors; ++i) {
1061       PyObject* tensor_obj = tensors_array[i];
1062       if (!EagerTensor_CheckExact(tensor_obj)) {
1063         PyErr_SetString(
1064             PyExc_TypeError,
1065             tensorflow::strings::StrCat("Expected a list of EagerTensors but "
1066                                         "element ",
1067                                         i, " has type \"",
1068                                         Py_TYPE(tensor_obj)->tp_name, "\"")
1069                 .c_str());
1070         return nullptr;
1071       }
1072 
1073       EagerTensor* t = reinterpret_cast<EagerTensor*>(tensor_obj);
1074       TFE_TensorHandle* handle = t->handle;
1075       int num_dims = TFE_TensorHandleNumDims(handle, status.get());
1076       if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_ValueError)) {
1077         return nullptr;
1078       }
1079       if (slice_dim >= num_dims) {
1080         PyErr_SetString(
1081             PyExc_IndexError,
1082             tensorflow::strings::StrCat("Slice dimension (", slice_dim,
1083                                         ") must be smaller than rank of all "
1084                                         "tensors, but tensor at index ",
1085                                         i, " has rank ", num_dims)
1086                 .c_str());
1087         return nullptr;
1088       }
1089       int64_t dim = TFE_TensorHandleDim(handle, slice_dim, status.get());
1090       if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_ValueError)) {
1091         return nullptr;
1092       }
1093       data[i] = dim;
1094     }
1095   }
1096 
1097   TFE_TensorHandle* handle =
1098       tensorflow::wrap(tensorflow::unwrap(ctx)->CreateLocalHandle(tensor));
1099 
1100   if (!status->status.ok()) {
1101     PyErr_SetString(
1102         PyExc_RuntimeError,
1103         tensorflow::strings::StrCat("Failed to construct new tensor handle: ",
1104                                     TF_Message(status.get()))
1105             .c_str());
1106     return nullptr;
1107   }
1108 
1109   return EagerTensorFromHandle(handle);
1110 }
1111 
TFE_Py_TensorShapeOnDevice(PyObject * tensor)1112 PyObject* TFE_Py_TensorShapeOnDevice(PyObject* tensor) {
1113   if (!EagerTensor_CheckExact(tensor)) {
1114     PyErr_SetString(
1115         PyExc_TypeError,
1116         tensorflow::strings::StrCat("Expected an EagerTensors but got type \"",
1117                                     Py_TYPE(tensor)->tp_name, "\"")
1118             .c_str());
1119     return nullptr;
1120   }
1121   TFE_TensorHandle* handle = EagerTensor_Handle(tensor);
1122 
1123   auto status = tensorflow::make_safe(TF_NewStatus());
1124   TFE_TensorDebugInfo* debug_info =
1125       TFE_TensorHandleTensorDebugInfo(handle, status.get());
1126   if (!status->status.ok()) {
1127     PyErr_SetString(
1128         PyExc_RuntimeError,
1129         tensorflow::strings::StrCat("Error retrieving tensor's device shape: ",
1130                                     TF_Message(status.get()))
1131             .c_str());
1132     return nullptr;
1133   }
1134 
1135   int rank = TFE_TensorDebugInfoOnDeviceNumDims(debug_info);
1136   PyObject* shape = PyTuple_New(rank);
1137   for (int i = 0; i < rank; ++i) {
1138     tensorflow::int64 dim_size = TFE_TensorDebugInfoOnDeviceDim(debug_info, i);
1139     PyTuple_SET_ITEM(shape, i, PyLong_FromLongLong(dim_size));
1140   }
1141   TFE_DeleteTensorDebugInfo(debug_info);
1142 
1143   return shape;
1144 }
1145