1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Library for running a computation across multiple devices. 16 17See the guide for overview and examples: 18[TensorFlow v2.x](https://www.tensorflow.org/guide/distributed_training), 19[TensorFlow v1.x](https://github.com/tensorflow/docs/blob/master/site/en/r1/guide/distribute_strategy.ipynb). # pylint: disable=line-too-long 20 21The intent of this library is that you can write an algorithm in a stylized way 22and it will be usable with a variety of different `tf.distribute.Strategy` 23implementations. Each descendant will implement a different strategy for 24distributing the algorithm across multiple devices/machines. Furthermore, these 25changes can be hidden inside the specific layers and other library classes that 26need special treatment to run in a distributed setting, so that most users' 27model definition code can run unchanged. The `tf.distribute.Strategy` API works 28the same way with eager and graph execution. 29 30*Glossary* 31 32* _Data parallelism_ is where we run multiple copies of the model 33 on different slices of the input data. This is in contrast to 34 _model parallelism_ where we divide up a single copy of a model 35 across multiple devices. 36 Note: we only support data parallelism for now, but 37 hope to add support for model parallelism in the future. 38* A _device_ is a CPU or accelerator (e.g. GPUs, TPUs) on some machine that 39 TensorFlow can run operations on (see e.g. `tf.device`). You may have multiple 40 devices on a single machine, or be connected to devices on multiple 41 machines. Devices used to run computations are called _worker devices_. 42 Devices used to store variables are _parameter devices_. For some strategies, 43 such as `tf.distribute.MirroredStrategy`, the worker and parameter devices 44 will be the same (see mirrored variables below). For others they will be 45 different. For example, `tf.distribute.experimental.CentralStorageStrategy` 46 puts the variables on a single device (which may be a worker device or may be 47 the CPU), and `tf.distribute.experimental.ParameterServerStrategy` puts the 48 variables on separate machines called parameter servers (see below). 49* A _replica_ is one copy of the model, running on one slice of the 50 input data. Right now each replica is executed on its own 51 worker device, but once we add support for model parallelism 52 a replica may span multiple worker devices. 53* A _host_ is the CPU device on a machine with worker devices, typically 54 used for running input pipelines. 55* A _worker_ is defined to be the physical machine(s) containing the physical 56 devices (e.g. GPUs, TPUs) on which the replicated computation is executed. A 57 worker may contain one or more replicas, but contains at least one 58 replica. Typically one worker will correspond to one machine, but in the case 59 of very large models with model parallelism, one worker may span multiple 60 machines. We typically run one input pipeline per worker, feeding all the 61 replicas on that worker. 62* _Synchronous_, or more commonly _sync_, training is where the updates from 63 each replica are aggregated together before updating the model variables. This 64 is in contrast to _asynchronous_, or _async_ training, where each replica 65 updates the model variables independently. You may also have replicas 66 partitioned into groups which are in sync within each group but async between 67 groups. 68* _Parameter servers_: These are machines that hold a single copy of 69 parameters/variables, used by some strategies (right now just 70 `tf.distribute.experimental.ParameterServerStrategy`). All replicas that want 71 to operate on a variable retrieve it at the beginning of a step and send an 72 update to be applied at the end of the step. These can in priniciple support 73 either sync or async training, but right now we only have support for async 74 training with parameter servers. Compare to 75 `tf.distribute.experimental.CentralStorageStrategy`, which puts all variables 76 on a single device on the same machine (and does sync training), and 77 `tf.distribute.MirroredStrategy`, which mirrors variables to multiple devices 78 (see below). 79* _Mirrored variables_: These are variables that are copied to multiple 80 devices, where we keep the copies in sync by applying the same 81 updates to every copy. Normally would only be used with sync training. 82* Reductions and all-reduce: A _reduction_ is some method of aggregating 83 multiple values into one value, like "sum" or "mean". If a strategy is doing 84 sync training, we will perform a reduction on the gradients to a parameter 85 from all replicas before applying the update. _All-reduce_ is an algorithm for 86 performing a reduction on values from multiple devices and making the result 87 available on all of those devices. 88 89Note that we provide a default version of `tf.distribute.Strategy` that is 90used when no other strategy is in scope, that provides the same API with 91reasonable default behavior. 92""" 93 94from __future__ import absolute_import 95from __future__ import division 96from __future__ import print_function 97 98import copy 99import enum # pylint: disable=g-bad-import-order 100import threading 101import weakref 102 103import six 104 105from tensorflow.python.autograph.core import ag_ctx as autograph_ctx 106from tensorflow.python.autograph.impl import api as autograph 107from tensorflow.python.data.ops import dataset_ops 108from tensorflow.python.distribute import device_util 109from tensorflow.python.distribute import distribution_strategy_context 110from tensorflow.python.distribute import numpy_dataset 111from tensorflow.python.distribute import reduce_util 112from tensorflow.python.eager import context as eager_context 113from tensorflow.python.eager import monitoring 114from tensorflow.python.framework import constant_op 115from tensorflow.python.framework import dtypes 116from tensorflow.python.framework import ops 117from tensorflow.python.framework import tensor_shape 118from tensorflow.python.ops import array_ops 119from tensorflow.python.ops import control_flow_ops 120from tensorflow.python.ops import custom_gradient 121from tensorflow.python.ops import math_ops 122from tensorflow.python.ops import resource_variable_ops 123from tensorflow.python.ops import summary_ops_v2 124from tensorflow.python.ops import variable_scope 125from tensorflow.python.ops.losses import loss_reduction 126from tensorflow.python.ops.losses import losses_impl 127from tensorflow.python.platform import tf_logging 128from tensorflow.python.training.tracking import base as trackable 129from tensorflow.python.util import nest 130from tensorflow.python.util import tf_contextlib 131from tensorflow.python.util.deprecation import deprecated 132from tensorflow.python.util.tf_export import tf_export 133from tensorflow.tools.docs import doc_controls 134 135 136# ------------------------------------------------------------------------------ 137# Context tracking whether in a strategy.update() or .update_non_slot() call. 138 139 140_update_replica_id = threading.local() 141 142 143def get_update_replica_id(): 144 """Get the current device if in a `tf.distribute.Strategy.update()` call.""" 145 try: 146 return _update_replica_id.current 147 except AttributeError: 148 return None 149 150 151class UpdateContext(object): 152 """Context manager when you are in `update()` or `update_non_slot()`.""" 153 154 def __init__(self, replica_id): 155 self._replica_id = replica_id 156 self._old_replica_id = None 157 158 def __enter__(self): 159 self._old_replica_id = get_update_replica_id() 160 _update_replica_id.current = self._replica_id 161 162 def __exit__(self, exception_type, exception_value, traceback): 163 del exception_type, exception_value, traceback 164 _update_replica_id.current = self._old_replica_id 165 166 167# ------------------------------------------------------------------------------ 168# Public utility functions. 169 170 171@tf_export(v1=["distribute.get_loss_reduction"]) 172def get_loss_reduction(): 173 """`tf.distribute.ReduceOp` corresponding to the last loss reduction. 174 175 This is used to decide whether loss should be scaled in optimizer (used only 176 for estimator + v1 optimizer use case). 177 178 Returns: 179 `tf.distribute.ReduceOp` corresponding to the last loss reduction for 180 estimator and v1 optimizer use case. `tf.distribute.ReduceOp.SUM` otherwise. 181 """ 182 if not distribution_strategy_context.get_strategy()._scale_loss_for_estimator: # pylint: disable=protected-access 183 # If we are not in Estimator context then return 'SUM'. We do not need to 184 # scale loss in the optimizer. 185 return reduce_util.ReduceOp.SUM 186 last_reduction = ops.get_default_graph()._last_loss_reduction # pylint: disable=protected-access 187 if (last_reduction == losses_impl.Reduction.SUM or 188 last_reduction == loss_reduction.ReductionV2.SUM): 189 return reduce_util.ReduceOp.SUM 190 return reduce_util.ReduceOp.MEAN 191 192 193# ------------------------------------------------------------------------------ 194# Internal API for validating the current thread mode 195 196 197def _require_cross_replica_or_default_context_extended(extended): 198 """Verify in cross-replica context.""" 199 context = _get_per_thread_mode() 200 cross_replica = context.cross_replica_context 201 if cross_replica is not None and cross_replica.extended is extended: 202 return 203 if context is _get_default_replica_mode(): 204 return 205 strategy = extended._container_strategy() # pylint: disable=protected-access 206 # We have an error to report, figure out the right message. 207 if context.strategy is not strategy: 208 _wrong_strategy_scope(strategy, context) 209 assert cross_replica is None 210 raise RuntimeError("Method requires being in cross-replica context, use " 211 "get_replica_context().merge_call()") 212 213 214def _wrong_strategy_scope(strategy, context): 215 # Figure out the right error message. 216 if not distribution_strategy_context.has_strategy(): 217 raise RuntimeError( 218 'Need to be inside "with strategy.scope()" for %s' % 219 (strategy,)) 220 else: 221 raise RuntimeError( 222 "Mixing different tf.distribute.Strategy objects: %s is not %s" % 223 (context.strategy, strategy)) 224 225 226def require_replica_context(replica_ctx): 227 """Verify in `replica_ctx` replica context.""" 228 context = _get_per_thread_mode() 229 if context.replica_context is replica_ctx: return 230 # We have an error to report, figure out the right message. 231 if context.replica_context is None: 232 raise RuntimeError("Need to be inside `call_for_each_replica()`") 233 if context.strategy is replica_ctx.strategy: 234 # Two different ReplicaContexts with the same tf.distribute.Strategy. 235 raise RuntimeError("Mismatching ReplicaContext.") 236 raise RuntimeError( 237 "Mismatching tf.distribute.Strategy objects: %s is not %s." % 238 (context.strategy, replica_ctx.strategy)) 239 240 241def _require_strategy_scope_strategy(strategy): 242 """Verify in a `strategy.scope()` in this thread.""" 243 context = _get_per_thread_mode() 244 if context.strategy is strategy: return 245 _wrong_strategy_scope(strategy, context) 246 247 248def _require_strategy_scope_extended(extended): 249 """Verify in a `distribution_strategy.scope()` in this thread.""" 250 context = _get_per_thread_mode() 251 if context.strategy.extended is extended: return 252 # Report error. 253 strategy = extended._container_strategy() # pylint: disable=protected-access 254 _wrong_strategy_scope(strategy, context) 255 256 257# ------------------------------------------------------------------------------ 258# Internal context managers used to implement the DistributionStrategy 259# base class 260 261 262class _CurrentDistributionContext(object): 263 """Context manager setting the current `tf.distribute.Strategy`. 264 265 Also: overrides the variable creator and optionally the current device. 266 """ 267 268 def __init__(self, 269 strategy, 270 var_creator_scope, 271 var_scope=None, 272 default_device=None): 273 self._context = distribution_strategy_context._CrossReplicaThreadMode( # pylint: disable=protected-access 274 strategy) 275 self._var_creator_scope = var_creator_scope 276 self._var_scope = var_scope 277 if default_device: 278 self._device_scope = ops.device(default_device) 279 else: 280 self._device_scope = None 281 self._same_scope_again_count = 0 282 283 def __enter__(self): 284 # Allow this scope to be entered if this strategy is already in scope. 285 if distribution_strategy_context.has_strategy(): 286 _require_cross_replica_or_default_context_extended( 287 self._context.strategy.extended) 288 self._same_scope_again_count += 1 289 else: 290 _push_per_thread_mode(self._context) 291 if self._var_scope: 292 self._var_scope.__enter__() 293 self._var_creator_scope.__enter__() 294 if self._device_scope: 295 self._device_scope.__enter__() 296 return self._context.strategy 297 298 def __exit__(self, exception_type, exception_value, traceback): 299 if self._same_scope_again_count > 0: 300 self._same_scope_again_count -= 1 301 return 302 if self._device_scope: 303 try: 304 self._device_scope.__exit__(exception_type, exception_value, traceback) 305 except RuntimeError as e: 306 six.raise_from( 307 RuntimeError("Device scope nesting error: move call to " 308 "tf.distribute.set_strategy() out of `with` scope."), 309 e) 310 311 try: 312 self._var_creator_scope.__exit__( 313 exception_type, exception_value, traceback) 314 except RuntimeError as e: 315 six.raise_from( 316 RuntimeError("Variable creator scope nesting error: move call to " 317 "tf.distribute.set_strategy() out of `with` scope."), 318 e) 319 320 if self._var_scope: 321 try: 322 self._var_scope.__exit__(exception_type, exception_value, traceback) 323 except RuntimeError as e: 324 six.raise_from( 325 RuntimeError("Variable scope nesting error: move call to " 326 "tf.distribute.set_strategy() out of `with` scope."), 327 e) 328 _pop_per_thread_mode() 329 330 331# TODO(yuefengz): add more replication modes. 332@tf_export("distribute.InputReplicationMode") 333class InputReplicationMode(enum.Enum): 334 """Replication mode for input function. 335 336 * `PER_WORKER`: The input function will be called on each worker 337 independently, creating as many input pipelines as number of workers. 338 Replicas will dequeue from the local Dataset on their worker. 339 `tf.distribute.Strategy` doesn't manage any state sharing between such 340 separate input pipelines. 341 """ 342 PER_WORKER = "PER_WORKER" 343 344 345@tf_export("distribute.InputContext") 346class InputContext(object): 347 """A class wrapping information needed by an input function. 348 349 This is a context class that is passed to the user's input function and 350 contains information about the compute replicas and input pipelines. The 351 number of compute replicas (in sync training) helps compute the local batch 352 size from the desired global batch size for each replica. The input pipeline 353 information can be used to return a different subset of the input in each 354 replica (for e.g. shard the input pipeline, use a different input 355 source etc). 356 """ 357 358 def __init__(self, 359 num_input_pipelines=1, 360 input_pipeline_id=0, 361 num_replicas_in_sync=1): 362 """Initializes an InputContext object. 363 364 Args: 365 num_input_pipelines: the number of input pipelines in a cluster. 366 input_pipeline_id: the current input pipeline id, should be an int in 367 [0,`num_input_pipelines`). 368 num_replicas_in_sync: the number of replicas that are in sync. 369 """ 370 self._num_input_pipelines = num_input_pipelines 371 self._input_pipeline_id = input_pipeline_id 372 self._num_replicas_in_sync = num_replicas_in_sync 373 374 @property 375 def num_replicas_in_sync(self): 376 """Returns the number of compute replicas in sync.""" 377 return self._num_replicas_in_sync 378 379 @property 380 def input_pipeline_id(self): 381 """Returns the input pipeline ID.""" 382 return self._input_pipeline_id 383 384 @property 385 def num_input_pipelines(self): 386 """Returns the number of input pipelines.""" 387 return self._num_input_pipelines 388 389 def get_per_replica_batch_size(self, global_batch_size): 390 """Returns the per-replica batch size. 391 392 Args: 393 global_batch_size: the global batch size which should be divisible by 394 `num_replicas_in_sync`. 395 396 Returns: 397 the per-replica batch size. 398 399 Raises: 400 ValueError: if `global_batch_size` not divisible by 401 `num_replicas_in_sync`. 402 """ 403 if global_batch_size % self._num_replicas_in_sync != 0: 404 raise ValueError("The `global_batch_size` %r is not divisible by " 405 "`num_replicas_in_sync` %r " % 406 (global_batch_size, self._num_replicas_in_sync)) 407 return global_batch_size // self._num_replicas_in_sync 408 409 def __str__(self): 410 return "tf.distribute.InputContext(input pipeline id {}, total: {})".format( 411 self.input_pipeline_id, self.num_input_pipelines) 412 413 414# ------------------------------------------------------------------------------ 415# Base classes for all distribution strategies. 416 417 418# pylint: disable=line-too-long 419@tf_export("distribute.Strategy", v1=[]) 420class Strategy(object): 421 """A state & compute distribution policy on a list of devices. 422 423 See [the guide](https://www.tensorflow.org/guide/distributed_training) 424 for overview and examples. 425 426 In short: 427 428 * To use it with Keras `compile`/`fit`, 429 [please 430 read](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_keras). 431 * You may pass descendant of `tf.distribute.Strategy` to 432 `tf.estimator.RunConfig` to specify how a `tf.estimator.Estimator` 433 should distribute its computation. See 434 [guide](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_estimator_limited_support). 435 * Otherwise, use `tf.distribute.Strategy.scope` to specify that a 436 strategy should be used when building an executing your model. 437 (This puts you in the "cross-replica context" for this strategy, which 438 means the strategy is put in control of things like variable placement.) 439 * If you are writing a custom training loop, you will need to call a few more 440 methods, 441 [see the 442 guide](https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_custom_training_loops): 443 444 * Start by either creating a `tf.data.Dataset` normally or using 445 `tf.distribute.experimental_make_numpy_dataset` to make a dataset out of 446 a `numpy` array. 447 * Use `tf.distribute.Strategy.experimental_distribute_dataset` to convert 448 a `tf.data.Dataset` to something that produces "per-replica" values. 449 If you want to manually specify how the dataset should be partitioned 450 across replicas, use 451 `tf.distribute.Strategy.experimental_distribute_datasets_from_function` 452 instead. 453 * Use `tf.distribute.Strategy.experimental_run_v2` to run a function 454 once per replica, taking values that may be "per-replica" (e.g. 455 from a distributed dataset) and returning "per-replica" values. 456 This function is executed in "replica context", which means each 457 operation is performed separately on each replica. 458 * Finally use a method (such as `tf.distribute.Strategy.reduce`) to 459 convert the resulting "per-replica" values into ordinary `Tensor`s. 460 461 A custom training loop can be as simple as: 462 463 ``` 464 with my_strategy.scope(): 465 @tf.function 466 def distribute_train_epoch(dataset): 467 def replica_fn(input): 468 # process input and return result 469 return result 470 471 total_result = 0 472 for x in dataset: 473 per_replica_result = my_strategy.experimental_run_v2(replica_fn, 474 args=(x,)) 475 total_result += my_strategy.reduce(tf.distribute.ReduceOp.SUM, 476 per_replica_result, axis=None) 477 return total_result 478 479 dist_dataset = my_strategy.experimental_distribute_dataset(dataset) 480 for _ in range(EPOCHS): 481 train_result = distribute_train_epoch(dist_dataset) 482 ``` 483 484 This takes an ordinary `dataset` and `replica_fn` and runs it 485 distributed using a particular `tf.distribute.Strategy` named 486 `my_strategy` above. Any variables created in `replica_fn` are created 487 using `my_strategy`'s policy, and library functions called by 488 `replica_fn` can use the `get_replica_context()` API to implement 489 distributed-specific behavior. 490 491 You can use the `reduce` API to aggregate results across replicas and use 492 this as a return value from one iteration over the distributed dataset. Or 493 you can use `tf.keras.metrics` (such as loss, accuracy, etc.) to 494 accumulate metrics across steps in a given epoch. 495 496 See the 497 [custom training loop 498 tutorial](https://www.tensorflow.org/tutorials/distribute/custom_training) 499 for a more detailed example. 500 501 Note: `tf.distribute.Strategy` currently does not support TensorFlow's 502 partitioned variables (where a single variable is split across multiple 503 devices) at this time. 504 """ 505 # pylint: enable=line-too-long 506 507 # TODO(josh11b): Partitioned computations, state; sharding 508 # TODO(josh11b): Model parallelism: "replicas" with multiple devices; shuffling 509 510 def __init__(self, extended): 511 self._extended = extended 512 513 # Flag that is used to indicate whether distribution strategy is used with 514 # Estimator. This is required for backward compatibility of loss scaling 515 # when using v1 optimizer with estimator. 516 self._scale_loss_for_estimator = False 517 518 if not hasattr(extended, "_retrace_functions_for_each_device"): 519 # pylint: disable=protected-access 520 try: 521 extended._retrace_functions_for_each_device = ( 522 len(extended.worker_devices) > 1) 523 distribution_strategy_replica_gauge.get_cell("num_replicas").set( 524 self.num_replicas_in_sync) 525 except: # pylint: disable=bare-except 526 # Default for the case where extended.worker_devices can't return 527 # a sensible value. 528 extended._retrace_functions_for_each_device = True 529 530 @property 531 def extended(self): 532 """`tf.distribute.StrategyExtended` with additional methods.""" 533 return self._extended 534 535 @tf_contextlib.contextmanager 536 def _scale_loss_for_estimator_enabled(self): 537 """Scope which sets a flag used for scaling losses in optimizer. 538 539 Yields: 540 `_scale_loss_for_estimator_enabled` is a context manager with a 541 side effect, but doesn't return a value. 542 """ 543 self._scale_loss_for_estimator = True 544 try: 545 yield 546 finally: 547 self._scale_loss_for_estimator = False 548 549 def scope(self): 550 """Returns a context manager selecting this Strategy as current. 551 552 Inside a `with strategy.scope():` code block, this thread 553 will use a variable creator set by `strategy`, and will 554 enter its "cross-replica context". 555 556 Returns: 557 A context manager. 558 """ 559 return self._extended._scope(self) # pylint: disable=protected-access 560 561 @doc_controls.do_not_doc_inheritable # DEPRECATED, moving to `extended` 562 def colocate_vars_with(self, colocate_with_variable): 563 """DEPRECATED: use extended.colocate_vars_with() instead.""" 564 return self._extended.colocate_vars_with(colocate_with_variable) 565 566 @doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only 567 def make_dataset_iterator(self, dataset): 568 """DEPRECATED TF 1.x ONLY.""" 569 return self._extended._make_dataset_iterator(dataset) # pylint: disable=protected-access 570 571 @doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only 572 def make_input_fn_iterator(self, 573 input_fn, 574 replication_mode=InputReplicationMode.PER_WORKER): 575 """DEPRECATED TF 1.x ONLY.""" 576 if replication_mode != InputReplicationMode.PER_WORKER: 577 raise ValueError( 578 "Input replication mode not supported: %r" % replication_mode) 579 with self.scope(): 580 return self.extended._make_input_fn_iterator( # pylint: disable=protected-access 581 input_fn, replication_mode=replication_mode) 582 583 def experimental_make_numpy_dataset(self, numpy_input): 584 """Makes a `tf.data.Dataset` for input provided via a numpy array. 585 586 This avoids adding `numpy_input` as a large constant in the graph, 587 and copies the data to the machine or machines that will be processing 588 the input. 589 590 Note that you will likely need to use `experimental_distribute_dataset` 591 with the returned dataset to further distribute it with the strategy. 592 593 Example: 594 ``` 595 numpy_input = np.ones([10], dtype=np.float32) 596 dataset = strategy.experimental_make_numpy_dataset(numpy_input) 597 dist_dataset = strategy.experimental_distribute_dataset(dataset) 598 ``` 599 600 Args: 601 numpy_input: A nest of NumPy input arrays that will be converted into a 602 dataset. Note that lists of Numpy arrays are stacked, as that is normal 603 `tf.data.Dataset` behavior. 604 605 Returns: 606 A `tf.data.Dataset` representing `numpy_input`. 607 """ 608 return self.extended.experimental_make_numpy_dataset( 609 numpy_input, session=None) 610 611 @doc_controls.do_not_generate_docs # DEPRECATED: TF 1.x only 612 def experimental_run(self, fn, input_iterator=None): 613 """DEPRECATED TF 1.x ONLY.""" 614 with self.scope(): 615 args = (input_iterator.get_next(),) if input_iterator is not None else () 616 return self.experimental_run_v2(fn, args=args) 617 618 def experimental_distribute_dataset(self, dataset): 619 """Distributes a tf.data.Dataset instance provided via `dataset`. 620 621 The returned distributed dataset can be iterated over similar to how 622 regular datasets can. 623 NOTE: Currently, the user cannot add any more transformations to a 624 distributed dataset. 625 626 The following is an example: 627 628 ```python 629 strategy = tf.distribute.MirroredStrategy() 630 631 # Create a dataset 632 dataset = dataset_ops.Dataset.TFRecordDataset([ 633 "/a/1.tfr", "/a/2.tfr", "/a/3.tfr", "/a/4.tfr"]) 634 635 # Distribute that dataset 636 dist_dataset = strategy.experimental_distribute_dataset(dataset) 637 638 # Iterate over the distributed dataset 639 for x in dist_dataset: 640 # process dataset elements 641 strategy.experimental_run_v2(train_step, args=(x,)) 642 ``` 643 644 We will assume that the input dataset is batched by the 645 global batch size. With this assumption, we will make a best effort to 646 divide each batch across all the replicas (one or more workers). 647 648 In a multi-worker setting, we will first attempt to distribute the dataset 649 by attempting to detect whether the dataset is being created out of 650 ReaderDatasets (e.g. TFRecordDataset, TextLineDataset, etc.) and if so, 651 attempting to shard the input files. Note that there has to be at least one 652 input file per worker. If you have less than one input file per worker, we 653 suggest that you should disable distributing your dataset using the method 654 below. 655 656 If that attempt is unsuccessful (e.g. the dataset is created from a 657 Dataset.range), we will shard the dataset evenly at the end by appending a 658 `.shard` operation to the end of the processing pipeline. This will cause 659 the entire preprocessing pipeline for all the data to be run on every 660 worker, and each worker will do redundant work. We will print a warning 661 if this method of sharding is selected. 662 663 You can disable dataset sharding across workers using the 664 `auto_shard_policy` option in `tf.data.experimental.DistributeOptions`. 665 666 Within each worker, we will also split the data among all the worker 667 devices (if more than one a present), and this will happen even if 668 multi-worker sharding is disabled using the method above. 669 670 If the above batch splitting and dataset sharding logic is undesirable, 671 please use `experimental_distribute_datasets_from_function` instead, which 672 does not do any automatic splitting or sharding. 673 674 You can also use the `element_spec` property of the distributed dataset 675 returned by this API to query the `tf.TypeSpec` of the elements returned 676 by the iterator. This can be used to set the `input_signature` property 677 of a `tf.function`. 678 679 ```python 680 strategy = tf.distribute.MirroredStrategy() 681 682 # Create a dataset 683 dataset = dataset_ops.Dataset.TFRecordDataset([ 684 "/a/1.tfr", "/a/2.tfr", "/a/3.tfr", "/a/4.tfr"]) 685 686 # Distribute that dataset 687 dist_dataset = strategy.experimental_distribute_dataset(dataset) 688 689 @tf.function(input_signature=[dist_dataset.element_spec]) 690 def train_step(inputs): 691 # train model with inputs 692 return 693 694 # Iterate over the distributed dataset 695 for x in dist_dataset: 696 # process dataset elements 697 strategy.experimental_run_v2(train_step, args=(x,)) 698 ``` 699 700 Args: 701 dataset: `tf.data.Dataset` that will be sharded across all replicas using 702 the rules stated above. 703 704 Returns: 705 A "distributed `Dataset`", which acts like a `tf.data.Dataset` except 706 it produces "per-replica" values. 707 """ 708 return self._extended._experimental_distribute_dataset(dataset) # pylint: disable=protected-access 709 710 def experimental_distribute_datasets_from_function(self, dataset_fn): 711 """Distributes `tf.data.Dataset` instances created by calls to `dataset_fn`. 712 713 `dataset_fn` will be called once for each worker in the strategy. Each 714 replica on that worker will dequeue one batch of inputs from the local 715 `Dataset` (i.e. if a worker has two replicas, two batches will be dequeued 716 from the `Dataset` every step). 717 718 This method can be used for several purposes. For example, where 719 `experimental_distribute_dataset` is unable to shard the input files, this 720 method might be used to manually shard the dataset (avoiding the slow 721 fallback behavior in `experimental_distribute_dataset`). In cases where the 722 dataset is infinite, this sharding can be done by creating dataset replicas 723 that differ only in their random seed. 724 `experimental_distribute_dataset` may also sometimes fail to split the 725 batch across replicas on a worker. In that case, this method can be used 726 where that limitation does not exist. 727 728 The `dataset_fn` should take an `tf.distribute.InputContext` instance where 729 information about batching and input replication can be accessed: 730 731 ``` 732 def dataset_fn(input_context): 733 batch_size = input_context.get_per_replica_batch_size(global_batch_size) 734 d = tf.data.Dataset.from_tensors([[1.]]).repeat().batch(batch_size) 735 return d.shard( 736 input_context.num_input_pipelines, input_context.input_pipeline_id) 737 738 inputs = strategy.experimental_distribute_datasets_from_function(dataset_fn) 739 740 for batch in inputs: 741 replica_results = strategy.experimental_run_v2(replica_fn, args=(batch,)) 742 ``` 743 744 IMPORTANT: The `tf.data.Dataset` returned by `dataset_fn` should have a 745 per-replica batch size, unlike `experimental_distribute_dataset`, which uses 746 the global batch size. This may be computed using 747 `input_context.get_per_replica_batch_size`. 748 749 To query the `tf.TypeSpec` of the elements in the distributed dataset 750 returned by this API, you need to use the `element_spec` property of the 751 distributed iterator. This `tf.TypeSpec` can be used to set the 752 `input_signature` property of a `tf.function`. 753 754 ```python 755 # If you want to specify `input_signature` for a `tf.function` you must 756 # first create the iterator. 757 iterator = iter(inputs) 758 759 @tf.function(input_signature=[iterator.element_spec]) 760 def replica_fn_with_signature(inputs): 761 # train the model with inputs 762 return 763 764 for _ in range(steps): 765 strategy.experimental_run_v2(replica_fn_with_signature, 766 args=(next(iterator),)) 767 ``` 768 769 Args: 770 dataset_fn: A function taking a `tf.distribute.InputContext` instance and 771 returning a `tf.data.Dataset`. 772 773 Returns: 774 A "distributed `Dataset`", which acts like a `tf.data.Dataset` except 775 it produces "per-replica" values. 776 """ 777 return self._extended._experimental_distribute_datasets_from_function( # pylint: disable=protected-access 778 dataset_fn) 779 780 def experimental_run_v2(self, fn, args=(), kwargs=None): 781 """Run `fn` on each replica, with the given arguments. 782 783 Executes ops specified by `fn` on each replica. If `args` or `kwargs` have 784 "per-replica" values, such as those produced by a "distributed `Dataset`", 785 when `fn` is executed on a particular replica, it will be executed with the 786 component of those "per-replica" values that correspond to that replica. 787 788 `fn` may call `tf.distribute.get_replica_context()` to access members such 789 as `all_reduce`. 790 791 All arguments in `args` or `kwargs` should either be nest of tensors or 792 per-replica objects containing tensors or composite tensors. 793 794 IMPORTANT: Depending on the implementation of `tf.distribute.Strategy` and 795 whether eager execution is enabled, `fn` may be called one or more times ( 796 once for each replica). 797 798 Args: 799 fn: The function to run. The output must be a `tf.nest` of `Tensor`s. 800 args: (Optional) Positional arguments to `fn`. 801 kwargs: (Optional) Keyword arguments to `fn`. 802 803 Returns: 804 Merged return value of `fn` across replicas. The structure of the return 805 value is the same as the return value from `fn`. Each element in the 806 structure can either be "per-replica" `Tensor` objects or `Tensor`s 807 (for example, if running on a single replica). 808 """ 809 if not isinstance(args, (list, tuple)): 810 raise ValueError( 811 "positional args must be a list or tuple, got {}".format(type(args))) 812 813 with self.scope(): 814 # tf.distribute supports Eager functions, so AutoGraph should not be 815 # applied when when the caller is also in Eager mode. 816 fn = autograph.tf_convert( 817 fn, autograph_ctx.control_status_ctx(), convert_by_default=False) 818 return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs) 819 820 def reduce(self, reduce_op, value, axis): 821 """Reduce `value` across replicas. 822 823 Given a per-replica value returned by `experimental_run_v2`, say a 824 per-example loss, the batch will be divided across all the replicas. This 825 function allows you to aggregate across replicas and optionally also across 826 batch elements. For example, if you have a global batch size of 8 and 2 827 replicas, values for examples `[0, 1, 2, 3]` will be on replica 0 and 828 `[4, 5, 6, 7]` will be on replica 1. By default, `reduce` will just 829 aggregate across replicas, returning `[0+4, 1+5, 2+6, 3+7]`. This is useful 830 when each replica is computing a scalar or some other value that doesn't 831 have a "batch" dimension (like a gradient). More often you will want to 832 aggregate across the global batch, which you can get by specifying the batch 833 dimension as the `axis`, typically `axis=0`. In this case it would return a 834 scalar `0+1+2+3+4+5+6+7`. 835 836 If there is a last partial batch, you will need to specify an axis so 837 that the resulting shape is consistent across replicas. So if the last 838 batch has size 6 and it is divided into [0, 1, 2, 3] and [4, 5], you 839 would get a shape mismatch unless you specify `axis=0`. If you specify 840 `tf.distribute.ReduceOp.MEAN`, using `axis=0` will use the correct 841 denominator of 6. Contrast this with computing `reduce_mean` to get a 842 scalar value on each replica and this function to average those means, 843 which will weigh some values `1/8` and others `1/4`. 844 845 Args: 846 reduce_op: A `tf.distribute.ReduceOp` value specifying how values should 847 be combined. 848 value: A "per replica" value, e.g. returned by `experimental_run_v2` to 849 be combined into a single tensor. 850 axis: Specifies the dimension to reduce along within each 851 replica's tensor. Should typically be set to the batch dimension, or 852 `None` to only reduce across replicas (e.g. if the tensor has no batch 853 dimension). 854 855 Returns: 856 A `Tensor`. 857 """ 858 # TODO(josh11b): support `value` being a nest. 859 _require_cross_replica_or_default_context_extended(self._extended) 860 if isinstance(reduce_op, six.string_types): 861 reduce_op = reduce_util.ReduceOp(reduce_op.upper()) 862 if axis is None: 863 return self._extended._reduce(reduce_op, value) # pylint: disable=protected-access 864 if reduce_op == reduce_util.ReduceOp.SUM: 865 value = self.experimental_run_v2( 866 lambda v: math_ops.reduce_sum(v, axis=axis), args=(value,)) 867 return self._extended._reduce(reduce_op, value) # pylint: disable=protected-access 868 if reduce_op != reduce_util.ReduceOp.MEAN: 869 raise TypeError("Expected `reduce_op` to be a `tf.distribute.ReduceOp`, " 870 "not: %r" % reduce_op) 871 # TODO(josh11b): Support list/tuple and tensor axis values. 872 if not isinstance(axis, six.integer_types): 873 raise TypeError("Expected `axis` to be an integer not: %r" % axis) 874 875 def mean_reduce_helper(v, axis=axis): 876 """Computes the numerator and denominator on each replica.""" 877 numer = math_ops.reduce_sum(v, axis=axis) 878 if v.shape.rank is not None: 879 # Note(joshl): We support axis < 0 to be consistent with the 880 # tf.math.reduce_* operations. 881 if axis < 0: 882 if axis + v.shape.rank < 0: 883 raise ValueError( 884 "`axis` = %r out of range for `value` with rank %d" % 885 (axis, v.shape.rank)) 886 axis += v.shape.rank 887 elif axis >= v.shape.rank: 888 raise ValueError( 889 "`axis` = %r out of range for `value` with rank %d" % 890 (axis, v.shape.rank)) 891 # TF v2 returns `None` for unknown dimensions and an integer for 892 # known dimension, whereas TF v1 returns tensor_shape.Dimension(None) 893 # or tensor_shape.Dimension(integer). `dimension_value` hides this 894 # difference, always returning `None` or an integer. 895 dim = tensor_shape.dimension_value(v.shape[axis]) 896 if dim is not None: 897 # By returning a python value in the static shape case, we can 898 # maybe get a fast path for reducing the denominator. 899 return numer, array_ops.constant(dim, dtype=dtypes.int64) 900 elif axis < 0: 901 axis = axis + array_ops.rank(v) 902 if v.shape.rank == 1: 903 # TODO(b/139422050): Currently tf.shape is not supported in TPU dynamic 904 # padder, use tf.size instead to workaround if the rank is 1. 905 denom = array_ops.size(v, out_type=dtypes.int64) 906 else: 907 denom = array_ops.shape_v2(v, out_type=dtypes.int64)[axis] 908 # TODO(josh11b): Should we cast denom to v.dtype here instead of after the 909 # reduce is complete? 910 return numer, denom 911 912 numer, denom = self.experimental_run_v2(mean_reduce_helper, args=(value,)) 913 # TODO(josh11b): Should batch reduce here instead of doing two. 914 numer = self._extended._reduce(reduce_util.ReduceOp.SUM, numer) # pylint: disable=protected-access 915 denom = self._extended._reduce(reduce_util.ReduceOp.SUM, denom) # pylint: disable=protected-access 916 denom = math_ops.cast(denom, numer.dtype) 917 return math_ops.truediv(numer, denom) 918 919 @doc_controls.do_not_doc_inheritable # DEPRECATED 920 def unwrap(self, value): 921 """Returns the list of all local per-replica values contained in `value`. 922 923 DEPRECATED: Please use `experimental_local_results` instead. 924 925 Note: This only returns values on the workers initiated by this client. 926 When using a `tf.distribute.Strategy` like 927 `tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker 928 will be its own client, and this function will only return values 929 computed on that worker. 930 931 Args: 932 value: A value returned by `experimental_run()`, 933 `extended.call_for_each_replica()`, or a variable created in `scope`. 934 935 Returns: 936 A tuple of values contained in `value`. If `value` represents a single 937 value, this returns `(value,).` 938 """ 939 return self._extended._local_results(value) # pylint: disable=protected-access 940 941 def experimental_local_results(self, value): 942 """Returns the list of all local per-replica values contained in `value`. 943 944 Note: This only returns values on the worker initiated by this client. 945 When using a `tf.distribute.Strategy` like 946 `tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker 947 will be its own client, and this function will only return values 948 computed on that worker. 949 950 Args: 951 value: A value returned by `experimental_run()`, `experimental_run_v2()`, 952 `extended.call_for_each_replica()`, or a variable created in `scope`. 953 954 Returns: 955 A tuple of values contained in `value`. If `value` represents a single 956 value, this returns `(value,).` 957 """ 958 return self._extended._local_results(value) # pylint: disable=protected-access 959 960 @doc_controls.do_not_doc_inheritable # DEPRECATED: TF v1.x only 961 def group(self, value, name=None): 962 """Shortcut for `tf.group(self.experimental_local_results(value))`.""" 963 return self._extended._group(value, name) # pylint: disable=protected-access 964 965 @property 966 def num_replicas_in_sync(self): 967 """Returns number of replicas over which gradients are aggregated.""" 968 return self._extended._num_replicas_in_sync # pylint: disable=protected-access 969 970 @doc_controls.do_not_doc_inheritable # DEPRECATED: see doc string 971 def configure(self, 972 session_config=None, 973 cluster_spec=None, 974 task_type=None, 975 task_id=None): 976 # pylint: disable=g-doc-return-or-yield,g-doc-args 977 """DEPRECATED: use `update_config_proto` instead. 978 979 Configures the strategy class. 980 981 DEPRECATED: This method's functionality has been split into the strategy 982 constructor and `update_config_proto`. In the future, we will allow passing 983 cluster and config_proto to the constructor to configure the strategy. And 984 `update_config_proto` can be used to update the config_proto based on the 985 specific strategy. 986 """ 987 return self._extended._configure( # pylint: disable=protected-access 988 session_config, cluster_spec, task_type, task_id) 989 990 @doc_controls.do_not_generate_docs # DEPRECATED 991 def update_config_proto(self, config_proto): 992 """DEPRECATED TF 1.x ONLY.""" 993 return self._extended._update_config_proto(config_proto) # pylint: disable=protected-access 994 995 def __deepcopy__(self, memo): 996 # First do a regular deepcopy of `self`. 997 cls = self.__class__ 998 result = cls.__new__(cls) 999 memo[id(self)] = result 1000 for k, v in self.__dict__.items(): 1001 setattr(result, k, copy.deepcopy(v, memo)) 1002 # One little fix-up: we want `result._extended` to reference `result` 1003 # instead of `self`. 1004 result._extended._container_strategy_weakref = weakref.ref(result) # pylint: disable=protected-access 1005 return result 1006 1007 def __copy__(self): 1008 raise RuntimeError("Must only deepcopy DistributionStrategy.") 1009 1010 1011# TF v1.x version has additional deprecated APIs 1012@tf_export(v1=["distribute.Strategy"]) 1013class StrategyV1(Strategy): 1014 """A list of devices with a state & compute distribution policy. 1015 1016 See [the guide](https://www.tensorflow.org/guide/distribute_strategy) 1017 for overview and examples. 1018 1019 Note: Not all `tf.distribute.Strategy` implementations currently support 1020 TensorFlow's partitioned variables (where a single variable is split across 1021 multiple devices) at this time. 1022 """ 1023 1024 def make_dataset_iterator(self, dataset): 1025 """Makes an iterator for input provided via `dataset`. 1026 1027 DEPRECATED: This method is not available in TF 2.x. 1028 1029 Data from the given dataset will be distributed evenly across all the 1030 compute replicas. We will assume that the input dataset is batched by the 1031 global batch size. With this assumption, we will make a best effort to 1032 divide each batch across all the replicas (one or more workers). 1033 If this effort fails, an error will be thrown, and the user should instead 1034 use `make_input_fn_iterator` which provides more control to the user, and 1035 does not try to divide a batch across replicas. 1036 1037 The user could also use `make_input_fn_iterator` if they want to 1038 customize which input is fed to which replica/worker etc. 1039 1040 Args: 1041 dataset: `tf.data.Dataset` that will be distributed evenly across all 1042 replicas. 1043 1044 Returns: 1045 An `tf.distribute.InputIterator` which returns inputs for each step of the 1046 computation. User should call `initialize` on the returned iterator. 1047 """ 1048 return self._extended._make_dataset_iterator(dataset) # pylint: disable=protected-access 1049 1050 def make_input_fn_iterator(self, # pylint: disable=useless-super-delegation 1051 input_fn, 1052 replication_mode=InputReplicationMode.PER_WORKER): 1053 """Returns an iterator split across replicas created from an input function. 1054 1055 DEPRECATED: This method is not available in TF 2.x. 1056 1057 The `input_fn` should take an `tf.distribute.InputContext` object where 1058 information about batching and input sharding can be accessed: 1059 1060 ``` 1061 def input_fn(input_context): 1062 batch_size = input_context.get_per_replica_batch_size(global_batch_size) 1063 d = tf.data.Dataset.from_tensors([[1.]]).repeat().batch(batch_size) 1064 return d.shard(input_context.num_input_pipelines, 1065 input_context.input_pipeline_id) 1066 with strategy.scope(): 1067 iterator = strategy.make_input_fn_iterator(input_fn) 1068 replica_results = strategy.experimental_run(replica_fn, iterator) 1069 ``` 1070 1071 The `tf.data.Dataset` returned by `input_fn` should have a per-replica 1072 batch size, which may be computed using 1073 `input_context.get_per_replica_batch_size`. 1074 1075 Args: 1076 input_fn: A function taking a `tf.distribute.InputContext` object and 1077 returning a `tf.data.Dataset`. 1078 replication_mode: an enum value of `tf.distribute.InputReplicationMode`. 1079 Only `PER_WORKER` is supported currently, which means there will be 1080 a single call to `input_fn` per worker. Replicas will dequeue from the 1081 local `tf.data.Dataset` on their worker. 1082 1083 Returns: 1084 An iterator object that should first be `.initialize()`-ed. It may then 1085 either be passed to `strategy.experimental_run()` or you can 1086 `iterator.get_next()` to get the next value to pass to 1087 `strategy.extended.call_for_each_replica()`. 1088 """ 1089 return super(StrategyV1, self).make_input_fn_iterator( 1090 input_fn, replication_mode) 1091 1092 def experimental_make_numpy_dataset(self, numpy_input, session=None): 1093 """Makes a tf.data.Dataset for input provided via a numpy array. 1094 1095 This avoids adding `numpy_input` as a large constant in the graph, 1096 and copies the data to the machine or machines that will be processing 1097 the input. 1098 1099 Note that you will likely need to use 1100 tf.distribute.Strategy.experimental_distribute_dataset 1101 with the returned dataset to further distribute it with the strategy. 1102 1103 Example: 1104 ``` 1105 numpy_input = np.ones([10], dtype=np.float32) 1106 dataset = strategy.experimental_make_numpy_dataset(numpy_input) 1107 dist_dataset = strategy.experimental_distribute_dataset(dataset) 1108 ``` 1109 1110 Args: 1111 numpy_input: A nest of NumPy input arrays that will be converted into a 1112 dataset. Note that lists of Numpy arrays are stacked, as that is normal 1113 `tf.data.Dataset` behavior. 1114 session: (TensorFlow v1.x graph execution only) A session used for 1115 initialization. 1116 1117 Returns: 1118 A `tf.data.Dataset` representing `numpy_input`. 1119 """ 1120 return self.extended.experimental_make_numpy_dataset( 1121 numpy_input, session=session) 1122 1123 def experimental_run(self, fn, input_iterator=None): # pylint: disable=useless-super-delegation 1124 """Runs ops in `fn` on each replica, with inputs from `input_iterator`. 1125 1126 DEPRECATED: This method is not available in TF 2.x. Please switch 1127 to using `experimental_run_v2` instead. 1128 1129 When eager execution is enabled, executes ops specified by `fn` on each 1130 replica. Otherwise, builds a graph to execute the ops on each replica. 1131 1132 Each replica will take a single, different input from the inputs provided by 1133 one `get_next` call on the input iterator. 1134 1135 `fn` may call `tf.distribute.get_replica_context()` to access members such 1136 as `replica_id_in_sync_group`. 1137 1138 IMPORTANT: Depending on the `tf.distribute.Strategy` implementation being 1139 used, and whether eager execution is enabled, `fn` may be called one or more 1140 times (once for each replica). 1141 1142 Args: 1143 fn: The function to run. The inputs to the function must match the outputs 1144 of `input_iterator.get_next()`. The output must be a `tf.nest` of 1145 `Tensor`s. 1146 input_iterator: (Optional) input iterator from which the inputs are taken. 1147 1148 Returns: 1149 Merged return value of `fn` across replicas. The structure of the return 1150 value is the same as the return value from `fn`. Each element in the 1151 structure can either be `PerReplica` (if the values are unsynchronized), 1152 `Mirrored` (if the values are kept in sync), or `Tensor` (if running on a 1153 single replica). 1154 """ 1155 return super(StrategyV1, self).experimental_run( 1156 fn, input_iterator) 1157 1158 def reduce(self, reduce_op, value, axis=None): 1159 return super(StrategyV1, self).reduce(reduce_op, value, axis) 1160 1161 reduce.__doc__ = Strategy.reduce.__doc__ 1162 1163 def update_config_proto(self, config_proto): 1164 """Returns a copy of `config_proto` modified for use with this strategy. 1165 1166 DEPRECATED: This method is not available in TF 2.x. 1167 1168 The updated config has something needed to run a strategy, e.g. 1169 configuration to run collective ops, or device filters to improve 1170 distributed training performance. 1171 1172 Args: 1173 config_proto: a `tf.ConfigProto` object. 1174 1175 Returns: 1176 The updated copy of the `config_proto`. 1177 """ 1178 return self._extended._update_config_proto(config_proto) # pylint: disable=protected-access 1179 1180 1181# NOTE(josh11b): For any strategy that needs to support tf.compat.v1, 1182# instead descend from StrategyExtendedV1. 1183@tf_export("distribute.StrategyExtended", v1=[]) 1184class StrategyExtendedV2(object): 1185 """Additional APIs for algorithms that need to be distribution-aware. 1186 1187 Note: For most usage of `tf.distribute.Strategy`, there should be no need to 1188 call these methods, since TensorFlow libraries (such as optimizers) already 1189 call these methods when needed on your behalf. 1190 1191 Lower-level concepts: 1192 1193 * Wrapped values: In order to represent values parallel across devices 1194 (either replicas or the devices associated with a particular value), we 1195 wrap them in a "PerReplica" or "Mirrored" object that contains a map 1196 from replica id to values. "PerReplica" is used when the value may be 1197 different across replicas, and "Mirrored" when the value are the same. 1198 * Unwrapping and merging: Consider calling a function `fn` on multiple 1199 replicas, like `experimental_run_v2(fn, args=[w])` with an 1200 argument `w` that is a wrapped value. This means `w` will have a map taking 1201 replica id `0` to `w0`, replica id `11` to `w1`, etc. 1202 `experimental_run_v2()` unwraps `w` before calling `fn`, so 1203 it calls `fn(w0)` on `d0`, `fn(w1)` on `d1`, etc. It then merges the return 1204 values from `fn()`, which can possibly result in wrapped values. For 1205 example, let's say `fn()` returns a tuple with three components: `(x, a, 1206 v0)` from replica 0, `(x, b, v1)` on replica 1, etc. If the first component 1207 is the same object `x` from every replica, then the first component of the 1208 merged result will also be `x`. If the second component is different (`a`, 1209 `b`, ...) from each replica, then the merged value will have a wrapped map 1210 from replica device to the different values. If the third component is the 1211 members of a mirrored variable (`v` maps `d0` to `v0`, `d1` to `v1`, etc.), 1212 then the merged result will be that mirrored variable (`v`). 1213 * Worker devices vs. parameter devices: Most replica computations will 1214 happen on worker devices. Since we don't yet support model 1215 parallelism, there will be one worker device per replica. When using 1216 parameter servers or central storage, the set of devices holding 1217 variables may be different, otherwise the parameter devices might 1218 match the worker devices. 1219 1220 *Replica context vs. Cross-replica context* 1221 1222 A _replica context_ applies when we are in some function that is being called 1223 once for each replica. Otherwise we are in cross-replica context, which is 1224 useful for calling `tf.distribute.Strategy` methods which operate across the 1225 replicas (like `reduce_to()`). By default you start in a replica context 1226 (the "default single replica context") and then some methods can switch you 1227 back and forth. There is a third mode you can be in called _update context_ 1228 used when updating variables. 1229 1230 * `tf.distribute.Strategy.scope`: enters cross-replica context when 1231 no other strategy is in scope. 1232 * `tf.distribute.Strategy.experimental_run_v2`: calls a function in 1233 replica context. 1234 * `tf.distribute.ReplicaContext.merge_call`: transitions from replica 1235 context to cross-replica context. 1236 * `tf.distribute.StrategyExtended.update`: calls a function in an update 1237 context from a cross-replica context. 1238 1239 In a replica context, you may freely read the values of variables, but 1240 you may only update their value if they specify a way to aggregate the 1241 update using the `aggregation` parameter in the variable's constructor. 1242 In a cross-replica context, you may read or write variables (writes may 1243 need to be broadcast to all copies of the variable if it is mirrored). 1244 1245 *Sync on read variables* 1246 1247 In some cases, such as a metric, we want to accumulate a bunch of updates on 1248 each replica independently and only aggregate when reading. This can be a big 1249 performance win when the value is read only rarely (maybe the value is only 1250 read at the end of an epoch or when checkpointing). These are variables 1251 created by passing `synchronization=ON_READ` to the variable's constructor 1252 (and some value for `aggregation`). 1253 1254 The strategy may choose to put the variable on multiple devices, like mirrored 1255 variables, but unlike mirrored variables we don't synchronize the updates to 1256 them to make sure they have the same value. Instead, the synchronization is 1257 performed when reading in cross-replica context. In a replica context, reads 1258 and writes are performed on the local copy (we allow reads so you can write 1259 code like `v = 0.9*v + 0.1*update`). We don't allow operations like 1260 `v.assign_add` in a cross-replica context for sync on read variables; right 1261 now we don't have a use case for such updates and depending on the aggregation 1262 mode such updates may not be sensible. 1263 1264 *Locality* 1265 1266 Depending on how a value is produced, it will have a type that will determine 1267 how it may be used. 1268 1269 "Per-replica" values exist on the worker devices, with a different value for 1270 each replica. They are produced by iterating through a "distributed `Dataset`" 1271 returned by `tf.distribute.Strategy.experimental_distribute_dataset` and 1272 `tf.distribute.Strategy.experimental_distribute_datasets_from_function`. They 1273 are also the typical result returned by 1274 `tf.distribute.Strategy.experimental_run_v2`. You typically can't use a 1275 per-replica value directly in a cross-replica context, without first resolving 1276 how to aggregate the values across replicas, for instance by using 1277 `tf.distribute.Strategy.reduce`. 1278 1279 "Mirrored" values are like per-replica values, except we know that the value 1280 on all replicas are the same. We can safely read a mirrored value in a 1281 cross-replica context by using the value on any replica. You can convert 1282 a per-replica value into a mirrored value by using 1283 `tf.distribute.ReplicaContext.all_reduce`. 1284 1285 Values can also have the same locality as a variable, which is a mirrored 1286 value but residing on the same devices as the variable (as opposed to the 1287 compute devices). Such values may be passed to a call to 1288 `tf.distribute.StrategyExtended.update` to update the value of a variable. 1289 You may use `tf.distribute.StrategyExtended.colocate_vars_with` to give a 1290 variable the same locality as another variable. This is useful, for example, 1291 for "slot" variables used by an optimizer for keeping track of statistics 1292 used to update a primary/model variable. You may convert a per-replica 1293 value to a variable's locality by using 1294 `tf.distribute.StrategyExtended.reduce_to` or 1295 `tf.distribute.StrategyExtended.batch_reduce_to`. 1296 1297 In addition to slot variables which should be colocated with their primary 1298 variables, optimizers also define non-slot variables. These can be things like 1299 "number of step updates performed" or "beta1^t" and "beta2^t". Each strategy 1300 has some policy for which devices those variables should be copied too, called 1301 the "non-slot devices" (some subset of the parameter devices). We require that 1302 all non-slot variables are allocated on the same device, or mirrored across 1303 the same set of devices. You can use 1304 `tf.distribute.StrategyExtended.non_slot_devices` to pick a consistent set of 1305 devices to pass to both `tf.distribute.StrategyExtended.colocate_vars_with` 1306 and `tf.distribute.StrategyExtended.update_non_slot`. 1307 1308 *How to update a variable* 1309 1310 The standard pattern for updating variables is to: 1311 1312 1. In your function passed to `tf.distribute.Strategy.experimental_run_v2`, 1313 compute a list of (update, variable) pairs. For example, the update might 1314 be a the gradient of the loss with respect to the variable. 1315 2. Switch to cross-replica mode by calling 1316 `tf.distribute.get_replica_context().merge_call()` with the updates and 1317 variables as arguments. 1318 3. Call 1319 `tf.distribute.StrategyExtended.reduce_to(VariableAggregation.SUM, t, v)` 1320 (for one variable) or `tf.distribute.StrategyExtended.batch_reduce_to` 1321 (for a list of variables) to sum the updates. 1322 and broadcast the result to the variable's devices. 1323 4. Call `tf.distribute.StrategyExtended.update(v)` for each variable to update 1324 its value. 1325 1326 Steps 2 through 4 are done automatically by class 1327 `tf.keras.optimizers.Optimizer` if you call its 1328 `tf.keras.optimizers.Optimizer.apply_gradients` method in a replica context. 1329 They are also done automatically if you call an `assign*` method on a (non 1330 sync-on-read) variable that was constructed with an aggregation method (which 1331 is used to determine the reduction used in step 3). 1332 1333 *Distribute-aware layers* 1334 1335 Layers are generally called in a replica context, except when defining a 1336 functional model. `tf.distribute.in_cross_replica_context` will let you 1337 determine which case you are in. If in a replica context, 1338 the `tf.distribute.get_replica_context` function will return a 1339 `tf.distribute.ReplicaContext` object. The `ReplicaContext` object has an 1340 `all_reduce` method for aggregating across all replicas. Alternatively, you 1341 can update variables following steps 2-4 above. 1342 1343 Note: For new `tf.distribute.Strategy` implementations, please put all logic 1344 in a subclass of `tf.distribute.StrategyExtended`. The only code needed for 1345 the `tf.distribute.Strategy` subclass is for instantiating your subclass of 1346 `tf.distribute.StrategyExtended` in the `__init__` method. 1347 """ 1348 1349 def __init__(self, container_strategy): 1350 self._container_strategy_weakref = weakref.ref(container_strategy) 1351 self._default_device = None 1352 # This property is used to determine if we should set drop_remainder=True 1353 # when creating Datasets from numpy array inputs. 1354 self._require_static_shapes = False 1355 1356 def _container_strategy(self): 1357 """Get the containing `tf.distribute.Strategy`. 1358 1359 This should not generally be needed except when creating a new 1360 `ReplicaContext` and to validate that the caller is in the correct 1361 `scope()`. 1362 1363 Returns: 1364 The `tf.distribute.Strategy` such that `strategy.extended` is `self`. 1365 """ 1366 container_strategy = self._container_strategy_weakref() 1367 assert container_strategy is not None 1368 return container_strategy 1369 1370 def _scope(self, strategy): 1371 """Implementation of tf.distribute.Strategy.scope().""" 1372 1373 def creator_with_resource_vars(next_creator, **kwargs): 1374 """Variable creator to use in `_CurrentDistributionContext`.""" 1375 _require_strategy_scope_extended(self) 1376 kwargs["use_resource"] = True 1377 kwargs["distribute_strategy"] = strategy 1378 1379 # Unwrap `initial_value` if it is a `CheckpointInitialValue` to avoid 1380 # dereferencing a `Tensor` that is without a `name`. 1381 # TODO(b/138130844): Revisit the following check once 1382 # `CheckpointInitialValue` class is removed. 1383 if isinstance(kwargs["initial_value"], trackable.CheckpointInitialValue): 1384 kwargs["initial_value"] = kwargs["initial_value"].wrapped_value 1385 1386 return self._create_variable(next_creator, **kwargs) 1387 1388 def distributed_getter(getter, *args, **kwargs): 1389 if not self._allow_variable_partition(): 1390 if kwargs.pop("partitioner", None) is not None: 1391 tf_logging.log_first_n( 1392 tf_logging.WARN, "Partitioned variables are disabled when using " 1393 "current tf.distribute.Strategy.", 1) 1394 return getter(*args, **kwargs) 1395 1396 return _CurrentDistributionContext( 1397 strategy, 1398 variable_scope.variable_creator_scope(creator_with_resource_vars), 1399 variable_scope.variable_scope( 1400 variable_scope.get_variable_scope(), 1401 custom_getter=distributed_getter), self._default_device) 1402 1403 def _allow_variable_partition(self): 1404 return False 1405 1406 def _create_variable(self, next_creator, **kwargs): 1407 # Note: should support "colocate_with" argument. 1408 raise NotImplementedError("must be implemented in descendants") 1409 1410 def variable_created_in_scope(self, v): 1411 """Tests whether `v` was created while this strategy scope was active. 1412 1413 Variables created inside the strategy scope are "owned" by it: 1414 1415 ```python 1416 strategy = tf.distribute.StrategyExtended() 1417 with strategy.scope(): 1418 v = tf.Variable(1.) 1419 strategy.variable_created_in_scope(v) 1420 True 1421 ``` 1422 1423 Variables created outside the strategy are not owned by it: 1424 1425 ```python 1426 v = tf.Variable(1.) 1427 strategy.variable_created_in_scope(v) 1428 False 1429 ``` 1430 1431 Args: 1432 v: A `tf.Variable` instance. 1433 1434 Returns: 1435 True if `v` was created inside the scope, False if not. 1436 """ 1437 return v._distribute_strategy == self._container_strategy_weakref() # pylint: disable=protected-access 1438 1439 def colocate_vars_with(self, colocate_with_variable): 1440 """Scope that controls which devices variables will be created on. 1441 1442 No operations should be added to the graph inside this scope, it 1443 should only be used when creating variables (some implementations 1444 work by changing variable creation, others work by using a 1445 tf.compat.v1.colocate_with() scope). 1446 1447 This may only be used inside `self.scope()`. 1448 1449 Example usage: 1450 1451 ``` 1452 with strategy.scope(): 1453 var1 = tf.Variable(...) 1454 with strategy.extended.colocate_vars_with(var1): 1455 # var2 and var3 will be created on the same device(s) as var1 1456 var2 = tf.Variable(...) 1457 var3 = tf.Variable(...) 1458 1459 def fn(v1, v2, v3): 1460 # operates on v1 from var1, v2 from var2, and v3 from var3 1461 1462 # `fn` runs on every device `var1` is on, `var2` and `var3` will be there 1463 # too. 1464 strategy.extended.update(var1, fn, args=(var2, var3)) 1465 ``` 1466 1467 Args: 1468 colocate_with_variable: A variable created in this strategy's `scope()`. 1469 Variables created while in the returned context manager will be on the 1470 same set of devices as `colocate_with_variable`. 1471 1472 Returns: 1473 A context manager. 1474 """ 1475 1476 def create_colocated_variable(next_creator, **kwargs): 1477 _require_strategy_scope_extended(self) 1478 kwargs["use_resource"] = True 1479 kwargs["colocate_with"] = colocate_with_variable 1480 return next_creator(**kwargs) 1481 1482 _require_strategy_scope_extended(self) 1483 self._validate_colocate_with_variable(colocate_with_variable) 1484 return variable_scope.variable_creator_scope(create_colocated_variable) 1485 1486 def _validate_colocate_with_variable(self, colocate_with_variable): 1487 """Validate `colocate_with_variable` argument to `colocate_vars_with`.""" 1488 pass 1489 1490 def _make_dataset_iterator(self, dataset): 1491 raise NotImplementedError("must be implemented in descendants") 1492 1493 def _make_input_fn_iterator(self, input_fn, replication_mode): 1494 raise NotImplementedError("must be implemented in descendants") 1495 1496 def _experimental_distribute_dataset(self, dataset): 1497 raise NotImplementedError("must be implemented in descendants") 1498 1499 def _experimental_distribute_datasets_from_function(self, dataset_fn): 1500 raise NotImplementedError("must be implemented in descendants") 1501 1502 def _reduce(self, reduce_op, value): 1503 # Default implementation until we have an implementation for each strategy. 1504 return self._local_results( 1505 self._reduce_to(reduce_op, value, 1506 device_util.current() or "/device:CPU:0"))[0] 1507 1508 def reduce_to(self, reduce_op, value, destinations): 1509 """Combine (via e.g. sum or mean) values across replicas. 1510 1511 Args: 1512 reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum. 1513 value: A per-replica value with one value per replica. 1514 destinations: A mirrored variable, a per-replica tensor, or a device 1515 string. The return value will be copied to all destination devices (or 1516 all the devices where the `destinations` value resides). To perform an 1517 all-reduction, pass `value` to `destinations`. 1518 1519 Returns: 1520 A tensor or value mirrored to `destinations`. 1521 """ 1522 # TODO(josh11b): More docstring 1523 _require_cross_replica_or_default_context_extended(self) 1524 assert not isinstance(destinations, (list, tuple)) 1525 assert not isinstance(reduce_op, variable_scope.VariableAggregation) 1526 if isinstance(reduce_op, six.string_types): 1527 reduce_op = reduce_util.ReduceOp(reduce_op.upper()) 1528 assert (reduce_op == reduce_util.ReduceOp.SUM or 1529 reduce_op == reduce_util.ReduceOp.MEAN) 1530 return self._reduce_to(reduce_op, value, destinations) 1531 1532 def _reduce_to(self, reduce_op, value, destinations): 1533 raise NotImplementedError("must be implemented in descendants") 1534 1535 def batch_reduce_to(self, reduce_op, value_destination_pairs): 1536 """Combine multiple `reduce_to` calls into one for faster execution. 1537 1538 Args: 1539 reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum. 1540 value_destination_pairs: A sequence of (value, destinations) 1541 pairs. See `reduce_to()` for a description. 1542 1543 Returns: 1544 A list of mirrored values, one per pair in `value_destination_pairs`. 1545 """ 1546 # TODO(josh11b): More docstring 1547 _require_cross_replica_or_default_context_extended(self) 1548 assert not isinstance(reduce_op, variable_scope.VariableAggregation) 1549 if isinstance(reduce_op, six.string_types): 1550 reduce_op = reduce_util.ReduceOp(reduce_op.upper()) 1551 return self._batch_reduce_to(reduce_op, value_destination_pairs) 1552 1553 def _batch_reduce_to(self, reduce_op, value_destination_pairs): 1554 return [ 1555 self.reduce_to(reduce_op, t, destinations=v) 1556 for t, v in value_destination_pairs 1557 ] 1558 1559 def update(self, var, fn, args=(), kwargs=None, group=True): 1560 """Run `fn` to update `var` using inputs mirrored to the same devices. 1561 1562 If `var` is mirrored across multiple devices, then this implements 1563 logic like: 1564 1565 ``` 1566 results = {} 1567 for device, v in var: 1568 with tf.device(device): 1569 # args and kwargs will be unwrapped if they are mirrored. 1570 results[device] = fn(v, *args, **kwargs) 1571 return merged(results) 1572 ``` 1573 1574 Otherwise this returns `fn(var, *args, **kwargs)` colocated with `var`. 1575 1576 Neither `args` nor `kwargs` may contain per-replica values. 1577 If they contain mirrored values, they will be unwrapped before 1578 calling `fn`. 1579 1580 Args: 1581 var: Variable, possibly mirrored to multiple devices, to operate on. 1582 fn: Function to call. Should take the variable as the first argument. 1583 args: Tuple or list. Additional positional arguments to pass to `fn()`. 1584 kwargs: Dict with keyword arguments to pass to `fn()`. 1585 group: Boolean. Defaults to True. If False, the return value will be 1586 unwrapped. 1587 1588 Returns: 1589 By default, the merged return value of `fn` across all replicas. The 1590 merged result has dependencies to make sure that if it is evaluated at 1591 all, the side effects (updates) will happen on every replica. If instead 1592 "group=False" is specified, this function will return a nest of lists 1593 where each list has an element per replica, and the caller is responsible 1594 for ensuring all elements are executed. 1595 """ 1596 _require_cross_replica_or_default_context_extended(self) 1597 if kwargs is None: 1598 kwargs = {} 1599 fn = autograph.tf_convert( 1600 fn, autograph_ctx.control_status_ctx(), convert_by_default=False) 1601 with self._container_strategy().scope(): 1602 return self._update(var, fn, args, kwargs, group) 1603 1604 def _update(self, var, fn, args, kwargs, group): 1605 raise NotImplementedError("must be implemented in descendants") 1606 1607 def update_non_slot( 1608 self, colocate_with, fn, args=(), kwargs=None, group=True): 1609 """Runs `fn(*args, **kwargs)` on `colocate_with` devices. 1610 1611 Args: 1612 colocate_with: The return value of `non_slot_devices()`. 1613 fn: Function to execute. 1614 args: Tuple or list. Positional arguments to pass to `fn()`. 1615 kwargs: Dict with keyword arguments to pass to `fn()`. 1616 group: Boolean. Defaults to True. If False, the return value will be 1617 unwrapped. 1618 1619 Returns: 1620 Return value of `fn`, possibly merged across devices. 1621 """ 1622 _require_cross_replica_or_default_context_extended(self) 1623 if kwargs is None: 1624 kwargs = {} 1625 fn = autograph.tf_convert( 1626 fn, autograph_ctx.control_status_ctx(), convert_by_default=False) 1627 with self._container_strategy().scope(): 1628 return self._update_non_slot(colocate_with, fn, args, kwargs, group) 1629 1630 def _update_non_slot(self, colocate_with, fn, args, kwargs, group): 1631 raise NotImplementedError("must be implemented in descendants") 1632 1633 def _local_results(self, distributed_value): 1634 raise NotImplementedError("must be implemented in descendants") 1635 1636 def value_container(self, value): 1637 """Returns the container that this per-replica `value` belongs to. 1638 1639 Args: 1640 value: A value returned by `experimental_run_v2()` or a variable 1641 created in `scope()`. 1642 1643 Returns: 1644 A container that `value` belongs to. 1645 If value does not belong to any container (including the case of 1646 container having been destroyed), returns the value itself. 1647 `value in experimental_local_results(value_container(value))` will 1648 always be true. 1649 """ 1650 raise NotImplementedError("must be implemented in descendants") 1651 1652 def _group(self, value, name=None): 1653 """Implementation of `group`.""" 1654 value = nest.flatten(self._local_results(value)) 1655 1656 if len(value) != 1 or name is not None: 1657 return control_flow_ops.group(value, name=name) 1658 # Special handling for the common case of one op. 1659 v, = value 1660 if hasattr(v, "op"): 1661 v = v.op 1662 return v 1663 1664 @property 1665 def experimental_require_static_shapes(self): 1666 """Returns `True` if static shape is required; `False` otherwise.""" 1667 return self._require_static_shapes 1668 1669 @property 1670 def _num_replicas_in_sync(self): 1671 """Returns number of replicas over which gradients are aggregated.""" 1672 raise NotImplementedError("must be implemented in descendants") 1673 1674 @property 1675 def worker_devices(self): 1676 """Returns the tuple of all devices used to for compute replica execution. 1677 """ 1678 # TODO(josh11b): More docstring 1679 raise NotImplementedError("must be implemented in descendants") 1680 1681 @property 1682 def parameter_devices(self): 1683 """Returns the tuple of all devices used to place variables.""" 1684 # TODO(josh11b): More docstring 1685 raise NotImplementedError("must be implemented in descendants") 1686 1687 def non_slot_devices(self, var_list): 1688 """Device(s) for non-slot variables. 1689 1690 Create variables on these devices in a 1691 `with colocate_vars_with(non_slot_devices(...)):` block. 1692 Update those using `update_non_slot()`. 1693 1694 Args: 1695 var_list: The list of variables being optimized, needed with the 1696 default `tf.distribute.Strategy`. 1697 Returns: 1698 A sequence of devices for non-slot variables. 1699 """ 1700 raise NotImplementedError("must be implemented in descendants") 1701 1702 def _configure(self, 1703 session_config=None, 1704 cluster_spec=None, 1705 task_type=None, 1706 task_id=None): 1707 """Configures the strategy class.""" 1708 del session_config, cluster_spec, task_type, task_id 1709 1710 def _update_config_proto(self, config_proto): 1711 return copy.deepcopy(config_proto) 1712 1713 def _in_multi_worker_mode(self): 1714 """Whether this strategy indicates working in multi-worker settings. 1715 1716 Multi-worker training refers to the setup where the training is 1717 distributed across multiple workers, as opposed to the case where 1718 only a local process performs the training. This function is 1719 used by higher-level apis such as Keras' `model.fit()` to infer 1720 for example whether or not a distribute coordinator should be run, 1721 and thus TensorFlow servers should be started for communication 1722 with other servers in the cluster, or whether or not saving/restoring 1723 checkpoints is relevant for preemption fault tolerance. 1724 1725 Subclasses should override this to provide whether the strategy is 1726 currently in multi-worker setup. 1727 1728 Experimental. Signature and implementation are subject to change. 1729 """ 1730 raise NotImplementedError("must be implemented in descendants") 1731 1732 1733@tf_export(v1=["distribute.StrategyExtended"]) # pylint: disable=missing-docstring 1734class StrategyExtendedV1(StrategyExtendedV2): 1735 1736 __doc__ = StrategyExtendedV2.__doc__ 1737 1738 def experimental_make_numpy_dataset(self, numpy_input, session=None): 1739 """Makes a dataset for input provided via a numpy array. 1740 1741 This avoids adding `numpy_input` as a large constant in the graph, 1742 and copies the data to the machine or machines that will be processing 1743 the input. 1744 1745 Args: 1746 numpy_input: A nest of NumPy input arrays that will be distributed evenly 1747 across all replicas. Note that lists of Numpy arrays are stacked, as 1748 that is normal `tf.data.Dataset` behavior. 1749 session: (TensorFlow v1.x graph execution only) A session used for 1750 initialization. 1751 1752 Returns: 1753 A `tf.data.Dataset` representing `numpy_input`. 1754 """ 1755 _require_cross_replica_or_default_context_extended(self) 1756 return self._experimental_make_numpy_dataset(numpy_input, session=session) 1757 1758 def _experimental_make_numpy_dataset(self, numpy_input, session): 1759 raise NotImplementedError("must be implemented in descendants") 1760 1761 def broadcast_to(self, tensor, destinations): 1762 """Mirror a tensor on one device to all worker devices. 1763 1764 Args: 1765 tensor: A Tensor value to broadcast. 1766 destinations: A mirrored variable or device string specifying the 1767 destination devices to copy `tensor` to. 1768 1769 Returns: 1770 A value mirrored to `destinations` devices. 1771 """ 1772 assert destinations is not None # from old strategy.broadcast() 1773 # TODO(josh11b): More docstring 1774 _require_cross_replica_or_default_context_extended(self) 1775 assert not isinstance(destinations, (list, tuple)) 1776 return self._broadcast_to(tensor, destinations) 1777 1778 def _broadcast_to(self, tensor, destinations): 1779 raise NotImplementedError("must be implemented in descendants") 1780 1781 def experimental_run_steps_on_iterator(self, 1782 fn, 1783 iterator, 1784 iterations=1, 1785 initial_loop_values=None): 1786 """DEPRECATED: please use `experimental_run_v2` instead. 1787 1788 Run `fn` with input from `iterator` for `iterations` times. 1789 1790 This method can be used to run a step function for training a number of 1791 times using input from a dataset. 1792 1793 Args: 1794 fn: function to run using this distribution strategy. The function must 1795 have the following signature: `def fn(context, inputs)`. `context` is an 1796 instance of `MultiStepContext` that will be passed when `fn` is run. 1797 `context` can be used to specify the outputs to be returned from `fn` 1798 by calling `context.set_last_step_output`. It can also be used to 1799 capture non tensor outputs by `context.set_non_tensor_output`. See 1800 `MultiStepContext` documentation for more information. `inputs` will 1801 have same type/structure as `iterator.get_next()`. Typically, `fn` 1802 will use `call_for_each_replica` method of the strategy to distribute 1803 the computation over multiple replicas. 1804 iterator: Iterator of a dataset that represents the input for `fn`. The 1805 caller is responsible for initializing the iterator as needed. 1806 iterations: (Optional) Number of iterations that `fn` should be run. 1807 Defaults to 1. 1808 initial_loop_values: (Optional) Initial values to be passed into the 1809 loop that runs `fn`. Defaults to `None`. # TODO(priyag): Remove 1810 initial_loop_values argument when we have a mechanism to infer the 1811 outputs of `fn`. 1812 1813 Returns: 1814 Returns the `MultiStepContext` object which has the following properties, 1815 among other things: 1816 - run_op: An op that runs `fn` `iterations` times. 1817 - last_step_outputs: A dictionary containing tensors set using 1818 `context.set_last_step_output`. Evaluating this returns the value of 1819 the tensors after the last iteration. 1820 - non_tensor_outputs: A dictionatry containing anything that was set by 1821 `fn` by calling `context.set_non_tensor_output`. 1822 """ 1823 _require_cross_replica_or_default_context_extended(self) 1824 with self._container_strategy().scope(): 1825 return self._experimental_run_steps_on_iterator(fn, iterator, iterations, 1826 initial_loop_values) 1827 1828 def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, 1829 initial_loop_values): 1830 raise NotImplementedError("must be implemented in descendants") 1831 1832 def call_for_each_replica(self, fn, args=(), kwargs=None): 1833 """Run `fn` once per replica. 1834 1835 `fn` may call `tf.get_replica_context()` to access methods such as 1836 `replica_id_in_sync_group` and `merge_call()`. 1837 1838 `merge_call()` is used to communicate between the replicas and 1839 re-enter the cross-replica context. All replicas pause their execution 1840 having encountered a `merge_call()` call. After that the 1841 `merge_fn`-function is executed. Its results are then unwrapped and 1842 given back to each replica call. After that execution resumes until 1843 `fn` is complete or encounters another `merge_call()`. Example: 1844 1845 ```python 1846 # Called once in "cross-replica" context. 1847 def merge_fn(distribution, three_plus_replica_id): 1848 # sum the values across replicas 1849 return sum(distribution.experimental_local_results(three_plus_replica_id)) 1850 1851 # Called once per replica in `distribution`, in a "replica" context. 1852 def fn(three): 1853 replica_ctx = tf.get_replica_context() 1854 v = three + replica_ctx.replica_id_in_sync_group 1855 # Computes the sum of the `v` values across all replicas. 1856 s = replica_ctx.merge_call(merge_fn, args=(v,)) 1857 return s + v 1858 1859 with distribution.scope(): 1860 # in "cross-replica" context 1861 ... 1862 merged_results = distribution.experimental_run_v2(fn, args=[3]) 1863 # merged_results has the values from every replica execution of `fn`. 1864 # This statement prints a list: 1865 print(distribution.experimental_local_results(merged_results)) 1866 ``` 1867 1868 Args: 1869 fn: function to run (will be run once per replica). 1870 args: Tuple or list with positional arguments for `fn`. 1871 kwargs: Dict with keyword arguments for `fn`. 1872 1873 Returns: 1874 Merged return value of `fn` across all replicas. 1875 """ 1876 _require_cross_replica_or_default_context_extended(self) 1877 if kwargs is None: 1878 kwargs = {} 1879 with self._container_strategy().scope(): 1880 return self._call_for_each_replica(fn, args, kwargs) 1881 1882 def _call_for_each_replica(self, fn, args, kwargs): 1883 raise NotImplementedError("must be implemented in descendants") 1884 1885 def read_var(self, v): 1886 """Reads the value of a variable. 1887 1888 Returns the aggregate value of a replica-local variable, or the 1889 (read-only) value of any other variable. 1890 1891 Args: 1892 v: A variable allocated within the scope of this `tf.distribute.Strategy`. 1893 1894 Returns: 1895 A tensor representing the value of `v`, aggregated across replicas if 1896 necessary. 1897 """ 1898 raise NotImplementedError("must be implemented in descendants") 1899 1900 @property 1901 def experimental_between_graph(self): 1902 """Whether the strategy uses between-graph replication or not. 1903 1904 This is expected to return a constant value that will not be changed 1905 throughout its life cycle. 1906 """ 1907 raise NotImplementedError("must be implemented in descendants") 1908 1909 @property 1910 def experimental_should_init(self): 1911 """Whether initialization is needed.""" 1912 raise NotImplementedError("must be implemented in descendants") 1913 1914 @property 1915 def should_checkpoint(self): 1916 """Whether checkpointing is needed.""" 1917 raise NotImplementedError("must be implemented in descendants") 1918 1919 @property 1920 def should_save_summary(self): 1921 """Whether saving summaries is needed.""" 1922 raise NotImplementedError("must be implemented in descendants") 1923 1924 1925# A note about the difference between the context managers 1926# `ReplicaContext` (defined here) and `_CurrentDistributionContext` 1927# (defined above) used by `tf.distribute.Strategy.scope()`: 1928# 1929# * a ReplicaContext is only present during a `experimental_run_v2()` 1930# call (except during a `merge_run` call) and in such a scope it 1931# will be returned by calls to `get_replica_context()`. Implementers of new 1932# Strategy descendants will frequently also need to 1933# define a descendant of ReplicaContext, and are responsible for 1934# entering and exiting this context. 1935# 1936# * Strategy.scope() sets up a variable_creator scope that 1937# changes variable creation calls (e.g. to make mirrored 1938# variables). This is intended as an outer scope that users enter once 1939# around their model creation and graph definition. There is no 1940# anticipated need to define descendants of _CurrentDistributionContext. 1941# It sets the current Strategy for purposes of 1942# `get_strategy()` and `has_strategy()` 1943# and switches the thread mode to a "cross-replica context". 1944@tf_export("distribute.ReplicaContext") 1945class ReplicaContext(object): 1946 """`tf.distribute.Strategy` API when in a replica context. 1947 1948 You can use `tf.distribute.get_replica_context` to get an instance of 1949 `ReplicaContext`. This should be inside your replicated step function, such 1950 as in a `tf.distribute.Strategy.experimental_run_v2` call. 1951 """ 1952 1953 def __init__(self, strategy, replica_id_in_sync_group): 1954 self._strategy = strategy 1955 self._thread_context = distribution_strategy_context._InReplicaThreadMode( # pylint: disable=protected-access 1956 self) 1957 self._replica_id_in_sync_group = replica_id_in_sync_group 1958 self._summary_recording_distribution_strategy = None 1959 1960 def __enter__(self): 1961 _push_per_thread_mode(self._thread_context) 1962 1963 def replica_id_is_zero(): 1964 return math_ops.equal(self._replica_id_in_sync_group, 1965 constant_op.constant(0)) 1966 1967 summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access 1968 self._summary_recording_distribution_strategy = ( 1969 summary_state.is_recording_distribution_strategy) 1970 summary_state.is_recording_distribution_strategy = replica_id_is_zero 1971 1972 def __exit__(self, exception_type, exception_value, traceback): 1973 summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access 1974 summary_state.is_recording_distribution_strategy = ( 1975 self._summary_recording_distribution_strategy) 1976 _pop_per_thread_mode() 1977 1978 def merge_call(self, merge_fn, args=(), kwargs=None): 1979 """Merge args across replicas and run `merge_fn` in a cross-replica context. 1980 1981 This allows communication and coordination when there are multiple calls 1982 to the step_fn triggered by a call to 1983 `strategy.experimental_run_v2(step_fn, ...)`. 1984 1985 See `tf.distribute.Strategy.experimental_run_v2` for an 1986 explanation. 1987 1988 If not inside a distributed scope, this is equivalent to: 1989 1990 ``` 1991 strategy = tf.distribute.get_strategy() 1992 with cross-replica-context(strategy): 1993 return merge_fn(strategy, *args, **kwargs) 1994 ``` 1995 1996 Args: 1997 merge_fn: Function that joins arguments from threads that are given as 1998 PerReplica. It accepts `tf.distribute.Strategy` object as 1999 the first argument. 2000 args: List or tuple with positional per-thread arguments for `merge_fn`. 2001 kwargs: Dict with keyword per-thread arguments for `merge_fn`. 2002 2003 Returns: 2004 The return value of `merge_fn`, except for `PerReplica` values which are 2005 unpacked. 2006 """ 2007 require_replica_context(self) 2008 if kwargs is None: 2009 kwargs = {} 2010 merge_fn = autograph.tf_convert( 2011 merge_fn, autograph_ctx.control_status_ctx(), convert_by_default=False) 2012 return self._merge_call(merge_fn, args, kwargs) 2013 2014 def _merge_call(self, merge_fn, args, kwargs): 2015 """Default implementation for single replica.""" 2016 _push_per_thread_mode( # thread-local, so not needed with multiple threads 2017 distribution_strategy_context._CrossReplicaThreadMode(self._strategy)) # pylint: disable=protected-access 2018 try: 2019 return merge_fn(self._strategy, *args, **kwargs) 2020 finally: 2021 _pop_per_thread_mode() 2022 2023 @property 2024 def num_replicas_in_sync(self): 2025 """Returns number of replicas over which gradients are aggregated.""" 2026 return self._strategy.num_replicas_in_sync 2027 2028 @property 2029 def replica_id_in_sync_group(self): 2030 """Returns the id of the replica being defined. 2031 2032 This identifies the replica that is part of a sync group. Currently we 2033 assume that all sync groups contain the same number of replicas. The value 2034 of the replica id can range from 0 to `num_replica_in_sync` - 1. 2035 2036 NOTE: This is not guaranteed to be the same ID as the XLA replica ID use 2037 for low-level operations such as collective_permute. 2038 """ 2039 require_replica_context(self) 2040 return self._replica_id_in_sync_group 2041 2042 @property 2043 def strategy(self): 2044 """The current `tf.distribute.Strategy` object.""" 2045 return self._strategy 2046 2047 @property 2048 def devices(self): 2049 """The devices this replica is to be executed on, as a tuple of strings.""" 2050 require_replica_context(self) 2051 return (device_util.current(),) 2052 2053 def all_reduce(self, reduce_op, value): 2054 """All-reduces the given `value Tensor` nest across replicas. 2055 2056 If `all_reduce` is called in any replica, it must be called in all replicas. 2057 The nested structure and `Tensor` shapes must be identical in all replicas. 2058 2059 IMPORTANT: The ordering of communications must be identical in all replicas. 2060 2061 Example with two replicas: 2062 Replica 0 `value`: {'a': 1, 'b': [40, 1]} 2063 Replica 1 `value`: {'a': 3, 'b': [ 2, 98]} 2064 2065 If `reduce_op` == `SUM`: 2066 Result (on all replicas): {'a': 4, 'b': [42, 99]} 2067 2068 If `reduce_op` == `MEAN`: 2069 Result (on all replicas): {'a': 2, 'b': [21, 49.5]} 2070 2071 Args: 2072 reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum. 2073 value: The nested structure of `Tensor`s to all-reduce. The structure must 2074 be compatible with `tf.nest`. 2075 2076 Returns: 2077 A `Tensor` nest with the reduced `value`s from each replica. 2078 """ 2079 if isinstance(reduce_op, six.string_types): 2080 reduce_op = reduce_util.ReduceOp(reduce_op.upper()) 2081 2082 def batch_all_reduce(strategy, *value_flat): 2083 return strategy.extended.batch_reduce_to( 2084 reduce_op, [(v, _batch_reduce_destination(v)) for v in value_flat]) 2085 2086 if reduce_op in [reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN]: 2087 # TODO(cjfj): Work out why `batch_reduce` doesn't return the correct grad. 2088 @custom_gradient.custom_gradient 2089 def grad_wrapper(*xs): 2090 ys = self.merge_call(batch_all_reduce, args=xs) 2091 # The gradient of an all-sum is itself an all-sum (all-mean, likewise). 2092 return ys, lambda *dy_s: self.all_reduce(reduce_op, dy_s) 2093 return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value))) 2094 else: 2095 # TODO(cjfj): Implement gradients for other reductions. 2096 reduced = nest.pack_sequence_as( 2097 value, self.merge_call(batch_all_reduce, args=nest.flatten(value))) 2098 return nest.map_structure(array_ops.prevent_gradient, reduced) 2099 2100 # TODO(josh11b): Implement `start_all_reduce(method, t)` for efficient 2101 # all-reduce. It would return a function returning the result of reducing `t` 2102 # across all replicas. The caller would wait to call this function until they 2103 # needed the reduce result, allowing an efficient implementation: 2104 # * With eager execution, the reduction could be performed asynchronously 2105 # in the background, not blocking until the result was needed. 2106 # * When constructing a graph, it could batch up all reduction requests up 2107 # to that point that the first result is needed. Most likely this can be 2108 # implemented in terms of `merge_call()` and `batch_reduce_to()`. 2109 2110 2111def _batch_reduce_destination(x): 2112 """Returns the destinations for batch all-reduce.""" 2113 if isinstance(x, ops.Tensor): 2114 # If this is a one device strategy. 2115 return x.device 2116 else: 2117 return x 2118 2119 2120# ------------------------------------------------------------------------------ 2121 2122 2123_creating_default_strategy_singleton = False 2124 2125 2126class _DefaultDistributionStrategy(StrategyV1): 2127 """Default `tf.distribute.Strategy` if none is explicitly selected.""" 2128 2129 def __init__(self): 2130 if not _creating_default_strategy_singleton: 2131 raise RuntimeError("Should only create a single instance of " 2132 "_DefaultDistributionStrategy") 2133 super(_DefaultDistributionStrategy, self).__init__( 2134 _DefaultDistributionExtended(self)) 2135 2136 def __deepcopy__(self, memo): 2137 del memo 2138 raise RuntimeError("Should only create a single instance of " 2139 "_DefaultDistributionStrategy") 2140 2141 2142class _DefaultDistributionContext(object): 2143 """Context manager setting the default `tf.distribute.Strategy`.""" 2144 2145 def __init__(self, strategy): 2146 2147 def creator(next_creator, **kwargs): 2148 _require_strategy_scope_strategy(strategy) 2149 return next_creator(**kwargs) 2150 2151 self._var_creator_scope = variable_scope.variable_creator_scope(creator) 2152 self._strategy = strategy 2153 self._nested_count = 0 2154 2155 def __enter__(self): 2156 # Allow this scope to be entered if this strategy is already in scope. 2157 if distribution_strategy_context.has_strategy(): 2158 raise RuntimeError("Must not nest tf.distribute.Strategy scopes.") 2159 if self._nested_count == 0: 2160 self._var_creator_scope.__enter__() 2161 self._nested_count += 1 2162 return self._strategy 2163 2164 def __exit__(self, exception_type, exception_value, traceback): 2165 self._nested_count -= 1 2166 if self._nested_count == 0: 2167 try: 2168 self._var_creator_scope.__exit__( 2169 exception_type, exception_value, traceback) 2170 except RuntimeError as e: 2171 six.raise_from( 2172 RuntimeError("Variable creator scope nesting error: move call to " 2173 "tf.distribute.set_strategy() out of `with` scope."), 2174 e) 2175 2176 2177class _DefaultDistributionExtended(StrategyExtendedV1): 2178 """Implementation of _DefaultDistributionStrategy.""" 2179 2180 def __init__(self, container_strategy): 2181 super(_DefaultDistributionExtended, self).__init__(container_strategy) 2182 self._retrace_functions_for_each_device = False 2183 2184 def _scope(self, strategy): 2185 """Context manager setting a variable creator and `self` as current.""" 2186 return _DefaultDistributionContext(strategy) 2187 2188 def colocate_vars_with(self, colocate_with_variable): 2189 """Does not require `self.scope`.""" 2190 _require_strategy_scope_extended(self) 2191 return ops.colocate_with(colocate_with_variable) 2192 2193 def variable_created_in_scope(self, v): 2194 return v._distribute_strategy is None # pylint: disable=protected-access 2195 2196 def _experimental_distribute_dataset(self, dataset): 2197 return dataset 2198 2199 def _experimental_distribute_datasets_from_function(self, dataset_fn): 2200 return dataset_fn(InputContext()) 2201 2202 def _make_dataset_iterator(self, dataset): 2203 return _DefaultDistributionExtended.DefaultInputIterator(dataset) 2204 2205 def _make_input_fn_iterator(self, 2206 input_fn, 2207 replication_mode=InputReplicationMode.PER_WORKER): 2208 dataset = input_fn(InputContext()) 2209 return _DefaultDistributionExtended.DefaultInputIterator(dataset) 2210 2211 def _experimental_make_numpy_dataset(self, numpy_input, session): 2212 numpy_flat = nest.flatten(numpy_input) 2213 vars_flat = tuple( 2214 variable_scope.variable(array_ops.zeros(i.shape, i.dtype), 2215 trainable=False, use_resource=True) 2216 for i in numpy_flat 2217 ) 2218 for v, i in zip(vars_flat, numpy_flat): 2219 numpy_dataset.init_var_from_numpy(v, i, session) 2220 vars_nested = nest.pack_sequence_as(numpy_input, vars_flat) 2221 return dataset_ops.Dataset.from_tensor_slices(vars_nested) 2222 2223 def _broadcast_to(self, tensor, destinations): 2224 if destinations is None: 2225 return tensor 2226 else: 2227 raise NotImplementedError("TODO") 2228 2229 def _call_for_each_replica(self, fn, args, kwargs): 2230 with ReplicaContext( 2231 self._container_strategy(), 2232 replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)): 2233 return fn(*args, **kwargs) 2234 2235 def _reduce_to(self, reduce_op, value, destinations): 2236 # TODO(josh11b): Use destinations? 2237 del reduce_op, destinations 2238 return value 2239 2240 def _update(self, var, fn, args, kwargs, group): 2241 # The implementations of _update() and _update_non_slot() are identical 2242 # except _update() passes `var` as the first argument to `fn()`. 2243 return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group) 2244 2245 def _update_non_slot(self, colocate_with, fn, args, kwargs, should_group): 2246 # TODO(josh11b): Figure out what we should be passing to UpdateContext() 2247 # once that value is used for something. 2248 with UpdateContext(colocate_with): 2249 result = fn(*args, **kwargs) 2250 if should_group: 2251 return result 2252 else: 2253 return nest.map_structure(self._local_results, result) 2254 2255 def read_var(self, replica_local_var): 2256 return array_ops.identity(replica_local_var) 2257 2258 def _local_results(self, distributed_value): 2259 return (distributed_value,) 2260 2261 def value_container(self, value): 2262 return value 2263 2264 @property 2265 def _num_replicas_in_sync(self): 2266 return 1 2267 2268 @property 2269 def worker_devices(self): 2270 raise RuntimeError("worker_devices() method unsupported by default " 2271 "tf.distribute.Strategy.") 2272 2273 @property 2274 def parameter_devices(self): 2275 raise RuntimeError("parameter_devices() method unsupported by default " 2276 "tf.distribute.Strategy.") 2277 2278 def non_slot_devices(self, var_list): 2279 return min(var_list, key=lambda x: x.name) 2280 2281 def _in_multi_worker_mode(self): 2282 """Whether this strategy indicates working in multi-worker settings.""" 2283 # Default strategy doesn't indicate multi-worker training. 2284 return False 2285 2286 # TODO(priyag): This should inherit from `InputIterator`, once dependency 2287 # issues have been resolved. 2288 class DefaultInputIterator(object): 2289 """Default implementation of `InputIterator` for default strategy.""" 2290 2291 def __init__(self, dataset): 2292 self._dataset = dataset 2293 if eager_context.executing_eagerly(): 2294 self._iterator = dataset_ops.make_one_shot_iterator(dataset) 2295 else: 2296 self._iterator = dataset_ops.make_initializable_iterator(dataset) 2297 2298 def get_next(self): 2299 return self._iterator.get_next() 2300 2301 @deprecated(None, "Use the iterator's `initializer` property instead.") 2302 def initialize(self): 2303 """Initialize underlying iterators. 2304 2305 Returns: 2306 A list of any initializer ops that should be run. 2307 """ 2308 if eager_context.executing_eagerly(): 2309 self._iterator = self._dataset.make_one_shot_iterator() 2310 return [] 2311 else: 2312 return [self._iterator.initializer] 2313 2314 @property 2315 def initializer(self): 2316 """Returns a list of ops that initialize the iterator.""" 2317 return self.initialize() 2318 2319 # TODO(priyag): Delete this once all strategies use global batch size. 2320 @property 2321 def _global_batch_size(self): 2322 """Global and per-replica batching are equivalent for this strategy.""" 2323 return True 2324 2325 2326# ------------------------------------------------------------------------------ 2327# We haven't yet implemented deserialization for DistributedVariables. 2328# So here we catch any attempts to deserialize variables 2329# when using distribution strategies. 2330# pylint: disable=protected-access 2331_original_from_proto = resource_variable_ops._from_proto_fn 2332 2333 2334def _from_proto_fn(v, import_scope=None): 2335 if distribution_strategy_context.has_strategy(): 2336 raise NotImplementedError( 2337 "Deserialization of variables is not yet supported when using a " 2338 "tf.distribute.Strategy.") 2339 else: 2340 return _original_from_proto(v, import_scope=import_scope) 2341 2342resource_variable_ops._from_proto_fn = _from_proto_fn 2343# pylint: enable=protected-access 2344 2345 2346#------------------------------------------------------------------------------- 2347# Shorthand for some methods from distribution_strategy_context. 2348_push_per_thread_mode = distribution_strategy_context._push_per_thread_mode # pylint: disable=protected-access 2349_get_per_thread_mode = distribution_strategy_context._get_per_thread_mode # pylint: disable=protected-access 2350_pop_per_thread_mode = distribution_strategy_context._pop_per_thread_mode # pylint: disable=protected-access 2351_get_default_replica_mode = ( 2352 distribution_strategy_context._get_default_replica_mode) # pylint: disable=protected-access 2353 2354 2355# ------------------------------------------------------------------------------ 2356# Metrics to track which distribution strategy is being called 2357distribution_strategy_gauge = monitoring.StringGauge( 2358 "/tensorflow/api/distribution_strategy", 2359 "Gauge to track the type of distribution strategy used.", "TFVersion") 2360distribution_strategy_replica_gauge = monitoring.IntGauge( 2361 "/tensorflow/api/distribution_strategy/replica", 2362 "Gauge to track the number of replica each distribution strategy used.", 2363 "CountType") 2364