1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Grouping dataset transformations.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import numpy as np 21 22from tensorflow.python.data.ops import dataset_ops 23from tensorflow.python.data.util import nest 24from tensorflow.python.data.util import structure 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import tensor_shape 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import check_ops 31from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops 32from tensorflow.python.ops import math_ops 33from tensorflow.python.util.tf_export import tf_export 34 35 36@tf_export("data.experimental.group_by_reducer") 37def group_by_reducer(key_func, reducer): 38 """A transformation that groups elements and performs a reduction. 39 40 This transformation maps element of a dataset to a key using `key_func` and 41 groups the elements by key. The `reducer` is used to process each group; its 42 `init_func` is used to initialize state for each group when it is created, the 43 `reduce_func` is used to update the state every time an element is mapped to 44 the matching group, and the `finalize_func` is used to map the final state to 45 an output value. 46 47 Args: 48 key_func: A function mapping a nested structure of tensors 49 (having shapes and types defined by `self.output_shapes` and 50 `self.output_types`) to a scalar `tf.int64` tensor. 51 reducer: An instance of `Reducer`, which captures the reduction logic using 52 the `init_func`, `reduce_func`, and `finalize_func` functions. 53 54 Returns: 55 A `Dataset` transformation function, which can be passed to 56 `tf.data.Dataset.apply`. 57 """ 58 59 def _apply_fn(dataset): 60 """Function from `Dataset` to `Dataset` that applies the transformation.""" 61 return _GroupByReducerDataset(dataset, key_func, reducer) 62 63 return _apply_fn 64 65 66@tf_export("data.experimental.group_by_window") 67def group_by_window(key_func, 68 reduce_func, 69 window_size=None, 70 window_size_func=None): 71 """A transformation that groups windows of elements by key and reduces them. 72 73 This transformation maps each consecutive element in a dataset to a key 74 using `key_func` and groups the elements by key. It then applies 75 `reduce_func` to at most `window_size_func(key)` elements matching the same 76 key. All except the final window for each key will contain 77 `window_size_func(key)` elements; the final window may be smaller. 78 79 You may provide either a constant `window_size` or a window size determined by 80 the key through `window_size_func`. 81 82 Args: 83 key_func: A function mapping a nested structure of tensors 84 (having shapes and types defined by `self.output_shapes` and 85 `self.output_types`) to a scalar `tf.int64` tensor. 86 reduce_func: A function mapping a key and a dataset of up to `window_size` 87 consecutive elements matching that key to another dataset. 88 window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 89 consecutive elements matching the same key to combine in a single 90 batch, which will be passed to `reduce_func`. Mutually exclusive with 91 `window_size_func`. 92 window_size_func: A function mapping a key to a `tf.int64` scalar 93 `tf.Tensor`, representing the number of consecutive elements matching 94 the same key to combine in a single batch, which will be passed to 95 `reduce_func`. Mutually exclusive with `window_size`. 96 97 Returns: 98 A `Dataset` transformation function, which can be passed to 99 `tf.data.Dataset.apply`. 100 101 Raises: 102 ValueError: if neither or both of {`window_size`, `window_size_func`} are 103 passed. 104 """ 105 if (window_size is not None and window_size_func or 106 not (window_size is not None or window_size_func)): 107 raise ValueError("Must pass either window_size or window_size_func.") 108 109 if window_size is not None: 110 111 def constant_window_func(unused_key): 112 return ops.convert_to_tensor(window_size, dtype=dtypes.int64) 113 114 window_size_func = constant_window_func 115 116 assert window_size_func is not None 117 118 def _apply_fn(dataset): 119 """Function from `Dataset` to `Dataset` that applies the transformation.""" 120 return _GroupByWindowDataset(dataset, key_func, reduce_func, 121 window_size_func) 122 123 return _apply_fn 124 125 126@tf_export("data.experimental.bucket_by_sequence_length") 127def bucket_by_sequence_length(element_length_func, 128 bucket_boundaries, 129 bucket_batch_sizes, 130 padded_shapes=None, 131 padding_values=None, 132 pad_to_bucket_boundary=False, 133 no_padding=False, 134 drop_remainder=False): 135 """A transformation that buckets elements in a `Dataset` by length. 136 137 Elements of the `Dataset` are grouped together by length and then are padded 138 and batched. 139 140 This is useful for sequence tasks in which the elements have variable length. 141 Grouping together elements that have similar lengths reduces the total 142 fraction of padding in a batch which increases training step efficiency. 143 144 Args: 145 element_length_func: function from element in `Dataset` to `tf.int32`, 146 determines the length of the element, which will determine the bucket it 147 goes into. 148 bucket_boundaries: `list<int>`, upper length boundaries of the buckets. 149 bucket_batch_sizes: `list<int>`, batch size per bucket. Length should be 150 `len(bucket_boundaries) + 1`. 151 padded_shapes: Nested structure of `tf.TensorShape` to pass to 152 `tf.data.Dataset.padded_batch`. If not provided, will use 153 `dataset.output_shapes`, which will result in variable length dimensions 154 being padded out to the maximum length in each batch. 155 padding_values: Values to pad with, passed to 156 `tf.data.Dataset.padded_batch`. Defaults to padding with 0. 157 pad_to_bucket_boundary: bool, if `False`, will pad dimensions with unknown 158 size to maximum length in batch. If `True`, will pad dimensions with 159 unknown size to bucket boundary minus 1 (i.e., the maximum length in each 160 bucket), and caller must ensure that the source `Dataset` does not contain 161 any elements with length longer than `max(bucket_boundaries)`. 162 no_padding: `bool`, indicates whether to pad the batch features (features 163 need to be either of type `tf.SparseTensor` or of same shape). 164 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing 165 whether the last batch should be dropped in the case it has fewer than 166 `batch_size` elements; the default behavior is not to drop the smaller 167 batch. 168 169 Returns: 170 A `Dataset` transformation function, which can be passed to 171 `tf.data.Dataset.apply`. 172 173 Raises: 174 ValueError: if `len(bucket_batch_sizes) != len(bucket_boundaries) + 1`. 175 """ 176 with ops.name_scope("bucket_by_seq_length"): 177 if len(bucket_batch_sizes) != (len(bucket_boundaries) + 1): 178 raise ValueError( 179 "len(bucket_batch_sizes) must equal len(bucket_boundaries) + 1") 180 181 batch_sizes = constant_op.constant(bucket_batch_sizes, dtype=dtypes.int64) 182 183 def element_to_bucket_id(*args): 184 """Return int64 id of the length bucket for this element.""" 185 seq_length = element_length_func(*args) 186 187 boundaries = list(bucket_boundaries) 188 buckets_min = [np.iinfo(np.int32).min] + boundaries 189 buckets_max = boundaries + [np.iinfo(np.int32).max] 190 conditions_c = math_ops.logical_and( 191 math_ops.less_equal(buckets_min, seq_length), 192 math_ops.less(seq_length, buckets_max)) 193 bucket_id = math_ops.reduce_min(array_ops.where(conditions_c)) 194 195 return bucket_id 196 197 def window_size_fn(bucket_id): 198 # The window size is set to the batch size for this bucket 199 window_size = batch_sizes[bucket_id] 200 return window_size 201 202 def make_padded_shapes(shapes, none_filler=None): 203 padded = [] 204 for shape in nest.flatten(shapes): 205 shape = tensor_shape.TensorShape(shape) 206 shape = [ 207 none_filler if tensor_shape.dimension_value(d) is None else d 208 for d in shape 209 ] 210 padded.append(shape) 211 return nest.pack_sequence_as(shapes, padded) 212 213 def batching_fn(bucket_id, grouped_dataset): 214 """Batch elements in dataset.""" 215 batch_size = window_size_fn(bucket_id) 216 if no_padding: 217 return grouped_dataset.batch(batch_size, drop_remainder=drop_remainder) 218 none_filler = None 219 if pad_to_bucket_boundary: 220 err_msg = ("When pad_to_bucket_boundary=True, elements must have " 221 "length < max(bucket_boundaries).") 222 check = check_ops.assert_less( 223 bucket_id, 224 constant_op.constant(len(bucket_batch_sizes) - 1, 225 dtype=dtypes.int64), 226 message=err_msg) 227 with ops.control_dependencies([check]): 228 boundaries = constant_op.constant(bucket_boundaries, 229 dtype=dtypes.int64) 230 bucket_boundary = boundaries[bucket_id] 231 none_filler = bucket_boundary - 1 232 input_shapes = dataset_ops.get_legacy_output_shapes(grouped_dataset) 233 shapes = make_padded_shapes(padded_shapes or input_shapes, 234 none_filler=none_filler) 235 return grouped_dataset.padded_batch( 236 batch_size, shapes, padding_values, drop_remainder=drop_remainder) 237 238 def _apply_fn(dataset): 239 return dataset.apply( 240 group_by_window(element_to_bucket_id, batching_fn, 241 window_size_func=window_size_fn)) 242 243 return _apply_fn 244 245 246class _GroupByReducerDataset(dataset_ops.UnaryDataset): 247 """A `Dataset` that groups its input and performs a reduction.""" 248 249 def __init__(self, input_dataset, key_func, reducer): 250 """See `group_by_reducer()` for details.""" 251 self._input_dataset = input_dataset 252 self._make_key_func(key_func, input_dataset) 253 self._make_init_func(reducer.init_func) 254 self._make_reduce_func(reducer.reduce_func, input_dataset) 255 self._make_finalize_func(reducer.finalize_func) 256 variant_tensor = ged_ops.experimental_group_by_reducer_dataset( 257 self._input_dataset._variant_tensor, # pylint: disable=protected-access 258 self._key_func.function.captured_inputs, 259 self._init_func.function.captured_inputs, 260 self._reduce_func.function.captured_inputs, 261 self._finalize_func.function.captured_inputs, 262 key_func=self._key_func.function, 263 init_func=self._init_func.function, 264 reduce_func=self._reduce_func.function, 265 finalize_func=self._finalize_func.function, 266 **dataset_ops.flat_structure(self)) 267 super(_GroupByReducerDataset, self).__init__(input_dataset, variant_tensor) 268 269 def _make_key_func(self, key_func, input_dataset): 270 """Make wrapping defun for key_func.""" 271 self._key_func = dataset_ops.StructuredFunctionWrapper( 272 key_func, self._transformation_name(), dataset=input_dataset) 273 if not self._key_func.output_structure.is_compatible_with( 274 structure.TensorStructure(dtypes.int64, [])): 275 raise ValueError( 276 "`key_func` must return a single tf.int64 tensor. " 277 "Got type=%s and shape=%s" 278 % (self._key_func.output_types, self._key_func.output_shapes)) 279 280 def _make_init_func(self, init_func): 281 """Make wrapping defun for init_func.""" 282 self._init_func = dataset_ops.StructuredFunctionWrapper( 283 init_func, 284 self._transformation_name(), 285 input_structure=structure.TensorStructure(dtypes.int64, [])) 286 287 def _make_reduce_func(self, reduce_func, input_dataset): 288 """Make wrapping defun for reduce_func.""" 289 290 # Iteratively rerun the reduce function until reaching a fixed point on 291 # `self._state_structure`. 292 self._state_structure = self._init_func.output_structure 293 state_types = self._init_func.output_types 294 state_shapes = self._init_func.output_shapes 295 state_classes = self._init_func.output_classes 296 need_to_rerun = True 297 while need_to_rerun: 298 299 wrapped_func = dataset_ops.StructuredFunctionWrapper( 300 reduce_func, 301 self._transformation_name(), 302 input_structure=structure.NestedStructure( 303 (self._state_structure, input_dataset._element_structure)), # pylint: disable=protected-access 304 add_to_graph=False) 305 306 # Extract and validate class information from the returned values. 307 for new_state_class, state_class in zip( 308 nest.flatten(wrapped_func.output_classes), 309 nest.flatten(state_classes)): 310 if not issubclass(new_state_class, state_class): 311 raise TypeError( 312 "The element classes for the new state must match the initial " 313 "state. Expected %s; got %s." % 314 (self._state_classes, wrapped_func.output_classes)) 315 316 # Extract and validate type information from the returned values. 317 for new_state_type, state_type in zip( 318 nest.flatten(wrapped_func.output_types), nest.flatten(state_types)): 319 if new_state_type != state_type: 320 raise TypeError( 321 "The element types for the new state must match the initial " 322 "state. Expected %s; got %s." % 323 (self._init_func.output_types, wrapped_func.output_types)) 324 325 # Extract shape information from the returned values. 326 flat_state_shapes = nest.flatten(state_shapes) 327 flat_new_state_shapes = nest.flatten(wrapped_func.output_shapes) 328 weakened_state_shapes = [ 329 original.most_specific_compatible_shape(new) 330 for original, new in zip(flat_state_shapes, flat_new_state_shapes) 331 ] 332 333 need_to_rerun = False 334 for original_shape, weakened_shape in zip(flat_state_shapes, 335 weakened_state_shapes): 336 if original_shape.ndims is not None and ( 337 weakened_shape.ndims is None or 338 original_shape.as_list() != weakened_shape.as_list()): 339 need_to_rerun = True 340 break 341 342 if need_to_rerun: 343 state_shapes = nest.pack_sequence_as( 344 self._init_func.output_shapes, weakened_state_shapes) 345 self._state_structure = structure.convert_legacy_structure( 346 state_types, state_shapes, state_classes) 347 348 self._reduce_func = wrapped_func 349 self._reduce_func.function.add_to_graph(ops.get_default_graph()) 350 351 def _make_finalize_func(self, finalize_func): 352 """Make wrapping defun for finalize_func.""" 353 self._finalize_func = dataset_ops.StructuredFunctionWrapper( 354 finalize_func, self._transformation_name(), 355 input_structure=self._state_structure) 356 357 @property 358 def _element_structure(self): 359 return self._finalize_func.output_structure 360 361 def _functions(self): 362 return [ 363 self._key_func, self._init_func, self._reduce_func, self._finalize_func 364 ] 365 366 def _transformation_name(self): 367 return "tf.data.experimental.group_by_reducer()" 368 369 370class _GroupByWindowDataset(dataset_ops.UnaryDataset): 371 """A `Dataset` that groups its input and performs a windowed reduction.""" 372 373 def __init__(self, input_dataset, key_func, reduce_func, window_size_func): 374 """See `group_by_window()` for details.""" 375 self._input_dataset = input_dataset 376 self._make_key_func(key_func, input_dataset) 377 self._make_reduce_func(reduce_func, input_dataset) 378 self._make_window_size_func(window_size_func) 379 variant_tensor = ged_ops.experimental_group_by_window_dataset( 380 self._input_dataset._variant_tensor, # pylint: disable=protected-access 381 self._key_func.function.captured_inputs, 382 self._reduce_func.function.captured_inputs, 383 self._window_size_func.function.captured_inputs, 384 key_func=self._key_func.function, 385 reduce_func=self._reduce_func.function, 386 window_size_func=self._window_size_func.function, 387 **dataset_ops.flat_structure(self)) 388 super(_GroupByWindowDataset, self).__init__(input_dataset, variant_tensor) 389 390 def _make_window_size_func(self, window_size_func): 391 """Make wrapping defun for window_size_func.""" 392 393 def window_size_func_wrapper(key): 394 return ops.convert_to_tensor(window_size_func(key), dtype=dtypes.int64) 395 self._window_size_func = dataset_ops.StructuredFunctionWrapper( 396 window_size_func_wrapper, 397 self._transformation_name(), 398 input_structure=structure.TensorStructure(dtypes.int64, [])) 399 if not self._window_size_func.output_structure.is_compatible_with( 400 structure.TensorStructure(dtypes.int64, [])): 401 raise ValueError( 402 "`window_size_func` must return a single tf.int64 scalar tensor.") 403 404 def _make_key_func(self, key_func, input_dataset): 405 """Make wrapping defun for key_func.""" 406 407 def key_func_wrapper(*args): 408 return ops.convert_to_tensor(key_func(*args), dtype=dtypes.int64) 409 self._key_func = dataset_ops.StructuredFunctionWrapper( 410 key_func_wrapper, self._transformation_name(), dataset=input_dataset) 411 if not self._key_func.output_structure.is_compatible_with( 412 structure.TensorStructure(dtypes.int64, [])): 413 raise ValueError( 414 "`key_func` must return a single tf.int64 scalar tensor.") 415 416 def _make_reduce_func(self, reduce_func, input_dataset): 417 """Make wrapping defun for reduce_func.""" 418 nested_dataset = dataset_ops.DatasetStructure( 419 input_dataset._element_structure) # pylint: disable=protected-access 420 input_structure = structure.NestedStructure( 421 (structure.TensorStructure(dtypes.int64, []), nested_dataset)) 422 self._reduce_func = dataset_ops.StructuredFunctionWrapper( 423 reduce_func, self._transformation_name(), 424 input_structure=input_structure) 425 if not isinstance( 426 self._reduce_func.output_structure, dataset_ops.DatasetStructure): 427 raise TypeError("`reduce_func` must return a `Dataset` object.") 428 # pylint: disable=protected-access 429 self._structure = ( 430 self._reduce_func.output_structure._element_structure) 431 432 @property 433 def _element_structure(self): 434 return self._structure 435 436 def _functions(self): 437 return [self._key_func, self._reduce_func, self._window_size_func] 438 439 def _transformation_name(self): 440 return "tf.data.experimental.group_by_window()" 441 442 443@tf_export("data.experimental.Reducer") 444class Reducer(object): 445 """A reducer is used for reducing a set of elements. 446 447 A reducer is represented as a tuple of the three functions: 448 1) initialization function: key => initial state 449 2) reduce function: (old state, input) => new state 450 3) finalization function: state => result 451 """ 452 453 def __init__(self, init_func, reduce_func, finalize_func): 454 self._init_func = init_func 455 self._reduce_func = reduce_func 456 self._finalize_func = finalize_func 457 458 @property 459 def init_func(self): 460 return self._init_func 461 462 @property 463 def reduce_func(self): 464 return self._reduce_func 465 466 @property 467 def finalize_func(self): 468 return self._finalize_func 469