1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14#============================================================================== 15"""Data Flow Operations.""" 16# pylint: disable=g-bad-name 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import hashlib 23import threading 24 25import six 26 27from tensorflow.python.eager import context 28from tensorflow.python.framework import dtypes as _dtypes 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import random_seed 31from tensorflow.python.framework import tensor_shape 32from tensorflow.python.framework import tensor_util 33from tensorflow.python.lib.io import python_io 34from tensorflow.python.ops import array_ops 35from tensorflow.python.ops import control_flow_ops 36from tensorflow.python.ops import gen_data_flow_ops 37from tensorflow.python.ops import math_ops 38from tensorflow.python.ops import resource_variable_ops 39# go/tf-wildcard-import 40# pylint: disable=wildcard-import 41from tensorflow.python.ops.gen_data_flow_ops import * 42from tensorflow.python.util import deprecation 43from tensorflow.python.util.tf_export import tf_export 44 45# pylint: enable=wildcard-import 46 47 48def _as_type_list(dtypes): 49 """Convert dtypes to a list of types.""" 50 assert dtypes is not None 51 if not (isinstance(dtypes, list) or isinstance(dtypes, tuple)): 52 # We have a single type. 53 return [dtypes] 54 else: 55 # We have a list or tuple of types. 56 return list(dtypes) 57 58 59def _as_shape_list(shapes, 60 dtypes, 61 unknown_dim_allowed=False, 62 unknown_rank_allowed=False): 63 """Convert shapes to a list of tuples of int (or None).""" 64 del dtypes 65 if unknown_dim_allowed: 66 if (not isinstance(shapes, collections.Sequence) or not shapes or 67 any(shape is None or isinstance(shape, int) for shape in shapes)): 68 raise ValueError( 69 "When providing partial shapes, a list of shapes must be provided.") 70 if shapes is None: 71 return None 72 if isinstance(shapes, tensor_shape.TensorShape): 73 shapes = [shapes] 74 if not isinstance(shapes, (tuple, list)): 75 raise TypeError( 76 "shapes must be a TensorShape or a list or tuple of TensorShapes.") 77 if all(shape is None or isinstance(shape, int) for shape in shapes): 78 # We have a single shape. 79 shapes = [shapes] 80 shapes = [tensor_shape.as_shape(shape) for shape in shapes] 81 if not unknown_dim_allowed: 82 if any(not shape.is_fully_defined() for shape in shapes): 83 raise ValueError("All shapes must be fully defined: %s" % shapes) 84 if not unknown_rank_allowed: 85 if any([shape.dims is None for shape in shapes]): 86 raise ValueError("All shapes must have a defined rank: %s" % shapes) 87 88 return shapes 89 90 91def _as_name_list(names, dtypes): 92 if names is None: 93 return None 94 if not isinstance(names, (list, tuple)): 95 names = [names] 96 if len(names) != len(dtypes): 97 raise ValueError("List of names must have the same length as the list " 98 "of dtypes") 99 return list(names) 100 101 102def _shape_common(s1, s2): 103 """The greatest lower bound (ordered by specificity) TensorShape.""" 104 s1 = tensor_shape.TensorShape(s1) 105 s2 = tensor_shape.TensorShape(s2) 106 if s1.ndims is None or s2.ndims is None or s1.ndims != s2.ndims: 107 return tensor_shape.unknown_shape() 108 d = [ 109 d1 if d1 is not None and d1 == d2 else None 110 for (d1, d2) in zip(s1.as_list(), s2.as_list()) 111 ] 112 return tensor_shape.TensorShape(d) 113 114 115# pylint: disable=protected-access 116@tf_export("queue.QueueBase", 117 v1=["queue.QueueBase", "io.QueueBase", "QueueBase"]) 118@deprecation.deprecated_endpoints(["io.QueueBase", "QueueBase"]) 119class QueueBase(object): 120 """Base class for queue implementations. 121 122 A queue is a TensorFlow data structure that stores tensors across 123 multiple steps, and exposes operations that enqueue and dequeue 124 tensors. 125 126 Each queue element is a tuple of one or more tensors, where each 127 tuple component has a static dtype, and may have a static shape. The 128 queue implementations support versions of enqueue and dequeue that 129 handle single elements, versions that support enqueuing and 130 dequeuing a batch of elements at once. 131 132 See `tf.FIFOQueue` and 133 `tf.RandomShuffleQueue` for concrete 134 implementations of this class, and instructions on how to create 135 them. 136 """ 137 138 def __init__(self, dtypes, shapes, names, queue_ref): 139 """Constructs a queue object from a queue reference. 140 141 The two optional lists, `shapes` and `names`, must be of the same length 142 as `dtypes` if provided. The values at a given index `i` indicate the 143 shape and name to use for the corresponding queue component in `dtypes`. 144 145 Args: 146 dtypes: A list of types. The length of dtypes must equal the number 147 of tensors in each element. 148 shapes: Constraints on the shapes of tensors in an element: 149 A list of shape tuples or None. This list is the same length 150 as dtypes. If the shape of any tensors in the element are constrained, 151 all must be; shapes can be None if the shapes should not be constrained. 152 names: Optional list of names. If provided, the `enqueue()` and 153 `dequeue()` methods will use dictionaries with these names as keys. 154 Must be None or a list or tuple of the same length as `dtypes`. 155 queue_ref: The queue reference, i.e. the output of the queue op. 156 157 Raises: 158 ValueError: If one of the arguments is invalid. 159 """ 160 self._dtypes = dtypes 161 if shapes is not None: 162 if len(shapes) != len(dtypes): 163 raise ValueError("Queue shapes must have the same length as dtypes") 164 self._shapes = [tensor_shape.TensorShape(s) for s in shapes] 165 else: 166 self._shapes = [tensor_shape.unknown_shape() for _ in self._dtypes] 167 if names is not None: 168 if len(names) != len(dtypes): 169 raise ValueError("Queue names must have the same length as dtypes") 170 self._names = names 171 else: 172 self._names = None 173 self._queue_ref = queue_ref 174 if context.executing_eagerly(): 175 if context.context().scope_name: 176 self._name = context.context().scope_name 177 else: 178 self._name = "Empty" 179 self._resource_deleter = resource_variable_ops.EagerResourceDeleter( 180 queue_ref, None) 181 else: 182 self._name = self._queue_ref.op.name.split("/")[-1] 183 184 @staticmethod 185 def from_list(index, queues): 186 """Create a queue using the queue reference from `queues[index]`. 187 188 Args: 189 index: An integer scalar tensor that determines the input that gets 190 selected. 191 queues: A list of `QueueBase` objects. 192 193 Returns: 194 A `QueueBase` object. 195 196 Raises: 197 TypeError: When `queues` is not a list of `QueueBase` objects, 198 or when the data types of `queues` are not all the same. 199 """ 200 if ((not queues) or (not isinstance(queues, list)) or 201 (not all(isinstance(x, QueueBase) for x in queues))): 202 raise TypeError("A list of queues expected") 203 204 dtypes = queues[0].dtypes 205 if not all(dtypes == q.dtypes for q in queues[1:]): 206 raise TypeError("Queues do not have matching component dtypes.") 207 208 names = queues[0].names 209 if not all(names == q.names for q in queues[1:]): 210 raise TypeError("Queues do not have matching component names.") 211 212 queue_shapes = [q.shapes for q in queues] 213 reduced_shapes = [ 214 six.moves.reduce(_shape_common, s) for s in zip(*queue_shapes) 215 ] 216 217 queue_refs = array_ops.stack([x.queue_ref for x in queues]) 218 selected_queue = array_ops.gather(queue_refs, index) 219 return QueueBase( 220 dtypes=dtypes, 221 shapes=reduced_shapes, 222 names=names, 223 queue_ref=selected_queue) 224 225 @property 226 def queue_ref(self): 227 """The underlying queue reference.""" 228 return self._queue_ref 229 230 @property 231 def name(self): 232 """The name of the underlying queue.""" 233 if context.executing_eagerly(): 234 return self._name 235 return self._queue_ref.op.name 236 237 @property 238 def dtypes(self): 239 """The list of dtypes for each component of a queue element.""" 240 return self._dtypes 241 242 @property 243 def shapes(self): 244 """The list of shapes for each component of a queue element.""" 245 return self._shapes 246 247 @property 248 def names(self): 249 """The list of names for each component of a queue element.""" 250 return self._names 251 252 def _check_enqueue_dtypes(self, vals): 253 """Validate and convert `vals` to a list of `Tensor`s. 254 255 The `vals` argument can be a Tensor, a list or tuple of tensors, or a 256 dictionary with tensor values. 257 258 If it is a dictionary, the queue must have been constructed with a 259 `names` attribute and the dictionary keys must match the queue names. 260 If the queue was constructed with a `names` attribute, `vals` must 261 be a dictionary. 262 263 Args: 264 vals: A tensor, a list or tuple of tensors, or a dictionary.. 265 266 Returns: 267 A list of `Tensor` objects. 268 269 Raises: 270 ValueError: If `vals` is invalid. 271 """ 272 if isinstance(vals, dict): 273 if not self._names: 274 raise ValueError("Queue must have names to enqueue a dictionary") 275 if sorted(self._names, key=str) != sorted(vals.keys(), key=str): 276 raise ValueError("Keys in dictionary to enqueue do not match " 277 "names of Queue. Dictionary: (%s), Queue: (%s)" % 278 (sorted(vals.keys()), sorted(self._names))) 279 # The order of values in `self._names` indicates the order in which the 280 # tensors in the dictionary `vals` must be listed. 281 vals = [vals[k] for k in self._names] 282 else: 283 if self._names: 284 raise ValueError("You must enqueue a dictionary in a Queue with names") 285 if not isinstance(vals, (list, tuple)): 286 vals = [vals] 287 288 tensors = [] 289 for i, (val, dtype) in enumerate(zip(vals, self._dtypes)): 290 tensors.append( 291 ops.convert_to_tensor(val, dtype=dtype, name="component_%d" % i)) 292 293 return tensors 294 295 def _scope_vals(self, vals): 296 """Return a list of values to pass to `name_scope()`. 297 298 Args: 299 vals: A tensor, a list or tuple of tensors, or a dictionary. 300 301 Returns: 302 The values in vals as a list. 303 """ 304 if isinstance(vals, (list, tuple)): 305 return vals 306 elif isinstance(vals, dict): 307 return vals.values() 308 else: 309 return [vals] 310 311 def enqueue(self, vals, name=None): 312 """Enqueues one element to this queue. 313 314 If the queue is full when this operation executes, it will block 315 until the element has been enqueued. 316 317 At runtime, this operation may raise an error if the queue is 318 `tf.QueueBase.close` before or during its execution. If the 319 queue is closed before this operation runs, 320 `tf.errors.CancelledError` will be raised. If this operation is 321 blocked, and either (i) the queue is closed by a close operation 322 with `cancel_pending_enqueues=True`, or (ii) the session is 323 `tf.Session.close`, 324 `tf.errors.CancelledError` will be raised. 325 326 Args: 327 vals: A tensor, a list or tuple of tensors, or a dictionary containing 328 the values to enqueue. 329 name: A name for the operation (optional). 330 331 Returns: 332 The operation that enqueues a new tuple of tensors to the queue. 333 """ 334 with ops.name_scope(name, "%s_enqueue" % self._name, 335 self._scope_vals(vals)) as scope: 336 vals = self._check_enqueue_dtypes(vals) 337 338 # NOTE(mrry): Not using a shape function because we need access to 339 # the `QueueBase` object. 340 for val, shape in zip(vals, self._shapes): 341 val.get_shape().assert_is_compatible_with(shape) 342 343 if self._queue_ref.dtype == _dtypes.resource: 344 return gen_data_flow_ops.queue_enqueue_v2( 345 self._queue_ref, vals, name=scope) 346 else: 347 return gen_data_flow_ops.queue_enqueue( 348 self._queue_ref, vals, name=scope) 349 350 def enqueue_many(self, vals, name=None): 351 """Enqueues zero or more elements to this queue. 352 353 This operation slices each component tensor along the 0th dimension to 354 make multiple queue elements. All of the tensors in `vals` must have the 355 same size in the 0th dimension. 356 357 If the queue is full when this operation executes, it will block 358 until all of the elements have been enqueued. 359 360 At runtime, this operation may raise an error if the queue is 361 `tf.QueueBase.close` before or during its execution. If the 362 queue is closed before this operation runs, 363 `tf.errors.CancelledError` will be raised. If this operation is 364 blocked, and either (i) the queue is closed by a close operation 365 with `cancel_pending_enqueues=True`, or (ii) the session is 366 `tf.Session.close`, 367 `tf.errors.CancelledError` will be raised. 368 369 Args: 370 vals: A tensor, a list or tuple of tensors, or a dictionary 371 from which the queue elements are taken. 372 name: A name for the operation (optional). 373 374 Returns: 375 The operation that enqueues a batch of tuples of tensors to the queue. 376 """ 377 with ops.name_scope(name, "%s_EnqueueMany" % self._name, 378 self._scope_vals(vals)) as scope: 379 vals = self._check_enqueue_dtypes(vals) 380 381 # NOTE(mrry): Not using a shape function because we need access to 382 # the `QueueBase` object. 383 # NOTE(fchollet): the code that follow is verbose because it needs to be 384 # compatible with both TF v1 TensorShape behavior and TF v2 behavior. 385 batch_dim = tensor_shape.dimension_value( 386 vals[0].get_shape().with_rank_at_least(1)[0]) 387 batch_dim = tensor_shape.Dimension(batch_dim) 388 for val, shape in zip(vals, self._shapes): 389 val_batch_dim = tensor_shape.dimension_value( 390 val.get_shape().with_rank_at_least(1)[0]) 391 val_batch_dim = tensor_shape.Dimension(val_batch_dim) 392 batch_dim = batch_dim.merge_with(val_batch_dim) 393 val.get_shape()[1:].assert_is_compatible_with(shape) 394 395 return gen_data_flow_ops.queue_enqueue_many_v2( 396 self._queue_ref, vals, name=scope) 397 398 def _dequeue_return_value(self, tensors): 399 """Return the value to return from a dequeue op. 400 401 If the queue has names, return a dictionary with the 402 names as keys. Otherwise return either a single tensor 403 or a list of tensors depending on the length of `tensors`. 404 405 Args: 406 tensors: List of tensors from the dequeue op. 407 408 Returns: 409 A single tensor, a list of tensors, or a dictionary 410 of tensors. 411 """ 412 if self._names: 413 # The returned values in `tensors` are in the same order as 414 # the names in `self._names`. 415 return {n: tensors[i] for i, n in enumerate(self._names)} 416 elif len(tensors) == 1: 417 return tensors[0] 418 else: 419 return tensors 420 421 def dequeue(self, name=None): 422 """Dequeues one element from this queue. 423 424 If the queue is empty when this operation executes, it will block 425 until there is an element to dequeue. 426 427 At runtime, this operation may raise an error if the queue is 428 `tf.QueueBase.close` before or during its execution. If the 429 queue is closed, the queue is empty, and there are no pending 430 enqueue operations that can fulfill this request, 431 `tf.errors.OutOfRangeError` will be raised. If the session is 432 `tf.Session.close`, 433 `tf.errors.CancelledError` will be raised. 434 435 Args: 436 name: A name for the operation (optional). 437 438 Returns: 439 The tuple of tensors that was dequeued. 440 """ 441 if name is None: 442 name = "%s_Dequeue" % self._name 443 if self._queue_ref.dtype == _dtypes.resource: 444 ret = gen_data_flow_ops.queue_dequeue_v2( 445 self._queue_ref, self._dtypes, name=name) 446 else: 447 ret = gen_data_flow_ops.queue_dequeue( 448 self._queue_ref, self._dtypes, name=name) 449 450 # NOTE(mrry): Not using a shape function because we need access to 451 # the `QueueBase` object. 452 if not context.executing_eagerly(): 453 op = ret[0].op 454 for output, shape in zip(op.values(), self._shapes): 455 output.set_shape(shape) 456 457 return self._dequeue_return_value(ret) 458 459 def dequeue_many(self, n, name=None): 460 """Dequeues and concatenates `n` elements from this queue. 461 462 This operation concatenates queue-element component tensors along 463 the 0th dimension to make a single component tensor. All of the 464 components in the dequeued tuple will have size `n` in the 0th dimension. 465 466 If the queue is closed and there are less than `n` elements left, then an 467 `OutOfRange` exception is raised. 468 469 At runtime, this operation may raise an error if the queue is 470 `tf.QueueBase.close` before or during its execution. If the 471 queue is closed, the queue contains fewer than `n` elements, and 472 there are no pending enqueue operations that can fulfill this 473 request, `tf.errors.OutOfRangeError` will be raised. If the 474 session is `tf.Session.close`, 475 `tf.errors.CancelledError` will be raised. 476 477 Args: 478 n: A scalar `Tensor` containing the number of elements to dequeue. 479 name: A name for the operation (optional). 480 481 Returns: 482 The list of concatenated tensors that was dequeued. 483 """ 484 if name is None: 485 name = "%s_DequeueMany" % self._name 486 487 ret = gen_data_flow_ops.queue_dequeue_many_v2( 488 self._queue_ref, n=n, component_types=self._dtypes, name=name) 489 490 # NOTE(mrry): Not using a shape function because we need access to 491 # the Queue object. 492 if not context.executing_eagerly(): 493 op = ret[0].op 494 batch_dim = tensor_shape.Dimension( 495 tensor_util.constant_value(op.inputs[1])) 496 for output, shape in zip(op.values(), self._shapes): 497 output.set_shape( 498 tensor_shape.TensorShape([batch_dim]).concatenate(shape)) 499 500 return self._dequeue_return_value(ret) 501 502 def dequeue_up_to(self, n, name=None): 503 """Dequeues and concatenates `n` elements from this queue. 504 505 **Note** This operation is not supported by all queues. If a queue does not 506 support DequeueUpTo, then a `tf.errors.UnimplementedError` is raised. 507 508 This operation concatenates queue-element component tensors along 509 the 0th dimension to make a single component tensor. If the queue 510 has not been closed, all of the components in the dequeued tuple 511 will have size `n` in the 0th dimension. 512 513 If the queue is closed and there are more than `0` but fewer than 514 `n` elements remaining, then instead of raising a 515 `tf.errors.OutOfRangeError` like `tf.QueueBase.dequeue_many`, 516 less than `n` elements are returned immediately. If the queue is 517 closed and there are `0` elements left in the queue, then a 518 `tf.errors.OutOfRangeError` is raised just like in `dequeue_many`. 519 Otherwise the behavior is identical to `dequeue_many`. 520 521 Args: 522 n: A scalar `Tensor` containing the number of elements to dequeue. 523 name: A name for the operation (optional). 524 525 Returns: 526 The tuple of concatenated tensors that was dequeued. 527 """ 528 if name is None: 529 name = "%s_DequeueUpTo" % self._name 530 531 ret = gen_data_flow_ops.queue_dequeue_up_to_v2( 532 self._queue_ref, n=n, component_types=self._dtypes, name=name) 533 534 # NOTE(mrry): Not using a shape function because we need access to 535 # the Queue object. 536 if not context.executing_eagerly(): 537 op = ret[0].op 538 for output, shape in zip(op.values(), self._shapes): 539 output.set_shape(tensor_shape.TensorShape([None]).concatenate(shape)) 540 541 return self._dequeue_return_value(ret) 542 543 def close(self, cancel_pending_enqueues=False, name=None): 544 """Closes this queue. 545 546 This operation signals that no more elements will be enqueued in 547 the given queue. Subsequent `enqueue` and `enqueue_many` 548 operations will fail. Subsequent `dequeue` and `dequeue_many` 549 operations will continue to succeed if sufficient elements remain 550 in the queue. Subsequently dequeue and dequeue_many operations 551 that would otherwise block waiting for more elements (if close 552 hadn't been called) will now fail immediately. 553 554 If `cancel_pending_enqueues` is `True`, all pending requests will also 555 be canceled. 556 557 Args: 558 cancel_pending_enqueues: (Optional.) A boolean, defaulting to 559 `False` (described above). 560 name: A name for the operation (optional). 561 562 Returns: 563 The operation that closes the queue. 564 """ 565 if name is None: 566 name = "%s_Close" % self._name 567 if self._queue_ref.dtype == _dtypes.resource: 568 return gen_data_flow_ops.queue_close_v2( 569 self._queue_ref, 570 cancel_pending_enqueues=cancel_pending_enqueues, 571 name=name) 572 else: 573 return gen_data_flow_ops.queue_close( 574 self._queue_ref, 575 cancel_pending_enqueues=cancel_pending_enqueues, 576 name=name) 577 578 def is_closed(self, name=None): 579 """Returns true if queue is closed. 580 581 This operation returns true if the queue is closed and false if the queue 582 is open. 583 584 Args: 585 name: A name for the operation (optional). 586 587 Returns: 588 True if the queue is closed and false if the queue is open. 589 """ 590 if name is None: 591 name = "%s_Is_Closed" % self._name 592 if self._queue_ref.dtype == _dtypes.resource: 593 return gen_data_flow_ops.queue_is_closed_v2(self._queue_ref, name=name) 594 else: 595 return gen_data_flow_ops.queue_is_closed_(self._queue_ref, name=name) 596 597 def size(self, name=None): 598 """Compute the number of elements in this queue. 599 600 Args: 601 name: A name for the operation (optional). 602 603 Returns: 604 A scalar tensor containing the number of elements in this queue. 605 """ 606 if name is None: 607 name = "%s_Size" % self._name 608 if self._queue_ref.dtype == _dtypes.resource: 609 return gen_data_flow_ops.queue_size_v2(self._queue_ref, name=name) 610 else: 611 return gen_data_flow_ops.queue_size(self._queue_ref, name=name) 612 613def _shared_name(shared_name): 614 if context.executing_eagerly(): 615 return str(ops.uid()) 616 return shared_name 617 618 619@tf_export( 620 "queue.RandomShuffleQueue", 621 v1=["queue.RandomShuffleQueue", 622 "io.RandomShuffleQueue", "RandomShuffleQueue"]) 623@deprecation.deprecated_endpoints( 624 ["io.RandomShuffleQueue", "RandomShuffleQueue"]) 625class RandomShuffleQueue(QueueBase): 626 """A queue implementation that dequeues elements in a random order. 627 628 See `tf.QueueBase` for a description of the methods on 629 this class. 630 """ 631 632 def __init__(self, 633 capacity, 634 min_after_dequeue, 635 dtypes, 636 shapes=None, 637 names=None, 638 seed=None, 639 shared_name=None, 640 name="random_shuffle_queue"): 641 """Create a queue that dequeues elements in a random order. 642 643 A `RandomShuffleQueue` has bounded capacity; supports multiple 644 concurrent producers and consumers; and provides exactly-once 645 delivery. 646 647 A `RandomShuffleQueue` holds a list of up to `capacity` 648 elements. Each element is a fixed-length tuple of tensors whose 649 dtypes are described by `dtypes`, and whose shapes are optionally 650 described by the `shapes` argument. 651 652 If the `shapes` argument is specified, each component of a queue 653 element must have the respective fixed shape. If it is 654 unspecified, different queue elements may have different shapes, 655 but the use of `dequeue_many` is disallowed. 656 657 The `min_after_dequeue` argument allows the caller to specify a 658 minimum number of elements that will remain in the queue after a 659 `dequeue` or `dequeue_many` operation completes, to ensure a 660 minimum level of mixing of elements. This invariant is maintained 661 by blocking those operations until sufficient elements have been 662 enqueued. The `min_after_dequeue` argument is ignored after the 663 queue has been closed. 664 665 Args: 666 capacity: An integer. The upper bound on the number of elements 667 that may be stored in this queue. 668 min_after_dequeue: An integer (described above). 669 dtypes: A list of `DType` objects. The length of `dtypes` must equal 670 the number of tensors in each queue element. 671 shapes: (Optional.) A list of fully-defined `TensorShape` objects 672 with the same length as `dtypes`, or `None`. 673 names: (Optional.) A list of string naming the components in the queue 674 with the same length as `dtypes`, or `None`. If specified the dequeue 675 methods return a dictionary with the names as keys. 676 seed: A Python integer. Used to create a random seed. See 677 `tf.set_random_seed` 678 for behavior. 679 shared_name: (Optional.) If non-empty, this queue will be shared under 680 the given name across multiple sessions. 681 name: Optional name for the queue operation. 682 """ 683 dtypes = _as_type_list(dtypes) 684 shapes = _as_shape_list(shapes, dtypes) 685 names = _as_name_list(names, dtypes) 686 seed1, seed2 = random_seed.get_seed(seed) 687 if seed1 is None and seed2 is None: 688 seed1, seed2 = 0, 0 689 elif seed is None and shared_name is not None: 690 # This means that graph seed is provided but op seed is not provided. 691 # If shared_name is also provided, make seed2 depend only on the graph 692 # seed and shared_name. (seed2 from get_seed() is generally dependent on 693 # the id of the last op created.) 694 string = (str(seed1) + shared_name).encode("utf-8") 695 seed2 = int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF 696 queue_ref = gen_data_flow_ops.random_shuffle_queue_v2( 697 component_types=dtypes, 698 shapes=shapes, 699 capacity=capacity, 700 min_after_dequeue=min_after_dequeue, 701 seed=seed1, 702 seed2=seed2, 703 shared_name=_shared_name(shared_name), 704 name=name) 705 706 super(RandomShuffleQueue, self).__init__(dtypes, shapes, names, queue_ref) 707 708 709@tf_export("queue.FIFOQueue", v1=["queue.FIFOQueue", "FIFOQueue"]) 710@deprecation.deprecated_endpoints("FIFOQueue") 711class FIFOQueue(QueueBase): 712 """A queue implementation that dequeues elements in first-in first-out order. 713 714 See `tf.QueueBase` for a description of the methods on 715 this class. 716 """ 717 718 def __init__(self, 719 capacity, 720 dtypes, 721 shapes=None, 722 names=None, 723 shared_name=None, 724 name="fifo_queue"): 725 """Creates a queue that dequeues elements in a first-in first-out order. 726 727 A `FIFOQueue` has bounded capacity; supports multiple concurrent 728 producers and consumers; and provides exactly-once delivery. 729 730 A `FIFOQueue` holds a list of up to `capacity` elements. Each 731 element is a fixed-length tuple of tensors whose dtypes are 732 described by `dtypes`, and whose shapes are optionally described 733 by the `shapes` argument. 734 735 If the `shapes` argument is specified, each component of a queue 736 element must have the respective fixed shape. If it is 737 unspecified, different queue elements may have different shapes, 738 but the use of `dequeue_many` is disallowed. 739 740 Args: 741 capacity: An integer. The upper bound on the number of elements 742 that may be stored in this queue. 743 dtypes: A list of `DType` objects. The length of `dtypes` must equal 744 the number of tensors in each queue element. 745 shapes: (Optional.) A list of fully-defined `TensorShape` objects 746 with the same length as `dtypes`, or `None`. 747 names: (Optional.) A list of string naming the components in the queue 748 with the same length as `dtypes`, or `None`. If specified the dequeue 749 methods return a dictionary with the names as keys. 750 shared_name: (Optional.) If non-empty, this queue will be shared under 751 the given name across multiple sessions. 752 name: Optional name for the queue operation. 753 """ 754 dtypes = _as_type_list(dtypes) 755 shapes = _as_shape_list(shapes, dtypes) 756 names = _as_name_list(names, dtypes) 757 queue_ref = gen_data_flow_ops.fifo_queue_v2( 758 component_types=dtypes, 759 shapes=shapes, 760 capacity=capacity, 761 shared_name=_shared_name(shared_name), 762 name=name) 763 764 super(FIFOQueue, self).__init__(dtypes, shapes, names, queue_ref) 765 766 767@tf_export( 768 "queue.PaddingFIFOQueue", 769 v1=["queue.PaddingFIFOQueue", "io.PaddingFIFOQueue", "PaddingFIFOQueue"]) 770@deprecation.deprecated_endpoints(["io.PaddingFIFOQueue", "PaddingFIFOQueue"]) 771class PaddingFIFOQueue(QueueBase): 772 """A FIFOQueue that supports batching variable-sized tensors by padding. 773 774 A `PaddingFIFOQueue` may contain components with dynamic shape, while also 775 supporting `dequeue_many`. See the constructor for more details. 776 777 See `tf.QueueBase` for a description of the methods on 778 this class. 779 """ 780 781 def __init__(self, 782 capacity, 783 dtypes, 784 shapes, 785 names=None, 786 shared_name=None, 787 name="padding_fifo_queue"): 788 """Creates a queue that dequeues elements in a first-in first-out order. 789 790 A `PaddingFIFOQueue` has bounded capacity; supports multiple concurrent 791 producers and consumers; and provides exactly-once delivery. 792 793 A `PaddingFIFOQueue` holds a list of up to `capacity` elements. Each 794 element is a fixed-length tuple of tensors whose dtypes are 795 described by `dtypes`, and whose shapes are described by the `shapes` 796 argument. 797 798 The `shapes` argument must be specified; each component of a queue 799 element must have the respective shape. Shapes of fixed 800 rank but variable size are allowed by setting any shape dimension to None. 801 In this case, the inputs' shape may vary along the given dimension, and 802 `dequeue_many` will pad the given dimension with zeros up to the maximum 803 shape of all elements in the given batch. 804 805 Args: 806 capacity: An integer. The upper bound on the number of elements 807 that may be stored in this queue. 808 dtypes: A list of `DType` objects. The length of `dtypes` must equal 809 the number of tensors in each queue element. 810 shapes: A list of `TensorShape` objects, with the same length as 811 `dtypes`. Any dimension in the `TensorShape` containing value 812 `None` is dynamic and allows values to be enqueued with 813 variable size in that dimension. 814 names: (Optional.) A list of string naming the components in the queue 815 with the same length as `dtypes`, or `None`. If specified the dequeue 816 methods return a dictionary with the names as keys. 817 shared_name: (Optional.) If non-empty, this queue will be shared under 818 the given name across multiple sessions. 819 name: Optional name for the queue operation. 820 821 Raises: 822 ValueError: If shapes is not a list of shapes, or the lengths of dtypes 823 and shapes do not match, or if names is specified and the lengths of 824 dtypes and names do not match. 825 """ 826 dtypes = _as_type_list(dtypes) 827 shapes = _as_shape_list(shapes, dtypes, unknown_dim_allowed=True) 828 names = _as_name_list(names, dtypes) 829 if len(dtypes) != len(shapes): 830 raise ValueError("Shapes must be provided for all components, " 831 "but received %d dtypes and %d shapes." % (len(dtypes), 832 len(shapes))) 833 834 queue_ref = gen_data_flow_ops.padding_fifo_queue_v2( 835 component_types=dtypes, 836 shapes=shapes, 837 capacity=capacity, 838 shared_name=_shared_name(shared_name), 839 name=name) 840 841 super(PaddingFIFOQueue, self).__init__(dtypes, shapes, names, queue_ref) 842 843 844@tf_export("queue.PriorityQueue", 845 v1=["queue.PriorityQueue", "io.PriorityQueue", "PriorityQueue"]) 846@deprecation.deprecated_endpoints(["io.PriorityQueue", "PriorityQueue"]) 847class PriorityQueue(QueueBase): 848 """A queue implementation that dequeues elements in prioritized order. 849 850 See `tf.QueueBase` for a description of the methods on 851 this class. 852 """ 853 854 def __init__(self, 855 capacity, 856 types, 857 shapes=None, 858 names=None, 859 shared_name=None, 860 name="priority_queue"): 861 """Creates a queue that dequeues elements in a first-in first-out order. 862 863 A `PriorityQueue` has bounded capacity; supports multiple concurrent 864 producers and consumers; and provides exactly-once delivery. 865 866 A `PriorityQueue` holds a list of up to `capacity` elements. Each 867 element is a fixed-length tuple of tensors whose dtypes are 868 described by `types`, and whose shapes are optionally described 869 by the `shapes` argument. 870 871 If the `shapes` argument is specified, each component of a queue 872 element must have the respective fixed shape. If it is 873 unspecified, different queue elements may have different shapes, 874 but the use of `dequeue_many` is disallowed. 875 876 Enqueues and Dequeues to the `PriorityQueue` must include an additional 877 tuple entry at the beginning: the `priority`. The priority must be 878 an int64 scalar (for `enqueue`) or an int64 vector (for `enqueue_many`). 879 880 Args: 881 capacity: An integer. The upper bound on the number of elements 882 that may be stored in this queue. 883 types: A list of `DType` objects. The length of `types` must equal 884 the number of tensors in each queue element, except the first priority 885 element. The first tensor in each element is the priority, 886 which must be type int64. 887 shapes: (Optional.) A list of fully-defined `TensorShape` objects, 888 with the same length as `types`, or `None`. 889 names: (Optional.) A list of strings naming the components in the queue 890 with the same length as `dtypes`, or `None`. If specified, the dequeue 891 methods return a dictionary with the names as keys. 892 shared_name: (Optional.) If non-empty, this queue will be shared under 893 the given name across multiple sessions. 894 name: Optional name for the queue operation. 895 """ 896 types = _as_type_list(types) 897 shapes = _as_shape_list(shapes, types) 898 899 queue_ref = gen_data_flow_ops.priority_queue_v2( 900 component_types=types, 901 shapes=shapes, 902 capacity=capacity, 903 shared_name=_shared_name(shared_name), 904 name=name) 905 906 priority_dtypes = [_dtypes.int64] + types 907 priority_shapes = [()] + shapes if shapes else shapes 908 909 super(PriorityQueue, self).__init__(priority_dtypes, priority_shapes, names, 910 queue_ref) 911 912 913# TODO(josh11b): class BatchQueue(QueueBase): 914 915 916class Barrier(object): 917 """Represents a key-value map that persists across graph executions.""" 918 919 def __init__(self, types, shapes=None, shared_name=None, name="barrier"): 920 """Creates a barrier that persists across different graph executions. 921 922 A barrier represents a key-value map, where each key is a string, and 923 each value is a tuple of tensors. 924 925 At runtime, the barrier contains 'complete' and 'incomplete' 926 elements. A complete element has defined tensors for all 927 components of its value tuple, and may be accessed using 928 take_many. An incomplete element has some undefined components in 929 its value tuple, and may be updated using insert_many. 930 931 The barrier call `take_many` outputs values in a particular order. 932 First, it only outputs completed values. Second, the order in which 933 completed values are returned matches the order in which their very 934 first component was inserted into the barrier. So, for example, for this 935 sequence of insertions and removals: 936 937 barrier = Barrier((tf.string, tf.int32), shapes=((), ())) 938 barrier.insert_many(0, keys=["k1", "k2"], values=["a", "b"]).run() 939 barrier.insert_many(1, keys=["k1"], values=[1]).run() 940 barrier.insert_many(0, keys=["k3"], values=["c"]).run() 941 barrier.insert_many(1, keys=["k3"], values=[3]).run() 942 barrier.insert_many(1, keys=["k2"], values=[2]).run() 943 944 (indices, keys, values) = barrier.take_many(2) 945 (indices_val, keys_val, values0_val, values1_val) = 946 session.run([indices, keys, values[0], values[1]]) 947 948 The output will be (up to permutation of "k1" and "k2"): 949 950 indices_val == (-2**63, -2**63) 951 keys_val == ("k1", "k2") 952 values0_val == ("a", "b") 953 values1_val == (1, 2) 954 955 Note the key "k2" was inserted into the barrier before "k3". Even though 956 "k3" was completed first, both are complete by the time 957 take_many is called. As a result, "k2" is prioritized and "k1" and "k2" 958 are returned first. "k3" remains in the barrier until the next execution 959 of `take_many`. Since "k1" and "k2" had their first insertions into 960 the barrier together, their indices are the same (-2**63). The index 961 of "k3" will be -2**63 + 1, because it was the next new inserted key. 962 963 Args: 964 types: A single dtype or a tuple of dtypes, corresponding to the 965 dtypes of the tensor elements that comprise a value in this barrier. 966 shapes: Optional. Constraints on the shapes of tensors in the values: 967 a single tensor shape tuple; a tuple of tensor shape tuples 968 for each barrier-element tuple component; or None if the shape should 969 not be constrained. 970 shared_name: Optional. If non-empty, this barrier will be shared under 971 the given name across multiple sessions. 972 name: Optional name for the barrier op. 973 974 Raises: 975 ValueError: If one of the `shapes` indicate no elements. 976 """ 977 self._types = _as_type_list(types) 978 979 if shapes is not None: 980 shapes = _as_shape_list(shapes, self._types) 981 self._shapes = [tensor_shape.TensorShape(s) for s in shapes] 982 for i, shape in enumerate(self._shapes): 983 if shape.num_elements() == 0: 984 raise ValueError("Empty tensors are not supported, but received " 985 "shape '%s' at index %d" % (shape, i)) 986 else: 987 self._shapes = [tensor_shape.unknown_shape() for _ in self._types] 988 989 self._barrier_ref = gen_data_flow_ops.barrier( 990 component_types=self._types, 991 shapes=self._shapes, 992 shared_name=shared_name, 993 name=name) 994 if context.executing_eagerly(): 995 self._name = context.context().scope_name 996 else: 997 self._name = self._barrier_ref.op.name.split("/")[-1] 998 999 @property 1000 def barrier_ref(self): 1001 """Get the underlying barrier reference.""" 1002 return self._barrier_ref 1003 1004 @property 1005 def name(self): 1006 """The name of the underlying barrier.""" 1007 if context.executing_eagerly(): 1008 return self._name 1009 return self._barrier_ref.op.name 1010 1011 def insert_many(self, component_index, keys, values, name=None): 1012 """For each key, assigns the respective value to the specified component. 1013 1014 This operation updates each element at component_index. 1015 1016 Args: 1017 component_index: The component of the value that is being assigned. 1018 keys: A vector of keys, with length n. 1019 values: An any-dimensional tensor of values, which are associated with the 1020 respective keys. The first dimension must have length n. 1021 name: Optional name for the op. 1022 1023 Returns: 1024 The operation that performs the insertion. 1025 Raises: 1026 InvalidArgumentsError: If inserting keys and values without elements. 1027 """ 1028 if name is None: 1029 name = "%s_BarrierInsertMany" % self._name 1030 return gen_data_flow_ops.barrier_insert_many( 1031 self._barrier_ref, keys, values, component_index, name=name) 1032 1033 def take_many(self, 1034 num_elements, 1035 allow_small_batch=False, 1036 timeout=None, 1037 name=None): 1038 """Takes the given number of completed elements from this barrier. 1039 1040 This operation concatenates completed-element component tensors along 1041 the 0th dimension to make a single component tensor. 1042 1043 If barrier has no completed elements, this operation will block 1044 until there are 'num_elements' elements to take. 1045 1046 TODO(b/25743580): the semantics of `allow_small_batch` are experimental 1047 and may be extended to other cases in the future. 1048 1049 TODO(ebrevdo): If a take_many(allow_small_batch=True) is blocking 1050 already when the barrier is closed, it will block for ever. Fix this 1051 by using asynchronous operations. 1052 1053 Args: 1054 num_elements: The number of elements to take. 1055 allow_small_batch: If the barrier is closed, don't block if there are less 1056 completed elements than requested, but instead return all available 1057 completed elements. 1058 timeout: This specifies the number of milliseconds to block 1059 before returning with DEADLINE_EXCEEDED. (This option is not 1060 supported yet.) 1061 name: A name for the operation (optional). 1062 1063 Returns: 1064 A tuple of (index, key, value_list). 1065 "index" is a int64 tensor of length num_elements containing the 1066 index of the insert_many call for which the very first component of 1067 the given element was inserted into the Barrier, starting with 1068 the value -2**63. Note, this value is different from the 1069 index of the insert_many call for which the element was completed. 1070 "key" is a string tensor of length num_elements containing the keys. 1071 "value_list" is a tuple of tensors, each one with size num_elements 1072 in the 0th dimension for each component in the barrier's values. 1073 1074 """ 1075 if name is None: 1076 name = "%s_BarrierTakeMany" % self._name 1077 ret = gen_data_flow_ops.barrier_take_many( 1078 self._barrier_ref, 1079 num_elements, 1080 self._types, 1081 allow_small_batch, 1082 timeout, 1083 name=name) 1084 1085 # NOTE(mrry): Not using a shape function because we need access to 1086 # the Barrier object. 1087 if not context.executing_eagerly(): 1088 op = ret[0].op 1089 if allow_small_batch: 1090 batch_dim = None 1091 else: 1092 batch_dim = tensor_shape.Dimension( 1093 tensor_util.constant_value(op.inputs[1])) 1094 op.outputs[0].set_shape(tensor_shape.vector(batch_dim)) # indices 1095 op.outputs[1].set_shape(tensor_shape.vector(batch_dim)) # keys 1096 for output, shape in zip(op.outputs[2:], self._shapes): # value_list 1097 output.set_shape( 1098 tensor_shape.TensorShape([batch_dim]).concatenate(shape)) 1099 1100 return ret 1101 1102 def close(self, cancel_pending_enqueues=False, name=None): 1103 """Closes this barrier. 1104 1105 This operation signals that no more new key values will be inserted in the 1106 given barrier. Subsequent InsertMany operations with new keys will fail. 1107 InsertMany operations that just complement already existing keys with other 1108 components, will continue to succeed. Subsequent TakeMany operations will 1109 continue to succeed if sufficient elements remain in the barrier. Subsequent 1110 TakeMany operations that would block will fail immediately. 1111 1112 If `cancel_pending_enqueues` is `True`, all pending requests to the 1113 underlying queue will also be canceled, and completing of already 1114 started values is also not acceptable anymore. 1115 1116 Args: 1117 cancel_pending_enqueues: (Optional.) A boolean, defaulting to 1118 `False` (described above). 1119 name: Optional name for the op. 1120 1121 Returns: 1122 The operation that closes the barrier. 1123 """ 1124 if name is None: 1125 name = "%s_BarrierClose" % self._name 1126 return gen_data_flow_ops.barrier_close( 1127 self._barrier_ref, 1128 cancel_pending_enqueues=cancel_pending_enqueues, 1129 name=name) 1130 1131 def ready_size(self, name=None): 1132 """Compute the number of complete elements in the given barrier. 1133 1134 Args: 1135 name: A name for the operation (optional). 1136 1137 Returns: 1138 A single-element tensor containing the number of complete elements in the 1139 given barrier. 1140 """ 1141 if name is None: 1142 name = "%s_BarrierReadySize" % self._name 1143 return gen_data_flow_ops.barrier_ready_size(self._barrier_ref, name=name) 1144 1145 def incomplete_size(self, name=None): 1146 """Compute the number of incomplete elements in the given barrier. 1147 1148 Args: 1149 name: A name for the operation (optional). 1150 1151 Returns: 1152 A single-element tensor containing the number of incomplete elements in 1153 the given barrier. 1154 """ 1155 if name is None: 1156 name = "%s_BarrierIncompleteSize" % self._name 1157 return gen_data_flow_ops.barrier_incomplete_size( 1158 self._barrier_ref, name=name) 1159 1160 1161@tf_export(v1=["ConditionalAccumulatorBase"]) 1162class ConditionalAccumulatorBase(object): 1163 """A conditional accumulator for aggregating gradients. 1164 1165 Up-to-date gradients (i.e., time step at which gradient was computed is 1166 equal to the accumulator's time step) are added to the accumulator. 1167 1168 Extraction of the average gradient is blocked until the required number of 1169 gradients has been accumulated. 1170 """ 1171 1172 def __init__(self, dtype, shape, accumulator_ref): 1173 """Creates a new ConditionalAccumulator. 1174 1175 Args: 1176 dtype: Datatype of the accumulated gradients. 1177 shape: Shape of the accumulated gradients. 1178 accumulator_ref: A handle to the conditional accumulator, created by sub- 1179 classes 1180 """ 1181 self._dtype = dtype 1182 if shape is not None: 1183 self._shape = tensor_shape.TensorShape(shape) 1184 else: 1185 self._shape = tensor_shape.unknown_shape() 1186 self._accumulator_ref = accumulator_ref 1187 if context.executing_eagerly(): 1188 self._name = context.context().scope_name 1189 else: 1190 self._name = self._accumulator_ref.op.name.split("/")[-1] 1191 1192 @property 1193 def accumulator_ref(self): 1194 """The underlying accumulator reference.""" 1195 return self._accumulator_ref 1196 1197 @property 1198 def name(self): 1199 """The name of the underlying accumulator.""" 1200 return self._name 1201 1202 @property 1203 def dtype(self): 1204 """The datatype of the gradients accumulated by this accumulator.""" 1205 return self._dtype 1206 1207 def num_accumulated(self, name=None): 1208 """Number of gradients that have currently been aggregated in accumulator. 1209 1210 Args: 1211 name: Optional name for the operation. 1212 1213 Returns: 1214 Number of accumulated gradients currently in accumulator. 1215 """ 1216 if name is None: 1217 name = "%s_NumAccumulated" % self._name 1218 return gen_data_flow_ops.accumulator_num_accumulated( 1219 self._accumulator_ref, name=name) 1220 1221 def set_global_step(self, new_global_step, name=None): 1222 """Sets the global time step of the accumulator. 1223 1224 The operation logs a warning if we attempt to set to a time step that is 1225 lower than the accumulator's own time step. 1226 1227 Args: 1228 new_global_step: Value of new time step. Can be a variable or a constant 1229 name: Optional name for the operation. 1230 1231 Returns: 1232 Operation that sets the accumulator's time step. 1233 """ 1234 return gen_data_flow_ops.accumulator_set_global_step( 1235 self._accumulator_ref, 1236 math_ops.cast(ops.convert_to_tensor(new_global_step), _dtypes.int64), 1237 name=name) 1238 1239 1240@tf_export(v1=["ConditionalAccumulator"]) 1241class ConditionalAccumulator(ConditionalAccumulatorBase): 1242 """A conditional accumulator for aggregating gradients. 1243 1244 Up-to-date gradients (i.e., time step at which gradient was computed is 1245 equal to the accumulator's time step) are added to the accumulator. 1246 1247 Extraction of the average gradient is blocked until the required number of 1248 gradients has been accumulated. 1249 """ 1250 1251 def __init__(self, 1252 dtype, 1253 shape=None, 1254 shared_name=None, 1255 name="conditional_accumulator", 1256 reduction_type="MEAN"): 1257 """Creates a new ConditionalAccumulator. 1258 1259 Args: 1260 dtype: Datatype of the accumulated gradients. 1261 shape: Shape of the accumulated gradients. 1262 shared_name: Optional. If non-empty, this accumulator will be shared under 1263 the given name across multiple sessions. 1264 name: Optional name for the accumulator. 1265 reduction_type: Reduction type to use when taking the gradient. 1266 """ 1267 accumulator_ref = gen_data_flow_ops.conditional_accumulator( 1268 dtype=dtype, 1269 shape=shape, 1270 shared_name=shared_name, 1271 name=name, 1272 reduction_type=reduction_type) 1273 super(ConditionalAccumulator, self).__init__(dtype, shape, accumulator_ref) 1274 1275 def apply_grad(self, grad, local_step=0, name=None): 1276 """Attempts to apply a gradient to the accumulator. 1277 1278 The attempt is silently dropped if the gradient is stale, i.e., local_step 1279 is less than the accumulator's global time step. 1280 1281 Args: 1282 grad: The gradient tensor to be applied. 1283 local_step: Time step at which the gradient was computed. 1284 name: Optional name for the operation. 1285 1286 Returns: 1287 The operation that (conditionally) applies a gradient to the accumulator. 1288 1289 Raises: 1290 ValueError: If grad is of the wrong shape 1291 """ 1292 grad = ops.convert_to_tensor(grad, self._dtype) 1293 grad.get_shape().assert_is_compatible_with(self._shape) 1294 local_step = math_ops.cast(ops.convert_to_tensor(local_step), _dtypes.int64) 1295 return gen_data_flow_ops.accumulator_apply_gradient( 1296 self._accumulator_ref, local_step=local_step, gradient=grad, name=name) 1297 1298 def take_grad(self, num_required, name=None): 1299 """Attempts to extract the average gradient from the accumulator. 1300 1301 The operation blocks until sufficient number of gradients have been 1302 successfully applied to the accumulator. 1303 1304 Once successful, the following actions are also triggered: 1305 1306 - Counter of accumulated gradients is reset to 0. 1307 - Aggregated gradient is reset to 0 tensor. 1308 - Accumulator's internal time step is incremented by 1. 1309 1310 Args: 1311 num_required: Number of gradients that needs to have been aggregated 1312 name: Optional name for the operation 1313 1314 Returns: 1315 A tensor holding the value of the average gradient. 1316 1317 Raises: 1318 InvalidArgumentError: If num_required < 1 1319 """ 1320 out = gen_data_flow_ops.accumulator_take_gradient( 1321 self._accumulator_ref, num_required, dtype=self._dtype, name=name) 1322 out.set_shape(self._shape) 1323 return out 1324 1325 1326@tf_export( 1327 "sparse.SparseConditionalAccumulator", 1328 v1=["sparse.SparseConditionalAccumulator", "SparseConditionalAccumulator"]) 1329@deprecation.deprecated_endpoints("SparseConditionalAccumulator") 1330class SparseConditionalAccumulator(ConditionalAccumulatorBase): 1331 """A conditional accumulator for aggregating sparse gradients. 1332 1333 Sparse gradients are represented by IndexedSlices. 1334 1335 Up-to-date gradients (i.e., time step at which gradient was computed is 1336 equal to the accumulator's time step) are added to the accumulator. 1337 1338 Extraction of the average gradient is blocked until the required number of 1339 gradients has been accumulated. 1340 1341 Args: 1342 dtype: Datatype of the accumulated gradients. 1343 shape: Shape of the accumulated gradients. 1344 shared_name: Optional. If non-empty, this accumulator will be shared under 1345 the given name across multiple sessions. 1346 name: Optional name for the accumulator. 1347 reduction_type: Reduction type to use when taking the gradient. 1348 """ 1349 1350 def __init__(self, 1351 dtype, 1352 shape=None, 1353 shared_name=None, 1354 name="sparse_conditional_accumulator", 1355 reduction_type="MEAN"): 1356 accumulator_ref = gen_data_flow_ops.sparse_conditional_accumulator( 1357 dtype=dtype, 1358 shape=shape, 1359 shared_name=shared_name, 1360 name=name, 1361 reduction_type=reduction_type) 1362 super(SparseConditionalAccumulator, self).__init__(dtype, shape, 1363 accumulator_ref) 1364 1365 def apply_indexed_slices_grad(self, grad, local_step=0, name=None): 1366 """Attempts to apply a gradient to the accumulator. 1367 1368 The attempt is silently dropped if the gradient is stale, i.e., local_step 1369 is less than the accumulator's global time step. 1370 1371 Args: 1372 grad: The gradient IndexedSlices to be applied. 1373 local_step: Time step at which the gradient was computed. 1374 name: Optional name for the operation. 1375 1376 Returns: 1377 The operation that (conditionally) applies a gradient to the accumulator. 1378 1379 Raises: 1380 InvalidArgumentError: If grad is of the wrong shape 1381 """ 1382 return self.apply_grad( 1383 grad_indices=grad.indices, 1384 grad_values=grad.values, 1385 grad_shape=grad.dense_shape, 1386 local_step=local_step, 1387 name=name) 1388 1389 def apply_grad(self, 1390 grad_indices, 1391 grad_values, 1392 grad_shape=None, 1393 local_step=0, 1394 name=None): 1395 """Attempts to apply a sparse gradient to the accumulator. 1396 1397 The attempt is silently dropped if the gradient is stale, i.e., local_step 1398 is less than the accumulator's global time step. 1399 1400 A sparse gradient is represented by its indices, values and possibly empty 1401 or None shape. Indices must be a vector representing the locations of 1402 non-zero entries in the tensor. Values are the non-zero slices of the 1403 gradient, and must have the same first dimension as indices, i.e., the nnz 1404 represented by indices and values must be consistent. Shape, if not empty or 1405 None, must be consistent with the accumulator's shape (if also provided). 1406 1407 Example: 1408 A tensor [[0, 0], [0. 1], [2, 3]] can be represented 1409 indices: [1,2] 1410 values: [[0,1],[2,3]] 1411 shape: [3, 2] 1412 1413 Args: 1414 grad_indices: Indices of the sparse gradient to be applied. 1415 grad_values: Values of the sparse gradient to be applied. 1416 grad_shape: Shape of the sparse gradient to be applied. 1417 local_step: Time step at which the gradient was computed. 1418 name: Optional name for the operation. 1419 1420 Returns: 1421 The operation that (conditionally) applies a gradient to the accumulator. 1422 1423 Raises: 1424 InvalidArgumentError: If grad is of the wrong shape 1425 """ 1426 local_step = math_ops.cast(ops.convert_to_tensor(local_step), _dtypes.int64) 1427 return gen_data_flow_ops.sparse_accumulator_apply_gradient( 1428 self._accumulator_ref, 1429 local_step=local_step, 1430 gradient_indices=math_ops.cast(grad_indices, _dtypes.int64), 1431 gradient_values=grad_values, 1432 gradient_shape=math_ops.cast( 1433 [] if grad_shape is None else grad_shape, _dtypes.int64), 1434 has_known_shape=(grad_shape is not None), 1435 name=name) 1436 1437 def take_grad(self, num_required, name=None): 1438 """Attempts to extract the average gradient from the accumulator. 1439 1440 The operation blocks until sufficient number of gradients have been 1441 successfully applied to the accumulator. 1442 1443 Once successful, the following actions are also triggered: 1444 - Counter of accumulated gradients is reset to 0. 1445 - Aggregated gradient is reset to 0 tensor. 1446 - Accumulator's internal time step is incremented by 1. 1447 1448 Args: 1449 num_required: Number of gradients that needs to have been aggregated 1450 name: Optional name for the operation 1451 1452 Returns: 1453 A tuple of indices, values, and shape representing the average gradient. 1454 1455 Raises: 1456 InvalidArgumentError: If num_required < 1 1457 """ 1458 return gen_data_flow_ops.sparse_accumulator_take_gradient( 1459 self._accumulator_ref, num_required, dtype=self._dtype, name=name) 1460 1461 def take_indexed_slices_grad(self, num_required, name=None): 1462 """Attempts to extract the average gradient from the accumulator. 1463 1464 The operation blocks until sufficient number of gradients have been 1465 successfully applied to the accumulator. 1466 1467 Once successful, the following actions are also triggered: 1468 - Counter of accumulated gradients is reset to 0. 1469 - Aggregated gradient is reset to 0 tensor. 1470 - Accumulator's internal time step is incremented by 1. 1471 1472 Args: 1473 num_required: Number of gradients that needs to have been aggregated 1474 name: Optional name for the operation 1475 1476 Returns: 1477 An IndexedSlices holding the value of the average gradient. 1478 1479 Raises: 1480 InvalidArgumentError: If num_required < 1 1481 """ 1482 return_val = gen_data_flow_ops.sparse_accumulator_take_gradient( 1483 self._accumulator_ref, num_required, dtype=self._dtype, name=name) 1484 return ops.IndexedSlices( 1485 indices=return_val.indices, 1486 values=return_val.values, 1487 dense_shape=return_val.shape) 1488 1489 1490class BaseStagingArea(object): 1491 """Base class for Staging Areas.""" 1492 _identifier = 0 1493 _lock = threading.Lock() 1494 1495 def __init__(self, 1496 dtypes, 1497 shapes=None, 1498 names=None, 1499 shared_name=None, 1500 capacity=0, 1501 memory_limit=0): 1502 if shared_name is None: 1503 self._name = ( 1504 ops.get_default_graph().unique_name(self.__class__.__name__)) 1505 elif isinstance(shared_name, six.string_types): 1506 self._name = shared_name 1507 else: 1508 raise ValueError("shared_name must be a string") 1509 1510 self._dtypes = dtypes 1511 1512 if shapes is not None: 1513 if len(shapes) != len(dtypes): 1514 raise ValueError("StagingArea shapes must be the same length as dtypes") 1515 self._shapes = [tensor_shape.TensorShape(s) for s in shapes] 1516 else: 1517 self._shapes = [tensor_shape.unknown_shape() for _ in self._dtypes] 1518 1519 if names is not None: 1520 if len(names) != len(dtypes): 1521 raise ValueError("StagingArea names must be the same length as dtypes") 1522 self._names = names 1523 else: 1524 self._names = None 1525 1526 self._capacity = capacity 1527 self._memory_limit = memory_limit 1528 1529 # all get and put ops must colocate with this op 1530 with ops.name_scope("%s_root" % self._name): 1531 self._coloc_op = control_flow_ops.no_op() 1532 1533 @property 1534 def name(self): 1535 """The name of the staging area.""" 1536 return self._name 1537 1538 @property 1539 def dtypes(self): 1540 """The list of dtypes for each component of a staging area element.""" 1541 return self._dtypes 1542 1543 @property 1544 def shapes(self): 1545 """The list of shapes for each component of a staging area element.""" 1546 return self._shapes 1547 1548 @property 1549 def names(self): 1550 """The list of names for each component of a staging area element.""" 1551 return self._names 1552 1553 @property 1554 def capacity(self): 1555 """The maximum number of elements of this staging area.""" 1556 return self._capacity 1557 1558 @property 1559 def memory_limit(self): 1560 """The maximum number of bytes of this staging area.""" 1561 return self._memory_limit 1562 1563 def _check_put_dtypes(self, vals, indices=None): 1564 """Validate and convert `vals` to a list of `Tensor`s. 1565 1566 The `vals` argument can be a Tensor, a list or tuple of tensors, or a 1567 dictionary with tensor values. 1568 1569 If `vals` is a list, then the appropriate indices associated with the 1570 values must be provided. 1571 1572 If it is a dictionary, the staging area must have been constructed with a 1573 `names` attribute and the dictionary keys must match the staging area names. 1574 `indices` will be inferred from the dictionary keys. 1575 If the staging area was constructed with a `names` attribute, `vals` must 1576 be a dictionary. 1577 1578 Checks that the dtype and shape of each value matches that 1579 of the staging area. 1580 1581 Args: 1582 vals: A tensor, a list or tuple of tensors, or a dictionary. 1583 1584 Returns: 1585 A (tensors, indices) tuple where `tensors` is a list of `Tensor` objects 1586 and `indices` is a list of indices associed with the tensors. 1587 1588 Raises: 1589 ValueError: If `vals` or `indices` is invalid. 1590 """ 1591 if isinstance(vals, dict): 1592 if not self._names: 1593 raise ValueError( 1594 "Staging areas must have names to enqueue a dictionary") 1595 if not set(vals.keys()).issubset(self._names): 1596 raise ValueError("Keys in dictionary to put do not match names " 1597 "of staging area. Dictionary: (%s), Queue: (%s)" % 1598 (sorted(vals.keys()), sorted(self._names))) 1599 # The order of values in `self._names` indicates the order in which the 1600 # tensors in the dictionary `vals` must be listed. 1601 vals, indices, _ = zip(*[(vals[k], i, k) 1602 for i, k in enumerate(self._names) 1603 if k in vals]) 1604 else: 1605 if self._names: 1606 raise ValueError("You must enqueue a dictionary in a staging area " 1607 "with names") 1608 1609 if indices is None: 1610 raise ValueError("Indices must be supplied when inserting a list " 1611 "of tensors") 1612 1613 if len(indices) != len(vals): 1614 raise ValueError("Number of indices '%s' doesn't match " 1615 "number of values '%s'") 1616 1617 if not isinstance(vals, (list, tuple)): 1618 vals = [vals] 1619 indices = [0] 1620 1621 # Sanity check number of values 1622 if not len(vals) <= len(self._dtypes): 1623 raise ValueError("Unexpected number of inputs '%s' vs '%s'" % 1624 (len(vals), len(self._dtypes))) 1625 1626 tensors = [] 1627 1628 for val, i in zip(vals, indices): 1629 dtype, shape = self._dtypes[i], self._shapes[i] 1630 # Check dtype 1631 if val.dtype != dtype: 1632 raise ValueError("Datatypes do not match. '%s' != '%s'" % 1633 (str(val.dtype), str(dtype))) 1634 1635 # Check shape 1636 val.get_shape().assert_is_compatible_with(shape) 1637 1638 tensors.append( 1639 ops.convert_to_tensor(val, dtype=dtype, name="component_%d" % i)) 1640 1641 return tensors, indices 1642 1643 def _create_device_transfers(self, tensors): 1644 """Encode inter-device transfers if the current device 1645 is not the same as the Staging Area's device. 1646 """ 1647 1648 if not isinstance(tensors, (tuple, list)): 1649 tensors = [tensors] 1650 1651 curr_device_scope = control_flow_ops.no_op().device 1652 1653 if curr_device_scope != self._coloc_op.device: 1654 tensors = [array_ops.identity(t) for t in tensors] 1655 1656 return tensors 1657 1658 def _get_return_value(self, tensors, indices): 1659 """Return the value to return from a get op. 1660 1661 If the staging area has names, return a dictionary with the 1662 names as keys. Otherwise return either a single tensor 1663 or a list of tensors depending on the length of `tensors`. 1664 1665 Args: 1666 tensors: List of tensors from the get op. 1667 indices: Indices of associated names and shapes 1668 1669 Returns: 1670 A single tensor, a list of tensors, or a dictionary 1671 of tensors. 1672 """ 1673 1674 tensors = self._create_device_transfers(tensors) 1675 1676 # Sets shape 1677 for output, i in zip(tensors, indices): 1678 output.set_shape(self._shapes[i]) 1679 1680 if self._names: 1681 # The returned values in `tensors` are in the same order as 1682 # the names in `self._names`. 1683 return {self._names[i]: t for t, i in zip(tensors, indices)} 1684 return tensors 1685 1686 def _scope_vals(self, vals): 1687 """Return a list of values to pass to `name_scope()`. 1688 1689 Args: 1690 vals: A tensor, a list or tuple of tensors, or a dictionary. 1691 1692 Returns: 1693 The values in vals as a list. 1694 """ 1695 if isinstance(vals, (list, tuple)): 1696 return vals 1697 elif isinstance(vals, dict): 1698 return vals.values() 1699 else: 1700 return [vals] 1701 1702 1703class StagingArea(BaseStagingArea): 1704 """Class for staging inputs. No ordering guarantees. 1705 1706 A `StagingArea` is a TensorFlow data structure that stores tensors across 1707 multiple steps, and exposes operations that can put and get tensors. 1708 1709 Each `StagingArea` element is a tuple of one or more tensors, where each 1710 tuple component has a static dtype, and may have a static shape. 1711 1712 The capacity of a `StagingArea` may be bounded or unbounded. 1713 It supports multiple concurrent producers and consumers; and 1714 provides exactly-once delivery. 1715 1716 Each element of a `StagingArea` is a fixed-length tuple of tensors whose 1717 dtypes are described by `dtypes`, and whose shapes are optionally described 1718 by the `shapes` argument. 1719 1720 If the `shapes` argument is specified, each component of a staging area 1721 element must have the respective fixed shape. If it is 1722 unspecified, different elements may have different shapes, 1723 1724 It can be configured with a capacity in which case 1725 put(values) will block until space becomes available. 1726 1727 Similarly, it can be configured with a memory limit which 1728 will block put(values) until space is available. 1729 This is mostly useful for limiting the number of tensors on 1730 devices such as GPUs. 1731 1732 All get() and peek() commands block if the requested data 1733 is not present in the Staging Area. 1734 1735 """ 1736 1737 def __init__(self, 1738 dtypes, 1739 shapes=None, 1740 names=None, 1741 shared_name=None, 1742 capacity=0, 1743 memory_limit=0): 1744 """Constructs a staging area object. 1745 1746 The two optional lists, `shapes` and `names`, must be of the same length 1747 as `dtypes` if provided. The values at a given index `i` indicate the 1748 shape and name to use for the corresponding queue component in `dtypes`. 1749 1750 The device scope at the time of object creation determines where the 1751 storage for the `StagingArea` will reside. Calls to `put` will incur a copy 1752 to this memory space, if necessary. Tensors returned by `get` will be 1753 placed according to the device scope when `get` is called. 1754 1755 Args: 1756 dtypes: A list of types. The length of dtypes must equal the number 1757 of tensors in each element. 1758 shapes: (Optional.) Constraints on the shapes of tensors in an element. 1759 A list of shape tuples or None. This list is the same length 1760 as dtypes. If the shape of any tensors in the element are constrained, 1761 all must be; shapes can be None if the shapes should not be constrained. 1762 names: (Optional.) If provided, the `get()` and 1763 `put()` methods will use dictionaries with these names as keys. 1764 Must be None or a list or tuple of the same length as `dtypes`. 1765 shared_name: (Optional.) A name to be used for the shared object. By 1766 passing the same name to two different python objects they will share 1767 the underlying staging area. Must be a string. 1768 capacity: (Optional.) Maximum number of elements. 1769 An integer. If zero, the Staging Area is unbounded 1770 memory_limit: (Optional.) Maximum number of bytes of all tensors 1771 in the Staging Area. 1772 An integer. If zero, the Staging Area is unbounded 1773 1774 Raises: 1775 ValueError: If one of the arguments is invalid. 1776 """ 1777 1778 super(StagingArea, self).__init__(dtypes, shapes, names, shared_name, 1779 capacity, memory_limit) 1780 1781 def put(self, values, name=None): 1782 """Create an op that places a value into the staging area. 1783 1784 This operation will block if the `StagingArea` has reached 1785 its capacity. 1786 1787 Args: 1788 values: A single tensor, a list or tuple of tensors, or a dictionary with 1789 tensor values. The number of elements must match the length of the 1790 list provided to the dtypes argument when creating the StagingArea. 1791 name: A name for the operation (optional). 1792 1793 Returns: 1794 The created op. 1795 1796 Raises: 1797 ValueError: If the number or type of inputs don't match the staging area. 1798 """ 1799 with ops.name_scope(name, "%s_put" % self._name, 1800 self._scope_vals(values)) as scope: 1801 1802 if not isinstance(values, (list, tuple, dict)): 1803 values = [values] 1804 1805 # Hard-code indices for this staging area 1806 indices = list(six.moves.range(len(values))) 1807 vals, _ = self._check_put_dtypes(values, indices) 1808 1809 with ops.colocate_with(self._coloc_op): 1810 op = gen_data_flow_ops.stage( 1811 values=vals, 1812 shared_name=self._name, 1813 name=scope, 1814 capacity=self._capacity, 1815 memory_limit=self._memory_limit) 1816 1817 return op 1818 1819 def __internal_get(self, get_fn, name): 1820 with ops.colocate_with(self._coloc_op): 1821 ret = get_fn() 1822 1823 indices = list(six.moves.range(len(self._dtypes))) # Hard coded 1824 return self._get_return_value(ret, indices) 1825 1826 def get(self, name=None): 1827 """Gets one element from this staging area. 1828 1829 If the staging area is empty when this operation executes, it will block 1830 until there is an element to dequeue. 1831 1832 Note that unlike others ops that can block, like the queue Dequeue 1833 operations, this can stop other work from happening. To avoid this, the 1834 intended use is for this to be called only when there will be an element 1835 already available. One method for doing this in a training loop would be to 1836 run a `put()` call during a warmup session.run call, and then call both 1837 `get()` and `put()` in each subsequent step. 1838 1839 The placement of the returned tensor will be determined by the current 1840 device scope when this function is called. 1841 1842 Args: 1843 name: A name for the operation (optional). 1844 1845 Returns: 1846 The tuple of tensors that was gotten. 1847 """ 1848 if name is None: 1849 name = "%s_get" % self._name 1850 1851 # pylint: disable=bad-continuation 1852 fn = lambda: gen_data_flow_ops.unstage(dtypes=self._dtypes, 1853 shared_name=self._name, name=name, 1854 capacity=self._capacity, 1855 memory_limit=self._memory_limit) 1856 # pylint: enable=bad-continuation 1857 1858 return self.__internal_get(fn, name) 1859 1860 def peek(self, index, name=None): 1861 """Peeks at an element in the staging area. 1862 1863 If the staging area is too small to contain the element at 1864 the specified index, it will block until enough elements 1865 are inserted to complete the operation. 1866 1867 The placement of the returned tensor will be determined by 1868 the current device scope when this function is called. 1869 1870 Args: 1871 index: The index of the tensor within the staging area 1872 to look up. 1873 name: A name for the operation (optional). 1874 1875 Returns: 1876 The tuple of tensors that was gotten. 1877 """ 1878 if name is None: 1879 name = "%s_peek" % self._name 1880 1881 # pylint: disable=bad-continuation 1882 fn = lambda: gen_data_flow_ops.stage_peek(index, 1883 dtypes=self._dtypes, shared_name=self._name, 1884 name=name, capacity=self._capacity, 1885 memory_limit=self._memory_limit) 1886 # pylint: enable=bad-continuation 1887 1888 return self.__internal_get(fn, name) 1889 1890 def size(self, name=None): 1891 """Returns the number of elements in the staging area. 1892 1893 Args: 1894 name: A name for the operation (optional) 1895 1896 Returns: 1897 The created op 1898 """ 1899 if name is None: 1900 name = "%s_size" % self._name 1901 1902 return gen_data_flow_ops.stage_size( 1903 name=name, 1904 shared_name=self._name, 1905 dtypes=self._dtypes, 1906 capacity=self._capacity, 1907 memory_limit=self._memory_limit) 1908 1909 def clear(self, name=None): 1910 """Clears the staging area. 1911 1912 Args: 1913 name: A name for the operation (optional) 1914 1915 Returns: 1916 The created op 1917 """ 1918 if name is None: 1919 name = "%s_clear" % self._name 1920 1921 return gen_data_flow_ops.stage_clear( 1922 name=name, 1923 shared_name=self._name, 1924 dtypes=self._dtypes, 1925 capacity=self._capacity, 1926 memory_limit=self._memory_limit) 1927 1928 1929class MapStagingArea(BaseStagingArea): 1930 """A `MapStagingArea` is a TensorFlow data structure that stores tensors 1931 across multiple steps, and exposes operations that can put and get tensors. 1932 1933 Each `MapStagingArea` element is a (key, value) pair. 1934 Only int64 keys are supported, other types should be 1935 hashed to produce a key. 1936 Values are a tuple of one or more tensors. 1937 Each tuple component has a static dtype, 1938 and may have a static shape. 1939 1940 The capacity of a `MapStagingArea` may be bounded or unbounded. 1941 It supports multiple concurrent producers and consumers; and 1942 provides exactly-once delivery. 1943 1944 Each value tuple of a `MapStagingArea` is a fixed-length tuple of tensors 1945 whose 1946 dtypes are described by `dtypes`, and whose shapes are optionally described 1947 by the `shapes` argument. 1948 1949 If the `shapes` argument is specified, each component of a staging area 1950 element must have the respective fixed shape. If it is 1951 unspecified, different elements may have different shapes, 1952 1953 It behaves like an associative container with support for: 1954 1955 - put(key, values) 1956 - peek(key) like dict.get(key) 1957 - get(key) like dict.pop(key) 1958 - get(key=None) like dict.popitem() 1959 - size() 1960 - clear() 1961 1962 If ordered a tree structure ordered by key will be used and 1963 get(key=None) will remove (key, value) pairs in increasing key order. 1964 Otherwise a hashtable 1965 1966 It can be configured with a capacity in which case 1967 put(key, values) will block until space becomes available. 1968 1969 Similarly, it can be configured with a memory limit which 1970 will block put(key, values) until space is available. 1971 This is mostly useful for limiting the number of tensors on 1972 devices such as GPUs. 1973 1974 All get() and peek() commands block if the requested 1975 (key, value) pair is not present in the staging area. 1976 1977 Partial puts are supported and will be placed in an incomplete 1978 map until such time as all values associated with the key have 1979 been inserted. Once completed, this (key, value) pair will be 1980 inserted into the map. Data in the incomplete map 1981 counts towards the memory limit, but not towards capacity limit. 1982 1983 Partial gets from the map are also supported. 1984 This removes the partially requested tensors from the entry, 1985 but the entry is only removed from the map once all tensors 1986 associated with it are removed. 1987 """ 1988 1989 def __init__(self, 1990 dtypes, 1991 shapes=None, 1992 names=None, 1993 shared_name=None, 1994 ordered=False, 1995 capacity=0, 1996 memory_limit=0): 1997 """Args: 1998 1999 dtypes: A list of types. The length of dtypes must equal the number 2000 of tensors in each element. 2001 capacity: (Optional.) Maximum number of elements. 2002 An integer. If zero, the Staging Area is unbounded 2003 memory_limit: (Optional.) Maximum number of bytes of all tensors 2004 in the Staging Area (excluding keys). 2005 An integer. If zero, the Staging Area is unbounded 2006 ordered: (Optional.) If True the underlying data structure 2007 is a tree ordered on key. Otherwise assume a hashtable. 2008 shapes: (Optional.) Constraints on the shapes of tensors in an element. 2009 A list of shape tuples or None. This list is the same length 2010 as dtypes. If the shape of any tensors in the element are constrained, 2011 all must be; shapes can be None if the shapes should not be constrained. 2012 names: (Optional.) If provided, the `get()` and 2013 `put()` methods will use dictionaries with these names as keys. 2014 Must be None or a list or tuple of the same length as `dtypes`. 2015 shared_name: (Optional.) A name to be used for the shared object. By 2016 passing the same name to two different python objects they will share 2017 the underlying staging area. Must be a string. 2018 2019 Raises: 2020 ValueError: If one of the arguments is invalid. 2021 2022 """ 2023 2024 super(MapStagingArea, self).__init__(dtypes, shapes, names, shared_name, 2025 capacity, memory_limit) 2026 2027 # Defer to different methods depending if the map is ordered 2028 self._ordered = ordered 2029 2030 if ordered: 2031 self._put_fn = gen_data_flow_ops.ordered_map_stage 2032 self._pop_fn = gen_data_flow_ops.ordered_map_unstage 2033 self._popitem_fn = gen_data_flow_ops.ordered_map_unstage_no_key 2034 self._peek_fn = gen_data_flow_ops.ordered_map_peek 2035 self._size_fn = gen_data_flow_ops.ordered_map_size 2036 self._incomplete_size_fn = gen_data_flow_ops.ordered_map_incomplete_size 2037 self._clear_fn = gen_data_flow_ops.ordered_map_clear 2038 else: 2039 self._put_fn = gen_data_flow_ops.map_stage 2040 self._pop_fn = gen_data_flow_ops.map_unstage 2041 self._popitem_fn = gen_data_flow_ops.map_unstage_no_key 2042 self._peek_fn = gen_data_flow_ops.map_peek 2043 self._size_fn = gen_data_flow_ops.map_size 2044 self._incomplete_size_fn = gen_data_flow_ops.map_incomplete_size 2045 self._clear_fn = gen_data_flow_ops.map_clear 2046 2047 def put(self, key, vals, indices=None, name=None): 2048 """Create an op that stores the (key, vals) pair in the staging area. 2049 2050 Incomplete puts are possible, preferably using a dictionary for vals 2051 as the appropriate dtypes and shapes can be inferred from the value names 2052 dictionary key values. If vals is a list or tuple, indices must 2053 also be specified so that the op knows at which element position 2054 to perform the insert. 2055 2056 This operation will block if the capacity or memory limit of this 2057 container is reached. 2058 2059 Args: 2060 key: Key associated with the data 2061 vals: Tensor (or a dict/tuple of Tensors) to place 2062 into the staging area. 2063 indices: (Optional) if vals is a tuple/list, this is required. 2064 name: A name for the operation (optional) 2065 2066 Returns: 2067 The created op 2068 2069 Raises: 2070 ValueError: If the number or type of inputs don't match the staging 2071 area. 2072 """ 2073 2074 with ops.name_scope(name, "%s_put" % self._name, 2075 self._scope_vals(vals)) as scope: 2076 2077 vals, indices = self._check_put_dtypes(vals, indices) 2078 2079 with ops.colocate_with(self._coloc_op): 2080 op = self._put_fn( 2081 key, 2082 indices, 2083 vals, 2084 dtypes=self._dtypes, 2085 shared_name=self._name, 2086 name=scope, 2087 capacity=self._capacity, 2088 memory_limit=self._memory_limit) 2089 return op 2090 2091 def _get_indices_and_dtypes(self, indices=None): 2092 if indices is None: 2093 indices = list(six.moves.range(len(self._dtypes))) 2094 2095 if not isinstance(indices, (tuple, list)): 2096 raise TypeError("Invalid indices type '%s'" % type(indices)) 2097 2098 if len(indices) == 0: 2099 raise ValueError("Empty indices") 2100 2101 if all(isinstance(i, str) for i in indices): 2102 if self._names is None: 2103 raise ValueError("String indices provided '%s', but this Staging Area " 2104 "was not created with names." % indices) 2105 2106 try: 2107 indices = [self._names.index(n) for n in indices] 2108 except ValueError: 2109 raise ValueError("Named index '%s' not in " 2110 "Staging Area names '%s'" % (n, self._names)) 2111 elif all(isinstance(i, int) for i in indices): 2112 pass 2113 else: 2114 raise TypeError("Mixed types in indices '%s'. " 2115 "May only be str or int" % indices) 2116 2117 dtypes = [self._dtypes[i] for i in indices] 2118 2119 return indices, dtypes 2120 2121 def peek(self, key, indices=None, name=None): 2122 """Peeks at staging area data associated with the key. 2123 2124 If the key is not in the staging area, it will block 2125 until the associated (key, value) is inserted. 2126 2127 Args: 2128 key: Key associated with the required data 2129 indices: Partial list of tensors to retrieve (optional). 2130 A list of integer or string indices. 2131 String indices are only valid if the Staging Area 2132 has names associated with it. 2133 name: A name for the operation (optional) 2134 2135 Returns: 2136 The created op 2137 """ 2138 2139 if name is None: 2140 name = "%s_pop" % self._name 2141 2142 indices, dtypes = self._get_indices_and_dtypes(indices) 2143 2144 with ops.colocate_with(self._coloc_op): 2145 result = self._peek_fn( 2146 key, 2147 shared_name=self._name, 2148 indices=indices, 2149 dtypes=dtypes, 2150 name=name, 2151 capacity=self._capacity, 2152 memory_limit=self._memory_limit) 2153 2154 return self._get_return_value(result, indices) 2155 2156 def get(self, key=None, indices=None, name=None): 2157 """If the key is provided, the associated (key, value) is returned from the staging area. 2158 2159 If the key is not in the staging area, this method will block until 2160 the associated (key, value) is inserted. 2161 If no key is provided and the staging area is ordered, 2162 the (key, value) with the smallest key will be returned. 2163 Otherwise, a random (key, value) will be returned. 2164 2165 If the staging area is empty when this operation executes, 2166 it will block until there is an element to dequeue. 2167 2168 Args: 2169 key: Key associated with the required data (Optional) 2170 indices: Partial list of tensors to retrieve (optional). 2171 A list of integer or string indices. 2172 String indices are only valid if the Staging Area 2173 has names associated with it. 2174 name: A name for the operation (optional) 2175 2176 Returns: 2177 The created op 2178 """ 2179 if key is None: 2180 return self._popitem(indices=indices, name=name) 2181 else: 2182 return self._pop(key, indices=indices, name=name) 2183 2184 def _pop(self, key, indices=None, name=None): 2185 """Remove and return the associated (key, value) is returned from the staging area. 2186 2187 If the key is not in the staging area, this method will block until 2188 the associated (key, value) is inserted. 2189 Args: 2190 key: Key associated with the required data 2191 indices: Partial list of tensors to retrieve (optional). 2192 A list of integer or string indices. 2193 String indices are only valid if the Staging Area 2194 has names associated with it. 2195 name: A name for the operation (optional) 2196 2197 Returns: 2198 The created op 2199 """ 2200 if name is None: 2201 name = "%s_get" % self._name 2202 2203 indices, dtypes = self._get_indices_and_dtypes(indices) 2204 2205 with ops.colocate_with(self._coloc_op): 2206 result = self._pop_fn( 2207 key, 2208 shared_name=self._name, 2209 indices=indices, 2210 dtypes=dtypes, 2211 name=name, 2212 capacity=self._capacity, 2213 memory_limit=self._memory_limit) 2214 2215 return key, self._get_return_value(result, indices) 2216 2217 def _popitem(self, indices=None, name=None): 2218 """If the staging area is ordered, the (key, value) with the smallest key will be returned. 2219 2220 Otherwise, a random (key, value) will be returned. 2221 If the staging area is empty when this operation executes, 2222 it will block until there is an element to dequeue. 2223 2224 Args: 2225 key: Key associated with the required data 2226 indices: Partial list of tensors to retrieve (optional). 2227 A list of integer or string indices. 2228 String indices are only valid if the Staging Area 2229 has names associated with it. 2230 name: A name for the operation (optional) 2231 2232 Returns: 2233 The created op 2234 """ 2235 if name is None: 2236 name = "%s_get_nokey" % self._name 2237 2238 indices, dtypes = self._get_indices_and_dtypes(indices) 2239 2240 with ops.colocate_with(self._coloc_op): 2241 key, result = self._popitem_fn( 2242 shared_name=self._name, 2243 indices=indices, 2244 dtypes=dtypes, 2245 name=name, 2246 capacity=self._capacity, 2247 memory_limit=self._memory_limit) 2248 2249 # Separate keys and results out from 2250 # underlying namedtuple 2251 key = self._create_device_transfers(key)[0] 2252 result = self._get_return_value(result, indices) 2253 2254 return key, result 2255 2256 def size(self, name=None): 2257 """Returns the number of elements in the staging area. 2258 2259 Args: 2260 name: A name for the operation (optional) 2261 2262 Returns: 2263 The created op 2264 """ 2265 if name is None: 2266 name = "%s_size" % self._name 2267 2268 return self._size_fn( 2269 shared_name=self._name, 2270 name=name, 2271 dtypes=self._dtypes, 2272 capacity=self._capacity, 2273 memory_limit=self._memory_limit) 2274 2275 def incomplete_size(self, name=None): 2276 """Returns the number of incomplete elements in the staging area. 2277 2278 Args: 2279 name: A name for the operation (optional) 2280 2281 Returns: 2282 The created op 2283 """ 2284 if name is None: 2285 name = "%s_incomplete_size" % self._name 2286 2287 return self._incomplete_size_fn( 2288 shared_name=self._name, 2289 name=name, 2290 dtypes=self._dtypes, 2291 capacity=self._capacity, 2292 memory_limit=self._memory_limit) 2293 2294 def clear(self, name=None): 2295 """Clears the staging area. 2296 2297 Args: 2298 name: A name for the operation (optional) 2299 2300 Returns: 2301 The created op 2302 """ 2303 if name is None: 2304 name = "%s_clear" % self._name 2305 2306 return self._clear_fn( 2307 shared_name=self._name, 2308 name=name, 2309 dtypes=self._dtypes, 2310 capacity=self._capacity, 2311 memory_limit=self._memory_limit) 2312 2313 2314class RecordInput(object): 2315 """RecordInput asynchronously reads and randomly yields TFRecords. 2316 2317 A RecordInput Op will continuously read a batch of records asynchronously 2318 into a buffer of some fixed capacity. It can also asynchronously yield 2319 random records from this buffer. 2320 2321 It will not start yielding until at least `buffer_size / 2` elements have been 2322 placed into the buffer so that sufficient randomization can take place. 2323 2324 The order the files are read will be shifted each epoch by `shift_amount` so 2325 that the data is presented in a different order every epoch. 2326 """ 2327 2328 def __init__(self, 2329 file_pattern, 2330 batch_size=1, 2331 buffer_size=1, 2332 parallelism=1, 2333 shift_ratio=0, 2334 seed=0, 2335 name=None, 2336 batches=None, 2337 compression_type=None): 2338 """Constructs a RecordInput Op. 2339 2340 Args: 2341 file_pattern: File path to the dataset, possibly containing wildcards. 2342 All matching files will be iterated over each epoch. 2343 batch_size: How many records to return at a time. 2344 buffer_size: The maximum number of records the buffer will contain. 2345 parallelism: How many reader threads to use for reading from files. 2346 shift_ratio: What percentage of the total number files to move the start 2347 file forward by each epoch. 2348 seed: Specify the random number seed used by generator that randomizes 2349 records. 2350 name: Optional name for the operation. 2351 batches: None by default, creating a single batch op. Otherwise specifies 2352 how many batches to create, which are returned as a list when 2353 `get_yield_op()` is called. An example use case is to split processing 2354 between devices on one computer. 2355 compression_type: The type of compression for the file. Currently ZLIB and 2356 GZIP are supported. Defaults to none. 2357 2358 Raises: 2359 ValueError: If one of the arguments is invalid. 2360 """ 2361 self._batch_size = batch_size 2362 if batches is not None: 2363 self._batch_size *= batches 2364 self._batches = batches 2365 self._file_pattern = file_pattern 2366 self._buffer_size = buffer_size 2367 self._parallelism = parallelism 2368 self._shift_ratio = shift_ratio 2369 self._seed = seed 2370 self._name = name 2371 self._compression_type = python_io.TFRecordCompressionType.NONE 2372 if compression_type is not None: 2373 self._compression_type = compression_type 2374 2375 def get_yield_op(self): 2376 """Adds a node that yields a group of records every time it is executed. 2377 If RecordInput `batches` parameter is not None, it yields a list of 2378 record batches with the specified `batch_size`. 2379 """ 2380 compression_type = python_io.TFRecordOptions.get_compression_type_string( 2381 python_io.TFRecordOptions(self._compression_type)) 2382 records = gen_data_flow_ops.record_input( 2383 file_pattern=self._file_pattern, 2384 file_buffer_size=self._buffer_size, 2385 file_parallelism=self._parallelism, 2386 file_shuffle_shift_ratio=self._shift_ratio, 2387 batch_size=self._batch_size, 2388 file_random_seed=self._seed, 2389 compression_type=compression_type, 2390 name=self._name) 2391 if self._batches is None: 2392 return records 2393 else: 2394 with ops.name_scope(self._name): 2395 batch_list = [[] for _ in six.moves.range(self._batches)] 2396 records = array_ops.split(records, self._batch_size, 0) 2397 records = [array_ops.reshape(record, []) for record in records] 2398 for index, protobuf in zip(six.moves.range(len(records)), records): 2399 batch_index = index % self._batches 2400 batch_list[batch_index].append(protobuf) 2401 return batch_list 2402