1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14#============================================================================== 15"""Data Flow Operations.""" 16# pylint: disable=g-bad-name 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import hashlib 23import threading 24 25import six 26 27from tensorflow.python.eager import context 28from tensorflow.python.framework import dtypes as _dtypes 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import random_seed 31from tensorflow.python.framework import tensor_shape 32from tensorflow.python.framework import tensor_util 33from tensorflow.python.lib.io import python_io 34from tensorflow.python.ops import array_ops 35from tensorflow.python.ops import control_flow_ops 36from tensorflow.python.ops import gen_data_flow_ops 37from tensorflow.python.ops import math_ops 38# go/tf-wildcard-import 39# pylint: disable=wildcard-import 40from tensorflow.python.ops.gen_data_flow_ops import * 41from tensorflow.python.util.tf_export import tf_export 42 43# pylint: enable=wildcard-import 44 45 46def _as_type_list(dtypes): 47 """Convert dtypes to a list of types.""" 48 assert dtypes is not None 49 if not (isinstance(dtypes, list) or isinstance(dtypes, tuple)): 50 # We have a single type. 51 return [dtypes] 52 else: 53 # We have a list or tuple of types. 54 return list(dtypes) 55 56 57def _as_shape_list(shapes, 58 dtypes, 59 unknown_dim_allowed=False, 60 unknown_rank_allowed=False): 61 """Convert shapes to a list of tuples of int (or None).""" 62 del dtypes 63 if unknown_dim_allowed: 64 if (not isinstance(shapes, collections.Sequence) or not shapes or 65 any(shape is None or isinstance(shape, int) for shape in shapes)): 66 raise ValueError( 67 "When providing partial shapes, a list of shapes must be provided.") 68 if shapes is None: 69 return None 70 if isinstance(shapes, tensor_shape.TensorShape): 71 shapes = [shapes] 72 if not isinstance(shapes, (tuple, list)): 73 raise TypeError( 74 "shapes must be a TensorShape or a list or tuple of TensorShapes.") 75 if all(shape is None or isinstance(shape, int) for shape in shapes): 76 # We have a single shape. 77 shapes = [shapes] 78 shapes = [tensor_shape.as_shape(shape) for shape in shapes] 79 if not unknown_dim_allowed: 80 if any([not shape.is_fully_defined() for shape in shapes]): 81 raise ValueError("All shapes must be fully defined: %s" % shapes) 82 if not unknown_rank_allowed: 83 if any([shape.dims is None for shape in shapes]): 84 raise ValueError("All shapes must have a defined rank: %s" % shapes) 85 86 return shapes 87 88 89def _as_name_list(names, dtypes): 90 if names is None: 91 return None 92 if not isinstance(names, (list, tuple)): 93 names = [names] 94 if len(names) != len(dtypes): 95 raise ValueError("List of names must have the same length as the list " 96 "of dtypes") 97 return list(names) 98 99 100def _shape_common(s1, s2): 101 """The greatest lower bound (ordered by specificity) TensorShape.""" 102 s1 = tensor_shape.TensorShape(s1) 103 s2 = tensor_shape.TensorShape(s2) 104 if s1.ndims is None or s2.ndims is None or s1.ndims != s2.ndims: 105 return tensor_shape.unknown_shape() 106 d = [ 107 d1 if d1 is not None and d1 == d2 else None 108 for (d1, d2) in zip(s1.as_list(), s2.as_list()) 109 ] 110 return tensor_shape.TensorShape(d) 111 112 113# pylint: disable=protected-access 114@tf_export("QueueBase") 115class QueueBase(object): 116 """Base class for queue implementations. 117 118 A queue is a TensorFlow data structure that stores tensors across 119 multiple steps, and exposes operations that enqueue and dequeue 120 tensors. 121 122 Each queue element is a tuple of one or more tensors, where each 123 tuple component has a static dtype, and may have a static shape. The 124 queue implementations support versions of enqueue and dequeue that 125 handle single elements, versions that support enqueuing and 126 dequeuing a batch of elements at once. 127 128 See @{tf.FIFOQueue} and 129 @{tf.RandomShuffleQueue} for concrete 130 implementations of this class, and instructions on how to create 131 them. 132 133 @compatibility(eager) 134 Queues are not compatible with eager execution. Instead, please 135 use `tf.data` to get data into your model. 136 @end_compatibility 137 """ 138 139 def __init__(self, dtypes, shapes, names, queue_ref): 140 """Constructs a queue object from a queue reference. 141 142 The two optional lists, `shapes` and `names`, must be of the same length 143 as `dtypes` if provided. The values at a given index `i` indicate the 144 shape and name to use for the corresponding queue component in `dtypes`. 145 146 Args: 147 dtypes: A list of types. The length of dtypes must equal the number 148 of tensors in each element. 149 shapes: Constraints on the shapes of tensors in an element: 150 A list of shape tuples or None. This list is the same length 151 as dtypes. If the shape of any tensors in the element are constrained, 152 all must be; shapes can be None if the shapes should not be constrained. 153 names: Optional list of names. If provided, the `enqueue()` and 154 `dequeue()` methods will use dictionaries with these names as keys. 155 Must be None or a list or tuple of the same length as `dtypes`. 156 queue_ref: The queue reference, i.e. the output of the queue op. 157 158 Raises: 159 ValueError: If one of the arguments is invalid. 160 RuntimeError: If eager execution is enabled. 161 """ 162 if context.in_eager_mode(): 163 raise RuntimeError( 164 "Queues are not supported when eager execution is enabled. " 165 "Instead, please use tf.data to get data into your model.") 166 self._dtypes = dtypes 167 if shapes is not None: 168 if len(shapes) != len(dtypes): 169 raise ValueError("Queue shapes must have the same length as dtypes") 170 self._shapes = [tensor_shape.TensorShape(s) for s in shapes] 171 else: 172 self._shapes = [tensor_shape.unknown_shape() for _ in self._dtypes] 173 if names is not None: 174 if len(names) != len(dtypes): 175 raise ValueError("Queue names must have the same length as dtypes") 176 self._names = names 177 else: 178 self._names = None 179 self._queue_ref = queue_ref 180 if context.in_graph_mode(): 181 self._name = self._queue_ref.op.name.split("/")[-1] 182 else: 183 self._name = context.context().scope_name 184 185 @staticmethod 186 def from_list(index, queues): 187 """Create a queue using the queue reference from `queues[index]`. 188 189 Args: 190 index: An integer scalar tensor that determines the input that gets 191 selected. 192 queues: A list of `QueueBase` objects. 193 194 Returns: 195 A `QueueBase` object. 196 197 Raises: 198 TypeError: When `queues` is not a list of `QueueBase` objects, 199 or when the data types of `queues` are not all the same. 200 """ 201 if ((not queues) or (not isinstance(queues, list)) or 202 (not all(isinstance(x, QueueBase) for x in queues))): 203 raise TypeError("A list of queues expected") 204 205 dtypes = queues[0].dtypes 206 if not all([dtypes == q.dtypes for q in queues[1:]]): 207 raise TypeError("Queues do not have matching component dtypes.") 208 209 names = queues[0].names 210 if not all([names == q.names for q in queues[1:]]): 211 raise TypeError("Queues do not have matching component names.") 212 213 queue_shapes = [q.shapes for q in queues] 214 reduced_shapes = [ 215 six.moves.reduce(_shape_common, s) for s in zip(*queue_shapes) 216 ] 217 218 queue_refs = array_ops.stack([x.queue_ref for x in queues]) 219 selected_queue = array_ops.gather(queue_refs, index) 220 return QueueBase( 221 dtypes=dtypes, 222 shapes=reduced_shapes, 223 names=names, 224 queue_ref=selected_queue) 225 226 @property 227 def queue_ref(self): 228 """The underlying queue reference.""" 229 return self._queue_ref 230 231 @property 232 def name(self): 233 """The name of the underlying queue.""" 234 if context.in_graph_mode(): 235 return self._queue_ref.op.name 236 return self._name 237 238 @property 239 def dtypes(self): 240 """The list of dtypes for each component of a queue element.""" 241 return self._dtypes 242 243 @property 244 def shapes(self): 245 """The list of shapes for each component of a queue element.""" 246 return self._shapes 247 248 @property 249 def names(self): 250 """The list of names for each component of a queue element.""" 251 return self._names 252 253 def _check_enqueue_dtypes(self, vals): 254 """Validate and convert `vals` to a list of `Tensor`s. 255 256 The `vals` argument can be a Tensor, a list or tuple of tensors, or a 257 dictionary with tensor values. 258 259 If it is a dictionary, the queue must have been constructed with a 260 `names` attribute and the dictionary keys must match the queue names. 261 If the queue was constructed with a `names` attribute, `vals` must 262 be a dictionary. 263 264 Args: 265 vals: A tensor, a list or tuple of tensors, or a dictionary.. 266 267 Returns: 268 A list of `Tensor` objects. 269 270 Raises: 271 ValueError: If `vals` is invalid. 272 """ 273 if isinstance(vals, dict): 274 if not self._names: 275 raise ValueError("Queue must have names to enqueue a dictionary") 276 if sorted(self._names, key=str) != sorted(vals.keys(), key=str): 277 raise ValueError("Keys in dictionary to enqueue do not match " 278 "names of Queue. Dictionary: (%s), Queue: (%s)" % 279 (sorted(vals.keys()), sorted(self._names))) 280 # The order of values in `self._names` indicates the order in which the 281 # tensors in the dictionary `vals` must be listed. 282 vals = [vals[k] for k in self._names] 283 else: 284 if self._names: 285 raise ValueError("You must enqueue a dictionary in a Queue with names") 286 if not isinstance(vals, (list, tuple)): 287 vals = [vals] 288 289 tensors = [] 290 for i, (val, dtype) in enumerate(zip(vals, self._dtypes)): 291 tensors.append( 292 ops.convert_to_tensor(val, dtype=dtype, name="component_%d" % i)) 293 294 return tensors 295 296 def _scope_vals(self, vals): 297 """Return a list of values to pass to `name_scope()`. 298 299 Args: 300 vals: A tensor, a list or tuple of tensors, or a dictionary. 301 302 Returns: 303 The values in vals as a list. 304 """ 305 if isinstance(vals, (list, tuple)): 306 return vals 307 elif isinstance(vals, dict): 308 return vals.values() 309 else: 310 return [vals] 311 312 def enqueue(self, vals, name=None): 313 """Enqueues one element to this queue. 314 315 If the queue is full when this operation executes, it will block 316 until the element has been enqueued. 317 318 At runtime, this operation may raise an error if the queue is 319 @{tf.QueueBase.close} before or during its execution. If the 320 queue is closed before this operation runs, 321 `tf.errors.CancelledError` will be raised. If this operation is 322 blocked, and either (i) the queue is closed by a close operation 323 with `cancel_pending_enqueues=True`, or (ii) the session is 324 @{tf.Session.close}, 325 `tf.errors.CancelledError` will be raised. 326 327 Args: 328 vals: A tensor, a list or tuple of tensors, or a dictionary containing 329 the values to enqueue. 330 name: A name for the operation (optional). 331 332 Returns: 333 The operation that enqueues a new tuple of tensors to the queue. 334 """ 335 with ops.name_scope(name, "%s_enqueue" % self._name, 336 self._scope_vals(vals)) as scope: 337 vals = self._check_enqueue_dtypes(vals) 338 339 # NOTE(mrry): Not using a shape function because we need access to 340 # the `QueueBase` object. 341 for val, shape in zip(vals, self._shapes): 342 val.get_shape().assert_is_compatible_with(shape) 343 344 if self._queue_ref.dtype == _dtypes.resource: 345 return gen_data_flow_ops._queue_enqueue_v2( 346 self._queue_ref, vals, name=scope) 347 else: 348 return gen_data_flow_ops._queue_enqueue( 349 self._queue_ref, vals, name=scope) 350 351 def enqueue_many(self, vals, name=None): 352 """Enqueues zero or more elements to this queue. 353 354 This operation slices each component tensor along the 0th dimension to 355 make multiple queue elements. All of the tensors in `vals` must have the 356 same size in the 0th dimension. 357 358 If the queue is full when this operation executes, it will block 359 until all of the elements have been enqueued. 360 361 At runtime, this operation may raise an error if the queue is 362 @{tf.QueueBase.close} before or during its execution. If the 363 queue is closed before this operation runs, 364 `tf.errors.CancelledError` will be raised. If this operation is 365 blocked, and either (i) the queue is closed by a close operation 366 with `cancel_pending_enqueues=True`, or (ii) the session is 367 @{tf.Session.close}, 368 `tf.errors.CancelledError` will be raised. 369 370 Args: 371 vals: A tensor, a list or tuple of tensors, or a dictionary 372 from which the queue elements are taken. 373 name: A name for the operation (optional). 374 375 Returns: 376 The operation that enqueues a batch of tuples of tensors to the queue. 377 """ 378 with ops.name_scope(name, "%s_EnqueueMany" % self._name, 379 self._scope_vals(vals)) as scope: 380 vals = self._check_enqueue_dtypes(vals) 381 382 # NOTE(mrry): Not using a shape function because we need access to 383 # the `QueueBase` object. 384 batch_dim = vals[0].get_shape().with_rank_at_least(1)[0] 385 for val, shape in zip(vals, self._shapes): 386 batch_dim = batch_dim.merge_with( 387 val.get_shape().with_rank_at_least(1)[0]) 388 val.get_shape()[1:].assert_is_compatible_with(shape) 389 390 return gen_data_flow_ops._queue_enqueue_many_v2( 391 self._queue_ref, vals, name=scope) 392 393 def _dequeue_return_value(self, tensors): 394 """Return the value to return from a dequeue op. 395 396 If the queue has names, return a dictionary with the 397 names as keys. Otherwise return either a single tensor 398 or a list of tensors depending on the length of `tensors`. 399 400 Args: 401 tensors: List of tensors from the dequeue op. 402 403 Returns: 404 A single tensor, a list of tensors, or a dictionary 405 of tensors. 406 """ 407 if self._names: 408 # The returned values in `tensors` are in the same order as 409 # the names in `self._names`. 410 return {n: tensors[i] for i, n in enumerate(self._names)} 411 elif len(tensors) == 1: 412 return tensors[0] 413 else: 414 return tensors 415 416 def dequeue(self, name=None): 417 """Dequeues one element from this queue. 418 419 If the queue is empty when this operation executes, it will block 420 until there is an element to dequeue. 421 422 At runtime, this operation may raise an error if the queue is 423 @{tf.QueueBase.close} before or during its execution. If the 424 queue is closed, the queue is empty, and there are no pending 425 enqueue operations that can fulfill this request, 426 `tf.errors.OutOfRangeError` will be raised. If the session is 427 @{tf.Session.close}, 428 `tf.errors.CancelledError` will be raised. 429 430 Args: 431 name: A name for the operation (optional). 432 433 Returns: 434 The tuple of tensors that was dequeued. 435 """ 436 if name is None: 437 name = "%s_Dequeue" % self._name 438 if self._queue_ref.dtype == _dtypes.resource: 439 ret = gen_data_flow_ops._queue_dequeue_v2( 440 self._queue_ref, self._dtypes, name=name) 441 else: 442 ret = gen_data_flow_ops._queue_dequeue( 443 self._queue_ref, self._dtypes, name=name) 444 445 # NOTE(mrry): Not using a shape function because we need access to 446 # the `QueueBase` object. 447 if context.in_graph_mode(): 448 op = ret[0].op 449 for output, shape in zip(op.values(), self._shapes): 450 output.set_shape(shape) 451 452 return self._dequeue_return_value(ret) 453 454 def dequeue_many(self, n, name=None): 455 """Dequeues and concatenates `n` elements from this queue. 456 457 This operation concatenates queue-element component tensors along 458 the 0th dimension to make a single component tensor. All of the 459 components in the dequeued tuple will have size `n` in the 0th dimension. 460 461 If the queue is closed and there are less than `n` elements left, then an 462 `OutOfRange` exception is raised. 463 464 At runtime, this operation may raise an error if the queue is 465 @{tf.QueueBase.close} before or during its execution. If the 466 queue is closed, the queue contains fewer than `n` elements, and 467 there are no pending enqueue operations that can fulfill this 468 request, `tf.errors.OutOfRangeError` will be raised. If the 469 session is @{tf.Session.close}, 470 `tf.errors.CancelledError` will be raised. 471 472 Args: 473 n: A scalar `Tensor` containing the number of elements to dequeue. 474 name: A name for the operation (optional). 475 476 Returns: 477 The list of concatenated tensors that was dequeued. 478 """ 479 if name is None: 480 name = "%s_DequeueMany" % self._name 481 482 ret = gen_data_flow_ops._queue_dequeue_many_v2( 483 self._queue_ref, n=n, component_types=self._dtypes, name=name) 484 485 # NOTE(mrry): Not using a shape function because we need access to 486 # the Queue object. 487 if context.in_graph_mode(): 488 op = ret[0].op 489 batch_dim = tensor_shape.Dimension( 490 tensor_util.constant_value(op.inputs[1])) 491 for output, shape in zip(op.values(), self._shapes): 492 output.set_shape( 493 tensor_shape.TensorShape([batch_dim]).concatenate(shape)) 494 495 return self._dequeue_return_value(ret) 496 497 def dequeue_up_to(self, n, name=None): 498 """Dequeues and concatenates `n` elements from this queue. 499 500 **Note** This operation is not supported by all queues. If a queue does not 501 support DequeueUpTo, then a `tf.errors.UnimplementedError` is raised. 502 503 This operation concatenates queue-element component tensors along 504 the 0th dimension to make a single component tensor. If the queue 505 has not been closed, all of the components in the dequeued tuple 506 will have size `n` in the 0th dimension. 507 508 If the queue is closed and there are more than `0` but fewer than 509 `n` elements remaining, then instead of raising a 510 `tf.errors.OutOfRangeError` like @{tf.QueueBase.dequeue_many}, 511 less than `n` elements are returned immediately. If the queue is 512 closed and there are `0` elements left in the queue, then a 513 `tf.errors.OutOfRangeError` is raised just like in `dequeue_many`. 514 Otherwise the behavior is identical to `dequeue_many`. 515 516 Args: 517 n: A scalar `Tensor` containing the number of elements to dequeue. 518 name: A name for the operation (optional). 519 520 Returns: 521 The tuple of concatenated tensors that was dequeued. 522 """ 523 if name is None: 524 name = "%s_DequeueUpTo" % self._name 525 526 ret = gen_data_flow_ops._queue_dequeue_up_to_v2( 527 self._queue_ref, n=n, component_types=self._dtypes, name=name) 528 529 # NOTE(mrry): Not using a shape function because we need access to 530 # the Queue object. 531 if context.in_graph_mode(): 532 op = ret[0].op 533 for output, shape in zip(op.values(), self._shapes): 534 output.set_shape(tensor_shape.TensorShape([None]).concatenate(shape)) 535 536 return self._dequeue_return_value(ret) 537 538 def close(self, cancel_pending_enqueues=False, name=None): 539 """Closes this queue. 540 541 This operation signals that no more elements will be enqueued in 542 the given queue. Subsequent `enqueue` and `enqueue_many` 543 operations will fail. Subsequent `dequeue` and `dequeue_many` 544 operations will continue to succeed if sufficient elements remain 545 in the queue. Subsequently dequeue and dequeue_many operations 546 that would otherwise block waiting for more elements (if close 547 hadn't been called) will now fail immediately. 548 549 If `cancel_pending_enqueues` is `True`, all pending requests will also 550 be canceled. 551 552 Args: 553 cancel_pending_enqueues: (Optional.) A boolean, defaulting to 554 `False` (described above). 555 name: A name for the operation (optional). 556 557 Returns: 558 The operation that closes the queue. 559 """ 560 if name is None: 561 name = "%s_Close" % self._name 562 if self._queue_ref.dtype == _dtypes.resource: 563 return gen_data_flow_ops._queue_close_v2( 564 self._queue_ref, 565 cancel_pending_enqueues=cancel_pending_enqueues, 566 name=name) 567 else: 568 return gen_data_flow_ops._queue_close( 569 self._queue_ref, 570 cancel_pending_enqueues=cancel_pending_enqueues, 571 name=name) 572 573 def is_closed(self, name=None): 574 """ Returns true if queue is closed. 575 576 This operation returns true if the queue is closed and false if the queue 577 is open. 578 579 Args: 580 name: A name for the operation (optional). 581 582 Returns: 583 True if the queue is closed and false if the queue is open. 584 """ 585 if name is None: 586 name = "%s_Is_Closed" % self._name 587 if self._queue_ref.dtype == _dtypes.resource: 588 return gen_data_flow_ops.queue_is_closed_v2(self._queue_ref, name=name) 589 else: 590 return gen_data_flow_ops.queue_is_closed_(self._queue_ref, name=name) 591 592 def size(self, name=None): 593 """Compute the number of elements in this queue. 594 595 Args: 596 name: A name for the operation (optional). 597 598 Returns: 599 A scalar tensor containing the number of elements in this queue. 600 """ 601 if name is None: 602 name = "%s_Size" % self._name 603 if self._queue_ref.dtype == _dtypes.resource: 604 return gen_data_flow_ops._queue_size_v2(self._queue_ref, name=name) 605 else: 606 return gen_data_flow_ops._queue_size(self._queue_ref, name=name) 607 608 609@tf_export("RandomShuffleQueue") 610class RandomShuffleQueue(QueueBase): 611 """A queue implementation that dequeues elements in a random order. 612 613 See @{tf.QueueBase} for a description of the methods on 614 this class. 615 616 @compatibility(eager) 617 Queues are not compatible with eager execution. Instead, please 618 use `tf.data` to get data into your model. 619 @end_compatibility 620 """ 621 622 def __init__(self, 623 capacity, 624 min_after_dequeue, 625 dtypes, 626 shapes=None, 627 names=None, 628 seed=None, 629 shared_name=None, 630 name="random_shuffle_queue"): 631 """Create a queue that dequeues elements in a random order. 632 633 A `RandomShuffleQueue` has bounded capacity; supports multiple 634 concurrent producers and consumers; and provides exactly-once 635 delivery. 636 637 A `RandomShuffleQueue` holds a list of up to `capacity` 638 elements. Each element is a fixed-length tuple of tensors whose 639 dtypes are described by `dtypes`, and whose shapes are optionally 640 described by the `shapes` argument. 641 642 If the `shapes` argument is specified, each component of a queue 643 element must have the respective fixed shape. If it is 644 unspecified, different queue elements may have different shapes, 645 but the use of `dequeue_many` is disallowed. 646 647 The `min_after_dequeue` argument allows the caller to specify a 648 minimum number of elements that will remain in the queue after a 649 `dequeue` or `dequeue_many` operation completes, to ensure a 650 minimum level of mixing of elements. This invariant is maintained 651 by blocking those operations until sufficient elements have been 652 enqueued. The `min_after_dequeue` argument is ignored after the 653 queue has been closed. 654 655 Args: 656 capacity: An integer. The upper bound on the number of elements 657 that may be stored in this queue. 658 min_after_dequeue: An integer (described above). 659 dtypes: A list of `DType` objects. The length of `dtypes` must equal 660 the number of tensors in each queue element. 661 shapes: (Optional.) A list of fully-defined `TensorShape` objects 662 with the same length as `dtypes`, or `None`. 663 names: (Optional.) A list of string naming the components in the queue 664 with the same length as `dtypes`, or `None`. If specified the dequeue 665 methods return a dictionary with the names as keys. 666 seed: A Python integer. Used to create a random seed. See 667 @{tf.set_random_seed} 668 for behavior. 669 shared_name: (Optional.) If non-empty, this queue will be shared under 670 the given name across multiple sessions. 671 name: Optional name for the queue operation. 672 """ 673 dtypes = _as_type_list(dtypes) 674 shapes = _as_shape_list(shapes, dtypes) 675 names = _as_name_list(names, dtypes) 676 seed1, seed2 = random_seed.get_seed(seed) 677 if seed1 is None and seed2 is None: 678 seed1, seed2 = 0, 0 679 elif seed is None and shared_name is not None: 680 # This means that graph seed is provided but op seed is not provided. 681 # If shared_name is also provided, make seed2 depend only on the graph 682 # seed and shared_name. (seed2 from get_seed() is generally dependent on 683 # the id of the last op created.) 684 string = (str(seed1) + shared_name).encode("utf-8") 685 seed2 = int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF 686 queue_ref = gen_data_flow_ops._random_shuffle_queue_v2( 687 component_types=dtypes, 688 shapes=shapes, 689 capacity=capacity, 690 min_after_dequeue=min_after_dequeue, 691 seed=seed1, 692 seed2=seed2, 693 shared_name=shared_name, 694 name=name) 695 696 super(RandomShuffleQueue, self).__init__(dtypes, shapes, names, queue_ref) 697 698 699@tf_export("FIFOQueue") 700class FIFOQueue(QueueBase): 701 """A queue implementation that dequeues elements in first-in first-out order. 702 703 See @{tf.QueueBase} for a description of the methods on 704 this class. 705 706 @compatibility(eager) 707 Queues are not compatible with eager execution. Instead, please 708 use `tf.data` to get data into your model. 709 @end_compatibility 710 """ 711 712 def __init__(self, 713 capacity, 714 dtypes, 715 shapes=None, 716 names=None, 717 shared_name=None, 718 name="fifo_queue"): 719 """Creates a queue that dequeues elements in a first-in first-out order. 720 721 A `FIFOQueue` has bounded capacity; supports multiple concurrent 722 producers and consumers; and provides exactly-once delivery. 723 724 A `FIFOQueue` holds a list of up to `capacity` elements. Each 725 element is a fixed-length tuple of tensors whose dtypes are 726 described by `dtypes`, and whose shapes are optionally described 727 by the `shapes` argument. 728 729 If the `shapes` argument is specified, each component of a queue 730 element must have the respective fixed shape. If it is 731 unspecified, different queue elements may have different shapes, 732 but the use of `dequeue_many` is disallowed. 733 734 Args: 735 capacity: An integer. The upper bound on the number of elements 736 that may be stored in this queue. 737 dtypes: A list of `DType` objects. The length of `dtypes` must equal 738 the number of tensors in each queue element. 739 shapes: (Optional.) A list of fully-defined `TensorShape` objects 740 with the same length as `dtypes`, or `None`. 741 names: (Optional.) A list of string naming the components in the queue 742 with the same length as `dtypes`, or `None`. If specified the dequeue 743 methods return a dictionary with the names as keys. 744 shared_name: (Optional.) If non-empty, this queue will be shared under 745 the given name across multiple sessions. 746 name: Optional name for the queue operation. 747 """ 748 dtypes = _as_type_list(dtypes) 749 shapes = _as_shape_list(shapes, dtypes) 750 names = _as_name_list(names, dtypes) 751 queue_ref = gen_data_flow_ops._fifo_queue_v2( 752 component_types=dtypes, 753 shapes=shapes, 754 capacity=capacity, 755 shared_name=shared_name, 756 name=name) 757 758 super(FIFOQueue, self).__init__(dtypes, shapes, names, queue_ref) 759 760 761@tf_export("PaddingFIFOQueue") 762class PaddingFIFOQueue(QueueBase): 763 """A FIFOQueue that supports batching variable-sized tensors by padding. 764 765 A `PaddingFIFOQueue` may contain components with dynamic shape, while also 766 supporting `dequeue_many`. See the constructor for more details. 767 768 See @{tf.QueueBase} for a description of the methods on 769 this class. 770 771 @compatibility(eager) 772 Queues are not compatible with eager execution. Instead, please 773 use `tf.data` to get data into your model. 774 @end_compatibility 775 """ 776 777 def __init__(self, 778 capacity, 779 dtypes, 780 shapes, 781 names=None, 782 shared_name=None, 783 name="padding_fifo_queue"): 784 """Creates a queue that dequeues elements in a first-in first-out order. 785 786 A `PaddingFIFOQueue` has bounded capacity; supports multiple concurrent 787 producers and consumers; and provides exactly-once delivery. 788 789 A `PaddingFIFOQueue` holds a list of up to `capacity` elements. Each 790 element is a fixed-length tuple of tensors whose dtypes are 791 described by `dtypes`, and whose shapes are described by the `shapes` 792 argument. 793 794 The `shapes` argument must be specified; each component of a queue 795 element must have the respective shape. Shapes of fixed 796 rank but variable size are allowed by setting any shape dimension to None. 797 In this case, the inputs' shape may vary along the given dimension, and 798 `dequeue_many` will pad the given dimension with zeros up to the maximum 799 shape of all elements in the given batch. 800 801 Args: 802 capacity: An integer. The upper bound on the number of elements 803 that may be stored in this queue. 804 dtypes: A list of `DType` objects. The length of `dtypes` must equal 805 the number of tensors in each queue element. 806 shapes: A list of `TensorShape` objects, with the same length as 807 `dtypes`. Any dimension in the `TensorShape` containing value 808 `None` is dynamic and allows values to be enqueued with 809 variable size in that dimension. 810 names: (Optional.) A list of string naming the components in the queue 811 with the same length as `dtypes`, or `None`. If specified the dequeue 812 methods return a dictionary with the names as keys. 813 shared_name: (Optional.) If non-empty, this queue will be shared under 814 the given name across multiple sessions. 815 name: Optional name for the queue operation. 816 817 Raises: 818 ValueError: If shapes is not a list of shapes, or the lengths of dtypes 819 and shapes do not match, or if names is specified and the lengths of 820 dtypes and names do not match. 821 """ 822 dtypes = _as_type_list(dtypes) 823 shapes = _as_shape_list(shapes, dtypes, unknown_dim_allowed=True) 824 names = _as_name_list(names, dtypes) 825 if len(dtypes) != len(shapes): 826 raise ValueError("Shapes must be provided for all components, " 827 "but received %d dtypes and %d shapes." % (len(dtypes), 828 len(shapes))) 829 830 queue_ref = gen_data_flow_ops._padding_fifo_queue_v2( 831 component_types=dtypes, 832 shapes=shapes, 833 capacity=capacity, 834 shared_name=shared_name, 835 name=name) 836 837 super(PaddingFIFOQueue, self).__init__(dtypes, shapes, names, queue_ref) 838 839 840@tf_export("PriorityQueue") 841class PriorityQueue(QueueBase): 842 """A queue implementation that dequeues elements in prioritized order. 843 844 See @{tf.QueueBase} for a description of the methods on 845 this class. 846 847 @compatibility(eager) 848 Queues are not compatible with eager execution. Instead, please 849 use `tf.data` to get data into your model. 850 @end_compatibility 851 """ 852 853 def __init__(self, 854 capacity, 855 types, 856 shapes=None, 857 names=None, 858 shared_name=None, 859 name="priority_queue"): 860 """Creates a queue that dequeues elements in a first-in first-out order. 861 862 A `PriorityQueue` has bounded capacity; supports multiple concurrent 863 producers and consumers; and provides exactly-once delivery. 864 865 A `PriorityQueue` holds a list of up to `capacity` elements. Each 866 element is a fixed-length tuple of tensors whose dtypes are 867 described by `types`, and whose shapes are optionally described 868 by the `shapes` argument. 869 870 If the `shapes` argument is specified, each component of a queue 871 element must have the respective fixed shape. If it is 872 unspecified, different queue elements may have different shapes, 873 but the use of `dequeue_many` is disallowed. 874 875 Enqueues and Dequeues to the `PriorityQueue` must include an additional 876 tuple entry at the beginning: the `priority`. The priority must be 877 an int64 scalar (for `enqueue`) or an int64 vector (for `enqueue_many`). 878 879 Args: 880 capacity: An integer. The upper bound on the number of elements 881 that may be stored in this queue. 882 types: A list of `DType` objects. The length of `types` must equal 883 the number of tensors in each queue element, except the first priority 884 element. The first tensor in each element is the priority, 885 which must be type int64. 886 shapes: (Optional.) A list of fully-defined `TensorShape` objects, 887 with the same length as `types`, or `None`. 888 names: (Optional.) A list of strings naming the components in the queue 889 with the same length as `dtypes`, or `None`. If specified, the dequeue 890 methods return a dictionary with the names as keys. 891 shared_name: (Optional.) If non-empty, this queue will be shared under 892 the given name across multiple sessions. 893 name: Optional name for the queue operation. 894 """ 895 types = _as_type_list(types) 896 shapes = _as_shape_list(shapes, types) 897 898 queue_ref = gen_data_flow_ops._priority_queue_v2( 899 component_types=types, 900 shapes=shapes, 901 capacity=capacity, 902 shared_name=shared_name, 903 name=name) 904 905 priority_dtypes = [_dtypes.int64] + types 906 priority_shapes = [()] + shapes if shapes else shapes 907 908 super(PriorityQueue, self).__init__(priority_dtypes, priority_shapes, names, 909 queue_ref) 910 911 912# TODO(josh11b): class BatchQueue(QueueBase): 913 914 915class Barrier(object): 916 """Represents a key-value map that persists across graph executions.""" 917 918 def __init__(self, types, shapes=None, shared_name=None, name="barrier"): 919 """Creates a barrier that persists across different graph executions. 920 921 A barrier represents a key-value map, where each key is a string, and 922 each value is a tuple of tensors. 923 924 At runtime, the barrier contains 'complete' and 'incomplete' 925 elements. A complete element has defined tensors for all 926 components of its value tuple, and may be accessed using 927 take_many. An incomplete element has some undefined components in 928 its value tuple, and may be updated using insert_many. 929 930 The barrier call `take_many` outputs values in a particular order. 931 First, it only outputs completed values. Second, the order in which 932 completed values are returned matches the order in which their very 933 first component was inserted into the barrier. So, for example, for this 934 sequence of insertions and removals: 935 936 barrier = Barrier((tf.string, tf.int32), shapes=((), ())) 937 barrier.insert_many(0, keys=["k1", "k2"], values=["a", "b"]).run() 938 barrier.insert_many(1, keys=["k1"], values=[1]).run() 939 barrier.insert_many(0, keys=["k3"], values=["c"]).run() 940 barrier.insert_many(1, keys=["k3"], values=[3]).run() 941 barrier.insert_many(1, keys=["k2"], values=[2]).run() 942 943 (indices, keys, values) = barrier.take_many(2) 944 (indices_val, keys_val, values0_val, values1_val) = 945 session.run([indices, keys, values[0], values[1]]) 946 947 The output will be (up to permutation of "k1" and "k2"): 948 949 indices_val == (-2**63, -2**63) 950 keys_val == ("k1", "k2") 951 values0_val == ("a", "b") 952 values1_val == (1, 2) 953 954 Note the key "k2" was inserted into the barrier before "k3". Even though 955 "k3" was completed first, both are complete by the time 956 take_many is called. As a result, "k2" is prioritized and "k1" and "k2" 957 are returned first. "k3" remains in the barrier until the next execution 958 of `take_many`. Since "k1" and "k2" had their first insertions into 959 the barrier together, their indices are the same (-2**63). The index 960 of "k3" will be -2**63 + 1, because it was the next new inserted key. 961 962 Args: 963 types: A single dtype or a tuple of dtypes, corresponding to the 964 dtypes of the tensor elements that comprise a value in this barrier. 965 shapes: Optional. Constraints on the shapes of tensors in the values: 966 a single tensor shape tuple; a tuple of tensor shape tuples 967 for each barrier-element tuple component; or None if the shape should 968 not be constrained. 969 shared_name: Optional. If non-empty, this barrier will be shared under 970 the given name across multiple sessions. 971 name: Optional name for the barrier op. 972 973 Raises: 974 ValueError: If one of the `shapes` indicate no elements. 975 """ 976 self._types = _as_type_list(types) 977 978 if shapes is not None: 979 shapes = _as_shape_list(shapes, self._types) 980 self._shapes = [tensor_shape.TensorShape(s) for s in shapes] 981 for i, shape in enumerate(self._shapes): 982 if shape.num_elements() == 0: 983 raise ValueError("Empty tensors are not supported, but received " 984 "shape '%s' at index %d" % (shape, i)) 985 else: 986 self._shapes = [tensor_shape.unknown_shape() for _ in self._types] 987 988 self._barrier_ref = gen_data_flow_ops._barrier( 989 component_types=self._types, 990 shapes=self._shapes, 991 shared_name=shared_name, 992 name=name) 993 if context.in_graph_mode(): 994 self._name = self._barrier_ref.op.name.split("/")[-1] 995 else: 996 self._name = context.context().scope_name 997 998 @property 999 def barrier_ref(self): 1000 """Get the underlying barrier reference.""" 1001 return self._barrier_ref 1002 1003 @property 1004 def name(self): 1005 """The name of the underlying barrier.""" 1006 if context.in_graph_mode(): 1007 return self._barrier_ref.op.name 1008 return self._name 1009 1010 def insert_many(self, component_index, keys, values, name=None): 1011 """For each key, assigns the respective value to the specified component. 1012 1013 This operation updates each element at component_index. 1014 1015 Args: 1016 component_index: The component of the value that is being assigned. 1017 keys: A vector of keys, with length n. 1018 values: An any-dimensional tensor of values, which are associated with the 1019 respective keys. The first dimension must have length n. 1020 name: Optional name for the op. 1021 1022 Returns: 1023 The operation that performs the insertion. 1024 Raises: 1025 InvalidArgumentsError: If inserting keys and values without elements. 1026 """ 1027 if name is None: 1028 name = "%s_BarrierInsertMany" % self._name 1029 return gen_data_flow_ops._barrier_insert_many( 1030 self._barrier_ref, keys, values, component_index, name=name) 1031 1032 def take_many(self, 1033 num_elements, 1034 allow_small_batch=False, 1035 timeout=None, 1036 name=None): 1037 """Takes the given number of completed elements from this barrier. 1038 1039 This operation concatenates completed-element component tensors along 1040 the 0th dimension to make a single component tensor. 1041 1042 If barrier has no completed elements, this operation will block 1043 until there are 'num_elements' elements to take. 1044 1045 TODO(b/25743580): the semantics of `allow_small_batch` are experimental 1046 and may be extended to other cases in the future. 1047 1048 TODO(ebrevdo): If a take_many(allow_small_batch=True) is blocking 1049 already when the barrier is closed, it will block for ever. Fix this 1050 by using asynchronous operations. 1051 1052 Args: 1053 num_elements: The number of elements to take. 1054 allow_small_batch: If the barrier is closed, don't block if there are less 1055 completed elements than requested, but instead return all available 1056 completed elements. 1057 timeout: This specifies the number of milliseconds to block 1058 before returning with DEADLINE_EXCEEDED. (This option is not 1059 supported yet.) 1060 name: A name for the operation (optional). 1061 1062 Returns: 1063 A tuple of (index, key, value_list). 1064 "index" is a int64 tensor of length num_elements containing the 1065 index of the insert_many call for which the very first component of 1066 the given element was inserted into the Barrier, starting with 1067 the value -2**63. Note, this value is different from the 1068 index of the insert_many call for which the element was completed. 1069 "key" is a string tensor of length num_elements containing the keys. 1070 "value_list" is a tuple of tensors, each one with size num_elements 1071 in the 0th dimension for each component in the barrier's values. 1072 1073 """ 1074 if name is None: 1075 name = "%s_BarrierTakeMany" % self._name 1076 ret = gen_data_flow_ops._barrier_take_many( 1077 self._barrier_ref, 1078 num_elements, 1079 self._types, 1080 allow_small_batch, 1081 timeout, 1082 name=name) 1083 1084 # NOTE(mrry): Not using a shape function because we need access to 1085 # the Barrier object. 1086 if context.in_graph_mode(): 1087 op = ret[0].op 1088 if allow_small_batch: 1089 batch_dim = None 1090 else: 1091 batch_dim = tensor_shape.Dimension( 1092 tensor_util.constant_value(op.inputs[1])) 1093 op.outputs[0].set_shape(tensor_shape.vector(batch_dim)) # indices 1094 op.outputs[1].set_shape(tensor_shape.vector(batch_dim)) # keys 1095 for output, shape in zip(op.outputs[2:], self._shapes): # value_list 1096 output.set_shape( 1097 tensor_shape.TensorShape([batch_dim]).concatenate(shape)) 1098 1099 return ret 1100 1101 def close(self, cancel_pending_enqueues=False, name=None): 1102 """Closes this barrier. 1103 1104 This operation signals that no more new key values will be inserted in the 1105 given barrier. Subsequent InsertMany operations with new keys will fail. 1106 InsertMany operations that just complement already existing keys with other 1107 components, will continue to succeed. Subsequent TakeMany operations will 1108 continue to succeed if sufficient elements remain in the barrier. Subsequent 1109 TakeMany operations that would block will fail immediately. 1110 1111 If `cancel_pending_enqueues` is `True`, all pending requests to the 1112 underlying queue will also be canceled, and completing of already 1113 started values is also not acceptable anymore. 1114 1115 Args: 1116 cancel_pending_enqueues: (Optional.) A boolean, defaulting to 1117 `False` (described above). 1118 name: Optional name for the op. 1119 1120 Returns: 1121 The operation that closes the barrier. 1122 """ 1123 if name is None: 1124 name = "%s_BarrierClose" % self._name 1125 return gen_data_flow_ops._barrier_close( 1126 self._barrier_ref, 1127 cancel_pending_enqueues=cancel_pending_enqueues, 1128 name=name) 1129 1130 def ready_size(self, name=None): 1131 """Compute the number of complete elements in the given barrier. 1132 1133 Args: 1134 name: A name for the operation (optional). 1135 1136 Returns: 1137 A single-element tensor containing the number of complete elements in the 1138 given barrier. 1139 """ 1140 if name is None: 1141 name = "%s_BarrierReadySize" % self._name 1142 return gen_data_flow_ops._barrier_ready_size(self._barrier_ref, name=name) 1143 1144 def incomplete_size(self, name=None): 1145 """Compute the number of incomplete elements in the given barrier. 1146 1147 Args: 1148 name: A name for the operation (optional). 1149 1150 Returns: 1151 A single-element tensor containing the number of incomplete elements in 1152 the given barrier. 1153 """ 1154 if name is None: 1155 name = "%s_BarrierIncompleteSize" % self._name 1156 return gen_data_flow_ops._barrier_incomplete_size( 1157 self._barrier_ref, name=name) 1158 1159 1160@tf_export("ConditionalAccumulatorBase") 1161class ConditionalAccumulatorBase(object): 1162 """A conditional accumulator for aggregating gradients. 1163 1164 Up-to-date gradients (i.e., time step at which gradient was computed is 1165 equal to the accumulator's time step) are added to the accumulator. 1166 1167 Extraction of the average gradient is blocked until the required number of 1168 gradients has been accumulated. 1169 """ 1170 1171 def __init__(self, dtype, shape, accumulator_ref): 1172 """Creates a new ConditionalAccumulator. 1173 1174 Args: 1175 dtype: Datatype of the accumulated gradients. 1176 shape: Shape of the accumulated gradients. 1177 accumulator_ref: A handle to the conditional accumulator, created by sub- 1178 classes 1179 """ 1180 self._dtype = dtype 1181 if shape is not None: 1182 self._shape = tensor_shape.TensorShape(shape) 1183 else: 1184 self._shape = tensor_shape.unknown_shape() 1185 self._accumulator_ref = accumulator_ref 1186 if context.in_graph_mode(): 1187 self._name = self._accumulator_ref.op.name.split("/")[-1] 1188 else: 1189 self._name = context.context().scope_name 1190 1191 @property 1192 def accumulator_ref(self): 1193 """The underlying accumulator reference.""" 1194 return self._accumulator_ref 1195 1196 @property 1197 def name(self): 1198 """The name of the underlying accumulator.""" 1199 return self._name 1200 1201 @property 1202 def dtype(self): 1203 """The datatype of the gradients accumulated by this accumulator.""" 1204 return self._dtype 1205 1206 def num_accumulated(self, name=None): 1207 """Number of gradients that have currently been aggregated in accumulator. 1208 1209 Args: 1210 name: Optional name for the operation. 1211 1212 Returns: 1213 Number of accumulated gradients currently in accumulator. 1214 """ 1215 if name is None: 1216 name = "%s_NumAccumulated" % self._name 1217 return gen_data_flow_ops.accumulator_num_accumulated( 1218 self._accumulator_ref, name=name) 1219 1220 def set_global_step(self, new_global_step, name=None): 1221 """Sets the global time step of the accumulator. 1222 1223 The operation logs a warning if we attempt to set to a time step that is 1224 lower than the accumulator's own time step. 1225 1226 Args: 1227 new_global_step: Value of new time step. Can be a variable or a constant 1228 name: Optional name for the operation. 1229 1230 Returns: 1231 Operation that sets the accumulator's time step. 1232 """ 1233 return gen_data_flow_ops.accumulator_set_global_step( 1234 self._accumulator_ref, 1235 math_ops.to_int64(ops.convert_to_tensor(new_global_step)), 1236 name=name) 1237 1238 1239@tf_export("ConditionalAccumulator") 1240class ConditionalAccumulator(ConditionalAccumulatorBase): 1241 """A conditional accumulator for aggregating gradients. 1242 1243 Up-to-date gradients (i.e., time step at which gradient was computed is 1244 equal to the accumulator's time step) are added to the accumulator. 1245 1246 Extraction of the average gradient is blocked until the required number of 1247 gradients has been accumulated. 1248 """ 1249 1250 def __init__(self, 1251 dtype, 1252 shape=None, 1253 shared_name=None, 1254 name="conditional_accumulator"): 1255 """Creates a new ConditionalAccumulator. 1256 1257 Args: 1258 dtype: Datatype of the accumulated gradients. 1259 shape: Shape of the accumulated gradients. 1260 shared_name: Optional. If non-empty, this accumulator will be shared under 1261 the given name across multiple sessions. 1262 name: Optional name for the accumulator. 1263 """ 1264 accumulator_ref = gen_data_flow_ops.conditional_accumulator( 1265 dtype=dtype, shape=shape, shared_name=shared_name, name=name) 1266 super(ConditionalAccumulator, self).__init__(dtype, shape, accumulator_ref) 1267 1268 def apply_grad(self, grad, local_step=0, name=None): 1269 """Attempts to apply a gradient to the accumulator. 1270 1271 The attempt is silently dropped if the gradient is stale, i.e., local_step 1272 is less than the accumulator's global time step. 1273 1274 Args: 1275 grad: The gradient tensor to be applied. 1276 local_step: Time step at which the gradient was computed. 1277 name: Optional name for the operation. 1278 1279 Returns: 1280 The operation that (conditionally) applies a gradient to the accumulator. 1281 1282 Raises: 1283 ValueError: If grad is of the wrong shape 1284 """ 1285 grad = ops.convert_to_tensor(grad, self._dtype) 1286 grad.get_shape().assert_is_compatible_with(self._shape) 1287 local_step = math_ops.to_int64(ops.convert_to_tensor(local_step)) 1288 return gen_data_flow_ops.accumulator_apply_gradient( 1289 self._accumulator_ref, local_step=local_step, gradient=grad, name=name) 1290 1291 def take_grad(self, num_required, name=None): 1292 """Attempts to extract the average gradient from the accumulator. 1293 1294 The operation blocks until sufficient number of gradients have been 1295 successfully applied to the accumulator. 1296 1297 Once successful, the following actions are also triggered: 1298 1299 - Counter of accumulated gradients is reset to 0. 1300 - Aggregated gradient is reset to 0 tensor. 1301 - Accumulator's internal time step is incremented by 1. 1302 1303 Args: 1304 num_required: Number of gradients that needs to have been aggregated 1305 name: Optional name for the operation 1306 1307 Returns: 1308 A tensor holding the value of the average gradient. 1309 1310 Raises: 1311 InvalidArgumentError: If num_required < 1 1312 """ 1313 out = gen_data_flow_ops.accumulator_take_gradient( 1314 self._accumulator_ref, num_required, dtype=self._dtype, name=name) 1315 out.set_shape(self._shape) 1316 return out 1317 1318 1319@tf_export("SparseConditionalAccumulator") 1320class SparseConditionalAccumulator(ConditionalAccumulatorBase): 1321 """A conditional accumulator for aggregating sparse gradients. 1322 1323 Sparse gradients are represented by IndexedSlices. 1324 1325 Up-to-date gradients (i.e., time step at which gradient was computed is 1326 equal to the accumulator's time step) are added to the accumulator. 1327 1328 Extraction of the average gradient is blocked until the required number of 1329 gradients has been accumulated. 1330 1331 Args: 1332 dtype: Datatype of the accumulated gradients. 1333 shape: Shape of the accumulated gradients. 1334 shared_name: Optional. If non-empty, this accumulator will be shared under 1335 the given name across multiple sessions. 1336 name: Optional name for the accumulator. 1337 """ 1338 1339 def __init__(self, 1340 dtype, 1341 shape=None, 1342 shared_name=None, 1343 name="sparse_conditional_accumulator"): 1344 accumulator_ref = gen_data_flow_ops.sparse_conditional_accumulator( 1345 dtype=dtype, shape=shape, shared_name=shared_name, name=name) 1346 super(SparseConditionalAccumulator, self).__init__(dtype, shape, 1347 accumulator_ref) 1348 1349 def apply_indexed_slices_grad(self, grad, local_step=0, name=None): 1350 """Attempts to apply a gradient to the accumulator. 1351 1352 The attempt is silently dropped if the gradient is stale, i.e., local_step 1353 is less than the accumulator's global time step. 1354 1355 Args: 1356 grad: The gradient IndexedSlices to be applied. 1357 local_step: Time step at which the gradient was computed. 1358 name: Optional name for the operation. 1359 1360 Returns: 1361 The operation that (conditionally) applies a gradient to the accumulator. 1362 1363 Raises: 1364 InvalidArgumentError: If grad is of the wrong shape 1365 """ 1366 return self.apply_grad( 1367 grad_indices=grad.indices, 1368 grad_values=grad.values, 1369 grad_shape=grad.dense_shape, 1370 local_step=local_step, 1371 name=name) 1372 1373 def apply_grad(self, 1374 grad_indices, 1375 grad_values, 1376 grad_shape=None, 1377 local_step=0, 1378 name=None): 1379 """Attempts to apply a sparse gradient to the accumulator. 1380 1381 The attempt is silently dropped if the gradient is stale, i.e., local_step 1382 is less than the accumulator's global time step. 1383 1384 A sparse gradient is represented by its indices, values and possibly empty 1385 or None shape. Indices must be a vector representing the locations of 1386 non-zero entries in the tensor. Values are the non-zero slices of the 1387 gradient, and must have the same first dimension as indices, i.e., the nnz 1388 represented by indices and values must be consistent. Shape, if not empty or 1389 None, must be consistent with the accumulator's shape (if also provided). 1390 1391 Example: 1392 A tensor [[0, 0], [0. 1], [2, 3]] can be represented 1393 indices: [1,2] 1394 values: [[0,1],[2,3]] 1395 shape: [3, 2] 1396 1397 Args: 1398 grad_indices: Indices of the sparse gradient to be applied. 1399 grad_values: Values of the sparse gradient to be applied. 1400 grad_shape: Shape of the sparse gradient to be applied. 1401 local_step: Time step at which the gradient was computed. 1402 name: Optional name for the operation. 1403 1404 Returns: 1405 The operation that (conditionally) applies a gradient to the accumulator. 1406 1407 Raises: 1408 InvalidArgumentError: If grad is of the wrong shape 1409 """ 1410 local_step = math_ops.to_int64(ops.convert_to_tensor(local_step)) 1411 return gen_data_flow_ops.sparse_accumulator_apply_gradient( 1412 self._accumulator_ref, 1413 local_step=local_step, 1414 gradient_indices=math_ops.to_int64(grad_indices), 1415 gradient_values=grad_values, 1416 gradient_shape=math_ops.to_int64([] 1417 if grad_shape is None else grad_shape), 1418 has_known_shape=(grad_shape is not None), 1419 name=name) 1420 1421 def take_grad(self, num_required, name=None): 1422 """Attempts to extract the average gradient from the accumulator. 1423 1424 The operation blocks until sufficient number of gradients have been 1425 successfully applied to the accumulator. 1426 1427 Once successful, the following actions are also triggered: 1428 - Counter of accumulated gradients is reset to 0. 1429 - Aggregated gradient is reset to 0 tensor. 1430 - Accumulator's internal time step is incremented by 1. 1431 1432 Args: 1433 num_required: Number of gradients that needs to have been aggregated 1434 name: Optional name for the operation 1435 1436 Returns: 1437 A tuple of indices, values, and shape representing the average gradient. 1438 1439 Raises: 1440 InvalidArgumentError: If num_required < 1 1441 """ 1442 return gen_data_flow_ops.sparse_accumulator_take_gradient( 1443 self._accumulator_ref, num_required, dtype=self._dtype, name=name) 1444 1445 def take_indexed_slices_grad(self, num_required, name=None): 1446 """Attempts to extract the average gradient from the accumulator. 1447 1448 The operation blocks until sufficient number of gradients have been 1449 successfully applied to the accumulator. 1450 1451 Once successful, the following actions are also triggered: 1452 - Counter of accumulated gradients is reset to 0. 1453 - Aggregated gradient is reset to 0 tensor. 1454 - Accumulator's internal time step is incremented by 1. 1455 1456 Args: 1457 num_required: Number of gradients that needs to have been aggregated 1458 name: Optional name for the operation 1459 1460 Returns: 1461 An IndexedSlices holding the value of the average gradient. 1462 1463 Raises: 1464 InvalidArgumentError: If num_required < 1 1465 """ 1466 return_val = gen_data_flow_ops.sparse_accumulator_take_gradient( 1467 self._accumulator_ref, num_required, dtype=self._dtype, name=name) 1468 return ops.IndexedSlices( 1469 indices=return_val.indices, 1470 values=return_val.values, 1471 dense_shape=return_val.shape) 1472 1473 1474class BaseStagingArea(object): 1475 """Base class for Staging Areas.""" 1476 _identifier = 0 1477 _lock = threading.Lock() 1478 1479 def __init__(self, 1480 dtypes, 1481 shapes=None, 1482 names=None, 1483 shared_name=None, 1484 capacity=0, 1485 memory_limit=0): 1486 if shared_name is None: 1487 self._name = ( 1488 ops.get_default_graph().unique_name(self.__class__.__name__)) 1489 elif isinstance(shared_name, six.string_types): 1490 self._name = shared_name 1491 else: 1492 raise ValueError("shared_name must be a string") 1493 1494 self._dtypes = dtypes 1495 1496 if shapes is not None: 1497 if len(shapes) != len(dtypes): 1498 raise ValueError("StagingArea shapes must be the same length as dtypes") 1499 self._shapes = [tensor_shape.TensorShape(s) for s in shapes] 1500 else: 1501 self._shapes = [tensor_shape.unknown_shape() for _ in self._dtypes] 1502 1503 if names is not None: 1504 if len(names) != len(dtypes): 1505 raise ValueError("StagingArea names must be the same length as dtypes") 1506 self._names = names 1507 else: 1508 self._names = None 1509 1510 self._capacity = capacity 1511 self._memory_limit = memory_limit 1512 1513 # all get and put ops must colocate with this op 1514 with ops.name_scope("%s_root" % self._name): 1515 self._coloc_op = control_flow_ops.no_op() 1516 1517 @property 1518 def name(self): 1519 """The name of the staging area.""" 1520 return self._name 1521 1522 @property 1523 def dtypes(self): 1524 """The list of dtypes for each component of a staging area element.""" 1525 return self._dtypes 1526 1527 @property 1528 def shapes(self): 1529 """The list of shapes for each component of a staging area element.""" 1530 return self._shapes 1531 1532 @property 1533 def names(self): 1534 """The list of names for each component of a staging area element.""" 1535 return self._names 1536 1537 @property 1538 def capacity(self): 1539 """The maximum number of elements of this staging area.""" 1540 return self._capacity 1541 1542 @property 1543 def memory_limit(self): 1544 """The maximum number of bytes of this staging area.""" 1545 return self._memory_limit 1546 1547 def _check_put_dtypes(self, vals, indices=None): 1548 """Validate and convert `vals` to a list of `Tensor`s. 1549 1550 The `vals` argument can be a Tensor, a list or tuple of tensors, or a 1551 dictionary with tensor values. 1552 1553 If `vals` is a list, then the appropriate indices associated with the 1554 values must be provided. 1555 1556 If it is a dictionary, the staging area must have been constructed with a 1557 `names` attribute and the dictionary keys must match the staging area names. 1558 `indices` will be inferred from the dictionary keys. 1559 If the staging area was constructed with a `names` attribute, `vals` must 1560 be a dictionary. 1561 1562 Checks that the dtype and shape of each value matches that 1563 of the staging area. 1564 1565 Args: 1566 vals: A tensor, a list or tuple of tensors, or a dictionary.. 1567 1568 Returns: 1569 A (tensors, indices) tuple where `tensors` is a list of `Tensor` objects 1570 and `indices` is a list of indices associed with the tensors. 1571 1572 Raises: 1573 ValueError: If `vals` or `indices` is invalid. 1574 """ 1575 if isinstance(vals, dict): 1576 if not self._names: 1577 raise ValueError( 1578 "Staging areas must have names to enqueue a dictionary") 1579 if not set(vals.keys()).issubset(self._names): 1580 raise ValueError("Keys in dictionary to put do not match names " 1581 "of staging area. Dictionary: (%s), Queue: (%s)" % 1582 (sorted(vals.keys()), sorted(self._names))) 1583 # The order of values in `self._names` indicates the order in which the 1584 # tensors in the dictionary `vals` must be listed. 1585 vals, indices, n = zip(*[(vals[k], i, k) 1586 for i, k in enumerate(self._names) 1587 if k in vals]) 1588 else: 1589 if self._names: 1590 raise ValueError("You must enqueue a dictionary in a staging area " 1591 "with names") 1592 1593 if indices is None: 1594 raise ValueError("Indices must be supplied when inserting a list " 1595 "of tensors") 1596 1597 if len(indices) != len(vals): 1598 raise ValueError("Number of indices '%s' doesn't match " 1599 "number of values '%s'") 1600 1601 if not isinstance(vals, (list, tuple)): 1602 vals = [vals] 1603 indices = [0] 1604 1605 # Sanity check number of values 1606 if not len(vals) <= len(self._dtypes): 1607 raise ValueError("Unexpected number of inputs '%s' vs '%s'" % 1608 (len(vals), len(self._dtypes))) 1609 1610 tensors = [] 1611 1612 for val, i in zip(vals, indices): 1613 dtype, shape = self._dtypes[i], self._shapes[i] 1614 # Check dtype 1615 if not val.dtype == dtype: 1616 raise ValueError("Datatypes do not match. '%s' != '%s'" % 1617 (str(val.dtype), str(dtype))) 1618 1619 # Check shape 1620 val.get_shape().assert_is_compatible_with(shape) 1621 1622 tensors.append( 1623 ops.convert_to_tensor(val, dtype=dtype, name="component_%d" % i)) 1624 1625 return tensors, indices 1626 1627 def _create_device_transfers(self, tensors): 1628 """Encode inter-device transfers if the current device 1629 is not the same as the Staging Area's device 1630 """ 1631 1632 if not isinstance(tensors, (tuple, list)): 1633 tensors = [tensors] 1634 1635 curr_device_scope = control_flow_ops.no_op().device 1636 1637 if curr_device_scope != self._coloc_op.device: 1638 tensors = [array_ops.identity(t) for t in tensors] 1639 1640 return tensors 1641 1642 def _get_return_value(self, tensors, indices): 1643 """Return the value to return from a get op. 1644 1645 If the staging area has names, return a dictionary with the 1646 names as keys. Otherwise return either a single tensor 1647 or a list of tensors depending on the length of `tensors`. 1648 1649 Args: 1650 tensors: List of tensors from the get op. 1651 indices: Indices of associated names and shapes 1652 1653 Returns: 1654 A single tensor, a list of tensors, or a dictionary 1655 of tensors. 1656 """ 1657 1658 tensors = self._create_device_transfers(tensors) 1659 1660 # Sets shape 1661 for output, i in zip(tensors, indices): 1662 output.set_shape(self._shapes[i]) 1663 1664 if self._names: 1665 # The returned values in `tensors` are in the same order as 1666 # the names in `self._names`. 1667 return {self._names[i]: t for t, i in zip(tensors, indices)} 1668 return tensors 1669 1670 def _scope_vals(self, vals): 1671 """Return a list of values to pass to `name_scope()`. 1672 1673 Args: 1674 vals: A tensor, a list or tuple of tensors, or a dictionary. 1675 1676 Returns: 1677 The values in vals as a list. 1678 """ 1679 if isinstance(vals, (list, tuple)): 1680 return vals 1681 elif isinstance(vals, dict): 1682 return vals.values() 1683 else: 1684 return [vals] 1685 1686 1687class StagingArea(BaseStagingArea): 1688 """Class for staging inputs. No ordering guarantees. 1689 1690 A `StagingArea` is a TensorFlow data structure that stores tensors across 1691 multiple steps, and exposes operations that can put and get tensors. 1692 1693 Each `StagingArea` element is a tuple of one or more tensors, where each 1694 tuple component has a static dtype, and may have a static shape. 1695 1696 The capacity of a `StagingArea` may be bounded or unbounded. 1697 It supports multiple concurrent producers and consumers; and 1698 provides exactly-once delivery. 1699 1700 Each element of a `StagingArea` is a fixed-length tuple of tensors whose 1701 dtypes are described by `dtypes`, and whose shapes are optionally described 1702 by the `shapes` argument. 1703 1704 If the `shapes` argument is specified, each component of a staging area 1705 element must have the respective fixed shape. If it is 1706 unspecified, different elements may have different shapes, 1707 1708 It can be configured with a capacity in which case 1709 put(values) will block until space becomes available. 1710 1711 Similarly, it can be configured with a memory limit which 1712 will block put(values) until space is available. 1713 This is mostly useful for limiting the number of tensors on 1714 devices such as GPUs. 1715 1716 All get() and peek() commands block if the requested data 1717 is not present in the Staging Area. 1718 1719 """ 1720 1721 def __init__(self, 1722 dtypes, 1723 shapes=None, 1724 names=None, 1725 shared_name=None, 1726 capacity=0, 1727 memory_limit=0): 1728 """Constructs a staging area object. 1729 1730 The two optional lists, `shapes` and `names`, must be of the same length 1731 as `dtypes` if provided. The values at a given index `i` indicate the 1732 shape and name to use for the corresponding queue component in `dtypes`. 1733 1734 The device scope at the time of object creation determines where the 1735 storage for the `StagingArea` will reside. Calls to `put` will incur a copy 1736 to this memory space, if necessary. Tensors returned by `get` will be 1737 placed according to the device scope when `get` is called. 1738 1739 Args: 1740 dtypes: A list of types. The length of dtypes must equal the number 1741 of tensors in each element. 1742 capacity: (Optional.) Maximum number of elements. 1743 An integer. If zero, the Staging Area is unbounded 1744 memory_limit: (Optional.) Maximum number of bytes of all tensors 1745 in the Staging Area. 1746 An integer. If zero, the Staging Area is unbounded 1747 shapes: (Optional.) Constraints on the shapes of tensors in an element. 1748 A list of shape tuples or None. This list is the same length 1749 as dtypes. If the shape of any tensors in the element are constrained, 1750 all must be; shapes can be None if the shapes should not be constrained. 1751 names: (Optional.) If provided, the `get()` and 1752 `put()` methods will use dictionaries with these names as keys. 1753 Must be None or a list or tuple of the same length as `dtypes`. 1754 shared_name: (Optional.) A name to be used for the shared object. By 1755 passing the same name to two different python objects they will share 1756 the underlying staging area. Must be a string. 1757 1758 Raises: 1759 ValueError: If one of the arguments is invalid. 1760 """ 1761 1762 super(StagingArea, self).__init__(dtypes, shapes, names, shared_name, 1763 capacity, memory_limit) 1764 1765 def put(self, values, name=None): 1766 """Create an op that places a value into the staging area. 1767 1768 This operation will block if the `StagingArea` has reached 1769 its capacity. 1770 1771 Args: 1772 values: Tensor (or a tuple of Tensors) to place into the staging area. 1773 name: A name for the operation (optional). 1774 1775 Returns: 1776 The created op. 1777 1778 Raises: 1779 ValueError: If the number or type of inputs don't match the staging area. 1780 """ 1781 with ops.name_scope(name, "%s_put" % self._name, 1782 self._scope_vals(values)) as scope: 1783 1784 # Hard-code indices for this staging area 1785 indices = ( 1786 list(six.moves.range(len(values))) 1787 if isinstance(values, (list, tuple)) else None) 1788 vals, _ = self._check_put_dtypes(values, indices) 1789 1790 with ops.colocate_with(self._coloc_op): 1791 op = gen_data_flow_ops.stage( 1792 values=vals, 1793 shared_name=self._name, 1794 name=scope, 1795 capacity=self._capacity, 1796 memory_limit=self._memory_limit) 1797 1798 return op 1799 1800 def __internal_get(self, get_fn, name): 1801 with ops.colocate_with(self._coloc_op): 1802 ret = get_fn() 1803 1804 indices = list(six.moves.range(len(self._dtypes))) # Hard coded 1805 return self._get_return_value(ret, indices) 1806 1807 def get(self, name=None): 1808 """Gets one element from this staging area. 1809 1810 If the staging area is empty when this operation executes, it will block 1811 until there is an element to dequeue. 1812 1813 Note that unlike others ops that can block, like the queue Dequeue 1814 operations, this can stop other work from happening. To avoid this, the 1815 intended use is for this to be called only when there will be an element 1816 already available. One method for doing this in a training loop would be to 1817 run a `put()` call during a warmup session.run call, and then call both 1818 `get()` and `put()` in each subsequent step. 1819 1820 The placement of the returned tensor will be determined by the current 1821 device scope when this function is called. 1822 1823 Args: 1824 name: A name for the operation (optional). 1825 1826 Returns: 1827 The tuple of tensors that was gotten. 1828 """ 1829 if name is None: 1830 name = "%s_get" % self._name 1831 1832 # pylint: disable=bad-continuation 1833 fn = lambda: gen_data_flow_ops.unstage(dtypes=self._dtypes, 1834 shared_name=self._name, name=name, 1835 capacity=self._capacity, 1836 memory_limit=self._memory_limit) 1837 # pylint: enable=bad-continuation 1838 1839 return self.__internal_get(fn, name) 1840 1841 def peek(self, index, name=None): 1842 """Peeks at an element in the staging area. 1843 1844 If the staging area is too small to contain the element at 1845 the specified index, it will block until enough elements 1846 are inserted to complete the operation. 1847 1848 The placement of the returned tensor will be determined by 1849 the current device scope when this function is called. 1850 1851 Args: 1852 index: The index of the tensor within the staging area 1853 to look up. 1854 name: A name for the operation (optional). 1855 1856 Returns: 1857 The tuple of tensors that was gotten. 1858 """ 1859 if name is None: 1860 name = "%s_peek" % self._name 1861 1862 # pylint: disable=bad-continuation 1863 fn = lambda: gen_data_flow_ops.stage_peek(index, 1864 dtypes=self._dtypes, shared_name=self._name, 1865 name=name, capacity=self._capacity, 1866 memory_limit=self._memory_limit) 1867 # pylint: enable=bad-continuation 1868 1869 return self.__internal_get(fn, name) 1870 1871 def size(self, name=None): 1872 """Returns the number of elements in the staging area. 1873 1874 Args: 1875 name: A name for the operation (optional) 1876 1877 Returns: 1878 The created op 1879 """ 1880 if name is None: 1881 name = "%s_size" % self._name 1882 1883 return gen_data_flow_ops.stage_size( 1884 name=name, 1885 shared_name=self._name, 1886 dtypes=self._dtypes, 1887 capacity=self._capacity, 1888 memory_limit=self._memory_limit) 1889 1890 def clear(self, name=None): 1891 """Clears the staging area. 1892 1893 Args: 1894 name: A name for the operation (optional) 1895 1896 Returns: 1897 The created op 1898 """ 1899 if name is None: 1900 name = "%s_clear" % self._name 1901 1902 return gen_data_flow_ops.stage_clear( 1903 name=name, 1904 shared_name=self._name, 1905 dtypes=self._dtypes, 1906 capacity=self._capacity, 1907 memory_limit=self._memory_limit) 1908 1909 1910class MapStagingArea(BaseStagingArea): 1911 """A `MapStagingArea` is a TensorFlow data structure that stores tensors across multiple steps, and exposes operations that can put and get tensors. 1912 1913 Each `MapStagingArea` element is a (key, value) pair. 1914 Only int64 keys are supported, other types should be 1915 hashed to produce a key. 1916 Values are a tuple of one or more tensors. 1917 Each tuple component has a static dtype, 1918 and may have a static shape. 1919 1920 The capacity of a `MapStagingArea` may be bounded or unbounded. 1921 It supports multiple concurrent producers and consumers; and 1922 provides exactly-once delivery. 1923 1924 Each value tuple of a `MapStagingArea` is a fixed-length tuple of tensors 1925 whose 1926 dtypes are described by `dtypes`, and whose shapes are optionally described 1927 by the `shapes` argument. 1928 1929 If the `shapes` argument is specified, each component of a staging area 1930 element must have the respective fixed shape. If it is 1931 unspecified, different elements may have different shapes, 1932 1933 It behaves like an associative container with support for: 1934 1935 - put(key, values) 1936 - peek(key) like dict.get(key) 1937 - get(key) like dict.pop(key) 1938 - get(key=None) like dict.popitem() 1939 - size() 1940 - clear() 1941 1942 If ordered a tree structure ordered by key will be used and 1943 get(key=None) will remove (key, value) pairs in increasing key order. 1944 Otherwise a hashtable 1945 1946 It can be configured with a capacity in which case 1947 put(key, values) will block until space becomes available. 1948 1949 Similarly, it can be configured with a memory limit which 1950 will block put(key, values) until space is available. 1951 This is mostly useful for limiting the number of tensors on 1952 devices such as GPUs. 1953 1954 All get() and peek() commands block if the requested 1955 (key, value) pair is not present in the staging area. 1956 1957 Partial puts are supported and will be placed in an incomplete 1958 map until such time as all values associated with the key have 1959 been inserted. Once completed, this (key, value) pair will be 1960 inserted into the map. Data in the incomplete map 1961 counts towards the memory limit, but not towards capacity limit. 1962 1963 Partial gets from the map are also supported. 1964 This removes the partially requested tensors from the entry, 1965 but the entry is only removed from the map once all tensors 1966 associated with it are removed. 1967 """ 1968 1969 def __init__(self, 1970 dtypes, 1971 shapes=None, 1972 names=None, 1973 shared_name=None, 1974 ordered=False, 1975 capacity=0, 1976 memory_limit=0): 1977 """Args: 1978 1979 dtypes: A list of types. The length of dtypes must equal the number 1980 of tensors in each element. 1981 capacity: (Optional.) Maximum number of elements. 1982 An integer. If zero, the Staging Area is unbounded 1983 memory_limit: (Optional.) Maximum number of bytes of all tensors 1984 in the Staging Area (excluding keys). 1985 An integer. If zero, the Staging Area is unbounded 1986 ordered: (Optional.) If True the underlying data structure 1987 is a tree ordered on key. Otherwise assume a hashtable. 1988 shapes: (Optional.) Constraints on the shapes of tensors in an element. 1989 A list of shape tuples or None. This list is the same length 1990 as dtypes. If the shape of any tensors in the element are constrained, 1991 all must be; shapes can be None if the shapes should not be constrained. 1992 names: (Optional.) If provided, the `get()` and 1993 `put()` methods will use dictionaries with these names as keys. 1994 Must be None or a list or tuple of the same length as `dtypes`. 1995 shared_name: (Optional.) A name to be used for the shared object. By 1996 passing the same name to two different python objects they will share 1997 the underlying staging area. Must be a string. 1998 1999 Raises: 2000 ValueError: If one of the arguments is invalid. 2001 2002 """ 2003 2004 super(MapStagingArea, self).__init__(dtypes, shapes, names, shared_name, 2005 capacity, memory_limit) 2006 2007 # Defer to different methods depending if the map is ordered 2008 self._ordered = ordered 2009 2010 if ordered: 2011 self._put_fn = gen_data_flow_ops.ordered_map_stage 2012 self._pop_fn = gen_data_flow_ops.ordered_map_unstage 2013 self._popitem_fn = gen_data_flow_ops.ordered_map_unstage_no_key 2014 self._peek_fn = gen_data_flow_ops.ordered_map_peek 2015 self._size_fn = gen_data_flow_ops.ordered_map_size 2016 self._incomplete_size_fn = gen_data_flow_ops.ordered_map_incomplete_size 2017 self._clear_fn = gen_data_flow_ops.ordered_map_clear 2018 else: 2019 self._put_fn = gen_data_flow_ops.map_stage 2020 self._pop_fn = gen_data_flow_ops.map_unstage 2021 self._popitem_fn = gen_data_flow_ops.map_unstage_no_key 2022 self._peek_fn = gen_data_flow_ops.map_peek 2023 self._size_fn = gen_data_flow_ops.map_size 2024 self._incomplete_size_fn = gen_data_flow_ops.map_incomplete_size 2025 self._clear_fn = gen_data_flow_ops.map_clear 2026 2027 def put(self, key, vals, indices=None, name=None): 2028 """Create an op that stores the (key, vals) pair in the staging area. 2029 2030 Incomplete puts are possible, preferably using a dictionary for vals 2031 as the appropriate dtypes and shapes can be inferred from the value names 2032 dictionary key values. If vals is a list or tuple, indices must 2033 also be specified so that the op knows at which element position 2034 to perform the insert. 2035 2036 This operation will block if the capacity or memory limit of this 2037 container is reached. 2038 2039 Args: 2040 key: Key associated with the data 2041 vals: Tensor (or a dict/tuple of Tensors) to place 2042 into the staging area. 2043 indices: (Optional) if vals is a tuple/list, this is required. 2044 name: A name for the operation (optional) 2045 2046 Returns: 2047 The created op 2048 2049 Raises: 2050 ValueError: If the number or type of inputs don't match the staging 2051 area. 2052 """ 2053 2054 with ops.name_scope(name, "%s_put" % self._name, 2055 self._scope_vals(vals)) as scope: 2056 2057 vals, indices = self._check_put_dtypes(vals, indices) 2058 2059 with ops.colocate_with(self._coloc_op): 2060 op = self._put_fn( 2061 key, 2062 indices, 2063 vals, 2064 dtypes=self._dtypes, 2065 shared_name=self._name, 2066 name=scope, 2067 capacity=self._capacity, 2068 memory_limit=self._memory_limit) 2069 return op 2070 2071 def _get_indices_and_dtypes(self, indices=None): 2072 if indices is None: 2073 indices = list(six.moves.range(len(self._dtypes))) 2074 2075 if not isinstance(indices, (tuple, list)): 2076 raise TypeError("Invalid indices type '%s'" % type(indices)) 2077 2078 if len(indices) == 0: 2079 raise ValueError("Empty indices") 2080 2081 if all(isinstance(i, str) for i in indices): 2082 if self._names is None: 2083 raise ValueError("String indices provided '%s', but this Staging Area " 2084 "was not created with names." % indices) 2085 2086 try: 2087 indices = [self._names.index(n) for n in indices] 2088 except ValueError: 2089 raise ValueError("Named index '%s' not in " 2090 "Staging Area names '%s'" % (n, self._names)) 2091 elif all(isinstance(i, int) for i in indices): 2092 pass 2093 else: 2094 raise TypeError("Mixed types in indices '%s'. " 2095 "May only be str or int" % indices) 2096 2097 dtypes = [self._dtypes[i] for i in indices] 2098 2099 return indices, dtypes 2100 2101 def peek(self, key, indices=None, name=None): 2102 """Peeks at staging area data associated with the key. 2103 2104 If the key is not in the staging area, it will block 2105 until the associated (key, value) is inserted. 2106 2107 Args: 2108 key: Key associated with the required data 2109 indices: Partial list of tensors to retrieve (optional). 2110 A list of integer or string indices. 2111 String indices are only valid if the Staging Area 2112 has names associated with it. 2113 name: A name for the operation (optional) 2114 2115 Returns: 2116 The created op 2117 """ 2118 2119 if name is None: 2120 name = "%s_pop" % self._name 2121 2122 indices, dtypes = self._get_indices_and_dtypes(indices) 2123 2124 with ops.colocate_with(self._coloc_op): 2125 result = self._peek_fn( 2126 key, 2127 shared_name=self._name, 2128 indices=indices, 2129 dtypes=dtypes, 2130 name=name, 2131 capacity=self._capacity, 2132 memory_limit=self._memory_limit) 2133 2134 return self._get_return_value(result, indices) 2135 2136 def get(self, key=None, indices=None, name=None): 2137 """If the key is provided, the associated (key, value) is returned from the staging area. 2138 2139 If the key is not in the staging area, this method will block until 2140 the associated (key, value) is inserted. 2141 If no key is provided and the staging area is ordered, 2142 the (key, value) with the smallest key will be returned. 2143 Otherwise, a random (key, value) will be returned. 2144 2145 If the staging area is empty when this operation executes, 2146 it will block until there is an element to dequeue. 2147 2148 Args: 2149 key: Key associated with the required data (Optional) 2150 indices: Partial list of tensors to retrieve (optional). 2151 A list of integer or string indices. 2152 String indices are only valid if the Staging Area 2153 has names associated with it. 2154 name: A name for the operation (optional) 2155 2156 Returns: 2157 The created op 2158 """ 2159 if key is None: 2160 return self._popitem(indices=indices, name=name) 2161 else: 2162 return self._pop(key, indices=indices, name=name) 2163 2164 def _pop(self, key, indices=None, name=None): 2165 """Remove and return the associated (key, value) is returned from the staging area. 2166 2167 If the key is not in the staging area, this method will block until 2168 the associated (key, value) is inserted. 2169 Args: 2170 key: Key associated with the required data 2171 indices: Partial list of tensors to retrieve (optional). 2172 A list of integer or string indices. 2173 String indices are only valid if the Staging Area 2174 has names associated with it. 2175 name: A name for the operation (optional) 2176 2177 Returns: 2178 The created op 2179 """ 2180 if name is None: 2181 name = "%s_get" % self._name 2182 2183 indices, dtypes = self._get_indices_and_dtypes(indices) 2184 2185 with ops.colocate_with(self._coloc_op): 2186 result = self._pop_fn( 2187 key, 2188 shared_name=self._name, 2189 indices=indices, 2190 dtypes=dtypes, 2191 name=name, 2192 capacity=self._capacity, 2193 memory_limit=self._memory_limit) 2194 2195 return key, self._get_return_value(result, indices) 2196 2197 def _popitem(self, indices=None, name=None): 2198 """If the staging area is ordered, the (key, value) with the smallest key will be returned. 2199 2200 Otherwise, a random (key, value) will be returned. 2201 If the staging area is empty when this operation executes, 2202 it will block until there is an element to dequeue. 2203 2204 Args: 2205 key: Key associated with the required data 2206 indices: Partial list of tensors to retrieve (optional). 2207 A list of integer or string indices. 2208 String indices are only valid if the Staging Area 2209 has names associated with it. 2210 name: A name for the operation (optional) 2211 2212 Returns: 2213 The created op 2214 """ 2215 if name is None: 2216 name = "%s_get_nokey" % self._name 2217 2218 indices, dtypes = self._get_indices_and_dtypes(indices) 2219 2220 with ops.colocate_with(self._coloc_op): 2221 key, result = self._popitem_fn( 2222 shared_name=self._name, 2223 indices=indices, 2224 dtypes=dtypes, 2225 name=name, 2226 capacity=self._capacity, 2227 memory_limit=self._memory_limit) 2228 2229 # Separate keys and results out from 2230 # underlying namedtuple 2231 key = self._create_device_transfers(key)[0] 2232 result = self._get_return_value(result, indices) 2233 2234 return key, result 2235 2236 def size(self, name=None): 2237 """Returns the number of elements in the staging area. 2238 2239 Args: 2240 name: A name for the operation (optional) 2241 2242 Returns: 2243 The created op 2244 """ 2245 if name is None: 2246 name = "%s_size" % self._name 2247 2248 return self._size_fn( 2249 shared_name=self._name, 2250 name=name, 2251 dtypes=self._dtypes, 2252 capacity=self._capacity, 2253 memory_limit=self._memory_limit) 2254 2255 def incomplete_size(self, name=None): 2256 """Returns the number of incomplete elements in the staging area. 2257 2258 Args: 2259 name: A name for the operation (optional) 2260 2261 Returns: 2262 The created op 2263 """ 2264 if name is None: 2265 name = "%s_incomplete_size" % self._name 2266 2267 return self._incomplete_size_fn( 2268 shared_name=self._name, 2269 name=name, 2270 dtypes=self._dtypes, 2271 capacity=self._capacity, 2272 memory_limit=self._memory_limit) 2273 2274 def clear(self, name=None): 2275 """Clears the staging area. 2276 2277 Args: 2278 name: A name for the operation (optional) 2279 2280 Returns: 2281 The created op 2282 """ 2283 if name is None: 2284 name = "%s_clear" % self._name 2285 2286 return self._clear_fn( 2287 shared_name=self._name, 2288 name=name, 2289 dtypes=self._dtypes, 2290 capacity=self._capacity, 2291 memory_limit=self._memory_limit) 2292 2293 2294class RecordInput(object): 2295 """RecordInput asynchronously reads and randomly yields TFRecords. 2296 2297 A RecordInput Op will continuously read a batch of records asynchronously 2298 into a buffer of some fixed capacity. It can also asynchronously yield 2299 random records from this buffer. 2300 2301 It will not start yielding until at least `buffer_size / 2` elements have been 2302 placed into the buffer so that sufficient randomization can take place. 2303 2304 The order the files are read will be shifted each epoch by `shift_amount` so 2305 that the data is presented in a different order every epoch. 2306 """ 2307 2308 def __init__(self, 2309 file_pattern, 2310 batch_size=1, 2311 buffer_size=1, 2312 parallelism=1, 2313 shift_ratio=0, 2314 seed=0, 2315 name=None, 2316 batches=None, 2317 compression_type=None): 2318 """Constructs a RecordInput Op. 2319 2320 Args: 2321 file_pattern: File path to the dataset, possibly containing wildcards. 2322 All matching files will be iterated over each epoch. 2323 batch_size: How many records to return at a time. 2324 buffer_size: The maximum number of records the buffer will contain. 2325 parallelism: How many reader threads to use for reading from files. 2326 shift_ratio: What percentage of the total number files to move the start 2327 file forward by each epoch. 2328 seed: Specify the random number seed used by generator that randomizes 2329 records. 2330 name: Optional name for the operation. 2331 batches: None by default, creating a single batch op. Otherwise specifies 2332 how many batches to create, which are returned as a list when 2333 `get_yield_op()` is called. An example use case is to split processing 2334 between devices on one computer. 2335 compression_type: The type of compression for the file. Currently ZLIB and 2336 GZIP are supported. Defaults to none. 2337 2338 Raises: 2339 ValueError: If one of the arguments is invalid. 2340 """ 2341 self._batch_size = batch_size 2342 if batches is not None: 2343 self._batch_size *= batches 2344 self._batches = batches 2345 self._file_pattern = file_pattern 2346 self._buffer_size = buffer_size 2347 self._parallelism = parallelism 2348 self._shift_ratio = shift_ratio 2349 self._seed = seed 2350 self._name = name 2351 self._compression_type = python_io.TFRecordCompressionType.NONE 2352 if compression_type is not None: 2353 self._compression_type = compression_type 2354 2355 def get_yield_op(self): 2356 """Adds a node that yields a group of records every time it is executed. 2357 If RecordInput `batches` parameter is not None, it yields a list of 2358 record batches with the specified `batch_size`. 2359 """ 2360 compression_type = python_io.TFRecordOptions.get_compression_type_string( 2361 python_io.TFRecordOptions(self._compression_type)) 2362 records = gen_data_flow_ops.record_input( 2363 file_pattern=self._file_pattern, 2364 file_buffer_size=self._buffer_size, 2365 file_parallelism=self._parallelism, 2366 file_shuffle_shift_ratio=self._shift_ratio, 2367 batch_size=self._batch_size, 2368 file_random_seed=self._seed, 2369 compression_type=compression_type, 2370 name=self._name) 2371 if self._batches is None: 2372 return records 2373 else: 2374 with ops.name_scope(self._name): 2375 batch_list = [[] for i in six.moves.range(self._batches)] 2376 records = array_ops.split(records, self._batch_size, 0) 2377 records = [array_ops.reshape(record, []) for record in records] 2378 for index, protobuf in zip(six.moves.range(len(records)), records): 2379 batch_index = index % self._batches 2380 batch_list[batch_index].append(protobuf) 2381 return batch_list 2382