1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/python/outfeed_receiver_py.h"
17
18 #include <memory>
19
20 #include "absl/algorithm/container.h"
21 #include "absl/memory/memory.h"
22 #include "absl/synchronization/mutex.h"
23 #include "pybind11/functional.h"
24 #include "pybind11/pybind11.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
27 #include "tensorflow/compiler/xla/python/outfeed_receiver.h"
28 #include "tensorflow/compiler/xla/python/py_client.h"
29 #include "tensorflow/compiler/xla/python/types.h"
30
31 namespace xla {
32
33 namespace py = pybind11;
34
35 namespace {
36
37 // A wrapper for OutfeedReceiver for use from Python, useful for ensuring
38 // that the GIL is released before destroying the OutfeedReceiver.
39 class OutfeedReceiverForPython {
40 public:
41 // A callback to Python takes: consumer id, received literal.
42 using CallbackToPython =
43 std::function<void(ClientAndPtr<PjRtDevice>, uint32_t, pybind11::object)>;
44
OutfeedReceiverForPython(CallbackToPython callback_python,std::vector<std::shared_ptr<PyClient>> clients,ssize_t max_callback_queue_size_bytes)45 OutfeedReceiverForPython(CallbackToPython callback_python,
46 std::vector<std::shared_ptr<PyClient>> clients,
47 ssize_t max_callback_queue_size_bytes)
48 : callback_python_(std::move(callback_python)),
49 clients_(std::move(clients)) {
50 OutfeedReceiver::Callback callback =
51 [this](PjRtDevice* device, uint32_t consumer_id,
52 std::shared_ptr<Literal> literal) {
53 this->Callback(device, consumer_id, std::move(literal));
54 };
55 std::vector<PjRtClient*> client_ptrs(clients_.size());
56 absl::c_transform(clients_, client_ptrs.begin(),
57 [](const std::shared_ptr<PyClient>& client) {
58 return client->pjrt_client();
59 });
60 outfeed_receiver_ = absl::make_unique<OutfeedReceiver>(
61 callback, client_ptrs, max_callback_queue_size_bytes);
62 }
63 OutfeedReceiverForPython(const OutfeedReceiverForPython&) = delete;
64 OutfeedReceiverForPython& operator=(const OutfeedReceiverForPython&) = delete;
65
~OutfeedReceiverForPython()66 ~OutfeedReceiverForPython() {
67 // This destructor is called from the Python GC. Release it for the duration
68 // of the destruction, including the destruction of the OutfeedReceiver,
69 // when we may actually have to wait for threads to end. During this time
70 // we do not callback to Python (sometimes we get an exception
71 // "std::runtime_error: scoped_acquire::dec_ref(): thread state must
72 // be current!"").
73 {
74 absl::MutexLock lock(&mu_);
75 outfeed_receiver_shutting_down_ = true;
76 }
77 py::gil_scoped_release gil_release;
78 outfeed_receiver_ = nullptr; // Shutdown the outfeed receiver.
79 }
80
Start()81 void Start() { outfeed_receiver_->Start(); }
82
AddOutfeed(XlaBuilder * builder,XlaOp token,uint32_t consumer_id,std::vector<XlaOp> arrays)83 StatusOr<XlaOp> AddOutfeed(XlaBuilder* builder, XlaOp token,
84 uint32_t consumer_id, std::vector<XlaOp> arrays) {
85 return outfeed_receiver_->AddOutfeedToBuilder(builder, token, consumer_id,
86 arrays);
87 }
88
Callback(PjRtDevice * device,uint32_t consumer_id,std::shared_ptr<Literal> literal)89 void Callback(PjRtDevice* device, uint32_t consumer_id,
90 std::shared_ptr<Literal> literal) {
91 {
92 absl::MutexLock lock(&mu_);
93 if (outfeed_receiver_shutting_down_) {
94 VLOG(2) << "Ignoring unsafe callback to Python during shutdown";
95 return;
96 }
97 }
98 // We expect the number of clients to be small, so an O(n) search is fine.
99 auto it = absl::c_find_if(
100 clients_, [device](const std::shared_ptr<PyClient>& client) {
101 return client->pjrt_client() == device->client();
102 });
103 CHECK(it != clients_.end());
104 py::gil_scoped_acquire gil_acquire; // Need GIL also for LiteralToPython
105 py::object literal_python =
106 LiteralToPython(std::move(literal)).ValueOrDie();
107 // The callback_ should handle all exceptions in user-code. If we get
108 // an exception here, it is a bug in the callback and we should stop.
109 callback_python_(WrapWithClient<PjRtDevice>(*it, device), consumer_id,
110 std::move(literal_python));
111 }
112
113 private:
114 CallbackToPython callback_python_;
115 absl::Mutex mu_;
116 bool outfeed_receiver_shutting_down_ TF_GUARDED_BY(mu_) = false;
117 std::vector<std::shared_ptr<PyClient>> clients_;
118 std::unique_ptr<OutfeedReceiver> outfeed_receiver_;
119 };
120
121 } // namespace
122
BuildOutfeedReceiverSubmodule(py::module * m)123 void BuildOutfeedReceiverSubmodule(py::module* m) {
124 py::module outfeed_receiver =
125 m->def_submodule("outfeed_receiver", "Outfeed receiver");
126 outfeed_receiver.def(
127 "start",
128 [](OutfeedReceiverForPython::CallbackToPython callback_to_python,
129 std::vector<std::shared_ptr<PyClient>> clients,
130 ssize_t max_callback_queue_size_bytes)
131 -> std::unique_ptr<OutfeedReceiverForPython> {
132 auto server = absl::make_unique<OutfeedReceiverForPython>(
133 callback_to_python, clients, max_callback_queue_size_bytes);
134 server->Start();
135 return server;
136 },
137 py::arg("callback_to_python"), py::arg("backends"),
138 py::arg("max_queue_size_bytes") = 256 * 1024 * 1024,
139 R"(Starts a multithreaded outfeed receiver.
140
141 There is one thread for each of the specified devices. When Python
142 drops the last reference to the returned object, the receiver is shut
143 down. The destructor will block until all data is received from
144 devices.
145
146 Args:
147 * callback_to_python: a Python callback to call, with <consumer_id>
148 and the data received.
149 * backends: the list of backends to listen on.
150 * max_queue_size_bytes: an optional integer to bound the maximum size
151 of arrays in the callback queue. When this limit is reached the
152 device listener pauses.
153 )",
154 py::call_guard<py::gil_scoped_release>());
155
156 py::class_<OutfeedReceiverForPython> outfeed_receiver_class(
157 outfeed_receiver, "OutfeedReceiverForPython");
158
159 outfeed_receiver_class.def(
160 "add_outfeed", &OutfeedReceiverForPython::AddOutfeed, py::arg("builder"),
161 py::arg("token"), py::arg("consumer_id"), py::arg("arrays"),
162 R"(Adds an outfeed into the given computation builder.
163
164 Has the side-effect of registering the sent shape along with the consumer
165 ID. Returns error if the outfeed shape is not compatible with previously
166 used shape for the same consumer ID.)",
167 py::call_guard<py::gil_scoped_release>());
168 }
169
170 } // namespace xla
171