• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""SequenceQueueingStateSaver and wrappers.
16
17Please see the reading data how-to for context.
18"""
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import collections
25import numbers
26
27import six
28
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import errors
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import sparse_tensor
33from tensorflow.python.framework import tensor_shape
34from tensorflow.python.framework import tensor_util
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import check_ops
37from tensorflow.python.ops import control_flow_ops
38from tensorflow.python.ops import data_flow_ops
39from tensorflow.python.ops import math_ops
40from tensorflow.python.ops import random_ops
41from tensorflow.python.ops import sparse_ops
42from tensorflow.python.ops import string_ops
43from tensorflow.python.summary import summary
44from tensorflow.python.training import queue_runner
45
46# pylint: disable=protected-access
47_restore_sparse = sparse_ops._take_many_sparse_from_tensors_map
48_store_sparse = sparse_ops._add_many_sparse_to_tensors_map
49# pylint: enable=protected-access
50
51
52class _SequenceInputWrapper(object):
53  """A wrapper object for storing sequence-related input.
54
55  The SequenceInputWapper accepts four objects:
56
57    length: A scalar int containing the length of the input sequence.
58    key: A scalar string containing the unique key of the input sequence.
59    sequences: A dict mapping labels, like `input`, to tensors
60      whose initial index dimension is at least size `length`.
61    context: A dict mapping labels, like `global_target`, to tensors
62      that represent data across the entire example.
63  """
64
65  def __init__(self, length, key, sequences, context):
66    length = ops.convert_to_tensor(length, name="length")
67    key = ops.convert_to_tensor(key, name="key")
68    if not isinstance(sequences, dict):
69      raise TypeError("sequences must be a dict")
70    if not isinstance(context, dict):
71      raise TypeError("context must be a dict")
72    if not sequences:
73      raise ValueError("must have at least one sequence tensor")
74    for k in sequences.keys():
75      if not isinstance(k, six.string_types):
76        raise TypeError("sequence key must be string: %s" % k)
77      if ":" in k:
78        raise ValueError("sequence key may not have a colon: '%s'" % k)
79    for k in context.keys():
80      if not isinstance(k, six.string_types):
81        raise TypeError("context key must be string: %s" % k)
82      if ":" in k:
83        raise ValueError("context key may not have a colon: '%s'" % k)
84    sequences = dict((k, ops.convert_to_tensor(
85        v, name="sequence_%s" % k)) for k, v in sequences.items())
86    context = dict((k, ops.convert_to_tensor(
87        v, name="context_%s" % k)) for k, v in context.items())
88    self._length = length
89    self._key = key
90    self._sequences = sequences
91    self._context = context
92
93  @property
94  def length(self):
95    return self._length
96
97  @property
98  def key(self):
99    return self._key
100
101  @property
102  def sequences(self):
103    return self._sequences
104
105  @property
106  def context(self):
107    return self._context
108
109
110def _check_multiple_of(value, multiple_of):
111  """Checks that value `value` is a non-zero multiple of `multiple_of`.
112
113  Args:
114    value: an int32 scalar Tensor.
115    multiple_of: an int or int32 scalar Tensor.
116
117  Returns:
118    new_value: an int32 scalar Tensor matching `value`, but which includes an
119      assertion that `value` is a multiple of `multiple_of`.
120  """
121  assert isinstance(value, ops.Tensor)
122  with ops.control_dependencies([
123      control_flow_ops.Assert(
124          math_ops.logical_and(
125              math_ops.equal(math_ops.mod(value, multiple_of), 0),
126              math_ops.not_equal(value, 0)), [
127                  string_ops.string_join([
128                      "Tensor %s should be a multiple of: " % value.name,
129                      string_ops.as_string(multiple_of), ", but saw value: ",
130                      string_ops.as_string(value),
131                      ". Consider setting pad=True."
132                  ])
133              ])
134  ]):
135    new_value = array_ops.identity(value, name="multiple_of_checked")
136    return new_value
137
138
139def _check_rank(value, expected_rank):
140  """Check the rank of Tensor `value`, via shape inference and assertions.
141
142  Args:
143    value: A Tensor, possibly with shape associated shape information.
144    expected_rank: int32 scalar (optionally a `Tensor`).
145
146  Returns:
147    new_value: A Tensor matching `value`.  Accessing this tensor tests
148      assertions on its rank.  If expected_rank is not a `Tensor`, then
149      new_value's shape's rank has been set.
150
151  Raises:
152    ValueError: if `expected_rank` is not a `Tensor` and the rank of `value`
153      is known and is not equal to `expected_rank`.
154  """
155  assert isinstance(value, ops.Tensor)
156  with ops.control_dependencies([
157      control_flow_ops.Assert(
158          math_ops.equal(expected_rank, array_ops.rank(value)), [
159              string_ops.string_join([
160                  "Rank of tensor %s should be: " % value.name,
161                  string_ops.as_string(expected_rank), ", shape received:"
162              ]), array_ops.shape(value)
163          ])
164  ]):
165    new_value = array_ops.identity(value, name="rank_checked")
166    if isinstance(expected_rank, ops.Tensor):
167      expected_rank_value = tensor_util.constant_value(expected_rank)
168      if expected_rank_value is not None:
169        expected_rank = int(expected_rank_value)
170    if not isinstance(expected_rank, ops.Tensor):
171      try:
172        new_value.set_shape(new_value.get_shape().with_rank(expected_rank))
173      except ValueError as e:
174        raise ValueError("Rank check failed for %s: %s" % (value.name, str(e)))
175    return new_value
176
177
178def _check_shape(value, expected_shape):
179  """Check the shape of Tensor `value`, via shape inference and assertions.
180
181  Args:
182    value: A Tensor, possibly with shape associated shape information.
183    expected_shape: a `TensorShape`, list of `int32`, or a vector `Tensor`.
184
185  Returns:
186    new_value: A Tensor matching `value`.  Accessing this tensor tests
187      assertions on its shape.  If expected_shape is not a `Tensor`, then
188      new_value's shape has been set.
189
190  Raises:
191    ValueError: if `expected_shape` is not a `Tensor` and the shape of `value`
192      is known and is not equal to `expected_shape`.
193  """
194  assert isinstance(value, ops.Tensor)
195  if isinstance(expected_shape, tensor_shape.TensorShape):
196    expected_shape = expected_shape.as_list()
197  if isinstance(expected_shape, ops.Tensor):
198    expected_shape_value = tensor_util.constant_value(expected_shape)
199    if expected_shape_value is not None:
200      expected_shape = [int(d) for d in expected_shape_value]
201  if isinstance(expected_shape, ops.Tensor):
202    value = _check_rank(value, array_ops.size(expected_shape))
203  else:
204    value = _check_rank(value, len(expected_shape))
205  with ops.control_dependencies([
206      control_flow_ops.Assert(
207          math_ops.reduce_all(
208              math_ops.equal(expected_shape, array_ops.shape(value))), [
209                  string_ops.string_join([
210                      "Shape of tensor %s should be: " % value.name,
211                      string_ops.as_string(expected_shape),
212                      ", shape received: ",
213                      string_ops.as_string(array_ops.shape(value))
214                  ])
215              ])
216  ]):
217    new_value = array_ops.identity(value, name="shape_checked")
218    if not isinstance(expected_shape, ops.Tensor):
219      try:
220        new_value.set_shape(new_value.get_shape().merge_with(expected_shape))
221      except ValueError as e:
222        raise ValueError("Shape check failed for %s: %s" % (value.name, str(e)))
223    return new_value
224
225
226def _check_dimensions(value, dimensions, expected_sizes, debug_prefix):
227  """Check the dimensions of Tensor `value`, via shape inference and assertions.
228
229  Args:
230    value: A Tensor, with optional / partial shape associated shape information.
231    dimensions: An int list, the dimensions to check.
232    expected_sizes: list of mixed ints and int32 scalar tensors.
233      Optionally also a vector `Tensor`.
234    debug_prefix: A string, used for naming ops and printing debugging messages.
235
236  Returns:
237    new_value: A Tensor matching `value`.  Accessing this tensor tests
238      assertions on its shape.  If expected_sizes is not a `Tensor`, then
239      new_value's shape has been set for all `dimensions[i]` where
240      `expected_sizes[i]` is not a `Tensor`.
241
242  Raises:
243    TypeError: if any of the input contains invalid types:
244      if `value` is not a `Tensor`.
245      if `dimensions` is not a `list` or `tuple`.
246    ValueError: if input has incorrect sizes or inferred shapes do not match:
247      if `dimensions` contains repeated dimensions.
248      if `expected_sizes` is not a `Tensor` and its length does not match that
249        `dimensions`.
250      if `value`'s shape has a well-defined rank, and one of the values in
251        `dimensions` is equal to or above this rank.
252      if `value`'s shape is well defined for some `dimensions[i]`, and
253        `expected_sizes[i]` is not a `Tensor`, and these two values do
254        not match.
255  """
256
257  if not isinstance(dimensions, (list, tuple)):
258    raise TypeError("dimensions must be a list or tuple")
259  if len(set(dimensions)) != len(dimensions):
260    raise ValueError("dimensions are not unique: %s" % dimensions)
261  if not isinstance(value, ops.Tensor):
262    raise TypeError("value is not a Tensor: %s" % value)
263  value_shape = value.get_shape()
264  if not isinstance(expected_sizes, ops.Tensor):
265    if len(dimensions) != len(expected_sizes):
266      raise ValueError("len(dimensions) != len(expected_sizes): %d vs. %d" %
267                       (len(dimensions), len(expected_sizes)))
268    if value_shape.ndims is not None:
269      if value_shape.ndims <= max(dimensions):
270        raise ValueError(
271            "%s: rank of input is not greater than max(dimensions): "
272            "%d vs. %d" % (debug_prefix, value.get_shape().ndims,
273                           max(dimensions)))
274      value_dims = value_shape.as_list()
275      for d, s in zip(dimensions, expected_sizes):
276        if not isinstance(s, ops.Tensor):
277          value_dims[d] = s
278      try:
279        value.set_shape(value.get_shape().merge_with(value_dims))
280      except ValueError as e:
281        raise ValueError("Dimensions check failed for %s: %s" %
282                         (debug_prefix, str(e)))
283  with ops.control_dependencies([
284      control_flow_ops.Assert(
285          math_ops.equal(expected_size, array_ops.shape(value)[dimension]), [
286              string_ops.string_join([
287                  "Dimension %d of tensor labeled %s should be: " %
288                  (dimension, debug_prefix),
289                  string_ops.as_string(expected_size), ", shape received: ",
290                  string_ops.as_string(array_ops.shape(value))
291              ])
292          ]) for (dimension, expected_size) in zip(dimensions, expected_sizes)
293  ]):
294    new_value = array_ops.identity(value, name="dims_checked_%s" % debug_prefix)
295    return new_value
296
297
298def _prepare_sequence_inputs(inputs, states):
299  """Convert input to tensors and validate shape information.
300
301  Args:
302    inputs: A `_SequenceInputWrapper` instance.
303    states: A dictionary mapping state names to input constants or tensors.
304
305  Returns:
306    The tuple (length, key, sorted_states, sorted_sequences, sorted_context),
307    where each value has been checked for valid shape, and the sorted_* dicts
308    are instances of OrderedDict; with key-value pairs sorted by key.
309
310  Raises:
311    ValueError: if the shapes of inputs.context.values(), states.values(),
312      or inputs.sequences.values() are not fully defined (with the exception
313      of the dimension of any `Tensor` in inputs.sequences.values()).
314    TypeError: if the dtype of length is not int32.
315  """
316  # Convert state initial values to tensors
317  states = dict((k, ops.convert_to_tensor(
318      v, name="state_%s" % k)) for k, v in states.items())
319
320  def _assert_fully_defined(label, dict_, ignore_first_dimension=False):
321    start_dimension = 1 if ignore_first_dimension else 0
322    for k, v in dict_.items():
323      if not v.get_shape()[start_dimension:].is_fully_defined():
324        raise ValueError("Shape for %s %s is not fully defined %s: %s" %
325                         (label, k, "(ignoring first dimension)" if
326                          ignore_first_dimension else "", v.get_shape()))
327
328  _assert_fully_defined("state", states)
329  _assert_fully_defined("context", inputs.context)
330  # Sequences' first dimension (time) may be variable
331  _assert_fully_defined(
332      "sequence", inputs.sequences, ignore_first_dimension=True)
333
334  # Get dictionaries' dtypes ordered by name - ordering is important
335  # when switching between dicts and tuples for passing to Barrier.
336  def _sort_by_name(d):
337    return collections.OrderedDict(sorted(d.items(), key=lambda k_v: k_v[0]))
338
339  sorted_sequences = _sort_by_name(inputs.sequences)
340  sorted_context = _sort_by_name(inputs.context)
341  sorted_states = _sort_by_name(states)
342
343  length = _check_rank(inputs.length, 0)
344  key = _check_rank(inputs.key, 0)
345
346  if length.dtype != dtypes.int32:
347    raise TypeError("length dtype must be int32, but received: %s" %
348                    length.dtype)
349  if key.dtype != dtypes.string:
350    raise TypeError("key dtype must be string, but received: %s" % key.dtype)
351
352  return (length, key, sorted_states, sorted_sequences, sorted_context)
353
354
355# NextQueuedSequenceBatch works closely with
356# SequenceQueueingStateSaver and requires access to its private properties
357# pylint: disable=protected-access
358class NextQueuedSequenceBatch(object):
359  """NextQueuedSequenceBatch stores deferred SequenceQueueingStateSaver data.
360
361  This class is instantiated by `SequenceQueueingStateSaver` and is accessible
362  via its `next_batch` property.
363  """
364
365  def __init__(self, state_saver):
366    self._state_saver = state_saver
367
368  @property
369  def total_length(self):
370    """The lengths of the original (non-truncated) unrolled examples.
371
372    Returns:
373      An integer vector of length `batch_size`, the total lengths.
374    """
375    return self._state_saver._received_total_length
376
377  @property
378  def length(self):
379    """The lengths of the given truncated unrolled examples.
380
381    For initial iterations, for which `sequence * num_unroll < length`,
382    this number is `num_unroll`.  For the remainder,
383    this number is between `0` and `num_unroll`.
384
385    Returns:
386      An integer vector of length `batch_size`, the lengths.
387    """
388    return self._state_saver._received_length
389
390  @property
391  def batch_size(self):
392    """The batch_size of the given batch.
393
394    Usually, this is the batch_size requested when initializing the SQSS, but
395    if allow_small_batch=True this will become smaller when inputs are
396    exhausted.
397
398    Returns:
399      A scalar integer tensor, the batch_size
400    """
401    return self._state_saver._received_batch_size
402
403  @property
404  def insertion_index(self):
405    """The insertion indices of the examples (when they were first added).
406
407    These indices start with the value -2**63 and increase with every
408    call to the prefetch op.  Each whole example gets its own insertion
409    index, and this is used to prioritize the example so that its truncated
410    segments appear in adjacent iterations, even if new examples are inserted
411    by the prefetch op between iterations.
412
413    Returns:
414      An int64 vector of length `batch_size`, the insertion indices.
415    """
416    return self._state_saver._received_indices
417
418  @property
419  def key(self):
420    """The key names of the given truncated unrolled examples.
421
422    The format of the key is:
423
424    ```python
425    "%05d_of_%05d:%s" % (sequence, sequence_count, original_key)
426    ```
427
428    where `original_key` is the unique key read in by the prefetcher.
429
430    Returns:
431      A string vector of length `batch_size`, the keys.
432    """
433    return self._state_saver._received_keys
434
435  @property
436  def next_key(self):
437    """The key names of the next (in iteration) truncated unrolled examples.
438
439    The format of the key is:
440
441    ```python
442    "%05d_of_%05d:%s" % (sequence + 1, sequence_count, original_key)
443    ```
444
445    if `sequence + 1 < sequence_count`, otherwise:
446
447    ```python
448    "STOP:%s" % original_key
449    ```
450
451    where `original_key` is the unique key read in by the prefetcher.
452
453    Returns:
454      A string vector of length `batch_size`, the keys.
455    """
456    return self._state_saver._received_next_key
457
458  @property
459  def sequence(self):
460    """An int32 vector, length `batch_size`: the sequence index of each entry.
461
462    When an input is split up, the sequence values
463    ```
464    0, 1, ..., sequence_count - 1
465    ```
466    are assigned to each split.
467
468    Returns:
469      An int32 vector `Tensor`.
470    """
471    return self._state_saver._received_sequence
472
473  @property
474  def sequence_count(self):
475    """An int32 vector, length `batch_size`: the sequence count of each entry.
476
477    When an input is split up, the number of splits is equal to:
478    `padded_length / num_unroll`.  This is the sequence_count.
479
480    Returns:
481      An int32 vector `Tensor`.
482    """
483    return self._state_saver._received_sequence_count
484
485  @property
486  def context(self):
487    """A dict mapping keys of `input_context` to batched context.
488
489    Returns:
490      A dict mapping keys of `input_context` to tensors.
491      If we had at input:
492
493      ```python
494      context["name"].get_shape() == [d1, d2, ...]
495      ```
496
497      then for this property:
498
499      ```python
500      context["name"].get_shape() == [batch_size, d1, d2, ...]
501      ```
502
503    """
504    return self._state_saver._received_context
505
506  @property
507  def sequences(self):
508    """A dict mapping keys of `input_sequences` to split and rebatched data.
509
510    Returns:
511      A dict mapping keys of `input_sequences` to tensors.
512      If we had at input:
513
514      ```python
515      sequences["name"].get_shape() == [None, d1, d2, ...]
516      ```
517
518      where `None` meant the sequence time was dynamic, then for this property:
519
520      ```python
521      sequences["name"].get_shape() == [batch_size, num_unroll, d1, d2, ...].
522      ```
523
524    """
525    return self._state_saver._received_sequences
526
527  def state(self, state_name):
528    """Returns batched state tensors.
529
530    Args:
531      state_name: string, matches a key provided in `initial_states`.
532
533    Returns:
534      A `Tensor`: a batched set of states, either initial states (if this is
535      the first run of the given example), or a value as stored during
536      a previous iteration via `save_state` control flow.
537      Its type is the same as `initial_states["state_name"].dtype`.
538      If we had at input:
539
540      ```python
541      initial_states[state_name].get_shape() == [d1, d2, ...],
542      ```
543
544      then
545
546      ```python
547      state(state_name).get_shape() == [batch_size, d1, d2, ...]
548      ```
549
550    Raises:
551      KeyError: if `state_name` does not match any of the initial states
552        declared in `initial_states`.
553    """
554    return self._state_saver._received_states[state_name]
555
556  def save_state(self, state_name, value, name=None):
557    """Returns an op to save the current batch of state `state_name`.
558
559    Args:
560      state_name: string, matches a key provided in `initial_states`.
561      value: A `Tensor`.
562        Its type must match that of `initial_states[state_name].dtype`.
563        If we had at input:
564
565        ```python
566        initial_states[state_name].get_shape() == [d1, d2, ...]
567        ```
568
569        then the shape of `value` must match:
570
571        ```python
572        tf.shape(value) == [batch_size, d1, d2, ...]
573        ```
574
575      name: string (optional).  The name scope for newly created ops.
576
577    Returns:
578      A control flow op that stores the new state of each entry into
579      the state saver.  This op must be run for every iteration that
580      accesses data from the state saver (otherwise the state saver
581      will never progress through its states and run out of capacity).
582
583    Raises:
584      KeyError: if `state_name` does not match any of the initial states
585        declared in `initial_states`.
586    """
587    if state_name not in self._state_saver._received_states.keys():
588      raise KeyError("state was not declared: %s" % state_name)
589    default_name = "InputQueueingStateSaver_SaveState"
590    with ops.name_scope(name, default_name, values=[value]):
591      # Place all operations on the CPU. Barriers and queues are only
592      # implemented for CPU, but all the other book-keeping operations
593      # (reshape, shape, range, ...) would be placed on GPUs if available,
594      # unless we explicitly tie them to CPU.
595      with ops.colocate_with(self._state_saver._capacity_queue.queue_ref):
596        indices_where_not_done = array_ops.reshape(
597            array_ops.where(
598                math_ops.logical_not(self._state_saver._sequence_is_done)),
599            [-1])
600        keeping_next_key = array_ops.gather(
601            self._state_saver._received_next_key, indices_where_not_done)
602        value = _check_shape(
603            array_ops.identity(
604                value, name="convert_%s" % state_name),
605            array_ops.shape(self._state_saver._received_states[state_name]))
606        keeping_state = array_ops.gather(value, indices_where_not_done)
607        return self._state_saver._barrier.insert_many(
608            self._state_saver._get_barrier_index("state", state_name),
609            keeping_next_key,
610            keeping_state,
611            name="BarrierInsertState_%s" % state_name)
612
613
614# pylint: enable=protected-access
615
616
617class SequenceQueueingStateSaver(object):
618  """SequenceQueueingStateSaver provides access to stateful values from input.
619
620  This class is meant to be used instead of, e.g., a `Queue`, for splitting
621  variable-length sequence inputs into segments of sequences with fixed length
622  and batching them into mini-batches.  It maintains contexts and state for a
623  sequence across the segments.  It can be used in conjunction with a
624  `QueueRunner` (see the example below).
625
626  The `SequenceQueueingStateSaver` (SQSS) accepts one example at a time via the
627  inputs `input_length`, `input_key`, `input_sequences` (a dict),
628  `input_context` (a dict), and `initial_states` (a dict).
629  The sequences, values in `input_sequences`, may have variable first dimension
630  (the `padded_length`), though this dimension must always be a multiple of
631  `num_unroll`.  All other dimensions must be fixed and accessible via
632  `get_shape` calls.  The length prior to padding can be recorded in
633  `input_length`.  The context values in `input_context` must all have fixed and
634  well defined dimensions.  The initial state values must all have fixed and
635  well defined dimensions.
636
637  The SQSS splits the sequences of an input example into segments of length
638  `num_unroll`.  Across examples minibatches of size `batch_size` are formed.
639  These minibatches contain a segment of the sequences, copy the context values,
640  and maintain state, length, and key information of the original input
641  examples.  In the first segment of an example the state is still the initial
642  state.  It can then be updated; and updated state values are accessible in
643  subsequent segments of the same example. After each segment
644  `batch.save_state()` must be called which is done by the state_saving_rnn.
645  Without this call, the dequeue op associated with the SQSS will not run.
646  Internally, SQSS has a queue for the input examples. Its `capacity` is
647  configurable.  If set smaller than `batch_size` then the dequeue op will block
648  indefinitely.  A small multiple of `batch_size` is a good rule of thumb to
649  prevent that queue from becoming a bottleneck and slowing down training.
650  If set too large (and note that it defaults to unbounded) memory consumption
651  goes up.  Moreover, when iterating over the same input examples multiple times
652  reusing the same `key` the `capacity` must be smaller than the number of
653  examples.
654
655  The prefetcher, which reads one unrolled, variable-length input sequence at
656  a time, is accessible via `prefetch_op`.  The underlying `Barrier` object
657  is accessible via `barrier`.  Processed minibatches, as well as
658  state read and write capabilities are accessible via `next_batch`.
659  Specifically, `next_batch` provides access to all of the minibatched
660  data, including the following, see `NextQueuedSequenceBatch` for details:
661
662  *  `total_length`, `length`, `insertion_index`, `key`, `next_key`,
663  *  `sequence` (the index each minibatch entry's time segment index),
664  *  `sequence_count` (the total time segment count for each minibatch entry),
665  *  `context` (a dict of the copied minibatched context values),
666  *  `sequences` (a dict of the split minibatched variable-length sequences),
667  *  `state` (to access the states of the current segments of these entries)
668  *  `save_state` (to save the states for the next segments of these entries)
669
670  Example usage:
671
672  ```python
673  batch_size = 32
674  num_unroll = 20
675  lstm_size = 8
676  cell = tf.contrib.rnn.BasicLSTMCell(num_units=lstm_size)
677  initial_state_values = tf.zeros(cell.state_size, dtype=tf.float32)
678
679  raw_data = get_single_input_from_input_reader()
680  length, key, sequences, context = my_parser(raw_data)
681  assert "input" in sequences.keys()
682  assert "label" in context.keys()
683  initial_states = {"lstm_state": initial_state_value}
684
685  stateful_reader = tf.SequenceQueueingStateSaver(
686      batch_size, num_unroll,
687      length=length, input_key=key, input_sequences=sequences,
688      input_context=context, initial_states=initial_states,
689      capacity=batch_size*100)
690
691  batch = stateful_reader.next_batch
692  inputs = batch.sequences["input"]
693  context_label = batch.context["label"]
694
695  inputs_by_time = tf.split(value=inputs, num_or_size_splits=num_unroll, axis=1)
696  assert len(inputs_by_time) == num_unroll
697
698  lstm_output, _ = tf.contrib.rnn.static_state_saving_rnn(
699    cell,
700    inputs_by_time,
701    state_saver=batch,
702    state_name="lstm_state")
703
704  # Start a prefetcher in the background
705  sess = tf.Session()
706  num_threads = 3
707  queue_runner = tf.train.QueueRunner(
708      stateful_reader, [stateful_reader.prefetch_op] * num_threads)
709  tf.train.add_queue_runner(queue_runner)
710  tf.train.start_queue_runners(sess=session)
711
712  while True:
713    # Step through batches, perform training or inference...
714    session.run([lstm_output])
715  ```
716
717  **Note**: Usually the barrier is given to a QueueRunner as in the
718      examples above.  The QueueRunner will close the barrier if the prefetch_op
719      receives an OutOfRange Error from upstream input queues (i.e., reaches
720      the end of the input).  If the barrier is closed no further new examples
721      are added to the SQSS.  The underlying barrier might, however, still
722      contain further unroll-steps of examples that have not undergone all
723      iterations.  To gracefully finish all examples, the flag
724      `allow_small_batch` must be set to true, which causes the SQSS to issue
725      progressively smaller mini-batches with the remaining examples.
726  """
727
728  def __init__(self,
729               batch_size,
730               num_unroll,
731               input_length,
732               input_key,
733               input_sequences,
734               input_context,
735               initial_states,
736               capacity=None,
737               allow_small_batch=False,
738               name=None):
739    """Creates the SequenceQueueingStateSaver.
740
741    Args:
742      batch_size: int or int32 scalar `Tensor`, how large minibatches should
743        be when accessing the `state()` method and `context`, `sequences`, etc,
744        properties.
745      num_unroll: Python integer, how many time steps to unroll at a time.
746        The input sequences of length `k` are then split into `k / num_unroll`
747        many segments.
748      input_length: An int32 scalar `Tensor`, the length of the sequence prior
749        to padding.  This value may be at most `padded_length` for any given
750        input (see below for the definition of `padded_length`).
751        Batched and total lengths of the current iteration are made accessible
752        via the `length` and `total_length` properties.  The shape of
753        input_length (scalar) must be fully specified.
754      input_key: A string scalar `Tensor`, the **unique** key for the given
755        input.  This is used to keep track of the split minibatch elements
756        of this input.  Batched keys of the current iteration are made
757        accessible via the `key` property.  The shape of `input_key` (scalar)
758        must be fully specified.
759      input_sequences: A dict mapping string names to `Tensor` values.  The
760        values must all have matching first dimension, called `padded_length`.
761        The `SequenceQueueingStateSaver` will split these tensors along
762        this first dimension into minibatch elements of dimension
763        `num_unroll`. Batched and segmented sequences of the current iteration
764        are made accessible via the `sequences` property.
765
766        **Note**: `padded_length` may be dynamic, and may vary from input
767        to input, but must always be a multiple of `num_unroll`.  The remainder
768        of the shape (other than the first dimension) must be fully specified.
769      input_context: A dict mapping string names to `Tensor` values.  The values
770        are treated as "global" across all time splits of the given input,
771        and will be copied across for all minibatch elements accordingly.
772        Batched and copied context of the current iteration are made
773        accessible via the `context` property.
774
775        **Note**: All input_context values must have fully defined shapes.
776      initial_states: A dict mapping string state names to multi-dimensional
777        values (e.g. constants or tensors).  This input defines the set of
778        states that will be kept track of during computing iterations, and
779        which can be accessed via the `state` and `save_state` methods.
780
781        **Note**: All initial_state values must have fully defined shapes.
782      capacity: The max capacity of the SQSS in number of examples. Needs to be
783        at least `batch_size`. Defaults to unbounded.
784      allow_small_batch: If true, the SQSS will return smaller batches when
785        there aren't enough input examples to fill a whole batch and the end of
786        the input has been reached (i.e., the underlying barrier has been
787        closed).
788      name: An op name string (optional).
789
790    Raises:
791      TypeError: if any of the inputs is not an expected type.
792      ValueError: if any of the input values is inconsistent, e.g. if
793      not enough shape information is available from inputs to build
794      the state saver.
795    """
796    if capacity is not None and isinstance(batch_size, ops.Tensor):
797      with ops.control_dependencies([check_ops.assert_greater_equal(
798          math_ops.cast(capacity, dtype=dtypes.int64),
799          math_ops.cast(batch_size, dtype=dtypes.int64),
800          message="capacity needs to be >= batch_size.")]):
801        input_key = array_ops.identity(input_key)
802    elif capacity is not None and capacity < batch_size:
803      raise ValueError("capacity %d needs to be >= batch_size %d" % (
804          capacity, batch_size))
805    # The barrier is ignorant of the number of actual examples, since a long
806    # example that requires many iterations produces more elements in the
807    # barrier than a short example. Furthermore, we don't have an upper bound
808    # on the length of examples, and hence have to keep the capacity of the
809    # barrier at infinite to avoid dead-lock. Instead we have to keep track of
810    # the number of active examples in this class, and block the prefetch_op
811    # when capacity is reached. To this end, we employ a FIFOQueue in which we
812    # store one token (its value doesn't matter) for each input example, and
813    # dequeue a token for each completed example. Since the capacity of this
814    # queue is limited the enqueue operation will block if capacity is reached.
815    self._capacity_queue = data_flow_ops.FIFOQueue(
816        capacity=capacity, dtypes=[dtypes.int32], shapes=[[]])
817    # Place all operations on the CPU. Barriers and queues are only implemented
818    # for CPU, but all the other book-keeping operations
819    # (reshape, shape, range, ...) would be placed on GPUs if available,
820    # unless we explicitly tie them to CPU.
821    with ops.colocate_with(self._capacity_queue.queue_ref):
822      if not isinstance(initial_states, dict):
823        raise TypeError("initial_states must be a dictionary")
824      if not initial_states:
825        raise ValueError(
826            "initial_states may not be empty: at least one state variable is "
827            "required to properly enqueue split sequences to run in separate "
828            "iterations")
829      for k in initial_states:
830        if not isinstance(k, six.string_types):
831          raise TypeError("state name must be a string: %s" % k)
832        if ":" in k:
833          raise ValueError("state name may not have a colon: '%s'" % k)
834
835      op_vars = ([input_length, input_key] + list(input_sequences.values()) +
836                 list(input_context.values()))
837      with ops.name_scope(name, "InputQueueingStateSaver", op_vars) as scope:
838        inputs = _SequenceInputWrapper(input_length, input_key, input_sequences,
839                                       input_context)
840        self._batch_size = batch_size
841        self._num_unroll = num_unroll
842        self._name = scope
843
844        # This step makes sure all shapes are well defined.  We can now
845        # use get_shape() on any tensor in the output of this function
846        # and get a fully-defined shape.
847        (self._length, self._key, self._sorted_states, self._sorted_sequences,
848         self._sorted_context) = _prepare_sequence_inputs(inputs,
849                                                          initial_states)
850        self._padded_length = array_ops.identity(
851            array_ops.shape(six.next(six.itervalues(self._sorted_sequences)))[
852                0],
853            name="padded_length")  # The name is useful for debugging
854        self._padded_length = _check_multiple_of(self._padded_length,
855                                                 self._num_unroll)
856
857        # sequences should have length == all matching
858        self._sorted_sequences = collections.OrderedDict(
859            (k, _check_dimensions(
860                v, [0], [self._padded_length],
861                debug_prefix="sorted_sequences_%s" % k))
862            for k, v in self._sorted_sequences.items())
863        self._uninitialized_states = self._sorted_states
864
865        # Once this is set, self._get_barrier_*_index are available for use.
866        self._store_index_maps(self._sorted_sequences, self._sorted_context,
867                               self._sorted_states)
868
869        # Make sure that the length is <= the padded_length
870        with ops.control_dependencies([
871            control_flow_ops.Assert(
872                math_ops.less_equal(self._length, self._padded_length), [
873                    "Input length should be <= than length from sequences:",
874                    self._length, " vs. ", self._padded_length
875                ])
876        ]):
877          self._length = array_ops.identity(self._length)
878
879        # Only create barrier; enqueue and dequeue operations happen when you
880        # access prefetch_op and next_batch.
881        self._create_barrier()
882        self._scope = scope
883      self._allow_small_batch = allow_small_batch
884      self._prefetch_op = None
885      self._next_batch = None
886
887  @property
888  def name(self):
889    return self._name
890
891  @property
892  def barrier(self):
893    return self._barrier
894
895  @property
896  def batch_size(self):
897    return self._batch_size
898
899  @property
900  def num_unroll(self):
901    return self._num_unroll
902
903  @property
904  def prefetch_op(self):
905    """The op used to prefetch new data into the state saver.
906
907    Running it once enqueues one new input example into the state saver.
908    The first time this gets called, it additionally creates the prefetch_op.
909    Subsequent calls simply return the previously created `prefetch_op`.
910
911    It should be run in a separate thread via e.g. a `QueueRunner`.
912
913    Returns:
914      An `Operation` that performs prefetching.
915    """
916    if not self._prefetch_op:
917      with ops.name_scope(None), ops.name_scope(
918          self._scope, values=[self._barrier.barrier_ref]):
919        self._create_prefetch_op()
920    return self._prefetch_op
921
922  @property
923  def next_batch(self):
924    """The `NextQueuedSequenceBatch` providing access to batched output data.
925
926    Also provides access to the `state` and `save_state` methods.
927    The first time this gets called, it additionally prepares barrier reads
928    and creates `NextQueuedSequenceBatch` / next_batch objects. Subsequent
929    calls simply return the previously created `next_batch`.
930
931    In order to access data in `next_batch` without blocking, the `prefetch_op`
932    must have been run at least `batch_size` times (ideally in a separate
933    thread, or launched via a `QueueRunner`). After processing a segment in
934    `next_batch()`, `batch.save_state()` must be called which is done by the
935    state_saving_rnn. Without this call, the dequeue op associated with the SQSS
936    will not run.
937
938    Returns:
939      A cached `NextQueuedSequenceBatch` instance.
940    """
941    # This is needed to prevent errors if next_batch is called before
942    # prefetch_op is created.
943    if not self._prefetch_op:
944      with ops.name_scope(None), ops.name_scope(
945          self._scope, values=[self._barrier.barrier_ref]):
946        self._create_prefetch_op()
947    if not self._next_batch:
948      with ops.name_scope(None), ops.name_scope(
949          self._scope, values=[self._barrier.barrier_ref]):
950        self._prepare_barrier_reads()
951    return self._next_batch
952
953  def close(self, cancel_pending_enqueues=False, name=None):
954    """Closes the barrier and the FIFOQueue.
955
956    This operation signals that no more segments of new sequences will be
957    enqueued. New segments of already inserted sequences may still be enqueued
958    and dequeued if there is a sufficient number filling a batch or
959    allow_small_batch is true. Otherwise dequeue operations will fail
960    immediately.
961
962    Args:
963      cancel_pending_enqueues: (Optional.) A boolean, defaulting to
964        `False`. If `True`, all pending enqueues to the underlying queues will
965        be cancelled, and completing already started sequences is not possible.
966      name: Optional name for the op.
967
968    Returns:
969      The operation that closes the barrier and the FIFOQueue.
970    """
971    with ops.name_scope(name, "SQSSClose", [self._prefetch_op]) as name:
972      barrier_close = self.barrier.close(cancel_pending_enqueues,
973                                         "BarrierClose")
974      fifo_queue_close = self._capacity_queue.close(cancel_pending_enqueues,
975                                                    "FIFOClose")
976      return control_flow_ops.group(barrier_close, fifo_queue_close, name=name)
977
978  def _store_index_maps(self, sequences, context, states):
979    """Prepares the internal dictionaries _name_to_index and _index_to_name.
980
981    These dictionaries are used to keep track of indices into the barrier.
982
983    Args:
984      sequences: `OrderedDict` of string, `Tensor` pairs.
985      context: `OrderedDict` of string, `Tensor` pairs.
986      states: `OrderedDict` of string, `Tensor` pairs.
987    """
988    assert isinstance(sequences, dict)
989    assert isinstance(context, dict)
990    assert isinstance(states, dict)
991    self._name_to_index = {
992        name: ix
993        for (ix, name) in enumerate([
994            "__length", "__total_length", "__next_key", "__sequence",
995            "__sequence_count"
996        ] + ["__sequence__%s" % k for k in sequences.keys()] + [
997            "__context__%s" % k for k in context.keys()
998        ] + ["__state__%s" % k for k in states.keys()])}
999    self._index_to_name = [
1000        name
1001        for (name, _) in sorted(
1002            self._name_to_index.items(), key=lambda n_ix: n_ix[1])
1003    ]
1004
1005  def _get_barrier_length_index(self):
1006    return self._name_to_index["__length"]
1007
1008  def _get_barrier_total_length_index(self):
1009    return self._name_to_index["__total_length"]
1010
1011  def _get_barrier_next_key_index(self):
1012    return self._name_to_index["__next_key"]
1013
1014  def _get_barrier_sequence_index(self):
1015    return self._name_to_index["__sequence"]
1016
1017  def _get_barrier_sequence_count_index(self):
1018    return self._name_to_index["__sequence_count"]
1019
1020  def _get_barrier_index(self, index_type, name):
1021    assert index_type in ("sequence", "context", "state")
1022    key = "__%s__%s" % (index_type, name)
1023    assert key in self._name_to_index, (
1024        "Requested a name not in the value type %s: %s" % (index_type, name))
1025    return self._name_to_index[key]
1026
1027  def _create_barrier(self):
1028    """Create the barrier.
1029
1030    This method initializes the Barrier object with the right types and shapes.
1031    """
1032    # Create the barrier
1033    sequence_dtypes = [v.dtype for k, v in self._sorted_sequences.items()]
1034    context_dtypes = [v.dtype for k, v in self._sorted_context.items()]
1035    state_dtypes = [v.dtype for k, v in self._sorted_states.items()]
1036    types = ([
1037        dtypes.int32,  # length
1038        dtypes.int32,  # total_length
1039        dtypes.string,  # next_keys
1040        dtypes.int32,  # sequence
1041        dtypes.int32
1042    ]  # expanded_sequence_count
1043             + sequence_dtypes + context_dtypes + state_dtypes)
1044    sequence_shapes = [
1045        [self._num_unroll] + self._sorted_sequences[k].get_shape().as_list()[1:]
1046        for k in self._sorted_sequences.keys()
1047    ]
1048    context_shapes = [
1049        self._sorted_context[k].get_shape().as_list()
1050        for k in self._sorted_context.keys()
1051    ]
1052    state_shapes = [
1053        self._sorted_states[k].get_shape().as_list()
1054        for k in self._sorted_states.keys()
1055    ]
1056    shapes = ([
1057        (),  # length
1058        (),  # total_length
1059        (),  # next_keys
1060        (),  # sequence
1061        ()
1062    ]  # expanded_sequence_count
1063              + sequence_shapes + context_shapes + state_shapes)
1064
1065    self._barrier = data_flow_ops.Barrier(types=types, shapes=shapes)
1066
1067  def _create_prefetch_op(self):
1068    """Group insert_many ops and create prefetch_op.
1069
1070    This method implements the "meat" of the logic underlying the
1071    `SequenceQueueingStateSaver`.  It performs dynamic reshaping of
1072    sequences, copying of context, and initial insertion of these values,
1073    as well as the key, next_key, sequence, sequence_count, and initial
1074    states into the barrier.
1075    """
1076    # Step 1: identify how many barrier entries to split this input
1077    # into, store the result as a scalar
1078    sequence_count = math_ops.div(self._padded_length, self._num_unroll)
1079    sequence_count_vec = array_ops.expand_dims(sequence_count, 0)
1080
1081    # The final unrolled sequence's length is num_unroll only in
1082    # the case that num_unroll divides it evenly.
1083    ones = array_ops.ones(sequence_count_vec, dtype=dtypes.int32)
1084    sequence = math_ops.range(sequence_count)
1085    expanded_length = math_ops.maximum(
1086        0, self._length - self._num_unroll * sequence)
1087    expanded_length = math_ops.minimum(self._num_unroll, expanded_length)
1088    expanded_total_length = self._length * ones
1089    expanded_sequence_count = sequence_count * ones
1090    current_keys = string_ops.string_join(
1091        [
1092            string_ops.as_string(
1093                sequence, width=5, fill="0"), "_of_", string_ops.as_string(
1094                    sequence_count, width=5, fill="0"), ":", self._key
1095        ],
1096        name="StringJoinCurrentKeys")
1097    next_keys = array_ops.concat(
1098        [
1099            array_ops.slice(current_keys, [1], [-1]), array_ops.expand_dims(
1100                string_ops.string_join(
1101                    ["STOP:", self._key], name="StringJoinStop"),
1102                0)
1103        ],
1104        0,
1105        name="concat_next_keys")
1106    reshaped_sequences = collections.OrderedDict((
1107        k,
1108        _check_dimensions(
1109            # Reshape sequences to sequence_count rows
1110            array_ops.reshape(
1111                v,
1112                array_ops.concat(
1113                    [
1114                        array_ops.expand_dims(sequence_count, 0),
1115                        array_ops.expand_dims(self._num_unroll, 0),
1116                        v.get_shape().as_list()[1:]
1117                    ],
1118                    0,
1119                    name="concat_sequences_%s" % k),
1120                name="reshape_sequences_%s" % k),
1121            [0, 1] + list(range(2, v.get_shape().ndims + 1)),
1122            [sequence_count, self._num_unroll] + v.get_shape().as_list()[1:],
1123            debug_prefix="reshaped_sequences_%s" %
1124            k)) for k, v in self._sorted_sequences.items())
1125    expanded_context = collections.OrderedDict(
1126        (
1127            k,
1128            _check_dimensions(
1129                # Copy context to be sequence_count rows
1130                array_ops.tile(
1131                    array_ops.expand_dims(v, 0),
1132                    array_ops.concat(
1133                        [
1134                            array_ops.expand_dims(sequence_count, 0),
1135                            [1] * v.get_shape().ndims
1136                        ],
1137                        0,
1138                        name="concat_context_%s" % k),
1139                    name="tile_context_%s" % k),
1140                [0] + list(range(1, v.get_shape().ndims + 1)),
1141                [sequence_count] + v.get_shape().as_list(),
1142                debug_prefix="expanded_context_%s" % k))
1143        for k, v in self._sorted_context.items())
1144
1145    # Storing into the barrier, for each current_key:
1146    #   sequence_ix, sequence_count, next_key, length,
1147    #   context... (copied), sequences... (truncated)
1148    # Also storing into the barrier for the first key
1149    #   states (using initial_states).
1150    insert_sequence_op = self._barrier.insert_many(
1151        self._get_barrier_sequence_index(),
1152        current_keys,
1153        sequence,
1154        name="BarrierInsertSequence")
1155    insert_sequence_count_op = self._barrier.insert_many(
1156        self._get_barrier_sequence_count_index(),
1157        current_keys,
1158        expanded_sequence_count,
1159        name="BarrierInsertSequenceCount")
1160    insert_next_key_op = self._barrier.insert_many(
1161        self._get_barrier_next_key_index(),
1162        current_keys,
1163        next_keys,
1164        name="BarrierInsertNextKey")
1165    insert_length_op = self._barrier.insert_many(
1166        self._get_barrier_length_index(),
1167        current_keys,
1168        expanded_length,
1169        name="BarrierInsertLength")
1170    insert_total_length_op = self._barrier.insert_many(
1171        self._get_barrier_total_length_index(),
1172        current_keys,
1173        expanded_total_length,
1174        name="BarrierInsertTotalLength")
1175    insert_context_ops = dict((name, self._barrier.insert_many(
1176        self._get_barrier_index("context", name),
1177        current_keys,
1178        value,
1179        name="BarrierInsertContext_%s" % name))
1180                              for (name, value) in expanded_context.items())
1181    insert_sequences_ops = dict((name, self._barrier.insert_many(
1182        self._get_barrier_index("sequence", name),
1183        current_keys,
1184        value,
1185        name="BarrierInsertSequences_%s" % name))
1186                                for (name, value) in reshaped_sequences.items())
1187
1188    # An op that blocks if we reached capacity in number of active examples.
1189    TOKEN_WITH_IGNORED_VALUE = 21051976  # pylint: disable=invalid-name
1190    insert_capacity_token_op = self._capacity_queue.enqueue(
1191        (TOKEN_WITH_IGNORED_VALUE,))
1192
1193    # Insert just the initial state.  Specifically force this to run
1194    # the insert sequence op *first* so that the Barrier receives
1195    # an insert with *all* the segments and the segments all get the same index.
1196    with ops.control_dependencies(
1197        [insert_sequence_op, insert_capacity_token_op]):
1198      insert_initial_state_ops = dict(
1199          (name, self._barrier.insert_many(
1200              self._get_barrier_index("state", name),
1201              array_ops.stack([current_keys[0]]),
1202              array_ops.stack([value]),
1203              name="BarrierInitialInsertState_%s" % name))
1204          for (name, value) in self._uninitialized_states.items())
1205
1206    all_inserts = ([
1207        insert_capacity_token_op, insert_sequence_op, insert_sequence_count_op,
1208        insert_next_key_op, insert_length_op, insert_total_length_op
1209    ] + list(insert_initial_state_ops.values()) +
1210                   list(insert_context_ops.values()) +
1211                   list(insert_sequences_ops.values()))
1212
1213    self._prefetch_op = control_flow_ops.group(
1214        *all_inserts, name="StateSaverPrefetchGroup")
1215
1216  def _prepare_barrier_reads(self):
1217    """Creates ops for reading the barrier, as used by properties like `length`.
1218    """
1219    # Ops for reading from the barrier.  These ops must be run in a
1220    # different thread than the prefetcher op to avoid blocking.
1221    received = self._barrier.take_many(
1222        self._batch_size, self._allow_small_batch, name="BarrierTakeMany")
1223
1224    self._received_indices = received[0]
1225    self._received_keys = received[1]
1226    received_values = received[2]
1227
1228    self._received_sequence = received_values[self._get_barrier_sequence_index(
1229    )]
1230    self._received_sequence_count = received_values[
1231        self._get_barrier_sequence_count_index()]
1232    self._received_next_key = received_values[self._get_barrier_next_key_index(
1233    )]
1234    self._received_length = received_values[self._get_barrier_length_index()]
1235    self._received_total_length = received_values[
1236        self._get_barrier_total_length_index()]
1237    self._received_context = collections.OrderedDict(
1238        (name, received_values[self._get_barrier_index("context", name)])
1239        for name in self._sorted_context.keys())
1240    self._received_sequences = collections.OrderedDict(
1241        (name, received_values[self._get_barrier_index("sequence", name)])
1242        for name in self._sorted_sequences.keys())
1243
1244    self._received_batch_size = array_ops.squeeze(
1245        array_ops.shape(self._received_length))
1246
1247    # Which examples are we done with?
1248    self._sequence_is_done = (
1249        self._received_sequence + 1 >= self._received_sequence_count)
1250
1251    # Compute the number of finished sequences and dequeue as many tokens from
1252    # the capacity queue.
1253    finished_sequences = (math_ops.reduce_sum(
1254        math_ops.cast(self._sequence_is_done, dtypes.int32)))
1255    # TODO(ebrevdo): convert to dequeue_up_to when FIFOQueue supports it.
1256    dequeue_op = self._capacity_queue.dequeue_many(finished_sequences)
1257
1258    # Tie the dequeue_op to the received_state, such that it is definitely
1259    # carried out.
1260    with ops.control_dependencies([dequeue_op]):
1261      self._received_states = collections.OrderedDict(
1262          (name, array_ops.identity(received_values[self._get_barrier_index(
1263              "state", name)])) for name in self._sorted_states.keys())
1264    self._next_batch = NextQueuedSequenceBatch(self)
1265
1266
1267def batch_sequences_with_states(input_key,
1268                                input_sequences,
1269                                input_context,
1270                                input_length,
1271                                initial_states,
1272                                num_unroll,
1273                                batch_size,
1274                                num_threads=3,
1275                                capacity=1000,
1276                                allow_small_batch=True,
1277                                pad=True,
1278                                make_keys_unique=False,
1279                                make_keys_unique_seed=None,
1280                                name=None):
1281  """Creates batches of segments of sequential input.
1282
1283  This method creates a `SequenceQueueingStateSaver` (SQSS) and adds it to
1284  the queuerunners. It returns a `NextQueuedSequenceBatch`.
1285
1286  It accepts one example at a time identified by a unique `input_key`.
1287  `input_sequence` is a dict with values that are tensors with time as first
1288  dimension. This time dimension must be the same across those tensors of an
1289  example. It can vary across examples. Although it always has to be a multiple
1290  of `num_unroll`. Hence, padding may be necessary and it is turned on by
1291  default by `pad=True`.
1292
1293  `input_length` is a Tensor scalar or an int recording the time dimension prior
1294  to padding. It should be between 0 and the time dimension. One reason we want
1295  to keep track of it is so that we can take it into consideration when
1296  computing the loss. If `pad=True` then `input_length` can be `None` and will
1297  be inferred.
1298
1299  This methods segments `input_sequence` into segments of length `num_unroll`.
1300  It batches input sequences from `batch_size` many examples. These mini-batches
1301  are available through the `sequence` property of the output. Moreover, for
1302  each entry in the batch we can access its original `input_key` in `key` and
1303  its input length in `total_length`. `length` records within this segment how
1304  many non-padded time steps there are.
1305
1306  Static features of an example that do not vary across time can be part of the
1307  `input_context`, a dict with Tensor values. This method copies the context for
1308  each segment and makes it available in the `context` of the output.
1309
1310  This method can maintain and update a state for each example. It accepts some
1311  initial_states as a dict with Tensor values. The first mini-batch an example
1312  is contained has initial_states as entry of the `state`. If save_state is
1313  called then the next segment will have the updated entry of the `state`.
1314  See `NextQueuedSequenceBatch` for a complete list of properties and methods.
1315
1316  Example usage:
1317
1318  ```python
1319  batch_size = 32
1320  num_unroll = 20
1321  num_enqueue_threads = 3
1322  lstm_size = 8
1323  cell = tf.contrib.rnn.BasicLSTMCell(num_units=lstm_size)
1324
1325  key, sequences, context = my_parser(raw_data)
1326  initial_state_values = tf.zeros((state_size,), dtype=tf.float32)
1327  initial_states = {"lstm_state": initial_state_values}
1328  batch = tf.batch_sequences_with_states(
1329      input_key=key,
1330      input_sequences=sequences,
1331      input_context=context,
1332      input_length=tf.shape(sequences["input"])[0],
1333      initial_states=initial_states,
1334      num_unroll=num_unroll,
1335      batch_size=batch_size,
1336      num_threads=num_enqueue_threads,
1337      capacity=batch_size * num_enqueue_threads * 2)
1338
1339  inputs = batch.sequences["input"]
1340  context_label = batch.context["label"]
1341
1342  inputs_by_time = tf.split(value=inputs, num_or_size_splits=num_unroll, axis=1)
1343  assert len(inputs_by_time) == num_unroll
1344
1345  lstm_output, _ = tf.contrib.rnn.static_state_saving_rnn(
1346    cell,
1347    inputs_by_time,
1348    state_saver=batch,
1349    state_name="lstm_state")
1350
1351  # Start a prefetcher in the background
1352  sess = tf.Session()
1353
1354  tf.train.start_queue_runners(sess=session)
1355
1356  while True:
1357    # Step through batches, perform training or inference...
1358    session.run([lstm_output])
1359  ```
1360
1361  Args:
1362    input_key: A string scalar `Tensor`, the **unique** key for the given
1363      input example.  This is used to keep track of the split minibatch elements
1364      of this input.  Batched keys of the current iteration are made
1365      accessible via the `key` property.  The shape of `input_key` (scalar) must
1366      be fully specified.  Consider setting `make_keys_unique` to True when
1367      iterating over the same input multiple times.
1368
1369      **Note**: if `make_keys_unique=False` then `input_key`s must be unique.
1370    input_sequences: A dict mapping string names to `Tensor` values.  The values
1371      must all have matching first dimension, called `value_length`. They may
1372      vary from input to input. The remainder of the shape (other than the first
1373      dimension) must be fully specified.
1374      The `SequenceQueueingStateSaver` will split these tensors along
1375      this first dimension into minibatch elements of dimension `num_unrolled`.
1376      Batched and segmented sequences of the current iteration are made
1377      accessible via the `sequences` property.
1378
1379      **Note**: if `pad=False`, then `value_length` must always be a multiple
1380        of `num_unroll`.
1381    input_context: A dict mapping string names to `Tensor` values.  The values
1382      are treated as "global" across all time splits of the given input example,
1383      and will be copied across for all minibatch elements accordingly.
1384      Batched and copied context of the current iteration are made
1385      accessible via the `context` property.
1386
1387      **Note**: All input_context values must have fully defined shapes.
1388    input_length: None or an int32 scalar `Tensor`, the length of the sequence
1389      prior to padding. If `input_length=None` and `pad=True` then the length
1390      will be inferred and will be equal to `value_length`. If `pad=False` then
1391      `input_length` cannot be `None`: `input_length` must be specified. Its
1392      shape of `input_length` (scalar) must be fully specified. Its value may be
1393      at most `value_length` for any given input (see above for the definition
1394      of `value_length`). Batched and total lengths of the current iteration are
1395      made accessible via the `length` and `total_length` properties.
1396    initial_states: A dict mapping string state names to multi-dimensional
1397      values (e.g. constants or tensors).  This input defines the set of
1398      states that will be kept track of during computing iterations, and
1399      which can be accessed via the `state` and `save_state` methods.
1400
1401      **Note**: All initial_state values must have fully defined shapes.
1402    num_unroll: Python integer, how many time steps to unroll at a time.
1403      The input sequences of length k are then split into k / num_unroll many
1404      segments.
1405    batch_size: int or int32 scalar `Tensor`, how large minibatches should
1406      be when accessing the `state()` method and `context`, `sequences`, etc,
1407      properties.
1408    num_threads: The int number of threads enqueuing input examples into a
1409      queue.
1410    capacity: The max capacity of the queue in number of examples. Needs to be
1411      at least `batch_size`. Defaults to 1000. When iterating over the same
1412      input example multiple times reusing their keys the `capacity` must be
1413      smaller than the number of examples.
1414    allow_small_batch: If true, the queue will return smaller batches when
1415      there aren't enough input examples to fill a whole batch and the end of
1416      the input has been reached.
1417    pad: If `True`, `input_sequences` will be padded to multiple of
1418      `num_unroll`. In that case `input_length` may be `None` and is assumed to
1419      be the length of first dimension of values in `input_sequences`
1420      (i.e. `value_length`).
1421    make_keys_unique: Whether to append a random integer to the `input_key` in
1422      an effort to make it unique. The seed can be set via
1423      `make_keys_unique_seed`.
1424    make_keys_unique_seed: If `make_keys_unique=True` this fixes the seed with
1425      which a random postfix is generated.
1426    name: An op name string (optional).
1427
1428  Returns:
1429    A NextQueuedSequenceBatch with segmented and batched inputs and their
1430    states.
1431
1432  Raises:
1433    TypeError: if any of the inputs is not an expected type.
1434    ValueError: if any of the input values is inconsistent, e.g. if
1435      not enough shape information is available from inputs to build
1436      the state saver.
1437  """
1438  tensor_list = (list(input_sequences.values()) + list(input_context.values()) +
1439                 list(initial_states.values()))
1440  with ops.name_scope(name, "batch_sequences_with_states", tensor_list) as name:
1441    if pad:
1442      length, input_sequences = _padding(input_sequences, num_unroll)
1443      input_length = input_length if input_length is not None else length
1444    elif input_sequences:
1445      # Assert that value_length is a multiple of num_unroll.
1446      checked_input_sequences = {}
1447      for key, value in input_sequences.items():
1448        if (isinstance(value, sparse_tensor.SparseTensor) or
1449            isinstance(value, sparse_tensor.SparseTensorValue)):
1450          value_length = value.dense_shape[0]
1451          with ops.control_dependencies([
1452              control_flow_ops.Assert(
1453                  math_ops.logical_and(
1454                      math_ops.equal(value_length % num_unroll, 0),
1455                      math_ops.not_equal(value_length, 0)),
1456                  [
1457                      string_ops.string_join([
1458                          "SparseTensor %s first dimension should be a "
1459                          "multiple of: " % key,
1460                          string_ops.as_string(num_unroll),
1461                          ", but saw value: ",
1462                          string_ops.as_string(value_length),
1463                          ". Consider setting pad=True."])])]):
1464            checked_input_sequences[key] = sparse_tensor.SparseTensor(
1465                indices=array_ops.identity(
1466                    value.indices, name="multiple_of_checked"),
1467                values=array_ops.identity(
1468                    value.values, name="multiple_of_checked"),
1469                dense_shape=array_ops.identity(
1470                    value.dense_shape, name="multiple_of_checked"))
1471        else:
1472          if not isinstance(value, ops.Tensor):
1473            try:
1474              value = ops.convert_to_tensor(value)
1475            except TypeError:
1476              raise TypeError(
1477                  "Unsupported input_sequences expected Tensor or SparseTensor "
1478                  "values, got: %s for key %s" % (str(type(value)), key))
1479          value_length = array_ops.shape(value)[0]
1480          with ops.control_dependencies([
1481              control_flow_ops.Assert(
1482                  math_ops.logical_and(
1483                      math_ops.equal(value_length % num_unroll, 0),
1484                      math_ops.not_equal(value_length, 0)),
1485                  [
1486                      string_ops.string_join([
1487                          "Tensor %s first dimension should be a multiple "
1488                          "of: " % key,
1489                          string_ops.as_string(num_unroll),
1490                          ", but saw value: ",
1491                          string_ops.as_string(value_length),
1492                          ". Consider setting pad=True."
1493                      ])
1494                  ])
1495          ]):
1496            checked_input_sequences[key] = array_ops.identity(
1497                value, name="multiple_of_checked")
1498      input_sequences = checked_input_sequences
1499    # Move SparseTensors in context into input_sequences.
1500    _move_sparse_tensor_out_context(input_context, input_sequences, num_unroll)
1501    # Deconstruct SparseTensors in sequence into a dense Tensor before inputting
1502    # to SQSS.
1503    (transformed_input_seq,
1504     sparse_tensor_keys,
1505     tensor_list) = _deconstruct_sparse_tensor_seq(input_sequences)
1506
1507    if make_keys_unique:
1508      input_key = string_ops.string_join([
1509          input_key,
1510          string_ops.as_string(
1511              random_ops.random_uniform(
1512                  (), minval=0, maxval=100000000, dtype=dtypes.int32,
1513                  seed=make_keys_unique_seed))])
1514
1515    # setup stateful queue reader
1516    stateful_reader = SequenceQueueingStateSaver(
1517        batch_size,
1518        num_unroll,
1519        input_length=input_length,
1520        input_key=input_key,
1521        input_sequences=transformed_input_seq,
1522        input_context=input_context,
1523        initial_states=initial_states,
1524        capacity=capacity,
1525        allow_small_batch=allow_small_batch)
1526
1527    barrier = stateful_reader.barrier
1528    summary.scalar("queue/%s/ready_segment_batches_" % barrier.name,
1529                   math_ops.cast(barrier.ready_size(), dtypes.float32))
1530
1531    q_runner = queue_runner.QueueRunner(
1532        stateful_reader, [stateful_reader.prefetch_op] * num_threads,
1533        queue_closed_exception_types=(errors.OutOfRangeError,
1534                                      errors.CancelledError))
1535    queue_runner.add_queue_runner(q_runner)
1536    batch = stateful_reader.next_batch
1537
1538    # Reconstruct SparseTensors in sequence.
1539    _reconstruct_sparse_tensor_seq(
1540        batch.sequences,
1541        sparse_tensor_keys,
1542        tensor_list,
1543        batch_size,
1544        num_unroll)
1545    # Move select SparseTensors back to context.
1546    _move_sparse_tensor_in_context(batch.context, batch.sequences)
1547    return batch
1548
1549
1550def _padding(sequences, num_unroll):
1551  """For a dictionary of sequences, pads tensors to a multiple of `num_unroll`.
1552
1553  Args:
1554    sequences: dictionary with `Tensor` values.
1555    num_unroll: int specifying to what multiple to pad sequences to.
1556  Returns:
1557    length: Scalar `Tensor` of dimension 0 of all the values in sequences.
1558    padded_sequence: Dictionary of sequences that are padded to a multiple of
1559      `num_unroll`.
1560  Raises:
1561    ValueError: If `num_unroll` not an int or sequences not a dictionary from
1562                string to `Tensor`.
1563  """
1564  if not isinstance(num_unroll, numbers.Integral):
1565    raise ValueError("Unsupported num_unroll expected int, got: %s" %
1566                     str(num_unroll))
1567  if not isinstance(sequences, dict):
1568    raise TypeError("Unsupported sequences expected dict, got: %s" %
1569                    str(sequences))
1570  for key, value in sequences.items():
1571    if not isinstance(key, six.string_types):
1572      raise TypeError("Unsupported sequences key expected string, got: %s" %
1573                      str(key))
1574  if not sequences:
1575    return 0, {}
1576
1577  # Sort 'sequences_dict' so 'length' will have a predictable value below.
1578  sequences_dict = collections.OrderedDict()
1579  for key, value in sorted(sequences.items()):
1580    if not (isinstance(value, sparse_tensor.SparseTensor) or
1581            isinstance(value, sparse_tensor.SparseTensorValue)):
1582      sequences_dict[key] = ops.convert_to_tensor(value)
1583    else:
1584      sequences_dict[key] = value
1585
1586  lengths = [array_ops.shape(value)[0] for value in sequences_dict.values()
1587             if isinstance(value, ops.Tensor)]
1588  if lengths:
1589    length = lengths[0]
1590    all_lengths_equal = [
1591        control_flow_ops.Assert(
1592            math_ops.equal(l, length), [string_ops.string_join(
1593                ["All sequence lengths must match, but received lengths: ",
1594                 string_ops.as_string(lengths)])])
1595        for l in lengths]
1596    length = control_flow_ops.with_dependencies(all_lengths_equal, length)
1597  else:  # Only have SparseTensors
1598    sparse_lengths = [value.dense_shape[0] for value in sequences_dict.values()
1599                      if isinstance(value, sparse_tensor.SparseTensor)]
1600    length = math_ops.reduce_max(math_ops.cast(sparse_lengths, dtypes.int32))
1601
1602  unroll = array_ops.constant(num_unroll)
1603  padded_length = length + ((unroll - (length % unroll)) % unroll)
1604  padded_sequences = {}
1605  for key, value in sequences_dict.items():
1606    if isinstance(value, ops.Tensor):
1607      # 1. create shape of paddings
1608      # first dimension of value will be increased by num_paddings to
1609      # padded_length
1610      num_paddings = [padded_length - array_ops.shape(value)[0]]
1611      # the shape of the paddings that we concat with the original value will be
1612      # [num_paddings, tf.shape(value)[1], tf.shape(value)[2], ...,
1613      #  tf.shape(value)[tf.rank(value) - 1])]
1614      padding_shape = array_ops.concat(
1615          (num_paddings, array_ops.shape(value)[1:]), 0)
1616      # 2. fill padding shape with dummies
1617      dummy = array_ops.constant(
1618          "" if value.dtype == dtypes.string else 0, dtype=value.dtype)
1619      paddings = array_ops.fill(dims=padding_shape, value=dummy)
1620      # 3. concat values with paddings
1621      padded_sequences[key] = array_ops.concat([value, paddings], 0)
1622    else:
1623      padded_shape = array_ops.concat(
1624          [[math_ops.cast(padded_length, dtypes.int64)], value.dense_shape[1:]],
1625          0)
1626      padded_sequences[key] = sparse_tensor.SparseTensor(
1627          indices=value.indices,
1628          values=value.values,
1629          dense_shape=padded_shape)
1630  return length, padded_sequences
1631
1632
1633_SPARSE_CONTEXT_PREFIX_KEY = "_context_in_seq_"
1634
1635
1636def _move_sparse_tensor_out_context(input_context, input_sequences, num_unroll):
1637  """Moves `SparseTensor`s from `input_context` into `input_sequences` as seq.
1638
1639  For `key, value` pairs in `input_context` with `SparseTensor` `value` removes
1640  them from `input_context` and transforms the `value` into a sequence and
1641  then adding `key`, transformed `value` into `input_sequences`.
1642  The transformation is done by adding a new first dimension of `value_length`
1643  equal to that of the other values in input_sequences` and tiling the `value`
1644  every `num_unroll` steps.
1645
1646  Args:
1647    input_context: dictionary with `Tensor` or `SparseTensor` values. To be
1648      modified to take out `SparseTensor` values.
1649    input_sequences: dictionary with `Tensor` or `SparseTensor` values. To be
1650      modified to add transformed `SparseTensor` values from `input_context`.
1651    num_unroll: int specifying to what multiple to pad sequences to.
1652  """
1653  value_length = array_ops.constant(1)
1654  if input_sequences:
1655    seq = list(input_sequences.values())[0]
1656    if isinstance(seq, ops.Tensor):
1657      with ops.control_dependencies([seq]):
1658        value_length = array_ops.shape(seq)[0]
1659    else:
1660      value_length = seq.dense_shape[0]
1661  value_length = math_ops.cast(value_length, dtype=dtypes.int64)
1662  def _copy_sparse_tensor(sp_tensor):
1663    """Operation to tile a sparse tensor along a newly added 0 dimension.
1664
1665    Adding a new first dimension of `value_length` and tiling the `sp_tensor`
1666    every `num_unroll` steps.
1667
1668    Args:
1669      sp_tensor: `SparseTensor`.
1670    Returns:
1671      `SparseTensor` sequence with `sp_tensor` tiled.
1672    """
1673    n = value_length // num_unroll
1674    n = math_ops.cast(n, dtype=dtypes.int32)
1675    values = array_ops.tile(sp_tensor.values, array_ops.expand_dims(n, 0))
1676    shape = array_ops.concat(
1677        [array_ops.expand_dims(value_length, 0), sp_tensor.dense_shape], 0)
1678
1679    # Construct new indices by multiplying old ones and prepending [0, n).
1680    # First multiply indices n times along a newly created 0-dimension.
1681    multiplied_indices = array_ops.tile(
1682        array_ops.expand_dims(sp_tensor.indices, 0),
1683        array_ops.stack([n, 1, 1]))
1684
1685    # Construct indicator for [0, n).
1686    # [ [ [0] [0] ... [0] ]
1687    #   [ [num_unroll] [num_unroll] ... [num_unroll] ]
1688    #     ...
1689    #   [ [num_unroll*(n-1)] [num_unroll*(n-1)] ... [num_unroll*(n-1)] ] ]
1690    # of shape [n, shape(sp_tensor.indices)[0], 1]
1691    # Get current dimensions of indices.
1692    dim0 = array_ops.shape(sp_tensor.indices)[0]
1693    dim1 = array_ops.shape(sp_tensor.indices)[1]
1694    ind = math_ops.range(start=0, limit=value_length, delta=num_unroll)
1695
1696    # ind.set_shape([n])
1697    ind = array_ops.expand_dims(ind, 1)
1698    ind = array_ops.expand_dims(ind, 2)
1699    ind = array_ops.tile(ind, [1, dim0, 1])
1700
1701    # Concatenate both and reshape.
1702    indices = array_ops.concat([ind, multiplied_indices], 2)
1703    indices = array_ops.reshape(indices, [dim0 * n, dim1 + 1])
1704
1705    return sparse_tensor.SparseTensor(indices=indices,
1706                                      values=values,
1707                                      dense_shape=shape)
1708
1709  sparse_tensor_keys = [
1710      k for k in sorted(input_context.keys())
1711      if (isinstance(input_context[k], sparse_tensor.SparseTensor) or
1712          isinstance(input_context[k], sparse_tensor.SparseTensorValue))]
1713  for key in sparse_tensor_keys:
1714    input_sequences[_SPARSE_CONTEXT_PREFIX_KEY + key] = _copy_sparse_tensor(
1715        input_context[key])
1716    del input_context[key]
1717
1718
1719def _move_sparse_tensor_in_context(context, sequences):
1720  sparse_tensor_keys = [
1721      k for k in sorted(sequences) if k.startswith(_SPARSE_CONTEXT_PREFIX_KEY)]
1722  for key in sparse_tensor_keys:
1723    new_key = key[len(_SPARSE_CONTEXT_PREFIX_KEY):]
1724    sp_tensor = sequences[key]
1725    # Take out time dimension.
1726    sp_tensor = sparse_tensor.SparseTensor(
1727        sp_tensor.indices,  # with only 0s at column 1 representing time.
1728        sp_tensor.values,
1729        array_ops.concat(
1730            [[sp_tensor.dense_shape[0]],  # batch
1731             [1],  # time
1732             sp_tensor.dense_shape[2:]],  # SparseTensor shape prior to batching
1733            0))
1734    new_shape = array_ops.concat(
1735        [[sp_tensor.dense_shape[0]], sp_tensor.dense_shape[2:]], 0)
1736    context[new_key] = sparse_ops.sparse_reshape(sp_tensor, new_shape)
1737    del sequences[key]
1738
1739
1740def _deconstruct_sparse_tensor_seq(input_sequence, shared_name=None):
1741  """Converts `SparseTensor` values into `Tensors` of IDs and meta data.
1742
1743  Given a dict of keys -> `Tensor` or `SparseTensor` transforms the
1744  `SparseTensor` values into `Tensor` values of IDs by calling `_store_sparse`.
1745  The IDs are pointers into and underlying `SparseTensorsMap` that is being
1746  constructed. Additional meta data is returned in order to be able to
1747  reconstruct `SparseTensor` values after batching and segmenting the IDs
1748  `Tensor`.
1749
1750  Args:
1751    input_sequence: dictionary with `Tensor` or `SparseTensor` values.
1752    shared_name: The shared name for the underlying `SparseTensorsMap`
1753      (optional, defaults to the name of the newly created op).
1754  Returns:
1755    A tuple `(sequence, sparse_tensor_keys, tensor_list)` where `sequence` is
1756    dictionary with the same keys as `input_sequence` but only `Tensor` values,
1757    `sparse_tensor_keys` is a list of the keys of the `SparseTensor` values that
1758    were converted, and `tensor_list` is a list of the same length with
1759    `Tensor` objects.
1760  """
1761  sparse_tensor_keys = [
1762      k for k in sorted(input_sequence.keys())
1763      if (isinstance(input_sequence[k], sparse_tensor.SparseTensor) or
1764          isinstance(input_sequence[k], sparse_tensor.SparseTensorValue))]
1765  if not sparse_tensor_keys:
1766    return input_sequence, None, sparse_tensor_keys
1767  sparse_tensor_list = [input_sequence[k] for k in sparse_tensor_keys]
1768  tensor_list = [_store_sparse(sp_tensor, shared_name=shared_name)
1769                 for sp_tensor in sparse_tensor_list]
1770  transformed_input_seq = dict(input_sequence)
1771  tensor_op_list = []
1772  for i, k in enumerate(sparse_tensor_keys):
1773    transformed_input_seq[k] = tensor_list[i]
1774    tensor_op_list += [tensor_list[i].op]
1775  return transformed_input_seq, sparse_tensor_keys, tensor_op_list
1776
1777
1778def _reconstruct_sparse_tensor_seq(sequence,
1779                                   sparse_tensor_keys,
1780                                   tensor_op_list,
1781                                   batch_size,
1782                                   num_unroll):
1783  """Inverse of _deconstruct_sparse_tensor_seq.
1784
1785  Given a dict of keys -> `Tensor` reconstructs `SparseTensor` values for keys
1786  in `sparse_tensor_keys`. Their `Tensor` values are assumed to be IDs into the
1787  underlying `SparseTensorsMap`. The `dense_shape` of the `SparseTensor`s is
1788  `[batch_size, num_unroll, d_0, d_1, ..., d_n]` when the original
1789  `SparseTensor` that got deconstructed with `_deconstruct_sparse_tensor_seq`
1790  has a `dense_shape` of `[None, d_0, d_1, ..., d_n]`.
1791
1792  Args:
1793    sequence: dictionary with only `Tensor` values that is being updated.
1794    sparse_tensor_keys: list of the keys present in `sequence` identifying
1795      `SparseTensor` values that should be reconstructed.
1796    tensor_op_list: list of the same length as `sparse_tensor_keys` with
1797      `Tensor` objects.
1798    batch_size: int or int32 scalar `Tensor`, how large minibatches should
1799      be.
1800    num_unroll: Python integer, how many time steps were unrolled at a time.
1801  """
1802  def _flatten_tensor(tensor):
1803    """Flattens `Tensor` of `shape [batch_size, num_unroll]` into 1D `Tensor`.
1804
1805    The main use of this function is to work around the limitation of
1806    `_restore_sparse` to only accept 1D handles.
1807
1808    Args:
1809      tensor: 2D `Tensor` of `shape [batch_size, num_unroll]`
1810    Returns:
1811      1D `Tensor`.
1812    """
1813    return array_ops.reshape(tensor, [-1])
1814
1815  def _unflatten_sparse_tensor(sp_tensor):
1816    """Recreates `[batch_size, num_unroll]` dimensions in the `SparseTensor`.
1817
1818    Counter-part of `_flatten_tensor` which is called on the input of
1819    `_restore_sparse` while this method is called on the output of it.
1820    Together they work around the limitation of `_restore_sparse` to only
1821    accept 1D handles.
1822
1823    The `indices` in `sp_tensor` is a 2D `Tensor` of `shape [N, ndims]`, where
1824    `N` is the number of `values` and `ndims` is the number of dimension in its
1825    dense counterpart. Among `ndims` the first entry corresponds to the batch
1826    dimension `[0, num_unroll * batch_size)` from which we need to recreate the
1827    2 dimensions `batch_size` and `num_unroll`.
1828
1829    The reason this reconstruction works is because the output of
1830    `_restore_sparse` despite being a `SparseTensor` is actually dense w.r.t.
1831    that first entry.
1832
1833    Args:
1834      sp_tensor: A SparseTensor.
1835    Returns:
1836      A SparseTensor with a +1 higher rank than the input.
1837    """
1838    idx_batch = math_ops.cast(
1839        math_ops.floor(sp_tensor.indices[:, 0] / num_unroll), dtypes.int64)
1840    idx_time = math_ops.mod(sp_tensor.indices[:, 0], num_unroll)
1841    indices = array_ops.concat(
1842        [
1843            array_ops.expand_dims(idx_batch, 1),
1844            array_ops.expand_dims(idx_time, 1), sp_tensor.indices[:, 1:]
1845        ],
1846        axis=1)
1847    dense_shape = array_ops.concat(
1848        [[math_ops.cast(batch_size, dtype=dtypes.int64)],
1849         [math_ops.cast(num_unroll, dtype=dtypes.int64)],
1850         sp_tensor.dense_shape[1:]], axis=0)
1851    return sparse_tensor.SparseTensor(
1852        indices=indices,
1853        values=sp_tensor.values,
1854        dense_shape=dense_shape)
1855
1856  if not sparse_tensor_keys:
1857    return
1858  tensor_list = [sequence[k] for k in sparse_tensor_keys]
1859  sp_tensors = [
1860      _restore_sparse(sparse_map_op=i,
1861                      # Flatten the 2D Tensor [batch_size, num_unroll] of
1862                      # handles to a 1D Tensor.
1863                      # Reconstruct the dimensions later.
1864                      # TODO(b/34247140): Remove this workaround.
1865                      sparse_handles=_flatten_tensor(s), rank=None)
1866      for i, s in zip(tensor_op_list, tensor_list)]
1867  num_unroll = ops.convert_to_tensor(num_unroll, dtype=dtypes.int64,
1868                                     name="num_unroll_int64")
1869
1870  # Recreate the [batch_size, num_unroll] dimensions in the SparseTensors.
1871  # The dense_shape will have a +1 higher rank.
1872  # TODO(b/34247140): Remove this workaround.
1873  sp_tensors_higher_dim = [_unflatten_sparse_tensor(s) for s in sp_tensors]
1874
1875  # Set values to SparseTensors for sparse_tensor_keys.
1876  for i, key in enumerate(sparse_tensor_keys):
1877    sequence[key] = sp_tensors_higher_dim[i]
1878  return
1879