1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=line-too-long 16"""Library for running a computation across multiple devices. 17 18The intent of this library is that you can write an algorithm in a stylized way 19and it will be usable with a variety of different `tf.distribute.Strategy` 20implementations. Each descendant will implement a different strategy for 21distributing the algorithm across multiple devices/machines. Furthermore, these 22changes can be hidden inside the specific layers and other library classes that 23need special treatment to run in a distributed setting, so that most users' 24model definition code can run unchanged. The `tf.distribute.Strategy` API works 25the same way with eager and graph execution. 26 27*Guides* 28 29* [TensorFlow v2.x](https://www.tensorflow.org/guide/distributed_training) 30* [TensorFlow v1.x](https://github.com/tensorflow/docs/blob/master/site/en/r1/guide/distribute_strategy.ipynb) 31 32*Tutorials* 33 34* [Distributed Training Tutorials](https://www.tensorflow.org/tutorials/distribute/) 35 36 The tutorials cover how to use `tf.distribute.Strategy` to do distributed 37 training with native Keras APIs, custom training loops, 38 and Estimator APIs. They also cover how to save/load model when using 39 `tf.distribute.Strategy`. 40 41*Glossary* 42 43* _Data parallelism_ is where we run multiple copies of the model 44 on different slices of the input data. This is in contrast to 45 _model parallelism_ where we divide up a single copy of a model 46 across multiple devices. 47 Note: we only support data parallelism for now, but 48 hope to add support for model parallelism in the future. 49* A _device_ is a CPU or accelerator (e.g. GPUs, TPUs) on some machine that 50 TensorFlow can run operations on (see e.g. `tf.device`). You may have multiple 51 devices on a single machine, or be connected to devices on multiple 52 machines. Devices used to run computations are called _worker devices_. 53 Devices used to store variables are _parameter devices_. For some strategies, 54 such as `tf.distribute.MirroredStrategy`, the worker and parameter devices 55 will be the same (see mirrored variables below). For others they will be 56 different. For example, `tf.distribute.experimental.CentralStorageStrategy` 57 puts the variables on a single device (which may be a worker device or may be 58 the CPU), and `tf.distribute.experimental.ParameterServerStrategy` puts the 59 variables on separate machines called _parameter servers_ (see below). 60* A _replica_ is one copy of the model, running on one slice of the 61 input data. Right now each replica is executed on its own 62 worker device, but once we add support for model parallelism 63 a replica may span multiple worker devices. 64* A _host_ is the CPU device on a machine with worker devices, typically 65 used for running input pipelines. 66* A _worker_ is defined to be the physical machine(s) containing the physical 67 devices (e.g. GPUs, TPUs) on which the replicated computation is executed. A 68 worker may contain one or more replicas, but contains at least one 69 replica. Typically one worker will correspond to one machine, but in the case 70 of very large models with model parallelism, one worker may span multiple 71 machines. We typically run one input pipeline per worker, feeding all the 72 replicas on that worker. 73* _Synchronous_, or more commonly _sync_, training is where the updates from 74 each replica are aggregated together before updating the model variables. This 75 is in contrast to _asynchronous_, or _async_ training, where each replica 76 updates the model variables independently. You may also have replicas 77 partitioned into groups which are in sync within each group but async between 78 groups. 79* _Parameter servers_: These are machines that hold a single copy of 80 parameters/variables, used by some strategies (right now just 81 `tf.distribute.experimental.ParameterServerStrategy`). All replicas that want 82 to operate on a variable retrieve it at the beginning of a step and send an 83 update to be applied at the end of the step. These can in principle support 84 either sync or async training, but right now we only have support for async 85 training with parameter servers. Compare to 86 `tf.distribute.experimental.CentralStorageStrategy`, which puts all variables 87 on a single device on the same machine (and does sync training), and 88 `tf.distribute.MirroredStrategy`, which mirrors variables to multiple devices 89 (see below). 90 91* _Replica context_ vs. _Cross-replica context_ vs _Update context_ 92 93 A _replica context_ applies 94 when you execute the computation function that was called with `strategy.run`. 95 Conceptually, you're in replica context when executing the computation 96 function that is being replicated. 97 98 An _update context_ is entered in a `tf.distribute.StrategyExtended.update` 99 call. 100 101 An _cross-replica context_ is entered when you enter a `strategy.scope`. This 102 is useful for calling `tf.distribute.Strategy` methods which operate across 103 the replicas (like `reduce_to()`). By default you start in a _replica context_ 104 (the "default single _replica context_") and then some methods can switch you 105 back and forth. 106 107* _Distributed value_: Distributed value is represented by the base class 108 `tf.distribute.DistributedValues`. `tf.distribute.DistributedValues` is useful 109 to represent values on multiple devices, and it contains a map from replica id 110 to values. Two representative kinds of `tf.distribute.DistributedValues` are 111 "PerReplica" and "Mirrored" values. 112 113 "PerReplica" values exist on the worker 114 devices, with a different value for each replica. They are produced by 115 iterating through a distributed dataset returned by 116 `tf.distribute.Strategy.experimental_distribute_dataset` and 117 `tf.distribute.Strategy.distribute_datasets_from_function`. They 118 are also the typical result returned by 119 `tf.distribute.Strategy.run`. 120 121 "Mirrored" values are like "PerReplica" values, except we know that the value 122 on all replicas are the same. We can safely read a "Mirrored" value in a 123 cross-replica context by using the value on any replica. 124 125* _Unwrapping_ and _merging_: Consider calling a function `fn` on multiple 126 replicas, like `strategy.run(fn, args=[w])` with an 127 argument `w` that is a `tf.distribute.DistributedValues`. This means `w` will 128 have a map taking replica id `0` to `w0`, replica id `1` to `w1`, etc. 129 `strategy.run()` unwraps `w` before calling `fn`, so it calls `fn(w0)` on 130 device `d0`, `fn(w1)` on device `d1`, etc. It then merges the return 131 values from `fn()`, which leads to one common object if the returned values 132 are the same object from every replica, or a `DistributedValues` object 133 otherwise. 134 135* _Reductions_ and _all-reduce_: A _reduction_ is a method of aggregating 136 multiple values into one value, like "sum" or "mean". If a strategy is doing 137 sync training, we will perform a reduction on the gradients to a parameter 138 from all replicas before applying the update. _All-reduce_ is an algorithm for 139 performing a reduction on values from multiple devices and making the result 140 available on all of those devices. 141 142* _Mirrored variables_: These are variables that are created on multiple 143 devices, where we keep the variables in sync by applying the same 144 updates to every copy. Mirrored variables are created with 145 `tf.Variable(...synchronization=tf.VariableSynchronization.ON_WRITE...)`. 146 Normally they are only used in synchronous training. 147 148* _SyncOnRead variables_ 149 150 _SyncOnRead variables_ are created by 151 `tf.Variable(...synchronization=tf.VariableSynchronization.ON_READ...)`, and 152 they are created on multiple devices. In replica context, each 153 component variable on the local replica can perform reads and writes without 154 synchronization with each other. When the 155 _SyncOnRead variable_ is read in cross-replica context, the values from 156 component variables are aggregated and returned. 157 158 _SyncOnRead variables_ bring a lot of custom configuration difficulty to the 159 underlying logic, so we do not encourage users to instantiate and use 160 _SyncOnRead variable_ on their own. We have mainly used _SyncOnRead 161 variables_ for use cases such as batch norm and metrics. For performance 162 reasons, we often don't need to keep these statistics in sync every step and 163 they can be accumulated on each replica independently. The only time we want 164 to sync them is reporting or checkpointing, which typically happens in 165 cross-replica context. _SyncOnRead variables_ are also often used by advanced 166 users who want to control when variable values are aggregated. For example, 167 users sometimes want to maintain gradients independently on each replica for a 168 couple of steps without aggregation. 169 170* _Distribute-aware layers_ 171 172 Layers are generally called in a replica context, except when defining a 173 Keras functional model. `tf.distribute.in_cross_replica_context` will let you 174 determine which case you are in. If in a replica context, 175 the `tf.distribute.get_replica_context` function will return the default 176 replica context outside a strategy scope, `None` within a strategy scope, and 177 a `tf.distribute.ReplicaContext` object inside a strategy scope and within a 178 `tf.distribute.Strategy.run` function. The `ReplicaContext` object has an 179 `all_reduce` method for aggregating across all replicas. 180 181 182Note that we provide a default version of `tf.distribute.Strategy` that is 183used when no other strategy is in scope, that provides the same API with 184reasonable default behavior. 185""" 186# pylint: enable=line-too-long 187 188from __future__ import absolute_import 189from __future__ import division 190from __future__ import print_function 191 192import collections 193import copy 194import enum # pylint: disable=g-bad-import-order 195import functools 196import threading 197import weakref 198 199import six 200 201from tensorflow.python.autograph.core import ag_ctx as autograph_ctx 202from tensorflow.python.autograph.impl import api as autograph 203from tensorflow.python.data.ops import dataset_ops 204from tensorflow.python.distribute import collective_util 205from tensorflow.python.distribute import device_util 206from tensorflow.python.distribute import distribution_strategy_context 207from tensorflow.python.distribute import numpy_dataset 208from tensorflow.python.distribute import reduce_util 209from tensorflow.python.distribute import values 210from tensorflow.python.eager import context as eager_context 211from tensorflow.python.eager import def_function 212from tensorflow.python.eager import monitoring 213from tensorflow.python.framework import constant_op 214from tensorflow.python.framework import dtypes 215from tensorflow.python.framework import ops 216from tensorflow.python.framework import tensor_shape 217from tensorflow.python.framework import tensor_util 218from tensorflow.python.ops import array_ops 219from tensorflow.python.ops import control_flow_ops 220from tensorflow.python.ops import custom_gradient 221from tensorflow.python.ops import math_ops 222from tensorflow.python.ops import resource_variable_ops 223from tensorflow.python.ops import summary_ops_v2 224from tensorflow.python.ops import variable_scope 225from tensorflow.python.ops.losses import losses_impl 226from tensorflow.python.platform import tf_logging 227from tensorflow.python.training.tracking import base as trackable 228from tensorflow.python.util import deprecation 229from tensorflow.python.util import nest 230from tensorflow.python.util import tf_contextlib 231from tensorflow.python.util.deprecation import deprecated 232from tensorflow.python.util.tf_export import tf_export 233from tensorflow.tools.docs import doc_controls 234 235 236# ------------------------------------------------------------------------------ 237# Context tracking whether in a strategy.update() or .update_non_slot() call. 238 239 240_update_replica_id = threading.local() 241 242 243def get_update_replica_id(): 244 """Get the current device if in a `tf.distribute.Strategy.update()` call.""" 245 try: 246 return _update_replica_id.current 247 except AttributeError: 248 return None 249 250 251class UpdateContext(object): 252 """Context manager when you are in `update()` or `update_non_slot()`.""" 253 254 __slots__ = ["_replica_id", "_old_replica_id"] 255 256 def __init__(self, replica_id): 257 self._replica_id = replica_id 258 self._old_replica_id = None 259 260 def __enter__(self): 261 self._old_replica_id = get_update_replica_id() 262 _update_replica_id.current = self._replica_id 263 264 def __exit__(self, exception_type, exception_value, traceback): 265 del exception_type, exception_value, traceback 266 _update_replica_id.current = self._old_replica_id 267 268 269# ------------------------------------------------------------------------------ 270# Public utility functions. 271 272 273@tf_export(v1=["distribute.get_loss_reduction"]) 274def get_loss_reduction(): 275 """`tf.distribute.ReduceOp` corresponding to the last loss reduction. 276 277 This is used to decide whether loss should be scaled in optimizer (used only 278 for estimator + v1 optimizer use case). 279 280 Returns: 281 `tf.distribute.ReduceOp` corresponding to the last loss reduction for 282 estimator and v1 optimizer use case. `tf.distribute.ReduceOp.SUM` otherwise. 283 """ 284 if not distribution_strategy_context.get_strategy()._scale_loss_for_estimator: # pylint: disable=protected-access 285 # If we are not in Estimator context then return 'SUM'. We do not need to 286 # scale loss in the optimizer. 287 return reduce_util.ReduceOp.SUM 288 last_reduction = ops.get_default_graph()._last_loss_reduction # pylint: disable=protected-access 289 if (last_reduction == losses_impl.Reduction.SUM or 290 last_reduction == "sum"): # Check for tf.keras.losses.Reduction.SUM 291 return reduce_util.ReduceOp.SUM 292 return reduce_util.ReduceOp.MEAN 293 294 295# ------------------------------------------------------------------------------ 296# Internal API for validating the current thread mode 297 298 299def _require_cross_replica_or_default_context_extended(extended, 300 error_message=None): 301 """Verify in cross-replica context.""" 302 context = _get_per_thread_mode() 303 cross_replica = context.cross_replica_context 304 if cross_replica is not None and cross_replica.extended is extended: 305 return 306 if context is _get_default_replica_mode(): 307 return 308 strategy = extended._container_strategy() # pylint: disable=protected-access 309 # We have an error to report, figure out the right message. 310 if context.strategy is not strategy: 311 _wrong_strategy_scope(strategy, context) 312 assert cross_replica is None 313 if not error_message: 314 error_message = ("Method requires being in cross-replica context, use " 315 "get_replica_context().merge_call()") 316 raise RuntimeError(error_message) 317 318 319def _wrong_strategy_scope(strategy, context): 320 # Figure out the right error message. 321 if not distribution_strategy_context.has_strategy(): 322 raise RuntimeError( 323 'Need to be inside "with strategy.scope()" for %s' % 324 (strategy,)) 325 else: 326 raise RuntimeError( 327 "Mixing different tf.distribute.Strategy objects: %s is not %s" % 328 (context.strategy, strategy)) 329 330 331def require_replica_context(replica_ctx): 332 """Verify in `replica_ctx` replica context.""" 333 context = _get_per_thread_mode() 334 if context.replica_context is replica_ctx: return 335 # We have an error to report, figure out the right message. 336 if context.replica_context is None: 337 raise RuntimeError("Need to be inside `call_for_each_replica()`") 338 if context.strategy is replica_ctx.strategy: 339 # Two different ReplicaContexts with the same tf.distribute.Strategy. 340 raise RuntimeError("Mismatching ReplicaContext.") 341 raise RuntimeError( 342 "Mismatching tf.distribute.Strategy objects: %s is not %s." % 343 (context.strategy, replica_ctx.strategy)) 344 345 346def _require_strategy_scope_strategy(strategy): 347 """Verify in a `strategy.scope()` in this thread.""" 348 context = _get_per_thread_mode() 349 if context.strategy is strategy: return 350 _wrong_strategy_scope(strategy, context) 351 352 353def _require_strategy_scope_extended(extended): 354 """Verify in a `distribution_strategy.scope()` in this thread.""" 355 context = _get_per_thread_mode() 356 if context.strategy.extended is extended: return 357 # Report error. 358 strategy = extended._container_strategy() # pylint: disable=protected-access 359 _wrong_strategy_scope(strategy, context) 360 361 362# ------------------------------------------------------------------------------ 363# Internal context managers used to implement the DistributionStrategy 364# base class 365 366 367class _CurrentDistributionContext(object): 368 """Context manager setting the current `tf.distribute.Strategy`. 369 370 Also: overrides the variable creator and optionally the current device. 371 """ 372 373 def __init__(self, 374 strategy, 375 var_creator_scope, 376 var_scope=None, 377 resource_creator_scope=None, 378 default_device=None): 379 self._context = distribution_strategy_context._CrossReplicaThreadMode( # pylint: disable=protected-access 380 strategy) 381 self._var_creator_scope = var_creator_scope 382 self._var_scope = var_scope 383 self._resource_creator_scope = resource_creator_scope 384 if default_device: 385 self._device_scope = ops.device(default_device) 386 else: 387 self._device_scope = None 388 self._same_scope_again_count = 0 389 390 def __enter__(self): 391 # Allow this scope to be entered if this strategy is already in scope. 392 if distribution_strategy_context.has_strategy(): 393 _require_cross_replica_or_default_context_extended( 394 self._context.strategy.extended) 395 self._same_scope_again_count += 1 396 else: 397 _push_per_thread_mode(self._context) 398 if self._var_scope: 399 self._var_scope.__enter__() 400 self._var_creator_scope.__enter__() 401 if self._resource_creator_scope: 402 nest.map_structure(lambda scope: scope.__enter__(), 403 self._resource_creator_scope) 404 if self._device_scope: 405 self._device_scope.__enter__() 406 return self._context.strategy 407 408 def __exit__(self, exception_type, exception_value, traceback): 409 if self._same_scope_again_count > 0: 410 self._same_scope_again_count -= 1 411 return 412 if self._device_scope: 413 try: 414 self._device_scope.__exit__(exception_type, exception_value, traceback) 415 except RuntimeError as e: 416 six.raise_from( 417 RuntimeError("Device scope nesting error: move call to " 418 "tf.distribute.set_strategy() out of `with` scope."), 419 e) 420 421 try: 422 self._var_creator_scope.__exit__( 423 exception_type, exception_value, traceback) 424 except RuntimeError as e: 425 six.raise_from( 426 RuntimeError("Variable creator scope nesting error: move call to " 427 "tf.distribute.set_strategy() out of `with` scope."), 428 e) 429 430 if self._resource_creator_scope: 431 try: 432 self._resource_creator_scope.__exit__(exception_type, exception_value, 433 traceback) 434 except RuntimeError as e: 435 six.raise_from( 436 RuntimeError("Resource creator scope nesting error: move call " 437 "to tf.distribute.set_strategy() out of `with` " 438 "scope."), e) 439 440 if self._var_scope: 441 try: 442 self._var_scope.__exit__(exception_type, exception_value, traceback) 443 except RuntimeError as e: 444 six.raise_from( 445 RuntimeError("Variable scope nesting error: move call to " 446 "tf.distribute.set_strategy() out of `with` scope."), 447 e) 448 _pop_per_thread_mode() 449 450 451# TODO(yuefengz): add more replication modes. 452@tf_export("distribute.InputReplicationMode") 453class InputReplicationMode(enum.Enum): 454 """Replication mode for input function. 455 456 * `PER_WORKER`: The input function will be called on each worker 457 independently, creating as many input pipelines as number of workers. 458 Replicas will dequeue from the local Dataset on their worker. 459 `tf.distribute.Strategy` doesn't manage any state sharing between such 460 separate input pipelines. 461 * `PER_REPLICA`: The input function will be called on each replica separately. 462 `tf.distribute.Strategy` doesn't manage any state sharing between such 463 separate input pipelines. 464 """ 465 PER_WORKER = "PER_WORKER" 466 PER_REPLICA = "PER_REPLICA" 467 468 469@tf_export("distribute.InputContext") 470class InputContext(object): 471 """A class wrapping information needed by an input function. 472 473 This is a context class that is passed to the user's input function and 474 contains information about the compute replicas and input pipelines. The 475 number of compute replicas (in sync training) helps compute the local batch 476 size from the desired global batch size for each replica. The input pipeline 477 information can be used to return a different subset of the input in each 478 replica (for e.g. shard the input pipeline, use a different input 479 source etc). 480 """ 481 482 __slots__ = [ 483 "_num_input_pipelines", "_input_pipeline_id", "_num_replicas_in_sync" 484 ] 485 486 def __init__(self, 487 num_input_pipelines=1, 488 input_pipeline_id=0, 489 num_replicas_in_sync=1): 490 """Initializes an InputContext object. 491 492 Args: 493 num_input_pipelines: the number of input pipelines in a cluster. 494 input_pipeline_id: the current input pipeline id, should be an int in 495 [0,`num_input_pipelines`). 496 num_replicas_in_sync: the number of replicas that are in sync. 497 """ 498 self._num_input_pipelines = num_input_pipelines 499 self._input_pipeline_id = input_pipeline_id 500 self._num_replicas_in_sync = num_replicas_in_sync 501 502 @property 503 def num_replicas_in_sync(self): 504 """Returns the number of compute replicas in sync.""" 505 return self._num_replicas_in_sync 506 507 @property 508 def input_pipeline_id(self): 509 """Returns the input pipeline ID.""" 510 return self._input_pipeline_id 511 512 @property 513 def num_input_pipelines(self): 514 """Returns the number of input pipelines.""" 515 return self._num_input_pipelines 516 517 def get_per_replica_batch_size(self, global_batch_size): 518 """Returns the per-replica batch size. 519 520 Args: 521 global_batch_size: the global batch size which should be divisible by 522 `num_replicas_in_sync`. 523 524 Returns: 525 the per-replica batch size. 526 527 Raises: 528 ValueError: if `global_batch_size` not divisible by 529 `num_replicas_in_sync`. 530 """ 531 if global_batch_size % self._num_replicas_in_sync != 0: 532 raise ValueError("The `global_batch_size` %r is not divisible by " 533 "`num_replicas_in_sync` %r " % 534 (global_batch_size, self._num_replicas_in_sync)) 535 return global_batch_size // self._num_replicas_in_sync 536 537 def __str__(self): 538 return "tf.distribute.InputContext(input pipeline id {}, total: {})".format( 539 self.input_pipeline_id, self.num_input_pipelines) 540 541 542@tf_export("distribute.experimental.ValueContext", v1=[]) 543class ValueContext(object): 544 """A class wrapping information needed by a distribute function. 545 546 This is a context class that is passed to the `value_fn` in 547 `strategy.experimental_distribute_values_from_function` and contains 548 information about the compute replicas. The `num_replicas_in_sync` and 549 `replica_id` can be used to customize the value on each replica. 550 551 Example usage: 552 553 1. Directly constructed. 554 555 >>> def value_fn(context): 556 ... return context.replica_id_in_sync_group/context.num_replicas_in_sync 557 >>> context = tf.distribute.experimental.ValueContext( 558 ... replica_id_in_sync_group=2, num_replicas_in_sync=4) 559 >>> per_replica_value = value_fn(context) 560 >>> per_replica_value 561 0.5 562 563 2. Passed in by `experimental_distribute_values_from_function`. 564 565 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 566 >>> def value_fn(value_context): 567 ... return value_context.num_replicas_in_sync 568 >>> distributed_values = ( 569 ... strategy.experimental_distribute_values_from_function( 570 ... value_fn)) 571 >>> local_result = strategy.experimental_local_results(distributed_values) 572 >>> local_result 573 (2, 2) 574 575 """ 576 577 __slots__ = ["_replica_id_in_sync_group", "_num_replicas_in_sync"] 578 579 def __init__(self, 580 replica_id_in_sync_group=0, 581 num_replicas_in_sync=1): 582 """Initializes an ValueContext object. 583 584 Args: 585 replica_id_in_sync_group: the current replica_id, should be an int in 586 [0,`num_replicas_in_sync`). 587 num_replicas_in_sync: the number of replicas that are in sync. 588 """ 589 self._replica_id_in_sync_group = replica_id_in_sync_group 590 self._num_replicas_in_sync = num_replicas_in_sync 591 592 @property 593 def num_replicas_in_sync(self): 594 """Returns the number of compute replicas in sync.""" 595 return self._num_replicas_in_sync 596 597 @property 598 def replica_id_in_sync_group(self): 599 """Returns the replica ID.""" 600 return self._replica_id_in_sync_group 601 602 def __str__(self): 603 return (("tf.distribute.ValueContext(replica id {}, " 604 " total replicas in sync: ""{})") 605 .format(self.replica_id_in_sync_group, self.num_replicas_in_sync)) 606 607 608@tf_export("distribute.RunOptions") 609class RunOptions( 610 collections.namedtuple("RunOptions", [ 611 "experimental_enable_dynamic_batch_size", 612 "experimental_bucketizing_dynamic_shape", 613 "experimental_xla_options", 614 ])): 615 """Run options for `strategy.run`. 616 617 This can be used to hold some strategy specific configs. 618 619 Attributes: 620 experimental_enable_dynamic_batch_size: Boolean. Only applies to 621 TPUStrategy. Default to True. If True, TPUStrategy will enable dynamic 622 padder to support dynamic batch size for the inputs. Otherwise only static 623 shape inputs are allowed. 624 experimental_bucketizing_dynamic_shape: Boolean. Only applies to 625 TPUStrategy. Default to False. If True, TPUStrategy will automatic 626 bucketize inputs passed into `run` if the input shape is 627 dynamic. This is a performance optimization to reduce XLA recompilation, 628 which should not have impact on correctness. 629 experimental_xla_options: A `tf.tpu.XLAOptions` instance. Only applies to 630 TPUStrategy. Controls the XLA compiling options on TPUs. Default to None. 631 """ 632 633 def __new__(cls, 634 experimental_enable_dynamic_batch_size=True, 635 experimental_bucketizing_dynamic_shape=False, 636 experimental_xla_options=None): 637 return super(RunOptions, 638 cls).__new__(cls, experimental_enable_dynamic_batch_size, 639 experimental_bucketizing_dynamic_shape, 640 experimental_xla_options) 641 642 643@tf_export("distribute.InputOptions", v1=[]) 644class InputOptions( 645 collections.namedtuple("InputOptions", [ 646 "experimental_fetch_to_device", 647 "experimental_replication_mode", 648 "experimental_place_dataset_on_device", 649 "experimental_per_replica_buffer_size", 650 ])): 651 """Run options for `experimental_distribute_dataset(s_from_function)`. 652 653 This can be used to hold some strategy specific configs. 654 655 ```python 656 # Setup TPUStrategy 657 resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 658 tf.config.experimental_connect_to_cluster(resolver) 659 tf.tpu.experimental.initialize_tpu_system(resolver) 660 strategy = tf.distribute.TPUStrategy(resolver) 661 662 dataset = tf.data.Dataset.range(16) 663 distributed_dataset_on_host = ( 664 strategy.experimental_distribute_dataset( 665 dataset, 666 tf.distribute.InputOptions( 667 experimental_replication_mode= 668 experimental_replication_mode.PER_WORKER, 669 experimental_place_dataset_on_device=False, 670 experimental_per_replica_buffer_size=1))) 671 ``` 672 673 Attributes: 674 experimental_fetch_to_device: Boolean. If True, dataset 675 elements will be prefetched to accelerator device memory. When False, 676 dataset elements are prefetched to host device memory. Must be False when 677 using TPUEmbedding API. experimental_fetch_to_device can only be used 678 with experimental_replication_mode=PER_WORKER. Default behavior is same as 679 setting it to True. 680 experimental_replication_mode: Replication mode for the input function. 681 Currently, the InputReplicationMode.PER_REPLICA is only supported with 682 tf.distribute.MirroredStrategy. 683 experimental_distribute_datasets_from_function. 684 The default value is InputReplicationMode.PER_WORKER. 685 experimental_place_dataset_on_device: Boolean. Default to False. When True, 686 dataset will be placed on the device, otherwise it will remain on the 687 host. experimental_place_dataset_on_device=True can only be used with 688 experimental_replication_mode=PER_REPLICA 689 experimental_per_replica_buffer_size: Integer. Default to 1. Indicates the 690 prefetch buffer size in the replica device memory. Users can set it 691 to 0 to completely disable prefetching behavior, or a number greater than 692 1 to enable larger buffer size. Note that this option is still 693 valid with `experimental_fetch_to_device=False`. 694 """ 695 696 def __new__(cls, 697 experimental_fetch_to_device=None, 698 experimental_replication_mode=InputReplicationMode.PER_WORKER, 699 experimental_place_dataset_on_device=False, 700 experimental_per_replica_buffer_size=1): 701 if experimental_fetch_to_device is None: 702 experimental_fetch_to_device = True 703 704 return super(InputOptions, 705 cls).__new__(cls, experimental_fetch_to_device, 706 experimental_replication_mode, 707 experimental_place_dataset_on_device, 708 experimental_per_replica_buffer_size) 709 710# ------------------------------------------------------------------------------ 711# Base classes for all distribution strategies. 712 713 714# Base class for v1 Strategy and v2 Strategy classes. For API's specific to 715# v1/v2 Strategy, add to implementing classes of StrategyBase. 716# pylint: disable=line-too-long 717class StrategyBase(object): 718 """A state & compute distribution policy on a list of devices. 719 720 See [the guide](https://www.tensorflow.org/guide/distributed_training) 721 for overview and examples. See `tf.distribute.StrategyExtended` and 722 [`tf.distribute`](https://www.tensorflow.org/api_docs/python/tf/distribute) 723 for a glossary of concepts mentioned on this page such as "per-replica", 724 _replica_, and _reduce_. 725 726 In short: 727 728 * To use it with Keras `compile`/`fit`, 729 [please 730 read](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_keras). 731 * You may pass descendant of `tf.distribute.Strategy` to 732 `tf.estimator.RunConfig` to specify how a `tf.estimator.Estimator` 733 should distribute its computation. See 734 [guide](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_estimator_limited_support). 735 * Otherwise, use `tf.distribute.Strategy.scope` to specify that a 736 strategy should be used when building an executing your model. 737 (This puts you in the "cross-replica context" for this strategy, which 738 means the strategy is put in control of things like variable placement.) 739 * If you are writing a custom training loop, you will need to call a few more 740 methods, 741 [see the 742 guide](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_custom_training_loops): 743 744 * Start by creating a `tf.data.Dataset` normally. 745 * Use `tf.distribute.Strategy.experimental_distribute_dataset` to convert 746 a `tf.data.Dataset` to something that produces "per-replica" values. 747 If you want to manually specify how the dataset should be partitioned 748 across replicas, use 749 `tf.distribute.Strategy.distribute_datasets_from_function` 750 instead. 751 * Use `tf.distribute.Strategy.run` to run a function 752 once per replica, taking values that may be "per-replica" (e.g. 753 from a `tf.distribute.DistributedDataset` object) and returning 754 "per-replica" values. 755 This function is executed in "replica context", which means each 756 operation is performed separately on each replica. 757 * Finally use a method (such as `tf.distribute.Strategy.reduce`) to 758 convert the resulting "per-replica" values into ordinary `Tensor`s. 759 760 A custom training loop can be as simple as: 761 762 ``` 763 with my_strategy.scope(): 764 @tf.function 765 def distribute_train_epoch(dataset): 766 def replica_fn(input): 767 # process input and return result 768 return result 769 770 total_result = 0 771 for x in dataset: 772 per_replica_result = my_strategy.run(replica_fn, args=(x,)) 773 total_result += my_strategy.reduce(tf.distribute.ReduceOp.SUM, 774 per_replica_result, axis=None) 775 return total_result 776 777 dist_dataset = my_strategy.experimental_distribute_dataset(dataset) 778 for _ in range(EPOCHS): 779 train_result = distribute_train_epoch(dist_dataset) 780 ``` 781 782 This takes an ordinary `dataset` and `replica_fn` and runs it 783 distributed using a particular `tf.distribute.Strategy` named 784 `my_strategy` above. Any variables created in `replica_fn` are created 785 using `my_strategy`'s policy, and library functions called by 786 `replica_fn` can use the `get_replica_context()` API to implement 787 distributed-specific behavior. 788 789 You can use the `reduce` API to aggregate results across replicas and use 790 this as a return value from one iteration over a 791 `tf.distribute.DistributedDataset`. Or 792 you can use `tf.keras.metrics` (such as loss, accuracy, etc.) to 793 accumulate metrics across steps in a given epoch. 794 795 See the 796 [custom training loop 797 tutorial](https://www.tensorflow.org/tutorials/distribute/custom_training) 798 for a more detailed example. 799 800 Note: `tf.distribute.Strategy` currently does not support TensorFlow's 801 partitioned variables (where a single variable is split across multiple 802 devices) at this time. 803 """ 804 # pylint: enable=line-too-long 805 806 # TODO(josh11b): Partitioned computations, state; sharding 807 # TODO(josh11b): Model parallelism: "replicas" with multiple devices; shuffling 808 809 def __init__(self, extended): 810 self._extended = extended 811 812 # Flag that is used to indicate whether distribution strategy is used with 813 # Estimator. This is required for backward compatibility of loss scaling 814 # when using v1 optimizer with estimator. 815 self._scale_loss_for_estimator = False 816 817 if not hasattr(extended, "_retrace_functions_for_each_device"): 818 # pylint: disable=protected-access 819 # `extended._retrace_functions_for_each_device` dictates 820 # whether the same function will be retraced when it is called on 821 # different devices. 822 try: 823 extended._retrace_functions_for_each_device = ( 824 len(extended.worker_devices) > 1) 825 distribution_strategy_replica_gauge.get_cell("num_replicas").set( 826 self.num_replicas_in_sync) 827 except: # pylint: disable=bare-except 828 # Default for the case where extended.worker_devices can't return 829 # a sensible value. 830 extended._retrace_functions_for_each_device = True 831 832 # Below are the dicts of axis(int) -> `tf.function`. 833 self._mean_reduce_helper_fns = {} 834 self._reduce_sum_fns = {} 835 836 # Whether this strategy is designed to work with `ClusterCoordinator`. 837 self._should_use_with_coordinator = False 838 839 @property 840 def extended(self): 841 """`tf.distribute.StrategyExtended` with additional methods.""" 842 return self._extended 843 844 @tf_contextlib.contextmanager 845 def _scale_loss_for_estimator_enabled(self): 846 """Scope which sets a flag used for scaling losses in optimizer. 847 848 Yields: 849 `_scale_loss_for_estimator_enabled` is a context manager with a 850 side effect, but doesn't return a value. 851 """ 852 self._scale_loss_for_estimator = True 853 try: 854 yield 855 finally: 856 self._scale_loss_for_estimator = False 857 858 # pylint: disable=line-too-long 859 def scope(self): 860 """Context manager to make the strategy current and distribute variables. 861 862 This method returns a context manager, and is used as follows: 863 864 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 865 >>> # Variable created inside scope: 866 >>> with strategy.scope(): 867 ... mirrored_variable = tf.Variable(1.) 868 >>> mirrored_variable 869 MirroredVariable:{ 870 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>, 871 1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=1.0> 872 } 873 >>> # Variable created outside scope: 874 >>> regular_variable = tf.Variable(1.) 875 >>> regular_variable 876 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0> 877 878 _What happens when Strategy.scope is entered?_ 879 880 * `strategy` is installed in the global context as the "current" strategy. 881 Inside this scope, `tf.distribute.get_strategy()` will now return this 882 strategy. Outside this scope, it returns the default no-op strategy. 883 * Entering the scope also enters the "cross-replica context". See 884 `tf.distribute.StrategyExtended` for an explanation on cross-replica and 885 replica contexts. 886 * Variable creation inside `scope` is intercepted by the strategy. Each 887 strategy defines how it wants to affect the variable creation. Sync 888 strategies like `MirroredStrategy`, `TPUStrategy` and 889 `MultiWorkerMiroredStrategy` create variables replicated on each replica, 890 whereas `ParameterServerStrategy` creates variables on the parameter 891 servers. This is done using a custom `tf.variable_creator_scope`. 892 * In some strategies, a default device scope may also be entered: in 893 `MultiWorkerMiroredStrategy`, a default device scope of "/CPU:0" is 894 entered on each worker. 895 896 Note: Entering a scope does not automatically distribute a computation, except 897 in the case of high level training framework like keras `model.fit`. If 898 you're not using `model.fit`, you 899 need to use `strategy.run` API to explicitly distribute that computation. 900 See an example in the [custom training loop tutorial](https://www.tensorflow.org/tutorials/distribute/custom_training). 901 902 903 _What should be in scope and what should be outside?_ 904 905 There are a number of requirements on what needs to happen inside the scope. 906 However, in places where we have information about which strategy is in use, 907 we often enter the scope for the user, so they don't have to do it 908 explicitly (i.e. calling those either inside or outside the scope is OK). 909 910 * Anything that creates variables that should be distributed variables 911 must be called in a `strategy.scope`. This can be accomplished either by 912 directly calling the variable creating function within the scope context, 913 or by relying on another API like `strategy.run` or `keras.Model.fit` to 914 automatically enter it for you. Any variable that is created outside scope 915 will not be distributed and may have performance implications. Some common 916 objects that create variables in TF are Models, Optimizers, Metrics. Such 917 objects should always be initialized in the scope, and any functions 918 that may lazily create variables (e.g., `Model.__call__()`, tracing a 919 `tf.function`, etc.) should similarly be called within scope. Another 920 source of variable creation can be a checkpoint restore - when variables 921 are created lazily. Note that any variable created inside a strategy 922 captures the strategy information. So reading and writing to these 923 variables outside the `strategy.scope` can also work seamlessly, without 924 the user having to enter the scope. 925 * Some strategy APIs (such as `strategy.run` and `strategy.reduce`) which 926 require to be in a strategy's scope, enter the scope automatically, which 927 means when using those APIs you don't need to explicitly enter the scope 928 yourself. 929 * When a `tf.keras.Model` is created inside a `strategy.scope`, the Model 930 object captures the scope information. When high level training framework 931 methods such as `model.compile`, `model.fit`, etc. are then called, the 932 captured scope will be automatically entered, and the associated strategy 933 will be used to distribute the training etc. See a detailed example in 934 [distributed keras tutorial](https://www.tensorflow.org/tutorials/distribute/keras). 935 WARNING: Simply calling `model(..)` does not automatically enter the 936 captured scope -- only high level training framework APIs support this 937 behavior: `model.compile`, `model.fit`, `model.evaluate`, `model.predict` 938 and `model.save` can all be called inside or outside the scope. 939 * The following can be either inside or outside the scope: 940 * Creating the input datasets 941 * Defining `tf.function`s that represent your training step 942 * Saving APIs such as `tf.saved_model.save`. Loading creates variables, 943 so that should go inside the scope if you want to train the model in a 944 distributed way. 945 * Checkpoint saving. As mentioned above - `checkpoint.restore` may 946 sometimes need to be inside scope if it creates variables. 947 948 Returns: 949 A context manager. 950 """ 951 return self._extended._scope(self) # pylint: disable=protected-access 952 # pylint: enable=line-too-long 953 954 @doc_controls.do_not_doc_inheritable # DEPRECATED, moving to `extended` 955 @deprecated(None, "use extended.colocate_vars_with() instead.") 956 def colocate_vars_with(self, colocate_with_variable): 957 """DEPRECATED: use extended.colocate_vars_with() instead.""" 958 return self._extended.colocate_vars_with(colocate_with_variable) 959 960 @doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only 961 def make_dataset_iterator(self, dataset): 962 """DEPRECATED TF 1.x ONLY.""" 963 return self._extended._make_dataset_iterator(dataset) # pylint: disable=protected-access 964 965 @doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only 966 def make_input_fn_iterator(self, 967 input_fn, 968 replication_mode=InputReplicationMode.PER_WORKER): 969 """DEPRECATED TF 1.x ONLY.""" 970 if replication_mode != InputReplicationMode.PER_WORKER: 971 raise ValueError( 972 "Input replication mode not supported: %r" % replication_mode) 973 with self.scope(): 974 return self.extended._make_input_fn_iterator( # pylint: disable=protected-access 975 input_fn, replication_mode=replication_mode) 976 977 @doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only 978 @deprecated(None, "use run() instead") 979 def experimental_run(self, fn, input_iterator=None): 980 """DEPRECATED TF 1.x ONLY.""" 981 with self.scope(): 982 args = (input_iterator.get_next(),) if input_iterator is not None else () 983 return self.run(fn, args=args) 984 985 def experimental_distribute_dataset(self, dataset, options=None): 986 # pylint: disable=line-too-long 987 """Creates `tf.distribute.DistributedDataset` from `tf.data.Dataset`. 988 989 The returned `tf.distribute.DistributedDataset` can be iterated over 990 similar to regular datasets. 991 NOTE: The user cannot add any more transformations to a 992 `tf.distribute.DistributedDataset`. You can only create an iterator or 993 examine the `tf.TypeSpec` of the data generated by it. See API docs of 994 `tf.distribute.DistributedDataset` to learn more. 995 996 The following is an example: 997 998 >>> global_batch_size = 2 999 >>> # Passing the devices is optional. 1000 ... strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"]) 1001 >>> # Create a dataset 1002 ... dataset = tf.data.Dataset.range(4).batch(global_batch_size) 1003 >>> # Distribute that dataset 1004 ... dist_dataset = strategy.experimental_distribute_dataset(dataset) 1005 >>> @tf.function 1006 ... def replica_fn(input): 1007 ... return input*2 1008 >>> result = [] 1009 >>> # Iterate over the `tf.distribute.DistributedDataset` 1010 ... for x in dist_dataset: 1011 ... # process dataset elements 1012 ... result.append(strategy.run(replica_fn, args=(x,))) 1013 >>> print(result) 1014 [PerReplica:{ 1015 0: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([0])>, 1016 1: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([2])> 1017 }, PerReplica:{ 1018 0: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([4])>, 1019 1: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([6])> 1020 }] 1021 1022 1023 Three key actions happening under the hood of this method are batching, 1024 sharding, and prefetching. 1025 1026 In the code snippet above, `dataset` is batched by `global_batch_size`, and 1027 calling `experimental_distribute_dataset` on it rebatches `dataset` to a 1028 new batch size that is equal to the global batch size divided by the number 1029 of replicas in sync. We iterate through it using a Pythonic for loop. 1030 `x` is a `tf.distribute.DistributedValues` containing data for all replicas, 1031 and each replica gets data of the new batch size. 1032 `tf.distribute.Strategy.run` will take care of feeding the right per-replica 1033 data in `x` to the right `replica_fn` executed on each replica. 1034 1035 Sharding contains autosharding across multiple workers and within every 1036 worker. First, in multi-worker distributed training (i.e. when you use 1037 `tf.distribute.experimental.MultiWorkerMirroredStrategy` 1038 or `tf.distribute.TPUStrategy`), autosharding a dataset over a set of 1039 workers means that each worker is assigned a subset of the entire dataset 1040 (if the right `tf.data.experimental.AutoShardPolicy` is set). This is to 1041 ensure that at each step, a global batch size of non-overlapping dataset 1042 elements will be processed by each worker. Autosharding has a couple of 1043 different options that can be specified using 1044 `tf.data.experimental.DistributeOptions`. Then, sharding within each worker 1045 means the method will split the data among all the worker devices (if more 1046 than one a present). This will happen regardless of multi-worker 1047 autosharding. 1048 1049 Note: for autosharding across multiple workers, the default mode is 1050 `tf.data.experimental.AutoShardPolicy.AUTO`. This mode 1051 will attempt to shard the input dataset by files if the dataset is 1052 being created out of reader datasets (e.g. `tf.data.TFRecordDataset`, 1053 `tf.data.TextLineDataset`, etc.) or otherwise shard the dataset by data, 1054 where each of the workers will read the entire dataset and only process the 1055 shard assigned to it. However, if you have less than one input file per 1056 worker, we suggest that you disable dataset autosharding across workers by 1057 setting the `tf.data.experimental.DistributeOptions.auto_shard_policy` to be 1058 `tf.data.experimental.AutoShardPolicy.OFF`. 1059 1060 By default, this method adds a prefetch transformation at the end of the 1061 user provided `tf.data.Dataset` instance. The argument to the prefetch 1062 transformation which is `buffer_size` is equal to the number of replicas in 1063 sync. 1064 1065 If the above batch splitting and dataset sharding logic is undesirable, 1066 please use 1067 `tf.distribute.Strategy.distribute_datasets_from_function` 1068 instead, which does not do any automatic batching or sharding for you. 1069 1070 Note: If you are using TPUStrategy, the order in which the data is processed 1071 by the workers when using 1072 `tf.distribute.Strategy.experimental_distribute_dataset` or 1073 `tf.distribute.Strategy.distribute_datasets_from_function` is 1074 not guaranteed. This is typically required if you are using 1075 `tf.distribute` to scale prediction. You can however insert an index for 1076 each element in the batch and order outputs accordingly. Refer to [this 1077 snippet](https://www.tensorflow.org/tutorials/distribute/input#caveats) 1078 for an example of how to order outputs. 1079 1080 Note: Stateful dataset transformations are currently not supported with 1081 `tf.distribute.experimental_distribute_dataset` or 1082 `tf.distribute.distribute_datasets_from_function`. Any stateful 1083 ops that the dataset may have are currently ignored. For example, if your 1084 dataset has a `map_fn` that uses `tf.random.uniform` to rotate an image, 1085 then you have a dataset graph that depends on state (i.e the random seed) on 1086 the local machine where the python process is being executed. 1087 1088 For a tutorial on more usage and properties of this method, refer to the 1089 [tutorial on distributed input](https://www.tensorflow.org/tutorials/distribute/input#tfdistributestrategyexperimental_distribute_dataset). 1090 If you are interested in last partial batch handling, read [this section](https://www.tensorflow.org/tutorials/distribute/input#partial_batches). 1091 1092 Args: 1093 dataset: `tf.data.Dataset` that will be sharded across all replicas using 1094 the rules stated above. 1095 options: `tf.distribute.InputOptions` used to control options on how this 1096 dataset is distributed. 1097 1098 Returns: 1099 A `tf.distribute.DistributedDataset`. 1100 """ 1101 distribution_strategy_input_api_counter.get_cell( 1102 self.__class__.__name__, "distribute_dataset").increase_by(1) 1103 # pylint: enable=line-too-long 1104 return self._extended._experimental_distribute_dataset(dataset, options) # pylint: disable=protected-access 1105 1106 def distribute_datasets_from_function(self, dataset_fn, options=None): 1107 # pylint: disable=line-too-long 1108 """Distributes `tf.data.Dataset` instances created by calls to `dataset_fn`. 1109 1110 The argument `dataset_fn` that users pass in is an input function that has a 1111 `tf.distribute.InputContext` argument and returns a `tf.data.Dataset` 1112 instance. It is expected that the returned dataset from `dataset_fn` is 1113 already batched by per-replica batch size (i.e. global batch size divided by 1114 the number of replicas in sync) and sharded. 1115 `tf.distribute.Strategy.distribute_datasets_from_function` does 1116 not batch or shard the `tf.data.Dataset` instance 1117 returned from the input function. `dataset_fn` will be called on the CPU 1118 device of each of the workers and each generates a dataset where every 1119 replica on that worker will dequeue one batch of inputs (i.e. if a worker 1120 has two replicas, two batches will be dequeued from the `Dataset` every 1121 step). 1122 1123 This method can be used for several purposes. First, it allows you to 1124 specify your own batching and sharding logic. (In contrast, 1125 `tf.distribute.experimental_distribute_dataset` does batching and sharding 1126 for you.) For example, where 1127 `experimental_distribute_dataset` is unable to shard the input files, this 1128 method might be used to manually shard the dataset (avoiding the slow 1129 fallback behavior in `experimental_distribute_dataset`). In cases where the 1130 dataset is infinite, this sharding can be done by creating dataset replicas 1131 that differ only in their random seed. 1132 1133 The `dataset_fn` should take an `tf.distribute.InputContext` instance where 1134 information about batching and input replication can be accessed. 1135 1136 You can use `element_spec` property of the 1137 `tf.distribute.DistributedDataset` returned by this API to query the 1138 `tf.TypeSpec` of the elements returned by the iterator. This can be used to 1139 set the `input_signature` property of a `tf.function`. Follow 1140 `tf.distribute.DistributedDataset.element_spec` to see an example. 1141 1142 IMPORTANT: The `tf.data.Dataset` returned by `dataset_fn` should have a 1143 per-replica batch size, unlike `experimental_distribute_dataset`, which uses 1144 the global batch size. This may be computed using 1145 `input_context.get_per_replica_batch_size`. 1146 1147 Note: If you are using TPUStrategy, the order in which the data is processed 1148 by the workers when using 1149 `tf.distribute.Strategy.experimental_distribute_dataset` or 1150 `tf.distribute.Strategy.distribute_datasets_from_function` is 1151 not guaranteed. This is typically required if you are using 1152 `tf.distribute` to scale prediction. You can however insert an index for 1153 each element in the batch and order outputs accordingly. Refer to [this 1154 snippet](https://www.tensorflow.org/tutorials/distribute/input#caveats) 1155 for an example of how to order outputs. 1156 1157 Note: Stateful dataset transformations are currently not supported with 1158 `tf.distribute.experimental_distribute_dataset` or 1159 `tf.distribute.distribute_datasets_from_function`. Any stateful 1160 ops that the dataset may have are currently ignored. For example, if your 1161 dataset has a `map_fn` that uses `tf.random.uniform` to rotate an image, 1162 then you have a dataset graph that depends on state (i.e the random seed) on 1163 the local machine where the python process is being executed. 1164 1165 For a tutorial on more usage and properties of this method, refer to the 1166 [tutorial on distributed input](https://www.tensorflow.org/tutorials/distribute/input#tfdistributestrategyexperimental_distribute_datasets_from_function)). 1167 If you are interested in last partial batch handling, read [this section](https://www.tensorflow.org/tutorials/distribute/input#partial_batches). 1168 1169 Args: 1170 dataset_fn: A function taking a `tf.distribute.InputContext` instance and 1171 returning a `tf.data.Dataset`. 1172 options: `tf.distribute.InputOptions` used to control options on how this 1173 dataset is distributed. 1174 1175 Returns: 1176 A `tf.distribute.DistributedDataset`. 1177 """ 1178 distribution_strategy_input_api_counter.get_cell( 1179 self.__class__.__name__, 1180 "distribute_datasets_from_function").increase_by(1) 1181 # pylint: enable=line-too-long 1182 return self._extended._distribute_datasets_from_function( # pylint: disable=protected-access 1183 dataset_fn, options) 1184 1185 # TODO(b/162776748): Remove deprecated symbol. 1186 @doc_controls.do_not_doc_inheritable 1187 @deprecation.deprecated(None, "rename to distribute_datasets_from_function") 1188 def experimental_distribute_datasets_from_function(self, 1189 dataset_fn, 1190 options=None): 1191 return self.distribute_datasets_from_function(dataset_fn, options) 1192 1193 def run(self, fn, args=(), kwargs=None, options=None): 1194 """Invokes `fn` on each replica, with the given arguments. 1195 1196 This method is the primary way to distribute your computation with a 1197 tf.distribute object. It invokes `fn` on each replica. If `args` or `kwargs` 1198 have `tf.distribute.DistributedValues`, such as those produced by a 1199 `tf.distribute.DistributedDataset` from 1200 `tf.distribute.Strategy.experimental_distribute_dataset` or 1201 `tf.distribute.Strategy.distribute_datasets_from_function`, 1202 when `fn` is executed on a particular replica, it will be executed with the 1203 component of `tf.distribute.DistributedValues` that correspond to that 1204 replica. 1205 1206 `fn` is invoked under a replica context. `fn` may call 1207 `tf.distribute.get_replica_context()` to access members such as 1208 `all_reduce`. Please see the module-level docstring of tf.distribute for the 1209 concept of replica context. 1210 1211 All arguments in `args` or `kwargs` can be a nested structure of tensors, 1212 e.g. a list of tensors, in which case `args` and `kwargs` will be passed to 1213 the `fn` invoked on each replica. Or `args` or `kwargs` can be 1214 `tf.distribute.DistributedValues` containing tensors or composite tensors, 1215 i.e. `tf.compat.v1.TensorInfo.CompositeTensor`, in which case each `fn` call 1216 will get the component of a `tf.distribute.DistributedValues` corresponding 1217 to its replica. Note that arbitrary Python values that are not of the types 1218 above are not supported. 1219 1220 IMPORTANT: Depending on the implementation of `tf.distribute.Strategy` and 1221 whether eager execution is enabled, `fn` may be called one or more times. If 1222 `fn` is annotated with `tf.function` or `tf.distribute.Strategy.run` is 1223 called inside a `tf.function` (eager execution is disabled inside a 1224 `tf.function` by default), `fn` is called once per replica to generate a 1225 Tensorflow graph, which will then be reused for execution with new inputs. 1226 Otherwise, if eager execution is enabled, `fn` will be called once per 1227 replica every step just like regular python code. 1228 1229 Example usage: 1230 1231 1. Constant tensor input. 1232 1233 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 1234 >>> tensor_input = tf.constant(3.0) 1235 >>> @tf.function 1236 ... def replica_fn(input): 1237 ... return input*2.0 1238 >>> result = strategy.run(replica_fn, args=(tensor_input,)) 1239 >>> result 1240 PerReplica:{ 1241 0: <tf.Tensor: shape=(), dtype=float32, numpy=6.0>, 1242 1: <tf.Tensor: shape=(), dtype=float32, numpy=6.0> 1243 } 1244 1245 2. DistributedValues input. 1246 1247 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 1248 >>> @tf.function 1249 ... def run(): 1250 ... def value_fn(value_context): 1251 ... return value_context.num_replicas_in_sync 1252 ... distributed_values = ( 1253 ... strategy.experimental_distribute_values_from_function( 1254 ... value_fn)) 1255 ... def replica_fn2(input): 1256 ... return input*2 1257 ... return strategy.run(replica_fn2, args=(distributed_values,)) 1258 >>> result = run() 1259 >>> result 1260 <tf.Tensor: shape=(), dtype=int32, numpy=4> 1261 1262 3. Use `tf.distribute.ReplicaContext` to allreduce values. 1263 1264 >>> strategy = tf.distribute.MirroredStrategy(["gpu:0", "gpu:1"]) 1265 >>> @tf.function 1266 ... def run(): 1267 ... def value_fn(value_context): 1268 ... return tf.constant(value_context.replica_id_in_sync_group) 1269 ... distributed_values = ( 1270 ... strategy.experimental_distribute_values_from_function( 1271 ... value_fn)) 1272 ... def replica_fn(input): 1273 ... return tf.distribute.get_replica_context().all_reduce("sum", input) 1274 ... return strategy.run(replica_fn, args=(distributed_values,)) 1275 >>> result = run() 1276 >>> result 1277 PerReplica:{ 1278 0: <tf.Tensor: shape=(), dtype=int32, numpy=1>, 1279 1: <tf.Tensor: shape=(), dtype=int32, numpy=1> 1280 } 1281 1282 Args: 1283 fn: The function to run on each replica. 1284 args: Optional positional arguments to `fn`. Its element can be a tensor, 1285 a nested structure of tensors or a `tf.distribute.DistributedValues`. 1286 kwargs: Optional keyword arguments to `fn`. Its element can be a tensor, 1287 a nested structure of tensors or a `tf.distribute.DistributedValues`. 1288 options: An optional instance of `tf.distribute.RunOptions` specifying 1289 the options to run `fn`. 1290 1291 Returns: 1292 Merged return value of `fn` across replicas. The structure of the return 1293 value is the same as the return value from `fn`. Each element in the 1294 structure can either be `tf.distribute.DistributedValues`, `Tensor` 1295 objects, or `Tensor`s (for example, if running on a single replica). 1296 """ 1297 del options 1298 1299 if not isinstance(args, (list, tuple)): 1300 raise ValueError( 1301 "positional args must be a list or tuple, got {}".format(type(args))) 1302 1303 with self.scope(): 1304 # tf.distribute supports Eager functions, so AutoGraph should not be 1305 # applied when the caller is also in Eager mode. 1306 fn = autograph.tf_convert( 1307 fn, autograph_ctx.control_status_ctx(), convert_by_default=False) 1308 return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs) 1309 1310 def reduce(self, reduce_op, value, axis): 1311 """Reduce `value` across replicas and return result on current device. 1312 1313 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 1314 >>> def step_fn(): 1315 ... i = tf.distribute.get_replica_context().replica_id_in_sync_group 1316 ... return tf.identity(i) 1317 >>> 1318 >>> per_replica_result = strategy.run(step_fn) 1319 >>> total = strategy.reduce("SUM", per_replica_result, axis=None) 1320 >>> total 1321 <tf.Tensor: shape=(), dtype=int32, numpy=1> 1322 1323 To see how this would look with multiple replicas, consider the same 1324 example with MirroredStrategy with 2 GPUs: 1325 1326 ```python 1327 strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"]) 1328 def step_fn(): 1329 i = tf.distribute.get_replica_context().replica_id_in_sync_group 1330 return tf.identity(i) 1331 1332 per_replica_result = strategy.run(step_fn) 1333 # Check devices on which per replica result is: 1334 strategy.experimental_local_results(per_replica_result)[0].device 1335 # /job:localhost/replica:0/task:0/device:GPU:0 1336 strategy.experimental_local_results(per_replica_result)[1].device 1337 # /job:localhost/replica:0/task:0/device:GPU:1 1338 1339 total = strategy.reduce("SUM", per_replica_result, axis=None) 1340 # Check device on which reduced result is: 1341 total.device 1342 # /job:localhost/replica:0/task:0/device:CPU:0 1343 1344 ``` 1345 1346 This API is typically used for aggregating the results returned from 1347 different replicas, for reporting etc. For example, loss computed from 1348 different replicas can be averaged using this API before printing. 1349 1350 Note: The result is copied to the "current" device - which would typically 1351 be the CPU of the worker on which the program is running. For `TPUStrategy`, 1352 it is the first TPU host. For multi client `MultiWorkerMirroredStrategy`, 1353 this is CPU of each worker. 1354 1355 There are a number of different tf.distribute APIs for reducing values 1356 across replicas: 1357 * `tf.distribute.ReplicaContext.all_reduce`: This differs from 1358 `Strategy.reduce` in that it is for replica context and does 1359 not copy the results to the host device. `all_reduce` should be typically 1360 used for reductions inside the training step such as gradients. 1361 * `tf.distribute.StrategyExtended.reduce_to` and 1362 `tf.distribute.StrategyExtended.batch_reduce_to`: These APIs are more 1363 advanced versions of `Strategy.reduce` as they allow customizing the 1364 destination of the result. They are also called in cross replica context. 1365 1366 _What should axis be?_ 1367 1368 Given a per-replica value returned by `run`, say a 1369 per-example loss, the batch will be divided across all the replicas. This 1370 function allows you to aggregate across replicas and optionally also across 1371 batch elements by specifying the axis parameter accordingly. 1372 1373 For example, if you have a global batch size of 8 and 2 1374 replicas, values for examples `[0, 1, 2, 3]` will be on replica 0 and 1375 `[4, 5, 6, 7]` will be on replica 1. With `axis=None`, `reduce` will 1376 aggregate only across replicas, returning `[0+4, 1+5, 2+6, 3+7]`. 1377 This is useful when each replica is computing a scalar or some other value 1378 that doesn't have a "batch" dimension (like a gradient or loss). 1379 ``` 1380 strategy.reduce("sum", per_replica_result, axis=None) 1381 ``` 1382 1383 Sometimes, you will want to aggregate across both the global batch _and_ 1384 all replicas. You can get this behavior by specifying the batch 1385 dimension as the `axis`, typically `axis=0`. In this case it would return a 1386 scalar `0+1+2+3+4+5+6+7`. 1387 ``` 1388 strategy.reduce("sum", per_replica_result, axis=0) 1389 ``` 1390 1391 If there is a last partial batch, you will need to specify an axis so 1392 that the resulting shape is consistent across replicas. So if the last 1393 batch has size 6 and it is divided into [0, 1, 2, 3] and [4, 5], you 1394 would get a shape mismatch unless you specify `axis=0`. If you specify 1395 `tf.distribute.ReduceOp.MEAN`, using `axis=0` will use the correct 1396 denominator of 6. Contrast this with computing `reduce_mean` to get a 1397 scalar value on each replica and this function to average those means, 1398 which will weigh some values `1/8` and others `1/4`. 1399 1400 Args: 1401 reduce_op: a `tf.distribute.ReduceOp` value specifying how values should 1402 be combined. Allows using string representation of the enum such as 1403 "SUM", "MEAN". 1404 value: a `tf.distribute.DistributedValues` instance, e.g. returned by 1405 `Strategy.run`, to be combined into a single tensor. It can also be a 1406 regular tensor when used with `OneDeviceStrategy` or default strategy. 1407 axis: specifies the dimension to reduce along within each 1408 replica's tensor. Should typically be set to the batch dimension, or 1409 `None` to only reduce across replicas (e.g. if the tensor has no batch 1410 dimension). 1411 1412 Returns: 1413 A `Tensor`. 1414 """ 1415 # TODO(josh11b): support `value` being a nest. 1416 _require_cross_replica_or_default_context_extended(self._extended) 1417 if isinstance(reduce_op, six.string_types): 1418 reduce_op = reduce_util.ReduceOp(reduce_op.upper()) 1419 if axis is None: 1420 return self._extended._reduce(reduce_op, value) # pylint: disable=protected-access 1421 if reduce_op == reduce_util.ReduceOp.SUM: 1422 1423 def reduce_sum(v): 1424 return math_ops.reduce_sum(v, axis=axis) 1425 1426 if eager_context.executing_eagerly(): 1427 # As some strategies (e.g. TPUStrategy) doesn't support pure eager 1428 # execution, wrap the `reduce_sum_fn` with a `tf.function` so it can be 1429 # run from eager mode. Cache the tf.function by `axis` to avoid the 1430 # same function to be traced again. 1431 if axis not in self._reduce_sum_fns: 1432 1433 def reduce_sum_fn(v): 1434 return self.run(reduce_sum, args=(v,)) 1435 1436 self._reduce_sum_fns[axis] = def_function.function(reduce_sum_fn) 1437 value = self._reduce_sum_fns[axis](value) 1438 else: 1439 value = self.run(reduce_sum, args=(value,)) 1440 1441 return self._extended._reduce(reduce_op, value) # pylint: disable=protected-access 1442 if reduce_op != reduce_util.ReduceOp.MEAN: 1443 raise TypeError("Expected `reduce_op` to be a `tf.distribute.ReduceOp`, " 1444 "not: %r" % reduce_op) 1445 1446 def mean_reduce_helper(v, axes=axis): 1447 """Computes the numerator and denominator on each replica.""" 1448 numer = math_ops.reduce_sum(v, axis=axes) 1449 def dimension(axis): 1450 if v.shape.rank is not None: 1451 # Note(joshl): We support axis < 0 to be consistent with the 1452 # tf.math.reduce_* operations. 1453 if axis < 0: 1454 if axis + v.shape.rank < 0: 1455 raise ValueError( 1456 "`axis` = %r out of range for `value` with rank %d" % 1457 (axis, v.shape.rank)) 1458 axis += v.shape.rank 1459 elif axis >= v.shape.rank: 1460 raise ValueError( 1461 "`axis` = %r out of range for `value` with rank %d" % 1462 (axis, v.shape.rank)) 1463 # TF v2 returns `None` for unknown dimensions and an integer for 1464 # known dimension, whereas TF v1 returns tensor_shape.Dimension(None) 1465 # or tensor_shape.Dimension(integer). `dimension_value` hides this 1466 # difference, always returning `None` or an integer. 1467 dim = tensor_shape.dimension_value(v.shape[axis]) 1468 if dim is not None: 1469 # By returning a python value in the static shape case, we can 1470 # maybe get a fast path for reducing the denominator. 1471 # TODO(b/151871486): Remove array_ops.identity after we fallback to 1472 # simple reduction if inputs are all on CPU. 1473 return array_ops.identity( 1474 constant_op.constant(dim, dtype=dtypes.int64)) 1475 elif axis < 0: 1476 axis = axis + array_ops.rank(v) 1477 # TODO(b/151871486): Remove array_ops.identity after we fallback to 1478 # simple reduction if inputs are all on CPU. 1479 return array_ops.identity( 1480 array_ops.shape_v2(v, out_type=dtypes.int64)[axis]) 1481 if isinstance(axis, six.integer_types): 1482 denom = dimension(axis) 1483 elif isinstance(axis, (tuple, list)): 1484 denom = math_ops.reduce_prod([dimension(a) for a in axes]) 1485 else: 1486 raise TypeError( 1487 "Expected `axis` to be an integer, tuple or list not: %r" % axis) 1488 # TODO(josh11b): Should we cast denom to v.dtype here instead of after the 1489 # reduce is complete? 1490 return numer, denom 1491 1492 if eager_context.executing_eagerly(): 1493 # As some strategies (e.g. TPUStrategy) doesn't support pure eager 1494 # execution, wrap the `mean_reduce_helper` with a `tf.function` so it can 1495 # be run from eager mode. Cache the tf.function by `axis` to avoid the 1496 # same function to be traced again. 1497 if axis not in self._mean_reduce_helper_fns: 1498 1499 def mean_reduce_fn(v): 1500 return self.run(mean_reduce_helper, args=(v,)) 1501 1502 self._mean_reduce_helper_fns[axis] = def_function.function( 1503 mean_reduce_fn) 1504 numer, denom = self._mean_reduce_helper_fns[axis](value) 1505 else: 1506 numer, denom = self.run(mean_reduce_helper, args=(value,)) 1507 1508 # TODO(josh11b): Should batch reduce here instead of doing two. 1509 numer = self._extended._reduce(reduce_util.ReduceOp.SUM, numer) # pylint: disable=protected-access 1510 denom = self._extended._reduce(reduce_util.ReduceOp.SUM, denom) # pylint: disable=protected-access 1511 denom = math_ops.cast(denom, numer.dtype) 1512 return math_ops.truediv(numer, denom) 1513 1514 @doc_controls.do_not_doc_inheritable # DEPRECATED 1515 @deprecated(None, "use `experimental_local_results` instead.") 1516 def unwrap(self, value): 1517 """Returns the list of all local per-replica values contained in `value`. 1518 1519 DEPRECATED: Please use `experimental_local_results` instead. 1520 1521 Note: This only returns values on the workers initiated by this client. 1522 When using a `tf.distribute.Strategy` like 1523 `tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker 1524 will be its own client, and this function will only return values 1525 computed on that worker. 1526 1527 Args: 1528 value: A value returned by `experimental_run()`, 1529 `extended.call_for_each_replica()`, or a variable created in `scope`. 1530 1531 Returns: 1532 A tuple of values contained in `value`. If `value` represents a single 1533 value, this returns `(value,).` 1534 """ 1535 return self._extended._local_results(value) # pylint: disable=protected-access 1536 1537 def experimental_local_results(self, value): 1538 """Returns the list of all local per-replica values contained in `value`. 1539 1540 Note: This only returns values on the worker initiated by this client. 1541 When using a `tf.distribute.Strategy` like 1542 `tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker 1543 will be its own client, and this function will only return values 1544 computed on that worker. 1545 1546 Args: 1547 value: A value returned by `experimental_run()`, `run(), or a variable 1548 created in `scope`. 1549 1550 Returns: 1551 A tuple of values contained in `value` where ith element corresponds to 1552 ith replica. If `value` represents a single value, this returns 1553 `(value,).` 1554 """ 1555 return self._extended._local_results(value) # pylint: disable=protected-access 1556 1557 @doc_controls.do_not_doc_inheritable # DEPRECATED: TF v1.x only 1558 def group(self, value, name=None): 1559 """Shortcut for `tf.group(self.experimental_local_results(value))`.""" 1560 return self._extended._group(value, name) # pylint: disable=protected-access 1561 1562 @property 1563 def num_replicas_in_sync(self): 1564 """Returns number of replicas over which gradients are aggregated.""" 1565 return self._extended._num_replicas_in_sync # pylint: disable=protected-access 1566 1567 @doc_controls.do_not_doc_inheritable # DEPRECATED: see doc string 1568 @deprecated(None, "use `update_config_proto` instead.") 1569 def configure(self, 1570 session_config=None, 1571 cluster_spec=None, 1572 task_type=None, 1573 task_id=None): 1574 # pylint: disable=g-doc-return-or-yield,g-doc-args 1575 """DEPRECATED: use `update_config_proto` instead. 1576 1577 Configures the strategy class. 1578 1579 DEPRECATED: This method's functionality has been split into the strategy 1580 constructor and `update_config_proto`. In the future, we will allow passing 1581 cluster and config_proto to the constructor to configure the strategy. And 1582 `update_config_proto` can be used to update the config_proto based on the 1583 specific strategy. 1584 """ 1585 return self._extended._configure( # pylint: disable=protected-access 1586 session_config, cluster_spec, task_type, task_id) 1587 1588 @doc_controls.do_not_generate_docs # DEPRECATED 1589 def update_config_proto(self, config_proto): 1590 """DEPRECATED TF 1.x ONLY.""" 1591 return self._extended._update_config_proto(config_proto) # pylint: disable=protected-access 1592 1593 def __deepcopy__(self, memo): 1594 # First do a regular deepcopy of `self`. 1595 cls = self.__class__ 1596 result = cls.__new__(cls) 1597 memo[id(self)] = result 1598 for k, v in self.__dict__.items(): 1599 setattr(result, k, copy.deepcopy(v, memo)) 1600 # One little fix-up: we want `result._extended` to reference `result` 1601 # instead of `self`. 1602 result._extended._container_strategy_weakref = weakref.ref(result) # pylint: disable=protected-access 1603 return result 1604 1605 def __copy__(self): 1606 raise RuntimeError("Must only deepcopy DistributionStrategy.") 1607 1608 @property 1609 def cluster_resolver(self): 1610 """Returns the cluster resolver associated with this strategy. 1611 1612 In general, when using a multi-worker `tf.distribute` strategy such as 1613 `tf.distribute.experimental.MultiWorkerMirroredStrategy` or 1614 `tf.distribute.TPUStrategy()`, there is a 1615 `tf.distribute.cluster_resolver.ClusterResolver` associated with the 1616 strategy used, and such an instance is returned by this property. 1617 1618 Strategies that intend to have an associated 1619 `tf.distribute.cluster_resolver.ClusterResolver` must set the 1620 relevant attribute, or override this property; otherwise, `None` is returned 1621 by default. Those strategies should also provide information regarding what 1622 is returned by this property. 1623 1624 Single-worker strategies usually do not have a 1625 `tf.distribute.cluster_resolver.ClusterResolver`, and in those cases this 1626 property will return `None`. 1627 1628 The `tf.distribute.cluster_resolver.ClusterResolver` may be useful when the 1629 user needs to access information such as the cluster spec, task type or task 1630 id. For example, 1631 1632 ```python 1633 1634 os.environ['TF_CONFIG'] = json.dumps({ 1635 'cluster': { 1636 'worker': ["localhost:12345", "localhost:23456"], 1637 'ps': ["localhost:34567"] 1638 }, 1639 'task': {'type': 'worker', 'index': 0} 1640 }) 1641 1642 # This implicitly uses TF_CONFIG for the cluster and current task info. 1643 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() 1644 1645 ... 1646 1647 if strategy.cluster_resolver.task_type == 'worker': 1648 # Perform something that's only applicable on workers. Since we set this 1649 # as a worker above, this block will run on this particular instance. 1650 elif strategy.cluster_resolver.task_type == 'ps': 1651 # Perform something that's only applicable on parameter servers. Since we 1652 # set this as a worker above, this block will not run on this particular 1653 # instance. 1654 ``` 1655 1656 For more information, please see 1657 `tf.distribute.cluster_resolver.ClusterResolver`'s API docstring. 1658 1659 Returns: 1660 The cluster resolver associated with this strategy. Returns `None` if a 1661 cluster resolver is not applicable or available in this strategy. 1662 """ 1663 if hasattr(self.extended, "_cluster_resolver"): 1664 return self.extended._cluster_resolver # pylint: disable=protected-access 1665 return None 1666 1667 1668@tf_export("distribute.Strategy", v1=[]) # pylint: disable=g-missing-docstring 1669class Strategy(StrategyBase): 1670 1671 __doc__ = StrategyBase.__doc__ 1672 1673 def experimental_distribute_values_from_function(self, value_fn): 1674 """Generates `tf.distribute.DistributedValues` from `value_fn`. 1675 1676 This function is to generate `tf.distribute.DistributedValues` to pass 1677 into `run`, `reduce`, or other methods that take 1678 distributed values when not using datasets. 1679 1680 Args: 1681 value_fn: The function to run to generate values. It is called for 1682 each replica with `tf.distribute.ValueContext` as the sole argument. It 1683 must return a Tensor or a type that can be converted to a Tensor. 1684 Returns: 1685 A `tf.distribute.DistributedValues` containing a value for each replica. 1686 1687 Example usage: 1688 1689 1. Return constant value per replica: 1690 1691 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 1692 >>> def value_fn(ctx): 1693 ... return tf.constant(1.) 1694 >>> distributed_values = ( 1695 ... strategy.experimental_distribute_values_from_function( 1696 ... value_fn)) 1697 >>> local_result = strategy.experimental_local_results(distributed_values) 1698 >>> local_result 1699 (<tf.Tensor: shape=(), dtype=float32, numpy=1.0>, 1700 <tf.Tensor: shape=(), dtype=float32, numpy=1.0>) 1701 1702 2. Distribute values in array based on replica_id: 1703 1704 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 1705 >>> array_value = np.array([3., 2., 1.]) 1706 >>> def value_fn(ctx): 1707 ... return array_value[ctx.replica_id_in_sync_group] 1708 >>> distributed_values = ( 1709 ... strategy.experimental_distribute_values_from_function( 1710 ... value_fn)) 1711 >>> local_result = strategy.experimental_local_results(distributed_values) 1712 >>> local_result 1713 (3.0, 2.0) 1714 1715 3. Specify values using num_replicas_in_sync: 1716 1717 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 1718 >>> def value_fn(ctx): 1719 ... return ctx.num_replicas_in_sync 1720 >>> distributed_values = ( 1721 ... strategy.experimental_distribute_values_from_function( 1722 ... value_fn)) 1723 >>> local_result = strategy.experimental_local_results(distributed_values) 1724 >>> local_result 1725 (2, 2) 1726 1727 4. Place values on devices and distribute: 1728 1729 ``` 1730 strategy = tf.distribute.TPUStrategy() 1731 worker_devices = strategy.extended.worker_devices 1732 multiple_values = [] 1733 for i in range(strategy.num_replicas_in_sync): 1734 with tf.device(worker_devices[i]): 1735 multiple_values.append(tf.constant(1.0)) 1736 1737 def value_fn(ctx): 1738 return multiple_values[ctx.replica_id_in_sync_group] 1739 1740 distributed_values = strategy. 1741 experimental_distribute_values_from_function( 1742 value_fn) 1743 ``` 1744 1745 """ 1746 return self._extended._experimental_distribute_values_from_function( # pylint: disable=protected-access 1747 value_fn) 1748 1749 def gather(self, value, axis): 1750 # pylint: disable=line-too-long, protected-access 1751 """Gather `value` across replicas along `axis` to the current device. 1752 1753 Given a `tf.distribute.DistributedValues` or `tf.Tensor`-like 1754 object `value`, this API gathers and concatenates `value` across replicas 1755 along the `axis`-th dimension. The result is copied to the "current" device, 1756 which would typically be the CPU of the worker on which the program is 1757 running. For `tf.distribute.TPUStrategy`, it is the first TPU host. For 1758 multi-client `tf.distribute.MultiWorkerMirroredStrategy`, this is the CPU of 1759 each worker. 1760 1761 This API can only be called in the cross-replica context. For a counterpart 1762 in the replica context, see `tf.distribute.ReplicaContext.all_gather`. 1763 1764 Note: For all strategies except `tf.distribute.TPUStrategy`, the input 1765 `value` on different replicas must have the same rank, and their shapes must 1766 be the same in all dimensions except the `axis`-th dimension. In other 1767 words, their shapes cannot be different in a dimension `d` where `d` does 1768 not equal to the `axis` argument. For example, given a 1769 `tf.distribute.DistributedValues` with component tensors of shape 1770 `(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can call 1771 `gather(..., axis=1, ...)` on it, but not `gather(..., axis=0, ...)` or 1772 `gather(..., axis=2, ...)`. However, for `tf.distribute.TPUStrategy.gather`, 1773 all tensors must have exactly the same rank and same shape. 1774 1775 Note: Given a `tf.distribute.DistributedValues` `value`, its component 1776 tensors must have a non-zero rank. Otherwise, consider using 1777 `tf.expand_dims` before gathering them. 1778 1779 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 1780 >>> # A DistributedValues with component tensor of shape (2, 1) on each replica 1781 ... distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(tf.constant([[1], [2]]))) 1782 >>> @tf.function 1783 ... def run(): 1784 ... return strategy.gather(distributed_values, axis=0) 1785 >>> run() 1786 <tf.Tensor: shape=(4, 1), dtype=int32, numpy= 1787 array([[1], 1788 [2], 1789 [1], 1790 [2]], dtype=int32)> 1791 1792 1793 Consider the following example for more combinations: 1794 1795 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1", "GPU:2", "GPU:3"]) 1796 >>> single_tensor = tf.reshape(tf.range(6), shape=(1,2,3)) 1797 >>> distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(single_tensor)) 1798 >>> @tf.function 1799 ... def run(axis): 1800 ... return strategy.gather(distributed_values, axis=axis) 1801 >>> axis=0 1802 >>> run(axis) 1803 <tf.Tensor: shape=(4, 2, 3), dtype=int32, numpy= 1804 array([[[0, 1, 2], 1805 [3, 4, 5]], 1806 [[0, 1, 2], 1807 [3, 4, 5]], 1808 [[0, 1, 2], 1809 [3, 4, 5]], 1810 [[0, 1, 2], 1811 [3, 4, 5]]], dtype=int32)> 1812 >>> axis=1 1813 >>> run(axis) 1814 <tf.Tensor: shape=(1, 8, 3), dtype=int32, numpy= 1815 array([[[0, 1, 2], 1816 [3, 4, 5], 1817 [0, 1, 2], 1818 [3, 4, 5], 1819 [0, 1, 2], 1820 [3, 4, 5], 1821 [0, 1, 2], 1822 [3, 4, 5]]], dtype=int32)> 1823 >>> axis=2 1824 >>> run(axis) 1825 <tf.Tensor: shape=(1, 2, 12), dtype=int32, numpy= 1826 array([[[0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2], 1827 [3, 4, 5, 3, 4, 5, 3, 4, 5, 3, 4, 5]]], dtype=int32)> 1828 1829 1830 Args: 1831 value: a `tf.distribute.DistributedValues` instance, e.g. returned by 1832 `Strategy.run`, to be combined into a single tensor. It can also be a 1833 regular tensor when used with `tf.distribute.OneDeviceStrategy` or the 1834 default strategy. The tensors that constitute the DistributedValues 1835 can only be dense tensors with non-zero rank, NOT a `tf.IndexedSlices`. 1836 axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the 1837 range [0, rank(value)). 1838 1839 Returns: 1840 A `Tensor` that's the concatenation of `value` across replicas along 1841 `axis` dimension. 1842 """ 1843 # pylint: enable=line-too-long 1844 error_message = ("tf.distribute.Strategy.gather method requires " 1845 "cross-replica context, use " 1846 "get_replica_context().all_gather() instead") 1847 _require_cross_replica_or_default_context_extended(self._extended, 1848 error_message) 1849 dst = device_util.current( 1850 ) or self._extended._default_device or "/device:CPU:0" 1851 if isinstance(value, ops.IndexedSlices): 1852 raise NotImplementedError("gather does not support IndexedSlices") 1853 return self._extended._local_results( 1854 self._extended._gather_to(value, dst, axis))[0] 1855 1856 1857# TF v1.x version has additional deprecated APIs 1858@tf_export(v1=["distribute.Strategy"]) 1859class StrategyV1(StrategyBase): 1860 """A list of devices with a state & compute distribution policy. 1861 1862 See [the guide](https://www.tensorflow.org/guide/distribute_strategy) 1863 for overview and examples. 1864 1865 Note: Not all `tf.distribute.Strategy` implementations currently support 1866 TensorFlow's partitioned variables (where a single variable is split across 1867 multiple devices) at this time. 1868 """ 1869 1870 def make_dataset_iterator(self, dataset): 1871 """Makes an iterator for input provided via `dataset`. 1872 1873 DEPRECATED: This method is not available in TF 2.x. 1874 1875 Data from the given dataset will be distributed evenly across all the 1876 compute replicas. We will assume that the input dataset is batched by the 1877 global batch size. With this assumption, we will make a best effort to 1878 divide each batch across all the replicas (one or more workers). 1879 If this effort fails, an error will be thrown, and the user should instead 1880 use `make_input_fn_iterator` which provides more control to the user, and 1881 does not try to divide a batch across replicas. 1882 1883 The user could also use `make_input_fn_iterator` if they want to 1884 customize which input is fed to which replica/worker etc. 1885 1886 Args: 1887 dataset: `tf.data.Dataset` that will be distributed evenly across all 1888 replicas. 1889 1890 Returns: 1891 An `tf.distribute.InputIterator` which returns inputs for each step of the 1892 computation. User should call `initialize` on the returned iterator. 1893 """ 1894 return self._extended._make_dataset_iterator(dataset) # pylint: disable=protected-access 1895 1896 def make_input_fn_iterator(self, # pylint: disable=useless-super-delegation 1897 input_fn, 1898 replication_mode=InputReplicationMode.PER_WORKER): 1899 """Returns an iterator split across replicas created from an input function. 1900 1901 DEPRECATED: This method is not available in TF 2.x. 1902 1903 The `input_fn` should take an `tf.distribute.InputContext` object where 1904 information about batching and input sharding can be accessed: 1905 1906 ``` 1907 def input_fn(input_context): 1908 batch_size = input_context.get_per_replica_batch_size(global_batch_size) 1909 d = tf.data.Dataset.from_tensors([[1.]]).repeat().batch(batch_size) 1910 return d.shard(input_context.num_input_pipelines, 1911 input_context.input_pipeline_id) 1912 with strategy.scope(): 1913 iterator = strategy.make_input_fn_iterator(input_fn) 1914 replica_results = strategy.experimental_run(replica_fn, iterator) 1915 ``` 1916 1917 The `tf.data.Dataset` returned by `input_fn` should have a per-replica 1918 batch size, which may be computed using 1919 `input_context.get_per_replica_batch_size`. 1920 1921 Args: 1922 input_fn: A function taking a `tf.distribute.InputContext` object and 1923 returning a `tf.data.Dataset`. 1924 replication_mode: an enum value of `tf.distribute.InputReplicationMode`. 1925 Only `PER_WORKER` is supported currently, which means there will be 1926 a single call to `input_fn` per worker. Replicas will dequeue from the 1927 local `tf.data.Dataset` on their worker. 1928 1929 Returns: 1930 An iterator object that should first be `.initialize()`-ed. It may then 1931 either be passed to `strategy.experimental_run()` or you can 1932 `iterator.get_next()` to get the next value to pass to 1933 `strategy.extended.call_for_each_replica()`. 1934 """ 1935 return super(StrategyV1, self).make_input_fn_iterator( 1936 input_fn, replication_mode) 1937 1938 def experimental_make_numpy_dataset(self, numpy_input, session=None): 1939 """Makes a tf.data.Dataset for input provided via a numpy array. 1940 1941 This avoids adding `numpy_input` as a large constant in the graph, 1942 and copies the data to the machine or machines that will be processing 1943 the input. 1944 1945 Note that you will likely need to use 1946 tf.distribute.Strategy.experimental_distribute_dataset 1947 with the returned dataset to further distribute it with the strategy. 1948 1949 Example: 1950 ``` 1951 numpy_input = np.ones([10], dtype=np.float32) 1952 dataset = strategy.experimental_make_numpy_dataset(numpy_input) 1953 dist_dataset = strategy.experimental_distribute_dataset(dataset) 1954 ``` 1955 1956 Args: 1957 numpy_input: A nest of NumPy input arrays that will be converted into a 1958 dataset. Note that lists of Numpy arrays are stacked, as that is normal 1959 `tf.data.Dataset` behavior. 1960 session: (TensorFlow v1.x graph execution only) A session used for 1961 initialization. 1962 1963 Returns: 1964 A `tf.data.Dataset` representing `numpy_input`. 1965 """ 1966 return self.extended.experimental_make_numpy_dataset( 1967 numpy_input, session=session) 1968 1969 @deprecated( 1970 None, 1971 "This method is not available in TF 2.x. Please switch to using `run` instead." 1972 ) 1973 def experimental_run(self, fn, input_iterator=None): # pylint: disable=useless-super-delegation 1974 """Runs ops in `fn` on each replica, with inputs from `input_iterator`. 1975 1976 DEPRECATED: This method is not available in TF 2.x. Please switch 1977 to using `run` instead. 1978 1979 When eager execution is enabled, executes ops specified by `fn` on each 1980 replica. Otherwise, builds a graph to execute the ops on each replica. 1981 1982 Each replica will take a single, different input from the inputs provided by 1983 one `get_next` call on the input iterator. 1984 1985 `fn` may call `tf.distribute.get_replica_context()` to access members such 1986 as `replica_id_in_sync_group`. 1987 1988 IMPORTANT: Depending on the `tf.distribute.Strategy` implementation being 1989 used, and whether eager execution is enabled, `fn` may be called one or more 1990 times (once for each replica). 1991 1992 Args: 1993 fn: The function to run. The inputs to the function must match the outputs 1994 of `input_iterator.get_next()`. The output must be a `tf.nest` of 1995 `Tensor`s. 1996 input_iterator: (Optional) input iterator from which the inputs are taken. 1997 1998 Returns: 1999 Merged return value of `fn` across replicas. The structure of the return 2000 value is the same as the return value from `fn`. Each element in the 2001 structure can either be `PerReplica` (if the values are unsynchronized), 2002 `Mirrored` (if the values are kept in sync), or `Tensor` (if running on a 2003 single replica). 2004 """ 2005 return super(StrategyV1, self).experimental_run( 2006 fn, input_iterator) 2007 2008 def reduce(self, reduce_op, value, axis=None): 2009 return super(StrategyV1, self).reduce(reduce_op, value, axis) 2010 2011 reduce.__doc__ = StrategyBase.reduce.__doc__ 2012 2013 def update_config_proto(self, config_proto): 2014 """Returns a copy of `config_proto` modified for use with this strategy. 2015 2016 DEPRECATED: This method is not available in TF 2.x. 2017 2018 The updated config has something needed to run a strategy, e.g. 2019 configuration to run collective ops, or device filters to improve 2020 distributed training performance. 2021 2022 Args: 2023 config_proto: a `tf.ConfigProto` object. 2024 2025 Returns: 2026 The updated copy of the `config_proto`. 2027 """ 2028 return self._extended._update_config_proto(config_proto) # pylint: disable=protected-access 2029 2030 2031# NOTE(josh11b): For any strategy that needs to support tf.compat.v1, 2032# instead descend from StrategyExtendedV1. 2033@tf_export("distribute.StrategyExtended", v1=[]) 2034class StrategyExtendedV2(object): 2035 """Additional APIs for algorithms that need to be distribution-aware. 2036 2037 Note: For most usage of `tf.distribute.Strategy`, there should be no need to 2038 call these methods, since TensorFlow libraries (such as optimizers) already 2039 call these methods when needed on your behalf. 2040 2041 2042 Some common use cases of functions on this page: 2043 2044 * _Locality_ 2045 2046 `tf.distribute.DistributedValues` can have the same _locality_ as a 2047 _distributed variable_, which leads to a mirrored value residing on the same 2048 devices as the variable (as opposed to the compute devices). Such values may 2049 be passed to a call to `tf.distribute.StrategyExtended.update` to update the 2050 value of a variable. You may use 2051 `tf.distribute.StrategyExtended.colocate_vars_with` to give a variable the 2052 same locality as another variable. You may convert a "PerReplica" value to a 2053 variable's locality by using `tf.distribute.StrategyExtended.reduce_to` or 2054 `tf.distribute.StrategyExtended.batch_reduce_to`. 2055 2056 * _How to update a distributed variable_ 2057 2058 A distributed variable is variables created on multiple devices. As discussed 2059 in the [glossary](https://www.tensorflow.org/api_docs/python/tf/distribute), 2060 mirrored variable and SyncOnRead variable are two examples. The standard 2061 pattern for updating distributed variables is to: 2062 2063 1. In your function passed to `tf.distribute.Strategy.run`, 2064 compute a list of (update, variable) pairs. For example, the update might 2065 be a gradient of the loss with respect to the variable. 2066 2. Switch to cross-replica mode by calling 2067 `tf.distribute.get_replica_context().merge_call()` with the updates and 2068 variables as arguments. 2069 3. Call 2070 `tf.distribute.StrategyExtended.reduce_to(VariableAggregation.SUM, t, v)` 2071 (for one variable) or `tf.distribute.StrategyExtended.batch_reduce_to` 2072 (for a list of variables) to sum the updates. 2073 4. Call `tf.distribute.StrategyExtended.update(v)` for each variable to update 2074 its value. 2075 2076 Steps 2 through 4 are done automatically by class 2077 `tf.keras.optimizers.Optimizer` if you call its 2078 `tf.keras.optimizers.Optimizer.apply_gradients` method in a replica context. 2079 2080 In fact, a higher-level solution to update a distributed variable is by 2081 calling `assign` on the variable as you would do to a regular `tf.Variable`. 2082 You can call the method in both _replica context_ and _cross-replica context_. 2083 For a _mirrored variable_, calling `assign` in _replica context_ requires you 2084 to specify the `aggregation` type in the variable constructor. In that case, 2085 the context switching and sync described in steps 2 through 4 are handled for 2086 you. If you call `assign` on _mirrored variable_ in _cross-replica context_, 2087 you can only assign a single value or assign values from another mirrored 2088 variable or a mirrored `tf.distribute.DistributedValues`. For a _SyncOnRead 2089 variable_, in _replica context_, you can simply call `assign` on it and no 2090 aggregation happens under the hood. In _cross-replica context_, you can only 2091 assign a single value to a SyncOnRead variable. One example case is restoring 2092 from a checkpoint: if the `aggregation` type of the variable is 2093 `tf.VariableAggregation.SUM`, it is assumed that replica values were added 2094 before checkpointing, so at the time of restoring, the value is divided by 2095 the number of replicas and then assigned to each replica; if the `aggregation` 2096 type is `tf.VariableAggregation.MEAN`, the value is assigned to each replica 2097 directly. 2098 2099 """ 2100 2101 def __init__(self, container_strategy): 2102 self._container_strategy_weakref = weakref.ref(container_strategy) 2103 self._default_device = None 2104 # This property is used to determine if we should set drop_remainder=True 2105 # when creating Datasets from numpy array inputs. 2106 self._require_static_shapes = False 2107 2108 def _resource_creator_scope(self): 2109 """Returns one or more ops.resource_creator_scope for some Strategy.""" 2110 return None 2111 2112 def _container_strategy(self): 2113 """Get the containing `tf.distribute.Strategy`. 2114 2115 This should not generally be needed except when creating a new 2116 `ReplicaContext` and to validate that the caller is in the correct 2117 `scope()`. 2118 2119 Returns: 2120 The `tf.distribute.Strategy` such that `strategy.extended` is `self`. 2121 """ 2122 container_strategy = self._container_strategy_weakref() 2123 assert container_strategy is not None 2124 return container_strategy 2125 2126 def _scope(self, strategy): 2127 """Implementation of tf.distribute.Strategy.scope().""" 2128 2129 def creator_with_resource_vars(next_creator, **kwargs): 2130 """Variable creator to use in `_CurrentDistributionContext`.""" 2131 _require_strategy_scope_extended(self) 2132 kwargs["use_resource"] = True 2133 kwargs["distribute_strategy"] = strategy 2134 2135 # Unwrap `initial_value` if it is a `CheckpointInitialValue` to avoid 2136 # dereferencing a `Tensor` that is without a `name`. We still need to 2137 # propagate the metadata it's holding. 2138 if isinstance(kwargs["initial_value"], trackable.CheckpointInitialValue): 2139 checkpoint_restore_uid = kwargs[ 2140 "initial_value"].checkpoint_position.restore_uid 2141 kwargs["initial_value"] = kwargs["initial_value"].wrapped_value 2142 elif isinstance(kwargs["initial_value"], 2143 trackable.CheckpointInitialValueCallable): 2144 checkpoint_restore_uid = kwargs[ 2145 "initial_value"].checkpoint_position.restore_uid 2146 elif (isinstance(kwargs["initial_value"], functools.partial) and 2147 isinstance(kwargs["initial_value"].func, 2148 trackable.CheckpointInitialValueCallable)): 2149 # Some libraries (e.g, Keras) create partial function out of initializer 2150 # to bind shape/dtype, for example: 2151 # initial_val = functools.partial(initializer, shape, dtype=dtype) 2152 # Therefore to get the restore_uid we need to examine the "func" of 2153 # the partial function. 2154 checkpoint_restore_uid = kwargs[ 2155 "initial_value"].func.checkpoint_position.restore_uid 2156 else: 2157 checkpoint_restore_uid = None 2158 2159 created = self._create_variable(next_creator, **kwargs) 2160 2161 if checkpoint_restore_uid is not None: 2162 # pylint: disable=protected-access 2163 # Let the checkpointing infrastructure know that the variable was 2164 # already restored so it doesn't waste memory loading the value again. 2165 # In this case of CheckpointInitialValueCallable this may already be 2166 # done by the final variable creator, but it doesn't hurt to do it 2167 # again. 2168 created._maybe_initialize_trackable() 2169 created._update_uid = checkpoint_restore_uid 2170 # pylint: enable=protected-access 2171 return created 2172 2173 def distributed_getter(getter, *args, **kwargs): 2174 if not self._allow_variable_partition(): 2175 if kwargs.pop("partitioner", None) is not None: 2176 tf_logging.log_first_n( 2177 tf_logging.WARN, "Partitioned variables are disabled when using " 2178 "current tf.distribute.Strategy.", 1) 2179 return getter(*args, **kwargs) 2180 2181 return _CurrentDistributionContext( 2182 strategy, 2183 variable_scope.variable_creator_scope(creator_with_resource_vars), 2184 variable_scope.variable_scope( 2185 variable_scope.get_variable_scope(), 2186 custom_getter=distributed_getter), 2187 strategy.extended._resource_creator_scope(), # pylint: disable=protected-access 2188 self._default_device) 2189 2190 def _allow_variable_partition(self): 2191 return False 2192 2193 def _create_variable(self, next_creator, **kwargs): 2194 # Note: should support "colocate_with" argument. 2195 raise NotImplementedError("must be implemented in descendants") 2196 2197 def variable_created_in_scope(self, v): 2198 """Tests whether `v` was created while this strategy scope was active. 2199 2200 Variables created inside the strategy scope are "owned" by it: 2201 2202 >>> strategy = tf.distribute.MirroredStrategy() 2203 >>> with strategy.scope(): 2204 ... v = tf.Variable(1.) 2205 >>> strategy.extended.variable_created_in_scope(v) 2206 True 2207 2208 Variables created outside the strategy are not owned by it: 2209 2210 >>> strategy = tf.distribute.MirroredStrategy() 2211 >>> v = tf.Variable(1.) 2212 >>> strategy.extended.variable_created_in_scope(v) 2213 False 2214 2215 Args: 2216 v: A `tf.Variable` instance. 2217 2218 Returns: 2219 True if `v` was created inside the scope, False if not. 2220 """ 2221 return v._distribute_strategy == self._container_strategy_weakref() # pylint: disable=protected-access 2222 2223 def colocate_vars_with(self, colocate_with_variable): 2224 """Scope that controls which devices variables will be created on. 2225 2226 No operations should be added to the graph inside this scope, it 2227 should only be used when creating variables (some implementations 2228 work by changing variable creation, others work by using a 2229 tf.compat.v1.colocate_with() scope). 2230 2231 This may only be used inside `self.scope()`. 2232 2233 Example usage: 2234 2235 ``` 2236 with strategy.scope(): 2237 var1 = tf.Variable(...) 2238 with strategy.extended.colocate_vars_with(var1): 2239 # var2 and var3 will be created on the same device(s) as var1 2240 var2 = tf.Variable(...) 2241 var3 = tf.Variable(...) 2242 2243 def fn(v1, v2, v3): 2244 # operates on v1 from var1, v2 from var2, and v3 from var3 2245 2246 # `fn` runs on every device `var1` is on, `var2` and `var3` will be there 2247 # too. 2248 strategy.extended.update(var1, fn, args=(var2, var3)) 2249 ``` 2250 2251 Args: 2252 colocate_with_variable: A variable created in this strategy's `scope()`. 2253 Variables created while in the returned context manager will be on the 2254 same set of devices as `colocate_with_variable`. 2255 2256 Returns: 2257 A context manager. 2258 """ 2259 2260 def create_colocated_variable(next_creator, **kwargs): 2261 _require_strategy_scope_extended(self) 2262 kwargs["use_resource"] = True 2263 kwargs["colocate_with"] = colocate_with_variable 2264 return next_creator(**kwargs) 2265 2266 _require_strategy_scope_extended(self) 2267 self._validate_colocate_with_variable(colocate_with_variable) 2268 return variable_scope.variable_creator_scope(create_colocated_variable) 2269 2270 def _validate_colocate_with_variable(self, colocate_with_variable): 2271 """Validate `colocate_with_variable` argument to `colocate_vars_with`.""" 2272 pass 2273 2274 def _make_dataset_iterator(self, dataset): 2275 raise NotImplementedError("must be implemented in descendants") 2276 2277 def _make_input_fn_iterator(self, input_fn, replication_mode): 2278 raise NotImplementedError("must be implemented in descendants") 2279 2280 def _experimental_distribute_dataset(self, dataset, options): 2281 raise NotImplementedError("must be implemented in descendants") 2282 2283 def _distribute_datasets_from_function(self, dataset_fn, options): 2284 raise NotImplementedError("must be implemented in descendants") 2285 2286 def _experimental_distribute_values_from_function(self, value_fn): 2287 raise NotImplementedError("must be implemented in descendants") 2288 2289 def _reduce(self, reduce_op, value): 2290 # Default implementation until we have an implementation for each strategy. 2291 dst = device_util.current() or self._default_device or "/device:CPU:0" 2292 return self._local_results(self.reduce_to(reduce_op, value, dst))[0] 2293 2294 def reduce_to(self, reduce_op, value, destinations, options=None): 2295 """Combine (via e.g. sum or mean) values across replicas. 2296 2297 `reduce_to` aggregates `tf.distribute.DistributedValues` and distributed 2298 variables. It supports both dense values and `tf.IndexedSlices`. 2299 2300 This API currently can only be called in cross-replica context. Other 2301 variants to reduce values across replicas are: 2302 * `tf.distribute.StrategyExtended.batch_reduce_to`: the batch version of 2303 this API. 2304 * `tf.distribute.ReplicaContext.all_reduce`: the counterpart of this API 2305 in replica context. It supports both batched and non-batched all-reduce. 2306 * `tf.distribute.Strategy.reduce`: a more convenient method to reduce 2307 to the host in cross-replica context. 2308 2309 `destinations` specifies where to reduce the value to, e.g. "GPU:0". You can 2310 also pass in a `Tensor`, and the destinations will be the device of that 2311 tensor. For all-reduce, pass the same to `value` and `destinations`. 2312 2313 It can be used in `tf.distribute.ReplicaContext.merge_call` to write code 2314 that works for all `tf.distribute.Strategy`. 2315 2316 >>> @tf.function 2317 ... def step_fn(var): 2318 ... 2319 ... def merge_fn(strategy, value, var): 2320 ... # All-reduce the value. Note that `value` here is a 2321 ... # `tf.distribute.DistributedValues`. 2322 ... reduced = strategy.extended.reduce_to(tf.distribute.ReduceOp.SUM, 2323 ... value, destinations=var) 2324 ... strategy.extended.update(var, lambda var, value: var.assign(value), 2325 ... args=(reduced,)) 2326 ... 2327 ... value = tf.identity(1.) 2328 ... tf.distribute.get_replica_context().merge_call(merge_fn, 2329 ... args=(value, var)) 2330 >>> 2331 >>> def run(strategy): 2332 ... with strategy.scope(): 2333 ... v = tf.Variable(0.) 2334 ... strategy.run(step_fn, args=(v,)) 2335 ... return v 2336 >>> 2337 >>> run(tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])) 2338 MirroredVariable:{ 2339 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>, 2340 1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=2.0> 2341 } 2342 >>> run(tf.distribute.experimental.CentralStorageStrategy( 2343 ... compute_devices=["GPU:0", "GPU:1"], parameter_device="CPU:0")) 2344 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0> 2345 >>> run(tf.distribute.OneDeviceStrategy("GPU:0")) 2346 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0> 2347 2348 Args: 2349 reduce_op: a `tf.distribute.ReduceOp` value specifying how values should 2350 be combined. Allows using string representation of the enum such as 2351 "SUM", "MEAN". 2352 value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` like object. 2353 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a 2354 `tf.Tensor` alike object, or a device string. It specifies the devices 2355 to reduce to. To perform an all-reduce, pass the same to `value` and 2356 `destinations`. Note that if it's a `tf.Variable`, the value is reduced 2357 to the devices of that variable, and this method doesn't update the 2358 variable. 2359 options: a `tf.distribute.experimental.CommunicationOptions`. Options to 2360 perform collective operations. This overrides the default options if the 2361 `tf.distribute.Strategy` takes one in the constructor. See 2362 `tf.distribute.experimental.CommunicationOptions` for details of the 2363 options. 2364 2365 Returns: 2366 A tensor or value reduced to `destinations`. 2367 """ 2368 if options is None: 2369 options = collective_util.Options() 2370 _require_cross_replica_or_default_context_extended(self) 2371 assert not isinstance(destinations, (list, tuple)) 2372 assert not isinstance(reduce_op, variable_scope.VariableAggregation) 2373 if isinstance(reduce_op, six.string_types): 2374 reduce_op = reduce_util.ReduceOp(reduce_op.upper()) 2375 assert (reduce_op == reduce_util.ReduceOp.SUM or 2376 reduce_op == reduce_util.ReduceOp.MEAN) 2377 return self._reduce_to(reduce_op, value, destinations, options) 2378 2379 def _reduce_to(self, reduce_op, value, destinations, options): 2380 raise NotImplementedError("must be implemented in descendants") 2381 2382 def batch_reduce_to(self, reduce_op, value_destination_pairs, options=None): 2383 """Combine multiple `reduce_to` calls into one for faster execution. 2384 2385 Similar to `reduce_to`, but accepts a list of (value, destinations) pairs. 2386 It's more efficient than reduce each value separately. 2387 2388 This API currently can only be called in cross-replica context. Other 2389 variants to reduce values across replicas are: 2390 * `tf.distribute.StrategyExtended.reduce_to`: the non-batch version of 2391 this API. 2392 * `tf.distribute.ReplicaContext.all_reduce`: the counterpart of this API 2393 in replica context. It supports both batched and non-batched all-reduce. 2394 * `tf.distribute.Strategy.reduce`: a more convenient method to reduce 2395 to the host in cross-replica context. 2396 2397 See `reduce_to` for more information. 2398 2399 >>> @tf.function 2400 ... def step_fn(var): 2401 ... 2402 ... def merge_fn(strategy, value, var): 2403 ... # All-reduce the value. Note that `value` here is a 2404 ... # `tf.distribute.DistributedValues`. 2405 ... reduced = strategy.extended.batch_reduce_to( 2406 ... tf.distribute.ReduceOp.SUM, [(value, var)])[0] 2407 ... strategy.extended.update(var, lambda var, value: var.assign(value), 2408 ... args=(reduced,)) 2409 ... 2410 ... value = tf.identity(1.) 2411 ... tf.distribute.get_replica_context().merge_call(merge_fn, 2412 ... args=(value, var)) 2413 >>> 2414 >>> def run(strategy): 2415 ... with strategy.scope(): 2416 ... v = tf.Variable(0.) 2417 ... strategy.run(step_fn, args=(v,)) 2418 ... return v 2419 >>> 2420 >>> run(tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])) 2421 MirroredVariable:{ 2422 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>, 2423 1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=2.0> 2424 } 2425 >>> run(tf.distribute.experimental.CentralStorageStrategy( 2426 ... compute_devices=["GPU:0", "GPU:1"], parameter_device="CPU:0")) 2427 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0> 2428 >>> run(tf.distribute.OneDeviceStrategy("GPU:0")) 2429 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0> 2430 2431 Args: 2432 reduce_op: a `tf.distribute.ReduceOp` value specifying how values should 2433 be combined. Allows using string representation of the enum such as 2434 "SUM", "MEAN". 2435 value_destination_pairs: a sequence of (value, destinations) pairs. See 2436 `tf.distribute.Strategy.reduce_to` for descriptions. 2437 options: a `tf.distribute.experimental.CommunicationOptions`. Options to 2438 perform collective operations. This overrides the default options if the 2439 `tf.distribute.Strategy` takes one in the constructor. See 2440 `tf.distribute.experimental.CommunicationOptions` for details of the 2441 options. 2442 2443 Returns: 2444 A list of reduced values, one per pair in `value_destination_pairs`. 2445 """ 2446 if options is None: 2447 options = collective_util.Options() 2448 _require_cross_replica_or_default_context_extended(self) 2449 assert not isinstance(reduce_op, variable_scope.VariableAggregation) 2450 if isinstance(reduce_op, six.string_types): 2451 reduce_op = reduce_util.ReduceOp(reduce_op.upper()) 2452 return self._batch_reduce_to(reduce_op, value_destination_pairs, options) 2453 2454 def _batch_reduce_to(self, reduce_op, value_destination_pairs, options): 2455 return [ 2456 self.reduce_to(reduce_op, t, destinations=v, options=options) 2457 for t, v in value_destination_pairs 2458 ] 2459 2460 def _replica_ctx_all_reduce(self, reduce_op, value, options=None): 2461 """All-reduce `value` across all replicas so that all get the final result. 2462 2463 If `value` is a nested structure of tensors, all-reduces of these tensors 2464 will be batched when possible. `options` can be set to hint the batching 2465 behavior. 2466 2467 This API must be called in a replica context. 2468 2469 Args: 2470 reduce_op: A `tf.distribute.ReduceOp` value specifying how values should 2471 be combined. 2472 value: Value to be reduced. A tensor or a nested structure of tensors. 2473 options: A `tf.distribute.experimental.CommunicationOptions`. Options to 2474 perform collective operations. This overrides the default options if the 2475 `tf.distribute.Strategy` takes one in the constructor. 2476 2477 Returns: 2478 A tensor or a nested strucutre of tensors with the reduced values. The 2479 structure is the same as `value`. 2480 """ 2481 if options is None: 2482 options = collective_util.Options() 2483 replica_context = distribution_strategy_context.get_replica_context() 2484 assert replica_context, ( 2485 "`StrategyExtended._replica_ctx_all_reduce` must be called in" 2486 " a replica context") 2487 2488 def merge_fn(_, flat_value): 2489 return self.batch_reduce_to(reduce_op, [(v, v) for v in flat_value], 2490 options) 2491 2492 reduced = replica_context.merge_call(merge_fn, args=(nest.flatten(value),)) 2493 return nest.pack_sequence_as(value, reduced) 2494 2495 def _replica_ctx_update(self, var, fn, args=(), kwargs=None, group=True): 2496 """Run `fn` with `args` and `kwargs` to update `var`.""" 2497 # This method is called by ReplicaContext.update. Strategies who'd like to 2498 # remove merge_call in this path should override this method. 2499 replica_context = distribution_strategy_context.get_replica_context() 2500 if not replica_context: 2501 raise ValueError("`StrategyExtended._replica_ctx_update` must be called " 2502 "in a replica context.") 2503 2504 def merge_fn(_, *merged_args, **merged_kwargs): 2505 return self.update(var, fn, merged_args, merged_kwargs, group=group) 2506 2507 return replica_context.merge_call(merge_fn, args=args, kwargs=kwargs) 2508 2509 def _gather_to(self, value, destinations, axis, options=None): 2510 """Gather `value` across replicas along axis-th dimension to `destinations`. 2511 2512 `gather_to` gathers `tf.distribute.DistributedValues` or `tf.Tensor`-like 2513 object, along `axis`-th dimension. It supports only dense tensors but NOT 2514 sparse tensor. This API can only be called in cross-replica context. 2515 2516 Args: 2517 value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` like object. 2518 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a 2519 `tf.Tensor` alike object, or a device string. It specifies the devices 2520 to reduce to. To perform an all-gather, pass the same to `value` and 2521 `destinations`. Note that if it's a `tf.Variable`, the value is reduced 2522 to the devices of that variable, and this method doesn't update the 2523 variable. 2524 axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the 2525 range [0, rank(value)). 2526 options: a `tf.distribute.experimental.CommunicationOptions`. Options to 2527 perform collective operations. This overrides the default options if the 2528 `tf.distribute.Strategy` takes one in the constructor. See 2529 `tf.distribute.experimental.CommunicationOptions` for details of the 2530 options. 2531 2532 Returns: 2533 A tensor or value gathered to `destinations`. 2534 """ 2535 _require_cross_replica_or_default_context_extended(self) 2536 assert not isinstance(destinations, (list, tuple)) 2537 if options is None: 2538 options = collective_util.Options() 2539 return self._gather_to_implementation(value, destinations, axis, options) 2540 2541 def _gather_to_implementation(self, value, destinations, axis, options): 2542 raise NotImplementedError("_gather_to must be implemented in descendants") 2543 2544 def _batch_gather_to(self, value_destination_pairs, axis, options=None): 2545 _require_cross_replica_or_default_context_extended(self) 2546 if options is None: 2547 options = collective_util.Options() 2548 return [ 2549 self._gather_to(t, destinations=v, axis=axis, options=options) 2550 for t, v in value_destination_pairs 2551 ] 2552 2553 def update(self, var, fn, args=(), kwargs=None, group=True): 2554 """Run `fn` to update `var` using inputs mirrored to the same devices. 2555 2556 `tf.distribute.StrategyExtended.update` takes a distributed variable `var` 2557 to be updated, an update function `fn`, and `args` and `kwargs` for `fn`. It 2558 applies `fn` to each component variable of `var` and passes corresponding 2559 values from `args` and `kwargs`. Neither `args` nor `kwargs` may contain 2560 per-replica values. If they contain mirrored values, they will be unwrapped 2561 before calling `fn`. For example, `fn` can be `assign_add` and `args` can be 2562 a mirrored DistributedValues where each component contains the value to be 2563 added to this mirrored variable `var`. Calling `update` will call 2564 `assign_add` on each component variable of `var` with the corresponding 2565 tensor value on that device. 2566 2567 Example usage: 2568 2569 ```python 2570 strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # With 2 2571 devices 2572 with strategy.scope(): 2573 v = tf.Variable(5.0, aggregation=tf.VariableAggregation.SUM) 2574 def update_fn(v): 2575 return v.assign(1.0) 2576 result = strategy.extended.update(v, update_fn) 2577 # result is 2578 # Mirrored:{ 2579 # 0: tf.Tensor(1.0, shape=(), dtype=float32), 2580 # 1: tf.Tensor(1.0, shape=(), dtype=float32) 2581 # } 2582 ``` 2583 2584 If `var` is mirrored across multiple devices, then this method implements 2585 logic as following: 2586 2587 ```python 2588 results = {} 2589 for device, v in var: 2590 with tf.device(device): 2591 # args and kwargs will be unwrapped if they are mirrored. 2592 results[device] = fn(v, *args, **kwargs) 2593 return merged(results) 2594 ``` 2595 2596 Otherwise, this method returns `fn(var, *args, **kwargs)` colocated with 2597 `var`. 2598 2599 Args: 2600 var: Variable, possibly mirrored to multiple devices, to operate on. 2601 fn: Function to call. Should take the variable as the first argument. 2602 args: Tuple or list. Additional positional arguments to pass to `fn()`. 2603 kwargs: Dict with keyword arguments to pass to `fn()`. 2604 group: Boolean. Defaults to True. If False, the return value will be 2605 unwrapped. 2606 2607 Returns: 2608 By default, the merged return value of `fn` across all replicas. The 2609 merged result has dependencies to make sure that if it is evaluated at 2610 all, the side effects (updates) will happen on every replica. If instead 2611 "group=False" is specified, this function will return a nest of lists 2612 where each list has an element per replica, and the caller is responsible 2613 for ensuring all elements are executed. 2614 """ 2615 # TODO(b/178944108): Update the documentation to relfect the fact that 2616 # `update` can be called in a replica context. 2617 if kwargs is None: 2618 kwargs = {} 2619 replica_context = distribution_strategy_context.get_replica_context() 2620 # pylint: disable=protected-access 2621 if (replica_context is None or replica_context is 2622 distribution_strategy_context._get_default_replica_context()): 2623 fn = autograph.tf_convert( 2624 fn, autograph_ctx.control_status_ctx(), convert_by_default=False) 2625 with self._container_strategy().scope(): 2626 return self._update(var, fn, args, kwargs, group) 2627 else: 2628 return self._replica_ctx_update( 2629 var, fn, args=args, kwargs=kwargs, group=group) 2630 2631 def _update(self, var, fn, args, kwargs, group): 2632 raise NotImplementedError("must be implemented in descendants") 2633 2634 def _local_results(self, val): 2635 """Returns local results per replica as a tuple.""" 2636 if isinstance(val, values.DistributedValues): 2637 return val._values # pylint: disable=protected-access 2638 2639 if nest.is_nested(val): 2640 replica_values = [] 2641 2642 def get_values(x, index): 2643 if isinstance(x, values.DistributedValues): 2644 return x._values[index] # pylint: disable=protected-access 2645 return x 2646 2647 for i in range(len(self.worker_devices)): 2648 replica_values.append( 2649 nest.map_structure( 2650 lambda x: get_values(x, i), # pylint: disable=cell-var-from-loop 2651 val)) 2652 return tuple(replica_values) 2653 return (val,) 2654 2655 def value_container(self, value): 2656 """Returns the container that this per-replica `value` belongs to. 2657 2658 Args: 2659 value: A value returned by `run()` or a variable created in `scope()`. 2660 2661 Returns: 2662 A container that `value` belongs to. 2663 If value does not belong to any container (including the case of 2664 container having been destroyed), returns the value itself. 2665 `value in experimental_local_results(value_container(value))` will 2666 always be true. 2667 """ 2668 raise NotImplementedError("must be implemented in descendants") 2669 2670 def _group(self, value, name=None): 2671 """Implementation of `group`.""" 2672 value = nest.flatten(self._local_results(value)) 2673 2674 if len(value) != 1 or name is not None: 2675 return control_flow_ops.group(value, name=name) 2676 # Special handling for the common case of one op. 2677 v, = value 2678 if hasattr(v, "op"): 2679 v = v.op 2680 return v 2681 2682 @property 2683 def experimental_require_static_shapes(self): 2684 """Returns `True` if static shape is required; `False` otherwise.""" 2685 return self._require_static_shapes 2686 2687 @property 2688 def _num_replicas_in_sync(self): 2689 """Returns number of replicas over which gradients are aggregated.""" 2690 raise NotImplementedError("must be implemented in descendants") 2691 2692 @property 2693 def worker_devices(self): 2694 """Returns the tuple of all devices used to for compute replica execution. 2695 """ 2696 # TODO(josh11b): More docstring 2697 raise NotImplementedError("must be implemented in descendants") 2698 2699 @property 2700 def parameter_devices(self): 2701 """Returns the tuple of all devices used to place variables.""" 2702 # TODO(josh11b): More docstring 2703 raise NotImplementedError("must be implemented in descendants") 2704 2705 def _configure(self, 2706 session_config=None, 2707 cluster_spec=None, 2708 task_type=None, 2709 task_id=None): 2710 """Configures the strategy class.""" 2711 del session_config, cluster_spec, task_type, task_id 2712 2713 def _update_config_proto(self, config_proto): 2714 return copy.deepcopy(config_proto) 2715 2716 def _in_multi_worker_mode(self): 2717 """Whether this strategy indicates working in multi-worker settings. 2718 2719 Multi-worker training refers to the setup where the training is 2720 distributed across multiple workers, as opposed to the case where 2721 only a local process performs the training. This function is 2722 used by higher-level APIs such as Keras' `model.fit()` to infer 2723 for example whether or not a distribute coordinator should be run, 2724 and thus TensorFlow servers should be started for communication 2725 with other servers in the cluster, or whether or not saving/restoring 2726 checkpoints is relevant for preemption fault tolerance. 2727 2728 Subclasses should override this to provide whether the strategy is 2729 currently in multi-worker setup. 2730 2731 Experimental. Signature and implementation are subject to change. 2732 """ 2733 raise NotImplementedError("must be implemented in descendants") 2734 2735 2736@tf_export(v1=["distribute.StrategyExtended"]) # pylint: disable=missing-docstring 2737class StrategyExtendedV1(StrategyExtendedV2): 2738 2739 __doc__ = StrategyExtendedV2.__doc__ 2740 2741 def experimental_make_numpy_dataset(self, numpy_input, session=None): 2742 """Makes a dataset for input provided via a numpy array. 2743 2744 This avoids adding `numpy_input` as a large constant in the graph, 2745 and copies the data to the machine or machines that will be processing 2746 the input. 2747 2748 Args: 2749 numpy_input: A nest of NumPy input arrays that will be distributed evenly 2750 across all replicas. Note that lists of Numpy arrays are stacked, as 2751 that is normal `tf.data.Dataset` behavior. 2752 session: (TensorFlow v1.x graph execution only) A session used for 2753 initialization. 2754 2755 Returns: 2756 A `tf.data.Dataset` representing `numpy_input`. 2757 """ 2758 _require_cross_replica_or_default_context_extended(self) 2759 return self._experimental_make_numpy_dataset(numpy_input, session=session) 2760 2761 def _experimental_make_numpy_dataset(self, numpy_input, session): 2762 raise NotImplementedError("must be implemented in descendants") 2763 2764 def broadcast_to(self, tensor, destinations): 2765 """Mirror a tensor on one device to all worker devices. 2766 2767 Args: 2768 tensor: A Tensor value to broadcast. 2769 destinations: A mirrored variable or device string specifying the 2770 destination devices to copy `tensor` to. 2771 2772 Returns: 2773 A value mirrored to `destinations` devices. 2774 """ 2775 assert destinations is not None # from old strategy.broadcast() 2776 # TODO(josh11b): More docstring 2777 _require_cross_replica_or_default_context_extended(self) 2778 assert not isinstance(destinations, (list, tuple)) 2779 return self._broadcast_to(tensor, destinations) 2780 2781 def _broadcast_to(self, tensor, destinations): 2782 raise NotImplementedError("must be implemented in descendants") 2783 2784 @deprecated(None, "please use `run` instead.") 2785 def experimental_run_steps_on_iterator(self, 2786 fn, 2787 iterator, 2788 iterations=1, 2789 initial_loop_values=None): 2790 """DEPRECATED: please use `run` instead. 2791 2792 Run `fn` with input from `iterator` for `iterations` times. 2793 2794 This method can be used to run a step function for training a number of 2795 times using input from a dataset. 2796 2797 Args: 2798 fn: function to run using this distribution strategy. The function must 2799 have the following signature: `def fn(context, inputs)`. `context` is an 2800 instance of `MultiStepContext` that will be passed when `fn` is run. 2801 `context` can be used to specify the outputs to be returned from `fn` 2802 by calling `context.set_last_step_output`. It can also be used to 2803 capture non tensor outputs by `context.set_non_tensor_output`. See 2804 `MultiStepContext` documentation for more information. `inputs` will 2805 have same type/structure as `iterator.get_next()`. Typically, `fn` 2806 will use `call_for_each_replica` method of the strategy to distribute 2807 the computation over multiple replicas. 2808 iterator: Iterator of a dataset that represents the input for `fn`. The 2809 caller is responsible for initializing the iterator as needed. 2810 iterations: (Optional) Number of iterations that `fn` should be run. 2811 Defaults to 1. 2812 initial_loop_values: (Optional) Initial values to be passed into the 2813 loop that runs `fn`. Defaults to `None`. # TODO(priyag): Remove 2814 initial_loop_values argument when we have a mechanism to infer the 2815 outputs of `fn`. 2816 2817 Returns: 2818 Returns the `MultiStepContext` object which has the following properties, 2819 among other things: 2820 - run_op: An op that runs `fn` `iterations` times. 2821 - last_step_outputs: A dictionary containing tensors set using 2822 `context.set_last_step_output`. Evaluating this returns the value of 2823 the tensors after the last iteration. 2824 - non_tensor_outputs: A dictionary containing anything that was set by 2825 `fn` by calling `context.set_non_tensor_output`. 2826 """ 2827 _require_cross_replica_or_default_context_extended(self) 2828 with self._container_strategy().scope(): 2829 return self._experimental_run_steps_on_iterator(fn, iterator, iterations, 2830 initial_loop_values) 2831 2832 def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, 2833 initial_loop_values): 2834 raise NotImplementedError("must be implemented in descendants") 2835 2836 def call_for_each_replica(self, fn, args=(), kwargs=None): 2837 """Run `fn` once per replica. 2838 2839 `fn` may call `tf.get_replica_context()` to access methods such as 2840 `replica_id_in_sync_group` and `merge_call()`. 2841 2842 `merge_call()` is used to communicate between the replicas and 2843 re-enter the cross-replica context. All replicas pause their execution 2844 having encountered a `merge_call()` call. After that the 2845 `merge_fn`-function is executed. Its results are then unwrapped and 2846 given back to each replica call. After that execution resumes until 2847 `fn` is complete or encounters another `merge_call()`. Example: 2848 2849 ```python 2850 # Called once in "cross-replica" context. 2851 def merge_fn(distribution, three_plus_replica_id): 2852 # sum the values across replicas 2853 return sum(distribution.experimental_local_results(three_plus_replica_id)) 2854 2855 # Called once per replica in `distribution`, in a "replica" context. 2856 def fn(three): 2857 replica_ctx = tf.get_replica_context() 2858 v = three + replica_ctx.replica_id_in_sync_group 2859 # Computes the sum of the `v` values across all replicas. 2860 s = replica_ctx.merge_call(merge_fn, args=(v,)) 2861 return s + v 2862 2863 with distribution.scope(): 2864 # in "cross-replica" context 2865 ... 2866 merged_results = distribution.run(fn, args=[3]) 2867 # merged_results has the values from every replica execution of `fn`. 2868 # This statement prints a list: 2869 print(distribution.experimental_local_results(merged_results)) 2870 ``` 2871 2872 Args: 2873 fn: function to run (will be run once per replica). 2874 args: Tuple or list with positional arguments for `fn`. 2875 kwargs: Dict with keyword arguments for `fn`. 2876 2877 Returns: 2878 Merged return value of `fn` across all replicas. 2879 """ 2880 _require_cross_replica_or_default_context_extended(self) 2881 if kwargs is None: 2882 kwargs = {} 2883 with self._container_strategy().scope(): 2884 return self._call_for_each_replica(fn, args, kwargs) 2885 2886 def _call_for_each_replica(self, fn, args, kwargs): 2887 raise NotImplementedError("must be implemented in descendants") 2888 2889 def read_var(self, v): 2890 """Reads the value of a variable. 2891 2892 Returns the aggregate value of a replica-local variable, or the 2893 (read-only) value of any other variable. 2894 2895 Args: 2896 v: A variable allocated within the scope of this `tf.distribute.Strategy`. 2897 2898 Returns: 2899 A tensor representing the value of `v`, aggregated across replicas if 2900 necessary. 2901 """ 2902 raise NotImplementedError("must be implemented in descendants") 2903 2904 def update_non_slot( 2905 self, colocate_with, fn, args=(), kwargs=None, group=True): 2906 """Runs `fn(*args, **kwargs)` on `colocate_with` devices. 2907 2908 Used to update non-slot variables. 2909 2910 DEPRECATED: TF 1.x ONLY. 2911 2912 Args: 2913 colocate_with: Devices returned by `non_slot_devices()`. 2914 fn: Function to execute. 2915 args: Tuple or list. Positional arguments to pass to `fn()`. 2916 kwargs: Dict with keyword arguments to pass to `fn()`. 2917 group: Boolean. Defaults to True. If False, the return value will be 2918 unwrapped. 2919 2920 Returns: 2921 Return value of `fn`, possibly merged across devices. 2922 """ 2923 _require_cross_replica_or_default_context_extended(self) 2924 if kwargs is None: 2925 kwargs = {} 2926 fn = autograph.tf_convert( 2927 fn, autograph_ctx.control_status_ctx(), convert_by_default=False) 2928 with self._container_strategy().scope(): 2929 return self._update_non_slot(colocate_with, fn, args, kwargs, group) 2930 2931 def _update_non_slot(self, colocate_with, fn, args, kwargs, group): 2932 raise NotImplementedError("must be implemented in descendants") 2933 2934 def non_slot_devices(self, var_list): 2935 """Device(s) for non-slot variables. 2936 2937 DEPRECATED: TF 1.x ONLY. 2938 2939 This method returns non-slot devices where non-slot variables are placed. 2940 Users can create non-slot variables on these devices by using a block: 2941 2942 ```python 2943 with tf.distribute.StrategyExtended.colocate_vars_with(tf.distribute.StrategyExtended.non_slot_devices(...)): 2944 ... 2945 ``` 2946 2947 Args: 2948 var_list: The list of variables being optimized, needed with the 2949 default `tf.distribute.Strategy`. 2950 Returns: 2951 A sequence of devices for non-slot variables. 2952 """ 2953 raise NotImplementedError("must be implemented in descendants") 2954 2955 def _use_merge_call(self): 2956 """Whether to use merge-calls inside the distributed strategy.""" 2957 return True 2958 2959 @property 2960 def experimental_between_graph(self): 2961 """Whether the strategy uses between-graph replication or not. 2962 2963 This is expected to return a constant value that will not be changed 2964 throughout its life cycle. 2965 """ 2966 raise NotImplementedError("must be implemented in descendants") 2967 2968 @property 2969 def experimental_should_init(self): 2970 """Whether initialization is needed.""" 2971 raise NotImplementedError("must be implemented in descendants") 2972 2973 @property 2974 def should_checkpoint(self): 2975 """Whether checkpointing is needed.""" 2976 raise NotImplementedError("must be implemented in descendants") 2977 2978 @property 2979 def should_save_summary(self): 2980 """Whether saving summaries is needed.""" 2981 raise NotImplementedError("must be implemented in descendants") 2982 2983 2984# A note about the difference between the context managers 2985# `ReplicaContext` (defined here) and `_CurrentDistributionContext` 2986# (defined above) used by `tf.distribute.Strategy.scope()`: 2987# 2988# * a ReplicaContext is only present during a `run()` 2989# call (except during a `merge_run` call) and in such a scope it 2990# will be returned by calls to `get_replica_context()`. Implementers of new 2991# Strategy descendants will frequently also need to 2992# define a descendant of ReplicaContext, and are responsible for 2993# entering and exiting this context. 2994# 2995# * Strategy.scope() sets up a variable_creator scope that 2996# changes variable creation calls (e.g. to make mirrored 2997# variables). This is intended as an outer scope that users enter once 2998# around their model creation and graph definition. There is no 2999# anticipated need to define descendants of _CurrentDistributionContext. 3000# It sets the current Strategy for purposes of 3001# `get_strategy()` and `has_strategy()` 3002# and switches the thread mode to a "cross-replica context". 3003class ReplicaContextBase(object): 3004 """A class with a collection of APIs that can be called in a replica context. 3005 3006 You can use `tf.distribute.get_replica_context` to get an instance of 3007 `ReplicaContext`, which can only be called inside the function passed to 3008 `tf.distribute.Strategy.run`. 3009 3010 >>> strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) 3011 >>> def func(): 3012 ... replica_context = tf.distribute.get_replica_context() 3013 ... return replica_context.replica_id_in_sync_group 3014 >>> strategy.run(func) 3015 PerReplica:{ 3016 0: <tf.Tensor: shape=(), dtype=int32, numpy=0>, 3017 1: <tf.Tensor: shape=(), dtype=int32, numpy=1> 3018 } 3019 """ 3020 3021 def __init__(self, strategy, replica_id_in_sync_group): 3022 """Creates a ReplicaContext. 3023 3024 Args: 3025 strategy: A `tf.distribute.Strategy`. 3026 replica_id_in_sync_group: An integer, a `Tensor` or None. Prefer an 3027 integer whenever possible to avoid issues with nested `tf.function`. It 3028 accepts a `Tensor` only to be compatible with `tpu.replicate`. 3029 """ 3030 self._strategy = strategy 3031 self._thread_context = distribution_strategy_context._InReplicaThreadMode( # pylint: disable=protected-access 3032 self) 3033 if not (replica_id_in_sync_group is None or 3034 tensor_util.is_tf_type(replica_id_in_sync_group) or 3035 isinstance(replica_id_in_sync_group, int)): 3036 raise ValueError( 3037 "replica_id_in_sync_group can only be an integer, a Tensor or None.") 3038 self._replica_id_in_sync_group = replica_id_in_sync_group 3039 # We need this check because TPUContext extends from ReplicaContext and 3040 # does not pass a strategy object since it is used by TPUEstimator. 3041 if strategy: 3042 self._local_replica_id = strategy.extended._get_local_replica_id( 3043 replica_id_in_sync_group) 3044 self._summary_recording_distribution_strategy = None 3045 3046 @doc_controls.do_not_generate_docs 3047 def __enter__(self): 3048 _push_per_thread_mode(self._thread_context) 3049 3050 def replica_id_is_zero(): 3051 return math_ops.equal(self.replica_id_in_sync_group, 3052 constant_op.constant(0)) 3053 3054 summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access 3055 self._summary_recording_distribution_strategy = ( 3056 summary_state.is_recording_distribution_strategy) 3057 summary_state.is_recording_distribution_strategy = replica_id_is_zero 3058 3059 @doc_controls.do_not_generate_docs 3060 def __exit__(self, exception_type, exception_value, traceback): 3061 summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access 3062 summary_state.is_recording_distribution_strategy = ( 3063 self._summary_recording_distribution_strategy) 3064 _pop_per_thread_mode() 3065 3066 def merge_call(self, merge_fn, args=(), kwargs=None): 3067 """Merge args across replicas and run `merge_fn` in a cross-replica context. 3068 3069 This allows communication and coordination when there are multiple calls 3070 to the step_fn triggered by a call to `strategy.run(step_fn, ...)`. 3071 3072 See `tf.distribute.Strategy.run` for an explanation. 3073 3074 If not inside a distributed scope, this is equivalent to: 3075 3076 ``` 3077 strategy = tf.distribute.get_strategy() 3078 with cross-replica-context(strategy): 3079 return merge_fn(strategy, *args, **kwargs) 3080 ``` 3081 3082 Args: 3083 merge_fn: Function that joins arguments from threads that are given as 3084 PerReplica. It accepts `tf.distribute.Strategy` object as 3085 the first argument. 3086 args: List or tuple with positional per-thread arguments for `merge_fn`. 3087 kwargs: Dict with keyword per-thread arguments for `merge_fn`. 3088 3089 Returns: 3090 The return value of `merge_fn`, except for `PerReplica` values which are 3091 unpacked. 3092 """ 3093 require_replica_context(self) 3094 if kwargs is None: 3095 kwargs = {} 3096 3097 merge_fn = autograph.tf_convert( 3098 merge_fn, autograph_ctx.control_status_ctx(), convert_by_default=False) 3099 return self._merge_call(merge_fn, args, kwargs) 3100 3101 def _merge_call(self, merge_fn, args, kwargs): 3102 """Default implementation for single replica.""" 3103 _push_per_thread_mode( # thread-local, so not needed with multiple threads 3104 distribution_strategy_context._CrossReplicaThreadMode(self._strategy)) # pylint: disable=protected-access 3105 try: 3106 return merge_fn(self._strategy, *args, **kwargs) 3107 finally: 3108 _pop_per_thread_mode() 3109 3110 @property 3111 def num_replicas_in_sync(self): 3112 """Returns number of replicas that are kept in sync.""" 3113 return self._strategy.num_replicas_in_sync 3114 3115 @property 3116 def replica_id_in_sync_group(self): 3117 """Returns the id of the replica. 3118 3119 This identifies the replica among all replicas that are kept in sync. The 3120 value of the replica id can range from 0 to 3121 `tf.distribute.ReplicaContext.num_replicas_in_sync` - 1. 3122 3123 NOTE: This is not guaranteed to be the same ID as the XLA replica ID use 3124 for low-level operations such as collective_permute. 3125 3126 Returns: 3127 a `Tensor`. 3128 """ 3129 # It's important to prefer making the Tensor at call time whenever possible. 3130 # Keeping Tensors in global states doesn't work well with nested 3131 # tf.function, since it's possible that the tensor is generated in one func 3132 # graph, and gets captured by another, which will result in a subtle "An op 3133 # outside of the function building code is being passed a Graph tensor" 3134 # error. Making the tensor at call time to ensure it is the same graph where 3135 # it's used. However to be compatible with tpu.replicate(), 3136 # self._replica_id_in_sync_group can also be a Tensor. 3137 if tensor_util.is_tf_type(self._replica_id_in_sync_group): 3138 return self._replica_id_in_sync_group 3139 return constant_op.constant( 3140 self._replica_id_in_sync_group, 3141 dtypes.int32, 3142 name="replica_id_in_sync_group") 3143 3144 @property 3145 def _replica_id(self): 3146 """This is the local replica id in a given sync group.""" 3147 return self._local_replica_id 3148 3149 @property 3150 def strategy(self): 3151 """The current `tf.distribute.Strategy` object.""" 3152 return self._strategy 3153 3154 @property 3155 @deprecation.deprecated(None, "Please avoid relying on devices property.") 3156 def devices(self): 3157 """Returns the devices this replica is to be executed on, as a tuple of strings. 3158 3159 NOTE: For `tf.distribute.MirroredStrategy` and 3160 `tf.distribute.experimental.MultiWorkerMirroredStrategy`, this returns a 3161 nested 3162 list of device strings, e.g, [["GPU:0"]]. 3163 """ 3164 require_replica_context(self) 3165 return (device_util.current(),) 3166 3167 def all_reduce(self, reduce_op, value, options=None): 3168 """All-reduces `value` across all replicas. 3169 3170 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 3171 >>> def step_fn(): 3172 ... ctx = tf.distribute.get_replica_context() 3173 ... value = tf.identity(1.) 3174 ... return ctx.all_reduce(tf.distribute.ReduceOp.SUM, value) 3175 >>> strategy.experimental_local_results(strategy.run(step_fn)) 3176 (<tf.Tensor: shape=(), dtype=float32, numpy=2.0>, 3177 <tf.Tensor: shape=(), dtype=float32, numpy=2.0>) 3178 3179 It supports batched operations. You can pass a list of values and it 3180 attempts to batch them when possible. You can also specify `options` 3181 to indicate the desired batching behavior, e.g. batch the values into 3182 multiple packs so that they can better overlap with computations. 3183 3184 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 3185 >>> def step_fn(): 3186 ... ctx = tf.distribute.get_replica_context() 3187 ... value1 = tf.identity(1.) 3188 ... value2 = tf.identity(2.) 3189 ... return ctx.all_reduce(tf.distribute.ReduceOp.SUM, [value1, value2]) 3190 >>> strategy.experimental_local_results(strategy.run(step_fn)) 3191 ([<tf.Tensor: shape=(), dtype=float32, numpy=2.0>, 3192 <tf.Tensor: shape=(), dtype=float32, numpy=4.0>], 3193 [<tf.Tensor: shape=(), dtype=float32, numpy=2.0>, 3194 <tf.Tensor: shape=(), dtype=float32, numpy=4.0>]) 3195 3196 Note that all replicas need to participate in the all-reduce, otherwise this 3197 operation hangs. Note that if there're multiple all-reduces, they need to 3198 execute in the same order on all replicas. Dispatching all-reduce based on 3199 conditions is usually error-prone. 3200 3201 Known limitation: if `value` contains `tf.IndexedSlices`, attempting to 3202 compute gradient w.r.t `value` would result in an error. 3203 3204 This API currently can only be called in the replica context. Other 3205 variants to reduce values across replicas are: 3206 * `tf.distribute.StrategyExtended.reduce_to`: the reduce and all-reduce API 3207 in the cross-replica context. 3208 * `tf.distribute.StrategyExtended.batch_reduce_to`: the batched reduce and 3209 all-reduce API in the cross-replica context. 3210 * `tf.distribute.Strategy.reduce`: a more convenient method to reduce 3211 to the host in cross-replica context. 3212 3213 Args: 3214 reduce_op: a `tf.distribute.ReduceOp` value specifying how values should 3215 be combined. Allows using string representation of the enum such as 3216 "SUM", "MEAN". 3217 value: a potentially nested structure of `tf.Tensor` or `tf.IndexedSlices` which 3218 `tf.nest.flatten` accepts. The structure and the shapes of `value` need to be 3219 same on all replicas. 3220 options: a `tf.distribute.experimental.CommunicationOptions`. Options to 3221 perform collective operations. This overrides the default options if the 3222 `tf.distribute.Strategy` takes one in the constructor. See 3223 `tf.distribute.experimental.CommunicationOptions` for details of the 3224 options. 3225 3226 Returns: 3227 A nested structure of `tf.Tensor` with the reduced values. The structure 3228 is the same as `value`. 3229 """ 3230 flattened_value = nest.flatten(value) 3231 has_indexed_slices = False 3232 3233 for v in flattened_value: 3234 if isinstance(v, ops.IndexedSlices): 3235 has_indexed_slices = True 3236 3237 if isinstance(reduce_op, six.string_types): 3238 reduce_op = reduce_util.ReduceOp(reduce_op.upper()) 3239 if options is None: 3240 options = collective_util.Options() 3241 3242 def batch_all_reduce(strategy, *value_flat): 3243 return strategy.extended.batch_reduce_to( 3244 reduce_op, [(v, _batch_reduce_destination(v)) for v in value_flat], 3245 options) 3246 3247 # Due to the use of `capture_call_time_value` in collective ops, we have 3248 # to maintain two branches: one w/ merge_call and one w/o. Details can be 3249 # found in b/184009754. 3250 if self._strategy.extended._use_merge_call(): # pylint: disable=protected-access 3251 # TODO(cjfj): Work out why `batch_reduce` doesn't return the correct grad. 3252 if has_indexed_slices: 3253 return nest.pack_sequence_as( 3254 value, 3255 self.merge_call(batch_all_reduce, args=flattened_value)) 3256 3257 @custom_gradient.custom_gradient 3258 def grad_wrapper(*xs): 3259 ys = self.merge_call(batch_all_reduce, args=xs) 3260 # The gradient of an all-sum is itself an all-sum (all-mean, likewise). 3261 return ys, lambda *dy_s: self.all_reduce(reduce_op, dy_s) 3262 return nest.pack_sequence_as(value, grad_wrapper(*flattened_value)) 3263 else: 3264 if has_indexed_slices: 3265 return nest.pack_sequence_as( 3266 value, 3267 self._strategy.extended._replica_ctx_all_reduce( # pylint: disable=protected-access 3268 reduce_op, flattened_value, options)) 3269 3270 @custom_gradient.custom_gradient 3271 def grad_wrapper(*xs): 3272 ys = self._strategy.extended._replica_ctx_all_reduce( # pylint: disable=protected-access 3273 reduce_op, xs, options) 3274 # The gradient of an all-sum is itself an all-sum (all-mean, likewise). 3275 return ys, lambda *dy_s: self.all_reduce(reduce_op, dy_s) 3276 3277 return nest.pack_sequence_as(value, grad_wrapper(*flattened_value)) 3278 3279 # TODO(josh11b): Implement `start_all_reduce(method, t)` for efficient 3280 # all-reduce. It would return a function returning the result of reducing `t` 3281 # across all replicas. The caller would wait to call this function until they 3282 # needed the reduce result, allowing an efficient implementation: 3283 # * With eager execution, the reduction could be performed asynchronously 3284 # in the background, not blocking until the result was needed. 3285 # * When constructing a graph, it could batch up all reduction requests up 3286 # to that point that the first result is needed. Most likely this can be 3287 # implemented in terms of `merge_call()` and `batch_reduce_to()`. 3288 3289 3290@tf_export("distribute.ReplicaContext", v1=[]) 3291class ReplicaContext(ReplicaContextBase): 3292 3293 __doc__ = ReplicaContextBase.__doc__ 3294 3295 def all_gather(self, value, axis, options=None): 3296 """All-gathers `value` across all replicas along `axis`. 3297 3298 Note: An `all_gather` method can only be called in replica context. For 3299 a cross-replica context counterpart, see `tf.distribute.Strategy.gather`. 3300 All replicas need to participate in the all-gather, otherwise this 3301 operation hangs. So if `all_gather` is called in any replica, it must be 3302 called in all replicas. 3303 3304 Note: If there are multiple `all_gather` calls, they need to be executed in 3305 the same order on all replicas. Dispatching `all_gather` based on conditions 3306 is usually error-prone. 3307 3308 For all strategies except `tf.distribute.TPUStrategy`, the input 3309 `value` on different replicas must have the same rank, and their shapes must 3310 be the same in all dimensions except the `axis`-th dimension. In other 3311 words, their shapes cannot be different in a dimension `d` where `d` does 3312 not equal to the `axis` argument. For example, given a 3313 `tf.distribute.DistributedValues` with component tensors of shape 3314 `(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can call 3315 `all_gather(..., axis=1, ...)` on it, but not `all_gather(..., axis=0, ...)` 3316 or `all_gather(..., axis=2, ...)`. However, with 3317 `tf.distribute.TPUStrategy`, all tensors must have exactly the same rank and 3318 same shape. 3319 3320 Note: The input `value` must have a non-zero rank. Otherwise, consider using 3321 `tf.expand_dims` before gathering them. 3322 3323 You can pass in a single tensor to all-gather: 3324 3325 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 3326 >>> @tf.function 3327 ... def gather_value(): 3328 ... ctx = tf.distribute.get_replica_context() 3329 ... local_value = tf.constant([1, 2, 3]) 3330 ... return ctx.all_gather(local_value, axis=0) 3331 >>> result = strategy.run(gather_value) 3332 >>> result 3333 PerReplica:{ 3334 0: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>, 3335 1: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)> 3336 } 3337 >>> strategy.experimental_local_results(result) 3338 (<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], 3339 dtype=int32)>, 3340 <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], 3341 dtype=int32)>) 3342 3343 3344 You can also pass in a nested structure of tensors to all-gather, say, a 3345 list: 3346 3347 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 3348 >>> @tf.function 3349 ... def gather_nest(): 3350 ... ctx = tf.distribute.get_replica_context() 3351 ... value_1 = tf.constant([1, 2, 3]) 3352 ... value_2 = tf.constant([[1, 2], [3, 4]]) 3353 ... # all_gather a nest of `tf.distribute.DistributedValues` 3354 ... return ctx.all_gather([value_1, value_2], axis=0) 3355 >>> result = strategy.run(gather_nest) 3356 >>> result 3357 [PerReplica:{ 3358 0: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>, 3359 1: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)> 3360 }, PerReplica:{ 3361 0: <tf.Tensor: shape=(4, 2), dtype=int32, numpy= 3362 array([[1, 2], 3363 [3, 4], 3364 [1, 2], 3365 [3, 4]], dtype=int32)>, 3366 1: <tf.Tensor: shape=(4, 2), dtype=int32, numpy= 3367 array([[1, 2], 3368 [3, 4], 3369 [1, 2], 3370 [3, 4]], dtype=int32)> 3371 }] 3372 >>> strategy.experimental_local_results(result) 3373 ([<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>, 3374 <tf.Tensor: shape=(4, 2), dtype=int32, numpy= 3375 array([[1, 2], 3376 [3, 4], 3377 [1, 2], 3378 [3, 4]], dtype=int32)>], 3379 [<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>, 3380 <tf.Tensor: shape=(4, 2), dtype=int32, numpy= 3381 array([[1, 2], 3382 [3, 4], 3383 [1, 2], 3384 [3, 4]], dtype=int32)>]) 3385 3386 3387 What if you are all-gathering tensors with different shapes on different 3388 replicas? Consider the following example with two replicas, where you have 3389 `value` as a nested structure consisting of two items to all-gather, `a` and 3390 `b`. 3391 3392 On Replica 0, `value` is `{'a': [0], 'b': [[0, 1]]}`. 3393 3394 On Replica 1, `value` is `{'a': [1], 'b': [[2, 3], [4, 5]]}`. 3395 3396 Result for `all_gather` with `axis`=0 (on each of the replicas) is: 3397 3398 ```{'a': [1, 2], 'b': [[0, 1], [2, 3], [4, 5]]}``` 3399 3400 Args: 3401 value: a nested structure of `tf.Tensor` which `tf.nest.flatten` accepts, 3402 or a `tf.distribute.DistributedValues` instance. The structure of the 3403 `tf.Tensor` need to be same on all replicas. The underlying tensor 3404 constructs can only be dense tensors with non-zero rank, NOT 3405 `tf.IndexedSlices`. 3406 axis: 0-D int32 Tensor. Dimension along which to gather. 3407 options: a `tf.distribute.experimental.CommunicationOptions`. Options to 3408 perform collective operations. This overrides the default options if the 3409 `tf.distribute.Strategy` takes one in the constructor. See 3410 `tf.distribute.experimental.CommunicationOptions` for details of the 3411 options. 3412 3413 Returns: 3414 A nested structure of `tf.Tensor` with the gathered values. The structure 3415 is the same as `value`. 3416 """ 3417 for v in nest.flatten(value): 3418 if isinstance(v, ops.IndexedSlices): 3419 raise NotImplementedError("all_gather does not support IndexedSlices") 3420 3421 if options is None: 3422 options = collective_util.Options() 3423 3424 def batch_all_gather(strategy, *value_flat): 3425 return strategy.extended._batch_gather_to( # pylint: disable=protected-access 3426 [(v, _batch_reduce_destination(v)) for v in value_flat], axis, 3427 options) 3428 3429 @custom_gradient.custom_gradient 3430 def grad_wrapper(*xs): 3431 ys = self.merge_call(batch_all_gather, args=xs) 3432 3433 def grad(*dy_s): 3434 grads = self.all_reduce(reduce_util.ReduceOp.SUM, dy_s) 3435 new_grads = [] 3436 for i, grad in enumerate(grads): 3437 input_shape = array_ops.shape(xs[i]) 3438 axis_dim = array_ops.reshape(input_shape[axis], [1]) 3439 with ops.control_dependencies([array_ops.identity(grads)]): 3440 d = self.all_gather(axis_dim, axis=0) 3441 begin_dim = math_ops.reduce_sum(d[:self.replica_id_in_sync_group]) 3442 end_dim = begin_dim + array_ops.shape(xs[i])[axis] 3443 new_grad = array_ops.gather( 3444 grad, axis=axis, indices=math_ops.range(begin_dim, end_dim)) 3445 new_grads.append(new_grad) 3446 return new_grads 3447 3448 return ys, grad 3449 3450 return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value))) 3451 3452 def _update(self, var, fn, args=(), kwargs=None, group=True): 3453 """Run `fn` to update `var` with `args` and `kwargs` in replica context. 3454 3455 `tf.distribute.ReplicaContext.update` takes a (distributed) variable `var` 3456 to be updated, an update function `fn`, and `args` and `kwargs` for `fn`. 3457 `fn` applies to each component variable of `var` with corresponding input 3458 values from `args` and `kwargs`. 3459 3460 Example usage: 3461 3462 >>> strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # 2 replicas 3463 >>> with strategy.scope(): 3464 ... distributed_variable = tf.Variable(5.0) 3465 >>> distributed_variable 3466 MirroredVariable:{ 3467 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>, 3468 1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=5.0> 3469 } 3470 >>> def replica_fn(v): 3471 ... value = tf.identity(1.0) 3472 ... replica_context = tf.distribute.get_replica_context() 3473 ... update_fn = lambda var, value: var.assign(value) 3474 ... replica_context._update(v, update_fn, args=(value,)) 3475 >>> strategy.run(replica_fn, args=(distributed_variable,)) 3476 >>> distributed_variable 3477 MirroredVariable:{ 3478 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>, 3479 1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=1.0> 3480 } 3481 3482 This API must be called in a replica context. 3483 3484 Note that if `var` is a MirroredVariable (i.e., the type of variable created 3485 under the scope of a synchronous strategy, and is synchronized on-write, see 3486 `tf.VariableSynchronization` for more information) and `args`/`kwargs` 3487 contains different values for different replicas, `var` will be dangerously 3488 out of synchronization. Thus we recommend using `variable.assign(value)` as 3489 long as you can, which under the hood aggregates the updates and guarantees 3490 the synchronization. The case where you actually want this API instead of 3491 `variable.assign(value)` is that before assigning `value` to the `variable`, 3492 you'd like to conduct some pre-`assign` computation colocated with the 3493 variable devices (i.e. where variables reside, for MirroredStrategy they are 3494 the same as the compute device, for ParameterServerStrategy they refer to 3495 parameter servers). E.g., 3496 3497 ```python 3498 strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # 2 replicas 3499 with strategy.scope(): 3500 v = tf.Variable(5.0, aggregation=tf.VariableAggregation.SUM) 3501 def replica_fn(inputs): 3502 value = computation(inputs) 3503 replica_context = tf.distribute.get_replica_context() 3504 reduced_value = replica_context.all_reduce(value) 3505 3506 def update_fn(var, value): 3507 # this computation will colocate with `var`'s device 3508 updated_value = post_reduce_pre_update_computation(value) 3509 var.assign(value) 3510 3511 replica_context._update(v, update_fn, args=(reduced_value,)) 3512 3513 strategy.run(replica_fn, args=(inputs,)) 3514 ``` 3515 3516 This code snippet is consistent across all strategies. If you directly 3517 compute and use `assign` in the replica context instead of wrapping it with 3518 `update`, for strategies with fewer variable devices than compute devices 3519 (e.g., parameter server strategy, usually), the 3520 `post_reduce_pre_update_computation` will happen 3521 N==number_of_compute_devices times which is less performant. 3522 3523 3524 Args: 3525 var: Variable, possibly distributed to multiple devices, to operate on. 3526 fn: Function to call. Should take the variable as the first argument. 3527 args: Tuple or list. Additional positional arguments to pass to `fn()`. 3528 kwargs: Dict with keyword arguments to pass to `fn()`. 3529 group: Boolean. Defaults to True. Most strategies enter a merge_call to 3530 conduct update in cross-replica context, and group=True guarantees updates 3531 on all replicas is executed. 3532 3533 Returns: 3534 The return value of `fn` for the local replica. 3535 """ 3536 if kwargs is None: 3537 kwargs = {} 3538 return self._strategy.extended._replica_ctx_update(var, fn, args=args, kwargs=kwargs, group=group) # pylint: disable=protected-access 3539 3540 3541@tf_export(v1=["distribute.ReplicaContext"]) 3542class ReplicaContextV1(ReplicaContextBase): 3543 __doc__ = ReplicaContextBase.__doc__ 3544 3545 3546def _batch_reduce_destination(x): 3547 """Returns the destinations for batch all-reduce.""" 3548 if isinstance(x, ops.Tensor): 3549 # If this is a one device strategy. 3550 return x.device 3551 else: 3552 return x 3553 3554 3555# ------------------------------------------------------------------------------ 3556 3557 3558_creating_default_strategy_singleton = False 3559 3560 3561class _DefaultDistributionStrategyV1(StrategyV1): 3562 """Default `tf.distribute.Strategy` if none is explicitly selected.""" 3563 3564 def __init__(self): 3565 if not _creating_default_strategy_singleton: 3566 raise RuntimeError("Should only create a single instance of " 3567 "_DefaultDistributionStrategy") 3568 super(_DefaultDistributionStrategyV1, 3569 self).__init__(_DefaultDistributionExtended(self)) 3570 3571 def __deepcopy__(self, memo): 3572 del memo 3573 raise RuntimeError("Should only create a single instance of " 3574 "_DefaultDistributionStrategy") 3575 3576 3577class _DefaultDistributionStrategy(Strategy): 3578 """Default `tf.distribute.Strategy` if none is explicitly selected.""" 3579 3580 def __init__(self): 3581 if not _creating_default_strategy_singleton: 3582 raise RuntimeError("Should only create a single instance of " 3583 "_DefaultDistributionStrategy") 3584 super(_DefaultDistributionStrategy, self).__init__( 3585 _DefaultDistributionExtended(self)) 3586 3587 def __deepcopy__(self, memo): 3588 del memo 3589 raise RuntimeError("Should only create a single instance of " 3590 "_DefaultDistributionStrategy") 3591 3592 3593class _DefaultDistributionContext(object): 3594 """Context manager setting the default `tf.distribute.Strategy`.""" 3595 3596 __slots__ = ["_var_creator_scope", "_strategy", "_nested_count"] 3597 3598 def __init__(self, strategy): 3599 3600 def creator(next_creator, **kwargs): 3601 _require_strategy_scope_strategy(strategy) 3602 return next_creator(**kwargs) 3603 3604 self._var_creator_scope = variable_scope.variable_creator_scope(creator) 3605 self._strategy = strategy 3606 self._nested_count = 0 3607 3608 def __enter__(self): 3609 # Allow this scope to be entered if this strategy is already in scope. 3610 if distribution_strategy_context.has_strategy(): 3611 raise RuntimeError("Must not nest tf.distribute.Strategy scopes.") 3612 if self._nested_count == 0: 3613 self._var_creator_scope.__enter__() 3614 self._nested_count += 1 3615 return self._strategy 3616 3617 def __exit__(self, exception_type, exception_value, traceback): 3618 self._nested_count -= 1 3619 if self._nested_count == 0: 3620 try: 3621 self._var_creator_scope.__exit__( 3622 exception_type, exception_value, traceback) 3623 except RuntimeError as e: 3624 six.raise_from( 3625 RuntimeError("Variable creator scope nesting error: move call to " 3626 "tf.distribute.set_strategy() out of `with` scope."), 3627 e) 3628 3629 3630class _DefaultDistributionExtended(StrategyExtendedV1): 3631 """Implementation of _DefaultDistributionStrategy.""" 3632 3633 def __init__(self, container_strategy): 3634 super(_DefaultDistributionExtended, self).__init__(container_strategy) 3635 self._retrace_functions_for_each_device = False 3636 3637 def _scope(self, strategy): 3638 """Context manager setting a variable creator and `self` as current.""" 3639 return _DefaultDistributionContext(strategy) 3640 3641 def colocate_vars_with(self, colocate_with_variable): 3642 """Does not require `self.scope`.""" 3643 _require_strategy_scope_extended(self) 3644 return ops.colocate_with(colocate_with_variable) 3645 3646 def variable_created_in_scope(self, v): 3647 return v._distribute_strategy is None # pylint: disable=protected-access 3648 3649 def _experimental_distribute_dataset(self, dataset, options): 3650 return dataset 3651 3652 def _distribute_datasets_from_function(self, dataset_fn, options): 3653 return dataset_fn(InputContext()) 3654 3655 def _experimental_distribute_values_from_function(self, value_fn): 3656 return value_fn(ValueContext()) 3657 3658 def _make_dataset_iterator(self, dataset): 3659 return _DefaultDistributionExtended.DefaultInputIterator(dataset) 3660 3661 def _make_input_fn_iterator(self, 3662 input_fn, 3663 replication_mode=InputReplicationMode.PER_WORKER): 3664 dataset = input_fn(InputContext()) 3665 return _DefaultDistributionExtended.DefaultInputIterator(dataset) 3666 3667 def _experimental_make_numpy_dataset(self, numpy_input, session): 3668 numpy_flat = nest.flatten(numpy_input) 3669 vars_flat = tuple( 3670 variable_scope.variable(array_ops.zeros(i.shape, i.dtype), 3671 trainable=False, use_resource=True) 3672 for i in numpy_flat 3673 ) 3674 for v, i in zip(vars_flat, numpy_flat): 3675 numpy_dataset.init_var_from_numpy(v, i, session) 3676 vars_nested = nest.pack_sequence_as(numpy_input, vars_flat) 3677 return dataset_ops.Dataset.from_tensor_slices(vars_nested) 3678 3679 def _broadcast_to(self, tensor, destinations): 3680 if destinations is None: 3681 return tensor 3682 else: 3683 raise NotImplementedError("TODO") 3684 3685 def _call_for_each_replica(self, fn, args, kwargs): 3686 with ReplicaContext(self._container_strategy(), replica_id_in_sync_group=0): 3687 return fn(*args, **kwargs) 3688 3689 def _reduce_to(self, reduce_op, value, destinations, options): 3690 # TODO(josh11b): Use destinations? 3691 del reduce_op, destinations, options 3692 return value 3693 3694 def _gather_to_implementation(self, value, destinations, axis, options): 3695 del destinations, axis, options 3696 return value 3697 3698 def _update(self, var, fn, args, kwargs, group): 3699 # The implementations of _update() and _update_non_slot() are identical 3700 # except _update() passes `var` as the first argument to `fn()`. 3701 return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group) 3702 3703 def _update_non_slot(self, colocate_with, fn, args, kwargs, should_group): 3704 # TODO(josh11b): Figure out what we should be passing to UpdateContext() 3705 # once that value is used for something. 3706 with UpdateContext(colocate_with): 3707 result = fn(*args, **kwargs) 3708 if should_group: 3709 return result 3710 else: 3711 return nest.map_structure(self._local_results, result) 3712 3713 def read_var(self, replica_local_var): 3714 return array_ops.identity(replica_local_var) 3715 3716 def _local_results(self, distributed_value): 3717 return (distributed_value,) 3718 3719 def value_container(self, value): 3720 return value 3721 3722 @property 3723 def _num_replicas_in_sync(self): 3724 return 1 3725 3726 @property 3727 def worker_devices(self): 3728 raise RuntimeError("worker_devices() method unsupported by default " 3729 "tf.distribute.Strategy.") 3730 3731 @property 3732 def parameter_devices(self): 3733 raise RuntimeError("parameter_devices() method unsupported by default " 3734 "tf.distribute.Strategy.") 3735 3736 def non_slot_devices(self, var_list): 3737 return min(var_list, key=lambda x: x.name) 3738 3739 def _in_multi_worker_mode(self): 3740 """Whether this strategy indicates working in multi-worker settings.""" 3741 # Default strategy doesn't indicate multi-worker training. 3742 return False 3743 3744 @property 3745 def should_checkpoint(self): 3746 return True 3747 3748 @property 3749 def should_save_summary(self): 3750 return True 3751 3752 def _get_local_replica_id(self, replica_id_in_sync_group): 3753 return replica_id_in_sync_group 3754 3755 def _get_replica_id_in_sync_group(self, replica_id): 3756 return replica_id 3757 3758 # TODO(priyag): This should inherit from `InputIterator`, once dependency 3759 # issues have been resolved. 3760 class DefaultInputIterator(object): 3761 """Default implementation of `InputIterator` for default strategy.""" 3762 3763 def __init__(self, dataset): 3764 self._dataset = dataset 3765 if eager_context.executing_eagerly(): 3766 self._iterator = dataset_ops.make_one_shot_iterator(dataset) 3767 else: 3768 self._iterator = dataset_ops.make_initializable_iterator(dataset) 3769 3770 def get_next(self): 3771 return self._iterator.get_next() 3772 3773 def get_next_as_optional(self): 3774 return self._iterator.get_next_as_optional() 3775 3776 @deprecated(None, "Use the iterator's `initializer` property instead.") 3777 def initialize(self): 3778 """Initialize underlying iterators. 3779 3780 Returns: 3781 A list of any initializer ops that should be run. 3782 """ 3783 if eager_context.executing_eagerly(): 3784 self._iterator = self._dataset.make_one_shot_iterator() 3785 return [] 3786 else: 3787 return [self._iterator.initializer] 3788 3789 @property 3790 def initializer(self): 3791 """Returns a list of ops that initialize the iterator.""" 3792 return self.initialize() 3793 3794 # TODO(priyag): Delete this once all strategies use global batch size. 3795 @property 3796 def _global_batch_size(self): 3797 """Global and per-replica batching are equivalent for this strategy.""" 3798 return True 3799 3800 3801class _DefaultReplicaContext(ReplicaContext): 3802 """ReplicaContext for _DefaultDistributionStrategy.""" 3803 3804 @property 3805 def replica_id_in_sync_group(self): 3806 # Return 0 instead of a constant tensor to avoid creating a new node for 3807 # users who don't use distribution strategy. 3808 return 0 3809 3810 3811# ------------------------------------------------------------------------------ 3812# We haven't yet implemented deserialization for DistributedVariables. 3813# So here we catch any attempts to deserialize variables 3814# when using distribution strategies. 3815# pylint: disable=protected-access 3816_original_from_proto = resource_variable_ops._from_proto_fn 3817 3818 3819def _from_proto_fn(v, import_scope=None): 3820 if distribution_strategy_context.has_strategy(): 3821 raise NotImplementedError( 3822 "Deserialization of variables is not yet supported when using a " 3823 "tf.distribute.Strategy.") 3824 else: 3825 return _original_from_proto(v, import_scope=import_scope) 3826 3827resource_variable_ops._from_proto_fn = _from_proto_fn 3828# pylint: enable=protected-access 3829 3830 3831#------------------------------------------------------------------------------- 3832# Shorthand for some methods from distribution_strategy_context. 3833_push_per_thread_mode = distribution_strategy_context._push_per_thread_mode # pylint: disable=protected-access 3834_get_per_thread_mode = distribution_strategy_context._get_per_thread_mode # pylint: disable=protected-access 3835_pop_per_thread_mode = distribution_strategy_context._pop_per_thread_mode # pylint: disable=protected-access 3836_get_default_replica_mode = ( 3837 distribution_strategy_context._get_default_replica_mode) # pylint: disable=protected-access 3838 3839 3840# ------------------------------------------------------------------------------ 3841# Metrics to track which distribution strategy is being called 3842distribution_strategy_gauge = monitoring.StringGauge( 3843 "/tensorflow/api/distribution_strategy", 3844 "Gauge to track the type of distribution strategy used.", "TFVersion") 3845distribution_strategy_replica_gauge = monitoring.IntGauge( 3846 "/tensorflow/api/distribution_strategy/replica", 3847 "Gauge to track the number of replica each distribution strategy used.", 3848 "CountType") 3849distribution_strategy_input_api_counter = monitoring.Counter( 3850 "/tensorflow/api/distribution_strategy/input_api", 3851 "Counter to track the usage of the input APIs", "strategy", "api") 3852