1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <pybind11/stl.h>
17
18 #include <memory>
19
20 #include "pybind11/pybind11.h"
21 #include "tensorflow/c/eager/abstract_context.h"
22 #include "tensorflow/c/eager/abstract_function.h"
23 #include "tensorflow/c/eager/abstract_operation.h"
24 #include "tensorflow/c/eager/abstract_tensor_handle.h"
25 #include "tensorflow/c/eager/c_api.h"
26 #include "tensorflow/c/eager/c_api_internal.h"
27 #include "tensorflow/c/eager/c_api_unified_experimental.h"
28 #include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
29 #include "tensorflow/c/eager/immediate_execution_context.h"
30 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
31 #include "tensorflow/c/eager/tfe_context_internal.h"
32 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
33 #include "tensorflow/core/framework/tensor_shape.h"
34 #include "tensorflow/core/framework/types.pb.h"
35 #include "tensorflow/core/lib/core/status.h"
36 #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
37 #include "tensorflow/core/platform/errors.h"
38 #include "tensorflow/core/platform/refcount.h"
39 #include "tensorflow/python/eager/pywrap_tensor.h"
40 #include "tensorflow/python/lib/core/pybind11_lib.h"
41 #include "tensorflow/python/lib/core/pybind11_status.h"
42 #include "tensorflow/python/lib/core/safe_ptr.h"
43
44 namespace py = pybind11;
45
46 using tensorflow::AbstractContext;
47 using tensorflow::AbstractContextPtr;
48 using tensorflow::AbstractFunction;
49 using tensorflow::AbstractOperation;
50 using tensorflow::AbstractOperationPtr;
51 using tensorflow::AbstractTensorHandle;
52 using tensorflow::AbstractTensorHandlePtr;
53 using tensorflow::OutputList;
54
55 using tensorflow::tracing::TracingContext;
56 using tensorflow::tracing::TracingOperation;
57 using tensorflow::tracing::TracingTensorHandle;
58
59 using tensorflow::ImmediateContextPtr;
60 using tensorflow::ImmediateExecutionContext;
61 using tensorflow::ImmediateExecutionTensorHandle;
62
63 using tensorflow::dyn_cast;
64 using tensorflow::isa;
65 using tensorflow::unwrap;
66 using tensorflow::wrap;
67
68 using tensorflow::DataType;
69 using tensorflow::make_safe;
70 using tensorflow::MaybeRaiseRegisteredFromStatus;
71 using tensorflow::MaybeRaiseRegisteredFromTFStatus;
72 using tensorflow::Pyo;
73 using tensorflow::Safe_TF_StatusPtr;
74 using tensorflow::Status;
75 using tensorflow::string;
76 using tensorflow::TFE_TensorHandleToNumpy;
77 using tensorflow::core::RefCountPtr;
78
79 using tensorflow::errors::Internal;
80 using tensorflow::errors::InvalidArgument;
81
PYBIND11_MODULE(_unified_api,m)82 PYBIND11_MODULE(_unified_api, m) {
83 // Context creation functions.
84 m.def("SetTracingImplementation", [](const char* impl) {
85 Safe_TF_StatusPtr status = make_safe(TF_NewStatus());
86 TF_SetTracingImplementation(impl, status.get());
87 MaybeRaiseRegisteredFromStatus(status->status);
88 });
89 m.def("NewTracingContext", [](const char* fn_name) {
90 Safe_TF_StatusPtr status = make_safe(TF_NewStatus());
91 auto* ctx = unwrap(TF_CreateFunction(fn_name, status.get()));
92 MaybeRaiseRegisteredFromTFStatus(status.get());
93 if (!ctx) {
94 MaybeRaiseRegisteredFromStatus(
95 Internal("TF_CreateFunction returned nullptr"));
96 }
97 if (!isa<TracingContext>(ctx)) {
98 // TODO(srbs): Add a helper to convert the kind enum to a user-friendly
99 // string.
100 MaybeRaiseRegisteredFromStatus(
101 Internal("TF_CreateFunction must return a TracingContext, found ",
102 ctx->getKind()));
103 }
104 return dyn_cast<TracingContext>(ctx);
105 });
106 m.def("EagerContextToImmediateExecutionContext", [](py::handle& obj) {
107 TFE_Context* ctx =
108 static_cast<TFE_Context*>(PyCapsule_GetPointer(obj.ptr(), nullptr));
109 if (!ctx) {
110 MaybeRaiseRegisteredFromStatus(InvalidArgument("TFE_Context is nullptr"));
111 }
112 return unwrap(ctx);
113 });
114
115 // Unified execution context.
116 py::class_<AbstractContext, AbstractContextPtr>(m, "AbstractContext")
117 .def("CreateOperation",
118 [](AbstractContext* self, const char* op,
119 const char* raw_device_name) {
120 auto operation = self->CreateOperation();
121 operation->Reset(op, raw_device_name);
122 return operation;
123 })
124 .def("RegisterFunction",
125 [](AbstractContext* self, AbstractFunction* f) {
126 Status s = self->RegisterFunction(f);
127 MaybeRaiseRegisteredFromStatus(s);
128 })
129 .def("RemoveFunction", [](AbstractContext* self, const string& func) {
130 Status s = self->RemoveFunction(func);
131 MaybeRaiseRegisteredFromStatus(s);
132 });
133
134 py::class_<TracingContext, AbstractContext>(m, "TracingContext")
135 .def("AddParameter",
136 [](TracingContext* self, DataType dtype) {
137 TracingTensorHandle* handle = nullptr;
138 // TODO(srbs): Add shape argument to this function.
139 tensorflow::PartialTensorShape shape;
140 Status s = self->AddParameter(dtype, shape, &handle);
141 MaybeRaiseRegisteredFromStatus(s);
142 return static_cast<AbstractTensorHandle*>(handle);
143 })
144 .def("Finalize", [](TracingContext* self, py::handle& outputs) {
145 // TODO(srbs): Using OutputList seems like an overkill here. Should we
146 // simply pass in an absl::Span?
147 OutputList output_list;
148 if (outputs.ptr() != Py_None) {
149 if (!PyList_Check(outputs.ptr())) {
150 MaybeRaiseRegisteredFromStatus(
151 InvalidArgument("must provide a list of Tensors as inputs"));
152 }
153 Py_ssize_t len = PyList_Size(outputs.ptr());
154 output_list.outputs.resize(len);
155 for (Py_ssize_t i = 0; i < len; ++i) {
156 PyObject* elem = PyList_GetItem(outputs.ptr(), i);
157 if (!elem) {
158 MaybeRaiseRegisteredFromStatus(
159 InvalidArgument("Tensor at index ", i, " is None."));
160 }
161 py::handle elem_h = elem;
162 AbstractTensorHandle* handle = elem_h.cast<AbstractTensorHandle*>();
163 if (!isa<TracingTensorHandle>(handle)) {
164 MaybeRaiseRegisteredFromStatus(InvalidArgument(
165 "Tensor at index ", i, " is not a graph tensor."));
166 }
167 output_list.outputs[i] = handle;
168 }
169 }
170 AbstractFunction* f = nullptr;
171 Status s = self->Finalize(&output_list, &f);
172 MaybeRaiseRegisteredFromStatus(s);
173 return f;
174 });
175
176 // Note: This does not take ownership of the C++ context, the lifetime of
177 // which is managed by the python `Context` and is expected to outlive this
178 // object.
179 // TODO(srbs): Make AbstractContext refcounted so that the above comment is
180 // not needed.
181 py::class_<ImmediateExecutionContext, AbstractContext,
182 std::unique_ptr<ImmediateExecutionContext, py::nodelete>>
183 ImmediateExecutionContext(m, "ImmediateExecutionContext");
184
185 // Unified execution operation.
186 py::class_<AbstractOperation, AbstractOperationPtr>(m, "AbstractOperation")
187 .def("Reset",
188 [](AbstractOperation* self, const char* op,
189 const char* raw_device_name) {
190 Status s = self->Reset(op, raw_device_name);
191 MaybeRaiseRegisteredFromStatus(s);
192 })
193 .def("SetOpName",
194 [](AbstractOperation* self, const char* op_name) {
195 // TODO(srbs): We could provide SetOpName on TracingOperation
196 // but then we need to do a hasattr check or try/pass in python.
197 if (isa<TracingOperation>(self)) {
198 auto tracing_op = reinterpret_cast<TracingOperation*>(self);
199 Status s = tracing_op->SetOpName(op_name);
200 MaybeRaiseRegisteredFromStatus(s);
201 }
202 })
203 .def("Name", &AbstractOperation::Name)
204 .def("DeviceName", &AbstractOperation::DeviceName)
205 .def("SetDeviceName",
206 [](AbstractOperation* self, const char* name) {
207 Status s = self->SetDeviceName(name);
208 MaybeRaiseRegisteredFromStatus(s);
209 })
210 .def("AddInput",
211 [](AbstractOperation* self, AbstractTensorHandle* input) {
212 Status s = self->AddInput(input);
213 MaybeRaiseRegisteredFromStatus(s);
214 })
215 .def("SetAttrType",
216 [](AbstractOperation* self, const char* attr_name, DataType value) {
217 Status s = self->SetAttrType(attr_name, value);
218 MaybeRaiseRegisteredFromStatus(s);
219 })
220 .def("Execute", [](AbstractOperation* self, int num_outputs) {
221 std::vector<AbstractTensorHandle*> outputs(num_outputs);
222 MaybeRaiseRegisteredFromStatus(
223 self->Execute(absl::MakeSpan(outputs), &num_outputs));
224 return outputs;
225 });
226
227 // Unified execution tensor handle.
228 py::class_<AbstractTensorHandle, AbstractTensorHandlePtr>(
229 m, "AbstractTensorHandle")
230 .def("DataType", &AbstractTensorHandle::DataType)
231 .def("numpy", [](AbstractTensorHandle* self) {
232 // TODO(srbs): Export this on ImmediateExecutionTensorHandle only.
233 if (!isa<ImmediateExecutionTensorHandle>(self)) {
234 // TODO(srbs): Add a helper to convert the kind enum to a
235 // user-friendly string.
236 MaybeRaiseRegisteredFromStatus(Internal(
237 "AbstractTensorHandle.numpy() must be called with an ",
238 "ImmediateExecutionTensorHandle found type: ", self->getKind()));
239 }
240 TF_Status s;
241 TFE_TensorHandle* handle =
242 wrap(dyn_cast<ImmediateExecutionTensorHandle>(self));
243 auto result = TFE_TensorHandleToNumpy(handle, &s);
244 MaybeRaiseRegisteredFromStatus(s.status);
245 return Pyo(result);
246 });
247
248 m.def("EagerTensorToImmediateExecutionTensorHandle", [](py::object handle) {
249 if (!EagerTensor_CheckExact(handle.ptr())) {
250 MaybeRaiseRegisteredFromStatus(
251 InvalidArgument("EagerTensorToImmediateExecutionTensorHandle called "
252 "with non-EagerTensor."));
253 }
254 TFE_TensorHandle* eager_tensor = EagerTensor_Handle(handle.ptr());
255 auto t = static_cast<AbstractTensorHandle*>(unwrap(eager_tensor));
256 t->Ref();
257 return t;
258 });
259
260 py::class_<AbstractFunction, RefCountPtr<AbstractFunction>> AbstractFunction(
261 m, "AbstractFunction");
262 }
263