• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");;
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 
18 #include "Python.h"
19 #include "absl/strings/str_format.h"
20 #include "pybind11/chrono.h"
21 #include "pybind11/complex.h"
22 #include "pybind11/functional.h"
23 #include "pybind11/pybind11.h"
24 #include "pybind11/pytypes.h"
25 #include "pybind11/stl.h"
26 #include "tensorflow/c/c_api.h"
27 #include "tensorflow/c/c_api_experimental.h"
28 #include "tensorflow/c/eager/c_api.h"
29 #include "tensorflow/c/eager/c_api_experimental.h"
30 #include "tensorflow/c/eager/c_api_internal.h"
31 #include "tensorflow/c/eager/dlpack.h"
32 #include "tensorflow/c/eager/tfe_cancellation_manager_internal.h"
33 #include "tensorflow/c/eager/tfe_context_internal.h"
34 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
35 #include "tensorflow/c/tf_status.h"
36 #include "tensorflow/c/tf_status_helper.h"
37 #include "tensorflow/compiler/jit/flags.h"
38 #include "tensorflow/compiler/jit/get_compiler_ir.h"
39 #include "tensorflow/python/eager/pywrap_tensor_conversion.h"
40 #include "tensorflow/python/eager/pywrap_tfe.h"
41 #include "tensorflow/python/lib/core/py_exception_registry.h"
42 #include "tensorflow/python/lib/core/pybind11_lib.h"
43 #include "tensorflow/python/lib/core/pybind11_status.h"
44 #include "tensorflow/python/lib/core/safe_ptr.h"
45 #include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
46 #include "tensorflow/python/util/util.h"
47 
48 namespace py = pybind11;
49 
50 PYBIND11_MAKE_OPAQUE(TFE_Executor);
51 PYBIND11_MAKE_OPAQUE(TFE_ContextOptions);
52 PYBIND11_MAKE_OPAQUE(tensorflow::CancellationManager);
53 
54 PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter0);
55 PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter1);
56 PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter2);
57 PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge0);
58 PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge1);
59 PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge2);
60 PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge3);
61 PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge4);
62 PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGauge0);
63 PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGauge1);
64 PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGauge2);
65 PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGauge0);
66 PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGauge1);
67 PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGauge2);
68 PYBIND11_MAKE_OPAQUE(TFE_MonitoringSampler0);
69 PYBIND11_MAKE_OPAQUE(TFE_MonitoringSampler1);
70 PYBIND11_MAKE_OPAQUE(TFE_MonitoringSampler2);
71 PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounterCell);
72 PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGaugeCell);
73 PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGaugeCell);
74 PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGaugeCell);
75 PYBIND11_MAKE_OPAQUE(TFE_MonitoringSamplerCell);
76 
77 PYBIND11_MAKE_OPAQUE(TF_DeviceList);
78 PYBIND11_MAKE_OPAQUE(TF_Function);
79 PYBIND11_MAKE_OPAQUE(TF_Buffer);
80 
81 // Eager helper functions migrated from pywrap_tfe.i.
82 
83 namespace tensorflow {
84 
85 // We cannot use Context as an opaque type. SWIG also had
86 // difficult directly passing the pointer around. These
87 // typemaps are migrated over from pywrap_tfe.i. I tried
88 // using a custom type caster, but we get segfaults periodically.
89 
90 // TODO(amitpatankar): Move input and output logic of Context into a
91 // pybind11 custom type caster.
92 
InputTFE_Context(const py::handle & ctx)93 TFE_Context* InputTFE_Context(const py::handle& ctx) {
94   return static_cast<TFE_Context*>(PyCapsule_GetPointer(ctx.ptr(), nullptr));
95 }
96 
OutputTFE_Context(TFE_Context * context)97 PyObject* OutputTFE_Context(TFE_Context* context) {
98   return PyCapsule_New(context, nullptr, TFE_DeleteContextCapsule);
99 }
100 
ProtoStringToTFBuffer(PyObject * input)101 TF_Buffer* ProtoStringToTFBuffer(PyObject* input) {
102   // Convert a Python string object to TF_Buffer.
103   char* c_string;
104   Py_ssize_t py_size;
105   // PyBytes_AsStringAndSize() does not copy but simply interprets the input
106   if (PyBytes_AsStringAndSize(input, &c_string, &py_size) == -1) {
107     // Python has raised an error (likely TypeError or UnicodeEncodeError).
108     throw py::error_already_set();
109   }
110   return TF_NewBufferFromString(static_cast<void*>(c_string),
111                                 static_cast<size_t>(py_size));
112 }
113 
114 // These functions are typemaps from the Python side. I did not use
115 // a custom type caster since the logic is slightly harder to follow. This
116 // converter is also only used once in `TFE_Py_ExecuteCancelable_wrapper`.
InputTFE_InputTensorHandles(const py::handle & input_tensors)117 TFE_InputTensorHandles InputTFE_InputTensorHandles(
118     const py::handle& input_tensors) {
119   TFE_InputTensorHandles input_tensor_handles;
120   if (input_tensors.ptr() != Py_None) {
121     if (!PyList_Check(input_tensors.ptr())) {
122       tensorflow::ThrowTypeError("must provide a list of Tensors as inputs");
123     }
124     Py_ssize_t len = PyList_Size(input_tensors.ptr());
125     input_tensor_handles.resize(len);
126     for (Py_ssize_t i = 0; i < len; ++i) {
127       PyObject* elem = PyList_GetItem(input_tensors.ptr(), i);
128       if (!elem) {
129         tensorflow::ThrowTypeError("Input Tensor does not exist.");
130       }
131       if (EagerTensor_CheckExact(elem)) {
132         (input_tensor_handles)[i] = EagerTensor_Handle(elem);
133       } else if (tensorflow::swig::IsEagerTensorSlow(elem)) {
134         // Use equivalent of object.__getattribute__ to get the underlying
135         // tf wrapped EagerTensor (if there is one).
136         tensorflow::Safe_PyObjectPtr tf_should_use_attr(
137 #if PY_MAJOR_VERSION < 3
138             PyString_InternFromString("_tf_should_use_wrapped_value")
139 #else
140             PyUnicode_InternFromString("_tf_should_use_wrapped_value")
141 #endif
142         );
143         tensorflow::Safe_PyObjectPtr value_attr(
144             PyObject_GenericGetAttr(elem, tf_should_use_attr.get()));
145         if (value_attr) {
146           // This is an EagerTensor wrapped inside a TFShouldUse wrapped object.
147           (input_tensor_handles)[i] = EagerTensor_Handle(value_attr.get());
148         } else {
149           // This is a subclass of EagerTensor that we don't support.
150           PyErr_Clear();
151           tensorflow::ThrowTypeError(
152               tensorflow::strings::StrCat(
153                   "Saw an object that is an instance of a strict subclass of "
154                   "EagerTensor, which is not supported.  Item ",
155                   i, " is type: ", elem->ob_type->tp_name)
156                   .c_str());
157         }
158       } else if (tensorflow::swig::IsTensor(elem)) {
159         // If it isnt an EagerTensor, but is still a Tensor, it must be a graph
160         // tensor.
161         tensorflow::Safe_PyObjectPtr name_attr(
162             PyObject_GetAttrString(elem, "name"));
163         tensorflow::ThrowTypeError(
164             tensorflow::strings::StrCat(
165                 "An op outside of the function building code is being passed\n"
166                 "a \"Graph\" tensor. It is possible to have Graph tensors\n"
167                 "leak out of the function building context by including a\n"
168                 "tf.init_scope in your function building code.\n"
169                 "For example, the following function will fail:\n",
170                 "  @tf.function\n", "  def has_init_scope():\n",
171                 "    my_constant = tf.constant(1.)\n",
172                 "    with tf.init_scope():\n",
173                 "      added = my_constant * 2\n",
174                 "The graph tensor has name: ",
175                 name_attr ? TFE_GetPythonString(name_attr.get()) : "<unknown>")
176                 .c_str());
177       } else {
178         tensorflow::ThrowTypeError(
179             tensorflow::strings::StrCat(
180                 "provided list of inputs contains objects other "
181                 "than 'EagerTensor'. Item ",
182                 i, " is type: ", elem->ob_type->tp_name)
183                 .c_str());
184       }
185     }
186   }
187   return input_tensor_handles;
188 }
189 
190 // These functions are typemaps from the Python side. I did not use
191 // a custom type caster since the logic is slightly harder to follow. This
192 // converter is also only used once in `TFE_Py_ExecuteCancelable_wrapper`.
193 // This function actually takes a number rather than an output Tensor holder.
InputTFE_OutputTensorHandles(const py::handle & num_outputs)194 TFE_OutputTensorHandles InputTFE_OutputTensorHandles(
195     const py::handle& num_outputs) {
196   TFE_OutputTensorHandles output_tensor_handles;
197 #if PY_MAJOR_VERSION < 3
198   if (!PyInt_Check(num_outputs.ptr())) {
199 #else
200   if (!PyLong_Check(num_outputs.ptr())) {
201 #endif
202     PyErr_SetString(PyExc_TypeError,
203                     "expected an integer value (size of the number of "
204                     "outputs of the operation)");
205     throw py::error_already_set();
206   }
207 #if PY_MAJOR_VERSION < 3
208   long sz = PyInt_AsLong(num_outputs.ptr());  // NOLINT
209 #else
210   long sz = PyLong_AsLong(num_outputs.ptr());  // NOLINT
211 #endif
212   // PyLong_AsLong might throw an error if an overflow occurs.
213   if (PyErr_Occurred()) {
214     PyErr_SetString(PyExc_ValueError, tensorflow::strings::StrCat(
215                                           "Number of outputs is too big: ", sz)
216                                           .c_str());
217     throw py::error_already_set();
218   }
219   // We can't handle more than int32 sizes for number of outputs.
220   if (static_cast<long>(static_cast<int32_t>(sz)) != sz) {  // NOLINT
221     PyErr_SetString(PyExc_ValueError, tensorflow::strings::StrCat(
222                                           "Number of outputs is too big: ", sz)
223                                           .c_str());
224     throw py::error_already_set();
225   }
226   if (sz > 0) {
227 #if PY_MAJOR_VERSION < 3
228     output_tensor_handles.resize(PyInt_AsLong(num_outputs.ptr()), nullptr);
229 #else
230     output_tensor_handles.resize(PyLong_AsLong(num_outputs.ptr()), nullptr);
231 #endif
232   }
233   return output_tensor_handles;
234 }
235 
236 tensorflow::Device* GetMatchedDevice(py::handle& ctx, const char* device_name) {
237   auto* context = reinterpret_cast<tensorflow::ImmediateExecutionContext*>(
238       tensorflow::InputTFE_Context(ctx));
239 
240   tensorflow::DeviceNameUtils::ParsedName input_device_name;
241   if (!tensorflow::DeviceNameUtils::ParseFullOrLocalName(device_name,
242                                                          &input_device_name)) {
243     tensorflow::ThrowValueError(
244         absl::StrFormat("Failed parsing device name: '%s'. Note a valid device "
245                         "string should at least contain a device type and a "
246                         "device index, like \"GPU:0\".",
247                         device_name)
248             .c_str());
249   }
250 
251   std::vector<tensorflow::Device*> devices = context->ListLocalTfDevices();
252 
253   tensorflow::Device* matched_device = nullptr;
254   for (int device_idx = 0; device_idx < devices.size(); device_idx++) {
255     tensorflow::Device* device = devices[device_idx];
256 
257     if (tensorflow::DeviceNameUtils::AreCompatibleDevNames(
258             input_device_name, device->parsed_name())) {
259       if (matched_device != nullptr) {
260         tensorflow::ThrowValueError(
261             absl::StrFormat("Multiple devices match the provided string "
262                             "'%s': '%s' and '%s'.",
263                             device_name, matched_device->name(), device->name())
264                 .c_str());
265       }
266       matched_device = device;
267     }
268   }
269 
270   if (matched_device == nullptr) {
271     tensorflow::ThrowValueError(
272         absl::StrFormat("No matching devices found for '%s'", device_name)
273             .c_str());
274   }
275 
276   return matched_device;
277 }
278 
279 // Packs multiple `EagerTensor`s of the same dtype and shape into one
280 // `EagerTensor`.
281 py::object TFE_Py_PackEagerTensors_wrapper(const py::handle& context,
282                                            const py::handle& tensors) {
283   TFE_Context* ctx = tensorflow::InputTFE_Context(context);
284   TFE_InputTensorHandles handles = InputTFE_InputTensorHandles(tensors);
285   tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus());
286   int size = handles.size();
287   TFE_TensorHandle* packed_handle =
288       TFE_CreatePackedTensorHandle(ctx, handles.data(), &size, status.get());
289   tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
290   PyObject* packed_tensor =
291       EagerTensorFromHandle(packed_handle, /*is_packed=*/true);
292   return tensorflow::PyoOrThrow(packed_tensor);
293 }
294 
295 // This function was created from fusing the typemap logic in platform/base.i.
296 py::object TFE_Py_ExecuteCancelable_wrapper(
297     const py::handle& context, const char* device_name, const char* op_name,
298     const py::handle& inputs, const py::handle& attrs,
299     tensorflow::CancellationManager* cancellation_manager,
300     const py::handle& num_outputs) {
301   TFE_Context* ctx = tensorflow::InputTFE_Context(context);
302   TFE_InputTensorHandles input_tensor_handles =
303       InputTFE_InputTensorHandles(inputs);
304   TFE_OutputTensorHandles output_tensor_handles =
305       InputTFE_OutputTensorHandles(num_outputs);
306   tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus());
307   TFE_Py_ExecuteCancelable(ctx, device_name, op_name, &input_tensor_handles,
308                            attrs.ptr(), tensorflow::wrap(cancellation_manager),
309                            &output_tensor_handles, status.get());
310 
311   int output_len = output_tensor_handles.size();
312   PyObject* output_list = PyList_New(output_len);
313   for (int i = 0; i < output_len; ++i) {
314     PyObject* output;
315     output = EagerTensorFromHandle(output_tensor_handles.at(i));
316     PyList_SetItem(output_list, i, output);
317   }
318   tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
319   return tensorflow::PyoOrThrow(output_list);
320 }
321 
322 static py::object TF_ListPhysicalDevices() {
323   std::vector<string> devices;
324   tensorflow::Status s =
325       tensorflow::DeviceFactory::ListAllPhysicalDevices(&devices);
326   MaybeRaiseRegisteredFromStatus(s);
327   PyObject* result = PyList_New(devices.size());
328   int i = 0;
329   for (auto& dev : devices) {
330     PyObject* dev_obj = PyBytes_FromStringAndSize(dev.data(), dev.size());
331     PyList_SetItem(result, i, dev_obj);
332     ++i;
333   }
334   return tensorflow::PyoOrThrow(result);
335 }
336 
337 static py::object TF_ListPluggablePhysicalDevices() {
338   std::vector<string> devices;
339   tensorflow::Status s =
340       tensorflow::DeviceFactory::ListPluggablePhysicalDevices(&devices);
341   MaybeRaiseRegisteredFromStatus(s);
342   Safe_PyObjectPtr result(PyList_New(devices.size()));
343   int i = 0;
344   for (auto& dev : devices) {
345     PyObject* dev_obj = PyBytes_FromStringAndSize(dev.data(), dev.size());
346     PyList_SetItem(result.get(), i, dev_obj);
347     ++i;
348   }
349   return tensorflow::PyoOrThrow(result.release());
350 }
351 
352 static std::unordered_map<string, string> TF_GetDeviceDetails(int index) {
353   tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus());
354   std::unordered_map<string, string> device_details;
355   tensorflow::Status s =
356       tensorflow::DeviceFactory::GetAnyDeviceDetails(index, &device_details);
357   tensorflow::Set_TF_Status_from_Status(status.get(), s);
358   MaybeRaiseRegisteredFromTFStatus(status.get());
359   return device_details;
360 }
361 
362 static py::object TFE_ClearScalarCache() {
363   tensorflow::TFE_TensorHandleCache::Get()->Clear();
364   return py::none();
365 }
366 
367 // Returns compiler IR for a given function.
368 static py::bytes TFE_GetCompilerIr(py::handle& ctx,
369                                    const char* concrete_function_name,
370                                    const char* stage, const char* device_name,
371                                    py::handle& inputs) {
372   EagerContext* context = ContextFromInterface(
373       reinterpret_cast<ImmediateExecutionContext*>(InputTFE_Context(ctx)));
374 
375   std::string s_stage(stage);
376   IrExportStage selected_stage = [&] {
377     if (s_stage == "hlo") {
378       return IrExportStage::HLO;
379     } else if (s_stage == "hlo_serialized") {
380       return IrExportStage::HLO_SERIALIZED;
381     } else if (s_stage == "optimized_hlo") {
382       return IrExportStage::OPTIMIZED_HLO;
383     } else if (s_stage == "optimized_hlo_serialized") {
384       return IrExportStage::OPTIMIZED_HLO_SERIALIZED;
385     } else if (s_stage == "optimized_hlo_proto_serialized") {
386       return IrExportStage::OPTIMIZED_HLO_PROTO_SERIALIZED;
387     } else if (s_stage == "optimized_hlo_dot") {
388       return IrExportStage::OPTIMIZED_HLO_DOT;
389     } else {
390       ThrowValueError(
391           absl::StrFormat("Invalid stage selected: '%s'. Valid values are: "
392                           "'hlo', 'hlo_serialized', 'optimized_hlo', "
393                           "'optimized_hlo_serialized', 'optimized_hlo_dot'",
394                           s_stage)
395               .c_str());
396     }
397   }();
398 
399   TFE_InputTensorHandles handles = InputTFE_InputTensorHandles(inputs);
400 
401   std::vector<const TensorHandle*> input_handles;
402   for (TFE_TensorHandle* tensor_handle : handles) {
403     AbstractTensorHandle* abstract_tensor_handle = unwrap(tensor_handle);
404     input_handles.push_back(TensorHandleFromInterface(abstract_tensor_handle));
405   }
406 
407   DeviceNameUtils::ParsedName input_device_name;
408   if (!DeviceNameUtils::ParseFullOrLocalName(device_name, &input_device_name)) {
409     ThrowValueError(
410         absl::StrFormat("Failed parsing device name: '%s'", device_name)
411             .c_str());
412   }
413 
414   std::vector<Device*> devices = context->local_device_mgr()->ListDevices();
415   auto selected_device = absl::c_find_if(devices, [&](const Device* d) {
416     return DeviceNameUtils::AreCompatibleDevNames(input_device_name,
417                                                   d->parsed_name());
418   });
419   if (selected_device == devices.end()) {
420     ThrowValueError(
421         absl::StrFormat("No matching device found for '%s'", device_name)
422             .c_str());
423   }
424 
425   xla::StatusOr<std::string> hlo_str =
426       GetCompilerIr(selected_stage, context->pflr(), concrete_function_name,
427                     *selected_device, context, input_handles);
428 
429   if (!hlo_str.ok()) {
430     ThrowValueError(absl::StrFormat("Failed getting HLO text: '%s'",
431                                     hlo_str.status().error_message())
432                         .c_str());
433   }
434   return py::bytes(*hlo_str);
435 }
436 
437 }  // namespace tensorflow
438 
439 namespace {
440 
441 // Wrapper around the EagerContextThreadLocalData struct (defined in
442 // pywrap_tfe.h), so it can be accessed from Python.
443 //
444 // For PyObject* fields, the get_*() methods return a new reference; and the
445 // set_*() methods create a new reference (i.e., they do not steal a reference).
446 class EagerContextThreadLocalDataWrapper {
447  public:
EagerContextThreadLocalDataWrapper(py::handle py_eager_context,py::handle is_eager,py::handle device_spec)448   explicit EagerContextThreadLocalDataWrapper(py::handle py_eager_context,
449                                               py::handle is_eager,
450                                               py::handle device_spec)
451       : py_eager_context_(py_eager_context.ptr()) {
452     tensorflow::MakeEagerContextThreadLocalData(
453         py_eager_context.ptr(), is_eager.ptr(), device_spec.ptr());
454   }
455 
~EagerContextThreadLocalDataWrapper()456   ~EagerContextThreadLocalDataWrapper() {
457     tensorflow::DestroyEagerContextThreadLocalData(py_eager_context_);
458   }
459 
get_is_eager() const460   bool get_is_eager() const { return GetData()->is_eager; }
set_is_eager(bool v)461   void set_is_eager(bool v) { GetData()->is_eager = v; }
462 
get_invoking_op_callbacks() const463   bool get_invoking_op_callbacks() const {
464     return GetData()->invoking_op_callbacks;
465   }
set_invoking_op_callbacks(bool v)466   void set_invoking_op_callbacks(bool v) {
467     GetData()->invoking_op_callbacks = v;
468   }
469 
get_device_name() const470   py::object get_device_name() const {
471     return GetPyObject(&GetData()->device_name);
472   }
set_device_name(py::handle v)473   void set_device_name(py::handle v) {
474     SetPyObject(v, &GetData()->device_name);
475   }
476 
get_scope_name() const477   py::object get_scope_name() const {
478     return GetPyObject(&GetData()->scope_name);
479   }
set_scope_name(py::handle v)480   void set_scope_name(py::handle v) { SetPyObject(v, &GetData()->scope_name); }
481 
get_device_spec() const482   py::object get_device_spec() const {
483     return GetPyObject(&GetData()->device_spec);
484   }
set_device_spec(py::handle v)485   void set_device_spec(py::handle v) {
486     SetPyObject(v, &GetData()->device_spec);
487   }
488 
get_function_call_options() const489   py::object get_function_call_options() const {
490     return GetPyObject(&GetData()->function_call_options);
491   }
set_function_call_options(py::handle v)492   void set_function_call_options(py::handle v) {
493     SetPyObject(v, &GetData()->function_call_options);
494   }
495 
get_executor() const496   py::handle get_executor() const { return GetPyObject(&GetData()->executor); }
set_executor(py::handle v)497   void set_executor(py::handle v) { SetPyObject(v, &GetData()->executor); }
498 
get_op_callbacks() const499   py::object get_op_callbacks() const {
500     return GetPyObject(&GetData()->op_callbacks);
501   }
set_op_callbacks(py::handle v)502   void set_op_callbacks(py::handle v) {
503     SetPyObject(v, &GetData()->op_callbacks);
504   }
505 
506  private:
GetData() const507   tensorflow::EagerContextThreadLocalData* GetData() const {
508     auto* result =
509         tensorflow::GetEagerContextThreadLocalData(py_eager_context_);
510     if (!result) {
511       throw py::error_already_set();
512     }
513     return result;
514   }
515 
GetPyObject(tensorflow::Safe_PyObjectPtr * obj) const516   py::object GetPyObject(tensorflow::Safe_PyObjectPtr* obj) const {
517     return pybind11::reinterpret_borrow<py::object>(obj->get());
518   }
519 
SetPyObject(py::handle value,tensorflow::Safe_PyObjectPtr * ptr)520   void SetPyObject(py::handle value, tensorflow::Safe_PyObjectPtr* ptr) {
521     Py_INCREF(value.ptr());
522     ptr->reset(value.ptr());
523   }
524 
525   PyObject* py_eager_context_;  // not owned (borrowed reference).
526 };
527 
528 }  // namespace
529 
530 // py::return_value_policy::reference is defined as specified by the
531 // pybind11 documents listed here.
532 // https://pybind11.readthedocs.io/en/stable/advanced/functions.html#return-value-policies
533 // This means that C++ maintains ownership of the object. We
534 // are only assigning this to functions that return opaque types.
535 
PYBIND11_MODULE(_pywrap_tfe,m)536 PYBIND11_MODULE(_pywrap_tfe, m) {
537   py::class_<TFE_Executor> TFE_Executor_class(m, "TFE_Executor");
538   py::class_<TFE_ContextOptions> TFE_ContextOptions_class(m,
539                                                           "TFE_ContextOptions");
540   py::class_<TFE_MonitoringCounter0> TFE_MonitoringCounter0_class(
541       m, "TFE_MonitoringCounter0");
542   py::class_<TFE_MonitoringCounter1> TFE_MonitoringCounter1_class(
543       m, "TFE_MonitoringCounter1");
544   py::class_<TFE_MonitoringCounter2> TFE_MonitoringCounter2_class(
545       m, "TFE_MonitoringCounter2");
546   py::class_<TFE_MonitoringStringGauge0> TFE_MonitoringStringGauge0_class(
547       m, "TFE_MonitoringStringGauge0");
548   py::class_<TFE_MonitoringStringGauge1> TFE_MonitoringStringGauge1_class(
549       m, "TFE_MonitoringStringGauge1");
550   py::class_<TFE_MonitoringStringGauge2> TFE_MonitoringStringGauge2_class(
551       m, "TFE_MonitoringStringGauge2");
552   py::class_<TFE_MonitoringStringGauge3> TFE_MonitoringStringGauge3_class(
553       m, "TFE_MonitoringStringGauge3");
554   py::class_<TFE_MonitoringStringGauge4> TFE_MonitoringStringGauge4_class(
555       m, "TFE_MonitoringStringGauge4");
556   py::class_<TFE_MonitoringIntGauge0> TFE_MonitoringIntGauge0_class(
557       m, "TFE_MonitoringIntGauge0");
558   py::class_<TFE_MonitoringIntGauge1> TFE_MonitoringIntGauge1_class(
559       m, "TFE_MonitoringIntGauge1");
560   py::class_<TFE_MonitoringIntGauge2> TFE_MonitoringIntGauge2_class(
561       m, "TFE_MonitoringIntGauge2");
562   py::class_<TFE_MonitoringBoolGauge0> TFE_MonitoringBoolGauge0_class(
563       m, "TFE_MonitoringBoolGauge0");
564   py::class_<TFE_MonitoringBoolGauge1> TFE_MonitoringBoolGauge1_class(
565       m, "TFE_MonitoringBoolGauge1");
566   py::class_<TFE_MonitoringBoolGauge2> TFE_MonitoringBoolGauge2_class(
567       m, "TFE_MonitoringBoolGauge2");
568   py::class_<TFE_MonitoringCounterCell> TFE_MonitoringCounterCell_class(
569       m, "TFE_MonitoringCounterCell");
570   py::class_<TFE_MonitoringIntGaugeCell> TFE_MonitoringIntGaugeCell_class(
571       m, "TFE_MonitoringIntGaugeCell");
572   py::class_<TFE_MonitoringStringGaugeCell> TFE_MonitoringStringGaugeCell_class(
573       m, "TFE_MonitoringStringGaugeCell");
574   py::class_<TFE_MonitoringBoolGaugeCell> TFE_MonitoringBoolGaugeCell_class(
575       m, "TFE_MonitoringBoolGaugeCell");
576   py::class_<TFE_MonitoringSamplerCell> TFE_MonitoringSamplerCell_class(
577       m, "TFE_MonitoringSamplerCell");
578   py::class_<TFE_MonitoringBuckets> TFE_MonitoringBuckets_class(
579       m, "TFE_MonitoringBuckets");
580   py::class_<TFE_MonitoringSampler0> TFE_MonitoringSampler0_class(
581       m, "TFE_MonitoringSampler0");
582   py::class_<TFE_MonitoringSampler1> TFE_MonitoringSampler1_class(
583       m, "TFE_MonitoringSampler1");
584   py::class_<TFE_MonitoringSampler2> TFE_MonitoringSampler2_class(
585       m, "TFE_MonitoringSampler2");
586   py::class_<tensorflow::CancellationManager> TFE_CancellationManager_class(
587       m, "TFE_CancellationManager");
588 
589   py::class_<TF_DeviceList> TF_DeviceList_class(m, "TF_DeviceList");
590   py::class_<TF_Function> TF_Function_class(m, "TF_Function");
591 
592   m.def("TFE_Py_RegisterExceptionClass", [](const py::handle& e) {
593     return tensorflow::PyoOrThrow(TFE_Py_RegisterExceptionClass(e.ptr()));
594   });
595   m.def("TFE_Py_RegisterFallbackExceptionClass", [](const py::handle& e) {
596     return tensorflow::PyoOrThrow(
597         TFE_Py_RegisterFallbackExceptionClass(e.ptr()));
598   });
599 
600   m.def("TFE_GetMemoryInfo", [](py::handle& ctx, const char* device_name) {
601     tensorflow::Device* matched_device =
602         tensorflow::GetMatchedDevice(ctx, device_name);
603 
604     tensorflow::AllocatorAttributes attrs;
605     tensorflow::Allocator* allocator = matched_device->GetAllocator(attrs);
606 
607     if (absl::optional<tensorflow::AllocatorStats> stats =
608             allocator->GetStats()) {
609       return std::map<std::string, int64_t>{{"current", stats->bytes_in_use},
610                                             {"peak", stats->peak_bytes_in_use}};
611     }
612 
613     tensorflow::ThrowValueError(
614         absl::StrFormat("Allocator stats not available for device '%s'",
615                         device_name)
616             .c_str());
617   });
618 
619   m.def("TFE_ResetMemoryStats", [](py::handle& ctx, const char* device_name) {
620     tensorflow::Device* matched_device =
621         tensorflow::GetMatchedDevice(ctx, device_name);
622 
623     tensorflow::AllocatorAttributes attrs;
624     tensorflow::Allocator* allocator = matched_device->GetAllocator(attrs);
625 
626     if (!allocator->ClearStats()) {
627       tensorflow::ThrowValueError(
628           absl::StrFormat("Cannot reset memory stats for device '%s'",
629                           device_name)
630               .c_str());
631     }
632   });
633 
634   // XLA Eager Logic
635   m.def("TF_SetXlaEnableLazyCompilation", &TF_SetXlaEnableLazyCompilation);
636   m.def("TF_SetTfXlaCpuGlobalJit", &TF_SetTfXlaCpuGlobalJit);
637   m.def("TF_SetXlaAutoJitMode", &TF_SetXlaAutoJitMode);
638   m.def("TF_SetXlaConstantFoldingDisabled", &TF_SetXlaConstantFoldingDisabled);
639   m.def("TF_GetXlaConstantFoldingDisabled", &TF_GetXlaConstantFoldingDisabled);
640   m.def("TF_SetXlaMinClusterSize", &TF_SetXlaMinClusterSize);
641   m.def("TF_GetCompilerIr", &tensorflow::TFE_GetCompilerIr);
642 
643   // MLIR Logic
644   m.def("TF_IsMlirBridgeEnabled", [] {
645     // Since python protobuf enums are integers, cast to an integer before
646     // returning the enum to python.
647     return static_cast<int32_t>(
648         tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge);
649   });
650   m.def("TF_EnableMlirBridge", [](bool enabled) {
651     tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge =
652         enabled
653             ? tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED
654             : tensorflow::ConfigProto::Experimental::
655                   MLIR_BRIDGE_ROLLOUT_DISABLED;
656   });
657   m.def("TF_EnableXlaDevices", [] {
658     tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true;
659   });
660 
661   // // TFE_Context Logic
662   m.def(
663       "TFE_NewContext",
664       [](const TFE_ContextOptions* opts) {
665         tensorflow::Safe_TF_StatusPtr status =
666             tensorflow::make_safe(TF_NewStatus());
667         TFE_Context* context = TFE_NewContext(opts, status.get());
668         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
669         return tensorflow::PyoOrThrow(tensorflow::OutputTFE_Context(context));
670       },
671       py::return_value_policy::reference);
672   m.def("TFE_DeleteContext", [](py::handle& o) {
673     TFE_DeleteContext(tensorflow::InputTFE_Context(o));
674   });
675   m.def(
676       "TFE_ContextListDevices",
677       [](py::handle& o) {
678         tensorflow::Safe_TF_StatusPtr status =
679             tensorflow::make_safe(TF_NewStatus());
680         auto output = TFE_ContextListDevices(tensorflow::InputTFE_Context(o),
681                                              status.get());
682         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
683         return output;
684       },
685       py::return_value_policy::reference);
686   m.def(
687       "TFE_SetLogicalCpuDevices",
688       [](py::handle& ctx, int num_cpus, const char* prefix) {
689         tensorflow::Safe_TF_StatusPtr status =
690             tensorflow::make_safe(TF_NewStatus());
691         TFE_SetLogicalCpuDevices(tensorflow::InputTFE_Context(ctx), num_cpus,
692                                  prefix, status.get());
693         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
694       },
695       py::return_value_policy::reference);
696   m.def("TFE_HostAddressSpace", [](py::handle& o, TF_Buffer& buf) {
697     TFE_HostAddressSpace(tensorflow::InputTFE_Context(o), &buf);
698   });
699   m.def("TFE_ContextAddFunction", [](py::handle& ctx, TF_Function* func) {
700     tensorflow::Safe_TF_StatusPtr status =
701         tensorflow::make_safe(TF_NewStatus());
702     TFE_ContextAddFunction(tensorflow::InputTFE_Context(ctx), func,
703                            status.get());
704     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
705   });
706   m.def("TFE_ContextAddFunctionDef",
707         [](py::handle& ctx, const char* serialized_function_def, size_t size) {
708           tensorflow::Safe_TF_StatusPtr status =
709               tensorflow::make_safe(TF_NewStatus());
710           TFE_ContextAddFunctionDef(tensorflow::InputTFE_Context(ctx),
711                                     serialized_function_def, size,
712                                     status.get());
713           tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
714         });
715   m.def("TFE_ContextGetFunctionDef",
716         [](py::handle& ctx, const char* function_name, TF_Buffer& buf) {
717           tensorflow::Safe_TF_StatusPtr status =
718               tensorflow::make_safe(TF_NewStatus());
719           TFE_ContextGetFunctionDef(tensorflow::InputTFE_Context(ctx),
720                                     function_name, &buf, status.get());
721           tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
722         });
723   m.def("TFE_ContextRemoveFunction", [](py::handle& ctx, const char* name) {
724     tensorflow::Safe_TF_StatusPtr status =
725         tensorflow::make_safe(TF_NewStatus());
726     TFE_ContextRemoveFunction(tensorflow::InputTFE_Context(ctx), name,
727                               status.get());
728     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
729   });
730   m.def("TFE_ContextHasFunction", [](py::handle& ctx, const char* name) {
731     tensorflow::Safe_TF_StatusPtr status =
732         tensorflow::make_safe(TF_NewStatus());
733     auto output =
734         TFE_ContextHasFunction(tensorflow::InputTFE_Context(ctx), name);
735     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
736     return output;
737   });
738   m.def("TFE_ContextListFunctionNames", [](py::handle& ctx) {
739     return tensorflow::unwrap(tensorflow::InputTFE_Context(ctx))
740         ->ListFunctionNames();
741   });
742   m.def("TFE_ContextEnableRunMetadata", [](py::handle& ctx) {
743     TFE_ContextEnableRunMetadata(tensorflow::InputTFE_Context(ctx));
744   });
745   m.def("TFE_ContextDisableRunMetadata", [](py::handle& ctx) {
746     TFE_ContextEnableRunMetadata(tensorflow::InputTFE_Context(ctx));
747   });
748   m.def("TFE_ContextEnableGraphCollection", [](py::handle& ctx) {
749     TFE_ContextEnableGraphCollection(tensorflow::InputTFE_Context(ctx));
750   });
751   m.def("TFE_ContextDisableGraphCollection", [](py::handle& ctx) {
752     TFE_ContextDisableGraphCollection(tensorflow::InputTFE_Context(ctx));
753   });
754   m.def("TFE_ContextExportRunMetadata", [](py::handle& ctx, TF_Buffer& buf) {
755     tensorflow::Safe_TF_StatusPtr status =
756         tensorflow::make_safe(TF_NewStatus());
757     TFE_ContextExportRunMetadata(tensorflow::InputTFE_Context(ctx), &buf,
758                                  status.get());
759     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
760   });
761   m.def("TFE_ContextClearCaches", [](py::handle& o) {
762     TFE_ContextClearCaches(tensorflow::InputTFE_Context(o));
763   });
764   m.def("TFE_GetContextId", [](py::handle& ctx) {
765     return TFE_GetContextId(tensorflow::InputTFE_Context(ctx));
766   });
767   m.def("TFE_ContextGetDevicePlacementPolicy", [](py::handle& ctx) {
768     return TFE_ContextGetDevicePlacementPolicy(
769         tensorflow::InputTFE_Context(ctx));
770   });
771   m.def("TFE_ContextSetThreadLocalDevicePlacementPolicy",
772         [](py::handle& ctx, TFE_ContextDevicePlacementPolicy policy) {
773           TFE_ContextSetThreadLocalDevicePlacementPolicy(
774               tensorflow::InputTFE_Context(ctx), policy);
775         });
776   m.def("TFE_ContextSetServerDef", [](py::handle& ctx, int keep_alive_secs,
777                                       py::bytes proto) {
778     tensorflow::Safe_TF_StatusPtr status =
779         tensorflow::make_safe(TF_NewStatus());
780     tensorflow::Safe_TF_BufferPtr buf =
781         tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr()));
782     TFE_ContextSetServerDef(tensorflow::InputTFE_Context(ctx), keep_alive_secs,
783                             buf.get()->data, buf.get()->length, status.get());
784     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
785   });
786   m.def("TFE_ContextUpdateServerDef", [](py::handle& ctx, int keep_alive_secs,
787                                          py::bytes proto) {
788     tensorflow::Safe_TF_StatusPtr status =
789         tensorflow::make_safe(TF_NewStatus());
790     tensorflow::Safe_TF_BufferPtr buf =
791         tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr()));
792     Py_BEGIN_ALLOW_THREADS;
793     TFE_ContextUpdateServerDef(tensorflow::InputTFE_Context(ctx),
794                                keep_alive_secs, buf.get()->data,
795                                buf.get()->length, status.get());
796     Py_END_ALLOW_THREADS;
797     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
798   });
799   m.def("TFE_ContextCheckAlive", [](py::handle& ctx, const char* worker_name) {
800     tensorflow::Safe_TF_StatusPtr status =
801         tensorflow::make_safe(TF_NewStatus());
802     bool output = TFE_ContextCheckAlive(tensorflow::InputTFE_Context(ctx),
803                                         worker_name, status.get());
804     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
805     return output;
806   });
807   m.def("TFE_ContextSyncExecutors", [](py::handle& ctx) {
808     tensorflow::Safe_TF_StatusPtr status =
809         tensorflow::make_safe(TF_NewStatus());
810     // NOTE: release Python GIL for pending PyFunc ops to be executed properly.
811     Py_BEGIN_ALLOW_THREADS;
812     TFE_ContextAsyncWait(tensorflow::InputTFE_Context(ctx), status.get());
813     Py_END_ALLOW_THREADS;
814     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
815   });
816   m.def("TFE_ContextClearExecutors", [](py::handle& ctx) {
817     tensorflow::Safe_TF_StatusPtr status =
818         tensorflow::make_safe(TF_NewStatus());
819     // NOTE: release Python GIL for pending PyFunc ops to be executed properly.
820     Py_BEGIN_ALLOW_THREADS;
821     TFE_ContextAsyncWait(tensorflow::InputTFE_Context(ctx), status.get());
822     Py_END_ALLOW_THREADS;
823     // NOTE: different from TFE_ContextSyncExecutors that raises potential
824     // errors, deliberately ignore executor statuses in cleanup.
825   });
826   m.def(
827       "TFE_InsertConfigKeyValue",
828       [](py::handle& ctx, const char* config_key, const char* config_value) {
829         tensorflow::Safe_TF_StatusPtr status =
830             tensorflow::make_safe(TF_NewStatus());
831         Py_BEGIN_ALLOW_THREADS;
832         TFE_InsertConfigKeyValue(tensorflow::InputTFE_Context(ctx), config_key,
833                                  config_value, status.get());
834         Py_END_ALLOW_THREADS;
835         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
836       },
837       py::return_value_policy::reference);
838   m.def(
839       "TFE_GetConfigKeyValue",
840       [](py::handle& ctx, const char* config_key, TF_Buffer& config_value) {
841         tensorflow::Safe_TF_StatusPtr status =
842             tensorflow::make_safe(TF_NewStatus());
843         Py_BEGIN_ALLOW_THREADS;
844         TFE_GetConfigKeyValue(tensorflow::InputTFE_Context(ctx), config_key,
845                               &config_value, status.get());
846         Py_END_ALLOW_THREADS;
847         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
848       },
849       py::return_value_policy::reference);
850   m.def(
851       "TFE_DeleteConfigKeyValue",
852       [](py::handle& ctx, const char* config_key) {
853         tensorflow::Safe_TF_StatusPtr status =
854             tensorflow::make_safe(TF_NewStatus());
855         Py_BEGIN_ALLOW_THREADS;
856         TFE_DeleteConfigKeyValue(tensorflow::InputTFE_Context(ctx), config_key,
857                                  status.get());
858         Py_END_ALLOW_THREADS;
859         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
860       },
861       py::return_value_policy::reference);
862   m.def(
863       "TFE_ReportErrorToCluster",
864       [](py::handle& ctx, int error_code, const char* error_message) {
865         tensorflow::Safe_TF_StatusPtr status =
866             tensorflow::make_safe(TF_NewStatus());
867         TFE_ReportErrorToCluster(tensorflow::InputTFE_Context(ctx), error_code,
868                                  error_message, status.get());
869         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
870       },
871       py::return_value_policy::reference);
872   m.def("TFE_ContextSetSoftDevicePlacement", [](py::handle& ctx, bool enable) {
873     tensorflow::Safe_TF_StatusPtr status =
874         tensorflow::make_safe(TF_NewStatus());
875     TFE_ContextSetSoftDevicePlacement(tensorflow::InputTFE_Context(ctx), enable,
876                                       status.get());
877   });
878   m.def("TFE_ContextSetLogDevicePlacement", [](py::handle& ctx, bool enable) {
879     tensorflow::Safe_TF_StatusPtr status =
880         tensorflow::make_safe(TF_NewStatus());
881     TFE_ContextSetSoftDevicePlacement(tensorflow::InputTFE_Context(ctx), enable,
882                                       status.get());
883   });
884 
885   // TFE_Executor logic
886   m.def(
887       "TFE_NewExecutor",
888       [](const bool is_async) {
889         TFE_Executor* exc = TFE_NewExecutor(is_async);
890         return exc;
891       },
892       py::return_value_policy::reference);
893   m.def("TFE_DeleteExecutor", &TFE_DeleteExecutor);
894   m.def("TFE_ExecutorIsAsync", &TFE_ExecutorIsAsync);
895   m.def("TFE_ExecutorWaitForAllPendingNodes", [](TFE_Executor& exc) {
896     tensorflow::Safe_TF_StatusPtr status =
897         tensorflow::make_safe(TF_NewStatus());
898     // NOTE: release Python GIL for pending PyFunc ops to be executed properly.
899     Py_BEGIN_ALLOW_THREADS;
900     TFE_ExecutorWaitForAllPendingNodes(&exc, status.get());
901     Py_END_ALLOW_THREADS;
902     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
903   });
904   m.def("TFE_ExecutorClearError", &TFE_ExecutorClearError);
905   m.def("TFE_ContextSetExecutorForThread", [](py::handle& ctx,
906                                               TFE_Executor& exc) {
907     TFE_ContextSetExecutorForThread(tensorflow::InputTFE_Context(ctx), &exc);
908   });
909   m.def(
910       "TFE_ContextGetExecutorForThread",
911       [](py::handle& o) {
912         return TFE_ContextGetExecutorForThread(tensorflow::InputTFE_Context(o));
913       },
914       py::return_value_policy::reference);
915 
916   m.def("TFE_OpNameGetAttrType",
917         [](py::handle& ctx, const char* op_or_function_name,
918            const char* attr_name) {
919           int temp = 0;
920           unsigned char* is_list = reinterpret_cast<unsigned char*>(&temp);
921           tensorflow::Safe_TF_StatusPtr status =
922               tensorflow::make_safe(TF_NewStatus());
923           auto output = TFE_OpNameGetAttrType(tensorflow::InputTFE_Context(ctx),
924                                               op_or_function_name, attr_name,
925                                               is_list, status.get());
926           tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
927 #if PY_MAJOR_VERSION < 3
928           PyObject* output_pyo = PyInt_FromLong(output);
929 #else
930           PyObject* output_pyo = PyLong_FromLong(output);
931 #endif
932           if (*is_list == 1) {
933             PyObject* list = PyList_New(1);
934             PyList_SetItem(list, 0, output_pyo);
935             return tensorflow::PyoOrThrow(list);
936           }
937           return tensorflow::PyoOrThrow(output_pyo);
938         });
939   m.def("TFE_Py_InitEagerTensor", [](const py::handle& o) {
940     return tensorflow::PyoOrThrow(TFE_Py_InitEagerTensor(o.ptr()));
941   });
942   m.def("TFE_Py_PackEagerTensors",
943         [](const py::handle& context, const py::handle& handles) {
944           return tensorflow::TFE_Py_PackEagerTensors_wrapper(context, handles);
945         });
946   m.def("TFE_Py_SetEagerTensorProfiler", &TFE_Py_SetEagerTensorProfiler);
947   m.def("TFE_Py_RegisterJVPFunction", [](const py::handle& o) {
948     return tensorflow::PyoOrThrow(TFE_Py_RegisterJVPFunction(o.ptr()));
949   });
950   m.def("TFE_Py_RegisterGradientFunction", [](const py::handle& o) {
951     return tensorflow::PyoOrThrow(TFE_Py_RegisterGradientFunction(o.ptr()));
952   });
953   m.def("TFE_Py_Execute",
954         [](const py::handle& context, const char* device_name,
955            const char* op_name, const py::handle& inputs,
956            const py::handle& attrs, const py::handle& num_outputs) {
957           return tensorflow::TFE_Py_ExecuteCancelable_wrapper(
958               context, device_name, op_name, inputs, attrs.ptr(), nullptr,
959               num_outputs);
960         });
961   m.def(
962       "TFE_Py_ExecuteCancelable",
963       [](const py::handle& context, const char* device_name,
964          const char* op_name, const py::handle& inputs, const py::handle& attrs,
965          tensorflow::CancellationManager& cancellation_manager,
966          const py::handle& num_outputs) {
967         return tensorflow::TFE_Py_ExecuteCancelable_wrapper(
968             context, device_name, op_name, inputs, attrs.ptr(),
969             &cancellation_manager, num_outputs);
970       });
971   m.def("TFE_Py_FastPathExecute", [](const py::args args) {
972     // TFE_Py_FastPathExecute requires error checking prior to returning.
973     return tensorflow::PyoOrThrow(TFE_Py_FastPathExecute_C(args.ptr()));
974   });
975   m.def("TFE_Py_RecordGradient",
976         [](const py::handle& op_name, const py::handle& inputs,
977            const py::handle& attrs, const py::handle& results,
978            const py::handle& forward_pass_name_scope) {
979           return tensorflow::PyoOrThrow(TFE_Py_RecordGradient(
980               op_name.ptr(), inputs.ptr(), attrs.ptr(), results.ptr(),
981               forward_pass_name_scope.ptr()));
982         });
983   m.def("TFE_Py_UID", []() { return tensorflow::PyoOrThrow(TFE_Py_UID()); });
984 
985   // TFE_Py_Tape Logic
986   m.def("TFE_Py_TapeSetNew", [](const py::handle& persistent,
987                                 const py::handle& watch_accessed_variables) {
988     return tensorflow::PyoOrThrow(
989         TFE_Py_TapeSetNew(persistent.ptr(), watch_accessed_variables.ptr()));
990   });
991   m.def("TFE_Py_TapeSetAdd",
992         [](const py::handle& tape) { TFE_Py_TapeSetAdd(tape.ptr()); });
993   m.def("TFE_Py_TapeSetRemove",
994         [](const py::handle& tape) { TFE_Py_TapeSetRemove(tape.ptr()); });
995   m.def("TFE_Py_TapeSetStopOnThread", &TFE_Py_TapeSetStopOnThread);
996   m.def("TFE_Py_TapeSetRestartOnThread", &TFE_Py_TapeSetRestartOnThread);
997   m.def("TFE_Py_TapeSetIsStopped",
998         []() { return tensorflow::PyoOrThrow(TFE_Py_TapeSetIsStopped()); });
999   m.def("TFE_Py_TapeSetIsEmpty",
1000         []() { return tensorflow::PyoOrThrow(TFE_Py_TapeSetIsEmpty()); });
1001   m.def("TFE_Py_TapeSetShouldRecordBackprop", [](const py::handle& tensors) {
1002     return tensorflow::PyoOrThrow(
1003         TFE_Py_TapeSetShouldRecordBackprop(tensors.ptr()));
1004   });
1005   m.def("TFE_Py_TapeSetPossibleGradientTypes", [](const py::handle& tensors) {
1006     return tensorflow::PyoOrThrow(
1007         TFE_Py_TapeSetPossibleGradientTypes(tensors.ptr()));
1008   });
1009   m.def("TFE_Py_TapeSetDeleteTrace", &TFE_Py_TapeSetDeleteTrace);
1010   m.def("TFE_Py_TapeSetRecordOperation",
1011         [](const py::handle& op_type, const py::handle& output_tensors,
1012            const py::handle& input_tensors, const py::handle& backward_function,
1013            const py::handle& forward_function) {
1014           return tensorflow::PyoOrThrow(TFE_Py_TapeSetRecordOperation(
1015               op_type.ptr(), output_tensors.ptr(), input_tensors.ptr(),
1016               backward_function.ptr(), forward_function.ptr()));
1017         });
1018   m.def(
1019       "TFE_Py_TapeSetRecordOperationBackprop",
1020       [](const py::handle& op_type, const py::handle& output_tensors,
1021          const py::handle& input_tensors, const py::handle& backward_function) {
1022         return tensorflow::PyoOrThrow(TFE_Py_TapeSetRecordOperationBackprop(
1023             op_type.ptr(), output_tensors.ptr(), input_tensors.ptr(),
1024             backward_function.ptr()));
1025       });
1026   m.def(
1027       "TFE_Py_TapeSetRecordOperationForwardprop",
1028       [](const py::handle& op_type, const py::handle& output_tensors,
1029          const py::handle& input_tensors, const py::handle& backward_function,
1030          const py::handle& forwardprop_output_indices) {
1031         return tensorflow::PyoOrThrow(TFE_Py_TapeSetRecordOperationForwardprop(
1032             op_type.ptr(), output_tensors.ptr(), input_tensors.ptr(),
1033             backward_function.ptr(), forwardprop_output_indices.ptr()));
1034       });
1035   m.def("TFE_Py_TapeGradient",
1036         [](const py::handle& tape, const py::handle& target,
1037            const py::handle& sources, const py::handle& output_gradients,
1038            const py::handle& sources_raw,
1039            const py::handle& unconnected_gradients) {
1040           tensorflow::Safe_TF_StatusPtr status =
1041               tensorflow::make_safe(TF_NewStatus());
1042           PyObject* output = TFE_Py_TapeGradient(
1043               tape.ptr(), target.ptr(), sources.ptr(), output_gradients.ptr(),
1044               sources_raw.ptr(), unconnected_gradients.ptr(), status.get());
1045           tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1046           return tensorflow::PyoOrThrow(output);
1047         });
1048 
1049   m.def("TFE_Py_TapeVariableAccessed", [](const py::handle& variable) {
1050     TFE_Py_TapeVariableAccessed(variable.ptr());
1051   });
1052   m.def("TFE_Py_TapeWatch",
1053         [](const py::handle& tape, const py::handle& tensor) {
1054           TFE_Py_TapeWatch(tape.ptr(), tensor.ptr());
1055         });
1056   m.def("TFE_Py_TapeWatchVariable",
1057         [](const py::handle& tape, const py::handle& variable) {
1058           TFE_Py_TapeWatchVariable(tape.ptr(), variable.ptr());
1059         });
1060   m.def("TFE_Py_TapeWatchedVariables", [](const py::handle& tape) {
1061     return tensorflow::PyoOrThrow(TFE_Py_TapeWatchedVariables(tape.ptr()));
1062   });
1063 
1064   // TFE_Py_VariableWatcher logic.
1065   m.def("TFE_Py_VariableWatcherNew",
1066         []() { return tensorflow::PyoOrThrow(TFE_Py_VariableWatcherNew()); });
1067   m.def("TFE_Py_VariableWatcherRemove", [](const py::handle& variable_watcher) {
1068     TFE_Py_VariableWatcherRemove(variable_watcher.ptr());
1069   });
1070   m.def("TFE_Py_VariableWatcherVariableAccessed",
1071         [](const py::handle& variable) {
1072           TFE_Py_VariableWatcherVariableAccessed(variable.ptr());
1073         });
1074   m.def("TFE_Py_VariableWatcherWatchedVariables",
1075         [](const py::handle& variable_watcher) {
1076           return tensorflow::PyoOrThrow(
1077               TFE_Py_VariableWatcherWatchedVariables(variable_watcher.ptr()));
1078         });
1079 
1080   // TFE_Py_ForwardAccumulator logic.
1081   m.def("TFE_Py_ForwardAccumulatorNew", [](bool use_batch) {
1082     return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorNew(use_batch));
1083   });
1084 
1085   m.def("TFE_Py_ForwardAccumulatorSetAdd", [](const py::handle& accumulator) {
1086     return tensorflow::PyoOrThrow(
1087         TFE_Py_ForwardAccumulatorSetAdd(accumulator.ptr()));
1088   });
1089   m.def("TFE_Py_ForwardAccumulatorSetRemove",
1090         [](const py::handle& accumulator) {
1091           TFE_Py_ForwardAccumulatorSetRemove(accumulator.ptr());
1092         });
1093 
1094   m.def("TFE_Py_ForwardAccumulatorWatch",
1095         [](const py::handle& accumulator, const py::handle& tensor,
1096            const py::handle& tangent) {
1097           TFE_Py_ForwardAccumulatorWatch(accumulator.ptr(), tensor.ptr(),
1098                                          tangent.ptr());
1099         });
1100   m.def("TFE_Py_ForwardAccumulatorJVP",
1101         [](const py::handle& accumulator, const py::handle& tensor) {
1102           return tensorflow::PyoOrThrow(
1103               TFE_Py_ForwardAccumulatorJVP(accumulator.ptr(), tensor.ptr()));
1104         });
1105   m.def("TFE_Py_ForwardAccumulatorPushState", []() {
1106     return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorPushState());
1107   });
1108   m.def("TFE_Py_ForwardAccumulatorPopState", []() {
1109     return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorPopState());
1110   });
1111   m.def("TFE_Py_PackJVPs", [](const py::handle& tensors) {
1112     return tensorflow::PyoOrThrow(TFE_Py_PackJVPs(tensors.ptr()));
1113   });
1114 
1115   // TFE_ContextOptions Logic
1116   m.def("TFE_NewContextOptions", &TFE_NewContextOptions,
1117         py::return_value_policy::reference);
1118   m.def("TFE_ContextOptionsSetConfig", [](TFE_ContextOptions* options,
1119                                           py::bytes proto) {
1120     tensorflow::Safe_TF_StatusPtr status =
1121         tensorflow::make_safe(TF_NewStatus());
1122     tensorflow::Safe_TF_BufferPtr buf =
1123         tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr()));
1124     TFE_ContextOptionsSetConfig(options, buf.get()->data, buf.get()->length,
1125                                 status.get());
1126     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1127   });
1128   m.def("TFE_ContextOptionsSetDevicePlacementPolicy",
1129         &TFE_ContextOptionsSetDevicePlacementPolicy);
1130   m.def("TFE_ContextOptionsSetTfrt", &TFE_ContextOptionsSetTfrt);
1131   m.def("TFE_ContextOptionsSetTfrtDistributedRuntime",
1132         &TFE_ContextOptionsSetTfrtDistributedRuntime);
1133   // Experimental feature, intentionally not exposed as a C API yet.
1134   m.def("TFE_ContextOptionsSetRunEagerOpAsFunction",
1135         [](TFE_ContextOptions* options, bool run_eager_op_as_function) {
1136           options->run_eager_op_as_function = run_eager_op_as_function;
1137         });
1138   m.def("TFE_ContextOptionsSetAsync", &TFE_ContextOptionsSetAsync);
1139   m.def("TFE_DeleteContextOptions", &TFE_DeleteContextOptions,
1140         py::return_value_policy::reference);
1141 
1142   // TFE_Py_TensorShape Logic
1143   m.def("TFE_Py_TensorShapeSlice",
1144         [](const py::handle& tensors, int slice_dim) {
1145           return tensorflow::PyoOrThrow(
1146               TFE_Py_TensorShapeSlice(tensors.ptr(), slice_dim));
1147         });
1148   m.def("TFE_Py_TensorShapeOnDevice", [](const py::handle& tensors,
1149                                          int slice_dim) {
1150     return tensorflow::PyoOrThrow(TFE_Py_TensorShapeOnDevice(tensors.ptr()));
1151   });
1152   m.def("TFE_Py_EnableInteractivePythonLogging",
1153         &TFE_Py_EnableInteractivePythonLogging);
1154 
1155   // Additional Context Logic
1156   m.def("TFE_Py_SetEagerContext", [](const py::handle& o) {
1157     return tensorflow::PyoOrThrow(TFE_Py_SetEagerContext(o.ptr()));
1158   });
1159   m.def("TFE_Py_RegisterVSpace", [](const py::handle& o) {
1160     return tensorflow::PyoOrThrow(TFE_Py_RegisterVSpace(o.ptr()));
1161   });
1162   m.def("TFE_Py_EncodeArg",
1163         [](const py::handle& o, bool include_tensor_ranks_only) {
1164           return tensorflow::PyoOrThrow(
1165               TFE_Py_EncodeArg(o.ptr(), include_tensor_ranks_only));
1166         });
1167   m.def("TFE_EnableCollectiveOps", [](const py::handle& ctx, py::bytes proto) {
1168     tensorflow::Safe_TF_StatusPtr status =
1169         tensorflow::make_safe(TF_NewStatus());
1170     tensorflow::Safe_TF_BufferPtr buf =
1171         tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr()));
1172     TFE_EnableCollectiveOps(tensorflow::InputTFE_Context(ctx), buf.get()->data,
1173                             buf.get()->length, status.get());
1174     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1175   });
1176   m.def("TFE_AbortCollectiveOps", [](const py::handle& ctx, int code,
1177                                      const char* message) {
1178     tensorflow::Safe_TF_StatusPtr status =
1179         tensorflow::make_safe(TF_NewStatus());
1180     TF_SetStatus(status.get(), static_cast<TF_Code>(code), message);
1181     TFE_AbortCollectiveOps(tensorflow::InputTFE_Context(ctx), status.get());
1182   });
1183   m.def("TFE_CollectiveOpsCheckPeerHealth",
1184         [](const py::handle& ctx, const char* task, int64_t timeout_in_ms) {
1185           tensorflow::Safe_TF_StatusPtr status =
1186               tensorflow::make_safe(TF_NewStatus());
1187           TFE_CollectiveOpsCheckPeerHealth(tensorflow::InputTFE_Context(ctx),
1188                                            task, timeout_in_ms, status.get());
1189           tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1190         });
1191   m.def("TF_ListPhysicalDevices", &tensorflow::TF_ListPhysicalDevices);
1192   m.def("TF_ListPluggablePhysicalDevices",
1193         &tensorflow::TF_ListPluggablePhysicalDevices);
1194   m.def("TF_GetDeviceDetails", &tensorflow::TF_GetDeviceDetails);
1195   m.def("TF_DeleteDeviceList", &TF_DeleteDeviceList,
1196         py::return_value_policy::reference);
1197   m.def("TF_DeviceListCount", &TF_DeviceListCount);
1198   m.def("TF_DeviceListName", [](const TF_DeviceList* list, int index) {
1199     tensorflow::Safe_TF_StatusPtr status =
1200         tensorflow::make_safe(TF_NewStatus());
1201     auto output = TF_DeviceListName(list, index, status.get());
1202     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1203     return output;
1204   });
1205   m.def("TF_DeviceListType", [](const TF_DeviceList* list, int index) {
1206     tensorflow::Safe_TF_StatusPtr status =
1207         tensorflow::make_safe(TF_NewStatus());
1208     auto output = TF_DeviceListType(list, index, status.get());
1209     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1210     return output;
1211   });
1212 
1213   m.def("TF_PickUnusedPortOrDie", &TF_PickUnusedPortOrDie);
1214 
1215   // TFE_MonitoringCounter Logic
1216   m.def("TFE_MonitoringCounterCellIncrementBy",
1217         &TFE_MonitoringCounterCellIncrementBy);
1218   m.def("TFE_MonitoringCounterCellValue", &TFE_MonitoringCounterCellValue);
1219   m.def(
1220       "TFE_MonitoringNewCounter0",
1221       [](const char* name, const char* description) {
1222         tensorflow::Safe_TF_StatusPtr status =
1223             tensorflow::make_safe(TF_NewStatus());
1224         auto output =
1225             TFE_MonitoringNewCounter0(name, status.get(), description);
1226         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1227         return output;
1228       },
1229       py::return_value_policy::reference);
1230   m.def("TFE_MonitoringDeleteCounter0", &TFE_MonitoringDeleteCounter0,
1231         py::return_value_policy::reference);
1232   m.def("TFE_MonitoringGetCellCounter0", &TFE_MonitoringGetCellCounter0,
1233         py::return_value_policy::reference);
1234   m.def(
1235       "TFE_MonitoringNewCounter1",
1236       [](const char* name, const char* description, const char* label1) {
1237         tensorflow::Safe_TF_StatusPtr status =
1238             tensorflow::make_safe(TF_NewStatus());
1239         auto output =
1240             TFE_MonitoringNewCounter1(name, status.get(), description, label1);
1241         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1242         return output;
1243       },
1244       py::return_value_policy::reference);
1245   m.def("TFE_MonitoringDeleteCounter1", &TFE_MonitoringDeleteCounter1,
1246         py::return_value_policy::reference);
1247   m.def("TFE_MonitoringGetCellCounter1", &TFE_MonitoringGetCellCounter1,
1248         py::return_value_policy::reference);
1249   m.def(
1250       "TFE_MonitoringNewCounter2",
1251       [](const char* name, const char* description, const char* label1,
1252          const char* label2) {
1253         tensorflow::Safe_TF_StatusPtr status =
1254             tensorflow::make_safe(TF_NewStatus());
1255         auto output = TFE_MonitoringNewCounter2(name, status.get(), description,
1256                                                 label1, label2);
1257         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1258         return output;
1259       },
1260       py::return_value_policy::reference);
1261   m.def("TFE_MonitoringDeleteCounter2", &TFE_MonitoringDeleteCounter2,
1262         py::return_value_policy::reference);
1263   m.def("TFE_MonitoringGetCellCounter2", &TFE_MonitoringGetCellCounter2,
1264         py::return_value_policy::reference);
1265 
1266   // TFE_MonitoringIntGauge Logic
1267   m.def("TFE_MonitoringIntGaugeCellSet", &TFE_MonitoringIntGaugeCellSet);
1268   m.def("TFE_MonitoringIntGaugeCellValue", &TFE_MonitoringIntGaugeCellValue);
1269   m.def(
1270       "TFE_MonitoringNewIntGauge0",
1271       [](const char* name, const char* description) {
1272         tensorflow::Safe_TF_StatusPtr status =
1273             tensorflow::make_safe(TF_NewStatus());
1274         auto output =
1275             TFE_MonitoringNewIntGauge0(name, status.get(), description);
1276         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1277         return output;
1278       },
1279       py::return_value_policy::reference);
1280   m.def("TFE_MonitoringDeleteIntGauge0", &TFE_MonitoringDeleteIntGauge0,
1281         py::return_value_policy::reference);
1282   m.def("TFE_MonitoringGetCellIntGauge0", &TFE_MonitoringGetCellIntGauge0,
1283         py::return_value_policy::reference);
1284   m.def(
1285       "TFE_MonitoringNewIntGauge1",
1286       [](const char* name, const char* description, const char* label1) {
1287         tensorflow::Safe_TF_StatusPtr status =
1288             tensorflow::make_safe(TF_NewStatus());
1289         auto output =
1290             TFE_MonitoringNewIntGauge1(name, status.get(), description, label1);
1291         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1292         return output;
1293       },
1294       py::return_value_policy::reference);
1295   m.def("TFE_MonitoringDeleteIntGauge1", &TFE_MonitoringDeleteIntGauge1,
1296         py::return_value_policy::reference);
1297   m.def("TFE_MonitoringGetCellIntGauge1", &TFE_MonitoringGetCellIntGauge1,
1298         py::return_value_policy::reference);
1299   m.def(
1300       "TFE_MonitoringNewIntGauge2",
1301       [](const char* name, const char* description, const char* label1,
1302          const char* label2) {
1303         tensorflow::Safe_TF_StatusPtr status =
1304             tensorflow::make_safe(TF_NewStatus());
1305         auto output = TFE_MonitoringNewIntGauge2(name, status.get(),
1306                                                  description, label1, label2);
1307         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1308         return output;
1309       },
1310       py::return_value_policy::reference);
1311   m.def("TFE_MonitoringDeleteIntGauge2", &TFE_MonitoringDeleteIntGauge2,
1312         py::return_value_policy::reference);
1313   m.def("TFE_MonitoringGetCellIntGauge2", &TFE_MonitoringGetCellIntGauge2,
1314         py::return_value_policy::reference);
1315   m.def("TFE_MonitoringStringGaugeCellSet", &TFE_MonitoringStringGaugeCellSet);
1316   m.def("TFE_MonitoringStringGaugeCellValue",
1317         &TFE_MonitoringStringGaugeCellValue);
1318   m.def(
1319       "TFE_MonitoringNewStringGauge0",
1320       [](const char* name, const char* description) {
1321         tensorflow::Safe_TF_StatusPtr status =
1322             tensorflow::make_safe(TF_NewStatus());
1323         auto output =
1324             TFE_MonitoringNewStringGauge0(name, status.get(), description);
1325         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1326         return output;
1327       },
1328       py::return_value_policy::reference);
1329 
1330   // TFE_MonitoringStringGauge Logic
1331   m.def("TFE_MonitoringDeleteStringGauge0", &TFE_MonitoringDeleteStringGauge0);
1332   m.def("TFE_MonitoringGetCellStringGauge0", &TFE_MonitoringGetCellStringGauge0,
1333         py::return_value_policy::reference);
1334   m.def(
1335       "TFE_MonitoringNewStringGauge1",
1336       [](const char* name, const char* description, const char* label1) {
1337         tensorflow::Safe_TF_StatusPtr status =
1338             tensorflow::make_safe(TF_NewStatus());
1339         auto output = TFE_MonitoringNewStringGauge1(name, status.get(),
1340                                                     description, label1);
1341         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1342         return output;
1343       },
1344       py::return_value_policy::reference);
1345   m.def("TFE_MonitoringDeleteStringGauge1", &TFE_MonitoringDeleteStringGauge1);
1346   m.def("TFE_MonitoringGetCellStringGauge1", &TFE_MonitoringGetCellStringGauge1,
1347         py::return_value_policy::reference);
1348   m.def(
1349       "TFE_MonitoringNewStringGauge2",
1350       [](const char* name, const char* description, const char* label1,
1351          const char* label2) {
1352         tensorflow::Safe_TF_StatusPtr status =
1353             tensorflow::make_safe(TF_NewStatus());
1354         auto output = TFE_MonitoringNewStringGauge2(
1355             name, status.get(), description, label1, label2);
1356         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1357         return output;
1358       },
1359       py::return_value_policy::reference);
1360   m.def("TFE_MonitoringDeleteStringGauge2", &TFE_MonitoringDeleteStringGauge2);
1361   m.def("TFE_MonitoringGetCellStringGauge2", &TFE_MonitoringGetCellStringGauge2,
1362         py::return_value_policy::reference);
1363 
1364   m.def(
1365       "TFE_MonitoringNewStringGauge3",
1366       [](const char* name, const char* description, const char* label1,
1367          const char* label2, const char* label3) {
1368         tensorflow::Safe_TF_StatusPtr status =
1369             tensorflow::make_safe(TF_NewStatus());
1370         auto output = TFE_MonitoringNewStringGauge3(
1371             name, status.get(), description, label1, label2, label3);
1372         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1373         return output;
1374       },
1375       py::return_value_policy::reference);
1376   m.def("TFE_MonitoringDeleteStringGauge3", &TFE_MonitoringDeleteStringGauge3);
1377   m.def("TFE_MonitoringGetCellStringGauge3", &TFE_MonitoringGetCellStringGauge3,
1378         py::return_value_policy::reference);
1379 
1380   m.def(
1381       "TFE_MonitoringNewStringGauge4",
1382       [](const char* name, const char* description, const char* label1,
1383          const char* label2, const char* label3, const char* label4) {
1384         tensorflow::Safe_TF_StatusPtr status =
1385             tensorflow::make_safe(TF_NewStatus());
1386         auto output = TFE_MonitoringNewStringGauge4(
1387             name, status.get(), description, label1, label2, label3, label4);
1388         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1389         return output;
1390       },
1391       py::return_value_policy::reference);
1392   m.def("TFE_MonitoringDeleteStringGauge4", &TFE_MonitoringDeleteStringGauge4);
1393   m.def("TFE_MonitoringGetCellStringGauge4", &TFE_MonitoringGetCellStringGauge4,
1394         py::return_value_policy::reference);
1395 
1396   // TFE_MonitoringBoolGauge Logic
1397   m.def("TFE_MonitoringBoolGaugeCellSet", &TFE_MonitoringBoolGaugeCellSet);
1398   m.def("TFE_MonitoringBoolGaugeCellValue", &TFE_MonitoringBoolGaugeCellValue);
1399   m.def(
1400       "TFE_MonitoringNewBoolGauge0",
1401       [](const char* name, const char* description) {
1402         tensorflow::Safe_TF_StatusPtr status =
1403             tensorflow::make_safe(TF_NewStatus());
1404         auto output =
1405             TFE_MonitoringNewBoolGauge0(name, status.get(), description);
1406         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1407         return output;
1408       },
1409       py::return_value_policy::reference);
1410   m.def("TFE_MonitoringDeleteBoolGauge0", &TFE_MonitoringDeleteBoolGauge0,
1411         py::return_value_policy::reference);
1412   m.def("TFE_MonitoringGetCellBoolGauge0", &TFE_MonitoringGetCellBoolGauge0,
1413         py::return_value_policy::reference);
1414   m.def(
1415       "TFE_MonitoringNewBoolGauge1",
1416       [](const char* name, const char* description, const char* label1) {
1417         tensorflow::Safe_TF_StatusPtr status =
1418             tensorflow::make_safe(TF_NewStatus());
1419         auto output = TFE_MonitoringNewBoolGauge1(name, status.get(),
1420                                                   description, label1);
1421         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1422         return output;
1423       },
1424       py::return_value_policy::reference);
1425   m.def("TFE_MonitoringDeleteBoolGauge1", &TFE_MonitoringDeleteBoolGauge1,
1426         py::return_value_policy::reference);
1427   m.def("TFE_MonitoringGetCellBoolGauge1", &TFE_MonitoringGetCellBoolGauge1,
1428         py::return_value_policy::reference);
1429   m.def(
1430       "TFE_MonitoringNewBoolGauge2",
1431       [](const char* name, const char* description, const char* label1,
1432          const char* label2) {
1433         tensorflow::Safe_TF_StatusPtr status =
1434             tensorflow::make_safe(TF_NewStatus());
1435         auto output = TFE_MonitoringNewBoolGauge2(name, status.get(),
1436                                                   description, label1, label2);
1437         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1438         return output;
1439       },
1440       py::return_value_policy::reference);
1441   m.def("TFE_MonitoringDeleteBoolGauge2", &TFE_MonitoringDeleteBoolGauge2,
1442         py::return_value_policy::reference);
1443   m.def("TFE_MonitoringGetCellBoolGauge2", &TFE_MonitoringGetCellBoolGauge2,
1444         py::return_value_policy::reference);
1445 
1446   // TFE_MonitoringSampler Logic
1447   m.def("TFE_MonitoringSamplerCellAdd", &TFE_MonitoringSamplerCellAdd);
1448   m.def("TFE_MonitoringSamplerCellValue", &TFE_MonitoringSamplerCellValue);
1449   m.def("TFE_MonitoringNewExponentialBuckets",
1450         &TFE_MonitoringNewExponentialBuckets,
1451         py::return_value_policy::reference);
1452   m.def("TFE_MonitoringDeleteBuckets", &TFE_MonitoringDeleteBuckets,
1453         py::return_value_policy::reference);
1454   m.def(
1455       "TFE_MonitoringNewSampler0",
1456       [](const char* name, TFE_MonitoringBuckets* buckets,
1457          const char* description) {
1458         tensorflow::Safe_TF_StatusPtr status =
1459             tensorflow::make_safe(TF_NewStatus());
1460         auto output =
1461             TFE_MonitoringNewSampler0(name, buckets, status.get(), description);
1462         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1463         return output;
1464       },
1465       py::return_value_policy::reference);
1466   m.def("TFE_MonitoringDeleteSampler0", &TFE_MonitoringDeleteSampler0,
1467         py::return_value_policy::reference);
1468   m.def("TFE_MonitoringGetCellSampler0", &TFE_MonitoringGetCellSampler0,
1469         py::return_value_policy::reference);
1470   m.def(
1471       "TFE_MonitoringNewSampler1",
1472       [](const char* name, TFE_MonitoringBuckets* buckets,
1473          const char* description, const char* label1) {
1474         tensorflow::Safe_TF_StatusPtr status =
1475             tensorflow::make_safe(TF_NewStatus());
1476         auto output = TFE_MonitoringNewSampler1(name, buckets, status.get(),
1477                                                 description, label1);
1478         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1479         return output;
1480       },
1481       py::return_value_policy::reference);
1482   m.def("TFE_MonitoringDeleteSampler1", &TFE_MonitoringDeleteSampler1,
1483         py::return_value_policy::reference);
1484   m.def("TFE_MonitoringGetCellSampler1", &TFE_MonitoringGetCellSampler1,
1485         py::return_value_policy::reference);
1486   m.def(
1487       "TFE_MonitoringNewSampler2",
1488       [](const char* name, TFE_MonitoringBuckets* buckets,
1489          const char* description, const char* label1, const char* label2) {
1490         tensorflow::Safe_TF_StatusPtr status =
1491             tensorflow::make_safe(TF_NewStatus());
1492         auto output = TFE_MonitoringNewSampler2(name, buckets, status.get(),
1493                                                 description, label1, label2);
1494         tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1495         return output;
1496       },
1497       py::return_value_policy::reference);
1498   m.def("TFE_MonitoringDeleteSampler2", &TFE_MonitoringDeleteSampler2,
1499         py::return_value_policy::reference);
1500   m.def("TFE_MonitoringGetCellSampler2", &TFE_MonitoringGetCellSampler2,
1501         py::return_value_policy::reference);
1502 
1503   // TFE_CancellationManager Logic
1504   m.def("TFE_NewCancellationManager",
1505         []() { return new tensorflow::CancellationManager(); });
1506   m.def("TFE_CancellationManagerIsCancelled",
1507         &tensorflow::CancellationManager::IsCancelled);
1508   m.def("TFE_CancellationManagerStartCancel",
1509         &tensorflow::CancellationManager::StartCancel);
1510 
1511   m.def("TFE_ClearScalarCache", &tensorflow::TFE_ClearScalarCache);
1512 
1513   // Util buffer helper functions
1514   m.def("TF_NewBufferFromString", &TF_NewBufferFromString,
1515         py::return_value_policy::reference);
1516 
1517   // DLPack functions
1518   m.def("TFE_ToDlpackCapsule", [](py::handle& o) {
1519     PyObject* eager_tensor_pyobject_ptr = o.ptr();
1520     tensorflow::Safe_TF_StatusPtr status =
1521         tensorflow::make_safe(TF_NewStatus());
1522 
1523     if (!EagerTensor_CheckExact(eager_tensor_pyobject_ptr)) {
1524       status->status = tensorflow::errors::InvalidArgument(
1525           "The argument to `to_dlpack` must be a TF tensor, not Python object");
1526       tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1527     }
1528 
1529     TFE_TensorHandle* thandle = EagerTensor_Handle(eager_tensor_pyobject_ptr);
1530     void* dlm_ptr = tensorflow::TFE_HandleToDLPack(thandle, status.get());
1531     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1532 
1533     py::capsule capsule(
1534         dlm_ptr, tensorflow::kDlTensorCapsuleName, [](PyObject* capsule) {
1535           if (PyCapsule_IsValid(capsule, tensorflow::kDlTensorCapsuleName)) {
1536             void* dlm_rptr =
1537                 PyCapsule_GetPointer(capsule, tensorflow::kDlTensorCapsuleName);
1538             if (dlm_rptr) {
1539               tensorflow::TFE_CallDLManagedTensorDeleter(dlm_rptr);
1540               PyCapsule_SetDestructor(capsule, nullptr);
1541             }
1542           }
1543         });
1544     return capsule;
1545   });
1546 
1547   m.def("TFE_FromDlpackCapsule", [](const py::capsule& pycapsule,
1548                                     const py::handle& context) {
1549     tensorflow::Safe_TF_StatusPtr status =
1550         tensorflow::make_safe(TF_NewStatus());
1551     if (absl::string_view(pycapsule.name()) !=
1552         tensorflow::kDlTensorCapsuleName) {
1553       status->status = tensorflow::errors::InvalidArgument(
1554           "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". "
1555           "Note that a DLPack tensor may be consumed at most once.",
1556           absl::string_view(pycapsule.name()));
1557       tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1558     }
1559 
1560     TFE_TensorHandle* thandle = tensorflow::TFE_HandleFromDLPack(
1561         pycapsule, status.get(), tensorflow::InputTFE_Context(context));
1562 
1563     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1564 
1565     PyCapsule_SetName(pycapsule.ptr(), "used_dltensor");
1566     PyCapsule_SetDestructor(pycapsule.ptr(), nullptr);
1567 
1568     PyObject* pyhandle = EagerTensorFromHandle(thandle);
1569     return tensorflow::PyoOrThrow(pyhandle);
1570   });
1571 
1572   m.def("TFE_Py_RegisterCustomDevice", [](const py::handle& context,
1573                                           const py::capsule& device,
1574                                           const char* device_name,
1575                                           const py::capsule& device_info) {
1576     tensorflow::Safe_TF_StatusPtr status =
1577         tensorflow::make_safe(TF_NewStatus());
1578     if (absl::string_view(device.name()) != "TFE_CustomDevice") {
1579       status->status = tensorflow::errors::InvalidArgument(
1580           "Expected a capsule named 'TFE_CustomDevice' for the `device` "
1581           "argument, got ",
1582           absl::string_view(device.name()));
1583       tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1584     }
1585     if (absl::string_view(device_info.name()) !=
1586         "TFE_CustomDevice_DeviceInfo") {
1587       status->status = tensorflow::errors::InvalidArgument(
1588           "Expected a capsule named 'TFE_CustomDevice_DeviceInfo' for "
1589           "the `device_info` argument, got ",
1590           absl::string_view(device_info.name()));
1591       tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1592     }
1593     // TFE_RegisterCustomDevice takes ownership
1594     PyCapsule_SetDestructor(device_info.ptr(), nullptr);
1595     TFE_RegisterCustomDevice(
1596         tensorflow::InputTFE_Context(context),
1597         *reinterpret_cast<TFE_CustomDevice*>(
1598             PyCapsule_GetPointer(device.ptr(), "TFE_CustomDevice")),
1599         device_name,
1600         PyCapsule_GetPointer(device_info.ptr(), "TFE_CustomDevice_DeviceInfo"),
1601         status.get());
1602     tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1603   });
1604 
1605   py::class_<EagerContextThreadLocalDataWrapper>(m,
1606                                                  "EagerContextThreadLocalData")
1607       .def(py::init<py::handle, py::handle, py::handle>(),
1608            py::arg("py_eager_context"), py::arg("is_eager"),
1609            py::arg("device_spec"))
1610       .def_property("is_eager",
1611                     &EagerContextThreadLocalDataWrapper::get_is_eager,
1612                     &EagerContextThreadLocalDataWrapper::set_is_eager)
1613       .def_property(
1614           "invoking_op_callbacks",
1615           &EagerContextThreadLocalDataWrapper::get_invoking_op_callbacks,
1616           &EagerContextThreadLocalDataWrapper::set_invoking_op_callbacks)
1617       .def_property("device_name",
1618                     &EagerContextThreadLocalDataWrapper::get_device_name,
1619                     &EagerContextThreadLocalDataWrapper::set_device_name)
1620       .def_property("scope_name",
1621                     &EagerContextThreadLocalDataWrapper::get_scope_name,
1622                     &EagerContextThreadLocalDataWrapper::set_scope_name)
1623       .def_property("device_spec",
1624                     &EagerContextThreadLocalDataWrapper::get_device_spec,
1625                     &EagerContextThreadLocalDataWrapper::set_device_spec)
1626       .def_property(
1627           "function_call_options",
1628           &EagerContextThreadLocalDataWrapper::get_function_call_options,
1629           &EagerContextThreadLocalDataWrapper::set_function_call_options)
1630       .def_property("executor",
1631                     &EagerContextThreadLocalDataWrapper::get_executor,
1632                     &EagerContextThreadLocalDataWrapper::set_executor)
1633       .def_property("op_callbacks",
1634                     &EagerContextThreadLocalDataWrapper::get_op_callbacks,
1635                     &EagerContextThreadLocalDataWrapper::set_op_callbacks);
1636 
1637   // C API Enum
1638 
1639   py::enum_<TFE_ContextDevicePlacementPolicy>(
1640       m, "TFE_ContextDevicePlacementPolicy")
1641       .value("TFE_DEVICE_PLACEMENT_EXPLICIT", TFE_DEVICE_PLACEMENT_EXPLICIT)
1642       .value("TFE_DEVICE_PLACEMENT_WARN", TFE_DEVICE_PLACEMENT_WARN)
1643       .value("TFE_DEVICE_PLACEMENT_SILENT", TFE_DEVICE_PLACEMENT_SILENT)
1644       .value("TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32",
1645              TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32)
1646       .export_values();
1647 
1648   py::enum_<TF_AttrType>(m, "TF_AttrType")
1649       .value("TF_ATTR_STRING", TF_ATTR_STRING)
1650       .value("TF_ATTR_INT", TF_ATTR_INT)
1651       .value("TF_ATTR_FLOAT", TF_ATTR_FLOAT)
1652       .value("TF_ATTR_BOOL", TF_ATTR_BOOL)
1653       .value("TF_ATTR_TYPE", TF_ATTR_TYPE)
1654       .value("TF_ATTR_SHAPE", TF_ATTR_SHAPE)
1655       .value("TF_ATTR_TENSOR", TF_ATTR_TENSOR)
1656       .value("TF_ATTR_PLACEHOLDER", TF_ATTR_PLACEHOLDER)
1657       .value("TF_ATTR_FUNC", TF_ATTR_FUNC)
1658       .export_values();
1659 };
1660