1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_DATA_SERVICE_THREAD_SAFE_BUFFER_H_ 16 #define TENSORFLOW_CORE_DATA_SERVICE_THREAD_SAFE_BUFFER_H_ 17 18 #include <deque> 19 20 #include "tensorflow/core/platform/macros.h" 21 #include "tensorflow/core/platform/mutex.h" 22 #include "tensorflow/core/platform/status.h" 23 #include "tensorflow/core/platform/statusor.h" 24 25 namespace tensorflow { 26 namespace data { 27 28 // A thread-safe bounded buffer with cancellation support. 29 template <class T> 30 class ThreadSafeBuffer final { 31 public: 32 // Creates a buffer with the specified `buffer_size`. 33 // REQUIRES: buffer_size > 0 34 explicit ThreadSafeBuffer(size_t buffer_size); 35 36 // Gets the next element. Blocks if the buffer is empty. Returns an error if 37 // a non-OK status was pushed or the buffer has been cancelled. 38 StatusOr<T> Pop(); 39 40 // Writes the next element. Blocks if the buffer is full. Returns an error if 41 // the buffer has been cancelled. 42 Status Push(StatusOr<T> value); 43 44 // Cancels the buffer with `status` and notifies waiting threads. After 45 // cancelling, all `Push` and `Pop` calls will return `status`. 46 // REQUIRES: !status.ok() 47 void Cancel(Status status); 48 49 private: 50 const size_t buffer_size_; 51 52 mutex mu_; 53 condition_variable ready_to_pop_; 54 condition_variable ready_to_push_; 55 std::deque<StatusOr<T>> results_ TF_GUARDED_BY(mu_); 56 Status status_ TF_GUARDED_BY(mu_) = Status::OK(); 57 58 TF_DISALLOW_COPY_AND_ASSIGN(ThreadSafeBuffer); 59 }; 60 61 template <class T> ThreadSafeBuffer(size_t buffer_size)62ThreadSafeBuffer<T>::ThreadSafeBuffer(size_t buffer_size) 63 : buffer_size_(buffer_size) { 64 DCHECK_GT(buffer_size, 0) 65 << "ThreadSafeBuffer must have a postive buffer size. Got " << buffer_size 66 << "."; 67 } 68 69 template <class T> Pop()70StatusOr<T> ThreadSafeBuffer<T>::Pop() { 71 mutex_lock l(mu_); 72 while (status_.ok() && results_.empty()) { 73 ready_to_pop_.wait(l); 74 } 75 if (!status_.ok()) { 76 return status_; 77 } 78 StatusOr<T> result = std::move(results_.front()); 79 results_.pop_front(); 80 ready_to_push_.notify_one(); 81 return result; 82 } 83 84 template <class T> Push(StatusOr<T> value)85Status ThreadSafeBuffer<T>::Push(StatusOr<T> value) { 86 mutex_lock l(mu_); 87 while (status_.ok() && results_.size() >= buffer_size_) { 88 ready_to_push_.wait(l); 89 } 90 if (!status_.ok()) { 91 return status_; 92 } 93 results_.push_back(std::move(value)); 94 ready_to_pop_.notify_one(); 95 return Status::OK(); 96 } 97 98 template <class T> Cancel(Status status)99void ThreadSafeBuffer<T>::Cancel(Status status) { 100 DCHECK(!status.ok()) 101 << "Cancelling ThreadSafeBuffer requires a non-OK status. Got " << status; 102 mutex_lock l(mu_); 103 status_ = std::move(status); 104 ready_to_push_.notify_all(); 105 ready_to_pop_.notify_all(); 106 } 107 108 } // namespace data 109 } // namespace tensorflow 110 111 #endif // TENSORFLOW_CORE_DATA_SERVICE_THREAD_SAFE_BUFFER_H_ 112