1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");;
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17
18 #include "Python.h"
19 #include "absl/strings/str_format.h"
20 #include "pybind11/chrono.h"
21 #include "pybind11/complex.h"
22 #include "pybind11/functional.h"
23 #include "pybind11/pybind11.h"
24 #include "pybind11/pytypes.h"
25 #include "pybind11/stl.h"
26 #include "tensorflow/c/c_api.h"
27 #include "tensorflow/c/c_api_experimental.h"
28 #include "tensorflow/c/eager/c_api.h"
29 #include "tensorflow/c/eager/c_api_experimental.h"
30 #include "tensorflow/c/eager/c_api_internal.h"
31 #include "tensorflow/c/eager/dlpack.h"
32 #include "tensorflow/c/eager/tfe_cancellation_manager_internal.h"
33 #include "tensorflow/c/eager/tfe_context_internal.h"
34 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
35 #include "tensorflow/c/tf_status.h"
36 #include "tensorflow/c/tf_status_helper.h"
37 #include "tensorflow/compiler/jit/flags.h"
38 #include "tensorflow/compiler/jit/get_compiler_ir.h"
39 #include "tensorflow/python/eager/pywrap_tensor_conversion.h"
40 #include "tensorflow/python/eager/pywrap_tfe.h"
41 #include "tensorflow/python/lib/core/py_exception_registry.h"
42 #include "tensorflow/python/lib/core/pybind11_lib.h"
43 #include "tensorflow/python/lib/core/pybind11_status.h"
44 #include "tensorflow/python/lib/core/safe_ptr.h"
45 #include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
46 #include "tensorflow/python/util/util.h"
47
48 namespace py = pybind11;
49
50 PYBIND11_MAKE_OPAQUE(TFE_Executor);
51 PYBIND11_MAKE_OPAQUE(TFE_ContextOptions);
52 PYBIND11_MAKE_OPAQUE(tensorflow::CancellationManager);
53
54 PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter0);
55 PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter1);
56 PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter2);
57 PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge0);
58 PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge1);
59 PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge2);
60 PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGauge0);
61 PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGauge1);
62 PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGauge2);
63 PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGauge0);
64 PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGauge1);
65 PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGauge2);
66 PYBIND11_MAKE_OPAQUE(TFE_MonitoringSampler0);
67 PYBIND11_MAKE_OPAQUE(TFE_MonitoringSampler1);
68 PYBIND11_MAKE_OPAQUE(TFE_MonitoringSampler2);
69 PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounterCell);
70 PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGaugeCell);
71 PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGaugeCell);
72 PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGaugeCell);
73 PYBIND11_MAKE_OPAQUE(TFE_MonitoringSamplerCell);
74
75 PYBIND11_MAKE_OPAQUE(TF_DeviceList);
76 PYBIND11_MAKE_OPAQUE(TF_Function);
77 PYBIND11_MAKE_OPAQUE(TF_Buffer);
78
79 // Eager helper functions migrated from pywrap_tfe.i.
80
81 namespace tensorflow {
82
83 // We cannot use Context as an opaque type. SWIG also had
84 // difficult directly passing the pointer around. These
85 // typemaps are migrated over from pywrap_tfe.i. I tried
86 // using a custom type caster, but we get segfaults periodically.
87
88 // TODO(amitpatankar): Move input and output logic of Context into a
89 // pybind11 custom type caster.
90
InputTFE_Context(const py::handle & ctx)91 TFE_Context* InputTFE_Context(const py::handle& ctx) {
92 return static_cast<TFE_Context*>(PyCapsule_GetPointer(ctx.ptr(), nullptr));
93 }
94
OutputTFE_Context(TFE_Context * context)95 PyObject* OutputTFE_Context(TFE_Context* context) {
96 return PyCapsule_New(context, nullptr, TFE_DeleteContextCapsule);
97 }
98
ProtoStringToTFBuffer(PyObject * input)99 TF_Buffer* ProtoStringToTFBuffer(PyObject* input) {
100 // Convert a Python string object to TF_Buffer.
101 char* c_string;
102 Py_ssize_t py_size;
103 // PyBytes_AsStringAndSize() does not copy but simply interprets the input
104 if (PyBytes_AsStringAndSize(input, &c_string, &py_size) == -1) {
105 // Python has raised an error (likely TypeError or UnicodeEncodeError).
106 throw py::error_already_set();
107 }
108 return TF_NewBufferFromString(static_cast<void*>(c_string),
109 static_cast<size_t>(py_size));
110 }
111
112 // These functions are typemaps from the Python side. I did not use
113 // a custom type caster since the logic is slightly harder to follow. This
114 // converter is also only used once in `TFE_Py_ExecuteCancelable_wrapper`.
InputTFE_InputTensorHandles(const py::handle & input_tensors)115 TFE_InputTensorHandles InputTFE_InputTensorHandles(
116 const py::handle& input_tensors) {
117 TFE_InputTensorHandles input_tensor_handles;
118 if (input_tensors.ptr() != Py_None) {
119 if (!PyList_Check(input_tensors.ptr())) {
120 tensorflow::ThrowTypeError("must provide a list of Tensors as inputs");
121 }
122 Py_ssize_t len = PyList_Size(input_tensors.ptr());
123 input_tensor_handles.resize(len);
124 for (Py_ssize_t i = 0; i < len; ++i) {
125 PyObject* elem = PyList_GetItem(input_tensors.ptr(), i);
126 if (!elem) {
127 tensorflow::ThrowTypeError("Input Tensor does not exist.");
128 }
129 if (EagerTensor_CheckExact(elem)) {
130 (input_tensor_handles)[i] = EagerTensor_Handle(elem);
131 } else if (tensorflow::swig::IsEagerTensorSlow(elem)) {
132 // Use equivalent of object.__getattribute__ to get the underlying
133 // tf wrapped EagerTensor (if there is one).
134 tensorflow::Safe_PyObjectPtr tf_should_use_attr(
135 #if PY_MAJOR_VERSION < 3
136 PyString_InternFromString("_tf_should_use_wrapped_value")
137 #else
138 PyUnicode_InternFromString("_tf_should_use_wrapped_value")
139 #endif
140 );
141 tensorflow::Safe_PyObjectPtr value_attr(
142 PyObject_GenericGetAttr(elem, tf_should_use_attr.get()));
143 if (value_attr) {
144 // This is an EagerTensor wrapped inside a TFShouldUse wrapped object.
145 (input_tensor_handles)[i] = EagerTensor_Handle(value_attr.get());
146 } else {
147 // This is a subclass of EagerTensor that we don't support.
148 PyErr_Clear();
149 tensorflow::ThrowTypeError(
150 tensorflow::strings::StrCat(
151 "Saw an object that is an instance of a strict subclass of "
152 "EagerTensor, which is not supported. Item ",
153 i, " is type: ", elem->ob_type->tp_name)
154 .c_str());
155 }
156 } else if (tensorflow::swig::IsTensor(elem)) {
157 // If it isnt an EagerTensor, but is still a Tensor, it must be a graph
158 // tensor.
159 tensorflow::Safe_PyObjectPtr name_attr(
160 PyObject_GetAttrString(elem, "name"));
161 tensorflow::ThrowTypeError(
162 tensorflow::strings::StrCat(
163 "An op outside of the function building code is being passed\n"
164 "a \"Graph\" tensor. It is possible to have Graph tensors\n"
165 "leak out of the function building context by including a\n"
166 "tf.init_scope in your function building code.\n"
167 "For example, the following function will fail:\n",
168 " @tf.function\n", " def has_init_scope():\n",
169 " my_constant = tf.constant(1.)\n",
170 " with tf.init_scope():\n",
171 " added = my_constant * 2\n",
172 "The graph tensor has name: ",
173 name_attr ? TFE_GetPythonString(name_attr.get()) : "<unknown>")
174 .c_str());
175 } else {
176 tensorflow::ThrowTypeError(
177 tensorflow::strings::StrCat(
178 "provided list of inputs contains objects other "
179 "than 'EagerTensor'. Item ",
180 i, " is type: ", elem->ob_type->tp_name)
181 .c_str());
182 }
183 }
184 }
185 return input_tensor_handles;
186 }
187
188 // These functions are typemaps from the Python side. I did not use
189 // a custom type caster since the logic is slightly harder to follow. This
190 // converter is also only used once in `TFE_Py_ExecuteCancelable_wrapper`.
191 // This function actually takes a number rather than an output Tensor holder.
InputTFE_OutputTensorHandles(const py::handle & num_outputs)192 TFE_OutputTensorHandles InputTFE_OutputTensorHandles(
193 const py::handle& num_outputs) {
194 TFE_OutputTensorHandles output_tensor_handles;
195 #if PY_MAJOR_VERSION < 3
196 if (!PyInt_Check(num_outputs.ptr())) {
197 #else
198 if (!PyLong_Check(num_outputs.ptr())) {
199 #endif
200 PyErr_SetString(PyExc_TypeError,
201 "expected an integer value (size of the number of "
202 "outputs of the operation)");
203 throw py::error_already_set();
204 }
205 #if PY_MAJOR_VERSION < 3
206 long sz = PyInt_AsLong(num_outputs.ptr()); // NOLINT
207 #else
208 long sz = PyLong_AsLong(num_outputs.ptr()); // NOLINT
209 #endif
210 // We can't handle more than int32 sizes for number of outputs.
211 if (static_cast<long>(static_cast<int32>(sz)) != sz) { // NOLINT
212 PyErr_SetString(PyExc_ValueError, tensorflow::strings::StrCat(
213 "Number of outputs is too big: ", sz)
214 .c_str());
215 throw py::error_already_set();
216 }
217 if (sz > 0) {
218 #if PY_MAJOR_VERSION < 3
219 output_tensor_handles.resize(PyInt_AsLong(num_outputs.ptr()), nullptr);
220 #else
221 output_tensor_handles.resize(PyLong_AsLong(num_outputs.ptr()), nullptr);
222 #endif
223 }
224 return output_tensor_handles;
225 }
226
227 // Packs multiple `EagerTensor`s of the same dtype and shape into one
228 // `EagerTensor`.
229 py::object TFE_Py_PackEagerTensors_wrapper(const py::handle& context,
230 const py::handle& tensors) {
231 TFE_Context* ctx = tensorflow::InputTFE_Context(context);
232 TFE_InputTensorHandles handles = InputTFE_InputTensorHandles(tensors);
233 tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus());
234 int size = handles.size();
235 TFE_TensorHandle* packed_handle =
236 TFE_CreatePackedTensorHandle(ctx, handles.data(), &size, status.get());
237 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
238 PyObject* packed_tensor =
239 EagerTensorFromHandle(packed_handle, /*is_packed=*/true);
240 return tensorflow::PyoOrThrow(packed_tensor);
241 }
242
243 // This function was created from fusing the typemap logic in platform/base.i.
244 py::object TFE_Py_ExecuteCancelable_wrapper(
245 const py::handle& context, const char* device_name, const char* op_name,
246 const py::handle& inputs, const py::handle& attrs,
247 tensorflow::CancellationManager* cancellation_manager,
248 const py::handle& num_outputs) {
249 TFE_Context* ctx = tensorflow::InputTFE_Context(context);
250 TFE_InputTensorHandles input_tensor_handles =
251 InputTFE_InputTensorHandles(inputs);
252 TFE_OutputTensorHandles output_tensor_handles =
253 InputTFE_OutputTensorHandles(num_outputs);
254 tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus());
255 TFE_Py_ExecuteCancelable(ctx, device_name, op_name, &input_tensor_handles,
256 attrs.ptr(), tensorflow::wrap(cancellation_manager),
257 &output_tensor_handles, status.get());
258
259 int output_len = output_tensor_handles.size();
260 PyObject* output_list = PyList_New(output_len);
261 for (int i = 0; i < output_len; ++i) {
262 PyObject* output;
263 output = EagerTensorFromHandle(output_tensor_handles.at(i));
264 PyList_SetItem(output_list, i, output);
265 }
266 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
267 return tensorflow::PyoOrThrow(output_list);
268 }
269
270 static py::object TF_ListPhysicalDevices() {
271 std::vector<string> devices;
272 tensorflow::Status s =
273 tensorflow::DeviceFactory::ListAllPhysicalDevices(&devices);
274 MaybeRaiseRegisteredFromStatus(s);
275 PyObject* result = PyList_New(devices.size());
276 int i = 0;
277 for (auto& dev : devices) {
278 PyObject* dev_obj = PyBytes_FromStringAndSize(dev.data(), dev.size());
279 PyList_SetItem(result, i, dev_obj);
280 ++i;
281 }
282 return tensorflow::PyoOrThrow(result);
283 }
284
285 static std::unordered_map<string, string> TF_GetDeviceDetails(int index) {
286 tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus());
287 std::unordered_map<string, string> device_details;
288 tensorflow::Status s =
289 tensorflow::DeviceFactory::GetAnyDeviceDetails(index, &device_details);
290 tensorflow::Set_TF_Status_from_Status(status.get(), s);
291 MaybeRaiseRegisteredFromTFStatus(status.get());
292 return device_details;
293 }
294
295 static py::object TFE_ClearScalarCache() {
296 tensorflow::TFE_TensorHandleCache::Get()->Clear();
297 return py::none();
298 }
299
300 // Returns compiler IR for a given function.
301 static py::bytes TFE_GetCompilerIr(py::handle& ctx,
302 const char* concrete_function_name,
303 const char* stage, const char* device_name,
304 py::handle& inputs) {
305 EagerContext* context = ContextFromInterface(
306 reinterpret_cast<ImmediateExecutionContext*>(InputTFE_Context(ctx)));
307
308 std::string s_stage(stage);
309 IrExportStage selected_stage = [&] {
310 if (s_stage == "hlo") {
311 return IrExportStage::HLO;
312 } else if (s_stage == "hlo_serialized") {
313 return IrExportStage::HLO_SERIALIZED;
314 } else if (s_stage == "optimized_hlo") {
315 return IrExportStage::OPTIMIZED_HLO;
316 } else if (s_stage == "optimized_hlo_serialized") {
317 return IrExportStage::OPTIMIZED_HLO_SERIALIZED;
318 } else if (s_stage == "optimized_hlo_dot") {
319 return IrExportStage::OPTIMIZED_HLO_DOT;
320 } else {
321 ThrowValueError(
322 absl::StrFormat("Invalid stage selected: '%s'. Valid values are: "
323 "'hlo', 'optimized_hlo', 'optimized_hlo_dot'",
324 s_stage)
325 .c_str());
326 }
327 }();
328
329 TFE_InputTensorHandles handles = InputTFE_InputTensorHandles(inputs);
330
331 std::vector<const TensorHandle*> input_handles;
332 for (TFE_TensorHandle* tensor_handle : handles) {
333 AbstractTensorHandle* abstract_tensor_handle = unwrap(tensor_handle);
334 input_handles.push_back(TensorHandleFromInterface(abstract_tensor_handle));
335 }
336
337 DeviceNameUtils::ParsedName input_device_name;
338 if (!DeviceNameUtils::ParseFullOrLocalName(device_name, &input_device_name)) {
339 ThrowValueError(
340 absl::StrFormat("Failed parsing device name: '%s'", device_name)
341 .c_str());
342 }
343
344 std::vector<Device*> devices = context->local_device_mgr()->ListDevices();
345 auto selected_device = absl::c_find_if(devices, [&](const Device* d) {
346 return DeviceNameUtils::AreCompatibleDevNames(input_device_name,
347 d->parsed_name());
348 });
349 if (selected_device == devices.end()) {
350 ThrowValueError(
351 absl::StrFormat("No matching device found for '%s'", device_name)
352 .c_str());
353 }
354
355 xla::StatusOr<std::string> hlo_str =
356 GetCompilerIr(selected_stage, context->pflr(), concrete_function_name,
357 *selected_device, context, input_handles);
358
359 if (!hlo_str.ok()) {
360 ThrowValueError(absl::StrFormat("Failed getting HLO text: '%s'",
361 hlo_str.status().error_message())
362 .c_str());
363 }
364 return py::bytes(*hlo_str);
365 }
366
367 } // namespace tensorflow
368
369 namespace {
370
371 // Wrapper around the EagerContextThreadLocalData struct (defined in
372 // pywrap_tfe.h), so it can be accessed from Python.
373 //
374 // For PyObject* fields, the get_*() methods return a new reference; and the
375 // set_*() methods create a new reference (i.e., they do not steal a reference).
376 class EagerContextThreadLocalDataWrapper {
377 public:
EagerContextThreadLocalDataWrapper(py::handle py_eager_context,py::handle is_eager,py::handle device_spec)378 explicit EagerContextThreadLocalDataWrapper(py::handle py_eager_context,
379 py::handle is_eager,
380 py::handle device_spec)
381 : py_eager_context_(py_eager_context.ptr()) {
382 tensorflow::MakeEagerContextThreadLocalData(
383 py_eager_context.ptr(), is_eager.ptr(), device_spec.ptr());
384 }
385
~EagerContextThreadLocalDataWrapper()386 ~EagerContextThreadLocalDataWrapper() {
387 tensorflow::DestroyEagerContextThreadLocalData(py_eager_context_);
388 }
389
get_is_eager() const390 bool get_is_eager() const { return GetData()->is_eager; }
set_is_eager(bool v)391 void set_is_eager(bool v) { GetData()->is_eager = v; }
392
get_invoking_op_callbacks() const393 bool get_invoking_op_callbacks() const {
394 return GetData()->invoking_op_callbacks;
395 }
set_invoking_op_callbacks(bool v)396 void set_invoking_op_callbacks(bool v) {
397 GetData()->invoking_op_callbacks = v;
398 }
399
get_device_name() const400 py::handle get_device_name() const {
401 return GetPyObject(&GetData()->device_name);
402 }
set_device_name(py::handle v)403 void set_device_name(py::handle v) {
404 SetPyObject(v, &GetData()->device_name);
405 }
406
get_scope_name() const407 py::handle get_scope_name() const {
408 return GetPyObject(&GetData()->scope_name);
409 }
set_scope_name(py::handle v)410 void set_scope_name(py::handle v) { SetPyObject(v, &GetData()->scope_name); }
411
get_device_spec() const412 py::handle get_device_spec() const {
413 return GetPyObject(&GetData()->device_spec);
414 }
set_device_spec(py::handle v)415 void set_device_spec(py::handle v) {
416 SetPyObject(v, &GetData()->device_spec);
417 }
418
get_function_call_options() const419 py::handle get_function_call_options() const {
420 return GetPyObject(&GetData()->function_call_options);
421 }
set_function_call_options(py::handle v)422 void set_function_call_options(py::handle v) {
423 SetPyObject(v, &GetData()->function_call_options);
424 }
425
get_executor() const426 py::handle get_executor() const { return GetPyObject(&GetData()->executor); }
set_executor(py::handle v)427 void set_executor(py::handle v) { SetPyObject(v, &GetData()->executor); }
428
get_op_callbacks() const429 py::handle get_op_callbacks() const {
430 return GetPyObject(&GetData()->op_callbacks);
431 }
set_op_callbacks(py::handle v)432 void set_op_callbacks(py::handle v) {
433 SetPyObject(v, &GetData()->op_callbacks);
434 }
435
436 private:
GetData() const437 tensorflow::EagerContextThreadLocalData* GetData() const {
438 auto* result =
439 tensorflow::GetEagerContextThreadLocalData(py_eager_context_);
440 if (!result) {
441 throw py::error_already_set();
442 }
443 return result;
444 }
445
GetPyObject(tensorflow::Safe_PyObjectPtr * obj) const446 py::handle GetPyObject(tensorflow::Safe_PyObjectPtr* obj) const {
447 Py_INCREF(obj->get());
448 return obj->get();
449 }
450
SetPyObject(py::handle value,tensorflow::Safe_PyObjectPtr * ptr)451 void SetPyObject(py::handle value, tensorflow::Safe_PyObjectPtr* ptr) {
452 Py_INCREF(value.ptr());
453 ptr->reset(value.ptr());
454 }
455
456 PyObject* py_eager_context_; // not owned (borrowed reference).
457 };
458
459 } // namespace
460
461 // py::return_value_policy::reference is defined as specified by the
462 // pybind11 documents listed here.
463 // https://pybind11.readthedocs.io/en/stable/advanced/functions.html#return-value-policies
464 // This means that C++ maintains ownership of the object. We
465 // are only assigning this to functions that return opaque types.
466
PYBIND11_MODULE(_pywrap_tfe,m)467 PYBIND11_MODULE(_pywrap_tfe, m) {
468 py::class_<TFE_Executor> TFE_Executor_class(m, "TFE_Executor");
469 py::class_<TFE_ContextOptions> TFE_ContextOptions_class(m,
470 "TFE_ContextOptions");
471 py::class_<TFE_MonitoringCounter0> TFE_MonitoringCounter0_class(
472 m, "TFE_MonitoringCounter0");
473 py::class_<TFE_MonitoringCounter1> TFE_MonitoringCounter1_class(
474 m, "TFE_MonitoringCounter1");
475 py::class_<TFE_MonitoringCounter2> TFE_MonitoringCounter2_class(
476 m, "TFE_MonitoringCounter2");
477 py::class_<TFE_MonitoringStringGauge0> TFE_MonitoringStringGauge0_class(
478 m, "TFE_MonitoringStringGauge0");
479 py::class_<TFE_MonitoringStringGauge1> TFE_MonitoringStringGauge1_class(
480 m, "TFE_MonitoringStringGauge1");
481 py::class_<TFE_MonitoringStringGauge2> TFE_MonitoringStringGauge2_class(
482 m, "TFE_MonitoringStringGauge2");
483 py::class_<TFE_MonitoringIntGauge0> TFE_MonitoringIntGauge0_class(
484 m, "TFE_MonitoringIntGauge0");
485 py::class_<TFE_MonitoringIntGauge1> TFE_MonitoringIntGauge1_class(
486 m, "TFE_MonitoringIntGauge1");
487 py::class_<TFE_MonitoringIntGauge2> TFE_MonitoringIntGauge2_class(
488 m, "TFE_MonitoringIntGauge2");
489 py::class_<TFE_MonitoringBoolGauge0> TFE_MonitoringBoolGauge0_class(
490 m, "TFE_MonitoringBoolGauge0");
491 py::class_<TFE_MonitoringBoolGauge1> TFE_MonitoringBoolGauge1_class(
492 m, "TFE_MonitoringBoolGauge1");
493 py::class_<TFE_MonitoringBoolGauge2> TFE_MonitoringBoolGauge2_class(
494 m, "TFE_MonitoringBoolGauge2");
495 py::class_<TFE_MonitoringCounterCell> TFE_MonitoringCounterCell_class(
496 m, "TFE_MonitoringCounterCell");
497 py::class_<TFE_MonitoringIntGaugeCell> TFE_MonitoringIntGaugeCell_class(
498 m, "TFE_MonitoringIntGaugeCell");
499 py::class_<TFE_MonitoringStringGaugeCell> TFE_MonitoringStringGaugeCell_class(
500 m, "TFE_MonitoringStringGaugeCell");
501 py::class_<TFE_MonitoringBoolGaugeCell> TFE_MonitoringBoolGaugeCell_class(
502 m, "TFE_MonitoringBoolGaugeCell");
503 py::class_<TFE_MonitoringSamplerCell> TFE_MonitoringSamplerCell_class(
504 m, "TFE_MonitoringSamplerCell");
505 py::class_<TFE_MonitoringBuckets> TFE_MonitoringBuckets_class(
506 m, "TFE_MonitoringBuckets");
507 py::class_<TFE_MonitoringSampler0> TFE_MonitoringSampler0_class(
508 m, "TFE_MonitoringSampler0");
509 py::class_<TFE_MonitoringSampler1> TFE_MonitoringSampler1_class(
510 m, "TFE_MonitoringSampler1");
511 py::class_<TFE_MonitoringSampler2> TFE_MonitoringSampler2_class(
512 m, "TFE_MonitoringSampler2");
513 py::class_<tensorflow::CancellationManager> TFE_CancellationManager_class(
514 m, "TFE_CancellationManager");
515
516 py::class_<TF_DeviceList> TF_DeviceList_class(m, "TF_DeviceList");
517 py::class_<TF_Function> TF_Function_class(m, "TF_Function");
518
519 m.def("TFE_Py_RegisterExceptionClass", [](const py::handle& e) {
520 return tensorflow::PyoOrThrow(TFE_Py_RegisterExceptionClass(e.ptr()));
521 });
522 m.def("TFE_Py_RegisterFallbackExceptionClass", [](const py::handle& e) {
523 return tensorflow::PyoOrThrow(
524 TFE_Py_RegisterFallbackExceptionClass(e.ptr()));
525 });
526
527 m.def(
528 "TFE_GetMemoryInfo", [](py::handle& ctx, const char* device_name) {
529 auto* context =
530 reinterpret_cast<tensorflow::ImmediateExecutionContext*>(
531 tensorflow::InputTFE_Context(ctx));
532
533 tensorflow::DeviceNameUtils::ParsedName input_device_name;
534 if (!tensorflow::DeviceNameUtils::ParseFullOrLocalName(
535 device_name, &input_device_name)) {
536 tensorflow::ThrowValueError(
537 absl::StrFormat("Failed parsing device name: '%s'", device_name)
538 .c_str());
539 }
540
541 std::vector<tensorflow::Device*> devices =
542 context->ListLocalTfDevices();
543
544 tensorflow::Device* matched_device = nullptr;
545 for (int device_idx = 0; device_idx < devices.size(); device_idx++) {
546 tensorflow::Device* device = devices[device_idx];
547
548 if (tensorflow::DeviceNameUtils::AreCompatibleDevNames(
549 input_device_name, device->parsed_name())) {
550 if (device->device_type() == tensorflow::DEVICE_CPU) {
551 tensorflow::ThrowValueError(
552 "CPU does not support getting allocator information");
553 }
554
555 if (matched_device != nullptr) {
556 tensorflow::ThrowValueError(
557 absl::StrFormat(
558 "Multiple devices matching the provided string "
559 "'%s': '%s' and "
560 "'%s' ",
561 device_name, matched_device->name(), device->name())
562 .c_str());
563 }
564 matched_device = device;
565 }
566 }
567
568 if (matched_device == nullptr) {
569 tensorflow::ThrowValueError(
570 absl::StrFormat("No matching devices found for '%s'", device_name)
571 .c_str());
572 }
573
574 tensorflow::AllocatorAttributes attrs;
575 tensorflow::Allocator* allocator = matched_device->GetAllocator(attrs);
576
577 if (absl::optional<tensorflow::AllocatorStats> stats =
578 allocator->GetStats()) {
579 return std::map<std::string, int64_t>{
580 {"current", stats->bytes_in_use},
581 {"peak", stats->peak_bytes_in_use}};
582 }
583
584 tensorflow::ThrowTypeError(
585 absl::StrFormat("Allocator stats not available for device '%s'",
586 matched_device->name())
587 .c_str());
588 });
589
590 // XLA Eager Logic
591 m.def("TF_SetXlaEnableLazyCompilation", &TF_SetXlaEnableLazyCompilation);
592 m.def("TF_SetTfXlaCpuGlobalJit", &TF_SetTfXlaCpuGlobalJit);
593 m.def("TF_SetXlaAutoJitMode", &TF_SetXlaAutoJitMode);
594 m.def("TF_SetXlaConstantFoldingDisabled", &TF_SetXlaConstantFoldingDisabled);
595 m.def("TF_GetXlaConstantFoldingDisabled", &TF_GetXlaConstantFoldingDisabled);
596 m.def("TF_SetXlaMinClusterSize", &TF_SetXlaMinClusterSize);
597 m.def("TF_GetCompilerIr", &tensorflow::TFE_GetCompilerIr);
598
599 // MLIR Logic
600 m.def("TF_IsMlirBridgeEnabled", [] {
601 // Since python protobuf enums are integers, cast to an integer before
602 // returning the enum to python.
603 return static_cast<int32_t>(
604 tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge);
605 });
606 m.def("TF_EnableMlirBridge", [](bool enabled) {
607 tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge =
608 enabled
609 ? tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED
610 : tensorflow::ConfigProto::Experimental::
611 MLIR_BRIDGE_ROLLOUT_DISABLED;
612 });
613 m.def("TF_EnableXlaDevices", [] {
614 tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true;
615 });
616
617 // // TFE_Context Logic
618 m.def(
619 "TFE_NewContext",
620 [](const TFE_ContextOptions* opts) {
621 tensorflow::Safe_TF_StatusPtr status =
622 tensorflow::make_safe(TF_NewStatus());
623 TFE_Context* context = TFE_NewContext(opts, status.get());
624 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
625 return tensorflow::PyoOrThrow(tensorflow::OutputTFE_Context(context));
626 },
627 py::return_value_policy::reference);
628 m.def("TFE_DeleteContext", [](py::handle& o) {
629 TFE_DeleteContext(tensorflow::InputTFE_Context(o));
630 });
631 m.def(
632 "TFE_ContextListDevices",
633 [](py::handle& o) {
634 tensorflow::Safe_TF_StatusPtr status =
635 tensorflow::make_safe(TF_NewStatus());
636 auto output = TFE_ContextListDevices(tensorflow::InputTFE_Context(o),
637 status.get());
638 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
639 return output;
640 },
641 py::return_value_policy::reference);
642 m.def("TFE_HostAddressSpace", [](py::handle& o, TF_Buffer& buf) {
643 TFE_HostAddressSpace(tensorflow::InputTFE_Context(o), &buf);
644 });
645 m.def("TFE_ContextAddFunction", [](py::handle& ctx, TF_Function* func) {
646 tensorflow::Safe_TF_StatusPtr status =
647 tensorflow::make_safe(TF_NewStatus());
648 TFE_ContextAddFunction(tensorflow::InputTFE_Context(ctx), func,
649 status.get());
650 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
651 });
652 m.def("TFE_ContextAddFunctionDef",
653 [](py::handle& ctx, const char* serialized_function_def, size_t size) {
654 tensorflow::Safe_TF_StatusPtr status =
655 tensorflow::make_safe(TF_NewStatus());
656 TFE_ContextAddFunctionDef(tensorflow::InputTFE_Context(ctx),
657 serialized_function_def, size,
658 status.get());
659 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
660 });
661 m.def("TFE_ContextGetFunctionDef",
662 [](py::handle& ctx, const char* function_name, TF_Buffer& buf) {
663 tensorflow::Safe_TF_StatusPtr status =
664 tensorflow::make_safe(TF_NewStatus());
665 TFE_ContextGetFunctionDef(tensorflow::InputTFE_Context(ctx),
666 function_name, &buf, status.get());
667 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
668 });
669 m.def("TFE_ContextRemoveFunction", [](py::handle& ctx, const char* name) {
670 tensorflow::Safe_TF_StatusPtr status =
671 tensorflow::make_safe(TF_NewStatus());
672 TFE_ContextRemoveFunction(tensorflow::InputTFE_Context(ctx), name,
673 status.get());
674 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
675 });
676 m.def("TFE_ContextHasFunction", [](py::handle& ctx, const char* name) {
677 tensorflow::Safe_TF_StatusPtr status =
678 tensorflow::make_safe(TF_NewStatus());
679 auto output =
680 TFE_ContextHasFunction(tensorflow::InputTFE_Context(ctx), name);
681 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
682 return output;
683 });
684 m.def("TFE_ContextListFunctionNames", [](py::handle& ctx) {
685 return tensorflow::unwrap(tensorflow::InputTFE_Context(ctx))
686 ->ListFunctionNames();
687 });
688 m.def("TFE_ContextEnableRunMetadata", [](py::handle& ctx) {
689 TFE_ContextEnableRunMetadata(tensorflow::InputTFE_Context(ctx));
690 });
691 m.def("TFE_ContextDisableRunMetadata", [](py::handle& ctx) {
692 TFE_ContextEnableRunMetadata(tensorflow::InputTFE_Context(ctx));
693 });
694 m.def("TFE_ContextEnableGraphCollection", [](py::handle& ctx) {
695 TFE_ContextEnableGraphCollection(tensorflow::InputTFE_Context(ctx));
696 });
697 m.def("TFE_ContextDisableGraphCollection", [](py::handle& ctx) {
698 TFE_ContextDisableGraphCollection(tensorflow::InputTFE_Context(ctx));
699 });
700 m.def("TFE_ContextExportRunMetadata", [](py::handle& ctx, TF_Buffer& buf) {
701 tensorflow::Safe_TF_StatusPtr status =
702 tensorflow::make_safe(TF_NewStatus());
703 TFE_ContextExportRunMetadata(tensorflow::InputTFE_Context(ctx), &buf,
704 status.get());
705 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
706 });
707 m.def("TFE_ContextClearCaches", [](py::handle& o) {
708 TFE_ContextClearCaches(tensorflow::InputTFE_Context(o));
709 });
710 m.def("TFE_GetContextId", [](py::handle& ctx) {
711 return TFE_GetContextId(tensorflow::InputTFE_Context(ctx));
712 });
713 m.def("TFE_ContextGetDevicePlacementPolicy", [](py::handle& ctx) {
714 return TFE_ContextGetDevicePlacementPolicy(
715 tensorflow::InputTFE_Context(ctx));
716 });
717 m.def("TFE_ContextSetThreadLocalDevicePlacementPolicy",
718 [](py::handle& ctx, TFE_ContextDevicePlacementPolicy policy) {
719 TFE_ContextSetThreadLocalDevicePlacementPolicy(
720 tensorflow::InputTFE_Context(ctx), policy);
721 });
722 m.def("TFE_ContextSetServerDef", [](py::handle& ctx, int keep_alive_secs,
723 py::bytes proto) {
724 tensorflow::Safe_TF_StatusPtr status =
725 tensorflow::make_safe(TF_NewStatus());
726 tensorflow::Safe_TF_BufferPtr buf =
727 tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr()));
728 TFE_ContextSetServerDef(tensorflow::InputTFE_Context(ctx), keep_alive_secs,
729 buf.get()->data, buf.get()->length, status.get());
730 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
731 });
732 m.def("TFE_ContextUpdateServerDef", [](py::handle& ctx, int keep_alive_secs,
733 py::bytes proto) {
734 tensorflow::Safe_TF_StatusPtr status =
735 tensorflow::make_safe(TF_NewStatus());
736 tensorflow::Safe_TF_BufferPtr buf =
737 tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr()));
738 Py_BEGIN_ALLOW_THREADS;
739 TFE_ContextUpdateServerDef(tensorflow::InputTFE_Context(ctx),
740 keep_alive_secs, buf.get()->data,
741 buf.get()->length, status.get());
742 Py_END_ALLOW_THREADS;
743 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
744 });
745 m.def("TFE_ContextCheckAlive", [](py::handle& ctx, const char* worker_name) {
746 tensorflow::Safe_TF_StatusPtr status =
747 tensorflow::make_safe(TF_NewStatus());
748 bool output = TFE_ContextCheckAlive(tensorflow::InputTFE_Context(ctx),
749 worker_name, status.get());
750 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
751 return output;
752 });
753 m.def("TFE_ContextSyncExecutors", [](py::handle& ctx) {
754 tensorflow::Safe_TF_StatusPtr status =
755 tensorflow::make_safe(TF_NewStatus());
756 TFE_ContextAsyncWait(tensorflow::InputTFE_Context(ctx), status.get());
757 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
758 });
759 m.def("TFE_ContextClearExecutors", [](py::handle& ctx) {
760 tensorflow::Safe_TF_StatusPtr status =
761 tensorflow::make_safe(TF_NewStatus());
762 TFE_ContextAsyncWait(tensorflow::InputTFE_Context(ctx), status.get());
763 // NOTE: different from TFE_ContextSyncExecutors that raises potential
764 // errors, deliberately ignore executor statuses in cleanup.
765 });
766 m.def("TFE_ContextSetSoftDevicePlacement", [](py::handle& ctx, bool enable) {
767 tensorflow::Safe_TF_StatusPtr status =
768 tensorflow::make_safe(TF_NewStatus());
769 TFE_ContextSetSoftDevicePlacement(tensorflow::InputTFE_Context(ctx), enable,
770 status.get());
771 });
772 m.def("TFE_ContextSetLogDevicePlacement", [](py::handle& ctx, bool enable) {
773 tensorflow::Safe_TF_StatusPtr status =
774 tensorflow::make_safe(TF_NewStatus());
775 TFE_ContextSetSoftDevicePlacement(tensorflow::InputTFE_Context(ctx), enable,
776 status.get());
777 });
778
779 // TFE_Executor logic
780 m.def(
781 "TFE_NewExecutor",
782 [](const bool is_async) {
783 TFE_Executor* exc = TFE_NewExecutor(is_async);
784 return exc;
785 },
786 py::return_value_policy::reference);
787 m.def("TFE_DeleteExecutor", &TFE_DeleteExecutor);
788 m.def("TFE_ExecutorIsAsync", &TFE_ExecutorIsAsync);
789 m.def("TFE_ExecutorWaitForAllPendingNodes", [](TFE_Executor& exc) {
790 tensorflow::Safe_TF_StatusPtr status =
791 tensorflow::make_safe(TF_NewStatus());
792 // NOTE: release Python GIL for pending PyFunc ops to be executed properly.
793 Py_BEGIN_ALLOW_THREADS;
794 TFE_ExecutorWaitForAllPendingNodes(&exc, status.get());
795 Py_END_ALLOW_THREADS;
796 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
797 });
798 m.def("TFE_ExecutorClearError", &TFE_ExecutorClearError);
799 m.def("TFE_ContextSetExecutorForThread", [](py::handle& ctx,
800 TFE_Executor& exc) {
801 TFE_ContextSetExecutorForThread(tensorflow::InputTFE_Context(ctx), &exc);
802 });
803 m.def(
804 "TFE_ContextGetExecutorForThread",
805 [](py::handle& o) {
806 return TFE_ContextGetExecutorForThread(tensorflow::InputTFE_Context(o));
807 },
808 py::return_value_policy::reference);
809
810 m.def("TFE_OpNameGetAttrType",
811 [](py::handle& ctx, const char* op_or_function_name,
812 const char* attr_name) {
813 int temp = 0;
814 unsigned char* is_list = reinterpret_cast<unsigned char*>(&temp);
815 tensorflow::Safe_TF_StatusPtr status =
816 tensorflow::make_safe(TF_NewStatus());
817 auto output = TFE_OpNameGetAttrType(tensorflow::InputTFE_Context(ctx),
818 op_or_function_name, attr_name,
819 is_list, status.get());
820 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
821 #if PY_MAJOR_VERSION < 3
822 PyObject* output_pyo = PyInt_FromLong(output);
823 #else
824 PyObject* output_pyo = PyLong_FromLong(output);
825 #endif
826 if (*is_list == 1) {
827 PyObject* list = PyList_New(1);
828 PyList_SetItem(list, 0, output_pyo);
829 return tensorflow::PyoOrThrow(list);
830 }
831 return tensorflow::PyoOrThrow(output_pyo);
832 });
833 m.def("TFE_Py_InitEagerTensor", [](const py::handle& o) {
834 return tensorflow::PyoOrThrow(TFE_Py_InitEagerTensor(o.ptr()));
835 });
836 m.def("TFE_Py_PackEagerTensors",
837 [](const py::handle& context, const py::handle& handles) {
838 return tensorflow::TFE_Py_PackEagerTensors_wrapper(context, handles);
839 });
840 m.def("TFE_Py_SetEagerTensorProfiler", &TFE_Py_SetEagerTensorProfiler);
841 m.def("TFE_Py_RegisterJVPFunction", [](const py::handle& o) {
842 return tensorflow::PyoOrThrow(TFE_Py_RegisterJVPFunction(o.ptr()));
843 });
844 m.def("TFE_Py_RegisterGradientFunction", [](const py::handle& o) {
845 return tensorflow::PyoOrThrow(TFE_Py_RegisterGradientFunction(o.ptr()));
846 });
847 m.def("TFE_Py_Execute",
848 [](const py::handle& context, const char* device_name,
849 const char* op_name, const py::handle& inputs,
850 const py::handle& attrs, const py::handle& num_outputs) {
851 return tensorflow::TFE_Py_ExecuteCancelable_wrapper(
852 context, device_name, op_name, inputs, attrs.ptr(), nullptr,
853 num_outputs);
854 });
855 m.def(
856 "TFE_Py_ExecuteCancelable",
857 [](const py::handle& context, const char* device_name,
858 const char* op_name, const py::handle& inputs, const py::handle& attrs,
859 tensorflow::CancellationManager& cancellation_manager,
860 const py::handle& num_outputs) {
861 return tensorflow::TFE_Py_ExecuteCancelable_wrapper(
862 context, device_name, op_name, inputs, attrs.ptr(),
863 &cancellation_manager, num_outputs);
864 });
865 m.def("TFE_Py_FastPathExecute", [](const py::args args) {
866 // TFE_Py_FastPathExecute requires error checking prior to returning.
867 return tensorflow::PyoOrThrow(TFE_Py_FastPathExecute_C(args.ptr()));
868 });
869 m.def("TFE_Py_RecordGradient",
870 [](const py::handle& op_name, const py::handle& inputs,
871 const py::handle& attrs, const py::handle& results,
872 const py::handle& forward_pass_name_scope) {
873 return tensorflow::PyoOrThrow(TFE_Py_RecordGradient(
874 op_name.ptr(), inputs.ptr(), attrs.ptr(), results.ptr(),
875 forward_pass_name_scope.ptr()));
876 });
877 m.def("TFE_Py_UID", []() { return tensorflow::PyoOrThrow(TFE_Py_UID()); });
878
879 // TFE_Py_Tape Logic
880 m.def("TFE_Py_TapeSetNew", [](const py::handle& persistent,
881 const py::handle& watch_accessed_variables) {
882 return tensorflow::PyoOrThrow(
883 TFE_Py_TapeSetNew(persistent.ptr(), watch_accessed_variables.ptr()));
884 });
885 m.def("TFE_Py_TapeSetAdd",
886 [](const py::handle& tape) { TFE_Py_TapeSetAdd(tape.ptr()); });
887 m.def("TFE_Py_TapeSetRemove",
888 [](const py::handle& tape) { TFE_Py_TapeSetRemove(tape.ptr()); });
889 m.def("TFE_Py_TapeSetStopOnThread", &TFE_Py_TapeSetStopOnThread);
890 m.def("TFE_Py_TapeSetRestartOnThread", &TFE_Py_TapeSetRestartOnThread);
891 m.def("TFE_Py_TapeSetIsStopped",
892 []() { return tensorflow::PyoOrThrow(TFE_Py_TapeSetIsStopped()); });
893 m.def("TFE_Py_TapeSetIsEmpty",
894 []() { return tensorflow::PyoOrThrow(TFE_Py_TapeSetIsEmpty()); });
895 m.def("TFE_Py_TapeSetShouldRecordBackprop", [](const py::handle& tensors) {
896 return tensorflow::PyoOrThrow(
897 TFE_Py_TapeSetShouldRecordBackprop(tensors.ptr()));
898 });
899 m.def("TFE_Py_TapeSetPossibleGradientTypes", [](const py::handle& tensors) {
900 return tensorflow::PyoOrThrow(
901 TFE_Py_TapeSetPossibleGradientTypes(tensors.ptr()));
902 });
903 m.def("TFE_Py_TapeSetDeleteTrace", &TFE_Py_TapeSetDeleteTrace);
904 m.def("TFE_Py_TapeSetRecordOperation",
905 [](const py::handle& op_type, const py::handle& output_tensors,
906 const py::handle& input_tensors, const py::handle& backward_function,
907 const py::handle& forward_function) {
908 return tensorflow::PyoOrThrow(TFE_Py_TapeSetRecordOperation(
909 op_type.ptr(), output_tensors.ptr(), input_tensors.ptr(),
910 backward_function.ptr(), forward_function.ptr()));
911 });
912 m.def(
913 "TFE_Py_TapeSetRecordOperationBackprop",
914 [](const py::handle& op_type, const py::handle& output_tensors,
915 const py::handle& input_tensors, const py::handle& backward_function) {
916 return tensorflow::PyoOrThrow(TFE_Py_TapeSetRecordOperationBackprop(
917 op_type.ptr(), output_tensors.ptr(), input_tensors.ptr(),
918 backward_function.ptr()));
919 });
920 m.def(
921 "TFE_Py_TapeSetRecordOperationForwardprop",
922 [](const py::handle& op_type, const py::handle& output_tensors,
923 const py::handle& input_tensors, const py::handle& backward_function,
924 const py::handle& forwardprop_output_indices) {
925 return tensorflow::PyoOrThrow(TFE_Py_TapeSetRecordOperationForwardprop(
926 op_type.ptr(), output_tensors.ptr(), input_tensors.ptr(),
927 backward_function.ptr(), forwardprop_output_indices.ptr()));
928 });
929 m.def("TFE_Py_TapeGradient",
930 [](const py::handle& tape, const py::handle& target,
931 const py::handle& sources, const py::handle& output_gradients,
932 const py::handle& sources_raw,
933 const py::handle& unconnected_gradients) {
934 tensorflow::Safe_TF_StatusPtr status =
935 tensorflow::make_safe(TF_NewStatus());
936 PyObject* output = TFE_Py_TapeGradient(
937 tape.ptr(), target.ptr(), sources.ptr(), output_gradients.ptr(),
938 sources_raw.ptr(), unconnected_gradients.ptr(), status.get());
939 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
940 return tensorflow::PyoOrThrow(output);
941 });
942
943 m.def("TFE_Py_TapeVariableAccessed", [](const py::handle& variable) {
944 TFE_Py_TapeVariableAccessed(variable.ptr());
945 });
946 m.def("TFE_Py_TapeWatch",
947 [](const py::handle& tape, const py::handle& tensor) {
948 TFE_Py_TapeWatch(tape.ptr(), tensor.ptr());
949 });
950 m.def("TFE_Py_TapeWatchVariable",
951 [](const py::handle& tape, const py::handle& variable) {
952 TFE_Py_TapeWatchVariable(tape.ptr(), variable.ptr());
953 });
954 m.def("TFE_Py_TapeWatchedVariables", [](const py::handle& tape) {
955 return tensorflow::PyoOrThrow(TFE_Py_TapeWatchedVariables(tape.ptr()));
956 });
957
958 // TFE_Py_VariableWatcher logic.
959 m.def("TFE_Py_VariableWatcherNew",
960 []() { return tensorflow::PyoOrThrow(TFE_Py_VariableWatcherNew()); });
961 m.def("TFE_Py_VariableWatcherRemove", [](const py::handle& variable_watcher) {
962 TFE_Py_VariableWatcherRemove(variable_watcher.ptr());
963 });
964 m.def("TFE_Py_VariableWatcherVariableAccessed",
965 [](const py::handle& variable) {
966 TFE_Py_VariableWatcherVariableAccessed(variable.ptr());
967 });
968 m.def("TFE_Py_VariableWatcherWatchedVariables",
969 [](const py::handle& variable_watcher) {
970 return tensorflow::PyoOrThrow(
971 TFE_Py_VariableWatcherWatchedVariables(variable_watcher.ptr()));
972 });
973
974 // TFE_Py_ForwardAccumulator logic.
975 m.def("TFE_Py_ForwardAccumulatorNew", [](bool use_batch) {
976 return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorNew(use_batch));
977 });
978
979 m.def("TFE_Py_ForwardAccumulatorSetAdd", [](const py::handle& accumulator) {
980 return tensorflow::PyoOrThrow(
981 TFE_Py_ForwardAccumulatorSetAdd(accumulator.ptr()));
982 });
983 m.def("TFE_Py_ForwardAccumulatorSetRemove",
984 [](const py::handle& accumulator) {
985 TFE_Py_ForwardAccumulatorSetRemove(accumulator.ptr());
986 });
987
988 m.def("TFE_Py_ForwardAccumulatorWatch",
989 [](const py::handle& accumulator, const py::handle& tensor,
990 const py::handle& tangent) {
991 TFE_Py_ForwardAccumulatorWatch(accumulator.ptr(), tensor.ptr(),
992 tangent.ptr());
993 });
994 m.def("TFE_Py_ForwardAccumulatorJVP",
995 [](const py::handle& accumulator, const py::handle& tensor) {
996 return tensorflow::PyoOrThrow(
997 TFE_Py_ForwardAccumulatorJVP(accumulator.ptr(), tensor.ptr()));
998 });
999 m.def("TFE_Py_ForwardAccumulatorPushState", []() {
1000 return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorPushState());
1001 });
1002 m.def("TFE_Py_ForwardAccumulatorPopState", []() {
1003 return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorPopState());
1004 });
1005 m.def("TFE_Py_PackJVPs", [](const py::handle& tensors) {
1006 return tensorflow::PyoOrThrow(TFE_Py_PackJVPs(tensors.ptr()));
1007 });
1008
1009 // TFE_ContextOptions Logic
1010 m.def("TFE_NewContextOptions", &TFE_NewContextOptions,
1011 py::return_value_policy::reference);
1012 m.def("TFE_ContextOptionsSetConfig", [](TFE_ContextOptions* options,
1013 py::bytes proto) {
1014 tensorflow::Safe_TF_StatusPtr status =
1015 tensorflow::make_safe(TF_NewStatus());
1016 tensorflow::Safe_TF_BufferPtr buf =
1017 tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr()));
1018 TFE_ContextOptionsSetConfig(options, buf.get()->data, buf.get()->length,
1019 status.get());
1020 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1021 });
1022 m.def("TFE_ContextOptionsSetDevicePlacementPolicy",
1023 &TFE_ContextOptionsSetDevicePlacementPolicy);
1024 m.def("TFE_ContextOptionsSetTfrt", &TFE_ContextOptionsSetTfrt);
1025 m.def("TFE_ContextOptionsSetAsync", &TFE_ContextOptionsSetAsync);
1026 m.def("TFE_DeleteContextOptions", &TFE_DeleteContextOptions,
1027 py::return_value_policy::reference);
1028
1029 // TFE_Py_TensorShape Logic
1030 m.def("TFE_Py_TensorShapeSlice",
1031 [](const py::handle& tensors, int slice_dim) {
1032 return tensorflow::PyoOrThrow(
1033 TFE_Py_TensorShapeSlice(tensors.ptr(), slice_dim));
1034 });
1035 m.def("TFE_Py_TensorShapeOnDevice", [](const py::handle& tensors,
1036 int slice_dim) {
1037 return tensorflow::PyoOrThrow(TFE_Py_TensorShapeOnDevice(tensors.ptr()));
1038 });
1039 m.def("TFE_Py_EnableInteractivePythonLogging",
1040 &TFE_Py_EnableInteractivePythonLogging);
1041
1042 // Additional Context Logic
1043 m.def("TFE_Py_SetEagerContext", [](const py::handle& o) {
1044 return tensorflow::PyoOrThrow(TFE_Py_SetEagerContext(o.ptr()));
1045 });
1046 m.def("TFE_Py_RegisterVSpace", [](const py::handle& o) {
1047 return tensorflow::PyoOrThrow(TFE_Py_RegisterVSpace(o.ptr()));
1048 });
1049 m.def("TFE_Py_EncodeArg",
1050 [](const py::handle& o, bool include_tensor_ranks_only) {
1051 return tensorflow::PyoOrThrow(
1052 TFE_Py_EncodeArg(o.ptr(), include_tensor_ranks_only));
1053 });
1054 m.def("TFE_EnableCollectiveOps", [](const py::handle& ctx, py::bytes proto) {
1055 tensorflow::Safe_TF_StatusPtr status =
1056 tensorflow::make_safe(TF_NewStatus());
1057 tensorflow::Safe_TF_BufferPtr buf =
1058 tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr()));
1059 TFE_EnableCollectiveOps(tensorflow::InputTFE_Context(ctx), buf.get()->data,
1060 buf.get()->length, status.get());
1061 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1062 });
1063 m.def("TFE_AbortCollectiveOps", [](const py::handle& ctx, int code,
1064 const char* message) {
1065 tensorflow::Safe_TF_StatusPtr status =
1066 tensorflow::make_safe(TF_NewStatus());
1067 TF_SetStatus(status.get(), static_cast<TF_Code>(code), message);
1068 TFE_AbortCollectiveOps(tensorflow::InputTFE_Context(ctx), status.get());
1069 });
1070 m.def("TFE_CollectiveOpsCheckPeerHealth",
1071 [](const py::handle& ctx, const char* task, int64_t timeout_in_ms) {
1072 tensorflow::Safe_TF_StatusPtr status =
1073 tensorflow::make_safe(TF_NewStatus());
1074 TFE_CollectiveOpsCheckPeerHealth(tensorflow::InputTFE_Context(ctx),
1075 task, timeout_in_ms, status.get());
1076 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1077 });
1078 m.def("TF_ListPhysicalDevices", &tensorflow::TF_ListPhysicalDevices);
1079 m.def("TF_GetDeviceDetails", &tensorflow::TF_GetDeviceDetails);
1080 m.def("TF_DeleteDeviceList", &TF_DeleteDeviceList,
1081 py::return_value_policy::reference);
1082 m.def("TF_DeviceListCount", &TF_DeviceListCount);
1083 m.def("TF_DeviceListName", [](const TF_DeviceList* list, int index) {
1084 tensorflow::Safe_TF_StatusPtr status =
1085 tensorflow::make_safe(TF_NewStatus());
1086 auto output = TF_DeviceListName(list, index, status.get());
1087 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1088 return output;
1089 });
1090 m.def("TF_DeviceListType", [](const TF_DeviceList* list, int index) {
1091 tensorflow::Safe_TF_StatusPtr status =
1092 tensorflow::make_safe(TF_NewStatus());
1093 auto output = TF_DeviceListType(list, index, status.get());
1094 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1095 return output;
1096 });
1097
1098 m.def("TF_PickUnusedPortOrDie", &TF_PickUnusedPortOrDie);
1099
1100 // TFE_MonitoringCounter Logic
1101 m.def("TFE_MonitoringCounterCellIncrementBy",
1102 &TFE_MonitoringCounterCellIncrementBy);
1103 m.def("TFE_MonitoringCounterCellValue", &TFE_MonitoringCounterCellValue);
1104 m.def(
1105 "TFE_MonitoringNewCounter0",
1106 [](const char* name, const char* description) {
1107 tensorflow::Safe_TF_StatusPtr status =
1108 tensorflow::make_safe(TF_NewStatus());
1109 auto output =
1110 TFE_MonitoringNewCounter0(name, status.get(), description);
1111 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1112 return output;
1113 },
1114 py::return_value_policy::reference);
1115 m.def("TFE_MonitoringDeleteCounter0", &TFE_MonitoringDeleteCounter0,
1116 py::return_value_policy::reference);
1117 m.def("TFE_MonitoringGetCellCounter0", &TFE_MonitoringGetCellCounter0,
1118 py::return_value_policy::reference);
1119 m.def(
1120 "TFE_MonitoringNewCounter1",
1121 [](const char* name, const char* description, const char* label1) {
1122 tensorflow::Safe_TF_StatusPtr status =
1123 tensorflow::make_safe(TF_NewStatus());
1124 auto output =
1125 TFE_MonitoringNewCounter1(name, status.get(), description, label1);
1126 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1127 return output;
1128 },
1129 py::return_value_policy::reference);
1130 m.def("TFE_MonitoringDeleteCounter1", &TFE_MonitoringDeleteCounter1,
1131 py::return_value_policy::reference);
1132 m.def("TFE_MonitoringGetCellCounter1", &TFE_MonitoringGetCellCounter1,
1133 py::return_value_policy::reference);
1134 m.def(
1135 "TFE_MonitoringNewCounter2",
1136 [](const char* name, const char* description, const char* label1,
1137 const char* label2) {
1138 tensorflow::Safe_TF_StatusPtr status =
1139 tensorflow::make_safe(TF_NewStatus());
1140 auto output = TFE_MonitoringNewCounter2(name, status.get(), description,
1141 label1, label2);
1142 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1143 return output;
1144 },
1145 py::return_value_policy::reference);
1146 m.def("TFE_MonitoringDeleteCounter2", &TFE_MonitoringDeleteCounter2,
1147 py::return_value_policy::reference);
1148 m.def("TFE_MonitoringGetCellCounter2", &TFE_MonitoringGetCellCounter2,
1149 py::return_value_policy::reference);
1150
1151 // TFE_MonitoringIntGauge Logic
1152 m.def("TFE_MonitoringIntGaugeCellSet", &TFE_MonitoringIntGaugeCellSet);
1153 m.def("TFE_MonitoringIntGaugeCellValue", &TFE_MonitoringIntGaugeCellValue);
1154 m.def(
1155 "TFE_MonitoringNewIntGauge0",
1156 [](const char* name, const char* description) {
1157 tensorflow::Safe_TF_StatusPtr status =
1158 tensorflow::make_safe(TF_NewStatus());
1159 auto output =
1160 TFE_MonitoringNewIntGauge0(name, status.get(), description);
1161 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1162 return output;
1163 },
1164 py::return_value_policy::reference);
1165 m.def("TFE_MonitoringDeleteIntGauge0", &TFE_MonitoringDeleteIntGauge0,
1166 py::return_value_policy::reference);
1167 m.def("TFE_MonitoringGetCellIntGauge0", &TFE_MonitoringGetCellIntGauge0,
1168 py::return_value_policy::reference);
1169 m.def(
1170 "TFE_MonitoringNewIntGauge1",
1171 [](const char* name, const char* description, const char* label1) {
1172 tensorflow::Safe_TF_StatusPtr status =
1173 tensorflow::make_safe(TF_NewStatus());
1174 auto output =
1175 TFE_MonitoringNewIntGauge1(name, status.get(), description, label1);
1176 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1177 return output;
1178 },
1179 py::return_value_policy::reference);
1180 m.def("TFE_MonitoringDeleteIntGauge1", &TFE_MonitoringDeleteIntGauge1,
1181 py::return_value_policy::reference);
1182 m.def("TFE_MonitoringGetCellIntGauge1", &TFE_MonitoringGetCellIntGauge1,
1183 py::return_value_policy::reference);
1184 m.def(
1185 "TFE_MonitoringNewIntGauge2",
1186 [](const char* name, const char* description, const char* label1,
1187 const char* label2) {
1188 tensorflow::Safe_TF_StatusPtr status =
1189 tensorflow::make_safe(TF_NewStatus());
1190 auto output = TFE_MonitoringNewIntGauge2(name, status.get(),
1191 description, label1, label2);
1192 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1193 return output;
1194 },
1195 py::return_value_policy::reference);
1196 m.def("TFE_MonitoringDeleteIntGauge2", &TFE_MonitoringDeleteIntGauge2,
1197 py::return_value_policy::reference);
1198 m.def("TFE_MonitoringGetCellIntGauge2", &TFE_MonitoringGetCellIntGauge2,
1199 py::return_value_policy::reference);
1200 m.def("TFE_MonitoringStringGaugeCellSet", &TFE_MonitoringStringGaugeCellSet);
1201 m.def("TFE_MonitoringStringGaugeCellValue",
1202 &TFE_MonitoringStringGaugeCellValue);
1203 m.def(
1204 "TFE_MonitoringNewStringGauge0",
1205 [](const char* name, const char* description) {
1206 tensorflow::Safe_TF_StatusPtr status =
1207 tensorflow::make_safe(TF_NewStatus());
1208 auto output =
1209 TFE_MonitoringNewStringGauge0(name, status.get(), description);
1210 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1211 return output;
1212 },
1213 py::return_value_policy::reference);
1214
1215 // TFE_MonitoringStringGauge Logic
1216 m.def("TFE_MonitoringDeleteStringGauge0", &TFE_MonitoringDeleteStringGauge0);
1217 m.def("TFE_MonitoringGetCellStringGauge0", &TFE_MonitoringGetCellStringGauge0,
1218 py::return_value_policy::reference);
1219 m.def(
1220 "TFE_MonitoringNewStringGauge1",
1221 [](const char* name, const char* description, const char* label1) {
1222 tensorflow::Safe_TF_StatusPtr status =
1223 tensorflow::make_safe(TF_NewStatus());
1224 auto output = TFE_MonitoringNewStringGauge1(name, status.get(),
1225 description, label1);
1226 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1227 return output;
1228 },
1229 py::return_value_policy::reference);
1230 m.def("TFE_MonitoringDeleteStringGauge1", &TFE_MonitoringDeleteStringGauge1);
1231 m.def("TFE_MonitoringGetCellStringGauge1", &TFE_MonitoringGetCellStringGauge1,
1232 py::return_value_policy::reference);
1233 m.def(
1234 "TFE_MonitoringNewStringGauge2",
1235 [](const char* name, const char* description, const char* label1,
1236 const char* label2) {
1237 tensorflow::Safe_TF_StatusPtr status =
1238 tensorflow::make_safe(TF_NewStatus());
1239 auto output = TFE_MonitoringNewStringGauge2(
1240 name, status.get(), description, label1, label2);
1241 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1242 return output;
1243 },
1244 py::return_value_policy::reference);
1245 m.def("TFE_MonitoringDeleteStringGauge2", &TFE_MonitoringDeleteStringGauge2);
1246 m.def("TFE_MonitoringGetCellStringGauge2", &TFE_MonitoringGetCellStringGauge2,
1247 py::return_value_policy::reference);
1248
1249 // TFE_MonitoringBoolGauge Logic
1250 m.def("TFE_MonitoringBoolGaugeCellSet", &TFE_MonitoringBoolGaugeCellSet);
1251 m.def("TFE_MonitoringBoolGaugeCellValue", &TFE_MonitoringBoolGaugeCellValue);
1252 m.def(
1253 "TFE_MonitoringNewBoolGauge0",
1254 [](const char* name, const char* description) {
1255 tensorflow::Safe_TF_StatusPtr status =
1256 tensorflow::make_safe(TF_NewStatus());
1257 auto output =
1258 TFE_MonitoringNewBoolGauge0(name, status.get(), description);
1259 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1260 return output;
1261 },
1262 py::return_value_policy::reference);
1263 m.def("TFE_MonitoringDeleteBoolGauge0", &TFE_MonitoringDeleteBoolGauge0,
1264 py::return_value_policy::reference);
1265 m.def("TFE_MonitoringGetCellBoolGauge0", &TFE_MonitoringGetCellBoolGauge0,
1266 py::return_value_policy::reference);
1267 m.def(
1268 "TFE_MonitoringNewBoolGauge1",
1269 [](const char* name, const char* description, const char* label1) {
1270 tensorflow::Safe_TF_StatusPtr status =
1271 tensorflow::make_safe(TF_NewStatus());
1272 auto output = TFE_MonitoringNewBoolGauge1(name, status.get(),
1273 description, label1);
1274 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1275 return output;
1276 },
1277 py::return_value_policy::reference);
1278 m.def("TFE_MonitoringDeleteBoolGauge1", &TFE_MonitoringDeleteBoolGauge1,
1279 py::return_value_policy::reference);
1280 m.def("TFE_MonitoringGetCellBoolGauge1", &TFE_MonitoringGetCellBoolGauge1,
1281 py::return_value_policy::reference);
1282 m.def(
1283 "TFE_MonitoringNewBoolGauge2",
1284 [](const char* name, const char* description, const char* label1,
1285 const char* label2) {
1286 tensorflow::Safe_TF_StatusPtr status =
1287 tensorflow::make_safe(TF_NewStatus());
1288 auto output = TFE_MonitoringNewBoolGauge2(name, status.get(),
1289 description, label1, label2);
1290 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1291 return output;
1292 },
1293 py::return_value_policy::reference);
1294 m.def("TFE_MonitoringDeleteBoolGauge2", &TFE_MonitoringDeleteBoolGauge2,
1295 py::return_value_policy::reference);
1296 m.def("TFE_MonitoringGetCellBoolGauge2", &TFE_MonitoringGetCellBoolGauge2,
1297 py::return_value_policy::reference);
1298
1299 // TFE_MonitoringSampler Logic
1300 m.def("TFE_MonitoringSamplerCellAdd", &TFE_MonitoringSamplerCellAdd);
1301 m.def("TFE_MonitoringSamplerCellValue", &TFE_MonitoringSamplerCellValue);
1302 m.def("TFE_MonitoringNewExponentialBuckets",
1303 &TFE_MonitoringNewExponentialBuckets,
1304 py::return_value_policy::reference);
1305 m.def("TFE_MonitoringDeleteBuckets", &TFE_MonitoringDeleteBuckets,
1306 py::return_value_policy::reference);
1307 m.def(
1308 "TFE_MonitoringNewSampler0",
1309 [](const char* name, TFE_MonitoringBuckets* buckets,
1310 const char* description) {
1311 tensorflow::Safe_TF_StatusPtr status =
1312 tensorflow::make_safe(TF_NewStatus());
1313 auto output =
1314 TFE_MonitoringNewSampler0(name, buckets, status.get(), description);
1315 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1316 return output;
1317 },
1318 py::return_value_policy::reference);
1319 m.def("TFE_MonitoringDeleteSampler0", &TFE_MonitoringDeleteSampler0,
1320 py::return_value_policy::reference);
1321 m.def("TFE_MonitoringGetCellSampler0", &TFE_MonitoringGetCellSampler0,
1322 py::return_value_policy::reference);
1323 m.def(
1324 "TFE_MonitoringNewSampler1",
1325 [](const char* name, TFE_MonitoringBuckets* buckets,
1326 const char* description, const char* label1) {
1327 tensorflow::Safe_TF_StatusPtr status =
1328 tensorflow::make_safe(TF_NewStatus());
1329 auto output = TFE_MonitoringNewSampler1(name, buckets, status.get(),
1330 description, label1);
1331 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1332 return output;
1333 },
1334 py::return_value_policy::reference);
1335 m.def("TFE_MonitoringDeleteSampler1", &TFE_MonitoringDeleteSampler1,
1336 py::return_value_policy::reference);
1337 m.def("TFE_MonitoringGetCellSampler1", &TFE_MonitoringGetCellSampler1,
1338 py::return_value_policy::reference);
1339 m.def(
1340 "TFE_MonitoringNewSampler2",
1341 [](const char* name, TFE_MonitoringBuckets* buckets,
1342 const char* description, const char* label1, const char* label2) {
1343 tensorflow::Safe_TF_StatusPtr status =
1344 tensorflow::make_safe(TF_NewStatus());
1345 auto output = TFE_MonitoringNewSampler2(name, buckets, status.get(),
1346 description, label1, label2);
1347 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1348 return output;
1349 },
1350 py::return_value_policy::reference);
1351 m.def("TFE_MonitoringDeleteSampler2", &TFE_MonitoringDeleteSampler2,
1352 py::return_value_policy::reference);
1353 m.def("TFE_MonitoringGetCellSampler2", &TFE_MonitoringGetCellSampler2,
1354 py::return_value_policy::reference);
1355
1356 // TFE_CancellationManager Logic
1357 m.def("TFE_NewCancellationManager",
1358 []() { return new tensorflow::CancellationManager(); });
1359 m.def("TFE_CancellationManagerIsCancelled",
1360 &tensorflow::CancellationManager::IsCancelled);
1361 m.def("TFE_CancellationManagerStartCancel",
1362 &tensorflow::CancellationManager::StartCancel);
1363
1364 m.def("TFE_ClearScalarCache", &tensorflow::TFE_ClearScalarCache);
1365
1366 // Util buffer helper functions
1367 m.def("TF_NewBufferFromString", &TF_NewBufferFromString,
1368 py::return_value_policy::reference);
1369
1370 // DLPack functions
1371 m.def("TFE_ToDlpackCapsule", [](py::handle& o) {
1372 PyObject* eager_tensor_pyobject_ptr = o.ptr();
1373 tensorflow::Safe_TF_StatusPtr status =
1374 tensorflow::make_safe(TF_NewStatus());
1375
1376 if (!EagerTensor_CheckExact(eager_tensor_pyobject_ptr)) {
1377 status->status = tensorflow::errors::InvalidArgument(
1378 "The argument to `to_dlpack` must be a TF tensor, not Python object");
1379 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1380 }
1381
1382 TFE_TensorHandle* thandle = EagerTensor_Handle(eager_tensor_pyobject_ptr);
1383 void* dlm_ptr = tensorflow::TFE_HandleToDLPack(thandle, status.get());
1384 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1385
1386 py::capsule capsule(
1387 dlm_ptr, tensorflow::kDlTensorCapsuleName, [](PyObject* capsule) {
1388 if (PyCapsule_IsValid(capsule, tensorflow::kDlTensorCapsuleName)) {
1389 void* dlm_rptr =
1390 PyCapsule_GetPointer(capsule, tensorflow::kDlTensorCapsuleName);
1391 if (dlm_rptr) {
1392 tensorflow::TFE_CallDLManagedTensorDeleter(dlm_rptr);
1393 PyCapsule_SetDestructor(capsule, nullptr);
1394 }
1395 }
1396 });
1397 return capsule;
1398 });
1399
1400 m.def("TFE_FromDlpackCapsule", [](const py::capsule& pycapsule,
1401 const py::handle& context) {
1402 tensorflow::Safe_TF_StatusPtr status =
1403 tensorflow::make_safe(TF_NewStatus());
1404 if (absl::string_view(pycapsule.name()) !=
1405 tensorflow::kDlTensorCapsuleName) {
1406 status->status = tensorflow::errors::InvalidArgument(
1407 "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". "
1408 "Note that a DLPack tensor may be consumed at most once.",
1409 absl::string_view(pycapsule.name()));
1410 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1411 }
1412
1413 TFE_TensorHandle* thandle = tensorflow::TFE_HandleFromDLPack(
1414 pycapsule, status.get(), tensorflow::InputTFE_Context(context));
1415
1416 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1417
1418 PyCapsule_SetName(pycapsule.ptr(), "used_dltensor");
1419 PyCapsule_SetDestructor(pycapsule.ptr(), nullptr);
1420
1421 PyObject* pyhandle = EagerTensorFromHandle(thandle);
1422 return tensorflow::PyoOrThrow(pyhandle);
1423 });
1424
1425 m.def("TFE_Py_RegisterCustomDevice", [](const py::handle& context,
1426 const py::capsule& device,
1427 const char* device_name,
1428 const py::capsule& device_info) {
1429 tensorflow::Safe_TF_StatusPtr status =
1430 tensorflow::make_safe(TF_NewStatus());
1431 if (absl::string_view(device.name()) != "TFE_CustomDevice") {
1432 status->status = tensorflow::errors::InvalidArgument(
1433 "Expected a capsule named 'TFE_CustomDevice' for the `device` "
1434 "argument, got ",
1435 absl::string_view(device.name()));
1436 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1437 }
1438 if (absl::string_view(device_info.name()) !=
1439 "TFE_CustomDevice_DeviceInfo") {
1440 status->status = tensorflow::errors::InvalidArgument(
1441 "Expected a capsule named 'TFE_CustomDevice_DeviceInfo' for "
1442 "the `device_info` argument, got ",
1443 absl::string_view(device_info.name()));
1444 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1445 }
1446 // TFE_RegisterCustomDevice takes ownership
1447 PyCapsule_SetDestructor(device_info.ptr(), nullptr);
1448 TFE_RegisterCustomDevice(
1449 tensorflow::InputTFE_Context(context),
1450 *reinterpret_cast<TFE_CustomDevice*>(
1451 PyCapsule_GetPointer(device.ptr(), "TFE_CustomDevice")),
1452 device_name,
1453 PyCapsule_GetPointer(device_info.ptr(), "TFE_CustomDevice_DeviceInfo"),
1454 status.get());
1455 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1456 });
1457
1458 py::class_<EagerContextThreadLocalDataWrapper>(m,
1459 "EagerContextThreadLocalData")
1460 .def(py::init<py::handle, py::handle, py::handle>(),
1461 py::arg("py_eager_context"), py::arg("is_eager"),
1462 py::arg("device_spec"))
1463 .def_property("is_eager",
1464 &EagerContextThreadLocalDataWrapper::get_is_eager,
1465 &EagerContextThreadLocalDataWrapper::set_is_eager)
1466 .def_property(
1467 "invoking_op_callbacks",
1468 &EagerContextThreadLocalDataWrapper::get_invoking_op_callbacks,
1469 &EagerContextThreadLocalDataWrapper::set_invoking_op_callbacks)
1470 .def_property("device_name",
1471 &EagerContextThreadLocalDataWrapper::get_device_name,
1472 &EagerContextThreadLocalDataWrapper::set_device_name)
1473 .def_property("scope_name",
1474 &EagerContextThreadLocalDataWrapper::get_scope_name,
1475 &EagerContextThreadLocalDataWrapper::set_scope_name)
1476 .def_property("device_spec",
1477 &EagerContextThreadLocalDataWrapper::get_device_spec,
1478 &EagerContextThreadLocalDataWrapper::set_device_spec)
1479 .def_property(
1480 "function_call_options",
1481 &EagerContextThreadLocalDataWrapper::get_function_call_options,
1482 &EagerContextThreadLocalDataWrapper::set_function_call_options)
1483 .def_property("executor",
1484 &EagerContextThreadLocalDataWrapper::get_executor,
1485 &EagerContextThreadLocalDataWrapper::set_executor)
1486 .def_property("op_callbacks",
1487 &EagerContextThreadLocalDataWrapper::get_op_callbacks,
1488 &EagerContextThreadLocalDataWrapper::set_op_callbacks);
1489
1490 // C API Enum
1491
1492 py::enum_<TFE_ContextDevicePlacementPolicy>(
1493 m, "TFE_ContextDevicePlacementPolicy")
1494 .value("TFE_DEVICE_PLACEMENT_EXPLICIT", TFE_DEVICE_PLACEMENT_EXPLICIT)
1495 .value("TFE_DEVICE_PLACEMENT_WARN", TFE_DEVICE_PLACEMENT_WARN)
1496 .value("TFE_DEVICE_PLACEMENT_SILENT", TFE_DEVICE_PLACEMENT_SILENT)
1497 .value("TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32",
1498 TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32)
1499 .export_values();
1500
1501 py::enum_<TF_AttrType>(m, "TF_AttrType")
1502 .value("TF_ATTR_STRING", TF_ATTR_STRING)
1503 .value("TF_ATTR_INT", TF_ATTR_INT)
1504 .value("TF_ATTR_FLOAT", TF_ATTR_FLOAT)
1505 .value("TF_ATTR_BOOL", TF_ATTR_BOOL)
1506 .value("TF_ATTR_TYPE", TF_ATTR_TYPE)
1507 .value("TF_ATTR_SHAPE", TF_ATTR_SHAPE)
1508 .value("TF_ATTR_TENSOR", TF_ATTR_TENSOR)
1509 .value("TF_ATTR_PLACEHOLDER", TF_ATTR_PLACEHOLDER)
1510 .value("TF_ATTR_FUNC", TF_ATTR_FUNC)
1511 .export_values();
1512 };
1513