1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Library for running a computation across multiple devices.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import copy 22import threading 23import weakref 24import enum 25 26from tensorflow.python.data.ops import dataset_ops 27from tensorflow.python.distribute import device_util 28from tensorflow.python.distribute import distribution_strategy_context 29from tensorflow.python.distribute import numpy_dataset 30from tensorflow.python.distribute import reduce_util 31from tensorflow.python.eager import context as eager_context 32from tensorflow.python.framework import constant_op 33from tensorflow.python.framework import dtypes 34from tensorflow.python.framework import ops 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import control_flow_ops 37from tensorflow.python.ops import custom_gradient 38from tensorflow.python.ops import math_ops 39from tensorflow.python.ops import resource_variable_ops 40from tensorflow.python.ops import variable_scope 41from tensorflow.python.platform import tf_logging 42from tensorflow.python.util import nest 43from tensorflow.python.util.tf_export import tf_export 44from tensorflow.tools.docs import doc_controls 45 46 47# ------------------------------------------------------------------------------ 48# Context tracking whether in a strategy.update() or .update_non_slot() call. 49 50 51_update_device = threading.local() 52 53 54def get_update_device(): 55 """Get the current device if in a `tf.distribute.Strategy.update()` call.""" 56 try: 57 return _update_device.current 58 except AttributeError: 59 return None 60 61 62class UpdateContext(object): 63 """Context manager when you are in `update()` or `update_non_slot()`.""" 64 65 def __init__(self, device): 66 self._device = device 67 self._old_device = None 68 69 def __enter__(self): 70 self._old_device = get_update_device() 71 _update_device.current = self._device 72 73 def __exit__(self, exception_type, exception_value, traceback): 74 del exception_type, exception_value, traceback 75 _update_device.current = self._old_device 76 77 78# ------------------------------------------------------------------------------ 79# Public utility functions. 80 81 82@tf_export(v1=["distribute.get_loss_reduction"]) 83def get_loss_reduction(): 84 """DEPRECATED: Now always returns `tf.distribute.ReduceOp.SUM`. 85 86 We now always make the complete adjustment when computing the loss, so 87 code should always add gradients/losses across replicas, never average. 88 """ 89 return reduce_util.ReduceOp.SUM 90 91 92# ------------------------------------------------------------------------------ 93# Internal API for validating the current thread mode 94 95 96def _require_cross_replica_or_default_context_extended(extended): 97 """Verify in cross-replica context.""" 98 context = _get_per_thread_mode() 99 cross_replica = context.cross_replica_context 100 if cross_replica is not None and cross_replica.extended is extended: 101 return 102 if context is _get_default_replica_mode(): 103 return 104 strategy = extended._container_strategy() # pylint: disable=protected-access 105 # We have an error to report, figure out the right message. 106 if context.strategy is not strategy: 107 _wrong_strategy_scope(strategy, context) 108 assert cross_replica is None 109 raise RuntimeError("Method requires being in cross-replica context, use " 110 "get_replica_context().merge_call()") 111 112 113def _wrong_strategy_scope(strategy, context): 114 # Figure out the right error message. 115 if not distribution_strategy_context.has_strategy(): 116 raise RuntimeError( 117 'Need to be inside "with strategy.scope()" for %s' % 118 (strategy,)) 119 else: 120 raise RuntimeError( 121 "Mixing different tf.distribute.Strategy objects: %s is not %s" % 122 (context.strategy, strategy)) 123 124 125def require_replica_context(replica_ctx): 126 """Verify in `replica_ctx` replica context.""" 127 context = _get_per_thread_mode() 128 if context.replica_context is replica_ctx: return 129 # We have an error to report, figure out the right message. 130 if context.replica_context is None: 131 raise RuntimeError("Need to be inside `call_for_each_replica()`") 132 if context.strategy is replica_ctx.strategy: 133 # Two different ReplicaContexts with the same tf.distribute.Strategy. 134 raise RuntimeError("Mismatching ReplicaContext.") 135 raise RuntimeError( 136 "Mismatching tf.distribute.Strategy objects: %s is not %s." % 137 (context.strategy, replica_ctx.strategy)) 138 139 140def _require_strategy_scope_strategy(strategy): 141 """Verify in a `strategy.scope()` in this thread.""" 142 context = _get_per_thread_mode() 143 if context.strategy is strategy: return 144 _wrong_strategy_scope(strategy, context) 145 146 147def _require_strategy_scope_extended(extended): 148 """Verify in a `distribution_strategy.scope()` in this thread.""" 149 context = _get_per_thread_mode() 150 if context.strategy.extended is extended: return 151 # Report error. 152 strategy = extended._container_strategy() # pylint: disable=protected-access 153 _wrong_strategy_scope(strategy, context) 154 155 156# ------------------------------------------------------------------------------ 157# Internal context managers used to implement the DistributionStrategy 158# base class 159 160 161class _CurrentDistributionContext(object): 162 """Context manager setting the current `tf.distribute.Strategy`. 163 164 Also: overrides the variable creator and optionally the current device. 165 """ 166 167 def __init__(self, 168 strategy, 169 var_creator_scope, 170 var_scope=None, 171 default_device=None): 172 self._context = distribution_strategy_context._CrossReplicaThreadMode( # pylint: disable=protected-access 173 strategy) 174 self._var_creator_scope = var_creator_scope 175 self._var_scope = var_scope 176 if default_device: 177 self._device_scope = ops.device(default_device) 178 else: 179 self._device_scope = None 180 181 def __enter__(self): 182 _push_per_thread_mode(self._context) 183 if self._var_scope: 184 self._var_scope.__enter__() 185 self._var_creator_scope.__enter__() 186 if self._device_scope: 187 self._device_scope.__enter__() 188 return self._context.strategy 189 190 def __exit__(self, exception_type, exception_value, traceback): 191 if self._device_scope: 192 self._device_scope.__exit__(exception_type, exception_value, traceback) 193 self._var_creator_scope.__exit__(exception_type, exception_value, traceback) 194 if self._var_scope: 195 self._var_scope.__exit__(exception_type, exception_value, traceback) 196 _pop_per_thread_mode() 197 198 199class _SameScopeAgainContext(object): 200 """Trivial context manager when you are already in `scope()`.""" 201 202 def __init__(self, strategy): 203 self._strategy = strategy 204 205 def __enter__(self): 206 return self._strategy 207 208 def __exit__(self, exception_type, exception_value, traceback): 209 del exception_type, exception_value, traceback 210 211 212# TODO(yuefengz): add more replication modes. 213@tf_export("distribute.InputReplicationMode") 214class InputReplicationMode(enum.Enum): 215 """Replication mode for input function. 216 217 * `PER_WORKER`: The input function will be called on each worker 218 independently, creating as many input pipelines as number of workers. 219 Replicas will dequeue from the local Dataset on their worker. 220 `tf.distribute.Strategy` doesn't manage any state sharing between such 221 separate input pipelines. 222 """ 223 PER_WORKER = "PER_WORKER" 224 225 226@tf_export("distribute.InputContext") 227class InputContext(object): 228 """A class wrapping information needed by an input function. 229 230 This is a context class that is passed to the user's input fn and contains 231 information about the compute replicas and input pipelines. The number of 232 compute replicas (in sync training) helps compute per input pipeline batch 233 size from the desired global batch size. Input pipeline information can be 234 used to return a different subset of the input in each input pipeline (for 235 e.g. shard the input pipeline, use a different input source etc). 236 """ 237 238 def __init__(self, 239 num_input_pipelines=1, 240 input_pipeline_id=0, 241 num_replicas_in_sync=1): 242 """Initializes an InputContext object. 243 244 Args: 245 num_input_pipelines: the number of input pipelines in a cluster. 246 input_pipeline_id: the current input pipeline id, should be an int in 247 [0,`num_input_pipelines`). 248 num_replicas_in_sync: the number of replicas that are in sync. 249 """ 250 self._num_input_pipelines = num_input_pipelines 251 self._input_pipeline_id = input_pipeline_id 252 self._num_replicas_in_sync = num_replicas_in_sync 253 254 @property 255 def num_replicas_in_sync(self): 256 """Returns the number of compute replicas in sync.""" 257 return self._num_replicas_in_sync 258 259 @property 260 def input_pipeline_id(self): 261 """Returns the input pipeline ID.""" 262 return self._input_pipeline_id 263 264 @property 265 def num_input_pipelines(self): 266 """Returns the number of input pipelines.""" 267 return self._num_input_pipelines 268 269 def get_per_replica_batch_size(self, global_batch_size): 270 """Returns the per-replica batch size. 271 272 Args: 273 global_batch_size: the global batch size which should be divisible by 274 `num_replicas_in_sync`. 275 276 Returns: 277 the per-replica batch size. 278 279 Raises: 280 ValueError: if `global_batch_size` not divisible by 281 `num_replicas_in_sync`. 282 """ 283 if global_batch_size % self._num_replicas_in_sync != 0: 284 raise ValueError("The `global_batch_size` %r is not divisible by " 285 "`num_replicas_in_sync` %r " % 286 (global_batch_size, self._num_replicas_in_sync)) 287 return global_batch_size // self._num_replicas_in_sync 288 289 290# ------------------------------------------------------------------------------ 291# Base classes for all distribution strategies. 292 293 294@tf_export("distribute.Strategy") 295class DistributionStrategy(object): 296 """A list of devices with a state & compute distribution policy. 297 298 See [tensorflow/contrib/distribute/README.md]( 299 https://www.tensorflow.org/code/tensorflow/contrib/distribute/README.md) 300 for overview and examples. 301 """ 302 303 # TODO(josh11b): Raise an exception if variable partitioning requested before 304 # we add support. 305 # TODO(josh11b): Also `parameter_device_index` property? 306 # TODO(josh11b): `map()` 307 # TODO(josh11b): ClusterSpec/ClusterResolver 308 # TODO(josh11b): Partitioned computations, state; sharding 309 # TODO(josh11b): Model parallelism: "replicas" with multiple devices; shuffling 310 # TODO(josh11b): List of replicas with their worker and parameter devices 311 # (where the parameter devices may overlap in the ps case). 312 313 def __init__(self, extended): 314 self._extended = extended 315 316 @property 317 def extended(self): 318 """`tf.distribute.StrategyExtended` with additional methods.""" 319 return self._extended 320 321 def scope(self): 322 """Returns a context manager selecting this Strategy as current. 323 324 Inside a `with strategy.scope():` code block, this thread 325 will use a variable creator set by `strategy`, and will 326 enter its "cross-replica context". 327 328 Returns: 329 A context manager. 330 """ 331 return self._extended._scope(self) # pylint: disable=protected-access 332 333 @doc_controls.do_not_generate_docs # DEPRECATED, moving to `extended` 334 def colocate_vars_with(self, colocate_with_variable): 335 """DEPRECATED: use extended.colocate_vars_with() instead.""" 336 return self._extended.colocate_vars_with(colocate_with_variable) 337 338 def make_dataset_iterator(self, dataset): 339 """Makes an iterator for input provided via `dataset`. 340 341 Data from the given dataset will be distributed evenly across all the 342 compute replicas. We will assume that the input dataset is batched by the 343 global batch size. With this assumption, we will make a best effort to 344 divide each batch across all the replicas (one or more workers). 345 If this effort fails, an error will be thrown, and the user should instead 346 use `make_input_fn_iterator` which provides more control to the user, and 347 does not try to divide a batch across replicas. 348 349 The user could also use `make_input_fn_iterator` if they want to 350 customize which input is fed to which replica/worker etc. 351 352 Args: 353 dataset: `tf.data.Dataset` that will be distributed evenly across all 354 replicas. 355 356 Returns: 357 An `tf.distribute.InputIterator` which returns inputs for each step of the 358 computation. User should call `initialize` on the returned iterator. 359 """ 360 return self._extended._make_dataset_iterator(dataset) # pylint: disable=protected-access 361 362 def make_input_fn_iterator(self, 363 input_fn, 364 replication_mode=InputReplicationMode.PER_WORKER): 365 """Returns an iterator split across replicas created from an input function. 366 367 The `input_fn` should take an `tf.distribute.InputContext` object where 368 information about batching and input sharding can be accessed: 369 370 ``` 371 def input_fn(input_context): 372 batch_size = input_context.get_per_replica_batch_size(global_batch_size) 373 d = tf.data.Dataset.from_tensors([[1.]]).repeat().batch(batch_size) 374 return d.shard(input_context.num_input_pipelines, 375 input_context.input_pipeline_id) 376 with strategy.scope(): 377 iterator = strategy.make_input_fn_iterator(input_fn) 378 replica_results = strategy.experimental_run(replica_fn, iterator) 379 ``` 380 381 The `tf.data.Dataset` returned by `input_fn` should have a per-replica 382 batch size, which may be computed using 383 `input_context.get_per_replica_batch_size`. 384 385 Args: 386 input_fn: A function taking a `tf.distribute.InputContext` object and 387 returning a `tf.data.Dataset`. 388 replication_mode: an enum value of `tf.distribute.InputReplicationMode`. 389 Only `PER_WORKER` is supported currently, which means there will be 390 a single call to `input_fn` per worker. Replicas will dequeue from the 391 local `tf.data.Dataset` on their worker. 392 393 Returns: 394 An iterator object that should first be `.initialize()`-ed. It may then 395 either be passed to `strategy.experimental_run()` or you can 396 `iterator.get_next()` to get the next value to pass to 397 `strategy.extended.call_for_each_replica()`. 398 """ 399 if replication_mode != InputReplicationMode.PER_WORKER: 400 raise ValueError( 401 "Input replication mode not supported: %r" % replication_mode) 402 with self.scope(): 403 return self.extended._make_input_fn_iterator( # pylint: disable=protected-access 404 input_fn, replication_mode=replication_mode) 405 406 def experimental_make_numpy_iterator( 407 self, numpy_input, batch_size, num_epochs=1, shuffle=1024, session=None): 408 """Makes an iterator for input provided via a nest of numpy arrays. 409 410 Args: 411 numpy_input: A nest of NumPy input arrays that will be distributed evenly 412 across all replicas. Note that lists of Numpy arrays are stacked, 413 as that is normal `tf.data.Dataset` behavior. 414 batch_size: The number of entries from the array we should consume in one 415 step of the computation, across all replicas. This is the global batch 416 size. It should be divisible by `num_replicas_in_sync`. 417 num_epochs: The number of times to iterate through the examples. A value 418 of `None` means repeat forever. 419 shuffle: Size of buffer to use for shuffling the input examples. 420 Use `None` to disable shuffling. 421 session: (TensorFlow v1.x graph execution only) A session used for 422 initialization. 423 424 Returns: 425 An `tf.distribute.InputIterator` which returns inputs for each step of the 426 computation. User should call `initialize` on the returned iterator. 427 """ 428 ds = self.extended.experimental_make_numpy_dataset( 429 numpy_input, session=session) 430 if shuffle: 431 ds = ds.shuffle(shuffle) 432 if num_epochs != 1: 433 ds = ds.repeat(num_epochs) 434 # We need to use the drop_remainder argument to get a known static 435 # input shape which is required for TPUs. 436 drop_remainder = self.extended.experimental_require_static_shapes 437 ds = ds.batch(batch_size, drop_remainder=drop_remainder) 438 return self.make_dataset_iterator(ds) 439 440 def experimental_run(self, fn, input_iterator=None): 441 """Runs ops in `fn` on each replica, with inputs from `input_iterator`. 442 443 When eager execution is enabled, executes ops specified by `fn` on each 444 replica. Otherwise, builds a graph to execute the ops on each replica. 445 446 Each replica will take a single, different input from the inputs provided by 447 one `get_next` call on the input iterator. 448 449 `fn` may call `tf.distribute.get_replica_context()` to access members such 450 as `replica_id_in_sync_group`. 451 452 IMPORTANT: Depending on the `tf.distribute.Strategy` implementation being 453 used, and whether eager execution is enabled, `fn` may be called one or more 454 times (once for each replica). 455 456 Args: 457 fn: The function to run. The inputs to the function must match the outputs 458 of `input_iterator.get_next()`. The output must be a `tf.nest` of 459 `Tensor`s. 460 input_iterator: (Optional) input iterator from which the inputs are taken. 461 462 Returns: 463 Merged return value of `fn` across replicas. The structure of the return 464 value is the same as the return value from `fn`. Each element in the 465 structure can either be `PerReplica` (if the values are unsynchronized), 466 `Mirrored` (if the values are kept in sync), or `Tensor` (if running on a 467 single replica). 468 """ 469 with self.scope(): 470 args = (input_iterator.get_next(),) if input_iterator is not None else () 471 return self.experimental_run_v2(fn, args=args) 472 473 def experimental_run_v2(self, fn, args=(), kwargs=None): 474 """Runs ops in `fn` on each replica, with the given arguments. 475 476 When eager execution is enabled, executes ops specified by `fn` on each 477 replica. Otherwise, builds a graph to execute the ops on each replica. 478 479 `fn` may call `tf.distribute.get_replica_context()` to access members such 480 as `replica_id_in_sync_group`. 481 482 IMPORTANT: Depending on the `tf.distribute.Strategy` implementation being 483 used, and whether eager execution is enabled, `fn` may be called one or more 484 times (once for each replica). 485 486 Args: 487 fn: The function to run. The output must be a `tf.nest` of `Tensor`s. 488 args: (Optional) Positional arguments to `fn`. 489 kwargs: (Optional) Keyword arguments to `fn`. 490 491 Returns: 492 Merged return value of `fn` across replicas. The structure of the return 493 value is the same as the return value from `fn`. Each element in the 494 structure can either be `PerReplica` (if the values are unsynchronized), 495 `Mirrored` (if the values are kept in sync), or `Tensor` (if running on a 496 single replica). 497 """ 498 with self.scope(): 499 return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs) 500 501 def reduce(self, reduce_op, value): 502 """Reduce `value` across replicas. 503 504 Args: 505 reduce_op: A `tf.distribute.ReduceOp` value specifying how values should 506 be combined. 507 value: A "per replica" value to be combined into a single tensor. 508 509 Returns: 510 A `Tensor`. 511 """ 512 _require_cross_replica_or_default_context_extended(self._extended) 513 return self._extended._reduce(reduce_op, value) # pylint: disable=protected-access 514 515 @doc_controls.do_not_generate_docs # DEPRECATED 516 def unwrap(self, value): 517 """Returns the list of all local per-replica values contained in `value`. 518 519 DEPRECATED: Please use `experimental_local_results` instead. 520 521 Note: This only returns values on the workers initiated by this client. 522 When using a `Strategy` like 523 `tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker 524 will be its own client, and this function will only return values 525 computed on that worker. 526 527 Args: 528 value: A value returned by `experimental_run()`, 529 `extended.call_for_each_replica()`, or a variable created in `scope`. 530 531 Returns: 532 A tuple of values contained in `value`. If `value` represents a single 533 value, this returns `(value,).` 534 """ 535 return self._extended._local_results(value) # pylint: disable=protected-access 536 537 def experimental_local_results(self, value): 538 """Returns the list of all local per-replica values contained in `value`. 539 540 Note: This only returns values on the workers initiated by this client. 541 When using a `Strategy` like 542 `tf.distribute.experimental.MultiWorkerMirroredStrategy`, each worker 543 will be its own client, and this function will only return values 544 computed on that worker. 545 546 Args: 547 value: A value returned by `experimental_run()`, `experimental_run_v2()`, 548 `extended.call_for_each_replica()`, or a variable created in `scope`. 549 550 Returns: 551 A tuple of values contained in `value`. If `value` represents a single 552 value, this returns `(value,).` 553 """ 554 return self._extended._local_results(value) # pylint: disable=protected-access 555 556 @doc_controls.do_not_generate_docs # DEPRECATED: TF v1.x only 557 def group(self, value, name=None): 558 """Shortcut for `tf.group(self.experimental_local_results(value))`.""" 559 return self._extended._group(value, name) # pylint: disable=protected-access 560 561 @property 562 def num_replicas_in_sync(self): 563 """Returns number of replicas over which gradients are aggregated.""" 564 return self._extended._num_replicas_in_sync # pylint: disable=protected-access 565 566 @doc_controls.do_not_generate_docs # DEPRECATED, being replaced by a new API. 567 def configure(self, 568 session_config=None, 569 cluster_spec=None, 570 task_type=None, 571 task_id=None): 572 # pylint: disable=g-doc-return-or-yield,g-doc-args 573 """DEPRECATED: use `update_config_proto` instead. 574 575 Configures the strategy class. 576 577 DEPRECATED: This method's functionality has been split into the strategy 578 constructor and `update_config_proto`. In the future, we will allow passing 579 cluster and config_proto to the constructor to configure the strategy. And 580 `update_config_proto` can be used to update the config_proto based on the 581 specific strategy. 582 """ 583 return self._extended._configure( # pylint: disable=protected-access 584 session_config, cluster_spec, task_type, task_id) 585 586 def update_config_proto(self, config_proto): 587 """Returns a copy of `config_proto` modified for use with this strategy. 588 589 The updated config has something needed to run a strategy, e.g. 590 configuration to run collective ops, or device filters to improve 591 distributed training performance. 592 593 Args: 594 config_proto: a `tf.ConfigProto` object. 595 596 Returns: 597 The updated copy of the `config_proto`. 598 """ 599 return self._extended._update_config_proto(config_proto) # pylint: disable=protected-access 600 601 def __deepcopy__(self, memo): 602 # First do a regular deepcopy of `self`. 603 cls = self.__class__ 604 result = cls.__new__(cls) 605 memo[id(self)] = result 606 for k, v in self.__dict__.items(): 607 setattr(result, k, copy.deepcopy(v, memo)) 608 # One little fix-up: we want `result._extended` to reference `result` 609 # instead of `self`. 610 result._extended._container_strategy_weakref = weakref.ref(result) # pylint: disable=protected-access 611 return result 612 613 def __copy__(self): 614 raise RuntimeError("Must only deepcopy DistributionStrategy.") 615 616 617@tf_export("distribute.StrategyExtended") 618class DistributionStrategyExtended(object): 619 """Additional APIs for algorithms that need to be distribution-aware. 620 621 The intent is that you can write an algorithm in a stylized way and 622 it will be usable with a variety of different 623 `tf.distribute.Strategy` 624 implementations. Each descendant will implement a different strategy 625 for distributing the algorithm across multiple devices/machines. 626 Furthermore, these changes can be hidden inside the specific layers 627 and other library classes that need special treatment to run in a 628 distributed setting, so that most users' model definition code can 629 run unchanged. The `tf.distribute.Strategy` API works the same way 630 with eager and graph execution. 631 632 First let's introduce a few high-level concepts: 633 634 * _Data parallelism_ is where we run multiple copies of the model 635 on different slices of the input data. This is in contrast to 636 _model parallelism_ where we divide up a single copy of a model 637 across multiple devices. 638 Note: we only support data parallelism for now, but 639 hope to add support for model parallelism in the future. 640 * A _replica_ is one copy of the model, running on one slice of the 641 input data. 642 * _Synchronous_, or more commonly _sync_, training is where the 643 updates from each replica are aggregated together before updating 644 the model variables. This is in contrast to _asynchronous_, or 645 _async_ training, where each replica updates the model variables 646 independently. 647 * Furthermore you might run your computation on multiple devices 648 on one machine (or "host"), or on multiple machines/hosts. 649 If you are running on multiple machines, you might have a 650 single master host that drives computation across all of them, 651 or you might have multiple clients driving the computation 652 asynchronously. 653 654 To distribute an algorithm, we might use some of these ingredients: 655 656 * Parameter servers: These are hosts that hold a single copy of 657 parameters/variables. All replicas that want to operate on a variable 658 retrieve it at the beginning of a step and send an update to be 659 applied at the end of the step. Can support either sync or async 660 training. 661 * Mirrored variables: These are variables that are copied to multiple 662 devices, where we keep the copies in sync by applying the same 663 updates to every copy. Normally would only be used with sync training. 664 * Reductions and Allreduce: A _reduction_ is some method of 665 aggregating multiple values into one value, like "sum" or 666 "mean". If doing sync training, we will perform a reduction on the 667 gradients to a parameter from all replicas before applying the 668 update. Allreduce is an algorithm for performing a reduction on 669 values from multiple devices and making the result available on 670 all of those devices. 671 * In the future we will have support for TensorFlow's partitioned 672 variables, where a single variable is split across multiple 673 devices. 674 675 We have then a few approaches we want to support: 676 677 * Code written (as if) with no knowledge of class `tf.distribute.Strategy`. 678 This code should work as before, even if some of the layers, etc. 679 used by that code are written to be distribution-aware. This is done 680 by having a default `tf.distribute.Strategy` that gives ordinary behavior, 681 and by default being in a single replica context. 682 * Ordinary model code that you want to run using a specific 683 `tf.distribute.Strategy`. This can be as simple as: 684 685 ``` 686 with my_strategy.scope(): 687 iterator = my_strategy.make_dataset_iterator(dataset) 688 session.run(iterator.initialize()) 689 replica_train_ops = my_strategy.extended.call_for_each_replica( 690 replica_fn, args=(iterator.get_next(),)) 691 train_op = my_strategy.group(replica_train_ops) 692 ``` 693 694 This takes an ordinary `dataset` and `replica_fn` and runs it 695 distributed using a particular `tf.distribute.Strategy` in 696 `my_strategy`. Any variables created in `replica_fn` are created 697 using `my_strategy`'s policy, and library functions called by 698 `replica_fn` can use the `get_replica_context()` API to get enhanced 699 behavior in this case. 700 701 * If you want to write a distributed algorithm, you may use any of 702 the `tf.distribute.Strategy` APIs inside a 703 `with my_strategy.scope():` block of code. 704 705 Lower-level concepts: 706 707 * Wrapped values: In order to represent values parallel across devices 708 (either replicas or the devices associated with a particular value), we 709 wrap them in a "PerReplica" or "Mirrored" object that contains a map 710 from device to values. "PerReplica" is used when the value may be 711 different across replicas, and "Mirrored" when the value are the same. 712 * Unwrapping and merging: Consider calling a function `fn` on multiple 713 replicas, like `extended.call_for_each_replica(fn, args=[w])` with an 714 argument `w` that is a wrapped value. This means `w` will have a map taking 715 replica device `d0` to `w0`, replica device `d1` to `w1`, 716 etc. `extended.call_for_each_replica()` unwraps `w` before calling `fn`, so 717 it calls `fn(w0)` on `d0`, `fn(w1)` on `d1`, etc. It then merges the return 718 values from `fn()`, which can possibly result in wrapped values. For 719 example, let's say `fn()` returns a tuple with three components: `(x, a, 720 v0)` from replica 0, `(x, b, v1)` on replica 1, etc. If the first component 721 is the same object `x` from every replica, then the first component of the 722 merged result will also be `x`. If the second component is different (`a`, 723 `b`, ...) from each replica, then the merged value will have a wrapped map 724 from replica device to the different values. If the third component is the 725 members of a mirrored variable (`v` maps `d0` to `v0`, `d1` to `v1`, etc.), 726 then the merged result will be that mirrored variable (`v`). 727 * Replica context vs. Cross-replica context: _replica context_ is when we 728 are in some function that is being called once for each replica. 729 Otherwise we are in cross-replica context, which is useful for 730 calling `tf.distribute.Strategy` methods which operate across the 731 replicas (like `reduce_to()`). By default you start in a replica context 732 (the default "single replica context") and then some methods can 733 switch you back and forth, as described below. 734 * Worker devices vs. parameter devices: Most replica computations will 735 happen on worker devices. Since we don't yet support model 736 parallelism, there will be one worker device per replica. When using 737 parameter servers (see above), the set of devices holding 738 variables may be different, otherwise the parameter devices might 739 match the worker devices. 740 * Non-slot devices are some subset of the parameter devices where we 741 put all the non-slot variables. We need to ensure that all 742 non-slot variables are allocated on the same device, or mirrored 743 across the same set of devices. If you have some variable you want 744 to colocate all the non-slot variables with, you can use 745 `colocate_vars_with()` to get the remaining non-slot variables on 746 the same device. Otherwise you can use `non_slot_devices()` to 747 pick a consistent set of devices to pass to both 748 `colocate_vars_with()` and `update_non_slot()`. 749 750 When using a `tf.distribute.Strategy`, we have a new type dimension 751 called _locality_ that says what values are compatible with which 752 APIs: 753 754 * T: different value for each replica (e.g. a PerReplica-wrapped value). 755 * M: value is "mirrored" across replicas, i.e. there are copies with the 756 same value on each replica (e.g. a Mirrored-wrapped value). 757 * V(`v`): value is "mirrored" across all the devices which have a 758 copy of variable `v` (also a Mirrored-wrapped value, but over 759 parameter devices instead of worker devices). 760 * N: value is "mirrored" across all the "non-slot" devices 761 762 Rules for methods with respect to locality and single-replica vs. 763 cross-replica context: 764 765 * `with d.scope()`: default single-replica context -> cross-replica context 766 for `d` 767 * `with d.extended.colocate_vars_with(v)`: in replica/cross-replica context, 768 variables will be created with locality V(`v`). That is, if we write 769 `with d.extended.colocate_vars_with(v1): v2 = tf.get_variable(...)`, 770 then `v2` will have locality V(`v1`), i.e. locality V(`v2`) will equal 771 V(`v1`). 772 * `with d.extended.colocate_vars_with(d.extended.non_slot_devices(...))`: in 773 replica/cross-replica context, variables will be created with locality N 774 * `v = tf.get_variable(...)`: in replica/cross-replica context, creates 775 a variable (which by definition will have locality V(`v`), though 776 will match another locality if inside a `colocate_vars_with` 777 scope). 778 * `d.make_dataset_iterator(dataset)`: in cross-replica 779 context, produces an iterator with locality T 780 * `d.extended.broadcast_to(t, v)`: in cross-replica context, produces a value 781 with locality V(`v`) 782 * `d.extended.call_for_each_replica(fn, ...)`: in cross-replica context, runs 783 `fn()` in a replica context (and so may call `get_replica_context()` and 784 use its API, including `merge_call()` to get back to cross-replica 785 context), once for each replica. May use values with locality T or 786 M, and any variable. 787 * `d.extended.reduce_to(m, t, t)`: in cross-replica context, accepts t with 788 locality T and produces a value with locality M. 789 * `d.extended.reduce_to(m, t, v)`: in cross-replica context, accepts t with 790 locality T and produces a value with locality V(`v`). 791 * `d.extended.batch_reduce_to(m, [(t, v)]): see `d.extended.reduce_to()` 792 * `d.extended.update(v, fn, ...)`: in cross-replica context, runs `fn()` once 793 for each device `v` is copied to, all inputs should have locality 794 V(`v`), output will have locality V(`v`) as well. 795 * `d.extended.update_non_slot(d.extended.non_slot_devices(), fn)`: in 796 cross-replica context, like `d.extended.update()` except with locality N. 797 * `d.extended.read_var(v)`: Gets the (read-only) value of the variable `v` (on 798 the device determined by the current device scope), aggregating 799 across replicas for replica-local variables. Frequently, this will be 800 done automatically when using `v` in an expression or fetching it in 801 a cross-replica context, but this function can be used to force that 802 conversion happens at a particular point in time (for example, to 803 add the result of the conversion to a graph collection). 804 805 The standard pattern for updating variables is to: 806 807 1. Create an input iterator with `d.make_dataset_iterator()`. 808 2. Define each replica `d.extended.call_for_each_replica()` up to the point of 809 getting a list of gradient, variable pairs. 810 3. Call `d.extended.reduce_to(VariableAggregation.SUM, t, v)` or 811 `d.extended.batch_reduce_to()` to sum the gradients (with locality T) 812 into values with locality V(`v`). 813 4. Call `d.extended.update(v)` for each variable to update its value. 814 815 Steps 3 and 4 are done automatically by class `Optimizer` if you call 816 its `apply_gradients` method in a replica context. Otherwise you can 817 manually call its `_distributed_apply` method in a cross-replica context. 818 819 Another thing you might want to do in the middle of your replica function is 820 an all-reduce of some intermediate value, using `d.extended.reduce_to()` or 821 `d.extended.batch_reduce_to()`. You simply provide the same tensor as the 822 input and destination. 823 824 Layers should expect to be called in a replica context, and can use 825 the `tf.distribute.get_replica_context` function to get a 826 `tf.distribute.ReplicaContext` object. The 827 `ReplicaContext` object has a `merge_call()` method for entering 828 cross-replica context where you can use `reduce_to()` (or 829 `batch_reduce_to()`) and then optionally `update()` to update state. 830 831 You may use this API whether or not a `tf.distribute.Strategy` is 832 being used, since there is a default implementation of 833 `ReplicaContext` and `tf.distribute.Strategy`. 834 835 NOTE for new `tf.distribute.Strategy` implementations: Please put all logic 836 in a subclass of `tf.distribute.StrategyExtended`. The only code needed for 837 the `tf.distribute.Strategy` subclass is for instantiating your subclass of 838 `tf.distribute.StrategyExtended` in the `__init__` method. 839 """ 840 841 def __init__(self, container_strategy): 842 self._container_strategy_weakref = weakref.ref(container_strategy) 843 self._default_device = None 844 # This property is used to determine if we should set drop_remainder=True 845 # when creating Datasets from numpy array inputs. 846 self._require_static_shapes = False 847 848 def _container_strategy(self): 849 """Get the containing `DistributionStrategy`. 850 851 This should not generally be needed except when creating a new 852 `ReplicaContext` and to validate that the caller is in the correct 853 `scope()`. 854 855 Returns: 856 The `DistributionStrategy` such that `strategy.extended` is `self`. 857 """ 858 container_strategy = self._container_strategy_weakref() 859 assert container_strategy is not None 860 return container_strategy 861 862 def _scope(self, strategy): 863 """Implementation of DistributionStrategy.scope().""" 864 if distribution_strategy_context.has_strategy(): 865 _require_cross_replica_or_default_context_extended(self) 866 return _SameScopeAgainContext(strategy) 867 868 def creator_with_resource_vars(*args, **kwargs): 869 _require_strategy_scope_extended(self) 870 kwargs["use_resource"] = True 871 kwargs["distribute_strategy"] = strategy 872 return self._create_variable(*args, **kwargs) 873 874 def distributed_getter(getter, *args, **kwargs): 875 if not self._allow_variable_partition(): 876 if kwargs.pop("partitioner", None) is not None: 877 tf_logging.log_first_n( 878 tf_logging.WARN, "Partitioned variables are disabled when using " 879 "current tf.distribute.Strategy.", 1) 880 return getter(*args, **kwargs) 881 882 return _CurrentDistributionContext( 883 strategy, 884 variable_scope.variable_creator_scope(creator_with_resource_vars), 885 variable_scope.variable_scope( 886 variable_scope.get_variable_scope(), 887 custom_getter=distributed_getter), self._default_device) 888 889 def _allow_variable_partition(self): 890 return False 891 892 def _create_variable(self, next_creator, *args, **kwargs): 893 # Note: should support "colocate_with" argument. 894 raise NotImplementedError("must be implemented in descendants") 895 896 def variable_created_in_scope(self, v): 897 """Tests whether `v` was created while this strategy scope was active. 898 899 Variables created inside the strategy scope are "owned" by it: 900 901 >>> with strategy.scope(): 902 ... v = tf.Variable(1.) 903 >>> strategy.variable_created_in_scope(v) 904 True 905 906 Variables created outside the strategy are not owned by it: 907 908 >>> v = tf.Variable(1.) 909 >>> strategy.variable_created_in_scope(v) 910 False 911 912 Args: 913 v: A `tf.Variable` instance. 914 915 Returns: 916 True if `v` was created inside the scope, False if not. 917 """ 918 return v._distribute_strategy == self._container_strategy_weakref() # pylint: disable=protected-access 919 920 def read_var(self, v): 921 """Reads the value of a variable. 922 923 Returns the aggregate value of a replica-local variable, or the 924 (read-only) value of any other variable. 925 926 Args: 927 v: A variable allocated within the scope of this `tf.distribute.Strategy`. 928 929 Returns: 930 A tensor representing the value of `v`, aggregated across replicas if 931 necessary. 932 """ 933 raise NotImplementedError("must be implemented in descendants") 934 935 def colocate_vars_with(self, colocate_with_variable): 936 """Scope that controls which devices variables will be created on. 937 938 No operations should be added to the graph inside this scope, it 939 should only be used when creating variables (some implementations 940 work by changing variable creation, others work by using a 941 tf.colocate_with() scope). 942 943 This may only be used inside `self.scope()`. 944 945 Example usage: 946 947 ``` 948 with strategy.scope(): 949 var1 = tf.get_variable(...) 950 with strategy.extended.colocate_vars_with(var1): 951 # var2 and var3 will be created on the same device(s) as var1 952 var2 = tf.get_variable(...) 953 var3 = tf.get_variable(...) 954 955 def fn(v1, v2, v3): 956 # operates on v1 from var1, v2 from var2, and v3 from var3 957 958 # `fn` runs on every device `var1` is on, `var2` and `var3` will be there 959 # too. 960 strategy.extended.update(var1, fn, args=(var2, var3)) 961 ``` 962 963 Args: 964 colocate_with_variable: A variable created in this strategy's `scope()`. 965 Variables created while in the returned context manager will be on the 966 same set of devices as `colocate_with_variable`. 967 968 Returns: 969 A context manager. 970 """ 971 def create_colocated_variable(next_creator, *args, **kwargs): 972 _require_strategy_scope_extended(self) 973 kwargs["use_resource"] = True 974 kwargs["colocate_with"] = colocate_with_variable 975 return next_creator(*args, **kwargs) 976 977 _require_strategy_scope_extended(self) 978 self._validate_colocate_with_variable(colocate_with_variable) 979 return variable_scope.variable_creator_scope(create_colocated_variable) 980 981 def _validate_colocate_with_variable(self, colocate_with_variable): 982 """Validate `colocate_with_variable` argument to `colocate_vars_with`.""" 983 pass 984 985 def _make_dataset_iterator(self, dataset): 986 raise NotImplementedError("must be implemented in descendants") 987 988 def _make_input_fn_iterator(self, input_fn, replication_mode): 989 raise NotImplementedError("must be implemented in descendants") 990 991 def experimental_make_numpy_dataset(self, numpy_input, session=None): 992 """Makes a dataset for input provided via a numpy array. 993 994 This avoids adding `numpy_input` as a large constant in the graph, 995 and copies the data to the machine or machines that will be processing 996 the input. 997 998 Args: 999 numpy_input: A nest of NumPy input arrays that will be distributed evenly 1000 across all replicas. Note that lists of Numpy arrays are stacked, 1001 as that is normal `tf.data.Dataset` behavior. 1002 session: (TensorFlow v1.x graph execution only) A session used for 1003 initialization. 1004 1005 Returns: 1006 A `tf.data.Dataset` representing `numpy_input`. 1007 """ 1008 _require_cross_replica_or_default_context_extended(self) 1009 return self._experimental_make_numpy_dataset(numpy_input, session=session) 1010 1011 def _experimental_make_numpy_dataset(self, numpy_input, session): 1012 raise NotImplementedError("must be implemented in descendants") 1013 1014 def broadcast_to(self, tensor, destinations): 1015 """Mirror a tensor on one device to all worker devices. 1016 1017 Args: 1018 tensor: A Tensor value to broadcast. 1019 destinations: A mirrored variable or device string specifying the 1020 destination devices to copy `tensor` to. 1021 1022 Returns: 1023 A value mirrored to `destinations` devices. 1024 """ 1025 assert destinations is not None # from old strategy.broadcast() 1026 # TODO(josh11b): More docstring 1027 _require_cross_replica_or_default_context_extended(self) 1028 assert not isinstance(destinations, (list, tuple)) 1029 return self._broadcast_to(tensor, destinations) 1030 1031 def _broadcast_to(self, tensor, destinations): 1032 raise NotImplementedError("must be implemented in descendants") 1033 1034 def experimental_run_steps_on_iterator(self, fn, iterator, iterations=1, 1035 initial_loop_values=None): 1036 """Run `fn` with input from `iterator` for `iterations` times. 1037 1038 This method can be used to run a step function for training a number of 1039 times using input from a dataset. 1040 1041 Args: 1042 fn: function to run using this distribution strategy. The function must 1043 have the following signature: `def fn(context, inputs)`. 1044 `context` is an instance of `MultiStepContext` that will be passed when 1045 `fn` is run. `context` can be used to specify the outputs to be returned 1046 from `fn` by calling `context.set_last_step_output`. It can also be used 1047 to capture non tensor outputs by `context.set_non_tensor_output`. 1048 See `MultiStepContext` documentation for more information. 1049 `inputs` will have same type/structure as `iterator.get_next()`. 1050 Typically, `fn` will use `call_for_each_replica` method of the strategy 1051 to distribute the computation over multiple replicas. 1052 iterator: Iterator of a dataset that represents the input for `fn`. The 1053 caller is responsible for initializing the iterator as needed. 1054 iterations: (Optional) Number of iterations that `fn` should be run. 1055 Defaults to 1. 1056 initial_loop_values: (Optional) Initial values to be passed into the 1057 loop that runs `fn`. Defaults to `None`. # TODO(priyag): Remove 1058 initial_loop_values argument when we have a mechanism to infer the 1059 outputs of `fn`. 1060 1061 Returns: 1062 Returns the `MultiStepContext` object which has the following properties, 1063 among other things: 1064 - run_op: An op that runs `fn` `iterations` times. 1065 - last_step_outputs: A dictionary containing tensors set using 1066 `context.set_last_step_output`. Evaluating this returns the value of 1067 the tensors after the last iteration. 1068 - non_tensor_outputs: A dictionatry containing anything that was set by 1069 `fn` by calling `context.set_non_tensor_output`. 1070 """ 1071 _require_cross_replica_or_default_context_extended(self) 1072 with self._container_strategy().scope(): 1073 return self._experimental_run_steps_on_iterator( 1074 fn, iterator, iterations, initial_loop_values) 1075 1076 def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, 1077 initial_loop_values): 1078 raise NotImplementedError("must be implemented in descendants") 1079 1080 def call_for_each_replica(self, fn, args=(), kwargs=None): 1081 """Run `fn` once per replica. 1082 1083 `fn` may call `tf.get_replica_context()` to access methods such as 1084 `replica_id_in_sync_group` and `merge_call()`. 1085 1086 `merge_call()` is used to communicate between the replicas and 1087 re-enter the cross-replica context. All replicas pause their execution 1088 having encountered a `merge_call()` call. After that the 1089 `merge_fn`-function is executed. Its results are then unwrapped and 1090 given back to each replica call. After that execution resumes until 1091 `fn` is complete or encounters another `merge_call()`. Example: 1092 1093 ```python 1094 # Called once in "cross-replica" context. 1095 def merge_fn(distribution, three_plus_replica_id): 1096 # sum the values across replicas 1097 return sum(distribution.experimental_local_results(three_plus_replica_id)) 1098 1099 # Called once per replica in `distribution`, in a "replica" context. 1100 def fn(three): 1101 replica_ctx = tf.get_replica_context() 1102 v = three + replica_ctx.replica_id_in_sync_group 1103 # Computes the sum of the `v` values across all replicas. 1104 s = replica_ctx.merge_call(merge_fn, args=(v,)) 1105 return s + v 1106 1107 with distribution.scope(): 1108 # in "cross-replica" context 1109 ... 1110 merged_results = distribution.call_for_each_replica(fn, args=[3]) 1111 # merged_results has the values from every replica execution of `fn`. 1112 # This statement prints a list: 1113 print(distribution.experimental_local_results(merged_results)) 1114 ``` 1115 1116 Args: 1117 fn: function to run (will be run once per replica). 1118 args: Tuple or list with positional arguments for `fn`. 1119 kwargs: Dict with keyword arguments for `fn`. 1120 1121 Returns: 1122 Merged return value of `fn` across all replicas. 1123 """ 1124 _require_cross_replica_or_default_context_extended(self) 1125 if kwargs is None: 1126 kwargs = {} 1127 with self._container_strategy().scope(): 1128 return self._call_for_each_replica(fn, args, kwargs) 1129 1130 def _call_for_each_replica(self, fn, args, kwargs): 1131 raise NotImplementedError("must be implemented in descendants") 1132 1133 def _reduce(self, reduce_op, value): 1134 # Default implementation until we have an implementation for each strategy. 1135 return self._local_results( 1136 self._reduce_to(reduce_op, value, 1137 device_util.current() or "/device:CPU:0"))[0] 1138 1139 def reduce_to(self, reduce_op, value, destinations): 1140 """Combine (via e.g. sum or mean) values across replicas. 1141 1142 Args: 1143 reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum. 1144 value: A per-replica value with one value per replica. 1145 destinations: A mirrored variable, a per-replica tensor, or a device 1146 string. The return value will be copied to all destination devices (or 1147 all the devices where the `destinations` value resides). To perform an 1148 all-reduction, pass `value` to `destinations`. 1149 1150 Returns: 1151 A value mirrored to `destinations`. 1152 """ 1153 # TODO(josh11b): More docstring 1154 _require_cross_replica_or_default_context_extended(self) 1155 assert not isinstance(destinations, (list, tuple)) 1156 assert not isinstance(reduce_op, variable_scope.VariableAggregation) 1157 assert (reduce_op == reduce_util.ReduceOp.SUM or 1158 reduce_op == reduce_util.ReduceOp.MEAN) 1159 return self._reduce_to(reduce_op, value, destinations) 1160 1161 def _reduce_to(self, reduce_op, value, destinations): 1162 raise NotImplementedError("must be implemented in descendants") 1163 1164 def batch_reduce_to(self, reduce_op, value_destination_pairs): 1165 """Combine multiple `reduce_to` calls into one for faster execution. 1166 1167 Args: 1168 reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum. 1169 value_destination_pairs: A sequence of (value, destinations) 1170 pairs. See `reduce_to()` for a description. 1171 1172 Returns: 1173 A list of mirrored values, one per pair in `value_destination_pairs`. 1174 """ 1175 # TODO(josh11b): More docstring 1176 _require_cross_replica_or_default_context_extended(self) 1177 assert not isinstance(reduce_op, variable_scope.VariableAggregation) 1178 return self._batch_reduce_to(reduce_op, value_destination_pairs) 1179 1180 def _batch_reduce_to(self, reduce_op, value_destination_pairs): 1181 return [ 1182 self.reduce_to(reduce_op, t, destinations=v) 1183 for t, v in value_destination_pairs 1184 ] 1185 1186 def update(self, var, fn, args=(), kwargs=None, group=True): 1187 """Run `fn` to update `var` using inputs mirrored to the same devices. 1188 1189 If `var` is mirrored across multiple devices, then this implements 1190 logic like: 1191 1192 ``` 1193 results = {} 1194 for device, v in var: 1195 with tf.device(device): 1196 # args and kwargs will be unwrapped if they are mirrored. 1197 results[device] = fn(v, *args, **kwargs) 1198 return merged(results) 1199 ``` 1200 1201 Otherwise this returns `fn(var, *args, **kwargs)` colocated with `var`. 1202 1203 Neither `args` nor `kwargs` may contain per-replica values. 1204 If they contain mirrored values, they will be unwrapped before 1205 calling `fn`. 1206 1207 Args: 1208 var: Variable, possibly mirrored to multiple devices, to operate on. 1209 fn: Function to call. Should take the variable as the first argument. 1210 args: Tuple or list. Additional positional arguments to pass to `fn()`. 1211 kwargs: Dict with keyword arguments to pass to `fn()`. 1212 group: Boolean. Defaults to True. If False, the return value will be 1213 unwrapped. 1214 1215 Returns: 1216 By default, the merged return value of `fn` across all replicas. The 1217 merged result has dependencies to make sure that if it is evaluated at 1218 all, the side effects (updates) will happen on every replica. If instead 1219 "group=False" is specified, this function will return a nest of lists 1220 where each list has an element per replica, and the caller is responsible 1221 for ensuring all elements are executed. 1222 """ 1223 _require_cross_replica_or_default_context_extended(self) 1224 if kwargs is None: 1225 kwargs = {} 1226 with self._container_strategy().scope(): 1227 return self._update(var, fn, args, kwargs, group) 1228 1229 def _update(self, var, fn, args, kwargs, group): 1230 raise NotImplementedError("must be implemented in descendants") 1231 1232 def update_non_slot( 1233 self, colocate_with, fn, args=(), kwargs=None, group=True): 1234 """Runs `fn(*args, **kwargs)` on `colocate_with` devices. 1235 1236 Args: 1237 colocate_with: The return value of `non_slot_devices()`. 1238 fn: Function to execute. 1239 args: Tuple or list. Positional arguments to pass to `fn()`. 1240 kwargs: Dict with keyword arguments to pass to `fn()`. 1241 group: Boolean. Defaults to True. If False, the return value will be 1242 unwrapped. 1243 1244 Returns: 1245 Return value of `fn`, possibly merged across devices. 1246 """ 1247 _require_cross_replica_or_default_context_extended(self) 1248 if kwargs is None: 1249 kwargs = {} 1250 with self._container_strategy().scope(): 1251 return self._update_non_slot(colocate_with, fn, args, kwargs, group) 1252 1253 def _update_non_slot(self, colocate_with, fn, args, kwargs, group): 1254 raise NotImplementedError("must be implemented in descendants") 1255 1256 def _local_results(self, distributed_value): 1257 raise NotImplementedError("must be implemented in descendants") 1258 1259 def value_container(self, value): 1260 """Returns the container that this per-replica `value` belongs to. 1261 1262 Args: 1263 value: A value returned by `call_for_each_replica()` or a variable 1264 created in `scope()`. 1265 1266 Returns: 1267 A container that `value` belongs to. 1268 If value does not belong to any container (including the case of 1269 container having been destroyed), returns the value itself. 1270 `value in experimental_local_results(value_container(value))` will 1271 always be true. 1272 """ 1273 raise NotImplementedError("must be implemented in descendants") 1274 1275 def _group(self, value, name=None): 1276 """Implementation of `group`.""" 1277 value = nest.flatten(self._local_results(value)) 1278 1279 if len(value) != 1 or name is not None: 1280 return control_flow_ops.group(value, name=name) 1281 # Special handling for the common case of one op. 1282 v, = value 1283 if hasattr(v, "op"): 1284 v = v.op 1285 return v 1286 1287 @property 1288 def experimental_require_static_shapes(self): 1289 return self._require_static_shapes 1290 1291 @property 1292 def _num_replicas_in_sync(self): 1293 """Returns number of replicas over which gradients are aggregated.""" 1294 raise NotImplementedError("must be implemented in descendants") 1295 1296 @property 1297 def worker_devices(self): 1298 """Returns the tuple of all devices used to for compute replica execution. 1299 """ 1300 # TODO(josh11b): More docstring 1301 raise NotImplementedError("must be implemented in descendants") 1302 1303 @property 1304 def parameter_devices(self): 1305 """Returns the tuple of all devices used to place variables.""" 1306 # TODO(josh11b): More docstring 1307 raise NotImplementedError("must be implemented in descendants") 1308 1309 def non_slot_devices(self, var_list): 1310 """Device(s) for non-slot variables. 1311 1312 Create variables on these devices in a 1313 `with colocate_vars_with(non_slot_devices(...)):` block. 1314 Update those using `update_non_slot()`. 1315 1316 Args: 1317 var_list: The list of variables being optimized, needed with the 1318 default `tf.distribute.Strategy`. 1319 """ 1320 raise NotImplementedError("must be implemented in descendants") 1321 1322 @property 1323 def experimental_between_graph(self): 1324 """Whether the strategy uses between-graph replication or not. 1325 1326 This is expected to return a constant value that will not be changed 1327 throughout its life cycle. 1328 """ 1329 raise NotImplementedError("must be implemented in descendants") 1330 1331 def _configure(self, 1332 session_config=None, 1333 cluster_spec=None, 1334 task_type=None, 1335 task_id=None): 1336 """Configures the strategy class.""" 1337 del session_config, cluster_spec, task_type, task_id 1338 1339 def _update_config_proto(self, config_proto): 1340 return copy.deepcopy(config_proto) 1341 1342 @property 1343 def experimental_should_init(self): 1344 """Whether initialization is needed.""" 1345 raise NotImplementedError("must be implemented in descendants") 1346 1347 @property 1348 def should_checkpoint(self): 1349 """Whether checkpointing is needed.""" 1350 raise NotImplementedError("must be implemented in descendants") 1351 1352 @property 1353 def should_save_summary(self): 1354 """Whether saving summaries is needed.""" 1355 raise NotImplementedError("must be implemented in descendants") 1356 1357 1358# A note about the difference between the context managers 1359# `ReplicaContext` (defined here) and `_CurrentDistributionContext` 1360# (defined above) used by `DistributionStrategy.scope()`: 1361# 1362# * a ReplicaContext is only present during a `call_for_each_replica()` 1363# call (except during a `merge_run` call) and in such a scope it 1364# will be returned by calls to `get_replica_context()`. Implementers of new 1365# DistributionStrategy descendants will frequently also need to 1366# define a descendant of ReplicaContext, and are responsible for 1367# entering and exiting this context. 1368# 1369# * DistributionStrategy.scope() sets up a variable_creator scope that 1370# changes variable creation calls (e.g. to make mirrored 1371# variables). This is intended as an outer scope that users enter once 1372# around their model creation and graph definition. There is no 1373# anticipated need to define descendants of _CurrentDistributionContext. 1374# It sets the current DistributionStrategy for purposes of 1375# `get_strategy()` and `has_strategy()` 1376# and switches the thread mode to a "cross-replica context". 1377@tf_export("distribute.ReplicaContext") 1378class ReplicaContext(object): 1379 """`tf.distribute.Strategy` API when in a replica context. 1380 1381 To be used inside your replicated step function, such as in a 1382 `tf.distribute.StrategyExtended.call_for_each_replica` call. 1383 """ 1384 1385 def __init__(self, strategy, replica_id_in_sync_group): 1386 self._strategy = strategy 1387 self._thread_context = distribution_strategy_context._InReplicaThreadMode( # pylint: disable=protected-access 1388 self) 1389 self._replica_id_in_sync_group = replica_id_in_sync_group 1390 self._summary_recording_distribution_strategy = None 1391 1392 def __enter__(self): 1393 _push_per_thread_mode(self._thread_context) 1394 ctx = eager_context.context() 1395 1396 def replica_id_is_zero(): 1397 return math_ops.equal(self._replica_id_in_sync_group, 1398 constant_op.constant(0)) 1399 1400 self._summary_recording_distribution_strategy = ( 1401 ctx.summary_recording_distribution_strategy) 1402 ctx.summary_recording_distribution_strategy = replica_id_is_zero 1403 1404 def __exit__(self, exception_type, exception_value, traceback): 1405 ctx = eager_context.context() 1406 ctx.summary_recording_distribution_strategy = ( 1407 self._summary_recording_distribution_strategy) 1408 _pop_per_thread_mode() 1409 1410 def merge_call(self, merge_fn, args=(), kwargs=None): 1411 """Merge args across replicas and run `merge_fn` in a cross-replica context. 1412 1413 This allows communication and coordination when there are multiple calls 1414 to a model function triggered by a call to 1415 `strategy.extended.call_for_each_replica(model_fn, ...)`. 1416 1417 See `tf.distribute.StrategyExtended.call_for_each_replica` for an 1418 explanation. 1419 1420 If not inside a distributed scope, this is equivalent to: 1421 1422 ``` 1423 strategy = tf.distribute.get_strategy() 1424 with cross-replica-context(strategy): 1425 return merge_fn(strategy, *args, **kwargs) 1426 ``` 1427 1428 Args: 1429 merge_fn: function that joins arguments from threads that are given as 1430 PerReplica. It accepts `tf.distribute.Strategy` object as 1431 the first argument. 1432 args: List or tuple with positional per-thread arguments for `merge_fn`. 1433 kwargs: Dict with keyword per-thread arguments for `merge_fn`. 1434 1435 Returns: 1436 The return value of `merge_fn`, except for `PerReplica` values which are 1437 unpacked. 1438 """ 1439 require_replica_context(self) 1440 if kwargs is None: 1441 kwargs = {} 1442 return self._merge_call(merge_fn, args, kwargs) 1443 1444 def _merge_call(self, merge_fn, args, kwargs): 1445 """Default implementation for single replica.""" 1446 _push_per_thread_mode( # thread-local, so not needed with multiple threads 1447 distribution_strategy_context._CrossReplicaThreadMode(self._strategy)) # pylint: disable=protected-access 1448 try: 1449 return merge_fn(self._strategy, *args, **kwargs) 1450 finally: 1451 _pop_per_thread_mode() 1452 1453 @property 1454 def num_replicas_in_sync(self): 1455 """Returns number of replicas over which gradients are aggregated.""" 1456 return self._strategy.num_replicas_in_sync 1457 1458 @property 1459 def replica_id_in_sync_group(self): 1460 """Which replica is being defined, from 0 to `num_replicas_in_sync - 1`.""" 1461 require_replica_context(self) 1462 return self._replica_id_in_sync_group 1463 1464 @property 1465 def strategy(self): 1466 """The current `tf.distribute.Strategy` object.""" 1467 return self._strategy 1468 1469 @property 1470 def devices(self): 1471 """The devices this replica is to be executed on, as a tuple of strings.""" 1472 require_replica_context(self) 1473 return (device_util.current(),) 1474 1475 def all_reduce(self, reduce_op, value): 1476 """All-reduces the given `Tensor` nest across replicas. 1477 1478 If `all_reduce` is called in any replica, it must be called in all replicas. 1479 The nested structure and `Tensor` shapes must be identical in all replicas. 1480 1481 IMPORTANT: The ordering of communications must be identical in all replicas. 1482 1483 Example with two replicas: 1484 Replica 0 `value`: {'a': 1, 'b': [40, 1]} 1485 Replica 1 `value`: {'a': 3, 'b': [ 2, 98]} 1486 1487 If `reduce_op` == `SUM`: 1488 Result (on all replicas): {'a': 4, 'b': [42, 99]} 1489 1490 If `reduce_op` == `MEAN`: 1491 Result (on all replicas): {'a': 2, 'b': [21, 49.5]} 1492 1493 Args: 1494 reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum. 1495 value: The nested structure of `Tensor`s to all-reduced. 1496 The structure must be compatible with `tf.nest`. 1497 1498 Returns: 1499 A `Tensor` nest with the reduced `value`s from each replica. 1500 """ 1501 def batch_all_reduce(strategy, *value_flat): 1502 return strategy.extended.batch_reduce_to( 1503 reduce_op, [(v, _batch_reduce_destination(v)) for v in value_flat]) 1504 1505 if reduce_op in [reduce_util.ReduceOp.SUM, reduce_util.ReduceOp.MEAN]: 1506 # TODO(cjfj): Work out why `batch_reduce` doesn't return the correct grad. 1507 @custom_gradient.custom_gradient 1508 def grad_wrapper(*xs): 1509 ys = self.merge_call(batch_all_reduce, args=xs) 1510 # The gradient of an all-sum is itself an all-sum (all-mean, likewise). 1511 return ys, lambda *dy_s: self.all_reduce(reduce_op, dy_s) 1512 return nest.pack_sequence_as(value, grad_wrapper(*nest.flatten(value))) 1513 else: 1514 # TODO(cjfj): Implement gradients for other reductions. 1515 reduced = nest.pack_sequence_as( 1516 value, self.merge_call(batch_all_reduce, args=nest.flatten(value))) 1517 return nest.map_structure(array_ops.prevent_gradient, reduced) 1518 1519 # TODO(josh11b): Implement `start_all_reduce(method, t)` for efficient 1520 # all-reduce. It would return a function returning the result of reducing `t` 1521 # across all replicas. The caller would wait to call this function until they 1522 # needed the reduce result, allowing an efficient implementation: 1523 # * With eager execution, the reduction could be performed asynchronously 1524 # in the background, not blocking until the result was needed. 1525 # * When constructing a graph, it could batch up all reduction requests up 1526 # to that point that the first result is needed. Most likely this can be 1527 # implemented in terms of `merge_call()` and `batch_reduce_to()`. 1528 1529 1530def _batch_reduce_destination(x): 1531 """Returns the destinations for batch all-reduce.""" 1532 if isinstance(x, ops.Tensor): # One device strategies. 1533 return x.device 1534 else: 1535 return x 1536 1537 1538# ------------------------------------------------------------------------------ 1539 1540 1541class _DefaultDistributionStrategy(DistributionStrategy): 1542 """Default `tf.distribute.Strategy` if none is explicitly selected.""" 1543 1544 def __init__(self): 1545 super(_DefaultDistributionStrategy, self).__init__( 1546 _DefaultDistributionExtended(self)) 1547 1548 1549class _DefaultDistributionExtended(DistributionStrategyExtended): 1550 """Implementation of _DefaultDistributionStrategy.""" 1551 1552 def _scope(self, strategy): 1553 """Context manager setting a variable creator and `self` as current.""" 1554 if distribution_strategy_context.has_strategy(): 1555 raise RuntimeError("Must not nest tf.distribute.Strategy scopes.") 1556 1557 def creator(next_creator, *args, **kwargs): 1558 _require_strategy_scope_strategy(strategy) 1559 return next_creator(*args, **kwargs) 1560 1561 return _CurrentDistributionContext( 1562 strategy, variable_scope.variable_creator_scope(creator)) 1563 1564 def colocate_vars_with(self, colocate_with_variable): 1565 """Does not require `self.scope`.""" 1566 _require_strategy_scope_extended(self) 1567 return ops.colocate_with(colocate_with_variable) 1568 1569 def variable_created_in_scope(self, v): 1570 return v._distribute_strategy is None # pylint: disable=protected-access 1571 1572 def _make_dataset_iterator(self, dataset): 1573 return _DefaultDistributionExtended.DefaultInputIterator(dataset) 1574 1575 def _make_input_fn_iterator(self, 1576 input_fn, 1577 replication_mode=InputReplicationMode.PER_WORKER): 1578 dataset = input_fn(InputContext()) 1579 return _DefaultDistributionExtended.DefaultInputIterator(dataset) 1580 1581 def _experimental_make_numpy_dataset(self, numpy_input, session): 1582 numpy_flat = nest.flatten(numpy_input) 1583 vars_flat = tuple( 1584 variable_scope.variable(array_ops.zeros(i.shape, i.dtype), 1585 trainable=False, use_resource=True) 1586 for i in numpy_flat 1587 ) 1588 for v, i in zip(vars_flat, numpy_flat): 1589 numpy_dataset.init_var_from_numpy(v, i, session) 1590 vars_nested = nest.pack_sequence_as(numpy_input, vars_flat) 1591 return dataset_ops.Dataset.from_tensor_slices(vars_nested) 1592 1593 def _broadcast_to(self, tensor, destinations): 1594 if destinations is None: 1595 return tensor 1596 else: 1597 raise NotImplementedError("TODO") 1598 1599 def _call_for_each_replica(self, fn, args, kwargs): 1600 with ReplicaContext( 1601 self._container_strategy(), 1602 replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)): 1603 return fn(*args, **kwargs) 1604 1605 def _reduce_to(self, reduce_op, value, destinations): 1606 # TODO(josh11b): Use destinations? 1607 del reduce_op, destinations 1608 return value 1609 1610 def _update(self, var, fn, args, kwargs, group): 1611 # The implementations of _update() and _update_non_slot() are identical 1612 # except _update() passes `var` as the first argument to `fn()`. 1613 return self._update_non_slot(var, fn, (var,) + tuple(args), kwargs, group) 1614 1615 def _update_non_slot(self, colocate_with, fn, args, kwargs, should_group): 1616 # TODO(josh11b): Figure out what we should be passing to UpdateContext() 1617 # once that value is used for something. 1618 with ops.colocate_with(colocate_with), UpdateContext(colocate_with): 1619 result = fn(*args, **kwargs) 1620 if should_group: 1621 return result 1622 else: 1623 return nest.map_structure(self._local_results, result) 1624 1625 def read_var(self, replica_local_var): 1626 return array_ops.identity(replica_local_var) 1627 1628 def _local_results(self, distributed_value): 1629 return (distributed_value,) 1630 1631 def value_container(self, value): 1632 return value 1633 1634 @property 1635 def _num_replicas_in_sync(self): 1636 return 1 1637 1638 @property 1639 def worker_devices(self): 1640 raise RuntimeError("worker_devices() method unsupported by default " 1641 "tf.distribute.Strategy.") 1642 1643 @property 1644 def parameter_devices(self): 1645 raise RuntimeError("parameter_devices() method unsupported by default " 1646 "tf.distribute.Strategy.") 1647 1648 def non_slot_devices(self, var_list): 1649 return min(var_list, key=lambda x: x.name) 1650 1651 # TODO(priyag): This should inherit from `InputIterator`, once dependency 1652 # issues have been resolved. 1653 class DefaultInputIterator(object): 1654 """Default implementation of `InputIterator` for default strategy.""" 1655 1656 def __init__(self, dataset): 1657 self._dataset = dataset 1658 if eager_context.executing_eagerly(): 1659 self._iterator = dataset.make_one_shot_iterator() 1660 else: 1661 self._iterator = dataset.make_initializable_iterator() 1662 1663 def get_next(self): 1664 return self._iterator.get_next() 1665 1666 def initialize(self): 1667 if eager_context.executing_eagerly(): 1668 self._iterator = self._dataset.make_one_shot_iterator() 1669 return [] 1670 else: 1671 return [self._iterator.initializer] 1672 1673 # TODO(priyag): Delete this once all strategies use global batch size. 1674 @property 1675 def _global_batch_size(self): 1676 """Global and per-replica batching are equivalent for this strategy.""" 1677 return True 1678 1679 1680# ------------------------------------------------------------------------------ 1681# We haven't yet implemented deserialization for DistributedVariables. 1682# So here we catch any attempts to deserialize variables 1683# when using distribution strategies. 1684# pylint: disable=protected-access 1685_original_from_proto = resource_variable_ops._from_proto_fn 1686 1687 1688def _from_proto_fn(v, import_scope=None): 1689 if distribution_strategy_context.has_strategy(): 1690 raise NotImplementedError( 1691 "Deserialization of variables is not yet supported when using a " 1692 "tf.distribute.Strategy.") 1693 else: 1694 return _original_from_proto(v, import_scope=import_scope) 1695 1696resource_variable_ops._from_proto_fn = _from_proto_fn 1697# pylint: enable=protected-access 1698 1699 1700#------------------------------------------------------------------------------- 1701# Shorthand for some methods from distribution_strategy_context. 1702_push_per_thread_mode = distribution_strategy_context._push_per_thread_mode # pylint: disable=protected-access 1703_get_per_thread_mode = distribution_strategy_context._get_per_thread_mode # pylint: disable=protected-access 1704_pop_per_thread_mode = distribution_strategy_context._pop_per_thread_mode # pylint: disable=protected-access 1705_get_default_replica_mode = ( 1706 distribution_strategy_context._get_default_replica_mode) # pylint: disable=protected-access 1707