1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=line-too-long 16"""Library for running a computation across multiple devices. 17 18The intent of this library is that you can write an algorithm in a stylized way 19and it will be usable with a variety of different `tf.distribute.Strategy` 20implementations. Each descendant will implement a different strategy for 21distributing the algorithm across multiple devices/machines. Furthermore, these 22changes can be hidden inside the specific layers and other library classes that 23need special treatment to run in a distributed setting, so that most users' 24model definition code can run unchanged. The `tf.distribute.Strategy` API works 25the same way with eager and graph execution. 26 27*Guides* 28 29* [TensorFlow v2.x](https://www.tensorflow.org/guide/distributed_training) 30* [TensorFlow v1.x](https://github.com/tensorflow/docs/blob/master/site/en/r1/guide/distribute_strategy.ipynb) 31 32*Tutorials* 33 34* [Distributed Training Tutorials](https://www.tensorflow.org/tutorials/distribute/) 35 36 The tutorials cover how to use `tf.distribute.Strategy` to do distributed 37 training with native Keras APIs, custom training loops, 38 and Estimator APIs. They also cover how to save/load model when using 39 `tf.distribute.Strategy`. 40 41*Glossary* 42 43* _Data parallelism_ is where we run multiple copies of the model 44 on different slices of the input data. This is in contrast to 45 _model parallelism_ where we divide up a single copy of a model 46 across multiple devices. 47 Note: we only support data parallelism for now, but 48 hope to add support for model parallelism in the future. 49* A _device_ is a CPU or accelerator (e.g. GPUs, TPUs) on some machine that 50 TensorFlow can run operations on (see e.g. `tf.device`). You may have multiple 51 devices on a single machine, or be connected to devices on multiple 52 machines. Devices used to run computations are called _worker devices_. 53 Devices used to store variables are _parameter devices_. For some strategies, 54 such as `tf.distribute.MirroredStrategy`, the worker and parameter devices 55 will be the same (see mirrored variables below). For others they will be 56 different. For example, `tf.distribute.experimental.CentralStorageStrategy` 57 puts the variables on a single device (which may be a worker device or may be 58 the CPU), and `tf.distribute.experimental.ParameterServerStrategy` puts the 59 variables on separate machines called _parameter servers_ (see below). 60* A _replica_ is one copy of the model, running on one slice of the 61 input data. Right now each replica is executed on its own 62 worker device, but once we add support for model parallelism 63 a replica may span multiple worker devices. 64* A _host_ is the CPU device on a machine with worker devices, typically 65 used for running input pipelines. 66* A _worker_ is defined to be the physical machine(s) containing the physical 67 devices (e.g. GPUs, TPUs) on which the replicated computation is executed. A 68 worker may contain one or more replicas, but contains at least one 69 replica. Typically one worker will correspond to one machine, but in the case 70 of very large models with model parallelism, one worker may span multiple 71 machines. We typically run one input pipeline per worker, feeding all the 72 replicas on that worker. 73* _Synchronous_, or more commonly _sync_, training is where the updates from 74 each replica are aggregated together before updating the model variables. This 75 is in contrast to _asynchronous_, or _async_ training, where each replica 76 updates the model variables independently. You may also have replicas 77 partitioned into groups which are in sync within each group but async between 78 groups. 79* _Parameter servers_: These are machines that hold a single copy of 80 parameters/variables, used by some strategies (right now just 81 `tf.distribute.experimental.ParameterServerStrategy`). All replicas that want 82 to operate on a variable retrieve it at the beginning of a step and send an 83 update to be applied at the end of the step. These can in principle support 84 either sync or async training, but right now we only have support for async 85 training with parameter servers. Compare to 86 `tf.distribute.experimental.CentralStorageStrategy`, which puts all variables 87 on a single device on the same machine (and does sync training), and 88 `tf.distribute.MirroredStrategy`, which mirrors variables to multiple devices 89 (see below). 90 91* _Replica context_ vs. _Cross-replica context_ vs _Update context_ 92 93 A _replica context_ applies 94 when you execute the computation function that was called with `strategy.run`. 95 Conceptually, you're in replica context when executing the computation 96 function that is being replicated. 97 98 An _update context_ is entered in a `tf.distribute.StrategyExtended.update` 99 call. 100 101 An _cross-replica context_ is entered when you enter a `strategy.scope`. This 102 is useful for calling `tf.distribute.Strategy` methods which operate across 103 the replicas (like `reduce_to()`). By default you start in a _replica context_ 104 (the "default single _replica context_") and then some methods can switch you 105 back and forth. 106 107* _Distributed value_: Distributed value is represented by the base class 108 `tf.distribute.DistributedValues`. `tf.distribute.DistributedValues` is useful 109 to represent values on multiple devices, and it contains a map from replica id 110 to values. Two representative kinds of `tf.distribute.DistributedValues` are 111 "PerReplica" and "Mirrored" values. 112 113 "PerReplica" values exist on the worker 114 devices, with a different value for each replica. They are produced by 115 iterating through a distributed dataset returned by 116 `tf.distribute.Strategy.experimental_distribute_dataset` and 117 `tf.distribute.Strategy.distribute_datasets_from_function`. They 118 are also the typical result returned by 119 `tf.distribute.Strategy.run`. 120 121 "Mirrored" values are like "PerReplica" values, except we know that the value 122 on all replicas are the same. We can safely read a "Mirrored" value in a 123 cross-replica context by using the value on any replica. 124 125* _Unwrapping_ and _merging_: Consider calling a function `fn` on multiple 126 replicas, like `strategy.run(fn, args=[w])` with an 127 argument `w` that is a `tf.distribute.DistributedValues`. This means `w` will 128 have a map taking replica id `0` to `w0`, replica id `1` to `w1`, etc. 129 `strategy.run()` unwraps `w` before calling `fn`, so it calls `fn(w0)` on 130 device `d0`, `fn(w1)` on device `d1`, etc. It then merges the return 131 values from `fn()`, which leads to one common object if the returned values 132 are the same object from every replica, or a `DistributedValues` object 133 otherwise. 134 135* _Reductions_ and _all-reduce_: A _reduction_ is a method of aggregating 136 multiple values into one value, like "sum" or "mean". If a strategy is doing 137 sync training, we will perform a reduction on the gradients to a parameter 138 from all replicas before applying the update. _All-reduce_ is an algorithm for 139 performing a reduction on values from multiple devices and making the result 140 available on all of those devices. 141 142* _Mirrored variables_: These are variables that are created on multiple 143 devices, where we keep the variables in sync by applying the same 144 updates to every copy. Mirrored variables are created with 145 `tf.Variable(...synchronization=tf.VariableSynchronization.ON_WRITE...)`. 146 Normally they are only used in synchronous training. 147 148* _SyncOnRead variables_ 149 150 _SyncOnRead variables_ are created by 151 `tf.Variable(...synchronization=tf.VariableSynchronization.ON_READ...)`, and 152 they are created on multiple devices. In replica context, each 153 component variable on the local replica can perform reads and writes without 154 synchronization with each other. When the 155 _SyncOnRead variable_ is read in cross-replica context, the values from 156 component variables are aggregated and returned. 157 158 _SyncOnRead variables_ bring a lot of custom configuration difficulty to the 159 underlying logic, so we do not encourage users to instantiate and use 160 _SyncOnRead variable_ on their own. We have mainly used _SyncOnRead 161 variables_ for use cases such as batch norm and metrics. For performance 162 reasons, we often don't need to keep these statistics in sync every step and 163 they can be accumulated on each replica independently. The only time we want 164 to sync them is reporting or checkpointing, which typically happens in 165 cross-replica context. _SyncOnRead variables_ are also often used by advanced 166 users who want to control when variable values are aggregated. For example, 167 users sometimes want to maintain gradients independently on each replica for a 168 couple of steps without aggregation. 169 170* _Distribute-aware layers_ 171 172 Layers are generally called in a replica context, except when defining a 173 Keras functional model. `tf.distribute.in_cross_replica_context` will let you 174 determine which case you are in. If in a replica context, 175 the `tf.distribute.get_replica_context` function will return the default 176 replica context outside a strategy scope, `None` within a strategy scope, and 177 a `tf.distribute.ReplicaContext` object inside a strategy scope and within a 178 `tf.distribute.Strategy.run` function. The `ReplicaContext` object has an 179 `all_reduce` method for aggregating across all replicas. 180 181 182Note that we provide a default version of `tf.distribute.Strategy` that is 183used when no other strategy is in scope, that provides the same API with 184reasonable default behavior. 185""" 186# pylint: enable=line-too-long 187 188from __future__ import absolute_import 189from __future__ import division 190from __future__ import print_function 191 192import collections 193import copy 194import enum # pylint: disable=g-bad-import-order 195import threading 196import weakref 197 198import six 199 200from tensorflow.python.autograph.core import ag_ctx as autograph_ctx 201from tensorflow.python.autograph.impl import api as autograph 202from tensorflow.python.data.ops import dataset_ops 203from tensorflow.python.distribute import collective_util 204from tensorflow.python.distribute import device_util 205from tensorflow.python.distribute import distribution_strategy_context 206from tensorflow.python.distribute import numpy_dataset 207from tensorflow.python.distribute import reduce_util 208from tensorflow.python.eager import context as eager_context 209from tensorflow.python.eager import def_function 210from tensorflow.python.eager import monitoring 211from tensorflow.python.framework import constant_op 212from tensorflow.python.framework import dtypes 213from tensorflow.python.framework import ops 214from tensorflow.python.framework import tensor_shape 215from tensorflow.python.framework import tensor_util 216from tensorflow.python.ops import array_ops 217from tensorflow.python.ops import control_flow_ops 218from tensorflow.python.ops import custom_gradient 219from tensorflow.python.ops import math_ops 220from tensorflow.python.ops import resource_variable_ops 221from tensorflow.python.ops import summary_ops_v2 222from tensorflow.python.ops import variable_scope 223from tensorflow.python.ops.losses import losses_impl 224from tensorflow.python.platform import tf_logging 225from tensorflow.python.training.tracking import base as trackable 226from tensorflow.python.util import deprecation 227from tensorflow.python.util import nest 228from tensorflow.python.util import tf_contextlib 229from tensorflow.python.util.deprecation import deprecated 230from tensorflow.python.util.tf_export import tf_export 231from tensorflow.tools.docs import doc_controls 232 233 234# ------------------------------------------------------------------------------ 235# Context tracking whether in a strategy.update() or .update_non_slot() call. 236 237 238_update_replica_id = threading.local() 239 240 241def get_update_replica_id(): 242 """Get the current device if in a `tf.distribute.Strategy.update()` call.""" 243 try: 244 return _update_replica_id.current 245 except AttributeError: 246 return None 247 248 249class UpdateContext(object): 250 """Context manager when you are in `update()` or `update_non_slot()`.""" 251 252 __slots__ = ["_replica_id", "_old_replica_id"] 253 254 def __init__(self, replica_id): 255 self._replica_id = replica_id 256 self._old_replica_id = None 257 258 def __enter__(self): 259 self._old_replica_id = get_update_replica_id() 260 _update_replica_id.current = self._replica_id 261 262 def __exit__(self, exception_type, exception_value, traceback): 263 del exception_type, exception_value, traceback 264 _update_replica_id.current = self._old_replica_id 265 266 267# ------------------------------------------------------------------------------ 268# Public utility functions. 269 270 271@tf_export(v1=["distribute.get_loss_reduction"]) 272def get_loss_reduction(): 273 """`tf.distribute.ReduceOp` corresponding to the last loss reduction. 274 275 This is used to decide whether loss should be scaled in optimizer (used only 276 for estimator + v1 optimizer use case). 277 278 Returns: 279 `tf.distribute.ReduceOp` corresponding to the last loss reduction for 280 estimator and v1 optimizer use case. `tf.distribute.ReduceOp.SUM` otherwise. 281 """ 282 if not distribution_strategy_context.get_strategy()._scale_loss_for_estimator: # pylint: disable=protected-access 283 # If we are not in Estimator context then return 'SUM'. We do not need to 284 # scale loss in the optimizer. 285 return reduce_util.ReduceOp.SUM 286 last_reduction = ops.get_default_graph()._last_loss_reduction # pylint: disable=protected-access 287 if (last_reduction == losses_impl.Reduction.SUM or 288 last_reduction == "sum"): # Check for tf.keras.losses.Reduction.SUM 289 return reduce_util.ReduceOp.SUM 290 return reduce_util.ReduceOp.MEAN 291 292 293# ------------------------------------------------------------------------------ 294# Internal API for validating the current thread mode 295 296 297def _require_cross_replica_or_default_context_extended(extended, 298 error_message=None): 299 """Verify in cross-replica context.""" 300 context = _get_per_thread_mode() 301 cross_replica = context.cross_replica_context 302 if cross_replica is not None and cross_replica.extended is extended: 303 return 304 if context is _get_default_replica_mode(): 305 return 306 strategy = extended._container_strategy() # pylint: disable=protected-access 307 # We have an error to report, figure out the right message. 308 if context.strategy is not strategy: 309 _wrong_strategy_scope(strategy, context) 310 assert cross_replica is None 311 if not error_message: 312 error_message = ("Method requires being in cross-replica context, use " 313 "get_replica_context().merge_call()") 314 raise RuntimeError(error_message) 315 316 317def _wrong_strategy_scope(strategy, context): 318 # Figure out the right error message. 319 if not distribution_strategy_context.has_strategy(): 320 raise RuntimeError( 321 'Need to be inside "with strategy.scope()" for %s' % 322 (strategy,)) 323 else: 324 raise RuntimeError( 325 "Mixing different tf.distribute.Strategy objects: %s is not %s" % 326 (context.strategy, strategy)) 327 328 329def require_replica_context(replica_ctx): 330 """Verify in `replica_ctx` replica context.""" 331 context = _get_per_thread_mode() 332 if context.replica_context is replica_ctx: return 333 # We have an error to report, figure out the right message. 334 if context.replica_context is None: 335 raise RuntimeError("Need to be inside `call_for_each_replica()`") 336 if context.strategy is replica_ctx.strategy: 337 # Two different ReplicaContexts with the same tf.distribute.Strategy. 338 raise RuntimeError("Mismatching ReplicaContext.") 339 raise RuntimeError( 340 "Mismatching tf.distribute.Strategy objects: %s is not %s." % 341 (context.strategy, replica_ctx.strategy)) 342 343 344def _require_strategy_scope_strategy(strategy): 345 """Verify in a `strategy.scope()` in this thread.""" 346 context = _get_per_thread_mode() 347 if context.strategy is strategy: return 348 _wrong_strategy_scope(strategy, context) 349 350 351def _require_strategy_scope_extended(extended): 352 """Verify in a `distribution_strategy.scope()` in this thread.""" 353 context = _get_per_thread_mode() 354 if context.strategy.extended is extended: return 355 # Report error. 356 strategy = extended._container_strategy() # pylint: disable=protected-access 357 _wrong_strategy_scope(strategy, context) 358 359 360# ------------------------------------------------------------------------------ 361# Internal context managers used to implement the DistributionStrategy 362# base class 363 364 365class _CurrentDistributionContext(object): 366 """Context manager setting the current `tf.distribute.Strategy`. 367 368 Also: overrides the variable creator and optionally the current device. 369 """ 370 371 def __init__(self, 372 strategy, 373 var_creator_scope, 374 var_scope=None, 375 default_device=None): 376 self._context = distribution_strategy_context._CrossReplicaThreadMode( # pylint: disable=protected-access 377 strategy) 378 self._var_creator_scope = var_creator_scope 379 self._var_scope = var_scope 380 if default_device: 381 self._device_scope = ops.device(default_device) 382 else: 383 self._device_scope = None 384 self._same_scope_again_count = 0 385 386 def __enter__(self): 387 # Allow this scope to be entered if this strategy is already in scope. 388 if distribution_strategy_context.has_strategy(): 389 _require_cross_replica_or_default_context_extended( 390 self._context.strategy.extended) 391 self._same_scope_again_count += 1 392 else: 393 _push_per_thread_mode(self._context) 394 if self._var_scope: 395 self._var_scope.__enter__() 396 self._var_creator_scope.__enter__() 397 if self._device_scope: 398 self._device_scope.__enter__() 399 return self._context.strategy 400 401 def __exit__(self, exception_type, exception_value, traceback): 402 if self._same_scope_again_count > 0: 403 self._same_scope_again_count -= 1 404 return 405 if self._device_scope: 406 try: 407 self._device_scope.__exit__(exception_type, exception_value, traceback) 408 except RuntimeError as e: 409 six.raise_from( 410 RuntimeError("Device scope nesting error: move call to " 411 "tf.distribute.set_strategy() out of `with` scope."), 412 e) 413 414 try: 415 self._var_creator_scope.__exit__( 416 exception_type, exception_value, traceback) 417 except RuntimeError as e: 418 six.raise_from( 419 RuntimeError("Variable creator scope nesting error: move call to " 420 "tf.distribute.set_strategy() out of `with` scope."), 421 e) 422 423 if self._var_scope: 424 try: 425 self._var_scope.__exit__(exception_type, exception_value, traceback) 426 except RuntimeError as e: 427 six.raise_from( 428 RuntimeError("Variable scope nesting error: move call to " 429 "tf.distribute.set_strategy() out of `with` scope."), 430 e) 431 _pop_per_thread_mode() 432 433 434# TODO(yuefengz): add more replication modes. 435@tf_export("distribute.InputReplicationMode") 436class InputReplicationMode(enum.Enum): 437 """Replication mode for input function. 438 439 * `PER_WORKER`: The input function will be called on each worker 440 independently, creating as many input pipelines as number of workers. 441 Replicas will dequeue from the local Dataset on their worker. 442 `tf.distribute.Strategy` doesn't manage any state sharing between such 443 separate input pipelines. 444 * `PER_REPLICA`: The input function will be called on each replica separately. 445 `tf.distribute.Strategy` doesn't manage any state sharing between such 446 separate input pipelines. 447 """ 448 PER_WORKER = "PER_WORKER" 449 PER_REPLICA = "PER_REPLICA" 450 451 452@tf_export("distribute.InputContext") 453class InputContext(object): 454 """A class wrapping information needed by an input function. 455 456 This is a context class that is passed to the user's input function and 457 contains information about the compute replicas and input pipelines. The 458 number of compute replicas (in sync training) helps compute the local batch 459 size from the desired global batch size for each replica. The input pipeline 460 information can be used to return a different subset of the input in each 461 replica (for e.g. shard the input pipeline, use a different input 462 source etc). 463 """ 464 465 __slots__ = [ 466 "_num_input_pipelines", "_input_pipeline_id", "_num_replicas_in_sync" 467 ] 468 469 def __init__(self, 470 num_input_pipelines=1, 471 input_pipeline_id=0, 472 num_replicas_in_sync=1): 473 """Initializes an InputContext object. 474 475 Args: 476 num_input_pipelines: the number of input pipelines in a cluster. 477 input_pipeline_id: the current input pipeline id, should be an int in 478 [0,`num_input_pipelines`). 479 num_replicas_in_sync: the number of replicas that are in sync. 480 """ 481 self._num_input_pipelines = num_input_pipelines 482 self._input_pipeline_id = input_pipeline_id 483 self._num_replicas_in_sync = num_replicas_in_sync 484 485 @property 486 def num_replicas_in_sync(self): 487 """Returns the number of compute replicas in sync.""" 488 return self._num_replicas_in_sync 489 490 @property 491 def input_pipeline_id(self): 492 """Returns the input pipeline ID.""" 493 return self._input_pipeline_id 494 495 @property 496 def num_input_pipelines(self): 497 """Returns the number of input pipelines.""" 498 return self._num_input_pipelines 499 500 def get_per_replica_batch_size(self, global_batch_size): 501 """Returns the per-replica batch size. 502 503 Args: 504 global_batch_size: the global batch size which should be divisible by 505 `num_replicas_in_sync`. 506 507 Returns: 508 the per-replica batch size. 509 510 Raises: 511 ValueError: if `global_batch_size` not divisible by 512 `num_replicas_in_sync`. 513 """ 514 if global_batch_size % self._num_replicas_in_sync != 0: 515 raise ValueError("The `global_batch_size` %r is not divisible by " 516 "`num_replicas_in_sync` %r " % 517 (global_batch_size, self._num_replicas_in_sync)) 518 return global_batch_size // self._num_replicas_in_sync 519 520 def __str__(self): 521 return "tf.distribute.InputContext(input pipeline id {}, total: {})".format( 522 self.input_pipeline_id, self.num_input_pipelines) 523 524 525@tf_export("distribute.experimental.ValueContext", v1=[]) 526class ValueContext(object): 527 """A class wrapping information needed by a distribute function. 528 529 This is a context class that is passed to the `value_fn` in 530 `strategy.experimental_distribute_values_from_function` and contains 531 information about the compute replicas. The `num_replicas_in_sync` and 532 `replica_id` can be used to customize the value on each replica. 533 534 Example usage: 535 536 1. Directly constructed. 537 538 >>> def value_fn(context): 539 ... return context.replica_id_in_sync_group/context.num_replicas_in_sync 540 >>> context = tf.distribute.experimental.ValueContext( 541 ... replica_id_in_sync_group=2, num_replicas_in_sync=4) 542 >>> per_replica_value = value_fn(context) 543 >>> per_replica_value 544 0.5 545 546 2. Passed in by `experimental_distribute_values_from_function`. 547 548 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 549 >>> def value_fn(value_context): 550 ... return value_context.num_replicas_in_sync 551 >>> distributed_values = ( 552 ... strategy.experimental_distribute_values_from_function( 553 ... value_fn)) 554 >>> local_result = strategy.experimental_local_results(distributed_values) 555 >>> local_result 556 (2, 2) 557 558 """ 559 560 __slots__ = ["_replica_id_in_sync_group", "_num_replicas_in_sync"] 561 562 def __init__(self, 563 replica_id_in_sync_group=0, 564 num_replicas_in_sync=1): 565 """Initializes an ValueContext object. 566 567 Args: 568 replica_id_in_sync_group: the current replica_id, should be an int in 569 [0,`num_replicas_in_sync`). 570 num_replicas_in_sync: the number of replicas that are in sync. 571 """ 572 self._replica_id_in_sync_group = replica_id_in_sync_group 573 self._num_replicas_in_sync = num_replicas_in_sync 574 575 @property 576 def num_replicas_in_sync(self): 577 """Returns the number of compute replicas in sync.""" 578 return self._num_replicas_in_sync 579 580 @property 581 def replica_id_in_sync_group(self): 582 """Returns the replica ID.""" 583 return self._replica_id_in_sync_group 584 585 def __str__(self): 586 return (("tf.distribute.ValueContext(replica id {}, " 587 " total replicas in sync: ""{})") 588 .format(self.replica_id_in_sync_group, self.num_replicas_in_sync)) 589 590 591@tf_export("distribute.RunOptions") 592class RunOptions( 593 collections.namedtuple("RunOptions", [ 594 "experimental_enable_dynamic_batch_size", 595 "experimental_bucketizing_dynamic_shape", 596 ])): 597 """Run options for `strategy.run`. 598 599 This can be used to hold some strategy specific configs. 600 601 Attributes: 602 experimental_enable_dynamic_batch_size: Boolean. Only applies to 603 TPUStrategy. Default to True. If True, TPUStrategy will enable dynamic 604 padder to support dynamic batch size for the inputs. Otherwise only static 605 shape inputs are allowed. 606 experimental_bucketizing_dynamic_shape: Boolean. Only applies to 607 TPUStrategy. Default to False. If True, TPUStrategy will automatic 608 bucketize inputs passed into `run` if the input shape is 609 dynamic. This is a performance optimization to reduce XLA recompilation, 610 which should not have impact on correctness. 611 """ 612 613 def __new__(cls, 614 experimental_enable_dynamic_batch_size=True, 615 experimental_bucketizing_dynamic_shape=False): 616 return super(RunOptions, 617 cls).__new__(cls, experimental_enable_dynamic_batch_size, 618 experimental_bucketizing_dynamic_shape) 619 620 621@tf_export("distribute.InputOptions", v1=[]) 622class InputOptions( 623 collections.namedtuple("InputOptions", [ 624 "experimental_prefetch_to_device", 625 "experimental_replication_mode", 626 "experimental_place_dataset_on_device", 627 ])): 628 """Run options for `experimental_distribute_dataset(s_from_function)`. 629 630 This can be used to hold some strategy specific configs. 631 632 ```python 633 # Setup TPUStrategy 634 resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 635 tf.config.experimental_connect_to_cluster(resolver) 636 tf.tpu.experimental.initialize_tpu_system(resolver) 637 strategy = tf.distribute.TPUStrategy(resolver) 638 639 dataset = tf.data.Dataset.range(16) 640 distributed_dataset_on_host = ( 641 strategy.experimental_distribute_dataset( 642 dataset, 643 tf.distribute.InputOptions( 644 experimental_replication_mode= 645 experimental_replication_mode.PER_WORKER, 646 experimental_place_dataset_on_device=False))) 647 ``` 648 649 Attributes: 650 experimental_prefetch_to_device: Boolean. Defaults to True. If True, dataset 651 elements will be prefetched to accelerator device memory. When False, 652 dataset elements are prefetched to host device memory. Must be False when 653 using TPUEmbedding API. experimental_prefetch_to_device can only be used 654 with experimental_replication_mode=PER_WORKER 655 experimental_replication_mode: Replication mode for the input function. 656 Currently, the InputReplicationMode.PER_REPLICA is only supported with 657 tf.distribute.MirroredStrategy. 658 experimental_distribute_datasets_from_function. 659 The default value is InputReplicationMode.PER_WORKER. 660 experimental_place_dataset_on_device: Boolean. Default to False. When True, 661 dataset will be placed on the device, otherwise it will remain on the 662 host. experimental_place_dataset_on_device=True can only be used with 663 experimental_replication_mode=PER_REPLICA 664 """ 665 666 def __new__(cls, 667 experimental_prefetch_to_device=True, 668 experimental_replication_mode=InputReplicationMode.PER_WORKER, 669 experimental_place_dataset_on_device=False): 670 return super(InputOptions, 671 cls).__new__(cls, experimental_prefetch_to_device, 672 experimental_replication_mode, 673 experimental_place_dataset_on_device) 674 675# ------------------------------------------------------------------------------ 676# Base classes for all distribution strategies. 677 678 679# Base class for v1 Strategy and v2 Strategy classes. For API's specific to 680# v1/v2 Strategy, add to implementing classes of StrategyBase. 681# pylint: disable=line-too-long 682class StrategyBase(object): 683 """A state & compute distribution policy on a list of devices. 684 685 See [the guide](https://www.tensorflow.org/guide/distributed_training) 686 for overview and examples. See `tf.distribute.StrategyExtended` and 687 [`tf.distribute`](https://www.tensorflow.org/api_docs/python/tf/distribute) 688 for a glossary of concepts mentioned on this page such as "per-replica", 689 _replica_, and _reduce_. 690 691 In short: 692 693 * To use it with Keras `compile`/`fit`, 694 [please 695 read](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_keras). 696 * You may pass descendant of `tf.distribute.Strategy` to 697 `tf.estimator.RunConfig` to specify how a `tf.estimator.Estimator` 698 should distribute its computation. See 699 [guide](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_estimator_limited_support). 700 * Otherwise, use `tf.distribute.Strategy.scope` to specify that a 701 strategy should be used when building an executing your model. 702 (This puts you in the "cross-replica context" for this strategy, which 703 means the strategy is put in control of things like variable placement.) 704 * If you are writing a custom training loop, you will need to call a few more 705 methods, 706 [see the 707 guide](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_custom_training_loops): 708 709 * Start by creating a `tf.data.Dataset` normally. 710 * Use `tf.distribute.Strategy.experimental_distribute_dataset` to convert 711 a `tf.data.Dataset` to something that produces "per-replica" values. 712 If you want to manually specify how the dataset should be partitioned 713 across replicas, use 714 `tf.distribute.Strategy.distribute_datasets_from_function` 715 instead. 716 * Use `tf.distribute.Strategy.run` to run a function 717 once per replica, taking values that may be "per-replica" (e.g. 718 from a `tf.distribute.DistributedDataset` object) and returning 719 "per-replica" values. 720 This function is executed in "replica context", which means each 721 operation is performed separately on each replica. 722 * Finally use a method (such as `tf.distribute.Strategy.reduce`) to 723 convert the resulting "per-replica" values into ordinary `Tensor`s. 724 725 A custom training loop can be as simple as: 726 727 ``` 728 with my_strategy.scope(): 729 @tf.function 730 def distribute_train_epoch(dataset): 731 def replica_fn(input): 732 # process input and return result 733 return result 734 735 total_result = 0 736 for x in dataset: 737 per_replica_result = my_strategy.run(replica_fn, args=(x,)) 738 total_result += my_strategy.reduce(tf.distribute.ReduceOp.SUM, 739 per_replica_result, axis=None) 740 return total_result 741 742 dist_dataset = my_strategy.experimental_distribute_dataset(dataset) 743 for _ in range(EPOCHS): 744 train_result = distribute_train_epoch(dist_dataset) 745 ``` 746 747 This takes an ordinary `dataset` and `replica_fn` and runs it 748 distributed using a particular `tf.distribute.Strategy` named 749 `my_strategy` above. Any variables created in `replica_fn` are created 750 using `my_strategy`'s policy, and library functions called by 751 `replica_fn` can use the `get_replica_context()` API to implement 752 distributed-specific behavior. 753 754 You can use the `reduce` API to aggregate results across replicas and use 755 this as a return value from one iteration over a 756 `tf.distribute.DistributedDataset`. Or 757 you can use `tf.keras.metrics` (such as loss, accuracy, etc.) to 758 accumulate metrics across steps in a given epoch. 759 760 See the 761 [custom training loop 762 tutorial](https://www.tensorflow.org/tutorials/distribute/custom_training) 763 for a more detailed example. 764 765 Note: `tf.distribute.Strategy` currently does not support TensorFlow's 766 partitioned variables (where a single variable is split across multiple 767 devices) at this time. 768 """ 769 # pylint: enable=line-too-long 770 771 # TODO(josh11b): Partitioned computations, state; sharding 772 # TODO(josh11b): Model parallelism: "replicas" with multiple devices; shuffling 773 774 def __init__(self, extended): 775 self._extended = extended 776 777 # Flag that is used to indicate whether distribution strategy is used with 778 # Estimator. This is required for backward compatibility of loss scaling 779 # when using v1 optimizer with estimator. 780 self._scale_loss_for_estimator = False 781 782 if not hasattr(extended, "_retrace_functions_for_each_device"): 783 # pylint: disable=protected-access 784 # `extended._retrace_functions_for_each_device` dictates 785 # whether the same function will be retraced when it is called on 786 # different devices. 787 try: 788 extended._retrace_functions_for_each_device = ( 789 len(extended.worker_devices) > 1) 790 distribution_strategy_replica_gauge.get_cell("num_replicas").set( 791 self.num_replicas_in_sync) 792 except: # pylint: disable=bare-except 793 # Default for the case where extended.worker_devices can't return 794 # a sensible value. 795 extended._retrace_functions_for_each_device = True 796 797 # Below are the dicts of axis(int) -> `tf.function`. 798 self._mean_reduce_helper_fns = {} 799 self._reduce_sum_fns = {} 800 801 # Whether this strategy is designed to work with `ClusterCoordinator`. 802 self._should_use_with_coordinator = False 803 804 @property 805 def extended(self): 806 """`tf.distribute.StrategyExtended` with additional methods.""" 807 return self._extended 808 809 @tf_contextlib.contextmanager 810 def _scale_loss_for_estimator_enabled(self): 811 """Scope which sets a flag used for scaling losses in optimizer. 812 813 Yields: 814 `_scale_loss_for_estimator_enabled` is a context manager with a 815 side effect, but doesn't return a value. 816 """ 817 self._scale_loss_for_estimator = True 818 try: 819 yield 820 finally: 821 self._scale_loss_for_estimator = False 822 823 # pylint: disable=line-too-long 824 def scope(self): 825 """Context manager to make the strategy current and distribute variables. 826 827 This method returns a context manager, and is used as follows: 828 829 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 830 >>> # Variable created inside scope: 831 >>> with strategy.scope(): 832 ... mirrored_variable = tf.Variable(1.) 833 >>> mirrored_variable 834 MirroredVariable:{ 835 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>, 836 1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=1.0> 837 } 838 >>> # Variable created outside scope: 839 >>> regular_variable = tf.Variable(1.) 840 >>> regular_variable 841 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0> 842 843 _What happens when Strategy.scope is entered?_ 844 845 * `strategy` is installed in the global context as the "current" strategy. 846 Inside this scope, `tf.distribute.get_strategy()` will now return this 847 strategy. Outside this scope, it returns the default no-op strategy. 848 * Entering the scope also enters the "cross-replica context". See 849 `tf.distribute.StrategyExtended` for an explanation on cross-replica and 850 replica contexts. 851 * Variable creation inside `scope` is intercepted by the strategy. Each 852 strategy defines how it wants to affect the variable creation. Sync 853 strategies like `MirroredStrategy`, `TPUStrategy` and 854 `MultiWorkerMiroredStrategy` create variables replicated on each replica, 855 whereas `ParameterServerStrategy` creates variables on the parameter 856 servers. This is done using a custom `tf.variable_creator_scope`. 857 * In some strategies, a default device scope may also be entered: in 858 `MultiWorkerMiroredStrategy`, a default device scope of "/CPU:0" is 859 entered on each worker. 860 861 Note: Entering a scope does not automatically distribute a computation, except 862 in the case of high level training framework like keras `model.fit`. If 863 you're not using `model.fit`, you 864 need to use `strategy.run` API to explicitly distribute that computation. 865 See an example in the [custom training loop tutorial](https://www.tensorflow.org/tutorials/distribute/custom_training). 866 867 868 _What should be in scope and what should be outside?_ 869 870 There are a number of requirements on what needs to happen inside the scope. 871 However, in places where we have information about which strategy is in use, 872 we often enter the scope for the user, so they don't have to do it 873 explicitly (i.e. calling those either inside or outside the scope is OK). 874 875 * Anything that creates variables that should be distributed variables 876 must be called in a `strategy.scope`. This can be accomplished either by 877 directly calling the variable creating function within the scope context, 878 or by relying on another API like `strategy.run` or `keras.Model.fit` to 879 automatically enter it for you. Any variable that is created outside scope 880 will not be distributed and may have performance implications. Some common 881 objects that create variables in TF are Models, Optimizers, Metrics. Such 882 objects should always be initiliazized in the scope, and any functions 883 that may lazily create variables (e.g., `Model.__call__()`, tracing a 884 `tf.function`, etc.) should similarly be called within scope. Another 885 source of variable creation can be a checkpoint restore - when variables 886 are created lazily. Note that any variable created inside a strategy 887 captures the strategy information. So reading and writing to these 888 variables outside the `strategy.scope` can also work seamlessly, without 889 the user having to enter the scope. 890 * Some strategy APIs (such as `strategy.run` and `strategy.reduce`) which 891 require to be in a strategy's scope, enter the scope automatically, which 892 means when using those APIs you don't need to explicitly enter the scope 893 yourself. 894 * When a `tf.keras.Model` is created inside a `strategy.scope`, the Model 895 object captures the scope information. When high level training framework 896 methods such as `model.compile`, `model.fit`, etc. are then called, the 897 captured scope will be automatically entered, and the associated strategy 898 will be used to distribute the training etc. See a detailed example in 899 [distributed keras tutorial](https://www.tensorflow.org/tutorials/distribute/keras). 900 WARNING: Simply calling `model(..)` does not automatically enter the 901 captured scope -- only high level training framework APIs support this 902 behavior: `model.compile`, `model.fit`, `model.evaluate`, `model.predict` 903 and `model.save` can all be called inside or outside the scope. 904 * The following can be either inside or outside the scope: 905 * Creating the input datasets 906 * Defining `tf.function`s that represent your training step 907 * Saving APIs such as `tf.saved_model.save`. Loading creates variables, 908 so that should go inside the scope if you want to train the model in a 909 distributed way. 910 * Checkpoint saving. As mentioned above - `checkpoint.restore` may 911 sometimes need to be inside scope if it creates variables. 912 913 Returns: 914 A context manager. 915 """ 916 return self._extended._scope(self) # pylint: disable=protected-access 917 # pylint: enable=line-too-long 918 919 @doc_controls.do_not_doc_inheritable # DEPRECATED, moving to `extended` 920 def colocate_vars_with(self, colocate_with_variable): 921 """DEPRECATED: use extended.colocate_vars_with() instead.""" 922 return self._extended.colocate_vars_with(colocate_with_variable) 923 924 @doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only 925 def make_dataset_iterator(self, dataset): 926 """DEPRECATED TF 1.x ONLY.""" 927 return self._extended._make_dataset_iterator(dataset) # pylint: disable=protected-access 928 929 @doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only 930 def make_input_fn_iterator(self, 931 input_fn, 932 replication_mode=InputReplicationMode.PER_WORKER): 933 """DEPRECATED TF 1.x ONLY.""" 934 if replication_mode != InputReplicationMode.PER_WORKER: 935 raise ValueError( 936 "Input replication mode not supported: %r" % replication_mode) 937 with self.scope(): 938 return self.extended._make_input_fn_iterator( # pylint: disable=protected-access 939 input_fn, replication_mode=replication_mode) 940 941 @doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only 942 def experimental_run(self, fn, input_iterator=None): 943 """DEPRECATED TF 1.x ONLY.""" 944 with self.scope(): 945 args = (input_iterator.get_next(),) if input_iterator is not None else () 946 return self.run(fn, args=args) 947 948 def experimental_distribute_dataset(self, dataset, options=None): 949 # pylint: disable=line-too-long 950 """Creates `tf.distribute.DistributedDataset` from `tf.data.Dataset`. 951 952 The returned `tf.distribute.DistributedDataset` can be iterated over 953 similar to regular datasets. 954 NOTE: The user cannot add any more transformations to a 955 `tf.distribute.DistributedDataset`. You can only create an iterator or 956 examine the `tf.TypeSpec` of the data generated by it. See API docs of 957 `tf.distribute.DistributedDataset` to learn more. 958 959 The following is an example: 960 961 >>> global_batch_size = 2 962 >>> # Passing the devices is optional. 963 ... strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"]) 964 >>> # Create a dataset 965 ... dataset = tf.data.Dataset.range(4).batch(global_batch_size) 966 >>> # Distribute that dataset 967 ... dist_dataset = strategy.experimental_distribute_dataset(dataset) 968 >>> @tf.function 969 ... def replica_fn(input): 970 ... return input*2 971 >>> result = [] 972 >>> # Iterate over the `tf.distribute.DistributedDataset` 973 ... for x in dist_dataset: 974 ... # process dataset elements 975 ... result.append(strategy.run(replica_fn, args=(x,))) 976 >>> print(result) 977 [PerReplica:{ 978 0: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([0])>, 979 1: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([2])> 980 }, PerReplica:{ 981 0: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([4])>, 982 1: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([6])> 983 }] 984 985 986 Three key actions happening under the hood of this method are batching, 987 sharding, and prefetching. 988 989 In the code snippet above, `dataset` is batched by `global_batch_size`, and 990 calling `experimental_distribute_dataset` on it rebatches `dataset` to a 991 new batch size that is equal to the global batch size divided by the number 992 of replicas in sync. We iterate through it using a Pythonic for loop. 993 `x` is a `tf.distribute.DistributedValues` containing data for all replicas, 994 and each replica gets data of the new batch size. 995 `tf.distribute.Strategy.run` will take care of feeding the right per-replica 996 data in `x` to the right `replica_fn` executed on each replica. 997 998 Sharding contains autosharding across multiple workers and within every 999 worker. First, in multi-worker distributed training (i.e. when you use 1000 `tf.distribute.experimental.MultiWorkerMirroredStrategy` 1001 or `tf.distribute.TPUStrategy`), autosharding a dataset over a set of 1002 workers means that each worker is assigned a subset of the entire dataset 1003 (if the right `tf.data.experimental.AutoShardPolicy` is set). This is to 1004 ensure that at each step, a global batch size of non-overlapping dataset 1005 elements will be processed by each worker. Autosharding has a couple of 1006 different options that can be specified using 1007 `tf.data.experimental.DistributeOptions`. Then, sharding within each worker 1008 means the method will split the data among all the worker devices (if more 1009 than one a present). This will happen regardless of multi-worker 1010 autosharding. 1011 1012 Note: for autosharding across multiple workers, the default mode is 1013 `tf.data.experimental.AutoShardPolicy.AUTO`. This mode 1014 will attempt to shard the input dataset by files if the dataset is 1015 being created out of reader datasets (e.g. `tf.data.TFRecordDataset`, 1016 `tf.data.TextLineDataset`, etc.) or otherwise shard the dataset by data, 1017 where each of the workers will read the entire dataset and only process the 1018 shard assigned to it. However, if you have less than one input file per 1019 worker, we suggest that you disable dataset autosharding across workers by 1020 setting the `tf.data.experimental.DistributeOptions.auto_shard_policy` to be 1021 `tf.data.experimental.AutoShardPolicy.OFF`. 1022 1023 By default, this method adds a prefetch transformation at the end of the 1024 user provided `tf.data.Dataset` instance. The argument to the prefetch 1025 transformation which is `buffer_size` is equal to the number of replicas in 1026 sync. 1027 1028 If the above batch splitting and dataset sharding logic is undesirable, 1029 please use 1030 `tf.distribute.Strategy.distribute_datasets_from_function` 1031 instead, which does not do any automatic batching or sharding for you. 1032 1033 Note: If you are using TPUStrategy, the order in which the data is processed 1034 by the workers when using 1035 `tf.distribute.Strategy.experimental_distribute_dataset` or 1036 `tf.distribute.Strategy.distribute_datasets_from_function` is 1037 not guaranteed. This is typically required if you are using 1038 `tf.distribute` to scale prediction. You can however insert an index for 1039 each element in the batch and order outputs accordingly. Refer to [this 1040 snippet](https://www.tensorflow.org/tutorials/distribute/input#caveats) 1041 for an example of how to order outputs. 1042 1043 Note: Stateful dataset transformations are currently not supported with 1044 `tf.distribute.experimental_distribute_dataset` or 1045 `tf.distribute.distribute_datasets_from_function`. Any stateful 1046 ops that the dataset may have are currently ignored. For example, if your 1047 dataset has a `map_fn` that uses `tf.random.uniform` to rotate an image, 1048 then you have a dataset graph that depends on state (i.e the random seed) on 1049 the local machine where the python process is being executed. 1050 1051 For a tutorial on more usage and properties of this method, refer to the 1052 [tutorial on distributed input](https://www.tensorflow.org/tutorials/distribute/input#tfdistributestrategyexperimental_distribute_dataset). 1053 If you are interested in last partial batch handling, read [this section](https://www.tensorflow.org/tutorials/distribute/input#partial_batches). 1054 1055 Args: 1056 dataset: `tf.data.Dataset` that will be sharded across all replicas using 1057 the rules stated above. 1058 options: `tf.distribute.InputOptions` used to control options on how this 1059 dataset is distributed. 1060 1061 Returns: 1062 A `tf.distribute.DistributedDataset`. 1063 """ 1064 # pylint: enable=line-too-long 1065 return self._extended._experimental_distribute_dataset(dataset, options) # pylint: disable=protected-access 1066 1067 def distribute_datasets_from_function(self, dataset_fn, options=None): 1068 # pylint: disable=line-too-long 1069 """Distributes `tf.data.Dataset` instances created by calls to `dataset_fn`. 1070 1071 The argument `dataset_fn` that users pass in is an input function that has a 1072 `tf.distribute.InputContext` argument and returns a `tf.data.Dataset` 1073 instance. It is expected that the returned dataset from `dataset_fn` is 1074 already batched by per-replica batch size (i.e. global batch size divided by 1075 the number of replicas in sync) and sharded. 1076 `tf.distribute.Strategy.distribute_datasets_from_function` does 1077 not batch or shard the `tf.data.Dataset` instance 1078 returned from the input function. `dataset_fn` will be called on the CPU 1079 device of each of the workers and each generates a dataset where every 1080 replica on that worker will dequeue one batch of inputs (i.e. if a worker 1081 has two replicas, two batches will be dequeued from the `Dataset` every 1082 step). 1083 1084 This method can be used for several purposes. First, it allows you to 1085 specify your own batching and sharding logic. (In contrast, 1086 `tf.distribute.experimental_distribute_dataset` does batching and sharding 1087 for you.) For example, where 1088 `experimental_distribute_dataset` is unable to shard the input files, this 1089 method might be used to manually shard the dataset (avoiding the slow 1090 fallback behavior in `experimental_distribute_dataset`). In cases where the 1091 dataset is infinite, this sharding can be done by creating dataset replicas 1092 that differ only in their random seed. 1093 1094 The `dataset_fn` should take an `tf.distribute.InputContext` instance where 1095 information about batching and input replication can be accessed. 1096 1097 You can use `element_spec` property of the 1098 `tf.distribute.DistributedDataset` returned by this API to query the 1099 `tf.TypeSpec` of the elements returned by the iterator. This can be used to 1100 set the `input_signature` property of a `tf.function`. Follow 1101 `tf.distribute.DistributedDataset.element_spec` to see an example. 1102 1103 IMPORTANT: The `tf.data.Dataset` returned by `dataset_fn` should have a 1104 per-replica batch size, unlike `experimental_distribute_dataset`, which uses 1105 the global batch size. This may be computed using 1106 `input_context.get_per_replica_batch_size`. 1107 1108 Note: If you are using TPUStrategy, the order in which the data is processed 1109 by the workers when using 1110 `tf.distribute.Strategy.experimental_distribute_dataset` or 1111 `tf.distribute.Strategy.distribute_datasets_from_function` is 1112 not guaranteed. This is typically required if you are using 1113 `tf.distribute` to scale prediction. You can however insert an index for 1114 each element in the batch and order outputs accordingly. Refer to [this 1115 snippet](https://www.tensorflow.org/tutorials/distribute/input#caveats) 1116 for an example of how to order outputs. 1117 1118 Note: Stateful dataset transformations are currently not supported with 1119 `tf.distribute.experimental_distribute_dataset` or 1120 `tf.distribute.distribute_datasets_from_function`. Any stateful 1121 ops that the dataset may have are currently ignored. For example, if your 1122 dataset has a `map_fn` that uses `tf.random.uniform` to rotate an image, 1123 then you have a dataset graph that depends on state (i.e the random seed) on 1124 the local machine where the python process is being executed. 1125 1126 For a tutorial on more usage and properties of this method, refer to the 1127 [tutorial on distributed input](https://www.tensorflow.org/tutorials/distribute/input#tfdistributestrategyexperimental_distribute_datasets_from_function)). 1128 If you are interested in last partial batch handling, read [this section](https://www.tensorflow.org/tutorials/distribute/input#partial_batches). 1129 1130 Args: 1131 dataset_fn: A function taking a `tf.distribute.InputContext` instance and 1132 returning a `tf.data.Dataset`. 1133 options: `tf.distribute.InputOptions` used to control options on how this 1134 dataset is distributed. 1135 1136 Returns: 1137 A `tf.distribute.DistributedDataset`. 1138 """ 1139 # pylint: enable=line-too-long 1140 return self._extended._distribute_datasets_from_function( # pylint: disable=protected-access 1141 dataset_fn, options) 1142 1143 # TODO(b/162776748): Remove deprecated symbol. 1144 @doc_controls.do_not_doc_inheritable 1145 @deprecation.deprecated(None, "rename to distribute_datasets_from_function") 1146 def experimental_distribute_datasets_from_function(self, 1147 dataset_fn, 1148 options=None): 1149 return self.distribute_datasets_from_function(dataset_fn, options) 1150 1151 def run(self, fn, args=(), kwargs=None, options=None): 1152 """Invokes `fn` on each replica, with the given arguments. 1153 1154 This method is the primary way to distribute your computation with a 1155 tf.distribute object. It invokes `fn` on each replica. If `args` or `kwargs` 1156 have `tf.distribute.DistributedValues`, such as those produced by a 1157 `tf.distribute.DistributedDataset` from 1158 `tf.distribute.Strategy.experimental_distribute_dataset` or 1159 `tf.distribute.Strategy.distribute_datasets_from_function`, 1160 when `fn` is executed on a particular replica, it will be executed with the 1161 component of `tf.distribute.DistributedValues` that correspond to that 1162 replica. 1163 1164 `fn` is invoked under a replica context. `fn` may call 1165 `tf.distribute.get_replica_context()` to access members such as 1166 `all_reduce`. Please see the module-level docstring of tf.distribute for the 1167 concept of replica context. 1168 1169 All arguments in `args` or `kwargs` can be a nested structure of tensors, 1170 e.g. a list of tensors, in which case `args` and `kwargs` will be passed to 1171 the `fn` invoked on each replica. Or `args` or `kwargs` can be 1172 `tf.distribute.DistributedValues` containing tensors or composite tensors, 1173 i.e. `tf.compat.v1.TensorInfo.CompositeTensor`, in which case each `fn` call 1174 will get the component of a `tf.distribute.DistributedValues` corresponding 1175 to its replica. Note that arbitrary Python values that are not of the types 1176 above are not supported. 1177 1178 IMPORTANT: Depending on the implementation of `tf.distribute.Strategy` and 1179 whether eager execution is enabled, `fn` may be called one or more times. If 1180 `fn` is annotated with `tf.function` or `tf.distribute.Strategy.run` is 1181 called inside a `tf.function` (eager execution is disabled inside a 1182 `tf.function` by default), `fn` is called once per replica to generate a 1183 Tensorflow graph, which will then be reused for execution with new inputs. 1184 Otherwise, if eager execution is enabled, `fn` will be called once per 1185 replica every step just like regular python code. 1186 1187 Example usage: 1188 1189 1. Constant tensor input. 1190 1191 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 1192 >>> tensor_input = tf.constant(3.0) 1193 >>> @tf.function 1194 ... def replica_fn(input): 1195 ... return input*2.0 1196 >>> result = strategy.run(replica_fn, args=(tensor_input,)) 1197 >>> result 1198 PerReplica:{ 1199 0: <tf.Tensor: shape=(), dtype=float32, numpy=6.0>, 1200 1: <tf.Tensor: shape=(), dtype=float32, numpy=6.0> 1201 } 1202 1203 2. DistributedValues input. 1204 1205 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 1206 >>> @tf.function 1207 ... def run(): 1208 ... def value_fn(value_context): 1209 ... return value_context.num_replicas_in_sync 1210 ... distributed_values = ( 1211 ... strategy.experimental_distribute_values_from_function( 1212 ... value_fn)) 1213 ... def replica_fn2(input): 1214 ... return input*2 1215 ... return strategy.run(replica_fn2, args=(distributed_values,)) 1216 >>> result = run() 1217 >>> result 1218 <tf.Tensor: shape=(), dtype=int32, numpy=4> 1219 1220 3. Use `tf.distribute.ReplicaContext` to allreduce values. 1221 1222 >>> strategy = tf.distribute.MirroredStrategy(["gpu:0", "gpu:1"]) 1223 >>> @tf.function 1224 ... def run(): 1225 ... def value_fn(value_context): 1226 ... return tf.constant(value_context.replica_id_in_sync_group) 1227 ... distributed_values = ( 1228 ... strategy.experimental_distribute_values_from_function( 1229 ... value_fn)) 1230 ... def replica_fn(input): 1231 ... return tf.distribute.get_replica_context().all_reduce("sum", input) 1232 ... return strategy.run(replica_fn, args=(distributed_values,)) 1233 >>> result = run() 1234 >>> result 1235 PerReplica:{ 1236 0: <tf.Tensor: shape=(), dtype=int32, numpy=1>, 1237 1: <tf.Tensor: shape=(), dtype=int32, numpy=1> 1238 } 1239 1240 Args: 1241 fn: The function to run on each replica. 1242 args: Optional positional arguments to `fn`. Its element can be a tensor, 1243 a nested structure of tensors or a `tf.distribute.DistributedValues`. 1244 kwargs: Optional keyword arguments to `fn`. Its element can be a tensor, 1245 a nested structure of tensors or a `tf.distribute.DistributedValues`. 1246 options: An optional instance of `tf.distribute.RunOptions` specifying 1247 the options to run `fn`. 1248 1249 Returns: 1250 Merged return value of `fn` across replicas. The structure of the return 1251 value is the same as the return value from `fn`. Each element in the 1252 structure can either be `tf.distribute.DistributedValues`, `Tensor` 1253 objects, or `Tensor`s (for example, if running on a single replica). 1254 """ 1255 del options 1256 1257 if not isinstance(args, (list, tuple)): 1258 raise ValueError( 1259 "positional args must be a list or tuple, got {}".format(type(args))) 1260 1261 with self.scope(): 1262 # tf.distribute supports Eager functions, so AutoGraph should not be 1263 # applied when the caller is also in Eager mode. 1264 fn = autograph.tf_convert( 1265 fn, autograph_ctx.control_status_ctx(), convert_by_default=False) 1266 return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs) 1267 1268 def reduce(self, reduce_op, value, axis): 1269 """Reduce `value` across replicas and return result on current device. 1270 1271 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 1272 >>> def step_fn(): 1273 ... i = tf.distribute.get_replica_context().replica_id_in_sync_group 1274 ... return tf.identity(i) 1275 >>> 1276 >>> per_replica_result = strategy.run(step_fn) 1277 >>> total = strategy.reduce("SUM", per_replica_result, axis=None) 1278 >>> total 1279 <tf.Tensor: shape=(), dtype=int32, numpy=1> 1280 1281 To see how this would look with multiple replicas, consider the same 1282 example with MirroredStrategy with 2 GPUs: 1283 1284 ```python 1285 strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"]) 1286 def step_fn(): 1287 i = tf.distribute.get_replica_context().replica_id_in_sync_group 1288 return tf.identity(i) 1289 1290 per_replica_result = strategy.run(step_fn) 1291 # Check devices on which per replica result is: 1292 strategy.experimental_local_results(per_replica_result)[0].device 1293 # /job:localhost/replica:0/task:0/device:GPU:0 1294 strategy.experimental_local_results(per_replica_result)[1].device 1295 # /job:localhost/replica:0/task:0/device:GPU:1 1296 1297 total = strategy.reduce("SUM", per_replica_result, axis=None) 1298 # Check device on which reduced result is: 1299 total.device 1300 # /job:localhost/replica:0/task:0/device:CPU:0 1301 1302 ``` 1303 1304 This API is typically used for aggregating the results returned from 1305 different replicas, for reporting etc. For example, loss computed from 1306 different replicas can be averaged using this API before printing. 1307 1308 Note: The result is copied to the "current" device - which would typically 1309 be the CPU of the worker on which the program is running. For `TPUStrategy`, 1310 it is the first TPU host. For multi client `MultiWorkerMirroredStrategy`, 1311 this is CPU of each worker. 1312 1313 There are a number of different tf.distribute APIs for reducing values 1314 across replicas: 1315 * `tf.distribute.ReplicaContext.all_reduce`: This differs from 1316 `Strategy.reduce` in that it is for replica context and does 1317 not copy the results to the host device. `all_reduce` should be typically 1318 used for reductions inside the training step such as gradients. 1319 * `tf.distribute.StrategyExtended.reduce_to` and 1320 `tf.distribute.StrategyExtended.batch_reduce_to`: These APIs are more 1321 advanced versions of `Strategy.reduce` as they allow customizing the 1322 destination of the result. They are also called in cross replica context. 1323 1324 _What should axis be?_ 1325 1326 Given a per-replica value returned by `run`, say a 1327 per-example loss, the batch will be divided across all the replicas. This 1328 function allows you to aggregate across replicas and optionally also across 1329 batch elements by specifying the axis parameter accordingly. 1330 1331 For example, if you have a global batch size of 8 and 2 1332 replicas, values for examples `[0, 1, 2, 3]` will be on replica 0 and 1333 `[4, 5, 6, 7]` will be on replica 1. With `axis=None`, `reduce` will 1334 aggregate only across replicas, returning `[0+4, 1+5, 2+6, 3+7]`. 1335 This is useful when each replica is computing a scalar or some other value 1336 that doesn't have a "batch" dimension (like a gradient or loss). 1337 ``` 1338 strategy.reduce("sum", per_replica_result, axis=None) 1339 ``` 1340 1341 Sometimes, you will want to aggregate across both the global batch _and_ 1342 all replicas. You can get this behavior by specifying the batch 1343 dimension as the `axis`, typically `axis=0`. In this case it would return a 1344 scalar `0+1+2+3+4+5+6+7`. 1345 ``` 1346 strategy.reduce("sum", per_replica_result, axis=0) 1347 ``` 1348 1349 If there is a last partial batch, you will need to specify an axis so 1350 that the resulting shape is consistent across replicas. So if the last 1351 batch has size 6 and it is divided into [0, 1, 2, 3] and [4, 5], you 1352 would get a shape mismatch unless you specify `axis=0`. If you specify 1353 `tf.distribute.ReduceOp.MEAN`, using `axis=0` will use the correct 1354 denominator of 6. Contrast this with computing `reduce_mean` to get a 1355 scalar value on each replica and this function to average those means, 1356 which will weigh some values `1/8` and others `1/4`. 1357 1358 Args: 1359 reduce_op: a `tf.distribute.ReduceOp` value specifying how values should 1360 be combined. Allows using string representation of the enum such as 1361 "SUM", "MEAN". 1362 value: a `tf.distribute.DistributedValues` instance, e.g. returned by 1363 `Strategy.run`, to be combined into a single tensor. It can also be a 1364 regular tensor when used with `OneDeviceStrategy` or default strategy. 1365 axis: specifies the dimension to reduce along within each 1366 replica's tensor. Should typically be set to the batch dimension, or 1367 `None` to only reduce across replicas (e.g. if the tensor has no batch 1368 dimension). 1369 1370 Returns: 1371 A `Tensor`. 1372 """ 1373 # TODO(josh11b): support `value` being a nest. 1374 _require_cross_replica_or_default_context_extended(self._extended) 1375 if isinstance(reduce_op, six.string_types): 1376 reduce_op = reduce_util.ReduceOp(reduce_op.upper()) 1377 if axis is None: 1378 return self._extended._reduce(reduce_op, value) # pylint: disable=protected-access 1379 if reduce_op == reduce_util.ReduceOp.SUM: 1380 1381 def reduce_sum(v): 1382 return math_ops.reduce_sum(v, axis=axis) 1383 1384 if eager_context.executing_eagerly(): 1385 # As some strategies (e.g. TPUStrategy) doesn't support pure eager 1386 # execution, wrap the `reduce_sum_fn` with a `tf.function` so it can be 1387 # run from eager mode. Cache the tf.function by `axis` to avoid the 1388 # same function to be traced again. 1389 if axis not in self._reduce_sum_fns: 1390 1391 def reduce_sum_fn(v): 1392 return self.run(reduce_sum, args=(v,)) 1393 1394 self._reduce_sum_fns[axis] = def_function.function(reduce_sum_fn) 1395 value = self._reduce_sum_fns[axis](value) 1396 else: 1397 value = self.run(reduce_sum, args=(value,)) 1398 1399 return self._extended._reduce(reduce_op, value) # pylint: disable=protected-access 1400 if reduce_op != reduce_util.ReduceOp.MEAN: 1401 raise TypeError("Expected `reduce_op` to be a `tf.distribute.ReduceOp`, " 1402 "not: %r" % reduce_op) 1403 # TODO(josh11b): Support list/tuple and tensor axis values. 1404 if not isinstance(axis, six.integer_types): 1405 raise TypeError("Expected `axis` to be an integer not: %r" % axis) 1406 1407 def mean_reduce_helper(v, axis=axis): 1408 """Computes the numerator and denominator on each replica.""" 1409 numer = math_ops.reduce_sum(v, axis=axis) 1410 if v.shape.rank is not None: 1411 # Note(joshl): We support axis < 0 to be consistent with the 1412 # tf.math.reduce_* operations. 1413 if axis < 0: 1414 if axis + v.shape.rank < 0: 1415 raise ValueError( 1416 "`axis` = %r out of range for `value` with rank %d" % 1417 (axis, v.shape.rank)) 1418 axis += v.shape.rank 1419 elif axis >= v.shape.rank: 1420 raise ValueError( 1421 "`axis` = %r out of range for `value` with rank %d" % 1422 (axis, v.shape.rank)) 1423 # TF v2 returns `None` for unknown dimensions and an integer for 1424 # known dimension, whereas TF v1 returns tensor_shape.Dimension(None) 1425 # or tensor_shape.Dimension(integer). `dimension_value` hides this 1426 # difference, always returning `None` or an integer. 1427 dim = tensor_shape.dimension_value(v.shape[axis]) 1428 if dim is not None: 1429 # By returning a python value in the static shape case, we can 1430 # maybe get a fast path for reducing the denominator. 1431 # TODO(b/151871486): Remove array_ops.identity after we fallback to 1432 # simple reduction if inputs are all on CPU. 1433 return numer, array_ops.identity( 1434 constant_op.constant(dim, dtype=dtypes.int64)) 1435 elif axis < 0: 1436 axis = axis + array_ops.rank(v) 1437 # TODO(b/151871486): Remove array_ops.identity after we fallback to simple 1438 # reduction if inputs are all on CPU. 1439 denom = array_ops.identity( 1440 array_ops.shape_v2(v, out_type=dtypes.int64)[axis]) 1441 # TODO(josh11b): Should we cast denom to v.dtype here instead of after the 1442 # reduce is complete? 1443 return numer, denom 1444 1445 if eager_context.executing_eagerly(): 1446 # As some strategies (e.g. TPUStrategy) doesn't support pure eager 1447 # execution, wrap the `mean_reduce_helper` with a `tf.function` so it can 1448 # be run from eager mode. Cache the tf.function by `axis` to avoid the 1449 # same function to be traced again. 1450 if axis not in self._mean_reduce_helper_fns: 1451 1452 def mean_reduce_fn(v): 1453 return self.run(mean_reduce_helper, args=(v,)) 1454 1455 self._mean_reduce_helper_fns[axis] = def_function.function( 1456 mean_reduce_fn) 1457 numer, denom = self._mean_reduce_helper_fns[axis](value) 1458 else: 1459 numer, denom = self.run(mean_reduce_helper, args=(value,)) 1460 1461 # TODO(josh11b): Should batch reduce here instead of doing two. 1462 numer = self._extended._reduce(reduce_util.ReduceOp.SUM, numer) # pylint: disable=protected-access 1463 denom = self._extended._reduce(reduce_util.ReduceOp.SUM, denom) # pylint: disable=protected-access 1464 denom = math_ops.cast(denom, numer.dtype) 1465 return math_ops.truediv(numer, denom) 1466 1467 @doc_controls.do_not_doc_inheritable # DEPRECATED 1468 def unwrap(self, value): 1469 """Returns the list of all local per-replica values contained in `value`. 1470 1471 DEPRECATED: Please use `experimental_local_results` instead. 1472 1473 Note: This only returns values on the workers initiated by this client. 1474 When using a `tf.distribute.Strategy` like 1475 `tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker 1476 will be its own client, and this function will only return values 1477 computed on that worker. 1478 1479 Args: 1480 value: A value returned by `experimental_run()`, 1481 `extended.call_for_each_replica()`, or a variable created in `scope`. 1482 1483 Returns: 1484 A tuple of values contained in `value`. If `value` represents a single 1485 value, this returns `(value,).` 1486 """ 1487 return self._extended._local_results(value) # pylint: disable=protected-access 1488 1489 def experimental_local_results(self, value): 1490 """Returns the list of all local per-replica values contained in `value`. 1491 1492 Note: This only returns values on the worker initiated by this client. 1493 When using a `tf.distribute.Strategy` like 1494 `tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker 1495 will be its own client, and this function will only return values 1496 computed on that worker. 1497 1498 Args: 1499 value: A value returned by `experimental_run()`, `run()`, 1500 `extended.call_for_each_replica()`, or a variable created in `scope`. 1501 1502 Returns: 1503 A tuple of values contained in `value`. If `value` represents a single 1504 value, this returns `(value,).` 1505 """ 1506 return self._extended._local_results(value) # pylint: disable=protected-access 1507 1508 @doc_controls.do_not_doc_inheritable # DEPRECATED: TF v1.x only 1509 def group(self, value, name=None): 1510 """Shortcut for `tf.group(self.experimental_local_results(value))`.""" 1511 return self._extended._group(value, name) # pylint: disable=protected-access 1512 1513 @property 1514 def num_replicas_in_sync(self): 1515 """Returns number of replicas over which gradients are aggregated.""" 1516 return self._extended._num_replicas_in_sync # pylint: disable=protected-access 1517 1518 @doc_controls.do_not_doc_inheritable # DEPRECATED: see doc string 1519 def configure(self, 1520 session_config=None, 1521 cluster_spec=None, 1522 task_type=None, 1523 task_id=None): 1524 # pylint: disable=g-doc-return-or-yield,g-doc-args 1525 """DEPRECATED: use `update_config_proto` instead. 1526 1527 Configures the strategy class. 1528 1529 DEPRECATED: This method's functionality has been split into the strategy 1530 constructor and `update_config_proto`. In the future, we will allow passing 1531 cluster and config_proto to the constructor to configure the strategy. And 1532 `update_config_proto` can be used to update the config_proto based on the 1533 specific strategy. 1534 """ 1535 return self._extended._configure( # pylint: disable=protected-access 1536 session_config, cluster_spec, task_type, task_id) 1537 1538 @doc_controls.do_not_generate_docs # DEPRECATED 1539 def update_config_proto(self, config_proto): 1540 """DEPRECATED TF 1.x ONLY.""" 1541 return self._extended._update_config_proto(config_proto) # pylint: disable=protected-access 1542 1543 def __deepcopy__(self, memo): 1544 # First do a regular deepcopy of `self`. 1545 cls = self.__class__ 1546 result = cls.__new__(cls) 1547 memo[id(self)] = result 1548 for k, v in self.__dict__.items(): 1549 setattr(result, k, copy.deepcopy(v, memo)) 1550 # One little fix-up: we want `result._extended` to reference `result` 1551 # instead of `self`. 1552 result._extended._container_strategy_weakref = weakref.ref(result) # pylint: disable=protected-access 1553 return result 1554 1555 def __copy__(self): 1556 raise RuntimeError("Must only deepcopy DistributionStrategy.") 1557 1558 @property 1559 def cluster_resolver(self): 1560 """Returns the cluster resolver associated with this strategy. 1561 1562 In general, when using a multi-worker `tf.distribute` strategy such as 1563 `tf.distribute.experimental.MultiWorkerMirroredStrategy` or 1564 `tf.distribute.TPUStrategy()`, there is a 1565 `tf.distribute.cluster_resolver.ClusterResolver` associated with the 1566 strategy used, and such an instance is returned by this property. 1567 1568 Strategies that intend to have an associated 1569 `tf.distribute.cluster_resolver.ClusterResolver` must set the 1570 relevant attribute, or override this property; otherwise, `None` is returned 1571 by default. Those strategies should also provide information regarding what 1572 is returned by this property. 1573 1574 Single-worker strategies usually do not have a 1575 `tf.distribute.cluster_resolver.ClusterResolver`, and in those cases this 1576 property will return `None`. 1577 1578 The `tf.distribute.cluster_resolver.ClusterResolver` may be useful when the 1579 user needs to access information such as the cluster spec, task type or task 1580 id. For example, 1581 1582 ```python 1583 1584 os.environ['TF_CONFIG'] = json.dumps({ 1585 'cluster': { 1586 'worker': ["localhost:12345", "localhost:23456"], 1587 'ps': ["localhost:34567"] 1588 }, 1589 'task': {'type': 'worker', 'index': 0} 1590 }) 1591 1592 # This implicitly uses TF_CONFIG for the cluster and current task info. 1593 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() 1594 1595 ... 1596 1597 if strategy.cluster_resolver.task_type == 'worker': 1598 # Perform something that's only applicable on workers. Since we set this 1599 # as a worker above, this block will run on this particular instance. 1600 elif strategy.cluster_resolver.task_type == 'ps': 1601 # Perform something that's only applicable on parameter servers. Since we 1602 # set this as a worker above, this block will not run on this particular 1603 # instance. 1604 ``` 1605 1606 For more information, please see 1607 `tf.distribute.cluster_resolver.ClusterResolver`'s API docstring. 1608 1609 Returns: 1610 The cluster resolver associated with this strategy. Returns `None` if a 1611 cluster resolver is not applicable or available in this strategy. 1612 """ 1613 if hasattr(self.extended, "_cluster_resolver"): 1614 return self.extended._cluster_resolver # pylint: disable=protected-access 1615 return None 1616 1617 1618@tf_export("distribute.Strategy", v1=[]) # pylint: disable=g-missing-docstring 1619class Strategy(StrategyBase): 1620 1621 __doc__ = StrategyBase.__doc__ 1622 1623 def experimental_distribute_values_from_function(self, value_fn): 1624 """Generates `tf.distribute.DistributedValues` from `value_fn`. 1625 1626 This function is to generate `tf.distribute.DistributedValues` to pass 1627 into `run`, `reduce`, or other methods that take 1628 distributed values when not using datasets. 1629 1630 Args: 1631 value_fn: The function to run to generate values. It is called for 1632 each replica with `tf.distribute.ValueContext` as the sole argument. It 1633 must return a Tensor or a type that can be converted to a Tensor. 1634 Returns: 1635 A `tf.distribute.DistributedValues` containing a value for each replica. 1636 1637 Example usage: 1638 1639 1. Return constant value per replica: 1640 1641 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 1642 >>> def value_fn(ctx): 1643 ... return tf.constant(1.) 1644 >>> distributed_values = ( 1645 ... strategy.experimental_distribute_values_from_function( 1646 ... value_fn)) 1647 >>> local_result = strategy.experimental_local_results(distributed_values) 1648 >>> local_result 1649 (<tf.Tensor: shape=(), dtype=float32, numpy=1.0>, 1650 <tf.Tensor: shape=(), dtype=float32, numpy=1.0>) 1651 1652 2. Distribute values in array based on replica_id: 1653 1654 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 1655 >>> array_value = np.array([3., 2., 1.]) 1656 >>> def value_fn(ctx): 1657 ... return array_value[ctx.replica_id_in_sync_group] 1658 >>> distributed_values = ( 1659 ... strategy.experimental_distribute_values_from_function( 1660 ... value_fn)) 1661 >>> local_result = strategy.experimental_local_results(distributed_values) 1662 >>> local_result 1663 (3.0, 2.0) 1664 1665 3. Specify values using num_replicas_in_sync: 1666 1667 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 1668 >>> def value_fn(ctx): 1669 ... return ctx.num_replicas_in_sync 1670 >>> distributed_values = ( 1671 ... strategy.experimental_distribute_values_from_function( 1672 ... value_fn)) 1673 >>> local_result = strategy.experimental_local_results(distributed_values) 1674 >>> local_result 1675 (2, 2) 1676 1677 4. Place values on devices and distribute: 1678 1679 ``` 1680 strategy = tf.distribute.TPUStrategy() 1681 worker_devices = strategy.extended.worker_devices 1682 multiple_values = [] 1683 for i in range(strategy.num_replicas_in_sync): 1684 with tf.device(worker_devices[i]): 1685 multiple_values.append(tf.constant(1.0)) 1686 1687 def value_fn(ctx): 1688 return multiple_values[ctx.replica_id_in_sync_group] 1689 1690 distributed_values = strategy. 1691 experimental_distribute_values_from_function( 1692 value_fn) 1693 ``` 1694 1695 """ 1696 return self._extended._experimental_distribute_values_from_function( # pylint: disable=protected-access 1697 value_fn) 1698 1699 def gather(self, value, axis): 1700 # pylint: disable=line-too-long, protected-access 1701 """Gather `value` across replicas along `axis` to the current device. 1702 1703 Given a `tf.distribute.DistributedValues` or `tf.Tensor`-like 1704 object `value`, this API gathers and concatenates `value` across replicas 1705 along the `axis`-th dimension. The result is copied to the "current" device, 1706 which would typically be the CPU of the worker on which the program is 1707 running. For `tf.distribute.TPUStrategy`, it is the first TPU host. For 1708 multi-client `tf.distribute.MultiWorkerMirroredStrategy`, this is the CPU of 1709 each worker. 1710 1711 This API can only be called in the cross-replica context. For a counterpart 1712 in the replica context, see `tf.distribute.ReplicaContext.all_gather`. 1713 1714 Note: For all strategies except `tf.distribute.TPUStrategy`, the input 1715 `value` on different replicas must have the same rank, and their shapes must 1716 be the same in all dimensions except the `axis`-th dimension. In other 1717 words, their shapes cannot be different in a dimension `d` where `d` does 1718 not equal to the `axis` argument. For example, given a 1719 `tf.distribute.DistributedValues` with component tensors of shape 1720 `(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can call 1721 `gather(..., axis=1, ...)` on it, but not `gather(..., axis=0, ...)` or 1722 `gather(..., axis=2, ...)`. However, for `tf.distribute.TPUStrategy.gather`, 1723 all tensors must have exactly the same rank and same shape. 1724 1725 Note: Given a `tf.distribute.DistributedValues` `value`, its component 1726 tensors must have a non-zero rank. Otherwise, consider using 1727 `tf.expand_dims` before gathering them. 1728 1729 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 1730 >>> # A DistributedValues with component tensor of shape (2, 1) on each replica 1731 ... distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(tf.constant([[1], [2]]))) 1732 >>> @tf.function 1733 ... def run(): 1734 ... return strategy.gather(distributed_values, axis=0) 1735 >>> run() 1736 <tf.Tensor: shape=(4, 1), dtype=int32, numpy= 1737 array([[1], 1738 [2], 1739 [1], 1740 [2]], dtype=int32)> 1741 1742 1743 Consider the following example for more combinations: 1744 1745 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1", "GPU:2", "GPU:3"]) 1746 >>> single_tensor = tf.reshape(tf.range(6), shape=(1,2,3)) 1747 >>> distributed_values = strategy.experimental_distribute_values_from_function(lambda _: tf.identity(single_tensor)) 1748 >>> @tf.function 1749 ... def run(axis): 1750 ... return strategy.gather(distributed_values, axis=axis) 1751 >>> axis=0 1752 >>> run(axis) 1753 <tf.Tensor: shape=(4, 2, 3), dtype=int32, numpy= 1754 array([[[0, 1, 2], 1755 [3, 4, 5]], 1756 [[0, 1, 2], 1757 [3, 4, 5]], 1758 [[0, 1, 2], 1759 [3, 4, 5]], 1760 [[0, 1, 2], 1761 [3, 4, 5]]], dtype=int32)> 1762 >>> axis=1 1763 >>> run(axis) 1764 <tf.Tensor: shape=(1, 8, 3), dtype=int32, numpy= 1765 array([[[0, 1, 2], 1766 [3, 4, 5], 1767 [0, 1, 2], 1768 [3, 4, 5], 1769 [0, 1, 2], 1770 [3, 4, 5], 1771 [0, 1, 2], 1772 [3, 4, 5]]], dtype=int32)> 1773 >>> axis=2 1774 >>> run(axis) 1775 <tf.Tensor: shape=(1, 2, 12), dtype=int32, numpy= 1776 array([[[0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2], 1777 [3, 4, 5, 3, 4, 5, 3, 4, 5, 3, 4, 5]]], dtype=int32)> 1778 1779 1780 Args: 1781 value: a `tf.distribute.DistributedValues` instance, e.g. returned by 1782 `Strategy.run`, to be combined into a single tensor. It can also be a 1783 regular tensor when used with `tf.distribute.OneDeviceStrategy` or the 1784 default strategy. The tensors that constitute the DistributedValues 1785 can only be dense tensors with non-zero rank, NOT a `tf.IndexedSlices`. 1786 axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the 1787 range [0, rank(value)). 1788 1789 Returns: 1790 A `Tensor` that's the concatenation of `value` across replicas along 1791 `axis` dimension. 1792 """ 1793 # pylint: enable=line-too-long 1794 error_message = ("tf.distribute.Strategy.gather method requires " 1795 "cross-replica context, use " 1796 "get_replica_context().all_gather() instead") 1797 _require_cross_replica_or_default_context_extended(self._extended, 1798 error_message) 1799 dst = device_util.current( 1800 ) or self._extended._default_device or "/device:CPU:0" 1801 if isinstance(value, ops.IndexedSlices): 1802 raise NotImplementedError("gather does not support IndexedSlices") 1803 return self._extended._local_results( 1804 self._extended._gather_to(value, dst, axis))[0] 1805 1806 1807# TF v1.x version has additional deprecated APIs 1808@tf_export(v1=["distribute.Strategy"]) 1809class StrategyV1(StrategyBase): 1810 """A list of devices with a state & compute distribution policy. 1811 1812 See [the guide](https://www.tensorflow.org/guide/distribute_strategy) 1813 for overview and examples. 1814 1815 Note: Not all `tf.distribute.Strategy` implementations currently support 1816 TensorFlow's partitioned variables (where a single variable is split across 1817 multiple devices) at this time. 1818 """ 1819 1820 def make_dataset_iterator(self, dataset): 1821 """Makes an iterator for input provided via `dataset`. 1822 1823 DEPRECATED: This method is not available in TF 2.x. 1824 1825 Data from the given dataset will be distributed evenly across all the 1826 compute replicas. We will assume that the input dataset is batched by the 1827 global batch size. With this assumption, we will make a best effort to 1828 divide each batch across all the replicas (one or more workers). 1829 If this effort fails, an error will be thrown, and the user should instead 1830 use `make_input_fn_iterator` which provides more control to the user, and 1831 does not try to divide a batch across replicas. 1832 1833 The user could also use `make_input_fn_iterator` if they want to 1834 customize which input is fed to which replica/worker etc. 1835 1836 Args: 1837 dataset: `tf.data.Dataset` that will be distributed evenly across all 1838 replicas. 1839 1840 Returns: 1841 An `tf.distribute.InputIterator` which returns inputs for each step of the 1842 computation. User should call `initialize` on the returned iterator. 1843 """ 1844 return self._extended._make_dataset_iterator(dataset) # pylint: disable=protected-access 1845 1846 def make_input_fn_iterator(self, # pylint: disable=useless-super-delegation 1847 input_fn, 1848 replication_mode=InputReplicationMode.PER_WORKER): 1849 """Returns an iterator split across replicas created from an input function. 1850 1851 DEPRECATED: This method is not available in TF 2.x. 1852 1853 The `input_fn` should take an `tf.distribute.InputContext` object where 1854 information about batching and input sharding can be accessed: 1855 1856 ``` 1857 def input_fn(input_context): 1858 batch_size = input_context.get_per_replica_batch_size(global_batch_size) 1859 d = tf.data.Dataset.from_tensors([[1.]]).repeat().batch(batch_size) 1860 return d.shard(input_context.num_input_pipelines, 1861 input_context.input_pipeline_id) 1862 with strategy.scope(): 1863 iterator = strategy.make_input_fn_iterator(input_fn) 1864 replica_results = strategy.experimental_run(replica_fn, iterator) 1865 ``` 1866 1867 The `tf.data.Dataset` returned by `input_fn` should have a per-replica 1868 batch size, which may be computed using 1869 `input_context.get_per_replica_batch_size`. 1870 1871 Args: 1872 input_fn: A function taking a `tf.distribute.InputContext` object and 1873 returning a `tf.data.Dataset`. 1874 replication_mode: an enum value of `tf.distribute.InputReplicationMode`. 1875 Only `PER_WORKER` is supported currently, which means there will be 1876 a single call to `input_fn` per worker. Replicas will dequeue from the 1877 local `tf.data.Dataset` on their worker. 1878 1879 Returns: 1880 An iterator object that should first be `.initialize()`-ed. It may then 1881 either be passed to `strategy.experimental_run()` or you can 1882 `iterator.get_next()` to get the next value to pass to 1883 `strategy.extended.call_for_each_replica()`. 1884 """ 1885 return super(StrategyV1, self).make_input_fn_iterator( 1886 input_fn, replication_mode) 1887 1888 def experimental_make_numpy_dataset(self, numpy_input, session=None): 1889 """Makes a tf.data.Dataset for input provided via a numpy array. 1890 1891 This avoids adding `numpy_input` as a large constant in the graph, 1892 and copies the data to the machine or machines that will be processing 1893 the input. 1894 1895 Note that you will likely need to use 1896 tf.distribute.Strategy.experimental_distribute_dataset 1897 with the returned dataset to further distribute it with the strategy. 1898 1899 Example: 1900 ``` 1901 numpy_input = np.ones([10], dtype=np.float32) 1902 dataset = strategy.experimental_make_numpy_dataset(numpy_input) 1903 dist_dataset = strategy.experimental_distribute_dataset(dataset) 1904 ``` 1905 1906 Args: 1907 numpy_input: A nest of NumPy input arrays that will be converted into a 1908 dataset. Note that lists of Numpy arrays are stacked, as that is normal 1909 `tf.data.Dataset` behavior. 1910 session: (TensorFlow v1.x graph execution only) A session used for 1911 initialization. 1912 1913 Returns: 1914 A `tf.data.Dataset` representing `numpy_input`. 1915 """ 1916 return self.extended.experimental_make_numpy_dataset( 1917 numpy_input, session=session) 1918 1919 def experimental_run(self, fn, input_iterator=None): # pylint: disable=useless-super-delegation 1920 """Runs ops in `fn` on each replica, with inputs from `input_iterator`. 1921 1922 DEPRECATED: This method is not available in TF 2.x. Please switch 1923 to using `run` instead. 1924 1925 When eager execution is enabled, executes ops specified by `fn` on each 1926 replica. Otherwise, builds a graph to execute the ops on each replica. 1927 1928 Each replica will take a single, different input from the inputs provided by 1929 one `get_next` call on the input iterator. 1930 1931 `fn` may call `tf.distribute.get_replica_context()` to access members such 1932 as `replica_id_in_sync_group`. 1933 1934 IMPORTANT: Depending on the `tf.distribute.Strategy` implementation being 1935 used, and whether eager execution is enabled, `fn` may be called one or more 1936 times (once for each replica). 1937 1938 Args: 1939 fn: The function to run. The inputs to the function must match the outputs 1940 of `input_iterator.get_next()`. The output must be a `tf.nest` of 1941 `Tensor`s. 1942 input_iterator: (Optional) input iterator from which the inputs are taken. 1943 1944 Returns: 1945 Merged return value of `fn` across replicas. The structure of the return 1946 value is the same as the return value from `fn`. Each element in the 1947 structure can either be `PerReplica` (if the values are unsynchronized), 1948 `Mirrored` (if the values are kept in sync), or `Tensor` (if running on a 1949 single replica). 1950 """ 1951 return super(StrategyV1, self).experimental_run( 1952 fn, input_iterator) 1953 1954 def reduce(self, reduce_op, value, axis=None): 1955 return super(StrategyV1, self).reduce(reduce_op, value, axis) 1956 1957 reduce.__doc__ = StrategyBase.reduce.__doc__ 1958 1959 def update_config_proto(self, config_proto): 1960 """Returns a copy of `config_proto` modified for use with this strategy. 1961 1962 DEPRECATED: This method is not available in TF 2.x. 1963 1964 The updated config has something needed to run a strategy, e.g. 1965 configuration to run collective ops, or device filters to improve 1966 distributed training performance. 1967 1968 Args: 1969 config_proto: a `tf.ConfigProto` object. 1970 1971 Returns: 1972 The updated copy of the `config_proto`. 1973 """ 1974 return self._extended._update_config_proto(config_proto) # pylint: disable=protected-access 1975 1976 1977# NOTE(josh11b): For any strategy that needs to support tf.compat.v1, 1978# instead descend from StrategyExtendedV1. 1979@tf_export("distribute.StrategyExtended", v1=[]) 1980class StrategyExtendedV2(object): 1981 """Additional APIs for algorithms that need to be distribution-aware. 1982 1983 Note: For most usage of `tf.distribute.Strategy`, there should be no need to 1984 call these methods, since TensorFlow libraries (such as optimizers) already 1985 call these methods when needed on your behalf. 1986 1987 1988 Some common use cases of functions on this page: 1989 1990 * _Locality_ 1991 1992 `tf.distribute.DistributedValues` can have the same _locality_ as a 1993 _distributed variable_, which leads to a mirrored value residing on the same 1994 devices as the variable (as opposed to the compute devices). Such values may 1995 be passed to a call to `tf.distribute.StrategyExtended.update` to update the 1996 value of a variable. You may use 1997 `tf.distribute.StrategyExtended.colocate_vars_with` to give a variable the 1998 same locality as another variable. You may convert a "PerReplica" value to a 1999 variable's locality by using `tf.distribute.StrategyExtended.reduce_to` or 2000 `tf.distribute.StrategyExtended.batch_reduce_to`. 2001 2002 * _How to update a distributed variable_ 2003 2004 A distributed variable is variables created on multiple devices. As discussed 2005 in the [glossary](https://www.tensorflow.org/api_docs/python/tf/distribute), 2006 mirrored variable and SyncOnRead variable are two examples. The standard 2007 pattern for updating distributed variables is to: 2008 2009 1. In your function passed to `tf.distribute.Strategy.run`, 2010 compute a list of (update, variable) pairs. For example, the update might 2011 be a gradient of the loss with respect to the variable. 2012 2. Switch to cross-replica mode by calling 2013 `tf.distribute.get_replica_context().merge_call()` with the updates and 2014 variables as arguments. 2015 3. Call 2016 `tf.distribute.StrategyExtended.reduce_to(VariableAggregation.SUM, t, v)` 2017 (for one variable) or `tf.distribute.StrategyExtended.batch_reduce_to` 2018 (for a list of variables) to sum the updates. 2019 4. Call `tf.distribute.StrategyExtended.update(v)` for each variable to update 2020 its value. 2021 2022 Steps 2 through 4 are done automatically by class 2023 `tf.keras.optimizers.Optimizer` if you call its 2024 `tf.keras.optimizers.Optimizer.apply_gradients` method in a replica context. 2025 2026 In fact, a higher-level solution to update a distributed variable is by 2027 calling `assign` on the variable as you would do to a regular `tf.Variable`. 2028 You can call the method in both _replica context_ and _cross-replica context_. 2029 For a _mirrored variable_, calling `assign` in _replica context_ requires you 2030 to specify the `aggregation` type in the variable constructor. In that case, 2031 the context switching and sync described in steps 2 through 4 are handled for 2032 you. If you call `assign` on _mirrored variable_ in _cross-replica context_, 2033 you can only assign a single value or assign values from another mirrored 2034 variable or a mirrored `tf.distribute.DistributedValues`. For a _SyncOnRead 2035 variable_, in _replica context_, you can simply call `assign` on it and no 2036 aggregation happens under the hood. In _cross-replica context_, you can only 2037 assign a single value to a SyncOnRead variable. One example case is restoring 2038 from a checkpoint: if the `aggregation` type of the variable is 2039 `tf.VariableAggregation.SUM`, it is assumed that replica values were added 2040 before checkpointing, so at the time of restoring, the value is divided by 2041 the number of replicas and then assigned to each replica; if the `aggregation` 2042 type is `tf.VariableAggregation.MEAN`, the value is assigned to each replica 2043 directly. 2044 2045 """ 2046 2047 def __init__(self, container_strategy): 2048 self._container_strategy_weakref = weakref.ref(container_strategy) 2049 self._default_device = None 2050 # This property is used to determine if we should set drop_remainder=True 2051 # when creating Datasets from numpy array inputs. 2052 self._require_static_shapes = False 2053 2054 def _container_strategy(self): 2055 """Get the containing `tf.distribute.Strategy`. 2056 2057 This should not generally be needed except when creating a new 2058 `ReplicaContext` and to validate that the caller is in the correct 2059 `scope()`. 2060 2061 Returns: 2062 The `tf.distribute.Strategy` such that `strategy.extended` is `self`. 2063 """ 2064 container_strategy = self._container_strategy_weakref() 2065 assert container_strategy is not None 2066 return container_strategy 2067 2068 def _scope(self, strategy): 2069 """Implementation of tf.distribute.Strategy.scope().""" 2070 2071 def creator_with_resource_vars(next_creator, **kwargs): 2072 """Variable creator to use in `_CurrentDistributionContext`.""" 2073 _require_strategy_scope_extended(self) 2074 kwargs["use_resource"] = True 2075 kwargs["distribute_strategy"] = strategy 2076 2077 # Unwrap `initial_value` if it is a `CheckpointInitialValue` to avoid 2078 # dereferencing a `Tensor` that is without a `name`. We still need to 2079 # propagate the metadata it's holding. 2080 if isinstance(kwargs["initial_value"], trackable.CheckpointInitialValue): 2081 checkpoint_restore_uid = kwargs[ 2082 "initial_value"].checkpoint_position.restore_uid 2083 kwargs["initial_value"] = kwargs["initial_value"].wrapped_value 2084 elif isinstance(kwargs["initial_value"], 2085 trackable.CheckpointInitialValueCallable): 2086 checkpoint_restore_uid = kwargs[ 2087 "initial_value"].checkpoint_position.restore_uid 2088 else: 2089 checkpoint_restore_uid = None 2090 2091 created = self._create_variable(next_creator, **kwargs) 2092 2093 if checkpoint_restore_uid is not None: 2094 # pylint: disable=protected-access 2095 # Let the checkpointing infrastructure know that the variable was 2096 # already restored so it doesn't waste memory loading the value again. 2097 # In this case of CheckpointInitialValueCallable this may already be 2098 # done by the final variable creator, but it doesn't hurt to do it 2099 # again. 2100 created._maybe_initialize_trackable() 2101 created._update_uid = checkpoint_restore_uid 2102 # pylint: enable=protected-access 2103 return created 2104 2105 def distributed_getter(getter, *args, **kwargs): 2106 if not self._allow_variable_partition(): 2107 if kwargs.pop("partitioner", None) is not None: 2108 tf_logging.log_first_n( 2109 tf_logging.WARN, "Partitioned variables are disabled when using " 2110 "current tf.distribute.Strategy.", 1) 2111 return getter(*args, **kwargs) 2112 2113 return _CurrentDistributionContext( 2114 strategy, 2115 variable_scope.variable_creator_scope(creator_with_resource_vars), 2116 variable_scope.variable_scope( 2117 variable_scope.get_variable_scope(), 2118 custom_getter=distributed_getter), self._default_device) 2119 2120 def _allow_variable_partition(self): 2121 return False 2122 2123 def _create_variable(self, next_creator, **kwargs): 2124 # Note: should support "colocate_with" argument. 2125 raise NotImplementedError("must be implemented in descendants") 2126 2127 def variable_created_in_scope(self, v): 2128 """Tests whether `v` was created while this strategy scope was active. 2129 2130 Variables created inside the strategy scope are "owned" by it: 2131 2132 >>> strategy = tf.distribute.MirroredStrategy() 2133 >>> with strategy.scope(): 2134 ... v = tf.Variable(1.) 2135 >>> strategy.extended.variable_created_in_scope(v) 2136 True 2137 2138 Variables created outside the strategy are not owned by it: 2139 2140 >>> strategy = tf.distribute.MirroredStrategy() 2141 >>> v = tf.Variable(1.) 2142 >>> strategy.extended.variable_created_in_scope(v) 2143 False 2144 2145 Args: 2146 v: A `tf.Variable` instance. 2147 2148 Returns: 2149 True if `v` was created inside the scope, False if not. 2150 """ 2151 return v._distribute_strategy == self._container_strategy_weakref() # pylint: disable=protected-access 2152 2153 def colocate_vars_with(self, colocate_with_variable): 2154 """Scope that controls which devices variables will be created on. 2155 2156 No operations should be added to the graph inside this scope, it 2157 should only be used when creating variables (some implementations 2158 work by changing variable creation, others work by using a 2159 tf.compat.v1.colocate_with() scope). 2160 2161 This may only be used inside `self.scope()`. 2162 2163 Example usage: 2164 2165 ``` 2166 with strategy.scope(): 2167 var1 = tf.Variable(...) 2168 with strategy.extended.colocate_vars_with(var1): 2169 # var2 and var3 will be created on the same device(s) as var1 2170 var2 = tf.Variable(...) 2171 var3 = tf.Variable(...) 2172 2173 def fn(v1, v2, v3): 2174 # operates on v1 from var1, v2 from var2, and v3 from var3 2175 2176 # `fn` runs on every device `var1` is on, `var2` and `var3` will be there 2177 # too. 2178 strategy.extended.update(var1, fn, args=(var2, var3)) 2179 ``` 2180 2181 Args: 2182 colocate_with_variable: A variable created in this strategy's `scope()`. 2183 Variables created while in the returned context manager will be on the 2184 same set of devices as `colocate_with_variable`. 2185 2186 Returns: 2187 A context manager. 2188 """ 2189 2190 def create_colocated_variable(next_creator, **kwargs): 2191 _require_strategy_scope_extended(self) 2192 kwargs["use_resource"] = True 2193 kwargs["colocate_with"] = colocate_with_variable 2194 return next_creator(**kwargs) 2195 2196 _require_strategy_scope_extended(self) 2197 self._validate_colocate_with_variable(colocate_with_variable) 2198 return variable_scope.variable_creator_scope(create_colocated_variable) 2199 2200 def _validate_colocate_with_variable(self, colocate_with_variable): 2201 """Validate `colocate_with_variable` argument to `colocate_vars_with`.""" 2202 pass 2203 2204 def _make_dataset_iterator(self, dataset): 2205 raise NotImplementedError("must be implemented in descendants") 2206 2207 def _make_input_fn_iterator(self, input_fn, replication_mode): 2208 raise NotImplementedError("must be implemented in descendants") 2209 2210 def _experimental_distribute_dataset(self, dataset, options): 2211 raise NotImplementedError("must be implemented in descendants") 2212 2213 def _distribute_datasets_from_function(self, dataset_fn, options): 2214 raise NotImplementedError("must be implemented in descendants") 2215 2216 def _experimental_distribute_values_from_function(self, value_fn): 2217 raise NotImplementedError("must be implemented in descendants") 2218 2219 def _reduce(self, reduce_op, value): 2220 # Default implementation until we have an implementation for each strategy. 2221 dst = device_util.current() or self._default_device or "/device:CPU:0" 2222 return self._local_results(self.reduce_to(reduce_op, value, dst))[0] 2223 2224 def reduce_to(self, reduce_op, value, destinations, options=None): 2225 """Combine (via e.g. sum or mean) values across replicas. 2226 2227 `reduce_to` aggregates `tf.distribute.DistributedValues` and distributed 2228 variables. It supports both dense values and `tf.IndexedSlices`. 2229 2230 This API currently can only be called in cross-replica context. Other 2231 variants to reduce values across replicas are: 2232 * `tf.distribute.StrategyExtended.batch_reduce_to`: the batch version of 2233 this API. 2234 * `tf.distribute.ReplicaContext.all_reduce`: the counterpart of this API 2235 in replica context. It supports both batched and non-batched all-reduce. 2236 * `tf.distribute.Strategy.reduce`: a more convenient method to reduce 2237 to the host in cross-replica context. 2238 2239 `destinations` specifies where to reduce the value to, e.g. "GPU:0". You can 2240 also pass in a `Tensor`, and the destinations will be the device of that 2241 tensor. For all-reduce, pass the same to `value` and `destinations`. 2242 2243 It can be used in `tf.distribute.ReplicaContext.merge_call` to write code 2244 that works for all `tf.distribute.Strategy`. 2245 2246 >>> @tf.function 2247 ... def step_fn(var): 2248 ... 2249 ... def merge_fn(strategy, value, var): 2250 ... # All-reduce the value. Note that `value` here is a 2251 ... # `tf.distribute.DistributedValues`. 2252 ... reduced = strategy.extended.reduce_to(tf.distribute.ReduceOp.SUM, 2253 ... value, destinations=var) 2254 ... strategy.extended.update(var, lambda var, value: var.assign(value), 2255 ... args=(reduced,)) 2256 ... 2257 ... value = tf.identity(1.) 2258 ... tf.distribute.get_replica_context().merge_call(merge_fn, 2259 ... args=(value, var)) 2260 >>> 2261 >>> def run(strategy): 2262 ... with strategy.scope(): 2263 ... v = tf.Variable(0.) 2264 ... strategy.run(step_fn, args=(v,)) 2265 ... return v 2266 >>> 2267 >>> run(tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])) 2268 MirroredVariable:{ 2269 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>, 2270 1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=2.0> 2271 } 2272 >>> run(tf.distribute.experimental.CentralStorageStrategy( 2273 ... compute_devices=["GPU:0", "GPU:1"], parameter_device="CPU:0")) 2274 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0> 2275 >>> run(tf.distribute.OneDeviceStrategy("GPU:0")) 2276 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0> 2277 2278 Args: 2279 reduce_op: a `tf.distribute.ReduceOp` value specifying how values should 2280 be combined. Allows using string representation of the enum such as 2281 "SUM", "MEAN". 2282 value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` like object. 2283 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a 2284 `tf.Tensor` alike object, or a device string. It specifies the devices 2285 to reduce to. To perform an all-reduce, pass the same to `value` and 2286 `destinations`. Note that if it's a `tf.Variable`, the value is reduced 2287 to the devices of that variable, and this method doesn't update the 2288 variable. 2289 options: a `tf.distribute.experimental.CommunicationOptions`. Options to 2290 perform collective operations. This overrides the default options if the 2291 `tf.distribute.Strategy` takes one in the constructor. See 2292 `tf.distribute.experimental.CommunicationOptions` for details of the 2293 options. 2294 2295 Returns: 2296 A tensor or value reduced to `destinations`. 2297 """ 2298 if options is None: 2299 options = collective_util.Options() 2300 _require_cross_replica_or_default_context_extended(self) 2301 assert not isinstance(destinations, (list, tuple)) 2302 assert not isinstance(reduce_op, variable_scope.VariableAggregation) 2303 if isinstance(reduce_op, six.string_types): 2304 reduce_op = reduce_util.ReduceOp(reduce_op.upper()) 2305 assert (reduce_op == reduce_util.ReduceOp.SUM or 2306 reduce_op == reduce_util.ReduceOp.MEAN) 2307 return self._reduce_to(reduce_op, value, destinations, options) 2308 2309 def _reduce_to(self, reduce_op, value, destinations, options): 2310 raise NotImplementedError("must be implemented in descendants") 2311 2312 def batch_reduce_to(self, reduce_op, value_destination_pairs, options=None): 2313 """Combine multiple `reduce_to` calls into one for faster execution. 2314 2315 Similar to `reduce_to`, but accepts a list of (value, destinations) pairs. 2316 It's more efficient than reduce each value separately. 2317 2318 This API currently can only be called in cross-replica context. Other 2319 variants to reduce values across replicas are: 2320 * `tf.distribute.StrategyExtended.reduce_to`: the non-batch version of 2321 this API. 2322 * `tf.distribute.ReplicaContext.all_reduce`: the counterpart of this API 2323 in replica context. It supports both batched and non-batched all-reduce. 2324 * `tf.distribute.Strategy.reduce`: a more convenient method to reduce 2325 to the host in cross-replica context. 2326 2327 See `reduce_to` for more information. 2328 2329 >>> @tf.function 2330 ... def step_fn(var): 2331 ... 2332 ... def merge_fn(strategy, value, var): 2333 ... # All-reduce the value. Note that `value` here is a 2334 ... # `tf.distribute.DistributedValues`. 2335 ... reduced = strategy.extended.batch_reduce_to( 2336 ... tf.distribute.ReduceOp.SUM, [(value, var)])[0] 2337 ... strategy.extended.update(var, lambda var, value: var.assign(value), 2338 ... args=(reduced,)) 2339 ... 2340 ... value = tf.identity(1.) 2341 ... tf.distribute.get_replica_context().merge_call(merge_fn, 2342 ... args=(value, var)) 2343 >>> 2344 >>> def run(strategy): 2345 ... with strategy.scope(): 2346 ... v = tf.Variable(0.) 2347 ... strategy.run(step_fn, args=(v,)) 2348 ... return v 2349 >>> 2350 >>> run(tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])) 2351 MirroredVariable:{ 2352 0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0>, 2353 1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=2.0> 2354 } 2355 >>> run(tf.distribute.experimental.CentralStorageStrategy( 2356 ... compute_devices=["GPU:0", "GPU:1"], parameter_device="CPU:0")) 2357 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=2.0> 2358 >>> run(tf.distribute.OneDeviceStrategy("GPU:0")) 2359 <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0> 2360 2361 Args: 2362 reduce_op: a `tf.distribute.ReduceOp` value specifying how values should 2363 be combined. Allows using string representation of the enum such as 2364 "SUM", "MEAN". 2365 value_destination_pairs: a sequence of (value, destinations) pairs. See 2366 `tf.distribute.Strategy.reduce_to` for descriptions. 2367 options: a `tf.distribute.experimental.CommunicationOptions`. Options to 2368 perform collective operations. This overrides the default options if the 2369 `tf.distribute.Strategy` takes one in the constructor. See 2370 `tf.distribute.experimental.CommunicationOptions` for details of the 2371 options. 2372 2373 Returns: 2374 A list of reduced values, one per pair in `value_destination_pairs`. 2375 """ 2376 if options is None: 2377 options = collective_util.Options() 2378 _require_cross_replica_or_default_context_extended(self) 2379 assert not isinstance(reduce_op, variable_scope.VariableAggregation) 2380 if isinstance(reduce_op, six.string_types): 2381 reduce_op = reduce_util.ReduceOp(reduce_op.upper()) 2382 return self._batch_reduce_to(reduce_op, value_destination_pairs, options) 2383 2384 def _batch_reduce_to(self, reduce_op, value_destination_pairs, options): 2385 return [ 2386 self.reduce_to(reduce_op, t, destinations=v, options=options) 2387 for t, v in value_destination_pairs 2388 ] 2389 2390 def _replica_ctx_all_reduce(self, reduce_op, value, options=None): 2391 """All-reduce `value` across all replicas so that all get the final result. 2392 2393 If `value` is a nested structure of tensors, all-reduces of these tensors 2394 will be batched when possible. `options` can be set to hint the batching 2395 behavior. 2396 2397 This API must be called in a replica context. 2398 2399 Args: 2400 reduce_op: A `tf.distribute.ReduceOp` value specifying how values should 2401 be combined. Allows using string representation of the enum such as 2402 "SUM", "MEAN". 2403 value: Value to be reduced. A tensor or a nested structure of tensors. 2404 options: A `tf.distribute.experimental.CommunicationOptions`. Options to 2405 perform collective operations. This overrides the default options if the 2406 `tf.distribute.Strategy` takes one in the constructor. 2407 2408 Returns: 2409 A tensor or a nested strucutre of tensors with the reduced values. The 2410 structure is the same as `value`. 2411 """ 2412 if options is None: 2413 options = collective_util.Options() 2414 replica_context = distribution_strategy_context.get_replica_context() 2415 assert replica_context, ( 2416 "`StrategyExtended._replica_ctx_all_reduce` must be called in" 2417 " a replica context") 2418 2419 def merge_fn(_, flat_value): 2420 return self.batch_reduce_to(reduce_op, [(v, v) for v in flat_value], 2421 options) 2422 2423 reduced = replica_context.merge_call(merge_fn, args=(nest.flatten(value),)) 2424 return nest.pack_sequence_as(value, reduced) 2425 2426 def _gather_to(self, value, destinations, axis, options=None): 2427 """Gather `value` across replicas along axis-th dimension to `destinations`. 2428 2429 `gather_to` gathers `tf.distribute.DistributedValues` or `tf.Tensor`-like 2430 object, along `axis`-th dimension. It supports only dense tensors but NOT 2431 sparse tensor. This API can only be called in cross-replica context. 2432 2433 Args: 2434 value: a `tf.distribute.DistributedValues`, or a `tf.Tensor` like object. 2435 destinations: a `tf.distribute.DistributedValues`, a `tf.Variable`, a 2436 `tf.Tensor` alike object, or a device string. It specifies the devices 2437 to reduce to. To perform an all-gather, pass the same to `value` and 2438 `destinations`. Note that if it's a `tf.Variable`, the value is reduced 2439 to the devices of that variable, and this method doesn't update the 2440 variable. 2441 axis: 0-D int32 Tensor. Dimension along which to gather. Must be in the 2442 range [0, rank(value)). 2443 options: a `tf.distribute.experimental.CommunicationOptions`. Options to 2444 perform collective operations. This overrides the default options if the 2445 `tf.distribute.Strategy` takes one in the constructor. See 2446 `tf.distribute.experimental.CommunicationOptions` for details of the 2447 options. 2448 2449 Returns: 2450 A tensor or value gathered to `destinations`. 2451 """ 2452 _require_cross_replica_or_default_context_extended(self) 2453 assert not isinstance(destinations, (list, tuple)) 2454 if options is None: 2455 options = collective_util.Options() 2456 return self._gather_to_implementation(value, destinations, axis, options) 2457 2458 def _gather_to_implementation(self, value, destinations, axis, options): 2459 raise NotImplementedError("_gather_to must be implemented in descendants") 2460 2461 def _batch_gather_to(self, value_destination_pairs, axis, options=None): 2462 _require_cross_replica_or_default_context_extended(self) 2463 if options is None: 2464 options = collective_util.Options() 2465 return [ 2466 self._gather_to(t, destinations=v, axis=axis, options=options) 2467 for t, v in value_destination_pairs 2468 ] 2469 2470 def update(self, var, fn, args=(), kwargs=None, group=True): 2471 """Run `fn` to update `var` using inputs mirrored to the same devices. 2472 2473 `tf.distribute.StrategyExtended.update` takes a distributed variable `var` 2474 to be updated, an update function `fn`, and `args` and `kwargs` for `fn`. It 2475 applies `fn` to each component variable of `var` and passes corresponding 2476 values from `args` and `kwargs`. Neither `args` nor `kwargs` may contain 2477 per-replica values. If they contain mirrored values, they will be unwrapped 2478 before calling `fn`. For example, `fn` can be `assign_add` and `args` can be 2479 a mirrored DistributedValues where each component contains the value to be 2480 added to this mirrored variable `var`. Calling `update` will call 2481 `assign_add` on each component variable of `var` with the corresponding 2482 tensor value on that device. 2483 2484 Example usage: 2485 2486 ```python 2487 strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) # With 2 2488 devices 2489 with strategy.scope(): 2490 v = tf.Variable(5.0, aggregation=tf.VariableAggregation.SUM) 2491 def update_fn(v): 2492 return v.assign(1.0) 2493 result = strategy.extended.update(v, update_fn) 2494 # result is 2495 # Mirrored:{ 2496 # 0: tf.Tensor(1.0, shape=(), dtype=float32), 2497 # 1: tf.Tensor(1.0, shape=(), dtype=float32) 2498 # } 2499 ``` 2500 2501 If `var` is mirrored across multiple devices, then this method implements 2502 logic as following: 2503 2504 ```python 2505 results = {} 2506 for device, v in var: 2507 with tf.device(device): 2508 # args and kwargs will be unwrapped if they are mirrored. 2509 results[device] = fn(v, *args, **kwargs) 2510 return merged(results) 2511 ``` 2512 2513 Otherwise, this method returns `fn(var, *args, **kwargs)` colocated with 2514 `var`. 2515 2516 Args: 2517 var: Variable, possibly mirrored to multiple devices, to operate on. 2518 fn: Function to call. Should take the variable as the first argument. 2519 args: Tuple or list. Additional positional arguments to pass to `fn()`. 2520 kwargs: Dict with keyword arguments to pass to `fn()`. 2521 group: Boolean. Defaults to True. If False, the return value will be 2522 unwrapped. 2523 2524 Returns: 2525 By default, the merged return value of `fn` across all replicas. The 2526 merged result has dependencies to make sure that if it is evaluated at 2527 all, the side effects (updates) will happen on every replica. If instead 2528 "group=False" is specified, this function will return a nest of lists 2529 where each list has an element per replica, and the caller is responsible 2530 for ensuring all elements are executed. 2531 """ 2532 _require_cross_replica_or_default_context_extended(self) 2533 if kwargs is None: 2534 kwargs = {} 2535 fn = autograph.tf_convert( 2536 fn, autograph_ctx.control_status_ctx(), convert_by_default=False) 2537 with self._container_strategy().scope(): 2538 return self._update(var, fn, args, kwargs, group) 2539 2540 def _update(self, var, fn, args, kwargs, group): 2541 raise NotImplementedError("must be implemented in descendants") 2542 2543 def _local_results(self, distributed_value): 2544 raise NotImplementedError("must be implemented in descendants") 2545 2546 def value_container(self, value): 2547 """Returns the container that this per-replica `value` belongs to. 2548 2549 Args: 2550 value: A value returned by `run()` or a variable created in `scope()`. 2551 2552 Returns: 2553 A container that `value` belongs to. 2554 If value does not belong to any container (including the case of 2555 container having been destroyed), returns the value itself. 2556 `value in experimental_local_results(value_container(value))` will 2557 always be true. 2558 """ 2559 raise NotImplementedError("must be implemented in descendants") 2560 2561 def _group(self, value, name=None): 2562 """Implementation of `group`.""" 2563 value = nest.flatten(self._local_results(value)) 2564 2565 if len(value) != 1 or name is not None: 2566 return control_flow_ops.group(value, name=name) 2567 # Special handling for the common case of one op. 2568 v, = value 2569 if hasattr(v, "op"): 2570 v = v.op 2571 return v 2572 2573 @property 2574 def experimental_require_static_shapes(self): 2575 """Returns `True` if static shape is required; `False` otherwise.""" 2576 return self._require_static_shapes 2577 2578 @property 2579 def _num_replicas_in_sync(self): 2580 """Returns number of replicas over which gradients are aggregated.""" 2581 raise NotImplementedError("must be implemented in descendants") 2582 2583 @property 2584 def worker_devices(self): 2585 """Returns the tuple of all devices used to for compute replica execution. 2586 """ 2587 # TODO(josh11b): More docstring 2588 raise NotImplementedError("must be implemented in descendants") 2589 2590 @property 2591 def parameter_devices(self): 2592 """Returns the tuple of all devices used to place variables.""" 2593 # TODO(josh11b): More docstring 2594 raise NotImplementedError("must be implemented in descendants") 2595 2596 def _configure(self, 2597 session_config=None, 2598 cluster_spec=None, 2599 task_type=None, 2600 task_id=None): 2601 """Configures the strategy class.""" 2602 del session_config, cluster_spec, task_type, task_id 2603 2604 def _update_config_proto(self, config_proto): 2605 return copy.deepcopy(config_proto) 2606 2607 def _in_multi_worker_mode(self): 2608 """Whether this strategy indicates working in multi-worker settings. 2609 2610 Multi-worker training refers to the setup where the training is 2611 distributed across multiple workers, as opposed to the case where 2612 only a local process performs the training. This function is 2613 used by higher-level APIs such as Keras' `model.fit()` to infer 2614 for example whether or not a distribute coordinator should be run, 2615 and thus TensorFlow servers should be started for communication 2616 with other servers in the cluster, or whether or not saving/restoring 2617 checkpoints is relevant for preemption fault tolerance. 2618 2619 Subclasses should override this to provide whether the strategy is 2620 currently in multi-worker setup. 2621 2622 Experimental. Signature and implementation are subject to change. 2623 """ 2624 raise NotImplementedError("must be implemented in descendants") 2625 2626 2627@tf_export(v1=["distribute.StrategyExtended"]) # pylint: disable=missing-docstring 2628class StrategyExtendedV1(StrategyExtendedV2): 2629 2630 __doc__ = StrategyExtendedV2.__doc__ 2631 2632 def experimental_make_numpy_dataset(self, numpy_input, session=None): 2633 """Makes a dataset for input provided via a numpy array. 2634 2635 This avoids adding `numpy_input` as a large constant in the graph, 2636 and copies the data to the machine or machines that will be processing 2637 the input. 2638 2639 Args: 2640 numpy_input: A nest of NumPy input arrays that will be distributed evenly 2641 across all replicas. Note that lists of Numpy arrays are stacked, as 2642 that is normal `tf.data.Dataset` behavior. 2643 session: (TensorFlow v1.x graph execution only) A session used for 2644 initialization. 2645 2646 Returns: 2647 A `tf.data.Dataset` representing `numpy_input`. 2648 """ 2649 _require_cross_replica_or_default_context_extended(self) 2650 return self._experimental_make_numpy_dataset(numpy_input, session=session) 2651 2652 def _experimental_make_numpy_dataset(self, numpy_input, session): 2653 raise NotImplementedError("must be implemented in descendants") 2654 2655 def broadcast_to(self, tensor, destinations): 2656 """Mirror a tensor on one device to all worker devices. 2657 2658 Args: 2659 tensor: A Tensor value to broadcast. 2660 destinations: A mirrored variable or device string specifying the 2661 destination devices to copy `tensor` to. 2662 2663 Returns: 2664 A value mirrored to `destinations` devices. 2665 """ 2666 assert destinations is not None # from old strategy.broadcast() 2667 # TODO(josh11b): More docstring 2668 _require_cross_replica_or_default_context_extended(self) 2669 assert not isinstance(destinations, (list, tuple)) 2670 return self._broadcast_to(tensor, destinations) 2671 2672 def _broadcast_to(self, tensor, destinations): 2673 raise NotImplementedError("must be implemented in descendants") 2674 2675 def experimental_run_steps_on_iterator(self, 2676 fn, 2677 iterator, 2678 iterations=1, 2679 initial_loop_values=None): 2680 """DEPRECATED: please use `run` instead. 2681 2682 Run `fn` with input from `iterator` for `iterations` times. 2683 2684 This method can be used to run a step function for training a number of 2685 times using input from a dataset. 2686 2687 Args: 2688 fn: function to run using this distribution strategy. The function must 2689 have the following signature: `def fn(context, inputs)`. `context` is an 2690 instance of `MultiStepContext` that will be passed when `fn` is run. 2691 `context` can be used to specify the outputs to be returned from `fn` 2692 by calling `context.set_last_step_output`. It can also be used to 2693 capture non tensor outputs by `context.set_non_tensor_output`. See 2694 `MultiStepContext` documentation for more information. `inputs` will 2695 have same type/structure as `iterator.get_next()`. Typically, `fn` 2696 will use `call_for_each_replica` method of the strategy to distribute 2697 the computation over multiple replicas. 2698 iterator: Iterator of a dataset that represents the input for `fn`. The 2699 caller is responsible for initializing the iterator as needed. 2700 iterations: (Optional) Number of iterations that `fn` should be run. 2701 Defaults to 1. 2702 initial_loop_values: (Optional) Initial values to be passed into the 2703 loop that runs `fn`. Defaults to `None`. # TODO(priyag): Remove 2704 initial_loop_values argument when we have a mechanism to infer the 2705 outputs of `fn`. 2706 2707 Returns: 2708 Returns the `MultiStepContext` object which has the following properties, 2709 among other things: 2710 - run_op: An op that runs `fn` `iterations` times. 2711 - last_step_outputs: A dictionary containing tensors set using 2712 `context.set_last_step_output`. Evaluating this returns the value of 2713 the tensors after the last iteration. 2714 - non_tensor_outputs: A dictionary containing anything that was set by 2715 `fn` by calling `context.set_non_tensor_output`. 2716 """ 2717 _require_cross_replica_or_default_context_extended(self) 2718 with self._container_strategy().scope(): 2719 return self._experimental_run_steps_on_iterator(fn, iterator, iterations, 2720 initial_loop_values) 2721 2722 def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, 2723 initial_loop_values): 2724 raise NotImplementedError("must be implemented in descendants") 2725 2726 def call_for_each_replica(self, fn, args=(), kwargs=None): 2727 """Run `fn` once per replica. 2728 2729 `fn` may call `tf.get_replica_context()` to access methods such as 2730 `replica_id_in_sync_group` and `merge_call()`. 2731 2732 `merge_call()` is used to communicate between the replicas and 2733 re-enter the cross-replica context. All replicas pause their execution 2734 having encountered a `merge_call()` call. After that the 2735 `merge_fn`-function is executed. Its results are then unwrapped and 2736 given back to each replica call. After that execution resumes until 2737 `fn` is complete or encounters another `merge_call()`. Example: 2738 2739 ```python 2740 # Called once in "cross-replica" context. 2741 def merge_fn(distribution, three_plus_replica_id): 2742 # sum the values across replicas 2743 return sum(distribution.experimental_local_results(three_plus_replica_id)) 2744 2745 # Called once per replica in `distribution`, in a "replica" context. 2746 def fn(three): 2747 replica_ctx = tf.get_replica_context() 2748 v = three + replica_ctx.replica_id_in_sync_group 2749 # Computes the sum of the `v` values across all replicas. 2750 s = replica_ctx.merge_call(merge_fn, args=(v,)) 2751 return s + v 2752 2753 with distribution.scope(): 2754 # in "cross-replica" context 2755 ... 2756 merged_results = distribution.run(fn, args=[3]) 2757 # merged_results has the values from every replica execution of `fn`. 2758 # This statement prints a list: 2759 print(distribution.experimental_local_results(merged_results)) 2760 ``` 2761 2762 Args: 2763 fn: function to run (will be run once per replica). 2764 args: Tuple or list with positional arguments for `fn`. 2765 kwargs: Dict with keyword arguments for `fn`. 2766 2767 Returns: 2768 Merged return value of `fn` across all replicas. 2769 """ 2770 _require_cross_replica_or_default_context_extended(self) 2771 if kwargs is None: 2772 kwargs = {} 2773 with self._container_strategy().scope(): 2774 return self._call_for_each_replica(fn, args, kwargs) 2775 2776 def _call_for_each_replica(self, fn, args, kwargs): 2777 raise NotImplementedError("must be implemented in descendants") 2778 2779 def read_var(self, v): 2780 """Reads the value of a variable. 2781 2782 Returns the aggregate value of a replica-local variable, or the 2783 (read-only) value of any other variable. 2784 2785 Args: 2786 v: A variable allocated within the scope of this `tf.distribute.Strategy`. 2787 2788 Returns: 2789 A tensor representing the value of `v`, aggregated across replicas if 2790 necessary. 2791 """ 2792 raise NotImplementedError("must be implemented in descendants") 2793 2794 def update_non_slot( 2795 self, colocate_with, fn, args=(), kwargs=None, group=True): 2796 """Runs `fn(*args, **kwargs)` on `colocate_with` devices. 2797 2798 Used to update non-slot variables. 2799 2800 DEPRECATED: TF 1.x ONLY. 2801 2802 Args: 2803 colocate_with: Devices returned by `non_slot_devices()`. 2804 fn: Function to execute. 2805 args: Tuple or list. Positional arguments to pass to `fn()`. 2806 kwargs: Dict with keyword arguments to pass to `fn()`. 2807 group: Boolean. Defaults to True. If False, the return value will be 2808 unwrapped. 2809 2810 Returns: 2811 Return value of `fn`, possibly merged across devices. 2812 """ 2813 _require_cross_replica_or_default_context_extended(self) 2814 if kwargs is None: 2815 kwargs = {} 2816 fn = autograph.tf_convert( 2817 fn, autograph_ctx.control_status_ctx(), convert_by_default=False) 2818 with self._container_strategy().scope(): 2819 return self._update_non_slot(colocate_with, fn, args, kwargs, group) 2820 2821 def _update_non_slot(self, colocate_with, fn, args, kwargs, group): 2822 raise NotImplementedError("must be implemented in descendants") 2823 2824 def non_slot_devices(self, var_list): 2825 """Device(s) for non-slot variables. 2826 2827 DEPRECATED: TF 1.x ONLY. 2828 2829 This method returns non-slot devices where non-slot variables are placed. 2830 Users can create non-slot variables on these devices by using a block: 2831 2832 ```python 2833 with tf.distribute.StrategyExtended.colocate_vars_with(tf.distribute.StrategyExtended.non_slot_devices(...)): 2834 ... 2835 ``` 2836 2837 Args: 2838 var_list: The list of variables being optimized, needed with the 2839 default `tf.distribute.Strategy`. 2840 Returns: 2841 A sequence of devices for non-slot variables. 2842 """ 2843 raise NotImplementedError("must be implemented in descendants") 2844 2845 @property 2846 def experimental_between_graph(self): 2847 """Whether the strategy uses between-graph replication or not. 2848 2849 This is expected to return a constant value that will not be changed 2850 throughout its life cycle. 2851 """ 2852 raise NotImplementedError("must be implemented in descendants") 2853 2854 @property 2855 def experimental_should_init(self): 2856 """Whether initialization is needed.""" 2857 raise NotImplementedError("must be implemented in descendants") 2858 2859 @property 2860 def should_checkpoint(self): 2861 """Whether checkpointing is needed.""" 2862 raise NotImplementedError("must be implemented in descendants") 2863 2864 @property 2865 def should_save_summary(self): 2866 """Whether saving summaries is needed.""" 2867 raise NotImplementedError("must be implemented in descendants") 2868 2869 2870# A note about the difference between the context managers 2871# `ReplicaContext` (defined here) and `_CurrentDistributionContext` 2872# (defined above) used by `tf.distribute.Strategy.scope()`: 2873# 2874# * a ReplicaContext is only present during a `run()` 2875# call (except during a `merge_run` call) and in such a scope it 2876# will be returned by calls to `get_replica_context()`. Implementers of new 2877# Strategy descendants will frequently also need to 2878# define a descendant of ReplicaContext, and are responsible for 2879# entering and exiting this context. 2880# 2881# * Strategy.scope() sets up a variable_creator scope that 2882# changes variable creation calls (e.g. to make mirrored 2883# variables). This is intended as an outer scope that users enter once 2884# around their model creation and graph definition. There is no 2885# anticipated need to define descendants of _CurrentDistributionContext. 2886# It sets the current Strategy for purposes of 2887# `get_strategy()` and `has_strategy()` 2888# and switches the thread mode to a "cross-replica context". 2889class ReplicaContextBase(object): 2890 """A class with a collection of APIs that can be called in a replica context. 2891 2892 You can use `tf.distribute.get_replica_context` to get an instance of 2893 `ReplicaContext`, which can only be called inside the function passed to 2894 `tf.distribute.Strategy.run`. 2895 2896 >>> strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1']) 2897 >>> def func(): 2898 ... replica_context = tf.distribute.get_replica_context() 2899 ... return replica_context.replica_id_in_sync_group 2900 >>> strategy.run(func) 2901 PerReplica:{ 2902 0: <tf.Tensor: shape=(), dtype=int32, numpy=0>, 2903 1: <tf.Tensor: shape=(), dtype=int32, numpy=1> 2904 } 2905 """ 2906 2907 def __init__(self, strategy, replica_id_in_sync_group): 2908 """Creates a ReplicaContext. 2909 2910 Args: 2911 strategy: A `tf.distribute.Strategy`. 2912 replica_id_in_sync_group: An integer, a `Tensor` or None. Prefer an 2913 integer whenever possible to avoid issues with nested `tf.function`. It 2914 accepts a `Tensor` only to be compatible with `tpu.replicate`. 2915 """ 2916 self._strategy = strategy 2917 self._thread_context = distribution_strategy_context._InReplicaThreadMode( # pylint: disable=protected-access 2918 self) 2919 if not (replica_id_in_sync_group is None or 2920 tensor_util.is_tf_type(replica_id_in_sync_group) or 2921 isinstance(replica_id_in_sync_group, int)): 2922 raise ValueError( 2923 "replica_id_in_sync_group can only be an integer, a Tensor or None.") 2924 self._replica_id_in_sync_group = replica_id_in_sync_group 2925 # We need this check because TPUContext extends from ReplicaContext and 2926 # does not pass a strategy object since it is used by TPUEstimator. 2927 if strategy: 2928 self._local_replica_id = strategy.extended._get_local_replica_id( 2929 replica_id_in_sync_group) 2930 self._summary_recording_distribution_strategy = None 2931 2932 @doc_controls.do_not_generate_docs 2933 def __enter__(self): 2934 _push_per_thread_mode(self._thread_context) 2935 2936 def replica_id_is_zero(): 2937 return math_ops.equal(self.replica_id_in_sync_group, 2938 constant_op.constant(0)) 2939 2940 summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access 2941 self._summary_recording_distribution_strategy = ( 2942 summary_state.is_recording_distribution_strategy) 2943 summary_state.is_recording_distribution_strategy = replica_id_is_zero 2944 2945 @doc_controls.do_not_generate_docs 2946 def __exit__(self, exception_type, exception_value, traceback): 2947 summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access 2948 summary_state.is_recording_distribution_strategy = ( 2949 self._summary_recording_distribution_strategy) 2950 _pop_per_thread_mode() 2951 2952 def merge_call(self, merge_fn, args=(), kwargs=None): 2953 """Merge args across replicas and run `merge_fn` in a cross-replica context. 2954 2955 This allows communication and coordination when there are multiple calls 2956 to the step_fn triggered by a call to `strategy.run(step_fn, ...)`. 2957 2958 See `tf.distribute.Strategy.run` for an explanation. 2959 2960 If not inside a distributed scope, this is equivalent to: 2961 2962 ``` 2963 strategy = tf.distribute.get_strategy() 2964 with cross-replica-context(strategy): 2965 return merge_fn(strategy, *args, **kwargs) 2966 ``` 2967 2968 Args: 2969 merge_fn: Function that joins arguments from threads that are given as 2970 PerReplica. It accepts `tf.distribute.Strategy` object as 2971 the first argument. 2972 args: List or tuple with positional per-thread arguments for `merge_fn`. 2973 kwargs: Dict with keyword per-thread arguments for `merge_fn`. 2974 2975 Returns: 2976 The return value of `merge_fn`, except for `PerReplica` values which are 2977 unpacked. 2978 """ 2979 require_replica_context(self) 2980 if kwargs is None: 2981 kwargs = {} 2982 2983 merge_fn = autograph.tf_convert( 2984 merge_fn, autograph_ctx.control_status_ctx(), convert_by_default=False) 2985 return self._merge_call(merge_fn, args, kwargs) 2986 2987 def _merge_call(self, merge_fn, args, kwargs): 2988 """Default implementation for single replica.""" 2989 _push_per_thread_mode( # thread-local, so not needed with multiple threads 2990 distribution_strategy_context._CrossReplicaThreadMode(self._strategy)) # pylint: disable=protected-access 2991 try: 2992 return merge_fn(self._strategy, *args, **kwargs) 2993 finally: 2994 _pop_per_thread_mode() 2995 2996 @property 2997 def num_replicas_in_sync(self): 2998 """Returns number of replicas that are kept in sync.""" 2999 return self._strategy.num_replicas_in_sync 3000 3001 @property 3002 def replica_id_in_sync_group(self): 3003 """Returns the id of the replica. 3004 3005 This identifies the replica among all replicas that are kept in sync. The 3006 value of the replica id can range from 0 to 3007 `tf.distribute.ReplicaContext.num_replicas_in_sync` - 1. 3008 3009 NOTE: This is not guaranteed to be the same ID as the XLA replica ID use 3010 for low-level operations such as collective_permute. 3011 3012 Returns: 3013 a `Tensor`. 3014 """ 3015 # It's important to prefer making the Tensor at call time whenever possible. 3016 # Keeping Tensors in global states doesn't work well with nested 3017 # tf.function, since it's possible that the tensor is generated in one func 3018 # graph, and gets captured by another, which will result in a subtle "An op 3019 # outside of the function building code is being passed a Graph tensor" 3020 # error. Making the tensor at call time to ensure it is the same graph where 3021 # it's used. However to be compatible with tpu.replicate(), 3022 # self._replica_id_in_sync_group can also be a Tensor. 3023 if tensor_util.is_tf_type(self._replica_id_in_sync_group): 3024 return self._replica_id_in_sync_group 3025 return constant_op.constant( 3026 self._replica_id_in_sync_group, 3027 dtypes.int32, 3028 name="replica_id_in_sync_group") 3029 3030 @property 3031 def _replica_id(self): 3032 """This is the local replica id in a given sync group.""" 3033 return self._local_replica_id 3034 3035 @property 3036 def strategy(self): 3037 """The current `tf.distribute.Strategy` object.""" 3038 return self._strategy 3039 3040 @property 3041 @deprecation.deprecated(None, "Please avoid relying on devices property.") 3042 def devices(self): 3043 """Returns the devices this replica is to be executed on, as a tuple of strings. 3044 3045 NOTE: For `tf.distribute.MirroredStrategy` and 3046 `tf.distribute.experimental.MultiWorkerMirroredStrategy`, this returns a 3047 nested 3048 list of device strings, e.g, [["GPU:0"]]. 3049 """ 3050 require_replica_context(self) 3051 return (device_util.current(),) 3052 3053 def all_reduce(self, reduce_op, value, options=None): 3054 """All-reduces `value` across all replicas. 3055 3056 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 3057 >>> def step_fn(): 3058 ... ctx = tf.distribute.get_replica_context() 3059 ... value = tf.identity(1.) 3060 ... return ctx.all_reduce(tf.distribute.ReduceOp.SUM, value) 3061 >>> strategy.experimental_local_results(strategy.run(step_fn)) 3062 (<tf.Tensor: shape=(), dtype=float32, numpy=2.0>, 3063 <tf.Tensor: shape=(), dtype=float32, numpy=2.0>) 3064 3065 It supports batched operations. You can pass a list of values and it 3066 attempts to batch them when possible. You can also specify `options` 3067 to indicate the desired batching behavior, e.g. batch the values into 3068 multiple packs so that they can better overlap with computations. 3069 3070 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 3071 >>> def step_fn(): 3072 ... ctx = tf.distribute.get_replica_context() 3073 ... value1 = tf.identity(1.) 3074 ... value2 = tf.identity(2.) 3075 ... return ctx.all_reduce(tf.distribute.ReduceOp.SUM, [value1, value2]) 3076 >>> strategy.experimental_local_results(strategy.run(step_fn)) 3077 ([PerReplica:{ 3078 0: <tf.Tensor: shape=(), dtype=float32, numpy=2.0>, 3079 1: <tf.Tensor: shape=(), dtype=float32, numpy=2.0> 3080 }, PerReplica:{ 3081 0: <tf.Tensor: shape=(), dtype=float32, numpy=4.0>, 3082 1: <tf.Tensor: shape=(), dtype=float32, numpy=4.0> 3083 }],) 3084 3085 Note that all replicas need to participate in the all-reduce, otherwise this 3086 operation hangs. Note that if there're multiple all-reduces, they need to 3087 execute in the same order on all replicas. Dispatching all-reduce based on 3088 conditions is usually error-prone. 3089 3090 This API currently can only be called in the replica context. Other 3091 variants to reduce values across replicas are: 3092 * `tf.distribute.StrategyExtended.reduce_to`: the reduce and all-reduce API 3093 in the cross-replica context. 3094 * `tf.distribute.StrategyExtended.batch_reduce_to`: the batched reduce and 3095 all-reduce API in the cross-replica context. 3096 * `tf.distribute.Strategy.reduce`: a more convenient method to reduce 3097 to the host in cross-replica context. 3098 3099 Args: 3100 reduce_op: a `tf.distribute.ReduceOp` value specifying how values should 3101 be combined. Allows using string representation of the enum such as 3102 "SUM", "MEAN". 3103 value: a nested structure of `tf.Tensor` which `tf.nest.flatten` accepts. 3104 The structure and the shapes of the `tf.Tensor` need to be same on all 3105 replicas. 3106 options: a `tf.distribute.experimental.CommunicationOptions`. Options to 3107 perform collective operations. This overrides the default options if the 3108 `tf.distribute.Strategy` takes one in the constructor. See 3109 `tf.distribute.experimental.CommunicationOptions` for details of the 3110 options. 3111 3112 Returns: 3113 A nested structure of `tf.Tensor` with the reduced values. The structure 3114 is the same as `value`. 3115 """ 3116 if isinstance(reduce_op, six.string_types): 3117 reduce_op = reduce_util.ReduceOp(reduce_op.upper()) 3118 if options is None: 3119 options = collective_util.Options() 3120 3121 def batch_all_reduce(strategy, *value_flat): 3122 return strategy.extended.batch_reduce_to( 3123 reduce_op, [(v, _batch_reduce_destination(v)) for v in value_flat], 3124 options) 3125 3126 if reduce_op in [reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN]: 3127 # TODO(cjfj): Work out why `batch_reduce` doesn't return the correct grad. 3128 @custom_gradient.custom_gradient 3129 def grad_wrapper(*xs): 3130 ys = self.merge_call(batch_all_reduce, args=xs) 3131 # The gradient of an all-sum is itself an all-sum (all-mean, likewise). 3132 return ys, lambda *dy_s: self.all_reduce(reduce_op, dy_s) 3133 return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value))) 3134 else: 3135 # TODO(cjfj): Implement gradients for other reductions. 3136 reduced = nest.pack_sequence_as( 3137 value, self.merge_call(batch_all_reduce, args=nest.flatten(value))) 3138 return nest.map_structure(array_ops.prevent_gradient, reduced) 3139 3140 # TODO(josh11b): Implement `start_all_reduce(method, t)` for efficient 3141 # all-reduce. It would return a function returning the result of reducing `t` 3142 # across all replicas. The caller would wait to call this function until they 3143 # needed the reduce result, allowing an efficient implementation: 3144 # * With eager execution, the reduction could be performed asynchronously 3145 # in the background, not blocking until the result was needed. 3146 # * When constructing a graph, it could batch up all reduction requests up 3147 # to that point that the first result is needed. Most likely this can be 3148 # implemented in terms of `merge_call()` and `batch_reduce_to()`. 3149 3150 3151@tf_export("distribute.ReplicaContext", v1=[]) 3152class ReplicaContext(ReplicaContextBase): 3153 3154 __doc__ = ReplicaContextBase.__doc__ 3155 3156 def all_gather(self, value, axis, options=None): 3157 """All-gathers `value` across all replicas along `axis`. 3158 3159 Note: An `all_gather` method can only be called in replica context. For 3160 a cross-replica context counterpart, see `tf.distribute.Strategy.gather`. 3161 All replicas need to participate in the all-gather, otherwise this 3162 operation hangs. So if `all_gather` is called in any replica, it must be 3163 called in all replicas. 3164 3165 Note: If there are multiple `all_gather` calls, they need to be executed in 3166 the same order on all replicas. Dispatching `all_gather` based on conditions 3167 is usually error-prone. 3168 3169 For all strategies except `tf.distribute.TPUStrategy`, the input 3170 `value` on different replicas must have the same rank, and their shapes must 3171 be the same in all dimensions except the `axis`-th dimension. In other 3172 words, their shapes cannot be different in a dimension `d` where `d` does 3173 not equal to the `axis` argument. For example, given a 3174 `tf.distribute.DistributedValues` with component tensors of shape 3175 `(1, 2, 3)` and `(1, 3, 3)` on two replicas, you can call 3176 `all_gather(..., axis=1, ...)` on it, but not `all_gather(..., axis=0, ...)` 3177 or `all_gather(..., axis=2, ...)`. However, with 3178 `tf.distribute.TPUStrategy`, all tensors must have exactly the same rank and 3179 same shape. 3180 3181 Note: The input `value` must have a non-zero rank. Otherwise, consider using 3182 `tf.expand_dims` before gathering them. 3183 3184 You can pass in a single tensor to all-gather: 3185 3186 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 3187 >>> @tf.function 3188 ... def gather_value(): 3189 ... ctx = tf.distribute.get_replica_context() 3190 ... local_value = tf.constant([1, 2, 3]) 3191 ... return ctx.all_gather(local_value, axis=0) 3192 >>> result = strategy.run(gather_value) 3193 >>> result 3194 PerReplica:{ 3195 0: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>, 3196 1: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)> 3197 } 3198 >>> strategy.experimental_local_results(result) 3199 (<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], 3200 dtype=int32)>, 3201 <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], 3202 dtype=int32)>) 3203 3204 3205 You can also pass in a nested structure of tensors to all-gather, say, a 3206 list: 3207 3208 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 3209 >>> @tf.function 3210 ... def gather_nest(): 3211 ... ctx = tf.distribute.get_replica_context() 3212 ... value_1 = tf.constant([1, 2, 3]) 3213 ... value_2 = tf.constant([[1, 2], [3, 4]]) 3214 ... # all_gather a nest of `tf.distribute.DistributedValues` 3215 ... return ctx.all_gather([value_1, value_2], axis=0) 3216 >>> result = strategy.run(gather_nest) 3217 >>> result 3218 [PerReplica:{ 3219 0: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>, 3220 1: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)> 3221 }, PerReplica:{ 3222 0: <tf.Tensor: shape=(4, 2), dtype=int32, numpy= 3223 array([[1, 2], 3224 [3, 4], 3225 [1, 2], 3226 [3, 4]], dtype=int32)>, 3227 1: <tf.Tensor: shape=(4, 2), dtype=int32, numpy= 3228 array([[1, 2], 3229 [3, 4], 3230 [1, 2], 3231 [3, 4]], dtype=int32)> 3232 }] 3233 >>> strategy.experimental_local_results(result) 3234 ([PerReplica:{ 3235 0: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>, 3236 1: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)> 3237 }, PerReplica:{ 3238 0: <tf.Tensor: shape=(4, 2), dtype=int32, numpy= 3239 array([[1, 2], 3240 [3, 4], 3241 [1, 2], 3242 [3, 4]], dtype=int32)>, 3243 1: <tf.Tensor: shape=(4, 2), dtype=int32, numpy= 3244 array([[1, 2], 3245 [3, 4], 3246 [1, 2], 3247 [3, 4]], dtype=int32)> 3248 }],) 3249 3250 3251 What if you are all-gathering tensors with different shapes on different 3252 replicas? Consider the following example with two replicas, where you have 3253 `value` as a nested structure consisting of two items to all-gather, `a` and 3254 `b`. 3255 3256 On Replica 0, `value` is `{'a': [0], 'b': [[0, 1]]}`. 3257 3258 On Replica 1, `value` is `{'a': [1], 'b': [[2, 3], [4, 5]]}`. 3259 3260 Result for `all_gather` with `axis`=0 (on each of the replicas) is: 3261 3262 ```{'a': [1, 2], 'b': [[0, 1], [2, 3], [4, 5]]}``` 3263 3264 Args: 3265 value: a nested structure of `tf.Tensor` which `tf.nest.flatten` accepts, 3266 or a `tf.distribute.DistributedValues` instance. The structure of the 3267 `tf.Tensor` need to be same on all replicas. The underlying tensor 3268 constructs can only be dense tensors with non-zero rank, NOT 3269 `tf.IndexedSlices`. 3270 axis: 0-D int32 Tensor. Dimension along which to gather. 3271 options: a `tf.distribute.experimental.CommunicationOptions`. Options to 3272 perform collective operations. This overrides the default options if the 3273 `tf.distribute.Strategy` takes one in the constructor. See 3274 `tf.distribute.experimental.CommunicationOptions` for details of the 3275 options. 3276 3277 Returns: 3278 A nested structure of `tf.Tensor` with the gathered values. The structure 3279 is the same as `value`. 3280 """ 3281 for v in nest.flatten(value): 3282 if isinstance(v, ops.IndexedSlices): 3283 raise NotImplementedError("all_gather does not support IndexedSlices") 3284 3285 if options is None: 3286 options = collective_util.Options() 3287 3288 def batch_all_gather(strategy, *value_flat): 3289 return strategy.extended._batch_gather_to( # pylint: disable=protected-access 3290 [(v, _batch_reduce_destination(v)) for v in value_flat], axis, 3291 options) 3292 3293 @custom_gradient.custom_gradient 3294 def grad_wrapper(*xs): 3295 ys = self.merge_call(batch_all_gather, args=xs) 3296 3297 def grad(*dy_s): 3298 grads = self.all_reduce(reduce_util.ReduceOp.SUM, dy_s) 3299 new_grads = [] 3300 for i, grad in enumerate(grads): 3301 input_shape = array_ops.shape(xs[i]) 3302 axis_dim = array_ops.reshape(input_shape[axis], [1]) 3303 with ops.control_dependencies([array_ops.identity(grads)]): 3304 d = self.all_gather(axis_dim, axis=0) 3305 begin_dim = math_ops.reduce_sum(d[:self.replica_id_in_sync_group]) 3306 end_dim = begin_dim + array_ops.shape(xs[i])[axis] 3307 new_grad = array_ops.gather( 3308 grad, axis=axis, indices=math_ops.range(begin_dim, end_dim)) 3309 new_grads.append(new_grad) 3310 return new_grads 3311 3312 return ys, grad 3313 3314 return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value))) 3315 3316 3317@tf_export(v1=["distribute.ReplicaContext"]) 3318class ReplicaContextV1(ReplicaContextBase): 3319 __doc__ = ReplicaContextBase.__doc__ 3320 3321 3322def _batch_reduce_destination(x): 3323 """Returns the destinations for batch all-reduce.""" 3324 if isinstance(x, ops.Tensor): 3325 # If this is a one device strategy. 3326 return x.device 3327 else: 3328 return x 3329 3330 3331# ------------------------------------------------------------------------------ 3332 3333 3334_creating_default_strategy_singleton = False 3335 3336 3337class _DefaultDistributionStrategyV1(StrategyV1): 3338 """Default `tf.distribute.Strategy` if none is explicitly selected.""" 3339 3340 def __init__(self): 3341 if not _creating_default_strategy_singleton: 3342 raise RuntimeError("Should only create a single instance of " 3343 "_DefaultDistributionStrategy") 3344 super(_DefaultDistributionStrategyV1, 3345 self).__init__(_DefaultDistributionExtended(self)) 3346 3347 def __deepcopy__(self, memo): 3348 del memo 3349 raise RuntimeError("Should only create a single instance of " 3350 "_DefaultDistributionStrategy") 3351 3352 3353class _DefaultDistributionStrategy(Strategy): 3354 """Default `tf.distribute.Strategy` if none is explicitly selected.""" 3355 3356 def __init__(self): 3357 if not _creating_default_strategy_singleton: 3358 raise RuntimeError("Should only create a single instance of " 3359 "_DefaultDistributionStrategy") 3360 super(_DefaultDistributionStrategy, self).__init__( 3361 _DefaultDistributionExtended(self)) 3362 3363 def __deepcopy__(self, memo): 3364 del memo 3365 raise RuntimeError("Should only create a single instance of " 3366 "_DefaultDistributionStrategy") 3367 3368 3369class _DefaultDistributionContext(object): 3370 """Context manager setting the default `tf.distribute.Strategy`.""" 3371 3372 __slots__ = ["_var_creator_scope", "_strategy", "_nested_count"] 3373 3374 def __init__(self, strategy): 3375 3376 def creator(next_creator, **kwargs): 3377 _require_strategy_scope_strategy(strategy) 3378 return next_creator(**kwargs) 3379 3380 self._var_creator_scope = variable_scope.variable_creator_scope(creator) 3381 self._strategy = strategy 3382 self._nested_count = 0 3383 3384 def __enter__(self): 3385 # Allow this scope to be entered if this strategy is already in scope. 3386 if distribution_strategy_context.has_strategy(): 3387 raise RuntimeError("Must not nest tf.distribute.Strategy scopes.") 3388 if self._nested_count == 0: 3389 self._var_creator_scope.__enter__() 3390 self._nested_count += 1 3391 return self._strategy 3392 3393 def __exit__(self, exception_type, exception_value, traceback): 3394 self._nested_count -= 1 3395 if self._nested_count == 0: 3396 try: 3397 self._var_creator_scope.__exit__( 3398 exception_type, exception_value, traceback) 3399 except RuntimeError as e: 3400 six.raise_from( 3401 RuntimeError("Variable creator scope nesting error: move call to " 3402 "tf.distribute.set_strategy() out of `with` scope."), 3403 e) 3404 3405 3406class _DefaultDistributionExtended(StrategyExtendedV1): 3407 """Implementation of _DefaultDistributionStrategy.""" 3408 3409 def __init__(self, container_strategy): 3410 super(_DefaultDistributionExtended, self).__init__(container_strategy) 3411 self._retrace_functions_for_each_device = False 3412 3413 def _scope(self, strategy): 3414 """Context manager setting a variable creator and `self` as current.""" 3415 return _DefaultDistributionContext(strategy) 3416 3417 def colocate_vars_with(self, colocate_with_variable): 3418 """Does not require `self.scope`.""" 3419 _require_strategy_scope_extended(self) 3420 return ops.colocate_with(colocate_with_variable) 3421 3422 def variable_created_in_scope(self, v): 3423 return v._distribute_strategy is None # pylint: disable=protected-access 3424 3425 def _experimental_distribute_dataset(self, dataset, options): 3426 return dataset 3427 3428 def _distribute_datasets_from_function(self, dataset_fn, options): 3429 return dataset_fn(InputContext()) 3430 3431 def _experimental_distribute_values_from_function(self, value_fn): 3432 return value_fn(ValueContext()) 3433 3434 def _make_dataset_iterator(self, dataset): 3435 return _DefaultDistributionExtended.DefaultInputIterator(dataset) 3436 3437 def _make_input_fn_iterator(self, 3438 input_fn, 3439 replication_mode=InputReplicationMode.PER_WORKER): 3440 dataset = input_fn(InputContext()) 3441 return _DefaultDistributionExtended.DefaultInputIterator(dataset) 3442 3443 def _experimental_make_numpy_dataset(self, numpy_input, session): 3444 numpy_flat = nest.flatten(numpy_input) 3445 vars_flat = tuple( 3446 variable_scope.variable(array_ops.zeros(i.shape, i.dtype), 3447 trainable=False, use_resource=True) 3448 for i in numpy_flat 3449 ) 3450 for v, i in zip(vars_flat, numpy_flat): 3451 numpy_dataset.init_var_from_numpy(v, i, session) 3452 vars_nested = nest.pack_sequence_as(numpy_input, vars_flat) 3453 return dataset_ops.Dataset.from_tensor_slices(vars_nested) 3454 3455 def _broadcast_to(self, tensor, destinations): 3456 if destinations is None: 3457 return tensor 3458 else: 3459 raise NotImplementedError("TODO") 3460 3461 def _call_for_each_replica(self, fn, args, kwargs): 3462 with ReplicaContext(self._container_strategy(), replica_id_in_sync_group=0): 3463 return fn(*args, **kwargs) 3464 3465 def _reduce_to(self, reduce_op, value, destinations, options): 3466 # TODO(josh11b): Use destinations? 3467 del reduce_op, destinations, options 3468 return value 3469 3470 def _gather_to_implementation(self, value, destinations, axis, options): 3471 del destinations, axis, options 3472 return value 3473 3474 def _update(self, var, fn, args, kwargs, group): 3475 # The implementations of _update() and _update_non_slot() are identical 3476 # except _update() passes `var` as the first argument to `fn()`. 3477 return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group) 3478 3479 def _update_non_slot(self, colocate_with, fn, args, kwargs, should_group): 3480 # TODO(josh11b): Figure out what we should be passing to UpdateContext() 3481 # once that value is used for something. 3482 with UpdateContext(colocate_with): 3483 result = fn(*args, **kwargs) 3484 if should_group: 3485 return result 3486 else: 3487 return nest.map_structure(self._local_results, result) 3488 3489 def read_var(self, replica_local_var): 3490 return array_ops.identity(replica_local_var) 3491 3492 def _local_results(self, distributed_value): 3493 return (distributed_value,) 3494 3495 def value_container(self, value): 3496 return value 3497 3498 @property 3499 def _num_replicas_in_sync(self): 3500 return 1 3501 3502 @property 3503 def worker_devices(self): 3504 raise RuntimeError("worker_devices() method unsupported by default " 3505 "tf.distribute.Strategy.") 3506 3507 @property 3508 def parameter_devices(self): 3509 raise RuntimeError("parameter_devices() method unsupported by default " 3510 "tf.distribute.Strategy.") 3511 3512 def non_slot_devices(self, var_list): 3513 return min(var_list, key=lambda x: x.name) 3514 3515 def _in_multi_worker_mode(self): 3516 """Whether this strategy indicates working in multi-worker settings.""" 3517 # Default strategy doesn't indicate multi-worker training. 3518 return False 3519 3520 @property 3521 def should_checkpoint(self): 3522 return True 3523 3524 @property 3525 def should_save_summary(self): 3526 return True 3527 3528 def _get_local_replica_id(self, replica_id_in_sync_group): 3529 return replica_id_in_sync_group 3530 3531 def _get_replica_id_in_sync_group(self, replica_id): 3532 return replica_id 3533 3534 # TODO(priyag): This should inherit from `InputIterator`, once dependency 3535 # issues have been resolved. 3536 class DefaultInputIterator(object): 3537 """Default implementation of `InputIterator` for default strategy.""" 3538 3539 def __init__(self, dataset): 3540 self._dataset = dataset 3541 if eager_context.executing_eagerly(): 3542 self._iterator = dataset_ops.make_one_shot_iterator(dataset) 3543 else: 3544 self._iterator = dataset_ops.make_initializable_iterator(dataset) 3545 3546 def get_next(self): 3547 return self._iterator.get_next() 3548 3549 def get_next_as_optional(self): 3550 return self._iterator.get_next_as_optional() 3551 3552 @deprecated(None, "Use the iterator's `initializer` property instead.") 3553 def initialize(self): 3554 """Initialize underlying iterators. 3555 3556 Returns: 3557 A list of any initializer ops that should be run. 3558 """ 3559 if eager_context.executing_eagerly(): 3560 self._iterator = self._dataset.make_one_shot_iterator() 3561 return [] 3562 else: 3563 return [self._iterator.initializer] 3564 3565 @property 3566 def initializer(self): 3567 """Returns a list of ops that initialize the iterator.""" 3568 return self.initialize() 3569 3570 # TODO(priyag): Delete this once all strategies use global batch size. 3571 @property 3572 def _global_batch_size(self): 3573 """Global and per-replica batching are equivalent for this strategy.""" 3574 return True 3575 3576 3577class _DefaultReplicaContext(ReplicaContext): 3578 """ReplicaContext for _DefaultDistributionStrategy.""" 3579 3580 @property 3581 def replica_id_in_sync_group(self): 3582 # Return 0 instead of a constant tensor to avoid creating a new node for 3583 # users who don't use distribution strategy. 3584 return 0 3585 3586 3587# ------------------------------------------------------------------------------ 3588# We haven't yet implemented deserialization for DistributedVariables. 3589# So here we catch any attempts to deserialize variables 3590# when using distribution strategies. 3591# pylint: disable=protected-access 3592_original_from_proto = resource_variable_ops._from_proto_fn 3593 3594 3595def _from_proto_fn(v, import_scope=None): 3596 if distribution_strategy_context.has_strategy(): 3597 raise NotImplementedError( 3598 "Deserialization of variables is not yet supported when using a " 3599 "tf.distribute.Strategy.") 3600 else: 3601 return _original_from_proto(v, import_scope=import_scope) 3602 3603resource_variable_ops._from_proto_fn = _from_proto_fn 3604# pylint: enable=protected-access 3605 3606 3607#------------------------------------------------------------------------------- 3608# Shorthand for some methods from distribution_strategy_context. 3609_push_per_thread_mode = distribution_strategy_context._push_per_thread_mode # pylint: disable=protected-access 3610_get_per_thread_mode = distribution_strategy_context._get_per_thread_mode # pylint: disable=protected-access 3611_pop_per_thread_mode = distribution_strategy_context._pop_per_thread_mode # pylint: disable=protected-access 3612_get_default_replica_mode = ( 3613 distribution_strategy_context._get_default_replica_mode) # pylint: disable=protected-access 3614 3615 3616# ------------------------------------------------------------------------------ 3617# Metrics to track which distribution strategy is being called 3618distribution_strategy_gauge = monitoring.StringGauge( 3619 "/tensorflow/api/distribution_strategy", 3620 "Gauge to track the type of distribution strategy used.", "TFVersion") 3621distribution_strategy_replica_gauge = monitoring.IntGauge( 3622 "/tensorflow/api/distribution_strategy/replica", 3623 "Gauge to track the number of replica each distribution strategy used.", 3624 "CountType") 3625