• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Grouping dataset transformations."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.data.ops import dataset_ops
21from tensorflow.python.data.util import nest
22from tensorflow.python.data.util import structure
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.framework import tensor_spec
26from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
27from tensorflow.python.util import deprecation
28from tensorflow.python.util.tf_export import tf_export
29
30
31@tf_export("data.experimental.group_by_reducer")
32def group_by_reducer(key_func, reducer):
33  """A transformation that groups elements and performs a reduction.
34
35  This transformation maps element of a dataset to a key using `key_func` and
36  groups the elements by key. The `reducer` is used to process each group; its
37  `init_func` is used to initialize state for each group when it is created, the
38  `reduce_func` is used to update the state every time an element is mapped to
39  the matching group, and the `finalize_func` is used to map the final state to
40  an output value.
41
42  Args:
43    key_func: A function mapping a nested structure of tensors
44      (having shapes and types defined by `self.output_shapes` and
45      `self.output_types`) to a scalar `tf.int64` tensor.
46    reducer: An instance of `Reducer`, which captures the reduction logic using
47      the `init_func`, `reduce_func`, and `finalize_func` functions.
48
49  Returns:
50    A `Dataset` transformation function, which can be passed to
51    `tf.data.Dataset.apply`.
52  """
53
54  def _apply_fn(dataset):
55    """Function from `Dataset` to `Dataset` that applies the transformation."""
56    return _GroupByReducerDataset(dataset, key_func, reducer)
57
58  return _apply_fn
59
60
61@deprecation.deprecated(None, "Use `tf.data.Dataset.group_by_window(...)`.")
62@tf_export("data.experimental.group_by_window")
63def group_by_window(key_func,
64                    reduce_func,
65                    window_size=None,
66                    window_size_func=None):
67  """A transformation that groups windows of elements by key and reduces them.
68
69  This transformation maps each consecutive element in a dataset to a key
70  using `key_func` and groups the elements by key. It then applies
71  `reduce_func` to at most `window_size_func(key)` elements matching the same
72  key. All except the final window for each key will contain
73  `window_size_func(key)` elements; the final window may be smaller.
74
75  You may provide either a constant `window_size` or a window size determined by
76  the key through `window_size_func`.
77
78  Args:
79    key_func: A function mapping a nested structure of tensors
80      (having shapes and types defined by `self.output_shapes` and
81      `self.output_types`) to a scalar `tf.int64` tensor.
82    reduce_func: A function mapping a key and a dataset of up to `window_size`
83      consecutive elements matching that key to another dataset.
84    window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
85      consecutive elements matching the same key to combine in a single
86      batch, which will be passed to `reduce_func`. Mutually exclusive with
87      `window_size_func`.
88    window_size_func: A function mapping a key to a `tf.int64` scalar
89      `tf.Tensor`, representing the number of consecutive elements matching
90      the same key to combine in a single batch, which will be passed to
91      `reduce_func`. Mutually exclusive with `window_size`.
92
93  Returns:
94    A `Dataset` transformation function, which can be passed to
95    `tf.data.Dataset.apply`.
96
97  Raises:
98    ValueError: if neither or both of {`window_size`, `window_size_func`} are
99      passed.
100  """
101
102  def _apply_fn(dataset):
103    """Function from `Dataset` to `Dataset` that applies the transformation."""
104    return dataset.group_by_window(
105        key_func=key_func,
106        reduce_func=reduce_func,
107        window_size=window_size,
108        window_size_func=window_size_func)
109
110  return _apply_fn
111
112
113@deprecation.deprecated(None,
114                        "Use `tf.data.Dataset.bucket_by_sequence_length(...)`.")
115@tf_export("data.experimental.bucket_by_sequence_length")
116def bucket_by_sequence_length(element_length_func,
117                              bucket_boundaries,
118                              bucket_batch_sizes,
119                              padded_shapes=None,
120                              padding_values=None,
121                              pad_to_bucket_boundary=False,
122                              no_padding=False,
123                              drop_remainder=False):
124  """A transformation that buckets elements in a `Dataset` by length.
125
126  Elements of the `Dataset` are grouped together by length and then are padded
127  and batched.
128
129  This is useful for sequence tasks in which the elements have variable length.
130  Grouping together elements that have similar lengths reduces the total
131  fraction of padding in a batch which increases training step efficiency.
132
133  Below is an example to bucketize the input data to the 3 buckets
134  "[0, 3), [3, 5), [5, inf)" based on sequence length, with batch size 2.
135
136  >>> elements = [
137  ...   [0], [1, 2, 3, 4], [5, 6, 7],
138  ...   [7, 8, 9, 10, 11], [13, 14, 15, 16, 19, 20], [21, 22]]
139
140  >>> dataset = tf.data.Dataset.from_generator(
141  ...     lambda: elements, tf.int64, output_shapes=[None])
142
143  >>> dataset = dataset.apply(
144  ...     tf.data.experimental.bucket_by_sequence_length(
145  ...         element_length_func=lambda elem: tf.shape(elem)[0],
146  ...         bucket_boundaries=[3, 5],
147  ...         bucket_batch_sizes=[2, 2, 2]))
148
149  >>> for elem in dataset.as_numpy_iterator():
150  ...   print(elem)
151  [[1 2 3 4]
152   [5 6 7 0]]
153  [[ 7  8  9 10 11  0]
154   [13 14 15 16 19 20]]
155  [[ 0  0]
156   [21 22]]
157
158  There is also a possibility to pad the dataset till the bucket boundary.
159  You can also provide which value to be used while padding the data.
160  Below example uses `-1` as padding and it also shows the input data
161  being bucketizied to two buckets "[0,3], [4,6]".
162
163  >>> elements = [
164  ...   [0], [1, 2, 3, 4], [5, 6, 7],
165  ...   [7, 8, 9, 10, 11], [13, 14, 15, 16, 19, 20], [21, 22]]
166
167  >>> dataset = tf.data.Dataset.from_generator(
168  ...   lambda: elements, tf.int32, output_shapes=[None])
169
170  >>> dataset = dataset.apply(
171  ...     tf.data.experimental.bucket_by_sequence_length(
172  ...         element_length_func=lambda elem: tf.shape(elem)[0],
173  ...         bucket_boundaries=[4, 7],
174  ...         bucket_batch_sizes=[2, 2, 2],
175  ...         pad_to_bucket_boundary=True,
176  ...         padding_values=-1))
177
178  >>> for elem in dataset.as_numpy_iterator():
179  ...   print(elem)
180  [[ 0 -1 -1]
181   [ 5  6  7]]
182  [[ 1  2  3  4 -1 -1]
183   [ 7  8  9 10 11 -1]]
184  [[21 22 -1]]
185  [[13 14 15 16 19 20]]
186
187  When using `pad_to_bucket_boundary` option, it can be seen that it is
188  not always possible to maintain the bucket batch size.
189  You can drop the batches that do not maintain the bucket batch size by
190  using the option `drop_remainder`. Using the same input data as in the
191  above example you get the following result.
192
193  >>> elements = [
194  ...   [0], [1, 2, 3, 4], [5, 6, 7],
195  ...   [7, 8, 9, 10, 11], [13, 14, 15, 16, 19, 20], [21, 22]]
196
197  >>> dataset = tf.data.Dataset.from_generator(
198  ...   lambda: elements, tf.int32, output_shapes=[None])
199
200  >>> dataset = dataset.apply(
201  ...     tf.data.experimental.bucket_by_sequence_length(
202  ...         element_length_func=lambda elem: tf.shape(elem)[0],
203  ...         bucket_boundaries=[4, 7],
204  ...         bucket_batch_sizes=[2, 2, 2],
205  ...         pad_to_bucket_boundary=True,
206  ...         padding_values=-1,
207  ...         drop_remainder=True))
208
209  >>> for elem in dataset.as_numpy_iterator():
210  ...   print(elem)
211  [[ 0 -1 -1]
212   [ 5  6  7]]
213  [[ 1  2  3  4 -1 -1]
214   [ 7  8  9 10 11 -1]]
215
216  Args:
217    element_length_func: function from element in `Dataset` to `tf.int32`,
218      determines the length of the element, which will determine the bucket it
219      goes into.
220    bucket_boundaries: `list<int>`, upper length boundaries of the buckets.
221    bucket_batch_sizes: `list<int>`, batch size per bucket. Length should be
222      `len(bucket_boundaries) + 1`.
223    padded_shapes: Nested structure of `tf.TensorShape` to pass to
224      `tf.data.Dataset.padded_batch`. If not provided, will use
225      `dataset.output_shapes`, which will result in variable length dimensions
226      being padded out to the maximum length in each batch.
227    padding_values: Values to pad with, passed to
228      `tf.data.Dataset.padded_batch`. Defaults to padding with 0.
229    pad_to_bucket_boundary: bool, if `False`, will pad dimensions with unknown
230      size to maximum length in batch. If `True`, will pad dimensions with
231      unknown size to bucket boundary minus 1 (i.e., the maximum length in each
232      bucket), and caller must ensure that the source `Dataset` does not contain
233      any elements with length longer than `max(bucket_boundaries)`.
234    no_padding: `bool`, indicates whether to pad the batch features (features
235      need to be either of type `tf.sparse.SparseTensor` or of same shape).
236    drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
237      whether the last batch should be dropped in the case it has fewer than
238      `batch_size` elements; the default behavior is not to drop the smaller
239      batch.
240
241  Returns:
242    A `Dataset` transformation function, which can be passed to
243    `tf.data.Dataset.apply`.
244
245  Raises:
246    ValueError: if `len(bucket_batch_sizes) != len(bucket_boundaries) + 1`.
247  """
248
249  def _apply_fn(dataset):
250    return dataset.bucket_by_sequence_length(
251        element_length_func=element_length_func,
252        bucket_boundaries=bucket_boundaries,
253        bucket_batch_sizes=bucket_batch_sizes,
254        padded_shapes=padded_shapes,
255        padding_values=padding_values,
256        pad_to_bucket_boundary=pad_to_bucket_boundary,
257        no_padding=no_padding,
258        drop_remainder=drop_remainder)
259
260  return _apply_fn
261
262
263class _GroupByReducerDataset(dataset_ops.UnaryDataset):
264  """A `Dataset` that groups its input and performs a reduction."""
265
266  def __init__(self, input_dataset, key_func, reducer):
267    """See `group_by_reducer()` for details."""
268    self._input_dataset = input_dataset
269    self._make_key_func(key_func, input_dataset)
270    self._make_init_func(reducer.init_func)
271    self._make_reduce_func(reducer.reduce_func, input_dataset)
272    self._make_finalize_func(reducer.finalize_func)
273    variant_tensor = ged_ops.experimental_group_by_reducer_dataset(
274        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
275        self._key_func.function.captured_inputs,
276        self._init_func.function.captured_inputs,
277        self._reduce_func.function.captured_inputs,
278        self._finalize_func.function.captured_inputs,
279        key_func=self._key_func.function,
280        init_func=self._init_func.function,
281        reduce_func=self._reduce_func.function,
282        finalize_func=self._finalize_func.function,
283        **self._flat_structure)
284    super(_GroupByReducerDataset, self).__init__(input_dataset, variant_tensor)
285
286  def _make_key_func(self, key_func, input_dataset):
287    """Make wrapping defun for key_func."""
288    self._key_func = dataset_ops.StructuredFunctionWrapper(
289        key_func, self._transformation_name(), dataset=input_dataset)
290    if not self._key_func.output_structure.is_compatible_with(
291        tensor_spec.TensorSpec([], dtypes.int64)):
292      raise ValueError(
293          "`key_func` must return a single tf.int64 tensor. "
294          "Got type=%s and shape=%s"
295          % (self._key_func.output_types, self._key_func.output_shapes))
296
297  def _make_init_func(self, init_func):
298    """Make wrapping defun for init_func."""
299    self._init_func = dataset_ops.StructuredFunctionWrapper(
300        init_func,
301        self._transformation_name(),
302        input_structure=tensor_spec.TensorSpec([], dtypes.int64))
303
304  def _make_reduce_func(self, reduce_func, input_dataset):
305    """Make wrapping defun for reduce_func."""
306
307    # Iteratively rerun the reduce function until reaching a fixed point on
308    # `self._state_structure`.
309    self._state_structure = self._init_func.output_structure
310    state_types = self._init_func.output_types
311    state_shapes = self._init_func.output_shapes
312    state_classes = self._init_func.output_classes
313    need_to_rerun = True
314    while need_to_rerun:
315
316      wrapped_func = dataset_ops.StructuredFunctionWrapper(
317          reduce_func,
318          self._transformation_name(),
319          input_structure=(self._state_structure, input_dataset.element_spec),
320          add_to_graph=False)
321
322      # Extract and validate class information from the returned values.
323      for new_state_class, state_class in zip(
324          nest.flatten(wrapped_func.output_classes),
325          nest.flatten(state_classes)):
326        if not issubclass(new_state_class, state_class):
327          raise TypeError(
328              "The element classes for the new state must match the initial "
329              "state. Expected %s; got %s." %
330              (self._state_classes, wrapped_func.output_classes))
331
332      # Extract and validate type information from the returned values.
333      for new_state_type, state_type in zip(
334          nest.flatten(wrapped_func.output_types), nest.flatten(state_types)):
335        if new_state_type != state_type:
336          raise TypeError(
337              "The element types for the new state must match the initial "
338              "state. Expected %s; got %s." %
339              (self._init_func.output_types, wrapped_func.output_types))
340
341      # Extract shape information from the returned values.
342      flat_state_shapes = nest.flatten(state_shapes)
343      flat_new_state_shapes = nest.flatten(wrapped_func.output_shapes)
344      weakened_state_shapes = [
345          original.most_specific_compatible_shape(new)
346          for original, new in zip(flat_state_shapes, flat_new_state_shapes)
347      ]
348
349      need_to_rerun = False
350      for original_shape, weakened_shape in zip(flat_state_shapes,
351                                                weakened_state_shapes):
352        if original_shape.ndims is not None and (
353            weakened_shape.ndims is None or
354            original_shape.as_list() != weakened_shape.as_list()):
355          need_to_rerun = True
356          break
357
358      if need_to_rerun:
359        state_shapes = nest.pack_sequence_as(
360            self._init_func.output_shapes, weakened_state_shapes)
361        self._state_structure = structure.convert_legacy_structure(
362            state_types, state_shapes, state_classes)
363
364    self._reduce_func = wrapped_func
365    self._reduce_func.function.add_to_graph(ops.get_default_graph())
366
367  def _make_finalize_func(self, finalize_func):
368    """Make wrapping defun for finalize_func."""
369    self._finalize_func = dataset_ops.StructuredFunctionWrapper(
370        finalize_func, self._transformation_name(),
371        input_structure=self._state_structure)
372
373  @property
374  def element_spec(self):
375    return self._finalize_func.output_structure
376
377  def _functions(self):
378    return [
379        self._key_func, self._init_func, self._reduce_func, self._finalize_func
380    ]
381
382  def _transformation_name(self):
383    return "tf.data.experimental.group_by_reducer()"
384
385
386@tf_export("data.experimental.Reducer")
387class Reducer(object):
388  """A reducer is used for reducing a set of elements.
389
390  A reducer is represented as a tuple of the three functions:
391    1) initialization function: key => initial state
392    2) reduce function: (old state, input) => new state
393    3) finalization function: state => result
394  """
395
396  def __init__(self, init_func, reduce_func, finalize_func):
397    self._init_func = init_func
398    self._reduce_func = reduce_func
399    self._finalize_func = finalize_func
400
401  @property
402  def init_func(self):
403    return self._init_func
404
405  @property
406  def reduce_func(self):
407    return self._reduce_func
408
409  @property
410  def finalize_func(self):
411    return self._finalize_func
412