• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");;
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 
18 #include "Python.h"
19 #include "absl/strings/str_format.h"
20 #include "pybind11/chrono.h"
21 #include "pybind11/complex.h"
22 #include "pybind11/functional.h"
23 #include "pybind11/pybind11.h"
24 #include "pybind11/pytypes.h"
25 #include "pybind11/stl.h"
26 #include "tensorflow/c/c_api.h"
27 #include "tensorflow/c/c_api_experimental.h"
28 #include "tensorflow/c/eager/c_api.h"
29 #include "tensorflow/c/eager/c_api_experimental.h"
30 #include "tensorflow/c/eager/c_api_internal.h"
31 #include "tensorflow/c/eager/dlpack.h"
32 #include "tensorflow/c/eager/tfe_cancellation_manager_internal.h"
33 #include "tensorflow/c/eager/tfe_context_internal.h"
34 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
35 #include "tensorflow/c/tf_status.h"
36 #include "tensorflow/c/tf_status_helper.h"
37 #include "tensorflow/compiler/jit/flags.h"
38 #include "tensorflow/compiler/jit/get_compiler_ir.h"
39 #include "tensorflow/python/eager/pywrap_tensor_conversion.h"
40 #include "tensorflow/python/eager/pywrap_tfe.h"
41 #include "tensorflow/python/lib/core/py_exception_registry.h"
42 #include "tensorflow/python/lib/core/pybind11_lib.h"
43 #include "tensorflow/python/lib/core/pybind11_status.h"
44 #include "tensorflow/python/lib/core/safe_ptr.h"
45 #include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
46 #include "tensorflow/python/util/util.h"
47 
48 namespace py = pybind11;
49 
50 PYBIND11_MAKE_OPAQUE(TFE_Executor);
51 PYBIND11_MAKE_OPAQUE(TFE_ContextOptions);
52 PYBIND11_MAKE_OPAQUE(tensorflow::CancellationManager);
53 
54 PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter0);
55 PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter1);
56 PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter2);
57 PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge0);
58 PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge1);
59 PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge2);
60 PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGauge0);
61 PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGauge1);
62 PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGauge2);
63 PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGauge0);
64 PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGauge1);
65 PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGauge2);
66 PYBIND11_MAKE_OPAQUE(TFE_MonitoringSampler0);
67 PYBIND11_MAKE_OPAQUE(TFE_MonitoringSampler1);
68 PYBIND11_MAKE_OPAQUE(TFE_MonitoringSampler2);
69 PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounterCell);
70 PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGaugeCell);
71 PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGaugeCell);
72 PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGaugeCell);
73 PYBIND11_MAKE_OPAQUE(TFE_MonitoringSamplerCell);
74 
75 PYBIND11_MAKE_OPAQUE(TF_DeviceList);
76 PYBIND11_MAKE_OPAQUE(TF_Function);
77 PYBIND11_MAKE_OPAQUE(TF_Buffer);
78 
79 // Eager helper functions migrated from pywrap_tfe.i.
80 
81 namespace tensorflow {
82 
83 // We cannot use Context as an opaque type. SWIG also had
84 // difficult directly passing the pointer around. These
85 // typemaps are migrated over from pywrap_tfe.i. I tried
86 // using a custom type caster, but we get segfaults periodically.
87 
88 // TODO(amitpatankar): Move input and output logic of Context into a
89 // pybind11 custom type caster.
90 
InputTFE_Context(const py::handle & ctx)91 TFE_Context* InputTFE_Context(const py::handle& ctx) {
92   return static_cast<TFE_Context*>(PyCapsule_GetPointer(ctx.ptr(), nullptr));
93 }
94 
OutputTFE_Context(TFE_Context * context)95 PyObject* OutputTFE_Context(TFE_Context* context) {
96   return PyCapsule_New(context, nullptr, TFE_DeleteContextCapsule);
97 }
98 
ProtoStringToTFBuffer(PyObject * input)99 TF_Buffer* ProtoStringToTFBuffer(PyObject* input) {
100   // Convert a Python string object to TF_Buffer.
101   char* c_string;
102   Py_ssize_t py_size;
103   // PyBytes_AsStringAndSize() does not copy but simply interprets the input
104   if (PyBytes_AsStringAndSize(input, &c_string, &py_size) == -1) {
105     // Python has raised an error (likely TypeError or UnicodeEncodeError).
106     throw py::error_already_set();
107   }
108   return TF_NewBufferFromString(static_cast<void*>(c_string),
109                                 static_cast<size_t>(py_size));
110 }
111 
112 // These functions are typemaps from the Python side. I did not use
113 // a custom type caster since the logic is slightly harder to follow. This
114 // converter is also only used once in `TFE_Py_ExecuteCancelable_wrapper`.
InputTFE_InputTensorHandles(const py::handle & input_tensors)115 TFE_InputTensorHandles InputTFE_InputTensorHandles(
116     const py::handle& input_tensors) {
117   TFE_InputTensorHandles input_tensor_handles;
118   if (input_tensors.ptr() != Py_None) {
119     if (!PyList_Check(input_tensors.ptr())) {
120       tensorflow::ThrowTypeError("must provide a list of Tensors as inputs");
121     }
122     Py_ssize_t len = PyList_Size(input_tensors.ptr());
123     input_tensor_handles.resize(len);
124     for (Py_ssize_t i = 0; i < len; ++i) {
125       PyObject* elem = PyList_GetItem(input_tensors.ptr(), i);
126       if (!elem) {
127         tensorflow::ThrowTypeError("Input Tensor does not exist.");
128       }
129       if (EagerTensor_CheckExact(elem)) {
130         (input_tensor_handles)[i] = EagerTensor_Handle(elem);
131       } else if (tensorflow::swig::IsEagerTensorSlow(elem)) {
132         // Use equivalent of object.__getattribute__ to get the underlying
133         // tf wrapped EagerTensor (if there is one).
134         tensorflow::Safe_PyObjectPtr tf_should_use_attr(
135 #if PY_MAJOR_VERSION < 3
136             PyString_InternFromString("_tf_should_use_wrapped_value")
137 #else
138             PyUnicode_InternFromString("_tf_should_use_wrapped_value")
139 #endif
140         );
141         tensorflow::Safe_PyObjectPtr value_attr(
142             PyObject_GenericGetAttr(elem, tf_should_use_attr.get()));
143         if (value_attr) {
144           // This is an EagerTensor wrapped inside a TFShouldUse wrapped object.
145           (input_tensor_handles)[i] = EagerTensor_Handle(value_attr.get());
146         } else {
147           // This is a subclass of EagerTensor that we don't support.
148           PyErr_Clear();
149           tensorflow::ThrowTypeError(
150               tensorflow::strings::StrCat(
151                   "Saw an object that is an instance of a strict subclass of "
152                   "EagerTensor, which is not supported.  Item ",
153                   i, " is type: ", elem->ob_type->tp_name)
154                   .c_str());
155         }
156       } else if (tensorflow::swig::IsTensor(elem)) {
157         // If it isnt an EagerTensor, but is still a Tensor, it must be a graph
158         // tensor.
159         tensorflow::Safe_PyObjectPtr name_attr(
160             PyObject_GetAttrString(elem, "name"));
161         tensorflow::ThrowTypeError(
162             tensorflow::strings::StrCat(
163                 "An op outside of the function building code is being passed\n"
164                 "a \"Graph\" tensor. It is possible to have Graph tensors\n"
165                 "leak out of the function building context by including a\n"
166                 "tf.init_scope in your function building code.\n"
167                 "For example, the following function will fail:\n",
168                 "  @tf.function\n", "  def has_init_scope():\n",
169                 "    my_constant = tf.constant(1.)\n",
170                 "    with tf.init_scope():\n",
171                 "      added = my_constant * 2\n",
172                 "The graph tensor has name: ",
173                 name_attr ? TFE_GetPythonString(name_attr.get()) : "<unknown>")
174                 .c_str());
175       } else {
176         tensorflow::ThrowTypeError(
177             tensorflow::strings::StrCat(
178                 "provided list of inputs contains objects other "
179                 "than 'EagerTensor'. Item ",
180                 i, " is type: ", elem->ob_type->tp_name)
181                 .c_str());
182       }
183     }
184   }
185   return input_tensor_handles;
186 }
187 
188 // These functions are typemaps from the Python side. I did not use
189 // a custom type caster since the logic is slightly harder to follow. This
190 // converter is also only used once in `TFE_Py_ExecuteCancelable_wrapper`.
191 // This function actually takes a number rather than an output Tensor holder.
InputTFE_OutputTensorHandles(const py::handle & num_outputs)192 TFE_OutputTensorHandles InputTFE_OutputTensorHandles(
193     const py::handle& num_outputs) {
194   TFE_OutputTensorHandles output_tensor_handles;
195 #if PY_MAJOR_VERSION < 3
196   if (!PyInt_Check(num_outputs.ptr())) {
197 #else
198   if (!PyLong_Check(num_outputs.ptr())) {
199 #endif
200     PyErr_SetString(PyExc_TypeError,
201                     "expected an integer value (size of the number of "
202                     "outputs of the operation)");
203     throw py::error_already_set();
204   }
205 #if PY_MAJOR_VERSION < 3
206   long sz = PyInt_AsLong(num_outputs.ptr());  // NOLINT
207 #else
208   long sz = PyLong_AsLong(num_outputs.ptr());  // NOLINT
209 #endif
210   // We can't handle more than int32 sizes for number of outputs.
211   if (static_cast<long>(static_cast<int32>(sz)) != sz) {  // NOLINT
212     PyErr_SetString(PyExc_ValueError, tensorflow::strings::StrCat(
213                                           "Number of outputs is too big: ", sz)
214                                           .c_str());
215     throw py::error_already_set();
216   }
217   if (sz > 0) {
218 #if PY_MAJOR_VERSION < 3
219     output_tensor_handles.resize(PyInt_AsLong(num_outputs.ptr()), nullptr);
220 #else
221     output_tensor_handles.resize(PyLong_AsLong(num_outputs.ptr()), nullptr);
222 #endif
223   }
224   return output_tensor_handles;
225 }
226 
227 // Packs multiple `EagerTensor`s of the same dtype and shape into one
228 // `EagerTensor`.
229 py::object TFE_Py_PackEagerTensors_wrapper(const py::handle& context,
230                                            const py::handle& tensors) {
231   TFE_Context* ctx = tensorflow::InputTFE_Context(context);
232   TFE_InputTensorHandles handles = InputTFE_InputTensorHandles(tensors);
233   tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus());
234   int size = handles.size();
235   TFE_TensorHandle* packed_handle =
236       TFE_CreatePackedTensorHandle(ctx, handles.data(), &size, status.get());
237   tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
238   PyObject* packed_tensor =
239       EagerTensorFromHandle(packed_handle, /*is_packed=*/true);
240   return tensorflow::PyoOrThrow(packed_tensor);
241 }
242 
243 // This function was created from fusing the typemap logic in platform/base.i.
244 py::object TFE_Py_ExecuteCancelable_wrapper(
245     const py::handle& context, const char* device_name, const char* op_name,
246     const py::handle& inputs, const py::handle& attrs,
247     tensorflow::CancellationManager* cancellation_manager,
248     const py::handle& num_outputs) {
249   TFE_Context* ctx = tensorflow::InputTFE_Context(context);
250   TFE_InputTensorHandles input_tensor_handles =
251       InputTFE_InputTensorHandles(inputs);
252   TFE_OutputTensorHandles output_tensor_handles =
253       InputTFE_OutputTensorHandles(num_outputs);
254   tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus());
255   TFE_Py_ExecuteCancelable(ctx, device_name, op_name, &input_tensor_handles,
256                            attrs.ptr(), tensorflow::wrap(cancellation_manager),
257                            &output_tensor_handles, status.get());
258 
259   int output_len = output_tensor_handles.size();
260   PyObject* output_list = PyList_New(output_len);
261   for (int i = 0; i < output_len; ++i) {
262     PyObject* output;
263     output = EagerTensorFromHandle(output_tensor_handles.at(i));
264     PyList_SetItem(output_list, i, output);
265   }
266   tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
267   return tensorflow::PyoOrThrow(output_list);
268 }
269 
270 static py::object TF_ListPhysicalDevices() {
271   std::vector<string> devices;
272   tensorflow::Status s =
273       tensorflow::DeviceFactory::ListAllPhysicalDevices(&devices);
274   MaybeRaiseRegisteredFromStatus(s);
275   PyObject* result = PyList_New(devices.size());
276   int i = 0;
277   for (auto& dev : devices) {
278     PyObject* dev_obj = PyBytes_FromStringAndSize(dev.data(), dev.size());
279     PyList_SetItem(result, i, dev_obj);
280     ++i;
281   }
282   return tensorflow::PyoOrThrow(result);
283 }
284 
285 static std::unordered_map<string, string> TF_GetDeviceDetails(int index) {
286   tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus());
287   std::unordered_map<string, string> device_details;
288   tensorflow::Status s =
289       tensorflow::DeviceFactory::GetAnyDeviceDetails(index, &device_details);
290   tensorflow::Set_TF_Status_from_Status(status.get(), s);
291   MaybeRaiseRegisteredFromTFStatus(status.get());
292   return device_details;
293 }
294 
295 static py::object TFE_ClearScalarCache() {
296   tensorflow::TFE_TensorHandleCache::Get()->Clear();
297   return py::none();
298 }
299 
300 // Returns compiler IR for a given function.
301 static py::bytes TFE_GetCompilerIr(py::handle& ctx,
302                                    const char* concrete_function_name,
303                                    const char* stage, const char* device_name,
304                                    py::handle& inputs) {
305   EagerContext* context = ContextFromInterface(
306       reinterpret_cast<ImmediateExecutionContext*>(InputTFE_Context(ctx)));
307 
308   std::string s_stage(stage);
309   IrExportStage selected_stage = [&] {
310     if (s_stage == "hlo") {
311       return IrExportStage::HLO;
312     } else if (s_stage == "hlo_serialized") {
313       return IrExportStage::HLO_SERIALIZED;
314     } else if (s_stage == "optimized_hlo") {
315       return IrExportStage::OPTIMIZED_HLO;
316     } else if (s_stage == "optimized_hlo_serialized") {
317       return IrExportStage::OPTIMIZED_HLO_SERIALIZED;
318     } else if (s_stage == "optimized_hlo_dot") {
319       return IrExportStage::OPTIMIZED_HLO_DOT;
320     } else {
321       ThrowValueError(
322           absl::StrFormat("Invalid stage selected: '%s'. Valid values are: "
323                           "'hlo', 'optimized_hlo', 'optimized_hlo_dot'",
324                           s_stage)
325               .c_str());
326     }
327   }();
328 
329   TFE_InputTensorHandles handles = InputTFE_InputTensorHandles(inputs);
330 
331   std::vector<const TensorHandle*> input_handles;
332   for (TFE_TensorHandle* tensor_handle : handles) {
333     AbstractTensorHandle* abstract_tensor_handle = unwrap(tensor_handle);
334     input_handles.push_back(TensorHandleFromInterface(abstract_tensor_handle));
335   }
336 
337   DeviceNameUtils::ParsedName input_device_name;
338   if (!DeviceNameUtils::ParseFullOrLocalName(device_name, &input_device_name)) {
339     ThrowValueError(
340         absl::StrFormat("Failed parsing device name: '%s'", device_name)
341             .c_str());
342   }
343 
344   std::vector<Device*> devices = context->local_device_mgr()->ListDevices();
345   auto selected_device = absl::c_find_if(devices, [&](const Device* d) {
346     return DeviceNameUtils::AreCompatibleDevNames(input_device_name,
347                                                   d->parsed_name());
348   });
349   if (selected_device == devices.end()) {
350     ThrowValueError(
351         absl::StrFormat("No matching device found for '%s'", device_name)
352             .c_str());
353   }
354 
355   xla::StatusOr<std::string> hlo_str =
356       GetCompilerIr(selected_stage, context->pflr(), concrete_function_name,
357                     *selected_device, context, input_handles);
358 
359   if (!hlo_str.ok()) {
360     ThrowValueError(absl::StrFormat("Failed getting HLO text: '%s'",
361                                     hlo_str.status().error_message())
362                         .c_str());
363   }
364   return py::bytes(*hlo_str);
365 }
366 
367 }  // namespace tensorflow
368 
369 namespace {
370 
371 // Wrapper around the EagerContextThreadLocalData struct (defined in
372 // pywrap_tfe.h), so it can be accessed from Python.
373 //
374 // For PyObject* fields, the get_*() methods return a new reference; and the
375 // set_*() methods create a new reference (i.e., they do not steal a reference).
376 class EagerContextThreadLocalDataWrapper {
377  public:
EagerContextThreadLocalDataWrapper(py::handle py_eager_context,py::handle is_eager,py::handle device_spec)378   explicit EagerContextThreadLocalDataWrapper(py::handle py_eager_context,
379                                               py::handle is_eager,
380                                               py::handle device_spec)
381       : py_eager_context_(py_eager_context.ptr()) {
382     tensorflow::MakeEagerContextThreadLocalData(
383         py_eager_context.ptr(), is_eager.ptr(), device_spec.ptr());
384   }
385 
~EagerContextThreadLocalDataWrapper()386   ~EagerContextThreadLocalDataWrapper() {
387     tensorflow::DestroyEagerContextThreadLocalData(py_eager_context_);
388   }
389 
get_is_eager() const390   bool get_is_eager() const { return GetData()->is_eager; }
set_is_eager(bool v)391   void set_is_eager(bool v) { GetData()->is_eager = v; }
392 
get_invoking_op_callbacks() const393   bool get_invoking_op_callbacks() const {
394     return GetData()->invoking_op_callbacks;
395   }
set_invoking_op_callbacks(bool v)396   void set_invoking_op_callbacks(bool v) {
397     GetData()->invoking_op_callbacks = v;
398   }
399 
get_device_name() const400   py::handle get_device_name() const {
401     return GetPyObject(&GetData()->device_name);
402   }
set_device_name(py::handle v)403   void set_device_name(py::handle v) {
404     SetPyObject(v, &GetData()->device_name);
405   }
406 
get_scope_name() const407   py::handle get_scope_name() const {
408     return GetPyObject(&GetData()->scope_name);
409   }
set_scope_name(py::handle v)410   void set_scope_name(py::handle v) { SetPyObject(v, &GetData()->scope_name); }
411 
get_device_spec() const412   py::handle get_device_spec() const {
413     return GetPyObject(&GetData()->device_spec);
414   }
set_device_spec(py::handle v)415   void set_device_spec(py::handle v) {
416     SetPyObject(v, &GetData()->device_spec);
417   }
418 
get_function_call_options() const419   py::handle get_function_call_options() const {
420     return GetPyObject(&GetData()->function_call_options);
421   }
set_function_call_options(py::handle v)422   void set_function_call_options(py::handle v) {
423     SetPyObject(v, &GetData()->function_call_options);
424   }
425 
get_executor() const426   py::handle get_executor() const { return GetPyObject(&GetData()->executor); }
set_executor(py::handle v)427   void set_executor(py::handle v) { SetPyObject(v, &GetData()->executor); }
428 
get_op_callbacks() const429   py::handle get_op_callbacks() const {
430     return GetPyObject(&GetData()->op_callbacks);
431   }
set_op_callbacks(py::handle v)432   void set_op_callbacks(py::handle v) {
433     SetPyObject(v, &GetData()->op_callbacks);
434   }
435 
436  private:
GetData() const437   tensorflow::EagerContextThreadLocalData* GetData() const {
438     auto* result =
439         tensorflow::GetEagerContextThreadLocalData(py_eager_context_);
440     if (!result) {
441       throw py::error_already_set();
442     }
443     return result;
444   }
445 
GetPyObject(tensorflow::Safe_PyObjectPtr * obj) const446   py::handle GetPyObject(tensorflow::Safe_PyObjectPtr* obj) const {
447     Py_INCREF(obj->get());
448     return obj->get();
449   }
450 
SetPyObject(py::handle value,tensorflow::Safe_PyObjectPtr * ptr)451   void SetPyObject(py::handle value, tensorflow::Safe_PyObjectPtr* ptr) {
452     Py_INCREF(value.ptr());
453     ptr->reset(value.ptr());
454   }
455 
456   PyObject* py_eager_context_;  // not owned (borrowed reference).
457 };
458 
459 }  // namespace
460 
461 // py::return_value_policy::reference is defined as specified by the
462 // pybind11 documents listed here.
463 // https://pybind11.readthedocs.io/en/stable/advanced/functions.html#return-value-policies
464 // This means that C++ maintains ownership of the object. We
465 // are only assigning this to functions that return opaque types.
466 
PYBIND11_MODULE(_pywrap_tfe,m)467 PYBIND11_MODULE(_pywrap_tfe, m) {
468   py::class_<TFE_Executor> TFE_Executor_class(m, "TFE_Executor");
469   py::class_<TFE_ContextOptions> TFE_ContextOptions_class(m,
470                                                           "TFE_ContextOptions");
471   py::class_<TFE_MonitoringCounter0> TFE_MonitoringCounter0_class(
472       m, "TFE_MonitoringCounter0");
473   py::class_<TFE_MonitoringCounter1> TFE_MonitoringCounter1_class(
474       m, "TFE_MonitoringCounter1");
475   py::class_<TFE_MonitoringCounter2> TFE_MonitoringCounter2_class(
476       m, "TFE_MonitoringCounter2");
477   py::class_<TFE_MonitoringStringGauge0> TFE_MonitoringStringGauge0_class(
478       m, "TFE_MonitoringStringGauge0");
479   py::class_<TFE_MonitoringStringGauge1> TFE_MonitoringStringGauge1_class(
480       m, "TFE_MonitoringStringGauge1");
481   py::class_<TFE_MonitoringStringGauge2> TFE_MonitoringStringGauge2_class(
482       m, "TFE_MonitoringStringGauge2");
483   py::class_<TFE_MonitoringIntGauge0> TFE_MonitoringIntGauge0_class(
484       m, "TFE_MonitoringIntGauge0");
485   py::class_<TFE_MonitoringIntGauge1> TFE_MonitoringIntGauge1_class(
486       m, "TFE_MonitoringIntGauge1");
487   py::class_<TFE_MonitoringIntGauge2> TFE_MonitoringIntGauge2_class(
488       m, "TFE_MonitoringIntGauge2");
489   py::class_<TFE_MonitoringBoolGauge0> TFE_MonitoringBoolGauge0_class(
490       m, "TFE_MonitoringBoolGauge0");
491   py::class_<TFE_MonitoringBoolGauge1> TFE_MonitoringBoolGauge1_class(
492       m, "TFE_MonitoringBoolGauge1");
493   py::class_<TFE_MonitoringBoolGauge2> TFE_MonitoringBoolGauge2_class(
494       m, "TFE_MonitoringBoolGauge2");
495   py::class_<TFE_MonitoringCounterCell> TFE_MonitoringCounterCell_class(
496       m, "TFE_MonitoringCounterCell");
497   py::class_<TFE_MonitoringIntGaugeCell> TFE_MonitoringIntGaugeCell_class(
498       m, "TFE_MonitoringIntGaugeCell");
499   py::class_<TFE_MonitoringStringGaugeCell> TFE_MonitoringStringGaugeCell_class(
500       m, "TFE_MonitoringStringGaugeCell");
501   py::class_<TFE_MonitoringBoolGaugeCell> TFE_MonitoringBoolGaugeCell_class(
502       m, "TFE_MonitoringBoolGaugeCell");
503   py::class_<TFE_MonitoringSamplerCell> TFE_MonitoringSamplerCell_class(
504       m, "TFE_MonitoringSamplerCell");
505   py::class_<TFE_MonitoringBuckets> TFE_MonitoringBuckets_class(
506       m, "TFE_MonitoringBuckets");
507   py::class_<TFE_MonitoringSampler0> TFE_MonitoringSampler0_class(
508       m, "TFE_MonitoringSampler0");
509   py::class_<TFE_MonitoringSampler1> TFE_MonitoringSampler1_class(
510       m, "TFE_MonitoringSampler1");
511   py::class_<TFE_MonitoringSampler2> TFE_MonitoringSampler2_class(
512       m, "TFE_MonitoringSampler2");
513   py::class_<tensorflow::CancellationManager> TFE_CancellationManager_class(
514       m, "TFE_CancellationManager");
515 
516   py::class_<TF_DeviceList> TF_DeviceList_class(m, "TF_DeviceList");
517   py::class_<TF_Function> TF_Function_class(m, "TF_Function");
518 
519   m.def("TFE_Py_RegisterExceptionClass", [](const py::handle& e) {
520     return tensorflow::PyoOrThrow(TFE_Py_RegisterExceptionClass(e.ptr()));
521   });
522   m.def("TFE_Py_RegisterFallbackExceptionClass", [](const py::handle& e) {
523     return tensorflow::PyoOrThrow(
524         TFE_Py_RegisterFallbackExceptionClass(e.ptr()));
525   });
526 
527   m.def(
528       "TFE_GetMemoryInfo", [](py::handle& ctx, const char* device_name) {
529         auto* context =
530             reinterpret_cast<tensorflow::ImmediateExecutionContext*>(
531                 tensorflow::InputTFE_Context(ctx));
532 
533         tensorflow::DeviceNameUtils::ParsedName input_device_name;
534         if (!tensorflow::DeviceNameUtils::ParseFullOrLocalName(
535                 device_name, &input_device_name)) {
536           tensorflow::ThrowValueError(
537               absl::StrFormat("Failed parsing device name: '%s'", device_name)
538                   .c_str());
539         }
540 
541         std::vector<tensorflow::Device*> devices =
542             context->ListLocalTfDevices();
543 
544         tensorflow::Device* matched_device = nullptr;
545         for (int device_idx = 0; device_idx < devices.size(); device_idx++) {
546           tensorflow::Device* device = devices[device_idx];
547 
548           if (tensorflow::DeviceNameUtils::AreCompatibleDevNames(
549                   input_device_name, device->parsed_name())) {
550             if (device->device_type() == tensorflow::DEVICE_CPU) {
551               tensorflow::ThrowValueError(
552                   "CPU does not support getting allocator information");
553             }
554 
555             if (matched_device != nullptr) {
556               tensorflow::ThrowValueError(
557                   absl::StrFormat(
558                       "Multiple devices matching the provided string "
559                       "'%s': '%s' and "
560                       "'%s' ",
561                       device_name, matched_device->name(), device->name())
562                       .c_str());
563             }
564             matched_device = device;
565           }
566         }
567 
568         if (matched_device == nullptr) {
569           tensorflow::ThrowValueError(
570               absl::StrFormat("No matching devices found for '%s'", device_name)
571                   .c_str());
572         }
573 
574         tensorflow::AllocatorAttributes attrs;
575         tensorflow::Allocator* allocator = matched_device->GetAllocator(attrs);
576 
577         if (absl::optional<tensorflow::AllocatorStats> stats =
578                 allocator->GetStats()) {
579           return std::map<std::string, int64_t>{
580               {"current", stats->bytes_in_use},
581               {"peak", stats->peak_bytes_in_use}};
582         }
583 
584         tensorflow::ThrowTypeError(
585             absl::StrFormat("Allocator stats not available for device '%s'",
586                             matched_device->name())
587                 .c_str());
588       });
589 
590   // XLA Eager Logic
591   m.def("TF_SetXlaEnableLazyCompilation", &TF_SetXlaEnableLazyCompilation);
592   m.def("TF_SetTfXlaCpuGlobalJit", &TF_SetTfXlaCpuGlobalJit);
593   m.def("TF_SetXlaAutoJitMode", &TF_SetXlaAutoJitMode);
594   m.def("TF_SetXlaConstantFoldingDisabled", &TF_SetXlaConstantFoldingDisabled);
595   m.def("TF_GetXlaConstantFoldingDisabled", &TF_GetXlaConstantFoldingDisabled);
596   m.def("TF_SetXlaMinClusterSize", &TF_SetXlaMinClusterSize);
597   m.def("TF_GetCompilerIr", &tensorflow::TFE_GetCompilerIr);
598 
599   // MLIR Logic
600   m.def("TF_IsMlirBridgeEnabled", [] {
601     // Since python protobuf enums are integers, cast to an integer before
602     // returning the enum to python.
603     return static_cast<int32_t>(
604         tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge);
605   });
606   m.def("TF_EnableMlirBridge", [](bool enabled) {
607     tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge =
608         enabled
609             ? tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED
610             : tensorflow::ConfigProto::Experimental::
611                   MLIR_BRIDGE_ROLLOUT_DISABLED;
612   });
613   m.def("TF_EnableXlaDevices", [] {
614     tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true;
615   });
616 
617   // // TFE_Context Logic
618   m.def(
619       "TFE_NewContext",
620       [](const TFE_ContextOptions* opts) {
621         tensorflow::Safe_TF_StatusPtr status =
622             tensorflow::make_safe(TF_NewStatus());
623         TFE_Context* context = TFE_NewContext(opts, status.get());
624         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
625         return tensorflow::PyoOrThrow(tensorflow::OutputTFE_Context(context));
626       },
627       py::return_value_policy::reference);
628   m.def("TFE_DeleteContext", [](py::handle& o) {
629     TFE_DeleteContext(tensorflow::InputTFE_Context(o));
630   });
631   m.def(
632       "TFE_ContextListDevices",
633       [](py::handle& o) {
634         tensorflow::Safe_TF_StatusPtr status =
635             tensorflow::make_safe(TF_NewStatus());
636         auto output = TFE_ContextListDevices(tensorflow::InputTFE_Context(o),
637                                              status.get());
638         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
639         return output;
640       },
641       py::return_value_policy::reference);
642   m.def("TFE_HostAddressSpace", [](py::handle& o, TF_Buffer& buf) {
643     TFE_HostAddressSpace(tensorflow::InputTFE_Context(o), &buf);
644   });
645   m.def("TFE_ContextAddFunction", [](py::handle& ctx, TF_Function* func) {
646     tensorflow::Safe_TF_StatusPtr status =
647         tensorflow::make_safe(TF_NewStatus());
648     TFE_ContextAddFunction(tensorflow::InputTFE_Context(ctx), func,
649                            status.get());
650     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
651   });
652   m.def("TFE_ContextAddFunctionDef",
653         [](py::handle& ctx, const char* serialized_function_def, size_t size) {
654           tensorflow::Safe_TF_StatusPtr status =
655               tensorflow::make_safe(TF_NewStatus());
656           TFE_ContextAddFunctionDef(tensorflow::InputTFE_Context(ctx),
657                                     serialized_function_def, size,
658                                     status.get());
659           tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
660         });
661   m.def("TFE_ContextGetFunctionDef",
662         [](py::handle& ctx, const char* function_name, TF_Buffer& buf) {
663           tensorflow::Safe_TF_StatusPtr status =
664               tensorflow::make_safe(TF_NewStatus());
665           TFE_ContextGetFunctionDef(tensorflow::InputTFE_Context(ctx),
666                                     function_name, &buf, status.get());
667           tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
668         });
669   m.def("TFE_ContextRemoveFunction", [](py::handle& ctx, const char* name) {
670     tensorflow::Safe_TF_StatusPtr status =
671         tensorflow::make_safe(TF_NewStatus());
672     TFE_ContextRemoveFunction(tensorflow::InputTFE_Context(ctx), name,
673                               status.get());
674     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
675   });
676   m.def("TFE_ContextHasFunction", [](py::handle& ctx, const char* name) {
677     tensorflow::Safe_TF_StatusPtr status =
678         tensorflow::make_safe(TF_NewStatus());
679     auto output =
680         TFE_ContextHasFunction(tensorflow::InputTFE_Context(ctx), name);
681     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
682     return output;
683   });
684   m.def("TFE_ContextListFunctionNames", [](py::handle& ctx) {
685     return tensorflow::unwrap(tensorflow::InputTFE_Context(ctx))
686         ->ListFunctionNames();
687   });
688   m.def("TFE_ContextEnableRunMetadata", [](py::handle& ctx) {
689     TFE_ContextEnableRunMetadata(tensorflow::InputTFE_Context(ctx));
690   });
691   m.def("TFE_ContextDisableRunMetadata", [](py::handle& ctx) {
692     TFE_ContextEnableRunMetadata(tensorflow::InputTFE_Context(ctx));
693   });
694   m.def("TFE_ContextEnableGraphCollection", [](py::handle& ctx) {
695     TFE_ContextEnableGraphCollection(tensorflow::InputTFE_Context(ctx));
696   });
697   m.def("TFE_ContextDisableGraphCollection", [](py::handle& ctx) {
698     TFE_ContextDisableGraphCollection(tensorflow::InputTFE_Context(ctx));
699   });
700   m.def("TFE_ContextExportRunMetadata", [](py::handle& ctx, TF_Buffer& buf) {
701     tensorflow::Safe_TF_StatusPtr status =
702         tensorflow::make_safe(TF_NewStatus());
703     TFE_ContextExportRunMetadata(tensorflow::InputTFE_Context(ctx), &buf,
704                                  status.get());
705     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
706   });
707   m.def("TFE_ContextClearCaches", [](py::handle& o) {
708     TFE_ContextClearCaches(tensorflow::InputTFE_Context(o));
709   });
710   m.def("TFE_GetContextId", [](py::handle& ctx) {
711     return TFE_GetContextId(tensorflow::InputTFE_Context(ctx));
712   });
713   m.def("TFE_ContextGetDevicePlacementPolicy", [](py::handle& ctx) {
714     return TFE_ContextGetDevicePlacementPolicy(
715         tensorflow::InputTFE_Context(ctx));
716   });
717   m.def("TFE_ContextSetThreadLocalDevicePlacementPolicy",
718         [](py::handle& ctx, TFE_ContextDevicePlacementPolicy policy) {
719           TFE_ContextSetThreadLocalDevicePlacementPolicy(
720               tensorflow::InputTFE_Context(ctx), policy);
721         });
722   m.def("TFE_ContextSetServerDef", [](py::handle& ctx, int keep_alive_secs,
723                                       py::bytes proto) {
724     tensorflow::Safe_TF_StatusPtr status =
725         tensorflow::make_safe(TF_NewStatus());
726     tensorflow::Safe_TF_BufferPtr buf =
727         tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr()));
728     TFE_ContextSetServerDef(tensorflow::InputTFE_Context(ctx), keep_alive_secs,
729                             buf.get()->data, buf.get()->length, status.get());
730     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
731   });
732   m.def("TFE_ContextUpdateServerDef", [](py::handle& ctx, int keep_alive_secs,
733                                          py::bytes proto) {
734     tensorflow::Safe_TF_StatusPtr status =
735         tensorflow::make_safe(TF_NewStatus());
736     tensorflow::Safe_TF_BufferPtr buf =
737         tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr()));
738     Py_BEGIN_ALLOW_THREADS;
739     TFE_ContextUpdateServerDef(tensorflow::InputTFE_Context(ctx),
740                                keep_alive_secs, buf.get()->data,
741                                buf.get()->length, status.get());
742     Py_END_ALLOW_THREADS;
743     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
744   });
745   m.def("TFE_ContextCheckAlive", [](py::handle& ctx, const char* worker_name) {
746     tensorflow::Safe_TF_StatusPtr status =
747         tensorflow::make_safe(TF_NewStatus());
748     bool output = TFE_ContextCheckAlive(tensorflow::InputTFE_Context(ctx),
749                                         worker_name, status.get());
750     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
751     return output;
752   });
753   m.def("TFE_ContextSyncExecutors", [](py::handle& ctx) {
754     tensorflow::Safe_TF_StatusPtr status =
755         tensorflow::make_safe(TF_NewStatus());
756     TFE_ContextAsyncWait(tensorflow::InputTFE_Context(ctx), status.get());
757     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
758   });
759   m.def("TFE_ContextClearExecutors", [](py::handle& ctx) {
760     tensorflow::Safe_TF_StatusPtr status =
761         tensorflow::make_safe(TF_NewStatus());
762     TFE_ContextAsyncWait(tensorflow::InputTFE_Context(ctx), status.get());
763     // NOTE: different from TFE_ContextSyncExecutors that raises potential
764     // errors, deliberately ignore executor statuses in cleanup.
765   });
766   m.def("TFE_ContextSetSoftDevicePlacement", [](py::handle& ctx, bool enable) {
767     tensorflow::Safe_TF_StatusPtr status =
768         tensorflow::make_safe(TF_NewStatus());
769     TFE_ContextSetSoftDevicePlacement(tensorflow::InputTFE_Context(ctx), enable,
770                                       status.get());
771   });
772   m.def("TFE_ContextSetLogDevicePlacement", [](py::handle& ctx, bool enable) {
773     tensorflow::Safe_TF_StatusPtr status =
774         tensorflow::make_safe(TF_NewStatus());
775     TFE_ContextSetSoftDevicePlacement(tensorflow::InputTFE_Context(ctx), enable,
776                                       status.get());
777   });
778 
779   // TFE_Executor logic
780   m.def(
781       "TFE_NewExecutor",
782       [](const bool is_async) {
783         TFE_Executor* exc = TFE_NewExecutor(is_async);
784         return exc;
785       },
786       py::return_value_policy::reference);
787   m.def("TFE_DeleteExecutor", &TFE_DeleteExecutor);
788   m.def("TFE_ExecutorIsAsync", &TFE_ExecutorIsAsync);
789   m.def("TFE_ExecutorWaitForAllPendingNodes", [](TFE_Executor& exc) {
790     tensorflow::Safe_TF_StatusPtr status =
791         tensorflow::make_safe(TF_NewStatus());
792     // NOTE: release Python GIL for pending PyFunc ops to be executed properly.
793     Py_BEGIN_ALLOW_THREADS;
794     TFE_ExecutorWaitForAllPendingNodes(&exc, status.get());
795     Py_END_ALLOW_THREADS;
796     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
797   });
798   m.def("TFE_ExecutorClearError", &TFE_ExecutorClearError);
799   m.def("TFE_ContextSetExecutorForThread", [](py::handle& ctx,
800                                               TFE_Executor& exc) {
801     TFE_ContextSetExecutorForThread(tensorflow::InputTFE_Context(ctx), &exc);
802   });
803   m.def(
804       "TFE_ContextGetExecutorForThread",
805       [](py::handle& o) {
806         return TFE_ContextGetExecutorForThread(tensorflow::InputTFE_Context(o));
807       },
808       py::return_value_policy::reference);
809 
810   m.def("TFE_OpNameGetAttrType",
811         [](py::handle& ctx, const char* op_or_function_name,
812            const char* attr_name) {
813           int temp = 0;
814           unsigned char* is_list = reinterpret_cast<unsigned char*>(&temp);
815           tensorflow::Safe_TF_StatusPtr status =
816               tensorflow::make_safe(TF_NewStatus());
817           auto output = TFE_OpNameGetAttrType(tensorflow::InputTFE_Context(ctx),
818                                               op_or_function_name, attr_name,
819                                               is_list, status.get());
820           tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
821 #if PY_MAJOR_VERSION < 3
822           PyObject* output_pyo = PyInt_FromLong(output);
823 #else
824           PyObject* output_pyo = PyLong_FromLong(output);
825 #endif
826           if (*is_list == 1) {
827             PyObject* list = PyList_New(1);
828             PyList_SetItem(list, 0, output_pyo);
829             return tensorflow::PyoOrThrow(list);
830           }
831           return tensorflow::PyoOrThrow(output_pyo);
832         });
833   m.def("TFE_Py_InitEagerTensor", [](const py::handle& o) {
834     return tensorflow::PyoOrThrow(TFE_Py_InitEagerTensor(o.ptr()));
835   });
836   m.def("TFE_Py_PackEagerTensors",
837         [](const py::handle& context, const py::handle& handles) {
838           return tensorflow::TFE_Py_PackEagerTensors_wrapper(context, handles);
839         });
840   m.def("TFE_Py_SetEagerTensorProfiler", &TFE_Py_SetEagerTensorProfiler);
841   m.def("TFE_Py_RegisterJVPFunction", [](const py::handle& o) {
842     return tensorflow::PyoOrThrow(TFE_Py_RegisterJVPFunction(o.ptr()));
843   });
844   m.def("TFE_Py_RegisterGradientFunction", [](const py::handle& o) {
845     return tensorflow::PyoOrThrow(TFE_Py_RegisterGradientFunction(o.ptr()));
846   });
847   m.def("TFE_Py_Execute",
848         [](const py::handle& context, const char* device_name,
849            const char* op_name, const py::handle& inputs,
850            const py::handle& attrs, const py::handle& num_outputs) {
851           return tensorflow::TFE_Py_ExecuteCancelable_wrapper(
852               context, device_name, op_name, inputs, attrs.ptr(), nullptr,
853               num_outputs);
854         });
855   m.def(
856       "TFE_Py_ExecuteCancelable",
857       [](const py::handle& context, const char* device_name,
858          const char* op_name, const py::handle& inputs, const py::handle& attrs,
859          tensorflow::CancellationManager& cancellation_manager,
860          const py::handle& num_outputs) {
861         return tensorflow::TFE_Py_ExecuteCancelable_wrapper(
862             context, device_name, op_name, inputs, attrs.ptr(),
863             &cancellation_manager, num_outputs);
864       });
865   m.def("TFE_Py_FastPathExecute", [](const py::args args) {
866     // TFE_Py_FastPathExecute requires error checking prior to returning.
867     return tensorflow::PyoOrThrow(TFE_Py_FastPathExecute_C(args.ptr()));
868   });
869   m.def("TFE_Py_RecordGradient",
870         [](const py::handle& op_name, const py::handle& inputs,
871            const py::handle& attrs, const py::handle& results,
872            const py::handle& forward_pass_name_scope) {
873           return tensorflow::PyoOrThrow(TFE_Py_RecordGradient(
874               op_name.ptr(), inputs.ptr(), attrs.ptr(), results.ptr(),
875               forward_pass_name_scope.ptr()));
876         });
877   m.def("TFE_Py_UID", []() { return tensorflow::PyoOrThrow(TFE_Py_UID()); });
878 
879   // TFE_Py_Tape Logic
880   m.def("TFE_Py_TapeSetNew", [](const py::handle& persistent,
881                                 const py::handle& watch_accessed_variables) {
882     return tensorflow::PyoOrThrow(
883         TFE_Py_TapeSetNew(persistent.ptr(), watch_accessed_variables.ptr()));
884   });
885   m.def("TFE_Py_TapeSetAdd",
886         [](const py::handle& tape) { TFE_Py_TapeSetAdd(tape.ptr()); });
887   m.def("TFE_Py_TapeSetRemove",
888         [](const py::handle& tape) { TFE_Py_TapeSetRemove(tape.ptr()); });
889   m.def("TFE_Py_TapeSetStopOnThread", &TFE_Py_TapeSetStopOnThread);
890   m.def("TFE_Py_TapeSetRestartOnThread", &TFE_Py_TapeSetRestartOnThread);
891   m.def("TFE_Py_TapeSetIsStopped",
892         []() { return tensorflow::PyoOrThrow(TFE_Py_TapeSetIsStopped()); });
893   m.def("TFE_Py_TapeSetIsEmpty",
894         []() { return tensorflow::PyoOrThrow(TFE_Py_TapeSetIsEmpty()); });
895   m.def("TFE_Py_TapeSetShouldRecordBackprop", [](const py::handle& tensors) {
896     return tensorflow::PyoOrThrow(
897         TFE_Py_TapeSetShouldRecordBackprop(tensors.ptr()));
898   });
899   m.def("TFE_Py_TapeSetPossibleGradientTypes", [](const py::handle& tensors) {
900     return tensorflow::PyoOrThrow(
901         TFE_Py_TapeSetPossibleGradientTypes(tensors.ptr()));
902   });
903   m.def("TFE_Py_TapeSetDeleteTrace", &TFE_Py_TapeSetDeleteTrace);
904   m.def("TFE_Py_TapeSetRecordOperation",
905         [](const py::handle& op_type, const py::handle& output_tensors,
906            const py::handle& input_tensors, const py::handle& backward_function,
907            const py::handle& forward_function) {
908           return tensorflow::PyoOrThrow(TFE_Py_TapeSetRecordOperation(
909               op_type.ptr(), output_tensors.ptr(), input_tensors.ptr(),
910               backward_function.ptr(), forward_function.ptr()));
911         });
912   m.def(
913       "TFE_Py_TapeSetRecordOperationBackprop",
914       [](const py::handle& op_type, const py::handle& output_tensors,
915          const py::handle& input_tensors, const py::handle& backward_function) {
916         return tensorflow::PyoOrThrow(TFE_Py_TapeSetRecordOperationBackprop(
917             op_type.ptr(), output_tensors.ptr(), input_tensors.ptr(),
918             backward_function.ptr()));
919       });
920   m.def(
921       "TFE_Py_TapeSetRecordOperationForwardprop",
922       [](const py::handle& op_type, const py::handle& output_tensors,
923          const py::handle& input_tensors, const py::handle& backward_function,
924          const py::handle& forwardprop_output_indices) {
925         return tensorflow::PyoOrThrow(TFE_Py_TapeSetRecordOperationForwardprop(
926             op_type.ptr(), output_tensors.ptr(), input_tensors.ptr(),
927             backward_function.ptr(), forwardprop_output_indices.ptr()));
928       });
929   m.def("TFE_Py_TapeGradient",
930         [](const py::handle& tape, const py::handle& target,
931            const py::handle& sources, const py::handle& output_gradients,
932            const py::handle& sources_raw,
933            const py::handle& unconnected_gradients) {
934           tensorflow::Safe_TF_StatusPtr status =
935               tensorflow::make_safe(TF_NewStatus());
936           PyObject* output = TFE_Py_TapeGradient(
937               tape.ptr(), target.ptr(), sources.ptr(), output_gradients.ptr(),
938               sources_raw.ptr(), unconnected_gradients.ptr(), status.get());
939           tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
940           return tensorflow::PyoOrThrow(output);
941         });
942 
943   m.def("TFE_Py_TapeVariableAccessed", [](const py::handle& variable) {
944     TFE_Py_TapeVariableAccessed(variable.ptr());
945   });
946   m.def("TFE_Py_TapeWatch",
947         [](const py::handle& tape, const py::handle& tensor) {
948           TFE_Py_TapeWatch(tape.ptr(), tensor.ptr());
949         });
950   m.def("TFE_Py_TapeWatchVariable",
951         [](const py::handle& tape, const py::handle& variable) {
952           TFE_Py_TapeWatchVariable(tape.ptr(), variable.ptr());
953         });
954   m.def("TFE_Py_TapeWatchedVariables", [](const py::handle& tape) {
955     return tensorflow::PyoOrThrow(TFE_Py_TapeWatchedVariables(tape.ptr()));
956   });
957 
958   // TFE_Py_VariableWatcher logic.
959   m.def("TFE_Py_VariableWatcherNew",
960         []() { return tensorflow::PyoOrThrow(TFE_Py_VariableWatcherNew()); });
961   m.def("TFE_Py_VariableWatcherRemove", [](const py::handle& variable_watcher) {
962     TFE_Py_VariableWatcherRemove(variable_watcher.ptr());
963   });
964   m.def("TFE_Py_VariableWatcherVariableAccessed",
965         [](const py::handle& variable) {
966           TFE_Py_VariableWatcherVariableAccessed(variable.ptr());
967         });
968   m.def("TFE_Py_VariableWatcherWatchedVariables",
969         [](const py::handle& variable_watcher) {
970           return tensorflow::PyoOrThrow(
971               TFE_Py_VariableWatcherWatchedVariables(variable_watcher.ptr()));
972         });
973 
974   // TFE_Py_ForwardAccumulator logic.
975   m.def("TFE_Py_ForwardAccumulatorNew", [](bool use_batch) {
976     return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorNew(use_batch));
977   });
978 
979   m.def("TFE_Py_ForwardAccumulatorSetAdd", [](const py::handle& accumulator) {
980     return tensorflow::PyoOrThrow(
981         TFE_Py_ForwardAccumulatorSetAdd(accumulator.ptr()));
982   });
983   m.def("TFE_Py_ForwardAccumulatorSetRemove",
984         [](const py::handle& accumulator) {
985           TFE_Py_ForwardAccumulatorSetRemove(accumulator.ptr());
986         });
987 
988   m.def("TFE_Py_ForwardAccumulatorWatch",
989         [](const py::handle& accumulator, const py::handle& tensor,
990            const py::handle& tangent) {
991           TFE_Py_ForwardAccumulatorWatch(accumulator.ptr(), tensor.ptr(),
992                                          tangent.ptr());
993         });
994   m.def("TFE_Py_ForwardAccumulatorJVP",
995         [](const py::handle& accumulator, const py::handle& tensor) {
996           return tensorflow::PyoOrThrow(
997               TFE_Py_ForwardAccumulatorJVP(accumulator.ptr(), tensor.ptr()));
998         });
999   m.def("TFE_Py_ForwardAccumulatorPushState", []() {
1000     return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorPushState());
1001   });
1002   m.def("TFE_Py_ForwardAccumulatorPopState", []() {
1003     return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorPopState());
1004   });
1005   m.def("TFE_Py_PackJVPs", [](const py::handle& tensors) {
1006     return tensorflow::PyoOrThrow(TFE_Py_PackJVPs(tensors.ptr()));
1007   });
1008 
1009   // TFE_ContextOptions Logic
1010   m.def("TFE_NewContextOptions", &TFE_NewContextOptions,
1011         py::return_value_policy::reference);
1012   m.def("TFE_ContextOptionsSetConfig", [](TFE_ContextOptions* options,
1013                                           py::bytes proto) {
1014     tensorflow::Safe_TF_StatusPtr status =
1015         tensorflow::make_safe(TF_NewStatus());
1016     tensorflow::Safe_TF_BufferPtr buf =
1017         tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr()));
1018     TFE_ContextOptionsSetConfig(options, buf.get()->data, buf.get()->length,
1019                                 status.get());
1020     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1021   });
1022   m.def("TFE_ContextOptionsSetDevicePlacementPolicy",
1023         &TFE_ContextOptionsSetDevicePlacementPolicy);
1024   m.def("TFE_ContextOptionsSetTfrt", &TFE_ContextOptionsSetTfrt);
1025   m.def("TFE_ContextOptionsSetAsync", &TFE_ContextOptionsSetAsync);
1026   m.def("TFE_DeleteContextOptions", &TFE_DeleteContextOptions,
1027         py::return_value_policy::reference);
1028 
1029   // TFE_Py_TensorShape Logic
1030   m.def("TFE_Py_TensorShapeSlice",
1031         [](const py::handle& tensors, int slice_dim) {
1032           return tensorflow::PyoOrThrow(
1033               TFE_Py_TensorShapeSlice(tensors.ptr(), slice_dim));
1034         });
1035   m.def("TFE_Py_TensorShapeOnDevice", [](const py::handle& tensors,
1036                                          int slice_dim) {
1037     return tensorflow::PyoOrThrow(TFE_Py_TensorShapeOnDevice(tensors.ptr()));
1038   });
1039   m.def("TFE_Py_EnableInteractivePythonLogging",
1040         &TFE_Py_EnableInteractivePythonLogging);
1041 
1042   // Additional Context Logic
1043   m.def("TFE_Py_SetEagerContext", [](const py::handle& o) {
1044     return tensorflow::PyoOrThrow(TFE_Py_SetEagerContext(o.ptr()));
1045   });
1046   m.def("TFE_Py_RegisterVSpace", [](const py::handle& o) {
1047     return tensorflow::PyoOrThrow(TFE_Py_RegisterVSpace(o.ptr()));
1048   });
1049   m.def("TFE_Py_EncodeArg",
1050         [](const py::handle& o, bool include_tensor_ranks_only) {
1051           return tensorflow::PyoOrThrow(
1052               TFE_Py_EncodeArg(o.ptr(), include_tensor_ranks_only));
1053         });
1054   m.def("TFE_EnableCollectiveOps", [](const py::handle& ctx, py::bytes proto) {
1055     tensorflow::Safe_TF_StatusPtr status =
1056         tensorflow::make_safe(TF_NewStatus());
1057     tensorflow::Safe_TF_BufferPtr buf =
1058         tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr()));
1059     TFE_EnableCollectiveOps(tensorflow::InputTFE_Context(ctx), buf.get()->data,
1060                             buf.get()->length, status.get());
1061     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1062   });
1063   m.def("TFE_AbortCollectiveOps", [](const py::handle& ctx, int code,
1064                                      const char* message) {
1065     tensorflow::Safe_TF_StatusPtr status =
1066         tensorflow::make_safe(TF_NewStatus());
1067     TF_SetStatus(status.get(), static_cast<TF_Code>(code), message);
1068     TFE_AbortCollectiveOps(tensorflow::InputTFE_Context(ctx), status.get());
1069   });
1070   m.def("TFE_CollectiveOpsCheckPeerHealth",
1071         [](const py::handle& ctx, const char* task, int64_t timeout_in_ms) {
1072           tensorflow::Safe_TF_StatusPtr status =
1073               tensorflow::make_safe(TF_NewStatus());
1074           TFE_CollectiveOpsCheckPeerHealth(tensorflow::InputTFE_Context(ctx),
1075                                            task, timeout_in_ms, status.get());
1076           tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1077         });
1078   m.def("TF_ListPhysicalDevices", &tensorflow::TF_ListPhysicalDevices);
1079   m.def("TF_GetDeviceDetails", &tensorflow::TF_GetDeviceDetails);
1080   m.def("TF_DeleteDeviceList", &TF_DeleteDeviceList,
1081         py::return_value_policy::reference);
1082   m.def("TF_DeviceListCount", &TF_DeviceListCount);
1083   m.def("TF_DeviceListName", [](const TF_DeviceList* list, int index) {
1084     tensorflow::Safe_TF_StatusPtr status =
1085         tensorflow::make_safe(TF_NewStatus());
1086     auto output = TF_DeviceListName(list, index, status.get());
1087     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1088     return output;
1089   });
1090   m.def("TF_DeviceListType", [](const TF_DeviceList* list, int index) {
1091     tensorflow::Safe_TF_StatusPtr status =
1092         tensorflow::make_safe(TF_NewStatus());
1093     auto output = TF_DeviceListType(list, index, status.get());
1094     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1095     return output;
1096   });
1097 
1098   m.def("TF_PickUnusedPortOrDie", &TF_PickUnusedPortOrDie);
1099 
1100   // TFE_MonitoringCounter Logic
1101   m.def("TFE_MonitoringCounterCellIncrementBy",
1102         &TFE_MonitoringCounterCellIncrementBy);
1103   m.def("TFE_MonitoringCounterCellValue", &TFE_MonitoringCounterCellValue);
1104   m.def(
1105       "TFE_MonitoringNewCounter0",
1106       [](const char* name, const char* description) {
1107         tensorflow::Safe_TF_StatusPtr status =
1108             tensorflow::make_safe(TF_NewStatus());
1109         auto output =
1110             TFE_MonitoringNewCounter0(name, status.get(), description);
1111         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1112         return output;
1113       },
1114       py::return_value_policy::reference);
1115   m.def("TFE_MonitoringDeleteCounter0", &TFE_MonitoringDeleteCounter0,
1116         py::return_value_policy::reference);
1117   m.def("TFE_MonitoringGetCellCounter0", &TFE_MonitoringGetCellCounter0,
1118         py::return_value_policy::reference);
1119   m.def(
1120       "TFE_MonitoringNewCounter1",
1121       [](const char* name, const char* description, const char* label1) {
1122         tensorflow::Safe_TF_StatusPtr status =
1123             tensorflow::make_safe(TF_NewStatus());
1124         auto output =
1125             TFE_MonitoringNewCounter1(name, status.get(), description, label1);
1126         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1127         return output;
1128       },
1129       py::return_value_policy::reference);
1130   m.def("TFE_MonitoringDeleteCounter1", &TFE_MonitoringDeleteCounter1,
1131         py::return_value_policy::reference);
1132   m.def("TFE_MonitoringGetCellCounter1", &TFE_MonitoringGetCellCounter1,
1133         py::return_value_policy::reference);
1134   m.def(
1135       "TFE_MonitoringNewCounter2",
1136       [](const char* name, const char* description, const char* label1,
1137          const char* label2) {
1138         tensorflow::Safe_TF_StatusPtr status =
1139             tensorflow::make_safe(TF_NewStatus());
1140         auto output = TFE_MonitoringNewCounter2(name, status.get(), description,
1141                                                 label1, label2);
1142         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1143         return output;
1144       },
1145       py::return_value_policy::reference);
1146   m.def("TFE_MonitoringDeleteCounter2", &TFE_MonitoringDeleteCounter2,
1147         py::return_value_policy::reference);
1148   m.def("TFE_MonitoringGetCellCounter2", &TFE_MonitoringGetCellCounter2,
1149         py::return_value_policy::reference);
1150 
1151   // TFE_MonitoringIntGauge Logic
1152   m.def("TFE_MonitoringIntGaugeCellSet", &TFE_MonitoringIntGaugeCellSet);
1153   m.def("TFE_MonitoringIntGaugeCellValue", &TFE_MonitoringIntGaugeCellValue);
1154   m.def(
1155       "TFE_MonitoringNewIntGauge0",
1156       [](const char* name, const char* description) {
1157         tensorflow::Safe_TF_StatusPtr status =
1158             tensorflow::make_safe(TF_NewStatus());
1159         auto output =
1160             TFE_MonitoringNewIntGauge0(name, status.get(), description);
1161         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1162         return output;
1163       },
1164       py::return_value_policy::reference);
1165   m.def("TFE_MonitoringDeleteIntGauge0", &TFE_MonitoringDeleteIntGauge0,
1166         py::return_value_policy::reference);
1167   m.def("TFE_MonitoringGetCellIntGauge0", &TFE_MonitoringGetCellIntGauge0,
1168         py::return_value_policy::reference);
1169   m.def(
1170       "TFE_MonitoringNewIntGauge1",
1171       [](const char* name, const char* description, const char* label1) {
1172         tensorflow::Safe_TF_StatusPtr status =
1173             tensorflow::make_safe(TF_NewStatus());
1174         auto output =
1175             TFE_MonitoringNewIntGauge1(name, status.get(), description, label1);
1176         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1177         return output;
1178       },
1179       py::return_value_policy::reference);
1180   m.def("TFE_MonitoringDeleteIntGauge1", &TFE_MonitoringDeleteIntGauge1,
1181         py::return_value_policy::reference);
1182   m.def("TFE_MonitoringGetCellIntGauge1", &TFE_MonitoringGetCellIntGauge1,
1183         py::return_value_policy::reference);
1184   m.def(
1185       "TFE_MonitoringNewIntGauge2",
1186       [](const char* name, const char* description, const char* label1,
1187          const char* label2) {
1188         tensorflow::Safe_TF_StatusPtr status =
1189             tensorflow::make_safe(TF_NewStatus());
1190         auto output = TFE_MonitoringNewIntGauge2(name, status.get(),
1191                                                  description, label1, label2);
1192         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1193         return output;
1194       },
1195       py::return_value_policy::reference);
1196   m.def("TFE_MonitoringDeleteIntGauge2", &TFE_MonitoringDeleteIntGauge2,
1197         py::return_value_policy::reference);
1198   m.def("TFE_MonitoringGetCellIntGauge2", &TFE_MonitoringGetCellIntGauge2,
1199         py::return_value_policy::reference);
1200   m.def("TFE_MonitoringStringGaugeCellSet", &TFE_MonitoringStringGaugeCellSet);
1201   m.def("TFE_MonitoringStringGaugeCellValue",
1202         &TFE_MonitoringStringGaugeCellValue);
1203   m.def(
1204       "TFE_MonitoringNewStringGauge0",
1205       [](const char* name, const char* description) {
1206         tensorflow::Safe_TF_StatusPtr status =
1207             tensorflow::make_safe(TF_NewStatus());
1208         auto output =
1209             TFE_MonitoringNewStringGauge0(name, status.get(), description);
1210         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1211         return output;
1212       },
1213       py::return_value_policy::reference);
1214 
1215   // TFE_MonitoringStringGauge Logic
1216   m.def("TFE_MonitoringDeleteStringGauge0", &TFE_MonitoringDeleteStringGauge0);
1217   m.def("TFE_MonitoringGetCellStringGauge0", &TFE_MonitoringGetCellStringGauge0,
1218         py::return_value_policy::reference);
1219   m.def(
1220       "TFE_MonitoringNewStringGauge1",
1221       [](const char* name, const char* description, const char* label1) {
1222         tensorflow::Safe_TF_StatusPtr status =
1223             tensorflow::make_safe(TF_NewStatus());
1224         auto output = TFE_MonitoringNewStringGauge1(name, status.get(),
1225                                                     description, label1);
1226         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1227         return output;
1228       },
1229       py::return_value_policy::reference);
1230   m.def("TFE_MonitoringDeleteStringGauge1", &TFE_MonitoringDeleteStringGauge1);
1231   m.def("TFE_MonitoringGetCellStringGauge1", &TFE_MonitoringGetCellStringGauge1,
1232         py::return_value_policy::reference);
1233   m.def(
1234       "TFE_MonitoringNewStringGauge2",
1235       [](const char* name, const char* description, const char* label1,
1236          const char* label2) {
1237         tensorflow::Safe_TF_StatusPtr status =
1238             tensorflow::make_safe(TF_NewStatus());
1239         auto output = TFE_MonitoringNewStringGauge2(
1240             name, status.get(), description, label1, label2);
1241         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1242         return output;
1243       },
1244       py::return_value_policy::reference);
1245   m.def("TFE_MonitoringDeleteStringGauge2", &TFE_MonitoringDeleteStringGauge2);
1246   m.def("TFE_MonitoringGetCellStringGauge2", &TFE_MonitoringGetCellStringGauge2,
1247         py::return_value_policy::reference);
1248 
1249   // TFE_MonitoringBoolGauge Logic
1250   m.def("TFE_MonitoringBoolGaugeCellSet", &TFE_MonitoringBoolGaugeCellSet);
1251   m.def("TFE_MonitoringBoolGaugeCellValue", &TFE_MonitoringBoolGaugeCellValue);
1252   m.def(
1253       "TFE_MonitoringNewBoolGauge0",
1254       [](const char* name, const char* description) {
1255         tensorflow::Safe_TF_StatusPtr status =
1256             tensorflow::make_safe(TF_NewStatus());
1257         auto output =
1258             TFE_MonitoringNewBoolGauge0(name, status.get(), description);
1259         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1260         return output;
1261       },
1262       py::return_value_policy::reference);
1263   m.def("TFE_MonitoringDeleteBoolGauge0", &TFE_MonitoringDeleteBoolGauge0,
1264         py::return_value_policy::reference);
1265   m.def("TFE_MonitoringGetCellBoolGauge0", &TFE_MonitoringGetCellBoolGauge0,
1266         py::return_value_policy::reference);
1267   m.def(
1268       "TFE_MonitoringNewBoolGauge1",
1269       [](const char* name, const char* description, const char* label1) {
1270         tensorflow::Safe_TF_StatusPtr status =
1271             tensorflow::make_safe(TF_NewStatus());
1272         auto output = TFE_MonitoringNewBoolGauge1(name, status.get(),
1273                                                   description, label1);
1274         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1275         return output;
1276       },
1277       py::return_value_policy::reference);
1278   m.def("TFE_MonitoringDeleteBoolGauge1", &TFE_MonitoringDeleteBoolGauge1,
1279         py::return_value_policy::reference);
1280   m.def("TFE_MonitoringGetCellBoolGauge1", &TFE_MonitoringGetCellBoolGauge1,
1281         py::return_value_policy::reference);
1282   m.def(
1283       "TFE_MonitoringNewBoolGauge2",
1284       [](const char* name, const char* description, const char* label1,
1285          const char* label2) {
1286         tensorflow::Safe_TF_StatusPtr status =
1287             tensorflow::make_safe(TF_NewStatus());
1288         auto output = TFE_MonitoringNewBoolGauge2(name, status.get(),
1289                                                   description, label1, label2);
1290         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1291         return output;
1292       },
1293       py::return_value_policy::reference);
1294   m.def("TFE_MonitoringDeleteBoolGauge2", &TFE_MonitoringDeleteBoolGauge2,
1295         py::return_value_policy::reference);
1296   m.def("TFE_MonitoringGetCellBoolGauge2", &TFE_MonitoringGetCellBoolGauge2,
1297         py::return_value_policy::reference);
1298 
1299   // TFE_MonitoringSampler Logic
1300   m.def("TFE_MonitoringSamplerCellAdd", &TFE_MonitoringSamplerCellAdd);
1301   m.def("TFE_MonitoringSamplerCellValue", &TFE_MonitoringSamplerCellValue);
1302   m.def("TFE_MonitoringNewExponentialBuckets",
1303         &TFE_MonitoringNewExponentialBuckets,
1304         py::return_value_policy::reference);
1305   m.def("TFE_MonitoringDeleteBuckets", &TFE_MonitoringDeleteBuckets,
1306         py::return_value_policy::reference);
1307   m.def(
1308       "TFE_MonitoringNewSampler0",
1309       [](const char* name, TFE_MonitoringBuckets* buckets,
1310          const char* description) {
1311         tensorflow::Safe_TF_StatusPtr status =
1312             tensorflow::make_safe(TF_NewStatus());
1313         auto output =
1314             TFE_MonitoringNewSampler0(name, buckets, status.get(), description);
1315         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1316         return output;
1317       },
1318       py::return_value_policy::reference);
1319   m.def("TFE_MonitoringDeleteSampler0", &TFE_MonitoringDeleteSampler0,
1320         py::return_value_policy::reference);
1321   m.def("TFE_MonitoringGetCellSampler0", &TFE_MonitoringGetCellSampler0,
1322         py::return_value_policy::reference);
1323   m.def(
1324       "TFE_MonitoringNewSampler1",
1325       [](const char* name, TFE_MonitoringBuckets* buckets,
1326          const char* description, const char* label1) {
1327         tensorflow::Safe_TF_StatusPtr status =
1328             tensorflow::make_safe(TF_NewStatus());
1329         auto output = TFE_MonitoringNewSampler1(name, buckets, status.get(),
1330                                                 description, label1);
1331         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1332         return output;
1333       },
1334       py::return_value_policy::reference);
1335   m.def("TFE_MonitoringDeleteSampler1", &TFE_MonitoringDeleteSampler1,
1336         py::return_value_policy::reference);
1337   m.def("TFE_MonitoringGetCellSampler1", &TFE_MonitoringGetCellSampler1,
1338         py::return_value_policy::reference);
1339   m.def(
1340       "TFE_MonitoringNewSampler2",
1341       [](const char* name, TFE_MonitoringBuckets* buckets,
1342          const char* description, const char* label1, const char* label2) {
1343         tensorflow::Safe_TF_StatusPtr status =
1344             tensorflow::make_safe(TF_NewStatus());
1345         auto output = TFE_MonitoringNewSampler2(name, buckets, status.get(),
1346                                                 description, label1, label2);
1347         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1348         return output;
1349       },
1350       py::return_value_policy::reference);
1351   m.def("TFE_MonitoringDeleteSampler2", &TFE_MonitoringDeleteSampler2,
1352         py::return_value_policy::reference);
1353   m.def("TFE_MonitoringGetCellSampler2", &TFE_MonitoringGetCellSampler2,
1354         py::return_value_policy::reference);
1355 
1356   // TFE_CancellationManager Logic
1357   m.def("TFE_NewCancellationManager",
1358         []() { return new tensorflow::CancellationManager(); });
1359   m.def("TFE_CancellationManagerIsCancelled",
1360         &tensorflow::CancellationManager::IsCancelled);
1361   m.def("TFE_CancellationManagerStartCancel",
1362         &tensorflow::CancellationManager::StartCancel);
1363 
1364   m.def("TFE_ClearScalarCache", &tensorflow::TFE_ClearScalarCache);
1365 
1366   // Util buffer helper functions
1367   m.def("TF_NewBufferFromString", &TF_NewBufferFromString,
1368         py::return_value_policy::reference);
1369 
1370   // DLPack functions
1371   m.def("TFE_ToDlpackCapsule", [](py::handle& o) {
1372     PyObject* eager_tensor_pyobject_ptr = o.ptr();
1373     tensorflow::Safe_TF_StatusPtr status =
1374         tensorflow::make_safe(TF_NewStatus());
1375 
1376     if (!EagerTensor_CheckExact(eager_tensor_pyobject_ptr)) {
1377       status->status = tensorflow::errors::InvalidArgument(
1378           "The argument to `to_dlpack` must be a TF tensor, not Python object");
1379       tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1380     }
1381 
1382     TFE_TensorHandle* thandle = EagerTensor_Handle(eager_tensor_pyobject_ptr);
1383     void* dlm_ptr = tensorflow::TFE_HandleToDLPack(thandle, status.get());
1384     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1385 
1386     py::capsule capsule(
1387         dlm_ptr, tensorflow::kDlTensorCapsuleName, [](PyObject* capsule) {
1388           if (PyCapsule_IsValid(capsule, tensorflow::kDlTensorCapsuleName)) {
1389             void* dlm_rptr =
1390                 PyCapsule_GetPointer(capsule, tensorflow::kDlTensorCapsuleName);
1391             if (dlm_rptr) {
1392               tensorflow::TFE_CallDLManagedTensorDeleter(dlm_rptr);
1393               PyCapsule_SetDestructor(capsule, nullptr);
1394             }
1395           }
1396         });
1397     return capsule;
1398   });
1399 
1400   m.def("TFE_FromDlpackCapsule", [](const py::capsule& pycapsule,
1401                                     const py::handle& context) {
1402     tensorflow::Safe_TF_StatusPtr status =
1403         tensorflow::make_safe(TF_NewStatus());
1404     if (absl::string_view(pycapsule.name()) !=
1405         tensorflow::kDlTensorCapsuleName) {
1406       status->status = tensorflow::errors::InvalidArgument(
1407           "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". "
1408           "Note that a DLPack tensor may be consumed at most once.",
1409           absl::string_view(pycapsule.name()));
1410       tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1411     }
1412 
1413     TFE_TensorHandle* thandle = tensorflow::TFE_HandleFromDLPack(
1414         pycapsule, status.get(), tensorflow::InputTFE_Context(context));
1415 
1416     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1417 
1418     PyCapsule_SetName(pycapsule.ptr(), "used_dltensor");
1419     PyCapsule_SetDestructor(pycapsule.ptr(), nullptr);
1420 
1421     PyObject* pyhandle = EagerTensorFromHandle(thandle);
1422     return tensorflow::PyoOrThrow(pyhandle);
1423   });
1424 
1425   m.def("TFE_Py_RegisterCustomDevice", [](const py::handle& context,
1426                                           const py::capsule& device,
1427                                           const char* device_name,
1428                                           const py::capsule& device_info) {
1429     tensorflow::Safe_TF_StatusPtr status =
1430         tensorflow::make_safe(TF_NewStatus());
1431     if (absl::string_view(device.name()) != "TFE_CustomDevice") {
1432       status->status = tensorflow::errors::InvalidArgument(
1433           "Expected a capsule named 'TFE_CustomDevice' for the `device` "
1434           "argument, got ",
1435           absl::string_view(device.name()));
1436       tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1437     }
1438     if (absl::string_view(device_info.name()) !=
1439         "TFE_CustomDevice_DeviceInfo") {
1440       status->status = tensorflow::errors::InvalidArgument(
1441           "Expected a capsule named 'TFE_CustomDevice_DeviceInfo' for "
1442           "the `device_info` argument, got ",
1443           absl::string_view(device_info.name()));
1444       tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1445     }
1446     // TFE_RegisterCustomDevice takes ownership
1447     PyCapsule_SetDestructor(device_info.ptr(), nullptr);
1448     TFE_RegisterCustomDevice(
1449         tensorflow::InputTFE_Context(context),
1450         *reinterpret_cast<TFE_CustomDevice*>(
1451             PyCapsule_GetPointer(device.ptr(), "TFE_CustomDevice")),
1452         device_name,
1453         PyCapsule_GetPointer(device_info.ptr(), "TFE_CustomDevice_DeviceInfo"),
1454         status.get());
1455     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1456   });
1457 
1458   py::class_<EagerContextThreadLocalDataWrapper>(m,
1459                                                  "EagerContextThreadLocalData")
1460       .def(py::init<py::handle, py::handle, py::handle>(),
1461            py::arg("py_eager_context"), py::arg("is_eager"),
1462            py::arg("device_spec"))
1463       .def_property("is_eager",
1464                     &EagerContextThreadLocalDataWrapper::get_is_eager,
1465                     &EagerContextThreadLocalDataWrapper::set_is_eager)
1466       .def_property(
1467           "invoking_op_callbacks",
1468           &EagerContextThreadLocalDataWrapper::get_invoking_op_callbacks,
1469           &EagerContextThreadLocalDataWrapper::set_invoking_op_callbacks)
1470       .def_property("device_name",
1471                     &EagerContextThreadLocalDataWrapper::get_device_name,
1472                     &EagerContextThreadLocalDataWrapper::set_device_name)
1473       .def_property("scope_name",
1474                     &EagerContextThreadLocalDataWrapper::get_scope_name,
1475                     &EagerContextThreadLocalDataWrapper::set_scope_name)
1476       .def_property("device_spec",
1477                     &EagerContextThreadLocalDataWrapper::get_device_spec,
1478                     &EagerContextThreadLocalDataWrapper::set_device_spec)
1479       .def_property(
1480           "function_call_options",
1481           &EagerContextThreadLocalDataWrapper::get_function_call_options,
1482           &EagerContextThreadLocalDataWrapper::set_function_call_options)
1483       .def_property("executor",
1484                     &EagerContextThreadLocalDataWrapper::get_executor,
1485                     &EagerContextThreadLocalDataWrapper::set_executor)
1486       .def_property("op_callbacks",
1487                     &EagerContextThreadLocalDataWrapper::get_op_callbacks,
1488                     &EagerContextThreadLocalDataWrapper::set_op_callbacks);
1489 
1490   // C API Enum
1491 
1492   py::enum_<TFE_ContextDevicePlacementPolicy>(
1493       m, "TFE_ContextDevicePlacementPolicy")
1494       .value("TFE_DEVICE_PLACEMENT_EXPLICIT", TFE_DEVICE_PLACEMENT_EXPLICIT)
1495       .value("TFE_DEVICE_PLACEMENT_WARN", TFE_DEVICE_PLACEMENT_WARN)
1496       .value("TFE_DEVICE_PLACEMENT_SILENT", TFE_DEVICE_PLACEMENT_SILENT)
1497       .value("TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32",
1498              TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32)
1499       .export_values();
1500 
1501   py::enum_<TF_AttrType>(m, "TF_AttrType")
1502       .value("TF_ATTR_STRING", TF_ATTR_STRING)
1503       .value("TF_ATTR_INT", TF_ATTR_INT)
1504       .value("TF_ATTR_FLOAT", TF_ATTR_FLOAT)
1505       .value("TF_ATTR_BOOL", TF_ATTR_BOOL)
1506       .value("TF_ATTR_TYPE", TF_ATTR_TYPE)
1507       .value("TF_ATTR_SHAPE", TF_ATTR_SHAPE)
1508       .value("TF_ATTR_TENSOR", TF_ATTR_TENSOR)
1509       .value("TF_ATTR_PLACEHOLDER", TF_ATTR_PLACEHOLDER)
1510       .value("TF_ATTR_FUNC", TF_ATTR_FUNC)
1511       .export_values();
1512 };
1513