1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Various classes representing distributed inputs.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22import sys 23 24import six 25 26from tensorflow.python import tf2 27from tensorflow.python.data.experimental.ops import batching 28from tensorflow.python.data.experimental.ops import cardinality 29from tensorflow.python.data.experimental.ops import distribute 30from tensorflow.python.data.ops import dataset_ops 31from tensorflow.python.data.ops import iterator_ops 32from tensorflow.python.data.ops import multi_device_iterator_ops 33from tensorflow.python.data.ops import optional_ops 34from tensorflow.python.distribute import device_util 35from tensorflow.python.distribute import distribute_lib 36from tensorflow.python.distribute import distribute_utils 37from tensorflow.python.distribute import distribution_strategy_context 38from tensorflow.python.distribute import input_ops 39from tensorflow.python.distribute import reduce_util 40from tensorflow.python.distribute import values 41from tensorflow.python.distribute.distribute_lib import InputReplicationMode 42from tensorflow.python.eager import context 43from tensorflow.python.framework import composite_tensor 44from tensorflow.python.framework import constant_op 45from tensorflow.python.framework import device as tf_device 46from tensorflow.python.framework import dtypes 47from tensorflow.python.framework import errors 48from tensorflow.python.framework import ops 49from tensorflow.python.framework import sparse_tensor 50from tensorflow.python.framework import tensor_shape 51from tensorflow.python.framework import tensor_util 52from tensorflow.python.framework import type_spec 53from tensorflow.python.ops import array_ops 54from tensorflow.python.ops import control_flow_ops 55from tensorflow.python.ops import math_ops 56from tensorflow.python.ops.ragged import ragged_tensor 57from tensorflow.python.platform import tf_logging as logging 58from tensorflow.python.types import distribute as distribute_types 59from tensorflow.python.util import nest 60from tensorflow.python.util.compat import collections_abc 61from tensorflow.python.util.deprecation import deprecated 62from tensorflow.python.util.tf_export import tf_export 63from tensorflow.tools.docs import doc_controls 64 65 66def get_distributed_dataset(dataset, 67 input_workers, 68 strategy, 69 num_replicas_in_sync=None, 70 input_context=None, 71 options=None, 72 build=True): 73 """Returns a distributed dataset from the given tf.data.Dataset instance. 74 75 This is a common function that is used by all strategies to return a 76 distributed dataset. The distributed dataset instance returned is different 77 depending on if we are in a TF 1 or TF 2 context. The distributed dataset 78 instances returned differ from each other in the APIs supported by each of 79 them. 80 81 Args: 82 dataset: a tf.data.Dataset instance. 83 input_workers: an InputWorkers object which specifies devices on which 84 iterators should be created. 85 strategy: a `tf.distribute.Strategy` object, used to run all-reduce to 86 handle last partial batch. 87 num_replicas_in_sync: Optional integer. If this is not None, the value is 88 used to decide how to rebatch datasets into smaller batches so that 89 the total batch size for each step (across all workers and replicas) 90 adds up to `dataset`'s batch size. 91 input_context: `InputContext` for sharding. Only pass this in for between 92 graph multi-worker cases where there is only one `input_worker`. In 93 these cases, we will shard based on the `input_pipeline_id` and 94 `num_input_pipelines` in the `InputContext`. 95 options: Default is None. `tf.distribute.InputOptions` used to control 96 options on how this dataset is distributed. 97 build: whether to build underlying datasets when a DistributedDataset is 98 created. This is only useful for `ParameterServerStrategy` now. 99 100 Returns: 101 A distributed dataset instance. 102 """ 103 if tf2.enabled(): 104 return DistributedDataset( 105 input_workers, 106 strategy, 107 dataset, 108 num_replicas_in_sync=num_replicas_in_sync, 109 input_context=input_context, 110 build=build, 111 options=options) 112 else: 113 return DistributedDatasetV1( 114 dataset, 115 input_workers, 116 strategy, 117 num_replicas_in_sync=num_replicas_in_sync, 118 input_context=input_context, 119 options=options) 120 121 122def get_distributed_datasets_from_function(dataset_fn, 123 input_workers, 124 input_contexts, 125 strategy, 126 options=None, 127 build=True): 128 """Returns a distributed dataset from the given input function. 129 130 This is a common function that is used by all strategies to return a 131 distributed dataset. The distributed dataset instance returned is different 132 depending on if we are in a TF 1 or TF 2 context. The distributed dataset 133 instances returned differ from each other in the APIs supported by each of 134 them. 135 136 Args: 137 dataset_fn: a function that returns a tf.data.Dataset instance. 138 input_workers: an InputWorkers object which specifies devices on which 139 iterators should be created. 140 input_contexts: A list of `InputContext` instances to be passed to call(s) 141 to `dataset_fn`. Length and order should match worker order in 142 `worker_device_pairs`. 143 strategy: a `tf.distribute.Strategy` object, used to run all-reduce to 144 handle last partial batch. 145 options: Default is None. `tf.distribute.InputOptions` used to control 146 options on how this dataset is distributed. 147 build: whether to build underlying datasets when a 148 `DistributedDatasetFromFunction` is created. This is only useful for 149 `ParameterServerStrategy` now. 150 151 Returns: 152 A distributed dataset instance. 153 154 Raises: 155 ValueError: if `options.experimental_replication_mode` and 156 `options.experimental_place_dataset_on_device` are not consistent 157 """ 158 if (options is not None and 159 options.experimental_replication_mode != InputReplicationMode.PER_REPLICA 160 and options.experimental_place_dataset_on_device): 161 raise ValueError( 162 "When `experimental_place_dataset_on_device` is set for dataset " 163 "placement, you must also specify `PER_REPLICA` for the " 164 "replication mode") 165 166 if (options is not None and 167 options.experimental_replication_mode == InputReplicationMode.PER_REPLICA 168 and options.experimental_fetch_to_device and 169 options.experimental_place_dataset_on_device): 170 raise ValueError( 171 "`experimental_place_dataset_on_device` can not be set to True " 172 "when experimental_fetch_to_device is True and " 173 "replication mode is set to `PER_REPLICA`") 174 175 if tf2.enabled(): 176 return DistributedDatasetsFromFunction( 177 input_workers, 178 strategy, 179 input_contexts=input_contexts, 180 dataset_fn=dataset_fn, 181 options=options, 182 build=build, 183 ) 184 else: 185 return DistributedDatasetsFromFunctionV1(input_workers, strategy, 186 input_contexts, dataset_fn, 187 options) 188 189 190def get_iterator_spec_from_dataset(strategy, dataset): 191 """Returns an iterator spec from dataset function. 192 193 This function constructs type spec for iterator obtained from 194 iter(dataset). 195 196 Args: 197 strategy: a `tf.distribute.Strategy` object, used to run all-reduce to 198 handle last partial batch. 199 dataset: A tf.data.Dataset instance. If using a function that returns a 200 tf.data.Dataset instance, pass dataset_fn.structured_outputs. 201 202 Returns: 203 A type_spec for iterator for dataset instance. 204 205 """ 206 output_element_spec = dataset.element_spec 207 if isinstance(dataset._type_spec, # pylint: disable=protected-access 208 (DistributedDatasetSpec, 209 DistributedDatasetsFromFunctionSpec)): 210 iterator_type_spec = DistributedIteratorSpec( 211 strategy.extended._input_workers_with_options( # pylint: disable=protected-access 212 ), output_element_spec, 213 strategy.extended._container_strategy(), True, # pylint: disable=protected-access 214 None) 215 else: 216 if strategy.extended._num_gpus_per_worker: # pylint: disable=protected-access 217 logging.warning( 218 f"{strategy.extended._num_gpus_per_worker} GPUs " # pylint: disable=protected-access 219 "are allocated per worker. Please use DistributedDataset by " 220 "calling strategy.experimental_distribute_dataset or strategy." 221 "distribute_datasets_from_function to make best use of GPU " 222 "resources" 223 ) 224 iterator_type_spec = iterator_ops.IteratorSpec(output_element_spec) 225 return iterator_type_spec 226 227 228@tf_export("distribute.DistributedIterator", v1=[]) 229class DistributedIteratorInterface(collections_abc.Iterator, 230 distribute_types.Iterator): 231 """An iterator over `tf.distribute.DistributedDataset`. 232 233 `tf.distribute.DistributedIterator` is the primary mechanism for enumerating 234 elements of a `tf.distribute.DistributedDataset`. It supports the Python 235 Iterator protocol, which means it can be iterated over using a for-loop or by 236 fetching individual elements explicitly via `get_next()`. 237 238 You can create a `tf.distribute.DistributedIterator` by calling `iter` on 239 a `tf.distribute.DistributedDataset` or creating a python loop over a 240 `tf.distribute.DistributedDataset`. 241 242 Visit the [tutorial](https://www.tensorflow.org/tutorials/distribute/input) 243 on distributed input for more examples and caveats. 244 """ 245 246 def get_next(self): 247 """Returns the next input from the iterator for all replicas. 248 249 Example use: 250 251 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 252 >>> dataset = tf.data.Dataset.range(100).batch(2) 253 >>> dist_dataset = strategy.experimental_distribute_dataset(dataset) 254 >>> dist_dataset_iterator = iter(dist_dataset) 255 >>> @tf.function 256 ... def one_step(input): 257 ... return input 258 >>> step_num = 5 259 >>> for _ in range(step_num): 260 ... strategy.run(one_step, args=(dist_dataset_iterator.get_next(),)) 261 >>> strategy.experimental_local_results(dist_dataset_iterator.get_next()) 262 (<tf.Tensor: shape=(1,), dtype=int64, numpy=array([10])>, 263 <tf.Tensor: shape=(1,), dtype=int64, numpy=array([11])>) 264 265 Returns: 266 A single `tf.Tensor` or a `tf.distribute.DistributedValues` which contains 267 the next input for all replicas. 268 269 Raises: 270 `tf.errors.OutOfRangeError`: If the end of the iterator has been reached. 271 """ 272 raise NotImplementedError( 273 "DistributedIterator.get_next() must be implemented in descendants.") 274 275 @property 276 def element_spec(self): 277 # pylint: disable=line-too-long 278 """The type specification of an element of `tf.distribute.DistributedIterator`. 279 280 Example usage: 281 282 >>> global_batch_size = 16 283 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 284 >>> dataset = tf.data.Dataset.from_tensors(([1.],[2])).repeat(100).batch(global_batch_size) 285 >>> distributed_iterator = iter(strategy.experimental_distribute_dataset(dataset)) 286 >>> distributed_iterator.element_spec 287 (PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.float32, name=None), 288 TensorSpec(shape=(None, 1), dtype=tf.float32, name=None)), 289 PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.int32, name=None), 290 TensorSpec(shape=(None, 1), dtype=tf.int32, name=None))) 291 292 Returns: 293 A nested structure of `tf.TypeSpec` objects matching the structure of an 294 element of this `tf.distribute.DistributedIterator`. This returned value 295 is typically a `tf.distribute.DistributedValues` object and specifies the 296 `tf.TensorSpec` of individual components. 297 """ 298 raise NotImplementedError( 299 "DistributedIterator.element_spec() must be implemented in descendants") 300 301 def get_next_as_optional(self): 302 # pylint: disable=line-too-long 303 """Returns a `tf.experimental.Optional` that contains the next value for all replicas. 304 305 If the `tf.distribute.DistributedIterator` has reached the end of the 306 sequence, the returned `tf.experimental.Optional` will have no value. 307 308 Example usage: 309 310 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 311 >>> global_batch_size = 2 312 >>> steps_per_loop = 2 313 >>> dataset = tf.data.Dataset.range(10).batch(global_batch_size) 314 >>> distributed_iterator = iter( 315 ... strategy.experimental_distribute_dataset(dataset)) 316 >>> def step_fn(x): 317 ... # train the model with inputs 318 ... return x 319 >>> @tf.function 320 ... def train_fn(distributed_iterator): 321 ... for _ in tf.range(steps_per_loop): 322 ... optional_data = distributed_iterator.get_next_as_optional() 323 ... if not optional_data.has_value(): 324 ... break 325 ... per_replica_results = strategy.run(step_fn, args=(optional_data.get_value(),)) 326 ... tf.print(strategy.experimental_local_results(per_replica_results)) 327 >>> train_fn(distributed_iterator) 328 ... # ([0 1], [2 3]) 329 ... # ([4], []) 330 331 Returns: 332 An `tf.experimental.Optional` object representing the next value from the 333 `tf.distribute.DistributedIterator` (if it has one) or no value. 334 """ 335 # pylint: enable=line-too-long 336 raise NotImplementedError( 337 "get_next_as_optional() not implemented in descendants") 338 339 340@tf_export("distribute.DistributedDataset", v1=[]) 341class DistributedDatasetInterface(collections_abc.Iterable, 342 distribute_types.Iterable): 343 # pylint: disable=line-too-long 344 """Represents a dataset distributed among devices and machines. 345 346 A `tf.distribute.DistributedDataset` could be thought of as a "distributed" 347 dataset. When you use `tf.distribute` API to scale training to multiple 348 devices or machines, you also need to distribute the input data, which leads 349 to a `tf.distribute.DistributedDataset` instance, instead of a 350 `tf.data.Dataset` instance in the non-distributed case. In TF 2.x, 351 `tf.distribute.DistributedDataset` objects are Python iterables. 352 353 Note: `tf.distribute.DistributedDataset` instances are *not* of type 354 `tf.data.Dataset`. It only supports two usages we will mention below: 355 iteration and `element_spec`. We don't support any other APIs to transform or 356 inspect the dataset. 357 358 There are two APIs to create a `tf.distribute.DistributedDataset` object: 359 `tf.distribute.Strategy.experimental_distribute_dataset(dataset)`and 360 `tf.distribute.Strategy.distribute_datasets_from_function(dataset_fn)`. 361 *When to use which?* When you have a `tf.data.Dataset` instance, and the 362 regular batch splitting (i.e. re-batch the input `tf.data.Dataset` instance 363 with a new batch size that is equal to the global batch size divided by the 364 number of replicas in sync) and autosharding (i.e. the 365 `tf.data.experimental.AutoShardPolicy` options) work for you, use the former 366 API. Otherwise, if you are *not* using a canonical `tf.data.Dataset` instance, 367 or you would like to customize the batch splitting or sharding, you can wrap 368 these logic in a `dataset_fn` and use the latter API. Both API handles 369 prefetch to device for the user. For more details and examples, follow the 370 links to the APIs. 371 372 373 There are two main usages of a `DistributedDataset` object: 374 375 1. Iterate over it to generate the input for a single device or multiple 376 devices, which is a `tf.distribute.DistributedValues` instance. To do this, 377 you can: 378 379 * use a pythonic for-loop construct: 380 381 >>> global_batch_size = 4 382 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 383 >>> dataset = tf.data.Dataset.from_tensors(([1.],[1.])).repeat(4).batch(global_batch_size) 384 >>> dist_dataset = strategy.experimental_distribute_dataset(dataset) 385 >>> @tf.function 386 ... def train_step(input): 387 ... features, labels = input 388 ... return labels - 0.3 * features 389 >>> for x in dist_dataset: 390 ... # train_step trains the model using the dataset elements 391 ... loss = strategy.run(train_step, args=(x,)) 392 ... print("Loss is", loss) 393 Loss is PerReplica:{ 394 0: tf.Tensor( 395 [[0.7] 396 [0.7]], shape=(2, 1), dtype=float32), 397 1: tf.Tensor( 398 [[0.7] 399 [0.7]], shape=(2, 1), dtype=float32) 400 } 401 402 Placing the loop inside a `tf.function` will give a performance boost. 403 However `break` and `return` are currently not supported if the loop is 404 placed inside a `tf.function`. We also don't support placing the loop 405 inside a `tf.function` when using 406 `tf.distribute.experimental.MultiWorkerMirroredStrategy` or 407 `tf.distribute.experimental.TPUStrategy` with multiple workers. 408 409 * use `__iter__` to create an explicit iterator, which is of type 410 `tf.distribute.DistributedIterator` 411 412 >>> global_batch_size = 4 413 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 414 >>> train_dataset = tf.data.Dataset.from_tensors(([1.],[1.])).repeat(50).batch(global_batch_size) 415 >>> train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset) 416 >>> @tf.function 417 ... def distributed_train_step(dataset_inputs): 418 ... def train_step(input): 419 ... loss = tf.constant(0.1) 420 ... return loss 421 ... per_replica_losses = strategy.run(train_step, args=(dataset_inputs,)) 422 ... return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,axis=None) 423 >>> EPOCHS = 2 424 >>> STEPS = 3 425 >>> for epoch in range(EPOCHS): 426 ... total_loss = 0.0 427 ... num_batches = 0 428 ... dist_dataset_iterator = iter(train_dist_dataset) 429 ... for _ in range(STEPS): 430 ... total_loss += distributed_train_step(next(dist_dataset_iterator)) 431 ... num_batches += 1 432 ... average_train_loss = total_loss / num_batches 433 ... template = ("Epoch {}, Loss: {:.4f}") 434 ... print (template.format(epoch+1, average_train_loss)) 435 Epoch 1, Loss: 0.2000 436 Epoch 2, Loss: 0.2000 437 438 439 To achieve a performance improvement, you can also wrap the `strategy.run` 440 call with a `tf.range` inside a `tf.function`. This runs multiple steps in a 441 `tf.function`. Autograph will convert it to a `tf.while_loop` on the worker. 442 However, it is less flexible comparing with running a single step inside 443 `tf.function`. For example, you cannot run things eagerly or arbitrary 444 python code within the steps. 445 446 447 2. Inspect the `tf.TypeSpec` of the data generated by `DistributedDataset`. 448 449 `tf.distribute.DistributedDataset` generates 450 `tf.distribute.DistributedValues` as input to the devices. If you pass the 451 input to a `tf.function` and would like to specify the shape and type of 452 each Tensor argument to the function, you can pass a `tf.TypeSpec` object to 453 the `input_signature` argument of the `tf.function`. To get the 454 `tf.TypeSpec` of the input, you can use the `element_spec` property of the 455 `tf.distribute.DistributedDataset` or `tf.distribute.DistributedIterator` 456 object. 457 458 For example: 459 460 >>> global_batch_size = 4 461 >>> epochs = 1 462 >>> steps_per_epoch = 1 463 >>> mirrored_strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 464 >>> dataset = tf.data.Dataset.from_tensors(([2.])).repeat(100).batch(global_batch_size) 465 >>> dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset) 466 >>> @tf.function(input_signature=[dist_dataset.element_spec]) 467 ... def train_step(per_replica_inputs): 468 ... def step_fn(inputs): 469 ... return tf.square(inputs) 470 ... return mirrored_strategy.run(step_fn, args=(per_replica_inputs,)) 471 >>> for _ in range(epochs): 472 ... iterator = iter(dist_dataset) 473 ... for _ in range(steps_per_epoch): 474 ... output = train_step(next(iterator)) 475 ... print(output) 476 PerReplica:{ 477 0: tf.Tensor( 478 [[4.] 479 [4.]], shape=(2, 1), dtype=float32), 480 1: tf.Tensor( 481 [[4.] 482 [4.]], shape=(2, 1), dtype=float32) 483 } 484 485 486 Visit the [tutorial](https://www.tensorflow.org/tutorials/distribute/input) 487 on distributed input for more examples and caveats. 488 """ 489 490 def __iter__(self): 491 """Creates an iterator for the `tf.distribute.DistributedDataset`. 492 493 The returned iterator implements the Python Iterator protocol. 494 495 Example usage: 496 497 >>> global_batch_size = 4 498 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 499 >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4]).repeat().batch(global_batch_size) 500 >>> distributed_iterator = iter(strategy.experimental_distribute_dataset(dataset)) 501 >>> print(next(distributed_iterator)) 502 PerReplica:{ 503 0: tf.Tensor([1 2], shape=(2,), dtype=int32), 504 1: tf.Tensor([3 4], shape=(2,), dtype=int32) 505 } 506 507 Returns: 508 An `tf.distribute.DistributedIterator` instance for the given 509 `tf.distribute.DistributedDataset` object to enumerate over the 510 distributed data. 511 """ 512 raise NotImplementedError("Must be implemented in descendants") 513 514 @property 515 def element_spec(self): 516 """The type specification of an element of this `tf.distribute.DistributedDataset`. 517 518 Example usage: 519 520 >>> global_batch_size = 16 521 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 522 >>> dataset = tf.data.Dataset.from_tensors(([1.],[2])).repeat(100).batch(global_batch_size) 523 >>> dist_dataset = strategy.experimental_distribute_dataset(dataset) 524 >>> dist_dataset.element_spec 525 (PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.float32, name=None), 526 TensorSpec(shape=(None, 1), dtype=tf.float32, name=None)), 527 PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.int32, name=None), 528 TensorSpec(shape=(None, 1), dtype=tf.int32, name=None))) 529 530 Returns: 531 A nested structure of `tf.TypeSpec` objects matching the structure of an 532 element of this `tf.distribute.DistributedDataset`. This returned value is 533 typically a `tf.distribute.DistributedValues` object and specifies the 534 `tf.TensorSpec` of individual components. 535 """ 536 raise NotImplementedError( 537 "DistributedDataset.element_spec must be implemented in descendants.") 538 539 @doc_controls.do_not_generate_docs 540 def reduce(self, initial_state, reduce_func): 541 raise NotImplementedError( 542 "DistributedDataset.reduce must be implemented in descendants.") 543 544 545class InputWorkers(object): 546 """A 1-to-many mapping from input worker devices to compute devices.""" 547 548 # TODO(ishark): Remove option canonicalize_devices and make all the callers 549 # pass canonicalized or raw device strings as relevant from strategy. 550 def __init__(self, worker_device_pairs, canonicalize_devices=True): 551 """Initialize an `InputWorkers` object. 552 553 Args: 554 worker_device_pairs: A sequence of pairs: `(input device, a tuple of 555 compute devices fed by that input device)`. 556 canonicalize_devices: Whether to canonicalize devices for workers fully or 557 partially. If False, it will partially canonicalize devices by removing 558 job and task. 559 """ 560 self._worker_device_pairs = worker_device_pairs 561 self._input_worker_devices = tuple(d for d, _ in self._worker_device_pairs) 562 self._canonicalize_devices = canonicalize_devices 563 if canonicalize_devices: 564 self._fed_devices = tuple( 565 tuple(device_util.canonicalize(d) 566 for d in f) 567 for _, f in self._worker_device_pairs) 568 else: 569 self._fed_devices = tuple( 570 tuple(device_util.canonicalize_without_job_and_task(d) 571 for d in f) 572 for _, f in self._worker_device_pairs) 573 574 @property 575 def num_workers(self): 576 return len(self._input_worker_devices) 577 578 @property 579 def worker_devices(self): 580 return self._input_worker_devices 581 582 def compute_devices_for_worker(self, worker_index): 583 return self._fed_devices[worker_index] 584 585 def __repr__(self): 586 devices = self.worker_devices 587 debug_repr = ",\n".join(" %d %s: %s" % 588 (i, devices[i], self._fed_devices[i]) 589 for i in range(len(devices))) 590 return "%s:{\n%s}" % (self.__class__.__name__, debug_repr) 591 592 def serialize(self): 593 return (self._worker_device_pairs, self._canonicalize_devices) 594 595 def deserialize(self, serialized): 596 return InputWorkers(serialized) 597 598 599def _get_next_as_optional(iterator, strategy, return_per_replica=False): 600 """Returns an empty dataset indicator and the next input from the iterator. 601 602 Args: 603 iterator: a DistributedIterator object. 604 strategy: the `tf.distribute.Strategy` instance. 605 return_per_replica: a boolean. If True, the returned data will be wrapped 606 with `PerReplica` structure. Otherwise it is a 2D 607 num_input_workers*num_replicas_per_worker list. 608 609 Returns: 610 A tuple (a boolean tensor indicating whether the next batch has value 611 globally, data from all replicas). 612 """ 613 replicas = [] 614 worker_has_values = [] 615 worker_devices = [] 616 with distribution_strategy_context.enter_or_assert_strategy(strategy): 617 if distribution_strategy_context.get_replica_context() is not None: 618 raise ValueError("next(iterator) should be called from outside of " 619 "replica_fn. e.g. strategy.run(replica_fn, " 620 "args=(next(iterator),))") 621 622 for i, worker in enumerate(iterator._input_workers.worker_devices): # pylint: disable=protected-access 623 with ops.device(worker): 624 worker_has_value, next_element = ( 625 iterator._iterators[i].get_next_as_list()) # pylint: disable=protected-access 626 # Collective all-reduce requires explicit devices for inputs. 627 with ops.device("/cpu:0"): 628 # Converting to integers for all-reduce. 629 worker_has_value = math_ops.cast(worker_has_value, dtypes.int64) 630 worker_devices.append(worker_has_value.device) 631 worker_has_values.append(worker_has_value) 632 # Make `replicas` a flat list of values across all replicas. 633 replicas.append(next_element) 634 635 if return_per_replica: 636 flattened_data = [] 637 for per_worker_data in replicas: 638 flattened_data.extend(per_worker_data) 639 replicas = _create_per_replica(flattened_data, strategy) 640 641 # Run an all-reduce to see whether any worker has values. 642 # TODO(b/131423105): we should be able to short-cut the all-reduce in some 643 # cases. 644 if getattr(strategy.extended, "_support_per_replica_values", True): 645 # `reduce` expects a `PerReplica`, so we pass it one, even 646 # though it doesn't actually have a value per replica 647 worker_has_values = values.PerReplica(worker_has_values) 648 global_has_value = strategy.reduce( 649 reduce_util.ReduceOp.SUM, worker_has_values, axis=None) 650 else: 651 assert len(worker_has_values) == 1 652 global_has_value = worker_has_values[0] 653 global_has_value = array_ops.reshape( 654 math_ops.cast(global_has_value, dtypes.bool), []) 655 return global_has_value, replicas 656 657 658def _is_statically_shaped(element_spec): 659 """Test if an iterator output is statically shaped. 660 661 For sparse and ragged tensors this only tests the batch dimension. 662 663 Args: 664 element_spec: a nest structure of `tf.TypeSpec`. The element spec of the 665 dataset of the iterator. 666 667 Returns: 668 True if the shape is static, false otherwise. 669 """ 670 671 for spec in nest.flatten(element_spec): 672 if isinstance( 673 spec, (sparse_tensor.SparseTensorSpec, ragged_tensor.RaggedTensorSpec)): 674 # For sparse or ragged tensor, we should only check the first 675 # dimension in order to get_next_as_optional. This is because 676 # when these tensors get batched by dataset only the batch dimension 677 # is set. 678 if spec.shape.rank > 0 and spec.shape.as_list()[0] is None: 679 return False 680 else: 681 for component in nest.flatten(spec._component_specs): # pylint: disable=protected-access 682 if not component.shape.is_fully_defined(): 683 return False 684 return True 685 686 687class DistributedIteratorBase(DistributedIteratorInterface): 688 """Common implementation for all input iterators.""" 689 690 # pylint: disable=super-init-not-called 691 def __init__(self, input_workers, iterators, strategy, 692 enable_get_next_as_optional): 693 assert isinstance(input_workers, InputWorkers) 694 if not input_workers.worker_devices: 695 raise ValueError("Should have at least one worker for input iterator.") 696 697 self._iterators = iterators 698 self._input_workers = input_workers 699 self._strategy = strategy 700 self._enable_get_next_as_optional = enable_get_next_as_optional 701 702 def next(self): 703 return self.__next__() 704 705 def __next__(self): 706 try: 707 return self.get_next() 708 except errors.OutOfRangeError: 709 raise StopIteration 710 711 def __iter__(self): 712 return self 713 714 def get_next_as_optional(self): 715 global_has_value, replicas = _get_next_as_optional( 716 self, self._strategy, return_per_replica=True) 717 718 def return_none(): 719 return optional_ops.Optional.empty(self._element_spec) 720 721 return control_flow_ops.cond( 722 global_has_value, lambda: optional_ops.Optional.from_value(replicas), 723 return_none) 724 725 def get_next(self, name=None): 726 """Returns the next input from the iterator for all replicas.""" 727 if not self._enable_get_next_as_optional: 728 with distribution_strategy_context.enter_or_assert_strategy( 729 self._strategy): 730 if distribution_strategy_context.get_replica_context() is not None: 731 raise ValueError("next(iterator) should be called from outside of " 732 "replica_fn. e.g. strategy.run(replica_fn, " 733 "args=(next(iterator),))") 734 735 replicas = [] 736 for i, worker in enumerate(self._input_workers.worker_devices): 737 if name is not None: 738 d = tf_device.DeviceSpec.from_string(worker) 739 new_name = "%s_%s_%d" % (name, d.job, d.task) 740 else: 741 new_name = None 742 with ops.device(worker): 743 # Make `replicas` a flat list of values across all replicas. 744 replicas.extend( 745 self._iterators[i].get_next_as_list_static_shapes(new_name)) 746 return _create_per_replica(replicas, self._strategy) 747 748 out_of_range_replicas = [] 749 def out_of_range_fn(worker_index, device): 750 """This function will throw an OutOfRange error.""" 751 # As this will be only called when there is no data left, so calling 752 # get_next() will trigger an OutOfRange error. 753 data = self._iterators[worker_index].get_next(device) 754 out_of_range_replicas.append(data) 755 return data 756 757 global_has_value, replicas = _get_next_as_optional( 758 self, self._strategy, return_per_replica=False) 759 results = [] 760 for i, worker in enumerate(self._input_workers.worker_devices): 761 with ops.device(worker): 762 devices = self._input_workers.compute_devices_for_worker(i) 763 for j, device in enumerate(devices): 764 with ops.device(device): 765 # pylint: disable=undefined-loop-variable 766 # pylint: disable=cell-var-from-loop 767 # It is fine for the lambda to capture variables from the loop as 768 # the lambda is executed in the loop as well. 769 result = control_flow_ops.cond( 770 global_has_value, 771 lambda: replicas[i][j], 772 lambda: out_of_range_fn(i, device), 773 strict=True, 774 ) 775 # pylint: enable=cell-var-from-loop 776 # pylint: enable=undefined-loop-variable 777 results.append(result) 778 replicas = results 779 780 return _create_per_replica(replicas, self._strategy) 781 782 783class DistributedIteratorV1(DistributedIteratorBase): 784 """Input Iterator for a distributed dataset.""" 785 786 # We need a private initializer method for re-initializing multidevice 787 # iterators when used with Keras training loops. If we don't reinitialize the 788 # iterator we run into memory leak issues (b/123315763). 789 @property 790 def _initializer(self): 791 init_ops = [] 792 for it in self._iterators: 793 init_ops.extend(it.initialize()) 794 return control_flow_ops.group(init_ops) 795 796 @deprecated(None, "Use the iterator's `initializer` property instead.") 797 def initialize(self): 798 """Initialize underlying iterators. 799 800 Returns: 801 A list of any initializer ops that should be run. 802 """ 803 return self._initializer 804 805 @property 806 def initializer(self): 807 """Returns a list of ops that initialize the iterator.""" 808 return self.initialize() 809 810 # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs. 811 @property 812 def output_classes(self): 813 return self._iterators[0].output_classes 814 815 # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs. 816 @property 817 def output_shapes(self): 818 return self._iterators[0].output_shapes 819 820 # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs. 821 @property 822 def output_types(self): 823 return self._iterators[0].output_types 824 825 # TODO(priyag): Remove when we switch to using `MultiDeviceIterator` for TPUs. 826 def get_iterator(self, worker): 827 for i, w in enumerate(self._input_workers.worker_devices): 828 if worker == w: 829 return self._iterators[i] 830 return None 831 832 @property 833 def element_spec(self): 834 """The type specification of an element of this iterator.""" 835 return self._element_spec 836 837 838class DistributedDatasetAndIteratorSpec(type_spec.TypeSpec): 839 """Common Type specification for `DistributedDataset and DistributedDatasetsFromFunction.""" 840 841 __slots__ = [ 842 "_input_workers", "_element_spec", "_strategy", 843 "_enable_get_next_as_optional", "_options", 844 "_canonicalize_devices" 845 ] 846 847 def __init__(self, 848 input_workers, 849 element_spec, 850 strategy, 851 options, 852 enable_get_next_as_optional=None): 853 # We don't want to allow deserialization of this class because we don't 854 # serialize the strategy object. Currently the only places where 855 # _deserialize is called is when we save/restore using SavedModels. 856 if isinstance(input_workers, tuple): 857 raise NotImplementedError("DistributedIteratorSpec does not have support " 858 "for deserialization.") 859 else: 860 self._input_workers = input_workers 861 self._element_spec = element_spec 862 self._strategy = strategy 863 self._enable_get_next_as_optional = enable_get_next_as_optional 864 self._options = options 865 if self._strategy: 866 self._canonicalize_devices = getattr(self._strategy, 867 "_canonicalize_devices", True) 868 else: 869 self._canonicalize_devices = True 870 871 def _serialize(self): 872 # We cannot serialize the strategy object so we convert it to an id that we 873 # can use for comparison. 874 return (self._input_workers.serialize(), self._element_spec, 875 id(self._strategy), id(self._options)) 876 877 def _deserialize(self): 878 raise ValueError( 879 f"Deserialization is currently unsupported for {type(self)}.") 880 881 def sanity_check_type(self, other): 882 """Returns the most specific TypeSpec compatible with `self` and `other`. 883 884 Args: 885 other: A `TypeSpec`. 886 887 Raises: 888 ValueError: If there is no TypeSpec that is compatible with both `self` 889 and `other`. 890 """ 891 # pylint: disable=protected-access 892 if type(self) is not type(other): 893 raise ValueError("No TypeSpec is compatible with both %s and %s" % 894 (self, other)) 895 if self._input_workers.serialize() != other._input_workers.serialize(): 896 raise ValueError("_input_workers is not compatible with both %s " 897 "and %s" % (self, other)) 898 if self._strategy is not other._strategy: 899 raise ValueError("tf.distribute strategy is not compatible with both %s " 900 "and %s" % (self, other)) 901 902 903class DistributedIteratorSpec(DistributedDatasetAndIteratorSpec): 904 """Type specification for `DistributedIterator`.""" 905 906 def __init__(self, input_workers, element_spec, strategy, 907 enable_get_next_as_optional, options): 908 super(DistributedIteratorSpec, 909 self).__init__(input_workers, element_spec, strategy, options, 910 enable_get_next_as_optional) 911 912 @property 913 def value_type(self): 914 return DistributedIterator 915 916 # Overriding this method so that we can merge and reconstruct the spec object 917 def most_specific_compatible_type(self, other): 918 """Returns the most specific TypeSpec compatible with `self` and `other`. 919 920 Args: 921 other: A `TypeSpec`. 922 923 Raises: 924 ValueError: If there is no TypeSpec that is compatible with both `self` 925 and `other`. 926 """ 927 # pylint: disable=protected-access 928 self.sanity_check_type(other) 929 element_spec = nest.map_structure( 930 lambda a, b: a.most_specific_compatible_type(b), self._element_spec, 931 other._element_spec) 932 return DistributedIteratorSpec(self._input_workers, element_spec, 933 self._strategy, 934 self._enable_get_next_as_optional, 935 self._options) 936 937 @property 938 def _component_specs(self): 939 specs = [] 940 worker_device_pairs = self._input_workers._worker_device_pairs # pylint: disable=protected-access 941 942 for i, (input_device, compute_devices) in enumerate(worker_device_pairs): 943 element_spec = nest.map_structure( 944 functools.partial(_replace_per_replica_spec, i=i), self._element_spec) 945 specs.append( 946 _SingleWorkerDatasetIteratorSpec(input_device, compute_devices, 947 element_spec, self._options, 948 self._canonicalize_devices)) 949 return specs 950 951 def _to_components(self, value): 952 return value._iterators # pylint: disable=protected-access 953 954 def _from_components(self, components): 955 return DistributedIterator( 956 input_workers=self._input_workers, 957 iterators=None, 958 components=components, 959 element_spec=self._element_spec, 960 strategy=self._strategy, 961 enable_get_next_as_optional=self._enable_get_next_as_optional, 962 options=self._options) 963 964 @staticmethod 965 def from_value(value): 966 # pylint: disable=protected-access 967 return DistributedIteratorSpec(value._input_workers, value._element_spec, 968 value._strategy, 969 value._enable_get_next_as_optional, 970 value._options) 971 972 def _with_tensor_ranks_only(self): 973 element_spec = nest.map_structure( 974 lambda s: s._with_tensor_ranks_only(), # pylint: disable=protected-access 975 self._element_spec) 976 return DistributedIteratorSpec(self._input_workers, element_spec, 977 self._strategy, 978 self._enable_get_next_as_optional, 979 self._options) 980 981 982class DistributedIterator(DistributedIteratorBase, 983 composite_tensor.CompositeTensor): 984 """Input Iterator for a distributed dataset.""" 985 986 def __init__(self, 987 input_workers=None, 988 iterators=None, 989 strategy=None, 990 components=None, 991 element_spec=None, 992 enable_get_next_as_optional=False, 993 options=None): 994 if input_workers is None: 995 raise ValueError("`input_workers` should be " 996 "provided.") 997 998 error_message = ("Either `input_workers` or " 999 "both `components` and `element_spec` need to be " 1000 "provided.") 1001 self._options = options 1002 1003 if iterators is None: 1004 if (components is None or element_spec is None): 1005 raise ValueError(error_message) 1006 self._element_spec = element_spec 1007 self._input_workers = input_workers 1008 self._iterators = components 1009 self._strategy = strategy 1010 self._enable_get_next_as_optional = enable_get_next_as_optional 1011 else: 1012 if (components is not None and element_spec is not None): 1013 raise ValueError(error_message) 1014 1015 super(DistributedIterator, 1016 self).__init__(input_workers, iterators, strategy, 1017 enable_get_next_as_optional) 1018 1019 @property 1020 def element_spec(self): 1021 # When partial batch handling is enabled, always set the batch dimension to 1022 # None, otherwise we just follow element_spec of the underlying dataset 1023 # (whose batch dimension may also be None). This is because with partial 1024 # batching handling we could always produce empty batches. 1025 if (self._enable_get_next_as_optional and 1026 self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access 1027 return nest.map_structure( 1028 _rebatch_as_dynamic, self._element_spec, expand_composites=False) 1029 return self._element_spec 1030 1031 @property 1032 def _type_spec(self): 1033 # Note that we use actual element_spec instead of the rebatched-as-dynamic 1034 # one to create DistributedIteratorSpec, to be consistent with the 1035 # underlying iterators' specs. 1036 return DistributedIteratorSpec(self._input_workers, self._element_spec, 1037 self._strategy, 1038 self._enable_get_next_as_optional, 1039 self._options) 1040 1041 1042class _IterableInput(DistributedDatasetInterface): 1043 """Base class for iterable inputs for distribution strategies.""" 1044 1045 # pylint: disable=super-init-not-called 1046 def __init__(self, input_workers): 1047 assert isinstance(input_workers, InputWorkers) 1048 self._input_workers = input_workers 1049 1050 def __iter__(self): 1051 raise NotImplementedError("must be implemented in descendants") 1052 1053 def reduce(self, initial_state, reduce_fn): 1054 """Execute a `reduce_fn` over all the elements of the input.""" 1055 iterator = iter(self) 1056 has_data, data = _get_next_as_optional( 1057 iterator, self._strategy, return_per_replica=True) 1058 1059 def cond(has_data, data, state): 1060 del data, state # Unused. 1061 return has_data 1062 1063 def loop_body(has_data, data, state): 1064 """Executes `reduce_fn` in a loop till the dataset is empty.""" 1065 del has_data # Unused. 1066 state = reduce_fn(state, data) 1067 has_data, data = _get_next_as_optional( 1068 iterator, self._strategy, return_per_replica=True) 1069 return has_data, data, state 1070 1071 has_data, data, final_state = control_flow_ops.while_loop( 1072 cond, loop_body, [has_data, data, initial_state], parallel_iterations=1) 1073 return final_state 1074 1075 1076class DistributedDatasetSpec(DistributedDatasetAndIteratorSpec): 1077 """Type specification for `DistributedDataset.""" 1078 1079 def __init__(self, input_workers, element_spec, strategy, 1080 enable_get_next_as_optional, options): 1081 super(DistributedDatasetSpec, 1082 self).__init__(input_workers, element_spec, strategy, options, 1083 enable_get_next_as_optional) 1084 1085 @property 1086 def value_type(self): 1087 return DistributedDataset 1088 1089 # Overriding this method so that we can merge and reconstruct the spec object 1090 def most_specific_compatible_type(self, other): 1091 """Returns the most specific TypeSpec compatible with `self` and `other`. 1092 1093 Args: 1094 other: A `TypeSpec`. 1095 1096 Raises: 1097 ValueError: If there is no TypeSpec that is compatible with both `self` 1098 and `other`. 1099 """ 1100 # pylint: disable=protected-access 1101 self.sanity_check_type(other) 1102 element_spec = nest.map_structure( 1103 lambda a, b: a.most_specific_compatible_type(b), self._element_spec, 1104 other._element_spec) 1105 return DistributedDatasetSpec(self._input_workers, element_spec, 1106 self._strategy, 1107 self._enable_get_next_as_optional, 1108 self._options) 1109 1110 @property 1111 def _component_specs(self): 1112 specs = [] 1113 worker_device_pairs = self._input_workers._worker_device_pairs # pylint: disable=protected-access 1114 1115 for i, _ in enumerate(worker_device_pairs): 1116 element_spec = nest.map_structure( 1117 functools.partial(_replace_per_replica_spec, i=i), self._element_spec) 1118 specs.append(dataset_ops.DatasetSpec(element_spec)) 1119 return specs 1120 1121 def _to_components(self, value): 1122 return value._cloned_datasets # pylint: disable=protected-access 1123 1124 def _from_components(self, components): 1125 return DistributedDataset( 1126 input_workers=self._input_workers, 1127 strategy=self._strategy, 1128 components=components, 1129 element_spec=self._element_spec, 1130 enable_get_next_as_optional=self._enable_get_next_as_optional, 1131 options=self._options) 1132 1133 @staticmethod 1134 def from_value(value): 1135 # pylint: disable=protected-access 1136 return DistributedDatasetSpec(value._input_workers, value._element_spec, 1137 value._strategy, 1138 value._enable_get_next_as_optional, 1139 value._options) 1140 1141 1142class DistributedDataset(_IterableInput, composite_tensor.CompositeTensor): 1143 """Distributed dataset that supports prefetching to multiple devices.""" 1144 1145 def __init__(self, 1146 input_workers, 1147 strategy, 1148 dataset=None, 1149 num_replicas_in_sync=None, 1150 input_context=None, 1151 components=None, 1152 element_spec=None, 1153 enable_get_next_as_optional=None, 1154 build=True, 1155 options=None): 1156 """Distribute the dataset on all workers. 1157 1158 If `num_replicas_in_sync` is not None, we split each batch of the dataset 1159 into `num_replicas_in_sync` smaller batches, to be distributed among that 1160 worker's replicas, so that the batch size for a global step (across all 1161 workers and replicas) is as expected. 1162 1163 Args: 1164 input_workers: an `InputWorkers` object. 1165 strategy: a `tf.distribute.Strategy` object, used to run all-reduce to 1166 handle last partial batch. 1167 dataset: `tf.data.Dataset` that will be used as the input source. Either 1168 dataset or components field should be passed when constructing 1169 DistributedDataset. Use this when contructing DistributedDataset from a 1170 new `tf.data.Dataset`. Use components when constructing using 1171 DistributedDatasetSpec. 1172 num_replicas_in_sync: Optional integer. If this is not None, the value 1173 is used to decide how to rebatch datasets into smaller batches so that 1174 the total batch size for each step (across all workers and replicas) 1175 adds up to `dataset`'s batch size. 1176 input_context: `InputContext` for sharding. Only pass this in for between 1177 graph multi-worker cases where there is only one `input_worker`. In 1178 these cases, we will shard based on the `input_pipeline_id` and 1179 `num_input_pipelines` in the `InputContext`. 1180 components: datasets when DistributedDataset is constructed from 1181 DistributedDatasetSpec. Either field dataset or components should be 1182 passed. 1183 element_spec: element spec for DistributedDataset when constructing from 1184 DistributedDatasetSpec. This will be used to set the element_spec for 1185 DistributedDataset and verified against element_spec from components. 1186 enable_get_next_as_optional: this is required when components is passed 1187 instead of dataset. 1188 build: whether to build underlying datasets when this object is created. 1189 This is only useful for `ParameterServerStrategy` now. 1190 options: `tf.distribute.InputOptions` used to control options on how this 1191 dataset is distributed. 1192 """ 1193 super(DistributedDataset, self).__init__(input_workers=input_workers) 1194 if input_workers is None or strategy is None: 1195 raise ValueError("input_workers and strategy are required arguments") 1196 if dataset is not None and components is not None: 1197 raise ValueError("Only one of dataset or components should be present") 1198 if dataset is None and components is None: 1199 raise ValueError("At least one of dataset or components should be passed") 1200 1201 self._input_workers = input_workers 1202 self._strategy = strategy 1203 self._options = options 1204 self._input_context = input_context 1205 self._num_replicas_in_sync = num_replicas_in_sync 1206 1207 if dataset is not None: 1208 self._original_dataset = dataset 1209 self._built = False 1210 if build: 1211 self.build() 1212 else: 1213 if not build: 1214 raise ValueError( 1215 "When constructing DistributedDataset with components, build " 1216 "should not be False. This is an internal error. Please file a " 1217 "bug.") 1218 if enable_get_next_as_optional is None: 1219 raise ValueError( 1220 "When constructing DistributedDataset with components, " + 1221 "enable_get_next_as_optional should also be passed") 1222 self._cloned_datasets = components 1223 self._enable_get_next_as_optional = enable_get_next_as_optional 1224 1225 assert element_spec is not None 1226 if element_spec != _create_distributed_tensor_spec( 1227 self._strategy, self._cloned_datasets[0].element_spec): 1228 raise ValueError("Mismatched element_spec from the passed components") 1229 self._element_spec = element_spec 1230 1231 self._built = True 1232 1233 def build(self, dataset_to_replace=None): 1234 assert not self._built 1235 dataset = dataset_to_replace or self._original_dataset 1236 self._create_cloned_datasets_from_dataset(dataset, self._input_context, 1237 self._input_workers, 1238 self._strategy, 1239 self._num_replicas_in_sync) 1240 self._element_spec = _create_distributed_tensor_spec( 1241 self._strategy, self._cloned_datasets[0].element_spec) 1242 self._built = True 1243 1244 def _create_cloned_datasets_from_dataset(self, dataset, input_context, 1245 input_workers, strategy, 1246 num_replicas_in_sync): 1247 # We clone and shard the dataset on each worker. The current setup tries to 1248 # shard the dataset by files if possible so that each worker sees a 1249 # different subset of files. If that is not possible, will attempt to shard 1250 # the final input such that each worker will run the entire preprocessing 1251 # pipeline and only receive its own shard of the dataset. 1252 1253 # Additionally, we rebatch the dataset on each worker into 1254 # `num_replicas_in_sync` smaller batches to be distributed among that 1255 # worker's replicas, so that the batch size for a global step (across all 1256 # workers and replicas) adds up to the original dataset's batch size. 1257 if num_replicas_in_sync is not None: 1258 num_workers = input_context.num_input_pipelines if input_context else len( 1259 input_workers.worker_devices) 1260 rebatch_fn = self._make_rebatch_fn(dataset, num_workers, 1261 num_replicas_in_sync) 1262 else: 1263 rebatch_fn = None 1264 self._cloned_datasets = [] 1265 if input_context: 1266 # Between-graph where we rely on the input_context for sharding 1267 assert input_workers.num_workers == 1 1268 if rebatch_fn is not None: 1269 dataset = rebatch_fn(dataset, input_context.input_pipeline_id) 1270 dataset = input_ops.auto_shard_dataset(dataset, 1271 input_context.num_input_pipelines, 1272 input_context.input_pipeline_id, 1273 num_replicas_in_sync) 1274 self._cloned_datasets.append(dataset) 1275 else: 1276 replicated_ds = distribute.replicate(dataset, 1277 input_workers.worker_devices) 1278 for i, worker in enumerate(input_workers.worker_devices): 1279 with ops.device(worker): 1280 cloned_dataset = replicated_ds[worker] 1281 if rebatch_fn is not None: 1282 cloned_dataset = rebatch_fn(cloned_dataset, i) 1283 cloned_dataset = input_ops.auto_shard_dataset( 1284 cloned_dataset, len(input_workers.worker_devices), i, 1285 num_replicas_in_sync) 1286 self._cloned_datasets.append(cloned_dataset) 1287 1288 self._enable_get_next_as_optional = _enable_get_next_as_optional( 1289 strategy, dataset) 1290 1291 def _make_rebatch_fn(self, dataset, num_workers, num_replicas_in_sync): 1292 """Returns a callable that rebatches the input dataset. 1293 1294 Args: 1295 dataset: A `tf.data.Dataset` representing the dataset to be distributed. 1296 num_workers: An integer representing the number of workers to distribute 1297 `dataset` among. 1298 num_replicas_in_sync: An integer representing the number of replicas in 1299 sync across all workers. 1300 """ 1301 if num_replicas_in_sync % num_workers: 1302 raise ValueError( 1303 "tf.distribute expects every worker to have the same number of " 1304 "replicas. However, encountered `num_replicas_in_sync` ({}) that " 1305 "cannot be divided by `num_workers` ({})".format( 1306 num_replicas_in_sync, num_workers)) 1307 1308 num_replicas_per_worker = num_replicas_in_sync // num_workers 1309 with ops.colocate_with(dataset._variant_tensor): # pylint: disable=protected-access 1310 batch_size = distribute.compute_batch_size(dataset) 1311 1312 def rebatch_fn(dataset, worker_index): 1313 try: 1314 # pylint: disable=protected-access 1315 def apply_rebatch(): 1316 batch_sizes = distribute.batch_sizes_for_worker( 1317 batch_size, num_workers, num_replicas_per_worker, worker_index) 1318 return distribute._RebatchDataset( 1319 dataset, batch_sizes).prefetch(num_replicas_per_worker) 1320 1321 def apply_legacy_rebatch(): 1322 return distribute._LegacyRebatchDataset( 1323 dataset, num_replicas_in_sync).prefetch(num_replicas_per_worker) 1324 1325 with ops.colocate_with(dataset._variant_tensor): 1326 return control_flow_ops.cond( 1327 math_ops.not_equal(batch_size, -1), 1328 true_fn=apply_rebatch, 1329 false_fn=apply_legacy_rebatch) 1330 except errors.InvalidArgumentError as e: 1331 if "without encountering a batch" in str(e): 1332 six.reraise( 1333 ValueError, 1334 ValueError( 1335 "Call the `batch` method on the input Dataset in order to be " 1336 "able to split your input across {} replicas.\n Please see " 1337 "the tf.distribute.Strategy guide. {}".format( 1338 num_replicas_in_sync, e)), 1339 sys.exc_info()[2]) 1340 else: 1341 raise 1342 1343 return rebatch_fn 1344 1345 def __iter__(self): 1346 if not (context.executing_eagerly() or 1347 ops.get_default_graph().building_function): 1348 raise RuntimeError("__iter__() is only supported inside of tf.function " 1349 "or when eager execution is enabled.") 1350 if not self._built: 1351 raise ValueError("To use this dataset, you need to pass this dataset to " 1352 "ClusterCoordinator.create_per_worker_dataset.") 1353 1354 # This is an optional flag that can be used to turn off using 1355 # OwnedMultiDeviceIterators and instead use the legacy MultiDeviceIterators 1356 # as a stop gap solution that will allow us to roll out this change. 1357 enable_legacy_iterators = getattr(self._strategy, 1358 "_enable_legacy_iterators", False) 1359 1360 canonicalize_devices = getattr(self._strategy, "_canonicalize_devices", 1361 True) 1362 1363 worker_iterators = _create_iterators_per_worker(self._cloned_datasets, 1364 self._input_workers, 1365 enable_legacy_iterators, 1366 self._options, 1367 canonicalize_devices) 1368 if enable_legacy_iterators: 1369 iterator = DistributedIteratorV1( 1370 self._input_workers, 1371 worker_iterators, 1372 self._strategy, 1373 enable_get_next_as_optional=self._enable_get_next_as_optional) 1374 else: 1375 iterator = DistributedIterator( 1376 self._input_workers, 1377 worker_iterators, 1378 self._strategy, 1379 enable_get_next_as_optional=self._enable_get_next_as_optional, 1380 options=self._options) 1381 iterator._element_spec = self._element_spec # pylint: disable=protected-access 1382 1383 # When async eager is enabled, sometimes the iterator may not finish 1384 # initialization before passing to a multi device function, add a sync point 1385 # here to make sure all underlying iterators are initialized. 1386 if context.executing_eagerly(): 1387 context.async_wait() 1388 1389 return iterator 1390 1391 @property 1392 def element_spec(self): 1393 """The type specification of an element of this dataset.""" 1394 # When partial batch handling is enabled, always set the batch dimension to 1395 # None, otherwise we just follow element_spec of the underlying dataset 1396 # (whose batch dimension may also be None). This is because with partial 1397 # batching handling we could always produce empty batches. 1398 if (self._enable_get_next_as_optional and 1399 self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access 1400 return nest.map_structure( 1401 _rebatch_as_dynamic, self._element_spec, expand_composites=False) 1402 return self._element_spec 1403 1404 @property 1405 def _type_spec(self): 1406 return DistributedDatasetSpec(self._input_workers, self._element_spec, 1407 self._strategy, 1408 self._enable_get_next_as_optional, 1409 self._options) 1410 1411 1412class DistributedDatasetV1(DistributedDataset): 1413 """Distributed dataset that supports prefetching to multiple devices.""" 1414 1415 def __init__(self, 1416 dataset, 1417 input_workers, 1418 strategy, 1419 num_replicas_in_sync=None, 1420 input_context=None, 1421 options=None): 1422 self._input_workers = input_workers 1423 super(DistributedDatasetV1, self).__init__( 1424 input_workers, 1425 strategy, 1426 dataset, 1427 num_replicas_in_sync=num_replicas_in_sync, 1428 input_context=input_context, 1429 options=options) 1430 1431 def make_one_shot_iterator(self): 1432 """Get a one time use iterator for DistributedDatasetV1. 1433 1434 Note: This API is deprecated. Please use `for ... in dataset:` to iterate 1435 over the dataset or `iter` to create an iterator. 1436 1437 Returns: 1438 A DistributedIteratorV1 instance. 1439 """ 1440 return self._make_one_shot_iterator() 1441 1442 def _make_one_shot_iterator(self): 1443 """Get an iterator for DistributedDatasetV1.""" 1444 # Graph mode with one shot iterator is disabled because we have to call 1445 # `initialize` on the iterator which is only required if we are using a 1446 # tf.distribute strategy. 1447 if not context.executing_eagerly(): 1448 raise ValueError("Cannot create a one shot iterator. Please use " 1449 "`make_initializable_iterator()` instead.") 1450 return self._get_iterator() 1451 1452 def make_initializable_iterator(self): 1453 """Get an initializable iterator for DistributedDatasetV1. 1454 1455 Note: This API is deprecated. Please use 1456 `tf.compat.v1.data.make_initializable_iterator(dataset)` to create an 1457 initializable iterator. 1458 1459 Returns: 1460 A DistributedIteratorV1 instance. 1461 """ 1462 return self._make_initializable_iterator() 1463 1464 def _make_initializable_iterator(self, shared_name=None): # pylint: disable=unused-argument 1465 """Get an initializable iterator for DistributedDatasetV1.""" 1466 # Eager mode generates already initialized iterators. Hence we cannot create 1467 # an initializable iterator. 1468 if context.executing_eagerly(): 1469 raise ValueError("Cannot create initializable iterator in Eager mode. " 1470 "Please use `iter()` instead.") 1471 return self._get_iterator() 1472 1473 def _get_iterator(self): 1474 worker_iterators = _create_iterators_per_worker(self._cloned_datasets, 1475 self._input_workers, True, 1476 self._options) 1477 iterator = DistributedIteratorV1(self._input_workers, worker_iterators, 1478 self._strategy, 1479 self._enable_get_next_as_optional) 1480 iterator._element_spec = self.element_spec # pylint: disable=protected-access 1481 1482 # When async eager is enabled, sometimes the iterator may not finish 1483 # initialization before passing to a multi device function, add a sync point 1484 # here to make sure all underlying iterators are initialized. 1485 if context.executing_eagerly(): 1486 context.async_wait() 1487 1488 return iterator 1489 1490 def __iter__(self): 1491 if (ops.executing_eagerly_outside_functions() or 1492 ops.get_default_graph().building_function): 1493 return self._get_iterator() 1494 1495 raise RuntimeError("__iter__() is only supported inside of tf.function " 1496 "or when eager execution is enabled.") 1497 1498 1499class DistributedDatasetsFromFunctionSpec(DistributedDatasetAndIteratorSpec): 1500 """Type specification for `DistributedDatasetsFromFunction.""" 1501 1502 def __init__(self, input_workers, element_spec, strategy, options): 1503 super(DistributedDatasetsFromFunctionSpec, 1504 self).__init__(input_workers, element_spec, strategy, options) 1505 1506 @property 1507 def value_type(self): 1508 return DistributedDatasetsFromFunction 1509 1510 @property 1511 def _component_specs(self): 1512 specs = [] 1513 worker_device_pairs = self._input_workers._worker_device_pairs # pylint: disable=protected-access 1514 1515 for i, _ in enumerate(worker_device_pairs): 1516 element_spec = nest.map_structure( 1517 functools.partial(_replace_per_replica_spec, i=i), self._element_spec) 1518 specs.append(dataset_ops.DatasetSpec(element_spec)) 1519 return specs 1520 1521 # Overriding this method so that we can merge and reconstruct the spec object 1522 def most_specific_compatible_type(self, other): 1523 """Returns the most specific TypeSpec compatible with `self` and `other`. 1524 1525 Args: 1526 other: A `TypeSpec`. 1527 1528 Raises: 1529 ValueError: If there is no TypeSpec that is compatible with both `self` 1530 and `other`. 1531 """ 1532 # pylint: disable=protected-access 1533 self.sanity_check_type(other) 1534 element_spec = nest.map_structure( 1535 lambda a, b: a.most_specific_compatible_type(b), self._element_spec, 1536 other._element_spec) # pylint: disable=protected-access 1537 return DistributedDatasetsFromFunctionSpec(self._input_workers, 1538 element_spec, self._strategy, 1539 self._options) 1540 1541 def _to_components(self, value): 1542 return value._datasets # pylint: disable=protected-access 1543 1544 def _from_components(self, components): 1545 return DistributedDatasetsFromFunction( 1546 input_workers=self._input_workers, 1547 strategy=self._strategy, 1548 components=components, 1549 element_spec=self._element_spec, 1550 options=self._options) 1551 1552 @staticmethod 1553 def from_value(value): 1554 # pylint: disable=protected-access 1555 return DistributedDatasetsFromFunctionSpec( 1556 input_workers=value._input_workers, 1557 element_spec=value._element_spec, 1558 strategy=value._strategy, 1559 options=value._options) 1560 1561 1562# TODO(priyag): Add other replication modes. 1563class DistributedDatasetsFromFunction(_IterableInput, 1564 composite_tensor.CompositeTensor): 1565 """Inputs created from dataset function.""" 1566 1567 def __init__(self, 1568 input_workers, 1569 strategy, 1570 input_contexts=None, 1571 dataset_fn=None, 1572 options=None, 1573 components=None, 1574 element_spec=None, 1575 build=True): 1576 """Makes an iterable from datasets created by the given function. 1577 1578 Args: 1579 input_workers: an `InputWorkers` object. 1580 strategy: a `tf.distribute.Strategy` object, used to run all-reduce to 1581 handle last partial batch. 1582 input_contexts: A list of `InputContext` instances to be passed to call(s) 1583 to `dataset_fn`. Length and order should match worker order in 1584 `worker_device_pairs`. 1585 dataset_fn: A function that returns a `Dataset` given an `InputContext`. 1586 Either dataset_fn or components should be passed to construct 1587 DistributedDatasetsFromFunction. Use this when constructing 1588 DistributedDataset using a function. Use components when constructing 1589 using DistributedDatasetsFromFunctionSpec. 1590 options: `tf.distribute.InputOptions` used to control options on how this 1591 dataset is distributed. 1592 components: datasets when DistributedDatasetsFromFunction is constructed 1593 from DistributedDatasetsFromFunctionSpec. Only one of dataset or 1594 components should be passed. 1595 element_spec: element spec for DistributedDataset when constructing from 1596 DistributedDatasetSpec. This will be used to set the element_spec for 1597 DistributedDatasetsFromFunctionSpec and verified against element_spec 1598 from components. 1599 build: whether to build underlying datasets when this object is created. 1600 This is only useful for `ParameterServerStrategy` now. 1601 """ 1602 super(DistributedDatasetsFromFunction, self).__init__( 1603 input_workers=input_workers) 1604 self._input_workers = input_workers 1605 self._strategy = strategy 1606 self._options = options 1607 if dataset_fn is not None and components is not None: 1608 raise ValueError("Only one of dataset_fn or components should be set") 1609 if dataset_fn is None and components is None: 1610 raise ValueError("At least one of dataset_fn or components should be set") 1611 1612 if dataset_fn is not None: 1613 if input_workers.num_workers != len(input_contexts): 1614 raise ValueError( 1615 "Number of input workers (%d) is not same as number of " 1616 "input_contexts (%d)" % 1617 (input_workers.num_workers, len(input_contexts))) 1618 self._input_contexts = input_contexts 1619 self._dataset_fn = dataset_fn 1620 self._built = False 1621 if build: 1622 self.build() 1623 else: 1624 if element_spec is None: 1625 raise ValueError( 1626 "element_spec should also be passed when passing components") 1627 if not build: 1628 raise ValueError( 1629 "When constructing DistributedDatasetFromFunction with components, " 1630 "build should not be False. This is an internal error. Please file " 1631 "a bug.") 1632 self._element_spec = element_spec 1633 self._datasets = components 1634 self._built = True 1635 self._enable_get_next_as_optional = _enable_get_next_as_optional( 1636 self._strategy, self._datasets[0]) 1637 1638 def build(self): 1639 assert not self._built 1640 self._datasets, element_spec = ( 1641 _create_datasets_from_function_with_input_context( 1642 self._input_contexts, self._input_workers, self._dataset_fn)) 1643 self._element_spec = _create_distributed_tensor_spec( 1644 self._strategy, element_spec) 1645 self._enable_get_next_as_optional = _enable_get_next_as_optional( 1646 self._strategy, self._datasets[0]) 1647 self._built = True 1648 1649 def __iter__(self): 1650 if not (ops.executing_eagerly_outside_functions() or 1651 ops.get_default_graph().building_function): 1652 raise RuntimeError("__iter__() is only supported inside of tf.function " 1653 "or when eager execution is enabled.") 1654 1655 if not self._built: 1656 raise ValueError("You need to use this dataset in " 1657 "ClusterCoordinator.create_per_worker_dataset.") 1658 1659 # This is an optional flag that can be used to turn off using 1660 # OwnedMultiDeviceIterators and instead use the legacy MultiDeviceIterators 1661 # as a stop gap solution that will allow us to roll out this change. 1662 enable_legacy_iterators = getattr(self._strategy, 1663 "_enable_legacy_iterators", False) 1664 canonicalize_devices = getattr(self._strategy, "_canonicalize_devices", 1665 True) 1666 1667 iterators = _create_iterators_per_worker(self._datasets, 1668 self._input_workers, 1669 enable_legacy_iterators, 1670 self._options, 1671 canonicalize_devices) 1672 if enable_legacy_iterators: 1673 iterator = DistributedIteratorV1( 1674 self._input_workers, 1675 iterators, 1676 self._strategy, 1677 enable_get_next_as_optional=self._enable_get_next_as_optional) 1678 else: 1679 iterator = DistributedIterator( 1680 input_workers=self._input_workers, 1681 iterators=iterators, 1682 strategy=self._strategy, 1683 enable_get_next_as_optional=self._enable_get_next_as_optional, 1684 options=self._options) 1685 iterator._element_spec = self._element_spec # pylint: disable=protected-access 1686 1687 # When async eager is enabled, sometimes the iterator may not finish 1688 # initialization before passing to a multi device function, add a sync 1689 # point here to make sure all underlying iterators are initialized. 1690 if context.executing_eagerly(): 1691 context.async_wait() 1692 1693 return iterator 1694 1695 @property 1696 def element_spec(self): 1697 """The type specification of an element of this dataset.""" 1698 # When partial batch handling is enabled, always set the batch dimension to 1699 # None, otherwise we just follow element_spec of the underlying dataset 1700 # (whose batch dimension may also be None). This is because with partial 1701 # batching handling we could always produce empty batches. 1702 if (self._enable_get_next_as_optional and 1703 self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access 1704 return nest.map_structure( 1705 _rebatch_as_dynamic, self._element_spec, expand_composites=False) 1706 return self._element_spec 1707 1708 @property 1709 def _type_spec(self): 1710 return DistributedDatasetsFromFunctionSpec(self._input_workers, 1711 self._element_spec, 1712 self._strategy, self._options) 1713 1714 1715class DistributedDatasetsFromFunctionV1(DistributedDatasetsFromFunction): 1716 """Inputs created from dataset function.""" 1717 1718 def _make_initializable_iterator(self, shared_name=None): 1719 """Get an initializable iterator for DistributedDatasetsFromFunctionV1.""" 1720 del shared_name # Unused 1721 # Eager mode generates already initialized iterators. Hence we cannot create 1722 # an initializable iterator. 1723 if context.executing_eagerly(): 1724 raise ValueError("Cannot create initializable iterator in Eager mode. " 1725 "Please use `iter()` instead.") 1726 return self._get_iterator() 1727 1728 def _make_one_shot_iterator(self): 1729 """Get an iterator for iterating over DistributedDatasetsFromFunctionV1.""" 1730 # Graph mode with one shot iterator is disabled because we have to call 1731 # `initialize` on the iterator which is only required if we are using a 1732 # tf.distribute strategy. 1733 if not context.executing_eagerly(): 1734 raise ValueError("Cannot create a one shot iterator. Please use " 1735 "`make_initializable_iterator()` instead.") 1736 return self._get_iterator() 1737 1738 def _get_iterator(self): 1739 iterators = _create_iterators_per_worker(self._datasets, 1740 self._input_workers, True, 1741 self._options) 1742 iterator = DistributedIteratorV1(self._input_workers, iterators, 1743 self._strategy, 1744 self._enable_get_next_as_optional) 1745 iterator._element_spec = self._element_spec # pylint: disable=protected-access 1746 1747 # When async eager is enabled, sometimes the iterator may not finish 1748 # initialization before passing to a multi device function, add a sync point 1749 # here to make sure all underlying iterators are initialized. 1750 if context.executing_eagerly(): 1751 context.async_wait() 1752 1753 return iterator 1754 1755 def __iter__(self): 1756 if (ops.executing_eagerly_outside_functions() or 1757 ops.get_default_graph().building_function): 1758 return self._get_iterator() 1759 1760 raise RuntimeError("__iter__() is only supported inside of tf.function " 1761 "or when eager execution is enabled.") 1762 1763 1764# TODO(anjalisridhar): This class will be soon removed in favor of newer 1765# APIs. 1766class InputFunctionIterator(DistributedIteratorV1): 1767 """Iterator created from input function.""" 1768 1769 def __init__(self, input_fn, input_workers, input_contexts, strategy): 1770 """Make an iterator for input provided via an input function. 1771 1772 Currently implements PER_WORKER mode, in which the `input_fn` is called 1773 once on each worker. 1774 1775 TODO(priyag): Add other replication modes. 1776 1777 Args: 1778 input_fn: Input function that returns a `tf.data.Dataset` object. 1779 input_workers: an `InputWorkers` object. 1780 input_contexts: A list of `InputContext` instances to be passed to call(s) 1781 to `input_fn`. Length and order should match worker order in 1782 `worker_device_pairs`. 1783 strategy: a `tf.distribute.Strategy` object, used to run all-reduce to 1784 handle last partial batch. 1785 """ 1786 assert isinstance(input_workers, InputWorkers) 1787 if input_workers.num_workers != len(input_contexts): 1788 raise ValueError( 1789 "Number of input workers (%d) is not same as number of " 1790 "input_contexts (%d)" % 1791 (input_workers.num_workers, len(input_contexts))) 1792 1793 iterators = [] 1794 for i, ctx in enumerate(input_contexts): 1795 worker = input_workers.worker_devices[i] 1796 with ops.device(worker): 1797 result = input_fn(ctx) 1798 devices = input_workers.compute_devices_for_worker(i) 1799 if isinstance(result, dataset_ops.DatasetV2): 1800 iterator = _SingleWorkerDatasetIterator(result, worker, devices) 1801 elif callable(result): 1802 iterator = _SingleWorkerCallableIterator(result, worker, devices) 1803 else: 1804 raise ValueError( 1805 "input_fn must return a tf.data.Dataset or a callable.") 1806 iterators.append(iterator) 1807 1808 super(InputFunctionIterator, self).__init__( 1809 input_workers, iterators, strategy, enable_get_next_as_optional=False) 1810 self._enable_get_next_as_optional = False 1811 1812 1813# TODO(anjalisridhar): This class will soon be removed and users should move 1814# to using DistributedIterator. 1815class DatasetIterator(DistributedIteratorV1): 1816 """Iterator created from input dataset.""" 1817 1818 def __init__(self, 1819 dataset, 1820 input_workers, 1821 strategy, 1822 num_replicas_in_sync=None, 1823 input_context=None): 1824 """Make an iterator for the dataset on given devices. 1825 1826 If `num_replicas_in_sync` is not None, we split each batch of the dataset 1827 into `num_replicas_in_sync` smaller batches, to be distributed among that 1828 worker's replicas, so that the batch size for a global step (across all 1829 workers and replicas) is as expected. 1830 1831 Args: 1832 dataset: `tf.data.Dataset` that will be used as the input source. 1833 input_workers: an `InputWorkers` object. 1834 strategy: a `tf.distribute.Strategy` object, used to run all-reduce to 1835 handle last partial batch. 1836 num_replicas_in_sync: Optional integer. If this is not None, the value is 1837 used to decide how to rebatch datasets into smaller batches so that the 1838 total batch size for each step (across all workers and replicas) adds up 1839 to `dataset`'s batch size. 1840 input_context: `InputContext` for sharding. Only pass this in for between 1841 graph multi-worker cases where there is only one `input_worker`. In 1842 these cases, we will shard based on the `input_pipeline_id` and 1843 `num_input_pipelines` in the `InputContext`. 1844 """ 1845 dist_dataset = DistributedDatasetV1( 1846 dataset, 1847 input_workers, 1848 strategy, 1849 num_replicas_in_sync=num_replicas_in_sync, 1850 input_context=input_context) 1851 worker_iterators = _create_iterators_per_worker( 1852 dist_dataset._cloned_datasets, input_workers, True) # pylint: disable=protected-access 1853 super(DatasetIterator, 1854 self).__init__(input_workers, worker_iterators, strategy, 1855 dist_dataset._enable_get_next_as_optional) # pylint: disable=protected-access 1856 self._element_spec = dist_dataset.element_spec 1857 1858 1859def _dummy_tensor_fn(value_structure): 1860 """A function to create dummy tensors from `value_structure`.""" 1861 1862 def create_dummy_tensor(spec): 1863 """Create a dummy tensor with possible batch dimensions set to 0.""" 1864 if hasattr(spec, "_create_empty_value"): 1865 # Type spec may overwrite default dummy values behavior by declaring the 1866 # `_create_empty_value(self)` method. This method must return a value 1867 # compatible with the type spec with batch dimensions set to 0 or fail if 1868 # such a value does not exist. This allows a composite tensor to customize 1869 # dummy values creation as, in general, its dummy value is not composed 1870 # from dummy components (e.g. `row_splits` tensor of a RaggedTensor is 1871 # never allowed to be empty). See b/183969859 for more discussions. 1872 # TODO(b/186079336): reconsider CompositeTensor support. 1873 return spec._create_empty_value() # pylint: disable=protected-access 1874 1875 if isinstance(spec, ragged_tensor.RaggedTensorSpec): 1876 # Splice out the ragged dimensions. 1877 # pylint: disable=protected-access 1878 feature_shape = spec._shape[:1].concatenate( 1879 spec._shape[(1 + spec._ragged_rank):]) 1880 feature_type = spec._dtype 1881 # pylint: enable=protected-access 1882 else: 1883 feature_shape = spec.shape 1884 feature_type = spec.dtype 1885 # Ideally we should set the batch dimension to 0, however as in 1886 # DistributionStrategy we don't know the batch dimension, we try to 1887 # guess it as much as possible. If the feature has unknown dimensions, we 1888 # will set them to 0. If the feature shape is already static, we guess the 1889 # first dimension as batch dimension and set it to 0. 1890 dims = ([dim if dim is not None else 0 for dim in feature_shape.as_list()] 1891 if feature_shape else []) 1892 if dims and (isinstance(spec, ragged_tensor.RaggedTensorSpec) or 1893 feature_shape.is_fully_defined()): 1894 dims[0] = tensor_shape.Dimension(0) 1895 1896 if isinstance(spec, sparse_tensor.SparseTensorSpec): 1897 return sparse_tensor.SparseTensor( 1898 values=array_ops.zeros(0, feature_type), 1899 indices=array_ops.zeros((0, len(dims)), dtypes.int64), 1900 dense_shape=dims) 1901 1902 # Create the dummy tensor. 1903 dummy_tensor = array_ops.zeros(tensor_shape.TensorShape(dims), feature_type) 1904 if isinstance(spec, ragged_tensor.RaggedTensorSpec): 1905 # Reinsert the ragged dimensions with size 0. 1906 # pylint: disable=protected-access 1907 row_splits = array_ops.zeros(1, spec._row_splits_dtype) 1908 dummy_tensor = ragged_tensor.RaggedTensor.from_nested_row_splits( 1909 dummy_tensor, (row_splits,) * spec._ragged_rank, validate=False) 1910 # pylint: enable=protected-access 1911 return dummy_tensor 1912 1913 return nest.map_structure(create_dummy_tensor, value_structure) 1914 1915 1916def _recover_shape_fn(data, value_structure): 1917 """Recover the shape of `data` the same as shape of `value_structure`.""" 1918 1919 flattened_data = nest.flatten(data) 1920 for i, spec in enumerate(nest.flatten(value_structure)): 1921 for target, source in zip( 1922 nest.flatten(flattened_data[i], expand_composites=True), 1923 nest.flatten(spec, expand_composites=True)): 1924 target.set_shape(source.shape) 1925 # `SparseTensor` shape is not determined by the shape of its component 1926 # tensors. Rather, its shape depends on a tensor's values. 1927 if isinstance(spec, sparse_tensor.SparseTensorSpec) and spec.shape: 1928 dense_shape = spec.shape 1929 with ops.device(flattened_data[i].op.device): 1930 # For partially defined shapes, fill in missing values from tensor. 1931 if not dense_shape.is_fully_defined(): 1932 dense_shape = array_ops.stack([ 1933 flattened_data[i].dense_shape[j] if dim is None else dim 1934 for j, dim in enumerate(dense_shape.as_list()) 1935 ]) 1936 flattened_data[i] = sparse_tensor.SparseTensor( 1937 indices=flattened_data[i].indices, 1938 values=flattened_data[i].values, 1939 dense_shape=dense_shape) 1940 data = nest.pack_sequence_as(data, flattened_data) 1941 return data 1942 1943 1944class _SingleWorkerDatasetIteratorBase(object): 1945 """Iterator for a single `tf.data.Dataset`.""" 1946 1947 def __init__(self, dataset, worker, devices, options=None): 1948 """Create iterator for the `dataset` to fetch data to worker's `devices` . 1949 1950 A `MultiDeviceIterator` or `OwnedMultiDeviceIterator` is used to prefetch 1951 input to the devices on the given worker. 1952 1953 Args: 1954 dataset: A `tf.data.Dataset` instance. 1955 worker: Worker on which ops should be created. 1956 devices: Distribute data from `dataset` to these devices. 1957 options: options. 1958 """ 1959 self._dataset = dataset 1960 self._worker = worker 1961 self._devices = devices 1962 self._element_spec = dataset.element_spec 1963 self._options = options 1964 self._make_iterator() 1965 1966 def _make_iterator(self): 1967 raise NotImplementedError("must be implemented in descendants") 1968 1969 def _format_data_list_with_options(self, data_list): 1970 """Change the data in to a list type if required. 1971 1972 The OwnedMultiDeviceIterator returns the list data type, 1973 while the PER_REPLICA iterator (when used with prefetch disabled) 1974 returns without the enclosed list. This is to fix the inconsistency. 1975 Args: 1976 data_list: data_list 1977 Returns: 1978 list 1979 """ 1980 if (self._options and self._options.experimental_replication_mode == 1981 InputReplicationMode.PER_REPLICA and 1982 not self._options.experimental_fetch_to_device): 1983 return [data_list] 1984 else: 1985 return data_list 1986 1987 def get_next(self, device, name=None): 1988 """Get next element for the given device.""" 1989 del name 1990 with ops.device(self._worker): 1991 if _should_use_multi_device_iterator(self._options): 1992 return self._iterator.get_next(device) 1993 else: 1994 return self._iterator.get_next() 1995 1996 def get_next_as_list_static_shapes(self, name=None): 1997 """Get next element from the underlying iterator. 1998 1999 Runs the iterator get_next() within a device scope. Since this doesn't use 2000 get_next_as_optional(), it is considerably faster than get_next_as_list() 2001 (but can only be used when the shapes are static). 2002 2003 Args: 2004 name: not used. 2005 2006 Returns: 2007 A list consisting of the next data from each device. 2008 """ 2009 del name 2010 with ops.device(self._worker): 2011 return self._format_data_list_with_options(self._iterator.get_next()) 2012 2013 def get_next_as_list(self, name=None): 2014 """Get next element from underlying iterator. 2015 2016 If there is no data left, a list of dummy tensors with possible batch 2017 dimensions set to 0 will be returned. Use of get_next_as_optional() and 2018 extra logic adds overhead compared to get_next_as_list_static_shapes(), but 2019 allows us to handle non-static shapes. 2020 2021 Args: 2022 name: not used. 2023 2024 Returns: 2025 A boolean tensor indicates whether there is any data in next element and 2026 the real data as the next element or a list of dummy tensors if no data 2027 left. 2028 """ 2029 del name 2030 with ops.device(self._worker): 2031 data_list = self._format_data_list_with_options( 2032 self._iterator.get_next_as_optional()) 2033 result = [] 2034 for i, data in enumerate(data_list): 2035 # Place the condition op in the same device as the data so the data 2036 # doesn't need to be sent back to the worker. 2037 with ops.device(self._devices[i]): 2038 # Data will be fetched in order, so we only need to check if the first 2039 # replica has value to see whether there is data left for this single 2040 # worker. 2041 if i == 0: 2042 worker_has_value = data.has_value() 2043 2044 # pylint: disable=unnecessary-lambda 2045 # pylint: disable=cell-var-from-loop 2046 real_data = control_flow_ops.cond( 2047 data.has_value(), 2048 lambda: data.get_value(), 2049 lambda: _dummy_tensor_fn(data.element_spec), 2050 strict=True, 2051 ) 2052 # Some dimensions in `replicas` will become unknown after we 2053 # conditionally return the real tensors or the dummy tensors. Recover 2054 # the shapes from `data.element_spec`. We only need to do this in 2055 # non eager mode because we always know the runtime shape of the 2056 # tensors in eager mode. 2057 if not context.executing_eagerly(): 2058 real_data = _recover_shape_fn(real_data, data.element_spec) 2059 result.append(real_data) 2060 # pylint: enable=cell-var-from-loop 2061 # pylint: enable=unnecessary-lambda 2062 2063 return worker_has_value, result 2064 2065 2066class _SingleWorkerDatasetIteratorSpec(type_spec.TypeSpec): 2067 """Type specification for `_SingleWorkerOwnedDatasetIterator`.""" 2068 2069 __slots__ = [ 2070 "_worker", "_devices", "_element_spec", "_options", 2071 "_canonicalize_devices" 2072 ] 2073 2074 def __init__(self, worker, devices, element_spec, options, 2075 canonicalize_devices=True): 2076 self._worker = worker 2077 if canonicalize_devices: 2078 self._devices = tuple(device_util.canonicalize(d) for d in devices) 2079 else: 2080 self._devices = tuple( 2081 device_util.canonicalize_without_job_and_task(d) for d in devices) 2082 self._element_spec = element_spec 2083 # `self._options` intentionally made not `None` for proper serialization. 2084 self._options = (options if options is not None else 2085 distribute_lib.InputOptions()) 2086 self._canonicalize_devices = canonicalize_devices 2087 2088 @property 2089 def value_type(self): 2090 return _SingleWorkerOwnedDatasetIterator 2091 2092 def _serialize(self): 2093 return (self._worker, self._devices, self._element_spec, self._options, 2094 self._canonicalize_devices) 2095 2096 def _get_multi_device_iterator_spec(self, specs): 2097 device_scope = device_util.canonicalize(self._worker, device_util.current()) 2098 host_device = device_util.get_host_for_device(device_scope) 2099 # source_device while creating iterator governs the worker device in 2100 # iterator spec. 2101 worker = host_device 2102 specs.append( 2103 multi_device_iterator_ops.MultiDeviceIteratorSpec( 2104 self._devices, worker, element_spec=self._element_spec)) 2105 2106 @property 2107 def _component_specs(self): 2108 specs = [] 2109 if _should_use_multi_device_iterator(self._options): 2110 self._get_multi_device_iterator_spec(specs) 2111 else: 2112 specs.append(iterator_ops.IteratorSpec(element_spec=self._element_spec)) 2113 return specs 2114 2115 def _to_components(self, value): 2116 return [value._iterator] # pylint: disable=protected-access 2117 2118 def _from_components(self, components): 2119 return _SingleWorkerOwnedDatasetIterator( 2120 dataset=None, 2121 worker=self._worker, 2122 devices=self._devices, 2123 components=components, 2124 element_spec=self._element_spec, 2125 options=self._options, 2126 canonicalize_devices=self._canonicalize_devices) 2127 2128 @staticmethod 2129 def from_value(value): 2130 # pylint: disable=protected-access 2131 return _SingleWorkerDatasetIteratorSpec(value._worker, value._devices, 2132 value._element_spec, value._options, 2133 value._canonicalize_devices) 2134 2135 2136class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase, 2137 composite_tensor.CompositeTensor): 2138 """Iterator for a DistributedDataset instance.""" 2139 2140 def __init__(self, 2141 dataset=None, 2142 worker=None, 2143 devices=None, 2144 components=None, 2145 element_spec=None, 2146 options=None, 2147 canonicalize_devices=None): 2148 """Create iterator for the `dataset` to fetch data to worker's `devices` . 2149 2150 `OwnedMultiDeviceIterator` is used to prefetch input to the devices on the 2151 given worker. The lifetime of this iterator is tied to the encompassing 2152 python object. Once we go out of scope of the python object or return from 2153 a tf.function the underlying iterator resource is deleted. 2154 2155 Args: 2156 dataset: A `tf.data.Dataset` instance. 2157 worker: Worker on which ops should be created. 2158 devices: Distribute data from `dataset` to these devices. 2159 components: Tensor components to construct the 2160 _SingleWorkerOwnedDatasetIterator from. 2161 element_spec: A nested structure of `TypeSpec` objects that represents the 2162 type specification of elements of the iterator. 2163 options: `tf.distribute.InputOptions` used to control options on how this 2164 dataset is distributed. 2165 canonicalize_devices: Whether to canonicalize devices for workers fully or 2166 partially. If False, it will partially canonicalize devices by removing 2167 job and task. 2168 """ 2169 if worker is None or devices is None: 2170 raise ValueError("Both `worker` and `devices` should be provided") 2171 2172 error_message = ("Either `dataset` or both `components` and `element_spec` " 2173 "need to be provided.") 2174 2175 self._options = options 2176 self._canonicalize_devices = canonicalize_devices 2177 if dataset is None: 2178 if (components is None or element_spec is None): 2179 raise ValueError(error_message) 2180 self._element_spec = element_spec 2181 self._worker = worker 2182 self._devices = devices 2183 self._iterator = components[0] 2184 else: 2185 if (components is not None or element_spec is not None): 2186 raise ValueError(error_message) 2187 super(_SingleWorkerOwnedDatasetIterator, 2188 self).__init__(dataset, worker, devices, self._options) 2189 2190 def _create_owned_multi_device_iterator(self): 2191 # If the worker devices are already canonicalized, canonicalizing again 2192 # would have no impact. 2193 # For strategies running on remote workers such as PS Strategy, the device 2194 # scope will be derived from current worker, if used under init_scope(). 2195 device_scope = device_util.canonicalize(self._worker, 2196 device_util.current()) 2197 host_device = device_util.get_host_for_device(device_scope) 2198 with ops.device(device_scope): 2199 if self._options is not None: 2200 self._iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator( 2201 self._dataset, 2202 self._devices, 2203 source_device=host_device, 2204 max_buffer_size=self._options 2205 .experimental_per_replica_buffer_size, 2206 prefetch_buffer_size=self._options 2207 .experimental_per_replica_buffer_size) 2208 else: 2209 self._iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator( 2210 self._dataset, self._devices, source_device=host_device) 2211 2212 def _make_iterator(self): 2213 """Make appropriate iterator on the dataset.""" 2214 if not self._worker: 2215 raise ValueError("Worker device must be specified when creating an " 2216 "owned iterator.") 2217 if _should_use_multi_device_iterator(self._options): 2218 self._create_owned_multi_device_iterator() 2219 else: 2220 with ops.device(self._worker): 2221 self._iterator = iter(self._dataset) 2222 2223 @property 2224 def element_spec(self): 2225 return self._element_spec 2226 2227 @property 2228 def _type_spec(self): 2229 return _SingleWorkerDatasetIteratorSpec(self._worker, self._devices, 2230 self._element_spec, self._options, 2231 self._canonicalize_devices) 2232 2233 @property 2234 def output_classes(self): 2235 """Returns the class of each component of an element of this iterator. 2236 2237 The expected values are `tf.Tensor` and `tf.SparseTensor`. 2238 2239 Returns: 2240 A nested structure of Python `type` objects corresponding to each 2241 component of an element of this dataset. 2242 """ 2243 return nest.map_structure( 2244 lambda component_spec: component_spec._to_legacy_output_classes(), # pylint: disable=protected-access 2245 self._element_spec) 2246 2247 @property 2248 def output_shapes(self): 2249 """Returns the shape of each component of an element of this iterator. 2250 2251 Returns: 2252 A nested structure of `tf.TensorShape` objects corresponding to each 2253 component of an element of this dataset. 2254 """ 2255 return nest.map_structure( 2256 lambda component_spec: component_spec._to_legacy_output_shapes(), # pylint: disable=protected-access 2257 self._element_spec) 2258 2259 @property 2260 def output_types(self): 2261 """Returns the type of each component of an element of this iterator. 2262 2263 Returns: 2264 A nested structure of `tf.DType` objects corresponding to each component 2265 of an element of this dataset. 2266 """ 2267 return nest.map_structure( 2268 lambda component_spec: component_spec._to_legacy_output_types(), # pylint: disable=protected-access 2269 self._element_spec) 2270 2271 2272class _SingleWorkerDatasetIterator(_SingleWorkerDatasetIteratorBase): 2273 """Iterator for a single DistributedDatasetV1 instance.""" 2274 2275 def _make_iterator(self): 2276 """Make appropriate iterator on the dataset.""" 2277 with ops.device(self._worker): 2278 if self._options is not None: 2279 self._iterator = multi_device_iterator_ops.MultiDeviceIterator( 2280 self._dataset, 2281 self._devices, 2282 max_buffer_size=self._options.experimental_per_replica_buffer_size, 2283 prefetch_buffer_size=self._options 2284 .experimental_per_replica_buffer_size) 2285 else: 2286 self._iterator = multi_device_iterator_ops.MultiDeviceIterator( 2287 self._dataset, 2288 self._devices, 2289 ) 2290 2291 def initialize(self): 2292 """Initialize underlying iterator. 2293 2294 In eager execution, this simply recreates the underlying iterator. 2295 In graph execution, it returns the initializer ops for the underlying 2296 iterator. 2297 2298 Returns: 2299 A list of any initializer ops that should be run. 2300 """ 2301 if ops.executing_eagerly_outside_functions(): 2302 self._iterator._eager_reset() # pylint: disable=protected-access 2303 return [] 2304 else: 2305 return [self._iterator.initializer] 2306 2307 @property 2308 def output_classes(self): 2309 return dataset_ops.get_legacy_output_classes(self._iterator) 2310 2311 @property 2312 def output_shapes(self): 2313 return dataset_ops.get_legacy_output_shapes(self._iterator) 2314 2315 @property 2316 def output_types(self): 2317 return dataset_ops.get_legacy_output_types(self._iterator) 2318 2319 2320class _SingleWorkerCallableIterator(object): 2321 """Iterator for a single tensor-returning callable.""" 2322 2323 def __init__(self, fn, worker, devices): 2324 self._fn = fn 2325 self._worker = worker 2326 self._devices = devices 2327 2328 def get_next(self, device, name=None): 2329 """Get next element for the given device from the callable.""" 2330 del device, name 2331 with ops.device(self._worker): 2332 return self._fn() 2333 2334 def get_next_as_list_static_shapes(self, name=None): 2335 """Get next element from the callable.""" 2336 del name 2337 with ops.device(self._worker): 2338 data_list = [self._fn() for _ in self._devices] 2339 return data_list 2340 2341 def get_next_as_list(self, name=None): 2342 """Get next element from the callable.""" 2343 del name 2344 with ops.device(self._worker): 2345 data_list = [self._fn() for _ in self._devices] 2346 return constant_op.constant(True), data_list 2347 2348 def initialize(self): 2349 # TODO(petebu) Should this throw an exception instead? 2350 return [] 2351 2352 2353def _create_iterators_per_worker(worker_datasets, 2354 input_workers, 2355 enable_legacy_iterators, 2356 options=None, 2357 canonicalize_devices=False): 2358 """Create a multidevice iterator on each of the workers.""" 2359 assert isinstance(input_workers, InputWorkers) 2360 assert len(worker_datasets) == len(input_workers.worker_devices) 2361 iterators = [] 2362 for i, worker in enumerate(input_workers.worker_devices): 2363 with ops.device(worker): 2364 worker_devices = input_workers.compute_devices_for_worker(i) 2365 if tf2.enabled() and not enable_legacy_iterators: 2366 iterator = _SingleWorkerOwnedDatasetIterator( 2367 dataset=worker_datasets[i], 2368 worker=worker, 2369 devices=worker_devices, 2370 options=options, 2371 canonicalize_devices=canonicalize_devices) 2372 else: 2373 iterator = _SingleWorkerDatasetIterator(worker_datasets[i], worker, 2374 worker_devices, options) 2375 iterators.append(iterator) 2376 return iterators 2377 2378 2379def _create_datasets_from_function_with_input_context(input_contexts, 2380 input_workers, 2381 dataset_fn): 2382 """Create device datasets per worker given a dataset function.""" 2383 datasets = [] 2384 for i, ctx in enumerate(input_contexts): 2385 worker = input_workers.worker_devices[i] 2386 with ops.device(worker): 2387 dataset = dataset_fn(ctx) 2388 datasets.append(dataset) 2389 return datasets, dataset.element_spec 2390 2391 2392# TODO(sourabhbajaj): Remove this in lieu of distributed datasets 2393def _get_batched_dataset(d): 2394 """Get the batched dataset from `d`.""" 2395 # pylint: disable=protected-access 2396 if isinstance(d, dataset_ops.DatasetV1Adapter): 2397 d = d._dataset 2398 2399 if isinstance(d, (dataset_ops.BatchDataset, batching._MapAndBatchDataset)): 2400 return d 2401 elif isinstance(d, (dataset_ops.PrefetchDataset, 2402 dataset_ops._OptionsDataset)): 2403 return _get_batched_dataset(d._input_dataset) 2404 2405 raise ValueError( 2406 "Unable to get batched dataset from the input dataset. `batch` " 2407 "`map_and_batch` need to be the last operations on the dataset. " 2408 "The batch operations can be followed by a prefetch.") 2409 2410 2411def _get_batched_dataset_attributes(d): 2412 """Get `batch_size`, `drop_remainder` of dataset.""" 2413 # pylint: disable=protected-access 2414 assert isinstance(d, 2415 (dataset_ops.BatchDataset, batching._MapAndBatchDataset)) 2416 if isinstance(d, dataset_ops.BatchDataset): 2417 batch_size = d._batch_size 2418 drop_remainder = d._drop_remainder 2419 elif isinstance(d, batching._MapAndBatchDataset): 2420 batch_size = d._batch_size_t 2421 drop_remainder = d._drop_remainder_t 2422 # pylint: enable=protected-access 2423 2424 if tensor_util.is_tf_type(batch_size): 2425 batch_size = tensor_util.constant_value(batch_size) 2426 2427 if tensor_util.is_tf_type(drop_remainder): 2428 drop_remainder = tensor_util.constant_value(drop_remainder) 2429 2430 return batch_size, drop_remainder 2431 2432 2433# TODO(sourabhbajaj): Remove this in lieu of distributed datasets 2434def _get_dataset_attributes(dataset): 2435 """Get the underlying attributes from the dataset object.""" 2436 # pylint: disable=protected-access 2437 2438 # First, get batch_size and drop_remainder from the dataset. We need 2439 # to walk back the dataset creation process and find the batched version in 2440 # order to get the attributes. 2441 batched_dataset = _get_batched_dataset(dataset) 2442 batch_size, drop_remainder = _get_batched_dataset_attributes(batched_dataset) 2443 2444 # Second, prefetch buffer should be get from the original dataset. 2445 prefetch_buffer = None 2446 if isinstance(dataset, dataset_ops.PrefetchDataset): 2447 prefetch_buffer = dataset._buffer_size 2448 elif (isinstance(dataset, dataset_ops.DatasetV1Adapter) 2449 and isinstance(dataset._dataset, dataset_ops.PrefetchDataset)): 2450 prefetch_buffer = dataset._dataset._buffer_size 2451 2452 return batch_size, drop_remainder, prefetch_buffer 2453 2454 2455def _should_use_multi_device_iterator(options): 2456 """Determine whether to use multi_device_iterator_ops.""" 2457 if (options is None or 2458 options.experimental_replication_mode == InputReplicationMode.PER_WORKER 2459 or 2460 (options.experimental_replication_mode == InputReplicationMode.PER_REPLICA 2461 and options.experimental_fetch_to_device)): 2462 return True 2463 return False 2464 2465 2466class MultiStepContext(object): 2467 """A context object that can be used to capture things when running steps. 2468 2469 This context object is useful when running multiple steps at a time using the 2470 `experimental_run_steps_on_iterator` API. For e.g. it allows the user's step 2471 function to specify which outputs to emit at what frequency. Currently it 2472 supports capturing output from the last step, as well as capturing non tensor 2473 outputs. In the future it will be augmented to support other use cases such 2474 as output each N steps. 2475 """ 2476 2477 def __init__(self): 2478 """Initialize an output context. 2479 2480 Returns: 2481 A context object. 2482 """ 2483 self._last_step_outputs = {} 2484 self._last_step_outputs_reduce_ops = {} 2485 self._non_tensor_outputs = {} 2486 2487 @property 2488 def last_step_outputs(self): 2489 """A dictionary consisting of outputs to be captured on last step. 2490 2491 Keys in the dictionary are names of tensors to be captured, as specified 2492 when `set_last_step_output` is called. 2493 Values in the dictionary are the tensors themselves. If 2494 `set_last_step_output` was called with a `reduce_op` for this output, 2495 then the value is the reduced value. 2496 2497 Returns: 2498 A dictionary with last step outputs. 2499 """ 2500 return self._last_step_outputs 2501 2502 def _set_last_step_outputs(self, outputs): 2503 """Replace the entire dictionary of last step outputs.""" 2504 if not isinstance(outputs, dict): 2505 raise ValueError("Need a dictionary to set last_step_outputs.") 2506 self._last_step_outputs = outputs 2507 2508 def set_last_step_output(self, name, output, reduce_op=None): 2509 """Set `output` with `name` to be outputted from the last step. 2510 2511 Args: 2512 name: String, name to identify the output. Doesn't need to match tensor 2513 name. 2514 output: The tensors that should be outputted with `name`. See below for 2515 actual types supported. 2516 reduce_op: Reduction method to use to reduce outputs from multiple 2517 replicas. Required if `set_last_step_output` is called in a replica 2518 context. Optional in cross_replica_context. 2519 When present, the outputs from all the replicas are reduced using the 2520 current distribution strategy's `reduce` method. Hence, the type of 2521 `output` must be what's supported by the corresponding `reduce` method. 2522 For e.g. if using MirroredStrategy and reduction is set, output 2523 must be a `PerReplica` value. 2524 The reduce method is also recorded in a dictionary 2525 `_last_step_outputs_reduce_ops` for later interpreting of the 2526 outputs as already reduced or not. 2527 """ 2528 if distribution_strategy_context.in_cross_replica_context(): 2529 self._last_step_outputs_reduce_ops[name] = reduce_op 2530 if reduce_op is None: 2531 self._last_step_outputs[name] = output 2532 else: 2533 distribution = distribution_strategy_context.get_strategy() 2534 self._last_step_outputs[name] = distribution.reduce(reduce_op, output, 2535 axis=None) 2536 else: 2537 assert reduce_op is not None 2538 def merge_fn(distribution, value): 2539 self._last_step_outputs[name] = distribution.reduce(reduce_op, value, 2540 axis=None) 2541 # Setting this inside the `merge_fn` because all replicas share the same 2542 # context object, so it's more robust to set it only once (even if all 2543 # the replicas are trying to set the same value). 2544 self._last_step_outputs_reduce_ops[name] = reduce_op 2545 2546 distribution_strategy_context.get_replica_context().merge_call( 2547 merge_fn, args=(output,)) 2548 2549 @property 2550 def non_tensor_outputs(self): 2551 """A dictionary consisting of any non tensor outputs to be captured.""" 2552 return self._non_tensor_outputs 2553 2554 def set_non_tensor_output(self, name, output): 2555 """Set `output` with `name` to be captured as a non tensor output.""" 2556 if distribution_strategy_context.in_cross_replica_context(): 2557 self._non_tensor_outputs[name] = output 2558 else: 2559 def merge_fn(distribution, value): 2560 # NOTE(priyag): For non tensor outputs, we simply return all the values 2561 # in a list as reduction doesn't make sense on non tensors. 2562 self._non_tensor_outputs[name] = ( 2563 distribution.experimental_local_results(value)) 2564 distribution_strategy_context.get_replica_context().merge_call( 2565 merge_fn, args=(output,)) 2566 2567 2568def _create_distributed_tensor_spec(strategy, tensor_spec): 2569 """Create a `tf.TypeSpec` for a given strategy and input `tensor_spec`. 2570 2571 Args: 2572 strategy: The given `tf.distribute` strategy. 2573 tensor_spec: `tf.TensorSpec` of a given value. The batch dimension of the 2574 shape should be None if you have partial batches. 2575 2576 Returns: 2577 A `tf.TypeSpec` that matches the values produced by a given strategy. This 2578 can be a `tf.TensorSpec` or a `PerRelicaSpec`. 2579 """ 2580 num_replicas = len(strategy.extended.worker_devices) 2581 2582 # For one device strategy that is not MultiWorkerMirroredStrategy, return the 2583 # tensor_spec as is, since we don't wrap the output with PerReplica in this 2584 # case. 2585 # TODO(b/166464552): remove after we always wrap for all strategies. 2586 if not _always_wrap(strategy): 2587 return tensor_spec 2588 2589 # For other cases we assume the input to tf.function is a per replica type. 2590 def _get_value_per_replica(tensor_spec_per_input): 2591 value_specs = [tensor_spec_per_input for _ in range(num_replicas)] 2592 return values.PerReplicaSpec(*value_specs) 2593 2594 return nest.map_structure(_get_value_per_replica, tensor_spec) 2595 2596 2597def _replace_per_replica_spec(spec, i): 2598 """If `spec` is a `PerReplicaSpec`, then return its `i`th value_spec.""" 2599 if isinstance(spec, values.PerReplicaSpec): 2600 return spec._value_specs[i] # pylint: disable=protected-access 2601 else: 2602 return spec 2603 2604 2605def _enable_get_next_as_optional(strategy, dataset): 2606 """Returns whether to enable using partial batch handling.""" 2607 # TODO(b/133073708): we currently need a flag to control the usage because 2608 # there is a performance difference between get_next() and 2609 # get_next_as_optional(). And we only enable get_next_as_optional when the 2610 # output shapes are not static. 2611 # 2612 # TODO(rxsang): We want to always enable the get_next_as_optional behavior 2613 # when user passed input_fn instead of dataset. 2614 if not getattr( 2615 strategy.extended, "enable_partial_batch_handling", 2616 getattr(strategy.extended, "experimental_enable_get_next_as_optional", 2617 False)): 2618 return False 2619 2620 if context.executing_eagerly(): 2621 # If the dataset is infinite, we don't need to enable last partial batch 2622 # support. Currently the logic only applies to the case that distributed 2623 # dataset is created in eager mode, as we need to evaluate the dataset 2624 # cardinality. 2625 with ops.device(dataset._variant_tensor.device): # pylint: disable=protected-access 2626 if dataset.cardinality().numpy() == cardinality.INFINITE: 2627 return False 2628 2629 return not _is_statically_shaped( 2630 dataset.element_spec) or strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access 2631 2632 2633def _create_per_replica(value_list, strategy): 2634 """Creates a PerReplica. 2635 2636 For strategies other than OneDeviceStrategy, it creates a PerReplica whose 2637 type spec is set to the element spec of the dataset. This helps avoid 2638 retracing for partial batches. Retracing is problematic for multi client when 2639 different client retraces different time, since retracing changes the 2640 collective keys in the tf.function, and causes mismatches among clients. 2641 2642 For single client strategies, this simply calls distribute_utils.regroup(). 2643 2644 Args: 2645 value_list: a list of values, one for each replica. 2646 strategy: the `tf.distribute.Strategy`. 2647 2648 Returns: 2649 a structure of PerReplica. 2650 2651 """ 2652 # TODO(b/166464552): always wrap for all one device strategies as well. 2653 always_wrap = _always_wrap(strategy) 2654 per_replicas = distribute_utils.regroup(value_list, always_wrap=always_wrap) 2655 return per_replicas 2656 2657 2658def _always_wrap(strategy): 2659 """Returns whether to always wrap the values in a DistributedValues.""" 2660 return strategy.extended._in_multi_worker_mode() or len( # pylint: disable=protected-access 2661 strategy.extended.worker_devices) > 1 2662 2663 2664def _rebatch_as_dynamic(per_replica_spec): 2665 """Rebatch the spec to have a dynamic batch dimension.""" 2666 assert isinstance(per_replica_spec, values.PerReplicaSpec), per_replica_spec 2667 2668 # pylint: disable=protected-access 2669 def _rebatch(spec): 2670 # Rebatch if possible. 2671 try: 2672 return spec._unbatch()._batch(None) 2673 except ValueError: 2674 pass 2675 return spec 2676 2677 return values.PerReplicaSpec( 2678 *nest.map_structure(_rebatch, per_replica_spec._value_specs)) 2679 # pylint: enable=protected-access 2680