• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_DATA_SERVICE_THREAD_SAFE_BUFFER_H_
16 #define TENSORFLOW_CORE_DATA_SERVICE_THREAD_SAFE_BUFFER_H_
17 
18 #include <deque>
19 
20 #include "tensorflow/core/platform/macros.h"
21 #include "tensorflow/core/platform/mutex.h"
22 #include "tensorflow/core/platform/status.h"
23 #include "tensorflow/core/platform/statusor.h"
24 
25 namespace tensorflow {
26 namespace data {
27 
28 // A thread-safe bounded buffer with cancellation support.
29 template <class T>
30 class ThreadSafeBuffer final {
31  public:
32   // Creates a buffer with the specified `buffer_size`.
33   // REQUIRES: buffer_size > 0
34   explicit ThreadSafeBuffer(size_t buffer_size);
35 
36   // Gets the next element. Blocks if the buffer is empty. Returns an error if
37   // a non-OK status was pushed or the buffer has been cancelled.
38   StatusOr<T> Pop();
39 
40   // Writes the next element. Blocks if the buffer is full. Returns an error if
41   // the buffer has been cancelled.
42   Status Push(StatusOr<T> value);
43 
44   // Cancels the buffer with `status` and notifies waiting threads. After
45   // cancelling, all `Push` and `Pop` calls will return `status`.
46   // REQUIRES: !status.ok()
47   void Cancel(Status status);
48 
49  private:
50   const size_t buffer_size_;
51 
52   mutex mu_;
53   condition_variable ready_to_pop_;
54   condition_variable ready_to_push_;
55   std::deque<StatusOr<T>> results_ TF_GUARDED_BY(mu_);
56   Status status_ TF_GUARDED_BY(mu_) = Status::OK();
57 
58   TF_DISALLOW_COPY_AND_ASSIGN(ThreadSafeBuffer);
59 };
60 
61 template <class T>
ThreadSafeBuffer(size_t buffer_size)62 ThreadSafeBuffer<T>::ThreadSafeBuffer(size_t buffer_size)
63     : buffer_size_(buffer_size) {
64   DCHECK_GT(buffer_size, 0)
65       << "ThreadSafeBuffer must have a postive buffer size. Got " << buffer_size
66       << ".";
67 }
68 
69 template <class T>
Pop()70 StatusOr<T> ThreadSafeBuffer<T>::Pop() {
71   mutex_lock l(mu_);
72   while (status_.ok() && results_.empty()) {
73     ready_to_pop_.wait(l);
74   }
75   if (!status_.ok()) {
76     return status_;
77   }
78   StatusOr<T> result = std::move(results_.front());
79   results_.pop_front();
80   ready_to_push_.notify_one();
81   return result;
82 }
83 
84 template <class T>
Push(StatusOr<T> value)85 Status ThreadSafeBuffer<T>::Push(StatusOr<T> value) {
86   mutex_lock l(mu_);
87   while (status_.ok() && results_.size() >= buffer_size_) {
88     ready_to_push_.wait(l);
89   }
90   if (!status_.ok()) {
91     return status_;
92   }
93   results_.push_back(std::move(value));
94   ready_to_pop_.notify_one();
95   return Status::OK();
96 }
97 
98 template <class T>
Cancel(Status status)99 void ThreadSafeBuffer<T>::Cancel(Status status) {
100   DCHECK(!status.ok())
101       << "Cancelling ThreadSafeBuffer requires a non-OK status. Got " << status;
102   mutex_lock l(mu_);
103   status_ = std::move(status);
104   ready_to_push_.notify_all();
105   ready_to_pop_.notify_all();
106 }
107 
108 }  // namespace data
109 }  // namespace tensorflow
110 
111 #endif  // TENSORFLOW_CORE_DATA_SERVICE_THREAD_SAFE_BUFFER_H_
112