• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/python/outfeed_receiver.h"
17 
18 #include <sys/types.h>
19 
20 #include <memory>
21 #include <queue>
22 #include <sstream>
23 
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/memory/memory.h"
26 #include "absl/strings/str_format.h"
27 #include "tensorflow/compiler/xla/client/sharding_builder.h"
28 #include "tensorflow/compiler/xla/client/xla_builder.h"
29 #include "tensorflow/compiler/xla/client/xla_computation.h"
30 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
31 #include "tensorflow/compiler/xla/util.h"
32 #include "tensorflow/core/profiler/lib/traceme.h"
33 
34 // Implementation notes:
35 //
36 // Startup:
37 // -------
38 //
39 // The startup is initiated by a call from Python to StartOutfeedReceiver. For
40 // each local device there is one thread for listening for outfeeds from the
41 // device, one queue of received outfeeds, and one thread for invoking the
42 // Python callbacks.
43 //
44 // Framing protocol
45 // ----------------
46 //
47 // The outfeed mechanism has a single channel and the receiver must know
48 // exactly the shape and number of outfeed operations issued by the compiled
49 // code. This makes it hard to use outfeed in conditionals and loops and
50 // especially when outfeeding different-shaped data.
51 //
52 // To address this, when we compile the code we capture the shape of the
53 // data being outfed, and we generate a consumer ID (uint32_t) that is unique
54 // across the lifetime of the program to: the Python callable to callback to,
55 // the shape of the arguments, the keyword arguments to pass to the callable.
56 // Each outfeed payload is preceeded by a header (of shape u32[2]) with a
57 // special first value and the consumer ID. We maintain a registry of shapes
58 // by consumer ID. When receiving we lookup the shape by consumer ID, and then
59 // we read the payload.
60 //
61 // Back pressure:
62 // --------------
63 //
64 // We maintain a sum of the bytes from all the data waiting in the callback
65 // queues. The listening threads will wait for the sum to drop below a
66 // configurable threshold, default 256Mb. While the listening thread is waiting,
67 // on CPU and GPU the next outfeed operation from the device will block. On
68 // TPU there is a buffer, but eventually the TPU will also block.
69 //
70 // Shutdown:
71 // ---------
72 //
73 // The shutdown is initiated automatically when the last reference to the
74 // outfeed receiver object is dropped, and the Python garbage collector invokes
75 // the destructor.
76 //
77 // The shutdown sequence is implemented as follows:
78 // * we enqueue on all devices a computation that outfeeds a special header
79 //   with customer ID kOutfeedCidShutdown.
80 // * when each listening threads gets the shutdown header, it decrements
81 //   a counter of listening threads, and it
82 //   enqueues a special shutdown callback.
83 // * when each callback thread gets the shutdown callback marker, it terminates.
84 // * the shutdown code waits until all threads terminate.
85 //
86 // Since we currently keep the shape registry in the OutfeedReceiver, it is
87 // not safe to replace the OutfeedReceiver instance during the lifetime of
88 // the JAX program, or else previously cached jitted computations may refer
89 // to previously cached shapes. This can be solved, but for now we disallow
90 // replacing the OutfeedReceiver, and do not provide a Shutdown API to the
91 // Python program.
92 
93 namespace xla {
94 
95 // The header contains:
96 // 0. kOutfeedHeaderStart
97 // 1. consumer id
98 int constexpr kOutfeedHeaderWords = 2;
99 uint32_t constexpr kOutfeedHeaderStart = 271828;
100 // Special consumer IDs, without outfeed payload.
101 uint32_t constexpr kOutfeedCidShutdown = 0;
102 
103 // Encapsulates data received from a device outfeed.
104 class OutfeedData {
105  public:
OutfeedData(PjRtDevice * device,uint32_t consumer_id,Shape shape)106   OutfeedData(PjRtDevice* device, uint32_t consumer_id, Shape shape)
107       : device_(device),
108         consumer_id_(consumer_id),
109         shape_(shape),
110         literal_(nullptr),
111         literal_size_bytes_(0) {}
112 
device()113   PjRtDevice* device() { return device_; }
consumer_id() const114   uint32_t consumer_id() const { return consumer_id_; }
shape() const115   Shape shape() const { return shape_; }
literal()116   std::unique_ptr<Literal> literal() {
117     CHECK(literal_);
118     return std::move(literal_);
119   }
120 
121   void SetLiteral(std::unique_ptr<Literal> literal);
122 
literal_size_bytes() const123   ssize_t literal_size_bytes() const { return literal_size_bytes_; }
124 
125   std::string DebugString() const;
126 
127  private:
128   PjRtDevice* device_;
129   uint32_t consumer_id_;
130   Shape shape_;
131   std::unique_ptr<Literal> literal_;
132   ssize_t literal_size_bytes_;
133 };
134 
SetLiteral(std::unique_ptr<Literal> literal)135 void OutfeedData::SetLiteral(std::unique_ptr<Literal> literal) {
136   literal_ = std::move(literal);
137   shape_ = literal_->shape();
138   int total_size_bytes = 0;
139   ShapeUtil::ForEachSubshape(
140       shape_, [&](const Shape& literal_subshape, const ShapeIndex& index) {
141         if (!literal_subshape.IsTuple()) {
142           total_size_bytes += ShapeUtil::ByteSizeOf(literal_subshape, 8);
143         }
144       });
145   literal_size_bytes_ = total_size_bytes;
146 }
147 
DebugString() const148 std::string OutfeedData::DebugString() const {
149   return absl::StrFormat("dev=%s; cons=%d; shape=%s", device_->DebugString(),
150                          consumer_id_, shape_.ToString());
151 }
152 
153 class OutfeedReceiverImpl {
154  public:
155   OutfeedReceiverImpl(OutfeedReceiver::Callback callback,
156                       absl::Span<PjRtClient* const> clients,
157                       ssize_t max_callback_queue_size_bytes);
158 
159   OutfeedReceiverImpl(const OutfeedReceiverImpl&) = delete;
160   OutfeedReceiverImpl& operator=(const OutfeedReceiverImpl&) = delete;
161 
162   // Blocks until all data has been received from devices and all data
163   // in the queue has been passed to Python.
164   ~OutfeedReceiverImpl();
165 
166   void Start();
167 
168   StatusOr<XlaOp> AddOutfeedToBuilder(XlaBuilder* builder, XlaOp token,
169                                       uint32_t consumer_id,
170                                       std::vector<XlaOp> arrays);
171 
172  private:
CallbackQueueHasSpace()173   bool CallbackQueueHasSpace() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
174     return callback_queue_size_bytes_ < max_callback_queue_size_bytes_;
175   }
176 
ShutdownDone()177   bool ShutdownDone() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
178     return (num_working_callback_threads_ == 0 && num_listening_threads_ == 0);
179   }
180 
181   void CallbackThreadLoop(int device_idx);
182   void DeviceListenerThreadLoop(int device_idx);
183 
184   // Enqueues to a device an outfeed operation with a shutdown consumer ID.
185   Status SendShutdownOutfeedHeader(int device_idx);
186 
187   // Receives a raw Literal from a device outfeed.
188   StatusOr<std::unique_ptr<Literal>> ReceiveRawFromOutfeed(PjRtDevice* device,
189                                                            const Shape& shape);
190 
191   // Enqueues received data in the callbaback queue.
192   void EnqueueReceivedData(uint32_t device_idx,
193                            std::unique_ptr<OutfeedData> received)
194       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
195 
196   // Shuts down the threads. See implementation notes at top of file.
197   // It is not safe to restart an OutfeedReceiver after shutting down one.
198   void Shutdown();
199 
200   OutfeedReceiver::Callback callback_;
201   // The devices on which we are listening.
202   std::vector<PjRtDevice*> devices_;
203   // Maximum bytes capacity of the ensemble of callback queues.
204   uint64_t max_callback_queue_size_bytes_;
205 
206   absl::Mutex mu_;
207   // Registered shapes by consumer id.
208   // The shape registry must be alive as long as the program exists.
209   // Right now we tell the user to never restart after Shutdown.
210   absl::flat_hash_map<uint32_t, Shape> shape_registry_ TF_GUARDED_BY(mu_);
211   // How many bytes of Literal are in the ensemble of callback queues.
212   uint64_t callback_queue_size_bytes_ TF_GUARDED_BY(mu_);
213   // Threads listening.
214   int num_listening_threads_ TF_GUARDED_BY(mu_);
215   bool shutdown_started_ TF_GUARDED_BY(mu_);
216 
217   // How many callback threads are still working. Used for shutdown.
218   int num_working_callback_threads_ TF_GUARDED_BY(mu_);
219 
220   std::vector<std::queue<std::unique_ptr<OutfeedData>>> callback_queues_
221       TF_GUARDED_BY(mu_);
222   // The threadpool must come last to ensure the queue exists
223   // when the pool destructor is called.
224   std::unique_ptr<tensorflow::thread::ThreadPool> threads_;
225 };
226 
OutfeedReceiverImpl(OutfeedReceiver::Callback callback,absl::Span<PjRtClient * const> clients,ssize_t max_callback_queue_size_bytes)227 OutfeedReceiverImpl::OutfeedReceiverImpl(
228     OutfeedReceiver::Callback callback, absl::Span<PjRtClient* const> clients,
229     ssize_t max_callback_queue_size_bytes) {
230   callback_ = callback;
231   max_callback_queue_size_bytes_ = max_callback_queue_size_bytes;
232   for (const auto& client : clients) {
233     for (auto device : client->addressable_devices()) {
234       devices_.push_back(device);
235     }
236   }
237   CHECK_GT(devices_.size(), 0);
238   callback_queues_ =
239       std::vector<std::queue<std::unique_ptr<OutfeedData>>>(devices_.size());
240 
241   callback_queue_size_bytes_ = 0;
242   num_listening_threads_ = 0;
243   num_working_callback_threads_ = 0;
244   shutdown_started_ = false;
245 }
246 
Start()247 void OutfeedReceiverImpl::Start() {
248   {
249     absl::MutexLock lock(&mu_);
250     CHECK(!shutdown_started_);
251   }
252 
253   int num_threads = 2 * devices_.size();
254   threads_ = absl::make_unique<tensorflow::thread::ThreadPool>(
255       tensorflow::Env::Default(), "outfeed_receiver", num_threads);
256   for (int device_idx = 0; device_idx < devices_.size(); ++device_idx) {
257     threads_->Schedule(
258         [this, device_idx]() { DeviceListenerThreadLoop(device_idx); });
259     threads_->Schedule(
260         [this, device_idx]() { CallbackThreadLoop(device_idx); });
261   }
262 }
263 
Shutdown()264 void OutfeedReceiverImpl::Shutdown() {
265   VLOG(2) << "Shutdown start";
266   {
267     absl::MutexLock lock(&mu_);
268     CHECK(!shutdown_started_);
269     shutdown_started_ = true;
270   }
271   for (int device_idx = 0; device_idx < devices_.size(); ++device_idx) {
272     CHECK(SendShutdownOutfeedHeader(device_idx).ok());
273   }
274   VLOG(2) << "Shutdown waiting for listening and callback threads to stop";
275   absl::MutexLock lock(&mu_);
276   mu_.Await(absl::Condition(this, &OutfeedReceiverImpl::ShutdownDone));
277   VLOG(2) << "Shutdown done";
278 }
279 
~OutfeedReceiverImpl()280 OutfeedReceiverImpl::~OutfeedReceiverImpl() {
281   VLOG(2) << "~OutfeedReceiverImpl";
282   Shutdown();
283 }
284 
DeviceListenerThreadLoop(int device_idx)285 void OutfeedReceiverImpl::DeviceListenerThreadLoop(int device_idx) {
286   {
287     absl::MutexLock lock(&mu_);
288     ++num_listening_threads_;
289   }
290   PjRtDevice* device = devices_[device_idx];
291   while (true) {
292     Shape header_shape = ShapeUtil::MakeShape(U32, {kOutfeedHeaderWords});
293     std::unique_ptr<Literal> header =
294         ReceiveRawFromOutfeed(device, header_shape).ValueOrDie();
295     absl::Span<uint32_t> header_data = header->data<uint32>();
296     CHECK_EQ(header_data.size(), kOutfeedHeaderWords);
297     CHECK_EQ(header_data[0], kOutfeedHeaderStart);
298     uint32_t consumer_id = header_data[1];
299     Shape shape;
300     {
301       absl::MutexLock lock(&mu_);
302       auto registered_shape = shape_registry_.find(consumer_id);
303       if (registered_shape == shape_registry_.end()) {
304         LOG(FATAL)
305             << "[" << device->DebugString()
306             << "] Cannot find registered shape for consumer ID " << consumer_id
307             << ". Perhaps the code was compiled with a different instance "
308             << "of OutfeedReceiver.";
309       }
310       shape = registered_shape->second;
311     }
312     auto received = absl::make_unique<OutfeedData>(device, consumer_id, shape);
313     VLOG(2) << "Listener received header " << received->DebugString();
314     if (consumer_id == kOutfeedCidShutdown) {
315       VLOG(2) << "[" << device->DebugString()
316               << "] Listener received shutdown header";
317       absl::MutexLock lock(&mu_);
318       --num_listening_threads_;
319       VLOG(2) << "[" << device->DebugString() << "] Enqueue shutdown callback";
320       EnqueueReceivedData(device_idx, std::move(received));
321       return;
322     }
323     std::unique_ptr<Literal> data =
324         ReceiveRawFromOutfeed(device, shape).ValueOrDie();
325     received->SetLiteral(std::move(data));
326     absl::MutexLock lock(&mu_);
327     EnqueueReceivedData(device_idx, std::move(received));
328   }
329 }
330 
EnqueueReceivedData(uint32_t device_idx,std::unique_ptr<OutfeedData> received)331 void OutfeedReceiverImpl::EnqueueReceivedData(
332     uint32_t device_idx, std::unique_ptr<OutfeedData> received)
333     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
334   mu_.Await(absl::Condition(this, &OutfeedReceiverImpl::CallbackQueueHasSpace));
335   ssize_t literal_size_bytes = received->literal_size_bytes();
336   callback_queue_size_bytes_ += literal_size_bytes;
337   VLOG(2) << "Listener enqueues data " << received->DebugString() << " of size "
338           << literal_size_bytes << " bytes; "
339           << (1 + callback_queues_[device_idx].size())
340           << " callbacks in queue of total size " << callback_queue_size_bytes_
341           << " bytes.\n";
342   callback_queues_[device_idx].push(std::move(received));
343 }
344 
ReceiveRawFromOutfeed(PjRtDevice * device,const Shape & shape)345 StatusOr<std::unique_ptr<Literal>> OutfeedReceiverImpl::ReceiveRawFromOutfeed(
346     PjRtDevice* device, const Shape& shape) {
347   auto literal = std::make_unique<Literal>(shape);
348   TF_RETURN_IF_ERROR(device->TransferFromOutfeed(literal.get()));
349   return literal;
350 }
351 
CallbackThreadLoop(int device_idx)352 void OutfeedReceiverImpl::CallbackThreadLoop(int device_idx) {
353   const PjRtDevice* device = devices_[device_idx];
354   {
355     absl::MutexLock lock(&mu_);
356     num_working_callback_threads_++;
357   }
358   while (true) {
359     std::unique_ptr<OutfeedData> received;
360     {
361       absl::MutexLock lock(&mu_);
362       mu_.Await(absl::Condition(
363           +[](std::queue<std::unique_ptr<OutfeedData>>* queue) {
364             return !queue->empty();
365           },
366           &callback_queues_[device_idx]));
367       received = std::move(callback_queues_[device_idx].front());
368       callback_queues_[device_idx].pop();
369       callback_queue_size_bytes_ -= received->literal_size_bytes();
370       VLOG(2) << "[" << device->DebugString() << "] Dequeued callback for "
371               << received->DebugString() << "; "
372               << callback_queues_[device_idx].size()
373               << " callbacks in queue of total size "
374               << callback_queue_size_bytes_ << " bytes.\n";
375     }
376     if (received->consumer_id() == kOutfeedCidShutdown) {
377       VLOG(2) << "[" << device->DebugString()
378               << "] Callback loop received shutdown signal";
379       {
380         absl::MutexLock lock(&mu_);
381         CHECK(callback_queues_[device_idx].empty());
382         --num_working_callback_threads_;
383       }
384       VLOG(2) << "[" << device->DebugString() << "] Callback loop done";
385       return;
386     }
387     {
388       tensorflow::profiler::TraceMe traceme("OutfeedReceiver::Callback");
389       callback_(received->device(), received->consumer_id(),
390                 received->literal());
391     }
392   }
393 }
394 
SendShutdownOutfeedHeader(int device_idx)395 Status OutfeedReceiverImpl::SendShutdownOutfeedHeader(int device_idx) {
396   const PjRtDevice* device = devices_[device_idx];
397   constexpr int consumer_id = kOutfeedCidShutdown;
398   VLOG(2) << "[" << device->DebugString()
399           << "] SendSpecialHeader cons=" << consumer_id;
400   XlaBuilder builder(
401       absl::StrFormat("special_outfeed_header_%d_%d", consumer_id, device_idx));
402   XlaOp send =
403       AddOutfeedToBuilder(&builder, CreateToken(&builder), consumer_id, {})
404           .ValueOrDie();
405   XlaComputation computation = builder.Build(send).ValueOrDie();
406 
407   CompileOptions compile_options;
408   compile_options.executable_build_options.set_num_replicas(1);
409   compile_options.executable_build_options.set_num_partitions(1);
410   DeviceAssignment device_assignment(1, 1);
411   device_assignment(0, 0) = device->id();
412   compile_options.executable_build_options.set_device_assignment(
413       device_assignment);
414 
415   TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtExecutable> executable,
416                       devices_[device_idx]->client()->Compile(
417                           computation, std::move(compile_options)));
418   ExecuteOptions execute_options;
419   TF_ASSIGN_OR_RETURN(
420       std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> output_buffers,
421       executable->Execute({{}}, execute_options));
422   return Status::OK();
423 }
424 
AddOutfeedToBuilder(XlaBuilder * builder,XlaOp token,uint32_t consumer_id,std::vector<XlaOp> arrays)425 StatusOr<XlaOp> OutfeedReceiverImpl::AddOutfeedToBuilder(
426     XlaBuilder* builder, XlaOp token, uint32_t consumer_id,
427     std::vector<XlaOp> arrays) {
428   XlaOp data = Tuple(builder, std::move(arrays));
429   Shape shape_with_layout = builder->GetShape(data).ValueOrDie();
430   ShapeUtil::ForEachMutableSubshape(
431       &shape_with_layout, [](Shape* subshape, const ShapeIndex&) {
432         if (!subshape->has_layout()) {
433           LayoutUtil::SetToDefaultLayout(subshape);
434         }
435       });
436   VLOG(2) << "RegisterShape cons=" << consumer_id
437           << "; shape=" << shape_with_layout.ToString();
438   {
439     absl::MutexLock lock(&mu_);
440     auto found = shape_registry_.find(consumer_id);
441     if (found != shape_registry_.end()) {
442       if (!ShapeUtil::Equal(shape_with_layout, found->second)) {
443         return InvalidArgument(
444             "Shape %s does not match previous shape %s used "
445             "for consumer id %d",
446             shape_with_layout.DebugString(), found->second.DebugString(),
447             consumer_id);
448       }
449     } else {
450       shape_registry_.insert({consumer_id, shape_with_layout});
451     }
452   }
453 
454   std::vector<uint32_t> header{kOutfeedHeaderStart, consumer_id};
455   XlaOp header_op = ConstantR1<uint32_t>(builder, header);
456   // We assign the outfeed to the first device. This must match the sharding
457   // for the paired infeed.
458   builder->SetSharding(sharding_builder::AssignDevice(0));
459   token = OutfeedWithToken(
460       header_op, token, ShapeUtil::MakeShape(U32, {kOutfeedHeaderWords}), "");
461   if (consumer_id != kOutfeedCidShutdown) {
462     token = OutfeedWithToken(data, token, shape_with_layout, "");
463   }
464   builder->ClearSharding();
465   return token;
466 }
467 
OutfeedReceiver(Callback callback,absl::Span<PjRtClient * const> clients,ssize_t max_callback_queue_size_bytes)468 OutfeedReceiver::OutfeedReceiver(Callback callback,
469                                  absl::Span<PjRtClient* const> clients,
470                                  ssize_t max_callback_queue_size_bytes) {
471   p_impl_ = absl::make_unique<OutfeedReceiverImpl>(
472       callback, clients, max_callback_queue_size_bytes);
473 }
474 
~OutfeedReceiver()475 OutfeedReceiver::~OutfeedReceiver() {}
476 
Start()477 void OutfeedReceiver::Start() { p_impl_->Start(); }
478 
AddOutfeedToBuilder(XlaBuilder * builder,XlaOp token,uint32_t consumer_id,std::vector<XlaOp> arrays)479 StatusOr<XlaOp> OutfeedReceiver::AddOutfeedToBuilder(
480     XlaBuilder* builder, XlaOp token, uint32_t consumer_id,
481     std::vector<XlaOp> arrays) {
482   if (consumer_id == kOutfeedCidShutdown) {
483     return InvalidArgument("Consumer ID cannot be a reserved value: %d",
484                            consumer_id);
485   }
486   return p_impl_->AddOutfeedToBuilder(builder, token, consumer_id, arrays);
487 }
488 
489 }  // namespace xla
490