• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_XFEED_QUEUE_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_XFEED_QUEUE_H_
18 
19 #include <deque>
20 #include <functional>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/base/thread_annotations.h"
25 #include "tensorflow/core/platform/logging.h"
26 
27 namespace xla {
28 namespace gpu {
29 
30 // TODO(b/30467474) Once GPU outfeed implementation settles, consider
31 // folding back the cpu and gpu outfeed implementations into a generic
32 // one if possible.
33 
34 // Manages a thread-safe queue of buffers.
35 template <typename BufferType>
36 class XfeedQueue {
37  public:
38   // Adds a tree of buffers to the queue. The individual buffers correspond to
39   // the elements of a tuple and may be nullptr if the buffer is a tuple index
40   // buffer.
EnqueueDestination(BufferType buffers)41   void EnqueueDestination(BufferType buffers) {
42     absl::MutexLock l(&mu_);
43     enqueued_buffers_.push_back(std::move(buffers));
44     enqueue_cv_.Signal();
45 
46     EnqueueHook();
47   }
48 
49   // Blocks until the queue is non-empty, then returns the buffer at the head of
50   // the queue.
BlockingGetNextDestination()51   BufferType BlockingGetNextDestination() {
52     for (const auto& callback : before_get_next_dest_callbacks_) {
53       callback();
54     }
55 
56     bool became_empty;
57     BufferType current_buffer;
58     {
59       absl::MutexLock l(&mu_);
60       while (enqueued_buffers_.empty()) {
61         enqueue_cv_.Wait(&mu_);
62       }
63       current_buffer = std::move(enqueued_buffers_.front());
64       enqueued_buffers_.pop_front();
65       DequeueHook();
66       became_empty = enqueued_buffers_.empty();
67     }
68     if (became_empty) {
69       for (const auto& callback : on_empty_callbacks_) {
70         callback();
71       }
72     }
73     return current_buffer;
74   }
75 
RegisterOnEmptyCallback(std::function<void ()> callback)76   void RegisterOnEmptyCallback(std::function<void()> callback) {
77     on_empty_callbacks_.push_back(std::move(callback));
78   }
RegisterBeforeGetNextDestinationCallback(std::function<void ()> callback)79   void RegisterBeforeGetNextDestinationCallback(
80       std::function<void()> callback) {
81     before_get_next_dest_callbacks_.push_back(std::move(callback));
82   }
83 
~XfeedQueue()84   virtual ~XfeedQueue() {}
85 
86  protected:
DequeueHook()87   virtual void DequeueHook() ABSL_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) {}
EnqueueHook()88   virtual void EnqueueHook() ABSL_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) {}
89 
90   absl::Mutex mu_;
91 
92   // The queue of trees of buffers. Buffer* queue contents are not owned.
93   std::deque<BufferType> enqueued_buffers_ ABSL_GUARDED_BY(mu_);
94 
95  private:
96   // Condition variable that is signaled every time a buffer is enqueued.
97   absl::CondVar enqueue_cv_;
98 
99   // List of callbacks which will be called when 'enqueued_buffers_' becomes
100   // empty.
101   std::vector<std::function<void()>> on_empty_callbacks_;
102 
103   // List of callbacks which will be called before BlockingGetNextDestination()
104   // is called. This lets you e.g. call EnqueueDestination() for each call to
105   // BlockingGetNextDestination().
106   std::vector<std::function<void()>> before_get_next_dest_callbacks_;
107 };
108 
109 // Like XfeedQueue but with a maximum capacity.  Clients can call
110 // `BlockUntilEnqueueSlotAvailable` to block until there are fewer than
111 // `max_pending_xfeeds_` capacity pending infeed items.
112 //
113 // We introduce a separate `BlockUntilEnqueueSlotAvailable` (as opposed to
114 // overriding `EnqueueDestination` to block) because we want to block before we
115 // copy the buffer to GPU memory, in order to bound the memory consumption due
116 // to pending infeeds.
117 template <typename BufferType>
118 class BlockingXfeedQueue : public XfeedQueue<BufferType> {
119  public:
BlockingXfeedQueue(int max_pending_xfeeds)120   explicit BlockingXfeedQueue(int max_pending_xfeeds)
121       : max_pending_xfeeds_(max_pending_xfeeds) {}
122 
BlockUntilEnqueueSlotAvailable()123   void BlockUntilEnqueueSlotAvailable() {
124     absl::MutexLock l{&this->mu_};
125     while (pending_buffers_ + this->enqueued_buffers_.size() >=
126            max_pending_xfeeds_) {
127       VLOG(2) << "Capacity "
128               << (pending_buffers_ + this->enqueued_buffers_.size())
129               << " >= max capacity " << max_pending_xfeeds_;
130       dequeue_cv_.Wait(&this->mu_);
131     }
132 
133     pending_buffers_++;
134   }
135 
136  protected:
EnqueueHook()137   void EnqueueHook() ABSL_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) override {
138     pending_buffers_--;
139   }
140 
DequeueHook()141   void DequeueHook() ABSL_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) override {
142     dequeue_cv_.Signal();
143   }
144 
145  private:
146   const int max_pending_xfeeds_;
147 
148   // Condition variable that is signaled every time a buffer is dequeued.
149   absl::CondVar dequeue_cv_;
150 
151   // Keeps track of the number of buffers reserved but not added to
152   // enqueued_buffers_.
153   int pending_buffers_ ABSL_GUARDED_BY(this->mu_) = 0;
154 };
155 
156 }  // namespace gpu
157 }  // namespace xla
158 
159 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_XFEED_QUEUE_H_
160