1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=line-too-long 16"""Library for running a computation across multiple devices. 17 18The intent of this library is that you can write an algorithm in a stylized way 19and it will be usable with a variety of different `tf.distribute.Strategy` 20implementations. Each descendant will implement a different strategy for 21distributing the algorithm across multiple devices/machines. Furthermore, these 22changes can be hidden inside the specific layers and other library classes that 23need special treatment to run in a distributed setting, so that most users' 24model definition code can run unchanged. The `tf.distribute.Strategy` API works 25the same way with eager and graph execution. 26 27*Guides* 28 29* [TensorFlow v2.x](https://www.tensorflow.org/guide/distributed_training) 30* [TensorFlow v1.x](https://github.com/tensorflow/docs/blob/master/site/en/r1/guide/distribute_strategy.ipynb) 31 32*Tutorials* 33 34* [Distributed Training Tutorials](https://www.tensorflow.org/tutorials/distribute/) 35 36 The tutorials cover how to use `tf.distribute.Strategy` to do distributed 37 training with native Keras APIs, custom training loops, 38 and Estimator APIs. They also cover how to save/load model when using 39 `tf.distribute.Strategy`. 40 41*Glossary* 42 43* _Data parallelism_ is where we run multiple copies of the model 44 on different slices of the input data. This is in contrast to 45 _model parallelism_ where we divide up a single copy of a model 46 across multiple devices. 47 Note: we only support data parallelism for now, but 48 hope to add support for model parallelism in the future. 49* A _device_ is a CPU or accelerator (e.g. GPUs, TPUs) on some machine that 50 TensorFlow can run operations on (see e.g. `tf.device`). You may have multiple 51 devices on a single machine, or be connected to devices on multiple 52 machines. Devices used to run computations are called _worker devices_. 53 Devices used to store variables are _parameter devices_. For some strategies, 54 such as `tf.distribute.MirroredStrategy`, the worker and parameter devices 55 will be the same (see mirrored variables below). For others they will be 56 different. For example, `tf.distribute.experimental.CentralStorageStrategy` 57 puts the variables on a single device (which may be a worker device or may be 58 the CPU), and `tf.distribute.experimental.ParameterServerStrategy` puts the 59 variables on separate machines called _parameter servers_ (see below). 60* A _replica_ is one copy of the model, running on one slice of the 61 input data. Right now each replica is executed on its own 62 worker device, but once we add support for model parallelism 63 a replica may span multiple worker devices. 64* A _host_ is the CPU device on a machine with worker devices, typically 65 used for running input pipelines. 66* A _worker_ is defined to be the physical machine(s) containing the physical 67 devices (e.g. GPUs, TPUs) on which the replicated computation is executed. A 68 worker may contain one or more replicas, but contains at least one 69 replica. Typically one worker will correspond to one machine, but in the case 70 of very large models with model parallelism, one worker may span multiple 71 machines. We typically run one input pipeline per worker, feeding all the 72 replicas on that worker. 73* _Synchronous_, or more commonly _sync_, training is where the updates from 74 each replica are aggregated together before updating the model variables. This 75 is in contrast to _asynchronous_, or _async_ training, where each replica 76 updates the model variables independently. You may also have replicas 77 partitioned into groups which are in sync within each group but async between 78 groups. 79* _Parameter servers_: These are machines that hold a single copy of 80 parameters/variables, used by some strategies (right now just 81 `tf.distribute.experimental.ParameterServerStrategy`). All replicas that want 82 to operate on a variable retrieve it at the beginning of a step and send an 83 update to be applied at the end of the step. These can in principle support 84 either sync or async training, but right now we only have support for async 85 training with parameter servers. Compare to 86 `tf.distribute.experimental.CentralStorageStrategy`, which puts all variables 87 on a single device on the same machine (and does sync training), and 88 `tf.distribute.MirroredStrategy`, which mirrors variables to multiple devices 89 (see below). 90 91* _Replica context_ vs. _Cross-replica context_ vs _Update context_ 92 93 A _replica context_ applies 94 when you execute the computation function that was called with `strategy.run`. 95 Conceptually, you're in replica context when executing the computation 96 function that is being replicated. 97 98 An _update context_ is entered in a `tf.distribute.StrategyExtended.update` 99 call. 100 101 An _cross-replica context_ is entered when you enter a `strategy.scope`. This 102 is useful for calling `tf.distribute.Strategy` methods which operate across 103 the replicas (like `reduce_to()`). By default you start in a _replica context_ 104 (the "default single _replica context_") and then some methods can switch you 105 back and forth. 106 107* _Distributed value_: Distributed value is represented by the base class 108 `tf.distribute.DistributedValues`. `tf.distribute.DistributedValues` is useful 109 to represent values on multiple devices, and it contains a map from replica id 110 to values. Two representative types of `tf.distribute.DistributedValues` 111 are `tf.types.experimental.PerReplica` and `tf.types.experimental.Mirrored` 112 values. 113 114 `PerReplica` values exist on the worker devices, with a different value for 115 each replica. They are produced by iterating through a distributed dataset 116 returned by `tf.distribute.Strategy.experimental_distribute_dataset` and 117 `tf.distribute.Strategy.distribute_datasets_from_function`. They are also the 118 typical result returned by `tf.distribute.Strategy.run`. 119 120 `Mirrored` values are like `PerReplica` values, except we know that the value 121 on all replicas are the same. `Mirrored` values are kept synchronized by the 122 distribution strategy in use, while `PerReplica` values are left 123 unsynchronized. `Mirrored` values typically represent model weights. We can 124 safely read a `Mirrored` value in a cross-replica context by using the value 125 on any replica, while PerReplica values can only be read within a replica 126 context. 127 128* _Unwrapping_ and _merging_: Consider calling a function `fn` on multiple 129 replicas, like `strategy.run(fn, args=[w])` with an 130 argument `w` that is a `tf.distribute.DistributedValues`. This means `w` will 131 have a map taking replica id `0` to `w0`, replica id `1` to `w1`, etc. 132 `strategy.run()` unwraps `w` before calling `fn`, so it calls `fn(w0)` on 133 device `d0`, `fn(w1)` on device `d1`, etc. It then merges the return 134 values from `fn()`, which leads to one common object if the returned values 135 are the same object from every replica, or a `DistributedValues` object 136 otherwise. 137 138* _Reductions_ and _all-reduce_: A _reduction_ is a method of aggregating 139 multiple values into one value, like "sum" or "mean". If a strategy is doing 140 sync training, we will perform a reduction on the gradients to a parameter 141 from all replicas before applying the update. _All-reduce_ is an algorithm for 142 performing a reduction on values from multiple devices and making the result 143 available on all of those devices. 144 145* _Mirrored variables_: These are variables that are created on multiple 146 devices, where we keep the variables in sync by applying the same 147 updates to every copy. Mirrored variables are created with 148 `tf.Variable(...synchronization=tf.VariableSynchronization.ON_WRITE...)`. 149 Normally they are only used in synchronous training. 150 151* _SyncOnRead variables_ 152 153 _SyncOnRead variables_ are created by 154 `tf.Variable(...synchronization=tf.VariableSynchronization.ON_READ...)`, and 155 they are created on multiple devices. In replica context, each 156 component variable on the local replica can perform reads and writes without 157 synchronization with each other. When the 158 _SyncOnRead variable_ is read in cross-replica context, the values from 159 component variables are aggregated and returned. 160 161 _SyncOnRead variables_ bring a lot of custom configuration difficulty to the 162 underlying logic, so we do not encourage users to instantiate and use 163 _SyncOnRead variable_ on their own. We have mainly used _SyncOnRead 164 variables_ for use cases such as batch norm and metrics. For performance 165 reasons, we often don't need to keep these statistics in sync every step and 166 they can be accumulated on each replica independently. The only time we want 167 to sync them is reporting or checkpointing, which typically happens in 168 cross-replica context. _SyncOnRead variables_ are also often used by advanced 169 users who want to control when variable values are aggregated. For example, 170 users sometimes want to maintain gradients independently on each replica for a 171 couple of steps without aggregation. 172 173* _Distribute-aware layers_ 174 175 Layers are generally called in a replica context, except when defining a 176 Keras functional model. `tf.distribute.in_cross_replica_context` will let you 177 determine which case you are in. If in a replica context, 178 the `tf.distribute.get_replica_context` function will return the default 179 replica context outside a strategy scope, `None` within a strategy scope, and 180 a `tf.distribute.ReplicaContext` object inside a strategy scope and within a 181 `tf.distribute.Strategy.run` function. The `ReplicaContext` object has an 182 `all_reduce` method for aggregating across all replicas. 183 184 185Note that we provide a default version of `tf.distribute.Strategy` that is 186used when no other strategy is in scope, that provides the same API with 187reasonable default behavior. 188""" 189# pylint: enable=line-too-long 190 191import collections 192import copy 193import enum # pylint: disable=g-bad-import-order 194import functools 195import threading 196import weakref 197 198import six 199 200from tensorflow.python.autograph.core import ag_ctx as autograph_ctx 201from tensorflow.python.autograph.impl import api as autograph 202from tensorflow.python.data.ops import dataset_ops 203from tensorflow.python.distribute import collective_util 204from tensorflow.python.distribute import device_util 205from tensorflow.python.distribute import distribution_strategy_context 206from tensorflow.python.distribute import numpy_dataset 207from tensorflow.python.distribute import reduce_util 208from tensorflow.python.distribute import values 209from tensorflow.python.eager import context as eager_context 210from tensorflow.python.eager import def_function 211from tensorflow.python.eager import monitoring 212from tensorflow.python.framework import constant_op 213from tensorflow.python.framework import dtypes 214from tensorflow.python.framework import indexed_slices 215from tensorflow.python.framework import ops 216from tensorflow.python.framework import tensor_shape 217from tensorflow.python.framework import tensor_util 218from tensorflow.python.ops import array_ops 219from tensorflow.python.ops import control_flow_ops 220from tensorflow.python.ops import custom_gradient 221from tensorflow.python.ops import math_ops 222from tensorflow.python.ops import resource_variable_ops 223from tensorflow.python.ops import summary_ops_v2 224from tensorflow.python.ops import variable_scope 225from tensorflow.python.ops.losses import losses_impl 226from tensorflow.python.platform import tf_logging 227from tensorflow.python.trackable import base as trackable 228from tensorflow.python.util import deprecation 229from tensorflow.python.util import nest 230from tensorflow.python.util import tf_contextlib 231from tensorflow.python.util.deprecation import deprecated 232from tensorflow.python.util.tf_export import tf_export 233from tensorflow.tools.docs import doc_controls 234 235# ------------------------------------------------------------------------------ 236# Context tracking whether in a strategy.update() or .update_non_slot() call. 237 238 239_update_replica_id = threading.local() 240 241 242def get_update_replica_id(): 243 """Get the current device if in a `tf.distribute.Strategy.update()` call.""" 244 try: 245 return _update_replica_id.current 246 except AttributeError: 247 return None 248 249 250class UpdateContext(object): 251 """Context manager when you are in `update()` or `update_non_slot()`.""" 252 253 __slots__ = ["_replica_id", "_old_replica_id"] 254 255 def __init__(self, replica_id): 256 self._replica_id = replica_id 257 self._old_replica_id = None 258 259 def __enter__(self): 260 self._old_replica_id = get_update_replica_id() 261 _update_replica_id.current = self._replica_id 262 263 def __exit__(self, exception_type, exception_value, traceback): 264 del exception_type, exception_value, traceback 265 _update_replica_id.current = self._old_replica_id 266 267 268# ------------------------------------------------------------------------------ 269# Public utility functions. 270 271 272@tf_export(v1=["distribute.get_loss_reduction"]) 273def get_loss_reduction(): 274 """`tf.distribute.ReduceOp` corresponding to the last loss reduction. 275 276 This is used to decide whether loss should be scaled in optimizer (used only 277 for estimator + v1 optimizer use case). 278 279 Returns: 280 `tf.distribute.ReduceOp` corresponding to the last loss reduction for 281 estimator and v1 optimizer use case. `tf.distribute.ReduceOp.SUM` otherwise. 282 """ 283 if not distribution_strategy_context.get_strategy()._scale_loss_for_estimator: # pylint: disable=protected-access 284 # If we are not in Estimator context then return 'SUM'. We do not need to 285 # scale loss in the optimizer. 286 return reduce_util.ReduceOp.SUM 287 last_reduction = ops.get_default_graph()._last_loss_reduction # pylint: disable=protected-access 288 if (last_reduction == losses_impl.Reduction.SUM or 289 last_reduction == "sum"): # Check for tf.keras.losses.Reduction.SUM 290 return reduce_util.ReduceOp.SUM 291 return reduce_util.ReduceOp.MEAN 292 293 294# ------------------------------------------------------------------------------ 295# Internal API for validating the current thread mode 296 297 298def _require_cross_replica_or_default_context_extended(extended, 299 error_message=None): 300 """Verify in cross-replica context.""" 301 context = _get_per_thread_mode() 302 cross_replica = context.cross_replica_context 303 if cross_replica is not None and cross_replica.extended is extended: 304 return 305 if context is _get_default_replica_mode(): 306 return 307 strategy = extended._container_strategy() # pylint: disable=protected-access 308 # We have an error to report, figure out the right message. 309 if context.strategy is not strategy: 310 _wrong_strategy_scope(strategy, context) 311 assert cross_replica is None 312 if not error_message: 313 error_message = ("Method requires being in cross-replica context, use " 314 "get_replica_context().merge_call()") 315 raise RuntimeError(error_message) 316 317 318def _wrong_strategy_scope(strategy, context): 319 # Figure out the right error message. 320 if not distribution_strategy_context.has_strategy(): 321 raise RuntimeError( 322 'Need to be inside "with strategy.scope()" for %s' % 323 (strategy,)) 324 else: 325 raise RuntimeError( 326 "Mixing different tf.distribute.Strategy objects: %s is not %s" % 327 (context.strategy, strategy)) 328 329 330def require_replica_context(replica_ctx): 331 """Verify in `replica_ctx` replica context.""" 332 context = _get_per_thread_mode() 333 if context.replica_context is replica_ctx: return 334 # We have an error to report, figure out the right message. 335 if context.replica_context is None: 336 raise RuntimeError("Need to be inside `call_for_each_replica()`") 337 if context.strategy is replica_ctx.strategy: 338 # Two different ReplicaContexts with the same tf.distribute.Strategy. 339 raise RuntimeError("Mismatching ReplicaContext.") 340 raise RuntimeError( 341 "Mismatching tf.distribute.Strategy objects: %s is not %s." % 342 (context.strategy, replica_ctx.strategy)) 343 344 345def _require_strategy_scope_strategy(strategy): 346 """Verify in a `strategy.scope()` in this thread.""" 347 context = _get_per_thread_mode() 348 if context.strategy is strategy: return 349 _wrong_strategy_scope(strategy, context) 350 351 352def _require_strategy_scope_extended(extended): 353 """Verify in a `distribution_strategy.scope()` in this thread.""" 354 context = _get_per_thread_mode() 355 if context.strategy.extended is extended: return 356 # Report error. 357 strategy = extended._container_strategy() # pylint: disable=protected-access 358 _wrong_strategy_scope(strategy, context) 359 360 361# ------------------------------------------------------------------------------ 362# Internal context managers used to implement the DistributionStrategy 363# base class 364 365 366class _CurrentDistributionContext(object): 367 """Context manager setting the current `tf.distribute.Strategy`. 368 369 Also: overrides the variable creator and optionally the current device. 370 """ 371 372 def __init__(self, 373 strategy, 374 var_creator_scope, 375 var_scope=None, 376 resource_creator_scope=None, 377 default_device=None): 378 self._context = distribution_strategy_context._CrossReplicaThreadMode( # pylint: disable=protected-access 379 strategy) 380 self._var_creator_scope = var_creator_scope 381 self._var_scope = var_scope 382 self._resource_creator_scope = resource_creator_scope 383 if default_device: 384 self._device_scope = ops.device(default_device) 385 else: 386 self._device_scope = None 387 self._same_scope_again_count = 0 388 389 def __enter__(self): 390 # Allow this scope to be entered if this strategy is already in scope. 391 if distribution_strategy_context.has_strategy(): 392 _require_cross_replica_or_default_context_extended( 393 self._context.strategy.extended) 394 self._same_scope_again_count += 1 395 else: 396 _push_per_thread_mode(self._context) 397 if self._var_scope: 398 self._var_scope.__enter__() 399 self._var_creator_scope.__enter__() 400 if self._resource_creator_scope: 401 nest.map_structure(lambda scope: scope.__enter__(), 402 self._resource_creator_scope) 403 if self._device_scope: 404 self._device_scope.__enter__() 405 return self._context.strategy 406 407 def __exit__(self, exception_type, exception_value, traceback): 408 if self._same_scope_again_count > 0: 409 self._same_scope_again_count -= 1 410 return 411 if self._device_scope: 412 try: 413 self._device_scope.__exit__(exception_type, exception_value, traceback) 414 except RuntimeError as e: 415 six.raise_from( 416 RuntimeError("Device scope nesting error: move call to " 417 "tf.distribute.set_strategy() out of `with` scope."), 418 e) 419 420 try: 421 self._var_creator_scope.__exit__( 422 exception_type, exception_value, traceback) 423 except RuntimeError as e: 424 six.raise_from( 425 RuntimeError("Variable creator scope nesting error: move call to " 426 "tf.distribute.set_strategy() out of `with` scope."), 427 e) 428 429 if self._resource_creator_scope: 430 try: 431 if isinstance(self._resource_creator_scope, list): 432 reversed_resource_creator_scope = self._resource_creator_scope[::-1] 433 nest.map_structure( 434 lambda scope: scope.__exit__(exception_type, exception_value, # pylint:disable=g-long-lambda 435 traceback), 436 reversed_resource_creator_scope) 437 438 else: 439 self._resource_creator_scope.__exit__(exception_type, exception_value, 440 traceback) 441 except RuntimeError as e: 442 six.raise_from( 443 RuntimeError("Resource creator scope nesting error: move call " 444 "to tf.distribute.set_strategy() out of `with` " 445 "scope."), e) 446 447 if self._var_scope: 448 try: 449 self._var_scope.__exit__(exception_type, exception_value, traceback) 450 except RuntimeError as e: 451 six.raise_from( 452 RuntimeError("Variable scope nesting error: move call to " 453 "tf.distribute.set_strategy() out of `with` scope."), 454 e) 455 _pop_per_thread_mode() 456 457 458# TODO(yuefengz): add more replication modes. 459@tf_export("distribute.InputReplicationMode") 460class InputReplicationMode(enum.Enum): 461 """Replication mode for input function. 462 463 * `PER_WORKER`: The input function will be called on each worker 464 independently, creating as many input pipelines as number of workers. 465 Replicas will dequeue from the local Dataset on their worker. 466 `tf.distribute.Strategy` doesn't manage any state sharing between such 467 separate input pipelines. 468 * `PER_REPLICA`: The input function will be called on each replica separately. 469 `tf.distribute.Strategy` doesn't manage any state sharing between such 470 separate input pipelines. 471 """ 472 PER_WORKER = "PER_WORKER" 473 PER_REPLICA = "PER_REPLICA" 474 475 476@tf_export("distribute.InputContext") 477class InputContext(object): 478 """A class wrapping information needed by an input function. 479 480 This is a context class that is passed to the user's input function and 481 contains information about the compute replicas and input pipelines. The 482 number of compute replicas (in sync training) helps compute the local batch 483 size from the desired global batch size for each replica. The input pipeline 484 information can be used to return a different subset of the input in each 485 replica (for e.g. shard the input pipeline, use a different input 486 source etc). 487 """ 488 489 __slots__ = [ 490 "_num_input_pipelines", "_input_pipeline_id", "_num_replicas_in_sync" 491 ] 492 493 def __init__(self, 494 num_input_pipelines=1, 495 input_pipeline_id=0, 496 num_replicas_in_sync=1): 497 """Initializes an InputContext object. 498 499 Args: 500 num_input_pipelines: the number of input pipelines in a cluster. 501 input_pipeline_id: the current input pipeline id, should be an int in 502 [0,`num_input_pipelines`). 503 num_replicas_in_sync: the number of replicas that are in sync. 504 """ 505 self._num_input_pipelines = num_input_pipelines 506 self._input_pipeline_id = input_pipeline_id 507 self._num_replicas_in_sync = num_replicas_in_sync 508 509 @property 510 def num_replicas_in_sync(self): 511 """Returns the number of compute replicas in sync.""" 512 return self._num_replicas_in_sync 513 514 @property 515 def input_pipeline_id(self): 516 """Returns the input pipeline ID.""" 517 return self._input_pipeline_id 518 519 @property 520 def num_input_pipelines(self): 521 """Returns the number of input pipelines.""" 522 return self._num_input_pipelines 523 524 def get_per_replica_batch_size(self, global_batch_size): 525 """Returns the per-replica batch size. 526 527 Args: 528 global_batch_size: the global batch size which should be divisible by 529 `num_replicas_in_sync`. 530 531 Returns: 532 the per-replica batch size. 533 534 Raises: 535 ValueError: if `global_batch_size` not divisible by 536 `num_replicas_in_sync`. 537 """ 538 if global_batch_size % self._num_replicas_in_sync != 0: 539 raise ValueError("The `global_batch_size` %r is not divisible by " 540 "`num_replicas_in_sync` %r " % 541 (global_batch_size, self._num_replicas_in_sync)) 542 return global_batch_size // self._num_replicas_in_sync 543 544 def __str__(self): 545 return "tf.distribute.InputContext(input pipeline id {}, total: {})".format( 546 self.input_pipeline_id, self.num_input_pipelines) 547 548 549@tf_export("distribute.experimental.ValueContext", v1=[]) 550class ValueContext(object): 551 """A class wrapping information needed by a distribute function. 552 553 This is a context class that is passed to the `value_fn` in 554 `strategy.experimental_distribute_values_from_function` and contains 555 information about the compute replicas. The `num_replicas_in_sync` and 556 `replica_id` can be used to customize the value on each replica. 557 558 Example usage: 559 560 1. Directly constructed. 561 562 >>> def value_fn(context): 563 ... return context.replica_id_in_sync_group/context.num_replicas_in_sync 564 >>> context = tf.distribute.experimental.ValueContext( 565 ... replica_id_in_sync_group=2, num_replicas_in_sync=4) 566 >>> per_replica_value = value_fn(context) 567 >>> per_replica_value 568 0.5 569 570 2. Passed in by `experimental_distribute_values_from_function`. 571 572 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 573 >>> def value_fn(value_context): 574 ... return value_context.num_replicas_in_sync 575 >>> distributed_values = ( 576 ... strategy.experimental_distribute_values_from_function( 577 ... value_fn)) 578 >>> local_result = strategy.experimental_local_results(distributed_values) 579 >>> local_result 580 (2, 2) 581 582 """ 583 584 __slots__ = ["_replica_id_in_sync_group", "_num_replicas_in_sync"] 585 586 def __init__(self, 587 replica_id_in_sync_group=0, 588 num_replicas_in_sync=1): 589 """Initializes an ValueContext object. 590 591 Args: 592 replica_id_in_sync_group: the current replica_id, should be an int in 593 [0,`num_replicas_in_sync`). 594 num_replicas_in_sync: the number of replicas that are in sync. 595 """ 596 self._replica_id_in_sync_group = replica_id_in_sync_group 597 self._num_replicas_in_sync = num_replicas_in_sync 598 599 @property 600 def num_replicas_in_sync(self): 601 """Returns the number of compute replicas in sync.""" 602 return self._num_replicas_in_sync 603 604 @property 605 def replica_id_in_sync_group(self): 606 """Returns the replica ID.""" 607 return self._replica_id_in_sync_group 608 609 def __str__(self): 610 return (("tf.distribute.ValueContext(replica id {}, " 611 " total replicas in sync: ""{})") 612 .format(self.replica_id_in_sync_group, self.num_replicas_in_sync)) 613 614 615@tf_export("distribute.RunOptions") 616class RunOptions( 617 collections.namedtuple("RunOptions", [ 618 "experimental_enable_dynamic_batch_size", 619 "experimental_bucketizing_dynamic_shape", 620 "experimental_xla_options", 621 ])): 622 """Run options for `strategy.run`. 623 624 This can be used to hold some strategy specific configs. 625 626 Attributes: 627 experimental_enable_dynamic_batch_size: Boolean. Only applies to 628 TPUStrategy. Default to True. If True, TPUStrategy will enable dynamic 629 padder to support dynamic batch size for the inputs. Otherwise only static 630 shape inputs are allowed. 631 experimental_bucketizing_dynamic_shape: Boolean. Only applies to 632 TPUStrategy. Default to False. If True, TPUStrategy will automatic 633 bucketize inputs passed into `run` if the input shape is 634 dynamic. This is a performance optimization to reduce XLA recompilation, 635 which should not have impact on correctness. 636 experimental_xla_options: A `tf.tpu.XLAOptions` instance. Only applies to 637 TPUStrategy. Controls the XLA compiling options on TPUs. Default to None. 638 """ 639 640 def __new__(cls, 641 experimental_enable_dynamic_batch_size=True, 642 experimental_bucketizing_dynamic_shape=False, 643 experimental_xla_options=None): 644 return super(RunOptions, 645 cls).__new__(cls, experimental_enable_dynamic_batch_size, 646 experimental_bucketizing_dynamic_shape, 647 experimental_xla_options) 648 649 650@tf_export("distribute.InputOptions", v1=[]) 651class InputOptions( 652 collections.namedtuple("InputOptions", [ 653 "experimental_fetch_to_device", 654 "experimental_replication_mode", 655 "experimental_place_dataset_on_device", 656 "experimental_per_replica_buffer_size", 657 ])): 658 """Run options for `experimental_distribute_dataset(s_from_function)`. 659 660 This can be used to hold some strategy specific configs. 661 662 ```python 663 # Setup TPUStrategy 664 resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 665 tf.config.experimental_connect_to_cluster(resolver) 666 tf.tpu.experimental.initialize_tpu_system(resolver) 667 strategy = tf.distribute.TPUStrategy(resolver) 668 669 dataset = tf.data.Dataset.range(16) 670 distributed_dataset_on_host = ( 671 strategy.experimental_distribute_dataset( 672 dataset, 673 tf.distribute.InputOptions( 674 experimental_replication_mode= 675 experimental_replication_mode.PER_WORKER, 676 experimental_place_dataset_on_device=False, 677 experimental_per_replica_buffer_size=1))) 678 ``` 679 680 Attributes: 681 experimental_fetch_to_device: Boolean. If True, dataset 682 elements will be prefetched to accelerator device memory. When False, 683 dataset elements are prefetched to host device memory. Must be False when 684 using TPUEmbedding API. experimental_fetch_to_device can only be used 685 with experimental_replication_mode=PER_WORKER. Default behavior is same as 686 setting it to True. 687 experimental_replication_mode: Replication mode for the input function. 688 Currently, the InputReplicationMode.PER_REPLICA is only supported with 689 tf.distribute.MirroredStrategy. 690 experimental_distribute_datasets_from_function. 691 The default value is InputReplicationMode.PER_WORKER. 692 experimental_place_dataset_on_device: Boolean. Default to False. When True, 693 dataset will be placed on the device, otherwise it will remain on the 694 host. experimental_place_dataset_on_device=True can only be used with 695 experimental_replication_mode=PER_REPLICA 696 experimental_per_replica_buffer_size: Integer. Default to 1. Indicates the 697 prefetch buffer size in the replica device memory. Users can set it 698 to 0 to completely disable prefetching behavior, or a number greater than 699 1 to enable larger buffer size. Note that this option is still 700 valid with `experimental_fetch_to_device=False`. 701 """ 702 703 def __new__(cls, 704 experimental_fetch_to_device=None, 705 experimental_replication_mode=InputReplicationMode.PER_WORKER, 706 experimental_place_dataset_on_device=False, 707 experimental_per_replica_buffer_size=1): 708 if experimental_fetch_to_device is None: 709 experimental_fetch_to_device = True 710 711 return super(InputOptions, 712 cls).__new__(cls, experimental_fetch_to_device, 713 experimental_replication_mode, 714 experimental_place_dataset_on_device, 715 experimental_per_replica_buffer_size) 716 717# ------------------------------------------------------------------------------ 718# Base classes for all distribution strategies. 719 720 721# Base class for v1 Strategy and v2 Strategy classes. For API's specific to 722# v1/v2 Strategy, add to implementing classes of StrategyBase. 723# pylint: disable=line-too-long 724class StrategyBase(object): 725 """A state & compute distribution policy on a list of devices. 726 727 See [the guide](https://www.tensorflow.org/guide/distributed_training) 728 for overview and examples. See `tf.distribute.StrategyExtended` and 729 [`tf.distribute`](https://www.tensorflow.org/api_docs/python/tf/distribute) 730 for a glossary of concepts mentioned on this page such as "per-replica", 731 _replica_, and _reduce_. 732 733 In short: 734 735 * To use it with Keras `compile`/`fit`, 736 [please 737 read](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_keras). 738 * You may pass descendant of `tf.distribute.Strategy` to 739 `tf.estimator.RunConfig` to specify how a `tf.estimator.Estimator` 740 should distribute its computation. See 741 [guide](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_estimator_limited_support). 742 * Otherwise, use `tf.distribute.Strategy.scope` to specify that a 743 strategy should be used when building an executing your model. 744 (This puts you in the "cross-replica context" for this strategy, which 745 means the strategy is put in control of things like variable placement.) 746 * If you are writing a custom training loop, you will need to call a few more 747 methods, 748 [see the 749 guide](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_custom_training_loops): 750 751 * Start by creating a `tf.data.Dataset` normally. 752 * Use `tf.distribute.Strategy.experimental_distribute_dataset` to convert 753 a `tf.data.Dataset` to something that produces "per-replica" values. 754 If you want to manually specify how the dataset should be partitioned 755 across replicas, use 756 `tf.distribute.Strategy.distribute_datasets_from_function` 757 instead. 758 * Use `tf.distribute.Strategy.run` to run a function 759 once per replica, taking values that may be "per-replica" (e.g. 760 from a `tf.distribute.DistributedDataset` object) and returning 761 "per-replica" values. 762 This function is executed in "replica context", which means each 763 operation is performed separately on each replica. 764 * Finally use a method (such as `tf.distribute.Strategy.reduce`) to 765 convert the resulting "per-replica" values into ordinary `Tensor`s. 766 767 A custom training loop can be as simple as: 768 769 ``` 770 with my_strategy.scope(): 771 @tf.function 772 def distribute_train_epoch(dataset): 773 def replica_fn(input): 774 # process input and return result 775 return result 776 777 total_result = 0 778 for x in dataset: 779 per_replica_result = my_strategy.run(replica_fn, args=(x,)) 780 total_result += my_strategy.reduce(tf.distribute.ReduceOp.SUM, 781 per_replica_result, axis=None) 782 return total_result 783 784 dist_dataset = my_strategy.experimental_distribute_dataset(dataset) 785 for _ in range(EPOCHS): 786 train_result = distribute_train_epoch(dist_dataset) 787 ``` 788 789 This takes an ordinary `dataset` and `replica_fn` and runs it 790 distributed using a particular `tf.distribute.Strategy` named 791 `my_strategy` above. Any variables created in `replica_fn` are created 792 using `my_strategy`'s policy, and library functions called by 793 `replica_fn` can use the `get_replica_context()` API to implement 794 distributed-specific behavior. 795 796 You can use the `reduce` API to aggregate results across replicas and use 797 this as a return value from one iteration over a 798 `tf.distribute.DistributedDataset`. Or 799 you can use `tf.keras.metrics` (such as loss, accuracy, etc.) to 800 accumulate metrics across steps in a given epoch. 801 802 See the 803 [custom training loop 804 tutorial](https://www.tensorflow.org/tutorials/distribute/custom_training) 805 for a more detailed example. 806 807 Note: `tf.distribute.Strategy` currently does not support TensorFlow's 808 partitioned variables (where a single variable is split across multiple 809 devices) at this time. 810 """ 811 # pylint: enable=line-too-long 812 813 # TODO(josh11b): Partitioned computations, state; sharding 814 # TODO(josh11b): Model parallelism: "replicas" with multiple devices; shuffling 815 816 def __init__(self, extended): 817 self._extended = extended 818 819 # Flag that is used to indicate whether distribution strategy is used with 820 # Estimator. This is required for backward compatibility of loss scaling 821 # when using v1 optimizer with estimator. 822 self._scale_loss_for_estimator = False 823 824 if not hasattr(extended, "_retrace_functions_for_each_device"): 825 # pylint: disable=protected-access 826 # `extended._retrace_functions_for_each_device` dictates 827 # whether the same function will be retraced when it is called on 828 # different devices. 829 try: 830 extended._retrace_functions_for_each_device = ( 831 len(extended.worker_devices) > 1) 832 distribution_strategy_replica_gauge.get_cell("num_replicas").set( 833 self.num_replicas_in_sync) 834 except: # pylint: disable=bare-except 835 # Default for the case where extended.worker_devices can't return 836 # a sensible value. 837 extended._retrace_functions_for_each_device = True 838 839 # Below are the dicts of axis(int) -> `tf.function`. 840 self._mean_reduce_helper_fns = {} 841 self._reduce_sum_fns = {} 842 843 # Whether this strategy is designed to work with `ClusterCoordinator`. 844 self._should_use_with_coordinator = False 845 846 @property 847 def extended(self): 848 """`tf.distribute.StrategyExtended` with additional methods.""" 849 return self._extended 850 851 @tf_contextlib.contextmanager 852 def _scale_loss_for_estimator_enabled(self): 853 """Scope which sets a flag used for scaling losses in optimizer. 854 855 Yields: 856 `_scale_loss_for_estimator_enabled` is a context manager with a 857 side effect, but doesn't return a value. 858 """ 859 self._scale_loss_for_estimator = True 860 try: 861 yield 862 finally: 863 self._scale_loss_for_estimator = False 864 865 # pylint: disable=line-too-long 866 def scope(self): 867 """Context manager to make the strategy current and distribute variables. 868 869 This method returns a context manager, and is used as follows: 870 871 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 872 >>> # Variable created inside scope: 873 >>> with strategy.scope(): 874 ... mirrored_variable = tf.Variable(1.) 875 >>> mirrored_variable 876 MirroredVariable:{ 877 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>, 878 1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=1.0> 879 } 880 >>> # Variable created outside scope: 881 >>> regular_variable = tf.Variable(1.) 882 >>> regular_variable 883 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0> 884 885 _What happens when Strategy.scope is entered?_ 886 887 * `strategy` is installed in the global context as the "current" strategy. 888 Inside this scope, `tf.distribute.get_strategy()` will now return this 889 strategy. Outside this scope, it returns the default no-op strategy. 890 * Entering the scope also enters the "cross-replica context". See 891 `tf.distribute.StrategyExtended` for an explanation on cross-replica and 892 replica contexts. 893 * Variable creation inside `scope` is intercepted by the strategy. Each 894 strategy defines how it wants to affect the variable creation. Sync 895 strategies like `MirroredStrategy`, `TPUStrategy` and 896 `MultiWorkerMiroredStrategy` create variables replicated on each replica, 897 whereas `ParameterServerStrategy` creates variables on the parameter 898 servers. This is done using a custom `tf.variable_creator_scope`. 899 * In some strategies, a default device scope may also be entered: in 900 `MultiWorkerMiroredStrategy`, a default device scope of "/CPU:0" is 901 entered on each worker. 902 903 Note: Entering a scope does not automatically distribute a computation, except 904 in the case of high level training framework like keras `model.fit`. If 905 you're not using `model.fit`, you 906 need to use `strategy.run` API to explicitly distribute that computation. 907 See an example in the [custom training loop tutorial](https://www.tensorflow.org/tutorials/distribute/custom_training). 908 909 910 _What should be in scope and what should be outside?_ 911 912 There are a number of requirements on what needs to happen inside the scope. 913 However, in places where we have information about which strategy is in use, 914 we often enter the scope for the user, so they don't have to do it 915 explicitly (i.e. calling those either inside or outside the scope is OK). 916 917 * Anything that creates variables that should be distributed variables 918 must be called in a `strategy.scope`. This can be accomplished either by 919 directly calling the variable creating function within the scope context, 920 or by relying on another API like `strategy.run` or `keras.Model.fit` to 921 automatically enter it for you. Any variable that is created outside scope 922 will not be distributed and may have performance implications. Some common 923 objects that create variables in TF are Models, Optimizers, Metrics. Such 924 objects should always be initialized in the scope, and any functions 925 that may lazily create variables (e.g., `Model.__call__()`, tracing a 926 `tf.function`, etc.) should similarly be called within scope. Another 927 source of variable creation can be a checkpoint restore - when variables 928 are created lazily. Note that any variable created inside a strategy 929 captures the strategy information. So reading and writing to these 930 variables outside the `strategy.scope` can also work seamlessly, without 931 the user having to enter the scope. 932 * Some strategy APIs (such as `strategy.run` and `strategy.reduce`) which 933 require to be in a strategy's scope, enter the scope automatically, which 934 means when using those APIs you don't need to explicitly enter the scope 935 yourself. 936 * When a `tf.keras.Model` is created inside a `strategy.scope`, the Model 937 object captures the scope information. When high level training framework 938 methods such as `model.compile`, `model.fit`, etc. are then called, the 939 captured scope will be automatically entered, and the associated strategy 940 will be used to distribute the training etc. See a detailed example in 941 [distributed keras tutorial](https://www.tensorflow.org/tutorials/distribute/keras). 942 WARNING: Simply calling `model(..)` does not automatically enter the 943 captured scope -- only high level training framework APIs support this 944 behavior: `model.compile`, `model.fit`, `model.evaluate`, `model.predict` 945 and `model.save` can all be called inside or outside the scope. 946 * The following can be either inside or outside the scope: 947 * Creating the input datasets 948 * Defining `tf.function`s that represent your training step 949 * Saving APIs such as `tf.saved_model.save`. Loading creates variables, 950 so that should go inside the scope if you want to train the model in a 951 distributed way. 952 * Checkpoint saving. As mentioned above - `checkpoint.restore` may 953 sometimes need to be inside scope if it creates variables. 954 955 Returns: 956 A context manager. 957 """ 958 return self._extended._scope(self) # pylint: disable=protected-access 959 # pylint: enable=line-too-long 960 961 @doc_controls.do_not_doc_inheritable # DEPRECATED, moving to `extended` 962 @deprecated(None, "use extended.colocate_vars_with() instead.") 963 def colocate_vars_with(self, colocate_with_variable): 964 """DEPRECATED: use extended.colocate_vars_with() instead.""" 965 return self._extended.colocate_vars_with(colocate_with_variable) 966 967 @doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only 968 def make_dataset_iterator(self, dataset): 969 """DEPRECATED TF 1.x ONLY.""" 970 return self._extended._make_dataset_iterator(dataset) # pylint: disable=protected-access 971 972 @doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only 973 def make_input_fn_iterator(self, 974 input_fn, 975 replication_mode=InputReplicationMode.PER_WORKER): 976 """DEPRECATED TF 1.x ONLY.""" 977 if replication_mode != InputReplicationMode.PER_WORKER: 978 raise ValueError( 979 "Input replication mode not supported: %r" % replication_mode) 980 with self.scope(): 981 return self.extended._make_input_fn_iterator( # pylint: disable=protected-access 982 input_fn, replication_mode=replication_mode) 983 984 @doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only 985 @deprecated(None, "use run() instead") 986 def experimental_run(self, fn, input_iterator=None): 987 """DEPRECATED TF 1.x ONLY.""" 988 with self.scope(): 989 args = (input_iterator.get_next(),) if input_iterator is not None else () 990 return self.run(fn, args=args) 991 992 def experimental_distribute_dataset(self, dataset, options=None): 993 # pylint: disable=line-too-long 994 """Creates `tf.distribute.DistributedDataset` from `tf.data.Dataset`. 995 996 The returned `tf.distribute.DistributedDataset` can be iterated over 997 similar to regular datasets. 998 NOTE: The user cannot add any more transformations to a 999 `tf.distribute.DistributedDataset`. You can only create an iterator or 1000 examine the `tf.TypeSpec` of the data generated by it. See API docs of 1001 `tf.distribute.DistributedDataset` to learn more. 1002 1003 The following is an example: 1004 1005 >>> global_batch_size = 2 1006 >>> # Passing the devices is optional. 1007 ... strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"]) 1008 >>> # Create a dataset 1009 ... dataset = tf.data.Dataset.range(4).batch(global_batch_size) 1010 >>> # Distribute that dataset 1011 ... dist_dataset = strategy.experimental_distribute_dataset(dataset) 1012 >>> @tf.function 1013 ... def replica_fn(input): 1014 ... return input*2 1015 >>> result = [] 1016 >>> # Iterate over the `tf.distribute.DistributedDataset` 1017 ... for x in dist_dataset: 1018 ... # process dataset elements 1019 ... result.append(strategy.run(replica_fn, args=(x,))) 1020 >>> print(result) 1021 [PerReplica:{ 1022 0: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([0])>, 1023 1: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([2])> 1024 }, PerReplica:{ 1025 0: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([4])>, 1026 1: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([6])> 1027 }] 1028 1029 1030 Three key actions happening under the hood of this method are batching, 1031 sharding, and prefetching. 1032 1033 In the code snippet above, `dataset` is batched by `global_batch_size`, and 1034 calling `experimental_distribute_dataset` on it rebatches `dataset` to a 1035 new batch size that is equal to the global batch size divided by the number 1036 of replicas in sync. We iterate through it using a Pythonic for loop. 1037 `x` is a `tf.distribute.DistributedValues` containing data for all replicas, 1038 and each replica gets data of the new batch size. 1039 `tf.distribute.Strategy.run` will take care of feeding the right per-replica 1040 data in `x` to the right `replica_fn` executed on each replica. 1041 1042 Sharding contains autosharding across multiple workers and within every 1043 worker. First, in multi-worker distributed training (i.e. when you use 1044 `tf.distribute.experimental.MultiWorkerMirroredStrategy` 1045 or `tf.distribute.TPUStrategy`), autosharding a dataset over a set of 1046 workers means that each worker is assigned a subset of the entire dataset 1047 (if the right `tf.data.experimental.AutoShardPolicy` is set). This is to 1048 ensure that at each step, a global batch size of non-overlapping dataset 1049 elements will be processed by each worker. Autosharding has a couple of 1050 different options that can be specified using 1051 `tf.data.experimental.DistributeOptions`. Then, sharding within each worker 1052 means the method will split the data among all the worker devices (if more 1053 than one a present). This will happen regardless of multi-worker 1054 autosharding. 1055 1056 Note: for autosharding across multiple workers, the default mode is 1057 `tf.data.experimental.AutoShardPolicy.AUTO`. This mode 1058 will attempt to shard the input dataset by files if the dataset is 1059 being created out of reader datasets (e.g. `tf.data.TFRecordDataset`, 1060 `tf.data.TextLineDataset`, etc.) or otherwise shard the dataset by data, 1061 where each of the workers will read the entire dataset and only process the 1062 shard assigned to it. However, if you have less than one input file per 1063 worker, we suggest that you disable dataset autosharding across workers by 1064 setting the `tf.data.experimental.DistributeOptions.auto_shard_policy` to be 1065 `tf.data.experimental.AutoShardPolicy.OFF`. 1066 1067 By default, this method adds a prefetch transformation at the end of the 1068 user provided `tf.data.Dataset` instance. The argument to the prefetch 1069 transformation which is `buffer_size` is equal to the number of replicas in 1070 sync. 1071 1072 If the above batch splitting and dataset sharding logic is undesirable, 1073 please use 1074 `tf.distribute.Strategy.distribute_datasets_from_function` 1075 instead, which does not do any automatic batching or sharding for you. 1076 1077 Note: If you are using TPUStrategy, the order in which the data is processed 1078 by the workers when using 1079 `tf.distribute.Strategy.experimental_distribute_dataset` or 1080 `tf.distribute.Strategy.distribute_datasets_from_function` is 1081 not guaranteed. This is typically required if you are using 1082 `tf.distribute` to scale prediction. You can however insert an index for 1083 each element in the batch and order outputs accordingly. Refer to [this 1084 snippet](https://www.tensorflow.org/tutorials/distribute/input#caveats) 1085 for an example of how to order outputs. 1086 1087 Note: Stateful dataset transformations are currently not supported with 1088 `tf.distribute.experimental_distribute_dataset` or 1089 `tf.distribute.distribute_datasets_from_function`. Any stateful 1090 ops that the dataset may have are currently ignored. For example, if your 1091 dataset has a `map_fn` that uses `tf.random.uniform` to rotate an image, 1092 then you have a dataset graph that depends on state (i.e the random seed) on 1093 the local machine where the python process is being executed. 1094 1095 For a tutorial on more usage and properties of this method, refer to the 1096 [tutorial on distributed input](https://www.tensorflow.org/tutorials/distribute/input#tfdistributestrategyexperimental_distribute_dataset). 1097 If you are interested in last partial batch handling, read [this section](https://www.tensorflow.org/tutorials/distribute/input#partial_batches). 1098 1099 Args: 1100 dataset: `tf.data.Dataset` that will be sharded across all replicas using 1101 the rules stated above. 1102 options: `tf.distribute.InputOptions` used to control options on how this 1103 dataset is distributed. 1104 1105 Returns: 1106 A `tf.distribute.DistributedDataset`. 1107 """ 1108 distribution_strategy_input_api_counter.get_cell( 1109 self.__class__.__name__, "distribute_dataset").increase_by(1) 1110 # pylint: enable=line-too-long 1111 return self._extended._experimental_distribute_dataset(dataset, options) # pylint: disable=protected-access 1112 1113 def distribute_datasets_from_function(self, dataset_fn, options=None): 1114 # pylint: disable=line-too-long 1115 """Distributes `tf.data.Dataset` instances created by calls to `dataset_fn`. 1116 1117 The argument `dataset_fn` that users pass in is an input function that has a 1118 `tf.distribute.InputContext` argument and returns a `tf.data.Dataset` 1119 instance. It is expected that the returned dataset from `dataset_fn` is 1120 already batched by per-replica batch size (i.e. global batch size divided by 1121 the number of replicas in sync) and sharded. 1122 `tf.distribute.Strategy.distribute_datasets_from_function` does 1123 not batch or shard the `tf.data.Dataset` instance 1124 returned from the input function. `dataset_fn` will be called on the CPU 1125 device of each of the workers and each generates a dataset where every 1126 replica on that worker will dequeue one batch of inputs (i.e. if a worker 1127 has two replicas, two batches will be dequeued from the `Dataset` every 1128 step). 1129 1130 This method can be used for several purposes. First, it allows you to 1131 specify your own batching and sharding logic. (In contrast, 1132 `tf.distribute.experimental_distribute_dataset` does batching and sharding 1133 for you.) For example, where 1134 `experimental_distribute_dataset` is unable to shard the input files, this 1135 method might be used to manually shard the dataset (avoiding the slow 1136 fallback behavior in `experimental_distribute_dataset`). In cases where the 1137 dataset is infinite, this sharding can be done by creating dataset replicas 1138 that differ only in their random seed. 1139 1140 The `dataset_fn` should take an `tf.distribute.InputContext` instance where 1141 information about batching and input replication can be accessed. 1142 1143 You can use `element_spec` property of the 1144 `tf.distribute.DistributedDataset` returned by this API to query the 1145 `tf.TypeSpec` of the elements returned by the iterator. This can be used to 1146 set the `input_signature` property of a `tf.function`. Follow 1147 `tf.distribute.DistributedDataset.element_spec` to see an example. 1148 1149 IMPORTANT: The `tf.data.Dataset` returned by `dataset_fn` should have a 1150 per-replica batch size, unlike `experimental_distribute_dataset`, which uses 1151 the global batch size. This may be computed using 1152 `input_context.get_per_replica_batch_size`. 1153 1154 Note: If you are using TPUStrategy, the order in which the data is processed 1155 by the workers when using 1156 `tf.distribute.Strategy.experimental_distribute_dataset` or 1157 `tf.distribute.Strategy.distribute_datasets_from_function` is 1158 not guaranteed. This is typically required if you are using 1159 `tf.distribute` to scale prediction. You can however insert an index for 1160 each element in the batch and order outputs accordingly. Refer to [this 1161 snippet](https://www.tensorflow.org/tutorials/distribute/input#caveats) 1162 for an example of how to order outputs. 1163 1164 Note: Stateful dataset transformations are currently not supported with 1165 `tf.distribute.experimental_distribute_dataset` or 1166 `tf.distribute.distribute_datasets_from_function`. Any stateful 1167 ops that the dataset may have are currently ignored. For example, if your 1168 dataset has a `map_fn` that uses `tf.random.uniform` to rotate an image, 1169 then you have a dataset graph that depends on state (i.e the random seed) on 1170 the local machine where the python process is being executed. 1171 1172 For a tutorial on more usage and properties of this method, refer to the 1173 [tutorial on distributed input](https://www.tensorflow.org/tutorials/distribute/input#tfdistributestrategyexperimental_distribute_datasets_from_function)). 1174 If you are interested in last partial batch handling, read [this section](https://www.tensorflow.org/tutorials/distribute/input#partial_batches). 1175 1176 Args: 1177 dataset_fn: A function taking a `tf.distribute.InputContext` instance and 1178 returning a `tf.data.Dataset`. 1179 options: `tf.distribute.InputOptions` used to control options on how this 1180 dataset is distributed. 1181 1182 Returns: 1183 A `tf.distribute.DistributedDataset`. 1184 """ 1185 distribution_strategy_input_api_counter.get_cell( 1186 self.__class__.__name__, 1187 "distribute_datasets_from_function").increase_by(1) 1188 # pylint: enable=line-too-long 1189 return self._extended._distribute_datasets_from_function( # pylint: disable=protected-access 1190 dataset_fn, options) 1191 1192 # TODO(b/162776748): Remove deprecated symbol. 1193 @doc_controls.do_not_doc_inheritable 1194 @deprecation.deprecated(None, "rename to distribute_datasets_from_function") 1195 def experimental_distribute_datasets_from_function(self, 1196 dataset_fn, 1197 options=None): 1198 return self.distribute_datasets_from_function(dataset_fn, options) 1199 1200 def run(self, fn, args=(), kwargs=None, options=None): 1201 """Invokes `fn` on each replica, with the given arguments. 1202 1203 This method is the primary way to distribute your computation with a 1204 tf.distribute object. It invokes `fn` on each replica. If `args` or `kwargs` 1205 have `tf.distribute.DistributedValues`, such as those produced by a 1206 `tf.distribute.DistributedDataset` from 1207 `tf.distribute.Strategy.experimental_distribute_dataset` or 1208 `tf.distribute.Strategy.distribute_datasets_from_function`, 1209 when `fn` is executed on a particular replica, it will be executed with the 1210 component of `tf.distribute.DistributedValues` that correspond to that 1211 replica. 1212 1213 `fn` is invoked under a replica context. `fn` may call 1214 `tf.distribute.get_replica_context()` to access members such as 1215 `all_reduce`. Please see the module-level docstring of tf.distribute for the 1216 concept of replica context. 1217 1218 All arguments in `args` or `kwargs` can be a nested structure of tensors, 1219 e.g. a list of tensors, in which case `args` and `kwargs` will be passed to 1220 the `fn` invoked on each replica. Or `args` or `kwargs` can be 1221 `tf.distribute.DistributedValues` containing tensors or composite tensors, 1222 i.e. `tf.compat.v1.TensorInfo.CompositeTensor`, in which case each `fn` call 1223 will get the component of a `tf.distribute.DistributedValues` corresponding 1224 to its replica. Note that arbitrary Python values that are not of the types 1225 above are not supported. 1226 1227 IMPORTANT: Depending on the implementation of `tf.distribute.Strategy` and 1228 whether eager execution is enabled, `fn` may be called one or more times. If 1229 `fn` is annotated with `tf.function` or `tf.distribute.Strategy.run` is 1230 called inside a `tf.function` (eager execution is disabled inside a 1231 `tf.function` by default), `fn` is called once per replica to generate a 1232 Tensorflow graph, which will then be reused for execution with new inputs. 1233 Otherwise, if eager execution is enabled, `fn` will be called once per 1234 replica every step just like regular python code. 1235 1236 Example usage: 1237 1238 1. Constant tensor input. 1239 1240 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 1241 >>> tensor_input = tf.constant(3.0) 1242 >>> @tf.function 1243 ... def replica_fn(input): 1244 ... return input*2.0 1245 >>> result = strategy.run(replica_fn, args=(tensor_input,)) 1246 >>> result 1247 PerReplica:{ 1248 0: <tf.Tensor: shape=(), dtype=float32, numpy=6.0>, 1249 1: <tf.Tensor: shape=(), dtype=float32, numpy=6.0> 1250 } 1251 1252 2. DistributedValues input. 1253 1254 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 1255 >>> @tf.function 1256 ... def run(): 1257 ... def value_fn(value_context): 1258 ... return value_context.num_replicas_in_sync 1259 ... distributed_values = ( 1260 ... strategy.experimental_distribute_values_from_function( 1261 ... value_fn)) 1262 ... def replica_fn2(input): 1263 ... return input*2 1264 ... return strategy.run(replica_fn2, args=(distributed_values,)) 1265 >>> result = run() 1266 >>> result 1267 <tf.Tensor: shape=(), dtype=int32, numpy=4> 1268 1269 3. Use `tf.distribute.ReplicaContext` to allreduce values. 1270 1271 >>> strategy = tf.distribute.MirroredStrategy(["gpu:0", "gpu:1"]) 1272 >>> @tf.function 1273 ... def run(): 1274 ... def value_fn(value_context): 1275 ... return tf.constant(value_context.replica_id_in_sync_group) 1276 ... distributed_values = ( 1277 ... strategy.experimental_distribute_values_from_function( 1278 ... value_fn)) 1279 ... def replica_fn(input): 1280 ... return tf.distribute.get_replica_context().all_reduce("sum", input) 1281 ... return strategy.run(replica_fn, args=(distributed_values,)) 1282 >>> result = run() 1283 >>> result 1284 PerReplica:{ 1285 0: <tf.Tensor: shape=(), dtype=int32, numpy=1>, 1286 1: <tf.Tensor: shape=(), dtype=int32, numpy=1> 1287 } 1288 1289 Args: 1290 fn: The function to run on each replica. 1291 args: Optional positional arguments to `fn`. Its element can be a tensor, 1292 a nested structure of tensors or a `tf.distribute.DistributedValues`. 1293 kwargs: Optional keyword arguments to `fn`. Its element can be a tensor, 1294 a nested structure of tensors or a `tf.distribute.DistributedValues`. 1295 options: An optional instance of `tf.distribute.RunOptions` specifying 1296 the options to run `fn`. 1297 1298 Returns: 1299 Merged return value of `fn` across replicas. The structure of the return 1300 value is the same as the return value from `fn`. Each element in the 1301 structure can either be `tf.distribute.DistributedValues`, `Tensor` 1302 objects, or `Tensor`s (for example, if running on a single replica). 1303 """ 1304 del options 1305 1306 if not isinstance(args, (list, tuple)): 1307 raise ValueError( 1308 "positional args must be a list or tuple, got {}".format(type(args))) 1309 1310 with self.scope(): 1311 # tf.distribute supports Eager functions, so AutoGraph should not be 1312 # applied when the caller is also in Eager mode. 1313 fn = autograph.tf_convert( 1314 fn, autograph_ctx.control_status_ctx(), convert_by_default=False) 1315 return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs) 1316 1317 def reduce(self, reduce_op, value, axis): 1318 """Reduce `value` across replicas and return result on current device. 1319 1320 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 1321 >>> def step_fn(): 1322 ... i = tf.distribute.get_replica_context().replica_id_in_sync_group 1323 ... return tf.identity(i) 1324 >>> 1325 >>> per_replica_result = strategy.run(step_fn) 1326 >>> total = strategy.reduce("SUM", per_replica_result, axis=None) 1327 >>> total 1328 <tf.Tensor: shape=(), dtype=int32, numpy=1> 1329 1330 To see how this would look with multiple replicas, consider the same 1331 example with MirroredStrategy with 2 GPUs: 1332 1333 ```python 1334 strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"]) 1335 def step_fn(): 1336 i = tf.distribute.get_replica_context().replica_id_in_sync_group 1337 return tf.identity(i) 1338 1339 per_replica_result = strategy.run(step_fn) 1340 # Check devices on which per replica result is: 1341 strategy.experimental_local_results(per_replica_result)[0].device 1342 # /job:localhost/replica:0/task:0/device:GPU:0 1343 strategy.experimental_local_results(per_replica_result)[1].device 1344 # /job:localhost/replica:0/task:0/device:GPU:1 1345 1346 total = strategy.reduce("SUM", per_replica_result, axis=None) 1347 # Check device on which reduced result is: 1348 total.device 1349 # /job:localhost/replica:0/task:0/device:CPU:0 1350 1351 ``` 1352 1353 This API is typically used for aggregating the results returned from 1354 different replicas, for reporting etc. For example, loss computed from 1355 different replicas can be averaged using this API before printing. 1356 1357 Note: The result is copied to the "current" device - which would typically 1358 be the CPU of the worker on which the program is running. For `TPUStrategy`, 1359 it is the first TPU host. For multi client `MultiWorkerMirroredStrategy`, 1360 this is CPU of each worker. 1361 1362 There are a number of different tf.distribute APIs for reducing values 1363 across replicas: 1364 * `tf.distribute.ReplicaContext.all_reduce`: This differs from 1365 `Strategy.reduce` in that it is for replica context and does 1366 not copy the results to the host device. `all_reduce` should be typically 1367 used for reductions inside the training step such as gradients. 1368 * `tf.distribute.StrategyExtended.reduce_to` and 1369 `tf.distribute.StrategyExtended.batch_reduce_to`: These APIs are more 1370 advanced versions of `Strategy.reduce` as they allow customizing the 1371 destination of the result. They are also called in cross replica context. 1372 1373 _What should axis be?_ 1374 1375 Given a per-replica value returned by `run`, say a 1376 per-example loss, the batch will be divided across all the replicas. This 1377 function allows you to aggregate across replicas and optionally also across 1378 batch elements by specifying the axis parameter accordingly. 1379 1380 For example, if you have a global batch size of 8 and 2 1381 replicas, values for examples `[0, 1, 2, 3]` will be on replica 0 and 1382 `[4, 5, 6, 7]` will be on replica 1. With `axis=None`, `reduce` will 1383 aggregate only across replicas, returning `[0+4, 1+5, 2+6, 3+7]`. 1384 This is useful when each replica is computing a scalar or some other value 1385 that doesn't have a "batch" dimension (like a gradient or loss). 1386 ``` 1387 strategy.reduce("sum", per_replica_result, axis=None) 1388 ``` 1389 1390 Sometimes, you will want to aggregate across both the global batch _and_ 1391 all replicas. You can get this behavior by specifying the batch 1392 dimension as the `axis`, typically `axis=0`. In this case it would return a 1393 scalar `0+1+2+3+4+5+6+7`. 1394 ``` 1395 strategy.reduce("sum", per_replica_result, axis=0) 1396 ``` 1397 1398 If there is a last partial batch, you will need to specify an axis so 1399 that the resulting shape is consistent across replicas. So if the last 1400 batch has size 6 and it is divided into [0, 1, 2, 3] and [4, 5], you 1401 would get a shape mismatch unless you specify `axis=0`. If you specify 1402 `tf.distribute.ReduceOp.MEAN`, using `axis=0` will use the correct 1403 denominator of 6. Contrast this with computing `reduce_mean` to get a 1404 scalar value on each replica and this function to average those means, 1405 which will weigh some values `1/8` and others `1/4`. 1406 1407 Args: 1408 reduce_op: a `tf.distribute.ReduceOp` value specifying how values should 1409 be combined. Allows using string representation of the enum such as 1410 "SUM", "MEAN". 1411 value: a `tf.distribute.DistributedValues` instance, e.g. returned by 1412 `Strategy.run`, to be combined into a single tensor. It can also be a 1413 regular tensor when used with `OneDeviceStrategy` or default strategy. 1414 axis: specifies the dimension to reduce along within each 1415 replica's tensor. Should typically be set to the batch dimension, or 1416 `None` to only reduce across replicas (e.g. if the tensor has no batch 1417 dimension). 1418 1419 Returns: 1420 A `Tensor`. 1421 """ 1422 # TODO(josh11b): support `value` being a nest. 1423 _require_cross_replica_or_default_context_extended(self._extended) 1424 if isinstance(reduce_op, six.string_types): 1425 reduce_op = reduce_util.ReduceOp(reduce_op.upper()) 1426 if axis is None: 1427 return self._extended._reduce(reduce_op, value) # pylint: disable=protected-access 1428 if reduce_op == reduce_util.ReduceOp.SUM: 1429 1430 def reduce_sum(v): 1431 return math_ops.reduce_sum(v, axis=axis) 1432 1433 if eager_context.executing_eagerly(): 1434 # As some strategies (e.g. TPUStrategy) doesn't support pure eager 1435 # execution, wrap the `reduce_sum_fn` with a `tf.function` so it can be 1436 # run from eager mode. Cache the tf.function by `axis` to avoid the 1437 # same function to be traced again. 1438 if axis not in self._reduce_sum_fns: 1439 1440 def reduce_sum_fn(v): 1441 return self.run(reduce_sum, args=(v,)) 1442 1443 self._reduce_sum_fns[axis] = def_function.function(reduce_sum_fn) 1444 value = self._reduce_sum_fns[axis](value) 1445 else: 1446 value = self.run(reduce_sum, args=(value,)) 1447 1448 return self._extended._reduce(reduce_op, value) # pylint: disable=protected-access 1449 if reduce_op != reduce_util.ReduceOp.MEAN: 1450 raise TypeError("Expected `reduce_op` to be a `tf.distribute.ReduceOp`, " 1451 "not: %r" % reduce_op) 1452 1453 def mean_reduce_helper(v, axes=axis): 1454 """Computes the numerator and denominator on each replica.""" 1455 numer = math_ops.reduce_sum(v, axis=axes) 1456 def dimension(axis): 1457 if v.shape.rank is not None: 1458 # Note(joshl): We support axis < 0 to be consistent with the 1459 # tf.math.reduce_* operations. 1460 if axis < 0: 1461 if axis + v.shape.rank < 0: 1462 raise ValueError( 1463 "`axis` = %r out of range for `value` with rank %d" % 1464 (axis, v.shape.rank)) 1465 axis += v.shape.rank 1466 elif axis >= v.shape.rank: 1467 raise ValueError( 1468 "`axis` = %r out of range for `value` with rank %d" % 1469 (axis, v.shape.rank)) 1470 # TF v2 returns `None` for unknown dimensions and an integer for 1471 # known dimension, whereas TF v1 returns tensor_shape.Dimension(None) 1472 # or tensor_shape.Dimension(integer). `dimension_value` hides this 1473 # difference, always returning `None` or an integer. 1474 dim = tensor_shape.dimension_value(v.shape[axis]) 1475 if dim is not None: 1476 # By returning a python value in the static shape case, we can 1477 # maybe get a fast path for reducing the denominator. 1478 # TODO(b/151871486): Remove array_ops.identity after we fallback to 1479 # simple reduction if inputs are all on CPU. 1480 return array_ops.identity( 1481 constant_op.constant(dim, dtype=dtypes.int64)) 1482 elif axis < 0: 1483 axis = axis + array_ops.rank(v) 1484 # TODO(b/151871486): Remove array_ops.identity after we fallback to 1485 # simple reduction if inputs are all on CPU. 1486 return array_ops.identity( 1487 array_ops.shape_v2(v, out_type=dtypes.int64)[axis]) 1488 if isinstance(axis, six.integer_types): 1489 denom = dimension(axis) 1490 elif isinstance(axis, (tuple, list)): 1491 denom = math_ops.reduce_prod([dimension(a) for a in axes]) 1492 else: 1493 raise TypeError( 1494 "Expected `axis` to be an integer, tuple or list not: %r" % axis) 1495 # TODO(josh11b): Should we cast denom to v.dtype here instead of after the 1496 # reduce is complete? 1497 return numer, denom 1498 1499 if eager_context.executing_eagerly(): 1500 # As some strategies (e.g. TPUStrategy) doesn't support pure eager 1501 # execution, wrap the `mean_reduce_helper` with a `tf.function` so it can 1502 # be run from eager mode. Cache the tf.function by `axis` to avoid the 1503 # same function to be traced again. 1504 if axis not in self._mean_reduce_helper_fns: 1505 1506 def mean_reduce_fn(v): 1507 return self.run(mean_reduce_helper, args=(v,)) 1508 1509 self._mean_reduce_helper_fns[axis] = def_function.function( 1510 mean_reduce_fn) 1511 numer, denom = self._mean_reduce_helper_fns[axis](value) 1512 else: 1513 numer, denom = self.run(mean_reduce_helper, args=(value,)) 1514 1515 # TODO(josh11b): Should batch reduce here instead of doing two. 1516 numer = self._extended._reduce(reduce_util.ReduceOp.SUM, numer) # pylint: disable=protected-access 1517 denom = self._extended._reduce(reduce_util.ReduceOp.SUM, denom) # pylint: disable=protected-access 1518 denom = math_ops.cast(denom, numer.dtype) 1519 return math_ops.truediv(numer, denom) 1520 1521 @doc_controls.do_not_doc_inheritable # DEPRECATED 1522 @deprecated(None, "use `experimental_local_results` instead.") 1523 def unwrap(self, value): 1524 """Returns the list of all local per-replica values contained in `value`. 1525 1526 DEPRECATED: Please use `experimental_local_results` instead. 1527 1528 Note: This only returns values on the workers initiated by this client. 1529 When using a `tf.distribute.Strategy` like 1530 `tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker 1531 will be its own client, and this function will only return values 1532 computed on that worker. 1533 1534 Args: 1535 value: A value returned by `experimental_run()`, 1536 `extended.call_for_each_replica()`, or a variable created in `scope`. 1537 1538 Returns: 1539 A tuple of values contained in `value`. If `value` represents a single 1540 value, this returns `(value,).` 1541 """ 1542 return self._extended._local_results(value) # pylint: disable=protected-access 1543 1544 def experimental_local_results(self, value): 1545 """Returns the list of all local per-replica values contained in `value`. 1546 1547 Note: This only returns values on the worker initiated by this client. 1548 When using a `tf.distribute.Strategy` like 1549 `tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker 1550 will be its own client, and this function will only return values 1551 computed on that worker. 1552 1553 Args: 1554 value: A value returned by `experimental_run()`, `run(), or a variable 1555 created in `scope`. 1556 1557 Returns: 1558 A tuple of values contained in `value` where ith element corresponds to 1559 ith replica. If `value` represents a single value, this returns 1560 `(value,).` 1561 """ 1562 return self._extended._local_results(value) # pylint: disable=protected-access 1563 1564 @doc_controls.do_not_doc_inheritable # DEPRECATED: TF v1.x only 1565 def group(self, value, name=None): 1566 """Shortcut for `tf.group(self.experimental_local_results(value))`.""" 1567 return self._extended._group(value, name) # pylint: disable=protected-access 1568 1569 @property 1570 def num_replicas_in_sync(self): 1571 """Returns number of replicas over which gradients are aggregated.""" 1572 return self._extended._num_replicas_in_sync # pylint: disable=protected-access 1573 1574 @doc_controls.do_not_doc_inheritable # DEPRECATED: see doc string 1575 @deprecated(None, "use `update_config_proto` instead.") 1576 def configure(self, 1577 session_config=None, 1578 cluster_spec=None, 1579 task_type=None, 1580 task_id=None): 1581 # pylint: disable=g-doc-return-or-yield,g-doc-args 1582 """DEPRECATED: use `update_config_proto` instead. 1583 1584 Configures the strategy class. 1585 1586 DEPRECATED: This method's functionality has been split into the strategy 1587 constructor and `update_config_proto`. In the future, we will allow passing 1588 cluster and config_proto to the constructor to configure the strategy. And 1589 `update_config_proto` can be used to update the config_proto based on the 1590 specific strategy. 1591 """ 1592 return self._extended._configure( # pylint: disable=protected-access 1593 session_config, cluster_spec, task_type, task_id) 1594 1595 @doc_controls.do_not_generate_docs # DEPRECATED 1596 def update_config_proto(self, config_proto): 1597 """DEPRECATED TF 1.x ONLY.""" 1598 return self._extended._update_config_proto(config_proto) # pylint: disable=protected-access 1599 1600 def __deepcopy__(self, memo): 1601 # First do a regular deepcopy of `self`. 1602 cls = self.__class__ 1603 result = cls.__new__(cls) 1604 memo[id(self)] = result 1605 for k, v in self.__dict__.items(): 1606 setattr(result, k, copy.deepcopy(v, memo)) 1607 # One little fix-up: we want `result._extended` to reference `result` 1608 # instead of `self`. 1609 result._extended._container_strategy_weakref = weakref.ref(result) # pylint: disable=protected-access 1610 return result 1611 1612 def __copy__(self): 1613 raise RuntimeError("Must only deepcopy DistributionStrategy.") 1614 1615 @property 1616 def cluster_resolver(self): 1617 """Returns the cluster resolver associated with this strategy. 1618 1619 In general, when using a multi-worker `tf.distribute` strategy such as 1620 `tf.distribute.experimental.MultiWorkerMirroredStrategy` or 1621 `tf.distribute.TPUStrategy()`, there is a 1622 `tf.distribute.cluster_resolver.ClusterResolver` associated with the 1623 strategy used, and such an instance is returned by this property. 1624 1625 Strategies that intend to have an associated 1626 `tf.distribute.cluster_resolver.ClusterResolver` must set the 1627 relevant attribute, or override this property; otherwise, `None` is returned 1628 by default. Those strategies should also provide information regarding what 1629 is returned by this property. 1630 1631 Single-worker strategies usually do not have a 1632 `tf.distribute.cluster_resolver.ClusterResolver`, and in those cases this 1633 property will return `None`. 1634 1635 The `tf.distribute.cluster_resolver.ClusterResolver` may be useful when the 1636 user needs to access information such as the cluster spec, task type or task 1637 id. For example, 1638 1639 ```python 1640 1641 os.environ['TF_CONFIG'] = json.dumps({ 1642 'cluster': { 1643 'worker': ["localhost:12345", "localhost:23456"], 1644 'ps': ["localhost:34567"] 1645 }, 1646 'task': {'type': 'worker', 'index': 0} 1647 }) 1648 1649 # This implicitly uses TF_CONFIG for the cluster and current task info. 1650 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() 1651 1652 ... 1653 1654 if strategy.cluster_resolver.task_type == 'worker': 1655 # Perform something that's only applicable on workers. Since we set this 1656 # as a worker above, this block will run on this particular instance. 1657 elif strategy.cluster_resolver.task_type == 'ps': 1658 # Perform something that's only applicable on parameter servers. Since we 1659 # set this as a worker above, this block will not run on this particular 1660 # instance. 1661 ``` 1662 1663 For more information, please see 1664 `tf.distribute.cluster_resolver.ClusterResolver`'s API docstring. 1665 1666 Returns: 1667 The cluster resolver associated with this strategy. Returns `None` if a 1668 cluster resolver is not applicable or available in this strategy. 1669 """ 1670 if hasattr(self.extended, "_cluster_resolver"): 1671 return self.extended._cluster_resolver # pylint: disable=protected-access 1672 return None 1673 1674 1675@tf_export("distribute.Strategy", v1=[]) # pylint: disable=g-missing-docstring 1676class Strategy(StrategyBase): 1677 1678 __doc__ = StrategyBase.__doc__ 1679 1680 def experimental_distribute_values_from_function(self, value_fn): 1681 """Generates `tf.distribute.DistributedValues` from `value_fn`. 1682 1683 This function is to generate `tf.distribute.DistributedValues` to pass 1684 into `run`, `reduce`, or other methods that take 1685 distributed values when not using datasets. 1686 1687 Args: 1688 value_fn: The function to run to generate values. It is called for 1689 each replica with `tf.distribute.ValueContext` as the sole argument. It 1690 must return a Tensor or a type that can be converted to a Tensor. 1691 Returns: 1692 A `tf.distribute.DistributedValues` containing a value for each replica. 1693 1694 Example usage: 1695 1696 1. Return constant value per replica: 1697 1698 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 1699 >>> def value_fn(ctx): 1700 ... return tf.constant(1.) 1701 >>> distributed_values = ( 1702 ... strategy.experimental_distribute_values_from_function( 1703 ... value_fn)) 1704 >>> local_result = strategy.experimental_local_results(distributed_values) 1705 >>> local_result 1706 (<tf.Tensor: shape=(), dtype=float32, numpy=1.0>, 1707 <tf.Tensor: shape=(), dtype=float32, numpy=1.0>) 1708 1709 2. Distribute values in array based on replica_id: 1710 1711 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 1712 >>> array_value = np.array([3., 2., 1.]) 1713 >>> def value_fn(ctx): 1714 ... return array_value[ctx.replica_id_in_sync_group] 1715 >>> distributed_values = ( 1716 ... strategy.experimental_distribute_values_from_function( 1717 ... value_fn)) 1718 >>> local_result = strategy.experimental_local_results(distributed_values) 1719 >>> local_result 1720 (3.0, 2.0) 1721 1722 3. Specify values using num_replicas_in_sync: 1723 1724 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 1725 >>> def value_fn(ctx): 1726 ... return ctx.num_replicas_in_sync 1727 >>> distributed_values = ( 1728 ... strategy.experimental_distribute_values_from_function( 1729 ... value_fn)) 1730 >>> local_result = strategy.experimental_local_results(distributed_values) 1731 >>> local_result 1732 (2, 2) 1733 1734 4. Place values on devices and distribute: 1735 1736 ``` 1737 strategy = tf.distribute.TPUStrategy() 1738 worker_devices = strategy.extended.worker_devices 1739 multiple_values = [] 1740 for i in range(strategy.num_replicas_in_sync): 1741 with tf.device(worker_devices[i]): 1742 multiple_values.append(tf.constant(1.0)) 1743 1744 def value_fn(ctx): 1745 return multiple_values[ctx.replica_id_in_sync_group] 1746 1747 distributed_values = strategy. 1748 experimental_distribute_values_from_function( 1749 value_fn) 1750 ``` 1751 1752 """ 1753 return self._extended._experimental_distribute_values_from_function( # pylint: disable=protected-access 1754 value_fn) 1755 1756 def gather(self, value, axis): 1757 # pylint: disable=line-too-long, protected-access 1758 """Gather `value` across replicas along `axis` to the current device. 1759 1760 Given a `tf.distribute.DistributedValues` or `tf.Tensor`-like 1761 object `value`, this API gathers and concatenates `value` across replicas 1762 along the `axis`-th dimension. The result is copied to the "current" device, 1763 which would typically be the CPU of the worker on which the program is 1764 running. For `tf.distribute.TPUStrategy`, it is the first TPU host. For 1765 multi-client `tf.distribute.MultiWorkerMirroredStrategy`, this is the CPU of 1766 each worker. 1767 1768 This API can only be called in the cross-replica context. For a counterpart 1769 in the replica context, see `tf.distribute.ReplicaContext.all_gather`. 1770 1771 Note: For all strategies except `tf.distribute.TPUStrategy`, the input 1772 `value` on different replicas must have the same rank, and their shapes must 1773 be the same in all dimensions except the `axis`-th dimension. In other 1774 words, their shapes cannot be different in a dimension `d` where `d` does 1775 not equal to the `axis` argument. For example, given a 1776 `tf.distribute.DistributedValues` with component tensors of shape 1777 `(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can call 1778 `gather(..., axis=1, ...)` on it, but not `gather(..., axis=0, ...)` or 1779 `gather(..., axis=2, ...)`. However, for `tf.distribute.TPUStrategy.gather`, 1780 all tensors must have exactly the same rank and same shape. 1781 1782 Note: Given a `tf.distribute.DistributedValues` `value`, its component 1783 tensors must have a non-zero rank. Otherwise, consider using 1784 `tf.expand_dims` before gathering them. 1785 1786 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 1787 >>> # A DistributedValues with component tensor of shape (2, 1) on each replica 1788 ... distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(tf.constant([[1], [2]]))) 1789 >>> @tf.function 1790 ... def run(): 1791 ... return strategy.gather(distributed_values, axis=0) 1792 >>> run() 1793 <tf.Tensor: shape=(4, 1), dtype=int32, numpy= 1794 array([[1], 1795 [2], 1796 [1], 1797 [2]], dtype=int32)> 1798 1799 1800 Consider the following example for more combinations: 1801 1802 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1", "GPU:2", "GPU:3"]) 1803 >>> single_tensor = tf.reshape(tf.range(6), shape=(1,2,3)) 1804 >>> distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(single_tensor)) 1805 >>> @tf.function 1806 ... def run(axis): 1807 ... return strategy.gather(distributed_values, axis=axis) 1808 >>> axis=0 1809 >>> run(axis) 1810 <tf.Tensor: shape=(4, 2, 3), dtype=int32, numpy= 1811 array([[[0, 1, 2], 1812 [3, 4, 5]], 1813 [[0, 1, 2], 1814 [3, 4, 5]], 1815 [[0, 1, 2], 1816 [3, 4, 5]], 1817 [[0, 1, 2], 1818 [3, 4, 5]]], dtype=int32)> 1819 >>> axis=1 1820 >>> run(axis) 1821 <tf.Tensor: shape=(1, 8, 3), dtype=int32, numpy= 1822 array([[[0, 1, 2], 1823 [3, 4, 5], 1824 [0, 1, 2], 1825 [3, 4, 5], 1826 [0, 1, 2], 1827 [3, 4, 5], 1828 [0, 1, 2], 1829 [3, 4, 5]]], dtype=int32)> 1830 >>> axis=2 1831 >>> run(axis) 1832 <tf.Tensor: shape=(1, 2, 12), dtype=int32, numpy= 1833 array([[[0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2], 1834 [3, 4, 5, 3, 4, 5, 3, 4, 5, 3, 4, 5]]], dtype=int32)> 1835 1836 1837 Args: 1838 value: a `tf.distribute.DistributedValues` instance, e.g. returned by 1839 `Strategy.run`, to be combined into a single tensor. It can also be a 1840 regular tensor when used with `tf.distribute.OneDeviceStrategy` or the 1841 default strategy. The tensors that constitute the DistributedValues 1842 can only be dense tensors with non-zero rank, NOT a `tf.IndexedSlices`. 1843 axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the 1844 range [0, rank(value)). 1845 1846 Returns: 1847 A `Tensor` that's the concatenation of `value` across replicas along 1848 `axis` dimension. 1849 """ 1850 # pylint: enable=line-too-long 1851 error_message = ("tf.distribute.Strategy.gather method requires " 1852 "cross-replica context, use " 1853 "get_replica_context().all_gather() instead") 1854 _require_cross_replica_or_default_context_extended(self._extended, 1855 error_message) 1856 dst = device_util.current( 1857 ) or self._extended._default_device or "/device:CPU:0" 1858 if isinstance(value, indexed_slices.IndexedSlices): 1859 raise NotImplementedError("gather does not support IndexedSlices") 1860 return self._extended._local_results( 1861 self._extended._gather_to(value, dst, axis))[0] 1862 1863 1864# TF v1.x version has additional deprecated APIs 1865@tf_export(v1=["distribute.Strategy"]) 1866class StrategyV1(StrategyBase): 1867 """A list of devices with a state & compute distribution policy. 1868 1869 See [the guide](https://www.tensorflow.org/guide/distribute_strategy) 1870 for overview and examples. 1871 1872 Note: Not all `tf.distribute.Strategy` implementations currently support 1873 TensorFlow's partitioned variables (where a single variable is split across 1874 multiple devices) at this time. 1875 """ 1876 1877 def make_dataset_iterator(self, dataset): 1878 """Makes an iterator for input provided via `dataset`. 1879 1880 DEPRECATED: This method is not available in TF 2.x. 1881 1882 Data from the given dataset will be distributed evenly across all the 1883 compute replicas. We will assume that the input dataset is batched by the 1884 global batch size. With this assumption, we will make a best effort to 1885 divide each batch across all the replicas (one or more workers). 1886 If this effort fails, an error will be thrown, and the user should instead 1887 use `make_input_fn_iterator` which provides more control to the user, and 1888 does not try to divide a batch across replicas. 1889 1890 The user could also use `make_input_fn_iterator` if they want to 1891 customize which input is fed to which replica/worker etc. 1892 1893 Args: 1894 dataset: `tf.data.Dataset` that will be distributed evenly across all 1895 replicas. 1896 1897 Returns: 1898 An `tf.distribute.InputIterator` which returns inputs for each step of the 1899 computation. User should call `initialize` on the returned iterator. 1900 """ 1901 return self._extended._make_dataset_iterator(dataset) # pylint: disable=protected-access 1902 1903 def make_input_fn_iterator(self, # pylint: disable=useless-super-delegation 1904 input_fn, 1905 replication_mode=InputReplicationMode.PER_WORKER): 1906 """Returns an iterator split across replicas created from an input function. 1907 1908 DEPRECATED: This method is not available in TF 2.x. 1909 1910 The `input_fn` should take an `tf.distribute.InputContext` object where 1911 information about batching and input sharding can be accessed: 1912 1913 ``` 1914 def input_fn(input_context): 1915 batch_size = input_context.get_per_replica_batch_size(global_batch_size) 1916 d = tf.data.Dataset.from_tensors([[1.]]).repeat().batch(batch_size) 1917 return d.shard(input_context.num_input_pipelines, 1918 input_context.input_pipeline_id) 1919 with strategy.scope(): 1920 iterator = strategy.make_input_fn_iterator(input_fn) 1921 replica_results = strategy.experimental_run(replica_fn, iterator) 1922 ``` 1923 1924 The `tf.data.Dataset` returned by `input_fn` should have a per-replica 1925 batch size, which may be computed using 1926 `input_context.get_per_replica_batch_size`. 1927 1928 Args: 1929 input_fn: A function taking a `tf.distribute.InputContext` object and 1930 returning a `tf.data.Dataset`. 1931 replication_mode: an enum value of `tf.distribute.InputReplicationMode`. 1932 Only `PER_WORKER` is supported currently, which means there will be 1933 a single call to `input_fn` per worker. Replicas will dequeue from the 1934 local `tf.data.Dataset` on their worker. 1935 1936 Returns: 1937 An iterator object that should first be `.initialize()`-ed. It may then 1938 either be passed to `strategy.experimental_run()` or you can 1939 `iterator.get_next()` to get the next value to pass to 1940 `strategy.extended.call_for_each_replica()`. 1941 """ 1942 return super(StrategyV1, self).make_input_fn_iterator( 1943 input_fn, replication_mode) 1944 1945 def experimental_make_numpy_dataset(self, numpy_input, session=None): 1946 """Makes a tf.data.Dataset for input provided via a numpy array. 1947 1948 This avoids adding `numpy_input` as a large constant in the graph, 1949 and copies the data to the machine or machines that will be processing 1950 the input. 1951 1952 Note that you will likely need to use 1953 tf.distribute.Strategy.experimental_distribute_dataset 1954 with the returned dataset to further distribute it with the strategy. 1955 1956 Example: 1957 ``` 1958 numpy_input = np.ones([10], dtype=np.float32) 1959 dataset = strategy.experimental_make_numpy_dataset(numpy_input) 1960 dist_dataset = strategy.experimental_distribute_dataset(dataset) 1961 ``` 1962 1963 Args: 1964 numpy_input: A nest of NumPy input arrays that will be converted into a 1965 dataset. Note that lists of Numpy arrays are stacked, as that is normal 1966 `tf.data.Dataset` behavior. 1967 session: (TensorFlow v1.x graph execution only) A session used for 1968 initialization. 1969 1970 Returns: 1971 A `tf.data.Dataset` representing `numpy_input`. 1972 """ 1973 return self.extended.experimental_make_numpy_dataset( 1974 numpy_input, session=session) 1975 1976 @deprecated( 1977 None, 1978 "This method is not available in TF 2.x. Please switch to using `run` instead." 1979 ) 1980 def experimental_run(self, fn, input_iterator=None): # pylint: disable=useless-super-delegation 1981 """Runs ops in `fn` on each replica, with inputs from `input_iterator`. 1982 1983 DEPRECATED: This method is not available in TF 2.x. Please switch 1984 to using `run` instead. 1985 1986 When eager execution is enabled, executes ops specified by `fn` on each 1987 replica. Otherwise, builds a graph to execute the ops on each replica. 1988 1989 Each replica will take a single, different input from the inputs provided by 1990 one `get_next` call on the input iterator. 1991 1992 `fn` may call `tf.distribute.get_replica_context()` to access members such 1993 as `replica_id_in_sync_group`. 1994 1995 IMPORTANT: Depending on the `tf.distribute.Strategy` implementation being 1996 used, and whether eager execution is enabled, `fn` may be called one or more 1997 times (once for each replica). 1998 1999 Args: 2000 fn: The function to run. The inputs to the function must match the outputs 2001 of `input_iterator.get_next()`. The output must be a `tf.nest` of 2002 `Tensor`s. 2003 input_iterator: (Optional) input iterator from which the inputs are taken. 2004 2005 Returns: 2006 Merged return value of `fn` across replicas. The structure of the return 2007 value is the same as the return value from `fn`. Each element in the 2008 structure can either be `PerReplica` (if the values are unsynchronized), 2009 `Mirrored` (if the values are kept in sync), or `Tensor` (if running on a 2010 single replica). 2011 """ 2012 return super(StrategyV1, self).experimental_run( 2013 fn, input_iterator) 2014 2015 def reduce(self, reduce_op, value, axis=None): 2016 return super(StrategyV1, self).reduce(reduce_op, value, axis) 2017 2018 reduce.__doc__ = StrategyBase.reduce.__doc__ 2019 2020 def update_config_proto(self, config_proto): 2021 """Returns a copy of `config_proto` modified for use with this strategy. 2022 2023 DEPRECATED: This method is not available in TF 2.x. 2024 2025 The updated config has something needed to run a strategy, e.g. 2026 configuration to run collective ops, or device filters to improve 2027 distributed training performance. 2028 2029 Args: 2030 config_proto: a `tf.ConfigProto` object. 2031 2032 Returns: 2033 The updated copy of the `config_proto`. 2034 """ 2035 return self._extended._update_config_proto(config_proto) # pylint: disable=protected-access 2036 2037 2038# NOTE(josh11b): For any strategy that needs to support tf.compat.v1, 2039# instead descend from StrategyExtendedV1. 2040@tf_export("distribute.StrategyExtended", v1=[]) 2041class StrategyExtendedV2(object): 2042 """Additional APIs for algorithms that need to be distribution-aware. 2043 2044 Note: For most usage of `tf.distribute.Strategy`, there should be no need to 2045 call these methods, since TensorFlow libraries (such as optimizers) already 2046 call these methods when needed on your behalf. 2047 2048 2049 Some common use cases of functions on this page: 2050 2051 * _Locality_ 2052 2053 `tf.distribute.DistributedValues` can have the same _locality_ as a 2054 _distributed variable_, which leads to a mirrored value residing on the same 2055 devices as the variable (as opposed to the compute devices). Such values may 2056 be passed to a call to `tf.distribute.StrategyExtended.update` to update the 2057 value of a variable. You may use 2058 `tf.distribute.StrategyExtended.colocate_vars_with` to give a variable the 2059 same locality as another variable. You may convert a "PerReplica" value to a 2060 variable's locality by using `tf.distribute.StrategyExtended.reduce_to` or 2061 `tf.distribute.StrategyExtended.batch_reduce_to`. 2062 2063 * _How to update a distributed variable_ 2064 2065 A distributed variable is variables created on multiple devices. As discussed 2066 in the [glossary](https://www.tensorflow.org/api_docs/python/tf/distribute), 2067 mirrored variable and SyncOnRead variable are two examples. The standard 2068 pattern for updating distributed variables is to: 2069 2070 1. In your function passed to `tf.distribute.Strategy.run`, 2071 compute a list of (update, variable) pairs. For example, the update might 2072 be a gradient of the loss with respect to the variable. 2073 2. Switch to cross-replica mode by calling 2074 `tf.distribute.get_replica_context().merge_call()` with the updates and 2075 variables as arguments. 2076 3. Call 2077 `tf.distribute.StrategyExtended.reduce_to(VariableAggregation.SUM, t, v)` 2078 (for one variable) or `tf.distribute.StrategyExtended.batch_reduce_to` 2079 (for a list of variables) to sum the updates. 2080 4. Call `tf.distribute.StrategyExtended.update(v)` for each variable to update 2081 its value. 2082 2083 Steps 2 through 4 are done automatically by class 2084 `tf.keras.optimizers.Optimizer` if you call its 2085 `tf.keras.optimizers.Optimizer.apply_gradients` method in a replica context. 2086 2087 In fact, a higher-level solution to update a distributed variable is by 2088 calling `assign` on the variable as you would do to a regular `tf.Variable`. 2089 You can call the method in both _replica context_ and _cross-replica context_. 2090 For a _mirrored variable_, calling `assign` in _replica context_ requires you 2091 to specify the `aggregation` type in the variable constructor. In that case, 2092 the context switching and sync described in steps 2 through 4 are handled for 2093 you. If you call `assign` on _mirrored variable_ in _cross-replica context_, 2094 you can only assign a single value or assign values from another mirrored 2095 variable or a mirrored `tf.distribute.DistributedValues`. For a _SyncOnRead 2096 variable_, in _replica context_, you can simply call `assign` on it and no 2097 aggregation happens under the hood. In _cross-replica context_, you can only 2098 assign a single value to a SyncOnRead variable. One example case is restoring 2099 from a checkpoint: if the `aggregation` type of the variable is 2100 `tf.VariableAggregation.SUM`, it is assumed that replica values were added 2101 before checkpointing, so at the time of restoring, the value is divided by 2102 the number of replicas and then assigned to each replica; if the `aggregation` 2103 type is `tf.VariableAggregation.MEAN`, the value is assigned to each replica 2104 directly. 2105 2106 """ 2107 2108 def __init__(self, container_strategy): 2109 self._container_strategy_weakref = weakref.ref(container_strategy) 2110 self._default_device = None 2111 # This property is used to determine if we should set drop_remainder=True 2112 # when creating Datasets from numpy array inputs. 2113 self._require_static_shapes = False 2114 2115 def _resource_creator_scope(self): 2116 """Returns one or a list of ops.resource_creator_scope for some Strategy.""" 2117 return None 2118 2119 def _container_strategy(self): 2120 """Get the containing `tf.distribute.Strategy`. 2121 2122 This should not generally be needed except when creating a new 2123 `ReplicaContext` and to validate that the caller is in the correct 2124 `scope()`. 2125 2126 Returns: 2127 The `tf.distribute.Strategy` such that `strategy.extended` is `self`. 2128 """ 2129 container_strategy = self._container_strategy_weakref() 2130 assert container_strategy is not None 2131 return container_strategy 2132 2133 def _scope(self, strategy): 2134 """Implementation of tf.distribute.Strategy.scope().""" 2135 2136 def creator_with_resource_vars(next_creator, **kwargs): 2137 """Variable creator to use in `_CurrentDistributionContext`.""" 2138 _require_strategy_scope_extended(self) 2139 kwargs["use_resource"] = True 2140 kwargs["distribute_strategy"] = strategy 2141 2142 # Unwrap `initial_value` if it is a `CheckpointInitialValue` to avoid 2143 # dereferencing a `Tensor` that is without a `name`. We still need to 2144 # propagate the metadata it's holding. 2145 if isinstance(kwargs["initial_value"], trackable.CheckpointInitialValue): 2146 checkpoint_restore_uid = kwargs[ 2147 "initial_value"].checkpoint_position.restore_uid 2148 kwargs["initial_value"] = kwargs["initial_value"].wrapped_value 2149 elif isinstance(kwargs["initial_value"], 2150 trackable.CheckpointInitialValueCallable): 2151 checkpoint_restore_uid = kwargs[ 2152 "initial_value"].checkpoint_position.restore_uid 2153 elif (isinstance(kwargs["initial_value"], functools.partial) and 2154 isinstance(kwargs["initial_value"].func, 2155 trackable.CheckpointInitialValueCallable)): 2156 # Some libraries (e.g, Keras) create partial function out of initializer 2157 # to bind shape/dtype, for example: 2158 # initial_val = functools.partial(initializer, shape, dtype=dtype) 2159 # Therefore to get the restore_uid we need to examine the "func" of 2160 # the partial function. 2161 checkpoint_restore_uid = kwargs[ 2162 "initial_value"].func.checkpoint_position.restore_uid 2163 else: 2164 checkpoint_restore_uid = None 2165 2166 created = self._create_variable(next_creator, **kwargs) 2167 2168 if checkpoint_restore_uid is not None: 2169 # pylint: disable=protected-access 2170 # Let the checkpointing infrastructure know that the variable was 2171 # already restored so it doesn't waste memory loading the value again. 2172 # In this case of CheckpointInitialValueCallable this may already be 2173 # done by the final variable creator, but it doesn't hurt to do it 2174 # again. 2175 created._maybe_initialize_trackable() 2176 created._update_uid = checkpoint_restore_uid 2177 # pylint: enable=protected-access 2178 return created 2179 2180 def distributed_getter(getter, *args, **kwargs): 2181 if not self._allow_variable_partition(): 2182 if kwargs.pop("partitioner", None) is not None: 2183 tf_logging.log_first_n( 2184 tf_logging.WARN, "Partitioned variables are disabled when using " 2185 "current tf.distribute.Strategy.", 1) 2186 return getter(*args, **kwargs) 2187 2188 return _CurrentDistributionContext( 2189 strategy, 2190 variable_scope.variable_creator_scope(creator_with_resource_vars), 2191 variable_scope.variable_scope( 2192 variable_scope.get_variable_scope(), 2193 custom_getter=distributed_getter), 2194 strategy.extended._resource_creator_scope(), # pylint: disable=protected-access 2195 self._default_device) 2196 2197 def _allow_variable_partition(self): 2198 return False 2199 2200 def _create_variable(self, next_creator, **kwargs): 2201 # Note: should support "colocate_with" argument. 2202 raise NotImplementedError("must be implemented in descendants") 2203 2204 def variable_created_in_scope(self, v): 2205 """Tests whether `v` was created while this strategy scope was active. 2206 2207 Variables created inside the strategy scope are "owned" by it: 2208 2209 >>> strategy = tf.distribute.MirroredStrategy() 2210 >>> with strategy.scope(): 2211 ... v = tf.Variable(1.) 2212 >>> strategy.extended.variable_created_in_scope(v) 2213 True 2214 2215 Variables created outside the strategy are not owned by it: 2216 2217 >>> strategy = tf.distribute.MirroredStrategy() 2218 >>> v = tf.Variable(1.) 2219 >>> strategy.extended.variable_created_in_scope(v) 2220 False 2221 2222 Args: 2223 v: A `tf.Variable` instance. 2224 2225 Returns: 2226 True if `v` was created inside the scope, False if not. 2227 """ 2228 return v._distribute_strategy == self._container_strategy_weakref() # pylint: disable=protected-access 2229 2230 def colocate_vars_with(self, colocate_with_variable): 2231 """Scope that controls which devices variables will be created on. 2232 2233 No operations should be added to the graph inside this scope, it 2234 should only be used when creating variables (some implementations 2235 work by changing variable creation, others work by using a 2236 tf.compat.v1.colocate_with() scope). 2237 2238 This may only be used inside `self.scope()`. 2239 2240 Example usage: 2241 2242 ``` 2243 with strategy.scope(): 2244 var1 = tf.Variable(...) 2245 with strategy.extended.colocate_vars_with(var1): 2246 # var2 and var3 will be created on the same device(s) as var1 2247 var2 = tf.Variable(...) 2248 var3 = tf.Variable(...) 2249 2250 def fn(v1, v2, v3): 2251 # operates on v1 from var1, v2 from var2, and v3 from var3 2252 2253 # `fn` runs on every device `var1` is on, `var2` and `var3` will be there 2254 # too. 2255 strategy.extended.update(var1, fn, args=(var2, var3)) 2256 ``` 2257 2258 Args: 2259 colocate_with_variable: A variable created in this strategy's `scope()`. 2260 Variables created while in the returned context manager will be on the 2261 same set of devices as `colocate_with_variable`. 2262 2263 Returns: 2264 A context manager. 2265 """ 2266 2267 def create_colocated_variable(next_creator, **kwargs): 2268 _require_strategy_scope_extended(self) 2269 kwargs["use_resource"] = True 2270 kwargs["colocate_with"] = colocate_with_variable 2271 return next_creator(**kwargs) 2272 2273 _require_strategy_scope_extended(self) 2274 self._validate_colocate_with_variable(colocate_with_variable) 2275 return variable_scope.variable_creator_scope(create_colocated_variable) 2276 2277 def _validate_colocate_with_variable(self, colocate_with_variable): 2278 """Validate `colocate_with_variable` argument to `colocate_vars_with`.""" 2279 pass 2280 2281 def _make_dataset_iterator(self, dataset): 2282 raise NotImplementedError("must be implemented in descendants") 2283 2284 def _make_input_fn_iterator(self, input_fn, replication_mode): 2285 raise NotImplementedError("must be implemented in descendants") 2286 2287 def _experimental_distribute_dataset(self, dataset, options): 2288 raise NotImplementedError("must be implemented in descendants") 2289 2290 def _distribute_datasets_from_function(self, dataset_fn, options): 2291 raise NotImplementedError("must be implemented in descendants") 2292 2293 def _experimental_distribute_values_from_function(self, value_fn): 2294 raise NotImplementedError("must be implemented in descendants") 2295 2296 def _reduce(self, reduce_op, value): 2297 # Default implementation until we have an implementation for each strategy. 2298 dst = device_util.current() or self._default_device or "/device:CPU:0" 2299 return self._local_results(self.reduce_to(reduce_op, value, dst))[0] 2300 2301 def reduce_to(self, reduce_op, value, destinations, options=None): 2302 """Combine (via e.g. sum or mean) values across replicas. 2303 2304 `reduce_to` aggregates `tf.distribute.DistributedValues` and distributed 2305 variables. It supports both dense values and `tf.IndexedSlices`. 2306 2307 This API currently can only be called in cross-replica context. Other 2308 variants to reduce values across replicas are: 2309 * `tf.distribute.StrategyExtended.batch_reduce_to`: the batch version of 2310 this API. 2311 * `tf.distribute.ReplicaContext.all_reduce`: the counterpart of this API 2312 in replica context. It supports both batched and non-batched all-reduce. 2313 * `tf.distribute.Strategy.reduce`: a more convenient method to reduce 2314 to the host in cross-replica context. 2315 2316 `destinations` specifies where to reduce the value to, e.g. "GPU:0". You can 2317 also pass in a `Tensor`, and the destinations will be the device of that 2318 tensor. For all-reduce, pass the same to `value` and `destinations`. 2319 2320 It can be used in `tf.distribute.ReplicaContext.merge_call` to write code 2321 that works for all `tf.distribute.Strategy`. 2322 2323 >>> @tf.function 2324 ... def step_fn(var): 2325 ... 2326 ... def merge_fn(strategy, value, var): 2327 ... # All-reduce the value. Note that `value` here is a 2328 ... # `tf.distribute.DistributedValues`. 2329 ... reduced = strategy.extended.reduce_to(tf.distribute.ReduceOp.SUM, 2330 ... value, destinations=var) 2331 ... strategy.extended.update(var, lambda var, value: var.assign(value), 2332 ... args=(reduced,)) 2333 ... 2334 ... value = tf.identity(1.) 2335 ... tf.distribute.get_replica_context().merge_call(merge_fn, 2336 ... args=(value, var)) 2337 >>> 2338 >>> def run(strategy): 2339 ... with strategy.scope(): 2340 ... v = tf.Variable(0.) 2341 ... strategy.run(step_fn, args=(v,)) 2342 ... return v 2343 >>> 2344 >>> run(tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])) 2345 MirroredVariable:{ 2346 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>, 2347 1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=2.0> 2348 } 2349 >>> run(tf.distribute.experimental.CentralStorageStrategy( 2350 ... compute_devices=["GPU:0", "GPU:1"], parameter_device="CPU:0")) 2351 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0> 2352 >>> run(tf.distribute.OneDeviceStrategy("GPU:0")) 2353 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0> 2354 2355 Args: 2356 reduce_op: a `tf.distribute.ReduceOp` value specifying how values should 2357 be combined. Allows using string representation of the enum such as 2358 "SUM", "MEAN". 2359 value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` like object. 2360 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a 2361 `tf.Tensor` alike object, or a device string. It specifies the devices 2362 to reduce to. To perform an all-reduce, pass the same to `value` and 2363 `destinations`. Note that if it's a `tf.Variable`, the value is reduced 2364 to the devices of that variable, and this method doesn't update the 2365 variable. 2366 options: a `tf.distribute.experimental.CommunicationOptions`. Options to 2367 perform collective operations. This overrides the default options if the 2368 `tf.distribute.Strategy` takes one in the constructor. See 2369 `tf.distribute.experimental.CommunicationOptions` for details of the 2370 options. 2371 2372 Returns: 2373 A tensor or value reduced to `destinations`. 2374 """ 2375 if options is None: 2376 options = collective_util.Options() 2377 _require_cross_replica_or_default_context_extended(self) 2378 assert not isinstance(destinations, (list, tuple)) 2379 assert not isinstance(reduce_op, variable_scope.VariableAggregation) 2380 if isinstance(reduce_op, six.string_types): 2381 reduce_op = reduce_util.ReduceOp(reduce_op.upper()) 2382 assert (reduce_op == reduce_util.ReduceOp.SUM or 2383 reduce_op == reduce_util.ReduceOp.MEAN) 2384 return self._reduce_to(reduce_op, value, destinations, options) 2385 2386 def _reduce_to(self, reduce_op, value, destinations, options): 2387 raise NotImplementedError("must be implemented in descendants") 2388 2389 def batch_reduce_to(self, reduce_op, value_destination_pairs, options=None): 2390 """Combine multiple `reduce_to` calls into one for faster execution. 2391 2392 Similar to `reduce_to`, but accepts a list of (value, destinations) pairs. 2393 It's more efficient than reduce each value separately. 2394 2395 This API currently can only be called in cross-replica context. Other 2396 variants to reduce values across replicas are: 2397 * `tf.distribute.StrategyExtended.reduce_to`: the non-batch version of 2398 this API. 2399 * `tf.distribute.ReplicaContext.all_reduce`: the counterpart of this API 2400 in replica context. It supports both batched and non-batched all-reduce. 2401 * `tf.distribute.Strategy.reduce`: a more convenient method to reduce 2402 to the host in cross-replica context. 2403 2404 See `reduce_to` for more information. 2405 2406 >>> @tf.function 2407 ... def step_fn(var): 2408 ... 2409 ... def merge_fn(strategy, value, var): 2410 ... # All-reduce the value. Note that `value` here is a 2411 ... # `tf.distribute.DistributedValues`. 2412 ... reduced = strategy.extended.batch_reduce_to( 2413 ... tf.distribute.ReduceOp.SUM, [(value, var)])[0] 2414 ... strategy.extended.update(var, lambda var, value: var.assign(value), 2415 ... args=(reduced,)) 2416 ... 2417 ... value = tf.identity(1.) 2418 ... tf.distribute.get_replica_context().merge_call(merge_fn, 2419 ... args=(value, var)) 2420 >>> 2421 >>> def run(strategy): 2422 ... with strategy.scope(): 2423 ... v = tf.Variable(0.) 2424 ... strategy.run(step_fn, args=(v,)) 2425 ... return v 2426 >>> 2427 >>> run(tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])) 2428 MirroredVariable:{ 2429 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>, 2430 1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=2.0> 2431 } 2432 >>> run(tf.distribute.experimental.CentralStorageStrategy( 2433 ... compute_devices=["GPU:0", "GPU:1"], parameter_device="CPU:0")) 2434 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0> 2435 >>> run(tf.distribute.OneDeviceStrategy("GPU:0")) 2436 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0> 2437 2438 Args: 2439 reduce_op: a `tf.distribute.ReduceOp` value specifying how values should 2440 be combined. Allows using string representation of the enum such as 2441 "SUM", "MEAN". 2442 value_destination_pairs: a sequence of (value, destinations) pairs. See 2443 `tf.distribute.Strategy.reduce_to` for descriptions. 2444 options: a `tf.distribute.experimental.CommunicationOptions`. Options to 2445 perform collective operations. This overrides the default options if the 2446 `tf.distribute.Strategy` takes one in the constructor. See 2447 `tf.distribute.experimental.CommunicationOptions` for details of the 2448 options. 2449 2450 Returns: 2451 A list of reduced values, one per pair in `value_destination_pairs`. 2452 """ 2453 if options is None: 2454 options = collective_util.Options() 2455 _require_cross_replica_or_default_context_extended(self) 2456 assert not isinstance(reduce_op, variable_scope.VariableAggregation) 2457 if isinstance(reduce_op, six.string_types): 2458 reduce_op = reduce_util.ReduceOp(reduce_op.upper()) 2459 return self._batch_reduce_to(reduce_op, value_destination_pairs, options) 2460 2461 def _batch_reduce_to(self, reduce_op, value_destination_pairs, options): 2462 return [ 2463 self.reduce_to(reduce_op, t, destinations=v, options=options) 2464 for t, v in value_destination_pairs 2465 ] 2466 2467 def _replica_ctx_all_reduce(self, reduce_op, value, options=None): 2468 """All-reduce `value` across all replicas so that all get the final result. 2469 2470 If `value` is a nested structure of tensors, all-reduces of these tensors 2471 will be batched when possible. `options` can be set to hint the batching 2472 behavior. 2473 2474 This API must be called in a replica context. 2475 2476 Args: 2477 reduce_op: A `tf.distribute.ReduceOp` value specifying how values should 2478 be combined. 2479 value: Value to be reduced. A tensor or a nested structure of tensors. 2480 options: A `tf.distribute.experimental.CommunicationOptions`. Options to 2481 perform collective operations. This overrides the default options if the 2482 `tf.distribute.Strategy` takes one in the constructor. 2483 2484 Returns: 2485 A tensor or a nested strucutre of tensors with the reduced values. The 2486 structure is the same as `value`. 2487 """ 2488 if options is None: 2489 options = collective_util.Options() 2490 replica_context = distribution_strategy_context.get_replica_context() 2491 assert replica_context, ( 2492 "`StrategyExtended._replica_ctx_all_reduce` must be called in" 2493 " a replica context") 2494 2495 def merge_fn(_, flat_value): 2496 return self.batch_reduce_to(reduce_op, [(v, v) for v in flat_value], 2497 options) 2498 2499 reduced = replica_context.merge_call(merge_fn, args=(nest.flatten(value),)) 2500 return nest.pack_sequence_as(value, reduced) 2501 2502 def _replica_ctx_update(self, var, fn, args=(), kwargs=None, group=True): 2503 """Run `fn` with `args` and `kwargs` to update `var`.""" 2504 # This method is called by ReplicaContext.update. Strategies who'd like to 2505 # remove merge_call in this path should override this method. 2506 replica_context = distribution_strategy_context.get_replica_context() 2507 if not replica_context: 2508 raise ValueError("`StrategyExtended._replica_ctx_update` must be called " 2509 "in a replica context.") 2510 2511 def merge_fn(_, *merged_args, **merged_kwargs): 2512 return self.update(var, fn, merged_args, merged_kwargs, group=group) 2513 2514 return replica_context.merge_call(merge_fn, args=args, kwargs=kwargs) 2515 2516 def _gather_to(self, value, destinations, axis, options=None): 2517 """Gather `value` across replicas along axis-th dimension to `destinations`. 2518 2519 `gather_to` gathers `tf.distribute.DistributedValues` or `tf.Tensor`-like 2520 object, along `axis`-th dimension. It supports only dense tensors but NOT 2521 sparse tensor. This API can only be called in cross-replica context. 2522 2523 Args: 2524 value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` like object. 2525 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a 2526 `tf.Tensor` alike object, or a device string. It specifies the devices 2527 to reduce to. To perform an all-gather, pass the same to `value` and 2528 `destinations`. Note that if it's a `tf.Variable`, the value is reduced 2529 to the devices of that variable, and this method doesn't update the 2530 variable. 2531 axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the 2532 range [0, rank(value)). 2533 options: a `tf.distribute.experimental.CommunicationOptions`. Options to 2534 perform collective operations. This overrides the default options if the 2535 `tf.distribute.Strategy` takes one in the constructor. See 2536 `tf.distribute.experimental.CommunicationOptions` for details of the 2537 options. 2538 2539 Returns: 2540 A tensor or value gathered to `destinations`. 2541 """ 2542 _require_cross_replica_or_default_context_extended(self) 2543 assert not isinstance(destinations, (list, tuple)) 2544 if options is None: 2545 options = collective_util.Options() 2546 return self._gather_to_implementation(value, destinations, axis, options) 2547 2548 def _gather_to_implementation(self, value, destinations, axis, options): 2549 raise NotImplementedError("_gather_to must be implemented in descendants") 2550 2551 def _batch_gather_to(self, value_destination_pairs, axis, options=None): 2552 _require_cross_replica_or_default_context_extended(self) 2553 if options is None: 2554 options = collective_util.Options() 2555 return [ 2556 self._gather_to(t, destinations=v, axis=axis, options=options) 2557 for t, v in value_destination_pairs 2558 ] 2559 2560 def update(self, var, fn, args=(), kwargs=None, group=True): 2561 """Run `fn` to update `var` using inputs mirrored to the same devices. 2562 2563 `tf.distribute.StrategyExtended.update` takes a distributed variable `var` 2564 to be updated, an update function `fn`, and `args` and `kwargs` for `fn`. It 2565 applies `fn` to each component variable of `var` and passes corresponding 2566 values from `args` and `kwargs`. Neither `args` nor `kwargs` may contain 2567 per-replica values. If they contain mirrored values, they will be unwrapped 2568 before calling `fn`. For example, `fn` can be `assign_add` and `args` can be 2569 a mirrored DistributedValues where each component contains the value to be 2570 added to this mirrored variable `var`. Calling `update` will call 2571 `assign_add` on each component variable of `var` with the corresponding 2572 tensor value on that device. 2573 2574 Example usage: 2575 2576 ```python 2577 strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # With 2 2578 devices 2579 with strategy.scope(): 2580 v = tf.Variable(5.0, aggregation=tf.VariableAggregation.SUM) 2581 def update_fn(v): 2582 return v.assign(1.0) 2583 result = strategy.extended.update(v, update_fn) 2584 # result is 2585 # Mirrored:{ 2586 # 0: tf.Tensor(1.0, shape=(), dtype=float32), 2587 # 1: tf.Tensor(1.0, shape=(), dtype=float32) 2588 # } 2589 ``` 2590 2591 If `var` is mirrored across multiple devices, then this method implements 2592 logic as following: 2593 2594 ```python 2595 results = {} 2596 for device, v in var: 2597 with tf.device(device): 2598 # args and kwargs will be unwrapped if they are mirrored. 2599 results[device] = fn(v, *args, **kwargs) 2600 return merged(results) 2601 ``` 2602 2603 Otherwise, this method returns `fn(var, *args, **kwargs)` colocated with 2604 `var`. 2605 2606 Args: 2607 var: Variable, possibly mirrored to multiple devices, to operate on. 2608 fn: Function to call. Should take the variable as the first argument. 2609 args: Tuple or list. Additional positional arguments to pass to `fn()`. 2610 kwargs: Dict with keyword arguments to pass to `fn()`. 2611 group: Boolean. Defaults to True. If False, the return value will be 2612 unwrapped. 2613 2614 Returns: 2615 By default, the merged return value of `fn` across all replicas. The 2616 merged result has dependencies to make sure that if it is evaluated at 2617 all, the side effects (updates) will happen on every replica. If instead 2618 "group=False" is specified, this function will return a nest of lists 2619 where each list has an element per replica, and the caller is responsible 2620 for ensuring all elements are executed. 2621 """ 2622 # TODO(b/178944108): Update the documentation to relfect the fact that 2623 # `update` can be called in a replica context. 2624 if kwargs is None: 2625 kwargs = {} 2626 replica_context = distribution_strategy_context.get_replica_context() 2627 # pylint: disable=protected-access 2628 if (replica_context is None or replica_context is 2629 distribution_strategy_context._get_default_replica_context()): 2630 fn = autograph.tf_convert( 2631 fn, autograph_ctx.control_status_ctx(), convert_by_default=False) 2632 with self._container_strategy().scope(): 2633 return self._update(var, fn, args, kwargs, group) 2634 else: 2635 return self._replica_ctx_update( 2636 var, fn, args=args, kwargs=kwargs, group=group) 2637 2638 def _update(self, var, fn, args, kwargs, group): 2639 raise NotImplementedError("must be implemented in descendants") 2640 2641 def _local_results(self, val): 2642 """Returns local results per replica as a tuple.""" 2643 if isinstance(val, values.DistributedValues): 2644 return val._values # pylint: disable=protected-access 2645 2646 if nest.is_nested(val): 2647 replica_values = [] 2648 2649 def get_values(x, index): 2650 if isinstance(x, values.DistributedValues): 2651 return x._values[index] # pylint: disable=protected-access 2652 return x 2653 2654 for i in range(len(self.worker_devices)): 2655 replica_values.append( 2656 nest.map_structure( 2657 lambda x: get_values(x, i), # pylint: disable=cell-var-from-loop 2658 val)) 2659 return tuple(replica_values) 2660 return (val,) 2661 2662 def value_container(self, value): 2663 """Returns the container that this per-replica `value` belongs to. 2664 2665 Args: 2666 value: A value returned by `run()` or a variable created in `scope()`. 2667 2668 Returns: 2669 A container that `value` belongs to. 2670 If value does not belong to any container (including the case of 2671 container having been destroyed), returns the value itself. 2672 `value in experimental_local_results(value_container(value))` will 2673 always be true. 2674 """ 2675 raise NotImplementedError("must be implemented in descendants") 2676 2677 def _group(self, value, name=None): 2678 """Implementation of `group`.""" 2679 value = nest.flatten(self._local_results(value)) 2680 2681 if len(value) != 1 or name is not None: 2682 return control_flow_ops.group(value, name=name) 2683 # Special handling for the common case of one op. 2684 v, = value 2685 if hasattr(v, "op"): 2686 v = v.op 2687 return v 2688 2689 @property 2690 def experimental_require_static_shapes(self): 2691 """Returns `True` if static shape is required; `False` otherwise.""" 2692 return self._require_static_shapes 2693 2694 @property 2695 def _num_replicas_in_sync(self): 2696 """Returns number of replicas over which gradients are aggregated.""" 2697 raise NotImplementedError("must be implemented in descendants") 2698 2699 @property 2700 def worker_devices(self): 2701 """Returns the tuple of all devices used to for compute replica execution. 2702 """ 2703 # TODO(josh11b): More docstring 2704 raise NotImplementedError("must be implemented in descendants") 2705 2706 @property 2707 def parameter_devices(self): 2708 """Returns the tuple of all devices used to place variables.""" 2709 # TODO(josh11b): More docstring 2710 raise NotImplementedError("must be implemented in descendants") 2711 2712 def _configure(self, 2713 session_config=None, 2714 cluster_spec=None, 2715 task_type=None, 2716 task_id=None): 2717 """Configures the strategy class.""" 2718 del session_config, cluster_spec, task_type, task_id 2719 2720 def _update_config_proto(self, config_proto): 2721 return copy.deepcopy(config_proto) 2722 2723 def _in_multi_worker_mode(self): 2724 """Whether this strategy indicates working in multi-worker settings. 2725 2726 Multi-worker training refers to the setup where the training is 2727 distributed across multiple workers, as opposed to the case where 2728 only a local process performs the training. This function is 2729 used by higher-level APIs such as Keras' `model.fit()` to infer 2730 for example whether or not a distribute coordinator should be run, 2731 and thus TensorFlow servers should be started for communication 2732 with other servers in the cluster, or whether or not saving/restoring 2733 checkpoints is relevant for preemption fault tolerance. 2734 2735 Subclasses should override this to provide whether the strategy is 2736 currently in multi-worker setup. 2737 2738 Experimental. Signature and implementation are subject to change. 2739 """ 2740 raise NotImplementedError("must be implemented in descendants") 2741 2742 2743@tf_export(v1=["distribute.StrategyExtended"]) # pylint: disable=missing-docstring 2744class StrategyExtendedV1(StrategyExtendedV2): 2745 2746 __doc__ = StrategyExtendedV2.__doc__ 2747 2748 def experimental_make_numpy_dataset(self, numpy_input, session=None): 2749 """Makes a dataset for input provided via a numpy array. 2750 2751 This avoids adding `numpy_input` as a large constant in the graph, 2752 and copies the data to the machine or machines that will be processing 2753 the input. 2754 2755 Args: 2756 numpy_input: A nest of NumPy input arrays that will be distributed evenly 2757 across all replicas. Note that lists of Numpy arrays are stacked, as 2758 that is normal `tf.data.Dataset` behavior. 2759 session: (TensorFlow v1.x graph execution only) A session used for 2760 initialization. 2761 2762 Returns: 2763 A `tf.data.Dataset` representing `numpy_input`. 2764 """ 2765 _require_cross_replica_or_default_context_extended(self) 2766 return self._experimental_make_numpy_dataset(numpy_input, session=session) 2767 2768 def _experimental_make_numpy_dataset(self, numpy_input, session): 2769 raise NotImplementedError("must be implemented in descendants") 2770 2771 def broadcast_to(self, tensor, destinations): 2772 """Mirror a tensor on one device to all worker devices. 2773 2774 Args: 2775 tensor: A Tensor value to broadcast. 2776 destinations: A mirrored variable or device string specifying the 2777 destination devices to copy `tensor` to. 2778 2779 Returns: 2780 A value mirrored to `destinations` devices. 2781 """ 2782 assert destinations is not None # from old strategy.broadcast() 2783 # TODO(josh11b): More docstring 2784 _require_cross_replica_or_default_context_extended(self) 2785 assert not isinstance(destinations, (list, tuple)) 2786 return self._broadcast_to(tensor, destinations) 2787 2788 def _broadcast_to(self, tensor, destinations): 2789 raise NotImplementedError("must be implemented in descendants") 2790 2791 @deprecated(None, "please use `run` instead.") 2792 def experimental_run_steps_on_iterator(self, 2793 fn, 2794 iterator, 2795 iterations=1, 2796 initial_loop_values=None): 2797 """DEPRECATED: please use `run` instead. 2798 2799 Run `fn` with input from `iterator` for `iterations` times. 2800 2801 This method can be used to run a step function for training a number of 2802 times using input from a dataset. 2803 2804 Args: 2805 fn: function to run using this distribution strategy. The function must 2806 have the following signature: `def fn(context, inputs)`. `context` is an 2807 instance of `MultiStepContext` that will be passed when `fn` is run. 2808 `context` can be used to specify the outputs to be returned from `fn` 2809 by calling `context.set_last_step_output`. It can also be used to 2810 capture non tensor outputs by `context.set_non_tensor_output`. See 2811 `MultiStepContext` documentation for more information. `inputs` will 2812 have same type/structure as `iterator.get_next()`. Typically, `fn` 2813 will use `call_for_each_replica` method of the strategy to distribute 2814 the computation over multiple replicas. 2815 iterator: Iterator of a dataset that represents the input for `fn`. The 2816 caller is responsible for initializing the iterator as needed. 2817 iterations: (Optional) Number of iterations that `fn` should be run. 2818 Defaults to 1. 2819 initial_loop_values: (Optional) Initial values to be passed into the 2820 loop that runs `fn`. Defaults to `None`. # TODO(priyag): Remove 2821 initial_loop_values argument when we have a mechanism to infer the 2822 outputs of `fn`. 2823 2824 Returns: 2825 Returns the `MultiStepContext` object which has the following properties, 2826 among other things: 2827 - run_op: An op that runs `fn` `iterations` times. 2828 - last_step_outputs: A dictionary containing tensors set using 2829 `context.set_last_step_output`. Evaluating this returns the value of 2830 the tensors after the last iteration. 2831 - non_tensor_outputs: A dictionary containing anything that was set by 2832 `fn` by calling `context.set_non_tensor_output`. 2833 """ 2834 _require_cross_replica_or_default_context_extended(self) 2835 with self._container_strategy().scope(): 2836 return self._experimental_run_steps_on_iterator(fn, iterator, iterations, 2837 initial_loop_values) 2838 2839 def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, 2840 initial_loop_values): 2841 raise NotImplementedError("must be implemented in descendants") 2842 2843 def call_for_each_replica(self, fn, args=(), kwargs=None): 2844 """Run `fn` once per replica. 2845 2846 `fn` may call `tf.get_replica_context()` to access methods such as 2847 `replica_id_in_sync_group` and `merge_call()`. 2848 2849 `merge_call()` is used to communicate between the replicas and 2850 re-enter the cross-replica context. All replicas pause their execution 2851 having encountered a `merge_call()` call. After that the 2852 `merge_fn`-function is executed. Its results are then unwrapped and 2853 given back to each replica call. After that execution resumes until 2854 `fn` is complete or encounters another `merge_call()`. Example: 2855 2856 ```python 2857 # Called once in "cross-replica" context. 2858 def merge_fn(distribution, three_plus_replica_id): 2859 # sum the values across replicas 2860 return sum(distribution.experimental_local_results(three_plus_replica_id)) 2861 2862 # Called once per replica in `distribution`, in a "replica" context. 2863 def fn(three): 2864 replica_ctx = tf.get_replica_context() 2865 v = three + replica_ctx.replica_id_in_sync_group 2866 # Computes the sum of the `v` values across all replicas. 2867 s = replica_ctx.merge_call(merge_fn, args=(v,)) 2868 return s + v 2869 2870 with distribution.scope(): 2871 # in "cross-replica" context 2872 ... 2873 merged_results = distribution.run(fn, args=[3]) 2874 # merged_results has the values from every replica execution of `fn`. 2875 # This statement prints a list: 2876 print(distribution.experimental_local_results(merged_results)) 2877 ``` 2878 2879 Args: 2880 fn: function to run (will be run once per replica). 2881 args: Tuple or list with positional arguments for `fn`. 2882 kwargs: Dict with keyword arguments for `fn`. 2883 2884 Returns: 2885 Merged return value of `fn` across all replicas. 2886 """ 2887 _require_cross_replica_or_default_context_extended(self) 2888 if kwargs is None: 2889 kwargs = {} 2890 with self._container_strategy().scope(): 2891 return self._call_for_each_replica(fn, args, kwargs) 2892 2893 def _call_for_each_replica(self, fn, args, kwargs): 2894 raise NotImplementedError("must be implemented in descendants") 2895 2896 def read_var(self, v): 2897 """Reads the value of a variable. 2898 2899 Returns the aggregate value of a replica-local variable, or the 2900 (read-only) value of any other variable. 2901 2902 Args: 2903 v: A variable allocated within the scope of this `tf.distribute.Strategy`. 2904 2905 Returns: 2906 A tensor representing the value of `v`, aggregated across replicas if 2907 necessary. 2908 """ 2909 raise NotImplementedError("must be implemented in descendants") 2910 2911 def update_non_slot( 2912 self, colocate_with, fn, args=(), kwargs=None, group=True): 2913 """Runs `fn(*args, **kwargs)` on `colocate_with` devices. 2914 2915 Used to update non-slot variables. 2916 2917 DEPRECATED: TF 1.x ONLY. 2918 2919 Args: 2920 colocate_with: Devices returned by `non_slot_devices()`. 2921 fn: Function to execute. 2922 args: Tuple or list. Positional arguments to pass to `fn()`. 2923 kwargs: Dict with keyword arguments to pass to `fn()`. 2924 group: Boolean. Defaults to True. If False, the return value will be 2925 unwrapped. 2926 2927 Returns: 2928 Return value of `fn`, possibly merged across devices. 2929 """ 2930 _require_cross_replica_or_default_context_extended(self) 2931 if kwargs is None: 2932 kwargs = {} 2933 fn = autograph.tf_convert( 2934 fn, autograph_ctx.control_status_ctx(), convert_by_default=False) 2935 with self._container_strategy().scope(): 2936 return self._update_non_slot(colocate_with, fn, args, kwargs, group) 2937 2938 def _update_non_slot(self, colocate_with, fn, args, kwargs, group): 2939 raise NotImplementedError("must be implemented in descendants") 2940 2941 def non_slot_devices(self, var_list): 2942 """Device(s) for non-slot variables. 2943 2944 DEPRECATED: TF 1.x ONLY. 2945 2946 This method returns non-slot devices where non-slot variables are placed. 2947 Users can create non-slot variables on these devices by using a block: 2948 2949 ```python 2950 with tf.distribute.StrategyExtended.colocate_vars_with(tf.distribute.StrategyExtended.non_slot_devices(...)): 2951 ... 2952 ``` 2953 2954 Args: 2955 var_list: The list of variables being optimized, needed with the 2956 default `tf.distribute.Strategy`. 2957 Returns: 2958 A sequence of devices for non-slot variables. 2959 """ 2960 raise NotImplementedError("must be implemented in descendants") 2961 2962 def _use_merge_call(self): 2963 """Whether to use merge-calls inside the distributed strategy.""" 2964 return True 2965 2966 @property 2967 def experimental_between_graph(self): 2968 """Whether the strategy uses between-graph replication or not. 2969 2970 This is expected to return a constant value that will not be changed 2971 throughout its life cycle. 2972 """ 2973 raise NotImplementedError("must be implemented in descendants") 2974 2975 @property 2976 def experimental_should_init(self): 2977 """Whether initialization is needed.""" 2978 raise NotImplementedError("must be implemented in descendants") 2979 2980 @property 2981 def should_checkpoint(self): 2982 """Whether checkpointing is needed.""" 2983 raise NotImplementedError("must be implemented in descendants") 2984 2985 @property 2986 def should_save_summary(self): 2987 """Whether saving summaries is needed.""" 2988 raise NotImplementedError("must be implemented in descendants") 2989 2990 2991# A note about the difference between the context managers 2992# `ReplicaContext` (defined here) and `_CurrentDistributionContext` 2993# (defined above) used by `tf.distribute.Strategy.scope()`: 2994# 2995# * a ReplicaContext is only present during a `run()` 2996# call (except during a `merge_run` call) and in such a scope it 2997# will be returned by calls to `get_replica_context()`. Implementers of new 2998# Strategy descendants will frequently also need to 2999# define a descendant of ReplicaContext, and are responsible for 3000# entering and exiting this context. 3001# 3002# * Strategy.scope() sets up a variable_creator scope that 3003# changes variable creation calls (e.g. to make mirrored 3004# variables). This is intended as an outer scope that users enter once 3005# around their model creation and graph definition. There is no 3006# anticipated need to define descendants of _CurrentDistributionContext. 3007# It sets the current Strategy for purposes of 3008# `get_strategy()` and `has_strategy()` 3009# and switches the thread mode to a "cross-replica context". 3010class ReplicaContextBase(object): 3011 """A class with a collection of APIs that can be called in a replica context. 3012 3013 You can use `tf.distribute.get_replica_context` to get an instance of 3014 `ReplicaContext`, which can only be called inside the function passed to 3015 `tf.distribute.Strategy.run`. 3016 3017 >>> strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) 3018 >>> def func(): 3019 ... replica_context = tf.distribute.get_replica_context() 3020 ... return replica_context.replica_id_in_sync_group 3021 >>> strategy.run(func) 3022 PerReplica:{ 3023 0: <tf.Tensor: shape=(), dtype=int32, numpy=0>, 3024 1: <tf.Tensor: shape=(), dtype=int32, numpy=1> 3025 } 3026 """ 3027 3028 def __init__(self, strategy, replica_id_in_sync_group): 3029 """Creates a ReplicaContext. 3030 3031 Args: 3032 strategy: A `tf.distribute.Strategy`. 3033 replica_id_in_sync_group: An integer, a `Tensor` or None. Prefer an 3034 integer whenever possible to avoid issues with nested `tf.function`. It 3035 accepts a `Tensor` only to be compatible with `tpu.replicate`. 3036 """ 3037 self._strategy = strategy 3038 self._thread_context = distribution_strategy_context._InReplicaThreadMode( # pylint: disable=protected-access 3039 self) 3040 if not (replica_id_in_sync_group is None or 3041 tensor_util.is_tf_type(replica_id_in_sync_group) or 3042 isinstance(replica_id_in_sync_group, int)): 3043 raise ValueError( 3044 "replica_id_in_sync_group can only be an integer, a Tensor or None.") 3045 self._replica_id_in_sync_group = replica_id_in_sync_group 3046 # We need this check because TPUContext extends from ReplicaContext and 3047 # does not pass a strategy object since it is used by TPUEstimator. 3048 if strategy: 3049 self._local_replica_id = strategy.extended._get_local_replica_id( 3050 replica_id_in_sync_group) 3051 self._summary_recording_distribution_strategy = None 3052 3053 @doc_controls.do_not_generate_docs 3054 def __enter__(self): 3055 _push_per_thread_mode(self._thread_context) 3056 3057 def replica_id_is_zero(): 3058 return math_ops.equal(self.replica_id_in_sync_group, 3059 constant_op.constant(0)) 3060 3061 summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access 3062 self._summary_recording_distribution_strategy = ( 3063 summary_state.is_recording_distribution_strategy) 3064 summary_state.is_recording_distribution_strategy = replica_id_is_zero 3065 3066 @doc_controls.do_not_generate_docs 3067 def __exit__(self, exception_type, exception_value, traceback): 3068 summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access 3069 summary_state.is_recording_distribution_strategy = ( 3070 self._summary_recording_distribution_strategy) 3071 _pop_per_thread_mode() 3072 3073 def merge_call(self, merge_fn, args=(), kwargs=None): 3074 """Merge args across replicas and run `merge_fn` in a cross-replica context. 3075 3076 This allows communication and coordination when there are multiple calls 3077 to the step_fn triggered by a call to `strategy.run(step_fn, ...)`. 3078 3079 See `tf.distribute.Strategy.run` for an explanation. 3080 3081 If not inside a distributed scope, this is equivalent to: 3082 3083 ``` 3084 strategy = tf.distribute.get_strategy() 3085 with cross-replica-context(strategy): 3086 return merge_fn(strategy, *args, **kwargs) 3087 ``` 3088 3089 Args: 3090 merge_fn: Function that joins arguments from threads that are given as 3091 PerReplica. It accepts `tf.distribute.Strategy` object as 3092 the first argument. 3093 args: List or tuple with positional per-thread arguments for `merge_fn`. 3094 kwargs: Dict with keyword per-thread arguments for `merge_fn`. 3095 3096 Returns: 3097 The return value of `merge_fn`, except for `PerReplica` values which are 3098 unpacked. 3099 """ 3100 require_replica_context(self) 3101 if kwargs is None: 3102 kwargs = {} 3103 3104 merge_fn = autograph.tf_convert( 3105 merge_fn, autograph_ctx.control_status_ctx(), convert_by_default=False) 3106 return self._merge_call(merge_fn, args, kwargs) 3107 3108 def _merge_call(self, merge_fn, args, kwargs): 3109 """Default implementation for single replica.""" 3110 _push_per_thread_mode( # thread-local, so not needed with multiple threads 3111 distribution_strategy_context._CrossReplicaThreadMode(self._strategy)) # pylint: disable=protected-access 3112 try: 3113 return merge_fn(self._strategy, *args, **kwargs) 3114 finally: 3115 _pop_per_thread_mode() 3116 3117 @property 3118 def num_replicas_in_sync(self): 3119 """Returns number of replicas that are kept in sync.""" 3120 return self._strategy.num_replicas_in_sync 3121 3122 @property 3123 def replica_id_in_sync_group(self): 3124 """Returns the id of the replica. 3125 3126 This identifies the replica among all replicas that are kept in sync. The 3127 value of the replica id can range from 0 to 3128 `tf.distribute.ReplicaContext.num_replicas_in_sync` - 1. 3129 3130 NOTE: This is not guaranteed to be the same ID as the XLA replica ID use 3131 for low-level operations such as collective_permute. 3132 3133 Returns: 3134 a `Tensor`. 3135 """ 3136 # It's important to prefer making the Tensor at call time whenever possible. 3137 # Keeping Tensors in global states doesn't work well with nested 3138 # tf.function, since it's possible that the tensor is generated in one func 3139 # graph, and gets captured by another, which will result in a subtle "An op 3140 # outside of the function building code is being passed a Graph tensor" 3141 # error. Making the tensor at call time to ensure it is the same graph where 3142 # it's used. However to be compatible with tpu.replicate(), 3143 # self._replica_id_in_sync_group can also be a Tensor. 3144 if tensor_util.is_tf_type(self._replica_id_in_sync_group): 3145 return self._replica_id_in_sync_group 3146 return constant_op.constant( 3147 self._replica_id_in_sync_group, 3148 dtypes.int32, 3149 name="replica_id_in_sync_group") 3150 3151 @property 3152 def _replica_id(self): 3153 """This is the local replica id in a given sync group.""" 3154 return self._local_replica_id 3155 3156 @property 3157 def strategy(self): 3158 """The current `tf.distribute.Strategy` object.""" 3159 return self._strategy 3160 3161 @property 3162 @deprecation.deprecated(None, "Please avoid relying on devices property.") 3163 def devices(self): 3164 """Returns the devices this replica is to be executed on, as a tuple of strings. 3165 3166 NOTE: For `tf.distribute.MirroredStrategy` and 3167 `tf.distribute.experimental.MultiWorkerMirroredStrategy`, this returns a 3168 nested 3169 list of device strings, e.g, [["GPU:0"]]. 3170 """ 3171 require_replica_context(self) 3172 return (device_util.current(),) 3173 3174 def all_reduce(self, reduce_op, value, options=None): 3175 """All-reduces `value` across all replicas. 3176 3177 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 3178 >>> def step_fn(): 3179 ... ctx = tf.distribute.get_replica_context() 3180 ... value = tf.identity(1.) 3181 ... return ctx.all_reduce(tf.distribute.ReduceOp.SUM, value) 3182 >>> strategy.experimental_local_results(strategy.run(step_fn)) 3183 (<tf.Tensor: shape=(), dtype=float32, numpy=2.0>, 3184 <tf.Tensor: shape=(), dtype=float32, numpy=2.0>) 3185 3186 It supports batched operations. You can pass a list of values and it 3187 attempts to batch them when possible. You can also specify `options` 3188 to indicate the desired batching behavior, e.g. batch the values into 3189 multiple packs so that they can better overlap with computations. 3190 3191 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 3192 >>> def step_fn(): 3193 ... ctx = tf.distribute.get_replica_context() 3194 ... value1 = tf.identity(1.) 3195 ... value2 = tf.identity(2.) 3196 ... return ctx.all_reduce(tf.distribute.ReduceOp.SUM, [value1, value2]) 3197 >>> strategy.experimental_local_results(strategy.run(step_fn)) 3198 ([<tf.Tensor: shape=(), dtype=float32, numpy=2.0>, 3199 <tf.Tensor: shape=(), dtype=float32, numpy=4.0>], 3200 [<tf.Tensor: shape=(), dtype=float32, numpy=2.0>, 3201 <tf.Tensor: shape=(), dtype=float32, numpy=4.0>]) 3202 3203 Note that all replicas need to participate in the all-reduce, otherwise this 3204 operation hangs. Note that if there're multiple all-reduces, they need to 3205 execute in the same order on all replicas. Dispatching all-reduce based on 3206 conditions is usually error-prone. 3207 3208 Known limitation: if `value` contains `tf.IndexedSlices`, attempting to 3209 compute gradient w.r.t `value` would result in an error. 3210 3211 This API currently can only be called in the replica context. Other 3212 variants to reduce values across replicas are: 3213 * `tf.distribute.StrategyExtended.reduce_to`: the reduce and all-reduce API 3214 in the cross-replica context. 3215 * `tf.distribute.StrategyExtended.batch_reduce_to`: the batched reduce and 3216 all-reduce API in the cross-replica context. 3217 * `tf.distribute.Strategy.reduce`: a more convenient method to reduce 3218 to the host in cross-replica context. 3219 3220 Args: 3221 reduce_op: a `tf.distribute.ReduceOp` value specifying how values should 3222 be combined. Allows using string representation of the enum such as 3223 "SUM", "MEAN". 3224 value: a potentially nested structure of `tf.Tensor` or `tf.IndexedSlices` which 3225 `tf.nest.flatten` accepts. The structure and the shapes of `value` need to be 3226 same on all replicas. 3227 options: a `tf.distribute.experimental.CommunicationOptions`. Options to 3228 perform collective operations. This overrides the default options if the 3229 `tf.distribute.Strategy` takes one in the constructor. See 3230 `tf.distribute.experimental.CommunicationOptions` for details of the 3231 options. 3232 3233 Returns: 3234 A nested structure of `tf.Tensor` with the reduced values. The structure 3235 is the same as `value`. 3236 """ 3237 flattened_value = nest.flatten(value) 3238 has_indexed_slices = False 3239 3240 for v in flattened_value: 3241 if isinstance(v, indexed_slices.IndexedSlices): 3242 has_indexed_slices = True 3243 3244 if isinstance(reduce_op, six.string_types): 3245 reduce_op = reduce_util.ReduceOp(reduce_op.upper()) 3246 if options is None: 3247 options = collective_util.Options() 3248 3249 def batch_all_reduce(strategy, *value_flat): 3250 return strategy.extended.batch_reduce_to( 3251 reduce_op, [(v, _batch_reduce_destination(v)) for v in value_flat], 3252 options) 3253 3254 # Due to the use of `capture_call_time_value` in collective ops, we have 3255 # to maintain two branches: one w/ merge_call and one w/o. Details can be 3256 # found in b/184009754. 3257 if self._strategy.extended._use_merge_call(): # pylint: disable=protected-access 3258 # TODO(cjfj): Work out why `batch_reduce` doesn't return the correct grad. 3259 if has_indexed_slices: 3260 return nest.pack_sequence_as( 3261 value, 3262 self.merge_call(batch_all_reduce, args=flattened_value)) 3263 3264 @custom_gradient.custom_gradient 3265 def grad_wrapper(*xs): 3266 ys = self.merge_call(batch_all_reduce, args=xs) 3267 # The gradient of an all-sum is itself an all-sum (all-mean, likewise). 3268 return ys, lambda *dy_s: self.all_reduce(reduce_op, dy_s) 3269 return nest.pack_sequence_as(value, grad_wrapper(*flattened_value)) 3270 else: 3271 if has_indexed_slices: 3272 return nest.pack_sequence_as( 3273 value, 3274 self._strategy.extended._replica_ctx_all_reduce( # pylint: disable=protected-access 3275 reduce_op, flattened_value, options)) 3276 3277 @custom_gradient.custom_gradient 3278 def grad_wrapper(*xs): 3279 ys = self._strategy.extended._replica_ctx_all_reduce( # pylint: disable=protected-access 3280 reduce_op, xs, options) 3281 # The gradient of an all-sum is itself an all-sum (all-mean, likewise). 3282 return ys, lambda *dy_s: self.all_reduce(reduce_op, dy_s) 3283 3284 return nest.pack_sequence_as(value, grad_wrapper(*flattened_value)) 3285 3286 # TODO(josh11b): Implement `start_all_reduce(method, t)` for efficient 3287 # all-reduce. It would return a function returning the result of reducing `t` 3288 # across all replicas. The caller would wait to call this function until they 3289 # needed the reduce result, allowing an efficient implementation: 3290 # * With eager execution, the reduction could be performed asynchronously 3291 # in the background, not blocking until the result was needed. 3292 # * When constructing a graph, it could batch up all reduction requests up 3293 # to that point that the first result is needed. Most likely this can be 3294 # implemented in terms of `merge_call()` and `batch_reduce_to()`. 3295 3296 3297@tf_export("distribute.ReplicaContext", v1=[]) 3298class ReplicaContext(ReplicaContextBase): 3299 3300 __doc__ = ReplicaContextBase.__doc__ 3301 3302 def all_gather(self, value, axis, options=None): 3303 """All-gathers `value` across all replicas along `axis`. 3304 3305 Note: An `all_gather` method can only be called in replica context. For 3306 a cross-replica context counterpart, see `tf.distribute.Strategy.gather`. 3307 All replicas need to participate in the all-gather, otherwise this 3308 operation hangs. So if `all_gather` is called in any replica, it must be 3309 called in all replicas. 3310 3311 Note: If there are multiple `all_gather` calls, they need to be executed in 3312 the same order on all replicas. Dispatching `all_gather` based on conditions 3313 is usually error-prone. 3314 3315 For all strategies except `tf.distribute.TPUStrategy`, the input 3316 `value` on different replicas must have the same rank, and their shapes must 3317 be the same in all dimensions except the `axis`-th dimension. In other 3318 words, their shapes cannot be different in a dimension `d` where `d` does 3319 not equal to the `axis` argument. For example, given a 3320 `tf.distribute.DistributedValues` with component tensors of shape 3321 `(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can call 3322 `all_gather(..., axis=1, ...)` on it, but not `all_gather(..., axis=0, ...)` 3323 or `all_gather(..., axis=2, ...)`. However, with 3324 `tf.distribute.TPUStrategy`, all tensors must have exactly the same rank and 3325 same shape. 3326 3327 Note: The input `value` must have a non-zero rank. Otherwise, consider using 3328 `tf.expand_dims` before gathering them. 3329 3330 You can pass in a single tensor to all-gather: 3331 3332 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 3333 >>> @tf.function 3334 ... def gather_value(): 3335 ... ctx = tf.distribute.get_replica_context() 3336 ... local_value = tf.constant([1, 2, 3]) 3337 ... return ctx.all_gather(local_value, axis=0) 3338 >>> result = strategy.run(gather_value) 3339 >>> result 3340 PerReplica:{ 3341 0: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>, 3342 1: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)> 3343 } 3344 >>> strategy.experimental_local_results(result) 3345 (<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], 3346 dtype=int32)>, 3347 <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], 3348 dtype=int32)>) 3349 3350 3351 You can also pass in a nested structure of tensors to all-gather, say, a 3352 list: 3353 3354 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 3355 >>> @tf.function 3356 ... def gather_nest(): 3357 ... ctx = tf.distribute.get_replica_context() 3358 ... value_1 = tf.constant([1, 2, 3]) 3359 ... value_2 = tf.constant([[1, 2], [3, 4]]) 3360 ... # all_gather a nest of `tf.distribute.DistributedValues` 3361 ... return ctx.all_gather([value_1, value_2], axis=0) 3362 >>> result = strategy.run(gather_nest) 3363 >>> result 3364 [PerReplica:{ 3365 0: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>, 3366 1: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)> 3367 }, PerReplica:{ 3368 0: <tf.Tensor: shape=(4, 2), dtype=int32, numpy= 3369 array([[1, 2], 3370 [3, 4], 3371 [1, 2], 3372 [3, 4]], dtype=int32)>, 3373 1: <tf.Tensor: shape=(4, 2), dtype=int32, numpy= 3374 array([[1, 2], 3375 [3, 4], 3376 [1, 2], 3377 [3, 4]], dtype=int32)> 3378 }] 3379 >>> strategy.experimental_local_results(result) 3380 ([<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>, 3381 <tf.Tensor: shape=(4, 2), dtype=int32, numpy= 3382 array([[1, 2], 3383 [3, 4], 3384 [1, 2], 3385 [3, 4]], dtype=int32)>], 3386 [<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>, 3387 <tf.Tensor: shape=(4, 2), dtype=int32, numpy= 3388 array([[1, 2], 3389 [3, 4], 3390 [1, 2], 3391 [3, 4]], dtype=int32)>]) 3392 3393 3394 What if you are all-gathering tensors with different shapes on different 3395 replicas? Consider the following example with two replicas, where you have 3396 `value` as a nested structure consisting of two items to all-gather, `a` and 3397 `b`. 3398 3399 * On Replica 0, `value` is `{'a': [0], 'b': [[0, 1]]}`. 3400 * On Replica 1, `value` is `{'a': [1], 'b': [[2, 3], [4, 5]]}`. 3401 * Result for `all_gather` with `axis=0` (on each of the replicas) is: 3402 3403 ``` 3404 {'a': [1, 2], 'b': [[0, 1], [2, 3], [4, 5]]} 3405 ``` 3406 3407 Args: 3408 value: a nested structure of `tf.Tensor` which `tf.nest.flatten` accepts, 3409 or a `tf.distribute.DistributedValues` instance. The structure of the 3410 `tf.Tensor` need to be same on all replicas. The underlying tensor 3411 constructs can only be dense tensors with non-zero rank, NOT 3412 `tf.IndexedSlices`. 3413 axis: 0-D int32 Tensor. Dimension along which to gather. 3414 options: a `tf.distribute.experimental.CommunicationOptions`. Options to 3415 perform collective operations. This overrides the default options if the 3416 `tf.distribute.Strategy` takes one in the constructor. See 3417 `tf.distribute.experimental.CommunicationOptions` for details of the 3418 options. 3419 3420 Returns: 3421 A nested structure of `tf.Tensor` with the gathered values. The structure 3422 is the same as `value`. 3423 """ 3424 for v in nest.flatten(value): 3425 if isinstance(v, indexed_slices.IndexedSlices): 3426 raise NotImplementedError("all_gather does not support IndexedSlices") 3427 3428 if options is None: 3429 options = collective_util.Options() 3430 3431 def batch_all_gather(strategy, *value_flat): 3432 return strategy.extended._batch_gather_to( # pylint: disable=protected-access 3433 [(v, _batch_reduce_destination(v)) for v in value_flat], axis, 3434 options) 3435 3436 @custom_gradient.custom_gradient 3437 def grad_wrapper(*xs): 3438 ys = self.merge_call(batch_all_gather, args=xs) 3439 3440 def grad(*dy_s): 3441 grads = self.all_reduce(reduce_util.ReduceOp.SUM, dy_s) 3442 new_grads = [] 3443 for i, grad in enumerate(grads): 3444 input_shape = array_ops.shape(xs[i]) 3445 axis_dim = array_ops.reshape(input_shape[axis], [1]) 3446 with ops.control_dependencies([array_ops.identity(grads)]): 3447 d = self.all_gather(axis_dim, axis=0) 3448 begin_dim = math_ops.reduce_sum(d[:self.replica_id_in_sync_group]) 3449 end_dim = begin_dim + array_ops.shape(xs[i])[axis] 3450 new_grad = array_ops.gather( 3451 grad, axis=axis, indices=math_ops.range(begin_dim, end_dim)) 3452 new_grads.append(new_grad) 3453 return new_grads 3454 3455 return ys, grad 3456 3457 return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value))) 3458 3459 def _update(self, var, fn, args=(), kwargs=None, group=True): 3460 """Run `fn` to update `var` with `args` and `kwargs` in replica context. 3461 3462 `tf.distribute.ReplicaContext.update` takes a (distributed) variable `var` 3463 to be updated, an update function `fn`, and `args` and `kwargs` for `fn`. 3464 `fn` applies to each component variable of `var` with corresponding input 3465 values from `args` and `kwargs`. 3466 3467 Example usage: 3468 3469 >>> strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # 2 replicas 3470 >>> with strategy.scope(): 3471 ... distributed_variable = tf.Variable(5.0) 3472 >>> distributed_variable 3473 MirroredVariable:{ 3474 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=5.0>, 3475 1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=5.0> 3476 } 3477 >>> def replica_fn(v): 3478 ... value = tf.identity(1.0) 3479 ... replica_context = tf.distribute.get_replica_context() 3480 ... update_fn = lambda var, value: var.assign(value) 3481 ... replica_context._update(v, update_fn, args=(value,)) 3482 >>> strategy.run(replica_fn, args=(distributed_variable,)) 3483 >>> distributed_variable 3484 MirroredVariable:{ 3485 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>, 3486 1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=1.0> 3487 } 3488 3489 This API must be called in a replica context. 3490 3491 Note that if `var` is a MirroredVariable (i.e., the type of variable created 3492 under the scope of a synchronous strategy, and is synchronized on-write, see 3493 `tf.VariableSynchronization` for more information) and `args`/`kwargs` 3494 contains different values for different replicas, `var` will be dangerously 3495 out of synchronization. Thus we recommend using `variable.assign(value)` as 3496 long as you can, which under the hood aggregates the updates and guarantees 3497 the synchronization. The case where you actually want this API instead of 3498 `variable.assign(value)` is that before assigning `value` to the `variable`, 3499 you'd like to conduct some pre-`assign` computation colocated with the 3500 variable devices (i.e. where variables reside, for MirroredStrategy they are 3501 the same as the compute device, for ParameterServerStrategy they refer to 3502 parameter servers). E.g., 3503 3504 ```python 3505 strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # 2 replicas 3506 with strategy.scope(): 3507 v = tf.Variable(5.0, aggregation=tf.VariableAggregation.SUM) 3508 def replica_fn(inputs): 3509 value = computation(inputs) 3510 replica_context = tf.distribute.get_replica_context() 3511 reduced_value = replica_context.all_reduce(value) 3512 3513 def update_fn(var, value): 3514 # this computation will colocate with `var`'s device 3515 updated_value = post_reduce_pre_update_computation(value) 3516 var.assign(value) 3517 3518 replica_context._update(v, update_fn, args=(reduced_value,)) 3519 3520 strategy.run(replica_fn, args=(inputs,)) 3521 ``` 3522 3523 This code snippet is consistent across all strategies. If you directly 3524 compute and use `assign` in the replica context instead of wrapping it with 3525 `update`, for strategies with fewer variable devices than compute devices 3526 (e.g., parameter server strategy, usually), the 3527 `post_reduce_pre_update_computation` will happen 3528 N==number_of_compute_devices times which is less performant. 3529 3530 3531 Args: 3532 var: Variable, possibly distributed to multiple devices, to operate on. 3533 fn: Function to call. Should take the variable as the first argument. 3534 args: Tuple or list. Additional positional arguments to pass to `fn()`. 3535 kwargs: Dict with keyword arguments to pass to `fn()`. 3536 group: Boolean. Defaults to True. Most strategies enter a merge_call to 3537 conduct update in cross-replica context, and group=True guarantees updates 3538 on all replicas is executed. 3539 3540 Returns: 3541 The return value of `fn` for the local replica. 3542 """ 3543 if kwargs is None: 3544 kwargs = {} 3545 return self._strategy.extended._replica_ctx_update(var, fn, args=args, kwargs=kwargs, group=group) # pylint: disable=protected-access 3546 3547 3548@tf_export(v1=["distribute.ReplicaContext"]) 3549class ReplicaContextV1(ReplicaContextBase): 3550 __doc__ = ReplicaContextBase.__doc__ 3551 3552 3553def _batch_reduce_destination(x): 3554 """Returns the destinations for batch all-reduce.""" 3555 if isinstance(x, ops.Tensor): 3556 # If this is a one device strategy. 3557 return x.device 3558 else: 3559 return x 3560# ------------------------------------------------------------------------------ 3561 3562 3563_creating_default_strategy_singleton = False 3564 3565 3566class _DefaultDistributionStrategyV1(StrategyV1): 3567 """Default `tf.distribute.Strategy` if none is explicitly selected.""" 3568 3569 def __init__(self): 3570 if not _creating_default_strategy_singleton: 3571 raise RuntimeError("Should only create a single instance of " 3572 "_DefaultDistributionStrategy") 3573 super(_DefaultDistributionStrategyV1, 3574 self).__init__(_DefaultDistributionExtended(self)) 3575 3576 def __deepcopy__(self, memo): 3577 del memo 3578 raise RuntimeError("Should only create a single instance of " 3579 "_DefaultDistributionStrategy") 3580 3581 3582class _DefaultDistributionStrategy(Strategy): 3583 """Default `tf.distribute.Strategy` if none is explicitly selected.""" 3584 3585 def __init__(self): 3586 if not _creating_default_strategy_singleton: 3587 raise RuntimeError("Should only create a single instance of " 3588 "_DefaultDistributionStrategy") 3589 super(_DefaultDistributionStrategy, self).__init__( 3590 _DefaultDistributionExtended(self)) 3591 3592 def __deepcopy__(self, memo): 3593 del memo 3594 raise RuntimeError("Should only create a single instance of " 3595 "_DefaultDistributionStrategy") 3596 3597 3598class _DefaultDistributionContext(object): 3599 """Context manager setting the default `tf.distribute.Strategy`.""" 3600 3601 __slots__ = ["_var_creator_scope", "_strategy", "_nested_count"] 3602 3603 def __init__(self, strategy): 3604 3605 def creator(next_creator, **kwargs): 3606 _require_strategy_scope_strategy(strategy) 3607 return next_creator(**kwargs) 3608 3609 self._var_creator_scope = variable_scope.variable_creator_scope(creator) 3610 self._strategy = strategy 3611 self._nested_count = 0 3612 3613 def __enter__(self): 3614 # Allow this scope to be entered if this strategy is already in scope. 3615 if distribution_strategy_context.has_strategy(): 3616 raise RuntimeError("Must not nest tf.distribute.Strategy scopes.") 3617 if self._nested_count == 0: 3618 self._var_creator_scope.__enter__() 3619 self._nested_count += 1 3620 return self._strategy 3621 3622 def __exit__(self, exception_type, exception_value, traceback): 3623 self._nested_count -= 1 3624 if self._nested_count == 0: 3625 try: 3626 self._var_creator_scope.__exit__( 3627 exception_type, exception_value, traceback) 3628 except RuntimeError as e: 3629 six.raise_from( 3630 RuntimeError("Variable creator scope nesting error: move call to " 3631 "tf.distribute.set_strategy() out of `with` scope."), 3632 e) 3633 3634 3635class _DefaultDistributionExtended(StrategyExtendedV1): 3636 """Implementation of _DefaultDistributionStrategy.""" 3637 3638 def __init__(self, container_strategy): 3639 super(_DefaultDistributionExtended, self).__init__(container_strategy) 3640 self._retrace_functions_for_each_device = False 3641 3642 def _scope(self, strategy): 3643 """Context manager setting a variable creator and `self` as current.""" 3644 return _DefaultDistributionContext(strategy) 3645 3646 def colocate_vars_with(self, colocate_with_variable): 3647 """Does not require `self.scope`.""" 3648 _require_strategy_scope_extended(self) 3649 return ops.colocate_with(colocate_with_variable) 3650 3651 def variable_created_in_scope(self, v): 3652 return v._distribute_strategy is None # pylint: disable=protected-access 3653 3654 def _experimental_distribute_dataset(self, dataset, options): 3655 return dataset 3656 3657 def _distribute_datasets_from_function(self, dataset_fn, options): 3658 return dataset_fn(InputContext()) 3659 3660 def _experimental_distribute_values_from_function(self, value_fn): 3661 return value_fn(ValueContext()) 3662 3663 def _make_dataset_iterator(self, dataset): 3664 return _DefaultDistributionExtended.DefaultInputIterator(dataset) 3665 3666 def _make_input_fn_iterator(self, 3667 input_fn, 3668 replication_mode=InputReplicationMode.PER_WORKER): 3669 dataset = input_fn(InputContext()) 3670 return _DefaultDistributionExtended.DefaultInputIterator(dataset) 3671 3672 def _experimental_make_numpy_dataset(self, numpy_input, session): 3673 numpy_flat = nest.flatten(numpy_input) 3674 vars_flat = tuple( 3675 variable_scope.variable(array_ops.zeros(i.shape, i.dtype), 3676 trainable=False, use_resource=True) 3677 for i in numpy_flat 3678 ) 3679 for v, i in zip(vars_flat, numpy_flat): 3680 numpy_dataset.init_var_from_numpy(v, i, session) 3681 vars_nested = nest.pack_sequence_as(numpy_input, vars_flat) 3682 return dataset_ops.Dataset.from_tensor_slices(vars_nested) 3683 3684 def _broadcast_to(self, tensor, destinations): 3685 if destinations is None: 3686 return tensor 3687 else: 3688 raise NotImplementedError("TODO") 3689 3690 def _call_for_each_replica(self, fn, args, kwargs): 3691 with ReplicaContext(self._container_strategy(), replica_id_in_sync_group=0): 3692 return fn(*args, **kwargs) 3693 3694 def _reduce_to(self, reduce_op, value, destinations, options): 3695 # TODO(josh11b): Use destinations? 3696 del reduce_op, destinations, options 3697 return value 3698 3699 def _gather_to_implementation(self, value, destinations, axis, options): 3700 del destinations, axis, options 3701 return value 3702 3703 def _update(self, var, fn, args, kwargs, group): 3704 # The implementations of _update() and _update_non_slot() are identical 3705 # except _update() passes `var` as the first argument to `fn()`. 3706 return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group) 3707 3708 def _update_non_slot(self, colocate_with, fn, args, kwargs, should_group): 3709 # TODO(josh11b): Figure out what we should be passing to UpdateContext() 3710 # once that value is used for something. 3711 with UpdateContext(colocate_with): 3712 result = fn(*args, **kwargs) 3713 if should_group: 3714 return result 3715 else: 3716 return nest.map_structure(self._local_results, result) 3717 3718 def read_var(self, replica_local_var): 3719 return array_ops.identity(replica_local_var) 3720 3721 def _local_results(self, distributed_value): 3722 return (distributed_value,) 3723 3724 def value_container(self, value): 3725 return value 3726 3727 @property 3728 def _num_replicas_in_sync(self): 3729 return 1 3730 3731 @property 3732 def worker_devices(self): 3733 raise RuntimeError("worker_devices() method unsupported by default " 3734 "tf.distribute.Strategy.") 3735 3736 @property 3737 def parameter_devices(self): 3738 raise RuntimeError("parameter_devices() method unsupported by default " 3739 "tf.distribute.Strategy.") 3740 3741 def non_slot_devices(self, var_list): 3742 return min(var_list, key=lambda x: x.name) 3743 3744 def _in_multi_worker_mode(self): 3745 """Whether this strategy indicates working in multi-worker settings.""" 3746 # Default strategy doesn't indicate multi-worker training. 3747 return False 3748 3749 @property 3750 def should_checkpoint(self): 3751 return True 3752 3753 @property 3754 def should_save_summary(self): 3755 return True 3756 3757 def _get_local_replica_id(self, replica_id_in_sync_group): 3758 return replica_id_in_sync_group 3759 3760 def _get_replica_id_in_sync_group(self, replica_id): 3761 return replica_id 3762 3763 # TODO(priyag): This should inherit from `InputIterator`, once dependency 3764 # issues have been resolved. 3765 class DefaultInputIterator(object): 3766 """Default implementation of `InputIterator` for default strategy.""" 3767 3768 def __init__(self, dataset): 3769 self._dataset = dataset 3770 if eager_context.executing_eagerly(): 3771 self._iterator = dataset_ops.make_one_shot_iterator(dataset) 3772 else: 3773 self._iterator = dataset_ops.make_initializable_iterator(dataset) 3774 3775 def get_next(self): 3776 return self._iterator.get_next() 3777 3778 def get_next_as_optional(self): 3779 return self._iterator.get_next_as_optional() 3780 3781 @deprecated(None, "Use the iterator's `initializer` property instead.") 3782 def initialize(self): 3783 """Initialize underlying iterators. 3784 3785 Returns: 3786 A list of any initializer ops that should be run. 3787 """ 3788 if eager_context.executing_eagerly(): 3789 self._iterator = self._dataset.make_one_shot_iterator() 3790 return [] 3791 else: 3792 return [self._iterator.initializer] 3793 3794 @property 3795 def initializer(self): 3796 """Returns a list of ops that initialize the iterator.""" 3797 return self.initialize() 3798 3799 # TODO(priyag): Delete this once all strategies use global batch size. 3800 @property 3801 def _global_batch_size(self): 3802 """Global and per-replica batching are equivalent for this strategy.""" 3803 return True 3804 3805 3806class _DefaultReplicaContext(ReplicaContext): 3807 """ReplicaContext for _DefaultDistributionStrategy.""" 3808 3809 @property 3810 def replica_id_in_sync_group(self): 3811 # Return 0 instead of a constant tensor to avoid creating a new node for 3812 # users who don't use distribution strategy. 3813 return 0 3814 3815 3816# ------------------------------------------------------------------------------ 3817# We haven't yet implemented deserialization for DistributedVariables. 3818# So here we catch any attempts to deserialize variables 3819# when using distribution strategies. 3820# pylint: disable=protected-access 3821_original_from_proto = resource_variable_ops._from_proto_fn 3822 3823 3824def _from_proto_fn(v, import_scope=None): 3825 if distribution_strategy_context.has_strategy(): 3826 raise NotImplementedError( 3827 "Deserialization of variables is not yet supported when using a " 3828 "tf.distribute.Strategy.") 3829 else: 3830 return _original_from_proto(v, import_scope=import_scope) 3831 3832resource_variable_ops._from_proto_fn = _from_proto_fn 3833# pylint: enable=protected-access 3834 3835 3836#------------------------------------------------------------------------------- 3837# Shorthand for some methods from distribution_strategy_context. 3838_push_per_thread_mode = distribution_strategy_context._push_per_thread_mode # pylint: disable=protected-access 3839_get_per_thread_mode = distribution_strategy_context._get_per_thread_mode # pylint: disable=protected-access 3840_pop_per_thread_mode = distribution_strategy_context._pop_per_thread_mode # pylint: disable=protected-access 3841_get_default_replica_mode = ( 3842 distribution_strategy_context._get_default_replica_mode) # pylint: disable=protected-access 3843 3844 3845# ------------------------------------------------------------------------------ 3846# Metrics to track which distribution strategy is being called 3847distribution_strategy_gauge = monitoring.StringGauge( 3848 "/tensorflow/api/distribution_strategy", 3849 "Gauge to track the type of distribution strategy used.", "TFVersion") 3850distribution_strategy_replica_gauge = monitoring.IntGauge( 3851 "/tensorflow/api/distribution_strategy/replica", 3852 "Gauge to track the number of replica each distribution strategy used.", 3853 "CountType") 3854distribution_strategy_input_api_counter = monitoring.Counter( 3855 "/tensorflow/api/distribution_strategy/input_api", 3856 "Counter to track the usage of the input APIs", "strategy", "api") 3857