1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_XFEED_QUEUE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_XFEED_QUEUE_H_ 18 19 #include <deque> 20 #include <functional> 21 #include <vector> 22 23 #include "absl/base/thread_annotations.h" 24 #include "tensorflow/core/platform/logging.h" 25 #include "tensorflow/core/platform/mutex.h" 26 #include "tensorflow/core/platform/notification.h" 27 #include "tensorflow/core/platform/thread_annotations.h" 28 29 namespace xla { 30 namespace gpu { 31 32 // TODO(b/30467474) Once GPU outfeed implementation settles, consider 33 // folding back the cpu and gpu outfeed implementations into a generic 34 // one if possible. 35 36 // Manages a thread-safe queue of buffers. 37 template <typename BufferType> 38 class XfeedQueue { 39 public: 40 // Adds a tree of buffers to the queue. The individual buffers correspond to 41 // the elements of a tuple and may be nullptr if the buffer is a tuple index 42 // buffer. EnqueueDestination(BufferType buffers)43 void EnqueueDestination(BufferType buffers) { 44 tensorflow::mutex_lock l(mu_); 45 enqueued_buffers_.push_back(std::move(buffers)); 46 enqueue_cv_.notify_one(); 47 48 EnqueueHook(); 49 } 50 51 // Blocks until the queue is non-empty, then returns the buffer at the head of 52 // the queue. BlockingGetNextDestination()53 BufferType BlockingGetNextDestination() { 54 for (const auto& callback : before_get_next_dest_callbacks_) { 55 callback(); 56 } 57 58 bool became_empty; 59 BufferType current_buffer; 60 { 61 tensorflow::mutex_lock l(mu_); 62 while (enqueued_buffers_.empty()) { 63 enqueue_cv_.wait(l); 64 } 65 current_buffer = std::move(enqueued_buffers_.front()); 66 enqueued_buffers_.pop_front(); 67 DequeueHook(); 68 became_empty = enqueued_buffers_.empty(); 69 } 70 if (became_empty) { 71 for (const auto& callback : on_empty_callbacks_) { 72 callback(); 73 } 74 } 75 return current_buffer; 76 } 77 RegisterOnEmptyCallback(std::function<void ()> callback)78 void RegisterOnEmptyCallback(std::function<void()> callback) { 79 on_empty_callbacks_.push_back(std::move(callback)); 80 } RegisterBeforeGetNextDestinationCallback(std::function<void ()> callback)81 void RegisterBeforeGetNextDestinationCallback( 82 std::function<void()> callback) { 83 before_get_next_dest_callbacks_.push_back(std::move(callback)); 84 } 85 ~XfeedQueue()86 virtual ~XfeedQueue() {} 87 88 protected: DequeueHook()89 virtual void DequeueHook() ABSL_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) {} EnqueueHook()90 virtual void EnqueueHook() ABSL_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) {} 91 92 tensorflow::mutex mu_; 93 94 // The queue of trees of buffers. Buffer* queue contents are not owned. 95 std::deque<BufferType> enqueued_buffers_ ABSL_GUARDED_BY(mu_); 96 97 private: 98 // Condition variable that is signaled every time a buffer is enqueued. 99 tensorflow::condition_variable enqueue_cv_; 100 101 // List of callbacks which will be called when 'enqueued_buffers_' becomes 102 // empty. 103 std::vector<std::function<void()>> on_empty_callbacks_; 104 105 // List of callbacks which will be called before BlockingGetNextDestination() 106 // is called. This lets you e.g. call EnqueueDestination() for each call to 107 // BlockingGetNextDestination(). 108 std::vector<std::function<void()>> before_get_next_dest_callbacks_; 109 }; 110 111 // Like XfeedQueue but with a maximum capacity. Clients can call 112 // `BlockUntilEnqueueSlotAvailable` to block until there are fewer than 113 // `max_pending_xfeeds_` capacity pending infeed items. 114 // 115 // We introduce a separate `BlockUntilEnqueueSlotAvailable` (as opposed to 116 // overriding `EnqueueDestination` to block) because we want to block before we 117 // copy the buffer to GPU memory, in order to bound the memory consumption due 118 // to pending infeeds. 119 template <typename BufferType> 120 class BlockingXfeedQueue : public XfeedQueue<BufferType> { 121 public: BlockingXfeedQueue(int max_pending_xfeeds)122 explicit BlockingXfeedQueue(int max_pending_xfeeds) 123 : max_pending_xfeeds_(max_pending_xfeeds) {} 124 BlockUntilEnqueueSlotAvailable()125 void BlockUntilEnqueueSlotAvailable() { 126 tensorflow::mutex_lock l{this->mu_}; 127 while (pending_buffers_ + this->enqueued_buffers_.size() >= 128 max_pending_xfeeds_) { 129 VLOG(2) << "Capacity " 130 << (pending_buffers_ + this->enqueued_buffers_.size()) 131 << " >= max capacity " << max_pending_xfeeds_; 132 dequeue_cv_.wait(l); 133 } 134 135 pending_buffers_++; 136 } 137 138 protected: EnqueueHook()139 void EnqueueHook() ABSL_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) override { 140 pending_buffers_--; 141 } 142 DequeueHook()143 void DequeueHook() ABSL_EXCLUSIVE_LOCKS_REQUIRED(this->mu_) override { 144 dequeue_cv_.notify_one(); 145 } 146 147 private: 148 const int max_pending_xfeeds_; 149 150 // Condition variable that is signaled every time a buffer is dequeued. 151 tensorflow::condition_variable dequeue_cv_; 152 153 // Keeps track of the number of buffers reserved but not added to 154 // enqueued_buffers_. 155 int pending_buffers_ ABSL_GUARDED_BY(this->mu_) = 0; 156 }; 157 158 } // namespace gpu 159 } // namespace xla 160 161 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_XFEED_QUEUE_H_ 162