1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 // See docs in ../ops/data_flow_ops.cc. 16 17 #include <limits.h> 18 #include <unordered_map> 19 #include <vector> 20 21 #include "tensorflow/core/framework/op_kernel.h" 22 #include "tensorflow/core/framework/register_types.h" 23 #include "tensorflow/core/framework/resource_mgr.h" 24 #include "tensorflow/core/framework/resource_op_kernel.h" 25 #include "tensorflow/core/framework/tensor.h" 26 #include "tensorflow/core/framework/tensor_shape.h" 27 #include "tensorflow/core/framework/types.h" 28 #include "tensorflow/core/kernels/priority_queue.h" 29 #include "tensorflow/core/kernels/queue_base.h" 30 #include "tensorflow/core/lib/core/errors.h" 31 #include "tensorflow/core/lib/core/notification.h" 32 #include "tensorflow/core/lib/gtl/map_util.h" 33 #include "tensorflow/core/platform/logging.h" 34 #include "tensorflow/core/platform/macros.h" 35 #include "tensorflow/core/platform/mutex.h" 36 #include "tensorflow/core/platform/thread_annotations.h" 37 #include "tensorflow/core/platform/types.h" 38 39 namespace tensorflow { 40 41 namespace barrier { 42 43 class Barrier : public ResourceBase { 44 public: 45 typedef std::vector<Tensor> Tuple; 46 typedef std::function<void()> DoneCallback; 47 typedef std::function<void(const Tensor&, const Tensor&, const Tuple&)> 48 IndicesKeysValuesCallback; 49 Barrier(const DataTypeVector & value_component_types,const std::vector<TensorShape> & value_component_shapes,const string & name)50 Barrier(const DataTypeVector& value_component_types, 51 const std::vector<TensorShape>& value_component_shapes, 52 const string& name) 53 : closed_(false), 54 queue_closed_(false), 55 queue_cancelled_(false), 56 cancel_pending_enqueues_(false), 57 value_component_types_(value_component_types), 58 value_component_shapes_(value_component_shapes), 59 name_(name), 60 input_index_(std::numeric_limits<int64>::min()) { 61 DataTypeVector queue_component_types; 62 std::vector<TensorShape> queue_component_shapes; 63 64 // First queue component is for the input index; 65 // Second queue component is for the key; 66 // remaining queue components are for the value. 67 queue_component_types.push_back(DT_INT64); 68 queue_component_types.push_back(DT_STRING); 69 for (DataType dt : value_component_types) { 70 queue_component_types.push_back(dt); 71 } 72 73 // NOTE(mrry): PriorityQueue expects all shapes specified because 74 // we'll be issuing TakeMany. 75 queue_component_shapes.push_back(TensorShape({})); 76 queue_component_shapes.push_back(TensorShape({})); 77 queue_component_shapes.insert(queue_component_shapes.end(), 78 value_component_shapes.begin(), 79 value_component_shapes.end()); 80 81 ready_queue_ = new PriorityQueue( 82 QueueBase::kUnbounded /* capacity */, queue_component_types, 83 queue_component_shapes, strings::StrCat(name_, "_queue")); 84 } 85 Initialize()86 Status Initialize() { return ready_queue_->Initialize(); } 87 88 template <typename T> TryInsertMany(const Tensor & keys,int component_index,const Tensor & values,OpKernelContext * ctx,const DoneCallback & callback)89 void TryInsertMany(const Tensor& keys, int component_index, 90 const Tensor& values, OpKernelContext* ctx, 91 const DoneCallback& callback) { 92 TensorShape element_shape = values.shape(); 93 OP_REQUIRES_ASYNC( 94 ctx, keys.NumElements() == 0 || element_shape.num_elements() > 0, 95 errors::InvalidArgument("Tensors with no elements are not supported ", 96 name_, ": received shape ", 97 element_shape.DebugString()), 98 callback); 99 if (element_shape.dims() > 0) element_shape.RemoveDim(0); 100 const std::size_t num_inserted = keys.NumElements(); 101 102 // For each key, update the corresponding incomplete tuple with the 103 // the corresponding given value at component_index. 104 // This will be passed to the final callback at the very end. 105 bool new_elements = false; 106 107 // Will be used for the final insert into the queue. 108 Tuple insert_tuple; 109 110 { 111 mutex_lock lock(mu_); 112 if (closed_) { 113 OP_REQUIRES_ASYNC( 114 ctx, 115 !cancel_pending_enqueues_ && 116 (num_inserted == 0 || !incomplete_.empty()), 117 errors::Cancelled( 118 "Barrier ", name_, " is closed. Pending enqueues cancelled: ", 119 cancel_pending_enqueues_, 120 ". Number of new insertions: ", num_inserted, 121 ". Number of incomplete keys: ", incomplete_.size(), "."), 122 callback); 123 } 124 125 // Step 1: insert into the incomplete map and identify which 126 // entries are, in fact, complete and ready for enqueueing. Store 127 // them in a vector 128 std::vector<Tuple> ready_tuples; 129 130 for (int i = 0; i < num_inserted; ++i) { 131 OP_REQUIRES_OK_ASYNC( 132 ctx, 133 InsertOneLocked<T>(ctx, keys, values, element_shape, 134 component_index, i, &ready_tuples, 135 &new_elements), 136 callback); 137 } 138 139 if (new_elements) ++input_index_; 140 141 // This probably won't happen before the heat death of the 142 // universe, but who knows? Moore's law FTW. 143 OP_REQUIRES_ASYNC( 144 ctx, input_index_ != std::numeric_limits<int64>::max(), 145 errors::Internal( 146 "Barrier has had ", input_index_, 147 " insertions and can no longer keep track of new ones."), 148 callback); 149 150 if (ready_tuples.empty()) { 151 // Nothing to insert into the queue - so return early. 152 callback(); 153 return; 154 } 155 156 // We have something to Enqueue. Convert the Tuples into a single 157 // tuple by slicing entries into new Tensors. This part is slow 158 // but seems the cleanest solution for now. 159 insert_tuple.reserve(2 + num_components()); // indices, keys, rest 160 int insertion_size = ready_tuples.size(); 161 for (int i = 0; i < 2 + num_components(); ++i) { 162 TensorShape component_shape(ready_tuples[0][i].shape()); 163 component_shape.InsertDim(0, insertion_size); 164 Tensor component(ready_tuples[0][i].dtype(), component_shape); 165 for (int b = 0; b < insertion_size; ++b) { 166 OP_REQUIRES_OK_ASYNC( 167 ctx, 168 batch_util::CopyElementToSlice(std::move(ready_tuples[b][i]), 169 &component, b), 170 callback); 171 } 172 insert_tuple.push_back(component); 173 } 174 } 175 176 // Update the input index for the next batch. 177 ready_queue_->TryEnqueueMany( 178 insert_tuple, ctx, 179 // To avoid early closing of the queue, only close it if the 180 // SQSS is closed, nothing is left in the incomplete set, 181 // the queue is not already marked as closed, and (most 182 // importantly), the queue has entries in it. 183 [this, ctx, callback]() { 184 if (!ctx->status().ok()) { 185 callback(); 186 return; 187 } 188 { 189 mutex_lock lock(mu_); 190 int32 ready = ready_size(); 191 if (closed_ && incomplete_.empty() && queue_closed_ && ready > 0) { 192 CloseQueueLocked(ctx, false, callback); 193 } else { 194 callback(); 195 } 196 return; 197 } 198 }); 199 } 200 TryTakeMany(int num_elements,bool allow_small_batch,int64 timeout,OpKernelContext * ctx,const IndicesKeysValuesCallback & callback)201 void TryTakeMany(int num_elements, bool allow_small_batch, int64 timeout, 202 OpKernelContext* ctx, 203 const IndicesKeysValuesCallback& callback) { 204 int num_elements_to_deliver = num_elements; 205 { 206 mutex_lock lock(mu_); 207 if (closed_) { 208 int available_elements = ready_size(); 209 if (allow_small_batch) { 210 // We want to deliver a maximum of num_elements, if there are less 211 // elements available, we deliver at most the available_elements. If 212 // there are no 213 // elements available, a call to TryTakeMany should fail with 214 // OutOfRange. We trigger this error by setting the request here to 1. 215 num_elements_to_deliver = std::min(num_elements, available_elements); 216 } else { 217 // We're happy to wait for additional elements to be completed. 218 available_elements += incomplete_.size(); 219 } 220 // If there are 0 available elements or less elements than the 221 // number we can deliver, then we are done. 222 if (available_elements < std::max(num_elements_to_deliver, 1)) { 223 ctx->SetStatus(errors::OutOfRange( 224 "Barrier '", name_, "' is closed and has ", 225 "insufficient elements (requested ", num_elements_to_deliver, 226 ", total size ", available_elements, ")")); 227 callback(Tensor(DT_INT64), Tensor(DT_STRING), Tuple()); 228 return; 229 } 230 } 231 } 232 233 ready_queue_->TryDequeueMany( 234 num_elements_to_deliver, ctx, allow_small_batch, 235 [this, ctx, callback](const Tuple& t) { 236 Tensor indices(DT_INT64); 237 Tensor keys(DT_STRING); 238 Tuple values; 239 240 if (!ctx->status().ok()) { 241 callback(indices, keys, values); 242 return; 243 } 244 245 CHECK_EQ(t.size(), 2 + num_components()); 246 indices = t[0]; 247 keys = t[1]; 248 values.insert(values.begin(), t.begin() + 2, t.end()); 249 callback(indices, keys, values); 250 }); 251 } 252 Close(OpKernelContext * ctx,bool cancel_pending_enqueues,const DoneCallback & callback)253 void Close(OpKernelContext* ctx, bool cancel_pending_enqueues, 254 const DoneCallback& callback) { 255 mutex_lock lock(mu_); 256 // We're allowed to close twice if the first close wasn't a 257 // cancel but the second one is. 258 if (closed_ && (cancel_pending_enqueues_ || !cancel_pending_enqueues)) { 259 ctx->SetStatus( 260 errors::Cancelled("Barrier '", name_, "' is already closed.")); 261 callback(); 262 return; 263 } 264 cancel_pending_enqueues_ = cancel_pending_enqueues; 265 closed_ = true; 266 if (cancel_pending_enqueues_ || incomplete_.empty()) { 267 incomplete_.clear(); 268 // CloseQueueLocked runs the callback 269 CloseQueueLocked(ctx, cancel_pending_enqueues_, callback); 270 return; 271 } 272 callback(); 273 } 274 ready_size()275 int32 ready_size() { return ready_queue_->size(); } 276 incomplete_size()277 int32 incomplete_size() { 278 mutex_lock lock(mu_); 279 return incomplete_.size(); 280 } 281 name() const282 const string& name() const { return name_; } num_components() const283 int num_components() const { return value_component_types_.size(); } component_type(int i) const284 DataType component_type(int i) const { 285 CHECK_GE(i, 0); 286 CHECK_LT(static_cast<size_t>(i), value_component_types_.size()); 287 return value_component_types_[i]; 288 } component_types() const289 const DataTypeVector component_types() const { 290 return value_component_types_; 291 } component_shapes() const292 const gtl::ArraySlice<TensorShape> component_shapes() const { 293 return value_component_shapes_; 294 } 295 ~Barrier()296 ~Barrier() override EXCLUSIVE_LOCKS_REQUIRED(mu_) { 297 mutex_lock lock(mu_); 298 incomplete_.clear(); 299 ready_queue_->Unref(); 300 } 301 DebugString() const302 string DebugString() const override { return "A barrier"; } 303 304 protected: 305 template <typename T> InsertOneLocked(OpKernelContext * ctx,const Tensor & keys,const Tensor & values,const TensorShape & element_shape,int component_index,int i,std::vector<Tuple> * ready_tuples,bool * new_elements)306 Status InsertOneLocked(OpKernelContext* ctx, const Tensor& keys, 307 const Tensor& values, const TensorShape& element_shape, 308 int component_index, int i, 309 std::vector<Tuple>* ready_tuples, bool* new_elements) 310 EXCLUSIVE_LOCKS_REQUIRED(mu_) { 311 auto keys_vec = keys.flat<string>(); 312 auto values_matrix = values.flat_outer_dims<T>(); 313 314 PersistentTuple* element_ptr; 315 if (closed_) { 316 element_ptr = gtl::FindOrNull(incomplete_, keys_vec(i)); 317 if (element_ptr == nullptr) { 318 return errors::Cancelled( 319 "Barrier ", name_, 320 " is closed, but attempted to insert a brand new key: ", 321 keys_vec(i), 322 ". Pending enqueues cancelled: ", cancel_pending_enqueues_, 323 ". Insertion index: ", i, 324 ". Number of incomplete keys: ", incomplete_.size(), "."); 325 } 326 } else { 327 element_ptr = 328 >l::LookupOrInsert(&incomplete_, keys_vec(i), PersistentTuple()); 329 } 330 PersistentTuple& element = *element_ptr; 331 332 if (element.empty()) { // Never seen before key 333 // Added a new element, for keeping track of the insertion index 334 *new_elements = true; 335 336 // Initialize the incomplete tuple for a new key. 337 element.reserve(1 + num_components()); 338 339 // The first entry in element is the priority: the 340 // input_index_, so that tensors that entered the Barrier 341 // earlier have higher priority in the queue. 342 PersistentTensor index_persistent_tensor; 343 Tensor* allocate_index_tensor; 344 TF_RETURN_IF_ERROR(ctx->allocate_persistent(DT_INT64, TensorShape({}), 345 &index_persistent_tensor, 346 &allocate_index_tensor)); 347 348 Tensor index_tensor(DT_INT64, TensorShape({})); 349 allocate_index_tensor->scalar<int64>()() = input_index_; 350 element.push_back(index_persistent_tensor); 351 352 // The rest of the element stores uninitialized Tensors with 353 // the appropriate dtype. 354 for (int j = 0; j < num_components(); ++j) { 355 Tensor uninitialized(component_type(j)); 356 element.push_back(PersistentTensor(uninitialized)); 357 } 358 } 359 const PersistentTensor& component = element[1 + component_index]; 360 if (component.IsInitialized() && component.NumElements() > 0) { 361 return errors::InvalidArgument("Key ", keys_vec(i), 362 " already has a value for component ", 363 component_index, " in barrier ", name()); 364 } 365 366 // Extract the slice corresponding to the value from the value Tensor, 367 // and store it in the incomplete tuple at component_index. 368 PersistentTensor next_element; 369 Tensor* allocated_element; 370 TF_RETURN_IF_ERROR(ctx->allocate_persistent( 371 values.dtype(), element_shape, &next_element, &allocated_element)); 372 element[1 + component_index] = next_element; 373 allocated_element->flat<T>() = values_matrix.template chip<0>(i); 374 375 // Check the components of the tuple to see if it has become complete 376 // (i.e. all of its components are initialized). If so, add it to the 377 // ready queue. 378 bool is_complete = true; 379 for (int j = 0; is_complete && j < element.size(); ++j) { 380 is_complete = element[j].IsInitialized() && element[j].NumElements() > 0; 381 } 382 if (is_complete) { 383 // Add tuple to the ready queue. A queue tuple has the index 384 // as the first element and the key as the second element, 385 // followed by the value components. 386 Tuple ready_tuple; 387 ready_tuple.reserve(2 + num_components()); // index, key, rest 388 // Build a tensor for the key. TODO(mrry): Something more efficient. 389 PersistentTensor key; 390 Tensor* allocated_key; 391 TF_RETURN_IF_ERROR(ctx->allocate_persistent(DT_STRING, TensorShape({}), 392 &key, &allocated_key)); 393 ready_tuple.push_back(*element[0].AccessTensor(ctx)); // index 394 ready_tuple.push_back(*allocated_key); // key 395 ready_tuple[1].scalar<string>()() = keys_vec(i); // set the key 396 for (int j = 1; j < num_components() + 1; ++j) { 397 ready_tuple.push_back(*element[j].AccessTensor(ctx)); 398 } 399 incomplete_.erase(incomplete_.find(keys_vec(i))); 400 TF_RETURN_IF_ERROR(ready_queue_->ValidateTuple(ready_tuple)); 401 ready_tuples->push_back(ready_tuple); 402 } 403 return Status::OK(); 404 } 405 CloseQueueLocked(OpKernelContext * ctx,bool cancel_pending_enqueues,const DoneCallback & callback)406 void CloseQueueLocked(OpKernelContext* ctx, bool cancel_pending_enqueues, 407 const DoneCallback& callback) 408 EXCLUSIVE_LOCKS_REQUIRED(mu_) { 409 // CloseQueueLocked may only be called with mu_ held. 410 if (!cancel_pending_enqueues && queue_closed_) { 411 callback(); 412 return; 413 } 414 if (cancel_pending_enqueues && queue_cancelled_) { 415 callback(); 416 return; 417 } 418 queue_closed_ = true; 419 if (cancel_pending_enqueues) queue_cancelled_ = true; 420 if (!ready_queue_->is_closed()) { 421 ready_queue_->Close(ctx, cancel_pending_enqueues, callback); 422 } 423 } 424 425 private: 426 typedef std::vector<PersistentTensor> PersistentTuple; 427 mutex mu_; 428 bool closed_ GUARDED_BY(mu_); 429 bool queue_closed_ GUARDED_BY(mu_); 430 bool queue_cancelled_ GUARDED_BY(mu_); 431 bool cancel_pending_enqueues_ GUARDED_BY(mu_); 432 const DataTypeVector value_component_types_; 433 const std::vector<TensorShape>& value_component_shapes_; 434 const string name_; 435 int64 input_index_ GUARDED_BY(mu_); 436 std::unordered_map<string, PersistentTuple> incomplete_ GUARDED_BY(mu_); 437 PriorityQueue* ready_queue_; 438 439 TF_DISALLOW_COPY_AND_ASSIGN(Barrier); 440 }; 441 442 class BarrierOp : public ResourceOpKernel<Barrier> { 443 public: BarrierOp(OpKernelConstruction * context)444 explicit BarrierOp(OpKernelConstruction* context) 445 : ResourceOpKernel(context) { 446 OP_REQUIRES_OK( 447 context, context->GetAttr("component_types", &value_component_types_)); 448 OP_REQUIRES_OK(context, 449 context->GetAttr("shapes", &value_component_shapes_)); 450 OP_REQUIRES(context, 451 value_component_shapes_.size() == value_component_types_.size(), 452 errors::InvalidArgument( 453 "All of the component shapes must be specified")); 454 455 int32 value_capacity; 456 OP_REQUIRES_OK(context, context->GetAttr("capacity", &value_capacity)); 457 OP_REQUIRES(context, value_capacity == -1, 458 errors::InvalidArgument( 459 "Barrier only accepts capacity=-1. Feed the " 460 "inputs to your Barrier through a queue to enforce a " 461 "limited capacity.")); 462 } 463 464 private: CreateResource(Barrier ** barrier)465 Status CreateResource(Barrier** barrier) override 466 EXCLUSIVE_LOCKS_REQUIRED(mu_) { 467 *barrier = new Barrier(value_component_types_, value_component_shapes_, 468 cinfo_.name()); 469 if (*barrier == nullptr) { 470 return errors::ResourceExhausted("Failed to allocate barrier"); 471 } 472 return (*barrier)->Initialize(); 473 } 474 VerifyResource(Barrier * barrier)475 Status VerifyResource(Barrier* barrier) override 476 EXCLUSIVE_LOCKS_REQUIRED(mu_) { 477 if (barrier->component_types() != value_component_types_) { 478 return errors::InvalidArgument( 479 "Shared barrier '", cinfo_.name(), "' has component types ", 480 DataTypeSliceString(barrier->component_types()), 481 " but requested component types were ", 482 DataTypeSliceString(value_component_types_)); 483 } 484 if (barrier->component_shapes() != value_component_shapes_) { 485 return errors::InvalidArgument( 486 "Shared barrier '", cinfo_.name(), "' has component shapes ", 487 TensorShapeUtils::ShapeListString(barrier->component_shapes()), 488 " but requested component shapes were ", 489 TensorShapeUtils::ShapeListString(value_component_shapes_)); 490 } 491 return Status::OK(); 492 } 493 494 DataTypeVector value_component_types_; 495 std::vector<TensorShape> value_component_shapes_; 496 497 TF_DISALLOW_COPY_AND_ASSIGN(BarrierOp); 498 }; 499 500 REGISTER_KERNEL_BUILDER(Name("Barrier").Device(DEVICE_CPU), BarrierOp); 501 502 class BarrierOpKernel : public AsyncOpKernel { 503 public: BarrierOpKernel(OpKernelConstruction * context)504 explicit BarrierOpKernel(OpKernelConstruction* context) 505 : AsyncOpKernel(context) {} 506 ComputeAsync(OpKernelContext * ctx,DoneCallback callback)507 void ComputeAsync(OpKernelContext* ctx, DoneCallback callback) final { 508 Barrier* barrier = nullptr; 509 OP_REQUIRES_OK_ASYNC(ctx, GetResourceFromContext(ctx, "handle", &barrier), 510 callback); 511 ComputeAsync(ctx, barrier, [callback, barrier]() { 512 barrier->Unref(); 513 callback(); 514 }); 515 } 516 517 protected: 518 virtual void ComputeAsync(OpKernelContext* ctx, Barrier* barrier, 519 DoneCallback callback) = 0; 520 }; 521 522 template <typename T> 523 class InsertManyOp : public BarrierOpKernel { 524 public: InsertManyOp(OpKernelConstruction * context)525 explicit InsertManyOp(OpKernelConstruction* context) 526 : BarrierOpKernel(context) { 527 OP_REQUIRES_OK(context, 528 context->GetAttr("component_index", &component_index_)); 529 } 530 531 protected: ComputeAsync(OpKernelContext * ctx,Barrier * barrier,DoneCallback callback)532 void ComputeAsync(OpKernelContext* ctx, Barrier* barrier, 533 DoneCallback callback) override { 534 OP_REQUIRES_ASYNC( 535 ctx, component_index_ < barrier->num_components(), 536 errors::InvalidArgument("The component ID is out of range ", 537 component_index_, " > num_components", 538 " (= ", barrier->num_components(), ")"), 539 callback); 540 OP_REQUIRES_OK_ASYNC( 541 ctx, 542 ctx->MatchSignature({DT_STRING_REF, DT_STRING, 543 barrier->component_type(component_index_)}, 544 {}), 545 callback); 546 547 const Tensor* keys; 548 const Tensor* values; 549 OP_REQUIRES_OK_ASYNC(ctx, ctx->input("keys", &keys), callback); 550 OP_REQUIRES_OK_ASYNC(ctx, ctx->input("values", &values), callback); 551 barrier->TryInsertMany<T>(*keys, component_index_, *values, ctx, callback); 552 } 553 554 private: 555 int component_index_; 556 TF_DISALLOW_COPY_AND_ASSIGN(InsertManyOp); 557 }; 558 559 #define REGISTER_INSERTMANY(T) \ 560 REGISTER_KERNEL_BUILDER( \ 561 Name("BarrierInsertMany").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 562 InsertManyOp<T>); 563 564 TF_CALL_ALL_TYPES(REGISTER_INSERTMANY); 565 #undef REGISTER_INSERTMANY 566 567 class TakeManyOp : public BarrierOpKernel { 568 public: TakeManyOp(OpKernelConstruction * context)569 explicit TakeManyOp(OpKernelConstruction* context) 570 : BarrierOpKernel(context) { 571 OP_REQUIRES_OK(context, context->GetAttr("timeout_ms", &timeout_)); 572 // TODO(keveman): Enable timeout. 573 OP_REQUIRES(context, timeout_ == -1, 574 errors::InvalidArgument("Timeout not supported yet.")); 575 576 OP_REQUIRES_OK(context, 577 context->GetAttr("allow_small_batch", &allow_small_batch_)); 578 } 579 580 protected: ComputeAsync(OpKernelContext * ctx,Barrier * barrier,DoneCallback callback)581 void ComputeAsync(OpKernelContext* ctx, Barrier* barrier, 582 DoneCallback callback) override { 583 const Tensor* Tnum_elements; 584 OP_REQUIRES_OK_ASYNC(ctx, ctx->input("num_elements", &Tnum_elements), 585 callback); 586 OP_REQUIRES_ASYNC(ctx, TensorShapeUtils::IsScalar(Tnum_elements->shape()), 587 errors::InvalidArgument("num_elements must be a scalar."), 588 callback); 589 const int32 num_elements = Tnum_elements->scalar<int32>()(); 590 591 DataTypeVector expected_inputs = {DT_STRING_REF, DT_INT32}; 592 // The first output is the insertion index, the second output is the key. 593 DataTypeVector expected_outputs = {DT_INT64, DT_STRING}; 594 for (DataType dt : barrier->component_types()) { 595 expected_outputs.push_back(dt); 596 } 597 OP_REQUIRES_OK_ASYNC( 598 ctx, ctx->MatchSignature(expected_inputs, expected_outputs), callback); 599 600 barrier->TryTakeMany( 601 num_elements, allow_small_batch_, timeout_, ctx, 602 [ctx, callback](const Tensor& indices, const Tensor& keys, 603 const Barrier::Tuple& values) { 604 if (!ctx->status().ok()) { 605 callback(); 606 return; 607 } 608 // At this point, indices, keys, and values 609 // have all been written to successfully. 610 OP_REQUIRES_OK_ASYNC(ctx, ctx->set_output("indices", indices), 611 callback); 612 OP_REQUIRES_OK_ASYNC(ctx, ctx->set_output("keys", keys), callback); 613 OpOutputList values_output; 614 OP_REQUIRES_OK_ASYNC(ctx, ctx->output_list("values", &values_output), 615 callback); 616 for (size_t i = 0; i < values.size(); ++i) { 617 values_output.set(i, values[i]); 618 } 619 callback(); 620 }); 621 } 622 623 private: 624 int64 timeout_; 625 bool allow_small_batch_; 626 TF_DISALLOW_COPY_AND_ASSIGN(TakeManyOp); 627 }; 628 629 REGISTER_KERNEL_BUILDER(Name("BarrierTakeMany").Device(DEVICE_CPU), TakeManyOp); 630 631 class BarrierCloseOp : public BarrierOpKernel { 632 public: BarrierCloseOp(OpKernelConstruction * context)633 explicit BarrierCloseOp(OpKernelConstruction* context) 634 : BarrierOpKernel(context) { 635 OP_REQUIRES_OK(context, context->GetAttr("cancel_pending_enqueues", 636 &cancel_pending_enqueues_)); 637 } 638 639 protected: ComputeAsync(OpKernelContext * ctx,Barrier * barrier,DoneCallback callback)640 void ComputeAsync(OpKernelContext* ctx, Barrier* barrier, 641 DoneCallback callback) override { 642 barrier->Close(ctx, cancel_pending_enqueues_, callback); 643 } 644 645 private: 646 bool cancel_pending_enqueues_; 647 TF_DISALLOW_COPY_AND_ASSIGN(BarrierCloseOp); 648 }; 649 650 REGISTER_KERNEL_BUILDER(Name("BarrierClose").Device(DEVICE_CPU), 651 BarrierCloseOp); 652 653 class BarrierIncompleteSizeOp : public BarrierOpKernel { 654 public: BarrierIncompleteSizeOp(OpKernelConstruction * context)655 explicit BarrierIncompleteSizeOp(OpKernelConstruction* context) 656 : BarrierOpKernel(context) {} 657 658 protected: ComputeAsync(OpKernelContext * ctx,Barrier * barrier,DoneCallback callback)659 void ComputeAsync(OpKernelContext* ctx, Barrier* barrier, 660 DoneCallback callback) override { 661 Tensor* Tsize = nullptr; 662 OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(0, TensorShape({}), &Tsize), 663 callback); 664 Tsize->scalar<int32>().setConstant(barrier->incomplete_size()); 665 callback(); 666 } 667 }; 668 669 REGISTER_KERNEL_BUILDER(Name("BarrierIncompleteSize").Device(DEVICE_CPU), 670 BarrierIncompleteSizeOp); 671 672 class BarrierReadySizeOp : public BarrierOpKernel { 673 public: BarrierReadySizeOp(OpKernelConstruction * context)674 explicit BarrierReadySizeOp(OpKernelConstruction* context) 675 : BarrierOpKernel(context) {} 676 677 protected: ComputeAsync(OpKernelContext * ctx,Barrier * barrier,DoneCallback callback)678 void ComputeAsync(OpKernelContext* ctx, Barrier* barrier, 679 DoneCallback callback) override { 680 Tensor* Tsize = nullptr; 681 OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(0, TensorShape({}), &Tsize), 682 callback); 683 Tsize->scalar<int32>().setConstant(barrier->ready_size()); 684 callback(); 685 } 686 }; 687 688 REGISTER_KERNEL_BUILDER(Name("BarrierReadySize").Device(DEVICE_CPU), 689 BarrierReadySizeOp); 690 691 } // namespace barrier 692 693 } // namespace tensorflow 694