1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Grouping dataset transformations.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.data.ops import dataset_ops 21from tensorflow.python.data.util import nest 22from tensorflow.python.data.util import structure 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import tensor_spec 26from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops 27from tensorflow.python.util import deprecation 28from tensorflow.python.util.tf_export import tf_export 29 30 31@tf_export("data.experimental.group_by_reducer") 32def group_by_reducer(key_func, reducer): 33 """A transformation that groups elements and performs a reduction. 34 35 This transformation maps element of a dataset to a key using `key_func` and 36 groups the elements by key. The `reducer` is used to process each group; its 37 `init_func` is used to initialize state for each group when it is created, the 38 `reduce_func` is used to update the state every time an element is mapped to 39 the matching group, and the `finalize_func` is used to map the final state to 40 an output value. 41 42 Args: 43 key_func: A function mapping a nested structure of tensors 44 (having shapes and types defined by `self.output_shapes` and 45 `self.output_types`) to a scalar `tf.int64` tensor. 46 reducer: An instance of `Reducer`, which captures the reduction logic using 47 the `init_func`, `reduce_func`, and `finalize_func` functions. 48 49 Returns: 50 A `Dataset` transformation function, which can be passed to 51 `tf.data.Dataset.apply`. 52 """ 53 54 def _apply_fn(dataset): 55 """Function from `Dataset` to `Dataset` that applies the transformation.""" 56 return _GroupByReducerDataset(dataset, key_func, reducer) 57 58 return _apply_fn 59 60 61@deprecation.deprecated(None, "Use `tf.data.Dataset.group_by_window(...)`.") 62@tf_export("data.experimental.group_by_window") 63def group_by_window(key_func, 64 reduce_func, 65 window_size=None, 66 window_size_func=None): 67 """A transformation that groups windows of elements by key and reduces them. 68 69 This transformation maps each consecutive element in a dataset to a key 70 using `key_func` and groups the elements by key. It then applies 71 `reduce_func` to at most `window_size_func(key)` elements matching the same 72 key. All except the final window for each key will contain 73 `window_size_func(key)` elements; the final window may be smaller. 74 75 You may provide either a constant `window_size` or a window size determined by 76 the key through `window_size_func`. 77 78 Args: 79 key_func: A function mapping a nested structure of tensors 80 (having shapes and types defined by `self.output_shapes` and 81 `self.output_types`) to a scalar `tf.int64` tensor. 82 reduce_func: A function mapping a key and a dataset of up to `window_size` 83 consecutive elements matching that key to another dataset. 84 window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 85 consecutive elements matching the same key to combine in a single 86 batch, which will be passed to `reduce_func`. Mutually exclusive with 87 `window_size_func`. 88 window_size_func: A function mapping a key to a `tf.int64` scalar 89 `tf.Tensor`, representing the number of consecutive elements matching 90 the same key to combine in a single batch, which will be passed to 91 `reduce_func`. Mutually exclusive with `window_size`. 92 93 Returns: 94 A `Dataset` transformation function, which can be passed to 95 `tf.data.Dataset.apply`. 96 97 Raises: 98 ValueError: if neither or both of {`window_size`, `window_size_func`} are 99 passed. 100 """ 101 102 def _apply_fn(dataset): 103 """Function from `Dataset` to `Dataset` that applies the transformation.""" 104 return dataset.group_by_window( 105 key_func=key_func, 106 reduce_func=reduce_func, 107 window_size=window_size, 108 window_size_func=window_size_func) 109 110 return _apply_fn 111 112 113@deprecation.deprecated(None, 114 "Use `tf.data.Dataset.bucket_by_sequence_length(...)`.") 115@tf_export("data.experimental.bucket_by_sequence_length") 116def bucket_by_sequence_length(element_length_func, 117 bucket_boundaries, 118 bucket_batch_sizes, 119 padded_shapes=None, 120 padding_values=None, 121 pad_to_bucket_boundary=False, 122 no_padding=False, 123 drop_remainder=False): 124 """A transformation that buckets elements in a `Dataset` by length. 125 126 Elements of the `Dataset` are grouped together by length and then are padded 127 and batched. 128 129 This is useful for sequence tasks in which the elements have variable length. 130 Grouping together elements that have similar lengths reduces the total 131 fraction of padding in a batch which increases training step efficiency. 132 133 Below is an example to bucketize the input data to the 3 buckets 134 "[0, 3), [3, 5), [5, inf)" based on sequence length, with batch size 2. 135 136 >>> elements = [ 137 ... [0], [1, 2, 3, 4], [5, 6, 7], 138 ... [7, 8, 9, 10, 11], [13, 14, 15, 16, 19, 20], [21, 22]] 139 140 >>> dataset = tf.data.Dataset.from_generator( 141 ... lambda: elements, tf.int64, output_shapes=[None]) 142 143 >>> dataset = dataset.apply( 144 ... tf.data.experimental.bucket_by_sequence_length( 145 ... element_length_func=lambda elem: tf.shape(elem)[0], 146 ... bucket_boundaries=[3, 5], 147 ... bucket_batch_sizes=[2, 2, 2])) 148 149 >>> for elem in dataset.as_numpy_iterator(): 150 ... print(elem) 151 [[1 2 3 4] 152 [5 6 7 0]] 153 [[ 7 8 9 10 11 0] 154 [13 14 15 16 19 20]] 155 [[ 0 0] 156 [21 22]] 157 158 There is also a possibility to pad the dataset till the bucket boundary. 159 You can also provide which value to be used while padding the data. 160 Below example uses `-1` as padding and it also shows the input data 161 being bucketizied to two buckets "[0,3], [4,6]". 162 163 >>> elements = [ 164 ... [0], [1, 2, 3, 4], [5, 6, 7], 165 ... [7, 8, 9, 10, 11], [13, 14, 15, 16, 19, 20], [21, 22]] 166 167 >>> dataset = tf.data.Dataset.from_generator( 168 ... lambda: elements, tf.int32, output_shapes=[None]) 169 170 >>> dataset = dataset.apply( 171 ... tf.data.experimental.bucket_by_sequence_length( 172 ... element_length_func=lambda elem: tf.shape(elem)[0], 173 ... bucket_boundaries=[4, 7], 174 ... bucket_batch_sizes=[2, 2, 2], 175 ... pad_to_bucket_boundary=True, 176 ... padding_values=-1)) 177 178 >>> for elem in dataset.as_numpy_iterator(): 179 ... print(elem) 180 [[ 0 -1 -1] 181 [ 5 6 7]] 182 [[ 1 2 3 4 -1 -1] 183 [ 7 8 9 10 11 -1]] 184 [[21 22 -1]] 185 [[13 14 15 16 19 20]] 186 187 When using `pad_to_bucket_boundary` option, it can be seen that it is 188 not always possible to maintain the bucket batch size. 189 You can drop the batches that do not maintain the bucket batch size by 190 using the option `drop_remainder`. Using the same input data as in the 191 above example you get the following result. 192 193 >>> elements = [ 194 ... [0], [1, 2, 3, 4], [5, 6, 7], 195 ... [7, 8, 9, 10, 11], [13, 14, 15, 16, 19, 20], [21, 22]] 196 197 >>> dataset = tf.data.Dataset.from_generator( 198 ... lambda: elements, tf.int32, output_shapes=[None]) 199 200 >>> dataset = dataset.apply( 201 ... tf.data.experimental.bucket_by_sequence_length( 202 ... element_length_func=lambda elem: tf.shape(elem)[0], 203 ... bucket_boundaries=[4, 7], 204 ... bucket_batch_sizes=[2, 2, 2], 205 ... pad_to_bucket_boundary=True, 206 ... padding_values=-1, 207 ... drop_remainder=True)) 208 209 >>> for elem in dataset.as_numpy_iterator(): 210 ... print(elem) 211 [[ 0 -1 -1] 212 [ 5 6 7]] 213 [[ 1 2 3 4 -1 -1] 214 [ 7 8 9 10 11 -1]] 215 216 Args: 217 element_length_func: function from element in `Dataset` to `tf.int32`, 218 determines the length of the element, which will determine the bucket it 219 goes into. 220 bucket_boundaries: `list<int>`, upper length boundaries of the buckets. 221 bucket_batch_sizes: `list<int>`, batch size per bucket. Length should be 222 `len(bucket_boundaries) + 1`. 223 padded_shapes: Nested structure of `tf.TensorShape` to pass to 224 `tf.data.Dataset.padded_batch`. If not provided, will use 225 `dataset.output_shapes`, which will result in variable length dimensions 226 being padded out to the maximum length in each batch. 227 padding_values: Values to pad with, passed to 228 `tf.data.Dataset.padded_batch`. Defaults to padding with 0. 229 pad_to_bucket_boundary: bool, if `False`, will pad dimensions with unknown 230 size to maximum length in batch. If `True`, will pad dimensions with 231 unknown size to bucket boundary minus 1 (i.e., the maximum length in each 232 bucket), and caller must ensure that the source `Dataset` does not contain 233 any elements with length longer than `max(bucket_boundaries)`. 234 no_padding: `bool`, indicates whether to pad the batch features (features 235 need to be either of type `tf.sparse.SparseTensor` or of same shape). 236 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing 237 whether the last batch should be dropped in the case it has fewer than 238 `batch_size` elements; the default behavior is not to drop the smaller 239 batch. 240 241 Returns: 242 A `Dataset` transformation function, which can be passed to 243 `tf.data.Dataset.apply`. 244 245 Raises: 246 ValueError: if `len(bucket_batch_sizes) != len(bucket_boundaries) + 1`. 247 """ 248 249 def _apply_fn(dataset): 250 return dataset.bucket_by_sequence_length( 251 element_length_func=element_length_func, 252 bucket_boundaries=bucket_boundaries, 253 bucket_batch_sizes=bucket_batch_sizes, 254 padded_shapes=padded_shapes, 255 padding_values=padding_values, 256 pad_to_bucket_boundary=pad_to_bucket_boundary, 257 no_padding=no_padding, 258 drop_remainder=drop_remainder) 259 260 return _apply_fn 261 262 263class _GroupByReducerDataset(dataset_ops.UnaryDataset): 264 """A `Dataset` that groups its input and performs a reduction.""" 265 266 def __init__(self, input_dataset, key_func, reducer): 267 """See `group_by_reducer()` for details.""" 268 self._input_dataset = input_dataset 269 self._make_key_func(key_func, input_dataset) 270 self._make_init_func(reducer.init_func) 271 self._make_reduce_func(reducer.reduce_func, input_dataset) 272 self._make_finalize_func(reducer.finalize_func) 273 variant_tensor = ged_ops.experimental_group_by_reducer_dataset( 274 self._input_dataset._variant_tensor, # pylint: disable=protected-access 275 self._key_func.function.captured_inputs, 276 self._init_func.function.captured_inputs, 277 self._reduce_func.function.captured_inputs, 278 self._finalize_func.function.captured_inputs, 279 key_func=self._key_func.function, 280 init_func=self._init_func.function, 281 reduce_func=self._reduce_func.function, 282 finalize_func=self._finalize_func.function, 283 **self._flat_structure) 284 super(_GroupByReducerDataset, self).__init__(input_dataset, variant_tensor) 285 286 def _make_key_func(self, key_func, input_dataset): 287 """Make wrapping defun for key_func.""" 288 self._key_func = dataset_ops.StructuredFunctionWrapper( 289 key_func, self._transformation_name(), dataset=input_dataset) 290 if not self._key_func.output_structure.is_compatible_with( 291 tensor_spec.TensorSpec([], dtypes.int64)): 292 raise ValueError( 293 "`key_func` must return a single tf.int64 tensor. " 294 "Got type=%s and shape=%s" 295 % (self._key_func.output_types, self._key_func.output_shapes)) 296 297 def _make_init_func(self, init_func): 298 """Make wrapping defun for init_func.""" 299 self._init_func = dataset_ops.StructuredFunctionWrapper( 300 init_func, 301 self._transformation_name(), 302 input_structure=tensor_spec.TensorSpec([], dtypes.int64)) 303 304 def _make_reduce_func(self, reduce_func, input_dataset): 305 """Make wrapping defun for reduce_func.""" 306 307 # Iteratively rerun the reduce function until reaching a fixed point on 308 # `self._state_structure`. 309 self._state_structure = self._init_func.output_structure 310 state_types = self._init_func.output_types 311 state_shapes = self._init_func.output_shapes 312 state_classes = self._init_func.output_classes 313 need_to_rerun = True 314 while need_to_rerun: 315 316 wrapped_func = dataset_ops.StructuredFunctionWrapper( 317 reduce_func, 318 self._transformation_name(), 319 input_structure=(self._state_structure, input_dataset.element_spec), 320 add_to_graph=False) 321 322 # Extract and validate class information from the returned values. 323 for new_state_class, state_class in zip( 324 nest.flatten(wrapped_func.output_classes), 325 nest.flatten(state_classes)): 326 if not issubclass(new_state_class, state_class): 327 raise TypeError( 328 "The element classes for the new state must match the initial " 329 "state. Expected %s; got %s." % 330 (self._state_classes, wrapped_func.output_classes)) 331 332 # Extract and validate type information from the returned values. 333 for new_state_type, state_type in zip( 334 nest.flatten(wrapped_func.output_types), nest.flatten(state_types)): 335 if new_state_type != state_type: 336 raise TypeError( 337 "The element types for the new state must match the initial " 338 "state. Expected %s; got %s." % 339 (self._init_func.output_types, wrapped_func.output_types)) 340 341 # Extract shape information from the returned values. 342 flat_state_shapes = nest.flatten(state_shapes) 343 flat_new_state_shapes = nest.flatten(wrapped_func.output_shapes) 344 weakened_state_shapes = [ 345 original.most_specific_compatible_shape(new) 346 for original, new in zip(flat_state_shapes, flat_new_state_shapes) 347 ] 348 349 need_to_rerun = False 350 for original_shape, weakened_shape in zip(flat_state_shapes, 351 weakened_state_shapes): 352 if original_shape.ndims is not None and ( 353 weakened_shape.ndims is None or 354 original_shape.as_list() != weakened_shape.as_list()): 355 need_to_rerun = True 356 break 357 358 if need_to_rerun: 359 state_shapes = nest.pack_sequence_as( 360 self._init_func.output_shapes, weakened_state_shapes) 361 self._state_structure = structure.convert_legacy_structure( 362 state_types, state_shapes, state_classes) 363 364 self._reduce_func = wrapped_func 365 self._reduce_func.function.add_to_graph(ops.get_default_graph()) 366 367 def _make_finalize_func(self, finalize_func): 368 """Make wrapping defun for finalize_func.""" 369 self._finalize_func = dataset_ops.StructuredFunctionWrapper( 370 finalize_func, self._transformation_name(), 371 input_structure=self._state_structure) 372 373 @property 374 def element_spec(self): 375 return self._finalize_func.output_structure 376 377 def _functions(self): 378 return [ 379 self._key_func, self._init_func, self._reduce_func, self._finalize_func 380 ] 381 382 def _transformation_name(self): 383 return "tf.data.experimental.group_by_reducer()" 384 385 386@tf_export("data.experimental.Reducer") 387class Reducer(object): 388 """A reducer is used for reducing a set of elements. 389 390 A reducer is represented as a tuple of the three functions: 391 1) initialization function: key => initial state 392 2) reduce function: (old state, input) => new state 393 3) finalization function: state => result 394 """ 395 396 def __init__(self, init_func, reduce_func, finalize_func): 397 self._init_func = init_func 398 self._reduce_func = reduce_func 399 self._finalize_func = finalize_func 400 401 @property 402 def init_func(self): 403 return self._init_func 404 405 @property 406 def reduce_func(self): 407 return self._reduce_func 408 409 @property 410 def finalize_func(self): 411 return self._finalize_func 412