1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/python/eager/pywrap_tensor.h"
17
18 #include <stdlib.h>
19 #include <string.h>
20
21 #include <cmath>
22
23 #include "structmember.h" // NOLINT // For PyMemberDef
24 #include "pybind11/pybind11.h"
25 #include "tensorflow/c/c_api.h"
26 #include "tensorflow/c/eager/c_api.h"
27 #include "tensorflow/c/eager/c_api_internal.h"
28 #include "tensorflow/c/eager/tfe_context_internal.h"
29 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
30 #include "tensorflow/c/tf_status.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/framework/types.pb.h"
33 #include "tensorflow/core/lib/strings/strcat.h"
34 #include "tensorflow/python/eager/pywrap_tensor_conversion.h"
35 #include "tensorflow/python/eager/pywrap_tfe.h"
36 #include "tensorflow/python/lib/core/ndarray_tensor.h"
37 #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
38 #include "tensorflow/python/lib/core/numpy.h"
39 #include "tensorflow/python/lib/core/py_exception_registry.h"
40 #include "tensorflow/python/lib/core/py_seq_tensor.h"
41 #include "tensorflow/python/lib/core/pybind11_status.h"
42 #include "tensorflow/python/lib/core/safe_ptr.h"
43
44 // forward declare
45 struct EagerTensor;
46 namespace tensorflow {
47
48 // Convert a TFE_TensorHandle to a Python numpy.ndarray object.
49 // The two may share underlying storage so changes to one may reflect in the
50 // other.
TFE_TensorHandleToNumpy(TFE_TensorHandle * handle,TF_Status * status)51 PyObject* TFE_TensorHandleToNumpy(TFE_TensorHandle* handle, TF_Status* status) {
52 if (TFE_TensorHandleDataType(handle) == TF_RESOURCE) {
53 TF_SetStatus(status, TF_INVALID_ARGUMENT,
54 "Cannot convert a Tensor of dtype resource to a NumPy array.");
55 return nullptr;
56 }
57
58 if (TFE_TensorHandleDataType(handle) == TF_VARIANT) {
59 TF_SetStatus(status, TF_INVALID_ARGUMENT,
60 "Cannot convert a Tensor of dtype variant to a NumPy array.");
61 return nullptr;
62 }
63 tensorflow::Safe_TF_TensorPtr tensor = nullptr;
64 Py_BEGIN_ALLOW_THREADS;
65 tensor = tensorflow::make_safe(TFE_TensorHandleResolve(handle, status));
66 Py_END_ALLOW_THREADS;
67 if (!status->status.ok()) {
68 return nullptr;
69 }
70
71 PyObject* ret = nullptr;
72 auto cppstatus =
73 tensorflow::TF_TensorToMaybeAliasedPyArray(std::move(tensor), &ret);
74 tensorflow::Set_TF_Status_from_Status(status, cppstatus);
75 if (!status->status.ok()) {
76 Py_XDECREF(ret);
77 return nullptr;
78 }
79 CHECK_NE(ret, nullptr);
80 return ret;
81 }
82 } // namespace tensorflow
83 namespace {
84
85 using tensorflow::TFE_TensorHandleToNumpy;
86
87 // An instance of _EagerTensorProfiler that will receive callbacks about
88 // events on eager tensors. This is set by TFE_Py_InitEagerTensor, if at all.
89 PyObject* eager_tensor_profiler = nullptr;
90
91 // Read-only dict. Please don't use this in any setting where the dict might
92 // actually get mutated. This is only used to pass empty kwargs when creating a
93 // new EagerTensor.
EmptyDict()94 PyObject* EmptyDict() {
95 static PyObject* empty_dict = PyDict_New();
96 return empty_dict;
97 }
98
EmptyTuple()99 PyObject* EmptyTuple() {
100 static PyObject* empty_tuple = PyTuple_New(0);
101 return empty_tuple;
102 }
103
GetContextHandle(PyObject * py_context)104 TFE_Context* GetContextHandle(PyObject* py_context) {
105 tensorflow::Safe_PyObjectPtr py_context_handle(
106 PyObject_GetAttrString(py_context, "_handle"));
107 if (py_context_handle == nullptr) {
108 // Current Python code makes sure this never happens. If it does, or
109 // becomes hard to maintain, we can call the ensure_initialized() method
110 // here.
111 PyErr_SetString(
112 PyExc_TypeError,
113 "Expected `context` argument in EagerTensor constructor to have a "
114 "`_handle` attribute but it did not. Was eager Context initialized?");
115 return nullptr;
116 }
117
118 auto* ctx = reinterpret_cast<TFE_Context*>(
119 PyCapsule_GetPointer(py_context_handle.get(), nullptr));
120 if (ctx == nullptr) {
121 PyErr_SetString(PyExc_TypeError,
122 tensorflow::strings::StrCat(
123 "Expected context._handle to contain a PyCapsule "
124 "encoded pointer to TFE_Context. Got ",
125 Py_TYPE(py_context_handle.get())->tp_name)
126 .c_str());
127 }
128 return ctx;
129 }
130
131
132 // Helper function to convert `v` to a tensorflow::DataType and store it in
133 // `*out`. Returns true on success, false otherwise.
134 // Note that we assume that v is a python int (not long) representing a
135 // TF_DataType/tensorflow::DataType value.
PyIntToDataType(PyObject * v,tensorflow::DataType * out)136 bool PyIntToDataType(PyObject* v, tensorflow::DataType* out) {
137 #if PY_MAJOR_VERSION < 3
138 if (PyInt_Check(v)) {
139 *out = static_cast<tensorflow::DataType>(PyInt_AS_LONG(v));
140 return true;
141 }
142 #else
143 if (PyLong_Check(v)) {
144 *out = static_cast<tensorflow::DataType>(PyLong_AsLong(v));
145 return true;
146 }
147 #endif
148 return false;
149 }
150
151 // Helper function to create a python integer from TF_DataType.
PyIntFromDataType(TF_DataType l)152 PyObject* PyIntFromDataType(TF_DataType l) {
153 #if PY_MAJOR_VERSION < 3
154 return PyInt_FromLong(l);
155 #else
156 return PyLong_FromLong(l);
157 #endif
158 }
159
160 // PyObject->tensorflow::DataType conversion function to be used with
161 // PyArg_Parse* APIs.
ConvertDataType(PyObject * obj,tensorflow::DataType * dst)162 int ConvertDataType(PyObject* obj, tensorflow::DataType* dst) {
163 if (obj == Py_None) {
164 *dst = tensorflow::DataType::DT_INVALID;
165 } else if (!PyIntToDataType(obj, dst)) {
166 PyErr_SetString(
167 PyExc_TypeError,
168 tensorflow::strings::StrCat(
169 "Expecting a DataType value for dtype. Got ", Py_TYPE(obj)->tp_name)
170 .c_str());
171 return 0;
172 }
173
174 return 1;
175 }
176
177 // Conversion function extracting a const char** device name from a PyObject.
178 // The function should be used with PyArg_Parse* APIs.
ConvertDeviceName(PyObject * obj,const char ** dst)179 int ConvertDeviceName(PyObject* obj, const char** dst) {
180 if (obj == Py_None) {
181 *dst = nullptr;
182 } else {
183 auto device_name = TFE_GetPythonString(obj);
184 if (device_name == nullptr) {
185 PyErr_Clear();
186 PyErr_SetString(PyExc_TypeError, "Error parsing device argument.");
187 return 0;
188 }
189 *dst = device_name;
190 }
191
192 return 1;
193 }
194
RaiseExceptionTypeFromTFStatus(TF_Status * tf_status)195 void RaiseExceptionTypeFromTFStatus(TF_Status* tf_status) {
196 auto status = tensorflow::StatusFromTF_Status(tf_status);
197 SetRegisteredErrFromStatus(status);
198 }
199
200 } // namespace
201
202 namespace tensorflow {
203 // This function checks whether the desired type is "compatible" with the
204 // inferred type. At a high level, compatibility means that all integral types
205 // are compatible with each other, and all floating types are compatible with
206 // each other.
207 //
208 // Type compatibility doesn't consider overflows (i.e. int64 is *always*
209 // compatible with int32). This is intended to match graph behavior.
IsCompatible(DataType desired,DataType returned)210 bool IsCompatible(DataType desired, DataType returned) {
211 if (desired == returned) return true;
212
213 if (DataTypeIsInteger(desired) && DataTypeIsInteger(returned)) {
214 return true;
215 } else if (DataTypeIsFloating(desired) &&
216 (DataTypeIsFloating(returned) || DataTypeIsInteger(returned))) {
217 return true;
218 } else if (DataTypeIsComplex(desired) &&
219 (DataTypeIsComplex(returned) || DataTypeIsInteger(returned) ||
220 DataTypeIsFloating(returned))) {
221 return true;
222 } else if (DataTypeIsQuantized(desired) && DataTypeIsInteger(returned)) {
223 return true;
224 }
225 return false;
226 }
227
228 // TODO(nareshmodi): Move EagerCast and ReadVariableOp (which use the C API to
229 // execute TFE Ops) to a separate common library.
230 // Casts data referred to by `handle` from type `src_type_enum` to type
231 // `dst_type_enum`.
EagerCast(TFE_Context * ctx,TFE_TensorHandle * handle,TF_DataType src_type_enum,TF_DataType dst_type_enum,TF_Status * out_status)232 TFE_TensorHandle* EagerCast(TFE_Context* ctx, TFE_TensorHandle* handle,
233 TF_DataType src_type_enum,
234 TF_DataType dst_type_enum, TF_Status* out_status) {
235 if (ctx == nullptr) return nullptr;
236 const char* op_name = "Cast";
237 const char* device_name = "/device:CPU:0";
238 TFE_Op* op = TFE_NewOp(ctx, op_name, out_status);
239 #define RETURN_ERROR \
240 { \
241 TFE_DeleteOp(op); \
242 return nullptr; \
243 }
244 if (!out_status->status.ok()) RETURN_ERROR
245 TFE_OpSetDevice(op, device_name, out_status);
246 if (!out_status->status.ok()) RETURN_ERROR
247 TFE_OpAddInput(op, handle, out_status);
248 if (!out_status->status.ok()) RETURN_ERROR
249 TFE_OpSetAttrType(op, "SrcT", src_type_enum);
250 TFE_OpSetAttrType(op, "DstT", dst_type_enum);
251 TFE_OpSetAttrBool(op, "Truncate", false);
252 TFE_TensorHandle* output = nullptr;
253 int num_outputs = 1;
254 TFE_Execute(op, &output, &num_outputs, out_status);
255 if (!out_status->status.ok() || num_outputs != 1 || output == nullptr) {
256 if (output != nullptr) {
257 TFE_DeleteTensorHandle(output);
258 }
259 RETURN_ERROR
260 }
261 TFE_DeleteOp(op);
262 return output;
263 #undef RETURN_ERROR
264 }
265
EagerConst(TFE_Context * ctx,TFE_TensorHandle * handle,const char * device_name,TF_Status * out_status)266 Safe_TFE_TensorHandlePtr EagerConst(TFE_Context* ctx, TFE_TensorHandle* handle,
267 const char* device_name,
268 TF_Status* out_status) {
269 const char* op_name = "_EagerConst";
270 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
271 TFE_NewOp(ctx, op_name, out_status), TFE_DeleteOp);
272 if (!out_status->status.ok()) return nullptr;
273 TFE_OpSetDevice(op.get(), device_name, out_status);
274 if (!out_status->status.ok()) return nullptr;
275 TFE_OpAddInput(op.get(), handle, out_status);
276 if (!out_status->status.ok()) return nullptr;
277 TFE_OpSetAttrType(op.get(), "T", TFE_TensorHandleDataType(handle));
278 TFE_TensorHandle* output = nullptr;
279 int num_outputs = 1;
280 TFE_Execute(op.get(), &output, &num_outputs, out_status);
281 Safe_TFE_TensorHandlePtr result(output);
282 if (!out_status->status.ok() || num_outputs != 1) {
283 return nullptr;
284 }
285 return result;
286 }
287
ConvertToEagerTensorUncached(TFE_Context * ctx,PyObject * value,tensorflow::DataType dtype,const char * device_name)288 TFE_TensorHandle* ConvertToEagerTensorUncached(TFE_Context* ctx,
289 PyObject* value,
290 tensorflow::DataType dtype,
291 const char* device_name) {
292 tensorflow::Safe_PyObjectPtr value_decrefer;
293 if (PyArray_IsScalar(value, Generic)) {
294 // Convert numpy scalars to numpy arrays.
295 value = PyArray_FromScalar(value, nullptr);
296 // The returned value needs to be DECREF'd, but the original value was
297 // created in python code, and doesn't need to be DECREF'd.
298 value_decrefer.reset(value);
299 }
300
301 Safe_TFE_TensorHandlePtr handle =
302 make_safe(PySeqToTFE_TensorHandle(ctx, value, dtype));
303
304 if (handle == nullptr) return nullptr;
305
306 Safe_TF_StatusPtr status = make_safe(TF_NewStatus());
307 TF_DataType handle_dtype = TFE_TensorHandleDataType(handle.get());
308 if (dtype != tensorflow::DT_INVALID &&
309 dtype != static_cast<DataType>(handle_dtype)) {
310 if (tensorflow::IsCompatible(dtype, static_cast<DataType>(handle_dtype))) {
311 handle = tensorflow::make_safe(
312 tensorflow::EagerCast(ctx, handle.get(), handle_dtype,
313 static_cast<TF_DataType>(dtype), status.get()));
314 if (!status->status.ok()) {
315 PyErr_SetString(PyExc_TypeError,
316 absl::StrCat("Error while casting from dtype ",
317 tensorflow::DataTypeString(
318 static_cast<DataType>(handle_dtype)),
319 " to ", tensorflow::DataTypeString(dtype),
320 ". ", TF_Message(status.get()))
321 .c_str());
322 return nullptr;
323 }
324 } else {
325 tensorflow::Safe_PyObjectPtr value_str(PyObject_Repr(value));
326 PyErr_SetString(
327 PyExc_TypeError,
328 absl::StrCat("Cannot convert ", TFE_GetPythonString(value_str.get()),
329 " to EagerTensor of dtype ",
330 tensorflow::DataTypeString(dtype))
331 .c_str());
332 return nullptr;
333 }
334 }
335
336 // We always initially generate CPU:0 tensors. Copy to the current device.
337 if (device_name != nullptr) {
338 if (strstr(device_name, "/device:CPU:0") != nullptr) {
339 // We always generate CPU:0 tensors, but we may need to change the device
340 // slightly, as for example from /job:localhost/... to /job:worker/...
341 //
342 // Note that this is a shallow copy and will share the underlying buffer,
343 // because we are copying to the same device.
344 handle = make_safe(TFE_TensorHandleCopyToDevice(
345 handle.get(), ctx, device_name, status.get()));
346 const TF_Code code = TF_GetCode(status.get());
347 if (code != TF_OK) {
348 RaiseExceptionTypeFromTFStatus(status.get());
349 return nullptr;
350 }
351 } else {
352 /*Copy the constant to the current device. Identity is sometimes
353 overloaded to allow copies like this, but using a different op allows
354 devices to support constant creation without allowing copies via
355 identity ops.
356
357 Note that running this _EagerConst op limits mirroring of cached Python
358 literals somewhat. Mirroring of constants themselves works:
359
360 with tf.device("GPU:0"):
361 tf.constant(1.) # Cached on CPU:0, mirrored to GPU:0
362 with tf.device("GPU:1"):
363 tf.constant(1.) # Cache hit for the CPU version, new mirror to GPU:1.
364 with tf.device("GPU:1"):
365 tf.constant(1.) # Cache hit for the CPU version, cached mirror
366
367 But mirrors for the output of `tf.constant` are not shared just because
368 there was a cache hit for the input literal, because of _EagerConst:
369
370 x = tf.constant(2.) # Cached on CPU:0
371 with tf.device("GPU:1"):
372 tf.identity(x) # `x` now mirrored to GPU:1
373 y = tf.constant(2.) # Cache hit for CPU version
374 with tf.device("GPU:1"):
375 tf.identity(y) # `y` now mirrored on GPU:1 (new copy!)*/
376 handle =
377 tensorflow::EagerConst(ctx, handle.get(), device_name, status.get());
378 const TF_Code code = TF_GetCode(status.get());
379 if (code != TF_OK) {
380 RaiseExceptionTypeFromTFStatus(status.get());
381 return nullptr;
382 }
383 }
384 }
385
386 return handle.release();
387 }
388
ConvertToEagerTensor(TFE_Context * ctx,PyObject * value,DataType dtype,const char * device_name)389 TFE_TensorHandle* ConvertToEagerTensor(TFE_Context* ctx, PyObject* value,
390 DataType dtype,
391 const char* device_name) {
392 // Reduce the overhead of allocation/transfer-to-device for scalars by
393 // caching the corresponding handles. Note that currently only Python
394 // scalars are cached.
395 // TODO(slebedev): also cache singleton NumPy arrays and scalars?
396 if (PyArray_IsPythonNumber(value)) {
397 auto* cache = TFE_TensorHandleCache::Get();
398 TFE_TensorHandle* handle = cache->Lookup(value, dtype, ctx, device_name);
399 if (handle != nullptr) return handle;
400 handle = ConvertToEagerTensorUncached(ctx, value, dtype, device_name);
401 if (handle == nullptr) return nullptr;
402 if (!PyFloat_Check(value) || std::isfinite(PyFloat_AS_DOUBLE(value))) {
403 cache->Insert(value, dtype, ctx, device_name, handle);
404 }
405 return handle;
406 } else {
407 return ConvertToEagerTensorUncached(ctx, value, dtype, device_name);
408 }
409 }
410
411 } // namespace tensorflow
412
413 extern "C" {
414
415 static const int kMaxEagerTensorParentSize = 64;
416
417 // TODO(agarwal): store context handle in EagerTensor.
418 typedef struct EagerTensor {
419 PyObject_HEAD;
420 // Note that we leave kMaxEagerTensorParentSize bytes here for use by the
421 // parent class. The parent class is set at runtime, so we don't know the
422 // exact size at compile time.
423 char unused[kMaxEagerTensorParentSize];
424 TFE_TensorHandle* handle;
425 int64_t id;
426 // Indicates whether it's a packed tensor or not.
427 bool is_packed;
428 // This mirrors tensorflow.core.framework.ops.Tensor._handle_data Which will
429 // be None for tensors of type other than DT_RESOURCE. For DT_RESOURCE
430 // tensors, this will contain a serialized HandleData proto with shape
431 // inference metadata about shapes and dtypes of resources accessible from
432 // this handle.
433 // Note that we assume that handle_data cannot participate in reference
434 // cycles, and hence don't provide GC support for it.
435 PyObject* handle_data;
436
437 // This stores `_tensor_shape`, a cached `TensorShape` object, and is set the
438 // first time that `_EagerTensorBase`'s `shape` property is called.
439 PyObject* tensor_shape;
440
441 // We store a status object here as an optimization to avoid allocating a new
442 // Status objects on different functions that operate on EagerTensor and need
443 // to use a TF_Status object. However note that accesses to `status` are not
444 // thread-safe.
445 TF_Status status;
446
447 // The eager Context (from eager/context.py) used by this Tensor.
448 // This is currently used only to make sure context outlives TensorHandles.
449 PyObject* context;
450
451 PyObject* weakreflist; /* List of weak references */
452
453 // Per-instance attribute dictionary, to support monkey patching
454 // (e.g. EagerTensor.assign when slicing variables). This dictionary is
455 // created by CPython the first time an attribute is assigned, pointed to by
456 // tp_dictoffset. Note that garbage collection is not enabled for
457 // EagerTensors, so assigning objects to EagerTensor attributes which require
458 // garbage collection is likely to cause issues.
459 PyObject* dict;
460 } EagerTensor;
461
462 namespace {
463
464 // Returns true on success - successfully invoked or no profiler registered.
465 // Returns false if some error occurred.
MaybeInvokeCreatedOnEagerTensorProfiler(EagerTensor * created_tensor)466 bool MaybeInvokeCreatedOnEagerTensorProfiler(EagerTensor* created_tensor) {
467 if (eager_tensor_profiler != nullptr) {
468 #if PY_MAJOR_VERSION < 3
469 PyObject* created_method_name = PyString_InternFromString("created");
470 #else
471 PyObject* created_method_name = PyUnicode_InternFromString("created");
472 #endif
473 if (created_method_name == nullptr) {
474 return false;
475 }
476 PyObject* result = PyObject_CallMethodObjArgs(
477 eager_tensor_profiler, created_method_name, created_tensor, NULL);
478 if (result == nullptr) {
479 LOG(ERROR) << "Invoking created() on EagerTensor profiler failed";
480 // While we can potentially continue because the error is related to
481 // profiling, we choose to return an error because:
482 // - If profiling is used, the user likely wants to stop execution on
483 // profiling errors.
484 // - Error in profiling code might have left some state in an invalid
485 // form that can lead to an error later on. Better to fail fast.
486 Py_DECREF(created_method_name);
487 return false;
488 }
489 Py_DECREF(created_method_name);
490 Py_DECREF(result);
491 }
492 return true;
493 }
494
495 } // namespace
496
497 // tp_init for EagerTensor.
EagerTensor_init(EagerTensor * self,PyObject * args,PyObject * kwds)498 int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) {
499 self->id = get_uid();
500 self->handle = nullptr;
501 self->is_packed = false;
502 Py_INCREF(Py_None);
503 self->handle_data = Py_None;
504 Py_INCREF(Py_None);
505 self->tensor_shape = Py_None;
506 self->status.status = tensorflow::Status::OK();
507 self->dict = nullptr;
508 self->weakreflist = nullptr;
509 self->context = nullptr;
510 PyObject* value;
511 const char* device_name = nullptr;
512 tensorflow::DataType dtype = tensorflow::DataType::DT_INVALID;
513 const char* kwlist[] = {"value", "device", "dtype", nullptr};
514 if (!PyArg_ParseTupleAndKeywords(
515 args, kwds, "OO&|O&", const_cast<char**>(kwlist), &value,
516 ConvertDeviceName, &device_name, ConvertDataType, &dtype)) {
517 return -1;
518 }
519
520 PyObject* py_context = GetPyEagerContext();
521 if (py_context == nullptr) return -1;
522 self->context = py_context;
523
524 auto* handle = tensorflow::ConvertToEagerTensor(GetContextHandle(py_context),
525 value, dtype, device_name);
526 if (handle == nullptr) return -1;
527 self->handle = handle;
528
529 if (!MaybeInvokeCreatedOnEagerTensorProfiler(self)) {
530 return -1;
531 }
532
533 return 0;
534 }
535
536 // tp_dealloc for EagerTensor.
EagerTensor_dealloc(EagerTensor * self)537 void EagerTensor_dealloc(EagerTensor* self) {
538 // Unhook the object from python's GC so that the weakref deleter doesn't
539 // try to re-delete this.
540 PyObject_GC_UnTrack((PyObject*)self);
541
542 // Clear weak references to self.
543 // Needs to happen before any actual destruction.
544 PyObject_ClearWeakRefs((PyObject*)self);
545
546 Py_DECREF(self->handle_data);
547 Py_DECREF(self->tensor_shape);
548 // If an attribute dictionary has been created, release it. Note that this
549 // is only ever created by CPython's attribute setting methods; we don't
550 // create it ourselves.
551 Py_CLEAR(self->dict);
552 if (self->handle != nullptr) {
553 TFE_DeleteTensorHandle(self->handle);
554 self->handle = nullptr;
555 }
556
557 // Decref context after deleting the tensor handle.
558 Py_XDECREF(self->context);
559
560 // We have the global interpreter lock, so use this chance to perform delayed
561 // refcount decrements.
562 tensorflow::ClearDecrefCache();
563 auto id = self->id;
564 Py_TYPE(self)->tp_free(self);
565 TFE_Py_TapeSetDeleteTrace(id);
566 }
567
568 // Getter for `_id`.
EagerTensor_getid(EagerTensor * self,void * closure)569 static PyObject* EagerTensor_getid(EagerTensor* self, void* closure) {
570 return PyLong_FromLongLong(self->id);
571 }
572
573 // Getter for `_datatype_enum`.
EagerTensor_datatype_enum(EagerTensor * self)574 static PyObject* EagerTensor_datatype_enum(EagerTensor* self) {
575 return PyIntFromDataType(TFE_TensorHandleDataType(self->handle));
576 }
577
578 // Getter for `_shape_tuple`.
EagerTensor_shape_tuple(EagerTensor * self)579 static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
580 auto handle = self->handle;
581 int n = TFE_TensorHandleNumDims(handle, &self->status);
582 TF_Code code = TF_GetCode(&self->status);
583 if (code != TF_OK) {
584 RaiseExceptionTypeFromTFStatus(&self->status);
585 // Cleanup self->status before returning.
586 self->status.status = tensorflow::Status::OK();
587 return nullptr;
588 }
589 PyObject* shape = PyTuple_New(n);
590 if (PyErr_Occurred()) return nullptr;
591 for (int i = 0; i < n; ++i) {
592 int64_t dim_c_value = TFE_TensorHandleDim(handle, i, &self->status);
593 PyObject* dim;
594 // The C++ convention is -1 for unknown/variable axis lengths. Translate
595 // that to the Python "None" convention. Unknown axis lengths are unusual
596 // for eager tensors.
597 if (dim_c_value < 0) {
598 Py_IncRef(Py_None);
599 dim = Py_None;
600 } else {
601 dim = PyLong_FromLongLong(dim_c_value);
602 }
603 code = TF_GetCode(&self->status);
604 if (code != TF_OK || dim == nullptr ||
605 PyTuple_SetItem(shape, i, dim) != 0) {
606 if (code != TF_OK) {
607 RaiseExceptionTypeFromTFStatus(&self->status);
608 } else {
609 PyErr_SetString(PyExc_RuntimeError, "Error while creating shape");
610 }
611 // Cleanup self->status before returning.
612 self->status.status = tensorflow::Status::OK();
613 Py_DECREF(shape);
614 if (dim != nullptr) Py_DECREF(dim);
615 return nullptr;
616 }
617 }
618 return shape;
619 }
620
621 // Getter for `_rank`.
EagerTensor_rank(EagerTensor * self)622 static PyObject* EagerTensor_rank(EagerTensor* self) {
623 int num_dims = TFE_TensorHandleNumDims(self->handle, &self->status);
624 if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) {
625 // Cleanup self->status before returning.
626 self->status.status = tensorflow::Status::OK();
627 return nullptr;
628 }
629 #if PY_MAJOR_VERSION < 3
630 return PyInt_FromLong(num_dims);
631 #else
632 return PyLong_FromLong(num_dims);
633 #endif
634 }
635
636 // Getter for `_num_elements`.
EagerTensor_num_elements(EagerTensor * self)637 static PyObject* EagerTensor_num_elements(EagerTensor* self) {
638 auto handle = self->handle;
639 int n = TFE_TensorHandleNumElements(handle, &self->status);
640 if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) {
641 // Cleanup self->status before returning.
642 self->status.status = tensorflow::Status::OK();
643 return nullptr;
644 }
645 return PyLong_FromLongLong(n);
646 }
647
EagerTensor_handle_data(EagerTensor * self,void * unused)648 static PyObject* EagerTensor_handle_data(EagerTensor* self, void* unused) {
649 Py_INCREF(self->handle_data);
650 return self->handle_data;
651 }
652
EagerTensor_sethandle_data(EagerTensor * self,PyObject * value,void * unused)653 static int EagerTensor_sethandle_data(EagerTensor* self, PyObject* value,
654 void* unused) {
655 Py_DECREF(self->handle_data);
656 Py_INCREF(value);
657 self->handle_data = value;
658 return 0;
659 }
660
EagerTensor_tensor_shape(EagerTensor * self,void * unused)661 static PyObject* EagerTensor_tensor_shape(EagerTensor* self, void* unused) {
662 Py_INCREF(self->tensor_shape);
663 return self->tensor_shape;
664 }
665
EagerTensor_settensor_shape(EagerTensor * self,PyObject * value,void * unused)666 static int EagerTensor_settensor_shape(EagerTensor* self, PyObject* value,
667 void* unused) {
668 Py_DECREF(self->tensor_shape);
669 Py_INCREF(value);
670 self->tensor_shape = value;
671 return 0;
672 }
673
674 // Function `_copy_to_device`.
EagerTensor_copy_to_device(EagerTensor * self,PyObject * args,PyObject * kwds)675 static PyObject* EagerTensor_copy_to_device(EagerTensor* self, PyObject* args,
676 PyObject* kwds) {
677 if (!_PyArg_NoKeywords("copy_to_device", kwds)) return nullptr;
678
679 const char* device_name = nullptr;
680 if (!PyArg_ParseTuple(args, "O&:copy_to_device", ConvertDeviceName,
681 &device_name)) {
682 return nullptr;
683 }
684
685 // Note that this is a shallow copy and will share the underlying buffer
686 // if copying to the same device.
687 TFE_TensorHandle* handle = TFE_TensorHandleCopyToDevice(
688 self->handle, GetContextHandle(self->context), device_name,
689 &self->status);
690 if (MaybeRaiseExceptionFromTFStatus(&self->status, PyExc_RuntimeError)) {
691 // Cleanup self->status before returning.
692 self->status.status = tensorflow::Status::OK();
693 return nullptr;
694 }
695
696 return EagerTensorFromHandle(handle);
697 }
698
699 // Function `_numpy_internal`.
700 // Convert an EagerTensor to a Python numpy.ndarray object.
701 // The two may share underlying storage so changes to one may reflect in the
702 // other.
703 // Note that if `self` is not on CPU, we raise an Exception.
EagerTensor_numpy_internal(EagerTensor * self)704 static PyObject* EagerTensor_numpy_internal(EagerTensor* self) {
705 auto* py_array = TFE_TensorHandleToNumpy(self->handle, &self->status);
706 if (MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) {
707 Py_XDECREF(py_array);
708 // Cleanup self->status before returning.
709 self->status.status = tensorflow::Status::OK();
710 return nullptr;
711 } else {
712 return PyArray_Return(reinterpret_cast<PyArrayObject*>(py_array));
713 }
714 }
715
716 // Function `_prefer_custom_summarizer`.
717 //
718 // A hint that callers should prefer `SummarizeValue` to resolving this handle
719 // and formatting the tensor.
EagerTensor_prefer_custom_summarizer(EagerTensor * self)720 static PyObject* EagerTensor_prefer_custom_summarizer(EagerTensor* self) {
721 if (tensorflow::unwrap(self->handle)->PreferCustomSummarizer()) {
722 Py_RETURN_TRUE;
723 } else {
724 Py_RETURN_FALSE;
725 }
726 }
727
728 // Function `_summarize_value`.
729 //
730 // Returns a string PyObject which summarizes the value of this tensor. It does
731 // not include a shape or dtype.
EagerTensor_summarize_value(EagerTensor * self)732 static PyObject* EagerTensor_summarize_value(EagerTensor* self) {
733 std::string summary;
734 tensorflow::Status status =
735 tensorflow::unwrap(self->handle)->SummarizeValue(summary);
736 if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
737 return nullptr;
738 }
739 return PyUnicode_FromString(summary.c_str());
740 }
741
742 // Getter `device`.
EagerTensor_device(EagerTensor * self)743 static PyObject* EagerTensor_device(EagerTensor* self) {
744 const char* device = TFE_TensorHandleDeviceName(self->handle, &self->status);
745 if (MaybeRaiseExceptionFromTFStatus(&self->status, PyExc_ValueError)) {
746 // Cleanup self->status before returning.
747 self->status.status = tensorflow::Status::OK();
748 return nullptr;
749 }
750 #if PY_MAJOR_VERSION >= 3
751 return PyUnicode_FromString(device);
752 #else
753 return PyBytes_FromString(device);
754 #endif
755 }
756
757 // Getter `backing_device`.
EagerTensor_backing_device(EagerTensor * self)758 static PyObject* EagerTensor_backing_device(EagerTensor* self) {
759 const char* device =
760 TFE_TensorHandleBackingDeviceName(self->handle, &self->status);
761 if (MaybeRaiseExceptionFromTFStatus(&self->status, PyExc_ValueError)) {
762 // Cleanup self->status before returning.
763 self->status.status = tensorflow::Status::OK();
764 return nullptr;
765 }
766 #if PY_MAJOR_VERSION >= 3
767 return PyUnicode_FromString(device);
768 #else
769 return PyBytes_FromString(device);
770 #endif
771 }
772
773 // Getter `is_packed`.
EagerTensor_is_packed(EagerTensor * self)774 static PyObject* EagerTensor_is_packed(EagerTensor* self) {
775 return PyBool_FromLong(self->is_packed);
776 }
777
778 static PyGetSetDef EagerTensor_getsetters[] = {
779 {const_cast<char*>("_id"), (getter)EagerTensor_getid, nullptr,
780 const_cast<char*>("Tensor ID."), nullptr},
781 {const_cast<char*>("device"), (getter)EagerTensor_device, nullptr,
782 const_cast<char*>("Device of op that produced the tensor."), nullptr},
783 {const_cast<char*>("backing_device"), (getter)EagerTensor_backing_device,
784 nullptr, const_cast<char*>("Device on which tensor's memory is resident."),
785 nullptr},
786 {const_cast<char*>("is_packed"), (getter)EagerTensor_is_packed, nullptr,
787 const_cast<char*>("Whether the EagerTensor is a packed tensor or not."),
788 nullptr},
789 {const_cast<char*>("_handle_data"), (getter)EagerTensor_handle_data,
790 (setter)EagerTensor_sethandle_data,
791 const_cast<char*>("Shape/DType data if the EagerTensor is a DT_RESOURCE"),
792 nullptr},
793 {const_cast<char*>("_tensor_shape"), (getter)EagerTensor_tensor_shape,
794 (setter)EagerTensor_settensor_shape,
795 const_cast<char*>("Shape of the tensor."), nullptr},
796 {nullptr} /* Sentinel */
797 };
798
799 #if PY_MAJOR_VERSION < 3
800 // Only used for Python2 since Python3 seems to set the __dict__ correctly.
801 static PyMemberDef EagerTensor_members[] = {
802 {const_cast<char*>("__dict__"), T_OBJECT, offsetof(EagerTensor, dict),
803 READONLY},
804 {nullptr},
805 };
806 #endif
807
808 static PyMethodDef EagerTensor_methods[] = {
809 {"_numpy_internal", (PyCFunction)EagerTensor_numpy_internal, METH_NOARGS,
810 PyDoc_STR("Internal method to get a NumPy array for the tensor.")},
811 {"_datatype_enum", (PyCFunction)EagerTensor_datatype_enum, METH_NOARGS,
812 PyDoc_STR("The DType of the tensor as an enum.")},
813 {"_shape_tuple", (PyCFunction)EagerTensor_shape_tuple, METH_NOARGS,
814 PyDoc_STR("The shape of the tensor as a python tuple.")},
815 {"_rank", (PyCFunction)EagerTensor_rank, METH_NOARGS,
816 PyDoc_STR("The rank of the tensor.")},
817 {"_copy_to_device", (PyCFunction)EagerTensor_copy_to_device,
818 METH_VARARGS | METH_KEYWORDS,
819 PyDoc_STR("Copies the tensor to the desired device.")},
820 {"_num_elements", (PyCFunction)EagerTensor_num_elements, METH_NOARGS,
821 PyDoc_STR("Number of elements in the tensor.")},
822 {"_prefer_custom_summarizer",
823 (PyCFunction)EagerTensor_prefer_custom_summarizer, METH_NOARGS,
824 PyDoc_STR("Indicates whether _numpy_internal loses information.")},
825 {"_summarize_value", (PyCFunction)EagerTensor_summarize_value, METH_NOARGS,
826 PyDoc_STR("A string which summarizes the value of this tensor.")},
827 {nullptr, nullptr},
828 };
829
EagerTensor_getbuffer(EagerTensor * self,Py_buffer * view,int flags)830 static int EagerTensor_getbuffer(EagerTensor* self, Py_buffer* view,
831 int flags) {
832 if ((flags & PyBUF_WRITABLE) == PyBUF_WRITABLE) {
833 PyErr_SetString(PyExc_BufferError, "EagerTensor is not writable.");
834 return -1;
835 }
836
837 // TensorHandleToNumpy is zero-copy for everything but DT_RESOURCE and
838 // DT_STRING so the following is only slightly slower than a NumPy-free
839 // implementation.
840 auto py_array = tensorflow::make_safe(
841 TFE_TensorHandleToNumpy(self->handle, &self->status));
842 if (MaybeRaiseExceptionFromTFStatus(&self->status, PyExc_BufferError)) {
843 // Cleanup self->status before returning.
844 self->status.status = tensorflow::Status::OK();
845 return -1;
846 }
847 if (PyObject_GetBuffer(py_array.get(), view, flags) < 0) {
848 return -1;
849 }
850 view->readonly = 1;
851 return 0;
852 }
853
854 static PyBufferProcs EagerTensor_as_buffer = {
855 #if PY_MAJOR_VERSION < 3
856 nullptr, nullptr, nullptr, nullptr,
857 #endif
858 (getbufferproc)EagerTensor_getbuffer,
859 // Never called because getbufferproc delegates to NumPy.
860 (releasebufferproc) nullptr};
861
862 // Note that here we are trying to dynamically create a new class as a subclass
863 // of a "HEAPTYPE" class that is itself created in python code and passed in at
864 // runtime. This is fairly atypical and undocumented.
865 //
866 // We use the following strategy for this. Unfortunately, we have to use
867 // different approaches for python2.x vs python3.x
868 // For python2.x, we create the class as a static type and set its tp_base to
869 // the passed in type. Unfortunately setting tp_flags to include
870 // Py_TPFLAGS_HEAPTYPE does not work by itself since it needs some more
871 // initialization of the underlying PyHeapTypeObject and not doing that leads to
872 // some random crashes especially during garbage collection.
873 // python3.x explicitly disables a static subclass of a HEAPTYPE base class.
874 // However it provides a new function, PyType_FromSpecWithBases, to create
875 // types dynamically.
876
877 // Type object for EagerTensor. This is set by TFE_Py_InitEagerTensor.
878 PyTypeObject* EagerTensorType = nullptr;
879
880 #if PY_MAJOR_VERSION >= 3
881 static PyType_Slot EagerTensor_Type_slots[] = {
882 {Py_tp_dealloc, reinterpret_cast<void*>(EagerTensor_dealloc)},
883 {Py_tp_methods, reinterpret_cast<void*>(EagerTensor_methods)},
884 {Py_tp_getset, reinterpret_cast<void*>(EagerTensor_getsetters)},
885 {Py_tp_init, reinterpret_cast<void*>(EagerTensor_init)},
886 {0, nullptr},
887 };
888 #else
889
890 #define EAGER_TENSOR_TPFLAGS (Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_NEWBUFFER)
891
892 // TODO(agarwal): support active_trace.
893 static PyTypeObject _EagerTensorType = {
894 // clang-format off
895 PyVarObject_HEAD_INIT(nullptr, 0)
896 // clang-format on
897 "EagerTensor", /* tp_name */
898 sizeof(EagerTensor), /* tp_basicsize */
899 0, /* tp_itemsize */
900 (destructor)EagerTensor_dealloc, /* tp_dealloc */
901 #if PY_VERSION_HEX < 0x03080000
902 nullptr, /* tp_print */
903 #else
904 0, /* tp_vectorcall_offset */
905 #endif
906 nullptr, /* tp_getattr */
907 nullptr, /* tp_setattr */
908 nullptr, /* tp_compare */
909 nullptr, /* tp_repr */
910 nullptr, /* tp_as_number */
911 nullptr, /* tp_as_sequence */
912 nullptr, /* tp_as_mapping */
913 nullptr, /* tp_hash */
914 nullptr, /* tp_call */
915 nullptr, /* tp_str */
916 nullptr, /* tp_getattro */
917 nullptr, /* tp_setattro */
918 &EagerTensor_as_buffer, /* tp_as_buffer */
919 EAGER_TENSOR_TPFLAGS, /* tp_flags */
920 nullptr, /* tp_doc */
921 nullptr, /* tp_traverse */
922 nullptr, /* tp_clear */
923 nullptr, /* tp_richcompare */
924 offsetof(EagerTensor, weakreflist), /* tp_weaklistoffset */
925 nullptr, /* tp_iter */
926 nullptr, /* tp_iternext */
927 EagerTensor_methods, /* tp_methods */
928 EagerTensor_members, /* tp_members */
929 EagerTensor_getsetters, /* tp_getset */
930 nullptr, /* tp_base */
931 nullptr, /* tp_dict */
932 nullptr, /* tp_descr_get */
933 nullptr, /* tp_descr_set */
934 offsetof(EagerTensor, dict), /* tp_dictoffset */
935 (initproc)EagerTensor_init, /* tp_init */
936 nullptr, /* tp_alloc */
937 nullptr, /* tp_new */
938 };
939
940 #endif
941
942 } // extern "C"
943
EagerTensor_CheckExact(const PyObject * o)944 bool EagerTensor_CheckExact(const PyObject* o) {
945 return Py_TYPE(o) == EagerTensorType;
946 }
947
EagerTensor_Handle(const PyObject * o)948 TFE_TensorHandle* EagerTensor_Handle(const PyObject* o) {
949 return reinterpret_cast<const EagerTensor*>(o)->handle;
950 }
951
EagerTensorFromHandle(TFE_TensorHandle * handle,const bool is_packed)952 PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle,
953 const bool is_packed) {
954 if (handle == nullptr) {
955 return nullptr;
956 }
957 EagerTensor* t = reinterpret_cast<EagerTensor*>(
958 EagerTensorType->tp_new(EagerTensorType, EmptyTuple(), EmptyDict()));
959 if (t != nullptr) {
960 t->id = get_uid();
961 t->is_packed = is_packed;
962 Py_INCREF(Py_None);
963 t->handle_data = Py_None;
964 Py_INCREF(Py_None);
965 t->tensor_shape = Py_None;
966 t->handle = handle;
967 t->status.status = tensorflow::Status::OK();
968 t->weakreflist = nullptr;
969 PyObject* py_context = GetPyEagerContext();
970 if (py_context == nullptr) {
971 LOG(ERROR) << "Cannot create an eager tensor before eager context has "
972 "been set or after it has been deleted";
973 return nullptr;
974 }
975 t->context = py_context;
976
977 if (!MaybeInvokeCreatedOnEagerTensorProfiler(t)) {
978 return nullptr;
979 }
980 }
981 return reinterpret_cast<PyObject*>(t);
982 }
983
PyEagerTensor_ID(const PyObject * tensor)984 tensorflow::int64 PyEagerTensor_ID(const PyObject* tensor) {
985 DCHECK(EagerTensor_CheckExact(tensor));
986 return reinterpret_cast<const EagerTensor*>(tensor)->id;
987 }
988
PyEagerTensor_Dtype(const PyObject * tensor)989 tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor) {
990 DCHECK(EagerTensor_CheckExact(tensor));
991 return static_cast<tensorflow::DataType>(TFE_TensorHandleDataType(
992 reinterpret_cast<const EagerTensor*>(tensor)->handle));
993 }
994
PyEagerTensor_NumElements(PyObject * tensor)995 tensorflow::int64 PyEagerTensor_NumElements(PyObject* tensor) {
996 DCHECK(EagerTensor_CheckExact(tensor));
997 EagerTensor* as_c_eager_tensor = reinterpret_cast<EagerTensor*>(tensor);
998 int64_t result = TFE_TensorHandleNumElements(as_c_eager_tensor->handle,
999 &as_c_eager_tensor->status);
1000
1001 if (MaybeRaiseExceptionFromTFStatus(&as_c_eager_tensor->status,
1002 PyExc_ValueError)) {
1003 // Cleanup status before returning.
1004 as_c_eager_tensor->status.status = tensorflow::Status::OK();
1005 return -1;
1006 }
1007
1008 return result;
1009 }
1010
TFE_Py_InitEagerTensor(PyObject * base_class)1011 PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
1012 if (!PyType_Check(base_class)) {
1013 PyErr_SetString(
1014 PyExc_TypeError,
1015 tensorflow::strings::StrCat(
1016 "Expecting a class definition for `base_class` passed to ",
1017 "TFE_InitEagerTensor. Got ", Py_TYPE(base_class)->tp_name)
1018 .c_str());
1019 return nullptr;
1020 }
1021 // Note that we allocated kMaxEagerTensorParentSize bytes of unused space in
1022 // EagerTensor to allow for the space usage of the base class.
1023 PyTypeObject* base_class_type = reinterpret_cast<PyTypeObject*>(base_class);
1024 if (base_class_type->tp_basicsize > kMaxEagerTensorParentSize) {
1025 PyErr_SetString(
1026 PyExc_TypeError,
1027 tensorflow::strings::StrCat(
1028 "Unable to create subclass EagerTensor from base class ",
1029 Py_TYPE(base_class)->tp_name,
1030 ". Need its size to be <= ", kMaxEagerTensorParentSize)
1031 .c_str());
1032 return nullptr;
1033 }
1034 if (base_class_type->tp_itemsize != 0) {
1035 PyErr_SetString(
1036 PyExc_TypeError,
1037 tensorflow::strings::StrCat(
1038 "Unable to create subclass EagerTensor from base class ",
1039 Py_TYPE(base_class)->tp_name,
1040 " which supports variable length instances.")
1041 .c_str());
1042 return nullptr;
1043 }
1044 Py_INCREF(base_class);
1045 #if PY_MAJOR_VERSION >= 3
1046 PyObject* bases = PyTuple_New(1);
1047 PyTuple_SET_ITEM(bases, 0, base_class);
1048
1049 tensorflow::Safe_PyObjectPtr base_class_module(
1050 PyObject_GetAttrString(base_class, "__module__"));
1051 const char* module = nullptr;
1052 if (PyErr_Occurred()) {
1053 PyErr_Clear();
1054 module = "__builtin__";
1055 } else {
1056 module = PyBytes_AsString(base_class_module.get());
1057 if (module == nullptr) {
1058 PyErr_Clear();
1059 module = PyUnicode_AsUTF8(base_class_module.get());
1060 if (module == nullptr) {
1061 PyErr_Clear();
1062 module = "__builtin__";
1063 }
1064 }
1065 }
1066
1067 // NOTE: The c_str from this string needs to outlast the function, hence is
1068 // static.
1069 static tensorflow::string fully_qualified_name =
1070 tensorflow::strings::StrCat(module, ".EagerTensor");
1071
1072 static PyType_Spec EagerTensor_Type_spec = {
1073 fully_qualified_name.c_str(), sizeof(EagerTensor), 0,
1074 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE, EagerTensor_Type_slots};
1075
1076 EagerTensorType = reinterpret_cast<PyTypeObject*>(
1077 PyType_FromSpecWithBases(&EagerTensor_Type_spec, bases));
1078 if (PyErr_Occurred()) {
1079 return nullptr;
1080 }
1081 if (EagerTensorType == nullptr) {
1082 PyErr_SetString(PyExc_RuntimeError, "Error while creating EagerTensorType");
1083 return nullptr;
1084 }
1085 EagerTensorType->tp_dictoffset = offsetof(EagerTensor, dict);
1086 EagerTensorType->tp_as_buffer = &EagerTensor_as_buffer;
1087 #else
1088 _EagerTensorType.tp_base = base_class_type;
1089
1090 if (PyType_Ready(&_EagerTensorType) < 0) {
1091 if (PyErr_Occurred()) return nullptr;
1092 PyErr_SetString(PyExc_RuntimeError,
1093 "Error while creating EagerTensor type.");
1094 return nullptr;
1095 }
1096 EagerTensorType = &_EagerTensorType;
1097 Py_INCREF(EagerTensorType);
1098 #endif
1099 return reinterpret_cast<PyObject*>(EagerTensorType);
1100 }
1101
TFE_Py_SetEagerTensorProfiler(PyObject * profiler)1102 PyObject* TFE_Py_SetEagerTensorProfiler(PyObject* profiler) {
1103 Py_XDECREF(eager_tensor_profiler);
1104
1105 if (profiler == Py_None) {
1106 eager_tensor_profiler = nullptr;
1107 } else {
1108 eager_tensor_profiler = profiler;
1109 Py_INCREF(eager_tensor_profiler);
1110 }
1111 Py_RETURN_NONE;
1112 }
1113
TFE_Py_TensorShapeSlice(PyObject * tensors,int slice_dim)1114 PyObject* TFE_Py_TensorShapeSlice(PyObject* tensors, int slice_dim) {
1115 if (!PyList_Check(tensors) && !PyTuple_Check(tensors)) {
1116 PyErr_SetString(PyExc_TypeError,
1117 tensorflow::strings::StrCat(
1118 "tensors argument must be a list or a tuple. Got \"",
1119 Py_TYPE(tensors)->tp_name, "\"")
1120 .c_str());
1121 return nullptr;
1122 }
1123 if (slice_dim < 0) {
1124 PyErr_SetString(
1125 PyExc_ValueError,
1126 tensorflow::strings::StrCat("Slice dimension must be non-negative. "
1127 "Got ",
1128 slice_dim)
1129 .c_str());
1130 return nullptr;
1131 }
1132
1133 PyObject* py_context = GetPyEagerContext();
1134 if (py_context == nullptr) {
1135 PyErr_SetString(PyExc_RuntimeError, tensorflow::strings::StrCat(
1136 "Cannot create EagerTensor when "
1137 "EagerContext is not valid")
1138 .c_str());
1139 return nullptr;
1140 }
1141
1142 TFE_Context* ctx = GetContextHandle(py_context);
1143
1144 Py_ssize_t num_tensors = PySequence_Fast_GET_SIZE(tensors);
1145 PyObject** tensors_array = PySequence_Fast_ITEMS(tensors);
1146 int64_t num_tensors_int = static_cast<int64_t>(num_tensors);
1147
1148 auto status = tensorflow::make_safe(TF_NewStatus());
1149
1150 // Create an empty tensor.
1151 auto* tensor = tensorflow::unwrap(ctx)->CreateTensor(
1152 tensorflow::DT_INT32, /*dim_sizes=*/{num_tensors_int});
1153
1154 if (num_tensors_int > 0) {
1155 int32_t* data = reinterpret_cast<int32_t*>(tensor->Data());
1156
1157 // Fill the tensor with dims.
1158 for (Py_ssize_t i = 0; i < num_tensors; ++i) {
1159 PyObject* tensor_obj = tensors_array[i];
1160 if (!EagerTensor_CheckExact(tensor_obj)) {
1161 PyErr_SetString(
1162 PyExc_TypeError,
1163 tensorflow::strings::StrCat("Expected a list of EagerTensors but "
1164 "element ",
1165 i, " has type \"",
1166 Py_TYPE(tensor_obj)->tp_name, "\"")
1167 .c_str());
1168 return nullptr;
1169 }
1170
1171 EagerTensor* t = reinterpret_cast<EagerTensor*>(tensor_obj);
1172 TFE_TensorHandle* handle = t->handle;
1173 int num_dims = TFE_TensorHandleNumDims(handle, status.get());
1174 if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_ValueError)) {
1175 return nullptr;
1176 }
1177 if (slice_dim >= num_dims) {
1178 PyErr_SetString(
1179 PyExc_IndexError,
1180 tensorflow::strings::StrCat("Slice dimension (", slice_dim,
1181 ") must be smaller than rank of all "
1182 "tensors, but tensor at index ",
1183 i, " has rank ", num_dims)
1184 .c_str());
1185 return nullptr;
1186 }
1187 int64_t dim = TFE_TensorHandleDim(handle, slice_dim, status.get());
1188 if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_ValueError)) {
1189 return nullptr;
1190 }
1191 data[i] = dim;
1192 }
1193 }
1194
1195 TFE_TensorHandle* handle =
1196 tensorflow::wrap(tensorflow::unwrap(ctx)->CreateLocalHandle(tensor));
1197
1198 if (!status->status.ok()) {
1199 PyErr_SetString(
1200 PyExc_RuntimeError,
1201 tensorflow::strings::StrCat("Failed to construct new tensor handle: ",
1202 TF_Message(status.get()))
1203 .c_str());
1204 return nullptr;
1205 }
1206
1207 return EagerTensorFromHandle(handle);
1208 }
1209
TFE_Py_TensorShapeOnDevice(PyObject * tensor)1210 PyObject* TFE_Py_TensorShapeOnDevice(PyObject* tensor) {
1211 if (!EagerTensor_CheckExact(tensor)) {
1212 PyErr_SetString(
1213 PyExc_TypeError,
1214 tensorflow::strings::StrCat("Expected an EagerTensors but got type \"",
1215 Py_TYPE(tensor)->tp_name, "\"")
1216 .c_str());
1217 return nullptr;
1218 }
1219 TFE_TensorHandle* handle = EagerTensor_Handle(tensor);
1220
1221 auto status = tensorflow::make_safe(TF_NewStatus());
1222 TFE_TensorDebugInfo* debug_info =
1223 TFE_TensorHandleTensorDebugInfo(handle, status.get());
1224 if (!status->status.ok()) {
1225 PyErr_SetString(
1226 PyExc_RuntimeError,
1227 tensorflow::strings::StrCat("Error retrieving tensor's device shape: ",
1228 TF_Message(status.get()))
1229 .c_str());
1230 return nullptr;
1231 }
1232
1233 int rank = TFE_TensorDebugInfoOnDeviceNumDims(debug_info);
1234 PyObject* shape = PyTuple_New(rank);
1235 for (int i = 0; i < rank; ++i) {
1236 int64_t dim_size = TFE_TensorDebugInfoOnDeviceDim(debug_info, i);
1237 PyTuple_SET_ITEM(shape, i, PyLong_FromLongLong(dim_size));
1238 }
1239 TFE_DeleteTensorDebugInfo(debug_info);
1240
1241 return shape;
1242 }
1243