1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/python/outfeed_receiver.h"
17
18 #include <sys/types.h>
19
20 #include <memory>
21 #include <queue>
22 #include <sstream>
23
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/memory/memory.h"
26 #include "absl/strings/str_format.h"
27 #include "tensorflow/compiler/xla/client/sharding_builder.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/client/xla_computation.h"
30 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
31 #include "tensorflow/compiler/xla/util.h"
32 #include "tensorflow/core/profiler/lib/traceme.h"
33
34 // Implementation notes:
35 //
36 // Startup:
37 // -------
38 //
39 // The startup is initiated by a call from Python to StartOutfeedReceiver. For
40 // each local device there is one thread for listening for outfeeds from the
41 // device, one queue of received outfeeds, and one thread for invoking the
42 // Python callbacks.
43 //
44 // Framing protocol
45 // ----------------
46 //
47 // The outfeed mechanism has a single channel and the receiver must know
48 // exactly the shape and number of outfeed operations issued by the compiled
49 // code. This makes it hard to use outfeed in conditionals and loops and
50 // especially when outfeeding different-shaped data.
51 //
52 // To address this, when we compile the code we capture the shape of the
53 // data being outfed, and we generate a consumer ID (uint32_t) that is unique
54 // across the lifetime of the program to: the Python callable to callback to,
55 // the shape of the arguments, the keyword arguments to pass to the callable.
56 // Each outfeed payload is preceeded by a header (of shape u32[2]) with a
57 // special first value and the consumer ID. We maintain a registry of shapes
58 // by consumer ID. When receiving we lookup the shape by consumer ID, and then
59 // we read the payload.
60 //
61 // Back pressure:
62 // --------------
63 //
64 // We maintain a sum of the bytes from all the data waiting in the callback
65 // queues. The listening threads will wait for the sum to drop below a
66 // configurable threshold, default 256Mb. While the listening thread is waiting,
67 // on CPU and GPU the next outfeed operation from the device will block. On
68 // TPU there is a buffer, but eventually the TPU will also block.
69 //
70 // Shutdown:
71 // ---------
72 //
73 // The shutdown is initiated automatically when the last reference to the
74 // outfeed receiver object is dropped, and the Python garbage collector invokes
75 // the destructor.
76 //
77 // The shutdown sequence is implemented as follows:
78 // * we enqueue on all devices a computation that outfeeds a special header
79 // with customer ID kOutfeedCidShutdown.
80 // * when each listening threads gets the shutdown header, it decrements
81 // a counter of listening threads, and it
82 // enqueues a special shutdown callback.
83 // * when each callback thread gets the shutdown callback marker, it terminates.
84 // * the shutdown code waits until all threads terminate.
85 //
86 // Since we currently keep the shape registry in the OutfeedReceiver, it is
87 // not safe to replace the OutfeedReceiver instance during the lifetime of
88 // the JAX program, or else previously cached jitted computations may refer
89 // to previously cached shapes. This can be solved, but for now we disallow
90 // replacing the OutfeedReceiver, and do not provide a Shutdown API to the
91 // Python program.
92
93 namespace xla {
94
95 // The header contains:
96 // 0. kOutfeedHeaderStart
97 // 1. consumer id
98 int constexpr kOutfeedHeaderWords = 2;
99 uint32_t constexpr kOutfeedHeaderStart = 271828;
100 // Special consumer IDs, without outfeed payload.
101 uint32_t constexpr kOutfeedCidShutdown = 0;
102
103 // Encapsulates data received from a device outfeed.
104 class OutfeedData {
105 public:
OutfeedData(PjRtDevice * device,uint32_t consumer_id,Shape shape)106 OutfeedData(PjRtDevice* device, uint32_t consumer_id, Shape shape)
107 : device_(device),
108 consumer_id_(consumer_id),
109 shape_(shape),
110 literal_(nullptr),
111 literal_size_bytes_(0) {}
112
device()113 PjRtDevice* device() { return device_; }
consumer_id() const114 uint32_t consumer_id() const { return consumer_id_; }
shape() const115 Shape shape() const { return shape_; }
literal()116 std::unique_ptr<Literal> literal() {
117 CHECK(literal_);
118 return std::move(literal_);
119 }
120
121 void SetLiteral(std::unique_ptr<Literal> literal);
122
literal_size_bytes() const123 ssize_t literal_size_bytes() const { return literal_size_bytes_; }
124
125 std::string DebugString() const;
126
127 private:
128 PjRtDevice* device_;
129 uint32_t consumer_id_;
130 Shape shape_;
131 std::unique_ptr<Literal> literal_;
132 ssize_t literal_size_bytes_;
133 };
134
SetLiteral(std::unique_ptr<Literal> literal)135 void OutfeedData::SetLiteral(std::unique_ptr<Literal> literal) {
136 literal_ = std::move(literal);
137 shape_ = literal_->shape();
138 int total_size_bytes = 0;
139 ShapeUtil::ForEachSubshape(
140 shape_, [&](const Shape& literal_subshape, const ShapeIndex& index) {
141 if (!literal_subshape.IsTuple()) {
142 total_size_bytes += ShapeUtil::ByteSizeOf(literal_subshape, 8);
143 }
144 });
145 literal_size_bytes_ = total_size_bytes;
146 }
147
DebugString() const148 std::string OutfeedData::DebugString() const {
149 return absl::StrFormat("dev=%s; cons=%d; shape=%s", device_->DebugString(),
150 consumer_id_, shape_.ToString());
151 }
152
153 class OutfeedReceiverImpl {
154 public:
155 OutfeedReceiverImpl(OutfeedReceiver::Callback callback,
156 absl::Span<PjRtClient* const> clients,
157 ssize_t max_callback_queue_size_bytes);
158
159 OutfeedReceiverImpl(const OutfeedReceiverImpl&) = delete;
160 OutfeedReceiverImpl& operator=(const OutfeedReceiverImpl&) = delete;
161
162 // Blocks until all data has been received from devices and all data
163 // in the queue has been passed to Python.
164 ~OutfeedReceiverImpl();
165
166 void Start();
167
168 StatusOr<XlaOp> AddOutfeedToBuilder(XlaBuilder* builder, XlaOp token,
169 uint32_t consumer_id,
170 std::vector<XlaOp> arrays);
171
172 private:
CallbackQueueHasSpace()173 bool CallbackQueueHasSpace() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
174 return callback_queue_size_bytes_ < max_callback_queue_size_bytes_;
175 }
176
ShutdownDone()177 bool ShutdownDone() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
178 return (num_working_callback_threads_ == 0 && num_listening_threads_ == 0);
179 }
180
181 void CallbackThreadLoop(int device_idx);
182 void DeviceListenerThreadLoop(int device_idx);
183
184 // Enqueues to a device an outfeed operation with a shutdown consumer ID.
185 Status SendShutdownOutfeedHeader(int device_idx);
186
187 // Receives a raw Literal from a device outfeed.
188 StatusOr<std::unique_ptr<Literal>> ReceiveRawFromOutfeed(PjRtDevice* device,
189 const Shape& shape);
190
191 // Enqueues received data in the callbaback queue.
192 void EnqueueReceivedData(uint32_t device_idx,
193 std::unique_ptr<OutfeedData> received)
194 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
195
196 // Shuts down the threads. See implementation notes at top of file.
197 // It is not safe to restart an OutfeedReceiver after shutting down one.
198 void Shutdown();
199
200 OutfeedReceiver::Callback callback_;
201 // The devices on which we are listening.
202 std::vector<PjRtDevice*> devices_;
203 // Maximum bytes capacity of the ensemble of callback queues.
204 uint64_t max_callback_queue_size_bytes_;
205
206 absl::Mutex mu_;
207 // Registered shapes by consumer id.
208 // The shape registry must be alive as long as the program exists.
209 // Right now we tell the user to never restart after Shutdown.
210 absl::flat_hash_map<uint32_t, Shape> shape_registry_ TF_GUARDED_BY(mu_);
211 // How many bytes of Literal are in the ensemble of callback queues.
212 uint64_t callback_queue_size_bytes_ TF_GUARDED_BY(mu_);
213 // Threads listening.
214 int num_listening_threads_ TF_GUARDED_BY(mu_);
215 bool shutdown_started_ TF_GUARDED_BY(mu_);
216
217 // How many callback threads are still working. Used for shutdown.
218 int num_working_callback_threads_ TF_GUARDED_BY(mu_);
219
220 std::vector<std::queue<std::unique_ptr<OutfeedData>>> callback_queues_
221 TF_GUARDED_BY(mu_);
222 // The threadpool must come last to ensure the queue exists
223 // when the pool destructor is called.
224 std::unique_ptr<tensorflow::thread::ThreadPool> threads_;
225 };
226
OutfeedReceiverImpl(OutfeedReceiver::Callback callback,absl::Span<PjRtClient * const> clients,ssize_t max_callback_queue_size_bytes)227 OutfeedReceiverImpl::OutfeedReceiverImpl(
228 OutfeedReceiver::Callback callback, absl::Span<PjRtClient* const> clients,
229 ssize_t max_callback_queue_size_bytes) {
230 callback_ = callback;
231 max_callback_queue_size_bytes_ = max_callback_queue_size_bytes;
232 for (const auto& client : clients) {
233 for (auto device : client->addressable_devices()) {
234 devices_.push_back(device);
235 }
236 }
237 CHECK_GT(devices_.size(), 0);
238 callback_queues_ =
239 std::vector<std::queue<std::unique_ptr<OutfeedData>>>(devices_.size());
240
241 callback_queue_size_bytes_ = 0;
242 num_listening_threads_ = 0;
243 num_working_callback_threads_ = 0;
244 shutdown_started_ = false;
245 }
246
Start()247 void OutfeedReceiverImpl::Start() {
248 {
249 absl::MutexLock lock(&mu_);
250 CHECK(!shutdown_started_);
251 }
252
253 int num_threads = 2 * devices_.size();
254 threads_ = absl::make_unique<tensorflow::thread::ThreadPool>(
255 tensorflow::Env::Default(), "outfeed_receiver", num_threads);
256 for (int device_idx = 0; device_idx < devices_.size(); ++device_idx) {
257 threads_->Schedule(
258 [this, device_idx]() { DeviceListenerThreadLoop(device_idx); });
259 threads_->Schedule(
260 [this, device_idx]() { CallbackThreadLoop(device_idx); });
261 }
262 }
263
Shutdown()264 void OutfeedReceiverImpl::Shutdown() {
265 VLOG(2) << "Shutdown start";
266 {
267 absl::MutexLock lock(&mu_);
268 CHECK(!shutdown_started_);
269 shutdown_started_ = true;
270 }
271 for (int device_idx = 0; device_idx < devices_.size(); ++device_idx) {
272 CHECK(SendShutdownOutfeedHeader(device_idx).ok());
273 }
274 VLOG(2) << "Shutdown waiting for listening and callback threads to stop";
275 absl::MutexLock lock(&mu_);
276 mu_.Await(absl::Condition(this, &OutfeedReceiverImpl::ShutdownDone));
277 VLOG(2) << "Shutdown done";
278 }
279
~OutfeedReceiverImpl()280 OutfeedReceiverImpl::~OutfeedReceiverImpl() {
281 VLOG(2) << "~OutfeedReceiverImpl";
282 Shutdown();
283 }
284
DeviceListenerThreadLoop(int device_idx)285 void OutfeedReceiverImpl::DeviceListenerThreadLoop(int device_idx) {
286 {
287 absl::MutexLock lock(&mu_);
288 ++num_listening_threads_;
289 }
290 PjRtDevice* device = devices_[device_idx];
291 while (true) {
292 Shape header_shape = ShapeUtil::MakeShape(U32, {kOutfeedHeaderWords});
293 std::unique_ptr<Literal> header =
294 ReceiveRawFromOutfeed(device, header_shape).ValueOrDie();
295 absl::Span<uint32_t> header_data = header->data<uint32>();
296 CHECK_EQ(header_data.size(), kOutfeedHeaderWords);
297 CHECK_EQ(header_data[0], kOutfeedHeaderStart);
298 uint32_t consumer_id = header_data[1];
299 Shape shape;
300 {
301 absl::MutexLock lock(&mu_);
302 auto registered_shape = shape_registry_.find(consumer_id);
303 if (registered_shape == shape_registry_.end()) {
304 LOG(FATAL)
305 << "[" << device->DebugString()
306 << "] Cannot find registered shape for consumer ID " << consumer_id
307 << ". Perhaps the code was compiled with a different instance "
308 << "of OutfeedReceiver.";
309 }
310 shape = registered_shape->second;
311 }
312 auto received = absl::make_unique<OutfeedData>(device, consumer_id, shape);
313 VLOG(2) << "Listener received header " << received->DebugString();
314 if (consumer_id == kOutfeedCidShutdown) {
315 VLOG(2) << "[" << device->DebugString()
316 << "] Listener received shutdown header";
317 absl::MutexLock lock(&mu_);
318 --num_listening_threads_;
319 VLOG(2) << "[" << device->DebugString() << "] Enqueue shutdown callback";
320 EnqueueReceivedData(device_idx, std::move(received));
321 return;
322 }
323 std::unique_ptr<Literal> data =
324 ReceiveRawFromOutfeed(device, shape).ValueOrDie();
325 received->SetLiteral(std::move(data));
326 absl::MutexLock lock(&mu_);
327 EnqueueReceivedData(device_idx, std::move(received));
328 }
329 }
330
EnqueueReceivedData(uint32_t device_idx,std::unique_ptr<OutfeedData> received)331 void OutfeedReceiverImpl::EnqueueReceivedData(
332 uint32_t device_idx, std::unique_ptr<OutfeedData> received)
333 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
334 mu_.Await(absl::Condition(this, &OutfeedReceiverImpl::CallbackQueueHasSpace));
335 ssize_t literal_size_bytes = received->literal_size_bytes();
336 callback_queue_size_bytes_ += literal_size_bytes;
337 VLOG(2) << "Listener enqueues data " << received->DebugString() << " of size "
338 << literal_size_bytes << " bytes; "
339 << (1 + callback_queues_[device_idx].size())
340 << " callbacks in queue of total size " << callback_queue_size_bytes_
341 << " bytes.\n";
342 callback_queues_[device_idx].push(std::move(received));
343 }
344
ReceiveRawFromOutfeed(PjRtDevice * device,const Shape & shape)345 StatusOr<std::unique_ptr<Literal>> OutfeedReceiverImpl::ReceiveRawFromOutfeed(
346 PjRtDevice* device, const Shape& shape) {
347 auto literal = std::make_unique<Literal>(shape);
348 TF_RETURN_IF_ERROR(device->TransferFromOutfeed(literal.get()));
349 return literal;
350 }
351
CallbackThreadLoop(int device_idx)352 void OutfeedReceiverImpl::CallbackThreadLoop(int device_idx) {
353 const PjRtDevice* device = devices_[device_idx];
354 {
355 absl::MutexLock lock(&mu_);
356 num_working_callback_threads_++;
357 }
358 while (true) {
359 std::unique_ptr<OutfeedData> received;
360 {
361 absl::MutexLock lock(&mu_);
362 mu_.Await(absl::Condition(
363 +[](std::queue<std::unique_ptr<OutfeedData>>* queue) {
364 return !queue->empty();
365 },
366 &callback_queues_[device_idx]));
367 received = std::move(callback_queues_[device_idx].front());
368 callback_queues_[device_idx].pop();
369 callback_queue_size_bytes_ -= received->literal_size_bytes();
370 VLOG(2) << "[" << device->DebugString() << "] Dequeued callback for "
371 << received->DebugString() << "; "
372 << callback_queues_[device_idx].size()
373 << " callbacks in queue of total size "
374 << callback_queue_size_bytes_ << " bytes.\n";
375 }
376 if (received->consumer_id() == kOutfeedCidShutdown) {
377 VLOG(2) << "[" << device->DebugString()
378 << "] Callback loop received shutdown signal";
379 {
380 absl::MutexLock lock(&mu_);
381 CHECK(callback_queues_[device_idx].empty());
382 --num_working_callback_threads_;
383 }
384 VLOG(2) << "[" << device->DebugString() << "] Callback loop done";
385 return;
386 }
387 {
388 tensorflow::profiler::TraceMe traceme("OutfeedReceiver::Callback");
389 callback_(received->device(), received->consumer_id(),
390 received->literal());
391 }
392 }
393 }
394
SendShutdownOutfeedHeader(int device_idx)395 Status OutfeedReceiverImpl::SendShutdownOutfeedHeader(int device_idx) {
396 const PjRtDevice* device = devices_[device_idx];
397 constexpr int consumer_id = kOutfeedCidShutdown;
398 VLOG(2) << "[" << device->DebugString()
399 << "] SendSpecialHeader cons=" << consumer_id;
400 XlaBuilder builder(
401 absl::StrFormat("special_outfeed_header_%d_%d", consumer_id, device_idx));
402 XlaOp send =
403 AddOutfeedToBuilder(&builder, CreateToken(&builder), consumer_id, {})
404 .ValueOrDie();
405 XlaComputation computation = builder.Build(send).ValueOrDie();
406
407 CompileOptions compile_options;
408 compile_options.executable_build_options.set_num_replicas(1);
409 compile_options.executable_build_options.set_num_partitions(1);
410 DeviceAssignment device_assignment(1, 1);
411 device_assignment(0, 0) = device->id();
412 compile_options.executable_build_options.set_device_assignment(
413 device_assignment);
414
415 TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtExecutable> executable,
416 devices_[device_idx]->client()->Compile(
417 computation, std::move(compile_options)));
418 ExecuteOptions execute_options;
419 TF_ASSIGN_OR_RETURN(
420 std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> output_buffers,
421 executable->Execute({{}}, execute_options));
422 return Status::OK();
423 }
424
AddOutfeedToBuilder(XlaBuilder * builder,XlaOp token,uint32_t consumer_id,std::vector<XlaOp> arrays)425 StatusOr<XlaOp> OutfeedReceiverImpl::AddOutfeedToBuilder(
426 XlaBuilder* builder, XlaOp token, uint32_t consumer_id,
427 std::vector<XlaOp> arrays) {
428 XlaOp data = Tuple(builder, std::move(arrays));
429 Shape shape_with_layout = builder->GetShape(data).ValueOrDie();
430 ShapeUtil::ForEachMutableSubshape(
431 &shape_with_layout, [](Shape* subshape, const ShapeIndex&) {
432 if (!subshape->has_layout()) {
433 LayoutUtil::SetToDefaultLayout(subshape);
434 }
435 });
436 VLOG(2) << "RegisterShape cons=" << consumer_id
437 << "; shape=" << shape_with_layout.ToString();
438 {
439 absl::MutexLock lock(&mu_);
440 auto found = shape_registry_.find(consumer_id);
441 if (found != shape_registry_.end()) {
442 if (!ShapeUtil::Equal(shape_with_layout, found->second)) {
443 return InvalidArgument(
444 "Shape %s does not match previous shape %s used "
445 "for consumer id %d",
446 shape_with_layout.DebugString(), found->second.DebugString(),
447 consumer_id);
448 }
449 } else {
450 shape_registry_.insert({consumer_id, shape_with_layout});
451 }
452 }
453
454 std::vector<uint32_t> header{kOutfeedHeaderStart, consumer_id};
455 XlaOp header_op = ConstantR1<uint32_t>(builder, header);
456 // We assign the outfeed to the first device. This must match the sharding
457 // for the paired infeed.
458 builder->SetSharding(sharding_builder::AssignDevice(0));
459 token = OutfeedWithToken(
460 header_op, token, ShapeUtil::MakeShape(U32, {kOutfeedHeaderWords}), "");
461 if (consumer_id != kOutfeedCidShutdown) {
462 token = OutfeedWithToken(data, token, shape_with_layout, "");
463 }
464 builder->ClearSharding();
465 return token;
466 }
467
OutfeedReceiver(Callback callback,absl::Span<PjRtClient * const> clients,ssize_t max_callback_queue_size_bytes)468 OutfeedReceiver::OutfeedReceiver(Callback callback,
469 absl::Span<PjRtClient* const> clients,
470 ssize_t max_callback_queue_size_bytes) {
471 p_impl_ = absl::make_unique<OutfeedReceiverImpl>(
472 callback, clients, max_callback_queue_size_bytes);
473 }
474
~OutfeedReceiver()475 OutfeedReceiver::~OutfeedReceiver() {}
476
Start()477 void OutfeedReceiver::Start() { p_impl_->Start(); }
478
AddOutfeedToBuilder(XlaBuilder * builder,XlaOp token,uint32_t consumer_id,std::vector<XlaOp> arrays)479 StatusOr<XlaOp> OutfeedReceiver::AddOutfeedToBuilder(
480 XlaBuilder* builder, XlaOp token, uint32_t consumer_id,
481 std::vector<XlaOp> arrays) {
482 if (consumer_id == kOutfeedCidShutdown) {
483 return InvalidArgument("Consumer ID cannot be a reserved value: %d",
484 consumer_id);
485 }
486 return p_impl_->AddOutfeedToBuilder(builder, token, consumer_id, arrays);
487 }
488
489 } // namespace xla
490