1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Class MirroredStrategy implementing tf.distribute.Strategy.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import copy 22 23from tensorflow.python import tf2 24from tensorflow.python.distribute import collective_util 25from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib 26from tensorflow.python.distribute import cross_device_utils 27from tensorflow.python.distribute import device_util 28from tensorflow.python.distribute import distribute_lib 29from tensorflow.python.distribute import distribute_utils 30from tensorflow.python.distribute import distribution_strategy_context 31from tensorflow.python.distribute import input_lib 32from tensorflow.python.distribute import mirrored_run 33from tensorflow.python.distribute import multi_worker_util 34from tensorflow.python.distribute import numpy_dataset 35from tensorflow.python.distribute import reduce_util 36from tensorflow.python.distribute import values 37from tensorflow.python.distribute import values_util 38from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver 39from tensorflow.python.eager import context 40from tensorflow.python.eager import tape 41from tensorflow.python.framework import config 42from tensorflow.python.framework import constant_op 43from tensorflow.python.framework import device as tf_device 44from tensorflow.python.framework import dtypes 45from tensorflow.python.framework import ops 46from tensorflow.python.ops import array_ops 47from tensorflow.python.ops import control_flow_ops 48from tensorflow.python.ops import control_flow_util 49from tensorflow.python.platform import tf_logging as logging 50from tensorflow.python.util import nest 51from tensorflow.python.util.tf_export import tf_export 52 53# TODO(josh11b): Replace asserts in this file with if ...: raise ... 54 55 56def _is_device_list_single_worker(devices): 57 """Checks whether the devices list is for single or multi-worker. 58 59 Args: 60 devices: a list of device strings or tf.config.LogicalDevice objects, for 61 either local or for remote devices. 62 63 Returns: 64 a boolean indicating whether these device strings are for local or for 65 remote. 66 67 Raises: 68 ValueError: if device strings are not consistent. 69 """ 70 specs = [] 71 for d in devices: 72 name = d.name if isinstance(d, context.LogicalDevice) else d 73 specs.append(tf_device.DeviceSpec.from_string(name)) 74 num_workers = len({(d.job, d.task, d.replica) for d in specs}) 75 all_local = all(d.job in (None, "localhost") for d in specs) 76 any_local = any(d.job in (None, "localhost") for d in specs) 77 78 if any_local and not all_local: 79 raise ValueError("Local device string cannot have job specified other " 80 "than 'localhost'") 81 82 if num_workers == 1 and not all_local: 83 if any(d.task is None for d in specs): 84 raise ValueError("Remote device string must have task specified.") 85 86 return num_workers == 1 87 88 89def _cluster_spec_to_device_list(cluster_spec, num_gpus_per_worker): 90 """Returns a device list given a cluster spec.""" 91 cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec) 92 devices = [] 93 for task_type in ("chief", "worker"): 94 for task_id in range(len(cluster_spec.as_dict().get(task_type, []))): 95 if num_gpus_per_worker == 0: 96 devices.append("/job:%s/task:%d/device:CPU:0" % (task_type, task_id)) 97 else: 98 devices.extend([ 99 "/job:%s/task:%d/device:GPU:%i" % (task_type, task_id, gpu_id) 100 for gpu_id in range(num_gpus_per_worker) 101 ]) 102 return devices 103 104 105def _group_device_list(devices): 106 """Groups the devices list by task_type and task_id. 107 108 Args: 109 devices: a list of device strings for remote devices. 110 111 Returns: 112 a dict of list of device strings mapping from task_type to a list of devices 113 for the task_type in the ascending order of task_id. 114 """ 115 assert not _is_device_list_single_worker(devices) 116 device_dict = {} 117 118 for d in devices: 119 d_spec = tf_device.DeviceSpec.from_string(d) 120 121 # Create an entry for the task_type. 122 if d_spec.job not in device_dict: 123 device_dict[d_spec.job] = [] 124 125 # Fill the device list for task_type until it covers the task_id. 126 while len(device_dict[d_spec.job]) <= d_spec.task: 127 device_dict[d_spec.job].append([]) 128 129 device_dict[d_spec.job][d_spec.task].append(d) 130 131 return device_dict 132 133 134def _is_gpu_device(device): 135 return tf_device.DeviceSpec.from_string(device).device_type == "GPU" 136 137 138def _infer_num_gpus_per_worker(devices): 139 """Infers the number of GPUs on each worker. 140 141 Currently to make multi-worker cross device ops work, we need all workers to 142 have the same number of GPUs. 143 144 Args: 145 devices: a list of device strings, can be either local devices or remote 146 devices. 147 148 Returns: 149 number of GPUs per worker. 150 151 Raises: 152 ValueError if workers have different number of GPUs or GPU indices are not 153 consecutive and starting from 0. 154 """ 155 if _is_device_list_single_worker(devices): 156 return sum(1 for d in devices if _is_gpu_device(d)) 157 else: 158 device_dict = _group_device_list(devices) 159 num_gpus = None 160 for _, devices_in_task in device_dict.items(): 161 for device_in_task in devices_in_task: 162 if num_gpus is None: 163 num_gpus = sum(1 for d in device_in_task if _is_gpu_device(d)) 164 165 # Verify other workers have the same number of GPUs. 166 elif num_gpus != sum(1 for d in device_in_task if _is_gpu_device(d)): 167 raise ValueError("All workers should have the same number of GPUs.") 168 169 for d in device_in_task: 170 d_spec = tf_device.DeviceSpec.from_string(d) 171 if (d_spec.device_type == "GPU" and 172 d_spec.device_index >= num_gpus): 173 raise ValueError("GPU `device_index` on a worker should be " 174 "consecutive and start from 0.") 175 return num_gpus 176 177 178def all_local_devices(num_gpus=None): 179 devices = config.list_logical_devices("GPU") 180 if num_gpus is not None: 181 devices = devices[:num_gpus] 182 return devices or config.list_logical_devices("CPU") 183 184 185def all_devices(): 186 devices = [] 187 tfconfig = TFConfigClusterResolver() 188 if tfconfig.cluster_spec().as_dict(): 189 devices = _cluster_spec_to_device_list(tfconfig.cluster_spec(), 190 context.num_gpus()) 191 return devices if devices else all_local_devices() 192 193 194@tf_export("distribute.MirroredStrategy", v1=[]) # pylint: disable=g-classes-have-attributes 195class MirroredStrategy(distribute_lib.Strategy): 196 """Synchronous training across multiple replicas on one machine. 197 198 This strategy is typically used for training on one 199 machine with multiple GPUs. For TPUs, use 200 `tf.distribute.TPUStrategy`. To use `MirroredStrategy` with multiple workers, 201 please refer to `tf.distribute.experimental.MultiWorkerMirroredStrategy`. 202 203 For example, a variable created under a `MirroredStrategy` is a 204 `MirroredVariable`. If no devices are specified in the constructor argument of 205 the strategy then it will use all the available GPUs. If no GPUs are found, it 206 will use the available CPUs. Note that TensorFlow treats all CPUs on a 207 machine as a single device, and uses threads internally for parallelism. 208 209 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 210 >>> with strategy.scope(): 211 ... x = tf.Variable(1.) 212 >>> x 213 MirroredVariable:{ 214 0: <tf.Variable ... shape=() dtype=float32, numpy=1.0>, 215 1: <tf.Variable ... shape=() dtype=float32, numpy=1.0> 216 } 217 218 While using distribution strategies, all the variable creation should be done 219 within the strategy's scope. This will replicate the variables across all the 220 replicas and keep them in sync using an all-reduce algorithm. 221 222 Variables created inside a `MirroredStrategy` which is wrapped with a 223 `tf.function` are still `MirroredVariables`. 224 225 >>> x = [] 226 >>> @tf.function # Wrap the function with tf.function. 227 ... def create_variable(): 228 ... if not x: 229 ... x.append(tf.Variable(1.)) 230 ... return x[0] 231 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 232 >>> with strategy.scope(): 233 ... _ = create_variable() 234 ... print(x[0]) 235 MirroredVariable:{ 236 0: <tf.Variable ... shape=() dtype=float32, numpy=1.0>, 237 1: <tf.Variable ... shape=() dtype=float32, numpy=1.0> 238 } 239 240 `experimental_distribute_dataset` can be used to distribute the dataset across 241 the replicas when writing your own training loop. If you are using `.fit` and 242 `.compile` methods available in `tf.keras`, then `tf.keras` will handle the 243 distribution for you. 244 245 For example: 246 247 ```python 248 my_strategy = tf.distribute.MirroredStrategy() 249 with my_strategy.scope(): 250 @tf.function 251 def distribute_train_epoch(dataset): 252 def replica_fn(input): 253 # process input and return result 254 return result 255 256 total_result = 0 257 for x in dataset: 258 per_replica_result = my_strategy.run(replica_fn, args=(x,)) 259 total_result += my_strategy.reduce(tf.distribute.ReduceOp.SUM, 260 per_replica_result, axis=None) 261 return total_result 262 263 dist_dataset = my_strategy.experimental_distribute_dataset(dataset) 264 for _ in range(EPOCHS): 265 train_result = distribute_train_epoch(dist_dataset) 266 ``` 267 268 Args: 269 devices: a list of device strings such as `['/gpu:0', '/gpu:1']`. If 270 `None`, all available GPUs are used. If no GPUs are found, CPU is used. 271 cross_device_ops: optional, a descedant of `CrossDeviceOps`. If this is not 272 set, `NcclAllReduce()` will be used by default. One would customize this 273 if NCCL isn't available or if a special implementation that exploits 274 the particular hardware is available. 275 """ 276 277 # Only set this in tests. 278 _collective_key_base = 0 279 280 def __init__(self, devices=None, cross_device_ops=None): 281 extended = MirroredExtended( 282 self, devices=devices, cross_device_ops=cross_device_ops) 283 super(MirroredStrategy, self).__init__(extended) 284 distribute_lib.distribution_strategy_gauge.get_cell("V2").set( 285 "MirroredStrategy") 286 287 288@tf_export(v1=["distribute.MirroredStrategy"]) 289class MirroredStrategyV1(distribute_lib.StrategyV1): # pylint: disable=g-missing-docstring 290 291 __doc__ = MirroredStrategy.__doc__ 292 293 # Only set this in tests. 294 _collective_key_base = 0 295 296 def __init__(self, devices=None, cross_device_ops=None): 297 extended = MirroredExtended( 298 self, devices=devices, cross_device_ops=cross_device_ops) 299 super(MirroredStrategyV1, self).__init__(extended) 300 distribute_lib.distribution_strategy_gauge.get_cell("V1").set( 301 "MirroredStrategy") 302 303 304# TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1. 305class MirroredExtended(distribute_lib.StrategyExtendedV1): 306 """Implementation of MirroredStrategy.""" 307 308 # If this is set to True, use NCCL collective ops instead of NCCL cross device 309 # ops. 310 _prefer_collective_ops = False 311 312 def __init__(self, container_strategy, devices=None, cross_device_ops=None): 313 super(MirroredExtended, self).__init__(container_strategy) 314 if context.executing_eagerly(): 315 if devices and not _is_device_list_single_worker(devices): 316 raise RuntimeError("In-graph multi-worker training with " 317 "`MirroredStrategy` is not supported in eager mode.") 318 else: 319 if TFConfigClusterResolver().cluster_spec().as_dict(): 320 # if you are executing in eager mode, only the single machine code 321 # path is supported. 322 logging.info("Initializing local devices since in-graph multi-worker " 323 "training with `MirroredStrategy` is not supported in " 324 "eager mode. TF_CONFIG will be ignored when " 325 "when initializing `MirroredStrategy`.") 326 devices = devices or all_local_devices() 327 else: 328 devices = devices or all_devices() 329 330 assert devices, ("Got an empty `devices` list and unable to recognize " 331 "any local devices.") 332 self._cross_device_ops = cross_device_ops 333 self._collective_ops_in_use = False 334 self._collective_key_base = container_strategy._collective_key_base 335 self._initialize_strategy(devices) 336 self._communication_options = collective_util.Options( 337 implementation=collective_util.CommunicationImplementation.NCCL) 338 339 # TODO(b/128995245): Enable last partial batch support in graph mode. 340 if ops.executing_eagerly_outside_functions(): 341 self.experimental_enable_get_next_as_optional = True 342 343 # Flag to turn on VariablePolicy. 344 self._use_var_policy = False 345 346 def _use_merge_call(self): 347 # We currently only disable merge_call when XLA is used to compile the `fn` 348 # passed to `strategy.run` and all devices are GPU. 349 return not control_flow_util.GraphOrParentsInXlaContext( 350 ops.get_default_graph()) or not all( 351 [_is_gpu_device(d) for d in self._devices]) 352 353 def _initialize_strategy(self, devices): 354 # The _initialize_strategy method is intended to be used by distribute 355 # coordinator as well. 356 assert devices, "Must specify at least one device." 357 devices = tuple(device_util.resolve(d) for d in devices) 358 assert len(set(devices)) == len(devices), ( 359 "No duplicates allowed in `devices` argument: %s" % (devices,)) 360 if _is_device_list_single_worker(devices): 361 self._initialize_single_worker(devices) 362 self._collective_ops = self._make_collective_ops(devices) 363 if self._prefer_collective_ops and ( 364 isinstance(self._cross_device_ops, cross_device_ops_lib.NcclAllReduce) 365 or isinstance(self._inferred_cross_device_ops, 366 cross_device_ops_lib.NcclAllReduce)): 367 self._collective_ops_in_use = True 368 self._inferred_cross_device_ops = None 369 logging.info("Using MirroredStrategy with devices %r", devices) 370 else: 371 self._initialize_multi_worker(devices) 372 373 def _make_collective_ops(self, devices): 374 self._collective_keys = cross_device_utils.CollectiveKeys( 375 group_key_start=1 + self._collective_key_base) # pylint: disable=protected-access 376 return cross_device_ops_lib.CollectiveAllReduce( 377 devices=self._devices, 378 group_size=len(self._devices), 379 collective_keys=self._collective_keys) 380 381 def _initialize_single_worker(self, devices): 382 """Initializes the object for single-worker training.""" 383 self._devices = tuple(device_util.canonicalize(d) for d in devices) 384 self._input_workers_devices = ( 385 (device_util.canonicalize("/device:CPU:0", devices[0]), devices),) 386 387 self._inferred_cross_device_ops = None if self._cross_device_ops else ( 388 cross_device_ops_lib.select_cross_device_ops(devices)) 389 self._host_input_device = numpy_dataset.SingleDevice( 390 self._input_workers_devices[0][0]) 391 self._is_multi_worker_training = False 392 device_spec = tf_device.DeviceSpec.from_string( 393 self._input_workers_devices[0][0]) 394 # Ensures when we enter strategy.scope() we use the correct default device 395 if device_spec.job is not None and device_spec.job != "localhost": 396 self._default_device = "/job:%s/replica:%d/task:%d" % ( 397 device_spec.job, device_spec.replica, device_spec.task) 398 399 def _initialize_multi_worker(self, devices): 400 """Initializes the object for multi-worker training.""" 401 device_dict = _group_device_list(devices) 402 workers = [] 403 worker_devices = [] 404 for job in ("chief", "worker"): 405 for task in range(len(device_dict.get(job, []))): 406 worker = "/job:%s/task:%d" % (job, task) 407 workers.append(worker) 408 worker_devices.append((worker, device_dict[job][task])) 409 410 # Setting `_default_device` will add a device scope in the 411 # distribution.scope. We set the default device to the first worker. When 412 # users specify device under distribution.scope by 413 # with tf.device("/cpu:0"): 414 # ... 415 # their ops will end up on the cpu device of its first worker, e.g. 416 # "/job:worker/task:0/device:CPU:0". Note this is not used in replica mode. 417 self._default_device = workers[0] 418 self._host_input_device = numpy_dataset.SingleDevice(workers[0]) 419 420 self._devices = tuple(devices) 421 self._input_workers_devices = worker_devices 422 self._is_multi_worker_training = True 423 424 if len(workers) > 1: 425 # Grandfather usage in the legacy tests if they're configured properly. 426 if (not isinstance(self._cross_device_ops, 427 cross_device_ops_lib.ReductionToOneDevice) or 428 self._cross_device_ops._num_between_graph_workers > 1): # pylint: disable=protected-access 429 raise ValueError( 430 "In-graph multi-worker training with `MirroredStrategy` is not " 431 "supported.") 432 self._inferred_cross_device_ops = self._cross_device_ops 433 else: 434 # TODO(yuefengz): make `select_cross_device_ops` work with device strings 435 # containing job names. 436 self._inferred_cross_device_ops = cross_device_ops_lib.NcclAllReduce() 437 438 logging.info("Using MirroredStrategy with remote devices %r", devices) 439 440 def _input_workers_with_options(self, options=None): 441 if not options: 442 return input_lib.InputWorkers(self._input_workers_devices) 443 if (options.experimental_replication_mode == 444 distribute_lib.InputReplicationMode.PER_REPLICA): 445 if options.experimental_place_dataset_on_device: 446 self._input_workers_devices = ( 447 tuple( 448 (device_util.canonicalize(d, d), (d,)) for d in self._devices)) 449 else: 450 self._input_workers_devices = ( 451 tuple((device_util.canonicalize("/device:CPU:0", d), (d,)) 452 for d in self._devices)) 453 return input_lib.InputWorkers(self._input_workers_devices) 454 else: 455 if not options.experimental_fetch_to_device: 456 return input_lib.InputWorkers([ 457 (host_device, (host_device,) * len(compute_devices)) 458 for host_device, compute_devices in self._input_workers_devices 459 ]) 460 else: 461 return input_lib.InputWorkers(self._input_workers_devices) 462 463 @property 464 def _input_workers(self): 465 return self._input_workers_with_options() 466 467 def _get_variable_creator_initial_value(self, 468 replica_id, 469 device, 470 primary_var, 471 **kwargs): 472 """Return the initial value for variables on a replica.""" 473 if replica_id == 0: 474 return kwargs["initial_value"] 475 else: 476 assert primary_var is not None 477 assert device is not None 478 assert kwargs is not None 479 480 def initial_value_fn(): 481 if context.executing_eagerly() or ops.inside_function(): 482 init_value = primary_var.value() 483 return array_ops.identity(init_value) 484 else: 485 with ops.device(device): 486 init_value = primary_var.initial_value 487 return array_ops.identity(init_value) 488 489 return initial_value_fn 490 491 def _create_variable(self, next_creator, **kwargs): 492 """Create a mirrored variable. See `DistributionStrategy.scope`.""" 493 colocate_with = kwargs.pop("colocate_with", None) 494 if colocate_with is None: 495 devices = self._devices 496 elif isinstance(colocate_with, numpy_dataset.SingleDevice): 497 with ops.device(colocate_with.device): 498 return next_creator(**kwargs) 499 else: 500 devices = colocate_with._devices # pylint: disable=protected-access 501 502 def _real_mirrored_creator(**kwargs): # pylint: disable=g-missing-docstring 503 value_list = [] 504 for i, d in enumerate(devices): 505 with ops.device(d): 506 kwargs["initial_value"] = self._get_variable_creator_initial_value( 507 replica_id=i, 508 device=d, 509 primary_var=value_list[0] if value_list else None, 510 **kwargs) 511 if i > 0: 512 # Give replicas meaningful distinct names: 513 var0name = value_list[0].name.split(":")[0] 514 # We append a / to variable names created on replicas with id > 0 to 515 # ensure that we ignore the name scope and instead use the given 516 # name as the absolute name of the variable. 517 kwargs["name"] = "%s/replica_%d/" % (var0name, i) 518 with context.device_policy(context.DEVICE_PLACEMENT_SILENT): 519 # Don't record operations (e.g. other variable reads) during 520 # variable creation. 521 with tape.stop_recording(): 522 v = next_creator(**kwargs) 523 assert not isinstance(v, values.DistributedVariable) 524 value_list.append(v) 525 return value_list 526 527 return distribute_utils.create_mirrored_variable( 528 self._container_strategy(), _real_mirrored_creator, 529 distribute_utils.VARIABLE_CLASS_MAPPING, 530 distribute_utils.VARIABLE_POLICY_MAPPING, **kwargs) 531 532 def _validate_colocate_with_variable(self, colocate_with_variable): 533 distribute_utils.validate_colocate_distributed_variable( 534 colocate_with_variable, self) 535 536 def _make_dataset_iterator(self, dataset): 537 return input_lib.DatasetIterator( 538 dataset, 539 self._input_workers, 540 self._container_strategy(), 541 num_replicas_in_sync=self._num_replicas_in_sync) 542 543 def _make_input_fn_iterator( 544 self, 545 input_fn, 546 replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): 547 input_contexts = [] 548 num_workers = self._input_workers.num_workers 549 for i in range(num_workers): 550 input_contexts.append(distribute_lib.InputContext( 551 num_input_pipelines=num_workers, 552 input_pipeline_id=i, 553 num_replicas_in_sync=self._num_replicas_in_sync)) 554 return input_lib.InputFunctionIterator(input_fn, self._input_workers, 555 input_contexts, 556 self._container_strategy()) 557 558 def _experimental_distribute_dataset(self, dataset, options): 559 if (options and options.experimental_replication_mode == 560 distribute_lib.InputReplicationMode.PER_REPLICA): 561 raise NotImplementedError( 562 "InputReplicationMode.PER_REPLICA " 563 "is only supported in " 564 "`experimental_distribute_datasets_from_function`." 565 ) 566 return input_lib.get_distributed_dataset( 567 dataset, 568 self._input_workers_with_options(options), 569 self._container_strategy(), 570 num_replicas_in_sync=self._num_replicas_in_sync, 571 options=options) 572 573 def _experimental_make_numpy_dataset(self, numpy_input, session): 574 return numpy_dataset.one_host_numpy_dataset( 575 numpy_input, self._host_input_device, session) 576 577 def _distribute_datasets_from_function(self, dataset_fn, options): 578 input_workers = self._input_workers_with_options(options) 579 input_contexts = [] 580 num_workers = input_workers.num_workers 581 for i in range(num_workers): 582 input_contexts.append(distribute_lib.InputContext( 583 num_input_pipelines=num_workers, 584 input_pipeline_id=i, 585 num_replicas_in_sync=self._num_replicas_in_sync)) 586 587 return input_lib.get_distributed_datasets_from_function( 588 dataset_fn, input_workers, input_contexts, self._container_strategy(), 589 options) 590 591 def _experimental_distribute_values_from_function(self, value_fn): 592 per_replica_values = [] 593 for replica_id in range(self._num_replicas_in_sync): 594 per_replica_values.append(value_fn( 595 distribute_lib.ValueContext(replica_id, 596 self._num_replicas_in_sync))) 597 return distribute_utils.regroup(per_replica_values, always_wrap=True) 598 599 # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. 600 def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, 601 initial_loop_values=None): 602 if initial_loop_values is None: 603 initial_loop_values = {} 604 initial_loop_values = nest.flatten(initial_loop_values) 605 606 ctx = input_lib.MultiStepContext() 607 def body(i, *args): 608 """A wrapper around `fn` to create the while loop body.""" 609 del args 610 fn_result = fn(ctx, iterator.get_next()) 611 for (name, output) in ctx.last_step_outputs.items(): 612 # Convert all outputs to tensors, potentially from `DistributedValues`. 613 ctx.last_step_outputs[name] = self._local_results(output) 614 flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) 615 with ops.control_dependencies([fn_result]): 616 return [i + 1] + flat_last_step_outputs 617 618 # We capture the control_flow_context at this point, before we run `fn` 619 # inside a while_loop. This is useful in cases where we might need to exit 620 # these contexts and get back to the outer context to do some things, for 621 # e.g. create an op which should be evaluated only once at the end of the 622 # loop on the host. One such usage is in creating metrics' value op. 623 self._outer_control_flow_context = ( 624 ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access 625 626 cond = lambda i, *args: i < iterations 627 i = constant_op.constant(0) 628 loop_result = control_flow_ops.while_loop( 629 cond, body, [i] + initial_loop_values, name="", 630 parallel_iterations=1, back_prop=False, swap_memory=False, 631 return_same_structure=True) 632 del self._outer_control_flow_context 633 634 ctx.run_op = control_flow_ops.group(loop_result) 635 636 # Convert the last_step_outputs from a list to the original dict structure 637 # of last_step_outputs. 638 last_step_tensor_outputs = loop_result[1:] 639 last_step_tensor_outputs_dict = nest.pack_sequence_as( 640 ctx.last_step_outputs, last_step_tensor_outputs) 641 642 for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access 643 output = last_step_tensor_outputs_dict[name] 644 # For outputs that have already been reduced, wrap them in a Mirrored 645 # container, else in a PerReplica container. 646 if reduce_op is None: 647 last_step_tensor_outputs_dict[name] = distribute_utils.regroup(output) 648 else: 649 assert len(output) == 1 650 last_step_tensor_outputs_dict[name] = output[0] 651 652 ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access 653 return ctx 654 655 def _broadcast_to(self, tensor, destinations): 656 # This is both a fast path for Python constants, and a way to delay 657 # converting Python values to a tensor until we know what type it 658 # should be converted to. Otherwise we have trouble with: 659 # global_step.assign_add(1) 660 # since the `1` gets broadcast as an int32 but global_step is int64. 661 if isinstance(tensor, (float, int)): 662 return tensor 663 # TODO(josh11b): In eager mode, use one thread per device, or async mode. 664 if not destinations: 665 # TODO(josh11b): Use current logical device instead of 0 here. 666 destinations = self._devices 667 return self._get_cross_device_ops(tensor).broadcast(tensor, destinations) 668 669 def _call_for_each_replica(self, fn, args, kwargs): 670 return mirrored_run.call_for_each_replica( 671 self._container_strategy(), fn, args, kwargs) 672 673 def _configure(self, 674 session_config=None, 675 cluster_spec=None, 676 task_type=None, 677 task_id=None): 678 del task_type, task_id 679 680 if session_config: 681 session_config.CopyFrom(self._update_config_proto(session_config)) 682 683 if cluster_spec: 684 # TODO(yuefengz): remove the following code once cluster_resolver is 685 # added. 686 num_gpus_per_worker = _infer_num_gpus_per_worker(self._devices) 687 multi_worker_devices = _cluster_spec_to_device_list( 688 cluster_spec, num_gpus_per_worker) 689 self._initialize_multi_worker(multi_worker_devices) 690 691 def _update_config_proto(self, config_proto): 692 updated_config = copy.deepcopy(config_proto) 693 updated_config.isolate_session_state = True 694 return updated_config 695 696 def _get_cross_device_ops(self, value): 697 if not self._use_merge_call(): 698 return self._collective_ops 699 700 if self._collective_ops_in_use: 701 if isinstance(value, values.DistributedValues): 702 value_int32 = True in { 703 dtypes.as_dtype(v.dtype) == dtypes.int32 for v in value.values 704 } 705 else: 706 value_int32 = dtypes.as_dtype(value.dtype) == dtypes.int32 707 if value_int32: 708 return cross_device_ops_lib.ReductionToOneDevice() 709 else: 710 return self._collective_ops 711 712 return self._cross_device_ops or self._inferred_cross_device_ops 713 714 def _gather_to_implementation(self, value, destinations, axis, options): 715 if not isinstance(value, values.DistributedValues): 716 # ReductionToOneDevice._gather accepts DistributedValues only. 717 return value 718 return self._get_cross_device_ops(value)._gather( # pylint: disable=protected-access 719 value, 720 destinations=destinations, 721 axis=axis, 722 options=self._communication_options.merge(options)) 723 724 def _reduce_to(self, reduce_op, value, destinations, options): 725 if (distribute_utils.is_mirrored(value) and 726 reduce_op == reduce_util.ReduceOp.MEAN): 727 return value 728 assert not distribute_utils.is_mirrored(value) 729 def get_values(value): 730 if not isinstance(value, values.DistributedValues): 731 # This function handles reducing values that are not PerReplica or 732 # Mirrored values. For example, the same value could be present on all 733 # replicas in which case `value` would be a single value or value could 734 # be 0. 735 return cross_device_ops_lib.reduce_non_distributed_value( 736 reduce_op, value, destinations, self._num_replicas_in_sync) 737 if self._use_merge_call() and self._collective_ops_in_use and (( 738 not cross_device_ops_lib._devices_match(value, destinations) or # pylint: disable=protected-access 739 any("cpu" in d.lower() 740 for d in cross_device_ops_lib.get_devices_from(destinations)))): 741 return cross_device_ops_lib.ReductionToOneDevice().reduce( 742 reduce_op, value, destinations) 743 return self._get_cross_device_ops(value).reduce( 744 reduce_op, 745 value, 746 destinations=destinations, 747 options=self._communication_options.merge(options)) 748 749 return nest.map_structure(get_values, value) 750 751 def _batch_reduce_to(self, reduce_op, value_destination_pairs, options): 752 cross_device_ops = None 753 for value, _ in value_destination_pairs: 754 if cross_device_ops is None: 755 cross_device_ops = self._get_cross_device_ops(value) 756 elif cross_device_ops is not self._get_cross_device_ops(value): 757 raise ValueError("inputs to batch_reduce_to must be either all on the " 758 "the host or all on the compute devices") 759 return cross_device_ops.batch_reduce( 760 reduce_op, 761 value_destination_pairs, 762 options=self._communication_options.merge(options)) 763 764 def _update(self, var, fn, args, kwargs, group): 765 # TODO(josh11b): In eager mode, use one thread per device. 766 assert isinstance(var, values.DistributedVariable) 767 updates = [] 768 for i, v in enumerate(var.values): 769 name = "update_%d" % i 770 with ops.device(v.device), \ 771 distribute_lib.UpdateContext(i), \ 772 ops.name_scope(name): 773 # If args and kwargs are not mirrored, the value is returned as is. 774 updates.append( 775 fn(v, *distribute_utils.select_replica(i, args), 776 **distribute_utils.select_replica(i, kwargs))) 777 return distribute_utils.update_regroup(self, updates, group) 778 779 def _replica_ctx_all_reduce(self, reduce_op, value, options=None): 780 """Implements `StrategyExtendedV2._replica_ctx_all_reduce`.""" 781 # This implementation avoids using `merge_call` and just launches collective 782 # ops in one replica. 783 if options is None: 784 options = collective_util.Options() 785 786 if context.executing_eagerly() or ( 787 not tf2.enabled()) or self._use_merge_call(): 788 # In eager mode, falls back to the default implementation that uses 789 # `merge_call`. Replica functions are running sequentially in eager mode, 790 # and due to the blocking nature of collective ops, execution will hang if 791 # collective ops are to be launched sequentially. 792 return super()._replica_ctx_all_reduce(reduce_op, value, options) 793 794 replica_context = distribution_strategy_context.get_replica_context() 795 assert replica_context, ( 796 "`StrategyExtended._replica_ctx_all_reduce` must be called in a " 797 "replica context") 798 return self._get_cross_device_ops(value)._all_reduce( # pylint: disable=protected-access 799 reduce_op, 800 value, 801 replica_context._replica_id, # pylint: disable=protected-access 802 options) 803 804 def _replica_ctx_update(self, var, fn, args, kwargs, group): 805 if self._use_merge_call(): 806 return super()._replica_ctx_update(var, fn, args, kwargs, group) 807 808 replica_context = distribution_strategy_context.get_replica_context() 809 assert replica_context 810 replica_id = values_util.get_current_replica_id_as_int() 811 name = "update_%d" % replica_id 812 813 if isinstance(var, values.DistributedVariable): 814 var = var._get_replica(replica_id) # pylint: disable=protected-access 815 816 with ops.device(var.device), ops.name_scope(name): 817 result = fn(var, *args, **kwargs) 818 return result 819 820 def _update_non_slot(self, colocate_with, fn, args, kwargs, group): 821 assert isinstance(colocate_with, tuple) 822 # TODO(josh11b): In eager mode, use one thread per device. 823 updates = [] 824 for i, d in enumerate(colocate_with): 825 name = "update_%d" % i 826 with ops.device(d), distribute_lib.UpdateContext(i), ops.name_scope(name): 827 updates.append( 828 fn(*distribute_utils.select_replica(i, args), 829 **distribute_utils.select_replica(i, kwargs))) 830 return distribute_utils.update_regroup(self, updates, group) 831 832 def read_var(self, replica_local_var): 833 """Read the aggregate value of a replica-local variable.""" 834 # pylint: disable=protected-access 835 if distribute_utils.is_sync_on_read(replica_local_var): 836 return replica_local_var._get_cross_replica() 837 assert distribute_utils.is_mirrored(replica_local_var) 838 return array_ops.identity(replica_local_var._get()) 839 # pylint: enable=protected-access 840 841 def value_container(self, val): 842 return distribute_utils.value_container(val) 843 844 @property 845 def _num_replicas_in_sync(self): 846 return len(self._devices) 847 848 @property 849 def worker_devices(self): 850 return self._devices 851 852 @property 853 def worker_devices_by_replica(self): 854 return [[d] for d in self._devices] 855 856 @property 857 def parameter_devices(self): 858 return self.worker_devices 859 860 @property 861 def experimental_between_graph(self): 862 return False 863 864 @property 865 def experimental_should_init(self): 866 return True 867 868 @property 869 def should_checkpoint(self): 870 return True 871 872 @property 873 def should_save_summary(self): 874 return True 875 876 def non_slot_devices(self, var_list): 877 del var_list 878 # TODO(josh11b): Should this be the last logical device instead? 879 return self._devices 880 881 # TODO(priyag): Delete this once all strategies use global batch size. 882 @property 883 def _global_batch_size(self): 884 """`make_dataset_iterator` and `make_numpy_iterator` use global batch size. 885 886 `make_input_fn_iterator` assumes per-replica batching. 887 888 Returns: 889 Boolean. 890 """ 891 return True 892 893 def _in_multi_worker_mode(self): 894 """Whether this strategy indicates working in multi-worker settings.""" 895 return False 896 897 def _get_local_replica_id(self, replica_id_in_sync_group): 898 return replica_id_in_sync_group 899 900 def _get_replica_id_in_sync_group(self, replica_id): 901 return replica_id 902