• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Grouping dataset transformations."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.data.ops import dataset_ops
21from tensorflow.python.data.util import nest
22from tensorflow.python.data.util import sparse
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import function
25from tensorflow.python.framework import ops
26from tensorflow.python.ops import gen_dataset_ops
27
28
29def group_by_window(key_func,
30                    reduce_func,
31                    window_size=None,
32                    window_size_func=None):
33  """A transformation that groups windows of elements by key and reduces them.
34
35  This transformation maps each consecutive element in a dataset to a key
36  using `key_func` and groups the elements by key. It then applies
37  `reduce_func` to at most `window_size_func(key)` elements matching the same
38  key. All execpt the final window for each key will contain
39  `window_size_func(key)` elements; the final window may be smaller.
40
41  You may provide either a constant `window_size` or a window size determined by
42  the key through `window_size_func`.
43
44  Args:
45    key_func: A function mapping a nested structure of tensors
46      (having shapes and types defined by `self.output_shapes` and
47      `self.output_types`) to a scalar `tf.int64` tensor.
48    reduce_func: A function mapping a key and a dataset of up to `window_size`
49      consecutive elements matching that key to another dataset.
50    window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
51      consecutive elements matching the same key to combine in a single
52      batch, which will be passed to `reduce_func`. Mutually exclusive with
53      `window_size_func`.
54    window_size_func: A function mapping a key to a `tf.int64` scalar
55      `tf.Tensor`, representing the number of consecutive elements matching
56      the same key to combine in a single batch, which will be passed to
57      `reduce_func`. Mutually exclusive with `window_size`.
58
59  Returns:
60    A `Dataset` transformation function, which can be passed to
61    @{tf.data.Dataset.apply}.
62
63  Raises:
64    ValueError: if neither or both of {`window_size`, `window_size_func`} are
65      passed.
66  """
67  if (window_size is not None and window_size_func or
68      not (window_size is not None or window_size_func)):
69    raise ValueError("Must pass either window_size or window_size_func.")
70
71  if window_size is not None:
72
73    def constant_window_func(unused_key):
74      return ops.convert_to_tensor(window_size, dtype=dtypes.int64)
75
76    window_size_func = constant_window_func
77
78  assert window_size_func is not None
79
80  def _apply_fn(dataset):
81    """Function from `Dataset` to `Dataset` that applies the transformation."""
82    return GroupByWindowDataset(dataset, key_func, reduce_func,
83                                window_size_func)
84
85  return _apply_fn
86
87
88class _VariantDataset(dataset_ops.Dataset):
89  """A Dataset wrapper for a tf.variant-typed function argument."""
90
91  def __init__(self, dataset_variant, output_types, output_shapes,
92               output_classes):
93    super(_VariantDataset, self).__init__()
94    self._dataset_variant = dataset_variant
95    self._output_types = output_types
96    self._output_shapes = output_shapes
97    self._output_classes = output_classes
98
99  def _as_variant_tensor(self):
100    return self._dataset_variant
101
102  @property
103  def output_classes(self):
104    return self._output_classes
105
106  @property
107  def output_shapes(self):
108    return self._output_shapes
109
110  @property
111  def output_types(self):
112    return self._output_types
113
114
115class GroupByWindowDataset(dataset_ops.Dataset):
116  """A `Dataset` that groups its input and performs a windowed reduction."""
117
118  def __init__(self, input_dataset, key_func, reduce_func, window_size_func):
119    """See `group_by_window()` for details."""
120    super(GroupByWindowDataset, self).__init__()
121
122    self._input_dataset = input_dataset
123
124    self._make_key_func(key_func, input_dataset)
125    self._make_reduce_func(reduce_func, input_dataset)
126    self._make_window_size_func(window_size_func)
127
128  def _make_window_size_func(self, window_size_func):
129    """Make wrapping Defun for window_size_func."""
130
131    @function.Defun(dtypes.int64)
132    def tf_window_size_func(key):
133      key.set_shape([])
134      window_size = ops.convert_to_tensor(
135          window_size_func(key), dtype=dtypes.int64)
136      if window_size.dtype != dtypes.int64:
137        raise ValueError(
138            "`window_size_func` must return a single tf.int64 tensor.")
139      return window_size
140
141    self._window_size_func = tf_window_size_func
142    self._window_size_func.add_to_graph(ops.get_default_graph())
143
144  def _make_key_func(self, key_func, input_dataset):
145    """Make wrapping Defun for key_func."""
146
147    @function.Defun(*nest.flatten(
148        sparse.as_dense_types(input_dataset.output_types,
149                              input_dataset.output_classes)))
150    def tf_key_func(*args):
151      """A wrapper for Defun that facilitates shape inference."""
152      # Pass in shape information from the input_dataset.
153      dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
154                                            input_dataset.output_classes)
155      for arg, shape in zip(args, nest.flatten(dense_shapes)):
156        arg.set_shape(shape)
157
158      nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
159      nested_args = sparse.deserialize_sparse_tensors(
160          nested_args, input_dataset.output_types, input_dataset.output_shapes,
161          input_dataset.output_classes)
162      # pylint: disable=protected-access
163      if dataset_ops._should_unpack_args(nested_args):
164        ret = key_func(*nested_args)
165      # pylint: enable=protected-access
166      else:
167        ret = key_func(nested_args)
168      ret = ops.convert_to_tensor(ret, dtype=dtypes.int64)
169      if ret.dtype != dtypes.int64:
170        raise ValueError("`key_func` must return a single tf.int64 tensor.")
171      return ret
172
173    self._key_func = tf_key_func
174    self._key_func.add_to_graph(ops.get_default_graph())
175
176  def _make_reduce_func(self, reduce_func, input_dataset):
177    """Make wrapping Defun for reduce_func."""
178
179    @function.Defun(dtypes.int64, dtypes.variant)
180    def tf_reduce_func(key, window_dataset_variant):
181      """A wrapper for Defun that facilitates shape inference."""
182      key.set_shape([])
183      window_dataset = _VariantDataset(
184          window_dataset_variant, input_dataset.output_types,
185          input_dataset.output_shapes, input_dataset.output_classes)
186      if not isinstance(window_dataset, dataset_ops.Dataset):
187        raise TypeError("`window_dataset` must return a `Dataset` object.")
188      output_dataset = reduce_func(key, window_dataset)
189      if not isinstance(output_dataset, dataset_ops.Dataset):
190        raise TypeError("`reduce_func` must return a `Dataset` object.")
191      self._output_classes = output_dataset.output_classes
192      self._output_types = output_dataset.output_types
193      self._output_shapes = output_dataset.output_shapes
194      return output_dataset._as_variant_tensor()  # pylint: disable=protected-access
195
196    self._reduce_func = tf_reduce_func
197    self._reduce_func.add_to_graph(ops.get_default_graph())
198
199  @property
200  def output_classes(self):
201    return self._output_classes
202
203  @property
204  def output_shapes(self):
205    return self._output_shapes
206
207  @property
208  def output_types(self):
209    return self._output_types
210
211  def _as_variant_tensor(self):
212    return gen_dataset_ops.group_by_window_dataset(
213        self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
214        self._key_func.captured_inputs,
215        self._reduce_func.captured_inputs,
216        self._window_size_func.captured_inputs,
217        key_func=self._key_func,
218        reduce_func=self._reduce_func,
219        window_size_func=self._window_size_func,
220        output_types=nest.flatten(
221            sparse.as_dense_types(self.output_types, self.output_classes)),
222        output_shapes=nest.flatten(
223            sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
224