1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Various classes representing distributed inputs.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22import sys 23 24import six 25 26from tensorflow.python import tf2 27from tensorflow.python.data.experimental.ops import batching 28from tensorflow.python.data.experimental.ops import cardinality 29from tensorflow.python.data.experimental.ops import distribute 30from tensorflow.python.data.ops import dataset_ops 31from tensorflow.python.data.ops import iterator_ops 32from tensorflow.python.data.ops import multi_device_iterator_ops 33from tensorflow.python.data.ops import optional_ops 34from tensorflow.python.distribute import device_util 35from tensorflow.python.distribute import distribute_utils 36from tensorflow.python.distribute import distribution_strategy_context 37from tensorflow.python.distribute import input_ops 38from tensorflow.python.distribute import reduce_util 39from tensorflow.python.distribute import values 40from tensorflow.python.distribute.distribute_lib import InputReplicationMode 41from tensorflow.python.eager import context 42from tensorflow.python.framework import composite_tensor 43from tensorflow.python.framework import constant_op 44from tensorflow.python.framework import device as tf_device 45from tensorflow.python.framework import dtypes 46from tensorflow.python.framework import errors 47from tensorflow.python.framework import ops 48from tensorflow.python.framework import sparse_tensor 49from tensorflow.python.framework import tensor_shape 50from tensorflow.python.framework import tensor_util 51from tensorflow.python.framework import type_spec 52from tensorflow.python.ops import array_ops 53from tensorflow.python.ops import control_flow_ops 54from tensorflow.python.ops import math_ops 55from tensorflow.python.ops.ragged import ragged_tensor 56from tensorflow.python.types import distribute as distribute_types 57from tensorflow.python.util import nest 58from tensorflow.python.util.compat import collections_abc 59from tensorflow.python.util.deprecation import deprecated 60from tensorflow.python.util.tf_export import tf_export 61from tensorflow.tools.docs import doc_controls 62 63 64def get_distributed_dataset(dataset, 65 input_workers, 66 strategy, 67 num_replicas_in_sync=None, 68 input_context=None): 69 """Returns a distributed dataset from the given tf.data.Dataset instance. 70 71 This is a common function that is used by all strategies to return a 72 distributed dataset. The distributed dataset instance returned is different 73 depending on if we are in a TF 1 or TF 2 context. The distributed dataset 74 instances returned differ from each other in the APIs supported by each of 75 them. 76 77 Args: 78 dataset: a tf.data.Dataset instance. 79 input_workers: an InputWorkers object which specifies devices on which 80 iterators should be created. 81 strategy: a `tf.distribute.Strategy` object, used to run all-reduce to 82 handle last partial batch. 83 num_replicas_in_sync: Optional integer. If this is not None, the value is 84 used to decide how to rebatch datasets into smaller batches so that 85 the total batch size for each step (across all workers and replicas) 86 adds up to `dataset`'s batch size. 87 input_context: `InputContext` for sharding. Only pass this in for between 88 graph multi-worker cases where there is only one `input_worker`. In 89 these cases, we will shard based on the `input_pipeline_id` and 90 `num_input_pipelines` in the `InputContext`. 91 92 Returns: 93 A distributed dataset instance. 94 """ 95 if tf2.enabled(): 96 return DistributedDataset( 97 dataset, 98 input_workers, 99 strategy, 100 num_replicas_in_sync=num_replicas_in_sync, 101 input_context=input_context) 102 else: 103 return DistributedDatasetV1( 104 dataset, 105 input_workers, 106 strategy, 107 num_replicas_in_sync=num_replicas_in_sync, 108 input_context=input_context) 109 110 111def get_distributed_datasets_from_function(dataset_fn, 112 input_workers, 113 input_contexts, 114 strategy, 115 options=None): 116 """Returns a distributed dataset from the given input function. 117 118 This is a common function that is used by all strategies to return a 119 distributed dataset. The distributed dataset instance returned is different 120 depending on if we are in a TF 1 or TF 2 context. The distributed dataset 121 instances returned differ from each other in the APIs supported by each of 122 them. 123 124 Args: 125 dataset_fn: a function that returns a tf.data.Dataset instance. 126 input_workers: an InputWorkers object which specifies devices on which 127 iterators should be created. 128 input_contexts: A list of `InputContext` instances to be passed to call(s) 129 to `dataset_fn`. Length and order should match worker order in 130 `worker_device_pairs`. 131 strategy: a `tf.distribute.Strategy` object, used to run all-reduce to 132 handle last partial batch. 133 options: Default is None. `tf.distribute.InputOptions` used to control 134 options on how this dataset is distributed. 135 136 Returns: 137 A distributed dataset instance. 138 139 Raises: 140 ValueError: if `options.experimental_replication_mode` and 141 `options.experimental_place_dataset_on_device` are not consistent 142 """ 143 if (options is not None and 144 options.experimental_replication_mode != InputReplicationMode.PER_REPLICA 145 and options.experimental_place_dataset_on_device): 146 raise ValueError( 147 "When `experimental_place_dataset_on_device` is set for dataset " 148 "placement, you must also specify `PER_REPLICA` for the " 149 "replication mode") 150 151 if (options is not None and 152 options.experimental_replication_mode == InputReplicationMode.PER_REPLICA 153 and options.experimental_prefetch_to_device and 154 options.experimental_place_dataset_on_device): 155 raise ValueError( 156 "`experimental_place_dataset_on_device` can not be set to True " 157 "when experimental_prefetch_to_device is True and " 158 "replication mode is set to `PER_REPLICA`") 159 160 if tf2.enabled(): 161 return DistributedDatasetsFromFunction(dataset_fn, input_workers, 162 input_contexts, strategy, options) 163 else: 164 return DistributedDatasetsFromFunctionV1( 165 dataset_fn, 166 input_workers, 167 input_contexts, 168 strategy, 169 options) 170 171 172@tf_export("distribute.DistributedIterator", v1=[]) 173class DistributedIteratorInterface(collections_abc.Iterator, 174 distribute_types.Iterator): 175 """An iterator over `tf.distribute.DistributedDataset`. 176 177 `tf.distribute.DistributedIterator` is the primary mechanism for enumerating 178 elements of a `tf.distribute.DistributedDataset`. It supports the Python 179 Iterator protocol, which means it can be iterated over using a for-loop or by 180 fetching individual elements explicitly via `get_next()`. 181 182 You can create a `tf.distribute.DistributedIterator` by calling `iter` on 183 a `tf.distribute.DistributedDataset` or creating a python loop over a 184 `tf.distribute.DistributedDataset`. 185 186 Visit the [tutorial](https://www.tensorflow.org/tutorials/distribute/input) 187 on distributed input for more examples and caveats. 188 """ 189 190 def get_next(self): 191 """Returns the next input from the iterator for all replicas. 192 193 Example use: 194 195 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 196 >>> dataset = tf.data.Dataset.range(100).batch(2) 197 >>> dist_dataset = strategy.experimental_distribute_dataset(dataset) 198 >>> dist_dataset_iterator = iter(dist_dataset) 199 >>> @tf.function 200 ... def one_step(input): 201 ... return input 202 >>> step_num = 5 203 >>> for _ in range(step_num): 204 ... strategy.run(one_step, args=(dist_dataset_iterator.get_next(),)) 205 >>> strategy.experimental_local_results(dist_dataset_iterator.get_next()) 206 (<tf.Tensor: shape=(1,), dtype=int64, numpy=array([10])>, 207 <tf.Tensor: shape=(1,), dtype=int64, numpy=array([11])>) 208 209 Returns: 210 A single `tf.Tensor` or a `tf.distribute.DistributedValues` which contains 211 the next input for all replicas. 212 213 Raises: 214 `tf.errors.OutOfRangeError`: If the end of the iterator has been reached. 215 """ 216 raise NotImplementedError( 217 "DistributedIterator.get_next() must be implemented in descendants.") 218 219 @property 220 def element_spec(self): 221 # pylint: disable=line-too-long 222 """The type specification of an element of `tf.distribute.DistributedIterator`. 223 224 Example usage: 225 226 >>> global_batch_size = 16 227 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 228 >>> dataset = tf.data.Dataset.from_tensors(([1.],[2])).repeat(100).batch(global_batch_size) 229 >>> distributed_iterator = iter(strategy.experimental_distribute_dataset(dataset)) 230 >>> distributed_iterator.element_spec 231 (PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.float32, name=None), 232 TensorSpec(shape=(None, 1), dtype=tf.float32, name=None)), 233 PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.int32, name=None), 234 TensorSpec(shape=(None, 1), dtype=tf.int32, name=None))) 235 236 Returns: 237 A nested structure of `tf.TypeSpec` objects matching the structure of an 238 element of this `tf.distribute.DistributedIterator`. This returned value 239 is typically a `tf.distribute.DistributedValues` object and specifies the 240 `tf.TensorSpec` of individual components. 241 """ 242 raise NotImplementedError( 243 "DistributedIterator.element_spec() must be implemented in descendants") 244 245 def get_next_as_optional(self): 246 # pylint: disable=line-too-long 247 """Returns a `tf.experimental.Optional` that contains the next value for all replicas. 248 249 If the `tf.distribute.DistributedIterator` has reached the end of the 250 sequence, the returned `tf.experimental.Optional` will have no value. 251 252 Example usage: 253 254 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 255 >>> global_batch_size = 2 256 >>> steps_per_loop = 2 257 >>> dataset = tf.data.Dataset.range(10).batch(global_batch_size) 258 >>> distributed_iterator = iter( 259 ... strategy.experimental_distribute_dataset(dataset)) 260 >>> def step_fn(x): 261 ... # train the model with inputs 262 ... return x 263 >>> @tf.function 264 ... def train_fn(distributed_iterator): 265 ... for _ in tf.range(steps_per_loop): 266 ... optional_data = distributed_iterator.get_next_as_optional() 267 ... if not optional_data.has_value(): 268 ... break 269 ... per_replica_results = strategy.run(step_fn, args=(optional_data.get_value(),)) 270 ... tf.print(strategy.experimental_local_results(per_replica_results)) 271 >>> train_fn(distributed_iterator) 272 ... # ([0 1], [2 3]) 273 ... # ([4], []) 274 275 Returns: 276 An `tf.experimental.Optional` object representing the next value from the 277 `tf.distribute.DistributedIterator` (if it has one) or no value. 278 """ 279 # pylint: enable=line-too-long 280 raise NotImplementedError( 281 "get_next_as_optional() not implemented in descendants") 282 283 284@tf_export("distribute.DistributedDataset", v1=[]) 285class DistributedDatasetInterface(collections_abc.Iterable, 286 distribute_types.Iterable): 287 # pylint: disable=line-too-long 288 """Represents a dataset distributed among devices and machines. 289 290 A `tf.distribute.DistributedDataset` could be thought of as a "distributed" 291 dataset. When you use `tf.distribute` API to scale training to multiple 292 devices or machines, you also need to distribute the input data, which leads 293 to a `tf.distribute.DistributedDataset` instance, instead of a 294 `tf.data.Dataset` instance in the non-distributed case. In TF 2.x, 295 `tf.distribute.DistributedDataset` objects are Python iterables. 296 297 Note: `tf.distribute.DistributedDataset` instances are *not* of type 298 `tf.data.Dataset`. It only supports two usages we will mention below: 299 iteration and `element_spec`. We don't support any other APIs to transform or 300 inspect the dataset. 301 302 There are two APIs to create a `tf.distribute.DistributedDataset` object: 303 `tf.distribute.Strategy.experimental_distribute_dataset(dataset)`and 304 `tf.distribute.Strategy.distribute_datasets_from_function(dataset_fn)`. 305 *When to use which?* When you have a `tf.data.Dataset` instance, and the 306 regular batch splitting (i.e. re-batch the input `tf.data.Dataset` instance 307 with a new batch size that is equal to the global batch size divided by the 308 number of replicas in sync) and autosharding (i.e. the 309 `tf.data.experimental.AutoShardPolicy` options) work for you, use the former 310 API. Otherwise, if you are *not* using a canonical `tf.data.Dataset` instance, 311 or you would like to customize the batch splitting or sharding, you can wrap 312 these logic in a `dataset_fn` and use the latter API. Both API handles 313 prefetch to device for the user. For more details and examples, follow the 314 links to the APIs. 315 316 317 There are two main usages of a `DistributedDataset` object: 318 319 1. Iterate over it to generate the input for a single device or multiple 320 devices, which is a `tf.distribute.DistributedValues` instance. To do this, 321 you can: 322 323 * use a pythonic for-loop construct: 324 325 >>> global_batch_size = 4 326 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 327 >>> dataset = tf.data.Dataset.from_tensors(([1.],[1.])).repeat(4).batch(global_batch_size) 328 >>> dist_dataset = strategy.experimental_distribute_dataset(dataset) 329 >>> @tf.function 330 ... def train_step(input): 331 ... features, labels = input 332 ... return labels - 0.3 * features 333 >>> for x in dist_dataset: 334 ... # train_step trains the model using the dataset elements 335 ... loss = strategy.run(train_step, args=(x,)) 336 ... print("Loss is", loss) 337 Loss is PerReplica:{ 338 0: tf.Tensor( 339 [[0.7] 340 [0.7]], shape=(2, 1), dtype=float32), 341 1: tf.Tensor( 342 [[0.7] 343 [0.7]], shape=(2, 1), dtype=float32) 344 } 345 346 Placing the loop inside a `tf.function` will give a performance boost. 347 However `break` and `return` are currently not supported if the loop is 348 placed inside a `tf.function`. We also don't support placing the loop 349 inside a `tf.function` when using 350 `tf.distribute.experimental.MultiWorkerMirroredStrategy` or 351 `tf.distribute.experimental.TPUStrategy` with multiple workers. 352 353 * use `__iter__` to create an explicit iterator, which is of type 354 `tf.distribute.DistributedIterator` 355 356 >>> global_batch_size = 4 357 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 358 >>> train_dataset = tf.data.Dataset.from_tensors(([1.],[1.])).repeat(50).batch(global_batch_size) 359 >>> train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset) 360 >>> @tf.function 361 ... def distributed_train_step(dataset_inputs): 362 ... def train_step(input): 363 ... loss = tf.constant(0.1) 364 ... return loss 365 ... per_replica_losses = strategy.run(train_step, args=(dataset_inputs,)) 366 ... return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,axis=None) 367 >>> EPOCHS = 2 368 >>> STEPS = 3 369 >>> for epoch in range(EPOCHS): 370 ... total_loss = 0.0 371 ... num_batches = 0 372 ... dist_dataset_iterator = iter(train_dist_dataset) 373 ... for _ in range(STEPS): 374 ... total_loss += distributed_train_step(next(dist_dataset_iterator)) 375 ... num_batches += 1 376 ... average_train_loss = total_loss / num_batches 377 ... template = ("Epoch {}, Loss: {:.4f}") 378 ... print (template.format(epoch+1, average_train_loss)) 379 Epoch 1, Loss: 0.2000 380 Epoch 2, Loss: 0.2000 381 382 383 To achieve a performance improvement, you can also wrap the `strategy.run` 384 call with a `tf.range` inside a `tf.function`. This runs multiple steps in a 385 `tf.function`. Autograph will convert it to a `tf.while_loop` on the worker. 386 However, it is less flexible comparing with running a single step inside 387 `tf.function`. For example, you cannot run things eagerly or arbitrary 388 python code within the steps. 389 390 391 2. Inspect the `tf.TypeSpec` of the data generated by `DistributedDataset`. 392 393 `tf.distribute.DistributedDataset` generates 394 `tf.distribute.DistributedValues` as input to the devices. If you pass the 395 input to a `tf.function` and would like to specify the shape and type of 396 each Tensor argument to the function, you can pass a `tf.TypeSpec` object to 397 the `input_signature` argument of the `tf.function`. To get the 398 `tf.TypeSpec` of the input, you can use the `element_spec` property of the 399 `tf.distribute.DistributedDataset` or `tf.distribute.DistributedIterator` 400 object. 401 402 For example: 403 404 >>> global_batch_size = 4 405 >>> epochs = 1 406 >>> steps_per_epoch = 1 407 >>> mirrored_strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 408 >>> dataset = tf.data.Dataset.from_tensors(([2.])).repeat(100).batch(global_batch_size) 409 >>> dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset) 410 >>> @tf.function(input_signature=[dist_dataset.element_spec]) 411 ... def train_step(per_replica_inputs): 412 ... def step_fn(inputs): 413 ... return tf.square(inputs) 414 ... return mirrored_strategy.run(step_fn, args=(per_replica_inputs,)) 415 >>> for _ in range(epochs): 416 ... iterator = iter(dist_dataset) 417 ... for _ in range(steps_per_epoch): 418 ... output = train_step(next(iterator)) 419 ... print(output) 420 PerReplica:{ 421 0: tf.Tensor( 422 [[4.] 423 [4.]], shape=(2, 1), dtype=float32), 424 1: tf.Tensor( 425 [[4.] 426 [4.]], shape=(2, 1), dtype=float32) 427 } 428 429 430 Visit the [tutorial](https://www.tensorflow.org/tutorials/distribute/input) 431 on distributed input for more examples and caveats. 432 """ 433 434 def __iter__(self): 435 """Creates an iterator for the `tf.distribute.DistributedDataset`. 436 437 The returned iterator implements the Python Iterator protocol. 438 439 Example usage: 440 441 >>> global_batch_size = 4 442 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 443 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4]).repeat().batch(global_batch_size) 444 >>> distributed_iterator = iter(strategy.experimental_distribute_dataset(dataset)) 445 >>> print(next(distributed_iterator)) 446 PerReplica:{ 447 0: tf.Tensor([1 2], shape=(2,), dtype=int32), 448 1: tf.Tensor([3 4], shape=(2,), dtype=int32) 449 } 450 451 Returns: 452 An `tf.distribute.DistributedIterator` instance for the given 453 `tf.distribute.DistributedDataset` object to enumerate over the 454 distributed data. 455 """ 456 raise NotImplementedError("Must be implemented in descendants") 457 458 @property 459 def element_spec(self): 460 """The type specification of an element of this `tf.distribute.DistributedDataset`. 461 462 Example usage: 463 464 >>> global_batch_size = 16 465 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 466 >>> dataset = tf.data.Dataset.from_tensors(([1.],[2])).repeat(100).batch(global_batch_size) 467 >>> dist_dataset = strategy.experimental_distribute_dataset(dataset) 468 >>> dist_dataset.element_spec 469 (PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.float32, name=None), 470 TensorSpec(shape=(None, 1), dtype=tf.float32, name=None)), 471 PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.int32, name=None), 472 TensorSpec(shape=(None, 1), dtype=tf.int32, name=None))) 473 474 Returns: 475 A nested structure of `tf.TypeSpec` objects matching the structure of an 476 element of this `tf.distribute.DistributedDataset`. This returned value is 477 typically a `tf.distribute.DistributedValues` object and specifies the 478 `tf.TensorSpec` of individual components. 479 """ 480 raise NotImplementedError( 481 "DistributedDataset.element_spec must be implemented in descendants.") 482 483 @doc_controls.do_not_generate_docs 484 def reduce(self, initial_state, reduce_func): 485 raise NotImplementedError( 486 "DistributedDataset.reduce must be implemented in descendants.") 487 488 489class InputWorkers(object): 490 """A 1-to-many mapping from input worker devices to compute devices.""" 491 492 def __init__(self, worker_device_pairs): 493 """Initialize an `InputWorkers` object. 494 495 Args: 496 worker_device_pairs: A sequence of pairs: 497 `(input device, a tuple of compute devices fed by that input device)`. 498 """ 499 self._worker_device_pairs = worker_device_pairs 500 self._input_worker_devices = tuple(d for d, _ in self._worker_device_pairs) 501 self._fed_devices = tuple(tuple(device_util.canonicalize(d) for d in f) 502 for _, f in self._worker_device_pairs) 503 504 @property 505 def num_workers(self): 506 return len(self._input_worker_devices) 507 508 @property 509 def worker_devices(self): 510 return self._input_worker_devices 511 512 def compute_devices_for_worker(self, worker_index): 513 return self._fed_devices[worker_index] 514 515 def __repr__(self): 516 devices = self.worker_devices 517 debug_repr = ",\n".join(" %d %s: %s" % 518 (i, devices[i], self._fed_devices[i]) 519 for i in range(len(devices))) 520 return "%s:{\n%s}" % (self.__class__.__name__, debug_repr) 521 522 def serialize(self): 523 return self._worker_device_pairs 524 525 def deserialize(self, worker_device_pairs): 526 return InputWorkers(worker_device_pairs) 527 528 529def _get_next_as_optional(iterator, strategy, return_per_replica=False): 530 """Returns an empty dataset indicator and the next input from the iterator. 531 532 Args: 533 iterator: a DistributedIterator object. 534 strategy: the `tf.distribute.Strategy` instance. 535 return_per_replica: a boolean. If True, the returned data will be wrapped 536 with `PerReplica` structure. Otherwise it is a 2D 537 num_input_workers*num_replicas_per_worker list. 538 539 Returns: 540 A tuple (a boolean tensor indicating whether the next batch has value 541 globally, data from all replicas). 542 """ 543 replicas = [] 544 worker_has_values = [] 545 worker_devices = [] 546 for i, worker in enumerate(iterator._input_workers.worker_devices): # pylint: disable=protected-access 547 with ops.device(worker): 548 worker_has_value, next_element = ( 549 iterator._iterators[i].get_next_as_list()) # pylint: disable=protected-access 550 # Collective all-reduce requires explicit devices for inputs. 551 with ops.device("/cpu:0"): 552 # Converting to integers for all-reduce. 553 worker_has_value = math_ops.cast(worker_has_value, dtypes.int64) 554 worker_devices.append(worker_has_value.device) 555 worker_has_values.append(worker_has_value) 556 # Make `replicas` a flat list of values across all replicas. 557 replicas.append(next_element) 558 559 if return_per_replica: 560 flattened_data = [] 561 for per_worker_data in replicas: 562 flattened_data.extend(per_worker_data) 563 replicas = _create_per_replica(flattened_data, strategy) 564 565 # Run an all-reduce to see whether any worker has values. 566 # TODO(b/131423105): we should be able to short-cut the all-reduce in some 567 # cases. 568 if getattr(strategy.extended, "_support_per_replica_values", True): 569 # `reduce` expects a `PerReplica`, so we pass it one, even 570 # though it doesn't actually have a value per replica 571 worker_has_values = values.PerReplica(worker_has_values) 572 global_has_value = strategy.reduce( 573 reduce_util.ReduceOp.SUM, worker_has_values, axis=None) 574 else: 575 assert len(worker_has_values) == 1 576 global_has_value = worker_has_values[0] 577 global_has_value = array_ops.reshape( 578 math_ops.cast(global_has_value, dtypes.bool), []) 579 return global_has_value, replicas 580 581 582def _is_statically_shaped(element_spec): 583 """Test if an iterator output is statically shaped. 584 585 For sparse and ragged tensors this only tests the batch dimension. 586 587 Args: 588 element_spec: a nest structure of `tf.TypeSpec`. The element spec of the 589 dataset of the iterator. 590 591 Returns: 592 True if the shape is static, false otherwise. 593 """ 594 595 for spec in nest.flatten(element_spec): 596 if isinstance( 597 spec, (sparse_tensor.SparseTensorSpec, ragged_tensor.RaggedTensorSpec)): 598 # For sparse or ragged tensor, we should only check the first 599 # dimension in order to get_next_as_optional. This is because 600 # when these tensors get batched by dataset only the batch dimension 601 # is set. 602 if spec.shape.rank > 0 and spec.shape.as_list()[0] is None: 603 return False 604 else: 605 for component in nest.flatten(spec._component_specs): # pylint: disable=protected-access 606 if not component.shape.is_fully_defined(): 607 return False 608 return True 609 610 611class DistributedIteratorBase(DistributedIteratorInterface): 612 """Common implementation for all input iterators.""" 613 614 # pylint: disable=super-init-not-called 615 def __init__(self, input_workers, iterators, strategy, 616 enable_get_next_as_optional): 617 assert isinstance(input_workers, InputWorkers) 618 if not input_workers.worker_devices: 619 raise ValueError("Should have at least one worker for input iterator.") 620 621 self._iterators = iterators 622 self._input_workers = input_workers 623 self._strategy = strategy 624 self._enable_get_next_as_optional = enable_get_next_as_optional 625 626 def next(self): 627 return self.__next__() 628 629 def __next__(self): 630 try: 631 return self.get_next() 632 except errors.OutOfRangeError: 633 raise StopIteration 634 635 def __iter__(self): 636 return self 637 638 def get_next_as_optional(self): 639 global_has_value, replicas = _get_next_as_optional( 640 self, self._strategy, return_per_replica=True) 641 642 def return_none(): 643 return optional_ops.Optional.empty(self._element_spec) 644 645 return control_flow_ops.cond( 646 global_has_value, lambda: optional_ops.Optional.from_value(replicas), 647 return_none) 648 649 def get_next(self, name=None): 650 """Returns the next input from the iterator for all replicas.""" 651 if not self._enable_get_next_as_optional: 652 replicas = [] 653 for i, worker in enumerate(self._input_workers.worker_devices): 654 if name is not None: 655 d = tf_device.DeviceSpec.from_string(worker) 656 new_name = "%s_%s_%d" % (name, d.job, d.task) 657 else: 658 new_name = None 659 with ops.device(worker): 660 # Make `replicas` a flat list of values across all replicas. 661 replicas.extend( 662 self._iterators[i].get_next_as_list_static_shapes(new_name)) 663 return _create_per_replica(replicas, self._strategy) 664 665 out_of_range_replicas = [] 666 def out_of_range_fn(worker_index, device): 667 """This function will throw an OutOfRange error.""" 668 # As this will be only called when there is no data left, so calling 669 # get_next() will trigger an OutOfRange error. 670 data = self._iterators[worker_index].get_next(device) 671 out_of_range_replicas.append(data) 672 return data 673 674 global_has_value, replicas = _get_next_as_optional( 675 self, self._strategy, return_per_replica=False) 676 results = [] 677 for i, worker in enumerate(self._input_workers.worker_devices): 678 with ops.device(worker): 679 devices = self._input_workers.compute_devices_for_worker(i) 680 for j, device in enumerate(devices): 681 with ops.device(device): 682 # pylint: disable=undefined-loop-variable 683 # pylint: disable=cell-var-from-loop 684 # It is fine for the lambda to capture variables from the loop as 685 # the lambda is executed in the loop as well. 686 result = control_flow_ops.cond( 687 global_has_value, 688 lambda: replicas[i][j], 689 lambda: out_of_range_fn(i, device), 690 strict=True, 691 ) 692 # pylint: enable=cell-var-from-loop 693 # pylint: enable=undefined-loop-variable 694 results.append(result) 695 replicas = results 696 697 return _create_per_replica(replicas, self._strategy) 698 699 700class DistributedIteratorV1(DistributedIteratorBase): 701 """Input Iterator for a distributed dataset.""" 702 703 # We need a private initializer method for re-initializing multidevice 704 # iterators when used with Keras training loops. If we don't reinitialize the 705 # iterator we run into memory leak issues (b/123315763). 706 @property 707 def _initializer(self): 708 init_ops = [] 709 for it in self._iterators: 710 init_ops.extend(it.initialize()) 711 return control_flow_ops.group(init_ops) 712 713 @deprecated(None, "Use the iterator's `initializer` property instead.") 714 def initialize(self): 715 """Initialize underlying iterators. 716 717 Returns: 718 A list of any initializer ops that should be run. 719 """ 720 return self._initializer 721 722 @property 723 def initializer(self): 724 """Returns a list of ops that initialize the iterator.""" 725 return self.initialize() 726 727 # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs. 728 @property 729 def output_classes(self): 730 return self._iterators[0].output_classes 731 732 # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs. 733 @property 734 def output_shapes(self): 735 return self._iterators[0].output_shapes 736 737 # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs. 738 @property 739 def output_types(self): 740 return self._iterators[0].output_types 741 742 # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs. 743 def get_iterator(self, worker): 744 for i, w in enumerate(self._input_workers.worker_devices): 745 if worker == w: 746 return self._iterators[i] 747 return None 748 749 @property 750 def element_spec(self): 751 """The type specification of an element of this iterator.""" 752 return self._element_spec 753 754 755class DistributedIteratorSpec(type_spec.TypeSpec): 756 """Type specification for `DistributedIterator`.""" 757 758 __slots__ = [ 759 "_input_workers", "_element_spec", "_strategy", 760 "_enable_get_next_as_optional", "_options" 761 ] 762 763 def __init__(self, input_workers, element_spec, strategy, 764 enable_get_next_as_optional, options): 765 # We don't want to allow deserialization of this class because we don't 766 # serialize the strategy object. Currently the only places where 767 # _deserialize is called is when we save/restore using SavedModels. 768 if isinstance(input_workers, tuple): 769 raise NotImplementedError("DistributedIteratorSpec does not have support " 770 "for deserialization.") 771 else: 772 self._input_workers = input_workers 773 self._element_spec = element_spec 774 self._strategy = strategy 775 self._enable_get_next_as_optional = enable_get_next_as_optional 776 self._options = options 777 778 @property 779 def value_type(self): 780 return DistributedIterator 781 782 def _serialize(self): 783 # We cannot serialize the strategy object so we convert it to an id that we 784 # can use for comparison. 785 return (self._input_workers.serialize(), self._element_spec, 786 id(self._strategy), id(self._options)) 787 788 def _deserialize(self): 789 raise ValueError("Deserialization is currently unsupported for " 790 "DistributedIteratorSpec.") 791 792 # Overriding this method so that we can merge and reconstruct the spec object 793 def most_specific_compatible_type(self, other): 794 """Returns the most specific TypeSpec compatible with `self` and `other`. 795 796 Args: 797 other: A `TypeSpec`. 798 799 Raises: 800 ValueError: If there is no TypeSpec that is compatible with both `self` 801 and `other`. 802 """ 803 # pylint: disable=protected-access 804 if type(self) is not type(other): 805 raise ValueError("No TypeSpec is compatible with both %s and %s" % 806 (self, other)) 807 if self._input_workers.serialize() != other._input_workers.serialize(): 808 raise ValueError("_input_workers is not compatible with both %s " 809 "and %s" % (self, other)) 810 if self._strategy is not other._strategy: 811 raise ValueError("tf.distribute strategy is not compatible with both %s " 812 "and %s" % (self, other)) 813 element_spec = nest.map_structure( 814 lambda a, b: a.most_specific_compatible_type(b), self._element_spec, 815 other._element_spec) 816 return DistributedIteratorSpec(self._input_workers, element_spec, 817 self._strategy, 818 self._enable_get_next_as_optional, 819 self._options) 820 821 @property 822 def _component_specs(self): 823 specs = [] 824 worker_device_pairs = self._input_workers._worker_device_pairs # pylint: disable=protected-access 825 826 for i, (input_device, compute_devices) in enumerate(worker_device_pairs): 827 element_spec = nest.map_structure( 828 functools.partial(_replace_per_replica_spec, i=i), self._element_spec) 829 specs.append( 830 _SingleWorkerDatasetIteratorSpec(input_device, compute_devices, 831 element_spec, self._options)) 832 return specs 833 834 def _to_components(self, value): 835 return value._iterators # pylint: disable=protected-access 836 837 def _from_components(self, components): 838 return DistributedIterator( 839 input_workers=self._input_workers, 840 iterators=None, 841 components=components, 842 element_spec=self._element_spec, 843 strategy=self._strategy, 844 enable_get_next_as_optional=self._enable_get_next_as_optional, 845 options=self._options) 846 847 @staticmethod 848 def from_value(value): 849 # pylint: disable=protected-access 850 return DistributedIteratorSpec(value._input_workers, value._element_spec, 851 value._strategy, 852 value._enable_get_next_as_optional, 853 value._options) 854 855 def _with_tensor_ranks_only(self): 856 element_spec = nest.map_structure( 857 lambda s: s._with_tensor_ranks_only(), # pylint: disable=protected-access 858 self._element_spec) 859 return DistributedIteratorSpec(self._input_workers, element_spec, 860 self._strategy, 861 self._enable_get_next_as_optional, 862 self._options) 863 864 865class DistributedIterator(DistributedIteratorBase, 866 composite_tensor.CompositeTensor): 867 """Input Iterator for a distributed dataset.""" 868 869 def __init__(self, 870 input_workers=None, 871 iterators=None, 872 strategy=None, 873 components=None, 874 element_spec=None, 875 enable_get_next_as_optional=False, 876 options=None): 877 if input_workers is None: 878 raise ValueError("`input_workers` should be " 879 "provided.") 880 881 error_message = ("Either `input_workers` or " 882 "both `components` and `element_spec` need to be " 883 "provided.") 884 self._options = options 885 886 if iterators is None: 887 if (components is None or element_spec is None): 888 raise ValueError(error_message) 889 self._element_spec = element_spec 890 self._input_workers = input_workers 891 self._iterators = components 892 self._strategy = strategy 893 self._enable_get_next_as_optional = enable_get_next_as_optional 894 else: 895 if (components is not None and element_spec is not None): 896 raise ValueError(error_message) 897 898 super(DistributedIterator, 899 self).__init__(input_workers, iterators, strategy, 900 enable_get_next_as_optional) 901 902 @property 903 def element_spec(self): 904 # When partial batch handling is enabled, always set the batch dimension to 905 # None, otherwise we just follow element_spec of the underlying dataset 906 # (whose batch dimension may also be None). This is because with partial 907 # batching handling we could always produce empty batches. 908 if (self._enable_get_next_as_optional and 909 self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access 910 return nest.map_structure( 911 _rebatch_as_dynamic, self._element_spec, expand_composites=False) 912 return self._element_spec 913 914 @property 915 def _type_spec(self): 916 # Note that we use actual element_spec instead of the rebatched-as-dynamic 917 # one to create DistributedIteratorSpec, to be consistent with the 918 # underlying iterators' specs. 919 return DistributedIteratorSpec(self._input_workers, self._element_spec, 920 self._strategy, 921 self._enable_get_next_as_optional, 922 self._options) 923 924 925class _IterableInput(DistributedDatasetInterface): 926 """Base class for iterable inputs for distribution strategies.""" 927 928 # pylint: disable=super-init-not-called 929 def __init__(self, input_workers): 930 assert isinstance(input_workers, InputWorkers) 931 self._input_workers = input_workers 932 933 def __iter__(self): 934 raise NotImplementedError("must be implemented in descendants") 935 936 def reduce(self, initial_state, reduce_fn): 937 """Execute a `reduce_fn` over all the elements of the input.""" 938 iterator = iter(self) 939 has_data, data = _get_next_as_optional( 940 iterator, self._strategy, return_per_replica=True) 941 942 def cond(has_data, data, state): 943 del data, state # Unused. 944 return has_data 945 946 def loop_body(has_data, data, state): 947 """Executes `reduce_fn` in a loop till the dataset is empty.""" 948 del has_data # Unused. 949 state = reduce_fn(state, data) 950 has_data, data = _get_next_as_optional( 951 iterator, self._strategy, return_per_replica=True) 952 return has_data, data, state 953 954 has_data, data, final_state = control_flow_ops.while_loop( 955 cond, loop_body, [has_data, data, initial_state], parallel_iterations=1) 956 return final_state 957 958 959class DistributedDataset(_IterableInput): 960 """Distributed dataset that supports prefetching to multiple devices.""" 961 962 def __init__(self, 963 dataset, 964 input_workers, 965 strategy, 966 num_replicas_in_sync=None, 967 input_context=None): 968 """Distribute the dataset on all workers. 969 970 If `num_replicas_in_sync` is not None, we split each batch of the dataset 971 into `num_replicas_in_sync` smaller batches, to be distributed among that 972 worker's replicas, so that the batch size for a global step (across all 973 workers and replicas) is as expected. 974 975 Args: 976 dataset: `tf.data.Dataset` that will be used as the input source. 977 input_workers: an `InputWorkers` object. 978 strategy: a `tf.distribute.Strategy` object, used to run all-reduce to 979 handle last partial batch. 980 num_replicas_in_sync: Optional integer. If this is not None, the value 981 is used to decide how to rebatch datasets into smaller batches so that 982 the total batch size for each step (across all workers and replicas) 983 adds up to `dataset`'s batch size. 984 input_context: `InputContext` for sharding. Only pass this in for between 985 graph multi-worker cases where there is only one `input_worker`. In 986 these cases, we will shard based on the `input_pipeline_id` and 987 `num_input_pipelines` in the `InputContext`. 988 """ 989 super(DistributedDataset, self).__init__(input_workers=input_workers) 990 # We clone and shard the dataset on each worker. The current setup tries to 991 # shard the dataset by files if possible so that each worker sees a 992 # different subset of files. If that is not possible, will attempt to shard 993 # the final input such that each worker will run the entire preprocessing 994 # pipeline and only receive its own shard of the dataset. 995 996 # Additionally, we rebatch the dataset on each worker into 997 # `num_replicas_in_sync` smaller batches to be distributed among that 998 # worker's replicas, so that the batch size for a global step (across all 999 # workers and replicas) adds up to the original dataset's batch size. 1000 if num_replicas_in_sync is not None: 1001 num_workers = input_context.num_input_pipelines if input_context else len( 1002 input_workers.worker_devices) 1003 rebatch_fn = self._make_rebatch_fn(dataset, num_workers, 1004 num_replicas_in_sync) 1005 else: 1006 rebatch_fn = None 1007 1008 self._cloned_datasets = [] 1009 if input_context: 1010 # Between-graph where we rely on the input_context for sharding 1011 assert input_workers.num_workers == 1 1012 if rebatch_fn is not None: 1013 dataset = rebatch_fn(dataset, input_context.input_pipeline_id) 1014 dataset = input_ops.auto_shard_dataset(dataset, 1015 input_context.num_input_pipelines, 1016 input_context.input_pipeline_id, 1017 num_replicas_in_sync) 1018 self._cloned_datasets.append(dataset) 1019 else: 1020 replicated_ds = distribute.replicate(dataset, 1021 input_workers.worker_devices) 1022 for i, worker in enumerate(input_workers.worker_devices): 1023 with ops.device(worker): 1024 cloned_dataset = replicated_ds[worker] 1025 cloned_dataset = cloned_dataset.with_options(dataset.options()) 1026 if rebatch_fn is not None: 1027 cloned_dataset = rebatch_fn(cloned_dataset, i) 1028 cloned_dataset = input_ops.auto_shard_dataset( 1029 cloned_dataset, len(input_workers.worker_devices), i, 1030 num_replicas_in_sync) 1031 self._cloned_datasets.append(cloned_dataset) 1032 1033 self._input_workers = input_workers 1034 self._strategy = strategy 1035 self._enable_get_next_as_optional = _enable_get_next_as_optional( 1036 self._strategy, dataset) 1037 self._element_spec = _create_distributed_tensor_spec( 1038 self._strategy, self._cloned_datasets[0].element_spec) 1039 1040 def _make_rebatch_fn(self, dataset, num_workers, num_replicas_in_sync): 1041 """Returns a callable that rebatches the input dataset. 1042 1043 Args: 1044 dataset: A `tf.data.Dataset` representing the dataset to be distributed. 1045 num_workers: An integer representing the number of workers to distribute 1046 `dataset` among. 1047 num_replicas_in_sync: An integer representing the number of replicas in 1048 sync across all workers. 1049 """ 1050 if num_replicas_in_sync % num_workers: 1051 raise ValueError( 1052 "tf.distribute expects every worker to have the same number of " 1053 "replicas. However, encountered `num_replicas_in_sync` ({}) that " 1054 "cannot be divided by `num_workers` ({})".format( 1055 num_replicas_in_sync, num_workers)) 1056 1057 num_replicas_per_worker = num_replicas_in_sync // num_workers 1058 with ops.colocate_with(dataset._variant_tensor): # pylint: disable=protected-access 1059 batch_size = distribute.compute_batch_size(dataset) 1060 1061 def rebatch_fn(dataset, worker_index): 1062 try: 1063 # pylint: disable=protected-access 1064 def apply_rebatch(): 1065 batch_sizes = distribute.batch_sizes_for_worker( 1066 batch_size, num_workers, num_replicas_per_worker, worker_index) 1067 return distribute._RebatchDataset( 1068 dataset, batch_sizes).prefetch(num_replicas_per_worker) 1069 1070 def apply_legacy_rebatch(): 1071 return distribute._LegacyRebatchDataset( 1072 dataset, num_replicas_in_sync).prefetch(num_replicas_per_worker) 1073 1074 with ops.colocate_with(dataset._variant_tensor): 1075 return control_flow_ops.cond( 1076 math_ops.not_equal(batch_size, -1), 1077 true_fn=apply_rebatch, 1078 false_fn=apply_legacy_rebatch) 1079 except errors.InvalidArgumentError as e: 1080 if "without encountering a batch" in str(e): 1081 six.reraise( 1082 ValueError, 1083 ValueError( 1084 "Call the `batch` method on the input Dataset in order to be " 1085 "able to split your input across {} replicas.\n Please see " 1086 "the tf.distribute.Strategy guide. {}".format( 1087 num_replicas_in_sync, e)), 1088 sys.exc_info()[2]) 1089 else: 1090 raise 1091 1092 return rebatch_fn 1093 1094 def __iter__(self): 1095 if not (context.executing_eagerly() or 1096 ops.get_default_graph().building_function): 1097 raise RuntimeError("__iter__() is only supported inside of tf.function " 1098 "or when eager execution is enabled.") 1099 1100 # This is an optional flag that can be used to turn off using 1101 # OwnedMultiDeviceIterators and instead use the legacy MultiDeviceIterators 1102 # as a stop gap solution that will allow us to roll out this change. 1103 enable_legacy_iterators = getattr(self._strategy, 1104 "_enable_legacy_iterators", False) 1105 worker_iterators = _create_iterators_per_worker(self._cloned_datasets, 1106 self._input_workers, 1107 enable_legacy_iterators) 1108 if enable_legacy_iterators: 1109 iterator = DistributedIteratorV1( 1110 self._input_workers, 1111 worker_iterators, 1112 self._strategy, 1113 enable_get_next_as_optional=self._enable_get_next_as_optional) 1114 else: 1115 iterator = DistributedIterator( 1116 self._input_workers, 1117 worker_iterators, 1118 self._strategy, 1119 enable_get_next_as_optional=self._enable_get_next_as_optional) 1120 iterator._element_spec = self._element_spec # pylint: disable=protected-access 1121 1122 # When async eager is enabled, sometimes the iterator may not finish 1123 # initialization before passing to a multi device function, add a sync point 1124 # here to make sure all underlying iterators are initialized. 1125 if context.executing_eagerly(): 1126 context.async_wait() 1127 1128 return iterator 1129 1130 @property 1131 def element_spec(self): 1132 """The type specification of an element of this dataset.""" 1133 # When partial batch handling is enabled, always set the batch dimension to 1134 # None, otherwise we just follow element_spec of the underlying dataset 1135 # (whose batch dimension may also be None). This is because with partial 1136 # batching handling we could always produce empty batches. 1137 if (self._enable_get_next_as_optional and 1138 self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access 1139 return nest.map_structure( 1140 _rebatch_as_dynamic, self._element_spec, expand_composites=False) 1141 return self._element_spec 1142 1143 1144class DistributedDatasetV1(DistributedDataset): 1145 """Distributed dataset that supports prefetching to multiple devices.""" 1146 1147 def __init__(self, 1148 dataset, 1149 input_workers, 1150 strategy, 1151 num_replicas_in_sync=None, 1152 input_context=None): 1153 self._input_workers = input_workers 1154 super(DistributedDatasetV1, self).__init__( 1155 dataset, 1156 input_workers, 1157 strategy, 1158 num_replicas_in_sync=num_replicas_in_sync, 1159 input_context=input_context) 1160 1161 def make_one_shot_iterator(self): 1162 """Get a one time use iterator for DistributedDatasetV1. 1163 1164 Note: This API is deprecated. Please use `for ... in dataset:` to iterate 1165 over the dataset or `iter` to create an iterator. 1166 1167 Returns: 1168 A DistributedIteratorV1 instance. 1169 """ 1170 return self._make_one_shot_iterator() 1171 1172 def _make_one_shot_iterator(self): 1173 """Get an iterator for DistributedDatasetV1.""" 1174 # Graph mode with one shot iterator is disabled because we have to call 1175 # `initialize` on the iterator which is only required if we are using a 1176 # tf.distribute strategy. 1177 if not context.executing_eagerly(): 1178 raise ValueError("Cannot create a one shot iterator. Please use " 1179 "`make_initializable_iterator()` instead.") 1180 return self._get_iterator() 1181 1182 def make_initializable_iterator(self): 1183 """Get an initializable iterator for DistributedDatasetV1. 1184 1185 Note: This API is deprecated. Please use 1186 `tf.compat.v1.data.make_initializable_iterator(dataset)` to create an 1187 initializable iterator. 1188 1189 Returns: 1190 A DistributedIteratorV1 instance. 1191 """ 1192 return self._make_initializable_iterator() 1193 1194 def _make_initializable_iterator(self, shared_name=None): # pylint: disable=unused-argument 1195 """Get an initializable iterator for DistributedDatasetV1.""" 1196 # Eager mode generates already initialized iterators. Hence we cannot create 1197 # an initializable iterator. 1198 if context.executing_eagerly(): 1199 raise ValueError("Cannot create initializable iterator in Eager mode. " 1200 "Please use `iter()` instead.") 1201 return self._get_iterator() 1202 1203 def _get_iterator(self): 1204 worker_iterators = _create_iterators_per_worker(self._cloned_datasets, 1205 self._input_workers, 1206 True) 1207 iterator = DistributedIteratorV1(self._input_workers, worker_iterators, 1208 self._strategy, 1209 self._enable_get_next_as_optional) 1210 iterator._element_spec = self.element_spec # pylint: disable=protected-access 1211 1212 # When async eager is enabled, sometimes the iterator may not finish 1213 # initialization before passing to a multi device function, add a sync point 1214 # here to make sure all underlying iterators are initialized. 1215 if context.executing_eagerly(): 1216 context.async_wait() 1217 1218 return iterator 1219 1220 def __iter__(self): 1221 if (ops.executing_eagerly_outside_functions() or 1222 ops.get_default_graph().building_function): 1223 return self._get_iterator() 1224 1225 raise RuntimeError("__iter__() is only supported inside of tf.function " 1226 "or when eager execution is enabled.") 1227 1228 1229# TODO(priyag): Add other replication modes. 1230class DistributedDatasetsFromFunction(_IterableInput): 1231 """Inputs created from dataset function.""" 1232 1233 def __init__(self, dataset_fn, input_workers, input_contexts, strategy, 1234 options): 1235 """Makes an iterable from datasets created by the given function. 1236 1237 Args: 1238 dataset_fn: A function that returns a `Dataset` given an `InputContext`. 1239 input_workers: an `InputWorkers` object. 1240 input_contexts: A list of `InputContext` instances to be passed to call(s) 1241 to `dataset_fn`. Length and order should match worker order in 1242 `worker_device_pairs`. 1243 strategy: a `tf.distribute.Strategy` object, used to run all-reduce to 1244 handle last partial batch. 1245 options: `tf.distribute.InputOptions` used to control options on how this 1246 dataset is distributed. 1247 """ 1248 super(DistributedDatasetsFromFunction, self).__init__( 1249 input_workers=input_workers) 1250 1251 if input_workers.num_workers != len(input_contexts): 1252 raise ValueError( 1253 "Number of input workers (%d) is not same as number of " 1254 "input_contexts (%d)" % 1255 (input_workers.num_workers, len(input_contexts))) 1256 1257 self._input_workers = input_workers 1258 self._input_contexts = input_contexts 1259 self._strategy = strategy 1260 self._options = options 1261 self._datasets, element_spec = ( 1262 _create_datasets_from_function_with_input_context( 1263 self._input_contexts, self._input_workers, dataset_fn)) 1264 self._enable_get_next_as_optional = _enable_get_next_as_optional( 1265 self._strategy, self._datasets[0]) 1266 self._element_spec = _create_distributed_tensor_spec( 1267 self._strategy, element_spec) 1268 1269 def __iter__(self): 1270 if (ops.executing_eagerly_outside_functions() or 1271 ops.get_default_graph().building_function): 1272 # This is an optional flag that can be used to turn off using 1273 # OwnedMultiDeviceIterators and instead use the legacy 1274 # MultiDeviceIterators as a stop gap solution that will allow us to roll 1275 # out this change. 1276 enable_legacy_iterators = getattr(self._strategy, 1277 "_enable_legacy_iterators", False) 1278 iterators = _create_iterators_per_worker(self._datasets, 1279 self._input_workers, 1280 enable_legacy_iterators, 1281 self._options) 1282 if enable_legacy_iterators: 1283 iterator = DistributedIteratorV1( 1284 self._input_workers, 1285 iterators, 1286 self._strategy, 1287 enable_get_next_as_optional=self._enable_get_next_as_optional) 1288 else: 1289 iterator = DistributedIterator( 1290 input_workers=self._input_workers, 1291 iterators=iterators, 1292 strategy=self._strategy, 1293 enable_get_next_as_optional=self._enable_get_next_as_optional, 1294 options=self._options) 1295 iterator._element_spec = self._element_spec # pylint: disable=protected-access 1296 1297 # When async eager is enabled, sometimes the iterator may not finish 1298 # initialization before passing to a multi device function, add a sync 1299 # point here to make sure all underlying iterators are initialized. 1300 if context.executing_eagerly(): 1301 context.async_wait() 1302 1303 return iterator 1304 1305 raise RuntimeError("__iter__() is only supported inside of tf.function " 1306 "or when eager execution is enabled.") 1307 1308 @property 1309 def element_spec(self): 1310 """The type specification of an element of this dataset.""" 1311 # When partial batch handling is enabled, always set the batch dimension to 1312 # None, otherwise we just follow element_spec of the underlying dataset 1313 # (whose batch dimension may also be None). This is because with partial 1314 # batching handling we could always produce empty batches. 1315 if (self._enable_get_next_as_optional and 1316 self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access 1317 return nest.map_structure( 1318 _rebatch_as_dynamic, self._element_spec, expand_composites=False) 1319 return self._element_spec 1320 1321 1322class DistributedDatasetsFromFunctionV1(DistributedDatasetsFromFunction): 1323 """Inputs created from dataset function.""" 1324 1325 def _make_initializable_iterator(self, shared_name=None): 1326 """Get an initializable iterator for DistributedDatasetsFromFunctionV1.""" 1327 del shared_name # Unused 1328 # Eager mode generates already initialized iterators. Hence we cannot create 1329 # an initializable iterator. 1330 if context.executing_eagerly(): 1331 raise ValueError("Cannot create initializable iterator in Eager mode. " 1332 "Please use `iter()` instead.") 1333 return self._get_iterator() 1334 1335 def _make_one_shot_iterator(self): 1336 """Get an iterator for iterating over DistributedDatasetsFromFunctionV1.""" 1337 # Graph mode with one shot iterator is disabled because we have to call 1338 # `initialize` on the iterator which is only required if we are using a 1339 # tf.distribute strategy. 1340 if not context.executing_eagerly(): 1341 raise ValueError("Cannot create a one shot iterator. Please use " 1342 "`make_initializable_iterator()` instead.") 1343 return self._get_iterator() 1344 1345 def _get_iterator(self): 1346 iterators = _create_iterators_per_worker(self._datasets, 1347 self._input_workers, True) 1348 iterator = DistributedIteratorV1(self._input_workers, iterators, 1349 self._strategy, 1350 self._enable_get_next_as_optional) 1351 iterator._element_spec = self._element_spec # pylint: disable=protected-access 1352 1353 # When async eager is enabled, sometimes the iterator may not finish 1354 # initialization before passing to a multi device function, add a sync point 1355 # here to make sure all underlying iterators are initialized. 1356 if context.executing_eagerly(): 1357 context.async_wait() 1358 1359 return iterator 1360 1361 def __iter__(self): 1362 if (ops.executing_eagerly_outside_functions() or 1363 ops.get_default_graph().building_function): 1364 return self._get_iterator() 1365 1366 raise RuntimeError("__iter__() is only supported inside of tf.function " 1367 "or when eager execution is enabled.") 1368 1369 1370# TODO(anjalisridhar): This class will be soon removed in favor of newer 1371# APIs. 1372class InputFunctionIterator(DistributedIteratorV1): 1373 """Iterator created from input function.""" 1374 1375 def __init__(self, input_fn, input_workers, input_contexts, strategy): 1376 """Make an iterator for input provided via an input function. 1377 1378 Currently implements PER_WORKER mode, in which the `input_fn` is called 1379 once on each worker. 1380 1381 TODO(priyag): Add other replication modes. 1382 1383 Args: 1384 input_fn: Input function that returns a `tf.data.Dataset` object. 1385 input_workers: an `InputWorkers` object. 1386 input_contexts: A list of `InputContext` instances to be passed to call(s) 1387 to `input_fn`. Length and order should match worker order in 1388 `worker_device_pairs`. 1389 strategy: a `tf.distribute.Strategy` object, used to run all-reduce to 1390 handle last partial batch. 1391 """ 1392 assert isinstance(input_workers, InputWorkers) 1393 if input_workers.num_workers != len(input_contexts): 1394 raise ValueError( 1395 "Number of input workers (%d) is not same as number of " 1396 "input_contexts (%d)" % 1397 (input_workers.num_workers, len(input_contexts))) 1398 1399 iterators = [] 1400 for i, ctx in enumerate(input_contexts): 1401 worker = input_workers.worker_devices[i] 1402 with ops.device(worker): 1403 result = input_fn(ctx) 1404 devices = input_workers.compute_devices_for_worker(i) 1405 if isinstance(result, dataset_ops.DatasetV2): 1406 iterator = _SingleWorkerDatasetIterator(result, worker, devices) 1407 elif callable(result): 1408 iterator = _SingleWorkerCallableIterator(result, worker, devices) 1409 else: 1410 raise ValueError( 1411 "input_fn must return a tf.data.Dataset or a callable.") 1412 iterators.append(iterator) 1413 1414 super(InputFunctionIterator, self).__init__( 1415 input_workers, iterators, strategy, enable_get_next_as_optional=False) 1416 self._enable_get_next_as_optional = False 1417 1418 1419# TODO(anjalisridhar): This class will soon be removed and users should move 1420# to using DistributedIterator. 1421class DatasetIterator(DistributedIteratorV1): 1422 """Iterator created from input dataset.""" 1423 1424 def __init__(self, 1425 dataset, 1426 input_workers, 1427 strategy, 1428 num_replicas_in_sync=None, 1429 input_context=None): 1430 """Make an iterator for the dataset on given devices. 1431 1432 If `num_replicas_in_sync` is not None, we split each batch of the dataset 1433 into `num_replicas_in_sync` smaller batches, to be distributed among that 1434 worker's replicas, so that the batch size for a global step (across all 1435 workers and replicas) is as expected. 1436 1437 Args: 1438 dataset: `tf.data.Dataset` that will be used as the input source. 1439 input_workers: an `InputWorkers` object. 1440 strategy: a `tf.distribute.Strategy` object, used to run all-reduce to 1441 handle last partial batch. 1442 num_replicas_in_sync: Optional integer. If this is not None, the value is 1443 used to decide how to rebatch datasets into smaller batches so that the 1444 total batch size for each step (across all workers and replicas) adds up 1445 to `dataset`'s batch size. 1446 input_context: `InputContext` for sharding. Only pass this in for between 1447 graph multi-worker cases where there is only one `input_worker`. In 1448 these cases, we will shard based on the `input_pipeline_id` and 1449 `num_input_pipelines` in the `InputContext`. 1450 """ 1451 dist_dataset = DistributedDatasetV1( 1452 dataset, 1453 input_workers, 1454 strategy, 1455 num_replicas_in_sync=num_replicas_in_sync, 1456 input_context=input_context) 1457 worker_iterators = _create_iterators_per_worker( 1458 dist_dataset._cloned_datasets, input_workers, True) # pylint: disable=protected-access 1459 super(DatasetIterator, 1460 self).__init__(input_workers, worker_iterators, strategy, 1461 dist_dataset._enable_get_next_as_optional) # pylint: disable=protected-access 1462 self._element_spec = dist_dataset.element_spec 1463 1464 1465def _dummy_tensor_fn(value_structure): 1466 """A function to create dummy tensors from `value_structure`.""" 1467 1468 def create_dummy_tensor(spec): 1469 """Create a dummy tensor with possible batch dimensions set to 0.""" 1470 if isinstance(spec, ragged_tensor.RaggedTensorSpec): 1471 # Splice out the ragged dimensions. 1472 # pylint: disable=protected-access 1473 feature_shape = spec._shape[:1].concatenate( 1474 spec._shape[(1 + spec._ragged_rank):]) 1475 feature_type = spec._dtype 1476 # pylint: enable=protected-access 1477 else: 1478 feature_shape = spec.shape 1479 feature_type = spec.dtype 1480 # Ideally we should set the batch dimension to 0, however as in 1481 # DistributionStrategy we don't know the batch dimension, we try to 1482 # guess it as much as possible. If the feature has unknown dimensions, we 1483 # will set them to 0. If the feature shape is already static, we guess the 1484 # first dimension as batch dimension and set it to 0. 1485 dims = ([dim if dim is not None else 0 for dim in feature_shape.as_list()] 1486 if feature_shape else []) 1487 if dims and (isinstance(spec, ragged_tensor.RaggedTensorSpec) or 1488 feature_shape.is_fully_defined()): 1489 dims[0] = tensor_shape.Dimension(0) 1490 1491 if isinstance(spec, sparse_tensor.SparseTensorSpec): 1492 return sparse_tensor.SparseTensor( 1493 values=array_ops.zeros(0, feature_type), 1494 indices=array_ops.zeros((0, len(dims)), dtypes.int64), 1495 dense_shape=dims) 1496 1497 # Create the dummy tensor. 1498 dummy_tensor = array_ops.zeros(tensor_shape.TensorShape(dims), feature_type) 1499 if isinstance(spec, ragged_tensor.RaggedTensorSpec): 1500 # Reinsert the ragged dimensions with size 0. 1501 # pylint: disable=protected-access 1502 row_splits = array_ops.zeros(1, spec._row_splits_dtype) 1503 dummy_tensor = ragged_tensor.RaggedTensor.from_nested_row_splits( 1504 dummy_tensor, (row_splits,) * spec._ragged_rank, validate=False) 1505 # pylint: enable=protected-access 1506 return dummy_tensor 1507 1508 return nest.map_structure(create_dummy_tensor, value_structure) 1509 1510 1511def _recover_shape_fn(data, value_structure): 1512 """Recover the shape of `data` the same as shape of `value_structure`.""" 1513 1514 flattened_data = nest.flatten(data) 1515 for i, spec in enumerate(nest.flatten(value_structure)): 1516 for target, source in zip( 1517 nest.flatten(flattened_data[i], expand_composites=True), 1518 nest.flatten(spec, expand_composites=True)): 1519 target.set_shape(source.shape) 1520 # `SparseTensor` shape is not determined by the shape of its component 1521 # tensors. Rather, its shape depends on a tensor's values. 1522 if isinstance(spec, sparse_tensor.SparseTensorSpec) and spec.shape: 1523 dense_shape = spec.shape 1524 with ops.device(flattened_data[i].op.device): 1525 # For partially defined shapes, fill in missing values from tensor. 1526 if not dense_shape.is_fully_defined(): 1527 dense_shape = array_ops.stack([ 1528 flattened_data[i].dense_shape[j] if dim is None else dim 1529 for j, dim in enumerate(dense_shape.as_list()) 1530 ]) 1531 flattened_data[i] = sparse_tensor.SparseTensor( 1532 indices=flattened_data[i].indices, 1533 values=flattened_data[i].values, 1534 dense_shape=dense_shape) 1535 data = nest.pack_sequence_as(data, flattened_data) 1536 return data 1537 1538 1539class _SingleWorkerDatasetIteratorBase(object): 1540 """Iterator for a single `tf.data.Dataset`.""" 1541 1542 def __init__(self, dataset, worker, devices, options=None): 1543 """Create iterator for the `dataset` to fetch data to worker's `devices` . 1544 1545 A `MultiDeviceIterator` or `OwnedMultiDeviceIterator` is used to prefetch 1546 input to the devices on the given worker. 1547 1548 Args: 1549 dataset: A `tf.data.Dataset` instance. 1550 worker: Worker on which ops should be created. 1551 devices: Distribute data from `dataset` to these devices. 1552 options: options. 1553 """ 1554 self._dataset = dataset 1555 self._worker = worker 1556 self._devices = devices 1557 self._element_spec = dataset.element_spec 1558 self._options = options 1559 self._make_iterator() 1560 1561 def _make_iterator(self): 1562 raise NotImplementedError("must be implemented in descendants") 1563 1564 def _format_data_list_with_options(self, data_list): 1565 """Change the data in to a list type if required. 1566 1567 The OwnedMultiDeviceIterator returns the list data type, 1568 while the PER_REPLICA iterator (when used with prefetch disabled) 1569 returns without the enclosed list. This is to fix the inconsistency. 1570 Args: 1571 data_list: data_list 1572 Returns: 1573 list 1574 """ 1575 if (self._options and self._options.experimental_replication_mode == 1576 InputReplicationMode.PER_REPLICA and 1577 not self._options.experimental_prefetch_to_device): 1578 return [data_list] 1579 else: 1580 return data_list 1581 1582 def get_next(self, device, name=None): 1583 """Get next element for the given device.""" 1584 del name 1585 with ops.device(self._worker): 1586 if _should_use_multi_device_iterator(self._options): 1587 return self._iterator.get_next(device) 1588 else: 1589 return self._iterator.get_next() 1590 1591 def get_next_as_list_static_shapes(self, name=None): 1592 """Get next element from the underlying iterator. 1593 1594 Runs the iterator get_next() within a device scope. Since this doesn't use 1595 get_next_as_optional(), it is considerably faster than get_next_as_list() 1596 (but can only be used when the shapes are static). 1597 1598 Args: 1599 name: not used. 1600 1601 Returns: 1602 A list consisting of the next data from each device. 1603 """ 1604 del name 1605 with ops.device(self._worker): 1606 return self._format_data_list_with_options(self._iterator.get_next()) 1607 1608 def get_next_as_list(self, name=None): 1609 """Get next element from underlying iterator. 1610 1611 If there is no data left, a list of dummy tensors with possible batch 1612 dimensions set to 0 will be returned. Use of get_next_as_optional() and 1613 extra logic adds overhead compared to get_next_as_list_static_shapes(), but 1614 allows us to handle non-static shapes. 1615 1616 Args: 1617 name: not used. 1618 1619 Returns: 1620 A boolean tensor indicates whether there is any data in next element and 1621 the real data as the next element or a list of dummy tensors if no data 1622 left. 1623 """ 1624 del name 1625 with ops.device(self._worker): 1626 data_list = self._format_data_list_with_options( 1627 self._iterator.get_next_as_optional()) 1628 result = [] 1629 for i, data in enumerate(data_list): 1630 # Place the condition op in the same device as the data so the data 1631 # doesn't need to be sent back to the worker. 1632 with ops.device(self._devices[i]): 1633 # Data will be fetched in order, so we only need to check if the first 1634 # replica has value to see whether there is data left for this single 1635 # worker. 1636 if i == 0: 1637 worker_has_value = data.has_value() 1638 1639 # pylint: disable=unnecessary-lambda 1640 # pylint: disable=cell-var-from-loop 1641 real_data = control_flow_ops.cond( 1642 data.has_value(), 1643 lambda: data.get_value(), 1644 lambda: _dummy_tensor_fn(data.element_spec), 1645 strict=True, 1646 ) 1647 # Some dimensions in `replicas` will become unknown after we 1648 # conditionally return the real tensors or the dummy tensors. Recover 1649 # the shapes from `data.element_spec`. We only need to do this in 1650 # non eager mode because we always know the runtime shape of the 1651 # tensors in eager mode. 1652 if not context.executing_eagerly(): 1653 real_data = _recover_shape_fn(real_data, data.element_spec) 1654 result.append(real_data) 1655 # pylint: enable=cell-var-from-loop 1656 # pylint: enable=unnecessary-lambda 1657 1658 return worker_has_value, result 1659 1660 1661class _SingleWorkerDatasetIteratorSpec(type_spec.TypeSpec): 1662 """Type specification for `_SingleWorkerOwnedDatasetIterator`.""" 1663 1664 __slots__ = ["_worker", "_devices", "_element_spec", "_options"] 1665 1666 def __init__(self, worker, devices, element_spec, options): 1667 self._worker = worker 1668 self._devices = tuple(device_util.canonicalize(d) for d in devices) 1669 self._element_spec = element_spec 1670 self._options = options 1671 1672 @property 1673 def value_type(self): 1674 return _SingleWorkerOwnedDatasetIterator 1675 1676 def _serialize(self): 1677 return (self._worker, self._devices, self._element_spec, self._options) 1678 1679 @property 1680 def _component_specs(self): 1681 specs = [] 1682 if _should_use_multi_device_iterator(self._options): 1683 specs.append( 1684 multi_device_iterator_ops.MultiDeviceIteratorSpec( 1685 self._devices, self._worker, element_spec=self._element_spec)) 1686 else: 1687 specs.append(iterator_ops.IteratorSpec(element_spec=self._element_spec)) 1688 return specs 1689 1690 def _to_components(self, value): 1691 return [value._iterator] # pylint: disable=protected-access 1692 1693 def _from_components(self, components): 1694 return _SingleWorkerOwnedDatasetIterator( 1695 dataset=None, 1696 worker=self._worker, 1697 devices=self._devices, 1698 components=components, 1699 element_spec=self._element_spec, 1700 options=self._options) 1701 1702 @staticmethod 1703 def from_value(value): 1704 # pylint: disable=protected-access 1705 return _SingleWorkerDatasetIteratorSpec(value._worker, value._devices, 1706 value._element_spec, value._options) 1707 1708 1709class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase, 1710 composite_tensor.CompositeTensor): 1711 """Iterator for a DistributedDataset instance.""" 1712 1713 def __init__(self, 1714 dataset=None, 1715 worker=None, 1716 devices=None, 1717 components=None, 1718 element_spec=None, 1719 options=None): 1720 """Create iterator for the `dataset` to fetch data to worker's `devices` . 1721 1722 `OwnedMultiDeviceIterator` is used to prefetch input to the devices on the 1723 given worker. The lifetime of this iterator is tied to the encompassing 1724 python object. Once we go out of scope of the python object or return from 1725 a tf.function the underlying iterator resource is deleted. 1726 1727 Args: 1728 dataset: A `tf.data.Dataset` instance. 1729 worker: Worker on which ops should be created. 1730 devices: Distribute data from `dataset` to these devices. 1731 components: Tensor components to construct the 1732 _SingleWorkerOwnedDatasetIterator from. 1733 element_spec: A nested structure of `TypeSpec` objects that represents the 1734 type specification of elements of the iterator. 1735 options: `tf.distribute.InputOptions` used to control options on how this 1736 dataset is distributed. 1737 """ 1738 if worker is None or devices is None: 1739 raise ValueError("Both `worker` and `devices` should be provided") 1740 1741 error_message = ("Either `dataset` or both `components` and `element_spec` " 1742 "need to be provided.") 1743 1744 self._options = options 1745 if dataset is None: 1746 if (components is None or element_spec is None): 1747 raise ValueError(error_message) 1748 self._element_spec = element_spec 1749 self._worker = worker 1750 self._devices = devices 1751 self._iterator = components[0] 1752 else: 1753 if (components is not None or element_spec is not None): 1754 raise ValueError(error_message) 1755 super(_SingleWorkerOwnedDatasetIterator, 1756 self).__init__(dataset, worker, devices, options) 1757 1758 def _make_iterator(self): 1759 """Make appropriate iterator on the dataset.""" 1760 if not self._worker: 1761 raise ValueError("Worked device must be specified when creating an " 1762 "owned iterator.") 1763 if _should_use_multi_device_iterator(self._options): 1764 host_device = device_util.get_host_for_device(self._worker) 1765 with ops.device(self._worker): 1766 self._iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator( 1767 self._dataset, self._devices, source_device=host_device) 1768 else: 1769 with ops.device(self._worker): 1770 self._iterator = iter(self._dataset) 1771 1772 @property 1773 def element_spec(self): 1774 return self._element_spec 1775 1776 @property 1777 def _type_spec(self): 1778 return _SingleWorkerDatasetIteratorSpec(self._worker, self._devices, 1779 self._element_spec, self._options) 1780 1781 @property 1782 def output_classes(self): 1783 """Returns the class of each component of an element of this iterator. 1784 1785 The expected values are `tf.Tensor` and `tf.SparseTensor`. 1786 1787 Returns: 1788 A nested structure of Python `type` objects corresponding to each 1789 component of an element of this dataset. 1790 """ 1791 return nest.map_structure( 1792 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access 1793 self._element_spec) 1794 1795 @property 1796 def output_shapes(self): 1797 """Returns the shape of each component of an element of this iterator. 1798 1799 Returns: 1800 A nested structure of `tf.TensorShape` objects corresponding to each 1801 component of an element of this dataset. 1802 """ 1803 return nest.map_structure( 1804 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access 1805 self._element_spec) 1806 1807 @property 1808 def output_types(self): 1809 """Returns the type of each component of an element of this iterator. 1810 1811 Returns: 1812 A nested structure of `tf.DType` objects corresponding to each component 1813 of an element of this dataset. 1814 """ 1815 return nest.map_structure( 1816 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access 1817 self._element_spec) 1818 1819 1820class _SingleWorkerDatasetIterator(_SingleWorkerDatasetIteratorBase): 1821 """Iterator for a single DistributedDatasetV1 instance.""" 1822 1823 def _make_iterator(self): 1824 """Make appropriate iterator on the dataset.""" 1825 with ops.device(self._worker): 1826 self._iterator = multi_device_iterator_ops.MultiDeviceIterator( 1827 self._dataset, self._devices) 1828 1829 def initialize(self): 1830 """Initialize underlying iterator. 1831 1832 In eager execution, this simply recreates the underlying iterator. 1833 In graph execution, it returns the initializer ops for the underlying 1834 iterator. 1835 1836 Returns: 1837 A list of any initializer ops that should be run. 1838 """ 1839 if ops.executing_eagerly_outside_functions(): 1840 self._iterator._eager_reset() # pylint: disable=protected-access 1841 return [] 1842 else: 1843 return [self._iterator.initializer] 1844 1845 @property 1846 def output_classes(self): 1847 return dataset_ops.get_legacy_output_classes(self._iterator) 1848 1849 @property 1850 def output_shapes(self): 1851 return dataset_ops.get_legacy_output_shapes(self._iterator) 1852 1853 @property 1854 def output_types(self): 1855 return dataset_ops.get_legacy_output_types(self._iterator) 1856 1857 1858class _SingleWorkerCallableIterator(object): 1859 """Iterator for a single tensor-returning callable.""" 1860 1861 def __init__(self, fn, worker, devices): 1862 self._fn = fn 1863 self._worker = worker 1864 self._devices = devices 1865 1866 def get_next(self, device, name=None): 1867 """Get next element for the given device from the callable.""" 1868 del device, name 1869 with ops.device(self._worker): 1870 return self._fn() 1871 1872 def get_next_as_list_static_shapes(self, name=None): 1873 """Get next element from the callable.""" 1874 del name 1875 with ops.device(self._worker): 1876 data_list = [self._fn() for _ in self._devices] 1877 return data_list 1878 1879 def get_next_as_list(self, name=None): 1880 """Get next element from the callable.""" 1881 del name 1882 with ops.device(self._worker): 1883 data_list = [self._fn() for _ in self._devices] 1884 return constant_op.constant(True), data_list 1885 1886 def initialize(self): 1887 # TODO(petebu) Should this throw an exception instead? 1888 return [] 1889 1890 1891def _create_iterators_per_worker(worker_datasets, 1892 input_workers, 1893 enable_legacy_iterators, 1894 options=None): 1895 """Create a multidevice iterator on each of the workers.""" 1896 assert isinstance(input_workers, InputWorkers) 1897 assert len(worker_datasets) == len(input_workers.worker_devices) 1898 iterators = [] 1899 for i, worker in enumerate(input_workers.worker_devices): 1900 with ops.device(worker): 1901 worker_devices = input_workers.compute_devices_for_worker(i) 1902 if tf2.enabled() and not enable_legacy_iterators: 1903 iterator = _SingleWorkerOwnedDatasetIterator( 1904 dataset=worker_datasets[i], 1905 worker=worker, 1906 devices=worker_devices, 1907 options=options) 1908 else: 1909 iterator = _SingleWorkerDatasetIterator(worker_datasets[i], worker, 1910 worker_devices, options) 1911 iterators.append(iterator) 1912 return iterators 1913 1914 1915def _create_datasets_from_function_with_input_context(input_contexts, 1916 input_workers, 1917 dataset_fn): 1918 """Create device datasets per worker given a dataset function.""" 1919 datasets = [] 1920 for i, ctx in enumerate(input_contexts): 1921 worker = input_workers.worker_devices[i] 1922 with ops.device(worker): 1923 dataset = dataset_fn(ctx) 1924 datasets.append(dataset) 1925 return datasets, dataset.element_spec 1926 1927 1928# TODO(sourabhbajaj): Remove this in lieu of distributed datasets 1929def _get_batched_dataset(d): 1930 """Get the batched dataset from `d`.""" 1931 # pylint: disable=protected-access 1932 if isinstance(d, dataset_ops.DatasetV1Adapter): 1933 d = d._dataset 1934 1935 if isinstance(d, (dataset_ops.BatchDataset, batching._MapAndBatchDataset)): 1936 return d 1937 elif isinstance(d, (dataset_ops.PrefetchDataset, 1938 dataset_ops._OptionsDataset)): 1939 return _get_batched_dataset(d._input_dataset) 1940 1941 raise ValueError( 1942 "Unable to get batched dataset from the input dataset. `batch` " 1943 "`map_and_batch` need to be the last operations on the dataset. " 1944 "The batch operations can be followed by a prefetch.") 1945 1946 1947def _get_batched_dataset_attributes(d): 1948 """Get `batch_size`, `drop_remainder` of dataset.""" 1949 # pylint: disable=protected-access 1950 assert isinstance(d, 1951 (dataset_ops.BatchDataset, batching._MapAndBatchDataset)) 1952 if isinstance(d, dataset_ops.BatchDataset): 1953 batch_size = d._batch_size 1954 drop_remainder = d._drop_remainder 1955 elif isinstance(d, batching._MapAndBatchDataset): 1956 batch_size = d._batch_size_t 1957 drop_remainder = d._drop_remainder_t 1958 # pylint: enable=protected-access 1959 1960 if tensor_util.is_tf_type(batch_size): 1961 batch_size = tensor_util.constant_value(batch_size) 1962 1963 if tensor_util.is_tf_type(drop_remainder): 1964 drop_remainder = tensor_util.constant_value(drop_remainder) 1965 1966 return batch_size, drop_remainder 1967 1968 1969# TODO(sourabhbajaj): Remove this in lieu of distributed datasets 1970def _get_dataset_attributes(dataset): 1971 """Get the underlying attributes from the dataset object.""" 1972 # pylint: disable=protected-access 1973 1974 # First, get batch_size and drop_remainder from the dataset. We need 1975 # to walk back the dataset creation process and find the batched version in 1976 # order to get the attributes. 1977 batched_dataset = _get_batched_dataset(dataset) 1978 batch_size, drop_remainder = _get_batched_dataset_attributes(batched_dataset) 1979 1980 # Second, prefetch buffer should be get from the original dataset. 1981 prefetch_buffer = None 1982 if isinstance(dataset, dataset_ops.PrefetchDataset): 1983 prefetch_buffer = dataset._buffer_size 1984 elif (isinstance(dataset, dataset_ops.DatasetV1Adapter) 1985 and isinstance(dataset._dataset, dataset_ops.PrefetchDataset)): 1986 prefetch_buffer = dataset._dataset._buffer_size 1987 1988 return batch_size, drop_remainder, prefetch_buffer 1989 1990 1991def _should_use_multi_device_iterator(options): 1992 """Determine whether to use multi_device_iterator_ops.""" 1993 if (options is None or 1994 options.experimental_replication_mode == InputReplicationMode.PER_WORKER 1995 or 1996 (options.experimental_replication_mode == InputReplicationMode.PER_REPLICA 1997 and options.experimental_prefetch_to_device)): 1998 return True 1999 return False 2000 2001 2002class MultiStepContext(object): 2003 """A context object that can be used to capture things when running steps. 2004 2005 This context object is useful when running multiple steps at a time using the 2006 `experimental_run_steps_on_iterator` API. For e.g. it allows the user's step 2007 function to specify which outputs to emit at what frequency. Currently it 2008 supports capturing output from the last step, as well as capturing non tensor 2009 outputs. In the future it will be augmented to support other use cases such 2010 as output each N steps. 2011 """ 2012 2013 def __init__(self): 2014 """Initialize an output context. 2015 2016 Returns: 2017 A context object. 2018 """ 2019 self._last_step_outputs = {} 2020 self._last_step_outputs_reduce_ops = {} 2021 self._non_tensor_outputs = {} 2022 2023 @property 2024 def last_step_outputs(self): 2025 """A dictionary consisting of outputs to be captured on last step. 2026 2027 Keys in the dictionary are names of tensors to be captured, as specified 2028 when `set_last_step_output` is called. 2029 Values in the dictionary are the tensors themselves. If 2030 `set_last_step_output` was called with a `reduce_op` for this output, 2031 then the value is the reduced value. 2032 2033 Returns: 2034 A dictionary with last step outputs. 2035 """ 2036 return self._last_step_outputs 2037 2038 def _set_last_step_outputs(self, outputs): 2039 """Replace the entire dictionary of last step outputs.""" 2040 if not isinstance(outputs, dict): 2041 raise ValueError("Need a dictionary to set last_step_outputs.") 2042 self._last_step_outputs = outputs 2043 2044 def set_last_step_output(self, name, output, reduce_op=None): 2045 """Set `output` with `name` to be outputted from the last step. 2046 2047 Args: 2048 name: String, name to identify the output. Doesn't need to match tensor 2049 name. 2050 output: The tensors that should be outputted with `name`. See below for 2051 actual types supported. 2052 reduce_op: Reduction method to use to reduce outputs from multiple 2053 replicas. Required if `set_last_step_output` is called in a replica 2054 context. Optional in cross_replica_context. 2055 When present, the outputs from all the replicas are reduced using the 2056 current distribution strategy's `reduce` method. Hence, the type of 2057 `output` must be what's supported by the corresponding `reduce` method. 2058 For e.g. if using MirroredStrategy and reduction is set, output 2059 must be a `PerReplica` value. 2060 The reduce method is also recorded in a dictionary 2061 `_last_step_outputs_reduce_ops` for later interpreting of the 2062 outputs as already reduced or not. 2063 """ 2064 if distribution_strategy_context.in_cross_replica_context(): 2065 self._last_step_outputs_reduce_ops[name] = reduce_op 2066 if reduce_op is None: 2067 self._last_step_outputs[name] = output 2068 else: 2069 distribution = distribution_strategy_context.get_strategy() 2070 self._last_step_outputs[name] = distribution.reduce(reduce_op, output, 2071 axis=None) 2072 else: 2073 assert reduce_op is not None 2074 def merge_fn(distribution, value): 2075 self._last_step_outputs[name] = distribution.reduce(reduce_op, value, 2076 axis=None) 2077 # Setting this inside the `merge_fn` because all replicas share the same 2078 # context object, so it's more robust to set it only once (even if all 2079 # the replicas are trying to set the same value). 2080 self._last_step_outputs_reduce_ops[name] = reduce_op 2081 2082 distribution_strategy_context.get_replica_context().merge_call( 2083 merge_fn, args=(output,)) 2084 2085 @property 2086 def non_tensor_outputs(self): 2087 """A dictionary consisting of any non tensor outputs to be captured.""" 2088 return self._non_tensor_outputs 2089 2090 def set_non_tensor_output(self, name, output): 2091 """Set `output` with `name` to be captured as a non tensor output.""" 2092 if distribution_strategy_context.in_cross_replica_context(): 2093 self._non_tensor_outputs[name] = output 2094 else: 2095 def merge_fn(distribution, value): 2096 # NOTE(priyag): For non tensor outputs, we simply return all the values 2097 # in a list as reduction doesn't make sense on non tensors. 2098 self._non_tensor_outputs[name] = ( 2099 distribution.experimental_local_results(value)) 2100 distribution_strategy_context.get_replica_context().merge_call( 2101 merge_fn, args=(output,)) 2102 2103 2104def _create_distributed_tensor_spec(strategy, tensor_spec): 2105 """Create a `tf.TypeSpec` for a given strategy and input `tensor_spec`. 2106 2107 Args: 2108 strategy: The given `tf.distribute` strategy. 2109 tensor_spec: `tf.TensorSpec` of a given value. The batch dimension of the 2110 shape should be None if you have partial batches. 2111 2112 Returns: 2113 A `tf.TypeSpec` that matches the values produced by a given strategy. This 2114 can be a `tf.TensorSpec` or a `PerRelicaSpec`. 2115 """ 2116 num_replicas = len(strategy.extended.worker_devices) 2117 2118 # For one device strategy that is not MultiWorkerMirroredStrategy, return the 2119 # tensor_spec as is, since we don't wrap the output with PerReplica in this 2120 # case. 2121 # TODO(b/166464552): remove after we always wrap for all strategies. 2122 if not _always_wrap(strategy): 2123 return tensor_spec 2124 2125 # For other cases we assume the input to tf.function is a per replica type. 2126 def _get_value_per_replica(tensor_spec_per_input): 2127 value_specs = [tensor_spec_per_input for _ in range(num_replicas)] 2128 return values.PerReplicaSpec(*value_specs) 2129 2130 return nest.map_structure(_get_value_per_replica, tensor_spec) 2131 2132 2133def _replace_per_replica_spec(spec, i): 2134 """If `spec` is a `PerReplicaSpec`, then return its `i`th value_spec.""" 2135 if isinstance(spec, values.PerReplicaSpec): 2136 return spec._value_specs[i] # pylint: disable=protected-access 2137 else: 2138 return spec 2139 2140 2141def _enable_get_next_as_optional(strategy, dataset): 2142 """Returns whether to enable using partial batch handling.""" 2143 # TODO(b/133073708): we currently need a flag to control the usage because 2144 # there is a performance difference between get_next() and 2145 # get_next_as_optional(). And we only enable get_next_as_optional when the 2146 # output shapes are not static. 2147 # 2148 # TODO(rxsang): We want to always enable the get_next_as_optional behavior 2149 # when user passed input_fn instead of dataset. 2150 if not getattr(strategy.extended, "experimental_enable_get_next_as_optional", 2151 False): 2152 return False 2153 2154 if context.executing_eagerly(): 2155 # If the dataset is infinite, we don't need to enable last partial batch 2156 # support. Currently the logic only applies to the case that distributed 2157 # dataset is created in eager mode, as we need to evaluate the dataset 2158 # cardinality. 2159 with ops.device(dataset._variant_tensor.device): # pylint: disable=protected-access 2160 if dataset.cardinality().numpy() == cardinality.INFINITE: 2161 return False 2162 2163 return not _is_statically_shaped( 2164 dataset.element_spec) or strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access 2165 2166 2167def _create_per_replica(value_list, strategy): 2168 """Creates a PerReplica. 2169 2170 For strategies other than OneDeviceStrategy, it creates a PerReplica whose 2171 type spec is set to the element spec of the dataset. This helps avoid 2172 retracing for partial batches. Retracing is problematic for multi client when 2173 different client retraces different time, since retracing changes the 2174 collective keys in the tf.function, and causes mismatches among clients. 2175 2176 For single client strategies, this simply calls distribute_utils.regroup(). 2177 2178 Args: 2179 value_list: a list of values, one for each replica. 2180 strategy: the `tf.distribute.Strategy`. 2181 2182 Returns: 2183 a structure of PerReplica. 2184 2185 """ 2186 # TODO(b/166464552): always wrap for all one device strategies as well. 2187 always_wrap = _always_wrap(strategy) 2188 per_replicas = distribute_utils.regroup(value_list, always_wrap=always_wrap) 2189 return per_replicas 2190 2191 2192def _always_wrap(strategy): 2193 """Returns whether to always wrap the values in a DistributedValues.""" 2194 return strategy.extended._in_multi_worker_mode() or len( # pylint: disable=protected-access 2195 strategy.extended.worker_devices) > 1 2196 2197 2198def _rebatch_as_dynamic(per_replica_spec): 2199 """Rebatch the spec to have a dynamic batch dimension.""" 2200 assert isinstance(per_replica_spec, values.PerReplicaSpec), per_replica_spec 2201 2202 # pylint: disable=protected-access 2203 def _rebatch(spec): 2204 # Rebatch if possible. 2205 try: 2206 return spec._unbatch()._batch(None) 2207 except ValueError: 2208 pass 2209 return spec 2210 2211 return values.PerReplicaSpec( 2212 *nest.map_structure(_rebatch, per_replica_spec._value_specs)) 2213 # pylint: enable=protected-access 2214