• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/python/eager/pywrap_tensor.h"
17 
18 #include <stdlib.h>
19 #include <string.h>
20 
21 #include <cmath>
22 
23 #include "structmember.h"  // NOLINT // For PyMemberDef
24 #include "pybind11/pybind11.h"
25 #include "tensorflow/c/c_api.h"
26 #include "tensorflow/c/eager/c_api.h"
27 #include "tensorflow/c/eager/c_api_internal.h"
28 #include "tensorflow/c/eager/tfe_context_internal.h"
29 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
30 #include "tensorflow/c/tf_status.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/framework/types.pb.h"
33 #include "tensorflow/core/lib/strings/strcat.h"
34 #include "tensorflow/python/eager/pywrap_tensor_conversion.h"
35 #include "tensorflow/python/eager/pywrap_tfe.h"
36 #include "tensorflow/python/lib/core/ndarray_tensor.h"
37 #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
38 #include "tensorflow/python/lib/core/numpy.h"
39 #include "tensorflow/python/lib/core/py_exception_registry.h"
40 #include "tensorflow/python/lib/core/py_seq_tensor.h"
41 #include "tensorflow/python/lib/core/pybind11_status.h"
42 #include "tensorflow/python/lib/core/safe_ptr.h"
43 
44 // forward declare
45 struct EagerTensor;
46 namespace tensorflow {
47 
48 // Convert a TFE_TensorHandle to a Python numpy.ndarray object.
49 // The two may share underlying storage so changes to one may reflect in the
50 // other.
TFE_TensorHandleToNumpy(TFE_TensorHandle * handle,TF_Status * status)51 PyObject* TFE_TensorHandleToNumpy(TFE_TensorHandle* handle, TF_Status* status) {
52   if (TFE_TensorHandleDataType(handle) == TF_RESOURCE) {
53     TF_SetStatus(status, TF_INVALID_ARGUMENT,
54                  "Cannot convert a Tensor of dtype resource to a NumPy array.");
55     return nullptr;
56   }
57 
58   if (TFE_TensorHandleDataType(handle) == TF_VARIANT) {
59     TF_SetStatus(status, TF_INVALID_ARGUMENT,
60                  "Cannot convert a Tensor of dtype variant to a NumPy array.");
61     return nullptr;
62   }
63   tensorflow::Safe_TF_TensorPtr tensor = nullptr;
64   Py_BEGIN_ALLOW_THREADS;
65   tensor = tensorflow::make_safe(TFE_TensorHandleResolve(handle, status));
66   Py_END_ALLOW_THREADS;
67   if (!status->status.ok()) {
68     return nullptr;
69   }
70 
71   PyObject* ret = nullptr;
72   auto cppstatus =
73       tensorflow::TF_TensorToMaybeAliasedPyArray(std::move(tensor), &ret);
74   tensorflow::Set_TF_Status_from_Status(status, cppstatus);
75   if (!status->status.ok()) {
76     Py_XDECREF(ret);
77     return nullptr;
78   }
79   CHECK_NE(ret, nullptr);
80   return ret;
81 }
82 }  // namespace tensorflow
83 namespace {
84 
85 using tensorflow::TFE_TensorHandleToNumpy;
86 
87 // An instance of _EagerTensorProfiler that will receive callbacks about
88 // events on eager tensors. This is set by TFE_Py_InitEagerTensor, if at all.
89 PyObject* eager_tensor_profiler = nullptr;
90 
91 // Read-only dict. Please don't use this in any setting where the dict might
92 // actually get mutated. This is only used to pass empty kwargs when creating a
93 // new EagerTensor.
EmptyDict()94 PyObject* EmptyDict() {
95   static PyObject* empty_dict = PyDict_New();
96   return empty_dict;
97 }
98 
EmptyTuple()99 PyObject* EmptyTuple() {
100   static PyObject* empty_tuple = PyTuple_New(0);
101   return empty_tuple;
102 }
103 
GetContextHandle(PyObject * py_context)104 TFE_Context* GetContextHandle(PyObject* py_context) {
105   tensorflow::Safe_PyObjectPtr py_context_handle(
106       PyObject_GetAttrString(py_context, "_handle"));
107   if (py_context_handle == nullptr) {
108     // Current Python code makes sure this never happens. If it does, or
109     // becomes hard to maintain, we can call the ensure_initialized() method
110     // here.
111     PyErr_SetString(
112         PyExc_TypeError,
113         "Expected `context` argument in EagerTensor constructor to have a "
114         "`_handle` attribute but it did not. Was eager Context initialized?");
115     return nullptr;
116   }
117 
118   auto* ctx = reinterpret_cast<TFE_Context*>(
119       PyCapsule_GetPointer(py_context_handle.get(), nullptr));
120   if (ctx == nullptr) {
121     PyErr_SetString(PyExc_TypeError,
122                     tensorflow::strings::StrCat(
123                         "Expected context._handle to contain a PyCapsule "
124                         "encoded pointer to TFE_Context. Got ",
125                         Py_TYPE(py_context_handle.get())->tp_name)
126                         .c_str());
127   }
128   return ctx;
129 }
130 
131 
132 // Helper function to convert `v` to a tensorflow::DataType and store it in
133 // `*out`. Returns true on success, false otherwise.
134 // Note that we assume that v is a python int (not long) representing a
135 // TF_DataType/tensorflow::DataType value.
PyIntToDataType(PyObject * v,tensorflow::DataType * out)136 bool PyIntToDataType(PyObject* v, tensorflow::DataType* out) {
137 #if PY_MAJOR_VERSION < 3
138   if (PyInt_Check(v)) {
139     *out = static_cast<tensorflow::DataType>(PyInt_AS_LONG(v));
140     return true;
141   }
142 #else
143   if (PyLong_Check(v)) {
144     *out = static_cast<tensorflow::DataType>(PyLong_AsLong(v));
145     return true;
146   }
147 #endif
148   return false;
149 }
150 
151 // Helper function to create a python integer from TF_DataType.
PyIntFromDataType(TF_DataType l)152 PyObject* PyIntFromDataType(TF_DataType l) {
153 #if PY_MAJOR_VERSION < 3
154   return PyInt_FromLong(l);
155 #else
156   return PyLong_FromLong(l);
157 #endif
158 }
159 
160 // PyObject->tensorflow::DataType conversion function to be used with
161 // PyArg_Parse* APIs.
ConvertDataType(PyObject * obj,tensorflow::DataType * dst)162 int ConvertDataType(PyObject* obj, tensorflow::DataType* dst) {
163   if (obj == Py_None) {
164     *dst = tensorflow::DataType::DT_INVALID;
165   } else if (!PyIntToDataType(obj, dst)) {
166     PyErr_SetString(
167         PyExc_TypeError,
168         tensorflow::strings::StrCat(
169             "Expecting a DataType value for dtype. Got ", Py_TYPE(obj)->tp_name)
170             .c_str());
171     return 0;
172   }
173 
174   return 1;
175 }
176 
177 // Conversion function extracting a const char** device name from a PyObject.
178 // The function should be used with PyArg_Parse* APIs.
ConvertDeviceName(PyObject * obj,const char ** dst)179 int ConvertDeviceName(PyObject* obj, const char** dst) {
180   if (obj == Py_None) {
181     *dst = nullptr;
182   } else {
183     auto device_name = TFE_GetPythonString(obj);
184     if (device_name == nullptr) {
185       PyErr_Clear();
186       PyErr_SetString(PyExc_TypeError, "Error parsing device argument.");
187       return 0;
188     }
189     *dst = device_name;
190   }
191 
192   return 1;
193 }
194 
RaiseExceptionTypeFromTFStatus(TF_Status * tf_status)195 void RaiseExceptionTypeFromTFStatus(TF_Status* tf_status) {
196   auto status = tensorflow::StatusFromTF_Status(tf_status);
197   SetRegisteredErrFromStatus(status);
198 }
199 
200 }  // namespace
201 
202 namespace tensorflow {
203 // This function checks whether the desired type is "compatible" with the
204 // inferred type. At a high level, compatibility means that all integral types
205 // are compatible with each other, and all floating types are compatible with
206 // each other.
207 //
208 // Type compatibility doesn't consider overflows (i.e. int64 is *always*
209 // compatible with int32). This is intended to match graph behavior.
IsCompatible(DataType desired,DataType returned)210 bool IsCompatible(DataType desired, DataType returned) {
211   if (desired == returned) return true;
212 
213   if (DataTypeIsInteger(desired) && DataTypeIsInteger(returned)) {
214     return true;
215   } else if (DataTypeIsFloating(desired) &&
216              (DataTypeIsFloating(returned) || DataTypeIsInteger(returned))) {
217     return true;
218   } else if (DataTypeIsComplex(desired) &&
219              (DataTypeIsComplex(returned) || DataTypeIsInteger(returned) ||
220               DataTypeIsFloating(returned))) {
221     return true;
222   } else if (DataTypeIsQuantized(desired) && DataTypeIsInteger(returned)) {
223     return true;
224   }
225   return false;
226 }
227 
228 // TODO(nareshmodi): Move EagerCast and ReadVariableOp (which use the C API to
229 // execute TFE Ops) to a separate common library.
230 // Casts data referred to by `handle` from type `src_type_enum` to type
231 // `dst_type_enum`.
EagerCast(TFE_Context * ctx,TFE_TensorHandle * handle,TF_DataType src_type_enum,TF_DataType dst_type_enum,TF_Status * out_status)232 TFE_TensorHandle* EagerCast(TFE_Context* ctx, TFE_TensorHandle* handle,
233                             TF_DataType src_type_enum,
234                             TF_DataType dst_type_enum, TF_Status* out_status) {
235   if (ctx == nullptr) return nullptr;
236   const char* op_name = "Cast";
237   const char* device_name = "/device:CPU:0";
238   TFE_Op* op = TFE_NewOp(ctx, op_name, out_status);
239 #define RETURN_ERROR  \
240   {                   \
241     TFE_DeleteOp(op); \
242     return nullptr;   \
243   }
244   if (!out_status->status.ok()) RETURN_ERROR
245   TFE_OpSetDevice(op, device_name, out_status);
246   if (!out_status->status.ok()) RETURN_ERROR
247   TFE_OpAddInput(op, handle, out_status);
248   if (!out_status->status.ok()) RETURN_ERROR
249   TFE_OpSetAttrType(op, "SrcT", src_type_enum);
250   TFE_OpSetAttrType(op, "DstT", dst_type_enum);
251   TFE_OpSetAttrBool(op, "Truncate", false);
252   TFE_TensorHandle* output = nullptr;
253   int num_outputs = 1;
254   TFE_Execute(op, &output, &num_outputs, out_status);
255   if (!out_status->status.ok() || num_outputs != 1 || output == nullptr) {
256     if (output != nullptr) {
257       TFE_DeleteTensorHandle(output);
258     }
259     RETURN_ERROR
260   }
261   TFE_DeleteOp(op);
262   return output;
263 #undef RETURN_ERROR
264 }
265 
EagerConst(TFE_Context * ctx,TFE_TensorHandle * handle,const char * device_name,TF_Status * out_status)266 Safe_TFE_TensorHandlePtr EagerConst(TFE_Context* ctx, TFE_TensorHandle* handle,
267                                     const char* device_name,
268                                     TF_Status* out_status) {
269   const char* op_name = "_EagerConst";
270   std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
271       TFE_NewOp(ctx, op_name, out_status), TFE_DeleteOp);
272   if (!out_status->status.ok()) return nullptr;
273   TFE_OpSetDevice(op.get(), device_name, out_status);
274   if (!out_status->status.ok()) return nullptr;
275   TFE_OpAddInput(op.get(), handle, out_status);
276   if (!out_status->status.ok()) return nullptr;
277   TFE_OpSetAttrType(op.get(), "T", TFE_TensorHandleDataType(handle));
278   TFE_TensorHandle* output = nullptr;
279   int num_outputs = 1;
280   TFE_Execute(op.get(), &output, &num_outputs, out_status);
281   Safe_TFE_TensorHandlePtr result(output);
282   if (!out_status->status.ok() || num_outputs != 1) {
283     return nullptr;
284   }
285   return result;
286 }
287 
ConvertToEagerTensorUncached(TFE_Context * ctx,PyObject * value,tensorflow::DataType dtype,const char * device_name)288 TFE_TensorHandle* ConvertToEagerTensorUncached(TFE_Context* ctx,
289                                                PyObject* value,
290                                                tensorflow::DataType dtype,
291                                                const char* device_name) {
292   tensorflow::Safe_PyObjectPtr value_decrefer;
293   if (PyArray_IsScalar(value, Generic)) {
294     // Convert numpy scalars to numpy arrays.
295     value = PyArray_FromScalar(value, nullptr);
296     // The returned value needs to be DECREF'd, but the original value was
297     // created in python code, and doesn't need to be DECREF'd.
298     value_decrefer.reset(value);
299   }
300 
301   Safe_TFE_TensorHandlePtr handle =
302       make_safe(PySeqToTFE_TensorHandle(ctx, value, dtype));
303 
304   if (handle == nullptr) return nullptr;
305 
306   Safe_TF_StatusPtr status = make_safe(TF_NewStatus());
307   TF_DataType handle_dtype = TFE_TensorHandleDataType(handle.get());
308   if (dtype != tensorflow::DT_INVALID &&
309       dtype != static_cast<DataType>(handle_dtype)) {
310     if (tensorflow::IsCompatible(dtype, static_cast<DataType>(handle_dtype))) {
311       handle = tensorflow::make_safe(
312           tensorflow::EagerCast(ctx, handle.get(), handle_dtype,
313                                 static_cast<TF_DataType>(dtype), status.get()));
314       if (!status->status.ok()) {
315         PyErr_SetString(PyExc_TypeError,
316                         absl::StrCat("Error while casting from dtype ",
317                                      tensorflow::DataTypeString(
318                                          static_cast<DataType>(handle_dtype)),
319                                      " to ", tensorflow::DataTypeString(dtype),
320                                      ". ", TF_Message(status.get()))
321                             .c_str());
322         return nullptr;
323       }
324     } else {
325       tensorflow::Safe_PyObjectPtr value_str(PyObject_Repr(value));
326       PyErr_SetString(
327           PyExc_TypeError,
328           absl::StrCat("Cannot convert ", TFE_GetPythonString(value_str.get()),
329                        " to EagerTensor of dtype ",
330                        tensorflow::DataTypeString(dtype))
331               .c_str());
332       return nullptr;
333     }
334   }
335 
336   // We always initially generate CPU:0 tensors. Copy to the current device.
337   if (device_name != nullptr) {
338     if (strstr(device_name, "/device:CPU:0") != nullptr) {
339       // We always generate CPU:0 tensors, but we may need to change the device
340       // slightly, as for example from /job:localhost/... to /job:worker/...
341       //
342       // Note that this is a shallow copy and will share the underlying buffer,
343       // because we are copying to the same device.
344       handle = make_safe(TFE_TensorHandleCopyToDevice(
345           handle.get(), ctx, device_name, status.get()));
346       const TF_Code code = TF_GetCode(status.get());
347       if (code != TF_OK) {
348         RaiseExceptionTypeFromTFStatus(status.get());
349         return nullptr;
350       }
351     } else {
352       /*Copy the constant to the current device. Identity is sometimes
353         overloaded to allow copies like this, but using a different op allows
354         devices to support constant creation without allowing copies via
355         identity ops.
356 
357         Note that running this _EagerConst op limits mirroring of cached Python
358         literals somewhat. Mirroring of constants themselves works:
359 
360         with tf.device("GPU:0"):
361           tf.constant(1.)  # Cached on CPU:0, mirrored to GPU:0
362         with tf.device("GPU:1"):
363           tf.constant(1.)  # Cache hit for the CPU version, new mirror to GPU:1.
364         with tf.device("GPU:1"):
365           tf.constant(1.)  # Cache hit for the CPU version, cached mirror
366 
367         But mirrors for the output of `tf.constant` are not shared just because
368         there was a cache hit for the input literal, because of _EagerConst:
369 
370         x = tf.constant(2.)  # Cached on CPU:0
371         with tf.device("GPU:1"):
372           tf.identity(x)  # `x` now mirrored to GPU:1
373         y = tf.constant(2.)  # Cache hit for CPU version
374         with tf.device("GPU:1"):
375           tf.identity(y)  # `y` now mirrored on GPU:1 (new copy!)*/
376       handle =
377           tensorflow::EagerConst(ctx, handle.get(), device_name, status.get());
378       const TF_Code code = TF_GetCode(status.get());
379       if (code != TF_OK) {
380         RaiseExceptionTypeFromTFStatus(status.get());
381         return nullptr;
382       }
383     }
384   }
385 
386   return handle.release();
387 }
388 
ConvertToEagerTensor(TFE_Context * ctx,PyObject * value,DataType dtype,const char * device_name)389 TFE_TensorHandle* ConvertToEagerTensor(TFE_Context* ctx, PyObject* value,
390                                        DataType dtype,
391                                        const char* device_name) {
392   // Reduce the overhead of allocation/transfer-to-device for scalars by
393   // caching the corresponding handles. Note that currently only Python
394   // scalars are cached.
395   // TODO(slebedev): also cache singleton NumPy arrays and scalars?
396   if (PyArray_IsPythonNumber(value)) {
397     auto* cache = TFE_TensorHandleCache::Get();
398     TFE_TensorHandle* handle = cache->Lookup(value, dtype, ctx, device_name);
399     if (handle != nullptr) return handle;
400     handle = ConvertToEagerTensorUncached(ctx, value, dtype, device_name);
401     if (handle == nullptr) return nullptr;
402     if (!PyFloat_Check(value) || std::isfinite(PyFloat_AS_DOUBLE(value))) {
403       cache->Insert(value, dtype, ctx, device_name, handle);
404     }
405     return handle;
406   } else {
407     return ConvertToEagerTensorUncached(ctx, value, dtype, device_name);
408   }
409 }
410 
411 }  // namespace tensorflow
412 
413 extern "C" {
414 
415 static const int kMaxEagerTensorParentSize = 64;
416 
417 // TODO(agarwal): store context handle in EagerTensor.
418 typedef struct EagerTensor {
419   PyObject_HEAD;
420   // Note that we leave kMaxEagerTensorParentSize bytes here for use by the
421   // parent class. The parent class is set at runtime, so we don't know the
422   // exact size at compile time.
423   char unused[kMaxEagerTensorParentSize];
424   TFE_TensorHandle* handle;
425   int64_t id;
426   // Indicates whether it's a packed tensor or not.
427   bool is_packed;
428   // This mirrors tensorflow.core.framework.ops.Tensor._handle_data Which will
429   // be None for tensors of type other than DT_RESOURCE. For DT_RESOURCE
430   // tensors, this will contain a serialized HandleData proto with shape
431   // inference metadata about shapes and dtypes of resources accessible from
432   // this handle.
433   // Note that we assume that handle_data cannot participate in reference
434   // cycles, and hence don't provide GC support for it.
435   PyObject* handle_data;
436 
437   // This stores `_tensor_shape`, a cached `TensorShape` object, and is set the
438   // first time that `_EagerTensorBase`'s `shape` property is called.
439   PyObject* tensor_shape;
440 
441   // We store a status object here as an optimization to avoid allocating a new
442   // Status objects on different functions that operate on EagerTensor and need
443   // to use a TF_Status object. However note that accesses to `status` are not
444   // thread-safe.
445   TF_Status status;
446 
447   // The eager Context (from eager/context.py) used by this Tensor.
448   // This is currently used only to make sure context outlives TensorHandles.
449   PyObject* context;
450 
451   PyObject* weakreflist; /* List of weak references */
452 
453   // Per-instance attribute dictionary, to support monkey patching
454   // (e.g. EagerTensor.assign when slicing variables). This dictionary is
455   // created by CPython the first time an attribute is assigned, pointed to by
456   // tp_dictoffset. Note that garbage collection is not enabled for
457   // EagerTensors, so assigning objects to EagerTensor attributes which require
458   // garbage collection is likely to cause issues.
459   PyObject* dict;
460 } EagerTensor;
461 
462 namespace {
463 
464 // Returns true on success - successfully invoked or no profiler registered.
465 // Returns false if some error occurred.
MaybeInvokeCreatedOnEagerTensorProfiler(EagerTensor * created_tensor)466 bool MaybeInvokeCreatedOnEagerTensorProfiler(EagerTensor* created_tensor) {
467   if (eager_tensor_profiler != nullptr) {
468 #if PY_MAJOR_VERSION < 3
469     PyObject* created_method_name = PyString_InternFromString("created");
470 #else
471     PyObject* created_method_name = PyUnicode_InternFromString("created");
472 #endif
473     if (created_method_name == nullptr) {
474       return false;
475     }
476     PyObject* result = PyObject_CallMethodObjArgs(
477         eager_tensor_profiler, created_method_name, created_tensor, NULL);
478     if (result == nullptr) {
479       LOG(ERROR) << "Invoking created() on EagerTensor profiler failed";
480       // While we can potentially continue because the error is related to
481       // profiling, we choose to return an error because:
482       //  - If profiling is used, the user likely wants to stop execution on
483       //    profiling errors.
484       //  - Error in profiling code might have left some state in an invalid
485       //    form that can lead to an error later on. Better to fail fast.
486       Py_DECREF(created_method_name);
487       return false;
488     }
489     Py_DECREF(created_method_name);
490     Py_DECREF(result);
491   }
492   return true;
493 }
494 
495 }  // namespace
496 
497 // tp_init for EagerTensor.
EagerTensor_init(EagerTensor * self,PyObject * args,PyObject * kwds)498 int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) {
499   self->id = get_uid();
500   self->handle = nullptr;
501   self->is_packed = false;
502   Py_INCREF(Py_None);
503   self->handle_data = Py_None;
504   Py_INCREF(Py_None);
505   self->tensor_shape = Py_None;
506   self->status.status = ::tensorflow::OkStatus();
507   self->dict = nullptr;
508   self->weakreflist = nullptr;
509   self->context = nullptr;
510   PyObject* value;
511   const char* device_name = nullptr;
512   tensorflow::DataType dtype = tensorflow::DataType::DT_INVALID;
513   const char* kwlist[] = {"value", "device", "dtype", nullptr};
514   if (!PyArg_ParseTupleAndKeywords(
515           args, kwds, "OO&|O&", const_cast<char**>(kwlist), &value,
516           ConvertDeviceName, &device_name, ConvertDataType, &dtype)) {
517     return -1;
518   }
519 
520   PyObject* py_context = GetPyEagerContext();
521   if (py_context == nullptr) return -1;
522   self->context = py_context;
523 
524   auto* handle = tensorflow::ConvertToEagerTensor(GetContextHandle(py_context),
525                                                   value, dtype, device_name);
526   if (handle == nullptr) return -1;
527   self->handle = handle;
528 
529   if (!MaybeInvokeCreatedOnEagerTensorProfiler(self)) {
530     return -1;
531   }
532 
533   return 0;
534 }
535 
536 // tp_dealloc for EagerTensor.
EagerTensor_dealloc(EagerTensor * self)537 void EagerTensor_dealloc(EagerTensor* self) {
538   // Unhook the object from python's GC so that the weakref deleter doesn't
539   // try to re-delete this.
540   PyObject_GC_UnTrack((PyObject*)self);
541 
542   // Clear weak references to self.
543   // Needs to happen before any actual destruction.
544   PyObject_ClearWeakRefs((PyObject*)self);
545 
546   Py_DECREF(self->handle_data);
547   Py_DECREF(self->tensor_shape);
548   // If an attribute dictionary has been created, release it. Note that this
549   // is only ever created by CPython's attribute setting methods; we don't
550   // create it ourselves.
551   Py_CLEAR(self->dict);
552   if (self->handle != nullptr) {
553     // Destructor may call arbitrary functions that end up calling into
554     // Python from another thread.
555     Py_BEGIN_ALLOW_THREADS;
556     TFE_DeleteTensorHandle(self->handle);
557     Py_END_ALLOW_THREADS;
558     self->handle = nullptr;
559   }
560 
561   // Decref context after deleting the tensor handle.
562   Py_XDECREF(self->context);
563 
564   // We have the global interpreter lock, so use this chance to perform delayed
565   // refcount decrements.
566   tensorflow::ClearDecrefCache();
567   auto id = self->id;
568   Py_TYPE(self)->tp_free(self);
569   TFE_Py_TapeSetDeleteTrace(id);
570 }
571 
572 // Getter for `_id`.
EagerTensor_getid(EagerTensor * self,void * closure)573 static PyObject* EagerTensor_getid(EagerTensor* self, void* closure) {
574   return PyLong_FromLongLong(self->id);
575 }
576 
577 // Getter for `_datatype_enum`.
EagerTensor_datatype_enum(EagerTensor * self)578 static PyObject* EagerTensor_datatype_enum(EagerTensor* self) {
579   return PyIntFromDataType(TFE_TensorHandleDataType(self->handle));
580 }
581 
582 // Getter for `_shape_tuple`.
EagerTensor_shape_tuple(EagerTensor * self)583 static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
584   auto handle = self->handle;
585   int n = TFE_TensorHandleNumDims(handle, &self->status);
586   TF_Code code = TF_GetCode(&self->status);
587   if (code != TF_OK) {
588     RaiseExceptionTypeFromTFStatus(&self->status);
589     // Cleanup self->status before returning.
590     self->status.status = ::tensorflow::OkStatus();
591     return nullptr;
592   }
593   PyObject* shape = PyTuple_New(n);
594   if (PyErr_Occurred()) return nullptr;
595   for (int i = 0; i < n; ++i) {
596     int64_t dim_c_value = TFE_TensorHandleDim(handle, i, &self->status);
597     PyObject* dim;
598     // The C++ convention is -1 for unknown/variable axis lengths. Translate
599     // that to the Python "None" convention. Unknown axis lengths are unusual
600     // for eager tensors.
601     if (dim_c_value < 0) {
602       Py_IncRef(Py_None);
603       dim = Py_None;
604     } else {
605       dim = PyLong_FromLongLong(dim_c_value);
606     }
607     code = TF_GetCode(&self->status);
608     if (code != TF_OK || dim == nullptr ||
609         PyTuple_SetItem(shape, i, dim) != 0) {
610       if (code != TF_OK) {
611         RaiseExceptionTypeFromTFStatus(&self->status);
612       } else {
613         PyErr_SetString(PyExc_RuntimeError, "Error while creating shape");
614       }
615       // Cleanup self->status before returning.
616       self->status.status = ::tensorflow::OkStatus();
617       Py_DECREF(shape);
618       if (dim != nullptr) Py_DECREF(dim);
619       return nullptr;
620     }
621   }
622   return shape;
623 }
624 
625 // Getter for `_rank`.
EagerTensor_rank(EagerTensor * self)626 static PyObject* EagerTensor_rank(EagerTensor* self) {
627   int num_dims = TFE_TensorHandleNumDims(self->handle, &self->status);
628   if (tensorflow::MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) {
629     // Cleanup self->status before returning.
630     self->status.status = ::tensorflow::OkStatus();
631     return nullptr;
632   }
633 #if PY_MAJOR_VERSION < 3
634   return PyInt_FromLong(num_dims);
635 #else
636   return PyLong_FromLong(num_dims);
637 #endif
638 }
639 
640 // Getter for `_num_elements`.
EagerTensor_num_elements(EagerTensor * self)641 static PyObject* EagerTensor_num_elements(EagerTensor* self) {
642   auto handle = self->handle;
643   int n = TFE_TensorHandleNumElements(handle, &self->status);
644   if (tensorflow::MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) {
645     // Cleanup self->status before returning.
646     self->status.status = ::tensorflow::OkStatus();
647     return nullptr;
648   }
649   return PyLong_FromLongLong(n);
650 }
651 
EagerTensor_handle_data(EagerTensor * self,void * unused)652 static PyObject* EagerTensor_handle_data(EagerTensor* self, void* unused) {
653   Py_INCREF(self->handle_data);
654   return self->handle_data;
655 }
656 
EagerTensor_sethandle_data(EagerTensor * self,PyObject * value,void * unused)657 static int EagerTensor_sethandle_data(EagerTensor* self, PyObject* value,
658                                       void* unused) {
659   Py_DECREF(self->handle_data);
660   Py_INCREF(value);
661   self->handle_data = value;
662   return 0;
663 }
664 
EagerTensor_tensor_shape(EagerTensor * self,void * unused)665 static PyObject* EagerTensor_tensor_shape(EagerTensor* self, void* unused) {
666   Py_INCREF(self->tensor_shape);
667   return self->tensor_shape;
668 }
669 
EagerTensor_settensor_shape(EagerTensor * self,PyObject * value,void * unused)670 static int EagerTensor_settensor_shape(EagerTensor* self, PyObject* value,
671                                        void* unused) {
672   Py_DECREF(self->tensor_shape);
673   Py_INCREF(value);
674   self->tensor_shape = value;
675   return 0;
676 }
677 
678 // Function `_copy_to_device`.
EagerTensor_copy_to_device(EagerTensor * self,PyObject * args,PyObject * kwds)679 static PyObject* EagerTensor_copy_to_device(EagerTensor* self, PyObject* args,
680                                             PyObject* kwds) {
681   if (!_PyArg_NoKeywords("copy_to_device", kwds)) return nullptr;
682 
683   const char* device_name = nullptr;
684   if (!PyArg_ParseTuple(args, "O&:copy_to_device", ConvertDeviceName,
685                         &device_name)) {
686     return nullptr;
687   }
688 
689   // Note that this is a shallow copy and will share the underlying buffer
690   // if copying to the same device.
691   TFE_TensorHandle* handle = TFE_TensorHandleCopyToDevice(
692       self->handle, GetContextHandle(self->context), device_name,
693       &self->status);
694   if (tensorflow::MaybeRaiseExceptionFromTFStatus(&self->status,
695                                                   PyExc_RuntimeError)) {
696     // Cleanup self->status before returning.
697     self->status.status = ::tensorflow::OkStatus();
698     return nullptr;
699   }
700 
701   return EagerTensorFromHandle(handle);
702 }
703 
704 // Function `_numpy_internal`.
705 // Convert an EagerTensor to a Python numpy.ndarray object.
706 // The two may share underlying storage so changes to one may reflect in the
707 // other.
708 // Note that if `self` is not on CPU, we raise an Exception.
EagerTensor_numpy_internal(EagerTensor * self)709 static PyObject* EagerTensor_numpy_internal(EagerTensor* self) {
710   auto* py_array = TFE_TensorHandleToNumpy(self->handle, &self->status);
711   if (tensorflow::MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) {
712     Py_XDECREF(py_array);
713     // Cleanup self->status before returning.
714     self->status.status = ::tensorflow::OkStatus();
715     return nullptr;
716   } else {
717     return PyArray_Return(reinterpret_cast<PyArrayObject*>(py_array));
718   }
719 }
720 
721 // Function `_prefer_custom_summarizer`.
722 //
723 // A hint that callers should prefer `SummarizeValue` to resolving this handle
724 // and formatting the tensor.
EagerTensor_prefer_custom_summarizer(EagerTensor * self)725 static PyObject* EagerTensor_prefer_custom_summarizer(EagerTensor* self) {
726   if (tensorflow::unwrap(self->handle)->PreferCustomSummarizer()) {
727     Py_RETURN_TRUE;
728   } else {
729     Py_RETURN_FALSE;
730   }
731 }
732 
733 // Function `_summarize_value`.
734 //
735 // Returns a string PyObject which summarizes the value of this tensor. It does
736 // not include a shape or dtype.
EagerTensor_summarize_value(EagerTensor * self)737 static PyObject* EagerTensor_summarize_value(EagerTensor* self) {
738   std::string summary;
739   tensorflow::Status status =
740       tensorflow::unwrap(self->handle)->SummarizeValue(summary);
741   if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
742     return nullptr;
743   }
744   return PyUnicode_FromString(summary.c_str());
745 }
746 
747 // Getter `device`.
EagerTensor_device(EagerTensor * self)748 static PyObject* EagerTensor_device(EagerTensor* self) {
749   const char* device = TFE_TensorHandleDeviceName(self->handle, &self->status);
750   if (tensorflow::MaybeRaiseExceptionFromTFStatus(&self->status,
751                                                   PyExc_ValueError)) {
752     // Cleanup self->status before returning.
753     self->status.status = ::tensorflow::OkStatus();
754     return nullptr;
755   }
756 #if PY_MAJOR_VERSION >= 3
757   return PyUnicode_FromString(device);
758 #else
759   return PyBytes_FromString(device);
760 #endif
761 }
762 
763 // Getter `backing_device`.
EagerTensor_backing_device(EagerTensor * self)764 static PyObject* EagerTensor_backing_device(EagerTensor* self) {
765   const char* device =
766       TFE_TensorHandleBackingDeviceName(self->handle, &self->status);
767   if (tensorflow::MaybeRaiseExceptionFromTFStatus(&self->status,
768                                                   PyExc_ValueError)) {
769     // Cleanup self->status before returning.
770     self->status.status = ::tensorflow::OkStatus();
771     return nullptr;
772   }
773 #if PY_MAJOR_VERSION >= 3
774   return PyUnicode_FromString(device);
775 #else
776   return PyBytes_FromString(device);
777 #endif
778 }
779 
780 // Getter `is_packed`.
EagerTensor_is_packed(EagerTensor * self)781 static PyObject* EagerTensor_is_packed(EagerTensor* self) {
782   return PyBool_FromLong(self->is_packed);
783 }
784 
785 static PyGetSetDef EagerTensor_getsetters[] = {
786     {const_cast<char*>("_id"), (getter)EagerTensor_getid, nullptr,
787      const_cast<char*>("Tensor ID."), nullptr},
788     {const_cast<char*>("device"), (getter)EagerTensor_device, nullptr,
789      const_cast<char*>("Device of op that produced the tensor."), nullptr},
790     {const_cast<char*>("backing_device"), (getter)EagerTensor_backing_device,
791      nullptr, const_cast<char*>("Device on which tensor's memory is resident."),
792      nullptr},
793     {const_cast<char*>("is_packed"), (getter)EagerTensor_is_packed, nullptr,
794      const_cast<char*>("Whether the EagerTensor is a packed tensor or not."),
795      nullptr},
796     {const_cast<char*>("_handle_data"), (getter)EagerTensor_handle_data,
797      (setter)EagerTensor_sethandle_data,
798      const_cast<char*>("Shape/DType data if the EagerTensor is a DT_RESOURCE"),
799      nullptr},
800     {const_cast<char*>("_tensor_shape"), (getter)EagerTensor_tensor_shape,
801      (setter)EagerTensor_settensor_shape,
802      const_cast<char*>("Shape of the tensor."), nullptr},
803     {nullptr} /* Sentinel */
804 };
805 
806 #if PY_MAJOR_VERSION < 3
807 // Only used for Python2 since Python3 seems to set the __dict__ correctly.
808 static PyMemberDef EagerTensor_members[] = {
809     {const_cast<char*>("__dict__"), T_OBJECT, offsetof(EagerTensor, dict),
810      READONLY},
811     {nullptr},
812 };
813 #endif
814 
815 static PyMethodDef EagerTensor_methods[] = {
816     {"_numpy_internal", (PyCFunction)EagerTensor_numpy_internal, METH_NOARGS,
817      PyDoc_STR("Internal method to get a NumPy array for the tensor.")},
818     {"_datatype_enum", (PyCFunction)EagerTensor_datatype_enum, METH_NOARGS,
819      PyDoc_STR("The DType of the tensor as an enum.")},
820     {"_shape_tuple", (PyCFunction)EagerTensor_shape_tuple, METH_NOARGS,
821      PyDoc_STR("The shape of the tensor as a python tuple.")},
822     {"_rank", (PyCFunction)EagerTensor_rank, METH_NOARGS,
823      PyDoc_STR("The rank of the tensor.")},
824     {"_copy_to_device", (PyCFunction)EagerTensor_copy_to_device,
825      METH_VARARGS | METH_KEYWORDS,
826      PyDoc_STR("Copies the tensor to the desired device.")},
827     {"_num_elements", (PyCFunction)EagerTensor_num_elements, METH_NOARGS,
828      PyDoc_STR("Number of elements in the tensor.")},
829     {"_prefer_custom_summarizer",
830      (PyCFunction)EagerTensor_prefer_custom_summarizer, METH_NOARGS,
831      PyDoc_STR("Indicates whether _numpy_internal loses information.")},
832     {"_summarize_value", (PyCFunction)EagerTensor_summarize_value, METH_NOARGS,
833      PyDoc_STR("A string which summarizes the value of this tensor.")},
834     {nullptr, nullptr},
835 };
836 
EagerTensor_getbuffer(EagerTensor * self,Py_buffer * view,int flags)837 static int EagerTensor_getbuffer(EagerTensor* self, Py_buffer* view,
838                                  int flags) {
839   if ((flags & PyBUF_WRITABLE) == PyBUF_WRITABLE) {
840     PyErr_SetString(PyExc_BufferError, "EagerTensor is not writable.");
841     return -1;
842   }
843 
844   // TensorHandleToNumpy is zero-copy for everything but DT_RESOURCE and
845   // DT_STRING so the following is only slightly slower than a NumPy-free
846   // implementation.
847   auto py_array = tensorflow::make_safe(
848       TFE_TensorHandleToNumpy(self->handle, &self->status));
849   if (tensorflow::MaybeRaiseExceptionFromTFStatus(&self->status,
850                                                   PyExc_BufferError)) {
851     // Cleanup self->status before returning.
852     self->status.status = ::tensorflow::OkStatus();
853     return -1;
854   }
855   if (PyObject_GetBuffer(py_array.get(), view, flags) < 0) {
856     return -1;
857   }
858   view->readonly = 1;
859   return 0;
860 }
861 
862 static PyBufferProcs EagerTensor_as_buffer = {
863 #if PY_MAJOR_VERSION < 3
864     nullptr, nullptr, nullptr, nullptr,
865 #endif
866     (getbufferproc)EagerTensor_getbuffer,
867     // Never called because getbufferproc delegates to NumPy.
868     (releasebufferproc) nullptr};
869 
870 // Note that here we are trying to dynamically create a new class as a subclass
871 // of a "HEAPTYPE" class that is itself created in python code and passed in at
872 // runtime. This is fairly atypical and undocumented.
873 //
874 // We use the following strategy for this. Unfortunately, we have to use
875 // different approaches for python2.x vs python3.x
876 // For python2.x, we create the class as a static type and set its tp_base to
877 // the passed in type. Unfortunately setting tp_flags to include
878 // Py_TPFLAGS_HEAPTYPE does not work by itself since it needs some more
879 // initialization of the underlying PyHeapTypeObject and not doing that leads to
880 // some random crashes especially during garbage collection.
881 // python3.x explicitly disables a static subclass of a HEAPTYPE base class.
882 // However it provides a new function, PyType_FromSpecWithBases, to create
883 // types dynamically.
884 
885 // Type object for EagerTensor. This is set by TFE_Py_InitEagerTensor.
886 PyTypeObject* EagerTensorType = nullptr;
887 
888 #if PY_MAJOR_VERSION >= 3
889 static PyType_Slot EagerTensor_Type_slots[] = {
890     {Py_tp_dealloc, reinterpret_cast<void*>(EagerTensor_dealloc)},
891     {Py_tp_methods, reinterpret_cast<void*>(EagerTensor_methods)},
892     {Py_tp_getset, reinterpret_cast<void*>(EagerTensor_getsetters)},
893     {Py_tp_init, reinterpret_cast<void*>(EagerTensor_init)},
894     {0, nullptr},
895 };
896 #else
897 
898 #define EAGER_TENSOR_TPFLAGS (Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_NEWBUFFER)
899 
900 // TODO(agarwal): support active_trace.
901 static PyTypeObject _EagerTensorType = {
902     // clang-format off
903     PyVarObject_HEAD_INIT(nullptr, 0)
904     // clang-format on
905     "EagerTensor",                      /* tp_name */
906     sizeof(EagerTensor),                /* tp_basicsize */
907     0,                                  /* tp_itemsize */
908     (destructor)EagerTensor_dealloc,    /* tp_dealloc */
909 #if PY_VERSION_HEX < 0x03080000
910     nullptr,                            /* tp_print */
911 #else
912     0, /* tp_vectorcall_offset */
913 #endif
914     nullptr,                            /* tp_getattr */
915     nullptr,                            /* tp_setattr */
916     nullptr,                            /* tp_compare */
917     nullptr,                            /* tp_repr */
918     nullptr,                            /* tp_as_number */
919     nullptr,                            /* tp_as_sequence */
920     nullptr,                            /* tp_as_mapping */
921     nullptr,                            /* tp_hash */
922     nullptr,                            /* tp_call */
923     nullptr,                            /* tp_str */
924     nullptr,                            /* tp_getattro */
925     nullptr,                            /* tp_setattro */
926     &EagerTensor_as_buffer,             /* tp_as_buffer */
927     EAGER_TENSOR_TPFLAGS,               /* tp_flags */
928     nullptr,                            /* tp_doc */
929     nullptr,                            /* tp_traverse */
930     nullptr,                            /* tp_clear */
931     nullptr,                            /* tp_richcompare */
932     offsetof(EagerTensor, weakreflist), /* tp_weaklistoffset */
933     nullptr,                            /* tp_iter */
934     nullptr,                            /* tp_iternext */
935     EagerTensor_methods,                /* tp_methods */
936     EagerTensor_members,                /* tp_members */
937     EagerTensor_getsetters,             /* tp_getset */
938     nullptr,                            /* tp_base */
939     nullptr,                            /* tp_dict */
940     nullptr,                            /* tp_descr_get */
941     nullptr,                            /* tp_descr_set */
942     offsetof(EagerTensor, dict),        /* tp_dictoffset */
943     (initproc)EagerTensor_init,         /* tp_init */
944     nullptr,                            /* tp_alloc */
945     nullptr,                            /* tp_new */
946 };
947 
948 #endif
949 
950 }  // extern "C"
951 
EagerTensor_CheckExact(const PyObject * o)952 bool EagerTensor_CheckExact(const PyObject* o) {
953   return Py_TYPE(o) == EagerTensorType;
954 }
955 
EagerTensor_Handle(const PyObject * o)956 TFE_TensorHandle* EagerTensor_Handle(const PyObject* o) {
957   return reinterpret_cast<const EagerTensor*>(o)->handle;
958 }
959 
EagerTensorFromHandle(TFE_TensorHandle * handle,const bool is_packed)960 PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle,
961                                 const bool is_packed) {
962   if (handle == nullptr) {
963     return nullptr;
964   }
965   EagerTensor* t = reinterpret_cast<EagerTensor*>(
966       EagerTensorType->tp_new(EagerTensorType, EmptyTuple(), EmptyDict()));
967   if (t != nullptr) {
968     t->id = get_uid();
969     t->is_packed = is_packed;
970     Py_INCREF(Py_None);
971     t->handle_data = Py_None;
972     Py_INCREF(Py_None);
973     t->tensor_shape = Py_None;
974     t->handle = handle;
975     t->status.status = ::tensorflow::OkStatus();
976     t->weakreflist = nullptr;
977     PyObject* py_context = GetPyEagerContext();
978     if (py_context == nullptr) {
979       LOG(ERROR) << "Cannot create an eager tensor before eager context has "
980                     "been set or after it has been deleted";
981       return nullptr;
982     }
983     t->context = py_context;
984 
985     if (!MaybeInvokeCreatedOnEagerTensorProfiler(t)) {
986       return nullptr;
987     }
988   }
989   return reinterpret_cast<PyObject*>(t);
990 }
991 
PyEagerTensor_ID(const PyObject * tensor)992 int64_t PyEagerTensor_ID(const PyObject* tensor) {
993   DCHECK(EagerTensor_CheckExact(tensor));
994   return reinterpret_cast<const EagerTensor*>(tensor)->id;
995 }
996 
PyEagerTensor_Dtype(const PyObject * tensor)997 tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor) {
998   DCHECK(EagerTensor_CheckExact(tensor));
999   return static_cast<tensorflow::DataType>(TFE_TensorHandleDataType(
1000       reinterpret_cast<const EagerTensor*>(tensor)->handle));
1001 }
1002 
PyEagerTensor_NumElements(PyObject * tensor)1003 int64_t PyEagerTensor_NumElements(PyObject* tensor) {
1004   DCHECK(EagerTensor_CheckExact(tensor));
1005   EagerTensor* as_c_eager_tensor = reinterpret_cast<EagerTensor*>(tensor);
1006   int64_t result = TFE_TensorHandleNumElements(as_c_eager_tensor->handle,
1007                                                &as_c_eager_tensor->status);
1008 
1009   if (tensorflow::MaybeRaiseExceptionFromTFStatus(&as_c_eager_tensor->status,
1010                                                   PyExc_ValueError)) {
1011     // Cleanup status before returning.
1012     as_c_eager_tensor->status.status = ::tensorflow::OkStatus();
1013     return -1;
1014   }
1015 
1016   return result;
1017 }
1018 
TFE_Py_InitEagerTensor(PyObject * base_class)1019 PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
1020   if (!PyType_Check(base_class)) {
1021     PyErr_SetString(
1022         PyExc_TypeError,
1023         tensorflow::strings::StrCat(
1024             "Expecting a class definition for `base_class` passed to ",
1025             "TFE_InitEagerTensor. Got ", Py_TYPE(base_class)->tp_name)
1026             .c_str());
1027     return nullptr;
1028   }
1029   // Note that we allocated kMaxEagerTensorParentSize bytes of unused space in
1030   // EagerTensor to allow for the space usage of the base class.
1031   PyTypeObject* base_class_type = reinterpret_cast<PyTypeObject*>(base_class);
1032   if (base_class_type->tp_basicsize > kMaxEagerTensorParentSize) {
1033     PyErr_SetString(
1034         PyExc_TypeError,
1035         tensorflow::strings::StrCat(
1036             "Unable to create subclass EagerTensor from base class ",
1037             Py_TYPE(base_class)->tp_name,
1038             ". Need its size to be <= ", kMaxEagerTensorParentSize)
1039             .c_str());
1040     return nullptr;
1041   }
1042   if (base_class_type->tp_itemsize != 0) {
1043     PyErr_SetString(
1044         PyExc_TypeError,
1045         tensorflow::strings::StrCat(
1046             "Unable to create subclass EagerTensor from base class ",
1047             Py_TYPE(base_class)->tp_name,
1048             " which supports variable length instances.")
1049             .c_str());
1050     return nullptr;
1051   }
1052   Py_INCREF(base_class);
1053 #if PY_MAJOR_VERSION >= 3
1054   PyObject* bases = PyTuple_New(1);
1055   PyTuple_SET_ITEM(bases, 0, base_class);
1056 
1057   tensorflow::Safe_PyObjectPtr base_class_module(
1058       PyObject_GetAttrString(base_class, "__module__"));
1059   const char* module = nullptr;
1060   if (PyErr_Occurred()) {
1061     PyErr_Clear();
1062     module = "__builtin__";
1063   } else {
1064     module = PyBytes_AsString(base_class_module.get());
1065     if (module == nullptr) {
1066       PyErr_Clear();
1067       module = PyUnicode_AsUTF8(base_class_module.get());
1068       if (module == nullptr) {
1069         PyErr_Clear();
1070         module = "__builtin__";
1071       }
1072     }
1073   }
1074 
1075   // NOTE: The c_str from this string needs to outlast the function, hence is
1076   // static.
1077   static tensorflow::string fully_qualified_name =
1078       tensorflow::strings::StrCat(module, ".EagerTensor");
1079 
1080   static PyType_Spec EagerTensor_Type_spec = {
1081       fully_qualified_name.c_str(), sizeof(EagerTensor), 0,
1082       Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE, EagerTensor_Type_slots};
1083 
1084   EagerTensorType = reinterpret_cast<PyTypeObject*>(
1085       PyType_FromSpecWithBases(&EagerTensor_Type_spec, bases));
1086   if (PyErr_Occurred()) {
1087     return nullptr;
1088   }
1089   if (EagerTensorType == nullptr) {
1090     PyErr_SetString(PyExc_RuntimeError, "Error while creating EagerTensorType");
1091     return nullptr;
1092   }
1093   EagerTensorType->tp_dictoffset = offsetof(EagerTensor, dict);
1094   EagerTensorType->tp_as_buffer = &EagerTensor_as_buffer;
1095 #else
1096   _EagerTensorType.tp_base = base_class_type;
1097 
1098   if (PyType_Ready(&_EagerTensorType) < 0) {
1099     if (PyErr_Occurred()) return nullptr;
1100     PyErr_SetString(PyExc_RuntimeError,
1101                     "Error while creating EagerTensor type.");
1102     return nullptr;
1103   }
1104   EagerTensorType = &_EagerTensorType;
1105   Py_INCREF(EagerTensorType);
1106 #endif
1107   return reinterpret_cast<PyObject*>(EagerTensorType);
1108 }
1109 
TFE_Py_SetEagerTensorProfiler(PyObject * profiler)1110 PyObject* TFE_Py_SetEagerTensorProfiler(PyObject* profiler) {
1111   Py_XDECREF(eager_tensor_profiler);
1112 
1113   if (profiler == Py_None) {
1114     eager_tensor_profiler = nullptr;
1115   } else {
1116     eager_tensor_profiler = profiler;
1117     Py_INCREF(eager_tensor_profiler);
1118   }
1119   Py_RETURN_NONE;
1120 }
1121 
TFE_Py_TensorShapeSlice(PyObject * tensors,int slice_dim)1122 PyObject* TFE_Py_TensorShapeSlice(PyObject* tensors, int slice_dim) {
1123   if (!PyList_Check(tensors) && !PyTuple_Check(tensors)) {
1124     PyErr_SetString(PyExc_TypeError,
1125                     tensorflow::strings::StrCat(
1126                         "tensors argument must be a list or a tuple. Got \"",
1127                         Py_TYPE(tensors)->tp_name, "\"")
1128                         .c_str());
1129     return nullptr;
1130   }
1131   if (slice_dim < 0) {
1132     PyErr_SetString(
1133         PyExc_ValueError,
1134         tensorflow::strings::StrCat("Slice dimension must be non-negative. "
1135                                     "Got ",
1136                                     slice_dim)
1137             .c_str());
1138     return nullptr;
1139   }
1140 
1141   PyObject* py_context = GetPyEagerContext();
1142   if (py_context == nullptr) {
1143     PyErr_SetString(PyExc_RuntimeError, tensorflow::strings::StrCat(
1144                                             "Cannot create EagerTensor when "
1145                                             "EagerContext is not valid")
1146                                             .c_str());
1147     return nullptr;
1148   }
1149 
1150   TFE_Context* ctx = GetContextHandle(py_context);
1151 
1152   Py_ssize_t num_tensors = PySequence_Fast_GET_SIZE(tensors);
1153   PyObject** tensors_array = PySequence_Fast_ITEMS(tensors);
1154   int64_t num_tensors_int = static_cast<int64_t>(num_tensors);
1155 
1156   auto status = tensorflow::make_safe(TF_NewStatus());
1157 
1158   // Create an empty tensor.
1159   auto* tensor = tensorflow::unwrap(ctx)->CreateTensor(
1160       tensorflow::DT_INT32, /*dim_sizes=*/{num_tensors_int});
1161 
1162   if (num_tensors_int > 0) {
1163     int32_t* data = reinterpret_cast<int32_t*>(tensor->Data());
1164 
1165     // Fill the tensor with dims.
1166     for (Py_ssize_t i = 0; i < num_tensors; ++i) {
1167       PyObject* tensor_obj = tensors_array[i];
1168       if (!EagerTensor_CheckExact(tensor_obj)) {
1169         PyErr_SetString(
1170             PyExc_TypeError,
1171             tensorflow::strings::StrCat("Expected a list of EagerTensors but "
1172                                         "element ",
1173                                         i, " has type \"",
1174                                         Py_TYPE(tensor_obj)->tp_name, "\"")
1175                 .c_str());
1176         return nullptr;
1177       }
1178 
1179       EagerTensor* t = reinterpret_cast<EagerTensor*>(tensor_obj);
1180       TFE_TensorHandle* handle = t->handle;
1181       int num_dims = TFE_TensorHandleNumDims(handle, status.get());
1182       if (tensorflow::MaybeRaiseExceptionFromTFStatus(status.get(),
1183                                                       PyExc_ValueError)) {
1184         return nullptr;
1185       }
1186       if (slice_dim >= num_dims) {
1187         PyErr_SetString(
1188             PyExc_IndexError,
1189             tensorflow::strings::StrCat("Slice dimension (", slice_dim,
1190                                         ") must be smaller than rank of all "
1191                                         "tensors, but tensor at index ",
1192                                         i, " has rank ", num_dims)
1193                 .c_str());
1194         return nullptr;
1195       }
1196       int64_t dim = TFE_TensorHandleDim(handle, slice_dim, status.get());
1197       if (tensorflow::MaybeRaiseExceptionFromTFStatus(status.get(),
1198                                                       PyExc_ValueError)) {
1199         return nullptr;
1200       }
1201       data[i] = dim;
1202     }
1203   }
1204 
1205   TFE_TensorHandle* handle =
1206       tensorflow::wrap(tensorflow::unwrap(ctx)->CreateLocalHandle(tensor));
1207 
1208   if (!status->status.ok()) {
1209     PyErr_SetString(
1210         PyExc_RuntimeError,
1211         tensorflow::strings::StrCat("Failed to construct new tensor handle: ",
1212                                     TF_Message(status.get()))
1213             .c_str());
1214     return nullptr;
1215   }
1216 
1217   return EagerTensorFromHandle(handle);
1218 }
1219 
TFE_Py_TensorShapeOnDevice(PyObject * tensor)1220 PyObject* TFE_Py_TensorShapeOnDevice(PyObject* tensor) {
1221   if (!EagerTensor_CheckExact(tensor)) {
1222     PyErr_SetString(
1223         PyExc_TypeError,
1224         tensorflow::strings::StrCat("Expected an EagerTensors but got type \"",
1225                                     Py_TYPE(tensor)->tp_name, "\"")
1226             .c_str());
1227     return nullptr;
1228   }
1229   TFE_TensorHandle* handle = EagerTensor_Handle(tensor);
1230 
1231   auto status = tensorflow::make_safe(TF_NewStatus());
1232   TFE_TensorDebugInfo* debug_info =
1233       TFE_TensorHandleTensorDebugInfo(handle, status.get());
1234   if (!status->status.ok()) {
1235     PyErr_SetString(
1236         PyExc_RuntimeError,
1237         tensorflow::strings::StrCat("Error retrieving tensor's device shape: ",
1238                                     TF_Message(status.get()))
1239             .c_str());
1240     return nullptr;
1241   }
1242 
1243   int rank = TFE_TensorDebugInfoOnDeviceNumDims(debug_info);
1244   PyObject* shape = PyTuple_New(rank);
1245   for (int i = 0; i < rank; ++i) {
1246     int64_t dim_size = TFE_TensorDebugInfoOnDeviceDim(debug_info, i);
1247     PyTuple_SET_ITEM(shape, i, PyLong_FromLongLong(dim_size));
1248   }
1249   TFE_DeleteTensorDebugInfo(debug_info);
1250 
1251   return shape;
1252 }
1253