• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#==============================================================================
15"""Data Flow Operations."""
16# pylint: disable=g-bad-name
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import hashlib
23import threading
24
25import six
26
27from tensorflow.python.eager import context
28from tensorflow.python.framework import dtypes as _dtypes
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import random_seed
31from tensorflow.python.framework import tensor_shape
32from tensorflow.python.framework import tensor_util
33from tensorflow.python.lib.io import python_io
34from tensorflow.python.ops import array_ops
35from tensorflow.python.ops import control_flow_ops
36from tensorflow.python.ops import gen_data_flow_ops
37from tensorflow.python.ops import math_ops
38# go/tf-wildcard-import
39# pylint: disable=wildcard-import
40from tensorflow.python.ops.gen_data_flow_ops import *
41from tensorflow.python.util.tf_export import tf_export
42
43# pylint: enable=wildcard-import
44
45
46def _as_type_list(dtypes):
47  """Convert dtypes to a list of types."""
48  assert dtypes is not None
49  if not (isinstance(dtypes, list) or isinstance(dtypes, tuple)):
50    # We have a single type.
51    return [dtypes]
52  else:
53    # We have a list or tuple of types.
54    return list(dtypes)
55
56
57def _as_shape_list(shapes,
58                   dtypes,
59                   unknown_dim_allowed=False,
60                   unknown_rank_allowed=False):
61  """Convert shapes to a list of tuples of int (or None)."""
62  del dtypes
63  if unknown_dim_allowed:
64    if (not isinstance(shapes, collections.Sequence) or not shapes or
65        any(shape is None or isinstance(shape, int) for shape in shapes)):
66      raise ValueError(
67          "When providing partial shapes, a list of shapes must be provided.")
68  if shapes is None:
69    return None
70  if isinstance(shapes, tensor_shape.TensorShape):
71    shapes = [shapes]
72  if not isinstance(shapes, (tuple, list)):
73    raise TypeError(
74        "shapes must be a TensorShape or a list or tuple of TensorShapes.")
75  if all(shape is None or isinstance(shape, int) for shape in shapes):
76    # We have a single shape.
77    shapes = [shapes]
78  shapes = [tensor_shape.as_shape(shape) for shape in shapes]
79  if not unknown_dim_allowed:
80    if any([not shape.is_fully_defined() for shape in shapes]):
81      raise ValueError("All shapes must be fully defined: %s" % shapes)
82  if not unknown_rank_allowed:
83    if any([shape.dims is None for shape in shapes]):
84      raise ValueError("All shapes must have a defined rank: %s" % shapes)
85
86  return shapes
87
88
89def _as_name_list(names, dtypes):
90  if names is None:
91    return None
92  if not isinstance(names, (list, tuple)):
93    names = [names]
94  if len(names) != len(dtypes):
95    raise ValueError("List of names must have the same length as the list "
96                     "of dtypes")
97  return list(names)
98
99
100def _shape_common(s1, s2):
101  """The greatest lower bound (ordered by specificity) TensorShape."""
102  s1 = tensor_shape.TensorShape(s1)
103  s2 = tensor_shape.TensorShape(s2)
104  if s1.ndims is None or s2.ndims is None or s1.ndims != s2.ndims:
105    return tensor_shape.unknown_shape()
106  d = [
107      d1 if d1 is not None and d1 == d2 else None
108      for (d1, d2) in zip(s1.as_list(), s2.as_list())
109  ]
110  return tensor_shape.TensorShape(d)
111
112
113# pylint: disable=protected-access
114@tf_export("QueueBase")
115class QueueBase(object):
116  """Base class for queue implementations.
117
118  A queue is a TensorFlow data structure that stores tensors across
119  multiple steps, and exposes operations that enqueue and dequeue
120  tensors.
121
122  Each queue element is a tuple of one or more tensors, where each
123  tuple component has a static dtype, and may have a static shape. The
124  queue implementations support versions of enqueue and dequeue that
125  handle single elements, versions that support enqueuing and
126  dequeuing a batch of elements at once.
127
128  See @{tf.FIFOQueue} and
129  @{tf.RandomShuffleQueue} for concrete
130  implementations of this class, and instructions on how to create
131  them.
132
133  @compatibility(eager)
134  Queues are not compatible with eager execution. Instead, please
135  use `tf.data` to get data into your model.
136  @end_compatibility
137  """
138
139  def __init__(self, dtypes, shapes, names, queue_ref):
140    """Constructs a queue object from a queue reference.
141
142    The two optional lists, `shapes` and `names`, must be of the same length
143    as `dtypes` if provided.  The values at a given index `i` indicate the
144    shape and name to use for the corresponding queue component in `dtypes`.
145
146    Args:
147      dtypes:  A list of types.  The length of dtypes must equal the number
148        of tensors in each element.
149      shapes: Constraints on the shapes of tensors in an element:
150        A list of shape tuples or None. This list is the same length
151        as dtypes.  If the shape of any tensors in the element are constrained,
152        all must be; shapes can be None if the shapes should not be constrained.
153      names: Optional list of names.  If provided, the `enqueue()` and
154        `dequeue()` methods will use dictionaries with these names as keys.
155        Must be None or a list or tuple of the same length as `dtypes`.
156      queue_ref: The queue reference, i.e. the output of the queue op.
157
158    Raises:
159      ValueError: If one of the arguments is invalid.
160      RuntimeError: If eager execution is enabled.
161    """
162    if context.in_eager_mode():
163      raise RuntimeError(
164          "Queues are not supported when eager execution is enabled. "
165          "Instead, please use tf.data to get data into your model.")
166    self._dtypes = dtypes
167    if shapes is not None:
168      if len(shapes) != len(dtypes):
169        raise ValueError("Queue shapes must have the same length as dtypes")
170      self._shapes = [tensor_shape.TensorShape(s) for s in shapes]
171    else:
172      self._shapes = [tensor_shape.unknown_shape() for _ in self._dtypes]
173    if names is not None:
174      if len(names) != len(dtypes):
175        raise ValueError("Queue names must have the same length as dtypes")
176      self._names = names
177    else:
178      self._names = None
179    self._queue_ref = queue_ref
180    if context.in_graph_mode():
181      self._name = self._queue_ref.op.name.split("/")[-1]
182    else:
183      self._name = context.context().scope_name
184
185  @staticmethod
186  def from_list(index, queues):
187    """Create a queue using the queue reference from `queues[index]`.
188
189    Args:
190      index: An integer scalar tensor that determines the input that gets
191        selected.
192      queues: A list of `QueueBase` objects.
193
194    Returns:
195      A `QueueBase` object.
196
197    Raises:
198      TypeError: When `queues` is not a list of `QueueBase` objects,
199        or when the data types of `queues` are not all the same.
200    """
201    if ((not queues) or (not isinstance(queues, list)) or
202        (not all(isinstance(x, QueueBase) for x in queues))):
203      raise TypeError("A list of queues expected")
204
205    dtypes = queues[0].dtypes
206    if not all([dtypes == q.dtypes for q in queues[1:]]):
207      raise TypeError("Queues do not have matching component dtypes.")
208
209    names = queues[0].names
210    if not all([names == q.names for q in queues[1:]]):
211      raise TypeError("Queues do not have matching component names.")
212
213    queue_shapes = [q.shapes for q in queues]
214    reduced_shapes = [
215        six.moves.reduce(_shape_common, s) for s in zip(*queue_shapes)
216    ]
217
218    queue_refs = array_ops.stack([x.queue_ref for x in queues])
219    selected_queue = array_ops.gather(queue_refs, index)
220    return QueueBase(
221        dtypes=dtypes,
222        shapes=reduced_shapes,
223        names=names,
224        queue_ref=selected_queue)
225
226  @property
227  def queue_ref(self):
228    """The underlying queue reference."""
229    return self._queue_ref
230
231  @property
232  def name(self):
233    """The name of the underlying queue."""
234    if context.in_graph_mode():
235      return self._queue_ref.op.name
236    return self._name
237
238  @property
239  def dtypes(self):
240    """The list of dtypes for each component of a queue element."""
241    return self._dtypes
242
243  @property
244  def shapes(self):
245    """The list of shapes for each component of a queue element."""
246    return self._shapes
247
248  @property
249  def names(self):
250    """The list of names for each component of a queue element."""
251    return self._names
252
253  def _check_enqueue_dtypes(self, vals):
254    """Validate and convert `vals` to a list of `Tensor`s.
255
256    The `vals` argument can be a Tensor, a list or tuple of tensors, or a
257    dictionary with tensor values.
258
259    If it is a dictionary, the queue must have been constructed with a
260    `names` attribute and the dictionary keys must match the queue names.
261    If the queue was constructed with a `names` attribute, `vals` must
262    be a dictionary.
263
264    Args:
265      vals: A tensor, a list or tuple of tensors, or a dictionary..
266
267    Returns:
268      A list of `Tensor` objects.
269
270    Raises:
271      ValueError: If `vals` is invalid.
272    """
273    if isinstance(vals, dict):
274      if not self._names:
275        raise ValueError("Queue must have names to enqueue a dictionary")
276      if sorted(self._names, key=str) != sorted(vals.keys(), key=str):
277        raise ValueError("Keys in dictionary to enqueue do not match "
278                         "names of Queue.  Dictionary: (%s), Queue: (%s)" %
279                         (sorted(vals.keys()), sorted(self._names)))
280      # The order of values in `self._names` indicates the order in which the
281      # tensors in the dictionary `vals` must be listed.
282      vals = [vals[k] for k in self._names]
283    else:
284      if self._names:
285        raise ValueError("You must enqueue a dictionary in a Queue with names")
286      if not isinstance(vals, (list, tuple)):
287        vals = [vals]
288
289    tensors = []
290    for i, (val, dtype) in enumerate(zip(vals, self._dtypes)):
291      tensors.append(
292          ops.convert_to_tensor(val, dtype=dtype, name="component_%d" % i))
293
294    return tensors
295
296  def _scope_vals(self, vals):
297    """Return a list of values to pass to `name_scope()`.
298
299    Args:
300      vals: A tensor, a list or tuple of tensors, or a dictionary.
301
302    Returns:
303      The values in vals as a list.
304    """
305    if isinstance(vals, (list, tuple)):
306      return vals
307    elif isinstance(vals, dict):
308      return vals.values()
309    else:
310      return [vals]
311
312  def enqueue(self, vals, name=None):
313    """Enqueues one element to this queue.
314
315    If the queue is full when this operation executes, it will block
316    until the element has been enqueued.
317
318    At runtime, this operation may raise an error if the queue is
319    @{tf.QueueBase.close} before or during its execution. If the
320    queue is closed before this operation runs,
321    `tf.errors.CancelledError` will be raised. If this operation is
322    blocked, and either (i) the queue is closed by a close operation
323    with `cancel_pending_enqueues=True`, or (ii) the session is
324    @{tf.Session.close},
325    `tf.errors.CancelledError` will be raised.
326
327    Args:
328      vals: A tensor, a list or tuple of tensors, or a dictionary containing
329        the values to enqueue.
330      name: A name for the operation (optional).
331
332    Returns:
333      The operation that enqueues a new tuple of tensors to the queue.
334    """
335    with ops.name_scope(name, "%s_enqueue" % self._name,
336                        self._scope_vals(vals)) as scope:
337      vals = self._check_enqueue_dtypes(vals)
338
339      # NOTE(mrry): Not using a shape function because we need access to
340      # the `QueueBase` object.
341      for val, shape in zip(vals, self._shapes):
342        val.get_shape().assert_is_compatible_with(shape)
343
344      if self._queue_ref.dtype == _dtypes.resource:
345        return gen_data_flow_ops._queue_enqueue_v2(
346            self._queue_ref, vals, name=scope)
347      else:
348        return gen_data_flow_ops._queue_enqueue(
349            self._queue_ref, vals, name=scope)
350
351  def enqueue_many(self, vals, name=None):
352    """Enqueues zero or more elements to this queue.
353
354    This operation slices each component tensor along the 0th dimension to
355    make multiple queue elements. All of the tensors in `vals` must have the
356    same size in the 0th dimension.
357
358    If the queue is full when this operation executes, it will block
359    until all of the elements have been enqueued.
360
361    At runtime, this operation may raise an error if the queue is
362    @{tf.QueueBase.close} before or during its execution. If the
363    queue is closed before this operation runs,
364    `tf.errors.CancelledError` will be raised. If this operation is
365    blocked, and either (i) the queue is closed by a close operation
366    with `cancel_pending_enqueues=True`, or (ii) the session is
367    @{tf.Session.close},
368    `tf.errors.CancelledError` will be raised.
369
370    Args:
371      vals: A tensor, a list or tuple of tensors, or a dictionary
372        from which the queue elements are taken.
373      name: A name for the operation (optional).
374
375    Returns:
376      The operation that enqueues a batch of tuples of tensors to the queue.
377    """
378    with ops.name_scope(name, "%s_EnqueueMany" % self._name,
379                        self._scope_vals(vals)) as scope:
380      vals = self._check_enqueue_dtypes(vals)
381
382      # NOTE(mrry): Not using a shape function because we need access to
383      # the `QueueBase` object.
384      batch_dim = vals[0].get_shape().with_rank_at_least(1)[0]
385      for val, shape in zip(vals, self._shapes):
386        batch_dim = batch_dim.merge_with(
387            val.get_shape().with_rank_at_least(1)[0])
388        val.get_shape()[1:].assert_is_compatible_with(shape)
389
390      return gen_data_flow_ops._queue_enqueue_many_v2(
391          self._queue_ref, vals, name=scope)
392
393  def _dequeue_return_value(self, tensors):
394    """Return the value to return from a dequeue op.
395
396    If the queue has names, return a dictionary with the
397    names as keys.  Otherwise return either a single tensor
398    or a list of tensors depending on the length of `tensors`.
399
400    Args:
401      tensors: List of tensors from the dequeue op.
402
403    Returns:
404      A single tensor, a list of tensors, or a dictionary
405      of tensors.
406    """
407    if self._names:
408      # The returned values in `tensors` are in the same order as
409      # the names in `self._names`.
410      return {n: tensors[i] for i, n in enumerate(self._names)}
411    elif len(tensors) == 1:
412      return tensors[0]
413    else:
414      return tensors
415
416  def dequeue(self, name=None):
417    """Dequeues one element from this queue.
418
419    If the queue is empty when this operation executes, it will block
420    until there is an element to dequeue.
421
422    At runtime, this operation may raise an error if the queue is
423    @{tf.QueueBase.close} before or during its execution. If the
424    queue is closed, the queue is empty, and there are no pending
425    enqueue operations that can fulfill this request,
426    `tf.errors.OutOfRangeError` will be raised. If the session is
427    @{tf.Session.close},
428    `tf.errors.CancelledError` will be raised.
429
430    Args:
431      name: A name for the operation (optional).
432
433    Returns:
434      The tuple of tensors that was dequeued.
435    """
436    if name is None:
437      name = "%s_Dequeue" % self._name
438    if self._queue_ref.dtype == _dtypes.resource:
439      ret = gen_data_flow_ops._queue_dequeue_v2(
440          self._queue_ref, self._dtypes, name=name)
441    else:
442      ret = gen_data_flow_ops._queue_dequeue(
443          self._queue_ref, self._dtypes, name=name)
444
445    # NOTE(mrry): Not using a shape function because we need access to
446    # the `QueueBase` object.
447    if context.in_graph_mode():
448      op = ret[0].op
449      for output, shape in zip(op.values(), self._shapes):
450        output.set_shape(shape)
451
452    return self._dequeue_return_value(ret)
453
454  def dequeue_many(self, n, name=None):
455    """Dequeues and concatenates `n` elements from this queue.
456
457    This operation concatenates queue-element component tensors along
458    the 0th dimension to make a single component tensor.  All of the
459    components in the dequeued tuple will have size `n` in the 0th dimension.
460
461    If the queue is closed and there are less than `n` elements left, then an
462    `OutOfRange` exception is raised.
463
464    At runtime, this operation may raise an error if the queue is
465    @{tf.QueueBase.close} before or during its execution. If the
466    queue is closed, the queue contains fewer than `n` elements, and
467    there are no pending enqueue operations that can fulfill this
468    request, `tf.errors.OutOfRangeError` will be raised. If the
469    session is @{tf.Session.close},
470    `tf.errors.CancelledError` will be raised.
471
472    Args:
473      n: A scalar `Tensor` containing the number of elements to dequeue.
474      name: A name for the operation (optional).
475
476    Returns:
477      The list of concatenated tensors that was dequeued.
478    """
479    if name is None:
480      name = "%s_DequeueMany" % self._name
481
482    ret = gen_data_flow_ops._queue_dequeue_many_v2(
483        self._queue_ref, n=n, component_types=self._dtypes, name=name)
484
485    # NOTE(mrry): Not using a shape function because we need access to
486    # the Queue object.
487    if context.in_graph_mode():
488      op = ret[0].op
489      batch_dim = tensor_shape.Dimension(
490          tensor_util.constant_value(op.inputs[1]))
491      for output, shape in zip(op.values(), self._shapes):
492        output.set_shape(
493            tensor_shape.TensorShape([batch_dim]).concatenate(shape))
494
495    return self._dequeue_return_value(ret)
496
497  def dequeue_up_to(self, n, name=None):
498    """Dequeues and concatenates `n` elements from this queue.
499
500    **Note** This operation is not supported by all queues.  If a queue does not
501    support DequeueUpTo, then a `tf.errors.UnimplementedError` is raised.
502
503    This operation concatenates queue-element component tensors along
504    the 0th dimension to make a single component tensor. If the queue
505    has not been closed, all of the components in the dequeued tuple
506    will have size `n` in the 0th dimension.
507
508    If the queue is closed and there are more than `0` but fewer than
509    `n` elements remaining, then instead of raising a
510    `tf.errors.OutOfRangeError` like @{tf.QueueBase.dequeue_many},
511    less than `n` elements are returned immediately.  If the queue is
512    closed and there are `0` elements left in the queue, then a
513    `tf.errors.OutOfRangeError` is raised just like in `dequeue_many`.
514    Otherwise the behavior is identical to `dequeue_many`.
515
516    Args:
517      n: A scalar `Tensor` containing the number of elements to dequeue.
518      name: A name for the operation (optional).
519
520    Returns:
521      The tuple of concatenated tensors that was dequeued.
522    """
523    if name is None:
524      name = "%s_DequeueUpTo" % self._name
525
526    ret = gen_data_flow_ops._queue_dequeue_up_to_v2(
527        self._queue_ref, n=n, component_types=self._dtypes, name=name)
528
529    # NOTE(mrry): Not using a shape function because we need access to
530    # the Queue object.
531    if context.in_graph_mode():
532      op = ret[0].op
533      for output, shape in zip(op.values(), self._shapes):
534        output.set_shape(tensor_shape.TensorShape([None]).concatenate(shape))
535
536    return self._dequeue_return_value(ret)
537
538  def close(self, cancel_pending_enqueues=False, name=None):
539    """Closes this queue.
540
541    This operation signals that no more elements will be enqueued in
542    the given queue. Subsequent `enqueue` and `enqueue_many`
543    operations will fail. Subsequent `dequeue` and `dequeue_many`
544    operations will continue to succeed if sufficient elements remain
545    in the queue. Subsequently dequeue and dequeue_many operations
546    that would otherwise block waiting for more elements (if close
547    hadn't been called) will now fail immediately.
548
549    If `cancel_pending_enqueues` is `True`, all pending requests will also
550    be canceled.
551
552    Args:
553      cancel_pending_enqueues: (Optional.) A boolean, defaulting to
554        `False` (described above).
555      name: A name for the operation (optional).
556
557    Returns:
558      The operation that closes the queue.
559    """
560    if name is None:
561      name = "%s_Close" % self._name
562    if self._queue_ref.dtype == _dtypes.resource:
563      return gen_data_flow_ops._queue_close_v2(
564          self._queue_ref,
565          cancel_pending_enqueues=cancel_pending_enqueues,
566          name=name)
567    else:
568      return gen_data_flow_ops._queue_close(
569          self._queue_ref,
570          cancel_pending_enqueues=cancel_pending_enqueues,
571          name=name)
572
573  def is_closed(self, name=None):
574    """ Returns true if queue is closed.
575
576    This operation returns true if the queue is closed and false if the queue
577    is open.
578
579    Args:
580      name: A name for the operation (optional).
581
582    Returns:
583      True if the queue is closed and false if the queue is open.
584    """
585    if name is None:
586      name = "%s_Is_Closed" % self._name
587    if self._queue_ref.dtype == _dtypes.resource:
588      return gen_data_flow_ops.queue_is_closed_v2(self._queue_ref, name=name)
589    else:
590      return gen_data_flow_ops.queue_is_closed_(self._queue_ref, name=name)
591
592  def size(self, name=None):
593    """Compute the number of elements in this queue.
594
595    Args:
596      name: A name for the operation (optional).
597
598    Returns:
599      A scalar tensor containing the number of elements in this queue.
600    """
601    if name is None:
602      name = "%s_Size" % self._name
603    if self._queue_ref.dtype == _dtypes.resource:
604      return gen_data_flow_ops._queue_size_v2(self._queue_ref, name=name)
605    else:
606      return gen_data_flow_ops._queue_size(self._queue_ref, name=name)
607
608
609@tf_export("RandomShuffleQueue")
610class RandomShuffleQueue(QueueBase):
611  """A queue implementation that dequeues elements in a random order.
612
613  See @{tf.QueueBase} for a description of the methods on
614  this class.
615
616  @compatibility(eager)
617  Queues are not compatible with eager execution. Instead, please
618  use `tf.data` to get data into your model.
619  @end_compatibility
620  """
621
622  def __init__(self,
623               capacity,
624               min_after_dequeue,
625               dtypes,
626               shapes=None,
627               names=None,
628               seed=None,
629               shared_name=None,
630               name="random_shuffle_queue"):
631    """Create a queue that dequeues elements in a random order.
632
633    A `RandomShuffleQueue` has bounded capacity; supports multiple
634    concurrent producers and consumers; and provides exactly-once
635    delivery.
636
637    A `RandomShuffleQueue` holds a list of up to `capacity`
638    elements. Each element is a fixed-length tuple of tensors whose
639    dtypes are described by `dtypes`, and whose shapes are optionally
640    described by the `shapes` argument.
641
642    If the `shapes` argument is specified, each component of a queue
643    element must have the respective fixed shape. If it is
644    unspecified, different queue elements may have different shapes,
645    but the use of `dequeue_many` is disallowed.
646
647    The `min_after_dequeue` argument allows the caller to specify a
648    minimum number of elements that will remain in the queue after a
649    `dequeue` or `dequeue_many` operation completes, to ensure a
650    minimum level of mixing of elements. This invariant is maintained
651    by blocking those operations until sufficient elements have been
652    enqueued. The `min_after_dequeue` argument is ignored after the
653    queue has been closed.
654
655    Args:
656      capacity: An integer. The upper bound on the number of elements
657        that may be stored in this queue.
658      min_after_dequeue: An integer (described above).
659      dtypes:  A list of `DType` objects. The length of `dtypes` must equal
660        the number of tensors in each queue element.
661      shapes: (Optional.) A list of fully-defined `TensorShape` objects
662        with the same length as `dtypes`, or `None`.
663      names: (Optional.) A list of string naming the components in the queue
664        with the same length as `dtypes`, or `None`.  If specified the dequeue
665        methods return a dictionary with the names as keys.
666      seed: A Python integer. Used to create a random seed. See
667        @{tf.set_random_seed}
668        for behavior.
669      shared_name: (Optional.) If non-empty, this queue will be shared under
670        the given name across multiple sessions.
671      name: Optional name for the queue operation.
672    """
673    dtypes = _as_type_list(dtypes)
674    shapes = _as_shape_list(shapes, dtypes)
675    names = _as_name_list(names, dtypes)
676    seed1, seed2 = random_seed.get_seed(seed)
677    if seed1 is None and seed2 is None:
678      seed1, seed2 = 0, 0
679    elif seed is None and shared_name is not None:
680      # This means that graph seed is provided but op seed is not provided.
681      # If shared_name is also provided, make seed2 depend only on the graph
682      # seed and shared_name. (seed2 from get_seed() is generally dependent on
683      # the id of the last op created.)
684      string = (str(seed1) + shared_name).encode("utf-8")
685      seed2 = int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF
686    queue_ref = gen_data_flow_ops._random_shuffle_queue_v2(
687        component_types=dtypes,
688        shapes=shapes,
689        capacity=capacity,
690        min_after_dequeue=min_after_dequeue,
691        seed=seed1,
692        seed2=seed2,
693        shared_name=shared_name,
694        name=name)
695
696    super(RandomShuffleQueue, self).__init__(dtypes, shapes, names, queue_ref)
697
698
699@tf_export("FIFOQueue")
700class FIFOQueue(QueueBase):
701  """A queue implementation that dequeues elements in first-in first-out order.
702
703  See @{tf.QueueBase} for a description of the methods on
704  this class.
705
706  @compatibility(eager)
707  Queues are not compatible with eager execution. Instead, please
708  use `tf.data` to get data into your model.
709  @end_compatibility
710  """
711
712  def __init__(self,
713               capacity,
714               dtypes,
715               shapes=None,
716               names=None,
717               shared_name=None,
718               name="fifo_queue"):
719    """Creates a queue that dequeues elements in a first-in first-out order.
720
721    A `FIFOQueue` has bounded capacity; supports multiple concurrent
722    producers and consumers; and provides exactly-once delivery.
723
724    A `FIFOQueue` holds a list of up to `capacity` elements. Each
725    element is a fixed-length tuple of tensors whose dtypes are
726    described by `dtypes`, and whose shapes are optionally described
727    by the `shapes` argument.
728
729    If the `shapes` argument is specified, each component of a queue
730    element must have the respective fixed shape. If it is
731    unspecified, different queue elements may have different shapes,
732    but the use of `dequeue_many` is disallowed.
733
734    Args:
735      capacity: An integer. The upper bound on the number of elements
736        that may be stored in this queue.
737      dtypes:  A list of `DType` objects. The length of `dtypes` must equal
738        the number of tensors in each queue element.
739      shapes: (Optional.) A list of fully-defined `TensorShape` objects
740        with the same length as `dtypes`, or `None`.
741      names: (Optional.) A list of string naming the components in the queue
742        with the same length as `dtypes`, or `None`.  If specified the dequeue
743        methods return a dictionary with the names as keys.
744      shared_name: (Optional.) If non-empty, this queue will be shared under
745        the given name across multiple sessions.
746      name: Optional name for the queue operation.
747    """
748    dtypes = _as_type_list(dtypes)
749    shapes = _as_shape_list(shapes, dtypes)
750    names = _as_name_list(names, dtypes)
751    queue_ref = gen_data_flow_ops._fifo_queue_v2(
752        component_types=dtypes,
753        shapes=shapes,
754        capacity=capacity,
755        shared_name=shared_name,
756        name=name)
757
758    super(FIFOQueue, self).__init__(dtypes, shapes, names, queue_ref)
759
760
761@tf_export("PaddingFIFOQueue")
762class PaddingFIFOQueue(QueueBase):
763  """A FIFOQueue that supports batching variable-sized tensors by padding.
764
765  A `PaddingFIFOQueue` may contain components with dynamic shape, while also
766  supporting `dequeue_many`.  See the constructor for more details.
767
768  See @{tf.QueueBase} for a description of the methods on
769  this class.
770
771  @compatibility(eager)
772  Queues are not compatible with eager execution. Instead, please
773  use `tf.data` to get data into your model.
774  @end_compatibility
775  """
776
777  def __init__(self,
778               capacity,
779               dtypes,
780               shapes,
781               names=None,
782               shared_name=None,
783               name="padding_fifo_queue"):
784    """Creates a queue that dequeues elements in a first-in first-out order.
785
786    A `PaddingFIFOQueue` has bounded capacity; supports multiple concurrent
787    producers and consumers; and provides exactly-once delivery.
788
789    A `PaddingFIFOQueue` holds a list of up to `capacity` elements. Each
790    element is a fixed-length tuple of tensors whose dtypes are
791    described by `dtypes`, and whose shapes are described by the `shapes`
792    argument.
793
794    The `shapes` argument must be specified; each component of a queue
795    element must have the respective shape.  Shapes of fixed
796    rank but variable size are allowed by setting any shape dimension to None.
797    In this case, the inputs' shape may vary along the given dimension, and
798    `dequeue_many` will pad the given dimension with zeros up to the maximum
799    shape of all elements in the given batch.
800
801    Args:
802      capacity: An integer. The upper bound on the number of elements
803        that may be stored in this queue.
804      dtypes:  A list of `DType` objects. The length of `dtypes` must equal
805        the number of tensors in each queue element.
806      shapes: A list of `TensorShape` objects, with the same length as
807        `dtypes`.  Any dimension in the `TensorShape` containing value
808        `None` is dynamic and allows values to be enqueued with
809         variable size in that dimension.
810      names: (Optional.) A list of string naming the components in the queue
811        with the same length as `dtypes`, or `None`.  If specified the dequeue
812        methods return a dictionary with the names as keys.
813      shared_name: (Optional.) If non-empty, this queue will be shared under
814        the given name across multiple sessions.
815      name: Optional name for the queue operation.
816
817    Raises:
818      ValueError: If shapes is not a list of shapes, or the lengths of dtypes
819        and shapes do not match, or if names is specified and the lengths of
820        dtypes and names do not match.
821    """
822    dtypes = _as_type_list(dtypes)
823    shapes = _as_shape_list(shapes, dtypes, unknown_dim_allowed=True)
824    names = _as_name_list(names, dtypes)
825    if len(dtypes) != len(shapes):
826      raise ValueError("Shapes must be provided for all components, "
827                       "but received %d dtypes and %d shapes." % (len(dtypes),
828                                                                  len(shapes)))
829
830    queue_ref = gen_data_flow_ops._padding_fifo_queue_v2(
831        component_types=dtypes,
832        shapes=shapes,
833        capacity=capacity,
834        shared_name=shared_name,
835        name=name)
836
837    super(PaddingFIFOQueue, self).__init__(dtypes, shapes, names, queue_ref)
838
839
840@tf_export("PriorityQueue")
841class PriorityQueue(QueueBase):
842  """A queue implementation that dequeues elements in prioritized order.
843
844  See @{tf.QueueBase} for a description of the methods on
845  this class.
846
847  @compatibility(eager)
848  Queues are not compatible with eager execution. Instead, please
849  use `tf.data` to get data into your model.
850  @end_compatibility
851  """
852
853  def __init__(self,
854               capacity,
855               types,
856               shapes=None,
857               names=None,
858               shared_name=None,
859               name="priority_queue"):
860    """Creates a queue that dequeues elements in a first-in first-out order.
861
862    A `PriorityQueue` has bounded capacity; supports multiple concurrent
863    producers and consumers; and provides exactly-once delivery.
864
865    A `PriorityQueue` holds a list of up to `capacity` elements. Each
866    element is a fixed-length tuple of tensors whose dtypes are
867    described by `types`, and whose shapes are optionally described
868    by the `shapes` argument.
869
870    If the `shapes` argument is specified, each component of a queue
871    element must have the respective fixed shape. If it is
872    unspecified, different queue elements may have different shapes,
873    but the use of `dequeue_many` is disallowed.
874
875    Enqueues and Dequeues to the `PriorityQueue` must include an additional
876    tuple entry at the beginning: the `priority`.  The priority must be
877    an int64 scalar (for `enqueue`) or an int64 vector (for `enqueue_many`).
878
879    Args:
880      capacity: An integer. The upper bound on the number of elements
881        that may be stored in this queue.
882      types:  A list of `DType` objects. The length of `types` must equal
883        the number of tensors in each queue element, except the first priority
884        element.  The first tensor in each element is the priority,
885        which must be type int64.
886      shapes: (Optional.) A list of fully-defined `TensorShape` objects,
887        with the same length as `types`, or `None`.
888      names: (Optional.) A list of strings naming the components in the queue
889        with the same length as `dtypes`, or `None`.  If specified, the dequeue
890        methods return a dictionary with the names as keys.
891      shared_name: (Optional.) If non-empty, this queue will be shared under
892        the given name across multiple sessions.
893      name: Optional name for the queue operation.
894    """
895    types = _as_type_list(types)
896    shapes = _as_shape_list(shapes, types)
897
898    queue_ref = gen_data_flow_ops._priority_queue_v2(
899        component_types=types,
900        shapes=shapes,
901        capacity=capacity,
902        shared_name=shared_name,
903        name=name)
904
905    priority_dtypes = [_dtypes.int64] + types
906    priority_shapes = [()] + shapes if shapes else shapes
907
908    super(PriorityQueue, self).__init__(priority_dtypes, priority_shapes, names,
909                                        queue_ref)
910
911
912# TODO(josh11b): class BatchQueue(QueueBase):
913
914
915class Barrier(object):
916  """Represents a key-value map that persists across graph executions."""
917
918  def __init__(self, types, shapes=None, shared_name=None, name="barrier"):
919    """Creates a barrier that persists across different graph executions.
920
921    A barrier represents a key-value map, where each key is a string, and
922    each value is a tuple of tensors.
923
924    At runtime, the barrier contains 'complete' and 'incomplete'
925    elements. A complete element has defined tensors for all
926    components of its value tuple, and may be accessed using
927    take_many. An incomplete element has some undefined components in
928    its value tuple, and may be updated using insert_many.
929
930    The barrier call `take_many` outputs values in a particular order.
931    First, it only outputs completed values.  Second, the order in which
932    completed values are returned matches the order in which their very
933    first component was inserted into the barrier.  So, for example, for this
934    sequence of insertions and removals:
935
936      barrier = Barrier((tf.string, tf.int32), shapes=((), ()))
937      barrier.insert_many(0, keys=["k1", "k2"], values=["a", "b"]).run()
938      barrier.insert_many(1, keys=["k1"], values=[1]).run()
939      barrier.insert_many(0, keys=["k3"], values=["c"]).run()
940      barrier.insert_many(1, keys=["k3"], values=[3]).run()
941      barrier.insert_many(1, keys=["k2"], values=[2]).run()
942
943      (indices, keys, values) = barrier.take_many(2)
944      (indices_val, keys_val, values0_val, values1_val) =
945         session.run([indices, keys, values[0], values[1]])
946
947    The output will be (up to permutation of "k1" and "k2"):
948
949      indices_val == (-2**63, -2**63)
950      keys_val == ("k1", "k2")
951      values0_val == ("a", "b")
952      values1_val == (1, 2)
953
954    Note the key "k2" was inserted into the barrier before "k3".  Even though
955    "k3" was completed first, both are complete by the time
956    take_many is called.  As a result, "k2" is prioritized and "k1" and "k2"
957    are returned first.  "k3" remains in the barrier until the next execution
958    of `take_many`.  Since "k1" and "k2" had their first insertions into
959    the barrier together, their indices are the same (-2**63).  The index
960    of "k3" will be -2**63 + 1, because it was the next new inserted key.
961
962    Args:
963      types: A single dtype or a tuple of dtypes, corresponding to the
964        dtypes of the tensor elements that comprise a value in this barrier.
965      shapes: Optional. Constraints on the shapes of tensors in the values:
966        a single tensor shape tuple; a tuple of tensor shape tuples
967        for each barrier-element tuple component; or None if the shape should
968        not be constrained.
969      shared_name: Optional. If non-empty, this barrier will be shared under
970        the given name across multiple sessions.
971      name: Optional name for the barrier op.
972
973    Raises:
974      ValueError: If one of the `shapes` indicate no elements.
975    """
976    self._types = _as_type_list(types)
977
978    if shapes is not None:
979      shapes = _as_shape_list(shapes, self._types)
980      self._shapes = [tensor_shape.TensorShape(s) for s in shapes]
981      for i, shape in enumerate(self._shapes):
982        if shape.num_elements() == 0:
983          raise ValueError("Empty tensors are not supported, but received "
984                           "shape '%s' at index %d" % (shape, i))
985    else:
986      self._shapes = [tensor_shape.unknown_shape() for _ in self._types]
987
988    self._barrier_ref = gen_data_flow_ops._barrier(
989        component_types=self._types,
990        shapes=self._shapes,
991        shared_name=shared_name,
992        name=name)
993    if context.in_graph_mode():
994      self._name = self._barrier_ref.op.name.split("/")[-1]
995    else:
996      self._name = context.context().scope_name
997
998  @property
999  def barrier_ref(self):
1000    """Get the underlying barrier reference."""
1001    return self._barrier_ref
1002
1003  @property
1004  def name(self):
1005    """The name of the underlying barrier."""
1006    if context.in_graph_mode():
1007      return self._barrier_ref.op.name
1008    return self._name
1009
1010  def insert_many(self, component_index, keys, values, name=None):
1011    """For each key, assigns the respective value to the specified component.
1012
1013    This operation updates each element at component_index.
1014
1015    Args:
1016      component_index: The component of the value that is being assigned.
1017      keys: A vector of keys, with length n.
1018      values: An any-dimensional tensor of values, which are associated with the
1019        respective keys. The first dimension must have length n.
1020      name: Optional name for the op.
1021
1022    Returns:
1023      The operation that performs the insertion.
1024    Raises:
1025      InvalidArgumentsError: If inserting keys and values without elements.
1026    """
1027    if name is None:
1028      name = "%s_BarrierInsertMany" % self._name
1029    return gen_data_flow_ops._barrier_insert_many(
1030        self._barrier_ref, keys, values, component_index, name=name)
1031
1032  def take_many(self,
1033                num_elements,
1034                allow_small_batch=False,
1035                timeout=None,
1036                name=None):
1037    """Takes the given number of completed elements from this barrier.
1038
1039    This operation concatenates completed-element component tensors along
1040    the 0th dimension to make a single component tensor.
1041
1042    If barrier has no completed elements, this operation will block
1043    until there are 'num_elements' elements to take.
1044
1045    TODO(b/25743580): the semantics of `allow_small_batch` are experimental
1046    and may be extended to other cases in the future.
1047
1048    TODO(ebrevdo): If a take_many(allow_small_batch=True) is blocking
1049    already when the barrier is closed, it will block for ever. Fix this
1050    by using asynchronous operations.
1051
1052    Args:
1053      num_elements: The number of elements to take.
1054      allow_small_batch: If the barrier is closed, don't block if there are less
1055        completed elements than requested, but instead return all available
1056        completed elements.
1057      timeout: This specifies the number of milliseconds to block
1058        before returning with DEADLINE_EXCEEDED. (This option is not
1059        supported yet.)
1060      name: A name for the operation (optional).
1061
1062    Returns:
1063      A tuple of (index, key, value_list).
1064      "index" is a int64 tensor of length num_elements containing the
1065        index of the insert_many call for which the very first component of
1066        the given element was inserted into the Barrier, starting with
1067        the value -2**63.  Note, this value is different from the
1068        index of the insert_many call for which the element was completed.
1069      "key" is a string tensor of length num_elements containing the keys.
1070      "value_list" is a tuple of tensors, each one with size num_elements
1071        in the 0th dimension for each component in the barrier's values.
1072
1073    """
1074    if name is None:
1075      name = "%s_BarrierTakeMany" % self._name
1076    ret = gen_data_flow_ops._barrier_take_many(
1077        self._barrier_ref,
1078        num_elements,
1079        self._types,
1080        allow_small_batch,
1081        timeout,
1082        name=name)
1083
1084    # NOTE(mrry): Not using a shape function because we need access to
1085    # the Barrier object.
1086    if context.in_graph_mode():
1087      op = ret[0].op
1088      if allow_small_batch:
1089        batch_dim = None
1090      else:
1091        batch_dim = tensor_shape.Dimension(
1092            tensor_util.constant_value(op.inputs[1]))
1093      op.outputs[0].set_shape(tensor_shape.vector(batch_dim))  # indices
1094      op.outputs[1].set_shape(tensor_shape.vector(batch_dim))  # keys
1095      for output, shape in zip(op.outputs[2:], self._shapes):  # value_list
1096        output.set_shape(
1097            tensor_shape.TensorShape([batch_dim]).concatenate(shape))
1098
1099    return ret
1100
1101  def close(self, cancel_pending_enqueues=False, name=None):
1102    """Closes this barrier.
1103
1104    This operation signals that no more new key values will be inserted in the
1105    given barrier. Subsequent InsertMany operations with new keys will fail.
1106    InsertMany operations that just complement already existing keys with other
1107    components, will continue to succeed. Subsequent TakeMany operations will
1108    continue to succeed if sufficient elements remain in the barrier. Subsequent
1109    TakeMany operations that would block will fail immediately.
1110
1111    If `cancel_pending_enqueues` is `True`, all pending requests to the
1112    underlying queue will also be canceled, and completing of already
1113    started values is also not acceptable anymore.
1114
1115    Args:
1116      cancel_pending_enqueues: (Optional.) A boolean, defaulting to
1117        `False` (described above).
1118      name: Optional name for the op.
1119
1120    Returns:
1121      The operation that closes the barrier.
1122    """
1123    if name is None:
1124      name = "%s_BarrierClose" % self._name
1125    return gen_data_flow_ops._barrier_close(
1126        self._barrier_ref,
1127        cancel_pending_enqueues=cancel_pending_enqueues,
1128        name=name)
1129
1130  def ready_size(self, name=None):
1131    """Compute the number of complete elements in the given barrier.
1132
1133    Args:
1134      name: A name for the operation (optional).
1135
1136    Returns:
1137      A single-element tensor containing the number of complete elements in the
1138      given barrier.
1139    """
1140    if name is None:
1141      name = "%s_BarrierReadySize" % self._name
1142    return gen_data_flow_ops._barrier_ready_size(self._barrier_ref, name=name)
1143
1144  def incomplete_size(self, name=None):
1145    """Compute the number of incomplete elements in the given barrier.
1146
1147    Args:
1148      name: A name for the operation (optional).
1149
1150    Returns:
1151      A single-element tensor containing the number of incomplete elements in
1152      the given barrier.
1153    """
1154    if name is None:
1155      name = "%s_BarrierIncompleteSize" % self._name
1156    return gen_data_flow_ops._barrier_incomplete_size(
1157        self._barrier_ref, name=name)
1158
1159
1160@tf_export("ConditionalAccumulatorBase")
1161class ConditionalAccumulatorBase(object):
1162  """A conditional accumulator for aggregating gradients.
1163
1164  Up-to-date gradients (i.e., time step at which gradient was computed is
1165  equal to the accumulator's time step) are added to the accumulator.
1166
1167  Extraction of the average gradient is blocked until the required number of
1168  gradients has been accumulated.
1169  """
1170
1171  def __init__(self, dtype, shape, accumulator_ref):
1172    """Creates a new ConditionalAccumulator.
1173
1174    Args:
1175      dtype: Datatype of the accumulated gradients.
1176      shape: Shape of the accumulated gradients.
1177      accumulator_ref: A handle to the conditional accumulator, created by sub-
1178        classes
1179    """
1180    self._dtype = dtype
1181    if shape is not None:
1182      self._shape = tensor_shape.TensorShape(shape)
1183    else:
1184      self._shape = tensor_shape.unknown_shape()
1185    self._accumulator_ref = accumulator_ref
1186    if context.in_graph_mode():
1187      self._name = self._accumulator_ref.op.name.split("/")[-1]
1188    else:
1189      self._name = context.context().scope_name
1190
1191  @property
1192  def accumulator_ref(self):
1193    """The underlying accumulator reference."""
1194    return self._accumulator_ref
1195
1196  @property
1197  def name(self):
1198    """The name of the underlying accumulator."""
1199    return self._name
1200
1201  @property
1202  def dtype(self):
1203    """The datatype of the gradients accumulated by this accumulator."""
1204    return self._dtype
1205
1206  def num_accumulated(self, name=None):
1207    """Number of gradients that have currently been aggregated in accumulator.
1208
1209    Args:
1210      name: Optional name for the operation.
1211
1212    Returns:
1213      Number of accumulated gradients currently in accumulator.
1214    """
1215    if name is None:
1216      name = "%s_NumAccumulated" % self._name
1217    return gen_data_flow_ops.accumulator_num_accumulated(
1218        self._accumulator_ref, name=name)
1219
1220  def set_global_step(self, new_global_step, name=None):
1221    """Sets the global time step of the accumulator.
1222
1223    The operation logs a warning if we attempt to set to a time step that is
1224    lower than the accumulator's own time step.
1225
1226    Args:
1227      new_global_step: Value of new time step. Can be a variable or a constant
1228      name: Optional name for the operation.
1229
1230    Returns:
1231      Operation that sets the accumulator's time step.
1232    """
1233    return gen_data_flow_ops.accumulator_set_global_step(
1234        self._accumulator_ref,
1235        math_ops.to_int64(ops.convert_to_tensor(new_global_step)),
1236        name=name)
1237
1238
1239@tf_export("ConditionalAccumulator")
1240class ConditionalAccumulator(ConditionalAccumulatorBase):
1241  """A conditional accumulator for aggregating gradients.
1242
1243  Up-to-date gradients (i.e., time step at which gradient was computed is
1244  equal to the accumulator's time step) are added to the accumulator.
1245
1246  Extraction of the average gradient is blocked until the required number of
1247  gradients has been accumulated.
1248  """
1249
1250  def __init__(self,
1251               dtype,
1252               shape=None,
1253               shared_name=None,
1254               name="conditional_accumulator"):
1255    """Creates a new ConditionalAccumulator.
1256
1257    Args:
1258      dtype: Datatype of the accumulated gradients.
1259      shape: Shape of the accumulated gradients.
1260      shared_name: Optional. If non-empty, this accumulator will be shared under
1261        the given name across multiple sessions.
1262      name: Optional name for the accumulator.
1263    """
1264    accumulator_ref = gen_data_flow_ops.conditional_accumulator(
1265        dtype=dtype, shape=shape, shared_name=shared_name, name=name)
1266    super(ConditionalAccumulator, self).__init__(dtype, shape, accumulator_ref)
1267
1268  def apply_grad(self, grad, local_step=0, name=None):
1269    """Attempts to apply a gradient to the accumulator.
1270
1271    The attempt is silently dropped if the gradient is stale, i.e., local_step
1272    is less than the accumulator's global time step.
1273
1274    Args:
1275      grad: The gradient tensor to be applied.
1276      local_step: Time step at which the gradient was computed.
1277      name: Optional name for the operation.
1278
1279    Returns:
1280      The operation that (conditionally) applies a gradient to the accumulator.
1281
1282    Raises:
1283      ValueError: If grad is of the wrong shape
1284    """
1285    grad = ops.convert_to_tensor(grad, self._dtype)
1286    grad.get_shape().assert_is_compatible_with(self._shape)
1287    local_step = math_ops.to_int64(ops.convert_to_tensor(local_step))
1288    return gen_data_flow_ops.accumulator_apply_gradient(
1289        self._accumulator_ref, local_step=local_step, gradient=grad, name=name)
1290
1291  def take_grad(self, num_required, name=None):
1292    """Attempts to extract the average gradient from the accumulator.
1293
1294    The operation blocks until sufficient number of gradients have been
1295    successfully applied to the accumulator.
1296
1297    Once successful, the following actions are also triggered:
1298
1299    - Counter of accumulated gradients is reset to 0.
1300    - Aggregated gradient is reset to 0 tensor.
1301    - Accumulator's internal time step is incremented by 1.
1302
1303    Args:
1304      num_required: Number of gradients that needs to have been aggregated
1305      name: Optional name for the operation
1306
1307    Returns:
1308      A tensor holding the value of the average gradient.
1309
1310    Raises:
1311      InvalidArgumentError: If num_required < 1
1312    """
1313    out = gen_data_flow_ops.accumulator_take_gradient(
1314        self._accumulator_ref, num_required, dtype=self._dtype, name=name)
1315    out.set_shape(self._shape)
1316    return out
1317
1318
1319@tf_export("SparseConditionalAccumulator")
1320class SparseConditionalAccumulator(ConditionalAccumulatorBase):
1321  """A conditional accumulator for aggregating sparse gradients.
1322
1323  Sparse gradients are represented by IndexedSlices.
1324
1325  Up-to-date gradients (i.e., time step at which gradient was computed is
1326  equal to the accumulator's time step) are added to the accumulator.
1327
1328  Extraction of the average gradient is blocked until the required number of
1329  gradients has been accumulated.
1330
1331  Args:
1332    dtype: Datatype of the accumulated gradients.
1333    shape: Shape of the accumulated gradients.
1334    shared_name: Optional. If non-empty, this accumulator will be shared under
1335      the given name across multiple sessions.
1336    name: Optional name for the accumulator.
1337  """
1338
1339  def __init__(self,
1340               dtype,
1341               shape=None,
1342               shared_name=None,
1343               name="sparse_conditional_accumulator"):
1344    accumulator_ref = gen_data_flow_ops.sparse_conditional_accumulator(
1345        dtype=dtype, shape=shape, shared_name=shared_name, name=name)
1346    super(SparseConditionalAccumulator, self).__init__(dtype, shape,
1347                                                       accumulator_ref)
1348
1349  def apply_indexed_slices_grad(self, grad, local_step=0, name=None):
1350    """Attempts to apply a gradient to the accumulator.
1351
1352    The attempt is silently dropped if the gradient is stale, i.e., local_step
1353    is less than the accumulator's global time step.
1354
1355    Args:
1356      grad: The gradient IndexedSlices to be applied.
1357      local_step: Time step at which the gradient was computed.
1358      name: Optional name for the operation.
1359
1360    Returns:
1361      The operation that (conditionally) applies a gradient to the accumulator.
1362
1363    Raises:
1364      InvalidArgumentError: If grad is of the wrong shape
1365    """
1366    return self.apply_grad(
1367        grad_indices=grad.indices,
1368        grad_values=grad.values,
1369        grad_shape=grad.dense_shape,
1370        local_step=local_step,
1371        name=name)
1372
1373  def apply_grad(self,
1374                 grad_indices,
1375                 grad_values,
1376                 grad_shape=None,
1377                 local_step=0,
1378                 name=None):
1379    """Attempts to apply a sparse gradient to the accumulator.
1380
1381    The attempt is silently dropped if the gradient is stale, i.e., local_step
1382    is less than the accumulator's global time step.
1383
1384    A sparse gradient is represented by its indices, values and possibly empty
1385    or None shape. Indices must be a vector representing the locations of
1386    non-zero entries in the tensor. Values are the non-zero slices of the
1387    gradient, and must have the same first dimension as indices, i.e., the nnz
1388    represented by indices and values must be consistent. Shape, if not empty or
1389    None, must be consistent with the accumulator's shape (if also provided).
1390
1391    Example:
1392      A tensor [[0, 0], [0. 1], [2, 3]] can be represented
1393        indices: [1,2]
1394        values: [[0,1],[2,3]]
1395        shape: [3, 2]
1396
1397    Args:
1398      grad_indices: Indices of the sparse gradient to be applied.
1399      grad_values: Values of the sparse gradient to be applied.
1400      grad_shape: Shape of the sparse gradient to be applied.
1401      local_step: Time step at which the gradient was computed.
1402      name: Optional name for the operation.
1403
1404    Returns:
1405      The operation that (conditionally) applies a gradient to the accumulator.
1406
1407    Raises:
1408      InvalidArgumentError: If grad is of the wrong shape
1409    """
1410    local_step = math_ops.to_int64(ops.convert_to_tensor(local_step))
1411    return gen_data_flow_ops.sparse_accumulator_apply_gradient(
1412        self._accumulator_ref,
1413        local_step=local_step,
1414        gradient_indices=math_ops.to_int64(grad_indices),
1415        gradient_values=grad_values,
1416        gradient_shape=math_ops.to_int64([]
1417                                         if grad_shape is None else grad_shape),
1418        has_known_shape=(grad_shape is not None),
1419        name=name)
1420
1421  def take_grad(self, num_required, name=None):
1422    """Attempts to extract the average gradient from the accumulator.
1423
1424    The operation blocks until sufficient number of gradients have been
1425    successfully applied to the accumulator.
1426
1427    Once successful, the following actions are also triggered:
1428    - Counter of accumulated gradients is reset to 0.
1429    - Aggregated gradient is reset to 0 tensor.
1430    - Accumulator's internal time step is incremented by 1.
1431
1432    Args:
1433      num_required: Number of gradients that needs to have been aggregated
1434      name: Optional name for the operation
1435
1436    Returns:
1437      A tuple of indices, values, and shape representing the average gradient.
1438
1439    Raises:
1440      InvalidArgumentError: If num_required < 1
1441    """
1442    return gen_data_flow_ops.sparse_accumulator_take_gradient(
1443        self._accumulator_ref, num_required, dtype=self._dtype, name=name)
1444
1445  def take_indexed_slices_grad(self, num_required, name=None):
1446    """Attempts to extract the average gradient from the accumulator.
1447
1448    The operation blocks until sufficient number of gradients have been
1449    successfully applied to the accumulator.
1450
1451    Once successful, the following actions are also triggered:
1452    - Counter of accumulated gradients is reset to 0.
1453    - Aggregated gradient is reset to 0 tensor.
1454    - Accumulator's internal time step is incremented by 1.
1455
1456    Args:
1457      num_required: Number of gradients that needs to have been aggregated
1458      name: Optional name for the operation
1459
1460    Returns:
1461      An IndexedSlices holding the value of the average gradient.
1462
1463    Raises:
1464      InvalidArgumentError: If num_required < 1
1465    """
1466    return_val = gen_data_flow_ops.sparse_accumulator_take_gradient(
1467        self._accumulator_ref, num_required, dtype=self._dtype, name=name)
1468    return ops.IndexedSlices(
1469        indices=return_val.indices,
1470        values=return_val.values,
1471        dense_shape=return_val.shape)
1472
1473
1474class BaseStagingArea(object):
1475  """Base class for Staging Areas."""
1476  _identifier = 0
1477  _lock = threading.Lock()
1478
1479  def __init__(self,
1480               dtypes,
1481               shapes=None,
1482               names=None,
1483               shared_name=None,
1484               capacity=0,
1485               memory_limit=0):
1486    if shared_name is None:
1487      self._name = (
1488          ops.get_default_graph().unique_name(self.__class__.__name__))
1489    elif isinstance(shared_name, six.string_types):
1490      self._name = shared_name
1491    else:
1492      raise ValueError("shared_name must be a string")
1493
1494    self._dtypes = dtypes
1495
1496    if shapes is not None:
1497      if len(shapes) != len(dtypes):
1498        raise ValueError("StagingArea shapes must be the same length as dtypes")
1499      self._shapes = [tensor_shape.TensorShape(s) for s in shapes]
1500    else:
1501      self._shapes = [tensor_shape.unknown_shape() for _ in self._dtypes]
1502
1503    if names is not None:
1504      if len(names) != len(dtypes):
1505        raise ValueError("StagingArea names must be the same length as dtypes")
1506      self._names = names
1507    else:
1508      self._names = None
1509
1510    self._capacity = capacity
1511    self._memory_limit = memory_limit
1512
1513    # all get and put ops must colocate with this op
1514    with ops.name_scope("%s_root" % self._name):
1515      self._coloc_op = control_flow_ops.no_op()
1516
1517  @property
1518  def name(self):
1519    """The name of the staging area."""
1520    return self._name
1521
1522  @property
1523  def dtypes(self):
1524    """The list of dtypes for each component of a staging area element."""
1525    return self._dtypes
1526
1527  @property
1528  def shapes(self):
1529    """The list of shapes for each component of a staging area element."""
1530    return self._shapes
1531
1532  @property
1533  def names(self):
1534    """The list of names for each component of a staging area element."""
1535    return self._names
1536
1537  @property
1538  def capacity(self):
1539    """The maximum number of elements of this staging area."""
1540    return self._capacity
1541
1542  @property
1543  def memory_limit(self):
1544    """The maximum number of bytes of this staging area."""
1545    return self._memory_limit
1546
1547  def _check_put_dtypes(self, vals, indices=None):
1548    """Validate and convert `vals` to a list of `Tensor`s.
1549
1550    The `vals` argument can be a Tensor, a list or tuple of tensors, or a
1551    dictionary with tensor values.
1552
1553    If `vals` is a list, then the appropriate indices associated with the
1554    values must be provided.
1555
1556    If it is a dictionary, the staging area must have been constructed with a
1557    `names` attribute and the dictionary keys must match the staging area names.
1558    `indices` will be inferred from the dictionary keys.
1559    If the staging area was constructed with a `names` attribute, `vals` must
1560    be a dictionary.
1561
1562    Checks that the dtype and shape of each value matches that
1563    of the staging area.
1564
1565    Args:
1566      vals: A tensor, a list or tuple of tensors, or a dictionary..
1567
1568    Returns:
1569      A (tensors, indices) tuple where `tensors` is a list of `Tensor` objects
1570      and `indices` is a list of indices associed with the tensors.
1571
1572    Raises:
1573      ValueError: If `vals` or `indices` is invalid.
1574    """
1575    if isinstance(vals, dict):
1576      if not self._names:
1577        raise ValueError(
1578            "Staging areas must have names to enqueue a dictionary")
1579      if not set(vals.keys()).issubset(self._names):
1580        raise ValueError("Keys in dictionary to put do not match names "
1581                         "of staging area. Dictionary: (%s), Queue: (%s)" %
1582                         (sorted(vals.keys()), sorted(self._names)))
1583      # The order of values in `self._names` indicates the order in which the
1584      # tensors in the dictionary `vals` must be listed.
1585      vals, indices, n = zip(*[(vals[k], i, k)
1586                               for i, k in enumerate(self._names)
1587                               if k in vals])
1588    else:
1589      if self._names:
1590        raise ValueError("You must enqueue a dictionary in a staging area "
1591                         "with names")
1592
1593      if indices is None:
1594        raise ValueError("Indices must be supplied when inserting a list "
1595                         "of tensors")
1596
1597      if len(indices) != len(vals):
1598        raise ValueError("Number of indices '%s' doesn't match "
1599                         "number of values '%s'")
1600
1601      if not isinstance(vals, (list, tuple)):
1602        vals = [vals]
1603        indices = [0]
1604
1605    # Sanity check number of values
1606    if not len(vals) <= len(self._dtypes):
1607      raise ValueError("Unexpected number of inputs '%s' vs '%s'" %
1608                       (len(vals), len(self._dtypes)))
1609
1610    tensors = []
1611
1612    for val, i in zip(vals, indices):
1613      dtype, shape = self._dtypes[i], self._shapes[i]
1614      # Check dtype
1615      if not val.dtype == dtype:
1616        raise ValueError("Datatypes do not match. '%s' != '%s'" %
1617                         (str(val.dtype), str(dtype)))
1618
1619      # Check shape
1620      val.get_shape().assert_is_compatible_with(shape)
1621
1622      tensors.append(
1623          ops.convert_to_tensor(val, dtype=dtype, name="component_%d" % i))
1624
1625    return tensors, indices
1626
1627  def _create_device_transfers(self, tensors):
1628    """Encode inter-device transfers if the current device
1629    is not the same as the Staging Area's device
1630    """
1631
1632    if not isinstance(tensors, (tuple, list)):
1633      tensors = [tensors]
1634
1635    curr_device_scope = control_flow_ops.no_op().device
1636
1637    if curr_device_scope != self._coloc_op.device:
1638      tensors = [array_ops.identity(t) for t in tensors]
1639
1640    return tensors
1641
1642  def _get_return_value(self, tensors, indices):
1643    """Return the value to return from a get op.
1644
1645    If the staging area has names, return a dictionary with the
1646    names as keys.  Otherwise return either a single tensor
1647    or a list of tensors depending on the length of `tensors`.
1648
1649    Args:
1650      tensors: List of tensors from the get op.
1651      indices: Indices of associated names and shapes
1652
1653    Returns:
1654      A single tensor, a list of tensors, or a dictionary
1655      of tensors.
1656    """
1657
1658    tensors = self._create_device_transfers(tensors)
1659
1660    # Sets shape
1661    for output, i in zip(tensors, indices):
1662      output.set_shape(self._shapes[i])
1663
1664    if self._names:
1665      # The returned values in `tensors` are in the same order as
1666      # the names in `self._names`.
1667      return {self._names[i]: t for t, i in zip(tensors, indices)}
1668    return tensors
1669
1670  def _scope_vals(self, vals):
1671    """Return a list of values to pass to `name_scope()`.
1672
1673    Args:
1674      vals: A tensor, a list or tuple of tensors, or a dictionary.
1675
1676    Returns:
1677      The values in vals as a list.
1678    """
1679    if isinstance(vals, (list, tuple)):
1680      return vals
1681    elif isinstance(vals, dict):
1682      return vals.values()
1683    else:
1684      return [vals]
1685
1686
1687class StagingArea(BaseStagingArea):
1688  """Class for staging inputs. No ordering guarantees.
1689
1690  A `StagingArea` is a TensorFlow data structure that stores tensors across
1691  multiple steps, and exposes operations that can put and get tensors.
1692
1693  Each `StagingArea` element is a tuple of one or more tensors, where each
1694  tuple component has a static dtype, and may have a static shape.
1695
1696  The capacity of a `StagingArea` may be bounded or unbounded.
1697  It supports multiple concurrent producers and consumers; and
1698  provides exactly-once delivery.
1699
1700  Each element of a `StagingArea` is a fixed-length tuple of tensors whose
1701  dtypes are described by `dtypes`, and whose shapes are optionally described
1702  by the `shapes` argument.
1703
1704  If the `shapes` argument is specified, each component of a staging area
1705  element must have the respective fixed shape. If it is
1706  unspecified, different elements may have different shapes,
1707
1708  It can be configured with a capacity in which case
1709  put(values) will block until space becomes available.
1710
1711  Similarly, it can be configured with a memory limit which
1712  will block put(values) until space is available.
1713  This is mostly useful for limiting the number of tensors on
1714  devices such as GPUs.
1715
1716  All get() and peek() commands block if the requested data
1717  is not present in the Staging Area.
1718
1719  """
1720
1721  def __init__(self,
1722               dtypes,
1723               shapes=None,
1724               names=None,
1725               shared_name=None,
1726               capacity=0,
1727               memory_limit=0):
1728    """Constructs a staging area object.
1729
1730    The two optional lists, `shapes` and `names`, must be of the same length
1731    as `dtypes` if provided.  The values at a given index `i` indicate the
1732    shape and name to use for the corresponding queue component in `dtypes`.
1733
1734    The device scope at the time of object creation determines where the
1735    storage for the `StagingArea` will reside.  Calls to `put` will incur a copy
1736    to this memory space, if necessary.  Tensors returned by `get` will be
1737    placed according to the device scope when `get` is called.
1738
1739    Args:
1740      dtypes:  A list of types.  The length of dtypes must equal the number
1741        of tensors in each element.
1742      capacity: (Optional.) Maximum number of elements.
1743        An integer. If zero, the Staging Area is unbounded
1744      memory_limit: (Optional.) Maximum number of bytes of all tensors
1745        in the Staging Area.
1746        An integer. If zero, the Staging Area is unbounded
1747      shapes: (Optional.) Constraints on the shapes of tensors in an element.
1748        A list of shape tuples or None. This list is the same length
1749        as dtypes.  If the shape of any tensors in the element are constrained,
1750        all must be; shapes can be None if the shapes should not be constrained.
1751      names: (Optional.) If provided, the `get()` and
1752        `put()` methods will use dictionaries with these names as keys.
1753        Must be None or a list or tuple of the same length as `dtypes`.
1754      shared_name: (Optional.) A name to be used for the shared object. By
1755        passing the same name to two different python objects they will share
1756        the underlying staging area. Must be a string.
1757
1758    Raises:
1759      ValueError: If one of the arguments is invalid.
1760    """
1761
1762    super(StagingArea, self).__init__(dtypes, shapes, names, shared_name,
1763                                      capacity, memory_limit)
1764
1765  def put(self, values, name=None):
1766    """Create an op that places a value into the staging area.
1767
1768    This operation will block if the `StagingArea` has reached
1769    its capacity.
1770
1771    Args:
1772      values: Tensor (or a tuple of Tensors) to place into the staging area.
1773      name: A name for the operation (optional).
1774
1775    Returns:
1776        The created op.
1777
1778    Raises:
1779      ValueError: If the number or type of inputs don't match the staging area.
1780    """
1781    with ops.name_scope(name, "%s_put" % self._name,
1782                        self._scope_vals(values)) as scope:
1783
1784      # Hard-code indices for this staging area
1785      indices = (
1786          list(six.moves.range(len(values)))
1787          if isinstance(values, (list, tuple)) else None)
1788      vals, _ = self._check_put_dtypes(values, indices)
1789
1790      with ops.colocate_with(self._coloc_op):
1791        op = gen_data_flow_ops.stage(
1792            values=vals,
1793            shared_name=self._name,
1794            name=scope,
1795            capacity=self._capacity,
1796            memory_limit=self._memory_limit)
1797
1798      return op
1799
1800  def __internal_get(self, get_fn, name):
1801    with ops.colocate_with(self._coloc_op):
1802      ret = get_fn()
1803
1804    indices = list(six.moves.range(len(self._dtypes)))  # Hard coded
1805    return self._get_return_value(ret, indices)
1806
1807  def get(self, name=None):
1808    """Gets one element from this staging area.
1809
1810    If the staging area is empty when this operation executes, it will block
1811    until there is an element to dequeue.
1812
1813    Note that unlike others ops that can block, like the queue Dequeue
1814    operations, this can stop other work from happening.  To avoid this, the
1815    intended use is for this to be called only when there will be an element
1816    already available.  One method for doing this in a training loop would be to
1817    run a `put()` call during a warmup session.run call, and then call both
1818    `get()` and `put()` in each subsequent step.
1819
1820    The placement of the returned tensor will be determined by the current
1821    device scope when this function is called.
1822
1823    Args:
1824      name: A name for the operation (optional).
1825
1826    Returns:
1827      The tuple of tensors that was gotten.
1828    """
1829    if name is None:
1830      name = "%s_get" % self._name
1831
1832    # pylint: disable=bad-continuation
1833    fn = lambda: gen_data_flow_ops.unstage(dtypes=self._dtypes,
1834                    shared_name=self._name, name=name,
1835                    capacity=self._capacity,
1836                    memory_limit=self._memory_limit)
1837    # pylint: enable=bad-continuation
1838
1839    return self.__internal_get(fn, name)
1840
1841  def peek(self, index, name=None):
1842    """Peeks at an element in the staging area.
1843
1844    If the staging area is too small to contain the element at
1845    the specified index, it will block until enough elements
1846    are inserted to complete the operation.
1847
1848    The placement of the returned tensor will be determined by
1849    the current device scope when this function is called.
1850
1851    Args:
1852      index: The index of the tensor within the staging area
1853              to look up.
1854      name: A name for the operation (optional).
1855
1856    Returns:
1857      The tuple of tensors that was gotten.
1858    """
1859    if name is None:
1860      name = "%s_peek" % self._name
1861
1862    # pylint: disable=bad-continuation
1863    fn = lambda: gen_data_flow_ops.stage_peek(index,
1864                    dtypes=self._dtypes, shared_name=self._name,
1865                    name=name, capacity=self._capacity,
1866                    memory_limit=self._memory_limit)
1867    # pylint: enable=bad-continuation
1868
1869    return self.__internal_get(fn, name)
1870
1871  def size(self, name=None):
1872    """Returns the number of elements in the staging area.
1873
1874    Args:
1875        name: A name for the operation (optional)
1876
1877    Returns:
1878        The created op
1879    """
1880    if name is None:
1881      name = "%s_size" % self._name
1882
1883    return gen_data_flow_ops.stage_size(
1884        name=name,
1885        shared_name=self._name,
1886        dtypes=self._dtypes,
1887        capacity=self._capacity,
1888        memory_limit=self._memory_limit)
1889
1890  def clear(self, name=None):
1891    """Clears the staging area.
1892
1893    Args:
1894        name: A name for the operation (optional)
1895
1896    Returns:
1897        The created op
1898    """
1899    if name is None:
1900      name = "%s_clear" % self._name
1901
1902    return gen_data_flow_ops.stage_clear(
1903        name=name,
1904        shared_name=self._name,
1905        dtypes=self._dtypes,
1906        capacity=self._capacity,
1907        memory_limit=self._memory_limit)
1908
1909
1910class MapStagingArea(BaseStagingArea):
1911  """A `MapStagingArea` is a TensorFlow data structure that stores tensors across multiple steps, and exposes operations that can put and get tensors.
1912
1913  Each `MapStagingArea` element is a (key, value) pair.
1914  Only int64 keys are supported, other types should be
1915  hashed to produce a key.
1916  Values are a tuple of one or more tensors.
1917  Each tuple component has a static dtype,
1918  and may have a static shape.
1919
1920  The capacity of a `MapStagingArea` may be bounded or unbounded.
1921  It supports multiple concurrent producers and consumers; and
1922  provides exactly-once delivery.
1923
1924  Each value tuple of a `MapStagingArea` is a fixed-length tuple of tensors
1925  whose
1926  dtypes are described by `dtypes`, and whose shapes are optionally described
1927  by the `shapes` argument.
1928
1929  If the `shapes` argument is specified, each component of a staging area
1930  element must have the respective fixed shape. If it is
1931  unspecified, different elements may have different shapes,
1932
1933  It behaves like an associative container with support for:
1934
1935   - put(key, values)
1936   - peek(key)         like dict.get(key)
1937   - get(key)          like dict.pop(key)
1938   - get(key=None)     like dict.popitem()
1939   - size()
1940   - clear()
1941
1942  If ordered a tree structure ordered by key will be used and
1943  get(key=None) will remove (key, value) pairs in increasing key order.
1944  Otherwise a hashtable
1945
1946  It can be configured with a capacity in which case
1947  put(key, values) will block until space becomes available.
1948
1949  Similarly, it can be configured with a memory limit which
1950  will block put(key, values) until space is available.
1951  This is mostly useful for limiting the number of tensors on
1952  devices such as GPUs.
1953
1954  All get() and peek() commands block if the requested
1955  (key, value) pair is not present in the staging area.
1956
1957  Partial puts are supported and will be placed in an incomplete
1958  map until such time as all values associated with the key have
1959  been inserted. Once completed, this (key, value) pair will be
1960  inserted into the map. Data in the incomplete map
1961  counts towards the memory limit, but not towards capacity limit.
1962
1963  Partial gets from the map are also supported.
1964  This removes the partially requested tensors from the entry,
1965  but the entry is only removed from the map once all tensors
1966  associated with it are removed.
1967  """
1968
1969  def __init__(self,
1970               dtypes,
1971               shapes=None,
1972               names=None,
1973               shared_name=None,
1974               ordered=False,
1975               capacity=0,
1976               memory_limit=0):
1977    """Args:
1978
1979      dtypes:  A list of types.  The length of dtypes must equal the number
1980        of tensors in each element.
1981      capacity: (Optional.) Maximum number of elements.
1982        An integer. If zero, the Staging Area is unbounded
1983      memory_limit: (Optional.) Maximum number of bytes of all tensors
1984        in the Staging Area (excluding keys).
1985        An integer. If zero, the Staging Area is unbounded
1986      ordered: (Optional.) If True the underlying data structure
1987        is a tree ordered on key. Otherwise assume a hashtable.
1988      shapes: (Optional.) Constraints on the shapes of tensors in an element.
1989        A list of shape tuples or None. This list is the same length
1990        as dtypes.  If the shape of any tensors in the element are constrained,
1991        all must be; shapes can be None if the shapes should not be constrained.
1992      names: (Optional.) If provided, the `get()` and
1993        `put()` methods will use dictionaries with these names as keys.
1994        Must be None or a list or tuple of the same length as `dtypes`.
1995      shared_name: (Optional.) A name to be used for the shared object. By
1996        passing the same name to two different python objects they will share
1997        the underlying staging area. Must be a string.
1998
1999    Raises:
2000      ValueError: If one of the arguments is invalid.
2001
2002    """
2003
2004    super(MapStagingArea, self).__init__(dtypes, shapes, names, shared_name,
2005                                         capacity, memory_limit)
2006
2007    # Defer to different methods depending if the map is ordered
2008    self._ordered = ordered
2009
2010    if ordered:
2011      self._put_fn = gen_data_flow_ops.ordered_map_stage
2012      self._pop_fn = gen_data_flow_ops.ordered_map_unstage
2013      self._popitem_fn = gen_data_flow_ops.ordered_map_unstage_no_key
2014      self._peek_fn = gen_data_flow_ops.ordered_map_peek
2015      self._size_fn = gen_data_flow_ops.ordered_map_size
2016      self._incomplete_size_fn = gen_data_flow_ops.ordered_map_incomplete_size
2017      self._clear_fn = gen_data_flow_ops.ordered_map_clear
2018    else:
2019      self._put_fn = gen_data_flow_ops.map_stage
2020      self._pop_fn = gen_data_flow_ops.map_unstage
2021      self._popitem_fn = gen_data_flow_ops.map_unstage_no_key
2022      self._peek_fn = gen_data_flow_ops.map_peek
2023      self._size_fn = gen_data_flow_ops.map_size
2024      self._incomplete_size_fn = gen_data_flow_ops.map_incomplete_size
2025      self._clear_fn = gen_data_flow_ops.map_clear
2026
2027  def put(self, key, vals, indices=None, name=None):
2028    """Create an op that stores the (key, vals) pair in the staging area.
2029
2030    Incomplete puts are possible, preferably using a dictionary for vals
2031    as the appropriate dtypes and shapes can be inferred from the value names
2032    dictionary key values. If vals is a list or tuple, indices must
2033    also be specified so that the op knows at which element position
2034    to perform the insert.
2035
2036    This operation will block if the capacity or memory limit of this
2037    container is reached.
2038
2039    Args:
2040        key: Key associated with the data
2041        vals: Tensor (or a dict/tuple of Tensors) to place
2042                into the staging area.
2043        indices: (Optional) if vals is a tuple/list, this is required.
2044        name: A name for the operation (optional)
2045
2046    Returns:
2047        The created op
2048
2049    Raises:
2050        ValueError: If the number or type of inputs don't match the staging
2051        area.
2052    """
2053
2054    with ops.name_scope(name, "%s_put" % self._name,
2055                        self._scope_vals(vals)) as scope:
2056
2057      vals, indices = self._check_put_dtypes(vals, indices)
2058
2059      with ops.colocate_with(self._coloc_op):
2060        op = self._put_fn(
2061            key,
2062            indices,
2063            vals,
2064            dtypes=self._dtypes,
2065            shared_name=self._name,
2066            name=scope,
2067            capacity=self._capacity,
2068            memory_limit=self._memory_limit)
2069    return op
2070
2071  def _get_indices_and_dtypes(self, indices=None):
2072    if indices is None:
2073      indices = list(six.moves.range(len(self._dtypes)))
2074
2075    if not isinstance(indices, (tuple, list)):
2076      raise TypeError("Invalid indices type '%s'" % type(indices))
2077
2078    if len(indices) == 0:
2079      raise ValueError("Empty indices")
2080
2081    if all(isinstance(i, str) for i in indices):
2082      if self._names is None:
2083        raise ValueError("String indices provided '%s', but this Staging Area "
2084                         "was not created with names." % indices)
2085
2086      try:
2087        indices = [self._names.index(n) for n in indices]
2088      except ValueError:
2089        raise ValueError("Named index '%s' not in "
2090                         "Staging Area names '%s'" % (n, self._names))
2091    elif all(isinstance(i, int) for i in indices):
2092      pass
2093    else:
2094      raise TypeError("Mixed types in indices '%s'. "
2095                      "May only be str or int" % indices)
2096
2097    dtypes = [self._dtypes[i] for i in indices]
2098
2099    return indices, dtypes
2100
2101  def peek(self, key, indices=None, name=None):
2102    """Peeks at staging area data associated with the key.
2103
2104    If the key is not in the staging area, it will block
2105    until the associated (key, value) is inserted.
2106
2107    Args:
2108        key: Key associated with the required data
2109        indices: Partial list of tensors to retrieve (optional).
2110                A list of integer or string indices.
2111                String indices are only valid if the Staging Area
2112                has names associated with it.
2113        name: A name for the operation (optional)
2114
2115    Returns:
2116        The created op
2117    """
2118
2119    if name is None:
2120      name = "%s_pop" % self._name
2121
2122    indices, dtypes = self._get_indices_and_dtypes(indices)
2123
2124    with ops.colocate_with(self._coloc_op):
2125      result = self._peek_fn(
2126          key,
2127          shared_name=self._name,
2128          indices=indices,
2129          dtypes=dtypes,
2130          name=name,
2131          capacity=self._capacity,
2132          memory_limit=self._memory_limit)
2133
2134    return self._get_return_value(result, indices)
2135
2136  def get(self, key=None, indices=None, name=None):
2137    """If the key is provided, the associated (key, value) is returned from the staging area.
2138
2139    If the key is not in the staging area, this method will block until
2140    the associated (key, value) is inserted.
2141    If no key is provided and the staging area is ordered,
2142    the (key, value) with the smallest key will be returned.
2143    Otherwise, a random (key, value) will be returned.
2144
2145    If the staging area is empty when this operation executes,
2146    it will block until there is an element to dequeue.
2147
2148    Args:
2149        key: Key associated with the required data (Optional)
2150        indices: Partial list of tensors to retrieve (optional).
2151                A list of integer or string indices.
2152                String indices are only valid if the Staging Area
2153                has names associated with it.
2154        name: A name for the operation (optional)
2155
2156    Returns:
2157        The created op
2158    """
2159    if key is None:
2160      return self._popitem(indices=indices, name=name)
2161    else:
2162      return self._pop(key, indices=indices, name=name)
2163
2164  def _pop(self, key, indices=None, name=None):
2165    """Remove and return the associated (key, value) is returned from the staging area.
2166
2167    If the key is not in the staging area, this method will block until
2168    the associated (key, value) is inserted.
2169    Args:
2170        key: Key associated with the required data
2171        indices: Partial list of tensors to retrieve (optional).
2172                A list of integer or string indices.
2173                String indices are only valid if the Staging Area
2174                has names associated with it.
2175        name: A name for the operation (optional)
2176
2177    Returns:
2178        The created op
2179    """
2180    if name is None:
2181      name = "%s_get" % self._name
2182
2183    indices, dtypes = self._get_indices_and_dtypes(indices)
2184
2185    with ops.colocate_with(self._coloc_op):
2186      result = self._pop_fn(
2187          key,
2188          shared_name=self._name,
2189          indices=indices,
2190          dtypes=dtypes,
2191          name=name,
2192          capacity=self._capacity,
2193          memory_limit=self._memory_limit)
2194
2195    return key, self._get_return_value(result, indices)
2196
2197  def _popitem(self, indices=None, name=None):
2198    """If the staging area is ordered, the (key, value) with the smallest key will be returned.
2199
2200    Otherwise, a random (key, value) will be returned.
2201    If the staging area is empty when this operation executes,
2202    it will block until there is an element to dequeue.
2203
2204    Args:
2205        key: Key associated with the required data
2206        indices: Partial list of tensors to retrieve (optional).
2207                A list of integer or string indices.
2208                String indices are only valid if the Staging Area
2209                has names associated with it.
2210        name: A name for the operation (optional)
2211
2212    Returns:
2213        The created op
2214    """
2215    if name is None:
2216      name = "%s_get_nokey" % self._name
2217
2218    indices, dtypes = self._get_indices_and_dtypes(indices)
2219
2220    with ops.colocate_with(self._coloc_op):
2221      key, result = self._popitem_fn(
2222          shared_name=self._name,
2223          indices=indices,
2224          dtypes=dtypes,
2225          name=name,
2226          capacity=self._capacity,
2227          memory_limit=self._memory_limit)
2228
2229    # Separate keys and results out from
2230    # underlying namedtuple
2231    key = self._create_device_transfers(key)[0]
2232    result = self._get_return_value(result, indices)
2233
2234    return key, result
2235
2236  def size(self, name=None):
2237    """Returns the number of elements in the staging area.
2238
2239    Args:
2240        name: A name for the operation (optional)
2241
2242    Returns:
2243        The created op
2244    """
2245    if name is None:
2246      name = "%s_size" % self._name
2247
2248    return self._size_fn(
2249        shared_name=self._name,
2250        name=name,
2251        dtypes=self._dtypes,
2252        capacity=self._capacity,
2253        memory_limit=self._memory_limit)
2254
2255  def incomplete_size(self, name=None):
2256    """Returns the number of incomplete elements in the staging area.
2257
2258    Args:
2259        name: A name for the operation (optional)
2260
2261    Returns:
2262        The created op
2263    """
2264    if name is None:
2265      name = "%s_incomplete_size" % self._name
2266
2267    return self._incomplete_size_fn(
2268        shared_name=self._name,
2269        name=name,
2270        dtypes=self._dtypes,
2271        capacity=self._capacity,
2272        memory_limit=self._memory_limit)
2273
2274  def clear(self, name=None):
2275    """Clears the staging area.
2276
2277    Args:
2278        name: A name for the operation (optional)
2279
2280    Returns:
2281        The created op
2282    """
2283    if name is None:
2284      name = "%s_clear" % self._name
2285
2286    return self._clear_fn(
2287        shared_name=self._name,
2288        name=name,
2289        dtypes=self._dtypes,
2290        capacity=self._capacity,
2291        memory_limit=self._memory_limit)
2292
2293
2294class RecordInput(object):
2295  """RecordInput asynchronously reads and randomly yields TFRecords.
2296
2297  A RecordInput Op will continuously read a batch of records asynchronously
2298  into a buffer of some fixed capacity. It can also asynchronously yield
2299  random records from this buffer.
2300
2301  It will not start yielding until at least `buffer_size / 2` elements have been
2302  placed into the buffer so that sufficient randomization can take place.
2303
2304  The order the files are read will be shifted each epoch by `shift_amount` so
2305  that the data is presented in a different order every epoch.
2306  """
2307
2308  def __init__(self,
2309               file_pattern,
2310               batch_size=1,
2311               buffer_size=1,
2312               parallelism=1,
2313               shift_ratio=0,
2314               seed=0,
2315               name=None,
2316               batches=None,
2317               compression_type=None):
2318    """Constructs a RecordInput Op.
2319
2320    Args:
2321      file_pattern: File path to the dataset, possibly containing wildcards.
2322        All matching files will be iterated over each epoch.
2323      batch_size: How many records to return at a time.
2324      buffer_size: The maximum number of records the buffer will contain.
2325      parallelism: How many reader threads to use for reading from files.
2326      shift_ratio: What percentage of the total number files to move the start
2327        file forward by each epoch.
2328      seed: Specify the random number seed used by generator that randomizes
2329        records.
2330      name: Optional name for the operation.
2331      batches: None by default, creating a single batch op. Otherwise specifies
2332        how many batches to create, which are returned as a list when
2333        `get_yield_op()` is called. An example use case is to split processing
2334        between devices on one computer.
2335      compression_type: The type of compression for the file. Currently ZLIB and
2336        GZIP are supported. Defaults to none.
2337
2338    Raises:
2339      ValueError: If one of the arguments is invalid.
2340    """
2341    self._batch_size = batch_size
2342    if batches is not None:
2343      self._batch_size *= batches
2344    self._batches = batches
2345    self._file_pattern = file_pattern
2346    self._buffer_size = buffer_size
2347    self._parallelism = parallelism
2348    self._shift_ratio = shift_ratio
2349    self._seed = seed
2350    self._name = name
2351    self._compression_type = python_io.TFRecordCompressionType.NONE
2352    if compression_type is not None:
2353      self._compression_type = compression_type
2354
2355  def get_yield_op(self):
2356    """Adds a node that yields a group of records every time it is executed.
2357    If RecordInput `batches` parameter is not None, it yields a list of
2358    record batches with the specified `batch_size`.
2359    """
2360    compression_type = python_io.TFRecordOptions.get_compression_type_string(
2361        python_io.TFRecordOptions(self._compression_type))
2362    records = gen_data_flow_ops.record_input(
2363        file_pattern=self._file_pattern,
2364        file_buffer_size=self._buffer_size,
2365        file_parallelism=self._parallelism,
2366        file_shuffle_shift_ratio=self._shift_ratio,
2367        batch_size=self._batch_size,
2368        file_random_seed=self._seed,
2369        compression_type=compression_type,
2370        name=self._name)
2371    if self._batches is None:
2372      return records
2373    else:
2374      with ops.name_scope(self._name):
2375        batch_list = [[] for i in six.moves.range(self._batches)]
2376        records = array_ops.split(records, self._batch_size, 0)
2377        records = [array_ops.reshape(record, []) for record in records]
2378        for index, protobuf in zip(six.moves.range(len(records)), records):
2379          batch_index = index % self._batches
2380          batch_list[batch_index].append(protobuf)
2381        return batch_list
2382