• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Scan dataset transformation."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import collections
21
22from tensorflow.python.data.ops import dataset_ops
23from tensorflow.python.data.util import nest
24from tensorflow.python.data.util import sparse
25from tensorflow.python.framework import function
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import gen_dataset_ops
28
29
30class _ScanDataset(dataset_ops.Dataset):
31  """A dataset that scans a function across its input."""
32
33  def __init__(self, input_dataset, initial_state, scan_func):
34    """See `scan()` for details."""
35    super(_ScanDataset, self).__init__()
36    self._input_dataset = input_dataset
37
38    with ops.name_scope("initial_state"):
39      self._initial_state = nest.pack_sequence_as(initial_state, [
40          ops.convert_to_tensor(t, name="component_%d" % i)
41          for i, t in enumerate(nest.flatten(initial_state))
42      ])
43
44    # Compute initial values for the state shapes and types based on
45    # the initial state. These will be refined by running
46    # `tf_scan_func` one or more times below.
47    # TODO(b/68937811): Allow the initial state to be a tf.SparseTensor.
48    self._state_shapes = nest.pack_sequence_as(
49        self._initial_state,
50        [t.shape for t in nest.flatten(self._initial_state)])
51    self._state_types = nest.pack_sequence_as(
52        self._initial_state,
53        [t.dtype for t in nest.flatten(self._initial_state)])
54
55    # Will be populated by calling `tf_scan_func`.
56    self._output_classes = None
57    self._output_shapes = None
58    self._output_types = None
59
60    # Iteratively rerun the scan function until reaching a fixed pont on
61    # `self._state_shapes`.
62    need_to_rerun = True
63    while need_to_rerun:
64
65      flat_state_shapes = nest.flatten(self._state_shapes)
66      flat_state_types = nest.flatten(self._state_types)
67
68      # Create a list in which `tf_scan_func` will store the s
69      flat_new_state_shapes = []
70
71      @function.Defun(*(flat_state_types + nest.flatten(
72          sparse.as_dense_types(input_dataset.output_types,
73                                input_dataset.output_classes))))
74      def tf_scan_func(*args):
75        """A wrapper for Defun that facilitates shape inference."""
76        # Pass in shape information from the state and input_dataset.
77        # TODO(b/69424092): Check that neither inputs nor outputs are sparse.
78        dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
79                                              input_dataset.output_classes)
80        for arg, shape in zip(args,
81                              flat_state_shapes + nest.flatten(dense_shapes)):
82          arg.set_shape(shape)
83
84        pivot = len(flat_state_shapes)
85        old_state = nest.pack_sequence_as(self._initial_state, args[:pivot])
86        input_value = nest.pack_sequence_as(input_dataset.output_types,
87                                            args[pivot:])
88
89        ret = scan_func(old_state, input_value)
90        if not isinstance(ret, collections.Sequence) or len(ret) != 2:
91          raise TypeError("The scan function must return a pair comprising the "
92                          "new state and the output value.")
93        new_state, output_value = ret
94
95        flat_new_state = [
96            ops.convert_to_tensor(t) for t in nest.flatten(new_state)
97        ]
98        flat_output_value = [
99            ops.convert_to_tensor(t) for t in nest.flatten(output_value)
100        ]
101
102        # Extract shape information from the returned values.
103        flat_new_state_shapes.extend([t.shape for t in flat_new_state])
104        self._output_shapes = nest.pack_sequence_as(
105            output_value, [t.shape for t in flat_output_value])
106
107        # Extract and validate type information from the returned values.
108        for t, dtype in zip(flat_new_state, flat_state_types):
109          if t.dtype != dtype:
110            raise TypeError(
111                "The element types for the new state must match the initial "
112                "state. Expected %s; got %s." %
113                (self._state_types, nest.pack_sequence_as(
114                    self._state_types, [t.dtype for t in flat_new_state])))
115        self._output_classes = nest.pack_sequence_as(
116            output_value, [ops.Tensor for _ in flat_output_value])
117        self._output_types = nest.pack_sequence_as(
118            output_value, [t.dtype for t in flat_output_value])
119
120        return flat_new_state + flat_output_value
121
122      # Use the private method that will execute `tf_scan_func` but delay
123      # adding it to the graph in case we need to rerun the function.
124      tf_scan_func._create_definition_if_needed()  # pylint: disable=protected-access
125
126      weakened_state_shapes = [
127          original.most_specific_compatible_shape(new)
128          for original, new in zip(flat_state_shapes, flat_new_state_shapes)
129      ]
130
131      need_to_rerun = False
132      for original_shape, weakened_shape in zip(flat_state_shapes,
133                                                weakened_state_shapes):
134        if original_shape.ndims is not None and (
135            weakened_shape.ndims is None or
136            original_shape.as_list() != weakened_shape.as_list()):
137          need_to_rerun = True
138          break
139
140      if need_to_rerun:
141        # NOTE(mrry): `self._output_shapes` will be overwritten when we rerun
142        # `tf_scan_func`.
143        self._state_shapes = nest.pack_sequence_as(self._state_shapes,
144                                                   weakened_state_shapes)
145
146    self._scan_func = tf_scan_func
147
148  def _as_variant_tensor(self):
149    input_t = self._input_dataset._as_variant_tensor()  # pylint: disable=protected-access
150    return gen_dataset_ops.scan_dataset(
151        input_t,
152        nest.flatten(self._initial_state),
153        self._scan_func.captured_inputs,
154        f=self._scan_func,
155        output_types=nest.flatten(
156            sparse.as_dense_types(self.output_types, self.output_classes)),
157        output_shapes=nest.flatten(
158            sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
159
160  @property
161  def output_classes(self):
162    return self._output_classes
163
164  @property
165  def output_shapes(self):
166    return self._output_shapes
167
168  @property
169  def output_types(self):
170    return self._output_types
171
172
173def scan(initial_state, scan_func):
174  """A transformation that scans a function across an input dataset.
175
176  This transformation is a stateful relative of @{tf.data.Dataset.map}.
177  In addition to mapping `scan_func` across the elements of the input dataset,
178  `scan()` accumulates one or more state tensors, whose initial values are
179  `initial_state`.
180
181  Args:
182    initial_state: A nested structure of tensors, representing the initial state
183      of the accumulator.
184    scan_func: A function that maps `(old_state, input_element)` to
185      `(new_state, output_element). It must take two arguments and return a
186      pair of nested structures of tensors. The `new_state` must match the
187      structure of `initial_state`.
188
189  Returns:
190    A `Dataset` transformation function, which can be passed to
191    @{tf.data.Dataset.apply}.
192  """
193  def _apply_fn(dataset):
194    return _ScanDataset(dataset, initial_state, scan_func)
195
196  return _apply_fn
197