• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/python/outfeed_receiver_py.h"
17 
18 #include <memory>
19 
20 #include "absl/algorithm/container.h"
21 #include "absl/memory/memory.h"
22 #include "absl/synchronization/mutex.h"
23 #include "pybind11/functional.h"
24 #include "pybind11/pybind11.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
27 #include "tensorflow/compiler/xla/python/outfeed_receiver.h"
28 #include "tensorflow/compiler/xla/python/py_client.h"
29 #include "tensorflow/compiler/xla/python/types.h"
30 
31 namespace xla {
32 
33 namespace py = pybind11;
34 
35 namespace {
36 
37 // A wrapper for OutfeedReceiver for use from Python, useful for ensuring
38 // that the GIL is released before destroying the OutfeedReceiver.
39 class OutfeedReceiverForPython {
40  public:
41   // A callback to Python takes: consumer id, received literal.
42   using CallbackToPython =
43       std::function<void(ClientAndPtr<PjRtDevice>, uint32_t, pybind11::object)>;
44 
OutfeedReceiverForPython(CallbackToPython callback_python,std::vector<std::shared_ptr<PyClient>> clients,ssize_t max_callback_queue_size_bytes)45   OutfeedReceiverForPython(CallbackToPython callback_python,
46                            std::vector<std::shared_ptr<PyClient>> clients,
47                            ssize_t max_callback_queue_size_bytes)
48       : callback_python_(std::move(callback_python)),
49         clients_(std::move(clients)) {
50     OutfeedReceiver::Callback callback =
51         [this](PjRtDevice* device, uint32_t consumer_id,
52                std::shared_ptr<Literal> literal) {
53           this->Callback(device, consumer_id, std::move(literal));
54         };
55     std::vector<PjRtClient*> client_ptrs(clients_.size());
56     absl::c_transform(clients_, client_ptrs.begin(),
57                       [](const std::shared_ptr<PyClient>& client) {
58                         return client->pjrt_client();
59                       });
60     outfeed_receiver_ = absl::make_unique<OutfeedReceiver>(
61         callback, client_ptrs, max_callback_queue_size_bytes);
62   }
63   OutfeedReceiverForPython(const OutfeedReceiverForPython&) = delete;
64   OutfeedReceiverForPython& operator=(const OutfeedReceiverForPython&) = delete;
65 
~OutfeedReceiverForPython()66   ~OutfeedReceiverForPython() {
67     // This destructor is called from the Python GC. Release it for the duration
68     // of the destruction, including the destruction of the OutfeedReceiver,
69     // when we may actually have to wait for threads to end. During this time
70     // we do not callback to Python (sometimes we get an exception
71     // "std::runtime_error: scoped_acquire::dec_ref(): thread state must
72     // be current!"").
73     {
74       absl::MutexLock lock(&mu_);
75       outfeed_receiver_shutting_down_ = true;
76     }
77     py::gil_scoped_release gil_release;
78     outfeed_receiver_ = nullptr;  // Shutdown the outfeed receiver.
79   }
80 
Start()81   void Start() { outfeed_receiver_->Start(); }
82 
AddOutfeed(XlaBuilder * builder,XlaOp token,uint32_t consumer_id,std::vector<XlaOp> arrays)83   StatusOr<XlaOp> AddOutfeed(XlaBuilder* builder, XlaOp token,
84                              uint32_t consumer_id, std::vector<XlaOp> arrays) {
85     return outfeed_receiver_->AddOutfeedToBuilder(builder, token, consumer_id,
86                                                   arrays);
87   }
88 
Callback(PjRtDevice * device,uint32_t consumer_id,std::shared_ptr<Literal> literal)89   void Callback(PjRtDevice* device, uint32_t consumer_id,
90                 std::shared_ptr<Literal> literal) {
91     {
92       absl::MutexLock lock(&mu_);
93       if (outfeed_receiver_shutting_down_) {
94         VLOG(2) << "Ignoring unsafe callback to Python during shutdown";
95         return;
96       }
97     }
98     // We expect the number of clients to be small, so an O(n) search is fine.
99     auto it = absl::c_find_if(
100         clients_, [device](const std::shared_ptr<PyClient>& client) {
101           return client->pjrt_client() == device->client();
102         });
103     CHECK(it != clients_.end());
104     py::gil_scoped_acquire gil_acquire;  // Need GIL also for LiteralToPython
105     py::object literal_python =
106         LiteralToPython(std::move(literal)).ValueOrDie();
107     // The callback_ should handle all exceptions in user-code. If we get
108     // an exception here, it is a bug in the callback and we should stop.
109     callback_python_(WrapWithClient<PjRtDevice>(*it, device), consumer_id,
110                      std::move(literal_python));
111   }
112 
113  private:
114   CallbackToPython callback_python_;
115   absl::Mutex mu_;
116   bool outfeed_receiver_shutting_down_ TF_GUARDED_BY(mu_) = false;
117   std::vector<std::shared_ptr<PyClient>> clients_;
118   std::unique_ptr<OutfeedReceiver> outfeed_receiver_;
119 };
120 
121 }  // namespace
122 
BuildOutfeedReceiverSubmodule(py::module * m)123 void BuildOutfeedReceiverSubmodule(py::module* m) {
124   py::module outfeed_receiver =
125       m->def_submodule("outfeed_receiver", "Outfeed receiver");
126   outfeed_receiver.def(
127       "start",
128       [](OutfeedReceiverForPython::CallbackToPython callback_to_python,
129          std::vector<std::shared_ptr<PyClient>> clients,
130          ssize_t max_callback_queue_size_bytes)
131           -> std::unique_ptr<OutfeedReceiverForPython> {
132         auto server = absl::make_unique<OutfeedReceiverForPython>(
133             callback_to_python, clients, max_callback_queue_size_bytes);
134         server->Start();
135         return server;
136       },
137       py::arg("callback_to_python"), py::arg("backends"),
138       py::arg("max_queue_size_bytes") = 256 * 1024 * 1024,
139       R"(Starts a multithreaded outfeed receiver.
140 
141       There is one thread for each of the specified devices. When Python
142       drops the last reference to the returned object, the receiver is shut
143       down. The destructor will block until all data is received from
144       devices.
145 
146       Args:
147         * callback_to_python: a Python callback to call, with <consumer_id>
148           and the data received.
149         * backends: the list of backends to listen on.
150         * max_queue_size_bytes: an optional integer to bound the maximum size
151             of arrays in the callback queue. When this limit is reached the
152             device listener pauses.
153       )",
154       py::call_guard<py::gil_scoped_release>());
155 
156   py::class_<OutfeedReceiverForPython> outfeed_receiver_class(
157       outfeed_receiver, "OutfeedReceiverForPython");
158 
159   outfeed_receiver_class.def(
160       "add_outfeed", &OutfeedReceiverForPython::AddOutfeed, py::arg("builder"),
161       py::arg("token"), py::arg("consumer_id"), py::arg("arrays"),
162       R"(Adds an outfeed into the given computation builder.
163 
164       Has the side-effect of registering the sent shape along with the consumer
165       ID. Returns error if the outfeed shape is not compatible with previously
166       used shape for the same consumer ID.)",
167       py::call_guard<py::gil_scoped_release>());
168 }
169 
170 }  // namespace xla
171