• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <pybind11/stl.h>
17 
18 #include <memory>
19 
20 #include "pybind11/pybind11.h"
21 #include "tensorflow/c/eager/abstract_context.h"
22 #include "tensorflow/c/eager/abstract_function.h"
23 #include "tensorflow/c/eager/abstract_operation.h"
24 #include "tensorflow/c/eager/abstract_tensor_handle.h"
25 #include "tensorflow/c/eager/c_api.h"
26 #include "tensorflow/c/eager/c_api_internal.h"
27 #include "tensorflow/c/eager/c_api_unified_experimental.h"
28 #include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
29 #include "tensorflow/c/eager/immediate_execution_context.h"
30 #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
31 #include "tensorflow/c/eager/tfe_context_internal.h"
32 #include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
33 #include "tensorflow/core/framework/tensor_shape.h"
34 #include "tensorflow/core/framework/types.pb.h"
35 #include "tensorflow/core/lib/core/status.h"
36 #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
37 #include "tensorflow/core/platform/errors.h"
38 #include "tensorflow/core/platform/refcount.h"
39 #include "tensorflow/python/eager/pywrap_tensor.h"
40 #include "tensorflow/python/lib/core/pybind11_lib.h"
41 #include "tensorflow/python/lib/core/pybind11_status.h"
42 #include "tensorflow/python/lib/core/safe_ptr.h"
43 
44 namespace py = pybind11;
45 
46 using tensorflow::AbstractContext;
47 using tensorflow::AbstractContextPtr;
48 using tensorflow::AbstractFunction;
49 using tensorflow::AbstractOperation;
50 using tensorflow::AbstractOperationPtr;
51 using tensorflow::AbstractTensorHandle;
52 using tensorflow::AbstractTensorHandlePtr;
53 using tensorflow::OutputList;
54 
55 using tensorflow::tracing::TracingContext;
56 using tensorflow::tracing::TracingOperation;
57 using tensorflow::tracing::TracingTensorHandle;
58 
59 using tensorflow::ImmediateContextPtr;
60 using tensorflow::ImmediateExecutionContext;
61 using tensorflow::ImmediateExecutionTensorHandle;
62 
63 using tensorflow::dyn_cast;
64 using tensorflow::isa;
65 using tensorflow::unwrap;
66 using tensorflow::wrap;
67 
68 using tensorflow::DataType;
69 using tensorflow::make_safe;
70 using tensorflow::MaybeRaiseRegisteredFromStatus;
71 using tensorflow::MaybeRaiseRegisteredFromTFStatus;
72 using tensorflow::Pyo;
73 using tensorflow::Safe_TF_StatusPtr;
74 using tensorflow::Status;
75 using tensorflow::string;
76 using tensorflow::TFE_TensorHandleToNumpy;
77 using tensorflow::core::RefCountPtr;
78 
79 using tensorflow::errors::Internal;
80 using tensorflow::errors::InvalidArgument;
81 
PYBIND11_MODULE(_unified_api,m)82 PYBIND11_MODULE(_unified_api, m) {
83   // Context creation functions.
84   m.def("SetTracingImplementation", [](const char* impl) {
85     Safe_TF_StatusPtr status = make_safe(TF_NewStatus());
86     TF_SetTracingImplementation(impl, status.get());
87     MaybeRaiseRegisteredFromStatus(status->status);
88   });
89   m.def("NewTracingContext", [](const char* fn_name) {
90     Safe_TF_StatusPtr status = make_safe(TF_NewStatus());
91     auto* ctx = unwrap(TF_CreateFunction(fn_name, status.get()));
92     MaybeRaiseRegisteredFromTFStatus(status.get());
93     if (!ctx) {
94       MaybeRaiseRegisteredFromStatus(
95           Internal("TF_CreateFunction returned nullptr"));
96     }
97     if (!isa<TracingContext>(ctx)) {
98       // TODO(srbs): Add a helper to convert the kind enum to a user-friendly
99       // string.
100       MaybeRaiseRegisteredFromStatus(
101           Internal("TF_CreateFunction must return a TracingContext, found ",
102                    ctx->getKind()));
103     }
104     return dyn_cast<TracingContext>(ctx);
105   });
106   m.def("EagerContextToImmediateExecutionContext", [](py::handle& obj) {
107     TFE_Context* ctx =
108         static_cast<TFE_Context*>(PyCapsule_GetPointer(obj.ptr(), nullptr));
109     if (!ctx) {
110       MaybeRaiseRegisteredFromStatus(InvalidArgument("TFE_Context is nullptr"));
111     }
112     return unwrap(ctx);
113   });
114 
115   // Unified execution context.
116   py::class_<AbstractContext, AbstractContextPtr>(m, "AbstractContext")
117       .def("CreateOperation",
118            [](AbstractContext* self, const char* op,
119               const char* raw_device_name) {
120              auto operation = self->CreateOperation();
121              operation->Reset(op, raw_device_name);
122              return operation;
123            })
124       .def("RegisterFunction",
125            [](AbstractContext* self, AbstractFunction* f) {
126              Status s = self->RegisterFunction(f);
127              MaybeRaiseRegisteredFromStatus(s);
128            })
129       .def("RemoveFunction", [](AbstractContext* self, const string& func) {
130         Status s = self->RemoveFunction(func);
131         MaybeRaiseRegisteredFromStatus(s);
132       });
133 
134   py::class_<TracingContext, AbstractContext>(m, "TracingContext")
135       .def("AddParameter",
136            [](TracingContext* self, DataType dtype) {
137              TracingTensorHandle* handle = nullptr;
138              // TODO(srbs): Add shape argument to this function.
139              tensorflow::PartialTensorShape shape;
140              Status s = self->AddParameter(dtype, shape, &handle);
141              MaybeRaiseRegisteredFromStatus(s);
142              return static_cast<AbstractTensorHandle*>(handle);
143            })
144       .def("Finalize", [](TracingContext* self, py::handle& outputs) {
145         // TODO(srbs): Using OutputList seems like an overkill here. Should we
146         // simply pass in an absl::Span?
147         OutputList output_list;
148         if (outputs.ptr() != Py_None) {
149           if (!PyList_Check(outputs.ptr())) {
150             MaybeRaiseRegisteredFromStatus(
151                 InvalidArgument("must provide a list of Tensors as inputs"));
152           }
153           Py_ssize_t len = PyList_Size(outputs.ptr());
154           output_list.outputs.resize(len);
155           for (Py_ssize_t i = 0; i < len; ++i) {
156             PyObject* elem = PyList_GetItem(outputs.ptr(), i);
157             if (!elem) {
158               MaybeRaiseRegisteredFromStatus(
159                   InvalidArgument("Tensor at index  ", i, " is None."));
160             }
161             py::handle elem_h = elem;
162             AbstractTensorHandle* handle = elem_h.cast<AbstractTensorHandle*>();
163             if (!isa<TracingTensorHandle>(handle)) {
164               MaybeRaiseRegisteredFromStatus(InvalidArgument(
165                   "Tensor at index  ", i, " is not a graph tensor."));
166             }
167             output_list.outputs[i] = handle;
168           }
169         }
170         AbstractFunction* f = nullptr;
171         Status s = self->Finalize(&output_list, &f);
172         MaybeRaiseRegisteredFromStatus(s);
173         return f;
174       });
175 
176   // Note: This does not take ownership of the C++ context, the lifetime of
177   // which is managed by the python `Context` and is expected to outlive this
178   // object.
179   // TODO(srbs): Make AbstractContext refcounted so that the above comment is
180   // not needed.
181   py::class_<ImmediateExecutionContext, AbstractContext,
182              std::unique_ptr<ImmediateExecutionContext, py::nodelete>>
183       ImmediateExecutionContext(m, "ImmediateExecutionContext");
184 
185   // Unified execution operation.
186   py::class_<AbstractOperation, AbstractOperationPtr>(m, "AbstractOperation")
187       .def("Reset",
188            [](AbstractOperation* self, const char* op,
189               const char* raw_device_name) {
190              Status s = self->Reset(op, raw_device_name);
191              MaybeRaiseRegisteredFromStatus(s);
192            })
193       .def("SetOpName",
194            [](AbstractOperation* self, const char* op_name) {
195              // TODO(srbs): We could provide SetOpName on TracingOperation
196              // but then we need to do a hasattr check or try/pass in python.
197              if (isa<TracingOperation>(self)) {
198                auto tracing_op = reinterpret_cast<TracingOperation*>(self);
199                Status s = tracing_op->SetOpName(op_name);
200                MaybeRaiseRegisteredFromStatus(s);
201              }
202            })
203       .def("Name", &AbstractOperation::Name)
204       .def("DeviceName", &AbstractOperation::DeviceName)
205       .def("SetDeviceName",
206            [](AbstractOperation* self, const char* name) {
207              Status s = self->SetDeviceName(name);
208              MaybeRaiseRegisteredFromStatus(s);
209            })
210       .def("AddInput",
211            [](AbstractOperation* self, AbstractTensorHandle* input) {
212              Status s = self->AddInput(input);
213              MaybeRaiseRegisteredFromStatus(s);
214            })
215       .def("SetAttrType",
216            [](AbstractOperation* self, const char* attr_name, DataType value) {
217              Status s = self->SetAttrType(attr_name, value);
218              MaybeRaiseRegisteredFromStatus(s);
219            })
220       .def("Execute", [](AbstractOperation* self, int num_outputs) {
221         std::vector<AbstractTensorHandle*> outputs(num_outputs);
222         MaybeRaiseRegisteredFromStatus(
223             self->Execute(absl::MakeSpan(outputs), &num_outputs));
224         return outputs;
225       });
226 
227   // Unified execution tensor handle.
228   py::class_<AbstractTensorHandle, AbstractTensorHandlePtr>(
229       m, "AbstractTensorHandle")
230       .def("DataType", &AbstractTensorHandle::DataType)
231       .def("numpy", [](AbstractTensorHandle* self) {
232         // TODO(srbs): Export this on ImmediateExecutionTensorHandle only.
233         if (!isa<ImmediateExecutionTensorHandle>(self)) {
234           // TODO(srbs): Add a helper to convert the kind enum to a
235           // user-friendly string.
236           MaybeRaiseRegisteredFromStatus(Internal(
237               "AbstractTensorHandle.numpy() must be called with an ",
238               "ImmediateExecutionTensorHandle found type: ", self->getKind()));
239         }
240         TF_Status s;
241         TFE_TensorHandle* handle =
242             wrap(dyn_cast<ImmediateExecutionTensorHandle>(self));
243         auto result = TFE_TensorHandleToNumpy(handle, &s);
244         MaybeRaiseRegisteredFromStatus(s.status);
245         return Pyo(result);
246       });
247 
248   m.def("EagerTensorToImmediateExecutionTensorHandle", [](py::object handle) {
249     if (!EagerTensor_CheckExact(handle.ptr())) {
250       MaybeRaiseRegisteredFromStatus(
251           InvalidArgument("EagerTensorToImmediateExecutionTensorHandle called "
252                           "with non-EagerTensor."));
253     }
254     TFE_TensorHandle* eager_tensor = EagerTensor_Handle(handle.ptr());
255     auto t = static_cast<AbstractTensorHandle*>(unwrap(eager_tensor));
256     t->Ref();
257     return t;
258   });
259 
260   py::class_<AbstractFunction, RefCountPtr<AbstractFunction>> AbstractFunction(
261       m, "AbstractFunction");
262 }
263