• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_XFEED_QUEUE_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_XFEED_QUEUE_H_
18 
19 #include <deque>
20 #include <functional>
21 #include <vector>
22 
23 #include "absl/base/thread_annotations.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/platform/mutex.h"
26 #include "tensorflow/core/platform/notification.h"
27 #include "tensorflow/core/platform/thread_annotations.h"
28 
29 namespace xla {
30 namespace gpu {
31 
32 // TODO(b/30467474) Once GPU outfeed implementation settles, consider
33 // folding back the cpu and gpu outfeed implementations into a generic
34 // one if possible.
35 
36 // Manages a thread-safe queue of buffers.
37 template <typename BufferType>
38 class XfeedQueue {
39  public:
40   // Adds a tree of buffers to the queue. The individual buffers correspond to
41   // the elements of a tuple and may be nullptr if the buffer is a tuple index
42   // buffer.
EnqueueDestination(BufferType buffers)43   void EnqueueDestination(BufferType buffers) {
44     tensorflow::mutex_lock l(mu_);
45     enqueued_buffers_.push_back(std::move(buffers));
46     enqueue_cv_.notify_one();
47 
48     EnqueueHook();
49   }
50 
51   // Blocks until the queue is non-empty, then returns the buffer at the head of
52   // the queue.
BlockingGetNextDestination()53   BufferType BlockingGetNextDestination() {
54     for (const auto& callback : before_get_next_dest_callbacks_) {
55       callback();
56     }
57 
58     bool became_empty;
59     BufferType current_buffer;
60     {
61       tensorflow::mutex_lock l(mu_);
62       while (enqueued_buffers_.empty()) {
63         enqueue_cv_.wait(l);
64       }
65       current_buffer = std::move(enqueued_buffers_.front());
66       enqueued_buffers_.pop_front();
67       DequeueHook();
68       became_empty = enqueued_buffers_.empty();
69     }
70     if (became_empty) {
71       for (const auto& callback : on_empty_callbacks_) {
72         callback();
73       }
74     }
75     return current_buffer;
76   }
77 
RegisterOnEmptyCallback(std::function<void ()> callback)78   void RegisterOnEmptyCallback(std::function<void()> callback) {
79     on_empty_callbacks_.push_back(std::move(callback));
80   }
RegisterBeforeGetNextDestinationCallback(std::function<void ()> callback)81   void RegisterBeforeGetNextDestinationCallback(
82       std::function<void()> callback) {
83     before_get_next_dest_callbacks_.push_back(std::move(callback));
84   }
85 
~XfeedQueue()86   virtual ~XfeedQueue() {}
87 
88  protected:
DequeueHook()89   virtual void DequeueHook() ABSL_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) {}
EnqueueHook()90   virtual void EnqueueHook() ABSL_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) {}
91 
92   tensorflow::mutex mu_;
93 
94   // The queue of trees of buffers. Buffer* queue contents are not owned.
95   std::deque<BufferType> enqueued_buffers_ ABSL_GUARDED_BY(mu_);
96 
97  private:
98   // Condition variable that is signaled every time a buffer is enqueued.
99   tensorflow::condition_variable enqueue_cv_;
100 
101   // List of callbacks which will be called when 'enqueued_buffers_' becomes
102   // empty.
103   std::vector<std::function<void()>> on_empty_callbacks_;
104 
105   // List of callbacks which will be called before BlockingGetNextDestination()
106   // is called. This lets you e.g. call EnqueueDestination() for each call to
107   // BlockingGetNextDestination().
108   std::vector<std::function<void()>> before_get_next_dest_callbacks_;
109 };
110 
111 // Like XfeedQueue but with a maximum capacity.  Clients can call
112 // `BlockUntilEnqueueSlotAvailable` to block until there are fewer than
113 // `max_pending_xfeeds_` capacity pending infeed items.
114 //
115 // We introduce a separate `BlockUntilEnqueueSlotAvailable` (as opposed to
116 // overriding `EnqueueDestination` to block) because we want to block before we
117 // copy the buffer to GPU memory, in order to bound the memory consumption due
118 // to pending infeeds.
119 template <typename BufferType>
120 class BlockingXfeedQueue : public XfeedQueue<BufferType> {
121  public:
BlockingXfeedQueue(int max_pending_xfeeds)122   explicit BlockingXfeedQueue(int max_pending_xfeeds)
123       : max_pending_xfeeds_(max_pending_xfeeds) {}
124 
BlockUntilEnqueueSlotAvailable()125   void BlockUntilEnqueueSlotAvailable() {
126     tensorflow::mutex_lock l{this->mu_};
127     while (pending_buffers_ + this->enqueued_buffers_.size() >=
128            max_pending_xfeeds_) {
129       VLOG(2) << "Capacity "
130               << (pending_buffers_ + this->enqueued_buffers_.size())
131               << " >= max capacity " << max_pending_xfeeds_;
132       dequeue_cv_.wait(l);
133     }
134 
135     pending_buffers_++;
136   }
137 
138  protected:
EnqueueHook()139   void EnqueueHook() ABSL_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) override {
140     pending_buffers_--;
141   }
142 
DequeueHook()143   void DequeueHook() ABSL_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) override {
144     dequeue_cv_.notify_one();
145   }
146 
147  private:
148   const int max_pending_xfeeds_;
149 
150   // Condition variable that is signaled every time a buffer is dequeued.
151   tensorflow::condition_variable dequeue_cv_;
152 
153   // Keeps track of the number of buffers reserved but not added to
154   // enqueued_buffers_.
155   int pending_buffers_ ABSL_GUARDED_BY(this->mu_) = 0;
156 };
157 
158 }  // namespace gpu
159 }  // namespace xla
160 
161 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_XFEED_QUEUE_H_
162