• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Grouping dataset transformations."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21
22from tensorflow.python.data.ops import dataset_ops
23from tensorflow.python.data.util import nest
24from tensorflow.python.data.util import structure
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import tensor_shape
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import check_ops
31from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
32from tensorflow.python.ops import math_ops
33from tensorflow.python.util.tf_export import tf_export
34
35
36@tf_export("data.experimental.group_by_reducer")
37def group_by_reducer(key_func, reducer):
38  """A transformation that groups elements and performs a reduction.
39
40  This transformation maps element of a dataset to a key using `key_func` and
41  groups the elements by key. The `reducer` is used to process each group; its
42  `init_func` is used to initialize state for each group when it is created, the
43  `reduce_func` is used to update the state every time an element is mapped to
44  the matching group, and the `finalize_func` is used to map the final state to
45  an output value.
46
47  Args:
48    key_func: A function mapping a nested structure of tensors
49      (having shapes and types defined by `self.output_shapes` and
50      `self.output_types`) to a scalar `tf.int64` tensor.
51    reducer: An instance of `Reducer`, which captures the reduction logic using
52      the `init_func`, `reduce_func`, and `finalize_func` functions.
53
54  Returns:
55    A `Dataset` transformation function, which can be passed to
56    `tf.data.Dataset.apply`.
57  """
58
59  def _apply_fn(dataset):
60    """Function from `Dataset` to `Dataset` that applies the transformation."""
61    return _GroupByReducerDataset(dataset, key_func, reducer)
62
63  return _apply_fn
64
65
66@tf_export("data.experimental.group_by_window")
67def group_by_window(key_func,
68                    reduce_func,
69                    window_size=None,
70                    window_size_func=None):
71  """A transformation that groups windows of elements by key and reduces them.
72
73  This transformation maps each consecutive element in a dataset to a key
74  using `key_func` and groups the elements by key. It then applies
75  `reduce_func` to at most `window_size_func(key)` elements matching the same
76  key. All except the final window for each key will contain
77  `window_size_func(key)` elements; the final window may be smaller.
78
79  You may provide either a constant `window_size` or a window size determined by
80  the key through `window_size_func`.
81
82  Args:
83    key_func: A function mapping a nested structure of tensors
84      (having shapes and types defined by `self.output_shapes` and
85      `self.output_types`) to a scalar `tf.int64` tensor.
86    reduce_func: A function mapping a key and a dataset of up to `window_size`
87      consecutive elements matching that key to another dataset.
88    window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
89      consecutive elements matching the same key to combine in a single
90      batch, which will be passed to `reduce_func`. Mutually exclusive with
91      `window_size_func`.
92    window_size_func: A function mapping a key to a `tf.int64` scalar
93      `tf.Tensor`, representing the number of consecutive elements matching
94      the same key to combine in a single batch, which will be passed to
95      `reduce_func`. Mutually exclusive with `window_size`.
96
97  Returns:
98    A `Dataset` transformation function, which can be passed to
99    `tf.data.Dataset.apply`.
100
101  Raises:
102    ValueError: if neither or both of {`window_size`, `window_size_func`} are
103      passed.
104  """
105  if (window_size is not None and window_size_func or
106      not (window_size is not None or window_size_func)):
107    raise ValueError("Must pass either window_size or window_size_func.")
108
109  if window_size is not None:
110
111    def constant_window_func(unused_key):
112      return ops.convert_to_tensor(window_size, dtype=dtypes.int64)
113
114    window_size_func = constant_window_func
115
116  assert window_size_func is not None
117
118  def _apply_fn(dataset):
119    """Function from `Dataset` to `Dataset` that applies the transformation."""
120    return _GroupByWindowDataset(dataset, key_func, reduce_func,
121                                 window_size_func)
122
123  return _apply_fn
124
125
126@tf_export("data.experimental.bucket_by_sequence_length")
127def bucket_by_sequence_length(element_length_func,
128                              bucket_boundaries,
129                              bucket_batch_sizes,
130                              padded_shapes=None,
131                              padding_values=None,
132                              pad_to_bucket_boundary=False,
133                              no_padding=False,
134                              drop_remainder=False):
135  """A transformation that buckets elements in a `Dataset` by length.
136
137  Elements of the `Dataset` are grouped together by length and then are padded
138  and batched.
139
140  This is useful for sequence tasks in which the elements have variable length.
141  Grouping together elements that have similar lengths reduces the total
142  fraction of padding in a batch which increases training step efficiency.
143
144  Args:
145    element_length_func: function from element in `Dataset` to `tf.int32`,
146      determines the length of the element, which will determine the bucket it
147      goes into.
148    bucket_boundaries: `list<int>`, upper length boundaries of the buckets.
149    bucket_batch_sizes: `list<int>`, batch size per bucket. Length should be
150      `len(bucket_boundaries) + 1`.
151    padded_shapes: Nested structure of `tf.TensorShape` to pass to
152      `tf.data.Dataset.padded_batch`. If not provided, will use
153      `dataset.output_shapes`, which will result in variable length dimensions
154      being padded out to the maximum length in each batch.
155    padding_values: Values to pad with, passed to
156      `tf.data.Dataset.padded_batch`. Defaults to padding with 0.
157    pad_to_bucket_boundary: bool, if `False`, will pad dimensions with unknown
158      size to maximum length in batch. If `True`, will pad dimensions with
159      unknown size to bucket boundary minus 1 (i.e., the maximum length in each
160      bucket), and caller must ensure that the source `Dataset` does not contain
161      any elements with length longer than `max(bucket_boundaries)`.
162    no_padding: `bool`, indicates whether to pad the batch features (features
163      need to be either of type `tf.SparseTensor` or of same shape).
164    drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
165      whether the last batch should be dropped in the case it has fewer than
166      `batch_size` elements; the default behavior is not to drop the smaller
167      batch.
168
169  Returns:
170    A `Dataset` transformation function, which can be passed to
171    `tf.data.Dataset.apply`.
172
173  Raises:
174    ValueError: if `len(bucket_batch_sizes) != len(bucket_boundaries) + 1`.
175  """
176  with ops.name_scope("bucket_by_seq_length"):
177    if len(bucket_batch_sizes) != (len(bucket_boundaries) + 1):
178      raise ValueError(
179          "len(bucket_batch_sizes) must equal len(bucket_boundaries) + 1")
180
181    batch_sizes = constant_op.constant(bucket_batch_sizes, dtype=dtypes.int64)
182
183    def element_to_bucket_id(*args):
184      """Return int64 id of the length bucket for this element."""
185      seq_length = element_length_func(*args)
186
187      boundaries = list(bucket_boundaries)
188      buckets_min = [np.iinfo(np.int32).min] + boundaries
189      buckets_max = boundaries + [np.iinfo(np.int32).max]
190      conditions_c = math_ops.logical_and(
191          math_ops.less_equal(buckets_min, seq_length),
192          math_ops.less(seq_length, buckets_max))
193      bucket_id = math_ops.reduce_min(array_ops.where(conditions_c))
194
195      return bucket_id
196
197    def window_size_fn(bucket_id):
198      # The window size is set to the batch size for this bucket
199      window_size = batch_sizes[bucket_id]
200      return window_size
201
202    def make_padded_shapes(shapes, none_filler=None):
203      padded = []
204      for shape in nest.flatten(shapes):
205        shape = tensor_shape.TensorShape(shape)
206        shape = [
207            none_filler if tensor_shape.dimension_value(d) is None else d
208            for d in shape
209        ]
210        padded.append(shape)
211      return nest.pack_sequence_as(shapes, padded)
212
213    def batching_fn(bucket_id, grouped_dataset):
214      """Batch elements in dataset."""
215      batch_size = window_size_fn(bucket_id)
216      if no_padding:
217        return grouped_dataset.batch(batch_size, drop_remainder=drop_remainder)
218      none_filler = None
219      if pad_to_bucket_boundary:
220        err_msg = ("When pad_to_bucket_boundary=True, elements must have "
221                   "length < max(bucket_boundaries).")
222        check = check_ops.assert_less(
223            bucket_id,
224            constant_op.constant(len(bucket_batch_sizes) - 1,
225                                 dtype=dtypes.int64),
226            message=err_msg)
227        with ops.control_dependencies([check]):
228          boundaries = constant_op.constant(bucket_boundaries,
229                                            dtype=dtypes.int64)
230          bucket_boundary = boundaries[bucket_id]
231          none_filler = bucket_boundary - 1
232      input_shapes = dataset_ops.get_legacy_output_shapes(grouped_dataset)
233      shapes = make_padded_shapes(padded_shapes or input_shapes,
234                                  none_filler=none_filler)
235      return grouped_dataset.padded_batch(
236          batch_size, shapes, padding_values, drop_remainder=drop_remainder)
237
238    def _apply_fn(dataset):
239      return dataset.apply(
240          group_by_window(element_to_bucket_id, batching_fn,
241                          window_size_func=window_size_fn))
242
243    return _apply_fn
244
245
246class _GroupByReducerDataset(dataset_ops.UnaryDataset):
247  """A `Dataset` that groups its input and performs a reduction."""
248
249  def __init__(self, input_dataset, key_func, reducer):
250    """See `group_by_reducer()` for details."""
251    self._input_dataset = input_dataset
252    self._make_key_func(key_func, input_dataset)
253    self._make_init_func(reducer.init_func)
254    self._make_reduce_func(reducer.reduce_func, input_dataset)
255    self._make_finalize_func(reducer.finalize_func)
256    variant_tensor = ged_ops.experimental_group_by_reducer_dataset(
257        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
258        self._key_func.function.captured_inputs,
259        self._init_func.function.captured_inputs,
260        self._reduce_func.function.captured_inputs,
261        self._finalize_func.function.captured_inputs,
262        key_func=self._key_func.function,
263        init_func=self._init_func.function,
264        reduce_func=self._reduce_func.function,
265        finalize_func=self._finalize_func.function,
266        **dataset_ops.flat_structure(self))
267    super(_GroupByReducerDataset, self).__init__(input_dataset, variant_tensor)
268
269  def _make_key_func(self, key_func, input_dataset):
270    """Make wrapping defun for key_func."""
271    self._key_func = dataset_ops.StructuredFunctionWrapper(
272        key_func, self._transformation_name(), dataset=input_dataset)
273    if not self._key_func.output_structure.is_compatible_with(
274        structure.TensorStructure(dtypes.int64, [])):
275      raise ValueError(
276          "`key_func` must return a single tf.int64 tensor. "
277          "Got type=%s and shape=%s"
278          % (self._key_func.output_types, self._key_func.output_shapes))
279
280  def _make_init_func(self, init_func):
281    """Make wrapping defun for init_func."""
282    self._init_func = dataset_ops.StructuredFunctionWrapper(
283        init_func,
284        self._transformation_name(),
285        input_structure=structure.TensorStructure(dtypes.int64, []))
286
287  def _make_reduce_func(self, reduce_func, input_dataset):
288    """Make wrapping defun for reduce_func."""
289
290    # Iteratively rerun the reduce function until reaching a fixed point on
291    # `self._state_structure`.
292    self._state_structure = self._init_func.output_structure
293    state_types = self._init_func.output_types
294    state_shapes = self._init_func.output_shapes
295    state_classes = self._init_func.output_classes
296    need_to_rerun = True
297    while need_to_rerun:
298
299      wrapped_func = dataset_ops.StructuredFunctionWrapper(
300          reduce_func,
301          self._transformation_name(),
302          input_structure=structure.NestedStructure(
303              (self._state_structure, input_dataset._element_structure)),  # pylint: disable=protected-access
304          add_to_graph=False)
305
306      # Extract and validate class information from the returned values.
307      for new_state_class, state_class in zip(
308          nest.flatten(wrapped_func.output_classes),
309          nest.flatten(state_classes)):
310        if not issubclass(new_state_class, state_class):
311          raise TypeError(
312              "The element classes for the new state must match the initial "
313              "state. Expected %s; got %s." %
314              (self._state_classes, wrapped_func.output_classes))
315
316      # Extract and validate type information from the returned values.
317      for new_state_type, state_type in zip(
318          nest.flatten(wrapped_func.output_types), nest.flatten(state_types)):
319        if new_state_type != state_type:
320          raise TypeError(
321              "The element types for the new state must match the initial "
322              "state. Expected %s; got %s." %
323              (self._init_func.output_types, wrapped_func.output_types))
324
325      # Extract shape information from the returned values.
326      flat_state_shapes = nest.flatten(state_shapes)
327      flat_new_state_shapes = nest.flatten(wrapped_func.output_shapes)
328      weakened_state_shapes = [
329          original.most_specific_compatible_shape(new)
330          for original, new in zip(flat_state_shapes, flat_new_state_shapes)
331      ]
332
333      need_to_rerun = False
334      for original_shape, weakened_shape in zip(flat_state_shapes,
335                                                weakened_state_shapes):
336        if original_shape.ndims is not None and (
337            weakened_shape.ndims is None or
338            original_shape.as_list() != weakened_shape.as_list()):
339          need_to_rerun = True
340          break
341
342      if need_to_rerun:
343        state_shapes = nest.pack_sequence_as(
344            self._init_func.output_shapes, weakened_state_shapes)
345        self._state_structure = structure.convert_legacy_structure(
346            state_types, state_shapes, state_classes)
347
348    self._reduce_func = wrapped_func
349    self._reduce_func.function.add_to_graph(ops.get_default_graph())
350
351  def _make_finalize_func(self, finalize_func):
352    """Make wrapping defun for finalize_func."""
353    self._finalize_func = dataset_ops.StructuredFunctionWrapper(
354        finalize_func, self._transformation_name(),
355        input_structure=self._state_structure)
356
357  @property
358  def _element_structure(self):
359    return self._finalize_func.output_structure
360
361  def _functions(self):
362    return [
363        self._key_func, self._init_func, self._reduce_func, self._finalize_func
364    ]
365
366  def _transformation_name(self):
367    return "tf.data.experimental.group_by_reducer()"
368
369
370class _GroupByWindowDataset(dataset_ops.UnaryDataset):
371  """A `Dataset` that groups its input and performs a windowed reduction."""
372
373  def __init__(self, input_dataset, key_func, reduce_func, window_size_func):
374    """See `group_by_window()` for details."""
375    self._input_dataset = input_dataset
376    self._make_key_func(key_func, input_dataset)
377    self._make_reduce_func(reduce_func, input_dataset)
378    self._make_window_size_func(window_size_func)
379    variant_tensor = ged_ops.experimental_group_by_window_dataset(
380        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
381        self._key_func.function.captured_inputs,
382        self._reduce_func.function.captured_inputs,
383        self._window_size_func.function.captured_inputs,
384        key_func=self._key_func.function,
385        reduce_func=self._reduce_func.function,
386        window_size_func=self._window_size_func.function,
387        **dataset_ops.flat_structure(self))
388    super(_GroupByWindowDataset, self).__init__(input_dataset, variant_tensor)
389
390  def _make_window_size_func(self, window_size_func):
391    """Make wrapping defun for window_size_func."""
392
393    def window_size_func_wrapper(key):
394      return ops.convert_to_tensor(window_size_func(key), dtype=dtypes.int64)
395    self._window_size_func = dataset_ops.StructuredFunctionWrapper(
396        window_size_func_wrapper,
397        self._transformation_name(),
398        input_structure=structure.TensorStructure(dtypes.int64, []))
399    if not self._window_size_func.output_structure.is_compatible_with(
400        structure.TensorStructure(dtypes.int64, [])):
401      raise ValueError(
402          "`window_size_func` must return a single tf.int64 scalar tensor.")
403
404  def _make_key_func(self, key_func, input_dataset):
405    """Make wrapping defun for key_func."""
406
407    def key_func_wrapper(*args):
408      return ops.convert_to_tensor(key_func(*args), dtype=dtypes.int64)
409    self._key_func = dataset_ops.StructuredFunctionWrapper(
410        key_func_wrapper, self._transformation_name(), dataset=input_dataset)
411    if not self._key_func.output_structure.is_compatible_with(
412        structure.TensorStructure(dtypes.int64, [])):
413      raise ValueError(
414          "`key_func` must return a single tf.int64 scalar tensor.")
415
416  def _make_reduce_func(self, reduce_func, input_dataset):
417    """Make wrapping defun for reduce_func."""
418    nested_dataset = dataset_ops.DatasetStructure(
419        input_dataset._element_structure)  # pylint: disable=protected-access
420    input_structure = structure.NestedStructure(
421        (structure.TensorStructure(dtypes.int64, []), nested_dataset))
422    self._reduce_func = dataset_ops.StructuredFunctionWrapper(
423        reduce_func, self._transformation_name(),
424        input_structure=input_structure)
425    if not isinstance(
426        self._reduce_func.output_structure, dataset_ops.DatasetStructure):
427      raise TypeError("`reduce_func` must return a `Dataset` object.")
428    # pylint: disable=protected-access
429    self._structure = (
430        self._reduce_func.output_structure._element_structure)
431
432  @property
433  def _element_structure(self):
434    return self._structure
435
436  def _functions(self):
437    return [self._key_func, self._reduce_func, self._window_size_func]
438
439  def _transformation_name(self):
440    return "tf.data.experimental.group_by_window()"
441
442
443@tf_export("data.experimental.Reducer")
444class Reducer(object):
445  """A reducer is used for reducing a set of elements.
446
447  A reducer is represented as a tuple of the three functions:
448    1) initialization function: key => initial state
449    2) reduce function: (old state, input) => new state
450    3) finalization function: state => result
451  """
452
453  def __init__(self, init_func, reduce_func, finalize_func):
454    self._init_func = init_func
455    self._reduce_func = reduce_func
456    self._finalize_func = finalize_func
457
458  @property
459  def init_func(self):
460    return self._init_func
461
462  @property
463  def reduce_func(self):
464    return self._reduce_func
465
466  @property
467  def finalize_func(self):
468    return self._finalize_func
469