• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/python/eager/pywrap_tensor.h"
17 
18 #include <stdlib.h>
19 #include <string.h>
20 
21 #include <cmath>
22 
23 #include "structmember.h"  // NOLINT // For PyMemberDef
24 #include "tensorflow/c/c_api.h"
25 #include "tensorflow/c/eager/c_api.h"
26 #include "tensorflow/c/eager/c_api_internal.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/framework/types.pb.h"
29 #include "tensorflow/core/lib/strings/strcat.h"
30 #include "tensorflow/python/eager/pywrap_tensor_conversion.h"
31 #include "tensorflow/python/eager/pywrap_tfe.h"
32 #include "tensorflow/python/lib/core/ndarray_tensor.h"
33 #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
34 #include "tensorflow/python/lib/core/numpy.h"
35 #include "tensorflow/python/lib/core/py_seq_tensor.h"
36 #include "tensorflow/python/lib/core/safe_ptr.h"
37 
38 // forward declare
39 struct EagerTensor;
40 
41 namespace {
42 
43 // An instance of _EagerTensorProfiler that will receive callbacks about
44 // events on eager tensors. This is set by TFE_Py_InitEagerTensor, if at all.
45 PyObject* eager_tensor_profiler = nullptr;
46 
GetContextHandle(PyObject * py_context)47 TFE_Context* GetContextHandle(PyObject* py_context) {
48   tensorflow::Safe_PyObjectPtr py_context_handle(
49       PyObject_GetAttrString(py_context, "_handle"));
50   if (py_context_handle == nullptr) {
51     // Current Python code makes sure this never happens. If it does, or
52     // becomes hard to maintain, we can call the ensure_initialized() method
53     // here.
54     PyErr_SetString(
55         PyExc_TypeError,
56         "Expected `context` argument in EagerTensor constructor to have a "
57         "`_handle` attribute but it did not. Was eager Context initialized?");
58     return nullptr;
59   }
60 
61   auto* ctx = reinterpret_cast<TFE_Context*>(
62       PyCapsule_GetPointer(py_context_handle.get(), nullptr));
63   if (ctx == nullptr) {
64     PyErr_SetString(PyExc_TypeError,
65                     tensorflow::strings::StrCat(
66                         "Expected context._handle to contain a PyCapsule "
67                         "encoded pointer to TFE_Context. Got ",
68                         Py_TYPE(py_context_handle.get())->tp_name)
69                         .c_str());
70   }
71   return ctx;
72 }
73 
74 // Convert a TFE_TensorHandle to a Python numpy.ndarray object.
75 // The two may share underlying storage so changes to one may reflect in the
76 // other.
TFE_TensorHandleToNumpy(TFE_TensorHandle * handle,TF_Status * status)77 PyObject* TFE_TensorHandleToNumpy(TFE_TensorHandle* handle, TF_Status* status) {
78   if (TFE_TensorHandleDataType(handle) == TF_RESOURCE) {
79     TF_SetStatus(status, TF_INVALID_ARGUMENT,
80                  "Cannot convert a Tensor of dtype resource to a NumPy array.");
81     return nullptr;
82   }
83 
84   tensorflow::Safe_TF_TensorPtr tensor = nullptr;
85   Py_BEGIN_ALLOW_THREADS;
86   tensor = tensorflow::make_safe(TFE_TensorHandleResolve(handle, status));
87   Py_END_ALLOW_THREADS;
88   if (!status->status.ok()) {
89     return nullptr;
90   }
91 
92   PyObject* ret = nullptr;
93   auto cppstatus =
94       tensorflow::TF_TensorToMaybeAliasedPyArray(std::move(tensor), &ret);
95   tensorflow::Set_TF_Status_from_Status(status, cppstatus);
96   if (!status->status.ok()) {
97     Py_XDECREF(ret);
98     return nullptr;
99   }
100   CHECK_NE(ret, nullptr);
101   return ret;
102 }
103 
104 // Helper function to convert `v` to a tensorflow::DataType and store it in
105 // `*out`. Returns true on success, false otherwise.
106 // Note that we assume that v is a python int (not long) representing a
107 // TF_DataType/tensorflow::DataType value.
PyIntToDataType(PyObject * v,tensorflow::DataType * out)108 bool PyIntToDataType(PyObject* v, tensorflow::DataType* out) {
109 #if PY_MAJOR_VERSION < 3
110   if (PyInt_Check(v)) {
111     *out = static_cast<tensorflow::DataType>(PyInt_AS_LONG(v));
112     return true;
113   }
114 #else
115   if (PyLong_Check(v)) {
116     *out = static_cast<tensorflow::DataType>(PyLong_AsLong(v));
117     return true;
118   }
119 #endif
120   return false;
121 }
122 
123 // Helper function to create a python integer from TF_DataType.
PyIntFromDataType(TF_DataType l)124 PyObject* PyIntFromDataType(TF_DataType l) {
125 #if PY_MAJOR_VERSION < 3
126   return PyInt_FromLong(l);
127 #else
128   return PyLong_FromLong(l);
129 #endif
130 }
131 
132 // PyObject->tensorflow::DataType conversion function to be used with
133 // PyArg_Parse* APIs.
ConvertDataType(PyObject * obj,tensorflow::DataType * dst)134 int ConvertDataType(PyObject* obj, tensorflow::DataType* dst) {
135   if (obj == Py_None) {
136     *dst = tensorflow::DataType::DT_INVALID;
137   } else if (!PyIntToDataType(obj, dst)) {
138     PyErr_SetString(
139         PyExc_TypeError,
140         tensorflow::strings::StrCat(
141             "Expecting a DataType value for dtype. Got ", Py_TYPE(obj)->tp_name)
142             .c_str());
143     return 0;
144   }
145 
146   return 1;
147 }
148 
149 // Conversion function extracting a const char** device name from a PyObject.
150 // The function should be used with PyArg_Parse* APIs.
ConvertDeviceName(PyObject * obj,const char ** dst)151 int ConvertDeviceName(PyObject* obj, const char** dst) {
152   if (obj == Py_None) {
153     *dst = nullptr;
154   } else {
155     auto device_name = TFE_GetPythonString(obj);
156     if (device_name == nullptr) {
157       PyErr_Clear();
158       PyErr_SetString(PyExc_TypeError, "Error parsing device argument.");
159       return 0;
160     }
161     *dst = device_name;
162   }
163 
164   return 1;
165 }
166 
167 }  // namespace
168 
169 namespace tensorflow {
170 // This function checks whether the desired type is "compatible" with the
171 // inferred type. At a high level, compatibility means that all integral types
172 // are compatible with each other, and all floating types are compatible with
173 // each other.
174 //
175 // Type compatibility doesn't consider overflows (i.e. int64 is *always*
176 // compatible with int32). This is intended to match graph behavior.
IsCompatible(DataType desired,DataType returned)177 bool IsCompatible(DataType desired, DataType returned) {
178   if (desired == returned) return true;
179 
180   if (DataTypeIsInteger(desired) && DataTypeIsInteger(returned)) {
181     return true;
182   } else if (DataTypeIsFloating(desired) &&
183              (DataTypeIsFloating(returned) || DataTypeIsInteger(returned))) {
184     return true;
185   } else if (DataTypeIsComplex(desired) &&
186              (DataTypeIsComplex(returned) || DataTypeIsInteger(returned) ||
187               DataTypeIsFloating(returned))) {
188     return true;
189   } else if (DataTypeIsQuantized(desired) && DataTypeIsInteger(returned)) {
190     return true;
191   }
192   return false;
193 }
194 
195 // TODO(nareshmodi): Move EagerCast and ReadVariableOp (which use the C API to
196 // execute TFE Ops) to a separate common library.
197 // Casts data referred to by `handle` from type `src_type_enum` to type
198 // `dst_type_enum`.
EagerCast(TFE_Context * ctx,TFE_TensorHandle * handle,TF_DataType src_type_enum,TF_DataType dst_type_enum,TF_Status * out_status)199 TFE_TensorHandle* EagerCast(TFE_Context* ctx, TFE_TensorHandle* handle,
200                             TF_DataType src_type_enum,
201                             TF_DataType dst_type_enum, TF_Status* out_status) {
202   if (ctx == nullptr) return nullptr;
203   const char* op_name = "Cast";
204   const char* device_name = "/device:CPU:0";
205   TFE_Op* op = TFE_NewOp(ctx, op_name, out_status);
206 #define RETURN_ERROR  \
207   {                   \
208     TFE_DeleteOp(op); \
209     return nullptr;   \
210   }
211   if (!out_status->status.ok()) RETURN_ERROR
212   TFE_OpSetDevice(op, device_name, out_status);
213   if (!out_status->status.ok()) RETURN_ERROR
214   TFE_OpAddInput(op, handle, out_status);
215   if (!out_status->status.ok()) RETURN_ERROR
216   TFE_OpSetAttrType(op, "SrcT", src_type_enum);
217   TFE_OpSetAttrType(op, "DstT", dst_type_enum);
218   TFE_OpSetAttrBool(op, "Truncate", false);
219   TFE_TensorHandle* output = nullptr;
220   int num_outputs = 1;
221   TFE_Execute(op, &output, &num_outputs, out_status);
222   if (!out_status->status.ok() || num_outputs != 1 || output == nullptr) {
223     if (output != nullptr) {
224       TFE_DeleteTensorHandle(output);
225     }
226     RETURN_ERROR
227   }
228   TFE_DeleteOp(op);
229   return output;
230 #undef RETURN_ERROR
231 }
232 
ConvertToEagerTensorUncached(TFE_Context * ctx,PyObject * value,tensorflow::DataType dtype,const char * device_name)233 TFE_TensorHandle* ConvertToEagerTensorUncached(TFE_Context* ctx,
234                                                PyObject* value,
235                                                tensorflow::DataType dtype,
236                                                const char* device_name) {
237   tensorflow::Safe_PyObjectPtr value_decrefer;
238   if (PyArray_IsScalar(value, Generic)) {
239     // Convert numpy scalars to numpy arrays.
240     value = PyArray_FromScalar(value, nullptr);
241     // The returned value needs to be DECREF'd, but the original value was
242     // created in python code, and doesn't need to be DECREF'd.
243     value_decrefer.reset(value);
244   }
245 
246   Safe_TFE_TensorHandlePtr handle =
247       make_safe(PySeqToTFE_TensorHandle(ctx, value, dtype));
248 
249   if (handle == nullptr) return nullptr;
250 
251   Safe_TF_StatusPtr status = make_safe(TF_NewStatus());
252   TF_DataType handle_dtype = TFE_TensorHandleDataType(handle.get());
253   if (dtype != tensorflow::DT_INVALID &&
254       dtype != static_cast<DataType>(handle_dtype)) {
255     if (tensorflow::IsCompatible(dtype, static_cast<DataType>(handle_dtype))) {
256       handle = tensorflow::make_safe(
257           tensorflow::EagerCast(ctx, handle.get(), handle_dtype,
258                                 static_cast<TF_DataType>(dtype), status.get()));
259       if (!status->status.ok()) {
260         PyErr_SetString(PyExc_TypeError,
261                         absl::StrCat("Error while casting from dtype ",
262                                      tensorflow::DataTypeString(
263                                          static_cast<DataType>(handle_dtype)),
264                                      " to ", tensorflow::DataTypeString(dtype),
265                                      ". ", TF_Message(status.get()))
266                             .c_str());
267         return nullptr;
268       }
269     } else {
270       tensorflow::Safe_PyObjectPtr value_str(PyObject_Repr(value));
271       PyErr_SetString(
272           PyExc_TypeError,
273           absl::StrCat("Cannot convert ", TFE_GetPythonString(value_str.get()),
274                        " to EagerTensor of dtype ",
275                        tensorflow::DataTypeString(dtype))
276               .c_str());
277       return nullptr;
278     }
279   }
280 
281   // Almost all TensorFlow kernels for GPU devices keep int32 tensors in host
282   // memory. We approximate the same behavior for eager execution - keeping
283   // int32 tensors in host memory.
284   //
285   // We do so to preclude the need for callers into such kernels from having to
286   // explicitly place the int32 tensors in host memory. For example, without
287   // this, one needed:
288   //
289   // with tf.device('/gpu:0'):
290   //   ...// code here
291   //   with tf.device('/cpu:0'):
292   //     shape = tf.constant(...)
293   //   y = tf.random_uniform(shape)
294   //
295   // Without the CPU device block, tfe.ops.random_uniform would fail since the
296   // kernel expects the shape in host memory.
297   //
298   // With this support, we simplify the code:
299   //
300   // with tf.device('/gpu:0'):
301   //   y = tf.random_uniform(...)
302   //
303   // The approximation is not exact there are GPU kernels which do not require
304   // host memory for int32 tensors. This will lead to a discrepancy between
305   // eager and graph execution.
306   //
307   // To support remote execution copy int32 tensors to another CPU device.
308   // TODO(ashankar): Fix this.
309   if (device_name != nullptr &&
310       (TFE_TensorHandleDataType(handle.get()) != TF_INT32 ||
311        strstr(device_name, "/device:CPU:0") != nullptr)) {
312     // Note that this is a shallow copy and will share the underlying buffer
313     // if copying to the same device.
314     handle = make_safe(TFE_TensorHandleCopyToDevice(handle.get(), ctx,
315                                                     device_name, status.get()));
316     if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_RuntimeError)) {
317       return nullptr;
318     }
319   }
320 
321   return handle.release();
322 }
323 
ConvertToEagerTensor(TFE_Context * ctx,PyObject * value,DataType dtype,const char * device_name)324 TFE_TensorHandle* ConvertToEagerTensor(TFE_Context* ctx, PyObject* value,
325                                        DataType dtype,
326                                        const char* device_name) {
327   // Reduce the overhead of allocation/transfer-to-device for scalars by
328   // caching the corresponding handles. Note that currently only Python
329   // scalars are cached.
330   // TODO(slebedev): also cache singleton NumPy arrays and scalars?
331   if (PyArray_IsPythonNumber(value)) {
332     auto* cache = TFE_TensorHandleCache::Get();
333     TFE_TensorHandle* handle = cache->Lookup(value, dtype, device_name);
334     if (handle != nullptr) return handle;
335     handle = ConvertToEagerTensorUncached(ctx, value, dtype, device_name);
336     if (handle == nullptr) return nullptr;
337     if (!PyFloat_Check(value) || std::isfinite(PyFloat_AS_DOUBLE(value))) {
338       cache->Insert(value, dtype, device_name, handle);
339     }
340     return handle;
341   } else {
342     return ConvertToEagerTensorUncached(ctx, value, dtype, device_name);
343   }
344 }
345 
346 }  // namespace tensorflow
347 
348 extern "C" {
349 
350 static const int kMaxEagerTensorParentSize = 64;
351 
352 // TODO(agarwal): store context handle in EagerTensor.
353 typedef struct EagerTensor {
354   PyObject_HEAD;
355   // Note that we leave kMaxEagerTensorParentSize bytes here for use by the
356   // parent class. The parent class is set at runtime, so we don't know the
357   // exact size at compile time.
358   char unused[kMaxEagerTensorParentSize];
359   TFE_TensorHandle* handle;
360   int64_t id;
361   // This mirrors tensorflow.core.framework.ops.Tensor._handle_data Which will
362   // be None for tensors of type other than DT_REOSURCE. For DT_RESOURCE
363   // tensors, this will contain a serialized HandleData proto with shape
364   // inference metadata about shapes and dtypes of resources accessible from
365   // this handle.
366   // Note that we assume that handle_data cannot participate in reference
367   // cycles, and hence don't provide GC support for it.
368   PyObject* handle_data;
369 
370   // This stores `_tensor_shape`, a cached `TensorShape` object, and is set the
371   // first time that `_EagerTensorBase`'s `shape` property is called.
372   PyObject* tensor_shape;
373 
374   // We store a status object here as an optimization to avoid allocating a new
375   // Status objects on different functions that operate on EagerTensor and need
376   // to use a TF_Status object. However note that accesses to `status` are not
377   // thread-safe.
378   TF_Status status;
379 
380   // The eager Context (from eager/context.py) used by this Tensor.
381   // This is currently used only to make sure context outlives TensorHandles.
382   PyObject* context;
383 
384   PyObject* weakreflist; /* List of weak references */
385 
386   // Per-instance attribute dictionary, to support monkey patching
387   // (e.g. EagerTensor.assign when slicing variables). This dictionary is
388   // created by CPython the first time an attribute is assigned, pointed to by
389   // tp_dictoffset. Note that garbage collection is not enabled for
390   // EagerTensors, so assigning objects to EagerTensor attributes which require
391   // garbage collection is likely to cause issues.
392   PyObject* dict;
393 } EagerTensor;
394 
395 namespace {
396 
397 // Returns true on success - successfully invoked or no profiler registered.
398 // Returns false if some error occurred.
MaybeInvokeCreatedOnEagerTensorProfiler(EagerTensor * created_tensor)399 bool MaybeInvokeCreatedOnEagerTensorProfiler(EagerTensor* created_tensor) {
400   if (eager_tensor_profiler != nullptr) {
401 #if PY_MAJOR_VERSION < 3
402     PyObject* created_method_name = PyString_InternFromString("created");
403 #else
404     PyObject* created_method_name = PyUnicode_InternFromString("created");
405 #endif
406     if (created_method_name == nullptr) {
407       return false;
408     }
409     PyObject* result = PyObject_CallMethodObjArgs(
410         eager_tensor_profiler, created_method_name, created_tensor, NULL);
411     if (result == nullptr) {
412       LOG(ERROR) << "Invoking created() on EagerTensor profiler failed";
413       // While we can potentially continue because the error is related to
414       // profiling, we choose to return an error because:
415       //  - If profiling is used, the user likely wants to stop execution on
416       //    profiling errors.
417       //  - Error in profiling code might have left some state in an invalid
418       //    form that can lead to an error later on. Better to fail fast.
419       Py_DECREF(created_method_name);
420       return false;
421     }
422     Py_DECREF(created_method_name);
423     Py_DECREF(result);
424   }
425   return true;
426 }
427 
428 }  // namespace
429 
430 // tp_init for EagerTensor.
EagerTensor_init(EagerTensor * self,PyObject * args,PyObject * kwds)431 int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) {
432   self->id = get_uid();
433   self->handle = nullptr;
434   Py_INCREF(Py_None);
435   self->handle_data = Py_None;
436   Py_INCREF(Py_None);
437   self->tensor_shape = Py_None;
438   self->status.status = tensorflow::Status::OK();
439   self->dict = nullptr;
440   self->weakreflist = nullptr;
441   self->context = nullptr;
442   PyObject* value;
443   const char* device_name = nullptr;
444   tensorflow::DataType dtype = tensorflow::DataType::DT_INVALID;
445   const char* kwlist[] = {"value", "device", "dtype", nullptr};
446   if (!PyArg_ParseTupleAndKeywords(
447           args, kwds, "OO&|O&", const_cast<char**>(kwlist), &value,
448           ConvertDeviceName, &device_name, ConvertDataType, &dtype)) {
449     return -1;
450   }
451 
452   PyObject* py_context = GetPyEagerContext();
453   if (py_context == nullptr) return -1;
454   self->context = py_context;
455 
456   auto* handle = tensorflow::ConvertToEagerTensor(GetContextHandle(py_context),
457                                                   value, dtype, device_name);
458   if (handle == nullptr) return -1;
459   self->handle = handle;
460 
461   if (!MaybeInvokeCreatedOnEagerTensorProfiler(self)) {
462     return -1;
463   }
464 
465   return 0;
466 }
467 
468 // tp_dealloc for EagerTensor.
EagerTensor_dealloc(EagerTensor * self)469 void EagerTensor_dealloc(EagerTensor* self) {
470   // Unhook the object from python's GC so that the weakref deleter doesn't
471   // try to re-delete this.
472   PyObject_GC_UnTrack((PyObject*)self);
473 
474   // Clear weak references to self.
475   // Needs to happen before any actual destruction.
476   PyObject_ClearWeakRefs((PyObject*)self);
477 
478   Py_DECREF(self->handle_data);
479   Py_DECREF(self->tensor_shape);
480   // If an attribute dictionary has been created, release it. Note that this
481   // is only ever created by CPython's attribute setting methods; we don't
482   // create it ourselves.
483   Py_CLEAR(self->dict);
484   if (self->handle != nullptr) {
485     TFE_DeleteTensorHandle(self->handle);
486     self->handle = nullptr;
487   }
488 
489   // Decref context after deleting the tensor handle.
490   Py_XDECREF(self->context);
491 
492   // We have the global interpreter lock, so use this chance to perform delayed
493   // refcount decrements.
494   tensorflow::ClearDecrefCache();
495   auto id = self->id;
496   Py_TYPE(self)->tp_free(self);
497   TFE_Py_TapeSetDeleteTrace(id);
498 }
499 
500 // Getter for `_id`.
EagerTensor_getid(EagerTensor * self,void * closure)501 static PyObject* EagerTensor_getid(EagerTensor* self, void* closure) {
502   return PyLong_FromLongLong(self->id);
503 }
504 
505 // Getter for `_datatype_enum`.
EagerTensor_datatype_enum(EagerTensor * self)506 static PyObject* EagerTensor_datatype_enum(EagerTensor* self) {
507   return PyIntFromDataType(TFE_TensorHandleDataType(self->handle));
508 }
509 
510 // Getter for `_shape_tuple`.
EagerTensor_shape_tuple(EagerTensor * self)511 static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
512   auto handle = self->handle;
513   int n = TFE_TensorHandleNumDims(handle, &self->status);
514   if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) {
515     // Cleanup self->status before returning.
516     self->status.status = tensorflow::Status::OK();
517     return nullptr;
518   }
519   PyObject* shape = PyTuple_New(n);
520   if (PyErr_Occurred()) return nullptr;
521   for (int i = 0; i < n; ++i) {
522     PyObject* dim =
523         PyLong_FromLongLong(TFE_TensorHandleDim(handle, i, &self->status));
524     if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr) ||
525         dim == nullptr || PyTuple_SetItem(shape, i, dim) != 0) {
526       // Cleanup self->status before returning.
527       self->status.status = tensorflow::Status::OK();
528       Py_DECREF(shape);
529       if (dim != nullptr) Py_DECREF(dim);
530       PyErr_SetString(PyExc_RuntimeError, "Error while creating shape");
531       return nullptr;
532     }
533   }
534   return shape;
535 }
536 
537 // Getter for `_rank`.
EagerTensor_rank(EagerTensor * self)538 static PyObject* EagerTensor_rank(EagerTensor* self) {
539   int num_dims = TFE_TensorHandleNumDims(self->handle, &self->status);
540   if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) {
541     // Cleanup self->status before returning.
542     self->status.status = tensorflow::Status::OK();
543     return nullptr;
544   }
545 #if PY_MAJOR_VERSION < 3
546   return PyInt_FromLong(num_dims);
547 #else
548   return PyLong_FromLong(num_dims);
549 #endif
550 }
551 
552 // Getter for `_num_elements`.
EagerTensor_num_elements(EagerTensor * self)553 static PyObject* EagerTensor_num_elements(EagerTensor* self) {
554   auto handle = self->handle;
555   int n = TFE_TensorHandleNumElements(handle, &self->status);
556   if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) {
557     // Cleanup self->status before returning.
558     self->status.status = tensorflow::Status::OK();
559     return nullptr;
560   }
561   return PyLong_FromLongLong(n);
562 }
563 
EagerTensor_handle_data(EagerTensor * self,void * unused)564 static PyObject* EagerTensor_handle_data(EagerTensor* self, void* unused) {
565   Py_INCREF(self->handle_data);
566   return self->handle_data;
567 }
568 
EagerTensor_sethandle_data(EagerTensor * self,PyObject * value,void * unused)569 static int EagerTensor_sethandle_data(EagerTensor* self, PyObject* value,
570                                       void* unused) {
571   Py_DECREF(self->handle_data);
572   Py_INCREF(value);
573   self->handle_data = value;
574   return 0;
575 }
576 
EagerTensor_tensor_shape(EagerTensor * self,void * unused)577 static PyObject* EagerTensor_tensor_shape(EagerTensor* self, void* unused) {
578   Py_INCREF(self->tensor_shape);
579   return self->tensor_shape;
580 }
581 
EagerTensor_settensor_shape(EagerTensor * self,PyObject * value,void * unused)582 static int EagerTensor_settensor_shape(EagerTensor* self, PyObject* value,
583                                        void* unused) {
584   Py_DECREF(self->tensor_shape);
585   Py_INCREF(value);
586   self->tensor_shape = value;
587   return 0;
588 }
589 
590 // Function `_copy_to_device`.
EagerTensor_copy_to_device(EagerTensor * self,PyObject * args,PyObject * kwds)591 static PyObject* EagerTensor_copy_to_device(EagerTensor* self, PyObject* args,
592                                             PyObject* kwds) {
593   if (!_PyArg_NoKeywords("copy_to_device", kwds)) return nullptr;
594 
595   const char* device_name = nullptr;
596   if (!PyArg_ParseTuple(args, "O&:copy_to_device", ConvertDeviceName,
597                         &device_name)) {
598     return nullptr;
599   }
600 
601   // Note that this is a shallow copy and will share the underlying buffer
602   // if copying to the same device.
603   TFE_TensorHandle* handle = TFE_TensorHandleCopyToDevice(
604       self->handle, GetContextHandle(self->context), device_name,
605       &self->status);
606   if (MaybeRaiseExceptionFromTFStatus(&self->status, PyExc_RuntimeError)) {
607     // Cleanup self->status before returning.
608     self->status.status = tensorflow::Status::OK();
609     return nullptr;
610   }
611 
612   return EagerTensorFromHandle(handle);
613 }
614 
615 // Function `_numpy_internal`.
616 // Convert an EagerTensor to a Python numpy.ndarray object.
617 // The two may share underlying storage so changes to one may reflect in the
618 // other.
619 // Note that if `self` is not on CPU, we raise an Exception.
EagerTensor_numpy_internal(EagerTensor * self)620 static PyObject* EagerTensor_numpy_internal(EagerTensor* self) {
621   auto* py_array = TFE_TensorHandleToNumpy(self->handle, &self->status);
622   if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) {
623     Py_XDECREF(py_array);
624     // Cleanup self->status before returning.
625     self->status.status = tensorflow::Status::OK();
626     return nullptr;
627   } else {
628     return PyArray_Return(reinterpret_cast<PyArrayObject*>(py_array));
629   }
630 }
631 
632 // Getter `device`.
EagerTensor_device(EagerTensor * self)633 static PyObject* EagerTensor_device(EagerTensor* self) {
634   const char* device = TFE_TensorHandleDeviceName(self->handle, &self->status);
635   if (MaybeRaiseExceptionFromTFStatus(&self->status, PyExc_ValueError)) {
636     // Cleanup self->status before returning.
637     self->status.status = tensorflow::Status::OK();
638     return nullptr;
639   }
640 #if PY_MAJOR_VERSION >= 3
641   return PyUnicode_FromString(device);
642 #else
643   return PyBytes_FromString(device);
644 #endif
645 }
646 
647 // Getter `backing_device`.
EagerTensor_backing_device(EagerTensor * self)648 static PyObject* EagerTensor_backing_device(EagerTensor* self) {
649   const char* device =
650       TFE_TensorHandleBackingDeviceName(self->handle, &self->status);
651   if (MaybeRaiseExceptionFromTFStatus(&self->status, PyExc_ValueError)) {
652     // Cleanup self->status before returning.
653     self->status.status = tensorflow::Status::OK();
654     return nullptr;
655   }
656 #if PY_MAJOR_VERSION >= 3
657   return PyUnicode_FromString(device);
658 #else
659   return PyBytes_FromString(device);
660 #endif
661 }
662 
663 static PyGetSetDef EagerTensor_getseters[] = {
664     {const_cast<char*>("_id"), (getter)EagerTensor_getid, nullptr,
665      const_cast<char*>("Tensor ID."), nullptr},
666     {const_cast<char*>("device"), (getter)EagerTensor_device, nullptr,
667      const_cast<char*>("Device of op that produced the tensor."), nullptr},
668     {const_cast<char*>("backing_device"), (getter)EagerTensor_backing_device,
669      nullptr, const_cast<char*>("Device on which tensor's memory is resident."),
670      nullptr},
671     {const_cast<char*>("_handle_data"), (getter)EagerTensor_handle_data,
672      (setter)EagerTensor_sethandle_data,
673      const_cast<char*>("Shape/DType data if the EagerTensor is a DT_RESOURCE"),
674      nullptr},
675     {const_cast<char*>("_tensor_shape"), (getter)EagerTensor_tensor_shape,
676      (setter)EagerTensor_settensor_shape,
677      const_cast<char*>("Shape of the tensor."), nullptr},
678     {nullptr} /* Sentinel */
679 };
680 
681 #if PY_MAJOR_VERSION < 3
682 // Only used for Python2 since Python3 seems to set the __dict__ correctly.
683 static PyMemberDef EagerTensor_members[] = {
684     {const_cast<char*>("__dict__"), T_OBJECT, offsetof(EagerTensor, dict),
685      READONLY},
686     {nullptr},
687 };
688 #endif
689 
690 static PyMethodDef EagerTensor_methods[] = {
691     {"_numpy_internal", (PyCFunction)EagerTensor_numpy_internal, METH_NOARGS,
692      PyDoc_STR("Internal method to get a NumPy array for the tensor.")},
693     {"_datatype_enum", (PyCFunction)EagerTensor_datatype_enum, METH_NOARGS,
694      PyDoc_STR("The DType of the tensor as an enum.")},
695     {"_shape_tuple", (PyCFunction)EagerTensor_shape_tuple, METH_NOARGS,
696      PyDoc_STR("The shape of the tensor as a python tuple.")},
697     {"_rank", (PyCFunction)EagerTensor_rank, METH_NOARGS,
698      PyDoc_STR("The rank of the tensor.")},
699     {"_copy_to_device", (PyCFunction)EagerTensor_copy_to_device,
700      METH_VARARGS | METH_KEYWORDS,
701      PyDoc_STR("Copies the tensor to the desired device.")},
702     {"_num_elements", (PyCFunction)EagerTensor_num_elements, METH_NOARGS,
703      PyDoc_STR("Number of elements in the tensor.")},
704     {nullptr, nullptr},
705 };
706 
EagerTensor_getbuffer(EagerTensor * self,Py_buffer * view,int flags)707 static int EagerTensor_getbuffer(EagerTensor* self, Py_buffer* view,
708                                  int flags) {
709   if ((flags & PyBUF_WRITABLE) == PyBUF_WRITABLE) {
710     PyErr_SetString(PyExc_BufferError, "EagerTensor is not writable.");
711     return -1;
712   }
713 
714   // TensorHandleToNumpy is zero-copy for everything but DT_RESOURCE and
715   // DT_STRING so the following is only slightly slower than a NumPy-free
716   // implementation.
717   auto py_array = tensorflow::make_safe(
718       TFE_TensorHandleToNumpy(self->handle, &self->status));
719   if (MaybeRaiseExceptionFromTFStatus(&self->status, PyExc_BufferError)) {
720     // Cleanup self->status before returning.
721     self->status.status = tensorflow::Status::OK();
722     return -1;
723   }
724   if (PyObject_GetBuffer(py_array.get(), view, flags) < 0) {
725     return -1;
726   }
727   view->readonly = 1;
728   return 0;
729 }
730 
731 static PyBufferProcs EagerTensor_as_buffer = {
732 #if PY_MAJOR_VERSION < 3
733     nullptr, nullptr, nullptr, nullptr,
734 #endif
735     (getbufferproc)EagerTensor_getbuffer,
736     // Never called because getbufferproc delegates to NumPy.
737     (releasebufferproc) nullptr};
738 
739 // Note that here we are trying to dynamically create a new class as a subclass
740 // of a "HEAPTYPE" class that is itself created in python code and passed in at
741 // runtime. This is fairly atypical and undocumented.
742 //
743 // We use the following strategy for this. Unfortunately, we have to use
744 // different approaches for python2.x vs python3.x
745 // For python2.x, we create the class as a static type and set its tp_base to
746 // the passed in type. Unfortunately setting tp_flags to include
747 // Py_TPFLAGS_HEAPTYPE does not work by itself since it needs some more
748 // initialization of the underlying PyHeapTypeObject and not doing that leads to
749 // some random crashes especially during garbage collection.
750 // python3.x explicitly disables a static subclass of a HEAPTYPE base class.
751 // However it provides a new function, PyType_FromSpecWithBases, to create
752 // types dynamically.
753 
754 // Type object for EagerTensor. This is set by TFE_Py_InitEagerTensor.
755 PyTypeObject* EagerTensorType = nullptr;
756 
757 #if PY_MAJOR_VERSION >= 3
758 static PyType_Slot EagerTensor_Type_slots[] = {
759     {Py_tp_dealloc, reinterpret_cast<void*>(EagerTensor_dealloc)},
760     {Py_tp_methods, reinterpret_cast<void*>(EagerTensor_methods)},
761     {Py_tp_getset, reinterpret_cast<void*>(EagerTensor_getseters)},
762     {Py_tp_init, reinterpret_cast<void*>(EagerTensor_init)},
763     {0, nullptr},
764 };
765 #else
766 
767 #define EAGER_TENSOR_TPFLAGS (Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_NEWBUFFER)
768 
769 // TODO(agarwal): support active_trace.
770 static PyTypeObject _EagerTensorType = {
771     // clang-format off
772     PyVarObject_HEAD_INIT(nullptr, 0)
773     // clang-format on
774     "EagerTensor",                      /* tp_name */
775     sizeof(EagerTensor),                /* tp_basicsize */
776     0,                                  /* tp_itemsize */
777     (destructor)EagerTensor_dealloc,    /* tp_dealloc */
778     nullptr,                            /* tp_print */
779     nullptr,                            /* tp_getattr */
780     nullptr,                            /* tp_setattr */
781     nullptr,                            /* tp_compare */
782     nullptr,                            /* tp_repr */
783     nullptr,                            /* tp_as_number */
784     nullptr,                            /* tp_as_sequence */
785     nullptr,                            /* tp_as_mapping */
786     nullptr,                            /* tp_hash */
787     nullptr,                            /* tp_call */
788     nullptr,                            /* tp_str */
789     nullptr,                            /* tp_getattro */
790     nullptr,                            /* tp_setattro */
791     &EagerTensor_as_buffer,             /* tp_as_buffer */
792     EAGER_TENSOR_TPFLAGS,               /* tp_flags */
793     nullptr,                            /* tp_doc */
794     nullptr,                            /* tp_traverse */
795     nullptr,                            /* tp_clear */
796     nullptr,                            /* tp_richcompare */
797     offsetof(EagerTensor, weakreflist), /* tp_weaklistoffset */
798     nullptr,                            /* tp_iter */
799     nullptr,                            /* tp_iternext */
800     EagerTensor_methods,                /* tp_methods */
801     EagerTensor_members,                /* tp_members */
802     EagerTensor_getseters,              /* tp_getset */
803     nullptr,                            /* tp_base */
804     nullptr,                            /* tp_dict */
805     nullptr,                            /* tp_descr_get */
806     nullptr,                            /* tp_descr_set */
807     offsetof(EagerTensor, dict),        /* tp_dictoffset */
808     (initproc)EagerTensor_init,         /* tp_init */
809     nullptr,                            /* tp_alloc */
810     nullptr,                            /* tp_new */
811 };
812 
813 #endif
814 
815 }  // extern "C"
816 
EagerTensor_CheckExact(const PyObject * o)817 bool EagerTensor_CheckExact(const PyObject* o) {
818   return Py_TYPE(o) == EagerTensorType;
819 }
820 
EagerTensor_Handle(const PyObject * o)821 TFE_TensorHandle* EagerTensor_Handle(const PyObject* o) {
822   return reinterpret_cast<const EagerTensor*>(o)->handle;
823 }
824 
EagerTensorFromHandle(TFE_TensorHandle * handle)825 PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle) {
826   if (handle == nullptr) {
827     return nullptr;
828   }
829   EagerTensor* t = reinterpret_cast<EagerTensor*>(
830       EagerTensorType->tp_new(EagerTensorType, Py_None, Py_None));
831   if (t != nullptr) {
832     t->id = get_uid();
833     Py_INCREF(Py_None);
834     t->handle_data = Py_None;
835     Py_INCREF(Py_None);
836     t->tensor_shape = Py_None;
837     t->handle = handle;
838     t->status.status = tensorflow::Status::OK();
839     t->weakreflist = nullptr;
840     PyObject* py_context = GetPyEagerContext();
841     if (py_context == nullptr) {
842       LOG(ERROR) << "Cannot create an eager tensor before eager context has "
843                     "been set or after it has been deleted";
844       return nullptr;
845     }
846     t->context = py_context;
847 
848     if (!MaybeInvokeCreatedOnEagerTensorProfiler(t)) {
849       return nullptr;
850     }
851   }
852   return reinterpret_cast<PyObject*>(t);
853 }
854 
PyEagerTensor_ID(const PyObject * tensor)855 tensorflow::int64 PyEagerTensor_ID(const PyObject* tensor) {
856   DCHECK(EagerTensor_CheckExact(tensor));
857   return reinterpret_cast<const EagerTensor*>(tensor)->id;
858 }
859 
PyEagerTensor_Dtype(const PyObject * tensor)860 tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor) {
861   DCHECK(EagerTensor_CheckExact(tensor));
862   return static_cast<tensorflow::DataType>(TFE_TensorHandleDataType(
863       reinterpret_cast<const EagerTensor*>(tensor)->handle));
864 }
865 
PyEagerTensor_NumElements(PyObject * tensor)866 tensorflow::int64 PyEagerTensor_NumElements(PyObject* tensor) {
867   DCHECK(EagerTensor_CheckExact(tensor));
868   EagerTensor* as_c_eager_tensor = reinterpret_cast<EagerTensor*>(tensor);
869   tensorflow::int64 result = TFE_TensorHandleNumElements(
870       as_c_eager_tensor->handle, &as_c_eager_tensor->status);
871 
872   if (MaybeRaiseExceptionFromTFStatus(&as_c_eager_tensor->status,
873                                       PyExc_ValueError)) {
874     // Cleanup status before returning.
875     as_c_eager_tensor->status.status = tensorflow::Status::OK();
876     return -1;
877   }
878 
879   return result;
880 }
881 
TFE_Py_InitEagerTensor(PyObject * base_class)882 PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
883   if (!PyType_Check(base_class)) {
884     PyErr_SetString(
885         PyExc_TypeError,
886         tensorflow::strings::StrCat(
887             "Expecting a class definition for `base_class` passed to ",
888             "TFE_InitEagerTensor. Got ", Py_TYPE(base_class)->tp_name)
889             .c_str());
890     return nullptr;
891   }
892   // Note that we allocated kMaxEagerTensorParentSize bytes of unused space in
893   // EagerTensor to allow for the space usage of the base class.
894   PyTypeObject* base_class_type = reinterpret_cast<PyTypeObject*>(base_class);
895   if (base_class_type->tp_basicsize > kMaxEagerTensorParentSize) {
896     PyErr_SetString(
897         PyExc_TypeError,
898         tensorflow::strings::StrCat(
899             "Unable to create subclass EagerTensor from base class ",
900             Py_TYPE(base_class)->tp_name,
901             ". Need its size to be <= ", kMaxEagerTensorParentSize)
902             .c_str());
903     return nullptr;
904   }
905   if (base_class_type->tp_itemsize != 0) {
906     PyErr_SetString(
907         PyExc_TypeError,
908         tensorflow::strings::StrCat(
909             "Unable to create subclass EagerTensor from base class ",
910             Py_TYPE(base_class)->tp_name,
911             " which supports variable length instances.")
912             .c_str());
913     return nullptr;
914   }
915   Py_INCREF(base_class);
916 #if PY_MAJOR_VERSION >= 3
917   PyObject* bases = PyTuple_New(1);
918   PyTuple_SET_ITEM(bases, 0, base_class);
919 
920   tensorflow::Safe_PyObjectPtr base_class_module(
921       PyObject_GetAttrString(base_class, "__module__"));
922   const char* module = nullptr;
923   if (PyErr_Occurred()) {
924     PyErr_Clear();
925     module = "__builtin__";
926   } else {
927     module = PyBytes_AsString(base_class_module.get());
928     if (module == nullptr) {
929       PyErr_Clear();
930       module = PyUnicode_AsUTF8(base_class_module.get());
931       if (module == nullptr) {
932         PyErr_Clear();
933         module = "__builtin__";
934       }
935     }
936   }
937 
938   // NOTE: The c_str from this string needs to outlast the function, hence is
939   // static.
940   static tensorflow::string fully_qualified_name =
941       tensorflow::strings::StrCat(module, ".EagerTensor");
942 
943   static PyType_Spec EagerTensor_Type_spec = {
944       fully_qualified_name.c_str(), sizeof(EagerTensor), 0,
945       Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE, EagerTensor_Type_slots};
946 
947   EagerTensorType = reinterpret_cast<PyTypeObject*>(
948       PyType_FromSpecWithBases(&EagerTensor_Type_spec, bases));
949   if (PyErr_Occurred()) {
950     return nullptr;
951   }
952   if (EagerTensorType == nullptr) {
953     PyErr_SetString(PyExc_RuntimeError, "Error while creating EagerTensorType");
954     return nullptr;
955   }
956   EagerTensorType->tp_dictoffset = offsetof(EagerTensor, dict);
957   EagerTensorType->tp_as_buffer = &EagerTensor_as_buffer;
958 #else
959   _EagerTensorType.tp_base = base_class_type;
960 
961   if (PyType_Ready(&_EagerTensorType) < 0) {
962     if (PyErr_Occurred()) return nullptr;
963     PyErr_SetString(PyExc_RuntimeError,
964                     "Error while creating EagerTensor type.");
965     return nullptr;
966   }
967   EagerTensorType = &_EagerTensorType;
968   Py_INCREF(EagerTensorType);
969 #endif
970   return reinterpret_cast<PyObject*>(EagerTensorType);
971 }
972 
TFE_Py_SetEagerTensorProfiler(PyObject * profiler)973 PyObject* TFE_Py_SetEagerTensorProfiler(PyObject* profiler) {
974   Py_XDECREF(eager_tensor_profiler);
975 
976   if (profiler == Py_None) {
977     eager_tensor_profiler = nullptr;
978   } else {
979     eager_tensor_profiler = profiler;
980     Py_INCREF(eager_tensor_profiler);
981   }
982   Py_RETURN_NONE;
983 }
984 
TFE_Py_TensorShapeSlice(PyObject * tensors,int slice_dim)985 PyObject* TFE_Py_TensorShapeSlice(PyObject* tensors, int slice_dim) {
986   if (!PyList_Check(tensors) && !PyTuple_Check(tensors)) {
987     PyErr_SetString(PyExc_TypeError,
988                     tensorflow::strings::StrCat(
989                         "tensors argument must be a list or a tuple. Got \"",
990                         Py_TYPE(tensors)->tp_name, "\"")
991                         .c_str());
992     return nullptr;
993   }
994   if (slice_dim < 0) {
995     PyErr_SetString(
996         PyExc_ValueError,
997         tensorflow::strings::StrCat("Slice dimension must be non-negative. "
998                                     "Got ",
999                                     slice_dim)
1000             .c_str());
1001     return nullptr;
1002   }
1003 
1004   Py_ssize_t num_tensors = PySequence_Fast_GET_SIZE(tensors);
1005   PyObject** tensors_array = PySequence_Fast_ITEMS(tensors);
1006   int64_t num_tensors_int = static_cast<int64_t>(num_tensors);
1007   auto tensor = tensorflow::make_safe(TF_AllocateTensor(
1008       TF_INT32, &num_tensors_int, /*num_dims=*/1, /*len=*/4 * num_tensors_int));
1009   int32_t* data = reinterpret_cast<int32_t*>(TF_TensorData(tensor.get()));
1010   auto status = tensorflow::make_safe(TF_NewStatus());
1011   for (Py_ssize_t i = 0; i < num_tensors; ++i) {
1012     PyObject* tensor_obj = tensors_array[i];
1013     if (!EagerTensor_CheckExact(tensor_obj)) {
1014       PyErr_SetString(PyExc_TypeError,
1015                       tensorflow::strings::StrCat(
1016                           "Expected a list of EagerTensors but "
1017                           "element ",
1018                           i, " has type \"", Py_TYPE(tensor_obj)->tp_name, "\"")
1019                           .c_str());
1020       return nullptr;
1021     }
1022 
1023     EagerTensor* t = reinterpret_cast<EagerTensor*>(tensor_obj);
1024     TFE_TensorHandle* handle = t->handle;
1025     int num_dims = TFE_TensorHandleNumDims(handle, status.get());
1026     if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_ValueError)) {
1027       return nullptr;
1028     }
1029     if (slice_dim >= num_dims) {
1030       PyErr_SetString(
1031           PyExc_IndexError,
1032           tensorflow::strings::StrCat("Slice dimension (", slice_dim,
1033                                       ") must be smaller than rank of all "
1034                                       "tensors, but tensor at index ",
1035                                       i, " has rank ", num_dims)
1036               .c_str());
1037       return nullptr;
1038     }
1039     int64_t dim = TFE_TensorHandleDim(handle, slice_dim, status.get());
1040     if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_ValueError)) {
1041       return nullptr;
1042     }
1043     data[i] = dim;
1044   }
1045 
1046   TFE_TensorHandle* handle = TFE_NewTensorHandle(tensor.get(), status.get());
1047   if (!status->status.ok()) {
1048     PyErr_SetString(
1049         PyExc_RuntimeError,
1050         tensorflow::strings::StrCat("Failed to construct new tensor handle: ",
1051                                     TF_Message(status.get()))
1052             .c_str());
1053     return nullptr;
1054   }
1055 
1056   return EagerTensorFromHandle(handle);
1057 }
1058 
TFE_Py_TensorShapeOnDevice(PyObject * tensor)1059 PyObject* TFE_Py_TensorShapeOnDevice(PyObject* tensor) {
1060   if (!EagerTensor_CheckExact(tensor)) {
1061     PyErr_SetString(
1062         PyExc_TypeError,
1063         tensorflow::strings::StrCat("Expected an EagerTensors but got type \"",
1064                                     Py_TYPE(tensor)->tp_name, "\"")
1065             .c_str());
1066     return nullptr;
1067   }
1068   TFE_TensorHandle* handle = EagerTensor_Handle(tensor);
1069 
1070   auto status = tensorflow::make_safe(TF_NewStatus());
1071   TFE_TensorDebugInfo* debug_info =
1072       TFE_TensorHandleTensorDebugInfo(handle, status.get());
1073   if (!status->status.ok()) {
1074     PyErr_SetString(
1075         PyExc_RuntimeError,
1076         tensorflow::strings::StrCat("Error retrieving tensor's device shape: ",
1077                                     TF_Message(status.get()))
1078             .c_str());
1079     return nullptr;
1080   }
1081 
1082   int rank = TFE_TensorDebugInfoOnDeviceNumDims(debug_info);
1083   PyObject* shape = PyTuple_New(rank);
1084   for (int i = 0; i < rank; ++i) {
1085     tensorflow::int64 dim_size = TFE_TensorDebugInfoOnDeviceDim(debug_info, i);
1086     PyTuple_SET_ITEM(shape, i, PyLong_FromLongLong(dim_size));
1087   }
1088   TFE_DeleteTensorDebugInfo(debug_info);
1089 
1090   return shape;
1091 }
1092