1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Distribution Strategy-related dataset transformations.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import numpy as np 21 22from tensorflow.python.data.experimental.ops.distribute_options import ExternalStatePolicy 23from tensorflow.python.data.ops import dataset_ops 24from tensorflow.python.data.util import nest 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import tensor_shape 29from tensorflow.python.framework import tensor_util 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops 32 33 34class _AutoShardDataset(dataset_ops.UnaryDataset): 35 """A `Dataset` that shards the `Dataset` automatically. 36 37 This dataset takes in an existing dataset and tries to automatically figure 38 out how to shard the dataset in a multi-worker scenario using graph rewrites. 39 40 If the AutoShardPolicy is set to FILE, it walks up the dataset graph until 41 it finds a reader dataset, then inserts a ShardDataset op before that node 42 so that each worker only sees some files. 43 44 If the AutoShardPolicy is set to DATA, it inserts a ShardDataset op at the 45 end of the input pipeline, before any terminal PrefetchDataset if there is 46 one. Additionally, if there is a RebatchDatasetV2 in the input pipeline, it 47 is written to legacy RebatchDataset for correctness reasons, since 48 RebatchDatasetV2 is incompatible with data sharding. 49 50 If the AutoShardPolicy is set to AUTO, it tries to do file-based sharding. 51 If it cannot find a reader dataset, it falls back to doing data-based 52 sharding. 53 54 If the AutoShardPolicy is set to OFF, it does nothing. 55 56 Args: 57 num_workers: Total number of workers to shard this dataset across. 58 index: The current worker index (out of the total number of workers) this 59 dataset is for. 60 num_replicas: The total number of replicas across all workers. This is used 61 only when sharding by data (either DATA or AUTO) in order to rewrite 62 RebatchDatasetV2 to RebatchDataset. 63 64 Raises: 65 NotFoundError: If we cannot find a suitable reader dataset to begin 66 automatically sharding the dataset. 67 """ 68 69 def __init__(self, input_dataset, num_workers, index, num_replicas=None): 70 self._input_dataset = input_dataset 71 72 self._element_spec = input_dataset.element_spec 73 variant_tensor = ged_ops.auto_shard_dataset( 74 self._input_dataset._variant_tensor, # pylint: disable=protected-access 75 num_workers=num_workers, 76 index=index, 77 auto_shard_policy=int( 78 input_dataset.options().experimental_distribute.auto_shard_policy), 79 num_replicas=num_replicas, 80 **self._flat_structure) 81 super(_AutoShardDataset, self).__init__(input_dataset, variant_tensor) 82 83 @property 84 def element_spec(self): 85 return self._element_spec 86 87 88def _AutoShardDatasetV1(input_dataset, num_workers, index, num_replicas=None): # pylint: disable=invalid-name 89 return dataset_ops.DatasetV1Adapter( 90 _AutoShardDataset(input_dataset, num_workers, index, num_replicas)) 91 92 93class _RebatchDataset(dataset_ops.UnaryDataset): 94 """A `Dataset` that rebatches elements from its input into new batch sizes. 95 96 `_RebatchDataset(input_dataset, batch_sizes)` is functionally equivalent to 97 `input_dataset.unbatch().batch(N)`, where the value of N cycles through the 98 `batch_sizes` input list. The elements produced by this dataset have the same 99 rank as the elements of the input dataset. 100 101 For example: 102 103 ```python 104 ds = tf.data.Dataset.range(8) 105 ds = ds.batch(4) 106 ds = _RebatchDataset(ds, batch_sizes=[2, 1, 1]) 107 for elem in ds: 108 print(elem) 109 >> [0, 1], [2], [3], [4, 5], [6], [7] 110 111 ds = tf.data.Dataset.range(16) 112 ds = ds.batch(4) 113 ds = _RebatchDataset(ds, batch_sizes=[6]) 114 for elem in ds: 115 print(elem) 116 >> [0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11], [12, 13, 14, 15] 117 ``` 118 """ 119 120 def __init__(self, input_dataset, batch_sizes, drop_remainder=False): 121 """Creates a _RebatchDataset. 122 123 Args: 124 input_dataset: `Dataset` to rebatch. 125 batch_sizes: A `tf.int64` scalar or vector, representing the size of 126 batches to produce. If this argument is a vector, these values are 127 cycled through in order. 128 drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing 129 whether the last batch should be dropped in the case it has fewer than 130 `batch_sizes[cycle_index] elements; the default behavior is not to drop 131 the smaller batch. 132 """ 133 self._input_dataset = input_dataset 134 self._batch_sizes = ops.convert_to_tensor( 135 batch_sizes, dtype=dtypes.int64, name="batch_sizes") 136 self._drop_remainder = ops.convert_to_tensor( 137 drop_remainder, dtype=dtypes.bool, name="drop_remainder") 138 new_batch_dim = self._compute_static_batch_dim() 139 140 # pylint: disable=protected-access 141 self._element_spec = nest.map_structure( 142 lambda ts: ts._unbatch()._batch(new_batch_dim), 143 dataset_ops.get_structure(input_dataset)) 144 # pylint: enable=protected-access 145 146 input_dataset = dataset_ops.normalize_to_dense(input_dataset) 147 variant_tensor = ged_ops.rebatch_dataset_v2( 148 input_dataset._variant_tensor, # pylint: disable=protected-access 149 batch_sizes=batch_sizes, 150 drop_remainder=drop_remainder, 151 **self._flat_structure) 152 super(_RebatchDataset, self).__init__(input_dataset, variant_tensor) 153 154 def _compute_static_batch_dim(self): 155 """Computes the static batch dimension of a dataset if it can be determined. 156 157 Given the _RebatchDataset parameters, determines the batch dimension of this 158 dataset statically. Returns None if this cannot be determined or is 159 variable. 160 161 Returns: 162 An integer representing the batch dimension of the dataset. If it cannot 163 be determined statically, returns None. 164 165 Raises: 166 ValueError: The batch_sizes parameter is malformed, input_dataset is 167 not batched, or input_dataset batch sizes are incompatible with each 168 other. 169 """ 170 new_batch_dim = tensor_util.constant_value(self._batch_sizes) 171 if new_batch_dim is None: 172 return None 173 174 if isinstance(new_batch_dim, np.ndarray): 175 if len(new_batch_dim.shape) == 1: 176 if np.all(new_batch_dim == new_batch_dim[0]): 177 new_batch_dim = new_batch_dim[0] 178 else: 179 return None 180 elif len(new_batch_dim.shape) > 1: 181 raise ValueError("Expected batch_sizes to be a scalar or vector.") 182 183 if self._may_form_partial_batches(new_batch_dim): 184 return None 185 186 return new_batch_dim 187 188 def _may_form_partial_batches(self, desired_batch_size): 189 """Returns whether this dataset may form partial batches.""" 190 if tensor_util.constant_value(self._drop_remainder): 191 return False 192 193 def get_batch_dim(type_spec): 194 shape = type_spec._to_legacy_output_shapes() # pylint: disable=protected-access 195 if not isinstance(shape, tensor_shape.TensorShape): 196 return None 197 if shape.rank is None: 198 return None 199 if len(shape) < 1: 200 raise ValueError("Expected a dataset whose elements have rank >= 1 " 201 "but found a dataset whose elements are scalars. " 202 "You can fix the issue by adding the `batch` " 203 "transformation to the dataset.") 204 return shape.dims[0].value 205 206 input_batch_dims = [ 207 get_batch_dim(ts) 208 for ts in nest.flatten(dataset_ops.get_structure(self._input_dataset)) 209 ] 210 known_input_batch_dims = [d for d in input_batch_dims if d is not None] 211 212 if not known_input_batch_dims: 213 return True 214 215 known_input_batch_dims = np.asarray(known_input_batch_dims) 216 if not np.all(known_input_batch_dims == known_input_batch_dims[0]): 217 raise ValueError("Batch dimensions of input dataset are not compatible.") 218 219 return known_input_batch_dims[0] % desired_batch_size != 0 220 221 @property 222 def element_spec(self): 223 return self._element_spec 224 225 226class _LegacyRebatchDataset(dataset_ops.UnaryDataset): 227 """A `Dataset` that divides its input batches into `num_replicas` sub-batches. 228 229 For each batch in the input dataset, _LegacyRebatchDataset will produce 230 `num_replicas` smaller batches whose sizes add up to the original batch size. 231 232 For example: 233 234 ```python 235 ds = tf.data.Dataset.range(8) 236 ds = ds.batch(4) 237 ds = _LegacyRebatchDataset(ds, num_replicas=3) 238 for elem in ds: 239 print(elem) 240 >> [0, 1], [2, 3], [], [4, 5], [6, 7], [] 241 ``` 242 """ 243 244 def __init__(self, input_dataset, num_replicas): 245 """Creates a _LegacyRebatchDataset. 246 247 Args: 248 input_dataset: `Dataset` to rebatch. 249 num_replicas: A `tf.int64` scalar, representing the number of sub-batches 250 to split each batch from `input_dataset` into. 251 """ 252 253 def recalculate_batch_size(type_spec): 254 """Recalculates the output_shape after dividing it by num_replicas.""" 255 output_shape = type_spec._to_legacy_output_shapes() # pylint: disable=protected-access 256 if not isinstance(output_shape, tensor_shape.TensorShape): 257 return None 258 259 # If the output shape is unknown, we set the batch dimension to unknown. 260 if output_shape.rank is None: 261 return None 262 263 if len(output_shape) < 1: 264 raise ValueError("Expected a dataset whose elements have rank >= 1 " 265 "but found a dataset whose elements are scalars. " 266 "You can fix the issue by adding the `batch` " 267 "transformation to the dataset.") 268 output_dims = [d.value for d in output_shape.dims] 269 270 if output_dims[0] is not None and output_dims[0] % num_replicas == 0: 271 return output_dims[0] // num_replicas 272 273 # Set the batch dimension to unknown. If the global batch size does not 274 # divide num_replicas evenly, the minibatches may have different sizes. 275 return None 276 277 def rebatch(type_spec): 278 # pylint: disable=protected-access 279 batch_size = recalculate_batch_size(type_spec) 280 return type_spec._unbatch()._batch(batch_size) 281 # pylint: enable=protected-access 282 283 self._element_spec = nest.map_structure( 284 rebatch, dataset_ops.get_structure(input_dataset)) 285 input_dataset = dataset_ops.normalize_to_dense(input_dataset) 286 variant_tensor = ged_ops.rebatch_dataset( 287 input_dataset._variant_tensor, # pylint: disable=protected-access 288 num_replicas=num_replicas, 289 **self._flat_structure) 290 super(_LegacyRebatchDataset, self).__init__(input_dataset, variant_tensor) 291 292 @property 293 def element_spec(self): 294 return self._element_spec 295 296 297class _RemoteDataset(dataset_ops.DatasetSource): 298 """Creates a dataset on a given `device` given a graph def.""" 299 300 def __init__(self, graph_def, device, element_spec): 301 self._elem_spec = element_spec 302 with ops.device(device): 303 variant_tensor = ged_ops.dataset_from_graph(graph_def) 304 super(_RemoteDataset, self).__init__(variant_tensor) 305 306 @property 307 def element_spec(self): 308 return self._elem_spec 309 310 311def replicate(dataset, devices): 312 """A transformation that replicates `dataset` onto a list of devices. 313 314 Args: 315 dataset: A `tf.data.Dataset` object. 316 devices: A list of devices to replicate the dataset on. 317 318 Returns: 319 A dictionary mapping device name to a dataset on that device. 320 """ 321 if not isinstance(dataset, dataset_ops.DatasetV2): 322 raise TypeError("`dataset` must be a `tf.data.Dataset` object.") 323 324 # pylint: disable=protected-access 325 dataset_device = dataset._variant_tensor.device 326 327 datasets = {} 328 if len(devices) == 1 and devices[0] == dataset_device: 329 datasets[devices[0]] = dataset 330 return datasets 331 332 with ops.colocate_with(dataset._variant_tensor): 333 # We apply options before replicating the dataset because options are 334 # currently not automatically preserved through dataset serialization and 335 # thus an explicit application of options here is needed to avoid losing 336 # `dataset` options. 337 # 338 # TODO(b/147325552): Propagating options to C++ upon their setting would 339 # allow us to preserve the options across both variant and GraphDef based 340 # serialization, avoiding the need to explicitly apply options here. 341 dataset = dataset._apply_options() 342 policy = dataset.options().experimental_external_state_policy 343 if policy is None: 344 policy = ExternalStatePolicy.WARN 345 graph_def = dataset._as_serialized_graph( 346 strip_device_assignment=True, external_state_policy=policy) 347 for device in devices: 348 ds = _RemoteDataset(graph_def, device, dataset.element_spec) 349 datasets[device] = ds 350 return datasets 351 352 353def batch_sizes_for_worker(global_batch_size, num_workers, 354 num_replicas_per_worker, worker_index): 355 """Determines how to rebatch a dataset for the given worker. 356 357 Given the global batch size, number of workers, number of replicas per worker, 358 and worker index, returns the correct batch sizes for rebatching a dataset 359 on worker `worker_index` of `num_workers`, such that each global step (across 360 all workers and replicas) will consume global_batch_size elements. The 361 returned value should be passed as the `batch_sizes` input parameter to 362 `tf.data.experimental.rebatch()`. The returned batch sizes meet the following 363 constraints: 364 365 Let G = global_batch_size, W = num_workers, R = num_replicas_per_worker 366 (A) for any worker, len(batch_sizes) = W * R 367 (B) for any worker, sum(batch_sizes) == G 368 (C) for any global step (i.e. R iterations on each worker), the sum of batches 369 consumed by replicas across all workers is G. 370 (D) any two batch sizes of any two replicas differs by at most one. 371 372 For example, suppose we have G = 7, W = 2, R = 2, and suppose we have two 373 files which each contain 7 elements: 374 375 ```python 376 # WORKER 0 377 batch_sizes_0 = batch_sizes_for_worker(global_batch_size=global_batch_size, 378 num_workers=2, 379 num_replicas_per_worker=2, 380 worker_index=0) 381 print(batch_sizes_0) 382 >> [2, 2, 2, 1] 383 384 dataset_0 = tf.data.Dataset.from_tensor_slices(["file_a", "file_b"]) 385 dataset_0 = dataset_0.shard(num_shards, index=0) 386 dataset_0 = dataset_0.batch(7) 387 dataset_0 = dataset_0.apply(tf.data.experimental.rebatch(batch_sizes_0)) 388 for elem in dataset_0: 389 print(elem) 390 >> [[A0, A1], [A2, A3], [A4, A5], [A6]] 391 392 # WORKER 1 393 batch_sizes_1 = batch_sizes_for_worker(global_batch_size=global_batch_size, 394 num_workers=2, 395 num_replicas_per_worker=2, 396 worker_index=1) 397 print(batch_sizes_1) 398 >> [2, 1, 2, 2] 399 400 dataset_1 = tf.data.Dataset.from_tensor_slices(["file_a", "file_b"]) 401 dataset_1 = dataset_1.shard(num_shards, index=1) 402 dataset_1 = dataset_1.batch(7) 403 dataset_1 = dataset_1.apply(tf.data.experimental.rebatch(batch_sizes_1)) 404 for elem in dataset_1: 405 print(elem) 406 >> [[B0, B1], [B2], [B3, B4], [B5, B6]] 407 ``` 408 409 The above example will produce the following elements: 410 411 Step 1: 412 Worker 0 Replica 0: [A0, A1] 413 Worker 0 Replica 1: [A2, A3] 414 Worker 1 Replica 0: [B0, B1] 415 Worker 1 Replica 1: [B2] 416 Total batch size = 7 417 418 Step 2: 419 Worker 0 Replica 0: [A4, A5] 420 Worker 0 Replica 1: [A6] 421 Worker 1 Replica 0: [B3, B4] 422 Worker 1 Replica 1: [B5, B6] 423 Total batch size = 7 424 425 Args: 426 global_batch_size: A `tf.int64` scalar, representing the global batch size. 427 num_workers: An integer representing the number of workers the dataset will 428 be distributed across. 429 num_replicas_per_worker: An integer representing the number of replicas per 430 worker. All workers are assumed to have the same number of replicas. 431 worker_index: An integer index of the worker to be rebatched. 432 433 Returns: 434 A `tf.int64` vector, representing the batch sizes to rebatch the dataset 435 into. 436 """ 437 # Constraint (A) 438 num_subbatches = num_workers * num_replicas_per_worker 439 440 offset = worker_index * num_replicas_per_worker 441 442 const_value = tensor_util.constant_value(global_batch_size) 443 if const_value is not None: 444 # Use the constant global batch size for further calculations 445 global_batch_size = const_value 446 447 # Let N = W * R. Constraint (B) and (D) jointly mean that the iterations 448 # should have batch size either floor(B/N) or ceil(B/N). Namely, of the N 449 # subbatches a batch is split into, B - N * floor(B/N) of them will have size 450 # ceil(B/N), and the rest will have size floor(B/N). 451 floor = global_batch_size // num_subbatches 452 num_ceil = global_batch_size - (num_subbatches * floor) 453 454 # For worker 0, we assign the first num_ceil subbatches to have size 455 # ceil(B/N), and the remainder to have size floor(B/N). The other workers will 456 # each be offset by R * worker_index in order to meet constraint (C). 457 if const_value is not None: 458 # If the global batch size is a known constant value, we return a constant 459 # tensor directly instead of manipulating it with TF ops. This allows for 460 # better downstream shape inference. 461 worker_0 = [floor + 1] * num_ceil + [floor] * (num_subbatches - num_ceil) 462 return ops.convert_to_tensor( 463 worker_0[offset:] + worker_0[:offset], 464 dtype=dtypes.int64, 465 name="batch_sizes") 466 467 worker_0 = array_ops.ones(num_subbatches, dtype=dtypes.int64) 468 worker_0 = floor * worker_0 + array_ops.concat([ 469 array_ops.ones(num_ceil, dtype=dtypes.int64), 470 array_ops.zeros(num_subbatches - num_ceil, dtype=dtypes.int64) 471 ], 472 axis=0) 473 474 return array_ops.concat([worker_0[offset:], worker_0[:offset]], axis=0) 475 476 477def compute_batch_size(dataset): 478 """An operation that returns the batch size of the dataset. 479 480 This op tries to infer the batch size statically by walking up the dataset 481 tree from the final dataset node and returning the batch size of the first 482 batching dataset (such as from .batch() and .padded_batch()) that it 483 encounters. This differs from using the `element_spec` of a dataset in that it 484 does not account for partial batches. 485 486 This operation may fail if it encounters contradictory batch sizes (for 487 example, if the dataset is created by zipping together two datasets with 488 different batch sizes), if there are no explicit batching transformations, or 489 if there are operations downstream from the batching transformation that may 490 modify its batch size. In these cases, it returns a -1. 491 492 Args: 493 dataset: A `tf.data.Dataset` object. 494 495 Returns: 496 A `tf.int64` Tensor representing the batch size of the dataset sans partial 497 batches. If this cannot be inferred statically, the value of this tensor 498 will be -1. 499 """ 500 501 def get_static_batch_dim(output_shape): 502 if output_shape.rank is None: 503 return None 504 return output_shape.dims[0].value 505 506 batch_dims = [ 507 get_static_batch_dim(ts._to_legacy_output_shapes()) # pylint: disable=protected-access 508 for ts in nest.flatten(dataset_ops.get_structure(dataset)) 509 ] 510 511 if all(d is not None for d in batch_dims): 512 513 if all(d == batch_dims[0] for d in batch_dims): 514 # If all batch dimensions are known and equal, return that directly. 515 batch_dim = batch_dims[0] 516 else: 517 # If all batch dimensions are known but not all equal, return -1. 518 batch_dim = -1 519 520 return constant_op.constant( 521 batch_dim, dtype=dtypes.int64, name="static_batch_size") 522 523 # If any batch dimensions are unknown, use compute_batch_size op. 524 return ged_ops.compute_batch_size(dataset._variant_tensor) # pylint: disable=protected-access 525 526 527_AutoShardDatasetV1.__doc__ = _AutoShardDataset.__doc__ 528