1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");;
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17
18 #include "Python.h"
19 #include "absl/strings/str_format.h"
20 #include "pybind11/chrono.h"
21 #include "pybind11/complex.h"
22 #include "pybind11/functional.h"
23 #include "pybind11/pybind11.h"
24 #include "pybind11/pytypes.h"
25 #include "pybind11/stl.h"
26 #include "tensorflow/c/c_api.h"
27 #include "tensorflow/c/c_api_experimental.h"
28 #include "tensorflow/c/eager/c_api.h"
29 #include "tensorflow/c/eager/c_api_experimental.h"
30 #include "tensorflow/c/eager/c_api_internal.h"
31 #include "tensorflow/c/eager/dlpack.h"
32 #include "tensorflow/c/eager/tfe_cancellation_manager_internal.h"
33 #include "tensorflow/c/eager/tfe_context_internal.h"
34 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
35 #include "tensorflow/c/tf_status.h"
36 #include "tensorflow/c/tf_status_helper.h"
37 #include "tensorflow/compiler/jit/flags.h"
38 #include "tensorflow/compiler/jit/get_compiler_ir.h"
39 #include "tensorflow/python/eager/pywrap_tensor_conversion.h"
40 #include "tensorflow/python/eager/pywrap_tfe.h"
41 #include "tensorflow/python/lib/core/py_exception_registry.h"
42 #include "tensorflow/python/lib/core/pybind11_lib.h"
43 #include "tensorflow/python/lib/core/pybind11_status.h"
44 #include "tensorflow/python/lib/core/safe_ptr.h"
45 #include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
46 #include "tensorflow/python/util/util.h"
47
48 namespace py = pybind11;
49
50 PYBIND11_MAKE_OPAQUE(TFE_Executor);
51 PYBIND11_MAKE_OPAQUE(TFE_ContextOptions);
52 PYBIND11_MAKE_OPAQUE(tensorflow::CancellationManager);
53
54 PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter0);
55 PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter1);
56 PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter2);
57 PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge0);
58 PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge1);
59 PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge2);
60 PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge3);
61 PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge4);
62 PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGauge0);
63 PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGauge1);
64 PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGauge2);
65 PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGauge0);
66 PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGauge1);
67 PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGauge2);
68 PYBIND11_MAKE_OPAQUE(TFE_MonitoringSampler0);
69 PYBIND11_MAKE_OPAQUE(TFE_MonitoringSampler1);
70 PYBIND11_MAKE_OPAQUE(TFE_MonitoringSampler2);
71 PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounterCell);
72 PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGaugeCell);
73 PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGaugeCell);
74 PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGaugeCell);
75 PYBIND11_MAKE_OPAQUE(TFE_MonitoringSamplerCell);
76
77 PYBIND11_MAKE_OPAQUE(TF_DeviceList);
78 PYBIND11_MAKE_OPAQUE(TF_Function);
79 PYBIND11_MAKE_OPAQUE(TF_Buffer);
80
81 // Eager helper functions migrated from pywrap_tfe.i.
82
83 namespace tensorflow {
84
85 // We cannot use Context as an opaque type. SWIG also had
86 // difficult directly passing the pointer around. These
87 // typemaps are migrated over from pywrap_tfe.i. I tried
88 // using a custom type caster, but we get segfaults periodically.
89
90 // TODO(amitpatankar): Move input and output logic of Context into a
91 // pybind11 custom type caster.
92
InputTFE_Context(const py::handle & ctx)93 TFE_Context* InputTFE_Context(const py::handle& ctx) {
94 return static_cast<TFE_Context*>(PyCapsule_GetPointer(ctx.ptr(), nullptr));
95 }
96
OutputTFE_Context(TFE_Context * context)97 PyObject* OutputTFE_Context(TFE_Context* context) {
98 return PyCapsule_New(context, nullptr, TFE_DeleteContextCapsule);
99 }
100
ProtoStringToTFBuffer(PyObject * input)101 TF_Buffer* ProtoStringToTFBuffer(PyObject* input) {
102 // Convert a Python string object to TF_Buffer.
103 char* c_string;
104 Py_ssize_t py_size;
105 // PyBytes_AsStringAndSize() does not copy but simply interprets the input
106 if (PyBytes_AsStringAndSize(input, &c_string, &py_size) == -1) {
107 // Python has raised an error (likely TypeError or UnicodeEncodeError).
108 throw py::error_already_set();
109 }
110 return TF_NewBufferFromString(static_cast<void*>(c_string),
111 static_cast<size_t>(py_size));
112 }
113
114 // These functions are typemaps from the Python side. I did not use
115 // a custom type caster since the logic is slightly harder to follow. This
116 // converter is also only used once in `TFE_Py_ExecuteCancelable_wrapper`.
InputTFE_InputTensorHandles(const py::handle & input_tensors)117 TFE_InputTensorHandles InputTFE_InputTensorHandles(
118 const py::handle& input_tensors) {
119 TFE_InputTensorHandles input_tensor_handles;
120 if (input_tensors.ptr() != Py_None) {
121 if (!PyList_Check(input_tensors.ptr())) {
122 tensorflow::ThrowTypeError("must provide a list of Tensors as inputs");
123 }
124 Py_ssize_t len = PyList_Size(input_tensors.ptr());
125 input_tensor_handles.resize(len);
126 for (Py_ssize_t i = 0; i < len; ++i) {
127 PyObject* elem = PyList_GetItem(input_tensors.ptr(), i);
128 if (!elem) {
129 tensorflow::ThrowTypeError("Input Tensor does not exist.");
130 }
131 if (EagerTensor_CheckExact(elem)) {
132 (input_tensor_handles)[i] = EagerTensor_Handle(elem);
133 } else if (tensorflow::swig::IsEagerTensorSlow(elem)) {
134 // Use equivalent of object.__getattribute__ to get the underlying
135 // tf wrapped EagerTensor (if there is one).
136 tensorflow::Safe_PyObjectPtr tf_should_use_attr(
137 #if PY_MAJOR_VERSION < 3
138 PyString_InternFromString("_tf_should_use_wrapped_value")
139 #else
140 PyUnicode_InternFromString("_tf_should_use_wrapped_value")
141 #endif
142 );
143 tensorflow::Safe_PyObjectPtr value_attr(
144 PyObject_GenericGetAttr(elem, tf_should_use_attr.get()));
145 if (value_attr) {
146 // This is an EagerTensor wrapped inside a TFShouldUse wrapped object.
147 (input_tensor_handles)[i] = EagerTensor_Handle(value_attr.get());
148 } else {
149 // This is a subclass of EagerTensor that we don't support.
150 PyErr_Clear();
151 tensorflow::ThrowTypeError(
152 tensorflow::strings::StrCat(
153 "Saw an object that is an instance of a strict subclass of "
154 "EagerTensor, which is not supported. Item ",
155 i, " is type: ", elem->ob_type->tp_name)
156 .c_str());
157 }
158 } else if (tensorflow::swig::IsTensor(elem)) {
159 // If it isnt an EagerTensor, but is still a Tensor, it must be a graph
160 // tensor.
161 tensorflow::Safe_PyObjectPtr name_attr(
162 PyObject_GetAttrString(elem, "name"));
163 tensorflow::ThrowTypeError(
164 tensorflow::strings::StrCat(
165 "An op outside of the function building code is being passed\n"
166 "a \"Graph\" tensor. It is possible to have Graph tensors\n"
167 "leak out of the function building context by including a\n"
168 "tf.init_scope in your function building code.\n"
169 "For example, the following function will fail:\n",
170 " @tf.function\n", " def has_init_scope():\n",
171 " my_constant = tf.constant(1.)\n",
172 " with tf.init_scope():\n",
173 " added = my_constant * 2\n",
174 "The graph tensor has name: ",
175 name_attr ? TFE_GetPythonString(name_attr.get()) : "<unknown>")
176 .c_str());
177 } else {
178 tensorflow::ThrowTypeError(
179 tensorflow::strings::StrCat(
180 "provided list of inputs contains objects other "
181 "than 'EagerTensor'. Item ",
182 i, " is type: ", elem->ob_type->tp_name)
183 .c_str());
184 }
185 }
186 }
187 return input_tensor_handles;
188 }
189
190 // These functions are typemaps from the Python side. I did not use
191 // a custom type caster since the logic is slightly harder to follow. This
192 // converter is also only used once in `TFE_Py_ExecuteCancelable_wrapper`.
193 // This function actually takes a number rather than an output Tensor holder.
InputTFE_OutputTensorHandles(const py::handle & num_outputs)194 TFE_OutputTensorHandles InputTFE_OutputTensorHandles(
195 const py::handle& num_outputs) {
196 TFE_OutputTensorHandles output_tensor_handles;
197 #if PY_MAJOR_VERSION < 3
198 if (!PyInt_Check(num_outputs.ptr())) {
199 #else
200 if (!PyLong_Check(num_outputs.ptr())) {
201 #endif
202 PyErr_SetString(PyExc_TypeError,
203 "expected an integer value (size of the number of "
204 "outputs of the operation)");
205 throw py::error_already_set();
206 }
207 #if PY_MAJOR_VERSION < 3
208 long sz = PyInt_AsLong(num_outputs.ptr()); // NOLINT
209 #else
210 long sz = PyLong_AsLong(num_outputs.ptr()); // NOLINT
211 #endif
212 // PyLong_AsLong might throw an error if an overflow occurs.
213 if (PyErr_Occurred()) {
214 PyErr_SetString(PyExc_ValueError, tensorflow::strings::StrCat(
215 "Number of outputs is too big: ", sz)
216 .c_str());
217 throw py::error_already_set();
218 }
219 // We can't handle more than int32 sizes for number of outputs.
220 if (static_cast<long>(static_cast<int32_t>(sz)) != sz) { // NOLINT
221 PyErr_SetString(PyExc_ValueError, tensorflow::strings::StrCat(
222 "Number of outputs is too big: ", sz)
223 .c_str());
224 throw py::error_already_set();
225 }
226 if (sz > 0) {
227 #if PY_MAJOR_VERSION < 3
228 output_tensor_handles.resize(PyInt_AsLong(num_outputs.ptr()), nullptr);
229 #else
230 output_tensor_handles.resize(PyLong_AsLong(num_outputs.ptr()), nullptr);
231 #endif
232 }
233 return output_tensor_handles;
234 }
235
236 tensorflow::Device* GetMatchedDevice(py::handle& ctx, const char* device_name) {
237 auto* context = reinterpret_cast<tensorflow::ImmediateExecutionContext*>(
238 tensorflow::InputTFE_Context(ctx));
239
240 tensorflow::DeviceNameUtils::ParsedName input_device_name;
241 if (!tensorflow::DeviceNameUtils::ParseFullOrLocalName(device_name,
242 &input_device_name)) {
243 tensorflow::ThrowValueError(
244 absl::StrFormat("Failed parsing device name: '%s'. Note a valid device "
245 "string should at least contain a device type and a "
246 "device index, like \"GPU:0\".",
247 device_name)
248 .c_str());
249 }
250
251 std::vector<tensorflow::Device*> devices = context->ListLocalTfDevices();
252
253 tensorflow::Device* matched_device = nullptr;
254 for (int device_idx = 0; device_idx < devices.size(); device_idx++) {
255 tensorflow::Device* device = devices[device_idx];
256
257 if (tensorflow::DeviceNameUtils::AreCompatibleDevNames(
258 input_device_name, device->parsed_name())) {
259 if (matched_device != nullptr) {
260 tensorflow::ThrowValueError(
261 absl::StrFormat("Multiple devices match the provided string "
262 "'%s': '%s' and '%s'.",
263 device_name, matched_device->name(), device->name())
264 .c_str());
265 }
266 matched_device = device;
267 }
268 }
269
270 if (matched_device == nullptr) {
271 tensorflow::ThrowValueError(
272 absl::StrFormat("No matching devices found for '%s'", device_name)
273 .c_str());
274 }
275
276 return matched_device;
277 }
278
279 // Packs multiple `EagerTensor`s of the same dtype and shape into one
280 // `EagerTensor`.
281 py::object TFE_Py_PackEagerTensors_wrapper(const py::handle& context,
282 const py::handle& tensors) {
283 TFE_Context* ctx = tensorflow::InputTFE_Context(context);
284 TFE_InputTensorHandles handles = InputTFE_InputTensorHandles(tensors);
285 tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus());
286 int size = handles.size();
287 TFE_TensorHandle* packed_handle =
288 TFE_CreatePackedTensorHandle(ctx, handles.data(), &size, status.get());
289 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
290 PyObject* packed_tensor =
291 EagerTensorFromHandle(packed_handle, /*is_packed=*/true);
292 return tensorflow::PyoOrThrow(packed_tensor);
293 }
294
295 // This function was created from fusing the typemap logic in platform/base.i.
296 py::object TFE_Py_ExecuteCancelable_wrapper(
297 const py::handle& context, const char* device_name, const char* op_name,
298 const py::handle& inputs, const py::handle& attrs,
299 tensorflow::CancellationManager* cancellation_manager,
300 const py::handle& num_outputs) {
301 TFE_Context* ctx = tensorflow::InputTFE_Context(context);
302 TFE_InputTensorHandles input_tensor_handles =
303 InputTFE_InputTensorHandles(inputs);
304 TFE_OutputTensorHandles output_tensor_handles =
305 InputTFE_OutputTensorHandles(num_outputs);
306 tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus());
307 TFE_Py_ExecuteCancelable(ctx, device_name, op_name, &input_tensor_handles,
308 attrs.ptr(), tensorflow::wrap(cancellation_manager),
309 &output_tensor_handles, status.get());
310
311 int output_len = output_tensor_handles.size();
312 PyObject* output_list = PyList_New(output_len);
313 for (int i = 0; i < output_len; ++i) {
314 PyObject* output;
315 output = EagerTensorFromHandle(output_tensor_handles.at(i));
316 PyList_SetItem(output_list, i, output);
317 }
318 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
319 return tensorflow::PyoOrThrow(output_list);
320 }
321
322 static py::object TF_ListPhysicalDevices() {
323 std::vector<string> devices;
324 tensorflow::Status s =
325 tensorflow::DeviceFactory::ListAllPhysicalDevices(&devices);
326 MaybeRaiseRegisteredFromStatus(s);
327 PyObject* result = PyList_New(devices.size());
328 int i = 0;
329 for (auto& dev : devices) {
330 PyObject* dev_obj = PyBytes_FromStringAndSize(dev.data(), dev.size());
331 PyList_SetItem(result, i, dev_obj);
332 ++i;
333 }
334 return tensorflow::PyoOrThrow(result);
335 }
336
337 static py::object TF_ListPluggablePhysicalDevices() {
338 std::vector<string> devices;
339 tensorflow::Status s =
340 tensorflow::DeviceFactory::ListPluggablePhysicalDevices(&devices);
341 MaybeRaiseRegisteredFromStatus(s);
342 Safe_PyObjectPtr result(PyList_New(devices.size()));
343 int i = 0;
344 for (auto& dev : devices) {
345 PyObject* dev_obj = PyBytes_FromStringAndSize(dev.data(), dev.size());
346 PyList_SetItem(result.get(), i, dev_obj);
347 ++i;
348 }
349 return tensorflow::PyoOrThrow(result.release());
350 }
351
352 static std::unordered_map<string, string> TF_GetDeviceDetails(int index) {
353 tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus());
354 std::unordered_map<string, string> device_details;
355 tensorflow::Status s =
356 tensorflow::DeviceFactory::GetAnyDeviceDetails(index, &device_details);
357 tensorflow::Set_TF_Status_from_Status(status.get(), s);
358 MaybeRaiseRegisteredFromTFStatus(status.get());
359 return device_details;
360 }
361
362 static py::object TFE_ClearScalarCache() {
363 tensorflow::TFE_TensorHandleCache::Get()->Clear();
364 return py::none();
365 }
366
367 // Returns compiler IR for a given function.
368 static py::bytes TFE_GetCompilerIr(py::handle& ctx,
369 const char* concrete_function_name,
370 const char* stage, const char* device_name,
371 py::handle& inputs) {
372 EagerContext* context = ContextFromInterface(
373 reinterpret_cast<ImmediateExecutionContext*>(InputTFE_Context(ctx)));
374
375 std::string s_stage(stage);
376 IrExportStage selected_stage = [&] {
377 if (s_stage == "hlo") {
378 return IrExportStage::HLO;
379 } else if (s_stage == "hlo_serialized") {
380 return IrExportStage::HLO_SERIALIZED;
381 } else if (s_stage == "optimized_hlo") {
382 return IrExportStage::OPTIMIZED_HLO;
383 } else if (s_stage == "optimized_hlo_serialized") {
384 return IrExportStage::OPTIMIZED_HLO_SERIALIZED;
385 } else if (s_stage == "optimized_hlo_proto_serialized") {
386 return IrExportStage::OPTIMIZED_HLO_PROTO_SERIALIZED;
387 } else if (s_stage == "optimized_hlo_dot") {
388 return IrExportStage::OPTIMIZED_HLO_DOT;
389 } else {
390 ThrowValueError(
391 absl::StrFormat("Invalid stage selected: '%s'. Valid values are: "
392 "'hlo', 'hlo_serialized', 'optimized_hlo', "
393 "'optimized_hlo_serialized', 'optimized_hlo_dot'",
394 s_stage)
395 .c_str());
396 }
397 }();
398
399 TFE_InputTensorHandles handles = InputTFE_InputTensorHandles(inputs);
400
401 std::vector<const TensorHandle*> input_handles;
402 for (TFE_TensorHandle* tensor_handle : handles) {
403 AbstractTensorHandle* abstract_tensor_handle = unwrap(tensor_handle);
404 input_handles.push_back(TensorHandleFromInterface(abstract_tensor_handle));
405 }
406
407 DeviceNameUtils::ParsedName input_device_name;
408 if (!DeviceNameUtils::ParseFullOrLocalName(device_name, &input_device_name)) {
409 ThrowValueError(
410 absl::StrFormat("Failed parsing device name: '%s'", device_name)
411 .c_str());
412 }
413
414 std::vector<Device*> devices = context->local_device_mgr()->ListDevices();
415 auto selected_device = absl::c_find_if(devices, [&](const Device* d) {
416 return DeviceNameUtils::AreCompatibleDevNames(input_device_name,
417 d->parsed_name());
418 });
419 if (selected_device == devices.end()) {
420 ThrowValueError(
421 absl::StrFormat("No matching device found for '%s'", device_name)
422 .c_str());
423 }
424
425 xla::StatusOr<std::string> hlo_str =
426 GetCompilerIr(selected_stage, context->pflr(), concrete_function_name,
427 *selected_device, context, input_handles);
428
429 if (!hlo_str.ok()) {
430 ThrowValueError(absl::StrFormat("Failed getting HLO text: '%s'",
431 hlo_str.status().error_message())
432 .c_str());
433 }
434 return py::bytes(*hlo_str);
435 }
436
437 } // namespace tensorflow
438
439 namespace {
440
441 // Wrapper around the EagerContextThreadLocalData struct (defined in
442 // pywrap_tfe.h), so it can be accessed from Python.
443 //
444 // For PyObject* fields, the get_*() methods return a new reference; and the
445 // set_*() methods create a new reference (i.e., they do not steal a reference).
446 class EagerContextThreadLocalDataWrapper {
447 public:
EagerContextThreadLocalDataWrapper(py::handle py_eager_context,py::handle is_eager,py::handle device_spec)448 explicit EagerContextThreadLocalDataWrapper(py::handle py_eager_context,
449 py::handle is_eager,
450 py::handle device_spec)
451 : py_eager_context_(py_eager_context.ptr()) {
452 tensorflow::MakeEagerContextThreadLocalData(
453 py_eager_context.ptr(), is_eager.ptr(), device_spec.ptr());
454 }
455
~EagerContextThreadLocalDataWrapper()456 ~EagerContextThreadLocalDataWrapper() {
457 tensorflow::DestroyEagerContextThreadLocalData(py_eager_context_);
458 }
459
get_is_eager() const460 bool get_is_eager() const { return GetData()->is_eager; }
set_is_eager(bool v)461 void set_is_eager(bool v) { GetData()->is_eager = v; }
462
get_invoking_op_callbacks() const463 bool get_invoking_op_callbacks() const {
464 return GetData()->invoking_op_callbacks;
465 }
set_invoking_op_callbacks(bool v)466 void set_invoking_op_callbacks(bool v) {
467 GetData()->invoking_op_callbacks = v;
468 }
469
get_device_name() const470 py::object get_device_name() const {
471 return GetPyObject(&GetData()->device_name);
472 }
set_device_name(py::handle v)473 void set_device_name(py::handle v) {
474 SetPyObject(v, &GetData()->device_name);
475 }
476
get_scope_name() const477 py::object get_scope_name() const {
478 return GetPyObject(&GetData()->scope_name);
479 }
set_scope_name(py::handle v)480 void set_scope_name(py::handle v) { SetPyObject(v, &GetData()->scope_name); }
481
get_device_spec() const482 py::object get_device_spec() const {
483 return GetPyObject(&GetData()->device_spec);
484 }
set_device_spec(py::handle v)485 void set_device_spec(py::handle v) {
486 SetPyObject(v, &GetData()->device_spec);
487 }
488
get_function_call_options() const489 py::object get_function_call_options() const {
490 return GetPyObject(&GetData()->function_call_options);
491 }
set_function_call_options(py::handle v)492 void set_function_call_options(py::handle v) {
493 SetPyObject(v, &GetData()->function_call_options);
494 }
495
get_executor() const496 py::handle get_executor() const { return GetPyObject(&GetData()->executor); }
set_executor(py::handle v)497 void set_executor(py::handle v) { SetPyObject(v, &GetData()->executor); }
498
get_op_callbacks() const499 py::object get_op_callbacks() const {
500 return GetPyObject(&GetData()->op_callbacks);
501 }
set_op_callbacks(py::handle v)502 void set_op_callbacks(py::handle v) {
503 SetPyObject(v, &GetData()->op_callbacks);
504 }
505
506 private:
GetData() const507 tensorflow::EagerContextThreadLocalData* GetData() const {
508 auto* result =
509 tensorflow::GetEagerContextThreadLocalData(py_eager_context_);
510 if (!result) {
511 throw py::error_already_set();
512 }
513 return result;
514 }
515
GetPyObject(tensorflow::Safe_PyObjectPtr * obj) const516 py::object GetPyObject(tensorflow::Safe_PyObjectPtr* obj) const {
517 return pybind11::reinterpret_borrow<py::object>(obj->get());
518 }
519
SetPyObject(py::handle value,tensorflow::Safe_PyObjectPtr * ptr)520 void SetPyObject(py::handle value, tensorflow::Safe_PyObjectPtr* ptr) {
521 Py_INCREF(value.ptr());
522 ptr->reset(value.ptr());
523 }
524
525 PyObject* py_eager_context_; // not owned (borrowed reference).
526 };
527
528 } // namespace
529
530 // py::return_value_policy::reference is defined as specified by the
531 // pybind11 documents listed here.
532 // https://pybind11.readthedocs.io/en/stable/advanced/functions.html#return-value-policies
533 // This means that C++ maintains ownership of the object. We
534 // are only assigning this to functions that return opaque types.
535
PYBIND11_MODULE(_pywrap_tfe,m)536 PYBIND11_MODULE(_pywrap_tfe, m) {
537 py::class_<TFE_Executor> TFE_Executor_class(m, "TFE_Executor");
538 py::class_<TFE_ContextOptions> TFE_ContextOptions_class(m,
539 "TFE_ContextOptions");
540 py::class_<TFE_MonitoringCounter0> TFE_MonitoringCounter0_class(
541 m, "TFE_MonitoringCounter0");
542 py::class_<TFE_MonitoringCounter1> TFE_MonitoringCounter1_class(
543 m, "TFE_MonitoringCounter1");
544 py::class_<TFE_MonitoringCounter2> TFE_MonitoringCounter2_class(
545 m, "TFE_MonitoringCounter2");
546 py::class_<TFE_MonitoringStringGauge0> TFE_MonitoringStringGauge0_class(
547 m, "TFE_MonitoringStringGauge0");
548 py::class_<TFE_MonitoringStringGauge1> TFE_MonitoringStringGauge1_class(
549 m, "TFE_MonitoringStringGauge1");
550 py::class_<TFE_MonitoringStringGauge2> TFE_MonitoringStringGauge2_class(
551 m, "TFE_MonitoringStringGauge2");
552 py::class_<TFE_MonitoringStringGauge3> TFE_MonitoringStringGauge3_class(
553 m, "TFE_MonitoringStringGauge3");
554 py::class_<TFE_MonitoringStringGauge4> TFE_MonitoringStringGauge4_class(
555 m, "TFE_MonitoringStringGauge4");
556 py::class_<TFE_MonitoringIntGauge0> TFE_MonitoringIntGauge0_class(
557 m, "TFE_MonitoringIntGauge0");
558 py::class_<TFE_MonitoringIntGauge1> TFE_MonitoringIntGauge1_class(
559 m, "TFE_MonitoringIntGauge1");
560 py::class_<TFE_MonitoringIntGauge2> TFE_MonitoringIntGauge2_class(
561 m, "TFE_MonitoringIntGauge2");
562 py::class_<TFE_MonitoringBoolGauge0> TFE_MonitoringBoolGauge0_class(
563 m, "TFE_MonitoringBoolGauge0");
564 py::class_<TFE_MonitoringBoolGauge1> TFE_MonitoringBoolGauge1_class(
565 m, "TFE_MonitoringBoolGauge1");
566 py::class_<TFE_MonitoringBoolGauge2> TFE_MonitoringBoolGauge2_class(
567 m, "TFE_MonitoringBoolGauge2");
568 py::class_<TFE_MonitoringCounterCell> TFE_MonitoringCounterCell_class(
569 m, "TFE_MonitoringCounterCell");
570 py::class_<TFE_MonitoringIntGaugeCell> TFE_MonitoringIntGaugeCell_class(
571 m, "TFE_MonitoringIntGaugeCell");
572 py::class_<TFE_MonitoringStringGaugeCell> TFE_MonitoringStringGaugeCell_class(
573 m, "TFE_MonitoringStringGaugeCell");
574 py::class_<TFE_MonitoringBoolGaugeCell> TFE_MonitoringBoolGaugeCell_class(
575 m, "TFE_MonitoringBoolGaugeCell");
576 py::class_<TFE_MonitoringSamplerCell> TFE_MonitoringSamplerCell_class(
577 m, "TFE_MonitoringSamplerCell");
578 py::class_<TFE_MonitoringBuckets> TFE_MonitoringBuckets_class(
579 m, "TFE_MonitoringBuckets");
580 py::class_<TFE_MonitoringSampler0> TFE_MonitoringSampler0_class(
581 m, "TFE_MonitoringSampler0");
582 py::class_<TFE_MonitoringSampler1> TFE_MonitoringSampler1_class(
583 m, "TFE_MonitoringSampler1");
584 py::class_<TFE_MonitoringSampler2> TFE_MonitoringSampler2_class(
585 m, "TFE_MonitoringSampler2");
586 py::class_<tensorflow::CancellationManager> TFE_CancellationManager_class(
587 m, "TFE_CancellationManager");
588
589 py::class_<TF_DeviceList> TF_DeviceList_class(m, "TF_DeviceList");
590 py::class_<TF_Function> TF_Function_class(m, "TF_Function");
591
592 m.def("TFE_Py_RegisterExceptionClass", [](const py::handle& e) {
593 return tensorflow::PyoOrThrow(TFE_Py_RegisterExceptionClass(e.ptr()));
594 });
595 m.def("TFE_Py_RegisterFallbackExceptionClass", [](const py::handle& e) {
596 return tensorflow::PyoOrThrow(
597 TFE_Py_RegisterFallbackExceptionClass(e.ptr()));
598 });
599
600 m.def("TFE_GetMemoryInfo", [](py::handle& ctx, const char* device_name) {
601 tensorflow::Device* matched_device =
602 tensorflow::GetMatchedDevice(ctx, device_name);
603
604 tensorflow::AllocatorAttributes attrs;
605 tensorflow::Allocator* allocator = matched_device->GetAllocator(attrs);
606
607 if (absl::optional<tensorflow::AllocatorStats> stats =
608 allocator->GetStats()) {
609 return std::map<std::string, int64_t>{{"current", stats->bytes_in_use},
610 {"peak", stats->peak_bytes_in_use}};
611 }
612
613 tensorflow::ThrowValueError(
614 absl::StrFormat("Allocator stats not available for device '%s'",
615 device_name)
616 .c_str());
617 });
618
619 m.def("TFE_ResetMemoryStats", [](py::handle& ctx, const char* device_name) {
620 tensorflow::Device* matched_device =
621 tensorflow::GetMatchedDevice(ctx, device_name);
622
623 tensorflow::AllocatorAttributes attrs;
624 tensorflow::Allocator* allocator = matched_device->GetAllocator(attrs);
625
626 if (!allocator->ClearStats()) {
627 tensorflow::ThrowValueError(
628 absl::StrFormat("Cannot reset memory stats for device '%s'",
629 device_name)
630 .c_str());
631 }
632 });
633
634 // XLA Eager Logic
635 m.def("TF_SetXlaEnableLazyCompilation", &TF_SetXlaEnableLazyCompilation);
636 m.def("TF_SetTfXlaCpuGlobalJit", &TF_SetTfXlaCpuGlobalJit);
637 m.def("TF_SetXlaAutoJitMode", &TF_SetXlaAutoJitMode);
638 m.def("TF_SetXlaConstantFoldingDisabled", &TF_SetXlaConstantFoldingDisabled);
639 m.def("TF_GetXlaConstantFoldingDisabled", &TF_GetXlaConstantFoldingDisabled);
640 m.def("TF_SetXlaMinClusterSize", &TF_SetXlaMinClusterSize);
641 m.def("TF_GetCompilerIr", &tensorflow::TFE_GetCompilerIr);
642
643 // MLIR Logic
644 m.def("TF_IsMlirBridgeEnabled", [] {
645 // Since python protobuf enums are integers, cast to an integer before
646 // returning the enum to python.
647 return static_cast<int32_t>(
648 tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge);
649 });
650 m.def("TF_EnableMlirBridge", [](bool enabled) {
651 tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge =
652 enabled
653 ? tensorflow::ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED
654 : tensorflow::ConfigProto::Experimental::
655 MLIR_BRIDGE_ROLLOUT_DISABLED;
656 });
657 m.def("TF_EnableXlaDevices", [] {
658 tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true;
659 });
660
661 // // TFE_Context Logic
662 m.def(
663 "TFE_NewContext",
664 [](const TFE_ContextOptions* opts) {
665 tensorflow::Safe_TF_StatusPtr status =
666 tensorflow::make_safe(TF_NewStatus());
667 TFE_Context* context = TFE_NewContext(opts, status.get());
668 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
669 return tensorflow::PyoOrThrow(tensorflow::OutputTFE_Context(context));
670 },
671 py::return_value_policy::reference);
672 m.def("TFE_DeleteContext", [](py::handle& o) {
673 TFE_DeleteContext(tensorflow::InputTFE_Context(o));
674 });
675 m.def(
676 "TFE_ContextListDevices",
677 [](py::handle& o) {
678 tensorflow::Safe_TF_StatusPtr status =
679 tensorflow::make_safe(TF_NewStatus());
680 auto output = TFE_ContextListDevices(tensorflow::InputTFE_Context(o),
681 status.get());
682 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
683 return output;
684 },
685 py::return_value_policy::reference);
686 m.def(
687 "TFE_SetLogicalCpuDevices",
688 [](py::handle& ctx, int num_cpus, const char* prefix) {
689 tensorflow::Safe_TF_StatusPtr status =
690 tensorflow::make_safe(TF_NewStatus());
691 TFE_SetLogicalCpuDevices(tensorflow::InputTFE_Context(ctx), num_cpus,
692 prefix, status.get());
693 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
694 },
695 py::return_value_policy::reference);
696 m.def("TFE_HostAddressSpace", [](py::handle& o, TF_Buffer& buf) {
697 TFE_HostAddressSpace(tensorflow::InputTFE_Context(o), &buf);
698 });
699 m.def("TFE_ContextAddFunction", [](py::handle& ctx, TF_Function* func) {
700 tensorflow::Safe_TF_StatusPtr status =
701 tensorflow::make_safe(TF_NewStatus());
702 TFE_ContextAddFunction(tensorflow::InputTFE_Context(ctx), func,
703 status.get());
704 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
705 });
706 m.def("TFE_ContextAddFunctionDef",
707 [](py::handle& ctx, const char* serialized_function_def, size_t size) {
708 tensorflow::Safe_TF_StatusPtr status =
709 tensorflow::make_safe(TF_NewStatus());
710 TFE_ContextAddFunctionDef(tensorflow::InputTFE_Context(ctx),
711 serialized_function_def, size,
712 status.get());
713 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
714 });
715 m.def("TFE_ContextGetFunctionDef",
716 [](py::handle& ctx, const char* function_name, TF_Buffer& buf) {
717 tensorflow::Safe_TF_StatusPtr status =
718 tensorflow::make_safe(TF_NewStatus());
719 TFE_ContextGetFunctionDef(tensorflow::InputTFE_Context(ctx),
720 function_name, &buf, status.get());
721 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
722 });
723 m.def("TFE_ContextRemoveFunction", [](py::handle& ctx, const char* name) {
724 tensorflow::Safe_TF_StatusPtr status =
725 tensorflow::make_safe(TF_NewStatus());
726 TFE_ContextRemoveFunction(tensorflow::InputTFE_Context(ctx), name,
727 status.get());
728 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
729 });
730 m.def("TFE_ContextHasFunction", [](py::handle& ctx, const char* name) {
731 tensorflow::Safe_TF_StatusPtr status =
732 tensorflow::make_safe(TF_NewStatus());
733 auto output =
734 TFE_ContextHasFunction(tensorflow::InputTFE_Context(ctx), name);
735 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
736 return output;
737 });
738 m.def("TFE_ContextListFunctionNames", [](py::handle& ctx) {
739 return tensorflow::unwrap(tensorflow::InputTFE_Context(ctx))
740 ->ListFunctionNames();
741 });
742 m.def("TFE_ContextEnableRunMetadata", [](py::handle& ctx) {
743 TFE_ContextEnableRunMetadata(tensorflow::InputTFE_Context(ctx));
744 });
745 m.def("TFE_ContextDisableRunMetadata", [](py::handle& ctx) {
746 TFE_ContextEnableRunMetadata(tensorflow::InputTFE_Context(ctx));
747 });
748 m.def("TFE_ContextEnableGraphCollection", [](py::handle& ctx) {
749 TFE_ContextEnableGraphCollection(tensorflow::InputTFE_Context(ctx));
750 });
751 m.def("TFE_ContextDisableGraphCollection", [](py::handle& ctx) {
752 TFE_ContextDisableGraphCollection(tensorflow::InputTFE_Context(ctx));
753 });
754 m.def("TFE_ContextExportRunMetadata", [](py::handle& ctx, TF_Buffer& buf) {
755 tensorflow::Safe_TF_StatusPtr status =
756 tensorflow::make_safe(TF_NewStatus());
757 TFE_ContextExportRunMetadata(tensorflow::InputTFE_Context(ctx), &buf,
758 status.get());
759 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
760 });
761 m.def("TFE_ContextClearCaches", [](py::handle& o) {
762 TFE_ContextClearCaches(tensorflow::InputTFE_Context(o));
763 });
764 m.def("TFE_GetContextId", [](py::handle& ctx) {
765 return TFE_GetContextId(tensorflow::InputTFE_Context(ctx));
766 });
767 m.def("TFE_ContextGetDevicePlacementPolicy", [](py::handle& ctx) {
768 return TFE_ContextGetDevicePlacementPolicy(
769 tensorflow::InputTFE_Context(ctx));
770 });
771 m.def("TFE_ContextSetThreadLocalDevicePlacementPolicy",
772 [](py::handle& ctx, TFE_ContextDevicePlacementPolicy policy) {
773 TFE_ContextSetThreadLocalDevicePlacementPolicy(
774 tensorflow::InputTFE_Context(ctx), policy);
775 });
776 m.def("TFE_ContextSetServerDef", [](py::handle& ctx, int keep_alive_secs,
777 py::bytes proto) {
778 tensorflow::Safe_TF_StatusPtr status =
779 tensorflow::make_safe(TF_NewStatus());
780 tensorflow::Safe_TF_BufferPtr buf =
781 tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr()));
782 TFE_ContextSetServerDef(tensorflow::InputTFE_Context(ctx), keep_alive_secs,
783 buf.get()->data, buf.get()->length, status.get());
784 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
785 });
786 m.def("TFE_ContextUpdateServerDef", [](py::handle& ctx, int keep_alive_secs,
787 py::bytes proto) {
788 tensorflow::Safe_TF_StatusPtr status =
789 tensorflow::make_safe(TF_NewStatus());
790 tensorflow::Safe_TF_BufferPtr buf =
791 tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr()));
792 Py_BEGIN_ALLOW_THREADS;
793 TFE_ContextUpdateServerDef(tensorflow::InputTFE_Context(ctx),
794 keep_alive_secs, buf.get()->data,
795 buf.get()->length, status.get());
796 Py_END_ALLOW_THREADS;
797 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
798 });
799 m.def("TFE_ContextCheckAlive", [](py::handle& ctx, const char* worker_name) {
800 tensorflow::Safe_TF_StatusPtr status =
801 tensorflow::make_safe(TF_NewStatus());
802 bool output = TFE_ContextCheckAlive(tensorflow::InputTFE_Context(ctx),
803 worker_name, status.get());
804 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
805 return output;
806 });
807 m.def("TFE_ContextSyncExecutors", [](py::handle& ctx) {
808 tensorflow::Safe_TF_StatusPtr status =
809 tensorflow::make_safe(TF_NewStatus());
810 // NOTE: release Python GIL for pending PyFunc ops to be executed properly.
811 Py_BEGIN_ALLOW_THREADS;
812 TFE_ContextAsyncWait(tensorflow::InputTFE_Context(ctx), status.get());
813 Py_END_ALLOW_THREADS;
814 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
815 });
816 m.def("TFE_ContextClearExecutors", [](py::handle& ctx) {
817 tensorflow::Safe_TF_StatusPtr status =
818 tensorflow::make_safe(TF_NewStatus());
819 // NOTE: release Python GIL for pending PyFunc ops to be executed properly.
820 Py_BEGIN_ALLOW_THREADS;
821 TFE_ContextAsyncWait(tensorflow::InputTFE_Context(ctx), status.get());
822 Py_END_ALLOW_THREADS;
823 // NOTE: different from TFE_ContextSyncExecutors that raises potential
824 // errors, deliberately ignore executor statuses in cleanup.
825 });
826 m.def(
827 "TFE_InsertConfigKeyValue",
828 [](py::handle& ctx, const char* config_key, const char* config_value) {
829 tensorflow::Safe_TF_StatusPtr status =
830 tensorflow::make_safe(TF_NewStatus());
831 Py_BEGIN_ALLOW_THREADS;
832 TFE_InsertConfigKeyValue(tensorflow::InputTFE_Context(ctx), config_key,
833 config_value, status.get());
834 Py_END_ALLOW_THREADS;
835 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
836 },
837 py::return_value_policy::reference);
838 m.def(
839 "TFE_GetConfigKeyValue",
840 [](py::handle& ctx, const char* config_key, TF_Buffer& config_value) {
841 tensorflow::Safe_TF_StatusPtr status =
842 tensorflow::make_safe(TF_NewStatus());
843 Py_BEGIN_ALLOW_THREADS;
844 TFE_GetConfigKeyValue(tensorflow::InputTFE_Context(ctx), config_key,
845 &config_value, status.get());
846 Py_END_ALLOW_THREADS;
847 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
848 },
849 py::return_value_policy::reference);
850 m.def(
851 "TFE_DeleteConfigKeyValue",
852 [](py::handle& ctx, const char* config_key) {
853 tensorflow::Safe_TF_StatusPtr status =
854 tensorflow::make_safe(TF_NewStatus());
855 Py_BEGIN_ALLOW_THREADS;
856 TFE_DeleteConfigKeyValue(tensorflow::InputTFE_Context(ctx), config_key,
857 status.get());
858 Py_END_ALLOW_THREADS;
859 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
860 },
861 py::return_value_policy::reference);
862 m.def(
863 "TFE_ReportErrorToCluster",
864 [](py::handle& ctx, int error_code, const char* error_message) {
865 tensorflow::Safe_TF_StatusPtr status =
866 tensorflow::make_safe(TF_NewStatus());
867 TFE_ReportErrorToCluster(tensorflow::InputTFE_Context(ctx), error_code,
868 error_message, status.get());
869 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
870 },
871 py::return_value_policy::reference);
872 m.def("TFE_ContextSetSoftDevicePlacement", [](py::handle& ctx, bool enable) {
873 tensorflow::Safe_TF_StatusPtr status =
874 tensorflow::make_safe(TF_NewStatus());
875 TFE_ContextSetSoftDevicePlacement(tensorflow::InputTFE_Context(ctx), enable,
876 status.get());
877 });
878 m.def("TFE_ContextSetLogDevicePlacement", [](py::handle& ctx, bool enable) {
879 tensorflow::Safe_TF_StatusPtr status =
880 tensorflow::make_safe(TF_NewStatus());
881 TFE_ContextSetSoftDevicePlacement(tensorflow::InputTFE_Context(ctx), enable,
882 status.get());
883 });
884
885 // TFE_Executor logic
886 m.def(
887 "TFE_NewExecutor",
888 [](const bool is_async) {
889 TFE_Executor* exc = TFE_NewExecutor(is_async);
890 return exc;
891 },
892 py::return_value_policy::reference);
893 m.def("TFE_DeleteExecutor", &TFE_DeleteExecutor);
894 m.def("TFE_ExecutorIsAsync", &TFE_ExecutorIsAsync);
895 m.def("TFE_ExecutorWaitForAllPendingNodes", [](TFE_Executor& exc) {
896 tensorflow::Safe_TF_StatusPtr status =
897 tensorflow::make_safe(TF_NewStatus());
898 // NOTE: release Python GIL for pending PyFunc ops to be executed properly.
899 Py_BEGIN_ALLOW_THREADS;
900 TFE_ExecutorWaitForAllPendingNodes(&exc, status.get());
901 Py_END_ALLOW_THREADS;
902 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
903 });
904 m.def("TFE_ExecutorClearError", &TFE_ExecutorClearError);
905 m.def("TFE_ContextSetExecutorForThread", [](py::handle& ctx,
906 TFE_Executor& exc) {
907 TFE_ContextSetExecutorForThread(tensorflow::InputTFE_Context(ctx), &exc);
908 });
909 m.def(
910 "TFE_ContextGetExecutorForThread",
911 [](py::handle& o) {
912 return TFE_ContextGetExecutorForThread(tensorflow::InputTFE_Context(o));
913 },
914 py::return_value_policy::reference);
915
916 m.def("TFE_OpNameGetAttrType",
917 [](py::handle& ctx, const char* op_or_function_name,
918 const char* attr_name) {
919 int temp = 0;
920 unsigned char* is_list = reinterpret_cast<unsigned char*>(&temp);
921 tensorflow::Safe_TF_StatusPtr status =
922 tensorflow::make_safe(TF_NewStatus());
923 auto output = TFE_OpNameGetAttrType(tensorflow::InputTFE_Context(ctx),
924 op_or_function_name, attr_name,
925 is_list, status.get());
926 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
927 #if PY_MAJOR_VERSION < 3
928 PyObject* output_pyo = PyInt_FromLong(output);
929 #else
930 PyObject* output_pyo = PyLong_FromLong(output);
931 #endif
932 if (*is_list == 1) {
933 PyObject* list = PyList_New(1);
934 PyList_SetItem(list, 0, output_pyo);
935 return tensorflow::PyoOrThrow(list);
936 }
937 return tensorflow::PyoOrThrow(output_pyo);
938 });
939 m.def("TFE_Py_InitEagerTensor", [](const py::handle& o) {
940 return tensorflow::PyoOrThrow(TFE_Py_InitEagerTensor(o.ptr()));
941 });
942 m.def("TFE_Py_PackEagerTensors",
943 [](const py::handle& context, const py::handle& handles) {
944 return tensorflow::TFE_Py_PackEagerTensors_wrapper(context, handles);
945 });
946 m.def("TFE_Py_SetEagerTensorProfiler", &TFE_Py_SetEagerTensorProfiler);
947 m.def("TFE_Py_RegisterJVPFunction", [](const py::handle& o) {
948 return tensorflow::PyoOrThrow(TFE_Py_RegisterJVPFunction(o.ptr()));
949 });
950 m.def("TFE_Py_RegisterGradientFunction", [](const py::handle& o) {
951 return tensorflow::PyoOrThrow(TFE_Py_RegisterGradientFunction(o.ptr()));
952 });
953 m.def("TFE_Py_Execute",
954 [](const py::handle& context, const char* device_name,
955 const char* op_name, const py::handle& inputs,
956 const py::handle& attrs, const py::handle& num_outputs) {
957 return tensorflow::TFE_Py_ExecuteCancelable_wrapper(
958 context, device_name, op_name, inputs, attrs.ptr(), nullptr,
959 num_outputs);
960 });
961 m.def(
962 "TFE_Py_ExecuteCancelable",
963 [](const py::handle& context, const char* device_name,
964 const char* op_name, const py::handle& inputs, const py::handle& attrs,
965 tensorflow::CancellationManager& cancellation_manager,
966 const py::handle& num_outputs) {
967 return tensorflow::TFE_Py_ExecuteCancelable_wrapper(
968 context, device_name, op_name, inputs, attrs.ptr(),
969 &cancellation_manager, num_outputs);
970 });
971 m.def("TFE_Py_FastPathExecute", [](const py::args args) {
972 // TFE_Py_FastPathExecute requires error checking prior to returning.
973 return tensorflow::PyoOrThrow(TFE_Py_FastPathExecute_C(args.ptr()));
974 });
975 m.def("TFE_Py_RecordGradient",
976 [](const py::handle& op_name, const py::handle& inputs,
977 const py::handle& attrs, const py::handle& results,
978 const py::handle& forward_pass_name_scope) {
979 return tensorflow::PyoOrThrow(TFE_Py_RecordGradient(
980 op_name.ptr(), inputs.ptr(), attrs.ptr(), results.ptr(),
981 forward_pass_name_scope.ptr()));
982 });
983 m.def("TFE_Py_UID", []() { return tensorflow::PyoOrThrow(TFE_Py_UID()); });
984
985 // TFE_Py_Tape Logic
986 m.def("TFE_Py_TapeSetNew", [](const py::handle& persistent,
987 const py::handle& watch_accessed_variables) {
988 return tensorflow::PyoOrThrow(
989 TFE_Py_TapeSetNew(persistent.ptr(), watch_accessed_variables.ptr()));
990 });
991 m.def("TFE_Py_TapeSetAdd",
992 [](const py::handle& tape) { TFE_Py_TapeSetAdd(tape.ptr()); });
993 m.def("TFE_Py_TapeSetRemove",
994 [](const py::handle& tape) { TFE_Py_TapeSetRemove(tape.ptr()); });
995 m.def("TFE_Py_TapeSetStopOnThread", &TFE_Py_TapeSetStopOnThread);
996 m.def("TFE_Py_TapeSetRestartOnThread", &TFE_Py_TapeSetRestartOnThread);
997 m.def("TFE_Py_TapeSetIsStopped",
998 []() { return tensorflow::PyoOrThrow(TFE_Py_TapeSetIsStopped()); });
999 m.def("TFE_Py_TapeSetIsEmpty",
1000 []() { return tensorflow::PyoOrThrow(TFE_Py_TapeSetIsEmpty()); });
1001 m.def("TFE_Py_TapeSetShouldRecordBackprop", [](const py::handle& tensors) {
1002 return tensorflow::PyoOrThrow(
1003 TFE_Py_TapeSetShouldRecordBackprop(tensors.ptr()));
1004 });
1005 m.def("TFE_Py_TapeSetPossibleGradientTypes", [](const py::handle& tensors) {
1006 return tensorflow::PyoOrThrow(
1007 TFE_Py_TapeSetPossibleGradientTypes(tensors.ptr()));
1008 });
1009 m.def("TFE_Py_TapeSetDeleteTrace", &TFE_Py_TapeSetDeleteTrace);
1010 m.def("TFE_Py_TapeSetRecordOperation",
1011 [](const py::handle& op_type, const py::handle& output_tensors,
1012 const py::handle& input_tensors, const py::handle& backward_function,
1013 const py::handle& forward_function) {
1014 return tensorflow::PyoOrThrow(TFE_Py_TapeSetRecordOperation(
1015 op_type.ptr(), output_tensors.ptr(), input_tensors.ptr(),
1016 backward_function.ptr(), forward_function.ptr()));
1017 });
1018 m.def(
1019 "TFE_Py_TapeSetRecordOperationBackprop",
1020 [](const py::handle& op_type, const py::handle& output_tensors,
1021 const py::handle& input_tensors, const py::handle& backward_function) {
1022 return tensorflow::PyoOrThrow(TFE_Py_TapeSetRecordOperationBackprop(
1023 op_type.ptr(), output_tensors.ptr(), input_tensors.ptr(),
1024 backward_function.ptr()));
1025 });
1026 m.def(
1027 "TFE_Py_TapeSetRecordOperationForwardprop",
1028 [](const py::handle& op_type, const py::handle& output_tensors,
1029 const py::handle& input_tensors, const py::handle& backward_function,
1030 const py::handle& forwardprop_output_indices) {
1031 return tensorflow::PyoOrThrow(TFE_Py_TapeSetRecordOperationForwardprop(
1032 op_type.ptr(), output_tensors.ptr(), input_tensors.ptr(),
1033 backward_function.ptr(), forwardprop_output_indices.ptr()));
1034 });
1035 m.def("TFE_Py_TapeGradient",
1036 [](const py::handle& tape, const py::handle& target,
1037 const py::handle& sources, const py::handle& output_gradients,
1038 const py::handle& sources_raw,
1039 const py::handle& unconnected_gradients) {
1040 tensorflow::Safe_TF_StatusPtr status =
1041 tensorflow::make_safe(TF_NewStatus());
1042 PyObject* output = TFE_Py_TapeGradient(
1043 tape.ptr(), target.ptr(), sources.ptr(), output_gradients.ptr(),
1044 sources_raw.ptr(), unconnected_gradients.ptr(), status.get());
1045 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1046 return tensorflow::PyoOrThrow(output);
1047 });
1048
1049 m.def("TFE_Py_TapeVariableAccessed", [](const py::handle& variable) {
1050 TFE_Py_TapeVariableAccessed(variable.ptr());
1051 });
1052 m.def("TFE_Py_TapeWatch",
1053 [](const py::handle& tape, const py::handle& tensor) {
1054 TFE_Py_TapeWatch(tape.ptr(), tensor.ptr());
1055 });
1056 m.def("TFE_Py_TapeWatchVariable",
1057 [](const py::handle& tape, const py::handle& variable) {
1058 TFE_Py_TapeWatchVariable(tape.ptr(), variable.ptr());
1059 });
1060 m.def("TFE_Py_TapeWatchedVariables", [](const py::handle& tape) {
1061 return tensorflow::PyoOrThrow(TFE_Py_TapeWatchedVariables(tape.ptr()));
1062 });
1063
1064 // TFE_Py_VariableWatcher logic.
1065 m.def("TFE_Py_VariableWatcherNew",
1066 []() { return tensorflow::PyoOrThrow(TFE_Py_VariableWatcherNew()); });
1067 m.def("TFE_Py_VariableWatcherRemove", [](const py::handle& variable_watcher) {
1068 TFE_Py_VariableWatcherRemove(variable_watcher.ptr());
1069 });
1070 m.def("TFE_Py_VariableWatcherVariableAccessed",
1071 [](const py::handle& variable) {
1072 TFE_Py_VariableWatcherVariableAccessed(variable.ptr());
1073 });
1074 m.def("TFE_Py_VariableWatcherWatchedVariables",
1075 [](const py::handle& variable_watcher) {
1076 return tensorflow::PyoOrThrow(
1077 TFE_Py_VariableWatcherWatchedVariables(variable_watcher.ptr()));
1078 });
1079
1080 // TFE_Py_ForwardAccumulator logic.
1081 m.def("TFE_Py_ForwardAccumulatorNew", [](bool use_batch) {
1082 return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorNew(use_batch));
1083 });
1084
1085 m.def("TFE_Py_ForwardAccumulatorSetAdd", [](const py::handle& accumulator) {
1086 return tensorflow::PyoOrThrow(
1087 TFE_Py_ForwardAccumulatorSetAdd(accumulator.ptr()));
1088 });
1089 m.def("TFE_Py_ForwardAccumulatorSetRemove",
1090 [](const py::handle& accumulator) {
1091 TFE_Py_ForwardAccumulatorSetRemove(accumulator.ptr());
1092 });
1093
1094 m.def("TFE_Py_ForwardAccumulatorWatch",
1095 [](const py::handle& accumulator, const py::handle& tensor,
1096 const py::handle& tangent) {
1097 TFE_Py_ForwardAccumulatorWatch(accumulator.ptr(), tensor.ptr(),
1098 tangent.ptr());
1099 });
1100 m.def("TFE_Py_ForwardAccumulatorJVP",
1101 [](const py::handle& accumulator, const py::handle& tensor) {
1102 return tensorflow::PyoOrThrow(
1103 TFE_Py_ForwardAccumulatorJVP(accumulator.ptr(), tensor.ptr()));
1104 });
1105 m.def("TFE_Py_ForwardAccumulatorPushState", []() {
1106 return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorPushState());
1107 });
1108 m.def("TFE_Py_ForwardAccumulatorPopState", []() {
1109 return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorPopState());
1110 });
1111 m.def("TFE_Py_PackJVPs", [](const py::handle& tensors) {
1112 return tensorflow::PyoOrThrow(TFE_Py_PackJVPs(tensors.ptr()));
1113 });
1114
1115 // TFE_ContextOptions Logic
1116 m.def("TFE_NewContextOptions", &TFE_NewContextOptions,
1117 py::return_value_policy::reference);
1118 m.def("TFE_ContextOptionsSetConfig", [](TFE_ContextOptions* options,
1119 py::bytes proto) {
1120 tensorflow::Safe_TF_StatusPtr status =
1121 tensorflow::make_safe(TF_NewStatus());
1122 tensorflow::Safe_TF_BufferPtr buf =
1123 tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr()));
1124 TFE_ContextOptionsSetConfig(options, buf.get()->data, buf.get()->length,
1125 status.get());
1126 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1127 });
1128 m.def("TFE_ContextOptionsSetDevicePlacementPolicy",
1129 &TFE_ContextOptionsSetDevicePlacementPolicy);
1130 m.def("TFE_ContextOptionsSetTfrt", &TFE_ContextOptionsSetTfrt);
1131 m.def("TFE_ContextOptionsSetTfrtDistributedRuntime",
1132 &TFE_ContextOptionsSetTfrtDistributedRuntime);
1133 // Experimental feature, intentionally not exposed as a C API yet.
1134 m.def("TFE_ContextOptionsSetRunEagerOpAsFunction",
1135 [](TFE_ContextOptions* options, bool run_eager_op_as_function) {
1136 options->run_eager_op_as_function = run_eager_op_as_function;
1137 });
1138 m.def("TFE_ContextOptionsSetAsync", &TFE_ContextOptionsSetAsync);
1139 m.def("TFE_DeleteContextOptions", &TFE_DeleteContextOptions,
1140 py::return_value_policy::reference);
1141
1142 // TFE_Py_TensorShape Logic
1143 m.def("TFE_Py_TensorShapeSlice",
1144 [](const py::handle& tensors, int slice_dim) {
1145 return tensorflow::PyoOrThrow(
1146 TFE_Py_TensorShapeSlice(tensors.ptr(), slice_dim));
1147 });
1148 m.def("TFE_Py_TensorShapeOnDevice", [](const py::handle& tensors,
1149 int slice_dim) {
1150 return tensorflow::PyoOrThrow(TFE_Py_TensorShapeOnDevice(tensors.ptr()));
1151 });
1152 m.def("TFE_Py_EnableInteractivePythonLogging",
1153 &TFE_Py_EnableInteractivePythonLogging);
1154
1155 // Additional Context Logic
1156 m.def("TFE_Py_SetEagerContext", [](const py::handle& o) {
1157 return tensorflow::PyoOrThrow(TFE_Py_SetEagerContext(o.ptr()));
1158 });
1159 m.def("TFE_Py_RegisterVSpace", [](const py::handle& o) {
1160 return tensorflow::PyoOrThrow(TFE_Py_RegisterVSpace(o.ptr()));
1161 });
1162 m.def("TFE_Py_EncodeArg",
1163 [](const py::handle& o, bool include_tensor_ranks_only) {
1164 return tensorflow::PyoOrThrow(
1165 TFE_Py_EncodeArg(o.ptr(), include_tensor_ranks_only));
1166 });
1167 m.def("TFE_EnableCollectiveOps", [](const py::handle& ctx, py::bytes proto) {
1168 tensorflow::Safe_TF_StatusPtr status =
1169 tensorflow::make_safe(TF_NewStatus());
1170 tensorflow::Safe_TF_BufferPtr buf =
1171 tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr()));
1172 TFE_EnableCollectiveOps(tensorflow::InputTFE_Context(ctx), buf.get()->data,
1173 buf.get()->length, status.get());
1174 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1175 });
1176 m.def("TFE_AbortCollectiveOps", [](const py::handle& ctx, int code,
1177 const char* message) {
1178 tensorflow::Safe_TF_StatusPtr status =
1179 tensorflow::make_safe(TF_NewStatus());
1180 TF_SetStatus(status.get(), static_cast<TF_Code>(code), message);
1181 TFE_AbortCollectiveOps(tensorflow::InputTFE_Context(ctx), status.get());
1182 });
1183 m.def("TFE_CollectiveOpsCheckPeerHealth",
1184 [](const py::handle& ctx, const char* task, int64_t timeout_in_ms) {
1185 tensorflow::Safe_TF_StatusPtr status =
1186 tensorflow::make_safe(TF_NewStatus());
1187 TFE_CollectiveOpsCheckPeerHealth(tensorflow::InputTFE_Context(ctx),
1188 task, timeout_in_ms, status.get());
1189 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1190 });
1191 m.def("TF_ListPhysicalDevices", &tensorflow::TF_ListPhysicalDevices);
1192 m.def("TF_ListPluggablePhysicalDevices",
1193 &tensorflow::TF_ListPluggablePhysicalDevices);
1194 m.def("TF_GetDeviceDetails", &tensorflow::TF_GetDeviceDetails);
1195 m.def("TF_DeleteDeviceList", &TF_DeleteDeviceList,
1196 py::return_value_policy::reference);
1197 m.def("TF_DeviceListCount", &TF_DeviceListCount);
1198 m.def("TF_DeviceListName", [](const TF_DeviceList* list, int index) {
1199 tensorflow::Safe_TF_StatusPtr status =
1200 tensorflow::make_safe(TF_NewStatus());
1201 auto output = TF_DeviceListName(list, index, status.get());
1202 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1203 return output;
1204 });
1205 m.def("TF_DeviceListType", [](const TF_DeviceList* list, int index) {
1206 tensorflow::Safe_TF_StatusPtr status =
1207 tensorflow::make_safe(TF_NewStatus());
1208 auto output = TF_DeviceListType(list, index, status.get());
1209 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1210 return output;
1211 });
1212
1213 m.def("TF_PickUnusedPortOrDie", &TF_PickUnusedPortOrDie);
1214
1215 // TFE_MonitoringCounter Logic
1216 m.def("TFE_MonitoringCounterCellIncrementBy",
1217 &TFE_MonitoringCounterCellIncrementBy);
1218 m.def("TFE_MonitoringCounterCellValue", &TFE_MonitoringCounterCellValue);
1219 m.def(
1220 "TFE_MonitoringNewCounter0",
1221 [](const char* name, const char* description) {
1222 tensorflow::Safe_TF_StatusPtr status =
1223 tensorflow::make_safe(TF_NewStatus());
1224 auto output =
1225 TFE_MonitoringNewCounter0(name, status.get(), description);
1226 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1227 return output;
1228 },
1229 py::return_value_policy::reference);
1230 m.def("TFE_MonitoringDeleteCounter0", &TFE_MonitoringDeleteCounter0,
1231 py::return_value_policy::reference);
1232 m.def("TFE_MonitoringGetCellCounter0", &TFE_MonitoringGetCellCounter0,
1233 py::return_value_policy::reference);
1234 m.def(
1235 "TFE_MonitoringNewCounter1",
1236 [](const char* name, const char* description, const char* label1) {
1237 tensorflow::Safe_TF_StatusPtr status =
1238 tensorflow::make_safe(TF_NewStatus());
1239 auto output =
1240 TFE_MonitoringNewCounter1(name, status.get(), description, label1);
1241 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1242 return output;
1243 },
1244 py::return_value_policy::reference);
1245 m.def("TFE_MonitoringDeleteCounter1", &TFE_MonitoringDeleteCounter1,
1246 py::return_value_policy::reference);
1247 m.def("TFE_MonitoringGetCellCounter1", &TFE_MonitoringGetCellCounter1,
1248 py::return_value_policy::reference);
1249 m.def(
1250 "TFE_MonitoringNewCounter2",
1251 [](const char* name, const char* description, const char* label1,
1252 const char* label2) {
1253 tensorflow::Safe_TF_StatusPtr status =
1254 tensorflow::make_safe(TF_NewStatus());
1255 auto output = TFE_MonitoringNewCounter2(name, status.get(), description,
1256 label1, label2);
1257 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1258 return output;
1259 },
1260 py::return_value_policy::reference);
1261 m.def("TFE_MonitoringDeleteCounter2", &TFE_MonitoringDeleteCounter2,
1262 py::return_value_policy::reference);
1263 m.def("TFE_MonitoringGetCellCounter2", &TFE_MonitoringGetCellCounter2,
1264 py::return_value_policy::reference);
1265
1266 // TFE_MonitoringIntGauge Logic
1267 m.def("TFE_MonitoringIntGaugeCellSet", &TFE_MonitoringIntGaugeCellSet);
1268 m.def("TFE_MonitoringIntGaugeCellValue", &TFE_MonitoringIntGaugeCellValue);
1269 m.def(
1270 "TFE_MonitoringNewIntGauge0",
1271 [](const char* name, const char* description) {
1272 tensorflow::Safe_TF_StatusPtr status =
1273 tensorflow::make_safe(TF_NewStatus());
1274 auto output =
1275 TFE_MonitoringNewIntGauge0(name, status.get(), description);
1276 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1277 return output;
1278 },
1279 py::return_value_policy::reference);
1280 m.def("TFE_MonitoringDeleteIntGauge0", &TFE_MonitoringDeleteIntGauge0,
1281 py::return_value_policy::reference);
1282 m.def("TFE_MonitoringGetCellIntGauge0", &TFE_MonitoringGetCellIntGauge0,
1283 py::return_value_policy::reference);
1284 m.def(
1285 "TFE_MonitoringNewIntGauge1",
1286 [](const char* name, const char* description, const char* label1) {
1287 tensorflow::Safe_TF_StatusPtr status =
1288 tensorflow::make_safe(TF_NewStatus());
1289 auto output =
1290 TFE_MonitoringNewIntGauge1(name, status.get(), description, label1);
1291 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1292 return output;
1293 },
1294 py::return_value_policy::reference);
1295 m.def("TFE_MonitoringDeleteIntGauge1", &TFE_MonitoringDeleteIntGauge1,
1296 py::return_value_policy::reference);
1297 m.def("TFE_MonitoringGetCellIntGauge1", &TFE_MonitoringGetCellIntGauge1,
1298 py::return_value_policy::reference);
1299 m.def(
1300 "TFE_MonitoringNewIntGauge2",
1301 [](const char* name, const char* description, const char* label1,
1302 const char* label2) {
1303 tensorflow::Safe_TF_StatusPtr status =
1304 tensorflow::make_safe(TF_NewStatus());
1305 auto output = TFE_MonitoringNewIntGauge2(name, status.get(),
1306 description, label1, label2);
1307 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1308 return output;
1309 },
1310 py::return_value_policy::reference);
1311 m.def("TFE_MonitoringDeleteIntGauge2", &TFE_MonitoringDeleteIntGauge2,
1312 py::return_value_policy::reference);
1313 m.def("TFE_MonitoringGetCellIntGauge2", &TFE_MonitoringGetCellIntGauge2,
1314 py::return_value_policy::reference);
1315 m.def("TFE_MonitoringStringGaugeCellSet", &TFE_MonitoringStringGaugeCellSet);
1316 m.def("TFE_MonitoringStringGaugeCellValue",
1317 &TFE_MonitoringStringGaugeCellValue);
1318 m.def(
1319 "TFE_MonitoringNewStringGauge0",
1320 [](const char* name, const char* description) {
1321 tensorflow::Safe_TF_StatusPtr status =
1322 tensorflow::make_safe(TF_NewStatus());
1323 auto output =
1324 TFE_MonitoringNewStringGauge0(name, status.get(), description);
1325 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1326 return output;
1327 },
1328 py::return_value_policy::reference);
1329
1330 // TFE_MonitoringStringGauge Logic
1331 m.def("TFE_MonitoringDeleteStringGauge0", &TFE_MonitoringDeleteStringGauge0);
1332 m.def("TFE_MonitoringGetCellStringGauge0", &TFE_MonitoringGetCellStringGauge0,
1333 py::return_value_policy::reference);
1334 m.def(
1335 "TFE_MonitoringNewStringGauge1",
1336 [](const char* name, const char* description, const char* label1) {
1337 tensorflow::Safe_TF_StatusPtr status =
1338 tensorflow::make_safe(TF_NewStatus());
1339 auto output = TFE_MonitoringNewStringGauge1(name, status.get(),
1340 description, label1);
1341 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1342 return output;
1343 },
1344 py::return_value_policy::reference);
1345 m.def("TFE_MonitoringDeleteStringGauge1", &TFE_MonitoringDeleteStringGauge1);
1346 m.def("TFE_MonitoringGetCellStringGauge1", &TFE_MonitoringGetCellStringGauge1,
1347 py::return_value_policy::reference);
1348 m.def(
1349 "TFE_MonitoringNewStringGauge2",
1350 [](const char* name, const char* description, const char* label1,
1351 const char* label2) {
1352 tensorflow::Safe_TF_StatusPtr status =
1353 tensorflow::make_safe(TF_NewStatus());
1354 auto output = TFE_MonitoringNewStringGauge2(
1355 name, status.get(), description, label1, label2);
1356 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1357 return output;
1358 },
1359 py::return_value_policy::reference);
1360 m.def("TFE_MonitoringDeleteStringGauge2", &TFE_MonitoringDeleteStringGauge2);
1361 m.def("TFE_MonitoringGetCellStringGauge2", &TFE_MonitoringGetCellStringGauge2,
1362 py::return_value_policy::reference);
1363
1364 m.def(
1365 "TFE_MonitoringNewStringGauge3",
1366 [](const char* name, const char* description, const char* label1,
1367 const char* label2, const char* label3) {
1368 tensorflow::Safe_TF_StatusPtr status =
1369 tensorflow::make_safe(TF_NewStatus());
1370 auto output = TFE_MonitoringNewStringGauge3(
1371 name, status.get(), description, label1, label2, label3);
1372 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1373 return output;
1374 },
1375 py::return_value_policy::reference);
1376 m.def("TFE_MonitoringDeleteStringGauge3", &TFE_MonitoringDeleteStringGauge3);
1377 m.def("TFE_MonitoringGetCellStringGauge3", &TFE_MonitoringGetCellStringGauge3,
1378 py::return_value_policy::reference);
1379
1380 m.def(
1381 "TFE_MonitoringNewStringGauge4",
1382 [](const char* name, const char* description, const char* label1,
1383 const char* label2, const char* label3, const char* label4) {
1384 tensorflow::Safe_TF_StatusPtr status =
1385 tensorflow::make_safe(TF_NewStatus());
1386 auto output = TFE_MonitoringNewStringGauge4(
1387 name, status.get(), description, label1, label2, label3, label4);
1388 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1389 return output;
1390 },
1391 py::return_value_policy::reference);
1392 m.def("TFE_MonitoringDeleteStringGauge4", &TFE_MonitoringDeleteStringGauge4);
1393 m.def("TFE_MonitoringGetCellStringGauge4", &TFE_MonitoringGetCellStringGauge4,
1394 py::return_value_policy::reference);
1395
1396 // TFE_MonitoringBoolGauge Logic
1397 m.def("TFE_MonitoringBoolGaugeCellSet", &TFE_MonitoringBoolGaugeCellSet);
1398 m.def("TFE_MonitoringBoolGaugeCellValue", &TFE_MonitoringBoolGaugeCellValue);
1399 m.def(
1400 "TFE_MonitoringNewBoolGauge0",
1401 [](const char* name, const char* description) {
1402 tensorflow::Safe_TF_StatusPtr status =
1403 tensorflow::make_safe(TF_NewStatus());
1404 auto output =
1405 TFE_MonitoringNewBoolGauge0(name, status.get(), description);
1406 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1407 return output;
1408 },
1409 py::return_value_policy::reference);
1410 m.def("TFE_MonitoringDeleteBoolGauge0", &TFE_MonitoringDeleteBoolGauge0,
1411 py::return_value_policy::reference);
1412 m.def("TFE_MonitoringGetCellBoolGauge0", &TFE_MonitoringGetCellBoolGauge0,
1413 py::return_value_policy::reference);
1414 m.def(
1415 "TFE_MonitoringNewBoolGauge1",
1416 [](const char* name, const char* description, const char* label1) {
1417 tensorflow::Safe_TF_StatusPtr status =
1418 tensorflow::make_safe(TF_NewStatus());
1419 auto output = TFE_MonitoringNewBoolGauge1(name, status.get(),
1420 description, label1);
1421 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1422 return output;
1423 },
1424 py::return_value_policy::reference);
1425 m.def("TFE_MonitoringDeleteBoolGauge1", &TFE_MonitoringDeleteBoolGauge1,
1426 py::return_value_policy::reference);
1427 m.def("TFE_MonitoringGetCellBoolGauge1", &TFE_MonitoringGetCellBoolGauge1,
1428 py::return_value_policy::reference);
1429 m.def(
1430 "TFE_MonitoringNewBoolGauge2",
1431 [](const char* name, const char* description, const char* label1,
1432 const char* label2) {
1433 tensorflow::Safe_TF_StatusPtr status =
1434 tensorflow::make_safe(TF_NewStatus());
1435 auto output = TFE_MonitoringNewBoolGauge2(name, status.get(),
1436 description, label1, label2);
1437 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1438 return output;
1439 },
1440 py::return_value_policy::reference);
1441 m.def("TFE_MonitoringDeleteBoolGauge2", &TFE_MonitoringDeleteBoolGauge2,
1442 py::return_value_policy::reference);
1443 m.def("TFE_MonitoringGetCellBoolGauge2", &TFE_MonitoringGetCellBoolGauge2,
1444 py::return_value_policy::reference);
1445
1446 // TFE_MonitoringSampler Logic
1447 m.def("TFE_MonitoringSamplerCellAdd", &TFE_MonitoringSamplerCellAdd);
1448 m.def("TFE_MonitoringSamplerCellValue", &TFE_MonitoringSamplerCellValue);
1449 m.def("TFE_MonitoringNewExponentialBuckets",
1450 &TFE_MonitoringNewExponentialBuckets,
1451 py::return_value_policy::reference);
1452 m.def("TFE_MonitoringDeleteBuckets", &TFE_MonitoringDeleteBuckets,
1453 py::return_value_policy::reference);
1454 m.def(
1455 "TFE_MonitoringNewSampler0",
1456 [](const char* name, TFE_MonitoringBuckets* buckets,
1457 const char* description) {
1458 tensorflow::Safe_TF_StatusPtr status =
1459 tensorflow::make_safe(TF_NewStatus());
1460 auto output =
1461 TFE_MonitoringNewSampler0(name, buckets, status.get(), description);
1462 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1463 return output;
1464 },
1465 py::return_value_policy::reference);
1466 m.def("TFE_MonitoringDeleteSampler0", &TFE_MonitoringDeleteSampler0,
1467 py::return_value_policy::reference);
1468 m.def("TFE_MonitoringGetCellSampler0", &TFE_MonitoringGetCellSampler0,
1469 py::return_value_policy::reference);
1470 m.def(
1471 "TFE_MonitoringNewSampler1",
1472 [](const char* name, TFE_MonitoringBuckets* buckets,
1473 const char* description, const char* label1) {
1474 tensorflow::Safe_TF_StatusPtr status =
1475 tensorflow::make_safe(TF_NewStatus());
1476 auto output = TFE_MonitoringNewSampler1(name, buckets, status.get(),
1477 description, label1);
1478 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1479 return output;
1480 },
1481 py::return_value_policy::reference);
1482 m.def("TFE_MonitoringDeleteSampler1", &TFE_MonitoringDeleteSampler1,
1483 py::return_value_policy::reference);
1484 m.def("TFE_MonitoringGetCellSampler1", &TFE_MonitoringGetCellSampler1,
1485 py::return_value_policy::reference);
1486 m.def(
1487 "TFE_MonitoringNewSampler2",
1488 [](const char* name, TFE_MonitoringBuckets* buckets,
1489 const char* description, const char* label1, const char* label2) {
1490 tensorflow::Safe_TF_StatusPtr status =
1491 tensorflow::make_safe(TF_NewStatus());
1492 auto output = TFE_MonitoringNewSampler2(name, buckets, status.get(),
1493 description, label1, label2);
1494 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1495 return output;
1496 },
1497 py::return_value_policy::reference);
1498 m.def("TFE_MonitoringDeleteSampler2", &TFE_MonitoringDeleteSampler2,
1499 py::return_value_policy::reference);
1500 m.def("TFE_MonitoringGetCellSampler2", &TFE_MonitoringGetCellSampler2,
1501 py::return_value_policy::reference);
1502
1503 // TFE_CancellationManager Logic
1504 m.def("TFE_NewCancellationManager",
1505 []() { return new tensorflow::CancellationManager(); });
1506 m.def("TFE_CancellationManagerIsCancelled",
1507 &tensorflow::CancellationManager::IsCancelled);
1508 m.def("TFE_CancellationManagerStartCancel",
1509 &tensorflow::CancellationManager::StartCancel);
1510
1511 m.def("TFE_ClearScalarCache", &tensorflow::TFE_ClearScalarCache);
1512
1513 // Util buffer helper functions
1514 m.def("TF_NewBufferFromString", &TF_NewBufferFromString,
1515 py::return_value_policy::reference);
1516
1517 // DLPack functions
1518 m.def("TFE_ToDlpackCapsule", [](py::handle& o) {
1519 PyObject* eager_tensor_pyobject_ptr = o.ptr();
1520 tensorflow::Safe_TF_StatusPtr status =
1521 tensorflow::make_safe(TF_NewStatus());
1522
1523 if (!EagerTensor_CheckExact(eager_tensor_pyobject_ptr)) {
1524 status->status = tensorflow::errors::InvalidArgument(
1525 "The argument to `to_dlpack` must be a TF tensor, not Python object");
1526 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1527 }
1528
1529 TFE_TensorHandle* thandle = EagerTensor_Handle(eager_tensor_pyobject_ptr);
1530 void* dlm_ptr = tensorflow::TFE_HandleToDLPack(thandle, status.get());
1531 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1532
1533 py::capsule capsule(
1534 dlm_ptr, tensorflow::kDlTensorCapsuleName, [](PyObject* capsule) {
1535 if (PyCapsule_IsValid(capsule, tensorflow::kDlTensorCapsuleName)) {
1536 void* dlm_rptr =
1537 PyCapsule_GetPointer(capsule, tensorflow::kDlTensorCapsuleName);
1538 if (dlm_rptr) {
1539 tensorflow::TFE_CallDLManagedTensorDeleter(dlm_rptr);
1540 PyCapsule_SetDestructor(capsule, nullptr);
1541 }
1542 }
1543 });
1544 return capsule;
1545 });
1546
1547 m.def("TFE_FromDlpackCapsule", [](const py::capsule& pycapsule,
1548 const py::handle& context) {
1549 tensorflow::Safe_TF_StatusPtr status =
1550 tensorflow::make_safe(TF_NewStatus());
1551 if (absl::string_view(pycapsule.name()) !=
1552 tensorflow::kDlTensorCapsuleName) {
1553 status->status = tensorflow::errors::InvalidArgument(
1554 "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". "
1555 "Note that a DLPack tensor may be consumed at most once.",
1556 absl::string_view(pycapsule.name()));
1557 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1558 }
1559
1560 TFE_TensorHandle* thandle = tensorflow::TFE_HandleFromDLPack(
1561 pycapsule, status.get(), tensorflow::InputTFE_Context(context));
1562
1563 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1564
1565 PyCapsule_SetName(pycapsule.ptr(), "used_dltensor");
1566 PyCapsule_SetDestructor(pycapsule.ptr(), nullptr);
1567
1568 PyObject* pyhandle = EagerTensorFromHandle(thandle);
1569 return tensorflow::PyoOrThrow(pyhandle);
1570 });
1571
1572 m.def("TFE_Py_RegisterCustomDevice", [](const py::handle& context,
1573 const py::capsule& device,
1574 const char* device_name,
1575 const py::capsule& device_info) {
1576 tensorflow::Safe_TF_StatusPtr status =
1577 tensorflow::make_safe(TF_NewStatus());
1578 if (absl::string_view(device.name()) != "TFE_CustomDevice") {
1579 status->status = tensorflow::errors::InvalidArgument(
1580 "Expected a capsule named 'TFE_CustomDevice' for the `device` "
1581 "argument, got ",
1582 absl::string_view(device.name()));
1583 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1584 }
1585 if (absl::string_view(device_info.name()) !=
1586 "TFE_CustomDevice_DeviceInfo") {
1587 status->status = tensorflow::errors::InvalidArgument(
1588 "Expected a capsule named 'TFE_CustomDevice_DeviceInfo' for "
1589 "the `device_info` argument, got ",
1590 absl::string_view(device_info.name()));
1591 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1592 }
1593 // TFE_RegisterCustomDevice takes ownership
1594 PyCapsule_SetDestructor(device_info.ptr(), nullptr);
1595 TFE_RegisterCustomDevice(
1596 tensorflow::InputTFE_Context(context),
1597 *reinterpret_cast<TFE_CustomDevice*>(
1598 PyCapsule_GetPointer(device.ptr(), "TFE_CustomDevice")),
1599 device_name,
1600 PyCapsule_GetPointer(device_info.ptr(), "TFE_CustomDevice_DeviceInfo"),
1601 status.get());
1602 tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
1603 });
1604
1605 py::class_<EagerContextThreadLocalDataWrapper>(m,
1606 "EagerContextThreadLocalData")
1607 .def(py::init<py::handle, py::handle, py::handle>(),
1608 py::arg("py_eager_context"), py::arg("is_eager"),
1609 py::arg("device_spec"))
1610 .def_property("is_eager",
1611 &EagerContextThreadLocalDataWrapper::get_is_eager,
1612 &EagerContextThreadLocalDataWrapper::set_is_eager)
1613 .def_property(
1614 "invoking_op_callbacks",
1615 &EagerContextThreadLocalDataWrapper::get_invoking_op_callbacks,
1616 &EagerContextThreadLocalDataWrapper::set_invoking_op_callbacks)
1617 .def_property("device_name",
1618 &EagerContextThreadLocalDataWrapper::get_device_name,
1619 &EagerContextThreadLocalDataWrapper::set_device_name)
1620 .def_property("scope_name",
1621 &EagerContextThreadLocalDataWrapper::get_scope_name,
1622 &EagerContextThreadLocalDataWrapper::set_scope_name)
1623 .def_property("device_spec",
1624 &EagerContextThreadLocalDataWrapper::get_device_spec,
1625 &EagerContextThreadLocalDataWrapper::set_device_spec)
1626 .def_property(
1627 "function_call_options",
1628 &EagerContextThreadLocalDataWrapper::get_function_call_options,
1629 &EagerContextThreadLocalDataWrapper::set_function_call_options)
1630 .def_property("executor",
1631 &EagerContextThreadLocalDataWrapper::get_executor,
1632 &EagerContextThreadLocalDataWrapper::set_executor)
1633 .def_property("op_callbacks",
1634 &EagerContextThreadLocalDataWrapper::get_op_callbacks,
1635 &EagerContextThreadLocalDataWrapper::set_op_callbacks);
1636
1637 // C API Enum
1638
1639 py::enum_<TFE_ContextDevicePlacementPolicy>(
1640 m, "TFE_ContextDevicePlacementPolicy")
1641 .value("TFE_DEVICE_PLACEMENT_EXPLICIT", TFE_DEVICE_PLACEMENT_EXPLICIT)
1642 .value("TFE_DEVICE_PLACEMENT_WARN", TFE_DEVICE_PLACEMENT_WARN)
1643 .value("TFE_DEVICE_PLACEMENT_SILENT", TFE_DEVICE_PLACEMENT_SILENT)
1644 .value("TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32",
1645 TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32)
1646 .export_values();
1647
1648 py::enum_<TF_AttrType>(m, "TF_AttrType")
1649 .value("TF_ATTR_STRING", TF_ATTR_STRING)
1650 .value("TF_ATTR_INT", TF_ATTR_INT)
1651 .value("TF_ATTR_FLOAT", TF_ATTR_FLOAT)
1652 .value("TF_ATTR_BOOL", TF_ATTR_BOOL)
1653 .value("TF_ATTR_TYPE", TF_ATTR_TYPE)
1654 .value("TF_ATTR_SHAPE", TF_ATTR_SHAPE)
1655 .value("TF_ATTR_TENSOR", TF_ATTR_TENSOR)
1656 .value("TF_ATTR_PLACEHOLDER", TF_ATTR_PLACEHOLDER)
1657 .value("TF_ATTR_FUNC", TF_ATTR_FUNC)
1658 .export_values();
1659 };
1660