• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#==============================================================================
15"""Data Flow Operations."""
16# pylint: disable=g-bad-name
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import hashlib
23import threading
24
25import six
26
27from tensorflow.python.eager import context
28from tensorflow.python.framework import dtypes as _dtypes
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import random_seed
31from tensorflow.python.framework import tensor_shape
32from tensorflow.python.framework import tensor_util
33from tensorflow.python.lib.io import python_io
34from tensorflow.python.ops import array_ops
35from tensorflow.python.ops import control_flow_ops
36from tensorflow.python.ops import gen_data_flow_ops
37from tensorflow.python.ops import math_ops
38from tensorflow.python.ops import resource_variable_ops
39# go/tf-wildcard-import
40# pylint: disable=wildcard-import
41from tensorflow.python.ops.gen_data_flow_ops import *
42from tensorflow.python.util import deprecation
43from tensorflow.python.util.tf_export import tf_export
44
45# pylint: enable=wildcard-import
46
47
48def _as_type_list(dtypes):
49  """Convert dtypes to a list of types."""
50  assert dtypes is not None
51  if not (isinstance(dtypes, list) or isinstance(dtypes, tuple)):
52    # We have a single type.
53    return [dtypes]
54  else:
55    # We have a list or tuple of types.
56    return list(dtypes)
57
58
59def _as_shape_list(shapes,
60                   dtypes,
61                   unknown_dim_allowed=False,
62                   unknown_rank_allowed=False):
63  """Convert shapes to a list of tuples of int (or None)."""
64  del dtypes
65  if unknown_dim_allowed:
66    if (not isinstance(shapes, collections.Sequence) or not shapes or
67        any(shape is None or isinstance(shape, int) for shape in shapes)):
68      raise ValueError(
69          "When providing partial shapes, a list of shapes must be provided.")
70  if shapes is None:
71    return None
72  if isinstance(shapes, tensor_shape.TensorShape):
73    shapes = [shapes]
74  if not isinstance(shapes, (tuple, list)):
75    raise TypeError(
76        "shapes must be a TensorShape or a list or tuple of TensorShapes.")
77  if all(shape is None or isinstance(shape, int) for shape in shapes):
78    # We have a single shape.
79    shapes = [shapes]
80  shapes = [tensor_shape.as_shape(shape) for shape in shapes]
81  if not unknown_dim_allowed:
82    if any(not shape.is_fully_defined() for shape in shapes):
83      raise ValueError("All shapes must be fully defined: %s" % shapes)
84  if not unknown_rank_allowed:
85    if any([shape.dims is None for shape in shapes]):
86      raise ValueError("All shapes must have a defined rank: %s" % shapes)
87
88  return shapes
89
90
91def _as_name_list(names, dtypes):
92  if names is None:
93    return None
94  if not isinstance(names, (list, tuple)):
95    names = [names]
96  if len(names) != len(dtypes):
97    raise ValueError("List of names must have the same length as the list "
98                     "of dtypes")
99  return list(names)
100
101
102def _shape_common(s1, s2):
103  """The greatest lower bound (ordered by specificity) TensorShape."""
104  s1 = tensor_shape.TensorShape(s1)
105  s2 = tensor_shape.TensorShape(s2)
106  if s1.ndims is None or s2.ndims is None or s1.ndims != s2.ndims:
107    return tensor_shape.unknown_shape()
108  d = [
109      d1 if d1 is not None and d1 == d2 else None
110      for (d1, d2) in zip(s1.as_list(), s2.as_list())
111  ]
112  return tensor_shape.TensorShape(d)
113
114
115# pylint: disable=protected-access
116@tf_export("queue.QueueBase",
117           v1=["queue.QueueBase", "io.QueueBase", "QueueBase"])
118@deprecation.deprecated_endpoints(["io.QueueBase", "QueueBase"])
119class QueueBase(object):
120  """Base class for queue implementations.
121
122  A queue is a TensorFlow data structure that stores tensors across
123  multiple steps, and exposes operations that enqueue and dequeue
124  tensors.
125
126  Each queue element is a tuple of one or more tensors, where each
127  tuple component has a static dtype, and may have a static shape. The
128  queue implementations support versions of enqueue and dequeue that
129  handle single elements, versions that support enqueuing and
130  dequeuing a batch of elements at once.
131
132  See `tf.FIFOQueue` and
133  `tf.RandomShuffleQueue` for concrete
134  implementations of this class, and instructions on how to create
135  them.
136  """
137
138  def __init__(self, dtypes, shapes, names, queue_ref):
139    """Constructs a queue object from a queue reference.
140
141    The two optional lists, `shapes` and `names`, must be of the same length
142    as `dtypes` if provided.  The values at a given index `i` indicate the
143    shape and name to use for the corresponding queue component in `dtypes`.
144
145    Args:
146      dtypes:  A list of types.  The length of dtypes must equal the number
147        of tensors in each element.
148      shapes: Constraints on the shapes of tensors in an element:
149        A list of shape tuples or None. This list is the same length
150        as dtypes.  If the shape of any tensors in the element are constrained,
151        all must be; shapes can be None if the shapes should not be constrained.
152      names: Optional list of names.  If provided, the `enqueue()` and
153        `dequeue()` methods will use dictionaries with these names as keys.
154        Must be None or a list or tuple of the same length as `dtypes`.
155      queue_ref: The queue reference, i.e. the output of the queue op.
156
157    Raises:
158      ValueError: If one of the arguments is invalid.
159    """
160    self._dtypes = dtypes
161    if shapes is not None:
162      if len(shapes) != len(dtypes):
163        raise ValueError("Queue shapes must have the same length as dtypes")
164      self._shapes = [tensor_shape.TensorShape(s) for s in shapes]
165    else:
166      self._shapes = [tensor_shape.unknown_shape() for _ in self._dtypes]
167    if names is not None:
168      if len(names) != len(dtypes):
169        raise ValueError("Queue names must have the same length as dtypes")
170      self._names = names
171    else:
172      self._names = None
173    self._queue_ref = queue_ref
174    if context.executing_eagerly():
175      if context.context().scope_name:
176        self._name = context.context().scope_name
177      else:
178        self._name = "Empty"
179      self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
180          queue_ref, None)
181    else:
182      self._name = self._queue_ref.op.name.split("/")[-1]
183
184  @staticmethod
185  def from_list(index, queues):
186    """Create a queue using the queue reference from `queues[index]`.
187
188    Args:
189      index: An integer scalar tensor that determines the input that gets
190        selected.
191      queues: A list of `QueueBase` objects.
192
193    Returns:
194      A `QueueBase` object.
195
196    Raises:
197      TypeError: When `queues` is not a list of `QueueBase` objects,
198        or when the data types of `queues` are not all the same.
199    """
200    if ((not queues) or (not isinstance(queues, list)) or
201        (not all(isinstance(x, QueueBase) for x in queues))):
202      raise TypeError("A list of queues expected")
203
204    dtypes = queues[0].dtypes
205    if not all(dtypes == q.dtypes for q in queues[1:]):
206      raise TypeError("Queues do not have matching component dtypes.")
207
208    names = queues[0].names
209    if not all(names == q.names for q in queues[1:]):
210      raise TypeError("Queues do not have matching component names.")
211
212    queue_shapes = [q.shapes for q in queues]
213    reduced_shapes = [
214        six.moves.reduce(_shape_common, s) for s in zip(*queue_shapes)
215    ]
216
217    queue_refs = array_ops.stack([x.queue_ref for x in queues])
218    selected_queue = array_ops.gather(queue_refs, index)
219    return QueueBase(
220        dtypes=dtypes,
221        shapes=reduced_shapes,
222        names=names,
223        queue_ref=selected_queue)
224
225  @property
226  def queue_ref(self):
227    """The underlying queue reference."""
228    return self._queue_ref
229
230  @property
231  def name(self):
232    """The name of the underlying queue."""
233    if context.executing_eagerly():
234      return self._name
235    return self._queue_ref.op.name
236
237  @property
238  def dtypes(self):
239    """The list of dtypes for each component of a queue element."""
240    return self._dtypes
241
242  @property
243  def shapes(self):
244    """The list of shapes for each component of a queue element."""
245    return self._shapes
246
247  @property
248  def names(self):
249    """The list of names for each component of a queue element."""
250    return self._names
251
252  def _check_enqueue_dtypes(self, vals):
253    """Validate and convert `vals` to a list of `Tensor`s.
254
255    The `vals` argument can be a Tensor, a list or tuple of tensors, or a
256    dictionary with tensor values.
257
258    If it is a dictionary, the queue must have been constructed with a
259    `names` attribute and the dictionary keys must match the queue names.
260    If the queue was constructed with a `names` attribute, `vals` must
261    be a dictionary.
262
263    Args:
264      vals: A tensor, a list or tuple of tensors, or a dictionary..
265
266    Returns:
267      A list of `Tensor` objects.
268
269    Raises:
270      ValueError: If `vals` is invalid.
271    """
272    if isinstance(vals, dict):
273      if not self._names:
274        raise ValueError("Queue must have names to enqueue a dictionary")
275      if sorted(self._names, key=str) != sorted(vals.keys(), key=str):
276        raise ValueError("Keys in dictionary to enqueue do not match "
277                         "names of Queue.  Dictionary: (%s), Queue: (%s)" %
278                         (sorted(vals.keys()), sorted(self._names)))
279      # The order of values in `self._names` indicates the order in which the
280      # tensors in the dictionary `vals` must be listed.
281      vals = [vals[k] for k in self._names]
282    else:
283      if self._names:
284        raise ValueError("You must enqueue a dictionary in a Queue with names")
285      if not isinstance(vals, (list, tuple)):
286        vals = [vals]
287
288    tensors = []
289    for i, (val, dtype) in enumerate(zip(vals, self._dtypes)):
290      tensors.append(
291          ops.convert_to_tensor(val, dtype=dtype, name="component_%d" % i))
292
293    return tensors
294
295  def _scope_vals(self, vals):
296    """Return a list of values to pass to `name_scope()`.
297
298    Args:
299      vals: A tensor, a list or tuple of tensors, or a dictionary.
300
301    Returns:
302      The values in vals as a list.
303    """
304    if isinstance(vals, (list, tuple)):
305      return vals
306    elif isinstance(vals, dict):
307      return vals.values()
308    else:
309      return [vals]
310
311  def enqueue(self, vals, name=None):
312    """Enqueues one element to this queue.
313
314    If the queue is full when this operation executes, it will block
315    until the element has been enqueued.
316
317    At runtime, this operation may raise an error if the queue is
318    `tf.QueueBase.close` before or during its execution. If the
319    queue is closed before this operation runs,
320    `tf.errors.CancelledError` will be raised. If this operation is
321    blocked, and either (i) the queue is closed by a close operation
322    with `cancel_pending_enqueues=True`, or (ii) the session is
323    `tf.Session.close`,
324    `tf.errors.CancelledError` will be raised.
325
326    Args:
327      vals: A tensor, a list or tuple of tensors, or a dictionary containing
328        the values to enqueue.
329      name: A name for the operation (optional).
330
331    Returns:
332      The operation that enqueues a new tuple of tensors to the queue.
333    """
334    with ops.name_scope(name, "%s_enqueue" % self._name,
335                        self._scope_vals(vals)) as scope:
336      vals = self._check_enqueue_dtypes(vals)
337
338      # NOTE(mrry): Not using a shape function because we need access to
339      # the `QueueBase` object.
340      for val, shape in zip(vals, self._shapes):
341        val.get_shape().assert_is_compatible_with(shape)
342
343      if self._queue_ref.dtype == _dtypes.resource:
344        return gen_data_flow_ops.queue_enqueue_v2(
345            self._queue_ref, vals, name=scope)
346      else:
347        return gen_data_flow_ops.queue_enqueue(
348            self._queue_ref, vals, name=scope)
349
350  def enqueue_many(self, vals, name=None):
351    """Enqueues zero or more elements to this queue.
352
353    This operation slices each component tensor along the 0th dimension to
354    make multiple queue elements. All of the tensors in `vals` must have the
355    same size in the 0th dimension.
356
357    If the queue is full when this operation executes, it will block
358    until all of the elements have been enqueued.
359
360    At runtime, this operation may raise an error if the queue is
361    `tf.QueueBase.close` before or during its execution. If the
362    queue is closed before this operation runs,
363    `tf.errors.CancelledError` will be raised. If this operation is
364    blocked, and either (i) the queue is closed by a close operation
365    with `cancel_pending_enqueues=True`, or (ii) the session is
366    `tf.Session.close`,
367    `tf.errors.CancelledError` will be raised.
368
369    Args:
370      vals: A tensor, a list or tuple of tensors, or a dictionary
371        from which the queue elements are taken.
372      name: A name for the operation (optional).
373
374    Returns:
375      The operation that enqueues a batch of tuples of tensors to the queue.
376    """
377    with ops.name_scope(name, "%s_EnqueueMany" % self._name,
378                        self._scope_vals(vals)) as scope:
379      vals = self._check_enqueue_dtypes(vals)
380
381      # NOTE(mrry): Not using a shape function because we need access to
382      # the `QueueBase` object.
383      # NOTE(fchollet): the code that follow is verbose because it needs to be
384      # compatible with both TF v1 TensorShape behavior and TF v2 behavior.
385      batch_dim = tensor_shape.dimension_value(
386          vals[0].get_shape().with_rank_at_least(1)[0])
387      batch_dim = tensor_shape.Dimension(batch_dim)
388      for val, shape in zip(vals, self._shapes):
389        val_batch_dim = tensor_shape.dimension_value(
390            val.get_shape().with_rank_at_least(1)[0])
391        val_batch_dim = tensor_shape.Dimension(val_batch_dim)
392        batch_dim = batch_dim.merge_with(val_batch_dim)
393        val.get_shape()[1:].assert_is_compatible_with(shape)
394
395      return gen_data_flow_ops.queue_enqueue_many_v2(
396          self._queue_ref, vals, name=scope)
397
398  def _dequeue_return_value(self, tensors):
399    """Return the value to return from a dequeue op.
400
401    If the queue has names, return a dictionary with the
402    names as keys.  Otherwise return either a single tensor
403    or a list of tensors depending on the length of `tensors`.
404
405    Args:
406      tensors: List of tensors from the dequeue op.
407
408    Returns:
409      A single tensor, a list of tensors, or a dictionary
410      of tensors.
411    """
412    if self._names:
413      # The returned values in `tensors` are in the same order as
414      # the names in `self._names`.
415      return {n: tensors[i] for i, n in enumerate(self._names)}
416    elif len(tensors) == 1:
417      return tensors[0]
418    else:
419      return tensors
420
421  def dequeue(self, name=None):
422    """Dequeues one element from this queue.
423
424    If the queue is empty when this operation executes, it will block
425    until there is an element to dequeue.
426
427    At runtime, this operation may raise an error if the queue is
428    `tf.QueueBase.close` before or during its execution. If the
429    queue is closed, the queue is empty, and there are no pending
430    enqueue operations that can fulfill this request,
431    `tf.errors.OutOfRangeError` will be raised. If the session is
432    `tf.Session.close`,
433    `tf.errors.CancelledError` will be raised.
434
435    Args:
436      name: A name for the operation (optional).
437
438    Returns:
439      The tuple of tensors that was dequeued.
440    """
441    if name is None:
442      name = "%s_Dequeue" % self._name
443    if self._queue_ref.dtype == _dtypes.resource:
444      ret = gen_data_flow_ops.queue_dequeue_v2(
445          self._queue_ref, self._dtypes, name=name)
446    else:
447      ret = gen_data_flow_ops.queue_dequeue(
448          self._queue_ref, self._dtypes, name=name)
449
450    # NOTE(mrry): Not using a shape function because we need access to
451    # the `QueueBase` object.
452    if not context.executing_eagerly():
453      op = ret[0].op
454      for output, shape in zip(op.values(), self._shapes):
455        output.set_shape(shape)
456
457    return self._dequeue_return_value(ret)
458
459  def dequeue_many(self, n, name=None):
460    """Dequeues and concatenates `n` elements from this queue.
461
462    This operation concatenates queue-element component tensors along
463    the 0th dimension to make a single component tensor.  All of the
464    components in the dequeued tuple will have size `n` in the 0th dimension.
465
466    If the queue is closed and there are less than `n` elements left, then an
467    `OutOfRange` exception is raised.
468
469    At runtime, this operation may raise an error if the queue is
470    `tf.QueueBase.close` before or during its execution. If the
471    queue is closed, the queue contains fewer than `n` elements, and
472    there are no pending enqueue operations that can fulfill this
473    request, `tf.errors.OutOfRangeError` will be raised. If the
474    session is `tf.Session.close`,
475    `tf.errors.CancelledError` will be raised.
476
477    Args:
478      n: A scalar `Tensor` containing the number of elements to dequeue.
479      name: A name for the operation (optional).
480
481    Returns:
482      The list of concatenated tensors that was dequeued.
483    """
484    if name is None:
485      name = "%s_DequeueMany" % self._name
486
487    ret = gen_data_flow_ops.queue_dequeue_many_v2(
488        self._queue_ref, n=n, component_types=self._dtypes, name=name)
489
490    # NOTE(mrry): Not using a shape function because we need access to
491    # the Queue object.
492    if not context.executing_eagerly():
493      op = ret[0].op
494      batch_dim = tensor_shape.Dimension(
495          tensor_util.constant_value(op.inputs[1]))
496      for output, shape in zip(op.values(), self._shapes):
497        output.set_shape(
498            tensor_shape.TensorShape([batch_dim]).concatenate(shape))
499
500    return self._dequeue_return_value(ret)
501
502  def dequeue_up_to(self, n, name=None):
503    """Dequeues and concatenates `n` elements from this queue.
504
505    **Note** This operation is not supported by all queues.  If a queue does not
506    support DequeueUpTo, then a `tf.errors.UnimplementedError` is raised.
507
508    This operation concatenates queue-element component tensors along
509    the 0th dimension to make a single component tensor. If the queue
510    has not been closed, all of the components in the dequeued tuple
511    will have size `n` in the 0th dimension.
512
513    If the queue is closed and there are more than `0` but fewer than
514    `n` elements remaining, then instead of raising a
515    `tf.errors.OutOfRangeError` like `tf.QueueBase.dequeue_many`,
516    less than `n` elements are returned immediately.  If the queue is
517    closed and there are `0` elements left in the queue, then a
518    `tf.errors.OutOfRangeError` is raised just like in `dequeue_many`.
519    Otherwise the behavior is identical to `dequeue_many`.
520
521    Args:
522      n: A scalar `Tensor` containing the number of elements to dequeue.
523      name: A name for the operation (optional).
524
525    Returns:
526      The tuple of concatenated tensors that was dequeued.
527    """
528    if name is None:
529      name = "%s_DequeueUpTo" % self._name
530
531    ret = gen_data_flow_ops.queue_dequeue_up_to_v2(
532        self._queue_ref, n=n, component_types=self._dtypes, name=name)
533
534    # NOTE(mrry): Not using a shape function because we need access to
535    # the Queue object.
536    if not context.executing_eagerly():
537      op = ret[0].op
538      for output, shape in zip(op.values(), self._shapes):
539        output.set_shape(tensor_shape.TensorShape([None]).concatenate(shape))
540
541    return self._dequeue_return_value(ret)
542
543  def close(self, cancel_pending_enqueues=False, name=None):
544    """Closes this queue.
545
546    This operation signals that no more elements will be enqueued in
547    the given queue. Subsequent `enqueue` and `enqueue_many`
548    operations will fail. Subsequent `dequeue` and `dequeue_many`
549    operations will continue to succeed if sufficient elements remain
550    in the queue. Subsequently dequeue and dequeue_many operations
551    that would otherwise block waiting for more elements (if close
552    hadn't been called) will now fail immediately.
553
554    If `cancel_pending_enqueues` is `True`, all pending requests will also
555    be canceled.
556
557    Args:
558      cancel_pending_enqueues: (Optional.) A boolean, defaulting to
559        `False` (described above).
560      name: A name for the operation (optional).
561
562    Returns:
563      The operation that closes the queue.
564    """
565    if name is None:
566      name = "%s_Close" % self._name
567    if self._queue_ref.dtype == _dtypes.resource:
568      return gen_data_flow_ops.queue_close_v2(
569          self._queue_ref,
570          cancel_pending_enqueues=cancel_pending_enqueues,
571          name=name)
572    else:
573      return gen_data_flow_ops.queue_close(
574          self._queue_ref,
575          cancel_pending_enqueues=cancel_pending_enqueues,
576          name=name)
577
578  def is_closed(self, name=None):
579    """Returns true if queue is closed.
580
581    This operation returns true if the queue is closed and false if the queue
582    is open.
583
584    Args:
585      name: A name for the operation (optional).
586
587    Returns:
588      True if the queue is closed and false if the queue is open.
589    """
590    if name is None:
591      name = "%s_Is_Closed" % self._name
592    if self._queue_ref.dtype == _dtypes.resource:
593      return gen_data_flow_ops.queue_is_closed_v2(self._queue_ref, name=name)
594    else:
595      return gen_data_flow_ops.queue_is_closed_(self._queue_ref, name=name)
596
597  def size(self, name=None):
598    """Compute the number of elements in this queue.
599
600    Args:
601      name: A name for the operation (optional).
602
603    Returns:
604      A scalar tensor containing the number of elements in this queue.
605    """
606    if name is None:
607      name = "%s_Size" % self._name
608    if self._queue_ref.dtype == _dtypes.resource:
609      return gen_data_flow_ops.queue_size_v2(self._queue_ref, name=name)
610    else:
611      return gen_data_flow_ops.queue_size(self._queue_ref, name=name)
612
613def _shared_name(shared_name):
614  if context.executing_eagerly():
615    return str(ops.uid())
616  return shared_name
617
618
619@tf_export(
620    "queue.RandomShuffleQueue",
621    v1=["queue.RandomShuffleQueue",
622        "io.RandomShuffleQueue", "RandomShuffleQueue"])
623@deprecation.deprecated_endpoints(
624    ["io.RandomShuffleQueue", "RandomShuffleQueue"])
625class RandomShuffleQueue(QueueBase):
626  """A queue implementation that dequeues elements in a random order.
627
628  See `tf.QueueBase` for a description of the methods on
629  this class.
630  """
631
632  def __init__(self,
633               capacity,
634               min_after_dequeue,
635               dtypes,
636               shapes=None,
637               names=None,
638               seed=None,
639               shared_name=None,
640               name="random_shuffle_queue"):
641    """Create a queue that dequeues elements in a random order.
642
643    A `RandomShuffleQueue` has bounded capacity; supports multiple
644    concurrent producers and consumers; and provides exactly-once
645    delivery.
646
647    A `RandomShuffleQueue` holds a list of up to `capacity`
648    elements. Each element is a fixed-length tuple of tensors whose
649    dtypes are described by `dtypes`, and whose shapes are optionally
650    described by the `shapes` argument.
651
652    If the `shapes` argument is specified, each component of a queue
653    element must have the respective fixed shape. If it is
654    unspecified, different queue elements may have different shapes,
655    but the use of `dequeue_many` is disallowed.
656
657    The `min_after_dequeue` argument allows the caller to specify a
658    minimum number of elements that will remain in the queue after a
659    `dequeue` or `dequeue_many` operation completes, to ensure a
660    minimum level of mixing of elements. This invariant is maintained
661    by blocking those operations until sufficient elements have been
662    enqueued. The `min_after_dequeue` argument is ignored after the
663    queue has been closed.
664
665    Args:
666      capacity: An integer. The upper bound on the number of elements
667        that may be stored in this queue.
668      min_after_dequeue: An integer (described above).
669      dtypes:  A list of `DType` objects. The length of `dtypes` must equal
670        the number of tensors in each queue element.
671      shapes: (Optional.) A list of fully-defined `TensorShape` objects
672        with the same length as `dtypes`, or `None`.
673      names: (Optional.) A list of string naming the components in the queue
674        with the same length as `dtypes`, or `None`.  If specified the dequeue
675        methods return a dictionary with the names as keys.
676      seed: A Python integer. Used to create a random seed. See
677        `tf.set_random_seed`
678        for behavior.
679      shared_name: (Optional.) If non-empty, this queue will be shared under
680        the given name across multiple sessions.
681      name: Optional name for the queue operation.
682    """
683    dtypes = _as_type_list(dtypes)
684    shapes = _as_shape_list(shapes, dtypes)
685    names = _as_name_list(names, dtypes)
686    seed1, seed2 = random_seed.get_seed(seed)
687    if seed1 is None and seed2 is None:
688      seed1, seed2 = 0, 0
689    elif seed is None and shared_name is not None:
690      # This means that graph seed is provided but op seed is not provided.
691      # If shared_name is also provided, make seed2 depend only on the graph
692      # seed and shared_name. (seed2 from get_seed() is generally dependent on
693      # the id of the last op created.)
694      string = (str(seed1) + shared_name).encode("utf-8")
695      seed2 = int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF
696    queue_ref = gen_data_flow_ops.random_shuffle_queue_v2(
697        component_types=dtypes,
698        shapes=shapes,
699        capacity=capacity,
700        min_after_dequeue=min_after_dequeue,
701        seed=seed1,
702        seed2=seed2,
703        shared_name=_shared_name(shared_name),
704        name=name)
705
706    super(RandomShuffleQueue, self).__init__(dtypes, shapes, names, queue_ref)
707
708
709@tf_export("queue.FIFOQueue", v1=["queue.FIFOQueue", "FIFOQueue"])
710@deprecation.deprecated_endpoints("FIFOQueue")
711class FIFOQueue(QueueBase):
712  """A queue implementation that dequeues elements in first-in first-out order.
713
714  See `tf.QueueBase` for a description of the methods on
715  this class.
716  """
717
718  def __init__(self,
719               capacity,
720               dtypes,
721               shapes=None,
722               names=None,
723               shared_name=None,
724               name="fifo_queue"):
725    """Creates a queue that dequeues elements in a first-in first-out order.
726
727    A `FIFOQueue` has bounded capacity; supports multiple concurrent
728    producers and consumers; and provides exactly-once delivery.
729
730    A `FIFOQueue` holds a list of up to `capacity` elements. Each
731    element is a fixed-length tuple of tensors whose dtypes are
732    described by `dtypes`, and whose shapes are optionally described
733    by the `shapes` argument.
734
735    If the `shapes` argument is specified, each component of a queue
736    element must have the respective fixed shape. If it is
737    unspecified, different queue elements may have different shapes,
738    but the use of `dequeue_many` is disallowed.
739
740    Args:
741      capacity: An integer. The upper bound on the number of elements
742        that may be stored in this queue.
743      dtypes:  A list of `DType` objects. The length of `dtypes` must equal
744        the number of tensors in each queue element.
745      shapes: (Optional.) A list of fully-defined `TensorShape` objects
746        with the same length as `dtypes`, or `None`.
747      names: (Optional.) A list of string naming the components in the queue
748        with the same length as `dtypes`, or `None`.  If specified the dequeue
749        methods return a dictionary with the names as keys.
750      shared_name: (Optional.) If non-empty, this queue will be shared under
751        the given name across multiple sessions.
752      name: Optional name for the queue operation.
753    """
754    dtypes = _as_type_list(dtypes)
755    shapes = _as_shape_list(shapes, dtypes)
756    names = _as_name_list(names, dtypes)
757    queue_ref = gen_data_flow_ops.fifo_queue_v2(
758        component_types=dtypes,
759        shapes=shapes,
760        capacity=capacity,
761        shared_name=_shared_name(shared_name),
762        name=name)
763
764    super(FIFOQueue, self).__init__(dtypes, shapes, names, queue_ref)
765
766
767@tf_export(
768    "queue.PaddingFIFOQueue",
769    v1=["queue.PaddingFIFOQueue", "io.PaddingFIFOQueue", "PaddingFIFOQueue"])
770@deprecation.deprecated_endpoints(["io.PaddingFIFOQueue", "PaddingFIFOQueue"])
771class PaddingFIFOQueue(QueueBase):
772  """A FIFOQueue that supports batching variable-sized tensors by padding.
773
774  A `PaddingFIFOQueue` may contain components with dynamic shape, while also
775  supporting `dequeue_many`.  See the constructor for more details.
776
777  See `tf.QueueBase` for a description of the methods on
778  this class.
779  """
780
781  def __init__(self,
782               capacity,
783               dtypes,
784               shapes,
785               names=None,
786               shared_name=None,
787               name="padding_fifo_queue"):
788    """Creates a queue that dequeues elements in a first-in first-out order.
789
790    A `PaddingFIFOQueue` has bounded capacity; supports multiple concurrent
791    producers and consumers; and provides exactly-once delivery.
792
793    A `PaddingFIFOQueue` holds a list of up to `capacity` elements. Each
794    element is a fixed-length tuple of tensors whose dtypes are
795    described by `dtypes`, and whose shapes are described by the `shapes`
796    argument.
797
798    The `shapes` argument must be specified; each component of a queue
799    element must have the respective shape.  Shapes of fixed
800    rank but variable size are allowed by setting any shape dimension to None.
801    In this case, the inputs' shape may vary along the given dimension, and
802    `dequeue_many` will pad the given dimension with zeros up to the maximum
803    shape of all elements in the given batch.
804
805    Args:
806      capacity: An integer. The upper bound on the number of elements
807        that may be stored in this queue.
808      dtypes:  A list of `DType` objects. The length of `dtypes` must equal
809        the number of tensors in each queue element.
810      shapes: A list of `TensorShape` objects, with the same length as
811        `dtypes`.  Any dimension in the `TensorShape` containing value
812        `None` is dynamic and allows values to be enqueued with
813         variable size in that dimension.
814      names: (Optional.) A list of string naming the components in the queue
815        with the same length as `dtypes`, or `None`.  If specified the dequeue
816        methods return a dictionary with the names as keys.
817      shared_name: (Optional.) If non-empty, this queue will be shared under
818        the given name across multiple sessions.
819      name: Optional name for the queue operation.
820
821    Raises:
822      ValueError: If shapes is not a list of shapes, or the lengths of dtypes
823        and shapes do not match, or if names is specified and the lengths of
824        dtypes and names do not match.
825    """
826    dtypes = _as_type_list(dtypes)
827    shapes = _as_shape_list(shapes, dtypes, unknown_dim_allowed=True)
828    names = _as_name_list(names, dtypes)
829    if len(dtypes) != len(shapes):
830      raise ValueError("Shapes must be provided for all components, "
831                       "but received %d dtypes and %d shapes." % (len(dtypes),
832                                                                  len(shapes)))
833
834    queue_ref = gen_data_flow_ops.padding_fifo_queue_v2(
835        component_types=dtypes,
836        shapes=shapes,
837        capacity=capacity,
838        shared_name=_shared_name(shared_name),
839        name=name)
840
841    super(PaddingFIFOQueue, self).__init__(dtypes, shapes, names, queue_ref)
842
843
844@tf_export("queue.PriorityQueue",
845           v1=["queue.PriorityQueue", "io.PriorityQueue", "PriorityQueue"])
846@deprecation.deprecated_endpoints(["io.PriorityQueue", "PriorityQueue"])
847class PriorityQueue(QueueBase):
848  """A queue implementation that dequeues elements in prioritized order.
849
850  See `tf.QueueBase` for a description of the methods on
851  this class.
852  """
853
854  def __init__(self,
855               capacity,
856               types,
857               shapes=None,
858               names=None,
859               shared_name=None,
860               name="priority_queue"):
861    """Creates a queue that dequeues elements in a first-in first-out order.
862
863    A `PriorityQueue` has bounded capacity; supports multiple concurrent
864    producers and consumers; and provides exactly-once delivery.
865
866    A `PriorityQueue` holds a list of up to `capacity` elements. Each
867    element is a fixed-length tuple of tensors whose dtypes are
868    described by `types`, and whose shapes are optionally described
869    by the `shapes` argument.
870
871    If the `shapes` argument is specified, each component of a queue
872    element must have the respective fixed shape. If it is
873    unspecified, different queue elements may have different shapes,
874    but the use of `dequeue_many` is disallowed.
875
876    Enqueues and Dequeues to the `PriorityQueue` must include an additional
877    tuple entry at the beginning: the `priority`.  The priority must be
878    an int64 scalar (for `enqueue`) or an int64 vector (for `enqueue_many`).
879
880    Args:
881      capacity: An integer. The upper bound on the number of elements
882        that may be stored in this queue.
883      types:  A list of `DType` objects. The length of `types` must equal
884        the number of tensors in each queue element, except the first priority
885        element.  The first tensor in each element is the priority,
886        which must be type int64.
887      shapes: (Optional.) A list of fully-defined `TensorShape` objects,
888        with the same length as `types`, or `None`.
889      names: (Optional.) A list of strings naming the components in the queue
890        with the same length as `dtypes`, or `None`.  If specified, the dequeue
891        methods return a dictionary with the names as keys.
892      shared_name: (Optional.) If non-empty, this queue will be shared under
893        the given name across multiple sessions.
894      name: Optional name for the queue operation.
895    """
896    types = _as_type_list(types)
897    shapes = _as_shape_list(shapes, types)
898
899    queue_ref = gen_data_flow_ops.priority_queue_v2(
900        component_types=types,
901        shapes=shapes,
902        capacity=capacity,
903        shared_name=_shared_name(shared_name),
904        name=name)
905
906    priority_dtypes = [_dtypes.int64] + types
907    priority_shapes = [()] + shapes if shapes else shapes
908
909    super(PriorityQueue, self).__init__(priority_dtypes, priority_shapes, names,
910                                        queue_ref)
911
912
913# TODO(josh11b): class BatchQueue(QueueBase):
914
915
916class Barrier(object):
917  """Represents a key-value map that persists across graph executions."""
918
919  def __init__(self, types, shapes=None, shared_name=None, name="barrier"):
920    """Creates a barrier that persists across different graph executions.
921
922    A barrier represents a key-value map, where each key is a string, and
923    each value is a tuple of tensors.
924
925    At runtime, the barrier contains 'complete' and 'incomplete'
926    elements. A complete element has defined tensors for all
927    components of its value tuple, and may be accessed using
928    take_many. An incomplete element has some undefined components in
929    its value tuple, and may be updated using insert_many.
930
931    The barrier call `take_many` outputs values in a particular order.
932    First, it only outputs completed values.  Second, the order in which
933    completed values are returned matches the order in which their very
934    first component was inserted into the barrier.  So, for example, for this
935    sequence of insertions and removals:
936
937      barrier = Barrier((tf.string, tf.int32), shapes=((), ()))
938      barrier.insert_many(0, keys=["k1", "k2"], values=["a", "b"]).run()
939      barrier.insert_many(1, keys=["k1"], values=[1]).run()
940      barrier.insert_many(0, keys=["k3"], values=["c"]).run()
941      barrier.insert_many(1, keys=["k3"], values=[3]).run()
942      barrier.insert_many(1, keys=["k2"], values=[2]).run()
943
944      (indices, keys, values) = barrier.take_many(2)
945      (indices_val, keys_val, values0_val, values1_val) =
946         session.run([indices, keys, values[0], values[1]])
947
948    The output will be (up to permutation of "k1" and "k2"):
949
950      indices_val == (-2**63, -2**63)
951      keys_val == ("k1", "k2")
952      values0_val == ("a", "b")
953      values1_val == (1, 2)
954
955    Note the key "k2" was inserted into the barrier before "k3".  Even though
956    "k3" was completed first, both are complete by the time
957    take_many is called.  As a result, "k2" is prioritized and "k1" and "k2"
958    are returned first.  "k3" remains in the barrier until the next execution
959    of `take_many`.  Since "k1" and "k2" had their first insertions into
960    the barrier together, their indices are the same (-2**63).  The index
961    of "k3" will be -2**63 + 1, because it was the next new inserted key.
962
963    Args:
964      types: A single dtype or a tuple of dtypes, corresponding to the
965        dtypes of the tensor elements that comprise a value in this barrier.
966      shapes: Optional. Constraints on the shapes of tensors in the values:
967        a single tensor shape tuple; a tuple of tensor shape tuples
968        for each barrier-element tuple component; or None if the shape should
969        not be constrained.
970      shared_name: Optional. If non-empty, this barrier will be shared under
971        the given name across multiple sessions.
972      name: Optional name for the barrier op.
973
974    Raises:
975      ValueError: If one of the `shapes` indicate no elements.
976    """
977    self._types = _as_type_list(types)
978
979    if shapes is not None:
980      shapes = _as_shape_list(shapes, self._types)
981      self._shapes = [tensor_shape.TensorShape(s) for s in shapes]
982      for i, shape in enumerate(self._shapes):
983        if shape.num_elements() == 0:
984          raise ValueError("Empty tensors are not supported, but received "
985                           "shape '%s' at index %d" % (shape, i))
986    else:
987      self._shapes = [tensor_shape.unknown_shape() for _ in self._types]
988
989    self._barrier_ref = gen_data_flow_ops.barrier(
990        component_types=self._types,
991        shapes=self._shapes,
992        shared_name=shared_name,
993        name=name)
994    if context.executing_eagerly():
995      self._name = context.context().scope_name
996    else:
997      self._name = self._barrier_ref.op.name.split("/")[-1]
998
999  @property
1000  def barrier_ref(self):
1001    """Get the underlying barrier reference."""
1002    return self._barrier_ref
1003
1004  @property
1005  def name(self):
1006    """The name of the underlying barrier."""
1007    if context.executing_eagerly():
1008      return self._name
1009    return self._barrier_ref.op.name
1010
1011  def insert_many(self, component_index, keys, values, name=None):
1012    """For each key, assigns the respective value to the specified component.
1013
1014    This operation updates each element at component_index.
1015
1016    Args:
1017      component_index: The component of the value that is being assigned.
1018      keys: A vector of keys, with length n.
1019      values: An any-dimensional tensor of values, which are associated with the
1020        respective keys. The first dimension must have length n.
1021      name: Optional name for the op.
1022
1023    Returns:
1024      The operation that performs the insertion.
1025    Raises:
1026      InvalidArgumentsError: If inserting keys and values without elements.
1027    """
1028    if name is None:
1029      name = "%s_BarrierInsertMany" % self._name
1030    return gen_data_flow_ops.barrier_insert_many(
1031        self._barrier_ref, keys, values, component_index, name=name)
1032
1033  def take_many(self,
1034                num_elements,
1035                allow_small_batch=False,
1036                timeout=None,
1037                name=None):
1038    """Takes the given number of completed elements from this barrier.
1039
1040    This operation concatenates completed-element component tensors along
1041    the 0th dimension to make a single component tensor.
1042
1043    If barrier has no completed elements, this operation will block
1044    until there are 'num_elements' elements to take.
1045
1046    TODO(b/25743580): the semantics of `allow_small_batch` are experimental
1047    and may be extended to other cases in the future.
1048
1049    TODO(ebrevdo): If a take_many(allow_small_batch=True) is blocking
1050    already when the barrier is closed, it will block for ever. Fix this
1051    by using asynchronous operations.
1052
1053    Args:
1054      num_elements: The number of elements to take.
1055      allow_small_batch: If the barrier is closed, don't block if there are less
1056        completed elements than requested, but instead return all available
1057        completed elements.
1058      timeout: This specifies the number of milliseconds to block
1059        before returning with DEADLINE_EXCEEDED. (This option is not
1060        supported yet.)
1061      name: A name for the operation (optional).
1062
1063    Returns:
1064      A tuple of (index, key, value_list).
1065      "index" is a int64 tensor of length num_elements containing the
1066        index of the insert_many call for which the very first component of
1067        the given element was inserted into the Barrier, starting with
1068        the value -2**63.  Note, this value is different from the
1069        index of the insert_many call for which the element was completed.
1070      "key" is a string tensor of length num_elements containing the keys.
1071      "value_list" is a tuple of tensors, each one with size num_elements
1072        in the 0th dimension for each component in the barrier's values.
1073
1074    """
1075    if name is None:
1076      name = "%s_BarrierTakeMany" % self._name
1077    ret = gen_data_flow_ops.barrier_take_many(
1078        self._barrier_ref,
1079        num_elements,
1080        self._types,
1081        allow_small_batch,
1082        timeout,
1083        name=name)
1084
1085    # NOTE(mrry): Not using a shape function because we need access to
1086    # the Barrier object.
1087    if not context.executing_eagerly():
1088      op = ret[0].op
1089      if allow_small_batch:
1090        batch_dim = None
1091      else:
1092        batch_dim = tensor_shape.Dimension(
1093            tensor_util.constant_value(op.inputs[1]))
1094      op.outputs[0].set_shape(tensor_shape.vector(batch_dim))  # indices
1095      op.outputs[1].set_shape(tensor_shape.vector(batch_dim))  # keys
1096      for output, shape in zip(op.outputs[2:], self._shapes):  # value_list
1097        output.set_shape(
1098            tensor_shape.TensorShape([batch_dim]).concatenate(shape))
1099
1100    return ret
1101
1102  def close(self, cancel_pending_enqueues=False, name=None):
1103    """Closes this barrier.
1104
1105    This operation signals that no more new key values will be inserted in the
1106    given barrier. Subsequent InsertMany operations with new keys will fail.
1107    InsertMany operations that just complement already existing keys with other
1108    components, will continue to succeed. Subsequent TakeMany operations will
1109    continue to succeed if sufficient elements remain in the barrier. Subsequent
1110    TakeMany operations that would block will fail immediately.
1111
1112    If `cancel_pending_enqueues` is `True`, all pending requests to the
1113    underlying queue will also be canceled, and completing of already
1114    started values is also not acceptable anymore.
1115
1116    Args:
1117      cancel_pending_enqueues: (Optional.) A boolean, defaulting to
1118        `False` (described above).
1119      name: Optional name for the op.
1120
1121    Returns:
1122      The operation that closes the barrier.
1123    """
1124    if name is None:
1125      name = "%s_BarrierClose" % self._name
1126    return gen_data_flow_ops.barrier_close(
1127        self._barrier_ref,
1128        cancel_pending_enqueues=cancel_pending_enqueues,
1129        name=name)
1130
1131  def ready_size(self, name=None):
1132    """Compute the number of complete elements in the given barrier.
1133
1134    Args:
1135      name: A name for the operation (optional).
1136
1137    Returns:
1138      A single-element tensor containing the number of complete elements in the
1139      given barrier.
1140    """
1141    if name is None:
1142      name = "%s_BarrierReadySize" % self._name
1143    return gen_data_flow_ops.barrier_ready_size(self._barrier_ref, name=name)
1144
1145  def incomplete_size(self, name=None):
1146    """Compute the number of incomplete elements in the given barrier.
1147
1148    Args:
1149      name: A name for the operation (optional).
1150
1151    Returns:
1152      A single-element tensor containing the number of incomplete elements in
1153      the given barrier.
1154    """
1155    if name is None:
1156      name = "%s_BarrierIncompleteSize" % self._name
1157    return gen_data_flow_ops.barrier_incomplete_size(
1158        self._barrier_ref, name=name)
1159
1160
1161@tf_export(v1=["ConditionalAccumulatorBase"])
1162class ConditionalAccumulatorBase(object):
1163  """A conditional accumulator for aggregating gradients.
1164
1165  Up-to-date gradients (i.e., time step at which gradient was computed is
1166  equal to the accumulator's time step) are added to the accumulator.
1167
1168  Extraction of the average gradient is blocked until the required number of
1169  gradients has been accumulated.
1170  """
1171
1172  def __init__(self, dtype, shape, accumulator_ref):
1173    """Creates a new ConditionalAccumulator.
1174
1175    Args:
1176      dtype: Datatype of the accumulated gradients.
1177      shape: Shape of the accumulated gradients.
1178      accumulator_ref: A handle to the conditional accumulator, created by sub-
1179        classes
1180    """
1181    self._dtype = dtype
1182    if shape is not None:
1183      self._shape = tensor_shape.TensorShape(shape)
1184    else:
1185      self._shape = tensor_shape.unknown_shape()
1186    self._accumulator_ref = accumulator_ref
1187    if context.executing_eagerly():
1188      self._name = context.context().scope_name
1189    else:
1190      self._name = self._accumulator_ref.op.name.split("/")[-1]
1191
1192  @property
1193  def accumulator_ref(self):
1194    """The underlying accumulator reference."""
1195    return self._accumulator_ref
1196
1197  @property
1198  def name(self):
1199    """The name of the underlying accumulator."""
1200    return self._name
1201
1202  @property
1203  def dtype(self):
1204    """The datatype of the gradients accumulated by this accumulator."""
1205    return self._dtype
1206
1207  def num_accumulated(self, name=None):
1208    """Number of gradients that have currently been aggregated in accumulator.
1209
1210    Args:
1211      name: Optional name for the operation.
1212
1213    Returns:
1214      Number of accumulated gradients currently in accumulator.
1215    """
1216    if name is None:
1217      name = "%s_NumAccumulated" % self._name
1218    return gen_data_flow_ops.accumulator_num_accumulated(
1219        self._accumulator_ref, name=name)
1220
1221  def set_global_step(self, new_global_step, name=None):
1222    """Sets the global time step of the accumulator.
1223
1224    The operation logs a warning if we attempt to set to a time step that is
1225    lower than the accumulator's own time step.
1226
1227    Args:
1228      new_global_step: Value of new time step. Can be a variable or a constant
1229      name: Optional name for the operation.
1230
1231    Returns:
1232      Operation that sets the accumulator's time step.
1233    """
1234    return gen_data_flow_ops.accumulator_set_global_step(
1235        self._accumulator_ref,
1236        math_ops.cast(ops.convert_to_tensor(new_global_step), _dtypes.int64),
1237        name=name)
1238
1239
1240@tf_export(v1=["ConditionalAccumulator"])
1241class ConditionalAccumulator(ConditionalAccumulatorBase):
1242  """A conditional accumulator for aggregating gradients.
1243
1244  Up-to-date gradients (i.e., time step at which gradient was computed is
1245  equal to the accumulator's time step) are added to the accumulator.
1246
1247  Extraction of the average gradient is blocked until the required number of
1248  gradients has been accumulated.
1249  """
1250
1251  def __init__(self,
1252               dtype,
1253               shape=None,
1254               shared_name=None,
1255               name="conditional_accumulator",
1256               reduction_type="MEAN"):
1257    """Creates a new ConditionalAccumulator.
1258
1259    Args:
1260      dtype: Datatype of the accumulated gradients.
1261      shape: Shape of the accumulated gradients.
1262      shared_name: Optional. If non-empty, this accumulator will be shared under
1263        the given name across multiple sessions.
1264      name: Optional name for the accumulator.
1265      reduction_type: Reduction type to use when taking the gradient.
1266    """
1267    accumulator_ref = gen_data_flow_ops.conditional_accumulator(
1268        dtype=dtype,
1269        shape=shape,
1270        shared_name=shared_name,
1271        name=name,
1272        reduction_type=reduction_type)
1273    super(ConditionalAccumulator, self).__init__(dtype, shape, accumulator_ref)
1274
1275  def apply_grad(self, grad, local_step=0, name=None):
1276    """Attempts to apply a gradient to the accumulator.
1277
1278    The attempt is silently dropped if the gradient is stale, i.e., local_step
1279    is less than the accumulator's global time step.
1280
1281    Args:
1282      grad: The gradient tensor to be applied.
1283      local_step: Time step at which the gradient was computed.
1284      name: Optional name for the operation.
1285
1286    Returns:
1287      The operation that (conditionally) applies a gradient to the accumulator.
1288
1289    Raises:
1290      ValueError: If grad is of the wrong shape
1291    """
1292    grad = ops.convert_to_tensor(grad, self._dtype)
1293    grad.get_shape().assert_is_compatible_with(self._shape)
1294    local_step = math_ops.cast(ops.convert_to_tensor(local_step), _dtypes.int64)
1295    return gen_data_flow_ops.accumulator_apply_gradient(
1296        self._accumulator_ref, local_step=local_step, gradient=grad, name=name)
1297
1298  def take_grad(self, num_required, name=None):
1299    """Attempts to extract the average gradient from the accumulator.
1300
1301    The operation blocks until sufficient number of gradients have been
1302    successfully applied to the accumulator.
1303
1304    Once successful, the following actions are also triggered:
1305
1306    - Counter of accumulated gradients is reset to 0.
1307    - Aggregated gradient is reset to 0 tensor.
1308    - Accumulator's internal time step is incremented by 1.
1309
1310    Args:
1311      num_required: Number of gradients that needs to have been aggregated
1312      name: Optional name for the operation
1313
1314    Returns:
1315      A tensor holding the value of the average gradient.
1316
1317    Raises:
1318      InvalidArgumentError: If num_required < 1
1319    """
1320    out = gen_data_flow_ops.accumulator_take_gradient(
1321        self._accumulator_ref, num_required, dtype=self._dtype, name=name)
1322    out.set_shape(self._shape)
1323    return out
1324
1325
1326@tf_export(
1327    "sparse.SparseConditionalAccumulator",
1328    v1=["sparse.SparseConditionalAccumulator", "SparseConditionalAccumulator"])
1329@deprecation.deprecated_endpoints("SparseConditionalAccumulator")
1330class SparseConditionalAccumulator(ConditionalAccumulatorBase):
1331  """A conditional accumulator for aggregating sparse gradients.
1332
1333  Sparse gradients are represented by IndexedSlices.
1334
1335  Up-to-date gradients (i.e., time step at which gradient was computed is
1336  equal to the accumulator's time step) are added to the accumulator.
1337
1338  Extraction of the average gradient is blocked until the required number of
1339  gradients has been accumulated.
1340
1341  Args:
1342    dtype: Datatype of the accumulated gradients.
1343    shape: Shape of the accumulated gradients.
1344    shared_name: Optional. If non-empty, this accumulator will be shared under
1345      the given name across multiple sessions.
1346    name: Optional name for the accumulator.
1347    reduction_type: Reduction type to use when taking the gradient.
1348  """
1349
1350  def __init__(self,
1351               dtype,
1352               shape=None,
1353               shared_name=None,
1354               name="sparse_conditional_accumulator",
1355               reduction_type="MEAN"):
1356    accumulator_ref = gen_data_flow_ops.sparse_conditional_accumulator(
1357        dtype=dtype,
1358        shape=shape,
1359        shared_name=shared_name,
1360        name=name,
1361        reduction_type=reduction_type)
1362    super(SparseConditionalAccumulator, self).__init__(dtype, shape,
1363                                                       accumulator_ref)
1364
1365  def apply_indexed_slices_grad(self, grad, local_step=0, name=None):
1366    """Attempts to apply a gradient to the accumulator.
1367
1368    The attempt is silently dropped if the gradient is stale, i.e., local_step
1369    is less than the accumulator's global time step.
1370
1371    Args:
1372      grad: The gradient IndexedSlices to be applied.
1373      local_step: Time step at which the gradient was computed.
1374      name: Optional name for the operation.
1375
1376    Returns:
1377      The operation that (conditionally) applies a gradient to the accumulator.
1378
1379    Raises:
1380      InvalidArgumentError: If grad is of the wrong shape
1381    """
1382    return self.apply_grad(
1383        grad_indices=grad.indices,
1384        grad_values=grad.values,
1385        grad_shape=grad.dense_shape,
1386        local_step=local_step,
1387        name=name)
1388
1389  def apply_grad(self,
1390                 grad_indices,
1391                 grad_values,
1392                 grad_shape=None,
1393                 local_step=0,
1394                 name=None):
1395    """Attempts to apply a sparse gradient to the accumulator.
1396
1397    The attempt is silently dropped if the gradient is stale, i.e., local_step
1398    is less than the accumulator's global time step.
1399
1400    A sparse gradient is represented by its indices, values and possibly empty
1401    or None shape. Indices must be a vector representing the locations of
1402    non-zero entries in the tensor. Values are the non-zero slices of the
1403    gradient, and must have the same first dimension as indices, i.e., the nnz
1404    represented by indices and values must be consistent. Shape, if not empty or
1405    None, must be consistent with the accumulator's shape (if also provided).
1406
1407    Example:
1408      A tensor [[0, 0], [0. 1], [2, 3]] can be represented
1409        indices: [1,2]
1410        values: [[0,1],[2,3]]
1411        shape: [3, 2]
1412
1413    Args:
1414      grad_indices: Indices of the sparse gradient to be applied.
1415      grad_values: Values of the sparse gradient to be applied.
1416      grad_shape: Shape of the sparse gradient to be applied.
1417      local_step: Time step at which the gradient was computed.
1418      name: Optional name for the operation.
1419
1420    Returns:
1421      The operation that (conditionally) applies a gradient to the accumulator.
1422
1423    Raises:
1424      InvalidArgumentError: If grad is of the wrong shape
1425    """
1426    local_step = math_ops.cast(ops.convert_to_tensor(local_step), _dtypes.int64)
1427    return gen_data_flow_ops.sparse_accumulator_apply_gradient(
1428        self._accumulator_ref,
1429        local_step=local_step,
1430        gradient_indices=math_ops.cast(grad_indices, _dtypes.int64),
1431        gradient_values=grad_values,
1432        gradient_shape=math_ops.cast(
1433            [] if grad_shape is None else grad_shape, _dtypes.int64),
1434        has_known_shape=(grad_shape is not None),
1435        name=name)
1436
1437  def take_grad(self, num_required, name=None):
1438    """Attempts to extract the average gradient from the accumulator.
1439
1440    The operation blocks until sufficient number of gradients have been
1441    successfully applied to the accumulator.
1442
1443    Once successful, the following actions are also triggered:
1444    - Counter of accumulated gradients is reset to 0.
1445    - Aggregated gradient is reset to 0 tensor.
1446    - Accumulator's internal time step is incremented by 1.
1447
1448    Args:
1449      num_required: Number of gradients that needs to have been aggregated
1450      name: Optional name for the operation
1451
1452    Returns:
1453      A tuple of indices, values, and shape representing the average gradient.
1454
1455    Raises:
1456      InvalidArgumentError: If num_required < 1
1457    """
1458    return gen_data_flow_ops.sparse_accumulator_take_gradient(
1459        self._accumulator_ref, num_required, dtype=self._dtype, name=name)
1460
1461  def take_indexed_slices_grad(self, num_required, name=None):
1462    """Attempts to extract the average gradient from the accumulator.
1463
1464    The operation blocks until sufficient number of gradients have been
1465    successfully applied to the accumulator.
1466
1467    Once successful, the following actions are also triggered:
1468    - Counter of accumulated gradients is reset to 0.
1469    - Aggregated gradient is reset to 0 tensor.
1470    - Accumulator's internal time step is incremented by 1.
1471
1472    Args:
1473      num_required: Number of gradients that needs to have been aggregated
1474      name: Optional name for the operation
1475
1476    Returns:
1477      An IndexedSlices holding the value of the average gradient.
1478
1479    Raises:
1480      InvalidArgumentError: If num_required < 1
1481    """
1482    return_val = gen_data_flow_ops.sparse_accumulator_take_gradient(
1483        self._accumulator_ref, num_required, dtype=self._dtype, name=name)
1484    return ops.IndexedSlices(
1485        indices=return_val.indices,
1486        values=return_val.values,
1487        dense_shape=return_val.shape)
1488
1489
1490class BaseStagingArea(object):
1491  """Base class for Staging Areas."""
1492  _identifier = 0
1493  _lock = threading.Lock()
1494
1495  def __init__(self,
1496               dtypes,
1497               shapes=None,
1498               names=None,
1499               shared_name=None,
1500               capacity=0,
1501               memory_limit=0):
1502    if shared_name is None:
1503      self._name = (
1504          ops.get_default_graph().unique_name(self.__class__.__name__))
1505    elif isinstance(shared_name, six.string_types):
1506      self._name = shared_name
1507    else:
1508      raise ValueError("shared_name must be a string")
1509
1510    self._dtypes = dtypes
1511
1512    if shapes is not None:
1513      if len(shapes) != len(dtypes):
1514        raise ValueError("StagingArea shapes must be the same length as dtypes")
1515      self._shapes = [tensor_shape.TensorShape(s) for s in shapes]
1516    else:
1517      self._shapes = [tensor_shape.unknown_shape() for _ in self._dtypes]
1518
1519    if names is not None:
1520      if len(names) != len(dtypes):
1521        raise ValueError("StagingArea names must be the same length as dtypes")
1522      self._names = names
1523    else:
1524      self._names = None
1525
1526    self._capacity = capacity
1527    self._memory_limit = memory_limit
1528
1529    # all get and put ops must colocate with this op
1530    with ops.name_scope("%s_root" % self._name):
1531      self._coloc_op = control_flow_ops.no_op()
1532
1533  @property
1534  def name(self):
1535    """The name of the staging area."""
1536    return self._name
1537
1538  @property
1539  def dtypes(self):
1540    """The list of dtypes for each component of a staging area element."""
1541    return self._dtypes
1542
1543  @property
1544  def shapes(self):
1545    """The list of shapes for each component of a staging area element."""
1546    return self._shapes
1547
1548  @property
1549  def names(self):
1550    """The list of names for each component of a staging area element."""
1551    return self._names
1552
1553  @property
1554  def capacity(self):
1555    """The maximum number of elements of this staging area."""
1556    return self._capacity
1557
1558  @property
1559  def memory_limit(self):
1560    """The maximum number of bytes of this staging area."""
1561    return self._memory_limit
1562
1563  def _check_put_dtypes(self, vals, indices=None):
1564    """Validate and convert `vals` to a list of `Tensor`s.
1565
1566    The `vals` argument can be a Tensor, a list or tuple of tensors, or a
1567    dictionary with tensor values.
1568
1569    If `vals` is a list, then the appropriate indices associated with the
1570    values must be provided.
1571
1572    If it is a dictionary, the staging area must have been constructed with a
1573    `names` attribute and the dictionary keys must match the staging area names.
1574    `indices` will be inferred from the dictionary keys.
1575    If the staging area was constructed with a `names` attribute, `vals` must
1576    be a dictionary.
1577
1578    Checks that the dtype and shape of each value matches that
1579    of the staging area.
1580
1581    Args:
1582      vals: A tensor, a list or tuple of tensors, or a dictionary.
1583
1584    Returns:
1585      A (tensors, indices) tuple where `tensors` is a list of `Tensor` objects
1586      and `indices` is a list of indices associed with the tensors.
1587
1588    Raises:
1589      ValueError: If `vals` or `indices` is invalid.
1590    """
1591    if isinstance(vals, dict):
1592      if not self._names:
1593        raise ValueError(
1594            "Staging areas must have names to enqueue a dictionary")
1595      if not set(vals.keys()).issubset(self._names):
1596        raise ValueError("Keys in dictionary to put do not match names "
1597                         "of staging area. Dictionary: (%s), Queue: (%s)" %
1598                         (sorted(vals.keys()), sorted(self._names)))
1599      # The order of values in `self._names` indicates the order in which the
1600      # tensors in the dictionary `vals` must be listed.
1601      vals, indices, _ = zip(*[(vals[k], i, k)
1602                               for i, k in enumerate(self._names)
1603                               if k in vals])
1604    else:
1605      if self._names:
1606        raise ValueError("You must enqueue a dictionary in a staging area "
1607                         "with names")
1608
1609      if indices is None:
1610        raise ValueError("Indices must be supplied when inserting a list "
1611                         "of tensors")
1612
1613      if len(indices) != len(vals):
1614        raise ValueError("Number of indices '%s' doesn't match "
1615                         "number of values '%s'")
1616
1617      if not isinstance(vals, (list, tuple)):
1618        vals = [vals]
1619        indices = [0]
1620
1621    # Sanity check number of values
1622    if not len(vals) <= len(self._dtypes):
1623      raise ValueError("Unexpected number of inputs '%s' vs '%s'" %
1624                       (len(vals), len(self._dtypes)))
1625
1626    tensors = []
1627
1628    for val, i in zip(vals, indices):
1629      dtype, shape = self._dtypes[i], self._shapes[i]
1630      # Check dtype
1631      if val.dtype != dtype:
1632        raise ValueError("Datatypes do not match. '%s' != '%s'" %
1633                         (str(val.dtype), str(dtype)))
1634
1635      # Check shape
1636      val.get_shape().assert_is_compatible_with(shape)
1637
1638      tensors.append(
1639          ops.convert_to_tensor(val, dtype=dtype, name="component_%d" % i))
1640
1641    return tensors, indices
1642
1643  def _create_device_transfers(self, tensors):
1644    """Encode inter-device transfers if the current device
1645    is not the same as the Staging Area's device.
1646    """
1647
1648    if not isinstance(tensors, (tuple, list)):
1649      tensors = [tensors]
1650
1651    curr_device_scope = control_flow_ops.no_op().device
1652
1653    if curr_device_scope != self._coloc_op.device:
1654      tensors = [array_ops.identity(t) for t in tensors]
1655
1656    return tensors
1657
1658  def _get_return_value(self, tensors, indices):
1659    """Return the value to return from a get op.
1660
1661    If the staging area has names, return a dictionary with the
1662    names as keys.  Otherwise return either a single tensor
1663    or a list of tensors depending on the length of `tensors`.
1664
1665    Args:
1666      tensors: List of tensors from the get op.
1667      indices: Indices of associated names and shapes
1668
1669    Returns:
1670      A single tensor, a list of tensors, or a dictionary
1671      of tensors.
1672    """
1673
1674    tensors = self._create_device_transfers(tensors)
1675
1676    # Sets shape
1677    for output, i in zip(tensors, indices):
1678      output.set_shape(self._shapes[i])
1679
1680    if self._names:
1681      # The returned values in `tensors` are in the same order as
1682      # the names in `self._names`.
1683      return {self._names[i]: t for t, i in zip(tensors, indices)}
1684    return tensors
1685
1686  def _scope_vals(self, vals):
1687    """Return a list of values to pass to `name_scope()`.
1688
1689    Args:
1690      vals: A tensor, a list or tuple of tensors, or a dictionary.
1691
1692    Returns:
1693      The values in vals as a list.
1694    """
1695    if isinstance(vals, (list, tuple)):
1696      return vals
1697    elif isinstance(vals, dict):
1698      return vals.values()
1699    else:
1700      return [vals]
1701
1702
1703class StagingArea(BaseStagingArea):
1704  """Class for staging inputs. No ordering guarantees.
1705
1706  A `StagingArea` is a TensorFlow data structure that stores tensors across
1707  multiple steps, and exposes operations that can put and get tensors.
1708
1709  Each `StagingArea` element is a tuple of one or more tensors, where each
1710  tuple component has a static dtype, and may have a static shape.
1711
1712  The capacity of a `StagingArea` may be bounded or unbounded.
1713  It supports multiple concurrent producers and consumers; and
1714  provides exactly-once delivery.
1715
1716  Each element of a `StagingArea` is a fixed-length tuple of tensors whose
1717  dtypes are described by `dtypes`, and whose shapes are optionally described
1718  by the `shapes` argument.
1719
1720  If the `shapes` argument is specified, each component of a staging area
1721  element must have the respective fixed shape. If it is
1722  unspecified, different elements may have different shapes,
1723
1724  It can be configured with a capacity in which case
1725  put(values) will block until space becomes available.
1726
1727  Similarly, it can be configured with a memory limit which
1728  will block put(values) until space is available.
1729  This is mostly useful for limiting the number of tensors on
1730  devices such as GPUs.
1731
1732  All get() and peek() commands block if the requested data
1733  is not present in the Staging Area.
1734
1735  """
1736
1737  def __init__(self,
1738               dtypes,
1739               shapes=None,
1740               names=None,
1741               shared_name=None,
1742               capacity=0,
1743               memory_limit=0):
1744    """Constructs a staging area object.
1745
1746    The two optional lists, `shapes` and `names`, must be of the same length
1747    as `dtypes` if provided.  The values at a given index `i` indicate the
1748    shape and name to use for the corresponding queue component in `dtypes`.
1749
1750    The device scope at the time of object creation determines where the
1751    storage for the `StagingArea` will reside.  Calls to `put` will incur a copy
1752    to this memory space, if necessary.  Tensors returned by `get` will be
1753    placed according to the device scope when `get` is called.
1754
1755    Args:
1756      dtypes:  A list of types.  The length of dtypes must equal the number
1757        of tensors in each element.
1758      shapes: (Optional.) Constraints on the shapes of tensors in an element.
1759        A list of shape tuples or None. This list is the same length
1760        as dtypes.  If the shape of any tensors in the element are constrained,
1761        all must be; shapes can be None if the shapes should not be constrained.
1762      names: (Optional.) If provided, the `get()` and
1763        `put()` methods will use dictionaries with these names as keys.
1764        Must be None or a list or tuple of the same length as `dtypes`.
1765      shared_name: (Optional.) A name to be used for the shared object. By
1766        passing the same name to two different python objects they will share
1767        the underlying staging area. Must be a string.
1768      capacity: (Optional.) Maximum number of elements.
1769        An integer. If zero, the Staging Area is unbounded
1770      memory_limit: (Optional.) Maximum number of bytes of all tensors
1771        in the Staging Area.
1772        An integer. If zero, the Staging Area is unbounded
1773
1774    Raises:
1775      ValueError: If one of the arguments is invalid.
1776    """
1777
1778    super(StagingArea, self).__init__(dtypes, shapes, names, shared_name,
1779                                      capacity, memory_limit)
1780
1781  def put(self, values, name=None):
1782    """Create an op that places a value into the staging area.
1783
1784    This operation will block if the `StagingArea` has reached
1785    its capacity.
1786
1787    Args:
1788      values: A single tensor, a list or tuple of tensors, or a dictionary with
1789        tensor values. The number of elements must match the length of the
1790        list provided to the dtypes argument when creating the StagingArea.
1791      name: A name for the operation (optional).
1792
1793    Returns:
1794        The created op.
1795
1796    Raises:
1797      ValueError: If the number or type of inputs don't match the staging area.
1798    """
1799    with ops.name_scope(name, "%s_put" % self._name,
1800                        self._scope_vals(values)) as scope:
1801
1802      if not isinstance(values, (list, tuple, dict)):
1803        values = [values]
1804
1805      # Hard-code indices for this staging area
1806      indices = list(six.moves.range(len(values)))
1807      vals, _ = self._check_put_dtypes(values, indices)
1808
1809      with ops.colocate_with(self._coloc_op):
1810        op = gen_data_flow_ops.stage(
1811            values=vals,
1812            shared_name=self._name,
1813            name=scope,
1814            capacity=self._capacity,
1815            memory_limit=self._memory_limit)
1816
1817      return op
1818
1819  def __internal_get(self, get_fn, name):
1820    with ops.colocate_with(self._coloc_op):
1821      ret = get_fn()
1822
1823    indices = list(six.moves.range(len(self._dtypes)))  # Hard coded
1824    return self._get_return_value(ret, indices)
1825
1826  def get(self, name=None):
1827    """Gets one element from this staging area.
1828
1829    If the staging area is empty when this operation executes, it will block
1830    until there is an element to dequeue.
1831
1832    Note that unlike others ops that can block, like the queue Dequeue
1833    operations, this can stop other work from happening.  To avoid this, the
1834    intended use is for this to be called only when there will be an element
1835    already available.  One method for doing this in a training loop would be to
1836    run a `put()` call during a warmup session.run call, and then call both
1837    `get()` and `put()` in each subsequent step.
1838
1839    The placement of the returned tensor will be determined by the current
1840    device scope when this function is called.
1841
1842    Args:
1843      name: A name for the operation (optional).
1844
1845    Returns:
1846      The tuple of tensors that was gotten.
1847    """
1848    if name is None:
1849      name = "%s_get" % self._name
1850
1851    # pylint: disable=bad-continuation
1852    fn = lambda: gen_data_flow_ops.unstage(dtypes=self._dtypes,
1853                    shared_name=self._name, name=name,
1854                    capacity=self._capacity,
1855                    memory_limit=self._memory_limit)
1856    # pylint: enable=bad-continuation
1857
1858    return self.__internal_get(fn, name)
1859
1860  def peek(self, index, name=None):
1861    """Peeks at an element in the staging area.
1862
1863    If the staging area is too small to contain the element at
1864    the specified index, it will block until enough elements
1865    are inserted to complete the operation.
1866
1867    The placement of the returned tensor will be determined by
1868    the current device scope when this function is called.
1869
1870    Args:
1871      index: The index of the tensor within the staging area
1872              to look up.
1873      name: A name for the operation (optional).
1874
1875    Returns:
1876      The tuple of tensors that was gotten.
1877    """
1878    if name is None:
1879      name = "%s_peek" % self._name
1880
1881    # pylint: disable=bad-continuation
1882    fn = lambda: gen_data_flow_ops.stage_peek(index,
1883                    dtypes=self._dtypes, shared_name=self._name,
1884                    name=name, capacity=self._capacity,
1885                    memory_limit=self._memory_limit)
1886    # pylint: enable=bad-continuation
1887
1888    return self.__internal_get(fn, name)
1889
1890  def size(self, name=None):
1891    """Returns the number of elements in the staging area.
1892
1893    Args:
1894        name: A name for the operation (optional)
1895
1896    Returns:
1897        The created op
1898    """
1899    if name is None:
1900      name = "%s_size" % self._name
1901
1902    return gen_data_flow_ops.stage_size(
1903        name=name,
1904        shared_name=self._name,
1905        dtypes=self._dtypes,
1906        capacity=self._capacity,
1907        memory_limit=self._memory_limit)
1908
1909  def clear(self, name=None):
1910    """Clears the staging area.
1911
1912    Args:
1913        name: A name for the operation (optional)
1914
1915    Returns:
1916        The created op
1917    """
1918    if name is None:
1919      name = "%s_clear" % self._name
1920
1921    return gen_data_flow_ops.stage_clear(
1922        name=name,
1923        shared_name=self._name,
1924        dtypes=self._dtypes,
1925        capacity=self._capacity,
1926        memory_limit=self._memory_limit)
1927
1928
1929class MapStagingArea(BaseStagingArea):
1930  """A `MapStagingArea` is a TensorFlow data structure that stores tensors
1931  across multiple steps, and exposes operations that can put and get tensors.
1932
1933  Each `MapStagingArea` element is a (key, value) pair.
1934  Only int64 keys are supported, other types should be
1935  hashed to produce a key.
1936  Values are a tuple of one or more tensors.
1937  Each tuple component has a static dtype,
1938  and may have a static shape.
1939
1940  The capacity of a `MapStagingArea` may be bounded or unbounded.
1941  It supports multiple concurrent producers and consumers; and
1942  provides exactly-once delivery.
1943
1944  Each value tuple of a `MapStagingArea` is a fixed-length tuple of tensors
1945  whose
1946  dtypes are described by `dtypes`, and whose shapes are optionally described
1947  by the `shapes` argument.
1948
1949  If the `shapes` argument is specified, each component of a staging area
1950  element must have the respective fixed shape. If it is
1951  unspecified, different elements may have different shapes,
1952
1953  It behaves like an associative container with support for:
1954
1955   - put(key, values)
1956   - peek(key)         like dict.get(key)
1957   - get(key)          like dict.pop(key)
1958   - get(key=None)     like dict.popitem()
1959   - size()
1960   - clear()
1961
1962  If ordered a tree structure ordered by key will be used and
1963  get(key=None) will remove (key, value) pairs in increasing key order.
1964  Otherwise a hashtable
1965
1966  It can be configured with a capacity in which case
1967  put(key, values) will block until space becomes available.
1968
1969  Similarly, it can be configured with a memory limit which
1970  will block put(key, values) until space is available.
1971  This is mostly useful for limiting the number of tensors on
1972  devices such as GPUs.
1973
1974  All get() and peek() commands block if the requested
1975  (key, value) pair is not present in the staging area.
1976
1977  Partial puts are supported and will be placed in an incomplete
1978  map until such time as all values associated with the key have
1979  been inserted. Once completed, this (key, value) pair will be
1980  inserted into the map. Data in the incomplete map
1981  counts towards the memory limit, but not towards capacity limit.
1982
1983  Partial gets from the map are also supported.
1984  This removes the partially requested tensors from the entry,
1985  but the entry is only removed from the map once all tensors
1986  associated with it are removed.
1987  """
1988
1989  def __init__(self,
1990               dtypes,
1991               shapes=None,
1992               names=None,
1993               shared_name=None,
1994               ordered=False,
1995               capacity=0,
1996               memory_limit=0):
1997    """Args:
1998
1999      dtypes:  A list of types.  The length of dtypes must equal the number
2000        of tensors in each element.
2001      capacity: (Optional.) Maximum number of elements.
2002        An integer. If zero, the Staging Area is unbounded
2003      memory_limit: (Optional.) Maximum number of bytes of all tensors
2004        in the Staging Area (excluding keys).
2005        An integer. If zero, the Staging Area is unbounded
2006      ordered: (Optional.) If True the underlying data structure
2007        is a tree ordered on key. Otherwise assume a hashtable.
2008      shapes: (Optional.) Constraints on the shapes of tensors in an element.
2009        A list of shape tuples or None. This list is the same length
2010        as dtypes.  If the shape of any tensors in the element are constrained,
2011        all must be; shapes can be None if the shapes should not be constrained.
2012      names: (Optional.) If provided, the `get()` and
2013        `put()` methods will use dictionaries with these names as keys.
2014        Must be None or a list or tuple of the same length as `dtypes`.
2015      shared_name: (Optional.) A name to be used for the shared object. By
2016        passing the same name to two different python objects they will share
2017        the underlying staging area. Must be a string.
2018
2019    Raises:
2020      ValueError: If one of the arguments is invalid.
2021
2022    """
2023
2024    super(MapStagingArea, self).__init__(dtypes, shapes, names, shared_name,
2025                                         capacity, memory_limit)
2026
2027    # Defer to different methods depending if the map is ordered
2028    self._ordered = ordered
2029
2030    if ordered:
2031      self._put_fn = gen_data_flow_ops.ordered_map_stage
2032      self._pop_fn = gen_data_flow_ops.ordered_map_unstage
2033      self._popitem_fn = gen_data_flow_ops.ordered_map_unstage_no_key
2034      self._peek_fn = gen_data_flow_ops.ordered_map_peek
2035      self._size_fn = gen_data_flow_ops.ordered_map_size
2036      self._incomplete_size_fn = gen_data_flow_ops.ordered_map_incomplete_size
2037      self._clear_fn = gen_data_flow_ops.ordered_map_clear
2038    else:
2039      self._put_fn = gen_data_flow_ops.map_stage
2040      self._pop_fn = gen_data_flow_ops.map_unstage
2041      self._popitem_fn = gen_data_flow_ops.map_unstage_no_key
2042      self._peek_fn = gen_data_flow_ops.map_peek
2043      self._size_fn = gen_data_flow_ops.map_size
2044      self._incomplete_size_fn = gen_data_flow_ops.map_incomplete_size
2045      self._clear_fn = gen_data_flow_ops.map_clear
2046
2047  def put(self, key, vals, indices=None, name=None):
2048    """Create an op that stores the (key, vals) pair in the staging area.
2049
2050    Incomplete puts are possible, preferably using a dictionary for vals
2051    as the appropriate dtypes and shapes can be inferred from the value names
2052    dictionary key values. If vals is a list or tuple, indices must
2053    also be specified so that the op knows at which element position
2054    to perform the insert.
2055
2056    This operation will block if the capacity or memory limit of this
2057    container is reached.
2058
2059    Args:
2060        key: Key associated with the data
2061        vals: Tensor (or a dict/tuple of Tensors) to place
2062                into the staging area.
2063        indices: (Optional) if vals is a tuple/list, this is required.
2064        name: A name for the operation (optional)
2065
2066    Returns:
2067        The created op
2068
2069    Raises:
2070        ValueError: If the number or type of inputs don't match the staging
2071        area.
2072    """
2073
2074    with ops.name_scope(name, "%s_put" % self._name,
2075                        self._scope_vals(vals)) as scope:
2076
2077      vals, indices = self._check_put_dtypes(vals, indices)
2078
2079      with ops.colocate_with(self._coloc_op):
2080        op = self._put_fn(
2081            key,
2082            indices,
2083            vals,
2084            dtypes=self._dtypes,
2085            shared_name=self._name,
2086            name=scope,
2087            capacity=self._capacity,
2088            memory_limit=self._memory_limit)
2089    return op
2090
2091  def _get_indices_and_dtypes(self, indices=None):
2092    if indices is None:
2093      indices = list(six.moves.range(len(self._dtypes)))
2094
2095    if not isinstance(indices, (tuple, list)):
2096      raise TypeError("Invalid indices type '%s'" % type(indices))
2097
2098    if len(indices) == 0:
2099      raise ValueError("Empty indices")
2100
2101    if all(isinstance(i, str) for i in indices):
2102      if self._names is None:
2103        raise ValueError("String indices provided '%s', but this Staging Area "
2104                         "was not created with names." % indices)
2105
2106      try:
2107        indices = [self._names.index(n) for n in indices]
2108      except ValueError:
2109        raise ValueError("Named index '%s' not in "
2110                         "Staging Area names '%s'" % (n, self._names))
2111    elif all(isinstance(i, int) for i in indices):
2112      pass
2113    else:
2114      raise TypeError("Mixed types in indices '%s'. "
2115                      "May only be str or int" % indices)
2116
2117    dtypes = [self._dtypes[i] for i in indices]
2118
2119    return indices, dtypes
2120
2121  def peek(self, key, indices=None, name=None):
2122    """Peeks at staging area data associated with the key.
2123
2124    If the key is not in the staging area, it will block
2125    until the associated (key, value) is inserted.
2126
2127    Args:
2128        key: Key associated with the required data
2129        indices: Partial list of tensors to retrieve (optional).
2130                A list of integer or string indices.
2131                String indices are only valid if the Staging Area
2132                has names associated with it.
2133        name: A name for the operation (optional)
2134
2135    Returns:
2136        The created op
2137    """
2138
2139    if name is None:
2140      name = "%s_pop" % self._name
2141
2142    indices, dtypes = self._get_indices_and_dtypes(indices)
2143
2144    with ops.colocate_with(self._coloc_op):
2145      result = self._peek_fn(
2146          key,
2147          shared_name=self._name,
2148          indices=indices,
2149          dtypes=dtypes,
2150          name=name,
2151          capacity=self._capacity,
2152          memory_limit=self._memory_limit)
2153
2154    return self._get_return_value(result, indices)
2155
2156  def get(self, key=None, indices=None, name=None):
2157    """If the key is provided, the associated (key, value) is returned from the staging area.
2158
2159    If the key is not in the staging area, this method will block until
2160    the associated (key, value) is inserted.
2161    If no key is provided and the staging area is ordered,
2162    the (key, value) with the smallest key will be returned.
2163    Otherwise, a random (key, value) will be returned.
2164
2165    If the staging area is empty when this operation executes,
2166    it will block until there is an element to dequeue.
2167
2168    Args:
2169        key: Key associated with the required data (Optional)
2170        indices: Partial list of tensors to retrieve (optional).
2171                A list of integer or string indices.
2172                String indices are only valid if the Staging Area
2173                has names associated with it.
2174        name: A name for the operation (optional)
2175
2176    Returns:
2177        The created op
2178    """
2179    if key is None:
2180      return self._popitem(indices=indices, name=name)
2181    else:
2182      return self._pop(key, indices=indices, name=name)
2183
2184  def _pop(self, key, indices=None, name=None):
2185    """Remove and return the associated (key, value) is returned from the staging area.
2186
2187    If the key is not in the staging area, this method will block until
2188    the associated (key, value) is inserted.
2189    Args:
2190        key: Key associated with the required data
2191        indices: Partial list of tensors to retrieve (optional).
2192                A list of integer or string indices.
2193                String indices are only valid if the Staging Area
2194                has names associated with it.
2195        name: A name for the operation (optional)
2196
2197    Returns:
2198        The created op
2199    """
2200    if name is None:
2201      name = "%s_get" % self._name
2202
2203    indices, dtypes = self._get_indices_and_dtypes(indices)
2204
2205    with ops.colocate_with(self._coloc_op):
2206      result = self._pop_fn(
2207          key,
2208          shared_name=self._name,
2209          indices=indices,
2210          dtypes=dtypes,
2211          name=name,
2212          capacity=self._capacity,
2213          memory_limit=self._memory_limit)
2214
2215    return key, self._get_return_value(result, indices)
2216
2217  def _popitem(self, indices=None, name=None):
2218    """If the staging area is ordered, the (key, value) with the smallest key will be returned.
2219
2220    Otherwise, a random (key, value) will be returned.
2221    If the staging area is empty when this operation executes,
2222    it will block until there is an element to dequeue.
2223
2224    Args:
2225        key: Key associated with the required data
2226        indices: Partial list of tensors to retrieve (optional).
2227                A list of integer or string indices.
2228                String indices are only valid if the Staging Area
2229                has names associated with it.
2230        name: A name for the operation (optional)
2231
2232    Returns:
2233        The created op
2234    """
2235    if name is None:
2236      name = "%s_get_nokey" % self._name
2237
2238    indices, dtypes = self._get_indices_and_dtypes(indices)
2239
2240    with ops.colocate_with(self._coloc_op):
2241      key, result = self._popitem_fn(
2242          shared_name=self._name,
2243          indices=indices,
2244          dtypes=dtypes,
2245          name=name,
2246          capacity=self._capacity,
2247          memory_limit=self._memory_limit)
2248
2249    # Separate keys and results out from
2250    # underlying namedtuple
2251    key = self._create_device_transfers(key)[0]
2252    result = self._get_return_value(result, indices)
2253
2254    return key, result
2255
2256  def size(self, name=None):
2257    """Returns the number of elements in the staging area.
2258
2259    Args:
2260        name: A name for the operation (optional)
2261
2262    Returns:
2263        The created op
2264    """
2265    if name is None:
2266      name = "%s_size" % self._name
2267
2268    return self._size_fn(
2269        shared_name=self._name,
2270        name=name,
2271        dtypes=self._dtypes,
2272        capacity=self._capacity,
2273        memory_limit=self._memory_limit)
2274
2275  def incomplete_size(self, name=None):
2276    """Returns the number of incomplete elements in the staging area.
2277
2278    Args:
2279        name: A name for the operation (optional)
2280
2281    Returns:
2282        The created op
2283    """
2284    if name is None:
2285      name = "%s_incomplete_size" % self._name
2286
2287    return self._incomplete_size_fn(
2288        shared_name=self._name,
2289        name=name,
2290        dtypes=self._dtypes,
2291        capacity=self._capacity,
2292        memory_limit=self._memory_limit)
2293
2294  def clear(self, name=None):
2295    """Clears the staging area.
2296
2297    Args:
2298        name: A name for the operation (optional)
2299
2300    Returns:
2301        The created op
2302    """
2303    if name is None:
2304      name = "%s_clear" % self._name
2305
2306    return self._clear_fn(
2307        shared_name=self._name,
2308        name=name,
2309        dtypes=self._dtypes,
2310        capacity=self._capacity,
2311        memory_limit=self._memory_limit)
2312
2313
2314class RecordInput(object):
2315  """RecordInput asynchronously reads and randomly yields TFRecords.
2316
2317  A RecordInput Op will continuously read a batch of records asynchronously
2318  into a buffer of some fixed capacity. It can also asynchronously yield
2319  random records from this buffer.
2320
2321  It will not start yielding until at least `buffer_size / 2` elements have been
2322  placed into the buffer so that sufficient randomization can take place.
2323
2324  The order the files are read will be shifted each epoch by `shift_amount` so
2325  that the data is presented in a different order every epoch.
2326  """
2327
2328  def __init__(self,
2329               file_pattern,
2330               batch_size=1,
2331               buffer_size=1,
2332               parallelism=1,
2333               shift_ratio=0,
2334               seed=0,
2335               name=None,
2336               batches=None,
2337               compression_type=None):
2338    """Constructs a RecordInput Op.
2339
2340    Args:
2341      file_pattern: File path to the dataset, possibly containing wildcards.
2342        All matching files will be iterated over each epoch.
2343      batch_size: How many records to return at a time.
2344      buffer_size: The maximum number of records the buffer will contain.
2345      parallelism: How many reader threads to use for reading from files.
2346      shift_ratio: What percentage of the total number files to move the start
2347        file forward by each epoch.
2348      seed: Specify the random number seed used by generator that randomizes
2349        records.
2350      name: Optional name for the operation.
2351      batches: None by default, creating a single batch op. Otherwise specifies
2352        how many batches to create, which are returned as a list when
2353        `get_yield_op()` is called. An example use case is to split processing
2354        between devices on one computer.
2355      compression_type: The type of compression for the file. Currently ZLIB and
2356        GZIP are supported. Defaults to none.
2357
2358    Raises:
2359      ValueError: If one of the arguments is invalid.
2360    """
2361    self._batch_size = batch_size
2362    if batches is not None:
2363      self._batch_size *= batches
2364    self._batches = batches
2365    self._file_pattern = file_pattern
2366    self._buffer_size = buffer_size
2367    self._parallelism = parallelism
2368    self._shift_ratio = shift_ratio
2369    self._seed = seed
2370    self._name = name
2371    self._compression_type = python_io.TFRecordCompressionType.NONE
2372    if compression_type is not None:
2373      self._compression_type = compression_type
2374
2375  def get_yield_op(self):
2376    """Adds a node that yields a group of records every time it is executed.
2377    If RecordInput `batches` parameter is not None, it yields a list of
2378    record batches with the specified `batch_size`.
2379    """
2380    compression_type = python_io.TFRecordOptions.get_compression_type_string(
2381        python_io.TFRecordOptions(self._compression_type))
2382    records = gen_data_flow_ops.record_input(
2383        file_pattern=self._file_pattern,
2384        file_buffer_size=self._buffer_size,
2385        file_parallelism=self._parallelism,
2386        file_shuffle_shift_ratio=self._shift_ratio,
2387        batch_size=self._batch_size,
2388        file_random_seed=self._seed,
2389        compression_type=compression_type,
2390        name=self._name)
2391    if self._batches is None:
2392      return records
2393    else:
2394      with ops.name_scope(self._name):
2395        batch_list = [[] for _ in six.moves.range(self._batches)]
2396        records = array_ops.split(records, self._batch_size, 0)
2397        records = [array_ops.reshape(record, []) for record in records]
2398        for index, protobuf in zip(six.moves.range(len(records)), records):
2399          batch_index = index % self._batches
2400          batch_list[batch_index].append(protobuf)
2401        return batch_list
2402