• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/python/eager/pywrap_tensor.h"
17 
18 #include <stdlib.h>
19 #include <string.h>
20 
21 #include <cmath>
22 
23 #include "structmember.h"  // NOLINT // For PyMemberDef
24 #include "pybind11/pybind11.h"
25 #include "tensorflow/c/c_api.h"
26 #include "tensorflow/c/eager/c_api.h"
27 #include "tensorflow/c/eager/c_api_internal.h"
28 #include "tensorflow/c/eager/tfe_context_internal.h"
29 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
30 #include "tensorflow/c/tf_status.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/framework/types.pb.h"
33 #include "tensorflow/core/lib/strings/strcat.h"
34 #include "tensorflow/python/eager/pywrap_tensor_conversion.h"
35 #include "tensorflow/python/eager/pywrap_tfe.h"
36 #include "tensorflow/python/lib/core/ndarray_tensor.h"
37 #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
38 #include "tensorflow/python/lib/core/numpy.h"
39 #include "tensorflow/python/lib/core/py_exception_registry.h"
40 #include "tensorflow/python/lib/core/py_seq_tensor.h"
41 #include "tensorflow/python/lib/core/pybind11_status.h"
42 #include "tensorflow/python/lib/core/safe_ptr.h"
43 
44 // forward declare
45 struct EagerTensor;
46 namespace tensorflow {
47 
48 // Convert a TFE_TensorHandle to a Python numpy.ndarray object.
49 // The two may share underlying storage so changes to one may reflect in the
50 // other.
TFE_TensorHandleToNumpy(TFE_TensorHandle * handle,TF_Status * status)51 PyObject* TFE_TensorHandleToNumpy(TFE_TensorHandle* handle, TF_Status* status) {
52   if (TFE_TensorHandleDataType(handle) == TF_RESOURCE) {
53     TF_SetStatus(status, TF_INVALID_ARGUMENT,
54                  "Cannot convert a Tensor of dtype resource to a NumPy array.");
55     return nullptr;
56   }
57 
58   if (TFE_TensorHandleDataType(handle) == TF_VARIANT) {
59     TF_SetStatus(status, TF_INVALID_ARGUMENT,
60                  "Cannot convert a Tensor of dtype variant to a NumPy array.");
61     return nullptr;
62   }
63   tensorflow::Safe_TF_TensorPtr tensor = nullptr;
64   Py_BEGIN_ALLOW_THREADS;
65   tensor = tensorflow::make_safe(TFE_TensorHandleResolve(handle, status));
66   Py_END_ALLOW_THREADS;
67   if (!status->status.ok()) {
68     return nullptr;
69   }
70 
71   PyObject* ret = nullptr;
72   auto cppstatus =
73       tensorflow::TF_TensorToMaybeAliasedPyArray(std::move(tensor), &ret);
74   tensorflow::Set_TF_Status_from_Status(status, cppstatus);
75   if (!status->status.ok()) {
76     Py_XDECREF(ret);
77     return nullptr;
78   }
79   CHECK_NE(ret, nullptr);
80   return ret;
81 }
82 }  // namespace tensorflow
83 namespace {
84 
85 using tensorflow::TFE_TensorHandleToNumpy;
86 
87 // An instance of _EagerTensorProfiler that will receive callbacks about
88 // events on eager tensors. This is set by TFE_Py_InitEagerTensor, if at all.
89 PyObject* eager_tensor_profiler = nullptr;
90 
91 // Read-only dict. Please don't use this in any setting where the dict might
92 // actually get mutated. This is only used to pass empty kwargs when creating a
93 // new EagerTensor.
EmptyDict()94 PyObject* EmptyDict() {
95   static PyObject* empty_dict = PyDict_New();
96   return empty_dict;
97 }
98 
EmptyTuple()99 PyObject* EmptyTuple() {
100   static PyObject* empty_tuple = PyTuple_New(0);
101   return empty_tuple;
102 }
103 
GetContextHandle(PyObject * py_context)104 TFE_Context* GetContextHandle(PyObject* py_context) {
105   tensorflow::Safe_PyObjectPtr py_context_handle(
106       PyObject_GetAttrString(py_context, "_handle"));
107   if (py_context_handle == nullptr) {
108     // Current Python code makes sure this never happens. If it does, or
109     // becomes hard to maintain, we can call the ensure_initialized() method
110     // here.
111     PyErr_SetString(
112         PyExc_TypeError,
113         "Expected `context` argument in EagerTensor constructor to have a "
114         "`_handle` attribute but it did not. Was eager Context initialized?");
115     return nullptr;
116   }
117 
118   auto* ctx = reinterpret_cast<TFE_Context*>(
119       PyCapsule_GetPointer(py_context_handle.get(), nullptr));
120   if (ctx == nullptr) {
121     PyErr_SetString(PyExc_TypeError,
122                     tensorflow::strings::StrCat(
123                         "Expected context._handle to contain a PyCapsule "
124                         "encoded pointer to TFE_Context. Got ",
125                         Py_TYPE(py_context_handle.get())->tp_name)
126                         .c_str());
127   }
128   return ctx;
129 }
130 
131 
132 // Helper function to convert `v` to a tensorflow::DataType and store it in
133 // `*out`. Returns true on success, false otherwise.
134 // Note that we assume that v is a python int (not long) representing a
135 // TF_DataType/tensorflow::DataType value.
PyIntToDataType(PyObject * v,tensorflow::DataType * out)136 bool PyIntToDataType(PyObject* v, tensorflow::DataType* out) {
137 #if PY_MAJOR_VERSION < 3
138   if (PyInt_Check(v)) {
139     *out = static_cast<tensorflow::DataType>(PyInt_AS_LONG(v));
140     return true;
141   }
142 #else
143   if (PyLong_Check(v)) {
144     *out = static_cast<tensorflow::DataType>(PyLong_AsLong(v));
145     return true;
146   }
147 #endif
148   return false;
149 }
150 
151 // Helper function to create a python integer from TF_DataType.
PyIntFromDataType(TF_DataType l)152 PyObject* PyIntFromDataType(TF_DataType l) {
153 #if PY_MAJOR_VERSION < 3
154   return PyInt_FromLong(l);
155 #else
156   return PyLong_FromLong(l);
157 #endif
158 }
159 
160 // PyObject->tensorflow::DataType conversion function to be used with
161 // PyArg_Parse* APIs.
ConvertDataType(PyObject * obj,tensorflow::DataType * dst)162 int ConvertDataType(PyObject* obj, tensorflow::DataType* dst) {
163   if (obj == Py_None) {
164     *dst = tensorflow::DataType::DT_INVALID;
165   } else if (!PyIntToDataType(obj, dst)) {
166     PyErr_SetString(
167         PyExc_TypeError,
168         tensorflow::strings::StrCat(
169             "Expecting a DataType value for dtype. Got ", Py_TYPE(obj)->tp_name)
170             .c_str());
171     return 0;
172   }
173 
174   return 1;
175 }
176 
177 // Conversion function extracting a const char** device name from a PyObject.
178 // The function should be used with PyArg_Parse* APIs.
ConvertDeviceName(PyObject * obj,const char ** dst)179 int ConvertDeviceName(PyObject* obj, const char** dst) {
180   if (obj == Py_None) {
181     *dst = nullptr;
182   } else {
183     auto device_name = TFE_GetPythonString(obj);
184     if (device_name == nullptr) {
185       PyErr_Clear();
186       PyErr_SetString(PyExc_TypeError, "Error parsing device argument.");
187       return 0;
188     }
189     *dst = device_name;
190   }
191 
192   return 1;
193 }
194 
RaiseExceptionTypeFromTFStatus(TF_Status * tf_status)195 void RaiseExceptionTypeFromTFStatus(TF_Status* tf_status) {
196   auto status = tensorflow::StatusFromTF_Status(tf_status);
197   SetRegisteredErrFromStatus(status);
198 }
199 
200 }  // namespace
201 
202 namespace tensorflow {
203 // This function checks whether the desired type is "compatible" with the
204 // inferred type. At a high level, compatibility means that all integral types
205 // are compatible with each other, and all floating types are compatible with
206 // each other.
207 //
208 // Type compatibility doesn't consider overflows (i.e. int64 is *always*
209 // compatible with int32). This is intended to match graph behavior.
IsCompatible(DataType desired,DataType returned)210 bool IsCompatible(DataType desired, DataType returned) {
211   if (desired == returned) return true;
212 
213   if (DataTypeIsInteger(desired) && DataTypeIsInteger(returned)) {
214     return true;
215   } else if (DataTypeIsFloating(desired) &&
216              (DataTypeIsFloating(returned) || DataTypeIsInteger(returned))) {
217     return true;
218   } else if (DataTypeIsComplex(desired) &&
219              (DataTypeIsComplex(returned) || DataTypeIsInteger(returned) ||
220               DataTypeIsFloating(returned))) {
221     return true;
222   } else if (DataTypeIsQuantized(desired) && DataTypeIsInteger(returned)) {
223     return true;
224   }
225   return false;
226 }
227 
228 // TODO(nareshmodi): Move EagerCast and ReadVariableOp (which use the C API to
229 // execute TFE Ops) to a separate common library.
230 // Casts data referred to by `handle` from type `src_type_enum` to type
231 // `dst_type_enum`.
EagerCast(TFE_Context * ctx,TFE_TensorHandle * handle,TF_DataType src_type_enum,TF_DataType dst_type_enum,TF_Status * out_status)232 TFE_TensorHandle* EagerCast(TFE_Context* ctx, TFE_TensorHandle* handle,
233                             TF_DataType src_type_enum,
234                             TF_DataType dst_type_enum, TF_Status* out_status) {
235   if (ctx == nullptr) return nullptr;
236   const char* op_name = "Cast";
237   const char* device_name = "/device:CPU:0";
238   TFE_Op* op = TFE_NewOp(ctx, op_name, out_status);
239 #define RETURN_ERROR  \
240   {                   \
241     TFE_DeleteOp(op); \
242     return nullptr;   \
243   }
244   if (!out_status->status.ok()) RETURN_ERROR
245   TFE_OpSetDevice(op, device_name, out_status);
246   if (!out_status->status.ok()) RETURN_ERROR
247   TFE_OpAddInput(op, handle, out_status);
248   if (!out_status->status.ok()) RETURN_ERROR
249   TFE_OpSetAttrType(op, "SrcT", src_type_enum);
250   TFE_OpSetAttrType(op, "DstT", dst_type_enum);
251   TFE_OpSetAttrBool(op, "Truncate", false);
252   TFE_TensorHandle* output = nullptr;
253   int num_outputs = 1;
254   TFE_Execute(op, &output, &num_outputs, out_status);
255   if (!out_status->status.ok() || num_outputs != 1 || output == nullptr) {
256     if (output != nullptr) {
257       TFE_DeleteTensorHandle(output);
258     }
259     RETURN_ERROR
260   }
261   TFE_DeleteOp(op);
262   return output;
263 #undef RETURN_ERROR
264 }
265 
EagerConst(TFE_Context * ctx,TFE_TensorHandle * handle,const char * device_name,TF_Status * out_status)266 Safe_TFE_TensorHandlePtr EagerConst(TFE_Context* ctx, TFE_TensorHandle* handle,
267                                     const char* device_name,
268                                     TF_Status* out_status) {
269   const char* op_name = "_EagerConst";
270   std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
271       TFE_NewOp(ctx, op_name, out_status), TFE_DeleteOp);
272   if (!out_status->status.ok()) return nullptr;
273   TFE_OpSetDevice(op.get(), device_name, out_status);
274   if (!out_status->status.ok()) return nullptr;
275   TFE_OpAddInput(op.get(), handle, out_status);
276   if (!out_status->status.ok()) return nullptr;
277   TFE_OpSetAttrType(op.get(), "T", TFE_TensorHandleDataType(handle));
278   TFE_TensorHandle* output = nullptr;
279   int num_outputs = 1;
280   TFE_Execute(op.get(), &output, &num_outputs, out_status);
281   Safe_TFE_TensorHandlePtr result(output);
282   if (!out_status->status.ok() || num_outputs != 1) {
283     return nullptr;
284   }
285   return result;
286 }
287 
ConvertToEagerTensorUncached(TFE_Context * ctx,PyObject * value,tensorflow::DataType dtype,const char * device_name)288 TFE_TensorHandle* ConvertToEagerTensorUncached(TFE_Context* ctx,
289                                                PyObject* value,
290                                                tensorflow::DataType dtype,
291                                                const char* device_name) {
292   tensorflow::Safe_PyObjectPtr value_decrefer;
293   if (PyArray_IsScalar(value, Generic)) {
294     // Convert numpy scalars to numpy arrays.
295     value = PyArray_FromScalar(value, nullptr);
296     // The returned value needs to be DECREF'd, but the original value was
297     // created in python code, and doesn't need to be DECREF'd.
298     value_decrefer.reset(value);
299   }
300 
301   Safe_TFE_TensorHandlePtr handle =
302       make_safe(PySeqToTFE_TensorHandle(ctx, value, dtype));
303 
304   if (handle == nullptr) return nullptr;
305 
306   Safe_TF_StatusPtr status = make_safe(TF_NewStatus());
307   TF_DataType handle_dtype = TFE_TensorHandleDataType(handle.get());
308   if (dtype != tensorflow::DT_INVALID &&
309       dtype != static_cast<DataType>(handle_dtype)) {
310     if (tensorflow::IsCompatible(dtype, static_cast<DataType>(handle_dtype))) {
311       handle = tensorflow::make_safe(
312           tensorflow::EagerCast(ctx, handle.get(), handle_dtype,
313                                 static_cast<TF_DataType>(dtype), status.get()));
314       if (!status->status.ok()) {
315         PyErr_SetString(PyExc_TypeError,
316                         absl::StrCat("Error while casting from dtype ",
317                                      tensorflow::DataTypeString(
318                                          static_cast<DataType>(handle_dtype)),
319                                      " to ", tensorflow::DataTypeString(dtype),
320                                      ". ", TF_Message(status.get()))
321                             .c_str());
322         return nullptr;
323       }
324     } else {
325       tensorflow::Safe_PyObjectPtr value_str(PyObject_Repr(value));
326       PyErr_SetString(
327           PyExc_TypeError,
328           absl::StrCat("Cannot convert ", TFE_GetPythonString(value_str.get()),
329                        " to EagerTensor of dtype ",
330                        tensorflow::DataTypeString(dtype))
331               .c_str());
332       return nullptr;
333     }
334   }
335 
336   // We always initially generate CPU:0 tensors. Copy to the current device.
337   if (device_name != nullptr) {
338     if (strstr(device_name, "/device:CPU:0") != nullptr) {
339       // We always generate CPU:0 tensors, but we may need to change the device
340       // slightly, as for example from /job:localhost/... to /job:worker/...
341       //
342       // Note that this is a shallow copy and will share the underlying buffer,
343       // because we are copying to the same device.
344       handle = make_safe(TFE_TensorHandleCopyToDevice(
345           handle.get(), ctx, device_name, status.get()));
346       const TF_Code code = TF_GetCode(status.get());
347       if (code != TF_OK) {
348         RaiseExceptionTypeFromTFStatus(status.get());
349         return nullptr;
350       }
351     } else {
352       /*Copy the constant to the current device. Identity is sometimes
353         overloaded to allow copies like this, but using a different op allows
354         devices to support constant creation without allowing copies via
355         identity ops.
356 
357         Note that running this _EagerConst op limits mirroring of cached Python
358         literals somewhat. Mirroring of constants themselves works:
359 
360         with tf.device("GPU:0"):
361           tf.constant(1.)  # Cached on CPU:0, mirrored to GPU:0
362         with tf.device("GPU:1"):
363           tf.constant(1.)  # Cache hit for the CPU version, new mirror to GPU:1.
364         with tf.device("GPU:1"):
365           tf.constant(1.)  # Cache hit for the CPU version, cached mirror
366 
367         But mirrors for the output of `tf.constant` are not shared just because
368         there was a cache hit for the input literal, because of _EagerConst:
369 
370         x = tf.constant(2.)  # Cached on CPU:0
371         with tf.device("GPU:1"):
372           tf.identity(x)  # `x` now mirrored to GPU:1
373         y = tf.constant(2.)  # Cache hit for CPU version
374         with tf.device("GPU:1"):
375           tf.identity(y)  # `y` now mirrored on GPU:1 (new copy!)*/
376       handle =
377           tensorflow::EagerConst(ctx, handle.get(), device_name, status.get());
378       const TF_Code code = TF_GetCode(status.get());
379       if (code != TF_OK) {
380         RaiseExceptionTypeFromTFStatus(status.get());
381         return nullptr;
382       }
383     }
384   }
385 
386   return handle.release();
387 }
388 
ConvertToEagerTensor(TFE_Context * ctx,PyObject * value,DataType dtype,const char * device_name)389 TFE_TensorHandle* ConvertToEagerTensor(TFE_Context* ctx, PyObject* value,
390                                        DataType dtype,
391                                        const char* device_name) {
392   // Reduce the overhead of allocation/transfer-to-device for scalars by
393   // caching the corresponding handles. Note that currently only Python
394   // scalars are cached.
395   // TODO(slebedev): also cache singleton NumPy arrays and scalars?
396   if (PyArray_IsPythonNumber(value)) {
397     auto* cache = TFE_TensorHandleCache::Get();
398     TFE_TensorHandle* handle = cache->Lookup(value, dtype, ctx, device_name);
399     if (handle != nullptr) return handle;
400     handle = ConvertToEagerTensorUncached(ctx, value, dtype, device_name);
401     if (handle == nullptr) return nullptr;
402     if (!PyFloat_Check(value) || std::isfinite(PyFloat_AS_DOUBLE(value))) {
403       cache->Insert(value, dtype, ctx, device_name, handle);
404     }
405     return handle;
406   } else {
407     return ConvertToEagerTensorUncached(ctx, value, dtype, device_name);
408   }
409 }
410 
411 }  // namespace tensorflow
412 
413 extern "C" {
414 
415 static const int kMaxEagerTensorParentSize = 64;
416 
417 // TODO(agarwal): store context handle in EagerTensor.
418 typedef struct EagerTensor {
419   PyObject_HEAD;
420   // Note that we leave kMaxEagerTensorParentSize bytes here for use by the
421   // parent class. The parent class is set at runtime, so we don't know the
422   // exact size at compile time.
423   char unused[kMaxEagerTensorParentSize];
424   TFE_TensorHandle* handle;
425   int64_t id;
426   // Indicates whether it's a packed tensor or not.
427   bool is_packed;
428   // This mirrors tensorflow.core.framework.ops.Tensor._handle_data Which will
429   // be None for tensors of type other than DT_RESOURCE. For DT_RESOURCE
430   // tensors, this will contain a serialized HandleData proto with shape
431   // inference metadata about shapes and dtypes of resources accessible from
432   // this handle.
433   // Note that we assume that handle_data cannot participate in reference
434   // cycles, and hence don't provide GC support for it.
435   PyObject* handle_data;
436 
437   // This stores `_tensor_shape`, a cached `TensorShape` object, and is set the
438   // first time that `_EagerTensorBase`'s `shape` property is called.
439   PyObject* tensor_shape;
440 
441   // We store a status object here as an optimization to avoid allocating a new
442   // Status objects on different functions that operate on EagerTensor and need
443   // to use a TF_Status object. However note that accesses to `status` are not
444   // thread-safe.
445   TF_Status status;
446 
447   // The eager Context (from eager/context.py) used by this Tensor.
448   // This is currently used only to make sure context outlives TensorHandles.
449   PyObject* context;
450 
451   PyObject* weakreflist; /* List of weak references */
452 
453   // Per-instance attribute dictionary, to support monkey patching
454   // (e.g. EagerTensor.assign when slicing variables). This dictionary is
455   // created by CPython the first time an attribute is assigned, pointed to by
456   // tp_dictoffset. Note that garbage collection is not enabled for
457   // EagerTensors, so assigning objects to EagerTensor attributes which require
458   // garbage collection is likely to cause issues.
459   PyObject* dict;
460 } EagerTensor;
461 
462 namespace {
463 
464 // Returns true on success - successfully invoked or no profiler registered.
465 // Returns false if some error occurred.
MaybeInvokeCreatedOnEagerTensorProfiler(EagerTensor * created_tensor)466 bool MaybeInvokeCreatedOnEagerTensorProfiler(EagerTensor* created_tensor) {
467   if (eager_tensor_profiler != nullptr) {
468 #if PY_MAJOR_VERSION < 3
469     PyObject* created_method_name = PyString_InternFromString("created");
470 #else
471     PyObject* created_method_name = PyUnicode_InternFromString("created");
472 #endif
473     if (created_method_name == nullptr) {
474       return false;
475     }
476     PyObject* result = PyObject_CallMethodObjArgs(
477         eager_tensor_profiler, created_method_name, created_tensor, NULL);
478     if (result == nullptr) {
479       LOG(ERROR) << "Invoking created() on EagerTensor profiler failed";
480       // While we can potentially continue because the error is related to
481       // profiling, we choose to return an error because:
482       //  - If profiling is used, the user likely wants to stop execution on
483       //    profiling errors.
484       //  - Error in profiling code might have left some state in an invalid
485       //    form that can lead to an error later on. Better to fail fast.
486       Py_DECREF(created_method_name);
487       return false;
488     }
489     Py_DECREF(created_method_name);
490     Py_DECREF(result);
491   }
492   return true;
493 }
494 
495 }  // namespace
496 
497 // tp_init for EagerTensor.
EagerTensor_init(EagerTensor * self,PyObject * args,PyObject * kwds)498 int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) {
499   self->id = get_uid();
500   self->handle = nullptr;
501   self->is_packed = false;
502   Py_INCREF(Py_None);
503   self->handle_data = Py_None;
504   Py_INCREF(Py_None);
505   self->tensor_shape = Py_None;
506   self->status.status = tensorflow::Status::OK();
507   self->dict = nullptr;
508   self->weakreflist = nullptr;
509   self->context = nullptr;
510   PyObject* value;
511   const char* device_name = nullptr;
512   tensorflow::DataType dtype = tensorflow::DataType::DT_INVALID;
513   const char* kwlist[] = {"value", "device", "dtype", nullptr};
514   if (!PyArg_ParseTupleAndKeywords(
515           args, kwds, "OO&|O&", const_cast<char**>(kwlist), &value,
516           ConvertDeviceName, &device_name, ConvertDataType, &dtype)) {
517     return -1;
518   }
519 
520   PyObject* py_context = GetPyEagerContext();
521   if (py_context == nullptr) return -1;
522   self->context = py_context;
523 
524   auto* handle = tensorflow::ConvertToEagerTensor(GetContextHandle(py_context),
525                                                   value, dtype, device_name);
526   if (handle == nullptr) return -1;
527   self->handle = handle;
528 
529   if (!MaybeInvokeCreatedOnEagerTensorProfiler(self)) {
530     return -1;
531   }
532 
533   return 0;
534 }
535 
536 // tp_dealloc for EagerTensor.
EagerTensor_dealloc(EagerTensor * self)537 void EagerTensor_dealloc(EagerTensor* self) {
538   // Unhook the object from python's GC so that the weakref deleter doesn't
539   // try to re-delete this.
540   PyObject_GC_UnTrack((PyObject*)self);
541 
542   // Clear weak references to self.
543   // Needs to happen before any actual destruction.
544   PyObject_ClearWeakRefs((PyObject*)self);
545 
546   Py_DECREF(self->handle_data);
547   Py_DECREF(self->tensor_shape);
548   // If an attribute dictionary has been created, release it. Note that this
549   // is only ever created by CPython's attribute setting methods; we don't
550   // create it ourselves.
551   Py_CLEAR(self->dict);
552   if (self->handle != nullptr) {
553     TFE_DeleteTensorHandle(self->handle);
554     self->handle = nullptr;
555   }
556 
557   // Decref context after deleting the tensor handle.
558   Py_XDECREF(self->context);
559 
560   // We have the global interpreter lock, so use this chance to perform delayed
561   // refcount decrements.
562   tensorflow::ClearDecrefCache();
563   auto id = self->id;
564   Py_TYPE(self)->tp_free(self);
565   TFE_Py_TapeSetDeleteTrace(id);
566 }
567 
568 // Getter for `_id`.
EagerTensor_getid(EagerTensor * self,void * closure)569 static PyObject* EagerTensor_getid(EagerTensor* self, void* closure) {
570   return PyLong_FromLongLong(self->id);
571 }
572 
573 // Getter for `_datatype_enum`.
EagerTensor_datatype_enum(EagerTensor * self)574 static PyObject* EagerTensor_datatype_enum(EagerTensor* self) {
575   return PyIntFromDataType(TFE_TensorHandleDataType(self->handle));
576 }
577 
578 // Getter for `_shape_tuple`.
EagerTensor_shape_tuple(EagerTensor * self)579 static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
580   auto handle = self->handle;
581   int n = TFE_TensorHandleNumDims(handle, &self->status);
582   TF_Code code = TF_GetCode(&self->status);
583   if (code != TF_OK) {
584     RaiseExceptionTypeFromTFStatus(&self->status);
585     // Cleanup self->status before returning.
586     self->status.status = tensorflow::Status::OK();
587     return nullptr;
588   }
589   PyObject* shape = PyTuple_New(n);
590   if (PyErr_Occurred()) return nullptr;
591   for (int i = 0; i < n; ++i) {
592     int64_t dim_c_value = TFE_TensorHandleDim(handle, i, &self->status);
593     PyObject* dim;
594     // The C++ convention is -1 for unknown/variable axis lengths. Translate
595     // that to the Python "None" convention. Unknown axis lengths are unusual
596     // for eager tensors.
597     if (dim_c_value < 0) {
598       Py_IncRef(Py_None);
599       dim = Py_None;
600     } else {
601       dim = PyLong_FromLongLong(dim_c_value);
602     }
603     code = TF_GetCode(&self->status);
604     if (code != TF_OK || dim == nullptr ||
605         PyTuple_SetItem(shape, i, dim) != 0) {
606       if (code != TF_OK) {
607         RaiseExceptionTypeFromTFStatus(&self->status);
608       } else {
609         PyErr_SetString(PyExc_RuntimeError, "Error while creating shape");
610       }
611       // Cleanup self->status before returning.
612       self->status.status = tensorflow::Status::OK();
613       Py_DECREF(shape);
614       if (dim != nullptr) Py_DECREF(dim);
615       return nullptr;
616     }
617   }
618   return shape;
619 }
620 
621 // Getter for `_rank`.
EagerTensor_rank(EagerTensor * self)622 static PyObject* EagerTensor_rank(EagerTensor* self) {
623   int num_dims = TFE_TensorHandleNumDims(self->handle, &self->status);
624   if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) {
625     // Cleanup self->status before returning.
626     self->status.status = tensorflow::Status::OK();
627     return nullptr;
628   }
629 #if PY_MAJOR_VERSION < 3
630   return PyInt_FromLong(num_dims);
631 #else
632   return PyLong_FromLong(num_dims);
633 #endif
634 }
635 
636 // Getter for `_num_elements`.
EagerTensor_num_elements(EagerTensor * self)637 static PyObject* EagerTensor_num_elements(EagerTensor* self) {
638   auto handle = self->handle;
639   int n = TFE_TensorHandleNumElements(handle, &self->status);
640   if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) {
641     // Cleanup self->status before returning.
642     self->status.status = tensorflow::Status::OK();
643     return nullptr;
644   }
645   return PyLong_FromLongLong(n);
646 }
647 
EagerTensor_handle_data(EagerTensor * self,void * unused)648 static PyObject* EagerTensor_handle_data(EagerTensor* self, void* unused) {
649   Py_INCREF(self->handle_data);
650   return self->handle_data;
651 }
652 
EagerTensor_sethandle_data(EagerTensor * self,PyObject * value,void * unused)653 static int EagerTensor_sethandle_data(EagerTensor* self, PyObject* value,
654                                       void* unused) {
655   Py_DECREF(self->handle_data);
656   Py_INCREF(value);
657   self->handle_data = value;
658   return 0;
659 }
660 
EagerTensor_tensor_shape(EagerTensor * self,void * unused)661 static PyObject* EagerTensor_tensor_shape(EagerTensor* self, void* unused) {
662   Py_INCREF(self->tensor_shape);
663   return self->tensor_shape;
664 }
665 
EagerTensor_settensor_shape(EagerTensor * self,PyObject * value,void * unused)666 static int EagerTensor_settensor_shape(EagerTensor* self, PyObject* value,
667                                        void* unused) {
668   Py_DECREF(self->tensor_shape);
669   Py_INCREF(value);
670   self->tensor_shape = value;
671   return 0;
672 }
673 
674 // Function `_copy_to_device`.
EagerTensor_copy_to_device(EagerTensor * self,PyObject * args,PyObject * kwds)675 static PyObject* EagerTensor_copy_to_device(EagerTensor* self, PyObject* args,
676                                             PyObject* kwds) {
677   if (!_PyArg_NoKeywords("copy_to_device", kwds)) return nullptr;
678 
679   const char* device_name = nullptr;
680   if (!PyArg_ParseTuple(args, "O&:copy_to_device", ConvertDeviceName,
681                         &device_name)) {
682     return nullptr;
683   }
684 
685   // Note that this is a shallow copy and will share the underlying buffer
686   // if copying to the same device.
687   TFE_TensorHandle* handle = TFE_TensorHandleCopyToDevice(
688       self->handle, GetContextHandle(self->context), device_name,
689       &self->status);
690   if (MaybeRaiseExceptionFromTFStatus(&self->status, PyExc_RuntimeError)) {
691     // Cleanup self->status before returning.
692     self->status.status = tensorflow::Status::OK();
693     return nullptr;
694   }
695 
696   return EagerTensorFromHandle(handle);
697 }
698 
699 // Function `_numpy_internal`.
700 // Convert an EagerTensor to a Python numpy.ndarray object.
701 // The two may share underlying storage so changes to one may reflect in the
702 // other.
703 // Note that if `self` is not on CPU, we raise an Exception.
EagerTensor_numpy_internal(EagerTensor * self)704 static PyObject* EagerTensor_numpy_internal(EagerTensor* self) {
705   auto* py_array = TFE_TensorHandleToNumpy(self->handle, &self->status);
706   if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) {
707     Py_XDECREF(py_array);
708     // Cleanup self->status before returning.
709     self->status.status = tensorflow::Status::OK();
710     return nullptr;
711   } else {
712     return PyArray_Return(reinterpret_cast<PyArrayObject*>(py_array));
713   }
714 }
715 
716 // Function `_prefer_custom_summarizer`.
717 //
718 // A hint that callers should prefer `SummarizeValue` to resolving this handle
719 // and formatting the tensor.
EagerTensor_prefer_custom_summarizer(EagerTensor * self)720 static PyObject* EagerTensor_prefer_custom_summarizer(EagerTensor* self) {
721   if (tensorflow::unwrap(self->handle)->PreferCustomSummarizer()) {
722     Py_RETURN_TRUE;
723   } else {
724     Py_RETURN_FALSE;
725   }
726 }
727 
728 // Function `_summarize_value`.
729 //
730 // Returns a string PyObject which summarizes the value of this tensor. It does
731 // not include a shape or dtype.
EagerTensor_summarize_value(EagerTensor * self)732 static PyObject* EagerTensor_summarize_value(EagerTensor* self) {
733   std::string summary;
734   tensorflow::Status status =
735       tensorflow::unwrap(self->handle)->SummarizeValue(summary);
736   if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
737     return nullptr;
738   }
739   return PyUnicode_FromString(summary.c_str());
740 }
741 
742 // Getter `device`.
EagerTensor_device(EagerTensor * self)743 static PyObject* EagerTensor_device(EagerTensor* self) {
744   const char* device = TFE_TensorHandleDeviceName(self->handle, &self->status);
745   if (MaybeRaiseExceptionFromTFStatus(&self->status, PyExc_ValueError)) {
746     // Cleanup self->status before returning.
747     self->status.status = tensorflow::Status::OK();
748     return nullptr;
749   }
750 #if PY_MAJOR_VERSION >= 3
751   return PyUnicode_FromString(device);
752 #else
753   return PyBytes_FromString(device);
754 #endif
755 }
756 
757 // Getter `backing_device`.
EagerTensor_backing_device(EagerTensor * self)758 static PyObject* EagerTensor_backing_device(EagerTensor* self) {
759   const char* device =
760       TFE_TensorHandleBackingDeviceName(self->handle, &self->status);
761   if (MaybeRaiseExceptionFromTFStatus(&self->status, PyExc_ValueError)) {
762     // Cleanup self->status before returning.
763     self->status.status = tensorflow::Status::OK();
764     return nullptr;
765   }
766 #if PY_MAJOR_VERSION >= 3
767   return PyUnicode_FromString(device);
768 #else
769   return PyBytes_FromString(device);
770 #endif
771 }
772 
773 // Getter `is_packed`.
EagerTensor_is_packed(EagerTensor * self)774 static PyObject* EagerTensor_is_packed(EagerTensor* self) {
775   return PyBool_FromLong(self->is_packed);
776 }
777 
778 static PyGetSetDef EagerTensor_getsetters[] = {
779     {const_cast<char*>("_id"), (getter)EagerTensor_getid, nullptr,
780      const_cast<char*>("Tensor ID."), nullptr},
781     {const_cast<char*>("device"), (getter)EagerTensor_device, nullptr,
782      const_cast<char*>("Device of op that produced the tensor."), nullptr},
783     {const_cast<char*>("backing_device"), (getter)EagerTensor_backing_device,
784      nullptr, const_cast<char*>("Device on which tensor's memory is resident."),
785      nullptr},
786     {const_cast<char*>("is_packed"), (getter)EagerTensor_is_packed, nullptr,
787      const_cast<char*>("Whether the EagerTensor is a packed tensor or not."),
788      nullptr},
789     {const_cast<char*>("_handle_data"), (getter)EagerTensor_handle_data,
790      (setter)EagerTensor_sethandle_data,
791      const_cast<char*>("Shape/DType data if the EagerTensor is a DT_RESOURCE"),
792      nullptr},
793     {const_cast<char*>("_tensor_shape"), (getter)EagerTensor_tensor_shape,
794      (setter)EagerTensor_settensor_shape,
795      const_cast<char*>("Shape of the tensor."), nullptr},
796     {nullptr} /* Sentinel */
797 };
798 
799 #if PY_MAJOR_VERSION < 3
800 // Only used for Python2 since Python3 seems to set the __dict__ correctly.
801 static PyMemberDef EagerTensor_members[] = {
802     {const_cast<char*>("__dict__"), T_OBJECT, offsetof(EagerTensor, dict),
803      READONLY},
804     {nullptr},
805 };
806 #endif
807 
808 static PyMethodDef EagerTensor_methods[] = {
809     {"_numpy_internal", (PyCFunction)EagerTensor_numpy_internal, METH_NOARGS,
810      PyDoc_STR("Internal method to get a NumPy array for the tensor.")},
811     {"_datatype_enum", (PyCFunction)EagerTensor_datatype_enum, METH_NOARGS,
812      PyDoc_STR("The DType of the tensor as an enum.")},
813     {"_shape_tuple", (PyCFunction)EagerTensor_shape_tuple, METH_NOARGS,
814      PyDoc_STR("The shape of the tensor as a python tuple.")},
815     {"_rank", (PyCFunction)EagerTensor_rank, METH_NOARGS,
816      PyDoc_STR("The rank of the tensor.")},
817     {"_copy_to_device", (PyCFunction)EagerTensor_copy_to_device,
818      METH_VARARGS | METH_KEYWORDS,
819      PyDoc_STR("Copies the tensor to the desired device.")},
820     {"_num_elements", (PyCFunction)EagerTensor_num_elements, METH_NOARGS,
821      PyDoc_STR("Number of elements in the tensor.")},
822     {"_prefer_custom_summarizer",
823      (PyCFunction)EagerTensor_prefer_custom_summarizer, METH_NOARGS,
824      PyDoc_STR("Indicates whether _numpy_internal loses information.")},
825     {"_summarize_value", (PyCFunction)EagerTensor_summarize_value, METH_NOARGS,
826      PyDoc_STR("A string which summarizes the value of this tensor.")},
827     {nullptr, nullptr},
828 };
829 
EagerTensor_getbuffer(EagerTensor * self,Py_buffer * view,int flags)830 static int EagerTensor_getbuffer(EagerTensor* self, Py_buffer* view,
831                                  int flags) {
832   if ((flags & PyBUF_WRITABLE) == PyBUF_WRITABLE) {
833     PyErr_SetString(PyExc_BufferError, "EagerTensor is not writable.");
834     return -1;
835   }
836 
837   // TensorHandleToNumpy is zero-copy for everything but DT_RESOURCE and
838   // DT_STRING so the following is only slightly slower than a NumPy-free
839   // implementation.
840   auto py_array = tensorflow::make_safe(
841       TFE_TensorHandleToNumpy(self->handle, &self->status));
842   if (MaybeRaiseExceptionFromTFStatus(&self->status, PyExc_BufferError)) {
843     // Cleanup self->status before returning.
844     self->status.status = tensorflow::Status::OK();
845     return -1;
846   }
847   if (PyObject_GetBuffer(py_array.get(), view, flags) < 0) {
848     return -1;
849   }
850   view->readonly = 1;
851   return 0;
852 }
853 
854 static PyBufferProcs EagerTensor_as_buffer = {
855 #if PY_MAJOR_VERSION < 3
856     nullptr, nullptr, nullptr, nullptr,
857 #endif
858     (getbufferproc)EagerTensor_getbuffer,
859     // Never called because getbufferproc delegates to NumPy.
860     (releasebufferproc) nullptr};
861 
862 // Note that here we are trying to dynamically create a new class as a subclass
863 // of a "HEAPTYPE" class that is itself created in python code and passed in at
864 // runtime. This is fairly atypical and undocumented.
865 //
866 // We use the following strategy for this. Unfortunately, we have to use
867 // different approaches for python2.x vs python3.x
868 // For python2.x, we create the class as a static type and set its tp_base to
869 // the passed in type. Unfortunately setting tp_flags to include
870 // Py_TPFLAGS_HEAPTYPE does not work by itself since it needs some more
871 // initialization of the underlying PyHeapTypeObject and not doing that leads to
872 // some random crashes especially during garbage collection.
873 // python3.x explicitly disables a static subclass of a HEAPTYPE base class.
874 // However it provides a new function, PyType_FromSpecWithBases, to create
875 // types dynamically.
876 
877 // Type object for EagerTensor. This is set by TFE_Py_InitEagerTensor.
878 PyTypeObject* EagerTensorType = nullptr;
879 
880 #if PY_MAJOR_VERSION >= 3
881 static PyType_Slot EagerTensor_Type_slots[] = {
882     {Py_tp_dealloc, reinterpret_cast<void*>(EagerTensor_dealloc)},
883     {Py_tp_methods, reinterpret_cast<void*>(EagerTensor_methods)},
884     {Py_tp_getset, reinterpret_cast<void*>(EagerTensor_getsetters)},
885     {Py_tp_init, reinterpret_cast<void*>(EagerTensor_init)},
886     {0, nullptr},
887 };
888 #else
889 
890 #define EAGER_TENSOR_TPFLAGS (Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_NEWBUFFER)
891 
892 // TODO(agarwal): support active_trace.
893 static PyTypeObject _EagerTensorType = {
894     // clang-format off
895     PyVarObject_HEAD_INIT(nullptr, 0)
896     // clang-format on
897     "EagerTensor",                      /* tp_name */
898     sizeof(EagerTensor),                /* tp_basicsize */
899     0,                                  /* tp_itemsize */
900     (destructor)EagerTensor_dealloc,    /* tp_dealloc */
901 #if PY_VERSION_HEX < 0x03080000
902     nullptr,                            /* tp_print */
903 #else
904     0, /* tp_vectorcall_offset */
905 #endif
906     nullptr,                            /* tp_getattr */
907     nullptr,                            /* tp_setattr */
908     nullptr,                            /* tp_compare */
909     nullptr,                            /* tp_repr */
910     nullptr,                            /* tp_as_number */
911     nullptr,                            /* tp_as_sequence */
912     nullptr,                            /* tp_as_mapping */
913     nullptr,                            /* tp_hash */
914     nullptr,                            /* tp_call */
915     nullptr,                            /* tp_str */
916     nullptr,                            /* tp_getattro */
917     nullptr,                            /* tp_setattro */
918     &EagerTensor_as_buffer,             /* tp_as_buffer */
919     EAGER_TENSOR_TPFLAGS,               /* tp_flags */
920     nullptr,                            /* tp_doc */
921     nullptr,                            /* tp_traverse */
922     nullptr,                            /* tp_clear */
923     nullptr,                            /* tp_richcompare */
924     offsetof(EagerTensor, weakreflist), /* tp_weaklistoffset */
925     nullptr,                            /* tp_iter */
926     nullptr,                            /* tp_iternext */
927     EagerTensor_methods,                /* tp_methods */
928     EagerTensor_members,                /* tp_members */
929     EagerTensor_getsetters,             /* tp_getset */
930     nullptr,                            /* tp_base */
931     nullptr,                            /* tp_dict */
932     nullptr,                            /* tp_descr_get */
933     nullptr,                            /* tp_descr_set */
934     offsetof(EagerTensor, dict),        /* tp_dictoffset */
935     (initproc)EagerTensor_init,         /* tp_init */
936     nullptr,                            /* tp_alloc */
937     nullptr,                            /* tp_new */
938 };
939 
940 #endif
941 
942 }  // extern "C"
943 
EagerTensor_CheckExact(const PyObject * o)944 bool EagerTensor_CheckExact(const PyObject* o) {
945   return Py_TYPE(o) == EagerTensorType;
946 }
947 
EagerTensor_Handle(const PyObject * o)948 TFE_TensorHandle* EagerTensor_Handle(const PyObject* o) {
949   return reinterpret_cast<const EagerTensor*>(o)->handle;
950 }
951 
EagerTensorFromHandle(TFE_TensorHandle * handle,const bool is_packed)952 PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle,
953                                 const bool is_packed) {
954   if (handle == nullptr) {
955     return nullptr;
956   }
957   EagerTensor* t = reinterpret_cast<EagerTensor*>(
958       EagerTensorType->tp_new(EagerTensorType, EmptyTuple(), EmptyDict()));
959   if (t != nullptr) {
960     t->id = get_uid();
961     t->is_packed = is_packed;
962     Py_INCREF(Py_None);
963     t->handle_data = Py_None;
964     Py_INCREF(Py_None);
965     t->tensor_shape = Py_None;
966     t->handle = handle;
967     t->status.status = tensorflow::Status::OK();
968     t->weakreflist = nullptr;
969     PyObject* py_context = GetPyEagerContext();
970     if (py_context == nullptr) {
971       LOG(ERROR) << "Cannot create an eager tensor before eager context has "
972                     "been set or after it has been deleted";
973       return nullptr;
974     }
975     t->context = py_context;
976 
977     if (!MaybeInvokeCreatedOnEagerTensorProfiler(t)) {
978       return nullptr;
979     }
980   }
981   return reinterpret_cast<PyObject*>(t);
982 }
983 
PyEagerTensor_ID(const PyObject * tensor)984 tensorflow::int64 PyEagerTensor_ID(const PyObject* tensor) {
985   DCHECK(EagerTensor_CheckExact(tensor));
986   return reinterpret_cast<const EagerTensor*>(tensor)->id;
987 }
988 
PyEagerTensor_Dtype(const PyObject * tensor)989 tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor) {
990   DCHECK(EagerTensor_CheckExact(tensor));
991   return static_cast<tensorflow::DataType>(TFE_TensorHandleDataType(
992       reinterpret_cast<const EagerTensor*>(tensor)->handle));
993 }
994 
PyEagerTensor_NumElements(PyObject * tensor)995 tensorflow::int64 PyEagerTensor_NumElements(PyObject* tensor) {
996   DCHECK(EagerTensor_CheckExact(tensor));
997   EagerTensor* as_c_eager_tensor = reinterpret_cast<EagerTensor*>(tensor);
998   int64_t result = TFE_TensorHandleNumElements(as_c_eager_tensor->handle,
999                                                &as_c_eager_tensor->status);
1000 
1001   if (MaybeRaiseExceptionFromTFStatus(&as_c_eager_tensor->status,
1002                                       PyExc_ValueError)) {
1003     // Cleanup status before returning.
1004     as_c_eager_tensor->status.status = tensorflow::Status::OK();
1005     return -1;
1006   }
1007 
1008   return result;
1009 }
1010 
TFE_Py_InitEagerTensor(PyObject * base_class)1011 PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
1012   if (!PyType_Check(base_class)) {
1013     PyErr_SetString(
1014         PyExc_TypeError,
1015         tensorflow::strings::StrCat(
1016             "Expecting a class definition for `base_class` passed to ",
1017             "TFE_InitEagerTensor. Got ", Py_TYPE(base_class)->tp_name)
1018             .c_str());
1019     return nullptr;
1020   }
1021   // Note that we allocated kMaxEagerTensorParentSize bytes of unused space in
1022   // EagerTensor to allow for the space usage of the base class.
1023   PyTypeObject* base_class_type = reinterpret_cast<PyTypeObject*>(base_class);
1024   if (base_class_type->tp_basicsize > kMaxEagerTensorParentSize) {
1025     PyErr_SetString(
1026         PyExc_TypeError,
1027         tensorflow::strings::StrCat(
1028             "Unable to create subclass EagerTensor from base class ",
1029             Py_TYPE(base_class)->tp_name,
1030             ". Need its size to be <= ", kMaxEagerTensorParentSize)
1031             .c_str());
1032     return nullptr;
1033   }
1034   if (base_class_type->tp_itemsize != 0) {
1035     PyErr_SetString(
1036         PyExc_TypeError,
1037         tensorflow::strings::StrCat(
1038             "Unable to create subclass EagerTensor from base class ",
1039             Py_TYPE(base_class)->tp_name,
1040             " which supports variable length instances.")
1041             .c_str());
1042     return nullptr;
1043   }
1044   Py_INCREF(base_class);
1045 #if PY_MAJOR_VERSION >= 3
1046   PyObject* bases = PyTuple_New(1);
1047   PyTuple_SET_ITEM(bases, 0, base_class);
1048 
1049   tensorflow::Safe_PyObjectPtr base_class_module(
1050       PyObject_GetAttrString(base_class, "__module__"));
1051   const char* module = nullptr;
1052   if (PyErr_Occurred()) {
1053     PyErr_Clear();
1054     module = "__builtin__";
1055   } else {
1056     module = PyBytes_AsString(base_class_module.get());
1057     if (module == nullptr) {
1058       PyErr_Clear();
1059       module = PyUnicode_AsUTF8(base_class_module.get());
1060       if (module == nullptr) {
1061         PyErr_Clear();
1062         module = "__builtin__";
1063       }
1064     }
1065   }
1066 
1067   // NOTE: The c_str from this string needs to outlast the function, hence is
1068   // static.
1069   static tensorflow::string fully_qualified_name =
1070       tensorflow::strings::StrCat(module, ".EagerTensor");
1071 
1072   static PyType_Spec EagerTensor_Type_spec = {
1073       fully_qualified_name.c_str(), sizeof(EagerTensor), 0,
1074       Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE, EagerTensor_Type_slots};
1075 
1076   EagerTensorType = reinterpret_cast<PyTypeObject*>(
1077       PyType_FromSpecWithBases(&EagerTensor_Type_spec, bases));
1078   if (PyErr_Occurred()) {
1079     return nullptr;
1080   }
1081   if (EagerTensorType == nullptr) {
1082     PyErr_SetString(PyExc_RuntimeError, "Error while creating EagerTensorType");
1083     return nullptr;
1084   }
1085   EagerTensorType->tp_dictoffset = offsetof(EagerTensor, dict);
1086   EagerTensorType->tp_as_buffer = &EagerTensor_as_buffer;
1087 #else
1088   _EagerTensorType.tp_base = base_class_type;
1089 
1090   if (PyType_Ready(&_EagerTensorType) < 0) {
1091     if (PyErr_Occurred()) return nullptr;
1092     PyErr_SetString(PyExc_RuntimeError,
1093                     "Error while creating EagerTensor type.");
1094     return nullptr;
1095   }
1096   EagerTensorType = &_EagerTensorType;
1097   Py_INCREF(EagerTensorType);
1098 #endif
1099   return reinterpret_cast<PyObject*>(EagerTensorType);
1100 }
1101 
TFE_Py_SetEagerTensorProfiler(PyObject * profiler)1102 PyObject* TFE_Py_SetEagerTensorProfiler(PyObject* profiler) {
1103   Py_XDECREF(eager_tensor_profiler);
1104 
1105   if (profiler == Py_None) {
1106     eager_tensor_profiler = nullptr;
1107   } else {
1108     eager_tensor_profiler = profiler;
1109     Py_INCREF(eager_tensor_profiler);
1110   }
1111   Py_RETURN_NONE;
1112 }
1113 
TFE_Py_TensorShapeSlice(PyObject * tensors,int slice_dim)1114 PyObject* TFE_Py_TensorShapeSlice(PyObject* tensors, int slice_dim) {
1115   if (!PyList_Check(tensors) && !PyTuple_Check(tensors)) {
1116     PyErr_SetString(PyExc_TypeError,
1117                     tensorflow::strings::StrCat(
1118                         "tensors argument must be a list or a tuple. Got \"",
1119                         Py_TYPE(tensors)->tp_name, "\"")
1120                         .c_str());
1121     return nullptr;
1122   }
1123   if (slice_dim < 0) {
1124     PyErr_SetString(
1125         PyExc_ValueError,
1126         tensorflow::strings::StrCat("Slice dimension must be non-negative. "
1127                                     "Got ",
1128                                     slice_dim)
1129             .c_str());
1130     return nullptr;
1131   }
1132 
1133   PyObject* py_context = GetPyEagerContext();
1134   if (py_context == nullptr) {
1135     PyErr_SetString(PyExc_RuntimeError, tensorflow::strings::StrCat(
1136                                             "Cannot create EagerTensor when "
1137                                             "EagerContext is not valid")
1138                                             .c_str());
1139     return nullptr;
1140   }
1141 
1142   TFE_Context* ctx = GetContextHandle(py_context);
1143 
1144   Py_ssize_t num_tensors = PySequence_Fast_GET_SIZE(tensors);
1145   PyObject** tensors_array = PySequence_Fast_ITEMS(tensors);
1146   int64_t num_tensors_int = static_cast<int64_t>(num_tensors);
1147 
1148   auto status = tensorflow::make_safe(TF_NewStatus());
1149 
1150   // Create an empty tensor.
1151   auto* tensor = tensorflow::unwrap(ctx)->CreateTensor(
1152       tensorflow::DT_INT32, /*dim_sizes=*/{num_tensors_int});
1153 
1154   if (num_tensors_int > 0) {
1155     int32_t* data = reinterpret_cast<int32_t*>(tensor->Data());
1156 
1157     // Fill the tensor with dims.
1158     for (Py_ssize_t i = 0; i < num_tensors; ++i) {
1159       PyObject* tensor_obj = tensors_array[i];
1160       if (!EagerTensor_CheckExact(tensor_obj)) {
1161         PyErr_SetString(
1162             PyExc_TypeError,
1163             tensorflow::strings::StrCat("Expected a list of EagerTensors but "
1164                                         "element ",
1165                                         i, " has type \"",
1166                                         Py_TYPE(tensor_obj)->tp_name, "\"")
1167                 .c_str());
1168         return nullptr;
1169       }
1170 
1171       EagerTensor* t = reinterpret_cast<EagerTensor*>(tensor_obj);
1172       TFE_TensorHandle* handle = t->handle;
1173       int num_dims = TFE_TensorHandleNumDims(handle, status.get());
1174       if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_ValueError)) {
1175         return nullptr;
1176       }
1177       if (slice_dim >= num_dims) {
1178         PyErr_SetString(
1179             PyExc_IndexError,
1180             tensorflow::strings::StrCat("Slice dimension (", slice_dim,
1181                                         ") must be smaller than rank of all "
1182                                         "tensors, but tensor at index ",
1183                                         i, " has rank ", num_dims)
1184                 .c_str());
1185         return nullptr;
1186       }
1187       int64_t dim = TFE_TensorHandleDim(handle, slice_dim, status.get());
1188       if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_ValueError)) {
1189         return nullptr;
1190       }
1191       data[i] = dim;
1192     }
1193   }
1194 
1195   TFE_TensorHandle* handle =
1196       tensorflow::wrap(tensorflow::unwrap(ctx)->CreateLocalHandle(tensor));
1197 
1198   if (!status->status.ok()) {
1199     PyErr_SetString(
1200         PyExc_RuntimeError,
1201         tensorflow::strings::StrCat("Failed to construct new tensor handle: ",
1202                                     TF_Message(status.get()))
1203             .c_str());
1204     return nullptr;
1205   }
1206 
1207   return EagerTensorFromHandle(handle);
1208 }
1209 
TFE_Py_TensorShapeOnDevice(PyObject * tensor)1210 PyObject* TFE_Py_TensorShapeOnDevice(PyObject* tensor) {
1211   if (!EagerTensor_CheckExact(tensor)) {
1212     PyErr_SetString(
1213         PyExc_TypeError,
1214         tensorflow::strings::StrCat("Expected an EagerTensors but got type \"",
1215                                     Py_TYPE(tensor)->tp_name, "\"")
1216             .c_str());
1217     return nullptr;
1218   }
1219   TFE_TensorHandle* handle = EagerTensor_Handle(tensor);
1220 
1221   auto status = tensorflow::make_safe(TF_NewStatus());
1222   TFE_TensorDebugInfo* debug_info =
1223       TFE_TensorHandleTensorDebugInfo(handle, status.get());
1224   if (!status->status.ok()) {
1225     PyErr_SetString(
1226         PyExc_RuntimeError,
1227         tensorflow::strings::StrCat("Error retrieving tensor's device shape: ",
1228                                     TF_Message(status.get()))
1229             .c_str());
1230     return nullptr;
1231   }
1232 
1233   int rank = TFE_TensorDebugInfoOnDeviceNumDims(debug_info);
1234   PyObject* shape = PyTuple_New(rank);
1235   for (int i = 0; i < rank; ++i) {
1236     int64_t dim_size = TFE_TensorDebugInfoOnDeviceDim(debug_info, i);
1237     PyTuple_SET_ITEM(shape, i, PyLong_FromLongLong(dim_size));
1238   }
1239   TFE_DeleteTensorDebugInfo(debug_info);
1240 
1241   return shape;
1242 }
1243