1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_QUEUE_BASE_H_ 17 #define TENSORFLOW_CORE_KERNELS_QUEUE_BASE_H_ 18 19 #include <deque> 20 #include <vector> 21 22 #include "absl/base/macros.h" 23 #include "tensorflow/core/framework/op_kernel.h" 24 #include "tensorflow/core/framework/queue_interface.h" 25 #include "tensorflow/core/framework/tensor.h" 26 #include "tensorflow/core/framework/tensor_shape.h" 27 #include "tensorflow/core/framework/types.h" 28 #include "tensorflow/core/lib/gtl/array_slice.h" 29 #include "tensorflow/core/platform/macros.h" 30 #include "tensorflow/core/platform/mutex.h" 31 #include "tensorflow/core/platform/types.h" 32 33 namespace tensorflow { 34 35 namespace barrier { 36 class Barrier; 37 } // namespace barrier 38 39 // Functionality common to asynchronous QueueInterface implementations. 40 class QueueBase : public QueueInterface { 41 public: 42 // As a possible value of 'capacity'. 43 static const int32 kUnbounded = INT_MAX; 44 45 // Args: 46 // component_dtypes: The types of each component in a queue-element tuple. 47 // component_shapes: The shapes of each component in a queue-element tuple, 48 // which must either be empty (if the shapes are not specified) or 49 // or have the same size as component_dtypes. 50 // name: A name to use for the queue. 51 QueueBase(int32 capacity, const DataTypeVector& component_dtypes, 52 const std::vector<TensorShape>& component_shapes, 53 const string& name); 54 55 // Implementations of QueueInterface methods -------------------------------- component_dtypes()56 const DataTypeVector& component_dtypes() const override { 57 return component_dtypes_; 58 } 59 60 Status ValidateTuple(const Tuple& tuple) override; 61 Status ValidateManyTuple(const Tuple& tuple) override; 62 63 void Close(OpKernelContext* ctx, bool cancel_pending_enqueues, 64 DoneCallback callback) override; 65 66 // Other public methods ----------------------------------------------------- component_shapes()67 const std::vector<TensorShape>& component_shapes() const { 68 return component_shapes_; 69 } 70 capacity()71 int32 capacity() const { return capacity_; } 72 is_closed()73 bool is_closed() const override { 74 mutex_lock lock(mu_); 75 return closed_; 76 } 77 78 // Copies the index^th slice (in the first dimension) of parent into element. 79 static Status CopySliceToElement(const Tensor& parent, Tensor* element, 80 int64 index); 81 82 // Copies element into the index^th slice (in the first dimension) of parent. 83 // NOTE(mrry): This method is deprecated. Use 84 // `tensorflow::batch_util::CopySliceToElement()` defined in 85 // "./batch_util.h" instead. 86 ABSL_DEPRECATED( 87 "Use `tensorflow::batch_util::CopySliceToElement()` defined in " 88 "\"./batch_util.h\" instead.") 89 static Status CopyElementToSlice(const Tensor& element, Tensor* parent, 90 int64 index); 91 92 protected: 93 enum Action { kEnqueue, kDequeue }; 94 enum RunResult { kNoProgress, kProgress, kComplete }; 95 96 // Tries to enqueue/dequeue (or close) based on whatever is at the 97 // front of enqueue_attempts_/dequeue_attempts_. Appends to 98 // *finished the callback for any finished attempt (so it may be 99 // called once mu_ is released). Returns true if any progress was 100 // made. 101 struct CleanUp { CleanUpCleanUp102 CleanUp(DoneCallback&& f, CancellationToken ct, CancellationManager* cm) 103 : finished(f), to_deregister(ct), cm(cm) {} 104 DoneCallback finished; 105 CancellationToken to_deregister; 106 CancellationManager* cm; 107 }; 108 109 // Returns the number of components in a queue-element tuple. num_components()110 int32 num_components() const { return component_dtypes_.size(); } 111 112 // True if shapes were specified. If so, inputs will be validated 113 // against them, etc. specified_shapes()114 bool specified_shapes() const { return component_shapes_.size() > 0; } 115 116 // Code common to Validate*Tuple(). 117 Status ValidateTupleCommon(const Tuple& tuple) const; 118 ManyOutShape(int i,int64 batch_size)119 TensorShape ManyOutShape(int i, int64 batch_size) { 120 TensorShape shape({batch_size}); 121 shape.AppendShape(component_shapes_[i]); 122 return shape; 123 } 124 125 void Cancel(Action action, CancellationManager* cancellation_manager, 126 CancellationToken token); 127 128 // Helper for cancelling all pending Enqueue(Many) operations when 129 // Close is called with cancel_pending_enqueues. 130 void CloseAndCancel(); 131 132 bool TryAttemptLocked(Action action, std::vector<CleanUp>* clean_up) 133 EXCLUSIVE_LOCKS_REQUIRED(mu_); 134 135 // Tries to make progress on the enqueues or dequeues at the front 136 // of the *_attempts_ queues. 137 void FlushUnlocked(); 138 139 ~QueueBase() override; 140 141 // Helpers for implementing MatchesNodeDef(). 142 static string ShapeListString(const gtl::ArraySlice<TensorShape>& shapes); 143 Status MatchesNodeDefOp(const NodeDef& node_def, const string& op) const; 144 Status MatchesNodeDefCapacity(const NodeDef& node_def, int32 capacity) const; 145 Status MatchesNodeDefTypes(const NodeDef& node_def) const; 146 Status MatchesNodeDefShapes(const NodeDef& node_def) const; 147 148 protected: 149 const int32 capacity_; 150 const DataTypeVector component_dtypes_; 151 const std::vector<TensorShape> component_shapes_; 152 const string name_; 153 mutable mutex mu_; 154 bool closed_ GUARDED_BY(mu_); 155 156 struct Attempt; 157 typedef std::function<RunResult(Attempt*)> RunCallback; 158 struct Attempt { 159 int32 elements_requested; 160 DoneCallback done_callback; // must be run outside mu_ 161 OpKernelContext* context; 162 CancellationManager* cancellation_manager; // not owned 163 CancellationToken cancellation_token; 164 RunCallback run_callback; // must be run while holding mu_ 165 bool is_cancelled; 166 Tuple tuple; 167 // tuples is used by some implementations allowing dynamic shapes. 168 std::vector<Tuple> tuples; 169 AttemptAttempt170 Attempt(int32 elements_requested, DoneCallback done_callback, 171 OpKernelContext* context, CancellationManager* cancellation_manager, 172 CancellationToken cancellation_token, RunCallback run_callback) 173 : elements_requested(elements_requested), 174 done_callback(done_callback), 175 context(context), 176 cancellation_manager(cancellation_manager), 177 cancellation_token(cancellation_token), 178 run_callback(run_callback), 179 is_cancelled(false) {} 180 }; 181 std::deque<Attempt> enqueue_attempts_ GUARDED_BY(mu_); 182 std::deque<Attempt> dequeue_attempts_ GUARDED_BY(mu_); 183 184 TF_DISALLOW_COPY_AND_ASSIGN(QueueBase); 185 }; 186 187 } // namespace tensorflow 188 189 #endif // TENSORFLOW_CORE_KERNELS_QUEUE_BASE_H_ 190