• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/python/outfeed_receiver.h"
17 
18 #include <sys/types.h>
19 
20 #include <memory>
21 #include <queue>
22 #include <sstream>
23 
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/memory/memory.h"
26 #include "absl/strings/str_format.h"
27 #include "tensorflow/compiler/xla/client/xla_builder.h"
28 #include "tensorflow/compiler/xla/client/xla_computation.h"
29 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
30 #include "tensorflow/compiler/xla/util.h"
31 #include "tensorflow/core/profiler/lib/traceme.h"
32 
33 // Implementation notes:
34 //
35 // Startup:
36 // -------
37 //
38 // The startup is initiated by a call from Python to StartOutfeedReceiver,
39 // which starts N threads for listening to the N devices and for enqueueing
40 // the received data into a callback queue. There is one additional callback
41 // thread for dequeing the data and invoking the Python callback.
42 //
43 // Framing protocol
44 // ----------------
45 //
46 // The outfeed mechanism has a single channel and the receiver must know
47 // exactly the shape and number of outfeed operations issued by the compiled
48 // code. This makes it hard to use outfeed in conditionals and loops and
49 // especially when outfeeding different-shaped data.
50 //
51 // To address this, when we compile the code we capture the shape of the
52 // data being outfed, and we generate a consumer ID (uint32_t) that is unique
53 // across the lifetime of the program to: the Python callable to callback to,
54 // the shape of the arguments, the keyword arguments to pass to the callable.
55 // Each outfeed payload is preceeded by a header (of shape u32[2]) with a
56 // special first value and the consumer ID. We maintain a registry of shapes
57 // by consumer ID. When receiving we lookup the shape by consumer ID, and then
58 // we read the payload.
59 //
60 // Back pressure:
61 // --------------
62 //
63 // We maintain a sum of the bytes from all the data waiting in the callback
64 // queue. The listening threads will wait for the sum to drop below a
65 // configurable threshold, default 256Mb. While the listening thread is waiting,
66 // on CPU and GPU the next outfeed operation from the device will block. On
67 // TPU there is a buffer, but eventually the TPU will also block.
68 //
69 // Shutdown:
70 // ---------
71 //
72 // The shutdown is initiated automatically when the last reference to the
73 // outfeed receiver object is dropped, and the Python garbage collector invokes
74 // the destructor.
75 //
76 // The shutdown sequence is implemented as follows:
77 // * we enqueue on all devices a computation that outfeeds a special header
78 //   with customer ID kOutfeedCidShutdown.
79 // * when each listening threads gets the shutdown header, it decrements
80 //   a counter of listening threads, and if the counter reaches 0, it
81 //   enqueues a special shutdown callback.
82 // * when the callback thread gets the shutdown callback marker, it terminates.
83 // * the shutdown code waits until all threads terminate.
84 //
85 // Since we currently keep the shape registry in the OutfeedReceiver, it is
86 // not safe to replace the OutfeedReceiver instance during the lifetime of
87 // the JAX program, or else previously cached jitted computations may refer
88 // to previously cached shapes. This can be solved, but for now we disallow
89 // replacing the OutfeedReceiver, and do not provide a Shutdown API to the
90 // Python program.
91 
92 namespace xla {
93 
94 // The header contains:
95 // 0. kOutfeedHeaderStart
96 // 1. consumer id
97 int constexpr kOutfeedHeaderWords = 2;
98 uint32_t constexpr kOutfeedHeaderStart = 271828;
99 // Special consumer IDs, without outfeed payload.
100 uint32_t constexpr kOutfeedCidShutdown = 0;
101 
102 // Encapsulates data received from a device outfeed.
103 class OutfeedData {
104  public:
OutfeedData(PjRtDevice * device,uint32_t consumer_id,Shape shape)105   OutfeedData(PjRtDevice* device, uint32_t consumer_id, Shape shape)
106       : device_(device),
107         consumer_id_(consumer_id),
108         shape_(shape),
109         literal_(nullptr),
110         literal_size_bytes_(0) {}
111 
device()112   PjRtDevice* device() { return device_; }
consumer_id() const113   uint32_t consumer_id() const { return consumer_id_; }
shape() const114   Shape shape() const { return shape_; }
literal()115   std::unique_ptr<Literal> literal() {
116     CHECK(literal_);
117     return std::move(literal_);
118   }
119 
120   void SetLiteral(std::unique_ptr<Literal> literal);
121 
literal_size_bytes() const122   ssize_t literal_size_bytes() const { return literal_size_bytes_; }
123 
124   std::string DebugString() const;
125 
126  private:
127   PjRtDevice* device_;
128   uint32_t consumer_id_;
129   Shape shape_;
130   std::unique_ptr<Literal> literal_;
131   ssize_t literal_size_bytes_;
132 };
133 
SetLiteral(std::unique_ptr<Literal> literal)134 void OutfeedData::SetLiteral(std::unique_ptr<Literal> literal) {
135   literal_ = std::move(literal);
136   shape_ = literal_->shape();
137   int total_size_bytes = 0;
138   ShapeUtil::ForEachSubshape(
139       shape_, [&](const Shape& literal_subshape, const ShapeIndex& index) {
140         if (!literal_subshape.IsTuple()) {
141           total_size_bytes += ShapeUtil::ByteSizeOf(literal_subshape, 8);
142         }
143       });
144   literal_size_bytes_ = total_size_bytes;
145 }
146 
DebugString() const147 std::string OutfeedData::DebugString() const {
148   return absl::StrFormat("dev=%s; cons=%d; shape=%s", device_->DebugString(),
149                          consumer_id_, shape_.ToString());
150 }
151 
152 class OutfeedReceiverImpl {
153  public:
154   OutfeedReceiverImpl(OutfeedReceiver::Callback callback,
155                       absl::Span<PjRtClient* const> clients,
156                       ssize_t max_callback_queue_size_bytes);
157 
158   OutfeedReceiverImpl(const OutfeedReceiverImpl&) = delete;
159   OutfeedReceiverImpl& operator=(const OutfeedReceiverImpl&) = delete;
160 
161   // Blocks until all data has been received from devices and all data
162   // in the queue has been passed to Python.
163   ~OutfeedReceiverImpl();
164 
165   void Start();
166 
167   StatusOr<XlaOp> AddOutfeedToBuilder(XlaBuilder* builder, XlaOp token,
168                                       uint32_t consumer_id,
169                                       std::vector<XlaOp> arrays);
170 
171  private:
CallbackQueueNotEmpty() const172   bool CallbackQueueNotEmpty() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
173     return !callback_queue_.empty();
174   }
175 
CallbackQueueHasSpace()176   bool CallbackQueueHasSpace() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
177     return callback_queue_size_bytes_ < max_callback_queue_size_bytes_;
178   }
179 
ShutdownDone()180   bool ShutdownDone() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
181     return (num_working_callback_threads_ == 0 && num_listening_threads_ == 0);
182   }
183 
184   void CallbackThreadLoop();
185   void DeviceListenerThreadLoop(int device_idx);
186 
187   // Enqueues to a device an outfeed operation with a shutdown consumer ID.
188   Status SendShutdownOutfeedHeader(int device_idx);
189 
190   // Receives a raw Literal from a device outfeed.
191   StatusOr<std::unique_ptr<Literal>> ReceiveRawFromOutfeed(PjRtDevice* device,
192                                                            const Shape& shape);
193 
194   // Enqueues received data in the callbaback queue.
195   void EnqueueReceivedData(std::unique_ptr<OutfeedData> received)
196       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
197 
198   // Shuts down the threads. See implementation notes at top of file.
199   // It is not safe to restart an OutfeedReceiver after shutting down one.
200   void Shutdown();
201 
202   OutfeedReceiver::Callback callback_;
203   // The devices on which we are listening.
204   std::vector<PjRtDevice*> devices_;
205   // Maximum bytes capacity of the callback queue.
206   uint64_t max_callback_queue_size_bytes_;
207 
208   absl::Mutex mu_;
209   // Registered shapes by consumer id.
210   // The shape registry must be alive as long as the program exists.
211   // Right now we tell the user to never restart after Shutdown.
212   absl::flat_hash_map<uint32_t, Shape> shape_registry_ TF_GUARDED_BY(mu_);
213   // How many bytes of Literal are in the callback queue.
214   uint64_t callback_queue_size_bytes_ TF_GUARDED_BY(mu_);
215   // Threads listening.
216   int num_listening_threads_ TF_GUARDED_BY(mu_);
217   bool shutdown_started_ TF_GUARDED_BY(mu_);
218 
219   // How many callback threads are still working. Used for shutdown.
220   int num_working_callback_threads_ TF_GUARDED_BY(mu_);
221 
222   std::queue<std::unique_ptr<OutfeedData>> callback_queue_ TF_GUARDED_BY(mu_);
223   // The threadpool must come last to ensure the queue exists
224   // when the pool destructor is called.
225   std::unique_ptr<tensorflow::thread::ThreadPool> threads_;
226 };
227 
OutfeedReceiverImpl(OutfeedReceiver::Callback callback,absl::Span<PjRtClient * const> clients,ssize_t max_callback_queue_size_bytes)228 OutfeedReceiverImpl::OutfeedReceiverImpl(
229     OutfeedReceiver::Callback callback, absl::Span<PjRtClient* const> clients,
230     ssize_t max_callback_queue_size_bytes) {
231   callback_ = callback;
232   max_callback_queue_size_bytes_ = max_callback_queue_size_bytes;
233   for (const auto& client : clients) {
234     for (auto device : client->addressable_devices()) {
235       devices_.push_back(device);
236     }
237   }
238   CHECK_GT(devices_.size(), 0);
239 
240   callback_queue_size_bytes_ = 0;
241   num_listening_threads_ = 0;
242   num_working_callback_threads_ = 0;
243   shutdown_started_ = false;
244 }
245 
Start()246 void OutfeedReceiverImpl::Start() {
247   {
248     absl::MutexLock lock(&mu_);
249     CHECK(!shutdown_started_);
250   }
251   int num_threads = 1 + devices_.size();
252   threads_ = absl::make_unique<tensorflow::thread::ThreadPool>(
253       tensorflow::Env::Default(), "outfeed_receiver", num_threads);
254   threads_->Schedule([this]() { CallbackThreadLoop(); });
255   for (int device_idx = 0; device_idx < devices_.size(); ++device_idx) {
256     threads_->Schedule(
257         [this, device_idx]() { DeviceListenerThreadLoop(device_idx); });
258   }
259 }
260 
Shutdown()261 void OutfeedReceiverImpl::Shutdown() {
262   VLOG(2) << "Shutdown start";
263   {
264     absl::MutexLock lock(&mu_);
265     CHECK(!shutdown_started_);
266     shutdown_started_ = true;
267   }
268   for (int device_idx = 0; device_idx < devices_.size(); ++device_idx) {
269     CHECK(SendShutdownOutfeedHeader(device_idx).ok());
270   }
271   VLOG(2) << "Shutdown waiting for listening and callback threads to stop";
272   absl::MutexLock lock(&mu_);
273   mu_.Await(absl::Condition(this, &OutfeedReceiverImpl::ShutdownDone));
274   VLOG(2) << "Shutdown done";
275 }
276 
~OutfeedReceiverImpl()277 OutfeedReceiverImpl::~OutfeedReceiverImpl() {
278   VLOG(2) << "~OutfeedReceiverImpl";
279   Shutdown();
280 }
281 
DeviceListenerThreadLoop(int device_idx)282 void OutfeedReceiverImpl::DeviceListenerThreadLoop(int device_idx) {
283   {
284     absl::MutexLock lock(&mu_);
285     ++num_listening_threads_;
286   }
287   PjRtDevice* device = devices_[device_idx];
288   while (true) {
289     Shape header_shape = ShapeUtil::MakeShape(U32, {kOutfeedHeaderWords});
290     std::unique_ptr<Literal> header =
291         ReceiveRawFromOutfeed(device, header_shape).ValueOrDie();
292     absl::Span<uint32_t> header_data = header->data<uint32>();
293     CHECK_EQ(header_data.size(), kOutfeedHeaderWords);
294     CHECK_EQ(header_data[0], kOutfeedHeaderStart);
295     uint32_t consumer_id = header_data[1];
296     Shape shape;
297     {
298       absl::MutexLock lock(&mu_);
299       auto registered_shape = shape_registry_.find(consumer_id);
300       if (registered_shape == shape_registry_.end()) {
301         LOG(FATAL)
302             << "[" << device->DebugString()
303             << "] Cannot find registered shape for consumer ID " << consumer_id
304             << ". Perhaps the code was compiled with a different instance "
305             << "of OutfeedReceiver.";
306       }
307       shape = registered_shape->second;
308     }
309     auto received = absl::make_unique<OutfeedData>(device, consumer_id, shape);
310     VLOG(2) << "Listener received header " << received->DebugString();
311     if (consumer_id == kOutfeedCidShutdown) {
312       VLOG(2) << "[" << device->DebugString()
313               << "] Listener received shutdown header";
314       absl::MutexLock lock(&mu_);
315       --num_listening_threads_;
316       if (num_listening_threads_ == 0) {
317         VLOG(2) << "Last listener shutdown; enqueue shutdown callback";
318         EnqueueReceivedData(std::move(received));
319       }
320       return;
321     }
322     std::unique_ptr<Literal> data =
323         ReceiveRawFromOutfeed(device, shape).ValueOrDie();
324     received->SetLiteral(std::move(data));
325     absl::MutexLock lock(&mu_);
326     EnqueueReceivedData(std::move(received));
327   }
328 }
329 
EnqueueReceivedData(std::unique_ptr<OutfeedData> received)330 void OutfeedReceiverImpl::EnqueueReceivedData(
331     std::unique_ptr<OutfeedData> received) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
332   mu_.Await(absl::Condition(this, &OutfeedReceiverImpl::CallbackQueueHasSpace));
333   ssize_t literal_size_bytes = received->literal_size_bytes();
334   callback_queue_size_bytes_ += literal_size_bytes;
335   VLOG(2) << "Listener enqueues data " << received->DebugString() << " of size "
336           << literal_size_bytes << " bytes; " << (1 + callback_queue_.size())
337           << " callbacks in queue of total size " << callback_queue_size_bytes_
338           << " bytes.\n";
339   callback_queue_.push(std::move(received));
340 }
341 
ReceiveRawFromOutfeed(PjRtDevice * device,const Shape & shape)342 StatusOr<std::unique_ptr<Literal>> OutfeedReceiverImpl::ReceiveRawFromOutfeed(
343     PjRtDevice* device, const Shape& shape) {
344   auto literal = std::make_unique<Literal>(shape);
345   TF_RETURN_IF_ERROR(device->TransferFromOutfeed(literal.get()));
346   return literal;
347 }
348 
CallbackThreadLoop()349 void OutfeedReceiverImpl::CallbackThreadLoop() {
350   {
351     absl::MutexLock lock(&mu_);
352     num_working_callback_threads_++;
353     CHECK_EQ(num_working_callback_threads_, 1);
354   }
355   while (true) {
356     std::unique_ptr<OutfeedData> received;
357     {
358       absl::MutexLock lock(&mu_);
359       mu_.Await(
360           absl::Condition(this, &OutfeedReceiverImpl::CallbackQueueNotEmpty));
361       received = std::move(callback_queue_.front());
362       callback_queue_.pop();
363       callback_queue_size_bytes_ -= received->literal_size_bytes();
364       VLOG(2) << "Dequeued callback for " << received->DebugString() << "; "
365               << callback_queue_.size() << " callbacks in queue of total size "
366               << callback_queue_size_bytes_ << " bytes.\n";
367     }
368     if (received->consumer_id() == kOutfeedCidShutdown) {
369       VLOG(2) << "Callback loop received shutdown signal";
370       {
371         absl::MutexLock lock(&mu_);
372         CHECK(callback_queue_.empty());
373         CHECK_EQ(callback_queue_size_bytes_, 0);
374         --num_working_callback_threads_;
375       }
376       VLOG(2) << "Callback loop done";
377       return;
378     }
379     {
380       tensorflow::profiler::TraceMe traceme("OutfeedReceiver::Callback");
381       callback_(received->device(), received->consumer_id(),
382                 received->literal());
383     }
384   }
385 }
386 
SendShutdownOutfeedHeader(int device_idx)387 Status OutfeedReceiverImpl::SendShutdownOutfeedHeader(int device_idx) {
388   const PjRtDevice* device = devices_[device_idx];
389   constexpr int consumer_id = kOutfeedCidShutdown;
390   VLOG(2) << "[" << device->DebugString()
391           << "] SendSpecialHeader cons=" << consumer_id;
392   XlaBuilder builder(
393       absl::StrFormat("special_outfeed_header_%d_%d", consumer_id, device_idx));
394   XlaOp send =
395       AddOutfeedToBuilder(&builder, CreateToken(&builder), consumer_id, {})
396           .ValueOrDie();
397   XlaComputation computation = builder.Build(send).ValueOrDie();
398 
399   CompileOptions compile_options;
400   compile_options.executable_build_options.set_num_replicas(1);
401   compile_options.executable_build_options.set_num_partitions(1);
402   DeviceAssignment device_assignment(1, 1);
403   device_assignment(0, 0) = device->id();
404   compile_options.executable_build_options.set_device_assignment(
405       device_assignment);
406 
407   TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtExecutable> executable,
408                       devices_[device_idx]->client()->Compile(
409                           computation, std::move(compile_options)));
410   ExecuteOptions execute_options;
411   TF_ASSIGN_OR_RETURN(
412       std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> output_buffers,
413       executable->Execute({{}}, execute_options));
414   return Status::OK();
415 }
416 
AddOutfeedToBuilder(XlaBuilder * builder,XlaOp token,uint32_t consumer_id,std::vector<XlaOp> arrays)417 StatusOr<XlaOp> OutfeedReceiverImpl::AddOutfeedToBuilder(
418     XlaBuilder* builder, XlaOp token, uint32_t consumer_id,
419     std::vector<XlaOp> arrays) {
420   XlaOp data = Tuple(builder, std::move(arrays));
421   Shape shape_with_layout = builder->GetShape(data).ValueOrDie();
422   ShapeUtil::ForEachMutableSubshape(
423       &shape_with_layout, [](Shape* subshape, const ShapeIndex&) {
424         if (!subshape->has_layout()) {
425           LayoutUtil::SetToDefaultLayout(subshape);
426         }
427       });
428   VLOG(2) << "RegisterShape cons=" << consumer_id
429           << "; shape=" << shape_with_layout.ToString();
430   {
431     absl::MutexLock lock(&mu_);
432     auto found = shape_registry_.find(consumer_id);
433     if (found != shape_registry_.end()) {
434       if (!ShapeUtil::Equal(shape_with_layout, found->second)) {
435         return InvalidArgument(
436             "Shape %s does not match previous shape %s used "
437             "for consumer id %d",
438             shape_with_layout.DebugString(), found->second.DebugString(),
439             consumer_id);
440       }
441     } else {
442       shape_registry_.insert({consumer_id, shape_with_layout});
443     }
444   }
445 
446   std::vector<uint32_t> header{kOutfeedHeaderStart, consumer_id};
447   XlaOp header_op = ConstantR1<uint32_t>(builder, header);
448   token = OutfeedWithToken(
449       header_op, token, ShapeUtil::MakeShape(U32, {kOutfeedHeaderWords}), "");
450   if (consumer_id != kOutfeedCidShutdown) {
451     token = OutfeedWithToken(data, token, shape_with_layout, "");
452   }
453   return token;
454 }
455 
OutfeedReceiver(Callback callback,absl::Span<PjRtClient * const> clients,ssize_t max_callback_queue_size_bytes)456 OutfeedReceiver::OutfeedReceiver(Callback callback,
457                                  absl::Span<PjRtClient* const> clients,
458                                  ssize_t max_callback_queue_size_bytes) {
459   p_impl_ = absl::make_unique<OutfeedReceiverImpl>(
460       callback, clients, max_callback_queue_size_bytes);
461 }
462 
~OutfeedReceiver()463 OutfeedReceiver::~OutfeedReceiver() {}
464 
Start()465 void OutfeedReceiver::Start() { p_impl_->Start(); }
466 
AddOutfeedToBuilder(XlaBuilder * builder,XlaOp token,uint32_t consumer_id,std::vector<XlaOp> arrays)467 StatusOr<XlaOp> OutfeedReceiver::AddOutfeedToBuilder(
468     XlaBuilder* builder, XlaOp token, uint32_t consumer_id,
469     std::vector<XlaOp> arrays) {
470   if (consumer_id == kOutfeedCidShutdown) {
471     return InvalidArgument("Consumer ID cannot be a reserved value: %d",
472                            consumer_id);
473   }
474   return p_impl_->AddOutfeedToBuilder(builder, token, consumer_id, arrays);
475 }
476 
477 }  // namespace xla
478