• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <stdlib.h>
17 
18 #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
19 #include "tensorflow/python/lib/core/numpy.h"
20 #include "tensorflow/python/lib/core/py_seq_tensor.h"
21 #include "tensorflow/python/lib/core/safe_ptr.h"
22 
23 #include "tensorflow/python/eager/pywrap_tensor.h"
24 #include "tensorflow/python/eager/pywrap_tfe.h"
25 
26 #include "tensorflow/c/c_api.h"
27 #include "tensorflow/core/lib/strings/strcat.h"
28 #include "tensorflow/python/lib/core/ndarray_tensor.h"
29 
30 #include "tensorflow/core/framework/types.h"
31 
32 #include "structmember.h"  // NOLINT // For PyMemberDef
33 
34 // forward declare
35 struct EagerTensor;
36 
37 namespace {
38 
39 // An instance of _EagerTensorProfiler that will receive callbacks about
40 // events on eager tensors. This is set by TFE_Py_InitEagerTensor, if at all.
41 PyObject* eager_tensor_profiler = nullptr;
42 
GetContext(PyObject * ctx)43 TFE_Context* GetContext(PyObject* ctx) {
44   TFE_Context* context =
45       reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(ctx, nullptr));
46   if (context == nullptr) {
47     PyErr_SetString(PyExc_TypeError,
48                     tensorflow::strings::StrCat(
49                         "Expecting a PyCapsule encoded context handle. Got ",
50                         Py_TYPE(ctx)->tp_name)
51                         .c_str());
52   }
53   return context;
54 }
55 
56 // Convert a Python numpy.ndarray object to a TFE_TensorHandle.
57 // The two may share underlying storage so changes to one may reflect in the
58 // other.
NumpyToTensorHandle(PyObject * obj)59 TFE_TensorHandle* NumpyToTensorHandle(PyObject* obj) {
60   tensorflow::Tensor t;
61   auto cppstatus = tensorflow::NdarrayToTensor(obj, &t);
62   if (cppstatus.ok()) {
63     return TFE_NewTensorHandle(t);
64   } else {
65     PyErr_SetString(PyExc_ValueError,
66                     tensorflow::strings::StrCat(
67                         "Failed to convert numpy ndarray to a Tensor (",
68                         cppstatus.error_message(), ").")
69                         .c_str());
70     return nullptr;
71   }
72 }
73 
CopyToDevice(TFE_TensorHandle * handle,PyObject * ctx,PyObject * dev)74 TFE_TensorHandle* CopyToDevice(TFE_TensorHandle* handle, PyObject* ctx,
75                                PyObject* dev) {
76   const char* device = "";
77   if (dev != nullptr && dev != Py_None) {
78     device = PyBytes_AsString(dev);
79 #if PY_MAJOR_VERSION >= 3
80     if (device == nullptr) {
81       PyErr_Clear();
82       device = PyUnicode_AsUTF8(dev);
83     }
84 #endif
85     if (device == nullptr) {
86       PyErr_SetString(PyExc_TypeError,
87                       "Error parsing device argument to CopyToDevice");
88       return nullptr;
89     }
90   }
91   TFE_Context* context = GetContext(ctx);
92   if (context == nullptr) {  // PyErr already set by GetContext
93     return nullptr;
94   }
95   auto status = tensorflow::make_safe(TF_NewStatus());
96   TFE_TensorHandle* new_handle =
97       TFE_TensorHandleCopyToDevice(handle, context, device, status.get());
98   if (TF_GetCode(status.get()) != TF_OK) {
99     PyErr_SetString(
100         PyExc_RuntimeError,
101         tensorflow::strings::StrCat("Error copying tensor to device: ", device,
102                                     ". ", TF_Message(status.get()))
103             .c_str());
104     return nullptr;
105   }
106   return new_handle;
107 }
108 
109 // Helper function to convert `v` to an int and store it in `*out`. Returns true
110 // on success, false otherwise.
111 // Note that we assume that v is a python int (not long) representing a
112 // TF_DataType value.
PyIntToDataType(PyObject * v,int * out)113 bool PyIntToDataType(PyObject* v, int* out) {
114 #if PY_MAJOR_VERSION < 3
115   if (PyInt_Check(v)) {
116     *out = PyInt_AS_LONG(v);
117     return true;
118   }
119 #else
120   if (PyLong_Check(v)) {
121     *out = PyLong_AsLong(v);
122     return true;
123   }
124 #endif
125   return false;
126 }
127 
128 // Helper function to create a python integer from TF_DataType.
PyIntFromDataType(TF_DataType l)129 PyObject* PyIntFromDataType(TF_DataType l) {
130 #if PY_MAJOR_VERSION < 3
131   return PyInt_FromLong(l);
132 #else
133   return PyLong_FromLong(l);
134 #endif
135 }
136 
137 }  // namespace
138 
139 namespace tensorflow {
140 // This function checks whether the desired type is "compatible" with the
141 // inferred type. At a high level, compatibility means that all integral types
142 // are compatible with each other, and all floating types are compatible with
143 // each other.
144 //
145 // Type compatibility doesn't consider overflows (i.e. int64 is *always*
146 // compatible with int32). This is intended to match graph behavior.
IsCompatible(int desired_dtype,TF_DataType returned_dtype)147 bool IsCompatible(int desired_dtype, TF_DataType returned_dtype) {
148   tensorflow::DataType desired =
149       static_cast<tensorflow::DataType>(desired_dtype);
150   tensorflow::DataType returned =
151       static_cast<tensorflow::DataType>(returned_dtype);
152 
153   if (desired == returned) return true;
154 
155   if (tensorflow::DataTypeIsInteger(desired) &&
156       tensorflow::DataTypeIsInteger(returned)) {
157     return true;
158   } else if (tensorflow::DataTypeIsFloating(desired) &&
159              (tensorflow::DataTypeIsFloating(returned) ||
160               tensorflow::DataTypeIsInteger(returned))) {
161     return true;
162   } else if (tensorflow::DataTypeIsComplex(desired) &&
163              (tensorflow::DataTypeIsComplex(returned) ||
164               tensorflow::DataTypeIsInteger(returned) ||
165               tensorflow::DataTypeIsFloating(returned))) {
166     return true;
167   } else if (tensorflow::DataTypeIsQuantized(desired) &&
168              tensorflow::DataTypeIsInteger(returned)) {
169     return true;
170   }
171   return false;
172 }
173 
174 // Casts data referred to by `handle` from type `src_type_enum` to type
175 // `dst_type_enum`.
EagerCast(TFE_Context * ctx,TFE_TensorHandle * handle,TF_DataType src_type_enum,TF_DataType dst_type_enum,TF_Status * out_status)176 TFE_TensorHandle* EagerCast(TFE_Context* ctx, TFE_TensorHandle* handle,
177                             TF_DataType src_type_enum,
178                             TF_DataType dst_type_enum, TF_Status* out_status) {
179   if (ctx == nullptr) return nullptr;
180   const char* op_name = "Cast";
181   const char* device_name = "/job:localhost/replica:0/task:0/device:CPU:0";
182   TFE_Op* op = TFE_NewOp(ctx, op_name, out_status);
183 #define RETURN_ERROR  \
184   {                   \
185     TFE_DeleteOp(op); \
186     return nullptr;   \
187   }
188   if (TF_GetCode(out_status) != TF_OK) RETURN_ERROR
189   TFE_OpSetDevice(op, device_name, out_status);
190   if (TF_GetCode(out_status) != TF_OK) RETURN_ERROR
191   TFE_OpAddInput(op, handle, out_status);
192   if (TF_GetCode(out_status) != TF_OK) RETURN_ERROR
193   TFE_OpSetAttrType(op, "SrcT", src_type_enum);
194   TFE_OpSetAttrType(op, "DstT", dst_type_enum);
195   TFE_OpSetAttrBool(op, "Truncate", false);
196   TFE_TensorHandle* output = nullptr;
197   int num_outputs = 1;
198   TFE_Execute(op, &output, &num_outputs, out_status);
199   if (TF_GetCode(out_status) != TF_OK || num_outputs != 1 ||
200       output == nullptr) {
201     if (output != nullptr) {
202       TFE_DeleteTensorHandle(output);
203     }
204     RETURN_ERROR
205   }
206   TFE_DeleteOp(op);
207   return output;
208 #undef RETURN_ERROR
209 }
210 
ConvertToEagerTensor(PyObject * value,PyObject * dtype)211 TFE_TensorHandle* ConvertToEagerTensor(PyObject* value, PyObject* dtype) {
212   int desired_dtype = -1;
213   if (dtype != Py_None) {
214     if (!PyIntToDataType(dtype, &desired_dtype)) {
215       PyErr_SetString(PyExc_TypeError,
216                       tensorflow::strings::StrCat(
217                           "Expecting a DataType value for dtype. Got ",
218                           Py_TYPE(dtype)->tp_name)
219                           .c_str());
220       return nullptr;
221     }
222   }
223   tensorflow::Safe_PyObjectPtr value_decrefer;
224   if (PyArray_IsScalar(value, Generic)) {
225     // Convert numpy scalars to numpy arrays.
226     value = PyArray_FromScalar(value, nullptr);
227     // The returned value needs to be DECREF'd, but the original value was
228     // created in python code, and doesn't need to be DECREF'd.
229     value_decrefer.reset(value);
230   }
231   if (PyArray_Check(value)) {
232     int desired_np_dtype = -1;
233     if (desired_dtype >= 0) {
234       if (!tensorflow::TF_DataType_to_PyArray_TYPE(
235                static_cast<TF_DataType>(desired_dtype), &desired_np_dtype)
236                .ok()) {
237         PyErr_SetString(PyExc_TypeError,
238                         tensorflow::strings::StrCat(
239                             "Invalid dtype argument value ", desired_dtype)
240                             .c_str());
241         return nullptr;
242       }
243     }
244     PyArrayObject* array = reinterpret_cast<PyArrayObject*>(value);
245     int current_np_dtype = PyArray_TYPE(array);
246     auto safe_value = tensorflow::make_safe(static_cast<PyObject*>(nullptr));
247     if ((desired_np_dtype >= 0 && desired_np_dtype != current_np_dtype) ||
248         !PyArray_ISCARRAY(array)) {
249       int new_dtype =
250           desired_np_dtype >= 0 ? desired_np_dtype : current_np_dtype;
251       safe_value = tensorflow::make_safe(
252           PyArray_FromAny(value, PyArray_DescrFromType(new_dtype), 0, 0,
253                           NPY_ARRAY_CARRAY | NPY_ARRAY_FORCECAST, nullptr));
254       if (PyErr_Occurred()) return nullptr;
255       if (safe_value == nullptr) {
256         PyErr_SetString(PyExc_ValueError, "Error while casting a numpy value");
257         return nullptr;
258       }
259       value = safe_value.get();
260     }
261     return NumpyToTensorHandle(value);
262   } else {
263     tensorflow::Tensor t;
264     // TODO(josh11b): Have PySeqToTensor set python errors instead of
265     // returning Status.
266     auto cppstatus = tensorflow::PySeqToTensor(value, dtype, &t);
267     if (!cppstatus.ok()) {
268       PyErr_SetString(PyExc_ValueError, cppstatus.error_message().c_str());
269       return nullptr;
270     }
271     return TFE_NewTensorHandle(t);
272   }
273 }
274 }  // namespace tensorflow
275 
276 extern "C" {
277 
278 static const int kMaxEagerTensorParentSize = 64;
279 
280 // TODO(agarwal): store context handle in EagerTensor.
281 typedef struct EagerTensor {
282   PyObject_HEAD;
283   // Note that we leave kMaxEagerTensorParentSize bytes here for use by the
284   // parent class. The parent class is set at runtime, so we don't know the
285   // exact size at compile time.
286   char unused[kMaxEagerTensorParentSize];
287   TFE_TensorHandle* handle;
288   int64_t id;
289   // This mirrors tensorflow.core.framework.ops.Tensor._handle_data Which will
290   // be None for tensors of type other than DT_REOSURCE. For DT_RESOURCE
291   // tensors, this will contain a serialized HandleData proto with shape
292   // inference metadata about shapes and dtypes of resources accessible from
293   // this handle.
294   // Note that we assume that handle_data cannot participate in reference
295   // cycles, and hence don't provide GC support for it.
296   PyObject* handle_data;
297 
298   // This stores `_keras_mask` object and is set by Tensorflow layers.
299   PyObject* keras_mask;
300 
301   // This stores `_tensor_shape`, a cached `TensorShape` object, and is set the
302   // first time that `_EagerTensorBase`'s `shape` property is called.
303   PyObject* tensor_shape;
304 
305   // We store a status object here as an optimization to avoid allocating a new
306   // Status objects on different functions that operate on EagerTensor and need
307   // to use a TF_Status object. However note that accesses to `status` are not
308   // thread-safe.
309   TF_Status* status;
310 
311   PyObject* weakreflist; /* List of weak references */
312 
313   // Per-instance attribute dictionary, to support monkey patching
314   // (e.g. EagerTensor.assign when slicing variables). This dictionary is
315   // created by CPython the first time an attribute is assigned, pointed to by
316   // tp_dictoffset. Note that garbage collection is not enabled for
317   // EagerTensors, so assigning objects to EagerTensor attributes which require
318   // garbage collection is likely to cause issues.
319   PyObject* dict;
320 } EagerTensor;
321 
322 namespace {
323 
324 // Returns true on success - successfully invoked or no profiler registered.
325 // Returns false if some error occurred.
MaybeInvokeCreatedOnEagerTensorProfiler(EagerTensor * created_tensor)326 bool MaybeInvokeCreatedOnEagerTensorProfiler(EagerTensor* created_tensor) {
327   if (eager_tensor_profiler != nullptr) {
328 #if PY_MAJOR_VERSION < 3
329     PyObject* created_method_name = PyString_InternFromString("created");
330 #else
331     PyObject* created_method_name = PyUnicode_InternFromString("created");
332 #endif
333     if (created_method_name == nullptr) {
334       return false;
335     }
336     PyObject* result = PyObject_CallMethodObjArgs(
337         eager_tensor_profiler, created_method_name, created_tensor, NULL);
338     if (result == nullptr) {
339       LOG(ERROR) << "Invoking created() on EagerTensor profiler failed";
340       // While we can potentially continue because the error is related to
341       // profiling, we choose to return an error because:
342       //  - If profiling is used, the user likely wants to stop execution on
343       //    profiling errors.
344       //  - Error in profiling code might have left some state in an invalid
345       //    form that can lead to an error later on. Better to fail fast.
346       Py_DECREF(created_method_name);
347       return false;
348     }
349     Py_DECREF(created_method_name);
350     Py_DECREF(result);
351   }
352   return true;
353 }
354 
355 }  // namespace
356 
357 // tp_init for EagerTensor.
EagerTensor_init(EagerTensor * self,PyObject * args,PyObject * kwds)358 int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) {
359   self->id = get_uid();
360   self->handle = nullptr;
361   Py_INCREF(Py_None);
362   self->handle_data = Py_None;
363   Py_INCREF(Py_None);
364   self->keras_mask = Py_None;
365   Py_INCREF(Py_None);
366   self->tensor_shape = Py_None;
367   self->status = TF_NewStatus();
368   self->dict = nullptr;
369   self->weakreflist = nullptr;
370   PyObject* value;
371   PyObject* context = nullptr;
372   PyObject* device = nullptr;
373   PyObject* dtype = Py_None;
374   PyObject* other_value = nullptr;
375   const char* kwlist[] = {"value", "context",     "device",
376                           "dtype", "other_value", nullptr};
377   if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|OO",
378                                    const_cast<char**>(kwlist), &value, &context,
379                                    &device, &dtype, &other_value)) {
380     return -1;
381   }
382 
383   if (other_value != nullptr) {
384     if (!EagerTensor_CheckExact(other_value)) {
385       PyErr_SetString(PyExc_TypeError,
386                       tensorflow::strings::StrCat(
387                           "Expecting an EagerTensor for other_value, got ",
388                           Py_TYPE(other_value)->tp_name)
389                           .c_str());
390 
391       return -1;
392     }
393     EagerTensor* other = reinterpret_cast<EagerTensor*>(other_value);
394     self->handle =
395         TFE_TensorHandleCopySharingTensor(other->handle, self->status);
396 
397     if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
398       return -1;
399     }
400 
401     return 0;
402   }
403 
404   // Extract dtype
405   int desired_dtype = -1;
406   if (dtype != Py_None) {
407     if (!PyIntToDataType(dtype, &desired_dtype)) {
408       PyErr_SetString(PyExc_TypeError,
409                       tensorflow::strings::StrCat(
410                           "Expecting a DataType value for dtype. Got ",
411                           Py_TYPE(dtype)->tp_name)
412                           .c_str());
413       return -1;
414     }
415   }
416   PyErr_Clear();
417   tensorflow::Safe_TFE_TensorHandlePtr handle =
418       tensorflow::make_safe(static_cast<TFE_TensorHandle*>(
419           tensorflow::ConvertToEagerTensor(value, dtype)));
420   if (handle == nullptr) return -1;
421   TF_DataType handle_dtype = TFE_TensorHandleDataType(handle.get());
422   if (desired_dtype >= 0 && desired_dtype != handle_dtype) {
423     // Check type compatibility.
424     if (tensorflow::IsCompatible(desired_dtype, handle_dtype)) {
425       handle = tensorflow::make_safe(tensorflow::EagerCast(
426           GetContext(context), handle.get(), handle_dtype,
427           static_cast<TF_DataType>(desired_dtype), self->status));
428       if (TF_GetCode(self->status) != TF_OK) {
429         PyErr_SetString(
430             PyExc_TypeError,
431             tensorflow::strings::StrCat(
432                 "Error while casting from DataType ",
433                 tensorflow::DataTypeString(
434                     static_cast<tensorflow::DataType>(handle_dtype)),
435                 " to ",
436                 tensorflow::DataTypeString(
437                     static_cast<tensorflow::DataType>(desired_dtype)),
438                 ". ", TF_Message(self->status))
439                 .c_str());
440         // Cleanup self->status before returning.
441         TF_SetStatus(self->status, TF_OK, "");
442         return -1;
443       }
444       handle_dtype = TFE_TensorHandleDataType(handle.get());
445     } else {
446       tensorflow::Safe_PyObjectPtr value_str(PyObject_Str(value));
447       PyErr_SetString(
448           PyExc_TypeError,
449           tensorflow::strings::StrCat(
450               "Cannot convert provided value to EagerTensor. Provided value: ",
451               TFE_GetPythonString(value_str.get()), " Requested dtype: ",
452               tensorflow::DataTypeString(
453                   static_cast<tensorflow::DataType>(desired_dtype)))
454               .c_str());
455       return -1;
456     }
457   }
458 
459   // Almost all TensorFlow kernels for GPU devices keep int32 tensors in host
460   // memory. We approximate the same behavior for eager execution - keeping
461   // int32 tensors in host memory.
462   //
463   // We do so to preclude the need for callers into such kernels from having to
464   // explicitly place the int32 tensors in host memory. For example, without
465   // this, one needed:
466   //
467   // with tf.device('/gpu:0'):
468   //   ...// code here
469   //   with tf.device('/cpu:0'):
470   //     shape = tf.constant(...)
471   //   y = tf.random_uniform(shape)
472   //
473   // Without the CPU device block, tfe.ops.random_uniform would fail since the
474   // kernel expects the shape in host memory.
475   //
476   // With this support, we simplify the code:
477   //
478   // with tf.device('/gpu:0'):
479   //   y = tf.random_uniform(...)
480   //
481   // The approximation is not exact there are GPU kernels which do not require
482   // host memory for int32 tensors. This will lead to a discrepancy between
483   // eager and graph execution.
484   // TODO(ashankar): Fix this.
485   if (handle_dtype != TF_INT32) {
486     // Note that this is a shallow copy and will share the underlying buffer
487     // if copying to the same device.
488     handle = tensorflow::make_safe(CopyToDevice(handle.get(), context, device));
489     if (handle == nullptr) return -1;
490   }
491   self->handle = handle.release();
492 
493   if (!MaybeInvokeCreatedOnEagerTensorProfiler(self)) {
494     return -1;
495   }
496 
497   return 0;
498 }
499 
500 // tp_dealloc for EagerTensor.
EagerTensor_dealloc(EagerTensor * self)501 void EagerTensor_dealloc(EagerTensor* self) {
502   // Clear weak references to self.
503   // Needs to happen before any actual destruction.
504   PyObject_ClearWeakRefs((PyObject*)self);
505 
506   TF_DeleteStatus(self->status);
507   Py_DECREF(self->handle_data);
508   Py_DECREF(self->keras_mask);
509   Py_DECREF(self->tensor_shape);
510   // If an attribute dictionary has been created, release it. Note that this
511   // is only ever created by CPython's attribute setting methods; we don't
512   // create it ourselves.
513   Py_CLEAR(self->dict);
514   if (self->handle != nullptr) {
515     TFE_DeleteTensorHandle(self->handle);
516     self->handle = nullptr;
517   }
518   // We have the global interpreter lock, so use this chance to perform delayed
519   // refcount decrements.
520   tensorflow::ClearDecrefCache();
521   auto id = self->id;
522   Py_TYPE(self)->tp_free(self);
523   TFE_Py_TapeSetDeleteTrace(id);
524 }
525 
526 // Getter for `_id`.
EagerTensor_getid(EagerTensor * self,void * closure)527 static PyObject* EagerTensor_getid(EagerTensor* self, void* closure) {
528   return PyLong_FromLongLong(self->id);
529 }
530 
531 // Getter for `_datatype_enum`.
EagerTensor_datatype_enum(EagerTensor * self)532 static PyObject* EagerTensor_datatype_enum(EagerTensor* self) {
533   return PyIntFromDataType(TFE_TensorHandleDataType(self->handle));
534 }
535 
536 // Getter for `_shape_tuple`.
EagerTensor_shape_tuple(EagerTensor * self)537 static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
538   auto handle = self->handle;
539   int n = TFE_TensorHandleNumDims(handle, self->status);
540   if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
541     // Cleanup self->status before returning.
542     TF_SetStatus(self->status, TF_OK, "");
543     return nullptr;
544   }
545   PyObject* shape = PyTuple_New(n);
546   if (PyErr_Occurred()) return nullptr;
547   for (int i = 0; i < n; ++i) {
548     PyObject* dim =
549         PyLong_FromLongLong(TFE_TensorHandleDim(handle, i, self->status));
550     if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError) ||
551         dim == nullptr || PyTuple_SetItem(shape, i, dim) != 0) {
552       // Cleanup self->status before returning.
553       TF_SetStatus(self->status, TF_OK, "");
554       Py_DECREF(shape);
555       if (dim != nullptr) Py_DECREF(dim);
556       PyErr_SetString(PyExc_RuntimeError, "Error while creating shape");
557       return nullptr;
558     }
559   }
560   return shape;
561 }
562 
563 // Getter for `_rank`.
EagerTensor_rank(EagerTensor * self)564 static PyObject* EagerTensor_rank(EagerTensor* self) {
565   int num_dims = TFE_TensorHandleNumDims(self->handle, self->status);
566   if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
567     // Cleanup self->status before returning.
568     TF_SetStatus(self->status, TF_OK, "");
569     return nullptr;
570   }
571 #if PY_MAJOR_VERSION < 3
572   return PyInt_FromLong(num_dims);
573 #else
574   return PyLong_FromLong(num_dims);
575 #endif
576 }
577 
578 // Getter for `_num_elements`.
EagerTensor_num_elements(EagerTensor * self)579 static PyObject* EagerTensor_num_elements(EagerTensor* self) {
580   auto handle = self->handle;
581   int n = TFE_TensorHandleNumElements(handle, self->status);
582   if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
583     // Cleanup self->status before returning.
584     TF_SetStatus(self->status, TF_OK, "");
585     return nullptr;
586   }
587   return PyLong_FromLongLong(n);
588 }
589 
EagerTensor_tensor_handle(EagerTensor * self,void * unused)590 static PyObject* EagerTensor_tensor_handle(EagerTensor* self, void* unused) {
591   Py_INCREF(self->handle_data);
592   return self->handle_data;
593 }
594 
EagerTensor_settensor_handle(EagerTensor * self,PyObject * value,void * unused)595 static int EagerTensor_settensor_handle(EagerTensor* self, PyObject* value,
596                                         void* unused) {
597   Py_DECREF(self->handle_data);
598   Py_INCREF(value);
599   self->handle_data = value;
600   return 0;
601 }
602 
EagerTensor_keras_mask(EagerTensor * self,void * unused)603 static PyObject* EagerTensor_keras_mask(EagerTensor* self, void* unused) {
604   Py_INCREF(self->keras_mask);
605   return self->keras_mask;
606 }
607 
EagerTensor_setkeras_mask(EagerTensor * self,PyObject * value,void * unused)608 static int EagerTensor_setkeras_mask(EagerTensor* self, PyObject* value,
609                                      void* unused) {
610   Py_DECREF(self->keras_mask);
611   Py_INCREF(value);
612   self->keras_mask = value;
613   return 0;
614 }
615 
EagerTensor_tensor_shape(EagerTensor * self,void * unused)616 static PyObject* EagerTensor_tensor_shape(EagerTensor* self, void* unused) {
617   Py_INCREF(self->tensor_shape);
618   return self->tensor_shape;
619 }
620 
EagerTensor_settensor_shape(EagerTensor * self,PyObject * value,void * unused)621 static int EagerTensor_settensor_shape(EagerTensor* self, PyObject* value,
622                                        void* unused) {
623   Py_DECREF(self->tensor_shape);
624   Py_INCREF(value);
625   self->tensor_shape = value;
626   return 0;
627 }
628 // Function `_copy_to_device`.
EagerTensor_copy_to_device(EagerTensor * self,PyObject * args,PyObject * kwds)629 static PyObject* EagerTensor_copy_to_device(EagerTensor* self, PyObject* args,
630                                             PyObject* kwds) {
631   const char* kwlist[] = {"context", "device", nullptr};
632   PyObject* ctx = nullptr;
633   PyObject* dev = nullptr;
634   if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO", const_cast<char**>(kwlist),
635                                    &ctx, &dev) ||
636       !ctx || !dev) {
637     return nullptr;
638   }
639   auto handle = CopyToDevice(self->handle, ctx, dev);
640   return EagerTensorFromHandle(handle);
641 }
642 
643 // Function `_numpy`.
644 // Convert an EagerTensor to a Python numpy.ndarray object.
645 // The two may share underlying storage so changes to one may reflect in the
646 // other.
647 // Note that if `self` is not on CPU, we raise an Exception.
EagerTensor_numpy(EagerTensor * self)648 static PyObject* EagerTensor_numpy(EagerTensor* self) {
649   auto status = tensorflow::make_safe(TF_NewStatus());
650   const tensorflow::Tensor* t =
651       TFE_TensorHandleUnderlyingTensorInHostMemory(self->handle, status.get());
652   if (TF_GetCode(status.get()) != TF_OK) {
653     PyErr_SetString(PyExc_RuntimeError, TF_Message(status.get()));
654     return nullptr;
655   }
656   PyObject* ret = nullptr;
657   auto cppstatus = tensorflow::TensorToNdarray(*t, &ret);
658   if (MaybeRaiseExceptionFromStatus(cppstatus, PyExc_RuntimeError)) {
659     Py_XDECREF(ret);
660     return nullptr;
661   } else {
662     return ret;
663   }
664 }
665 
666 // Getter `device`.
EagerTensor_device(EagerTensor * self)667 static PyObject* EagerTensor_device(EagerTensor* self) {
668   const char* device = TFE_TensorHandleDeviceName(self->handle, self->status);
669   if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
670     // Cleanup self->status before returning.
671     TF_SetStatus(self->status, TF_OK, "");
672     return nullptr;
673   }
674 #if PY_MAJOR_VERSION >= 3
675   return PyUnicode_FromString(device);
676 #else
677   return PyBytes_FromString(device);
678 #endif
679 }
680 
681 // Getter `backing_device`.
EagerTensor_backing_device(EagerTensor * self)682 static PyObject* EagerTensor_backing_device(EagerTensor* self) {
683   const char* device =
684       TFE_TensorHandleBackingDeviceName(self->handle, self->status);
685   if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) {
686     // Cleanup self->status before returning.
687     TF_SetStatus(self->status, TF_OK, "");
688     return nullptr;
689   }
690 #if PY_MAJOR_VERSION >= 3
691   return PyUnicode_FromString(device);
692 #else
693   return PyBytes_FromString(device);
694 #endif
695 }
696 
697 static PyGetSetDef EagerTensor_getseters[] = {
698     {const_cast<char*>("_id"), (getter)EagerTensor_getid, nullptr,
699      const_cast<char*>("_id"), nullptr},
700     {const_cast<char*>("device"), (getter)EagerTensor_device, nullptr,
701      const_cast<char*>("device"), nullptr},
702     {const_cast<char*>("backing_device"), (getter)EagerTensor_backing_device,
703      nullptr, const_cast<char*>("backing_device"), nullptr},
704     {const_cast<char*>("_handle_data"), (getter)EagerTensor_tensor_handle,
705      (setter)EagerTensor_settensor_handle, const_cast<char*>("_tensor_handle"),
706      nullptr},
707     {const_cast<char*>("_keras_mask"), (getter)EagerTensor_keras_mask,
708      (setter)EagerTensor_setkeras_mask, const_cast<char*>("_keras_mask"),
709      nullptr},
710     {const_cast<char*>("_tensor_shape"), (getter)EagerTensor_tensor_shape,
711      (setter)EagerTensor_settensor_shape, const_cast<char*>("_tensor_shape"),
712      nullptr},
713     {nullptr} /* Sentinel */
714 };
715 
716 #if PY_MAJOR_VERSION < 3
717 // Only used for Python2 since Python3 seems to set the __dict__ correctly.
718 static PyMemberDef EagerTensor_members[] = {
719     {const_cast<char*>("__dict__"), T_OBJECT, offsetof(EagerTensor, dict),
720      READONLY},
721     {nullptr},
722 };
723 #endif
724 
725 static PyMethodDef EagerTensor_methods[] = {
726     {"_numpy", (PyCFunction)EagerTensor_numpy, METH_NOARGS,
727      PyDoc_STR("_numpy")},
728     {"_datatype_enum", (PyCFunction)EagerTensor_datatype_enum, METH_NOARGS,
729      PyDoc_STR("_datatype_enum")},
730     {"_shape_tuple", (PyCFunction)EagerTensor_shape_tuple, METH_NOARGS,
731      PyDoc_STR("_shape_tuple")},
732     {"_rank", (PyCFunction)EagerTensor_rank, METH_NOARGS, PyDoc_STR("_rank")},
733     {"_copy_to_device", (PyCFunction)EagerTensor_copy_to_device,
734      METH_VARARGS | METH_KEYWORDS, PyDoc_STR("_copy_to_device")},
735     {"_num_elements", (PyCFunction)EagerTensor_num_elements, METH_NOARGS,
736      PyDoc_STR("_num_elements")},
737     {nullptr, nullptr},
738 };
739 
740 // Note that here we are trying to dynamically create a new class as a subclass
741 // of a "HEAPTYPE" class that is itself created in python code and passed in at
742 // runtime. This is fairly atypical and undocumented.
743 //
744 // We use the following strategy for this. Unfortunately, we have to use
745 // different approaches for python2.x vs python3.x
746 // For python2.x, we create the class as a static type and set its tp_base to
747 // the passed in type. Unfortunately setting tp_flags to include
748 // Py_TPFLAGS_HEAPTYPE does not work by itself since it needs some more
749 // initialization of the underlying PyHeapTypeObject and not doing that leads to
750 // some random crashes especially during garbage collection.
751 // python3.x explicitly disables a static subclass of a HEAPTYPE base class.
752 // However it provides a new function, PyType_FromSpecWithBases, to create
753 // types dynamically.
754 
755 // Type object for EagerTensor. This is set by TFE_Py_InitEagerTensor.
756 PyTypeObject* EagerTensorType = nullptr;
757 
758 #if PY_MAJOR_VERSION >= 3
759 static PyType_Slot EagerTensor_Type_slots[] = {
760     {Py_tp_dealloc, reinterpret_cast<void*>(EagerTensor_dealloc)},
761     {Py_tp_methods, reinterpret_cast<void*>(EagerTensor_methods)},
762     {Py_tp_getset, reinterpret_cast<void*>(EagerTensor_getseters)},
763     {Py_tp_init, reinterpret_cast<void*>(EagerTensor_init)},
764     {0, nullptr},
765 };
766 #else
767 // TODO(agarwal): support active_trace.
768 static PyTypeObject _EagerTensorType = {
769     // clang-format off
770     PyVarObject_HEAD_INIT(nullptr, 0)
771     // clang-format on
772     "EagerTensor",                      /* tp_name */
773     sizeof(EagerTensor),                /* tp_basicsize */
774     0,                                  /* tp_itemsize */
775     (destructor)EagerTensor_dealloc,    /* tp_dealloc */
776     nullptr,                            /* tp_print */
777     nullptr,                            /* tp_getattr */
778     nullptr,                            /* tp_setattr */
779     nullptr,                            /* tp_compare */
780     nullptr,                            /* tp_repr */
781     nullptr,                            /* tp_as_number */
782     nullptr,                            /* tp_as_sequence */
783     nullptr,                            /* tp_as_mapping */
784     nullptr,                            /* tp_hash */
785     nullptr,                            /* tp_call */
786     nullptr,                            /* tp_str */
787     nullptr,                            /* tp_getattro */
788     nullptr,                            /* tp_setattro */
789     nullptr,                            /* tp_as_buffer */
790     Py_TPFLAGS_DEFAULT,                 /* tp_flags */
791     nullptr,                            /* tp_doc */
792     nullptr,                            /* tp_traverse */
793     nullptr,                            /* tp_clear */
794     nullptr,                            /* tp_richcompare */
795     offsetof(EagerTensor, weakreflist), /* tp_weaklistoffset */
796     nullptr,                            /* tp_iter */
797     nullptr,                            /* tp_iternext */
798     EagerTensor_methods,                /* tp_methods */
799     EagerTensor_members,                /* tp_members */
800     EagerTensor_getseters,              /* tp_getset */
801     nullptr,                            /* tp_base */
802     nullptr,                            /* tp_dict */
803     nullptr,                            /* tp_descr_get */
804     nullptr,                            /* tp_descr_set */
805     offsetof(EagerTensor, dict),        /* tp_dictoffset */
806     (initproc)EagerTensor_init,         /* tp_init */
807     nullptr,                            /* tp_alloc */
808     nullptr,                            /* tp_new */
809 };
810 
811 #endif
812 
813 }  // extern "C"
814 
EagerTensor_CheckExact(const PyObject * o)815 bool EagerTensor_CheckExact(const PyObject* o) {
816   return Py_TYPE(o) == EagerTensorType;
817 }
818 
EagerTensor_Handle(const PyObject * o)819 TFE_TensorHandle* EagerTensor_Handle(const PyObject* o) {
820   return reinterpret_cast<const EagerTensor*>(o)->handle;
821 }
822 
EagerTensorFromHandle(TFE_TensorHandle * handle)823 PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle) {
824   if (handle == nullptr) {
825     return nullptr;
826   }
827   EagerTensor* t = reinterpret_cast<EagerTensor*>(
828       EagerTensorType->tp_new(EagerTensorType, Py_None, Py_None));
829   if (t != nullptr) {
830     t->id = get_uid();
831     Py_INCREF(Py_None);
832     t->handle_data = Py_None;
833     Py_INCREF(Py_None);
834     t->keras_mask = Py_None;
835     Py_INCREF(Py_None);
836     t->tensor_shape = Py_None;
837     t->handle = handle;
838     t->status = TF_NewStatus();
839     t->weakreflist = nullptr;
840 
841     if (!MaybeInvokeCreatedOnEagerTensorProfiler(t)) {
842       return nullptr;
843     }
844   }
845   return reinterpret_cast<PyObject*>(t);
846 }
847 
PyEagerTensor_ID(const PyObject * tensor)848 tensorflow::int64 PyEagerTensor_ID(const PyObject* tensor) {
849   DCHECK(EagerTensor_CheckExact(tensor));
850   return reinterpret_cast<const EagerTensor*>(tensor)->id;
851 }
852 
PyEagerTensor_Dtype(const PyObject * tensor)853 tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor) {
854   DCHECK(EagerTensor_CheckExact(tensor));
855   return static_cast<tensorflow::DataType>(TFE_TensorHandleDataType(
856       reinterpret_cast<const EagerTensor*>(tensor)->handle));
857 }
858 
PyEagerTensor_NumElements(const PyObject * tensor)859 tensorflow::int64 PyEagerTensor_NumElements(const PyObject* tensor) {
860   DCHECK(EagerTensor_CheckExact(tensor));
861   const EagerTensor* as_c_eager_tensor =
862       reinterpret_cast<const EagerTensor*>(tensor);
863   tensorflow::int64 result = TFE_TensorHandleNumElements(
864       as_c_eager_tensor->handle, as_c_eager_tensor->status);
865 
866   if (MaybeRaiseExceptionFromTFStatus(as_c_eager_tensor->status,
867                                       PyExc_ValueError)) {
868     // Cleanup status before returning.
869     TF_SetStatus(as_c_eager_tensor->status, TF_OK, "");
870     return -1;
871   }
872 
873   return result;
874 }
875 
TFE_Py_InitEagerTensor(PyObject * base_class)876 PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
877   if (!PyType_Check(base_class)) {
878     PyErr_SetString(
879         PyExc_TypeError,
880         tensorflow::strings::StrCat(
881             "Expecting a class definition for `base_class` passed to ",
882             "TFE_InitEagerTensor. Got ", Py_TYPE(base_class)->tp_name)
883             .c_str());
884     return nullptr;
885   }
886   // Note that we allocated kMaxEagerTensorParentSize bytes of unused space in
887   // EagerTensor to allow for the space usage of the base class.
888   PyTypeObject* base_class_type = reinterpret_cast<PyTypeObject*>(base_class);
889   if (base_class_type->tp_basicsize > kMaxEagerTensorParentSize) {
890     PyErr_SetString(
891         PyExc_TypeError,
892         tensorflow::strings::StrCat(
893             "Unable to create subclass EagerTensor from base class ",
894             Py_TYPE(base_class)->tp_name,
895             ". Need its size to be <= ", kMaxEagerTensorParentSize)
896             .c_str());
897     return nullptr;
898   }
899   if (base_class_type->tp_itemsize != 0) {
900     PyErr_SetString(
901         PyExc_TypeError,
902         tensorflow::strings::StrCat(
903             "Unable to create subclass EagerTensor from base class ",
904             Py_TYPE(base_class)->tp_name,
905             " which supports variable length instances.")
906             .c_str());
907     return nullptr;
908   }
909   Py_INCREF(base_class);
910 #if PY_MAJOR_VERSION >= 3
911   PyObject* bases = PyTuple_New(1);
912   PyTuple_SET_ITEM(bases, 0, base_class);
913 
914   tensorflow::Safe_PyObjectPtr base_class_module(
915       PyObject_GetAttrString(base_class, "__module__"));
916   const char* module = nullptr;
917   if (PyErr_Occurred()) {
918     PyErr_Clear();
919     module = "__builtin__";
920   } else {
921     module = PyBytes_AsString(base_class_module.get());
922     if (module == nullptr) {
923       PyErr_Clear();
924       module = PyUnicode_AsUTF8(base_class_module.get());
925       if (module == nullptr) {
926         PyErr_Clear();
927         module = "__builtin__";
928       }
929     }
930   }
931 
932   // NOTE: The c_str from this string needs to outlast the function, hence is
933   // static.
934   static tensorflow::string fully_qualified_name =
935       tensorflow::strings::StrCat(module, ".EagerTensor");
936 
937   static PyType_Spec EagerTensor_Type_spec = {
938       fully_qualified_name.c_str(), sizeof(EagerTensor), 0,
939       Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE, EagerTensor_Type_slots};
940 
941   EagerTensorType = reinterpret_cast<PyTypeObject*>(
942       PyType_FromSpecWithBases(&EagerTensor_Type_spec, bases));
943   if (PyErr_Occurred()) {
944     return nullptr;
945   }
946   if (EagerTensorType == nullptr) {
947     PyErr_SetString(PyExc_RuntimeError, "Error while creating EagerTensorType");
948     return nullptr;
949   }
950   EagerTensorType->tp_dictoffset = offsetof(EagerTensor, dict);
951 #else
952   _EagerTensorType.tp_base = base_class_type;
953 
954   if (PyType_Ready(&_EagerTensorType) < 0) {
955     if (PyErr_Occurred()) return nullptr;
956     PyErr_SetString(PyExc_RuntimeError,
957                     "Error while creating EagerTensor type.");
958     return nullptr;
959   }
960   EagerTensorType = &_EagerTensorType;
961   Py_INCREF(EagerTensorType);
962 #endif
963   return reinterpret_cast<PyObject*>(EagerTensorType);
964 }
965 
TFE_Py_SetEagerTensorProfiler(PyObject * profiler)966 PyObject* TFE_Py_SetEagerTensorProfiler(PyObject* profiler) {
967   Py_XDECREF(eager_tensor_profiler);
968 
969   if (profiler == Py_None) {
970     eager_tensor_profiler = nullptr;
971   } else {
972     eager_tensor_profiler = profiler;
973     Py_INCREF(eager_tensor_profiler);
974   }
975   Py_RETURN_NONE;
976 }
977 
TFE_Py_TensorShapeSlice(PyObject * tensors,int slice_dim)978 PyObject* TFE_Py_TensorShapeSlice(PyObject* tensors, int slice_dim) {
979   if (!PyList_Check(tensors) && !PyTuple_Check(tensors)) {
980     PyErr_SetString(PyExc_TypeError,
981                     tensorflow::strings::StrCat(
982                         "tensors argument must be a list or a tuple. Got \"",
983                         Py_TYPE(tensors)->tp_name, "\"")
984                         .c_str());
985     return nullptr;
986   }
987   if (slice_dim < 0) {
988     PyErr_SetString(
989         PyExc_ValueError,
990         tensorflow::strings::StrCat("Slice dimension must be non-negative. "
991                                     "Got ",
992                                     slice_dim)
993             .c_str());
994     return nullptr;
995   }
996 
997   Py_ssize_t num_tensors = PySequence_Fast_GET_SIZE(tensors);
998   int64_t num_tensors_int = static_cast<int64_t>(num_tensors);
999   auto tensor = tensorflow::make_safe(TF_AllocateTensor(
1000       TF_INT32, &num_tensors_int, /*num_dims=*/1, /*len=*/4 * num_tensors_int));
1001   int32_t* data = reinterpret_cast<int32_t*>(TF_TensorData(tensor.get()));
1002   auto status = tensorflow::make_safe(TF_NewStatus());
1003   for (Py_ssize_t i = 0; i < num_tensors; ++i) {
1004     PyObject* tensor_obj = PySequence_Fast_GET_ITEM(tensors, i);
1005     if (!EagerTensor_CheckExact(tensor_obj)) {
1006       PyErr_SetString(PyExc_TypeError,
1007                       tensorflow::strings::StrCat(
1008                           "Expected a list of EagerTensors but "
1009                           "element ",
1010                           i, " has type \"", Py_TYPE(tensor_obj)->tp_name, "\"")
1011                           .c_str());
1012       return nullptr;
1013     }
1014 
1015     EagerTensor* t = reinterpret_cast<EagerTensor*>(tensor_obj);
1016     TFE_TensorHandle* handle = t->handle;
1017     int num_dims = TFE_TensorHandleNumDims(handle, status.get());
1018     if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_ValueError)) {
1019       return nullptr;
1020     }
1021     if (slice_dim >= num_dims) {
1022       PyErr_SetString(
1023           PyExc_IndexError,
1024           tensorflow::strings::StrCat("Slice dimension (", slice_dim,
1025                                       ") must be smaller than rank of all "
1026                                       "tensors, but tensor at index ",
1027                                       i, " has rank ", num_dims)
1028               .c_str());
1029       return nullptr;
1030     }
1031     int64_t dim = TFE_TensorHandleDim(handle, slice_dim, status.get());
1032     if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_ValueError)) {
1033       return nullptr;
1034     }
1035     data[i] = dim;
1036   }
1037 
1038   TFE_TensorHandle* handle = TFE_NewTensorHandle(tensor.get(), status.get());
1039   if (TF_GetCode(status.get()) != TF_OK) {
1040     PyErr_SetString(
1041         PyExc_RuntimeError,
1042         tensorflow::strings::StrCat("Failed to construct new tensor handle: ",
1043                                     TF_Message(status.get()))
1044             .c_str());
1045     return nullptr;
1046   }
1047 
1048   return EagerTensorFromHandle(handle);
1049 }
1050 
TFE_Py_TensorShapeOnDevice(PyObject * tensor)1051 PyObject* TFE_Py_TensorShapeOnDevice(PyObject* tensor) {
1052   if (!EagerTensor_CheckExact(tensor)) {
1053     PyErr_SetString(
1054         PyExc_TypeError,
1055         tensorflow::strings::StrCat("Expected an EagerTensors but got type \"",
1056                                     Py_TYPE(tensor)->tp_name, "\"")
1057             .c_str());
1058     return nullptr;
1059   }
1060   TFE_TensorHandle* handle = EagerTensor_Handle(tensor);
1061 
1062   auto status = tensorflow::make_safe(TF_NewStatus());
1063   TFE_TensorDebugInfo* debug_info =
1064       TFE_TensorHandleTensorDebugInfo(handle, status.get());
1065   if (TF_GetCode(status.get()) != TF_OK) {
1066     PyErr_SetString(
1067         PyExc_RuntimeError,
1068         tensorflow::strings::StrCat("Error retrieving tensor's device shape: ",
1069                                     TF_Message(status.get()))
1070             .c_str());
1071     return nullptr;
1072   }
1073 
1074   int rank = TFE_TensorDebugInfoOnDeviceNumDims(debug_info);
1075   PyObject* shape = PyTuple_New(rank);
1076   for (int i = 0; i < rank; ++i) {
1077     tensorflow::int64 dim_size = TFE_TensorDebugInfoOnDeviceDim(debug_info, i);
1078     PyTuple_SET_ITEM(shape, i, PyLong_FromLongLong(dim_size));
1079   }
1080   TFE_DeleteTensorDebugInfo(debug_info);
1081 
1082   return shape;
1083 }
1084