1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Grouping dataset transformations.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.data.ops import dataset_ops 21from tensorflow.python.data.util import nest 22from tensorflow.python.data.util import sparse 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import function 25from tensorflow.python.framework import ops 26from tensorflow.python.ops import gen_dataset_ops 27 28 29def group_by_window(key_func, 30 reduce_func, 31 window_size=None, 32 window_size_func=None): 33 """A transformation that groups windows of elements by key and reduces them. 34 35 This transformation maps each consecutive element in a dataset to a key 36 using `key_func` and groups the elements by key. It then applies 37 `reduce_func` to at most `window_size_func(key)` elements matching the same 38 key. All execpt the final window for each key will contain 39 `window_size_func(key)` elements; the final window may be smaller. 40 41 You may provide either a constant `window_size` or a window size determined by 42 the key through `window_size_func`. 43 44 Args: 45 key_func: A function mapping a nested structure of tensors 46 (having shapes and types defined by `self.output_shapes` and 47 `self.output_types`) to a scalar `tf.int64` tensor. 48 reduce_func: A function mapping a key and a dataset of up to `window_size` 49 consecutive elements matching that key to another dataset. 50 window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 51 consecutive elements matching the same key to combine in a single 52 batch, which will be passed to `reduce_func`. Mutually exclusive with 53 `window_size_func`. 54 window_size_func: A function mapping a key to a `tf.int64` scalar 55 `tf.Tensor`, representing the number of consecutive elements matching 56 the same key to combine in a single batch, which will be passed to 57 `reduce_func`. Mutually exclusive with `window_size`. 58 59 Returns: 60 A `Dataset` transformation function, which can be passed to 61 @{tf.data.Dataset.apply}. 62 63 Raises: 64 ValueError: if neither or both of {`window_size`, `window_size_func`} are 65 passed. 66 """ 67 if (window_size is not None and window_size_func or 68 not (window_size is not None or window_size_func)): 69 raise ValueError("Must pass either window_size or window_size_func.") 70 71 if window_size is not None: 72 73 def constant_window_func(unused_key): 74 return ops.convert_to_tensor(window_size, dtype=dtypes.int64) 75 76 window_size_func = constant_window_func 77 78 assert window_size_func is not None 79 80 def _apply_fn(dataset): 81 """Function from `Dataset` to `Dataset` that applies the transformation.""" 82 return GroupByWindowDataset(dataset, key_func, reduce_func, 83 window_size_func) 84 85 return _apply_fn 86 87 88class _VariantDataset(dataset_ops.Dataset): 89 """A Dataset wrapper for a tf.variant-typed function argument.""" 90 91 def __init__(self, dataset_variant, output_types, output_shapes, 92 output_classes): 93 super(_VariantDataset, self).__init__() 94 self._dataset_variant = dataset_variant 95 self._output_types = output_types 96 self._output_shapes = output_shapes 97 self._output_classes = output_classes 98 99 def _as_variant_tensor(self): 100 return self._dataset_variant 101 102 @property 103 def output_classes(self): 104 return self._output_classes 105 106 @property 107 def output_shapes(self): 108 return self._output_shapes 109 110 @property 111 def output_types(self): 112 return self._output_types 113 114 115class GroupByWindowDataset(dataset_ops.Dataset): 116 """A `Dataset` that groups its input and performs a windowed reduction.""" 117 118 def __init__(self, input_dataset, key_func, reduce_func, window_size_func): 119 """See `group_by_window()` for details.""" 120 super(GroupByWindowDataset, self).__init__() 121 122 self._input_dataset = input_dataset 123 124 self._make_key_func(key_func, input_dataset) 125 self._make_reduce_func(reduce_func, input_dataset) 126 self._make_window_size_func(window_size_func) 127 128 def _make_window_size_func(self, window_size_func): 129 """Make wrapping Defun for window_size_func.""" 130 131 @function.Defun(dtypes.int64) 132 def tf_window_size_func(key): 133 key.set_shape([]) 134 window_size = ops.convert_to_tensor( 135 window_size_func(key), dtype=dtypes.int64) 136 if window_size.dtype != dtypes.int64: 137 raise ValueError( 138 "`window_size_func` must return a single tf.int64 tensor.") 139 return window_size 140 141 self._window_size_func = tf_window_size_func 142 self._window_size_func.add_to_graph(ops.get_default_graph()) 143 144 def _make_key_func(self, key_func, input_dataset): 145 """Make wrapping Defun for key_func.""" 146 147 @function.Defun(*nest.flatten( 148 sparse.as_dense_types(input_dataset.output_types, 149 input_dataset.output_classes))) 150 def tf_key_func(*args): 151 """A wrapper for Defun that facilitates shape inference.""" 152 # Pass in shape information from the input_dataset. 153 dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, 154 input_dataset.output_classes) 155 for arg, shape in zip(args, nest.flatten(dense_shapes)): 156 arg.set_shape(shape) 157 158 nested_args = nest.pack_sequence_as(input_dataset.output_types, args) 159 nested_args = sparse.deserialize_sparse_tensors( 160 nested_args, input_dataset.output_types, input_dataset.output_shapes, 161 input_dataset.output_classes) 162 # pylint: disable=protected-access 163 if dataset_ops._should_unpack_args(nested_args): 164 ret = key_func(*nested_args) 165 # pylint: enable=protected-access 166 else: 167 ret = key_func(nested_args) 168 ret = ops.convert_to_tensor(ret, dtype=dtypes.int64) 169 if ret.dtype != dtypes.int64: 170 raise ValueError("`key_func` must return a single tf.int64 tensor.") 171 return ret 172 173 self._key_func = tf_key_func 174 self._key_func.add_to_graph(ops.get_default_graph()) 175 176 def _make_reduce_func(self, reduce_func, input_dataset): 177 """Make wrapping Defun for reduce_func.""" 178 179 @function.Defun(dtypes.int64, dtypes.variant) 180 def tf_reduce_func(key, window_dataset_variant): 181 """A wrapper for Defun that facilitates shape inference.""" 182 key.set_shape([]) 183 window_dataset = _VariantDataset( 184 window_dataset_variant, input_dataset.output_types, 185 input_dataset.output_shapes, input_dataset.output_classes) 186 if not isinstance(window_dataset, dataset_ops.Dataset): 187 raise TypeError("`window_dataset` must return a `Dataset` object.") 188 output_dataset = reduce_func(key, window_dataset) 189 if not isinstance(output_dataset, dataset_ops.Dataset): 190 raise TypeError("`reduce_func` must return a `Dataset` object.") 191 self._output_classes = output_dataset.output_classes 192 self._output_types = output_dataset.output_types 193 self._output_shapes = output_dataset.output_shapes 194 return output_dataset._as_variant_tensor() # pylint: disable=protected-access 195 196 self._reduce_func = tf_reduce_func 197 self._reduce_func.add_to_graph(ops.get_default_graph()) 198 199 @property 200 def output_classes(self): 201 return self._output_classes 202 203 @property 204 def output_shapes(self): 205 return self._output_shapes 206 207 @property 208 def output_types(self): 209 return self._output_types 210 211 def _as_variant_tensor(self): 212 return gen_dataset_ops.group_by_window_dataset( 213 self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access 214 self._key_func.captured_inputs, 215 self._reduce_func.captured_inputs, 216 self._window_size_func.captured_inputs, 217 key_func=self._key_func, 218 reduce_func=self._reduce_func, 219 window_size_func=self._window_size_func, 220 output_types=nest.flatten( 221 sparse.as_dense_types(self.output_types, self.output_classes)), 222 output_shapes=nest.flatten( 223 sparse.as_dense_shapes(self.output_shapes, self.output_classes))) 224