• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // See docs in ../ops/data_flow_ops.cc.
16 
17 #include <limits.h>
18 #include <unordered_map>
19 #include <vector>
20 
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/register_types.h"
23 #include "tensorflow/core/framework/resource_mgr.h"
24 #include "tensorflow/core/framework/resource_op_kernel.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/kernels/priority_queue.h"
29 #include "tensorflow/core/kernels/queue_base.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/notification.h"
32 #include "tensorflow/core/lib/gtl/map_util.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/platform/macros.h"
35 #include "tensorflow/core/platform/mutex.h"
36 #include "tensorflow/core/platform/thread_annotations.h"
37 #include "tensorflow/core/platform/types.h"
38 
39 namespace tensorflow {
40 
41 namespace barrier {
42 
43 class Barrier : public ResourceBase {
44  public:
45   typedef std::vector<Tensor> Tuple;
46   typedef std::function<void()> DoneCallback;
47   typedef std::function<void(const Tensor&, const Tensor&, const Tuple&)>
48       IndicesKeysValuesCallback;
49 
Barrier(const DataTypeVector & value_component_types,const std::vector<TensorShape> & value_component_shapes,const string & name)50   Barrier(const DataTypeVector& value_component_types,
51           const std::vector<TensorShape>& value_component_shapes,
52           const string& name)
53       : closed_(false),
54         queue_closed_(false),
55         queue_cancelled_(false),
56         cancel_pending_enqueues_(false),
57         value_component_types_(value_component_types),
58         value_component_shapes_(value_component_shapes),
59         name_(name),
60         input_index_(std::numeric_limits<int64>::min()) {
61     DataTypeVector queue_component_types;
62     std::vector<TensorShape> queue_component_shapes;
63 
64     // First queue component is for the input index;
65     // Second queue component is for the key;
66     // remaining queue components are for the value.
67     queue_component_types.push_back(DT_INT64);
68     queue_component_types.push_back(DT_STRING);
69     for (DataType dt : value_component_types) {
70       queue_component_types.push_back(dt);
71     }
72 
73     // NOTE(mrry): PriorityQueue expects all shapes specified because
74     // we'll be issuing TakeMany.
75     queue_component_shapes.push_back(TensorShape({}));
76     queue_component_shapes.push_back(TensorShape({}));
77     queue_component_shapes.insert(queue_component_shapes.end(),
78                                   value_component_shapes.begin(),
79                                   value_component_shapes.end());
80 
81     ready_queue_ = new PriorityQueue(
82         QueueBase::kUnbounded /* capacity */, queue_component_types,
83         queue_component_shapes, strings::StrCat(name_, "_queue"));
84   }
85 
Initialize()86   Status Initialize() { return ready_queue_->Initialize(); }
87 
88   template <typename T>
TryInsertMany(const Tensor & keys,int component_index,const Tensor & values,OpKernelContext * ctx,const DoneCallback & callback)89   void TryInsertMany(const Tensor& keys, int component_index,
90                      const Tensor& values, OpKernelContext* ctx,
91                      const DoneCallback& callback) {
92     TensorShape element_shape = values.shape();
93     OP_REQUIRES_ASYNC(
94         ctx, keys.NumElements() == 0 || element_shape.num_elements() > 0,
95         errors::InvalidArgument("Tensors with no elements are not supported ",
96                                 name_, ": received shape ",
97                                 element_shape.DebugString()),
98         callback);
99     if (element_shape.dims() > 0) element_shape.RemoveDim(0);
100     const std::size_t num_inserted = keys.NumElements();
101 
102     // For each key, update the corresponding incomplete tuple with the
103     // the corresponding given value at component_index.
104     // This will be passed to the final callback at the very end.
105     bool new_elements = false;
106 
107     // Will be used for the final insert into the queue.
108     Tuple insert_tuple;
109 
110     {
111       mutex_lock lock(mu_);
112       if (closed_) {
113         OP_REQUIRES_ASYNC(
114             ctx,
115             !cancel_pending_enqueues_ &&
116                 (num_inserted == 0 || !incomplete_.empty()),
117             errors::Cancelled(
118                 "Barrier ", name_, " is closed.  Pending enqueues cancelled: ",
119                 cancel_pending_enqueues_,
120                 ".  Number of new insertions: ", num_inserted,
121                 ".  Number of incomplete keys: ", incomplete_.size(), "."),
122             callback);
123       }
124 
125       // Step 1: insert into the incomplete map and identify which
126       // entries are, in fact, complete and ready for enqueueing.  Store
127       // them in a vector
128       std::vector<Tuple> ready_tuples;
129 
130       for (int i = 0; i < num_inserted; ++i) {
131         OP_REQUIRES_OK_ASYNC(
132             ctx,
133             InsertOneLocked<T>(ctx, keys, values, element_shape,
134                                component_index, i, &ready_tuples,
135                                &new_elements),
136             callback);
137       }
138 
139       if (new_elements) ++input_index_;
140 
141       // This probably won't happen before the heat death of the
142       // universe, but who knows?  Moore's law FTW.
143       OP_REQUIRES_ASYNC(
144           ctx, input_index_ != std::numeric_limits<int64>::max(),
145           errors::Internal(
146               "Barrier has had ", input_index_,
147               " insertions and can no longer keep track of new ones."),
148           callback);
149 
150       if (ready_tuples.empty()) {
151         // Nothing to insert into the queue - so return early.
152         callback();
153         return;
154       }
155 
156       // We have something to Enqueue.  Convert the Tuples into a single
157       // tuple by slicing entries into new Tensors.  This part is slow
158       // but seems the cleanest solution for now.
159       insert_tuple.reserve(2 + num_components());  // indices, keys, rest
160       int insertion_size = ready_tuples.size();
161       for (int i = 0; i < 2 + num_components(); ++i) {
162         TensorShape component_shape(ready_tuples[0][i].shape());
163         component_shape.InsertDim(0, insertion_size);
164         Tensor component(ready_tuples[0][i].dtype(), component_shape);
165         for (int b = 0; b < insertion_size; ++b) {
166           OP_REQUIRES_OK_ASYNC(
167               ctx,
168               batch_util::CopyElementToSlice(std::move(ready_tuples[b][i]),
169                                              &component, b),
170               callback);
171         }
172         insert_tuple.push_back(component);
173       }
174     }
175 
176     // Update the input index for the next batch.
177     ready_queue_->TryEnqueueMany(
178         insert_tuple, ctx,
179         // To avoid early closing of the queue, only close it if the
180         // SQSS is closed, nothing is left in the incomplete set,
181         // the queue is not already marked as closed, and (most
182         // importantly), the queue has entries in it.
183         [this, ctx, callback]() {
184           if (!ctx->status().ok()) {
185             callback();
186             return;
187           }
188           {
189             mutex_lock lock(mu_);
190             int32 ready = ready_size();
191             if (closed_ && incomplete_.empty() && queue_closed_ && ready > 0) {
192               CloseQueueLocked(ctx, false, callback);
193             } else {
194               callback();
195             }
196             return;
197           }
198         });
199   }
200 
TryTakeMany(int num_elements,bool allow_small_batch,int64 timeout,OpKernelContext * ctx,const IndicesKeysValuesCallback & callback)201   void TryTakeMany(int num_elements, bool allow_small_batch, int64 timeout,
202                    OpKernelContext* ctx,
203                    const IndicesKeysValuesCallback& callback) {
204     int num_elements_to_deliver = num_elements;
205     {
206       mutex_lock lock(mu_);
207       if (closed_) {
208         int available_elements = ready_size();
209         if (allow_small_batch) {
210           // We want to deliver a maximum of num_elements, if there are less
211           // elements available, we deliver at most the available_elements. If
212           // there are no
213           // elements available, a call to TryTakeMany should fail with
214           // OutOfRange. We trigger this error by setting the request here to 1.
215           num_elements_to_deliver = std::min(num_elements, available_elements);
216         } else {
217           // We're happy to wait for additional elements to be completed.
218           available_elements += incomplete_.size();
219         }
220         // If there are 0 available elements or less elements than the
221         // number we can deliver, then we are done.
222         if (available_elements < std::max(num_elements_to_deliver, 1)) {
223           ctx->SetStatus(errors::OutOfRange(
224               "Barrier '", name_, "' is closed and has ",
225               "insufficient elements (requested ", num_elements_to_deliver,
226               ", total size ", available_elements, ")"));
227           callback(Tensor(DT_INT64), Tensor(DT_STRING), Tuple());
228           return;
229         }
230       }
231     }
232 
233     ready_queue_->TryDequeueMany(
234         num_elements_to_deliver, ctx, allow_small_batch,
235         [this, ctx, callback](const Tuple& t) {
236           Tensor indices(DT_INT64);
237           Tensor keys(DT_STRING);
238           Tuple values;
239 
240           if (!ctx->status().ok()) {
241             callback(indices, keys, values);
242             return;
243           }
244 
245           CHECK_EQ(t.size(), 2 + num_components());
246           indices = t[0];
247           keys = t[1];
248           values.insert(values.begin(), t.begin() + 2, t.end());
249           callback(indices, keys, values);
250         });
251   }
252 
Close(OpKernelContext * ctx,bool cancel_pending_enqueues,const DoneCallback & callback)253   void Close(OpKernelContext* ctx, bool cancel_pending_enqueues,
254              const DoneCallback& callback) {
255     mutex_lock lock(mu_);
256     // We're allowed to close twice if the first close wasn't a
257     // cancel but the second one is.
258     if (closed_ && (cancel_pending_enqueues_ || !cancel_pending_enqueues)) {
259       ctx->SetStatus(
260           errors::Cancelled("Barrier '", name_, "' is already closed."));
261       callback();
262       return;
263     }
264     cancel_pending_enqueues_ = cancel_pending_enqueues;
265     closed_ = true;
266     if (cancel_pending_enqueues_ || incomplete_.empty()) {
267       incomplete_.clear();
268       // CloseQueueLocked runs the callback
269       CloseQueueLocked(ctx, cancel_pending_enqueues_, callback);
270       return;
271     }
272     callback();
273   }
274 
ready_size()275   int32 ready_size() { return ready_queue_->size(); }
276 
incomplete_size()277   int32 incomplete_size() {
278     mutex_lock lock(mu_);
279     return incomplete_.size();
280   }
281 
name() const282   const string& name() const { return name_; }
num_components() const283   int num_components() const { return value_component_types_.size(); }
component_type(int i) const284   DataType component_type(int i) const {
285     CHECK_GE(i, 0);
286     CHECK_LT(static_cast<size_t>(i), value_component_types_.size());
287     return value_component_types_[i];
288   }
component_types() const289   const DataTypeVector component_types() const {
290     return value_component_types_;
291   }
component_shapes() const292   const gtl::ArraySlice<TensorShape> component_shapes() const {
293     return value_component_shapes_;
294   }
295 
~Barrier()296   ~Barrier() override EXCLUSIVE_LOCKS_REQUIRED(mu_) {
297     mutex_lock lock(mu_);
298     incomplete_.clear();
299     ready_queue_->Unref();
300   }
301 
DebugString() const302   string DebugString() const override { return "A barrier"; }
303 
304  protected:
305   template <typename T>
InsertOneLocked(OpKernelContext * ctx,const Tensor & keys,const Tensor & values,const TensorShape & element_shape,int component_index,int i,std::vector<Tuple> * ready_tuples,bool * new_elements)306   Status InsertOneLocked(OpKernelContext* ctx, const Tensor& keys,
307                          const Tensor& values, const TensorShape& element_shape,
308                          int component_index, int i,
309                          std::vector<Tuple>* ready_tuples, bool* new_elements)
310       EXCLUSIVE_LOCKS_REQUIRED(mu_) {
311     auto keys_vec = keys.flat<string>();
312     auto values_matrix = values.flat_outer_dims<T>();
313 
314     PersistentTuple* element_ptr;
315     if (closed_) {
316       element_ptr = gtl::FindOrNull(incomplete_, keys_vec(i));
317       if (element_ptr == nullptr) {
318         return errors::Cancelled(
319             "Barrier ", name_,
320             " is closed, but attempted to insert a brand new key: ",
321             keys_vec(i),
322             ".  Pending enqueues cancelled: ", cancel_pending_enqueues_,
323             ".  Insertion index: ", i,
324             ".  Number of incomplete keys: ", incomplete_.size(), ".");
325       }
326     } else {
327       element_ptr =
328           &gtl::LookupOrInsert(&incomplete_, keys_vec(i), PersistentTuple());
329     }
330     PersistentTuple& element = *element_ptr;
331 
332     if (element.empty()) {  // Never seen before key
333       // Added a new element, for keeping track of the insertion index
334       *new_elements = true;
335 
336       // Initialize the incomplete tuple for a new key.
337       element.reserve(1 + num_components());
338 
339       // The first entry in element is the priority: the
340       // input_index_, so that tensors that entered the Barrier
341       // earlier have higher priority in the queue.
342       PersistentTensor index_persistent_tensor;
343       Tensor* allocate_index_tensor;
344       TF_RETURN_IF_ERROR(ctx->allocate_persistent(DT_INT64, TensorShape({}),
345                                                   &index_persistent_tensor,
346                                                   &allocate_index_tensor));
347 
348       Tensor index_tensor(DT_INT64, TensorShape({}));
349       allocate_index_tensor->scalar<int64>()() = input_index_;
350       element.push_back(index_persistent_tensor);
351 
352       // The rest of the element stores uninitialized Tensors with
353       // the appropriate dtype.
354       for (int j = 0; j < num_components(); ++j) {
355         Tensor uninitialized(component_type(j));
356         element.push_back(PersistentTensor(uninitialized));
357       }
358     }
359     const PersistentTensor& component = element[1 + component_index];
360     if (component.IsInitialized() && component.NumElements() > 0) {
361       return errors::InvalidArgument("Key ", keys_vec(i),
362                                      " already has a value for component ",
363                                      component_index, " in barrier ", name());
364     }
365 
366     // Extract the slice corresponding to the value from the value Tensor,
367     // and store it in the incomplete tuple at component_index.
368     PersistentTensor next_element;
369     Tensor* allocated_element;
370     TF_RETURN_IF_ERROR(ctx->allocate_persistent(
371         values.dtype(), element_shape, &next_element, &allocated_element));
372     element[1 + component_index] = next_element;
373     allocated_element->flat<T>() = values_matrix.template chip<0>(i);
374 
375     // Check the components of the tuple to see if it has become complete
376     // (i.e. all of its components are initialized). If so, add it to the
377     // ready queue.
378     bool is_complete = true;
379     for (int j = 0; is_complete && j < element.size(); ++j) {
380       is_complete = element[j].IsInitialized() && element[j].NumElements() > 0;
381     }
382     if (is_complete) {
383       // Add tuple to the ready queue. A queue tuple has the index
384       // as the first element and the key as the second element,
385       // followed by the value components.
386       Tuple ready_tuple;
387       ready_tuple.reserve(2 + num_components());  // index, key, rest
388       // Build a tensor for the key. TODO(mrry): Something more efficient.
389       PersistentTensor key;
390       Tensor* allocated_key;
391       TF_RETURN_IF_ERROR(ctx->allocate_persistent(DT_STRING, TensorShape({}),
392                                                   &key, &allocated_key));
393       ready_tuple.push_back(*element[0].AccessTensor(ctx));  // index
394       ready_tuple.push_back(*allocated_key);                 // key
395       ready_tuple[1].scalar<string>()() = keys_vec(i);       // set the key
396       for (int j = 1; j < num_components() + 1; ++j) {
397         ready_tuple.push_back(*element[j].AccessTensor(ctx));
398       }
399       incomplete_.erase(incomplete_.find(keys_vec(i)));
400       TF_RETURN_IF_ERROR(ready_queue_->ValidateTuple(ready_tuple));
401       ready_tuples->push_back(ready_tuple);
402     }
403     return Status::OK();
404   }
405 
CloseQueueLocked(OpKernelContext * ctx,bool cancel_pending_enqueues,const DoneCallback & callback)406   void CloseQueueLocked(OpKernelContext* ctx, bool cancel_pending_enqueues,
407                         const DoneCallback& callback)
408       EXCLUSIVE_LOCKS_REQUIRED(mu_) {
409     // CloseQueueLocked may only be called with mu_ held.
410     if (!cancel_pending_enqueues && queue_closed_) {
411       callback();
412       return;
413     }
414     if (cancel_pending_enqueues && queue_cancelled_) {
415       callback();
416       return;
417     }
418     queue_closed_ = true;
419     if (cancel_pending_enqueues) queue_cancelled_ = true;
420     if (!ready_queue_->is_closed()) {
421       ready_queue_->Close(ctx, cancel_pending_enqueues, callback);
422     }
423   }
424 
425  private:
426   typedef std::vector<PersistentTensor> PersistentTuple;
427   mutex mu_;
428   bool closed_ GUARDED_BY(mu_);
429   bool queue_closed_ GUARDED_BY(mu_);
430   bool queue_cancelled_ GUARDED_BY(mu_);
431   bool cancel_pending_enqueues_ GUARDED_BY(mu_);
432   const DataTypeVector value_component_types_;
433   const std::vector<TensorShape>& value_component_shapes_;
434   const string name_;
435   int64 input_index_ GUARDED_BY(mu_);
436   std::unordered_map<string, PersistentTuple> incomplete_ GUARDED_BY(mu_);
437   PriorityQueue* ready_queue_;
438 
439   TF_DISALLOW_COPY_AND_ASSIGN(Barrier);
440 };
441 
442 class BarrierOp : public ResourceOpKernel<Barrier> {
443  public:
BarrierOp(OpKernelConstruction * context)444   explicit BarrierOp(OpKernelConstruction* context)
445       : ResourceOpKernel(context) {
446     OP_REQUIRES_OK(
447         context, context->GetAttr("component_types", &value_component_types_));
448     OP_REQUIRES_OK(context,
449                    context->GetAttr("shapes", &value_component_shapes_));
450     OP_REQUIRES(context,
451                 value_component_shapes_.size() == value_component_types_.size(),
452                 errors::InvalidArgument(
453                     "All of the component shapes must be specified"));
454 
455     int32 value_capacity;
456     OP_REQUIRES_OK(context, context->GetAttr("capacity", &value_capacity));
457     OP_REQUIRES(context, value_capacity == -1,
458                 errors::InvalidArgument(
459                     "Barrier only accepts capacity=-1.  Feed the "
460                     "inputs to your Barrier through a queue to enforce a "
461                     "limited capacity."));
462   }
463 
464  private:
CreateResource(Barrier ** barrier)465   Status CreateResource(Barrier** barrier) override
466       EXCLUSIVE_LOCKS_REQUIRED(mu_) {
467     *barrier = new Barrier(value_component_types_, value_component_shapes_,
468                            cinfo_.name());
469     if (*barrier == nullptr) {
470       return errors::ResourceExhausted("Failed to allocate barrier");
471     }
472     return (*barrier)->Initialize();
473   }
474 
VerifyResource(Barrier * barrier)475   Status VerifyResource(Barrier* barrier) override
476       EXCLUSIVE_LOCKS_REQUIRED(mu_) {
477     if (barrier->component_types() != value_component_types_) {
478       return errors::InvalidArgument(
479           "Shared barrier '", cinfo_.name(), "' has component types ",
480           DataTypeSliceString(barrier->component_types()),
481           " but requested component types were ",
482           DataTypeSliceString(value_component_types_));
483     }
484     if (barrier->component_shapes() != value_component_shapes_) {
485       return errors::InvalidArgument(
486           "Shared barrier '", cinfo_.name(), "' has component shapes ",
487           TensorShapeUtils::ShapeListString(barrier->component_shapes()),
488           " but requested component shapes were ",
489           TensorShapeUtils::ShapeListString(value_component_shapes_));
490     }
491     return Status::OK();
492   }
493 
494   DataTypeVector value_component_types_;
495   std::vector<TensorShape> value_component_shapes_;
496 
497   TF_DISALLOW_COPY_AND_ASSIGN(BarrierOp);
498 };
499 
500 REGISTER_KERNEL_BUILDER(Name("Barrier").Device(DEVICE_CPU), BarrierOp);
501 
502 class BarrierOpKernel : public AsyncOpKernel {
503  public:
BarrierOpKernel(OpKernelConstruction * context)504   explicit BarrierOpKernel(OpKernelConstruction* context)
505       : AsyncOpKernel(context) {}
506 
ComputeAsync(OpKernelContext * ctx,DoneCallback callback)507   void ComputeAsync(OpKernelContext* ctx, DoneCallback callback) final {
508     Barrier* barrier = nullptr;
509     OP_REQUIRES_OK_ASYNC(ctx, GetResourceFromContext(ctx, "handle", &barrier),
510                          callback);
511     ComputeAsync(ctx, barrier, [callback, barrier]() {
512       barrier->Unref();
513       callback();
514     });
515   }
516 
517  protected:
518   virtual void ComputeAsync(OpKernelContext* ctx, Barrier* barrier,
519                             DoneCallback callback) = 0;
520 };
521 
522 template <typename T>
523 class InsertManyOp : public BarrierOpKernel {
524  public:
InsertManyOp(OpKernelConstruction * context)525   explicit InsertManyOp(OpKernelConstruction* context)
526       : BarrierOpKernel(context) {
527     OP_REQUIRES_OK(context,
528                    context->GetAttr("component_index", &component_index_));
529   }
530 
531  protected:
ComputeAsync(OpKernelContext * ctx,Barrier * barrier,DoneCallback callback)532   void ComputeAsync(OpKernelContext* ctx, Barrier* barrier,
533                     DoneCallback callback) override {
534     OP_REQUIRES_ASYNC(
535         ctx, component_index_ < barrier->num_components(),
536         errors::InvalidArgument("The component ID is out of range ",
537                                 component_index_, " > num_components",
538                                 " (= ", barrier->num_components(), ")"),
539         callback);
540     OP_REQUIRES_OK_ASYNC(
541         ctx,
542         ctx->MatchSignature({DT_STRING_REF, DT_STRING,
543                              barrier->component_type(component_index_)},
544                             {}),
545         callback);
546 
547     const Tensor* keys;
548     const Tensor* values;
549     OP_REQUIRES_OK_ASYNC(ctx, ctx->input("keys", &keys), callback);
550     OP_REQUIRES_OK_ASYNC(ctx, ctx->input("values", &values), callback);
551     barrier->TryInsertMany<T>(*keys, component_index_, *values, ctx, callback);
552   }
553 
554  private:
555   int component_index_;
556   TF_DISALLOW_COPY_AND_ASSIGN(InsertManyOp);
557 };
558 
559 #define REGISTER_INSERTMANY(T)                                             \
560   REGISTER_KERNEL_BUILDER(                                                 \
561       Name("BarrierInsertMany").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
562       InsertManyOp<T>);
563 
564 TF_CALL_ALL_TYPES(REGISTER_INSERTMANY);
565 #undef REGISTER_INSERTMANY
566 
567 class TakeManyOp : public BarrierOpKernel {
568  public:
TakeManyOp(OpKernelConstruction * context)569   explicit TakeManyOp(OpKernelConstruction* context)
570       : BarrierOpKernel(context) {
571     OP_REQUIRES_OK(context, context->GetAttr("timeout_ms", &timeout_));
572     // TODO(keveman): Enable timeout.
573     OP_REQUIRES(context, timeout_ == -1,
574                 errors::InvalidArgument("Timeout not supported yet."));
575 
576     OP_REQUIRES_OK(context,
577                    context->GetAttr("allow_small_batch", &allow_small_batch_));
578   }
579 
580  protected:
ComputeAsync(OpKernelContext * ctx,Barrier * barrier,DoneCallback callback)581   void ComputeAsync(OpKernelContext* ctx, Barrier* barrier,
582                     DoneCallback callback) override {
583     const Tensor* Tnum_elements;
584     OP_REQUIRES_OK_ASYNC(ctx, ctx->input("num_elements", &Tnum_elements),
585                          callback);
586     OP_REQUIRES_ASYNC(ctx, TensorShapeUtils::IsScalar(Tnum_elements->shape()),
587                       errors::InvalidArgument("num_elements must be a scalar."),
588                       callback);
589     const int32 num_elements = Tnum_elements->scalar<int32>()();
590 
591     DataTypeVector expected_inputs = {DT_STRING_REF, DT_INT32};
592     // The first output is the insertion index, the second output is the key.
593     DataTypeVector expected_outputs = {DT_INT64, DT_STRING};
594     for (DataType dt : barrier->component_types()) {
595       expected_outputs.push_back(dt);
596     }
597     OP_REQUIRES_OK_ASYNC(
598         ctx, ctx->MatchSignature(expected_inputs, expected_outputs), callback);
599 
600     barrier->TryTakeMany(
601         num_elements, allow_small_batch_, timeout_, ctx,
602         [ctx, callback](const Tensor& indices, const Tensor& keys,
603                         const Barrier::Tuple& values) {
604           if (!ctx->status().ok()) {
605             callback();
606             return;
607           }
608           // At this point, indices, keys, and values
609           // have all been written to successfully.
610           OP_REQUIRES_OK_ASYNC(ctx, ctx->set_output("indices", indices),
611                                callback);
612           OP_REQUIRES_OK_ASYNC(ctx, ctx->set_output("keys", keys), callback);
613           OpOutputList values_output;
614           OP_REQUIRES_OK_ASYNC(ctx, ctx->output_list("values", &values_output),
615                                callback);
616           for (size_t i = 0; i < values.size(); ++i) {
617             values_output.set(i, values[i]);
618           }
619           callback();
620         });
621   }
622 
623  private:
624   int64 timeout_;
625   bool allow_small_batch_;
626   TF_DISALLOW_COPY_AND_ASSIGN(TakeManyOp);
627 };
628 
629 REGISTER_KERNEL_BUILDER(Name("BarrierTakeMany").Device(DEVICE_CPU), TakeManyOp);
630 
631 class BarrierCloseOp : public BarrierOpKernel {
632  public:
BarrierCloseOp(OpKernelConstruction * context)633   explicit BarrierCloseOp(OpKernelConstruction* context)
634       : BarrierOpKernel(context) {
635     OP_REQUIRES_OK(context, context->GetAttr("cancel_pending_enqueues",
636                                              &cancel_pending_enqueues_));
637   }
638 
639  protected:
ComputeAsync(OpKernelContext * ctx,Barrier * barrier,DoneCallback callback)640   void ComputeAsync(OpKernelContext* ctx, Barrier* barrier,
641                     DoneCallback callback) override {
642     barrier->Close(ctx, cancel_pending_enqueues_, callback);
643   }
644 
645  private:
646   bool cancel_pending_enqueues_;
647   TF_DISALLOW_COPY_AND_ASSIGN(BarrierCloseOp);
648 };
649 
650 REGISTER_KERNEL_BUILDER(Name("BarrierClose").Device(DEVICE_CPU),
651                         BarrierCloseOp);
652 
653 class BarrierIncompleteSizeOp : public BarrierOpKernel {
654  public:
BarrierIncompleteSizeOp(OpKernelConstruction * context)655   explicit BarrierIncompleteSizeOp(OpKernelConstruction* context)
656       : BarrierOpKernel(context) {}
657 
658  protected:
ComputeAsync(OpKernelContext * ctx,Barrier * barrier,DoneCallback callback)659   void ComputeAsync(OpKernelContext* ctx, Barrier* barrier,
660                     DoneCallback callback) override {
661     Tensor* Tsize = nullptr;
662     OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(0, TensorShape({}), &Tsize),
663                          callback);
664     Tsize->scalar<int32>().setConstant(barrier->incomplete_size());
665     callback();
666   }
667 };
668 
669 REGISTER_KERNEL_BUILDER(Name("BarrierIncompleteSize").Device(DEVICE_CPU),
670                         BarrierIncompleteSizeOp);
671 
672 class BarrierReadySizeOp : public BarrierOpKernel {
673  public:
BarrierReadySizeOp(OpKernelConstruction * context)674   explicit BarrierReadySizeOp(OpKernelConstruction* context)
675       : BarrierOpKernel(context) {}
676 
677  protected:
ComputeAsync(OpKernelContext * ctx,Barrier * barrier,DoneCallback callback)678   void ComputeAsync(OpKernelContext* ctx, Barrier* barrier,
679                     DoneCallback callback) override {
680     Tensor* Tsize = nullptr;
681     OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(0, TensorShape({}), &Tsize),
682                          callback);
683     Tsize->scalar<int32>().setConstant(barrier->ready_size());
684     callback();
685   }
686 };
687 
688 REGISTER_KERNEL_BUILDER(Name("BarrierReadySize").Device(DEVICE_CPU),
689                         BarrierReadySizeOp);
690 
691 }  // namespace barrier
692 
693 }  // namespace tensorflow
694