• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_XFEED_QUEUE_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_XFEED_QUEUE_H_
18 
19 #include <deque>
20 #include <functional>
21 #include <vector>
22 
23 #include "absl/base/thread_annotations.h"
24 #include "tensorflow/core/platform/mutex.h"
25 #include "tensorflow/core/platform/notification.h"
26 #include "tensorflow/core/platform/thread_annotations.h"
27 
28 namespace xla {
29 namespace gpu {
30 
31 // TODO(b/30467474) Once GPU outfeed implementation settles, consider
32 // folding back the cpu and gpu outfeed implementations into a generic
33 // one if possible.
34 
35 // Manages a thread-safe queue of buffers.
36 template <typename BufferType>
37 class XfeedQueue {
38  public:
39   // Adds a tree of buffers to the queue. The individual buffers correspond to
40   // the elements of a tuple and may be nullptr if the buffer is a tuple index
41   // buffer.
EnqueueDestination(BufferType buffers)42   void EnqueueDestination(BufferType buffers) {
43     tensorflow::mutex_lock l(mu_);
44     enqueued_buffers_.push_back(std::move(buffers));
45     cv_.notify_one();
46   }
47 
48   // Blocks until the queue is non-empty, then returns the buffer at the head of
49   // the queue.
BlockingGetNextDestination()50   BufferType BlockingGetNextDestination() {
51     for (const auto& callback : before_get_next_dest_callbacks_) {
52       callback();
53     }
54 
55     bool became_empty;
56     BufferType current_buffer;
57     {
58       tensorflow::mutex_lock l(mu_);
59       while (enqueued_buffers_.empty()) {
60         cv_.wait(l);
61       }
62       current_buffer = std::move(enqueued_buffers_.front());
63       enqueued_buffers_.pop_front();
64       became_empty = enqueued_buffers_.empty();
65     }
66     if (became_empty) {
67       for (const auto& callback : on_empty_callbacks_) {
68         callback();
69       }
70     }
71     return current_buffer;
72   }
73 
RegisterOnEmptyCallback(std::function<void ()> callback)74   void RegisterOnEmptyCallback(std::function<void()> callback) {
75     on_empty_callbacks_.push_back(std::move(callback));
76   }
RegisterBeforeGetNextDestinationCallback(std::function<void ()> callback)77   void RegisterBeforeGetNextDestinationCallback(
78       std::function<void()> callback) {
79     before_get_next_dest_callbacks_.push_back(std::move(callback));
80   }
81 
82  private:
83   tensorflow::mutex mu_;
84 
85   // Condition variable that is signaled every time a buffer is enqueued.
86   tensorflow::condition_variable cv_;
87 
88   // The queue of trees of buffers. Buffer* queue contents are not owned.
89   std::deque<BufferType> enqueued_buffers_ ABSL_GUARDED_BY(mu_);
90 
91   // List of callbacks which will be called when 'enqueued_buffers_' becomes
92   // empty.
93   std::vector<std::function<void()>> on_empty_callbacks_;
94 
95   // List of callbacks which will be called before BlockingGetNextDestination()
96   // is called. This lets you e.g. call EnqueueDestination() for each call to
97   // BlockingGetNextDestination().
98   std::vector<std::function<void()>> before_get_next_dest_callbacks_;
99 };
100 
101 }  // namespace gpu
102 }  // namespace xla
103 
104 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_XFEED_QUEUE_H_
105