1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/python/outfeed_receiver.h"
17
18 #include <sys/types.h>
19
20 #include <memory>
21 #include <queue>
22 #include <sstream>
23
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/memory/memory.h"
26 #include "absl/strings/str_format.h"
27 #include "tensorflow/compiler/xla/client/xla_builder.h"
28 #include "tensorflow/compiler/xla/client/xla_computation.h"
29 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
30 #include "tensorflow/compiler/xla/util.h"
31 #include "tensorflow/core/profiler/lib/traceme.h"
32
33 // Implementation notes:
34 //
35 // Startup:
36 // -------
37 //
38 // The startup is initiated by a call from Python to StartOutfeedReceiver,
39 // which starts N threads for listening to the N devices and for enqueueing
40 // the received data into a callback queue. There is one additional callback
41 // thread for dequeing the data and invoking the Python callback.
42 //
43 // Framing protocol
44 // ----------------
45 //
46 // The outfeed mechanism has a single channel and the receiver must know
47 // exactly the shape and number of outfeed operations issued by the compiled
48 // code. This makes it hard to use outfeed in conditionals and loops and
49 // especially when outfeeding different-shaped data.
50 //
51 // To address this, when we compile the code we capture the shape of the
52 // data being outfed, and we generate a consumer ID (uint32_t) that is unique
53 // across the lifetime of the program to: the Python callable to callback to,
54 // the shape of the arguments, the keyword arguments to pass to the callable.
55 // Each outfeed payload is preceeded by a header (of shape u32[2]) with a
56 // special first value and the consumer ID. We maintain a registry of shapes
57 // by consumer ID. When receiving we lookup the shape by consumer ID, and then
58 // we read the payload.
59 //
60 // Back pressure:
61 // --------------
62 //
63 // We maintain a sum of the bytes from all the data waiting in the callback
64 // queue. The listening threads will wait for the sum to drop below a
65 // configurable threshold, default 256Mb. While the listening thread is waiting,
66 // on CPU and GPU the next outfeed operation from the device will block. On
67 // TPU there is a buffer, but eventually the TPU will also block.
68 //
69 // Shutdown:
70 // ---------
71 //
72 // The shutdown is initiated automatically when the last reference to the
73 // outfeed receiver object is dropped, and the Python garbage collector invokes
74 // the destructor.
75 //
76 // The shutdown sequence is implemented as follows:
77 // * we enqueue on all devices a computation that outfeeds a special header
78 // with customer ID kOutfeedCidShutdown.
79 // * when each listening threads gets the shutdown header, it decrements
80 // a counter of listening threads, and if the counter reaches 0, it
81 // enqueues a special shutdown callback.
82 // * when the callback thread gets the shutdown callback marker, it terminates.
83 // * the shutdown code waits until all threads terminate.
84 //
85 // Since we currently keep the shape registry in the OutfeedReceiver, it is
86 // not safe to replace the OutfeedReceiver instance during the lifetime of
87 // the JAX program, or else previously cached jitted computations may refer
88 // to previously cached shapes. This can be solved, but for now we disallow
89 // replacing the OutfeedReceiver, and do not provide a Shutdown API to the
90 // Python program.
91
92 namespace xla {
93
94 // The header contains:
95 // 0. kOutfeedHeaderStart
96 // 1. consumer id
97 int constexpr kOutfeedHeaderWords = 2;
98 uint32_t constexpr kOutfeedHeaderStart = 271828;
99 // Special consumer IDs, without outfeed payload.
100 uint32_t constexpr kOutfeedCidShutdown = 0;
101
102 // Encapsulates data received from a device outfeed.
103 class OutfeedData {
104 public:
OutfeedData(PjRtDevice * device,uint32_t consumer_id,Shape shape)105 OutfeedData(PjRtDevice* device, uint32_t consumer_id, Shape shape)
106 : device_(device),
107 consumer_id_(consumer_id),
108 shape_(shape),
109 literal_(nullptr),
110 literal_size_bytes_(0) {}
111
device()112 PjRtDevice* device() { return device_; }
consumer_id() const113 uint32_t consumer_id() const { return consumer_id_; }
shape() const114 Shape shape() const { return shape_; }
literal()115 std::unique_ptr<Literal> literal() {
116 CHECK(literal_);
117 return std::move(literal_);
118 }
119
120 void SetLiteral(std::unique_ptr<Literal> literal);
121
literal_size_bytes() const122 ssize_t literal_size_bytes() const { return literal_size_bytes_; }
123
124 std::string DebugString() const;
125
126 private:
127 PjRtDevice* device_;
128 uint32_t consumer_id_;
129 Shape shape_;
130 std::unique_ptr<Literal> literal_;
131 ssize_t literal_size_bytes_;
132 };
133
SetLiteral(std::unique_ptr<Literal> literal)134 void OutfeedData::SetLiteral(std::unique_ptr<Literal> literal) {
135 literal_ = std::move(literal);
136 shape_ = literal_->shape();
137 int total_size_bytes = 0;
138 ShapeUtil::ForEachSubshape(
139 shape_, [&](const Shape& literal_subshape, const ShapeIndex& index) {
140 if (!literal_subshape.IsTuple()) {
141 total_size_bytes += ShapeUtil::ByteSizeOf(literal_subshape, 8);
142 }
143 });
144 literal_size_bytes_ = total_size_bytes;
145 }
146
DebugString() const147 std::string OutfeedData::DebugString() const {
148 return absl::StrFormat("dev=%s; cons=%d; shape=%s", device_->DebugString(),
149 consumer_id_, shape_.ToString());
150 }
151
152 class OutfeedReceiverImpl {
153 public:
154 OutfeedReceiverImpl(OutfeedReceiver::Callback callback,
155 absl::Span<PjRtClient* const> clients,
156 ssize_t max_callback_queue_size_bytes);
157
158 OutfeedReceiverImpl(const OutfeedReceiverImpl&) = delete;
159 OutfeedReceiverImpl& operator=(const OutfeedReceiverImpl&) = delete;
160
161 // Blocks until all data has been received from devices and all data
162 // in the queue has been passed to Python.
163 ~OutfeedReceiverImpl();
164
165 void Start();
166
167 StatusOr<XlaOp> AddOutfeedToBuilder(XlaBuilder* builder, XlaOp token,
168 uint32_t consumer_id,
169 std::vector<XlaOp> arrays);
170
171 private:
CallbackQueueNotEmpty() const172 bool CallbackQueueNotEmpty() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
173 return !callback_queue_.empty();
174 }
175
CallbackQueueHasSpace()176 bool CallbackQueueHasSpace() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
177 return callback_queue_size_bytes_ < max_callback_queue_size_bytes_;
178 }
179
ShutdownDone()180 bool ShutdownDone() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
181 return (num_working_callback_threads_ == 0 && num_listening_threads_ == 0);
182 }
183
184 void CallbackThreadLoop();
185 void DeviceListenerThreadLoop(int device_idx);
186
187 // Enqueues to a device an outfeed operation with a shutdown consumer ID.
188 Status SendShutdownOutfeedHeader(int device_idx);
189
190 // Receives a raw Literal from a device outfeed.
191 StatusOr<std::unique_ptr<Literal>> ReceiveRawFromOutfeed(PjRtDevice* device,
192 const Shape& shape);
193
194 // Enqueues received data in the callbaback queue.
195 void EnqueueReceivedData(std::unique_ptr<OutfeedData> received)
196 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
197
198 // Shuts down the threads. See implementation notes at top of file.
199 // It is not safe to restart an OutfeedReceiver after shutting down one.
200 void Shutdown();
201
202 OutfeedReceiver::Callback callback_;
203 // The devices on which we are listening.
204 std::vector<PjRtDevice*> devices_;
205 // Maximum bytes capacity of the callback queue.
206 uint64_t max_callback_queue_size_bytes_;
207
208 absl::Mutex mu_;
209 // Registered shapes by consumer id.
210 // The shape registry must be alive as long as the program exists.
211 // Right now we tell the user to never restart after Shutdown.
212 absl::flat_hash_map<uint32_t, Shape> shape_registry_ TF_GUARDED_BY(mu_);
213 // How many bytes of Literal are in the callback queue.
214 uint64_t callback_queue_size_bytes_ TF_GUARDED_BY(mu_);
215 // Threads listening.
216 int num_listening_threads_ TF_GUARDED_BY(mu_);
217 bool shutdown_started_ TF_GUARDED_BY(mu_);
218
219 // How many callback threads are still working. Used for shutdown.
220 int num_working_callback_threads_ TF_GUARDED_BY(mu_);
221
222 std::queue<std::unique_ptr<OutfeedData>> callback_queue_ TF_GUARDED_BY(mu_);
223 // The threadpool must come last to ensure the queue exists
224 // when the pool destructor is called.
225 std::unique_ptr<tensorflow::thread::ThreadPool> threads_;
226 };
227
OutfeedReceiverImpl(OutfeedReceiver::Callback callback,absl::Span<PjRtClient * const> clients,ssize_t max_callback_queue_size_bytes)228 OutfeedReceiverImpl::OutfeedReceiverImpl(
229 OutfeedReceiver::Callback callback, absl::Span<PjRtClient* const> clients,
230 ssize_t max_callback_queue_size_bytes) {
231 callback_ = callback;
232 max_callback_queue_size_bytes_ = max_callback_queue_size_bytes;
233 for (const auto& client : clients) {
234 for (auto device : client->addressable_devices()) {
235 devices_.push_back(device);
236 }
237 }
238 CHECK_GT(devices_.size(), 0);
239
240 callback_queue_size_bytes_ = 0;
241 num_listening_threads_ = 0;
242 num_working_callback_threads_ = 0;
243 shutdown_started_ = false;
244 }
245
Start()246 void OutfeedReceiverImpl::Start() {
247 {
248 absl::MutexLock lock(&mu_);
249 CHECK(!shutdown_started_);
250 }
251 int num_threads = 1 + devices_.size();
252 threads_ = absl::make_unique<tensorflow::thread::ThreadPool>(
253 tensorflow::Env::Default(), "outfeed_receiver", num_threads);
254 threads_->Schedule([this]() { CallbackThreadLoop(); });
255 for (int device_idx = 0; device_idx < devices_.size(); ++device_idx) {
256 threads_->Schedule(
257 [this, device_idx]() { DeviceListenerThreadLoop(device_idx); });
258 }
259 }
260
Shutdown()261 void OutfeedReceiverImpl::Shutdown() {
262 VLOG(2) << "Shutdown start";
263 {
264 absl::MutexLock lock(&mu_);
265 CHECK(!shutdown_started_);
266 shutdown_started_ = true;
267 }
268 for (int device_idx = 0; device_idx < devices_.size(); ++device_idx) {
269 CHECK(SendShutdownOutfeedHeader(device_idx).ok());
270 }
271 VLOG(2) << "Shutdown waiting for listening and callback threads to stop";
272 absl::MutexLock lock(&mu_);
273 mu_.Await(absl::Condition(this, &OutfeedReceiverImpl::ShutdownDone));
274 VLOG(2) << "Shutdown done";
275 }
276
~OutfeedReceiverImpl()277 OutfeedReceiverImpl::~OutfeedReceiverImpl() {
278 VLOG(2) << "~OutfeedReceiverImpl";
279 Shutdown();
280 }
281
DeviceListenerThreadLoop(int device_idx)282 void OutfeedReceiverImpl::DeviceListenerThreadLoop(int device_idx) {
283 {
284 absl::MutexLock lock(&mu_);
285 ++num_listening_threads_;
286 }
287 PjRtDevice* device = devices_[device_idx];
288 while (true) {
289 Shape header_shape = ShapeUtil::MakeShape(U32, {kOutfeedHeaderWords});
290 std::unique_ptr<Literal> header =
291 ReceiveRawFromOutfeed(device, header_shape).ValueOrDie();
292 absl::Span<uint32_t> header_data = header->data<uint32>();
293 CHECK_EQ(header_data.size(), kOutfeedHeaderWords);
294 CHECK_EQ(header_data[0], kOutfeedHeaderStart);
295 uint32_t consumer_id = header_data[1];
296 Shape shape;
297 {
298 absl::MutexLock lock(&mu_);
299 auto registered_shape = shape_registry_.find(consumer_id);
300 if (registered_shape == shape_registry_.end()) {
301 LOG(FATAL)
302 << "[" << device->DebugString()
303 << "] Cannot find registered shape for consumer ID " << consumer_id
304 << ". Perhaps the code was compiled with a different instance "
305 << "of OutfeedReceiver.";
306 }
307 shape = registered_shape->second;
308 }
309 auto received = absl::make_unique<OutfeedData>(device, consumer_id, shape);
310 VLOG(2) << "Listener received header " << received->DebugString();
311 if (consumer_id == kOutfeedCidShutdown) {
312 VLOG(2) << "[" << device->DebugString()
313 << "] Listener received shutdown header";
314 absl::MutexLock lock(&mu_);
315 --num_listening_threads_;
316 if (num_listening_threads_ == 0) {
317 VLOG(2) << "Last listener shutdown; enqueue shutdown callback";
318 EnqueueReceivedData(std::move(received));
319 }
320 return;
321 }
322 std::unique_ptr<Literal> data =
323 ReceiveRawFromOutfeed(device, shape).ValueOrDie();
324 received->SetLiteral(std::move(data));
325 absl::MutexLock lock(&mu_);
326 EnqueueReceivedData(std::move(received));
327 }
328 }
329
EnqueueReceivedData(std::unique_ptr<OutfeedData> received)330 void OutfeedReceiverImpl::EnqueueReceivedData(
331 std::unique_ptr<OutfeedData> received) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
332 mu_.Await(absl::Condition(this, &OutfeedReceiverImpl::CallbackQueueHasSpace));
333 ssize_t literal_size_bytes = received->literal_size_bytes();
334 callback_queue_size_bytes_ += literal_size_bytes;
335 VLOG(2) << "Listener enqueues data " << received->DebugString() << " of size "
336 << literal_size_bytes << " bytes; " << (1 + callback_queue_.size())
337 << " callbacks in queue of total size " << callback_queue_size_bytes_
338 << " bytes.\n";
339 callback_queue_.push(std::move(received));
340 }
341
ReceiveRawFromOutfeed(PjRtDevice * device,const Shape & shape)342 StatusOr<std::unique_ptr<Literal>> OutfeedReceiverImpl::ReceiveRawFromOutfeed(
343 PjRtDevice* device, const Shape& shape) {
344 auto literal = std::make_unique<Literal>(shape);
345 TF_RETURN_IF_ERROR(device->TransferFromOutfeed(literal.get()));
346 return literal;
347 }
348
CallbackThreadLoop()349 void OutfeedReceiverImpl::CallbackThreadLoop() {
350 {
351 absl::MutexLock lock(&mu_);
352 num_working_callback_threads_++;
353 CHECK_EQ(num_working_callback_threads_, 1);
354 }
355 while (true) {
356 std::unique_ptr<OutfeedData> received;
357 {
358 absl::MutexLock lock(&mu_);
359 mu_.Await(
360 absl::Condition(this, &OutfeedReceiverImpl::CallbackQueueNotEmpty));
361 received = std::move(callback_queue_.front());
362 callback_queue_.pop();
363 callback_queue_size_bytes_ -= received->literal_size_bytes();
364 VLOG(2) << "Dequeued callback for " << received->DebugString() << "; "
365 << callback_queue_.size() << " callbacks in queue of total size "
366 << callback_queue_size_bytes_ << " bytes.\n";
367 }
368 if (received->consumer_id() == kOutfeedCidShutdown) {
369 VLOG(2) << "Callback loop received shutdown signal";
370 {
371 absl::MutexLock lock(&mu_);
372 CHECK(callback_queue_.empty());
373 CHECK_EQ(callback_queue_size_bytes_, 0);
374 --num_working_callback_threads_;
375 }
376 VLOG(2) << "Callback loop done";
377 return;
378 }
379 {
380 tensorflow::profiler::TraceMe traceme("OutfeedReceiver::Callback");
381 callback_(received->device(), received->consumer_id(),
382 received->literal());
383 }
384 }
385 }
386
SendShutdownOutfeedHeader(int device_idx)387 Status OutfeedReceiverImpl::SendShutdownOutfeedHeader(int device_idx) {
388 const PjRtDevice* device = devices_[device_idx];
389 constexpr int consumer_id = kOutfeedCidShutdown;
390 VLOG(2) << "[" << device->DebugString()
391 << "] SendSpecialHeader cons=" << consumer_id;
392 XlaBuilder builder(
393 absl::StrFormat("special_outfeed_header_%d_%d", consumer_id, device_idx));
394 XlaOp send =
395 AddOutfeedToBuilder(&builder, CreateToken(&builder), consumer_id, {})
396 .ValueOrDie();
397 XlaComputation computation = builder.Build(send).ValueOrDie();
398
399 CompileOptions compile_options;
400 compile_options.executable_build_options.set_num_replicas(1);
401 compile_options.executable_build_options.set_num_partitions(1);
402 DeviceAssignment device_assignment(1, 1);
403 device_assignment(0, 0) = device->id();
404 compile_options.executable_build_options.set_device_assignment(
405 device_assignment);
406
407 TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtExecutable> executable,
408 devices_[device_idx]->client()->Compile(
409 computation, std::move(compile_options)));
410 ExecuteOptions execute_options;
411 TF_ASSIGN_OR_RETURN(
412 std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> output_buffers,
413 executable->Execute({{}}, execute_options));
414 return Status::OK();
415 }
416
AddOutfeedToBuilder(XlaBuilder * builder,XlaOp token,uint32_t consumer_id,std::vector<XlaOp> arrays)417 StatusOr<XlaOp> OutfeedReceiverImpl::AddOutfeedToBuilder(
418 XlaBuilder* builder, XlaOp token, uint32_t consumer_id,
419 std::vector<XlaOp> arrays) {
420 XlaOp data = Tuple(builder, std::move(arrays));
421 Shape shape_with_layout = builder->GetShape(data).ValueOrDie();
422 ShapeUtil::ForEachMutableSubshape(
423 &shape_with_layout, [](Shape* subshape, const ShapeIndex&) {
424 if (!subshape->has_layout()) {
425 LayoutUtil::SetToDefaultLayout(subshape);
426 }
427 });
428 VLOG(2) << "RegisterShape cons=" << consumer_id
429 << "; shape=" << shape_with_layout.ToString();
430 {
431 absl::MutexLock lock(&mu_);
432 auto found = shape_registry_.find(consumer_id);
433 if (found != shape_registry_.end()) {
434 if (!ShapeUtil::Equal(shape_with_layout, found->second)) {
435 return InvalidArgument(
436 "Shape %s does not match previous shape %s used "
437 "for consumer id %d",
438 shape_with_layout.DebugString(), found->second.DebugString(),
439 consumer_id);
440 }
441 } else {
442 shape_registry_.insert({consumer_id, shape_with_layout});
443 }
444 }
445
446 std::vector<uint32_t> header{kOutfeedHeaderStart, consumer_id};
447 XlaOp header_op = ConstantR1<uint32_t>(builder, header);
448 token = OutfeedWithToken(
449 header_op, token, ShapeUtil::MakeShape(U32, {kOutfeedHeaderWords}), "");
450 if (consumer_id != kOutfeedCidShutdown) {
451 token = OutfeedWithToken(data, token, shape_with_layout, "");
452 }
453 return token;
454 }
455
OutfeedReceiver(Callback callback,absl::Span<PjRtClient * const> clients,ssize_t max_callback_queue_size_bytes)456 OutfeedReceiver::OutfeedReceiver(Callback callback,
457 absl::Span<PjRtClient* const> clients,
458 ssize_t max_callback_queue_size_bytes) {
459 p_impl_ = absl::make_unique<OutfeedReceiverImpl>(
460 callback, clients, max_callback_queue_size_bytes);
461 }
462
~OutfeedReceiver()463 OutfeedReceiver::~OutfeedReceiver() {}
464
Start()465 void OutfeedReceiver::Start() { p_impl_->Start(); }
466
AddOutfeedToBuilder(XlaBuilder * builder,XlaOp token,uint32_t consumer_id,std::vector<XlaOp> arrays)467 StatusOr<XlaOp> OutfeedReceiver::AddOutfeedToBuilder(
468 XlaBuilder* builder, XlaOp token, uint32_t consumer_id,
469 std::vector<XlaOp> arrays) {
470 if (consumer_id == kOutfeedCidShutdown) {
471 return InvalidArgument("Consumer ID cannot be a reserved value: %d",
472 consumer_id);
473 }
474 return p_impl_->AddOutfeedToBuilder(builder, token, consumer_id, arrays);
475 }
476
477 } // namespace xla
478