1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_QUEUE_BASE_H_ 17 #define TENSORFLOW_CORE_KERNELS_QUEUE_BASE_H_ 18 19 #include <deque> 20 #include <vector> 21 22 #include "absl/base/macros.h" 23 #include "tensorflow/core/framework/op_kernel.h" 24 #include "tensorflow/core/framework/queue_interface.h" 25 #include "tensorflow/core/framework/tensor.h" 26 #include "tensorflow/core/framework/tensor_shape.h" 27 #include "tensorflow/core/framework/types.h" 28 #include "tensorflow/core/lib/gtl/array_slice.h" 29 #include "tensorflow/core/platform/macros.h" 30 #include "tensorflow/core/platform/mutex.h" 31 #include "tensorflow/core/platform/types.h" 32 33 namespace tensorflow { 34 35 // Functionality common to asynchronous QueueInterface implementations. 36 class QueueBase : public QueueInterface { 37 public: 38 // As a possible value of 'capacity'. 39 static constexpr int32 kUnbounded = INT_MAX; 40 41 // Args: 42 // component_dtypes: The types of each component in a queue-element tuple. 43 // component_shapes: The shapes of each component in a queue-element tuple, 44 // which must either be empty (if the shapes are not specified) or 45 // or have the same size as component_dtypes. 46 // name: A name to use for the queue. 47 QueueBase(int32 capacity, const DataTypeVector& component_dtypes, 48 const std::vector<TensorShape>& component_shapes, 49 const string& name); 50 51 // Implementations of QueueInterface methods -------------------------------- component_dtypes()52 const DataTypeVector& component_dtypes() const override { 53 return component_dtypes_; 54 } 55 56 Status ValidateTuple(const Tuple& tuple) override; 57 Status ValidateManyTuple(const Tuple& tuple) override; 58 59 void Close(OpKernelContext* ctx, bool cancel_pending_enqueues, 60 DoneCallback callback) override; 61 62 // Other public methods ----------------------------------------------------- component_shapes()63 const std::vector<TensorShape>& component_shapes() const { 64 return component_shapes_; 65 } 66 capacity()67 int32 capacity() const { return capacity_; } 68 is_closed()69 bool is_closed() const override { 70 mutex_lock lock(mu_); 71 return closed_; 72 } 73 74 // Copies the index^th slice (in the first dimension) of parent into element. 75 static Status CopySliceToElement(const Tensor& parent, Tensor* element, 76 int64 index); 77 78 // Copies element into the index^th slice (in the first dimension) of parent. 79 // NOTE(mrry): This method is deprecated. Use 80 // `tensorflow::batch_util::CopySliceToElement()` defined in 81 // "./batch_util.h" instead. 82 ABSL_DEPRECATED( 83 "Use `tensorflow::batch_util::CopySliceToElement()` defined in " 84 "\"./batch_util.h\" instead.") 85 static Status CopyElementToSlice(const Tensor& element, Tensor* parent, 86 int64 index); 87 88 protected: 89 enum Action { kEnqueue, kDequeue }; 90 enum RunResult { kNoProgress, kProgress, kComplete }; 91 92 // Tries to enqueue/dequeue (or close) based on whatever is at the 93 // front of enqueue_attempts_/dequeue_attempts_. Appends to 94 // *finished the callback for any finished attempt (so it may be 95 // called once mu_ is released). Returns true if any progress was 96 // made. 97 struct CleanUp { CleanUpCleanUp98 CleanUp(DoneCallback&& f, CancellationToken ct, CancellationManager* cm) 99 : finished(f), to_deregister(ct), cm(cm) {} 100 DoneCallback finished; 101 CancellationToken to_deregister; 102 CancellationManager* cm; 103 }; 104 105 // Returns the number of components in a queue-element tuple. num_components()106 int32 num_components() const { return component_dtypes_.size(); } 107 108 // True if shapes were specified. If so, inputs will be validated 109 // against them, etc. specified_shapes()110 bool specified_shapes() const { return component_shapes_.size() > 0; } 111 112 // Code common to Validate*Tuple(). 113 Status ValidateTupleCommon(const Tuple& tuple) const; 114 ManyOutShape(int i,int64 batch_size)115 TensorShape ManyOutShape(int i, int64 batch_size) { 116 TensorShape shape({batch_size}); 117 shape.AppendShape(component_shapes_[i]); 118 return shape; 119 } 120 121 void Cancel(Action action, CancellationManager* cancellation_manager, 122 CancellationToken token); 123 124 // Helper for cancelling all pending Enqueue(Many) operations when 125 // Close is called with cancel_pending_enqueues. 126 void CloseAndCancel(); 127 128 bool TryAttemptLocked(Action action, std::vector<CleanUp>* clean_up) 129 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 130 131 // Tries to make progress on the enqueues or dequeues at the front 132 // of the *_attempts_ queues. 133 void FlushUnlocked(); 134 135 ~QueueBase() override; 136 137 // Helpers for implementing MatchesNodeDef(). 138 static string ShapeListString(const gtl::ArraySlice<TensorShape>& shapes); 139 Status MatchesNodeDefOp(const NodeDef& node_def, const string& op) const; 140 Status MatchesNodeDefCapacity(const NodeDef& node_def, int32 capacity) const; 141 Status MatchesNodeDefTypes(const NodeDef& node_def) const; 142 Status MatchesNodeDefShapes(const NodeDef& node_def) const; 143 144 protected: 145 const int32 capacity_; 146 const DataTypeVector component_dtypes_; 147 const std::vector<TensorShape> component_shapes_; 148 const string name_; 149 mutable mutex mu_; 150 bool closed_ TF_GUARDED_BY(mu_); 151 152 struct Attempt; 153 typedef std::function<RunResult(Attempt*)> RunCallback; 154 struct Attempt { 155 int32 elements_requested; 156 DoneCallback done_callback; // must be run outside mu_ 157 OpKernelContext* context; 158 CancellationManager* cancellation_manager; // not owned 159 CancellationToken cancellation_token; 160 RunCallback run_callback; // must be run while holding mu_ 161 bool is_cancelled; 162 Tuple tuple; 163 // tuples is used by some implementations allowing dynamic shapes. 164 std::vector<Tuple> tuples; 165 AttemptAttempt166 Attempt(int32 elements_requested, DoneCallback done_callback, 167 OpKernelContext* context, CancellationManager* cancellation_manager, 168 CancellationToken cancellation_token, RunCallback run_callback) 169 : elements_requested(elements_requested), 170 done_callback(done_callback), 171 context(context), 172 cancellation_manager(cancellation_manager), 173 cancellation_token(cancellation_token), 174 run_callback(run_callback), 175 is_cancelled(false) {} 176 }; 177 std::deque<Attempt> enqueue_attempts_ TF_GUARDED_BY(mu_); 178 std::deque<Attempt> dequeue_attempts_ TF_GUARDED_BY(mu_); 179 180 TF_DISALLOW_COPY_AND_ASSIGN(QueueBase); 181 }; 182 183 } // namespace tensorflow 184 185 #endif // TENSORFLOW_CORE_KERNELS_QUEUE_BASE_H_ 186