1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Grouping dataset transformations.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import numpy as np 21 22from tensorflow.python.data.ops import dataset_ops 23from tensorflow.python.data.util import nest 24from tensorflow.python.data.util import structure 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import tensor_shape 29from tensorflow.python.framework import tensor_spec 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import check_ops 32from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops 33from tensorflow.python.ops import math_ops 34from tensorflow.python.util.tf_export import tf_export 35 36 37@tf_export("data.experimental.group_by_reducer") 38def group_by_reducer(key_func, reducer): 39 """A transformation that groups elements and performs a reduction. 40 41 This transformation maps element of a dataset to a key using `key_func` and 42 groups the elements by key. The `reducer` is used to process each group; its 43 `init_func` is used to initialize state for each group when it is created, the 44 `reduce_func` is used to update the state every time an element is mapped to 45 the matching group, and the `finalize_func` is used to map the final state to 46 an output value. 47 48 Args: 49 key_func: A function mapping a nested structure of tensors 50 (having shapes and types defined by `self.output_shapes` and 51 `self.output_types`) to a scalar `tf.int64` tensor. 52 reducer: An instance of `Reducer`, which captures the reduction logic using 53 the `init_func`, `reduce_func`, and `finalize_func` functions. 54 55 Returns: 56 A `Dataset` transformation function, which can be passed to 57 `tf.data.Dataset.apply`. 58 """ 59 60 def _apply_fn(dataset): 61 """Function from `Dataset` to `Dataset` that applies the transformation.""" 62 return _GroupByReducerDataset(dataset, key_func, reducer) 63 64 return _apply_fn 65 66 67@tf_export("data.experimental.group_by_window") 68def group_by_window(key_func, 69 reduce_func, 70 window_size=None, 71 window_size_func=None): 72 """A transformation that groups windows of elements by key and reduces them. 73 74 This transformation maps each consecutive element in a dataset to a key 75 using `key_func` and groups the elements by key. It then applies 76 `reduce_func` to at most `window_size_func(key)` elements matching the same 77 key. All except the final window for each key will contain 78 `window_size_func(key)` elements; the final window may be smaller. 79 80 You may provide either a constant `window_size` or a window size determined by 81 the key through `window_size_func`. 82 83 Args: 84 key_func: A function mapping a nested structure of tensors 85 (having shapes and types defined by `self.output_shapes` and 86 `self.output_types`) to a scalar `tf.int64` tensor. 87 reduce_func: A function mapping a key and a dataset of up to `window_size` 88 consecutive elements matching that key to another dataset. 89 window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 90 consecutive elements matching the same key to combine in a single 91 batch, which will be passed to `reduce_func`. Mutually exclusive with 92 `window_size_func`. 93 window_size_func: A function mapping a key to a `tf.int64` scalar 94 `tf.Tensor`, representing the number of consecutive elements matching 95 the same key to combine in a single batch, which will be passed to 96 `reduce_func`. Mutually exclusive with `window_size`. 97 98 Returns: 99 A `Dataset` transformation function, which can be passed to 100 `tf.data.Dataset.apply`. 101 102 Raises: 103 ValueError: if neither or both of {`window_size`, `window_size_func`} are 104 passed. 105 """ 106 if (window_size is not None and window_size_func or 107 not (window_size is not None or window_size_func)): 108 raise ValueError("Must pass either window_size or window_size_func.") 109 110 if window_size is not None: 111 112 def constant_window_func(unused_key): 113 return ops.convert_to_tensor(window_size, dtype=dtypes.int64) 114 115 window_size_func = constant_window_func 116 117 assert window_size_func is not None 118 119 def _apply_fn(dataset): 120 """Function from `Dataset` to `Dataset` that applies the transformation.""" 121 return _GroupByWindowDataset(dataset, key_func, reduce_func, 122 window_size_func) 123 124 return _apply_fn 125 126 127@tf_export("data.experimental.bucket_by_sequence_length") 128def bucket_by_sequence_length(element_length_func, 129 bucket_boundaries, 130 bucket_batch_sizes, 131 padded_shapes=None, 132 padding_values=None, 133 pad_to_bucket_boundary=False, 134 no_padding=False, 135 drop_remainder=False): 136 """A transformation that buckets elements in a `Dataset` by length. 137 138 Elements of the `Dataset` are grouped together by length and then are padded 139 and batched. 140 141 This is useful for sequence tasks in which the elements have variable length. 142 Grouping together elements that have similar lengths reduces the total 143 fraction of padding in a batch which increases training step efficiency. 144 145 Below is an example to bucketize the input data to the 3 buckets 146 "[0, 3), [3, 5), [5, inf)" based on sequence length, with batch size 2. 147 148 >>> elements = [ 149 ... [0], [1, 2, 3, 4], [5, 6, 7], 150 ... [7, 8, 9, 10, 11], [13, 14, 15, 16, 19, 20], [21, 22]] 151 152 >>> dataset = tf.data.Dataset.from_generator( 153 ... lambda: elements, tf.int64, output_shapes=[None]) 154 155 >>> dataset = dataset.apply( 156 ... tf.data.experimental.bucket_by_sequence_length( 157 ... element_length_func=lambda elem: tf.shape(elem)[0], 158 ... bucket_boundaries=[3, 5], 159 ... bucket_batch_sizes=[2, 2, 2])) 160 161 >>> for elem in dataset.as_numpy_iterator(): 162 ... print(elem) 163 [[1 2 3 4] 164 [5 6 7 0]] 165 [[ 7 8 9 10 11 0] 166 [13 14 15 16 19 20]] 167 [[ 0 0] 168 [21 22]] 169 170 Args: 171 element_length_func: function from element in `Dataset` to `tf.int32`, 172 determines the length of the element, which will determine the bucket it 173 goes into. 174 bucket_boundaries: `list<int>`, upper length boundaries of the buckets. 175 bucket_batch_sizes: `list<int>`, batch size per bucket. Length should be 176 `len(bucket_boundaries) + 1`. 177 padded_shapes: Nested structure of `tf.TensorShape` to pass to 178 `tf.data.Dataset.padded_batch`. If not provided, will use 179 `dataset.output_shapes`, which will result in variable length dimensions 180 being padded out to the maximum length in each batch. 181 padding_values: Values to pad with, passed to 182 `tf.data.Dataset.padded_batch`. Defaults to padding with 0. 183 pad_to_bucket_boundary: bool, if `False`, will pad dimensions with unknown 184 size to maximum length in batch. If `True`, will pad dimensions with 185 unknown size to bucket boundary minus 1 (i.e., the maximum length in each 186 bucket), and caller must ensure that the source `Dataset` does not contain 187 any elements with length longer than `max(bucket_boundaries)`. 188 no_padding: `bool`, indicates whether to pad the batch features (features 189 need to be either of type `tf.sparse.SparseTensor` or of same shape). 190 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing 191 whether the last batch should be dropped in the case it has fewer than 192 `batch_size` elements; the default behavior is not to drop the smaller 193 batch. 194 195 Returns: 196 A `Dataset` transformation function, which can be passed to 197 `tf.data.Dataset.apply`. 198 199 Raises: 200 ValueError: if `len(bucket_batch_sizes) != len(bucket_boundaries) + 1`. 201 """ 202 with ops.name_scope("bucket_by_seq_length"): 203 if len(bucket_batch_sizes) != (len(bucket_boundaries) + 1): 204 raise ValueError( 205 "len(bucket_batch_sizes) must equal len(bucket_boundaries) + 1") 206 207 batch_sizes = constant_op.constant(bucket_batch_sizes, dtype=dtypes.int64) 208 209 def element_to_bucket_id(*args): 210 """Return int64 id of the length bucket for this element.""" 211 seq_length = element_length_func(*args) 212 213 boundaries = list(bucket_boundaries) 214 buckets_min = [np.iinfo(np.int32).min] + boundaries 215 buckets_max = boundaries + [np.iinfo(np.int32).max] 216 conditions_c = math_ops.logical_and( 217 math_ops.less_equal(buckets_min, seq_length), 218 math_ops.less(seq_length, buckets_max)) 219 bucket_id = math_ops.reduce_min(array_ops.where(conditions_c)) 220 221 return bucket_id 222 223 def window_size_fn(bucket_id): 224 # The window size is set to the batch size for this bucket 225 window_size = batch_sizes[bucket_id] 226 return window_size 227 228 def make_padded_shapes(shapes, none_filler=None): 229 padded = [] 230 for shape in nest.flatten(shapes): 231 shape = tensor_shape.TensorShape(shape) 232 shape = [ 233 none_filler if tensor_shape.dimension_value(d) is None else d 234 for d in shape 235 ] 236 padded.append(shape) 237 return nest.pack_sequence_as(shapes, padded) 238 239 def batching_fn(bucket_id, grouped_dataset): 240 """Batch elements in dataset.""" 241 batch_size = window_size_fn(bucket_id) 242 if no_padding: 243 return grouped_dataset.batch(batch_size, drop_remainder=drop_remainder) 244 none_filler = None 245 if pad_to_bucket_boundary: 246 err_msg = ("When pad_to_bucket_boundary=True, elements must have " 247 "length < max(bucket_boundaries).") 248 check = check_ops.assert_less( 249 bucket_id, 250 constant_op.constant(len(bucket_batch_sizes) - 1, 251 dtype=dtypes.int64), 252 message=err_msg) 253 with ops.control_dependencies([check]): 254 boundaries = constant_op.constant(bucket_boundaries, 255 dtype=dtypes.int64) 256 bucket_boundary = boundaries[bucket_id] 257 none_filler = bucket_boundary - 1 258 input_shapes = dataset_ops.get_legacy_output_shapes(grouped_dataset) 259 shapes = make_padded_shapes(padded_shapes or input_shapes, 260 none_filler=none_filler) 261 return grouped_dataset.padded_batch( 262 batch_size, shapes, padding_values, drop_remainder=drop_remainder) 263 264 def _apply_fn(dataset): 265 return dataset.apply( 266 group_by_window(element_to_bucket_id, batching_fn, 267 window_size_func=window_size_fn)) 268 269 return _apply_fn 270 271 272class _GroupByReducerDataset(dataset_ops.UnaryDataset): 273 """A `Dataset` that groups its input and performs a reduction.""" 274 275 def __init__(self, input_dataset, key_func, reducer): 276 """See `group_by_reducer()` for details.""" 277 self._input_dataset = input_dataset 278 self._make_key_func(key_func, input_dataset) 279 self._make_init_func(reducer.init_func) 280 self._make_reduce_func(reducer.reduce_func, input_dataset) 281 self._make_finalize_func(reducer.finalize_func) 282 variant_tensor = ged_ops.experimental_group_by_reducer_dataset( 283 self._input_dataset._variant_tensor, # pylint: disable=protected-access 284 self._key_func.function.captured_inputs, 285 self._init_func.function.captured_inputs, 286 self._reduce_func.function.captured_inputs, 287 self._finalize_func.function.captured_inputs, 288 key_func=self._key_func.function, 289 init_func=self._init_func.function, 290 reduce_func=self._reduce_func.function, 291 finalize_func=self._finalize_func.function, 292 **self._flat_structure) 293 super(_GroupByReducerDataset, self).__init__(input_dataset, variant_tensor) 294 295 def _make_key_func(self, key_func, input_dataset): 296 """Make wrapping defun for key_func.""" 297 self._key_func = dataset_ops.StructuredFunctionWrapper( 298 key_func, self._transformation_name(), dataset=input_dataset) 299 if not self._key_func.output_structure.is_compatible_with( 300 tensor_spec.TensorSpec([], dtypes.int64)): 301 raise ValueError( 302 "`key_func` must return a single tf.int64 tensor. " 303 "Got type=%s and shape=%s" 304 % (self._key_func.output_types, self._key_func.output_shapes)) 305 306 def _make_init_func(self, init_func): 307 """Make wrapping defun for init_func.""" 308 self._init_func = dataset_ops.StructuredFunctionWrapper( 309 init_func, 310 self._transformation_name(), 311 input_structure=tensor_spec.TensorSpec([], dtypes.int64)) 312 313 def _make_reduce_func(self, reduce_func, input_dataset): 314 """Make wrapping defun for reduce_func.""" 315 316 # Iteratively rerun the reduce function until reaching a fixed point on 317 # `self._state_structure`. 318 self._state_structure = self._init_func.output_structure 319 state_types = self._init_func.output_types 320 state_shapes = self._init_func.output_shapes 321 state_classes = self._init_func.output_classes 322 need_to_rerun = True 323 while need_to_rerun: 324 325 wrapped_func = dataset_ops.StructuredFunctionWrapper( 326 reduce_func, 327 self._transformation_name(), 328 input_structure=(self._state_structure, input_dataset.element_spec), 329 add_to_graph=False) 330 331 # Extract and validate class information from the returned values. 332 for new_state_class, state_class in zip( 333 nest.flatten(wrapped_func.output_classes), 334 nest.flatten(state_classes)): 335 if not issubclass(new_state_class, state_class): 336 raise TypeError( 337 "The element classes for the new state must match the initial " 338 "state. Expected %s; got %s." % 339 (self._state_classes, wrapped_func.output_classes)) 340 341 # Extract and validate type information from the returned values. 342 for new_state_type, state_type in zip( 343 nest.flatten(wrapped_func.output_types), nest.flatten(state_types)): 344 if new_state_type != state_type: 345 raise TypeError( 346 "The element types for the new state must match the initial " 347 "state. Expected %s; got %s." % 348 (self._init_func.output_types, wrapped_func.output_types)) 349 350 # Extract shape information from the returned values. 351 flat_state_shapes = nest.flatten(state_shapes) 352 flat_new_state_shapes = nest.flatten(wrapped_func.output_shapes) 353 weakened_state_shapes = [ 354 original.most_specific_compatible_shape(new) 355 for original, new in zip(flat_state_shapes, flat_new_state_shapes) 356 ] 357 358 need_to_rerun = False 359 for original_shape, weakened_shape in zip(flat_state_shapes, 360 weakened_state_shapes): 361 if original_shape.ndims is not None and ( 362 weakened_shape.ndims is None or 363 original_shape.as_list() != weakened_shape.as_list()): 364 need_to_rerun = True 365 break 366 367 if need_to_rerun: 368 state_shapes = nest.pack_sequence_as( 369 self._init_func.output_shapes, weakened_state_shapes) 370 self._state_structure = structure.convert_legacy_structure( 371 state_types, state_shapes, state_classes) 372 373 self._reduce_func = wrapped_func 374 self._reduce_func.function.add_to_graph(ops.get_default_graph()) 375 376 def _make_finalize_func(self, finalize_func): 377 """Make wrapping defun for finalize_func.""" 378 self._finalize_func = dataset_ops.StructuredFunctionWrapper( 379 finalize_func, self._transformation_name(), 380 input_structure=self._state_structure) 381 382 @property 383 def element_spec(self): 384 return self._finalize_func.output_structure 385 386 def _functions(self): 387 return [ 388 self._key_func, self._init_func, self._reduce_func, self._finalize_func 389 ] 390 391 def _transformation_name(self): 392 return "tf.data.experimental.group_by_reducer()" 393 394 395class _GroupByWindowDataset(dataset_ops.UnaryDataset): 396 """A `Dataset` that groups its input and performs a windowed reduction.""" 397 398 def __init__(self, input_dataset, key_func, reduce_func, window_size_func): 399 """See `group_by_window()` for details.""" 400 self._input_dataset = input_dataset 401 self._make_key_func(key_func, input_dataset) 402 self._make_reduce_func(reduce_func, input_dataset) 403 self._make_window_size_func(window_size_func) 404 variant_tensor = ged_ops.group_by_window_dataset( 405 self._input_dataset._variant_tensor, # pylint: disable=protected-access 406 self._key_func.function.captured_inputs, 407 self._reduce_func.function.captured_inputs, 408 self._window_size_func.function.captured_inputs, 409 key_func=self._key_func.function, 410 reduce_func=self._reduce_func.function, 411 window_size_func=self._window_size_func.function, 412 **self._flat_structure) 413 super(_GroupByWindowDataset, self).__init__(input_dataset, variant_tensor) 414 415 def _make_window_size_func(self, window_size_func): 416 """Make wrapping defun for window_size_func.""" 417 418 def window_size_func_wrapper(key): 419 return ops.convert_to_tensor(window_size_func(key), dtype=dtypes.int64) 420 self._window_size_func = dataset_ops.StructuredFunctionWrapper( 421 window_size_func_wrapper, 422 self._transformation_name(), 423 input_structure=tensor_spec.TensorSpec([], dtypes.int64)) 424 if not self._window_size_func.output_structure.is_compatible_with( 425 tensor_spec.TensorSpec([], dtypes.int64)): 426 raise ValueError( 427 "`window_size_func` must return a single tf.int64 scalar tensor.") 428 429 def _make_key_func(self, key_func, input_dataset): 430 """Make wrapping defun for key_func.""" 431 432 def key_func_wrapper(*args): 433 return ops.convert_to_tensor(key_func(*args), dtype=dtypes.int64) 434 self._key_func = dataset_ops.StructuredFunctionWrapper( 435 key_func_wrapper, self._transformation_name(), dataset=input_dataset) 436 if not self._key_func.output_structure.is_compatible_with( 437 tensor_spec.TensorSpec([], dtypes.int64)): 438 raise ValueError( 439 "`key_func` must return a single tf.int64 scalar tensor.") 440 441 def _make_reduce_func(self, reduce_func, input_dataset): 442 """Make wrapping defun for reduce_func.""" 443 nested_dataset = dataset_ops.DatasetSpec( 444 input_dataset.element_spec) 445 input_structure = (tensor_spec.TensorSpec([], dtypes.int64), nested_dataset) 446 self._reduce_func = dataset_ops.StructuredFunctionWrapper( 447 reduce_func, self._transformation_name(), 448 input_structure=input_structure) 449 if not isinstance( 450 self._reduce_func.output_structure, dataset_ops.DatasetSpec): 451 raise TypeError("`reduce_func` must return a `Dataset` object.") 452 # pylint: disable=protected-access 453 self._element_spec = ( 454 self._reduce_func.output_structure._element_spec) 455 456 @property 457 def element_spec(self): 458 return self._element_spec 459 460 def _functions(self): 461 return [self._key_func, self._reduce_func, self._window_size_func] 462 463 def _transformation_name(self): 464 return "tf.data.experimental.group_by_window()" 465 466 467@tf_export("data.experimental.Reducer") 468class Reducer(object): 469 """A reducer is used for reducing a set of elements. 470 471 A reducer is represented as a tuple of the three functions: 472 1) initialization function: key => initial state 473 2) reduce function: (old state, input) => new state 474 3) finalization function: state => result 475 """ 476 477 def __init__(self, init_func, reduce_func, finalize_func): 478 self._init_func = init_func 479 self._reduce_func = reduce_func 480 self._finalize_func = finalize_func 481 482 @property 483 def init_func(self): 484 return self._init_func 485 486 @property 487 def reduce_func(self): 488 return self._reduce_func 489 490 @property 491 def finalize_func(self): 492 return self._finalize_func 493