• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_QUEUE_BASE_H_
17 #define TENSORFLOW_CORE_KERNELS_QUEUE_BASE_H_
18 
19 #include <deque>
20 #include <vector>
21 
22 #include "absl/base/macros.h"
23 #include "tensorflow/core/framework/op_kernel.h"
24 #include "tensorflow/core/framework/queue_interface.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/lib/gtl/array_slice.h"
29 #include "tensorflow/core/platform/macros.h"
30 #include "tensorflow/core/platform/mutex.h"
31 #include "tensorflow/core/platform/types.h"
32 
33 namespace tensorflow {
34 
35 namespace barrier {
36 class Barrier;
37 }  // namespace barrier
38 
39 // Functionality common to asynchronous QueueInterface implementations.
40 class QueueBase : public QueueInterface {
41  public:
42   // As a possible value of 'capacity'.
43   static const int32 kUnbounded = INT_MAX;
44 
45   // Args:
46   //   component_dtypes: The types of each component in a queue-element tuple.
47   //   component_shapes: The shapes of each component in a queue-element tuple,
48   //     which must either be empty (if the shapes are not specified) or
49   //     or have the same size as component_dtypes.
50   //   name: A name to use for the queue.
51   QueueBase(int32 capacity, const DataTypeVector& component_dtypes,
52             const std::vector<TensorShape>& component_shapes,
53             const string& name);
54 
55   // Implementations of QueueInterface methods --------------------------------
component_dtypes()56   const DataTypeVector& component_dtypes() const override {
57     return component_dtypes_;
58   }
59 
60   Status ValidateTuple(const Tuple& tuple) override;
61   Status ValidateManyTuple(const Tuple& tuple) override;
62 
63   void Close(OpKernelContext* ctx, bool cancel_pending_enqueues,
64              DoneCallback callback) override;
65 
66   // Other public methods -----------------------------------------------------
component_shapes()67   const std::vector<TensorShape>& component_shapes() const {
68     return component_shapes_;
69   }
70 
capacity()71   int32 capacity() const { return capacity_; }
72 
is_closed()73   bool is_closed() const override {
74     mutex_lock lock(mu_);
75     return closed_;
76   }
77 
78   // Copies the index^th slice (in the first dimension) of parent into element.
79   static Status CopySliceToElement(const Tensor& parent, Tensor* element,
80                                    int64 index);
81 
82   // Copies element into the index^th slice (in the first dimension) of parent.
83   // NOTE(mrry): This method is deprecated. Use
84   // `tensorflow::batch_util::CopySliceToElement()` defined in
85   // "./batch_util.h" instead.
86   ABSL_DEPRECATED(
87       "Use `tensorflow::batch_util::CopySliceToElement()` defined in "
88       "\"./batch_util.h\" instead.")
89   static Status CopyElementToSlice(const Tensor& element, Tensor* parent,
90                                    int64 index);
91 
92  protected:
93   enum Action { kEnqueue, kDequeue };
94   enum RunResult { kNoProgress, kProgress, kComplete };
95 
96   // Tries to enqueue/dequeue (or close) based on whatever is at the
97   // front of enqueue_attempts_/dequeue_attempts_.  Appends to
98   // *finished the callback for any finished attempt (so it may be
99   // called once mu_ is released).  Returns true if any progress was
100   // made.
101   struct CleanUp {
CleanUpCleanUp102     CleanUp(DoneCallback&& f, CancellationToken ct, CancellationManager* cm)
103         : finished(f), to_deregister(ct), cm(cm) {}
104     DoneCallback finished;
105     CancellationToken to_deregister;
106     CancellationManager* cm;
107   };
108 
109   // Returns the number of components in a queue-element tuple.
num_components()110   int32 num_components() const { return component_dtypes_.size(); }
111 
112   // True if shapes were specified.  If so, inputs will be validated
113   // against them, etc.
specified_shapes()114   bool specified_shapes() const { return component_shapes_.size() > 0; }
115 
116   // Code common to Validate*Tuple().
117   Status ValidateTupleCommon(const Tuple& tuple) const;
118 
ManyOutShape(int i,int64 batch_size)119   TensorShape ManyOutShape(int i, int64 batch_size) {
120     TensorShape shape({batch_size});
121     shape.AppendShape(component_shapes_[i]);
122     return shape;
123   }
124 
125   void Cancel(Action action, CancellationManager* cancellation_manager,
126               CancellationToken token);
127 
128   // Helper for cancelling all pending Enqueue(Many) operations when
129   // Close is called with cancel_pending_enqueues.
130   void CloseAndCancel();
131 
132   bool TryAttemptLocked(Action action, std::vector<CleanUp>* clean_up)
133       EXCLUSIVE_LOCKS_REQUIRED(mu_);
134 
135   // Tries to make progress on the enqueues or dequeues at the front
136   // of the *_attempts_ queues.
137   void FlushUnlocked();
138 
139   ~QueueBase() override;
140 
141   // Helpers for implementing MatchesNodeDef().
142   static string ShapeListString(const gtl::ArraySlice<TensorShape>& shapes);
143   Status MatchesNodeDefOp(const NodeDef& node_def, const string& op) const;
144   Status MatchesNodeDefCapacity(const NodeDef& node_def, int32 capacity) const;
145   Status MatchesNodeDefTypes(const NodeDef& node_def) const;
146   Status MatchesNodeDefShapes(const NodeDef& node_def) const;
147 
148  protected:
149   const int32 capacity_;
150   const DataTypeVector component_dtypes_;
151   const std::vector<TensorShape> component_shapes_;
152   const string name_;
153   mutable mutex mu_;
154   bool closed_ GUARDED_BY(mu_);
155 
156   struct Attempt;
157   typedef std::function<RunResult(Attempt*)> RunCallback;
158   struct Attempt {
159     int32 elements_requested;
160     DoneCallback done_callback;  // must be run outside mu_
161     OpKernelContext* context;
162     CancellationManager* cancellation_manager;  // not owned
163     CancellationToken cancellation_token;
164     RunCallback run_callback;  // must be run while holding mu_
165     bool is_cancelled;
166     Tuple tuple;
167     // tuples is used by some implementations allowing dynamic shapes.
168     std::vector<Tuple> tuples;
169 
AttemptAttempt170     Attempt(int32 elements_requested, DoneCallback done_callback,
171             OpKernelContext* context, CancellationManager* cancellation_manager,
172             CancellationToken cancellation_token, RunCallback run_callback)
173         : elements_requested(elements_requested),
174           done_callback(done_callback),
175           context(context),
176           cancellation_manager(cancellation_manager),
177           cancellation_token(cancellation_token),
178           run_callback(run_callback),
179           is_cancelled(false) {}
180   };
181   std::deque<Attempt> enqueue_attempts_ GUARDED_BY(mu_);
182   std::deque<Attempt> dequeue_attempts_ GUARDED_BY(mu_);
183 
184   TF_DISALLOW_COPY_AND_ASSIGN(QueueBase);
185 };
186 
187 }  // namespace tensorflow
188 
189 #endif  // TENSORFLOW_CORE_KERNELS_QUEUE_BASE_H_
190